Spaces:
Runtime error
Runtime error
fix vae nan bug
Browse files
vqvae.py
CHANGED
|
@@ -56,9 +56,13 @@ class Autoencoder(nn.Module):
|
|
| 56 |
:param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]`
|
| 57 |
"""
|
| 58 |
# Get embeddings with shape `[batch_size, z_channels * 2, z_height, z_height]`
|
|
|
|
|
|
|
| 59 |
z = self.encoder(img)
|
|
|
|
| 60 |
# Get the moments in the quantized embedding space
|
| 61 |
moments = self.quant_conv(z)
|
|
|
|
| 62 |
# Return the distribution
|
| 63 |
return GaussianDistribution(moments)
|
| 64 |
|
|
@@ -284,6 +288,7 @@ class GaussianDistribution:
|
|
| 284 |
`[batch_size, z_channels * 2, z_height, z_height]`
|
| 285 |
"""
|
| 286 |
# Split mean and log of variance
|
|
|
|
| 287 |
self.mean, log_var = torch.chunk(parameters, 2, dim=1)
|
| 288 |
# Clamp the log of variances
|
| 289 |
self.log_var = torch.clamp(log_var, -30.0, 20.0)
|
|
@@ -293,6 +298,8 @@ class GaussianDistribution:
|
|
| 293 |
|
| 294 |
def sample(self):
|
| 295 |
# Sample from the distribution
|
|
|
|
|
|
|
| 296 |
return self.mean + self.std * torch.randn_like(self.std)
|
| 297 |
|
| 298 |
def kl(self):
|
|
|
|
| 56 |
:param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]`
|
| 57 |
"""
|
| 58 |
# Get embeddings with shape `[batch_size, z_channels * 2, z_height, z_height]`
|
| 59 |
+
print(f"encoder parameters max: {max([p.max() for p in self.encoder.parameters()])}")
|
| 60 |
+
print(f"encoder parameters min: {min([p.min() for p in self.encoder.parameters()])}")
|
| 61 |
z = self.encoder(img)
|
| 62 |
+
print(f"z.max(): {z.max()}, z.min(): {z.min()}")
|
| 63 |
# Get the moments in the quantized embedding space
|
| 64 |
moments = self.quant_conv(z)
|
| 65 |
+
print(f"moments.max(): {moments.max()}, moments.min(): {moments.min()}")
|
| 66 |
# Return the distribution
|
| 67 |
return GaussianDistribution(moments)
|
| 68 |
|
|
|
|
| 288 |
`[batch_size, z_channels * 2, z_height, z_height]`
|
| 289 |
"""
|
| 290 |
# Split mean and log of variance
|
| 291 |
+
print(f"parameters.max(): {parameters.max()}, parameters.min(): {parameters.min()}")
|
| 292 |
self.mean, log_var = torch.chunk(parameters, 2, dim=1)
|
| 293 |
# Clamp the log of variances
|
| 294 |
self.log_var = torch.clamp(log_var, -30.0, 20.0)
|
|
|
|
| 298 |
|
| 299 |
def sample(self):
|
| 300 |
# Sample from the distribution
|
| 301 |
+
print(f"self.mean.max(): {self.mean.max()}, self.mean.min(): {self.mean.min()}")
|
| 302 |
+
print(f"self.std.max(): {self.std.max()}, self.std.min(): {self.std.min()}")
|
| 303 |
return self.mean + self.std * torch.randn_like(self.std)
|
| 304 |
|
| 305 |
def kl(self):
|