Aduc-sdr commited on
Commit
6a4b97d
·
verified ·
1 Parent(s): 6e99af4

Update managers/seedvr_manager.py

Browse files
Files changed (1) hide show
  1. managers/seedvr_manager.py +68 -59
managers/seedvr_manager.py CHANGED
@@ -1,14 +1,12 @@
1
  # managers/seedvr_manager.py
2
  #
3
- # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
  #
5
- # Version: 3.0.1 (Config Path Patch)
6
- #
7
- # This version adds a path patching mechanism. It copies the necessary VAE
8
- # configuration files from the cloned SeedVR dependency directory to the
9
- # location where the SeedVR code hardcodedly expects them, resolving the
10
- # FileNotFoundError during initialization.
11
 
 
12
  import torch
13
  import torch.distributed as dist
14
  import os
@@ -22,21 +20,54 @@ from torch.hub import download_url_to_file
22
  import gradio as gr
23
  import mediapy
24
  from einops import rearrange
25
- import shutil # <--- NOVO IMPORT
26
-
27
  from tools.tensor_utils import wavelet_reconstruction
28
 
29
  logger = logging.getLogger(__name__)
30
-
31
- # --- Gerenciamento de Dependências (sem alterações) ---
32
  DEPS_DIR = Path("./deps")
33
  SEEDVR_SPACE_DIR = DEPS_DIR / "SeedVR_Space"
34
  SEEDVR_SPACE_URL = "https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B"
35
 
36
- def setup_seedvr_dependencies():
37
- # ... (sem alterações aqui)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  if not SEEDVR_SPACE_DIR.exists():
39
- logger.info(f"SeedVR Space not found at '{SEEDVR_SPACE_DIR}'. Cloning from Hugging Face...")
40
  try:
41
  DEPS_DIR.mkdir(exist_ok=True)
