Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import spaces | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import numpy as np | |
| import math | |
| import os | |
| import pickle | |
| import requests | |
| import textwrap | |
| import subprocess | |
| import shutil | |
| import time | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| from transformers import AutoTokenizer | |
| # ============================================================================== | |
| # ------------------------- VERSION 1: SHARED SETUP ---------------------------- | |
| # ============================================================================== | |
| def setup_environment(): | |
| """Checks for and sets up the necessary data for V1.""" | |
| nano_gpt_repo_path = 'nanoGPT' | |
| data_dir_path = 'shakespeare_char' | |
| meta_path = os.path.join(data_dir_path, 'meta.pkl') | |
| if os.path.exists(meta_path): | |
| return | |
| print("Required data not found. Starting one-time setup...") | |
| if not os.path.exists(nano_gpt_repo_path): | |
| try: | |
| subprocess.run(['git', 'clone', 'https://github.com/karpathy/nanoGPT.git'], check=True, capture_output=True, text=True) | |
| except subprocess.CalledProcessError as e: | |
| print(f"Error cloning repository: {e.stderr}") | |
| pass | |
| source_data_dir = os.path.join(nano_gpt_repo_path, 'data', 'shakespeare_char') | |
| if not os.path.exists(data_dir_path) and os.path.exists(source_data_dir): | |
| shutil.copytree(source_data_dir, data_dir_path) | |
| # Check if we can run prepare | |
| prepare_script_path = os.path.join(data_dir_path, 'prepare.py') | |
| if os.path.exists(prepare_script_path) and not os.path.exists(meta_path): | |
| subprocess.run(['python', 'prepare.py'], check=True, cwd=data_dir_path, capture_output=True, text=True) | |
| setup_environment() | |
| def download_file(url, filename): | |
| if os.path.exists(filename): | |
| return | |
| print(f"Downloading '{filename}'...") | |
| try: | |
| response = requests.get(url, stream=True) | |
| response.raise_for_status() | |
| with open(filename, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| except requests.exceptions.RequestException as e: | |
| print(f"Error downloading {url}: {e}") | |
| # ============================================================================== | |
| # ---------------------- VERSION 1: ARCHITECTURE & LOGIC ----------------------- | |
| # ============================================================================== | |
| # V1 Constants and Meta Loading | |
| v1_data_dir = './shakespeare_char/' | |
| v1_meta_url = 'https://huggingface.co/spaces/thejagstudio/NanoDiffusion/resolve/main/meta.pkl' | |
| v1_meta_path = 'meta.pkl' | |
| download_file(v1_meta_url, v1_meta_path) | |
| v1_vocab_size = 65 # Fallback | |
| v1_itos = {} | |
| v1_stoi = {} | |
| if os.path.exists(v1_meta_path): | |
| with open(v1_meta_path, 'rb') as f: | |
| meta = pickle.load(f) | |
| v1_vocab_size = meta['vocab_size'] | |
| v1_itos = meta['itos'] | |
| v1_stoi = meta['stoi'] | |
| v1_context_length = 256 | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| def v1_decode(indices_tensor: torch.Tensor): | |
| if indices_tensor.dim() > 1: | |
| indices_tensor = indices_tensor.squeeze(0) | |
| indices = indices_tensor.cpu().numpy() | |
| return ''.join([v1_itos.get(i, '?') for i in indices]) | |
| def wrap_text(long_text, width=80): | |
| paragraphs = long_text.splitlines() | |
| wrapped = [textwrap.fill(p, width=width) if p else '' for p in paragraphs] | |
| return "\n".join(wrapped) | |
| class V1_GPTConfig: | |
| block_size: int = 1024 | |
| vocab_size: int = 50304 | |
| n_layer: int = 12 | |
| n_head: int = 12 | |
| n_embd: int = 768 | |
| cond_dim: int = 64 | |
| dropout: float = 0.0 | |
| bias: bool = False | |
| class V1_MLP(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) | |
| self.gelu = nn.GELU() | |
| self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) | |
| self.dropout = nn.Dropout(config.dropout) | |
| def forward(self, x): | |
| x = self.c_fc(x) | |
| x = self.gelu(x) | |
| x = self.c_proj(x) | |
| x = self.dropout(x) | |
| return x | |
| class V1_SelfAttention(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| assert config.n_embd % config.n_head == 0 | |
| self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) | |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) | |
| self.attn_dropout = nn.Dropout(config.dropout) | |
| self.resid_dropout = nn.Dropout(config.dropout) | |
| self.n_head = config.n_head | |
| self.n_embd = config.n_embd | |
| self.dropout = config.dropout | |
| self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') | |
| def forward(self, x): | |
| B, T, C = x.size() | |
| q, k, v = self.c_attn(x).split(self.n_embd, dim=2) | |
| k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) | |
| q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) | |
| v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) | |
| if self.flash: | |
| y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=False) | |
| else: | |
| att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) | |
| att = F.softmax(att, dim=-1) | |
| att = self.attn_dropout(att) | |
| y = att @ v | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) | |
| y = self.resid_dropout(self.c_proj(y)) | |
| return y | |
| def v1_modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: | |
| return x * (1 + scale) + shift | |
| def v1_bias_add_scale(x: torch.Tensor, bias: Optional[torch.Tensor], scale: torch.Tensor, residual: Optional[torch.Tensor]) -> torch.Tensor: | |
| if bias is not None: | |
| out = scale * (x + bias) | |
| else: | |
| out = scale * x | |
| if residual is not None: | |
| out = residual + out | |
| return out | |
| class V1_DDiTBlock(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.ln_1 = nn.LayerNorm(config.n_embd, bias=config.bias) | |
| self.attn = V1_SelfAttention(config) | |
| self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias) | |
| self.mlp = V1_MLP(config) | |
| self.adaLN_modulation = nn.Linear(config.cond_dim, 6 * config.n_embd) | |
| self.adaLN_modulation.weight.data.zero_() | |
| self.adaLN_modulation.bias.data.zero_() | |
| def forward(self, x, c): | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c)[:, None].chunk(6, dim=2) | |
| x_skip = x | |
| x = v1_modulate(self.ln_1(x), shift_msa, scale_msa) | |
| x = self.attn(x) | |
| x = v1_bias_add_scale(self.attn(self.ln_1(x)), None, gate_msa, x_skip) | |
| x = v1_bias_add_scale(self.mlp(v1_modulate(self.ln_2(x), shift_mlp, scale_mlp)), None, gate_mlp, x) | |
| return x | |
| class V1_DDitFinalLayer(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.norm_final = nn.LayerNorm(config.n_embd, bias=config.bias) | |
| self.linear = nn.Linear(config.n_embd, config.vocab_size) | |
| self.linear.weight.data.zero_() | |
| self.linear.bias.data.zero_() | |
| self.adaLN_modulation = nn.Linear(config.cond_dim, 2 * config.n_embd) | |
| self.adaLN_modulation.weight.data.zero_() | |
| self.adaLN_modulation.bias.data.zero_() | |
| def forward(self, x, c): | |
| shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2) | |
| x = v1_modulate(self.norm_final(x), shift, scale) | |
| x = self.linear(x) | |
| return x | |
| class V1_TimestepEmbedder(nn.Module): | |
| def __init__(self, hidden_size, frequency_embedding_size=256): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(frequency_embedding_size, hidden_size, bias=True), | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, hidden_size, bias=True), | |
| ) | |
| self.frequency_embedding_size = frequency_embedding_size | |
| def timestep_embedding(t, dim, max_period=10000): | |
| half = dim // 2 | |
| freqs = torch.exp( | |
| -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half | |
| ).to(device=t.device) | |
| args = t[:, None].float() * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| if dim % 2: | |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
| return embedding | |
| def forward(self, t): | |
| t_freq = self.timestep_embedding(t, self.frequency_embedding_size) | |
| t_emb = self.mlp(t_freq) | |
| return t_emb | |
| class V1_GPT(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| assert config.vocab_size is not None | |
| assert config.block_size is not None | |
| self.config = config | |
| self.sigma_map = V1_TimestepEmbedder(config.cond_dim) | |
| self.transformer = nn.ModuleDict(dict( | |
| wte = nn.Embedding(config.vocab_size, config.n_embd), | |
| wpe = nn.Embedding(config.block_size, config.n_embd), | |
| drop = nn.Dropout(config.dropout), | |
| h = nn.ModuleList([V1_DDiTBlock(config) for _ in range(config.n_layer)]), | |
| ln_f = nn.LayerNorm(config.n_embd, bias=config.bias), | |
| )) | |
| self.lm_head = V1_DDitFinalLayer(config) | |
| self.apply(self._init_weights) | |
| for pn, p in self.named_parameters(): | |
| if pn.endswith('c_proj.weight'): | |
| torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| torch.nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| def forward(self, idx, sigma): | |
| sigma = sigma.reshape(-1) | |
| b, t = idx.size() | |
| c = F.silu(self.sigma_map(sigma)) | |
| assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" | |
| pos = torch.arange(0, t, dtype=torch.long, device=device) | |
| tok_emb = self.transformer.wte(idx) | |
| pos_emb = self.transformer.wpe(pos) | |
| x = self.transformer.drop(tok_emb + pos_emb) | |
| for block in self.transformer.h: | |
| x = block(x, c) | |
| x = self.transformer.ln_f(x) | |
| x = self.lm_head(x, c) | |
| x = torch.scatter(x, -1, idx[..., None], torch.zeros_like(x[..., :1])) | |
| return x | |
| class V1_GeometricNoise: | |
| def __init__(self, sigma_min=1e-4, sigma_max=20): | |
| self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max]).to(device) | |
| def rate_noise(self, t): | |
| return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (self.sigmas[1].log() - self.sigmas[0].log()) | |
| def total_noise(self, t): | |
| return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t | |
| def __call__(self, t): | |
| return self.total_noise(t), self.rate_noise(t) | |
| # --- V1 Inference Logic --- | |
| def v1_transition(x_t: torch.Tensor, delta_sigma: torch.Tensor) -> torch.Tensor: | |
| base_prob = (1 - torch.exp(-delta_sigma[..., None])) / v1_vocab_size | |
| trans = torch.ones(*x_t.shape, v1_vocab_size, device=x_t.device) * base_prob | |
| trans = trans.scatter(-1, x_t[..., None], torch.zeros_like(trans)) | |
| diag_fill = 1 - trans.sum(dim=-1, keepdim=True) | |
| trans = trans.scatter(-1, x_t[..., None], diag_fill) | |
| return trans | |
| def v1_staggered_score(score, delta_sigma): | |
| exp_factor = torch.exp(-delta_sigma)[..., None] | |
| correction = ((exp_factor - 1) / (v1_vocab_size * exp_factor)) * score.sum(dim=-1, keepdim=True) | |
| return correction + score / exp_factor | |
| def v1_sample_categorical(probs: torch.Tensor) -> torch.Tensor: | |
| eps = 1e-10 | |
| gumbel_noise = -torch.log(-torch.log(torch.rand_like(probs) + eps) + eps) | |
| return torch.argmax(torch.log(probs + eps) + gumbel_noise, dim=-1) | |
| # --- V1 Model Loading --- | |
| print("Initializing V1 Model...") | |
| v1_model_args = dict(n_layer=6, n_head=6, n_embd=384, cond_dim=64, | |
| bias=False, vocab_size=v1_vocab_size, block_size=v1_context_length, dropout=0.2) | |
| v1_config = V1_GPTConfig(**v1_model_args) | |
| v1_model = V1_GPT(v1_config) | |
| try: | |
| v1_model.load_state_dict( | |
| torch.hub.load_state_dict_from_url( | |
| 'https://huggingface.co/spaces/thejagstudio/NanoDiffusion/resolve/main/final_model.pth?download=true', | |
| map_location=device | |
| ) | |
| ) | |
| v1_model.to(device) | |
| v1_model.eval() | |
| print("V1 Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Failed to load V1 model: {e}") | |
| v1_model = None | |
| v1_noise = V1_GeometricNoise(sigma_min=1e-4, sigma_max=20) | |
| def v1_generate_stream(steps, speed): | |
| """ | |
| Generator function for V1 that yields frames directly. | |
| Combined logic of generation and replay to allow for immediate stopping. | |
| """ | |
| if v1_model is None: | |
| yield "Error: V1 Model not loaded" | |
| return | |
| steps = int(steps) | |
| speed = float(speed) | |
| eps = 1e-5 | |
| # Calculate delay based on speed slider (similar to V2) | |
| # 0.5 is base constant, speed scales it down | |
| delay = 0.5 / max(speed, 0.1) | |
| x = torch.randint(0, v1_vocab_size, (1, v1_context_length), device=device) | |
| initial_text = f"--- Initial Random Noise ---\n\n{wrap_text(v1_decode(x[0]))}" | |
| yield initial_text | |
| time.sleep(delay) | |
| timesteps = torch.linspace(1, eps, steps + 1, device=device) | |
| step_size = (1 - eps) / steps | |
| with torch.no_grad(): | |
| for i in range(steps): | |
| t = timesteps[i] * torch.ones(x.shape[0], 1, device=device) | |
| curr_sigma_bar = v1_noise(t)[0] | |
| next_sigma_bar = v1_noise(t - step_size)[0] | |
| delta_sigma = curr_sigma_bar - next_sigma_bar | |
| log_score = v1_model(x, curr_sigma_bar) | |
| score = torch.exp(log_score) | |
| stag_score = v1_staggered_score(score, delta_sigma) | |
| probs = stag_score * v1_transition(x, delta_sigma) | |
| x = v1_sample_categorical(probs) | |
| progress_text = f"--- Denoising Step {i + 1}/{steps} ---\n\n{wrap_text(v1_decode(x[0]))}" | |
| yield progress_text | |
| # Artificial delay for visualization | |
| if speed < 20: | |
| time.sleep(delay) | |
| t = timesteps[steps] * torch.ones(x.shape[0], 1, device=device) | |
| curr_sigma_bar = v1_noise(t)[0] | |
| delta_sigma = curr_sigma_bar | |
| log_score = v1_model(x, curr_sigma_bar) | |
| score = torch.exp(log_score) | |
| stag_score = v1_staggered_score(score, delta_sigma) | |
| probs = stag_score * v1_transition(x, delta_sigma) | |
| x = v1_sample_categorical(probs) | |
| final_text = f"--- Final Denoised Text (Step {steps}) ---\n\n{wrap_text(v1_decode(x[0]))}" | |
| yield final_text | |
| # ============================================================================== | |
| # ---------------------- VERSION 2: ARCHITECTURE & LOGIC ----------------------- | |
| # ============================================================================== | |
| # PLEASE UPDATE THIS PATH TO YOUR ACTUAL LOCAL FILE OR URL | |
| V2_MODEL_PATH = "checkpoints/model_fp32.pt" | |
| class V2_RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def forward(self, x): | |
| var = x.pow(2).mean(-1, keepdim=True) | |
| x = x * torch.rsqrt(var + self.eps) | |
| return self.weight * x | |
| class V2_RotaryEmbedding(nn.Module): | |
| def __init__(self, dim, max_position_embeddings=16384, base=100000, scaling_factor=1.0): | |
| super().__init__() | |
| self.scaling_factor = scaling_factor | |
| self.dim = dim | |
| self.base = base | |
| self.max_position_embeddings = max_position_embeddings | |
| self.inv_freq = None | |
| self._cache = {} | |
| def _update_freqs(self, device): | |
| base = self.base * (self.scaling_factor ** (self.dim / (self.dim - 2))) | |
| inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) | |
| self.inv_freq = inv_freq | |
| def forward(self, x, seq_len=None): | |
| if seq_len is None: | |
| seq_len = x.shape[-2] | |
| if self.inv_freq is None or self.inv_freq.device != x.device: | |
| self._update_freqs(x.device) | |
| cache_key = (seq_len, x.device, x.dtype) | |
| if cache_key in self._cache: | |
| return self._cache[cache_key] | |
| t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) | |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| cos = emb.cos()[None, None, :, :] | |
| sin = emb.sin()[None, None, :, :] | |
| self._cache[cache_key] = (cos, sin) | |
| if len(self._cache) > 10: | |
| self._cache.pop(next(iter(self._cache))) | |
| return cos, sin | |
| def v2_apply_rotary_pos_emb(q, k, cos, sin): | |
| def rotate_half(x): | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| q_embed = (q * cos) + (rotate_half(q) * sin) | |
| k_embed = (k * cos) + (rotate_half(k) * sin) | |
| return q_embed, k_embed | |
| class V2_DiffusionAttention(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| self.num_heads = config.num_attention_heads | |
| self.head_dim = self.hidden_size // self.num_heads | |
| self.num_key_value_heads = config.num_key_value_heads | |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads | |
| self.use_flash_attn = config.use_flash_attn | |
| self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) | |
| self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) | |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) | |
| def forward(self, hidden_states, freqs_cis, attention_mask=None, past_kv=None): | |
| bsz, q_len, _ = hidden_states.size() | |
| q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | |
| k = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
| v = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
| cos, sin = freqs_cis | |
| cos = cos[:, :, :q_len, :] | |
| sin = sin[:, :, :q_len, :] | |
| q, k = v2_apply_rotary_pos_emb(q, k, cos, sin) | |
| if past_kv is not None: | |
| cache_k, cache_v = past_kv | |
| k = torch.cat([cache_k, k], dim=2) | |
| v = torch.cat([cache_v, v], dim=2) | |
| current_kv = (k, v) | |
| k = k.repeat_interleave(self.num_key_value_groups, dim=1) | |
| v = v.repeat_interleave(self.num_key_value_groups, dim=1) | |
| attn_mask = None | |
| if attention_mask is not None: | |
| attn_mask = attention_mask[:, None, None, :].to(dtype=q.dtype) | |
| attn_mask = (1.0 - attn_mask) * torch.finfo(q.dtype).min | |
| output = F.scaled_dot_product_attention( | |
| q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False | |
| ) | |
| output = output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size) | |
| return self.o_proj(output), current_kv | |
| class V2_MLP(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) | |
| self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) | |
| self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) | |
| self.act_fn = nn.SiLU() | |
| def forward(self, x): | |
| return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) | |
| class V2_BlockDiffusionBlock(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.self_attn = V2_DiffusionAttention(config) | |
| self.mlp = V2_MLP(config) | |
| self.input_layernorm = V2_RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.post_attention_layernorm = V2_RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.use_activation_checkpointing = config.use_activation_checkpointing | |
| def forward(self, hidden_states, freqs_cis, attention_mask, past_kv): | |
| return self._forward(hidden_states, freqs_cis, attention_mask, past_kv) | |
| def _forward(self, hidden_states, freqs_cis, attention_mask, past_kv): | |
| residual = hidden_states | |
| hidden_states = self.input_layernorm(hidden_states) | |
| attn_out, new_kv = self.self_attn(hidden_states, freqs_cis, attention_mask, past_kv) | |
| hidden_states = residual + attn_out | |
| residual = hidden_states | |
| hidden_states = self.post_attention_layernorm(hidden_states) | |
| hidden_states = residual + self.mlp(hidden_states) | |
| return hidden_states, new_kv | |
| class V2_ModelConfig: | |
| vocab_size: int = 151936 | |
| hidden_size: int = 1024 | |
| intermediate_size: int = 2816 | |
| num_hidden_layers: int = 16 | |
| num_attention_heads: int = 16 | |
| num_key_value_heads: int = 4 | |
| max_position_embeddings: int = 16384 | |
| rms_norm_eps: float = 1e-6 | |
| rope_theta: float = 100000.0 | |
| pad_token_id: int = 0 | |
| mask_token_id: int = 1 | |
| use_flash_attn: bool = True | |
| use_activation_checkpointing: bool = False | |
| attention_dropout: float = 0.0 | |
| hidden_dropout: float = 0.0 | |
| ModelConfig = V2_ModelConfig | |
| class V2_DiffusionLLM(nn.Module): | |
| def __init__(self, config: V2_ModelConfig): | |
| super().__init__() | |
| self.config = config | |
| pad_idx = config.pad_token_id if config.pad_token_id < config.vocab_size else None | |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=pad_idx) | |
| self.layers = nn.ModuleList([V2_BlockDiffusionBlock(config) for _ in range(config.num_hidden_layers)]) | |
| self.norm = V2_RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| self.rotary_emb = V2_RotaryEmbedding( | |
| config.hidden_size // config.num_attention_heads, | |
| config.max_position_embeddings | |
| ) | |
| self.lm_head.weight = self.embed_tokens.weight | |
| def forward(self, input_ids, attention_mask=None, past_key_values=None): | |
| bsz, seqlen = input_ids.shape | |
| hidden_states = self.embed_tokens(input_ids) | |
| freqs_cis = self.rotary_emb(hidden_states, seq_len=seqlen) | |
| if past_key_values is None: | |
| past_key_values = [None] * len(self.layers) | |
| new_kvs = [] | |
| for i, layer in enumerate(self.layers): | |
| hidden_states, kv = layer(hidden_states, freqs_cis, attention_mask, past_key_values[i]) | |
| new_kvs.append(kv) | |
| hidden_states = self.norm(hidden_states) | |
| logits = self.lm_head(hidden_states) | |
| return logits, new_kvs | |
| DiffusionLLM = V2_DiffusionLLM | |
| # --- V2 Loading Logic --- | |
| print("Initializing V2 components...") | |
| v2_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") | |
| if v2_tokenizer.pad_token is None: | |
| v2_tokenizer.pad_token = v2_tokenizer.eos_token | |
| v2_model = None | |
| v2_config = None | |
| if os.path.exists(V2_MODEL_PATH): | |
| print(f"Loading V2 model from {V2_MODEL_PATH}...") | |
| try: | |
| checkpoint = torch.load(V2_MODEL_PATH, map_location=device, weights_only=False) | |
| v2_config = checkpoint['config'] | |
| v2_model = V2_DiffusionLLM(v2_config) | |
| state_dict = checkpoint['model_state'] | |
| state_dict = {k: v.float() for k, v in state_dict.items()} | |
| v2_model.load_state_dict(state_dict) | |
| v2_model = v2_model.to(device) | |
| v2_model.eval() | |
| print("V2 Model loaded.") | |
| except Exception as e: | |
| print(f"Error loading V2 model: {e}") | |
| else: | |
| print(f"V2 Model file not found at {V2_MODEL_PATH}. Version 2 tab will not work without it.") | |
| def v2_generate_block_diffusion(prompt, steps, block_size, max_new_tokens, replay_speed): | |
| """ | |
| Refactored to yield frames for real-time streaming. | |
| """ | |
| if v2_model is None: | |
| yield "Error: V2 Model not found. Check path." | |
| return | |
| v2_model.eval() | |
| # Handle inputs | |
| steps = int(steps) | |
| block_size = int(block_size) | |
| max_new_tokens = int(max_new_tokens) | |
| speed = float(replay_speed) | |
| prompt_ids = v2_tokenizer.encode(prompt, return_tensors="pt").to(device) | |
| config = v2_model.config | |
| num_blocks = max_new_tokens // block_size | |
| context_ids = prompt_ids | |
| # Helper params | |
| temperature = 1.0 | |
| top_k = 40 | |
| top_p = 0.9 | |
| repetition_penalty = 1.2 | |
| # Calculate delay based on speed slider | |
| delay = 0.5 / max(speed, 0.1) | |
| for block_idx in range(num_blocks): | |
| mask_block = torch.full((1, block_size), config.mask_token_id, device=device) | |
| is_masked = torch.ones(1, block_size, dtype=torch.bool, device=device) | |
| for step_idx in range(steps): | |
| # --- SNAPSHOT & YIELD --- | |
| # Decode context | |
| ctx_str = v2_tokenizer.decode(context_ids[0], skip_special_tokens=True) | |
| # Decode block with masking visual | |
| block_tokens = mask_block[0].tolist() | |
| block_vis = [] | |
| for i, tid in enumerate(block_tokens): | |
| if is_masked[0, i]: | |
| block_vis.append("β") # Mask symbol | |
| else: | |
| block_vis.append(v2_tokenizer.decode([tid], skip_special_tokens=False)) | |
| block_str = "".join(block_vis) | |
| frame_text = (f"--- Generating Block {block_idx+1}/{num_blocks} | Step {step_idx+1}/{steps} ---\n\n" | |
| f"{ctx_str}{block_str}") | |
| yield frame_text | |
| # Artificial delay to visualize the step | |
| if speed < 20: # If max speed, skip sleep | |
| time.sleep(delay) | |
| # ------------------------ | |
| full_input = torch.cat([context_ids, mask_block], dim=1) | |
| attention_mask = torch.ones_like(full_input, dtype=torch.float32) | |
| logits, _ = v2_model(full_input, attention_mask=attention_mask) | |
| block_logits = logits[:, -block_size:, :] | |
| # Repetition penalty | |
| if repetition_penalty != 1.0: | |
| seen_tokens = set(context_ids[0].tolist()) | |
| for i in range(block_size): | |
| if not is_masked[0, i]: | |
| seen_tokens.add(mask_block[0, i].item()) | |
| for token_id in seen_tokens: | |
| if token_id < block_logits.shape[-1]: | |
| if block_logits[0, :, token_id].mean() > 0: | |
| block_logits[:, :, token_id] /= repetition_penalty | |
| else: | |
| block_logits[:, :, token_id] *= repetition_penalty | |
| block_logits = block_logits / temperature | |
| # Top-K | |
| if top_k > 0: | |
| top_k_logits, top_k_indices = torch.topk(block_logits, top_k, dim=-1) | |
| block_logits = torch.full_like(block_logits, float('-inf')) | |
| block_logits.scatter_(-1, top_k_indices, top_k_logits) | |
| # Top-P | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(block_logits, descending=True, dim=-1) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove) | |
| block_logits[indices_to_remove] = float('-inf') | |
| probs = F.softmax(block_logits, dim=-1) | |
| probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0) | |
| probs = probs.clamp(min=1e-10) | |
| probs = probs / probs.sum(dim=-1, keepdim=True) | |
| sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1) | |
| sampled_tokens = sampled_tokens.view(1, block_size) | |
| confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1) | |
| tokens_to_unmask = max(1, block_size // steps) | |
| if step_idx == steps - 1: | |
| tokens_to_unmask = is_masked.sum().item() | |
| if tokens_to_unmask > 0 and is_masked.sum() > 0: | |
| masked_confidence = confidence.clone() | |
| masked_confidence[~is_masked] = -1.0 | |
| num_to_unmask = min(tokens_to_unmask, is_masked.sum().item()) | |
| _, top_indices = torch.topk(masked_confidence.view(-1), num_to_unmask) | |
| for idx in top_indices: | |
| mask_block[0, idx] = sampled_tokens[0, idx] | |
| is_masked[0, idx] = False | |
| context_ids = torch.cat([context_ids, mask_block], dim=1) | |
| generated_ids = context_ids[0].tolist() | |
| final_text = v2_tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| yield final_text | |
| # ============================================================================== | |
| # ------------------------------- GRADIO UI ------------------------------------ | |
| # ============================================================================== | |
| css = '''.gradio-container > .fillable {max-width: 900px !important} | |
| h3{margin-top: 1em} | |
| p{margin-top: 0} | |
| textarea{font-family: monospace; background-color: #1a1b1e; color: #e0e0e0} | |
| ''' | |
| with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo: | |
| gr.Markdown("# Diffusion Language Models Playground") | |
| with gr.Tabs(): | |
| # --- TAB 1: VERSION 1 (CHAR DIFFUSION) --- | |
| with gr.Tab("Version 1: Character Diffusion (NanoGPT)"): | |
| gr.Markdown("### Tiny 11M parameter character-based continuous diffusion.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| v1_steps = gr.Slider(64, 512, 128, step=1, label="Denoising Steps") | |
| v1_speed = gr.Slider(1, 20, 10, step=1, label="Generation/Replay Speed") | |
| with gr.Row(): | |
| v1_btn = gr.Button("Generate", variant="primary") | |
| v1_stop = gr.Button("Stop", variant="stop") | |
| with gr.Column(scale=2): | |
| v1_out = gr.Textbox(label="Generated Text", lines=15, interactive=False) | |
| # V1 Logic: Merged generation and replay for proper stopping | |
| v1_event = v1_btn.click(v1_generate_stream, inputs=[v1_steps, v1_speed], outputs=[v1_out]) | |
| v1_stop.click(fn=None, inputs=None, outputs=None, cancels=[v1_event]) | |
| # --- TAB 2: VERSION 2 (BLOCK DIFFUSION) --- | |
| with gr.Tab("Version 2: Block Diffusion (LLaDA-style)"): | |
| gr.Markdown("### Block-based diffusion using Qwen tokenizer.") | |
| if v2_model is None: | |
| gr.Warning(f"V2 Model not loaded. Please check path: {V2_MODEL_PATH}") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| v2_prompt = gr.Textbox(label="Prompt", value="The king went to the") | |
| v2_steps = gr.Slider(4, 64, 16, step=1, label="Steps per Block") | |
| v2_block_size = gr.Slider(8, 126, 32, step=8, label="Block Size") | |
| v2_max_tokens = gr.Slider(32, 1024, 128, step=32, label="Total New Tokens") | |
| v2_speed = gr.Slider(1, 20, 1, step=1, label="Generation/Replay Speed") | |
| with gr.Row(): | |
| v2_btn = gr.Button("Generate", variant="primary") | |
| v2_stop = gr.Button("Stop", variant="stop") | |
| with gr.Column(scale=2): | |
| v2_out = gr.Textbox(label="Generated Text", lines=15, interactive=False) | |
| # V2 Logic | |
| v2_event = v2_btn.click( | |
| v2_generate_block_diffusion, | |
| inputs=[v2_prompt, v2_steps, v2_block_size, v2_max_tokens, v2_speed], | |
| outputs=[v2_out] | |
| ) | |
| v2_stop.click(fn=None, inputs=None, outputs=None, cancels=[v2_event]) | |
| if __name__ == "__main__": | |
| demo.launch() |