From 9f17e1db2418c79ec1448866e26091f891a34f09 Mon Sep 17 00:00:00 2001 From: Carlos Gutierrez Date: Sun, 16 Nov 2025 16:44:55 -0500 Subject: [PATCH] Fix optimized attention mask handling for training - Fix mask format conversion (float to boolean) for scaled_dot_product_attention - Fix mask dimensions for proper broadcasting [batch, 1, seq_len, seq_len] - Resolve conflict between is_causal and custom mask parameters - Enable training with optimized attention and KV caching --- models/optimized_attention.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/models/optimized_attention.py b/models/optimized_attention.py index d9d8225..c1ddc08 100644 --- a/models/optimized_attention.py +++ b/models/optimized_attention.py @@ -130,11 +130,30 @@ class OptimizedMultiHeadAttention(nn.Module): # Compute attention scores with optimized computation if self.use_flash_attention and hasattr(F, 'scaled_dot_product_attention'): # Use PyTorch's optimized scaled dot product attention + # Prepare mask for scaled_dot_product_attention + # Note: scaled_dot_product_attention expects boolean mask where True=attend, False=mask + # Our mask is float where 1.0=valid, 0.0=masked, so we need to convert + attn_mask = None + if mask is not None: + # Convert float mask to boolean: 1.0 -> True (attend), 0.0 -> False (mask) + if mask.dim() == 2: + # [seq_len, seq_len] -> [1, 1, seq_len, seq_len] for broadcasting + attn_mask = (mask > 0.5).unsqueeze(0).unsqueeze(0) + elif mask.dim() == 3: + # [batch, seq_len, seq_len] -> [batch, 1, seq_len, seq_len] for broadcasting + attn_mask = (mask > 0.5).unsqueeze(1) + else: + attn_mask = (mask > 0.5) + + # Use causal masking only if no custom mask is provided + # If we have a custom mask, it should already include causal masking + use_causal = self.causal and (attn_mask is None) + output = F.scaled_dot_product_attention( Q, K, V, - attn_mask=mask, + attn_mask=attn_mask, dropout_p=self.dropout.p if self.training else 0.0, - is_causal=self.causal, + is_causal=use_causal, ) attention_weights = None # Flash attention doesn't return weights else: