From 59d9f82a3294281e06711380868b899cf5475149 Mon Sep 17 00:00:00 2001 From: htwang Date: Wed, 11 Jun 2025 14:23:58 +0800 Subject: [PATCH] add resharding --- .../rollout/vllm_adapter/patch/__init__.py | 8 ++ .../patch/qwen2_5_vl_visionmlp_patch.py | 59 ++++++++++++ .../resharding/vllm_weight_container.py | 27 +++--- .../workers/resharding/weight_adaptor.py | 90 +++++++++++-------- 4 files changed, 134 insertions(+), 50 deletions(-) create mode 100644 mindspeed_rl/models/rollout/vllm_adapter/patch/__init__.py create mode 100644 mindspeed_rl/models/rollout/vllm_adapter/patch/qwen2_5_vl_visionmlp_patch.py diff --git a/mindspeed_rl/models/rollout/vllm_adapter/patch/__init__.py b/mindspeed_rl/models/rollout/vllm_adapter/patch/__init__.py new file mode 100644 index 00000000..a2225277 --- /dev/null +++ b/mindspeed_rl/models/rollout/vllm_adapter/patch/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. + +from mindspeed_rl.utils.utils import is_multimodal +from .qwen2_5_vl_visionmlp_patch import replace_qwen2_5_visionmlp + + +if is_multimodal(): + replace_qwen2_5_visionmlp() \ No newline at end of file diff --git a/mindspeed_rl/models/rollout/vllm_adapter/patch/qwen2_5_vl_visionmlp_patch.py b/mindspeed_rl/models/rollout/vllm_adapter/patch/qwen2_5_vl_visionmlp_patch.py new file mode 100644 index 00000000..1b5cd91d --- /dev/null +++ b/mindspeed_rl/models/rollout/vllm_adapter/patch/qwen2_5_vl_visionmlp_patch.py @@ -0,0 +1,59 @@ +# coding=utf-8 +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. +# Copyright 2025 The vLLM team. +# Copyright 2025 The Qwen Team. +# Copyright 2025 The HuggingFace Inc. team. + + +from typing import Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import vllm +import vllm.model_executor +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear +from vllm.model_executor.layers.quantization import QuantizationConfig +import vllm.model_executor.models +import vllm.model_executor.models.qwen2_5_vl + + +class Npu_Qwen2_5_VisionMLP(nn.Module): + + def __init__(self, + in_features: int, + hidden_features: int, + bias: bool = False, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + # 合并 gate_proj 和 up_proj + self.gate_up_proj = ColumnParallelLinear( + in_features, + hidden_features * 2, # 合并后输出维度翻倍 + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj" + ) + self.down_proj = RowParallelLinear( + hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj" + ) + self.act_fn = act_fn + + def forward(self, x: torch.Tensor): + # 合并后的 gate_up_proj 输出两个部分:x_gate 和 x_up + gate_up, _ = self.gate_up_proj(x) + x_gate, x_up = gate_up.chunk(2, dim=-1) # 按最后一维拆分 + x_gate = self.act_fn(x_gate) + x_down, _ = self.down_proj(x_gate * x_up) + return x_down + + +def replace_qwen2_5_visionmlp(): + vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionMLP = Npu_Qwen2_5_VisionMLP diff --git a/mindspeed_rl/workers/resharding/vllm_weight_container.py b/mindspeed_rl/workers/resharding/vllm_weight_container.py index 45fd3364..17856314 100644 --- a/mindspeed_rl/workers/resharding/vllm_weight_container.py +++ b/mindspeed_rl/workers/resharding/vllm_weight_container.py @@ -34,6 +34,7 @@ from mindspeed_rl.workers.resharding.utils import get_tensor_parallel_partition_ update_md5_by_rank, compute_md5, validate_md5, _build_infer_param_dict, get_tp_allgather_group, \ get_tp_allgather_world_size, is_tensor_parallel_param, get_tp_group, is_fake_tp_param from mindspeed_rl.utils.loggers import Loggers +from mindspeed_rl.utils.utils import is_multimodal logger = Loggers(__name__) @@ -89,7 +90,7 @@ class MegatronStyleVllmWeightContainer: self._vpp_layer_list = self._build_vpp_layer_list(self._num_layer_list) ## _noop_layers self._global2local_map = self._build_global2local_map(self._vpp_layer_list, self._vpp_size, self._noop_layers) if self._noop_layers is not None else None - + # tp configs self._tp_size = self.parallel_state.get_tensor_model_parallel_world_size() self._tp_group = self.parallel_state.get_tensor_model_parallel_group() @@ -136,7 +137,7 @@ class MegatronStyleVllmWeightContainer: def _validate_parallel_config(self): if self._infer_pp_size != 1: raise ValueError("infer_pp_size != 1 not supported yet") - + if self._infer_ep_size % self._ep_size != 0: raise ValueError("The training expert size should be divisibled by the inference expert size.") if self._ep_size > 1 and not self.moe_tp_extend_ep: @@ -211,7 +212,7 @@ class MegatronStyleVllmWeightContainer: start_layer += layers_in_vpp_rank return glb2local_map - + def _unwrap_megatron_model(self, model): """ Remove consecutive 'module.' prefixes from the model based on the state_dict's first key. @@ -232,9 +233,13 @@ class MegatronStyleVllmWeightContainer: Return a list of buffers, and a reference dict megatron_param_name->buffer. """ vllm_names = list(dict(self.vllm_model.named_parameters()).keys()) # 获取每个pp内部的weights name + if is_multimodal(): + layers_num = [sum(num_layer_list) for num_layer_list in self._num_layer_list] + else: + layers_num = sum(self._num_layer_list) self.weight_names_per_pp = self.weight_adaptor.get_weight_names_per_pp(self._vpp_layer_list, vllm_names, - sum(self._num_layer_list), self._vpp_size, self._noop_layers) - + layers_num, self._vpp_size, self._noop_layers) + self.weight_buffers = build_model_weight_buffer(self.vllm_model, self.weight_names_per_pp, self.weight_adaptor.get_weight_buffer_meta ) @@ -347,11 +352,11 @@ class MegatronStyleVllmWeightContainer: normal_layer_func = partial(self.weight_adaptor.global2local_layer, num_layer_list=self._vpp_layer_list, global2local_map=self._global2local_map) name_pairs = sorted(list(set([(name, vpp_rank, self.weight_adaptor.replace_name_i2t(normal_layer_func(name, vpp_rank=vpp_rank))) for vpp_rank, names_per_vpp in enumerate(weight_names_meta) for name in names_per_vpp]))) - + if self.enable_validate: self.origin_params_for_md5 = hashlib.md5() self.infer_params_for_md5 = [hashlib.md5() for _ in range(get_tp_allgather_world_size())] - + # 检查 linear_fc1 和 linear_fc2 权重形状是否符合特定关系(fc1 包含门控和扩展参数,因此大小是 fc2 的两倍)。不符合条件的模型不被支持。 for _, vpp_rank, megatron_name in name_pairs: if not megatron_name.startswith("image_encoder") and megatron_name.endswith("linear_fc1.weight"): @@ -394,7 +399,7 @@ class MegatronStyleVllmWeightContainer: self.weight_buffer_meta = self.weight_adaptor.get_weight_buffer_meta(self.vllm_model, combined_names_per_pp) self.experts_weight_buffer_meta = get_weight_buffer_meta_from_buffer(self.weight_buffer_meta) self.experts_memory_buffers = build_experts_memory_buffer(self.experts_weight_buffer_meta, self.experts_memory_expend_N) - + # Step2 将weights_buffer上对应的权重放到experts_buffer中 if(cur_pp_rank == pp_rank): weight_names = self.weight_names_per_pp[pp_rank] @@ -403,7 +408,7 @@ class MegatronStyleVllmWeightContainer: name_pairs = sorted(list(set([(name, vpp_rank, self.weight_adaptor.replace_name_i2t(normal_layer_func(name, vpp_rank=vpp_rank))) for vpp_rank, names_per_vpp in enumerate(weight_names_meta) for name in names_per_vpp]))) true_megatron_model = self._unwrap_megatron_model(self.megatron_model) - + megatron_params_dict = {} # 拿到当前pp的所有权重 for vpp_rank in range(self._vpp_size): @@ -418,12 +423,12 @@ class MegatronStyleVllmWeightContainer: # Step3 后续的操作可以复用 global_src = dist.get_global_rank(group=self._pp_group, group_rank=cur_pp_rank) - + # broadcast专家权重(experts memory buffer中的) for dtype, experts_memory_buffer in self.experts_memory_buffers.items(): dist.broadcast(tensor=experts_memory_buffer.data, src=global_src, group=self._pp_group, async_op=False) pp_group_rank = self._rank // self.pp_group_size - + # 获取对应的dtype for name, tensor_indices_value in sorted(experts_memory_buffer.tensor_indices.items()): shape = tensor_indices_value[1] # 是*N的 diff --git a/mindspeed_rl/workers/resharding/weight_adaptor.py b/mindspeed_rl/workers/resharding/weight_adaptor.py index a9285760..16118f2b 100644 --- a/mindspeed_rl/workers/resharding/weight_adaptor.py +++ b/mindspeed_rl/workers/resharding/weight_adaptor.py @@ -317,7 +317,7 @@ class Qwen2_5_VLWeightAdaptor(MegatronVLLMWeightAdaptor): ("image_encoder.projector.encoder.linear_fc1", "visual.merger.mlp.0"), ("image_encoder.projector.encoder.linear_fc2", "visual.merger.mlp.2"), ] - + def replace_name_i2t(self, inference_name): weight_suffix = "" if inference_name.endswith(".weight"): @@ -328,7 +328,7 @@ class Qwen2_5_VLWeightAdaptor(MegatronVLLMWeightAdaptor): base_name = inference_name[:-5] else: base_name = inference_name - + for megatron_pattern, vllm_pattern in self.params_mapping: vllm_regex = vllm_pattern.replace("{layer_num}", r"(\d+)") match = re.match(f"^{vllm_regex}(.*)$", base_name) @@ -336,46 +336,46 @@ class Qwen2_5_VLWeightAdaptor(MegatronVLLMWeightAdaptor): groups = match.groups() layer_nums = [g for g in groups[:-1] if g is not None and g.isdigit()] extra_suffix = groups[-1] if groups and groups[-1] is not None else "" - + megatron_result = megatron_pattern for layer_num in layer_nums: megatron_result = megatron_result.replace("{layer_num}", layer_num, 1) - + return megatron_result + extra_suffix + weight_suffix - + return inference_name - + @staticmethod def _convert_global_to_local_index(name, layer_keyword, pp_layers): """ Convert global layer index to local layer index for a given layer type. - + Args: name: Weight name containing layer information layer_keyword: Layer type keyword ('blocks' for visual, 'layers' for language model) pp_layers: List of layer counts per pipeline parallel rank - + Returns: Updated weight name with local layer index """ split_name = name.split('.') - + # Find the position of layer keyword layer_keyword_idx = -1 for i, name_part in enumerate(split_name): if name_part == layer_keyword: layer_keyword_idx = i break - + if layer_keyword_idx == -1: return name - + layer_num_idx = layer_keyword_idx + 1 if len(split_name) < layer_num_idx + 1 or not split_name[layer_num_idx].isdigit(): raise ValueError(f'Invalid {layer_keyword} name: {split_name}') - + global_idx = int(split_name[layer_num_idx]) - + # Calculate local index cumulative_layers = 0 for layers_in_pp_rank in pp_layers: @@ -386,39 +386,42 @@ class Qwen2_5_VLWeightAdaptor(MegatronVLLMWeightAdaptor): split_name[layer_num_idx] = str(local_index) return '.'.join(split_name) cumulative_layers += layers_in_pp_rank - + raise ValueError(f'Could not map {layer_keyword} {global_idx} to a local index with distribution {pp_layers}') @staticmethod - def global2local_layer(name, num_layer_list): + def global2local_layer(name, num_layer_list, vpp_rank=0, global2local_map=None): """ Transform layer names from global space to local space for Qwen2VL models. Supports both visual blocks and language model layers. - + Args: name: Weight name to transform num_layer_list: [img_pp_layers, llm_pp_layers] distribution - + Returns: Transformed weight name with local layer indices """ + if vpp_rank > 0: + raise NotImplementedError("VPP is not supported in multimodal models.") + img_pp_layers, llm_pp_layers = num_layer_list if name.startswith('visual') and 'blocks' in name: return Qwen2_5_VLWeightAdaptor._convert_global_to_local_index(name, 'blocks', img_pp_layers) elif name.startswith('language_model') and 'layers' in name: return Qwen2_5_VLWeightAdaptor._convert_global_to_local_index(name, 'layers', llm_pp_layers) - + return name @staticmethod def _categorize_weights(vllm_names): """ Categorize weight names by their types for easier processing. - + Args: vllm_names: List of vLLM weight names - + Returns: Dictionary containing categorized weight names """ @@ -428,7 +431,7 @@ class Qwen2_5_VLWeightAdaptor(MegatronVLLMWeightAdaptor): visual_post_layer_weights = [] lang_pre_layer_weights = [] lang_post_layer_weights = [] - + for name in vllm_names: if name.startswith('visual'): if 'blocks' not in name: @@ -446,7 +449,7 @@ class Qwen2_5_VLWeightAdaptor(MegatronVLLMWeightAdaptor): lang_post_layer_weights.append(name) else: lang_weights.append(name) - + return { 'visual_weights': visual_weights, 'lang_weights': lang_weights, @@ -460,10 +463,10 @@ class Qwen2_5_VLWeightAdaptor(MegatronVLLMWeightAdaptor): def _calculate_layer_ranges(pp_layers): """ Calculate layer ranges for each pipeline parallel stage. - + Args: pp_layers: List of layer counts per pipeline parallel rank - + Returns: List of (start_layer, end_layer) tuples for each rank """ @@ -512,59 +515,68 @@ class Qwen2_5_VLWeightAdaptor(MegatronVLLMWeightAdaptor): weight_names_per_pp[pp_rank].append(name) @staticmethod - def get_weight_names_per_pp(layer_list, vllm_names): + def get_weight_names_per_pp(layer_list, vllm_names, layers_num=None, vpp_size=0, noop_layers=None): """ Get weight names for each pipeline parallel stage optimized for Qwen2VL models. - + Args: layer_list: [img_pp_layers, llm_pp_layers] distribution vllm_names: List of vLLM weight names - + Returns: List of weight names for each pipeline parallel rank """ + if not layers_num: + if vpp_size > 0: + ValueError(f"layers_num is required with vpp_size = {vpp_size}") + layers_num = [sum(sub_layer_list) for sub_layer_list in layer_list] + if vpp_size > 1: + raise NotImplementedError("VPP is not supported in multimodal models.") + img_pp_layers, llm_pp_layers = layer_list pp_size = len(img_pp_layers) - + weight_categories = Qwen2_5_VLWeightAdaptor._categorize_weights(vllm_names) - + img_blocks_range = Qwen2_5_VLWeightAdaptor._calculate_layer_ranges(img_pp_layers) llm_layers_range = Qwen2_5_VLWeightAdaptor._calculate_layer_ranges(llm_pp_layers) - + weight_names_per_pp = [[] for _ in range(pp_size)] - + last_img_rank = Qwen2_5_VLWeightAdaptor._find_last_rank(img_pp_layers) first_llm_rank = Qwen2_5_VLWeightAdaptor._find_first_rank(llm_pp_layers) last_llm_rank = Qwen2_5_VLWeightAdaptor._find_last_rank(llm_pp_layers) - + # Process visual weights for pp_rank in range(pp_size): start_layer, end_layer = img_blocks_range[pp_rank] - + if start_layer == 0 and end_layer >= 0: weight_names_per_pp[pp_rank].extend(weight_categories['visual_pre_layer_weights']) - + if pp_rank == last_img_rank: weight_names_per_pp[pp_rank].extend(weight_categories['visual_post_layer_weights']) - + # Assign visual layer weights Qwen2_5_VLWeightAdaptor._assign_layer_weights( weight_names_per_pp, weight_categories['visual_weights'], img_blocks_range, 'blocks' ) - + # Process language model weights for pp_rank in range(pp_size): if pp_rank == first_llm_rank: weight_names_per_pp[pp_rank].extend(weight_categories['lang_pre_layer_weights']) - + if pp_rank == last_llm_rank: weight_names_per_pp[pp_rank].extend(weight_categories['lang_post_layer_weights']) - + # Assign language model layer weights Qwen2_5_VLWeightAdaptor._assign_layer_weights( weight_names_per_pp, weight_categories['lang_weights'], llm_layers_range, 'layers' ) - + + # Align vpp format, only support vpp=1 in the current version + weight_names_per_pp = [weight_names_per_pp] return weight_names_per_pp -- Gitee