@ -0,0 +1,800 @@ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "6dff5aab", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "import torch\n", "from collections import defaultdict\n", "from collections import Counter\n", "from torchinfo import summary\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import seaborn as sns" ] }, { "cell_type": "markdown", "id": "8571fc67", "metadata": {}, "source": [ "This cell loads a pre-trained Tiny-GPT2 model and its corresponding tokenizer from Hugging Face Transformers. The model is used for causal language modeling tasks, and the tokenizer converts text into tokens that the model can process. These objects are essential for running inference and analyzing the model’s behavior in subsequent steps." ] }, { "cell_type": "code", "execution_count": null, "id": "1ab7c096", "metadata": {}, "outputs": [], "source": [ "model = AutoModelForCausalLM.from_pretrained(\"sshleifer/tiny-gpt2\")\n", "tokenizer = AutoTokenizer.from_pretrained(\"sshleifer/tiny-gpt2\")" ] }, { "cell_type": "code", "execution_count": null, "id": "2881191c", "metadata": {}, "outputs": [], "source": [ "# Sample texts\n", "sample_sentences = [\n", " \"The quick brown fox jumps over the lazy dog.\",\n", " \"Artificial intelligence is revolutionizing the world.\",\n", " \"This model is small but efficient.\",\n", " \"Natural Language Processing is fun!\",\n", " \"GPT models are powerful tools for text generation.\"\n", "]" ] }, { "cell_type": "markdown", "id": "57c536bb", "metadata": {}, "source": [ "### 📜 1. Token Tales: How Long Are Our Sentences, Really?\n", "\n", "Before a model even thinks, it has to read — in tokens. We start by measuring how many tokens different sentences get after Byte-Pair Encoding (BPE). Spoiler: “The quick brown fox” is longer than it looks." ] }, { "cell_type": "code", "execution_count": null, "id": "3bcddece", "metadata": {}, "outputs": [], "source": [ "def visualize_tokenization(sentences):\n", " \"\"\"\n", " Visualizes the tokenization process for a list of sentences using the GPT-2 tokenizer.\n", " \n", " This function demonstrates how text is broken down into tokens and converted to token IDs\n", " that can be processed by the model. For each sentence, it shows:\n", " - The original sentence\n", " - The tokens created by the tokenizer\n", " - The numerical IDs corresponding to those tokens\n", " \n", " Args:\n", " sentences (list): A list of strings, where each string is a sentence to be tokenized\n", " \n", " Prints:\n", " For each sentence, displays:\n", " - Original sentence\n", " - List of tokens\n", " - List of token IDs\n", " - A separator line\n", " \"\"\"\n", " for sent in sentences:\n", " tokens = tokenizer.tokenize(sent)\n", " token_ids = tokenizer.convert_tokens_to_ids(tokens)\n", " print(f\"Sentence: {sent}\")\n", " print(f\"Tokens: {tokens}\")\n", " print(f\"Token IDs: {token_ids}\")\n", " print(\"-\" * 80)" ] }, { "cell_type": "code", "execution_count": null, "id": "2355c826", "metadata": {}, "outputs": [], "source": [ "def plot_token_length_histogram(sentences):\n", " \"\"\"\n", " Creates a bar plot showing the number of tokens for each input sentence.\n", " \n", " This function:\n", " 1. Tokenizes each sentence using the GPT-2 tokenizer\n", " 2. Counts the number of tokens per sentence\n", " 3. Creates a bar plot comparing token lengths across sentences\n", " \n", " Args:\n", " sentences (list): List of strings, where each string is a sentence to analyze\n", " \n", " Returns:\n", " None: Displays a matplotlib plot showing token counts for each sentence\n", " \"\"\"\n", " token_lengths = [len(tokenizer.tokenize(s)) for s in sentences]\n", " df = pd.DataFrame({\"Sentence\": sentences, \"Token Count\": token_lengths})\n", " \n", " plt.figure(figsize=(8, 5))\n", " sns.barplot(x=\"Sentence\", y=\"Token Count\", data=df)\n", " plt.xticks(rotation=45, ha='right')\n", " plt.title(\"Token Lengths per Sentence\")\n", " plt.ylabel(\"Number of Tokens\")\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "935057e8", "metadata": {}, "outputs": [], "source": [ "# Run the visualizations\n", "visualize_tokenization(sample_sentences)" ] }, { "cell_type": "code", "execution_count": null, "id": "ab8b22fa", "metadata": {}, "outputs": [], "source": [ "plot_token_length_histogram(sample_sentences)" ] }, { "cell_type": "markdown", "id": "2e264512", "metadata": {}, "source": [ "### 🏗️ 2. Model Summary: Meet the Neural Net\n", "\n", "Let’s introduce our protagonist. We extract a summary of the architecture — number of layers, embedding sizes, and what makes this model tick (hint: it’s mostly matrix multiplications with a sprinkle of normalization and attention)." ] }, { "cell_type": "code", "execution_count": null, "id": "651b6d35", "metadata": {}, "outputs": [], "source": [ "summary(model, input_size=(1, 32), dtypes=[torch.long])" ] }, { "cell_type": "markdown", "id": "d19fb4c2", "metadata": {}, "source": [ "### 🧮 3. Total Parameters: Counting the Brains\n", "\n", "If neurons were brain cells, parameters are the actual synapses. We count them all. Turns out, even a “tiny” model has millions of these." ] }, { "cell_type": "code", "execution_count": null, "id": "8bbfff85", "metadata": {}, "outputs": [], "source": [ "model.named_parameters" ] }, { "cell_type": "code", "execution_count": null, "id": "a14bcb93", "metadata": {}, "outputs": [], "source": [ "layer_names = []\n", "param_counts = []\n", "\n", "for name, param in model.named_parameters():\n", " layer_names.append(name)\n", " param_counts.append(param.numel())\n", "\n", "print(pd.DataFrame({\"Layer Name\": layer_names, \"Param Count\": param_counts}))\n", "\n", "plt.figure(figsize=(12, 6))\n", "plt.barh(layer_names, param_counts)\n", "plt.xlabel(\"Number of Parameters\")\n", "plt.title(\"Parameter Count by Layer\")\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "aedff7ef", "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(12, 8))\n", "plt.barh(layer_names, param_counts)\n", "plt.xlabel(\"Number of Parameters (log scale)\")\n", "plt.title(\"Parameter Count by Layer (Log Scale)\")\n", "plt.xscale(\"log\")\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "9c7fce21", "metadata": {}, "source": [ "### 🧱 4. Component Breakdown: What’s Inside the Box?\n", "\n", "We dissect the model into its fundamental parts: \n", "\t•\tMLPs for abstract reasoning,\n", "\t•\tProjection weights for dot-product sorcery,\n", "\t•\tLayerNorm for stabilization.\n", "\n", "We also ask: is this component quantizable? What’s its exact type? A full diagnostic report follows." ] }, { "cell_type": "code", "execution_count": null, "id": "f43b4447", "metadata": {}, "outputs": [], "source": [ "def get_param_counts(model):\n", " \"\"\"\n", " Categorizes and counts the number of parameters in a transformer model by component type.\n", "\n", " This function iterates through all named parameters of the provided model and groups them into\n", " the following categories based on their parameter names:\n", " - Embedding: Token and positional embeddings\n", " - Attention: All attention-related parameters\n", " - MLP: Feed-forward (MLP) parameters\n", " - LayerNorm: Layer normalization parameters\n", " - Output Head: Output (language modeling) head parameters\n", " - Other: Any parameters not matching the above categories\n", "\n", " Args:\n", " model: The transformer model (e.g., GPT-2) whose parameters will be analyzed.\n", "\n", " Returns:\n", " dict: A dictionary mapping component names to their total parameter counts.\n", " \"\"\"\n", " param_groups = {\n", " \"Embedding\": 0,\n", " \"Attention\": 0,\n", " \"MLP\": 0,\n", " \"LayerNorm\": 0,\n", " \"Output Head\": 0,\n", " \"Other\": 0,\n", " }\n", "\n", " for name, param in model.named_parameters():\n", " if \"wte\" in name or \"wpe\" in name:\n", " param_groups[\"Embedding\"] += param.numel()\n", " elif \"attn\" in name:\n", " param_groups[\"Attention\"] += param.numel()\n", " elif \"mlp\" in name:\n", " param_groups[\"MLP\"] += param.numel()\n", " elif \"ln\" in name:\n", " param_groups[\"LayerNorm\"] += param.numel()\n", " elif \"lm_head\" in name:\n", " param_groups[\"Output Head\"] += param.numel()\n", " else:\n", " param_groups[\"Other\"] += param.numel()\n", "\n", " return param_groups" ] }, { "cell_type": "code", "execution_count": null, "id": "60145f5b", "metadata": {}, "outputs": [], "source": [ "param_counts = get_param_counts(model)\n", "param_counts" ] }, { "cell_type": "markdown", "id": "b86c9608", "metadata": {}, "source": [ "### 📊 5. Bar Chart Bonanza: Parameters by Component\n", "\n", "Now that we know what each part is, let’s compare their weight — literally. A bar chart reveals which components are the parameter hogs. Spoiler: it’s usually the feed-forward MLP layers flexing hardest." ] }, { "cell_type": "code", "execution_count": null, "id": "58e1e379", "metadata": {}, "outputs": [], "source": [ "def plot_param_bar_chart(param_counts):\n", " \"\"\"\n", " Plots a bar chart showing the number of parameters for each component of a model.\n", "\n", " Args:\n", " param_counts (dict): A dictionary mapping component names (str) to their parameter counts (int).\n", "\n", " Displays:\n", " A seaborn bar plot visualizing the parameter count for each model component.\n", " \"\"\"\n", " df = pd.DataFrame(list(param_counts.items()), columns=[\"Component\", \"Param Count\"])\n", " plt.figure(figsize=(10, 5))\n", " sns.barplot(x=\"Component\", y=\"Param Count\", data=df)\n", " plt.title(\"Parameter Count per Component\")\n", " plt.xticks(rotation=45)\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "f897cb84", "metadata": {}, "outputs": [], "source": [ "plot_param_bar_chart(param_counts)" ] }, { "cell_type": "markdown", "id": "26073ba8", "metadata": {}, "source": [ "### Number of encoder layers" ] }, { "cell_type": "code", "execution_count": null, "id": "6bc4144b", "metadata": {}, "outputs": [], "source": [ "\"\"\"\n", "Counts the number of encoder modules (transformer blocks) in the loaded GPT-2 model.\n", "\n", "This code iterates through the model's named parameters, identifies unique encoder layer indices\n", "by parsing parameter names that start with 'transformer.h.', and prints the total number of encoder modules.\n", "\n", "Returns:\n", " None. Prints the number of encoder modules in the model.\n", "\"\"\"\n", "encoder_layers = set()\n", "for name, _ in model.named_parameters():\n", " if name.startswith(\"transformer.h.\"):\n", " layer_num = name.split(\".\")[2]\n", " encoder_layers.add(layer_num)\n", "\n", "print(f\"Number of encoder modules: {len(encoder_layers)}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "be7842c9", "metadata": {}, "outputs": [], "source": [ "from collections import defaultdict\n", "import torch.nn as nn\n", "\n", "total_params = 0\n", "print(f\"{'Layer Name':60} | {'Param Count':>12} | {'Component':40} | {'Shape':25} | {'Layer Type':15} | {'Quantizable'}\")\n", "print(\"-\" * 180)\n", "\n", "for name, param in model.named_parameters():\n", " count = param.numel()\n", " total_params += count\n", " shape = list(param.shape)\n", "\n", " # Classify component\n", " if 'attn.c_attn.weight' in name:\n", " component = \"QKV Projection Weight\"\n", " layer_type = \"Linear\"\n", " quant_type = \"Weights\"\n", " elif 'attn.c_attn.bias' in name:\n", " component = \"QKV Projection Bias\"\n", " layer_type = \"Linear\"\n", " quant_type = \"Weights\"\n", " elif 'attn.c_proj.weight' in name:\n", " component = \"Attention Output Projection Weight\"\n", " layer_type = \"Linear\"\n", " quant_type = \"Weights\"\n", " elif 'attn.c_proj.bias' in name:\n", " component = \"Attention Output Projection Bias\"\n", " layer_type = \"Linear\"\n", " quant_type = \"Weights\"\n", " elif 'mlp.c_fc.weight' in name:\n", " component = \"MLP FC Weight\"\n", " layer_type = \"Linear\"\n", " quant_type = \"Weights\"\n", " elif 'mlp.c_fc.bias' in name:\n", " component = \"MLP FC Bias\"\n", " layer_type = \"Linear\"\n", " quant_type = \"Weights\"\n", " elif 'mlp.c_proj.weight' in name:\n", " component = \"MLP Output Projection Weight\"\n", " layer_type = \"Linear\"\n", " quant_type = \"Weights\"\n", " elif 'mlp.c_proj.bias' in name:\n", " component = \"MLP Output Projection Bias\"\n", " layer_type = \"Linear\"\n", " quant_type = \"Weights\"\n", " elif 'wte' in name:\n", " component = \"Token Embedding\"\n", " layer_type = \"Embedding\"\n", " quant_type = \"Weights\"\n", " elif 'wpe' in name:\n", " component = \"Positional Embedding\"\n", " layer_type = \"Embedding\"\n", " quant_type = \"Weights\"\n", " elif 'ln_' in name or 'ln_f' in name:\n", " component = \"LayerNorm\"\n", " layer_type = \"LayerNorm\"\n", " quant_type = \"Not Quantizable\"\n", " elif 'lm_head' in name:\n", " component = \"Language Modeling Head\"\n", " layer_type = \"Linear\"\n", " quant_type = \"Weights\"\n", " else:\n", " component = \"Unknown\"\n", " layer_type = \"Unknown\"\n", " quant_type = \"Weights\"\n", "\n", " print(f\"{name:60} | {count:12,} | {component:40} | {str(shape):25} | {layer_type:15} | {quant_type}\")\n", "\n", "print(f\"\\nTotal Parameters: {total_params:,}\")" ] }, { "cell_type": "markdown", "id": "fda4673a", "metadata": {}, "source": [ "### 🧠 6. Memory Matters: How Much RAM Does Each Layer Eat?\n", "\n", "Next, we move into memory usage — both parameter memory and activation memory (i.e., how much RAM it uses at inference). Tiny-GPT2 may be lean, but some layers are sneakily greedy." ] }, { "cell_type": "code", "execution_count": null, "id": "5cb90b7a", "metadata": {}, "outputs": [], "source": [ "\"\"\"\n", "Estimates and prints the memory usage (in megabytes) of each parameter tensor in the model.\n", "\n", "For each named parameter in the model, this code calculates the memory required to store the parameter\n", "(assuming 4 bytes per value for float32 tensors), and prints the layer name along with its estimated memory usage.\n", "\n", "Returns:\n", " None. Prints the estimated memory usage for each layer in MB.\n", "\"\"\"\n", "print(\"\\nEstimated memory usage by layer (MB):\")\n", "for name, param in model.named_parameters():\n", " mem = param.numel() * 4 / (1024**2)\n", " print(f\"{name:60}: {mem:.2f} MB\")" ] }, { "cell_type": "markdown", "id": "56988ba1", "metadata": {}, "source": [ "### 🧹 7. Sparse or Dense? Measuring Redundancy\n", "\n", "We ask: how sparse is each component? A high sparsity means more zeros — and potentially more room for optimization. If you’re a fan of pruning, this one’s for you." ] }, { "cell_type": "code", "execution_count": null, "id": "e1cf4c90", "metadata": {}, "outputs": [], "source": [ "\"\"\"\n", "Calculates and prints the sparsity (percentage of zero weights) for each parameter tensor in the model.\n", "\n", "For each floating-point parameter in the model that requires gradients, this code computes the proportion of elements that are exactly zero,\n", "indicating the level of sparsity (potential redundancy or pruning opportunity) in each layer.\n", "\n", "Returns:\n", " None. Prints the sparsity percentage for each parameter tensor in the model.\n", "\"\"\"\n", "print(\"\\nSparsity (percentage of zero weights):\")\n", "\n", "for name, param in model.named_parameters():\n", " if param.requires_grad and param.is_floating_point():\n", " num_zeros = torch.sum(param == 0).item()\n", " total = param.numel()\n", " sparsity = 100 * num_zeros / total\n", " print(f\"{name:50}: {sparsity:.2f}% zeros\")" ] }, { "cell_type": "markdown", "id": "9ef8c872", "metadata": {}, "source": [ "### ⚙️ 8. FLOPs, Not Drops: Computation Cost by Component\n", "\n", "Let’s measure raw compute. We estimate the number of FLOPs (Floating Point Operations) each component performs during inference. More FLOPs ≠ better — but it does tell us where the computational hotspots are." ] }, { "cell_type": "code", "execution_count": null, "id": "06543941", "metadata": {}, "outputs": [], "source": [ "total_flops = 0\n", "\n", "print(\"\\nEstimated FLOPs per Linear Layer:\")\n", "for name, module in model.named_modules():\n", " if isinstance(module, torch.nn.Linear):\n", " in_features = module.in_features\n", " out_features = module.out_features\n", " flops = in_features * out_features\n", " total_flops += flops\n", " print(f\"{name:50}: {flops:,} FLOPs\")\n", "\n", "print(f\"\\nTotal Estimated Linear FLOPs: {total_flops:,}\")" ] }, { "cell_type": "markdown", "id": "ba187654", "metadata": {}, "source": [ "### 🧲 9. Attention, Please: Visualizing the Transformer’s Focus\n", "\n", "Ever wanted to know which words attend to which? We crack open the attention heads and lay bare their maps. You’ll see, visually, how the model makes sense of “The cat sat on the mat.”" ] }, { "cell_type": "code", "execution_count": null, "id": "58078ab6", "metadata": {}, "outputs": [], "source": [ "# Sample input\n", "text = \"The cat sat on the mat.\"\n", "inputs = tokenizer(text, return_tensors=\"pt\")\n", "tokens = tokenizer.convert_ids_to_tokens(inputs[\"input_ids\"][0])\n", "\n", "# Forward pass with attention output\n", "with torch.no_grad():\n", " outputs = model(**inputs)\n", " attentions = outputs.attentions # List of (batch_size, num_heads, seq_len, seq_len)" ] }, { "cell_type": "code", "execution_count": null, "id": "411efb2d", "metadata": {}, "outputs": [], "source": [ "\n", "# Visualize all heads in the first layer\n", "def plot_attention_heads(attentions, tokens, layer=0):\n", " layer_attn = attentions[layer][0] # shape: (num_heads, seq_len, seq_len)\n", " num_heads = layer_attn.shape[0]\n", " seq_len = len(tokens)\n", "\n", " fig, axes = plt.subplots(1, num_heads, figsize=(4 * num_heads, 4))\n", " if num_heads == 1:\n", " axes = [axes] # ensure iterable if only 1 head\n", "\n", " for h in range(num_heads):\n", " sns.heatmap(\n", " layer_attn[h].numpy(),\n", " xticklabels=tokens,\n", " yticklabels=tokens,\n", " cmap=\"viridis\",\n", " cbar=False,\n", " ax=axes[h]\n", " )\n", " axes[h].set_title(f\"Head {h}\")\n", " axes[h].tick_params(axis='x', rotation=45)\n", " axes[h].tick_params(axis='y', rotation=0)\n", "\n", " plt.suptitle(f\"Self-Attention Heads (Layer {layer})\", fontsize=16)\n", " plt.tight_layout(rect=[0, 0, 1, 0.95])\n", " plt.show()\n", "\n", "plot_attention_heads(attentions, tokens, layer=0)" ] }, { "cell_type": "markdown", "id": "c34ff068", "metadata": {}, "source": [ "### 🔁 10. Attention Rollout: Token Influence, End to End\n", "\n", "We go deeper — not just where attention points, but how information flows through layers. The attention rollout visual shows cumulative token influence from start to finish." ] }, { "cell_type": "code", "execution_count": null, "id": "ee14cbff", "metadata": {}, "outputs": [], "source": [ "# Rollout Function\n", "def compute_attention_rollout(attentions):\n", " \"\"\"\n", " attentions: list of (batch, heads, seq_len, seq_len)\n", " Returns: rollout attention matrix [seq_len, seq_len]\n", " \"\"\"\n", " # Average over heads for each layer\n", " num_layers = len(attentions)\n", " rollout = torch.eye(attentions[0].size(-1)) # identity [seq_len x seq_len]\n", " \n", " for layer_attn in attentions:\n", " avg_attn = layer_attn[0].mean(dim=0) # shape: [seq_len, seq_len]\n", " # Add residual connection (identity matrix)\n", " avg_attn = avg_attn + torch.eye(avg_attn.size(0))\n", " # Normalize rows to sum to 1\n", " avg_attn = avg_attn / avg_attn.sum(dim=-1, keepdim=True)\n", " rollout = avg_attn @ rollout # propagate influence\n", "\n", " return rollout.numpy()\n", "\n", "rollout_matrix = compute_attention_rollout(attentions)" ] }, { "cell_type": "code", "execution_count": null, "id": "45c48cc1", "metadata": {}, "outputs": [], "source": [ "def plot_attention_rollout(rollout, tokens):\n", " plt.figure(figsize=(8, 6))\n", " sns.heatmap(rollout, xticklabels=tokens, yticklabels=tokens, cmap=\"magma\", square=True)\n", " plt.title(\"Attention Rollout: Cumulative Token Influence\")\n", " plt.xlabel(\"Input Token\")\n", " plt.ylabel(\"Output Token\")\n", " plt.xticks(rotation=45)\n", " plt.yticks(rotation=0)\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "plot_attention_rollout(rollout_matrix, tokens)" ] }, { "cell_type": "markdown", "id": "7ffa22eb", "metadata": {}, "source": [ "### 🎯 11. Predicting the Future: Model Outputs & Logits\n", "\n", "Finally, we let the model talk. We give it a prompt, then visualize the top-k next-token predictions and plot the entire softmax distribution. This is where the model’s inner monologue becomes readable." ] }, { "cell_type": "code", "execution_count": null, "id": "040c1ebf", "metadata": {}, "outputs": [], "source": [ "# Sample input\n", "text = \"The quick brown fox\"\n", "inputs = tokenizer(text, return_tensors=\"pt\")\n", "\n", "# Run model and get logits\n", "with torch.no_grad():\n", " outputs = model(**inputs)\n", " logits = outputs.logits # shape: [batch, seq_len, vocab_size]\n", "\n", "# Focus on the last token's logits (for next-token prediction)\n", "last_token_logits = logits[0, -1, :] # shape: [vocab_size]\n", "probs = torch.softmax(last_token_logits, dim=-1)" ] }, { "cell_type": "code", "execution_count": null, "id": "bbe454c6", "metadata": {}, "outputs": [], "source": [ "# Visualize Top-k Predictions\n", "\n", "def plot_top_k_predictions(probs, tokenizer, top_k=10):\n", " top_probs, top_indices = torch.topk(probs, top_k)\n", " top_tokens = tokenizer.convert_ids_to_tokens(top_indices.tolist())\n", "\n", " plt.figure(figsize=(10, 5))\n", " sns.barplot(x=top_tokens, y=top_probs.numpy(), palette=\"Blues_d\")\n", " plt.title(\"Top-k Next Token Predictions\")\n", " plt.xlabel(\"Token\")\n", " plt.ylabel(\"Probability\")\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "plot_top_k_predictions(probs, tokenizer, top_k=10)" ] }, { "cell_type": "code", "execution_count": null, "id": "2f203836", "metadata": {}, "outputs": [], "source": [ "# Visualize Full Softmax Distribution (Top 50 Tokens)\n", "\n", "def plot_softmax_distribution(probs, tokenizer, top_n=50):\n", " sorted_probs, sorted_indices = torch.sort(probs, descending=True)\n", " sorted_tokens = tokenizer.convert_ids_to_tokens(sorted_indices[:top_n].tolist())\n", "\n", " plt.figure(figsize=(12, 5))\n", " sns.barplot(x=sorted_tokens, y=sorted_probs[:top_n].numpy(), palette=\"viridis\")\n", " plt.title(f\"Top {top_n} Tokens in Softmax Distribution\")\n", " plt.xlabel(\"Token\")\n", " plt.ylabel(\"Probability\")\n", " plt.xticks(rotation=90)\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "plot_softmax_distribution(probs, tokenizer, top_n=50)" ] } ], "metadata": { "kernelspec": { "display_name": "quant-env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.3" } }, "nbformat": 4, "nbformat_minor": 5 }