Skip to content

Commit 2b8b948

Browse files
committed
Add support for sliding window via ring kv cache
Summary: This diffs adds - Ring kv cache for local attention where ring buffer updates cache with new entries in a ring fashion rather than sliding cache - Correspondingly we have to update sdpa to accept mask that correspond to the state of ring buffer because slot in ring buffer may correspond to different token positions. Thus we cannot just use causal mask. Now custom_kv_cache option means we support both global attention and local/global like gemma3 Test Plan: Test added Reviewers: Subscribers: Tasks: Tags:
1 parent e1d9cd1 commit 2b8b948

File tree

4 files changed

+336
-55
lines changed

4 files changed

+336
-55
lines changed

optimum/executorch/attentions/custom_kv_cache.py

Lines changed: 207 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,20 @@
2121
try:
2222
from executorch.examples.models.llama.source_transformation.custom_kv_cache import (
2323
CustomKVCache,
24+
CustomRingKVCache,
2425
)
2526
except ImportError:
2627
raise ImportError("ExecutorTorch is not installed. Please install it to use CustomKVCache.")
2728

29+
try:
30+
from transformers.cache_utils import HybridCache
31+
except ImportError:
32+
# If transformers is not installed, raise an ImportError
33+
try:
34+
from transformers.cache_utils import HybridCache
35+
except ImportError:
36+
raise ImportError("transformers is not installed. Please install it to use HybridCache.")
37+
2838

