Spaces:
Paused
Paused
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from collections import abc as container_abcs | |
| from collections import defaultdict | |
| from copy import deepcopy | |
| from itertools import chain | |
| import torch | |
| import bitsandbytes.functional as F | |
| class MockArgs: | |
| def __init__(self, initial_data): | |
| for key in initial_data: | |
| setattr(self, key, initial_data[key]) | |
| class GlobalOptimManager: | |
| _instance = None | |
| def __init__(self): | |
| raise RuntimeError("Call get_instance() instead") | |
| def initialize(self): | |
| self.pid2config = {} | |
| self.index2config = {} | |
| self.optimizer = None | |
| self.uses_config_override = False | |
| self.module_weight_config_triple = [] | |
| def get_instance(cls): | |
| if cls._instance is None: | |
| cls._instance = cls.__new__(cls) | |
| cls._instance.initialize() | |
| return cls._instance | |
| def register_parameters(self, params): | |
| param_groups = list(params) | |
| if not isinstance(param_groups[0], dict): | |
| param_groups = [{"params": param_groups}] | |
| for group_index, group in enumerate(param_groups): | |
| for p_index, p in enumerate(group["params"]): | |
| if id(p) in self.pid2config: | |
| self.index2config[(group_index, p_index)] = self.pid2config[ | |
| id(p) | |
| ] | |
| def override_config( | |
| self, parameters, key=None, value=None, key_value_dict=None | |
| ): | |
| """ | |
| Overrides initial optimizer config for specific parameters. | |
| The key-values of the optimizer config for the input parameters are overridden | |
| This can be both, optimizer parameters like "betas", or "lr" or it can be | |
| 8-bit specific parameters like "optim_bits", "percentile_clipping". | |
| Parameters | |
| ---------- | |
| parameters : torch.Tensor or list(torch.Tensors) | |
| The input parameters. | |
| key : str | |
| The hyperparamter to override. | |
| value : object | |
| The value for the hyperparamters. | |
| key_value_dict : dict | |
| A dictionary with multiple key-values to override. | |
| """ | |
| self.uses_config_override = True | |
| if isinstance(parameters, torch.nn.Parameter): | |
| parameters = [parameters] | |
| if isinstance(parameters, torch.Tensor): | |
| parameters = [parameters] | |
| if key is not None and value is not None: | |
| assert key_value_dict is None | |
| key_value_dict = {key: value} | |
| if key_value_dict is not None: | |
| for p in parameters: | |
| if id(p) in self.pid2config: | |
| self.pid2config[id(p)].update(key_value_dict) | |
| else: | |
| self.pid2config[id(p)] = key_value_dict | |
| def register_module_override(self, module, param_name, config): | |
| self.module_weight_config_triple.append((module, param_name, config)) | |
| class Optimizer8bit(torch.optim.Optimizer): | |
| def __init__(self, params, defaults, optim_bits=32, is_paged=False): | |
| super().__init__(params, defaults) | |
| self.initialized = False | |
| self.name2qmap = {} | |
| self.is_paged = is_paged | |
| self.page_mng = F.GlobalPageManager.get_instance() | |
| self.mng = GlobalOptimManager.get_instance() | |
| self.non_castable_tensor_keys = { | |
| "qmap1", | |
| "qmap2", | |
| "max1", | |
| "max2", | |
| "new_max1", | |
| "new_max2", | |
| "state1", | |
| "state2", | |
| "gnorm_vec", | |
| "absmax1", | |
| "absmax2", | |
| "unorm_vec", | |
| } | |
| if optim_bits == 8: | |
| self.fill_qmap() | |
| def fill_qmap(self): | |
| self.name2qmap["dynamic"] = F.create_dynamic_map(signed=True) | |
| self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False) | |
| def __setstate__(self, state): | |
| super().__setstate__(state) | |
| def load_state_dict(self, state_dict): | |
| r"""Loads the optimizer state. | |
| Args: | |
| state_dict (dict): optimizer state. Should be an object returned | |
| from a call to :meth:`state_dict`. | |
| """ | |
| # deepcopy, to be consistent with module API | |
| state_dict = deepcopy(state_dict) | |
| # Validate the state_dict | |
| groups = self.param_groups | |
| saved_groups = state_dict["param_groups"] | |
| if len(groups) != len(saved_groups): | |
| raise ValueError( | |
| "loaded state dict has a different number of " | |
| "parameter groups" | |
| ) | |
| param_lens = (len(g["params"]) for g in groups) | |
| saved_lens = (len(g["params"]) for g in saved_groups) | |
| if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): | |
| raise ValueError( | |
| "loaded state dict contains a parameter group " | |
| "that doesn't match the size of optimizer's group" | |
| ) | |
| # Update the state | |
| id_map = { | |
| old_id: p | |
| for old_id, p in zip( | |
| chain.from_iterable(g["params"] for g in saved_groups), | |
| chain.from_iterable(g["params"] for g in groups), | |
| ) | |
| } | |
| def cast(param, value): | |
| r"""Make a deep copy of value, casting all tensors to device of param.""" | |
| if isinstance(value, torch.Tensor): | |
| # Floating-point types are a bit special here. They are the only ones | |
| # that are assumed to always match the type of params. | |
| if param.is_floating_point() and value.dtype != torch.uint8: | |
| value = value.to(param.dtype) | |
| return value | |
| elif isinstance(value, dict): | |
| for k, v in value.items(): | |
| if k in self.non_castable_tensor_keys: | |
| value[k] = v.to(param.device) | |
| else: | |
| value[k] = cast(param, v) | |
| return value | |
| elif isinstance(value, container_abcs.Iterable): | |
| return type(value)(cast(param, v) for v in value) | |
| else: | |
| return value | |
| # Copy state assigned to params (and cast tensors to appropriate types). | |
| # State that is not assigned to params is copied as is (needed for | |
| # backward compatibility). | |
| state = defaultdict(dict) | |
| for k, v in state_dict["state"].items(): | |
| if k in id_map: | |
| param = id_map[k] | |
| state[param] = cast(param, v) | |
| else: | |
| state[k] = v | |
| # Update parameter groups, setting their 'params' value | |
| def update_group(group, new_group): | |
| new_group["params"] = group["params"] | |
| return new_group | |
| param_groups = [ | |
| update_group(g, ng) for g, ng in zip(groups, saved_groups) | |
| ] | |
| self.__setstate__({"state": state, "param_groups": param_groups}) | |
| def to_gpu(self): | |
| for gindex, group in enumerate(self.param_groups): | |
| for pindex, p in enumerate(group["params"]): | |
| if p in self.state: | |
| values = self.state[p] | |
| for k, v in values.items(): | |
| if isinstance(v, torch.Tensor): | |
| is_paged = getattr(v, 'is_paged', False) | |
| if not is_paged: | |
| self.state[p][k] = v.to(p.device) | |
| def check_overrides(self): | |
| for module, attr, config in self.mng.module_weight_config_triple: | |
| pmodule = getattr(module, attr) | |
| assert pmodule is not None | |
| assert isinstance(pmodule, torch.Tensor) or isinstance( | |
| pmodule, torch.Parameter | |
| ) | |
| found = False | |
| for gindex, group in enumerate(self.param_groups): | |
| if found: | |
| break | |
| for pindex, p in enumerate(group["params"]): | |
| if found: | |
| break | |
| if id(p) == id(pmodule): | |
| # found the matching parameter | |
| # init override | |
| self.mng.pid2config[id(p)] = config | |
| self.mng.index2config[ | |
| (gindex, pindex) | |
| ] = self.mng.pid2config[id(p)] | |
| found = True | |
| def step(self, closure=None): | |
| """Performs a single optimization step. | |
| Arguments: | |
| closure (callable, optional): A closure that reevaluates the model | |
| and returns the loss. | |
| """ | |
| loss = None | |
| if closure is not None: | |
| with torch.enable_grad(): | |
| loss = closure() | |
| overflows = [] | |
| if not self.initialized: | |
| self.check_overrides() | |
| self.to_gpu() # needed for fairseq pure fp16 training | |
| self.initialized = True | |
| #if self.is_paged: self.page_mng.prefetch_all() | |
| for gindex, group in enumerate(self.param_groups): | |
| for pindex, p in enumerate(group["params"]): | |
| if p.grad is None: | |
| continue | |
| state = self.state[p] | |
| if len(state) == 0: | |
| self.init_state(group, p, gindex, pindex) | |
| self.prefetch_state(p) | |
| self.update_step(group, p, gindex, pindex) | |
| torch.cuda.synchronize() | |
| if self.is_paged: | |
| # all paged operation are asynchronous, we need | |
| # to sync to make sure all tensors are in the right state | |
| torch.cuda.synchronize() | |
| return loss | |
| def get_config(self, gindex, pindex, group): | |
| config = {} | |
| config["betas"] = group["betas"] | |
| config["eps"] = group["eps"] | |
| config["weight_decay"] = group["weight_decay"] | |
| config["lr"] = group["lr"] | |
| config["optim_bits"] = self.args.optim_bits | |
| config["min_8bit_size"] = self.args.min_8bit_size | |
| config["percentile_clipping"] = self.args.percentile_clipping | |
| config["block_wise"] = self.args.block_wise | |
| config["max_unorm"] = self.args.max_unorm | |
| config["skip_zeros"] = self.args.skip_zeros | |
| if (gindex, pindex) in self.mng.index2config: | |
| config.update(self.mng.index2config[(gindex, pindex)]) | |
| return config | |
| def init_state(self, group, p, gindex, pindex): | |
| raise NotImplementedError("init_state method needs to be overridden") | |
| def update_step(self, group, p, gindex, pindex): | |
| raise NotImplementedError( | |
| "The update_step method needs to be overridden" | |
| ) | |
| def get_state_buffer(self, p, dtype=torch.float32): | |
| if not self.is_paged or p.numel() < 1e5: | |
| return torch.zeros_like(p, dtype=dtype, device=p.device) | |
| else: | |
| # > 1 MB | |
| buff = F.get_paged(*p.shape, dtype=dtype, device=p.device) | |
| F.fill(buff, 0) | |
| self.page_mng.paged_tensors.append(buff) | |
| return buff | |
| def prefetch_state(self, p): | |
| if self.is_paged: | |
| state = self.state[p] | |
| s1 = state['state1'] | |
| is_paged = getattr(s1, 'is_paged', False) | |
| if is_paged: | |
| F.prefetch_tensor(state['state1']) | |
| if 'state2' in state: | |
| F.prefetch_tensor(state['state2']) | |
| class Optimizer2State(Optimizer8bit): | |
| def __init__( | |
| self, | |
| optimizer_name, | |
| params, | |
| lr=1e-3, | |
| betas=(0.9, 0.999), | |
| eps=1e-8, | |
| weight_decay=0.0, | |
| optim_bits=32, | |
| args=None, | |
| min_8bit_size=4096, | |
| percentile_clipping=100, | |
| block_wise=True, | |
| max_unorm=0.0, | |
| skip_zeros=False, | |
| is_paged=False | |
| ): | |
| if not 0.0 <= lr: | |
| raise ValueError(f"Invalid learning rate: {lr}") | |
| if not 0.0 <= eps: | |
| raise ValueError(f"Invalid epsilon value: {eps}") | |
| if isinstance(betas, str): | |
| # format: '(beta1, beta2)' | |
| betas = betas.replace("(", "").replace(")", "").strip().split(",") | |
| betas = [float(b) for b in betas] | |
| for i in range(len(betas)): | |
| if not 0.0 <= betas[i] < 1.0: | |
| raise ValueError( | |
| f"Invalid beta parameter at index {i}: {betas[i]}" | |
| ) | |
| if not 0.0 <= weight_decay: | |
| raise ValueError( | |
| f"Invalid weight_decay value: {weight_decay}" | |
| ) | |
| defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) | |
| super().__init__(params, defaults, optim_bits, is_paged) | |
| if args is None: | |
| args = {} | |
| args["optim_bits"] = optim_bits | |
| args["percentile_clipping"] = 100 | |
| args["min_8bit_size"] = min_8bit_size | |
| args["percentile_clipping"] = percentile_clipping | |
| args["block_wise"] = block_wise | |
| args["max_unorm"] = max_unorm | |
| args["skip_zeros"] = skip_zeros | |
| self.args = MockArgs(args) | |
| else: | |
| self.args = args | |
| self.optimizer_name = optimizer_name | |
| def init_state(self, group, p, gindex, pindex): | |
| config = self.get_config(gindex, pindex, group) | |
| if config["optim_bits"] == 32: | |
| dtype = torch.float32 | |
| elif config["optim_bits"] == 8: | |
| dtype = torch.uint8 | |
| else: | |
| raise NotImplementedError( | |
| f'Amount of optimizer bits not supported: {config["optim_bits"]}' | |
| ) | |
| if p.numel() < config["min_8bit_size"]: | |
| dtype = torch.float32 | |
| state = self.state[p] | |
| state["step"] = 0 | |
| if dtype == torch.float32 or ( | |
| dtype == torch.uint8 and p.numel() < 4096 | |
| ): | |
| state["state1"] = self.get_state_buffer(p, dtype=torch.float32) | |
| state["state2"] = self.get_state_buffer(p, dtype=torch.float32) | |
| elif dtype == torch.uint8: | |
| if state["step"] == 0: | |
| if "dynamic" not in self.name2qmap: | |
| self.fill_qmap() | |
| self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( | |
| p.device | |
| ) | |
| self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to( | |
| p.device | |
| ) | |
| state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) | |
| state["qmap1"] = self.name2qmap["dynamic"] | |
| state["state2"] = self.get_state_buffer(p, dtype=torch.uint8) | |
| state["qmap2"] = self.name2qmap["udynamic"] | |
| if config["block_wise"]: | |
| n = p.numel() | |
| blocks = n // 2048 | |
| blocks += 1 if n % 2048 > 0 else 0 | |
| state["absmax1"] = torch.zeros( | |
| (blocks,), dtype=torch.float32, device=p.device | |
| ) | |
| state["absmax2"] = torch.zeros( | |
| (blocks,), dtype=torch.float32, device=p.device | |
| ) | |
| else: | |
| state["max1"] = torch.zeros( | |
| (1,), dtype=torch.float32, device=p.device | |
| ) | |
| state["new_max1"] = torch.zeros( | |
| (1,), dtype=torch.float32, device=p.device | |
| ) | |
| state["max2"] = torch.zeros( | |
| (1,), dtype=torch.float32, device=p.device | |
| ) | |
| state["new_max2"] = torch.zeros( | |
| (1,), dtype=torch.float32, device=p.device | |
| ) | |
| if config["percentile_clipping"] < 100: | |
| state["gnorm_vec"] = torch.zeros((100,), device=p.device) | |
| if config["max_unorm"] > 0.0: | |
| state["unorm_vec"] = torch.zeros((1,), device=p.device) | |
| def update_step(self, group, p, gindex, pindex): | |
| state = self.state[p] | |
| grad = p.grad | |
| config = self.get_config(gindex, pindex, group) | |
| state["step"] += 1 | |
| step = state["step"] | |
| if config["percentile_clipping"] < 100: | |
| current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( | |
| grad, state["gnorm_vec"], step, config["percentile_clipping"] | |
| ) | |
| else: | |
| gnorm_scale = 1.0 | |
| if state["state1"].dtype == torch.float: | |
| F.optimizer_update_32bit( | |
| self.optimizer_name, | |
| grad, | |
| p, | |
| state["state1"], | |
| config["betas"][0], | |
| config["eps"], | |
| step, | |
| config["lr"], | |
| state["state2"], | |
| config["betas"][1], | |
| config["weight_decay"], | |
| gnorm_scale, | |
| state["unorm_vec"] if config["max_unorm"] > 0.0 else None, | |
| max_unorm=config["max_unorm"], | |
| skip_zeros=config["skip_zeros"], | |
| ) | |
| elif state["state1"].dtype == torch.uint8 and not config["block_wise"]: | |
| F.optimizer_update_8bit( | |
| self.optimizer_name, | |
| grad, | |
| p, | |
| state["state1"], | |
| state["state2"], | |
| config["betas"][0], | |
| config["betas"][1], | |
| config["eps"], | |
| step, | |
| config["lr"], | |
| state["qmap1"], | |
| state["qmap2"], | |
| state["max1"], | |
| state["max2"], | |
| state["new_max1"], | |
| state["new_max2"], | |
| config["weight_decay"], | |
| gnorm_scale=gnorm_scale, | |
| unorm_vec=state["unorm_vec"] | |
| if config["max_unorm"] > 0.0 | |
| else None, | |
| max_unorm=config["max_unorm"], | |
| ) | |
| # swap maxes | |
| state["max1"], state["new_max1"] = state["new_max1"], state["max1"] | |
| state["max2"], state["new_max2"] = state["new_max2"], state["max2"] | |
| elif state["state1"].dtype == torch.uint8 and config["block_wise"]: | |
| F.optimizer_update_8bit_blockwise( | |
| self.optimizer_name, | |
| grad, | |
| p, | |
| state["state1"], | |
| state["state2"], | |
| config["betas"][0], | |
| config["betas"][1], | |
| config["eps"], | |
| step, | |
| config["lr"], | |
| state["qmap1"], | |
| state["qmap2"], | |
| state["absmax1"], | |
| state["absmax2"], | |
| config["weight_decay"], | |
| gnorm_scale=gnorm_scale, | |
| skip_zeros=config["skip_zeros"], | |
| ) | |
| class Optimizer1State(Optimizer8bit): | |
| def __init__( | |
| self, | |
| optimizer_name, | |
| params, | |
| lr=1e-3, | |
| betas=(0.9, 0.0), | |
| eps=1e-8, | |
| weight_decay=0.0, | |
| optim_bits=32, | |
| args=None, | |
| min_8bit_size=4096, | |
| percentile_clipping=100, | |
| block_wise=True, | |
| max_unorm=0.0, | |
| skip_zeros=False, | |
| is_paged=False | |
| ): | |
| if not 0.0 <= lr: | |
| raise ValueError(f"Invalid learning rate: {lr}") | |
| if not 0.0 <= eps: | |
| raise ValueError(f"Invalid epsilon value: {eps}") | |
| for i in range(len(betas)): | |
| if not 0.0 <= betas[i] < 1.0: | |
| raise ValueError( | |
| f"Invalid beta parameter at index {i}: {betas[i]}" | |
| ) | |
| if not 0.0 <= weight_decay: | |
| raise ValueError( | |
| f"Invalid weight_decay value: {weight_decay}" | |
| ) | |
| defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) | |
| super().__init__(params, defaults, optim_bits, is_paged) | |
| if args is None: | |
| args = {} | |
| args["optim_bits"] = optim_bits | |
| args["percentile_clipping"] = 100 | |
| args["min_8bit_size"] = min_8bit_size | |
| args["percentile_clipping"] = percentile_clipping | |
| args["block_wise"] = block_wise | |
| args["max_unorm"] = max_unorm | |
| args["skip_zeros"] = skip_zeros | |
| self.args = MockArgs(args) | |
| else: | |
| self.args = args | |
| self.optimizer_name = optimizer_name | |
| def init_state(self, group, p, gindex, pindex): | |
| config = self.get_config(gindex, pindex, group) | |
| if config["optim_bits"] == 32: | |
| dtype = torch.float32 | |
| elif config["optim_bits"] == 8: | |
| dtype = torch.uint8 | |
| else: | |
| raise NotImplementedError( | |
| f'Amount of optimizer bits not supported: {config["optim_bits"]}' | |
| ) | |
| if p.numel() < config["min_8bit_size"]: | |
| dtype = torch.float32 | |
| state = self.state[p] | |
| state["step"] = 0 | |
| if dtype == torch.float32 or ( | |
| dtype == torch.uint8 and p.numel() < 4096 | |
| ): | |
| state["state1"] = self.get_state_buffer(p, dtype=torch.float32) | |
| elif dtype == torch.uint8: | |
| if state["step"] == 0: | |
| if "dynamic" not in self.name2qmap: | |
| self.fill_qmap() | |
| self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( | |
| p.device | |
| ) | |
| state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) | |
| state["qmap1"] = self.name2qmap["dynamic"] | |
| if config["block_wise"]: | |
| n = p.numel() | |
| blocks = n // 2048 | |
| blocks += 1 if n % 2048 > 0 else 0 | |
| state["absmax1"] = torch.zeros( | |
| (blocks,), dtype=torch.float32, device=p.device | |
| ) | |
| else: | |
| state["max1"] = torch.zeros( | |
| (1,), dtype=torch.float32, device=p.device | |
| ) | |
| state["new_max1"] = torch.zeros( | |
| (1,), dtype=torch.float32, device=p.device | |
| ) | |
| if config["percentile_clipping"] < 100: | |
| state["gnorm_vec"] = torch.zeros((100,), device=p.device) | |
| if config["max_unorm"] > 0.0: | |
| state["unorm_vec"] = torch.zeros((1,), device=p.device) | |
| def update_step(self, group, p, gindex, pindex): | |
| state = self.state[p] | |
| grad = p.grad | |
| config = self.get_config(gindex, pindex, group) | |
| state["step"] += 1 | |
| step = state["step"] | |
| if config["percentile_clipping"] < 100: | |
| current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( | |
| grad, state["gnorm_vec"], step, config["percentile_clipping"] | |
| ) | |
| else: | |
| gnorm_scale = 1.0 | |
| if state["state1"].dtype == torch.float: | |
| F.optimizer_update_32bit( | |
| self.optimizer_name, | |
| grad, | |
| p, | |
| state["state1"], | |
| config["betas"][0], | |
| config["eps"], | |
| step, | |
| config["lr"], | |
| None, | |
| config['betas'][1], | |
| config["weight_decay"], | |
| gnorm_scale, | |
| state["unorm_vec"] if config["max_unorm"] > 0.0 else None, | |
| max_unorm=config["max_unorm"], | |
| skip_zeros=config["skip_zeros"], | |
| ) | |
| elif state["state1"].dtype == torch.uint8 and not config["block_wise"]: | |
| F.optimizer_update_8bit( | |
| self.optimizer_name, | |
| grad, | |
| p, | |
| state["state1"], | |
| None, | |
| config["betas"][0], | |
| config["betas"][1], | |
| config["eps"], | |
| step, | |
| config["lr"], | |
| state["qmap1"], | |
| None, | |
| state["max1"], | |
| None, | |
| state["new_max1"], | |
| None, | |
| config["weight_decay"], | |
| gnorm_scale, | |
| state["unorm_vec"] if config["max_unorm"] > 0.0 else None, | |
| max_unorm=config["max_unorm"], | |
| ) | |
| state["max1"], state["new_max1"] = state["new_max1"], state["max1"] | |
| elif state["state1"].dtype == torch.uint8 and config["block_wise"]: | |
| F.optimizer_update_8bit_blockwise( | |
| self.optimizer_name, | |
| grad, | |
| p, | |
| state["state1"], | |
| None, | |
| config["betas"][0], | |
| config["betas"][1], | |
| config["eps"], | |
| step, | |
| config["lr"], | |
| state["qmap1"], | |
| None, | |
| state["absmax1"], | |
| None, | |
| config["weight_decay"], | |
| gnorm_scale=gnorm_scale, | |
| skip_zeros=config["skip_zeros"], | |
| ) | |