Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import numpy as np | |
| import torch | |
| from mmdet.registry import TASK_UTILS | |
| from .random_sampler import RandomSampler | |
| class IoUBalancedNegSampler(RandomSampler): | |
| """IoU Balanced Sampling. | |
| arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019) | |
| Sampling proposals according to their IoU. `floor_fraction` of needed RoIs | |
| are sampled from proposals whose IoU are lower than `floor_thr` randomly. | |
| The others are sampled from proposals whose IoU are higher than | |
| `floor_thr`. These proposals are sampled from some bins evenly, which are | |
| split by `num_bins` via IoU evenly. | |
| Args: | |
| num (int): number of proposals. | |
| pos_fraction (float): fraction of positive proposals. | |
| floor_thr (float): threshold (minimum) IoU for IoU balanced sampling, | |
| set to -1 if all using IoU balanced sampling. | |
| floor_fraction (float): sampling fraction of proposals under floor_thr. | |
| num_bins (int): number of bins in IoU balanced sampling. | |
| """ | |
| def __init__(self, | |
| num, | |
| pos_fraction, | |
| floor_thr=-1, | |
| floor_fraction=0, | |
| num_bins=3, | |
| **kwargs): | |
| super(IoUBalancedNegSampler, self).__init__(num, pos_fraction, | |
| **kwargs) | |
| assert floor_thr >= 0 or floor_thr == -1 | |
| assert 0 <= floor_fraction <= 1 | |
| assert num_bins >= 1 | |
| self.floor_thr = floor_thr | |
| self.floor_fraction = floor_fraction | |
| self.num_bins = num_bins | |
| def sample_via_interval(self, max_overlaps, full_set, num_expected): | |
| """Sample according to the iou interval. | |
| Args: | |
| max_overlaps (torch.Tensor): IoU between bounding boxes and ground | |
| truth boxes. | |
| full_set (set(int)): A full set of indices of boxes。 | |
| num_expected (int): Number of expected samples。 | |
| Returns: | |
| np.ndarray: Indices of samples | |
| """ | |
| max_iou = max_overlaps.max() | |
| iou_interval = (max_iou - self.floor_thr) / self.num_bins | |
| per_num_expected = int(num_expected / self.num_bins) | |
| sampled_inds = [] | |
| for i in range(self.num_bins): | |
| start_iou = self.floor_thr + i * iou_interval | |
| end_iou = self.floor_thr + (i + 1) * iou_interval | |
| tmp_set = set( | |
| np.where( | |
| np.logical_and(max_overlaps >= start_iou, | |
| max_overlaps < end_iou))[0]) | |
| tmp_inds = list(tmp_set & full_set) | |
| if len(tmp_inds) > per_num_expected: | |
| tmp_sampled_set = self.random_choice(tmp_inds, | |
| per_num_expected) | |
| else: | |
| tmp_sampled_set = np.array(tmp_inds, dtype=np.int64) | |
| sampled_inds.append(tmp_sampled_set) | |
| sampled_inds = np.concatenate(sampled_inds) | |
| if len(sampled_inds) < num_expected: | |
| num_extra = num_expected - len(sampled_inds) | |
| extra_inds = np.array(list(full_set - set(sampled_inds))) | |
| if len(extra_inds) > num_extra: | |
| extra_inds = self.random_choice(extra_inds, num_extra) | |
| sampled_inds = np.concatenate([sampled_inds, extra_inds]) | |
| return sampled_inds | |
| def _sample_neg(self, assign_result, num_expected, **kwargs): | |
| """Sample negative boxes. | |
| Args: | |
| assign_result (:obj:`AssignResult`): The assigned results of boxes. | |
| num_expected (int): The number of expected negative samples | |
| Returns: | |
| Tensor or ndarray: sampled indices. | |
| """ | |
| neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False) | |
| if neg_inds.numel() != 0: | |
| neg_inds = neg_inds.squeeze(1) | |
| if len(neg_inds) <= num_expected: | |
| return neg_inds | |
| else: | |
| max_overlaps = assign_result.max_overlaps.cpu().numpy() | |
| # balance sampling for negative samples | |
| neg_set = set(neg_inds.cpu().numpy()) | |
| if self.floor_thr > 0: | |
| floor_set = set( | |
| np.where( | |
| np.logical_and(max_overlaps >= 0, | |
| max_overlaps < self.floor_thr))[0]) | |
| iou_sampling_set = set( | |
| np.where(max_overlaps >= self.floor_thr)[0]) | |
| elif self.floor_thr == 0: | |
| floor_set = set(np.where(max_overlaps == 0)[0]) | |
| iou_sampling_set = set( | |
| np.where(max_overlaps > self.floor_thr)[0]) | |
| else: | |
| floor_set = set() | |
| iou_sampling_set = set( | |
| np.where(max_overlaps > self.floor_thr)[0]) | |
| # for sampling interval calculation | |
| self.floor_thr = 0 | |
| floor_neg_inds = list(floor_set & neg_set) | |
| iou_sampling_neg_inds = list(iou_sampling_set & neg_set) | |
| num_expected_iou_sampling = int(num_expected * | |
| (1 - self.floor_fraction)) | |
| if len(iou_sampling_neg_inds) > num_expected_iou_sampling: | |
| if self.num_bins >= 2: | |
| iou_sampled_inds = self.sample_via_interval( | |
| max_overlaps, set(iou_sampling_neg_inds), | |
| num_expected_iou_sampling) | |
| else: | |
| iou_sampled_inds = self.random_choice( | |
| iou_sampling_neg_inds, num_expected_iou_sampling) | |
| else: | |
| iou_sampled_inds = np.array( | |
| iou_sampling_neg_inds, dtype=np.int64) | |
| num_expected_floor = num_expected - len(iou_sampled_inds) | |
| if len(floor_neg_inds) > num_expected_floor: | |
| sampled_floor_inds = self.random_choice( | |
| floor_neg_inds, num_expected_floor) | |
| else: | |
| sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int64) | |
| sampled_inds = np.concatenate( | |
| (sampled_floor_inds, iou_sampled_inds)) | |
| if len(sampled_inds) < num_expected: | |
| num_extra = num_expected - len(sampled_inds) | |
| extra_inds = np.array(list(neg_set - set(sampled_inds))) | |
| if len(extra_inds) > num_extra: | |
| extra_inds = self.random_choice(extra_inds, num_extra) | |
| sampled_inds = np.concatenate((sampled_inds, extra_inds)) | |
| sampled_inds = torch.from_numpy(sampled_inds).long().to( | |
| assign_result.gt_inds.device) | |
| return sampled_inds | |