Initial commit: SheepOp LLM - Transformer-based language model implementation
- Complete transformer implementation from scratch - Training pipeline with gradient accumulation and mixed precision - Optimized inference with KV caching - Multi-format data processing (PDFs, images, code, text) - Comprehensive documentation - Apache 2.0 license - Example training plots included in docs/images/
This commit is contained in:
405
training/__init__.py
Normal file
405
training/__init__.py
Normal file
@@ -0,0 +1,405 @@
|
||||
"""
|
||||
Training utilities and training loop
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
|
||||
from typing import Dict, Optional, Callable
|
||||
from pathlib import Path
|
||||
import json
|
||||
import sys
|
||||
from tqdm import tqdm
|
||||
import math
|
||||
from .metrics import TrainingMetrics
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
Trainer class for language model training.
|
||||
Includes gradient accumulation, mixed precision training, and checkpointing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
train_loader,
|
||||
val_loader=None,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
device: str = 'cuda',
|
||||
max_epochs: int = 10,
|
||||
gradient_accumulation_steps: int = 1,
|
||||
max_grad_norm: float = 1.0,
|
||||
use_amp: bool = True,
|
||||
save_dir: str = './checkpoints',
|
||||
log_interval: int = 100,
|
||||
eval_interval: int = 1000,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
model: Model to train
|
||||
train_loader: Training data loader
|
||||
val_loader: Validation data loader (optional)
|
||||
optimizer: Optimizer (if None, AdamW is used)
|
||||
scheduler: Learning rate scheduler (optional)
|
||||
device: Device to train on
|
||||
max_epochs: Maximum number of epochs
|
||||
gradient_accumulation_steps: Gradient accumulation steps
|
||||
max_grad_norm: Maximum gradient norm for clipping
|
||||
use_amp: Whether to use mixed precision training
|
||||
save_dir: Directory to save checkpoints
|
||||
log_interval: Logging interval
|
||||
eval_interval: Evaluation interval
|
||||
"""
|
||||
# Convert device string to torch.device if needed
|
||||
if isinstance(device, str):
|
||||
self.device = torch.device(device)
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
self.model = model.to(self.device)
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.max_epochs = max_epochs
|
||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||
self.max_grad_norm = max_grad_norm
|
||||
self.save_dir = Path(save_dir)
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.log_interval = log_interval
|
||||
self.eval_interval = eval_interval
|
||||
|
||||
# Setup optimizer
|
||||
if optimizer is None:
|
||||
self.optimizer = AdamW(
|
||||
self.model.parameters(),
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=0.01,
|
||||
)
|
||||
else:
|
||||
self.optimizer = optimizer
|
||||
|
||||
# Setup scheduler
|
||||
self.scheduler = scheduler
|
||||
|
||||
# Determine device type for AMP
|
||||
# Convert device string to torch.device if needed
|
||||
if isinstance(device, str):
|
||||
self.device = torch.device(device)
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
device_type = self.device.type
|
||||
|
||||
# Setup mixed precision training (only for CUDA)
|
||||
self.device_type = device_type
|
||||
self.use_amp = use_amp and device_type == 'cuda' # Only use AMP for CUDA
|
||||
|
||||
if self.use_amp:
|
||||
# Use new device-agnostic API
|
||||
self.scaler = torch.amp.GradScaler('cuda')
|
||||
self.autocast_dtype = torch.float16
|
||||
else:
|
||||
self.scaler = None
|
||||
self.autocast_dtype = None
|
||||
|
||||
# Loss function
|
||||
self.criterion = nn.CrossEntropyLoss(ignore_index=-100)
|
||||
|
||||
# Training state
|
||||
self.current_epoch = 0
|
||||
self.global_step = 0
|
||||
self.best_val_loss = float('inf')
|
||||
|
||||
# Training metrics tracking
|
||||
self.metrics = TrainingMetrics(save_dir=save_dir)
|
||||
|
||||
def train_epoch(self) -> Dict[str, float]:
|
||||
"""Train for one epoch."""
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
progress_bar = tqdm(
|
||||
self.train_loader,
|
||||
desc=f"Epoch {self.current_epoch + 1}",
|
||||
mininterval=0.1,
|
||||
maxinterval=1.0,
|
||||
file=sys.stderr, # Write to stderr to avoid buffering issues
|
||||
dynamic_ncols=True, # Auto-adjust to terminal width
|
||||
disable=False, # Explicitly enable
|
||||
)
|
||||
|
||||
for batch_idx, batch in enumerate(progress_bar):
|
||||
input_ids = batch['input_ids'].to(self.device)
|
||||
labels = batch['labels'].to(self.device)
|
||||
|
||||
# Forward pass with mixed precision (only for CUDA)
|
||||
if self.use_amp:
|
||||
with torch.amp.autocast('cuda', dtype=self.autocast_dtype):
|
||||
logits, _ = self.model(input_ids)
|
||||
|
||||
# Reshape for loss computation
|
||||
logits = logits.view(-1, logits.size(-1))
|
||||
labels = labels.view(-1)
|
||||
|
||||
loss = self.criterion(logits, labels)
|
||||
loss = loss / self.gradient_accumulation_steps
|
||||
else:
|
||||
logits, _ = self.model(input_ids)
|
||||
|
||||
# Reshape for loss computation
|
||||
logits = logits.view(-1, logits.size(-1))
|
||||
labels = labels.view(-1)
|
||||
|
||||
loss = self.criterion(logits, labels)
|
||||
loss = loss / self.gradient_accumulation_steps
|
||||
|
||||
# Backward pass
|
||||
if self.use_amp:
|
||||
self.scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
# Gradient accumulation
|
||||
if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
|
||||
# Gradient clipping
|
||||
if self.use_amp:
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.max_grad_norm
|
||||
)
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.max_grad_norm
|
||||
)
|
||||
self.optimizer.step()
|
||||
|
||||
if self.scheduler is not None:
|
||||
self.scheduler.step()
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
self.global_step += 1
|
||||
|
||||
total_loss += loss.item() * self.gradient_accumulation_steps
|
||||
num_batches += 1
|
||||
|
||||
# Logging
|
||||
if self.global_step % self.log_interval == 0:
|
||||
avg_loss = total_loss / num_batches
|
||||
lr = self.optimizer.param_groups[0]['lr']
|
||||
progress_bar.set_postfix({
|
||||
'loss': f'{avg_loss:.4f}',
|
||||
'lr': f'{lr:.2e}',
|
||||
})
|
||||
progress_bar.refresh() # Force immediate refresh
|
||||
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
|
||||
|
||||
# Log metrics
|
||||
self.metrics.log(
|
||||
epoch=self.current_epoch,
|
||||
step=self.global_step,
|
||||
train_loss=avg_loss,
|
||||
lr=lr,
|
||||
)
|
||||
|
||||
# Evaluation
|
||||
if self.val_loader is not None and self.global_step % self.eval_interval == 0:
|
||||
val_loss = self.evaluate()
|
||||
if val_loss < self.best_val_loss:
|
||||
self.best_val_loss = val_loss
|
||||
self.save_checkpoint(is_best=True)
|
||||
|
||||
avg_loss = total_loss / num_batches
|
||||
|
||||
# Log epoch metrics
|
||||
self.metrics.log(
|
||||
epoch=self.current_epoch,
|
||||
step=self.global_step,
|
||||
train_loss=avg_loss,
|
||||
lr=self.optimizer.param_groups[0]['lr'],
|
||||
)
|
||||
|
||||
return {'loss': avg_loss}
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(self) -> float:
|
||||
"""Evaluate on validation set."""
|
||||
if self.val_loader is None:
|
||||
return float('inf')
|
||||
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
for batch in tqdm(
|
||||
self.val_loader,
|
||||
desc="Evaluating",
|
||||
mininterval=0.1,
|
||||
maxinterval=1.0,
|
||||
file=sys.stderr, # Write to stderr to avoid buffering issues
|
||||
dynamic_ncols=True, # Auto-adjust to terminal width
|
||||
disable=False, # Explicitly enable
|
||||
):
|
||||
input_ids = batch['input_ids'].to(self.device)
|
||||
labels = batch['labels'].to(self.device)
|
||||
|
||||
if self.use_amp:
|
||||
with torch.amp.autocast('cuda', dtype=self.autocast_dtype):
|
||||
logits, _ = self.model(input_ids)
|
||||
logits = logits.view(-1, logits.size(-1))
|
||||
labels = labels.view(-1)
|
||||
loss = self.criterion(logits, labels)
|
||||
else:
|
||||
logits, _ = self.model(input_ids)
|
||||
logits = logits.view(-1, logits.size(-1))
|
||||
labels = labels.view(-1)
|
||||
loss = self.criterion(logits, labels)
|
||||
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
avg_loss = total_loss / num_batches
|
||||
return avg_loss
|
||||
|
||||
def train(self):
|
||||
"""Main training loop."""
|
||||
try:
|
||||
for epoch in range(self.current_epoch, self.max_epochs):
|
||||
self.current_epoch = epoch
|
||||
|
||||
# Train epoch
|
||||
train_metrics = self.train_epoch()
|
||||
|
||||
# Evaluation at end of epoch
|
||||
if self.val_loader is not None:
|
||||
val_loss = self.evaluate()
|
||||
print(f"Epoch {epoch + 1}: Train Loss = {train_metrics['loss']:.4f}, "
|
||||
f"Val Loss = {val_loss:.4f}")
|
||||
else:
|
||||
print(f"Epoch {epoch + 1}: Train Loss = {train_metrics['loss']:.4f}")
|
||||
|
||||
# Save checkpoint
|
||||
self.save_checkpoint()
|
||||
|
||||
# Generate plots at end of training
|
||||
print("\n📊 Generating training plots...")
|
||||
try:
|
||||
self.metrics.plot_training_curve()
|
||||
self.metrics.plot_loss_by_epoch()
|
||||
self.metrics.print_summary()
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not generate plots: {e}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⚠️ Training interrupted by user!")
|
||||
print(f"💾 Saving checkpoint at epoch {self.current_epoch + 1}...")
|
||||
self.save_checkpoint()
|
||||
print(f"✅ Checkpoint saved! You can resume with:")
|
||||
print(f" python3 train.py --data <data> --resume {self.save_dir}/checkpoint_epoch_{self.current_epoch}.pt")
|
||||
|
||||
# Generate plots before exiting
|
||||
print("\n📊 Generating training plots...")
|
||||
try:
|
||||
self.metrics.plot_training_curve()
|
||||
self.metrics.plot_loss_by_epoch()
|
||||
self.metrics.print_summary()
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not generate plots: {e}")
|
||||
|
||||
raise
|
||||
|
||||
def save_checkpoint(self, is_best: bool = False, model_config: dict = None):
|
||||
"""Save model checkpoint."""
|
||||
checkpoint = {
|
||||
'epoch': self.current_epoch,
|
||||
'global_step': self.global_step,
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'best_val_loss': self.best_val_loss,
|
||||
}
|
||||
|
||||
# Save model config if provided
|
||||
if model_config is not None:
|
||||
checkpoint['model_config'] = model_config
|
||||
|
||||
if self.scheduler is not None:
|
||||
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
|
||||
|
||||
# Save regular checkpoint
|
||||
checkpoint_path = self.save_dir / f'checkpoint_epoch_{self.current_epoch}.pt'
|
||||
torch.save(checkpoint, checkpoint_path)
|
||||
|
||||
# Save best checkpoint
|
||||
if is_best:
|
||||
best_path = self.save_dir / 'best_checkpoint.pt'
|
||||
torch.save(checkpoint, best_path)
|
||||
print(f"Saved best checkpoint with val_loss = {self.best_val_loss:.4f}")
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: str):
|
||||
"""Load model checkpoint."""
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
|
||||
if 'scheduler_state_dict' in checkpoint and self.scheduler is not None:
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
|
||||
self.current_epoch = checkpoint.get('epoch', 0)
|
||||
self.global_step = checkpoint.get('global_step', 0)
|
||||
self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
|
||||
|
||||
print(f"Loaded checkpoint from epoch {self.current_epoch}")
|
||||
|
||||
|
||||
def compute_perplexity(model: nn.Module, data_loader, device: str = 'cuda') -> float:
|
||||
"""
|
||||
Compute perplexity on a dataset.
|
||||
|
||||
Args:
|
||||
model: Model to evaluate
|
||||
data_loader: Data loader
|
||||
device: Device to use
|
||||
|
||||
Returns:
|
||||
Perplexity score
|
||||
"""
|
||||
model.eval()
|
||||
total_loss = 0.0
|
||||
num_tokens = 0
|
||||
|
||||
criterion = nn.CrossEntropyLoss(ignore_index=-100, reduction='sum')
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(
|
||||
data_loader,
|
||||
desc="Computing perplexity",
|
||||
mininterval=0.1,
|
||||
maxinterval=1.0,
|
||||
file=sys.stderr, # Write to stderr to avoid buffering issues
|
||||
dynamic_ncols=True, # Auto-adjust to terminal width
|
||||
disable=False, # Explicitly enable
|
||||
):
|
||||
input_ids = batch['input_ids'].to(device)
|
||||
labels = batch['labels'].to(device)
|
||||
|
||||
logits, _ = model(input_ids)
|
||||
logits = logits.view(-1, logits.size(-1))
|
||||
labels = labels.view(-1)
|
||||
|
||||
loss = criterion(logits, labels)
|
||||
total_loss += loss.item()
|
||||
num_tokens += (labels != -100).sum().item()
|
||||
|
||||
avg_loss = total_loss / num_tokens
|
||||
perplexity = math.exp(avg_loss)
|
||||
|
||||
return perplexity
|
||||
|
||||
|
||||
Reference in New Issue
Block a user