fix: update frequencies when updating the rope base value (#40)
Browse files- fix: update frequencies when updating the rope base value (d8cbc92c8650d6bdc8e5afb28785625a98ccfab1)
- Update rotary.py (90873c4a21ac932b2df31d0e35e56b9c55460470)
- Update rotary.py (071760a5bbecc7b738c64583a3b5b337cd6d0667)
- Update rotary.py (1eb2361d4e9bdeedc1516196f02f199515916d30)
- Update rotary.py (066b97bdf39f4031bf1ddee4c706d5c842fb8748)
rotary.py
CHANGED
|
@@ -493,8 +493,16 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 493 |
|
| 494 |
@base.setter
|
| 495 |
def base(self, new_base):
|
|
|
|
| 496 |
if new_base > 0:
|
| 497 |
-
self._base
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
else:
|
| 499 |
raise ValueError("Rotary base value must be positive")
|
| 500 |
|
|
@@ -507,21 +515,27 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 507 |
)
|
| 508 |
)
|
| 509 |
|
| 510 |
-
def _update_cos_sin_cache(
|
|
|
|
|
|
|
| 511 |
# Reset the tables if the sequence length has changed,
|
| 512 |
# if we're on a new device (possibly due to tracing for instance),
|
| 513 |
# or if we're switching from inference mode to training
|
|
|
|
| 514 |
if (
|
| 515 |
seqlen > self._seq_len_cached
|
| 516 |
or self._cos_cached is None
|
| 517 |
or self._cos_cached.device != device
|
| 518 |
or self._cos_cached.dtype != dtype
|
| 519 |
or (self.training and self._cos_cached.is_inference())
|
|
|
|
| 520 |
):
|
| 521 |
self._seq_len_cached = seqlen
|
| 522 |
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
|
| 523 |
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
| 524 |
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
|
|
|
|
|
|
|
| 525 |
if self.pos_idx_in_fp32:
|
| 526 |
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
| 527 |
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
|
|
@@ -535,6 +549,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 535 |
else:
|
| 536 |
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
| 537 |
inv_freq = self.inv_freq
|
|
|
|
| 538 |
# Don't do einsum, it converts fp32 to fp16 under AMP
|
| 539 |
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 540 |
freqs = torch.outer(t, inv_freq)
|
|
|
|
| 493 |
|
| 494 |
@base.setter
|
| 495 |
def base(self, new_base):
|
| 496 |
+
new_base = float(new_base)
|
| 497 |
if new_base > 0:
|
| 498 |
+
if self._base != new_base: # only update if the base value has changed
|
| 499 |
+
self._base = new_base
|
| 500 |
+
self._update_cos_sin_cache(
|
| 501 |
+
self._seq_len_cached,
|
| 502 |
+
device=self.inv_freq.device,
|
| 503 |
+
dtype=self._cos_cached.dtype if self._cos_cached is not None else None,
|
| 504 |
+
rotary_base_changed=True,
|
| 505 |
+
)
|
| 506 |
else:
|
| 507 |
raise ValueError("Rotary base value must be positive")
|
| 508 |
|
|
|
|
| 515 |
)
|
| 516 |
)
|
| 517 |
|
| 518 |
+
def _update_cos_sin_cache(
|
| 519 |
+
self, seqlen, device=None, dtype=None, rotary_base_changed=False
|
| 520 |
+
):
|
| 521 |
# Reset the tables if the sequence length has changed,
|
| 522 |
# if we're on a new device (possibly due to tracing for instance),
|
| 523 |
# or if we're switching from inference mode to training
|
| 524 |
+
# or if the rotary base value was changed
|
| 525 |
if (
|
| 526 |
seqlen > self._seq_len_cached
|
| 527 |
or self._cos_cached is None
|
| 528 |
or self._cos_cached.device != device
|
| 529 |
or self._cos_cached.dtype != dtype
|
| 530 |
or (self.training and self._cos_cached.is_inference())
|
| 531 |
+
or rotary_base_changed
|
| 532 |
):
|
| 533 |
self._seq_len_cached = seqlen
|
| 534 |
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
|
| 535 |
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
| 536 |
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
|
| 537 |
+
if rotary_base_changed:
|
| 538 |
+
self.inv_freq = self._compute_inv_freq(device=device)
|
| 539 |
if self.pos_idx_in_fp32:
|
| 540 |
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
| 541 |
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
|
|
|
|
| 549 |
else:
|
| 550 |
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
| 551 |
inv_freq = self.inv_freq
|
| 552 |
+
|
| 553 |
# Don't do einsum, it converts fp32 to fp16 under AMP
|
| 554 |
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 555 |
freqs = torch.outer(t, inv_freq)
|