feat: check flash-attn version if installed
#15
by
reedcli - opened
- modeling_yi.py +10 -3
modeling_yi.py
CHANGED
|
@@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Union
|
|
| 4 |
|
| 5 |
import torch.utils.checkpoint
|
| 6 |
from einops import repeat
|
|
|
|
| 7 |
from torch import nn
|
| 8 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 9 |
from transformers.activations import ACT2FN
|
|
@@ -25,8 +26,12 @@ from .configuration_yi import YiConfig
|
|
| 25 |
|
| 26 |
is_flash_attn_available = True
|
| 27 |
try:
|
| 28 |
-
from flash_attn import flash_attn_func
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
is_flash_attn_available = False
|
| 31 |
|
| 32 |
logger = logging.get_logger(__name__)
|
|
@@ -539,7 +544,9 @@ class YiModel(YiPreTrainedModel):
|
|
| 539 |
def _prepare_decoder_attention_mask(
|
| 540 |
self, attention_mask, input_ids, inputs_embeds, past_key_values_length
|
| 541 |
):
|
| 542 |
-
input_shape =
|
|
|
|
|
|
|
| 543 |
# create causal mask
|
| 544 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 545 |
combined_attention_mask = None
|
|
|
|
| 4 |
|
| 5 |
import torch.utils.checkpoint
|
| 6 |
from einops import repeat
|
| 7 |
+
from packaging import version
|
| 8 |
from torch import nn
|
| 9 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 10 |
from transformers.activations import ACT2FN
|
|
|
|
| 26 |
|
| 27 |
is_flash_attn_available = True
|
| 28 |
try:
|
| 29 |
+
from flash_attn import flash_attn_func, __version__
|
| 30 |
+
|
| 31 |
+
assert version.parse(__version__) >= version.parse(
|
| 32 |
+
"2.3.0"
|
| 33 |
+
), "please update your flash_attn version (>= 2.3.0)"
|
| 34 |
+
except ModuleNotFoundError:
|
| 35 |
is_flash_attn_available = False
|
| 36 |
|
| 37 |
logger = logging.get_logger(__name__)
|
|
|
|
| 544 |
def _prepare_decoder_attention_mask(
|
| 545 |
self, attention_mask, input_ids, inputs_embeds, past_key_values_length
|
| 546 |
):
|
| 547 |
+
input_shape = (
|
| 548 |
+
input_ids.shape if input_ids is not None else inputs_embeds.shape[:-1]
|
| 549 |
+
)
|
| 550 |
# create causal mask
|
| 551 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 552 |
combined_attention_mask = None
|