9
9
import torch
10
10
11
11
12
+ # If transformers is not installed, raise an ImportError
12
13
try :
13
- from transformers .cache_utils import StaticCache
14
+ from transformers .cache_utils import HybridCache , StaticCache
14
15
except ImportError :
15
- # If transformers is not installed, raise an ImportError
16
- try :
17
- from transformers .cache_utils import StaticCache
18
- except ImportError :
19
- raise ImportError ("transformers is not installed. Please install it to use StaticCache." )
16
+ raise ImportError ("transformers is not installed. Please install it to use Static/HybridCache." )
20
17
21
18
try :
22
19
from executorch .examples .models .llama .source_transformation .custom_kv_cache import (
23
20
CustomKVCache ,
21
+ CustomRingKVCache ,
24
22
)
25
23
except ImportError :
26
- raise ImportError ("ExecutorTorch is not installed. Please install it to use CustomKVCache ." )
24
+ raise ImportError ("ExecutorTorch is not installed. Please install it to use Custom Cache ." )
27
25
28
26
29
27
class ETCustomStaticCache (StaticCache ):
@@ -55,7 +53,7 @@ def __init__(
55
53
assert device is None or device == "cpu" , "Device must be None or 'cpu'"
56
54
57
55
# Create a list of CustomKVCache instances, one per layer
58
- kv_cache_list = []
56
+ self . kv_cache = torch . nn . ModuleList ()
59
57
for _ in range (config .num_hidden_layers ):
60
58
layer_cache = CustomKVCache (
61
59
max_batch_size = self .max_batch_size ,
@@ -64,8 +62,7 @@ def __init__(
64
62
head_dim = self .head_dim ,
65
63
dtype = dtype ,
66
64
)
67
- kv_cache_list .append (layer_cache )
68
- self .kv_cache = torch .nn .ModuleList (kv_cache_list )
65
+ self .kv_cache .append (layer_cache )
69
66
70
67
def update (
71
68
self ,
@@ -180,6 +177,135 @@ def from_legacy_cache(
180
177
)
181
178
182
179
180
+ # Need to figure out if I have to inherit from HybridCache or StaticCache
181
+ class ETCustomHybridCache (HybridCache ):
182
+ """
183
+ Custom Hybrid KV Cache implementation for ExecutorTorch that inherits from Hugging Face's HybridCache
184
+ but uses ExecutorTorch's CustomKVCache for global layers and CustomRingKVCache for sliding window layers.
185
+ """
186
+
187
+ def __init__ (
188
+ self ,
189
+ config ,
190
+ max_batch_size : int ,
191
+ max_cache_len : Optional [int ] = None ,
192
+ device : Union [torch .device , str , None ] = None ,
193
+ dtype : torch .dtype = torch .float32 ,
194
+ layer_device_map : Optional [Dict [int , Union [str , torch .device , int ]]] = None ,
195
+ ):
196
+ super ().__init__ (
197
+ config = config ,
198
+ max_batch_size = max_batch_size ,
199
+ max_cache_len = max_cache_len ,
200
+ device = device ,
201
+ dtype = dtype ,
202
+ layer_device_map = layer_device_map ,
203
+ )
204
+
205
+ # make sure layer_device_map is none
206
+ assert layer_device_map is None
207
+ assert device is None or device == "cpu" , "Device must be None or 'cpu'"
208
+
209
+ self .cache_position = None
210
+ # Create a list of cache instances, one per layer
211
+ # Use CustomKVCache for global layers and CustomRingKVCache for sliding window layers
212
+ self .kv_cache = torch .nn .ModuleList ()
213
+ for layer_idx in range (config .num_hidden_layers ):
214
+ # newer version of transfomer has is_sliding defined
215
+ # for HybridCache
216
+ if self .is_sliding [layer_idx ]:
217
+ # This is a sliding window layer
218
+ layer_cache = CustomRingKVCache (
219
+ max_batch_size = self .max_batch_size ,
220
+ max_context_length = self .sliding_window_len ,
221
+ n_heads = self .num_key_value_heads ,
222
+ head_dim = self .head_dim ,
223
+ dtype = dtype ,
224
+ )
225
+ else :
226
+ layer_cache = CustomKVCache (
227
+ max_batch_size = self .max_batch_size ,
228
+ max_context_length = self .max_cache_len ,
229
+ n_heads = self .num_key_value_heads ,
230
+ head_dim = self .head_dim ,
231
+ dtype = dtype ,
232
+ )
233
+ self .kv_cache .append (layer_cache )
234
+
235
+ def update (
236
+ self ,
237
+ key_states : torch .Tensor ,
238
+ value_states : torch .Tensor ,
239
+ layer_idx : int ,
240
+ cache_kwargs : Optional [Dict [str , Any ]] = None ,
241
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
242
+ """
243
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`
244
+ using ExecutorTorch's CustomKVCache or CustomRingKVCache depending on the layer type.
245
+
246
+ Args:
247
+ key_states (`torch.Tensor`):
248
+ The new key states to cache. Shape: [batch_size, n_heads, seq_len, head_dim]
249
+ value_states (`torch.Tensor`):
250
+ The new value states to cache. Shape: [batch_size, n_heads, seq_len, head_dim]
251
+ layer_idx (`int`):
252
+ The index of the layer to cache the states for.
253
+ cache_kwargs (`Dict[str, Any]`, `optional`):
254
+ Additional arguments for the cache update.
255
+
256
+ Returns:
257
+ A tuple containing the updated key and value states.
258
+ """
259
+ assert cache_kwargs is not None
260
+
261
+ # Get cache position from cache_kwargs (used by HybridCache)
262
+ cache_position = cache_kwargs .get ("cache_position" )
263
+ assert cache_position is not None
264
+ assert isinstance (cache_position , torch .Tensor )
265
+ self .cache_position = cache_position
266
+
267
+ # Get the cache instance for this layer (either CustomKVCache or CustomRingKVCache)
268
+ layer_cache = self .kv_cache [layer_idx ]
269
+
270
+ # Use the cache's update method
271
+ # Both CustomKVCache and CustomRingKVCache have the same update interface
272
+ k_out , v_out = layer_cache .update (
273
+ input_pos = cache_position ,
274
+ k_val = key_states ,
275
+ v_val = value_states ,
276
+ )
277
+
278
+ return k_out , v_out
279
+
280
+ def get_seq_length (self , layer_idx : Optional [int ] = 0 ) -> int :
281
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
282
+ if layer_idx is None :
283
+ layer_idx = 0
284
+
285
+ # For CustomRingKVCache, we need to handle the sequence length differently
286
+ layer_cache = self .kv_cache [layer_idx ]
287
+ if self .is_sliding [layer_idx ]:
288
+ # CustomRingKVCache cache_position_manager which
289
+ # maintains cache position for each slot in the kv cache
290
+ # we return the max position + 1 to indicate max position
291
+ # seen so far. Not sure if thats the correct interpretation
292
+ # of sequence length
293
+ return layer_cache .cache_positions_manager .cache_positions .max ().item () + 1
294
+ return (layer_cache .k_cache [0 , :, 0 ].any (dim = - 1 )).sum ()
295
+
296
+ def get_layer_cache (self , layer_idx : int ):
297
+ """
298
+ Get the cache for a specific layer. This method is dynamo-traceable.
299
+
300
+ Args:
301
+ layer_idx (int): The layer index
302
+
303
+ Returns:
304
+ The cache instance for the specified layer (CustomKVCache or CustomRingKVCache)
305
+ """
306
+ return self .kv_cache [layer_idx ]
307
+
308
+
183
309
def replace_with_et_custom_kv_cache (module , config , generation_config , cache_dtype ):
184
310
"""
185
311
Replace all KV caches in the module with ETCustomStaticCache.
@@ -192,22 +318,6 @@ def replace_with_et_custom_kv_cache(module, config, generation_config, cache_dty
192
318
Returns:
193
319
The modified module
194
320
"""
195
- # Ensure custom ops are registered
196
- try :
197
- op = torch .ops .llama .update_cache
198
- assert op is not None
199
- except Exception :
200
- try :
201
- from executorch .extension .llm .custom_ops import custom_ops # noqa: F401
202
-
203
- op = torch .ops .llama .update_cache
204
- assert op is not None
205
- except ImportError :
206
- raise ImportError (
207
- "ExecutorTorch custom operations are not available. "
208
- "Please install executorch with custom operations support."
209
- )
210
-
211
321
# Recursively replace KV caches
212
322
return _replace_with_et_custom_kv_cache (module , config , generation_config , cache_dtype )
213
323
@@ -223,33 +333,73 @@ def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dt
223
333
Returns:
224
334
The modified module
225
335
"""
226
- assert hasattr (module , "static_cache" )
227
- assert isinstance (
228
- module .static_cache , StaticCache
229
- ), "Only StaticCache transform is supported. Hybrid cache with local global attention is not yet supported"
230
- # TODO: Add replace_cache to exported module
231
- # in transformer's executorch.py
232
- if getattr (module , "replace_cache" , None ) is not None :
233
- static_cache = ETCustomStaticCache (
234
- config = config ,
235
- max_batch_size = generation_config .cache_config .batch_size ,
236
- max_cache_len = generation_config .cache_config .max_cache_len ,
237
- device = generation_config .cache_config .device ,
238
- dtype = cache_dtype ,
239
- )
240
- module .replace_cache (static_cache )
336
+ # Check if module has static_cache (TorchExportableModuleWithStaticCache)
337
+ if hasattr (module , "static_cache" ):
338
+ assert isinstance (module .static_cache , StaticCache ), f"Expected StaticCache, got { type (module .static_cache )} "
339
+
340
+ # TODO: Add replace_cache to exported module
341
+ # in transformer's executorch.py
342
+ if getattr (module , "replace_cache" , None ) is not None :
343
+ static_cache = ETCustomStaticCache (
344
+ config = config ,
345
+ max_batch_size = generation_config .cache_config .batch_size ,
346
+ max_cache_len = generation_config .cache_config .max_cache_len ,
347
+ device = generation_config .cache_config .device ,
348
+ dtype = cache_dtype ,
349
+ )
350
+ module .replace_cache (static_cache )
351
+ else :
352
+ module .static_cache = ETCustomStaticCache (
353
+ config = config ,
354
+ max_batch_size = generation_config .cache_config .batch_size ,
355
+ max_cache_len = generation_config .cache_config .max_cache_len ,
356
+ device = generation_config .cache_config .device ,
357
+ dtype = cache_dtype ,
358
+ )
359
+ # Dont know why we need to this even though
360
+ # CustomKVCache registers the attributes
361
+ for i in range (len (module .static_cache .kv_cache )):
362
+ setattr (module , f"key_cache_{ i } " , module .static_cache .kv_cache [i ].k_cache )
363
+ setattr (module , f"value_cache_{ i } " , module .static_cache .kv_cache [i ].v_cache )
364
+
365
+ # Check if module has cache (TorchExportableModuleWithHybridCache)
366
+ elif hasattr (module , "cache" ):
367
+ assert isinstance (module .cache , HybridCache ), f"Expected HybridCache, got { type (module .cache )} "
368
+
369
+ # Replace with ETCustomHybridCache
370
+ if getattr (module , "replace_cache" , None ) is not None :
371
+ hybrid_cache = ETCustomHybridCache (
372
+ config = config ,
373
+ max_batch_size = generation_config .cache_config .batch_size ,
374
+ max_cache_len = generation_config .cache_config .max_cache_len ,
375
+ device = generation_config .cache_config .device ,
376
+ dtype = cache_dtype ,
377
+ )
378
+ module .replace_cache (hybrid_cache )
379
+ else :
380
+ module .cache = ETCustomHybridCache (
381
+ config = config ,
382
+ max_batch_size = generation_config .cache_config .batch_size ,
383
+ max_cache_len = generation_config .cache_config .max_cache_len ,
384
+ device = generation_config .cache_config .device ,
385
+ dtype = cache_dtype ,
386
+ )
387
+ # Register cache attributes for each layer
388
+ for i in range (len (module .cache .kv_cache )):
389
+ setattr (module , f"key_cache_{ i } " , module .cache .kv_cache [i ].k_cache )
390
+ setattr (module , f"value_cache_{ i } " , module .cache .kv_cache [i ].v_cache )
391
+ if module .cache .is_sliding [i ]:
392
+ # Register cache_positions as buffer for sliding window layers
393
+ # This prevents it from being traced as a constant
394
+ module .register_buffer (
395
+ f"cache_positions_{ i } " ,
396
+ module .cache .kv_cache [i ].cache_positions_manager .cache_positions ,
397
+ persistent = False ,
398
+ )
241
399
else :
242
- module .static_cache = ETCustomStaticCache (
243
- config = config ,
244
- max_batch_size = generation_config .cache_config .batch_size ,
245
- max_cache_len = generation_config .cache_config .max_cache_len ,
246
- device = generation_config .cache_config .device ,
247
- dtype = cache_dtype ,
400
+ raise ValueError (
401
+ "Module must have either 'static_cache' (TorchExportableModuleWithStaticCache) "
402
+ "or 'cache' (TorchExportableModuleWithHybridCache) attribute"
248
403
)
249
- # Dont know why we need to this even though
250
- # CustomKVCache registers the attributes
251
- for i in range (len (module .static_cache .kv_cache )):
252
- setattr (module , f"key_cache_{ i } " , module .static_cache .kv_cache [i ].k_cache )
253
- setattr (module , f"value_cache_{ i } " , module .static_cache .kv_cache [i ].v_cache )
254
404
255
405
return module
0 commit comments