Fix optimized attention mask handling for training

- Fix mask format conversion (float to boolean) for scaled_dot_product_attention
- Fix mask dimensions for proper broadcasting [batch, 1, seq_len, seq_len]
- Resolve conflict between is_causal and custom mask parameters
- Enable training with optimized attention and KV caching
This commit is contained in:
Carlos Gutierrez
2025-11-16 16:44:55 -05:00
parent 3fef3e2689
commit 9f17e1db24

View File

@@ -130,11 +130,30 @@ class OptimizedMultiHeadAttention(nn.Module):
# Compute attention scores with optimized computation
if self.use_flash_attention and hasattr(F, 'scaled_dot_product_attention'):
# Use PyTorch's optimized scaled dot product attention
# Prepare mask for scaled_dot_product_attention
# Note: scaled_dot_product_attention expects boolean mask where True=attend, False=mask
# Our mask is float where 1.0=valid, 0.0=masked, so we need to convert
attn_mask = None
if mask is not None:
# Convert float mask to boolean: 1.0 -> True (attend), 0.0 -> False (mask)
if mask.dim() == 2:
# [seq_len, seq_len] -> [1, 1, seq_len, seq_len] for broadcasting
attn_mask = (mask > 0.5).unsqueeze(0).unsqueeze(0)
elif mask.dim() == 3:
# [batch, seq_len, seq_len] -> [batch, 1, seq_len, seq_len] for broadcasting
attn_mask = (mask > 0.5).unsqueeze(1)
else:
attn_mask = (mask > 0.5)
# Use causal masking only if no custom mask is provided
# If we have a custom mask, it should already include causal masking
use_causal = self.causal and (attn_mask is None)
output = F.scaled_dot_product_attention(
Q, K, V,
attn_mask=mask,
attn_mask=attn_mask,
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=self.causal,
is_causal=use_causal,
)
attention_weights = None # Flash attention doesn't return weights
else: