import torch import math import time import torch.distributed as dist import logging ################################################################################# # DDP Functions # ################################################################################# def cleanup(): dist.destroy_process_group() ################################################################################# # Util Functions # ################################################################################# def lengths_to_mask(lengths, max_len): # max_len = max(lengths) mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) return mask #(b, len) def get_mask_subset_prob(mask, prob): subset_mask = torch.bernoulli(mask, p=prob) & mask return subset_mask def uniform(shape, device=None): return torch.zeros(shape, device=device).float().uniform_(0, 1) def cosine_schedule(t): return torch.cos(t * math.pi * 0.5) def update_ema(model, ema_model, ema_decay): with torch.no_grad(): for ema_param, model_param in zip(ema_model.parameters(), model.parameters()): ema_param.data.mul_(ema_decay).add_(model_param.data, alpha=(1 - ema_decay)) ################################################################################# # Logging Functions # ################################################################################# def def_value(): return 0.0 def create_logger(logging_dir): if dist.get_rank() == 0: # real logger logging.basicConfig( level=logging.INFO, format='[\033[34m%(asctime)s\033[0m] %(message)s', datefmt='%Y-%m-%d %H:%M:%S', handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] ) logger = logging.getLogger(__name__) else: # dummy logger (does nothing) logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) return logger def update_lr_warm_up(nb_iter, warm_up_iter, optimizer, lr): current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1) for param_group in optimizer.param_groups: param_group["lr"] = current_lr return current_lr def save(file_name, ep, model, optimizer, scheduler, total_it, name, ema=None): state = { name: model.state_dict(), f"opt_{name}": optimizer.state_dict(), "scheduler": scheduler.state_dict(), 'ep': ep, 'total_it': total_it, } if ema is not None: mardm_state_dict = model.state_dict() ema_mardm_state_dict = ema.state_dict() clip_weights = [e for e in mardm_state_dict.keys() if e.startswith('clip_model.')] for e in clip_weights: del mardm_state_dict[e] del ema_mardm_state_dict[e] state[name] = mardm_state_dict state[f"ema_{name}"] = ema_mardm_state_dict torch.save(state, file_name) def print_current_loss(start_time, niter_state, total_niters, losses, epoch=None, sub_epoch=None, inner_iter=None, tf_ratio=None, sl_steps=None): def as_minutes(s): m = math.floor(s / 60) s -= m * 60 return '%dm %ds' % (m, s) def time_since(since, percent): now = time.time() s = now - since es = s / percent rs = es - s return '%s (- %s)' % (as_minutes(s), as_minutes(rs)) if epoch is not None: print('ep/it:%2d-%4d niter:%6d' % (epoch, inner_iter, niter_state), end=" ") message = ' %s completed:%3d%%)' % (time_since(start_time, niter_state / total_niters), niter_state / total_niters * 100) for k, v in losses.items(): message += ' %s: %.4f ' % (k, v) print(message)