Commit
·
df388cc
1
Parent(s):
c572a14
Corrected rotary embedding
Browse files- attention.py +36 -16
attention.py
CHANGED
|
@@ -28,7 +28,7 @@ class RotaryEmbedding(nn.Module):
|
|
| 28 |
d_rotary: int,
|
| 29 |
rotary_base: float = 10000.0,
|
| 30 |
initial_cos_sin_cache_len: int = 2048,
|
| 31 |
-
device: torch.device =
|
| 32 |
) -> None:
|
| 33 |
super().__init__()
|
| 34 |
self.d_rotary = d_rotary
|
|
@@ -37,31 +37,37 @@ class RotaryEmbedding(nn.Module):
|
|
| 37 |
self.dtype = torch.float32
|
| 38 |
self._update_cos_sin_cache(seqlen=initial_cos_sin_cache_len)
|
| 39 |
|
| 40 |
-
def _update_cos_sin_cache(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
# only call this function when seqlen is larger than _max_seqlen
|
| 42 |
self._max_seqlen = seqlen
|
| 43 |
|
| 44 |
# m * theta_i = m * base^(-2i/d) = m * (1 / base^(2i/d)), where i in [1, d/2]
|
| 45 |
m = torch.arange(
|
| 46 |
seqlen,
|
| 47 |
-
device=
|
| 48 |
-
dtype=
|
| 49 |
)
|
| 50 |
theta_i = 1.0 / (
|
| 51 |
self.rotary_base ** (
|
| 52 |
torch.arange(
|
| 53 |
start=0,
|
| 54 |
end=self.d_rotary,
|
| 55 |
-
|
| 56 |
-
|
|
|
|
| 57 |
) / self.d_rotary
|
| 58 |
)
|
| 59 |
)
|
| 60 |
# torch.outer, since torch.einsum converts from fp32 to fp16 if used with torch.amp
|
| 61 |
# TODO: does this matter if I'm disabling torch.autocast?
|
| 62 |
m_theta_i = torch.outer(m, theta_i)
|
| 63 |
-
self._cos_cached = torch.cos(m_theta_i).to(
|
| 64 |
-
self._sin_cached = torch.sin(m_theta_i).to(
|
| 65 |
|
| 66 |
# TODO: scale_base caching is labelled as not yet done in Phi2
|
| 67 |
"""
|
|
@@ -90,14 +96,17 @@ class RotaryEmbedding(nn.Module):
|
|
| 90 |
sin: torch.FloatTensor, # dim: (_max_seqlen, d_rotary)
|
| 91 |
) -> torch.FloatTensor:
|
| 92 |
seqlen = x.shape[1]
|
| 93 |
-
|
|
|
|
|
|
|
| 94 |
broadcast_rearrange = "s d -> s 1 d" if x1.ndim == 4 else "s d -> s 1 1 d"
|
| 95 |
c, s = rearrange(cos[:seqlen], broadcast_rearrange), rearrange(sin[:seqlen], broadcast_rearrange)
|
| 96 |
x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]] # make sure rotary embedding is in float32
|
| 97 |
-
|
| 98 |
torch.FloatTensor,
|
| 99 |
torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], dim=-1).to(x.dtype)
|
| 100 |
)
|
|
|
|
| 101 |
|
| 102 |
def forward(
|
| 103 |
self,
|
|
@@ -107,9 +116,11 @@ class RotaryEmbedding(nn.Module):
|
|
| 107 |
if (
|
| 108 |
not self._max_seqlen
|
| 109 |
or self._max_seqlen < x.shape[1] + seqlen_offset
|
|
|
|
|
|
|
| 110 |
or (self.training and self._cos_cached.is_inference())
|
| 111 |
):
|
| 112 |
-
self._update_cos_sin_cache(seqlen=x.shape[1] + seqlen_offset)
|
| 113 |
return self._apply_rotary_emb_qkv(
|
| 114 |
x,
|
| 115 |
cast(torch.FloatTensor, self._cos_cached[seqlen_offset:]),
|
|
@@ -269,7 +280,8 @@ class MHA(nn.Module):
|
|
| 269 |
else RotaryEmbedding
|
| 270 |
)
|
| 271 |
self.rotary_emb = rotary_cls(
|
| 272 |
-
d_rotary=math.ceil((d_embedding // n_attn_heads) / 2), # d_rotary is half of d_head
|
|
|
|
| 273 |
initial_cos_sin_cache_len=initial_cos_sin_cache_len,
|
| 274 |
)
|
| 275 |
|
|
@@ -378,12 +390,20 @@ class MHA(nn.Module):
|
|
| 378 |
kv_cache: KVCache,
|
| 379 |
key_padding_mask: torch.BoolTensor | None,
|
| 380 |
) -> torch.FloatTensor:
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
seqlen_offset = 0 if kv_cache is None else kv_cache.seqlen_offset,
|
| 385 |
)
|
| 386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
self._update_kv_cache(kv, kv_cache, self.block_n)
|
| 388 |
causal = False # turning off causal mask for cross attention
|
| 389 |
|
|
|
|
| 28 |
d_rotary: int,
|
| 29 |
rotary_base: float = 10000.0,
|
| 30 |
initial_cos_sin_cache_len: int = 2048,
|
| 31 |
+
device: torch.device | None = None,
|
| 32 |
) -> None:
|
| 33 |
super().__init__()
|
| 34 |
self.d_rotary = d_rotary
|
|
|
|
| 37 |
self.dtype = torch.float32
|
| 38 |
self._update_cos_sin_cache(seqlen=initial_cos_sin_cache_len)
|
| 39 |
|
| 40 |
+
def _update_cos_sin_cache(
|
| 41 |
+
self,
|
| 42 |
+
seqlen: int,
|
| 43 |
+
device: str | None = None,
|
| 44 |
+
dtype: torch.dtype | None = None,
|
| 45 |
+
) -> None:
|
| 46 |
# only call this function when seqlen is larger than _max_seqlen
|
| 47 |
self._max_seqlen = seqlen
|
| 48 |
|
| 49 |
# m * theta_i = m * base^(-2i/d) = m * (1 / base^(2i/d)), where i in [1, d/2]
|
| 50 |
m = torch.arange(
|
| 51 |
seqlen,
|
| 52 |
+
device=device,
|
| 53 |
+
dtype=torch.float32,
|
| 54 |
)
|
| 55 |
theta_i = 1.0 / (
|
| 56 |
self.rotary_base ** (
|
| 57 |
torch.arange(
|
| 58 |
start=0,
|
| 59 |
end=self.d_rotary,
|
| 60 |
+
step=2,
|
| 61 |
+
device=device,
|
| 62 |
+
dtype=torch.float32,
|
| 63 |
) / self.d_rotary
|
| 64 |
)
|
| 65 |
)
|
| 66 |
# torch.outer, since torch.einsum converts from fp32 to fp16 if used with torch.amp
|
| 67 |
# TODO: does this matter if I'm disabling torch.autocast?
|
| 68 |
m_theta_i = torch.outer(m, theta_i)
|
| 69 |
+
self._cos_cached = torch.cos(m_theta_i).to(dtype)
|
| 70 |
+
self._sin_cached = torch.sin(m_theta_i).to(dtype)
|
| 71 |
|
| 72 |
# TODO: scale_base caching is labelled as not yet done in Phi2
|
| 73 |
"""
|
|
|
|
| 96 |
sin: torch.FloatTensor, # dim: (_max_seqlen, d_rotary)
|
| 97 |
) -> torch.FloatTensor:
|
| 98 |
seqlen = x.shape[1]
|
| 99 |
+
x_to_rotate = x[..., :self.d_rotary]
|
| 100 |
+
x_to_keep_unrotated = x[..., self.d_rotary:]
|
| 101 |
+
x1, x2 = x_to_rotate.chunk(2, dim=-1) # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_rotary/2)
|
| 102 |
broadcast_rearrange = "s d -> s 1 d" if x1.ndim == 4 else "s d -> s 1 1 d"
|
| 103 |
c, s = rearrange(cos[:seqlen], broadcast_rearrange), rearrange(sin[:seqlen], broadcast_rearrange)
|
| 104 |
x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]] # make sure rotary embedding is in float32
|
| 105 |
+
x_rotated = cast(
|
| 106 |
torch.FloatTensor,
|
| 107 |
torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], dim=-1).to(x.dtype)
|
| 108 |
)
|
| 109 |
+
return torch.cat([x_rotated, x_to_keep_unrotated], axis=-1)
|
| 110 |
|
| 111 |
def forward(
|
| 112 |
self,
|
|
|
|
| 116 |
if (
|
| 117 |
not self._max_seqlen
|
| 118 |
or self._max_seqlen < x.shape[1] + seqlen_offset
|
| 119 |
+
or self._cos_cached.device != x.device
|
| 120 |
+
or self._cos_cached.dtype != x.dtype
|
| 121 |
or (self.training and self._cos_cached.is_inference())
|
| 122 |
):
|
| 123 |
+
self._update_cos_sin_cache(seqlen=x.shape[1] + seqlen_offset, device=x.device, dtype=x.dtype)
|
| 124 |
return self._apply_rotary_emb_qkv(
|
| 125 |
x,
|
| 126 |
cast(torch.FloatTensor, self._cos_cached[seqlen_offset:]),
|
|
|
|
| 280 |
else RotaryEmbedding
|
| 281 |
)
|
| 282 |
self.rotary_emb = rotary_cls(
|
| 283 |
+
# d_rotary=math.ceil((d_embedding // n_attn_heads) / 2), # d_rotary is half of d_head
|
| 284 |
+
d_rotary=32, # TODO: figure out why Phi2 uses this
|
| 285 |
initial_cos_sin_cache_len=initial_cos_sin_cache_len,
|
| 286 |
)
|
| 287 |
|
|
|
|
| 390 |
kv_cache: KVCache,
|
| 391 |
key_padding_mask: torch.BoolTensor | None,
|
| 392 |
) -> torch.FloatTensor:
|
| 393 |
+
qk = qkv[:, :, :2, :, :]
|
| 394 |
+
qk = self.rotary_emb(
|
| 395 |
+
qk,
|
| 396 |
seqlen_offset = 0 if kv_cache is None else kv_cache.seqlen_offset,
|
| 397 |
)
|
| 398 |
+
v = cast(torch.FloatTensor, qkv[:, :, 2, :, :])
|
| 399 |
+
q = qk[:, :, 0, :, :]
|
| 400 |
+
kv = torch.cat(
|
| 401 |
+
[
|
| 402 |
+
qk[:, :, 1, :, :].unsqueeze(2),
|
| 403 |
+
v.unsqueeze(2),
|
| 404 |
+
],
|
| 405 |
+
dim=2,
|
| 406 |
+
)
|
| 407 |
self._update_kv_cache(kv, kv_cache, self.block_n)
|
| 408 |
causal = False # turning off causal mask for cross attention
|
| 409 |
|