Skip to content

Commit 918b247

Browse files
committed
Make room for transformer changes requiring attn_mask fn
Summary: Latest changse in 4.53 requires that custom attentin functions have corresponding mask generation Test Plan: CI wont pass without changes on top of 4.53 Reviewers: Subscribers: Tasks: Tags:
1 parent afafb97 commit 918b247

File tree

6 files changed

+42
-17
lines changed

6 files changed

+42
-17
lines changed

install_dev.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
def install_torch_nightly_deps():
77
"""Install torch related dependencies from pinned nightly"""
8-
EXECUTORCH_NIGHTLY_VERSION = "dev20250620"
8+
EXECUTORCH_NIGHTLY_VERSION = "dev20250625"
99
TORCHAO_NIGHTLY_VERSION = "dev20250620"
1010
# Torch nightly is aligned with pinned nightly in https://github.com/pytorch/executorch/blob/main/install_requirements.py#L74
1111
TORCH_NIGHTLY_VERSION = "dev20250601"
@@ -34,7 +34,7 @@ def install_dep_from_source():
3434
"-m",
3535
"pip",
3636
"install",
37-
"git+https://github.com/huggingface/transformers@51f94ea06d19a6308c61bbb4dc97c40aabd12bad#egg=transformers", # v4.52.4
37+
"git+https://github.com/huggingface/transformers@37367c7d9fd23401c26e79a2b26253ab2d1b7236#egg=transformers",
3838
]
3939
)
4040
subprocess.check_call(

optimum/executorch/attentions/custom_kv_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def __init__(
213213
for layer_idx in range(config.num_hidden_layers):
214214
# newer version of transfomer has is_sliding defined
215215
# for HybridCache
216-
if self.is_sliding_list[layer_idx]:
216+
if self.is_sliding[layer_idx]:
217217
# This is a sliding window layer
218218
layer_cache = CustomRingKVCache(
219219
max_batch_size=self.max_batch_size,
@@ -388,7 +388,7 @@ def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dt
388388
for i in range(len(module.cache.kv_cache)):
389389
setattr(module, f"key_cache_{i}", module.cache.kv_cache[i].k_cache)
390390
setattr(module, f"value_cache_{i}", module.cache.kv_cache[i].v_cache)
391-
if module.cache.is_sliding_list[i]:
391+
if module.cache.is_sliding[i]:
392392
# Register cache_positions as buffer for sliding window layers
393393
# This prevents it from being traced as a constant
394394
module.register_buffer(

optimum/exporters/executorch/integrations.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,33 @@ class CausalLMExportableModule(torch.nn.Module):
3838
This module ensures that the exported model is compatible with ExecuTorch.
3939
"""
4040

41-
def __init__(self, model, use_custom_kv_cache=False):
41+
def __init__(self, model, use_custom_kv_cache=False, use_custom_sdpa=False):
4242
super().__init__()
4343
self.model = model
4444
self.config = model.config
4545
self.use_custom_kv_cache = use_custom_kv_cache
46+
self.use_custom_sdpa = use_custom_sdpa
4647
self.metadata = save_config_to_constant_methods(model.config, model.generation_config)
4748

49+
def _register_attention_mask_for_4_53(self, exportable_module: torch.nn.Module):
50+
if is_transformers_version(">=", "4.53.0.dev0"):
51+
from transformers.integrations.executorch import sdpa_mask_without_vmap
52+
from transformers.masking_utils import AttentionMaskInterface
53+
from transformers.modeling_utils import AttentionInterface
54+
55+
_custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module)
56+
if self.use_custom_sdpa:
57+
if self.use_custom_kv_cache:
58+
AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache)
59+
AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap)
60+
# Manually set the attention implementation to custom_sdpa_ring_kv_cache
61+
# This handles both regular sdpa and one for sliding window/local attention
62+
exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache"
63+
else:
64+
# Manually set the attention implementation to custom_sdpa_ring_kv_cache
65+
# This handles both regular sdpa and one for sliding window/local attention
66+
exportable_module.model.model.config._attn_implementation = "custom_sdpa"
67+
4868
def export(self, input_ids=None, cache_position=None) -> Dict[str, ExportedProgram]:
4969
example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long)
5070
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
@@ -57,12 +77,7 @@ def export(self, input_ids=None, cache_position=None) -> Dict[str, ExportedProgr
5777
max_batch_size = 1
5878
max_cache_len = 4094
5979
exportable_module = TorchExportableModuleForDecoderOnlyLM(self.model, max_batch_size, max_cache_len)
60-
61-
from transformers.modeling_utils import AttentionInterface
62-
63-
_custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module)
64-
AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache)
65-
exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache"
80+
self._register_attention_mask_for_4_53(exportable_module)
6681
if self.use_custom_kv_cache:
6782
from optimum.executorch.attentions.custom_kv_cache import (
6883
replace_with_et_custom_kv_cache,

optimum/exporters/executorch/recipes/xnnpack.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,10 @@ def _lower_to_executorch(
9797

9898
exported_progs = model.export()
9999

100-
if model.config._attn_implementation == "custom_sdpa":
100+
if (
101+
model.config._attn_implementation == "custom_sdpa"
102+
or model.config._attn_implementation == "custom_sdpa_ring_kv_cache"
103+
):
101104
# Sanity check to make sure the exported program contains the custom sdpa operator.
102105
if not any(
103106
node.op == "call_function" and "custom_sdpa" in str(node.target)

optimum/exporters/executorch/tasks/causal_lm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl
5858
use_custom_kv_cache = kwargs.get("use_custom_kv_cache", False)
5959
attn_implementation = kwargs.get("attn_implementation", "custom_sdpa" if use_custom_sdpa else "sdpa")
6060
cache_implementation = kwargs.get("cache_implementation", "static")
61+
use_custom_sdpa = use_custom_sdpa or attn_implementation == "custom_sdpa"
6162
max_length = kwargs.get("max_length", 2048)
6263
config = kwargs.get("config", None)
6364

@@ -126,4 +127,4 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl
126127

127128
unwrap_tensor_subclass(eager_model)
128129

129-
return CausalLMExportableModule(eager_model, use_custom_kv_cache)
130+
return CausalLMExportableModule(eager_model, use_custom_kv_cache, use_custom_sdpa)

tests/models/test_modeling_gemma3.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ def test_gemma3_text_generation(self):
112112
@slow
113113
@pytest.mark.run_slow
114114
@pytest.mark.skipif(is_linux_ci, reason="OOM on linux runner")
115+
@pytest.mark.skipif(
116+
parse(transformers.__version__) < parse("4.53.0.dev0") or parse(torchao.__version__) < parse("0.11.0"),
117+
reason="Only available on transformers >= 4.53.0.dev0 and torchao >= 0.11.0",
118+
)
115119
def test_gemma3_text_generation_with_custom_sdpa(self):
116120
# TODO: Until https://github.com/huggingface/optimum/issues/2127 is fixed, have to use non-gated model on CI
117121
# model_id = "google/gemma-3-1b-it"
@@ -181,8 +185,8 @@ def test_gemma3_text_generation_with_custom_sdpa_float16(self):
181185
@slow
182186
@pytest.mark.run_slow
183187
@pytest.mark.skipif(
184-
parse(torchao.__version__) < parse("0.11.0.dev0"),
185-
reason="Only available on torchao >= 0.11.0.dev0",
188+
parse(transformers.__version__) < parse("4.53.0.dev0") or parse(torchao.__version__) < parse("0.11.0"),
189+
reason="Only available on transformers >= 4.53.0.dev0 and torchao >= 0.11.0",
186190
)
187191
def test_gemma3_text_generation_with_custom_sdpa_8da4w_8we(self):
188192
# TODO: Until https://github.com/huggingface/optimum/issues/2127 is fixed, have to use non-gated model on CI
@@ -217,9 +221,11 @@ def test_gemma3_text_generation_with_custom_sdpa_8da4w_8we(self):
217221

218222
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))
219223

224+
@slow
225+
@pytest.mark.run_slow
220226
@pytest.mark.skipif(
221-
parse(transformers.__version__) < parse("4.52.0") or parse(torchao.__version__) < parse("0.11.0"),
222-
reason="Only available on transformers >= 4.52.0 and torchao >= 0.11.0",
227+
parse(transformers.__version__) < parse("4.53.0.dev0") or parse(torchao.__version__) < parse("0.11.0"),
228+
reason="Only available on transformers >= 4.53.0.dev0 and torchao >= 0.11.0",
223229
)
224230
def test_gemma3_text_generation_with_custom_sdpa_kv_cache_8da4w_8we(self):
225231
# TODO: Until https://github.com/huggingface/optimum/issues/2127 is fixed, have to use non-gated model on CI

0 commit comments

Comments
 (0)