import argparse import diffusers import logging import math import os import time import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint import transformers import warnings import random from accelerate import Accelerator from accelerate.utils import LoggerType from accelerate import InitProcessGroupKwargs from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs from datetime import datetime from datetime import timedelta from diffusers.utils import check_min_version from einops import rearrange from omegaconf import OmegaConf from tqdm.auto import tqdm from musetalk.utils.utils import ( delete_additional_ckpt, seed_everything, get_mouth_region, process_audio_features, save_models ) from musetalk.loss.basic_loss import set_requires_grad from musetalk.loss.syncnet import get_sync_loss from musetalk.utils.training_utils import ( initialize_models_and_optimizers, initialize_dataloaders, initialize_loss_functions, initialize_syncnet, initialize_vgg, validation ) logger = get_logger(__name__, log_level="INFO") warnings.filterwarnings("ignore") check_min_version("0.10.0.dev0") def main(cfg): exp_name = cfg.exp_name save_dir = f"{cfg.output_dir}/{exp_name}" os.makedirs(save_dir, exist_ok=True) kwargs = DistributedDataParallelKwargs() process_group_kwargs = InitProcessGroupKwargs( timeout=timedelta(seconds=5400)) accelerator = Accelerator( gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps, log_with=["tensorboard", LoggerType.TENSORBOARD], project_dir=os.path.join(save_dir, "./tensorboard"), kwargs_handlers=[kwargs, process_group_kwargs], ) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if cfg.seed is not None: print('cfg.seed', cfg.seed, accelerator.process_index) seed_everything(cfg.seed + accelerator.process_index) weight_dtype = torch.float32 model_dict = initialize_models_and_optimizers(cfg, accelerator, weight_dtype) dataloader_dict = initialize_dataloaders(cfg) loss_dict = initialize_loss_functions(cfg, accelerator, model_dict['scheduler_max_steps']) syncnet = initialize_syncnet(cfg, accelerator, weight_dtype) vgg_IN, pyramid, downsampler = initialize_vgg(cfg, accelerator) # Prepare everything with our `accelerator`. model_dict['net'], model_dict['optimizer'], model_dict['lr_scheduler'], dataloader_dict['train_dataloader'], dataloader_dict['val_dataloader'] = accelerator.prepare( model_dict['net'], model_dict['optimizer'], model_dict['lr_scheduler'], dataloader_dict['train_dataloader'], dataloader_dict['val_dataloader'] ) print("length train/val", len(dataloader_dict['train_dataloader']), len(dataloader_dict['val_dataloader'])) # Calculate training steps and epochs num_update_steps_per_epoch = math.ceil( len(dataloader_dict['train_dataloader']) / cfg.solver.gradient_accumulation_steps ) num_train_epochs = math.ceil( cfg.solver.max_train_steps / num_update_steps_per_epoch ) # Initialize trackers on the main process if accelerator.is_main_process: run_time = datetime.now().strftime("%Y%m%d-%H%M") accelerator.init_trackers( cfg.exp_name, init_kwargs={"mlflow": {"run_name": run_time}}, ) # Calculate total batch size total_batch_size = ( cfg.data.train_bs * accelerator.num_processes * cfg.solver.gradient_accumulation_steps ) # Log training information logger.info("***** Running training *****") logger.info(f"Num Epochs = {num_train_epochs}") logger.info(f"Instantaneous batch size per device = {cfg.data.train_bs}") logger.info( f"Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" ) logger.info( f"Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}") logger.info(f"Total optimization steps = {cfg.solver.max_train_steps}") global_step = 0 first_epoch = 0 # Load checkpoint if resuming training if cfg.resume_from_checkpoint: resume_dir = save_dir dirs = os.listdir(resume_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) if len(dirs) > 0: path = dirs[-1] accelerator.load_state(os.path.join(resume_dir, path)) accelerator.print(f"Resuming from checkpoint {path}") global_step = int(path.split("-")[1]) first_epoch = global_step // num_update_steps_per_epoch resume_step = global_step % num_update_steps_per_epoch # Initialize progress bar progress_bar = tqdm( range(global_step, cfg.solver.max_train_steps), disable=not accelerator.is_local_main_process, ) progress_bar.set_description("Steps") # Log model types print("log type of models") print("unet", model_dict['unet'].dtype) print("vae", model_dict['vae'].dtype) print("wav2vec", model_dict['wav2vec'].dtype) def get_ganloss_weight(step): """Calculate GAN loss weight based on training step""" if step < cfg.discriminator_train_params.start_gan: return 0.0 else: return 1.0 # Training loop for epoch in range(first_epoch, num_train_epochs): # Set models to training mode model_dict['unet'].train() if cfg.loss_params.gan_loss > 0: loss_dict['discriminator'].train() if cfg.loss_params.mouth_gan_loss > 0: loss_dict['mouth_discriminator'].train() # Initialize loss accumulators train_loss = 0.0 train_loss_D = 0.0 train_loss_D_mouth = 0.0 l1_loss_accum = 0.0 vgg_loss_accum = 0.0 gan_loss_accum = 0.0 gan_loss_accum_mouth = 0.0 fm_loss_accum = 0.0 sync_loss_accum = 0.0 adapted_weight_accum = 0.0 t_data_start = time.time() for step, batch in enumerate(dataloader_dict['train_dataloader']): t_data = time.time() - t_data_start t_model_start = time.time() with torch.no_grad(): # Process input data pixel_values = batch["pixel_values_vid"].to(weight_dtype).to( accelerator.device, non_blocking=True ) bsz, num_frames, c, h, w = pixel_values.shape # Process reference images ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to( accelerator.device, non_blocking=True ) # Get face mask for GAN pixel_values_face_mask = batch['pixel_values_face_mask'] # Process audio features audio_prompts = process_audio_features(cfg, batch, model_dict['wav2vec'], bsz, num_frames, weight_dtype) # Initialize adapted weight adapted_weight = 1 # Process sync loss if enabled if cfg.loss_params.sync_loss > 0: mels = batch['mel'] # Prepare frames for latentsync (combine channels and frames) gt_frames = rearrange(pixel_values, 'b f c h w-> b (f c) h w') # Use lower half of face for latentsync height = gt_frames.shape[2] gt_frames = gt_frames[:, :, height // 2:, :] # Get audio embeddings audio_embed = syncnet.get_audio_embed(mels) # Calculate adapted weight based on audio-visual similarity if cfg.use_adapted_weight: vision_embed_gt = syncnet.get_vision_embed(gt_frames) image_audio_sim_gt = F.cosine_similarity( audio_embed, vision_embed_gt, dim=1 )[0] if image_audio_sim_gt < 0.05 or image_audio_sim_gt > 0.65: if cfg.adapted_weight_type == "cut_off": adapted_weight = 0.0 # Skip this batch print( f"\nThe i-a similarity in step {global_step} is {image_audio_sim_gt}, set adapted_weight to {adapted_weight}.") elif cfg.adapted_weight_type == "linear": adapted_weight = image_audio_sim_gt else: print(f"unknown adapted_weight_type: {cfg.adapted_weight_type}") adapted_weight = 1 # Random frame selection for memory efficiency max_start = 16 - cfg.num_backward_frames frames_left_index = random.randint(0, max_start) if max_start > 0 else 0 frames_right_index = frames_left_index + cfg.num_backward_frames else: frames_left_index = 0 frames_right_index = cfg.data.n_sample_frames # Extract frames for backward pass pixel_values_backward = pixel_values[:, frames_left_index:frames_right_index, ...] ref_pixel_values_backward = ref_pixel_values[:, frames_left_index:frames_right_index, ...] pixel_values_face_mask_backward = pixel_values_face_mask[:, frames_left_index:frames_right_index, ...] audio_prompts_backward = audio_prompts[:, frames_left_index:frames_right_index, ...] # Encode target images frames = rearrange(pixel_values_backward, 'b f c h w-> (b f) c h w') latents = model_dict['vae'].encode(frames).latent_dist.mode() latents = latents * model_dict['vae'].config.scaling_factor latents = latents.float() # Create masked images masked_pixel_values = pixel_values_backward.clone() masked_pixel_values[:, :, :, h//2:, :] = -1 masked_frames = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w') masked_latents = model_dict['vae'].encode(masked_frames).latent_dist.mode() masked_latents = masked_latents * model_dict['vae'].config.scaling_factor masked_latents = masked_latents.float() # Encode reference images ref_frames = rearrange(ref_pixel_values_backward, 'b f c h w-> (b f) c h w') ref_latents = model_dict['vae'].encode(ref_frames).latent_dist.mode() ref_latents = ref_latents * model_dict['vae'].config.scaling_factor ref_latents = ref_latents.float() # Prepare face mask and audio features pixel_values_face_mask_backward = rearrange( pixel_values_face_mask_backward, "b f c h w -> (b f) c h w" ) audio_prompts_backward = rearrange( audio_prompts_backward, 'b f c h w-> (b f) c h w' ) audio_prompts_backward = rearrange( audio_prompts_backward, '(b f) c h w -> (b f) (c h) w', b=bsz ) # Apply reference dropout (currently inactive) dropout = nn.Dropout(p=cfg.ref_dropout_rate) ref_latents = dropout(ref_latents) # Prepare model inputs input_latents = torch.cat([masked_latents, ref_latents], dim=1) input_latents = input_latents.to(weight_dtype) timesteps = torch.tensor([0], device=input_latents.device) # Forward pass latents_pred = model_dict['net']( input_latents, timesteps, audio_prompts_backward, ) latents_pred = (1 / model_dict['vae'].config.scaling_factor) * latents_pred image_pred = model_dict['vae'].decode(latents_pred).sample # Convert to float image_pred = image_pred.float() frames = frames.float() # Calculate L1 loss l1_loss = loss_dict['L1_loss'](frames, image_pred) l1_loss_accum += l1_loss.item() loss = cfg.loss_params.l1_loss * l1_loss * adapted_weight # Process mouth GAN loss if enabled if cfg.loss_params.mouth_gan_loss > 0: frames_mouth, image_pred_mouth = get_mouth_region( frames, image_pred, pixel_values_face_mask_backward ) pyramide_real_mouth = pyramid(downsampler(frames_mouth)) pyramide_generated_mouth = pyramid(downsampler(image_pred_mouth)) # Process VGG loss if enabled if cfg.loss_params.vgg_loss > 0: pyramide_real = pyramid(downsampler(frames)) pyramide_generated = pyramid(downsampler(image_pred)) loss_IN = 0 for scale in cfg.loss_params.pyramid_scale: x_vgg = vgg_IN(pyramide_generated['prediction_' + str(scale)]) y_vgg = vgg_IN(pyramide_real['prediction_' + str(scale)]) for i, weight in enumerate(cfg.loss_params.vgg_layer_weight): value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() loss_IN += weight * value loss_IN /= sum(cfg.loss_params.vgg_layer_weight) loss += loss_IN * cfg.loss_params.vgg_loss * adapted_weight vgg_loss_accum += loss_IN.item() # Process GAN loss if enabled if cfg.loss_params.gan_loss > 0: set_requires_grad(loss_dict['discriminator'], False) loss_G = 0. discriminator_maps_generated = loss_dict['discriminator'](pyramide_generated) discriminator_maps_real = loss_dict['discriminator'](pyramide_real) for scale in loss_dict['disc_scales']: key = 'prediction_map_%s' % scale value = ((1 - discriminator_maps_generated[key]) ** 2).mean() loss_G += value gan_loss_accum += loss_G.item() loss += loss_G * cfg.loss_params.gan_loss * get_ganloss_weight(global_step) * adapted_weight # Process feature matching loss if enabled if cfg.loss_params.fm_loss[0] > 0: L_feature_matching = 0. for scale in loss_dict['disc_scales']: key = 'feature_maps_%s' % scale for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])): value = torch.abs(a - b).mean() L_feature_matching += value * cfg.loss_params.fm_loss[i] loss += L_feature_matching * adapted_weight fm_loss_accum += L_feature_matching.item() # Process mouth GAN loss if enabled if cfg.loss_params.mouth_gan_loss > 0: set_requires_grad(loss_dict['mouth_discriminator'], False) loss_G = 0. mouth_discriminator_maps_generated = loss_dict['mouth_discriminator'](pyramide_generated_mouth) mouth_discriminator_maps_real = loss_dict['mouth_discriminator'](pyramide_real_mouth) for scale in loss_dict['disc_scales']: key = 'prediction_map_%s' % scale value = ((1 - mouth_discriminator_maps_generated[key]) ** 2).mean() loss_G += value gan_loss_accum_mouth += loss_G.item() loss += loss_G * cfg.loss_params.mouth_gan_loss * get_ganloss_weight(global_step) * adapted_weight # Process feature matching loss for mouth if enabled if cfg.loss_params.fm_loss[0] > 0: L_feature_matching = 0. for scale in loss_dict['disc_scales']: key = 'feature_maps_%s' % scale for i, (a, b) in enumerate(zip(mouth_discriminator_maps_real[key], mouth_discriminator_maps_generated[key])): value = torch.abs(a - b).mean() L_feature_matching += value * cfg.loss_params.fm_loss[i] loss += L_feature_matching * adapted_weight fm_loss_accum += L_feature_matching.item() # Process sync loss if enabled if cfg.loss_params.sync_loss > 0: pred_frames = rearrange( image_pred, '(b f) c h w-> b (f c) h w', f=pixel_values_backward.shape[1]) pred_frames = pred_frames[:, :, height // 2 :, :] sync_loss, image_audio_sim_pred = get_sync_loss( audio_embed, gt_frames, pred_frames, syncnet, adapted_weight, frames_left_index=frames_left_index, frames_right_index=frames_right_index, ) sync_loss_accum += sync_loss.item() loss += sync_loss * cfg.loss_params.sync_loss * adapted_weight # Backward pass avg_loss = accelerator.gather(loss.repeat(cfg.data.train_bs)).mean() train_loss += avg_loss.item() accelerator.backward(loss) # Train discriminator if GAN loss is enabled if cfg.loss_params.gan_loss > 0: set_requires_grad(loss_dict['discriminator'], True) loss_D = loss_dict['discriminator_full'](frames, image_pred.detach()) avg_loss_D = accelerator.gather(loss_D.repeat(cfg.data.train_bs)).mean() train_loss_D += avg_loss_D.item() / 1 loss_D = loss_D * get_ganloss_weight(global_step) * adapted_weight accelerator.backward(loss_D) if accelerator.sync_gradients: accelerator.clip_grad_norm_( loss_dict['discriminator'].parameters(), cfg.solver.max_grad_norm) if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0: loss_dict['optimizer_D'].step() loss_dict['scheduler_D'].step() loss_dict['optimizer_D'].zero_grad() # Train mouth discriminator if mouth GAN loss is enabled if cfg.loss_params.mouth_gan_loss > 0: set_requires_grad(loss_dict['mouth_discriminator'], True) mouth_loss_D = loss_dict['mouth_discriminator_full']( frames_mouth, image_pred_mouth.detach()) avg_mouth_loss_D = accelerator.gather( mouth_loss_D.repeat(cfg.data.train_bs)).mean() train_loss_D_mouth += avg_mouth_loss_D.item() / 1 mouth_loss_D = mouth_loss_D * get_ganloss_weight(global_step) * adapted_weight accelerator.backward(mouth_loss_D) if accelerator.sync_gradients: accelerator.clip_grad_norm_( loss_dict['mouth_discriminator'].parameters(), cfg.solver.max_grad_norm) if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0: loss_dict['mouth_optimizer_D'].step() loss_dict['mouth_scheduler_D'].step() loss_dict['mouth_optimizer_D'].zero_grad() # Update main model if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0: if accelerator.sync_gradients: accelerator.clip_grad_norm_( model_dict['trainable_params'], cfg.solver.max_grad_norm, ) model_dict['optimizer'].step() model_dict['lr_scheduler'].step() model_dict['optimizer'].zero_grad() # Update progress and log metrics if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 accelerator.log({ "train_loss": train_loss, "train_loss_D": train_loss_D, "train_loss_D_mouth": train_loss_D_mouth, "l1_loss": l1_loss_accum, "vgg_loss": vgg_loss_accum, "gan_loss": gan_loss_accum, "fm_loss": fm_loss_accum, "sync_loss": sync_loss_accum, "adapted_weight": adapted_weight_accum, "lr": model_dict['lr_scheduler'].get_last_lr()[0], }, step=global_step) # Reset loss accumulators train_loss = 0.0 l1_loss_accum = 0.0 vgg_loss_accum = 0.0 gan_loss_accum = 0.0 fm_loss_accum = 0.0 sync_loss_accum = 0.0 adapted_weight_accum = 0.0 train_loss_D = 0.0 train_loss_D_mouth = 0.0 # Run validation if needed if global_step % cfg.val_freq == 0 or global_step == 10: try: validation( cfg, dataloader_dict['val_dataloader'], model_dict['net'], model_dict['vae'], model_dict['wav2vec'], accelerator, save_dir, global_step, weight_dtype, syncnet_score=adapted_weight, ) except Exception as e: print(f"An error occurred during validation: {e}") # Save checkpoint if needed if global_step % cfg.checkpointing_steps == 0: save_path = os.path.join(save_dir, f"checkpoint-{global_step}") try: start_time = time.time() if accelerator.is_main_process: save_models( accelerator, model_dict['net'], save_dir, global_step, cfg, logger=logger ) delete_additional_ckpt(save_dir, cfg.total_limit) elapsed_time = time.time() - start_time if elapsed_time > 300: print(f"Skipping storage as it took too long in step {global_step}.") else: print(f"Resume states saved at {save_dir} successfully in {elapsed_time}s.") except Exception as e: print(f"Error when saving model in step {global_step}:", e) # Update progress bar t_model = time.time() - t_model_start logs = { "step_loss": loss.detach().item(), "lr": model_dict['lr_scheduler'].get_last_lr()[0], "td": f"{t_data:.2f}s", "tm": f"{t_model:.2f}s", } t_data_start = time.time() progress_bar.set_postfix(**logs) if global_step >= cfg.solver.max_train_steps: break # Save model after each epoch if (epoch + 1) % cfg.save_model_epoch_interval == 0: try: start_time = time.time() if accelerator.is_main_process: save_models(accelerator, model_dict['net'], save_dir, global_step, cfg) accelerator.save_state(save_path) elapsed_time = time.time() - start_time if elapsed_time > 120: print(f"Skipping storage as it took too long in step {global_step}.") else: print(f"Model saved successfully in {elapsed_time}s.") except Exception as e: print(f"Error when saving model in step {global_step}:", e) accelerator.wait_for_everyone() # End training accelerator.end_training() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="./configs/training/stage2.yaml") args = parser.parse_args() config = OmegaConf.load(args.config) main(config)