Spaces:
Sleeping
Sleeping
| """ | |
| Gradio Web Interface for ACMDM Motion Generation | |
| Milestone 3: User-friendly web interface for text-to-motion generation | |
| """ | |
| import os | |
| import sys | |
| from os.path import join as pjoin, dirname, abspath | |
| # Add current directory to Python path for HuggingFace Spaces | |
| # This ensures models/ and utils/ can be imported | |
| current_dir = dirname(abspath(__file__)) | |
| if current_dir not in sys.path: | |
| sys.path.insert(0, current_dir) | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from typing import Optional, Tuple, List | |
| import tempfile | |
| import random | |
| # Import from models and utils | |
| from models.AE_2D_Causal import AE_models | |
| from models.ACMDM import ACMDM_models | |
| from models.LengthEstimator import LengthEstimator | |
| from utils.back_process import back_process | |
| from utils.motion_process import plot_3d_motion, t2m_kinematic_chain | |
| # Global variables for model caching | |
| _models_cache = { | |
| 'ae': None, | |
| 'acmdm': None, | |
| 'length_estimator': None, | |
| 'stats': None, | |
| 'device': None, | |
| 'loaded': False | |
| } | |
| def set_seed(seed): | |
| """Set random seed for reproducibility""" | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.backends.cudnn.benchmark = False | |
| def load_models_cached( | |
| gpu_id: int = 0, | |
| model_name: str = 'ACMDM_Flow_S_PatchSize22', | |
| ae_name: str = 'AE_2D_Causal', | |
| ae_model: str = 'AE_Model', | |
| model_type: str = 'ACMDM-Flow-S-PatchSize22', | |
| dataset_name: str = 't2m', | |
| checkpoints_dir: str = './checkpoints', | |
| use_length_estimator: bool = False | |
| ) -> Tuple[dict, str]: | |
| """ | |
| Load models with caching to avoid reloading on every request. | |
| Returns (models_dict, status_message) | |
| """ | |
| global _models_cache | |
| # Check if models are already loaded | |
| if _models_cache['loaded']: | |
| return _models_cache, "Models already loaded (using cache)" | |
| try: | |
| # Determine device | |
| if gpu_id >= 0 and torch.cuda.is_available(): | |
| device = torch.device(f"cuda:{gpu_id}") | |
| else: | |
| device = torch.device("cpu") | |
| _models_cache['device'] = device | |
| status_messages = [f"Using device: {device}"] | |
| # Load AE | |
| status_messages.append("Loading AE model...") | |
| ae = AE_models[ae_model](input_width=3) | |
| ae_ckpt_path = pjoin(checkpoints_dir, dataset_name, ae_name, 'model', 'latest.tar') | |
| if not os.path.exists(ae_ckpt_path): | |
| return None, f"Error: AE checkpoint not found at {ae_ckpt_path}" | |
| ae_ckpt = torch.load(ae_ckpt_path, map_location='cpu') | |
| ae.load_state_dict(ae_ckpt['ae']) | |
| ae.eval() | |
| ae.to(device) | |
| _models_cache['ae'] = ae | |
| status_messages.append(f"โ Loaded AE from {ae_ckpt_path}") | |
| # Load ACMDM | |
| status_messages.append("Loading ACMDM model...") | |
| acmdm = ACMDM_models[model_type](input_dim=ae.output_emb_width, cond_mode='text') | |
| acmdm_ckpt_path = pjoin(checkpoints_dir, dataset_name, model_name, 'model', 'latest.tar') | |
| if not os.path.exists(acmdm_ckpt_path): | |
| return None, f"Error: ACMDM checkpoint not found at {acmdm_ckpt_path}" | |
| acmdm_ckpt = torch.load(acmdm_ckpt_path, map_location='cpu') | |
| missing_keys, unexpected_keys = acmdm.load_state_dict(acmdm_ckpt['ema_acmdm'], strict=False) | |
| assert len(unexpected_keys) == 0 | |
| assert all([k.startswith('clip_model.') for k in missing_keys]) | |
| acmdm.eval() | |
| acmdm.to(device) | |
| _models_cache['acmdm'] = acmdm | |
| status_messages.append(f"โ Loaded ACMDM from {acmdm_ckpt_path}") | |
| # Load LengthEstimator if needed | |
| length_estimator = None | |
| if use_length_estimator: | |
| status_messages.append("Loading LengthEstimator...") | |
| length_estimator = LengthEstimator(input_size=512, output_size=1) | |
| length_estimator_path = pjoin(checkpoints_dir, dataset_name, 'length_estimator', 'model', 'latest.tar') | |
| if os.path.exists(length_estimator_path): | |
| length_ckpt = torch.load(length_estimator_path, map_location='cpu') | |
| length_estimator.load_state_dict(length_ckpt['model']) | |
| length_estimator.eval() | |
| length_estimator.to(device) | |
| _models_cache['length_estimator'] = length_estimator | |
| status_messages.append(f"โ Loaded LengthEstimator") | |
| else: | |
| status_messages.append(f"โ LengthEstimator not found, will use default length") | |
| # Load normalization stats | |
| status_messages.append("Loading normalization statistics...") | |
| after_mean = np.load(pjoin(checkpoints_dir, dataset_name, ae_name, 'AE_2D_Causal_Post_Mean.npy')) | |
| after_std = np.load(pjoin(checkpoints_dir, dataset_name, ae_name, 'AE_2D_Causal_Post_Std.npy')) | |
| joint_mean = np.load(f'utils/22x3_mean_std/{dataset_name}/22x3_mean.npy') | |
| joint_std = np.load(f'utils/22x3_mean_std/{dataset_name}/22x3_std.npy') | |
| eval_mean = np.load(f'utils/eval_mean_std/{dataset_name}/eval_mean.npy') | |
| eval_std = np.load(f'utils/eval_mean_std/{dataset_name}/eval_std.npy') | |
| _models_cache['stats'] = { | |
| 'after_mean': after_mean, | |
| 'after_std': after_std, | |
| 'joint_mean': joint_mean, | |
| 'joint_std': joint_std, | |
| 'eval_mean': eval_mean, | |
| 'eval_std': eval_std | |
| } | |
| _models_cache['loaded'] = True | |
| status_message = "\n".join(status_messages) | |
| return _models_cache, status_message | |
| except Exception as e: | |
| error_msg = f"Error loading models: {str(e)}" | |
| import traceback | |
| error_msg += f"\n\nTraceback:\n{traceback.format_exc()}" | |
| return None, error_msg | |
| def estimate_motion_length(text: str, models_cache: dict) -> int: | |
| """Estimate motion length from text using LengthEstimator""" | |
| if models_cache['length_estimator'] is None: | |
| return None | |
| device = models_cache['device'] | |
| acmdm = models_cache['acmdm'] | |
| length_estimator = models_cache['length_estimator'] | |
| with torch.no_grad(): | |
| text_emb = acmdm.encode_text([text]) | |
| pred_length = length_estimator(text_emb) | |
| pred_length = int(pred_length.item() * 4) | |
| pred_length = ((pred_length + 2) // 4) * 4 | |
| pred_length = max(40, min(196, pred_length)) | |
| return pred_length | |
| def generate_motion_single( | |
| text: str, | |
| motion_length: Optional[int], | |
| cfg_scale: float, | |
| use_auto_length: bool, | |
| gpu_id: int, | |
| seed: int | |
| ) -> Tuple[Optional[str], str]: | |
| """ | |
| Generate a single motion from text. | |
| Returns (video_path, status_message) | |
| """ | |
| global _models_cache | |
| try: | |
| set_seed(seed) | |
| # Load models if not cached | |
| if not _models_cache['loaded']: | |
| models_cache, load_msg = load_models_cached( | |
| gpu_id=gpu_id, | |
| use_length_estimator=use_auto_length | |
| ) | |
| if models_cache is None: | |
| return None, f"Failed to load models:\n{load_msg}" | |
| else: | |
| models_cache = _models_cache | |
| device = models_cache['device'] | |
| ae = models_cache['ae'] | |
| acmdm = models_cache['acmdm'] | |
| stats = models_cache['stats'] | |
| # Estimate length if needed | |
| if motion_length is None or (motion_length == 0 and use_auto_length): | |
| if use_auto_length and models_cache['length_estimator'] is not None: | |
| motion_length = estimate_motion_length(text, models_cache) | |
| status_msg = f"Estimated motion length: {motion_length} frames\n" | |
| else: | |
| motion_length = 120 # Default | |
| status_msg = f"Using default motion length: {motion_length} frames\n" | |
| else: | |
| # Round to multiple of 4 | |
| motion_length = ((motion_length + 2) // 4) * 4 | |
| status_msg = f"Using specified motion length: {motion_length} frames\n" | |
| status_msg += f"Generating motion for: '{text}'...\n" | |
| # Generate motion | |
| with torch.no_grad(): | |
| latent_length = motion_length // 4 | |
| m_lens = torch.tensor([latent_length], device=device) | |
| pred_latents = acmdm.generate([text], m_lens, cfg_scale) | |
| # Denormalize latents | |
| pred_latents_np = pred_latents.permute(0, 2, 3, 1).detach().cpu().numpy() | |
| pred_latents_np = pred_latents_np * stats['after_std'] + stats['after_mean'] | |
| pred_latents_tensor = torch.from_numpy(pred_latents_np).to(device) | |
| # Decode through AE | |
| pred_motions = ae.decode(pred_latents_tensor.permute(0, 3, 1, 2)) | |
| # Denormalize motions | |
| pred_motions_np = pred_motions.permute(0, 2, 3, 1).detach().cpu().numpy() | |
| if stats['joint_mean'].ndim == 1: | |
| pred_motions_np = pred_motions_np * stats['joint_std'][np.newaxis, np.newaxis, :, np.newaxis] + stats['joint_mean'][np.newaxis, np.newaxis, :, np.newaxis] | |
| else: | |
| pred_motions_np = pred_motions_np * stats['joint_std'][np.newaxis, ..., np.newaxis] + stats['joint_mean'][np.newaxis, ..., np.newaxis] | |
| # Back process to get RIC format, then recover joint positions | |
| from utils.motion_process import recover_from_ric | |
| motion = pred_motions_np[0] # (22, 3, seq_len) | |
| motion = motion[:, :, :motion_length].transpose(2, 0, 1) # (seq_len, 22, 3) | |
| ric_data = back_process(motion, is_mesh=False) | |
| ric_tensor = torch.from_numpy(ric_data).float() | |
| joints = recover_from_ric(ric_tensor, joints_num=22).numpy() # (seq_len, 22, 3) | |
| # Create temporary file for video | |
| temp_dir = tempfile.mkdtemp() | |
| os.makedirs(temp_dir, exist_ok=True) | |
| video_path = pjoin(temp_dir, 'motion.mp4') | |
| # Generate video - plot_3d_motion returns the actual path (may be .mp4 or .gif) | |
| try: | |
| actual_video_path = plot_3d_motion( | |
| video_path, | |
| t2m_kinematic_chain, | |
| joints, | |
| title=text, | |
| fps=20, | |
| radius=4 | |
| ) | |
| except Exception as e: | |
| status_msg += f"\nโ Error in plot_3d_motion: {str(e)}" | |
| raise | |
| # Verify file exists - check both the returned path and original path | |
| if not os.path.exists(actual_video_path): | |
| # Check if file exists with original path | |
| if os.path.exists(video_path): | |
| actual_video_path = video_path | |
| # Check if it was saved as GIF instead | |
| elif os.path.exists(video_path.replace('.mp4', '.gif')): | |
| actual_video_path = video_path.replace('.mp4', '.gif') | |
| else: | |
| # List files in temp directory for debugging | |
| files_in_dir = os.listdir(temp_dir) if os.path.exists(temp_dir) else [] | |
| error_msg = f"Video file not found. Expected: {actual_video_path}\n" | |
| error_msg += f"Files in temp directory: {files_in_dir}" | |
| status_msg += f"\nโ {error_msg}" | |
| raise FileNotFoundError(error_msg) | |
| status_msg += f"\nโ Motion generated successfully! Video saved." | |
| return actual_video_path, status_msg | |
| except Exception as e: | |
| error_msg = f"Error generating motion: {str(e)}" | |
| import traceback | |
| error_msg += f"\n\nTraceback:\n{traceback.format_exc()}" | |
| return None, error_msg | |
| # Batch generation feature commented out | |
| # def generate_motion_batch( | |
| # text_file_content: str, | |
| # cfg_scale: float, | |
| # use_auto_length: bool, | |
| # gpu_id: int, | |
| # seed: int | |
| # ) -> Tuple[List[Optional[str]], str]: | |
| # """ | |
| # Generate motions from a batch of text prompts. | |
| # Returns (list_of_video_paths, status_message) | |
| # """ | |
| # # Parse text file | |
| # text_prompts = [] | |
| # for line in text_file_content.strip().split('\n'): | |
| # line = line.strip() | |
| # if not line: | |
| # continue | |
| # if '#' in line: | |
| # parts = line.split('#') | |
| # text = parts[0].strip() | |
| # length_str = parts[1].strip() if len(parts) > 1 else 'NA' | |
| # else: | |
| # text = line | |
| # length_str = 'NA' | |
| # | |
| # if length_str.upper() == 'NA': | |
| # motion_length = None if use_auto_length else 120 | |
| # else: | |
| # try: | |
| # motion_length = int(length_str) | |
| # motion_length = ((motion_length + 2) // 4) * 4 | |
| # except: | |
| # motion_length = None if use_auto_length else 120 | |
| # | |
| # text_prompts.append((text, motion_length)) | |
| # | |
| # if not text_prompts: | |
| # return [], "No valid prompts found in file" | |
| # | |
| # status_msg = f"Processing {len(text_prompts)} prompts...\n" | |
| # video_paths = [] | |
| # | |
| # # Create a persistent directory for batch videos (not temp, so Gradio can access them) | |
| # # Use a directory that Gradio can serve from | |
| # batch_output_dir = pjoin('generation', 'batch_temp') | |
| # os.makedirs(batch_output_dir, exist_ok=True) | |
| # | |
| # # Also ensure the parent directory exists | |
| # os.makedirs('generation', exist_ok=True) | |
| # | |
| # for idx, (text, motion_length) in enumerate(text_prompts): | |
| # try: | |
| # video_path, gen_msg = generate_motion_single( | |
| # text, motion_length, cfg_scale, use_auto_length, gpu_id, seed + idx | |
| # ) | |
| # | |
| # # If video was created, copy it to persistent location for batch gallery | |
| # if video_path and os.path.exists(video_path): | |
| # # Create a unique filename for this batch item | |
| # # Get file extension from original path | |
| # file_ext = os.path.splitext(video_path)[1] or '.mp4' | |
| # batch_video_path = pjoin(batch_output_dir, f'motion_{idx:04d}{file_ext}') | |
| # | |
| # # Copy file to persistent location | |
| # import shutil | |
| # try: | |
| # shutil.copy2(video_path, batch_video_path) | |
| # | |
| # # Verify the copied file exists | |
| # if os.path.exists(batch_video_path): | |
| # # Use relative path for Gradio (works better in HuggingFace Spaces) | |
| # # Gradio can serve files from relative paths within the app directory | |
| # video_paths.append(batch_video_path) | |
| # file_size = os.path.getsize(batch_video_path) | |
| # status_msg += f"\n[{idx+1}/{len(text_prompts)}] โ {text[:50]}... - Video saved ({file_size/1024:.1f} KB)" | |
| # else: | |
| # status_msg += f"\n[{idx+1}/{len(text_prompts)}] โ {text[:50]}... - Video copy failed (file not found after copy)" | |
| # video_paths.append(None) | |
| # except Exception as copy_error: | |
| # status_msg += f"\n[{idx+1}/{len(text_prompts)}] โ {text[:50]}... - Copy error: {str(copy_error)}" | |
| # video_paths.append(None) | |
| # else: | |
| # status_msg += f"\n[{idx+1}/{len(text_prompts)}] โ {text[:50]}... - Generation failed (no video path returned)" | |
| # video_paths.append(None) | |
| # | |
| # except Exception as e: | |
| # status_msg += f"\n[{idx+1}/{len(text_prompts)}] โ Error: {str(e)}" | |
| # import traceback | |
| # status_msg += f"\n Traceback: {traceback.format_exc()[:200]}" | |
| # video_paths.append(None) | |
| # | |
| # # Filter out None values and verify files exist - Gradio Gallery needs valid paths | |
| # valid_video_paths = [] | |
| # for path in video_paths: | |
| # if path is not None: | |
| # # Use relative path (Gradio handles these better in cloud environments) | |
| # # Verify file exists | |
| # if os.path.exists(path): | |
| # # Verify it's a video file | |
| # if path.lower().endswith(('.mp4', '.gif', '.mov', '.avi')): | |
| # # Get file size for debugging | |
| # file_size = os.path.getsize(path) | |
| # if file_size > 0: # Ensure file is not empty | |
| # valid_video_paths.append(path) # Keep relative path | |
| # status_msg += f"\n โ Verified: {os.path.basename(path)} ({file_size/1024:.1f} KB)" | |
| # else: | |
| # status_msg += f"\n โ Empty file: {path}" | |
| # else: | |
| # status_msg += f"\n โ Skipping non-video file: {path}" | |
| # else: | |
| # # Try absolute path as fallback | |
| # abs_path = os.path.abspath(path) | |
| # if os.path.exists(abs_path): | |
| # valid_video_paths.append(abs_path) | |
| # status_msg += f"\n โ Found via absolute path: {abs_path}" | |
| # else: | |
| # status_msg += f"\n โ File not found (relative: {path}, absolute: {abs_path})" | |
| # | |
| # if len(valid_video_paths) == 0: | |
| # status_msg += "\n\nโ No videos were successfully generated for the gallery." | |
| # status_msg += "\n๐ก Check the error messages above for each prompt." | |
| # # Return empty list explicitly | |
| # return [], status_msg | |
| # else: | |
| # status_msg += f"\n\nโ Successfully generated {len(valid_video_paths)}/{len(text_prompts)} motions." | |
| # status_msg += f"\n๐ Videos saved in: {batch_output_dir}" | |
| # status_msg += f"\n๐ Gallery will display {len(valid_video_paths)} video(s)." | |
| # | |
| # # Debug: Show first few paths | |
| # status_msg += f"\n\n๐ฌ First few video paths:\n" | |
| # for i, p in enumerate(valid_video_paths[:3]): | |
| # abs_p = os.path.abspath(p) if not os.path.isabs(p) else p | |
| # exists = "โ" if os.path.exists(p) else "โ" | |
| # status_msg += f" {exists} [{i+1}] {p} (exists: {os.path.exists(p)})\n" | |
| # if len(valid_video_paths) > 3: | |
| # status_msg += f" ... and {len(valid_video_paths) - 3} more\n" | |
| # | |
| # # Return list of file paths (Gradio Gallery expects list of strings) | |
| # # Ensure all paths are valid and accessible | |
| # final_paths = [] | |
| # for p in valid_video_paths: | |
| # if os.path.exists(p): | |
| # # Normalize path (use forward slashes for cross-platform compatibility) | |
| # normalized_path = p.replace('\\', '/') | |
| # final_paths.append(normalized_path) | |
| # else: | |
| # status_msg += f"\nโ Warning: Path not found when returning: {p}" | |
| # | |
| # if len(final_paths) != len(valid_video_paths): | |
| # status_msg += f"\nโ Some paths were filtered out. Returning {len(final_paths)}/{len(valid_video_paths)} videos." | |
| # | |
| # return final_paths, status_msg | |
| # Gradio Interface | |
| def create_interface(): | |
| """Create and configure the Gradio interface""" | |
| with gr.Blocks(title="ACMDM Motion Generation", theme=gr.themes.Soft()) as app: | |
| gr.Markdown(""" | |
| # ๐ญ ACMDM Motion Generation | |
| Generate human motion from text descriptions using the ACMDM (Absolute Coordinates Make Motion Generation Easy) model. | |
| **How to use:** | |
| 1. Enter a text description of the motion you want to generate | |
| 2. Adjust motion length (or use auto-estimate) | |
| 3. Click "Generate Motion" to create the animation | |
| 4. View and download the generated video | |
| **Example prompts:** | |
| - "A person is running on a treadmill." | |
| - "Someone is doing jumping jacks." | |
| - "A person walks forward and then turns around." | |
| """) | |
| with gr.Tabs(): | |
| # Single Generation Tab | |
| with gr.Tab("Single Motion Generation"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| text_input = gr.Textbox( | |
| label="Motion Description", | |
| placeholder="A person is running on a treadmill.", | |
| lines=3, | |
| value="A person is running on a treadmill." | |
| ) | |
| with gr.Row(): | |
| motion_length = gr.Slider( | |
| label="Motion Length (frames)", | |
| minimum=40, | |
| maximum=196, | |
| value=120, | |
| step=4, | |
| info="Will be rounded to nearest multiple of 4" | |
| ) | |
| use_auto_length = gr.Checkbox( | |
| label="Auto-estimate length", | |
| value=False, | |
| info="Use length estimator (ignores manual length if checked)" | |
| ) | |
| cfg_scale = gr.Slider( | |
| label="CFG Scale", | |
| minimum=1.0, | |
| maximum=10.0, | |
| value=3.0, | |
| step=0.5, | |
| info="Classifier-free guidance scale (higher = more aligned with text)" | |
| ) | |
| with gr.Row(): | |
| gpu_id = gr.Number( | |
| label="GPU ID", | |
| value=0, | |
| precision=0, | |
| info="Use -1 for CPU" | |
| ) | |
| seed = gr.Number( | |
| label="Random Seed", | |
| value=3407, | |
| precision=0 | |
| ) | |
| generate_btn = gr.Button("Generate Motion", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| video_output = gr.Video( | |
| label="Generated Motion", | |
| format="mp4" | |
| ) | |
| status_output = gr.Textbox( | |
| label="Status", | |
| lines=10, | |
| interactive=False | |
| ) | |
| # Update model status after generation | |
| def generate_and_update_status(text, motion_length, cfg_scale, use_auto_length, gpu_id, seed): | |
| video_path, status_msg = generate_motion_single( | |
| text, motion_length, cfg_scale, use_auto_length, gpu_id, seed | |
| ) | |
| # Return video and status, plus trigger status update | |
| return video_path, status_msg | |
| generate_btn.click( | |
| fn=generate_and_update_status, | |
| inputs=[text_input, motion_length, cfg_scale, use_auto_length, gpu_id, seed], | |
| outputs=[video_output, status_output] | |
| ) | |
| # Batch Generation Tab - COMMENTED OUT | |
| # with gr.Tab("Batch Generation"): | |
| # with gr.Row(): | |
| # with gr.Column(scale=1): | |
| # batch_text_input = gr.Textbox( | |
| # label="Text Prompts (one per line, format: text#length or text#NA)", | |
| # placeholder="A person is running on a treadmill.#120\nSomeone is doing jumping jacks.#NA", | |
| # lines=10, | |
| # info="Each line: 'text#length' or 'text#NA' for auto-estimate" | |
| # ) | |
| # | |
| # batch_cfg_scale = gr.Slider( | |
| # label="CFG Scale", | |
| # minimum=1.0, | |
| # maximum=10.0, | |
| # value=3.0, | |
| # step=0.5 | |
| # ) | |
| # | |
| # batch_use_auto_length = gr.Checkbox( | |
| # label="Auto-estimate length for NA", | |
| # value=True | |
| # ) | |
| # | |
| # batch_gpu_id = gr.Number( | |
| # label="GPU ID", | |
| # value=0, | |
| # precision=0 | |
| # ) | |
| # | |
| # batch_seed = gr.Number( | |
| # label="Random Seed", | |
| # value=3407, | |
| # precision=0 | |
| # ) | |
| # | |
| # batch_generate_btn = gr.Button("Generate Batch", variant="primary", size="lg") | |
| # | |
| # with gr.Column(scale=1): | |
| # batch_status_output = gr.Textbox( | |
| # label="Batch Status", | |
| # lines=15, | |
| # interactive=False | |
| # ) | |
| # batch_video_gallery = gr.Gallery( | |
| # label="Generated Motions", | |
| # show_label=True, | |
| # elem_id="gallery", | |
| # columns=2, | |
| # rows=2, | |
| # height="auto", | |
| # type="filepath" # Explicitly specify filepath type | |
| # ) | |
| # | |
| # batch_generate_btn.click( | |
| # fn=generate_motion_batch, | |
| # inputs=[batch_text_input, batch_cfg_scale, batch_use_auto_length, batch_gpu_id, batch_seed], | |
| # outputs=[batch_video_gallery, batch_status_output] | |
| # ) | |
| # Model Management Tab | |
| with gr.Tab("Model Management"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_status = gr.Textbox( | |
| label="Model Status", | |
| lines=15, | |
| interactive=False, | |
| value="โณ Models not loaded yet. They will be loaded automatically on first generation." | |
| ) | |
| with gr.Row(): | |
| refresh_status_btn = gr.Button("๐ Refresh Status", variant="primary") | |
| reload_models_btn = gr.Button("๐๏ธ Clear Cache", variant="secondary") | |
| gr.Markdown(""" | |
| **Model Configuration:** | |
| - Model Name: `ACMDM_Flow_S_PatchSize22` | |
| - AE Name: `AE_2D_Causal` | |
| - Dataset: `t2m` | |
| - Checkpoints Directory: `./checkpoints` | |
| """) | |
| def check_model_status(): | |
| """Check and display current model status""" | |
| global _models_cache | |
| if _models_cache['loaded']: | |
| device = _models_cache['device'] | |
| status_lines = [ | |
| "โ MODELS LOADED AND READY", | |
| "=" * 50, | |
| f"๐ฑ Device: {device}", | |
| f"๐พ CUDA Available: {torch.cuda.is_available()}", | |
| ] | |
| if device.type == 'cuda': | |
| status_lines.append(f"๐ฎ GPU: {torch.cuda.get_device_name(device.index if device.index is not None else 0)}") | |
| status_lines.append(f"๐พ GPU Memory: {torch.cuda.get_device_properties(device.index if device.index is not None else 0).total_memory / 1e9:.2f} GB") | |
| status_lines.extend([ | |
| "", | |
| "๐ฆ Loaded Models:", | |
| " โ Autoencoder (AE_2D_Causal)", | |
| " โ ACMDM Diffusion Model", | |
| ]) | |
| if _models_cache['length_estimator'] is not None: | |
| status_lines.append(" โ Length Estimator") | |
| else: | |
| status_lines.append(" โ Length Estimator (not loaded)") | |
| status_lines.extend([ | |
| "", | |
| "๐ Statistics Loaded:", | |
| " โ Post-AE Mean/Std", | |
| " โ Joint Mean/Std", | |
| " โ Eval Mean/Std", | |
| "", | |
| "โจ Models are cached and ready for generation!", | |
| "๐ก Tip: Use 'Clear Cache' to force reload on next generation." | |
| ]) | |
| return "\n".join(status_lines) | |
| else: | |
| return ( | |
| "โณ Models not loaded yet. They will be loaded automatically on first generation.\n" | |
| "๐ To load models now, go to 'Single Motion Generation' tab and click 'Generate Motion'.\n" | |
| "โฑ๏ธ First generation will take 30-60 seconds (model loading time).\n" | |
| "โก Subsequent generations will be much faster (5-15 seconds)." | |
| ) | |
| def reload_models(): | |
| """Clear model cache""" | |
| global _models_cache | |
| _models_cache['loaded'] = False | |
| _models_cache['ae'] = None | |
| _models_cache['acmdm'] = None | |
| _models_cache['length_estimator'] = None | |
| _models_cache['stats'] = None | |
| _models_cache['device'] = None | |
| return ( | |
| "๐๏ธ Model cache cleared. Models will be automatically reloaded on your next generation request.\n" | |
| "๐ก Click 'Refresh Status' after generating to see updated status." | |
| ) | |
| # Button callbacks | |
| refresh_status_btn.click( | |
| fn=check_model_status, | |
| outputs=model_status | |
| ) | |
| reload_models_btn.click( | |
| fn=reload_models, | |
| outputs=model_status | |
| ) | |
| # Update status on tab load | |
| app.load( | |
| fn=check_model_status, | |
| outputs=model_status | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| **Note:** First generation may take longer as models need to be loaded. Subsequent generations will be faster. | |
| **Tips:** | |
| - Use descriptive text prompts for better results | |
| - Adjust CFG scale: higher values (3-5) for more text alignment, lower values (1-2) for more diversity | |
| - Motion length should be a multiple of 4 (automatically rounded) | |
| - For batch processing, use the format: `text description#length` or `text description#NA` | |
| """) | |
| return app | |
| if __name__ == "__main__": | |
| # Create and launch the interface | |
| app = create_interface() | |
| # Launch with sharing option | |
| app.launch( | |
| server_name="0.0.0.0", # Allow external connections | |
| server_port=7860, # Default Gradio port | |
| share=False, # Set to True for public link (requires ngrok) | |
| show_error=True | |
| ) | |