CINXE.COM

Do Machine Learning Models Memorize or Generalize?

<!-- @license Copyright 2020 Google. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --> <!DOCTYPE html> <html> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <link rel="apple-touch-icon" sizes="180x180" href="https://pair.withgoogle.com/images/favicon/apple-touch-icon.png"> <link rel="icon" type="image/png" sizes="32x32" href="https://pair.withgoogle.com/images/favicon/favicon-32x32.png"> <link rel="icon" type="image/png" sizes="16x16" href="https://pair.withgoogle.com/images/favicon/favicon-16x16.png"> <link rel="mask-icon" href="https://pair.withgoogle.com/images/favicon/safari-pinned-tab.svg" color="#00695c"> <script> !(function(){ var url = window.location.href if (url.split('#')[0].split('?')[0].slice(-1) != '/' && !url.includes('.html')) window.location = url + '/' })() </script> <title>Do Machine Learning Models Memorize or Generalize?</title> <meta property="og:title" content="Do Machine Learning Models Memorize or Generalize?"> <meta property="og:url" content="https://pair.withgoogle.com/explorables/grokking/"> <meta name="og:description" content="An interactive introduction to grokking and mechanistic interpretability."> <meta property="og:image" content="https://pair.withgoogle.com/explorables/images/grokking.png"> <meta name="twitter:card" content="summary_large_image"> <link rel="stylesheet" type="text/css" href="../style_v2.css"> <link href='https://fonts.googleapis.com/css?family=Roboto+Slab:400,500,700|Roboto:700,500,300' rel='stylesheet' type='text/css'> <link href="https://fonts.googleapis.com/css?family=Google+Sans:400,500,700" rel="stylesheet"> <meta name="viewport" content="width=device-width"> </head> <body> <div class='header'> <div class='header-left'> <a href='https://pair.withgoogle.com/'> <img src='../images/pair-logo.svg' style='width: 100px'></img> </a> <a href='../'>Explorables</a> </div> </div> <h1 class='headline'>Do Machine Learning Models Memorize or Generalize?</h1> <div class='byline'> <div class='authors'> By <b> Adam Pearce</b>, <b> Asma Ghandeharioun</b>, <b> Nada Hussein</b>, <b> Nithum Thain</b>, <b> Martin Wattenberg</b> and <b> Lucas Dixon </b> </div> <div class='date'>August 2023</div> </div> <p>In 2021, researchers made a striking discovery while training a series of tiny models on toy tasks <a class='citestart' key='Grokking'></a>. They found a set of models that suddenly flipped from memorizing their training data to correctly generalizing on unseen inputs after training for much longer. This phenomenon – where generalization seems to happen abruptly and long after fitting the training data – is called <em>grokking</em> and has sparked a flurry of interest <a class='citestart' key='Omnigrok Universality Zhong23 ProgressParity gromov'></a>.</p> <div class='sticky-container'> <div class='mod-top-accuracy row sticky'></div> <p>Do more complex models also suddenly generalize after they’re trained longer? Large language models can certainly seem like they have a rich understanding of the world, but they might just be regurgitating memorized bits of the enormous amount of text they’ve been trained on <a class='citestart' key='Parrots Othello'></a>. How can we tell if they’re generalizing or memorizing?</p> <p>In this article we’ll examine the training dynamics of a tiny model and reverse engineer the solution it finds – and in the process provide an illustration of the exciting emerging field of mechanistic interpretability <a class='citestart' key='MechInterp ProgressMeasures'></a>. While it isn’t yet clear how to apply these techniques to today’s largest models, starting small makes it easier to develop intuitions as we progress towards answering these critical questions about large language models.</p> <h3 id="grokking-modular-addition">Grokking Modular Addition</h3> <p>Modular addition is essentially the fruit fly of grokking.<a class='footstart' key='modular'></a> The above line chart comes from a model trained to predict <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>a</mi><mo>+</mo><mi>b</mi><mtext> </mtext><mo lspace="0.22em" rspace="0.22em"><mrow><mi mathvariant="normal">m</mi><mi mathvariant="normal">o</mi><mi mathvariant="normal">d</mi></mrow></mo><mtext> </mtext><mn>67</mn></mrow><annotation encoding="application/x-tex">a + b \bmod 67</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6667em;vertical-align:-0.0833em;"></span><span class="mord mathnormal">a</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal">b</span><span class="mspace" style="margin-right:0.0556em;"></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin"><span class="mord"><span class="mord mathrm">mod</span></span></span><span class="mspace" style="margin-right:0.0556em;"></span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">67</span></span></span></span>.<a class='footstart' key='67' ></a> We start by randomly dividing all the <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>a</mi><mo separator="true">,</mo><mi>b</mi></mrow><annotation encoding="application/x-tex">a, b</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal">a</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathnormal">b</span></span></span></span> pairs into test and training datasets. Over thousands of training steps, the training data is used to adjust the model into outputting correct answers, while the test data is only used to check if the model has learned a general solution.</p> <p>The model’s architecture is similarly simple: <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mtext>ReLU</mtext><mrow><mo fence="true">(</mo><msub><mi mathvariant="bold">a</mi><mtext>one-hot</mtext></msub><msub><mi mathvariant="bold">W</mi><mtext>input</mtext></msub><mo>+</mo><msub><mi mathvariant="bold">b</mi><mtext>one-hot</mtext></msub><msub><mi mathvariant="bold">W</mi><mtext>input</mtext></msub><mo fence="true">)</mo></mrow><msub><mi mathvariant="bold">W</mi><mtext>output</mtext></msub></mrow><annotation encoding="application/x-tex">\text{ReLU}\left(\mathbf{a}_{\text{one-hot}} \mathbf{W}_{\text{input}} + \mathbf{b}_{\text{one-hot}} \mathbf{W}_{\text{input}}\right) \mathbf{W}_{\text{output}} </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.0361em;vertical-align:-0.2861em;"></span><span class="mord text"><span class="mord">ReLU</span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="minner"><span class="mopen delimcenter" style="top:0em;">(</span><span class="mord"><span class="mord mathbf">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">one-hot</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">input</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord"><span class="mord mathbf">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">one-hot</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">input</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mclose delimcenter" style="top:0em;">)</span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">output</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> — a one-layer MLP with 24 neurons.<a class='footstart' key='playground'></a> All the weights of the model are shown in the heatmap below; you can see how they change during training by mousing over the line chart above.</p> <div class='sticky-container'> <div class='mod-top-weights row x-sticky x-sticky-lower'></div> <p>The model makes a prediction by selecting the two columns of <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">W</mi><mtext>input</mtext></msub></mrow><annotation encoding="application/x-tex">\mathbf{W}_{\text{input}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">input</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> corresponding to inputs <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>a</mi></mrow><annotation encoding="application/x-tex">a</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal">a</span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>b</mi></mrow><annotation encoding="application/x-tex">b</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal">b</span></span></span></span> then adding them together to create a vector of 24 separate numbers. Next it sets all the negative numbers in the vector to 0 and finally outputs the column of <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">W</mi><mtext>output</mtext></msub></mrow><annotation encoding="application/x-tex">\mathbf{W}_{\text{output}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">output</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> that’s closest to the updated vector.</p> <p>The weights of the model are initially quite noisy but start to exhibit periodic patterns as accuracy on the test data increases and the model <animate data-animate='top-switches'>switches</animate> to generalizing. By the end of training, each neuron — each row of the heatmap — cycles through high and low values several times as the input number increases from 0 to 66.</p> <p>This is easier to see if we group the neurons by how often they cycle at the end of training and chart each of them as a separate line:</p> <p></div> </div></p> <div class='mod-top-waves row'></div> <p>The periodic patterns suggest the model is learning some sort of mathematical structure; the fact that it happens when the model starts to solve the test examples hints that it’s related to the model generalizing. But <em>why</em> does the model move away from the memorizing solution? And <em>what</em> is the generalizing solution?</p> <h3 id="generalizing-with-1s-and-0s">Generalizing With 1s and 0s</h3> <p>Figuring out both of these questions simultaneously is hard. Let’s make an even simpler task, one where we know what the generalizing solution should look like and try to understand why the model eventually learns it.</p> <p>We’ll take random sequences of thirty 1s and 0s and train our model to predict if there is an odd number of 1s in the first three digits. e.g. <digits>000110010110001010111001001011</digits> is <digits>0</digits> while <digits>010110010110001010111001001011</digits> is <digits>1</digits> — basically a slightly trickier XOR with some distraction noise. A generalizing model should only use the first three digits of the sequence; if the model is memorizing the training data, it will also use the subsequent distracting digits <a class='citestart' key='ProgressParity TwoCircuits'></a>.</p> <p>Our model is again a one-layer MLP, trained on a fixed batch of 1,200 sequences.<a class='footstart' key='sp-model'></a> At first only training accuracy increases — the model is memorizing the training data. As with modular arithmetic, test accuracy is essentially random and then sharply rises as the model learns a general solution.</p> <div class='parity-accuracy row'></div> <p>While <animate data-animate='sp-mem'>memorizing</animate> , the model looks dense and noisy with lots of high magnitude weights (shown as dark red and blue squares) spread across the chart below – the model is using all the inputs to make a prediction. As the model <animate data-animate='sp-gen'>generalizes</animate> and gets perfect test accuracy, we see all the weights connected to the distracting digits gray out with very low values and the model focusing on the first three digits — mirroring the generalized structure we expected!<a class='footstart' key='sp-solution'></a></p> <div class='parity-weights row'></div> <p>With this simplified example it’s easier to see why this happens: we’re pushing our model to do two things during training — output a high probability for the correct label (called minimizing <em>loss</em> <a class='footstart' key='loss'></a>) and have weights with low magnitudes (known as <em>weight decay</em> <a class='footstart' key='sp-l2'></a>). <span class='inline-train'>Train loss</span> actually slightly increases before the model generalizes as it exchanges loss related to outputting the correct label for having lower weights.</p> <div class='parity-loss row'></div> <p>The sharp drop in <span class='inline-test'>test loss</span> makes it appear like the model makes a sudden shift to generalization. But if we look at the weights of the model over training, most of them smoothly interpolate between the two solutions. The rapid generalization occurs when the last weights connected to the distracting digits are pruned by weight decay.</p> <div class='parity-weights-trajectory row'></div> <h3 id="when-does-grokking-happen-">When Does Grokking Happen?</h3> <p>It’s important to note that grokking is a contingent phenomenon — it goes away if model size, weight decay, data size and other hyper parameters aren’t just right. With too little weight decay, the model can’t escape overfitting the training data.<a class='footstart' key='overfit'></a> Adding more weight decay pushes the model to generalize after memorizing. Increasing weight decay even more causes test and train loss to fall together; the model goes straight to generalizing. And with too much weight decay the model will fail to learn anything.</p> <p>Below, we’ve trained over a thousand models on the 1s and 0s task with different hyperparameters. Training is noisy so nine models have been trained for each set of hyperparameters.</p> <div class='sparse-parity-sweep'></div> <p><br> <br> <br></p> <p>We can induce memorization and generalization on this somewhat contrived 1s and 0s task — but why does it happen with modular addition? Let’s first understand a little more about how a one-layer MLP can solve modular addition by constructing a generalizing solution that’s interpretable.</p> <h3 id="modular-addition-with-five-neurons">Modular Addition With Five Neurons</h3> <p>Recall that our modular arithmetic problem <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>a</mi><mo>+</mo><mi>b</mi><mtext> </mtext><mo lspace="0.22em" rspace="0.22em"><mrow><mi mathvariant="normal">m</mi><mi mathvariant="normal">o</mi><mi mathvariant="normal">d</mi></mrow></mo><mtext> </mtext><mn>67</mn></mrow><annotation encoding="application/x-tex">a + b \bmod 67</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6667em;vertical-align:-0.0833em;"></span><span class="mord mathnormal">a</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal">b</span><span class="mspace" style="margin-right:0.0556em;"></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin"><span class="mord"><span class="mord mathrm">mod</span></span></span><span class="mspace" style="margin-right:0.0556em;"></span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">67</span></span></span></span> is naturally periodic, with answers wrapping around if the sum ever passes 67. Mathematically, this can be mirrored by thinking of the sum as wrapping <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>a</mi></mrow><annotation encoding="application/x-tex">a</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal">a</span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>b</mi></mrow><annotation encoding="application/x-tex">b</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal">b</span></span></span></span> around a circle. The weights of the generalizing model also had periodic patterns, indicating that the solution might use this property.</p> <p>We can train a simpler model with a head start on the problem, constructing an embedding matrix that places <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>a</mi></mrow><annotation encoding="application/x-tex">a</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal">a</span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>b</mi></mrow><annotation encoding="application/x-tex">b</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal">b</span></span></span></span> on a circle by computing <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>cos</mi><mo>⁡</mo></mrow><annotation encoding="application/x-tex">\cos</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mop">cos</span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>sin</mi><mo>⁡</mo></mrow><annotation encoding="application/x-tex">\sin</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6679em;"></span><span class="mop">sin</span></span></span></span> for each possible input number <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>i</mi><mo>∈</mo><mo stretchy="false">{</mo><mn>0</mn><mo separator="true">,</mo><mn>1</mn><mo separator="true">,</mo><mo>…</mo><mo separator="true">,</mo><mn>66</mn><mo stretchy="false">}</mo></mrow><annotation encoding="application/x-tex">i \in \{0, 1, \ldots, 66\}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6986em;vertical-align:-0.0391em;"></span><span class="mord mathnormal">i</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">{</span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">1</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="minner">…</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">66</span><span class="mclose">}</span></span></span></span>.<a class='footstart' key='unit-circle'></a></p> <script type="math/tex"> $$ \mathbf{W}_{\text{embed}} = \begin{pmatrix} \dots & \ldots \\ \cos(i\frac{2\pi}{67}) & \sin(i \frac{2\pi}{67}) \\ \dots & \dots \\ \end{pmatrix} \quad $$ </script> <div class='row'><div class='embed'></div></div> <p><br> Then we train <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">W</mi><mtext>in-proj</mtext></msub></mrow><annotation encoding="application/x-tex">\mathbf{W}_{\text{in-proj}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">in-proj</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">W</mi><mtext>out-proj</mtext></msub></mrow><annotation encoding="application/x-tex">\mathbf{W}_{\text{out-proj}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">out-proj</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> in this one-layer MLP:</p> <script type="math/tex"> $$ \begin{aligned} \text{activations} & = \text{ReLU}\left(\mathbf{a}_{\text{one-hot}} \mathbf{W}_{\text{embed}} \mathbf{W}_{\text{in-proj}} + \mathbf{b}_{\text{one-hot}} \mathbf{W}_{\text{embed}} \mathbf{W}_{\text{in-proj}}\right) \\ \text{logits} & = \text{activations} \mathbf{W}_{\text{out-proj}} \mathbf{W}_{\text{embed}}^{\top} \end{aligned} $$ </script> <p>With just five neurons the model finds a solution with perfect accuracy.</p> <div class='sticky-container'> <div class='five-neuron-accuracy row sticky'></div> <div class='five-neuron-embed row'></div> <p>Eyeballing the trained parameters, all the neurons <animate data-animate='five-neuron-converge'>converge</animate> to roughly equal norms. If we directly plot their <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>cos</mi><mo>⁡</mo></mrow><annotation encoding="application/x-tex">\cos</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mop">cos</span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>sin</mi><mo>⁡</mo></mrow><annotation encoding="application/x-tex">\sin</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6679em;"></span><span class="mop">sin</span></span></span></span> components, they’re essentially evenly distributed around a circle:</p> <div class='five-neuron-circle row'></div> <p>Connect the adjacent neurons on the <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">W</mi><mtext>in-proj</mtext></msub></mrow><annotation encoding="application/x-tex">\mathbf{W}_{\text{in-proj}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">in-proj</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> circle and an intriguing pattern emerges: <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">W</mi><mtext>out-proj</mtext></msub></mrow><annotation encoding="application/x-tex">\mathbf{W}_{\text{out-proj}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">out-proj</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> is rotating around the circle twice as fast as <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">W</mi><mtext>in-proj</mtext></msub></mrow><annotation encoding="application/x-tex">\mathbf{W}_{\text{in-proj}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">in-proj</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span>.</p> <p></div></p> <div class='five-neuron-circle-2 row'></div> <p>The details of how this solution works aren’t essential — check out <a href="#appendix-a-how-the-circular-construction-works">Appendix A</a> to see how the doubled rotation allows the model to map inputs like <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mn>1</mn><mo>+</mo><mn>0</mn><mtext> </mtext><mo lspace="0.22em" rspace="0.22em"><mrow><mi mathvariant="normal">m</mi><mi mathvariant="normal">o</mi><mi mathvariant="normal">d</mi></mrow></mo><mtext> </mtext><mn>67</mn></mrow><annotation encoding="application/x-tex">1 + 0 \bmod 67</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7278em;vertical-align:-0.0833em;"></span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord">0</span><span class="mspace" style="margin-right:0.0556em;"></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin"><span class="mord"><span class="mord mathrm">mod</span></span></span><span class="mspace" style="margin-right:0.0556em;"></span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">67</span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mn>2</mn><mo>+</mo><mn>66</mn><mtext> </mtext><mo lspace="0.22em" rspace="0.22em"><mrow><mi mathvariant="normal">m</mi><mi mathvariant="normal">o</mi><mi mathvariant="normal">d</mi></mrow></mo><mtext> </mtext><mn>67</mn></mrow><annotation encoding="application/x-tex">2 + 66 \bmod 67</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7278em;vertical-align:-0.0833em;"></span><span class="mord">2</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord">66</span><span class="mspace" style="margin-right:0.0556em;"></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin"><span class="mord"><span class="mord mathrm">mod</span></span></span><span class="mspace" style="margin-right:0.0556em;"></span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">67</span></span></span></span> to the same place — but we have found a 20 parameter construction that solves modular addition. Can we find the same algorithm hidden in the 3,216 parameter model we started with? And why does the larger model switch to the generalizing solution after memorizing?</p> <h3 id="it-s-full-of-stars">It’s Full of Stars</h3> <p>Here’s the <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>a</mi><mo>+</mo><mi>b</mi><mtext> </mtext><mo lspace="0.22em" rspace="0.22em"><mrow><mi mathvariant="normal">m</mi><mi mathvariant="normal">o</mi><mi mathvariant="normal">d</mi></mrow></mo><mtext> </mtext><mn>67</mn></mrow><annotation encoding="application/x-tex">a + b \bmod 67</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6667em;vertical-align:-0.0833em;"></span><span class="mord mathnormal">a</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal">b</span><span class="mspace" style="margin-right:0.0556em;"></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin"><span class="mord"><span class="mord mathrm">mod</span></span></span><span class="mspace" style="margin-right:0.0556em;"></span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">67</span></span></span></span> model that we started with — it’s trained from scratch with no built-in periodicity.</p> <div class='sticky-container'> <div class='mod-bot-accuracy sticky'></div> <div class='mod-bot-waves row'></div> <p>Unlike the constructed solution, where <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">W</mi><mtext>embed</mtext></msub></mrow><annotation encoding="application/x-tex">\mathbf{W}_{\text{embed}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8361em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">embed</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> rotates around the circle once, this model has many different frequencies.</p> <p>Below, we’ve isolated the frequencies using the discrete Fourier transform (DFT).<a class='footstart' key='dft'></a> This factors out the learned periodic patterns across inputs, leaving us with the equivalent of <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">W</mi><mtext>in-proj</mtext></msub></mrow><annotation encoding="application/x-tex">\mathbf{W}_{\text{in-proj}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">in-proj</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">W</mi><mtext>out-proj</mtext></msub></mrow><annotation encoding="application/x-tex">\mathbf{W}_{\text{out-proj}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">out-proj</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> from the constructed solution. For each neuron, this gives a <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>cos</mi><mo>⁡</mo></mrow><annotation encoding="application/x-tex">\cos</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mop">cos</span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>sin</mi><mo>⁡</mo></mrow><annotation encoding="application/x-tex">\sin</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6679em;"></span><span class="mop">sin</span></span></span></span> value for every possible periodic frequency from 1 to 33. The wave charts we show above use this to group neurons into frequencies by finding their largest <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>cos</mi><mo>⁡</mo></mrow><annotation encoding="application/x-tex">\cos</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mop">cos</span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>sin</mi><mo>⁡</mo></mrow><annotation encoding="application/x-tex">\sin</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6679em;"></span><span class="mop">sin</span></span></span></span> value across all frequencies.<a class='footstart' key='dft-sort'></a></p> <div class='mod-bot-dft row'></div> <p>Just like in the 1s and 0s task, weight decay encourages this representation to become much sparser as the model <animate data-animate='bot-gen'>generalizes</animate> .</p> <p>Grouping neurons by their final trained frequencies, and plotting the <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>cos</mi><mo>⁡</mo></mrow><annotation encoding="application/x-tex">\cos</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mop">cos</span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>sin</mi><mo>⁡</mo></mrow><annotation encoding="application/x-tex">\sin</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6679em;"></span><span class="mop">sin</span></span></span></span> components of the DFT for each neuron, we see the same star shapes from the constructed solution appear.</p> <div class='mod-bot-freqs-hidden row'></div> <div class='mod-bot-freqs-out row'></div> <p><strong>This trained model is using the same algorithm as our constructed solution!</strong> Below, the contribution to the output generated by the neurons in each frequency are shown and we can see them calculating <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>cos</mi><mo>⁡</mo><mfrac><mrow><mn>2</mn><mi>π</mi><mo stretchy="false">(</mo><mi>a</mi><mo>+</mo><mi>b</mi><mo stretchy="false">)</mo><mi>f</mi><mi>r</mi><mi>e</mi><mi>q</mi></mrow><mn>67</mn></mfrac></mrow><annotation encoding="application/x-tex"> \cos\frac{2\pi (a + b) freq}{67}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.355em;vertical-align:-0.345em;"></span><span class="mop">cos</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.01em;"><span style="top:-2.655em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">67</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.485em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">π</span><span class="mopen mtight">(</span><span class="mord mathnormal mtight">a</span><span class="mbin mtight">+</span><span class="mord mathnormal mtight">b</span><span class="mclose mtight">)</span><span class="mord mathnormal mtight" style="margin-right:0.10764em;">f</span><span class="mord mathnormal mtight">re</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span> .<a class='footstart' key='logit-wave'></a></p> <p><p class='mod-bot-hide-on-sweep-change'>Notice what happens to the group of neurons with a frequency of 7 when test loss <animate data-animate='bot-improve'>improves</animate> after the short plateau at 45,000 steps — they start to snap into a star shape and their outputs more closely approximate a wave.</p> <div class='row mod-bot-sliders'></div> <div class='row mod-bot-logits'></div> <p>To lower loss without using higher weights (which would be punished by weight decay), the model uses several frequencies, taking advantage of constructive interference.<a class='citestart' key='ProgressMeasures'></a> There’s nothing magical about the frequencies 4, 5, 7 and 26 — click through other training runs below to see variations of this algorithm get learned.</p> <p></div></p> <div class='row mod-bot-seeds'></div> <h3 id="open-questions">Open Questions</h3> <p>While we now have a solid understanding of the mechanisms a one-layer MLP uses to solve modular addition and why they emerge during training, there are still many interesting open questions about memorization and generalization.</p> <h4 id="which-model-constraints-work-best-">Which Model Constraints Work Best?</h4> <p>Directly training the model visualized above — <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mtext>ReLU</mtext><mrow><mo fence="true">(</mo><msub><mi>a</mi><mtext>one-hot</mtext></msub><msub><mtext mathvariant="bold">W</mtext><mtext>input</mtext></msub><mo>+</mo><msub><mi>b</mi><mtext>one-hot</mtext></msub><msub><mtext mathvariant="bold">W</mtext><mtext>input</mtext></msub><mo fence="true">)</mo></mrow><msub><mtext mathvariant="bold">W</mtext><mtext>output</mtext></msub></mrow><annotation encoding="application/x-tex">\text{ReLU} \left(a_{\text{one-hot}}\textbf{W}_{\text{input}} + b_{\text{one-hot}}\textbf{W}_{\text{input}} \right) \textbf{W}_{\text{output}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.0361em;vertical-align:-0.2861em;"></span><span class="mord text"><span class="mord">ReLU</span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="minner"><span class="mopen delimcenter" style="top:0em;">(</span><span class="mord"><span class="mord mathnormal">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">one-hot</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord text"><span class="mord textbf">W</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">input</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord"><span class="mord mathnormal">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">one-hot</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord text"><span class="mord textbf">W</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">input</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mclose delimcenter" style="top:0em;">)</span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord text"><span class="mord textbf">W</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">output</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> — does not actually result in generalization on modular arithmetic, even with the addition of weight decay. At least one of the matrices has to be factored:</p> <p><span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><msub><mtext mathvariant="bold">W</mtext><mtext>input</mtext></msub><mo>=</mo><msub><mi mathvariant="bold">W</mi><mtext>embed</mtext></msub><msub><mi mathvariant="bold">W</mi><mtext>in-proj</mtext></msub></mrow><annotation encoding="application/x-tex"> \textbf{W}_{\text{input}} = \mathbf{W}_{\text{embed}} \mathbf{W}_{\text{in-proj}} </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord text"><span class="mord textbf">W</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">input</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">embed</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">in-proj</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span></span></p> <p><span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><msub><mtext mathvariant="bold">W</mtext><mtext>output</mtext></msub><mo>=</mo><msub><mtext mathvariant="bold">W</mtext><mtext>out-proj</mtext></msub><msubsup><mtext mathvariant="bold">W</mtext><mtext>embed</mtext><mi mathvariant="normal">⊤</mi></msubsup></mrow><annotation encoding="application/x-tex"> \textbf{W}_{\text{output}} = \textbf{W}_{\text{out-proj}} \textbf{W}_{\text{embed}}^{\top} </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord text"><span class="mord textbf">W</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">output</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.2112em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord text"><span class="mord textbf">W</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">out-proj</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord text"><span class="mord textbf">W</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9251em;"><span style="top:-2.453em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">embed</span></span></span></span></span><span style="top:-3.139em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">⊤</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span></span></span></span></span></p> <p>We observed that the generalizing solution is sparse after taking the discrete Fourier transformation, but the collapsed matrices have high norms. This suggests that direct weight decay on <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mtext mathvariant="bold">W</mtext><mtext>output</mtext></msub></mrow><annotation encoding="application/x-tex">\textbf{W}_\text{output}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord text"><span class="mord textbf">W</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord text mtight"><span class="mord mtight">output</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mtext mathvariant="bold">W</mtext><mtext>input</mtext></msub></mrow><annotation encoding="application/x-tex">\textbf{W}_{\text{input}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord text"><span class="mord textbf">W</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">input</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> doesn’t provide the right inductive bias for the task.</p> <p>Broadly speaking, weight decay does steer a wide variety of models away from memorizing their training data <a class='citestart' key='DoubleDescent double-demystified'></a>. Other techniques that help avoid overfitting include dropout, smaller models and even numerically unstable optimization algorithms <a class='citestart' key='Slingshot'></a>. These approaches interact in complex, non-linear ways, making it difficult to predict <em>a priori</em> which will ultimately induce generalization. Collapsing <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">W</mi><mtext>embed</mtext></msub><msub><mi mathvariant="bold">W</mi><mtext>in-proj</mtext></msub></mrow><annotation encoding="application/x-tex">\mathbf{W}_{\text{embed}} \mathbf{W}_{\text{in-proj}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">embed</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">in-proj</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> instead of <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mtext mathvariant="bold">W</mtext><mtext>out-proj</mtext></msub><msubsup><mtext mathvariant="bold">W</mtext><mtext>embed</mtext><mi mathvariant="normal">⊤</mi></msubsup></mrow><annotation encoding="application/x-tex">\textbf{W}_{\text{out-proj}} \textbf{W}_{\text{embed}}^{\top}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.2112em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord text"><span class="mord textbf">W</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">out-proj</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord text"><span class="mord textbf">W</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9251em;"><span style="top:-2.453em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">embed</span></span></span></span></span><span style="top:-3.139em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">⊤</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span></span></span></span>, for example, helps in some setups and hurts in others:</p> <div class='sweep-mod'></div> <h4 id="why-is-memorization-easier-than-generalization-">Why Is Memorization Easier Than Generalization?</h4> <p>One theory: there can be many more ways to memorize a training set than there are generalizing solutions. So statistically, memorization should be more likely to happen first, especially if we have no or little regularization. Regularization techniques, like weight decay, prioritize certain solutions over others, for example, preferring “sparse” solutions over “dense” ones.</p> <p>Recent work suggests that generalization is associated with well-structured representations <a class='citestart' key='EffectiveTheory'></a>. However, it’s not a necessary condition; some MLP variations without symmetric inputs learn less “circular” representations when solving modular addition <a class='citestart' key='Zhong23'></a>. We also observed that well-structured representations are not a sufficient condition for generalization. This small model (trained with no weight decay) starts generalizing, then switches to memorizing with periodic embeddings.</p> <div class='open-q-mem-0-accuracy row'></div> <div class='open-q-mem-0-weights row'></div> <p>It’s even possible to find hyperparameters where models start generalizing, then switch to memorizing, then switch back to generalizing! <a class='footstart' key='open-q-mem'></a></p> <div class='open-q-mem-1-accuracy row'></div> <div class='open-q-mem-1-weights row'></div> <h4 id="what-about-larger-models-">What About Larger Models?</h4> <p>Does grokking happen in larger models trained on real world tasks? Earlier observations reported the grokking phenomenon in algorithmic tasks in small transformers and MLPs <a class='citestart' key='Grokking ProgressMeasures Zhong23'></a>. Grokking has subsequently been found in more complex tasks involving images, text, and tabular data within certain ranges of hyperparameters <a class='citestart' key='Omnigrok Goldilocks'></a>. It’s also possible that the largest models, which are able to do many types of tasks, may be grokking many things at different speeds during training <a class='citestart' key='quantization'></a>.</p> <p>There have also been promising results in predicting grokking before it happens. Though some require knowledge of the generalizing solution <a class='citestart' key='ProgressMeasures'></a> or the overall data domain <a class='citestart' key='StructuralGrokking'></a>, some rely solely on the analysis of the training loss <a class='citestart' key='PredictingGrokking'></a> and might also apply to larger models — hopefully we’ll be able to build tools and techniques that can tell us when a model is parroting memorized information and when it’s using richer models.</p> <p>Understanding the solution to modular addition wasn’t trivial. Do we have any hope of understanding larger models? One route forward — like our digression into the 20 parameter model and the even simpler boolean parity problem — is to: 1) train simpler models with more inductive biases and fewer moving parts, 2) use them to explain inscrutable parts of how a larger model works, 3) repeat as needed. We believe this could be a fruitful approach to better understanding larger models, and complementary to efforts that aim to use larger models to explain smaller ones and other work to disentangle internal representations <a class='citestart' key='explain multiple-choice TMOS'></a>. Moreover, this kind of mechanistic approach to interpretability, in time, may help identify patterns that themselves ease or automate the uncovering of algorithms learned by neural networks.</p> <h3 id="credits">Credits</h3> <p>Thanks to Ardavan Saeedi, Crystal Qian, Emily Reif, Fernanda Viégas, Kathy Meier-Hellstern, Mahima Pushkarna, Minsuk Chang, Neel Nanda and Ryan Mullins for their help with this piece.</p> <p><a href="https://github.com/PAIR-code/ai-explorables/tree/master/server-side/grokking">Model training code</a> // <a href="https://github.com/PAIR-code/ai-explorables/tree/master/source/grokking">Visualization code</a></p> <h3 id="appendix-a-how-the-circular-construction-works">Appendix A: How the Circular Construction Works</h3> <p>We can almost calculate <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>a</mi><mo>+</mo><mi>b</mi><mtext> </mtext><mo lspace="0.22em" rspace="0.22em"><mrow><mi mathvariant="normal">m</mi><mi mathvariant="normal">o</mi><mi mathvariant="normal">d</mi></mrow></mo><mtext> </mtext><mn>67</mn></mrow><annotation encoding="application/x-tex">a + b \bmod 67</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6667em;vertical-align:-0.0833em;"></span><span class="mord mathnormal">a</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal">b</span><span class="mspace" style="margin-right:0.0556em;"></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin"><span class="mord"><span class="mord mathrm">mod</span></span></span><span class="mspace" style="margin-right:0.0556em;"></span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">67</span></span></span></span> using two circular embeddings and a completely linear model.</p> <div class='sticky-container'> <div class='slider-container row sticky appendix'></div> <div class='circle-vis row'></div> <p><p>It works! But we’re cheating a bit, do you see how <strong>unembed</strong> loops around the circle twice? We need to output a single prediction for “<v></v>“ — not separate predictions for “<v></v>“ and “<v2></v2>“. Directly adding the two predictions for a number together won’t work since they’re on opposite sides of the circles and will cancel each other out.</p> <p>Instead, let’s incorporate a <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mtext>ReLU</mtext><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\text{ReLU}(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord text"><span class="mord">ReLU</span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span> to fix the repeated outputs.</p> <div class='proj-vis row'></div> <p>We’ve essentially wrapped the circle around in on itself and the model outputs a single prediction for “<v></v>“.</p> <p>Formally, this is the constructed model:</p> <script type="math/tex"> $$ \begin{aligned} \text{activations} & = \text{ReLU}\left(\mathbf{a}_{\text{one-hot}} \mathbf{W}_{\text{embed}} \mathbf{W}_{\text{in-proj}} + \mathbf{b}_{\text{one-hot}} \mathbf{W}_{\text{embed}} \mathbf{W}_{\text{in-proj}}\right) \\ \text{logits} & = \text{activations} \mathbf{W}_{\text{out-proj}} \mathbf{W}_{\text{embed}}^{\top} \end{aligned} $$ </script> <p>With modulus <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>M</mi></mrow><annotation encoding="application/x-tex">M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span> evenly spaced neurons/directions:</p> <script type="math/tex"> $$ \mathbf{W}_{\text{embed}} = \begin{pmatrix} \dots & \ldots \\ \cos(i\frac{2\pi}{M}) & \sin(i \frac{2\pi}{M}) \\ \dots & \dots \\ \end{pmatrix}, \quad \mathbf{W}_{\text{in-proj}}^T = \begin{pmatrix} \dots & \ldots \\ \cos(i\frac{2\pi}{N}) & \sin(i \frac{2\pi}{N}) \\ \dots & \dots \\ \end{pmatrix}, \quad \mathbf{W}_{\text{out-proj}} = \begin{pmatrix} \dots & \dots \\ \cos(2i\frac{2\pi}{N}) & \sin(2i\frac{2\pi}{N}) \\ \dots & \dots \\ \end{pmatrix}. $$ </script> <p><br></p> <p>Interestingly this circle has a few wrinkles: this construction doesn’t give an exact answer!</p> <p></div></p> <div class='debug-vis row'></div> <div class='appendix num-inputs row'> <span>Neurons <input type="number" class='n_neurons' min="3" max="10" value="5"></span> <span>Modulus <input type="number" class='modulus' min="12" max="500" value="67"></span> </div> <p><br></p> <p>Using <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mi>x</mi><mn>2</mn></msup></mrow><annotation encoding="application/x-tex">x^2</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8141em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8141em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span></span></span></span> instead of <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mtext>ReLU</mtext><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\text{ReLU}(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord text"><span class="mord">ReLU</span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span> as the activation function, as suggested by <a class='citestart' key='gromov'></a> gives a provably exact solution!</p> <p>For simplicity, let <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>ω</mi><mo>:</mo><mo>=</mo><mfrac><mrow><mn>2</mn><mi>π</mi></mrow><mi>M</mi></mfrac></mrow><annotation encoding="application/x-tex">\omega:=\frac{2\pi}{M}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">ω</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">:=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.1901em;vertical-align:-0.345em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8451em;"><span style="top:-2.655em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.10903em;">M</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">π</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span> (the angle between numbers in <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">W</mi><mtext>embed</mtext></msub></mrow><annotation encoding="application/x-tex">\mathbf{W}_{\text{embed}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8361em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">embed</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span>) and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>θ</mi><mo>:</mo><mo>=</mo><mfrac><mrow><mn>2</mn><mi>π</mi></mrow><mi>N</mi></mfrac></mrow><annotation encoding="application/x-tex">\theta := \frac{2\pi}{N}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">:=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.1901em;vertical-align:-0.345em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8451em;"><span style="top:-2.655em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.10903em;">N</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">π</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span> (the angle between neurons in <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msubsup><mi mathvariant="bold">W</mi><mtext>in-proj</mtext><mi>T</mi></msubsup></mrow><annotation encoding="application/x-tex">\mathbf{W}_{\text{in-proj}}^T</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.2419em;vertical-align:-0.4006em;"></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8413em;"><span style="top:-2.4355em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">in-proj</span></span></span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.4006em;"><span></span></span></span></span></span></span></span></span></span>).</p> <p>Let’s rewrite <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mtext>logits</mtext><mrow><mi>a</mi><mo separator="true">,</mo><mi>b</mi></mrow></msup></mrow><annotation encoding="application/x-tex">\text{logits}^{a, b}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.1279em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord text"><span class="mord">logits</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9334em;"><span style="top:-3.1473em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">a</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">b</span></span></span></span></span></span></span></span></span></span></span></span> as an <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>M</mi></mrow><annotation encoding="application/x-tex">M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span>-dimensional vector <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo stretchy="false">∥</mo><mrow><mi>l</mi><msup><mo stretchy="false">∥</mo><mi>M</mi></msup></mrow></mrow><annotation encoding="application/x-tex"> \lVert \it{l} \rVert ^M </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.0913em;vertical-align:-0.25em;"></span><span class="mopen">∥</span><span class="mord"><span class="mord"><span class="mord mathit">l</span></span><span class="mclose"><span class="mclose">∥</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8413em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathit mtight">M</span></span></span></span></span></span></span></span></span></span></span></span> where:</p> <p><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>l</mi><mi>j</mi></msub><mo>=</mo><msubsup><mo>∑</mo><mrow><mi>i</mi><mo>=</mo><mn>0</mn></mrow><mrow><mi>N</mi><mo>−</mo><mn>1</mn></mrow></msubsup><mo fence="true" stretchy="true" minsize="2.4em" maxsize="2.4em">(</mo><mo fence="true" stretchy="true" minsize="1.2em" maxsize="1.2em">[</mo><mi>cos</mi><mo>⁡</mo><mo stretchy="false">(</mo><mi>a</mi><mi>ω</mi><mo>−</mo><mi>i</mi><mi>θ</mi><mo stretchy="false">)</mo><mo>+</mo><mi>cos</mi><mo>⁡</mo><mo stretchy="false">(</mo><mi>b</mi><mi>ω</mi><mo>−</mo><mi>i</mi><mi>θ</mi><mo stretchy="false">)</mo><msup><mo fence="true" stretchy="true" minsize="1.2em" maxsize="1.2em">]</mo><mn>2</mn></msup><mi>cos</mi><mo>⁡</mo><mo stretchy="false">(</mo><mi>j</mi><mi>ω</mi><mo>−</mo><mn>2</mn><mi>i</mi><mi>θ</mi><mo stretchy="false">)</mo><mo stretchy="false">)</mo><mo fence="true" stretchy="true" minsize="2.4em" maxsize="2.4em">)</mo></mrow><annotation encoding="application/x-tex">l_{j} = \sum_{i=0}^{N-1} \biggl(\bigl[ \cos(a\omega-i\theta) + \cos(b\omega-i\theta) \bigl]^2\cos(j\omega-2i\theta)) \biggr) </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9805em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.01968em;">l</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:-0.0197em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:2.4em;vertical-align:-0.95em;"></span><span class="mop"><span class="mop op-symbol small-op" style="position:relative;top:0em;">∑</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9812em;"><span style="top:-2.4003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span><span class="mrel mtight">=</span><span class="mord mtight">0</span></span></span></span><span style="top:-3.2029em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.10903em;">N</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2997em;"><span></span></span></span></span></span></span><span class="mopen"><span class="delimsizing size3">(</span></span><span class="mopen"><span class="delimsizing size1">[</span></span><span class="mop">cos</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.03588em;">aω</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">i</span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">cos</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.03588em;">bω</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1.404em;vertical-align:-0.35em;"></span><span class="mord mathnormal">i</span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="mclose">)</span><span class="mopen"><span class="mopen"><span class="delimsizing size1">]</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:1.054em;"><span style="top:-3.3029em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mop">cos</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.03588em;">jω</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:2.4em;vertical-align:-0.95em;"></span><span class="mord">2</span><span class="mord mathnormal">i</span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="mclose">))</span><span class="mclose"><span class="delimsizing size3">)</span></span></span></span></span></p> <p>This follows from the <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mtext>logits</mtext><mrow><mi>a</mi><mo separator="true">,</mo><mi>b</mi></mrow></msup></mrow><annotation encoding="application/x-tex">\text{logits}^{a,b}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.1279em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord text"><span class="mord">logits</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9334em;"><span style="top:-3.1473em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">a</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">b</span></span></span></span></span></span></span></span></span></span></span></span> equation above by plugging in the definitions of <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mtext mathvariant="bold">W</mtext><mtext>in-proj</mtext></msub></mrow><annotation encoding="application/x-tex">\textbf{W}_\text{in-proj}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord text"><span class="mord textbf">W</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord text mtight"><span class="mord mtight">in-proj</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mtext mathvariant="bold">W</mtext><mtext>out-proj</mtext></msub></mrow><annotation encoding="application/x-tex">\textbf{W}_\text{out-proj}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord text"><span class="mord textbf">W</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord text mtight"><span class="mord mtight">out-proj</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> and applying the trigonometric identity that <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>cos</mi><mo>⁡</mo><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo><mi>cos</mi><mo>⁡</mo><mo stretchy="false">(</mo><mi>y</mi><mo stretchy="false">)</mo><mo>+</mo><mi>sin</mi><mo>⁡</mo><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo><mi>sin</mi><mo>⁡</mo><mo stretchy="false">(</mo><mi>y</mi><mo stretchy="false">)</mo><mo>=</mo><mi>cos</mi><mo>⁡</mo><mo stretchy="false">(</mo><mi>x</mi><mo>−</mo><mi>y</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\cos(x)\cos(y) + \sin(x)\sin(y) = \cos(x-y)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">cos</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop">cos</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">sin</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop">sin</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">cos</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="mclose">)</span></span></span></span>.</p> <p>We can then prove the following:</p> <p><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mrow><mi mathvariant="normal">arg max</mi><mo>⁡</mo></mrow><mi>c</mi></msub><mtext> ⁣</mtext><msup><mtext>logits</mtext><mrow><mi>a</mi><mo separator="true">,</mo><mi>b</mi></mrow></msup><mo>=</mo><mi>a</mi><mo>+</mo><mi>b</mi><mtext> </mtext><mo lspace="0.22em" rspace="0.22em"><mrow><mi mathvariant="normal">m</mi><mi mathvariant="normal">o</mi><mi mathvariant="normal">d</mi></mrow></mo><mtext> </mtext><mi>M</mi></mrow><annotation encoding="application/x-tex"> \argmax_c \! \text{logits}^{a,b} = a + b \bmod M </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.1776em;vertical-align:-0.2441em;"></span><span class="mop"><span class="mop"><span class="mord mathrm" style="margin-right:0.01389em;">arg</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathrm">max</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.0573em;"><span style="top:-2.4559em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">c</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2441em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:-0.1667em;"></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord text"><span class="mord">logits</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9334em;"><span style="top:-3.1473em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">a</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">b</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6667em;vertical-align:-0.0833em;"></span><span class="mord mathnormal">a</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal">b</span><span class="mspace" style="margin-right:0.0556em;"></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin"><span class="mord"><span class="mord mathrm">mod</span></span></span><span class="mspace" style="margin-right:0.0556em;"></span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></p> <p>Applying the two trigonometric identities of <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>cos</mi><mo>⁡</mo><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo><mo>+</mo><mi>cos</mi><mo>⁡</mo><mo stretchy="false">(</mo><mi>y</mi><mo stretchy="false">)</mo><mo>=</mo><mn>2</mn><mi>cos</mi><mo>⁡</mo><mo stretchy="false">(</mo><mfrac><mrow><mi>x</mi><mo>−</mo><mi>y</mi></mrow><mn>2</mn></mfrac><mo stretchy="false">)</mo><mi>cos</mi><mo>⁡</mo><mo stretchy="false">(</mo><mfrac><mrow><mi>x</mi><mo>+</mo><mi>y</mi></mrow><mn>2</mn></mfrac><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\cos(x) + \cos(y) = 2 \cos(\frac{x-y}{2}) \cos(\frac{x+y}{2})</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">cos</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">cos</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.1994em;vertical-align:-0.345em;"></span><span class="mord">2</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop">cos</span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8544em;"><span style="top:-2.655em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.4461em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="mbin mtight">−</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop">cos</span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8544em;"><span style="top:-2.655em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.4461em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="mbin mtight">+</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mrow><mi>cos</mi><mo>⁡</mo></mrow><mn>2</mn></msup><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo><mi>c</mi><mi>o</mi><mi>s</mi><mo stretchy="false">(</mo><mi>y</mi><mo stretchy="false">)</mo><mo>=</mo><mn>1</mn><mi mathvariant="normal">/</mi><mn>4</mn><mo fence="true" stretchy="true" minsize="1.2em" maxsize="1.2em">[</mo><mn>2</mn><mi>cos</mi><mo>⁡</mo><mo stretchy="false">(</mo><mi>y</mi><mo stretchy="false">)</mo><mo>+</mo><mi>cos</mi><mo>⁡</mo><mo stretchy="false">(</mo><mn>2</mn><mi>x</mi><mo>−</mo><mi>y</mi><mo stretchy="false">)</mo><mo>+</mo><mi>cos</mi><mo>⁡</mo><mo stretchy="false">(</mo><mn>2</mn><mi>x</mi><mo>+</mo><mi>y</mi><mo stretchy="false">)</mo><mo fence="true" stretchy="true" minsize="1.2em" maxsize="1.2em">]</mo></mrow><annotation encoding="application/x-tex">\cos^2(x)cos(y) = 1/4 \bigl[ 2\cos(y) + \cos(2x-y) + \cos (2x+y) \bigl] </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.0641em;vertical-align:-0.25em;"></span><span class="mop"><span class="mop">cos</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8141em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mord mathnormal">cos</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.2em;vertical-align:-0.35em;"></span><span class="mord">1/4</span><span class="mopen"><span class="delimsizing size1">[</span></span><span class="mord">2</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop">cos</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">cos</span><span class="mopen">(</span><span class="mord">2</span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">cos</span><span class="mopen">(</span><span class="mord">2</span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1.2em;vertical-align:-0.35em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="mclose">)</span><span class="mopen"><span class="delimsizing size1">]</span></span></span></span></span>, we have:</p> <script type="math/tex"> $$ \begin{aligned} \text{logits}^{a, b} & = \sum_{i=0}^{N-1} \biggl(\bigl[ \cos(a\omega-i\theta) + \cos(b\omega-i\theta) \bigl]^2 \cos(c\omega-2i\theta) \biggl) \\ & = \sum_{i=0}^{N-1} \biggl(2 \bigl[ \cos(\frac{a-b}{2}\omega)\cos(\frac{a+b}{2}\omega-i\theta) \bigl] ^2 \cos(c\omega-2i\theta) \biggl) \\ & = \cos^2(\frac{a-b}{2}\omega) \sum_{i=0}^{N-1} \biggl(2\cos(c\omega - 2i\theta) + \cos((a+b-c)\omega) +\cos((a+b+c)\omega-4i\theta) \biggl) \end{aligned} $$ </script> <p>Note that <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>∑</mo><mi>cos</mi><mo>⁡</mo><mo stretchy="false">(</mo><msub><mi>γ</mi><mi>i</mi></msub><mo stretchy="false">)</mo><mo>=</mo><mn>0</mn></mrow><annotation encoding="application/x-tex">\sum \cos(\gamma_{i})=0</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop op-symbol small-op" style="position:relative;top:0em;">∑</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop">cos</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05556em;">γ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:-0.0556em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">0</span></span></span></span> where <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>γ</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">\gamma_{i}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05556em;">γ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:-0.0556em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> is equally spread around the circle. The first and the third sum terms wrap around the circle with <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mn>2</mn><mi>θ</mi></mrow><annotation encoding="application/x-tex">2\theta</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord">2</span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mn>4</mn><mi>θ</mi></mrow><annotation encoding="application/x-tex">4\theta</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord">4</span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span></span></span></span> increments respectively. The sum of the first terms equals zero for <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>&gt;</mo><mn>2</mn></mrow><annotation encoding="application/x-tex">N \gt 2</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7224em;vertical-align:-0.0391em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">&gt;</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">2</span></span></span></span> and the sum of the third terms equals zero for <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>&gt;</mo><mn>4</mn></mrow><annotation encoding="application/x-tex">N \gt 4</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7224em;vertical-align:-0.0391em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">&gt;</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">4</span></span></span></span>. Therefore, we have:</p> <p><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mtext>logits</mtext><mrow><mi>a</mi><mo separator="true">,</mo><mi>b</mi></mrow></msup><mo>=</mo><msup><mrow><mi>cos</mi><mo>⁡</mo></mrow><mn>2</mn></msup><mo stretchy="false">(</mo><mfrac><mrow><mi>a</mi><mo>−</mo><mi>b</mi></mrow><mn>2</mn></mfrac><mi>ω</mi><mo stretchy="false">)</mo><mi>cos</mi><mo>⁡</mo><mo stretchy="false">(</mo><mo stretchy="false">(</mo><mi>a</mi><mo>+</mo><mi>b</mi><mo>−</mo><mi>c</mi><mo stretchy="false">)</mo><mi>ω</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex"> \text{logits}^{a,b} = \cos^2(\frac{a-b}{2}\omega) \cos((a+b-c)\omega) </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.1279em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord text"><span class="mord">logits</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9334em;"><span style="top:-3.1473em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">a</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">b</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.2251em;vertical-align:-0.345em;"></span><span class="mop"><span class="mop">cos</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8141em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8801em;"><span style="top:-2.655em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">a</span><span class="mbin mtight">−</span><span class="mord mathnormal mtight">b</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord mathnormal" style="margin-right:0.03588em;">ω</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop">cos</span><span class="mopen">((</span><span class="mord mathnormal">a</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.7778em;vertical-align:-0.0833em;"></span><span class="mord mathnormal">b</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">c</span><span class="mclose">)</span><span class="mord mathnormal" style="margin-right:0.03588em;">ω</span><span class="mclose">)</span></span></span></span></p> <p>Since the first term is a positive constant w.r.t inputs, the equation is maximized when <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>cos</mi><mo>⁡</mo><mo stretchy="false">(</mo><mo stretchy="false">(</mo><mi>a</mi><mo>+</mo><mi>b</mi><mo>−</mo><mi>c</mi><mo stretchy="false">)</mo><mi>ω</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\cos((a+b-c)\omega)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">cos</span><span class="mopen">((</span><span class="mord mathnormal">a</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.7778em;vertical-align:-0.0833em;"></span><span class="mord mathnormal">b</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">c</span><span class="mclose">)</span><span class="mord mathnormal" style="margin-right:0.03588em;">ω</span><span class="mclose">)</span></span></span></span> is maximized, which is when <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>c</mi><mo>=</mo><mi>a</mi><mo>+</mo><mi>b</mi><mtext> </mtext><mo lspace="0.22em" rspace="0.22em"><mrow><mi mathvariant="normal">m</mi><mi mathvariant="normal">o</mi><mi mathvariant="normal">d</mi></mrow></mo><mtext> </mtext><mi>M</mi></mrow><annotation encoding="application/x-tex">c = a + b \bmod M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal">c</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6667em;vertical-align:-0.0833em;"></span><span class="mord mathnormal">a</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal">b</span><span class="mspace" style="margin-right:0.0556em;"></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin"><span class="mord"><span class="mord mathrm">mod</span></span></span><span class="mspace" style="margin-right:0.0556em;"></span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span>.</p> <p>Essentially <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mtext>ReLU</mtext><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\text{ReLU}(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord text"><span class="mord">ReLU</span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span> activations with weight decay (a very typical model setup) gives the model an inductive bias that’s close enough to the exact generalizing solution of <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mi>x</mi><mn>2</mn></msup></mrow><annotation encoding="application/x-tex">x^2</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8141em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8141em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span></span></span></span> activations with a sparse discrete Fourier transform to push in the direction of generalization but not so close that it won’t also learn to fit the training data with memorization.</p> <p><br></p> <h3 id="footnotes">Footnotes</h3> <p><a class='footend' key='modular'></a> In modular addition, we have two input numbers, <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>a</mi></mrow><annotation encoding="application/x-tex">a</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal">a</span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>b</mi></mrow><annotation encoding="application/x-tex">b</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal">b</span></span></span></span>, and a modulus <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>m</mi></mrow><annotation encoding="application/x-tex">m</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal">m</span></span></span></span>. We want to find the remainder of <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>a</mi><mo>+</mo><mi>b</mi></mrow><annotation encoding="application/x-tex">a + b</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6667em;vertical-align:-0.0833em;"></span><span class="mord mathnormal">a</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal">b</span></span></span></span> when divided by <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>m</mi></mrow><annotation encoding="application/x-tex">m</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal">m</span></span></span></span>. <span class='fn-break'></span> This type of addition is often called clock-face addition, because when adding two times, we often report the result modulo 12 (i.e. 5 hours after 8 o’clock is 1 o’clock). <span class='fn-break'></span> Modular addition sounds simple and it is. We can easily train 1,000s of models and treat them like fruit flies in neuroscience: small enough such that it is feasible to extract their <a href="https://www.science.org/doi/abs/10.1126/science.add9330">connectome</a> synapse-by-synapse, yet providing new interesting insights about the system more broadly. We can get a good understanding of the small models we’ve trained by visualizing all their internals.</p> <p><a class='footend' key='67'></a>67 isn’t a magic number – we could pick many numbers to illustrate grokking, but 67 is not so small that the task is trivial and also not so large that the visualizations are overwhelming.</p> <p><a class='footend' key='playground'></a> The model is trained with cross-entropy loss, AdamW and full batches. The <a href="#which-model-constraints-work-best-">section on regularization</a> and <a href="https://colab.research.google.com/github/PAIR-code/ai-explorables/blob/master/server-side/grokking/MLP_Modular_Addition.ipynb">training colab</a> have additional details. <span class='fn-break'></span> If you’re not familiar with <a href="https://en.wikipedia.org/wiki/Multilayer_perceptron">MLPs</a>, <a href="http://playground.tensorflow.org">playground.tensorflow.org</a> is a great place to start. <span class='fn-break'></span> A quick notation explanation: The columns of <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">W</mi><mtext>input</mtext></msub></mrow><annotation encoding="application/x-tex">\mathbf{W}_{\text{input}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">input</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">W</mi><mtext>ouput</mtext></msub></mrow><annotation encoding="application/x-tex">\mathbf{W}_{\text{ouput}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">ouput</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> represent the numbers from 0 to 66. <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">a</mi><mtext>one-hot</mtext></msub></mrow><annotation encoding="application/x-tex">\mathbf{a}_{\text{one-hot}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5944em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathbf">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">one-hot</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">b</mi><mtext>one-hot</mtext></msub></mrow><annotation encoding="application/x-tex">\mathbf{b}_{\text{one-hot}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8444em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathbf">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">one-hot</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> are how we <a href="https://en.wikipedia.org/wiki/One-hot">encode</a> the model’s inputs; each pick a single column from <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">W</mi><mtext>input</mtext></msub></mrow><annotation encoding="application/x-tex">\mathbf{W}_{\text{input}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">input</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span>. <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mtext>ReLU</mtext></mrow><annotation encoding="application/x-tex">\text{ReLU}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord text"><span class="mord">ReLU</span></span></span></span></span> replaces negative numbers with 0s; it is a <a href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks">fancy</a>) way of writing <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>max</mi><mo>⁡</mo><mo stretchy="false">(</mo><mi>x</mi><mo separator="true">,</mo><mn>0</mn><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\max(x, 0)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">max</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">0</span><span class="mclose">)</span></span></span></span>.</p> <p><a class='footend' key='sp-model'></a> With a small twist — we’re only outputting 1 or 0, so <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">W</mi><mtext>output</mtext></msub></mrow><annotation encoding="application/x-tex">\mathbf{W}_{\text{output}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">output</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> can be a single column. In the modular addition task we needed a column for every output number. <span class='fn-break'></span> The last column of <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">W</mi><mtext>input</mtext></msub></mrow><annotation encoding="application/x-tex">\mathbf{W}_{\text{input}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathbf" style="margin-right:0.01597em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3175em;"><span style="top:-2.55em;margin-left:-0.016em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">input</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> is also fixed to 1 to provide a <a href="https://stackoverflow.com/questions/2480650/what-is-the-role-of-the-bias-in-neural-networks">bias term</a>.</p> <p><a class='footend' key='sp-solution'></a><a href="https://arxiv.org/pdf/2303.11873.pdf#page=8">Appendix D</a> of “A Tale of Two Circuits: Grokking as Competition of Sparse and Dense Subnetworks” has an explanation of the 4 neuron solution generalizing solution here</p> <p><a class='footend' key='loss'></a> So far we’ve been charting <a href="https://developers.google.com/machine-learning/crash-course/classification/accuracy">accuracy</a>, the percentage of sequences where the correct label is the most likely. Training typically instead optimizes a differentiable objective function. All the models in this post use <a href="https://ml-cheatsheet.readthedocs.io/en/latest/loss_functions.html">cross entropy loss</a> which heavily penalizes incorrect predictions with high probabilities. <span class='fn-break'></span> Note that while some formulations of loss include a weight decay or regularization term, the loss plots here depict the cross entropy component alone.</p> <p><a class='footend' key='sp-l2'></a> On the 1s and 0s task here, we use L1 weight decay <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>L</mi><mn>1</mn><mo stretchy="false">(</mo><mi mathvariant="bold">w</mi><mo stretchy="false">)</mo><mo>=</mo><msub><mo>∑</mo><mi>i</mi></msub><mi mathvariant="normal">∣</mi><msub><mi>w</mi><mi>i</mi></msub><mi mathvariant="normal">∣</mi></mrow><annotation encoding="application/x-tex">L1(\mathbf{w}) = \sum_{i} |w_i|</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">L</span><span class="mord">1</span><span class="mopen">(</span><span class="mord mathbf" style="margin-right:0.01597em;">w</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.0497em;vertical-align:-0.2997em;"></span><span class="mop"><span class="mop op-symbol small-op" style="position:relative;top:0em;">∑</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.162em;"><span style="top:-2.4003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2997em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">∣</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02691em;">w</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:-0.0269em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord">∣</span></span></span></span>. <span class='fn-break'></span> L2 weight decay <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>L</mi><mn>2</mn><mo stretchy="false">(</mo><mi mathvariant="bold">w</mi><mo stretchy="false">)</mo><mo>=</mo><msub><mo>∑</mo><mi>i</mi></msub><msubsup><mi>w</mi><mi>i</mi><mn>2</mn></msubsup></mrow><annotation encoding="application/x-tex">L2(\mathbf{w}) = \sum_{i} w_i^2</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">L</span><span class="mord">2</span><span class="mopen">(</span><span class="mord mathbf" style="margin-right:0.01597em;">w</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.1138em;vertical-align:-0.2997em;"></span><span class="mop"><span class="mop op-symbol small-op" style="position:relative;top:0em;">∑</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.162em;"><span style="top:-2.4003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2997em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02691em;">w</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8141em;"><span style="top:-2.4413em;margin-left:-0.0269em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2587em;"><span></span></span></span></span></span></span></span></span></span> is a more typical choice. It pushes for <a href="https://explained.ai/regularization/L1vsL2.html">lots of small weights</a> leading to redundant neurons on this task: <span class='fn-break'></span> <img src='img/sp-l2.gif'></img></p> <p><a class='footend' key='overfit'></a> A model overfits the training data when it performs well on the training data but poorly on the test data — this is what we see with our memorizing models. In general, simpler models are less prone to overfitting as, due to their simplicity, decision rules are coarser and are required to make more generalizations. Of course, if a model is too simple for a task, it may not be able to learn good decision rules that capture the nuances of the task. Researchers force models to be simpler through a variety of techniques, including having models with fewer parameters or encouraging the parameters that the model does have to be small in size with weight decay.</p> <p><a class='footend' key='unit-circle'></a> Computing <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>cos</mi><mo>⁡</mo><mo stretchy="false">(</mo><mi>i</mi><mfrac><mrow><mn>2</mn><mi>π</mi></mrow><mn>67</mn></mfrac><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\cos(i\frac{2\pi}{67})</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.1901em;vertical-align:-0.345em;"></span><span class="mop">cos</span><span class="mopen">(</span><span class="mord mathnormal">i</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8451em;"><span style="top:-2.655em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">67</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">π</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>sin</mi><mo>⁡</mo><mo stretchy="false">(</mo><mi>i</mi><mfrac><mrow><mn>2</mn><mi>π</mi></mrow><mn>67</mn></mfrac><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\sin (i\frac{2\pi}{67})</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.1901em;vertical-align:-0.345em;"></span><span class="mop">sin</span><span class="mopen">(</span><span class="mord mathnormal">i</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8451em;"><span style="top:-2.655em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">67</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">π</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span></span></span></span> gives us points evenly spaced around the unit circle. <span class='fn-break'></span> Here’s what <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mtext mathvariant="bold">W</mtext><mtext>embed</mtext></msub></mrow><annotation encoding="application/x-tex">\textbf{W}_{\text{embed}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8361em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord text"><span class="mord textbf">W</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">embed</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> looks like on the unit circle: <span class='fn-break'></span> <img src='img/w_embed.png' width=319></img></p> <p><a class='footend' key='dft'></a> The <a href="https://www.youtube.com/watch?v=spUNpyF58BY">Discrete Fourier Transform</a> helps analyze the periodic nature of a sequence of values (in this case the <a href="https://colab.research.google.com/drive/1F6_1_cWXE5M7WocUcpQWp3v8z4b1jL20#scrollTo=iSPxi3ElsujY">weights for a particular neuron</a>) by breaking it down into sine and cosine functions. The more periodic a function is, the easier it is to represent with sine and cosines, and the sparser the output of the DFT.</p> <p><a class='footend' key='dft-sort'></a> We’ve reindexed the neurons by their final frequency and phase to make this grouping easier to see .</p> <p><a class='footend' key='logit-wave'></a> The model generates probabilities by taking the dot product of the neuron activations for a given input with <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mtext mathvariant="bold">W</mtext><mtext>output</mtext></msub></mrow><annotation encoding="application/x-tex">\textbf{W}_{\text{output}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9722em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord text"><span class="mord textbf">W</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">output</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> and softmaxing. If we calculate the dot product using only the activations from neurons of a single frequency, we can see which outputs the frequency group is making more or less likely. <span class='fn-break'></span> <a href="#appendix-a-how-the-circular-construction-works">Appendix A</a> explains why these logits form a wave — each group of frequencies is essentially outputting how close the correct answer is to every number on a version of <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mtext mathvariant="bold">W</mtext><mtext>embed</mtext></msub></mrow><annotation encoding="application/x-tex">\textbf{W}_{\text{embed}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8361em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord text"><span class="mord textbf">W</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">embed</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> with the group’s frequency.</p> <p><a class='footend' key='open-q-mem'></a>Both of these models are <a href="https://colab.sandbox.google.com/github/PAIR-code/ai-explorables/blob/master/server-side/grokking/MLP_Modular_Addition.ipynb#scrollTo=5hJqK4jx0vC7">quite small</a>. The bottom model has tweaked hyperparameters to encourage eventual generalization: it’s slightly larger to allow it to exit local minimums, it has more training data (making low loss memorizing solutions harder to find) and it has weight decay.</p> <h3 id="references">References</h3> <p><a class='citeend' key='Grokking'></a> <a href="https://arxiv.org/pdf/2201.02177.pdf">Grokking: Generalization Beyond Overfitting On Small Algorithmic Datasets</a> Power, A., Burda, Y., Edwards, H., Babuschkin, I., &amp; Misra, V. (2022). arXiv preprint arXiv:2201.02177.</p> <p><a class='citeend' key='Omnigrok'></a> <a href="https://arxiv.org/pdf/2210.01117.pdf">Omnigrok: Grokking Beyond Algorithmic Data</a> Liu, Z., Michaud, E. J., &amp; Tegmark, M. (2022, September). In The Eleventh International Conference on Learning Representations.</p> <p><a class='citeend' key='Universality'></a> <a href="https://arxiv.org/abs/2302.03025">A Toy Model of Universality: Reverse Engineering How Networks Learn Group Operations</a> Chughtai, B., Chan, L., Nanda, N. (2023). International Conference on Machine Learning.</p> <p><a class='citeend' key='Zhong23'></a><a href="https://arxiv.org/pdf/2306.17844.pdf">The Clock and the Pizza: Two Stories in Mechanistic Explanation of Neural Networks</a> Zhong, Z., Liu, Z., Tegmark, M., &amp; Andreas, J. (2023). arXiv preprint arXiv:2306.17844.</p> <p><a class='citeend' key='ProgressParity'></a> <a href="https://arxiv.org/abs/2207.08799">Hidden Progress in Deep Learning: SGD Learns Parities Near the Computational Limit</a> Boaz Barak, Benjamin L. Edelman, Surbhi Goel, Sham Kakade, Eran Malach, Cyril Zhang. (2022) Advances in Neural Information Processing Systems, 35, 21750-21764.</p> <p><a class='citeend' key='gromov'></a><a href="https://arxiv.org/abs/2301.02679">Grokking modular arithmetic</a> Andrey Gromov (2023). arXiv preprint arXiv:2301.02679.</p> <p><a class='citeend' key='Parrots'></a><a href="https://dl.acm.org/doi/pdf/10.1145/3442188.3445922?uuid=f2qngt2LcFCbgtaZ2024">On the Dangers of Stochastic Parrots: Can Language Models Be Too Big?🦜</a> Bender, E. M., Gebru, T., McMillan-Major, A., &amp; Shmitchell, S. (2021, March). <em>In Proceedings of the 2021 ACM conference on fairness, accountability, and transparency</em> (pp. 610-623).</p> <p><a class='citeend' key='Othello'></a> <a href="https://openreview.net/pdf?id=DeG07_TcZvT">Emergent World Representations: Exploring a Sequence Model Trained on a Synthetic Task</a> Li, K., Hopkins, A. K., Bau, D., Viégas, F., Pfister, H., &amp; Wattenberg, M. (2022, September). <em>In The Eleventh International Conference on Learning Representations</em>.</p> <p><a class='citeend' key='MechInterp'></a> <a href="https://transformer-circuits.pub/2022/mech-interp-essay/index.html">Mechanistic Interpretability, Variables, and the Importance of Interpretable Bases</a> Olah, C., 2022. Transformer Circuits Thread.</p> <p><a class='citeend' key='ProgressMeasures'></a> <a href="https://openreview.net/pdf?id=9XFSbDPmdW">Progress Measures for Grokking via Mechanistic Interpretability</a> Nanda, N., Chan, L., Lieberum, T., Smith, J., &amp; Steinhardt, J. (2022, September). In The Eleventh International Conference on Learning Representations.</p> <p><a class='citeend' key='TwoCircuits'></a> <a href="https://arxiv.org/abs/2303.11873">A Tale of Two Circuits: Grokking as Competition of Sparse and Dense Subnetworks</a> William Merrill, Nikolaos Tsilivis, Aman Shukla. (2023). arXiv preprint arXiv:2303.11873.</p> <p><a class='citeend' key='DoubleDescent'></a><a href="https://arxiv.org/pdf/2303.06173.pdf">Unifying Grokking and Double Descent</a> Davies, X., Langosco, L., &amp; Krueger, D. (2022, November). In NeurIPS ML Safety Workshop.</p> <p><a class='citeend' key='double-demystified'></a><a href="https://arxiv.org/abs/2303.14151">Double Descent Demystified: Identifying, Interpreting &amp; Ablating the Sources of a Deep Learning Puzzle</a> Rylan Schaeffer, R., Khona, M., Robertson, Z., Boopathy, A., Pistunova, K., Rocks, J., Rani Fiete, I., &amp; Koyejo, O. (2023). arXiv preprint arXiv:2303.14151.</p> <p><a class='citeend' key='Slingshot'></a> <a href="https://arxiv.org/pdf/2206.04817.pdf">The Slingshot Mechanism: An Empirical Study of Adaptive Optimizers and the Grokking Phenomenon</a> Thilak, V., Littwin, E., Zhai, S., Saremi, O., Paiss, R., &amp; Susskind, J. (2022). arXiv preprint arXiv:2206.04817.</p> <p><a class='citeend' key='EffectiveTheory'></a><a href="https://arxiv.org/pdf/2205.10343.pdf">Towards Understanding Grokking: An Effective Theory of Representation Learning</a> Liu, Z., Kitouni, O., Nolte, N. S., Michaud, E., Tegmark, M., &amp; Williams, M. (2022). Advances in Neural Information Processing Systems, 35, 34651-34663.</p> <p><a class='citeend' key='Goldilocks'></a><a href="https://arxiv.org/pdf/1807.02581.pdf">The Goldilocks Zone: Towards Better Understanding of Neural Network Loss Landscapes</a> Fort, S., &amp; Scherlis, A. (2019, July). In Proceedings of the AAAI conference on artificial intelligence (Vol. 33, No. 01, pp. 3574-3581).</p> <p><a class='citeend' key='quantization'></a><a href="https://arxiv.org/abs/2303.13506">The Quantization Model of Neural Scaling</a> Eric J. Michaud, Ziming Liu, Uzay Girit, Max Tegmark, O. (2023). arXiv preprint arXiv:2303.13506.</p> <p><a class='citeend' key='StructuralGrokking'></a> <a href="https://arxiv.org/pdf/2305.18741.pdf">Grokking of Hierarchical Structure in Vanilla Transformers</a> Murty, S., Sharma, P., Andreas, J., &amp; Manning, C. D. (2023). arXiv preprint arXiv:2305.18741.</p> <p><a class='citeend' key='PredictingGrokking'></a> <a href="https://arxiv.org/pdf/2306.13253.pdf">Predicting Grokking Long Before it Happens: A Look Into the Loss Landscape of Models Which Grok</a> Notsawo Jr, P., Zhou, H., Pezeshki, M., Rish, I., &amp; Dumas, G. (2023). arXiv preprint arXiv:2306.13253.</p> <p><a class='citeend' key='explain'></a><a href="https://openaipublic.blob.core.windows.net/neuron-explainer/paper/index.html">Language models can explain neurons in language models</a> Bills, S., Cammarata, N., Mossing, D., Tillman, H., Gao, L., Goh, G., Sutskever, I., Leike, J., Wu, J., &amp; Saunders, W. 2023. OpenAI Blog</p> <p><a class='citeend' key='multiple-choice'></a><a href="https://arxiv.org/abs/2307.09458">Does Circuit Analysis Interpretability Scale? Evidence from Multiple Choice Capabilities in Chinchilla</a> Tom Lieberum, Matthew Rahtz, János Kramár, Neel Nanda, Geoffrey Irving, Rohin Shah, Vladimir Mikulik (2023). arXiv preprint arXiv:2307.09458.</p> <p><a class='citeend' key='TMOS'></a><a href="https://transformer-circuits.pub/2022/toy_model/index.html">Toy Models of Superposition</a> Elhage, N., Hume, T., Olsson, C., Schiefer, N., Henighan, T., Kravec, S., Hatfield-Dodds, Z., Lasenby, R., Drain, D., Chen, C., Grosse, R., McCandlish, S., Kaplan, J., Amodei, D., Wattenberg, M. and Olah, C., 2022. Transformer Circuits Thread.</p> <p><a class='citeend' key='Connectome'></a> <a href="https://www.science.org/doi/abs/10.1126/science.add9330">The Connectome of an Insect Brain</a> Winding, M., Pedigo, B. D., Barnes, C. L., Patsolic, H. G., Park, Y., Kazimiers, T., … &amp; Zlatic, M. (2023). Science, 379(6636), eadd9330.</p> <p><a class='citeend' key='Multiscale'></a> <a href="https://proceedings.mlr.press/v162/pezeshki22a/pezeshki22a.pdf">Multi-Scale Feature Learning Dynamics: Insights for Double Descent</a> Pezeshki, M., Mitra, A., Bengio, Y., &amp; Lajoie, G. (2022, June). In the International Conference on Machine Learning (pp. 17669-17690). PMLR.</p> <p><a class='citeend' key='superposition'></a><a href="https://transformer-circuits.pub/2023/toy-double-descent/index.html">Superposition, Memorization, and Double Descent</a> Henighan, T., Carter, S., Hume, T., Elhage, N., Lasenby, R., Fort, S., Schiefer, N., and Olah, C., 2023. Transformer Circuits Thread.</p> <h3 id="more-explorables">More Explorables</h3> <p><p id='recirc'></p></p> <div class='recirc-feedback-form'></div> <p><link rel='stylesheet' href='../third_party/footnote_v2.css'></p> <p><link rel='stylesheet' href='../third_party/citation_v2.css'></p> <link rel='stylesheet' href='style.css'> <script id='MathJax-script' async src='https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js'></script> <script defer src='https://cdn.jsdelivr.net/npm/katex@0.16.8/dist/contrib/mathtex-script-type.min.js' integrity='sha384-jiBVvJ8NGGj5n7kJaiWwWp9AjC+Yh8rhZY3GtAX8yU28azcLgoRo4oukO87g7zDT' crossorigin='anonymous'></script> <script src='../third_party/d3_.js'></script> <script src='../third_party/d3-scale-chromatic.v1.min.js'></script> <script src='../third_party/tfjsv3.18.0.js'></script> <script src='../third_party/npyjs-global.js'></script> <script src='../third_party/swoopy-drag.js'></script> <script src='../third_party/footnote_v2.js'></script> <script src='../third_party/citation_v2.js'></script> <script src='util.js'></script> <script src='init-accuracy-chart.js'></script> <script src='init-animate-steps.js'></script> <script src='init-embed-vis.js'></script> <script src='init-input-sliders.js'></script> <script src='init-swoopy.js'></script> <p><link rel='stylesheet' href='mod-top/style.css'></p> <script src='mod-top/init-waves.js'></script> <script src='mod-top/init.js'></script> <p><link rel='stylesheet' href='sparse-parity/style.css'></p> <script src='sparse-parity/init.js'></script> <script src='sparse-parity/init-weight-trajectory.js'></script> <p><link rel='stylesheet' href='sweep-sparse-parity/style.css'></p> <script defer src='sweep-sparse-parity/init.js'></script> <p><link rel='stylesheet' href='sweep-mod/style.css'></p> <script defer src='sweep-mod/sweep-mod-charts.js'></script> <script defer src='sweep-mod/init.js'></script> <p><link rel='stylesheet' href='hand-weights/style.css'></p> <p><link rel='stylesheet' href='hand-weights/sliders.css'></p> <script src='hand-weights/init-embed-vis.js'></script> <script src='hand-weights/init-activation-vis.js'></script> <script src='hand-weights/init-circle-weights-vis.js'></script> <script src='hand-weights/init-circle-input-vis.js'></script> <script src='hand-weights/init-circle-weights-freq.js'></script> <script defer src='hand-weights/init.js'></script> <p><link rel='stylesheet' href='five-neurons/style.css'></p> <script src='five-neurons/five-circle.js'></script> <script defer src='five-neurons/init.js'></script> <p><link rel='stylesheet' href='mod-bot/style.css'></p> <script src='mod-bot/init-bot-freqs.js'></script> <script src='mod-bot/init-bot-logits.js'></script> <script src='mod-bot/seeds/init-seeds.js'></script> <script defer src='mod-bot/init.js'></script> <script src='open-q-mem/init-0.js'></script> <script src='open-q-mem/init-1.js'></script> <script defer src='circle-freq/circle-freq-init.js'></script> <p><link rel='stylesheet' href='appendix/style.css'></p> <script src='appendix/debug-vis.js'></script> <script src='appendix/init-circle-weights-vis.js'></script> <script src='appendix/init-proj-vis.js'></script> <script src='appendix/init-sliders.js'></script> <script src='appendix/line-error-vis.js'></script> <script src='appendix/debug-reuleaux.js'></script> <script defer src='appendix/init.js'></script> <script defer src='../third_party/recirc.js'></script> </body> <script async src="https://www.googletagmanager.com/gtag/js?id=UA-138505774-1"></script> <script> if (window.location.origin === 'https://pair.withgoogle.com'){ window.dataLayer = window.dataLayer || []; function gtag(){dataLayer.push(arguments);} gtag('js', new Date()); gtag('config', 'UA-138505774-1'); } </script> <script> // Tweaks for displaying in an iframe if (window !== window.parent){ // Open links in a new tab Array.from(document.querySelectorAll('a')) .forEach(e => { // skip anchor links if (e.href && e.href[0] == '#') return e.setAttribute('target', '_blank') e.setAttribute('rel', 'noopener noreferrer') }) // Remove recirc h3 Array.from(document.querySelectorAll('h3')) .forEach(e => { if (e.textContent != 'More Explorables') return e.parentNode.removeChild(e) }) // Remove recirc container var recircEl = document.querySelector('#recirc') recircEl.parentNode.removeChild(recircEl) } </script> </html>

Pages: 1 2 3 4 5 6 7 8 9 10