File size: 12,723 Bytes
a7bcb92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6f9f28
a7bcb92
e6f9f28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7bcb92
e6f9f28
 
 
a7bcb92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
import gradio as gr
import av
import numpy as np
from PIL import Image
import tempfile
import os

def sample_frame_indices(num_frames, fps, total_frames):
    """
    Fallback sampling function for basic frame selection.
    
    Args:
        num_frames (int): Number of frames to sample
        fps (float): Frames per second (not used in basic implementation)
        total_frames (int): Total frames in video
        
    Returns:
        list: Frame indices
    """
    if total_frames <= num_frames:
        return list(range(total_frames))
    
    # Simple uniform sampling
    indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
    return indices.tolist()

def sample_frame_indices_efficient_segments(num_frames, segment_duration, num_segments, container):
    """
    Enhanced frame sampling strategy that distributes frames across temporal segments
    of the video for better temporal coverage and content diversity.
    
    Args:
        num_frames (int): Total number of frames to sample
        segment_duration (float): Duration of each segment in seconds
        num_segments (int): Number of segments to sample from
        container (av.container): PyAV container object
        
    Returns:
        list: Exactly num_frames frame indices
    """
    # Get video properties
    video_stream = container.streams.video[0]
    video_fps = float(video_stream.average_rate)
    total_video_frames = video_stream.frames
    video_duration = total_video_frames / video_fps
    
    # Fallback to original sampling if video is too short or has issues
    if total_video_frames < num_frames or video_duration <= 0:
        return sample_frame_indices(num_frames, 4, total_video_frames)
    
    # Calculate frames per segment - ensure we get exactly num_frames
    base_frames_per_segment = num_frames // num_segments
    extra_frames = num_frames % num_segments
    
    # Ensure segment duration doesn't exceed video duration, but adjust if needed
    max_segment_duration = video_duration / num_segments * 0.8  # Leave some buffer
    effective_segment_duration = min(segment_duration, max_segment_duration)
    
    # If segments would be too small, fall back to original sampling
    if effective_segment_duration < 0.5:  # Less than 0.5 seconds per segment
        return sample_frame_indices(num_frames, 4, total_video_frames)
    
    # Calculate segment start times distributed across the video
    if num_segments == 1:
        segment_starts = [0]
    else:
        # Distribute segments evenly, ensuring they don't go beyond video end
        max_start_time = max(0, video_duration - effective_segment_duration)
        segment_starts = np.linspace(0, max_start_time, num_segments)
    
    all_indices = []
    frames_collected = 0
    
    for i, start_time in enumerate(segment_starts):
        # Calculate number of frames for this segment
        segment_frames = base_frames_per_segment + (1 if i < extra_frames else 0)
        
        if segment_frames == 0:
            continue
            
        # Convert time to frame indices
        start_frame = int(start_time * video_fps)
        end_frame = min(int((start_time + effective_segment_duration) * video_fps), total_video_frames)
        
        # Ensure we have a valid range
        if start_frame >= end_frame:
            end_frame = min(start_frame + int(0.5 * video_fps), total_video_frames)  # At least 0.5 seconds
        
        # Ensure end_frame is within bounds
        end_frame = min(end_frame, total_video_frames)
        
        # Sample frames within this segment
        if segment_frames == 1:
            # Single frame: take middle of segment
            frame_idx = start_frame + (end_frame - start_frame) // 2
            segment_indices = [min(frame_idx, total_video_frames - 1)]
        elif end_frame - start_frame <= segment_frames:
            # If segment is too short, take all available frames and pad
            available_frames = list(range(start_frame, end_frame))
            while len(available_frames) < segment_frames and available_frames:
                # Duplicate frames if needed
                available_frames.extend(available_frames[:segment_frames - len(available_frames)])
            segment_indices = available_frames[:segment_frames]
        else:
            # Multiple frames: distribute evenly within segment
            segment_indices = np.linspace(start_frame, end_frame - 1, segment_frames, dtype=int).tolist()
        
        all_indices.extend(segment_indices)
        frames_collected += len(segment_indices)
        
        # Safety check to prevent infinite loops
        if frames_collected >= num_frames:
            break
    
    # Convert to numpy array for easier manipulation
    all_indices = np.array(all_indices)
    
    # Ensure we have exactly num_frames - this is critical
    if len(all_indices) != num_frames:
        if len(all_indices) > num_frames:
            # Too many frames: select exactly num_frames uniformly
            step = len(all_indices) / num_frames
            selected_indices = [all_indices[int(i * step)] for i in range(num_frames)]
            all_indices = np.array(selected_indices)
        else:
            # Too few frames: pad by repeating frames
            needed = num_frames - len(all_indices)
            if len(all_indices) > 0:
                # Repeat existing frames cyclically
                additional_indices = []
                for i in range(needed):
                    additional_indices.append(all_indices[i % len(all_indices)])
                all_indices = np.concatenate([all_indices, additional_indices])
            else:
                # Fallback: use original sampling
                return sample_frame_indices(num_frames, 4, total_video_frames)
    
    # Final cleanup: ensure all indices are valid and within bounds
    all_indices = np.clip(all_indices, 0, total_video_frames - 1)
    
    # Sort indices to maintain temporal order
    all_indices = np.sort(all_indices)
    
    # Final verification - this should never fail now
    assert len(all_indices) == num_frames, f"Expected {num_frames} frames, got {len(all_indices)}"
    
    return all_indices.tolist()

