Spaces:
Runtime error
Runtime error
enable zerogpu
Browse files
app.py
CHANGED
|
@@ -210,13 +210,13 @@ if NEW_MODEL:
|
|
| 210 |
learn_sigma=True,
|
| 211 |
).to(device)
|
| 212 |
# ckpt_state_dict = torch.load(model_path)['model_state_dict']
|
| 213 |
-
ckpt_state_dict = torch.load(model_path, map_location=
|
| 214 |
missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False)
|
| 215 |
model.eval()
|
| 216 |
print(missing_keys, extra_keys)
|
| 217 |
assert len(missing_keys) == 0
|
| 218 |
vae_state_dict = torch.load(vae_path)['state_dict']
|
| 219 |
-
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).to(device)
|
| 220 |
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
|
| 221 |
autoencoder.eval()
|
| 222 |
assert len(missing_keys) == 0
|
|
@@ -243,7 +243,7 @@ else:
|
|
| 243 |
autoencoder.eval()
|
| 244 |
assert len(missing_keys) == 0 and len(extra_keys) == 0
|
| 245 |
sam_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="sam_vit_h_4b8939.pth", token=token)
|
| 246 |
-
sam_predictor = init_sam(ckpt_path=sam_path, device=
|
| 247 |
|
| 248 |
|
| 249 |
print("Mediapipe hand detector and SAM ready...")
|
|
@@ -254,7 +254,7 @@ hands = mp_hands.Hands(
|
|
| 254 |
min_detection_confidence=0.1,
|
| 255 |
)
|
| 256 |
|
| 257 |
-
|
| 258 |
def get_ref_anno(ref):
|
| 259 |
if ref is None:
|
| 260 |
return (
|
|
@@ -301,6 +301,7 @@ def get_ref_anno(ref):
|
|
| 301 |
elif keypts[21].sum() != 0:
|
| 302 |
input_point = np.array(keypts[21:22])
|
| 303 |
input_label = np.array([1])
|
|
|
|
| 304 |
masks, _, _ = sam_predictor.predict(
|
| 305 |
point_coords=input_point,
|
| 306 |
point_labels=input_label,
|
|
|
|
| 210 |
learn_sigma=True,
|
| 211 |
).to(device)
|
| 212 |
# ckpt_state_dict = torch.load(model_path)['model_state_dict']
|
| 213 |
+
ckpt_state_dict = torch.load(model_path, map_location='cpu')['ema_state_dict']
|
| 214 |
missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False)
|
| 215 |
model.eval()
|
| 216 |
print(missing_keys, extra_keys)
|
| 217 |
assert len(missing_keys) == 0
|
| 218 |
vae_state_dict = torch.load(vae_path)['state_dict']
|
| 219 |
+
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False) # .to(device)
|
| 220 |
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
|
| 221 |
autoencoder.eval()
|
| 222 |
assert len(missing_keys) == 0
|
|
|
|
| 243 |
autoencoder.eval()
|
| 244 |
assert len(missing_keys) == 0 and len(extra_keys) == 0
|
| 245 |
sam_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="sam_vit_h_4b8939.pth", token=token)
|
| 246 |
+
sam_predictor = init_sam(ckpt_path=sam_path, device='cpu')
|
| 247 |
|
| 248 |
|
| 249 |
print("Mediapipe hand detector and SAM ready...")
|
|
|
|
| 254 |
min_detection_confidence=0.1,
|
| 255 |
)
|
| 256 |
|
| 257 |
+
@spaces.GPU(duration=120)
|
| 258 |
def get_ref_anno(ref):
|
| 259 |
if ref is None:
|
| 260 |
return (
|
|
|
|
| 301 |
elif keypts[21].sum() != 0:
|
| 302 |
input_point = np.array(keypts[21:22])
|
| 303 |
input_label = np.array([1])
|
| 304 |
+
print("ready to run SAM")
|
| 305 |
masks, _, _ = sam_predictor.predict(
|
| 306 |
point_coords=input_point,
|
| 307 |
point_labels=input_label,
|