From 710b3d0893ed84f3e8bee4847b9ee5ed689cfac8 Mon Sep 17 00:00:00 2001 From: w00521005 Date: Wed, 12 Nov 2025 09:28:09 +0800 Subject: [PATCH] qwen3-moe eplb --- omni/accelerators/placement/config.yaml | 6 ++-- .../placement/omni_placement/omni_planner.py | 9 +++++- .../placement/omni_placement/utils.py | 29 +++++++++++++++---- omni/adaptors/vllm/worker/npu_model_runner.py | 2 +- omni/models/config_loader/loader.py | 2 +- 5 files changed, 37 insertions(+), 11 deletions(-) diff --git a/omni/accelerators/placement/config.yaml b/omni/accelerators/placement/config.yaml index 3c5dd12233..00ca5ba1cc 100644 --- a/omni/accelerators/placement/config.yaml +++ b/omni/accelerators/placement/config.yaml @@ -20,9 +20,9 @@ pattern_path: null # define max_layer_num as a constant 58 (for deepseek moe layer num 58) -max_moe_layer_num: 58 +max_moe_layer_num: 48 -enable_dynamic: False +enable_dynamic: True #False max_redundant_per_expert: 1 # 10 max_redundant_per_rank: 0 # 1 @@ -36,7 +36,7 @@ max_top_k: 8 # Adaptation for longcat zero expert enable_zero_expert: False -normal_expert_ids: 255 +normal_expert_ids: 127 # Optimizers Optimizers: diff --git a/omni/accelerators/placement/omni_placement/omni_planner.py b/omni/accelerators/placement/omni_placement/omni_planner.py index b998f063b7..949e15ff14 100644 --- a/omni/accelerators/placement/omni_placement/omni_planner.py +++ b/omni/accelerators/placement/omni_placement/omni_planner.py @@ -27,6 +27,10 @@ from pathlib import Path import time +import omni.adaptors.vllm.envs as omni_envs +if omni_envs.ENABLE_MS_MODELS: + from mindspore import mint + class OmniPlannerMeta(type): """Metaclass to implement singleton pattern for OmniPlanner.""" _instances = {} @@ -336,7 +340,10 @@ class OmniPlanner(metaclass=OmniPlannerMeta): if not self.is_moe_layer(layer_idx_moe): return tokens, token_expert_ids, token_expert_scores if self.enable_rank_round_robin: - token_expert_ids = torch.nn.functional.embedding(token_expert_ids,expert_mapping).squeeze(-1) + if omni_envs.ENABLE_MS_MODELS: + token_expert_ids = mint.nn.functional.embedding(token_expert_ids, expert_mapping).squeeze(-1) + else: + token_expert_ids = torch.nn.functional.embedding(token_expert_ids,expert_mapping).squeeze(-1) else: batch_size = token_expert_ids.shape[0] token_expert_ids = expert_mapping[token_expert_ids, self.redundant_bias[:batch_size,] % self.num_redundant_per_expert[layer_idx_moe][token_expert_ids]] diff --git a/omni/accelerators/placement/omni_placement/utils.py b/omni/accelerators/placement/omni_placement/utils.py index 70f1001cf5..8f8afb8484 100644 --- a/omni/accelerators/placement/omni_placement/utils.py +++ b/omni/accelerators/placement/omni_placement/utils.py @@ -12,6 +12,13 @@ from . import omni_placement from collections import defaultdict from pathlib import Path +import omni.adaptors.vllm.envs as omni_envs +if omni_envs.ENABLE_MS_MODELS: + from mindspore import Tensor + from mindspore.mint import unbind +else: + from torch import Tensor, unbind + BASE_DIR = Path(__file__).resolve().parent.parent DEFAULT_YAML_PATH = BASE_DIR / "config.yaml" DEFAULT_YAML_PATH = str(DEFAULT_YAML_PATH) @@ -98,10 +105,10 @@ def convert_param_dict_to_list(param_dict,layer_func, layer_func_param={}): # 检查所有元素是否为张量并形状一致 for t in tensor_list: - if not isinstance(t, torch.Tensor): + if not isinstance(t, Tensor): raise TypeError("All elements in tensor_list must be torch.Tensor") - unbound_tensors = [torch.unbind(t, dim=0) for t in tensor_list] + unbound_tensors = [unbind(t, dim=0) for t in tensor_list] return [list(group) for group in zip(*unbound_tensors)] @@ -142,13 +149,25 @@ def convert_param_to_ctype(param_list): if not isinstance(param_list, list): raise TypeError("param_list must be a list") + def get_element_size(tensor: Tensor): + if omni_envs.ENABLE_MS_MODELS: + return tensor.flatten()[-1].nbytes + else: + return tensor.element_size() + + def get_tensor_dtype(tensor: Tensor): + if omni_envs.ENABLE_MS_MODELS: + return str(tensor.dtype).lower() + else: + return str(tensor.dtype)[len('torch.'):] + def tensor_to_omni_tensor(tensor, tensor_name): - if not isinstance(tensor, torch.Tensor): + if not isinstance(tensor, Tensor): raise TypeError("All elements must be torch.Tensor") length = tensor.numel() - element_size = tensor.element_size() + element_size = get_element_size(tensor) address = tensor.data_ptr() - dtype = str(tensor.dtype)[len('torch.'):] + dtype = get_tensor_dtype(tensor) weight = omni_placement.Tensor( data_ptr=address, length=length, diff --git a/omni/adaptors/vllm/worker/npu_model_runner.py b/omni/adaptors/vllm/worker/npu_model_runner.py index 57ded6bfaf..254a5911ee 100644 --- a/omni/adaptors/vllm/worker/npu_model_runner.py +++ b/omni/adaptors/vllm/worker/npu_model_runner.py @@ -1223,7 +1223,7 @@ class NPUModelRunner(GPUModelRunner): if model_extra_config.task_config.enable_omni_placement: from omni.accelerators.placement.omni_placement.omni_planner import OmniPlanner first_k_dense_replace_names = ['num_dense_layers', 'first_k_dense_replace'] - first_k_dense_replace = get_attr_by_names(self.model.config, first_k_dense_replace_names, 3) + first_k_dense_replace = get_attr_by_names(self.model.config, first_k_dense_replace_names, 0) param_dict = dict(self.model.named_parameters()) self.planner = OmniPlanner() self.planner.init_dram_weights(param_dict, first_k_dense_replace=first_k_dense_replace) diff --git a/omni/models/config_loader/loader.py b/omni/models/config_loader/loader.py index 16a1d58f09..ede72f7897 100644 --- a/omni/models/config_loader/loader.py +++ b/omni/models/config_loader/loader.py @@ -309,7 +309,7 @@ def _init_model_extra_config(task_config) -> ModelExtraConfig: operator_opt_config = ModelOperatorOptConfig(**filter_dict_by_dataclass(ModelOperatorOptConfig, config_data['operator_optimizition_config'])) setattr(model_extra_config, 'task_config', task_config) - setattr(model_extra_config, 'parall_config', parall_config) + # setattr(model_extra_config, 'parall_config', parall_config) #unused for qwen3-moe setattr(model_extra_config, 'operator_opt_config', operator_opt_config) model_extra_config = ModelExtraConfig() -- Gitee