diff --git a/inference.py b/inference.py index ac82cd2..a741679 100644 --- a/inference.py +++ b/inference.py @@ -126,6 +126,20 @@ def get_memory_usage(device: torch.device) -> float: return None +def get_peak_memory_usage(device: torch.device) -> float: + """Get peak memory usage in MB since last reset.""" + if device.type == 'cuda': + try: + return torch.cuda.max_memory_allocated(device) / (1024 ** 2) # MB + except RuntimeError: + return None + elif device.type == 'mps': + # MPS doesn't have direct memory query, return None + return None + else: + return None + + def benchmark_inference( model: TransformerModel, tokenizer: SimpleTokenizer, @@ -158,8 +172,6 @@ def benchmark_inference( torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats(device) - memory_before = get_memory_usage(device) - generated_text, generated_ids, input_ids, gen_time = generate_text( model=model, tokenizer=tokenizer, @@ -172,8 +184,8 @@ def benchmark_inference( optimized=False, ) - memory_after = get_memory_usage(device) - memory_used = memory_after - memory_before if memory_after and memory_before else None + # Use peak memory for more accurate measurement + memory_used = get_peak_memory_usage(device) generated_ids = remove_trailing_padding(generated_ids, tokenizer.pad_token_id) prompt_length = len(input_ids[0]) @@ -222,8 +234,6 @@ def benchmark_inference( torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats(device) - memory_before = get_memory_usage(device) - generated_text, generated_ids, input_ids, gen_time = generate_text( model=model, tokenizer=tokenizer, @@ -236,8 +246,8 @@ def benchmark_inference( optimized=True, ) - memory_after = get_memory_usage(device) - memory_used = memory_after - memory_before if memory_after and memory_before else None + # Use peak memory for more accurate measurement + memory_used = get_peak_memory_usage(device) generated_ids = remove_trailing_padding(generated_ids, tokenizer.pad_token_id) prompt_length = len(input_ids[0]) diff --git a/inference_metrics.py b/inference_metrics.py index c753b4e..1e1a002 100644 --- a/inference_metrics.py +++ b/inference_metrics.py @@ -92,6 +92,13 @@ class InferenceMetrics: Returns: Dictionary with comparison statistics """ + def safe_mean(values): + """Safely compute mean, returning None if no valid values.""" + valid_values = [v for v in values if v is not None and not (isinstance(v, float) and np.isnan(v))] + if not valid_values: + return None + return np.mean(valid_values) + runs = self.metrics['runs'] optimized_runs = [r for r in runs if r['optimized']] @@ -103,16 +110,16 @@ class InferenceMetrics: 'avg_tokens_per_sec': np.mean([r['tokens_per_second'] for r in optimized_runs]) if optimized_runs else 0, 'avg_time_per_token': np.mean([r['time_per_token'] for r in optimized_runs]) if optimized_runs else 0, 'avg_total_time': np.mean([r['total_time'] for r in optimized_runs]) if optimized_runs else 0, - 'avg_memory_mb': np.mean([r['memory_used_mb'] for r in optimized_runs if r['memory_used_mb']]) if optimized_runs else None, - 'avg_gpu_util': np.mean([r['gpu_utilization'] for r in optimized_runs if r['gpu_utilization']]) if optimized_runs else None, + 'avg_memory_mb': safe_mean([r['memory_used_mb'] for r in optimized_runs]), + 'avg_gpu_util': safe_mean([r['gpu_utilization'] for r in optimized_runs]), }, 'non_optimized': { 'count': len(non_optimized_runs), 'avg_tokens_per_sec': np.mean([r['tokens_per_second'] for r in non_optimized_runs]) if non_optimized_runs else 0, 'avg_time_per_token': np.mean([r['time_per_token'] for r in non_optimized_runs]) if non_optimized_runs else 0, 'avg_total_time': np.mean([r['total_time'] for r in non_optimized_runs]) if non_optimized_runs else 0, - 'avg_memory_mb': np.mean([r['memory_used_mb'] for r in non_optimized_runs if r['memory_used_mb']]) if non_optimized_runs else None, - 'avg_gpu_util': np.mean([r['gpu_utilization'] for r in non_optimized_runs if r['gpu_utilization']]) if non_optimized_runs else None, + 'avg_memory_mb': safe_mean([r['memory_used_mb'] for r in non_optimized_runs]), + 'avg_gpu_util': safe_mean([r['gpu_utilization'] for r in non_optimized_runs]), }, } @@ -124,7 +131,9 @@ class InferenceMetrics: comparison['speedup'] = None # Calculate memory reduction - if comparison['optimized']['avg_memory_mb'] and comparison['non_optimized']['avg_memory_mb']: + if (comparison['optimized']['avg_memory_mb'] is not None and + comparison['non_optimized']['avg_memory_mb'] is not None and + comparison['non_optimized']['avg_memory_mb'] > 0): memory_reduction = (1 - comparison['optimized']['avg_memory_mb'] / comparison['non_optimized']['avg_memory_mb']) * 100 comparison['memory_reduction_percent'] = memory_reduction else: @@ -356,7 +365,7 @@ class InferenceMetrics: print(f" Average Tokens/Second: {comparison['optimized']['avg_tokens_per_sec']:.2f}") print(f" Average Time/Token: {comparison['optimized']['avg_time_per_token']:.3f} ms") print(f" Average Total Time: {comparison['optimized']['avg_total_time']:.3f} s") - if comparison['optimized']['avg_memory_mb']: + if comparison['optimized']['avg_memory_mb'] is not None: print(f" Average Memory: {comparison['optimized']['avg_memory_mb']:.1f} MB") print(f"\nNon-Optimized Runs: {comparison['non_optimized']['count']}") @@ -364,13 +373,13 @@ class InferenceMetrics: print(f" Average Tokens/Second: {comparison['non_optimized']['avg_tokens_per_sec']:.2f}") print(f" Average Time/Token: {comparison['non_optimized']['avg_time_per_token']:.3f} ms") print(f" Average Total Time: {comparison['non_optimized']['avg_total_time']:.3f} s") - if comparison['non_optimized']['avg_memory_mb']: + if comparison['non_optimized']['avg_memory_mb'] is not None: print(f" Average Memory: {comparison['non_optimized']['avg_memory_mb']:.1f} MB") if comparison['speedup']: print(f"\nšŸš€ SPEEDUP: {comparison['speedup']:.2f}x faster with optimizations") - if comparison['memory_reduction_percent']: + if comparison['memory_reduction_percent'] is not None: print(f"šŸ’¾ MEMORY REDUCTION: {comparison['memory_reduction_percent']:.1f}%") print("=" * 70)