import numpy as np
d_model = 768
n_layers = 12

residual_stream = np.random.randn(n_layers, d_model)

residual_stream.shape
(12, 768)

Let’s use Einops to rearrange this into a nice array to visualize.

from einops import rearrange, reduce, repeat

residual_stream = rearrange(residual_stream, "layer (row col) -> layer row col", row=16)

residual_stream.shape
(12, 16, 48)

Now we do the simplest plot of such a tensor.

import matplotlib.pyplot as plt

plt.imshow(residual_stream[0], cmap='viridis')
plt.title("Residual Stream at Layer 0")
plt.show()

We want to make this much prettier for the blog.

fig, ax = plt.subplots(1,1, figsize=(15,5))
ax.axis(False)
fig.set_facecolor('black')
ax.set_facecolor('black')

cmesh = ax.pcolormesh(residual_stream[0], edgecolors='k', cmap='viridis', vmin=-3, vmax=3)

Now we can think about animating this.

from matplotlib.animation import FuncAnimation

def anim_function(frame_num):
    cmesh.set_array(residual_stream[frame_num,:,:])
    return cmesh,

anim = FuncAnimation(fig, anim_function, frames=np.arange(residual_stream.shape[0]), interval=30)

from IPython.display import HTML
HTML(anim.to_jshtml())

Now lets look to do something with actual model activations!

from huggingface_hub import login
login()
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
GPT2LMHeadModel LOAD REPORT from: gpt2
Key                  | Status     |  | 
---------------------+------------+--+-
h.{0...11}.attn.bias | UNEXPECTED |  | 

Notes:
- UNEXPECTED    :can be ignored when loading from different task/architecture; not ok if you expect identical arch.
inputs = tokenizer(
    "Paris is the capital of",
    padding=True,
    truncation=True,
    return_tensors="pt"
)

print(f"Tokenized {inputs['input_ids'].shape[0]} sequences with max length {inputs['input_ids'].shape[1]}.")
Tokenized 1 sequences with max length 5.
with torch.no_grad():
    outputs = model(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        output_hidden_states=True,
    )
