Spaces:
Runtime error
Runtime error
runtime fix
Browse files
app.py
CHANGED
|
@@ -33,6 +33,10 @@ def set_seed(seed):
|
|
| 33 |
torch.cuda.manual_seed_all(seed)
|
| 34 |
random.seed(seed)
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
def remove_prefix(text, prefix):
|
| 38 |
if text.startswith(prefix):
|
|
@@ -176,9 +180,6 @@ class HandDiffOpts:
|
|
| 176 |
num_workers: int = 10
|
| 177 |
n_val_samples: int = 4
|
| 178 |
|
| 179 |
-
if not torch.cuda.is_available():
|
| 180 |
-
raise ValueError("No GPU")
|
| 181 |
-
|
| 182 |
# load models
|
| 183 |
if NEW_MODEL:
|
| 184 |
opts = HandDiffOpts()
|
|
@@ -202,15 +203,15 @@ if NEW_MODEL:
|
|
| 202 |
latent_dim=opts.latent_dim,
|
| 203 |
in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
|
| 204 |
learn_sigma=True,
|
| 205 |
-
).
|
| 206 |
# ckpt_state_dict = torch.load(model_path)['model_state_dict']
|
| 207 |
-
ckpt_state_dict = torch.load(model_path, map_location=torch.device(
|
| 208 |
missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False)
|
| 209 |
model.eval()
|
| 210 |
print(missing_keys, extra_keys)
|
| 211 |
assert len(missing_keys) == 0
|
| 212 |
vae_state_dict = torch.load(vae_path)['state_dict']
|
| 213 |
-
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).
|
| 214 |
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
|
| 215 |
autoencoder.eval()
|
| 216 |
assert len(missing_keys) == 0
|
|
@@ -225,18 +226,18 @@ else:
|
|
| 225 |
latent_dim=opts.latent_dim,
|
| 226 |
in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
|
| 227 |
learn_sigma=True,
|
| 228 |
-
).
|
| 229 |
ckpt_state_dict = torch.load(model_path)['state_dict']
|
| 230 |
dit_state_dict = {remove_prefix(k, 'diffusion_backbone.'): v for k, v in ckpt_state_dict.items() if k.startswith('diffusion_backbone')}
|
| 231 |
vae_state_dict = {remove_prefix(k, 'autoencoder.'): v for k, v in ckpt_state_dict.items() if k.startswith('autoencoder')}
|
| 232 |
missing_keys, extra_keys = model.load_state_dict(dit_state_dict, strict=False)
|
| 233 |
model.eval()
|
| 234 |
assert len(missing_keys) == 0 and len(extra_keys) == 0
|
| 235 |
-
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).
|
| 236 |
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
|
| 237 |
autoencoder.eval()
|
| 238 |
assert len(missing_keys) == 0 and len(extra_keys) == 0
|
| 239 |
-
sam_predictor = init_sam(ckpt_path="./sam_vit_h_4b8939.pth")
|
| 240 |
|
| 241 |
|
| 242 |
print("Mediapipe hand detector and SAM ready...")
|
|
@@ -312,7 +313,7 @@ def get_ref_anno(ref):
|
|
| 312 |
img,
|
| 313 |
keypts,
|
| 314 |
hand_mask,
|
| 315 |
-
device=
|
| 316 |
target_size=(256, 256),
|
| 317 |
latent_size=(32, 32),
|
| 318 |
):
|
|
@@ -348,7 +349,7 @@ def get_ref_anno(ref):
|
|
| 348 |
img,
|
| 349 |
keypts,
|
| 350 |
hand_mask,
|
| 351 |
-
device=
|
| 352 |
target_size=opts.image_size,
|
| 353 |
latent_size=opts.latent_size,
|
| 354 |
)
|
|
@@ -405,7 +406,7 @@ def get_target_anno(target):
|
|
| 405 |
)
|
| 406 |
* kpts_valid[:, None, None],
|
| 407 |
dtype=torch.float,
|
| 408 |
-
device=
|
| 409 |
)[None, ...]
|
| 410 |
target_cond = torch.cat(
|
| 411 |
[target_heatmaps, torch.zeros_like(target_heatmaps)[:, :1]], 1
|
|
@@ -525,12 +526,12 @@ def sample_diff(ref_cond, target_cond, target_keypts, num_gen, seed, cfg):
|
|
| 525 |
set_seed(seed)
|
| 526 |
z = torch.randn(
|
| 527 |
(num_gen, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]),
|
| 528 |
-
device=
|
| 529 |
)
|
| 530 |
target_cond = target_cond.repeat(num_gen, 1, 1, 1)
|
| 531 |
ref_cond = ref_cond.repeat(num_gen, 1, 1, 1)
|
| 532 |
# novel view synthesis mode = off
|
| 533 |
-
nvs = torch.zeros(num_gen, dtype=torch.int, device=
|
| 534 |
z = torch.cat([z, z], 0)
|
| 535 |
model_kwargs = dict(
|
| 536 |
target_cond=torch.cat([target_cond, torch.zeros_like(target_cond)]),
|
|
@@ -546,7 +547,7 @@ def sample_diff(ref_cond, target_cond, target_keypts, num_gen, seed, cfg):
|
|
| 546 |
clip_denoised=False,
|
| 547 |
model_kwargs=model_kwargs,
|
| 548 |
progress=True,
|
| 549 |
-
device=
|
| 550 |
).chunk(2)
|
| 551 |
sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor)
|
| 552 |
sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0)
|
|
@@ -635,14 +636,14 @@ def ready_sample(img_ori, inpaint_mask, keypts):
|
|
| 635 |
inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST
|
| 636 |
),
|
| 637 |
dtype=torch.float,
|
| 638 |
-
device=
|
| 639 |
).unsqueeze(0)[None, ...]
|
| 640 |
|
| 641 |
def make_ref_cond(
|
| 642 |
img,
|
| 643 |
keypts,
|
| 644 |
hand_mask,
|
| 645 |
-
device=
|
| 646 |
target_size=(256, 256),
|
| 647 |
latent_size=(32, 32),
|
| 648 |
):
|
|
@@ -678,7 +679,7 @@ def ready_sample(img_ori, inpaint_mask, keypts):
|
|
| 678 |
img,
|
| 679 |
keypts,
|
| 680 |
hand_mask * (1 - inpaint_mask),
|
| 681 |
-
device=
|
| 682 |
target_size=opts.image_size,
|
| 683 |
latent_size=opts.latent_size,
|
| 684 |
)
|
|
@@ -736,12 +737,12 @@ def sample_inpaint(
|
|
| 736 |
jump_n_sample = quality
|
| 737 |
cfg_scale = cfg
|
| 738 |
z = torch.randn(
|
| 739 |
-
(N, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), device=
|
| 740 |
)
|
| 741 |
target_cond_N = target_cond.repeat(N, 1, 1, 1)
|
| 742 |
ref_cond_N = ref_cond.repeat(N, 1, 1, 1)
|
| 743 |
# novel view synthesis mode = off
|
| 744 |
-
nvs = torch.zeros(N, dtype=torch.int, device=
|
| 745 |
z = torch.cat([z, z], 0)
|
| 746 |
model_kwargs = dict(
|
| 747 |
target_cond=torch.cat([target_cond_N, torch.zeros_like(target_cond_N)]),
|
|
@@ -759,7 +760,7 @@ def sample_inpaint(
|
|
| 759 |
clip_denoised=False,
|
| 760 |
model_kwargs=model_kwargs,
|
| 761 |
progress=True,
|
| 762 |
-
device=
|
| 763 |
jump_length=jump_length,
|
| 764 |
jump_n_sample=jump_n_sample,
|
| 765 |
).chunk(2)
|
|
|
|
| 33 |
torch.cuda.manual_seed_all(seed)
|
| 34 |
random.seed(seed)
|
| 35 |
|
| 36 |
+
if torch.cuda.is_available():
|
| 37 |
+
device = "cuda"
|
| 38 |
+
else:
|
| 39 |
+
device = "cpu"
|
| 40 |
|
| 41 |
def remove_prefix(text, prefix):
|
| 42 |
if text.startswith(prefix):
|
|
|
|
| 180 |
num_workers: int = 10
|
| 181 |
n_val_samples: int = 4
|
| 182 |
|
|
|
|
|
|
|
|
|
|
| 183 |
# load models
|
| 184 |
if NEW_MODEL:
|
| 185 |
opts = HandDiffOpts()
|
|
|
|
| 203 |
latent_dim=opts.latent_dim,
|
| 204 |
in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
|
| 205 |
learn_sigma=True,
|
| 206 |
+
).to(device)
|
| 207 |
# ckpt_state_dict = torch.load(model_path)['model_state_dict']
|
| 208 |
+
ckpt_state_dict = torch.load(model_path, map_location=torch.device(device))['ema_state_dict']
|
| 209 |
missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False)
|
| 210 |
model.eval()
|
| 211 |
print(missing_keys, extra_keys)
|
| 212 |
assert len(missing_keys) == 0
|
| 213 |
vae_state_dict = torch.load(vae_path)['state_dict']
|
| 214 |
+
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).to(device)
|
| 215 |
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
|
| 216 |
autoencoder.eval()
|
| 217 |
assert len(missing_keys) == 0
|
|
|
|
| 226 |
latent_dim=opts.latent_dim,
|
| 227 |
in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
|
| 228 |
learn_sigma=True,
|
| 229 |
+
).to(device)
|
| 230 |
ckpt_state_dict = torch.load(model_path)['state_dict']
|
| 231 |
dit_state_dict = {remove_prefix(k, 'diffusion_backbone.'): v for k, v in ckpt_state_dict.items() if k.startswith('diffusion_backbone')}
|
| 232 |
vae_state_dict = {remove_prefix(k, 'autoencoder.'): v for k, v in ckpt_state_dict.items() if k.startswith('autoencoder')}
|
| 233 |
missing_keys, extra_keys = model.load_state_dict(dit_state_dict, strict=False)
|
| 234 |
model.eval()
|
| 235 |
assert len(missing_keys) == 0 and len(extra_keys) == 0
|
| 236 |
+
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).to(device)
|
| 237 |
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
|
| 238 |
autoencoder.eval()
|
| 239 |
assert len(missing_keys) == 0 and len(extra_keys) == 0
|
| 240 |
+
sam_predictor = init_sam(ckpt_path="./sam_vit_h_4b8939.pth", device=device)
|
| 241 |
|
| 242 |
|
| 243 |
print("Mediapipe hand detector and SAM ready...")
|
|
|
|
| 313 |
img,
|
| 314 |
keypts,
|
| 315 |
hand_mask,
|
| 316 |
+
device=device,
|
| 317 |
target_size=(256, 256),
|
| 318 |
latent_size=(32, 32),
|
| 319 |
):
|
|
|
|
| 349 |
img,
|
| 350 |
keypts,
|
| 351 |
hand_mask,
|
| 352 |
+
device=device,
|
| 353 |
target_size=opts.image_size,
|
| 354 |
latent_size=opts.latent_size,
|
| 355 |
)
|
|
|
|
| 406 |
)
|
| 407 |
* kpts_valid[:, None, None],
|
| 408 |
dtype=torch.float,
|
| 409 |
+
device=device,
|
| 410 |
)[None, ...]
|
| 411 |
target_cond = torch.cat(
|
| 412 |
[target_heatmaps, torch.zeros_like(target_heatmaps)[:, :1]], 1
|
|
|
|
| 526 |
set_seed(seed)
|
| 527 |
z = torch.randn(
|
| 528 |
(num_gen, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]),
|
| 529 |
+
device=device,
|
| 530 |
)
|
| 531 |
target_cond = target_cond.repeat(num_gen, 1, 1, 1)
|
| 532 |
ref_cond = ref_cond.repeat(num_gen, 1, 1, 1)
|
| 533 |
# novel view synthesis mode = off
|
| 534 |
+
nvs = torch.zeros(num_gen, dtype=torch.int, device=device)
|
| 535 |
z = torch.cat([z, z], 0)
|
| 536 |
model_kwargs = dict(
|
| 537 |
target_cond=torch.cat([target_cond, torch.zeros_like(target_cond)]),
|
|
|
|
| 547 |
clip_denoised=False,
|
| 548 |
model_kwargs=model_kwargs,
|
| 549 |
progress=True,
|
| 550 |
+
device=device,
|
| 551 |
).chunk(2)
|
| 552 |
sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor)
|
| 553 |
sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0)
|
|
|
|
| 636 |
inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST
|
| 637 |
),
|
| 638 |
dtype=torch.float,
|
| 639 |
+
device=device,
|
| 640 |
).unsqueeze(0)[None, ...]
|
| 641 |
|
| 642 |
def make_ref_cond(
|
| 643 |
img,
|
| 644 |
keypts,
|
| 645 |
hand_mask,
|
| 646 |
+
device=device,
|
| 647 |
target_size=(256, 256),
|
| 648 |
latent_size=(32, 32),
|
| 649 |
):
|
|
|
|
| 679 |
img,
|
| 680 |
keypts,
|
| 681 |
hand_mask * (1 - inpaint_mask),
|
| 682 |
+
device=device,
|
| 683 |
target_size=opts.image_size,
|
| 684 |
latent_size=opts.latent_size,
|
| 685 |
)
|
|
|
|
| 737 |
jump_n_sample = quality
|
| 738 |
cfg_scale = cfg
|
| 739 |
z = torch.randn(
|
| 740 |
+
(N, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), device=device
|
| 741 |
)
|
| 742 |
target_cond_N = target_cond.repeat(N, 1, 1, 1)
|
| 743 |
ref_cond_N = ref_cond.repeat(N, 1, 1, 1)
|
| 744 |
# novel view synthesis mode = off
|
| 745 |
+
nvs = torch.zeros(N, dtype=torch.int, device=device)
|
| 746 |
z = torch.cat([z, z], 0)
|
| 747 |
model_kwargs = dict(
|
| 748 |
target_cond=torch.cat([target_cond_N, torch.zeros_like(target_cond_N)]),
|
|
|
|
| 760 |
clip_denoised=False,
|
| 761 |
model_kwargs=model_kwargs,
|
| 762 |
progress=True,
|
| 763 |
+
device=device,
|
| 764 |
jump_length=jump_length,
|
| 765 |
jump_n_sample=jump_n_sample,
|
| 766 |
).chunk(2)
|