sourxbhh's picture
UI fixes
9b0e308
"""
Gradio Web Interface for ACMDM Motion Generation
Milestone 3: User-friendly web interface for text-to-motion generation
"""
import os
import sys
from os.path import join as pjoin, dirname, abspath
# Add current directory to Python path for HuggingFace Spaces
# This ensures models/ and utils/ can be imported
current_dir = dirname(abspath(__file__))
if current_dir not in sys.path:
sys.path.insert(0, current_dir)
import torch
import numpy as np
import gradio as gr
from typing import Optional, Tuple, List
import tempfile
import random
# Import from models and utils
from models.AE_2D_Causal import AE_models
from models.ACMDM import ACMDM_models
from models.LengthEstimator import LengthEstimator
from utils.back_process import back_process
from utils.motion_process import plot_3d_motion, t2m_kinematic_chain
# Global variables for model caching
_models_cache = {
'ae': None,
'acmdm': None,
'length_estimator': None,
'stats': None,
'device': None,
'loaded': False
}
def set_seed(seed):
"""Set random seed for reproducibility"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
def load_models_cached(
gpu_id: int = 0,
model_name: str = 'ACMDM_Flow_S_PatchSize22',
ae_name: str = 'AE_2D_Causal',
ae_model: str = 'AE_Model',
model_type: str = 'ACMDM-Flow-S-PatchSize22',
dataset_name: str = 't2m',
checkpoints_dir: str = './checkpoints',
use_length_estimator: bool = False
) -> Tuple[dict, str]:
"""
Load models with caching to avoid reloading on every request.
Returns (models_dict, status_message)
"""
global _models_cache
# Check if models are already loaded
if _models_cache['loaded']:
return _models_cache, "Models already loaded (using cache)"
try:
# Determine device
if gpu_id >= 0 and torch.cuda.is_available():
device = torch.device(f"cuda:{gpu_id}")
else:
device = torch.device("cpu")
_models_cache['device'] = device
status_messages = [f"Using device: {device}"]
# Load AE
status_messages.append("Loading AE model...")
ae = AE_models[ae_model](input_width=3)
ae_ckpt_path = pjoin(checkpoints_dir, dataset_name, ae_name, 'model', 'latest.tar')
if not os.path.exists(ae_ckpt_path):
return None, f"Error: AE checkpoint not found at {ae_ckpt_path}"
ae_ckpt = torch.load(ae_ckpt_path, map_location='cpu')
ae.load_state_dict(ae_ckpt['ae'])
ae.eval()
ae.to(device)
_models_cache['ae'] = ae
status_messages.append(f"โœ“ Loaded AE from {ae_ckpt_path}")
# Load ACMDM
status_messages.append("Loading ACMDM model...")
acmdm = ACMDM_models[model_type](input_dim=ae.output_emb_width, cond_mode='text')
acmdm_ckpt_path = pjoin(checkpoints_dir, dataset_name, model_name, 'model', 'latest.tar')
if not os.path.exists(acmdm_ckpt_path):
return None, f"Error: ACMDM checkpoint not found at {acmdm_ckpt_path}"
acmdm_ckpt = torch.load(acmdm_ckpt_path, map_location='cpu')
missing_keys, unexpected_keys = acmdm.load_state_dict(acmdm_ckpt['ema_acmdm'], strict=False)
assert len(unexpected_keys) == 0
assert all([k.startswith('clip_model.') for k in missing_keys])
acmdm.eval()
acmdm.to(device)
_models_cache['acmdm'] = acmdm
status_messages.append(f"โœ“ Loaded ACMDM from {acmdm_ckpt_path}")
# Load LengthEstimator if needed
length_estimator = None
if use_length_estimator:
status_messages.append("Loading LengthEstimator...")
length_estimator = LengthEstimator(input_size=512, output_size=1)
length_estimator_path = pjoin(checkpoints_dir, dataset_name, 'length_estimator', 'model', 'latest.tar')
if os.path.exists(length_estimator_path):
length_ckpt = torch.load(length_estimator_path, map_location='cpu')
length_estimator.load_state_dict(length_ckpt['model'])
length_estimator.eval()
length_estimator.to(device)
_models_cache['length_estimator'] = length_estimator
status_messages.append(f"โœ“ Loaded LengthEstimator")
else:
status_messages.append(f"โš  LengthEstimator not found, will use default length")
# Load normalization stats
status_messages.append("Loading normalization statistics...")
after_mean = np.load(pjoin(checkpoints_dir, dataset_name, ae_name, 'AE_2D_Causal_Post_Mean.npy'))
after_std = np.load(pjoin(checkpoints_dir, dataset_name, ae_name, 'AE_2D_Causal_Post_Std.npy'))
joint_mean = np.load(f'utils/22x3_mean_std/{dataset_name}/22x3_mean.npy')
joint_std = np.load(f'utils/22x3_mean_std/{dataset_name}/22x3_std.npy')
eval_mean = np.load(f'utils/eval_mean_std/{dataset_name}/eval_mean.npy')
eval_std = np.load(f'utils/eval_mean_std/{dataset_name}/eval_std.npy')
_models_cache['stats'] = {
'after_mean': after_mean,
'after_std': after_std,
'joint_mean': joint_mean,
'joint_std': joint_std,
'eval_mean': eval_mean,
'eval_std': eval_std
}
_models_cache['loaded'] = True
status_message = "\n".join(status_messages)
return _models_cache, status_message
except Exception as e:
error_msg = f"Error loading models: {str(e)}"
import traceback
error_msg += f"\n\nTraceback:\n{traceback.format_exc()}"
return None, error_msg
def estimate_motion_length(text: str, models_cache: dict) -> int:
"""Estimate motion length from text using LengthEstimator"""
if models_cache['length_estimator'] is None:
return None
device = models_cache['device']
acmdm = models_cache['acmdm']
length_estimator = models_cache['length_estimator']
with torch.no_grad():
text_emb = acmdm.encode_text([text])
pred_length = length_estimator(text_emb)
pred_length = int(pred_length.item() * 4)
pred_length = ((pred_length + 2) // 4) * 4
pred_length = max(40, min(196, pred_length))
return pred_length
def generate_motion_single(
text: str,
motion_length: Optional[int],
cfg_scale: float,
use_auto_length: bool,
gpu_id: int,
seed: int
) -> Tuple[Optional[str], str]:
"""
Generate a single motion from text.
Returns (video_path, status_message)
"""
global _models_cache
try:
set_seed(seed)
# Load models if not cached
if not _models_cache['loaded']:
models_cache, load_msg = load_models_cached(
gpu_id=gpu_id,
use_length_estimator=use_auto_length
)
if models_cache is None:
return None, f"Failed to load models:\n{load_msg}"
else:
models_cache = _models_cache
device = models_cache['device']
ae = models_cache['ae']
acmdm = models_cache['acmdm']
stats = models_cache['stats']
# Estimate length if needed
if motion_length is None or (motion_length == 0 and use_auto_length):
if use_auto_length and models_cache['length_estimator'] is not None:
motion_length = estimate_motion_length(text, models_cache)
status_msg = f"Estimated motion length: {motion_length} frames\n"
else:
motion_length = 120 # Default
status_msg = f"Using default motion length: {motion_length} frames\n"
else:
# Round to multiple of 4
motion_length = ((motion_length + 2) // 4) * 4
status_msg = f"Using specified motion length: {motion_length} frames\n"
status_msg += f"Generating motion for: '{text}'...\n"
# Generate motion
with torch.no_grad():
latent_length = motion_length // 4
m_lens = torch.tensor([latent_length], device=device)
pred_latents = acmdm.generate([text], m_lens, cfg_scale)
# Denormalize latents
pred_latents_np = pred_latents.permute(0, 2, 3, 1).detach().cpu().numpy()
pred_latents_np = pred_latents_np * stats['after_std'] + stats['after_mean']
pred_latents_tensor = torch.from_numpy(pred_latents_np).to(device)
# Decode through AE
pred_motions = ae.decode(pred_latents_tensor.permute(0, 3, 1, 2))
# Denormalize motions
pred_motions_np = pred_motions.permute(0, 2, 3, 1).detach().cpu().numpy()
if stats['joint_mean'].ndim == 1:
pred_motions_np = pred_motions_np * stats['joint_std'][np.newaxis, np.newaxis, :, np.newaxis] + stats['joint_mean'][np.newaxis, np.newaxis, :, np.newaxis]
else:
pred_motions_np = pred_motions_np * stats['joint_std'][np.newaxis, ..., np.newaxis] + stats['joint_mean'][np.newaxis, ..., np.newaxis]
# Back process to get RIC format, then recover joint positions
from utils.motion_process import recover_from_ric
motion = pred_motions_np[0] # (22, 3, seq_len)
motion = motion[:, :, :motion_length].transpose(2, 0, 1) # (seq_len, 22, 3)
ric_data = back_process(motion, is_mesh=False)
ric_tensor = torch.from_numpy(ric_data).float()
joints = recover_from_ric(ric_tensor, joints_num=22).numpy() # (seq_len, 22, 3)
# Create temporary file for video
temp_dir = tempfile.mkdtemp()
os.makedirs(temp_dir, exist_ok=True)
video_path = pjoin(temp_dir, 'motion.mp4')
# Generate video - plot_3d_motion returns the actual path (may be .mp4 or .gif)
try:
actual_video_path = plot_3d_motion(
video_path,
t2m_kinematic_chain,
joints,
title=text,
fps=20,
radius=4
)
except Exception as e:
status_msg += f"\nโš  Error in plot_3d_motion: {str(e)}"
raise
# Verify file exists - check both the returned path and original path
if not os.path.exists(actual_video_path):
# Check if file exists with original path
if os.path.exists(video_path):
actual_video_path = video_path
# Check if it was saved as GIF instead
elif os.path.exists(video_path.replace('.mp4', '.gif')):
actual_video_path = video_path.replace('.mp4', '.gif')
else:
# List files in temp directory for debugging
files_in_dir = os.listdir(temp_dir) if os.path.exists(temp_dir) else []
error_msg = f"Video file not found. Expected: {actual_video_path}\n"
error_msg += f"Files in temp directory: {files_in_dir}"
status_msg += f"\nโŒ {error_msg}"
raise FileNotFoundError(error_msg)
status_msg += f"\nโœ“ Motion generated successfully! Video saved."
return actual_video_path, status_msg
except Exception as e:
error_msg = f"Error generating motion: {str(e)}"
import traceback
error_msg += f"\n\nTraceback:\n{traceback.format_exc()}"
return None, error_msg
# Batch generation feature commented out
# def generate_motion_batch(
# text_file_content: str,
# cfg_scale: float,
# use_auto_length: bool,
# gpu_id: int,
# seed: int
# ) -> Tuple[List[Optional[str]], str]:
# """
# Generate motions from a batch of text prompts.
# Returns (list_of_video_paths, status_message)
# """
# # Parse text file
# text_prompts = []
# for line in text_file_content.strip().split('\n'):
# line = line.strip()
# if not line:
# continue
# if '#' in line:
# parts = line.split('#')
# text = parts[0].strip()
# length_str = parts[1].strip() if len(parts) > 1 else 'NA'
# else:
# text = line
# length_str = 'NA'
#
# if length_str.upper() == 'NA':
# motion_length = None if use_auto_length else 120
# else:
# try:
# motion_length = int(length_str)
# motion_length = ((motion_length + 2) // 4) * 4
# except:
# motion_length = None if use_auto_length else 120
#
# text_prompts.append((text, motion_length))
#
# if not text_prompts:
# return [], "No valid prompts found in file"
#
# status_msg = f"Processing {len(text_prompts)} prompts...\n"
# video_paths = []
#
# # Create a persistent directory for batch videos (not temp, so Gradio can access them)
# # Use a directory that Gradio can serve from
# batch_output_dir = pjoin('generation', 'batch_temp')
# os.makedirs(batch_output_dir, exist_ok=True)
#
# # Also ensure the parent directory exists
# os.makedirs('generation', exist_ok=True)
#
# for idx, (text, motion_length) in enumerate(text_prompts):
# try:
# video_path, gen_msg = generate_motion_single(
# text, motion_length, cfg_scale, use_auto_length, gpu_id, seed + idx
# )
#
# # If video was created, copy it to persistent location for batch gallery
# if video_path and os.path.exists(video_path):
# # Create a unique filename for this batch item
# # Get file extension from original path
# file_ext = os.path.splitext(video_path)[1] or '.mp4'
# batch_video_path = pjoin(batch_output_dir, f'motion_{idx:04d}{file_ext}')
#
# # Copy file to persistent location
# import shutil
# try:
# shutil.copy2(video_path, batch_video_path)
#
# # Verify the copied file exists
# if os.path.exists(batch_video_path):
# # Use relative path for Gradio (works better in HuggingFace Spaces)
# # Gradio can serve files from relative paths within the app directory
# video_paths.append(batch_video_path)
# file_size = os.path.getsize(batch_video_path)
# status_msg += f"\n[{idx+1}/{len(text_prompts)}] โœ“ {text[:50]}... - Video saved ({file_size/1024:.1f} KB)"
# else:
# status_msg += f"\n[{idx+1}/{len(text_prompts)}] โš  {text[:50]}... - Video copy failed (file not found after copy)"
# video_paths.append(None)
# except Exception as copy_error:
# status_msg += f"\n[{idx+1}/{len(text_prompts)}] โš  {text[:50]}... - Copy error: {str(copy_error)}"
# video_paths.append(None)
# else:
# status_msg += f"\n[{idx+1}/{len(text_prompts)}] โŒ {text[:50]}... - Generation failed (no video path returned)"
# video_paths.append(None)
#
# except Exception as e:
# status_msg += f"\n[{idx+1}/{len(text_prompts)}] โŒ Error: {str(e)}"
# import traceback
# status_msg += f"\n Traceback: {traceback.format_exc()[:200]}"
# video_paths.append(None)
#
# # Filter out None values and verify files exist - Gradio Gallery needs valid paths
# valid_video_paths = []
# for path in video_paths:
# if path is not None:
# # Use relative path (Gradio handles these better in cloud environments)
# # Verify file exists
# if os.path.exists(path):
# # Verify it's a video file
# if path.lower().endswith(('.mp4', '.gif', '.mov', '.avi')):
# # Get file size for debugging
# file_size = os.path.getsize(path)
# if file_size > 0: # Ensure file is not empty
# valid_video_paths.append(path) # Keep relative path
# status_msg += f"\n โœ“ Verified: {os.path.basename(path)} ({file_size/1024:.1f} KB)"
# else:
# status_msg += f"\n โš  Empty file: {path}"
# else:
# status_msg += f"\n โš  Skipping non-video file: {path}"
# else:
# # Try absolute path as fallback
# abs_path = os.path.abspath(path)
# if os.path.exists(abs_path):
# valid_video_paths.append(abs_path)
# status_msg += f"\n โœ“ Found via absolute path: {abs_path}"
# else:
# status_msg += f"\n โš  File not found (relative: {path}, absolute: {abs_path})"
#
# if len(valid_video_paths) == 0:
# status_msg += "\n\nโŒ No videos were successfully generated for the gallery."
# status_msg += "\n๐Ÿ’ก Check the error messages above for each prompt."
# # Return empty list explicitly
# return [], status_msg
# else:
# status_msg += f"\n\nโœ… Successfully generated {len(valid_video_paths)}/{len(text_prompts)} motions."
# status_msg += f"\n๐Ÿ“ Videos saved in: {batch_output_dir}"
# status_msg += f"\n๐Ÿ“‹ Gallery will display {len(valid_video_paths)} video(s)."
#
# # Debug: Show first few paths
# status_msg += f"\n\n๐ŸŽฌ First few video paths:\n"
# for i, p in enumerate(valid_video_paths[:3]):
# abs_p = os.path.abspath(p) if not os.path.isabs(p) else p
# exists = "โœ“" if os.path.exists(p) else "โœ—"
# status_msg += f" {exists} [{i+1}] {p} (exists: {os.path.exists(p)})\n"
# if len(valid_video_paths) > 3:
# status_msg += f" ... and {len(valid_video_paths) - 3} more\n"
#
# # Return list of file paths (Gradio Gallery expects list of strings)
# # Ensure all paths are valid and accessible
# final_paths = []
# for p in valid_video_paths:
# if os.path.exists(p):
# # Normalize path (use forward slashes for cross-platform compatibility)
# normalized_path = p.replace('\\', '/')
# final_paths.append(normalized_path)
# else:
# status_msg += f"\nโš  Warning: Path not found when returning: {p}"
#
# if len(final_paths) != len(valid_video_paths):
# status_msg += f"\nโš  Some paths were filtered out. Returning {len(final_paths)}/{len(valid_video_paths)} videos."
#
# return final_paths, status_msg
# Gradio Interface
def create_interface():
"""Create and configure the Gradio interface"""
with gr.Blocks(title="ACMDM Motion Generation", theme=gr.themes.Soft()) as app:
gr.Markdown("""
# ๐ŸŽญ ACMDM Motion Generation
Generate human motion from text descriptions using the ACMDM (Absolute Coordinates Make Motion Generation Easy) model.
**How to use:**
1. Enter a text description of the motion you want to generate
2. Adjust motion length (or use auto-estimate)
3. Click "Generate Motion" to create the animation
4. View and download the generated video
**Example prompts:**
- "A person is running on a treadmill."
- "Someone is doing jumping jacks."
- "A person walks forward and then turns around."
""")
with gr.Tabs():
# Single Generation Tab
with gr.Tab("Single Motion Generation"):
with gr.Row():
with gr.Column(scale=1):
text_input = gr.Textbox(
label="Motion Description",
placeholder="A person is running on a treadmill.",
lines=3,
value="A person is running on a treadmill."
)
with gr.Row():
motion_length = gr.Slider(
label="Motion Length (frames)",
minimum=40,
maximum=196,
value=120,
step=4,
info="Will be rounded to nearest multiple of 4"
)
use_auto_length = gr.Checkbox(
label="Auto-estimate length",
value=False,
info="Use length estimator (ignores manual length if checked)"
)
cfg_scale = gr.Slider(
label="CFG Scale",
minimum=1.0,
maximum=10.0,
value=3.0,
step=0.5,
info="Classifier-free guidance scale (higher = more aligned with text)"
)
with gr.Row():
gpu_id = gr.Number(
label="GPU ID",
value=0,
precision=0,
info="Use -1 for CPU"
)
seed = gr.Number(
label="Random Seed",
value=3407,
precision=0
)
generate_btn = gr.Button("Generate Motion", variant="primary", size="lg")
with gr.Column(scale=1):
video_output = gr.Video(
label="Generated Motion",
format="mp4"
)
status_output = gr.Textbox(
label="Status",
lines=10,
interactive=False
)
# Update model status after generation
def generate_and_update_status(text, motion_length, cfg_scale, use_auto_length, gpu_id, seed):
video_path, status_msg = generate_motion_single(
text, motion_length, cfg_scale, use_auto_length, gpu_id, seed
)
# Return video and status, plus trigger status update
return video_path, status_msg
generate_btn.click(
fn=generate_and_update_status,
inputs=[text_input, motion_length, cfg_scale, use_auto_length, gpu_id, seed],
outputs=[video_output, status_output]
)
# Batch Generation Tab - COMMENTED OUT
# with gr.Tab("Batch Generation"):
# with gr.Row():
# with gr.Column(scale=1):
# batch_text_input = gr.Textbox(
# label="Text Prompts (one per line, format: text#length or text#NA)",
# placeholder="A person is running on a treadmill.#120\nSomeone is doing jumping jacks.#NA",
# lines=10,
# info="Each line: 'text#length' or 'text#NA' for auto-estimate"
# )
#
# batch_cfg_scale = gr.Slider(
# label="CFG Scale",
# minimum=1.0,
# maximum=10.0,
# value=3.0,
# step=0.5
# )
#
# batch_use_auto_length = gr.Checkbox(
# label="Auto-estimate length for NA",
# value=True
# )
#
# batch_gpu_id = gr.Number(
# label="GPU ID",
# value=0,
# precision=0
# )
#
# batch_seed = gr.Number(
# label="Random Seed",
# value=3407,
# precision=0
# )
#
# batch_generate_btn = gr.Button("Generate Batch", variant="primary", size="lg")
#
# with gr.Column(scale=1):
# batch_status_output = gr.Textbox(
# label="Batch Status",
# lines=15,
# interactive=False
# )
# batch_video_gallery = gr.Gallery(
# label="Generated Motions",
# show_label=True,
# elem_id="gallery",
# columns=2,
# rows=2,
# height="auto",
# type="filepath" # Explicitly specify filepath type
# )
#
# batch_generate_btn.click(
# fn=generate_motion_batch,
# inputs=[batch_text_input, batch_cfg_scale, batch_use_auto_length, batch_gpu_id, batch_seed],
# outputs=[batch_video_gallery, batch_status_output]
# )
# Model Management Tab
with gr.Tab("Model Management"):
with gr.Row():
with gr.Column():
model_status = gr.Textbox(
label="Model Status",
lines=15,
interactive=False,
value="โณ Models not loaded yet. They will be loaded automatically on first generation."
)
with gr.Row():
refresh_status_btn = gr.Button("๐Ÿ”„ Refresh Status", variant="primary")
reload_models_btn = gr.Button("๐Ÿ—‘๏ธ Clear Cache", variant="secondary")
gr.Markdown("""
**Model Configuration:**
- Model Name: `ACMDM_Flow_S_PatchSize22`
- AE Name: `AE_2D_Causal`
- Dataset: `t2m`
- Checkpoints Directory: `./checkpoints`
""")
def check_model_status():
"""Check and display current model status"""
global _models_cache
if _models_cache['loaded']:
device = _models_cache['device']
status_lines = [
"โœ“ MODELS LOADED AND READY",
"=" * 50,
f"๐Ÿ“ฑ Device: {device}",
f"๐Ÿ’พ CUDA Available: {torch.cuda.is_available()}",
]
if device.type == 'cuda':
status_lines.append(f"๐ŸŽฎ GPU: {torch.cuda.get_device_name(device.index if device.index is not None else 0)}")
status_lines.append(f"๐Ÿ’พ GPU Memory: {torch.cuda.get_device_properties(device.index if device.index is not None else 0).total_memory / 1e9:.2f} GB")
status_lines.extend([
"",
"๐Ÿ“ฆ Loaded Models:",
" โœ“ Autoencoder (AE_2D_Causal)",
" โœ“ ACMDM Diffusion Model",
])
if _models_cache['length_estimator'] is not None:
status_lines.append(" โœ“ Length Estimator")
else:
status_lines.append(" โš  Length Estimator (not loaded)")
status_lines.extend([
"",
"๐Ÿ“Š Statistics Loaded:",
" โœ“ Post-AE Mean/Std",
" โœ“ Joint Mean/Std",
" โœ“ Eval Mean/Std",
"",
"โœจ Models are cached and ready for generation!",
"๐Ÿ’ก Tip: Use 'Clear Cache' to force reload on next generation."
])
return "\n".join(status_lines)
else:
return (
"โณ Models not loaded yet. They will be loaded automatically on first generation.\n"
"๐Ÿ“ To load models now, go to 'Single Motion Generation' tab and click 'Generate Motion'.\n"
"โฑ๏ธ First generation will take 30-60 seconds (model loading time).\n"
"โšก Subsequent generations will be much faster (5-15 seconds)."
)
def reload_models():
"""Clear model cache"""
global _models_cache
_models_cache['loaded'] = False
_models_cache['ae'] = None
_models_cache['acmdm'] = None
_models_cache['length_estimator'] = None
_models_cache['stats'] = None
_models_cache['device'] = None
return (
"๐Ÿ—‘๏ธ Model cache cleared. Models will be automatically reloaded on your next generation request.\n"
"๐Ÿ’ก Click 'Refresh Status' after generating to see updated status."
)
# Button callbacks
refresh_status_btn.click(
fn=check_model_status,
outputs=model_status
)
reload_models_btn.click(
fn=reload_models,
outputs=model_status
)
# Update status on tab load
app.load(
fn=check_model_status,
outputs=model_status
)
gr.Markdown("""
---
**Note:** First generation may take longer as models need to be loaded. Subsequent generations will be faster.
**Tips:**
- Use descriptive text prompts for better results
- Adjust CFG scale: higher values (3-5) for more text alignment, lower values (1-2) for more diversity
- Motion length should be a multiple of 4 (automatically rounded)
- For batch processing, use the format: `text description#length` or `text description#NA`
""")
return app
if __name__ == "__main__":
# Create and launch the interface
app = create_interface()
# Launch with sharing option
app.launch(
server_name="0.0.0.0", # Allow external connections
server_port=7860, # Default Gradio port
share=False, # Set to True for public link (requires ngrok)
show_error=True
)