42
  subprocess.run(
@@ -46,28 +77,29 @@ def setup_seedvr_dependencies():
46
  logger.info("SeedVR Space repository cloned successfully.")
47
  except subprocess.CalledProcessError as e:
48
  logger.error(f"Failed to clone SeedVR Space. Git stderr: {e.stderr}")
49
- raise RuntimeError("Could not clone the required SeedVR dependency from Hugging Face.")
50
  else:
51
  logger.info("Found local SeedVR Space repository.")
52
 
 
53
  if str(SEEDVR_SPACE_DIR.resolve()) not in sys.path:
54
  sys.path.insert(0, str(SEEDVR_SPACE_DIR.resolve()))
55
  logger.info(f"Added '{SEEDVR_SPACE_DIR.resolve()}' to sys.path.")
56
 
57
- setup_seedvr_dependencies()
58
 
59
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
60
  from common.config import load_config
61
  from common.seed import set_seed
62
- from data.image.transforms.divisible_crop import DivisibleCrop
63
- # ... (outros imports do seedvr sem alterações)
64
  from torchvision.io.video import read_video
65
  from omegaconf import OmegaConf
66
-
 
 
 
67
 
68
  class SeedVrManager:
69
  def __init__(self, workspace_dir="deformes_workspace"):
70
- # ... (sem alterações aqui)
71
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
72
  self.runner = None
73
  self.workspace_dir = workspace_dir
@@ -75,80 +107,56 @@ class SeedVrManager:
75
  self._original_barrier = None
76
  logger.info("SeedVrManager initialized. Model will be loaded on demand.")
77
 
78
- # <--- INÍCIO DA NOVA FUNÇÃO DE PATCH --->
79
  def _patch_config_paths(self):
80
- """
81
- Copies the VAE config directory from the cloned repo to the hardcoded
82
- path that the SeedVR library expects.
83
- """
84
  app_root = Path("/home/user/app")
85
  source_config_dir = SEEDVR_SPACE_DIR / "models" / "video_vae_v3"
86
  target_config_parent_dir = app_root / "models"
87
  target_config_dir = target_config_parent_dir / "video_vae_v3"
88
-
89
  if not source_config_dir.exists():
90
  logger.warning(f"Source VAE config directory not found at {source_config_dir}. Skipping patch.")
91
  return
92
-
93
  if target_config_dir.exists():
94
  logger.info(f"Target VAE config path {target_config_dir} already exists. Skipping copy.")
95
  return
96
-
97
  logger.info(f"Patching SeedVR config path: Copying {source_config_dir} to {target_config_dir}...")
98
  try:
99
- # Cria o diretório pai (/home/user/app/models) se ele não existir
100
  target_config_parent_dir.mkdir(parents=True, exist_ok=True)
101
- # Copia a árvore de diretórios inteira
102
  shutil.copytree(source_config_dir, target_config_dir)
103
  logger.info("Config path patched successfully.")
104
  except Exception as e:
105
  logger.error(f"Failed to patch SeedVR config path: {e}", exc_info=True)
106
  raise IOError("Could not patch the required SeedVR configuration paths.")
107
- # <--- FIM DA NOVA FUNÇÃO DE PATCH --->
108
 
109
  def _download_models(self):
110
- # ... (sem alterações aqui)
111
- logger.info("Verifying and downloading SeedVR2 model checkpoints...")
112
  ckpt_dir = SEEDVR_SPACE_DIR / 'ckpts'
113
  ckpt_dir.mkdir(exist_ok=True)
114
 
115
  pretrain_model_urls = {
116
  'vae_ckpt': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
117
  'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
118
- 'dit_7b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-7B/resolve/main/seedvr2_ema_7b.pth',
119
  'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
120
  'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt'
121
  }
122
  for key, url in pretrain_model_urls.items():
123
  _load_file_from_url(url=url, model_dir=str(ckpt_dir))
124
- logger.info("SeedVR2 model checkpoints downloaded successfully.")
125
 
126
- def _initialize_runner(self, model_version: str):
127
- """Loads and configures the SeedVR model."""
128
  if self.runner is not None: return
129
-
130
- # Chama o patch ANTES de tentar carregar qualquer coisa
131
  self._patch_config_paths()
132
-
133
  self._download_models()
134
 
135
  if dist.is_available() and not dist.is_initialized():
136
- # ... (patch do barrier sem alterações)
137
- logger.info("Applying patch to disable torch.distributed.barrier for single-GPU inference.")
138
  self._original_barrier = dist.barrier
139
  dist.barrier = lambda *args, **kwargs: None
140
 
141
- logger.info(f"Initializing SeedVR2 {model_version} runner...")
142
- if model_version == '3B':
143
- config_path = SEEDVR_SPACE_DIR / 'configs_3b' / 'main.yaml'
144
- checkpoint_path = SEEDVR_SPACE_DIR / 'ckpts' / 'seedvr2_ema_3b.pth'
145
- elif model_version == '7B':
146
- config_path = SEEDVR_SPACE_DIR / 'configs_7b' / 'main.yaml'
147
- checkpoint_path = SEEDVR_SPACE_DIR / 'ckpts' / 'seedvr2_ema_7b.pth'
148
- else:
149
- raise ValueError(f"Unsupported SeedVR model version: {model_version}")
150
-
151
- # Agora, quando `load_config` for chamado, ele encontrará o arquivo no caminho esperado.
152
  config = load_config(str(config_path))
153
 
154
  self.runner = VideoDiffusionInfer(config)
@@ -158,28 +166,28 @@ class SeedVrManager:
158
  if hasattr(self.runner.vae, "set_memory_limit"):
159
  self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
160
  self.is_initialized = True
161
- logger.info(f"Runner for SeedVR2 {model_version} initialized and ready.")
162
 
163
- # ... (o resto da classe e do arquivo permanece o mesmo)
164
  def _unload_runner(self):
 
165
  if self.runner is not None:
166
  del self.runner; self.runner = None
167
  gc.collect(); torch.cuda.empty_cache()
168
  self.is_initialized = False
169
  logger.info("SeedVR runner unloaded from VRAM.")
170
  if self._original_barrier is not None:
171
- logger.info("Restoring original torch.distributed.barrier function.")
172
  dist.barrier = self._original_barrier
173
  self._original_barrier = None
174
 
175
  def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
176
- model_version: str = '3B', steps: int = 50, seed: int = 666,
177
- progress: gr.Progress = None) -> str:
178
  try:
179
- self._initialize_runner(model_version)
180
  set_seed(seed, same_across_ranks=True)
181
  self.runner.config.diffusion.timesteps.sampling.steps = steps
182
  self.runner.configure_diffusion()
 
183
  video_tensor = read_video(input_video_path, output_format="TCHW")[0] / 255.0
184
  res_h, res_w = video_tensor.shape[-2:]
