@@ -38,13 +38,33 @@ class CausalLMExportableModule(torch.nn.Module):
38
38
This module ensures that the exported model is compatible with ExecuTorch.
39
39
"""
40
40
41
- def __init__ (self , model , use_custom_kv_cache = False ):
41
+ def __init__ (self , model , use_custom_kv_cache = False , use_custom_sdpa = False ):
42
42
super ().__init__ ()
43
43
self .model = model
44
44
self .config = model .config
45
45
self .use_custom_kv_cache = use_custom_kv_cache
46
+ self .use_custom_sdpa = use_custom_sdpa
46
47
self .metadata = save_config_to_constant_methods (model .config , model .generation_config )
47
48
49
+ def _register_attention_mask_for_4_53 (self , exportable_module : torch .nn .Module ):
50
+ if is_transformers_version (">=" , "4.53.0.dev0" ):
51
+ from transformers .integrations .executorch import sdpa_mask_without_vmap
52
+ from transformers .masking_utils import AttentionMaskInterface
53
+ from transformers .modeling_utils import AttentionInterface
54
+
55
+ _custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache (exportable_module )
56
+ if self .use_custom_sdpa :
57
+ if self .use_custom_kv_cache :
58
+ AttentionInterface .register ("custom_sdpa_ring_kv_cache" , _custom_sdpa_for_ring_kv_cache )
59
+ AttentionMaskInterface .register ("custom_sdpa_ring_kv_cache" , sdpa_mask_without_vmap )
60
+ # Manually set the attention implementation to custom_sdpa_ring_kv_cache
61
+ # This handles both regular sdpa and one for sliding window/local attention
62
+ exportable_module .model .model .config ._attn_implementation = "custom_sdpa_ring_kv_cache"
63
+ else :
64
+ # Manually set the attention implementation to custom_sdpa_ring_kv_cache
65
+ # This handles both regular sdpa and one for sliding window/local attention
66
+ exportable_module .model .model .config ._attn_implementation = "custom_sdpa"
67
+
48
68
def export (self , input_ids = None , cache_position = None ) -> Dict [str , ExportedProgram ]:
49
69
example_input_ids = input_ids if input_ids is not None else torch .tensor ([[1 ]], dtype = torch .long )
50
70
example_cache_position = cache_position if cache_position is not None else torch .tensor ([0 ], dtype = torch .long )
@@ -57,12 +77,7 @@ def export(self, input_ids=None, cache_position=None) -> Dict[str, ExportedProgr
57
77
max_batch_size = 1
58
78
max_cache_len = 4094
59
79
exportable_module = TorchExportableModuleForDecoderOnlyLM (self .model , max_batch_size , max_cache_len )
60
-
61
- from transformers .modeling_utils import AttentionInterface
62
-
63
- _custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache (exportable_module )
64
- AttentionInterface .register ("custom_sdpa_ring_kv_cache" , _custom_sdpa_for_ring_kv_cache )
65
- exportable_module .model .model .config ._attn_implementation = "custom_sdpa_ring_kv_cache"
80
+ self ._register_attention_mask_for_4_53 (exportable_module )
66
81
if self .use_custom_kv_cache :
67
82
from optimum .executorch .attentions .custom_kv_cache import (
68
83
replace_with_et_custom_kv_cache ,
0 commit comments