Spaces:
Runtime error
Runtime error
| """ | |
| --- | |
| title: Autoencoder for Stable Diffusion | |
| summary: > | |
| Annotated PyTorch implementation/tutorial of the autoencoder | |
| for stable diffusion. | |
| --- | |
| # Autoencoder for [Stable Diffusion](../index.html) | |
| This implements the auto-encoder model used to map between image space and latent space. | |
| We have kept to the model definition and naming unchanged from | |
| [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion) | |
| so that we can load the checkpoints directly. | |
| """ | |
| from typing import List | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| try: | |
| import spaces | |
| except: | |
| pass | |
| class Autoencoder(nn.Module): | |
| """ | |
| ## Autoencoder | |
| This consists of the encoder and decoder modules. | |
| """ | |
| def __init__( | |
| self, encoder: "Encoder", decoder: "Decoder", emb_channels: int, z_channels: int | |
| ): | |
| """ | |
| :param encoder: is the encoder | |
| :param decoder: is the decoder | |
| :param emb_channels: is the number of dimensions in the quantized embedding space | |
| :param z_channels: is the number of channels in the embedding space | |
| """ | |
| super().__init__() | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| # Convolution to map from embedding space to | |
| # quantized embedding space moments (mean and log variance) | |
| self.quant_conv = nn.Conv2d(2 * z_channels, 2 * emb_channels, 1) | |
| # Convolution to map from quantized embedding space back to | |
| # embedding space | |
| self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1) | |
| def encode(self, img: torch.Tensor) -> "GaussianDistribution": | |
| """ | |
| ### Encode images to latent representation | |
| :param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]` | |
| """ | |
| # Get embeddings with shape `[batch_size, z_channels * 2, z_height, z_height]` | |
| print(f"encoder parameters max: {max([p.max() for p in self.encoder.parameters()])}") | |
| print(f"encoder parameters min: {min([p.min() for p in self.encoder.parameters()])}") | |
| print(f"img.dtype: {img.dtype}") | |
| z = self.encoder(img) | |
| print(f"z.max(): {z.max()}, z.min(): {z.min()}") | |
| # Get the moments in the quantized embedding space | |
| moments = self.quant_conv(z) | |
| print(f"moments.max(): {moments.max()}, moments.min(): {moments.min()}") | |
| # Return the distribution | |
| return GaussianDistribution(moments) | |
| def decode(self, z: torch.Tensor): | |
| """ | |
| ### Decode images from latent representation | |
| :param z: is the latent representation with shape `[batch_size, emb_channels, z_height, z_height]` | |
| """ | |
| # Map to embedding space from the quantized representation | |
| z = self.post_quant_conv(z) | |
| # Decode the image of shape `[batch_size, channels, height, width]` | |
| return self.decoder(z) | |
| def forward(self, x): | |
| posterior = self.encode(x) | |
| z = posterior.sample() | |
| dec = self.decode(z) | |
| return dec, posterior | |
| class Encoder(nn.Module): | |
| """ | |
| ## Encoder module | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| channels: int, | |
| channel_multipliers: List[int], | |
| n_resnet_blocks: int, | |
| in_channels: int, | |
| z_channels: int | |
| ): | |
| """ | |
| :param channels: is the number of channels in the first convolution layer | |
| :param channel_multipliers: are the multiplicative factors for the number of channels in the | |
| subsequent blocks | |
| :param n_resnet_blocks: is the number of resnet layers at each resolution | |
| :param in_channels: is the number of channels in the image | |
| :param z_channels: is the number of channels in the embedding space | |
| """ | |
| super().__init__() | |
| # Number of blocks of different resolutions. | |
| # The resolution is halved at the end each top level block | |
| n_resolutions = len(channel_multipliers) | |
| # Initial $3 \times 3$ convolution layer that maps the image to `channels` | |
| self.conv_in = nn.Conv2d(in_channels, channels, 3, stride=1, padding=1) | |
| # Number of channels in each top level block | |
| channels_list = [m * channels for m in [1] + channel_multipliers] | |
| # List of top-level blocks | |
| self.down = nn.ModuleList() | |
| # Create top-level blocks | |
| for i in range(n_resolutions): | |
| # Each top level block consists of multiple ResNet Blocks and down-sampling | |
| resnet_blocks = nn.ModuleList() | |
| # Add ResNet Blocks | |
| for _ in range(n_resnet_blocks): | |
| resnet_blocks.append(ResnetBlock(channels, channels_list[i + 1])) | |
| channels = channels_list[i + 1] | |
| # Top-level block | |
| down = nn.Module() | |
| down.block = resnet_blocks | |
| # Down-sampling at the end of each top level block except the last | |
| if i != n_resolutions - 1: | |
| down.downsample = DownSample(channels) | |
| else: | |
| down.downsample = nn.Identity() | |
| # | |
| self.down.append(down) | |
| # Final ResNet blocks with attention | |
| self.mid = nn.Module() | |
| self.mid.block_1 = ResnetBlock(channels, channels) | |
| self.mid.attn_1 = AttnBlock(channels) | |
| self.mid.block_2 = ResnetBlock(channels, channels) | |
| # Map to embedding space with a $3 \times 3$ convolution | |
| self.norm_out = normalization(channels) | |
| self.conv_out = nn.Conv2d(channels, 2 * z_channels, 3, stride=1, padding=1) | |
| def forward(self, img: torch.Tensor): | |
| """ | |
| :param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]` | |
| """ | |
| # Map to `channels` with the initial convolution | |
| x = self.conv_in(img) | |
| # Top-level blocks | |
| for down in self.down: | |
| # ResNet Blocks | |
| for block in down.block: | |
| x = block(x) | |
| # Down-sampling | |
| x = down.downsample(x) | |
| # Final ResNet blocks with attention | |
| x = self.mid.block_1(x) | |
| x = self.mid.attn_1(x) | |
| x = self.mid.block_2(x) | |
| # Normalize and map to embedding space | |
| x = self.norm_out(x) | |
| x = swish(x) | |
| x = self.conv_out(x) | |
| # | |
| return x | |
| class Decoder(nn.Module): | |
| """ | |
| ## Decoder module | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| channels: int, | |
| channel_multipliers: List[int], | |
| n_resnet_blocks: int, | |
| out_channels: int, | |
| z_channels: int | |
| ): | |
| """ | |
| :param channels: is the number of channels in the final convolution layer | |
| :param channel_multipliers: are the multiplicative factors for the number of channels in the | |
| previous blocks, in reverse order | |
| :param n_resnet_blocks: is the number of resnet layers at each resolution | |
| :param out_channels: is the number of channels in the image | |
| :param z_channels: is the number of channels in the embedding space | |
| """ | |
| super().__init__() | |
| # Number of blocks of different resolutions. | |
| # The resolution is halved at the end each top level block | |
| num_resolutions = len(channel_multipliers) | |
| # Number of channels in each top level block, in the reverse order | |
| channels_list = [m * channels for m in channel_multipliers] | |
| # Number of channels in the top-level block | |
| channels = channels_list[-1] | |
| # Initial $3 \times 3$ convolution layer that maps the embedding space to `channels` | |
| self.conv_in = nn.Conv2d(z_channels, channels, 3, stride=1, padding=1) | |
| # ResNet blocks with attention | |
| self.mid = nn.Module() | |
| self.mid.block_1 = ResnetBlock(channels, channels) | |
| self.mid.attn_1 = AttnBlock(channels) | |
| self.mid.block_2 = ResnetBlock(channels, channels) | |
| # List of top-level blocks | |
| self.up = nn.ModuleList() | |
| # Create top-level blocks | |
| for i in reversed(range(num_resolutions)): | |
| # Each top level block consists of multiple ResNet Blocks and up-sampling | |
| resnet_blocks = nn.ModuleList() | |
| # Add ResNet Blocks | |
| for _ in range(n_resnet_blocks + 1): | |
| resnet_blocks.append(ResnetBlock(channels, channels_list[i])) | |
| channels = channels_list[i] | |
| # Top-level block | |
| up = nn.Module() | |
| up.block = resnet_blocks | |
| # Up-sampling at the end of each top level block except the first | |
| if i != 0: | |
| up.upsample = UpSample(channels) | |
| else: | |
| up.upsample = nn.Identity() | |
| # Prepend to be consistent with the checkpoint | |
| self.up.insert(0, up) | |
| # Map to image space with a $3 \times 3$ convolution | |
| self.norm_out = normalization(channels) | |
| self.conv_out = nn.Conv2d(channels, out_channels, 3, stride=1, padding=1) | |
| def forward(self, z: torch.Tensor): | |
| """ | |
| :param z: is the embedding tensor with shape `[batch_size, z_channels, z_height, z_height]` | |
| """ | |
| # Map to `channels` with the initial convolution | |
| h = self.conv_in(z) | |
| # ResNet blocks with attention | |
| h = self.mid.block_1(h) | |
| h = self.mid.attn_1(h) | |
| h = self.mid.block_2(h) | |
| # Top-level blocks | |
| for up in reversed(self.up): | |
| # ResNet Blocks | |
| for block in up.block: | |
| h = block(h) | |
| # Up-sampling | |
| h = up.upsample(h) | |
| # Normalize and map to image space | |
| h = self.norm_out(h) | |
| h = swish(h) | |
| img = self.conv_out(h) | |
| # | |
| return img | |
| class GaussianDistribution: | |
| """ | |
| ## Gaussian Distribution | |
| """ | |
| def __init__(self, parameters: torch.Tensor): | |
| """ | |
| :param parameters: are the means and log of variances of the embedding of shape | |
| `[batch_size, z_channels * 2, z_height, z_height]` | |
| """ | |
| # Split mean and log of variance | |
| print(f"parameters.max(): {parameters.max()}, parameters.min(): {parameters.min()}") | |
| self.mean, log_var = torch.chunk(parameters, 2, dim=1) | |
| # Clamp the log of variances | |
| self.log_var = torch.clamp(log_var, -30.0, 20.0) | |
| # Calculate standard deviation | |
| self.std = torch.exp(0.5 * self.log_var) | |
| self.var = torch.exp(self.log_var) | |
| def sample(self): | |
| # Sample from the distribution | |
| print(f"self.mean.max(): {self.mean.max()}, self.mean.min(): {self.mean.min()}") | |
| print(f"self.std.max(): {self.std.max()}, self.std.min(): {self.std.min()}") | |
| return self.mean + self.std * torch.randn_like(self.std) | |
| def kl(self): | |
| return 0.5 * torch.sum( | |
| torch.pow(self.mean, 2) + self.var - 1.0 - self.log_var, dim=[1, 2, 3] | |
| ) | |
| class AttnBlock(nn.Module): | |
| """ | |
| ## Attention block | |
| """ | |
| def __init__(self, channels: int): | |
| """ | |
| :param channels: is the number of channels | |
| """ | |
| super().__init__() | |
| # Group normalization | |
| self.norm = normalization(channels) | |
| # Query, key and value mappings | |
| self.q = nn.Conv2d(channels, channels, 1) | |
| self.k = nn.Conv2d(channels, channels, 1) | |
| self.v = nn.Conv2d(channels, channels, 1) | |
| # Final $1 \times 1$ convolution layer | |
| self.proj_out = nn.Conv2d(channels, channels, 1) | |
| # Attention scaling factor | |
| self.scale = channels**-0.5 | |
| def forward(self, x: torch.Tensor): | |
| """ | |
| :param x: is the tensor of shape `[batch_size, channels, height, width]` | |
| """ | |
| # Normalize `x` | |
| x_norm = self.norm(x) | |
| # Get query, key and vector embeddings | |
| q = self.q(x_norm) | |
| k = self.k(x_norm) | |
| v = self.v(x_norm) | |
| # Reshape to query, key and vector embeedings from | |
| # `[batch_size, channels, height, width]` to | |
| # `[batch_size, channels, height * width]` | |
| b, c, h, w = q.shape | |
| q = q.view(b, c, h * w) | |
| k = k.view(b, c, h * w) | |
| v = v.view(b, c, h * w) | |
| # Compute $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$ | |
| attn = torch.einsum("bci,bcj->bij", q, k) * self.scale | |
| attn = F.softmax(attn, dim=2) | |
| # Compute $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$ | |
| out = torch.einsum("bij,bcj->bci", attn, v) | |
| # Reshape back to `[batch_size, channels, height, width]` | |
| out = out.view(b, c, h, w) | |
| # Final $1 \times 1$ convolution layer | |
| out = self.proj_out(out) | |
| # Add residual connection | |
| return x + out | |
| class UpSample(nn.Module): | |
| """ | |
| ## Up-sampling layer | |
| """ | |
| def __init__(self, channels: int): | |
| """ | |
| :param channels: is the number of channels | |
| """ | |
| super().__init__() | |
| # $3 \times 3$ convolution mapping | |
| self.conv = nn.Conv2d(channels, channels, 3, padding=1) | |
| def forward(self, x: torch.Tensor): | |
| """ | |
| :param x: is the input feature map with shape `[batch_size, channels, height, width]` | |
| """ | |
| # Up-sample by a factor of $2$ | |
| x = F.interpolate(x, scale_factor=2.0, mode="nearest") | |
| # Apply convolution | |
| return self.conv(x) | |
| class DownSample(nn.Module): | |
| """ | |
| ## Down-sampling layer | |
| """ | |
| def __init__(self, channels: int): | |
| """ | |
| :param channels: is the number of channels | |
| """ | |
| super().__init__() | |
| # $3 \times 3$ convolution with stride length of $2$ to down-sample by a factor of $2$ | |
| self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=0) | |
| def forward(self, x: torch.Tensor): | |
| """ | |
| :param x: is the input feature map with shape `[batch_size, channels, height, width]` | |
| """ | |
| # Add padding | |
| x = F.pad(x, (0, 1, 0, 1), mode="constant", value=0) | |
| # Apply convolution | |
| return self.conv(x) | |
| class ResnetBlock(nn.Module): | |
| """ | |
| ## ResNet Block | |
| """ | |
| def __init__(self, in_channels: int, out_channels: int): | |
| """ | |
| :param in_channels: is the number of channels in the input | |
| :param out_channels: is the number of channels in the output | |
| """ | |
| super().__init__() | |
| # First normalization and convolution layer | |
| self.norm1 = normalization(in_channels) | |
| self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1) | |
| # Second normalization and convolution layer | |
| self.norm2 = normalization(out_channels) | |
| self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1) | |
| # `in_channels` to `out_channels` mapping layer for residual connection | |
| if in_channels != out_channels: | |
| self.nin_shortcut = nn.Conv2d( | |
| in_channels, out_channels, 1, stride=1, padding=0 | |
| ) | |
| else: | |
| self.nin_shortcut = nn.Identity() | |
| def forward(self, x: torch.Tensor): | |
| """ | |
| :param x: is the input feature map with shape `[batch_size, channels, height, width]` | |
| """ | |
| h = x | |
| # First normalization and convolution layer | |
| h = self.norm1(h) | |
| h = swish(h) | |
| h = self.conv1(h) | |
| # Second normalization and convolution layer | |
| h = self.norm2(h) | |
| h = swish(h) | |
| h = self.conv2(h) | |
| # Map and add residual | |
| return self.nin_shortcut(x) + h | |
| def swish(x: torch.Tensor): | |
| """ | |
| ### Swish activation | |
| """ | |
| return x * torch.sigmoid(x) | |
| def normalization(channels: int): | |
| """ | |
| ### Group normalization | |
| This is a helper function, with fixed number of groups and `eps`. | |
| """ | |
| return nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6) | |
| def restore_ae_from_sd(model, path): | |
| def remove_prefix(text, prefix): | |
| if text.startswith(prefix): | |
| return text[len(prefix) :] | |
| return text | |
| checkpoint = torch.load(path) | |
| # checkpoint = torch.load(path, map_location="cpu") | |
| ckpt_state_dict = checkpoint["state_dict"] | |
| new_ckpt_state_dict = {} | |
| for k, v in ckpt_state_dict.items(): | |
| new_k = remove_prefix(k, "first_stage_model.") | |
| new_ckpt_state_dict[new_k] = v | |
| missing_keys, extra_keys = model.load_state_dict(new_ckpt_state_dict, strict=False) | |
| assert len(missing_keys) == 0 | |
| def create_model(in_channels, out_channels, latent_dim=4): | |
| encoder = Encoder( | |
| z_channels=latent_dim, | |
| in_channels=in_channels, | |
| channels=128, | |
| channel_multipliers=[1, 2, 4, 4], | |
| n_resnet_blocks=2, | |
| ) | |
| decoder = Decoder( | |
| out_channels=out_channels, | |
| z_channels=latent_dim, | |
| channels=128, | |
| channel_multipliers=[1, 2, 4, 4], | |
| n_resnet_blocks=2, | |
| ) | |
| autoencoder = Autoencoder( | |
| emb_channels=latent_dim, encoder=encoder, decoder=decoder, z_channels=latent_dim | |
| ) | |
| return autoencoder | |