""" Custom Training Loop Module Provides custom training loop implementation with fine-grained control over training. """ from dataclasses import dataclass, field from typing import Optional, Dict, Any, Callable import torch from torch.utils.data import DataLoader from tqdm import tqdm @dataclass class TrainingConfig: """Configuration for custom training loop.""" num_epochs: int = 3 learning_rate: float = 2e-4 batch_size: int = 4 gradient_accumulation_steps: int = 4 max_grad_norm: float = 1.0 warmup_steps: int = 100 logging_steps: int = 10 eval_steps: int = 500 save_steps: int = 500 output_dir: str = "./models/output" device: str = "cuda" if torch.cuda.is_available() else "cpu" class TrainingLoop: """ Custom training loop for fine-grained control over the training process. Provides manual control over: - Forward/backward passes - Gradient accumulation - Learning rate scheduling - Logging and evaluation - Checkpointing """ def __init__( self, model: torch.nn.Module, train_dataloader: DataLoader, eval_dataloader: Optional[DataLoader] = None, config: Optional[TrainingConfig] = None ): """ Initialize custom training loop. Args: model: PyTorch model to train train_dataloader: Training data loader eval_dataloader: Optional evaluation data loader config: Training configuration """ self.model = model self.train_dataloader = train_dataloader self.eval_dataloader = eval_dataloader self.config = config or TrainingConfig() self.optimizer = None self.scheduler = None self.global_step = 0 self.current_epoch = 0 def setup_optimizer(self, optimizer_class=torch.optim.AdamW, **optimizer_kwargs): """ Setup optimizer and learning rate scheduler. Args: optimizer_class: Optimizer class to use **optimizer_kwargs: Additional optimizer arguments """ self.optimizer = optimizer_class( self.model.parameters(), lr=self.config.learning_rate, **optimizer_kwargs ) # Linear warmup scheduler def lr_lambda(current_step: int): if current_step < self.config.warmup_steps: return float(current_step) / float(max(1, self.config.warmup_steps)) return 1.0 self.scheduler = torch.optim.lr_scheduler.LambdaLR( self.optimizer, lr_lambda ) def train_step(self, batch: Dict[str, torch.Tensor]) -> float: """ Perform a single training step. Args: batch: Batch of training data Returns: Loss value """ # Move batch to device batch = {k: v.to(self.config.device) for k, v in batch.items()} # Forward pass outputs = self.model(**batch) loss = outputs.loss # Scale loss for gradient accumulation loss = loss / self.config.gradient_accumulation_steps # Backward pass loss.backward() return loss.item() def train_epoch(self) -> Dict[str, float]: """ Train for one epoch. Returns: Training metrics """ self.model.train() total_loss = 0 num_batches = 0 progress_bar = tqdm( self.train_dataloader, desc=f"Epoch {self.current_epoch + 1}/{self.config.num_epochs}" ) for step, batch in enumerate(progress_bar): # Training step loss = self.train_step(batch) total_loss += loss # Gradient accumulation if (step + 1) % self.config.gradient_accumulation_steps == 0: # Clip gradients torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.config.max_grad_norm ) # Optimizer step self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() self.global_step += 1 num_batches += 1 # Update progress bar progress_bar.set_postfix({ "loss": total_loss / num_batches, "lr": self.scheduler.get_last_lr()[0] }) # Logging if self.global_step % self.config.logging_steps == 0: avg_loss = total_loss / num_batches print(f"Step {self.global_step}: loss={avg_loss:.4f}") # Evaluation if self.eval_dataloader and self.global_step % self.config.eval_steps == 0: eval_metrics = self.evaluate() print(f"Evaluation: {eval_metrics}") self.model.train() return { "loss": total_loss / max(num_batches, 1), "epoch": self.current_epoch } def evaluate(self) -> Dict[str, float]: """ Evaluate model on validation set. Returns: Evaluation metrics """ if self.eval_dataloader is None: return {} self.model.eval() total_loss = 0 num_batches = 0 with torch.no_grad(): for batch in tqdm(self.eval_dataloader, desc="Evaluating"): batch = {k: v.to(self.config.device) for k, v in batch.items()} outputs = self.model(**batch) total_loss += outputs.loss.item() num_batches += 1 return { "eval_loss": total_loss / max(num_batches, 1) } def train(self, callback: Optional[Callable] = None) -> Dict[str, Any]: """ Run full training loop. Args: callback: Optional callback function called after each epoch Returns: Training history """ if self.optimizer is None: self.setup_optimizer() print(f"Starting training for {self.config.num_epochs} epochs") print(f"Device: {self.config.device}") print(f"Batch size: {self.config.batch_size}") print(f"Gradient accumulation: {self.config.gradient_accumulation_steps}") history = { "train_loss": [], "eval_loss": [] } for epoch in range(self.config.num_epochs): self.current_epoch = epoch # Train epoch train_metrics = self.train_epoch() history["train_loss"].append(train_metrics["loss"]) # Evaluate if self.eval_dataloader: eval_metrics = self.evaluate() history["eval_loss"].append(eval_metrics.get("eval_loss", 0)) # Callback if callback: callback(epoch, train_metrics) print("✅ Training complete!") return history def save_checkpoint(self, path: str) -> None: """ Save training checkpoint. Args: path: Path to save checkpoint """ checkpoint = { "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "scheduler_state_dict": self.scheduler.state_dict(), "global_step": self.global_step, "epoch": self.current_epoch } torch.save(checkpoint, path) print(f"Checkpoint saved to: {path}") def load_checkpoint(self, path: str) -> None: """ Load training checkpoint. Args: path: Path to checkpoint """ checkpoint = torch.load(path) self.model.load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) self.global_step = checkpoint["global_step"] self.current_epoch = checkpoint["epoch"] print(f"Checkpoint loaded from: {path}")