Spaces:
Runtime error
Runtime error
enable zerogpu
Browse files
app.py
CHANGED
|
@@ -312,6 +312,7 @@ def get_ref_anno(ref):
|
|
| 312 |
point_labels=input_label,
|
| 313 |
multimask_output=False,
|
| 314 |
)
|
|
|
|
| 315 |
hand_mask = masks[0]
|
| 316 |
masked_img = img * hand_mask[..., None] + 255 * (1 - hand_mask[..., None])
|
| 317 |
ref_pose = visualize_hand(keypts, masked_img)
|
|
@@ -323,51 +324,48 @@ def get_ref_anno(ref):
|
|
| 323 |
|
| 324 |
@spaces.GPU(duration=120)
|
| 325 |
def make_ref_cond(
|
| 326 |
-
|
| 327 |
-
keypts,
|
| 328 |
-
hand_mask,
|
| 329 |
-
device=device,
|
| 330 |
-
target_size=(256, 256),
|
| 331 |
-
latent_size=(32, 32),
|
| 332 |
):
|
| 333 |
-
|
| 334 |
-
[
|
| 335 |
-
ToTensor(),
|
| 336 |
-
Resize(target_size),
|
| 337 |
-
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 338 |
-
]
|
| 339 |
-
)
|
| 340 |
-
image = image_transform(img).to(device)
|
| 341 |
-
kpts_valid = check_keypoints_validity(keypts, target_size)
|
| 342 |
-
heatmaps = torch.tensor(
|
| 343 |
-
keypoint_heatmap(
|
| 344 |
-
scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0
|
| 345 |
-
)
|
| 346 |
-
* kpts_valid[:, None, None],
|
| 347 |
-
dtype=torch.float,
|
| 348 |
-
device=device,
|
| 349 |
-
)[None, ...]
|
| 350 |
-
mask = torch.tensor(
|
| 351 |
-
cv2.resize(
|
| 352 |
-
hand_mask.astype(int),
|
| 353 |
-
dsize=latent_size,
|
| 354 |
-
interpolation=cv2.INTER_NEAREST,
|
| 355 |
-
),
|
| 356 |
-
dtype=torch.float,
|
| 357 |
-
device=device,
|
| 358 |
-
).unsqueeze(0)[None, ...]
|
| 359 |
latent = opts.latent_scaling_factor * autoencoder.encode(image[None, ...]).sample()
|
| 360 |
-
return image[None, ...],
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
device=device,
|
| 367 |
-
|
| 368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
)
|
| 370 |
-
|
|
|
|
| 371 |
if not REF_POSE_MASK:
|
| 372 |
heatmaps = torch.zeros_like(heatmaps)
|
| 373 |
mask = torch.zeros_like(mask)
|
|
|
|
| 312 |
point_labels=input_label,
|
| 313 |
multimask_output=False,
|
| 314 |
)
|
| 315 |
+
print("finished SAM")
|
| 316 |
hand_mask = masks[0]
|
| 317 |
masked_img = img * hand_mask[..., None] + 255 * (1 - hand_mask[..., None])
|
| 318 |
ref_pose = visualize_hand(keypts, masked_img)
|
|
|
|
| 324 |
|
| 325 |
@spaces.GPU(duration=120)
|
| 326 |
def make_ref_cond(
|
| 327 |
+
image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
):
|
| 329 |
+
print("ready to run autoencoder")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
latent = opts.latent_scaling_factor * autoencoder.encode(image[None, ...]).sample()
|
| 331 |
+
return image[None, ...], latent
|
| 332 |
+
|
| 333 |
+
image_transform = Compose(
|
| 334 |
+
[
|
| 335 |
+
ToTensor(),
|
| 336 |
+
Resize(opts.image_size),
|
| 337 |
+
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 338 |
+
]
|
| 339 |
+
)
|
| 340 |
+
image = image_transform(img).to(device)
|
| 341 |
+
kpts_valid = check_keypoints_validity(keypts, opts.image_size)
|
| 342 |
+
heatmaps = torch.tensor(
|
| 343 |
+
keypoint_heatmap(
|
| 344 |
+
scale_keypoint(keypts, opts.image_size, opts.latent_size), opts.latent_size, var=1.0
|
| 345 |
+
)
|
| 346 |
+
* kpts_valid[:, None, None],
|
| 347 |
+
dtype=torch.float,
|
| 348 |
device=device,
|
| 349 |
+
)[None, ...]
|
| 350 |
+
mask = torch.tensor(
|
| 351 |
+
cv2.resize(
|
| 352 |
+
hand_mask.astype(int),
|
| 353 |
+
dsize=opts.latent_size,
|
| 354 |
+
interpolation=cv2.INTER_NEAREST,
|
| 355 |
+
),
|
| 356 |
+
dtype=torch.float,
|
| 357 |
+
device=device,
|
| 358 |
+
).unsqueeze(0)[None, ...]
|
| 359 |
+
image, latent = make_ref_cond(
|
| 360 |
+
image,
|
| 361 |
+
# keypts,
|
| 362 |
+
# hand_mask,
|
| 363 |
+
# device=device,
|
| 364 |
+
# target_size=opts.image_size,
|
| 365 |
+
# latent_size=opts.latent_size,
|
| 366 |
)
|
| 367 |
+
print("finished autoencoder")
|
| 368 |
+
|
| 369 |
if not REF_POSE_MASK:
|
| 370 |
heatmaps = torch.zeros_like(heatmaps)
|
| 371 |
mask = torch.zeros_like(mask)
|