From a18158422ed4706ccb79e2af8e2ca6641b84e1be Mon Sep 17 00:00:00 2001 From: cxiaolong <2845907121@qq.com> Date: Tue, 26 Aug 2025 10:42:43 +0800 Subject: [PATCH 1/8] add hetero parallel feature --- mindspeed_mm/arguments.py | 4 + mindspeed_mm/models/vlm_model.py | 57 ++++- mindspeed_mm/utils/hetero_parallel.py | 307 ++++++++++++++++++++++++++ pretrain_vlm.py | 8 + 4 files changed, 374 insertions(+), 2 deletions(-) create mode 100644 mindspeed_mm/utils/hetero_parallel.py diff --git a/mindspeed_mm/arguments.py b/mindspeed_mm/arguments.py index f3f24abd..dec8f3a3 100644 --- a/mindspeed_mm/arguments.py +++ b/mindspeed_mm/arguments.py @@ -98,6 +98,10 @@ def _add_training_args(parser): group.add_argument('--recompute-num-layers-skip-core-attention', type=int, default=0) + group.add_argument('--hetero-parallel', + action='store_true', + default=False, + help='apply different parallelism to different models') return parser diff --git a/mindspeed_mm/models/vlm_model.py b/mindspeed_mm/models/vlm_model.py index a0bd40b0..32eca2ee 100644 --- a/mindspeed_mm/models/vlm_model.py +++ b/mindspeed_mm/models/vlm_model.py @@ -10,7 +10,7 @@ from megatron.core.tensor_parallel import scatter_to_sequence_parallel_region from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region -from megatron.training import get_args +from megatron.training import get_args, print_rank_0 from megatron.training.arguments import core_transformer_config_from_args from mindspeed_mm.models.common.module_spec.get_layer_spec import get_vit_layer_spec, get_llm_layer_spec, \ @@ -21,6 +21,7 @@ from mindspeed_mm.models.common.module import MultiModalModule from mindspeed_mm.models.text_encoder.text_encoder import TextEncoder from mindspeed_mm.models.common.mm_gpt_model import MMGPTModel from mindspeed_mm.models.vision.vlm_attentionmask_for_llm import prepare_positionsids_mask_for_llm +from mindspeed_mm.utils.hetero_parallel import change_parallel_state class VLMModel(MultiModalModule): @@ -49,8 +50,9 @@ class VLMModel(MultiModalModule): def __init__(self, config) -> None: super().__init__(config=config) + args = get_args() - self.config = core_transformer_config_from_args(get_args()) + self.config = core_transformer_config_from_args(args) self.pre_process: bool = config.pre_process self.post_process: bool = config.post_process self.reward_process: bool = getattr(config, 'reward_process', False) @@ -90,6 +92,9 @@ class VLMModel(MultiModalModule): if self.add_audio_encoder: self.audio_encoder = self._build_audio_encoder_model(config.audio_encoder) + if args.hetero_parallel: + change_parallel_state('text_decoder') + def shared_embedding_or_output_weight(self): """ This is a convenience method to surface the language model's word embeddings, which is @@ -103,6 +108,22 @@ class VLMModel(MultiModalModule): vit_layer_spec = get_vit_layer_spec(config.vision_encoder) proj_layer_spec = get_projector_layer_spec(config.vision_projector) + if get_args().hetero_parallel: + change_parallel_state('image_encoder') + + self.pp_size = mpu.get_pipeline_model_parallel_world_size() + self.enable_vp = mpu.get_virtual_pipeline_model_parallel_world_size() is not None + if self.enable_vp: + self.vp_rank = mpu.get_virtual_pipeline_model_parallel_rank() + self.vp_size = mpu.get_virtual_pipeline_model_parallel_world_size() + self.pp_rank = mpu.get_pipeline_model_parallel_rank() + print_rank_0(f'initial: image_encoder pp size is {self.pp_size}') + print_rank_0(f'initial: image_encoder tp size is {mpu.get_tensor_model_parallel_world_size()}') + print_rank_0(f'initial: image_encoder dp size is {mpu.get_data_parallel_world_size()}') + self.image_encoder_DATA_PARALLEL_WORLD_SIZE= mpu.get_data_parallel_world_size() + self.image_encoder_DATA_PARALLEL_GROUP= mpu.get_data_parallel_group() + self.image_encoder_DATA_PARALLEL_RANK= mpu.get_data_parallel_rank() + if self.pp_size <= 1: return VisionModel( config=config, @@ -168,6 +189,21 @@ class VLMModel(MultiModalModule): def _build_audio_encoder_model(self, config): audio_layer_spec = get_audio_layer_spec(config.audio_encoder) + if get_args().hetero_parallel: + change_parallel_state('audio_encoder') + self.pp_size = mpu.get_pipeline_model_parallel_world_size() + self.enable_vp = mpu.get_virtual_pipeline_model_parallel_world_size() is not None + if self.enable_vp: + self.vp_rank = mpu.get_virtual_pipeline_model_parallel_rank() + self.vp_size = mpu.get_virtual_pipeline_model_parallel_world_size() + self.pp_rank = mpu.get_pipeline_model_parallel_rank() + print_rank_0(f'initial: audio_encoder pp size is {self.pp_size}') + print_rank_0(f'initial: audio_encoder tp size is {mpu.get_tensor_model_parallel_world_size()}') + print_rank_0(f'initial: audio_encoder dp size is {mpu.get_data_parallel_world_size()}') + self.audio_encoder_DATA_PARALLEL_WORLD_SIZE= mpu.get_data_parallel_world_size() + self.audio_encoder_DATA_PARALLEL_GROUP= mpu.get_data_parallel_group() + self.audio_encoder_DATA_PARALLEL_RANK= mpu.get_data_parallel_rank() + if self.pp_size <= 1: return AudioModel( config=config, @@ -229,6 +265,23 @@ class VLMModel(MultiModalModule): def _build_text_decoder_model(self, config): + if get_args().hetero_parallel: + change_parallel_state('text_decoder') + self.pre_process = mpu.is_pipeline_first_stage() + self.post_process = mpu.is_pipeline_last_stage() + self.pp_size = mpu.get_pipeline_model_parallel_world_size() + self.enable_vp = mpu.get_virtual_pipeline_model_parallel_world_size() is not None + if self.enable_vp: + self.vp_rank = mpu.get_virtual_pipeline_model_parallel_rank() + self.vp_size = mpu.get_virtual_pipeline_model_parallel_world_size() + self.pp_rank = mpu.get_pipeline_model_parallel_rank() + print_rank_0(f'initial: text_decoder pp size is {self.pp_size}') + print_rank_0(f'initial: text_decoder tp size is {mpu.get_tensor_model_parallel_world_size()}') + print_rank_0(f'initial: text_decoder dp size is {mpu.get_data_parallel_world_size()}') + self.text_decoder_DATA_PARALLEL_WORLD_SIZE= mpu.get_data_parallel_world_size() + self.text_decoder_DATA_PARALLEL_GROUP= mpu.get_data_parallel_group() + self.text_decoder_DATA_PARALLEL_RANK= mpu.get_data_parallel_rank() + if self.pp_size <= 1: return MMGPTModel( config=config, diff --git a/mindspeed_mm/utils/hetero_parallel.py b/mindspeed_mm/utils/hetero_parallel.py new file mode 100644 index 00000000..c304c429 --- /dev/null +++ b/mindspeed_mm/utils/hetero_parallel.py @@ -0,0 +1,307 @@ +import inspect +from functools import wraps +from copy import deepcopy + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from megatron.training import get_args +from megatron.core.parallel_state import initialize_model_parallel, is_initialized +import megatron.core.parallel_state as mpu + + +_Parallel_States_Dict = {} + + +def apply_hetero_parallel_hooks(model): + + if hasattr(model, 'image_encoder'): + model.image_encoder.register_forward_pre_hook(image_encoder_forward_pre_hook) + model.image_encoder.register_forward_hook(image_encoder_forward_hook) + if hasattr(model, 'audio_encoder'): + model.audio_encoder.register_forward_pre_hook(audio_encoder_forward_pre_hook) + model.audio_encoder.register_forward_hook(audio_encoder_forward_hook) + + + +def image_encoder_forward_pre_hook(module, input): + pixel_values, image_grid_thw = input + + change_parallel_state('text_decoder') + pixel_values, _ = all_gather_DP_group(pixel_values, pad_dim=0, remove_padding=True) + image_grid_thw, _ = all_gather_DP_group(image_grid_thw) + change_parallel_state('image_encoder') + + chunk_seq_lens = torch.stack( + [chunk.prod(dim=1).sum() + for chunk in + torch.chunk(image_grid_thw, chunks=mpu.get_data_parallel_world_size(), dim=0 + )]).tolist() + pixel_values = split_tensor_DP_group(pixel_values, pad_dim=0, chunk_seq_lens=chunk_seq_lens) # [B, S] + + image_grid_thw = split_tensor_DP_group(image_grid_thw, split_dim=0) + + return pixel_values, image_grid_thw + + +def image_encoder_forward_hook(module, input, output): + output, all_lens = all_gather_DP_group(output, cat_dim=0, pad_dim=0, remove_padding=True) + + change_parallel_state('text_decoder') + all_lens = [sum(all_lens[i:i+len(all_lens)//mpu.get_data_parallel_world_size()]) + for i in + range(0, len(all_lens), len(all_lens)//mpu.get_data_parallel_world_size())] + output = split_tensor_DP_group(output, pad_dim=0, split_dim=0, chunk_seq_lens=all_lens) + + + return output + + +def audio_encoder_forward_pre_hook(module, input): + input_features, feature_attention_mask = input + change_parallel_state('text_decoder') + input_features, _ = all_gather_DP_group(input_features) + feature_attention_mask, _ = all_gather_DP_group(feature_attention_mask) + change_parallel_state('audio_encoder') + input_features = split_tensor_DP_group(input_features) + feature_attention_mask = split_tensor_DP_group(feature_attention_mask) + + return input_features, feature_attention_mask + + +def audio_encoder_forward_hook(module, input, output): + output, all_lens = all_gather_DP_group(output, pad_token_id=0.0, cat_dim=0, pad_dim=0, remove_padding=True) + change_parallel_state('text_decoder') + all_lens = [sum(all_lens[i:i+len(all_lens)//mpu.get_data_parallel_world_size()]) + for i in + range(0, len(all_lens), len(all_lens)//mpu.get_data_parallel_world_size())] + output = split_tensor_DP_group(output, pad_dim=0, split_dim=0, chunk_seq_lens=all_lens) + + return output + + +def parallel_config_extract(args_dict): + targets = ["TP", "CP", "PP"] + results = [] + + def dfs(curr, par_key=None): + if isinstance(curr, dict): + if all(k in curr for k in targets) and par_key: + results.append({ + par_key: {k: curr[k] for k in targets} + }) + for k, v in curr.items(): + dfs(v, k) + elif isinstance(curr, list): + for item in curr: + dfs(item, par_key) + + dfs(args_dict) + return results + + +def initial_modules_mpu(reuse_module, args): + args_dict = args.to_dict() + + if is_initialized: + _Parallel_States_Dict[reuse_module] = {} + state_snapshot = { + k: v for k, v in vars((mpu)).items() + if k.startswith('_') and not k.startswith('__') and not inspect.isfunction(v) + } + _Parallel_States_Dict[reuse_module].update(state_snapshot) + + parallel_configs = parallel_config_extract(args_dict) + for parallel_config in parallel_configs: + module = next(iter(parallel_config)) + TP = parallel_config[module]["TP"] + CP = parallel_config[module]["CP"] + PP = parallel_config[module]["PP"] + + if module not in _Parallel_States_Dict: + _Parallel_States_Dict[module] = {} + mpu.destroy_model_parallel() + initialize_model_parallel( + tensor_model_parallel_size=TP, + pipeline_model_parallel_size=PP, + virtual_pipeline_model_parallel_size=None, + pipeline_model_parallel_split_rank=None, + use_sharp=False, + context_parallel_size=CP, + expert_model_parallel_size=1, + nccl_communicator_config_path=None, + distributed_timeout_minutes=30, + order="tp-cp-ep-dp-pp") + + state_snapshot = { + k: v for k, v in vars((mpu)).items() + if k.startswith('_') and not k.startswith('__') and not inspect.isfunction(v) + } + _Parallel_States_Dict[module].update(state_snapshot) + + +def change_parallel_state(module): + target_globals = vars(mpu) + source_globals = _Parallel_States_Dict[module] + + for k, v in source_globals.items(): + if k in target_globals: + target_globals[k] = v + + +def initial_megatron_hetero_parallel_wrapper(fn): + print('initial_megatron_hetero_parallel_wrapper activated') + @wraps(fn) + def wrapper(*args, **kwargs): + fn(*args, **kwargs) + args = get_args() + if args.hetero_parallel: + modules = ['image_encoder', 'audio_encoder', 'text_decoder'] + vlm_config = deepcopy(args.mm.model) + from pretrain_vlm import _configure_modules + _configure_modules(vlm_config, modules) + initial_modules_mpu(reuse_module='text_decoder', args=vlm_config) + return + return wrapper + +from mindspeed_mm import training +training.initialize_megatron = \ + initial_megatron_hetero_parallel_wrapper(training.initialize_megatron) + + +def all_gather_DP_group(tensor, + pad_token_id=None, + cat_dim=0, + pad_dim=1, + remove_padding=False, + parallel_state=None, + ): + """Gather tensors + 暂时只支持BSH、BD + """ + + if parallel_state is None: + group = mpu.get_data_parallel_group() + world_size = mpu.get_data_parallel_world_size() + else: + group = parallel_state['_DATA_PARALLEL_GROUP'] + world_size = torch.distributed.get_world_size(group=group) + if tensor is None: + return None, None + + if world_size == 1: + return tensor, None + + if pad_token_id is not None or remove_padding: + pad_token_id = 0 if pad_token_id is None else pad_token_id + local_len = torch.tensor([tensor.shape[pad_dim]], device='cuda') + all_lens = [torch.zeros_like(local_len) for _ in range(world_size)] + + dist.all_gather(all_lens, local_len, group=group) + all_lens = [l.item() for l in all_lens] + max_len = max(all_lens) + + pad_size = max_len - local_len + if pad_size > 0: + pad_dims = [0] * (2 * tensor.dim()) + # pad_dims: [B, S, H], [D_left, D_right, S_left, S_right, H_left, H_right] + pad_dims[2 * (tensor.dim() - pad_dim) - 1] = pad_size + tensor = F.pad(tensor, pad_dims, value=pad_token_id) + + if tensor.requires_grad ==True: + if remove_padding: + raise NotImplementedError('tensors that require grad and need removing padding are not implemented') + output = _AllGatherDp.apply(tensor, cat_dim) + else: + gathered = [torch.zeros_like(tensor) for _ in range(world_size)] + dist.all_gather(gathered, tensor, group=group) + + if remove_padding: + gathered = [g[:l] for g, l in zip(gathered, all_lens)] + output = torch.cat(gathered, dim=cat_dim).contiguous() + + if remove_padding: + return output, all_lens + return output, None + + +def split_tensor_DP_group(tensor, + split_dim=0, + pad_dim=1, + chunk_seq_lens=None, + all_lens=None, + parallel_state=None): + """split tensors + 暂时只支持bsh + chunk_seq_lens: split tensor sliding chunk_seq_lens + all_lens: all tensor origin lens(cat_dim) + if all_lens is None, split tensor per device equal or not remove padding, + if all_lens is not None, remove padding intra-dp, do not remove padding inter-dp + """ + + if parallel_state is None: + world_size = mpu.get_data_parallel_world_size() + group = mpu.get_data_parallel_group() + else: + group = parallel_state['_DATA_PARALLEL_GROUP'] + world_size = torch.distributed.get_world_size(group=group) + + if tensor is None: + return None + + if world_size == 1: + return tensor + + rank = torch.distributed.get_rank(group) + + if chunk_seq_lens: + chunk = torch.split(tensor, dim=split_dim, split_size_or_sections=chunk_seq_lens)[rank] + else: + chunks = torch.chunk(tensor, world_size, dim=split_dim) + chunk = chunks[rank] + if all_lens is not None: + # for not equal split, need remove padding + local_lens_num = len(all_lens) // world_size + start_idx = rank * local_lens_num + end_idx = start_idx + local_lens_num + local_lens = all_lens[start_idx: end_idx] + index = [slice(None)] * chunk.ndim + index[pad_dim] = slice(0, max(local_lens)) # for inner-mbs, not remove padding + chunk = chunk[tuple(index)] + return chunk + + +class _AllGatherDp(torch.autograd.Function): + """ + all gahter for dp for diff cat dim and padding dim + """ + @staticmethod + def forward(ctx, _input, cat_dim=0): + group = mpu.get_data_parallel_group() + world_size = mpu.get_data_parallel_world_size() + group_rank = torch.distributed.get_rank(group) + ctx.world_size = world_size + ctx.group = group + ctx.group_rank = group_rank + ctx.cat_dim = cat_dim + ctx.original_batch_size = _input.shape[cat_dim] + + + gathered = [torch.zeros_like(_input) for _ in range(world_size)] + dist.all_gather(gathered, _input, group=group) + output = torch.cat(gathered, dim=cat_dim).contiguous() + return output + + @staticmethod + def backward(ctx, grad_output): + world_size, group, group_rank, cat_dim, original_batch_size \ + = ctx.world_size, ctx.group, ctx.group_rank, ctx.cat_dim, ctx.original_batch_size, \ + + start = group_rank * original_batch_size + end = start + original_batch_size + + idx = [slice(None)] * grad_output.dim() + idx[cat_dim] = slice(start, end) + grad_input = grad_output[tuple(idx)] + + return grad_input, None diff --git a/pretrain_vlm.py b/pretrain_vlm.py index e6382182..b39f5039 100644 --- a/pretrain_vlm.py +++ b/pretrain_vlm.py @@ -18,6 +18,7 @@ from mindspeed_mm.models.vlm_model import VLMModel from mindspeed_mm.patchs import dummy_optimizer_patch from mindspeed_mm.training import pretrain from mindspeed_mm.utils.transformer_model_config import get_model_config +from mindspeed_mm.utils.hetero_parallel import change_parallel_state, apply_hetero_parallel_hooks mindspeed_args = get_mindspeed_args() if hasattr(mindspeed_args, "ai_framework") and mindspeed_args.ai_framework == "mindspore" and mindspeed_args.optimization_level >= 0: import mindspeed_mm.mindspore.mindspore_adaptor @@ -40,6 +41,10 @@ def model_provider(pre_process=True, post_process=True, modules=None): model = VLMModel(vlm_config) + if args.hetero_parallel: + print_rank_0("apply hetero parallel ...") + apply_hetero_parallel_hooks(model) + _apply_freezing(model, vlm_config) return model @@ -144,6 +149,9 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): """Build train, valid, and test datasets.""" args = get_args() data_config = args.mm.data + if args.hetero_parallel: + print_rank_0("change parallel state for data loader ...") + change_parallel_state("text_decoder") train_dataset = build_mm_dataset(data_config.dataset_param) train_dataloader = build_mm_dataloader(train_dataset, data_config.dataloader_param, process_group=mpu.get_data_parallel_group(), -- Gitee From acd738b7d158608e35b84a8d0168c713d0c5fefe Mon Sep 17 00:00:00 2001 From: cxiaolong <2845907121@qq.com> Date: Tue, 26 Aug 2025 10:55:15 +0800 Subject: [PATCH 2/8] add model.json --- examples/qwen2.5omni/model_7b.json | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/qwen2.5omni/model_7b.json b/examples/qwen2.5omni/model_7b.json index 0573c843..51519a22 100644 --- a/examples/qwen2.5omni/model_7b.json +++ b/examples/qwen2.5omni/model_7b.json @@ -52,7 +52,10 @@ "freeze": true, "layernorm_epsilon": 1e-06, "normalization": "RMSNorm" - } + }, + "TP":1, + "PP":1, + "CP":1 }, "audio_encoder": { "audio_encoder": { @@ -79,7 +82,10 @@ "n_window": 100, "scale_embedding": false, "output_dim": 3584 - } + }, + "TP":1, + "PP":1, + "CP":1 }, "text_decoder": { "model_id": "qwen2_5_omni_thinker", -- Gitee From a4ddce0611c178bb032387150765436e9f8548e7 Mon Sep 17 00:00:00 2001 From: cxiaolong <2845907121@qq.com> Date: Wed, 27 Aug 2025 10:42:36 +0800 Subject: [PATCH 3/8] fix codecheck --- examples/qwen2.5omni/model_7b.json | 12 +++--- mindspeed_mm/models/vlm_model.py | 9 ----- mindspeed_mm/utils/hetero_parallel.py | 57 ++++++++++++++------------- 3 files changed, 35 insertions(+), 43 deletions(-) diff --git a/examples/qwen2.5omni/model_7b.json b/examples/qwen2.5omni/model_7b.json index 51519a22..8079cca0 100644 --- a/examples/qwen2.5omni/model_7b.json +++ b/examples/qwen2.5omni/model_7b.json @@ -53,9 +53,9 @@ "layernorm_epsilon": 1e-06, "normalization": "RMSNorm" }, - "TP":1, - "PP":1, - "CP":1 + "tp":1, + "pp":1, + "cp":1 }, "audio_encoder": { "audio_encoder": { @@ -83,9 +83,9 @@ "scale_embedding": false, "output_dim": 3584 }, - "TP":1, - "PP":1, - "CP":1 + "tp":1, + "pp":1, + "cp":1 }, "text_decoder": { "model_id": "qwen2_5_omni_thinker", diff --git a/mindspeed_mm/models/vlm_model.py b/mindspeed_mm/models/vlm_model.py index 32eca2ee..99556aac 100644 --- a/mindspeed_mm/models/vlm_model.py +++ b/mindspeed_mm/models/vlm_model.py @@ -120,9 +120,6 @@ class VLMModel(MultiModalModule): print_rank_0(f'initial: image_encoder pp size is {self.pp_size}') print_rank_0(f'initial: image_encoder tp size is {mpu.get_tensor_model_parallel_world_size()}') print_rank_0(f'initial: image_encoder dp size is {mpu.get_data_parallel_world_size()}') - self.image_encoder_DATA_PARALLEL_WORLD_SIZE= mpu.get_data_parallel_world_size() - self.image_encoder_DATA_PARALLEL_GROUP= mpu.get_data_parallel_group() - self.image_encoder_DATA_PARALLEL_RANK= mpu.get_data_parallel_rank() if self.pp_size <= 1: return VisionModel( @@ -200,9 +197,6 @@ class VLMModel(MultiModalModule): print_rank_0(f'initial: audio_encoder pp size is {self.pp_size}') print_rank_0(f'initial: audio_encoder tp size is {mpu.get_tensor_model_parallel_world_size()}') print_rank_0(f'initial: audio_encoder dp size is {mpu.get_data_parallel_world_size()}') - self.audio_encoder_DATA_PARALLEL_WORLD_SIZE= mpu.get_data_parallel_world_size() - self.audio_encoder_DATA_PARALLEL_GROUP= mpu.get_data_parallel_group() - self.audio_encoder_DATA_PARALLEL_RANK= mpu.get_data_parallel_rank() if self.pp_size <= 1: return AudioModel( @@ -278,9 +272,6 @@ class VLMModel(MultiModalModule): print_rank_0(f'initial: text_decoder pp size is {self.pp_size}') print_rank_0(f'initial: text_decoder tp size is {mpu.get_tensor_model_parallel_world_size()}') print_rank_0(f'initial: text_decoder dp size is {mpu.get_data_parallel_world_size()}') - self.text_decoder_DATA_PARALLEL_WORLD_SIZE= mpu.get_data_parallel_world_size() - self.text_decoder_DATA_PARALLEL_GROUP= mpu.get_data_parallel_group() - self.text_decoder_DATA_PARALLEL_RANK= mpu.get_data_parallel_rank() if self.pp_size <= 1: return MMGPTModel( diff --git a/mindspeed_mm/utils/hetero_parallel.py b/mindspeed_mm/utils/hetero_parallel.py index c304c429..5a078d1e 100644 --- a/mindspeed_mm/utils/hetero_parallel.py +++ b/mindspeed_mm/utils/hetero_parallel.py @@ -8,9 +8,11 @@ import torch.nn.functional as F from megatron.training import get_args from megatron.core.parallel_state import initialize_model_parallel, is_initialized import megatron.core.parallel_state as mpu +from mindspeed_mm import training -_Parallel_States_Dict = {} +_ParallelStatesDict = {} +_HeteroParallelModules = ['image_encoder', 'audio_encoder', 'text_decoder'] def apply_hetero_parallel_hooks(model): @@ -28,8 +30,8 @@ def image_encoder_forward_pre_hook(module, input): pixel_values, image_grid_thw = input change_parallel_state('text_decoder') - pixel_values, _ = all_gather_DP_group(pixel_values, pad_dim=0, remove_padding=True) - image_grid_thw, _ = all_gather_DP_group(image_grid_thw) + pixel_values, _ = all_gather_dp_group(pixel_values, pad_dim=0, remove_padding=True) + image_grid_thw, _ = all_gather_dp_group(image_grid_thw) change_parallel_state('image_encoder') chunk_seq_lens = torch.stack( @@ -37,21 +39,21 @@ def image_encoder_forward_pre_hook(module, input): for chunk in torch.chunk(image_grid_thw, chunks=mpu.get_data_parallel_world_size(), dim=0 )]).tolist() - pixel_values = split_tensor_DP_group(pixel_values, pad_dim=0, chunk_seq_lens=chunk_seq_lens) # [B, S] + pixel_values = split_tensor_dp_group(pixel_values, pad_dim=0, chunk_seq_lens=chunk_seq_lens) # [B, S] - image_grid_thw = split_tensor_DP_group(image_grid_thw, split_dim=0) + image_grid_thw = split_tensor_dp_group(image_grid_thw, split_dim=0) return pixel_values, image_grid_thw def image_encoder_forward_hook(module, input, output): - output, all_lens = all_gather_DP_group(output, cat_dim=0, pad_dim=0, remove_padding=True) + output, all_lens = all_gather_dp_group(output, cat_dim=0, pad_dim=0, remove_padding=True) change_parallel_state('text_decoder') - all_lens = [sum(all_lens[i:i+len(all_lens)//mpu.get_data_parallel_world_size()]) + all_lens = [sum(all_lens[i: i + len(all_lens) // mpu.get_data_parallel_world_size()]) for i in - range(0, len(all_lens), len(all_lens)//mpu.get_data_parallel_world_size())] - output = split_tensor_DP_group(output, pad_dim=0, split_dim=0, chunk_seq_lens=all_lens) + range(0, len(all_lens), len(all_lens) // mpu.get_data_parallel_world_size())] + output = split_tensor_dp_group(output, pad_dim=0, split_dim=0, chunk_seq_lens=all_lens) return output @@ -60,22 +62,22 @@ def image_encoder_forward_hook(module, input, output): def audio_encoder_forward_pre_hook(module, input): input_features, feature_attention_mask = input change_parallel_state('text_decoder') - input_features, _ = all_gather_DP_group(input_features) - feature_attention_mask, _ = all_gather_DP_group(feature_attention_mask) + input_features, _ = all_gather_dp_group(input_features) + feature_attention_mask, _ = all_gather_dp_group(feature_attention_mask) change_parallel_state('audio_encoder') - input_features = split_tensor_DP_group(input_features) - feature_attention_mask = split_tensor_DP_group(feature_attention_mask) + input_features = split_tensor_dp_group(input_features) + feature_attention_mask = split_tensor_dp_group(feature_attention_mask) return input_features, feature_attention_mask def audio_encoder_forward_hook(module, input, output): - output, all_lens = all_gather_DP_group(output, pad_token_id=0.0, cat_dim=0, pad_dim=0, remove_padding=True) + output, all_lens = all_gather_dp_group(output, pad_token_id=0.0, cat_dim=0, pad_dim=0, remove_padding=True) change_parallel_state('text_decoder') - all_lens = [sum(all_lens[i:i+len(all_lens)//mpu.get_data_parallel_world_size()]) + all_lens = [sum(all_lens[i: i+len(all_lens) // mpu.get_data_parallel_world_size()]) for i in - range(0, len(all_lens), len(all_lens)//mpu.get_data_parallel_world_size())] - output = split_tensor_DP_group(output, pad_dim=0, split_dim=0, chunk_seq_lens=all_lens) + range(0, len(all_lens), len(all_lens) // mpu.get_data_parallel_world_size())] + output = split_tensor_dp_group(output, pad_dim=0, split_dim=0, chunk_seq_lens=all_lens) return output @@ -104,12 +106,12 @@ def initial_modules_mpu(reuse_module, args): args_dict = args.to_dict() if is_initialized: - _Parallel_States_Dict[reuse_module] = {} + _ParallelStatesDict[reuse_module] = {} state_snapshot = { k: v for k, v in vars((mpu)).items() if k.startswith('_') and not k.startswith('__') and not inspect.isfunction(v) } - _Parallel_States_Dict[reuse_module].update(state_snapshot) + _ParallelStatesDict[reuse_module].update(state_snapshot) parallel_configs = parallel_config_extract(args_dict) for parallel_config in parallel_configs: @@ -118,8 +120,8 @@ def initial_modules_mpu(reuse_module, args): CP = parallel_config[module]["CP"] PP = parallel_config[module]["PP"] - if module not in _Parallel_States_Dict: - _Parallel_States_Dict[module] = {} + if module not in _ParallelStatesDict: + _ParallelStatesDict[module] = {} mpu.destroy_model_parallel() initialize_model_parallel( tensor_model_parallel_size=TP, @@ -137,12 +139,12 @@ def initial_modules_mpu(reuse_module, args): k: v for k, v in vars((mpu)).items() if k.startswith('_') and not k.startswith('__') and not inspect.isfunction(v) } - _Parallel_States_Dict[module].update(state_snapshot) + _ParallelStatesDict[module].update(state_snapshot) def change_parallel_state(module): target_globals = vars(mpu) - source_globals = _Parallel_States_Dict[module] + source_globals = _ParallelStatesDict[module] for k, v in source_globals.items(): if k in target_globals: @@ -156,20 +158,19 @@ def initial_megatron_hetero_parallel_wrapper(fn): fn(*args, **kwargs) args = get_args() if args.hetero_parallel: - modules = ['image_encoder', 'audio_encoder', 'text_decoder'] vlm_config = deepcopy(args.mm.model) from pretrain_vlm import _configure_modules - _configure_modules(vlm_config, modules) + _configure_modules(vlm_config, _HeteroParallelModules) initial_modules_mpu(reuse_module='text_decoder', args=vlm_config) return return wrapper -from mindspeed_mm import training + training.initialize_megatron = \ initial_megatron_hetero_parallel_wrapper(training.initialize_megatron) -def all_gather_DP_group(tensor, +def all_gather_dp_group(tensor, pad_token_id=None, cat_dim=0, pad_dim=1, @@ -225,7 +226,7 @@ def all_gather_DP_group(tensor, return output, None -def split_tensor_DP_group(tensor, +def split_tensor_dp_group(tensor, split_dim=0, pad_dim=1, chunk_seq_lens=None, -- Gitee From 408d5160c372db1e82b6565eff6a2efe69cfceb0 Mon Sep 17 00:00:00 2001 From: cxiaolong <2845907121@qq.com> Date: Wed, 27 Aug 2025 15:08:44 +0800 Subject: [PATCH 4/8] fix codecheck --- mindspeed_mm/utils/hetero_parallel.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/mindspeed_mm/utils/hetero_parallel.py b/mindspeed_mm/utils/hetero_parallel.py index 5a078d1e..651b9636 100644 --- a/mindspeed_mm/utils/hetero_parallel.py +++ b/mindspeed_mm/utils/hetero_parallel.py @@ -5,7 +5,7 @@ from copy import deepcopy import torch import torch.distributed as dist import torch.nn.functional as F -from megatron.training import get_args +from megatron.training import get_args, print_rank_0 from megatron.core.parallel_state import initialize_model_parallel, is_initialized import megatron.core.parallel_state as mpu from mindspeed_mm import training @@ -74,7 +74,7 @@ def audio_encoder_forward_pre_hook(module, input): def audio_encoder_forward_hook(module, input, output): output, all_lens = all_gather_dp_group(output, pad_token_id=0.0, cat_dim=0, pad_dim=0, remove_padding=True) change_parallel_state('text_decoder') - all_lens = [sum(all_lens[i: i+len(all_lens) // mpu.get_data_parallel_world_size()]) + all_lens = [sum(all_lens[i: i + len(all_lens) // mpu.get_data_parallel_world_size()]) for i in range(0, len(all_lens), len(all_lens) // mpu.get_data_parallel_world_size())] output = split_tensor_dp_group(output, pad_dim=0, split_dim=0, chunk_seq_lens=all_lens) @@ -108,8 +108,8 @@ def initial_modules_mpu(reuse_module, args): if is_initialized: _ParallelStatesDict[reuse_module] = {} state_snapshot = { - k: v for k, v in vars((mpu)).items() - if k.startswith('_') and not k.startswith('__') and not inspect.isfunction(v) + k: v for k, v in vars((mpu)).items() + if k.startswith('_') and not k.startswith('__') and not inspect.isfunction(v) } _ParallelStatesDict[reuse_module].update(state_snapshot) @@ -152,7 +152,8 @@ def change_parallel_state(module): def initial_megatron_hetero_parallel_wrapper(fn): - print('initial_megatron_hetero_parallel_wrapper activated') + print_rank_0('initial_megatron_hetero_parallel_wrapper activated') + @wraps(fn) def wrapper(*args, **kwargs): fn(*args, **kwargs) @@ -209,7 +210,7 @@ def all_gather_dp_group(tensor, pad_dims[2 * (tensor.dim() - pad_dim) - 1] = pad_size tensor = F.pad(tensor, pad_dims, value=pad_token_id) - if tensor.requires_grad ==True: + if tensor.requires_grad: if remove_padding: raise NotImplementedError('tensors that require grad and need removing padding are not implemented') output = _AllGatherDp.apply(tensor, cat_dim) -- Gitee From 8f2ee338ab46507c74d28e43b891e42a4128b4e4 Mon Sep 17 00:00:00 2001 From: cxiaolong <2845907121@qq.com> Date: Wed, 27 Aug 2025 17:39:05 +0800 Subject: [PATCH 5/8] fix codecheck --- mindspeed_mm/utils/hetero_parallel.py | 44 +++++++++++++++------------ pretrain_vlm.py | 8 +---- 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/mindspeed_mm/utils/hetero_parallel.py b/mindspeed_mm/utils/hetero_parallel.py index 651b9636..56814c71 100644 --- a/mindspeed_mm/utils/hetero_parallel.py +++ b/mindspeed_mm/utils/hetero_parallel.py @@ -25,7 +25,6 @@ def apply_hetero_parallel_hooks(model): model.audio_encoder.register_forward_hook(audio_encoder_forward_hook) - def image_encoder_forward_pre_hook(module, input): pixel_values, image_grid_thw = input @@ -34,13 +33,12 @@ def image_encoder_forward_pre_hook(module, input): image_grid_thw, _ = all_gather_dp_group(image_grid_thw) change_parallel_state('image_encoder') - chunk_seq_lens = torch.stack( - [chunk.prod(dim=1).sum() - for chunk in - torch.chunk(image_grid_thw, chunks=mpu.get_data_parallel_world_size(), dim=0 - )]).tolist() - pixel_values = split_tensor_dp_group(pixel_values, pad_dim=0, chunk_seq_lens=chunk_seq_lens) # [B, S] + chunk_seq_lens = [] + for chunk in torch.chunk(image_grid_thw, chunks=mpu.get_data_parallel_world_size(), dim=0): + chunk_seq_lens.append(chunk.prod(dim=1).sum()) + chunk_seq_lens = torch.stack(chunk_seq_lens).tolist() + pixel_values = split_tensor_dp_group(pixel_values, pad_dim=0, chunk_seq_lens=chunk_seq_lens) # [B, S] image_grid_thw = split_tensor_dp_group(image_grid_thw, split_dim=0) return pixel_values, image_grid_thw @@ -50,9 +48,12 @@ def image_encoder_forward_hook(module, input, output): output, all_lens = all_gather_dp_group(output, cat_dim=0, pad_dim=0, remove_padding=True) change_parallel_state('text_decoder') - all_lens = [sum(all_lens[i: i + len(all_lens) // mpu.get_data_parallel_world_size()]) - for i in - range(0, len(all_lens), len(all_lens) // mpu.get_data_parallel_world_size())] + + all_lens = [] + for i in range(0, len(all_lens), len(all_lens) // mpu.get_data_parallel_world_size()): + len = sum(all_lens[i: i + len(all_lens) // mpu.get_data_parallel_world_size()]) + all_lens.append(len) + output = split_tensor_dp_group(output, pad_dim=0, split_dim=0, chunk_seq_lens=all_lens) @@ -74,9 +75,12 @@ def audio_encoder_forward_pre_hook(module, input): def audio_encoder_forward_hook(module, input, output): output, all_lens = all_gather_dp_group(output, pad_token_id=0.0, cat_dim=0, pad_dim=0, remove_padding=True) change_parallel_state('text_decoder') - all_lens = [sum(all_lens[i: i + len(all_lens) // mpu.get_data_parallel_world_size()]) - for i in - range(0, len(all_lens), len(all_lens) // mpu.get_data_parallel_world_size())] + all_lens = [] + + for i in range(0, len(all_lens), len(all_lens) // mpu.get_data_parallel_world_size()): + len = sum(all_lens[i: i + len(all_lens) // mpu.get_data_parallel_world_size()]) + all_lens.append(len) + output = split_tensor_dp_group(output, pad_dim=0, split_dim=0, chunk_seq_lens=all_lens) return output @@ -200,7 +204,7 @@ def all_gather_dp_group(tensor, all_lens = [torch.zeros_like(local_len) for _ in range(world_size)] dist.all_gather(all_lens, local_len, group=group) - all_lens = [l.item() for l in all_lens] + all_lens = [length.item() for length in all_lens] max_len = max(all_lens) pad_size = max_len - local_len @@ -219,7 +223,7 @@ def all_gather_dp_group(tensor, dist.all_gather(gathered, tensor, group=group) if remove_padding: - gathered = [g[:l] for g, l in zip(gathered, all_lens)] + gathered = [g[:length] for g, length in zip(gathered, all_lens)] output = torch.cat(gathered, dim=cat_dim).contiguous() if remove_padding: @@ -228,11 +232,11 @@ def all_gather_dp_group(tensor, def split_tensor_dp_group(tensor, - split_dim=0, - pad_dim=1, - chunk_seq_lens=None, - all_lens=None, - parallel_state=None): + split_dim=0, + pad_dim=1, + chunk_seq_lens=None, + all_lens=None, + parallel_state=None): """split tensors 暂时只支持bsh chunk_seq_lens: split tensor sliding chunk_seq_lens diff --git a/pretrain_vlm.py b/pretrain_vlm.py index 478d1a9a..c5a4e296 100644 --- a/pretrain_vlm.py +++ b/pretrain_vlm.py @@ -154,13 +154,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): if args.hetero_parallel: print_rank_0("change parallel state for data loader ...") change_parallel_state("text_decoder") - train_dataset = build_mm_dataset(data_config.dataset_param) - train_dataloader = build_mm_dataloader(train_dataset, data_config.dataloader_param, - process_group=mpu.get_data_parallel_group(), - dataset_param=data_config.dataset_param, - consumed_samples=args.consumed_train_samples, ) - train_dataloader, val_dataloader, test_dataloader = build_iterations(train_dataloader) - return train_dataloader, val_dataloader, test_dataloader + datasets = build_mm_dataset(data_config.dataset_param) build_dataloader = partial( build_mm_dataloader, -- Gitee From d372ed2aa683d92be52c9d17b051142c6297e659 Mon Sep 17 00:00:00 2001 From: cxiaolong <2845907121@qq.com> Date: Wed, 27 Aug 2025 18:13:31 +0800 Subject: [PATCH 6/8] fix codecheck --- mindspeed_mm/utils/hetero_parallel.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mindspeed_mm/utils/hetero_parallel.py b/mindspeed_mm/utils/hetero_parallel.py index 56814c71..d0cf5553 100644 --- a/mindspeed_mm/utils/hetero_parallel.py +++ b/mindspeed_mm/utils/hetero_parallel.py @@ -49,12 +49,12 @@ def image_encoder_forward_hook(module, input, output): change_parallel_state('text_decoder') - all_lens = [] + chunk_seq_lens = [] for i in range(0, len(all_lens), len(all_lens) // mpu.get_data_parallel_world_size()): len = sum(all_lens[i: i + len(all_lens) // mpu.get_data_parallel_world_size()]) - all_lens.append(len) + chunk_seq_lens.append(len) - output = split_tensor_dp_group(output, pad_dim=0, split_dim=0, chunk_seq_lens=all_lens) + output = split_tensor_dp_group(output, pad_dim=0, split_dim=0, chunk_seq_lens=chunk_seq_lens) return output @@ -75,13 +75,13 @@ def audio_encoder_forward_pre_hook(module, input): def audio_encoder_forward_hook(module, input, output): output, all_lens = all_gather_dp_group(output, pad_token_id=0.0, cat_dim=0, pad_dim=0, remove_padding=True) change_parallel_state('text_decoder') - all_lens = [] + chunk_seq_lens = [] for i in range(0, len(all_lens), len(all_lens) // mpu.get_data_parallel_world_size()): len = sum(all_lens[i: i + len(all_lens) // mpu.get_data_parallel_world_size()]) - all_lens.append(len) + chunk_seq_lens.append(len) - output = split_tensor_dp_group(output, pad_dim=0, split_dim=0, chunk_seq_lens=all_lens) + output = split_tensor_dp_group(output, pad_dim=0, split_dim=0, chunk_seq_lens=chunk_seq_lens) return output -- Gitee From eb97a9217ce41ea5cbe625b227dc16b6cb6b47a2 Mon Sep 17 00:00:00 2001 From: cxiaolong <2845907121@qq.com> Date: Wed, 27 Aug 2025 18:56:06 +0800 Subject: [PATCH 7/8] fix codecheck --- mindspeed_mm/utils/hetero_parallel.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mindspeed_mm/utils/hetero_parallel.py b/mindspeed_mm/utils/hetero_parallel.py index d0cf5553..85f2a885 100644 --- a/mindspeed_mm/utils/hetero_parallel.py +++ b/mindspeed_mm/utils/hetero_parallel.py @@ -50,8 +50,9 @@ def image_encoder_forward_hook(module, input, output): change_parallel_state('text_decoder') chunk_seq_lens = [] - for i in range(0, len(all_lens), len(all_lens) // mpu.get_data_parallel_world_size()): - len = sum(all_lens[i: i + len(all_lens) // mpu.get_data_parallel_world_size()]) + origin_len =len(all_lens) + for i in range(0, origin_len, origin_len // mpu.get_data_parallel_world_size()): + len = sum(all_lens[i: i + origin_len // mpu.get_data_parallel_world_size()]) chunk_seq_lens.append(len) output = split_tensor_dp_group(output, pad_dim=0, split_dim=0, chunk_seq_lens=chunk_seq_lens) @@ -77,8 +78,9 @@ def audio_encoder_forward_hook(module, input, output): change_parallel_state('text_decoder') chunk_seq_lens = [] - for i in range(0, len(all_lens), len(all_lens) // mpu.get_data_parallel_world_size()): - len = sum(all_lens[i: i + len(all_lens) // mpu.get_data_parallel_world_size()]) + origin_len = len(all_lens) + for i in range(0, origin_len, origin_len // mpu.get_data_parallel_world_size()): + len = sum(all_lens[i: i + origin_len // mpu.get_data_parallel_world_size()]) chunk_seq_lens.append(len) output = split_tensor_dp_group(output, pad_dim=0, split_dim=0, chunk_seq_lens=chunk_seq_lens) -- Gitee From 14e2e2ebc80a405152d458d6fa6507e957c29301 Mon Sep 17 00:00:00 2001 From: cxiaolong <2845907121@qq.com> Date: Wed, 27 Aug 2025 19:22:18 +0800 Subject: [PATCH 8/8] fix codecheck --- mindspeed_mm/utils/hetero_parallel.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mindspeed_mm/utils/hetero_parallel.py b/mindspeed_mm/utils/hetero_parallel.py index 85f2a885..463d4ed4 100644 --- a/mindspeed_mm/utils/hetero_parallel.py +++ b/mindspeed_mm/utils/hetero_parallel.py @@ -50,10 +50,10 @@ def image_encoder_forward_hook(module, input, output): change_parallel_state('text_decoder') chunk_seq_lens = [] - origin_len =len(all_lens) + origin_len = len(all_lens) for i in range(0, origin_len, origin_len // mpu.get_data_parallel_world_size()): - len = sum(all_lens[i: i + origin_len // mpu.get_data_parallel_world_size()]) - chunk_seq_lens.append(len) + length = sum(all_lens[i: i + origin_len // mpu.get_data_parallel_world_size()]) + chunk_seq_lens.append(length) output = split_tensor_dp_group(output, pad_dim=0, split_dim=0, chunk_seq_lens=chunk_seq_lens) @@ -80,8 +80,8 @@ def audio_encoder_forward_hook(module, input, output): chunk_seq_lens = [] origin_len = len(all_lens) for i in range(0, origin_len, origin_len // mpu.get_data_parallel_world_size()): - len = sum(all_lens[i: i + origin_len // mpu.get_data_parallel_world_size()]) - chunk_seq_lens.append(len) + length = sum(all_lens[i: i + origin_len // mpu.get_data_parallel_world_size()]) + chunk_seq_lens.append(length) output = split_tensor_dp_group(output, pad_dim=0, split_dim=0, chunk_seq_lens=chunk_seq_lens) -- Gitee