zhouzaida
commited on
Commit
·
7718375
1
Parent(s):
9e6c322
add sdpa back
Browse files- modeling_kimi_vl.py +33 -1
modeling_kimi_vl.py
CHANGED
|
@@ -145,6 +145,38 @@ def multihead_attention(
|
|
| 145 |
return attn_out
|
| 146 |
|
| 147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
def eager_attention(
|
| 149 |
q: torch.Tensor,
|
| 150 |
k: torch.Tensor,
|
|
@@ -178,6 +210,7 @@ def eager_attention(
|
|
| 178 |
|
| 179 |
VL_VISION_ATTENTION_FUNCTIONS = {
|
| 180 |
"flash_attention_2": multihead_attention,
|
|
|
|
| 181 |
"eager": eager_attention,
|
| 182 |
}
|
| 183 |
|
|
@@ -2230,7 +2263,6 @@ class MoonVitPretrainedModel(PreTrainedModel):
|
|
| 2230 |
_no_split_modules = ["PackingTransformer"]
|
| 2231 |
_supports_flash_attn_2 = True
|
| 2232 |
_supports_sdpa = True
|
| 2233 |
-
|
| 2234 |
def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
|
| 2235 |
super().__init__(config, *inputs, **kwargs)
|
| 2236 |
config = deepcopy(config)
|
|
|
|
| 145 |
return attn_out
|
| 146 |
|
| 147 |
|
| 148 |
+
def sdpa_attention(
|
| 149 |
+
q: torch.Tensor,
|
| 150 |
+
k: torch.Tensor,
|
| 151 |
+
v: torch.Tensor,
|
| 152 |
+
q_cu_seqlens: Optional[torch.Tensor] = None,
|
| 153 |
+
k_cu_seqlens: Optional[torch.Tensor] = None,
|
| 154 |
+
) -> torch.Tensor:
|
| 155 |
+
"""SDPA attention.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
| 159 |
+
or (tot_seqlens, num_heads, head_dim) if packing.
|
| 160 |
+
"""
|
| 161 |
+
seq_length = q.shape[0]
|
| 162 |
+
attention_mask = torch.zeros(
|
| 163 |
+
[1, seq_length, seq_length], device=q.device, dtype=torch.bool
|
| 164 |
+
)
|
| 165 |
+
for i in range(1, len(q_cu_seqlens)):
|
| 166 |
+
attention_mask[
|
| 167 |
+
...,
|
| 168 |
+
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
|
| 169 |
+
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
|
| 170 |
+
] = True
|
| 171 |
+
q = q.transpose(0, 1)
|
| 172 |
+
k = k.transpose(0, 1)
|
| 173 |
+
v = v.transpose(0, 1)
|
| 174 |
+
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
| 175 |
+
attn_output = attn_output.transpose(0, 1)
|
| 176 |
+
attn_output = attn_output.reshape(seq_length, -1)
|
| 177 |
+
return attn_output
|
| 178 |
+
|
| 179 |
+
|
| 180 |
def eager_attention(
|
| 181 |
q: torch.Tensor,
|
| 182 |
k: torch.Tensor,
|
|
|
|
| 210 |
|
| 211 |
VL_VISION_ATTENTION_FUNCTIONS = {
|
| 212 |
"flash_attention_2": multihead_attention,
|
| 213 |
+
"sdpa": sdpa_attention,
|
| 214 |
"eager": eager_attention,
|
| 215 |
}
|
| 216 |
|
|
|
|
| 2263 |
_no_split_modules = ["PackingTransformer"]
|
| 2264 |
_supports_flash_attn_2 = True
|
| 2265 |
_supports_sdpa = True
|
|
|
|
| 2266 |
def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
|
| 2267 |
super().__init__(config, *inputs, **kwargs)
|
| 2268 |
config = deepcopy(config)
|