type(outputs.hidden_states)
tuple
outputs.hidden_states[12][0, -1, :]
tensor([ 3.9707e-01,  6.5592e-01, -1.0789e-01, -5.1339e-01,  2.2478e-01,
        -3.5340e-01,  2.9402e+00,  5.0266e-01, -1.1648e-01, -3.3925e-01,
        -1.4641e-01,  1.3185e-01,  5.6856e-01, -2.4193e-01,  2.3383e-01,
        -6.1204e-01, -1.6270e-01,  2.6729e-01,  1.0813e-01, -3.3550e+00,
         3.8122e-01,  3.1480e-01,  3.6961e-01, -1.1843e-01, -4.1384e-02,
         2.7716e-01, -6.5405e-01, -6.8845e-01,  1.7058e-01,  2.3255e-01,
         2.1136e-02, -2.7776e-01,  1.7016e-01,  3.5261e-01,  6.2866e-01,
        -2.2970e-01,  5.0996e+01, -3.4832e-01, -4.9283e-01,  1.6009e-01,
        -2.5912e-01, -4.3954e-02, -6.6652e-02,  1.7371e-02, -4.9989e-02,
         7.7331e-02, -2.0900e-01, -5.2491e-01, -9.1295e-01, -7.5307e-02,
        -5.9557e-02, -3.5249e-01,  9.1986e-02, -4.4967e-02,  1.7474e-01,
         1.0622e+00, -3.1931e-02, -2.3144e-01,  1.4521e-01, -4.0620e-01,
        -5.8457e-01,  1.4042e-02,  1.0694e-01, -1.2151e-01, -1.1069e+00,
         4.4808e-01, -1.6879e-01,  4.8419e-02, -1.2701e+00,  1.9134e-01,
        -7.3997e-01, -6.9388e-01,  9.3711e-03, -3.0020e-01, -2.9383e-02,
         1.1450e-01,  2.2709e-01, -6.2942e-01,  3.4850e-01,  4.6560e-01,
         2.9459e-01, -2.3642e-01,  4.6218e-01, -7.3631e-02,  4.3862e-01,
         1.2286e-01,  9.8999e-01, -1.0245e+00,  2.4983e-01, -1.2601e-01,
         4.3485e-01,  6.3421e-01,  3.8452e-01,  9.8958e-03, -6.1150e-02,
        -2.8366e-01, -3.5584e-01, -3.4806e-01,  4.6458e-02,  1.0344e-01,
        -1.7640e-01,  2.6008e-01,  3.0842e-01, -5.9029e-01,  6.2168e-02,
         8.0234e-02, -2.1006e-01, -1.3377e+00,  4.7618e-01, -2.3790e-01,
         5.6921e-01, -1.8896e-01,  8.6591e-02,  1.3231e-02,  1.2710e-01,
        -5.3515e-01,  3.2699e-01, -2.2347e-01,  2.8715e-01, -1.1292e+00,
         7.7704e-02,  5.1702e-01,  2.5250e-02,  8.0246e-01, -2.2748e-01,
         4.8657e-01,  4.7917e-01,  5.0834e-01, -1.8833e-01,  5.0228e-01,
        -3.7779e-01, -3.9931e-01,  2.4244e-01,  2.1779e-01, -1.0248e-01,
         4.3215e-01,  1.9822e-01, -4.4591e-01,  5.6154e-01,  3.0105e-01,
        -4.6189e-01, -8.0389e-01,  9.8432e-01, -3.1944e-01,  9.4833e-02,
         1.1400e-01, -1.5258e-01,  8.6682e-02, -5.9506e-03,  1.0061e+00,
         3.5015e-01, -5.4344e-01, -1.4929e-02,  2.9061e-01,  7.7716e-01,
         5.7593e-02,  4.1433e-01,  2.0678e-01,  1.1295e-01, -3.3525e-01,
        -6.4733e-01, -2.0853e-01, -6.8152e-01, -2.3764e-01, -7.2244e-03,
         5.3895e-02, -1.7540e-02, -3.5193e+00, -2.7689e-01, -1.4060e-01,
         2.2860e-01, -2.3910e-01,  1.4758e-01,  4.6928e-01, -4.1897e-01,
        -8.1610e-01,  6.8833e-01, -5.3819e-01, -5.4749e-02, -1.4693e-01,
         1.5520e-01,  3.4751e-01, -1.0811e-02, -1.1496e-01, -3.7802e-01,
        -5.4832e-02, -1.7603e-01, -1.5183e-01, -3.5338e-03, -1.5314e-02,
        -1.0327e-01, -4.9390e-01, -9.5211e-03,  2.2363e-01, -8.6061e-02,
        -1.6430e+00,  4.8663e-01,  2.2926e-01, -1.5008e-01,  4.6846e-03,
         1.8621e-01,  5.8974e-01, -2.3210e-01, -3.8068e-02,  5.2106e-02,
        -2.1625e-01,  1.3113e-01, -1.6528e-01, -3.0899e-01, -8.6106e-02,
         1.1730e-01, -2.2622e-01, -3.6722e-01, -1.5751e-01, -3.6578e-02,
        -3.0102e-01, -4.9017e-02, -4.6044e-01,  4.7928e-01,  3.5405e-02,
         9.5526e-02, -1.9290e-01,  3.5486e-01,  3.5565e-01, -4.3436e-01,
        -3.7214e-01,  3.7490e-01, -1.4289e-01,  1.4892e-01,  4.4548e-01,
        -6.1742e-02,  3.3417e-01,  1.5622e+00, -1.6550e-01, -9.5622e-02,
        -3.1609e-01,  3.5710e-01,  1.5547e-01,  3.7913e-01, -4.5316e-01,
        -3.0411e-01,  6.3645e-01,  4.6493e-01, -2.7021e-01,  2.8007e-01,
         3.5020e-01,  4.3037e-01, -2.0467e-01, -1.7361e-01,  5.0707e-02,
         5.2277e-01, -1.6470e-01,  2.9059e-01,  9.7539e-02, -8.1670e-02,
         3.0347e-01,  3.0409e-01, -1.9271e-01, -3.8301e-01,  5.5685e-01,
        -2.9989e-01,  7.5335e-01, -1.9442e-01, -8.1730e-01, -5.7642e-02,
         4.5157e-01, -1.0122e+00,  3.4106e-02,  3.7177e-01,  4.8461e-01,
         1.0395e+00,  3.5976e-01,  1.9491e-01, -1.8765e-01,  9.0858e-02,
        -1.2970e-01, -6.9582e-02, -3.3006e-01,  1.2637e-01, -1.4938e-01,
         1.7854e-01, -3.3438e-02,  2.9249e-01, -6.0153e-02, -3.9808e-01,
        -4.7156e-02,  1.6853e-01, -2.4003e-02, -9.1483e-02,  3.4247e-01,
         2.2795e-01,  3.3932e-01, -3.3874e-01, -5.3320e-02, -1.2861e-02,
        -1.2411e-02,  3.7387e-02,  5.8254e-02, -2.5397e-01,  5.2348e-01,
         8.6658e-01, -3.5760e-01,  1.3084e-01, -1.1010e-01,  5.1624e-01,
         6.0349e-01, -4.8093e-01,  1.6971e-01,  1.2630e+00, -2.6303e-01,
         1.8936e-01,  2.1825e-01,  2.9491e-02, -3.4005e-01, -4.3871e+01,
         4.7595e-02, -1.4228e-01, -1.0846e-01, -1.5160e-01,  1.0329e-01,
         3.7053e-01, -1.2320e-02,  3.1999e-01,  5.6906e-02,  3.0835e-01,
         2.0784e-01, -3.3139e-01, -4.5700e-01, -1.5465e-01,  1.6962e-01,
        -3.5235e-01, -7.5174e-03,  4.5752e-01,  4.0825e-01, -1.1673e-01,
        -1.0488e-02,  2.5212e-02, -1.6757e-01,  2.3002e-01, -1.8673e-02,
        -1.7697e-01, -2.1341e-01, -6.6391e-01, -1.8274e-01,  7.0124e-01,
         2.5572e-01,  1.5160e-01, -2.0538e-01,  1.8113e-01, -4.6497e-02,
        -2.5627e-01, -4.6529e-01,  1.9979e-01, -1.2916e-01, -3.4473e-01,
        -4.8907e-01, -3.6988e-01,  1.0921e-01, -1.2943e-01, -1.2346e-01,
         8.8096e-02,  1.2546e+00,  6.3094e+00, -2.5674e-01, -4.3520e-01,
        -2.3859e+00,  4.8746e-01, -3.9125e-02, -4.1557e-02, -2.7922e-01,
         8.7475e-02,  2.1376e-01, -6.5235e-02, -1.2277e+00, -2.0121e+01,
         4.7794e-01,  3.3857e-01,  1.3460e-01,  4.3976e-01,  1.7539e-01,
         2.3180e-01, -3.4298e-01,  2.6690e-01, -3.0966e-01, -8.0459e-01,
         1.0486e-01,  2.3068e-01, -2.7034e-01, -7.2113e-02,  1.2232e+00,
        -3.3541e-02,  1.2416e-01, -5.4055e-02,  3.2458e-01, -2.6819e-01,
        -2.7736e-01, -1.0285e-01, -3.2138e-01,  4.5720e-01,  1.7563e-01,
        -5.7128e-02,  1.5178e-01, -4.0476e-01,  5.8006e-02,  3.9195e-02,
        -2.5981e-01,  2.0811e-01, -2.8853e-01,  6.6633e+00, -5.7252e-01,
        -5.5131e-02,  7.3943e-02,  3.7546e-01,  5.8285e-02, -6.3121e-01,
        -5.0900e-01, -5.1948e-02, -1.1226e-01, -1.6612e-01, -1.6518e-01,
         4.8815e-01,  1.6613e-01, -4.5851e-01, -5.6316e-01, -3.2561e-02,
        -1.4305e-01, -2.9152e-01, -6.2530e-01, -1.1286e-01,  6.0920e-03,
         7.0946e+01,  4.8999e-01,  1.3975e-01, -1.2525e-01, -1.5065e-01,
         1.6868e-01, -2.7682e-02, -4.1910e-01,  2.6702e-01, -1.3076e+00,
         4.4607e-01,  1.2542e-01, -2.8018e+00,  1.4158e-01, -5.2851e-01,
        -3.3754e-01,  1.9284e-02,  1.0982e+00,  6.1203e-02,  2.5273e-01,
        -6.5044e-01, -1.6463e-01, -3.6856e-01,  7.3381e-02,  2.8394e-01,
         7.1713e-01, -4.5587e-01, -3.6499e-01, -3.1927e-01,  2.5299e-01,
         3.1554e-01, -3.9655e-01, -1.6431e-01, -3.2882e-01,  4.8321e-01,
         2.6438e-02,  3.2431e-01, -7.3670e-01,  3.7471e-01,  2.6125e-01,
         1.1955e-01, -2.1661e-01,  4.1328e-01,  1.6487e-01,  4.8600e-02,
        -1.4378e-02, -7.9637e-02, -4.6863e-01, -4.1781e-01,  1.7027e+00,
        -8.9309e-01,  1.2432e+00,  2.1844e-01,  5.7696e-02,  1.4689e-01,
         1.5076e-01,  5.9482e-02, -6.2378e-01, -2.9110e-01, -1.6801e-01,
        -3.6049e-01, -3.8028e-02, -2.2511e-01, -5.7139e-01, -4.8377e-01,
        -4.3101e-02,  1.8708e+02, -2.1507e-01,  1.9170e-01, -3.0552e-02,
         2.2827e-01,  1.0583e-01, -7.0139e-02, -2.7798e+00,  1.4771e-01,
         8.5176e-02, -2.1687e-01,  3.3596e-01,  5.2070e-01,  1.1699e-01,
         9.5625e-02,  7.2405e-02, -7.4354e-01, -4.1796e-01, -3.1036e-01,
         1.9051e-01,  1.3422e-01,  1.3123e-02,  6.9384e-02,  4.9266e-01,
        -3.0280e-01,  4.1873e-01,  3.8501e-01,  5.0008e-02, -4.1961e-01,
        -9.2081e-01, -8.8505e-01, -4.4598e-01,  3.2327e-01,  3.3958e-01,
         2.7451e-01,  1.5717e-02, -1.3682e-01, -1.1711e-01,  8.5470e-01,
         2.0400e-01, -3.2440e-03, -6.6013e-01, -9.5547e-01,  1.1471e-01,
        -2.7788e-01, -2.0668e-01,  2.4634e-01,  6.9303e-01, -4.9175e-01,
         5.9871e-01,  2.5664e-01, -3.5525e-02, -6.6084e-01,  3.5962e-03,
         7.7517e-01,  3.3645e-01,  4.1101e-01,  6.3882e-01, -1.1079e+00,
         3.5268e-01, -1.6977e-01, -3.6987e-02,  2.5442e-01,  3.5130e-01,
        -3.1887e-02, -4.1854e-01,  1.6195e-01, -1.2116e-01,  3.4556e-01,
        -3.1018e-01, -1.4006e-01, -2.2971e-02, -2.4656e-01, -1.3791e-02,
        -7.8256e-02,  1.3936e-01,  2.5287e-01,  1.2679e-01,  1.0444e-02,
        -7.2022e-01,  6.0788e-02,  1.3312e-01, -5.9214e-01, -2.5595e-02,
        -1.9300e-01,  2.0226e-01,  8.7149e-02, -2.8428e-01, -7.0516e-02,
         2.0947e-01, -2.4657e-01, -2.4485e-01, -9.9435e-02,  3.1029e-01,
        -3.1992e-02, -6.6761e-02, -6.7534e-01, -8.6753e-01, -6.0732e-02,
         1.5726e-01,  7.2841e-02, -1.3594e-01,  2.1770e-01, -1.9269e-01,
         7.0538e-02,  2.4697e-01, -2.1211e-01,  3.9387e-01, -1.1869e-01,
         3.3987e-01, -2.5016e-01,  3.6958e-02,  1.6831e+00,  1.3615e-02,
        -1.7321e-01,  5.6813e-01,  6.5152e-01, -1.0398e-01,  3.8231e-01,
         3.3899e-01,  8.4053e-01,  1.7321e-01,  2.9557e-01,  4.3555e-01,
        -3.8328e-01,  7.0625e-01,  1.4007e-01,  5.9916e-01,  3.2180e-01,
         2.8273e-01,  5.9825e-02,  9.4940e-01, -4.5199e-02,  2.8532e-01,
        -7.0141e-02,  5.2700e-02, -1.2425e-01, -4.9037e-02, -3.8386e-01,
        -6.8675e-01,  5.6606e-02,  1.1794e-01, -4.0015e-01,  1.5555e-01,
         4.9732e-01, -2.5473e-02,  6.1942e-01, -4.3914e-01,  3.6841e-01,
        -1.0106e+00,  5.7044e-01,  6.3400e-01,  1.8109e-01, -1.2689e+00,
         3.0568e-01, -2.4669e-01, -4.7047e-01, -6.9111e-01, -1.8980e-01,
         9.3604e-01, -2.0410e-01, -1.1872e-01, -2.7357e-02,  2.3875e-01,
         3.1092e-01,  1.1390e-02, -2.4784e-01, -7.8412e-03, -6.9789e-01,
         1.5111e-01, -2.7125e-01, -2.5387e-01, -1.8569e-01,  5.4281e-01,
        -1.2238e-01,  5.8334e-01, -1.0270e-01, -6.6581e-01,  5.0467e-02,
         2.2914e-01, -6.2707e-01,  3.1402e-01, -2.5495e-01,  2.1849e-01,
        -8.2420e-01,  1.0366e-01, -6.9238e-02, -1.6138e-01,  2.4286e-02,
         8.9909e-02, -2.2955e-02,  1.1647e-01,  3.7806e-01,  3.2999e-01,
         9.7268e-01,  2.0536e-01,  3.2675e-01,  1.2438e-01,  9.5249e-02,
         2.0198e-01, -1.1161e-01, -1.8637e-01, -5.8571e-02, -4.2029e-01,
        -7.9599e-02,  1.5202e-01,  2.0047e-01, -1.0811e+00, -6.1192e-01,
        -1.4036e-01,  5.7603e-01, -1.3407e-01, -1.2577e-01, -3.3293e-01,
        -1.4304e-01, -5.7476e-01, -3.0210e-01, -1.0654e-01,  1.8689e-01,
        -1.1735e-01,  2.6910e-02,  3.3609e-01, -3.8913e-01,  1.3305e-02,
        -9.7757e-01, -5.0708e-01, -3.9805e-01, -2.5544e-01,  5.2961e-01,
         7.5593e-01,  2.7601e-01, -1.0440e+00,  4.2858e-01, -3.3402e-01,
         4.9627e-01, -5.5644e-02,  2.6034e-01, -3.1860e-01,  1.9822e-01,
        -6.6199e-01,  5.5858e-01, -1.4615e-01, -1.3237e-01, -4.9860e-02,
         6.2007e-02, -1.5592e-02,  1.9362e-01, -3.9611e-01, -5.7654e-01,
        -1.5250e+00, -1.2473e-01,  1.1547e-01,  1.7411e-02, -4.1657e-01,
        -3.4960e-02,  3.0048e-01, -3.1938e-01,  4.4102e-01, -6.5498e-01,
         3.7197e-01, -1.3987e-01,  1.6034e-01, -6.5139e-01,  3.1342e-01,
        -2.8109e-01,  1.6483e-01,  1.8485e-01, -7.5067e-01,  1.2059e-01,
         2.5470e-01, -8.5622e-02,  4.0347e-02])

