Syncnet_FCN / SyncNetModel_FCN_Classification.py
Shubham
Deploy clean version
579f772
#!/usr/bin/python
#-*- coding: utf-8 -*-
"""
Fully Convolutional SyncNet (FCN-SyncNet) - CLASSIFICATION VERSION
Key difference from regression version:
- Output: Probability distribution over discrete offset classes
- Loss: CrossEntropyLoss instead of MSE
- Avoids regression-to-mean problem
Offset classes: -15 to +15 frames (31 classes total)
Class 0 = -15 frames, Class 15 = 0 frames, Class 30 = +15 frames
Author: Enhanced version based on original SyncNet
Date: 2025-12-04
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import cv2
import os
import subprocess
from scipy.io import wavfile
import python_speech_features
class TemporalCorrelation(nn.Module):
"""
Compute correlation between audio and video features across time.
"""
def __init__(self, max_displacement=15):
super(TemporalCorrelation, self).__init__()
self.max_displacement = max_displacement
def forward(self, feat1, feat2):
"""
Args:
feat1: [B, C, T] - visual features
feat2: [B, C, T] - audio features
Returns:
correlation: [B, 2*max_displacement+1, T] - correlation map
"""
B, C, T = feat1.shape
max_disp = self.max_displacement
# Normalize features
feat1 = F.normalize(feat1, dim=1)
feat2 = F.normalize(feat2, dim=1)
# Pad feat2 for shifting
feat2_padded = F.pad(feat2, (max_disp, max_disp), mode='replicate')
corr_list = []
for offset in range(-max_disp, max_disp + 1):
shifted_feat2 = feat2_padded[:, :, offset+max_disp:offset+max_disp+T]
corr = (feat1 * shifted_feat2).sum(dim=1, keepdim=True)
corr_list.append(corr)
correlation = torch.cat(corr_list, dim=1)
return correlation
class ChannelAttention(nn.Module):
"""Squeeze-and-Excitation style channel attention."""
def __init__(self, channels, reduction=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Sequential(
nn.Linear(channels, channels // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channels // reduction, channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, t = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1)
return x * y.expand_as(x)
class TemporalAttention(nn.Module):
"""Self-attention over temporal dimension."""
def __init__(self, channels):
super(TemporalAttention, self).__init__()
self.query_conv = nn.Conv1d(channels, channels // 8, 1)
self.key_conv = nn.Conv1d(channels, channels // 8, 1)
self.value_conv = nn.Conv1d(channels, channels, 1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
B, C, T = x.size()
query = self.query_conv(x).permute(0, 2, 1)
key = self.key_conv(x)
value = self.value_conv(x)
attention = torch.bmm(query, key)
attention = F.softmax(attention, dim=-1)
out = torch.bmm(value, attention.permute(0, 2, 1))
out = self.gamma * out + x
return out
class FCN_AudioEncoder(nn.Module):
"""Fully convolutional audio encoder."""
def __init__(self, output_channels=512):
super(FCN_AudioEncoder, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1)),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 192, kernel_size=(3,3), stride=(1,1), padding=(1,1)),
nn.BatchNorm2d(192),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(3,3), stride=(1,2)),
nn.Conv2d(192, 384, kernel_size=(3,3), padding=(1,1)),
nn.BatchNorm2d(384),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=(3,3), padding=(1,1)),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=(3,3), padding=(1,1)),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(3,3), stride=(2,2)),
nn.Conv2d(256, 512, kernel_size=(5,1), stride=(5,1), padding=(0,0)),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
)
self.channel_conv = nn.Sequential(
nn.Conv1d(512, 512, kernel_size=1),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Conv1d(512, output_channels, kernel_size=1),
nn.BatchNorm1d(output_channels),
)
self.channel_attn = ChannelAttention(output_channels)
def forward(self, x):
x = self.conv_layers(x)
B, C, F, T = x.size()
x = x.view(B, C * F, T)
x = self.channel_conv(x)
x = self.channel_attn(x)
return x
class FCN_VideoEncoder(nn.Module):
"""Fully convolutional video encoder."""
def __init__(self, output_channels=512):
super(FCN_VideoEncoder, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv3d(3, 96, kernel_size=(5,7,7), stride=(1,2,2), padding=(2,3,3)),
nn.BatchNorm3d(96),
nn.ReLU(inplace=True),
nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)),
nn.Conv3d(96, 256, kernel_size=(3,5,5), stride=(1,2,2), padding=(1,2,2)),
nn.BatchNorm3d(256),
nn.ReLU(inplace=True),
nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)),
nn.Conv3d(256, 256, kernel_size=(3,3,3), padding=(1,1,1)),
nn.BatchNorm3d(256),
nn.ReLU(inplace=True),
nn.Conv3d(256, 256, kernel_size=(3,3,3), padding=(1,1,1)),
nn.BatchNorm3d(256),
nn.ReLU(inplace=True),
nn.Conv3d(256, 256, kernel_size=(3,3,3), padding=(1,1,1)),
nn.BatchNorm3d(256),
nn.ReLU(inplace=True),
nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)),
nn.Conv3d(256, 512, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1)),
nn.BatchNorm3d(512),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool3d((None, 1, 1))
)
self.channel_conv = nn.Sequential(
nn.Conv1d(512, 512, kernel_size=1),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Conv1d(512, output_channels, kernel_size=1),
nn.BatchNorm1d(output_channels),
)
self.channel_attn = ChannelAttention(output_channels)
def forward(self, x):
x = self.conv_layers(x)
B, C, T, H, W = x.size()
x = x.view(B, C, T)
x = self.channel_conv(x)
x = self.channel_attn(x)
return x
class SyncNetFCN_Classification(nn.Module):
"""
Fully Convolutional SyncNet with CLASSIFICATION output.
Treats offset detection as a multi-class classification problem:
- num_classes = 2 * max_offset + 1 (e.g., 251 classes for max_offset=125)
- Class index = offset + max_offset (e.g., offset -5 → class 120)
- Uses CrossEntropyLoss for training
- Default: ±125 frames = ±5 seconds at 25fps
This avoids the regression-to-mean problem encountered with MSE loss.
Architecture:
1. Audio encoder: MFCC → temporal features
2. Video encoder: frames → temporal features
3. Correlation layer: compute audio-video similarity over time
4. Classifier: predict offset class probabilities
"""
def __init__(self, embedding_dim=512, max_offset=125, dropout=0.3):
super(SyncNetFCN_Classification, self).__init__()
self.embedding_dim = embedding_dim
self.max_offset = max_offset
self.num_classes = 2 * max_offset + 1 # -15 to +15 = 31 classes
# Encoders
self.audio_encoder = FCN_AudioEncoder(output_channels=embedding_dim)
self.video_encoder = FCN_VideoEncoder(output_channels=embedding_dim)
# Temporal correlation
self.correlation = TemporalCorrelation(max_displacement=max_offset)
# Classifier head (replaces regressor)
self.classifier = nn.Sequential(
nn.Conv1d(self.num_classes, 128, kernel_size=3, padding=1),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Conv1d(128, 64, kernel_size=3, padding=1),
nn.BatchNorm1d(64),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
# Output: class logits for each timestep
nn.Conv1d(64, self.num_classes, kernel_size=1),
)
# Global classifier (for single prediction from sequence)
self.global_classifier = nn.Sequential(
nn.AdaptiveAvgPool1d(1),
nn.Flatten(),
nn.Linear(self.num_classes, 128),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(128, self.num_classes),
)
def forward_audio(self, audio_mfcc):
"""Extract audio features."""
return self.audio_encoder(audio_mfcc)
def forward_video(self, video_frames):
"""Extract video features."""
return self.video_encoder(video_frames)
def forward(self, audio_mfcc, video_frames, return_temporal=False):
"""
Forward pass with audio-video offset classification.
Args:
audio_mfcc: [B, 1, F, T] - MFCC features
video_frames: [B, 3, T', H, W] - video frames
return_temporal: If True, also return per-timestep predictions
Returns:
class_logits: [B, num_classes] - global offset class logits
temporal_logits: [B, num_classes, T] - per-timestep logits (if return_temporal)
audio_features: [B, C, T_a] - audio embeddings
video_features: [B, C, T_v] - video embeddings
"""
# Extract features
if audio_mfcc.dim() == 3:
audio_mfcc = audio_mfcc.unsqueeze(1)
audio_features = self.audio_encoder(audio_mfcc)
video_features = self.video_encoder(video_frames)
# Align temporal dimensions
min_time = min(audio_features.size(2), video_features.size(2))
audio_features = audio_features[:, :, :min_time]
video_features = video_features[:, :, :min_time]
# Compute correlation
correlation = self.correlation(video_features, audio_features)
# Per-timestep classification
temporal_logits = self.classifier(correlation)
# Global classification (aggregate over time)
class_logits = self.global_classifier(temporal_logits)
if return_temporal:
return class_logits, temporal_logits, audio_features, video_features
return class_logits, audio_features, video_features
def predict_offset(self, class_logits):
"""
Convert class logits to offset prediction.
Args:
class_logits: [B, num_classes] - classification logits
Returns:
offsets: [B] - predicted offset in frames
confidences: [B] - prediction confidence (softmax probability)
"""
probs = F.softmax(class_logits, dim=1)
predicted_class = probs.argmax(dim=1)
offsets = predicted_class - self.max_offset # Convert class to offset
confidences = probs.max(dim=1).values
return offsets, confidences
def offset_to_class(self, offset):
"""Convert offset value to class index."""
return offset + self.max_offset
def class_to_offset(self, class_idx):
"""Convert class index to offset value."""
return class_idx - self.max_offset
class StreamSyncFCN_Classification(nn.Module):
"""
Streaming-capable FCN SyncNet with classification output.
Includes preprocessing, transfer learning, and inference utilities.
"""
def __init__(self, embedding_dim=512, max_offset=125,
window_size=25, stride=5, buffer_size=100,
pretrained_syncnet_path=None, auto_load_pretrained=True,
dropout=0.3):
super(StreamSyncFCN_Classification, self).__init__()
self.window_size = window_size
self.stride = stride
self.buffer_size = buffer_size
self.max_offset = max_offset
self.num_classes = 2 * max_offset + 1
# Initialize classification model
self.fcn_model = SyncNetFCN_Classification(
embedding_dim=embedding_dim,
max_offset=max_offset,
dropout=dropout
)
# Auto-load pretrained weights
if auto_load_pretrained and pretrained_syncnet_path:
self.load_pretrained_syncnet(pretrained_syncnet_path)
self.reset_buffers()
def reset_buffers(self):
"""Reset temporal buffers."""
self.logits_buffer = []
self.frame_count = 0
def load_pretrained_syncnet(self, syncnet_model_path, freeze_conv=True, verbose=True):
"""Load conv layers from original SyncNet."""
if verbose:
print(f"Loading pretrained SyncNet from: {syncnet_model_path}")
try:
pretrained = torch.load(syncnet_model_path, map_location='cpu')
if isinstance(pretrained, dict):
pretrained_dict = pretrained.get('model_state_dict', pretrained.get('state_dict', pretrained))
else:
pretrained_dict = pretrained.state_dict()
fcn_dict = self.fcn_model.state_dict()
loaded_count = 0
for key in list(pretrained_dict.keys()):
if key.startswith('netcnnaud.'):
idx = key.split('.')[1]
param = '.'.join(key.split('.')[2:])
new_key = f'audio_encoder.conv_layers.{idx}.{param}'
if new_key in fcn_dict and pretrained_dict[key].shape == fcn_dict[new_key].shape:
fcn_dict[new_key] = pretrained_dict[key]
loaded_count += 1
elif key.startswith('netcnnlip.'):
idx = key.split('.')[1]
param = '.'.join(key.split('.')[2:])
new_key = f'video_encoder.conv_layers.{idx}.{param}'
if new_key in fcn_dict and pretrained_dict[key].shape == fcn_dict[new_key].shape:
fcn_dict[new_key] = pretrained_dict[key]
loaded_count += 1
self.fcn_model.load_state_dict(fcn_dict, strict=False)
if verbose:
print(f"✓ Loaded {loaded_count} pretrained conv parameters")
if freeze_conv:
for name, param in self.fcn_model.named_parameters():
if 'conv_layers' in name:
param.requires_grad = False
if verbose:
print("✓ Froze pretrained conv layers")
except Exception as e:
if verbose:
print(f"⚠ Could not load pretrained weights: {e}")
def load_fcn_checkpoint(self, checkpoint_path, verbose=True):
"""Load FCN classification checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location='cpu')
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
else:
state_dict = checkpoint
# Try to load directly first
try:
self.fcn_model.load_state_dict(state_dict, strict=True)
if verbose:
print(f"✓ Loaded full checkpoint from {checkpoint_path}")
except:
# Load only matching keys
model_dict = self.fcn_model.state_dict()
pretrained_dict = {k: v for k, v in state_dict.items()
if k in model_dict and v.shape == model_dict[k].shape}
model_dict.update(pretrained_dict)
self.fcn_model.load_state_dict(model_dict, strict=False)
if verbose:
print(f"✓ Loaded {len(pretrained_dict)}/{len(state_dict)} parameters from {checkpoint_path}")
return checkpoint.get('epoch', None)
def unfreeze_all_layers(self, verbose=True):
"""Unfreeze all layers for fine-tuning."""
for param in self.fcn_model.parameters():
param.requires_grad = True
if verbose:
print("✓ Unfrozen all layers for fine-tuning")
def forward(self, audio_mfcc, video_frames, return_temporal=False):
"""Forward pass through FCN model."""
return self.fcn_model(audio_mfcc, video_frames, return_temporal)
def extract_audio_mfcc(self, video_path, temp_dir='temp'):
"""Extract audio and compute MFCC."""
os.makedirs(temp_dir, exist_ok=True)
audio_path = os.path.join(temp_dir, 'temp_audio.wav')
cmd = ['ffmpeg', '-y', '-i', video_path, '-ac', '1', '-ar', '16000',
'-vn', '-acodec', 'pcm_s16le', audio_path]
subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
sample_rate, audio = wavfile.read(audio_path)
mfcc = python_speech_features.mfcc(audio, sample_rate, numcep=13).T
mfcc_tensor = torch.FloatTensor(mfcc).unsqueeze(0).unsqueeze(0)
if os.path.exists(audio_path):
os.remove(audio_path)
return mfcc_tensor
def extract_video_frames(self, video_path, target_size=(112, 112)):
"""Extract video frames as tensor."""
cap = cv2.VideoCapture(video_path)
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frame = cv2.resize(frame, target_size)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame.astype(np.float32) / 255.0)
cap.release()
if not frames:
raise ValueError(f"No frames extracted from {video_path}")
frames_array = np.stack(frames, axis=0)
video_tensor = torch.FloatTensor(frames_array).permute(3, 0, 1, 2).unsqueeze(0)
return video_tensor
def detect_offset(self, video_path, temp_dir='temp', verbose=True):
"""
Detect AV offset using classification approach.
Args:
video_path: Path to video file
temp_dir: Temporary directory for audio extraction
verbose: Print progress information
Returns:
offset: Predicted offset in frames (positive = audio ahead)
confidence: Classification confidence (0-1)
class_probs: Full probability distribution over offset classes
"""
if verbose:
print(f"Processing: {video_path}")
# Extract features
mfcc = self.extract_audio_mfcc(video_path, temp_dir)
video = self.extract_video_frames(video_path)
if verbose:
print(f" Audio MFCC: {mfcc.shape}, Video: {video.shape}")
# Run inference
self.fcn_model.eval()
with torch.no_grad():
class_logits, _, _ = self.fcn_model(mfcc, video)
offset, confidence = self.fcn_model.predict_offset(class_logits)
class_probs = F.softmax(class_logits, dim=1)
offset = offset.item()
confidence = confidence.item()
if verbose:
print(f" Detected offset: {offset:+d} frames")
print(f" Confidence: {confidence:.4f}")
return offset, confidence, class_probs.squeeze(0).numpy()
def process_video_file(self, video_path, temp_dir='temp', verbose=True):
"""Alias for detect_offset for compatibility."""
offset, confidence, _ = self.detect_offset(video_path, temp_dir, verbose)
return offset, confidence
def create_classification_criterion(max_offset=125, label_smoothing=0.1):
"""
Create loss function for classification training.
Args:
max_offset: Maximum offset value
label_smoothing: Label smoothing factor (0 = no smoothing)
Returns:
criterion: CrossEntropyLoss with optional label smoothing
"""
return nn.CrossEntropyLoss(label_smoothing=label_smoothing)
def train_step_classification(model, audio, video, target_offset, criterion, optimizer, device):
"""
Single training step for classification model.
Args:
model: SyncNetFCN_Classification or StreamSyncFCN_Classification
audio: [B, 1, F, T] audio MFCC
video: [B, 3, T, H, W] video frames
target_offset: [B] target offset in frames (-max_offset to +max_offset)
criterion: CrossEntropyLoss
optimizer: Optimizer
device: torch device
Returns:
loss: Training loss value
accuracy: Classification accuracy
"""
model.train()
optimizer.zero_grad()
audio = audio.to(device)
video = video.to(device)
# Convert offset to class index
if hasattr(model, 'fcn_model'):
target_class = target_offset + model.fcn_model.max_offset
else:
target_class = target_offset + model.max_offset
target_class = target_class.long().to(device)
# Forward pass
if hasattr(model, 'fcn_model'):
class_logits, _, _ = model(audio, video)
else:
class_logits, _, _ = model(audio, video)
# Compute loss
loss = criterion(class_logits, target_class)
# Backward pass
loss.backward()
optimizer.step()
# Compute accuracy
predicted_class = class_logits.argmax(dim=1)
accuracy = (predicted_class == target_class).float().mean().item()
return loss.item(), accuracy
def validate_classification(model, dataloader, criterion, device, max_offset=125):
"""
Validate classification model.
Returns:
avg_loss: Average validation loss
accuracy: Classification accuracy
mean_error: Mean absolute error in frames
"""
model.eval()
total_loss = 0
correct = 0
total = 0
total_error = 0
with torch.no_grad():
for audio, video, target_offset in dataloader:
audio = audio.to(device)
video = video.to(device)
target_class = (target_offset + max_offset).long().to(device)
if hasattr(model, 'fcn_model'):
class_logits, _, _ = model(audio, video)
else:
class_logits, _, _ = model(audio, video)
loss = criterion(class_logits, target_class)
total_loss += loss.item() * audio.size(0)
predicted_class = class_logits.argmax(dim=1)
correct += (predicted_class == target_class).sum().item()
total += audio.size(0)
# Mean absolute error
predicted_offset = predicted_class - max_offset
target_offset_dev = target_class - max_offset
total_error += (predicted_offset - target_offset_dev).abs().sum().item()
return total_loss / total, correct / total, total_error / total
if __name__ == "__main__":
print("Testing SyncNetFCN_Classification...")
# Test model creation (use smaller offset for quick testing)
model = SyncNetFCN_Classification(embedding_dim=512, max_offset=125)
print(f"Number of classes: {model.num_classes}")
# Test forward pass
audio_input = torch.randn(2, 1, 13, 100)
video_input = torch.randn(2, 3, 25, 112, 112)
class_logits, audio_feat, video_feat = model(audio_input, video_input)
print(f"Class logits: {class_logits.shape}")
print(f"Audio features: {audio_feat.shape}")
print(f"Video features: {video_feat.shape}")
# Test prediction
offsets, confidences = model.predict_offset(class_logits)
print(f"Predicted offsets: {offsets}")
print(f"Confidences: {confidences}")
# Test with temporal output
class_logits, temporal_logits, _, _ = model(audio_input, video_input, return_temporal=True)
print(f"Temporal logits: {temporal_logits.shape}")
# Test training step
print("\nTesting training step...")
criterion = create_classification_criterion(max_offset=125, label_smoothing=0.1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
target_offset = torch.tensor([3, -5]) # Example target offsets
loss, acc = train_step_classification(
model, audio_input, video_input, target_offset,
criterion, optimizer, 'cpu'
)
print(f"Training loss: {loss:.4f}, Accuracy: {acc:.2%}")
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print("\nTesting StreamSyncFCN_Classification...")
stream_model = StreamSyncFCN_Classification(
embedding_dim=512, max_offset=125,
pretrained_syncnet_path=None, auto_load_pretrained=False
)
class_logits, _, _ = stream_model(audio_input, video_input)
print(f"Stream model class logits: {class_logits.shape}")
print("\n✓ All tests passed!")