sync code
Browse files- configuration_aria.py +7 -2
- modeling_aria.py +8 -0
- moe_lm.py +4 -2
configuration_aria.py
CHANGED
|
@@ -66,14 +66,19 @@ class AriaConfig(PretrainedConfig):
|
|
| 66 |
},
|
| 67 |
ignore_index=-100,
|
| 68 |
image_token_index=32000,
|
|
|
|
| 69 |
**kwargs,
|
| 70 |
):
|
| 71 |
super().__init__(**kwargs)
|
| 72 |
self.ignore_index = ignore_index
|
| 73 |
self.image_token_index = image_token_index
|
| 74 |
-
|
| 75 |
attn_implementation = kwargs.pop("attn_implementation", None)
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
# Convert the keys and values of projector_patch_to_query_dict to integers
|
| 79 |
# This ensures consistency even if they were provided as strings
|
|
|
|
| 66 |
},
|
| 67 |
ignore_index=-100,
|
| 68 |
image_token_index=32000,
|
| 69 |
+
tie_word_embeddings=False,
|
| 70 |
**kwargs,
|
| 71 |
):
|
| 72 |
super().__init__(**kwargs)
|
| 73 |
self.ignore_index = ignore_index
|
| 74 |
self.image_token_index = image_token_index
|
| 75 |
+
self.tie_word_embeddings = tie_word_embeddings
|
| 76 |
attn_implementation = kwargs.pop("attn_implementation", None)
|
| 77 |
+
|
| 78 |
+
# Set the default attention implementation to flash_attention_2 if not specified
|
| 79 |
+
self._attn_implementation = (
|
| 80 |
+
"flash_attention_2" if attn_implementation is None else attn_implementation
|
| 81 |
+
)
|
| 82 |
|
| 83 |
# Convert the keys and values of projector_patch_to_query_dict to integers
|
| 84 |
# This ensures consistency even if they were provided as strings
|
modeling_aria.py
CHANGED
|
@@ -165,6 +165,14 @@ class AriaForConditionalGeneration(AriaPretrainedModel, GenerationMixin):
|
|
| 165 |
"""Set the input embeddings for the language model."""
|
| 166 |
self.language_model.set_input_embeddings(value)
|
| 167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
def set_moe_z_loss_coeff(self, value):
|
| 169 |
"""
|
| 170 |
Set the z-loss coefficient for Mixture of Experts (MoE) models.
|
|
|
|
| 165 |
"""Set the input embeddings for the language model."""
|
| 166 |
self.language_model.set_input_embeddings(value)
|
| 167 |
|
| 168 |
+
def get_output_embeddings(self):
|
| 169 |
+
"""Retrieve the output embeddings from the language model."""
|
| 170 |
+
return self.language_model.get_output_embeddings()
|
| 171 |
+
|
| 172 |
+
def set_output_embeddings(self, value):
|
| 173 |
+
"""Set the output embeddings for the language model."""
|
| 174 |
+
self.language_model.set_output_embeddings(value)
|
| 175 |
+
|
| 176 |
def set_moe_z_loss_coeff(self, value):
|
| 177 |
"""
|
| 178 |
Set the z-loss coefficient for Mixture of Experts (MoE) models.
|
moe_lm.py
CHANGED
|
@@ -255,7 +255,8 @@ class TopKRouter(nn.Module):
|
|
| 255 |
- top_indices: Indices of top-k experts for each token.
|
| 256 |
- tokens_per_expert: Number of tokens assigned to each expert.
|
| 257 |
"""
|
| 258 |
-
|
|
|
|
| 259 |
|
| 260 |
top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
|
| 261 |
scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32).type_as(logits)
|
|
@@ -267,7 +268,8 @@ class TopKRouter(nn.Module):
|
|
| 267 |
max=self.config.moe_num_experts - 1,
|
| 268 |
)
|
| 269 |
|
| 270 |
-
|
|
|
|
| 271 |
return scores, top_indices, tokens_per_expert
|
| 272 |
|
| 273 |
def forward(
|
|
|
|
| 255 |
- top_indices: Indices of top-k experts for each token.
|
| 256 |
- tokens_per_expert: Number of tokens assigned to each expert.
|
| 257 |
"""
|
| 258 |
+
if self.training:
|
| 259 |
+
logits = self.apply_z_loss(logits)
|
| 260 |
|
| 261 |
top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
|
| 262 |
scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32).type_as(logits)
|
|
|
|
| 268 |
max=self.config.moe_num_experts - 1,
|
| 269 |
)
|
| 270 |
|
| 271 |
+
if self.training:
|
| 272 |
+
scores = self.apply_aux_loss(logits, tokens_per_expert, scores)
|
| 273 |
return scores, top_indices, tokens_per_expert
|
| 274 |
|
| 275 |
def forward(
|