282 lines
10 KiB
Python
282 lines
10 KiB
Python
"""KV cache with paged allocation and prefix sharing.
|
|
|
|
Implementation based on techniques from:
|
|
Cache-Craft: Managing Chunk-Caches for Efficient Retrieval-Augmented Generation.
|
|
|
|
See docs/CITATIONS.md for full citation details.
|
|
"""
|
|
|
|
import copy
|
|
import hashlib
|
|
from typing import Any, Optional
|
|
|
|
from llmds.paged_allocator import PagedAllocator
|
|
|
|
|
|
class KVCache:
|
|
"""
|
|
KV cache with paged allocation, prefix sharing, and deduplication.
|
|
|
|
Implements copy-on-write (COW) for prefix sharing: shared pages are
|
|
read-only until a write occurs, at which point they are copied.
|
|
|
|
Reference:
|
|
Cache-Craft: Managing Chunk-Caches for Efficient Retrieval-Augmented Generation.
|
|
|
|
**Copy-on-Write Semantics:**
|
|
- Shared pages (from prefix sharing) are read-only
|
|
- Attempts to modify shared pages trigger lazy copying
|
|
- Each sequence maintains its own copy of modified pages
|
|
- Original shared pages remain unchanged for other sequences
|
|
|
|
Supports hash-based deduplication of repeated system prompts.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
page_size: int = 512,
|
|
max_pages: int = 10000,
|
|
enable_prefix_sharing: bool = True,
|
|
):
|
|
"""
|
|
Initialize KV cache.
|
|
|
|
Args:
|
|
page_size: Size of each KV cache page in tokens
|
|
max_pages: Maximum number of pages to allocate
|
|
enable_prefix_sharing: Enable prefix sharing optimization
|
|
"""
|
|
self.allocator = PagedAllocator(page_size, max_pages)
|
|
self.page_size = page_size
|
|
self._sequences: dict[int, list[int]] = {} # seq_id -> list[page_ids]
|
|
self._kv_data: dict[int, Any] = {} # page_id -> KV data
|
|
self._prefix_map: dict[str, list[int]] = {} # hash -> page_ids
|
|
self._page_refs: dict[int, int] = {} # page_id -> reference count
|
|
self._shared_pages: set[int] = set() # page_ids that are shared (read-only)
|
|
self._enable_prefix_sharing = enable_prefix_sharing
|
|
self._seq_counter = 0
|
|
self._prefix_shares = 0
|
|
|
|
def _hash_prefix(self, prefix: list[int]) -> str:
|
|
"""Compute hash of prefix tokens."""
|
|
prefix_str = ",".join(map(str, prefix[:100])) # Limit length
|
|
return hashlib.sha256(prefix_str.encode()).hexdigest()
|
|
|
|
def _copy_if_shared(self, page_id: int, seq_id: int) -> int:
|
|
"""
|
|
Copy-on-write: if page is shared, create a new copy.
|
|
|
|
Args:
|
|
page_id: Original page ID (may be shared)
|
|
seq_id: Sequence ID requesting the copy
|
|
|
|
Returns:
|
|
New page_id if copied, original page_id if not shared
|
|
"""
|
|
if page_id not in self._shared_pages:
|
|
return page_id
|
|
|
|
# Page is shared - need to copy
|
|
new_page_id = self.allocator.alloc(1)[0]
|
|
|
|
# Copy the data
|
|
if page_id in self._kv_data:
|
|
self._kv_data[new_page_id] = copy.deepcopy(self._kv_data[page_id])
|
|
else:
|
|
# Empty page
|
|
self._kv_data[new_page_id] = []
|
|
|
|
# Decrement reference count of original
|
|
self._page_refs[page_id] = self._page_refs.get(page_id, 1) - 1
|
|
if self._page_refs[page_id] <= 0:
|
|
self._shared_pages.discard(page_id)
|
|
if page_id in self._page_refs:
|
|
del self._page_refs[page_id]
|
|
|
|
# New page is not shared (single owner)
|
|
self._page_refs[new_page_id] = 1
|
|
|
|
return new_page_id
|
|
|
|
def attach(
|
|
self,
|
|
seq_id: int,
|
|
kv_tokens: list[Any],
|
|
prefix_tokens: Optional[list[int]] = None,
|
|
) -> None:
|
|
"""
|
|
Attach KV cache for a sequence.
|
|
|
|
Implements copy-on-write: if prefix sharing is used, shared pages
|
|
are referenced but will be copied on first write.
|
|
|
|
Args:
|
|
seq_id: Sequence identifier
|
|
kv_tokens: KV tokens to cache
|
|
prefix_tokens: Optional prefix tokens for sharing
|
|
"""
|
|
if seq_id in self._sequences:
|
|
self.detach(seq_id)
|
|
|
|
pages_needed = (len(kv_tokens) + self.page_size - 1) // self.page_size
|
|
new_page_ids = self.allocator.alloc(pages_needed)
|
|
page_ids: list[int] = []
|
|
|
|
# Try prefix sharing if enabled
|
|
shared_prefix_pages: list[int] = []
|
|
if self._enable_prefix_sharing and prefix_tokens:
|
|
prefix_hash = self._hash_prefix(prefix_tokens)
|
|
if prefix_hash in self._prefix_map:
|
|
shared_prefix_pages = self._prefix_map[prefix_hash]
|
|
# Reference shared pages (will be copied on write if needed)
|
|
num_prefix_pages = min(len(shared_prefix_pages), pages_needed)
|
|
page_ids.extend(shared_prefix_pages[:num_prefix_pages])
|
|
|
|
# Update reference counts for shared pages
|
|
for shared_page_id in shared_prefix_pages[:num_prefix_pages]:
|
|
self._page_refs[shared_page_id] = self._page_refs.get(shared_page_id, 0) + 1
|
|
self._shared_pages.add(shared_page_id)
|
|
|
|
# Use remaining allocated pages for non-shared suffix
|
|
page_ids.extend(new_page_ids[num_prefix_pages:])
|
|
self._prefix_shares += 1
|
|
else:
|
|
# First time seeing this prefix - mark these pages as potential shared
|
|
num_prefix_pages = min(
|
|
(len(prefix_tokens) + self.page_size - 1) // self.page_size,
|
|
pages_needed
|
|
)
|
|
self._prefix_map[prefix_hash] = new_page_ids[:num_prefix_pages]
|
|
page_ids = new_page_ids
|
|
else:
|
|
page_ids = new_page_ids
|
|
|
|
# Store KV data with copy-on-write semantics
|
|
# For shared pages: if data differs, trigger COW; otherwise, reference existing
|
|
for i, page_id in enumerate(page_ids):
|
|
start = i * self.page_size
|
|
end = min(start + self.page_size, len(kv_tokens))
|
|
page_data = kv_tokens[start:end]
|
|
|
|
# Check if this page is shared
|
|
if page_id in self._shared_pages:
|
|
# Page is shared - check if data matches
|
|
existing_data = self._kv_data.get(page_id, [])
|
|
if existing_data != page_data:
|
|
# Data differs - trigger copy-on-write
|
|
page_id = self._copy_if_shared(page_id, seq_id)
|
|
page_ids[i] = page_id # Update the page_id in our list
|
|
# Now safe to write (page is not shared)
|
|
self._kv_data[page_id] = page_data
|
|
if page_id not in self._page_refs:
|
|
self._page_refs[page_id] = 1
|
|
# If data matches, no need to copy or write - just reference the shared page
|
|
else:
|
|
# Non-shared page - safe to write directly
|
|
self._kv_data[page_id] = page_data
|
|
if page_id not in self._page_refs:
|
|
self._page_refs[page_id] = 1
|
|
|
|
self._sequences[seq_id] = page_ids
|
|
|
|
def detach(self, seq_id: int) -> None:
|
|
"""
|
|
Detach and free KV cache for a sequence.
|
|
|
|
Decrements reference counts for shared pages. Pages are only freed
|
|
when their reference count reaches zero.
|
|
|
|
Args:
|
|
seq_id: Sequence identifier
|
|
"""
|
|
if seq_id not in self._sequences:
|
|
return
|
|
|
|
page_ids = self._sequences[seq_id]
|
|
|
|
# Update reference counts and free pages
|
|
pages_to_free: list[int] = []
|
|
for page_id in page_ids:
|
|
if page_id in self._shared_pages:
|
|
# Shared page - decrement reference count
|
|
self._page_refs[page_id] = self._page_refs.get(page_id, 1) - 1
|
|
if self._page_refs[page_id] <= 0:
|
|
# No more references - can free
|
|
self._shared_pages.discard(page_id)
|
|
if page_id in self._kv_data:
|
|
del self._kv_data[page_id]
|
|
if page_id in self._page_refs:
|
|
del self._page_refs[page_id]
|
|
pages_to_free.append(page_id)
|
|
else:
|
|
# Non-shared page - free immediately
|
|
if page_id in self._kv_data:
|
|
del self._kv_data[page_id]
|
|
if page_id in self._page_refs:
|
|
del self._page_refs[page_id]
|
|
pages_to_free.append(page_id)
|
|
|
|
# Free pages via allocator
|
|
if pages_to_free:
|
|
self.allocator.free(pages_to_free)
|
|
|
|
del self._sequences[seq_id]
|
|
|
|
def get(self, seq_id: int) -> Optional[list[Any]]:
|
|
"""
|
|
Get KV cache for a sequence.
|
|
|
|
Returns a copy of the data to prevent external modifications
|
|
from affecting shared pages.
|
|
|
|
Args:
|
|
seq_id: Sequence identifier
|
|
|
|
Returns:
|
|
List of KV tokens or None if not found
|
|
"""
|
|
if seq_id not in self._sequences:
|
|
return None
|
|
|
|
page_ids = self._sequences[seq_id]
|
|
kv_tokens = []
|
|
for page_id in page_ids:
|
|
if page_id in self._kv_data:
|
|
# Return copy to prevent external modification of shared pages
|
|
kv_tokens.extend(copy.deepcopy(self._kv_data[page_id]))
|
|
|
|
return kv_tokens
|
|
|
|
def stats(self) -> dict[str, Any]:
|
|
"""
|
|
Get cache statistics.
|
|
|
|
Returns:
|
|
Dictionary with cache statistics
|
|
"""
|
|
alloc_stats = self.allocator.stats()
|
|
return {
|
|
"total_sequences": len(self._sequences),
|
|
"total_pages": alloc_stats.total_pages,
|
|
"allocated_pages": alloc_stats.allocated_pages,
|
|
"free_pages": alloc_stats.free_pages,
|
|
"prefix_shares": self._prefix_shares,
|
|
"prefix_map_size": len(self._prefix_map),
|
|
"shared_pages_count": len(self._shared_pages),
|
|
"total_page_refs": sum(self._page_refs.values()),
|
|
}
|
|
|
|
def hook_speculative_decode(self, seq_id: int, draft_tokens: list[int]) -> None:
|
|
"""
|
|
Hook for speculative decoding compatibility.
|
|
|
|
Placeholder API for future implementation.
|
|
|
|
Args:
|
|
seq_id: Sequence identifier
|
|
draft_tokens: Draft tokens from speculative decoding
|
|
"""
|
|
# Placeholder for speculative decoding integration
|
|
pass
|
|
|