diff --git a/examples/qwen2.5omni/finetune_qwen2_5_omni_7b.sh b/examples/qwen2.5omni/finetune_qwen2_5_omni_7b.sh index f3fe29435ff2577985c1df1898fb9b4c83266330..761f8755f0eb754b7829bf55d364c43a6c58f551 100644 --- a/examples/qwen2.5omni/finetune_qwen2_5_omni_7b.sh +++ b/examples/qwen2.5omni/finetune_qwen2_5_omni_7b.sh @@ -49,6 +49,7 @@ GPT_ARGS=" --pipeline-model-parallel-size ${PP} \ --micro-batch-size ${MBS} \ --global-batch-size ${GBS} \ + --context-parallel-size ${CP} \ --tokenizer-type NullTokenizer \ --vocab-size 152064 \ --seq-length 3072 \ diff --git a/mindspeed_mm/models/audio/audio_model.py b/mindspeed_mm/models/audio/audio_model.py index 841561f99ac347c088e8c27124780f6e460744f0..0b4e2cfa9ba99c386ec528fddd6a4fad8e83e6fa 100644 --- a/mindspeed_mm/models/audio/audio_model.py +++ b/mindspeed_mm/models/audio/audio_model.py @@ -1,14 +1,19 @@ -#Copyright 2025 The Qwen team; Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# Copyright 2025 The Qwen team; Alibaba Group and the HuggingFace Inc. team. All rights reserved. import math import torch -from torch import nn as nn -from torch.nn import functional as F - +from megatron.core import mpu from megatron.core.transformer import TransformerConfig, ModuleSpec from megatron.core.transformer.transformer_block import TransformerBlock from megatron.training import get_args +from mindspeed.core.context_parallel.ulysses_context_parallel.unaligned_cp.mapping import ( + cal_split_sizes, + split_forward_gather_backward, + gather_forward_split_backward +) +from torch import nn as nn +from torch.nn import functional as F from mindspeed_mm.models.audio.omni_audio_encoder import SinusoidsPositionEmbedding, AudioLinear from mindspeed_mm.models.common.module import MultiModalModule @@ -181,6 +186,16 @@ class OmniAudioEncoder(MultiModalModule): ).to(torch.int32) hidden_states = hidden_states.unsqueeze(0).transpose(0, 1) seq_len, _, _ = hidden_states.shape + if mpu.get_context_parallel_world_size() > 1: + split_gather_sizes = cal_split_sizes(hidden_states.shape[0], mpu.get_context_parallel_world_size()) + hidden_states = split_forward_gather_backward( + hidden_states, + mpu.get_context_parallel_group(), + 0, + split_gather_sizes, + "down" + ) + if get_args().use_flash_attn: attention_mask = None else: @@ -192,6 +207,15 @@ class OmniAudioEncoder(MultiModalModule): attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = 0 hidden_states = self.blocks(hidden_states, attention_mask=attention_mask, cu_seqlens=cu_seqlens) + if mpu.get_context_parallel_world_size() > 1: + hidden_states = gather_forward_split_backward( + hidden_states, + mpu.get_context_parallel_group(), + 0, + split_gather_sizes, + "up" + ) + hidden_states = hidden_states.squeeze(1) hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) token_audio_list = []