Monohydroxides commited on
Commit
84346fd
·
verified ·
1 Parent(s): c8d3a40

Upload modeling_llada.py to support attention_mask as an input.

Browse files
Files changed (1) hide show
  1. modeling_llada.py +23 -5
modeling_llada.py CHANGED
@@ -87,6 +87,7 @@ def init_weights(
87
  ) -> None:
88
  """
89
  Initialize weights of a linear or embedding module.
 
90
  :param config: The model config.
91
  :param module: The linear or embedding submodule to initialize.
92
  :param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
@@ -648,12 +649,12 @@ class LLaDABlock(nn.Module):
648
  k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
649
  v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
650
 
651
- # Modify: MDM set causal to False, and with no attn_mask.
652
  return F.scaled_dot_product_attention(
653
  q,
654
  k,
655
  v,
656
- attn_mask=None,
657
  dropout_p=dropout_p,
658
  is_causal=False,
659
  )
@@ -711,7 +712,7 @@ class LLaDABlock(nn.Module):
711
  q,
712
  k,
713
  v,
714
- attn_mask=None,
715
  dropout_p=0.0 if not self.training else self.config.attention_dropout,
716
  is_causal=False,
717
  )
@@ -1156,7 +1157,20 @@ class LLaDAModel(nn.Module):
1156
  alibi_bias = alibi_attention_bias(seq_len, self.config, device)
1157
  self.__cache["alibi_attention_bias"] = alibi_bias
1158
  return alibi_bias
1159
-
 
 
 
 
 
 
 
 
 
 
 
 
 
1160
  def forward(
1161
  self,
1162
  input_ids: torch.LongTensor,
@@ -1176,16 +1190,20 @@ class LLaDAModel(nn.Module):
1176
  which input IDs are masked. A `1` value in the mask means that
1177
  the corresponding input ID should *not* be ignored. A `0` means
1178
  that the corresponding input ID is masked.
 
1179
  This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
1180
  library.
1181
  :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
1182
  `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
1183
  to introduce causal or other biases.
 
1184
  If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
1185
  indicates that the i-th element in the sequence is allowed to attend to the j-th
1186
  element in the sequence.
 
1187
  If the tensor is a float tensor, it will just be added to the attention
1188
  scores before the softmax.
 
1189
  The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
1190
  :param past_key_values: Pre-computed keys and values for each attention block.
1191
  Can be used to speed up sequential decoding. The `input_ids` which have
@@ -1252,7 +1270,7 @@ class LLaDAModel(nn.Module):
1252
  self.__cache, past_length + seq_len, x.device
1253
  ) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
1254
  elif attention_bias is None:
1255
- attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
1256
  elif attention_bias.dtype in (torch.int8, torch.bool):
1257
  attention_bias = attention_bias.to(dtype=torch.float)
1258
  attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
 
87
  ) -> None:
88
  """
89
  Initialize weights of a linear or embedding module.
90
+
91
  :param config: The model config.
92
  :param module: The linear or embedding submodule to initialize.
93
  :param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
 
649
  k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
650
  v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
651
 
652
+ # Modify: MDM set causal to False.
653
  return F.scaled_dot_product_attention(
654
  q,
655
  k,
656
  v,
657
+ attn_mask=attn_mask,
658
  dropout_p=dropout_p,
659
  is_causal=False,
660
  )
 
712
  q,
713
  k,
714
  v,
715
+ attn_mask=attention_bias,
716
  dropout_p=0.0 if not self.training else self.config.attention_dropout,
717
  is_causal=False,
718
  )
 
1157
  alibi_bias = alibi_attention_bias(seq_len, self.config, device)
1158
  self.__cache["alibi_attention_bias"] = alibi_bias
1159
  return alibi_bias
1160
+
1161
+ def get_bidirectional_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
1162
+ if (bidirectional_bias := self.__cache.get("bidirectional_attention_bias")) is not None and bidirectional_bias.shape[
1163
+ -1
1164
+ ] >= seq_len:
1165
+ if bidirectional_bias.device != device:
1166
+ bidirectional_bias = bidirectional_bias.to(device)
1167
+ self.__cache["bidirectional_attention_bias"] = bidirectional_bias
1168
+ return bidirectional_bias
1169
+ with torch.autocast(device.type, enabled=False):
1170
+ bidirectional_bias = torch.zeros((1, 1, seq_len, seq_len), device=device, dtype=torch.float)
1171
+ self.__cache["bidirectional_attention_bias"] = bidirectional_bias
1172
+ return bidirectional_bias
1173
+
1174
  def forward(
1175
  self,
1176
  input_ids: torch.LongTensor,
 
1190
  which input IDs are masked. A `1` value in the mask means that
1191
  the corresponding input ID should *not* be ignored. A `0` means
1192
  that the corresponding input ID is masked.
1193
+
1194
  This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
1195
  library.
1196
  :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
1197
  `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
1198
  to introduce causal or other biases.
1199
+
1200
  If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
1201
  indicates that the i-th element in the sequence is allowed to attend to the j-th
1202
  element in the sequence.
1203
+
1204
  If the tensor is a float tensor, it will just be added to the attention
1205
  scores before the softmax.
1206
+
1207
  The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
1208
  :param past_key_values: Pre-computed keys and values for each attention block.
1209
  Can be used to speed up sequential decoding. The `input_ids` which have
 
1270
  self.__cache, past_length + seq_len, x.device
1271
  ) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
1272
  elif attention_bias is None:
1273
+ attention_bias = self.get_bidirectional_attention_bias(past_length + seq_len, x.device)
1274
  elif attention_bias.dtype in (torch.int8, torch.bool):
1275
  attention_bias = attention_bias.to(dtype=torch.float)
1276
  attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)