Spaces:
Runtime error
Runtime error
fix vae nan bug
Browse files
app.py
CHANGED
|
@@ -217,7 +217,21 @@ if NEW_MODEL:
|
|
| 217 |
model.eval()
|
| 218 |
print(missing_keys, extra_keys)
|
| 219 |
assert len(missing_keys) == 0
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
# else:
|
| 222 |
# opts = HandDiffOpts()
|
| 223 |
# model_path = './finetune_epoch=5-step=130000.ckpt'
|
|
@@ -261,24 +275,8 @@ def get_ref_anno(ref):
|
|
| 261 |
None,
|
| 262 |
None,
|
| 263 |
)
|
| 264 |
-
|
| 265 |
-
vae_state_dict = torch.load(vae_path, map_location='cpu')['state_dict']
|
| 266 |
-
print(f"vae_state_dict encoder dtype: {vae_state_dict['encoder.conv_in.weight'].dtype}")
|
| 267 |
-
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False)
|
| 268 |
-
print(f"autoencoder encoder dtype: {next(autoencoder.encoder.parameters()).dtype}")
|
| 269 |
-
print(f"encoder before load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
|
| 270 |
-
print(f"encoder before load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
|
| 271 |
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
|
| 272 |
-
print(f"encoder after load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
|
| 273 |
-
print(f"encoder after load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
|
| 274 |
-
autoencoder = autoencoder.to(device)
|
| 275 |
-
autoencoder.eval()
|
| 276 |
-
print(f"encoder after eval() min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
|
| 277 |
-
print(f"encoder after eval() max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
|
| 278 |
-
print(f"autoencoder encoder after eval() dtype: {next(autoencoder.encoder.parameters()).dtype}")
|
| 279 |
-
assert len(missing_keys) == 0
|
| 280 |
|
| 281 |
-
|
| 282 |
img = ref["composite"][..., :3]
|
| 283 |
img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
|
| 284 |
keypts = np.zeros((42, 2))
|
|
|
|
| 217 |
model.eval()
|
| 218 |
print(missing_keys, extra_keys)
|
| 219 |
assert len(missing_keys) == 0
|
| 220 |
+
vae_state_dict = torch.load(vae_path, map_location='cpu')['state_dict']
|
| 221 |
+
print(f"vae_state_dict encoder dtype: {vae_state_dict['encoder.conv_in.weight'].dtype}")
|
| 222 |
+
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False)
|
| 223 |
+
print(f"autoencoder encoder dtype: {next(autoencoder.encoder.parameters()).dtype}")
|
| 224 |
+
print(f"encoder before load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
|
| 225 |
+
print(f"encoder before load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
|
| 226 |
+
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
|
| 227 |
+
print(f"encoder after load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
|
| 228 |
+
print(f"encoder after load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
|
| 229 |
+
autoencoder = autoencoder.to(device)
|
| 230 |
+
autoencoder.eval()
|
| 231 |
+
print(f"encoder after eval() min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
|
| 232 |
+
print(f"encoder after eval() max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
|
| 233 |
+
print(f"autoencoder encoder after eval() dtype: {next(autoencoder.encoder.parameters()).dtype}")
|
| 234 |
+
assert len(missing_keys) == 0
|
| 235 |
# else:
|
| 236 |
# opts = HandDiffOpts()
|
| 237 |
# model_path = './finetune_epoch=5-step=130000.ckpt'
|
|
|
|
| 275 |
None,
|
| 276 |
None,
|
| 277 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
|
|
|
|
| 280 |
img = ref["composite"][..., :3]
|
| 281 |
img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
|
| 282 |
keypts = np.zeros((42, 2))
|