From c97e91f85ed33386dd140caf9c9f85d713b023d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=B3=BD?= <1596468259@qq.com> Date: Mon, 11 Aug 2025 17:56:28 +0800 Subject: [PATCH] adapt latest mindspeed --- docs/features/lora_finetune.md | 3 - docs/user-guide/getting_start.md | 2 +- docs/user-guide/model-migration.md | 2 +- examples/cogvideox/README.md | 2 +- examples/dancegrpo/README.md | 2 +- examples/deepseekvl2/README.md | 2 +- examples/glm4.1v/README.md | 2 +- examples/hunyuanvideo/README.md | 2 +- examples/internvl2.5/README.md | 2 +- examples/internvl2/README.md | 2 +- examples/internvl3/README.md | 2 +- examples/llava1.5/README.md | 2 +- examples/opensora1.0/README.md | 2 +- examples/opensora1.2/README.md | 2 +- examples/opensora2.0/README.md | 2 +- examples/opensoraplan1.2/README.md | 2 +- examples/opensoraplan1.3/README.md | 2 +- examples/opensoraplan1.5/README.md | 2 +- examples/qwen2.5omni/README.md | 2 +- examples/qwen2.5vl/README.md | 2 +- examples/qwen2vl/README.md | 2 +- examples/qwen3vl_dev/README.md | 2 +- examples/rl/README.md | 2 +- examples/stepvideo/README.md | 2 +- examples/vae/README.md | 2 +- examples/whisper/README.md | 2 +- .../models/audio/omni_audio_encoder.py | 88 ++----------------- .../vision_encoders/qwen2vl_vit_model.py | 25 +++++- mindspeed_mm/patchs/ring_attn_patch.py | 51 ++++++++--- mindspeed_mm/patchs/ulysses_patches.py | 6 ++ mindspeed_mm/patchs/validate_args_patch.py | 1 + .../utils/transformer_model_config.py | 11 +++ .../inference_qwen2vl_7b.json | 4 +- .../inference_qwen2vl_7b.json | 4 +- 34 files changed, 115 insertions(+), 128 deletions(-) diff --git a/docs/features/lora_finetune.md b/docs/features/lora_finetune.md index d1b6e868..71e5336f 100644 --- a/docs/features/lora_finetune.md +++ b/docs/features/lora_finetune.md @@ -59,9 +59,6 @@ LoRA权重和原始权重合并方法: ### 注意事项 - -- 当前依赖的core版本存在一处参数校验错误,使能lora时需手动修改 `MindSpeed/mindspeed/features_manager/memory/swap_attention.py`文件中的一处代码,文件内搜索`if is_enable_lora:`,修改为`if is_enable_lora and args.swap_attention:`。 - - **冻结模块**:多模态模型中可能存在部分模块参数冻结的情况,冻结的模块不会参与 LoRA 微调。 ## 参考文献 diff --git a/docs/user-guide/getting_start.md b/docs/user-guide/getting_start.md index b18a2a57..89d490e1 100644 --- a/docs/user-guide/getting_start.md +++ b/docs/user-guide/getting_start.md @@ -35,7 +35,7 @@ conda activate test git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_v0.12.1 -git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 +git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab pip install -r requirements.txt pip3 install -e . cd .. diff --git a/docs/user-guide/model-migration.md b/docs/user-guide/model-migration.md index 564c35fd..1ea410ba 100644 --- a/docs/user-guide/model-migration.md +++ b/docs/user-guide/model-migration.md @@ -174,7 +174,7 @@ pip install torch_npu-2.7.1*-cp310-cp310-manylinux_2_28_aarch64.whl git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.8.0 -git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 +git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/cogvideox/README.md b/examples/cogvideox/README.md index f7310f2d..6c9cf171 100644 --- a/examples/cogvideox/README.md +++ b/examples/cogvideox/README.md @@ -99,7 +99,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 -git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 +git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab pip install -r requirements.txt pip install -e . cd .. diff --git a/examples/dancegrpo/README.md b/examples/dancegrpo/README.md index c577bd7b..49d8c0f9 100644 --- a/examples/dancegrpo/README.md +++ b/examples/dancegrpo/README.md @@ -78,7 +78,7 @@ pip install torch_npu-2.7.1*.whl # 安装加速库 git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed -git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 +git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab cp -r mindspeed ../MindSpeed-MM/ cd .. diff --git a/examples/deepseekvl2/README.md b/examples/deepseekvl2/README.md index 80535f7e..818a2e55 100644 --- a/examples/deepseekvl2/README.md +++ b/examples/deepseekvl2/README.md @@ -63,7 +63,7 @@ pip install torch_npu-2.7.1*-cp310-cp310-manylinux_2_28_aarch64.whl git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 -git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 +git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/glm4.1v/README.md b/examples/glm4.1v/README.md index 8a8c9a08..230e6e58 100644 --- a/examples/glm4.1v/README.md +++ b/examples/glm4.1v/README.md @@ -66,7 +66,7 @@ pip install torch_npu-2.7.1*-cp310-cp310-manylinux_2_28_aarch64.whl git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 -git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 +git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/hunyuanvideo/README.md b/examples/hunyuanvideo/README.md index 2f88044a..6c63d183 100644 --- a/examples/hunyuanvideo/README.md +++ b/examples/hunyuanvideo/README.md @@ -96,7 +96,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 -git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 +git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab pip install -r requirements.txt pip install -e . cd .. diff --git a/examples/internvl2.5/README.md b/examples/internvl2.5/README.md index c410c85a..ea78106e 100644 --- a/examples/internvl2.5/README.md +++ b/examples/internvl2.5/README.md @@ -83,7 +83,7 @@ pip install torch_npu-2.7.1*-cp310-cp310-manylinux_2_28_aarch64.whl git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 -git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 +git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/internvl2/README.md b/examples/internvl2/README.md index e807ebd0..813198b6 100644 --- a/examples/internvl2/README.md +++ b/examples/internvl2/README.md @@ -90,7 +90,7 @@ pip install torch_npu-2.7.1*-cp310-cp310-manylinux_2_28_aarch64.whl git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 -git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 +git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/internvl3/README.md b/examples/internvl3/README.md index 2c8609c8..d8082f44 100644 --- a/examples/internvl3/README.md +++ b/examples/internvl3/README.md @@ -81,7 +81,7 @@ pip install torch_npu-2.7.1*-cp310-cp310-manylinux_2_28_aarch64.whl git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 -git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 +git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/llava1.5/README.md b/examples/llava1.5/README.md index 568e8455..39d9c70b 100644 --- a/examples/llava1.5/README.md +++ b/examples/llava1.5/README.md @@ -85,7 +85,7 @@ commit_id=3e337ad git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 - git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 + git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/opensora1.0/README.md b/examples/opensora1.0/README.md index 1601e06c..05c4600d 100644 --- a/examples/opensora1.0/README.md +++ b/examples/opensora1.0/README.md @@ -68,7 +68,7 @@ git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 - git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 + git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/opensora1.2/README.md b/examples/opensora1.2/README.md index df8c80e2..3ba39fdb 100644 --- a/examples/opensora1.2/README.md +++ b/examples/opensora1.2/README.md @@ -69,7 +69,7 @@ git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 - git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 + git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/opensora2.0/README.md b/examples/opensora2.0/README.md index 89709b48..b9763e22 100644 --- a/examples/opensora2.0/README.md +++ b/examples/opensora2.0/README.md @@ -81,7 +81,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 -git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 +git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/opensoraplan1.2/README.md b/examples/opensoraplan1.2/README.md index 736f9997..88dde3d2 100644 --- a/examples/opensoraplan1.2/README.md +++ b/examples/opensoraplan1.2/README.md @@ -83,7 +83,7 @@ commit_id=adb2a20 git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 - git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 + git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/opensoraplan1.3/README.md b/examples/opensoraplan1.3/README.md index 79cc4811..ec3edf43 100644 --- a/examples/opensoraplan1.3/README.md +++ b/examples/opensoraplan1.3/README.md @@ -87,7 +87,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 -git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 +git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/opensoraplan1.5/README.md b/examples/opensoraplan1.5/README.md index f83df931..ace8e37d 100644 --- a/examples/opensoraplan1.5/README.md +++ b/examples/opensoraplan1.5/README.md @@ -55,7 +55,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 -git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 +git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab pip install -r requirements.txt pip install -e . cd .. diff --git a/examples/qwen2.5omni/README.md b/examples/qwen2.5omni/README.md index 7c306fc2..3e7e9309 100644 --- a/examples/qwen2.5omni/README.md +++ b/examples/qwen2.5omni/README.md @@ -77,7 +77,7 @@ mkdir logs data ckpt git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 -git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 +git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab # 安装mindspeed及依赖 pip install -e . cd .. diff --git a/examples/qwen2.5vl/README.md b/examples/qwen2.5vl/README.md index 62fff2f6..c7c5f3d9 100644 --- a/examples/qwen2.5vl/README.md +++ b/examples/qwen2.5vl/README.md @@ -78,7 +78,7 @@ mkdir logs data ckpt git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 -git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 +git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab # 安装mindspeed及依赖 pip install -e . cd .. diff --git a/examples/qwen2vl/README.md b/examples/qwen2vl/README.md index 08ff8cfc..f64f61a9 100644 --- a/examples/qwen2vl/README.md +++ b/examples/qwen2vl/README.md @@ -85,7 +85,7 @@ mkdir logs data ckpt git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 -git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 +git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab # 安装mindspeed及依赖 pip install -e . cd .. diff --git a/examples/qwen3vl_dev/README.md b/examples/qwen3vl_dev/README.md index d2c189be..01f8088a 100644 --- a/examples/qwen3vl_dev/README.md +++ b/examples/qwen3vl_dev/README.md @@ -98,7 +98,7 @@ pip install torch_npu-2.7.1*-cp310-cp310-manylinux_2_28_aarch64.whl git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 -git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 +git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/rl/README.md b/examples/rl/README.md index 27443a61..2fd55f86 100644 --- a/examples/rl/README.md +++ b/examples/rl/README.md @@ -84,7 +84,7 @@ pip install -r MindSpeed-MM/examples/rl/requirements.txt git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 -git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 +git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab pip install -r requirements.txt cp -r mindspeed ../MindSpeed-MM/ cd .. diff --git a/examples/stepvideo/README.md b/examples/stepvideo/README.md index bfe60a99..019df19b 100644 --- a/examples/stepvideo/README.md +++ b/examples/stepvideo/README.md @@ -83,7 +83,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 -git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 +git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab pip install -r requirements.txt pip install -e . cd .. diff --git a/examples/vae/README.md b/examples/vae/README.md index c5edd9ee..39204787 100644 --- a/examples/vae/README.md +++ b/examples/vae/README.md @@ -65,7 +65,7 @@ git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 - git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 + git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab pip install -r requirements.txt pip install -e . cd .. diff --git a/examples/whisper/README.md b/examples/whisper/README.md index be9a931f..0fa7bbf6 100644 --- a/examples/whisper/README.md +++ b/examples/whisper/README.md @@ -66,7 +66,7 @@ git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.12.1 - git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 + git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab pip install -r requirements.txt pip3 install -e . cd .. diff --git a/mindspeed_mm/models/audio/omni_audio_encoder.py b/mindspeed_mm/models/audio/omni_audio_encoder.py index 60d44d6c..2480daeb 100644 --- a/mindspeed_mm/models/audio/omni_audio_encoder.py +++ b/mindspeed_mm/models/audio/omni_audio_encoder.py @@ -60,88 +60,16 @@ class QwenOmniAudioSelfAttention(Qwen2vlVitSelfAttention): self.linear_qkv.bias.register_hook(freeze_k_bias_grad_hook) - def forward( - self, - hidden_states, - attention_mask, - key_value_states=None, - inference_context=None, - rotary_pos_emb=None, - rotary_pos_cos=None, - rotary_pos_sin=None, - attention_bias=None, - packed_seq_params=None, - sequence_len_offset=None, - inference_params=None, - ): - # hidden_states: [sq, b, h] - # For self attention we just duplicate the rotary_pos_emb if it isn't already - if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): - rotary_pos_emb = (rotary_pos_emb,) * 2 - - query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) - - if self.config.context_parallel_size > key.shape[2]: - key = key.repeat_interleave( - query.shape[2] // key.shape[2], dim=2 - ) - value = value.repeat_interleave( - query.shape[2] // value.shape[2], dim=2 - ) - # =================================================== - # Adjust key, value, and rotary_pos_emb for inference - # =================================================== - query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference( - inference_context, - query, - key, - value, - rotary_pos_emb, - rotary_pos_cos, - rotary_pos_sin, - sequence_len_offset, + def apply_rotary_pos_emb_qk(self, rotary_pos_emb, query, key): + q_pos_emb, k_pos_emb = rotary_pos_emb + query = apply_rotary_pos_emb( + query, q_pos_emb, config=self.config, + ) + key = apply_rotary_pos_emb( + key, k_pos_emb, config=self.config, ) - # ================================================ - # relative positional embedding (rotary embedding) - # ================================================ - if rotary_pos_emb is not None: - q_pos_emb, k_pos_emb = rotary_pos_emb - - query = apply_rotary_pos_emb( - query, q_pos_emb, config=self.config, - ) - key = apply_rotary_pos_emb( - key, k_pos_emb, config=self.config, - ) - - # ================================== - # core attention computation - # ================================== - if self.checkpoint_core_attention and self.training: - core_attn_out = self._checkpointed_attention_forward( - query, - key, - value, - attention_mask, - attn_mask_type=attn_mask_type, - packed_seq_params=packed_seq_params, - ) - else: - core_attn_out = self.core_attention( - query, - key, - value, - attention_mask, - attn_mask_type=attn_mask_type, - packed_seq_params=packed_seq_params, - ) - - # ================= - # Output. [sq, b, h] - # ================= - output, bias = self.linear_proj(core_attn_out) - return output, bias + return query, key QKV_SIZE = 3 diff --git a/mindspeed_mm/models/vision/vision_encoders/qwen2vl_vit_model.py b/mindspeed_mm/models/vision/vision_encoders/qwen2vl_vit_model.py index 62ce86d2..75620859 100644 --- a/mindspeed_mm/models/vision/vision_encoders/qwen2vl_vit_model.py +++ b/mindspeed_mm/models/vision/vision_encoders/qwen2vl_vit_model.py @@ -211,6 +211,13 @@ class Qwen2vlVitSelfAttention(SelfAttention): attn_mask_type=attn_mask_type ) + def apply_rotary_pos_emb_qk(self, rotary_pos_emb, query, key): + query = apply_rotary_pos_emb_vision(query, rotary_pos_emb, + use_fused_rope=self.config.use_fused_rotary_pos_emb) + key = apply_rotary_pos_emb_vision(key, rotary_pos_emb, + use_fused_rope=self.config.use_fused_rotary_pos_emb) + return query, key + def forward( self, hidden_states, @@ -260,10 +267,13 @@ class Qwen2vlVitSelfAttention(SelfAttention): # absolute positional embedding. # otherwise, only relative positional embedding takes effect if rotary_pos_emb is not None: - query = apply_rotary_pos_emb_vision(query, rotary_pos_emb, - use_fused_rope=self.config.use_fused_rotary_pos_emb) - key = apply_rotary_pos_emb_vision(key, rotary_pos_emb, - use_fused_rope=self.config.use_fused_rotary_pos_emb) + query, key = self.apply_rotary_pos_emb_qk(rotary_pos_emb, query, key) + + # Adapt origin TND format + if packed_seq_params is not None: + query = query.squeeze(1) + key = key.squeeze(1) + value = value.squeeze(1) # ================================== # core attention computation @@ -287,6 +297,13 @@ class Qwen2vlVitSelfAttention(SelfAttention): packed_seq_params=packed_seq_params, ) + if packed_seq_params is not None: + # reshape to same output shape as unpacked case + # from (t, np, hn) to (t, b=1, h=np*hn) + # t is the pack size = sum (sq_i) + # note that batch is a dummy dimension in the packed case + core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) + # ================= # Output. [sq, b, h] # ================= diff --git a/mindspeed_mm/patchs/ring_attn_patch.py b/mindspeed_mm/patchs/ring_attn_patch.py index 26d8ff91..1869b033 100644 --- a/mindspeed_mm/patchs/ring_attn_patch.py +++ b/mindspeed_mm/patchs/ring_attn_patch.py @@ -78,16 +78,39 @@ def vlm_cp_dot_product_attention_forward( attention_mask = get_attention_mask() if self.config.attention_mask_type == 'causal': self.config.sparse_mode = 2 - if self.config.reset_attention_mask: + if getattr(self.config, 'reset_attention_mask', False): if self.config.attention_mask_type == 'general': self.config.sparse_mode = 2 if not (self.config.context_parallel_size == 1 or self.config.context_parallel_algo == 'ulysses_cp_algo'): self.config.sparse_mode = 1 sparse_mode = self.config.sparse_mode + is_ulysses_algo = (getattr(self.config, 'context_parallel_algo', None) == 'ulysses_cp_algo') + if packed_seq_params is not None and self.config.attention_mask_type == 'causal': + attention_mask = torch.triu( + torch.ones((2048, 2048), + device='npu', dtype=torch.bool), diagonal=1) + sparse_mode = 2 ensure_valid(attention_bias is None, 'Attention bias is not supported for DotProductAttention.') - seq_length, bsz, n_head = query.shape[0], query.shape[1], query.shape[2] + if packed_seq_params is not None and not is_ulysses_algo: + #TND + T, n_head, D = query.shape[0], query.shape[1], query.shape[2] + else: + seq_length, bsz, n_head, head_dim = query.shape[0], query.shape[1], query.shape[2], query.shape[3] + + if packed_seq_params is not None and not is_ulysses_algo: + # TND + cp_size = parallel_state.get_context_parallel_world_size() + actual_seq_qlen = packed_seq_params.cu_seqlens_q.tolist() + actual_seq_kvlen = packed_seq_params.cu_seqlens_kv.tolist() + shape_order = 'TND' + else: + # SBH + actual_seq_qlen = None if packed_seq_params is None else packed_seq_params.cu_seqlens_q.tolist() + actual_seq_kvlen = None if packed_seq_params is None else packed_seq_params.cu_seqlens_kv.tolist() + query, key, value = [rearrange(x, 's b h d -> s b (h d)') for x in [query, key, value]] + shape_order = 'SBH' if attn_mask_type == AttnMaskType.no_mask: sparse_mode = 0 # default mask @@ -116,6 +139,7 @@ def vlm_cp_dot_product_attention_forward( attn_para['next_tokens'] = self.config.next_tockens attn_para['keep_prob'] = 1 - self.attention_dropout.p attn_para['sparse_mode'] = sparse_mode + attn_para['n_head'] = n_head output = ulyssesattn_context_parallel(query, key, value, attn_para, self.ulysses_comm_para) return output @@ -153,8 +177,10 @@ def vlm_cp_dot_product_attention_forward( if cp_para['causal']: attention_mask = torch.triu(torch.ones([2048, 2048], dtype=torch.bool, device=query.device), diagonal=1) - query, key, value = [rearrange(x, 's b h d -> s b (h d)') for x in [query, key, value]] if self.config.context_parallel_algo in ['megatron_cp_algo', 'hybrid_cp_algo']: + is_general_eod = ((getattr(self.config, 'attention_mask_type', None) == 'general') and (packed_seq_params is not None)) + if is_general_eod: + query, key, value = [rearrange(x, '(b s) n d -> s b (n d)', b=self.config.micro_batch_size) for x in [query, key, value]] cp_para['cp_global_ranks'] = cp_global_ranks if self.config.use_cp_send_recv_overlap: if cp_expanded_by_2d_tp: @@ -180,23 +206,19 @@ def vlm_cp_dot_product_attention_forward( output = ringattn_context_parallel(query, key, value, n_head, cp_para, scale, attention_mask, self.attention_dropout.p, packed_seq_params) + if is_general_eod: + output = rearrange(output, 's b (n d) -> (b s) n d', n=n_head) else: cp_para['scheduling_info'] = get_scheduling_info() output = adaptive_attn_context_parallel(query, key, value, n_head, cp_para, scale, attention_mask, self.attention_dropout.p) else: - if packed_seq_params is not None: # TND - cp_size = parallel_state.get_context_parallel_world_size() - actual_seq_qlen = packed_seq_params.cu_seqlens_q.tolist() - actual_seq_kvlen = packed_seq_params.cu_seqlens_kv.tolist() - query, key, value = [rearrange(x, 's b h d -> (b s) h d') for x in [query, key, value]] + # For EoD ulysses + if packed_seq_params is not None: + query, key, value = [rearrange(x, 's b (h d) -> (b s) h d', d=head_dim) for x in [query, key, value]] shape_order = 'TND' - else: # SBH - actual_seq_qlen = None - actual_seq_kvlen = None - query, key, value = [rearrange(x, 's b h d -> s b (h d)') for x in [query, key, value]] - shape_order = 'SBH' + if self.config.use_fusion_attn_v2: output = npu_fusion_attention( query, key, value, n_head, shape_order, @@ -229,7 +251,8 @@ def vlm_cp_dot_product_attention_forward( actual_seq_kvlen=actual_seq_kvlen )[0] if packed_seq_params is not None: - output = rearrange(output, '(b s) h d -> s b (h d)', s=seq_length, b=bsz) + output = rearrange(output, '(b s) h d -> s b (h d)', b=bsz) + shape_order = 'TND' return output diff --git a/mindspeed_mm/patchs/ulysses_patches.py b/mindspeed_mm/patchs/ulysses_patches.py index b7dbbf39..5fd016a8 100644 --- a/mindspeed_mm/patchs/ulysses_patches.py +++ b/mindspeed_mm/patchs/ulysses_patches.py @@ -74,6 +74,10 @@ class UlyssesContextAttention(torch.nn.Module): args_list[0] = attention_mask args = tuple(args_list) + if packed_seq_params is not None: + query = query.unsqueeze(1) + key = key.unsqueeze(1) + value = value.unsqueeze(1) scatter_sizes_query = cal_split_sizes(query.shape[self.scatter_idx], dist.get_world_size(self.spg)) scatter_sizes_key = cal_split_sizes(key.shape[self.scatter_idx], dist.get_world_size(self.spg)) scatter_sizes_value = cal_split_sizes(value.shape[self.scatter_idx], dist.get_world_size(self.spg)) @@ -98,6 +102,8 @@ class UlyssesContextAttention(torch.nn.Module): else: output = all_to_all(context_layer, self.spg, self.gather_idx, self.scatter_idx, gather_sizes, scatter_sizes_query) output = output.reshape(output.shape[0], output.shape[1], -1) + if packed_seq_params is not None: + output = output.squeeze(1) # out e.g., [s/p::h] return output diff --git a/mindspeed_mm/patchs/validate_args_patch.py b/mindspeed_mm/patchs/validate_args_patch.py index b792731e..c32f3f01 100644 --- a/mindspeed_mm/patchs/validate_args_patch.py +++ b/mindspeed_mm/patchs/validate_args_patch.py @@ -38,6 +38,7 @@ def validate_args(args, defaults=None): args.num_layers = safe_getattr(args.mm.model.text_decoder, 'num_layers', args.num_layers) args.hidden_size = safe_getattr(args.mm.model.text_decoder, 'hidden_size', args.hidden_size) args.num_attention_heads = safe_getattr(args.mm.model.text_decoder, 'num_attention_heads', args.num_attention_heads) + args.num_query_groups = safe_getattr(args.mm.model.text_decoder, 'num_query_groups', args.num_query_groups) args.max_position_embeddings = safe_getattr(args.mm.model.text_decoder, 'max_position_embeddings', args.max_position_embeddings) args.ffn_hidden_size = safe_getattr(args.mm.model.text_decoder, 'ffn_hidden_size', args.ffn_hidden_size) diff --git a/mindspeed_mm/utils/transformer_model_config.py b/mindspeed_mm/utils/transformer_model_config.py index 375a0085..8bd1f8d9 100644 --- a/mindspeed_mm/utils/transformer_model_config.py +++ b/mindspeed_mm/utils/transformer_model_config.py @@ -56,6 +56,17 @@ def get_model_config(config): else: t_config["activation_func"] = F.gelu + if t_config.get("kv_channels") is None and t_config.get("hidden_size") and t_config.get("num_attention_heads"): + t_config["kv_channels"] = t_config["hidden_size"] // t_config["num_attention_heads"] + if t_config.get("ffn_hidden_size") is None and t_config.get("hidden_size"): + t_config["ffn_hidden_size"] = 4 * t_config["hidden_size"] + if t_config.get("num_attention_heads") is None: + t_config["num_attention_heads"] = 0 + if t_config.get("num_query_groups") is None and t_config.get("num_attention_heads"): + t_config["num_query_groups"] = t_config.get("num_attention_heads") + if t_config.get("cp_comm_type") is None: + t_config["cp_comm_type"] = None + if getattr(global_args, "multi_latent_attention", False): t_config["rope_type"] = "rope" trans_config = MLATransformerConfig(**t_config) diff --git a/tests/st/run_configs/inference_qwen2vl_7B_pp1/inference_qwen2vl_7b.json b/tests/st/run_configs/inference_qwen2vl_7B_pp1/inference_qwen2vl_7b.json index d5c39bd0..a083b64e 100644 --- a/tests/st/run_configs/inference_qwen2vl_7B_pp1/inference_qwen2vl_7b.json +++ b/tests/st/run_configs/inference_qwen2vl_7B_pp1/inference_qwen2vl_7b.json @@ -38,7 +38,9 @@ "activation_func": "gelu", "bf16": true, "params_dtype": "bf16", - "freeze": true + "freeze": true, + "layernorm_epsilon": 1e-06, + "normalization": "LayerNorm" } }, "text_decoder": { diff --git a/tests/st/run_configs/inference_qwen2vl_7B_pp4/inference_qwen2vl_7b.json b/tests/st/run_configs/inference_qwen2vl_7B_pp4/inference_qwen2vl_7b.json index d5c39bd0..a083b64e 100644 --- a/tests/st/run_configs/inference_qwen2vl_7B_pp4/inference_qwen2vl_7b.json +++ b/tests/st/run_configs/inference_qwen2vl_7B_pp4/inference_qwen2vl_7b.json @@ -38,7 +38,9 @@ "activation_func": "gelu", "bf16": true, "params_dtype": "bf16", - "freeze": true + "freeze": true, + "layernorm_epsilon": 1e-06, + "normalization": "LayerNorm" } }, "text_decoder": { -- Gitee