Now we animate through the residual stream.

residual_stream = np.zeros((13, d_model))

for layer in range(13):
    residual_stream_layer = outputs.hidden_states[layer][0, -1, :].numpy()
    residual_stream[layer] = residual_stream_layer
    print(f"Layer {layer}: Residual stream shape {residual_stream_layer.shape}")
Layer 0: Residual stream shape (768,)
Layer 1: Residual stream shape (768,)
Layer 2: Residual stream shape (768,)
Layer 3: Residual stream shape (768,)
Layer 4: Residual stream shape (768,)
Layer 5: Residual stream shape (768,)
Layer 6: Residual stream shape (768,)
Layer 7: Residual stream shape (768,)
Layer 8: Residual stream shape (768,)
Layer 9: Residual stream shape (768,)
Layer 10: Residual stream shape (768,)
Layer 11: Residual stream shape (768,)
Layer 12: Residual stream shape (768,)
residual_stream = rearrange(residual_stream, "layer (row col) -> layer row col", row=16)

residual_stream.shape
(13, 16, 48)
from ipywidgets import interact, interactive, IntSlider, VBox, HBox, Output
import matplotlib.pyplot as plt

output = Output()

def plot_layer(layer=0):
    with output:
        output.clear_output(wait=True)
        fig, ax = plt.subplots(1, 1, figsize=(15, 5))
        ax.axis(False)
        fig.set_facecolor('white')
        ax.set_facecolor('white')

        cmesh = ax.pcolormesh(residual_stream[layer], edgecolors='white', cmap='viridis', vmin=-3, vmax=3)
        ax.set_title(f'Layer {layer}', color='black', fontsize=16)
        plt.show()

slider = IntSlider(min=0, max=residual_stream.shape[0]-1, step=1, value=0, description='Layer:')
slider.observe(lambda change: plot_layer(change['new']), names='value')

# Initial plot
plot_layer(0)

# Display with slider below and centered
from ipywidgets import Layout
display(VBox([output, HBox([slider], layout=Layout(justify_content='center'))],
             layout=Layout(background_color='white')))