Skip to content

Commit e1d9cd1

Browse files
authored
refactor: leverage ExecutorTorch's CustomKVCache for per-layer cache management (#78)
Summary: - Replace separate key/value cache tensors with CustomKVCache instances - Use CustomKVCache.update() method instead of manual torch.ops.llama.update_cache calls - Create self.kv_cache list with one CustomKVCache per layer - Maintain compatibility with parent StaticCache class - Fix type safety issues in get_seq_length method - Simplify cache update logic by leveraging ExecutorTorch's proven implementation Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
1 parent ae9da65 commit e1d9cd1

File tree

1 file changed

+38
-39
lines changed

1 file changed

+38
-39
lines changed

optimum/executorch/attentions/custom_kv_cache.py

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@
1818
except ImportError:
1919
raise ImportError("transformers is not installed. Please install it to use StaticCache.")
2020

21+
try:
22+
from executorch.examples.models.llama.source_transformation.custom_kv_cache import (
23+
CustomKVCache,
24+
)
25+
except ImportError:
26+
raise ImportError("ExecutorTorch is not installed. Please install it to use CustomKVCache.")
27+
2128

2229
class ETCustomStaticCache(StaticCache):
2330
"""
@@ -45,26 +52,20 @@ def __init__(
4552

4653
# make sure layer_device_map is none
4754
assert layer_device_map is None
48-
49-
# Clear existing caches
50-
self.key_cache = []
51-
self.value_cache = []
52-
53-
# Initialize cache buffers with our custom shape
54-
cache_shape = (
55-
self.max_batch_size,
56-
self.max_cache_len,
57-
self.num_key_value_heads,
58-
self.head_dim,
59-
)
6055
assert device is None or device == "cpu", "Device must be None or 'cpu'"
6156

57+
# Create a list of CustomKVCache instances, one per layer
58+
kv_cache_list = []
6259
for _ in range(config.num_hidden_layers):
63-
self.new_layer_key_cache = torch.zeros(cache_shape, dtype=dtype, device="cpu")
64-
self.new_layer_value_cache = torch.zeros(cache_shape, dtype=dtype, device="cpu")
65-
66-
self.key_cache.append(self.new_layer_key_cache)
67-
self.value_cache.append(self.new_layer_value_cache)
60+
layer_cache = CustomKVCache(
61+
max_batch_size=self.max_batch_size,
62+
max_context_length=self.max_cache_len,
63+
n_heads=self.num_key_value_heads,
64+
head_dim=self.head_dim,
65+
dtype=dtype,
66+
)
67+
kv_cache_list.append(layer_cache)
68+
self.kv_cache = torch.nn.ModuleList(kv_cache_list)
6869

6970
def update(
7071
self,
@@ -75,7 +76,7 @@ def update(
7576
) -> Tuple[torch.Tensor, torch.Tensor]:
7677
"""
7778
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`
78-
using custom operations.
79+
using ExecutorTorch's CustomKVCache.
7980
8081
Args:
8182
key_states (`torch.Tensor`):
@@ -95,32 +96,28 @@ def update(
9596
# Get cache position from cache_kwargs (used by StaticCache)
9697
cache_position = cache_kwargs.get("cache_position")
9798
assert cache_position is not None
99+
assert isinstance(cache_position, torch.Tensor)
98100

99-
# Get the current cache for this layer
100-
k_out = self.key_cache[layer_idx]
101-
v_out = self.value_cache[layer_idx]
102-
103-
# Transpose key and value states to match our cache shape
104-
# From [batch_size, n_heads, seq_len, head_dim] to [batch_size, seq_len, n_heads, head_dim]
105-
k_val = key_states.transpose(1, 2)
106-
v_val = value_states.transpose(1, 2)
101+
# Get the CustomKVCache instance for this layer
102+
layer_cache = self.kv_cache[layer_idx]
107103

108-
# Use custom operations to update the cache
109-
# Update cache with indices for more complex update patterns
110-
assert isinstance(cache_position, torch.Tensor)
111-
start_pos = cache_position[0].item()
112-
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
113-
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
104+
# Use the CustomKVCache's update method
105+
# CustomKVCache expects input_pos, k_val, v_val and handles the transpose internally
106+
k_out, v_out = layer_cache.update(
107+
input_pos=cache_position,
108+
k_val=key_states,
109+
v_val=value_states,
110+
)
114111

115-
# Return the updated cache in the format expected by the model
116-
# Transpose back from [batch_size, seq_len, n_heads, head_dim] to [batch_size, n_heads, seq_len, head_dim]
117-
return k_out.transpose(1, 2), v_out.transpose(1, 2)
112+
return k_out, v_out
118113

119114
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
120115
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
121116
# Occupied cache == any slot in the 2nd dim (sequence length) holds a non-zero value
122117
# This is different from StaticCache which checks the 3rd dim
123-
return (self.key_cache[layer_idx][0, :, 0].any(dim=-1)).sum()
118+
if layer_idx is None:
119+
layer_idx = 0
120+
return (self.kv_cache[layer_idx].k_cache[0, :, 0].any(dim=-1)).sum()
124121

125122
@classmethod
126123
def from_legacy_cache(
@@ -249,8 +246,10 @@ def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dt
249246
device=generation_config.cache_config.device,
250247
dtype=cache_dtype,
251248
)
252-
for i in range(len(module.static_cache.key_cache)):
253-
setattr(module, f"key_cache_{i}", module.static_cache.key_cache[i])
254-
setattr(module, f"value_cache_{i}", module.static_cache.value_cache[i])
249+
# Dont know why we need to this even though
250+
# CustomKVCache registers the attributes
251+
for i in range(len(module.static_cache.kv_cache)):
252+
setattr(module, f"key_cache_{i}", module.static_cache.kv_cache[i].k_cache)
253+
setattr(module, f"value_cache_{i}", module.static_cache.kv_cache[i].v_cache)
255254

256255
return module

0 commit comments

Comments
 (0)