aducsdr commited on
Commit
0d1253b
·
verified ·
1 Parent(s): 1721077

Delete seedvr_manager (2).py

Browse files
Files changed (1) hide show
  1. seedvr_manager (2).py +0 -233
seedvr_manager (2).py DELETED
@@ -1,233 +0,0 @@
1
- # managers/seedvr_manager.py
2
- #
3
- # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
- #
5
- # Version: 4.0.0 (Root Installer & Executor)
6
- #
7
- # This version fully adopts the logic from the functional hd_specialist.py example.
8
- # It acts as a setup manager: it clones the SeedVR repo and then copies all
9
- # necessary directories (projects, common, models, configs, ckpts) to the
10
- # application root. It also handles the pip installation of the Apex dependency.
11
- # This ensures that the SeedVR code runs in the exact file structure it expects.
12
-
13
- import torch
14
- import torch.distributed as dist
15
- import os
16
- import gc
17
- import logging
18
- import sys
19
- import subprocess
20
- from pathlib import Path
21
- from urllib.parse import urlparse
22
- from torch.hub import download_url_to_file
23
- import gradio as gr
24
- import mediapy
25
- from einops import rearrange
26
- import shutil
27
- from omegaconf import OmegaConf
28
-
29
- logger = logging.getLogger(__name__)
30
-
31
- # --- Caminhos Globais ---
32
- APP_ROOT = Path("/home/user/app")
33
- DEPS_DIR = APP_ROOT / "deps"
34
- SEEDVR_SPACE_DIR = DEPS_DIR / "SeedVR_Space"
35
- SEEDVR_SPACE_URL = "https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B"
36
-
37
- class SeedVrManager:
38
- def __init__(self, workspace_dir="deformes_workspace"):
39
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
40
- self.runner = None
41
- self.workspace_dir = workspace_dir
42
- self.is_initialized = False
43
- self._original_barrier = None
44
- self.setup_complete = False # Flag para rodar o setup apenas uma vez
45
- logger.info("SeedVrManager initialized. Setup will run on first use.")
46
-
47
- def _full_setup(self):
48
- """
49
- Executa todo o processo de setup uma única vez.
50
- """
51
- if self.setup_complete:
52
- return
53
-
54
- logger.info("--- Starting Full SeedVR Setup ---")
55
-
56
- # 1. Clonar o repositório se não existir
57
- if not SEEDVR_SPACE_DIR.exists():
58
- logger.info(f"Cloning SeedVR Space repo to {SEEDVR_SPACE_DIR}...")
59
- DEPS_DIR.mkdir(exist_ok=True, parents=True)
60
- subprocess.run(
61
- ["git", "clone", "--depth", "1", SEEDVR_SPACE_URL, str(SEEDVR_SPACE_DIR)],
62
- check=True, capture_output=True, text=True
63
- )
64
-
65
- # 2. Copiar as pastas necessárias para a raiz da aplicação
66
- required_dirs = ["projects", "common", "models", "configs_3b", "configs_7b"]
67
- for dirname in required_dirs:
68
- source = SEEDVR_SPACE_DIR / dirname
69
- target = APP_ROOT / dirname
70
- if not target.exists():
71
- logger.info(f"Copying '{dirname}' to application root...")
72
- shutil.copytree(source, target)
73
-
74
- # 3. Adicionar a raiz ao sys.path para garantir que os imports funcionem
75
- if str(APP_ROOT) not in sys.path:
76
- sys.path.insert(0, str(APP_ROOT))
77
- logger.info(f"Added '{APP_ROOT}' to sys.path.")
78
-
79
- # 4. Instalar dependências complexas como Apex
80
- try:
81
- import apex
82
- logger.info("Apex is already installed.")
83
- except ImportError:
84
- logger.info("Installing Apex dependency...")
85
- apex_url = 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
86
- apex_wheel_path = _load_file_from_url(url=apex_url, model_dir=str(DEPS_DIR))
87
- subprocess.run(f"pip install {apex_wheel_path}", check=True, shell=True)
88
- logger.info("Apex installed successfully.")
89
-
90
- # 5. Baixar os modelos para a pasta ./ckpts na raiz
91
- ckpt_dir = APP_ROOT / 'ckpts'
92
- ckpt_dir.mkdir(exist_ok=True)
93
- pretrain_model_urls = {
94
- 'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
95
- 'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
96
- 'dit_7b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-7B/resolve/main/seedvr2_ema_7b.pth',
97
- 'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
98
- 'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt'
99
- }
100
- for name, url in pretrain_model_urls.items():
101
- _load_file_from_url(url=url, model_dir=str(ckpt_dir))
102
-
103
- self.setup_complete = True
104
- logger.info("--- Full SeedVR Setup Complete ---")
105
-
106
- def _initialize_runner(self, model_version: str):
107
- if self.runner is not None: return
108
-
109
- # Garante que todo o ambiente está configurado antes de prosseguir
110
- self._full_setup()
111
-
112
- # Agora que o setup está feito, podemos importar os módulos
113
- from projects.video_diffusion_sr.infer import VideoDiffusionInfer
114
- from common.config import load_config
115
- from common.seed import set_seed
116
-
117
- if dist.is_available() and not dist.is_initialized():
118
- os.environ["MASTER_ADDR"] = "127.0.0.1"
119
- os.environ["MASTER_PORT"] = "12355"
120
- os.environ["RANK"] = str(0)
121
- os.environ["WORLD_SIZE"] = str(1)
122
- dist.init_process_group(backend='gloo')
123
- logger.info("Initialized torch.distributed process group.")
124
-
125
- logger.info(f"Initializing SeedVR2 {model_version} runner...")
126
- if model_version == '3B':
127
- config_path = APP_ROOT / 'configs_3b' / 'main.yaml'
128
- checkpoint_path = APP_ROOT / 'ckpts' / 'seedvr2_ema_3b.pth'
129
- else: # Assumimos 7B
130
- config_path = APP_ROOT / 'configs_7b' / 'main.yaml'
131
- checkpoint_path = APP_ROOT / 'ckpts' / 'seedvr2_ema_7b.pth'
132
-
133
- config = load_config(str(config_path))
134
-
135
- self.runner = VideoDiffusionInfer(config)
136
- OmegaConf.set_readonly(self.runner.config, False)
137
-
138
- self.runner.configure_dit_model(device=self.device, checkpoint=str(checkpoint_path))
139
- self.runner.configure_vae_model()
140
-
141
- if hasattr(self.runner.vae, "set_memory_limit"):
142
- self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
143
-
144
- self.is_initialized = True
145
- logger.info(f"Runner for SeedVR2 {model_version} initialized and ready.")
146
-
147
- def _unload_runner(self):
148
- if self.runner is not None:
149
- del self.runner
150
- self.runner = None
151
- gc.collect()
152
- torch.cuda.empty_cache()
153
- self.is_initialized = False
154
- logger.info("Runner do SeedVR2 descarregado da VRAM.")
155
- if dist.is_initialized():
156
- dist.destroy_process_group()
157
- logger.info("Destroyed torch.distributed process group.")
158
-
159
- def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
160
- model_version: str = '7B', steps: int = 100, seed: int = 666,
161
- progress: gr.Progress = None) -> str:
162
- try:
163
- self._initialize_runner(model_version)
164
-
165
- # Precisamos importar aqui, pois o sys.path é modificado no setup
166
- from common.seed import set_seed
167
- from data.image.transforms.divisible_crop import DivisibleCrop
168
- from data.image.transforms.na_resize import NaResize
169
- from data.video.transforms.rearrange import Rearrange
170
- from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
171
- from torchvision.transforms import Compose, Lambda, Normalize
172
- from torchvision.io.video import read_video
173
-
174
- set_seed(seed, same_across_ranks=True)
175
- self.runner.config.diffusion.timesteps.sampling.steps = steps
176
- self.runner.configure_diffusion()
177
-
178
- video_tensor = read_video(input_video_path, output_format="TCHW")[0] / 255.0
179
- res_h, res_w = video_tensor.shape[-2:]
180
- video_transform = Compose([
181
- NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False),
182
- Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
183
- DivisibleCrop((16, 16)),
184
- Normalize(0.5, 0.5),
185
- Rearrange("t c h w -> c t h w"),
186
- ])
187
- cond_latents = [video_transform(video_tensor.to(self.device))]
188
- input_videos = cond_latents
189
- self.runner.dit.to("cpu")
190
- self.runner.vae.to(self.device)
191
- cond_latents = self.runner.vae_encode(cond_latents)
192
- self.runner.vae.to("cpu"); gc.collect(); torch.cuda.empty_cache()
193
- self.runner.dit.to(self.device)
194
-
195
- pos_emb = torch.load(APP_ROOT / 'pos_emb.pt').to(self.device)
196
- neg_emb = torch.load(APP_ROOT / 'neg_emb.pt').to(self.device)
197
- text_embeds_dict = {"texts_pos": [pos_emb], "texts_neg": [neg_emb]}
198
-
199
- noises = [torch.randn_like(latent) for latent in cond_latents]
200
- conditions = [self.runner.get_condition(noise, latent_blur=latent, task="sr") for noise, latent in zip(noises, cond_latents)]
201
-
202
- with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
203
- video_tensors = self.runner.inference(noises=noises, conditions=conditions, dit_offload=True, **text_embeds_dict)
204
-
205
- self.runner.dit.to("cpu"); gc.collect(); torch.cuda.empty_cache()
206
- self.runner.vae.to(self.device)
207
- samples = self.runner.vae_decode(video_tensors)
208
- final_sample = samples[0]
209
- input_video_sample = input_videos[0]
210
- if final_sample.shape[1] < input_video_sample.shape[1]:
211
- input_video_sample = input_video_sample[:, :final_sample.shape[1]]
212
-
213
- final_sample = wavelet_reconstruction(rearrange(final_sample, "c t h w -> t c h w"), rearrange(input_video_sample, "c t h w -> t c h w"))
214
- final_sample = rearrange(final_sample, "t c h w -> t h w c")
215
- final_sample = final_sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
216
- final_sample_np = final_sample.to(torch.uint8).cpu().numpy()
217
-
218
- mediapy.write_video(output_video_path, final_sample_np, fps=24)
219
- logger.info(f"HD Mastered video saved to: {output_video_path}")
220
- return output_path
221
- finally:
222
- self._unload_runner()
223
-
224
- def _load_file_from_url(url, model_dir='./', file_name=None):
225
- os.makedirs(model_dir, exist_ok=True)
226
- filename = file_name or os.path.basename(urlparse(url).path)
227
- cached_file = os.path.abspath(os.path.join(model_dir, filename))
228
- if not os.path.exists(cached_file):
229
- logger.info(f'Downloading: "{url}" to {cached_file}')
230
- download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
231
- return cached_file
232
-
233
- seedvr_manager_singleton = SeedVrManager()