Adding paper
This commit is contained in:
@@ -125,6 +125,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
mask: Optional[torch.Tensor] = None,
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Forward pass through transformer block.
|
Forward pass through transformer block.
|
||||||
@@ -132,6 +133,7 @@ class TransformerBlock(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
x: Input tensor [batch_size, seq_len, d_model]
|
x: Input tensor [batch_size, seq_len, d_model]
|
||||||
mask: Optional attention mask
|
mask: Optional attention mask
|
||||||
|
use_cache: Whether to use KV cache (for optimized attention)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Output tensor [batch_size, seq_len, d_model]
|
Output tensor [batch_size, seq_len, d_model]
|
||||||
@@ -139,7 +141,11 @@ class TransformerBlock(nn.Module):
|
|||||||
# Pre-norm self-attention with residual connection
|
# Pre-norm self-attention with residual connection
|
||||||
residual = x
|
residual = x
|
||||||
x = self.norm1(x)
|
x = self.norm1(x)
|
||||||
attn_out, _ = self.self_attn(x, x, x, mask=mask)
|
# Pass use_cache to attention layer if it's OptimizedMultiHeadAttention
|
||||||
|
if isinstance(self.self_attn, OptimizedMultiHeadAttention):
|
||||||
|
attn_out, _ = self.self_attn(x, x, x, mask=mask, use_cache=use_cache)
|
||||||
|
else:
|
||||||
|
attn_out, _ = self.self_attn(x, x, x, mask=mask)
|
||||||
x = residual + self.dropout(attn_out)
|
x = residual + self.dropout(attn_out)
|
||||||
|
|
||||||
# Pre-norm feed-forward with residual connection
|
# Pre-norm feed-forward with residual connection
|
||||||
|
|||||||
@@ -119,8 +119,17 @@ class OptimizedMultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
# Use KV cache if available and enabled
|
# Use KV cache if available and enabled
|
||||||
if use_cache and self.kv_cache is not None:
|
if use_cache and self.kv_cache is not None:
|
||||||
# Append new keys and values to cache
|
# Get current cache length
|
||||||
self.kv_cache.append(K, V)
|
cached_len = self.kv_cache.keys.shape[2] if self.kv_cache.keys.numel() > 0 else 0
|
||||||
|
|
||||||
|
# Only append new tokens (those after the cached length)
|
||||||
|
if cached_len < seq_len:
|
||||||
|
# Extract only the new tokens to append
|
||||||
|
new_k = K[:, :, cached_len:, :] # [batch_size, num_heads, new_seq_len, d_k]
|
||||||
|
new_v = V[:, :, cached_len:, :]
|
||||||
|
self.kv_cache.append(new_k, new_v)
|
||||||
|
|
||||||
|
# Use all cached keys and values
|
||||||
K = self.kv_cache.keys
|
K = self.kv_cache.keys
|
||||||
V = self.kv_cache.values
|
V = self.kv_cache.values
|
||||||
kv_seq_len = K.shape[2]
|
kv_seq_len = K.shape[2]
|
||||||
@@ -328,9 +337,11 @@ class OptimizedInference:
|
|||||||
|
|
||||||
generated = input_ids.clone()
|
generated = input_ids.clone()
|
||||||
|
|
||||||
|
# Process initial prompt to populate cache
|
||||||
|
logits, _ = self.model(generated, use_cache=True)
|
||||||
|
|
||||||
for _ in range(max_length - input_ids.shape[1]):
|
for _ in range(max_length - input_ids.shape[1]):
|
||||||
# Forward pass
|
# Get next token logits from last position
|
||||||
logits, _ = self.model(generated)
|
|
||||||
next_token_logits = logits[:, -1, :] / temperature
|
next_token_logits = logits[:, -1, :] / temperature
|
||||||
|
|
||||||
# Apply top-k filtering
|
# Apply top-k filtering
|
||||||
@@ -370,6 +381,9 @@ class OptimizedInference:
|
|||||||
|
|
||||||
# Append to generated sequence
|
# Append to generated sequence
|
||||||
generated = torch.cat([generated, next_token], dim=1)
|
generated = torch.cat([generated, next_token], dim=1)
|
||||||
|
|
||||||
|
# Forward pass with cache enabled (processes full sequence but uses cached K/V)
|
||||||
|
logits, _ = self.model(generated, use_cache=True)
|
||||||
|
|
||||||
# Clear KV cache
|
# Clear KV cache
|
||||||
for module in self.model.modules():
|
for module in self.model.modules():
|
||||||
|
|||||||
@@ -129,6 +129,7 @@ class TransformerModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Forward pass through the transformer model.
|
Forward pass through the transformer model.
|
||||||
@@ -136,6 +137,7 @@ class TransformerModel(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
input_ids: Token indices [batch_size, seq_len]
|
input_ids: Token indices [batch_size, seq_len]
|
||||||
attention_mask: Optional attention mask [batch_size, seq_len]
|
attention_mask: Optional attention mask [batch_size, seq_len]
|
||||||
|
use_cache: Whether to use KV cache (for optimized attention)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
logits: Output logits [batch_size, seq_len, vocab_size]
|
logits: Output logits [batch_size, seq_len, vocab_size]
|
||||||
@@ -166,7 +168,7 @@ class TransformerModel(nn.Module):
|
|||||||
|
|
||||||
# Pass through transformer blocks
|
# Pass through transformer blocks
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x, mask=attention_mask)
|
x = layer(x, mask=attention_mask, use_cache=use_cache)
|
||||||
|
|
||||||
# Final layer norm
|
# Final layer norm
|
||||||
x = self.final_norm(x)
|
x = self.final_norm(x)
|
||||||
|
|||||||
Reference in New Issue
Block a user