def extract_frames_at_indices(video_path, frame_indices):
    """
    Extract frames from video at specified indices.
    
    Args:
        video_path (str): Path to video file
        frame_indices (list): List of frame indices to extract
        
    Returns:
        list: List of PIL Images
    """
    container = av.open(video_path)
    video_stream = container.streams.video[0]
    
    frames = []
    frame_idx = 0
    target_indices = set(frame_indices)
    
    # Decode video and extract frames at specified indices
    for frame in container.decode(video=0):
        if frame_idx in target_indices:
            # Convert frame to PIL Image
            img = frame.to_image()
            frames.append(img)
            
            # Remove from target set
            target_indices.remove(frame_idx)
            
            # Stop if we've collected all frames
            if not target_indices:
                break
        
        frame_idx += 1
    
    container.close()
    return frames

def process_video(video_file, num_frames, segment_duration, num_segments):
    """
    Main processing function for Gradio interface.
    
    Args:
        video_file: Uploaded video file
        num_frames (int): Number of frames to sample
        segment_duration (float): Duration of each segment in seconds
        num_segments (int): Number of segments
        
    Returns:
        tuple: (frames list, info string, indices list)
    """
    if video_file is None:
        return [], "Please upload a video file", []
    
    try:
        # Open video container
        container = av.open(video_file)
        video_stream = container.streams.video[0]
        
        # Get video info
        video_fps = float(video_stream.average_rate)
        total_frames = video_stream.frames
        video_duration = total_frames / video_fps if video_fps > 0 else 0
        
        # Get frame indices using the sampling function
        frame_indices = sample_frame_indices_efficient_segments(
            num_frames, segment_duration, num_segments, container
        )
        
        container.close()
        
        # Extract frames at selected indices
        frames = extract_frames_at_indices(video_file, frame_indices)
        
        # Create info string
        info = f"""
        **Video Information:**
        - Total frames: {total_frames}
        - FPS: {video_fps:.2f}
        - Duration: {video_duration:.2f} seconds
        """
        
        # Add frame numbers to images for display
        labeled_frames = []
        for i, (frame, idx) in enumerate(zip(frames, frame_indices)):
            # Create a copy and add text overlay
            frame_copy = frame.copy()
            # Add frame number as caption
            labeled_frames.append((frame_copy, f"Frame {idx} (Sample {i+1}/{num_frames})"))
        
        return labeled_frames, info, frame_indices
        
    except Exception as e:
        return [], f"Error processing video: {str(e)}", []

# Create Gradio interface
with gr.Blocks(title="PATS: Proficiency-Aware Temporal Sampling for Multi-View Sports Skill Assessment") as demo:
    gr.Markdown("""
    # PATS: Proficiency-Aware Temporal Sampling for Multi-View Sports Skill Assessment

    PATS (Proficiency-Aware Temporal Sampling) is a novel video sampling strategy designed specifically for automated sports skill assessment. 
    Unlike traditional methods that randomly sample frames or use uniform intervals, PATS preserves complete fundamental movements within continuous temporal segments. 
    The paper presenting PATS has been accepted at the 2025 4th IEEE Sport Technology and Research Workshop.

    This tool showcases the PATS sampling strategy. Find out more at the project page: https://edowhite.github.io/PATS

    ## Core Concept
    The key insight is that athletic proficiency manifests through structured temporal patterns that require observing complete, uninterrupted movements. 
    PATS addresses this by:

    - **Extracting continuous temporal segments** rather than isolated frames
    - **Preserving natural movement flow** essential for distinguishing expert from novice performance
    - **Distributing multiple segments** across the video timeline to maximize information coverage

    ## Performance
    When applied to SkillFormer on the EgoExo4D benchmark, PATS achieves:
    
    - **Consistent improvements** across all viewing configurations (+0.65% to +3.05%)
    - **Substantial domain-specific gains:** +26.22% in bouldering, +2.39% in music, +1.13% in basketball

    """)
    
    with gr.Row():
        with gr.Column(scale=1):
            # Input components
            video_input = gr.Video(label="Upload Video")
            
            gr.Markdown("### Sampling Parameters")
            num_frames = gr.Slider(
                minimum=1, 
                maximum=50, 
                value=8, 
                step=1, 
                label="Number of Frames to Sample",
                info="Total number of frames to extract from the video"
            )
            
            num_segments = gr.Slider(
                minimum=1, 
                maximum=20, 
                value=4, 
                step=1, 
                label="Number of Segments",
                info="Number of temporal segments to divide the video into"
            )
            
            segment_duration = gr.Slider(
                minimum=0.5, 
                maximum=10.0, 
                value=2.0, 
                step=0.5, 
                label="Segment Duration (seconds)",
                info="Duration of each segment for sampling"
            )
            
            process_btn = gr.Button("Process Video", variant="primary")
            
        with gr.Column(scale=2):
            # Output components
            info_output = gr.Markdown(label="Processing Information")
            gallery_output = gr.Gallery(
                label="Sampled Frames", 
                show_label=True, 
                elem_id="gallery",
                columns=4,
                rows=3,
                height="auto"
            )
            indices_output = gr.JSON(label="Frame Indices", visible=False)
    
    # Connect the processing function
    process_btn.click(
        fn=process_video,
        inputs=[video_input, num_frames, segment_duration, num_segments],
        outputs=[gallery_output, info_output, indices_output]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()