Skip to content

Commit fe38f14

Browse files
committed
Make room for transformer changes requiring attn_mask fn
Summary: Latest changse in 4.53 requires that custom attentin functions have corresponding mask generation Test Plan: CI wont pass without changes on top of 4.53 Reviewers: Subscribers: Tasks: Tags:
1 parent e5193e1 commit fe38f14

File tree

4 files changed

+12
-6
lines changed

4 files changed

+12
-6
lines changed

optimum/executorch/attentions/custom_kv_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def __init__(
213213
for layer_idx in range(config.num_hidden_layers):
214214
# newer version of transfomer has is_sliding defined
215215
# for HybridCache
216-
if self.is_sliding_list[layer_idx]:
216+
if self.is_sliding[layer_idx]:
217217
# This is a sliding window layer
218218
layer_cache = CustomRingKVCache(
219219
max_batch_size=self.max_batch_size,
@@ -388,7 +388,7 @@ def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dt
388388
for i in range(len(module.cache.kv_cache)):
389389
setattr(module, f"key_cache_{i}", module.cache.kv_cache[i].k_cache)
390390
setattr(module, f"value_cache_{i}", module.cache.kv_cache[i].v_cache)
391-
if module.cache.is_sliding_list[i]:
391+
if module.cache.is_sliding[i]:
392392
# Register cache_positions as buffer for sliding window layers
393393
# This prevents it from being traced as a constant
394394
module.register_buffer(

optimum/exporters/executorch/integrations.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,12 @@ def export(self, input_ids=None, cache_position=None) -> Dict[str, ExportedProgr
4949
example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long)
5050
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
5151

52-
if is_transformers_version(">=", "4.52.0.dev0"):
52+
if is_transformers_version(">=", "4.53.0.dev0"):
5353
from transformers.integrations.executorch import (
5454
TorchExportableModuleForDecoderOnlyLM,
55+
sdpa_mask_without_vmap,
5556
)
57+
from transformers.masking_utils import AttentionMaskInterface
5658

5759
max_batch_size = 1
5860
max_cache_len = 4094
@@ -62,6 +64,7 @@ def export(self, input_ids=None, cache_position=None) -> Dict[str, ExportedProgr
6264

6365
_custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module)
6466
AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache)
67+
AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap)
6568
exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache"
6669
if self.use_custom_kv_cache:
6770
from optimum.executorch.attentions.custom_kv_cache import (

optimum/exporters/executorch/recipes/xnnpack.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,10 @@ def _lower_to_executorch(
9797

9898
exported_progs = model.export()
9999

100-
if model.config._attn_implementation == "custom_sdpa":
100+
if (
101+
model.config._attn_implementation == "custom_sdpa"
102+
or model.config._attn_implementation == "custom_sdpa_ring_kv_cache"
103+
):
101104
# Sanity check to make sure the exported program contains the custom sdpa operator.
102105
if not any(
103106
node.op == "call_function" and "custom_sdpa" in str(node.target)

tests/models/test_modeling_gemma3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,8 @@ def test_gemma3_text_generation_with_custom_sdpa_8da4w_8we(self):
218218
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))
219219

220220
@pytest.mark.skipif(
221-
parse(transformers.__version__) < parse("4.52.0") or parse(torchao.__version__) < parse("0.11.0"),
222-
reason="Only available on transformers >= 4.52.0 and torchao >= 0.11.0",
221+
parse(transformers.__version__) < parse("4.53.0.dev0") or parse(torchao.__version__) < parse("0.11.0"),
222+
reason="Only available on transformers >= 4.53.0.dev0 and torchao >= 0.11.0",
223223
)
224224
def test_gemma3_text_generation_with_custom_sdpa_kv_cache_8da4w_8we(self):
225225
# TODO: Until https://github.com/huggingface/optimum/issues/2127 is fixed, have to use non-gated model on CI

0 commit comments

Comments
 (0)