21
21
try :
22
22
from executorch .examples .models .llama .source_transformation .custom_kv_cache import (
23
23
CustomKVCache ,
24
+ CustomRingKVCache ,
24
25
)
25
26
except ImportError :
26
27
raise ImportError ("ExecutorTorch is not installed. Please install it to use CustomKVCache." )
27
28
29
+ try :
30
+ from transformers .cache_utils import HybridCache
31
+ except ImportError :
32
+ # If transformers is not installed, raise an ImportError
33
+ try :
34
+ from transformers .cache_utils import HybridCache
35
+ except ImportError :
36
+ raise ImportError ("transformers is not installed. Please install it to use HybridCache." )
37
+
28
38
29
39
class ETCustomStaticCache (StaticCache ):
30
40
"""
@@ -55,7 +65,7 @@ def __init__(
55
65
assert device is None or device == "cpu" , "Device must be None or 'cpu'"
56
66
57
67
# Create a list of CustomKVCache instances, one per layer
58
- kv_cache_list = []
68
+ self . kv_cache = torch . nn . ModuleList ()
59
69
for _ in range (config .num_hidden_layers ):
60
70
layer_cache = CustomKVCache (
61
71
max_batch_size = self .max_batch_size ,
@@ -64,8 +74,7 @@ def __init__(
64
74
head_dim = self .head_dim ,
65
75
dtype = dtype ,
66
76
)
67
- kv_cache_list .append (layer_cache )
68
- self .kv_cache = torch .nn .ModuleList (kv_cache_list )
77
+ self .kv_cache .append (layer_cache )
69
78
70
79
def update (
71
80
self ,
@@ -180,6 +189,135 @@ def from_legacy_cache(
180
189
)
181
190
182
191
192
+ # Need to figure out if I have to inherit from HybridCache or StaticCache
193
+ class ETCustomHybridCache (HybridCache ):
194
+ """
195
+ Custom Hybrid KV Cache implementation for ExecutorTorch that inherits from Hugging Face's HybridCache
196
+ but uses ExecutorTorch's CustomKVCache for global layers and CustomRingKVCache for sliding window layers.
197
+ """
198
+
199
+ def __init__ (
200
+ self ,
201
+ config ,
202
+ max_batch_size : int ,
203
+ max_cache_len : Optional [int ] = None ,
204
+ device : Union [torch .device , str , None ] = None ,
205
+ dtype : torch .dtype = torch .float32 ,
206
+ layer_device_map : Optional [Dict [int , Union [str , torch .device , int ]]] = None ,
207
+ ):
208
+ super ().__init__ (
209
+ config = config ,
210
+ max_batch_size = max_batch_size ,
211
+ max_cache_len = max_cache_len ,
212
+ device = device ,
213
+ dtype = dtype ,
214
+ layer_device_map = layer_device_map ,
215
+ )
216
+
217
+ # make sure layer_device_map is none
218
+ assert layer_device_map is None
219
+ assert device is None or device == "cpu" , "Device must be None or 'cpu'"
220
+
221
+ self .cache_position = None
222
+ # Create a list of cache instances, one per layer
223
+ # Use CustomKVCache for global layers and CustomRingKVCache for sliding window layers
224
+ self .kv_cache = torch .nn .ModuleList ()
225
+ for layer_idx in range (config .num_hidden_layers ):
226
+ # newer version of transfomer has is_sliding defined
227
+ # for HybridCache
228
+ if self .is_sliding_list [layer_idx ]:
229
+ # This is a sliding window layer
230
+ layer_cache = CustomRingKVCache (
231
+ max_batch_size = self .max_batch_size ,
232
+ max_context_length = self .sliding_window_len ,
233
+ n_heads = self .num_key_value_heads ,
234
+ head_dim = self .head_dim ,
235
+ dtype = dtype ,
236
+ )
237
+ else :
238
+ layer_cache = CustomKVCache (
239
+ max_batch_size = self .max_batch_size ,
240
+ max_context_length = self .max_cache_len ,
241
+ n_heads = self .num_key_value_heads ,
242
+ head_dim = self .head_dim ,
243
+ dtype = dtype ,
244
+ )
245
+ self .kv_cache .append (layer_cache )
246
+
247
+ def update (
248
+ self ,
249
+ key_states : torch .Tensor ,
250
+ value_states : torch .Tensor ,
251
+ layer_idx : int ,
252
+ cache_kwargs : Optional [Dict [str , Any ]] = None ,
253
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
254
+ """
255
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`
256
+ using ExecutorTorch's CustomKVCache or CustomRingKVCache depending on the layer type.
257
+
258
+ Args:
259
+ key_states (`torch.Tensor`):
260
+ The new key states to cache. Shape: [batch_size, n_heads, seq_len, head_dim]
261
+ value_states (`torch.Tensor`):
262
+ The new value states to cache. Shape: [batch_size, n_heads, seq_len, head_dim]
263
+ layer_idx (`int`):
264
+ The index of the layer to cache the states for.
265
+ cache_kwargs (`Dict[str, Any]`, `optional`):
266
+ Additional arguments for the cache update.
267
+
268
+ Returns:
269
+ A tuple containing the updated key and value states.
270
+ """
271
+ assert cache_kwargs is not None
272
+
273
+ # Get cache position from cache_kwargs (used by HybridCache)
274
+ cache_position = cache_kwargs .get ("cache_position" )
275
+ assert cache_position is not None
276
+ assert isinstance (cache_position , torch .Tensor )
277
+ self .cache_position = cache_position
278
+
279
+ # Get the cache instance for this layer (either CustomKVCache or CustomRingKVCache)
280
+ layer_cache = self .kv_cache [layer_idx ]
281
+
282
+ # Use the cache's update method
283
+ # Both CustomKVCache and CustomRingKVCache have the same update interface
284
+ k_out , v_out = layer_cache .update (
285
+ input_pos = cache_position ,
286
+ k_val = key_states ,
287
+ v_val = value_states ,
288
+ )
289
+
290
+ return k_out , v_out
291
+
292
+ def get_seq_length (self , layer_idx : Optional [int ] = 0 ) -> int :
293
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
294
+ if layer_idx is None :
295
+ layer_idx = 0
296
+
297
+ # For CustomRingKVCache, we need to handle the sequence length differently
298
+ layer_cache = self .kv_cache [layer_idx ]
299
+ if self .is_sliding [layer_idx ]:
300
+ # CustomRingKVCache cache_position_manager which
301
+ # maintains cache position for each slot in the kv cache
302
+ # we return the max position + 1 to indicate max position
303
+ # seen so far. Not sure if thats the correct interpretation
304
+ # of sequence length
305
+ return layer_cache .cache_positions_manager .cache_positions .max ().item () + 1
306
+ return (layer_cache .k_cache [0 , :, 0 ].any (dim = - 1 )).sum ()
307
+
308
+ def get_layer_cache (self , layer_idx : int ):
309
+ """
310
+ Get the cache for a specific layer. This method is dynamo-traceable.
311
+
312
+ Args:
313
+ layer_idx (int): The layer index
314
+
315
+ Returns:
316
+ The cache instance for the specified layer (CustomKVCache or CustomRingKVCache)
317
+ """
318
+ return self .kv_cache [layer_idx ]
319
+
320
+
183
321
def replace_with_et_custom_kv_cache (module , config , generation_config , cache_dtype ):
184
322
"""
185
323
Replace all KV caches in the module with ETCustomStaticCache.
@@ -192,22 +330,6 @@ def replace_with_et_custom_kv_cache(module, config, generation_config, cache_dty
192
330
Returns:
193
331
The modified module
194
332
"""
195
- # Ensure custom ops are registered
196
- try :
197
- op = torch .ops .llama .update_cache
198
- assert op is not None
199
- except Exception :
200
- try :
201
- from executorch .extension .llm .custom_ops import custom_ops # noqa: F401
202
-
203
- op = torch .ops .llama .update_cache
204
- assert op is not None
205
- except ImportError :
206
- raise ImportError (
207
- "ExecutorTorch custom operations are not available. "
208
- "Please install executorch with custom operations support."
209
- )
210
-
211
333
# Recursively replace KV caches
212
334
return _replace_with_et_custom_kv_cache (module , config , generation_config , cache_dtype )
213
335
@@ -223,33 +345,73 @@ def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dt
223
345
Returns:
224
346
The modified module
225
347
"""
226
- assert hasattr (module , "static_cache" )
227
- assert isinstance (
228
- module .static_cache , StaticCache
229
- ), "Only StaticCache transform is supported. Hybrid cache with local global attention is not yet supported"
230
- # TODO: Add replace_cache to exported module
231
- # in transformer's executorch.py
232
- if getattr (module , "replace_cache" , None ) is not None :
233
- static_cache = ETCustomStaticCache (
234
- config = config ,
235
- max_batch_size = generation_config .cache_config .batch_size ,
236
- max_cache_len = generation_config .cache_config .max_cache_len ,
237
- device = generation_config .cache_config .device ,
238
- dtype = cache_dtype ,
239
- )
240
- module .replace_cache (static_cache )
348
+ # Check if module has static_cache (TorchExportableModuleWithStaticCache)
349
+ if hasattr (module , "static_cache" ):
350
+ assert isinstance (module .static_cache , StaticCache ), f"Expected StaticCache, got { type (module .static_cache )} "
351
+
352
+ # TODO: Add replace_cache to exported module
353
+ # in transformer's executorch.py
354
+ if getattr (module , "replace_cache" , None ) is not None :
355
+ static_cache = ETCustomStaticCache (
356
+ config = config ,
357
+ max_batch_size = generation_config .cache_config .batch_size ,
358
+ max_cache_len = generation_config .cache_config .max_cache_len ,
359
+ device = generation_config .cache_config .device ,
360
+ dtype = cache_dtype ,
361
+ )
362
+ module .replace_cache (static_cache )
363
+ else :
364
+ module .static_cache = ETCustomStaticCache (
365
+ config = config ,
366
+ max_batch_size = generation_config .cache_config .batch_size ,
367
+ max_cache_len = generation_config .cache_config .max_cache_len ,
368
+ device = generation_config .cache_config .device ,
369
+ dtype = cache_dtype ,
370
+ )
371
+ # Dont know why we need to this even though
372
+ # CustomKVCache registers the attributes
373
+ for i in range (len (module .static_cache .kv_cache )):
374
+ setattr (module , f"key_cache_{ i } " , module .static_cache .kv_cache [i ].k_cache )
375
+ setattr (module , f"value_cache_{ i } " , module .static_cache .kv_cache [i ].v_cache )
376
+
377
+ # Check if module has cache (TorchExportableModuleWithHybridCache)
378
+ elif hasattr (module , "cache" ):
379
+ assert isinstance (module .cache , HybridCache ), f"Expected HybridCache, got { type (module .cache )} "
380
+
381
+ # Replace with ETCustomHybridCache
382
+ if getattr (module , "replace_cache" , None ) is not None :
383
+ hybrid_cache = ETCustomHybridCache (
384
+ config = config ,
385
+ max_batch_size = generation_config .cache_config .batch_size ,
386
+ max_cache_len = generation_config .cache_config .max_cache_len ,
387
+ device = generation_config .cache_config .device ,
388
+ dtype = cache_dtype ,
389
+ )
390
+ module .replace_cache (hybrid_cache )
391
+ else :
392
+ module .cache = ETCustomHybridCache (
393
+ config = config ,
394
+ max_batch_size = generation_config .cache_config .batch_size ,
395
+ max_cache_len = generation_config .cache_config .max_cache_len ,
396
+ device = generation_config .cache_config .device ,
397
+ dtype = cache_dtype ,
398
+ )
399
+ # Register cache attributes for each layer
400
+ for i in range (len (module .cache .kv_cache )):
401
+ setattr (module , f"key_cache_{ i } " , module .cache .kv_cache [i ].k_cache )
402
+ setattr (module , f"value_cache_{ i } " , module .cache .kv_cache [i ].v_cache )
403
+ if module .cache .is_sliding_list [i ]:
404
+ # Register cache_positions as buffer for sliding window layers
405
+ # This prevents it from being traced as a constant
406
+ module .register_buffer (
407
+ f"cache_positions_{ i } " ,
408
+ module .cache .kv_cache [i ].cache_positions_manager .cache_positions ,
409
+ persistent = False ,
410
+ )
241
411
else :
242
- module .static_cache = ETCustomStaticCache (
243
- config = config ,
244
- max_batch_size = generation_config .cache_config .batch_size ,
245
- max_cache_len = generation_config .cache_config .max_cache_len ,
246
- device = generation_config .cache_config .device ,
247
- dtype = cache_dtype ,
412
+ raise ValueError (
413
+ "Module must have either 'static_cache' (TorchExportableModuleWithStaticCache) "
414
+ "or 'cache' (TorchExportableModuleWithHybridCache) attribute"
248
415
)
249
- # Dont know why we need to this even though
250
- # CustomKVCache registers the attributes
251
- for i in range (len (module .static_cache .kv_cache )):
252
- setattr (module , f"key_cache_{ i } " , module .static_cache .kv_cache [i ].k_cache )
253
- setattr (module , f"value_cache_{ i } " , module .static_cache .kv_cache [i ].v_cache )
254
416
255
417
return module
0 commit comments