Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| # -*- coding:utf-8 -*- | |
| import os | |
| import shutil | |
| import torch | |
| import os.path as osp | |
| from yolov6.utils.events import LOGGER | |
| from yolov6.utils.torch_utils import fuse_model | |
| def load_state_dict(weights, model, map_location=None): | |
| """Load weights from checkpoint file, only assign weights those layers' name and shape are match.""" | |
| ckpt = torch.load(weights, map_location=map_location) | |
| state_dict = ckpt['model'].float().state_dict() | |
| model_state_dict = model.state_dict() | |
| state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape} | |
| model.load_state_dict(state_dict, strict=False) | |
| del ckpt, state_dict, model_state_dict | |
| return model | |
| def load_checkpoint(weights, map_location=None, inplace=True, fuse=True): | |
| """Load model from checkpoint file.""" | |
| LOGGER.info("Loading checkpoint from {}".format(weights)) | |
| ckpt = torch.load(weights, map_location=map_location) # load | |
| model = ckpt['ema' if ckpt.get('ema') else 'model'].float() | |
| if fuse: | |
| LOGGER.info("\nFusing model...") | |
| model = fuse_model(model).eval() | |
| else: | |
| model = model.eval() | |
| return model | |
| def save_checkpoint(ckpt, is_best, save_dir, model_name=""): | |
| """ Save checkpoint to the disk.""" | |
| if not osp.exists(save_dir): | |
| os.makedirs(save_dir) | |
| filename = osp.join(save_dir, model_name + '.pt') | |
| torch.save(ckpt, filename) | |
| if is_best: | |
| best_filename = osp.join(save_dir, 'best_ckpt.pt') | |
| shutil.copyfile(filename, best_filename) | |
| def strip_optimizer(ckpt_dir, epoch): | |
| for s in ['best', 'last']: | |
| ckpt_path = osp.join(ckpt_dir, '{}_ckpt.pt'.format(s)) | |
| if not osp.exists(ckpt_path): | |
| continue | |
| ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) | |
| if ckpt.get('ema'): | |
| ckpt['model'] = ckpt['ema'] # replace model with ema | |
| for k in ['optimizer', 'ema', 'updates']: # keys | |
| ckpt[k] = None | |
| ckpt['epoch'] = epoch | |
| ckpt['model'].half() # to FP16 | |
| for p in ckpt['model'].parameters(): | |
| p.requires_grad = False | |
| torch.save(ckpt, ckpt_path) | |