File size: 4,047 Bytes
0f34fb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import math

class RopeND:
    def __init__(self, head_dim=64, nd=3, max_lens=[1024, 64, 64], nd_split=[2, 1, 1], bases=[1000, 1000, 1000],
                 auto_base=True, cache_longer=1):
        self.nd = nd
        self.head_dim = head_dim
        self.max_lens = max_lens
        self.nd_split = nd_split
        self.split_dims = [2 * i * (head_dim // 2 // sum(nd_split)) for i in nd_split]
        assert sum(self.split_dims) == head_dim
        self.auto_base = auto_base
        if auto_base:
            # empirical, make cos(theta) = -1 when length is kL. base = kL/pi
            # And L=1 the difference (1/base)**(1/32) ~ 0.7-0.8 ~ pi/4
            # for traditional L = 4096, 8L/pi = 10.4k, base is set to 10k
            self.bases = [(int(8 * l / math.pi) // 100 + 1) * 100 for l in self.max_lens]
            print(f"Bases for rope: {self.bases}")
        else:
            self.bases = bases
        self.cache_longer = cache_longer

    def generated_cos_sin_mix2d(self, max_len, dim, device, base=1000):
        inv_freq = 1.0 / (base ** \
                          (torch.linspace(start=0, end=self.head_dim, steps=dim // 2,
                                          device=device).float() / self.head_dim))
        assert inv_freq.size(0) * 2 == dim, f"inv_freq.size(0) = {inv_freq.size(0)}, required dim = {dim}"

        t = torch.arange(max_len * self.cache_longer, device=device).type_as(inv_freq)
        freqs = torch.einsum("i,j->ij", t, inv_freq)
        freqs = torch.cat([freqs, freqs], dim=1)
        return freqs.cos().to(torch.float), freqs.sin().to(torch.float)

    def generate_pos_embs_mix2d(self, position_ids, device=None):
        if device is None:
            device = position_ids.device

        if position_ids.dim() == 1:
            position_ids = position_ids.unsqueeze(0)

        cos_emb_all, sin_emb_all = [], []
        for i in range(self.nd):
            dim_i = self.split_dims[i]
            base_i = self.bases[i]
            max_len_i = self.max_lens[i]
            if not hasattr(self, f"cos_{i}"):
                _cos, _sin = self.generated_cos_sin_mix2d(max_len=max_len_i, dim=dim_i, device=device, base=base_i)
                setattr(self, f"cos_{i}", _cos)
                setattr(self, f"sin_{i}", _sin)
            cos_emb_all.append(getattr(self, f'cos_{i}')[position_ids[i, :], :])
            sin_emb_all.append(getattr(self, f'sin_{i}')[position_ids[i, :], :])
        cos_emb = torch.cat(cos_emb_all, dim=-1)
        sin_emb = torch.cat(sin_emb_all, dim=-1)
        return cos_emb, sin_emb

    def __call__(self, q, k, position_ids):
        '''q: N N_head L C
        '''
        cos_emb, sin_emb = self.generate_pos_embs_mix2d(position_ids, device=q.device)

        def rotate_half(x):
            """Rotates half the hidden dims of the input."""
            x1 = x[..., : x.shape[-1] // 2]
            x2 = x[..., x.shape[-1] // 2:]
            return torch.cat((-x2, x1), dim=-1)

        def apply_rotary_pos_emb(q, k, cos, sin):
            """Applies Rotary Position Embedding to the query and key tensors.

            Args:
                q (`torch.Tensor`): The query tensor.
                k (`torch.Tensor`): The key tensor.
                cos (`torch.Tensor`): The cosine part of the rotary embedding.
                sin (`torch.Tensor`): The sine part of the rotary embedding.
            Returns:
                `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
            """
            cos = cos.unsqueeze(0).unsqueeze(0)
            sin = sin.unsqueeze(0).unsqueeze(0)
            dtype = q.dtype
            q = q.to(torch.float)
            k = k.to(torch.float)
            q_embed = (q * cos) + (rotate_half(q) * sin)
            k_embed = (k * cos) + (rotate_half(k) * sin)
            q_embed = q_embed.to(dtype)
            k_embed = k_embed.to(dtype)
            return q_embed, k_embed

        q, k = apply_rotary_pos_emb(q, k, cos_emb, sin_emb)
        return q, k