From db999cd6184d9068fb013fd4da5ee849f7485e2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=8D=E9=87=91=E5=BC=88?= Date: Tue, 24 Jun 2025 07:09:34 +0800 Subject: [PATCH 1/3] Remove get_weight_buffer_meta_from_buffer --- mindspeed_rl/workers/resharding/memory_buffer.py | 9 --------- mindspeed_rl/workers/resharding/vllm_weight_container.py | 8 ++++---- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/mindspeed_rl/workers/resharding/memory_buffer.py b/mindspeed_rl/workers/resharding/memory_buffer.py index a8b00eb3..bd657335 100644 --- a/mindspeed_rl/workers/resharding/memory_buffer.py +++ b/mindspeed_rl/workers/resharding/memory_buffer.py @@ -85,15 +85,6 @@ def calc_padded_numel(shape: torch.Size, dtype: torch.dtype): return (numel + align_numel - 1) // align_numel * align_numel -# 构建EP增大的buffer———构造一个experts_weight_buffer_meta -def get_weight_buffer_meta_from_buffer(weight_buffer_meta) -> Dict[str, Dict]: - experts_weight_buffer_meta = {} - for name, meta_info in sorted(weight_buffer_meta.items()): - if "mlp.experts" in name: - experts_weight_buffer_meta[name] = meta_info - return experts_weight_buffer_meta - - def build_memory_buffer(weight_buffer_meta: Dict[str, Dict]) -> Dict[torch.dtype, MemoryBuffer]: """Build the memory buffer given weight_buffer_meta diff --git a/mindspeed_rl/workers/resharding/vllm_weight_container.py b/mindspeed_rl/workers/resharding/vllm_weight_container.py index 0f8d3ff6..21de31ca 100644 --- a/mindspeed_rl/workers/resharding/vllm_weight_container.py +++ b/mindspeed_rl/workers/resharding/vllm_weight_container.py @@ -391,14 +391,14 @@ class MegatronStyleVllmWeightContainer: # 构造临时的experts_memory_buffers for cur_pp_rank in range(self._pp_size): pp_rank = self._pp_rank - from mindspeed_rl.workers.resharding.memory_buffer import build_experts_memory_buffer, get_weight_buffer_meta_from_buffer + from mindspeed_rl.workers.resharding.memory_buffer import build_experts_memory_buffer # Step1 在当前的PP_rank中,设置一个临时的exprts_buffer combined_names_per_pp = [] vpp_stages = self.weight_names_per_pp[cur_pp_rank] for weight_names_per_stage in vpp_stages: - combined_names_per_pp.extend(weight_names_per_stage) - self.weight_buffer_meta = self.weight_adaptor.get_weight_buffer_meta(self.vllm_model, combined_names_per_pp) - self.experts_weight_buffer_meta = get_weight_buffer_meta_from_buffer(self.weight_buffer_meta) + combined_names_per_pp.extend(filter(lambda name: "mlp.experts" in name, weight_names_per_stage)) + self.experts_weight_buffer_meta = self.weight_adaptor.get_weight_buffer_meta(self.vllm_model, + combined_names_per_pp) self.experts_memory_buffers = build_experts_memory_buffer(self.experts_weight_buffer_meta, self.experts_memory_expand_N) # Step2 将weights_buffer上对应的权重放到experts_buffer中 -- Gitee From 2919199f484157a342c7a84db8a2c223cd18fe10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=8D=E9=87=91=E5=BC=88?= Date: Tue, 24 Jun 2025 11:01:54 +0800 Subject: [PATCH 2/3] Add support for irregular resharding handling --- .../vllm_adapter/megatron_weight_loaders.py | 4 +- .../workers/resharding/memory_buffer.py | 10 +++- .../workers/resharding/weight_adaptor.py | 57 ++++++++++++++++++- 3 files changed, 66 insertions(+), 5 deletions(-) diff --git a/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py b/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py index 5ace26b0..740f60ad 100644 --- a/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py +++ b/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py @@ -177,6 +177,7 @@ def update_megatron_weight_loader(): RowParallelLinear, ReplicatedLinear) from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding + from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MergedReplicatedLinear LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY = { ColumnParallelLinear: parallel_weight_loader, @@ -186,7 +187,8 @@ def update_megatron_weight_loader(): VocabParallelEmbedding: parallel_weight_loader, ParallelLMHead: parallel_weight_loader, ReplicatedLinear: parallel_weight_loader, - FusedMoE: parallel_weight_loader + FusedMoE: parallel_weight_loader, + CustomDeepseekV2MergedReplicatedLinear: parallel_weight_loader, } for layer_class, weight_loader in LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY.items(): diff --git a/mindspeed_rl/workers/resharding/memory_buffer.py b/mindspeed_rl/workers/resharding/memory_buffer.py index bd657335..bd932881 100644 --- a/mindspeed_rl/workers/resharding/memory_buffer.py +++ b/mindspeed_rl/workers/resharding/memory_buffer.py @@ -16,7 +16,7 @@ This file contains utilities to manipulate torch memory buffers """ -from typing import Dict, List +from typing import Callable, Dict, List, Optional import torch from torch import nn @@ -52,9 +52,12 @@ class MemoryBuffer: buffer_tensor = buffer_tensor.view(shape) return buffer_tensor - def copy_by_name(self, param_name: str, param): + def copy_by_name(self, param_name: str, param, + extra_resharder: Optional[Callable] = None): """Copy buffer_tensor""" buffer_tensor = self.get_by_name(param_name) + if extra_resharder is not None: + param = extra_resharder(param, param_name) try: buffer_tensor = buffer_tensor.view(param.shape) except RuntimeError as err: @@ -206,7 +209,8 @@ class ModelWeightBuffer: def copy_by_name(self, weight_name: str, param): dtype = self.weight_buffer_meta[weight_name]['dtype'] - self.memory_buffers[dtype].copy_by_name(weight_name, param) + resharder = self.weight_buffer_meta[weight_name].get('resharder') + self.memory_buffers[dtype].copy_by_name(weight_name, param, resharder) def offload(self): for memory_buffer in self.memory_buffers.values(): diff --git a/mindspeed_rl/workers/resharding/weight_adaptor.py b/mindspeed_rl/workers/resharding/weight_adaptor.py index d797d3ea..71d8b54b 100644 --- a/mindspeed_rl/workers/resharding/weight_adaptor.py +++ b/mindspeed_rl/workers/resharding/weight_adaptor.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Callable, Type import re import torch @@ -232,8 +233,56 @@ class MegatronVLLMWeightAdaptor(BaseWeightAdaptor): return weight_names_per_vpp_combined + @staticmethod + def _module_lookup(model: torch.nn.Module, path: str) -> Type[torch.nn.Module]: + module = model + for part in path.split(".")[:-1]: + if part.isdigit(): + module = module[int(part)] + else: + module = getattr(module, part) + + return type(module) + class DeepSeekMVWeightAdaptor(MegatronVLLMWeightAdaptor): + @staticmethod + def replicated_linear_resharder(param: torch.Tensor, name: str) -> torch.Tensor: + if "shared_experts" not in name: + return param + from vllm.distributed.parallel_state import _TP + + gathered_param = [torch.empty_like(param) for _ in range(_TP.world_size)] + torch.distributed.all_gather(gathered_param, param, group=_TP.device_group) + if "gate_up_proj" in name: + gate_lst = [] + up_lst = [] + for p in gathered_param: + gate, up = p.chunk(2, dim=0) + gate_lst.append(gate) + up_lst.append(up) + return torch.cat(gate_lst + up_lst, dim=0) + else: + return torch.cat(gathered_param, dim=1) + + @classmethod + def _get_extra_resharder( + cls, model: torch.nn.Module, name: str + ) -> Callable[[torch.Tensor, str], torch.Tensor]: + from vllm.model_executor.layers.linear import ReplicatedLinear + # vllm_ascend patches must be imported first. + from vllm_ascend.patch import worker as _worker, platform as _platform + from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MergedReplicatedLinear + + # 用于处理训推不规则切分带来的resharding问题,当前仅支持: + # + # - 共享专家在训练时作TP切分但推理不做TP切分 + EXTRA_RESHARDER_REGISTRY = { + CustomDeepseekV2MergedReplicatedLinear: cls.replicated_linear_resharder, + ReplicatedLinear: cls.replicated_linear_resharder, + } + return EXTRA_RESHARDER_REGISTRY.get(super()._module_lookup(model, name)) + """ Megatron-vLLM WeightAdaptor for DeepSeek model architectures. """ @@ -278,7 +327,13 @@ class DeepSeekMVWeightAdaptor(MegatronVLLMWeightAdaptor): elif 'q_a_proj' in name or 'q_proj' in name: continue else: - weight_buffer_meta[name] = {'shape': param.shape, 'dtype': param.dtype} + weight_buffer_meta[name] = { + 'shape': param.shape, + 'dtype': param.dtype, + 'resharder': + DeepSeekMVWeightAdaptor._get_extra_resharder(model, name), + } + return weight_buffer_meta def convert_weight_name_meta(self, weight_names): -- Gitee From 9684e2e9e720011ba4d4343329bdd10a06b6eeaf Mon Sep 17 00:00:00 2001 From: p00465316 Date: Sat, 14 Jun 2025 17:52:58 +0800 Subject: [PATCH 3/3] add environment variables for vllm improve performance --- mindspeed_rl/config_cls/generate_config.py | 21 ++++++++++++++++++--- mindspeed_rl/models/rollout/vllm_engine.py | 18 ++++++++++++++++-- mindspeed_rl/workers/actor_hybrid_worker.py | 8 ++++++++ 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/mindspeed_rl/config_cls/generate_config.py b/mindspeed_rl/config_cls/generate_config.py index 1c83eb94..a274086c 100644 --- a/mindspeed_rl/config_cls/generate_config.py +++ b/mindspeed_rl/config_cls/generate_config.py @@ -22,12 +22,20 @@ class GenerateConfig(BaseConfig): max_num_seqs: Maximum number of sequences to process simultaneously. Default is 256. max_model_len: Maximum model length (in tokens). Default is 2048. + max_num_batched_tokens: The maximum number of tokens model can run in a single batch. Default is 2048. dtype: Data type for model weights. Default is "bfloat16". gpu_memory_utilization: GPU memory utilization factor. Default is 0.5. - enforce_eager: Whether to always use eager-mode PyTorch. If True, we will disable ACL graph and always execute the model in eager mode. - If False, we will use ACL graph and eager execution in hybrid for maximal performance and flexibility. - + enforce_eager: Whether to always use eager-mode PyTorch. If True, we will disable ACL graph and always execute the model in eager mode. + If False, we will use ACL graph and eager execution in hybrid for maximal performance and flexibility. + torchair_graph: Whether to enable TorchAir graph optimization. If True, uses accelerated computational graph optimizations. + chunked_prefill_enabled: Whether to split long-sequence prefill operations into chunks. If True, processes sequences in segments to reduce peak memory usage; + If False, processes entire sequences at once (faster but requires more memory). + enable_expert_parallel: Whether to enable expert parallel computation for Mixture-of-Experts (MoE) layers. + enable_multistream_mla: Whether to put vector ops of MLA to another stream + enable_multistream_moe: Whether to enable multistream shared expert + enable_view_optimize: Whether to enable torchair view optimization + enable_kv_nz: Whether to enable kvcache NZ layout sampling_config: Configuration for text generation sampling. Default values are set for various sampling parameters. - num_completions: The number of independent completions to generate for each input prompt. Default is 1. - logprobs: The number of top tokens to return log probabilities for. Default is 1. @@ -81,6 +89,13 @@ class GenerateConfig(BaseConfig): self.enable_prefix_caching = False self.num_scheduler_steps = 1 self.enforce_eager = True + self.torchair_graph = False + self.chunked_prefill_enabled = False + self.enable_expert_parallel = False + self.enable_multistream_mla = False + self.enable_multistream_moe = False + self.enable_view_optimize = True + self.enable_kv_nz = False # 采样配置的默认值,用于生成文本时的采样策略设置 self.sampling_config = { diff --git a/mindspeed_rl/models/rollout/vllm_engine.py b/mindspeed_rl/models/rollout/vllm_engine.py index 29525e75..67e53f1d 100644 --- a/mindspeed_rl/models/rollout/vllm_engine.py +++ b/mindspeed_rl/models/rollout/vllm_engine.py @@ -69,6 +69,13 @@ class VLLMInferEngine(BaseInferEngine): enforce_eager: bool = False, limit_mm_image_per_prompt: int = 1, limit_mm_video_per_prompt: int = 0, + enable_expert_parallel: bool = False, + torchair_graph: bool = False, + chunked_prefill_enabled: bool = False, + enable_multistream_mla: bool = False, + enable_multistream_moe: bool = False, + enable_view_optimize: bool = True, + enable_kv_nz: bool = False, **kwargs ): """ @@ -191,11 +198,18 @@ class VLLMInferEngine(BaseInferEngine): max_model_len=max_model_len, seed=self.sampling_params.seed, limit_mm_per_prompt=limit_mm_per_prompt_dict, + enable_expert_parallel=enable_expert_parallel, additional_config={ 'expert_tensor_parallel_size': infer_expert_tensor_parallel_size, - 'enable_graph_mode': int(os.environ.get('VLLM_ENABLE_GRAPH_MODE', '0')), 'ascend_scheduler_config': {}, - } + 'torchair_graph_config': { + 'enabled': torchair_graph, + 'enable_multistream_mla': enable_multistream_mla, + 'enable_multistream_moe': enable_multistream_moe, + 'enable_view_optimize': enable_view_optimize, + 'enable_kv_nz': enable_kv_nz, + }, + }, ) self.model = self.llm.llm_engine.model_executor.driver_worker.worker.model_runner.get_model() diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index 36e8ad1e..5391e922 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -478,6 +478,14 @@ class ActorHybridWorkerBase(BaseWorker): enforce_eager=self.generate_config.enforce_eager, limit_mm_image_per_prompt=self.generate_config.limit_mm_image_per_prompt, limit_mm_video_per_prompt=self.generate_config.limit_mm_video_per_prompt + chunked_prefill_enabled=self.generate_config.chunked_prefill_enabled, + torchair_graph=self.generate_config.torchair_graph, + enable_expert_parallel=self.generate_config.enable_expert_parallel, + max_num_batched_tokens=self.generate_config.max_num_batched_tokens, + enable_multistream_mla=self.generate_config.enable_multistream_mla, + enable_multistream_moe=self.generate_config.enable_multistream_moe, + enable_view_optimize=self.generate_config.enable_view_optimize, + enable_kv_nz=self.generate_config.enable_kv_nz, ) return rollout -- Gitee