thejagstudio commited on
Commit
94f7913
Β·
1 Parent(s): 719d04f

Upload 8 files

Browse files
Files changed (7) hide show
  1. README.md +12 -10
  2. app-0.py +467 -0
  3. app.py +798 -0
  4. final_model.pth +3 -0
  5. meta.pkl +3 -0
  6. model_epoch_1.pth +3 -0
  7. requirements.txt +5 -0
README.md CHANGED
@@ -1,10 +1,12 @@
1
- ---
2
- title: NanoDiffusion
3
- emoji: πŸƒ
4
- colorFrom: yellow
5
- colorTo: green
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
+ ---
2
+ title: Diffusion GPT
3
+ emoji: πŸ–ŠοΈπŸ“–πŸŒ€
4
+ colorFrom: gray
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.49.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app-0.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
+ import math
8
+ import os
9
+ import pickle
10
+ import requests
11
+ import textwrap
12
+ import subprocess
13
+ import shutil
14
+ import time
15
+ from dataclasses import dataclass
16
+ from typing import Optional
17
+
18
+ # --- 1. Automated Environment and Data Setup ---
19
+
20
+ def setup_environment():
21
+ """
22
+ Checks for and sets up the necessary data and code.
23
+ - Clones nanoGPT if not present.
24
+ - Copies the shakespeare_char dataset directory.
25
+ - Runs the data preparation script to create meta.pkl and binary files.
26
+ This function makes the script self-contained.
27
+ """
28
+ nano_gpt_repo_path = 'nanoGPT'
29
+ data_dir_path = 'shakespeare_char'
30
+ meta_path = os.path.join(data_dir_path, 'meta.pkl')
31
+
32
+ if os.path.exists(meta_path):
33
+ print("Dataset and metadata found. Skipping setup.")
34
+ return
35
+
36
+ print("Required data not found. Starting one-time setup...")
37
+
38
+ if not os.path.exists(nano_gpt_repo_path):
39
+ print(f"Cloning nanoGPT repository...")
40
+ try:
41
+ subprocess.run(
42
+ ['git', 'clone', 'https://github.com/karpathy/nanoGPT.git'],
43
+ check=True, capture_output=True, text=True
44
+ )
45
+ print("Cloned successfully.")
46
+ except subprocess.CalledProcessError as e:
47
+ print(f"Error cloning repository: {e.stderr}")
48
+ raise
49
+ else:
50
+ print("nanoGPT repository already exists.")
51
+
52
+ source_data_dir = os.path.join(nano_gpt_repo_path, 'data', 'shakespeare_char')
53
+ if not os.path.exists(data_dir_path):
54
+ print(f"Copying '{source_data_dir}' to '{data_dir_path}'...")
55
+ shutil.copytree(source_data_dir, data_dir_path)
56
+ print("Copied successfully.")
57
+ else:
58
+ print(f"'{data_dir_path}' directory already exists.")
59
+
60
+ prepare_script_path = os.path.join(data_dir_path, 'prepare.py')
61
+ if not os.path.exists(meta_path):
62
+ print(f"Running data preparation script: '{prepare_script_path}'...")
63
+ try:
64
+ subprocess.run(
65
+ ['python', 'prepare.py'],
66
+ check=True, cwd=data_dir_path, capture_output=True, text=True
67
+ )
68
+ print("Data preparation script finished successfully.")
69
+ except subprocess.CalledProcessError as e:
70
+ print(f"Error running prepare.py: {e.stderr}")
71
+ raise
72
+
73
+ print("Setup complete.")
74
+
75
+ setup_environment()
76
+
77
+ # --- 2. Global Setup & Helper Functions ---
78
+
79
+ data_dir = './shakespeare_char/'
80
+ def download_file(url, filename):
81
+ """Downloads a file from a URL if it doesn't exist."""
82
+ if os.path.exists(filename):
83
+ print(f"'{filename}' already exists. Skipping download.")
84
+ return
85
+ print(f"Downloading '{filename}' from '{url}'...")
86
+ try:
87
+ response = requests.get(url, stream=True)
88
+ response.raise_for_status() # Check for download errors
89
+ with open(filename, 'wb') as f:
90
+ for chunk in response.iter_content(chunk_size=8192):
91
+ f.write(chunk)
92
+ print("Download complete.")
93
+ except requests.exceptions.RequestException as e:
94
+ print(f"Error downloading {url}: {e}")
95
+ raise
96
+
97
+ # Define file URLs and local paths
98
+ meta_url = 'https://huggingface.co/spaces/thejagstudio/diffusion-gpt/resolve/main/meta.pkl'
99
+ meta_path = 'meta.pkl'
100
+ download_file(meta_url, meta_path)
101
+ with open(meta_path, 'rb') as f:
102
+ meta = pickle.load(f)
103
+
104
+ vocab_size = meta['vocab_size']
105
+ itos = meta['itos']
106
+ stoi = meta['stoi']
107
+ context_length = 256
108
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
109
+
110
+ def decode(indices_tensor: torch.Tensor):
111
+ if indices_tensor.dim() > 1:
112
+ indices_tensor = indices_tensor.squeeze(0)
113
+ indices = indices_tensor.cpu().numpy()
114
+ return ''.join([itos.get(i, '?') for i in indices])
115
+
116
+ def wrap_text(long_text, width=80):
117
+ paragraphs = long_text.splitlines()
118
+ wrapped = [textwrap.fill(p, width=width) if p else '' for p in paragraphs]
119
+ return "\n".join(wrapped)
120
+
121
+
122
+ # --- 3. Model Architecture (Identical to Notebook) ---
123
+
124
+ @dataclass
125
+ class GPTConfig:
126
+ block_size: int = 1024
127
+ vocab_size: int = 50304
128
+ n_layer: int = 12
129
+ n_head: int = 12
130
+ n_embd: int = 768
131
+ cond_dim: int = 64
132
+ dropout: float = 0.0
133
+ bias: bool = False
134
+
135
+ class MLP(nn.Module):
136
+ def __init__(self, config):
137
+ super().__init__()
138
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
139
+ self.gelu = nn.GELU()
140
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
141
+ self.dropout = nn.Dropout(config.dropout)
142
+ def forward(self, x):
143
+ x = self.c_fc(x)
144
+ x = self.gelu(x)
145
+ x = self.c_proj(x)
146
+ x = self.dropout(x)
147
+ return x
148
+
149
+ class SelfAttention(nn.Module):
150
+ def __init__(self, config):
151
+ super().__init__()
152
+ assert config.n_embd % config.n_head == 0
153
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
154
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
155
+ self.attn_dropout = nn.Dropout(config.dropout)
156
+ self.resid_dropout = nn.Dropout(config.dropout)
157
+ self.n_head = config.n_head
158
+ self.n_embd = config.n_embd
159
+ self.dropout = config.dropout
160
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
161
+ def forward(self, x):
162
+ B, T, C = x.size()
163
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
164
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
165
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
166
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
167
+ if self.flash:
168
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=False)
169
+ else:
170
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
171
+ att = F.softmax(att, dim=-1)
172
+ att = self.attn_dropout(att)
173
+ y = att @ v
174
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
175
+ y = self.resid_dropout(self.c_proj(y))
176
+ return y
177
+
178
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
179
+ return x * (1 + scale) + shift
180
+
181
+ def bias_add_scale(x: torch.Tensor, bias: Optional[torch.Tensor], scale: torch.Tensor, residual: Optional[torch.Tensor]) -> torch.Tensor:
182
+ if bias is not None:
183
+ out = scale * (x + bias)
184
+ else:
185
+ out = scale * x
186
+ if residual is not None:
187
+ out = residual + out
188
+ return out
189
+
190
+ class DDiTBlock(nn.Module):
191
+ def __init__(self, config):
192
+ super().__init__()
193
+ self.ln_1 = nn.LayerNorm(config.n_embd, bias=config.bias)
194
+ self.attn = SelfAttention(config)
195
+ self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias)
196
+ self.mlp = MLP(config)
197
+ self.adaLN_modulation = nn.Linear(config.cond_dim, 6 * config.n_embd)
198
+ self.adaLN_modulation.weight.data.zero_()
199
+ self.adaLN_modulation.bias.data.zero_()
200
+ def forward(self, x, c):
201
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
202
+ x_skip = x
203
+ x = modulate(self.ln_1(x), shift_msa, scale_msa)
204
+ x = self.attn(x)
205
+ x = bias_add_scale(self.attn(self.ln_1(x)), None, gate_msa, x_skip)
206
+ x = bias_add_scale(self.mlp(modulate(self.ln_2(x), shift_mlp, scale_mlp)), None, gate_mlp, x)
207
+ return x
208
+
209
+ class DDitFinalLayer(nn.Module):
210
+ def __init__(self, config):
211
+ super().__init__()
212
+ self.norm_final = nn.LayerNorm(config.n_embd, bias=config.bias)
213
+ self.linear = nn.Linear(config.n_embd, config.vocab_size)
214
+ self.linear.weight.data.zero_()
215
+ self.linear.bias.data.zero_()
216
+ self.adaLN_modulation = nn.Linear(config.cond_dim, 2 * config.n_embd)
217
+ self.adaLN_modulation.weight.data.zero_()
218
+ self.adaLN_modulation.bias.data.zero_()
219
+ def forward(self, x, c):
220
+ shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
221
+ x = modulate(self.norm_final(x), shift, scale)
222
+ x = self.linear(x)
223
+ return x
224
+
225
+ class TimestepEmbedder(nn.Module):
226
+ def __init__(self, hidden_size, frequency_embedding_size=256):
227
+ super().__init__()
228
+ self.mlp = nn.Sequential(
229
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
230
+ nn.SiLU(),
231
+ nn.Linear(hidden_size, hidden_size, bias=True),
232
+ )
233
+ self.frequency_embedding_size = frequency_embedding_size
234
+ @staticmethod
235
+ def timestep_embedding(t, dim, max_period=10000):
236
+ half = dim // 2
237
+ freqs = torch.exp(
238
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
239
+ ).to(device=t.device)
240
+ args = t[:, None].float() * freqs[None]
241
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
242
+ if dim % 2:
243
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
244
+ return embedding
245
+ def forward(self, t):
246
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
247
+ t_emb = self.mlp(t_freq)
248
+ return t_emb
249
+
250
+ class GPT(nn.Module):
251
+ def __init__(self, config):
252
+ super().__init__()
253
+ assert config.vocab_size is not None
254
+ assert config.block_size is not None
255
+ self.config = config
256
+ self.sigma_map = TimestepEmbedder(config.cond_dim)
257
+ self.transformer = nn.ModuleDict(dict(
258
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
259
+ wpe = nn.Embedding(config.block_size, config.n_embd),
260
+ drop = nn.Dropout(config.dropout),
261
+ h = nn.ModuleList([DDiTBlock(config) for _ in range(config.n_layer)]),
262
+ ln_f = nn.LayerNorm(config.n_embd, bias=config.bias),
263
+ ))
264
+ self.lm_head = DDitFinalLayer(config)
265
+ self.apply(self._init_weights)
266
+ for pn, p in self.named_parameters():
267
+ if pn.endswith('c_proj.weight'):
268
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
269
+ def _init_weights(self, module):
270
+ if isinstance(module, nn.Linear):
271
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
272
+ if module.bias is not None:
273
+ torch.nn.init.zeros_(module.bias)
274
+ elif isinstance(module, nn.Embedding):
275
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
276
+ def forward(self, idx, sigma):
277
+ sigma = sigma.reshape(-1)
278
+ b, t = idx.size()
279
+ c = F.silu(self.sigma_map(sigma))
280
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
281
+ pos = torch.arange(0, t, dtype=torch.long, device=device)
282
+ tok_emb = self.transformer.wte(idx)
283
+ pos_emb = self.transformer.wpe(pos)
284
+ x = self.transformer.drop(tok_emb + pos_emb)
285
+ for block in self.transformer.h:
286
+ x = block(x, c)
287
+ x = self.transformer.ln_f(x)
288
+ x = self.lm_head(x, c)
289
+ x = torch.scatter(x, -1, idx[..., None], torch.zeros_like(x[..., :1]))
290
+ return x
291
+
292
+ class GeometricNoise:
293
+ def __init__(self, sigma_min=1e-4, sigma_max=20):
294
+ self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max]).to(device)
295
+ def rate_noise(self, t):
296
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (self.sigmas[1].log() - self.sigmas[0].log())
297
+ def total_noise(self, t):
298
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
299
+ def __call__(self, t):
300
+ return self.total_noise(t), self.rate_noise(t)
301
+
302
+ # --- 4. Inference & Sampling Logic (Identical to Notebook) ---
303
+
304
+ def transition(x_t: torch.Tensor, delta_sigma: torch.Tensor) -> torch.Tensor:
305
+ base_prob = (1 - torch.exp(-delta_sigma[..., None])) / vocab_size
306
+ trans = torch.ones(*x_t.shape, vocab_size, device=x_t.device) * base_prob
307
+ trans = trans.scatter(-1, x_t[..., None], torch.zeros_like(trans))
308
+ diag_fill = 1 - trans.sum(dim=-1, keepdim=True)
309
+ trans = trans.scatter(-1, x_t[..., None], diag_fill)
310
+ return trans
311
+
312
+ def staggered_score(score, delta_sigma):
313
+ exp_factor = torch.exp(-delta_sigma)[..., None]
314
+ correction = ((exp_factor - 1) / (vocab_size * exp_factor)) * score.sum(dim=-1, keepdim=True)
315
+ return correction + score / exp_factor
316
+
317
+ def sample_categorical(probs: torch.Tensor) -> torch.Tensor:
318
+ eps = 1e-10
319
+ gumbel_noise = -torch.log(-torch.log(torch.rand_like(probs) + eps) + eps)
320
+ return torch.argmax(torch.log(probs + eps) + gumbel_noise, dim=-1)
321
+
322
+
323
+ # --- 5. Model Initialization and Loading ---
324
+
325
+ print("Initializing and loading the pretrained model...")
326
+ model_args = dict(n_layer=6, n_head=6, n_embd=384, cond_dim=64,
327
+ bias=False, vocab_size=vocab_size, block_size=context_length, dropout=0.2)
328
+ config = GPTConfig(**model_args)
329
+ model = GPT(config)
330
+
331
+ model.load_state_dict(
332
+ torch.hub.load_state_dict_from_url(
333
+ 'https://huggingface.co/spaces/thejagstudio/diffusion-gpt/resolve/main/final_model.pth?download=true',
334
+ # 'https://huggingface.co/spaces/thejagstudio/diffusion-gpt/resolve/main/model_epoch_1.pth?download=true',
335
+ map_location=device
336
+ )
337
+ )
338
+ model.to(device)
339
+ model.eval()
340
+
341
+ noise = GeometricNoise(sigma_min=1e-4, sigma_max=20)
342
+ print("Model loaded successfully.")
343
+
344
+
345
+ # --- 6. Gradio Interface Logic ---
346
+ @spaces.GPU
347
+ def generate_text(steps):
348
+ """
349
+ Fast generation phase. Runs the diffusion process and stores all
350
+ intermediate frames in a list, then returns the final text and the list.
351
+ """
352
+ steps = int(steps)
353
+ eps = 1e-5
354
+
355
+ # List to store each frame of the diffusion process
356
+ diffusion_frames = []
357
+
358
+ # Start with a random sample
359
+ x = torch.randint(0, vocab_size, (1, context_length), device=device)
360
+ initial_text = f"--- Initial Random Noise ---\n\n{wrap_text(decode(x[0]))}"
361
+ diffusion_frames.append(initial_text)
362
+
363
+ timesteps = torch.linspace(1, eps, steps + 1, device=device)
364
+ step_size = (1 - eps) / steps
365
+
366
+ with torch.no_grad():
367
+ for i in range(steps):
368
+ t = timesteps[i] * torch.ones(x.shape[0], 1, device=device)
369
+ curr_sigma_bar = noise(t)[0]
370
+
371
+ next_sigma_bar = noise(t - step_size)[0]
372
+ delta_sigma = curr_sigma_bar - next_sigma_bar
373
+
374
+ log_score = model(x, curr_sigma_bar)
375
+ score = torch.exp(log_score)
376
+
377
+ stag_score = staggered_score(score, delta_sigma)
378
+ probs = stag_score * transition(x, delta_sigma)
379
+ x = sample_categorical(probs)
380
+
381
+ # Store the frame
382
+ progress_text = f"--- Denoising Step {i + 1}/{steps} ---\n\n{wrap_text(decode(x[0]))}"
383
+ diffusion_frames.append(progress_text)
384
+
385
+ # Final denoising step
386
+ t = timesteps[steps] * torch.ones(x.shape[0], 1, device=device)
387
+ curr_sigma_bar = noise(t)[0]
388
+ delta_sigma = curr_sigma_bar
389
+
390
+ log_score = model(x, curr_sigma_bar)
391
+ score = torch.exp(log_score)
392
+ stag_score = staggered_score(score, delta_sigma)
393
+ probs = stag_score * transition(x, delta_sigma)
394
+ x = sample_categorical(probs)
395
+
396
+ final_text = f"--- Final Denoised Text (Step {steps}) ---\n\n{wrap_text(decode(x[0]))}"
397
+ diffusion_frames.append(final_text)
398
+
399
+ # Return the final text and the complete list of frames
400
+ return final_text, diffusion_frames
401
+
402
+ def replay_diffusion(frames, replay_speed):
403
+ """
404
+ Slow replay phase. Iterates through the stored frames and yields them
405
+ with a delay to create an animation effect.
406
+ """
407
+ delay = 0.5 / replay_speed # Calculate delay based on speed multiplier
408
+ for frame in frames:
409
+ yield frame
410
+ time.sleep(delay)
411
+
412
+ # Define the Gradio UI
413
+ css = '''.gradio-container > .fillable {max-width: 720px !important}
414
+ h3{margin-top: 1em}
415
+ p{margin-top: 0}
416
+ textarea{font-family: monospace;background-color: black}
417
+ '''
418
+ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
419
+ gr.Markdown(
420
+ """
421
+ # LLADA inspired diffusion language model
422
+ ### A Tiny 11M parameters character based diffusion model
423
+ """
424
+ )
425
+
426
+ generate_button = gr.Button("Generate", variant="primary")
427
+
428
+ output_textbox = gr.Textbox(
429
+ label="Generated Text",
430
+ lines=15,
431
+ interactive=False,
432
+ show_copy_button=True,
433
+ placeholder="Generation will appear here..."
434
+ )
435
+ with gr.Row():
436
+ steps_slider = gr.Slider(
437
+ minimum=64,
438
+ maximum=512,
439
+ value=128,
440
+ step=1,
441
+ label="Denoising Steps",
442
+ info="Number of steps in the generation process."
443
+ )
444
+ speed_slider = gr.Slider(
445
+ minimum=1,
446
+ maximum=20,
447
+ value=10,
448
+ step=1,
449
+ label="Replay Speed",
450
+ info="Controls the speed of the animation after generation.",
451
+ visible=False
452
+ )
453
+
454
+ diffusion_frames_state = gr.State([])
455
+
456
+ generate_event = generate_button.click(
457
+ fn=generate_text,
458
+ inputs=[steps_slider],
459
+ outputs=[output_textbox, diffusion_frames_state]
460
+ ).then(
461
+ fn=replay_diffusion,
462
+ inputs=[diffusion_frames_state, speed_slider],
463
+ outputs=[output_textbox]
464
+ )
465
+
466
+ if __name__ == "__main__":
467
+ demo.launch()
app.py ADDED
@@ -0,0 +1,798 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
+ import math
8
+ import os
9
+ import pickle
10
+ import requests
11
+ import textwrap
12
+ import subprocess
13
+ import shutil
14
+ import time
15
+ from dataclasses import dataclass
16
+ from typing import Optional
17
+ from transformers import AutoTokenizer
18
+
19
+ # ==============================================================================
20
+ # ------------------------- VERSION 1: SHARED SETUP ----------------------------
21
+ # ==============================================================================
22
+
23
+ def setup_environment():
24
+ """Checks for and sets up the necessary data for V1."""
25
+ nano_gpt_repo_path = 'nanoGPT'
26
+ data_dir_path = 'shakespeare_char'
27
+ meta_path = os.path.join(data_dir_path, 'meta.pkl')
28
+
29
+ if os.path.exists(meta_path):
30
+ return
31
+
32
+ print("Required data not found. Starting one-time setup...")
33
+ if not os.path.exists(nano_gpt_repo_path):
34
+ try:
35
+ subprocess.run(['git', 'clone', 'https://github.com/karpathy/nanoGPT.git'], check=True, capture_output=True, text=True)
36
+ except subprocess.CalledProcessError as e:
37
+ print(f"Error cloning repository: {e.stderr}")
38
+ pass
39
+
40
+ source_data_dir = os.path.join(nano_gpt_repo_path, 'data', 'shakespeare_char')
41
+ if not os.path.exists(data_dir_path) and os.path.exists(source_data_dir):
42
+ shutil.copytree(source_data_dir, data_dir_path)
43
+
44
+ # Check if we can run prepare
45
+ prepare_script_path = os.path.join(data_dir_path, 'prepare.py')
46
+ if os.path.exists(prepare_script_path) and not os.path.exists(meta_path):
47
+ subprocess.run(['python', 'prepare.py'], check=True, cwd=data_dir_path, capture_output=True, text=True)
48
+
49
+ setup_environment()
50
+
51
+ def download_file(url, filename):
52
+ if os.path.exists(filename):
53
+ return
54
+ print(f"Downloading '{filename}'...")
55
+ try:
56
+ response = requests.get(url, stream=True)
57
+ response.raise_for_status()
58
+ with open(filename, 'wb') as f:
59
+ for chunk in response.iter_content(chunk_size=8192):
60
+ f.write(chunk)
61
+ except requests.exceptions.RequestException as e:
62
+ print(f"Error downloading {url}: {e}")
63
+
64
+ # ==============================================================================
65
+ # ---------------------- VERSION 1: ARCHITECTURE & LOGIC -----------------------
66
+ # ==============================================================================
67
+
68
+ # V1 Constants and Meta Loading
69
+ v1_data_dir = './shakespeare_char/'
70
+ v1_meta_url = 'https://huggingface.co/spaces/thejagstudio/diffusion-gpt/resolve/main/meta.pkl'
71
+ v1_meta_path = 'meta.pkl'
72
+ download_file(v1_meta_url, v1_meta_path)
73
+
74
+ v1_vocab_size = 65 # Fallback
75
+ v1_itos = {}
76
+ v1_stoi = {}
77
+
78
+ if os.path.exists(v1_meta_path):
79
+ with open(v1_meta_path, 'rb') as f:
80
+ meta = pickle.load(f)
81
+ v1_vocab_size = meta['vocab_size']
82
+ v1_itos = meta['itos']
83
+ v1_stoi = meta['stoi']
84
+
85
+ v1_context_length = 256
86
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
87
+
88
+ def v1_decode(indices_tensor: torch.Tensor):
89
+ if indices_tensor.dim() > 1:
90
+ indices_tensor = indices_tensor.squeeze(0)
91
+ indices = indices_tensor.cpu().numpy()
92
+ return ''.join([v1_itos.get(i, '?') for i in indices])
93
+
94
+ def wrap_text(long_text, width=80):
95
+ paragraphs = long_text.splitlines()
96
+ wrapped = [textwrap.fill(p, width=width) if p else '' for p in paragraphs]
97
+ return "\n".join(wrapped)
98
+
99
+ @dataclass
100
+ class V1_GPTConfig:
101
+ block_size: int = 1024
102
+ vocab_size: int = 50304
103
+ n_layer: int = 12
104
+ n_head: int = 12
105
+ n_embd: int = 768
106
+ cond_dim: int = 64
107
+ dropout: float = 0.0
108
+ bias: bool = False
109
+
110
+ class V1_MLP(nn.Module):
111
+ def __init__(self, config):
112
+ super().__init__()
113
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
114
+ self.gelu = nn.GELU()
115
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
116
+ self.dropout = nn.Dropout(config.dropout)
117
+ def forward(self, x):
118
+ x = self.c_fc(x)
119
+ x = self.gelu(x)
120
+ x = self.c_proj(x)
121
+ x = self.dropout(x)
122
+ return x
123
+
124
+ class V1_SelfAttention(nn.Module):
125
+ def __init__(self, config):
126
+ super().__init__()
127
+ assert config.n_embd % config.n_head == 0
128
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
129
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
130
+ self.attn_dropout = nn.Dropout(config.dropout)
131
+ self.resid_dropout = nn.Dropout(config.dropout)
132
+ self.n_head = config.n_head
133
+ self.n_embd = config.n_embd
134
+ self.dropout = config.dropout
135
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
136
+ def forward(self, x):
137
+ B, T, C = x.size()
138
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
139
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
140
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
141
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
142
+ if self.flash:
143
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=False)
144
+ else:
145
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
146
+ att = F.softmax(att, dim=-1)
147
+ att = self.attn_dropout(att)
148
+ y = att @ v
149
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
150
+ y = self.resid_dropout(self.c_proj(y))
151
+ return y
152
+
153
+ def v1_modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
154
+ return x * (1 + scale) + shift
155
+
156
+ def v1_bias_add_scale(x: torch.Tensor, bias: Optional[torch.Tensor], scale: torch.Tensor, residual: Optional[torch.Tensor]) -> torch.Tensor:
157
+ if bias is not None:
158
+ out = scale * (x + bias)
159
+ else:
160
+ out = scale * x
161
+ if residual is not None:
162
+ out = residual + out
163
+ return out
164
+
165
+ class V1_DDiTBlock(nn.Module):
166
+ def __init__(self, config):
167
+ super().__init__()
168
+ self.ln_1 = nn.LayerNorm(config.n_embd, bias=config.bias)
169
+ self.attn = V1_SelfAttention(config)
170
+ self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias)
171
+ self.mlp = V1_MLP(config)
172
+ self.adaLN_modulation = nn.Linear(config.cond_dim, 6 * config.n_embd)
173
+ self.adaLN_modulation.weight.data.zero_()
174
+ self.adaLN_modulation.bias.data.zero_()
175
+ def forward(self, x, c):
176
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
177
+ x_skip = x
178
+ x = v1_modulate(self.ln_1(x), shift_msa, scale_msa)
179
+ x = self.attn(x)
180
+ x = v1_bias_add_scale(self.attn(self.ln_1(x)), None, gate_msa, x_skip)
181
+ x = v1_bias_add_scale(self.mlp(v1_modulate(self.ln_2(x), shift_mlp, scale_mlp)), None, gate_mlp, x)
182
+ return x
183
+
184
+ class V1_DDitFinalLayer(nn.Module):
185
+ def __init__(self, config):
186
+ super().__init__()
187
+ self.norm_final = nn.LayerNorm(config.n_embd, bias=config.bias)
188
+ self.linear = nn.Linear(config.n_embd, config.vocab_size)
189
+ self.linear.weight.data.zero_()
190
+ self.linear.bias.data.zero_()
191
+ self.adaLN_modulation = nn.Linear(config.cond_dim, 2 * config.n_embd)
192
+ self.adaLN_modulation.weight.data.zero_()
193
+ self.adaLN_modulation.bias.data.zero_()
194
+ def forward(self, x, c):
195
+ shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
196
+ x = v1_modulate(self.norm_final(x), shift, scale)
197
+ x = self.linear(x)
198
+ return x
199
+
200
+ class V1_TimestepEmbedder(nn.Module):
201
+ def __init__(self, hidden_size, frequency_embedding_size=256):
202
+ super().__init__()
203
+ self.mlp = nn.Sequential(
204
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
205
+ nn.SiLU(),
206
+ nn.Linear(hidden_size, hidden_size, bias=True),
207
+ )
208
+ self.frequency_embedding_size = frequency_embedding_size
209
+ @staticmethod
210
+ def timestep_embedding(t, dim, max_period=10000):
211
+ half = dim // 2
212
+ freqs = torch.exp(
213
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
214
+ ).to(device=t.device)
215
+ args = t[:, None].float() * freqs[None]
216
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
217
+ if dim % 2:
218
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
219
+ return embedding
220
+ def forward(self, t):
221
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
222
+ t_emb = self.mlp(t_freq)
223
+ return t_emb
224
+
225
+ class V1_GPT(nn.Module):
226
+ def __init__(self, config):
227
+ super().__init__()
228
+ assert config.vocab_size is not None
229
+ assert config.block_size is not None
230
+ self.config = config
231
+ self.sigma_map = V1_TimestepEmbedder(config.cond_dim)
232
+ self.transformer = nn.ModuleDict(dict(
233
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
234
+ wpe = nn.Embedding(config.block_size, config.n_embd),
235
+ drop = nn.Dropout(config.dropout),
236
+ h = nn.ModuleList([V1_DDiTBlock(config) for _ in range(config.n_layer)]),
237
+ ln_f = nn.LayerNorm(config.n_embd, bias=config.bias),
238
+ ))
239
+ self.lm_head = V1_DDitFinalLayer(config)
240
+ self.apply(self._init_weights)
241
+ for pn, p in self.named_parameters():
242
+ if pn.endswith('c_proj.weight'):
243
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
244
+ def _init_weights(self, module):
245
+ if isinstance(module, nn.Linear):
246
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
247
+ if module.bias is not None:
248
+ torch.nn.init.zeros_(module.bias)
249
+ elif isinstance(module, nn.Embedding):
250
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
251
+ def forward(self, idx, sigma):
252
+ sigma = sigma.reshape(-1)
253
+ b, t = idx.size()
254
+ c = F.silu(self.sigma_map(sigma))
255
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
256
+ pos = torch.arange(0, t, dtype=torch.long, device=device)
257
+ tok_emb = self.transformer.wte(idx)
258
+ pos_emb = self.transformer.wpe(pos)
259
+ x = self.transformer.drop(tok_emb + pos_emb)
260
+ for block in self.transformer.h:
261
+ x = block(x, c)
262
+ x = self.transformer.ln_f(x)
263
+ x = self.lm_head(x, c)
264
+ x = torch.scatter(x, -1, idx[..., None], torch.zeros_like(x[..., :1]))
265
+ return x
266
+
267
+ class V1_GeometricNoise:
268
+ def __init__(self, sigma_min=1e-4, sigma_max=20):
269
+ self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max]).to(device)
270
+ def rate_noise(self, t):
271
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (self.sigmas[1].log() - self.sigmas[0].log())
272
+ def total_noise(self, t):
273
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
274
+ def __call__(self, t):
275
+ return self.total_noise(t), self.rate_noise(t)
276
+
277
+ # --- V1 Inference Logic ---
278
+ def v1_transition(x_t: torch.Tensor, delta_sigma: torch.Tensor) -> torch.Tensor:
279
+ base_prob = (1 - torch.exp(-delta_sigma[..., None])) / v1_vocab_size
280
+ trans = torch.ones(*x_t.shape, v1_vocab_size, device=x_t.device) * base_prob
281
+ trans = trans.scatter(-1, x_t[..., None], torch.zeros_like(trans))
282
+ diag_fill = 1 - trans.sum(dim=-1, keepdim=True)
283
+ trans = trans.scatter(-1, x_t[..., None], diag_fill)
284
+ return trans
285
+
286
+ def v1_staggered_score(score, delta_sigma):
287
+ exp_factor = torch.exp(-delta_sigma)[..., None]
288
+ correction = ((exp_factor - 1) / (v1_vocab_size * exp_factor)) * score.sum(dim=-1, keepdim=True)
289
+ return correction + score / exp_factor
290
+
291
+ def v1_sample_categorical(probs: torch.Tensor) -> torch.Tensor:
292
+ eps = 1e-10
293
+ gumbel_noise = -torch.log(-torch.log(torch.rand_like(probs) + eps) + eps)
294
+ return torch.argmax(torch.log(probs + eps) + gumbel_noise, dim=-1)
295
+
296
+ # --- V1 Model Loading ---
297
+ print("Initializing V1 Model...")
298
+ v1_model_args = dict(n_layer=6, n_head=6, n_embd=384, cond_dim=64,
299
+ bias=False, vocab_size=v1_vocab_size, block_size=v1_context_length, dropout=0.2)
300
+ v1_config = V1_GPTConfig(**v1_model_args)
301
+ v1_model = V1_GPT(v1_config)
302
+ try:
303
+ v1_model.load_state_dict(
304
+ torch.hub.load_state_dict_from_url(
305
+ 'https://huggingface.co/spaces/thejagstudio/diffusion-gpt/resolve/main/final_model.pth?download=true',
306
+ map_location=device
307
+ )
308
+ )
309
+ v1_model.to(device)
310
+ v1_model.eval()
311
+ print("V1 Model loaded successfully.")
312
+ except Exception as e:
313
+ print(f"Failed to load V1 model: {e}")
314
+ v1_model = None
315
+
316
+ v1_noise = V1_GeometricNoise(sigma_min=1e-4, sigma_max=20)
317
+
318
+
319
+ def v1_generate_stream(steps, speed):
320
+ """
321
+ Generator function for V1 that yields frames directly.
322
+ Combined logic of generation and replay to allow for immediate stopping.
323
+ """
324
+ if v1_model is None:
325
+ yield "Error: V1 Model not loaded"
326
+ return
327
+
328
+ steps = int(steps)
329
+ speed = float(speed)
330
+ eps = 1e-5
331
+
332
+ # Calculate delay based on speed slider (similar to V2)
333
+ # 0.5 is base constant, speed scales it down
334
+ delay = 0.5 / max(speed, 0.1)
335
+
336
+ x = torch.randint(0, v1_vocab_size, (1, v1_context_length), device=device)
337
+ initial_text = f"--- Initial Random Noise ---\n\n{wrap_text(v1_decode(x[0]))}"
338
+ yield initial_text
339
+ time.sleep(delay)
340
+
341
+ timesteps = torch.linspace(1, eps, steps + 1, device=device)
342
+ step_size = (1 - eps) / steps
343
+
344
+ with torch.no_grad():
345
+ for i in range(steps):
346
+ t = timesteps[i] * torch.ones(x.shape[0], 1, device=device)
347
+ curr_sigma_bar = v1_noise(t)[0]
348
+
349
+ next_sigma_bar = v1_noise(t - step_size)[0]
350
+ delta_sigma = curr_sigma_bar - next_sigma_bar
351
+
352
+ log_score = v1_model(x, curr_sigma_bar)
353
+ score = torch.exp(log_score)
354
+
355
+ stag_score = v1_staggered_score(score, delta_sigma)
356
+ probs = stag_score * v1_transition(x, delta_sigma)
357
+ x = v1_sample_categorical(probs)
358
+
359
+ progress_text = f"--- Denoising Step {i + 1}/{steps} ---\n\n{wrap_text(v1_decode(x[0]))}"
360
+ yield progress_text
361
+
362
+ # Artificial delay for visualization
363
+ if speed < 20:
364
+ time.sleep(delay)
365
+
366
+ t = timesteps[steps] * torch.ones(x.shape[0], 1, device=device)
367
+ curr_sigma_bar = v1_noise(t)[0]
368
+ delta_sigma = curr_sigma_bar
369
+
370
+ log_score = v1_model(x, curr_sigma_bar)
371
+ score = torch.exp(log_score)
372
+ stag_score = v1_staggered_score(score, delta_sigma)
373
+ probs = stag_score * v1_transition(x, delta_sigma)
374
+ x = v1_sample_categorical(probs)
375
+
376
+ final_text = f"--- Final Denoised Text (Step {steps}) ---\n\n{wrap_text(v1_decode(x[0]))}"
377
+ yield final_text
378
+
379
+ # ==============================================================================
380
+ # ---------------------- VERSION 2: ARCHITECTURE & LOGIC -----------------------
381
+ # ==============================================================================
382
+
383
+ # PLEASE UPDATE THIS PATH TO YOUR ACTUAL LOCAL FILE OR URL
384
+ V2_MODEL_PATH = "checkpoints/model_fp32.pt"
385
+
386
+ class V2_RMSNorm(nn.Module):
387
+ def __init__(self, dim: int, eps: float = 1e-6):
388
+ super().__init__()
389
+ self.eps = eps
390
+ self.weight = nn.Parameter(torch.ones(dim))
391
+
392
+ def forward(self, x):
393
+ var = x.pow(2).mean(-1, keepdim=True)
394
+ x = x * torch.rsqrt(var + self.eps)
395
+ return self.weight * x
396
+
397
+ class V2_RotaryEmbedding(nn.Module):
398
+ def __init__(self, dim, max_position_embeddings=16384, base=100000, scaling_factor=1.0):
399
+ super().__init__()
400
+ self.scaling_factor = scaling_factor
401
+ self.dim = dim
402
+ self.base = base
403
+ self.max_position_embeddings = max_position_embeddings
404
+ self.inv_freq = None
405
+ self._cache = {}
406
+
407
+ def _update_freqs(self, device):
408
+ base = self.base * (self.scaling_factor ** (self.dim / (self.dim - 2)))
409
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
410
+ self.inv_freq = inv_freq
411
+
412
+ def forward(self, x, seq_len=None):
413
+ if seq_len is None:
414
+ seq_len = x.shape[-2]
415
+
416
+ if self.inv_freq is None or self.inv_freq.device != x.device:
417
+ self._update_freqs(x.device)
418
+
419
+ cache_key = (seq_len, x.device, x.dtype)
420
+ if cache_key in self._cache:
421
+ return self._cache[cache_key]
422
+
423
+ t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
424
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
425
+ emb = torch.cat((freqs, freqs), dim=-1)
426
+
427
+ cos = emb.cos()[None, None, :, :]
428
+ sin = emb.sin()[None, None, :, :]
429
+
430
+ self._cache[cache_key] = (cos, sin)
431
+ if len(self._cache) > 10:
432
+ self._cache.pop(next(iter(self._cache)))
433
+
434
+ return cos, sin
435
+
436
+ def v2_apply_rotary_pos_emb(q, k, cos, sin):
437
+ def rotate_half(x):
438
+ x1 = x[..., : x.shape[-1] // 2]
439
+ x2 = x[..., x.shape[-1] // 2 :]
440
+ return torch.cat((-x2, x1), dim=-1)
441
+ q_embed = (q * cos) + (rotate_half(q) * sin)
442
+ k_embed = (k * cos) + (rotate_half(k) * sin)
443
+ return q_embed, k_embed
444
+
445
+ class V2_DiffusionAttention(nn.Module):
446
+ def __init__(self, config):
447
+ super().__init__()
448
+ self.hidden_size = config.hidden_size
449
+ self.num_heads = config.num_attention_heads
450
+ self.head_dim = self.hidden_size // self.num_heads
451
+ self.num_key_value_heads = config.num_key_value_heads
452
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
453
+ self.use_flash_attn = config.use_flash_attn
454
+
455
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
456
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
457
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
458
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
459
+
460
+ def forward(self, hidden_states, freqs_cis, attention_mask=None, past_kv=None):
461
+ bsz, q_len, _ = hidden_states.size()
462
+
463
+ q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
464
+ k = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
465
+ v = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
466
+
467
+ cos, sin = freqs_cis
468
+ cos = cos[:, :, :q_len, :]
469
+ sin = sin[:, :, :q_len, :]
470
+ q, k = v2_apply_rotary_pos_emb(q, k, cos, sin)
471
+
472
+ if past_kv is not None:
473
+ cache_k, cache_v = past_kv
474
+ k = torch.cat([cache_k, k], dim=2)
475
+ v = torch.cat([cache_v, v], dim=2)
476
+
477
+ current_kv = (k, v)
478
+ k = k.repeat_interleave(self.num_key_value_groups, dim=1)
479
+ v = v.repeat_interleave(self.num_key_value_groups, dim=1)
480
+
481
+ attn_mask = None
482
+ if attention_mask is not None:
483
+ attn_mask = attention_mask[:, None, None, :].to(dtype=q.dtype)
484
+ attn_mask = (1.0 - attn_mask) * torch.finfo(q.dtype).min
485
+
486
+ output = F.scaled_dot_product_attention(
487
+ q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
488
+ )
489
+
490
+ output = output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
491
+ return self.o_proj(output), current_kv
492
+
493
+ class V2_MLP(nn.Module):
494
+ def __init__(self, config):
495
+ super().__init__()
496
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
497
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
498
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
499
+ self.act_fn = nn.SiLU()
500
+
501
+ def forward(self, x):
502
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
503
+
504
+ class V2_BlockDiffusionBlock(nn.Module):
505
+ def __init__(self, config):
506
+ super().__init__()
507
+ self.self_attn = V2_DiffusionAttention(config)
508
+ self.mlp = V2_MLP(config)
509
+ self.input_layernorm = V2_RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
510
+ self.post_attention_layernorm = V2_RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
511
+ self.use_activation_checkpointing = config.use_activation_checkpointing
512
+
513
+ def forward(self, hidden_states, freqs_cis, attention_mask, past_kv):
514
+ return self._forward(hidden_states, freqs_cis, attention_mask, past_kv)
515
+
516
+ def _forward(self, hidden_states, freqs_cis, attention_mask, past_kv):
517
+ residual = hidden_states
518
+ hidden_states = self.input_layernorm(hidden_states)
519
+ attn_out, new_kv = self.self_attn(hidden_states, freqs_cis, attention_mask, past_kv)
520
+ hidden_states = residual + attn_out
521
+
522
+ residual = hidden_states
523
+ hidden_states = self.post_attention_layernorm(hidden_states)
524
+ hidden_states = residual + self.mlp(hidden_states)
525
+ return hidden_states, new_kv
526
+
527
+ @dataclass
528
+ class V2_ModelConfig:
529
+ vocab_size: int = 151936
530
+ hidden_size: int = 1024
531
+ intermediate_size: int = 2816
532
+ num_hidden_layers: int = 16
533
+ num_attention_heads: int = 16
534
+ num_key_value_heads: int = 4
535
+ max_position_embeddings: int = 16384
536
+ rms_norm_eps: float = 1e-6
537
+ rope_theta: float = 100000.0
538
+ pad_token_id: int = 0
539
+ mask_token_id: int = 1
540
+ use_flash_attn: bool = True
541
+ use_activation_checkpointing: bool = False
542
+ attention_dropout: float = 0.0
543
+ hidden_dropout: float = 0.0
544
+
545
+ ModelConfig = V2_ModelConfig
546
+
547
+ class V2_DiffusionLLM(nn.Module):
548
+ def __init__(self, config: V2_ModelConfig):
549
+ super().__init__()
550
+ self.config = config
551
+ pad_idx = config.pad_token_id if config.pad_token_id < config.vocab_size else None
552
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=pad_idx)
553
+
554
+ self.layers = nn.ModuleList([V2_BlockDiffusionBlock(config) for _ in range(config.num_hidden_layers)])
555
+ self.norm = V2_RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
556
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
557
+ self.rotary_emb = V2_RotaryEmbedding(
558
+ config.hidden_size // config.num_attention_heads,
559
+ config.max_position_embeddings
560
+ )
561
+ self.lm_head.weight = self.embed_tokens.weight
562
+
563
+ def forward(self, input_ids, attention_mask=None, past_key_values=None):
564
+ bsz, seqlen = input_ids.shape
565
+ hidden_states = self.embed_tokens(input_ids)
566
+ freqs_cis = self.rotary_emb(hidden_states, seq_len=seqlen)
567
+
568
+ if past_key_values is None:
569
+ past_key_values = [None] * len(self.layers)
570
+
571
+ new_kvs = []
572
+ for i, layer in enumerate(self.layers):
573
+ hidden_states, kv = layer(hidden_states, freqs_cis, attention_mask, past_key_values[i])
574
+ new_kvs.append(kv)
575
+
576
+ hidden_states = self.norm(hidden_states)
577
+ logits = self.lm_head(hidden_states)
578
+ return logits, new_kvs
579
+
580
+ DiffusionLLM = V2_DiffusionLLM
581
+
582
+ # --- V2 Loading Logic ---
583
+ print("Initializing V2 components...")
584
+ v2_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
585
+ if v2_tokenizer.pad_token is None:
586
+ v2_tokenizer.pad_token = v2_tokenizer.eos_token
587
+
588
+ v2_model = None
589
+ v2_config = None
590
+
591
+ if os.path.exists(V2_MODEL_PATH):
592
+ print(f"Loading V2 model from {V2_MODEL_PATH}...")
593
+ try:
594
+ checkpoint = torch.load(V2_MODEL_PATH, map_location=device, weights_only=False)
595
+ v2_config = checkpoint['config']
596
+ v2_model = V2_DiffusionLLM(v2_config)
597
+ state_dict = checkpoint['model_state']
598
+ state_dict = {k: v.float() for k, v in state_dict.items()}
599
+ v2_model.load_state_dict(state_dict)
600
+ v2_model = v2_model.to(device)
601
+ v2_model.eval()
602
+ print("V2 Model loaded.")
603
+ except Exception as e:
604
+ print(f"Error loading V2 model: {e}")
605
+ else:
606
+ print(f"V2 Model file not found at {V2_MODEL_PATH}. Version 2 tab will not work without it.")
607
+
608
+
609
+ @torch.no_grad()
610
+ def v2_generate_block_diffusion(prompt, steps, block_size, max_new_tokens, replay_speed):
611
+ """
612
+ Refactored to yield frames for real-time streaming.
613
+ """
614
+ if v2_model is None:
615
+ yield "Error: V2 Model not found. Check path."
616
+ return
617
+
618
+ v2_model.eval()
619
+ # Handle inputs
620
+ steps = int(steps)
621
+ block_size = int(block_size)
622
+ max_new_tokens = int(max_new_tokens)
623
+ speed = float(replay_speed)
624
+
625
+ prompt_ids = v2_tokenizer.encode(prompt, return_tensors="pt").to(device)
626
+ config = v2_model.config
627
+ num_blocks = max_new_tokens // block_size
628
+
629
+ context_ids = prompt_ids
630
+
631
+ # Helper params
632
+ temperature = 1.0
633
+ top_k = 40
634
+ top_p = 0.9
635
+ repetition_penalty = 1.2
636
+
637
+ # Calculate delay based on speed slider
638
+ delay = 0.5 / max(speed, 0.1)
639
+
640
+ for block_idx in range(num_blocks):
641
+ mask_block = torch.full((1, block_size), config.mask_token_id, device=device)
642
+ is_masked = torch.ones(1, block_size, dtype=torch.bool, device=device)
643
+
644
+ for step_idx in range(steps):
645
+ # --- SNAPSHOT & YIELD ---
646
+ # Decode context
647
+ ctx_str = v2_tokenizer.decode(context_ids[0], skip_special_tokens=True)
648
+
649
+ # Decode block with masking visual
650
+ block_tokens = mask_block[0].tolist()
651
+ block_vis = []
652
+ for i, tid in enumerate(block_tokens):
653
+ if is_masked[0, i]:
654
+ block_vis.append("β–‘") # Mask symbol
655
+ else:
656
+ block_vis.append(v2_tokenizer.decode([tid], skip_special_tokens=False))
657
+
658
+ block_str = "".join(block_vis)
659
+
660
+ frame_text = (f"--- Generating Block {block_idx+1}/{num_blocks} | Step {step_idx+1}/{steps} ---\n\n"
661
+ f"{ctx_str}{block_str}")
662
+
663
+ yield frame_text
664
+
665
+ # Artificial delay to visualize the step
666
+ if speed < 20: # If max speed, skip sleep
667
+ time.sleep(delay)
668
+ # ------------------------
669
+
670
+ full_input = torch.cat([context_ids, mask_block], dim=1)
671
+ attention_mask = torch.ones_like(full_input, dtype=torch.float32)
672
+
673
+ logits, _ = v2_model(full_input, attention_mask=attention_mask)
674
+ block_logits = logits[:, -block_size:, :]
675
+
676
+ # Repetition penalty
677
+ if repetition_penalty != 1.0:
678
+ seen_tokens = set(context_ids[0].tolist())
679
+ for i in range(block_size):
680
+ if not is_masked[0, i]:
681
+ seen_tokens.add(mask_block[0, i].item())
682
+ for token_id in seen_tokens:
683
+ if token_id < block_logits.shape[-1]:
684
+ if block_logits[0, :, token_id].mean() > 0:
685
+ block_logits[:, :, token_id] /= repetition_penalty
686
+ else:
687
+ block_logits[:, :, token_id] *= repetition_penalty
688
+
689
+ block_logits = block_logits / temperature
690
+
691
+ # Top-K
692
+ if top_k > 0:
693
+ top_k_logits, top_k_indices = torch.topk(block_logits, top_k, dim=-1)
694
+ block_logits = torch.full_like(block_logits, float('-inf'))
695
+ block_logits.scatter_(-1, top_k_indices, top_k_logits)
696
+
697
+ # Top-P
698
+ if top_p < 1.0:
699
+ sorted_logits, sorted_indices = torch.sort(block_logits, descending=True, dim=-1)
700
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
701
+ sorted_indices_to_remove = cumulative_probs > top_p
702
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
703
+ sorted_indices_to_remove[..., 0] = 0
704
+ indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
705
+ block_logits[indices_to_remove] = float('-inf')
706
+
707
+ probs = F.softmax(block_logits, dim=-1)
708
+ probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
709
+ probs = probs.clamp(min=1e-10)
710
+ probs = probs / probs.sum(dim=-1, keepdim=True)
711
+
712
+ sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1)
713
+ sampled_tokens = sampled_tokens.view(1, block_size)
714
+
715
+ confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1)
716
+
717
+ tokens_to_unmask = max(1, block_size // steps)
718
+ if step_idx == steps - 1:
719
+ tokens_to_unmask = is_masked.sum().item()
720
+
721
+ if tokens_to_unmask > 0 and is_masked.sum() > 0:
722
+ masked_confidence = confidence.clone()
723
+ masked_confidence[~is_masked] = -1.0
724
+ num_to_unmask = min(tokens_to_unmask, is_masked.sum().item())
725
+ _, top_indices = torch.topk(masked_confidence.view(-1), num_to_unmask)
726
+
727
+ for idx in top_indices:
728
+ mask_block[0, idx] = sampled_tokens[0, idx]
729
+ is_masked[0, idx] = False
730
+
731
+ context_ids = torch.cat([context_ids, mask_block], dim=1)
732
+
733
+ generated_ids = context_ids[0].tolist()
734
+ final_text = v2_tokenizer.decode(generated_ids, skip_special_tokens=True)
735
+ yield final_text
736
+
737
+
738
+ # ==============================================================================
739
+ # ------------------------------- GRADIO UI ------------------------------------
740
+ # ==============================================================================
741
+
742
+ css = '''.gradio-container > .fillable {max-width: 900px !important}
743
+ h3{margin-top: 1em}
744
+ p{margin-top: 0}
745
+ textarea{font-family: monospace; background-color: #1a1b1e; color: #e0e0e0}
746
+ '''
747
+
748
+ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
749
+ gr.Markdown("# Diffusion Language Models Playground")
750
+
751
+ with gr.Tabs():
752
+
753
+ # --- TAB 1: VERSION 1 (CHAR DIFFUSION) ---
754
+ with gr.Tab("Version 1: Character Diffusion (NanoGPT)"):
755
+ gr.Markdown("### Tiny 11M parameter character-based continuous diffusion.")
756
+ with gr.Row():
757
+ with gr.Column(scale=1):
758
+ v1_steps = gr.Slider(64, 512, 128, step=1, label="Denoising Steps")
759
+ v1_speed = gr.Slider(1, 20, 10, step=1, label="Generation/Replay Speed")
760
+ with gr.Row():
761
+ v1_btn = gr.Button("Generate", variant="primary")
762
+ v1_stop = gr.Button("Stop", variant="stop")
763
+ with gr.Column(scale=2):
764
+ v1_out = gr.Textbox(label="Generated Text", lines=15, interactive=False)
765
+
766
+ # V1 Logic: Merged generation and replay for proper stopping
767
+ v1_event = v1_btn.click(v1_generate_stream, inputs=[v1_steps, v1_speed], outputs=[v1_out])
768
+ v1_stop.click(fn=None, inputs=None, outputs=None, cancels=[v1_event])
769
+
770
+ # --- TAB 2: VERSION 2 (BLOCK DIFFUSION) ---
771
+ with gr.Tab("Version 2: Block Diffusion (LLaDA-style)"):
772
+ gr.Markdown("### Block-based diffusion using Qwen tokenizer.")
773
+ if v2_model is None:
774
+ gr.Warning(f"V2 Model not loaded. Please check path: {V2_MODEL_PATH}")
775
+
776
+ with gr.Row():
777
+ with gr.Column(scale=1):
778
+ v2_prompt = gr.Textbox(label="Prompt", value="The king went to the")
779
+ v2_steps = gr.Slider(4, 64, 16, step=1, label="Steps per Block")
780
+ v2_block_size = gr.Slider(8, 126, 32, step=8, label="Block Size")
781
+ v2_max_tokens = gr.Slider(32, 1024, 128, step=32, label="Total New Tokens")
782
+ v2_speed = gr.Slider(1, 20, 1, step=1, label="Generation/Replay Speed")
783
+ with gr.Row():
784
+ v2_btn = gr.Button("Generate", variant="primary")
785
+ v2_stop = gr.Button("Stop", variant="stop")
786
+ with gr.Column(scale=2):
787
+ v2_out = gr.Textbox(label="Generated Text", lines=15, interactive=False)
788
+
789
+ # V2 Logic
790
+ v2_event = v2_btn.click(
791
+ v2_generate_block_diffusion,
792
+ inputs=[v2_prompt, v2_steps, v2_block_size, v2_max_tokens, v2_speed],
793
+ outputs=[v2_out]
794
+ )
795
+ v2_stop.click(fn=None, inputs=None, outputs=None, cancels=[v2_event])
796
+
797
+ if __name__ == "__main__":
798
+ demo.launch()
final_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86da61c7e77f7062e51ec2974a723d6f195581e385ed802373d178ade0c81483
3
+ size 47047458
meta.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f221fcf4332ffda72912fad6e4b64e7988d6c4cba7ffcee98edbbc23f9a8400d
3
+ size 913
model_epoch_1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:413c52559c94e2e71d991e45147c8773e02d007b095443ff2ffaed5174e147a7
3
+ size 47038242
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ torch>=2.0.0
3
+ numpy
4
+ requests
5
+ transformers