Spaces:
Paused
Paused
| # Copyright 2022 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| A set of basic tensor ops compatible with tpu, gpu, and multigpu | |
| """ | |
| import pickle | |
| from functools import update_wrapper | |
| from typing import Any, Mapping | |
| import torch | |
| from ..state import PartialState | |
| from .constants import CUDA_DISTRIBUTED_TYPES | |
| from .dataclasses import DistributedType, TensorInformation | |
| from .imports import is_torch_distributed_available, is_tpu_available | |
| if is_tpu_available(check_device=False): | |
| import torch_xla.core.xla_model as xm | |
| if is_torch_distributed_available(): | |
| from torch.distributed import ReduceOp | |
| def is_torch_tensor(tensor): | |
| return isinstance(tensor, torch.Tensor) | |
| def is_torch_xpu_tensor(tensor): | |
| return isinstance( | |
| tensor, | |
| torch.xpu.FloatTensor, | |
| torch.xpu.ByteTensor, | |
| torch.xpu.IntTensor, | |
| torch.xpu.LongTensor, | |
| torch.xpu.HalfTensor, | |
| torch.xpu.DoubleTensor, | |
| torch.xpu.BFloat16Tensor, | |
| ) | |
| def is_tensor_information(tensor_info): | |
| return isinstance(tensor_info, TensorInformation) | |
| def is_namedtuple(data): | |
| """ | |
| Checks if `x` is a `namedtuple` or not. Can have false positives, but only if a user is trying to mimic a | |
| `namedtuple` perfectly. | |
| """ | |
| data_type = type(data) | |
| bases = data_type.__bases__ | |
| if len(bases) != 1 or bases[0] != tuple: | |
| return False | |
| fields = getattr(data_type, "_fields", None) | |
| if not isinstance(fields, tuple): | |
| return False | |
| return all(isinstance(member, str) for member in fields) | |
| def honor_type(obj, generator): | |
| """ | |
| Cast a generator to the same type as obj (list, tuple, or namedtuple) | |
| """ | |
| # Some objects may not be able to instantiate from a generator directly | |
| if is_namedtuple(obj): | |
| return type(obj)(*list(generator)) | |
| else: | |
| return type(obj)(generator) | |
| def recursively_apply(func, data, *args, test_type=is_torch_tensor, error_on_other_type=False, **kwargs): | |
| """ | |
| Recursively apply a function on a data structure that is a nested list/tuple/dictionary of a given base type. | |
| Args: | |
| func (`callable`): | |
| The function to recursively apply. | |
| data (nested list/tuple/dictionary of `main_type`): | |
| The data on which to apply `func` | |
| *args: | |
| Positional arguments that will be passed to `func` when applied on the unpacked data. | |
| main_type (`type`, *optional*, defaults to `torch.Tensor`): | |
| The base type of the objects to which apply `func`. | |
| error_on_other_type (`bool`, *optional*, defaults to `False`): | |
| Whether to return an error or not if after unpacking `data`, we get on an object that is not of type | |
| `main_type`. If `False`, the function will leave objects of types different than `main_type` unchanged. | |
| **kwargs: | |
| Keyword arguments that will be passed to `func` when applied on the unpacked data. | |
| Returns: | |
| The same data structure as `data` with `func` applied to every object of type `main_type`. | |
| """ | |
| if isinstance(data, (tuple, list)): | |
| return honor_type( | |
| data, | |
| ( | |
| recursively_apply( | |
| func, o, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs | |
| ) | |
| for o in data | |
| ), | |
| ) | |
| elif isinstance(data, Mapping): | |
| return type(data)( | |
| { | |
| k: recursively_apply( | |
| func, v, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs | |
| ) | |
| for k, v in data.items() | |
| } | |
| ) | |
| elif test_type(data): | |
| return func(data, *args, **kwargs) | |
| elif error_on_other_type: | |
| raise TypeError( | |
| f"Unsupported types ({type(data)}) passed to `{func.__name__}`. Only nested list/tuple/dicts of " | |
| f"objects that are valid for `{test_type.__name__}` should be passed." | |
| ) | |
| return data | |
| def send_to_device(tensor, device, non_blocking=False, skip_keys=None): | |
| """ | |
| Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device. | |
| Args: | |
| tensor (nested list/tuple/dictionary of `torch.Tensor`): | |
| The data to send to a given device. | |
| device (`torch.device`): | |
| The device to send the data to. | |
| Returns: | |
| The same data structure as `tensor` with all tensors sent to the proper device. | |
| """ | |
| if isinstance(tensor, (tuple, list)): | |
| return honor_type( | |
| tensor, (send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) for t in tensor) | |
| ) | |
| elif isinstance(tensor, Mapping): | |
| if isinstance(skip_keys, str): | |
| skip_keys = [skip_keys] | |
| elif skip_keys is None: | |
| skip_keys = [] | |
| return type(tensor)( | |
| { | |
| k: t if k in skip_keys else send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) | |
| for k, t in tensor.items() | |
| } | |
| ) | |
| elif hasattr(tensor, "to"): | |
| try: | |
| return tensor.to(device, non_blocking=non_blocking) | |
| except TypeError: # .to() doesn't accept non_blocking as kwarg | |
| return tensor.to(device) | |
| else: | |
| return tensor | |
| def get_data_structure(data): | |
| """ | |
| Recursively gathers the information needed to rebuild a nested list/tuple/dictionary of tensors. | |
| Args: | |
| data (nested list/tuple/dictionary of `torch.Tensor`): | |
| The data to send to analyze. | |
| Returns: | |
| The same data structure as `data` with [`~utils.TensorInformation`] instead of tensors. | |
| """ | |
| def _get_data_structure(tensor): | |
| return TensorInformation(shape=tensor.shape, dtype=tensor.dtype) | |
| return recursively_apply(_get_data_structure, data) | |
| def initialize_tensors(data_structure): | |
| """ | |
| Recursively initializes tensors from a nested list/tuple/dictionary of [`~utils.TensorInformation`]. | |
| Returns: | |
| The same data structure as `data` with tensors instead of [`~utils.TensorInformation`]. | |
| """ | |
| def _initialize_tensor(tensor_info): | |
| return torch.empty(*tensor_info.shape, dtype=tensor_info.dtype) | |
| return recursively_apply(_initialize_tensor, data_structure, test_type=is_tensor_information) | |
| def find_batch_size(data): | |
| """ | |
| Recursively finds the batch size in a nested list/tuple/dictionary of lists of tensors. | |
| Args: | |
| data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to find the batch size. | |
| Returns: | |
| `int`: The batch size. | |
| """ | |
| if isinstance(data, (tuple, list)): | |
| return find_batch_size(data[0]) | |
| elif isinstance(data, Mapping): | |
| for k in data.keys(): | |
| return find_batch_size(data[k]) | |
| elif not isinstance(data, torch.Tensor): | |
| raise TypeError(f"Can only find the batch size of tensors but got {type(data)}.") | |
| return data.shape[0] | |
| def listify(data): | |
| """ | |
| Recursively finds tensors in a nested list/tuple/dictionary and converts them to a list of numbers. | |
| Args: | |
| data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to convert to regular numbers. | |
| Returns: | |
| The same data structure as `data` with lists of numbers instead of `torch.Tensor`. | |
| """ | |
| def _convert_to_list(tensor): | |
| tensor = tensor.detach().cpu() | |
| if tensor.dtype == torch.bfloat16: | |
| # As of Numpy 1.21.4, NumPy does not support bfloat16 (see | |
| # https://github.com/numpy/numpy/blob/a47ecdea856986cd60eabbd53265c2ca5916ad5d/doc/source/user/basics.types.rst ). | |
| # Until Numpy adds bfloat16, we must convert float32. | |
| tensor = tensor.to(torch.float32) | |
| return tensor.tolist() | |
| return recursively_apply(_convert_to_list, data) | |
| def _tpu_gather(tensor): | |
| def _tpu_gather_one(tensor): | |
| if tensor.ndim == 0: | |
| tensor = tensor.clone()[None] | |
| return xm.all_gather(tensor) | |
| res = recursively_apply(_tpu_gather_one, tensor, error_on_other_type=True) | |
| xm.mark_step() | |
| return res | |
| def _gpu_gather(tensor): | |
| def _gpu_gather_one(tensor): | |
| if tensor.ndim == 0: | |
| tensor = tensor.clone()[None] | |
| output_tensors = [torch.empty_like(tensor) for _ in range(torch.distributed.get_world_size())] | |
| torch.distributed.all_gather(output_tensors, tensor) | |
| return torch.cat(output_tensors, dim=0) | |
| return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True) | |
| _cpu_gather = _gpu_gather | |
| def gather(tensor): | |
| """ | |
| Recursively gather tensor in a nested list/tuple/dictionary of tensors from all devices. | |
| Args: | |
| tensor (nested list/tuple/dictionary of `torch.Tensor`): | |
| The data to gather. | |
| Returns: | |
| The same data structure as `tensor` with all tensors sent to the proper device. | |
| """ | |
| if PartialState().distributed_type == DistributedType.TPU: | |
| return _tpu_gather(tensor) | |
| elif PartialState().distributed_type in CUDA_DISTRIBUTED_TYPES: | |
| return _gpu_gather(tensor) | |
| elif PartialState().distributed_type in DistributedType.MULTI_NPU: | |
| return _gpu_gather(tensor) | |
| elif PartialState().distributed_type in DistributedType.MULTI_XPU: | |
| return _gpu_gather(tensor) | |
| elif PartialState().distributed_type == DistributedType.MULTI_CPU: | |
| return _cpu_gather(tensor) | |
| else: | |
| return tensor | |
| def _gpu_gather_object(object: Any): | |
| output_objects = [None for _ in range(PartialState().num_processes)] | |
| torch.distributed.all_gather_object(output_objects, object) | |
| # all_gather_object returns a list of lists, so we need to flatten it | |
| return [x for y in output_objects for x in y] | |
| _cpu_gather_object = _gpu_gather_object | |
| def gather_object(object: Any): | |
| """ | |
| Recursively gather object in a nested list/tuple/dictionary of objects from all devices. | |
| Args: | |
| object (nested list/tuple/dictionary of picklable object): | |
| The data to gather. | |
| Returns: | |
| The same data structure as `object` with all the objects sent to every device. | |
| """ | |
| if PartialState().distributed_type == DistributedType.TPU: | |
| raise NotImplementedError("gather objects in TPU is not supported") | |
| elif PartialState().distributed_type in CUDA_DISTRIBUTED_TYPES: | |
| return _gpu_gather_object(object) | |
| elif PartialState().distributed_type in DistributedType.MULTI_NPU: | |
| return _gpu_gather_object(object) | |
| elif PartialState().distributed_type in DistributedType.MULTI_XPU: | |
| return _gpu_gather_object(object) | |
| elif PartialState().distributed_type == DistributedType.MULTI_CPU: | |
| return _cpu_gather_object(object) | |
| else: | |
| return object | |
| def _gpu_broadcast(data, src=0): | |
| def _gpu_broadcast_one(tensor, src=0): | |
| torch.distributed.broadcast(tensor, src=src) | |
| return tensor | |
| return recursively_apply(_gpu_broadcast_one, data, error_on_other_type=True, src=src) | |
| def _tpu_broadcast(tensor, src=0, name="broadcast tensor"): | |
| if isinstance(tensor, (list, tuple)): | |
| return honor_type(tensor, (_tpu_broadcast(t, name=f"{name}_{i}") for i, t in enumerate(tensor))) | |
| elif isinstance(tensor, Mapping): | |
| return type(tensor)({k: _tpu_broadcast(v, name=f"{name}_{k}") for k, v in tensor.items()}) | |
| return xm.mesh_reduce(name, tensor, lambda x: x[src]) | |
| def broadcast(tensor, from_process: int = 0): | |
| """ | |
| Recursively broadcast tensor in a nested list/tuple/dictionary of tensors to all devices. | |
| Args: | |
| tensor (nested list/tuple/dictionary of `torch.Tensor`): | |
| The data to gather. | |
| from_process (`int`, *optional*, defaults to 0): | |
| The process from which to send the data | |
| Returns: | |
| The same data structure as `tensor` with all tensors broadcasted to the proper device. | |
| """ | |
| if PartialState().distributed_type == DistributedType.TPU: | |
| return _tpu_broadcast(tensor, src=from_process, name="accelerate.utils.broadcast") | |
| elif PartialState().distributed_type in CUDA_DISTRIBUTED_TYPES: | |
| return _gpu_broadcast(tensor, src=from_process) | |
| elif PartialState().distributed_type in DistributedType.MULTI_NPU: | |
| return _gpu_gather_object(object) | |
| elif PartialState().distributed_type in DistributedType.MULTI_XPU: | |
| return _gpu_broadcast(tensor, src=from_process) | |
| elif PartialState().distributed_type == DistributedType.MULTI_CPU: | |
| return _gpu_broadcast(tensor, src=from_process) | |
| else: | |
| return tensor | |
| def broadcast_object_list(object_list, from_process: int = 0): | |
| """ | |
| Broadcast a list of picklable objects form one process to the others. | |
| Args: | |
| object_list (list of picklable objects): | |
| The list of objects to broadcast. This list will be modified inplace. | |
| from_process (`int`, *optional*, defaults to 0): | |
| The process from which to send the data. | |
| Returns: | |
| The same list containing the objects from process 0. | |
| """ | |
| if PartialState().distributed_type == DistributedType.TPU: | |
| for i, obj in enumerate(object_list): | |
| object_list[i] = xm.mesh_reduce("accelerate.utils.broadcast_object_list", obj, lambda x: x[from_process]) | |
| elif PartialState().distributed_type in CUDA_DISTRIBUTED_TYPES: | |
| torch.distributed.broadcast_object_list(object_list, src=from_process) | |
| elif PartialState().distributed_type in DistributedType.MULTI_NPU: | |
| torch.distributed.broadcast_object_list(object_list, src=from_process) | |
| elif PartialState().distributed_type in DistributedType.MULTI_XPU: | |
| torch.distributed.broadcast_object_list(object_list, src=from_process) | |
| elif PartialState().distributed_type == DistributedType.MULTI_CPU: | |
| torch.distributed.broadcast_object_list(object_list, src=from_process) | |
| return object_list | |
| def slice_tensors(data, tensor_slice): | |
| """ | |
| Recursively takes a slice in a nested list/tuple/dictionary of tensors. | |
| Args: | |
| data (nested list/tuple/dictionary of `torch.Tensor`): | |
| The data to slice. | |
| tensor_slice (`slice`): | |
| The slice to take. | |
| Returns: | |
| The same data structure as `data` with all the tensors slices. | |
| """ | |
| def _slice_tensor(tensor, tensor_slice): | |
| return tensor[tensor_slice] | |
| return recursively_apply(_slice_tensor, data, tensor_slice) | |
| def concatenate(data, dim=0): | |
| """ | |
| Recursively concatenate the tensors in a nested list/tuple/dictionary of lists of tensors with the same shape. | |
| Args: | |
| data (nested list/tuple/dictionary of lists of tensors `torch.Tensor`): | |
| The data to concatenate. | |
| dim (`int`, *optional*, defaults to 0): | |
| The dimension on which to concatenate. | |
| Returns: | |
| The same data structure as `data` with all the tensors concatenated. | |
| """ | |
| if isinstance(data[0], (tuple, list)): | |
| return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0])))) | |
| elif isinstance(data[0], Mapping): | |
| return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()}) | |
| elif not isinstance(data[0], torch.Tensor): | |
| raise TypeError(f"Can only concatenate tensors but got {type(data[0])}") | |
| return torch.cat(data, dim=dim) | |
| def pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False): | |
| """ | |
| Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so they | |
| can safely be gathered. | |
| Args: | |
| tensor (nested list/tuple/dictionary of `torch.Tensor`): | |
| The data to gather. | |
| dim (`int`, *optional*, defaults to 0): | |
| The dimension on which to pad. | |
| pad_index (`int`, *optional*, defaults to 0): | |
| The value with which to pad. | |
| pad_first (`bool`, *optional*, defaults to `False`): | |
| Whether to pad at the beginning or the end. | |
| """ | |
| def _pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False): | |
| if dim >= len(tensor.shape): | |
| return tensor | |
| # Gather all sizes | |
| size = torch.tensor(tensor.shape, device=tensor.device)[None] | |
| sizes = gather(size).cpu() | |
| # Then pad to the maximum size | |
| max_size = max(s[dim] for s in sizes) | |
| if max_size == tensor.shape[dim]: | |
| return tensor | |
| old_size = tensor.shape | |
| new_size = list(old_size) | |
| new_size[dim] = max_size | |
| new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index | |
| if pad_first: | |
| indices = tuple( | |
| slice(max_size - old_size[dim], max_size) if i == dim else slice(None) for i in range(len(new_size)) | |
| ) | |
| else: | |
| indices = tuple(slice(0, old_size[dim]) if i == dim else slice(None) for i in range(len(new_size))) | |
| new_tensor[indices] = tensor | |
| return new_tensor | |
| return recursively_apply( | |
| _pad_across_processes, tensor, error_on_other_type=True, dim=dim, pad_index=pad_index, pad_first=pad_first | |
| ) | |
| def reduce(tensor, reduction="mean"): | |
| """ | |
| Recursively reduce the tensors in a nested list/tuple/dictionary of lists of tensors across all processes by the | |
| mean of a given operation. | |
| Args: | |
| tensor (nested list/tuple/dictionary of `torch.Tensor`): | |
| The data to reduce. | |
| reduction (`str`, *optional*, defaults to `"mean"`): | |
| A reduction method. Can be of "mean", "sum", or "none" | |
| Returns: | |
| The same data structure as `data` with all the tensors reduced. | |
| """ | |
| def _reduce_across_processes(tensor, reduction="mean"): | |
| state = PartialState() | |
| cloned_tensor = tensor.clone() | |
| if state.distributed_type == DistributedType.NO: | |
| return cloned_tensor | |
| if state.distributed_type == DistributedType.TPU: | |
| xm.all_reduce("sum", cloned_tensor) | |
| elif state.distributed_type.value in CUDA_DISTRIBUTED_TYPES: | |
| torch.distributed.all_reduce(cloned_tensor, ReduceOp.SUM) | |
| elif state.distributed_type.value in DistributedType.MULTI_NPU: | |
| torch.distributed.all_reduce(cloned_tensor, ReduceOp.SUM) | |
| elif state.distributed_type.value in DistributedType.MULTI_XPU: | |
| torch.distributed.all_reduce(cloned_tensor, ReduceOp.SUM) | |
| elif state.distributed_type == DistributedType.MULTI_CPU: | |
| torch.distributed.all_reduce(cloned_tensor, ReduceOp.SUM) | |
| if reduction == "mean": | |
| cloned_tensor /= state.num_processes | |
| return cloned_tensor | |
| return recursively_apply(_reduce_across_processes, tensor, error_on_other_type=True, reduction=reduction) | |
| def convert_to_fp32(tensor): | |
| """ | |
| Recursively converts the elements nested list/tuple/dictionary of tensors in FP16/BF16 precision to FP32. | |
| Args: | |
| tensor (nested list/tuple/dictionary of `torch.Tensor`): | |
| The data to convert from FP16/BF16 to FP32. | |
| Returns: | |
| The same data structure as `tensor` with all tensors that were in FP16/BF16 precision converted to FP32. | |
| """ | |
| def _convert_to_fp32(tensor): | |
| return tensor.float() | |
| def _is_fp16_bf16_tensor(tensor): | |
| return hasattr(tensor, "dtype") and tensor.dtype in (torch.float16, torch.bfloat16) | |
| return recursively_apply(_convert_to_fp32, tensor, test_type=_is_fp16_bf16_tensor) | |
| class ConvertOutputsToFp32: | |
| """ | |
| Decorator to apply to a function outputing tensors (like a model forward pass) that ensures the outputs in FP16 | |
| precision will be convert back to FP32. | |
| Args: | |
| model_forward (`Callable`): | |
| The function which outputs we want to treat. | |
| Returns: | |
| The same function as `model_forward` but with converted outputs. | |
| """ | |
| def __init__(self, model_forward): | |
| self.model_forward = model_forward | |
| update_wrapper(self, model_forward) | |
| def __call__(self, *args, **kwargs): | |
| return convert_to_fp32(self.model_forward(*args, **kwargs)) | |
| def __getstate__(self): | |
| raise pickle.PicklingError( | |
| "Cannot pickle a prepared model with automatic mixed precision, please unwrap the model with `Accelerator.unwrap_model(model)` before pickling it." | |
| ) | |
| def convert_outputs_to_fp32(model_forward): | |
| model_forward = ConvertOutputsToFp32(model_forward) | |
| def forward(*args, **kwargs): | |
| return model_forward(*args, **kwargs) | |
| # To act like a decorator so that it can be popped when doing `extract_model_from_parallel` | |
| forward.__wrapped__ = model_forward | |
| return forward | |
| def find_device(data): | |
| """ | |
| Finds the device on which a nested dict/list/tuple of tensors lies (assuming they are all on the same device). | |
| Args: | |
| (nested list/tuple/dictionary of `torch.Tensor`): The data we want to know the device of. | |
| """ | |
| if isinstance(data, Mapping): | |
| for obj in data.values(): | |
| device = find_device(obj) | |
| if device is not None: | |
| return device | |
| elif isinstance(data, (tuple, list)): | |
| for obj in data: | |
| device = find_device(obj) | |
| if device is not None: | |
| return device | |
| elif isinstance(data, torch.Tensor): | |
| return data.device | |