diff --git a/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py b/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py index 3323579eb3debd927b2cb7396fca012da586cc21..5ace26b0daf4583dc64c8c11ad5f3c4df450ed14 100644 --- a/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py +++ b/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py @@ -65,6 +65,38 @@ def qwen_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module, return vllm_model +def qwen_vl_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module, + infer_paralle_config: InferParallelConfig, hf_config: PretrainedConfig) -> nn.Module: + params_dict = dict(vllm_model.named_parameters()) + vision_config = type('obj', (object,), { + 'num_attention_heads': hf_config.vision_config.num_heads, + 'num_key_value_heads': hf_config.vision_config.num_heads, + }) + + for name, loaded_weight in actor_weights.items(): + if name not in params_dict.keys(): + continue + if "qkv" in name: + if 'visual' in name: + if name.endswith('.bias'): + q_weight, k_weight, v_weight = qkv_split_bias(loaded_weight, infer_paralle_config, vision_config) + loaded_weight.copy_(torch.cat([q_weight, k_weight, v_weight], dim=0)) + else: + q_weight, k_weight, v_weight = qkv_split_weight(loaded_weight, infer_paralle_config, vision_config) + loaded_weight.copy_(torch.cat([q_weight, k_weight, v_weight], dim=0)) + else: + if name.endswith('.bias'): + q_weight, k_weight, v_weight = qkv_split_bias(loaded_weight, infer_paralle_config, hf_config) + loaded_weight.copy_(torch.cat([q_weight, k_weight, v_weight], dim=0)) + else: + q_weight, k_weight, v_weight = qkv_split_weight(loaded_weight, infer_paralle_config, hf_config) + loaded_weight.copy_(torch.cat([q_weight, k_weight, v_weight], dim=0)) + + load_single_weight(params_dict, name, loaded_weight) + + return vllm_model + + def deepseek_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module, infer_paralle_config: InferParallelConfig, hf_config: PretrainedConfig ) -> nn.Module: @@ -189,4 +221,5 @@ MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY = { "DeepseekV3ForCausalLM": deepseek_megatron_weight_loader, "DeepseekV2ForCausalLM": deepseek_megatron_weight_loader, "CustomDeepseekV3ForCausalLM": deepseek_megatron_weight_loader, + "Qwen2_5_VLForConditionalGeneration": qwen_vl_megatron_weight_loader } diff --git a/mindspeed_rl/workers/resharding/vllm_weight_container.py b/mindspeed_rl/workers/resharding/vllm_weight_container.py index 15664ee375f9f4b5feab383a7631b70b5d9ae096..45fd3364b6320efd0bc1b865d8675dbd97c519f2 100644 --- a/mindspeed_rl/workers/resharding/vllm_weight_container.py +++ b/mindspeed_rl/workers/resharding/vllm_weight_container.py @@ -19,6 +19,7 @@ This file contains a Megatron style Hybrid Model that shares the weights of the import hashlib import re from functools import partial +from typing import List import torch import torch.distributed as dist @@ -180,6 +181,9 @@ class MegatronStyleVllmWeightContainer: def _build_num_layer_list(self, num_layer_list): if num_layer_list: + # multimodal num_layer_list is a list of lists, including vit and llm layers + if isinstance(num_layer_list[0], List): + return num_layer_list return [int(num_layers) for num_layers in num_layer_list.split(',')] if self._num_hidden_layers % self._pp_size != 0: raise ValueError("num_layers % pp_size == 0, please specify num_layer_list") @@ -350,7 +354,7 @@ class MegatronStyleVllmWeightContainer: # 检查 linear_fc1 和 linear_fc2 权重形状是否符合特定关系(fc1 包含门控和扩展参数,因此大小是 fc2 的两倍)。不符合条件的模型不被支持。 for _, vpp_rank, megatron_name in name_pairs: - if megatron_name.endswith("linear_fc1.weight"): + if not megatron_name.startswith("image_encoder") and megatron_name.endswith("linear_fc1.weight"): fc2_name = megatron_name.replace("linear_fc1", "linear_fc2") megatron_param_fc1 = dict(true_megatron_model[vpp_rank].named_parameters())[megatron_name] megatron_param_fc2 = dict(true_megatron_model[vpp_rank].named_parameters())[fc2_name] @@ -563,7 +567,7 @@ class MegatronStyleVllmWeightContainer: we can throw an error to force user disable TP HybridEngine. """ - if "linear_fc1.weight" in name: + if 'projector' not in name and 'linear_fc1' in name: # if the tensor is gate and proj gate_lst = [] up_lst = [] diff --git a/mindspeed_rl/workers/resharding/weight_adaptor.py b/mindspeed_rl/workers/resharding/weight_adaptor.py index 9a333a94f63c5ac6ca14966370cce5ab0cdb631f..a92857606af268fbad6e16f6502f82a9d427dbde 100644 --- a/mindspeed_rl/workers/resharding/weight_adaptor.py +++ b/mindspeed_rl/workers/resharding/weight_adaptor.py @@ -293,11 +293,287 @@ class QwenMVWeightAdaptor(MegatronVLLMWeightAdaptor): super(QwenMVWeightAdaptor, self).__init__(model_config) +class Qwen2_5_VLWeightAdaptor(MegatronVLLMWeightAdaptor): + def __init__(self, model_config): + super(Qwen2_5_VLWeightAdaptor, self).__init__(model_config) + self.params_mapping = [ + ("text_decoder.embedding.word_embeddings", "language_model.model.embed_tokens"), + ("text_decoder.decoder.layers.{layer_num}.self_attention.linear_qkv", "language_model.model.layers.{layer_num}.self_attn.qkv_proj"), + ("text_decoder.decoder.layers.{layer_num}.self_attention.linear_proj", "language_model.model.layers.{layer_num}.self_attn.o_proj"), + ("text_decoder.decoder.layers.{layer_num}.input_layernorm", "language_model.model.layers.{layer_num}.input_layernorm"), + ("text_decoder.decoder.layers.{layer_num}.pre_mlp_layernorm", "language_model.model.layers.{layer_num}.post_attention_layernorm"), + ("text_decoder.decoder.layers.{layer_num}.mlp.linear_fc1", "language_model.model.layers.{layer_num}.mlp.gate_up_proj"), + ("text_decoder.decoder.layers.{layer_num}.mlp.linear_fc2", "language_model.model.layers.{layer_num}.mlp.down_proj"), + ("text_decoder.decoder.final_layernorm", "language_model.model.norm"), + ("text_decoder.output_layer", "language_model.lm_head"), + ("image_encoder.encoder.patch_embed.proj", "visual.patch_embed.proj"), + ("image_encoder.encoder.blocks.layers.{layer_num}.self_attention.linear_qkv", "visual.blocks.{layer_num}.attn.qkv"), + ("image_encoder.encoder.blocks.layers.{layer_num}.self_attention.linear_proj", "visual.blocks.{layer_num}.attn.proj"), + ("image_encoder.encoder.blocks.layers.{layer_num}.input_layernorm", "visual.blocks.{layer_num}.norm1"), + ("image_encoder.encoder.blocks.layers.{layer_num}.pre_mlp_layernorm", "visual.blocks.{layer_num}.norm2"), + ("image_encoder.encoder.blocks.layers.{layer_num}.mlp.linear_fc1", "visual.blocks.{layer_num}.mlp.gate_up_proj"), + ("image_encoder.encoder.blocks.layers.{layer_num}.mlp.linear_fc2", "visual.blocks.{layer_num}.mlp.down_proj"), + ("image_encoder.projector.layernorm", "visual.merger.ln_q"), + ("image_encoder.projector.encoder.linear_fc1", "visual.merger.mlp.0"), + ("image_encoder.projector.encoder.linear_fc2", "visual.merger.mlp.2"), + ] + + def replace_name_i2t(self, inference_name): + weight_suffix = "" + if inference_name.endswith(".weight"): + weight_suffix = ".weight" + base_name = inference_name[:-7] + elif inference_name.endswith(".bias"): + weight_suffix = ".bias" + base_name = inference_name[:-5] + else: + base_name = inference_name + + for megatron_pattern, vllm_pattern in self.params_mapping: + vllm_regex = vllm_pattern.replace("{layer_num}", r"(\d+)") + match = re.match(f"^{vllm_regex}(.*)$", base_name) + if match: + groups = match.groups() + layer_nums = [g for g in groups[:-1] if g is not None and g.isdigit()] + extra_suffix = groups[-1] if groups and groups[-1] is not None else "" + + megatron_result = megatron_pattern + for layer_num in layer_nums: + megatron_result = megatron_result.replace("{layer_num}", layer_num, 1) + + return megatron_result + extra_suffix + weight_suffix + + return inference_name + + @staticmethod + def _convert_global_to_local_index(name, layer_keyword, pp_layers): + """ + Convert global layer index to local layer index for a given layer type. + + Args: + name: Weight name containing layer information + layer_keyword: Layer type keyword ('blocks' for visual, 'layers' for language model) + pp_layers: List of layer counts per pipeline parallel rank + + Returns: + Updated weight name with local layer index + """ + split_name = name.split('.') + + # Find the position of layer keyword + layer_keyword_idx = -1 + for i, name_part in enumerate(split_name): + if name_part == layer_keyword: + layer_keyword_idx = i + break + + if layer_keyword_idx == -1: + return name + + layer_num_idx = layer_keyword_idx + 1 + if len(split_name) < layer_num_idx + 1 or not split_name[layer_num_idx].isdigit(): + raise ValueError(f'Invalid {layer_keyword} name: {split_name}') + + global_idx = int(split_name[layer_num_idx]) + + # Calculate local index + cumulative_layers = 0 + for layers_in_pp_rank in pp_layers: + if layers_in_pp_rank == 0: + continue + if cumulative_layers <= global_idx < cumulative_layers + layers_in_pp_rank: + local_index = global_idx - cumulative_layers + split_name[layer_num_idx] = str(local_index) + return '.'.join(split_name) + cumulative_layers += layers_in_pp_rank + + raise ValueError(f'Could not map {layer_keyword} {global_idx} to a local index with distribution {pp_layers}') + + @staticmethod + def global2local_layer(name, num_layer_list): + """ + Transform layer names from global space to local space for Qwen2VL models. + Supports both visual blocks and language model layers. + + Args: + name: Weight name to transform + num_layer_list: [img_pp_layers, llm_pp_layers] distribution + + Returns: + Transformed weight name with local layer indices + """ + img_pp_layers, llm_pp_layers = num_layer_list + + if name.startswith('visual') and 'blocks' in name: + return Qwen2_5_VLWeightAdaptor._convert_global_to_local_index(name, 'blocks', img_pp_layers) + elif name.startswith('language_model') and 'layers' in name: + return Qwen2_5_VLWeightAdaptor._convert_global_to_local_index(name, 'layers', llm_pp_layers) + + return name + + @staticmethod + def _categorize_weights(vllm_names): + """ + Categorize weight names by their types for easier processing. + + Args: + vllm_names: List of vLLM weight names + + Returns: + Dictionary containing categorized weight names + """ + visual_weights = [] + lang_weights = [] + visual_pre_layer_weights = [] + visual_post_layer_weights = [] + lang_pre_layer_weights = [] + lang_post_layer_weights = [] + + for name in vllm_names: + if name.startswith('visual'): + if 'blocks' not in name: + if 'patch_embed' in name: + visual_pre_layer_weights.append(name) + elif 'merger' in name: + visual_post_layer_weights.append(name) + else: + visual_weights.append(name) + elif name.startswith('language_model'): + if 'layers' not in name: + if 'embed_tokens' in name: + lang_pre_layer_weights.append(name) + else: + lang_post_layer_weights.append(name) + else: + lang_weights.append(name) + + return { + 'visual_weights': visual_weights, + 'lang_weights': lang_weights, + 'visual_pre_layer_weights': visual_pre_layer_weights, + 'visual_post_layer_weights': visual_post_layer_weights, + 'lang_pre_layer_weights': lang_pre_layer_weights, + 'lang_post_layer_weights': lang_post_layer_weights + } + + @staticmethod + def _calculate_layer_ranges(pp_layers): + """ + Calculate layer ranges for each pipeline parallel stage. + + Args: + pp_layers: List of layer counts per pipeline parallel rank + + Returns: + List of (start_layer, end_layer) tuples for each rank + """ + layer_ranges = [] + start_layer = 0 + for layers_in_pp_rank in pp_layers: + if layers_in_pp_rank > 0: + layer_ranges.append((start_layer, start_layer + layers_in_pp_rank - 1)) + start_layer += layers_in_pp_rank + else: + layer_ranges.append((-1, -1)) + return layer_ranges + + @staticmethod + def _find_last_rank(pp_layers): + """ + Find the last pipeline parallel rank that has non-zero layers. + """ + for i in range(len(pp_layers) - 1, -1, -1): + if pp_layers[i] > 0: + return i + return -1 + + @staticmethod + def _find_first_rank(pp_layers): + """ + Find the first pipeline parallel rank that has non-zero layers. + """ + for i, layers in enumerate(pp_layers): + if layers > 0: + return i + return -1 + + @staticmethod + def _assign_layer_weights(weight_names_per_pp, weights, layer_ranges, layer_keyword): + """ + Assign layer weights to their corresponding pipeline parallel stages. + """ + for pp_rank, (start_layer, end_layer) in enumerate(layer_ranges): + if start_layer >= 0 and end_layer >= 0: + for name in weights: + match = re.match(rf'.*\.{layer_keyword}\.(\d+)', name) + if match: + layer_num = int(match.group(1)) + if start_layer <= layer_num <= end_layer: + weight_names_per_pp[pp_rank].append(name) + + @staticmethod + def get_weight_names_per_pp(layer_list, vllm_names): + """ + Get weight names for each pipeline parallel stage optimized for Qwen2VL models. + + Args: + layer_list: [img_pp_layers, llm_pp_layers] distribution + vllm_names: List of vLLM weight names + + Returns: + List of weight names for each pipeline parallel rank + """ + img_pp_layers, llm_pp_layers = layer_list + pp_size = len(img_pp_layers) + + weight_categories = Qwen2_5_VLWeightAdaptor._categorize_weights(vllm_names) + + img_blocks_range = Qwen2_5_VLWeightAdaptor._calculate_layer_ranges(img_pp_layers) + llm_layers_range = Qwen2_5_VLWeightAdaptor._calculate_layer_ranges(llm_pp_layers) + + weight_names_per_pp = [[] for _ in range(pp_size)] + + last_img_rank = Qwen2_5_VLWeightAdaptor._find_last_rank(img_pp_layers) + first_llm_rank = Qwen2_5_VLWeightAdaptor._find_first_rank(llm_pp_layers) + last_llm_rank = Qwen2_5_VLWeightAdaptor._find_last_rank(llm_pp_layers) + + # Process visual weights + for pp_rank in range(pp_size): + start_layer, end_layer = img_blocks_range[pp_rank] + + if start_layer == 0 and end_layer >= 0: + weight_names_per_pp[pp_rank].extend(weight_categories['visual_pre_layer_weights']) + + if pp_rank == last_img_rank: + weight_names_per_pp[pp_rank].extend(weight_categories['visual_post_layer_weights']) + + # Assign visual layer weights + Qwen2_5_VLWeightAdaptor._assign_layer_weights( + weight_names_per_pp, weight_categories['visual_weights'], img_blocks_range, 'blocks' + ) + + # Process language model weights + for pp_rank in range(pp_size): + if pp_rank == first_llm_rank: + weight_names_per_pp[pp_rank].extend(weight_categories['lang_pre_layer_weights']) + + if pp_rank == last_llm_rank: + weight_names_per_pp[pp_rank].extend(weight_categories['lang_post_layer_weights']) + + # Assign language model layer weights + Qwen2_5_VLWeightAdaptor._assign_layer_weights( + weight_names_per_pp, weight_categories['lang_weights'], llm_layers_range, 'layers' + ) + + return weight_names_per_pp + + WEIGHT_ADAPTOR_REGISTRY = { "Qwen2ForCausalLM": QwenMVWeightAdaptor, "DeepseekV3ForCausalLM": DeepSeekMVWeightAdaptor, "DeepseekV2ForCausalLM": DeepSeekMVWeightAdaptor, "CustomDeepseekV3ForCausalLM": DeepSeekMVWeightAdaptor, + "Qwen2_5_VLForConditionalGeneration": Qwen2_5_VLWeightAdaptor, }