Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import math | |
| from typing import Iterator, Optional, Sequence, Sized | |
| import torch | |
| from mmengine.dist import get_dist_info, sync_random_seed | |
| from mmengine.registry import DATA_SAMPLERS | |
| from torch.utils.data import Sampler | |
| class MultiDataSampler(Sampler): | |
| """The default data sampler for both distributed and non-distributed | |
| environment. | |
| It has several differences from the PyTorch ``DistributedSampler`` as | |
| below: | |
| 1. This sampler supports non-distributed environment. | |
| 2. The round up behaviors are a little different. | |
| - If ``round_up=True``, this sampler will add extra samples to make the | |
| number of samples is evenly divisible by the world size. And | |
| this behavior is the same as the ``DistributedSampler`` with | |
| ``drop_last=False``. | |
| - If ``round_up=False``, this sampler won't remove or add any samples | |
| while the ``DistributedSampler`` with ``drop_last=True`` will remove | |
| tail samples. | |
| Args: | |
| dataset (Sized): The dataset. | |
| dataset_ratio (Sequence(int)) The ratios of different datasets. | |
| seed (int, optional): Random seed used to shuffle the sampler if | |
| :attr:`shuffle=True`. This number should be identical across all | |
| processes in the distributed group. Defaults to None. | |
| round_up (bool): Whether to add extra samples to make the number of | |
| samples evenly divisible by the world size. Defaults to True. | |
| """ | |
| def __init__(self, | |
| dataset: Sized, | |
| dataset_ratio: Sequence[int], | |
| seed: Optional[int] = None, | |
| round_up: bool = True) -> None: | |
| rank, world_size = get_dist_info() | |
| self.rank = rank | |
| self.world_size = world_size | |
| self.dataset = dataset | |
| self.dataset_ratio = dataset_ratio | |
| if seed is None: | |
| seed = sync_random_seed() | |
| self.seed = seed | |
| self.epoch = 0 | |
| self.round_up = round_up | |
| if self.round_up: | |
| self.num_samples = math.ceil(len(self.dataset) / world_size) | |
| self.total_size = self.num_samples * self.world_size | |
| else: | |
| self.num_samples = math.ceil( | |
| (len(self.dataset) - rank) / world_size) | |
| self.total_size = len(self.dataset) | |
| self.sizes = [len(dataset) for dataset in self.dataset.datasets] | |
| dataset_weight = [ | |
| torch.ones(s) * max(self.sizes) / s * r / sum(self.dataset_ratio) | |
| for i, (r, s) in enumerate(zip(self.dataset_ratio, self.sizes)) | |
| ] | |
| self.weights = torch.cat(dataset_weight) | |
| def __iter__(self) -> Iterator[int]: | |
| """Iterate the indices.""" | |
| # deterministically shuffle based on epoch and seed | |
| g = torch.Generator() | |
| g.manual_seed(self.seed + self.epoch) | |
| indices = torch.multinomial( | |
| self.weights, len(self.weights), generator=g, | |
| replacement=True).tolist() | |
| # add extra samples to make it evenly divisible | |
| if self.round_up: | |
| indices = ( | |
| indices * | |
| int(self.total_size / len(indices) + 1))[:self.total_size] | |
| # subsample | |
| indices = indices[self.rank:self.total_size:self.world_size] | |
| return iter(indices) | |
| def __len__(self) -> int: | |
| """The number of samples in this rank.""" | |
| return self.num_samples | |
| def set_epoch(self, epoch: int) -> None: | |
| """Sets the epoch for this sampler. | |
| When :attr:`shuffle=True`, this ensures all replicas use a different | |
| random ordering for each epoch. Otherwise, the next iteration of this | |
| sampler will yield the same ordering. | |
| Args: | |
| epoch (int): Epoch number. | |
| """ | |
| self.epoch = epoch | |