Spaces:
Running
on
Zero
Running
on
Zero
| """This file contains the model definition of TiTok. | |
| Copyright (2024) Bytedance Ltd. and/or its affiliates | |
| Licensed under the Apache License, Version 2.0 (the "License"); | |
| you may not use this file except in compliance with the License. | |
| You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software | |
| distributed under the License is distributed on an "AS IS" BASIS, | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | |
| limitations under the License. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from modeling.modules.base_model import BaseModel | |
| from modeling.modules.blocks import TiTokEncoder, TiTokDecoder | |
| from modeling.quantizer.quantizer import VectorQuantizer, DiagonalGaussianDistribution | |
| from modeling.modules.maskgit_vqgan import Encoder as Pixel_Eecoder | |
| from modeling.modules.maskgit_vqgan import Decoder as Pixel_Decoder | |
| from modeling.modules.maskgit_vqgan import VectorQuantizer as Pixel_Quantizer | |
| import json | |
| from omegaconf import OmegaConf | |
| from pathlib import Path | |
| from huggingface_hub import PyTorchModelHubMixin | |
| class PretrainedTokenizer(nn.Module): | |
| def __init__(self, pretrained_weight): | |
| super().__init__() | |
| conf = OmegaConf.create( | |
| {"channel_mult": [1, 1, 2, 2, 4], | |
| "num_resolutions": 5, | |
| "dropout": 0.0, | |
| "hidden_channels": 128, | |
| "num_channels": 3, | |
| "num_res_blocks": 2, | |
| "resolution": 256, | |
| "z_channels": 256}) | |
| self.encoder = Pixel_Eecoder(conf) | |
| self.decoder = Pixel_Decoder(conf) | |
| self.quantize = Pixel_Quantizer( | |
| num_embeddings=1024, embedding_dim=256, commitment_cost=0.25) | |
| # Load pretrained weights | |
| self.load_state_dict(torch.load(pretrained_weight, map_location=torch.device("cpu")), strict=True) | |
| self.eval() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def encode(self, x): | |
| hidden_states = self.encoder(x) | |
| quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states) | |
| return codebook_indices.detach() | |
| def decode(self, codes): | |
| quantized_states = self.quantize.get_codebook_entry(codes) | |
| rec_images = self.decoder(quantized_states) | |
| rec_images = torch.clamp(rec_images, 0.0, 1.0) | |
| return rec_images.detach() | |
| def decode_tokens(self, codes): | |
| return self.decode(codes) | |
| class TiTok(BaseModel, PyTorchModelHubMixin, tags=["arxiv:2406.07550", "image-tokenization"], repo_url="https://github.com/bytedance/1d-tokenizer", license="apache-2.0"): | |
| def __init__(self, config): | |
| if isinstance(config, dict): | |
| config = OmegaConf.create(config) | |
| super().__init__() | |
| self.config = config | |
| # This should be False for stage1 and True for stage2. | |
| self.finetune_decoder = config.model.vq_model.get("finetune_decoder", True) | |
| self.quantize_mode = config.model.vq_model.get("quantize_mode", "vq") | |
| if self.quantize_mode not in ["vq", "vae"]: | |
| raise ValueError(f"Unsupported quantize mode {self.quantize_mode}.") | |
| if self.finetune_decoder and self.quantize_mode not in ["vq"]: | |
| raise ValueError("Only supprot finetune_decoder with vq quantization for now.") | |
| self.encoder = TiTokEncoder(config) | |
| self.decoder = TiTokDecoder(config) | |
| self.num_latent_tokens = config.model.vq_model.num_latent_tokens | |
| scale = self.encoder.width ** -0.5 | |
| self.latent_tokens = nn.Parameter( | |
| scale * torch.randn(self.num_latent_tokens, self.encoder.width)) | |
| self.apply(self._init_weights) | |
| if self.quantize_mode == "vq": | |
| self.quantize = VectorQuantizer( | |
| codebook_size=config.model.vq_model.codebook_size, | |
| token_size=config.model.vq_model.token_size, | |
| commitment_cost=config.model.vq_model.commitment_cost, | |
| use_l2_norm=config.model.vq_model.use_l2_norm,) | |
| elif self.quantize_mode == "vae": | |
| self.quantize = DiagonalGaussianDistribution | |
| else: | |
| raise NotImplementedError | |
| if self.finetune_decoder: | |
| # Freeze encoder/quantizer/latent tokens | |
| self.latent_tokens.requires_grad_(False) | |
| self.encoder.eval() | |
| self.encoder.requires_grad_(False) | |
| self.quantize.eval() | |
| self.quantize.requires_grad_(False) | |
| # Include MaskGiT-VQGAN's quantizer and decoder | |
| self.pixel_quantize = Pixel_Quantizer( | |
| num_embeddings=1024, embedding_dim=256, commitment_cost=0.25) | |
| self.pixel_decoder = Pixel_Decoder(OmegaConf.create( | |
| {"channel_mult": [1, 1, 2, 2, 4], | |
| "num_resolutions": 5, | |
| "dropout": 0.0, | |
| "hidden_channels": 128, | |
| "num_channels": 3, | |
| "num_res_blocks": 2, | |
| "resolution": 256, | |
| "z_channels": 256})) | |
| def _save_pretrained(self, save_directory: Path) -> None: | |
| """Save weights and config to a local directory.""" | |
| # Assume 'self.config' is your DictConfig object | |
| # Convert to a regular dictionary | |
| dict_config = OmegaConf.to_container(self.config) | |
| # Save as JSON | |
| file_path = Path(save_directory) / "config.json" | |
| with open(file_path, 'w') as json_file: | |
| json.dump(dict_config, json_file, indent=4) | |
| super()._save_pretrained(save_directory) | |
| def _init_weights(self, module): | |
| """ Initialize the weights. | |
| :param: | |
| module -> torch.nn.Module: module to initialize | |
| """ | |
| if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d) or isinstance(module, nn.Conv2d): | |
| module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02) | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| def encode(self, x): | |
| if self.finetune_decoder: | |
| with torch.no_grad(): | |
| self.encoder.eval() | |
| self.quantize.eval() | |
| z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens) | |
| z_quantized, result_dict = self.quantize(z) | |
| result_dict["quantizer_loss"] *= 0 | |
| result_dict["commitment_loss"] *= 0 | |
| result_dict["codebook_loss"] *= 0 | |
| else: | |
| z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens) | |
| if self.quantize_mode == "vq": | |
| z_quantized, result_dict = self.quantize(z) | |
| elif self.quantize_mode == "vae": | |
| posteriors = self.quantize(z) | |
| z_quantized = posteriors.sample() | |
| result_dict = posteriors | |
| return z_quantized, result_dict | |
| def decode(self, z_quantized): | |
| decoded = self.decoder(z_quantized) | |
| if self.finetune_decoder: | |
| quantized_states = torch.einsum( | |
| 'nchw,cd->ndhw', decoded.softmax(1), | |
| self.pixel_quantize.embedding.weight) | |
| decoded = self.pixel_decoder(quantized_states) | |
| return decoded | |
| def decode_tokens(self, tokens): | |
| if self.quantize_mode == "vq": | |
| tokens = tokens.squeeze(1) | |
| batch, seq_len = tokens.shape # B x N | |
| z_quantized = self.quantize.get_codebook_entry( | |
| tokens.reshape(-1)).reshape(batch, 1, seq_len, -1) | |
| z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous() | |
| elif self.quantize_mode == "vae": | |
| z_quantized = tokens | |
| decoded = self.decode(z_quantized) | |
| return decoded | |
| def forward(self, x): | |
| z_quantized, result_dict = self.encode(x) | |
| decoded = self.decode(z_quantized) | |
| return decoded, result_dict | |