fixing memory

This commit is contained in:
Carlos Gutierrez
2025-11-16 16:22:12 -05:00
parent 5fe3dc0753
commit a1e703423c
2 changed files with 35 additions and 16 deletions

View File

@@ -126,6 +126,20 @@ def get_memory_usage(device: torch.device) -> float:
return None
def get_peak_memory_usage(device: torch.device) -> float:
"""Get peak memory usage in MB since last reset."""
if device.type == 'cuda':
try:
return torch.cuda.max_memory_allocated(device) / (1024 ** 2) # MB
except RuntimeError:
return None
elif device.type == 'mps':
# MPS doesn't have direct memory query, return None
return None
else:
return None
def benchmark_inference(
model: TransformerModel,
tokenizer: SimpleTokenizer,
@@ -158,8 +172,6 @@ def benchmark_inference(
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
memory_before = get_memory_usage(device)
generated_text, generated_ids, input_ids, gen_time = generate_text(
model=model,
tokenizer=tokenizer,
@@ -172,8 +184,8 @@ def benchmark_inference(
optimized=False,
)
memory_after = get_memory_usage(device)
memory_used = memory_after - memory_before if memory_after and memory_before else None
# Use peak memory for more accurate measurement
memory_used = get_peak_memory_usage(device)
generated_ids = remove_trailing_padding(generated_ids, tokenizer.pad_token_id)
prompt_length = len(input_ids[0])
@@ -222,8 +234,6 @@ def benchmark_inference(
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
memory_before = get_memory_usage(device)
generated_text, generated_ids, input_ids, gen_time = generate_text(
model=model,
tokenizer=tokenizer,
@@ -236,8 +246,8 @@ def benchmark_inference(
optimized=True,
)
memory_after = get_memory_usage(device)
memory_used = memory_after - memory_before if memory_after and memory_before else None
# Use peak memory for more accurate measurement
memory_used = get_peak_memory_usage(device)
generated_ids = remove_trailing_padding(generated_ids, tokenizer.pad_token_id)
prompt_length = len(input_ids[0])

View File

@@ -92,6 +92,13 @@ class InferenceMetrics:
Returns:
Dictionary with comparison statistics
"""
def safe_mean(values):
"""Safely compute mean, returning None if no valid values."""
valid_values = [v for v in values if v is not None and not (isinstance(v, float) and np.isnan(v))]
if not valid_values:
return None
return np.mean(valid_values)
runs = self.metrics['runs']
optimized_runs = [r for r in runs if r['optimized']]
@@ -103,16 +110,16 @@ class InferenceMetrics:
'avg_tokens_per_sec': np.mean([r['tokens_per_second'] for r in optimized_runs]) if optimized_runs else 0,
'avg_time_per_token': np.mean([r['time_per_token'] for r in optimized_runs]) if optimized_runs else 0,
'avg_total_time': np.mean([r['total_time'] for r in optimized_runs]) if optimized_runs else 0,
'avg_memory_mb': np.mean([r['memory_used_mb'] for r in optimized_runs if r['memory_used_mb']]) if optimized_runs else None,
'avg_gpu_util': np.mean([r['gpu_utilization'] for r in optimized_runs if r['gpu_utilization']]) if optimized_runs else None,
'avg_memory_mb': safe_mean([r['memory_used_mb'] for r in optimized_runs]),
'avg_gpu_util': safe_mean([r['gpu_utilization'] for r in optimized_runs]),
},
'non_optimized': {
'count': len(non_optimized_runs),
'avg_tokens_per_sec': np.mean([r['tokens_per_second'] for r in non_optimized_runs]) if non_optimized_runs else 0,
'avg_time_per_token': np.mean([r['time_per_token'] for r in non_optimized_runs]) if non_optimized_runs else 0,
'avg_total_time': np.mean([r['total_time'] for r in non_optimized_runs]) if non_optimized_runs else 0,
'avg_memory_mb': np.mean([r['memory_used_mb'] for r in non_optimized_runs if r['memory_used_mb']]) if non_optimized_runs else None,
'avg_gpu_util': np.mean([r['gpu_utilization'] for r in non_optimized_runs if r['gpu_utilization']]) if non_optimized_runs else None,
'avg_memory_mb': safe_mean([r['memory_used_mb'] for r in non_optimized_runs]),
'avg_gpu_util': safe_mean([r['gpu_utilization'] for r in non_optimized_runs]),
},
}
@@ -124,7 +131,9 @@ class InferenceMetrics:
comparison['speedup'] = None
# Calculate memory reduction
if comparison['optimized']['avg_memory_mb'] and comparison['non_optimized']['avg_memory_mb']:
if (comparison['optimized']['avg_memory_mb'] is not None and
comparison['non_optimized']['avg_memory_mb'] is not None and
comparison['non_optimized']['avg_memory_mb'] > 0):
memory_reduction = (1 - comparison['optimized']['avg_memory_mb'] / comparison['non_optimized']['avg_memory_mb']) * 100
comparison['memory_reduction_percent'] = memory_reduction
else:
@@ -356,7 +365,7 @@ class InferenceMetrics:
print(f" Average Tokens/Second: {comparison['optimized']['avg_tokens_per_sec']:.2f}")
print(f" Average Time/Token: {comparison['optimized']['avg_time_per_token']:.3f} ms")
print(f" Average Total Time: {comparison['optimized']['avg_total_time']:.3f} s")
if comparison['optimized']['avg_memory_mb']:
if comparison['optimized']['avg_memory_mb'] is not None:
print(f" Average Memory: {comparison['optimized']['avg_memory_mb']:.1f} MB")
print(f"\nNon-Optimized Runs: {comparison['non_optimized']['count']}")
@@ -364,13 +373,13 @@ class InferenceMetrics:
print(f" Average Tokens/Second: {comparison['non_optimized']['avg_tokens_per_sec']:.2f}")
print(f" Average Time/Token: {comparison['non_optimized']['avg_time_per_token']:.3f} ms")
print(f" Average Total Time: {comparison['non_optimized']['avg_total_time']:.3f} s")
if comparison['non_optimized']['avg_memory_mb']:
if comparison['non_optimized']['avg_memory_mb'] is not None:
print(f" Average Memory: {comparison['non_optimized']['avg_memory_mb']:.1f} MB")
if comparison['speedup']:
print(f"\n🚀 SPEEDUP: {comparison['speedup']:.2f}x faster with optimizations")
if comparison['memory_reduction_percent']:
if comparison['memory_reduction_percent'] is not None:
print(f"💾 MEMORY REDUCTION: {comparison['memory_reduction_percent']:.1f}%")
print("=" * 70)