PATS / app.py
EdBianchi's picture
Update app.py
d7bab42 verified
import gradio as gr
import av
import numpy as np
from PIL import Image
import tempfile
import os
def sample_frame_indices(num_frames, fps, total_frames):
"""
Fallback sampling function for basic frame selection.
Args:
num_frames (int): Number of frames to sample
fps (float): Frames per second (not used in basic implementation)
total_frames (int): Total frames in video
Returns:
list: Frame indices
"""
if total_frames <= num_frames:
return list(range(total_frames))
# Simple uniform sampling
indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
return indices.tolist()
def sample_frame_indices_efficient_segments(num_frames, segment_duration, num_segments, container):
"""
Enhanced frame sampling strategy that distributes frames across temporal segments
of the video for better temporal coverage and content diversity.
Args:
num_frames (int): Total number of frames to sample
segment_duration (float): Duration of each segment in seconds
num_segments (int): Number of segments to sample from
container (av.container): PyAV container object
Returns:
list: Exactly num_frames frame indices
"""
# Get video properties
video_stream = container.streams.video[0]
video_fps = float(video_stream.average_rate)
total_video_frames = video_stream.frames
video_duration = total_video_frames / video_fps
# Fallback to original sampling if video is too short or has issues
if total_video_frames < num_frames or video_duration <= 0:
return sample_frame_indices(num_frames, 4, total_video_frames)
# Calculate frames per segment - ensure we get exactly num_frames
base_frames_per_segment = num_frames // num_segments
extra_frames = num_frames % num_segments
# Ensure segment duration doesn't exceed video duration, but adjust if needed
max_segment_duration = video_duration / num_segments * 0.8 # Leave some buffer
effective_segment_duration = min(segment_duration, max_segment_duration)
# If segments would be too small, fall back to original sampling
if effective_segment_duration < 0.5: # Less than 0.5 seconds per segment
return sample_frame_indices(num_frames, 4, total_video_frames)
# Calculate segment start times distributed across the video
if num_segments == 1:
segment_starts = [0]
else:
# Distribute segments evenly, ensuring they don't go beyond video end
max_start_time = max(0, video_duration - effective_segment_duration)
segment_starts = np.linspace(0, max_start_time, num_segments)
all_indices = []
frames_collected = 0
for i, start_time in enumerate(segment_starts):
# Calculate number of frames for this segment
segment_frames = base_frames_per_segment + (1 if i < extra_frames else 0)
if segment_frames == 0:
continue
# Convert time to frame indices
start_frame = int(start_time * video_fps)
end_frame = min(int((start_time + effective_segment_duration) * video_fps), total_video_frames)
# Ensure we have a valid range
if start_frame >= end_frame:
end_frame = min(start_frame + int(0.5 * video_fps), total_video_frames) # At least 0.5 seconds
# Ensure end_frame is within bounds
end_frame = min(end_frame, total_video_frames)
# Sample frames within this segment
if segment_frames == 1:
# Single frame: take middle of segment
frame_idx = start_frame + (end_frame - start_frame) // 2
segment_indices = [min(frame_idx, total_video_frames - 1)]
elif end_frame - start_frame <= segment_frames:
# If segment is too short, take all available frames and pad
available_frames = list(range(start_frame, end_frame))
while len(available_frames) < segment_frames and available_frames:
# Duplicate frames if needed
available_frames.extend(available_frames[:segment_frames - len(available_frames)])
segment_indices = available_frames[:segment_frames]
else:
# Multiple frames: distribute evenly within segment
segment_indices = np.linspace(start_frame, end_frame - 1, segment_frames, dtype=int).tolist()
all_indices.extend(segment_indices)
frames_collected += len(segment_indices)
# Safety check to prevent infinite loops
if frames_collected >= num_frames:
break
# Convert to numpy array for easier manipulation
all_indices = np.array(all_indices)
# Ensure we have exactly num_frames - this is critical
if len(all_indices) != num_frames:
if len(all_indices) > num_frames:
# Too many frames: select exactly num_frames uniformly
step = len(all_indices) / num_frames
selected_indices = [all_indices[int(i * step)] for i in range(num_frames)]
all_indices = np.array(selected_indices)
else:
# Too few frames: pad by repeating frames
needed = num_frames - len(all_indices)
if len(all_indices) > 0:
# Repeat existing frames cyclically
additional_indices = []
for i in range(needed):
additional_indices.append(all_indices[i % len(all_indices)])
all_indices = np.concatenate([all_indices, additional_indices])
else:
# Fallback: use original sampling
return sample_frame_indices(num_frames, 4, total_video_frames)
# Final cleanup: ensure all indices are valid and within bounds
all_indices = np.clip(all_indices, 0, total_video_frames - 1)
# Sort indices to maintain temporal order
all_indices = np.sort(all_indices)
# Final verification - this should never fail now
assert len(all_indices) == num_frames, f"Expected {num_frames} frames, got {len(all_indices)}"
return all_indices.tolist()
def extract_frames_at_indices(video_path, frame_indices):
"""
Extract frames from video at specified indices.
Args:
video_path (str): Path to video file
frame_indices (list): List of frame indices to extract
Returns:
list: List of PIL Images
"""
container = av.open(video_path)
video_stream = container.streams.video[0]
frames = []
frame_idx = 0
target_indices = set(frame_indices)
# Decode video and extract frames at specified indices
for frame in container.decode(video=0):
if frame_idx in target_indices:
# Convert frame to PIL Image
img = frame.to_image()
frames.append(img)
# Remove from target set
target_indices.remove(frame_idx)
# Stop if we've collected all frames
if not target_indices:
break
frame_idx += 1
container.close()
return frames
def process_video(video_file, num_frames, segment_duration, num_segments):
"""
Main processing function for Gradio interface.
Args:
video_file: Uploaded video file
num_frames (int): Number of frames to sample
segment_duration (float): Duration of each segment in seconds
num_segments (int): Number of segments
Returns:
tuple: (frames list, info string, indices list)
"""
if video_file is None:
return [], "Please upload a video file", []
try:
# Open video container
container = av.open(video_file)
video_stream = container.streams.video[0]
# Get video info
video_fps = float(video_stream.average_rate)
total_frames = video_stream.frames
video_duration = total_frames / video_fps if video_fps > 0 else 0
# Get frame indices using the sampling function
frame_indices = sample_frame_indices_efficient_segments(
num_frames, segment_duration, num_segments, container
)
container.close()
# Extract frames at selected indices
frames = extract_frames_at_indices(video_file, frame_indices)
# Create info string
info = f"""
**Video Information:**
- Total frames: {total_frames}
- FPS: {video_fps:.2f}
- Duration: {video_duration:.2f} seconds
"""
# Add frame numbers to images for display
labeled_frames = []
for i, (frame, idx) in enumerate(zip(frames, frame_indices)):
# Create a copy and add text overlay
frame_copy = frame.copy()
# Add frame number as caption
labeled_frames.append((frame_copy, f"Frame {idx} (Sample {i+1}/{num_frames})"))
return labeled_frames, info, frame_indices
except Exception as e:
return [], f"Error processing video: {str(e)}", []
# Create Gradio interface
with gr.Blocks(title="PATS: Proficiency-Aware Temporal Sampling for Multi-View Sports Skill Assessment") as demo:
gr.Markdown("""
# PATS: Proficiency-Aware Temporal Sampling for Multi-View Sports Skill Assessment
PATS (Proficiency-Aware Temporal Sampling) is a novel video sampling strategy designed specifically for automated sports skill assessment.
Unlike traditional methods that randomly sample frames or use uniform intervals, PATS preserves complete fundamental movements within continuous temporal segments.
The paper presenting PATS has been accepted at the 2025 4th IEEE Sport Technology and Research Workshop.
This tool showcases the PATS sampling strategy. Find out more at the project page: https://edowhite.github.io/PATS
## Core Concept
The key insight is that athletic proficiency manifests through structured temporal patterns that require observing complete, uninterrupted movements.
PATS addresses this by:
- **Extracting continuous temporal segments** rather than isolated frames
- **Preserving natural movement flow** essential for distinguishing expert from novice performance
- **Distributing multiple segments** across the video timeline to maximize information coverage
## Performance
When applied to SkillFormer on the EgoExo4D benchmark, PATS achieves:
- **Consistent improvements** across all viewing configurations (+0.65% to +3.05%)
- **Substantial domain-specific gains:** +26.22% in bouldering, +2.39% in music, +1.13% in basketball
""")
with gr.Row():
with gr.Column(scale=1):
# Input components
video_input = gr.Video(label="Upload Video")
gr.Markdown("### Sampling Parameters")
num_frames = gr.Slider(
minimum=1,
maximum=50,
value=8,
step=1,
label="Number of Frames to Sample",
info="Total number of frames to extract from the video"
)
num_segments = gr.Slider(
minimum=1,
maximum=20,
value=4,
step=1,
label="Number of Segments",
info="Number of temporal segments to divide the video into"
)
segment_duration = gr.Slider(
minimum=0.5,
maximum=10.0,
value=2.0,
step=0.5,
label="Segment Duration (seconds)",
info="Duration of each segment for sampling"
)
process_btn = gr.Button("Process Video", variant="primary")
with gr.Column(scale=2):
# Output components
info_output = gr.Markdown(label="Processing Information")
gallery_output = gr.Gallery(
label="Sampled Frames",
show_label=True,
elem_id="gallery",
columns=4,
rows=3,
height="auto"
)
indices_output = gr.JSON(label="Frame Indices", visible=False)
# Connect the processing function
process_btn.click(
fn=process_video,
inputs=[video_input, num_frames, segment_duration, num_segments],
outputs=[gallery_output, info_output, indices_output]
)
# Launch the app
if __name__ == "__main__":
demo.launch()