| '''create dataset and dataloader''' | |
| import logging | |
| import torch | |
| import torch.utils.data | |
| def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): | |
| phase = dataset_opt['phase'] | |
| if phase == 'train': | |
| if opt['dist']: | |
| world_size = torch.distributed.get_world_size() | |
| num_workers = dataset_opt['n_workers'] | |
| assert dataset_opt['batch_size'] % world_size == 0 | |
| batch_size = dataset_opt['batch_size'] // world_size | |
| shuffle = False | |
| else: | |
| num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids']) | |
| batch_size = dataset_opt['batch_size'] | |
| shuffle = True | |
| return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, | |
| num_workers=num_workers, sampler=sampler, drop_last=True, | |
| pin_memory=False) | |
| else: | |
| return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, | |
| pin_memory=True) | |
| def create_dataset(dataset_opt): | |
| mode = dataset_opt['mode'] | |
| if mode == 'test': | |
| from data.coco_test_dataset import imageTestDataset as D | |
| elif mode == 'train': | |
| from data.coco_dataset import CoCoDataset as D | |
| elif mode == 'td': | |
| from data.test_dataset_td import imageTestDataset as D | |
| else: | |
| raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) | |
| print(mode) | |
| dataset = D(dataset_opt) | |
| logger = logging.getLogger('base') | |
| logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, | |
| dataset_opt['name'])) | |
| return dataset | |