Files
sheepOp/training/__init__.py
Carlos Gutierrez 3d2da94ce2 Initial commit: SheepOp LLM - Transformer-based language model implementation
- Complete transformer implementation from scratch
- Training pipeline with gradient accumulation and mixed precision
- Optimized inference with KV caching
- Multi-format data processing (PDFs, images, code, text)
- Comprehensive documentation
- Apache 2.0 license
- Example training plots included in docs/images/
2025-11-06 22:07:41 -05:00

406 lines
14 KiB
Python

"""
Training utilities and training loop
"""
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
from typing import Dict, Optional, Callable
from pathlib import Path
import json
import sys
from tqdm import tqdm
import math
from .metrics import TrainingMetrics
class Trainer:
"""
Trainer class for language model training.
Includes gradient accumulation, mixed precision training, and checkpointing.
"""
def __init__(
self,
model: nn.Module,
train_loader,
val_loader=None,
optimizer=None,
scheduler=None,
device: str = 'cuda',
max_epochs: int = 10,
gradient_accumulation_steps: int = 1,
max_grad_norm: float = 1.0,
use_amp: bool = True,
save_dir: str = './checkpoints',
log_interval: int = 100,
eval_interval: int = 1000,
):
"""
Args:
model: Model to train
train_loader: Training data loader
val_loader: Validation data loader (optional)
optimizer: Optimizer (if None, AdamW is used)
scheduler: Learning rate scheduler (optional)
device: Device to train on
max_epochs: Maximum number of epochs
gradient_accumulation_steps: Gradient accumulation steps
max_grad_norm: Maximum gradient norm for clipping
use_amp: Whether to use mixed precision training
save_dir: Directory to save checkpoints
log_interval: Logging interval
eval_interval: Evaluation interval
"""
# Convert device string to torch.device if needed
if isinstance(device, str):
self.device = torch.device(device)
else:
self.device = device
self.model = model.to(self.device)
self.train_loader = train_loader
self.val_loader = val_loader
self.max_epochs = max_epochs
self.gradient_accumulation_steps = gradient_accumulation_steps
self.max_grad_norm = max_grad_norm
self.save_dir = Path(save_dir)
self.save_dir.mkdir(parents=True, exist_ok=True)
self.log_interval = log_interval
self.eval_interval = eval_interval
# Setup optimizer
if optimizer is None:
self.optimizer = AdamW(
self.model.parameters(),
lr=1e-4,
betas=(0.9, 0.999),
weight_decay=0.01,
)
else:
self.optimizer = optimizer
# Setup scheduler
self.scheduler = scheduler
# Determine device type for AMP
# Convert device string to torch.device if needed
if isinstance(device, str):
self.device = torch.device(device)
else:
self.device = device
device_type = self.device.type
# Setup mixed precision training (only for CUDA)
self.device_type = device_type
self.use_amp = use_amp and device_type == 'cuda' # Only use AMP for CUDA
if self.use_amp:
# Use new device-agnostic API
self.scaler = torch.amp.GradScaler('cuda')
self.autocast_dtype = torch.float16
else:
self.scaler = None
self.autocast_dtype = None
# Loss function
self.criterion = nn.CrossEntropyLoss(ignore_index=-100)
# Training state
self.current_epoch = 0
self.global_step = 0
self.best_val_loss = float('inf')
# Training metrics tracking
self.metrics = TrainingMetrics(save_dir=save_dir)
def train_epoch(self) -> Dict[str, float]:
"""Train for one epoch."""
self.model.train()
total_loss = 0.0
num_batches = 0
progress_bar = tqdm(
self.train_loader,
desc=f"Epoch {self.current_epoch + 1}",
mininterval=0.1,
maxinterval=1.0,
file=sys.stderr, # Write to stderr to avoid buffering issues
dynamic_ncols=True, # Auto-adjust to terminal width
disable=False, # Explicitly enable
)
for batch_idx, batch in enumerate(progress_bar):
input_ids = batch['input_ids'].to(self.device)
labels = batch['labels'].to(self.device)
# Forward pass with mixed precision (only for CUDA)
if self.use_amp:
with torch.amp.autocast('cuda', dtype=self.autocast_dtype):
logits, _ = self.model(input_ids)
# Reshape for loss computation
logits = logits.view(-1, logits.size(-1))
labels = labels.view(-1)
loss = self.criterion(logits, labels)
loss = loss / self.gradient_accumulation_steps
else:
logits, _ = self.model(input_ids)
# Reshape for loss computation
logits = logits.view(-1, logits.size(-1))
labels = labels.view(-1)
loss = self.criterion(logits, labels)
loss = loss / self.gradient_accumulation_steps
# Backward pass
if self.use_amp:
self.scaler.scale(loss).backward()
else:
loss.backward()
# Gradient accumulation
if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
# Gradient clipping
if self.use_amp:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.max_grad_norm
)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.max_grad_norm
)
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()
self.optimizer.zero_grad()
self.global_step += 1
total_loss += loss.item() * self.gradient_accumulation_steps
num_batches += 1
# Logging
if self.global_step % self.log_interval == 0:
avg_loss = total_loss / num_batches
lr = self.optimizer.param_groups[0]['lr']
progress_bar.set_postfix({
'loss': f'{avg_loss:.4f}',
'lr': f'{lr:.2e}',
})
progress_bar.refresh() # Force immediate refresh
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
# Log metrics
self.metrics.log(
epoch=self.current_epoch,
step=self.global_step,
train_loss=avg_loss,
lr=lr,
)
# Evaluation
if self.val_loader is not None and self.global_step % self.eval_interval == 0:
val_loss = self.evaluate()
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
self.save_checkpoint(is_best=True)
avg_loss = total_loss / num_batches
# Log epoch metrics
self.metrics.log(
epoch=self.current_epoch,
step=self.global_step,
train_loss=avg_loss,
lr=self.optimizer.param_groups[0]['lr'],
)
return {'loss': avg_loss}
@torch.no_grad()
def evaluate(self) -> float:
"""Evaluate on validation set."""
if self.val_loader is None:
return float('inf')
self.model.eval()
total_loss = 0.0
num_batches = 0
for batch in tqdm(
self.val_loader,
desc="Evaluating",
mininterval=0.1,
maxinterval=1.0,
file=sys.stderr, # Write to stderr to avoid buffering issues
dynamic_ncols=True, # Auto-adjust to terminal width
disable=False, # Explicitly enable
):
input_ids = batch['input_ids'].to(self.device)
labels = batch['labels'].to(self.device)
if self.use_amp:
with torch.amp.autocast('cuda', dtype=self.autocast_dtype):
logits, _ = self.model(input_ids)
logits = logits.view(-1, logits.size(-1))
labels = labels.view(-1)
loss = self.criterion(logits, labels)
else:
logits, _ = self.model(input_ids)
logits = logits.view(-1, logits.size(-1))
labels = labels.view(-1)
loss = self.criterion(logits, labels)
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / num_batches
return avg_loss
def train(self):
"""Main training loop."""
try:
for epoch in range(self.current_epoch, self.max_epochs):
self.current_epoch = epoch
# Train epoch
train_metrics = self.train_epoch()
# Evaluation at end of epoch
if self.val_loader is not None:
val_loss = self.evaluate()
print(f"Epoch {epoch + 1}: Train Loss = {train_metrics['loss']:.4f}, "
f"Val Loss = {val_loss:.4f}")
else:
print(f"Epoch {epoch + 1}: Train Loss = {train_metrics['loss']:.4f}")
# Save checkpoint
self.save_checkpoint()
# Generate plots at end of training
print("\n📊 Generating training plots...")
try:
self.metrics.plot_training_curve()
self.metrics.plot_loss_by_epoch()
self.metrics.print_summary()
except Exception as e:
print(f"Warning: Could not generate plots: {e}")
except KeyboardInterrupt:
print("\n\n⚠️ Training interrupted by user!")
print(f"💾 Saving checkpoint at epoch {self.current_epoch + 1}...")
self.save_checkpoint()
print(f"✅ Checkpoint saved! You can resume with:")
print(f" python3 train.py --data <data> --resume {self.save_dir}/checkpoint_epoch_{self.current_epoch}.pt")
# Generate plots before exiting
print("\n📊 Generating training plots...")
try:
self.metrics.plot_training_curve()
self.metrics.plot_loss_by_epoch()
self.metrics.print_summary()
except Exception as e:
print(f"Warning: Could not generate plots: {e}")
raise
def save_checkpoint(self, is_best: bool = False, model_config: dict = None):
"""Save model checkpoint."""
checkpoint = {
'epoch': self.current_epoch,
'global_step': self.global_step,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'best_val_loss': self.best_val_loss,
}
# Save model config if provided
if model_config is not None:
checkpoint['model_config'] = model_config
if self.scheduler is not None:
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
# Save regular checkpoint
checkpoint_path = self.save_dir / f'checkpoint_epoch_{self.current_epoch}.pt'
torch.save(checkpoint, checkpoint_path)
# Save best checkpoint
if is_best:
best_path = self.save_dir / 'best_checkpoint.pt'
torch.save(checkpoint, best_path)
print(f"Saved best checkpoint with val_loss = {self.best_val_loss:.4f}")
def load_checkpoint(self, checkpoint_path: str):
"""Load model checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'scheduler_state_dict' in checkpoint and self.scheduler is not None:
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.current_epoch = checkpoint.get('epoch', 0)
self.global_step = checkpoint.get('global_step', 0)
self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
print(f"Loaded checkpoint from epoch {self.current_epoch}")
def compute_perplexity(model: nn.Module, data_loader, device: str = 'cuda') -> float:
"""
Compute perplexity on a dataset.
Args:
model: Model to evaluate
data_loader: Data loader
device: Device to use
Returns:
Perplexity score
"""
model.eval()
total_loss = 0.0
num_tokens = 0
criterion = nn.CrossEntropyLoss(ignore_index=-100, reduction='sum')
with torch.no_grad():
for batch in tqdm(
data_loader,
desc="Computing perplexity",
mininterval=0.1,
maxinterval=1.0,
file=sys.stderr, # Write to stderr to avoid buffering issues
dynamic_ncols=True, # Auto-adjust to terminal width
disable=False, # Explicitly enable
):
input_ids = batch['input_ids'].to(device)
labels = batch['labels'].to(device)
logits, _ = model(input_ids)
logits = logits.view(-1, logits.size(-1))
labels = labels.view(-1)
loss = criterion(logits, labels)
total_loss += loss.item()
num_tokens += (labels != -100).sum().item()
avg_loss = total_loss / num_tokens
perplexity = math.exp(avg_loss)
return perplexity