Spaces:
Build error
Build error
| import os | |
| import shutil | |
| import argparse | |
| import random | |
| import numpy as np | |
| from datetime import datetime | |
| from tqdm import tqdm | |
| import importlib | |
| import copy | |
| import librosa | |
| from pathlib import Path | |
| import json | |
| import time | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader, WeightedRandomSampler | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| import wandb | |
| from diffusers.optimization import get_scheduler | |
| from omegaconf import OmegaConf | |
| from emage_evaltools.mertic import FGD, BC, L1div | |
| from emage_utils.motion_io import beat_format_load, beat_format_save, MASK_DICT, recover_from_mask, recover_from_mask_ts | |
| import emage_utils.rotation_conversions as rc | |
| from emage_utils import fast_render | |
| from emage_utils.motion_rep_transfer import get_motion_rep_numpy | |
| # --------------------------------- loss here --------------------------------- # | |
| class GeodesicLoss(nn.Module): | |
| def __init__(self): | |
| super(GeodesicLoss, self).__init__() | |
| def compute_geodesic_distance(self, m1, m2): | |
| m1 = m1.reshape(-1, 3, 3) | |
| m2 = m2.reshape(-1, 3, 3) | |
| m = torch.bmm(m1, m2.transpose(1, 2)) | |
| cos = (m[:, 0, 0] + m[:, 1, 1] + m[:, 2, 2] - 1) / 2 | |
| cos = torch.clamp(cos, min=-1 + 1E-6, max=1-1E-6) | |
| theta = torch.acos(cos) | |
| return theta | |
| def __call__(self, m1, m2, reduction='mean'): | |
| loss = self.compute_geodesic_distance(m1, m2) | |
| if reduction == 'mean': | |
| return loss.mean() | |
| elif reduction == 'none': | |
| return loss | |
| else: | |
| raise RuntimeError | |
| GeodesicLossFn = GeodesicLoss() | |
| def contrastive_loss(features, labels, margin=1.0): | |
| # features: [bs, n, c] | |
| # labels: [bs, 1] | |
| # first, reduce features along time (or sequence) dimension | |
| feats = features.mean(dim=1) # [bs, c] | |
| lbs = labels.squeeze(-1) # [bs] | |
| # compute pairwise distances | |
| dist = torch.cdist(feats, feats, p=2) # [bs, bs] | |
| pos_mask = (lbs.unsqueeze(0) == lbs.unsqueeze(1)).float() # [bs, bs] | |
| # positive pairs: distance should be small | |
| pos_loss = pos_mask * dist | |
| # negative pairs: distance should be large | |
| # margin-based loss | |
| neg_loss = (1.0 - pos_mask) * F.relu(margin - dist) | |
| return pos_loss.mean() + neg_loss.mean() | |
| def get_weighted_sampler(dataset): | |
| # Collect labels | |
| labels = [] | |
| for item in dataset.data_list: | |
| labels.append(item["content_label"]) | |
| labels = np.array(labels) | |
| class_counts = np.bincount(labels) | |
| weights = 1.0 / class_counts[labels] | |
| sampler = WeightedRandomSampler( | |
| weights=weights, | |
| num_samples=len(weights), # Usually same as dataset size | |
| replacement=True # Typically True for weighted sampling | |
| ) | |
| return sampler | |
| # --------------------------------- train,val,test fn here --------------------------------- # | |
| def inference_fn(cfg, model, device, test_path, save_path): | |
| actual_model = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model | |
| actual_model.eval() | |
| torch.set_grad_enabled(False) | |
| test_list = [] | |
| for data_meta_path in test_path: | |
| test_list.extend(json.load(open(data_meta_path, "r"))) | |
| test_list = [item for item in test_list if item.get("mode") == "test"] | |
| seen_ids = set() | |
| test_list = [item for item in test_list if not (item["video_id"] in seen_ids or seen_ids.add(item["video_id"]))] | |
| save_list = [] | |
| start_time = time.time() | |
| total_length = 0 | |
| for test_file in tqdm(test_list, desc="Testing"): | |
| audio, _ = librosa.load(test_file["audio_path"], sr=cfg.audio_sr) | |
| audio = torch.from_numpy(audio).to(device).unsqueeze(0) | |
| speaker_id = torch.zeros(1,1).to(device).long() | |
| motion_pred = actual_model(audio, speaker_id, seed_frames=4, seed_motion=None)["motion_axis_angle"] | |
| t = motion_pred.shape[1] | |
| motion_pred = motion_pred.cpu().numpy().reshape(t, -1) | |
| beat_format_save(os.path.join(save_path, f"{test_file['video_id']}_output.npz"), motion_pred, upsample=30//cfg.pose_fps) | |
| save_list.append( | |
| { | |
| "audio_path": test_file["audio_path"], | |
| "motion_path": os.path.join(save_path, f"{test_file['video_id']}_output.npz"), | |
| "video_id": test_file["video_id"], | |
| } | |
| ) | |
| total_length+=t | |
| time_cost = time.time() - start_time | |
| print(f"\n cost {time_cost:.2f} seconds to generate {total_length / cfg.pose_fps:.2f} seconds of motion") | |
| return test_list, save_list | |
| def train_val_fn(cfg, batch, model, device, mode="train", optimizer=None, lr_scheduler=None, fgd_evaluator=None): | |
| model.train() if mode == "train" else model.eval() | |
| torch.set_grad_enabled(mode == "train") | |
| joint_mask = MASK_DICT[cfg.model.joint_mask] | |
| if mode == "train": | |
| optimizer.zero_grad() | |
| motion_gt = batch["motion"].to(device) | |
| audio = batch["audio"].to(device) | |
| rhythm = batch["rhythm_label"].to(device) | |
| content = batch["content_label"].to(device) | |
| bs, t, jc = motion_gt.shape | |
| j = jc // 3 | |
| speaker_id = torch.zeros(bs,1).to(device).long() | |
| motion_gt = rc.axis_angle_to_rotation_6d(motion_gt.reshape(bs,t,j,3)).reshape(bs, t, j*6) | |
| all_pred = model(audio, speaker_id, seed_frames=4, seed_motion=motion_gt, return_axis_angle=False) | |
| motion_pred = all_pred["motion"] | |
| motion_pred = rc.rotation_6d_to_matrix(motion_pred.reshape(bs,t,j,6)) | |
| motion_gt = rc.rotation_6d_to_matrix(motion_gt.reshape(bs,t,j,6)) | |
| loss = GeodesicLossFn(motion_pred, motion_gt) | |
| loss_dict = {"loss": loss} | |
| # feature disentanglement loss | |
| rhythm_fea = all_pred["audio_fea_r"] | |
| content_fea = all_pred["audio_fea_c"] | |
| # if two features are the same rhythm class, the distance should be small, other wise large | |
| rhythm_fea = F.normalize(rhythm_fea, dim=1) | |
| content_fea = F.normalize(content_fea, dim=1) | |
| rhythm_disentangle_loss = contrastive_loss(rhythm_fea, rhythm) | |
| content_disentangle_loss = contrastive_loss(content_fea, content) | |
| loss_dict["rhythm"] = rhythm_disentangle_loss | |
| loss_dict["content"] = content_disentangle_loss | |
| all_loss = sum(loss_dict.values()) | |
| loss_dict["all_loss"] = all_loss | |
| if mode == "train": | |
| if cfg.solver.max_grad_norm > 0: | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.solver.max_grad_norm) | |
| all_loss.backward() | |
| optimizer.step() | |
| lr_scheduler.step() | |
| if mode == "val": | |
| motion_pred = rc.matrix_to_rotation_6d(motion_pred).reshape(bs, t, j*6) | |
| motion_gt = rc.matrix_to_rotation_6d(motion_gt).reshape(bs, t, j*6) | |
| padded_pred = recover_from_mask_ts(motion_pred, joint_mask) | |
| padded_gt = recover_from_mask_ts(motion_gt, joint_mask) | |
| fgd_evaluator.update(padded_pred, padded_gt) | |
| return loss_dict | |
| # --------------------------------- main train loop here --------------------------------- # | |
| def main(cfg): | |
| seed_everything(cfg.seed) | |
| os.environ["WANDB_API_KEY"] = cfg.wandb_key | |
| local_rank = int(os.environ["LOCAL_RANK"]) if "LOCAL_RANK" in os.environ else 0 | |
| torch.cuda.set_device(local_rank) | |
| device = torch.device("cuda", local_rank) | |
| torch.distributed.init_process_group(backend="nccl") | |
| log_dir = os.path.join(cfg.output_dir, cfg.exp_name) | |
| experiment_ckpt_dir = os.path.join(log_dir, "checkpoints") | |
| os.makedirs(experiment_ckpt_dir, exist_ok=True) | |
| if local_rank == 0 and cfg.validation.wandb: | |
| wandb.init( | |
| project=cfg.wandb_project, | |
| name=cfg.exp_name, | |
| entity=cfg.wandb_entity, | |
| dir=log_dir, | |
| config=OmegaConf.to_container(cfg) | |
| ) | |
| # init | |
| if cfg.test: | |
| from models.disco_audio import DiscoAudioModel | |
| model = DiscoAudioModel.from_pretrained("/content/outputs/disco_audio/checkpoints/last").to(device) | |
| else: | |
| model = init_hf_class(cfg.model.name_pyfile, cfg.model.class_name, cfg.model).to(device) | |
| model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) | |
| # optimizer | |
| optimizer_cls = torch.optim.Adam | |
| optimizer = optimizer_cls( | |
| filter(lambda p: p.requires_grad, model.parameters()), | |
| lr=cfg.solver.learning_rate, | |
| betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2), | |
| weight_decay=cfg.solver.adam_weight_decay, | |
| eps=cfg.solver.adam_epsilon | |
| ) | |
| lr_scheduler = get_scheduler( | |
| cfg.solver.lr_scheduler, | |
| optimizer=optimizer, | |
| num_warmup_steps=cfg.solver.lr_warmup_steps * cfg.solver.gradient_accumulation_steps, | |
| num_training_steps=cfg.solver.max_train_steps * cfg.solver.gradient_accumulation_steps | |
| ) | |
| # dataset | |
| train_dataset = init_class(cfg.data.name_pyfile, cfg.data.class_name, cfg, split='train') | |
| test_dataset = init_class(cfg.data.name_pyfile, cfg.data.class_name, cfg, split='test') | |
| train_sampler = get_weighted_sampler(train_dataset) | |
| test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) | |
| train_loader = DataLoader(train_dataset, batch_size=cfg.data.train_bs, sampler=train_sampler, drop_last=True, num_workers=8) | |
| test_loader = DataLoader(test_dataset, batch_size=cfg.data.train_bs, sampler=test_sampler, drop_last=False, num_workers=8) | |
| # resume | |
| if cfg.resume_from_checkpoint: | |
| checkpoint = torch.load(cfg.resume_from_checkpoint, map_location="cpu") | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | |
| lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) | |
| iteration = checkpoint["iteration"] | |
| else: | |
| iteration = 0 | |
| if cfg.test: | |
| iteration = 0 | |
| max_epochs = (cfg.solver.max_train_steps // len(train_loader)) + (1 if cfg.solver.max_train_steps % len(train_loader) != 0 else 0) | |
| start_epoch = iteration // len(train_loader) | |
| start_step_in_epoch = iteration % len(train_loader) | |
| fgd_evaluator = FGD(download_path="./emage_evaltools/") | |
| bc_evaluator = BC(download_path="./emage_evaltools/", sigma=0.3, order=7) | |
| l1div_evaluator= L1div() | |
| loss_meters = {} | |
| loss_meters_val = {} | |
| best_fgd_val = np.inf | |
| best_fgd_iteration_val= 0 | |
| best_fgd_test = np.inf | |
| best_fgd_iteration_test = 0 | |
| # train loop | |
| data_start = time.time() | |
| for epoch in range(start_epoch, max_epochs): | |
| # train_sampler.set_epoch(epoch) | |
| pbar = tqdm(train_loader, leave=True) | |
| for i, batch in enumerate(pbar): | |
| # for correct resume, if the dataset is very large. since we fixed the seed, we can skip the data | |
| if i < start_step_in_epoch: | |
| iteration += 1 | |
| continue | |
| # test | |
| if iteration % cfg.validation.test_steps == 0 and local_rank == 0: | |
| test_save_path = os.path.join(log_dir, f"test_{iteration}") | |
| os.makedirs(test_save_path, exist_ok=True) | |
| with torch.no_grad(): | |
| test_list, save_list = inference_fn(cfg.model, model, device, cfg.data.test_meta_paths, test_save_path) | |
| if cfg.validation.evaluation: | |
| metrics = evaluation_fn([True]*55, test_list, save_list, fgd_evaluator, bc_evaluator, l1div_evaluator, device) | |
| if cfg.validation.visualization: visualization_fn(save_list, test_save_path, test_list, only_check_one=True) | |
| if cfg.validation.evaluation: best_fgd_test, best_fgd_iteration_test = log_test(model, metrics, iteration, best_fgd_test, best_fgd_iteration_test, cfg, local_rank, experiment_ckpt_dir, test_save_path) | |
| if cfg.test: return 0 | |
| # validation | |
| if iteration % cfg.validation.validation_steps == 0: | |
| loss_meters = {} | |
| loss_meters_val = {} | |
| fgd_evaluator.reset() | |
| pbar_val = tqdm(test_loader, leave=True) | |
| data_start_val = time.time() | |
| for j, batch in enumerate(pbar_val): | |
| data_time_val = time.time() - data_start_val | |
| with torch.no_grad(): | |
| val_loss_dict = train_val_fn(cfg, batch, model, device, mode="val", fgd_evaluator=fgd_evaluator) | |
| net_time_val = time.time() - data_start_val | |
| val_loss_dict["fgd"] = fgd_evaluator.compute() if j == len(test_loader) - 1 else 0 | |
| log_train_val(cfg, val_loss_dict, local_rank, loss_meters_val, pbar_val, epoch, max_epochs, iteration, net_time_val, data_time_val, optimizer, "Val ") | |
| data_start_val = time.time() | |
| if cfg.debug and j > 1: break | |
| if local_rank == 0: | |
| best_fgd_val, best_fgd_iteration_val = save_last_and_best_ckpt( | |
| model, optimizer, lr_scheduler, iteration, experiment_ckpt_dir, best_fgd_val, best_fgd_iteration_val, val_loss_dict["fgd"], lower_is_better=True, mertic_name="fgd") | |
| # train | |
| data_time = time.time() - data_start | |
| loss_dict = train_val_fn(cfg, batch, model, device, mode="train", optimizer=optimizer, lr_scheduler=lr_scheduler) | |
| net_time = time.time() - data_start - data_time | |
| log_train_val(cfg, loss_dict, local_rank, loss_meters, pbar, epoch, max_epochs, iteration, net_time, data_time, optimizer, "Train") | |
| data_start = time.time() | |
| iteration += 1 | |
| start_step_in_epoch = 0 | |
| epoch += 1 | |
| if local_rank == 0 and cfg.validation.wandb: | |
| wandb.finish() | |
| torch.distributed.destroy_process_group() | |
| # --------------------------------- utils fn here --------------------------------- # | |
| def evaluation_fn(joint_mask, gt_list, pred_list, fgd_evaluator, bc_evaluator, l1_evaluator, device): | |
| fgd_evaluator.reset() | |
| bc_evaluator.reset() | |
| l1_evaluator.reset() | |
| # lvd_evaluator.reset() | |
| # mse_evaluator.reset() | |
| for test_file in tqdm(gt_list, desc="Evaluation"): | |
| # only load selective joints | |
| pred_file = [item for item in pred_list if item["video_id"] == test_file["video_id"]][0] | |
| if not pred_file: | |
| print(f"Missing prediction for {test_file['video_id']}") | |
| continue | |
| # print(test_file["motion_path"], pred_file["motion_path"]) | |
| gt_dict = beat_format_load(test_file["motion_path"], joint_mask) | |
| pred_dict = beat_format_load(pred_file["motion_path"], joint_mask) | |
| motion_gt = gt_dict["poses"] | |
| motion_pred = pred_dict["poses"] | |
| # expressions_gt = gt_dict["expressions"] | |
| # expressions_pred = pred_dict["expressions"] | |
| betas = gt_dict["betas"] | |
| # motion_gt = recover_from_mask(motion_gt, joint_mask) # t1*165 | |
| # motion_pred = recover_from_mask(motion_pred, joint_mask) # t2*165 | |
| t = min(motion_gt.shape[0], motion_pred.shape[0]) | |
| motion_gt = motion_gt[:t] | |
| motion_pred = motion_pred[:t] | |
| # expressions_gt = expressions_gt[:t] | |
| # expressions_pred = expressions_pred[:t] | |
| # bc and l1 require position representation | |
| motion_position_pred = get_motion_rep_numpy(motion_pred, device=device, betas=betas)["position"] # t*55*3 | |
| motion_position_pred = motion_position_pred.reshape(t, -1) | |
| # ignore the start and end 2s, this may for beat dataset only | |
| audio_beat = bc_evaluator.load_audio(test_file["audio_path"], t_start=2 * 16000, t_end=int((t-60)/30*16000)) | |
| motion_beat = bc_evaluator.load_motion(motion_position_pred, t_start=60, t_end=t-60, pose_fps=30, without_file=True) | |
| bc_evaluator.compute(audio_beat, motion_beat, length=t-120, pose_fps=30) | |
| # audio_beat = bc_evaluator.load_audio(test_file["audio_path"], t_start=0 * 16000, t_end=int((t-0)/30*16000)) | |
| # motion_beat = bc_evaluator.load_motion(motion_position_pred, t_start=0, t_end=t-0, pose_fps=30, without_file=True) | |
| # bc_evaluator.compute(audio_beat, motion_beat, length=t-0, pose_fps=30) | |
| l1_evaluator.compute(motion_position_pred) | |
| # face_position_pred = get_motion_rep_numpy(motion_pred, device=device, expressions=expressions_pred, expression_only=True, betas=betas)["vertices"] # t -1 | |
| # face_position_gt = get_motion_rep_numpy(motion_gt, device=device, expressions=expressions_gt, expression_only=True, betas=betas)["vertices"] | |
| # lvd_evaluator.compute(face_position_pred, face_position_gt) | |
| # mse_evaluator.compute(face_position_pred, face_position_gt) | |
| # fgd requires rotation 6d representaiton | |
| motion_gt = torch.from_numpy(motion_gt).to(device).unsqueeze(0) | |
| motion_pred = torch.from_numpy(motion_pred).to(device).unsqueeze(0) | |
| motion_gt = rc.axis_angle_to_rotation_6d(motion_gt.reshape(1, t, 55, 3)).reshape(1, t, 55*6) | |
| motion_pred = rc.axis_angle_to_rotation_6d(motion_pred.reshape(1, t, 55, 3)).reshape(1, t, 55*6) | |
| fgd_evaluator.update(motion_pred.float(), motion_gt.float()) | |
| metrics = {} | |
| metrics["fgd"] = fgd_evaluator.compute() | |
| metrics["bc"] = bc_evaluator.avg() | |
| metrics["l1"] = l1_evaluator.avg() | |
| # metrics["lvd"] = lvd_evaluator.avg() | |
| # metrics["mse"] = mse_evaluator.avg() | |
| return metrics | |
| def visualization_fn(pred_list, save_path, gt_list=None, only_check_one=True): | |
| if gt_list is None: # single visualization | |
| for i in range(len(pred_list)): | |
| fast_render.render_one_sequence( | |
| pred_list[i]["motion_path"], | |
| save_path, | |
| pred_list[i]["audio_path"], | |
| model_folder="./evaluation/smplx_models/", | |
| ) | |
| if only_check_one: break | |
| else: # paired visualization, pad the translation | |
| for i in range(len(pred_list)): | |
| npz_pred = np.load(pred_list[i]["motion_path"], allow_pickle=True) | |
| gt_file = [item for item in gt_list if item["video_id"] == pred_list[i]["video_id"]][0] | |
| if not gt_file: | |
| print(f"Missing prediction for {pred_list[i]['video_id']}") | |
| continue | |
| npz_gt = np.load(gt_file["motion_path"], allow_pickle=True) | |
| t = npz_gt["poses"].shape[0] | |
| np.savez( | |
| os.path.join(save_path, f"{pred_list[i]['video_id']}_transpad.npz"), | |
| betas=npz_pred['betas'][:t], | |
| poses=npz_pred['poses'][:t], | |
| expressions=npz_pred['expressions'][:t], | |
| trans=npz_pred["trans"][:t], | |
| model='smplx2020', | |
| gender='neutral', | |
| mocap_frame_rate=30, | |
| ) | |
| fast_render.render_one_sequence( | |
| os.path.join(save_path, f"{pred_list[i]['video_id']}_transpad.npz"), | |
| gt_file["motion_path"], | |
| save_path, | |
| pred_list[i]["audio_path"], | |
| model_folder="./evaluation/smplx_models/", | |
| ) | |
| if only_check_one: break | |
| def log_test(model, metrics, iteration, best_mertics, best_iteration, cfg, local_rank, experiment_ckpt_dir, video_save_path=None): | |
| if local_rank == 0: | |
| print(f"\n Test Results at iteration {iteration}:") | |
| for key, value in metrics.items(): | |
| print(f" {key}: {value:.10f}") | |
| if cfg.validation.wandb: | |
| for key, value in metrics.items(): | |
| wandb.log({f"test/{key}": value}, step=iteration) | |
| if cfg.validation.wandb and cfg.validation.visualization: | |
| videos_to_log = [] | |
| for filename in os.listdir(video_save_path): | |
| if filename.endswith(".mp4"): | |
| videos_to_log.append(wandb.Video(os.path.join(video_save_path, filename))) | |
| if videos_to_log: | |
| wandb.log({"test/videos": videos_to_log}, step=iteration) | |
| if metrics["fgd"] < best_mertics: | |
| best_mertics = metrics["fgd"] | |
| best_iteration = iteration | |
| model.module.save_pretrained(os.path.join(experiment_ckpt_dir, "test_best")) | |
| # print(metrics, best_mertics, best_iteration) | |
| message = f"Current Test FGD: {metrics['fgd']:.4f} (Best: {best_mertics:.4f} at iteration {best_iteration})" | |
| log_metric_with_box(message) | |
| return best_mertics, best_iteration | |
| def log_metric_with_box(message): | |
| box_width = len(message) + 2 | |
| border = "-" * box_width | |
| print(f"\n{border}") | |
| print(f"|{message}|") | |
| print(f"{border}\n") | |
| def log_train_val(cfg, loss_dict, local_rank, loss_meters, pbar, epoch, max_epochs, iteration, net_time, data_time, optimizer, ptype="Train"): | |
| new_loss_dict = {} | |
| for k, v in loss_dict.items(): | |
| if "fgd" in k: continue | |
| v_cpu = torch.as_tensor(v).float().cpu().item() | |
| if k not in loss_meters: | |
| loss_meters[k] = {"sum":0,"count":0} | |
| loss_meters[k]["sum"] += v_cpu | |
| loss_meters[k]["count"] += 1 | |
| new_loss_dict[k] = v_cpu | |
| mem_used = torch.cuda.memory_reserved() / 1E9 | |
| lr = optimizer.param_groups[0]["lr"] | |
| loss_str = " ".join([f"{k}: {new_loss_dict[k]:.4f}({loss_meters[k]['sum']/loss_meters[k]['count']:.4f})" for k in new_loss_dict]) | |
| desc = f"{ptype}: Epoch[{epoch}/{max_epochs}] Iter[{iteration}] {loss_str} lr: {lr:.2E} data_time: {data_time:.3f} net_time: {net_time:.3f} mem: {mem_used:.2f}GB" | |
| pbar.set_description(desc) | |
| pbar.bar_format = "{desc} {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]" | |
| if cfg.validation.wandb and local_rank == 0: | |
| for k, v in new_loss_dict.items(): | |
| wandb.log({f"loss/{ptype}/{k}": v}, step=iteration) | |
| def save_last_and_best_ckpt(model, optimizer, lr_scheduler, iteration, save_dir, previous_best, best_iteration, current, lower_is_better=True, mertic_name="fgd"): | |
| checkpoint = { | |
| "model_state_dict": model.state_dict(), | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "lr_scheduler_state_dict": lr_scheduler.state_dict(), | |
| "iteration": iteration, | |
| } | |
| torch.save(checkpoint, os.path.join(save_dir, "last.bin")) | |
| model.module.save_pretrained(os.path.join(save_dir, "last")) | |
| if (lower_is_better and current < previous_best) or (not lower_is_better and current > previous_best): | |
| previous_best = current | |
| best_iteration = iteration | |
| shutil.copy(os.path.join(save_dir, "last.bin"), os.path.join(save_dir, "best.bin")) | |
| model.module.save_pretrained(os.path.join(save_dir, "best")) | |
| message = f"Current interation {iteration} {mertic_name}: {current:.4f} (Best: {previous_best:.4f} at iteration {best_iteration})" | |
| log_metric_with_box(message) | |
| return previous_best, best_iteration | |
| def init_hf_class(module_name, class_name, config, **kwargs): | |
| module = importlib.import_module(module_name) | |
| model_class = getattr(module, class_name) | |
| config_class = model_class.config_class | |
| config = config_class(config_obj=config) | |
| instance = model_class(config, **kwargs) | |
| return instance | |
| def init_class(module_name, class_name, config, **kwargs): | |
| module = importlib.import_module(module_name) | |
| model_class = getattr(module, class_name) | |
| instance = model_class(config, **kwargs) | |
| return instance | |
| def seed_everything(seed): | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cudnn.enabled = True | |
| def init_env(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", type=str, default="./configs/train/stage2.yaml") | |
| parser.add_argument("--debug", action="store_true") | |
| parser.add_argument("--wandb", action="store_true") | |
| parser.add_argument("--visualization", action="store_true") | |
| parser.add_argument("--evaluation", action="store_true") | |
| parser.add_argument("--test", action="store_true") | |
| parser.add_argument('overrides', nargs=argparse.REMAINDER) | |
| args = parser.parse_args() | |
| config = OmegaConf.load(args.config) | |
| config.exp_name = os.path.splitext(os.path.basename(args.config))[0] | |
| if args.overrides: config = OmegaConf.merge(config, OmegaConf.from_dotlist(args.overrides)) | |
| if args.debug: | |
| config.wandb_project = "debug" | |
| config.exp_name = "debug" | |
| config.solver.max_train_steps = 4 | |
| else: | |
| run_time = datetime.now().strftime("%Y%m%d-%H%M") | |
| config.exp_name = config.exp_name + "_" + run_time | |
| if args.wandb: | |
| config.validation.wandb = True | |
| if args.visualization: | |
| config.validation.visualization = True | |
| if args.evaluation: | |
| config.validation.evaluation = True | |
| if args.test: | |
| config.test = True | |
| save_dir = os.path.join(config.output_dir, config.exp_name) | |
| os.makedirs(save_dir, exist_ok=True) | |
| sanity_check_dir = os.path.join(save_dir, 'sanity_check') | |
| os.makedirs(sanity_check_dir, exist_ok=True) | |
| with open(os.path.join(sanity_check_dir, f'{config.exp_name}.yaml'), 'w') as f: | |
| OmegaConf.save(config, f) | |
| current_dir = Path.cwd() | |
| for py_file in current_dir.rglob('*.py'): | |
| dest_path = Path(sanity_check_dir) / py_file.relative_to(current_dir) | |
| dest_path.parent.mkdir(parents=True, exist_ok=True) | |
| shutil.copy(py_file, dest_path) | |
| return config | |
| if __name__ == "__main__": | |
| config = init_env() | |
| main(config) |