From 9cfd5fab14fc3456f024d0e34e0421a1c8d00aab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=8D=E9=87=91=E5=BC=88?= Date: Tue, 24 Jun 2025 07:09:34 +0800 Subject: [PATCH 1/2] Remove get_weight_buffer_meta_from_buffer --- mindspeed_rl/workers/resharding/memory_buffer.py | 9 --------- mindspeed_rl/workers/resharding/vllm_weight_container.py | 8 ++++---- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/mindspeed_rl/workers/resharding/memory_buffer.py b/mindspeed_rl/workers/resharding/memory_buffer.py index a8b00eb3..bd657335 100644 --- a/mindspeed_rl/workers/resharding/memory_buffer.py +++ b/mindspeed_rl/workers/resharding/memory_buffer.py @@ -85,15 +85,6 @@ def calc_padded_numel(shape: torch.Size, dtype: torch.dtype): return (numel + align_numel - 1) // align_numel * align_numel -# 构建EP增大的buffer———构造一个experts_weight_buffer_meta -def get_weight_buffer_meta_from_buffer(weight_buffer_meta) -> Dict[str, Dict]: - experts_weight_buffer_meta = {} - for name, meta_info in sorted(weight_buffer_meta.items()): - if "mlp.experts" in name: - experts_weight_buffer_meta[name] = meta_info - return experts_weight_buffer_meta - - def build_memory_buffer(weight_buffer_meta: Dict[str, Dict]) -> Dict[torch.dtype, MemoryBuffer]: """Build the memory buffer given weight_buffer_meta diff --git a/mindspeed_rl/workers/resharding/vllm_weight_container.py b/mindspeed_rl/workers/resharding/vllm_weight_container.py index 729a602e..fe749a1a 100644 --- a/mindspeed_rl/workers/resharding/vllm_weight_container.py +++ b/mindspeed_rl/workers/resharding/vllm_weight_container.py @@ -394,14 +394,14 @@ class MegatronStyleVllmWeightContainer: # 构造临时的experts_memory_buffers for cur_pp_rank in range(self._pp_size): pp_rank = self._pp_rank - from mindspeed_rl.workers.resharding.memory_buffer import build_experts_memory_buffer, get_weight_buffer_meta_from_buffer + from mindspeed_rl.workers.resharding.memory_buffer import build_experts_memory_buffer # Step1 在当前的PP_rank中,设置一个临时的exprts_buffer combined_names_per_pp = [] vpp_stages = self.weight_names_per_pp[cur_pp_rank] for weight_names_per_stage in vpp_stages: - combined_names_per_pp.extend(weight_names_per_stage) - self.weight_buffer_meta = self.weight_adaptor.get_weight_buffer_meta(self.vllm_model, combined_names_per_pp) - self.experts_weight_buffer_meta = get_weight_buffer_meta_from_buffer(self.weight_buffer_meta) + combined_names_per_pp.extend(filter(lambda name: "mlp.experts" in name, weight_names_per_stage)) + self.experts_weight_buffer_meta = self.weight_adaptor.get_weight_buffer_meta(self.vllm_model, + combined_names_per_pp) self.experts_memory_buffers = build_experts_memory_buffer(self.experts_weight_buffer_meta, self.experts_memory_expand_N) # Step2 将weights_buffer上对应的权重放到experts_buffer中 -- Gitee From 2a4a09afd1c5c7c196a5682d5645f0140ef5c1e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=8D=E9=87=91=E5=BC=88?= Date: Tue, 24 Jun 2025 11:01:54 +0800 Subject: [PATCH 2/2] Add support for irregular resharding handling --- .../vllm_adapter/megatron_weight_loaders.py | 4 +- .../workers/resharding/memory_buffer.py | 10 +++- .../workers/resharding/weight_adaptor.py | 56 ++++++++++++++++++- 3 files changed, 65 insertions(+), 5 deletions(-) diff --git a/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py b/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py index 0cff4882..8d4e7f96 100644 --- a/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py +++ b/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py @@ -190,6 +190,7 @@ def update_megatron_weight_loader(): RowParallelLinear, ReplicatedLinear) from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding + from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MergedReplicatedLinear LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY = { ColumnParallelLinear: parallel_weight_loader, @@ -199,7 +200,8 @@ def update_megatron_weight_loader(): VocabParallelEmbedding: parallel_weight_loader, ParallelLMHead: parallel_weight_loader, ReplicatedLinear: parallel_weight_loader, - FusedMoE: parallel_weight_loader + FusedMoE: parallel_weight_loader, + CustomDeepseekV2MergedReplicatedLinear: parallel_weight_loader, } for layer_class, weight_loader in LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY.items(): diff --git a/mindspeed_rl/workers/resharding/memory_buffer.py b/mindspeed_rl/workers/resharding/memory_buffer.py index bd657335..bd932881 100644 --- a/mindspeed_rl/workers/resharding/memory_buffer.py +++ b/mindspeed_rl/workers/resharding/memory_buffer.py @@ -16,7 +16,7 @@ This file contains utilities to manipulate torch memory buffers """ -from typing import Dict, List +from typing import Callable, Dict, List, Optional import torch from torch import nn @@ -52,9 +52,12 @@ class MemoryBuffer: buffer_tensor = buffer_tensor.view(shape) return buffer_tensor - def copy_by_name(self, param_name: str, param): + def copy_by_name(self, param_name: str, param, + extra_resharder: Optional[Callable] = None): """Copy buffer_tensor""" buffer_tensor = self.get_by_name(param_name) + if extra_resharder is not None: + param = extra_resharder(param, param_name) try: buffer_tensor = buffer_tensor.view(param.shape) except RuntimeError as err: @@ -206,7 +209,8 @@ class ModelWeightBuffer: def copy_by_name(self, weight_name: str, param): dtype = self.weight_buffer_meta[weight_name]['dtype'] - self.memory_buffers[dtype].copy_by_name(weight_name, param) + resharder = self.weight_buffer_meta[weight_name].get('resharder') + self.memory_buffers[dtype].copy_by_name(weight_name, param, resharder) def offload(self): for memory_buffer in self.memory_buffers.values(): diff --git a/mindspeed_rl/workers/resharding/weight_adaptor.py b/mindspeed_rl/workers/resharding/weight_adaptor.py index 91293c90..47af04a1 100644 --- a/mindspeed_rl/workers/resharding/weight_adaptor.py +++ b/mindspeed_rl/workers/resharding/weight_adaptor.py @@ -1,7 +1,13 @@ from abc import ABC, abstractmethod +from typing import Callable, Type import re import torch +from vllm.model_executor.layers.linear import ReplicatedLinear +# vllm_ascend patches must be imported first. +from vllm_ascend.patch import worker as _worker, platform as _platform +from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MergedReplicatedLinear + class BaseWeightAdaptor(ABC): def __init__(self, model_config): @@ -232,6 +238,49 @@ class MegatronVLLMWeightAdaptor(BaseWeightAdaptor): return weight_names_per_vpp_combined + @staticmethod + def _module_lookup(model: torch.nn.Module, path: str) -> Type[torch.nn.Module]: + module = model + for part in path.split(".")[:-1]: + if part.isdigit(): + module = module[int(part)] + else: + module = getattr(module, part) + + return type(module) + + @staticmethod + def replicated_linear_resharder(param: torch.Tensor, name: str) -> torch.Tensor: + if "shared_experts" not in name: + return param + from vllm.distributed.parallel_state import _TP + + gathered_param = [torch.empty_like(param) for _ in range(_TP.world_size)] + torch.distributed.all_gather(gathered_param, param, group=_TP.device_group) + if "gate_up_proj" in name: + gate_lst = [] + up_lst = [] + for p in gathered_param: + gate, up = p.chunk(2, dim=0) + gate_lst.append(gate) + up_lst.append(up) + return torch.cat(gate_lst + up_lst, dim=0) + else: + return torch.cat(gathered_param, dim=1) + + @classmethod + def _get_extra_resharder( + cls, model: torch.nn.Module, name: str + ) -> Callable[[torch.Tensor, str], torch.Tensor]: + # 用于处理训推不规则切分带来的resharding问题,当前仅支持: + # + # - 共享专家在训练时作TP切分但推理不做TP切分 + EXTRA_RESHARDER_REGISTRY = { + CustomDeepseekV2MergedReplicatedLinear: cls.replicated_linear_resharder, + ReplicatedLinear: cls.replicated_linear_resharder, + } + return EXTRA_RESHARDER_REGISTRY.get(cls._module_lookup(model, name)) + class DeepSeekMVWeightAdaptor(MegatronVLLMWeightAdaptor): """ @@ -278,7 +327,12 @@ class DeepSeekMVWeightAdaptor(MegatronVLLMWeightAdaptor): elif 'q_a_proj' in name or 'q_proj' in name: continue else: - weight_buffer_meta[name] = {'shape': param.shape, 'dtype': param.dtype} + weight_buffer_meta[name] = { + 'shape': param.shape, + 'dtype': param.dtype, + 'resharder':super()._get_extra_resharder(model, name), + } + return weight_buffer_meta def convert_weight_name_meta(self, weight_names): -- Gitee