add mask_first_token
Browse files- config.json +1 -0
- modeling_lsg_bart.py +6 -1
config.json
CHANGED
|
@@ -52,6 +52,7 @@
|
|
| 52 |
},
|
| 53 |
"length_penalty": 2.0,
|
| 54 |
"lsh_num_pre_rounds": 1,
|
|
|
|
| 55 |
"max_length": 512,
|
| 56 |
"max_position_embeddings": 4096,
|
| 57 |
"min_length": 128,
|
|
|
|
| 52 |
},
|
| 53 |
"length_penalty": 2.0,
|
| 54 |
"lsh_num_pre_rounds": 1,
|
| 55 |
+
"mask_first_token": false,
|
| 56 |
"max_length": 512,
|
| 57 |
"max_position_embeddings": 4096,
|
| 58 |
"min_length": 128,
|
modeling_lsg_bart.py
CHANGED
|
@@ -31,6 +31,7 @@ class LSGBartConfig(BartConfig):
|
|
| 31 |
base_model_prefix="lsg",
|
| 32 |
block_size=128,
|
| 33 |
lsh_num_pre_rounds=1,
|
|
|
|
| 34 |
num_global_tokens=1,
|
| 35 |
pass_global_tokens_to_decoder=True,
|
| 36 |
pool_with_global=True,
|
|
@@ -47,6 +48,7 @@ class LSGBartConfig(BartConfig):
|
|
| 47 |
self.base_model_prefix = base_model_prefix
|
| 48 |
self.block_size = block_size
|
| 49 |
self.lsh_num_pre_rounds = lsh_num_pre_rounds
|
|
|
|
| 50 |
self.num_global_tokens = num_global_tokens
|
| 51 |
self.pass_global_tokens_to_decoder = pass_global_tokens_to_decoder
|
| 52 |
self.pool_with_global = pool_with_global
|
|
@@ -711,6 +713,7 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
|
|
| 711 |
assert hasattr(config, "block_size") and hasattr(config, "adaptive")
|
| 712 |
self.block_size = config.block_size
|
| 713 |
self.adaptive = config.adaptive
|
|
|
|
| 714 |
self.pool_with_global = config.pool_with_global
|
| 715 |
self.pass_global_tokens_to_decoder = config.pass_global_tokens_to_decoder
|
| 716 |
|
|
@@ -737,7 +740,9 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
|
|
| 737 |
|
| 738 |
if attention_mask is None:
|
| 739 |
attention_mask = torch.ones(n, t, device=inputs_.device)
|
| 740 |
-
|
|
|
|
|
|
|
| 741 |
b = self.block_size * 2
|
| 742 |
pad = t % self.block_size
|
| 743 |
|
|
|
|
| 31 |
base_model_prefix="lsg",
|
| 32 |
block_size=128,
|
| 33 |
lsh_num_pre_rounds=1,
|
| 34 |
+
mask_first_token=False,
|
| 35 |
num_global_tokens=1,
|
| 36 |
pass_global_tokens_to_decoder=True,
|
| 37 |
pool_with_global=True,
|
|
|
|
| 48 |
self.base_model_prefix = base_model_prefix
|
| 49 |
self.block_size = block_size
|
| 50 |
self.lsh_num_pre_rounds = lsh_num_pre_rounds
|
| 51 |
+
self.mask_first_token = mask_first_token
|
| 52 |
self.num_global_tokens = num_global_tokens
|
| 53 |
self.pass_global_tokens_to_decoder = pass_global_tokens_to_decoder
|
| 54 |
self.pool_with_global = pool_with_global
|
|
|
|
| 713 |
assert hasattr(config, "block_size") and hasattr(config, "adaptive")
|
| 714 |
self.block_size = config.block_size
|
| 715 |
self.adaptive = config.adaptive
|
| 716 |
+
self.mask_first_token = config.mask_first_token
|
| 717 |
self.pool_with_global = config.pool_with_global
|
| 718 |
self.pass_global_tokens_to_decoder = config.pass_global_tokens_to_decoder
|
| 719 |
|
|
|
|
| 740 |
|
| 741 |
if attention_mask is None:
|
| 742 |
attention_mask = torch.ones(n, t, device=inputs_.device)
|
| 743 |
+
if self.mask_first_token:
|
| 744 |
+
attention_mask[:, 0] = 0
|
| 745 |
+
|
| 746 |
b = self.block_size * 2
|
| 747 |
pad = t % self.block_size
|
| 748 |
|