From cc8d83502d453f59e7a73db75af3826868bdf229 Mon Sep 17 00:00:00 2001 From: ruanhao Date: Tue, 26 Aug 2025 13:13:13 +0800 Subject: [PATCH 1/2] [Feature]qwen2.5omni add ulysses CP --- mindspeed_mm/models/audio/audio_model.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/mindspeed_mm/models/audio/audio_model.py b/mindspeed_mm/models/audio/audio_model.py index 841561f9..a89de69b 100644 --- a/mindspeed_mm/models/audio/audio_model.py +++ b/mindspeed_mm/models/audio/audio_model.py @@ -13,6 +13,9 @@ from megatron.training import get_args from mindspeed_mm.models.audio.omni_audio_encoder import SinusoidsPositionEmbedding, AudioLinear from mindspeed_mm.models.common.module import MultiModalModule from mindspeed_mm.models.vision.vision_encoders.vision_transformer_block import Qwen2VLVisionTransformerBlock +from megatron.core import mpu +from mindspeed.core.context_parallel.ulysses_context_parallel.unaligned_cp.mapping import cal_split_sizes, split_forward_gather_backward, \ + gather_forward_split_backward class AudioModel(MultiModalModule): @@ -181,6 +184,16 @@ class OmniAudioEncoder(MultiModalModule): ).to(torch.int32) hidden_states = hidden_states.unsqueeze(0).transpose(0, 1) seq_len, _, _ = hidden_states.shape + if mpu.get_context_parallel_world_size() > 1 and get_args().context_parallel_algo == "ulysses_cp_algo": + split_gather_sizes = cal_split_sizes(hidden_states.shape[0], mpu.get_context_parallel_world_size()) + hidden_states = split_forward_gather_backward( + hidden_states, + mpu.get_context_parallel_group(), + 0, + split_gather_sizes, + "down" + ) + if get_args().use_flash_attn: attention_mask = None else: @@ -192,6 +205,15 @@ class OmniAudioEncoder(MultiModalModule): attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = 0 hidden_states = self.blocks(hidden_states, attention_mask=attention_mask, cu_seqlens=cu_seqlens) + if mpu.get_context_parallel_world_size() > 1 and get_args().context_parallel_algo == "ulysses_cp_algo": + hidden_states = gather_forward_split_backward( + hidden_states, + mpu.get_context_parallel_group(), + 0, + split_gather_sizes, + "up" + ) + hidden_states = hidden_states.squeeze(1) hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) token_audio_list = [] -- Gitee From a3976e1e1932820c32112a1fe0bd5d0ef3195ede Mon Sep 17 00:00:00 2001 From: ruanhao Date: Fri, 29 Aug 2025 12:55:48 +0800 Subject: [PATCH 2/2] [Feature]qwen2.5omni add ulysses CP --- .../qwen2.5omni/finetune_qwen2_5_omni_7b.sh | 1 + mindspeed_mm/models/audio/audio_model.py | 20 ++++++++++--------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/examples/qwen2.5omni/finetune_qwen2_5_omni_7b.sh b/examples/qwen2.5omni/finetune_qwen2_5_omni_7b.sh index f3fe2943..761f8755 100644 --- a/examples/qwen2.5omni/finetune_qwen2_5_omni_7b.sh +++ b/examples/qwen2.5omni/finetune_qwen2_5_omni_7b.sh @@ -49,6 +49,7 @@ GPT_ARGS=" --pipeline-model-parallel-size ${PP} \ --micro-batch-size ${MBS} \ --global-batch-size ${GBS} \ + --context-parallel-size ${CP} \ --tokenizer-type NullTokenizer \ --vocab-size 152064 \ --seq-length 3072 \ diff --git a/mindspeed_mm/models/audio/audio_model.py b/mindspeed_mm/models/audio/audio_model.py index a89de69b..0b4e2cfa 100644 --- a/mindspeed_mm/models/audio/audio_model.py +++ b/mindspeed_mm/models/audio/audio_model.py @@ -1,21 +1,23 @@ -#Copyright 2025 The Qwen team; Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# Copyright 2025 The Qwen team; Alibaba Group and the HuggingFace Inc. team. All rights reserved. import math import torch -from torch import nn as nn -from torch.nn import functional as F - +from megatron.core import mpu from megatron.core.transformer import TransformerConfig, ModuleSpec from megatron.core.transformer.transformer_block import TransformerBlock from megatron.training import get_args +from mindspeed.core.context_parallel.ulysses_context_parallel.unaligned_cp.mapping import ( + cal_split_sizes, + split_forward_gather_backward, + gather_forward_split_backward +) +from torch import nn as nn +from torch.nn import functional as F from mindspeed_mm.models.audio.omni_audio_encoder import SinusoidsPositionEmbedding, AudioLinear from mindspeed_mm.models.common.module import MultiModalModule from mindspeed_mm.models.vision.vision_encoders.vision_transformer_block import Qwen2VLVisionTransformerBlock -from megatron.core import mpu -from mindspeed.core.context_parallel.ulysses_context_parallel.unaligned_cp.mapping import cal_split_sizes, split_forward_gather_backward, \ - gather_forward_split_backward class AudioModel(MultiModalModule): @@ -184,7 +186,7 @@ class OmniAudioEncoder(MultiModalModule): ).to(torch.int32) hidden_states = hidden_states.unsqueeze(0).transpose(0, 1) seq_len, _, _ = hidden_states.shape - if mpu.get_context_parallel_world_size() > 1 and get_args().context_parallel_algo == "ulysses_cp_algo": + if mpu.get_context_parallel_world_size() > 1: split_gather_sizes = cal_split_sizes(hidden_states.shape[0], mpu.get_context_parallel_world_size()) hidden_states = split_forward_gather_backward( hidden_states, @@ -205,7 +207,7 @@ class OmniAudioEncoder(MultiModalModule): attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = 0 hidden_states = self.blocks(hidden_states, attention_mask=attention_mask, cu_seqlens=cu_seqlens) - if mpu.get_context_parallel_world_size() > 1 and get_args().context_parallel_algo == "ulysses_cp_algo": + if mpu.get_context_parallel_world_size() > 1: hidden_states = gather_forward_split_backward( hidden_states, mpu.get_context_parallel_group(), -- Gitee