Spaces:
Running
Running
| import os | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| import torchvision | |
| import json | |
| from torch.utils.data import Dataset | |
| from torch.utils.data import DataLoader | |
| class SquarePad: | |
| def __call__(self, image): | |
| max_wh = max(image.size) | |
| p_left, p_top = [(max_wh - s) // 2 for s in image.size] | |
| p_right, p_bottom = [max_wh - (s+pad) for s, pad in zip(image.size, [p_left, p_top])] | |
| padding = (p_left, p_top, p_right, p_bottom) | |
| return transforms.functional.pad(image, padding, padding_mode = 'edge') | |
| class BaseDataset(Dataset): | |
| def __init__(self, img_size, dataset_path, inpainter): | |
| self.IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| self.IMAGENET_STD = [0.229, 0.224, 0.225] | |
| self.img_size = img_size | |
| self.dataset_path = dataset_path | |
| self.inpainter = inpainter | |
| self.json_path = os.path.join(dataset_path, 'DFDS_V2/DFDS_V2.0_2Percent.json') | |
| # self.json_path = os.path.join(dataset_path, 'DFDS_V2.0_2Percent.json') | |
| self.data = self.load_json() | |
| self.data_train = self.data[0:500] | |
| self.rgb_transform = transforms.Compose([ | |
| SquarePad(), | |
| transforms.Resize((img_size, img_size), interpolation = transforms.InterpolationMode.BICUBIC), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean = self.IMAGENET_MEAN, std = self.IMAGENET_STD) | |
| ]) | |
| def load_json(self): | |
| with open(self.json_path, 'r') as file: | |
| data = json.load(file) | |
| return data | |
| class TrainDataset(BaseDataset): | |
| def __init__(self, img_size, dataset_path): | |
| super().__init__(img_size = img_size, dataset_path = dataset_path, inpainter = None) | |
| self.gt_transform = transforms.Compose([ | |
| SquarePad(), | |
| transforms.Resize((img_size, img_size), interpolation = transforms.InterpolationMode.BICUBIC), | |
| transforms.ToTensor()] | |
| ) | |
| self.img_paths_pos, self.img_paths_neg, self.mask_paths_neg = self.load_dataset() | |
| def load_dataset(self): | |
| positive_imgs = [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] | |
| negative_imgs = [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SD1_Inpaint'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SD1.5_Inpaint'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SD2_Inpaint'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SDXL_Inpaint'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SD3_Inpaint'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SD3.5_Inpaint'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['kadinsky2.2_Inpaint'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['kadinsky3.1_Inpaint'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['FLUX_SHNELL_Inpaint'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['FLUX_DEV_FILL'].lstrip('/')) for data in self.data_train] | |
| negative_masks = [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \ | |
| [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] | |
| return positive_imgs, negative_imgs, negative_masks | |
| def __len__(self): | |
| return len(self.data_train) * 3 | |
| def __getitem__(self, idx): | |
| img_path_pos, img_path_neg, gt_neg = self.img_paths_pos[idx], self.img_paths_neg[idx], self.mask_paths_neg[idx] | |
| img_pos = Image.open(img_path_pos.replace('/Open_V7/','')).convert('RGB') | |
| img_neg = Image.open(img_path_neg.replace('/Open_V7/','')).convert('RGB') | |
| rgb_pos = self.rgb_transform(img_pos) | |
| rgb_neg = self.rgb_transform(img_neg) | |
| gt_pos = torch.zeros([1, img_pos.size[1], img_pos.size[0]]) | |
| gt_pos = torchvision.transforms.functional.to_pil_image(gt_pos) | |
| gt_pos = self.gt_transform(gt_pos) | |
| gt_neg = Image.open(gt_neg.replace('/Open_V7/','')).convert('L') | |
| gt_neg = self.gt_transform(gt_neg) | |
| gt_neg = torch.where(gt_neg > 0.5, 1., .0) | |
| return rgb_pos, gt_pos, rgb_neg, gt_neg | |
| class TestDataset(BaseDataset): | |
| def __init__(self, img_size, dataset_path, inpainter): | |
| super().__init__(img_size = img_size, dataset_path = dataset_path, inpainter = inpainter) | |
| self.gt_transform = transforms.Compose([ | |
| transforms.ToTensor()]) | |
| self.img_paths, self.mask_paths, self.labels = self.load_dataset() | |
| def load_dataset(self): | |
| positive_imgs = [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data[500:600]] | |
| positive_masks = [None for data in self.data[500:600]] | |
| negative_imgs = [os.path.join(self.dataset_path, data['masks'][0]['inpainters'][self.inpainter].lstrip('/')) for data in self.data[600:700]] | |
| negative_masks = [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data[600:700]] | |
| labels = [0.0 for data in self.data[500:600]] + [1.0 for data in self.data[600:700]] | |
| return positive_imgs + negative_imgs, positive_masks + negative_masks, labels | |
| def __len__(self): | |
| return len(self.data[500:700]) | |
| def __getitem__(self, idx): | |
| img_path, gt, label = self.img_paths[idx], self.mask_paths[idx], self.labels[idx] | |
| img = Image.open(img_path.replace('/Open_V7/','').replace('data/', 'data/DFDS_V2/')).convert('RGB') | |
| rgb = self.rgb_transform(img) | |
| if gt == None: | |
| gt = torch.zeros( | |
| [1, img.size[1], img.size[0]]) | |
| else: | |
| gt = Image.open(gt.replace('/Open_V7/','').replace('data/', 'data/DFDS_V2/')).convert('L') | |
| gt = self.gt_transform(gt) | |
| gt = torch.where(gt > 0.5, 1., .0) | |
| return rgb, label, gt, img_path | |
| def get_data_loader(split, img_size, batch_size, dataset_path, inpainter = None): | |
| if split == 'train': | |
| dataset = TrainDataset(img_size, dataset_path) | |
| data_loader = DataLoader(dataset = dataset, batch_size = batch_size, shuffle = True, num_workers = 8, drop_last = True, pin_memory = False) | |
| elif split == 'test': | |
| dataset = TestDataset(img_size, dataset_path, inpainter) | |
| data_loader = DataLoader(dataset = dataset, batch_size = batch_size, shuffle = False, num_workers = 8, drop_last = False, pin_memory = False) | |
| return data_loader |