Spaces:
Paused
Paused
| import shlex | |
| import subprocess | |
| import torch | |
| from typing import Tuple | |
| def outlier_hook(module, input): | |
| assert isinstance(module, torch.nn.Linear) | |
| tracer = OutlierTracer.get_instance() | |
| hvalue = tracer.get_hvalue(module.weight) | |
| if hvalue not in tracer.hvalue2outlier_idx: | |
| outlier_idx = find_outlier_dims(module.weight) | |
| tracer.outliers.append(outlier_idx) | |
| tracer.hvalues.append(hvalue) | |
| if len(tracer.outliers) > 1: | |
| # assign the current layer the outlier idx found from the weight | |
| # of the previous linear layer | |
| if tracer.outliers[-1].numel() > 0: | |
| assert tracer.outliers[-1].max() < module.weight.shape[1] | |
| tracer.hvalue2outlier_idx[hvalue] = tracer.outliers[-1] | |
| else: | |
| # first layer, we cannot use the weight for outlier detection | |
| # we follow a mixed approach: | |
| # (1) zscore test of std of hidden dimension | |
| # (2) magnitude > 6 test | |
| merged = input[0].view(-1, input[0].shape[-1]) | |
| # (1) zscore test of std of hidden dimension | |
| outlier_idx = find_outlier_dims(merged, reduction_dim=1, zscore=3) | |
| # (2) magnitude > 6 test | |
| dims = (torch.abs(input[0])> 6).sum(dim=list(range(len(input[0].shape)-1))) | |
| outlier_idx2 = torch.where(dims > 0)[0] | |
| outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique() | |
| tracer.hvalue2outlier_idx[hvalue] = outlier_idx | |
| else: | |
| for hook in tracer.hooks: | |
| hook.remove() | |
| class OutlierTracer(object): | |
| _instance = None | |
| def __init__(self): | |
| raise RuntimeError("Call get_instance() instead") | |
| def initialize(self, model): | |
| self.last_w = None | |
| self.current_outlier_dims = None | |
| self.hvalues = [] | |
| self.outliers = [] | |
| self.hvalue2outlier_idx = {} | |
| self.initialized = True | |
| self.hooks = [] | |
| for n, m in model.named_modules(): | |
| if isinstance(m, torch.nn.Linear): | |
| self.hooks.append(m.register_forward_pre_hook(outlier_hook)) | |
| def is_initialized(self): | |
| return getattr(self, 'initialized', False) | |
| def get_hvalue(self, weight): | |
| return weight.data.storage().data_ptr() | |
| def get_outliers(self, weight): | |
| if not self.is_initialized(): | |
| print('Outlier tracer is not initialized...') | |
| return None | |
| hvalue = self.get_hvalue(weight) | |
| if hvalue in self.hvalue2outlier_idx: | |
| return self.hvalue2outlier_idx[hvalue] | |
| else: | |
| return None | |
| def get_instance(cls): | |
| if cls._instance is None: | |
| cls._instance = cls.__new__(cls) | |
| return cls._instance | |
| def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False): | |
| if rdm: | |
| return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long() | |
| m = weight.mean(reduction_dim) | |
| mm = m.mean() | |
| mstd = m.std() | |
| zm = (m-mm)/mstd | |
| std = weight.std(reduction_dim) | |
| stdm = std.mean() | |
| stdstd = std.std() | |
| zstd = (std-stdm)/stdstd | |
| if topk is not None: | |
| val, idx = torch.topk(std.abs(), k=topk, dim=0) | |
| else: | |
| idx = torch.where(zstd > zscore)[0] | |
| return idx | |
| def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None): | |
| """ | |
| Replace linear modules with a new Linear module. | |
| Parameters: | |
| model (`torch.nn.Module`): | |
| Input model or `torch.nn.Module` as the function is run recursively. | |
| linear_replacement (`torch.nn.Module`): | |
| The linear module that replaces the old one. Only expects standard arguments. | |
| If other arguments need to be passed, use a lambda. | |
| skip_modules (`List[str]`, *optional*, defaults to `lm_head`): | |
| List of modules names not to convert. Defaults to `lm_head`. | |
| copy_weights (`bool`): | |
| Copy the weights from the old linear module to the new one | |
| post_processing_fun_name (`str`): | |
| A function name of the replacement linear class that is called | |
| after processing. | |
| """ | |
| for name, module in model.named_children(): | |
| if len(list(module.children())) > 0: | |
| replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function) | |
| if isinstance(module, torch.nn.Linear) and name not in skip_modules: | |
| old_module = model._modules[name] | |
| model._modules[name] = linear_replacement( | |
| module.in_features, | |
| module.out_features, | |
| module.bias is not None, | |
| ) | |
| if copy_weights: | |
| model._modules[name].weight = old_module.weight | |
| model._modules[name].bias = old_module.bias | |
| if post_processing_function is not None: | |
| func = getattr(module, post_processing_function, None) | |
| if func is not None: func(module) | |
| return model | |
| def execute_and_return(command_string: str) -> Tuple[str, str]: | |
| def _decode(subprocess_err_out_tuple): | |
| return tuple( | |
| to_decode.decode("UTF-8").strip() | |
| for to_decode in subprocess_err_out_tuple | |
| ) | |
| def execute_and_return_decoded_std_streams(command_string): | |
| return _decode( | |
| subprocess.Popen( | |
| shlex.split(command_string), | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| ).communicate() | |
| ) | |
| std_out, std_err = execute_and_return_decoded_std_streams(command_string) | |
| return std_out, std_err | |
| def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None): | |
| """ | |
| Replace linear modules with a new Linear module. | |
| Parameters: | |
| model (`torch.nn.Module`): | |
| Input model or `torch.nn.Module` as the function is run recursively. | |
| linear_replacement (`torch.nn.Module`): | |
| The linear module that replaces the old one. Only expects standard arguments. | |
| If other arguments need to be passed, use a lambda. | |
| skip_modules (`List[str]`, *optional*, defaults to `lm_head`): | |
| List of modules names not to convert. Defaults to `lm_head`. | |
| copy_weights (`bool`): | |
| Copy the weights from the old linear module to the new one | |
| post_processing_fun_name (`str`): | |
| A function name of the replacement linear class that is called | |
| after processing. | |
| """ | |
| for name, module in model.named_children(): | |
| if len(list(module.children())) > 0: | |
| replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function) | |
| if isinstance(module, torch.nn.Linear) and name not in skip_modules: | |
| old_module = model._modules[name] | |
| model._modules[name] = linear_replacement( | |
| module.in_features, | |
| module.out_features, | |
| module.bias is not None, | |
| ) | |
| if copy_weights: | |
| model._modules[name].weight = old_module.weight | |
| model._modules[name].bias = old_module.bias | |
| if post_processing_function is not None: | |
| func = getattr(module, post_processing_function, None) | |
| if func is not None: func(module) | |
| return model | |