Files
llm-rag-ds-optimizer/llmds/kv_cache.py

282 lines
10 KiB
Python

"""KV cache with paged allocation and prefix sharing.
Implementation based on techniques from:
Cache-Craft: Managing Chunk-Caches for Efficient Retrieval-Augmented Generation.
See docs/CITATIONS.md for full citation details.
"""
import copy
import hashlib
from typing import Any, Optional
from llmds.paged_allocator import PagedAllocator
class KVCache:
"""
KV cache with paged allocation, prefix sharing, and deduplication.
Implements copy-on-write (COW) for prefix sharing: shared pages are
read-only until a write occurs, at which point they are copied.
Reference:
Cache-Craft: Managing Chunk-Caches for Efficient Retrieval-Augmented Generation.
**Copy-on-Write Semantics:**
- Shared pages (from prefix sharing) are read-only
- Attempts to modify shared pages trigger lazy copying
- Each sequence maintains its own copy of modified pages
- Original shared pages remain unchanged for other sequences
Supports hash-based deduplication of repeated system prompts.
"""
def __init__(
self,
page_size: int = 512,
max_pages: int = 10000,
enable_prefix_sharing: bool = True,
):
"""
Initialize KV cache.
Args:
page_size: Size of each KV cache page in tokens
max_pages: Maximum number of pages to allocate
enable_prefix_sharing: Enable prefix sharing optimization
"""
self.allocator = PagedAllocator(page_size, max_pages)
self.page_size = page_size
self._sequences: dict[int, list[int]] = {} # seq_id -> list[page_ids]
self._kv_data: dict[int, Any] = {} # page_id -> KV data
self._prefix_map: dict[str, list[int]] = {} # hash -> page_ids
self._page_refs: dict[int, int] = {} # page_id -> reference count
self._shared_pages: set[int] = set() # page_ids that are shared (read-only)
self._enable_prefix_sharing = enable_prefix_sharing
self._seq_counter = 0
self._prefix_shares = 0
def _hash_prefix(self, prefix: list[int]) -> str:
"""Compute hash of prefix tokens."""
prefix_str = ",".join(map(str, prefix[:100])) # Limit length
return hashlib.sha256(prefix_str.encode()).hexdigest()
def _copy_if_shared(self, page_id: int, seq_id: int) -> int:
"""
Copy-on-write: if page is shared, create a new copy.
Args:
page_id: Original page ID (may be shared)
seq_id: Sequence ID requesting the copy
Returns:
New page_id if copied, original page_id if not shared
"""
if page_id not in self._shared_pages:
return page_id
# Page is shared - need to copy
new_page_id = self.allocator.alloc(1)[0]
# Copy the data
if page_id in self._kv_data:
self._kv_data[new_page_id] = copy.deepcopy(self._kv_data[page_id])
else:
# Empty page
self._kv_data[new_page_id] = []
# Decrement reference count of original
self._page_refs[page_id] = self._page_refs.get(page_id, 1) - 1
if self._page_refs[page_id] <= 0:
self._shared_pages.discard(page_id)
if page_id in self._page_refs:
del self._page_refs[page_id]
# New page is not shared (single owner)
self._page_refs[new_page_id] = 1
return new_page_id
def attach(
self,
seq_id: int,
kv_tokens: list[Any],
prefix_tokens: Optional[list[int]] = None,
) -> None:
"""
Attach KV cache for a sequence.
Implements copy-on-write: if prefix sharing is used, shared pages
are referenced but will be copied on first write.
Args:
seq_id: Sequence identifier
kv_tokens: KV tokens to cache
prefix_tokens: Optional prefix tokens for sharing
"""
if seq_id in self._sequences:
self.detach(seq_id)
pages_needed = (len(kv_tokens) + self.page_size - 1) // self.page_size
new_page_ids = self.allocator.alloc(pages_needed)
page_ids: list[int] = []
# Try prefix sharing if enabled
shared_prefix_pages: list[int] = []
if self._enable_prefix_sharing and prefix_tokens:
prefix_hash = self._hash_prefix(prefix_tokens)
if prefix_hash in self._prefix_map:
shared_prefix_pages = self._prefix_map[prefix_hash]
# Reference shared pages (will be copied on write if needed)
num_prefix_pages = min(len(shared_prefix_pages), pages_needed)
page_ids.extend(shared_prefix_pages[:num_prefix_pages])
# Update reference counts for shared pages
for shared_page_id in shared_prefix_pages[:num_prefix_pages]:
self._page_refs[shared_page_id] = self._page_refs.get(shared_page_id, 0) + 1
self._shared_pages.add(shared_page_id)
# Use remaining allocated pages for non-shared suffix
page_ids.extend(new_page_ids[num_prefix_pages:])
self._prefix_shares += 1
else:
# First time seeing this prefix - mark these pages as potential shared
num_prefix_pages = min(
(len(prefix_tokens) + self.page_size - 1) // self.page_size,
pages_needed
)
self._prefix_map[prefix_hash] = new_page_ids[:num_prefix_pages]
page_ids = new_page_ids
else:
page_ids = new_page_ids
# Store KV data with copy-on-write semantics
# For shared pages: if data differs, trigger COW; otherwise, reference existing
for i, page_id in enumerate(page_ids):
start = i * self.page_size
end = min(start + self.page_size, len(kv_tokens))
page_data = kv_tokens[start:end]
# Check if this page is shared
if page_id in self._shared_pages:
# Page is shared - check if data matches
existing_data = self._kv_data.get(page_id, [])
if existing_data != page_data:
# Data differs - trigger copy-on-write
page_id = self._copy_if_shared(page_id, seq_id)
page_ids[i] = page_id # Update the page_id in our list
# Now safe to write (page is not shared)
self._kv_data[page_id] = page_data
if page_id not in self._page_refs:
self._page_refs[page_id] = 1
# If data matches, no need to copy or write - just reference the shared page
else:
# Non-shared page - safe to write directly
self._kv_data[page_id] = page_data
if page_id not in self._page_refs:
self._page_refs[page_id] = 1
self._sequences[seq_id] = page_ids
def detach(self, seq_id: int) -> None:
"""
Detach and free KV cache for a sequence.
Decrements reference counts for shared pages. Pages are only freed
when their reference count reaches zero.
Args:
seq_id: Sequence identifier
"""
if seq_id not in self._sequences:
return
page_ids = self._sequences[seq_id]
# Update reference counts and free pages
pages_to_free: list[int] = []
for page_id in page_ids:
if page_id in self._shared_pages:
# Shared page - decrement reference count
self._page_refs[page_id] = self._page_refs.get(page_id, 1) - 1
if self._page_refs[page_id] <= 0:
# No more references - can free
self._shared_pages.discard(page_id)
if page_id in self._kv_data:
del self._kv_data[page_id]
if page_id in self._page_refs:
del self._page_refs[page_id]
pages_to_free.append(page_id)
else:
# Non-shared page - free immediately
if page_id in self._kv_data:
del self._kv_data[page_id]
if page_id in self._page_refs:
del self._page_refs[page_id]
pages_to_free.append(page_id)
# Free pages via allocator
if pages_to_free:
self.allocator.free(pages_to_free)
del self._sequences[seq_id]
def get(self, seq_id: int) -> Optional[list[Any]]:
"""
Get KV cache for a sequence.
Returns a copy of the data to prevent external modifications
from affecting shared pages.
Args:
seq_id: Sequence identifier
Returns:
List of KV tokens or None if not found
"""
if seq_id not in self._sequences:
return None
page_ids = self._sequences[seq_id]
kv_tokens = []
for page_id in page_ids:
if page_id in self._kv_data:
# Return copy to prevent external modification of shared pages
kv_tokens.extend(copy.deepcopy(self._kv_data[page_id]))
return kv_tokens
def stats(self) -> dict[str, Any]:
"""
Get cache statistics.
Returns:
Dictionary with cache statistics
"""
alloc_stats = self.allocator.stats()
return {
"total_sequences": len(self._sequences),
"total_pages": alloc_stats.total_pages,
"allocated_pages": alloc_stats.allocated_pages,
"free_pages": alloc_stats.free_pages,
"prefix_shares": self._prefix_shares,
"prefix_map_size": len(self._prefix_map),
"shared_pages_count": len(self._shared_pages),
"total_page_refs": sum(self._page_refs.values()),
}
def hook_speculative_decode(self, seq_id: int, draft_tokens: list[int]) -> None:
"""
Hook for speculative decoding compatibility.
Placeholder API for future implementation.
Args:
seq_id: Sequence identifier
draft_tokens: Draft tokens from speculative decoding
"""
# Placeholder for speculative decoding integration
pass