| import math | |
| import warnings | |
| from collections.abc import Sequence | |
| from functools import partial | |
| from typing import Any, Callable, Optional, Tuple, Union | |
| import torch | |
| from torch import nn | |
| from .fc import FC_CLASS_REGISTRY | |
| from .norm import NORM_CLASS_REGISTRY | |
| try: | |
| import transformer_engine.pytorch as te | |
| except: | |
| te = None | |
| def torch_default_param_init_fn_(module: nn.Module, **kwargs: Any) -> None: | |
| del kwargs | |
| if hasattr(module, "reset_parameters") and isinstance( | |
| module.reset_parameters, Callable | |
| ): | |
| module.reset_parameters() | |
| def fused_init_helper_(module: nn.Module, init_fn_: Callable) -> None: | |
| _fused = getattr(module, "_fused", None) | |
| if _fused is None: | |
| raise RuntimeError(f"Internal logic error") | |
| assert isinstance(module.weight, torch.Tensor) | |
| (dim, splits) = _fused | |
| splits = (0, *splits, module.weight.size(dim)) | |
| for s, e in zip(splits[:-1], splits[1:]): | |
| slice_indices = [slice(None)] * module.weight.ndim | |
| slice_indices[dim] = slice(s, e) | |
| init_fn_(module.weight[slice_indices]) | |
| def generic_param_init_fn_( | |
| module: nn.Module, | |
| init_fn_: Callable, | |
| n_layers: int, | |
| d_model: Optional[int] = None, | |
| init_div_is_residual: Union[int, float, str, bool] = True, | |
| emb_init_std: Optional[float] = None, | |
| emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, | |
| **kwargs: Any, | |
| ) -> None: | |
| del kwargs | |
| init_div_is_residual = init_div_is_residual | |
| if init_div_is_residual is False: | |
| div_is_residual = 1.0 | |
| elif init_div_is_residual is True: | |
| div_is_residual = math.sqrt(2 * n_layers) | |
| elif isinstance(init_div_is_residual, float) or isinstance( | |
| init_div_is_residual, int | |
| ): | |
| div_is_residual = init_div_is_residual | |
| elif init_div_is_residual.isnumeric(): | |
| div_is_residual = float(init_div_is_residual) | |
| else: | |
| div_is_residual = 1.0 | |
| raise ValueError( | |
| f"Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}" | |
| ) | |
| if isinstance(module, tuple(set(FC_CLASS_REGISTRY.values()))): | |
| if hasattr(module, "_fused"): | |
| fused_init_helper_(module, init_fn_) | |
| else: | |
| init_fn_(module.weight) | |
| if module.bias is not None: | |
| assert isinstance(module.bias, torch.Tensor) | |
| torch.nn.init.zeros_(module.bias) | |
| if init_div_is_residual is not False and getattr(module, "_is_residual", False): | |
| with torch.no_grad(): | |
| module.weight.div_(div_is_residual) | |
| elif isinstance(module, nn.Embedding): | |
| if emb_init_std is not None: | |
| std = emb_init_std | |
| if std == 0: | |
| warnings.warn(f"Embedding layer initialized to 0.") | |
| emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std) | |
| elif emb_init_uniform_lim is not None: | |
| lim = emb_init_uniform_lim | |
| if isinstance(lim, Sequence): | |
| if len(lim) > 2: | |
| raise ValueError( | |
| f"Uniform init requires a min and a max limit. User input: {lim}." | |
| ) | |
| if lim[0] == lim[1]: | |
| warnings.warn(f"Embedding layer initialized to {lim[0]}.") | |
| else: | |
| if lim == 0: | |
| warnings.warn(f"Embedding layer initialized to 0.") | |
| lim = [-lim, lim] | |
| (a, b) = lim | |
| emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b) | |
| else: | |
| emb_init_fn_ = init_fn_ | |
| emb_init_fn_(module.weight) | |
| elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))): | |
| if hasattr(module, "weight") and isinstance(module.weight, torch.Tensor): | |
| torch.nn.init.ones_(module.weight) | |
| if hasattr(module, "bias") and isinstance(module.bias, torch.Tensor): | |
| torch.nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.MultiheadAttention): | |
| if module._qkv_same_embed_dim: | |
| assert module.in_proj_weight is not None | |
| assert ( | |
| module.q_proj_weight is None | |
| and module.k_proj_weight is None | |
| and (module.v_proj_weight is None) | |
| ) | |
| assert d_model is not None | |
| _d = d_model | |
| splits = (0, _d, 2 * _d, 3 * _d) | |
| for s, e in zip(splits[:-1], splits[1:]): | |
| init_fn_(module.in_proj_weight[s:e]) | |
| else: | |
| assert ( | |
| module.q_proj_weight is not None | |
| and module.k_proj_weight is not None | |
| and (module.v_proj_weight is not None) | |
| ) | |
| assert module.in_proj_weight is None | |
| init_fn_(module.q_proj_weight) | |
| init_fn_(module.k_proj_weight) | |
| init_fn_(module.v_proj_weight) | |
| if module.in_proj_bias is not None: | |
| torch.nn.init.zeros_(module.in_proj_bias) | |
| if module.bias_k is not None: | |
| torch.nn.init.zeros_(module.bias_k) | |
| if module.bias_v is not None: | |
| torch.nn.init.zeros_(module.bias_v) | |
| init_fn_(module.out_proj.weight) | |
| if init_div_is_residual is not False and getattr( | |
| module.out_proj, "_is_residual", False | |
| ): | |
| with torch.no_grad(): | |
| module.out_proj.weight.div_(div_is_residual) | |
| if module.out_proj.bias is not None: | |
| torch.nn.init.zeros_(module.out_proj.bias) | |
| elif te is not None and isinstance(module, te.LayerNormMLP): | |
| if isinstance(module.layer_norm_weight, torch.Tensor): | |
| torch.nn.init.ones_(module.layer_norm_weight) | |
| if isinstance(module.layer_norm_bias, torch.Tensor): | |
| torch.nn.init.zeros_(module.layer_norm_bias) | |
| init_fn_(module.fc1_weight) | |
| if module.fc1_bias is not None: | |
| assert isinstance(module.fc1_bias, torch.Tensor) | |
| torch.nn.init.zeros_(module.fc1_bias) | |
| init_fn_(module.fc2_weight) | |
| if module.fc2_bias is not None: | |
| assert isinstance(module.fc2_bias, torch.Tensor) | |
| torch.nn.init.zeros_(module.fc2_bias) | |
| with torch.no_grad(): | |
| module.fc2_weight.div_(div_is_residual) | |
| else: | |
| for _ in module.parameters(recurse=False): | |
| raise NotImplementedError( | |
| f"{module.__class__.__name__} parameters are not initialized by param_init_fn." | |
| ) | |
| def _normal_init_(std: float, mean: float = 0.0) -> Callable: | |
| return partial(torch.nn.init.normal_, mean=mean, std=std) | |
| def _normal_param_init_fn_( | |
| module: nn.Module, | |
| std: float, | |
| n_layers: int, | |
| d_model: Optional[int] = None, | |
| init_div_is_residual: Union[int, float, str, bool] = True, | |
| emb_init_std: Optional[float] = None, | |
| emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, | |
| **kwargs: Any, | |
| ) -> None: | |
| del kwargs | |
| init_fn_ = _normal_init_(std=std) | |
| generic_param_init_fn_( | |
| module=module, | |
| init_fn_=init_fn_, | |
| d_model=d_model, | |
| n_layers=n_layers, | |
| init_div_is_residual=init_div_is_residual, | |
| emb_init_std=emb_init_std, | |
| emb_init_uniform_lim=emb_init_uniform_lim, | |
| ) | |
| def baseline_param_init_fn_( | |
| module: nn.Module, | |
| init_std: Optional[float], | |
| n_layers: int, | |
| d_model: Optional[int] = None, | |
| init_div_is_residual: Union[int, float, str, bool] = True, | |
| emb_init_std: Optional[float] = None, | |
| emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, | |
| **kwargs: Any, | |
| ) -> None: | |
| del kwargs | |
| if init_std is None: | |
| raise ValueError( | |
| "You must set model.init_config['init_std'] to a float value to use the default initialization scheme." | |
| ) | |
| _normal_param_init_fn_( | |
| module=module, | |
| std=init_std, | |
| d_model=d_model, | |
| n_layers=n_layers, | |
| init_div_is_residual=init_div_is_residual, | |
| emb_init_std=emb_init_std, | |
| emb_init_uniform_lim=emb_init_uniform_lim, | |
| ) | |
| def small_param_init_fn_( | |
| module: nn.Module, | |
| n_layers: int, | |
| d_model: int, | |
| init_div_is_residual: Union[int, float, str, bool] = True, | |
| emb_init_std: Optional[float] = None, | |
| emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, | |
| **kwargs: Any, | |
| ) -> None: | |
| del kwargs | |
| std = math.sqrt(2 / (5 * d_model)) | |
| _normal_param_init_fn_( | |
| module=module, | |
| std=std, | |
| d_model=d_model, | |
| n_layers=n_layers, | |
| init_div_is_residual=init_div_is_residual, | |
| emb_init_std=emb_init_std, | |
| emb_init_uniform_lim=emb_init_uniform_lim, | |
| ) | |
| def neox_param_init_fn_( | |
| module: nn.Module, | |
| n_layers: int, | |
| d_model: int, | |
| emb_init_std: Optional[float] = None, | |
| emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, | |
| **kwargs: Any, | |
| ) -> None: | |
| """From section 2.3.1 of GPT-NeoX-20B: | |
| An Open-Source AutoregressiveLanguage Model — Black et. al. (2022) | |
| see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151 | |
| and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py | |
| """ | |
| del kwargs | |
| residual_div = n_layers / math.sqrt(10) | |
| small_param_init_fn_( | |
| module=module, | |
| d_model=d_model, | |
| n_layers=n_layers, | |
| init_div_is_residual=residual_div, | |
| emb_init_std=emb_init_std, | |
| emb_init_uniform_lim=emb_init_uniform_lim, | |
| ) | |
| def kaiming_uniform_param_init_fn_( | |
| module: nn.Module, | |
| n_layers: int, | |
| d_model: Optional[int] = None, | |
| init_div_is_residual: Union[int, float, str, bool] = True, | |
| emb_init_std: Optional[float] = None, | |
| emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, | |
| init_gain: float = 0, | |
| fan_mode: str = "fan_in", | |
| init_nonlinearity: str = "leaky_relu", | |
| **kwargs: Any, | |
| ) -> None: | |
| del kwargs | |
| kaiming_uniform_ = partial( | |
| nn.init.kaiming_uniform_, | |
| a=init_gain, | |
| mode=fan_mode, | |
| nonlinearity=init_nonlinearity, | |
| ) | |
| generic_param_init_fn_( | |
| module=module, | |
| init_fn_=kaiming_uniform_, | |
| d_model=d_model, | |
| n_layers=n_layers, | |
| init_div_is_residual=init_div_is_residual, | |
| emb_init_std=emb_init_std, | |
| emb_init_uniform_lim=emb_init_uniform_lim, | |
| ) | |
| def kaiming_normal_param_init_fn_( | |
| module: nn.Module, | |
| n_layers: int, | |
| d_model: Optional[int] = None, | |
| init_div_is_residual: Union[int, float, str, bool] = True, | |
| emb_init_std: Optional[float] = None, | |
| emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, | |
| init_gain: float = 0, | |
| fan_mode: str = "fan_in", | |
| init_nonlinearity: str = "leaky_relu", | |
| **kwargs: Any, | |
| ) -> None: | |
| del kwargs | |
| kaiming_normal_ = partial( | |
| torch.nn.init.kaiming_normal_, | |
| a=init_gain, | |
| mode=fan_mode, | |
| nonlinearity=init_nonlinearity, | |
| ) | |
| generic_param_init_fn_( | |
| module=module, | |
| init_fn_=kaiming_normal_, | |
| d_model=d_model, | |
| n_layers=n_layers, | |
| init_div_is_residual=init_div_is_residual, | |
| emb_init_std=emb_init_std, | |
| emb_init_uniform_lim=emb_init_uniform_lim, | |
| ) | |
| def xavier_uniform_param_init_fn_( | |
| module: nn.Module, | |
| n_layers: int, | |
| d_model: Optional[int] = None, | |
| init_div_is_residual: Union[int, float, str, bool] = True, | |
| emb_init_std: Optional[float] = None, | |
| emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, | |
| init_gain: float = 0, | |
| **kwargs: Any, | |
| ) -> None: | |
| del kwargs | |
| xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain) | |
| generic_param_init_fn_( | |
| module=module, | |
| init_fn_=xavier_uniform_, | |
| d_model=d_model, | |
| n_layers=n_layers, | |
| init_div_is_residual=init_div_is_residual, | |
| emb_init_std=emb_init_std, | |
| emb_init_uniform_lim=emb_init_uniform_lim, | |
| ) | |
| def xavier_normal_param_init_fn_( | |
| module: nn.Module, | |
| n_layers: int, | |
| d_model: Optional[int] = None, | |
| init_div_is_residual: Union[int, float, str, bool] = True, | |
| emb_init_std: Optional[float] = None, | |
| emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, | |
| init_gain: float = 0, | |
| **kwargs: Any, | |
| ) -> None: | |
| del kwargs | |
| xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain) | |
| generic_param_init_fn_( | |
| module=module, | |
| init_fn_=xavier_normal_, | |
| d_model=d_model, | |
| n_layers=n_layers, | |
| init_div_is_residual=init_div_is_residual, | |
| emb_init_std=emb_init_std, | |
| emb_init_uniform_lim=emb_init_uniform_lim, | |
| ) | |
| MODEL_INIT_REGISTRY = { | |
| "default_": torch_default_param_init_fn_, | |
| "baseline_": baseline_param_init_fn_, | |
| "kaiming_uniform_": kaiming_uniform_param_init_fn_, | |
| "kaiming_normal_": kaiming_normal_param_init_fn_, | |
| "neox_init_": neox_param_init_fn_, | |
| "small_init_": small_param_init_fn_, | |
| "xavier_uniform_": xavier_uniform_param_init_fn_, | |
| "xavier_normal_": xavier_normal_param_init_fn_, | |
| } | |