Commit
·
71ef017
1
Parent(s):
807ba34
set use_flash_attn if not available
Browse files- mha.py +0 -2
- modeling_xlm_roberta.py +16 -3
mha.py
CHANGED
|
@@ -10,8 +10,6 @@ import torch
|
|
| 10 |
import torch.nn as nn
|
| 11 |
from einops import rearrange, repeat
|
| 12 |
|
| 13 |
-
from flash_attn.utils.distributed import get_dim_for_local_rank
|
| 14 |
-
|
| 15 |
try:
|
| 16 |
from flash_attn import (
|
| 17 |
flash_attn_kvpacked_func,
|
|
|
|
| 10 |
import torch.nn as nn
|
| 11 |
from einops import rearrange, repeat
|
| 12 |
|
|
|
|
|
|
|
| 13 |
try:
|
| 14 |
from flash_attn import (
|
| 15 |
flash_attn_kvpacked_func,
|
modeling_xlm_roberta.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
# This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
|
| 2 |
# Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
|
| 3 |
-
|
| 4 |
# Copyright (c) 2022, Tri Dao.
|
| 5 |
# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
|
| 6 |
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
|
@@ -8,6 +7,7 @@
|
|
| 8 |
|
| 9 |
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
|
| 10 |
|
|
|
|
| 11 |
import logging
|
| 12 |
import re
|
| 13 |
from collections import OrderedDict
|
|
@@ -65,8 +65,21 @@ except ImportError:
|
|
| 65 |
logger = logging.getLogger(__name__)
|
| 66 |
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
| 69 |
-
use_flash_attn =
|
| 70 |
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 71 |
rotary_kwargs = {}
|
| 72 |
if config.position_embedding_type == "rotary":
|
|
@@ -169,7 +182,7 @@ def _init_weights(module, initializer_range=0.02):
|
|
| 169 |
class XLMRobertaEncoder(nn.Module):
|
| 170 |
def __init__(self, config: XLMRobertaFlashConfig):
|
| 171 |
super().__init__()
|
| 172 |
-
self.use_flash_attn =
|
| 173 |
self.layers = nn.ModuleList(
|
| 174 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
| 175 |
)
|
|
|
|
| 1 |
# This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
|
| 2 |
# Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
|
|
|
|
| 3 |
# Copyright (c) 2022, Tri Dao.
|
| 4 |
# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
|
| 5 |
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
|
|
|
| 7 |
|
| 8 |
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
|
| 9 |
|
| 10 |
+
import importlib.util
|
| 11 |
import logging
|
| 12 |
import re
|
| 13 |
from collections import OrderedDict
|
|
|
|
| 65 |
logger = logging.getLogger(__name__)
|
| 66 |
|
| 67 |
|
| 68 |
+
def get_use_flash_attn(config: XLMRobertaFlashConfig):
|
| 69 |
+
if not config.use_flash_attn:
|
| 70 |
+
return False
|
| 71 |
+
if not torch.cuda.is_available():
|
| 72 |
+
return False
|
| 73 |
+
if importlib.util.find_spec("flash_attn") is None:
|
| 74 |
+
logger.warning(
|
| 75 |
+
'flash_attn is not installed. Using PyTorch native attention implementation.'
|
| 76 |
+
)
|
| 77 |
+
return False
|
| 78 |
+
return True
|
| 79 |
+
|
| 80 |
+
|
| 81 |
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
| 82 |
+
use_flash_attn = get_use_flash_attn(config)
|
| 83 |
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 84 |
rotary_kwargs = {}
|
| 85 |
if config.position_embedding_type == "rotary":
|
|
|
|
| 182 |
class XLMRobertaEncoder(nn.Module):
|
| 183 |
def __init__(self, config: XLMRobertaFlashConfig):
|
| 184 |
super().__init__()
|
| 185 |
+
self.use_flash_attn = get_use_flash_attn(config)
|
| 186 |
self.layers = nn.ModuleList(
|
| 187 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
| 188 |
)
|