| from .transenet import * | |
| from .srcnn import SRCNN | |
| from .fsrcnn import FSRCNN | |
| from .lgcnet import LGCNET | |
| from .dcm import DIM | |
| from .vdsr import VDSR | |
| # import os | |
| # from importlib import import_module | |
| # | |
| # import torch | |
| # import torch.nn as nn | |
| # | |
| # | |
| # class Model(nn.Module): | |
| # def __init__(self, args, ckp): | |
| # super(Model, self).__init__() | |
| # print('Making model...') | |
| # | |
| # self.scale = args.scale | |
| # self.idx_scale = 0 | |
| # self.self_ensemble = args.self_ensemble | |
| # self.chop = args.chop | |
| # self.precision = args.precision | |
| # self.cpu = args.cpu | |
| # self.device = torch.device('cpu' if args.cpu else 'cuda') | |
| # self.n_GPUs = args.n_GPUs | |
| # self.save_models = args.save_models | |
| # | |
| # module = import_module('model.' + args.model.lower()) | |
| # self.model = module.make_model(args).to(self.device) | |
| # if args.precision == 'half': self.model.half() | |
| # | |
| # if not args.cpu and args.n_GPUs > 1: | |
| # self.model = nn.DataParallel(self.model, range(args.n_GPUs)) | |
| # | |
| # self.load( | |
| # ckp.dir, | |
| # pre_train=args.pre_train, | |
| # resume=args.resume, | |
| # cpu=args.cpu | |
| # ) | |
| # if args.print_model: print(self.model) | |
| # | |
| # def forward(self, x): | |
| # target = self.get_model() | |
| # | |
| # if self.self_ensemble and not self.training: | |
| # if self.chop: | |
| # forward_function = self.forward_chop | |
| # else: | |
| # forward_function = self.model.forward | |
| # | |
| # return self.forward_x8(x, forward_function) | |
| # elif self.chop and not self.training: | |
| # return self.forward_chop(x) | |
| # else: | |
| # return self.model(x) | |
| # | |
| # def get_model(self): | |
| # if self.n_GPUs == 1: | |
| # return self.model | |
| # else: | |
| # return self.model.module | |
| # | |
| # def state_dict(self, **kwargs): | |
| # target = self.get_model() | |
| # return target.state_dict(**kwargs) | |
| # | |
| # def save(self, apath, epoch, is_best=False): | |
| # target = self.get_model() | |
| # torch.save( | |
| # target.state_dict(), | |
| # os.path.join(apath, 'model', 'model_latest.pt') | |
| # ) | |
| # if is_best: | |
| # torch.save( | |
| # target.state_dict(), | |
| # os.path.join(apath, 'model', 'model_best.pt') | |
| # ) | |
| # | |
| # if self.save_models: | |
| # torch.save( | |
| # target.state_dict(), | |
| # os.path.join(apath, 'model', 'model_{}.pt'.format(epoch)) | |
| # ) | |
| # | |
| # def load(self, apath, pre_train='.', resume=-1, cpu=False): | |
| # if cpu: | |
| # kwargs = {'map_location': lambda storage, loc: storage} | |
| # else: | |
| # kwargs = {} | |
| # | |
| # if resume == 1: # loading model from model_latest.pt file | |
| # print('loading model from the model_latest.pt file...') | |
| # self.get_model().load_state_dict( | |
| # torch.load( | |
| # os.path.join(apath, 'model', 'model_latest.pt'), | |
| # **kwargs | |
| # ), | |
| # strict=False | |
| # ) | |
| # elif resume == 0: # loading model from a pre-trained model file ... | |
| # if pre_train != '.': | |
| # print('Loading model from {}'.format(pre_train)) | |
| # self.get_model().load_state_dict( | |
| # torch.load(pre_train, **kwargs), | |
| # strict=False | |
| # ) | |
| # else: | |
| # self.get_model().load_state_dict( | |
| # torch.load( | |
| # os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), | |
| # **kwargs | |
| # ), | |
| # strict=False | |
| # ) | |
| # | |
| # def forward_chop(self, x, shave=10, min_size=160000): | |
| # scale = self.scale[self.idx_scale] | |
| # n_GPUs = min(self.n_GPUs, 4) | |
| # b, c, h, w = x.size() | |
| # h_half, w_half = h // 2, w // 2 | |
| # h_size, w_size = h_half + shave, w_half + shave | |
| # lr_list = [ | |
| # x[:, :, 0:h_size, 0:w_size], | |
| # x[:, :, 0:h_size, (w - w_size):w], | |
| # x[:, :, (h - h_size):h, 0:w_size], | |
| # x[:, :, (h - h_size):h, (w - w_size):w]] | |
| # | |
| # if w_size * h_size < min_size: | |
| # sr_list = [] | |
| # for i in range(0, 4, n_GPUs): | |
| # lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) | |
| # sr_batch = self.model(lr_batch) | |
| # sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) | |
| # else: | |
| # sr_list = [ | |
| # self.forward_chop(patch, shave=shave, min_size=min_size) \ | |
| # for patch in lr_list | |
| # ] | |
| # | |
| # h, w = scale * h, scale * w | |
| # h_half, w_half = scale * h_half, scale * w_half | |
| # h_size, w_size = scale * h_size, scale * w_size | |
| # shave *= scale | |
| # | |
| # output = x.new(b, c, h, w) | |
| # output[:, :, 0:h_half, 0:w_half] \ | |
| # = sr_list[0][:, :, 0:h_half, 0:w_half] | |
| # output[:, :, 0:h_half, w_half:w] \ | |
| # = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] | |
| # output[:, :, h_half:h, 0:w_half] \ | |
| # = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] | |
| # output[:, :, h_half:h, w_half:w] \ | |
| # = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] | |
| # | |
| # return output | |
| # | |
| # def forward_x8(self, x, forward_function): | |
| # def _transform(v, op): | |
| # if self.precision != 'single': v = v.float() | |
| # | |
| # v2np = v.data.cpu().numpy() | |
| # if op == 'v': | |
| # tfnp = v2np[:, :, :, ::-1].copy() | |
| # elif op == 'h': | |
| # tfnp = v2np[:, :, ::-1, :].copy() | |
| # elif op == 't': | |
| # tfnp = v2np.transpose((0, 1, 3, 2)).copy() | |
| # | |
| # ret = torch.Tensor(tfnp).to(self.device) | |
| # if self.precision == 'half': ret = ret.half() | |
| # | |
| # return ret | |
| # | |
| # lr_list = [x] | |
| # for tf in 'v', 'h', 't': | |
| # lr_list.extend([_transform(t, tf) for t in lr_list]) | |
| # | |
| # sr_list = [forward_function(aug) for aug in lr_list] | |
| # for i in range(len(sr_list)): | |
| # if i > 3: | |
| # sr_list[i] = _transform(sr_list[i], 't') | |
| # if i % 4 > 1: | |
| # sr_list[i] = _transform(sr_list[i], 'h') | |
| # if (i % 4) % 2 == 1: | |
| # sr_list[i] = _transform(sr_list[i], 'v') | |
| # | |
| # output_cat = torch.cat(sr_list, dim=0) | |
| # output = output_cat.mean(dim=0, keepdim=True) | |
| # | |
| # return output | |
| # | |
| # | |
| # | |
| # | |
| # | |
| # | |
| # | |
| # | |
| # | |
| # | |
| # | |
| # | |
| # | |
| # | |