import os import torch from PIL import Image from torchvision import transforms import torchvision import json from torch.utils.data import Dataset from torch.utils.data import DataLoader class SquarePad: def __call__(self, image): max_wh = max(image.size) p_left, p_top = [(max_wh - s) // 2 for s in image.size] p_right, p_bottom = [max_wh - (s+pad) for s, pad in zip(image.size, [p_left, p_top])] padding = (p_left, p_top, p_right, p_bottom) return transforms.functional.pad(image, padding, padding_mode = 'edge') class BaseDataset(Dataset): def __init__(self, img_size, dataset_path, inpainter): self.IMAGENET_MEAN = [0.485, 0.456, 0.406] self.IMAGENET_STD = [0.229, 0.224, 0.225] self.img_size = img_size self.dataset_path = dataset_path self.inpainter = inpainter self.json_path = os.path.join(dataset_path, 'DFDS_V2/DFDS_V2.0_2Percent.json') # self.json_path = os.path.join(dataset_path, 'DFDS_V2.0_2Percent.json') self.data = self.load_json() self.data_train = self.data[0:500] self.rgb_transform = transforms.Compose([ SquarePad(), transforms.Resize((img_size, img_size), interpolation = transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean = self.IMAGENET_MEAN, std = self.IMAGENET_STD) ]) def load_json(self): with open(self.json_path, 'r') as file: data = json.load(file) return data class TrainDataset(BaseDataset): def __init__(self, img_size, dataset_path): super().__init__(img_size = img_size, dataset_path = dataset_path, inpainter = None) self.gt_transform = transforms.Compose([ SquarePad(), transforms.Resize((img_size, img_size), interpolation = transforms.InterpolationMode.BICUBIC), transforms.ToTensor()] ) self.img_paths_pos, self.img_paths_neg, self.mask_paths_neg = self.load_dataset() def load_dataset(self): positive_imgs = [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] negative_imgs = [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SD1_Inpaint'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SD1.5_Inpaint'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SD2_Inpaint'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SDXL_Inpaint'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SD3_Inpaint'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SD3.5_Inpaint'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['kadinsky2.2_Inpaint'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['kadinsky3.1_Inpaint'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['FLUX_SHNELL_Inpaint'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['FLUX_DEV_FILL'].lstrip('/')) for data in self.data_train] negative_masks = [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \ [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] return positive_imgs, negative_imgs, negative_masks def __len__(self): return len(self.data_train) * 3 def __getitem__(self, idx): img_path_pos, img_path_neg, gt_neg = self.img_paths_pos[idx], self.img_paths_neg[idx], self.mask_paths_neg[idx] img_pos = Image.open(img_path_pos.replace('/Open_V7/','')).convert('RGB') img_neg = Image.open(img_path_neg.replace('/Open_V7/','')).convert('RGB') rgb_pos = self.rgb_transform(img_pos) rgb_neg = self.rgb_transform(img_neg) gt_pos = torch.zeros([1, img_pos.size[1], img_pos.size[0]]) gt_pos = torchvision.transforms.functional.to_pil_image(gt_pos) gt_pos = self.gt_transform(gt_pos) gt_neg = Image.open(gt_neg.replace('/Open_V7/','')).convert('L') gt_neg = self.gt_transform(gt_neg) gt_neg = torch.where(gt_neg > 0.5, 1., .0) return rgb_pos, gt_pos, rgb_neg, gt_neg class TestDataset(BaseDataset): def __init__(self, img_size, dataset_path, inpainter): super().__init__(img_size = img_size, dataset_path = dataset_path, inpainter = inpainter) self.gt_transform = transforms.Compose([ transforms.ToTensor()]) self.img_paths, self.mask_paths, self.labels = self.load_dataset() def load_dataset(self): positive_imgs = [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data[500:600]] positive_masks = [None for data in self.data[500:600]] negative_imgs = [os.path.join(self.dataset_path, data['masks'][0]['inpainters'][self.inpainter].lstrip('/')) for data in self.data[600:700]] negative_masks = [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data[600:700]] labels = [0.0 for data in self.data[500:600]] + [1.0 for data in self.data[600:700]] return positive_imgs + negative_imgs, positive_masks + negative_masks, labels def __len__(self): return len(self.data[500:700]) def __getitem__(self, idx): img_path, gt, label = self.img_paths[idx], self.mask_paths[idx], self.labels[idx] img = Image.open(img_path.replace('/Open_V7/','').replace('data/', 'data/DFDS_V2/')).convert('RGB') rgb = self.rgb_transform(img) if gt == None: gt = torch.zeros( [1, img.size[1], img.size[0]]) else: gt = Image.open(gt.replace('/Open_V7/','').replace('data/', 'data/DFDS_V2/')).convert('L') gt = self.gt_transform(gt) gt = torch.where(gt > 0.5, 1., .0) return rgb, label, gt, img_path def get_data_loader(split, img_size, batch_size, dataset_path, inpainter = None): if split == 'train': dataset = TrainDataset(img_size, dataset_path) data_loader = DataLoader(dataset = dataset, batch_size = batch_size, shuffle = True, num_workers = 8, drop_last = True, pin_memory = False) elif split == 'test': dataset = TestDataset(img_size, dataset_path, inpainter) data_loader = DataLoader(dataset = dataset, batch_size = batch_size, shuffle = False, num_workers = 8, drop_last = False, pin_memory = False) return data_loader