Fix optimized attention mask handling for training
- Fix mask format conversion (float to boolean) for scaled_dot_product_attention - Fix mask dimensions for proper broadcasting [batch, 1, seq_len, seq_len] - Resolve conflict between is_causal and custom mask parameters - Enable training with optimized attention and KV caching
This commit is contained in:
@@ -130,11 +130,30 @@ class OptimizedMultiHeadAttention(nn.Module):
|
||||
# Compute attention scores with optimized computation
|
||||
if self.use_flash_attention and hasattr(F, 'scaled_dot_product_attention'):
|
||||
# Use PyTorch's optimized scaled dot product attention
|
||||
# Prepare mask for scaled_dot_product_attention
|
||||
# Note: scaled_dot_product_attention expects boolean mask where True=attend, False=mask
|
||||
# Our mask is float where 1.0=valid, 0.0=masked, so we need to convert
|
||||
attn_mask = None
|
||||
if mask is not None:
|
||||
# Convert float mask to boolean: 1.0 -> True (attend), 0.0 -> False (mask)
|
||||
if mask.dim() == 2:
|
||||
# [seq_len, seq_len] -> [1, 1, seq_len, seq_len] for broadcasting
|
||||
attn_mask = (mask > 0.5).unsqueeze(0).unsqueeze(0)
|
||||
elif mask.dim() == 3:
|
||||
# [batch, seq_len, seq_len] -> [batch, 1, seq_len, seq_len] for broadcasting
|
||||
attn_mask = (mask > 0.5).unsqueeze(1)
|
||||
else:
|
||||
attn_mask = (mask > 0.5)
|
||||
|
||||
# Use causal masking only if no custom mask is provided
|
||||
# If we have a custom mask, it should already include causal masking
|
||||
use_causal = self.causal and (attn_mask is None)
|
||||
|
||||
output = F.scaled_dot_product_attention(
|
||||
Q, K, V,
|
||||
attn_mask=mask,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.dropout.p if self.training else 0.0,
|
||||
is_causal=self.causal,
|
||||
is_causal=use_causal,
|
||||
)
|
||||
attention_weights = None # Flash attention doesn't return weights
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user