Upload modeling_llada.py to support attention_mask as an input.
Browse files- modeling_llada.py +23 -5
modeling_llada.py
CHANGED
|
@@ -87,6 +87,7 @@ def init_weights(
|
|
| 87 |
) -> None:
|
| 88 |
"""
|
| 89 |
Initialize weights of a linear or embedding module.
|
|
|
|
| 90 |
:param config: The model config.
|
| 91 |
:param module: The linear or embedding submodule to initialize.
|
| 92 |
:param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
|
|
@@ -648,12 +649,12 @@ class LLaDABlock(nn.Module):
|
|
| 648 |
k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
| 649 |
v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
| 650 |
|
| 651 |
-
# Modify: MDM set causal to False
|
| 652 |
return F.scaled_dot_product_attention(
|
| 653 |
q,
|
| 654 |
k,
|
| 655 |
v,
|
| 656 |
-
attn_mask=
|
| 657 |
dropout_p=dropout_p,
|
| 658 |
is_causal=False,
|
| 659 |
)
|
|
@@ -711,7 +712,7 @@ class LLaDABlock(nn.Module):
|
|
| 711 |
q,
|
| 712 |
k,
|
| 713 |
v,
|
| 714 |
-
attn_mask=
|
| 715 |
dropout_p=0.0 if not self.training else self.config.attention_dropout,
|
| 716 |
is_causal=False,
|
| 717 |
)
|
|
@@ -1156,7 +1157,20 @@ class LLaDAModel(nn.Module):
|
|
| 1156 |
alibi_bias = alibi_attention_bias(seq_len, self.config, device)
|
| 1157 |
self.__cache["alibi_attention_bias"] = alibi_bias
|
| 1158 |
return alibi_bias
|
| 1159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1160 |
def forward(
|
| 1161 |
self,
|
| 1162 |
input_ids: torch.LongTensor,
|
|
@@ -1176,16 +1190,20 @@ class LLaDAModel(nn.Module):
|
|
| 1176 |
which input IDs are masked. A `1` value in the mask means that
|
| 1177 |
the corresponding input ID should *not* be ignored. A `0` means
|
| 1178 |
that the corresponding input ID is masked.
|
|
|
|
| 1179 |
This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
|
| 1180 |
library.
|
| 1181 |
:param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
|
| 1182 |
`(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
|
| 1183 |
to introduce causal or other biases.
|
|
|
|
| 1184 |
If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
|
| 1185 |
indicates that the i-th element in the sequence is allowed to attend to the j-th
|
| 1186 |
element in the sequence.
|
|
|
|
| 1187 |
If the tensor is a float tensor, it will just be added to the attention
|
| 1188 |
scores before the softmax.
|
|
|
|
| 1189 |
The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
|
| 1190 |
:param past_key_values: Pre-computed keys and values for each attention block.
|
| 1191 |
Can be used to speed up sequential decoding. The `input_ids` which have
|
|
@@ -1252,7 +1270,7 @@ class LLaDAModel(nn.Module):
|
|
| 1252 |
self.__cache, past_length + seq_len, x.device
|
| 1253 |
) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
|
| 1254 |
elif attention_bias is None:
|
| 1255 |
-
attention_bias =
|
| 1256 |
elif attention_bias.dtype in (torch.int8, torch.bool):
|
| 1257 |
attention_bias = attention_bias.to(dtype=torch.float)
|
| 1258 |
attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
|
|
|
|
| 87 |
) -> None:
|
| 88 |
"""
|
| 89 |
Initialize weights of a linear or embedding module.
|
| 90 |
+
|
| 91 |
:param config: The model config.
|
| 92 |
:param module: The linear or embedding submodule to initialize.
|
| 93 |
:param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
|
|
|
|
| 649 |
k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
| 650 |
v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
| 651 |
|
| 652 |
+
# Modify: MDM set causal to False.
|
| 653 |
return F.scaled_dot_product_attention(
|
| 654 |
q,
|
| 655 |
k,
|
| 656 |
v,
|
| 657 |
+
attn_mask=attn_mask,
|
| 658 |
dropout_p=dropout_p,
|
| 659 |
is_causal=False,
|
| 660 |
)
|
|
|
|
| 712 |
q,
|
| 713 |
k,
|
| 714 |
v,
|
| 715 |
+
attn_mask=attention_bias,
|
| 716 |
dropout_p=0.0 if not self.training else self.config.attention_dropout,
|
| 717 |
is_causal=False,
|
| 718 |
)
|
|
|
|
| 1157 |
alibi_bias = alibi_attention_bias(seq_len, self.config, device)
|
| 1158 |
self.__cache["alibi_attention_bias"] = alibi_bias
|
| 1159 |
return alibi_bias
|
| 1160 |
+
|
| 1161 |
+
def get_bidirectional_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
| 1162 |
+
if (bidirectional_bias := self.__cache.get("bidirectional_attention_bias")) is not None and bidirectional_bias.shape[
|
| 1163 |
+
-1
|
| 1164 |
+
] >= seq_len:
|
| 1165 |
+
if bidirectional_bias.device != device:
|
| 1166 |
+
bidirectional_bias = bidirectional_bias.to(device)
|
| 1167 |
+
self.__cache["bidirectional_attention_bias"] = bidirectional_bias
|
| 1168 |
+
return bidirectional_bias
|
| 1169 |
+
with torch.autocast(device.type, enabled=False):
|
| 1170 |
+
bidirectional_bias = torch.zeros((1, 1, seq_len, seq_len), device=device, dtype=torch.float)
|
| 1171 |
+
self.__cache["bidirectional_attention_bias"] = bidirectional_bias
|
| 1172 |
+
return bidirectional_bias
|
| 1173 |
+
|
| 1174 |
def forward(
|
| 1175 |
self,
|
| 1176 |
input_ids: torch.LongTensor,
|
|
|
|
| 1190 |
which input IDs are masked. A `1` value in the mask means that
|
| 1191 |
the corresponding input ID should *not* be ignored. A `0` means
|
| 1192 |
that the corresponding input ID is masked.
|
| 1193 |
+
|
| 1194 |
This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
|
| 1195 |
library.
|
| 1196 |
:param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
|
| 1197 |
`(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
|
| 1198 |
to introduce causal or other biases.
|
| 1199 |
+
|
| 1200 |
If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
|
| 1201 |
indicates that the i-th element in the sequence is allowed to attend to the j-th
|
| 1202 |
element in the sequence.
|
| 1203 |
+
|
| 1204 |
If the tensor is a float tensor, it will just be added to the attention
|
| 1205 |
scores before the softmax.
|
| 1206 |
+
|
| 1207 |
The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
|
| 1208 |
:param past_key_values: Pre-computed keys and values for each attention block.
|
| 1209 |
Can be used to speed up sequential decoding. The `input_ids` which have
|
|
|
|
| 1270 |
self.__cache, past_length + seq_len, x.device
|
| 1271 |
) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
|
| 1272 |
elif attention_bias is None:
|
| 1273 |
+
attention_bias = self.get_bidirectional_attention_bias(past_length + seq_len, x.device)
|
| 1274 |
elif attention_bias.dtype in (torch.int8, torch.bool):
|
| 1275 |
attention_bias = attention_bias.to(dtype=torch.float)
|
| 1276 |
attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
|