Spaces:
Paused
Paused
| import os | |
| from typing import List, Callable | |
| import numpy as np | |
| import torch | |
| from torch.nn import functional as F | |
| np_dtype_to_torch_dtype = { | |
| np.float16: torch.float16, | |
| np.float32: torch.float32, | |
| np.uint8: torch.uint8, | |
| np.int8: torch.int8, | |
| np.int32: torch.int32, | |
| np.int64: torch.int64, | |
| bool: torch.bool, | |
| } | |
| class IndexDiff: | |
| def __init__(self, off_elements: torch.Tensor=None, off_positions: torch.Tensor=None, on_positions: torch.Tensor=None): | |
| self.off_elements = off_elements | |
| self.off_positions = off_positions | |
| self.on_positions = on_positions | |
| def batch_copy(sources: List[torch.Tensor], copy_stream, indices=None, device='cuda'): | |
| with torch.cuda.stream(copy_stream): | |
| out = () | |
| for src in sources: | |
| indexed = src[indices] if indices is not None else src | |
| dst = torch.empty(indexed.shape, device=device, dtype=src.dtype) | |
| dst.copy_(indexed, non_blocking=True) | |
| out += (dst,) | |
| return out | |
| def mmap_to_tensor(torch_wrapped_mmap, pin_memory=False) -> torch.Tensor: | |
| out = torch.empty(torch_wrapped_mmap.shape, dtype=torch_wrapped_mmap.dtype, device='cpu', pin_memory=pin_memory) | |
| out.copy_(torch_wrapped_mmap) | |
| return out | |
| # Assuming that each entry of cached_indices is a step down the memory hierarchy, | |
| # compute the diff at each level of the hierarchy. | |
| # e.g. the first loop computes the indices that the GPU does not have, | |
| # and the second loop computes the indices *of that diff* that the CPU does not have. | |
| def compute_index_diffs(new_indices: torch.Tensor, cached_indices_list: List[torch.Tensor], pin_memory=True): | |
| diffs = [] | |
| current_diff = new_indices | |
| for cached_indices in cached_indices_list: | |
| if current_diff.size(0) == 0: | |
| # No need to go further down the hierarchy | |
| break | |
| # Compute elements of new indices not contained current indices | |
| off_elements = torch.tensor( | |
| list(set(current_diff.tolist()).difference(set(cached_indices.tolist()))), | |
| device='cpu', | |
| dtype=torch.int32, | |
| pin_memory=pin_memory | |
| ) | |
| # Compute mask of current indices where new indices does not contain the element | |
| on_position_mask = torch.isin(cached_indices, current_diff, assume_unique=True) | |
| on_positions = torch.nonzero(on_position_mask).flatten() | |
| off_positions = torch.nonzero(~on_position_mask).flatten()[:off_elements.size(0)] | |
| diffs.append(IndexDiff(off_elements, off_positions, on_positions)) | |
| current_diff = off_elements | |
| return diffs | |
| def topk_and_threshold(x, k, threshold=1): | |
| vals, indices = torch.topk(x, k, sorted=True) | |
| return indices[vals > threshold].int() | |
| def load_mlp_sparsity_predictor(weight_path_prefix: str, layer_num: int, dtype: torch.dtype, device: str = 'cuda') -> Callable: | |
| path_prefix = f'{weight_path_prefix}decoder.layers.{layer_num}.attn.mlp-sparsity-predictor.' | |
| return load_predictor(path_prefix, dtype, device=device) | |
| def load_predictor(path_prefix: str, dtype: torch.dtype, device: str='cuda') -> Callable: | |
| path = lambda i: os.path.expanduser(f'{path_prefix}{i}.weight') | |
| if os.path.exists(path(1)): | |
| l1 = torch.load(path(1)).to(device).to(dtype) | |
| l2 = torch.load(path(2)).to(device).to(dtype) | |
| return lambda x: F.linear(F.linear(x, l1), l2) | |
| else: | |
| print(f'could not find predictor at {path(1)}') | |
| return None |