18
18
except ImportError :
19
19
raise ImportError ("transformers is not installed. Please install it to use StaticCache." )
20
20
21
+ try :
22
+ from executorch .examples .models .llama .source_transformation .custom_kv_cache import (
23
+ CustomKVCache ,
24
+ )
25
+ except ImportError :
26
+ raise ImportError ("ExecutorTorch is not installed. Please install it to use CustomKVCache." )
27
+
21
28
22
29
class ETCustomStaticCache (StaticCache ):
23
30
"""
@@ -45,26 +52,20 @@ def __init__(
45
52
46
53
# make sure layer_device_map is none
47
54
assert layer_device_map is None
48
-
49
- # Clear existing caches
50
- self .key_cache = []
51
- self .value_cache = []
52
-
53
- # Initialize cache buffers with our custom shape
54
- cache_shape = (
55
- self .max_batch_size ,
56
- self .max_cache_len ,
57
- self .num_key_value_heads ,
58
- self .head_dim ,
59
- )
60
55
assert device is None or device == "cpu" , "Device must be None or 'cpu'"
61
56
57
+ # Create a list of CustomKVCache instances, one per layer
58
+ kv_cache_list = []
62
59
for _ in range (config .num_hidden_layers ):
63
- self .new_layer_key_cache = torch .zeros (cache_shape , dtype = dtype , device = "cpu" )
64
- self .new_layer_value_cache = torch .zeros (cache_shape , dtype = dtype , device = "cpu" )
65
-
66
- self .key_cache .append (self .new_layer_key_cache )
67
- self .value_cache .append (self .new_layer_value_cache )
60
+ layer_cache = CustomKVCache (
61
+ max_batch_size = self .max_batch_size ,
62
+ max_context_length = self .max_cache_len ,
63
+ n_heads = self .num_key_value_heads ,
64
+ head_dim = self .head_dim ,
65
+ dtype = dtype ,
66
+ )
67
+ kv_cache_list .append (layer_cache )
68
+ self .kv_cache = torch .nn .ModuleList (kv_cache_list )
68
69
69
70
def update (
70
71
self ,
@@ -75,7 +76,7 @@ def update(
75
76
) -> Tuple [torch .Tensor , torch .Tensor ]:
76
77
"""
77
78
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`
78
- using custom operations .
79
+ using ExecutorTorch's CustomKVCache .
79
80
80
81
Args:
81
82
key_states (`torch.Tensor`):
@@ -95,32 +96,28 @@ def update(
95
96
# Get cache position from cache_kwargs (used by StaticCache)
96
97
cache_position = cache_kwargs .get ("cache_position" )
97
98
assert cache_position is not None
99
+ assert isinstance (cache_position , torch .Tensor )
98
100
99
- # Get the current cache for this layer
100
- k_out = self .key_cache [layer_idx ]
101
- v_out = self .value_cache [layer_idx ]
102
-
103
- # Transpose key and value states to match our cache shape
104
- # From [batch_size, n_heads, seq_len, head_dim] to [batch_size, seq_len, n_heads, head_dim]
105
- k_val = key_states .transpose (1 , 2 )
106
- v_val = value_states .transpose (1 , 2 )
101
+ # Get the CustomKVCache instance for this layer
102
+ layer_cache = self .kv_cache [layer_idx ]
107
103
108
- # Use custom operations to update the cache
109
- # Update cache with indices for more complex update patterns
110
- assert isinstance (cache_position , torch .Tensor )
111
- start_pos = cache_position [0 ].item ()
112
- _ = torch .ops .llama .update_cache (k_val , k_out , start_pos )
113
- _ = torch .ops .llama .update_cache (v_val , v_out , start_pos )
104
+ # Use the CustomKVCache's update method
105
+ # CustomKVCache expects input_pos, k_val, v_val and handles the transpose internally
106
+ k_out , v_out = layer_cache .update (
107
+ input_pos = cache_position ,
108
+ k_val = key_states ,
109
+ v_val = value_states ,
110
+ )
114
111
115
- # Return the updated cache in the format expected by the model
116
- # Transpose back from [batch_size, seq_len, n_heads, head_dim] to [batch_size, n_heads, seq_len, head_dim]
117
- return k_out .transpose (1 , 2 ), v_out .transpose (1 , 2 )
112
+ return k_out , v_out
118
113
119
114
def get_seq_length (self , layer_idx : Optional [int ] = 0 ) -> int :
120
115
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
121
116
# Occupied cache == any slot in the 2nd dim (sequence length) holds a non-zero value
122
117
# This is different from StaticCache which checks the 3rd dim
123
- return (self .key_cache [layer_idx ][0 , :, 0 ].any (dim = - 1 )).sum ()
118
+ if layer_idx is None :
119
+ layer_idx = 0
120
+ return (self .kv_cache [layer_idx ].k_cache [0 , :, 0 ].any (dim = - 1 )).sum ()
124
121
125
122
@classmethod
126
123
def from_legacy_cache (
@@ -249,8 +246,10 @@ def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dt
249
246
device = generation_config .cache_config .device ,
250
247
dtype = cache_dtype ,
251
248
)
252
- for i in range (len (module .static_cache .key_cache )):
253
- setattr (module , f"key_cache_{ i } " , module .static_cache .key_cache [i ])
254
- setattr (module , f"value_cache_{ i } " , module .static_cache .value_cache [i ])
249
+ # Dont know why we need to this even though
250
+ # CustomKVCache registers the attributes
251
+ for i in range (len (module .static_cache .kv_cache )):
252
+ setattr (module , f"key_cache_{ i } " , module .static_cache .kv_cache [i ].k_cache )
253
+ setattr (module , f"value_cache_{ i } " , module .static_cache .kv_cache [i ].v_cache )
255
254
256
255
return module
0 commit comments