diff --git a/MindIE/MultiModal/Flux.1-DEV/FLUX1dev/layers/attention_processor.py b/MindIE/MultiModal/Flux.1-DEV/FLUX1dev/layers/attention_processor.py index ac02729309e5154a463e1bcc9a0a3d9889d535a9..e53573ce937ef944a952873be857e7f2ecf494b1 100644 --- a/MindIE/MultiModal/Flux.1-DEV/FLUX1dev/layers/attention_processor.py +++ b/MindIE/MultiModal/Flux.1-DEV/FLUX1dev/layers/attention_processor.py @@ -44,7 +44,8 @@ def apply_fa(query, key, value, attention_mask): heads = query.shape[-2] head_dim = query.shape[-1] - hidden_states = attention_forward(query, key, value, attn_mask=attention_mask) + hidden_states = attention_forward(query, key, value, opt_mode="manual", attn_mask=attention_mask, + op_type="fused_attn_score", layout="BSND") return hidden_states.reshape(batch_size, -1, head_dim * heads)