import numpy as npd_model = 768
n_layers = 12
residual_stream = np.random.randn(n_layers, d_model)
residual_stream.shape(12, 768)
Let’s use Einops to rearrange this into a nice array to visualize.
from einops import rearrange, reduce, repeat
residual_stream = rearrange(residual_stream, "layer (row col) -> layer row col", row=16)
residual_stream.shape(12, 16, 48)
Now we do the simplest plot of such a tensor.
import matplotlib.pyplot as plt
plt.imshow(residual_stream[0], cmap='viridis')
plt.title("Residual Stream at Layer 0")
plt.show()
We want to make this much prettier for the blog.
fig, ax = plt.subplots(1,1, figsize=(15,5))
ax.axis(False)
fig.set_facecolor('black')
ax.set_facecolor('black')
cmesh = ax.pcolormesh(residual_stream[0], edgecolors='k', cmap='viridis', vmin=-3, vmax=3)Now we can think about animating this.
from matplotlib.animation import FuncAnimation
def anim_function(frame_num):
cmesh.set_array(residual_stream[frame_num,:,:])
return cmesh,
anim = FuncAnimation(fig, anim_function, frames=np.arange(residual_stream.shape[0]), interval=30)
from IPython.display import HTML
HTML(anim.to_jshtml())Now lets look to do something with actual model activations!
from huggingface_hub import login
login()import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_tokenGPT2LMHeadModel LOAD REPORT from: gpt2
Key | Status | |
---------------------+------------+--+-
h.{0...11}.attn.bias | UNEXPECTED | |
Notes:
- UNEXPECTED :can be ignored when loading from different task/architecture; not ok if you expect identical arch.
inputs = tokenizer(
"Paris is the capital of",
padding=True,
truncation=True,
return_tensors="pt"
)
print(f"Tokenized {inputs['input_ids'].shape[0]} sequences with max length {inputs['input_ids'].shape[1]}.")Tokenized 1 sequences with max length 5.
with torch.no_grad():
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
output_hidden_states=True,
)type(outputs.hidden_states)tuple
outputs.hidden_states[12][0, -1, :]tensor([ 3.9707e-01, 6.5592e-01, -1.0789e-01, -5.1339e-01, 2.2478e-01,
-3.5340e-01, 2.9402e+00, 5.0266e-01, -1.1648e-01, -3.3925e-01,
-1.4641e-01, 1.3185e-01, 5.6856e-01, -2.4193e-01, 2.3383e-01,
-6.1204e-01, -1.6270e-01, 2.6729e-01, 1.0813e-01, -3.3550e+00,
3.8122e-01, 3.1480e-01, 3.6961e-01, -1.1843e-01, -4.1384e-02,
2.7716e-01, -6.5405e-01, -6.8845e-01, 1.7058e-01, 2.3255e-01,
2.1136e-02, -2.7776e-01, 1.7016e-01, 3.5261e-01, 6.2866e-01,
-2.2970e-01, 5.0996e+01, -3.4832e-01, -4.9283e-01, 1.6009e-01,
-2.5912e-01, -4.3954e-02, -6.6652e-02, 1.7371e-02, -4.9989e-02,
7.7331e-02, -2.0900e-01, -5.2491e-01, -9.1295e-01, -7.5307e-02,
-5.9557e-02, -3.5249e-01, 9.1986e-02, -4.4967e-02, 1.7474e-01,
1.0622e+00, -3.1931e-02, -2.3144e-01, 1.4521e-01, -4.0620e-01,
-5.8457e-01, 1.4042e-02, 1.0694e-01, -1.2151e-01, -1.1069e+00,
4.4808e-01, -1.6879e-01, 4.8419e-02, -1.2701e+00, 1.9134e-01,
-7.3997e-01, -6.9388e-01, 9.3711e-03, -3.0020e-01, -2.9383e-02,
1.1450e-01, 2.2709e-01, -6.2942e-01, 3.4850e-01, 4.6560e-01,
2.9459e-01, -2.3642e-01, 4.6218e-01, -7.3631e-02, 4.3862e-01,
1.2286e-01, 9.8999e-01, -1.0245e+00, 2.4983e-01, -1.2601e-01,
4.3485e-01, 6.3421e-01, 3.8452e-01, 9.8958e-03, -6.1150e-02,
-2.8366e-01, -3.5584e-01, -3.4806e-01, 4.6458e-02, 1.0344e-01,
-1.7640e-01, 2.6008e-01, 3.0842e-01, -5.9029e-01, 6.2168e-02,
8.0234e-02, -2.1006e-01, -1.3377e+00, 4.7618e-01, -2.3790e-01,
5.6921e-01, -1.8896e-01, 8.6591e-02, 1.3231e-02, 1.2710e-01,
-5.3515e-01, 3.2699e-01, -2.2347e-01, 2.8715e-01, -1.1292e+00,
7.7704e-02, 5.1702e-01, 2.5250e-02, 8.0246e-01, -2.2748e-01,
4.8657e-01, 4.7917e-01, 5.0834e-01, -1.8833e-01, 5.0228e-01,
-3.7779e-01, -3.9931e-01, 2.4244e-01, 2.1779e-01, -1.0248e-01,
4.3215e-01, 1.9822e-01, -4.4591e-01, 5.6154e-01, 3.0105e-01,
-4.6189e-01, -8.0389e-01, 9.8432e-01, -3.1944e-01, 9.4833e-02,
1.1400e-01, -1.5258e-01, 8.6682e-02, -5.9506e-03, 1.0061e+00,
3.5015e-01, -5.4344e-01, -1.4929e-02, 2.9061e-01, 7.7716e-01,
5.7593e-02, 4.1433e-01, 2.0678e-01, 1.1295e-01, -3.3525e-01,
-6.4733e-01, -2.0853e-01, -6.8152e-01, -2.3764e-01, -7.2244e-03,
5.3895e-02, -1.7540e-02, -3.5193e+00, -2.7689e-01, -1.4060e-01,
2.2860e-01, -2.3910e-01, 1.4758e-01, 4.6928e-01, -4.1897e-01,
-8.1610e-01, 6.8833e-01, -5.3819e-01, -5.4749e-02, -1.4693e-01,
1.5520e-01, 3.4751e-01, -1.0811e-02, -1.1496e-01, -3.7802e-01,
-5.4832e-02, -1.7603e-01, -1.5183e-01, -3.5338e-03, -1.5314e-02,
-1.0327e-01, -4.9390e-01, -9.5211e-03, 2.2363e-01, -8.6061e-02,
-1.6430e+00, 4.8663e-01, 2.2926e-01, -1.5008e-01, 4.6846e-03,
1.8621e-01, 5.8974e-01, -2.3210e-01, -3.8068e-02, 5.2106e-02,
-2.1625e-01, 1.3113e-01, -1.6528e-01, -3.0899e-01, -8.6106e-02,
1.1730e-01, -2.2622e-01, -3.6722e-01, -1.5751e-01, -3.6578e-02,
-3.0102e-01, -4.9017e-02, -4.6044e-01, 4.7928e-01, 3.5405e-02,
9.5526e-02, -1.9290e-01, 3.5486e-01, 3.5565e-01, -4.3436e-01,
-3.7214e-01, 3.7490e-01, -1.4289e-01, 1.4892e-01, 4.4548e-01,
-6.1742e-02, 3.3417e-01, 1.5622e+00, -1.6550e-01, -9.5622e-02,
-3.1609e-01, 3.5710e-01, 1.5547e-01, 3.7913e-01, -4.5316e-01,
-3.0411e-01, 6.3645e-01, 4.6493e-01, -2.7021e-01, 2.8007e-01,
3.5020e-01, 4.3037e-01, -2.0467e-01, -1.7361e-01, 5.0707e-02,
5.2277e-01, -1.6470e-01, 2.9059e-01, 9.7539e-02, -8.1670e-02,
3.0347e-01, 3.0409e-01, -1.9271e-01, -3.8301e-01, 5.5685e-01,
-2.9989e-01, 7.5335e-01, -1.9442e-01, -8.1730e-01, -5.7642e-02,
4.5157e-01, -1.0122e+00, 3.4106e-02, 3.7177e-01, 4.8461e-01,
1.0395e+00, 3.5976e-01, 1.9491e-01, -1.8765e-01, 9.0858e-02,
-1.2970e-01, -6.9582e-02, -3.3006e-01, 1.2637e-01, -1.4938e-01,
1.7854e-01, -3.3438e-02, 2.9249e-01, -6.0153e-02, -3.9808e-01,
-4.7156e-02, 1.6853e-01, -2.4003e-02, -9.1483e-02, 3.4247e-01,
2.2795e-01, 3.3932e-01, -3.3874e-01, -5.3320e-02, -1.2861e-02,
-1.2411e-02, 3.7387e-02, 5.8254e-02, -2.5397e-01, 5.2348e-01,
8.6658e-01, -3.5760e-01, 1.3084e-01, -1.1010e-01, 5.1624e-01,
6.0349e-01, -4.8093e-01, 1.6971e-01, 1.2630e+00, -2.6303e-01,
1.8936e-01, 2.1825e-01, 2.9491e-02, -3.4005e-01, -4.3871e+01,
4.7595e-02, -1.4228e-01, -1.0846e-01, -1.5160e-01, 1.0329e-01,
3.7053e-01, -1.2320e-02, 3.1999e-01, 5.6906e-02, 3.0835e-01,
2.0784e-01, -3.3139e-01, -4.5700e-01, -1.5465e-01, 1.6962e-01,
-3.5235e-01, -7.5174e-03, 4.5752e-01, 4.0825e-01, -1.1673e-01,
-1.0488e-02, 2.5212e-02, -1.6757e-01, 2.3002e-01, -1.8673e-02,
-1.7697e-01, -2.1341e-01, -6.6391e-01, -1.8274e-01, 7.0124e-01,
2.5572e-01, 1.5160e-01, -2.0538e-01, 1.8113e-01, -4.6497e-02,
-2.5627e-01, -4.6529e-01, 1.9979e-01, -1.2916e-01, -3.4473e-01,
-4.8907e-01, -3.6988e-01, 1.0921e-01, -1.2943e-01, -1.2346e-01,
8.8096e-02, 1.2546e+00, 6.3094e+00, -2.5674e-01, -4.3520e-01,
-2.3859e+00, 4.8746e-01, -3.9125e-02, -4.1557e-02, -2.7922e-01,
8.7475e-02, 2.1376e-01, -6.5235e-02, -1.2277e+00, -2.0121e+01,
4.7794e-01, 3.3857e-01, 1.3460e-01, 4.3976e-01, 1.7539e-01,
2.3180e-01, -3.4298e-01, 2.6690e-01, -3.0966e-01, -8.0459e-01,
1.0486e-01, 2.3068e-01, -2.7034e-01, -7.2113e-02, 1.2232e+00,
-3.3541e-02, 1.2416e-01, -5.4055e-02, 3.2458e-01, -2.6819e-01,
-2.7736e-01, -1.0285e-01, -3.2138e-01, 4.5720e-01, 1.7563e-01,
-5.7128e-02, 1.5178e-01, -4.0476e-01, 5.8006e-02, 3.9195e-02,
-2.5981e-01, 2.0811e-01, -2.8853e-01, 6.6633e+00, -5.7252e-01,
-5.5131e-02, 7.3943e-02, 3.7546e-01, 5.8285e-02, -6.3121e-01,
-5.0900e-01, -5.1948e-02, -1.1226e-01, -1.6612e-01, -1.6518e-01,
4.8815e-01, 1.6613e-01, -4.5851e-01, -5.6316e-01, -3.2561e-02,
-1.4305e-01, -2.9152e-01, -6.2530e-01, -1.1286e-01, 6.0920e-03,
7.0946e+01, 4.8999e-01, 1.3975e-01, -1.2525e-01, -1.5065e-01,
1.6868e-01, -2.7682e-02, -4.1910e-01, 2.6702e-01, -1.3076e+00,
4.4607e-01, 1.2542e-01, -2.8018e+00, 1.4158e-01, -5.2851e-01,
-3.3754e-01, 1.9284e-02, 1.0982e+00, 6.1203e-02, 2.5273e-01,
-6.5044e-01, -1.6463e-01, -3.6856e-01, 7.3381e-02, 2.8394e-01,
7.1713e-01, -4.5587e-01, -3.6499e-01, -3.1927e-01, 2.5299e-01,
3.1554e-01, -3.9655e-01, -1.6431e-01, -3.2882e-01, 4.8321e-01,
2.6438e-02, 3.2431e-01, -7.3670e-01, 3.7471e-01, 2.6125e-01,
1.1955e-01, -2.1661e-01, 4.1328e-01, 1.6487e-01, 4.8600e-02,
-1.4378e-02, -7.9637e-02, -4.6863e-01, -4.1781e-01, 1.7027e+00,
-8.9309e-01, 1.2432e+00, 2.1844e-01, 5.7696e-02, 1.4689e-01,
1.5076e-01, 5.9482e-02, -6.2378e-01, -2.9110e-01, -1.6801e-01,
-3.6049e-01, -3.8028e-02, -2.2511e-01, -5.7139e-01, -4.8377e-01,
-4.3101e-02, 1.8708e+02, -2.1507e-01, 1.9170e-01, -3.0552e-02,
2.2827e-01, 1.0583e-01, -7.0139e-02, -2.7798e+00, 1.4771e-01,
8.5176e-02, -2.1687e-01, 3.3596e-01, 5.2070e-01, 1.1699e-01,
9.5625e-02, 7.2405e-02, -7.4354e-01, -4.1796e-01, -3.1036e-01,
1.9051e-01, 1.3422e-01, 1.3123e-02, 6.9384e-02, 4.9266e-01,
-3.0280e-01, 4.1873e-01, 3.8501e-01, 5.0008e-02, -4.1961e-01,
-9.2081e-01, -8.8505e-01, -4.4598e-01, 3.2327e-01, 3.3958e-01,
2.7451e-01, 1.5717e-02, -1.3682e-01, -1.1711e-01, 8.5470e-01,
2.0400e-01, -3.2440e-03, -6.6013e-01, -9.5547e-01, 1.1471e-01,
-2.7788e-01, -2.0668e-01, 2.4634e-01, 6.9303e-01, -4.9175e-01,
5.9871e-01, 2.5664e-01, -3.5525e-02, -6.6084e-01, 3.5962e-03,
7.7517e-01, 3.3645e-01, 4.1101e-01, 6.3882e-01, -1.1079e+00,
3.5268e-01, -1.6977e-01, -3.6987e-02, 2.5442e-01, 3.5130e-01,
-3.1887e-02, -4.1854e-01, 1.6195e-01, -1.2116e-01, 3.4556e-01,
-3.1018e-01, -1.4006e-01, -2.2971e-02, -2.4656e-01, -1.3791e-02,
-7.8256e-02, 1.3936e-01, 2.5287e-01, 1.2679e-01, 1.0444e-02,
-7.2022e-01, 6.0788e-02, 1.3312e-01, -5.9214e-01, -2.5595e-02,
-1.9300e-01, 2.0226e-01, 8.7149e-02, -2.8428e-01, -7.0516e-02,
2.0947e-01, -2.4657e-01, -2.4485e-01, -9.9435e-02, 3.1029e-01,
-3.1992e-02, -6.6761e-02, -6.7534e-01, -8.6753e-01, -6.0732e-02,
1.5726e-01, 7.2841e-02, -1.3594e-01, 2.1770e-01, -1.9269e-01,
7.0538e-02, 2.4697e-01, -2.1211e-01, 3.9387e-01, -1.1869e-01,
3.3987e-01, -2.5016e-01, 3.6958e-02, 1.6831e+00, 1.3615e-02,
-1.7321e-01, 5.6813e-01, 6.5152e-01, -1.0398e-01, 3.8231e-01,
3.3899e-01, 8.4053e-01, 1.7321e-01, 2.9557e-01, 4.3555e-01,
-3.8328e-01, 7.0625e-01, 1.4007e-01, 5.9916e-01, 3.2180e-01,
2.8273e-01, 5.9825e-02, 9.4940e-01, -4.5199e-02, 2.8532e-01,
-7.0141e-02, 5.2700e-02, -1.2425e-01, -4.9037e-02, -3.8386e-01,
-6.8675e-01, 5.6606e-02, 1.1794e-01, -4.0015e-01, 1.5555e-01,
4.9732e-01, -2.5473e-02, 6.1942e-01, -4.3914e-01, 3.6841e-01,
-1.0106e+00, 5.7044e-01, 6.3400e-01, 1.8109e-01, -1.2689e+00,
3.0568e-01, -2.4669e-01, -4.7047e-01, -6.9111e-01, -1.8980e-01,
9.3604e-01, -2.0410e-01, -1.1872e-01, -2.7357e-02, 2.3875e-01,
3.1092e-01, 1.1390e-02, -2.4784e-01, -7.8412e-03, -6.9789e-01,
1.5111e-01, -2.7125e-01, -2.5387e-01, -1.8569e-01, 5.4281e-01,
-1.2238e-01, 5.8334e-01, -1.0270e-01, -6.6581e-01, 5.0467e-02,
2.2914e-01, -6.2707e-01, 3.1402e-01, -2.5495e-01, 2.1849e-01,
-8.2420e-01, 1.0366e-01, -6.9238e-02, -1.6138e-01, 2.4286e-02,
8.9909e-02, -2.2955e-02, 1.1647e-01, 3.7806e-01, 3.2999e-01,
9.7268e-01, 2.0536e-01, 3.2675e-01, 1.2438e-01, 9.5249e-02,
2.0198e-01, -1.1161e-01, -1.8637e-01, -5.8571e-02, -4.2029e-01,
-7.9599e-02, 1.5202e-01, 2.0047e-01, -1.0811e+00, -6.1192e-01,
-1.4036e-01, 5.7603e-01, -1.3407e-01, -1.2577e-01, -3.3293e-01,
-1.4304e-01, -5.7476e-01, -3.0210e-01, -1.0654e-01, 1.8689e-01,
-1.1735e-01, 2.6910e-02, 3.3609e-01, -3.8913e-01, 1.3305e-02,
-9.7757e-01, -5.0708e-01, -3.9805e-01, -2.5544e-01, 5.2961e-01,
7.5593e-01, 2.7601e-01, -1.0440e+00, 4.2858e-01, -3.3402e-01,
4.9627e-01, -5.5644e-02, 2.6034e-01, -3.1860e-01, 1.9822e-01,
-6.6199e-01, 5.5858e-01, -1.4615e-01, -1.3237e-01, -4.9860e-02,
6.2007e-02, -1.5592e-02, 1.9362e-01, -3.9611e-01, -5.7654e-01,
-1.5250e+00, -1.2473e-01, 1.1547e-01, 1.7411e-02, -4.1657e-01,
-3.4960e-02, 3.0048e-01, -3.1938e-01, 4.4102e-01, -6.5498e-01,
3.7197e-01, -1.3987e-01, 1.6034e-01, -6.5139e-01, 3.1342e-01,
-2.8109e-01, 1.6483e-01, 1.8485e-01, -7.5067e-01, 1.2059e-01,
2.5470e-01, -8.5622e-02, 4.0347e-02])
Now we animate through the residual stream.
residual_stream = np.zeros((13, d_model))
for layer in range(13):
residual_stream_layer = outputs.hidden_states[layer][0, -1, :].numpy()
residual_stream[layer] = residual_stream_layer
print(f"Layer {layer}: Residual stream shape {residual_stream_layer.shape}")Layer 0: Residual stream shape (768,)
Layer 1: Residual stream shape (768,)
Layer 2: Residual stream shape (768,)
Layer 3: Residual stream shape (768,)
Layer 4: Residual stream shape (768,)
Layer 5: Residual stream shape (768,)
Layer 6: Residual stream shape (768,)
Layer 7: Residual stream shape (768,)
Layer 8: Residual stream shape (768,)
Layer 9: Residual stream shape (768,)
Layer 10: Residual stream shape (768,)
Layer 11: Residual stream shape (768,)
Layer 12: Residual stream shape (768,)
residual_stream = rearrange(residual_stream, "layer (row col) -> layer row col", row=16)
residual_stream.shape(13, 16, 48)
from ipywidgets import interact, interactive, IntSlider, VBox, HBox, Output
import matplotlib.pyplot as plt
output = Output()
def plot_layer(layer=0):
with output:
output.clear_output(wait=True)
fig, ax = plt.subplots(1, 1, figsize=(15, 5))
ax.axis(False)
fig.set_facecolor('white')
ax.set_facecolor('white')
cmesh = ax.pcolormesh(residual_stream[layer], edgecolors='white', cmap='viridis', vmin=-3, vmax=3)
ax.set_title(f'Layer {layer}', color='black', fontsize=16)
plt.show()
slider = IntSlider(min=0, max=residual_stream.shape[0]-1, step=1, value=0, description='Layer:')
slider.observe(lambda change: plot_layer(change['new']), names='value')
# Initial plot
plot_layer(0)
# Display with slider below and centered
from ipywidgets import Layout
display(VBox([output, HBox([slider], layout=Layout(justify_content='center'))],
layout=Layout(background_color='white')))