Spaces:
Runtime error
Runtime error
| import yaml | |
| import torch | |
| import random | |
| import numpy as np | |
| import os | |
| import sys | |
| import matplotlib.pyplot as plt | |
| from einops import repeat | |
| import cv2 | |
| import time | |
| import torch.nn.functional as F | |
| __all__ = ["decode_mask_to_onehot", | |
| "encode_onehot_to_mask", | |
| 'Logger', | |
| 'get_coords_grid', | |
| 'get_coords_grid_float', | |
| 'draw_bboxes', | |
| 'Infos', | |
| 'inv_normalize_img', | |
| 'make_numpy_img', | |
| 'get_metrics' | |
| ] | |
| class Infos(object): | |
| def __init__(self, phase, class_names=None): | |
| assert phase in ['od'], "Error in Infos" | |
| self.phase = phase | |
| self.class_names = class_names | |
| self.register() | |
| self.pattern = 'train' | |
| self.epoch_id = 0 | |
| self.max_epoch = 0 | |
| self.batch_id = 0 | |
| self.batch_num = 0 | |
| self.lr = 0 | |
| self.fps_data_load = 0 | |
| self.fps = 0 | |
| self.val_metric = 0 | |
| # 'running_acc': {'loss': [], 'mIoU': [], 'OA': [], 'F1_score': []}, | |
| # 'epoch_metrics': {'loss': 1e10, 'mIoU': 0, 'OA': 0, 'F1_score': 0}, | |
| # 'best_val_metrics': {'epoch_id': 0, 'loss': 1e10, 'mIoU': 0, 'OA': 0, 'F1_score': 0}, | |
| def set_epoch_training_time(self, data): | |
| self.epoch_training_time = data | |
| def set_pattern(self, data): | |
| self.pattern = data | |
| def set_epoch_id(self, data): | |
| self.epoch_id = data | |
| def set_max_epoch(self, data): | |
| self.max_epoch = data | |
| def set_batch_id(self, data): | |
| self.batch_id = data | |
| def set_batch_num(self, data): | |
| self.batch_num = data | |
| def set_lr(self, data): | |
| self.lr = data | |
| def set_fps_data_load(self, data): | |
| self.fps_data_load = data | |
| def set_fps(self, data): | |
| self.fps = data | |
| def clear_cache(self): | |
| self.register() | |
| def get_val_metric(self): | |
| return self.val_metric | |
| def cal_metrics(self): | |
| if self.phase == 'od': | |
| coco_api_gt = COCO() | |
| coco_api_gt.dataset['images'] = [] | |
| coco_api_gt.dataset['annotations'] = [] | |
| ann_id = 0 | |
| for i, targets_per_image in enumerate(self.result_all['target_all']): | |
| for j in range(targets_per_image.shape[0]): | |
| coco_api_gt.dataset['images'].append({'id': i}) | |
| coco_api_gt.dataset['annotations'].append({ | |
| 'image_id': i, | |
| "category_id": int(targets_per_image[j, 0]), | |
| "bbox": np.hstack([targets_per_image[j, 1:3], targets_per_image[j, 3:5] - targets_per_image[j, 1:3]]), | |
| "area": np.prod(targets_per_image[j, 3:5] - targets_per_image[j, 1:3]), | |
| "id": ann_id, | |
| "iscrowd": 0 | |
| }) | |
| ann_id += 1 | |
| coco_api_gt.dataset['categories'] = [{"id": i, "supercategory": c, "name": c} for i, c in | |
| enumerate(self.class_names)] | |
| coco_api_gt.createIndex() | |
| coco_api_pred = COCO() | |
| coco_api_pred.dataset['images'] = [] | |
| coco_api_pred.dataset['annotations'] = [] | |
| ann_id = 0 | |
| for i, preds_per_image in enumerate(self.result_all['pred_all']): | |
| for j in range(preds_per_image.shape[0]): | |
| coco_api_pred.dataset['images'].append({'id': i}) | |
| coco_api_pred.dataset['annotations'].append({ | |
| 'image_id': i, | |
| "category_id": int(preds_per_image[j, 0]), | |
| 'score': preds_per_image[j, 1], | |
| "bbox": np.hstack( | |
| [preds_per_image[j, 2:4], preds_per_image[j, 4:6] - preds_per_image[j, 2:4]]), | |
| "area": np.prod(preds_per_image[j, 4:6] - preds_per_image[j, 2:4]), | |
| "id": ann_id, | |
| "iscrowd": 0 | |
| }) | |
| ann_id += 1 | |
| coco_api_pred.dataset['categories'] = [{"id": i, "supercategory": c, "name": c} for i, c in | |
| enumerate(self.class_names)] | |
| coco_api_pred.createIndex() | |
| coco_eval = COCOeval(coco_api_gt, coco_api_pred, "bbox") | |
| coco_eval.params.imgIds = coco_api_gt.getImgIds() | |
| coco_eval.evaluate() | |
| coco_eval.accumulate() | |
| self.metrics = coco_eval.summarize() | |
| self.val_metric = self.metrics[1] | |
| def print_epoch_state_infos(self, logger): | |
| infos_str = 'Pattern: %s Epoch [%d,%d], time: %d loss: %.4f' % \ | |
| (self.pattern, self.epoch_id, self.max_epoch, self.epoch_training_time, np.mean(self.loss_all['loss'])) | |
| logger.write(infos_str + '\n') | |
| time_start = time.time() | |
| self.cal_metrics() | |
| time_end = time.time() | |
| logger.write('Pattern: %s Epoch Eval_time: %d\n' % (self.pattern, (time_end - time_start))) | |
| if self.phase == 'od': | |
| titleStr = 6 * ['Average Precision'] + 6 * ['Average Recall'] | |
| typeStr = 6 * ['(AP)'] + 6 * ['(AR)'] | |
| iouStr = 12 * ['0.50:0.95'] | |
| iouStr[1] = '0.50' | |
| iouStr[2] = '0.75' | |
| areaRng = 3 * ['all'] + ['small', 'medium', 'large'] + 3 * ['all'] + ['small', 'medium', 'large'] | |
| maxDets = 6 * [100] + [1, 10, 100] + 3 * [100] | |
| for i in range(12): | |
| infos_str = '{:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}\n' | |
| logger.write(infos_str.format(titleStr[i], typeStr[i], iouStr[i], areaRng[i], maxDets[i], self.metrics[i])) | |
| def save_epoch_state_infos(self, writer): | |
| iter = self.epoch_id | |
| keys = [ | |
| 'AP_m_all_100', | |
| 'AP_50_all_100', | |
| 'AP_75_all_100', | |
| 'AP_m_small_100', | |
| 'AP_m_medium_100', | |
| 'AP_m_large_100', | |
| 'AR_m_all_1', | |
| 'AR_m_all_10', | |
| 'AR_m_all_100', | |
| 'AR_m_small_100', | |
| 'AR_m_medium_100', | |
| 'AR_m_large_100', | |
| ] | |
| for i, key in enumerate(keys): | |
| writer.add_scalar(f'%s/epoch/%s' % (self.pattern, key), self.metrics[i], iter) | |
| def print_batch_state_infos(self, logger): | |
| infos_str = 'Pattern: %s [%d,%d][%d,%d], lr: %5f, fps_data_load: %.2f, fps: %.2f' % \ | |
| (self.pattern, self.epoch_id, self.max_epoch, self.batch_id, | |
| self.batch_num, self.lr, self.fps_data_load, self.fps) | |
| # add loss | |
| infos_str += ', loss: %.4f' % self.loss_all['loss'][-1] | |
| logger.write(infos_str + '\n') | |
| def save_batch_state_infos(self, writer): | |
| iter = self.epoch_id * self.batch_num + self.batch_id | |
| writer.add_scalar('%s/lr' % self.pattern, self.lr, iter) | |
| for key, value in self.loss_all.items(): | |
| writer.add_scalar(f'%s/%s' % (self.pattern, key), value[-1], iter) | |
| def save_results(self, img_batch, prior_mean, prior_std, vis_dir, *args, **kwargs): | |
| batch_size = img_batch.size(0) | |
| k = np.clip(int(0.3 * batch_size), a_min=1, a_max=batch_size) | |
| ids = np.random.choice(range(batch_size), k, replace=False) | |
| for img_id in ids: | |
| img = img_batch[img_id].detach().cpu() | |
| pred = self.result_all['pred_all'][img_id - batch_size] | |
| target = self.result_all['target_all'][img_id - batch_size] | |
| img = make_numpy_img(inv_normalize_img(img, prior_mean, prior_std)) | |
| pred_draw = draw_bboxes(img, pred, self.class_names, (255, 0, 0)) | |
| target_draw = draw_bboxes(img, target, self.class_names, (0, 255, 0)) | |
| # target = make_numpy_img(encode_onehot_to_mask(target)) | |
| # pred = make_numpy_img(pred_label[img_id]) | |
| vis = np.concatenate([img/255., pred_draw/255., target_draw/255.], axis=0) | |
| vis = np.clip(vis, a_min=0, a_max=1) | |
| file_name = os.path.join(vis_dir, self.pattern, f'{self.epoch_id}_{self.batch_id}_{img_id}.png') | |
| plt.imsave(file_name, vis) | |
| def register(self): | |
| self.is_registered_result = False | |
| self.result_all = {} | |
| self.is_registered_loss = False | |
| self.loss_all = {} | |
| def register_result(self, data: dict): | |
| for key in data.keys(): | |
| self.result_all[key] = [] | |
| self.is_registered_result = True | |
| def append_result(self, data: dict): | |
| if not self.is_registered_result: | |
| self.register_result(data) | |
| for key, value in data.items(): | |
| self.result_all[key] += value | |
| def register_loss(self, data: dict): | |
| for key in data.keys(): | |
| self.loss_all[key] = [] | |
| self.is_registered_loss = True | |
| def append_loss(self, data: dict): | |
| if not self.is_registered_loss: | |
| self.register_loss(data) | |
| for key, value in data.items(): | |
| self.loss_all[key].append(value.detach().cpu().numpy()) | |
| # draw bboxes on image, bboxes with classID | |
| def draw_bboxes(img, bboxes, color=(255, 0, 0), class_names=None, is_show_score=True): | |
| ''' | |
| Args: | |
| img: | |
| bboxes: [n, 5], class_idx, l, t, r, b | |
| [n, 6], class_idx, score, l, t, r, b | |
| Returns: | |
| ''' | |
| assert img is not None, "In draw_bboxes, img is None" | |
| if torch.is_tensor(img): | |
| img = img.cpu().numpy() | |
| img = img.astype(np.uint8).copy() | |
| if torch.is_tensor(bboxes): | |
| bboxes = bboxes.cpu().numpy() | |
| for bbox in bboxes: | |
| if class_names: | |
| class_name = class_names[int(bbox[0])] | |
| bbox_coordinate = bbox[1:] | |
| if len(bbox) == 6: | |
| score = bbox[1] | |
| bbox_coordinate = bbox[2:] | |
| bbox_coordinate = bbox_coordinate.astype(np.int) | |
| if is_show_score: | |
| cv2.rectangle(img, pt1=tuple(bbox_coordinate[0:2] - np.array([2, 15])), | |
| pt2=tuple(bbox_coordinate[0:2] + np.array([15, 1])), color=(0, 0, 255), thickness=-1) | |
| if len(bbox) == 6: | |
| cv2.putText(img, text='%s:%.2f' % (class_name, score), | |
| org=tuple(bbox_coordinate[0:2] - np.array([1, 7])), fontFace=cv2.FONT_HERSHEY_SIMPLEX, | |
| fontScale=0.2, color=(255, 255, 255), thickness=1) | |
| else: | |
| cv2.putText(img, text='%s' % class_name, | |
| org=tuple(bbox_coordinate[0:2] - np.array([1, 7])), fontFace=cv2.FONT_HERSHEY_SIMPLEX, | |
| fontScale=0.2, color=(255, 255, 255), thickness=1) | |
| cv2.rectangle(img, pt1=tuple(bbox_coordinate[0:2]), pt2=tuple(bbox_coordinate[2:4]), color=color, thickness=2) | |
| return img | |
| def get_coords_grid(h_end, w_end, h_start=0, w_start=0, h_steps=None, w_steps=None, is_normalize=False): | |
| if h_steps is None: | |
| h_steps = int(h_end - h_start) + 1 | |
| if w_steps is None: | |
| w_steps = int(w_end - w_start) + 1 | |
| y = torch.linspace(h_start, h_end, h_steps) | |
| x = torch.linspace(w_start, w_end, w_steps) | |
| if is_normalize: | |
| y = y / h_end | |
| x = x / w_end | |
| coords = torch.meshgrid(y, x) | |
| coords = torch.stack(coords[::-1], dim=0) | |
| return coords | |
| def get_coords_grid_float(ht, wd, scale, is_normalize=False): | |
| y = torch.linspace(0, scale, ht + 2) | |
| x = torch.linspace(0, scale, wd + 2) | |
| if is_normalize: | |
| y = y/scale | |
| x = x/scale | |
| coords = torch.meshgrid(y[1:-1], x[1:-1]) | |
| coords = torch.stack(coords[::-1], dim=0) | |
| return coords | |
| def get_coords_vector_float(len, scale, is_normalize=False): | |
| x = torch.linspace(0, scale, len+2) | |
| if is_normalize: | |
| x = x/scale | |
| coords = torch.meshgrid(x[1:-1], torch.tensor([0.])) | |
| coords = torch.stack(coords[::-1], dim=0) | |
| return coords | |
| class Logger(object): | |
| def __init__(self, filename="Default.log", is_terminal_show=True): | |
| self.is_terminal_show = is_terminal_show | |
| if self.is_terminal_show: | |
| self.terminal = sys.stdout | |
| self.log = open(filename, "a") | |
| def write(self, message): | |
| if self.is_terminal_show: | |
| self.terminal.write(message) | |
| self.log.write(message) | |
| self.flush() | |
| def flush(self): | |
| if self.is_terminal_show: | |
| self.terminal.flush() | |
| self.log.flush() | |
| class ParamsParser: | |
| def __init__(self, project_file): | |
| self.params = yaml.safe_load(open(project_file).read()) | |
| def __getattr__(self, item): | |
| return self.params.get(item, None) | |
| def get_all_dict(dict_infos: dict) -> dict: | |
| return_dict = {} | |
| for key, value in dict_infos.items(): | |
| if not isinstance(value, dict): | |
| return_dict[key] = value | |
| else: | |
| return_dict = dict(return_dict.items(), **get_all_dict(value)) | |
| return return_dict | |
| def make_numpy_img(tensor_data): | |
| if len(tensor_data.shape) == 2: | |
| tensor_data = tensor_data.unsqueeze(2) | |
| tensor_data = torch.cat((tensor_data, tensor_data, tensor_data), dim=2) | |
| elif tensor_data.size(0) == 1: | |
| tensor_data = tensor_data.permute((1, 2, 0)) | |
| tensor_data = torch.cat((tensor_data, tensor_data, tensor_data), dim=2) | |
| elif tensor_data.size(0) == 3: | |
| tensor_data = tensor_data.permute((1, 2, 0)) | |
| elif tensor_data.size(2) == 3: | |
| pass | |
| else: | |
| raise Exception('tensor_data apply to make_numpy_img error') | |
| vis_img = tensor_data.detach().cpu().numpy() | |
| return vis_img | |
| def print_infos(logger, writer, infos: dict): | |
| keys = list(infos.keys()) | |
| values = list(infos.values()) | |
| infos_str = 'Pattern: %s [%d,%d][%d,%d], lr: %5f, fps_data_load: %.2f, fps: %.2f' % tuple(values[:8]) | |
| if len(values) > 8: | |
| extra_infos = [f', {x}: {y:.4f}' for x, y in zip(keys[8:], values[8:])] | |
| infos_str = infos_str + ''.join(extra_infos) | |
| logger.write(infos_str + '\n') | |
| writer.add_scalar('%s/lr' % infos['pattern'], infos['lr'], | |
| infos['epoch_id'] * infos['batch_num'] + infos['batch_id']) | |
| for key, value in zip(keys[8:], values[8:]): | |
| writer.add_scalar(f'%s/%s' % (infos['pattern'], key), value, | |
| infos['epoch_id'] * infos['batch_num'] + infos['batch_id']) | |
| def invert_affine(origin_imgs, preds, pattern='train'): | |
| if pattern == 'val': | |
| for i in range(len(preds)): | |
| if len(preds[i]['rois']) == 0: | |
| continue | |
| else: | |
| old_h, old_w, _ = origin_imgs[i].shape | |
| preds[i]['rois'][:, [0, 2]] = preds[i]['rois'][:, [0, 2]] / (512 / old_w) | |
| preds[i]['rois'][:, [1, 3]] = preds[i]['rois'][:, [1, 3]] / (512 / old_h) | |
| return preds | |
| def save_output_infos(input, output, vis_dir, pattern, epoch_id, batch_id): | |
| flows, pf1s, pf2s = output | |
| k = np.clip(int(0.2 * len(flows[0])), a_min=2, a_max=len(flows[0])) | |
| ids = np.random.choice(range(len(flows[0])), k, replace=False) | |
| for img_id in ids: | |
| img1, img2 = input['ori_img1'][img_id:img_id+1].to(flows[0].device), input['ori_img2'][img_id:img_id+1].to(flows[0].device) | |
| # call the network with image pair batches and actions | |
| flow = flows[0][img_id:img_id+1] | |
| warps = flow_to_warp(flow) | |
| warped_img2 = resample(img2, warps) | |
| ori_img1 = make_numpy_img(img1[0]) / 255. | |
| ori_img2 = make_numpy_img(img2[0]) / 255. | |
| warped_img2 = make_numpy_img(warped_img2[0]) / 255. | |
| flow_amplitude = torch.sqrt(flow[0, 0:1, ...] ** 2 + flow[0, 1:2, ...] ** 2) | |
| flow_amplitude = make_numpy_img(flow_amplitude) | |
| flow_amplitude = (flow_amplitude - np.min(flow_amplitude)) / (np.max(flow_amplitude) - np.min(flow_amplitude) + 1e-10) | |
| u = make_numpy_img(flow[0, 0:1, ...]) | |
| v = make_numpy_img(flow[0, 1:2, ...]) | |
| vis = np.concatenate([ori_img1, ori_img2, warped_img2, flow_amplitude], axis=0) | |
| vis = np.clip(vis, a_min=0, a_max=1) | |
| file_name = os.path.join(vis_dir, pattern, str(epoch_id) + '_' + str(batch_id) + '.jpg') | |
| plt.imsave(file_name, vis) | |
| def inv_normalize_img(img, prior_mean=[0, 0, 0], prior_std=[1, 1, 1]): | |
| prior_mean = torch.tensor(prior_mean, dtype=torch.float).to(img.device).view(img.size(0), 1, 1) | |
| prior_std = torch.tensor(prior_std, dtype=torch.float).to(img.device).view(img.size(0), 1, 1) | |
| img = img * prior_std + prior_mean | |
| img = img * 255. | |
| img = torch.clamp(img, min=0, max=255) | |
| return img | |
| def save_seg_output_infos(input, output, vis_dir, pattern, epoch_id, batch_id, prior_mean, prior_std): | |
| pred_label = torch.argmax(output, 1) | |
| k = np.clip(int(0.2 * len(pred_label)), a_min=1, a_max=len(pred_label[0])) | |
| ids = np.random.choice(range(len(pred_label)), k, replace=False) | |
| for img_id in ids: | |
| img = input['img'][img_id].to(pred_label.device) | |
| target = input['label'][img_id].to(pred_label.device) | |
| img = make_numpy_img(inv_normalize_img(img, prior_mean, prior_std)) / 255. | |
| target = make_numpy_img(encode_onehot_to_mask(target)) | |
| pred = make_numpy_img(pred_label[img_id]) | |
| vis = np.concatenate([img, pred, target], axis=0) | |
| vis = np.clip(vis, a_min=0, a_max=1) | |
| file_name = os.path.join(vis_dir, pattern, str(epoch_id) + '_' + str(batch_id) + '.jpg') | |
| plt.imsave(file_name, vis) | |
| def set_requires_grad(nets, requires_grad=False): | |
| """Set requies_grad=Fasle for all the networks to avoid unnecessary computations | |
| Parameters: | |
| nets (network list) -- a list of networks | |
| requires_grad (bool) -- whether the networks require gradients or not | |
| """ | |
| if not isinstance(nets, list): | |
| nets = [nets] | |
| for net in nets: | |
| if net is not None: | |
| for param in net.parameters(): | |
| param.requires_grad = requires_grad | |
| def boolean_string(s): | |
| if s not in {'False', 'True'}: | |
| raise ValueError('Not a valid boolean string') | |
| return s == 'True' | |
| def cpt_pxl_cls_acc(pred_idx, target): | |
| pred_idx = torch.reshape(pred_idx, [-1]) | |
| target = torch.reshape(target, [-1]) | |
| return torch.mean((pred_idx.int() == target.int()).float()) | |
| def cpt_batch_psnr(img, img_gt, PIXEL_MAX): | |
| mse = torch.mean((img - img_gt) ** 2, dim=[1, 2, 3]) | |
| psnr = 20 * torch.log10(PIXEL_MAX / torch.sqrt(mse)) | |
| return torch.mean(psnr) | |
| def cpt_psnr(img, img_gt, PIXEL_MAX): | |
| mse = np.mean((img - img_gt) ** 2) | |
| psnr = 20 * np.log10(PIXEL_MAX / np.sqrt(mse)) | |
| return psnr | |
| def cpt_rgb_ssim(img, img_gt): | |
| img = clip_01(img) | |
| img_gt = clip_01(img_gt) | |
| SSIM = 0 | |
| for i in range(3): | |
| tmp = img[:, :, i] | |
| tmp_gt = img_gt[:, :, i] | |
| ssim = sk_cpt_ssim(tmp, tmp_gt) | |
| SSIM = SSIM + ssim | |
| return SSIM / 3.0 | |
| def cpt_ssim(img, img_gt): | |
| img = clip_01(img) | |
| img_gt = clip_01(img_gt) | |
| return sk_cpt_ssim(img, img_gt) | |
| def decode_mask_to_onehot(mask, n_class): | |
| ''' | |
| mask : BxWxH or WxH | |
| n_class : n | |
| return : BxnxWxH or nxWxH | |
| ''' | |
| assert len(mask.shape) in [2, 3], "decode_mask_to_onehot error!" | |
| if len(mask.shape) == 2: | |
| mask = mask.unsqueeze(0) | |
| onehot = torch.zeros((mask.size(0), n_class, mask.size(1), mask.size(2))).to(mask.device) | |
| for i in range(n_class): | |
| onehot[:, i, ...] = mask == i | |
| if len(mask.shape) == 2: | |
| onehot = onehot.squeeze(0) | |
| return onehot | |
| def encode_onehot_to_mask(onehot): | |
| ''' | |
| onehot: tensor, BxnxWxH or nxWxH | |
| output: tensor, BxWxH or WxH | |
| ''' | |
| assert len(onehot.shape) in [3, 4], "encode_onehot_to_mask error!" | |
| mask = torch.argmax(onehot, dim=len(onehot.shape)-3) | |
| return mask | |
| def decode(pred, target=None, *args, **kwargs): | |
| """ | |
| Args: | |
| phase: 'od' | |
| pred: big_cls_1(0), big_reg_1, small_cls_1(2), small_reg_1, big_cls_2(4), big_reg_2, small_cls_2(6), small_reg_2 | |
| target: [[n,5], [n,5]] list of tensor | |
| Returns: | |
| """ | |
| phase = kwargs['phase'] | |
| img_size = kwargs['img_size'] | |
| if phase == 'od': | |
| prior_box_wh = kwargs['prior_box_wh'] | |
| conf_thres = kwargs['conf_thres'] | |
| iou_thres = kwargs['iou_thres'] | |
| conf_type = kwargs['conf_type'] | |
| pred_conf_32_2 = F.softmax(pred[4], dim=1)[:, 1, ...] # B H W | |
| pred_conf_64_2 = F.softmax(pred[6], dim=1)[:, 1, ...] # B H W | |
| obj_mask_32_2 = pred_conf_32_2 > conf_thres # B H W | |
| obj_mask_64_2 = pred_conf_64_2 > conf_thres # B H W | |
| pre_loc_32_2 = pred[1] + pred[5] # B 4 H W | |
| pre_loc_32_2[:, 0::2, ...] *= prior_box_wh[0] | |
| pre_loc_32_2[:, 1::2, ...] *= prior_box_wh[1] | |
| x_y_grid = get_coords_grid(31, 31, 0, 0) | |
| x_y_grid *= 8 | |
| x_y_grid = torch.cat([x_y_grid, x_y_grid], dim=0) | |
| pre_loc_32_2 += x_y_grid.to(pre_loc_32_2.device) | |
| pre_loc_64_2 = pred[3] + pred[7] # B 4 H W | |
| pre_loc_64_2[:, 0::2, ...] *= prior_box_wh[0] | |
| pre_loc_64_2[:, 1::2, ...] *= prior_box_wh[1] | |
| x_y_grid_2 = get_coords_grid(63, 63, 0, 0) | |
| x_y_grid_2 *= 4 | |
| x_y_grid_2 = torch.cat([x_y_grid_2, x_y_grid_2], dim=0) | |
| pre_loc_64_2 += x_y_grid_2.to(pre_loc_32_2.device) | |
| pred_all = [] | |
| for i in range(pre_loc_32_2.size(0)): | |
| score_32 = pred_conf_32_2[i][obj_mask_32_2[i]] # N | |
| score_64 = pred_conf_64_2[i][obj_mask_64_2[i]] # M | |
| loc_32 = pre_loc_32_2[i].permute((1, 2, 0))[obj_mask_32_2[i]] # Nx4 | |
| loc_64 = pre_loc_64_2[i].permute((1, 2, 0))[obj_mask_64_2[i]] # Mx4 | |
| score_list = torch.cat((score_32, score_64), dim=0).detach().cpu().numpy() | |
| boxes_list = torch.cat((loc_32, loc_64), dim=0).detach().cpu().numpy() | |
| boxes_list[:, 0::2] /= img_size[0] | |
| boxes_list[:, 1::2] /= img_size[1] | |
| label_list = np.ones_like(score_list) | |
| # 目标预设150 | |
| boxes_list = boxes_list[:150, :] | |
| score_list = score_list[:150] | |
| label_list = label_list[:150] | |
| boxes, scores, labels = weighted_boxes_fusion([boxes_list], [score_list], [label_list], weights=None, | |
| iou_thr=iou_thres, conf_type=conf_type) | |
| boxes[:, 0::2] *= img_size[0] | |
| boxes[:, 1::2] *= img_size[1] | |
| pred_boxes = np.concatenate((labels.reshape(-1, 1), scores.reshape(-1, 1), boxes), axis=1) | |
| pred_all.append(pred_boxes) | |
| if target is not None: | |
| target_all = [x.cpu().numpy() for x in target] | |
| else: | |
| target_all = None | |
| return {"pred_all": pred_all, "target_all": target_all} | |
| def get_metrics(phase, pred, target): | |
| ''' | |
| pred: logits, tensor, nBatch*nClass*W*H | |
| target: labels, tensor, nBatch*nClass*W*H | |
| ''' | |
| if phase == 'seg': | |
| pred = torch.argmax(pred.detach(), dim=1) | |
| pred = decode_mask_to_onehot(pred, target.size(1)) | |
| # positive samples in ground truth | |
| gt_pos_sum = torch.sum(target == 1, dim=(0, 2, 3)) | |
| # positive prediction in predict mask | |
| pred_pos_sum = torch.sum(pred == 1, dim=(0, 2, 3)) | |
| # cal true positive sample | |
| true_pos_sum = torch.sum((target == 1) * (pred == 1), dim=(0, 2, 3)) | |
| # Precision | |
| precision = true_pos_sum / (pred_pos_sum + 1e-15) | |
| # Recall | |
| recall = true_pos_sum / (gt_pos_sum + 1e-15) | |
| # IoU | |
| IoU = true_pos_sum / (pred_pos_sum + gt_pos_sum - true_pos_sum + 1e-15) | |
| # OA | |
| OA = 1 - (pred_pos_sum + gt_pos_sum - 2 * true_pos_sum) / torch.sum(target >= 0, dim=(0, 2, 3)) | |
| # F1-score | |
| F1_score = 2 * precision * recall / (precision + recall + 1e-15) | |
| return IoU, OA, F1_score | |