CINXE.COM

Drug Molecule Generation with VAE

<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/generative/molecule_generation/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Drug Molecule Generation with VAE"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Drug Molecule Generation with VAE"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Drug Molecule Generation with VAE</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink active" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink2" href="/examples/generative/ddim/">Denoising Diffusion Implicit Models</a> <a class="nav-sublink2" href="/examples/generative/random_walks_with_stable_diffusion/">A walk through latent space with Stable Diffusion</a> <a class="nav-sublink2" href="/examples/generative/dreambooth/">DreamBooth</a> <a class="nav-sublink2" href="/examples/generative/ddpm/">Denoising Diffusion Probabilistic Models</a> <a class="nav-sublink2" href="/examples/generative/fine_tune_via_textual_inversion/">Teach StableDiffusion new concepts via Textual Inversion</a> <a class="nav-sublink2" href="/examples/generative/finetune_stable_diffusion/">Fine-tuning Stable Diffusion</a> <a class="nav-sublink2" href="/examples/generative/vae/">Variational AutoEncoder</a> <a class="nav-sublink2" href="/examples/generative/dcgan_overriding_train_step/">GAN overriding Model.train_step</a> <a class="nav-sublink2" href="/examples/generative/wgan_gp/">WGAN-GP overriding Model.train_step</a> <a class="nav-sublink2" href="/examples/generative/conditional_gan/">Conditional GAN</a> <a class="nav-sublink2" href="/examples/generative/cyclegan/">CycleGAN</a> <a class="nav-sublink2" href="/examples/generative/gan_ada/">Data-efficient GANs with Adaptive Discriminator Augmentation</a> <a class="nav-sublink2" href="/examples/generative/deep_dream/">Deep Dream</a> <a class="nav-sublink2" href="/examples/generative/gaugan/">GauGAN for conditional image generation</a> <a class="nav-sublink2" href="/examples/generative/pixelcnn/">PixelCNN</a> <a class="nav-sublink2" href="/examples/generative/stylegan/">Face image generation with StyleGAN</a> <a class="nav-sublink2" href="/examples/generative/vq_vae/">Vector-Quantized Variational Autoencoders</a> <a class="nav-sublink2" href="/examples/generative/neural_style_transfer/">Neural style transfer</a> <a class="nav-sublink2" href="/examples/generative/adain/">Neural Style Transfer with AdaIN</a> <a class="nav-sublink2" href="/examples/generative/gpt2_text_generation_with_keras_hub/">GPT2 Text Generation with KerasHub</a> <a class="nav-sublink2" href="/examples/generative/text_generation_gpt/">GPT text generation from scratch with KerasHub</a> <a class="nav-sublink2" href="/examples/generative/text_generation_with_miniature_gpt/">Text generation with a miniature GPT</a> <a class="nav-sublink2" href="/examples/generative/lstm_character_level_text_generation/">Character-level text generation with LSTM</a> <a class="nav-sublink2" href="/examples/generative/text_generation_fnet/">Text Generation using FNet</a> <a class="nav-sublink2 active" href="/examples/generative/molecule_generation/">Drug Molecule Generation with VAE</a> <a class="nav-sublink2" href="/examples/generative/wgan-graphs/">WGAN-GP with R-GCN for the generation of small molecular graphs</a> <a class="nav-sublink2" href="/examples/generative/random_walks_with_stable_diffusion_3/">A walk through latent space with Stable Diffusion 3</a> <a class="nav-sublink2" href="/examples/generative/real_nvp/">Density estimation using Real NVP</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/generative/'>Generative Deep Learning</a> / Drug Molecule Generation with VAE </div> <div class='k-content'> <h1 id="drug-molecule-generation-with-vae">Drug Molecule Generation with VAE</h1> <p><strong>Author:</strong> <a href="https://www.linkedin.com/in/victor-basu-520958147">Victor Basu</a><br> <strong>Date created:</strong> 2022/03/10<br> <strong>Last modified:</strong> 2022/03/24<br> <strong>Description:</strong> Implementing a Convolutional Variational AutoEncoder (VAE) for Drug Discovery.</p> <div class='example_version_banner keras_2'>ⓘ This example uses Keras 2</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/generative/ipynb/molecule_generation.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/generative/molecule_generation.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>In this example, we use a Variational Autoencoder to generate molecules for drug discovery. We use the research papers <a href="https://arxiv.org/abs/1610.02415">Automatic chemical design using a data-driven continuous representation of molecules</a> and <a href="https://arxiv.org/abs/1805.11973">MolGAN: An implicit generative model for small molecular graphs</a> as a reference.</p> <p>The model described in the paper <strong>Automatic chemical design using a data-driven continuous representation of molecules</strong> generates new molecules via efficient exploration of open-ended spaces of chemical compounds. The model consists of three components: Encoder, Decoder and Predictor. The Encoder converts the discrete representation of a molecule into a real-valued continuous vector, and the Decoder converts these continuous vectors back to discrete molecule representations. The Predictor estimates chemical properties from the latent continuous vector representation of the molecule. Continuous representations allow the use of gradient-based optimization to efficiently guide the search for optimized functional compounds.</p> <p><img alt="intro" src="https://bit.ly/3CtPMzM" /></p> <p><strong>Figure (a)</strong> - A diagram of the autoencoder used for molecule design, including the joint property prediction model. Starting from a discrete molecule representation, such as a SMILES string, the encoder network converts each molecule into a vector in the latent space, which is effectively a continuous molecule representation. Given a point in the latent space, the decoder network produces a corresponding SMILES string. A multilayer perceptron network estimates the value of target properties associated with each molecule.</p> <p><strong>Figure (b)</strong> - Gradient-based optimization in continuous latent space. After training a surrogate model <code>f(z)</code> to predict the properties of molecules based on their latent representation <code>z</code>, we can optimize <code>f(z)</code> with respect to <code>z</code> to find new latent representations expected to match specific desired properties. These new latent representations can then be decoded into SMILES strings, at which point their properties can be tested empirically.</p> <p>For an explanation and implementation of MolGAN, please refer to the Keras Example <a href="https://bit.ly/3pU6zXK"><strong>WGAN-GP with R-GCN for the generation of small molecular graphs</strong></a> by Alexander Kensert. Many of the functions used in the present example are from the above Keras example.</p> <hr /> <h2 id="setup">Setup</h2> <p>RDKit is an open source toolkit for cheminformatics and machine learning. This toolkit come in handy if one is into drug discovery domain. In this example, RDKit is used to conveniently and efficiently transform SMILES to molecule objects, and then from those obtain sets of atoms and bonds.</p> <p>Quoting from <a href="https://keras.io/examples/generative/wgan-graphs/">WGAN-GP with R-GCN for the generation of small molecular graphs</a>):</p> <p><strong>"SMILES expresses the structure of a given molecule in the form of an ASCII string. The SMILES string is a compact encoding which, for smaller molecules, is relatively human-readable. Encoding molecules as a string both alleviates and facilitates database and/or web searching of a given molecule. RDKit uses algorithms to accurately transform a given SMILES to a molecule object, which can then be used to compute a great number of molecular properties/features."</strong></p> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">pip</span> <span class="o">-</span><span class="n">q</span> <span class="n">install</span> <span class="n">rdkit</span><span class="o">-</span><span class="n">pypi</span><span class="o">==</span><span class="mf">2021.9.4</span> </code></pre></div> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">ast</span> <span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">from</span> <span class="nn">tensorflow</span> <span class="kn">import</span> <span class="n">keras</span> <span class="kn">from</span> <span class="nn">tensorflow.keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">from</span> <span class="nn">rdkit</span> <span class="kn">import</span> <span class="n">Chem</span><span class="p">,</span> <span class="n">RDLogger</span> <span class="kn">from</span> <span class="nn">rdkit.Chem</span> <span class="kn">import</span> <span class="n">BondType</span> <span class="kn">from</span> <span class="nn">rdkit.Chem.Draw</span> <span class="kn">import</span> <span class="n">MolsToGridImage</span> <span class="n">RDLogger</span><span class="o">.</span><span class="n">DisableLog</span><span class="p">(</span><span class="s2">&quot;rdApp.*&quot;</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> |████████████████████████████████| 20.6 MB 1.2 MB/s [?25h </code></pre></div> </div> <hr /> <h2 id="dataset">Dataset</h2> <p>We use the <a href="https://bit.ly/3IVBI4x"><strong>ZINC – A Free Database of Commercially Available Compounds for Virtual Screening</strong></a> dataset. The dataset comes with molecule formula in SMILE representation along with their respective molecular properties such as <strong>logP</strong> (water–octanal partition coefficient), <strong>SAS</strong> (synthetic accessibility score) and <strong>QED</strong> (Qualitative Estimate of Drug-likeness).</p> <div class="codehilite"><pre><span></span><code><span class="n">csv_path</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">get_file</span><span class="p">(</span> <span class="s2">&quot;/content/250k_rndm_zinc_drugs_clean_3.csv&quot;</span><span class="p">,</span> <span class="s2">&quot;https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv&quot;</span><span class="p">,</span> <span class="p">)</span> <span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_csv</span><span class="p">(</span><span class="s2">&quot;/content/250k_rndm_zinc_drugs_clean_3.csv&quot;</span><span class="p">)</span> <span class="n">df</span><span class="p">[</span><span class="s2">&quot;smiles&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="s2">&quot;smiles&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="k">lambda</span> <span class="n">s</span><span class="p">:</span> <span class="n">s</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">,</span> <span class="s2">&quot;&quot;</span><span class="p">))</span> <span class="n">df</span><span class="o">.</span><span class="n">head</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Downloading data from https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv 22606589/22606589 [==============================] - 0s 0us/step </code></pre></div> </div> <div> <style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; } <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>.dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; } </code></pre></div> </div> </style> <table border="1" class="dataframe"> <thead> <tr style="text-align: right;"> <th></th> <th>smiles</th> <th>logP</th> <th>qed</th> <th>SAS</th> </tr> </thead> <tbody> <tr> <th>0</th> <td>CC(C)(C)c1ccc2occ(CC(=O)Nc3ccccc3F)c2c1</td> <td>5.05060</td> <td>0.702012</td> <td>2.084095</td> </tr> <tr> <th>1</th> <td>C[C@@H]1CC(Nc2cncc(-c3nncn3C)c2)C[C@@H](C)C1</td> <td>3.11370</td> <td>0.928975</td> <td>3.432004</td> </tr> <tr> <th>2</th> <td>N#Cc1ccc(-c2ccc(O[C@@H](C(=O)N3CCCC3)c3ccccc3)...</td> <td>4.96778</td> <td>0.599682</td> <td>2.470633</td> </tr> <tr> <th>3</th> <td>CCOC(=O)[C@@H]1CCCN(C(=O)c2nc(-c3ccc(C)cc3)n3c...</td> <td>4.00022</td> <td>0.690944</td> <td>2.822753</td> </tr> <tr> <th>4</th> <td>N#CC1=C(SCC(=O)Nc2cccc(Cl)c2)N=C([O-])[C@H](C#...</td> <td>3.60956</td> <td>0.789027</td> <td>4.035182</td> </tr> </tbody> </table> </div> <hr /> <h2 id="hyperparameters">Hyperparameters</h2> <div class="codehilite"><pre><span></span><code><span class="n">SMILE_CHARSET</span> <span class="o">=</span> <span class="s1">&#39;[&quot;C&quot;, &quot;B&quot;, &quot;F&quot;, &quot;I&quot;, &quot;H&quot;, &quot;O&quot;, &quot;N&quot;, &quot;S&quot;, &quot;P&quot;, &quot;Cl&quot;, &quot;Br&quot;]&#39;</span> <span class="n">bond_mapping</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;SINGLE&quot;</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;DOUBLE&quot;</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;TRIPLE&quot;</span><span class="p">:</span> <span class="mi">2</span><span class="p">,</span> <span class="s2">&quot;AROMATIC&quot;</span><span class="p">:</span> <span class="mi">3</span><span class="p">}</span> <span class="n">bond_mapping</span><span class="o">.</span><span class="n">update</span><span class="p">(</span> <span class="p">{</span><span class="mi">0</span><span class="p">:</span> <span class="n">BondType</span><span class="o">.</span><span class="n">SINGLE</span><span class="p">,</span> <span class="mi">1</span><span class="p">:</span> <span class="n">BondType</span><span class="o">.</span><span class="n">DOUBLE</span><span class="p">,</span> <span class="mi">2</span><span class="p">:</span> <span class="n">BondType</span><span class="o">.</span><span class="n">TRIPLE</span><span class="p">,</span> <span class="mi">3</span><span class="p">:</span> <span class="n">BondType</span><span class="o">.</span><span class="n">AROMATIC</span><span class="p">}</span> <span class="p">)</span> <span class="n">SMILE_CHARSET</span> <span class="o">=</span> <span class="n">ast</span><span class="o">.</span><span class="n">literal_eval</span><span class="p">(</span><span class="n">SMILE_CHARSET</span><span class="p">)</span> <span class="n">MAX_MOLSIZE</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">df</span><span class="p">[</span><span class="s2">&quot;smiles&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">str</span><span class="o">.</span><span class="n">len</span><span class="p">())</span> <span class="n">SMILE_to_index</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">((</span><span class="n">c</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">c</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">SMILE_CHARSET</span><span class="p">))</span> <span class="n">index_to_SMILE</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">((</span><span class="n">i</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">c</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">SMILE_CHARSET</span><span class="p">))</span> <span class="n">atom_mapping</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="n">SMILE_to_index</span><span class="p">)</span> <span class="n">atom_mapping</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">index_to_SMILE</span><span class="p">)</span> <span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">100</span> <span class="n">EPOCHS</span> <span class="o">=</span> <span class="mi">10</span> <span class="n">VAE_LR</span> <span class="o">=</span> <span class="mf">5e-4</span> <span class="n">NUM_ATOMS</span> <span class="o">=</span> <span class="mi">120</span> <span class="c1"># Maximum number of atoms</span> <span class="n">ATOM_DIM</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">SMILE_CHARSET</span><span class="p">)</span> <span class="c1"># Number of atom types</span> <span class="n">BOND_DIM</span> <span class="o">=</span> <span class="mi">4</span> <span class="o">+</span> <span class="mi">1</span> <span class="c1"># Number of bond types</span> <span class="n">LATENT_DIM</span> <span class="o">=</span> <span class="mi">435</span> <span class="c1"># Size of the latent space</span> <span class="k">def</span> <span class="nf">smiles_to_graph</span><span class="p">(</span><span class="n">smiles</span><span class="p">):</span> <span class="c1"># Converts SMILES to molecule object</span> <span class="n">molecule</span> <span class="o">=</span> <span class="n">Chem</span><span class="o">.</span><span class="n">MolFromSmiles</span><span class="p">(</span><span class="n">smiles</span><span class="p">)</span> <span class="c1"># Initialize adjacency and feature tensor</span> <span class="n">adjacency</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BOND_DIM</span><span class="p">,</span> <span class="n">NUM_ATOMS</span><span class="p">,</span> <span class="n">NUM_ATOMS</span><span class="p">),</span> <span class="s2">&quot;float32&quot;</span><span class="p">)</span> <span class="n">features</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">NUM_ATOMS</span><span class="p">,</span> <span class="n">ATOM_DIM</span><span class="p">),</span> <span class="s2">&quot;float32&quot;</span><span class="p">)</span> <span class="c1"># loop over each atom in molecule</span> <span class="k">for</span> <span class="n">atom</span> <span class="ow">in</span> <span class="n">molecule</span><span class="o">.</span><span class="n">GetAtoms</span><span class="p">():</span> <span class="n">i</span> <span class="o">=</span> <span class="n">atom</span><span class="o">.</span><span class="n">GetIdx</span><span class="p">()</span> <span class="n">atom_type</span> <span class="o">=</span> <span class="n">atom_mapping</span><span class="p">[</span><span class="n">atom</span><span class="o">.</span><span class="n">GetSymbol</span><span class="p">()]</span> <span class="n">features</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">ATOM_DIM</span><span class="p">)[</span><span class="n">atom_type</span><span class="p">]</span> <span class="c1"># loop over one-hop neighbors</span> <span class="k">for</span> <span class="n">neighbor</span> <span class="ow">in</span> <span class="n">atom</span><span class="o">.</span><span class="n">GetNeighbors</span><span class="p">():</span> <span class="n">j</span> <span class="o">=</span> <span class="n">neighbor</span><span class="o">.</span><span class="n">GetIdx</span><span class="p">()</span> <span class="n">bond</span> <span class="o">=</span> <span class="n">molecule</span><span class="o">.</span><span class="n">GetBondBetweenAtoms</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">)</span> <span class="n">bond_type_idx</span> <span class="o">=</span> <span class="n">bond_mapping</span><span class="p">[</span><span class="n">bond</span><span class="o">.</span><span class="n">GetBondType</span><span class="p">()</span><span class="o">.</span><span class="n">name</span><span class="p">]</span> <span class="n">adjacency</span><span class="p">[</span><span class="n">bond_type_idx</span><span class="p">,</span> <span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">],</span> <span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># Where no bond, add 1 to last channel (indicating &quot;non-bond&quot;)</span> <span class="c1"># Notice: channels-first</span> <span class="n">adjacency</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">adjacency</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># Where no atom, add 1 to last column (indicating &quot;non-atom&quot;)</span> <span class="n">features</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">return</span> <span class="n">adjacency</span><span class="p">,</span> <span class="n">features</span> <span class="k">def</span> <span class="nf">graph_to_molecule</span><span class="p">(</span><span class="n">graph</span><span class="p">):</span> <span class="c1"># Unpack graph</span> <span class="n">adjacency</span><span class="p">,</span> <span class="n">features</span> <span class="o">=</span> <span class="n">graph</span> <span class="c1"># RWMol is a molecule object intended to be edited</span> <span class="n">molecule</span> <span class="o">=</span> <span class="n">Chem</span><span class="o">.</span><span class="n">RWMol</span><span class="p">()</span> <span class="c1"># Remove &quot;no atoms&quot; &amp; atoms with no bonds</span> <span class="n">keep_idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="o">!=</span> <span class="n">ATOM_DIM</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">adjacency</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">)</span> <span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="n">features</span> <span class="o">=</span> <span class="n">features</span><span class="p">[</span><span class="n">keep_idx</span><span class="p">]</span> <span class="n">adjacency</span> <span class="o">=</span> <span class="n">adjacency</span><span class="p">[:,</span> <span class="n">keep_idx</span><span class="p">,</span> <span class="p">:][:,</span> <span class="p">:,</span> <span class="n">keep_idx</span><span class="p">]</span> <span class="c1"># Add atoms to molecule</span> <span class="k">for</span> <span class="n">atom_type_idx</span> <span class="ow">in</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span> <span class="n">atom</span> <span class="o">=</span> <span class="n">Chem</span><span class="o">.</span><span class="n">Atom</span><span class="p">(</span><span class="n">atom_mapping</span><span class="p">[</span><span class="n">atom_type_idx</span><span class="p">])</span> <span class="n">_</span> <span class="o">=</span> <span class="n">molecule</span><span class="o">.</span><span class="n">AddAtom</span><span class="p">(</span><span class="n">atom</span><span class="p">)</span> <span class="c1"># Add bonds between atoms in molecule; based on the upper triangles</span> <span class="c1"># of the [symmetric] adjacency tensor</span> <span class="p">(</span><span class="n">bonds_ij</span><span class="p">,</span> <span class="n">atoms_i</span><span class="p">,</span> <span class="n">atoms_j</span><span class="p">)</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">triu</span><span class="p">(</span><span class="n">adjacency</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="p">(</span><span class="n">bond_ij</span><span class="p">,</span> <span class="n">atom_i</span><span class="p">,</span> <span class="n">atom_j</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">bonds_ij</span><span class="p">,</span> <span class="n">atoms_i</span><span class="p">,</span> <span class="n">atoms_j</span><span class="p">):</span> <span class="k">if</span> <span class="n">atom_i</span> <span class="o">==</span> <span class="n">atom_j</span> <span class="ow">or</span> <span class="n">bond_ij</span> <span class="o">==</span> <span class="n">BOND_DIM</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span> <span class="k">continue</span> <span class="n">bond_type</span> <span class="o">=</span> <span class="n">bond_mapping</span><span class="p">[</span><span class="n">bond_ij</span><span class="p">]</span> <span class="n">molecule</span><span class="o">.</span><span class="n">AddBond</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">atom_i</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">atom_j</span><span class="p">),</span> <span class="n">bond_type</span><span class="p">)</span> <span class="c1"># Sanitize the molecule; for more information on sanitization, see</span> <span class="c1"># https://www.rdkit.org/docs/RDKit_Book.html#molecular-sanitization</span> <span class="n">flag</span> <span class="o">=</span> <span class="n">Chem</span><span class="o">.</span><span class="n">SanitizeMol</span><span class="p">(</span><span class="n">molecule</span><span class="p">,</span> <span class="n">catchErrors</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># Let&#39;s be strict. If sanitization fails, return None</span> <span class="k">if</span> <span class="n">flag</span> <span class="o">!=</span> <span class="n">Chem</span><span class="o">.</span><span class="n">SanitizeFlags</span><span class="o">.</span><span class="n">SANITIZE_NONE</span><span class="p">:</span> <span class="k">return</span> <span class="kc">None</span> <span class="k">return</span> <span class="n">molecule</span> </code></pre></div> <hr /> <h2 id="generate-training-set">Generate training set</h2> <div class="codehilite"><pre><span></span><code><span class="n">train_df</span> <span class="o">=</span> <span class="n">df</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">frac</span><span class="o">=</span><span class="mf">0.75</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">)</span> <span class="c1"># random state is a seed value</span> <span class="n">train_df</span><span class="o">.</span><span class="n">reset_index</span><span class="p">(</span><span class="n">drop</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">adjacency_tensor</span><span class="p">,</span> <span class="n">feature_tensor</span><span class="p">,</span> <span class="n">qed_tensor</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[],</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">8000</span><span class="p">):</span> <span class="n">adjacency</span><span class="p">,</span> <span class="n">features</span> <span class="o">=</span> <span class="n">smiles_to_graph</span><span class="p">(</span><span class="n">train_df</span><span class="o">.</span><span class="n">loc</span><span class="p">[</span><span class="n">idx</span><span class="p">][</span><span class="s2">&quot;smiles&quot;</span><span class="p">])</span> <span class="n">qed</span> <span class="o">=</span> <span class="n">train_df</span><span class="o">.</span><span class="n">loc</span><span class="p">[</span><span class="n">idx</span><span class="p">][</span><span class="s2">&quot;qed&quot;</span><span class="p">]</span> <span class="n">adjacency_tensor</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">adjacency</span><span class="p">)</span> <span class="n">feature_tensor</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">features</span><span class="p">)</span> <span class="n">qed_tensor</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">qed</span><span class="p">)</span> <span class="n">adjacency_tensor</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">adjacency_tensor</span><span class="p">)</span> <span class="n">feature_tensor</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">feature_tensor</span><span class="p">)</span> <span class="n">qed_tensor</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">qed_tensor</span><span class="p">)</span> <span class="k">class</span> <span class="nc">RelationalGraphConvLayer</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">units</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">kernel_initializer</span><span class="o">=</span><span class="s2">&quot;glorot_uniform&quot;</span><span class="p">,</span> <span class="n">bias_initializer</span><span class="o">=</span><span class="s2">&quot;zeros&quot;</span><span class="p">,</span> <span class="n">kernel_regularizer</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">bias_regularizer</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span> <span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span> <span class="o">=</span> <span class="n">units</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">activations</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">activation</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_bias</span> <span class="o">=</span> <span class="n">use_bias</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel_initializer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">kernel_initializer</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias_initializer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">bias_initializer</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel_regularizer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">regularizers</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">kernel_regularizer</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias_regularizer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">regularizers</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">bias_regularizer</span><span class="p">)</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="n">bond_dim</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span> <span class="n">atom_dim</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">][</span><span class="mi">2</span><span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">bond_dim</span><span class="p">,</span> <span class="n">atom_dim</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">),</span> <span class="n">initializer</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">kernel_initializer</span><span class="p">,</span> <span class="n">regularizer</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">kernel_regularizer</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;W&quot;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_bias</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">bond_dim</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">),</span> <span class="n">initializer</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">bias_initializer</span><span class="p">,</span> <span class="n">regularizer</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">bias_regularizer</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;b&quot;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">built</span> <span class="o">=</span> <span class="kc">True</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span> <span class="n">adjacency</span><span class="p">,</span> <span class="n">features</span> <span class="o">=</span> <span class="n">inputs</span> <span class="c1"># Aggregate information from neighbors</span> <span class="n">x</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">adjacency</span><span class="p">,</span> <span class="n">features</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:])</span> <span class="c1"># Apply linear transformation</span> <span class="n">x</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel</span><span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_bias</span><span class="p">:</span> <span class="n">x</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="c1"># Reduce bond types dim</span> <span class="n">x_reduced</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># Apply non-linear transformation</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="n">x_reduced</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="build-the-encoder-and-decoder">Build the Encoder and Decoder</h2> <p>The Encoder takes as input a molecule's graph adjacency matrix and feature matrix. These features are processed via a Graph Convolution layer, then are flattened and processed by several Dense layers to derive <code>z_mean</code> and <code>log_var</code>, the latent-space representation of the molecule.</p> <p><strong>Graph Convolution layer</strong>: The relational graph convolution layer implements non-linearly transformed neighbourhood aggregations. We can define these layers as follows:</p> <p><code>H_hat**(l+1) = σ(D_hat**(-1) * A_hat * H_hat**(l+1) * W**(l))</code></p> <p>Where <code>σ</code> denotes the non-linear transformation (commonly a ReLU activation), <code>A</code> the adjacency tensor, <code>H_hat**(l)</code> the feature tensor at the <code>l-th</code> layer, <code>D_hat**(-1)</code> the inverse diagonal degree tensor of <code>A_hat</code>, and <code>W_hat**(l)</code> the trainable weight tensor at the <code>l-th</code> layer. Specifically, for each bond type (relation), the degree tensor expresses, in the diagonal, the number of bonds attached to each atom.</p> <p>Source: <a href="https://keras.io/examples/generative/wgan-graphs/">WGAN-GP with R-GCN for the generation of small molecular graphs</a>)</p> <p>The Decoder takes as input the latent-space representation and predicts the graph adjacency matrix and feature matrix of the corresponding molecules.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">get_encoder</span><span class="p">(</span> <span class="n">gconv_units</span><span class="p">,</span> <span class="n">latent_dim</span><span class="p">,</span> <span class="n">adjacency_shape</span><span class="p">,</span> <span class="n">feature_shape</span><span class="p">,</span> <span class="n">dense_units</span><span class="p">,</span> <span class="n">dropout_rate</span> <span class="p">):</span> <span class="n">adjacency</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">adjacency_shape</span><span class="p">)</span> <span class="n">features</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">feature_shape</span><span class="p">)</span> <span class="c1"># Propagate through one or more graph convolutional layers</span> <span class="n">features_transformed</span> <span class="o">=</span> <span class="n">features</span> <span class="k">for</span> <span class="n">units</span> <span class="ow">in</span> <span class="n">gconv_units</span><span class="p">:</span> <span class="n">features_transformed</span> <span class="o">=</span> <span class="n">RelationalGraphConvLayer</span><span class="p">(</span><span class="n">units</span><span class="p">)(</span> <span class="p">[</span><span class="n">adjacency</span><span class="p">,</span> <span class="n">features_transformed</span><span class="p">]</span> <span class="p">)</span> <span class="c1"># Reduce 2-D representation of molecule to 1-D</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling1D</span><span class="p">()(</span><span class="n">features_transformed</span><span class="p">)</span> <span class="c1"># Propagate through one or more densely connected layers</span> <span class="k">for</span> <span class="n">units</span> <span class="ow">in</span> <span class="n">dense_units</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_rate</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">z_mean</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">latent_dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&quot;float32&quot;</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;z_mean&quot;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">log_var</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">latent_dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&quot;float32&quot;</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;log_var&quot;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">encoder</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">([</span><span class="n">adjacency</span><span class="p">,</span> <span class="n">features</span><span class="p">],</span> <span class="p">[</span><span class="n">z_mean</span><span class="p">,</span> <span class="n">log_var</span><span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;encoder&quot;</span><span class="p">)</span> <span class="k">return</span> <span class="n">encoder</span> <span class="k">def</span> <span class="nf">get_decoder</span><span class="p">(</span><span class="n">dense_units</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="p">,</span> <span class="n">latent_dim</span><span class="p">,</span> <span class="n">adjacency_shape</span><span class="p">,</span> <span class="n">feature_shape</span><span class="p">):</span> <span class="n">latent_inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">latent_dim</span><span class="p">,))</span> <span class="n">x</span> <span class="o">=</span> <span class="n">latent_inputs</span> <span class="k">for</span> <span class="n">units</span> <span class="ow">in</span> <span class="n">dense_units</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;tanh&quot;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_rate</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Map outputs of previous layer (x) to [continuous] adjacency tensors (x_adjacency)</span> <span class="n">x_adjacency</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">math</span><span class="o">.</span><span class="n">reduce_prod</span><span class="p">(</span><span class="n">adjacency_shape</span><span class="p">))(</span><span class="n">x</span><span class="p">)</span> <span class="n">x_adjacency</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Reshape</span><span class="p">(</span><span class="n">adjacency_shape</span><span class="p">)(</span><span class="n">x_adjacency</span><span class="p">)</span> <span class="c1"># Symmetrify tensors in the last two dimensions</span> <span class="n">x_adjacency</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_adjacency</span> <span class="o">+</span> <span class="n">tf</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">x_adjacency</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">)))</span> <span class="o">/</span> <span class="mi">2</span> <span class="n">x_adjacency</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)(</span><span class="n">x_adjacency</span><span class="p">)</span> <span class="c1"># Map outputs of previous layer (x) to [continuous] feature tensors (x_features)</span> <span class="n">x_features</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">math</span><span class="o">.</span><span class="n">reduce_prod</span><span class="p">(</span><span class="n">feature_shape</span><span class="p">))(</span><span class="n">x</span><span class="p">)</span> <span class="n">x_features</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Reshape</span><span class="p">(</span><span class="n">feature_shape</span><span class="p">)(</span><span class="n">x_features</span><span class="p">)</span> <span class="n">x_features</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">2</span><span class="p">)(</span><span class="n">x_features</span><span class="p">)</span> <span class="n">decoder</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span> <span class="n">latent_inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="p">[</span><span class="n">x_adjacency</span><span class="p">,</span> <span class="n">x_features</span><span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;decoder&quot;</span> <span class="p">)</span> <span class="k">return</span> <span class="n">decoder</span> </code></pre></div> <hr /> <h2 id="build-the-sampling-layer">Build the Sampling layer</h2> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">Sampling</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">z_mean</span><span class="p">,</span> <span class="n">z_log_var</span> <span class="o">=</span> <span class="n">inputs</span> <span class="n">batch</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">z_log_var</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="n">dim</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">z_log_var</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span> <span class="n">epsilon</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">backend</span><span class="o">.</span><span class="n">random_normal</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">dim</span><span class="p">))</span> <span class="k">return</span> <span class="n">z_mean</span> <span class="o">+</span> <span class="n">tf</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">z_log_var</span><span class="p">)</span> <span class="o">*</span> <span class="n">epsilon</span> </code></pre></div> <hr /> <h2 id="build-the-vae">Build the VAE</h2> <p>This model is trained to optimize four losses:</p> <ul> <li>Categorical crossentropy</li> <li>KL divergence loss</li> <li>Property prediction loss</li> <li>Graph loss (gradient penalty)</li> </ul> <p>The categorical crossentropy loss function measures the model's reconstruction accuracy. The Property prediction loss estimates the mean squared error between predicted and actual properties after running the latent representation through a property prediction model. The property prediction of the model is optimized via binary crossentropy. The gradient penalty is further guided by the model's property (QED) prediction.</p> <p>A gradient penalty is an alternative soft constraint on the 1-Lipschitz continuity as an improvement upon the gradient clipping scheme from the original neural network ("1-Lipschitz continuity" means that the norm of the gradient is at most 1 at every single point of the function). It adds a regularization term to the loss function.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">MoleculeGenerator</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoder</span><span class="p">,</span> <span class="n">decoder</span><span class="p">,</span> <span class="n">max_len</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">decoder</span> <span class="bp">self</span><span class="o">.</span><span class="n">property_prediction_layer</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_len</span> <span class="o">=</span> <span class="n">max_len</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_total_loss_tracker</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;train_total_loss&quot;</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">val_total_loss_tracker</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;val_total_loss&quot;</span><span class="p">)</span> <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="n">adjacency_tensor</span><span class="p">,</span> <span class="n">feature_tensor</span><span class="p">,</span> <span class="n">qed_tensor</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="n">graph_real</span> <span class="o">=</span> <span class="p">[</span><span class="n">adjacency_tensor</span><span class="p">,</span> <span class="n">feature_tensor</span><span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">qed_tensor</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="n">z_mean</span><span class="p">,</span> <span class="n">z_log_var</span><span class="p">,</span> <span class="n">qed_pred</span><span class="p">,</span> <span class="n">gen_adjacency</span><span class="p">,</span> <span class="n">gen_features</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span> <span class="n">graph_real</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="n">graph_generated</span> <span class="o">=</span> <span class="p">[</span><span class="n">gen_adjacency</span><span class="p">,</span> <span class="n">gen_features</span><span class="p">]</span> <span class="n">total_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_compute_loss</span><span class="p">(</span> <span class="n">z_log_var</span><span class="p">,</span> <span class="n">z_mean</span><span class="p">,</span> <span class="n">qed_tensor</span><span class="p">,</span> <span class="n">qed_pred</span><span class="p">,</span> <span class="n">graph_real</span><span class="p">,</span> <span class="n">graph_generated</span> <span class="p">)</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">total_loss</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">grads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">))</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_total_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">total_loss</span><span class="p">)</span> <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;loss&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_total_loss_tracker</span><span class="o">.</span><span class="n">result</span><span class="p">()}</span> <span class="k">def</span> <span class="nf">_compute_loss</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">z_log_var</span><span class="p">,</span> <span class="n">z_mean</span><span class="p">,</span> <span class="n">qed_true</span><span class="p">,</span> <span class="n">qed_pred</span><span class="p">,</span> <span class="n">graph_real</span><span class="p">,</span> <span class="n">graph_generated</span> <span class="p">):</span> <span class="n">adjacency_real</span><span class="p">,</span> <span class="n">features_real</span> <span class="o">=</span> <span class="n">graph_real</span> <span class="n">adjacency_gen</span><span class="p">,</span> <span class="n">features_gen</span> <span class="o">=</span> <span class="n">graph_generated</span> <span class="n">adjacency_loss</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">categorical_crossentropy</span><span class="p">(</span><span class="n">adjacency_real</span><span class="p">,</span> <span class="n">adjacency_gen</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="p">)</span> <span class="p">)</span> <span class="n">features_loss</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">categorical_crossentropy</span><span class="p">(</span><span class="n">features_real</span><span class="p">,</span> <span class="n">features_gen</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="p">)</span> <span class="p">)</span> <span class="n">kl_loss</span> <span class="o">=</span> <span class="o">-</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span> <span class="mi">1</span> <span class="o">+</span> <span class="n">z_log_var</span> <span class="o">-</span> <span class="n">tf</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">z_mean</span><span class="p">)</span> <span class="o">-</span> <span class="n">tf</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z_log_var</span><span class="p">),</span> <span class="mi">1</span> <span class="p">)</span> <span class="n">kl_loss</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">kl_loss</span><span class="p">)</span> <span class="n">property_loss</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">binary_crossentropy</span><span class="p">(</span><span class="n">qed_true</span><span class="p">,</span> <span class="n">qed_pred</span><span class="p">)</span> <span class="p">)</span> <span class="n">graph_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_gradient_penalty</span><span class="p">(</span><span class="n">graph_real</span><span class="p">,</span> <span class="n">graph_generated</span><span class="p">)</span> <span class="k">return</span> <span class="n">kl_loss</span> <span class="o">+</span> <span class="n">property_loss</span> <span class="o">+</span> <span class="n">graph_loss</span> <span class="o">+</span> <span class="n">adjacency_loss</span> <span class="o">+</span> <span class="n">features_loss</span> <span class="k">def</span> <span class="nf">_gradient_penalty</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph_real</span><span class="p">,</span> <span class="n">graph_generated</span><span class="p">):</span> <span class="c1"># Unpack graphs</span> <span class="n">adjacency_real</span><span class="p">,</span> <span class="n">features_real</span> <span class="o">=</span> <span class="n">graph_real</span> <span class="n">adjacency_generated</span><span class="p">,</span> <span class="n">features_generated</span> <span class="o">=</span> <span class="n">graph_generated</span> <span class="c1"># Generate interpolated graphs (adjacency_interp and features_interp)</span> <span class="n">alpha</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">])</span> <span class="n">alpha</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">adjacency_interp</span> <span class="o">=</span> <span class="p">(</span><span class="n">adjacency_real</span> <span class="o">*</span> <span class="n">alpha</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">alpha</span><span class="p">)</span> <span class="o">*</span> <span class="n">adjacency_generated</span> <span class="n">alpha</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">features_interp</span> <span class="o">=</span> <span class="p">(</span><span class="n">features_real</span> <span class="o">*</span> <span class="n">alpha</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">alpha</span><span class="p">)</span> <span class="o">*</span> <span class="n">features_generated</span> <span class="c1"># Compute the logits of interpolated graphs</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="n">tape</span><span class="o">.</span><span class="n">watch</span><span class="p">(</span><span class="n">adjacency_interp</span><span class="p">)</span> <span class="n">tape</span><span class="o">.</span><span class="n">watch</span><span class="p">(</span><span class="n">features_interp</span><span class="p">)</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span> <span class="p">[</span><span class="n">adjacency_interp</span><span class="p">,</span> <span class="n">features_interp</span><span class="p">],</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="c1"># Compute the gradients with respect to the interpolated graphs</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="p">[</span><span class="n">adjacency_interp</span><span class="p">,</span> <span class="n">features_interp</span><span class="p">])</span> <span class="c1"># Compute the gradient penalty</span> <span class="n">grads_adjacency_penalty</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">tf</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">grads</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span> <span class="o">**</span> <span class="mi">2</span> <span class="n">grads_features_penalty</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">tf</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">grads</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">2</span><span class="p">))</span> <span class="o">**</span> <span class="mi">2</span> <span class="k">return</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">grads_adjacency_penalty</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span> <span class="o">+</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">grads_features_penalty</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">inference</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">):</span> <span class="n">z</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">LATENT_DIM</span><span class="p">))</span> <span class="n">reconstruction_adjacency</span><span class="p">,</span> <span class="n">reconstruction_features</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">decoder</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">z</span><span class="p">)</span> <span class="c1"># obtain one-hot encoded adjacency tensor</span> <span class="n">adjacency</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">reconstruction_adjacency</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="n">adjacency</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">adjacency</span><span class="p">,</span> <span class="n">depth</span><span class="o">=</span><span class="n">BOND_DIM</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># Remove potential self-loops from adjacency</span> <span class="n">adjacency</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">linalg</span><span class="o">.</span><span class="n">set_diag</span><span class="p">(</span><span class="n">adjacency</span><span class="p">,</span> <span class="n">tf</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">adjacency</span><span class="p">)[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]))</span> <span class="c1"># obtain one-hot encoded feature tensor</span> <span class="n">features</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">reconstruction_features</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> <span class="n">features</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="n">depth</span><span class="o">=</span><span class="n">ATOM_DIM</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> <span class="k">return</span> <span class="p">[</span> <span class="n">graph_to_molecule</span><span class="p">([</span><span class="n">adjacency</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">features</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">numpy</span><span class="p">()])</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span> <span class="p">]</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">z_mean</span><span class="p">,</span> <span class="n">log_var</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">z</span> <span class="o">=</span> <span class="n">Sampling</span><span class="p">()([</span><span class="n">z_mean</span><span class="p">,</span> <span class="n">log_var</span><span class="p">])</span> <span class="n">gen_adjacency</span><span class="p">,</span> <span class="n">gen_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">z</span><span class="p">)</span> <span class="n">property_pred</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">property_prediction_layer</span><span class="p">(</span><span class="n">z_mean</span><span class="p">)</span> <span class="k">return</span> <span class="n">z_mean</span><span class="p">,</span> <span class="n">log_var</span><span class="p">,</span> <span class="n">property_pred</span><span class="p">,</span> <span class="n">gen_adjacency</span><span class="p">,</span> <span class="n">gen_features</span> </code></pre></div> <hr /> <h2 id="train-the-model">Train the model</h2> <div class="codehilite"><pre><span></span><code><span class="n">vae_optimizer</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="n">VAE_LR</span><span class="p">)</span> <span class="n">encoder</span> <span class="o">=</span> <span class="n">get_encoder</span><span class="p">(</span> <span class="n">gconv_units</span><span class="o">=</span><span class="p">[</span><span class="mi">9</span><span class="p">],</span> <span class="n">adjacency_shape</span><span class="o">=</span><span class="p">(</span><span class="n">BOND_DIM</span><span class="p">,</span> <span class="n">NUM_ATOMS</span><span class="p">,</span> <span class="n">NUM_ATOMS</span><span class="p">),</span> <span class="n">feature_shape</span><span class="o">=</span><span class="p">(</span><span class="n">NUM_ATOMS</span><span class="p">,</span> <span class="n">ATOM_DIM</span><span class="p">),</span> <span class="n">latent_dim</span><span class="o">=</span><span class="n">LATENT_DIM</span><span class="p">,</span> <span class="n">dense_units</span><span class="o">=</span><span class="p">[</span><span class="mi">512</span><span class="p">],</span> <span class="n">dropout_rate</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="p">)</span> <span class="n">decoder</span> <span class="o">=</span> <span class="n">get_decoder</span><span class="p">(</span> <span class="n">dense_units</span><span class="o">=</span><span class="p">[</span><span class="mi">128</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">512</span><span class="p">],</span> <span class="n">dropout_rate</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">latent_dim</span><span class="o">=</span><span class="n">LATENT_DIM</span><span class="p">,</span> <span class="n">adjacency_shape</span><span class="o">=</span><span class="p">(</span><span class="n">BOND_DIM</span><span class="p">,</span> <span class="n">NUM_ATOMS</span><span class="p">,</span> <span class="n">NUM_ATOMS</span><span class="p">),</span> <span class="n">feature_shape</span><span class="o">=</span><span class="p">(</span><span class="n">NUM_ATOMS</span><span class="p">,</span> <span class="n">ATOM_DIM</span><span class="p">),</span> <span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">MoleculeGenerator</span><span class="p">(</span><span class="n">encoder</span><span class="p">,</span> <span class="n">decoder</span><span class="p">,</span> <span class="n">MAX_MOLSIZE</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">vae_optimizer</span><span class="p">)</span> <span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">([</span><span class="n">adjacency_tensor</span><span class="p">,</span> <span class="n">feature_tensor</span><span class="p">,</span> <span class="n">qed_tensor</span><span class="p">],</span> <span class="n">epochs</span><span class="o">=</span><span class="n">EPOCHS</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/10 250/250 [==============================] - 24s 84ms/step - loss: 68958.3946 Epoch 2/10 250/250 [==============================] - 20s 79ms/step - loss: 68819.8421 Epoch 3/10 250/250 [==============================] - 20s 79ms/step - loss: 68830.6720 Epoch 4/10 250/250 [==============================] - 20s 79ms/step - loss: 68816.1486 Epoch 5/10 250/250 [==============================] - 20s 79ms/step - loss: 68825.9977 Epoch 6/10 250/250 [==============================] - 19s 78ms/step - loss: 68818.0771 Epoch 7/10 250/250 [==============================] - 19s 77ms/step - loss: 68815.8525 Epoch 8/10 250/250 [==============================] - 20s 78ms/step - loss: 68820.5459 Epoch 9/10 250/250 [==============================] - 21s 83ms/step - loss: 68806.9465 Epoch 10/10 250/250 [==============================] - 21s 84ms/step - loss: 68805.9879 </code></pre></div> </div> <hr /> <h2 id="inference">Inference</h2> <p>We use our model to generate new valid molecules from different points of the latent space.</p> <h3 id="generate-unique-molecules-with-the-model">Generate unique Molecules with the model</h3> <div class="codehilite"><pre><span></span><code><span class="n">molecules</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">inference</span><span class="p">(</span><span class="mi">1000</span><span class="p">)</span> <span class="n">MolsToGridImage</span><span class="p">(</span> <span class="p">[</span><span class="n">m</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="n">molecules</span> <span class="k">if</span> <span class="n">m</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">][:</span><span class="mi">1000</span><span class="p">],</span> <span class="n">molsPerRow</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">subImgSize</span><span class="o">=</span><span class="p">(</span><span class="mi">260</span><span class="p">,</span> <span class="mi">160</span><span class="p">)</span> <span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/generative/molecule_generation/molecule_generation_21_0.png" /></p> <h3 id="display-latent-space-clusters-with-respect-to-molecular-properties-qae">Display latent space clusters with respect to molecular properties (QAE)</h3> <hr /> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">plot_latent</span><span class="p">(</span><span class="n">vae</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span> <span class="c1"># display a 2D plot of the property in the latent space</span> <span class="n">z_mean</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">vae</span><span class="o">.</span><span class="n">encoder</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">z_mean</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">z_mean</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">c</span><span class="o">=</span><span class="n">labels</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">colorbar</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s2">&quot;z[0]&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s2">&quot;z[1]&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="n">plot_latent</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="p">[</span><span class="n">adjacency_tensor</span><span class="p">[:</span><span class="mi">8000</span><span class="p">],</span> <span class="n">feature_tensor</span><span class="p">[:</span><span class="mi">8000</span><span class="p">]],</span> <span class="n">qed_tensor</span><span class="p">[:</span><span class="mi">8000</span><span class="p">])</span> </code></pre></div> <p><img alt="png" src="/img/examples/generative/molecule_generation/molecule_generation_23_0.png" /></p> <hr /> <h2 id="conclusion">Conclusion</h2> <p>In this example, we combined model architectures from two papers, "Automatic chemical design using a data-driven continuous representation of molecules" from 2016 and the "MolGAN" paper from 2018. The former paper treats SMILES inputs as strings and seeks to generate molecule strings in SMILES format, while the later paper considers SMILES inputs as graphs (a combination of adjacency matrices and feature matrices) and seeks to generate molecules as graphs.</p> <p>This hybrid approach enables a new type of directed gradient-based search through chemical space.</p> <p>Example available on HuggingFace</p> <table> <thead> <tr> <th style="text-align: center;">Trained Model</th> <th style="text-align: center;">Demo</th> </tr> </thead> <tbody> <tr> <td style="text-align: center;"><a href="https://huggingface.co/keras-io/drug-molecule-generation-with-VAE"><img alt="Generic badge" src="https://img.shields.io/badge/%F0%9F%A4%97%20Model-molecule%20generation%20with%20VAE-black.svg" /></a></td> <td style="text-align: center;"><a href="https://huggingface.co/spaces/keras-io/generating-drug-molecule-with-VAE"><img alt="Generic badge" src="https://img.shields.io/badge/%F0%9F%A4%97%20Spaces-molecule%20generation%20with%20VAE-black.svg" /></a></td> </tr> </tbody> </table> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#drug-molecule-generation-with-vae'>Drug Molecule Generation with VAE</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#dataset'>Dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#hyperparameters'>Hyperparameters</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#-generate-training-set'> Generate training set</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#build-the-encoder-and-decoder'>Build the Encoder and Decoder</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#build-the-sampling-layer'>Build the Sampling layer</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#build-the-vae'>Build the VAE</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-the-model'>Train the model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#inference'>Inference</a> </div> <div class='k-outline-depth-3'> <a href='#generate-unique-molecules-with-the-model'>Generate unique Molecules with the model</a> </div> <div class='k-outline-depth-3'> <a href='#display-latent-space-clusters-with-respect-to-molecular-properties-qae'>Display latent space clusters with respect to molecular properties (QAE)</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#conclusion'>Conclusion</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>

Pages: 1 2 3 4 5 6 7 8 9 10