| import argparse | |
| import json | |
| import os | |
| import math | |
| from functools import partial | |
| import cv2 | |
| import numpy as np | |
| import yaml | |
| import torch | |
| from einops import rearrange | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| import datasets | |
| import models | |
| import utils | |
| device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
| def batched_predict(model, img, bsize): | |
| with torch.no_grad(): | |
| pred = model(img) | |
| return pred | |
| def eval_psnr(loader, class_names, model, | |
| data_norm=None, eval_type=None, save_fig=False, | |
| scale_ratio=1, save_path=None, verbose=False, crop_border=4, | |
| cal_metrics=True,): | |
| crop_border = int(crop_border) if crop_border else crop_border | |
| print('crop border: ', crop_border) | |
| model.eval() | |
| if data_norm is None: | |
| data_norm = { | |
| 'img': {'sub': [0], 'div': [1]}, | |
| 'gt': {'sub': [0], 'div': [1]} | |
| } | |
| t = data_norm['img'] | |
| img_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).to(device) | |
| img_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).to(device) | |
| t = data_norm['gt'] | |
| gt_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).to(device) | |
| gt_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).to(device) | |
| if eval_type is None: | |
| metric_fn = [utils.calculate_psnr_pt, utils.calculate_ssim_pt] | |
| elif eval_type == 'psnr+ssim': | |
| metric_fn = [utils.calculate_psnr_pt, utils.calculate_ssim_pt] | |
| elif eval_type.startswith('div2k'): | |
| scale = int(eval_type.split('-')[1]) | |
| metric_fn = partial(utils.calc_psnr, dataset='div2k', scale=scale) | |
| elif eval_type.startswith('benchmark'): | |
| scale = int(eval_type.split('-')[1]) | |
| metric_fn = partial(utils.calc_psnr, dataset='benchmark', scale=scale) | |
| else: | |
| raise NotImplementedError | |
| val_res_psnr = utils.Averager(class_names) | |
| val_res_ssim = utils.Averager(class_names) | |
| pbar = tqdm(loader, leave=False, desc='val') | |
| for batch in pbar: | |
| for k, v in batch.items(): | |
| if torch.is_tensor(v): | |
| batch[k] = v.to(device) | |
| img = (batch['img'] - img_sub) / img_div | |
| with torch.no_grad(): | |
| pred = model(img, batch['gt'].shape[-2:]) | |
| if isinstance(pred, list): | |
| pred = pred[-1] | |
| pred = pred * gt_div + gt_sub | |
| if cal_metrics: | |
| res_psnr = metric_fn[0]( | |
| pred, | |
| batch['gt'], | |
| crop_border=crop_border | |
| ) | |
| res_ssim = metric_fn[1]( | |
| pred, | |
| batch['gt'], | |
| crop_border=crop_border | |
| ) | |
| else: | |
| res_psnr = torch.ones(len(pred)) | |
| res_ssim = torch.ones(len(pred)) | |
| file_names = batch.get('filename', None) | |
| if file_names is not None and save_fig: | |
| for idx in range(len(batch['img'])): | |
| ori_img = batch['img'][idx].cpu().numpy() * 255 | |
| ori_img = np.clip(ori_img, a_min=0, a_max=255) | |
| ori_img = ori_img.astype(np.uint8) | |
| ori_img = rearrange(ori_img, 'C H W -> H W C') | |
| pred_img = pred[idx].cpu().numpy() * 255 | |
| pred_img = np.clip(pred_img, a_min=0, a_max=255) | |
| pred_img = pred_img.astype(np.uint8) | |
| pred_img = rearrange(pred_img, 'C H W -> H W C') | |
| gt_img = batch['gt'][idx].cpu().numpy() * 255 | |
| gt_img = np.clip(gt_img, a_min=0, a_max=255) | |
| gt_img = gt_img.astype(np.uint8) | |
| gt_img = rearrange(gt_img, 'C H W -> H W C') | |
| psnr = res_psnr[idx].cpu().numpy() | |
| ssim = res_ssim[idx].cpu().numpy() | |
| ori_file_name = f'{save_path}/{file_names[idx]}_Ori.png' | |
| cv2.imwrite(ori_file_name, ori_img) | |
| pred_file_name = f'{save_path}/{file_names[idx]}_{scale_ratio}X_{psnr:.2f}_{ssim:.4f}.png' | |
| cv2.imwrite(pred_file_name, pred_img) | |
| gt_file_name = f'{save_path}/{file_names[idx]}_GT.png' | |
| cv2.imwrite(gt_file_name, gt_img) | |
| val_res_psnr.add(batch['class_name'], res_psnr) | |
| val_res_ssim.add(batch['class_name'], res_ssim) | |
| if verbose: | |
| pbar.set_description( | |
| 'val psnr: {:.4f} ssim: {:.4f}'.format(val_res_psnr.item()['all'], val_res_ssim.item()['all'])) | |
| return val_res_psnr.item(), val_res_ssim.item() | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--config', default='configs/test_CNN.yaml') | |
| parser.add_argument('--model', default='checkpoints/EXP20220610_5/epoch-best.pth') | |
| parser.add_argument('--scale_ratio', default=4, type=float) | |
| parser.add_argument('--save_fig', default=False, type=bool) | |
| parser.add_argument('--save_path', default='tmp', type=str) | |
| parser.add_argument('--cal_metrics', default=True, type=bool) | |
| parser.add_argument('--return_class_metrics', default=False, type=bool) | |
| parser.add_argument('--dataset_name', default='UC', type=str) | |
| args = parser.parse_args() | |
| with open(args.config, 'r') as f: | |
| config = yaml.load(f, Loader=yaml.FullLoader) | |
| root_split_file = {'UC': | |
| { | |
| 'root_path': '/data/kyanchen/datasets/UC/256', | |
| 'split_file': 'data_split/UC_split.json' | |
| }, | |
| 'AID': | |
| { | |
| 'root_path': '/data/kyanchen/datasets/AID', | |
| 'split_file': 'data_split/AID_split.json' | |
| } | |
| } | |
| config['test_dataset']['dataset']['args']['root_path'] = root_split_file[args.dataset_name]['root_path'] | |
| config['test_dataset']['dataset']['args']['split_file'] = root_split_file[args.dataset_name]['split_file'] | |
| config['test_dataset']['wrapper']['args']['scale_ratio'] = args.scale_ratio | |
| spec = config['test_dataset'] | |
| dataset = datasets.make(spec['dataset']) | |
| dataset = datasets.make(spec['wrapper'], args={'dataset': dataset}) | |
| loader = DataLoader(dataset, batch_size=spec['batch_size'], num_workers=0, pin_memory=True, shuffle=False, | |
| drop_last=False) | |
| if not os.path.exists(args.model): | |
| assert NameError | |
| model_spec = torch.load(args.model)['model'] | |
| print(model_spec['args']) | |
| model = models.make(model_spec, load_sd=True).to(device) | |
| file_names = json.load(open(config['test_dataset']['dataset']['args']['split_file']))['test'] | |
| class_names = list(set([os.path.basename(os.path.dirname(x)) for x in file_names])) | |
| crop_border = config['test_dataset']['wrapper']['args']['scale_ratio'] + 5 | |
| dataset_name = os.path.basename(config['test_dataset']['dataset']['args']['split_file']).split('_')[0] | |
| max_scale = {'UC': 5, 'AID': 12} | |
| if args.scale_ratio > max_scale[dataset_name]: | |
| crop_border = int((args.scale_ratio - max_scale[dataset_name]) / 2 * 48) | |
| if args.save_fig: | |
| os.makedirs(args.save_path, exist_ok=True) | |
| res = eval_psnr( | |
| loader, class_names, model, | |
| data_norm=config.get('data_norm'), | |
| eval_type=config.get('eval_type'), | |
| crop_border=crop_border, | |
| verbose=True, | |
| save_fig=args.save_fig, | |
| scale_ratio=args.scale_ratio, | |
| save_path=args.save_path, | |
| cal_metrics=args.cal_metrics | |
| ) | |
| if args.return_class_metrics: | |
| keys = list(res[0].keys()) | |
| keys.sort() | |
| print('psnr') | |
| for k in keys: | |
| print(f'{k}: {res[0][k]:0.2f}') | |
| print('ssim') | |
| for k in keys: | |
| print(f'{k}: {res[1][k]:0.4f}') | |
| print(f'psnr: {res[0]["all"]:0.2f}') | |
| print(f'ssim: {res[1]["all"]:0.4f}') | |