Initial commit: SheepOp LLM - Transformer-based language model implementation
- Complete transformer implementation from scratch - Training pipeline with gradient accumulation and mixed precision - Optimized inference with KV caching - Multi-format data processing (PDFs, images, code, text) - Comprehensive documentation - Apache 2.0 license - Example training plots included in docs/images/
This commit is contained in:
109
utils.py
Normal file
109
utils.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
Utility functions for model evaluation and metrics
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import List, Dict
|
||||
import numpy as np
|
||||
import sys
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def compute_accuracy(model: nn.Module, data_loader, device: str = 'cuda') -> float:
|
||||
"""
|
||||
Compute token-level accuracy.
|
||||
|
||||
Args:
|
||||
model: Model to evaluate
|
||||
data_loader: Data loader
|
||||
device: Device to use
|
||||
|
||||
Returns:
|
||||
Accuracy score
|
||||
"""
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(
|
||||
data_loader,
|
||||
desc="Computing accuracy",
|
||||
mininterval=0.1,
|
||||
maxinterval=1.0,
|
||||
file=sys.stderr, # Write to stderr to avoid buffering issues
|
||||
dynamic_ncols=True, # Auto-adjust to terminal width
|
||||
disable=False, # Explicitly enable
|
||||
):
|
||||
input_ids = batch['input_ids'].to(device)
|
||||
labels = batch['labels'].to(device)
|
||||
|
||||
logits, _ = model(input_ids)
|
||||
predictions = torch.argmax(logits, dim=-1)
|
||||
|
||||
# Mask out padding tokens
|
||||
mask = (labels != -100)
|
||||
correct += ((predictions == labels) * mask).sum().item()
|
||||
total += mask.sum().item()
|
||||
|
||||
accuracy = correct / total if total > 0 else 0.0
|
||||
return accuracy
|
||||
|
||||
|
||||
def compute_metrics(model: nn.Module, data_loader, device: str = 'cuda') -> Dict[str, float]:
|
||||
"""
|
||||
Compute various evaluation metrics.
|
||||
|
||||
Args:
|
||||
model: Model to evaluate
|
||||
data_loader: Data loader
|
||||
device: Device to use
|
||||
|
||||
Returns:
|
||||
Dictionary of metrics
|
||||
"""
|
||||
model.eval()
|
||||
total_loss = 0.0
|
||||
correct = 0
|
||||
total_tokens = 0
|
||||
|
||||
criterion = nn.CrossEntropyLoss(ignore_index=-100, reduction='sum')
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(
|
||||
data_loader,
|
||||
desc="Computing metrics",
|
||||
mininterval=0.1,
|
||||
maxinterval=1.0,
|
||||
file=sys.stderr, # Write to stderr to avoid buffering issues
|
||||
dynamic_ncols=True, # Auto-adjust to terminal width
|
||||
disable=False, # Explicitly enable
|
||||
):
|
||||
input_ids = batch['input_ids'].to(device)
|
||||
labels = batch['labels'].to(device)
|
||||
|
||||
logits, _ = model(input_ids)
|
||||
logits = logits.view(-1, logits.size(-1))
|
||||
labels_flat = labels.view(-1)
|
||||
|
||||
# Loss
|
||||
loss = criterion(logits, labels_flat)
|
||||
total_loss += loss.item()
|
||||
|
||||
# Accuracy
|
||||
predictions = torch.argmax(logits, dim=-1)
|
||||
mask = (labels_flat != -100)
|
||||
correct += ((predictions == labels_flat) * mask).sum().item()
|
||||
total_tokens += mask.sum().item()
|
||||
|
||||
avg_loss = total_loss / total_tokens if total_tokens > 0 else 0.0
|
||||
accuracy = correct / total_tokens if total_tokens > 0 else 0.0
|
||||
perplexity = np.exp(avg_loss) if avg_loss > 0 else float('inf')
|
||||
|
||||
return {
|
||||
'loss': avg_loss,
|
||||
'accuracy': accuracy,
|
||||
'perplexity': perplexity,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user