#!/usr/bin/python # -*- coding: utf-8 -*- """ Training Script for SyncNetFCN on VoxCeleb2 Usage: python train_syncnet_fcn_complete.py --data_dir E:/voxceleb2_dataset/VoxCeleb2/dev --pretrained_model data/syncnet_v2.model """ import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader import os import argparse import numpy as np from SyncNetModel_FCN import StreamSyncFCN import glob import random import cv2 import subprocess from scipy.io import wavfile import python_speech_features class VoxCeleb2Dataset(Dataset): """VoxCeleb2 dataset loader for sync training with real preprocessing.""" def __init__(self, data_dir, max_offset=15, video_length=25, temp_dir='temp_dataset'): """ Args: data_dir: Path to VoxCeleb2 root directory max_offset: Maximum frame offset for negative samples video_length: Number of frames per clip temp_dir: Temporary directory for audio extraction """ self.data_dir = data_dir self.max_offset = max_offset self.video_length = video_length self.temp_dir = temp_dir os.makedirs(temp_dir, exist_ok=True) # Find all video files self.video_files = glob.glob(os.path.join(data_dir, '**', '*.mp4'), recursive=True) print(f"Found {len(self.video_files)} videos in dataset") def __len__(self): return len(self.video_files) def _extract_audio_mfcc(self, video_path): """Extract audio and compute MFCC features.""" # Create unique temp audio file video_id = os.path.splitext(os.path.basename(video_path))[0] audio_path = os.path.join(self.temp_dir, f'{video_id}_audio.wav') try: # Extract audio using FFmpeg cmd = ['ffmpeg', '-y', '-i', video_path, '-ac', '1', '-ar', '16000', '-vn', '-acodec', 'pcm_s16le', audio_path] result = subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, timeout=30) if result.returncode != 0: raise RuntimeError(f"FFmpeg failed for {video_path}: {result.stderr.decode(errors='ignore')}") # Read audio and compute MFCC try: sample_rate, audio = wavfile.read(audio_path) except Exception as e: raise RuntimeError(f"wavfile.read failed for {audio_path}: {e}") # Ensure audio is 1D if isinstance(audio, np.ndarray) and len(audio.shape) > 1: audio = audio.mean(axis=1) # Check for empty or invalid audio if not isinstance(audio, np.ndarray) or audio.size == 0: raise ValueError(f"Audio data is empty or invalid for {audio_path}") # Compute MFCC try: mfcc = python_speech_features.mfcc(audio, sample_rate, numcep=13) except Exception as e: raise RuntimeError(f"MFCC extraction failed for {audio_path}: {e}") # Shape: [T, 13] -> [13, T] -> [1, 1, 13, T] mfcc_tensor = torch.FloatTensor(mfcc.T).unsqueeze(0).unsqueeze(0) # [1, 1, 13, T] # Clean up temp file if os.path.exists(audio_path): try: os.remove(audio_path) except Exception: pass return mfcc_tensor except Exception as e: # Clean up temp file on error if os.path.exists(audio_path): try: os.remove(audio_path) except Exception: pass raise RuntimeError(f"Failed to extract audio from {video_path}: {e}") def _extract_video_frames(self, video_path, target_size=(112, 112)): """Extract video frames as tensor.""" cap = cv2.VideoCapture(video_path) frames = [] while True: ret, frame = cap.read() if not ret: break # Resize and normalize frame = cv2.resize(frame, target_size) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame.astype(np.float32) / 255.0) cap.release() if not frames: raise ValueError(f"No frames extracted from {video_path}") # Stack and convert to tensor [T, H, W, 3] -> [3, T, H, W] frames_array = np.stack(frames, axis=0) video_tensor = torch.FloatTensor(frames_array).permute(3, 0, 1, 2).unsqueeze(0) return video_tensor def _crop_or_pad_video(self, video_tensor, target_length): """Crop or pad video to target length.""" B, C, T, H, W = video_tensor.shape if T > target_length: # Random crop start = random.randint(0, T - target_length) return video_tensor[:, :, start:start+target_length, :, :] elif T < target_length: # Pad with last frame pad_length = target_length - T last_frame = video_tensor[:, :, -1:, :, :].repeat(1, 1, pad_length, 1, 1) return torch.cat([video_tensor, last_frame], dim=2) else: return video_tensor def _crop_or_pad_audio(self, audio_tensor, target_length): """Crop or pad audio to target length.""" B, C, T = audio_tensor.shape if T > target_length: # Random crop start = random.randint(0, T - target_length) return audio_tensor[:, :, start:start+target_length] elif T < target_length: # Pad with zeros pad_length = target_length - T padding = torch.zeros(B, C, pad_length) return torch.cat([audio_tensor, padding], dim=2) else: return audio_tensor def __getitem__(self, idx): """ Returns: audio: [1, 13, T] MFCC features video: [3, T_frames, H, W] video frames offset: Ground truth offset (0 for positive, non-zero for negative) label: 1 if in sync, 0 if out of sync """ import time video_path = self.video_files[idx] t0 = time.time() # Randomly decide if this should be positive (sync) or negative (out-of-sync) is_positive = random.random() > 0.5 if is_positive: offset = 0 label = 1 else: # Random offset between 1 and max_offset offset = random.randint(1, self.max_offset) * random.choice([-1, 1]) label = 0 # Log offset/label distribution occasionally if random.random() < 0.01: print(f"[INFO][VoxCeleb2Dataset] idx={idx}, path={video_path}, offset={offset}, label={label}") try: # Extract audio MFCC features t_audio0 = time.time() audio = self._extract_audio_mfcc(video_path) t_audio1 = time.time() # Log audio tensor shape/dtype if random.random() < 0.01: print(f"[INFO][Audio] idx={idx}, path={video_path}, shape={audio.shape}, dtype={audio.dtype}, time={t_audio1-t_audio0:.2f}s") # Extract video frames t_vid0 = time.time() video = self._extract_video_frames(video_path) t_vid1 = time.time() # Log number of frames if random.random() < 0.01: print(f"[INFO][Video] idx={idx}, path={video_path}, frames={video.shape[2] if video.dim()==5 else 'ERR'}, shape={video.shape}, dtype={video.dtype}, time={t_vid1-t_vid0:.2f}s") # Apply temporal offset for negative samples if not is_positive and offset != 0: if offset > 0: # Shift video forward (cut from beginning) video = video[:, :, offset:, :, :] else: # Shift video backward (cut from end) video = video[:, :, :offset, :, :] # Crop/pad to fixed length video = self._crop_or_pad_video(video, self.video_length) audio = self._crop_or_pad_audio(audio, self.video_length * 4) # Remove batch dimension (DataLoader will add it) # audio is [1, 1, 13, T], squeeze to [1, 13, T] audio = audio.squeeze(0) # [1, 13, T] video = video.squeeze(0) # [3, T, H, W] # Check for shape mismatches if audio.shape[0] != 13: raise ValueError(f"Audio MFCC shape mismatch: {audio.shape} for {video_path}") if video.shape[0] != 3 or video.shape[2] != 112 or video.shape[3] != 112: raise ValueError(f"Video frame shape mismatch: {video.shape} for {video_path}") t1 = time.time() if random.random() < 0.01: print(f"[INFO][Sample] idx={idx}, path={video_path}, total_time={t1-t0:.2f}s") dummy = False except Exception as e: # Fallback to dummy data if preprocessing fails # Only print occasionally to avoid spam import traceback print(f"[WARN][VoxCeleb2Dataset] idx={idx}, path={video_path}, ERROR_STAGE=__getitem__, error={str(e)[:100]}") traceback.print_exc(limit=1) audio = torch.randn(1, 13, self.video_length * 4) video = torch.randn(3, self.video_length, 112, 112) offset = 0 label = 1 dummy = True # Resource cleanup: ensure no temp files left behind (audio) temp_audio = os.path.join(self.temp_dir, f'{os.path.splitext(os.path.basename(video_path))[0]}_audio.wav') if os.path.exists(temp_audio): try: os.remove(temp_audio) except Exception: pass # Log dummy sample usage if dummy and random.random() < 0.5: print(f"[WARN][VoxCeleb2Dataset] idx={idx}, path={video_path}, DUMMY_SAMPLE_USED") return { 'audio': audio, 'video': video, 'offset': torch.tensor(offset, dtype=torch.float32), 'label': torch.tensor(label, dtype=torch.float32), 'dummy': dummy } class SyncLoss(nn.Module): """Binary cross-entropy loss for sync/no-sync classification.""" def __init__(self): super(SyncLoss, self).__init__() self.bce = nn.BCEWithLogitsLoss() def forward(self, sync_probs, labels): """ Args: sync_probs: [B, 2*K+1, T] sync probability distribution labels: [B] binary labels (1=sync, 0=out-of-sync) """ # Take max probability across offsets and time max_probs = sync_probs.max(dim=1)[0].max(dim=1)[0] # [B] # BCE loss loss = self.bce(max_probs, labels) return loss def train_epoch(model, dataloader, optimizer, criterion, device): """Train for one epoch.""" model.train() total_loss = 0 correct = 0 total = 0 import torch import gc for batch_idx, batch in enumerate(dataloader): audio = batch['audio'].to(device) video = batch['video'].to(device) labels = batch['label'].to(device) # Log dummy data in batch if 'dummy' in batch: num_dummy = batch['dummy'].sum().item() if hasattr(batch['dummy'], 'sum') else int(sum(batch['dummy'])) if num_dummy > 0: print(f"[WARN][train_epoch] Batch {batch_idx}: {num_dummy}/{len(labels)} dummy samples in batch!") # Forward pass optimizer.zero_grad() sync_probs, _, _ = model(audio, video) # Log tensor shapes if batch_idx % 50 == 0: print(f"[INFO][train_epoch] Batch {batch_idx}: audio {audio.shape}, video {video.shape}, sync_probs {sync_probs.shape}") # Compute loss loss = criterion(sync_probs, labels) # Backward pass loss.backward() optimizer.step() # Statistics total_loss += loss.item() pred = (sync_probs.max(dim=1)[0].max(dim=1)[0] > 0.5).float() correct += (pred == labels).sum().item() total += labels.size(0) # Log memory usage occasionally if batch_idx % 100 == 0 and torch.cuda.is_available(): mem = torch.cuda.memory_allocated() / 1024**2 print(f"[INFO][train_epoch] Batch {batch_idx}: GPU memory used: {mem:.2f} MB") if batch_idx % 10 == 0: print(f' Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}, Acc: {100*correct/total:.2f}%') # Clean up del audio, video, labels gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() avg_loss = total_loss / len(dataloader) accuracy = 100 * correct / total return avg_loss, accuracy def main(): parser = argparse.ArgumentParser(description='Train SyncNetFCN') parser.add_argument('--data_dir', type=str, required=True, help='VoxCeleb2 root directory') parser.add_argument('--pretrained_model', type=str, default='data/syncnet_v2.model', help='Pretrained SyncNet model') parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)') parser.add_argument('--epochs', type=int, default=10, help='Number of epochs') parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') parser.add_argument('--output_dir', type=str, default='checkpoints', help='Output directory') parser.add_argument('--use_attention', action='store_true', help='Use attention model') parser.add_argument('--num_workers', type=int, default=2, help='DataLoader workers') args = parser.parse_args() # Device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f'Using device: {device}') # Create output directory os.makedirs(args.output_dir, exist_ok=True) # Create model with transfer learning print('Creating model...') model = StreamSyncFCN( pretrained_syncnet_path=args.pretrained_model, auto_load_pretrained=True, use_attention=args.use_attention ) model = model.to(device) print(f'Model created. Pretrained conv layers loaded and frozen.') # Dataset and dataloader print('Loading dataset...') dataset = VoxCeleb2Dataset(args.data_dir) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) # Loss and optimizer criterion = SyncLoss() # Only optimize non-frozen parameters trainable_params = [p for p in model.parameters() if p.requires_grad] optimizer = optim.Adam(trainable_params, lr=args.lr) print(f'Trainable parameters: {sum(p.numel() for p in trainable_params):,}') print(f'Frozen parameters: {sum(p.numel() for p in model.parameters() if not p.requires_grad):,}') # Training loop print('\nStarting training...') print('='*80) for epoch in range(args.epochs): print(f'\nEpoch {epoch+1}/{args.epochs}') print('-'*80) avg_loss, accuracy = train_epoch(model, dataloader, optimizer, criterion, device) print(f'\nEpoch {epoch+1} Summary:') print(f' Average Loss: {avg_loss:.4f}') print(f' Accuracy: {accuracy:.2f}%') # Save checkpoint checkpoint_path = os.path.join(args.output_dir, f'syncnet_fcn_epoch{epoch+1}.pth') torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': avg_loss, 'accuracy': accuracy, }, checkpoint_path) print(f' Checkpoint saved: {checkpoint_path}') print('\n' + '='*80) print('Training complete!') print(f'Final model saved to: {args.output_dir}') if __name__ == '__main__': main()