The Residual Stream
This post discusses the residual stream in transformer models and builds towards using it for two examples of mechanistic interpretability: Logit Lens (nostalgebraist 2020) and Linear Probes (Alain and Bengio 2018).
Transformers
Much of the description of a transformer here comes straight from the excellent “A Mathematical Framework for Transformer Circuits” (Elhage et al. 2021).
The transformer is made of a few crucial components:
- Token embedding, \(W_{E}\)
- Residual Stream, \(x_{i}\)
- Attention layer, \(Attn(x_{i}) = \sum_{h \in H} h(x_{i})\)
- MLP Layer, \(MLP(x_{i})\)
- Unembedding, \(W_{U}\)
I will give a high level overview of what some of these elements do by considering the input sequence,
“The capital of the country containing Manchester1 is”
1 Getting a sensible answer from small LLMs is hard, GPT-2 prompted with “The capital of the country containing Lyon is” returned “the” with 30% confidence and “France” with only 4% confidence.
We will follow the final token “is” as this is the token from which the model will predict the next token.
The first step is to use the Token embedding, \(W_{E} \in \mathbb{R}^{d_{model} \times d_{vocab}}\), which acts as a look-up table from the token “is”, which we represent initially by the Token index \(t \in \mathbb{R}^{d_{vocab}}\), to the \(d_{model}\) vector that represents it,
\[ x_{0} = W_{E} t \tag{1}\]
This is the source of the residual stream, each process in the model adds something to this initial vector \(x_{0}\)2. The vector \(x_{0}\) starts as an isolated embedding of the token “is” but finishes as a representation capable of predicting the next token in the sequence.
2 There has been some progress towards not using the simple additive residual stream. One such example are Hyper-Connections (Zhu et al. 2025), which allow the strength of connections between layers to be learnt during training. With the residual stream a crucial tool for safety researchers I think it would be a major set back if connection variants such as this became the norm.
Most commonly it seems a layer3 of a transformer refers to the combination of an Attention layer and an MLP layer but they add to the residual stream sequentially,
3 (Elhage et al. 2021) also refers to this as a “residual block”.
\[ x_{n+1} = x_{n} + Attn(x_{n}) + MLP(x_{n} + Attn(x_{n})) \tag{2}\]
Attention heads move information from the residual streams of other tokens in the sequence. For example in our example if we are going to answer the question,
“The capital of the country containing Manchester is”
our final representation \(x_{-1}\) is going to need to capture the multi-hop reasoning that the country containing Manchester is the United Kingdom and hence the capital is London we need to move information from these other tokens.
It must also know these geographical facts, research suggests that this is captured somewhere within the MLP layers (Nanda et al. 2024).
Why residuals?
There are a few reasons that this additive approach is favoured but their inclusion is likely a legacy of the Residual Neural Network (ResNet) (He et al. 2015). This paper suggested a residual stream formulation,
\[ x + f(x) \tag{3}\]
which has since also become known as a skip connection. The layers of the network return a “residual mapping” \(f(x)\) which is added to the main input. From their paper the primary motivation was to combat the degradation problem, where deeper networks have higher training and test error. This should be slightly surprising as we might expect a model with 20 layers can be replicated within a model of 50 layers.
It turns out that for networks without a residual stream the correlation between gradients decays exponentially (Balduzzi et al. 2018), that is to say at a point \(x\) the gradient \(\nabla f(x)\) might be almost random compared to a near by point \(x + \delta\) with gradient \(\nabla f(x + \delta)\). If the gradients are almost random gradient descent is too, we set off down the slope only to immediately find it increase again. In contrast, with skip connections the correlation between gradients decays sub-linearly.
A subtly different additional benefit to using a residual stream is that we also solve the vanishing gradient problem, where by in a deep network backpropagation will involve many applications of the chain rule such that if any of the gradients are small the total update to the weights will vanish.
Following this 3Blue1Brown explainer (Sanderson 2017) we consider a very simple network with single neuron layers, connected in sequence. Let’s say every layer consists of,
\[ f(x_{i}) = \sigma(W x_{i} + b) \tag{4}\]
The sigmoid function4, \(\sigma\), is defined as,
4 ReLU and GELU activation functions are designed to also combat the vanishing gradient problem.
\[ \sigma(z) = \frac{e^z}{1+e^z} \tag{5}\]
Its derivative, \(\sigma'(z) = \sigma(z)(1 - \sigma(z))\), is bounded within the interval \((0, 0.25]\). Hence, when we calculate the gradient of the loss, \(\nabla \mathcal{L}\), in backpropagation this term will reduce the resulting value by at least a factor of 4 each layer. As we update the weights according to \(-\eta \nabla \mathcal{L}\) when the gradient vanishes we stop learning.
With a residual stream there will always be a non-zero gradient,
\[ x_{i+1} = x_{i} + f(x_{i}) \tag{6}\]
\[ \frac{\partial(x_{i+1})}{\partial x_{i}} = 1 + f'(x_{i}) \tag{7}\]
The constant term \(1\) means even with repeated applications of the chain rule it is always possible to trace through the network to get a non-zero update to earlier layers.
The Residual Stream in Mechanistic Interpretability
With this consistent \(d_{model}\) dimension vector flowing through the model we have an obvious place to inspect the model’s thoughts. As each layer refines the representation of the “is” token in preparation for predicting the next we can explore two ways to interpret this process:
- Logit Lens (nostalgebraist 2020)
- Linear Probes (Alain and Bengio 2018)
In the rest of this post I have some experiments that use these techniques to illustrate their value.
Logit Lens
The idea beyhind Logit Lens is simple, after the final layer, \(x_{-1}\), of the transformer we apply an unembedding to get logits,
\[ T(t) = W_{U} x_{-1} \tag{8}\]
which allow us to predict words. For our input sequence GPT-2 returned:
Londonwith 13% confidenceBirminghamwith 8% confidencenotwith 6% confidence- …
We can however apply the unembedding to intermediate vectors in the residual stream, getting a crude understanding of what the model is thinking of at that point. This allows us, in a very simple way, to see the thoughts of the model progress through the layers.
Linear Probes
Linear probes approach the interpretability problem in a different way. Imagine we have two classes of inputs, the linear probe is a classifier trained on the residual stream vectors that learns to distinguish the classes. Examples might include:
- Proper punctuation vs. ALL CAPS
- Natural language vs. HTML
- Harmless text vs. Harmful text
The first case we might imagine is quite simple to learn: there is a simple pattern-matching problem for the model to spot between "Hello, can I help you?" and "HELLO, CAN I HELP YOU?". Likewise, the second is hopefully quite easy to discern, but perhaps requires some more thinking: "The main header of the page says Hello." looks very different to "<h1 class='main-header'> Hello </h1>", but it requires more knowledge to consistently separate the two.
Finally, examples of harmless and harmful text might be really tricky to separate, requiring genuine semantic understanding. For example, "There is a bug in your code, so you should **refactor the function.**" is quite harmless whereas "There is a bug in your code, so you should **disable the firewall.**" could be catastrophic.