2939
class ETCustomStaticCache(StaticCache):
3040
"""
@@ -55,7 +65,7 @@ def __init__(
5565
assert device is None or device == "cpu", "Device must be None or 'cpu'"
5666

5767
# Create a list of CustomKVCache instances, one per layer
58-
kv_cache_list = []
68+
self.kv_cache = torch.nn.ModuleList()
5969
for _ in range(config.num_hidden_layers):
6070
layer_cache = CustomKVCache(
6171
max_batch_size=self.max_batch_size,
@@ -64,8 +74,7 @@ def __init__(
6474
head_dim=self.head_dim,
6575
dtype=dtype,
6676
)
67-
kv_cache_list.append(layer_cache)
68-
self.kv_cache = torch.nn.ModuleList(kv_cache_list)
77+
self.kv_cache.append(layer_cache)
6978

7079
def update(
7180
self,
@@ -180,6 +189,135 @@ def from_legacy_cache(
180189
)
181190

182191

192+
# Need to figure out if I have to inherit from HybridCache or StaticCache
193+
class ETCustomHybridCache(HybridCache):
194+
"""
195+
Custom Hybrid KV Cache implementation for ExecutorTorch that inherits from Hugging Face's HybridCache
196+
but uses ExecutorTorch's CustomKVCache for global layers and CustomRingKVCache for sliding window layers.
197+
"""
198+
199+
def __init__(
200+
self,
201+
config,
202+
max_batch_size: int,
203+
max_cache_len: Optional[int] = None,
204+
device: Union[torch.device, str, None] = None,
205+
dtype: torch.dtype = torch.float32,
206+
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
207+
):
208+
super().__init__(
209+
config=config,
210+
max_batch_size=max_batch_size,
211+
max_cache_len=max_cache_len,
212+
device=device,
213+
dtype=dtype,
214+
layer_device_map=layer_device_map,
215+
)
216+
217+
# make sure layer_device_map is none
218+
assert layer_device_map is None
219+
assert device is None or device == "cpu", "Device must be None or 'cpu'"
220+
221+
self.cache_position = None
222+
# Create a list of cache instances, one per layer
223+
# Use CustomKVCache for global layers and CustomRingKVCache for sliding window layers
224+
self.kv_cache = torch.nn.ModuleList()
225+
for layer_idx in range(config.num_hidden_layers):
226+
# newer version of transfomer has is_sliding defined
227+
# for HybridCache
228+
if self.is_sliding_list[layer_idx]:
229+
# This is a sliding window layer
230+
layer_cache = CustomRingKVCache(
231+
max_batch_size=self.max_batch_size,
232+
max_context_length=self.sliding_window_len,
233+
n_heads=self.num_key_value_heads,
234+
head_dim=self.head_dim,
235+
dtype=dtype,
236+
)
237+
else:
238+
layer_cache = CustomKVCache(
239+
max_batch_size=self.max_batch_size,
240+
max_context_length=self.max_cache_len,
241+
n_heads=self.num_key_value_heads,
242+
head_dim=self.head_dim,
243+
dtype=dtype,
244+
)
245+
self.kv_cache.append(layer_cache)
246+
247+
def update(
248+
self,
249+
key_states: torch.Tensor,
250+
value_states: torch.Tensor,
251+
layer_idx: int,
252+
cache_kwargs: Optional[Dict[str, Any]] = None,
253+
) -> Tuple[torch.Tensor, torch.Tensor]:
254+
"""
255+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`
256+
using ExecutorTorch's CustomKVCache or CustomRingKVCache depending on the layer type.
257+
258+
Args:
259+
key_states (`torch.Tensor`):
260+
The new key states to cache. Shape: [batch_size, n_heads, seq_len, head_dim]
261+
value_states (`torch.Tensor`):
262+
The new value states to cache. Shape: [batch_size, n_heads, seq_len, head_dim]
263+
layer_idx (`int`):
264+
The index of the layer to cache the states for.
265+
cache_kwargs (`Dict[str, Any]`, `optional`):
266+
Additional arguments for the cache update.
267+
268+
Returns:
269+
A tuple containing the updated key and value states.
270+
"""
271+
assert cache_kwargs is not None
272+
273+
# Get cache position from cache_kwargs (used by HybridCache)
274+
cache_position = cache_kwargs.get("cache_position")
275+
assert cache_position is not None
276+
assert isinstance(cache_position, torch.Tensor)
277+
self.cache_position = cache_position
278+
279+
# Get the cache instance for this layer (either CustomKVCache or CustomRingKVCache)
280+
layer_cache = self.kv_cache[layer_idx]
281+
282+
# Use the cache's update method
283+
# Both CustomKVCache and CustomRingKVCache have the same update interface
284+
k_out, v_out = layer_cache.update(
285+
input_pos=cache_position,
286+
k_val=key_states,
287+
v_val=value_states,
288+
)
289+
290+
return k_out, v_out
291+
292+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
293+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
294+
if layer_idx is None:
295+
layer_idx = 0
296+
297+
# For CustomRingKVCache, we need to handle the sequence length differently
298+
layer_cache = self.kv_cache[layer_idx]
299+
if self.is_sliding[layer_idx]:
300+
# CustomRingKVCache cache_position_manager which
301+
# maintains cache position for each slot in the kv cache
302+
# we return the max position + 1 to indicate max position
303+
# seen so far. Not sure if thats the correct interpretation
304+
# of sequence length
305+
return layer_cache.cache_positions_manager.cache_positions.max().item() + 1
306+
return (layer_cache.k_cache[0, :, 0].any(dim=-1)).sum()
307+
308+
def get_layer_cache(self, layer_idx: int):
309+
"""
310+
Get the cache for a specific layer. This method is dynamo-traceable.
311+
312+
Args:
313+
layer_idx (int): The layer index
314+
315+
Returns:
316+
The cache instance for the specified layer (CustomKVCache or CustomRingKVCache)
317+
"""
318+
return self.kv_cache[layer_idx]
319+
320+
183321
def replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype):
184322
"""
185323
Replace all KV caches in the module with ETCustomStaticCache.
@@ -192,22 +330,6 @@ def replace_with_et_custom_kv_cache(module, config, generation_config, cache_dty
192330
Returns:
193331
The modified module
194332
"""
195-
# Ensure custom ops are registered
196-
try:
197-
op = torch.ops.llama.update_cache
198-
assert op is not None
199-
except Exception:
200-
try:
201-
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
202-
203-
op = torch.ops.llama.update_cache
204-
assert op is not None
205-
except ImportError:
206-
raise ImportError(
207-
"ExecutorTorch custom operations are not available. "
208-
"Please install executorch with custom operations support."
209-
)
210-
211333
# Recursively replace KV caches
212334
return _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype)
213335

@@ -223,33 +345,73 @@ def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dt
223345
Returns:
224346
The modified module
225347
"""
226-
assert hasattr(module, "static_cache")
227-
assert isinstance(
228-
module.static_cache, StaticCache
229-
), "Only StaticCache transform is supported. Hybrid cache with local global attention is not yet supported"
230-
# TODO: Add replace_cache to exported module
231-
# in transformer's executorch.py
232-
if getattr(module, "replace_cache", None) is not None:
233-
static_cache = ETCustomStaticCache(
234-
config=config,
235-
max_batch_size=generation_config.cache_config.batch_size,
236-
max_cache_len=generation_config.cache_config.max_cache_len,
237-
device=generation_config.cache_config.device,
238-
dtype=cache_dtype,
239-
)
240-
module.replace_cache(static_cache)
348+
# Check if module has static_cache (TorchExportableModuleWithStaticCache)
349+
if hasattr(module, "static_cache"):
350+
assert isinstance(module.static_cache, StaticCache), f"Expected StaticCache, got {type(module.static_cache)}"
351+
352+
# TODO: Add replace_cache to exported module
353+
# in transformer's executorch.py
354+
if getattr(module, "replace_cache", None) is not None:
355+
static_cache = ETCustomStaticCache(
356+
config=config,
357+
max_batch_size=generation_config.cache_config.batch_size,
358+
max_cache_len=generation_config.cache_config.max_cache_len,
359+
device=generation_config.cache_config.device,
360+
dtype=cache_dtype,
361+
)
362+
module.replace_cache(static_cache)
363+
else:
364+
module.static_cache = ETCustomStaticCache(
365+
config=config,
366+
max_batch_size=generation_config.cache_config.batch_size,
367+
max_cache_len=generation_config.cache_config.max_cache_len,
368+
device=generation_config.cache_config.device,
369+
dtype=cache_dtype,
370+
)
371+
# Dont know why we need to this even though
372+
# CustomKVCache registers the attributes
373+
for i in range(len(module.static_cache.kv_cache)):
374+
setattr(module, f"key_cache_{i}", module.static_cache.kv_cache[i].k_cache)
375+
setattr(module, f"value_cache_{i}", module.static_cache.kv_cache[i].v_cache)
376+
377+
# Check if module has cache (TorchExportableModuleWithHybridCache)
378+
elif hasattr(module, "cache"):
379+
assert isinstance(module.cache, HybridCache), f"Expected HybridCache, got {type(module.cache)}"
380+
381+
# Replace with ETCustomHybridCache
382+
if getattr(module, "replace_cache", None) is not None:
383+
hybrid_cache = ETCustomHybridCache(
384+
config=config,
385+
max_batch_size=generation_config.cache_config.batch_size,
386+
max_cache_len=generation_config.cache_config.max_cache_len,
387+
device=generation_config.cache_config.device,
388+
dtype=cache_dtype,
389+
)
390+
module.replace_cache(hybrid_cache)
391+
else:
392+
module.cache = ETCustomHybridCache(
393+
config=config,
394+
max_batch_size=generation_config.cache_config.batch_size,
395+
max_cache_len=generation_config.cache_config.max_cache_len,
396+
device=generation_config.cache_config.device,
397+
dtype=cache_dtype,
398+
)
399+
# Register cache attributes for each layer
400+
for i in range(len(module.cache.kv_cache)):
401+
setattr(module, f"key_cache_{i}", module.cache.kv_cache[i].k_cache)
402+
setattr(module, f"value_cache_{i}", module.cache.kv_cache[i].v_cache)
403+
if module.cache.is_sliding_list[i]:
404+
# Register cache_positions as buffer for sliding window layers
405+
# This prevents it from being traced as a constant
406+
module.register_buffer(
407+
f"cache_positions_{i}",
408+
module.cache.kv_cache[i].cache_positions_manager.cache_positions,
409+
persistent=False,
410+
)
241411
else:
242-
module.static_cache = ETCustomStaticCache(
243-
config=config,
244-
max_batch_size=generation_config.cache_config.batch_size,
245-
max_cache_len=generation_config.cache_config.max_cache_len,
246-
device=generation_config.cache_config.device,
247-
dtype=cache_dtype,
412+
raise ValueError(
413+
"Module must have either 'static_cache' (TorchExportableModuleWithStaticCache) "
414+
"or 'cache' (TorchExportableModuleWithHybridCache) attribute"
248415
)
249-
# Dont know why we need to this even though
250-
# CustomKVCache registers the attributes
251-
for i in range(len(module.static_cache.kv_cache)):
252-
setattr(module, f"key_cache_{i}", module.static_cache.kv_cache[i].k_cache)
253-
setattr(module, f"value_cache_{i}", module.static_cache.kv_cache[i].v_cache)
254416

255417
return module

0 commit comments

Comments
 (0)