fixing memory

This commit is contained in:
Carlos Gutierrez
2025-11-16 16:22:12 -05:00
parent 5fe3dc0753
commit a1e703423c
2 changed files with 35 additions and 16 deletions

View File

@@ -126,6 +126,20 @@ def get_memory_usage(device: torch.device) -> float:
return None
def get_peak_memory_usage(device: torch.device) -> float:
"""Get peak memory usage in MB since last reset."""
if device.type == 'cuda':
try:
return torch.cuda.max_memory_allocated(device) / (1024 ** 2) # MB
except RuntimeError:
return None
elif device.type == 'mps':
# MPS doesn't have direct memory query, return None
return None
else:
return None
def benchmark_inference(
model: TransformerModel,
tokenizer: SimpleTokenizer,
@@ -158,8 +172,6 @@ def benchmark_inference(
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
memory_before = get_memory_usage(device)
generated_text, generated_ids, input_ids, gen_time = generate_text(
model=model,
tokenizer=tokenizer,
@@ -172,8 +184,8 @@ def benchmark_inference(
optimized=False,
)
memory_after = get_memory_usage(device)
memory_used = memory_after - memory_before if memory_after and memory_before else None
# Use peak memory for more accurate measurement
memory_used = get_peak_memory_usage(device)
generated_ids = remove_trailing_padding(generated_ids, tokenizer.pad_token_id)
prompt_length = len(input_ids[0])
@@ -222,8 +234,6 @@ def benchmark_inference(
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
memory_before = get_memory_usage(device)
generated_text, generated_ids, input_ids, gen_time = generate_text(
model=model,
tokenizer=tokenizer,
@@ -236,8 +246,8 @@ def benchmark_inference(
optimized=True,
)
memory_after = get_memory_usage(device)
memory_used = memory_after - memory_before if memory_after and memory_before else None
# Use peak memory for more accurate measurement
memory_used = get_peak_memory_usage(device)
generated_ids = remove_trailing_padding(generated_ids, tokenizer.pad_token_id)
prompt_length = len(input_ids[0])