import os import sys import json import time import importlib.util from pathlib import Path from flask import Flask, request, jsonify, Response, stream_with_context from flask_cors import CORS import torch from transformers import AutoTokenizer app = Flask(__name__, static_folder='static', static_url_path='/static') CORS(app) # Global state model = None tokenizer = None config = None device = None DiffusionLLM = None chat_function = None def find_file(filename, search_dirs=None): """Find a file in current directory or parent directories.""" if search_dirs is None: search_dirs = [ os.path.dirname(__file__), # Current directory os.path.dirname(os.path.dirname(__file__)), # Parent directory os.getcwd(), # Working directory ] for directory in search_dirs: filepath = os.path.join(directory, filename) if os.path.exists(filepath): print(f"Found {filename} at: {filepath}") return filepath return None def try_import_module(filepath, module_name): """Dynamically import a Python file as a module.""" if not filepath or not os.path.exists(filepath): return None try: # Add the directory to sys.path module_dir = os.path.dirname(filepath) if module_dir not in sys.path: sys.path.insert(0, module_dir) spec = importlib.util.spec_from_file_location(module_name, filepath) if spec is None: print(f"Could not create spec for {filepath}") return None module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module spec.loader.exec_module(module) print(f"Successfully imported {module_name} from {filepath}") return module except Exception as e: print(f"Error importing {filepath}: {e}") import traceback traceback.print_exc() return None def load_model_internal(): """Load the model and tokenizer.""" global model, tokenizer, config, device, DiffusionLLM, chat_function if model is not None: return True try: print("=" * 60) print("Starting model loading process...") print("=" * 60) # Find and import infer-base.py base_path = find_file("infer-base.py") if base_path is None: raise RuntimeError("Could not find infer-base.py. Make sure it's in the same directory as app.py or parent directory.") print(f"\nImporting infer-base.py from: {base_path}") base_mod = try_import_module(base_path, "infer_base") if base_mod is None: raise RuntimeError("Failed to import infer-base.py") # Check for DiffusionLLM class if not hasattr(base_mod, 'DiffusionLLM'): print("Available attributes in infer_base:", dir(base_mod)) raise RuntimeError("DiffusionLLM class not found in infer-base.py") DiffusionLLM = base_mod.DiffusionLLM print("✓ Successfully loaded DiffusionLLM class") # Find and import infer-chat.py chat_path = find_file("infer-chat.py") if chat_path is None: raise RuntimeError("Could not find infer-chat.py") print(f"\nImporting infer-chat.py from: {chat_path}") chat_mod = try_import_module(chat_path, "infer_chat") if chat_mod is None or not hasattr(chat_mod, 'chat'): raise RuntimeError("Failed to import chat function from infer-chat.py") chat_function = chat_mod.chat print("✓ Successfully loaded chat function") # Setup pickling workaround for torch.load try: if hasattr(base_mod, 'ModelConfig'): sys.modules['__main__'].ModelConfig = base_mod.ModelConfig sys.modules['__main__'].DiffusionLLM = DiffusionLLM print("✓ Configured pickle support for model loading") except Exception as e: print(f"Warning: Could not setup pickle workaround: {e}") # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"\n✓ Using device: {device}") # Load tokenizer print("\nLoading tokenizer...") tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("✓ Tokenizer loaded") # Find model checkpoint checkpoint_dirs = [ "checkpoints", "../checkpoints", "./checkpoints", os.path.join(os.path.dirname(__file__), "checkpoints"), os.path.join(os.path.dirname(__file__), "../checkpoints"), ] model_path = None for checkpoint_dir in checkpoint_dirs: best_path = os.path.join(checkpoint_dir, "best_model.pt") fp32_path = os.path.join(checkpoint_dir, "model_fp32.pt") if os.path.exists(best_path): model_path = best_path break elif os.path.exists(fp32_path): model_path = fp32_path break if model_path is None: raise RuntimeError( "Could not find model checkpoint. Looking for:\n" " - checkpoints/best_model.pt\n" " - checkpoints/model_fp32.pt\n" f"Searched directories: {checkpoint_dirs}" ) print(f"\n✓ Found model checkpoint: {model_path}") print("Loading model weights (this may take a minute)...") # Load model checkpoint = torch.load(model_path, map_location=device, weights_only=False) config = checkpoint['config'] print("Creating model...") model = DiffusionLLM(config) print("Loading state dict...") state_dict = checkpoint['model_state'] state_dict = {k: v.float() for k, v in state_dict.items()} model.load_state_dict(state_dict) model = model.to(device) model.eval() num_params = sum(p.numel() for p in model.parameters()) / 1e6 print(f"\n{'=' * 60}") print(f"✓✓✓ MODEL LOADED SUCCESSFULLY ✓✓✓") print(f"{'=' * 60}") print(f"Parameters: {num_params:.1f}M") if 'step' in checkpoint: print(f"Training steps: {checkpoint['step']}") if 'best_val_loss' in checkpoint: print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}") print(f"{'=' * 60}\n") return True except Exception as e: print("\n" + "=" * 60) print("ERROR LOADING MODEL") print("=" * 60) print(f"Error: {e}") import traceback traceback.print_exc() print("=" * 60 + "\n") return False def create_streaming_visualizer(): """Create a visualizer that yields SSE events instead of printing to terminal.""" def visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True): # Normalize inputs to lists if not isinstance(mask_blocks, list): mask_blocks = [mask_blocks] is_masked_list = [is_masked_list] # Decode context try: context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ') except Exception: context_text = str(context_ids[0].tolist()) # Build blocks visualization all_blocks = [] for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)): block_tokens = mask_block[0].tolist() block_data = [] for i, token_id in enumerate(block_tokens): if is_masked[0, i]: block_data.append({ 'type': 'masked', 'text': '███' }) else: try: token_text = tok.decode([token_id], skip_special_tokens=False) except Exception: token_text = str(int(token_id)) block_data.append({ 'type': 'revealed', 'text': token_text }) all_blocks.append({ 'block_index': block_idx, 'tokens': block_data }) # Return data structure that will be sent as SSE return { 'context': context_text, 'blocks': all_blocks, 'num_blocks': len(mask_blocks) } return visualizer @app.route('/') def index(): """Serve the main HTML page.""" return app.send_static_file('index.html') @app.route('/api/load', methods=['POST']) def load_model_endpoint(): """Load the model.""" data = request.json or {} check_only = data.get('check_only', False) global model if check_only: return jsonify({ 'loaded': model is not None, 'message': 'Model is loaded' if model is not None else 'Model not loaded' }) if model is not None: return jsonify({ 'loaded': True, 'message': 'Model already loaded' }) success = load_model_internal() if success: return jsonify({ 'loaded': True, 'message': 'Model loaded successfully' }) else: return jsonify({ 'loaded': False, 'message': 'Failed to load model. Check server logs for details.' }), 500 @app.route('/api/generate', methods=['POST']) def generate(): """Generate response without streaming.""" global model, tokenizer, config, device, chat_function if model is None: return jsonify({'error': 'Model not loaded'}), 400 if chat_function is None: return jsonify({'error': 'Chat function not available'}), 400 data = request.json instruction = data.get('instruction', '') steps = data.get('steps', 64) block_size = data.get('block_size', 128) max_new_tokens = data.get('max_new_tokens', 128) parallel_blocks = data.get('parallel_blocks', 1) if not instruction: return jsonify({'error': 'No instruction provided'}), 400 try: # Generate response raw_output, response = chat_function( model, tokenizer, instruction, steps=steps, block_size=block_size, max_new_tokens=max_new_tokens, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.2, no_repeat_ngram_size=3, verbose=False, visualize_fn=None, parallel_blocks=parallel_blocks, ) return jsonify({ 'response': response, 'raw_output': raw_output }) except Exception as e: import traceback traceback.print_exc() return jsonify({'error': str(e)}), 500 @app.route('/api/generate-stream', methods=['POST']) def generate_stream(): """Generate response with streaming visualization.""" global model, tokenizer, config, device, chat_function if model is None: return jsonify({'error': 'Model not loaded'}), 400 if chat_function is None: return jsonify({'error': 'Chat function not available'}), 400 data = request.json instruction = data.get('instruction', '') steps = data.get('steps', 64) block_size = data.get('block_size', 128) max_new_tokens = data.get('max_new_tokens', 128) parallel_blocks = data.get('parallel_blocks', 1) if not instruction: return jsonify({'error': 'No instruction provided'}), 400 def generate_events(): try: # Import threading to allow yielding from callback import queue event_queue = queue.Queue() generation_complete = {'done': False, 'result': None} def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True): """This gets called during generation - we need to send events immediately""" visualizer = create_streaming_visualizer() data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear) # Put the update in the queue so it can be yielded immediately event_queue.put({'type': 'update', 'data': data}) # Start generation in a separate thread so we can yield events as they come import threading def run_generation(): try: raw_output, response = chat_function( model, tokenizer, instruction, steps=steps, block_size=block_size, max_new_tokens=max_new_tokens, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.2, no_repeat_ngram_size=3, verbose=False, visualize_fn=streaming_visualizer, parallel_blocks=parallel_blocks, ) generation_complete['result'] = (raw_output, response) except Exception as e: generation_complete['result'] = ('error', str(e)) finally: generation_complete['done'] = True event_queue.put(None) # Signal completion # Start generation thread gen_thread = threading.Thread(target=run_generation) gen_thread.daemon = True gen_thread.start() # Yield start event yield f"data: {json.dumps({'type': 'start', 'message': 'Generation started'})}\n\n" # Yield events as they come from the queue while not generation_complete['done'] or not event_queue.empty(): try: event = event_queue.get(timeout=0.1) if event is None: # Completion signal break yield f"data: {json.dumps(event)}\n\n" except queue.Empty: continue # Wait for thread to finish gen_thread.join(timeout=1.0) # Send final response if generation_complete['result']: raw_output, response = generation_complete['result'] if raw_output == 'error': yield f"data: {json.dumps({'type': 'error', 'error': response})}\n\n" else: yield f"data: {json.dumps({'type': 'complete', 'response': response, 'raw_output': raw_output})}\n\n" except Exception as e: import traceback traceback.print_exc() yield f"data: {json.dumps({'type': 'error', 'error': str(e)})}\n\n" return Response( stream_with_context(generate_events()), mimetype='text/event-stream', headers={ 'Cache-Control': 'no-cache', 'X-Accel-Buffering': 'no' } ) @app.route('/api/test-stream', methods=['GET']) def test_stream(): """Test streaming endpoint.""" def generate(): for i in range(10): yield f"data: {json.dumps({'message': f'Test message {i+1}'})}\n\n" time.sleep(0.5) yield f"data: {json.dumps({'message': 'Stream complete'})}\n\n" return Response( stream_with_context(generate()), mimetype='text/event-stream', headers={ 'Cache-Control': 'no-cache', 'X-Accel-Buffering': 'no' } ) if __name__ == '__main__': app.run(debug=True, host='0.0.0.0', port=7860, threaded=True)