Initial commit: SheepOp LLM - Transformer-based language model implementation
- Complete transformer implementation from scratch - Training pipeline with gradient accumulation and mixed precision - Optimized inference with KV caching - Multi-format data processing (PDFs, images, code, text) - Comprehensive documentation - Apache 2.0 license - Example training plots included in docs/images/
This commit is contained in:
35
models/__init__.py
Normal file
35
models/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
SheepOp LLM - A modern language model implementation
|
||||
Optimized for production RAG systems
|
||||
"""
|
||||
from .transformer import TransformerModel
|
||||
from .attention import MultiHeadAttention, PositionalEncoding
|
||||
from .blocks import TransformerBlock, FeedForward
|
||||
from .optimized_attention import (
|
||||
OptimizedMultiHeadAttention,
|
||||
RetrievalCache,
|
||||
OptimizedInference,
|
||||
KVCache,
|
||||
)
|
||||
from .prefetching import (
|
||||
PrefetchDataLoader,
|
||||
LookaheadRetriever,
|
||||
BatchPrefetcher,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'TransformerModel',
|
||||
'MultiHeadAttention',
|
||||
'PositionalEncoding',
|
||||
'TransformerBlock',
|
||||
'FeedForward',
|
||||
'OptimizedMultiHeadAttention',
|
||||
'RetrievalCache',
|
||||
'OptimizedInference',
|
||||
'KVCache',
|
||||
'PrefetchDataLoader',
|
||||
'LookaheadRetriever',
|
||||
'BatchPrefetcher',
|
||||
]
|
||||
|
||||
|
||||
220
models/attention.py
Normal file
220
models/attention.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
Multi-Head Attention mechanism from "Attention Is All You Need"
|
||||
Includes optimizations for long context and hallucination reduction
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
Multi-Head Attention mechanism with optional causal masking.
|
||||
|
||||
Features:
|
||||
- Scaled dot-product attention
|
||||
- Optional causal masking for autoregressive generation
|
||||
- Efficient attention computation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
num_heads: int,
|
||||
dropout: float = 0.1,
|
||||
bias: bool = False,
|
||||
causal: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
d_model: Model dimension
|
||||
num_heads: Number of attention heads
|
||||
dropout: Dropout probability
|
||||
bias: Whether to use bias in linear layers
|
||||
causal: Whether to use causal masking
|
||||
"""
|
||||
super().__init__()
|
||||
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
||||
|
||||
self.d_model = d_model
|
||||
self.num_heads = num_heads
|
||||
self.d_k = d_model // num_heads
|
||||
self.causal = causal
|
||||
|
||||
# Linear projections for Q, K, V
|
||||
self.q_proj = nn.Linear(d_model, d_model, bias=bias)
|
||||
self.k_proj = nn.Linear(d_model, d_model, bias=bias)
|
||||
self.v_proj = nn.Linear(d_model, d_model, bias=bias)
|
||||
self.out_proj = nn.Linear(d_model, d_model, bias=bias)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.scale = 1.0 / math.sqrt(self.d_k)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass of multi-head attention.
|
||||
|
||||
Args:
|
||||
query: Query tensor [batch_size, seq_len, d_model]
|
||||
key: Key tensor [batch_size, seq_len, d_model]
|
||||
value: Value tensor [batch_size, seq_len, d_model]
|
||||
mask: Optional attention mask [batch_size, seq_len, seq_len]
|
||||
|
||||
Returns:
|
||||
output: Attention output [batch_size, seq_len, d_model]
|
||||
attention_weights: Attention weights [batch_size, num_heads, seq_len, seq_len]
|
||||
"""
|
||||
batch_size, seq_len, _ = query.shape
|
||||
|
||||
# Project Q, K, V
|
||||
Q = self.q_proj(query) # [batch_size, seq_len, d_model]
|
||||
K = self.k_proj(key)
|
||||
V = self.v_proj(value)
|
||||
|
||||
# Reshape for multi-head attention
|
||||
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) # [batch_size, num_heads, seq_len, d_k]
|
||||
K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
|
||||
V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
|
||||
|
||||
# Compute attention scores
|
||||
scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # [batch_size, num_heads, seq_len, seq_len]
|
||||
|
||||
# Apply causal mask if needed
|
||||
if self.causal:
|
||||
causal_mask = torch.triu(
|
||||
torch.ones(seq_len, seq_len, device=query.device, dtype=torch.bool),
|
||||
diagonal=1
|
||||
)
|
||||
scores.masked_fill_(causal_mask, float('-inf'))
|
||||
|
||||
# Apply external mask if provided
|
||||
if mask is not None:
|
||||
scores.masked_fill_(mask.unsqueeze(1) == 0, float('-inf'))
|
||||
|
||||
# Compute attention weights
|
||||
attention_weights = F.softmax(scores, dim=-1)
|
||||
attention_weights = self.dropout(attention_weights)
|
||||
|
||||
# Apply attention to values
|
||||
output = torch.matmul(attention_weights, V) # [batch_size, num_heads, seq_len, d_k]
|
||||
|
||||
# Concatenate heads
|
||||
output = output.transpose(1, 2).contiguous() # [batch_size, seq_len, num_heads, d_k]
|
||||
output = output.view(batch_size, seq_len, self.d_model) # [batch_size, seq_len, d_model]
|
||||
|
||||
# Final projection
|
||||
output = self.out_proj(output)
|
||||
|
||||
return output, attention_weights
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
"""
|
||||
Positional encoding for transformer models.
|
||||
Uses sinusoidal positional encoding as described in "Attention Is All You Need".
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
|
||||
"""
|
||||
Args:
|
||||
d_model: Model dimension
|
||||
max_len: Maximum sequence length
|
||||
dropout: Dropout probability
|
||||
"""
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# Create positional encoding matrix
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
|
||||
)
|
||||
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0) # [1, max_len, d_model]
|
||||
|
||||
# Register as buffer (not a parameter)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Add positional encoding to input.
|
||||
|
||||
Args:
|
||||
x: Input tensor [batch_size, seq_len, d_model]
|
||||
|
||||
Returns:
|
||||
Output with positional encoding added
|
||||
"""
|
||||
seq_len = x.shape[1]
|
||||
x = x + self.pe[:, :seq_len, :]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class RotaryPositionalEncoding(nn.Module):
|
||||
"""
|
||||
Rotary Position Embedding (RoPE) - More efficient for long sequences.
|
||||
Better for long-horizon execution tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, max_len: int = 8192):
|
||||
"""
|
||||
Args:
|
||||
d_model: Model dimension (must be even)
|
||||
max_len: Maximum sequence length
|
||||
"""
|
||||
super().__init__()
|
||||
assert d_model % 2 == 0, "d_model must be even for RoPE"
|
||||
|
||||
self.d_model = d_model
|
||||
self.max_len = max_len
|
||||
|
||||
# Precompute frequency matrix
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
|
||||
def forward(self, x: torch.Tensor, offset: int = 0) -> torch.Tensor:
|
||||
"""
|
||||
Apply rotary positional encoding.
|
||||
|
||||
Args:
|
||||
x: Input tensor [batch_size, seq_len, d_model]
|
||||
offset: Position offset for relative positions
|
||||
|
||||
Returns:
|
||||
Rotated input tensor
|
||||
"""
|
||||
seq_len = x.shape[1]
|
||||
device = x.device
|
||||
|
||||
# Generate position indices
|
||||
t = torch.arange(offset, offset + seq_len, device=device).type_as(self.inv_freq)
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
# Apply rotation
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
# Split input into two halves
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
|
||||
# Apply rotation
|
||||
rotated = torch.cat([
|
||||
x1 * cos - x2 * sin,
|
||||
x1 * sin + x2 * cos
|
||||
], dim=-1)
|
||||
|
||||
return rotated
|
||||
|
||||
|
||||
153
models/blocks.py
Normal file
153
models/blocks.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
Transformer building blocks: Feed-forward networks and transformer blocks
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional
|
||||
from .attention import MultiHeadAttention
|
||||
from .optimized_attention import OptimizedMultiHeadAttention
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
"""
|
||||
Position-wise Feed-Forward Network.
|
||||
Implements two linear transformations with activation in between.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
d_ff: int,
|
||||
dropout: float = 0.1,
|
||||
activation: str = 'gelu',
|
||||
bias: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
d_model: Model dimension
|
||||
d_ff: Feed-forward dimension (typically 4 * d_model)
|
||||
dropout: Dropout probability
|
||||
activation: Activation function ('gelu' or 'relu')
|
||||
bias: Whether to use bias in linear layers
|
||||
"""
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(d_model, d_ff, bias=bias)
|
||||
self.linear2 = nn.Linear(d_ff, d_model, bias=bias)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
if activation == 'gelu':
|
||||
self.activation = nn.GELU()
|
||||
elif activation == 'relu':
|
||||
self.activation = nn.ReLU()
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {activation}")
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: Input tensor [batch_size, seq_len, d_model]
|
||||
|
||||
Returns:
|
||||
Output tensor [batch_size, seq_len, d_model]
|
||||
"""
|
||||
x = self.linear1(x)
|
||||
x = self.activation(x)
|
||||
x = self.dropout(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
"""
|
||||
Transformer block with self-attention and feed-forward network.
|
||||
Includes residual connections and layer normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
num_heads: int,
|
||||
d_ff: int,
|
||||
dropout: float = 0.1,
|
||||
activation: str = 'gelu',
|
||||
layer_norm_eps: float = 1e-5,
|
||||
bias: bool = False,
|
||||
causal: bool = False,
|
||||
use_optimized_attention: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
d_model: Model dimension
|
||||
num_heads: Number of attention heads
|
||||
d_ff: Feed-forward dimension
|
||||
dropout: Dropout probability
|
||||
activation: Activation function
|
||||
layer_norm_eps: Epsilon for layer normalization
|
||||
bias: Whether to use bias in linear layers
|
||||
causal: Whether to use causal masking
|
||||
use_optimized_attention: Whether to use optimized attention with KV caching
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Self-attention with pre-norm architecture
|
||||
if use_optimized_attention:
|
||||
self.self_attn = OptimizedMultiHeadAttention(
|
||||
d_model=d_model,
|
||||
num_heads=num_heads,
|
||||
dropout=dropout,
|
||||
bias=bias,
|
||||
causal=causal,
|
||||
use_flash_attention=True,
|
||||
)
|
||||
else:
|
||||
self.self_attn = MultiHeadAttention(
|
||||
d_model=d_model,
|
||||
num_heads=num_heads,
|
||||
dropout=dropout,
|
||||
bias=bias,
|
||||
causal=causal,
|
||||
)
|
||||
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
||||
|
||||
# Feed-forward network
|
||||
self.feed_forward = FeedForward(
|
||||
d_model=d_model,
|
||||
d_ff=d_ff,
|
||||
dropout=dropout,
|
||||
activation=activation,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through transformer block.
|
||||
|
||||
Args:
|
||||
x: Input tensor [batch_size, seq_len, d_model]
|
||||
mask: Optional attention mask
|
||||
|
||||
Returns:
|
||||
Output tensor [batch_size, seq_len, d_model]
|
||||
"""
|
||||
# Pre-norm self-attention with residual connection
|
||||
residual = x
|
||||
x = self.norm1(x)
|
||||
attn_out, _ = self.self_attn(x, x, x, mask=mask)
|
||||
x = residual + self.dropout(attn_out)
|
||||
|
||||
# Pre-norm feed-forward with residual connection
|
||||
residual = x
|
||||
x = self.norm2(x)
|
||||
ff_out = self.feed_forward(x)
|
||||
x = residual + self.dropout(ff_out)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
413
models/optimized_attention.py
Normal file
413
models/optimized_attention.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""
|
||||
Optimized attention mechanisms for production RAG systems
|
||||
Implements KV caching, optimized attention computation, and retrieval optimizations
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from typing import Optional, Tuple, Dict, List
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCache:
|
||||
"""Key-Value cache for efficient autoregressive generation."""
|
||||
keys: torch.Tensor # [batch_size, num_heads, seq_len, d_k]
|
||||
values: torch.Tensor # [batch_size, num_heads, seq_len, d_k]
|
||||
|
||||
def append(self, new_keys: torch.Tensor, new_values: torch.Tensor):
|
||||
"""Append new keys and values to the cache."""
|
||||
self.keys = torch.cat([self.keys, new_keys], dim=2)
|
||||
self.values = torch.cat([self.values, new_values], dim=2)
|
||||
|
||||
def clear(self):
|
||||
"""Clear the cache."""
|
||||
self.keys = None
|
||||
self.values = None
|
||||
|
||||
|
||||
class OptimizedMultiHeadAttention(nn.Module):
|
||||
"""
|
||||
Optimized Multi-Head Attention with KV caching and efficient computation.
|
||||
|
||||
Features:
|
||||
- KV cache for autoregressive generation
|
||||
- Optimized attention computation
|
||||
- Support for incremental decoding
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
num_heads: int,
|
||||
dropout: float = 0.1,
|
||||
bias: bool = False,
|
||||
causal: bool = False,
|
||||
use_flash_attention: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
d_model: Model dimension
|
||||
num_heads: Number of attention heads
|
||||
dropout: Dropout probability
|
||||
bias: Whether to use bias in linear layers
|
||||
causal: Whether to use causal masking
|
||||
use_flash_attention: Whether to use optimized flash attention (if available)
|
||||
"""
|
||||
super().__init__()
|
||||
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
||||
|
||||
self.d_model = d_model
|
||||
self.num_heads = num_heads
|
||||
self.d_k = d_model // num_heads
|
||||
self.causal = causal
|
||||
self.use_flash_attention = use_flash_attention
|
||||
|
||||
# Linear projections for Q, K, V
|
||||
self.q_proj = nn.Linear(d_model, d_model, bias=bias)
|
||||
self.k_proj = nn.Linear(d_model, d_model, bias=bias)
|
||||
self.v_proj = nn.Linear(d_model, d_model, bias=bias)
|
||||
self.out_proj = nn.Linear(d_model, d_model, bias=bias)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.scale = 1.0 / math.sqrt(self.d_k)
|
||||
|
||||
# KV cache for inference
|
||||
self.kv_cache: Optional[KVCache] = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
value: Optional[torch.Tensor] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Forward pass with optional KV caching.
|
||||
|
||||
Args:
|
||||
query: Query tensor [batch_size, seq_len, d_model]
|
||||
key: Key tensor [batch_size, seq_len, d_model] (if None, uses query)
|
||||
value: Value tensor [batch_size, seq_len, d_model] (if None, uses query)
|
||||
mask: Optional attention mask [batch_size, seq_len, seq_len]
|
||||
use_cache: Whether to use KV cache
|
||||
cache_position: Position in cache for incremental decoding
|
||||
|
||||
Returns:
|
||||
output: Attention output [batch_size, seq_len, d_model]
|
||||
attention_weights: Attention weights [batch_size, num_heads, seq_len, seq_len]
|
||||
"""
|
||||
if key is None:
|
||||
key = query
|
||||
if value is None:
|
||||
value = query
|
||||
|
||||
batch_size, seq_len, _ = query.shape
|
||||
|
||||
# Project Q, K, V
|
||||
Q = self.q_proj(query) # [batch_size, seq_len, d_model]
|
||||
K = self.k_proj(key)
|
||||
V = self.v_proj(value)
|
||||
|
||||
# Reshape for multi-head attention
|
||||
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) # [batch_size, num_heads, seq_len, d_k]
|
||||
K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
|
||||
V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
|
||||
|
||||
# Use KV cache if available and enabled
|
||||
if use_cache and self.kv_cache is not None:
|
||||
# Append new keys and values to cache
|
||||
self.kv_cache.append(K, V)
|
||||
K = self.kv_cache.keys
|
||||
V = self.kv_cache.values
|
||||
kv_seq_len = K.shape[2]
|
||||
else:
|
||||
kv_seq_len = seq_len
|
||||
|
||||
# Compute attention scores with optimized computation
|
||||
if self.use_flash_attention and hasattr(F, 'scaled_dot_product_attention'):
|
||||
# Use PyTorch's optimized scaled dot product attention
|
||||
output = F.scaled_dot_product_attention(
|
||||
Q, K, V,
|
||||
attn_mask=mask,
|
||||
dropout_p=self.dropout.p if self.training else 0.0,
|
||||
is_causal=self.causal,
|
||||
)
|
||||
attention_weights = None # Flash attention doesn't return weights
|
||||
else:
|
||||
# Standard attention computation
|
||||
scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # [batch_size, num_heads, seq_len, kv_seq_len]
|
||||
|
||||
# Apply causal mask if needed
|
||||
if self.causal:
|
||||
causal_mask = torch.triu(
|
||||
torch.ones(seq_len, kv_seq_len, device=query.device, dtype=torch.bool),
|
||||
diagonal=1
|
||||
)
|
||||
scores.masked_fill_(causal_mask, float('-inf'))
|
||||
|
||||
# Apply external mask if provided
|
||||
if mask is not None:
|
||||
if mask.dim() == 2:
|
||||
mask = mask.unsqueeze(1).unsqueeze(1) # [batch_size, 1, seq_len, kv_seq_len]
|
||||
scores.masked_fill_(mask == 0, float('-inf'))
|
||||
|
||||
# Compute attention weights
|
||||
attention_weights = F.softmax(scores, dim=-1)
|
||||
attention_weights = self.dropout(attention_weights)
|
||||
|
||||
# Apply attention to values
|
||||
output = torch.matmul(attention_weights, V) # [batch_size, num_heads, seq_len, d_k]
|
||||
|
||||
# Concatenate heads
|
||||
output = output.transpose(1, 2).contiguous() # [batch_size, seq_len, num_heads, d_k]
|
||||
output = output.view(batch_size, seq_len, self.d_model) # [batch_size, seq_len, d_model]
|
||||
|
||||
# Final projection
|
||||
output = self.out_proj(output)
|
||||
|
||||
return output, attention_weights
|
||||
|
||||
def init_kv_cache(self, batch_size: int, max_length: int, device: torch.device):
|
||||
"""Initialize KV cache for inference."""
|
||||
self.kv_cache = KVCache(
|
||||
keys=torch.empty(batch_size, self.num_heads, 0, self.d_k, device=device),
|
||||
values=torch.empty(batch_size, self.num_heads, 0, self.d_k, device=device),
|
||||
)
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear the KV cache."""
|
||||
self.kv_cache = None
|
||||
|
||||
|
||||
class RetrievalCache:
|
||||
"""
|
||||
Approximate cache for retrieval results.
|
||||
Reduces expensive vector database lookups by caching similar queries.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = 1000, similarity_threshold: float = 0.9):
|
||||
"""
|
||||
Args:
|
||||
max_size: Maximum number of cached entries
|
||||
similarity_threshold: Minimum similarity to consider a cache hit
|
||||
"""
|
||||
self.max_size = max_size
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.cache: Dict[str, List[Dict]] = {} # query_hash -> retrieved_docs
|
||||
self.query_embeddings: Dict[str, torch.Tensor] = {} # query_hash -> embedding
|
||||
|
||||
def get(self, query_hash: str, query_embedding: torch.Tensor) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Retrieve cached results if similar query exists.
|
||||
|
||||
Args:
|
||||
query_hash: Hash of the query
|
||||
query_embedding: Embedding of the query
|
||||
|
||||
Returns:
|
||||
Cached results if found, None otherwise
|
||||
"""
|
||||
# Check exact match first
|
||||
if query_hash in self.cache:
|
||||
return self.cache[query_hash]
|
||||
|
||||
# Check for similar queries
|
||||
best_match = None
|
||||
best_similarity = 0.0
|
||||
|
||||
for cached_hash, cached_embedding in self.query_embeddings.items():
|
||||
# Compute cosine similarity
|
||||
similarity = F.cosine_similarity(
|
||||
query_embedding.unsqueeze(0),
|
||||
cached_embedding.unsqueeze(0)
|
||||
).item()
|
||||
|
||||
if similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = cached_hash
|
||||
|
||||
if best_similarity >= self.similarity_threshold and best_match:
|
||||
return self.cache[best_match]
|
||||
|
||||
return None
|
||||
|
||||
def set(self, query_hash: str, query_embedding: torch.Tensor, results: List[Dict]):
|
||||
"""
|
||||
Store query and results in cache.
|
||||
|
||||
Args:
|
||||
query_hash: Hash of the query
|
||||
query_embedding: Embedding of the query
|
||||
results: Retrieved documents/results
|
||||
"""
|
||||
# Remove oldest entry if cache is full
|
||||
if len(self.cache) >= self.max_size:
|
||||
oldest_key = next(iter(self.cache))
|
||||
del self.cache[oldest_key]
|
||||
del self.query_embeddings[oldest_key]
|
||||
|
||||
self.cache[query_hash] = results
|
||||
self.query_embeddings[query_hash] = query_embedding
|
||||
|
||||
def clear(self):
|
||||
"""Clear the cache."""
|
||||
self.cache.clear()
|
||||
self.query_embeddings.clear()
|
||||
|
||||
|
||||
class OptimizedInference:
|
||||
"""
|
||||
Optimized inference utilities for production RAG systems.
|
||||
Includes prefetching, batching, and parallel processing.
|
||||
"""
|
||||
|
||||
def __init__(self, model: nn.Module, device: torch.device):
|
||||
"""
|
||||
Args:
|
||||
model: Model to use for inference
|
||||
device: Device to run inference on
|
||||
"""
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.model.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_with_cache(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int = 100,
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: float = 1.0,
|
||||
do_sample: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Generate with KV cache for efficient autoregressive generation.
|
||||
|
||||
Args:
|
||||
input_ids: Starting token indices [batch_size, seq_len]
|
||||
max_length: Maximum generation length
|
||||
temperature: Sampling temperature
|
||||
top_k: Top-k sampling parameter
|
||||
top_p: Nucleus sampling parameter
|
||||
do_sample: Whether to sample or use greedy decoding
|
||||
|
||||
Returns:
|
||||
Generated token sequences
|
||||
"""
|
||||
batch_size = input_ids.shape[0]
|
||||
device = input_ids.device
|
||||
|
||||
# Initialize KV cache in all attention layers
|
||||
for module in self.model.modules():
|
||||
if isinstance(module, OptimizedMultiHeadAttention):
|
||||
module.init_kv_cache(batch_size, max_length, device)
|
||||
|
||||
generated = input_ids.clone()
|
||||
|
||||
for _ in range(max_length - input_ids.shape[1]):
|
||||
# Forward pass
|
||||
logits, _ = self.model(generated)
|
||||
next_token_logits = logits[:, -1, :] / temperature
|
||||
|
||||
# Apply top-k filtering
|
||||
if top_k is not None and top_k > 0:
|
||||
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
|
||||
next_token_logits[indices_to_remove] = float('-inf')
|
||||
|
||||
# Apply top-p (nucleus) filtering
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
next_token_logits[indices_to_remove] = float('-inf')
|
||||
|
||||
# Sample or take argmax
|
||||
if do_sample:
|
||||
probs = torch.softmax(next_token_logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
||||
|
||||
# Early stopping: stop if EOS or padding token is generated (for batch_size=1)
|
||||
if batch_size == 1:
|
||||
eos_token_id = 3 # Default EOS token ID
|
||||
if next_token.item() == eos_token_id:
|
||||
break
|
||||
|
||||
# Early stopping: stop if padding token is generated (prevent generating padding)
|
||||
pad_token_id = 0 # Default padding token ID
|
||||
if next_token.item() == pad_token_id:
|
||||
break
|
||||
|
||||
# Append to generated sequence
|
||||
generated = torch.cat([generated, next_token], dim=1)
|
||||
|
||||
# Clear KV cache
|
||||
for module in self.model.modules():
|
||||
if isinstance(module, OptimizedMultiHeadAttention):
|
||||
module.clear_cache()
|
||||
|
||||
return generated
|
||||
|
||||
@torch.no_grad()
|
||||
def batch_generate(
|
||||
self,
|
||||
input_ids_list: List[torch.Tensor],
|
||||
max_length: int = 100,
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: float = 1.0,
|
||||
batch_size: int = 8,
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Generate for multiple prompts in batches for efficiency.
|
||||
|
||||
Args:
|
||||
input_ids_list: List of starting token sequences
|
||||
max_length: Maximum generation length
|
||||
temperature: Sampling temperature
|
||||
top_k: Top-k sampling parameter
|
||||
top_p: Nucleus sampling parameter
|
||||
batch_size: Batch size for processing
|
||||
|
||||
Returns:
|
||||
List of generated sequences
|
||||
"""
|
||||
results = []
|
||||
|
||||
for i in range(0, len(input_ids_list), batch_size):
|
||||
batch = input_ids_list[i:i + batch_size]
|
||||
|
||||
# Pad to same length
|
||||
max_len = max(seq.shape[1] for seq in batch)
|
||||
padded_batch = []
|
||||
for seq in batch:
|
||||
padding = torch.zeros(seq.shape[0], max_len - seq.shape[1],
|
||||
dtype=seq.dtype, device=seq.device)
|
||||
padded_batch.append(torch.cat([seq, padding], dim=1))
|
||||
|
||||
batch_tensor = torch.cat(padded_batch, dim=0)
|
||||
|
||||
# Generate for batch
|
||||
generated = self.generate_with_cache(
|
||||
batch_tensor,
|
||||
max_length=max_length,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
results.extend([gen for gen in generated])
|
||||
|
||||
return results
|
||||
|
||||
267
models/prefetching.py
Normal file
267
models/prefetching.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""
|
||||
Prefetching mechanism for parallel data loading and processing
|
||||
Optimizes RAG systems by prefetching retrieval results
|
||||
"""
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from typing import List, Dict, Optional, Callable, Any
|
||||
from threading import Thread
|
||||
from queue import Queue
|
||||
import time
|
||||
|
||||
|
||||
class PrefetchDataLoader:
|
||||
"""
|
||||
DataLoader with prefetching for parallel data loading.
|
||||
Reduces GPU idle time by prefetching batches in background threads.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataloader: DataLoader,
|
||||
prefetch_factor: int = 2,
|
||||
device: torch.device = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dataloader: Base DataLoader to wrap
|
||||
prefetch_factor: Number of batches to prefetch
|
||||
device: Device to prefetch batches to
|
||||
"""
|
||||
self.dataloader = dataloader
|
||||
self.prefetch_factor = prefetch_factor
|
||||
self.device = device
|
||||
|
||||
self.queue = Queue(maxsize=prefetch_factor)
|
||||
self.thread = None
|
||||
self._stop_thread = False
|
||||
|
||||
def _prefetch_worker(self):
|
||||
"""Worker thread that prefetches batches."""
|
||||
for batch in self.dataloader:
|
||||
if self._stop_thread:
|
||||
break
|
||||
|
||||
# Move to device if specified
|
||||
if self.device is not None:
|
||||
batch = {k: v.to(self.device, non_blocking=True)
|
||||
for k, v in batch.items()}
|
||||
|
||||
self.queue.put(batch)
|
||||
|
||||
self.queue.put(None) # Signal end of data
|
||||
|
||||
def __iter__(self):
|
||||
"""Start prefetching thread and return iterator."""
|
||||
self._stop_thread = False
|
||||
self.thread = Thread(target=self._prefetch_worker, daemon=True)
|
||||
self.thread.start()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
"""Get next prefetched batch."""
|
||||
batch = self.queue.get()
|
||||
if batch is None:
|
||||
raise StopIteration
|
||||
return batch
|
||||
|
||||
def __len__(self):
|
||||
"""Return length of underlying dataloader."""
|
||||
return len(self.dataloader)
|
||||
|
||||
def stop(self):
|
||||
"""Stop prefetching thread."""
|
||||
self._stop_thread = True
|
||||
if self.thread is not None:
|
||||
self.thread.join()
|
||||
|
||||
|
||||
class LookaheadRetriever:
|
||||
"""
|
||||
Lookahead retrieval mechanism for RAG systems.
|
||||
Prefetches retrieval results for anticipated queries.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retrieval_fn: Callable[[str], List[Dict]],
|
||||
lookahead_window: int = 3,
|
||||
prefetch_queue_size: int = 10,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
retrieval_fn: Function that takes a query and returns retrieved documents
|
||||
lookahead_window: Number of queries to look ahead
|
||||
prefetch_queue_size: Maximum size of prefetch queue
|
||||
"""
|
||||
self.retrieval_fn = retrieval_fn
|
||||
self.lookahead_window = lookahead_window
|
||||
self.prefetch_queue_size = prefetch_queue_size
|
||||
|
||||
self.prefetch_queue: Queue = Queue(maxsize=prefetch_queue_size)
|
||||
self.prefetch_thread: Optional[Thread] = None
|
||||
self._stop_thread = False
|
||||
|
||||
def _prefetch_worker(self, query_queue: Queue):
|
||||
"""Worker thread that prefetches retrieval results."""
|
||||
while not self._stop_thread:
|
||||
try:
|
||||
query = query_queue.get(timeout=1.0)
|
||||
if query is None:
|
||||
break
|
||||
|
||||
# Perform retrieval
|
||||
results = self.retrieval_fn(query)
|
||||
|
||||
# Add to prefetch queue
|
||||
try:
|
||||
self.prefetch_queue.put((query, results), timeout=0.1)
|
||||
except:
|
||||
pass # Queue full, skip
|
||||
|
||||
except:
|
||||
continue
|
||||
|
||||
def start_prefetching(self, query_stream: List[str]):
|
||||
"""Start prefetching retrieval results for query stream."""
|
||||
query_queue = Queue()
|
||||
|
||||
# Add queries to queue
|
||||
for query in query_stream:
|
||||
query_queue.put(query)
|
||||
query_queue.put(None) # Signal end
|
||||
|
||||
self._stop_thread = False
|
||||
self.prefetch_thread = Thread(target=self._prefetch_worker, args=(query_queue,), daemon=True)
|
||||
self.prefetch_thread.start()
|
||||
|
||||
def get(self, query: str, timeout: float = 1.0) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Get retrieval results, checking prefetch queue first.
|
||||
|
||||
Args:
|
||||
query: Query string
|
||||
timeout: Timeout for checking prefetch queue
|
||||
|
||||
Returns:
|
||||
Retrieved documents or None if not found
|
||||
"""
|
||||
# Check prefetch queue
|
||||
while not self.prefetch_queue.empty():
|
||||
try:
|
||||
cached_query, results = self.prefetch_queue.get(timeout=timeout)
|
||||
if cached_query == query:
|
||||
return results
|
||||
# Put back if not matching
|
||||
self.prefetch_queue.put((cached_query, results))
|
||||
except:
|
||||
break
|
||||
|
||||
# Fallback to direct retrieval
|
||||
return self.retrieval_fn(query)
|
||||
|
||||
def stop(self):
|
||||
"""Stop prefetching thread."""
|
||||
self._stop_thread = True
|
||||
if self.prefetch_thread is not None:
|
||||
self.prefetch_thread.join()
|
||||
|
||||
|
||||
class BatchPrefetcher:
|
||||
"""
|
||||
Batched prefetching for multiple queries.
|
||||
Groups queries into batches for efficient retrieval.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_retrieval_fn: Callable[[List[str]], List[List[Dict]]],
|
||||
batch_size: int = 8,
|
||||
prefetch_factor: int = 2,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
batch_retrieval_fn: Function that takes list of queries and returns list of results
|
||||
batch_size: Size of batches for retrieval
|
||||
prefetch_factor: Number of batches to prefetch
|
||||
"""
|
||||
self.batch_retrieval_fn = batch_retrieval_fn
|
||||
self.batch_size = batch_size
|
||||
self.prefetch_factor = prefetch_factor
|
||||
|
||||
self.prefetch_queue: Queue = Queue(maxsize=prefetch_factor)
|
||||
self.prefetch_thread: Optional[Thread] = None
|
||||
self._stop_thread = False
|
||||
|
||||
def _prefetch_worker(self, query_queue: Queue):
|
||||
"""Worker thread that prefetches batches of retrieval results."""
|
||||
batch = []
|
||||
|
||||
while not self._stop_thread:
|
||||
try:
|
||||
query = query_queue.get(timeout=1.0)
|
||||
if query is None:
|
||||
# Process remaining batch
|
||||
if batch:
|
||||
results = self.batch_retrieval_fn(batch)
|
||||
for q, r in zip(batch, results):
|
||||
self.prefetch_queue.put((q, r))
|
||||
break
|
||||
|
||||
batch.append(query)
|
||||
|
||||
# Process batch when full
|
||||
if len(batch) >= self.batch_size:
|
||||
results = self.batch_retrieval_fn(batch)
|
||||
for q, r in zip(batch, results):
|
||||
try:
|
||||
self.prefetch_queue.put((q, r), timeout=0.1)
|
||||
except:
|
||||
pass # Queue full
|
||||
batch = []
|
||||
|
||||
except:
|
||||
continue
|
||||
|
||||
def start_prefetching(self, query_stream: List[str]):
|
||||
"""Start prefetching retrieval results for query stream."""
|
||||
query_queue = Queue()
|
||||
|
||||
for query in query_stream:
|
||||
query_queue.put(query)
|
||||
query_queue.put(None) # Signal end
|
||||
|
||||
self._stop_thread = False
|
||||
self.prefetch_thread = Thread(target=self._prefetch_worker, args=(query_queue,), daemon=True)
|
||||
self.prefetch_thread.start()
|
||||
|
||||
def get(self, query: str, timeout: float = 1.0) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Get retrieval results from prefetch queue.
|
||||
|
||||
Args:
|
||||
query: Query string
|
||||
timeout: Timeout for checking prefetch queue
|
||||
|
||||
Returns:
|
||||
Retrieved documents or None if not found
|
||||
"""
|
||||
# Check prefetch queue
|
||||
while not self.prefetch_queue.empty():
|
||||
try:
|
||||
cached_query, results = self.prefetch_queue.get(timeout=timeout)
|
||||
if cached_query == query:
|
||||
return results
|
||||
# Put back if not matching
|
||||
self.prefetch_queue.put((cached_query, results))
|
||||
except:
|
||||
break
|
||||
|
||||
return None
|
||||
|
||||
def stop(self):
|
||||
"""Stop prefetching thread."""
|
||||
self._stop_thread = True
|
||||
if self.prefetch_thread is not None:
|
||||
self.prefetch_thread.join()
|
||||
|
||||
268
models/transformer.py
Normal file
268
models/transformer.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""
|
||||
Complete Transformer model for language modeling
|
||||
Incorporates best practices from multiple research papers
|
||||
Optimized for production RAG systems with KV caching and efficient inference
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Tuple
|
||||
from .blocks import TransformerBlock
|
||||
from .attention import PositionalEncoding
|
||||
from .optimized_attention import OptimizedMultiHeadAttention, RetrievalCache, OptimizedInference
|
||||
|
||||
|
||||
class TransformerModel(nn.Module):
|
||||
"""
|
||||
Full Transformer Language Model.
|
||||
|
||||
Features:
|
||||
- Multi-head self-attention
|
||||
- Positional encoding
|
||||
- Layer normalization
|
||||
- Residual connections
|
||||
- Causal masking for autoregressive generation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
d_model: int = 512,
|
||||
num_layers: int = 6,
|
||||
num_heads: int = 8,
|
||||
d_ff: int = 2048,
|
||||
max_seq_len: int = 512,
|
||||
dropout: float = 0.1,
|
||||
activation: str = 'gelu',
|
||||
layer_norm_eps: float = 1e-5,
|
||||
bias: bool = False,
|
||||
tie_weights: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
vocab_size: Vocabulary size
|
||||
d_model: Model dimension
|
||||
num_layers: Number of transformer layers
|
||||
num_heads: Number of attention heads
|
||||
d_ff: Feed-forward dimension
|
||||
max_seq_len: Maximum sequence length
|
||||
dropout: Dropout probability
|
||||
activation: Activation function ('gelu' or 'relu')
|
||||
layer_norm_eps: Epsilon for layer normalization
|
||||
bias: Whether to use bias in linear layers
|
||||
tie_weights: Whether to tie input and output embeddings
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.d_model = d_model
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
# Token embeddings
|
||||
self.token_embedding = nn.Embedding(vocab_size, d_model)
|
||||
|
||||
# Positional encoding
|
||||
self.pos_encoding = PositionalEncoding(
|
||||
d_model=d_model,
|
||||
max_len=max_seq_len,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
# Transformer blocks (use optimized attention if available)
|
||||
# Note: Set use_optimized_attention=True for production inference
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerBlock(
|
||||
d_model=d_model,
|
||||
num_heads=num_heads,
|
||||
d_ff=d_ff,
|
||||
dropout=dropout,
|
||||
activation=activation,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
bias=bias,
|
||||
causal=True, # Causal masking for autoregressive generation
|
||||
use_optimized_attention=False, # Set to True for inference optimizations
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
# Final layer norm
|
||||
self.final_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
||||
|
||||
# Output projection
|
||||
self.output_proj = nn.Linear(d_model, vocab_size, bias=bias)
|
||||
|
||||
# Optionally tie weights
|
||||
if tie_weights:
|
||||
self.output_proj.weight = self.token_embedding.weight
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# Retrieval cache for RAG systems
|
||||
self.retrieval_cache = RetrievalCache(max_size=1000, similarity_threshold=0.9)
|
||||
|
||||
# Initialize weights
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
"""Initialize weights following best practices."""
|
||||
# Initialize embeddings
|
||||
nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)
|
||||
|
||||
# Initialize output projection
|
||||
if self.output_proj.weight is not self.token_embedding.weight:
|
||||
nn.init.normal_(self.output_proj.weight, mean=0.0, std=0.02)
|
||||
|
||||
# Initialize linear layers
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
nn.init.ones_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Forward pass through the transformer model.
|
||||
|
||||
Args:
|
||||
input_ids: Token indices [batch_size, seq_len]
|
||||
attention_mask: Optional attention mask [batch_size, seq_len]
|
||||
|
||||
Returns:
|
||||
logits: Output logits [batch_size, seq_len, vocab_size]
|
||||
attention_weights: Optional attention weights
|
||||
"""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
||||
# Token embeddings
|
||||
x = self.token_embedding(input_ids) # [batch_size, seq_len, d_model]
|
||||
|
||||
# Add positional encoding
|
||||
x = self.pos_encoding(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
# Create attention mask if not provided
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
batch_size, seq_len, device=input_ids.device, dtype=torch.bool
|
||||
)
|
||||
|
||||
# Expand mask for attention
|
||||
# [batch_size, seq_len] -> [batch_size, seq_len, seq_len]
|
||||
if attention_mask.dim() == 2:
|
||||
attention_mask = attention_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len)
|
||||
|
||||
# Apply attention mask (invert: 1 for valid, 0 for masked)
|
||||
attention_mask = attention_mask.float()
|
||||
|
||||
# Pass through transformer blocks
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask=attention_mask)
|
||||
|
||||
# Final layer norm
|
||||
x = self.final_norm(x)
|
||||
|
||||
# Output projection
|
||||
logits = self.output_proj(x) # [batch_size, seq_len, vocab_size]
|
||||
|
||||
return logits, None
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int = 100,
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: float = 1.0,
|
||||
do_sample: bool = True,
|
||||
pad_token_id: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Autoregressive generation.
|
||||
|
||||
Args:
|
||||
input_ids: Starting token indices [batch_size, seq_len]
|
||||
max_length: Maximum generation length
|
||||
temperature: Sampling temperature
|
||||
top_k: Top-k sampling parameter
|
||||
top_p: Nucleus sampling parameter
|
||||
do_sample: Whether to sample or use greedy decoding
|
||||
pad_token_id: Padding token ID
|
||||
|
||||
Returns:
|
||||
Generated token sequences
|
||||
"""
|
||||
self.eval()
|
||||
device = input_ids.device
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
generated = input_ids.clone()
|
||||
|
||||
with torch.no_grad():
|
||||
for _ in range(max_length - input_ids.shape[1]):
|
||||
# Forward pass
|
||||
logits, _ = self.forward(generated)
|
||||
next_token_logits = logits[:, -1, :] / temperature
|
||||
|
||||
# Apply top-k filtering
|
||||
if top_k is not None and top_k > 0:
|
||||
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
|
||||
next_token_logits[indices_to_remove] = float('-inf')
|
||||
|
||||
# Apply top-p (nucleus) filtering
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
next_token_logits[indices_to_remove] = float('-inf')
|
||||
|
||||
# Sample or take argmax
|
||||
if do_sample:
|
||||
probs = torch.softmax(next_token_logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
||||
|
||||
# Early stopping: stop if EOS or padding token is generated (for batch_size=1)
|
||||
if batch_size == 1:
|
||||
eos_token_id = getattr(self, 'eos_token_id', None) or 3 # Default EOS token
|
||||
if next_token.item() == eos_token_id:
|
||||
break
|
||||
|
||||
# Early stopping: stop if padding token is generated (prevent generating padding)
|
||||
if pad_token_id is not None and next_token.item() == pad_token_id:
|
||||
break
|
||||
|
||||
# Append to generated sequence
|
||||
generated = torch.cat([generated, next_token], dim=1)
|
||||
|
||||
return generated
|
||||
|
||||
def get_optimized_inference(self) -> OptimizedInference:
|
||||
"""
|
||||
Get optimized inference utility with KV caching and batching.
|
||||
|
||||
Returns:
|
||||
OptimizedInference instance
|
||||
"""
|
||||
return OptimizedInference(self, next(self.parameters()).device)
|
||||
|
||||
def get_num_params(self) -> int:
|
||||
"""Return the number of trainable parameters."""
|
||||
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user