Skip to content

Add support for sliding window via ring kv cache #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions install_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def install_torch_nightly_deps():
"""Install torch related dependencies from pinned nightly"""
EXECUTORCH_NIGHTLY_VERSION = "dev20250620"
EXECUTORCH_NIGHTLY_VERSION = "dev20250625"
TORCHAO_NIGHTLY_VERSION = "dev20250620"
# Torch nightly is aligned with pinned nightly in https://github.com/pytorch/executorch/blob/main/install_requirements.py#L74
TORCH_NIGHTLY_VERSION = "dev20250601"
Expand Down Expand Up @@ -34,7 +34,7 @@ def install_dep_from_source():
"-m",
"pip",
"install",
"git+https://github.com/huggingface/transformers@51f94ea06d19a6308c61bbb4dc97c40aabd12bad#egg=transformers", # v4.52.4
"git+https://github.com/huggingface/transformers@37367c7d9fd23401c26e79a2b26253ab2d1b7236#egg=transformers",
]
)
subprocess.check_call(
Expand Down
254 changes: 202 additions & 52 deletions optimum/executorch/attentions/custom_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,19 @@
import torch


# If transformers is not installed, raise an ImportError
try:
from transformers.cache_utils import StaticCache
from transformers.cache_utils import HybridCache, StaticCache
except ImportError:
# If transformers is not installed, raise an ImportError
try:
from transformers.cache_utils import StaticCache
except ImportError:
raise ImportError("transformers is not installed. Please install it to use StaticCache.")
raise ImportError("transformers is not installed. Please install it to use Static/HybridCache.")

try:
from executorch.examples.models.llama.source_transformation.custom_kv_cache import (
CustomKVCache,
CustomRingKVCache,
)
except ImportError:
raise ImportError("ExecutorTorch is not installed. Please install it to use CustomKVCache.")
raise ImportError("ExecutorTorch is not installed. Please install it to use Custom Cache.")


class ETCustomStaticCache(StaticCache):
Expand Down Expand Up @@ -55,7 +53,7 @@ def __init__(
assert device is None or device == "cpu", "Device must be None or 'cpu'"

# Create a list of CustomKVCache instances, one per layer
kv_cache_list = []
self.kv_cache = torch.nn.ModuleList()
for _ in range(config.num_hidden_layers):
layer_cache = CustomKVCache(
max_batch_size=self.max_batch_size,
Expand All @@ -64,8 +62,7 @@ def __init__(
head_dim=self.head_dim,
dtype=dtype,
)
kv_cache_list.append(layer_cache)
self.kv_cache = torch.nn.ModuleList(kv_cache_list)
self.kv_cache.append(layer_cache)

def update(
self,
Expand Down Expand Up @@ -180,6 +177,135 @@ def from_legacy_cache(
)


# Need to figure out if I have to inherit from HybridCache or StaticCache
class ETCustomHybridCache(HybridCache):
"""
Custom Hybrid KV Cache implementation for ExecutorTorch that inherits from Hugging Face's HybridCache
but uses ExecutorTorch's CustomKVCache for global layers and CustomRingKVCache for sliding window layers.
"""

def __init__(
self,
config,
max_batch_size: int,
max_cache_len: Optional[int] = None,
device: Union[torch.device, str, None] = None,
dtype: torch.dtype = torch.float32,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
):
super().__init__(
config=config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=device,
dtype=dtype,
layer_device_map=layer_device_map,
)

# make sure layer_device_map is none
assert layer_device_map is None
assert device is None or device == "cpu", "Device must be None or 'cpu'"

self.cache_position = None
# Create a list of cache instances, one per layer
# Use CustomKVCache for global layers and CustomRingKVCache for sliding window layers
self.kv_cache = torch.nn.ModuleList()
for layer_idx in range(config.num_hidden_layers):
# newer version of transfomer has is_sliding defined
# for HybridCache
if self.is_sliding[layer_idx]:
# This is a sliding window layer
layer_cache = CustomRingKVCache(
max_batch_size=self.max_batch_size,
max_context_length=self.sliding_window_len,
n_heads=self.num_key_value_heads,
head_dim=self.head_dim,
dtype=dtype,
)
else:
layer_cache = CustomKVCache(
max_batch_size=self.max_batch_size,
max_context_length=self.max_cache_len,
n_heads=self.num_key_value_heads,
head_dim=self.head_dim,
dtype=dtype,
)
self.kv_cache.append(layer_cache)

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`
using ExecutorTorch's CustomKVCache or CustomRingKVCache depending on the layer type.

Args:
key_states (`torch.Tensor`):
The new key states to cache. Shape: [batch_size, n_heads, seq_len, head_dim]
value_states (`torch.Tensor`):
The new value states to cache. Shape: [batch_size, n_heads, seq_len, head_dim]
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache update.

Returns:
A tuple containing the updated key and value states.
"""
assert cache_kwargs is not None

# Get cache position from cache_kwargs (used by HybridCache)
cache_position = cache_kwargs.get("cache_position")
assert cache_position is not None
assert isinstance(cache_position, torch.Tensor)
self.cache_position = cache_position

# Get the cache instance for this layer (either CustomKVCache or CustomRingKVCache)
layer_cache = self.kv_cache[layer_idx]

# Use the cache's update method
# Both CustomKVCache and CustomRingKVCache have the same update interface
k_out, v_out = layer_cache.update(
input_pos=cache_position,
k_val=key_states,
v_val=value_states,
)

return k_out, v_out

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if layer_idx is None:
layer_idx = 0

# For CustomRingKVCache, we need to handle the sequence length differently
layer_cache = self.kv_cache[layer_idx]
if self.is_sliding[layer_idx]:
# CustomRingKVCache cache_position_manager which
# maintains cache position for each slot in the kv cache
# we return the max position + 1 to indicate max position
# seen so far. Not sure if thats the correct interpretation
# of sequence length
return layer_cache.cache_positions_manager.cache_positions.max().item() + 1
return (layer_cache.k_cache[0, :, 0].any(dim=-1)).sum()

def get_layer_cache(self, layer_idx: int):
"""
Get the cache for a specific layer. This method is dynamo-traceable.

Args:
layer_idx (int): The layer index

Returns:
The cache instance for the specified layer (CustomKVCache or CustomRingKVCache)
"""
return self.kv_cache[layer_idx]


def replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype):
"""
Replace all KV caches in the module with ETCustomStaticCache.
Expand All @@ -192,22 +318,6 @@ def replace_with_et_custom_kv_cache(module, config, generation_config, cache_dty
Returns:
The modified module
"""
# Ensure custom ops are registered
try:
op = torch.ops.llama.update_cache
assert op is not None
except Exception:
try:
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401

op = torch.ops.llama.update_cache
assert op is not None
except ImportError:
raise ImportError(
"ExecutorTorch custom operations are not available. "
"Please install executorch with custom operations support."
)

# Recursively replace KV caches
return _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype)

Expand All @@ -223,33 +333,73 @@ def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dt
Returns:
The modified module
"""
assert hasattr(module, "static_cache")
assert isinstance(
module.static_cache, StaticCache
), "Only StaticCache transform is supported. Hybrid cache with local global attention is not yet supported"
# TODO: Add replace_cache to exported module
# in transformer's executorch.py
if getattr(module, "replace_cache", None) is not None:
static_cache = ETCustomStaticCache(
config=config,
max_batch_size=generation_config.cache_config.batch_size,
max_cache_len=generation_config.cache_config.max_cache_len,
device=generation_config.cache_config.device,
dtype=cache_dtype,
)
module.replace_cache(static_cache)
# Check if module has static_cache (TorchExportableModuleWithStaticCache)
if hasattr(module, "static_cache"):
assert isinstance(module.static_cache, StaticCache), f"Expected StaticCache, got {type(module.static_cache)}"

# TODO: Add replace_cache to exported module
# in transformer's executorch.py
if getattr(module, "replace_cache", None) is not None:
static_cache = ETCustomStaticCache(
config=config,
max_batch_size=generation_config.cache_config.batch_size,
max_cache_len=generation_config.cache_config.max_cache_len,
device=generation_config.cache_config.device,
dtype=cache_dtype,
)
module.replace_cache(static_cache)
else:
module.static_cache = ETCustomStaticCache(
config=config,
max_batch_size=generation_config.cache_config.batch_size,
max_cache_len=generation_config.cache_config.max_cache_len,
device=generation_config.cache_config.device,
dtype=cache_dtype,
)
# Dont know why we need to this even though
# CustomKVCache registers the attributes
for i in range(len(module.static_cache.kv_cache)):
setattr(module, f"key_cache_{i}", module.static_cache.kv_cache[i].k_cache)
setattr(module, f"value_cache_{i}", module.static_cache.kv_cache[i].v_cache)

# Check if module has cache (TorchExportableModuleWithHybridCache)
elif hasattr(module, "cache"):
assert isinstance(module.cache, HybridCache), f"Expected HybridCache, got {type(module.cache)}"

# Replace with ETCustomHybridCache
if getattr(module, "replace_cache", None) is not None:
hybrid_cache = ETCustomHybridCache(
config=config,
max_batch_size=generation_config.cache_config.batch_size,
max_cache_len=generation_config.cache_config.max_cache_len,
device=generation_config.cache_config.device,
dtype=cache_dtype,
)
module.replace_cache(hybrid_cache)
else:
module.cache = ETCustomHybridCache(
config=config,
max_batch_size=generation_config.cache_config.batch_size,
max_cache_len=generation_config.cache_config.max_cache_len,
device=generation_config.cache_config.device,
dtype=cache_dtype,
)
# Register cache attributes for each layer
for i in range(len(module.cache.kv_cache)):
setattr(module, f"key_cache_{i}", module.cache.kv_cache[i].k_cache)
setattr(module, f"value_cache_{i}", module.cache.kv_cache[i].v_cache)
if module.cache.is_sliding[i]:
# Register cache_positions as buffer for sliding window layers
# This prevents it from being traced as a constant
module.register_buffer(
f"cache_positions_{i}",
module.cache.kv_cache[i].cache_positions_manager.cache_positions,
persistent=False,
)
else:
module.static_cache = ETCustomStaticCache(
config=config,
max_batch_size=generation_config.cache_config.batch_size,
max_cache_len=generation_config.cache_config.max_cache_len,
device=generation_config.cache_config.device,
dtype=cache_dtype,
raise ValueError(
"Module must have either 'static_cache' (TorchExportableModuleWithStaticCache) "
"or 'cache' (TorchExportableModuleWithHybridCache) attribute"
)
# Dont know why we need to this even though
# CustomKVCache registers the attributes
for i in range(len(module.static_cache.kv_cache)):
setattr(module, f"key_cache_{i}", module.static_cache.kv_cache[i].k_cache)
setattr(module, f"value_cache_{i}", module.static_cache.kv_cache[i].v_cache)

return module
Loading
Loading