Initial commit: SheepOp LLM - Transformer-based language model implementation
- Complete transformer implementation from scratch - Training pipeline with gradient accumulation and mixed precision - Optimized inference with KV caching - Multi-format data processing (PDFs, images, code, text) - Comprehensive documentation - Apache 2.0 license - Example training plots included in docs/images/
This commit is contained in:
444
inference.py
Normal file
444
inference.py
Normal file
@@ -0,0 +1,444 @@
|
||||
"""
|
||||
Inference script for generating text
|
||||
Optimized for production RAG systems with KV caching and efficient inference
|
||||
"""
|
||||
import torch
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import importlib.util
|
||||
import time
|
||||
|
||||
# Ensure current directory is in path
|
||||
project_root = Path(__file__).parent.absolute()
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Explicitly import from local data module to avoid conflicts with stdlib 'data' module
|
||||
data_module_path = project_root / "data" / "__init__.py"
|
||||
spec = importlib.util.spec_from_file_location("sheepop_data", data_module_path)
|
||||
sheepop_data = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(sheepop_data)
|
||||
SimpleTokenizer = sheepop_data.SimpleTokenizer
|
||||
|
||||
from models import TransformerModel
|
||||
from models.optimized_attention import OptimizedInference
|
||||
from inference_metrics import InferenceMetrics
|
||||
|
||||
|
||||
def load_model(checkpoint_path: str, device: str = 'cuda', tokenizer=None):
|
||||
"""Load model from checkpoint."""
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
|
||||
# Get model config from checkpoint or use defaults
|
||||
model_config = checkpoint.get('model_config', {})
|
||||
|
||||
# If no config in checkpoint, try to infer from model state dict or use defaults
|
||||
if not model_config:
|
||||
print("Warning: No model_config found in checkpoint. Using defaults.")
|
||||
# Try to infer vocab_size from tokenizer if provided
|
||||
if tokenizer is not None:
|
||||
vocab_size = tokenizer.vocab_size
|
||||
else:
|
||||
# Default vocab size - should match your tokenizer
|
||||
vocab_size = 128 # Default for SimpleTokenizer
|
||||
|
||||
model_config = {
|
||||
'vocab_size': vocab_size,
|
||||
'd_model': 512,
|
||||
'num_layers': 6,
|
||||
'num_heads': 8,
|
||||
'd_ff': 2048,
|
||||
'max_seq_len': 512,
|
||||
'dropout': 0.1,
|
||||
'activation': 'gelu',
|
||||
}
|
||||
print(f"Using default config with vocab_size={vocab_size}")
|
||||
|
||||
model = TransformerModel(**model_config)
|
||||
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def generate_text(
|
||||
model: TransformerModel,
|
||||
tokenizer: SimpleTokenizer,
|
||||
prompt: str,
|
||||
max_length: int = 100,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = 50,
|
||||
top_p: float = 0.95,
|
||||
device: str = 'cuda',
|
||||
optimized: bool = False,
|
||||
):
|
||||
"""
|
||||
Generate text from a prompt.
|
||||
|
||||
Returns:
|
||||
tuple: (generated_text, generated_ids, input_ids, generation_time)
|
||||
"""
|
||||
# Encode prompt
|
||||
input_ids = tokenizer.encode(prompt)
|
||||
input_ids = torch.tensor([input_ids], device=device)
|
||||
|
||||
# Measure generation time
|
||||
start_time = time.time()
|
||||
|
||||
if optimized:
|
||||
optimizer = model.get_optimized_inference()
|
||||
generated = optimizer.generate_with_cache(
|
||||
input_ids=input_ids,
|
||||
max_length=max_length,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
)
|
||||
else:
|
||||
generated = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_length=max_length,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
do_sample=True,
|
||||
)
|
||||
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
# Decode
|
||||
generated_ids = generated[0].cpu().tolist()
|
||||
generated_text = tokenizer.decode(generated_ids)
|
||||
|
||||
return generated_text, generated_ids, input_ids, generation_time
|
||||
|
||||
|
||||
def get_memory_usage(device: torch.device) -> float:
|
||||
"""Get current memory usage in MB."""
|
||||
if device.type == 'cuda':
|
||||
return torch.cuda.memory_allocated(device) / (1024 ** 2) # MB
|
||||
elif device.type == 'mps':
|
||||
# MPS doesn't have direct memory query, return None
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def benchmark_inference(
|
||||
model: TransformerModel,
|
||||
tokenizer: SimpleTokenizer,
|
||||
prompt: str,
|
||||
max_length: int,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
device: torch.device,
|
||||
metrics: InferenceMetrics,
|
||||
run_name: str,
|
||||
):
|
||||
"""Run benchmark for both optimized and non-optimized inference."""
|
||||
|
||||
def remove_trailing_padding(token_ids, pad_token_id):
|
||||
"""Remove trailing padding tokens."""
|
||||
while token_ids and token_ids[-1] == pad_token_id:
|
||||
token_ids.pop()
|
||||
return token_ids
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print(f"BENCHMARK RUN: {run_name}")
|
||||
print("=" * 70)
|
||||
|
||||
results = {}
|
||||
|
||||
# Run non-optimized first
|
||||
print("\n🔴 Running NON-OPTIMIZED inference...")
|
||||
if device.type == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
|
||||
memory_before = get_memory_usage(device)
|
||||
|
||||
generated_text, generated_ids, input_ids, gen_time = generate_text(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_length=max_length,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
device=str(device),
|
||||
optimized=False,
|
||||
)
|
||||
|
||||
memory_after = get_memory_usage(device)
|
||||
memory_used = memory_after - memory_before if memory_after and memory_before else None
|
||||
|
||||
generated_ids = remove_trailing_padding(generated_ids, tokenizer.pad_token_id)
|
||||
prompt_length = len(input_ids[0])
|
||||
generated_length = len(generated_ids) - prompt_length
|
||||
|
||||
if generated_length > 0:
|
||||
tokens_per_sec = generated_length / gen_time
|
||||
time_per_token = (gen_time / generated_length) * 1000 # ms
|
||||
else:
|
||||
tokens_per_sec = 0
|
||||
time_per_token = 0
|
||||
|
||||
results['non_optimized'] = {
|
||||
'text': generated_text,
|
||||
'prompt_length': prompt_length,
|
||||
'generated_length': generated_length,
|
||||
'total_time': gen_time,
|
||||
'tokens_per_sec': tokens_per_sec,
|
||||
'time_per_token': time_per_token,
|
||||
'memory_mb': memory_used,
|
||||
}
|
||||
|
||||
print(f" ⏱️ Total Time: {gen_time:.3f} s")
|
||||
print(f" 📊 Tokens/Second: {tokens_per_sec:.2f}")
|
||||
print(f" ⚡ Time/Token: {time_per_token:.3f} ms")
|
||||
if memory_used:
|
||||
print(f" 💾 Memory Used: {memory_used:.1f} MB")
|
||||
print(f" 📝 Generated: {generated_text[:100]}...")
|
||||
|
||||
# Log metrics
|
||||
metrics.log_run(
|
||||
run_name=f"{run_name}_non_optimized",
|
||||
optimized=False,
|
||||
prompt_length=prompt_length,
|
||||
generated_length=generated_length,
|
||||
total_time=gen_time,
|
||||
tokens_per_second=tokens_per_sec,
|
||||
time_per_token=time_per_token,
|
||||
memory_used_mb=memory_used,
|
||||
device=str(device),
|
||||
)
|
||||
|
||||
# Run optimized
|
||||
print("\n🟢 Running OPTIMIZED inference...")
|
||||
if device.type == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
|
||||
memory_before = get_memory_usage(device)
|
||||
|
||||
generated_text, generated_ids, input_ids, gen_time = generate_text(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_length=max_length,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
device=str(device),
|
||||
optimized=True,
|
||||
)
|
||||
|
||||
memory_after = get_memory_usage(device)
|
||||
memory_used = memory_after - memory_before if memory_after and memory_before else None
|
||||
|
||||
generated_ids = remove_trailing_padding(generated_ids, tokenizer.pad_token_id)
|
||||
prompt_length = len(input_ids[0])
|
||||
generated_length = len(generated_ids) - prompt_length
|
||||
|
||||
if generated_length > 0:
|
||||
tokens_per_sec = generated_length / gen_time
|
||||
time_per_token = (gen_time / generated_length) * 1000 # ms
|
||||
else:
|
||||
tokens_per_sec = 0
|
||||
time_per_token = 0
|
||||
|
||||
results['optimized'] = {
|
||||
'text': generated_text,
|
||||
'prompt_length': prompt_length,
|
||||
'generated_length': generated_length,
|
||||
'total_time': gen_time,
|
||||
'tokens_per_sec': tokens_per_sec,
|
||||
'time_per_token': time_per_token,
|
||||
'memory_mb': memory_used,
|
||||
}
|
||||
|
||||
print(f" ⏱️ Total Time: {gen_time:.3f} s")
|
||||
print(f" 📊 Tokens/Second: {tokens_per_sec:.2f}")
|
||||
print(f" ⚡ Time/Token: {time_per_token:.3f} ms")
|
||||
if memory_used:
|
||||
print(f" 💾 Memory Used: {memory_used:.1f} MB")
|
||||
print(f" 📝 Generated: {generated_text[:100]}...")
|
||||
|
||||
# Log metrics
|
||||
metrics.log_run(
|
||||
run_name=f"{run_name}_optimized",
|
||||
optimized=True,
|
||||
prompt_length=prompt_length,
|
||||
generated_length=generated_length,
|
||||
total_time=gen_time,
|
||||
tokens_per_second=tokens_per_sec,
|
||||
time_per_token=time_per_token,
|
||||
memory_used_mb=memory_used,
|
||||
device=str(device),
|
||||
)
|
||||
|
||||
# Calculate speedup
|
||||
if results['non_optimized']['tokens_per_sec'] > 0:
|
||||
speedup = results['optimized']['tokens_per_sec'] / results['non_optimized']['tokens_per_sec']
|
||||
print(f"\n🚀 SPEEDUP: {speedup:.2f}x faster with optimizations")
|
||||
|
||||
if results['non_optimized']['memory_mb'] and results['optimized']['memory_mb']:
|
||||
memory_reduction = (1 - results['optimized']['memory_mb'] / results['non_optimized']['memory_mb']) * 100
|
||||
print(f"💾 MEMORY REDUCTION: {memory_reduction:.1f}%")
|
||||
|
||||
print("=" * 70)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Generate text with SheepOp LLM')
|
||||
parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')
|
||||
parser.add_argument('--prompt', type=str, required=True, help='Prompt text')
|
||||
parser.add_argument('--max-length', type=int, default=100, help='Maximum generation length')
|
||||
parser.add_argument('--temperature', type=float, default=1.0, help='Sampling temperature')
|
||||
parser.add_argument('--top-k', type=int, default=50, help='Top-k sampling')
|
||||
parser.add_argument('--top-p', type=float, default=0.95, help='Top-p (nucleus) sampling')
|
||||
parser.add_argument('--device', type=str, default='cuda', help='Device to use')
|
||||
parser.add_argument('--optimized', action='store_true', help='Use optimized inference with KV caching')
|
||||
parser.add_argument('--benchmark', action='store_true', help='Run benchmark comparing optimized vs non-optimized inference (for research)')
|
||||
parser.add_argument('--benchmark-dir', type=str, default='./inference_benchmarks', help='Directory to save benchmark results')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup device
|
||||
if args.device == 'cuda' and torch.cuda.is_available():
|
||||
device = torch.device('cuda')
|
||||
elif args.device == 'mps' and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
device = torch.device('mps')
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Create tokenizer first (needed for vocab_size)
|
||||
tokenizer = SimpleTokenizer()
|
||||
|
||||
# Load model
|
||||
print("Loading model...")
|
||||
model = load_model(args.checkpoint, device=device, tokenizer=tokenizer)
|
||||
print("Model loaded!")
|
||||
|
||||
# Check if benchmarking mode
|
||||
if args.benchmark:
|
||||
print("\n🔬 BENCHMARK MODE: Comparing optimized vs non-optimized inference")
|
||||
print("=" * 70)
|
||||
|
||||
# Initialize metrics
|
||||
metrics = InferenceMetrics(save_dir=args.benchmark_dir)
|
||||
|
||||
# Run benchmark
|
||||
run_name = f"run_{int(time.time())}"
|
||||
results = benchmark_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=args.prompt,
|
||||
max_length=args.max_length,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
device=device,
|
||||
metrics=metrics,
|
||||
run_name=run_name,
|
||||
)
|
||||
|
||||
# Generate plots and summary
|
||||
print("\n📊 Generating comparison plots and data...")
|
||||
metrics.plot_comparison()
|
||||
metrics.plot_performance_over_time()
|
||||
metrics.export_to_csv()
|
||||
metrics.print_summary()
|
||||
|
||||
print(f"\n✅ Benchmark complete! Results saved to: {args.benchmark_dir}")
|
||||
print(f" - JSON metrics: {args.benchmark_dir}/inference_metrics.json")
|
||||
print(f" - CSV export: {args.benchmark_dir}/inference_metrics.csv")
|
||||
print(f" - Comparison plot: {args.benchmark_dir}/optimization_comparison.png")
|
||||
print(f" - Performance plot: {args.benchmark_dir}/performance_over_time.png")
|
||||
|
||||
return
|
||||
|
||||
# Normal inference mode
|
||||
use_optimized = args.optimized if hasattr(args, 'optimized') else False
|
||||
|
||||
if use_optimized:
|
||||
print("Using optimized inference with KV caching...")
|
||||
optimizer = model.get_optimized_inference()
|
||||
else:
|
||||
optimizer = None
|
||||
|
||||
# Encode prompt
|
||||
input_ids = tokenizer.encode(args.prompt)
|
||||
input_ids = torch.tensor([input_ids], device=device)
|
||||
|
||||
# Generate text
|
||||
print(f"Prompt: {args.prompt}")
|
||||
print("Generating...")
|
||||
|
||||
# Filter out padding tokens from the end of generated sequence
|
||||
def remove_trailing_padding(token_ids, pad_token_id):
|
||||
"""Remove trailing padding tokens."""
|
||||
while token_ids and token_ids[-1] == pad_token_id:
|
||||
token_ids.pop()
|
||||
return token_ids
|
||||
|
||||
if optimizer is not None:
|
||||
# Use optimized generation with KV cache
|
||||
generated = optimizer.generate_with_cache(
|
||||
input_ids=input_ids,
|
||||
max_length=args.max_length,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
)
|
||||
generated_ids = generated[0].cpu().tolist()
|
||||
# Remove trailing padding
|
||||
generated_ids = remove_trailing_padding(generated_ids, tokenizer.pad_token_id)
|
||||
print(f"Generated {len(generated_ids)} tokens (input had {len(input_ids[0])} tokens, after removing padding)")
|
||||
else:
|
||||
# Use standard generation
|
||||
generated = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_length=args.max_length,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
do_sample=True,
|
||||
)
|
||||
generated_ids = generated[0].cpu().tolist()
|
||||
# Remove trailing padding
|
||||
generated_ids = remove_trailing_padding(generated_ids, tokenizer.pad_token_id)
|
||||
print(f"Generated {len(generated_ids)} tokens (input had {len(input_ids[0])} tokens, after removing padding)")
|
||||
|
||||
# Debug: Show some token statistics
|
||||
vocab_size = tokenizer.vocab_size
|
||||
valid_tokens = sum(1 for tid in generated_ids if tid in tokenizer.inv_vocab)
|
||||
unk_tokens = sum(1 for tid in generated_ids if tid not in tokenizer.inv_vocab)
|
||||
pad_tokens = sum(1 for tid in generated_ids if tid == tokenizer.pad_token_id)
|
||||
|
||||
print(f"Token statistics:")
|
||||
print(f" Valid tokens: {valid_tokens}/{len(generated_ids)}")
|
||||
print(f" Unknown tokens: {unk_tokens}")
|
||||
print(f" Pad tokens: {pad_tokens}")
|
||||
print(f" Vocab size: {vocab_size}")
|
||||
print(f" Token ID range: {min(generated_ids) if generated_ids else 'N/A'} - {max(generated_ids) if generated_ids else 'N/A'}")
|
||||
|
||||
# Show first 20 token IDs for debugging
|
||||
print(f" First 20 token IDs: {generated_ids[:20]}")
|
||||
|
||||
generated_text = tokenizer.decode(generated_ids)
|
||||
|
||||
print(f"\nGenerated: {generated_text}")
|
||||
print(f"Generated length: {len(generated_text)} characters")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user