Make `configuration_phi4flash.py` and `modeling_phi4flash.py` compatible with standard sliding window config (#7)
Browse files- `sliding_window: list[Optional[int]]` -> `sliding_window: int` + `layer_types: list[str]` (8928fa33c0a966cf0822f7601ec722d38dd61285)
- configuration_phi4flash.py +10 -4
- modeling_phi4flash.py +9 -9
configuration_phi4flash.py
CHANGED
|
@@ -112,6 +112,7 @@ class Phi4FlashConfig(PretrainedConfig):
|
|
| 112 |
bos_token_id=1,
|
| 113 |
eos_token_id=2,
|
| 114 |
sliding_window=2047,
|
|
|
|
| 115 |
mb_per_layer= 2,
|
| 116 |
mamba_d_state=16,
|
| 117 |
mamba_d_conv=4,
|
|
@@ -141,11 +142,16 @@ class Phi4FlashConfig(PretrainedConfig):
|
|
| 141 |
self.use_cache = use_cache
|
| 142 |
self.rope_theta = rope_theta
|
| 143 |
self.mb_per_layer = mb_per_layer
|
| 144 |
-
self.sliding_window =
|
| 145 |
-
|
| 146 |
-
for layer_idx in range(num_hidden_layers)
|
| 147 |
-
]
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
self.mamba_d_state = mamba_d_state
|
| 150 |
self.mamba_d_conv = mamba_d_conv
|
| 151 |
self.mamba_expand = mamba_expand
|
|
|
|
| 112 |
bos_token_id=1,
|
| 113 |
eos_token_id=2,
|
| 114 |
sliding_window=2047,
|
| 115 |
+
layer_types=None,
|
| 116 |
mb_per_layer= 2,
|
| 117 |
mamba_d_state=16,
|
| 118 |
mamba_d_conv=4,
|
|
|
|
| 142 |
self.use_cache = use_cache
|
| 143 |
self.rope_theta = rope_theta
|
| 144 |
self.mb_per_layer = mb_per_layer
|
| 145 |
+
self.sliding_window = sliding_window
|
| 146 |
+
self.layer_types = layer_types
|
|
|
|
|
|
|
| 147 |
|
| 148 |
+
if self.layer_types is None:
|
| 149 |
+
is_sliding = lambda i: i < num_hidden_layers // 2 and i % 2 == 1,
|
| 150 |
+
self.layer_types = [
|
| 151 |
+
"sliding_attention" if is_sliding(layer_idx) else "full_attention"
|
| 152 |
+
for layer_idx in range(num_hidden_layers)
|
| 153 |
+
]
|
| 154 |
+
|
| 155 |
self.mamba_d_state = mamba_d_state
|
| 156 |
self.mamba_d_conv = mamba_d_conv
|
| 157 |
self.mamba_expand = mamba_expand
|
modeling_phi4flash.py
CHANGED
|
@@ -129,7 +129,7 @@ def _get_cache(
|
|
| 129 |
cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
|
| 130 |
|
| 131 |
if cache_implementation == "sliding_window":
|
| 132 |
-
max_cache_len = min(self.config.sliding_window
|
| 133 |
|
| 134 |
need_new_cache = (
|
| 135 |
not hasattr(self, "_cache")
|
|
@@ -243,7 +243,7 @@ class SambaYCache(Cache):
|
|
| 243 |
sliding_cache_shape = (
|
| 244 |
self.max_batch_size,
|
| 245 |
self.num_key_value_heads,
|
| 246 |
-
min(config.sliding_window
|
| 247 |
self.head_dim,
|
| 248 |
)
|
| 249 |
conv_cache_shape = (self.max_batch_size, intermediate_size, conv_kernel_size)
|
|
@@ -573,7 +573,7 @@ class SambaYFlashAttention2(SambaYAttention):
|
|
| 573 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 574 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 575 |
|
| 576 |
-
use_sliding_windows = self.config.sliding_window is not None and self.config.
|
| 577 |
|
| 578 |
if past_key_value is not None:
|
| 579 |
|
|
@@ -710,8 +710,8 @@ class SambaYFlashAttention2(SambaYAttention):
|
|
| 710 |
softmax_scale=softmax_scale,
|
| 711 |
causal=causal,
|
| 712 |
window_size=(
|
| 713 |
-
self.config.
|
| 714 |
-
self.config.
|
| 715 |
),
|
| 716 |
)
|
| 717 |
|
|
@@ -735,8 +735,8 @@ class SambaYFlashAttention2(SambaYAttention):
|
|
| 735 |
softmax_scale=softmax_scale,
|
| 736 |
causal=causal,
|
| 737 |
window_size=(
|
| 738 |
-
self.config.
|
| 739 |
-
self.config.
|
| 740 |
),
|
| 741 |
)
|
| 742 |
|
|
@@ -1085,9 +1085,9 @@ class SambaYDecoderLayer(nn.Module):
|
|
| 1085 |
residual = residual.to(torch.float32)
|
| 1086 |
self_attn_weights = None
|
| 1087 |
else:
|
| 1088 |
-
if self.config.sliding_window is not None and self.config.
|
| 1089 |
if past_key_value is not None and cache_position[0] > 0: # when decoding
|
| 1090 |
-
attention_mask = attention_mask[:, -self.config.
|
| 1091 |
#hidden_states = self.input_layernorm2(hidden_states.to(dtype=self.input_layernorm2.weight.dtype))
|
| 1092 |
# Self Attention
|
| 1093 |
attn_outputs, self_attn_weights, yoco_key_values = self.attn(
|
|
|
|
| 129 |
cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
|
| 130 |
|
| 131 |
if cache_implementation == "sliding_window":
|
| 132 |
+
max_cache_len = min(self.config.sliding_window, max_cache_len)
|
| 133 |
|
| 134 |
need_new_cache = (
|
| 135 |
not hasattr(self, "_cache")
|
|
|
|
| 243 |
sliding_cache_shape = (
|
| 244 |
self.max_batch_size,
|
| 245 |
self.num_key_value_heads,
|
| 246 |
+
min(config.sliding_window, max_cache_len),
|
| 247 |
self.head_dim,
|
| 248 |
)
|
| 249 |
conv_cache_shape = (self.max_batch_size, intermediate_size, conv_kernel_size)
|
|
|
|
| 573 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 574 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 575 |
|
| 576 |
+
use_sliding_windows = self.config.sliding_window is not None and self.config.layer_types[self.layer_idx] is not None
|
| 577 |
|
| 578 |
if past_key_value is not None:
|
| 579 |
|
|
|
|
| 710 |
softmax_scale=softmax_scale,
|
| 711 |
causal=causal,
|
| 712 |
window_size=(
|
| 713 |
+
self.config.layer_types[self.layer_idx] -1,
|
| 714 |
+
self.config.layer_types[self.layer_idx] -1,
|
| 715 |
),
|
| 716 |
)
|
| 717 |
|
|
|
|
| 735 |
softmax_scale=softmax_scale,
|
| 736 |
causal=causal,
|
| 737 |
window_size=(
|
| 738 |
+
self.config.layer_types[self.layer_idx] -1,
|
| 739 |
+
self.config.layer_types[self.layer_idx] -1,
|
| 740 |
),
|
| 741 |
)
|
| 742 |
|
|
|
|
| 1085 |
residual = residual.to(torch.float32)
|
| 1086 |
self_attn_weights = None
|
| 1087 |
else:
|
| 1088 |
+
if self.config.sliding_window is not None and self.config.layer_types[self.layer_idx] is not None and attention_mask is not None: # efficient SDPA and no padding
|
| 1089 |
if past_key_value is not None and cache_position[0] > 0: # when decoding
|
| 1090 |
+
attention_mask = attention_mask[:, -self.config.layer_types[self.layer_idx]:]
|
| 1091 |
#hidden_states = self.input_layernorm2(hidden_states.to(dtype=self.input_layernorm2.weight.dtype))
|
| 1092 |
# Self Attention
|
| 1093 |
attn_outputs, self_attn_weights, yoco_key_values = self.attn(
|