fixing memory
This commit is contained in:
26
inference.py
26
inference.py
@@ -126,6 +126,20 @@ def get_memory_usage(device: torch.device) -> float:
|
||||
return None
|
||||
|
||||
|
||||
def get_peak_memory_usage(device: torch.device) -> float:
|
||||
"""Get peak memory usage in MB since last reset."""
|
||||
if device.type == 'cuda':
|
||||
try:
|
||||
return torch.cuda.max_memory_allocated(device) / (1024 ** 2) # MB
|
||||
except RuntimeError:
|
||||
return None
|
||||
elif device.type == 'mps':
|
||||
# MPS doesn't have direct memory query, return None
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def benchmark_inference(
|
||||
model: TransformerModel,
|
||||
tokenizer: SimpleTokenizer,
|
||||
@@ -158,8 +172,6 @@ def benchmark_inference(
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
|
||||
memory_before = get_memory_usage(device)
|
||||
|
||||
generated_text, generated_ids, input_ids, gen_time = generate_text(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
@@ -172,8 +184,8 @@ def benchmark_inference(
|
||||
optimized=False,
|
||||
)
|
||||
|
||||
memory_after = get_memory_usage(device)
|
||||
memory_used = memory_after - memory_before if memory_after and memory_before else None
|
||||
# Use peak memory for more accurate measurement
|
||||
memory_used = get_peak_memory_usage(device)
|
||||
|
||||
generated_ids = remove_trailing_padding(generated_ids, tokenizer.pad_token_id)
|
||||
prompt_length = len(input_ids[0])
|
||||
@@ -222,8 +234,6 @@ def benchmark_inference(
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
|
||||
memory_before = get_memory_usage(device)
|
||||
|
||||
generated_text, generated_ids, input_ids, gen_time = generate_text(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
@@ -236,8 +246,8 @@ def benchmark_inference(
|
||||
optimized=True,
|
||||
)
|
||||
|
||||
memory_after = get_memory_usage(device)
|
||||
memory_used = memory_after - memory_before if memory_after and memory_before else None
|
||||
# Use peak memory for more accurate measurement
|
||||
memory_used = get_peak_memory_usage(device)
|
||||
|
||||
generated_ids = remove_trailing_padding(generated_ids, tokenizer.pad_token_id)
|
||||
prompt_length = len(input_ids[0])
|
||||
|
||||
Reference in New Issue
Block a user