diff --git a/models/optimized_attention.py b/models/optimized_attention.py index d9d8225..c1ddc08 100644 --- a/models/optimized_attention.py +++ b/models/optimized_attention.py @@ -130,11 +130,30 @@ class OptimizedMultiHeadAttention(nn.Module): # Compute attention scores with optimized computation if self.use_flash_attention and hasattr(F, 'scaled_dot_product_attention'): # Use PyTorch's optimized scaled dot product attention + # Prepare mask for scaled_dot_product_attention + # Note: scaled_dot_product_attention expects boolean mask where True=attend, False=mask + # Our mask is float where 1.0=valid, 0.0=masked, so we need to convert + attn_mask = None + if mask is not None: + # Convert float mask to boolean: 1.0 -> True (attend), 0.0 -> False (mask) + if mask.dim() == 2: + # [seq_len, seq_len] -> [1, 1, seq_len, seq_len] for broadcasting + attn_mask = (mask > 0.5).unsqueeze(0).unsqueeze(0) + elif mask.dim() == 3: + # [batch, seq_len, seq_len] -> [batch, 1, seq_len, seq_len] for broadcasting + attn_mask = (mask > 0.5).unsqueeze(1) + else: + attn_mask = (mask > 0.5) + + # Use causal masking only if no custom mask is provided + # If we have a custom mask, it should already include causal masking + use_causal = self.causal and (attn_mask is None) + output = F.scaled_dot_product_attention( Q, K, V, - attn_mask=mask, + attn_mask=attn_mask, dropout_p=self.dropout.p if self.training else 0.0, - is_causal=self.causal, + is_causal=use_causal, ) attention_weights = None # Flash attention doesn't return weights else: