Hack90 commited on
Commit
8f80b43
·
verified ·
1 Parent(s): 2b126ef

Update modeling_llada.py

Browse files
Files changed (1) hide show
  1. modeling_llada.py +5 -6
modeling_llada.py CHANGED
@@ -30,7 +30,6 @@ from transformers import PreTrainedModel
30
  from transformers.modeling_outputs import CausalLMOutputWithPast
31
  from transformers.models.auto import AutoModel
32
  from transformers.cache_utils import Cache
33
- from transformers import AutoConfig
34
 
35
  from .configuration_llada import (
36
  LLaDAConfig,
@@ -637,7 +636,7 @@ class LLaDABlock(nn.Module):
637
  """
638
  if self.flash_attn_func is not None and attn_mask is None:
639
  r = self.flash_attn_func(
640
- q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=is_causal
641
  )
642
  return r.transpose(1, 2)
643
  else:
@@ -657,7 +656,7 @@ class LLaDABlock(nn.Module):
657
  v,
658
  attn_mask=None,
659
  dropout_p=dropout_p,
660
- # is_causal=False,
661
  )
662
 
663
  def attention(
@@ -713,9 +712,9 @@ class LLaDABlock(nn.Module):
713
  q,
714
  k,
715
  v,
716
- attn_mask=attention_bias,
717
  dropout_p=0.0 if not self.training else self.config.attention_dropout,
718
- is_causal=attention_bias is None,
719
  )
720
 
721
  # Re-assemble all head outputs side-by-side.
@@ -1491,4 +1490,4 @@ class LLaDAModelLM(PreTrainedModel):
1491
  self.model.transformer.ff_out = self.model.transformer.wte
1492
 
1493
  # Register the model so that it is available for transformer pipelines, auto-loading, etc.
1494
- AutoModel.register(LLaDAConfig, LLaDAModelLM)
 
30
  from transformers.modeling_outputs import CausalLMOutputWithPast
31
  from transformers.models.auto import AutoModel
32
  from transformers.cache_utils import Cache
 
33
 
34
  from .configuration_llada import (
35
  LLaDAConfig,
 
636
  """
637
  if self.flash_attn_func is not None and attn_mask is None:
638
  r = self.flash_attn_func(
639
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=False
640
  )
641
  return r.transpose(1, 2)
642
  else:
 
656
  v,
657
  attn_mask=None,
658
  dropout_p=dropout_p,
659
+ is_causal=False,
660
  )
661
 
662
  def attention(
 
712
  q,
713
  k,
714
  v,
715
+ attn_mask=None,
716
  dropout_p=0.0 if not self.training else self.config.attention_dropout,
717
+ is_causal=False,
718
  )
719
 
720
  # Re-assemble all head outputs side-by-side.
 
1490
  self.model.transformer.ff_out = self.model.transformer.wte
1491
 
1492
  # Register the model so that it is available for transformer pipelines, auto-loading, etc.
1493
+ AutoModel.register(LLaDAConfig, LLaDAModelLM)