Skip to content

Commit

Permalink
Fix SDPA dispatch & make SDPA CI compatible with torch<2.1.1 (hugging…
Browse files Browse the repository at this point in the history
…face#27940)

fix sdpa dispatch
  • Loading branch information
fxmarty authored Dec 11, 2023
1 parent 7ea21f1 commit 9f18cc6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
15 changes: 8 additions & 7 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,6 +1244,7 @@ def _autoset_attn_implementation(
# Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user.
# The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
# The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
requested_attn_implementation = None
if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
if config._attn_implementation != "flash_attention_2" and use_flash_attention_2:
raise ValueError(
Expand All @@ -1260,9 +1261,7 @@ def _autoset_attn_implementation(
raise ValueError(message + ".")

# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
hard_check_only = True
else:
hard_check_only = False
requested_attn_implementation = config._attn_implementation_internal

if use_flash_attention_2:
logger.warning_once(
Expand All @@ -1275,13 +1274,15 @@ def _autoset_attn_implementation(
config,
torch_dtype=torch_dtype,
device_map=device_map,
hard_check_only=hard_check_only,
hard_check_only=False,
check_device_map=check_device_map,
)
elif cls._supports_sdpa or config._attn_implementation == "sdpa":
elif requested_attn_implementation in [None, "sdpa"]:
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
config = cls._check_and_enable_sdpa(config, hard_check_only=hard_check_only)
elif not hard_check_only:
config = cls._check_and_enable_sdpa(
config, hard_check_only=False if requested_attn_implementation is None else True
)
else:
config._attn_implementation = "eager"

return config
Expand Down
3 changes: 2 additions & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
is_flax_available,
is_tf_available,
is_torch_fx_available,
is_torch_sdpa_available,
)
from transformers.utils.generic import ModelOutput

Expand Down Expand Up @@ -778,7 +779,7 @@ def _create_and_check_torchscript(self, config, inputs_dict):
configs_no_init.torchscript = True
for model_class in self.all_model_classes:
for attn_implementation in ["eager", "sdpa"]:
if attn_implementation == "sdpa" and not model_class._supports_sdpa:
if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()):
continue

configs_no_init._attn_implementation = attn_implementation
Expand Down

0 comments on commit 9f18cc6

Please sign in to comment.