support eager attention
Browse files- config.json +2 -0
- modeling_aria.py +1 -0
- vision_encoder.py +0 -1
config.json
CHANGED
|
@@ -30,8 +30,10 @@
|
|
| 30 |
},
|
| 31 |
"torch_dtype": "bfloat16",
|
| 32 |
"transformers_version": "4.45.0",
|
|
|
|
| 33 |
"vision_config": {
|
| 34 |
"_flash_attn_2_enabled": true,
|
|
|
|
| 35 |
"architectures": [
|
| 36 |
"AriaVisionModel"
|
| 37 |
],
|
|
|
|
| 30 |
},
|
| 31 |
"torch_dtype": "bfloat16",
|
| 32 |
"transformers_version": "4.45.0",
|
| 33 |
+
"_attn_implementation": "flash_attention_2",
|
| 34 |
"vision_config": {
|
| 35 |
"_flash_attn_2_enabled": true,
|
| 36 |
+
"_attn_implementation": "flash_attention_2",
|
| 37 |
"architectures": [
|
| 38 |
"AriaVisionModel"
|
| 39 |
],
|
modeling_aria.py
CHANGED
|
@@ -133,6 +133,7 @@ class AriaForConditionalGeneration(AriaPretrainedModel):
|
|
| 133 |
def __init__(self, config: AriaConfig):
|
| 134 |
super().__init__(config)
|
| 135 |
|
|
|
|
| 136 |
self.vision_tower = AriaVisionModel(config.vision_config)
|
| 137 |
self.multi_modal_projector = build_mm_projector(config)
|
| 138 |
self.vocab_size = config.text_config.vocab_size
|
|
|
|
| 133 |
def __init__(self, config: AriaConfig):
|
| 134 |
super().__init__(config)
|
| 135 |
|
| 136 |
+
config.vision_config._attn_implementation = config._attn_implementation
|
| 137 |
self.vision_tower = AriaVisionModel(config.vision_config)
|
| 138 |
self.multi_modal_projector = build_mm_projector(config)
|
| 139 |
self.vocab_size = config.text_config.vocab_size
|
vision_encoder.py
CHANGED
|
@@ -38,7 +38,6 @@ class AriaVisionConfig(SiglipVisionConfig):
|
|
| 38 |
**kwargs,
|
| 39 |
):
|
| 40 |
super().__init__(**kwargs)
|
| 41 |
-
self._attn_implementation = "flash_attention_2"
|
| 42 |
|
| 43 |
|
| 44 |
class IdentityOp(torch.nn.Module):
|
|
|
|
| 38 |
**kwargs,
|
| 39 |
):
|
| 40 |
super().__init__(**kwargs)
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
class IdentityOp(torch.nn.Module):
|