code
|
text
Mechanistic Interpretability on Irreducible Integer Identifiers
Noah Syrkis and Anders Søgaard
November 29, 2024
Introduction
Recent years have seen deep learning (DL) models achieve remarkable profi
ciency in complex computational tasks, including protein structure prediction
[1]
, strategic reasoning
[2]
, and natural language generation
[3]
—areas previ
ously thought to be the exclusive domain of human intelligence. Traditional
(symbolic) programming allows functions like
𝑓
(
𝑥
,
𝑦
)
=
cos
(
𝑎
⋅
𝑥
)
+
sin
(
𝑏
⋅
𝑦
)
to be implemented in code with clear typographical isomorphism—meaning the
code’s structure directly mirrors the mathematical notation. For example, in the
language Haskell:
f x y = cos(a * x) + sin(b * y)
. In contrast, DL models
are inherently sub-symbolic, meaning that the models’ atomic constituents
(often 32-bit floating-point numbers centered around 0) do not map directly to
mathematical vocabulary. For reference, shows a DL-based implementation of
the aforementioned function. Indeed, the increasing prevalence of DL can be
understood as a transition from symbolic to sub-symbolic algorithms.
Precursors to modern DL methods learned how to weigh human-designed fea
tures
[4]
, with later works learning to create features from data to subsequently
weigh
[5]
,
[6]
—in combination with tree search strategies, in the case of games
[7]
. Very recent DL work has even eliminated tree search in the case of chess,
mapping directly from observation space to action space
[8]
. Pure DL methods
are thus becoming ubiquitous but remain largely inscrutable, with recent works
still attempting to define what interpretability even means in the DL context
[9]
. Given the sub-symbolic nature of DL models, it is unsurprising that their
interpretation remains difficult.
Mathematically, DL refers to a set of methods that combine linear maps (matrix
multiplications) with non-linearities (activation functions). Formally, all the
potential numerical values of a given model’s weights
𝑊
can be thought of as
a hypothesis space
ℋ
. Often,
ℋ
is determined by human decisions (number
of layers, kinds of layers, sizes of layers, etc.).
ℋ
is then navigated using some
optimization heuristic, such as gradient descent, in the hope of finding a
𝑊
that “performs well” (i.e., successfully minimizes some loss
ℒ
often computed
by a differentiable function with respect to
𝑊
) on whatever training data is
present. This vast, sub-symbolic hypothesis space, while enabling impressive
performance and the solving of relatively exotic
¹
tasks, makes it challenging
to understand how any one particular solution actually works (i.e., a black box
algorithm).
The ways in which a given model can minimize
ℒ
can be placed on a continuum:
on one side, we have overfitting, remembering the training data, (i.e. functioning
as an archive akin to lossy and even lossless compression); and on the other,
we have generalization, learning the rules that govern the relationship between
input and output (i.e. functioning as an algorithm).
When attempting to give a mechanistic explanation of a given DL model’s
behavior, it necessarily entails the
existence
of a mechanism. Mechanistic
interpretability (MI) assumes this mechanism to be general, thus making gener
alization a necessary (though insufficient) condition. Generalization ensures that
there
is
a mechanism/algorithm present to be uncovered (necessity); however,
it is possible for that algorithm to be so obscurely implemented that reverse
engineering, for all intents and purposes, is impossible (insufficiency). Various
forms of regularization are used to incentivize the emergence of algorithmic
(generalized) and interpretable, rather than archiving (over-fitted) behavior
[10]
,
[11]
,
[12]
.
As of yet, no MI work has explored the effect of multi-task learning, the focus
of this paper. Multitask learning also has a regularizing effect
[13]
. Formally, the
set of hypotheses spaces for each task of a set of tasks (often called environment)
is denoted by
ℋ
∈
ℍ
. When minimizing the losses across all tasks in parallel,
generalizing
𝑊
’s are thus incentivized, as these help lower loss across tasks (in
contrast to memorizing
𝑊
’s that lower loss for one task). A
𝑊
derived from
a multi-task training process can thus be thought of as the intersection of the
high-performing areas of all
ℋ
∈
ℍ
.
In this spirit, the present paper builds on the work of
Nanda et al. (2023)
, which
trains a transformer
[15]
model to perform modular addition, as seen in
Eq. 1
.
The task is denoted as
𝒯
nanda
throughout the paper.
(
𝑥
0
+
𝑥
1
)
mod
𝑝
,
∀
𝑥
0
,
𝑥
1
<
𝑝
,
𝑝
=
1
1
3
(1)
The task of this paper focuses on predicting remainders modulo all primes
𝑞
less
than
𝑝
, where
𝑥
is interpreted as
𝑥
0
𝑝
0
+
𝑥
1
𝑝
1
, formally shown in
Eq. 2
, and is
referred to as
𝒯
miiii
:
(
𝑥
0
𝑝
0
+
𝑥
1
𝑝
1
)
mod
𝑞
,
∀
𝑥
0
,
𝑥
1
<
𝑝
,
∀
𝑞
<
𝑝
,
𝑝
=
1
1
3
(2)
𝒯
miiii
differentiates itself from
𝒯
nanda
in two significant ways:
1)
it is non-
commutative, and
2)
it is, as mentioned, multi-task. These differences present
unique challenges for mechanistic interpretation, as the model must learn to
handle both the order-dependent nature of the inputs and develop shared repre
sentations across multiple modular arithmetic tasks. Further, as
𝒯
miiii
is harder
than
𝒯
nanda
the model can be expected to generalize slower when trained on the
former. Therefore,
Lee et al. (2024)
‘s recent work on speeding up generalization
by positing the model parameters gradients through time can be viewed as a
sum of
1)
a slow varying generalizing component (which is boosted), and
2)
,
a quick varying overfitting component (which is suppressed), is (successfully)
replicated to make training tractable.
Figure 1: Visualizing natural numbers less than
1
2
7
6
9
in polar coordinates
(
𝑛
,
𝑛
mod
2
𝜋
)
. Left: union of numbers with remainder 0 mod 17 and 23 (see the
two spirals). Middle: numbers with remainder 0 mod 11. Right: prime numbers.
It is shown here to encourage the reader to think in periodic terms.
More generally, modular arithmetic on primes is a particularly useful task for
MI as it ensures uniformity among the output classes, allows for comparison
with other MI work
[14]
, and, from a number-theoretic point of view, primes
contain mysteries ranging from the trivially solved—are there an infinite number
of primes?—to the deceptively difficult—can all even numbers larger than 4 be
described as the sum of two primes? The latter, known as Goldbach’s Conjecture,
remains unsolved after centuries. The choice of using every prime less than
the square root of the largest number of the dataset also serves the following
purpose: to test if a given natural number is prime, it suffices to test that it is not
a multiple of any prime less than its square root—the set of tasks trained for here,
can thus be viewed in conjunction as a single prime detection task (primes are
the only samples whose target vector contains no zeros, since it is not a multiple
of any of the factors
𝑞
). There are about
𝑛
ln
(
𝑛
)
primes less than
𝑛
.
To provide insight into the periodic structure of these remainders for natural
numbers less than
1
2
7
6
9
(and motivate thinking in rotational terms),
Figure 1
visualizes various modular patterns in polar coordinates (
𝑛
,
𝑛
mod
2
𝜋
). One
could imagine tightening and loosening the spiral by multiplying
2
𝜋
by a
constant to align multiples of a given number in a straight line (imagining this
is encouraged).
Background and related work
Multiple papers describe the use of deep learning to detect prime numbers
[17]
,
[18]
,
[19]
. None are particularly promising as prime detection algorithms, as
they do not provide speedups, use more memory, and are less accurate than
traditional methods. However, in exploring the foundations of deep learning,
the task of prime detection is interesting, as it is a simple task that is difficult to
learn, and is synthetic, meaning that the arbitrary amounts of data are generated
by a simple algorithm.
Mechanistic Interpretability (MI)
MI is a relatively new field focused on reverse-engineering the internal
mechanisms of neural networks.
Lipton (2018)
explored different definitions
of interpretability in this context. MI can be contrasted with other forms of
interpretability, such as feature importance analysis; while feature importance
measures correlations between inputs and outputs (e.g., red pixels correlating
with “rose” classifications), MI aims to understand how the model actually
processes information (i.e., the mechanism).
Methods and tools used so far in MI include: Activation visualization across
ordered samples; Singular value decomposition of weight matrices; Ablation
studies to identify critical circuits.
Conmy et al. (2023)
even successfully auto
mate circuit
²
discovery. Many reverse engineering methods from other fields,
such as computational neuroscience or signal processing, almost certainly have
their uses here as well.
In spite of deep learning’s practical successes, uncertainty remains about its
theoretical underpinnings, echoing the interpretability debate. Recent work
attempts to place different DL architectures and concepts in a either geometric
[21]
, information theoretic
[22]
, or even category theoretic
[23]
context. How
ever, no unified theory has emerged. Much interesting deep learning research
thus focuses on practical, simple, or algorithmic tasks with known solutions and
architectures. For example, grokking
[24]
, the (relatively) sudden generalization
after overfitting, as elaborated later, is a recent and practical discovery.
Case study: modular addition
One such practical discovery is made by
Nanda et al. (2023)
. A single layer trans
former model with ReLU activation function was trained to perform modular
addition (
𝒯
nanda
).
Nanda et al. (2023)
‘s analysis of their trained model exemplifies
MI methodology. They discovered that:
1)
The embedding layer learns trigono
metric lookup tables of sine and cosine values as per
Eq. 3
;
2)
The feed-forward
network combines these through multiplication and trigonometric identities
(
Eq. 4
), and
3)
The final layer performs the equivalent of argmax (
Eq. 5
).
𝑥
0
→
sin
(
𝑤
𝑥
0
)
,
cos
(
𝑤
𝑥
0
)
(3.1)
𝑥
1
→
sin
(
𝑤
𝑥
1
)
,
cos
(
𝑤
𝑥
1
)
(3.2)
sin
(
𝑤
(
𝑥
0
+
𝑥
1
)
)
=
sin
(
𝑤
𝑥
0
)
cos
(
𝑤
𝑥
0
)
+
cos
(
𝑤
𝑥
0
)
sin
(
𝑤
𝑥
1
)
(4.1)
cos
(
𝑤
(
𝑥
0
+
𝑥
1
)
)
=
cos
(
𝑤
𝑥
1
)
cos
(
𝑤
𝑥
1
)
−
sin
(
𝑤
𝑥
0
)
sin
(
𝑤
𝑥
1
)
(4.2)
Logit
(
𝑐
)
∝
cos
(
𝑤
(
𝑥
0
+
𝑥
1
−
𝑐
)
)
(5.1)
=
cos
(
𝑤
(
𝑥
0
+
𝑥
1
)
)
cos
(
𝑤
𝑐
)
+
sin
(
𝑤
(
𝑥
0
+
𝑥
1
)
)
sin
(
𝑤
𝑐
)
(5.2)
Generalization and grokking
Power et al. (2022)
shows generalization can happen
“[...] well past the point of
overfitting”
, dubbing the phenomenon “grokking”. The phenomenon is now well
established
[14]
,
[25]
,
[26]
.
Nanda et al. (2023)
shows that a generalized circuit
“arises from the gradual amplification of structured mechanisms encoded in the
weights,”
rather than being a relatively sudden and stochastic encounter of an
appropriate region of
ℋ
. The important word of the quote is thus “gradual”.
By regarding the series of gradients in time as a stochastic signal,
Lee et al.
(2024)
proposes decomposing the signal. Conceptually,
Lee et al. (2024)
argues
that in the case of gradient descent, the ordered sequence of gradient updates
can be viewed as consisting of two components:
1)
a fast varying overfitting
component, and
2)
a slow varying generalizing components. The general algo
rithm explaining the relationship between input and output is the same for all
samples, whereas the weights that allow a given model to function are unique
for all samples. Though not proven, this intuition bears out in that generalization
is sped up fifty-fold in some cases.
This echoes the idea that generalized circuits go through
gradual
amplification
[14]
. To the extent that this phenomenon is widespread, it bodes well for gener
alizable DL in that the generalizing signal that one would want to amplify might
exist long before the model is fully trained and could potentially be boosted in
a targeted way by the method described by
Lee et al. (2024)
.
Perhaps the most widespread loss functions used in deep learning are mean
cross-entropy
Eq. 6.1
(for classification) and mean squared error
Eq. 6.2
(for
regression).
𝐿
MCE
=
1
𝑛
∑
𝑛
𝑖
=
1
∑
𝑘
𝑗
=
1
𝑦
𝑝
𝑖
𝑗
ln
(
1
̂
𝑦
𝑝
𝑖
𝑗
)
(6.1)
𝐿
MSE
=
1
𝑛
∑
𝑛
𝑖
=
1
(
𝑦
𝑖
−
̂
𝑦
𝑖
)
2
(6.2)
These have various computational and mathematical properties that make them
convenient to use, though they have been shown to struggle with generalizing
out-of-distribution data
[27]
,
[22]
.
Multi-task learning in deep learning
As stated, multi-task learning has been shown to have a regularizing effect
[13]
,
[28]
as the hypothesis
𝑊
that performs well across all of the hypothesis spaces
ℋ
∈
ℍ
is more likely to be general. Viewed information theoretically, this con
cept is reminiscent of
Shannon (2001)
‘s asymptotic equipartition property
[30]
,
or even more generally, the law of large numbers, which states that the more
samples we have of a distribution, the closer our
estimates
are to its underlying
properties will align with the
true
underlying properties.
In the context of
𝒯
miiii
, multi-task learning is done by having the last layer output
predictions for all tasks in parallel. Thus, whereas
𝒯
nanda
outputs a single one-hot
1
×
1
1
3
vector for each of the potential remainders,
𝒯
miiii
, as we shall see, outputs
a
1
×
𝑞
vector for each prime
𝑞
<
𝑝
(i.e., 29 output-task vectors when
𝑝
=
1
1
3
).
The embeddings layer and the transformer block are thus shared for all tasks,
meaning that representations that perform well across tasks are incentivized.
Transformer architecture
Transformers combine self-attention (a communication mechanism) with feed-
forward layers (a computation mechanism). The original transformer-block
[15]
used extensive regularization—layer norm
[10]
, dropout, weight decay, and
residual connections are all integral components of the original architecture,
though recent years have seen simplifications yielding similar performance
[31]
,
[32]
.
Input tokens are embedded into a
𝑑
-dimensional space using learned token and
positional embeddings:
𝑧
=
TokenEmbed
(
𝑥
)
+
PosEmbed
(
pos
)
(7)
Each transformer block comprises multi-head attention:
Attention
(
𝑄
,
𝐾
,
𝑉
)
=
softmax
(
𝑄
𝐾
𝑇
√
𝑑
𝑘
)
𝑉
(8)
where
𝑄
,
𝐾
, and
𝑉
are linear projections of the input. Attention heads are
combined through addition rather than concatenation (a transformer specific
detail to align with
Nanda et al. (2023)
). This is followed by a feed-forward
network with ReLU activation:
FFN
(
𝑧
)
=
ReLU
(
𝑧
𝑊
in
)
𝑊
out
(9)
mapping from
𝑑
→
4
𝑑
→
𝑑
dimensions, before finally:
̂
𝑦
=
𝑧
𝑊
unembed
(10)
Each component includes residual connections and dropout.
Methods
How exactly a given model implements an algorithm is a non-trivial question—
even modular addition is implemented in a relatively obscure way
[14]
, as per
Eq. 3
,
Eq. 4
, and
Eq. 5
.
This investigation probes the fundamental algorithmic structures internalized
by a transformer model trained on a set of basic prime number-related modular
arithmetic tasks with slight variations in complexity. This approach provides
insights into how and why specific algorithmic patterns emerge from seemingly
straightforward learning processes.
As stated, the setup here differentiates itself from
𝒯
nanda
in two crucial ways:
1)
It is non-commutative; and
2)
It is multitask.
Tasks
Stated plainly: the task
𝒯
miiii
predicts the remainder when dividing a two-digit
base-
𝑝
number by each prime factor
𝑞
less than
𝑝
. The set of prime factors we
construct tasks for is thus
{
𝑞
}
=
{
𝑞
∈
ℙ
:
𝑞
<
𝑝
}
. For
𝑝
=
1
1
3
, this yields 29
parallel tasks, one for each prime less than
𝑝
. Each task predicts a remainder
in the range
[
0
,
𝑞
−
1
]
. This means smaller primes like 2 and 3 require binary
and ternary classification, respectively, while the largest prime less than
𝑝
, 109,
requires predictions across 109 classes. The tasks thus naturally vary in diffi
culty: predicting
mod
2
requires distinguishing odd from even numbers (which
in binary amounts to looking at the last bit) while predicting
mod
1
0
9
involves
making a selection between many relatively similar classes. From an informa
tion-theoretical perspective, the expected cross entropy for an
𝑛
-class problem
is
ln
(
𝑛
)
, which has implications for the construction of the loss function, further
discussed in .
Additionally, a baseline task
𝒯
basis
was constructed by shuffling the
𝑦
-labels of
𝒯
miiii
, and a task ablation test
𝒯
masked
was constructed by masking away the four
simplest tasks
𝑞
∈
{
2
,
3
,
5
,
7
}
.
Data
Input Space (
𝑋
)
Each input
𝑥
∈
𝑋
represents a number in base
𝑝
using two
digits,
(
𝑥
0
,
𝑥
1
)
, where the represented number is
𝑥
0
𝑝
0
+
𝑥
1
𝑝
1
. For example,
with
𝑝
=
1
1
, the input space consists of all pairs
(
𝑥
0
,
𝑥
1
)
where
𝑥
0
,
𝑥
1
<
1
1
,
representing numbers up to
1
1
2
−
1
=
1
2
0
. This yields a dataset of 121 samples.
Figure 2
visualizes this input space, with each cell representing the value
𝑥
0
𝑝
0
+
𝑥
1
𝑝
1
.
Figure 2: Visualizing X (for a small dataset where
𝑝
=
1
1
). Each cell represents
the tuple
(
𝑥
0
,
𝑥
1
)
. The top left shows 0
(
0
,
0
)
, and the bottom right shows 120
(
1
0
,
1
0
)
—both in base-11
Output Space (
𝑌
)
For each input
𝑥
, a vector
𝑦
∈
𝑌
contains the remainder
when dividing by each prime less than
𝑝
. For
𝑝
=
1
1
, this means predicting the
remainder when dividing by 2, 3, 5, and 7. Each element
𝑦
𝑖
ranges from
0
to
𝑞
𝑖
−
1
where
𝑞
𝑖
is the
𝑖
-th prime.
Figure 3
visualizes these remainders, with each
subplot showing the remainder pattern for a specific prime divisor. For compar
ison, the rightmost plot shows the output space of
[14]
‘s modular addition task.
Figure 3: Visualizing tasks in Y (for
𝑝
=
1
1
).
𝑥
0
and
𝑥
1
vary on the two axis,
with the remainder modulo
𝑞
∈
{
2
,
3
,
5
,
7
}
indicated by the square size. Note
the innate periodicity of the modulo operator.
Model
The model follows the original transformer architecture
[15]
with several key
design choices aligned with recent work on mechanistic interpretability
[14]
,
[16]
: biases are disabled, and layer normalization is not used. The model consists
of three main components: an embedding layer, transformer blocks, and an
output layer. All weights are initialized following
He et al. (2015)
. The model
processes vectors of the kind seen in
Eq. 11
, writing the eventual result to the
last position.
[
𝑥
0
𝑥
1
̂
𝑦
]
(11)
Training
Hyper parameter optimization was conducted using Optuna
[34]
, searching over
Table 1
.
dropout
𝜆
wd
𝑑
lr
heads
0
,
1
2
,
1
5
,
1
1
0
0
,
1
2
,
2
0
,
1
1
0
,
1
2
,
1
1
2
8
,
2
5
6
3e-4, 1e-4
4, 8
Table 1: Hyperparameter search space for training.
The model is trained using AdamW
[35]
with
𝛽
1
=
0
.
9
,
𝛽
2
=
0
.
9
8
following
Nanda et al. (2023)
. To handle the varying number of classes across tasks (from 2
classes for mod 2 to 109 classes for mod 109), a modified (weighted) mean cross-
entropy (
Eq. 6.1
) loss is created, correcting for the difference in the expected loss
within each task. Note that
𝔼
[
𝐿
MCE
]
=
ln
(
1
𝑞
)
, where
𝑞
is the number of classes
within the task in question. Correcting for this, the loss function becomes as
shown in
Eq. 12.3
.
𝐿
𝒯
miiii
=
∑
𝑞
∈
{
𝑞
}
𝐿
MCE
𝑞
ln
(
𝑞
)
(12.1)
=
∑
𝑞
∈
{
𝑞
}
∑
𝑛
𝑖
=
1
∑
𝑞
−
1
𝑗
=
0
𝑦
𝑞
𝑖
𝑗
ln
(
̂
𝑦
𝑞
𝑖
𝑗
)
𝑛
ln
(
𝑞
)
(12.2)
=
∑
𝑞
∈
{
𝑞
}
∑
𝑛
𝑖
=
1
∑
𝑞
−
1
𝑗
=
0
𝑦
𝑞
𝑖
𝑗
ln
(
̂
𝑦
𝑞
𝑖
𝑗
)
𝑛
ln
(
𝑞
)
(12.3)
To accelerate generalization, gradient filtering as per
Lee et al. (2024)
is imple
mented and replicated.
𝑔
𝑡
=
∇
𝜃
𝐿
+
𝜆
(
𝛼
𝑒
𝑡
−
1
+
(
1
−
𝛼
)
𝑔
𝑡
−
1
)
(13)
where
𝑒
𝑡
is the exponential moving average of gradients with decay rate
𝛼
=
0
.
9
8
, and
𝜆
controls the influence of the slow-varying component.
The training uses full batch gradient descent with the entire dataset of
𝑝
2
sam
ples (
1
2
7
6
9
when
𝑝
=
1
1
3
). The model is evaluated on a held-out validation set
after each epoch, tracking per-task accuracy and loss. As the setup used in
𝒯
nanda
,
training was done on thirty percent of the total dataset, with the remaining used
for validation (1000 samples) and testing (remaining). Further, as
𝒯
miiii
involves
the learning of 29 (when
𝑝
=
1
1
3
) tasks rather than 1, and due to each task’s
non-commutativity, a larger hidden dimension of 256 was added to the hyper
parameter search space, as well as the potential for 8 heads (
𝒯
nanda
was solved
with a hidden dimension of 128, and 4 heads). The number of transformer blocks
was kept at 1, as this ensures consistency with
𝒯
nanda
(and as full generalization
was possible, as we shall see in the ).
Training was done on a NVIDIA GeForce RTX 4090 GPU, with Python3.11 and
extensive use of “JAX 0.4.35” and its associated ecosystem. Neuron activations
were calculated at every training step and logged for later analysis.
Visualization
Much of the data worked with here is inherently high dimensional. For training,
for example, we have
𝑛
steps, two splits (train/valid) about
𝑝
ln
(
𝑝
)
tasks, and two
metrics (accuracy and loss). This, along with the inherent opaqueness of deep
learning models, motivated the development of a custom visualization library,
esch
³
, to visualize attention weights, intermediate representations, training
metrics, and more. To familiarize the reader with visualizing the inner workings
of a trained model, an essential plot type for the reader to keep in mind is seen
in
Figure 4
. As there are only
1
2
7
6
9
samples when
𝑝
=
1
1
3
, all samples can be
fed at once to the model. Inspecting a specific activation thus yields a
1
×
1
2
7
9
6
vector
𝑣
, which can be reshaped as a
1
1
3
×
1
1
3
matrix, with the two axes,
𝑥
0
and
𝑥
1
, varying from 0 to 112, respectively. The top-left corner then shows the
given value for the sample
(
0
⋅
𝑝
0
+
0
⋅
𝑝
1
)
, and so on.
Figure 4: Plotting a neuron: (left) The activation of a particular neuron as
𝑥
0
and
𝑥
1
varies from
0
to
𝑝
. (right) The same processed with a fast Fourier transform
to active frequencies (
𝜔
).
Note that in
esch
plots, when appropriate, only the top leftmost
3
7
×
3
7
slice is
shown so as not to overwhelm the reader.
Mechanistic interpretability process
Recall that a combination of linear products is itself a linear product. Therefore,
as a mechanistic interpretability rule of thumb, one should look at the outputs
of the non-linear transformations. In our case, that will be the attention weights
and the intermediate representations within the transformer block’s feed-for
ward output (which follows the ReLU activation). Additionally, the embedding
layers will be inspected using Fourier analysis and singular value decomposition.
As mentioned in , our interpretability approach combines activation visualiza
tion with frequency analysis to understand the learned algorithmic patterns.
Following
Nanda et al. (2023)
, we analyze both the attention patterns and the
learned representations through several lenses:
Attention visualization
Using
esch
, the custom visualization library, to visualize attention weights and
intermediate representations. The library allows for the visualization of atten
tion patterns across different layers, as well as the visualization of intermediate
representations at each layer. These visualizations provide insights into the
learned patterns and help identify potential areas of improvement.
The fast Fourier transform
As periodicity is established by
Nanda et al. (2023)
as a fundamental feature of
the model trained on
𝒯
nanda
, the fast Fourier transform (FFT) algorithm is used
to detect which frequencies are in play. Note that any square image can be
described as a sum of 2d sine and cosine waves varying in frequency from 1 to
the size of the image divided by 2 (plus a constant). This is a fundamental tool
used in signal processing. The theory is briefly outlined in for reference. This
analysis helps identify the dominant frequencies in the model’s computational
patterns. Recall that a vector can be described as a linear combination of other
periodic vectors as per the discrete Fourier transform.
The default basis of the one-hot encoded representation of the input is thus the
identity matrix. This can be projected into a Fourier basis by multiplying with
the discrete Fourier transform (DFT) matrix visualized in .
Results and analysis
Hyper-parameter optimization
The best-performing hyper-parameters for training the model on
𝒯
miiii
are listed
in
Table 2
. Notably, the model did not converge when
𝜆
=
0
, confirming the
utility of the gradient amplification method proposed by
Lee et al. (2024)
in the
context of
𝒯
miiii
.
dropout
𝜆
wd
𝑑
lr
heads
1
1
0
1
2
1
3
256
3
×
1
0
−
4
4
Table 2: Result of hyper-parameter search over
𝒯
miiii
.
Model Performance
Figure 5
show the training and validation accuracy on
𝒯
miiii
over time. The model
achieves a perfect accuracy of 1 on the validation set across all 29 tasks. The
cross-entropy loss in
Figure 6
echoes this. In short—and to use the terminology
of
Power et al. (2022)
—the model “grokked” on all tasks. Interestingly, tasks
corresponding to modulo 2, 3, 5, and 7 generalized in succession, while the
remaining 25 tasks generalized around epoch
4
0
0
0
0
in no particular order. This
might suggest that the model initially learned solutions for the simpler tasks
and later developed a more general computational strategy that allowed it to
generalize across the remaining, more complex tasks.
Figure 5: Accuracy training “curves”: Training (top) and validation (bottom)
accuracy over time (
𝑥
-axis in log-scale). We see grokking occur on all tasks, first
for
𝑞
∈
{
2
,
3
,
5
,
7
}
in that order, and then the remaining 25 in no particular order.
Figure 6: Cross-entropy (
Eq. 6.1
) loss on training (top) and validation (bottom)
over time (note the log scale on the
𝑥
-axis).
Embeddings
Positional embeddings play a crucial role in transformers by encoding the
position of tokens in a sequence.
Figure 7
compares the positional embeddings
of models trained on
𝒯
nanda
and
𝒯
miiii
.
For
𝒯
nanda
, which involves a commutative task, the positional embeddings are
virtually identical, with a Pearson correlation of 0.95, reflecting that the position
of input tokens does not significantly alter their contribution to the task. In
contrast, for
𝒯
miiii
, the positional embeddings have a Pearson correlation of
−0.64, indicating that the embeddings for the two positions are different. This
difference is expected due to the non-commutative nature of the task, where
the order of
𝑥
0
and
𝑥
1
matters (
𝑥
0
⋅
𝑝
0
≠
𝑥
0
⋅
𝑝
1
). This confirms that the model
appropriately encodes position information for solving the tasks.
Figure 7: Positional embeddings for
(
𝑥
0
,
𝑥
1
)
for models trained on
𝒯
nanda
(top)
and
𝒯
miiii
(bottom). Pearson’s correlation is 0.95 and −0.64 respectively. This
reflects the commutativity of
𝒯
nanda
and the lack thereof for
𝒯
miiii
. Hollow cells
indicate negative numbers.
Recall that a matrix
𝐌
of size
𝑚
×
𝑛
can be decomposed to its singular values
𝐌
=
𝐔
𝚺
𝐕
𝐓
(with the transpose being the complex conjugate when
𝐌
is
complex), where
𝐔
is
𝑚
×
𝑚
,
𝚺
an
𝑚
×
𝑛
rectangular diagonal matrix (whose
diagonal is represented as a flat vector throughout this paper), and
𝐕
𝐓
a
𝑛
×
𝑛
matrix. Intuitively, this can be thought of as rotating in the input space, then
scaling, and then rotating in the output space.
Figure 8
displays the singular values of the token embeddings learned for
𝒯
nanda
and
𝒯
miiii
. The singular values for
𝒯
miiii
are more diffuse, indicating that a larger
number of components are needed to capture the variance in the embeddings
compared to
𝒯
nanda
. This suggests that the token embeddings for
𝒯
miiii
encode
more complex information, reflecting the increased complexity of the multi-task
learning scenario.
Figure 8: First 83 of 113 singular values (truncated for clarity) of
U
for
𝒯
nanda
(top) and
𝒯
miiii
(bottom). The ticks indicate the points where 50% and 90% of
the variance is accounted for. We thus see that for
𝒯
miiii
, the embedding space is
much more crammed.
Figure 9:
𝒯
nanda
’s most significant (cutoff at 0.5 as per
Figure 8
) singular vectors
of
U
from the singular value decomposition. Note this looks periodic!
Figure 10:
𝒯
miiii
’s most significant vectors of
U
. Note that, like in
Figure 9
, we still
observe periodicity, but there are more frequencies in play, as further explored
in
Figure 12
.
Figure 9
and
Figure 10
present the most significant singular vectors of
U
for
𝒯
nanda
and
𝒯
miiii
, respectively. Visual inspection shows periodicity in the top
vectors for both models, but the
𝒯
miiii
model requires more vectors to capture the
same amount of variance, consistent with the diffuse singular values observed
in
Figure 8
.
To further understand the structure of the token embeddings, we applied the
Fast Fourier Transform (FFT). Only a few frequencies are active for
𝒯
nanda
as
seen in
Figure 11
, consistent with the model implementing a cosine-sine lookup
table as described in
Nanda et al. (2023)
.
For the
𝒯
miiii
model, we observe a broader spectrum of active frequencies
(
Figure 12
). This is expected due to the model having to represent periodicity
corresponding to 29 primes.
Comparing with
𝒯
basis
in figure
Figure 13
, the periodicity is understood to be a
structure inherent to the data picked up by the model.