miiii
MECHANISTIC INTERPRETABILITY ON IRREDUCIBLE INTEGER IDENTIFIERSNoah Syrkis & Anders SøgaardNovember 29, 2024IntroductionRecent years have seen deep learning (DL) models achieve remarkable proficiency incomplex computational tasks, including protein structure prediction [1], strategic reasoning [2], and natural language generation [3]—areas previously thought to be the exclusivedomain of human intelligence. Traditional (symbolic) programming allows functionslike 𝑓(𝑥,𝑦)=cos(𝑎𝑥)+sin(𝑏𝑦) to be implemented in code with clear typographicalisomorphism—meaning the code’s structure directly mirrors the mathematical notation.For example, in the language Haskell: f x y = cos(a * x) + sin(b * y). In contrast, DLmodels are inherently sub-symbolic, meaning that the models’ atomic constituents (often32-bit floating-point numbers centered around 0) do not map directly to mathematicalvocabulary. For reference, shows a DL-based implementation of the aforementionedfunction. Indeed, the increasing prevalence of DL can be understood as a transition fromsymbolic to sub-symbolic algorithms.Precursors to modern DL methods learned how to weigh human-designed features [4],with later works learning to create features from data to subsequently weigh [5], [6]—incombination with tree search strategies, in the case of games [7]. Very recent DL work haseven eliminated tree search in the case of chess, mapping directly from observation spaceto action space [8]. Pure DL methods are thus becoming ubiquitous but remain largelyinscrutable, with recent works still attempting to define what interpretability even meansin the DL context [9]. Given the sub-symbolic nature of DL models, it is unsurprisingthat their interpretation remains difficult.Mathematically, DL refers to a set of methods that combine linear maps (matrix multiplications) with non-linearities (activation functions). Formally, all the potential numericalvalues of a given model’s weights 𝑊 can be thought of as a hypothesis space . Often, is determined by human decisions (number of layers, kinds of layers, sizes of layers, etc.). is then navigated using some optimization heuristic, such as gradient descent, in thehope of finding a 𝑊 that “performs well” (i.e., successfully minimizes some loss oftencomputed by a differentiable function with respect to 𝑊) on whatever training data ispresent. This vast, sub-symbolic hypothesis space, while enabling impressive performanceand the solving of relatively exotic¹ tasks, makes it challenging to understand how anyone particular solution actually works (i.e., a black box algorithm).The ways in which a given model can minimize can be placed on a continuum: on oneside, we have overfitting, remembering the training data, (i.e. functioning as an archiveakin to lossy and even lossless compression); and on the other, we have generalization,learning the rules that govern the relationship between input and output (i.e. functioningas an algorithm).When attempting to give a mechanistic explanation of a given DL model’s behavior,it necessarily entails the existence of a mechanism. Mechanistic interpretability (MI)assumes this mechanism to be general, thus making generalization a necessary (thoughinsufficient) condition. Generalization ensures that there is a mechanism/algorithmpresent to be uncovered (necessity); however, it is possible for that algorithm to be soobscurely implemented that reverse engineering, for all intents and purposes, is impossible(insufficiency). Various forms of regularization are used to incentivize the emergence ofalgorithmic (generalized) and interpretable, rather than archiving (over-fitted) behavior[10], [11], [12].As of yet, no MI work has explored the effect of multi-task learning, the focus of this paper.Multitask learning also has a regularizing effect [13]. Formally, the set of hypotheses spacesfor each task of a set of tasks (often called environment) is denoted by . Whenminimizing the losses across all tasks in parallel, generalizing 𝑊’s are thus incentivized,as these help lower loss across tasks (in contrast to memorizing 𝑊’s that lower loss forone task). A 𝑊 derived from a multi-task training process can thus be thought of as theintersection of the high-performing areas of all .In this spirit, the present paper builds on the work of Nanda et al. (2023), which trains atransformer [15] model to perform modular addition, as seen in Eq. 1. The task is denotedas 𝒯nanda throughout the paper.(𝑥0+𝑥1)mod𝑝,𝑥0,𝑥1<𝑝,𝑝=113(1)The task of this paper focuses on predicting remainders modulo all primes 𝑞 less than𝑝, where 𝑥 is interpreted as 𝑥0𝑝0+𝑥1𝑝1, formally shown in Eq. 2, and is referred to as𝒯miiii:(𝑥0𝑝0+𝑥1𝑝1)mod𝑞,𝑥0,𝑥1<𝑝,𝑞<𝑝,𝑝=113(2)𝒯miiii differentiates itself from 𝒯nanda in two significant ways: 1) it is non-commutative, and2) it is, as mentioned, multi-task. These differences present unique challenges for mechanistic interpretation, as the model must learn to handle both the order-dependent natureof the inputs and develop shared representations across multiple modular arithmetic tasks.Further, as 𝒯miiii is harder than 𝒯nanda the model can be expected to generalize slowerwhen trained on the former. Therefore, Lee et al. (2024)‘s recent work on speeding upgeneralization by positing the model parameters gradients through time can be viewedas a sum of 1) a slow varying generalizing component (which is boosted), and 2), a quickvarying overfitting component (which is suppressed), is (successfully) replicated to maketraining tractable.Figure 1: Visualizing natural numbers less than 12769 in polar coordinates (𝑛,𝑛mod2𝜋).Left: union of numbers with remainder 0 mod 17 and 23 (see the two spirals). Middle:numbers with remainder 0 mod 11. Right: prime numbers. It is shown here to encouragethe reader to think in periodic terms.More generally, modular arithmetic on primes is a particularly useful task for MI as itensures uniformity among the output classes, allows for comparison with other MI work[14], and, from a number-theoretic point of view, primes contain mysteries ranging fromthe trivially solved—are there an infinite number of primes?—to the deceptively difficult—can all even numbers larger than 4 be described as the sum of two primes? The latter,known as Goldbach’s Conjecture, remains unsolved after centuries. The choice of usingevery prime less than the square root of the largest number of the dataset also serves thefollowing purpose: to test if a given natural number is prime, it suffices to test that it isnot a multiple of any prime less than its square root—the set of tasks trained for here,can thus be viewed in conjunction as a single prime detection task (primes are the onlysamples whose target vector contains no zeros, since it is not a multiple of any of thefactors 𝑞). There are about 𝑛ln(𝑛) primes less than 𝑛.To provide insight into the periodic structure of these remainders for natural numbersless than 12769 (and motivate thinking in rotational terms), Figure 1 visualizes variousmodular patterns in polar coordinates (𝑛,𝑛mod2𝜋). One could imagine tightening andloosening the spiral by multiplying 2𝜋 by a constant to align multiples of a given numberin a straight line (imagining this is encouraged).Background and related workMultiple papers describe the use of deep learning to detect prime numbers [17], [18], [19].None are particularly promising as prime detection algorithms, as they do not providespeedups, use more memory, and are less accurate than traditional methods. However, inexploring the foundations of deep learning, the task of prime detection is interesting, asit is a simple task that is difficult to learn, and is synthetic, meaning that the arbitraryamounts of data are generated by a simple algorithm.Mechanistic Interpretability (MI)MI is a relatively new field focused on reverse-engineering the internal mechanisms ofneural networks. Lipton (2018) explored different definitions of interpretability in thiscontext. MI can be contrasted with other forms of interpretability, such as featureimportance analysis; while feature importance measures correlations between inputs andoutputs (e.g., red pixels correlating with “rose” classifications), MI aims to understandhow the model actually processes information (i.e., the mechanism).Methods and tools used so far in MI include: Activation visualization across orderedsamples; Singular value decomposition of weight matrices; Ablation studies to identifycritical circuits. Conmy et al. (2023) even successfully automate circuit² discovery. Manyreverse engineering methods from other fields, such as computational neuroscience orsignal processing, almost certainly have their uses here as well.In spite of deep learning’s practical successes, uncertainty remains about its theoreticalunderpinnings, echoing the interpretability debate. Recent work attempts to place different DL architectures and concepts in a either geometric [21], information theoretic [22], oreven category theoretic [23] context. However, no unified theory has emerged. Much interesting deep learning research thus focuses on practical, simple, or algorithmic tasks withknown solutions and architectures. For example, grokking [24], the (relatively) suddengeneralization after overfitting, as elaborated later, is a recent and practical discovery.Case study: modular additionOne such practical discovery is made by Nanda et al. (2023). A single layer transformermodel with ReLU activation function was trained to perform modular addition (𝒯nanda).Nanda et al. (2023)‘s analysis of their trained model exemplifies MI methodology. Theydiscovered that: 1) The embedding layer learns trigonometric lookup tables of sine andcosine values as per Eq. 3; 2) The feed-forward network combines these through multiplication and trigonometric identities (Eq. 4), and 3) The final layer performs the equivalentof argmax (Eq. 5).𝑥0sin(𝑤𝑥0),cos(𝑤𝑥0)(3.1)𝑥1sin(𝑤𝑥1),cos(𝑤𝑥1)(3.2)sin(𝑤(𝑥0+𝑥1))=sin(𝑤𝑥0)cos(𝑤𝑥0)+cos(𝑤𝑥0)sin(𝑤𝑥1)(4.1)cos(𝑤(𝑥0+𝑥1))=cos(𝑤𝑥1)cos(𝑤𝑥1)sin(𝑤𝑥0)sin(𝑤𝑥1)(4.2)Logit(𝑐)cos(𝑤(𝑥0+𝑥1𝑐))(5.1)=cos(𝑤(𝑥0+𝑥1))cos(𝑤𝑐)+sin(𝑤(𝑥0+𝑥1))sin(𝑤𝑐)(5.2)Generalization and grokkingPower et al. (2022) shows generalization can happen “[...] well past the point of overfitting”, dubbing the phenomenon “grokking”. The phenomenon is now well established [14],[25], [26]. Nanda et al. (2023) shows that a generalized circuit “arises from the gradualamplification of structured mechanisms encoded in the weights,” rather than being arelatively sudden and stochastic encounter of an appropriate region of . The importantword of the quote is thus “gradual”.By regarding the series of gradients in time as a stochastic signal, Lee et al. (2024) proposesdecomposing the signal. Conceptually, Lee et al. (2024) argues that in the case of gradientdescent, the ordered sequence of gradient updates can be viewed as consisting of twocomponents: 1) a fast varying overfitting component, and 2) a slow varying generalizingcomponents. The general algorithm explaining the relationship between input and outputis the same for all samples, whereas the weights that allow a given model to function areunique for all samples. Though not proven, this intuition bears out in that generalizationis sped up fifty-fold in some cases.This echoes the idea that generalized circuits go through gradual amplification [14]. To theextent that this phenomenon is widespread, it bodes well for generalizable DL in that thegeneralizing signal that one would want to amplify might exist long before the model isfully trained and could potentially be boosted in a targeted way by the method describedby Lee et al. (2024).Perhaps the most widespread loss functions used in deep learning are mean cross-entropyEq. 6.1 (for classification) and mean squared error Eq. 6.2 (for regression).𝐿MCE=1𝑛𝑛𝑖=1𝑘𝑗=1𝑦𝑝𝑖𝑗ln(1̂𝑦𝑝𝑖𝑗)(6.1)𝐿MSE=1𝑛𝑛𝑖=1(𝑦𝑖̂𝑦𝑖)2(6.2)These have various computational and mathematical properties that make them convenient to use, though they have been shown to struggle with generalizing out-of-distributiondata [27], [22].Multi-task learning in deep learningAs stated, multi-task learning has been shown to have a regularizing effect [13], [28] as thehypothesis 𝑊 that performs well across all of the hypothesis spaces is more likelyto be general. Viewed information theoretically, this concept is reminiscent of Shannon(2001)‘s asymptotic equipartition property [30], or even more generally, the law of largenumbers, which states that the more samples we have of a distribution, the closer ourestimates are to its underlying properties will align with the true underlying properties.In the context of 𝒯miiii, multi-task learning is done by having the last layer outputpredictions for all tasks in parallel. Thus, whereas 𝒯nanda outputs a single one-hot 1×113vector for each of the potential remainders, 𝒯miiii, as we shall see, outputs a 1×𝑞 vectorfor each prime 𝑞<𝑝 (i.e., 29 output-task vectors when 𝑝=113). The embeddings layerand the transformer block are thus shared for all tasks, meaning that representations thatperform well across tasks are incentivized.Transformer architectureTransformers combine self-attention (a communication mechanism) with feed-forwardlayers (a computation mechanism). The original transformer-block [15] used extensiveregularization—layer norm [10], dropout, weight decay, and residual connections are allintegral components of the original architecture, though recent years have seen simplifications yielding similar performance [31], [32].Input tokens are embedded into a 𝑑-dimensional space using learned token and positionalembeddings:𝑧=TokenEmbed(𝑥)+PosEmbed(pos)(7)Each transformer block comprises multi-head attention:Attention(𝑄,𝐾,𝑉)=softmax(𝑄𝐾𝑇𝑑𝑘)𝑉(8)where 𝑄, 𝐾, and 𝑉 are linear projections of the input. Attention heads are combinedthrough addition rather than concatenation (a transformer specific detail to align withNanda et al. (2023)). This is followed by a feed-forward network with ReLU activation:FFN(𝑧)=ReLU(𝑧𝑊in)𝑊out(9)mapping from 𝑑4𝑑𝑑 dimensions, before finally:̂𝑦=𝑧𝑊unembed(10)Each component includes residual connections and dropout.MethodsHow exactly a given model implements an algorithm is a non-trivial question—evenmodular addition is implemented in a relatively obscure way [14], as per Eq. 3, Eq. 4,and Eq. 5.This investigation probes the fundamental algorithmic structures internalized by a transformer model trained on a set of basic prime number-related modular arithmetic taskswith slight variations in complexity. This approach provides insights into how and whyspecific algorithmic patterns emerge from seemingly straightforward learning processes.As stated, the setup here differentiates itself from 𝒯nanda in two crucial ways: 1) It is non-commutative; and 2) It is multitask.TasksStated plainly: the task 𝒯miiii predicts the remainder when dividing a two-digit base-𝑝number by each prime factor 𝑞 less than 𝑝. The set of prime factors we construct tasksfor is thus {𝑞}={𝑞:𝑞<𝑝}. For 𝑝=113, this yields 29 parallel tasks, one for eachprime less than 𝑝. Each task predicts a remainder in the range [0,𝑞1]. This meanssmaller primes like 2 and 3 require binary and ternary classification, respectively, whilethe largest prime less than 𝑝, 109, requires predictions across 109 classes. The tasksthus naturally vary in difficulty: predicting mod2 requires distinguishing odd from evennumbers (which in binary amounts to looking at the last bit) while predicting mod109involves making a selection between many relatively similar classes. From an information-theoretical perspective, the expected cross entropy for an 𝑛-class problem is ln(𝑛), whichhas implications for the construction of the loss function, further discussed in .Additionally, a baseline task 𝒯basis was constructed by shuffling the 𝑦-labels of 𝒯miiii, anda task ablation test 𝒯masked was constructed by masking away the four simplest tasks𝑞{2,3,5,7}.DataInput Space (𝑋) Each input 𝑥𝑋 represents a number in base 𝑝 using two digits,(𝑥0,𝑥1), where the represented number is 𝑥0𝑝0+𝑥1𝑝1. For example, with 𝑝=11, theinput space consists of all pairs (𝑥0,𝑥1) where 𝑥0,𝑥1<11, representing numbers up to1121=120. This yields a dataset of 121 samples. Figure 2 visualizes this input space,with each cell representing the value 𝑥0𝑝0+𝑥1𝑝1.Figure 2: Visualizing X (for a small dataset where 𝑝=11). Each cell represents the tuple(𝑥0,𝑥1). The top left shows 0 (0,0), and the bottom right shows 120 (10,10)—bothin base-11Output Space (𝑌) For each input 𝑥, a vector 𝑦𝑌 contains the remainder whendividing by each prime less than 𝑝. For 𝑝=11, this means predicting the remainder whendividing by 2, 3, 5, and 7. Each element 𝑦𝑖 ranges from 0 to 𝑞𝑖1 where 𝑞𝑖 is the 𝑖-th prime. Figure 3 visualizes these remainders, with each subplot showing the remainderpattern for a specific prime divisor. For comparison, the rightmost plot shows the outputspace of [14]‘s modular addition task.Figure 3: Visualizing tasks in Y (for 𝑝=11). 𝑥0 and 𝑥1 vary on the two axis, with theremainder modulo 𝑞{2,3,5,7} indicated by the square size. Note the innate periodicityof the modulo operator.ModelThe model follows the original transformer architecture [15] with several key design choicesaligned with recent work on mechanistic interpretability [14], [16]: biases are disabled,and layer normalization is not used. The model consists of three main components: anembedding layer, transformer blocks, and an output layer. All weights are initializedfollowing He et al. (2015). The model processes vectors of the kind seen in Eq. 11, writingthe eventual result to the last position.[𝑥0𝑥1̂𝑦](11)TrainingHyper parameter optimization was conducted using Optuna [34], searching over Table 1.dropout𝜆wd𝑑lrheads0,12,15,1100,12,20,110,12,1128,2563e-4, 1e-44, 8Table 1: Hyperparameter search space for training.The model is trained using AdamW [35] with 𝛽1=0.9, 𝛽2=0.98 following Nanda et al.(2023). To handle the varying number of classes across tasks (from 2 classes for mod 2to 109 classes for mod 109), a modified (weighted) mean cross-entropy (Eq. 6.1) loss iscreated, correcting for the difference in the expected loss within each task. Note that𝔼[𝐿MCE]=ln(1𝑞), where 𝑞 is the number of classes within the task in question. Correctingfor this, the loss function becomes as shown in Eq. 12.3.𝐿𝒯miiii=𝑞{𝑞}𝐿MCE𝑞ln(𝑞)(12.1)=𝑞{𝑞}𝑛𝑖=1𝑞1𝑗=0𝑦𝑞𝑖𝑗ln(̂𝑦𝑞𝑖𝑗)𝑛ln(𝑞)(12.2)=𝑞{𝑞}𝑛𝑖=1𝑞1𝑗=0𝑦𝑞𝑖𝑗ln(̂𝑦𝑞𝑖𝑗)𝑛ln(𝑞)(12.3)To accelerate generalization, gradient filtering as per Lee et al. (2024) is implemented andreplicated.𝑔𝑡=𝜃𝐿+𝜆(𝛼𝑒𝑡1+(1𝛼)𝑔𝑡1)(13)where 𝑒𝑡 is the exponential moving average of gradients with decay rate 𝛼=0.98, and 𝜆controls the influence of the slow-varying component.The training uses full batch gradient descent with the entire dataset of 𝑝2 samples (12769when 𝑝=113). The model is evaluated on a held-out validation set after each epoch,tracking per-task accuracy and loss. As the setup used in 𝒯nanda, training was done onthirty percent of the total dataset, with the remaining used for validation (1000 samples)and testing (remaining). Further, as 𝒯miiii involves the learning of 29 (when 𝑝=113) tasksrather than 1, and due to each task’s non-commutativity, a larger hidden dimension of 256was added to the hyper parameter search space, as well as the potential for 8 heads (𝒯nandawas solved with a hidden dimension of 128, and 4 heads). The number of transformerblocks was kept at 1, as this ensures consistency with 𝒯nanda (and as full generalizationwas possible, as we shall see in the ).Training was done on a NVIDIA GeForce RTX 4090 GPU, with Python3.11 and extensiveuse of “JAX 0.4.35” and its associated ecosystem. Neuron activations were calculated atevery training step and logged for later analysis.VisualizationMuch of the data worked with here is inherently high dimensional. For training, forexample, we have 𝑛 steps, two splits (train/valid) about 𝑝ln(𝑝) tasks, and two metrics(accuracy and loss). This, along with the inherent opaqueness of deep learning models,motivated the development of a custom visualization library, esch³, to visualize attentionweights, intermediate representations, training metrics, and more. To familiarize thereader with visualizing the inner workings of a trained model, an essential plot type forthe reader to keep in mind is seen in Figure 4. As there are only 12769 samples when𝑝=113, all samples can be fed at once to the model. Inspecting a specific activation thusyields a 1× 12796 vector 𝑣, which can be reshaped as a 113×113 matrix, with the twoaxes, 𝑥0 and 𝑥1, varying from 0 to 112, respectively. The top-left corner then shows thegiven value for the sample (0𝑝0+0𝑝1), and so on.Figure 4: Plotting a neuron: (left) The activation of a particular neuron as 𝑥0 and 𝑥1varies from 0 to 𝑝. (right) The same processed with a fast Fourier transform to activefrequencies (𝜔).Note that in esch plots, when appropriate, only the top leftmost 37×37 slice is shownso as not to overwhelm the reader.Mechanistic interpretability processRecall that a combination of linear products is itself a linear product. Therefore, as amechanistic interpretability rule of thumb, one should look at the outputs of the non-linear transformations. In our case, that will be the attention weights and the intermediaterepresentations within the transformer block’s feed-forward output (which follows theReLU activation). Additionally, the embedding layers will be inspected using Fourieranalysis and singular value decomposition. As mentioned in , our interpretability approachcombines activation visualization with frequency analysis to understand the learned algorithmic patterns. Following Nanda et al. (2023), we analyze both the attention patternsand the learned representations through several lenses:Attention visualizationUsing esch, the custom visualization library, to visualize attention weights and intermediate representations. The library allows for the visualization of attention patterns acrossdifferent layers, as well as the visualization of intermediate representations at each layer.These visualizations provide insights into the learned patterns and help identify potentialareas of improvement.The fast Fourier transformAs periodicity is established by Nanda et al. (2023) as a fundamental feature of the modeltrained on 𝒯nanda, the fast Fourier transform (FFT) algorithm is used to detect whichfrequencies are in play. Note that any square image can be described as a sum of 2dsine and cosine waves varying in frequency from 1 to the size of the image divided by2 (plus a constant). This is a fundamental tool used in signal processing. The theory isbriefly outlined in for reference. This analysis helps identify the dominant frequencies inthe model’s computational patterns. Recall that a vector can be described as a linearcombination of other periodic vectors as per the discrete Fourier transform.The default basis of the one-hot encoded representation of the input is thus the identitymatrix. This can be projected into a Fourier basis by multiplying with the discrete Fouriertransform (DFT) matrix visualized in .Results and analysisHyper-parameter optimizationThe best-performing hyper-parameters for training the model on 𝒯miiii are listed inTable 2. Notably, the model did not converge when 𝜆=0, confirming the utility of thegradient amplification method proposed by Lee et al. (2024) in the context of 𝒯miiii.dropout𝜆wd𝑑lrheads11012132563×1044Table 2: Result of hyper-parameter search over 𝒯miiii.Model PerformanceFigure 5 show the training and validation accuracy on 𝒯miiii over time. The model achievesa perfect accuracy of 1 on the validation set across all 29 tasks. The cross-entropy lossin Figure 6 echoes this. In short—and to use the terminology of Power et al. (2022)—themodel “grokked” on all tasks. Interestingly, tasks corresponding to modulo 2, 3, 5, and 7generalized in succession, while the remaining 25 tasks generalized around epoch 40000in no particular order. This might suggest that the model initially learned solutions forthe simpler tasks and later developed a more general computational strategy that allowedit to generalize across the remaining, more complex tasks.Figure 5: Accuracy training “curves”: Training (top) and validation (bottom) accuracyover time (𝑥-axis in log-scale). We see grokking occur on all tasks, first for 𝑞{2,3,5,7}in that order, and then the remaining 25 in no particular order.Figure 6: Cross-entropy (Eq. 6.1) loss on training (top) and validation (bottom) over time(note the log scale on the 𝑥-axis).EmbeddingsPositional embeddings play a crucial role in transformers by encoding the position oftokens in a sequence. Figure 7 compares the positional embeddings of models trained on𝒯nanda and 𝒯miiii.For 𝒯nanda, which involves a commutative task, the positional embeddings are virtuallyidentical, with a Pearson correlation of 0.95, reflecting that the position of input tokensdoes not significantly alter their contribution to the task. In contrast, for 𝒯miiii, thepositional embeddings have a Pearson correlation of −0.64, indicating that the embeddingsfor the two positions are different. This difference is expected due to the non-commutativenature of the task, where the order of 𝑥0 and 𝑥1 matters (𝑥0𝑝0𝑥0𝑝1). This confirmsthat the model appropriately encodes position information for solving the tasks.Figure 7: Positional embeddings for (𝑥0,𝑥1) for models trained on 𝒯nanda (top) and𝒯miiii (bottom). Pearson’s correlation is 0.95 and −0.64 respectively. This reflects thecommutativity of 𝒯nanda and the lack thereof for 𝒯miiii. Hollow cells indicate negativenumbers.Recall that a matrix 𝐌 of size 𝑚×𝑛 can be decomposed to its singular values 𝐌=𝐔𝚺𝐕𝐓 (with the transpose being the complex conjugate when 𝐌 is complex), where 𝐔 is𝑚×𝑚, 𝚺 an 𝑚×𝑛 rectangular diagonal matrix (whose diagonal is represented as a flatvector throughout this paper), and 𝐕𝐓 a 𝑛×𝑛 matrix. Intuitively, this can be thoughtof as rotating in the input space, then scaling, and then rotating in the output space.Figure 8 displays the singular values of the token embeddings learned for 𝒯nanda and𝒯miiii. The singular values for 𝒯miiii are more diffuse, indicating that a larger number ofcomponents are needed to capture the variance in the embeddings compared to 𝒯nanda.This suggests that the token embeddings for 𝒯miiii encode more complex information,reflecting the increased complexity of the multi-task learning scenario.Figure 8: First 83 of 113 singular values (truncated for clarity) of U for 𝒯nanda (top)and 𝒯miiii (bottom). The ticks indicate the points where 50% and 90% of the variance isaccounted for. We thus see that for 𝒯miiii, the embedding space is much more crammed.Figure 9: 𝒯nanda’s most significant (cutoff at 0.5 as per Figure 8) singular vectors of Ufrom the singular value decomposition. Note this looks periodic!Figure 10: 𝒯miiii’s most significant vectors of U. Note that, like in Figure 9, we still observeperiodicity, but there are more frequencies in play, as further explored in Figure 12.Figure 9 and Figure 10 present the most significant singular vectors of U for 𝒯nanda and𝒯miiii, respectively. Visual inspection shows periodicity in the top vectors for both models,but the 𝒯miiii model requires more vectors to capture the same amount of variance,consistent with the diffuse singular values observed in Figure 8.To further understand the structure of the token embeddings, we applied the Fast FourierTransform (FFT). Only a few frequencies are active for 𝒯nanda as seen in Figure 11,consistent with the model implementing a cosine-sine lookup table as described in Nandaet al. (2023).For the 𝒯miiii model, we observe a broader spectrum of active frequencies (Figure 12). Thisis expected due to the model having to represent periodicity corresponding to 29 primes.Comparing with 𝒯basis in figure Figure 13, the periodicity is understood to be a structureinherent to the data picked up by the model.Figure 11: 𝒯nanda tokens in Fourier basis: Note how all tokens are essentially linearcombinations of the five most dominant Fourier basis vectors. The sparsity echoes thefindings in Figure 8 that very few directions in the embedding space are used.Figure 12: The periodicity in the 𝒯miiii embeddings involves a much larger fraction of theFourier basis, echoing the multiple tasks and their innate difference in frequency (recallthat all tasks are performed on unique primes 𝑞).Figure 13: Embeddings for 𝒯basis in Fourier basis have no periodicity. The periodicity isindeed an artifact of the modulo operator.Analysis of Neuron Activations and FrequenciesTo understand the internal mechanisms developed by the model, we analyzed the neuronactivations after the output weight matrix 𝑊out for the model trained on 𝒯miiii. Figure 14shows that these activations exhibit periodic patterns with respect to (𝑥0,𝑥1). Thisperiodicity aligns with the modular arithmetic nature of the tasks, mirroring Nanda etal. (2023) (𝒯nanda).Figure 14: We plot the activation of the first three neurons of the activations immediatelyfollowing ReLU in Eq. 9 as 𝑥0 and 𝑥1 vary (top). Note we only show the top 37×37corner of the full 133×113 sample matrix. Here too we see periodicity, confirmed by aFourier transform (bottom). Neurons are reactive to highly particular frequencies in theirinput domains.For comparison, Figure 15 shows the neuron activations for a model trained on 𝒯basis.These activations do not exhibit periodicity, confirming that the observed periodicpatterns in the models trained for 𝒯miiii and 𝒯nanda, too, are indeed a result of the modulooperations inherent in the tasks.Figure 15: Neuron activations for model trained on 𝒯basis. As in Figure 13, no periodicityis observed for the baseline.The analysis of active frequencies through training using the Fast Fourier Transform (FFT)is illustrated in Figure 16, with the core findings showing a spike in frequency activationaround epoch 16384 visible in Table 3. The top plot shows the different frequencies of thetransformer block’s feed-forward neurons evolving as the model learns. The bottom plotdisplays the variance of frequency activations and the number of frequencies exceedinga significance threshold 𝜔>𝜇+2𝜎 (i.e., which spots like the ones of the bottom row ofFigure 14 are active). Initially, a handful of frequencies become dominant as the modelgeneralizes on the first four tasks. As training progresses and the model begins to generalize on the remaining tasks, more frequencies become significant, suggesting that themodel is developing more complex internal representations to handle the additional tasks.Figure 16: Top: Frequency dominance is averaged over neurons through training (averageactivation of frequencies as shown in Figure 14). We see four distinct phases: 1) nosignificant frequencies, 2) frequencies gradually emerging, 3) a wall of frequencies, and4) frequency count similar to phase 2. Bottom: total number of active frequencies at thecorresponding time step. A frequency 𝜔 is active when 𝜔>𝜇+2𝜎 (a frequently usedsignal processing default).epoch256102440961638465536|𝜔|00101810Table 3: Number of active frequencies on 𝒯miiii over epochs.Figure 17 shows the L2 norms of gradients through time for the different weight matricesof the model trained on 𝒯miiii. The gradient norms provide insights into how differentparts of the model are being updated during training. Like with Nanda et al. (2023), theattention layer converges quickly, echoing their finding that it does not contribute muchto solving their modular arithmetic task.Figure 17: L2 norms of gradients over time for the different weight matrices of the modeltrained on 𝒯miiii. Row order corresponds to how deep in the model the weight is used. Thisshows when during training what parts of the model are updated. e.* are embeddingsEq. 7, a.* attention layer weights Eq. 8, and w.* weights of the feed-forward module Eq. 9(e.u being the final un-embedding layer.)These results demonstrate that more frequencies are involved when training on 𝒯miiiicompared to 𝒯nanda. The increased frequency components reflect the need for the model toencode multiple periodic patterns corresponding to the various modular arithmetic tasks.Combining the analysis of embeddings and the transformer block neurons, we see that:1.A lot more frequencies are in play for 𝒯miiii than in 𝒯nanda.2.Neurons remain highly reactive to a very small set of frequencies.3.The periodicity is an artifact of the modulo group by analysis of 𝒯basisAttention PatternsFigure 18 shows that, in contrast to the model trained on 𝒯nanda, where attention headsmay focus jointly on both input tokens in a periodic fashion, the attention heads for the𝒯miiii model focus exclusively on one digit or the other. This behavior could be due tothe non-commutative nature of the task, where the position of each digit significantlyaffects the outcome. Nanda et al. (2023) concludes that the attention mechanism doesnot contribute significantly4 to the solving of 𝒯nanda, and will thus also not be exploredfurther here.Figure 18: Attention from ̂𝑦 to 𝑥0 for the four attention heads in the 𝒯miiii model. Theattention heads tend to focus on one digit, reflecting the non-commutative nature ofthe task.Overall, our results demonstrate that the model effectively learns to solve multiplemodular arithmetic tasks by developing internal representations that capture the periodicnature of these tasks. The analysis of embeddings and neuron activations provides insightsinto how the model generalizes from simpler to more complex tasks, possibly through thereuse and integration of learned circuits. Interestingly, as will be discussed in , there arefour significant shifts in the number of frequencies active in the neurons.Lastly, as can be seen in , when dropout was disabled, the model’s performance divergedon the validation set (overfitting), even with other kinds of regularization (multi-task,l2, etc.). shows the accuracy through training for 𝒯masked. The masking of 𝑞{2,3,5,7},does not delay grokking on the remaining tasks notably, the spiking after generalizationto easy tasks (𝑞{11,13,17,19}) remains.DiscussionRecall that, when viewed individually, the sub-tasks of 𝒯miiii differ from 𝒯nanda only incommutativity (making the task harder) and in prime since 𝑞<𝑝 (making the task easierfor smaller 𝑞’s like {2,3,5,7}—though less so as 𝑞 approaches 𝑝). Figure 7 indicates thatthe model learns to account for the commutativity by using positional embeddings.During training, the changes in the number of active neuron frequencies 𝜔 in the feed-forward layer (see Figure 16) echo the loss and accuracy progression seen in Figure 5and Figure 6. Further, as the model groks on the primes 𝑞{2,3,5,7}, a handful offrequencies become dominant, similar to the original model trained on 𝒯nanda.We thus have that: 1) the model learns to correct for commutativity, and 2) the model ismarred by periodicity as per both the token embedding () and feed-forward layer analysis(). Combining these facts, we can assume the learned mechanism to be extremely similar(and perhaps identical when ignoring commutativity) to the one outlined by Nanda et al.(2023) in Eq. 3, Eq. 4, and Eq. 5.As the remaining 25 tasks are learned, we see a temporary spike in the number of activefrequencies, disappearing again as the model generalizes fully (reaches perfect accuracy).The observation that the number of active frequencies before and after the remaining 25tasks are learned are the same indicates a reuse of circuitry. However, the fact of the spikesuggests that additional circuits form during the generalization process. Viewed in thecontext of Lee et al. (2024)‘s method for boosting slow-varying generalizing components aquestion emerges: are there circuits that only facilitate the development of generalization,but are not present in the generalized mechanisms? Real life gives plenty of examples ofphenomena that make itself obsolete (e.g., a medicine that fully eradicates an illness andis thus no longer needed). Viewed this way, the spike suggests we might divide the set ofcircuits in two: 1); those useful for the mechanism, and 2), those useful for learning themechanism.Future workA logical next step would be to explore the validity of the notion that some circuits helplearning and others help solve the problem. This might yield insight on how to improveLee et al. (2024)‘s grokking speedup heuristic. Aspects of the circuit discovery workflowcould be automated with the methods outlined by Conmy et al. (2023).Additionally, making variations on 𝒯miiii is also likely to be a good avenue for discovery:Divisibility rather than remainder could be predicted; Experiments training for moreepochs could be conducted with larger values of 𝑝; A more in depth mapping of theshared circuitry could be done, for example attempting to see what types of ablationsbreak which tasks—for example, how can performance be degraded on one grokked task,without affecting the others?.The code associated with this paper is available as a PyPI package (pip install miiii) tofacilitate the exploration of these questions (as well as replication of the findings at hand).ConclusionThis paper explores the impact of multi-task learning on mechanistic interpretabilityby training a transformer model on a non-commutative, multi-task modular arithmeticproblem 𝒯miiii. The model successfully generalizes across all tasks, learning complexinternal representations that capture the unique periodic nature of modular arithmeticacross multiple primes. Analysis reveals that while the model reuses and integrates circuitsfor simpler tasks, additional circuits may form during training to facilitate generalizationto more complex tasks.These findings highlight that multi-task learning influences the emergence and complexityof internal mechanisms, posing challenges for mechanistic interpretability but also offeringinsights into how models internalize algorithms. Understanding these dynamics is important for advancing interpretability and reliability in deep learning systems. Future workincludes exploring the distinction between circuits that aid in learning versus those thatcontribute to the final mechanism/solution and investigating how variations in task designimpact the development of internal representations. Advancing the understanding of howdeep learning models handle multiple tasks contributes to the broader goal of makingthese models more interpretable and reliable.¹Try manually writing a function in a language of your choice that classifies dogs and cats from images.²In the context of MI, “circuit” refers to a subgraph of a neural network that performs a particularfunction.³https://github.com/syrkis/esch4Neel Nanda has also stated (in a YouTube video) that a multi layer perceptron rather than atransformer-block would probably have been more appropriate for his setup. The transformer is used heredue to the non-commutativity of 𝒯miiii, and to stay close to Nanda’s work.References[1]J. Jumper et al., “Highly Accurate Protein Structure Prediction with AlphaFold,”Nature, vol. 596, no. 7873, pp. 583–589, Aug. 2021, doi: 10.1038/s41586-021-03819-2.[2]E. Dinan et al., “Human-Level Play in the Game of \emph{Diplomacy} by Combining Language Models with Strategic Reasoning,” Science, vol. 378, no. 6624, pp.1067–1074, Dec. 2022, doi: 10.1126/science.ade9097.[3]A. Radford and K. Narasimhan, “Improving Language Understanding by GenerativePre-Training,” 2018.[4]C. E. Shannon, “Programming a Computer for Playing Chess,” Computer ChessCompendium, pp. 2–13, 1950, doi: 10.1007/978-1-4757-1968-0_1.[5]G. Tesauro, “TD-Gammon, A Self-Teaching Backgammon Program, Achieves Master-Level Play,” 1993.[6]D. Silver et al., “Mastering the Game of Go without Human Knowledge,” Nature,vol. 550, no. 7676, pp. 354–359, Oct. 2017, doi: 10.1038/nature24270.[7]C. B. Browne et al., “A Survey of Monte Carlo Tree Search Methods,” IEEETransactions on Computational Intelligence and AI in Games, vol. 4, no. 1, pp. 1–43, Mar. 2012, doi: 10.1109/TCIAIG.2012.2186810.[8]A. Ruoss et al., “Grandmaster-Level Chess Without Search,” no. arXiv:2402.04494.Feb. 2024. doi: 10.48550/arXiv.2402.04494.[9]Z. C. Lipton, “The Mythos of Model Interpretability: In Machine Learning, theConcept of Interpretability Is Both Important and Slippery.,” Queue, vol. 16, no. 3,pp. 31–57, Jun. 2018, doi: 10.1145/3236386.3241340.[10]J. L. Ba, J. R. Kiros, and G. E. Hinton, “Layer Normalization,” no. arXiv:1607.06450.arXiv, Jul. 2016. doi: 10.48550/arXiv.1607.06450.[11]A. Krizhevsky, I. Sutskever, and G. E. Hinton, “ImageNet Classification with DeepConvolutional Neural Networks,” Communications of the ACM, vol. 60, no. 6, pp.84–90, May 2017, doi: 10.1145/3065386.[12]A. Krogh and J. Hertz, “A Simple Weight Decay Can Improve Generalization,” inAdvances in Neural Information Processing Systems, Morgan-Kaufmann, 1991.[13]J. Baxter, “A Model of Inductive Bias Learning,” no. arXiv:1106.0245. arXiv, Jun.2011. doi: 10.48550/arXiv.1106.0245.[14]N. Nanda, L. Chan, T. Lieberum, J. Smith, and J. Steinhardt, “Progress Measures forGrokking via Mechanistic Interpretability,” no. arXiv:2301.05217. arXiv, Oct. 2023.[15]A. Vaswani et al., “Attention Is All You Need,” no. arXiv:1706.03762. arXiv, Dec.2017. doi: 10.48550/arXiv.1706.03762.[16]J. Lee, B. G. Kang, K. Kim, and K. M. Lee, “Grokfast: Accelerated Grokking byAmplifying Slow Gradients,” no. arXiv:2405.20233. Jun. 2024.[17]L. Egri and T. R. Shultz, “A Compositional Neural-network Solution to Prime-number Testing,” 2006.[18]S. Lee and S. Kim, “Exploring Prime Number Classification: Achieving High RecallRate and Rapid Convergence with Sparse Encoding,” no. arXiv:2402.03363. arXiv,Feb. 2024.[19]D. Wu, J. Yang, M. U. Ahsan, and K. Wang, “Classification of Integers Based onResidue Classes via Modern Deep Learning Algorithms.” Apr. 2023. doi: 10.1016/j.patter.2023.100860.[20]A. Conmy, A. N. Mavor-Parker, A. Lynch, S. Heimersheim, and A. Garriga-Alonso, “Towards Automated Circuit Discovery for Mechanistic Interpretability,”no. arXiv:2304.14997. arXiv, Oct. 2023. doi: 10.48550/arXiv.2304.14997.[21]M. M. Bronstein, J. Bruna, T. Cohen, and P. Velǐcković, “Geometric Deep Learning:Grids, Groups, Graphs, Geodesics, and Gauges,” no. arXiv:2104.13478. arXiv, May2021. doi: 10.48550/arXiv.2104.13478.[22]S. Yu, L. Sanchez Giraldo, and J. Principe, “Information-Theoretic Methods in DeepNeural Networks: Recent Advances and Emerging Opportunities,” in Proceedingsof the Thirtieth International Joint Conference on Artificial Intelligence, Montreal,Canada: International Joint Conferences on Artificial Intelligence Organization, Aug.2021, pp. 4669–4678. doi: 10.24963/ijcai.2021/633.[23]B. Gavranović, P. Lessard, A. Dudzik, T. von Glehn, J. G. M. Araújo, and P.Velǐcković, “Categorical Deep Learning: An Algebraic Theory of Architectures,” no.arXiv:2402.15332. arXiv, Feb. 2024. doi: 10.48550/arXiv.2402.15332.[24]A. Power, Y. Burda, H. Edwards, I. Babuschkin, and V. Misra, “Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets,” no. arXiv:2201.02177.arXiv, Jan. 2022. doi: 10.48550/arXiv.2201.02177.[25]A. I. Humayun, R. Balestriero, and R. Baraniuk, “Deep Networks Always Grok andHere Is Why,” no. arXiv:2402.15555. Jun. 2024. doi: 10.48550/arXiv.2402.15555.[26]B. Wang, X. Yue, Y. Su, and H. Sun, “Grokked Transformers Are Implicit Reasoners:A Mechanistic Journey to the Edge of Generalization,” no. arXiv:2405.15071. May2024. doi: 10.48550/arXiv.2405.15071.[27]H. J. Jeon and B. Van Roy, “An Information-Theoretic Framework for Deep Learning,” Advances in Neural Information Processing Systems, vol. 35, pp. 3279–3291,Dec. 2022.[28]A. Maurer, “Bounds for Linear Multi-Task Learning.”[29]C. E. Shannon, “A Mathematical Theory of Communication,” ACM SIGMOBILEMobile Computing and Communications Review, vol. 5, no. 1, pp. 3–55, Jan. 2001,doi: 10.1145/584091.584093.[30]T. M. Cover and J. A. Thomas, Elements of Information Theory, 2nd ed. Hoboken,N.J: Wiley-Interscience, 2006.[31]B. He and T. Hofmann, “Simplifying Transformer Blocks,” no. arXiv:2311.01906.arXiv, Nov. 2023. doi: 10.48550/arXiv.2311.01906.[32]M. Hosseini and P. Hosseini, “You Need to Pay Better Attention,” no.arXiv:2403.01643. Mar. 2024. doi: 10.48550/arXiv.2403.01643.[33]K. He, X. Zhang, S. Ren, and J. Sun, “Delving Deep into Rectifiers: SurpassingHuman-Level Performance on ImageNet Classification,” no. arXiv:1502.01852. arXiv,Feb. 2015. doi: 10.48550/arXiv.1502.01852.[34]T. Akiba, S. Sano, T. Yanase, T. Ohta, and M. Koyama, “Optuna: A Next-generation Hyperparameter Optimization Framework,” in Proceedings of the 25th ACMSIGKDD International Conference on Knowledge Discovery & Data Mining, in KDD'19. New York, NY, USA: Association for Computing Machinery, Jul. 2019, pp. 2623–2631. doi: 10.1145/3292500.3330701.[35]I. Loshchilov and F. Hutter, “Decoupled Weight Decay Regularization,” no.arXiv:1711.05101. arXiv, Jan. 2019. doi: 10.48550/arXiv.1711.05101.