185
  video_transform = Compose([
@@ -223,6 +231,7 @@ class SeedVrManager:
223
  self._unload_runner()
224
 
225
  def _load_file_from_url(url, model_dir='./', file_name=None):
 
226
  os.makedirs(model_dir, exist_ok=True)
227
  filename = file_name or os.path.basename(urlparse(url).path)
228
  cached_file = os.path.abspath(os.path.join(model_dir, filename))
 
1
  # managers/seedvr_manager.py
2
  #
3
+ # Version: 3.2.0 (3B Model Focus)
4
  #
5
+ # This version simplifies the manager to exclusively use the SeedVR 3B model.
6
+ # The 7B model download and selection logic have been removed to streamline
7
+ # the code and reduce resource usage.
 
 
 
8
 
9
+ # ... (imports permanecem os mesmos) ...
10
  import torch
11
  import torch.distributed as dist
12
  import os
 
20
  import gradio as gr
21
  import mediapy
22
  from einops import rearrange
23
+ import shutil
 
24
  from tools.tensor_utils import wavelet_reconstruction
25
 
26
  logger = logging.getLogger(__name__)
27
+ # ... (setup_seedvr_environment_and_dependencies e imports do seedvr permanecem os mesmos) ...
28
+ # --- INÍCIO DA SEÇÃO DE GERENCIAMENTO DE DEPENDÊNCIAS E AMBIENTE ---
29
  DEPS_DIR = Path("./deps")
30
  SEEDVR_SPACE_DIR = DEPS_DIR / "SeedVR_Space"
31
  SEEDVR_SPACE_URL = "https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B"
32
 
33
+ def setup_seedvr_environment_and_dependencies():
34
+ """
35
+ Performs all one-time setup tasks for SeedVR:
36
+ 1. Sets torch.distributed environment variables.
37
+ 2. Forces the installation of flash-attn.
38
+ 3. Clones the SeedVR repository for its code modules.
39
+ 4. Adds the repository to the Python path.
40
+ """
41
+ # 1. Configurar variáveis de ambiente para torch.distributed
42
+ if "MASTER_ADDR" not in os.environ:
43
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
44
+ os.environ["MASTER_PORT"] = "12355" # Porta aleatória
45
+ os.environ["RANK"] = str(0)
46
+ os.environ["WORLD_SIZE"] = str(1)
47
+ logger.info("Set up environment variables for torch.distributed.")
48
+
49
+ # 2. Forçar a instalação do flash-attn
50
+ try:
51
+ import flash_attn
52
+ logger.info("flash-attn is already installed.")
53
+ except ImportError:
54
+ logger.info("Attempting to install flash-attn...")
55
+ try:
56
+ subprocess.run(
57
+ "pip install flash-attn --no-build-isolation",
58
+ env={**os.environ, "FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
59
+ shell=True,
60
+ check=True,
61
+ capture_output=True,
62
+ text=True
63
+ )
64
+ logger.info("flash-attn installed successfully.")
65
+ except subprocess.CalledProcessError as e:
66
+ logger.error(f"Failed to install flash-attn. Stderr: {e.stderr}")
67
+
68
+ # 3. Clonar o repositório do SeedVR Space
69
  if not SEEDVR_SPACE_DIR.exists():
70
+ logger.info(f"SeedVR Space not found at '{SEEDVR_SPACE_DIR}'. Cloning...")
71
  try:
72
  DEPS_DIR.mkdir(exist_ok=True)
73
  subprocess.run(
 
77
  logger.info("SeedVR Space repository cloned successfully.")
78
  except subprocess.CalledProcessError as e:
79
  logger.error(f"Failed to clone SeedVR Space. Git stderr: {e.stderr}")
80
+ raise RuntimeError("Could not clone the required SeedVR dependency.")
81
  else:
82
  logger.info("Found local SeedVR Space repository.")
83
 
84
+ # 4. Adicionar o repositório ao path do Python
85
  if str(SEEDVR_SPACE_DIR.resolve()) not in sys.path:
86
  sys.path.insert(0, str(SEEDVR_SPACE_DIR.resolve()))
87
  logger.info(f"Added '{SEEDVR_SPACE_DIR.resolve()}' to sys.path.")
88
 
89
+ setup_seedvr_environment_and_dependencies()
90
 
91
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
92
  from common.config import load_config
93
  from common.seed import set_seed
 
 
94
  from torchvision.io.video import read_video
95
  from omegaconf import OmegaConf
96
+ from data.image.transforms.divisible_crop import DivisibleCrop
97
+ from data.image.transforms.na_resize import NaResize
98
+ from data.video.transforms.rearrange import Rearrange
99
+ from torchvision.transforms import Compose, Lambda, Normalize
100
 
101
  class SeedVrManager:
102
  def __init__(self, workspace_dir="deformes_workspace"):
 
103
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
104
  self.runner = None
105
  self.workspace_dir = workspace_dir
 
107
  self._original_barrier = None
108
  logger.info("SeedVrManager initialized. Model will be loaded on demand.")
109
 
 
110
  def _patch_config_paths(self):
111
+ # ... (sem alterações) ...
 
 
 
112
  app_root = Path("/home/user/app")
113
  source_config_dir = SEEDVR_SPACE_DIR / "models" / "video_vae_v3"
114
  target_config_parent_dir = app_root / "models"
115
  target_config_dir = target_config_parent_dir / "video_vae_v3"
 
116
  if not source_config_dir.exists():
117
  logger.warning(f"Source VAE config directory not found at {source_config_dir}. Skipping patch.")
118
  return
 
119
  if target_config_dir.exists():
120
  logger.info(f"Target VAE config path {target_config_dir} already exists. Skipping copy.")
121
  return
 
122
  logger.info(f"Patching SeedVR config path: Copying {source_config_dir} to {target_config_dir}...")
123
  try:
 
124
  target_config_parent_dir.mkdir(parents=True, exist_ok=True)
 
125
  shutil.copytree(source_config_dir, target_config_dir)
126
  logger.info("Config path patched successfully.")
127
  except Exception as e:
128
  logger.error(f"Failed to patch SeedVR config path: {e}", exc_info=True)
129
  raise IOError("Could not patch the required SeedVR configuration paths.")
 
130
 
131
  def _download_models(self):
132
+ logger.info("Verifying and downloading SeedVR2 3B model checkpoints...")
 
133
  ckpt_dir = SEEDVR_SPACE_DIR / 'ckpts'
134
  ckpt_dir.mkdir(exist_ok=True)
135
 
136
  pretrain_model_urls = {
137
  'vae_ckpt': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
138
  'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
139
+ # 'dit_7b' REMOVIDO
140
  'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
141
  'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt'
142
  }
143
  for key, url in pretrain_model_urls.items():
144
  _load_file_from_url(url=url, model_dir=str(ckpt_dir))
145
+ logger.info("SeedVR2 3B model checkpoints downloaded successfully.")
146
 
147
+ def _initialize_runner(self): # <--- REMOVIDO model_version
 
148
  if self.runner is not None: return
 
 
149
  self._patch_config_paths()
 
150
  self._download_models()
151
 
152
  if dist.is_available() and not dist.is_initialized():
 
 
153
  self._original_barrier = dist.barrier
154
  dist.barrier = lambda *args, **kwargs: None
155
 
156
+ logger.info("Initializing SeedVR2 3B runner...")
157
+ config_path = SEEDVR_SPACE_DIR / 'configs_3b' / 'main.yaml'
158
+ checkpoint_path = SEEDVR_SPACE_DIR / 'ckpts' / 'seedvr2_ema_3b.pth'
159
+
 
 
 
 
 
 
 
160
  config = load_config(str(config_path))
161
 
162
  self.runner = VideoDiffusionInfer(config)
 
166
  if hasattr(self.runner.vae, "set_memory_limit"):
167
  self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
168
  self.is_initialized = True
169
+ logger.info("Runner for SeedVR2 3B initialized and ready.")
170
 
 
171
  def _unload_runner(self):
172
+ # ... (sem alterações) ...
173
  if self.runner is not None:
174
  del self.runner; self.runner = None
175
  gc.collect(); torch.cuda.empty_cache()
176
  self.is_initialized = False
177
  logger.info("SeedVR runner unloaded from VRAM.")
178
  if self._original_barrier is not None:
 
179
  dist.barrier = self._original_barrier
180
  self._original_barrier = None
181
 
182
  def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
183
+ steps: int = 50, seed: int = 666,
184
+ progress: gr.Progress = None) -> str: # <--- REMOVIDO model_version
185
  try:
186
+ self._initialize_runner() # <--- REMOVIDO model_version
187
  set_seed(seed, same_across_ranks=True)
188
  self.runner.config.diffusion.timesteps.sampling.steps = steps
189
  self.runner.configure_diffusion()
190
+ # ... (resto da função sem alterações) ...
191
  video_tensor = read_video(input_video_path, output_format="TCHW")[0] / 255.0
192
  res_h, res_w = video_tensor.shape[-2:]
193
  video_transform = Compose([
 
231
  self._unload_runner()
232
 
233
  def _load_file_from_url(url, model_dir='./', file_name=None):
234
+ # ... (sem alterações) ...
235
  os.makedirs(model_dir, exist_ok=True)
236
  filename = file_name or os.path.basename(urlparse(url).path)
237
  cached_file = os.path.abspath(os.path.join(model_dir, filename))