Commit
·
943cec2
1
Parent(s):
c55e591
feat: truncation option during init
Browse filesSigned-off-by: jupyterjazz <[email protected]>
- configuration_xlm_roberta.py +2 -0
- modeling_xlm_roberta.py +1 -0
configuration_xlm_roberta.py
CHANGED
|
@@ -32,6 +32,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
| 32 |
torch_dtype=None,
|
| 33 |
emb_pooler=None,
|
| 34 |
matryoshka_dimensions=None,
|
|
|
|
| 35 |
**kwargs,
|
| 36 |
):
|
| 37 |
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
|
@@ -61,6 +62,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
| 61 |
self.use_flash_attn = use_flash_attn
|
| 62 |
self.emb_pooler = emb_pooler
|
| 63 |
self.matryoshka_dimensions = matryoshka_dimensions
|
|
|
|
| 64 |
if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
|
| 65 |
self.torch_dtype = getattr(torch, torch_dtype)
|
| 66 |
else:
|
|
|
|
| 32 |
torch_dtype=None,
|
| 33 |
emb_pooler=None,
|
| 34 |
matryoshka_dimensions=None,
|
| 35 |
+
truncate_dim=None,
|
| 36 |
**kwargs,
|
| 37 |
):
|
| 38 |
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
|
|
|
| 62 |
self.use_flash_attn = use_flash_attn
|
| 63 |
self.emb_pooler = emb_pooler
|
| 64 |
self.matryoshka_dimensions = matryoshka_dimensions
|
| 65 |
+
self.truncate_dim = truncate_dim
|
| 66 |
if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
|
| 67 |
self.torch_dtype = getattr(torch, torch_dtype)
|
| 68 |
else:
|
modeling_xlm_roberta.py
CHANGED
|
@@ -578,6 +578,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 578 |
|
| 579 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
| 580 |
|
|
|
|
| 581 |
if truncate_dim:
|
| 582 |
all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
|
| 583 |
|
|
|
|
| 578 |
|
| 579 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
| 580 |
|
| 581 |
+
truncate_dim = truncate_dim or self.config.truncate_dim
|
| 582 |
if truncate_dim:
|
| 583 |
all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
|
| 584 |
|