Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.optim.lr_scheduler import CosineAnnealingLR | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from ptflops import get_model_complexity_info | |
| from .DarkIR import DarkIR | |
| def create_model(opt, rank, adapter = False): | |
| ''' | |
| Creates the model. | |
| opt: a dictionary from the yaml config key network | |
| ''' | |
| name = opt['name'] | |
| model = DarkIR(img_channel=opt['img_channels'], | |
| width=opt['width'], | |
| middle_blk_num_enc=opt['middle_blk_num_enc'], | |
| middle_blk_num_dec=opt['middle_blk_num_dec'], | |
| enc_blk_nums=opt['enc_blk_nums'], | |
| dec_blk_nums=opt['dec_blk_nums'], | |
| dilations=opt['dilations'], | |
| extra_depth_wise=opt['extra_depth_wise']) | |
| if rank ==0: | |
| print(f'Using {name} network') | |
| input_size = (3, 256, 256) | |
| macs, params = get_model_complexity_info(model, input_size, print_per_layer_stat = False) | |
| print(f'Computational complexity at {input_size}: {macs}') | |
| print('Number of parameters: ', params) | |
| else: | |
| macs, params = None, None | |
| model.to(rank) | |
| model = DDP(model, device_ids=[rank], find_unused_parameters=adapter) | |
| return model, macs, params | |
| def create_optim_scheduler(opt, model): | |
| ''' | |
| Returns the optim and its scheduler. | |
| opt: a dictionary of the yaml config file with the train key | |
| ''' | |
| optim = torch.optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()) , | |
| lr = opt['lr_initial'], | |
| weight_decay = opt['weight_decay'], | |
| betas = opt['betas']) | |
| if opt['lr_scheme'] == 'CosineAnnealing': | |
| scheduler = CosineAnnealingLR(optim, T_max=opt['epochs'], eta_min=opt['eta_min']) | |
| else: | |
| raise NotImplementedError('scheduler not implemented') | |
| return optim, scheduler | |
| def load_weights(model, old_weights): | |
| ''' | |
| Loads the weights of a pretrained model, picking only the weights that are | |
| in the new model. | |
| ''' | |
| new_weights = model.state_dict() | |
| new_weights.update({k: v for k, v in old_weights.items() if k in new_weights}) | |
| model.load_state_dict(new_weights) | |
| return model | |
| def load_optim(optim, optim_weights): | |
| ''' | |
| Loads the values of the optimizer picking only the weights that are in the new model. | |
| ''' | |
| optim_new_weights = optim.state_dict() | |
| # optim_new_weights.load_state_dict(optim_weights) | |
| optim_new_weights.update({k:v for k, v in optim_weights.items() if k in optim_new_weights}) | |
| return optim | |
| def resume_model(model, | |
| optim, | |
| scheduler, | |
| path_model, | |
| rank,resume:str=None): | |
| ''' | |
| Returns the loaded weights of model and optimizer if resume flag is True | |
| ''' | |
| map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} | |
| if resume: | |
| checkpoints = torch.load(path_model, map_location=map_location, weights_only=False) | |
| weights = checkpoints['model_state_dict'] | |
| model = load_weights(model, old_weights=weights) | |
| optim = load_optim(optim, optim_weights = checkpoints['optimizer_state_dict']) | |
| scheduler.load_state_dict(checkpoints['scheduler_state_dict']) | |
| start_epochs = checkpoints['epoch'] | |
| if rank == 0: print('Loaded weights') | |
| else: | |
| start_epochs = 0 | |
| if rank==0: print('Starting from zero the training') | |
| return model, optim, scheduler, start_epochs | |
| def find_different_keys(dict1, dict2): | |
| # Finding different keys | |
| different_keys = set(dict1.keys()) ^ set(dict2.keys()) | |
| return different_keys | |
| def number_common_keys(dict1, dict2): | |
| # Finding common keys | |
| common_keys = set(dict1.keys()) & set(dict2.keys()) | |
| # Counting the number of common keys | |
| common_keys_count = len(common_keys) | |
| return common_keys_count | |
| # # Function to add 'modules_list' prefix after the first numeric index | |
| # def add_middle_prefix(state_dict, middle_prefix, target_strings): | |
| # new_state_dict = {} | |
| # for key, value in state_dict.items(): | |
| # for target in target_strings: | |
| # if target in key: | |
| # parts = key.split('.') | |
| # # Find the first numeric index after the target string | |
| # for i, part in enumerate(parts): | |
| # if part == target: | |
| # # Insert the middle prefix after the first numeric index | |
| # if i + 1 < len(parts) and parts[i + 1].isdigit(): | |
| # parts.insert(i + 2, middle_prefix) | |
| # break | |
| # new_key = '.'.join(parts) | |
| # new_state_dict[new_key] = value | |
| # break | |
| # else: | |
| # new_state_dict[key] = value | |
| # return new_state_dict | |
| # # Function to adjust keys for 'middle_blks.' prefix | |
| # def adjust_middle_blks_keys(state_dict, target_prefix, middle_prefix): | |
| # new_state_dict = {} | |
| # for key, value in state_dict.items(): | |
| # if target_prefix in key: | |
| # parts = key.split('.') | |
| # # Find the target prefix and adjust the key | |
| # for i, part in enumerate(parts): | |
| # if part == target_prefix.rstrip('.'): | |
| # if i + 1 < len(parts) and parts[i + 1].isdigit(): | |
| # # Swap the numerical part and the middle prefix | |
| # new_key = '.'.join(parts[:i + 1] + [middle_prefix] + parts[i + 1:i + 2] + parts[i + 2:]) | |
| # new_state_dict[new_key] = value | |
| # break | |
| # else: | |
| # new_state_dict[key] = value | |
| # return new_state_dict | |
| # def resume_nafnet(model, | |
| # optim, | |
| # scheduler, | |
| # path_adapter, | |
| # path_model, | |
| # rank, resume:str=None): | |
| # ''' | |
| # Returns the loaded weights of model and optimizer if resume flag is True | |
| # ''' | |
| # map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} | |
| # #first load the model weights | |
| # checkpoints = torch.load(path_model, map_location=map_location, weights_only=False) | |
| # weights = checkpoints | |
| # if rank==0: | |
| # print(len(weights), len(model.state_dict().keys())) | |
| # different_keys = find_different_keys(weights, model.state_dict()) | |
| # filtered_keys = {item for item in different_keys if 'adapter' not in item} | |
| # print(filtered_keys) | |
| # print(len(filtered_keys)) | |
| # model = load_weights(model, old_weights=weights) | |
| # #now if needed load the adapter weights | |
| # if resume: | |
| # checkpoints = torch.load(path_adapter, map_location=map_location, weights_only=False) | |
| # weights = checkpoints | |
| # model = load_weights(model, old_weights=weights) | |
| # # optim = load_optim(optim, optim_weights = checkpoints['optimizer_state_dict']) | |
| # scheduler.load_state_dict(checkpoints['scheduler_state_dict']) | |
| # start_epochs = checkpoints['epoch'] | |
| # if rank == 0: print('Loaded weights') | |
| # else: | |
| # start_epochs = 0 | |
| # if rank == 0: print('Starting from zero the training') | |
| # return model, optim, scheduler, start_epochs | |
| def save_checkpoint(model, optim, scheduler, metrics_eval, metrics_train, paths, adapter = False, rank = None): | |
| ''' | |
| Save the .pt of the model after each epoch. | |
| ''' | |
| best_psnr = metrics_train['best_psnr'] | |
| if rank!=0: | |
| return best_psnr | |
| if type(next(iter(metrics_eval.values()))) != dict: | |
| metrics_eval = {'metrics': metrics_eval} | |
| weights = model.state_dict() | |
| # Save the model after every epoch | |
| model_to_save = { | |
| 'epoch': metrics_train['epoch'], | |
| 'model_state_dict': weights, | |
| 'optimizer_state_dict': optim.state_dict(), | |
| 'loss': metrics_train['train_loss'], | |
| 'scheduler_state_dict': scheduler.state_dict() | |
| } | |
| try: | |
| torch.save(model_to_save, paths['new']) | |
| # Save best model if new valid_psnr is higher than the best one | |
| if next(iter(metrics_eval.values()))['valid_psnr'] >= metrics_train['best_psnr']: | |
| torch.save(model_to_save, paths['best']) | |
| metrics_train['best_psnr'] = next(iter(metrics_eval.values()))['valid_psnr'] # update best psnr | |
| except Exception as e: | |
| print(f"Error saving model: {e}") | |
| return metrics_train['best_psnr'] | |
| __all__ = ['create_model', 'resume_model', 'create_optim_scheduler', 'save_checkpoint', | |
| 'load_optim', 'load_weights'] | |