import gradio as gr import json import os import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader import pickle from pathlib import Path from datetime import datetime import threading import glob from collections import Counter import struct class SimpleTokenizer: """A simple tokenizer for faster startup""" def __init__(self): self.vocab = {} self.inverse_vocab = {} self.vocab_size = 0 self.pad_token = "" self.pad_token_id = 0 self.eos_token = "" self.eos_token_id = 1 self.unk_token = "" self.unk_token_id = 2 # Start with basic tokens self.add_token(self.pad_token) # ID 0 self.add_token(self.eos_token) # ID 1 self.add_token(self.unk_token) # ID 2 def add_token(self, token): if token not in self.vocab: self.vocab[token] = self.vocab_size self.inverse_vocab[self.vocab_size] = token self.vocab_size += 1 return True return False def build_vocab_from_texts(self, texts, max_vocab_size=10000): """Build vocabulary from all training texts""" print("Building vocabulary from training data...") # Count all tokens token_counter = Counter() for text in texts: tokens = text.split() token_counter.update(tokens) # Add most frequent tokens to vocabulary for token, _ in token_counter.most_common(max_vocab_size - self.vocab_size): self.add_token(token) print(f"Vocabulary built with {self.vocab_size} tokens") def tokenize(self, text): # Simple word-level tokenization tokens = text.split() token_ids = [] for token in tokens: if token in self.vocab: token_ids.append(self.vocab[token]) else: token_ids.append(self.unk_token_id) # Use UNK token for out-of-vocab words return token_ids def encode(self, text, max_length=None, padding=False, truncation=False): token_ids = self.tokenize(text) if truncation and max_length and len(token_ids) > max_length: token_ids = token_ids[:max_length] if padding and max_length and len(token_ids) < max_length: token_ids = token_ids + [self.pad_token_id] * (max_length - len(token_ids)) return token_ids def decode(self, token_ids): # Remove padding tokens for cleaner output filtered_ids = [id for id in token_ids if id != self.pad_token_id] return " ".join([self.inverse_vocab.get(id, self.unk_token) for id in filtered_ids]) class TextDataset(Dataset): def __init__(self, texts, tokenizer, max_length=512): self.tokenizer = tokenizer self.texts = texts self.max_length = max_length # Filter out empty texts self.texts = [text for text in texts if text.strip()] def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.texts[idx] # Ensure text is not empty if not text.strip(): text = " " # Use space for empty text token_ids = self.tokenizer.encode( text, max_length=self.max_length, padding=True, truncation=True ) # Convert to tensor and ensure all IDs are within valid range token_ids = [min(id, self.tokenizer.vocab_size - 1) for id in token_ids] return { 'input_ids': torch.tensor(token_ids, dtype=torch.long), 'labels': torch.tensor(token_ids, dtype=torch.long) } class SimpleGPT(nn.Module): """A simplified GPT-like model for faster training""" def __init__(self, vocab_size, d_model=512, n_layers=6, n_heads=8, max_seq_len=512): super().__init__() self.d_model = d_model self.vocab_size = vocab_size self.max_seq_len = max_seq_len # Token and position embeddings self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=0) # padding_idx=0 for pad token self.position_embedding = nn.Embedding(max_seq_len, d_model) # Transformer layers encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=d_model * 4, batch_first=True, dropout=0.1 ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) # Output layer with dropout for regularization self.dropout = nn.Dropout(0.1) self.output_layer = nn.Linear(d_model, vocab_size) # Initialize weights properly self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, input_ids, labels=None): batch_size, seq_len = input_ids.shape # Ensure all token IDs are within valid range input_ids = torch.clamp(input_ids, 0, self.vocab_size - 1) # Create token embeddings token_embeds = self.token_embedding(input_ids) # Create position embeddings positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, seq_len) position_embeds = self.position_embedding(positions) # Combine embeddings x = token_embeds + position_embeds # Create attention mask (ignore padding tokens) attention_mask = (input_ids != 0).float() # Transformer with attention mask x = self.transformer(x, src_key_padding_mask=attention_mask == 0) # Apply dropout x = self.dropout(x) # Output logits = self.output_layer(x) # Calculate loss if labels provided loss = None if labels is not None: # Ensure labels are within valid range labels = torch.clamp(labels, 0, self.vocab_size - 1) # Create loss mask to ignore padding tokens loss_mask = (labels != 0).float() loss_fn = nn.CrossEntropyLoss(ignore_index=0, reduction='none') # ignore padding losses = loss_fn(logits.view(-1, self.vocab_size), labels.view(-1)) loss = (losses * loss_mask.view(-1)).sum() / loss_mask.sum() return {'logits': logits, 'loss': loss} class AITrainerApp: def __init__(self): # Use simple tokenizer for faster startup self.tokenizer = SimpleTokenizer() self.model = None self.training_data = [] # Default model configuration self.model_config = { "d_model": 512, "n_layers": 6, "n_heads": 8, "max_seq_len": 512 } # Training control self.training_thread = None self.stop_training_flag = False self.training_status = "Ready - Load training data to begin" self.output_log = "Training output will appear here...\n" def get_device(self, device_type="auto"): """Get the selected device based on user choice""" if device_type == "auto": return torch.device('cuda' if torch.cuda.is_available() else 'cpu') elif device_type == "cuda": if torch.cuda.is_available(): return torch.device('cuda') else: return torch.device('cpu') else: return torch.device('cpu') def log_output(self, message): """Add message to output log""" self.output_log += message + "\n" return self.output_log def verify_model_file(self, file_path): """Verify if a model file is valid before loading""" try: # Simple file checks if not os.path.exists(file_path): return False, "File does not exist" if os.path.getsize(file_path) < 1024: # Less than 1KB return False, "File is too small to be a valid model" return True, "File appears valid" except Exception as e: return False, f"Error verifying file: {str(e)}" def load_training_files(self, files): """Load training files from provided file objects""" if not files: return "No files selected", self.output_log total_texts = [] for file_info in files: try: # Read the content from the file object content = file_info.read().decode('utf-8') # Split into smaller chunks if needed chunks = self.split_into_chunks(content, 1000) total_texts.extend(chunks) self.output_log = self.log_output(f"Loaded {len(chunks)} chunks from {file_info.name}") except Exception as e: error_msg = f"Error reading {file_info.name}: {str(e)}" self.output_log = self.log_output(error_msg) return error_msg, self.output_log self.training_data.extend(total_texts) # Build vocabulary from all training texts self.tokenizer.build_vocab_from_texts(self.training_data, max_vocab_size=10000) status_msg = f"Loaded {len(total_texts)} text chunks from {len(files)} files" self.output_log = self.log_output(status_msg) self.output_log = self.log_output(f"Vocabulary size: {self.tokenizer.vocab_size}") return status_msg, self.output_log def split_into_chunks(self, text, chunk_size): words = text.split() chunks = [] for i in range(0, len(words), chunk_size): chunk = ' '.join(words[i:i+chunk_size]) chunks.append(chunk) return chunks def view_training_data(self): if not self.training_data: return "No training data loaded" preview = "" for i, text in enumerate(self.training_data[:50]): # Show first 50 chunks preview += f"Chunk {i+1}:\n{text}\n\n{'='*50}\n\n" return preview def start_training(self, d_model, n_layers, n_heads, batch_size, learning_rate, epochs, device_type): if not self.training_data: error_msg = "Error: No training data loaded!" self.output_log = self.log_output(error_msg) return error_msg, self.output_log, gr.update(interactive=True) self.stop_training_flag = False self.training_status = "Training started..." self.output_log = self.log_output("Starting training...") # Update model config from UI self.model_config.update({ "d_model": int(d_model), "n_layers": int(n_layers), "n_heads": int(n_heads) }) # Start training in separate thread self.training_thread = threading.Thread( target=self.train_model, args=(int(batch_size), float(learning_rate), int(epochs), device_type) ) self.training_thread.daemon = True self.training_thread.start() return "Training started...", self.output_log, gr.update(interactive=False) def stop_training(self): self.stop_training_flag = True self.training_status = "Stopping training..." self.output_log = self.log_output("Stopping training...") return "Stopping training...", self.output_log, gr.update(interactive=True) def train_model(self, batch_size, learning_rate, epochs, device_type): try: # Create dataset and dataloader dataset = TextDataset(self.training_data, self.tokenizer) dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=True ) # Initialize model self.model = SimpleGPT( vocab_size=self.tokenizer.vocab_size, d_model=self.model_config["d_model"], n_layers=self.model_config["n_layers"], n_heads=self.model_config["n_heads"], max_seq_len=self.model_config["max_seq_len"] ) # Setup optimizer optimizer = optim.AdamW( self.model.parameters(), lr=learning_rate ) # Training loop device = self.get_device(device_type) self.model.to(device) self.output_log = self.log_output(f"Using device: {device}") for epoch in range(epochs): if self.stop_training_flag: break self.model.train() total_loss = 0 total_batches = 0 for batch_idx, batch in enumerate(dataloader): if self.stop_training_flag: break optimizer.zero_grad() input_ids = batch['input_ids'].to(device) labels = batch['labels'].to(device) # Debug: Check for invalid token IDs max_id = input_ids.max().item() if max_id >= self.tokenizer.vocab_size: self.output_log = self.log_output(f"Warning: Found token ID {max_id} but vocab size is {self.tokenizer.vocab_size}") # Clamp values to valid range input_ids = torch.clamp(input_ids, 0, self.tokenizer.vocab_size - 1) labels = torch.clamp(labels, 0, self.tokenizer.vocab_size - 1) outputs = self.model(input_ids=input_ids, labels=labels) loss = outputs['loss'] if torch.isnan(loss) or torch.isinf(loss): self.output_log = self.log_output("Warning: NaN or Inf loss detected, skipping batch") continue loss.backward() # Gradient clipping to prevent explosions torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) optimizer.step() total_loss += loss.item() total_batches += 1 if batch_idx % 10 == 0: status_msg = f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}" self.training_status = status_msg if batch_idx % 50 == 0: # Log less frequently to avoid UI slowdown self.output_log = self.log_output(status_msg) if total_batches > 0: avg_loss = total_loss / total_batches epoch_msg = f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}" self.training_status = epoch_msg self.output_log = self.log_output(epoch_msg) if not self.stop_training_flag: completion_msg = "Training completed successfully!" self.training_status = completion_msg self.output_log = self.log_output(completion_msg) except Exception as e: error_msg = f"Training error: {str(e)}" self.training_status = error_msg self.output_log = self.log_output(error_msg) import traceback self.output_log = self.log_output(traceback.format_exc()) finally: self.stop_training_flag = False # Re-enable the start training button return gr.update(interactive=True) def save_model(self, file_path): if self.model is None: self.output_log = self.log_output("Error: No model to save!") return "Error: No model to save!", self.output_log try: torch.save({ 'model_state_dict': self.model.state_dict(), 'tokenizer': self.tokenizer, 'config': self.model_config, 'training_data_info': { 'num_chunks': len(self.training_data), 'vocab_size': self.tokenizer.vocab_size } }, file_path) success_msg = f"Model saved to {file_path}" self.training_status = success_msg self.output_log = self.log_output(success_msg) return success_msg, self.output_log except Exception as e: error_msg = f"Error saving model: {str(e)}" self.output_log = self.log_output(error_msg) return error_msg, self.output_log def load_model(self, file_path): if not file_path: return "No file selected", self.output_log try: checkpoint = torch.load(file_path, map_location='cpu') # Recreate the model architecture self.model_config = checkpoint['config'] self.model = SimpleGPT( vocab_size=checkpoint['tokenizer'].vocab_size, d_model=self.model_config["d_model"], n_layers=self.model_config["n_layers"], n_heads=self.model_config["n_heads"], max_seq_len=self.model_config["max_seq_len"] ) # Load weights self.model.load_state_dict(checkpoint['model_state_dict']) # Load tokenizer self.tokenizer = checkpoint['tokenizer'] success_msg = f"Model loaded from {file_path}" self.training_status = success_msg self.output_log = self.log_output(success_msg) return success_msg, self.output_log, str(self.model_config['d_model']), str(self.model_config['n_layers']), str(self.model_config['n_heads']) except Exception as e: error_msg = f"Error loading model: {str(e)}" self.output_log = self.log_output(error_msg) return error_msg, self.output_log, gr.update(), gr.update(), gr.update() # Create the app instance app = AITrainerApp() # Create Gradio interface with gr.Blocks(title="AI Text Generation Trainer") as demo: gr.Markdown("# AI Text Generation Trainer") with gr.Row(): with gr.Column(scale=1): gr.Markdown("## Controls") # Data management gr.Markdown("### Data Management") file_input = gr.File(file_count="multiple", label="Training Files") load_btn = gr.Button("Load Text Files") view_data_btn = gr.Button("View Training Data") data_preview = gr.Textbox(label="Training Data Preview", lines=10, interactive=False) # Device selection gr.Markdown("### Device Selection") device_type = gr.Radio( choices=["auto", "cpu", "cuda"], value="auto", label="Processing Device" ) device_info = gr.Textbox( label="Device Info", value=f"GPU available: {'Yes' if torch.cuda.is_available() else 'No'}", interactive=False ) # Model configuration gr.Markdown("### Model Configuration") d_model = gr.Number(value=512, label="Embedding Size") n_layers = gr.Number(value=6, label="Number of Layers") n_heads = gr.Number(value=8, label="Number of Heads") # Training parameters gr.Markdown("### Training Parameters") batch_size = gr.Number(value=4, label="Batch Size") learning_rate = gr.Number(value=0.001, label="Learning Rate") epochs = gr.Number(value=3, label="Epochs") # Training controls gr.Markdown("### Training Control") start_btn = gr.Button("Start Training", variant="primary") stop_btn = gr.Button("Stop Training") # Export buttons gr.Markdown("### Export Model") save_path = gr.Textbox(label="Save Path", value="model.pth") save_btn = gr.Button("Save Model") load_path = gr.Textbox(label="Load Path", value="model.pth") load_btn = gr.Button("Load Model") with gr.Column(scale=2): gr.Markdown("## Status & Output") status = gr.Textbox(label="Status", value=app.training_status, interactive=False) output = gr.Textbox(label="Output Log", value=app.output_log, lines=20, interactive=False) # Define event handlers load_btn.click( app.load_training_files, inputs=[file_input], outputs=[status, output] ) view_data_btn.click( app.view_training_data, inputs=[], outputs=[data_preview] ) start_btn.click( app.start_training, inputs=[d_model, n_layers, n_heads, batch_size, learning_rate, epochs, device_type], outputs=[status, output, start_btn] ) stop_btn.click( app.stop_training, inputs=[], outputs=[status, output, start_btn] ) save_btn.click( app.save_model, inputs=[save_path], outputs=[status, output] ) load_btn.click( app.load_model, inputs=[load_path], outputs=[status, output, d_model, n_layers, n_heads] ) if __name__ == "__main__": demo.launch()