Adding paper
This commit is contained in:
@@ -119,8 +119,17 @@ class OptimizedMultiHeadAttention(nn.Module):
|
||||
|
||||
# Use KV cache if available and enabled
|
||||
if use_cache and self.kv_cache is not None:
|
||||
# Append new keys and values to cache
|
||||
self.kv_cache.append(K, V)
|
||||
# Get current cache length
|
||||
cached_len = self.kv_cache.keys.shape[2] if self.kv_cache.keys.numel() > 0 else 0
|
||||
|
||||
# Only append new tokens (those after the cached length)
|
||||
if cached_len < seq_len:
|
||||
# Extract only the new tokens to append
|
||||
new_k = K[:, :, cached_len:, :] # [batch_size, num_heads, new_seq_len, d_k]
|
||||
new_v = V[:, :, cached_len:, :]
|
||||
self.kv_cache.append(new_k, new_v)
|
||||
|
||||
# Use all cached keys and values
|
||||
K = self.kv_cache.keys
|
||||
V = self.kv_cache.values
|
||||
kv_seq_len = K.shape[2]
|
||||
@@ -328,9 +337,11 @@ class OptimizedInference:
|
||||
|
||||
generated = input_ids.clone()
|
||||
|
||||
# Process initial prompt to populate cache
|
||||
logits, _ = self.model(generated, use_cache=True)
|
||||
|
||||
for _ in range(max_length - input_ids.shape[1]):
|
||||
# Forward pass
|
||||
logits, _ = self.model(generated)
|
||||
# Get next token logits from last position
|
||||
next_token_logits = logits[:, -1, :] / temperature
|
||||
|
||||
# Apply top-k filtering
|
||||
@@ -370,6 +381,9 @@ class OptimizedInference:
|
||||
|
||||
# Append to generated sequence
|
||||
generated = torch.cat([generated, next_token], dim=1)
|
||||
|
||||
# Forward pass with cache enabled (processes full sequence but uses cached K/V)
|
||||
logits, _ = self.model(generated, use_cache=True)
|
||||
|
||||
# Clear KV cache
|
||||
for module in self.model.modules():
|
||||
|
||||
Reference in New Issue
Block a user