|
|
|
|
|
""" |
|
|
Fixed video action prediction with proper TimeSformer tensor format. |
|
|
This version resolves the tensor compatibility issues definitively. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import List, Tuple, Optional |
|
|
import warnings |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
warnings.filterwarnings("ignore", category=DeprecationWarning) |
|
|
|
|
|
import torch |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
|
|
|
try: |
|
|
import cv2 |
|
|
HAS_CV2 = True |
|
|
except ImportError: |
|
|
HAS_CV2 = False |
|
|
cv2 = None |
|
|
|
|
|
try: |
|
|
import decord |
|
|
HAS_DECORD = True |
|
|
except ImportError: |
|
|
HAS_DECORD = False |
|
|
decord = None |
|
|
|
|
|
MODEL_ID = "facebook/timesformer-base-finetuned-k400" |
|
|
|
|
|
def read_video_frames_cv2(video_path: Path, num_frames: int = 8) -> List[Image.Image]: |
|
|
"""Read frames using OpenCV with robust error handling.""" |
|
|
if not HAS_CV2: |
|
|
raise RuntimeError("OpenCV not available") |
|
|
|
|
|
cap = cv2.VideoCapture(str(video_path)) |
|
|
if not cap.isOpened(): |
|
|
raise RuntimeError(f"Cannot open video: {video_path}") |
|
|
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
if total_frames == 0: |
|
|
cap.release() |
|
|
raise RuntimeError("Video has no frames") |
|
|
|
|
|
|
|
|
if total_frames <= num_frames: |
|
|
frame_indices = list(range(total_frames)) |
|
|
else: |
|
|
step = max(1, total_frames // num_frames) |
|
|
frame_indices = [i * step for i in range(num_frames)] |
|
|
|
|
|
frame_indices = [min(idx, total_frames - 1) for idx in frame_indices] |
|
|
|
|
|
frames = [] |
|
|
for idx in frame_indices: |
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
|
|
ret, frame = cap.read() |
|
|
if ret: |
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
pil_image = Image.fromarray(frame_rgb) |
|
|
frames.append(pil_image) |
|
|
|
|
|
cap.release() |
|
|
|
|
|
|
|
|
while len(frames) < num_frames: |
|
|
if frames: |
|
|
frames.append(frames[-1].copy()) |
|
|
else: |
|
|
|
|
|
black_frame = Image.new('RGB', (224, 224), (0, 0, 0)) |
|
|
frames.append(black_frame) |
|
|
|
|
|
return frames[:num_frames] |
|
|
|
|
|
def read_video_frames_decord(video_path: Path, num_frames: int = 8) -> List[Image.Image]: |
|
|
"""Read frames using decord.""" |
|
|
if not HAS_DECORD: |
|
|
raise RuntimeError("Decord not available") |
|
|
|
|
|
vr = decord.VideoReader(str(video_path)) |
|
|
total_frames = len(vr) |
|
|
|
|
|
if total_frames == 0: |
|
|
raise RuntimeError("Video has no frames") |
|
|
|
|
|
|
|
|
if total_frames <= num_frames: |
|
|
indices = list(range(total_frames)) |
|
|
else: |
|
|
step = max(1, total_frames // num_frames) |
|
|
indices = [i * step for i in range(num_frames)] |
|
|
indices = [min(idx, total_frames - 1) for idx in indices] |
|
|
|
|
|
try: |
|
|
frame_arrays = vr.get_batch(indices).asnumpy() |
|
|
frames = [Image.fromarray(frame) for frame in frame_arrays] |
|
|
except Exception: |
|
|
|
|
|
frames = [] |
|
|
for idx in indices: |
|
|
try: |
|
|
frame = vr[idx].asnumpy() |
|
|
frames.append(Image.fromarray(frame)) |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
|
|
|
while len(frames) < num_frames: |
|
|
if frames: |
|
|
frames.append(frames[-1].copy()) |
|
|
else: |
|
|
black_frame = Image.new('RGB', (224, 224), (0, 0, 0)) |
|
|
frames.append(black_frame) |
|
|
|
|
|
return frames[:num_frames] |
|
|
|
|
|
def read_video_frames(video_path: Path, num_frames: int = 8) -> List[Image.Image]: |
|
|
"""Read video frames with fallback methods.""" |
|
|
last_error = None |
|
|
|
|
|
|
|
|
if HAS_DECORD: |
|
|
try: |
|
|
frames = read_video_frames_decord(video_path, num_frames) |
|
|
if frames and len(frames) > 0: |
|
|
logging.debug(f"Successfully read {len(frames)} frames using decord") |
|
|
return frames |
|
|
except Exception as e: |
|
|
last_error = e |
|
|
logging.debug(f"Decord failed: {e}") |
|
|
|
|
|
|
|
|
if HAS_CV2: |
|
|
try: |
|
|
frames = read_video_frames_cv2(video_path, num_frames) |
|
|
if frames and len(frames) > 0: |
|
|
logging.debug(f"Successfully read {len(frames)} frames using OpenCV") |
|
|
return frames |
|
|
except Exception as e: |
|
|
last_error = e |
|
|
logging.debug(f"OpenCV failed: {e}") |
|
|
|
|
|
if last_error: |
|
|
raise RuntimeError(f"Failed to read video frames: {last_error}") |
|
|
else: |
|
|
raise RuntimeError("No video reading library available") |
|
|
|
|
|
def normalize_frames(frames: List[Image.Image], target_size: Tuple[int, int] = (224, 224)) -> List[Image.Image]: |
|
|
"""Normalize frames to consistent format.""" |
|
|
if not frames: |
|
|
raise RuntimeError("No frames to normalize") |
|
|
|
|
|
normalized = [] |
|
|
for i, frame in enumerate(frames): |
|
|
try: |
|
|
|
|
|
if frame.mode != 'RGB': |
|
|
frame = frame.convert('RGB') |
|
|
|
|
|
|
|
|
if frame.size != target_size: |
|
|
frame = frame.resize(target_size, Image.Resampling.LANCZOS) |
|
|
|
|
|
normalized.append(frame) |
|
|
except Exception as e: |
|
|
logging.warning(f"Error normalizing frame {i}: {e}") |
|
|
|
|
|
black_frame = Image.new('RGB', target_size, (0, 0, 0)) |
|
|
normalized.append(black_frame) |
|
|
|
|
|
return normalized |
|
|
|
|
|
def create_timesformer_tensor(frames: List[Image.Image]) -> torch.Tensor: |
|
|
""" |
|
|
Create properly formatted tensor for TimeSformer model. |
|
|
|
|
|
TimeSformer expects 5D input tensor: |
|
|
Input format: [batch_size, num_frames, channels, height, width] |
|
|
For 8 frames of 224x224: [1, 8, 3, 224, 224] |
|
|
""" |
|
|
if len(frames) != 8: |
|
|
raise ValueError(f"Expected 8 frames, got {len(frames)}") |
|
|
|
|
|
|
|
|
frame_tensors = [] |
|
|
|
|
|
for frame in frames: |
|
|
|
|
|
if frame.mode != 'RGB': |
|
|
frame = frame.convert('RGB') |
|
|
if frame.size != (224, 224): |
|
|
frame = frame.resize((224, 224), Image.Resampling.LANCZOS) |
|
|
|
|
|
|
|
|
pixels = list(frame.getdata()) |
|
|
|
|
|
|
|
|
r_channel = [] |
|
|
g_channel = [] |
|
|
b_channel = [] |
|
|
|
|
|
for r, g, b in pixels: |
|
|
r_channel.append(r / 255.0) |
|
|
g_channel.append(g / 255.0) |
|
|
b_channel.append(b / 255.0) |
|
|
|
|
|
|
|
|
r_tensor = torch.tensor(r_channel, dtype=torch.float32).view(224, 224) |
|
|
g_tensor = torch.tensor(g_channel, dtype=torch.float32).view(224, 224) |
|
|
b_tensor = torch.tensor(b_channel, dtype=torch.float32).view(224, 224) |
|
|
|
|
|
|
|
|
frame_tensor = torch.stack([r_tensor, g_tensor, b_tensor], dim=0) |
|
|
frame_tensors.append(frame_tensor) |
|
|
|
|
|
|
|
|
video_tensor = torch.stack(frame_tensors, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
video_tensor = video_tensor.unsqueeze(0) |
|
|
|
|
|
logging.debug(f"Created tensor with shape: {video_tensor.shape}") |
|
|
logging.debug(f"Tensor dtype: {video_tensor.dtype}") |
|
|
logging.debug(f"Tensor range: [{video_tensor.min():.3f}, {video_tensor.max():.3f}]") |
|
|
|
|
|
return video_tensor |
|
|
|
|
|
def load_model(device: Optional[str] = None): |
|
|
"""Load TimeSformer model and processor.""" |
|
|
try: |
|
|
from transformers import AutoImageProcessor, TimesformerForVideoClassification |
|
|
|
|
|
device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
logging.info(f"Loading model on device: {device}") |
|
|
|
|
|
processor = AutoImageProcessor.from_pretrained(MODEL_ID) |
|
|
model = TimesformerForVideoClassification.from_pretrained(MODEL_ID) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
logging.info("Model loaded successfully") |
|
|
return processor, model, device |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Failed to load model: {e}") |
|
|
raise RuntimeError(f"Model loading failed: {e}") |
|
|
|
|
|
def predict_actions(video_path: str, top_k: int = 5) -> List[Tuple[str, float]]: |
|
|
""" |
|
|
Predict actions in video using TimeSformer model. |
|
|
|
|
|
Args: |
|
|
video_path: Path to video file |
|
|
top_k: Number of top predictions to return |
|
|
|
|
|
Returns: |
|
|
List of (action_label, confidence_score) tuples |
|
|
""" |
|
|
video_path = Path(video_path) |
|
|
|
|
|
if not video_path.exists(): |
|
|
raise FileNotFoundError(f"Video file not found: {video_path}") |
|
|
|
|
|
try: |
|
|
|
|
|
processor, model, device = load_model() |
|
|
|
|
|
|
|
|
logging.info(f"Processing video: {video_path.name}") |
|
|
frames = read_video_frames(video_path, num_frames=8) |
|
|
frames = normalize_frames(frames, target_size=(224, 224)) |
|
|
|
|
|
logging.info(f"Extracted and normalized {len(frames)} frames") |
|
|
|
|
|
|
|
|
pixel_values = create_timesformer_tensor(frames) |
|
|
pixel_values = pixel_values.to(device) |
|
|
|
|
|
|
|
|
logging.info("Running model inference...") |
|
|
with torch.no_grad(): |
|
|
outputs = model(pixel_values=pixel_values) |
|
|
logits = outputs.logits |
|
|
|
|
|
|
|
|
probabilities = torch.softmax(logits, dim=-1)[0] |
|
|
top_probs, top_indices = torch.topk(probabilities, k=top_k) |
|
|
|
|
|
|
|
|
results = [] |
|
|
for prob, idx in zip(top_probs, top_indices): |
|
|
label = model.config.id2label[idx.item()] |
|
|
confidence = float(prob.item()) |
|
|
results.append((label, confidence)) |
|
|
|
|
|
logging.info(f"Generated {len(results)} predictions successfully") |
|
|
|
|
|
|
|
|
if results: |
|
|
top_label, top_conf = results[0] |
|
|
logging.info(f"Top prediction: {top_label} ({top_conf:.3f})") |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Prediction failed: {e}") |
|
|
raise RuntimeError(f"Video processing error: {e}") |
|
|
|
|
|
def main(): |
|
|
"""Command line interface.""" |
|
|
parser = argparse.ArgumentParser(description="Predict actions in video using TimeSformer") |
|
|
parser.add_argument("video", type=str, help="Path to video file") |
|
|
parser.add_argument("--top-k", type=int, default=5, help="Number of top predictions") |
|
|
parser.add_argument("--json", action="store_true", help="Output as JSON") |
|
|
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose logging") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.verbose: |
|
|
logging.getLogger().setLevel(logging.DEBUG) |
|
|
|
|
|
try: |
|
|
|
|
|
predictions = predict_actions(args.video, top_k=args.top_k) |
|
|
|
|
|
if args.json: |
|
|
output = [{"label": label, "confidence": confidence} |
|
|
for label, confidence in predictions] |
|
|
print(json.dumps(output, indent=2)) |
|
|
else: |
|
|
print(f"\nTop {len(predictions)} predictions for: {args.video}") |
|
|
print("-" * 60) |
|
|
for i, (label, confidence) in enumerate(predictions, 1): |
|
|
print(f"{i:2d}. {label:<35} {confidence:.4f}") |
|
|
|
|
|
return 0 |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
if args.verbose: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return 1 |
|
|
|
|
|
if __name__ == "__main__": |
|
|
exit(main()) |
|
|
|