Update RotaryEmbedding caching
#33
by
beibin79
- opened
- modelling_RW.py +11 -15
modelling_RW.py
CHANGED
|
@@ -56,13 +56,12 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 56 |
base=10000,
|
| 57 |
):
|
| 58 |
super().__init__()
|
| 59 |
-
inv_freq = 1.0 / (base
|
|
|
|
| 60 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 61 |
self.head_dim = head_dim
|
| 62 |
-
self.
|
| 63 |
-
self.
|
| 64 |
-
self.cos_cached: torch.Tensor | None = None
|
| 65 |
-
self.sin_cached: torch.Tensor | None = None
|
| 66 |
|
| 67 |
def cos_sin(
|
| 68 |
self,
|
|
@@ -70,27 +69,24 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 70 |
device="cuda",
|
| 71 |
dtype=torch.bfloat16,
|
| 72 |
) -> torch.Tensor:
|
| 73 |
-
if seq_len
|
| 74 |
self.seq_len_cached = seq_len
|
| 75 |
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
|
| 76 |
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 77 |
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
| 78 |
-
|
| 79 |
if dtype in [torch.float16, torch.bfloat16]:
|
| 80 |
emb = emb.float()
|
|
|
|
|
|
|
| 81 |
|
| 82 |
-
|
| 83 |
-
self.sin_cached = emb.sin()[None, :, :]
|
| 84 |
-
|
| 85 |
-
self.cos_cached = self.cos_cached.type(dtype)
|
| 86 |
-
self.sin_cached = self.sin_cached.type(dtype)
|
| 87 |
-
|
| 88 |
-
return self.cos_cached, self.sin_cached
|
| 89 |
|
| 90 |
def forward(self, q, k):
|
| 91 |
batch, seq_len, head_dim = q.shape
|
|
|
|
| 92 |
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
| 93 |
-
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) *
|
|
|
|
| 94 |
|
| 95 |
|
| 96 |
def _make_causal_mask(
|
|
|
|
| 56 |
base=10000,
|
| 57 |
):
|
| 58 |
super().__init__()
|
| 59 |
+
inv_freq = 1.0 / (base
|
| 60 |
+
**(torch.arange(0, head_dim, 2).float() / head_dim))
|
| 61 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 62 |
self.head_dim = head_dim
|
| 63 |
+
self.cos_cache_dict: dict = {}
|
| 64 |
+
self.sin_cache_dict: dict = {}
|
|
|
|
|
|
|
| 65 |
|
| 66 |
def cos_sin(
|
| 67 |
self,
|
|
|
|
| 69 |
device="cuda",
|
| 70 |
dtype=torch.bfloat16,
|
| 71 |
) -> torch.Tensor:
|
| 72 |
+
if seq_len not in self.cos_cache_dict or seq_len not in self.sin_cache_dict:
|
| 73 |
self.seq_len_cached = seq_len
|
| 74 |
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
|
| 75 |
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 76 |
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
|
|
|
| 77 |
if dtype in [torch.float16, torch.bfloat16]:
|
| 78 |
emb = emb.float()
|
| 79 |
+
self.cos_cache_dict[seq_len] = emb.cos()[None, :, :].type(dtype)
|
| 80 |
+
self.sin_cache_dict[seq_len] = emb.sin()[None, :, :].type(dtype)
|
| 81 |
|
| 82 |
+
return self.cos_cache_dict[seq_len], self.sin_cache_dict[seq_len]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
def forward(self, q, k):
|
| 85 |
batch, seq_len, head_dim = q.shape
|
| 86 |
+
assert seq_len is not None, "seq_len must be known and not None"
|
| 87 |
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
| 88 |
+
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) *
|
| 89 |
+
sin)
|
| 90 |
|
| 91 |
|
| 92 |
def _make_causal_mask(
|