Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import cv2 | |
| import torch | |
| class Compose(object): | |
| """Composes several transforms together. | |
| Args: | |
| transforms (list of ``Transform`` objects): list of transforms to compose. | |
| Example: | |
| >>> transforms.Compose([ | |
| >>> transforms.CenterCrop(10), | |
| >>> transforms.ToTensor(), | |
| >>> ]) | |
| """ | |
| def __init__(self, transforms): | |
| self.transforms = transforms | |
| def __call__(self, data): | |
| for t in self.transforms: | |
| data = t(data) | |
| return data | |
| def __repr__(self): | |
| format_string = self.__class__.__name__ + '(' | |
| for t in self.transforms: | |
| format_string += '\n' | |
| format_string += ' {0}'.format(t) | |
| format_string += '\n)' | |
| return format_string | |
| class ConvertUcharToFloat(object): | |
| """ | |
| Convert img form uchar to float32 | |
| """ | |
| def __call__(self, data): | |
| data = [x.astype(np.float32) for x in data] | |
| return data | |
| class RandomContrast(object): | |
| """ | |
| Get random contrast img | |
| """ | |
| def __init__(self, phase, lower=0.8, upper=1.2, prob=0.5): | |
| self.phase = phase | |
| self.lower = lower | |
| self.upper = upper | |
| self.prob = prob | |
| assert self.upper >= self.lower, "contrast upper must be >= lower!" | |
| assert self.lower > 0, "contrast lower must be non-negative!" | |
| def __call__(self, data): | |
| if self.phase in ['od', 'seg']: | |
| img, _ = data | |
| if torch.rand(1) < self.prob: | |
| alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) | |
| img *= alpha.numpy() | |
| return_data = img, _ | |
| elif self.phase == 'cd': | |
| img1, label1, img2, label2 = data | |
| if torch.rand(1) < self.prob: | |
| alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) | |
| img1 *= alpha.numpy() | |
| if torch.rand(1) < self.prob: | |
| alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) | |
| img2 *= alpha.numpy() | |
| return_data = img1, label1, img2, label2 | |
| return return_data | |
| class RandomBrightness(object): | |
| """ | |
| Get random brightness img | |
| """ | |
| def __init__(self, phase, delta=10, prob=0.5): | |
| self.phase = phase | |
| self.delta = delta | |
| self.prob = prob | |
| assert 0. <= self.delta < 255., "brightness delta must between 0 to 255" | |
| def __call__(self, data): | |
| if self.phase in ['od', 'seg']: | |
| img, _ = data | |
| if torch.rand(1) < self.prob: | |
| delta = torch.FloatTensor(1).uniform_(- self.delta, self.delta) | |
| img += delta.numpy() | |
| return_data = img, _ | |
| elif self.phase == 'cd': | |
| img1, label1, img2, label2 = data | |
| if torch.rand(1) < self.prob: | |
| delta = torch.FloatTensor(1).uniform_(- self.delta, self.delta) | |
| img1 += delta.numpy() | |
| if torch.rand(1) < self.prob: | |
| delta = torch.FloatTensor(1).uniform_(- self.delta, self.delta) | |
| img2 += delta.numpy() | |
| return_data = img1, label1, img2, label2 | |
| return return_data | |
| class ConvertColor(object): | |
| """ | |
| Convert img color BGR to HSV or HSV to BGR for later img distortion. | |
| """ | |
| def __init__(self, phase, current='RGB', target='HSV'): | |
| self.phase = phase | |
| self.current = current | |
| self.target = target | |
| def __call__(self, data): | |
| if self.phase in ['od', 'seg']: | |
| img, _ = data | |
| if self.current == 'RGB' and self.target == 'HSV': | |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) | |
| elif self.current == 'HSV' and self.target == 'RGB': | |
| img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) | |
| else: | |
| raise NotImplementedError("Convert color fail!") | |
| return_data = img, _ | |
| elif self.phase == 'cd': | |
| img1, label1, img2, label2 = data | |
| if self.current == 'RGB' and self.target == 'HSV': | |
| img1 = cv2.cvtColor(img1, cv2.COLOR_RGB2HSV) | |
| img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2HSV) | |
| elif self.current == 'HSV' and self.target == 'RGB': | |
| img1 = cv2.cvtColor(img1, cv2.COLOR_HSV2RGB) | |
| img2 = cv2.cvtColor(img2, cv2.COLOR_HSV2RGB) | |
| else: | |
| raise NotImplementedError("Convert color fail!") | |
| return_data = img1, label1, img2, label2 | |
| return return_data | |
| class RandomSaturation(object): | |
| """ | |
| get random saturation img | |
| apply the restriction on saturation S | |
| """ | |
| def __init__(self, phase, lower=0.8, upper=1.2, prob=0.5): | |
| self.phase = phase | |
| self.lower = lower | |
| self.upper = upper | |
| self.prob = prob | |
| assert self.upper >= self.lower, "saturation upper must be >= lower!" | |
| assert self.lower > 0, "saturation lower must be non-negative!" | |
| def __call__(self, data): | |
| if self.phase in ['od', 'seg']: | |
| img, _ = data | |
| if torch.rand(1) < self.prob: | |
| alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) | |
| img[:, :, 1] *= alpha.numpy() | |
| return_data = img, _ | |
| elif self.phase == 'cd': | |
| img1, label1, img2, label2 = data | |
| if torch.rand(1) < self.prob: | |
| alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) | |
| img1[:, :, 1] *= alpha.numpy() | |
| if torch.rand(1) < self.prob: | |
| alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) | |
| img2[:, :, 1] *= alpha.numpy() | |
| return_data = img1, label1, img2, label2 | |
| return return_data | |
| class RandomHue(object): | |
| """ | |
| get random Hue img | |
| apply the restriction on Hue H | |
| """ | |
| def __init__(self, phase, delta=10., prob=0.5): | |
| self.phase = phase | |
| self.delta = delta | |
| self.prob = prob | |
| assert 0 <= self.delta < 360, "Hue delta must between 0 to 360!" | |
| def __call__(self, data): | |
| if self.phase in ['od', 'seg']: | |
| img, _ = data | |
| if torch.rand(1) < self.prob: | |
| alpha = torch.FloatTensor(1).uniform_(-self.delta, self.delta) | |
| img[:, :, 0] += alpha.numpy() | |
| img[:, :, 0][img[:, :, 0] > 360.0] -= 360.0 | |
| img[:, :, 0][img[:, :, 0] < 0.0] += 360.0 | |
| return_data = img, _ | |
| elif self.phase == 'cd': | |
| img1, label1, img2, label2 = data | |
| if torch.rand(1) < self.prob: | |
| alpha = torch.FloatTensor(1).uniform_(-self.delta, self.delta) | |
| img1[:, :, 0] += alpha.numpy() | |
| img1[:, :, 0][img1[:, :, 0] > 360.0] -= 360.0 | |
| img1[:, :, 0][img1[:, :, 0] < 0.0] += 360.0 | |
| if torch.rand(1) < self.prob: | |
| alpha = torch.FloatTensor(1).uniform_(-self.delta, self.delta) | |
| img2[:, :, 0] += alpha.numpy() | |
| img2[:, :, 0][img2[:, :, 0] > 360.0] -= 360.0 | |
| img2[:, :, 0][img2[:, :, 0] < 0.0] += 360.0 | |
| return_data = img1, label1, img2, label2 | |
| return return_data | |
| class RandomChannelNoise(object): | |
| """ | |
| Get random shuffle channels | |
| """ | |
| def __init__(self, phase, prob=0.4): | |
| self.phase = phase | |
| self.prob = prob | |
| self.perms = ((0, 1, 2), (0, 2, 1), | |
| (1, 0, 2), (1, 2, 0), | |
| (2, 0, 1), (2, 1, 0)) | |
| def __call__(self, data): | |
| if self.phase in ['od', 'seg']: | |
| img, _ = data | |
| if torch.rand(1) < self.prob: | |
| shuffle_factor = self.perms[torch.randint(0, len(self.perms), size=[])] | |
| img = img[:, :, shuffle_factor] | |
| return_data = img, _ | |
| elif self.phase == 'cd': | |
| img1, label1, img2, label2 = data | |
| if torch.rand(1) < self.prob: | |
| shuffle_factor = self.perms[torch.randint(0, len(self.perms), size=[])] | |
| img1 = img1[:, :, shuffle_factor] | |
| if torch.rand(1) < self.prob: | |
| shuffle_factor = self.perms[torch.randint(0, len(self.perms), size=[])] | |
| img2 = img2[:, :, shuffle_factor] | |
| return_data = img1, label1, img2, label2 | |
| return return_data | |
| class ImgDistortion(object): | |
| """ | |
| Change img by distortion | |
| """ | |
| def __init__(self, phase, prob=0.5): | |
| self.phase = phase | |
| self.prob = prob | |
| self.operation = [ | |
| RandomContrast(phase), | |
| ConvertColor(phase, current='RGB', target='HSV'), | |
| RandomSaturation(phase), | |
| RandomHue(phase), | |
| ConvertColor(phase, current='HSV', target='RGB'), | |
| RandomContrast(phase) | |
| ] | |
| self.random_brightness = RandomBrightness(phase) | |
| self.random_light_noise = RandomChannelNoise(phase) | |
| def __call__(self, data): | |
| if torch.rand(1) < self.prob: | |
| data = self.random_brightness(data) | |
| if torch.rand(1) < self.prob: | |
| distort = Compose(self.operation[:-1]) | |
| else: | |
| distort = Compose(self.operation[1:]) | |
| data = distort(data) | |
| data = self.random_light_noise(data) | |
| return data | |
| class ExpandImg(object): | |
| """ | |
| Get expand img | |
| """ | |
| def __init__(self, phase, prior_mean, prob=0.5, expand_ratio=0.2): | |
| self.phase = phase | |
| self.prior_mean = np.array(prior_mean) * 255 | |
| self.prob = prob | |
| self.expand_ratio = expand_ratio | |
| def __call__(self, data): | |
| if self.phase == 'seg': | |
| img, label = data | |
| if torch.rand(1) < self.prob: | |
| return data | |
| height, width, channels = img.shape | |
| ratio_width = self.expand_ratio * torch.rand([]) | |
| ratio_height = self.expand_ratio * torch.rand([]) | |
| left, right = torch.randint(high=int(max(1, width * ratio_width)), size=[2]) | |
| top, bottom = torch.randint(high=int(max(1, width * ratio_height)), size=[2]) | |
| img = cv2.copyMakeBorder( | |
| img, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=self.prior_mean) | |
| label = cv2.copyMakeBorder( | |
| label, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=0) | |
| return img, label | |
| elif self.phase == 'cd': | |
| img1, label1, img2, label2 = data | |
| if torch.rand(1) < self.prob: | |
| return data | |
| height, width, channels = img1.shape | |
| ratio_width = self.expand_ratio * torch.rand([]) | |
| ratio_height = self.expand_ratio * torch.rand([]) | |
| left, right = torch.randint(high=int(max(1, width * ratio_width)), size=[2]) | |
| top, bottom = torch.randint(high=int(max(1, width * ratio_height)), size=[2]) | |
| img1 = cv2.copyMakeBorder( | |
| img1, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=self.prior_mean) | |
| label1 = cv2.copyMakeBorder( | |
| label1, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=0) | |
| img2 = cv2.copyMakeBorder( | |
| img2, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=self.prior_mean) | |
| label2 = cv2.copyMakeBorder( | |
| label2, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=0) | |
| return img1, label1, img2, label2 | |
| elif self.phase == 'od': | |
| if torch.rand(1) < self.prob: | |
| return data | |
| img, label = data | |
| height, width, channels = img.shape | |
| ratio_width = self.expand_ratio * torch.rand([]) | |
| ratio_height = self.expand_ratio * torch.rand([]) | |
| left, right = torch.randint(high=int(max(1, width * ratio_width)), size=[2]) | |
| top, bottom = torch.randint(high=int(max(1, width * ratio_height)), size=[2]) | |
| left = int(left) | |
| right = int(right) | |
| top = int(top) | |
| bottom = int(bottom) | |
| img = cv2.copyMakeBorder( | |
| img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=self.prior_mean) | |
| label[:, 1::2] += left | |
| label[:, 2::2] += top | |
| return img, label | |
| class RandomSampleCrop(object): | |
| """ | |
| Crop | |
| Arguments: | |
| img (Image): the image being input during training | |
| boxes (Tensor): the original bounding boxes in pt form | |
| label (Tensor): the class label for each bbox | |
| mode (float tuple): the min and max jaccard overlaps | |
| Return: | |
| (img, boxes, classes) | |
| img (Image): the cropped image | |
| boxes (Tensor): the adjusted bounding boxes in pt form | |
| label (Tensor): the class label for each bbox | |
| """ | |
| def __init__(self, | |
| phase, | |
| original_size=[512, 512], | |
| prob=0.5, | |
| crop_scale_ratios_range=[0.8, 1.2], | |
| aspect_ratio_range=[4./5, 5./4]): | |
| self.phase = phase | |
| self.prob = prob | |
| self.scale_range = crop_scale_ratios_range | |
| self.original_size = original_size | |
| self.aspect_ratio_range = aspect_ratio_range # h/w | |
| self.max_try_times = 500 | |
| def __call__(self, data): | |
| if self.phase == 'seg': | |
| img, label = data | |
| w, h, c = img.shape | |
| if torch.rand(1) < self.prob: | |
| return data | |
| else: | |
| try_times = 0 | |
| while try_times < self.max_try_times: | |
| crop_w = torch.randint( | |
| min(w, int(self.scale_range[0] * self.original_size[0])), | |
| min(w + 1, int(self.scale_range[1] * self.original_size[0])), | |
| size=[] | |
| ) | |
| crop_h = torch.randint( | |
| min(h, int(self.scale_range[0] * self.original_size[1])), | |
| min(h + 1, int(self.scale_range[1] * self.original_size[1])), | |
| size=[] | |
| ) | |
| # aspect ratio constraint | |
| if self.aspect_ratio_range[0] < crop_h / crop_w < self.aspect_ratio_range[1]: | |
| break | |
| else: | |
| try_times += 1 | |
| if try_times >= self.max_try_times: | |
| print("try times over max threshold!", flush=True) | |
| return img, label | |
| left = torch.randint(0, w - crop_w + 1, size=[]) | |
| top = torch.randint(0, h - crop_h + 1, size=[]) | |
| img = img[top:(top + crop_h), left:(left + crop_w), :] | |
| label = label[top:(top + crop_h), left:(left + crop_w)] | |
| return img, label | |
| elif self.phase == 'od': | |
| if torch.rand(1) < self.prob: | |
| return data | |
| img, label = data | |
| w, h, c = img.shape | |
| while True: | |
| crop_w = torch.randint( | |
| min(w, int(self.scale_range[0] * self.original_size[0])), | |
| min(w + 1, int(self.scale_range[1] * self.original_size[0])), | |
| size=[] | |
| ) | |
| crop_h = torch.randint( | |
| min(h, int(self.scale_range[0] * self.original_size[1])), | |
| min(h + 1, int(self.scale_range[1] * self.original_size[1])), | |
| size=[] | |
| ) | |
| # aspect ratio constraint | |
| if self.aspect_ratio_range[0] < crop_h / crop_w < self.aspect_ratio_range[1]: | |
| break | |
| left = torch.randint(0, w - crop_w + 1, size=[]) | |
| top = torch.randint(0, h - crop_h + 1, size=[]) | |
| left = left.numpy() | |
| top = top.numpy() | |
| crop_h = crop_h.numpy() | |
| crop_w = crop_w.numpy() | |
| img = img[top:(top + crop_h), left:(left + crop_w), :] | |
| if len(label): | |
| # keep overlap with gt box IF center in sampled patch | |
| centers = (label[:, 1:3] + label[:, 3:]) / 2.0 | |
| # mask in all gt boxes that above and to the left of centers | |
| m1 = (left <= centers[:, 0]) * (top <= centers[:, 1]) | |
| # mask in all gt boxes that under and to the right of centers | |
| m2 = ((left + crop_w) >= centers[:, 0]) * ((top + crop_h) > centers[:, 1]) | |
| # mask in that both m1 and m2 are true | |
| mask = m1 * m2 | |
| # take only matching gt boxes | |
| current_label = label[mask, :] | |
| # adjust to crop (by substracting crop's left,top) | |
| current_label[:, 1::2] -= left | |
| current_label[:, 2::2] -= top | |
| label = current_label | |
| return img, label | |
| class RandomMirror(object): | |
| def __init__(self, phase, prob=0.5): | |
| self.phase = phase | |
| self.prob = prob | |
| def __call__(self, data): | |
| if self.phase == 'seg': | |
| img, label = data | |
| if torch.rand(1) < self.prob: | |
| img = img[:, ::-1] | |
| label = label[:, ::-1] | |
| return img, label | |
| elif self.phase == 'cd': | |
| img1, label1, img2, label2 = data | |
| if torch.rand(1) < self.prob: | |
| img1 = img1[:, ::-1] | |
| label1 = label1[:, ::-1] | |
| img2 = img2[:, ::-1] | |
| label2 = label2[:, ::-1] | |
| return img1, label1, img2, label2 | |
| elif self.phase == 'od': | |
| img, label = data | |
| if torch.rand(1) < self.prob: | |
| _, width, _ = img.shape | |
| img = img[:, ::-1] | |
| label[:, 1::2] = width - label[:, 3::-2] | |
| return img, label | |
| class RandomFlipV(object): | |
| def __init__(self, phase, prob=0.5): | |
| self.phase = phase | |
| self.prob = prob | |
| def __call__(self, data): | |
| if self.phase == 'seg': | |
| img, label = data | |
| if torch.rand(1) < self.prob: | |
| img = img[::-1, :] | |
| label = label[::-1, :] | |
| return img, label | |
| elif self.phase == 'cd': | |
| img1, label1, img2, label2 = data | |
| if torch.rand(1) < self.prob: | |
| img1 = img1[::-1, :] | |
| label1 = label1[::-1, :] | |
| img2 = img2[::-1, :] | |
| label2 = label2[::-1, :] | |
| return img1, label1, img2, label2 | |
| elif self.phase == 'od': | |
| img, label = data | |
| if torch.rand(1) < self.prob: | |
| height, _, _ = img.shape | |
| img = img[::-1, :] | |
| label[:, 2::2] = height - label[:, 4:1:-2] | |
| return img, label | |
| class Resize(object): | |
| def __init__(self, phase, size): | |
| self.phase = phase | |
| self.size = size | |
| def __call__(self, data): | |
| if self.phase == 'seg': | |
| img, label = data | |
| img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR) | |
| # for label | |
| label = cv2.resize(label, self.size, interpolation=cv2.INTER_NEAREST) | |
| return img, label | |
| elif self.phase == 'cd': | |
| img1, label1, img2, label2 = data | |
| img1 = cv2.resize(img1, self.size, interpolation=cv2.INTER_LINEAR) | |
| img2 = cv2.resize(img2, self.size, interpolation=cv2.INTER_LINEAR) | |
| # for label | |
| label1 = cv2.resize(label1, self.size, interpolation=cv2.INTER_NEAREST) | |
| label2 = cv2.resize(label2, self.size, interpolation=cv2.INTER_NEAREST) | |
| return img1, label1, img2, label2 | |
| elif self.phase == 'od': | |
| img, label = data | |
| height, width, _ = img.shape | |
| img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR) | |
| label[:, 1::2] = label[:, 1::2] / width * self.size[0] | |
| label[:, 2::2] = label[:, 2::2] / height * self.size[1] | |
| return img, label | |
| class Normalize(object): | |
| def __init__(self, phase, prior_mean, prior_std): | |
| self.phase = phase | |
| self.prior_mean = np.array([[prior_mean]], dtype=np.float32) | |
| self.prior_std = np.array([[prior_std]], dtype=np.float32) | |
| def __call__(self, data): | |
| if self.phase in ['od', 'seg']: | |
| img, _ = data | |
| img = img / 255. | |
| img = (img - self.prior_mean) / (self.prior_std + 1e-10) | |
| return img, _ | |
| elif self.phase == 'cd': | |
| img1, label1, img2, label2 = data | |
| img1 = img1 / 255. | |
| img1 = (img1 - self.prior_mean) / (self.prior_std + 1e-10) | |
| img2 = img2 / 255. | |
| img2 = (img2 - self.prior_mean) / (self.prior_std + 1e-10) | |
| return img1, label1, img2, label2 | |
| class InvNormalize(object): | |
| def __init__(self, prior_mean, prior_std): | |
| self.prior_mean = np.array([[prior_mean]], dtype=np.float32) | |
| self.prior_std = np.array([[prior_std]], dtype=np.float32) | |
| def __call__(self, img): | |
| img = img * self.prior_std + self.prior_mean | |
| img = img * 255. | |
| img = np.clip(img, a_min=0, a_max=255) | |
| return img | |
| class Augmentations(object): | |
| def __init__(self, size, prior_mean=0, prior_std=1, pattern='train', phase='seg', *args, **kwargs): | |
| self.size = size | |
| self.prior_mean = prior_mean | |
| self.prior_std = prior_std | |
| self.phase = phase | |
| augments = { | |
| 'train': Compose([ | |
| ConvertUcharToFloat(), | |
| ImgDistortion(self.phase), | |
| ExpandImg(self.phase, self.prior_mean), | |
| RandomSampleCrop(self.phase, original_size=self.size), | |
| RandomMirror(self.phase), | |
| RandomFlipV(self.phase), | |
| Resize(self.phase, self.size), | |
| Normalize(self.phase, self.prior_mean, self.prior_std), | |
| ]), | |
| 'val': Compose([ | |
| ConvertUcharToFloat(), | |
| Resize(self.phase, self.size), | |
| Normalize(self.phase, self.prior_mean, self.prior_std), | |
| ]), | |
| 'test': Compose([ | |
| ConvertUcharToFloat(), | |
| Resize(self.phase, self.size), | |
| Normalize(self.phase, self.prior_mean, self.prior_std), | |
| ]) | |
| } | |
| self.augment = augments[pattern] | |
| def __call__(self, data): | |
| return self.augment(data) | |