Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python | |
| ''' | |
| This file is to train stable video diffusion by my personal implementation which is based on diffusers' training example code. | |
| ''' | |
| import argparse | |
| import logging | |
| import math | |
| import os, sys | |
| import time | |
| import random | |
| import shutil | |
| import warnings | |
| import cv2 | |
| from PIL import Image | |
| from einops import rearrange, repeat | |
| from pathlib import Path | |
| from omegaconf import OmegaConf | |
| import imageio | |
| import accelerate | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint | |
| from torch.utils.tensorboard import SummaryWriter | |
| from torch.utils.data import RandomSampler | |
| import transformers | |
| from accelerate import Accelerator | |
| from accelerate.logging import get_logger | |
| from accelerate.utils import ProjectConfiguration, set_seed | |
| from packaging import version | |
| from PIL import Image | |
| from torchvision import transforms | |
| from tqdm.auto import tqdm | |
| from transformers import AutoTokenizer, PretrainedConfig | |
| import diffusers | |
| from diffusers import ( | |
| AutoencoderKLTemporalDecoder, | |
| DDPMScheduler, | |
| ) | |
| from diffusers.training_utils import EMAModel, compute_snr | |
| from diffusers.optimization import get_scheduler | |
| from diffusers.utils import check_min_version, is_wandb_available, load_image, export_to_video | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from diffusers.image_processor import VaeImageProcessor | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection | |
| if is_wandb_available(): | |
| import wandb | |
| # Import files from the local folder | |
| root_path = os.path.abspath('.') | |
| sys.path.append(root_path) | |
| from svd.pipeline_stable_video_diffusion import StableVideoDiffusionPipeline | |
| from svd.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel | |
| from data_loader.video_dataset import Video_Dataset, get_video_frames, tokenize_captions | |
| from utils.img_utils import resize_with_antialiasing | |
| # Will error if the minimal version of diffusers is not installed. Remove at your own risks. | |
| # check_min_version("0.25.0.dev0") | |
| logger = get_logger(__name__) | |
| warnings.filterwarnings('ignore') | |
| ################################################################################################################################################### | |
| def parse_args(input_args=None): | |
| parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") | |
| parser.add_argument( | |
| "--config_path", | |
| type=str, | |
| default="config/train_image2video.yaml", | |
| help="Path to pretrained model or model identifier from huggingface.co/models.", | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| def log_validation(vae, unet, image_encoder, text_encoder, tokenizer, config, accelerator, weight_dtype, step, | |
| parent_store_folder = None, force_close_flip = False, use_ambiguous_prompt=False): | |
| # This function will also be used in other files | |
| print("Running validation... ") | |
| # Init | |
| validation_source_folder = config["validation_img_folder"] | |
| # Init the pipeline | |
| pipeline = StableVideoDiffusionPipeline.from_pretrained( | |
| config["pretrained_model_name_or_path"], | |
| vae = accelerator.unwrap_model(vae), | |
| image_encoder = accelerator.unwrap_model(image_encoder), | |
| unet = accelerator.unwrap_model(unet), | |
| revision = None, # Set None directly now | |
| torch_dtype = weight_dtype, | |
| ) | |
| pipeline = pipeline.to(accelerator.device) | |
| pipeline.set_progress_bar_config(disable=True) | |
| # Process all image in the folder | |
| frames_collection = [] | |
| for image_name in sorted(os.listdir(validation_source_folder)): | |
| if accelerator.is_main_process: | |
| if parent_store_folder is None: | |
| validation_store_folder = os.path.join(config["validation_store_folder"] + "_" + config["scheduler"], "step_" + str(step), image_name) | |
| else: | |
| validation_store_folder = os.path.join(parent_store_folder, image_name) | |
| if os.path.exists(validation_store_folder): | |
| shutil.rmtree(validation_store_folder) | |
| os.makedirs(validation_store_folder) | |
| image_path = os.path.join(validation_source_folder, image_name, 'im_0.jpg') | |
| ref_image = load_image(image_path) | |
| ref_image = ref_image.resize((config["width"], config["height"])) | |
| # Decide the motion score in SVD (mostly what we use is fix value now) | |
| if config["motion_bucket_id"] is None: | |
| raise NotImplementedError("We need a fixed motion_bucket_id in the config") | |
| else: | |
| reflected_motion_bucket_id = config["motion_bucket_id"] | |
| print("Inference Motion Bucket ID is ", reflected_motion_bucket_id) | |
| # Prepare text prompt | |
| if config["use_text"]: | |
| # Read the file | |
| file_path = os.path.join(validation_source_folder, image_name, "lang.txt") | |
| file = open(file_path, 'r') | |
| prompt = file.readlines()[0] # Only read the first line | |
| if use_ambiguous_prompt: | |
| prompt = prompt.split(" ")[0] + " this to there" | |
| print("We are creating ambiguous prompt, which is: ", prompt) | |
| else: | |
| prompt = "" | |
| # Use the same tokenize process as the dataset preparation stage | |
| tokenized_prompt = tokenize_captions(prompt, tokenizer, config, is_train=False).unsqueeze(0).to(accelerator.device) # Use unsqueeze to expand dim | |
| # Store the prompt for the sanity check | |
| f = open(os.path.join(validation_store_folder, "lang_cond.txt"), "a") | |
| f.write(prompt) | |
| f.close() | |
| # Flip the image by chance (it is needed to check whether there is any object position words [left|right] in the prompt text) | |
| flip = False | |
| if not force_close_flip: # force_close_flip is True in testing time; else, we cannot match in the same standard | |
| if random.random() < config["flip_aug_prob"]: | |
| if config["use_text"]: | |
| if prompt.find("left") == -1 and prompt.find("right") == -1: # Cannot have position word, like left and right (up and down is ok) | |
| flip = True | |
| else: | |
| flip = True | |
| if flip: | |
| print("Use flip in validation!") | |
| ref_image = ref_image.transpose(Image.FLIP_LEFT_RIGHT) | |
| # Call the model for inference | |
| with torch.autocast("cuda"): | |
| frames = pipeline( | |
| ref_image, | |
| tokenized_prompt, | |
| config["use_text"], | |
| text_encoder, | |
| height = config["height"], | |
| width = config["width"], | |
| num_frames = config["video_seq_length"], | |
| num_inference_steps = config["num_inference_steps"], | |
| decode_chunk_size = 8, | |
| motion_bucket_id = reflected_motion_bucket_id, | |
| fps = 7, | |
| noise_aug_strength = config["inference_noise_aug_strength"], | |
| ).frames[0] | |
| # Store the frames | |
| # breakpoint() | |
| for idx, frame in enumerate(frames): | |
| frame.save(os.path.join(validation_store_folder, str(idx)+".png")) | |
| imageio.mimsave(os.path.join(validation_store_folder, 'combined.gif'), frames) # gif storage quality is not high, recommend to check png images | |
| frames_collection.append(frames) | |
| # Cleaning process | |
| del pipeline | |
| torch.cuda.empty_cache() | |
| return frames_collection # Return resuly based on the need | |
| def tensor_to_vae_latent(inputs, vae): | |
| video_length = inputs.shape[1] | |
| inputs = rearrange(inputs, "b f c h w -> (b f) c h w") | |
| latents = vae.encode(inputs).latent_dist.mode() | |
| latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length) # Use f or b to rearrage should have the same effect | |
| latents = latents * vae.config.scaling_factor | |
| return latents | |
| def import_pretrained_text_encoder(pretrained_model_name_or_path: str, revision: str): | |
| ''' Import Text encoder information | |
| ''' | |
| text_encoder_config = PretrainedConfig.from_pretrained( | |
| pretrained_model_name_or_path, | |
| subfolder="text_encoder", | |
| revision=revision, | |
| ) | |
| model_class = text_encoder_config.architectures[0] | |
| if model_class == "CLIPTextModel": | |
| from transformers import CLIPTextModel | |
| return CLIPTextModel | |
| else: # No other cases will be considerred | |
| raise ValueError(f"{model_class} is not supported.") | |
| def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32): | |
| """Draws samples from an lognormal distribution.""" | |
| u = torch.rand(shape, dtype=dtype, device=device) * (1 - 2e-7) + 1e-7 | |
| return torch.distributions.Normal(loc, scale).icdf(u).exp() | |
| def get_add_time_ids( | |
| unet_config, | |
| expected_add_embed_dim, | |
| fps, | |
| motion_bucket_id, | |
| noise_aug_strength, | |
| dtype, | |
| batch_size, | |
| num_videos_per_prompt, | |
| ): | |
| # Construct Basic add_time_ids items | |
| add_time_ids = [fps, motion_bucket_id, noise_aug_strength] | |
| passed_add_embed_dim = unet_config.addition_time_embed_dim * len(add_time_ids) | |
| if expected_add_embed_dim != passed_add_embed_dim: | |
| raise ValueError( | |
| f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." | |
| ) | |
| add_time_ids = torch.tensor([add_time_ids], dtype=dtype) | |
| add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) | |
| return add_time_ids | |
| #################################################################################################################################################################### | |
| def main(config): | |
| # Read Config Setting | |
| resume_from_checkpoint = config["resume_from_checkpoint"] | |
| output_dir = config["output_dir"] | |
| logging_name = config["logging_name"] | |
| mixed_precision = config["mixed_precision"] | |
| report_to = config["report_to"] | |
| pretrained_model_name_or_path = config["pretrained_model_name_or_path"] | |
| pretrained_tokenizer_name_or_path = config["pretrained_tokenizer_name_or_path"] | |
| gradient_checkpointing = config["gradient_checkpointing"] | |
| learning_rate = config["learning_rate"] | |
| adam_beta1 = config["adam_beta1"] | |
| adam_beta2 = config["adam_beta2"] | |
| adam_weight_decay = config["adam_weight_decay"] | |
| adam_epsilon = config["adam_epsilon"] | |
| train_batch_size = config["train_batch_size"] | |
| dataloader_num_workers = config["dataloader_num_workers"] | |
| gradient_accumulation_steps = config["gradient_accumulation_steps"] | |
| num_train_iters = config["num_train_iters"] | |
| lr_warmup_steps = config["lr_warmup_steps"] | |
| checkpointing_steps = config["checkpointing_steps"] | |
| process_fps = config["process_fps"] | |
| train_noise_aug_strength = config["train_noise_aug_strength"] | |
| use_8bit_adam = config["use_8bit_adam"] | |
| scale_lr = config["scale_lr"] | |
| conditioning_dropout_prob = config["conditioning_dropout_prob"] | |
| checkpoints_total_limit = config["checkpoints_total_limit"] | |
| validation_step = config["validation_step"] | |
| partial_finetune = config['partial_finetune'] | |
| # Default Setting | |
| revision = None | |
| variant = "fp16" | |
| lr_scheduler = "constant" | |
| max_grad_norm = 1.0 | |
| tracker_project_name = "img2video" | |
| num_videos_per_prompt = 1 | |
| seed = 42 | |
| # No CFG in training now | |
| # Define the accelerator | |
| logging_dir = Path(output_dir, logging_name) | |
| accelerator_project_config = ProjectConfiguration(project_dir=output_dir, logging_dir=logging_dir) | |
| accelerator = Accelerator( | |
| gradient_accumulation_steps = gradient_accumulation_steps, | |
| mixed_precision = mixed_precision, | |
| log_with = report_to, | |
| project_config = accelerator_project_config, | |
| ) | |
| generator = torch.Generator(device=accelerator.device).manual_seed(seed) | |
| # Make one log on every process with the configuration for debugging. | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| level=logging.INFO, | |
| ) | |
| logger.info(accelerator.state, main_process_only=False) | |
| if accelerator.is_local_main_process: | |
| transformers.utils.logging.set_verbosity_warning() | |
| diffusers.utils.logging.set_verbosity_info() | |
| else: | |
| transformers.utils.logging.set_verbosity_error() | |
| diffusers.utils.logging.set_verbosity_error() | |
| # Handle the repository creation | |
| if accelerator.is_main_process and resume_from_checkpoint != "latest": # For the latest checkpoint version, we don't need to delete our folders | |
| # Validation file | |
| validation_store_folder = config["validation_store_folder"] + "_" + config["scheduler"] | |
| print("We will remove ", validation_store_folder) | |
| if os.path.exists(validation_store_folder): | |
| archive_name = validation_store_folder + "_archive" | |
| if os.path.exists(archive_name): | |
| shutil.rmtree(archive_name) | |
| print("We will move to archive ", archive_name) | |
| os.rename(validation_store_folder, archive_name) | |
| os.makedirs(validation_store_folder) | |
| # Output Dir | |
| if os.path.exists(output_dir): | |
| shutil.rmtree(output_dir) | |
| # os.makedirs(output_dir, exist_ok=True) | |
| # Log | |
| if os.path.exists("runs"): | |
| shutil.rmtree("runs") | |
| # Copy the config to here | |
| os.system(" cp config/train_image2video.yaml " + validation_store_folder + "/") | |
| # Load All Module Needed | |
| feature_extractor = CLIPImageProcessor.from_pretrained( | |
| pretrained_model_name_or_path, subfolder="feature_extractor", revision=revision | |
| ) # This instance has now weight, they are just seeting file | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
| pretrained_model_name_or_path, subfolder="image_encoder", revision=revision, variant=variant | |
| ) | |
| vae = AutoencoderKLTemporalDecoder.from_pretrained( | |
| pretrained_model_name_or_path, subfolder="vae", revision=revision, variant=variant | |
| ) | |
| if config["load_unet_path"] != None: | |
| print("We will load UNet from ", config["load_unet_path"]) | |
| unet = UNetSpatioTemporalConditionModel.from_pretrained( | |
| config["load_unet_path"], | |
| subfolder = "unet", | |
| low_cpu_mem_usage = True, | |
| ) # For the variant, we don't have fp16 version, so we will read from fp32 | |
| else: | |
| print("We will only use SVD pretrained UNet") | |
| unet = UNetSpatioTemporalConditionModel.from_pretrained( | |
| pretrained_model_name_or_path, | |
| subfolder = "unet", | |
| low_cpu_mem_usage = True, | |
| variant = variant, | |
| ) | |
| # Prepare for the tokenizer if use text | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| pretrained_tokenizer_name_or_path, | |
| subfolder = "tokenizer", | |
| revision = revision, | |
| use_fast = False, | |
| ) | |
| if config["use_text"]: | |
| # Clip Text Encoder | |
| text_encoder_cls = import_pretrained_text_encoder(pretrained_tokenizer_name_or_path, revision) | |
| text_encoder = text_encoder_cls.from_pretrained( | |
| pretrained_tokenizer_name_or_path, subfolder = "text_encoder", revision = revision, variant = variant | |
| ) | |
| else: | |
| text_encoder = None | |
| # Store the config due to the disappearance after accelerator prepare (This is written to handle some unknown phenomenon) | |
| unet_config = unet.config | |
| expected_add_embed_dim = unet.add_embedding.linear_1.in_features | |
| # Freeze vae + feature_extractor + image_encoder, but set unet to trainable | |
| vae.requires_grad_(False) | |
| image_encoder.requires_grad_(False) | |
| unet.requires_grad_(False) # Will switch back to train mode later on | |
| if config["use_text"]: | |
| text_encoder.requires_grad_(False) # All set with no grad needed (like VAE) follow other T2I papers | |
| # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision | |
| # as these weights are only used for inference, keeping weights in full precision is not required. | |
| weight_dtype = torch.float32 | |
| if accelerator.mixed_precision == "fp16": | |
| weight_dtype = torch.float16 | |
| elif accelerator.mixed_precision == "bf16": | |
| weight_dtype = torch.bfloat16 | |
| # Move vae + image_encoder to gpu and cast to weight_dtype | |
| vae.to(accelerator.device, dtype=weight_dtype) | |
| image_encoder.to(accelerator.device, dtype=weight_dtype) | |
| # unet.to(accelerator.device, dtype=weight_dtype) | |
| if config["use_text"]: | |
| text_encoder.to(accelerator.device, dtype=weight_dtype) | |
| # Acceleration: `accelerate` 0.16.0 will have better support for customized saving | |
| if version.parse(accelerate.__version__) >= version.parse("0.16.0"): | |
| # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format | |
| def save_model_hook(models, weights, output_dir): | |
| if accelerator.is_main_process: | |
| for i, model in enumerate(models): | |
| model.save_pretrained(os.path.join(output_dir, "unet")) | |
| # make sure to pop weight so that corresponding model is not saved again | |
| weights.pop() | |
| def load_model_hook(models, input_dir): | |
| for i in range(len(models)): | |
| # pop models so that they are not loaded again | |
| model = models.pop() | |
| # load diffusers style into model | |
| load_model = UNetSpatioTemporalConditionModel.from_pretrained(input_dir, subfolder="unet") | |
| model.register_to_config(**load_model.config) | |
| model.load_state_dict(load_model.state_dict()) | |
| del load_model | |
| accelerator.register_save_state_pre_hook(save_model_hook) | |
| accelerator.register_load_state_pre_hook(load_model_hook) | |
| if gradient_checkpointing: | |
| unet.enable_gradient_checkpointing() | |
| ################################ Make Training dataset ############################### | |
| train_dataset = Video_Dataset(config, device = accelerator.device, tokenizer=tokenizer) | |
| sampler = RandomSampler(train_dataset) | |
| train_dataloader = torch.utils.data.DataLoader( | |
| train_dataset, | |
| sampler = sampler, | |
| batch_size = train_batch_size, | |
| num_workers = dataloader_num_workers * accelerator.num_processes, | |
| ) | |
| ####################################################################################### | |
| ####################################### Optimizer Setting ##################################################################### | |
| if scale_lr: | |
| learning_rate = ( | |
| learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes | |
| ) | |
| # 8bit adam to save more memory (Usally we need this to save the memory) | |
| if use_8bit_adam: | |
| try: | |
| import bitsandbytes as bnb | |
| except ImportError: | |
| raise ImportError( | |
| "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" | |
| ) | |
| optimizer_cls = bnb.optim.AdamW8bit | |
| else: | |
| optimizer_cls = torch.optim.AdamW | |
| # Switch back to unet in training mode | |
| unet.requires_grad_(True) | |
| ############################## For partial fine-tune setting ############################## | |
| parameters_list = [] | |
| for name, param in unet.named_parameters(): | |
| if partial_finetune: # The partial finetune we use is to only train attn layers, which will be ~190M params (TODO:needs to check later for exact value) | |
| # Full Spatial: .transformer_blocks. && spatial_ | |
| # Attn + All emb: attn && emb | |
| if name.find("attn") != -1 or name.find("emb") != -1: # Only block the spatial Transformer | |
| parameters_list.append(param) | |
| param.requires_grad = True | |
| else: | |
| param.requires_grad = False | |
| else: | |
| parameters_list.append(param) | |
| param.requires_grad = True | |
| # Double check what will be trained | |
| total_params_for_training = 0 | |
| # if os.path.exists("param_lists.txt"): | |
| # os.remove("param_lists.txt") | |
| # file1 = open("param_lists.txt","a") | |
| for name, param in unet.named_parameters(): | |
| # file1.write(name + "\n") | |
| if param.requires_grad: | |
| total_params_for_training += param.numel() | |
| print(name + " requires grad update") | |
| print("Total parameter that will be trained has ", total_params_for_training) | |
| ########################################################################################## | |
| # Optimizer creation | |
| optimizer = optimizer_cls( | |
| parameters_list, | |
| lr = learning_rate, | |
| betas = (adam_beta1, adam_beta2), | |
| weight_decay = adam_weight_decay, | |
| eps = adam_epsilon, | |
| ) | |
| # Scheduler and Training steps | |
| dataset_length = len(train_dataset) | |
| print("Dataset length read from the train side is ", dataset_length) | |
| num_update_steps_per_epoch = math.ceil(dataset_length / gradient_accumulation_steps) | |
| max_train_steps = num_train_iters * train_batch_size | |
| # Learning Rate Scheduler (we all use constant) | |
| lr_scheduler = get_scheduler( | |
| "constant", | |
| optimizer = optimizer, | |
| num_warmup_steps = lr_warmup_steps * accelerator.num_processes, | |
| num_training_steps = max_train_steps * accelerator.num_processes, | |
| num_cycles = 1, | |
| power = 1.0, | |
| ) | |
| ##################################################################################################################################### | |
| # Prepare everything with our `accelerator`. | |
| unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | |
| unet, optimizer, train_dataloader, lr_scheduler | |
| ) | |
| # We need to RECALCULATE our total training steps as the size of the training dataloader may have changed. | |
| print("accelerator.num_processes is ", accelerator.num_processes) | |
| print("num_train_iters is ", num_train_iters) | |
| num_train_epochs = math.ceil(num_train_iters * accelerator.num_processes * gradient_accumulation_steps / dataset_length) | |
| print("num_train_epochs is ", num_train_epochs) | |
| # We need to initialize the trackers we use, and also store our configuration. | |
| if accelerator.is_main_process: # Only on the main process! | |
| tracker_config = dict(vars(args)) | |
| accelerator.init_trackers(tracker_project_name, tracker_config) | |
| # Train! | |
| logger.info("***** Running training *****") | |
| logger.info(f" Dataset Length = {dataset_length}") | |
| logger.info(f" Num Epochs = {num_train_epochs}") | |
| logger.info(f" Instantaneous batch size per device = {train_batch_size}") | |
| logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") | |
| logger.info(f" Total optimization steps = {max_train_steps}") | |
| # Load the Closest / Best weight | |
| global_step = 0 # Catch the current iteration | |
| first_epoch = 0 | |
| if resume_from_checkpoint: | |
| if resume_from_checkpoint != "latest": | |
| path = os.path.basename(resume_from_checkpoint) | |
| else: | |
| # Get the most recent checkpoint | |
| dirs = os.listdir(output_dir) | |
| dirs = [d for d in dirs if d.startswith("checkpoint")] | |
| dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) | |
| path = dirs[-1] if len(dirs) > 0 else None | |
| print("We will resume the latest weight ", path) | |
| if path is None: # Don't resume | |
| accelerator.print( | |
| f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run." | |
| ) | |
| resume_from_checkpoint = None | |
| initial_global_step = 0 | |
| else: # Resume from the closest checkpoint | |
| accelerator.print(f"Resuming from checkpoint {path}") | |
| accelerator.load_state(os.path.join(output_dir, path)) | |
| global_step = int(path.split("-")[1]) | |
| initial_global_step = global_step | |
| first_epoch = global_step // num_update_steps_per_epoch | |
| else: | |
| initial_global_step = 0 | |
| if accelerator.is_main_process: | |
| print("Initial Learning rate is ", optimizer.param_groups[0]['lr']) | |
| print("global_step will start from ", global_step) | |
| progress_bar = tqdm( | |
| range(initial_global_step, max_train_steps), | |
| initial=initial_global_step, | |
| desc="Steps", | |
| # Only show the progress bar once on each machine. | |
| disable=not accelerator.is_local_main_process, | |
| ) | |
| # Prepare tensorboard log | |
| writer = SummaryWriter() | |
| ######################################################### Auxiliary Function ################################################################# | |
| def encode_clip(pixel_values, prompt): | |
| ''' Encoder hidden states input source | |
| pixel_values: first frame pixel information | |
| prompt: language prompt with takenized | |
| ''' | |
| ########################################## Prepare the Text Embedding ##################################################### | |
| # pixel_values is in the range [-1, 1] | |
| pixel_values = resize_with_antialiasing(pixel_values, (224, 224)) | |
| pixel_values = (pixel_values + 1.0) / 2.0 # [-1, 1] -> [0, 1] | |
| # Normalize the image with for CLIP input | |
| pixel_values = feature_extractor( | |
| images=pixel_values, | |
| do_normalize=True, | |
| do_center_crop=False, | |
| do_resize=False, | |
| do_rescale=False, | |
| return_tensors="pt", | |
| ).pixel_values | |
| # The following is the same as _encode_image in SVD pipeline | |
| pixel_values = pixel_values.to(device=accelerator.device, dtype=weight_dtype) | |
| image_embeddings = image_encoder(pixel_values).image_embeds | |
| image_embeddings = image_embeddings.unsqueeze(1) | |
| # duplicate image embeddings for each generation per prompt, using mps friendly method | |
| bs_embed, seq_len, _ = image_embeddings.shape | |
| image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) | |
| encoder_hidden_states = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) | |
| ############################################################################################################################# | |
| ########################################## Prepare the Text embedding if needed ############################################# | |
| if config["use_text"]: | |
| text_embeddings = text_encoder(prompt)[0] | |
| # Concat two embeddings together on dim 1 | |
| encoder_hidden_states = torch.cat((text_embeddings, encoder_hidden_states), dim=1) | |
| # Layer norm on the last dim | |
| layer_norm = nn.LayerNorm((78, 1024)).to(device=accelerator.device, dtype=weight_dtype) | |
| encoder_hidden_states_norm = layer_norm(encoder_hidden_states) | |
| # Return | |
| return encoder_hidden_states_norm | |
| else: # Just return back default on | |
| return encoder_hidden_states | |
| ############################################################################################################################# | |
| #################################################################################################################################################### | |
| ############################################################################################################################ | |
| # For the training, we mimic the code from T2I in diffusers | |
| for epoch in range(first_epoch, num_train_epochs): | |
| unet.train() | |
| train_loss = 0.0 | |
| for step, batch in enumerate(train_dataloader): | |
| with accelerator.accumulate(unet): | |
| # batch is a torch tensor with range of [-1, 1] but no other pre-porcessing | |
| video_frames = batch["video_frames"].to(weight_dtype).to(accelerator.device, non_blocking=True) | |
| reflected_motion_bucket_id = batch["reflected_motion_bucket_id"] | |
| prompt = batch["prompt"] | |
| # Images to VAE latent space | |
| latents = tensor_to_vae_latent(video_frames, vae) | |
| ##################################### Add Noise ######################################## | |
| bsz, num_frames = latents.shape[:2] | |
| # Encode the first frame | |
| conditional_pixel_values = video_frames[:, 0, :, :, :] # First frame | |
| # Following AnimateSomething, we use constant to repace cond_sigmas | |
| conditional_pixel_values = conditional_pixel_values + torch.randn_like(conditional_pixel_values) * train_noise_aug_strength | |
| conditional_latents = vae.encode(conditional_pixel_values).latent_dist.mode() # mode() returns mean value no std influence | |
| conditional_latents = repeat(conditional_latents, 'b c h w->b f c h w', f=num_frames) # copied across the frame axis to be the same shape as noise | |
| # Add noise to the latents according to the noise magnitude at each timestep | |
| # This is the forward diffusion process | |
| sigmas = rand_log_normal(shape=[bsz,], loc=config["noise_mean"], scale=config["noise_std"]).to(latents.device) | |
| sigmas = sigmas[:, None, None, None, None] | |
| noisy_latents = latents + torch.randn_like(latents) * sigmas | |
| inp_noisy_latents = noisy_latents / ((sigmas**2 + 1) ** 0.5) | |
| # For the encoder hidden states based on the first frame and prompt | |
| encoder_hidden_states = encode_clip(video_frames[:, 0, :, :, :].float(), prompt) # First Frame + Text Prompt | |
| # Conditioning dropout to support classifier-free guidance during inference. For more details | |
| # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800 (InstructPix2Pix). | |
| if conditioning_dropout_prob != 0: | |
| random_p = torch.rand(bsz, device=latents.device, generator=generator) | |
| # Sample masks for the edit prompts. | |
| prompt_mask = random_p < 2 * conditioning_dropout_prob | |
| prompt_mask = prompt_mask.reshape(bsz, 1, 1) | |
| # Final text conditioning. | |
| null_conditioning = torch.zeros_like(encoder_hidden_states) | |
| encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states) | |
| # Sample masks for the original images. | |
| image_mask_dtype = conditional_latents.dtype | |
| image_mask = 1 - ((random_p >= conditioning_dropout_prob).to(image_mask_dtype) * (random_p < 3 * conditioning_dropout_prob).to(image_mask_dtype)) | |
| image_mask = image_mask.reshape(bsz, 1, 1, 1) | |
| # Final image conditioning. | |
| conditional_latents = image_mask * conditional_latents | |
| # Concatenate the `conditional_latents` with the `noisy_latents`. | |
| inp_noisy_latents = torch.cat([inp_noisy_latents, conditional_latents], dim=2) | |
| # GT noise | |
| target = latents | |
| ########################################################################################## | |
| ################################ Other Embedding and Conditioning ################################### | |
| reflected_motion_bucket_id = torch.sum(reflected_motion_bucket_id)/len(reflected_motion_bucket_id) | |
| reflected_motion_bucket_id = int(reflected_motion_bucket_id.cpu().detach().numpy()) | |
| # print("Training reflected_motion_bucket_id is ", reflected_motion_bucket_id) | |
| added_time_ids = get_add_time_ids( | |
| unet_config, | |
| expected_add_embed_dim, | |
| process_fps, | |
| reflected_motion_bucket_id, | |
| train_noise_aug_strength, | |
| weight_dtype, | |
| train_batch_size, | |
| num_videos_per_prompt, | |
| ) # The same as SVD pipeline's _get_add_time_ids | |
| added_time_ids = added_time_ids.to(accelerator.device) | |
| timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(accelerator.device) | |
| ##################################################################################################### | |
| ###################################### Predict Noise ###################################### | |
| model_pred = unet( | |
| inp_noisy_latents, | |
| timesteps, | |
| encoder_hidden_states, | |
| added_time_ids = added_time_ids | |
| ).sample | |
| # Denoise the latents | |
| c_out = -sigmas / ((sigmas**2 + 1)**0.5) | |
| c_skip = 1 / (sigmas**2 + 1) | |
| denoised_latents = model_pred * c_out + c_skip * noisy_latents | |
| weighing = (1 + sigmas ** 2) * (sigmas**-2.0) | |
| ########################################################################################## | |
| ############################### Calculate Loss and Update Optimizer ####################### | |
| # MSE loss | |
| loss = torch.mean( | |
| ( weighing.float() * (denoised_latents.float() - target.float())**2 ).reshape(target.shape[0], -1), | |
| dim=1, | |
| ) | |
| loss = loss.mean() | |
| # Gather the losses across all processes for logging (if we use distributed training). | |
| avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean() | |
| train_loss += avg_loss.item() / gradient_accumulation_steps | |
| # Update Tensorboard | |
| writer.add_scalar('Loss/train-Loss-Step', avg_loss, global_step) | |
| # Backpropagate | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| ########################################################################################## | |
| # Checks if the accelerator has performed an optimization step behind the scenes | |
| if accelerator.sync_gradients: | |
| progress_bar.update(1) | |
| accelerator.log({"train_loss": train_loss}, step=global_step) | |
| train_loss = 0.0 | |
| ########################################## Checkpoints ######################################### | |
| if global_step != 0 and global_step % checkpointing_steps == 0: | |
| if accelerator.is_main_process: | |
| start = time.time() | |
| # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` | |
| if checkpoints_total_limit is not None: | |
| checkpoints = os.listdir(output_dir) | |
| checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] | |
| checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) | |
| # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints | |
| if len(checkpoints) >= checkpoints_total_limit: | |
| num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 | |
| removing_checkpoints = checkpoints[0:num_to_remove] | |
| logger.info( | |
| f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" | |
| ) | |
| logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") | |
| for removing_checkpoint in removing_checkpoints: | |
| removing_checkpoint = os.path.join(output_dir, removing_checkpoint) | |
| shutil.rmtree(removing_checkpoint) | |
| save_path = os.path.join(output_dir, f"checkpoint-{global_step}") | |
| accelerator.save_state(save_path) | |
| logger.info(f"Saved state to {save_path}") | |
| print("Save time use " + str(time.time() - start) + " s") | |
| ######################################################################################################## | |
| # Update Log | |
| logs = {"step_loss": loss.detach().item(), "lr": optimizer.param_groups[0]['lr']} | |
| progress_bar.set_postfix(**logs) | |
| ##################################### Validation per XXX iterations ####################################### | |
| if accelerator.is_main_process: | |
| if global_step % validation_step == 0: # Fixed 100 steps to validate | |
| if config["validation_img_folder"] is not None: | |
| log_validation( | |
| vae, | |
| unet, | |
| image_encoder, | |
| text_encoder, | |
| tokenizer, | |
| config, | |
| accelerator, | |
| weight_dtype, | |
| global_step, | |
| use_ambiguous_prompt = config["mix_ambiguous"], | |
| ) | |
| ############################################################################################################### | |
| # Update Steps and Break if needed | |
| global_step += 1 | |
| if global_step >= max_train_steps: | |
| break | |
| ############################################################################################################################ | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| config = OmegaConf.load(args.config_path) | |
| main(config) | |