import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
# Clear MPS cache
if torch.backends.mps.is_available():
    torch.mps.empty_cache()

# Also clear any cached gradients
import gc
gc.collect()
13575
device = "mps" if torch.backends.mps.is_available() else "cpu"
model_name = "microsoft/phi-3-mini-4k-instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_name).to(device)

tokenizer = AutoTokenizer.from_pretrained(model_name)

model
Phi3ForCausalLM(
  (model): Phi3Model(
    (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
    (layers): ModuleList(
      (0-31): 32 x Phi3DecoderLayer(
        (self_attn): Phi3Attention(
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
        )
        (mlp): Phi3MLP(
          (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (activation_fn): SiLUActivation()
        )
        (input_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
        (post_attention_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
        (resid_attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
      )
    )
    (norm): Phi3RMSNorm((3072,), eps=1e-05)
    (rotary_emb): Phi3RotaryEmbedding()
  )
  (lm_head): Linear(in_features=3072, out_features=32064, bias=False)
)
def decode_token_with_logitlens(model, device, tokenizer, input, tokens_to_gen=None):
    '''
    outputs a dictionary with {'decoded_tokens': [num_layers, seq_len], 'decoded_logits': [num_layers, seq_len]}
    '''
    inputs = tokenizer(input, return_tensors="pt").to(device)

    text = tokenizer.decode(inputs['input_ids'][0])
    # run the loop to generate new tokens after input, append to input and decode
    if tokens_to_gen != None:
        # generate new tokens all at once; then append them to input, then logitlens them all
        output = model.generate(
            inputs['input_ids'],
            do_sample=True,
            top_p=0.95,
            temperature=0.001,
            top_k=0,
            max_new_tokens=tokens_to_gen,
            )
        new_token = tokenizer.decode(output[0][-tokens_to_gen:])
        text += new_token
        inputs = tokenizer(text, return_tensors="pt").to(device)
    text_tokens = [tokenizer.decode(id) for id in inputs['input_ids'][0]]

    # apply decoder lens
    classifier_head = model.lm_head # Linear(in_features=3072, out_features=32064, bias=False)

    hidden_states = model(**inputs, output_hidden_states = True).hidden_states
    decoded_intermediate_token = {}
    decoded_intermediate_logit = {}
    with torch.no_grad():
        for layer_id in range(len(hidden_states)):
            hidden_state = hidden_states[layer_id]
            decoded_value = classifier_head(hidden_state) # [batch, seq_len, vocab_size]
            # get probabilities
            decoded_values = torch.nn.functional.softmax(decoded_value, dim=-1)
            # take max element
            argmax = torch.argmax(decoded_values, dim=-1)[0] # select first element in batch
            # decode all tokens
            decoded_token = [tokenizer.decode(int(el)) for el in argmax]
            decoded_logit = [decoded_values[0, it, argmax[it]].item() for it in range(len(argmax))] # list of layers, per layer the sequence_length
            decoded_intermediate_token[layer_id] = decoded_token
            decoded_intermediate_logit[layer_id] = decoded_logit

    tokens = list(decoded_intermediate_token.values()) # [num_layers, seq_len]
    logits = list(decoded_intermediate_logit.values()) # [num_layers, seq_len]
    return {'text_tokens':text_tokens, 'decoded_tokens': tokens, 'decoded_logits': logits}
text = "Citron: Acide\nSucre: Sucré\nPiment: Épicé\nMiel:"
text_tokenized = tokenizer.encode(text, return_tensors="pt").to(device)
output = model.generate(text_tokenized, num_beams=4, max_new_tokens=3, do_sample=True)
text_decoded = tokenizer.decode(output[0])
print(text_decoded)
Citron: Acide
Sucre: Sucré
Piment: Épicé
Miel: Doux
dict_output = decode_token_with_logitlens(model, device, tokenizer, text, tokens_to_gen=2)
decoded_tokens = dict_output['decoded_tokens']
decoded_logits = dict_output['decoded_logits']
text_tokens = dict_output['text_tokens']
print(len(text_tokens), text_tokens)
25 ['Cit', 'ron', ':', 'A', 'cide', '\n', 'Su', 'cre', ':', 'S', 'uc', 'ré', '\n', 'P', 'iment', ':', 'É', 'pic', 'é', '\n', 'M', 'iel', ':', 'D', 'oux']
for element in range(len(decoded_tokens)):
    print(decoded_tokens[element])
['izen', 'unci', 'Screen', 'bove', 'zat', 'Datos', 'ites', 'opening', 'Screen', 'orted', 'Stra', 'imas', 'Datos', 'rior', 'Db', 'Screen', 'cze', 'éd', 'parc', 'Datos', 'arch', 'merge', 'Screen', 'enis', 'Terr']
['míst', '☺', 'míst', 'míst', '🌍', 'míst', 'becom', 'míst', 'míst', '↵', '\x97', '̶', 'míst', '↵', 'hbox', 'míst', '̶', '\x97', '☺', 'míst', '↵', 'Ê', 'míst', '♯', 'Č']
['míst', '☺', 'míst', 'míst', '', 'míst', 'cí', 'míst', 'míst', 'míst', '›', '̶', 'míst', 'míst', 'ão', 'míst', '̶', '�', '☺', 'míst', 'cí', 'Vé', '─', '♯', 'Č']
['option', 'Niem', '\x97', 'mí', 'has', '\x97', '马', 'pick', 'infatti', '⇔', '›', 'men', '\x97', 'OL', 'wind', '\x97', '⇔', 'Sver', 'living', '\x97', '↵', 'Vé', '\x97', 'Martí', 'Č']
['option', 'ouc', 'highly', 'ero', 'practices', '\x97', 'ivi', 'int', 'infatti', 'míst', 'lande', 'ợ', '', 'OL', 'al', '\x97', '⟶', 'nic', 'surr', '\x97', '⇔', 'rig', '\x97', 'LT', 'fil']
['option', 'r', 'Fich', 'mí', 'ity', '', '⟶', 'at', 'Fich', 'echter', 'uro', 'exclus', '', 'urg', 'tab', 'Fon', '🌍', 'nick', 'bean', '', 'cí', 'rig', 'Fon', 'Martí', 'fi']
['option', 'ville', 'alberga', 'dynamically', 'ogen', '', '♦', 'abet', 'ば', 'míst', '☺', 'ones', '', 'cí', 'anu', '⇔', 'Err', 'urg', 'itos', '', 'cí', 'rig', 'prima', 'Martí', 'brow']
['option', '\u2009', '⇔', 'ero', 'gra', '', 'pon', 'LES', 'cogn', '⇔', '☺', 'ば', '', 'cí', 'ão', '⇔', 'Err', 'ere', '➖', '', 'penas', 'pan', 'prima', '\x97', 'brow']
['option', '\u2009', '⇔', 'dynamically', 'ity', '✔', 'cre', 'ville', 'prima', 'OL', 'loc', 'emer', '', 'urg', 'on', 'prima', 'cot', 'ere', 'acon', '', 'urg', 'de', 'prima', '↔', '-']
['option', 'isch', '{\r', 'oret', 'de', '↔', 'cre', 'de', 'prima', 'urg', 'loc', 'taste', 'ext', 'OL', 'on', 'cons', 'aster', 'ere', 'ol', 'ext', 'urg', 'de', 'cons', 'our', '-']
['option', 'soft', '{\r', 'war', 'de', '✅', 'pon', 'de', 'pre', 'urg', 'loc', ']],', 'odor', 'FA', 'on', 'pre', 'cot', 'ere', 'flav', 'soft', 'urg', 'de', 'prima', '↔', '-']
['option', 'iar', '{\r', 'результате', 'de', 'con', 'oc', 'de', 'pre', 'urg', 'ita', 'exc', 'soft', 'ü', 'on', 'pre', 'oc', 'результате', 'ien', 'soft', 'urg', 'de', 'op', 'ign', '-']
['option', 'iar', '{\r', 'war', 'de', '});\r', 'qu', 'de', 'pre', 'eng', 'е', 'sens', 'soft', 'ü', 'on', 'pre', 'ol', 'ere', 'ien', 'soft', 'ü', 'eli', 'prom', 'lam', '-']
['option', 'iar', '{\r', 'ali', 'ity', '});\r', 'con', 'de', 'prom', 'eng', 'ill', 'Tout', 'soft', 'ü', 'on', 'rig', 'lam', 'ere', 'flav', 'etc', 'adr', 'eli', 'prom', 'riv', '-']
['option', 'ica', '{\r', 'ali', 'de', '{\r', 'im', 'de', 'li', 'urg', 'ub', 'de', 'adr', 'urg', 'on', 'sens', 'lam', 'ere', '/', 'etc', 'urg', 'eli', 'sens', 'ill', '-']
['option', 'ica', '{\r', 'ali', 'ity', '{\r', 'ver', 'de', 'li', 'urg', 'ub', 'ais', '\ufeff', 'urg', 'on', 'sens', 'vol', 'ere', 'option', 'etc', 'ast', 'loc', 'war', 'urg', '-']
['option', 'ica', '{\r', 'myster', 'ity', '{\r', 'ft', 'de', 'ase', 'urg', 'amer', ']]', '{`', 'urg', 'on', 'sens', 'lam', 'ere', 'option', 'etc', 'ast', 'loc', 'sens', 'riv', '-']
['option', 'ica', '{\r', 'lem', 'de', '});\r', 'pect', 'de', 'sens', 'urg', 'ub', '\ufeff', 'etc', 'urg', 'on', 'sens', 'vol', 'ere', '/', 'etc', 'ast', 'loc', 'sens', 'riv', '-']
['option', 'ic', '{\r', 'lem', 'de', '});\r', 'ft', 'de', 'sens', 'ill', 'ub', '"}', '{`', 'urg', 'al', 'sens', 'vol', 'enter', 'taste', 'etc', 'ast', 'loc', 'sens', 'riv', 'soft']
['option', 'et', '{\r', 'modern', 'de', '<|endoftext|>', 'il', 'de', '', 'our', 'ub', '/', '{`', 'urg', 'al', 'sens', 'vol', 'enter', '/', '<|endoftext|>', 'ast', 's', 'sens', 'urg', '-']
['option', 'et', '{\r', 'compreh', 'de', '          ', 'il', 'de', 'b', 'our', 'ub', '\n', 'etc', 'urg', 'al', 'sens', 'vol', 'enter', '/', 'etc', 'our', ':', 'sens', 'il', '-']
['option', 'et', '{\r', 'compreh', 'de', '{\r', 'ven', 'de', '', 'our', 'our', '\n', 'etc', 'eg', 'al', 'sens', 'vol', 'enter', '<|endoftext|>', 'etc', 'ur', ':', 'de', 'ur', '\n']
['option', 'et', '{\r', 'Sidenote', 'de', '\t', 'ffic', 'de', '', 'uc', 'ub', '\n', 't', 'eg', 'al', 's', 'vol', 'enter', '\n', '<|endoftext|>', 'ast', ':', 'sweet', 'iffer', '\n']
['option', 'et', '\r', 'Sidenote', 'de', '✅', 'plement', ':', '', 'que', 'urs', '\n', 't', 'eg', 'o', 'd', 'vol', 'enter', '/', 'le', 'ast', ':', 'sweet', 'iffer', '\n']
['option', 'et', '"', 'Sidenote', 'ac', '✅', 'plement', ':', '', 'uc', 'rose', '\n', '', 'ast', 'ary', 'bitter', 'vol', 'enter', '\n', '<|endoftext|>', 'ast', ':', 'sweet', 'ivid', '\n']
['option', 'et', '"', 'Sidenote', 'le', '', 'plement', ':', 'sugar', 'weet', 'rose', '\n', 'gr', 'ast', 'de', 'bitter', 'vol', 'enter', '\n', 'c', 'our', ':', 'sweet', 'ivid', '\n']
['option', 'et', '', 'Sidenote', 'ac', '', 'plement', ':', 'Sug', 'weet', 'rose', '\n', 'c', 'ig', ':', 'hot', 'vol', 'enter', '\n', 'c', 'ature', ':', 'sweet', 'ivid', '\n']
['option', 'et', 'Sidenote', 'Sidenote', 'ac', '', 'plement', ':', 'sol', 'weet', 'rose', '\n', '', 'iment', ':', 'sp', 'colog', 'enter', '\n', 'D', 'astic', ':', 'sweet', 'ul', '\n']
['option', 'et', 'Sidenote', 'Sidenote', 's', '', 'plement', ':', '', 'uc', 'rose', '\n', '', 'iment', ':', 'sp', 'per', 'et', '\n', 'po', 'astic', ':', 'swe', 'ul', '\n']
['option', 'et', '', 'Sidenote', '', '', 'cc', ':', '', 'atur', 'id', '\n', '', 'iment', ':', '', 'per', 'ur', '\n', '', 'int', ':', 'D', 'ess', '\n']
['Ł', 'et', '', 'Sidenote', '', '\n', 'cre', ':', '', 'uc', 'rose', '\n', '', 'iment', ':', '', 'per', 'ur', '\n', 'A', 'ent', ':', 'D', 'u', '\n']
['\u202f', 'i', '', '', '', '\n', 'it', ':', '', 'uc', 'rose', '\n', '\n', 'ain', ':', 'A', 'p', 'ur', '\n', '\n', 'ent', ':', 'D', 'u', '\n']
['izen', 'ella', '\n', 'Cit', 'cit', '\n', 'ivi', ':', 'Gl', 'uc', 'ré', '\n', '\n', 'om', ':', 'A', 'pic', 'é', '\n', 'C', 'iel', ':', 'D', 'oux', '\n']
# Find where "Miel: Doux" tokens start
search_text = "Miel: Doux"
full_text = ''.join(text_tokens)

# Find the token indices for "Miel: Doux"
# First, let's see which tokens contain these characters
miel_start_idx = None
for i, token in enumerate(text_tokens):
    if "Miel" in token:
        miel_start_idx = i
        break

# Print to see what we found
print(f"Full text tokens: {text_tokens}")
print(f"Miel starts at token index: {miel_start_idx}")

# Let's look at a few tokens around "Miel:"
if miel_start_idx is not None:
    # Show tokens from Miel onwards (adjust range as needed)
    to_viz = (miel_start_idx, min(miel_start_idx + 5, len(text_tokens)))
else:
    # Fallback: show last few tokens if we can't find "Miel"
    to_viz = (max(0, len(text_tokens) - 5), len(text_tokens))
Full text tokens: ['Cit', 'ron', ':', 'A', 'cide', '\n', 'Su', 'cre', ':', 'S', 'uc', 'ré', '\n', 'P', 'iment', ':', 'É', 'pic', 'é', '\n', 'M', 'iel', ':', 'D', 'oux']
Miel starts at token index: None
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def plot_pretty_logit_lens(tokens_viz, logits_viz, text_tokens, start_idx, end_idx, start_layer=0):
    """
    Creates a professional Logit Lens visualization.
    """
    # Filter to only show layers from start_layer onwards
    tokens_viz = tokens_viz[start_layer:]
    logits_viz = logits_viz[start_layer:]

    # 1. Setup Data Labels
    # Rows: Layers, Columns: Pos + Token
    row_labels = [f"Layer {i}" for i in range(start_layer, start_layer + len(tokens_viz))]
    col_labels = [f"{i}: {text_tokens[i]}" for i in range(start_idx, end_idx)]

    # 2. Styling Parameters
    plt.figure(figsize=(max(12, len(col_labels)*2), len(row_labels)*0.6))
    sns.set_theme(style="white")

    # 3. Create Heatmap
    # We use logits_viz for the colors and tokens_viz for the text annotations
    ax = sns.heatmap(
        logits_viz,
        annot=tokens_viz,
        fmt="",                   # We are passing strings, so no formatting
        cmap="mako",              # Standard, readable colormap (try 'magma' or 'viridis' too)
        linewidths=0.5,
        linecolor='white',
        cbar_kws={'label': 'Logit Magnitude'},
        xticklabels=col_labels,
        yticklabels=row_labels,
        annot_kws={"size": 10, "weight": "bold"}
    )

    # 4. Refine Labels
    ax.xaxis.tick_top() # Move tokens to the top for easier reading
    ax.xaxis.set_label_position('top')
    plt.xticks(rotation=45, ha='left')
    plt.yticks(rotation=0)

    plt.title("Logit Lens: Evolution of Token Predictions", pad=40, fontsize=16)
    plt.tight_layout()

    plt.savefig(
        "images/logit_lens_viz.svg",
        format='svg',
        bbox_inches='tight',
        transparent=False,
        metadata={'Creator': 'Logit Lens Visualizer'}
    )

    plt.show()

# --- Integration with your existing variables ---
# Extract the slice of data needed
s_idx, e_idx = to_viz
final_tokens = [lay[s_idx:e_idx] for lay in decoded_tokens]
final_logits = np.array([logit[s_idx:e_idx] for logit in decoded_logits])

# Call with start_layer=12 to begin at Layer 12
plot_pretty_logit_lens(final_tokens, final_logits, text_tokens, s_idx, e_idx, start_layer=20)

text = "The capital of the country containing Manchester is"

inputs = tokenizer(text, return_tensors="pt").to(device)

# Get text tokens for this specific text
text_tokens = [tokenizer.decode(id) for id in inputs['input_ids'][0]]

# Get the hidden state at layer 24, token position 22
layer_24_hidden = model(**inputs, output_hidden_states=True).hidden_states[32]

# Apply classifier head to get logits for all vocabulary
with torch.no_grad():
    logits = model.lm_head(layer_24_hidden[0, -1, :])  # Use -1 for last token position
    probs = torch.nn.functional.softmax(logits, dim=-1)

    # Get top 10
    top_10_probs, top_10_indices = torch.topk(probs, 10)

    # Decode tokens
    top_10_tokens = [tokenizer.decode(int(idx)) for idx in top_10_indices]
    top_10_probs = top_10_probs.float().cpu().numpy()

# Reverse for display (highest at top)
top_10_tokens = top_10_tokens[::-1]
top_10_probs = top_10_probs[::-1]

# Create horizontal bar chart
fig, ax = plt.subplots(figsize=(10, 6))

y_positions = range(len(top_10_tokens))
ax.barh(y_positions, top_10_probs, color='steelblue', alpha=0.7)

ax.set_yticks(y_positions)
ax.set_yticklabels([f"{token}" for token in top_10_tokens])
ax.set_xlabel('Softmax Probability')
ax.set_title(f'Layer 24, Token Position {len(text_tokens)-1} ("{text_tokens[-1]}") - Top 10 Predictions')
ax.set_xlim(0, max(top_10_probs) * 1.1)

ax.grid(axis='x', alpha=0.3)

for i, prob in enumerate(top_10_probs):
    ax.text(prob + 0.001, i, f'{prob:.4f}', va='center', fontsize=9)

plt.tight_layout()
plt.show()

# import pickle

# # Save
# with open('logit_lens_results.pkl', 'wb') as f:
#     pickle.dump(dict_output, f)

# # Load later
# with open('logit_lens_results.pkl', 'rb') as f:
#     dict_output = pickle.load(f)
#     decoded_tokens = dict_output['decoded_tokens']
#     decoded_logits = dict_output['decoded_logits']
#     text_tokens = dict_output['text_tokens']