Syncnet_FCN / train_syncnet_fcn_improved.py
Shubham
Deploy clean version
579f772
#!/usr/bin/python
# -*- coding: utf-8 -*-
"""
IMPROVED Training Script for SyncNetFCN on VoxCeleb2
Key Fixes:
1. Corrected loss function: CrossEntropyLoss for offset prediction (31 classes)
2. Removed dummy data fallback
3. Reduced logging overhead
4. Added proper metrics tracking (exact accuracy, ±1 frame accuracy, MAE)
5. Added temporal consistency regularization
6. Better learning rate scheduling
Usage:
python train_syncnet_fcn_improved.py --data_dir E:/voxceleb2_dataset/VoxCeleb2/dev --pretrained_model data/syncnet_v2.model --checkpoint checkpoints/syncnet_fcn_epoch2.pth
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os
import argparse
import numpy as np
from SyncNetModel_FCN import StreamSyncFCN
import glob
import random
import cv2
import subprocess
from scipy.io import wavfile
import python_speech_features
import time
class VoxCeleb2DatasetImproved(Dataset):
"""Improved VoxCeleb2 dataset loader with fixed label format and no dummy data."""
def __init__(self, data_dir, max_offset=15, video_length=25, temp_dir='temp_dataset'):
"""
Args:
data_dir: Path to VoxCeleb2 root directory
max_offset: Maximum frame offset for negative samples
video_length: Number of frames per clip
temp_dir: Temporary directory for audio extraction
"""
self.data_dir = data_dir
self.max_offset = max_offset
self.video_length = video_length
self.temp_dir = temp_dir
os.makedirs(temp_dir, exist_ok=True)
# Find all video files
self.video_files = glob.glob(os.path.join(data_dir, '**', '*.mp4'), recursive=True)
print(f"Found {len(self.video_files)} videos in dataset")
# Track failed samples
self.failed_samples = set()
def __len__(self):
return len(self.video_files)
def _extract_audio_mfcc(self, video_path):
"""Extract audio and compute MFCC features."""
video_id = os.path.splitext(os.path.basename(video_path))[0]
audio_path = os.path.join(self.temp_dir, f'{video_id}_audio.wav')
try:
# Extract audio using FFmpeg
cmd = ['ffmpeg', '-y', '-i', video_path, '-ac', '1', '-ar', '16000',
'-vn', '-acodec', 'pcm_s16le', audio_path]
result = subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, timeout=30)
if result.returncode != 0:
raise RuntimeError(f"FFmpeg failed")
# Read audio and compute MFCC
sample_rate, audio = wavfile.read(audio_path)
# Ensure audio is 1D
if isinstance(audio, np.ndarray) and len(audio.shape) > 1:
audio = audio.mean(axis=1)
if not isinstance(audio, np.ndarray) or audio.size == 0:
raise ValueError(f"Audio data is empty")
# Compute MFCC
mfcc = python_speech_features.mfcc(audio, sample_rate, numcep=13)
mfcc_tensor = torch.FloatTensor(mfcc.T).unsqueeze(0).unsqueeze(0)
# Clean up temp file
if os.path.exists(audio_path):
try:
os.remove(audio_path)
except Exception:
pass
return mfcc_tensor
except Exception as e:
if os.path.exists(audio_path):
try:
os.remove(audio_path)
except Exception:
pass
raise RuntimeError(f"Failed to extract audio: {e}")
def _extract_video_frames(self, video_path, target_size=(112, 112)):
"""Extract video frames as tensor."""
cap = cv2.VideoCapture(video_path)
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frame = cv2.resize(frame, target_size)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame.astype(np.float32) / 255.0)
cap.release()
if not frames:
raise ValueError(f"No frames extracted")
frames_array = np.stack(frames, axis=0)
video_tensor = torch.FloatTensor(frames_array).permute(3, 0, 1, 2).unsqueeze(0)
return video_tensor
def _crop_or_pad_video(self, video_tensor, target_length):
"""Crop or pad video to target length."""
B, C, T, H, W = video_tensor.shape
if T > target_length:
start = random.randint(0, T - target_length)
return video_tensor[:, :, start:start+target_length, :, :]
elif T < target_length:
pad_length = target_length - T
last_frame = video_tensor[:, :, -1:, :, :].repeat(1, 1, pad_length, 1, 1)
return torch.cat([video_tensor, last_frame], dim=2)
else:
return video_tensor
def _crop_or_pad_audio(self, audio_tensor, target_length):
"""Crop or pad audio to target length."""
B, C, F, T = audio_tensor.shape
if T > target_length:
start = random.randint(0, T - target_length)
return audio_tensor[:, :, :, start:start+target_length]
elif T < target_length:
pad_length = target_length - T
padding = torch.zeros(B, C, F, pad_length)
return torch.cat([audio_tensor, padding], dim=3)
else:
return audio_tensor
def __getitem__(self, idx):
"""
Returns:
audio: [1, 13, T] MFCC features
video: [3, T_frames, H, W] video frames
offset: Ground truth offset in frames (integer from -15 to +15)
"""
video_path = self.video_files[idx]
# Skip previously failed samples
if idx in self.failed_samples:
return self.__getitem__((idx + 1) % len(self))
# Balanced offset distribution
# 20% synced (offset=0), 80% distributed across other offsets
if random.random() < 0.2:
offset = 0
else:
# Exclude 0 from choices
offset_choices = [o for o in range(-self.max_offset, self.max_offset + 1) if o != 0]
offset = random.choice(offset_choices)
# Log occasionally (every 1000 samples instead of random 1%)
if idx % 1000 == 0:
print(f"[INFO] Processing sample {idx}: offset={offset}")
max_retries = 3
for attempt in range(max_retries):
try:
# Extract audio MFCC features
audio = self._extract_audio_mfcc(video_path)
# Extract video frames
video = self._extract_video_frames(video_path)
# Apply temporal offset for negative samples
if offset != 0:
if offset > 0:
# Shift video forward (cut from beginning)
video = video[:, :, offset:, :, :]
else:
# Shift video backward (cut from end)
video = video[:, :, :offset, :, :]
# Crop/pad to fixed length
video = self._crop_or_pad_video(video, self.video_length)
audio = self._crop_or_pad_audio(audio, self.video_length * 4)
# Remove batch dimension
audio = audio.squeeze(0) # [1, 13, T]
video = video.squeeze(0) # [3, T, H, W]
# Validate shapes
if audio.shape[0] != 1 or audio.shape[1] != 13:
raise ValueError(f"Audio MFCC shape mismatch: {audio.shape}")
if audio.shape[2] != self.video_length * 4:
# Force fix length if mismatch (should be handled by crop_or_pad but double check)
audio = self._crop_or_pad_audio(audio.unsqueeze(0), self.video_length * 4).squeeze(0)
if video.shape[0] != 3 or video.shape[2] != 112 or video.shape[3] != 112:
raise ValueError(f"Video frame shape mismatch: {video.shape}")
if video.shape[1] != self.video_length:
# Force fix length
video = self._crop_or_pad_video(video.unsqueeze(0), self.video_length).squeeze(0)
# Final check
if audio.shape != (1, 13, 100) or video.shape != (3, 25, 112, 112):
raise ValueError(f"Final shape mismatch: Audio {audio.shape}, Video {video.shape}")
return {
'audio': audio,
'video': video,
'offset': torch.tensor(offset, dtype=torch.long), # Integer offset, not binary
}
except Exception as e:
if attempt == max_retries - 1:
# Mark as failed and try next sample
self.failed_samples.add(idx)
if idx % 100 == 0: # Only log occasionally
print(f"[WARN] Sample {idx} failed after {max_retries} attempts: {str(e)[:100]}")
return self.__getitem__((idx + 1) % len(self))
continue
class OffsetRegressionLoss(nn.Module):
"""L1 regression loss for continuous offset prediction."""
def __init__(self):
super(OffsetRegressionLoss, self).__init__()
self.l1 = nn.L1Loss() # More robust to outliers than MSE
def forward(self, predicted_offsets, target_offsets):
"""
Args:
predicted_offsets: [B, 1, T] - model output (continuous offset predictions)
target_offsets: [B] - ground truth offset in frames (float)
Returns:
loss: scalar
"""
B, C, T = predicted_offsets.shape
# Average over time dimension
predicted_offsets_avg = predicted_offsets.mean(dim=2).squeeze(1) # [B]
# L1 loss
loss = self.l1(predicted_offsets_avg, target_offsets.float())
return loss
def temporal_consistency_loss(predicted_offsets):
"""
Encourage smooth predictions over time.
Args:
predicted_offsets: [B, 1, T]
Returns:
consistency_loss: scalar
"""
# Compute difference between adjacent timesteps
temporal_diff = predicted_offsets[:, :, 1:] - predicted_offsets[:, :, :-1]
consistency_loss = (temporal_diff ** 2).mean()
return consistency_loss
def compute_metrics(predicted_offsets, target_offsets, max_offset=125):
"""
Compute comprehensive metrics for offset regression.
Args:
predicted_offsets: [B, 1, T]
target_offsets: [B]
Returns:
dict with metrics
"""
B, C, T = predicted_offsets.shape
# Average over time
predicted_offsets_avg = predicted_offsets.mean(dim=2).squeeze(1) # [B]
# Mean absolute error
mae = torch.abs(predicted_offsets_avg - target_offsets).mean()
# Root mean squared error
rmse = torch.sqrt(((predicted_offsets_avg - target_offsets) ** 2).mean())
# Error buckets
acc_1frame = (torch.abs(predicted_offsets_avg - target_offsets) <= 1).float().mean()
acc_1sec = (torch.abs(predicted_offsets_avg - target_offsets) <= 25).float().mean()
# Strict Sync Score (1 - error/25_frames)
# 1.0 = perfect sync
# 0.0 = >1 second error (unusable)
abs_error = torch.abs(predicted_offsets_avg - target_offsets)
sync_score = 1.0 - (abs_error / 25.0) # 25 frames = 1 second
sync_score = torch.clamp(sync_score, 0.0, 1.0).mean()
return {
'mae': mae.item(),
'rmse': rmse.item(),
'acc_1frame': acc_1frame.item(),
'acc_1sec': acc_1sec.item(),
'sync_score': sync_score.item()
}
def train_epoch(model, dataloader, optimizer, criterion, device, epoch_num):
"""Train for one epoch with regression metrics."""
model.train()
total_loss = 0
total_offset_loss = 0
total_consistency_loss = 0
metrics_accum = {'mae': 0, 'rmse': 0, 'acc_1frame': 0, 'acc_1sec': 0, 'sync_score': 0}
num_batches = 0
import gc
for batch_idx, batch in enumerate(dataloader):
audio = batch['audio'].to(device)
video = batch['video'].to(device)
offsets = batch['offset'].to(device)
# Forward pass
optimizer.zero_grad()
predicted_offsets, _, _ = model(audio, video)
# Compute losses
offset_loss = criterion(predicted_offsets, offsets)
consistency_loss = temporal_consistency_loss(predicted_offsets)
# Combined loss
loss = offset_loss + 0.1 * consistency_loss
# Backward pass
loss.backward()
optimizer.step()
# Statistics
total_loss += loss.item()
total_offset_loss += offset_loss.item()
total_consistency_loss += consistency_loss.item()
# Compute metrics
with torch.no_grad():
metrics = compute_metrics(predicted_offsets, offsets)
for key in metrics_accum:
metrics_accum[key] += metrics[key]
num_batches += 1
# Log every 10 batches
if batch_idx % 10 == 0:
print(f' Batch {batch_idx}/{len(dataloader)}, '
f'Loss: {loss.item():.4f}, '
f'MAE: {metrics["mae"]:.2f} frames, '
f'Score: {metrics["sync_score"]:.4f}')
# Clean up
del audio, video, offsets, predicted_offsets
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Average metrics
avg_loss = total_loss / num_batches
avg_offset_loss = total_offset_loss / num_batches
avg_consistency_loss = total_consistency_loss / num_batches
for key in metrics_accum:
metrics_accum[key] /= num_batches
return avg_loss, avg_offset_loss, avg_consistency_loss, metrics_accum
def main():
parser = argparse.ArgumentParser(description='Train SyncNetFCN (Improved)')
parser.add_argument('--data_dir', type=str, required=True, help='VoxCeleb2 root directory')
parser.add_argument('--pretrained_model', type=str, default='data/syncnet_v2.model',
help='Pretrained SyncNet model')
parser.add_argument('--checkpoint', type=str, default=None,
help='Resume from checkpoint (optional)')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)')
parser.add_argument('--epochs', type=int, default=20, help='Number of epochs')
parser.add_argument('--lr', type=float, default=0.00001, help='Learning rate (lowered from 0.001)')
parser.add_argument('--output_dir', type=str, default='checkpoints_improved', help='Output directory')
parser.add_argument('--use_attention', action='store_true', help='Use attention model')
parser.add_argument('--num_workers', type=int, default=2, help='DataLoader workers')
parser.add_argument('--max_offset', type=int, default=125, help='Max offset in frames (default: 125)')
parser.add_argument('--unfreeze_epoch', type=int, default=10, help='Epoch to unfreeze all layers (default: 10)')
args = parser.parse_args()
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Create model with transfer learning (max_offset=125 for ±5 seconds)
print(f'Creating model with max_offset={args.max_offset}...')
model = StreamSyncFCN(
max_offset=args.max_offset, # ±5 seconds at 25fps
pretrained_syncnet_path=args.pretrained_model,
auto_load_pretrained=True,
use_attention=args.use_attention
)
# Load from checkpoint if provided
start_epoch = 0
if args.checkpoint and os.path.exists(args.checkpoint):
print(f'Loading checkpoint: {args.checkpoint}')
checkpoint = torch.load(args.checkpoint, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint.get('epoch', 0)
print(f'Resuming from epoch {start_epoch}')
model = model.to(device)
print(f'Model created. Pretrained conv layers loaded and frozen.')
# Dataset and dataloader
print(f'Loading dataset with max_offset={args.max_offset}...')
dataset = VoxCeleb2DatasetImproved(args.data_dir, max_offset=args.max_offset)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, pin_memory=True)
# Loss and optimizer (REGRESSION)
criterion = OffsetRegressionLoss()
# Only optimize non-frozen parameters
trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.Adam(trainable_params, lr=args.lr)
# Learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=5, # Restart every 5 epochs
T_mult=2, # Double restart period each time
eta_min=1e-7 # Minimum LR
)
print(f'Trainable parameters: {sum(p.numel() for p in trainable_params):,}')
print(f'Frozen parameters: {sum(p.numel() for p in model.parameters() if not p.requires_grad):,}')
print(f'Learning rate: {args.lr}')
# Training loop
print('\\nStarting training...')
print('='*80)
best_tolerance_acc = 0
for epoch in range(start_epoch, start_epoch + args.epochs):
print(f'\\nEpoch {epoch+1}/{start_epoch + args.epochs}')
print('-'*80)
# Unfreeze layers if reached unfreeze_epoch
if epoch + 1 == args.unfreeze_epoch:
print(f'\\n🔓 Unfreezing all layers for fine-tuning at epoch {epoch+1}...')
model.unfreeze_all_layers()
# Lower learning rate for fine-tuning
new_lr = args.lr * 0.1
print(f'📉 Lowering learning rate to {new_lr} for fine-tuning')
# Re-initialize optimizer with all parameters
trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.Adam(trainable_params, lr=new_lr)
# Re-initialize scheduler
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer, T_0=5, T_mult=2, eta_min=1e-8
)
print(f'Trainable parameters now: {sum(p.numel() for p in trainable_params):,}')
avg_loss, avg_offset_loss, avg_consistency_loss, metrics = train_epoch(
model, dataloader, optimizer, criterion, device, epoch
)
# Step scheduler
scheduler.step()
current_lr = optimizer.param_groups[0]['lr']
print(f'\nEpoch {epoch+1} Summary:')
print(f' Total Loss: {avg_loss:.4f}')
print(f' Offset Loss: {avg_offset_loss:.4f}')
print(f' Consistency Loss: {avg_consistency_loss:.4f}')
print(f' MAE: {metrics["mae"]:.2f} frames ({metrics["mae"]/25:.3f} seconds)')
print(f' RMSE: {metrics["rmse"]:.2f} frames')
print(f' Sync Score: {metrics["sync_score"]:.4f} (1.0=Perfect, 0.0=>1s Error)')
print(f' <1 Frame Acc: {metrics["acc_1frame"]*100:.2f}%')
print(f' <1 Second Acc: {metrics["acc_1sec"]*100:.2f}%')
print(f' Learning Rate: {current_lr:.2e}')
# Save checkpoint
checkpoint_path = os.path.join(args.output_dir, f'syncnet_fcn_improved_epoch{epoch+1}.pth')
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': avg_loss,
'offset_loss': avg_offset_loss,
'metrics': metrics,
}, checkpoint_path)
print(f' Checkpoint saved: {checkpoint_path}')
# Save best model based on Sync Score
if metrics['sync_score'] > best_tolerance_acc:
best_tolerance_acc = metrics['sync_score']
best_path = os.path.join(args.output_dir, 'syncnet_fcn_best.pth')
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'metrics': metrics,
}, best_path)
print(f' ✓ New best model saved! (Score: {best_tolerance_acc:.4f})')
print('\n' + '='*80)
print('Training complete!')
print(f'Best Sync Score: {best_tolerance_acc:.4f}')
print(f'Models saved to: {args.output_dir}')
if __name__ == '__main__':
main()