Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| # -*- coding:utf-8 -*- | |
| # This code is based on | |
| # https://github.com/ultralytics/yolov5/blob/master/utils/dataloaders.py | |
| import os | |
| from torch.utils.data import dataloader, distributed | |
| from .datasets import TrainValDataset | |
| from yolov6.utils.events import LOGGER | |
| from yolov6.utils.torch_utils import torch_distributed_zero_first | |
| def create_dataloader( | |
| path, | |
| img_size, | |
| batch_size, | |
| stride, | |
| hyp=None, | |
| augment=False, | |
| check_images=False, | |
| check_labels=False, | |
| pad=0.0, | |
| rect=False, | |
| rank=-1, | |
| workers=8, | |
| shuffle=False, | |
| data_dict=None, | |
| task="Train", | |
| ): | |
| """Create general dataloader. | |
| Returns dataloader and dataset | |
| """ | |
| if rect and shuffle: | |
| LOGGER.warning( | |
| "WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False" | |
| ) | |
| shuffle = False | |
| with torch_distributed_zero_first(rank): | |
| dataset = TrainValDataset( | |
| path, | |
| img_size, | |
| batch_size, | |
| augment=augment, | |
| hyp=hyp, | |
| rect=rect, | |
| check_images=check_images, | |
| check_labels=check_labels, | |
| stride=int(stride), | |
| pad=pad, | |
| rank=rank, | |
| data_dict=data_dict, | |
| task=task, | |
| ) | |
| batch_size = min(batch_size, len(dataset)) | |
| workers = min( | |
| [ | |
| os.cpu_count() // int(os.getenv("WORLD_SIZE", 1)), | |
| batch_size if batch_size > 1 else 0, | |
| workers, | |
| ] | |
| ) # number of workers | |
| sampler = ( | |
| None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) | |
| ) | |
| return ( | |
| TrainValDataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=shuffle and sampler is None, | |
| num_workers=workers, | |
| sampler=sampler, | |
| pin_memory=True, | |
| collate_fn=TrainValDataset.collate_fn, | |
| ), | |
| dataset, | |
| ) | |
| class TrainValDataLoader(dataloader.DataLoader): | |
| """Dataloader that reuses workers | |
| Uses same syntax as vanilla DataLoader | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler)) | |
| self.iterator = super().__iter__() | |
| def __len__(self): | |
| return len(self.batch_sampler.sampler) | |
| def __iter__(self): | |
| for i in range(len(self)): | |
| yield next(self.iterator) | |
| class _RepeatSampler: | |
| """Sampler that repeats forever | |
| Args: | |
| sampler (Sampler) | |
| """ | |
| def __init__(self, sampler): | |
| self.sampler = sampler | |
| def __iter__(self): | |
| while True: | |
| yield from iter(self.sampler) | |