Skip to content

Commit c2b20c9

Browse files
authored
Add support for sliding window via ring kv cache (#81)
* Add support for sliding window via ring kv cache Summary: This diffs adds - Ring kv cache for local attention where ring buffer updates cache with new entries in a ring fashion rather than sliding cache - Correspondingly we have to update sdpa to accept mask that correspond to the state of ring buffer because slot in ring buffer may correspond to different token positions. Thus we cannot just use causal mask. Now custom_kv_cache option means we support both global attention and local/global like gemma3 Test Plan: Test added Reviewers: Subscribers: Tasks: Tags: * Gating imports on ET version Tags: * Make room for transformer changes requiring attn_mask fn Summary: Latest changse in 4.53 requires that custom attentin functions have corresponding mask generation Test Plan: CI wont pass without changes on top of 4.53 Reviewers: Subscribers: Tasks: Tags:
1 parent 6e71e2b commit c2b20c9

File tree

6 files changed

+365
-67
lines changed

6 files changed

+365
-67
lines changed

optimum/executorch/attentions/custom_kv_cache.py

Lines changed: 202 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,19 @@
99
import torch
1010

1111

12+
# If transformers is not installed, raise an ImportError
1213
try:
13-
from transformers.cache_utils import StaticCache
14+
from transformers.cache_utils import HybridCache, StaticCache
1415
except ImportError:
15-
# If transformers is not installed, raise an ImportError
16-
try:
17-
from transformers.cache_utils import StaticCache
18-
except ImportError:
19-
raise ImportError("transformers is not installed. Please install it to use StaticCache.")
16+
raise ImportError("transformers is not installed. Please install it to use Static/HybridCache.")
2017

2118
try:
2219
from executorch.examples.models.llama.source_transformation.custom_kv_cache import (
2320
CustomKVCache,
21+
CustomRingKVCache,
2422
)
2523
except ImportError:
26-
raise ImportError("ExecutorTorch is not installed. Please install it to use CustomKVCache.")
24+
raise ImportError("ExecutorTorch is not installed. Please install it to use Custom Cache.")
2725

2826

2927
class ETCustomStaticCache(StaticCache):
@@ -55,7 +53,7 @@ def __init__(
5553
assert device is None or device == "cpu", "Device must be None or 'cpu'"
5654

5755
# Create a list of CustomKVCache instances, one per layer
58-
kv_cache_list = []
56+
self.kv_cache = torch.nn.ModuleList()
5957
for _ in range(config.num_hidden_layers):
6058
layer_cache = CustomKVCache(
6159
max_batch_size=self.max_batch_size,
@@ -64,8 +62,7 @@ def __init__(
6462
head_dim=self.head_dim,
6563
dtype=dtype,
6664
)
67-
kv_cache_list.append(layer_cache)
68-
self.kv_cache = torch.nn.ModuleList(kv_cache_list)
65+
self.kv_cache.append(layer_cache)
6966

7067
def update(
7168
self,
@@ -180,6 +177,135 @@ def from_legacy_cache(
180177
)
181178

182179

180+
# Need to figure out if I have to inherit from HybridCache or StaticCache
181+
class ETCustomHybridCache(HybridCache):
182+
"""
183+
Custom Hybrid KV Cache implementation for ExecutorTorch that inherits from Hugging Face's HybridCache
184+
but uses ExecutorTorch's CustomKVCache for global layers and CustomRingKVCache for sliding window layers.
185+
"""
186+
187+
def __init__(
188+
self,
189+
config,
190+
max_batch_size: int,
191+
max_cache_len: Optional[int] = None,
192+
device: Union[torch.device, str, None] = None,
193+
dtype: torch.dtype = torch.float32,
194+
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
195+
):
196+
super().__init__(
197+
config=config,
198+
max_batch_size=max_batch_size,
199+
max_cache_len=max_cache_len,
200+
device=device,
201+
dtype=dtype,
202+
layer_device_map=layer_device_map,
203+
)
204+
205+
# make sure layer_device_map is none
206+
assert layer_device_map is None
207+
assert device is None or device == "cpu", "Device must be None or 'cpu'"
208+
209+
self.cache_position = None
210+
# Create a list of cache instances, one per layer
211+
# Use CustomKVCache for global layers and CustomRingKVCache for sliding window layers
212+
self.kv_cache = torch.nn.ModuleList()
213+
for layer_idx in range(config.num_hidden_layers):
214+
# newer version of transfomer has is_sliding defined
215+
# for HybridCache
216+
if self.is_sliding[layer_idx]:
217+
# This is a sliding window layer
218+
layer_cache = CustomRingKVCache(
219+
max_batch_size=self.max_batch_size,
220+
max_context_length=self.sliding_window_len,
221+
n_heads=self.num_key_value_heads,
222+
head_dim=self.head_dim,
223+
dtype=dtype,
224+
)
225+
else:
226+
layer_cache = CustomKVCache(
227+
max_batch_size=self.max_batch_size,
228+
max_context_length=self.max_cache_len,
229+
n_heads=self.num_key_value_heads,
230+
head_dim=self.head_dim,
231+
dtype=dtype,
232+
)
233+
self.kv_cache.append(layer_cache)
234+
235+
def update(
236+
self,
237+
key_states: torch.Tensor,
238+
value_states: torch.Tensor,
239+
layer_idx: int,
240+
cache_kwargs: Optional[Dict[str, Any]] = None,
241+
) -> Tuple[torch.Tensor, torch.Tensor]:
242+
"""
243+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`
244+
using ExecutorTorch's CustomKVCache or CustomRingKVCache depending on the layer type.
245+
246+
Args:
247+
key_states (`torch.Tensor`):
248+
The new key states to cache. Shape: [batch_size, n_heads, seq_len, head_dim]
249+
value_states (`torch.Tensor`):
250+
The new value states to cache. Shape: [batch_size, n_heads, seq_len, head_dim]
251+
layer_idx (`int`):
252+
The index of the layer to cache the states for.
253+
cache_kwargs (`Dict[str, Any]`, `optional`):
254+
Additional arguments for the cache update.
255+
256+
Returns:
257+
A tuple containing the updated key and value states.
258+
"""
259+
assert cache_kwargs is not None
260+
261+
# Get cache position from cache_kwargs (used by HybridCache)
262+
cache_position = cache_kwargs.get("cache_position")
263+
assert cache_position is not None
264+
assert isinstance(cache_position, torch.Tensor)
265+
self.cache_position = cache_position
266+
267+
# Get the cache instance for this layer (either CustomKVCache or CustomRingKVCache)
268+
layer_cache = self.kv_cache[layer_idx]
269+
270+
# Use the cache's update method
271+
# Both CustomKVCache and CustomRingKVCache have the same update interface
272+
k_out, v_out = layer_cache.update(
273+
input_pos=cache_position,
274+
k_val=key_states,
275+
v_val=value_states,
276+
)
277+
278+
return k_out, v_out
279+
280+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
281+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
282+
if layer_idx is None:
283+
layer_idx = 0
284+
285+
# For CustomRingKVCache, we need to handle the sequence length differently
286+
layer_cache = self.kv_cache[layer_idx]
287+
if self.is_sliding[layer_idx]:
288+
# CustomRingKVCache cache_position_manager which
289+
# maintains cache position for each slot in the kv cache
290+
# we return the max position + 1 to indicate max position
291+
# seen so far. Not sure if thats the correct interpretation
292+
# of sequence length
293+
return layer_cache.cache_positions_manager.cache_positions.max().item() + 1
294+
return (layer_cache.k_cache[0, :, 0].any(dim=-1)).sum()
295+
296+
def get_layer_cache(self, layer_idx: int):
297+
"""
298+
Get the cache for a specific layer. This method is dynamo-traceable.
299+
300+
Args:
301+
layer_idx (int): The layer index
302+
303+
Returns:
304+
The cache instance for the specified layer (CustomKVCache or CustomRingKVCache)
305+
"""
306+
return self.kv_cache[layer_idx]
307+
308+
183309
def replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype):
184310
"""
185311
Replace all KV caches in the module with ETCustomStaticCache.
@@ -192,22 +318,6 @@ def replace_with_et_custom_kv_cache(module, config, generation_config, cache_dty
192318
Returns:
193319
The modified module
194320
"""
195-
# Ensure custom ops are registered
196-
try:
197-
op = torch.ops.llama.update_cache
198-
assert op is not None
199-
except Exception:
200-
try:
201-
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
202-
203-
op = torch.ops.llama.update_cache
204-
assert op is not None
205-
except ImportError:
206-
raise ImportError(
207-
"ExecutorTorch custom operations are not available. "
208-
"Please install executorch with custom operations support."
209-
)
210-
211321
# Recursively replace KV caches
212322
return _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype)
213323

@@ -223,33 +333,73 @@ def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dt
223333
Returns:
224334
The modified module
225335
"""
226-
assert hasattr(module, "static_cache")
227-
assert isinstance(
228-
module.static_cache, StaticCache
229-
), "Only StaticCache transform is supported. Hybrid cache with local global attention is not yet supported"
230-
# TODO: Add replace_cache to exported module
231-
# in transformer's executorch.py
232-
if getattr(module, "replace_cache", None) is not None:
233-
static_cache = ETCustomStaticCache(
234-
config=config,
235-
max_batch_size=generation_config.cache_config.batch_size,
236-
max_cache_len=generation_config.cache_config.max_cache_len,
237-
device=generation_config.cache_config.device,
238-
dtype=cache_dtype,
239-
)
240-
module.replace_cache(static_cache)
336+
# Check if module has static_cache (TorchExportableModuleWithStaticCache)
337+
if hasattr(module, "static_cache"):
338+
assert isinstance(module.static_cache, StaticCache), f"Expected StaticCache, got {type(module.static_cache)}"
339+
340+
# TODO: Add replace_cache to exported module
341+
# in transformer's executorch.py
342+
if getattr(module, "replace_cache", None) is not None:
343+
static_cache = ETCustomStaticCache(
344+
config=config,
345+
max_batch_size=generation_config.cache_config.batch_size,
346+
max_cache_len=generation_config.cache_config.max_cache_len,
347+
device=generation_config.cache_config.device,
348+
dtype=cache_dtype,
349+
)
350+
module.replace_cache(static_cache)
351+
else:
352+
module.static_cache = ETCustomStaticCache(
353+
config=config,
354+
max_batch_size=generation_config.cache_config.batch_size,
355+
max_cache_len=generation_config.cache_config.max_cache_len,
356+
device=generation_config.cache_config.device,
357+
dtype=cache_dtype,
358+
)
359+
# Dont know why we need to this even though
360+
# CustomKVCache registers the attributes
361+
for i in range(len(module.static_cache.kv_cache)):
362+
setattr(module, f"key_cache_{i}", module.static_cache.kv_cache[i].k_cache)
363+
setattr(module, f"value_cache_{i}", module.static_cache.kv_cache[i].v_cache)
364+
365+
# Check if module has cache (TorchExportableModuleWithHybridCache)
366+
elif hasattr(module, "cache"):
367+
assert isinstance(module.cache, HybridCache), f"Expected HybridCache, got {type(module.cache)}"
368+
369+
# Replace with ETCustomHybridCache
370+
if getattr(module, "replace_cache", None) is not None:
371+
hybrid_cache = ETCustomHybridCache(
372+
config=config,
373+
max_batch_size=generation_config.cache_config.batch_size,
374+
max_cache_len=generation_config.cache_config.max_cache_len,
375+
device=generation_config.cache_config.device,
376+
dtype=cache_dtype,
377+
)
378+
module.replace_cache(hybrid_cache)
379+
else:
380+
module.cache = ETCustomHybridCache(
381+
config=config,
382+
max_batch_size=generation_config.cache_config.batch_size,
383+
max_cache_len=generation_config.cache_config.max_cache_len,
384+
device=generation_config.cache_config.device,
385+
dtype=cache_dtype,
386+
)
387+
# Register cache attributes for each layer
388+
for i in range(len(module.cache.kv_cache)):
389+
setattr(module, f"key_cache_{i}", module.cache.kv_cache[i].k_cache)
390+
setattr(module, f"value_cache_{i}", module.cache.kv_cache[i].v_cache)
391+
if module.cache.is_sliding[i]:
392+
# Register cache_positions as buffer for sliding window layers
393+
# This prevents it from being traced as a constant
394+
module.register_buffer(
395+
f"cache_positions_{i}",
396+
module.cache.kv_cache[i].cache_positions_manager.cache_positions,
397+
persistent=False,
398+
)
241399
else:
242-
module.static_cache = ETCustomStaticCache(
243-
config=config,
244-
max_batch_size=generation_config.cache_config.batch_size,
245-
max_cache_len=generation_config.cache_config.max_cache_len,
246-
device=generation_config.cache_config.device,
247-
dtype=cache_dtype,
400+
raise ValueError(
401+
"Module must have either 'static_cache' (TorchExportableModuleWithStaticCache) "
402+
"or 'cache' (TorchExportableModuleWithHybridCache) attribute"
248403
)
249-
# Dont know why we need to this even though
250-
# CustomKVCache registers the attributes
251-
for i in range(len(module.static_cache.kv_cache)):
252-
setattr(module, f"key_cache_{i}", module.static_cache.kv_cache[i].k_cache)
253-
setattr(module, f"value_cache_{i}", module.static_cache.kv_cache[i].v_cache)
254404

255405
return module

0 commit comments

Comments
 (0)