Spaces:
Sleeping
Sleeping
| import torch | |
| import math | |
| class RopeND: | |
| def __init__(self, head_dim=64, nd=3, max_lens=[1024, 64, 64], nd_split=[2, 1, 1], bases=[1000, 1000, 1000], | |
| auto_base=True, cache_longer=1): | |
| self.nd = nd | |
| self.head_dim = head_dim | |
| self.max_lens = max_lens | |
| self.nd_split = nd_split | |
| self.split_dims = [2 * i * (head_dim // 2 // sum(nd_split)) for i in nd_split] | |
| assert sum(self.split_dims) == head_dim | |
| self.auto_base = auto_base | |
| if auto_base: | |
| # empirical, make cos(theta) = -1 when length is kL. base = kL/pi | |
| # And L=1 the difference (1/base)**(1/32) ~ 0.7-0.8 ~ pi/4 | |
| # for traditional L = 4096, 8L/pi = 10.4k, base is set to 10k | |
| self.bases = [(int(8 * l / math.pi) // 100 + 1) * 100 for l in self.max_lens] | |
| print(f"Bases for rope: {self.bases}") | |
| else: | |
| self.bases = bases | |
| self.cache_longer = cache_longer | |
| def generated_cos_sin_mix2d(self, max_len, dim, device, base=1000): | |
| inv_freq = 1.0 / (base ** \ | |
| (torch.linspace(start=0, end=self.head_dim, steps=dim // 2, | |
| device=device).float() / self.head_dim)) | |
| assert inv_freq.size(0) * 2 == dim, f"inv_freq.size(0) = {inv_freq.size(0)}, required dim = {dim}" | |
| t = torch.arange(max_len * self.cache_longer, device=device).type_as(inv_freq) | |
| freqs = torch.einsum("i,j->ij", t, inv_freq) | |
| freqs = torch.cat([freqs, freqs], dim=1) | |
| return freqs.cos().to(torch.float), freqs.sin().to(torch.float) | |
| def generate_pos_embs_mix2d(self, position_ids, device=None): | |
| if device is None: | |
| device = position_ids.device | |
| if position_ids.dim() == 1: | |
| position_ids = position_ids.unsqueeze(0) | |
| cos_emb_all, sin_emb_all = [], [] | |
| for i in range(self.nd): | |
| dim_i = self.split_dims[i] | |
| base_i = self.bases[i] | |
| max_len_i = self.max_lens[i] | |
| if not hasattr(self, f"cos_{i}"): | |
| _cos, _sin = self.generated_cos_sin_mix2d(max_len=max_len_i, dim=dim_i, device=device, base=base_i) | |
| setattr(self, f"cos_{i}", _cos) | |
| setattr(self, f"sin_{i}", _sin) | |
| cos_emb_all.append(getattr(self, f'cos_{i}')[position_ids[i, :], :]) | |
| sin_emb_all.append(getattr(self, f'sin_{i}')[position_ids[i, :], :]) | |
| cos_emb = torch.cat(cos_emb_all, dim=-1) | |
| sin_emb = torch.cat(sin_emb_all, dim=-1) | |
| return cos_emb, sin_emb | |
| def __call__(self, q, k, position_ids): | |
| '''q: N N_head L C | |
| ''' | |
| cos_emb, sin_emb = self.generate_pos_embs_mix2d(position_ids, device=q.device) | |
| def rotate_half(x): | |
| """Rotates half the hidden dims of the input.""" | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2:] | |
| return torch.cat((-x2, x1), dim=-1) | |
| def apply_rotary_pos_emb(q, k, cos, sin): | |
| """Applies Rotary Position Embedding to the query and key tensors. | |
| Args: | |
| q (`torch.Tensor`): The query tensor. | |
| k (`torch.Tensor`): The key tensor. | |
| cos (`torch.Tensor`): The cosine part of the rotary embedding. | |
| sin (`torch.Tensor`): The sine part of the rotary embedding. | |
| Returns: | |
| `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. | |
| """ | |
| cos = cos.unsqueeze(0).unsqueeze(0) | |
| sin = sin.unsqueeze(0).unsqueeze(0) | |
| dtype = q.dtype | |
| q = q.to(torch.float) | |
| k = k.to(torch.float) | |
| q_embed = (q * cos) + (rotate_half(q) * sin) | |
| k_embed = (k * cos) + (rotate_half(k) * sin) | |
| q_embed = q_embed.to(dtype) | |
| k_embed = k_embed.to(dtype) | |
| return q_embed, k_embed | |
| q, k = apply_rotary_pos_emb(q, k, cos_emb, sin_emb) | |
| return q, k |