#!/usr/bin/python # -*- coding: utf-8 -*- """ IMPROVED Training Script for SyncNetFCN on VoxCeleb2 Key Fixes: 1. Corrected loss function: CrossEntropyLoss for offset prediction (31 classes) 2. Removed dummy data fallback 3. Reduced logging overhead 4. Added proper metrics tracking (exact accuracy, ±1 frame accuracy, MAE) 5. Added temporal consistency regularization 6. Better learning rate scheduling Usage: python train_syncnet_fcn_improved.py --data_dir E:/voxceleb2_dataset/VoxCeleb2/dev --pretrained_model data/syncnet_v2.model --checkpoint checkpoints/syncnet_fcn_epoch2.pth """ import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader import os import argparse import numpy as np from SyncNetModel_FCN import StreamSyncFCN import glob import random import cv2 import subprocess from scipy.io import wavfile import python_speech_features import time class VoxCeleb2DatasetImproved(Dataset): """Improved VoxCeleb2 dataset loader with fixed label format and no dummy data.""" def __init__(self, data_dir, max_offset=15, video_length=25, temp_dir='temp_dataset'): """ Args: data_dir: Path to VoxCeleb2 root directory max_offset: Maximum frame offset for negative samples video_length: Number of frames per clip temp_dir: Temporary directory for audio extraction """ self.data_dir = data_dir self.max_offset = max_offset self.video_length = video_length self.temp_dir = temp_dir os.makedirs(temp_dir, exist_ok=True) # Find all video files self.video_files = glob.glob(os.path.join(data_dir, '**', '*.mp4'), recursive=True) print(f"Found {len(self.video_files)} videos in dataset") # Track failed samples self.failed_samples = set() def __len__(self): return len(self.video_files) def _extract_audio_mfcc(self, video_path): """Extract audio and compute MFCC features.""" video_id = os.path.splitext(os.path.basename(video_path))[0] audio_path = os.path.join(self.temp_dir, f'{video_id}_audio.wav') try: # Extract audio using FFmpeg cmd = ['ffmpeg', '-y', '-i', video_path, '-ac', '1', '-ar', '16000', '-vn', '-acodec', 'pcm_s16le', audio_path] result = subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, timeout=30) if result.returncode != 0: raise RuntimeError(f"FFmpeg failed") # Read audio and compute MFCC sample_rate, audio = wavfile.read(audio_path) # Ensure audio is 1D if isinstance(audio, np.ndarray) and len(audio.shape) > 1: audio = audio.mean(axis=1) if not isinstance(audio, np.ndarray) or audio.size == 0: raise ValueError(f"Audio data is empty") # Compute MFCC mfcc = python_speech_features.mfcc(audio, sample_rate, numcep=13) mfcc_tensor = torch.FloatTensor(mfcc.T).unsqueeze(0).unsqueeze(0) # Clean up temp file if os.path.exists(audio_path): try: os.remove(audio_path) except Exception: pass return mfcc_tensor except Exception as e: if os.path.exists(audio_path): try: os.remove(audio_path) except Exception: pass raise RuntimeError(f"Failed to extract audio: {e}") def _extract_video_frames(self, video_path, target_size=(112, 112)): """Extract video frames as tensor.""" cap = cv2.VideoCapture(video_path) frames = [] while True: ret, frame = cap.read() if not ret: break frame = cv2.resize(frame, target_size) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame.astype(np.float32) / 255.0) cap.release() if not frames: raise ValueError(f"No frames extracted") frames_array = np.stack(frames, axis=0) video_tensor = torch.FloatTensor(frames_array).permute(3, 0, 1, 2).unsqueeze(0) return video_tensor def _crop_or_pad_video(self, video_tensor, target_length): """Crop or pad video to target length.""" B, C, T, H, W = video_tensor.shape if T > target_length: start = random.randint(0, T - target_length) return video_tensor[:, :, start:start+target_length, :, :] elif T < target_length: pad_length = target_length - T last_frame = video_tensor[:, :, -1:, :, :].repeat(1, 1, pad_length, 1, 1) return torch.cat([video_tensor, last_frame], dim=2) else: return video_tensor def _crop_or_pad_audio(self, audio_tensor, target_length): """Crop or pad audio to target length.""" B, C, F, T = audio_tensor.shape if T > target_length: start = random.randint(0, T - target_length) return audio_tensor[:, :, :, start:start+target_length] elif T < target_length: pad_length = target_length - T padding = torch.zeros(B, C, F, pad_length) return torch.cat([audio_tensor, padding], dim=3) else: return audio_tensor def __getitem__(self, idx): """ Returns: audio: [1, 13, T] MFCC features video: [3, T_frames, H, W] video frames offset: Ground truth offset in frames (integer from -15 to +15) """ video_path = self.video_files[idx] # Skip previously failed samples if idx in self.failed_samples: return self.__getitem__((idx + 1) % len(self)) # Balanced offset distribution # 20% synced (offset=0), 80% distributed across other offsets if random.random() < 0.2: offset = 0 else: # Exclude 0 from choices offset_choices = [o for o in range(-self.max_offset, self.max_offset + 1) if o != 0] offset = random.choice(offset_choices) # Log occasionally (every 1000 samples instead of random 1%) if idx % 1000 == 0: print(f"[INFO] Processing sample {idx}: offset={offset}") max_retries = 3 for attempt in range(max_retries): try: # Extract audio MFCC features audio = self._extract_audio_mfcc(video_path) # Extract video frames video = self._extract_video_frames(video_path) # Apply temporal offset for negative samples if offset != 0: if offset > 0: # Shift video forward (cut from beginning) video = video[:, :, offset:, :, :] else: # Shift video backward (cut from end) video = video[:, :, :offset, :, :] # Crop/pad to fixed length video = self._crop_or_pad_video(video, self.video_length) audio = self._crop_or_pad_audio(audio, self.video_length * 4) # Remove batch dimension audio = audio.squeeze(0) # [1, 13, T] video = video.squeeze(0) # [3, T, H, W] # Validate shapes if audio.shape[0] != 1 or audio.shape[1] != 13: raise ValueError(f"Audio MFCC shape mismatch: {audio.shape}") if audio.shape[2] != self.video_length * 4: # Force fix length if mismatch (should be handled by crop_or_pad but double check) audio = self._crop_or_pad_audio(audio.unsqueeze(0), self.video_length * 4).squeeze(0) if video.shape[0] != 3 or video.shape[2] != 112 or video.shape[3] != 112: raise ValueError(f"Video frame shape mismatch: {video.shape}") if video.shape[1] != self.video_length: # Force fix length video = self._crop_or_pad_video(video.unsqueeze(0), self.video_length).squeeze(0) # Final check if audio.shape != (1, 13, 100) or video.shape != (3, 25, 112, 112): raise ValueError(f"Final shape mismatch: Audio {audio.shape}, Video {video.shape}") return { 'audio': audio, 'video': video, 'offset': torch.tensor(offset, dtype=torch.long), # Integer offset, not binary } except Exception as e: if attempt == max_retries - 1: # Mark as failed and try next sample self.failed_samples.add(idx) if idx % 100 == 0: # Only log occasionally print(f"[WARN] Sample {idx} failed after {max_retries} attempts: {str(e)[:100]}") return self.__getitem__((idx + 1) % len(self)) continue class OffsetRegressionLoss(nn.Module): """L1 regression loss for continuous offset prediction.""" def __init__(self): super(OffsetRegressionLoss, self).__init__() self.l1 = nn.L1Loss() # More robust to outliers than MSE def forward(self, predicted_offsets, target_offsets): """ Args: predicted_offsets: [B, 1, T] - model output (continuous offset predictions) target_offsets: [B] - ground truth offset in frames (float) Returns: loss: scalar """ B, C, T = predicted_offsets.shape # Average over time dimension predicted_offsets_avg = predicted_offsets.mean(dim=2).squeeze(1) # [B] # L1 loss loss = self.l1(predicted_offsets_avg, target_offsets.float()) return loss def temporal_consistency_loss(predicted_offsets): """ Encourage smooth predictions over time. Args: predicted_offsets: [B, 1, T] Returns: consistency_loss: scalar """ # Compute difference between adjacent timesteps temporal_diff = predicted_offsets[:, :, 1:] - predicted_offsets[:, :, :-1] consistency_loss = (temporal_diff ** 2).mean() return consistency_loss def compute_metrics(predicted_offsets, target_offsets, max_offset=125): """ Compute comprehensive metrics for offset regression. Args: predicted_offsets: [B, 1, T] target_offsets: [B] Returns: dict with metrics """ B, C, T = predicted_offsets.shape # Average over time predicted_offsets_avg = predicted_offsets.mean(dim=2).squeeze(1) # [B] # Mean absolute error mae = torch.abs(predicted_offsets_avg - target_offsets).mean() # Root mean squared error rmse = torch.sqrt(((predicted_offsets_avg - target_offsets) ** 2).mean()) # Error buckets acc_1frame = (torch.abs(predicted_offsets_avg - target_offsets) <= 1).float().mean() acc_1sec = (torch.abs(predicted_offsets_avg - target_offsets) <= 25).float().mean() # Strict Sync Score (1 - error/25_frames) # 1.0 = perfect sync # 0.0 = >1 second error (unusable) abs_error = torch.abs(predicted_offsets_avg - target_offsets) sync_score = 1.0 - (abs_error / 25.0) # 25 frames = 1 second sync_score = torch.clamp(sync_score, 0.0, 1.0).mean() return { 'mae': mae.item(), 'rmse': rmse.item(), 'acc_1frame': acc_1frame.item(), 'acc_1sec': acc_1sec.item(), 'sync_score': sync_score.item() } def train_epoch(model, dataloader, optimizer, criterion, device, epoch_num): """Train for one epoch with regression metrics.""" model.train() total_loss = 0 total_offset_loss = 0 total_consistency_loss = 0 metrics_accum = {'mae': 0, 'rmse': 0, 'acc_1frame': 0, 'acc_1sec': 0, 'sync_score': 0} num_batches = 0 import gc for batch_idx, batch in enumerate(dataloader): audio = batch['audio'].to(device) video = batch['video'].to(device) offsets = batch['offset'].to(device) # Forward pass optimizer.zero_grad() predicted_offsets, _, _ = model(audio, video) # Compute losses offset_loss = criterion(predicted_offsets, offsets) consistency_loss = temporal_consistency_loss(predicted_offsets) # Combined loss loss = offset_loss + 0.1 * consistency_loss # Backward pass loss.backward() optimizer.step() # Statistics total_loss += loss.item() total_offset_loss += offset_loss.item() total_consistency_loss += consistency_loss.item() # Compute metrics with torch.no_grad(): metrics = compute_metrics(predicted_offsets, offsets) for key in metrics_accum: metrics_accum[key] += metrics[key] num_batches += 1 # Log every 10 batches if batch_idx % 10 == 0: print(f' Batch {batch_idx}/{len(dataloader)}, ' f'Loss: {loss.item():.4f}, ' f'MAE: {metrics["mae"]:.2f} frames, ' f'Score: {metrics["sync_score"]:.4f}') # Clean up del audio, video, offsets, predicted_offsets gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # Average metrics avg_loss = total_loss / num_batches avg_offset_loss = total_offset_loss / num_batches avg_consistency_loss = total_consistency_loss / num_batches for key in metrics_accum: metrics_accum[key] /= num_batches return avg_loss, avg_offset_loss, avg_consistency_loss, metrics_accum def main(): parser = argparse.ArgumentParser(description='Train SyncNetFCN (Improved)') parser.add_argument('--data_dir', type=str, required=True, help='VoxCeleb2 root directory') parser.add_argument('--pretrained_model', type=str, default='data/syncnet_v2.model', help='Pretrained SyncNet model') parser.add_argument('--checkpoint', type=str, default=None, help='Resume from checkpoint (optional)') parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)') parser.add_argument('--epochs', type=int, default=20, help='Number of epochs') parser.add_argument('--lr', type=float, default=0.00001, help='Learning rate (lowered from 0.001)') parser.add_argument('--output_dir', type=str, default='checkpoints_improved', help='Output directory') parser.add_argument('--use_attention', action='store_true', help='Use attention model') parser.add_argument('--num_workers', type=int, default=2, help='DataLoader workers') parser.add_argument('--max_offset', type=int, default=125, help='Max offset in frames (default: 125)') parser.add_argument('--unfreeze_epoch', type=int, default=10, help='Epoch to unfreeze all layers (default: 10)') args = parser.parse_args() # Device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f'Using device: {device}') # Create output directory os.makedirs(args.output_dir, exist_ok=True) # Create model with transfer learning (max_offset=125 for ±5 seconds) print(f'Creating model with max_offset={args.max_offset}...') model = StreamSyncFCN( max_offset=args.max_offset, # ±5 seconds at 25fps pretrained_syncnet_path=args.pretrained_model, auto_load_pretrained=True, use_attention=args.use_attention ) # Load from checkpoint if provided start_epoch = 0 if args.checkpoint and os.path.exists(args.checkpoint): print(f'Loading checkpoint: {args.checkpoint}') checkpoint = torch.load(args.checkpoint, map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) start_epoch = checkpoint.get('epoch', 0) print(f'Resuming from epoch {start_epoch}') model = model.to(device) print(f'Model created. Pretrained conv layers loaded and frozen.') # Dataset and dataloader print(f'Loading dataset with max_offset={args.max_offset}...') dataset = VoxCeleb2DatasetImproved(args.data_dir, max_offset=args.max_offset) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) # Loss and optimizer (REGRESSION) criterion = OffsetRegressionLoss() # Only optimize non-frozen parameters trainable_params = [p for p in model.parameters() if p.requires_grad] optimizer = optim.Adam(trainable_params, lr=args.lr) # Learning rate scheduler scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=5, # Restart every 5 epochs T_mult=2, # Double restart period each time eta_min=1e-7 # Minimum LR ) print(f'Trainable parameters: {sum(p.numel() for p in trainable_params):,}') print(f'Frozen parameters: {sum(p.numel() for p in model.parameters() if not p.requires_grad):,}') print(f'Learning rate: {args.lr}') # Training loop print('\\nStarting training...') print('='*80) best_tolerance_acc = 0 for epoch in range(start_epoch, start_epoch + args.epochs): print(f'\\nEpoch {epoch+1}/{start_epoch + args.epochs}') print('-'*80) # Unfreeze layers if reached unfreeze_epoch if epoch + 1 == args.unfreeze_epoch: print(f'\\nšŸ”“ Unfreezing all layers for fine-tuning at epoch {epoch+1}...') model.unfreeze_all_layers() # Lower learning rate for fine-tuning new_lr = args.lr * 0.1 print(f'šŸ“‰ Lowering learning rate to {new_lr} for fine-tuning') # Re-initialize optimizer with all parameters trainable_params = [p for p in model.parameters() if p.requires_grad] optimizer = optim.Adam(trainable_params, lr=new_lr) # Re-initialize scheduler scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=5, T_mult=2, eta_min=1e-8 ) print(f'Trainable parameters now: {sum(p.numel() for p in trainable_params):,}') avg_loss, avg_offset_loss, avg_consistency_loss, metrics = train_epoch( model, dataloader, optimizer, criterion, device, epoch ) # Step scheduler scheduler.step() current_lr = optimizer.param_groups[0]['lr'] print(f'\nEpoch {epoch+1} Summary:') print(f' Total Loss: {avg_loss:.4f}') print(f' Offset Loss: {avg_offset_loss:.4f}') print(f' Consistency Loss: {avg_consistency_loss:.4f}') print(f' MAE: {metrics["mae"]:.2f} frames ({metrics["mae"]/25:.3f} seconds)') print(f' RMSE: {metrics["rmse"]:.2f} frames') print(f' Sync Score: {metrics["sync_score"]:.4f} (1.0=Perfect, 0.0=>1s Error)') print(f' <1 Frame Acc: {metrics["acc_1frame"]*100:.2f}%') print(f' <1 Second Acc: {metrics["acc_1sec"]*100:.2f}%') print(f' Learning Rate: {current_lr:.2e}') # Save checkpoint checkpoint_path = os.path.join(args.output_dir, f'syncnet_fcn_improved_epoch{epoch+1}.pth') torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': avg_loss, 'offset_loss': avg_offset_loss, 'metrics': metrics, }, checkpoint_path) print(f' Checkpoint saved: {checkpoint_path}') # Save best model based on Sync Score if metrics['sync_score'] > best_tolerance_acc: best_tolerance_acc = metrics['sync_score'] best_path = os.path.join(args.output_dir, 'syncnet_fcn_best.pth') torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'metrics': metrics, }, best_path) print(f' āœ“ New best model saved! (Score: {best_tolerance_acc:.4f})') print('\n' + '='*80) print('Training complete!') print(f'Best Sync Score: {best_tolerance_acc:.4f}') print(f'Models saved to: {args.output_dir}') if __name__ == '__main__': main()