160 lines
4.8 KiB
Python
160 lines
4.8 KiB
Python
"""
|
|
Transformer building blocks: Feed-forward networks and transformer blocks
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
from typing import Optional
|
|
from .attention import MultiHeadAttention
|
|
from .optimized_attention import OptimizedMultiHeadAttention
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
"""
|
|
Position-wise Feed-Forward Network.
|
|
Implements two linear transformations with activation in between.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
d_model: int,
|
|
d_ff: int,
|
|
dropout: float = 0.1,
|
|
activation: str = 'gelu',
|
|
bias: bool = False,
|
|
):
|
|
"""
|
|
Args:
|
|
d_model: Model dimension
|
|
d_ff: Feed-forward dimension (typically 4 * d_model)
|
|
dropout: Dropout probability
|
|
activation: Activation function ('gelu' or 'relu')
|
|
bias: Whether to use bias in linear layers
|
|
"""
|
|
super().__init__()
|
|
self.linear1 = nn.Linear(d_model, d_ff, bias=bias)
|
|
self.linear2 = nn.Linear(d_ff, d_model, bias=bias)
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
if activation == 'gelu':
|
|
self.activation = nn.GELU()
|
|
elif activation == 'relu':
|
|
self.activation = nn.ReLU()
|
|
else:
|
|
raise ValueError(f"Unsupported activation: {activation}")
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
x: Input tensor [batch_size, seq_len, d_model]
|
|
|
|
Returns:
|
|
Output tensor [batch_size, seq_len, d_model]
|
|
"""
|
|
x = self.linear1(x)
|
|
x = self.activation(x)
|
|
x = self.dropout(x)
|
|
x = self.linear2(x)
|
|
return x
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
"""
|
|
Transformer block with self-attention and feed-forward network.
|
|
Includes residual connections and layer normalization.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
d_model: int,
|
|
num_heads: int,
|
|
d_ff: int,
|
|
dropout: float = 0.1,
|
|
activation: str = 'gelu',
|
|
layer_norm_eps: float = 1e-5,
|
|
bias: bool = False,
|
|
causal: bool = False,
|
|
use_optimized_attention: bool = False,
|
|
):
|
|
"""
|
|
Args:
|
|
d_model: Model dimension
|
|
num_heads: Number of attention heads
|
|
d_ff: Feed-forward dimension
|
|
dropout: Dropout probability
|
|
activation: Activation function
|
|
layer_norm_eps: Epsilon for layer normalization
|
|
bias: Whether to use bias in linear layers
|
|
causal: Whether to use causal masking
|
|
use_optimized_attention: Whether to use optimized attention with KV caching
|
|
"""
|
|
super().__init__()
|
|
|
|
# Self-attention with pre-norm architecture
|
|
if use_optimized_attention:
|
|
self.self_attn = OptimizedMultiHeadAttention(
|
|
d_model=d_model,
|
|
num_heads=num_heads,
|
|
dropout=dropout,
|
|
bias=bias,
|
|
causal=causal,
|
|
use_flash_attention=True,
|
|
)
|
|
else:
|
|
self.self_attn = MultiHeadAttention(
|
|
d_model=d_model,
|
|
num_heads=num_heads,
|
|
dropout=dropout,
|
|
bias=bias,
|
|
causal=causal,
|
|
)
|
|
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
|
|
|
# Feed-forward network
|
|
self.feed_forward = FeedForward(
|
|
d_model=d_model,
|
|
d_ff=d_ff,
|
|
dropout=dropout,
|
|
activation=activation,
|
|
bias=bias,
|
|
)
|
|
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
mask: Optional[torch.Tensor] = None,
|
|
use_cache: bool = False,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Forward pass through transformer block.
|
|
|
|
Args:
|
|
x: Input tensor [batch_size, seq_len, d_model]
|
|
mask: Optional attention mask
|
|
use_cache: Whether to use KV cache (for optimized attention)
|
|
|
|
Returns:
|
|
Output tensor [batch_size, seq_len, d_model]
|
|
"""
|
|
# Pre-norm self-attention with residual connection
|
|
residual = x
|
|
x = self.norm1(x)
|
|
# Pass use_cache to attention layer if it's OptimizedMultiHeadAttention
|
|
if isinstance(self.self_attn, OptimizedMultiHeadAttention):
|
|
attn_out, _ = self.self_attn(x, x, x, mask=mask, use_cache=use_cache)
|
|
else:
|
|
attn_out, _ = self.self_attn(x, x, x, mask=mask)
|
|
x = residual + self.dropout(attn_out)
|
|
|
|
# Pre-norm feed-forward with residual connection
|
|
residual = x
|
|
x = self.norm2(x)
|
|
ff_out = self.feed_forward(x)
|
|
x = residual + self.dropout(ff_out)
|
|
|
|
return x
|
|
|
|
|