diff --git a/tests/st/python/cases_parallel/vllm_qwen3_moe.py b/tests/st/python/cases_parallel/vllm_qwen3_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..214773ede556ae913f8aa80fb2d4804ce3734894 --- /dev/null +++ b/tests/st/python/cases_parallel/vllm_qwen3_moe.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2025 Huawei Technologites Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# isort:skip_file +"""test vllm qwen3 moe.""" +import os + +from tests.st.python import utils + + +def teardown_function(): + utils.cleanup_subprocesses() + + +env_manager = utils.EnvVarManager() +# def env +env_vars = { + "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), + "HCCL_OP_EXPANSION_MODE": "AIV", + "VLLM_MS_MODEL_BACKEND": "Native", + "MS_ALLOC_CONF": "enable_vmm:True", + "LCCL_DETERMINISTIC": "1", + "HCCL_DETERMINISTIC": "true", + "ATB_MATMUL_SHUFFLE_K_ENABLE": "0", + "ATB_LLM_LCOC_ENABLE": "0", + "VLLM_USE_V1": "1", +} +# set env +env_manager.setup_ai_environment(env_vars) +import vllm_mindspore +from vllm import LLM, SamplingParams + + +def test_vllm_qwen3_30b_a3b(): + """ + test case qwen3-30B-A3B + """ + + # Sample prompts. + prompts = [ + "<|im_start|>user\n将文本分类为中性、负面或正面。 " + "\n文本:我认为这次假期还可以。 \n情感:" + "<|im_end|>\n<|im_start|>assistant\n", + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1) + + # Create an LLM. + llm = LLM(model="/home/workspace/mindspore_dataset/weight/Qwen3-30B-A3B", + gpu_memory_utilization=0.9, + tensor_parallel_size=2, + max_model_len=4096) + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + except_list = ['\n好的,我现在需要处理这个文本分类'] + # Print the outputs. + for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert generated_text == except_list[ + i], f"Expected: {except_list[i]}, but got: {generated_text}" + + # unset env + env_manager.unset_all() diff --git a/tests/st/python/test_cases_parallel.py b/tests/st/python/test_cases_parallel.py index 4346978fc1fde4cd7f680dc1830c797ba15525f4..6dd1e47addfee50a6f367e85eb8b36990e232576 100644 --- a/tests/st/python/test_cases_parallel.py +++ b/tests/st/python/test_cases_parallel.py @@ -200,6 +200,8 @@ def test_cases_parallel_part7(): (1, "cases_parallel/vllm_qwen2_5_vl_7b_v1.py" "::test_qwen2_5_vl_7b_v1_video_infer", "vllm_qwen2_5_vl_7b_v1_video_infer.log"), + (2, "cases_parallel/vllm_qwen3_moe.py::test_vllm_qwen3_30b_a3b", + "test_vllm_qwen3_30b_a3b.log"), ] run_tasks(cases) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/__init__.py b/vllm_mindspore/model_executor/layers/fused_moe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..70d81228e661d06bad265789d80d8a90b9a1bb16 --- /dev/null +++ b/vllm_mindspore/model_executor/layers/fused_moe/__init__.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/__init__.py +# +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2025 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from vllm_mindspore.model_executor.layers.fused_moe.layer import FusedMoE + +__all__ = ["FusedMoE"] diff --git a/vllm_mindspore/model_executor/layers/fused_moe/config.py b/vllm_mindspore/model_executor/layers/fused_moe/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c34836daab7796591db28f24eac09ac4ad04e835 --- /dev/null +++ b/vllm_mindspore/model_executor/layers/fused_moe/config.py @@ -0,0 +1,226 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/config.py +# +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2025 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import mindspore as ms +import vllm.envs as envs +from vllm.config import ParallelConfig, get_current_vllm_config +from vllm.distributed import (get_dp_group, get_ep_group, + get_tensor_model_parallel_rank) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@dataclass +class FusedMoEParallelConfig: + tp_size: int + dp_size: int + ep_size: int + tp_rank: int + dp_rank: int + ep_rank: int + + use_ep: bool # whether to use EP or not + + @property + def use_all2all_kernels(self): + return self.dp_size > 1 and self.use_ep and self.tp_size == 1 + + @staticmethod + def make(tp_size_: int, dp_size_: int, + vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": + """ + Determine MoE parallel configuration. Based on the input tp_size_, + dp_size_, ep_size_ and vllm's parallel config, determine what + level's of parallelism to use in the fused moe layer. + + Args: + tp_size_ (int): tp_size passed into the FusedMoE constructor. + dp_size_ (int): dp_size passed into the FusedMoE constructor. + ep_size_ (int): ep_size passed into the FusedMoE constructor. + vllm_parallel_config (ParallelConfig): vllm's parallel config + object. + + Examples: + When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1, + we simply return the sizes unaltered and the ranks set to 0. + + Expert Parallelism is considered only when either dp_size_ or tp_size_ + is non trivial. + + When TP = 2, DP = 1 and EP = False, the configuration on different + devices, + - device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} // + legend : {size, rank} + - device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0} + - Comment : Tensors are sharded across 2 devices. + + When TP = 1, DP = 2 and EP = False, the configuration on different + devices, + - device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0} + - device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded + across 2 decvices. + + When TP = 2, DP = 2 and EP = False, the configuration on different + devices, + - device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0} + - device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0} + - device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0} + - device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded + across 4 devices. + + When, TP = 2, DP = 1 and EP = True, the configuration on different + devices, + - device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1} + - Comment: The experts are split between the 2 devices. + + When, TP = 1, DP = 2 and EP = True, the configuration on different + devices, + - device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1} + - Comment: There are 2 engine instances and the experts are split + between the 2 devices. + + When TP = 2, DP = 2 and EP = True, the configuration on different + devices, + - device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0} + - device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1} + - device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2} + - device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3} + - Comment: There are 2 engine instances and the experts are split + between the 4 devices. + """ + + def flatten_tp_across_dp(dp_rank: int): + tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank() + # There are actually dp_size_ * tp_size_ devices. Update tp_size + # and tp_rank so we shard across all devices. + tp_size = dp_size_ * tp_size_ + tp_rank = dp_rank * tp_size_ + tp_rank + return tp_size, tp_rank + + use_ep = (dp_size_ * tp_size_ > 1 + and vllm_parallel_config.enable_expert_parallel) + + dp_size = dp_size_ + dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 + tp_size, tp_rank = flatten_tp_across_dp(dp_rank) + + if not use_ep: + return FusedMoEParallelConfig(tp_size=tp_size, + tp_rank=tp_rank, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=1, + ep_rank=0, + use_ep=False) + # DP + EP / TP + EP / DP + TP + EP + assert use_ep + + vllm_config = get_current_vllm_config() + # custom_ep_size is used for tp + ep parallel, + # which is not supported in original vllm. + if vllm_config.additional_config is not None and \ + vllm_config.additional_config.get("expert_parallel", None) \ + is not None: + custom_ep_size = int( + vllm_config.additional_config.get("expert_parallel", None)) + ep_size = custom_ep_size + tp_size = tp_size // custom_ep_size + tp_rank = tp_rank % tp_size + ep_rank = get_ep_group().rank_in_group // tp_size + return FusedMoEParallelConfig(tp_size=tp_size, + tp_rank=tp_rank, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=ep_size, + ep_rank=ep_rank, + use_ep=True) + else: + # In EP, each device owns a set of experts fully. + # There is no tensor parallel update tp_size, tp_rank, + # ep_size and ep_rank to reflect that. + ep_size = tp_size + ep_rank = tp_rank + return FusedMoEParallelConfig(tp_size=1, + tp_rank=0, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=ep_size, + ep_rank=ep_rank, + use_ep=True) + + +@dataclass +class FusedMoEConfig: + num_experts: int + experts_per_token: int + hidden_dim: int + + num_local_experts: int + moe_parallel_config: FusedMoEParallelConfig + + in_dtype: ms.dtype.Type # The activation type. + quant_dtype: ms.dtype.Type = None + + # TODO: add more quantization params, blocked, per-token, etc. + block_size: int = 128 + + max_num_tokens: int = envs.VLLM_FUSED_MOE_CHUNK_SIZE + + optim_tp_ep_gating_perf: bool = False + + def __post_init__(self): + if self.dp_size > 1: + logger.debug("Using FusedMoEConfig::max_num_tokens=%d", + self.max_num_tokens) + + @property + def tp_size(self): + return self.moe_parallel_config.tp_size + + @property + def dp_size(self): + return self.moe_parallel_config.dp_size + + @property + def ep_size(self): + return self.moe_parallel_config.ep_size + + @property + def tp_rank(self): + return self.moe_parallel_config.tp_rank + + @property + def dp_rank(self): + return self.moe_parallel_config.dp_rank + + @property + def ep_rank(self): + return self.moe_parallel_config.ep_rank + + @property + def use_ep(self): + return self.moe_parallel_config.use_ep diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..a5ca779047ddb5e612537ec3a6c1aff39c79af8a --- /dev/null +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -0,0 +1,275 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py +# +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2025 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fused MoE kernel with MindSpore.""" + +from typing import Optional + +import mindspore as ms +import numpy as np +from mindspore import Tensor, mint, nn, ops +from mindspore.ops.auto_generate import (GroupedMatmulV4, MoeDistributeCombine, + MoeDistributeDispatch, + MoeInitRoutingV2, MoeTokenUnpermute) +from vllm.distributed.parallel_state import get_ep_group + +from vllm_mindspore.utils import is_910b + + +def fused_topk( + hidden_states: Tensor, + gating_output: Tensor, + topk: int, + renormalize: bool, + indices_type=None, +) -> tuple[Tensor, Tensor]: + score = mint.softmax(gating_output, dim=-1) + topk_weights, topk_ids = mint.topk(score, k=topk, dim=-1) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + if indices_type is not None: + topk_ids = topk_ids.to(indices_type) + return topk_weights, topk_ids + + +def grouped_topk( + hidden_states: Tensor, + gating_output: Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[Tensor] = None +) -> tuple[Tensor, Tensor]: + raise NotImplementedError("grouped_topk is not implemented.") + + +class FusedExperts(nn.Cell): + + def __init__(self, moe_config): + super().__init__() + self.group_matmul_ops = GroupedMatmulV4() + self.moe_init_routing_op = MoeInitRoutingV2() + self.moe_token_unpermute = MoeTokenUnpermute() + + self.pure_tp = False + self.pure_ep = False + self.tp_ep = False + + self.experts_num = moe_config.num_experts + self.local_expert_num = moe_config.num_local_experts + self.ep_size = moe_config.moe_parallel_config.ep_size + self.ep_rank = moe_config.moe_parallel_config.ep_rank + self.dp_size = moe_config.moe_parallel_config.dp_size + self.optim_tp_ep_gating_perf = moe_config.optim_tp_ep_gating_perf + if self.ep_size > 1: + experts_num_map = [(self.experts_num // self.ep_size) + for _ in range(self.ep_size - 1)] + experts_num_map.append(self.experts_num - + ((self.experts_num // self.ep_size) * + (self.ep_size - 1))) + self.experts_num_map = experts_num_map + self.ep_group = get_ep_group().device_group._name + + # pure ep mode + if moe_config.moe_parallel_config.ep_size > 1 and \ + moe_config.moe_parallel_config.tp_size == 1: + self.pure_ep = True + + self.use_all2all_kernels = \ + moe_config.moe_parallel_config.use_all2all_kernels + + if self.use_all2all_kernels: + # some configuration for dispatch and combine + self.dispatch = MoeDistributeDispatch() + self.combine = MoeDistributeCombine() + self.dispatch_tp_world_size = 0 if is_910b() else 1 + self.dispatch_shared_expert_num = 0 if is_910b() else 1 + self.max_bs = 256 if is_910b() else 512 + self.max_bs *= self.ep_size + + # pure tp mode + elif moe_config.moe_parallel_config.ep_size == 1 and \ + moe_config.moe_parallel_config.tp_size >= 1: + self.pure_tp = True + # tp + ep mode + else: + self.tp_ep = True + experts_num_map_np = np.array(self.experts_num_map, dtype=np.int32) + experts_num_map_cu_np = np.cumsum(experts_num_map_np, + dtype=np.int32) + self.expert_start_index = 0 if self.ep_rank == 0 else int( + experts_num_map_cu_np[self.ep_rank - 1]) + + def construct(self, + hidden_states: Tensor, + w1: Tensor, + w2: Tensor, + topk_weights: Tensor, + topk_ids: Tensor, + activation: str = "silu", + global_num_experts: int = -1, + apply_router_weight_on_input: bool = False) -> Tensor: + + if self.pure_tp: + hidden_states = self.run_tp_moe(hidden_states, w1, w2, topk_ids, + topk_weights, activation, + global_num_experts, + apply_router_weight_on_input) + # ep_size > 1 : pure ep or tp + ep + elif self.pure_ep: + # pure ep + hidden_states = self.run_ep_moe(hidden_states, w1, w2, topk_ids, + topk_weights, activation, + global_num_experts, + apply_router_weight_on_input) + # tp_size > 1 : tp + ep + else: + hidden_states = self.run_tp_ep_moe(hidden_states, w1, w2, topk_ids, + topk_weights, activation, + global_num_experts, + apply_router_weight_on_input) + + return hidden_states + + def _gate_activation(self, gate, activation): + if activation == "silu": + return mint.nn.functional.silu(gate) + elif activation == "gelu": + return mint.nn.functional.gelu(gate) + else: + raise ValueError(f"Unsupported activation function: {activation}") + + def _group_matmul(self, hidden_states, weight, group_list): + return self.group_matmul_ops([hidden_states], [weight], + None, + None, + None, + None, + None, + None, + group_list, + split_item=3, + group_type=0, + group_list_type=1)[0] + + def _ffn(self, hidden_state, w1, w2, group_list, activation): + gate_hidden_out = self._group_matmul(hidden_state, w1, group_list) + gate, hidden = mint.split(gate_hidden_out, + (w1.shape[2] // 2, w1.shape[2] // 2), -1) + gate = self._gate_activation(gate, activation) + hidden = mint.mul(hidden, gate) + expert_output = self._group_matmul(hidden, w2, group_list) + expert_output = mint.nan_to_num(expert_output, 0, 0, 0) + return expert_output + + def run_tp_moe(self, hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input): + topk_weights = topk_weights.astype(hidden_states.dtype) + topk_ids = topk_ids.astype(ms.int32) + + sorted_input_tensor, unsort_map, group_list, _ = \ + self.moe_init_routing_op( + hidden_states, + topk_ids, + active_num=0, + expert_capacity=0, + expert_num=global_num_experts, + drop_pad_mode=0, + expert_tokens_count_or_cumsum_flag=2, + expert_tokens_before_capacity_flag=True) + + group_list = group_list.astype(ms.int64) + + expert_output = self._ffn(sorted_input_tensor, w1, w2, group_list, + activation) + + moe_output = self.moe_token_unpermute(permuted_tokens=expert_output, + sorted_indices=unsort_map, + probs=topk_weights, + padded_mode=False, + restore_shape=None) + return moe_output + + def run_tp_ep_moe(self, hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input): + topk_weights = topk_weights.astype(hidden_states.dtype) + topk_ids = topk_ids.astype(ms.int32) + + if self.dp_size > 1 or not self.optim_tp_ep_gating_perf: + topk_mask = topk_ids < self.expert_start_index + local_topk_ids = topk_ids - self.expert_start_index + local_topk_ids = local_topk_ids.astype(ms.int32) + # trick: if tp + ep moe, means ep_size > 1, + # and expert will be distributed across ep_size, + # so max(local_topk_ids) < self.experts_num - 1. + # It will allow ffn not compute the expert output, + # which are not assigned to this ep rank. + local_topk_ids = ops.masked_fill(local_topk_ids, topk_mask, + self.experts_num - 1) + else: + local_topk_ids = topk_ids + + weight_mask = local_topk_ids >= self.local_expert_num + topk_weights = ops.masked_fill(topk_weights, weight_mask, 0) + + sorted_input_tensor, unsort_map, group_list, _ = \ + self.moe_init_routing_op( + hidden_states, + local_topk_ids, + active_num=0, + expert_capacity=0, + expert_num=global_num_experts, + drop_pad_mode=0, + expert_tokens_count_or_cumsum_flag=2, + expert_tokens_before_capacity_flag=True) + + group_list = group_list[:self.local_expert_num] + group_list = group_list.astype(ms.int64) + expert_output = self._ffn(sorted_input_tensor, w1, w2, group_list, + activation) + moe_output = self.moe_token_unpermute(permuted_tokens=expert_output, + sorted_indices=unsort_map, + probs=topk_weights, + padded_mode=False, + restore_shape=None) + return moe_output + + def run_ep_moe(self, hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input): + topk_weights = topk_weights.astype(hidden_states.dtype) + topk_ids = topk_ids.astype(ms.int32) + + return self._ep_with_dispatch_combine(hidden_states, w1, w2, topk_ids, + topk_weights, activation, + global_num_experts, + apply_router_weight_on_input) + + def _ep_with_dispatch_combine(self, hidden_states, w1, w2, topk_ids, + topk_weights, activation, global_num_experts, + apply_router_weight_on_input): + """fused ops, moe feed forward with dispatch and combine.""" + # TODO: to implement ep parallel with dispatch and combine. + raise NotImplementedError("ep parallel with dispatch and combine " + "is not implemented.") diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..ae6fe301b1ec875222f52747312609aa1a2ce45e --- /dev/null +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -0,0 +1,836 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/layer.py +# +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2025 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fused MoE layers with MindSpore.""" + +from abc import abstractmethod +from typing import Callable, Optional + +import numpy as np +import vllm.envs as envs +from mindspore import Parameter, Tensor, from_numpy, mint, nn, ops +from vllm.config import get_current_vllm_config +from vllm.distributed import (get_dp_group, get_ep_group, + get_tensor_model_parallel_world_size, + get_tp_group) +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + +from vllm_mindspore.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEParallelConfig) +from vllm_mindspore.model_executor.layers.fused_moe.fused_moe import ( + FusedExperts, fused_topk, grouped_topk) +from vllm_mindspore.model_executor.layers.quantization.base_config import ( + QuantizeMethodBase) +from vllm_mindspore.model_executor.model_loader.weight_utils import ( + split_loaded_weight) + +logger = init_logger(__name__) + + +class FusedMoEMethodBase(QuantizeMethodBase): + + @abstractmethod + def create_weights(self, layer: nn.Cell, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype, **extra_weight_attrs): + raise NotImplementedError + + @abstractmethod + def apply( + self, + layer: nn.Cell, + x: Tensor, + router_logits: Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> Tensor: + raise NotImplementedError + + +class UnquantizedFusedMoEMethod(FusedMoEMethodBase, nn.Cell): + """MoE method without quantization.""" + + def __init__(self, moe: FusedMoEConfig): + super().__init__() + self.fused_experts = FusedExperts(moe) + self.moe = moe + + def create_weights(self, layer: nn.Cell, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype, **extra_weight_attrs): + # Fused gate_up_proj (column parallel) + # Transpose the weight to make it compatible with the GroupMatMul kernel + w13_weight = Parameter(mint.empty(num_experts, + hidden_size, + 2 * intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.insert_param_to_cell("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + # Mark the weight as transposed, so the weight loader can know it. + set_weight_attrs(w13_weight, {"is_transposed": True}) + + # down_proj (row parallel) + w2_weight = Parameter(mint.empty(num_experts, + intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.insert_param_to_cell("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + set_weight_attrs(w2_weight, {"is_transposed": True}) + + def apply( + self, + layer: nn.Cell, + x: Tensor, + router_logits: Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> Tensor: + return self.forward_npu( + x=x, + layer=layer, + router_logits=router_logits, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + global_num_experts=global_num_experts, + expert_map=expert_map, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input) + + def forward_npu( + self, + layer: nn.Cell, + x: Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> Tensor: + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + indices_type=None) + + return self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input) + + +def determine_expert_map(ep_size: int, ep_rank: int, global_num_experts: int): + """ + use numpy rather than tensor because tensor operation will be on NPU + with mindspore, which is slow. + """ + assert ep_size > 0 + if ep_size == 1: + return (global_num_experts, None) + + local_num_experts = global_num_experts // ep_size + + # Create a numpy array of size global_num_experts filled with -1 + expert_map = np.full((global_num_experts, ), -1, dtype=np.int32) + # Create an expert map for the local experts + if ep_rank < (ep_size - 1): + # Each non-last rank gets local_num_experts experts. + expert_map[ep_rank * local_num_experts: + (ep_rank + 1) * local_num_experts] = \ + np.arange(0, local_num_experts, dtype=np.int32) + else: + # All remaining experts are assigned to the last rank. + local_num_experts = global_num_experts - ep_rank * local_num_experts + expert_map[-local_num_experts:] = np.arange(0, + local_num_experts, + dtype=np.int32) + return (local_num_experts, expert_map) + + +class FusedMoE(nn.Cell): + """FusedMoE layer for MoE models. + + This layer contains both MergedColumnParallel weights (gate_up_proj / + w13) and RowParallelLinear weights (down_proj/ w2). + + Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We + copy that naming convention here and handle any remapping in the + load_weights function in each model implementation. + + Args: + num_experts: Number of experts in the model + top_k: Number of experts selected for each token + hidden_size: Input hidden state size of the transformer + intermediate_size: Intermediate size of the experts + params_dtype: Data type for the parameters. + reduce_results: Whether to all all_reduce on the output of the layer + renomalize: Whether to renormalize the logits in the fused_moe kernel + quant_config: Quantization configure. + """ + + def __init__( + self, + num_experts: int, # Global number of experts + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype=None, + reduce_results: bool = False, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + ep_size: Optional[int] = None, + dp_size: Optional[int] = None, + prefix: str = "", + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + *, + optim_tp_ep_gating_perf: bool = False, + ): + super().__init__() + + # TODO: to support apply_router_weight_on_input + if apply_router_weight_on_input: + raise NotImplementedError("apply_router_weight_on_input" + "is not supported yet") + + if params_dtype is None: + params_dtype = get_current_vllm_config().model_config.dtype + self.params_dtype = params_dtype + + vllm_config = get_current_vllm_config() + self.moe_parallel_config: FusedMoEParallelConfig = ( + FusedMoEParallelConfig.make( + tp_size_=(tp_size if tp_size is not None else + get_tensor_model_parallel_world_size()), + dp_size_=(dp_size if dp_size is not None else + get_dp_group().world_size), + vllm_parallel_config=vllm_config.parallel_config)) + + self.global_num_experts = num_experts + + # Determine expert maps + if self.use_ep: + self.local_num_experts, self.expert_map = determine_expert_map( + ep_size=self.ep_size, + ep_rank=self.ep_rank, + global_num_experts=self.global_num_experts) + else: + self.local_num_experts, self.expert_map = (self.global_num_experts, + None) + + # Determine the moe parallel mode. + # pure_tp means using tensor parallelism only, no expert parallelism. + self.pure_tp = False + self.tp_ep = False + self.pure_ep = False + + # self.ep_size == 1, means use tensor parallelism to compute moe. + if self.ep_size == 1: + self.pure_tp = True + # self.ep_size > 1, means use expert parallelism or + # expert parallelism mix tensor parallelism. + else: + if self.tp_size == 1: + self.pure_ep = True + else: + self.tp_ep = True + + self.optim_tp_ep_gating_perf = \ + optim_tp_ep_gating_perf and self.tp_ep + + if self.ep_rank < (self.ep_size - 1): + self.expert_start_index = self.ep_rank * self.local_num_experts + self.expert_end_index = (self.ep_rank + 1) * self.local_num_experts + else: + self.expert_start_index = self.ep_rank * self.local_num_experts + self.expert_end_index = self.global_num_experts + + self.top_k = top_k + + assert intermediate_size % self.tp_size == 0 + self.hidden_size = hidden_size + self.intermediate_size_per_partition = intermediate_size // self.tp_size + self.reduce_results = reduce_results + self.renormalize = renormalize + self.use_grouped_topk = use_grouped_topk + if self.use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + self.num_expert_group = num_expert_group + self.topk_group = topk_group + self.custom_routing_function = custom_routing_function + self.scoring_func = scoring_func + self.e_score_correction_bias = e_score_correction_bias + self.apply_router_weight_on_input = apply_router_weight_on_input + self.activation = activation + + if self.scoring_func != "softmax" and not self.use_grouped_topk: + raise ValueError("Only softmax scoring function is supported for " + "non-grouped topk.") + + moe = FusedMoEConfig( + num_experts=self.global_num_experts, + experts_per_token=top_k, + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + moe_parallel_config=self.moe_parallel_config, + # TODO (bnell): this needs to be fixed for quantized types. + in_dtype=params_dtype, + max_num_tokens=envs.VLLM_FUSED_MOE_CHUNK_SIZE, + optim_tp_ep_gating_perf=self.optim_tp_ep_gating_perf, + ) + self.moe_config = moe + self.quant_config = quant_config + + # Note: get_quant_method will look at the layer's local_num_experts + # for heuristic purposes, so it must be initialized first. + quant_method: Optional[QuantizeMethodBase] = None + + if quant_config is None: + quant_method = UnquantizedFusedMoEMethod(moe) + else: + quant_method = quant_config.get_quant_method(self, prefix) + + assert quant_method is not None + assert isinstance(quant_method, FusedMoEMethodBase) + self.quant_method = quant_method + + moe_quant_params = { + "num_experts": self.local_num_experts, + "hidden_size": hidden_size, + "intermediate_size_per_partition": + self.intermediate_size_per_partition, + "params_dtype": params_dtype, + "weight_loader": self.weight_loader, + } + # need full intermediate size pre-sharding for WNA16 act order + if (self.quant_method.__class__.__name__ + in ("GPTQMarlinMoEMethod", + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod")): + moe_quant_params["intermediate_size_full"] = intermediate_size + + self.quant_method.create_weights(layer=self, **moe_quant_params) + + # Initialize some communication ops and group. + self.dp_group = get_dp_group().device_group._name + self.ep_group = get_ep_group().device_group._name + + self.tp_world_size = get_tensor_model_parallel_world_size() + self.tp_group = get_tp_group().device_group._name + self.all_reduce_from_tp_group = ops.AllReduce(group=self.tp_group) + + if (self.pure_tp or self.tp_ep) and self.dp_size > 1: + self.all_gather_from_dp_group = ops.AllGather(group=self.dp_group) + self.all_reduce_from_dp_group = ops.AllReduce(group=self.dp_group) + self.reduce_scatter_from_dp_group = ops.ReduceScatter( + group=self.dp_group) + + @property + def tp_size(self): + return self.moe_parallel_config.tp_size + + @property + def dp_size(self): + return self.moe_parallel_config.dp_size + + @property + def ep_size(self): + return self.moe_parallel_config.ep_size + + @property + def tp_rank(self): + return self.moe_parallel_config.tp_rank + + @property + def dp_rank(self): + return self.moe_parallel_config.dp_rank + + @property + def ep_rank(self): + return self.moe_parallel_config.ep_rank + + @property + def use_ep(self): + return self.moe_parallel_config.use_ep + + @property + def use_all2all_kernels(self): + return self.moe_parallel_config.use_all2all_kernels + + def _load_w13(self, param: Parameter, shard_dim: int, shard_id: str, + loaded_weight: Tensor, expert_id: int, tp_rank: int): + is_param_transpose = param.is_transposed \ + if hasattr(param, "is_transposed") else False + + # Index the loaded weight for tp sharding. + # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim + if is_param_transpose: + shard_size = param.shape[-1] // 2 + else: + shard_size = param.shape[-2] // 2 + + loaded_weight = split_loaded_weight(loaded_weight, shard_dim, + shard_size * tp_rank, shard_size) + + if is_param_transpose: + loaded_weight = from_numpy(loaded_weight.swapaxes(-1, -2)) + else: + loaded_weight = from_numpy(loaded_weight) + + # Narrow parameter and load. + # w1, gate_proj: Load into first logical weight of w13. + if shard_id == "w1": + if is_param_transpose: + param[expert_id, :, 0:shard_size] = loaded_weight + else: + param[expert_id, 0:shard_size, :] = loaded_weight + # w3, up_proj: Load into second logical weight of w13. + else: + assert shard_id == "w3" + if is_param_transpose: + param[expert_id, :, shard_size:shard_size * 2] = loaded_weight + else: + param[expert_id, shard_size:shard_size * 2, :] = loaded_weight + + def _load_w2(self, + param: Parameter, + shard_dim: int, + loaded_weight: Tensor, + tp_rank: int, + expert_id: int, + load_full: bool = False): + is_param_transpose = param.is_transposed \ + if hasattr(param, "is_transposed") else False + + # Index the loaded weight for tp sharding. + # down_proj: "RowParallel" so tp sharding on input_dim + # Narrow parameter and load. + if not load_full: + if is_param_transpose: + shard_size = param.shape[-2] + else: + shard_size = param.shape[-1] + loaded_weight = split_loaded_weight(loaded_weight, shard_dim, + shard_size * tp_rank, + shard_size) + + if is_param_transpose: + loaded_weight = from_numpy(loaded_weight.swapaxes(-1, -2)) + else: + loaded_weight = from_numpy(loaded_weight) + param[expert_id] = loaded_weight + # w2, down_proj: Load into only logical weight of w2. + else: + if is_param_transpose: + loaded_weight = from_numpy(loaded_weight.swapaxes(-1, -2)) + else: + loaded_weight = from_numpy(loaded_weight) + param.set_data(loaded_weight) + + def _load_single_value(self, param: Parameter, loaded_weight: Tensor, + expert_id: int): + is_param_transpose = param.is_transposed \ + if hasattr(param, "is_transposed") else False + loaded_weight = loaded_weight[:] + if is_param_transpose: + loaded_weight = from_numpy(loaded_weight.swapaxes(-1, -2)) + else: + loaded_weight = from_numpy(loaded_weight) + param[expert_id] = from_numpy(loaded_weight) + + def _load_g_idx(self, shard_id: str, param: Parameter, shard_dim: int, + loaded_weight: Tensor, tp_rank: int, expert_id: int): + + if shard_id == "w2": + self._load_w2(shard_dim=shard_dim, + loaded_weight=loaded_weight, + param=param, + expert_id=expert_id, + tp_rank=tp_rank) + else: + assert shard_id in ("w1", "w3") + is_param_transpose = param.is_transposed \ + if hasattr(param, "is_transposed") else False + loaded_weight = loaded_weight[:] + if is_param_transpose: + loaded_weight = from_numpy(loaded_weight.swapaxes(-1, -2)) + else: + loaded_weight = from_numpy(loaded_weight) + param[expert_id] = from_numpy(loaded_weight) + + def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: + if self.expert_map is None: + return expert_id + return self.expert_map[expert_id].item() + + def _load_model_weight_or_group_weight_scale(self, + shard_dim: int, + param: Parameter, + shard_id: str, + loaded_weight: Tensor, + tp_rank: int, + expert_id: int, + load_full_w2: bool = False): + """ + Load grouped weight scales for group quantization or model weights + :param shard_dim: dimension to shard + :param expert_data: parameter for a particular expert + :param shard_id: either w1, w2, or w3 + :param loaded_weight: checkpoint weight to load into the param + :param tp_rank: tensor parallel rank + :param load_full_w2: whether or not the w2 loaded should be sharded. + """ + if shard_id == "w2": + # In the case where we have actorder/g_idx, we do not partition the + # w2 scales, as indicated by `load_full` argument, for all tp cases + self._load_w2(shard_dim=shard_dim, + loaded_weight=loaded_weight, + param=param, + tp_rank=tp_rank, + expert_id=expert_id, + load_full=load_full_w2) + elif shard_id in ("w1", "w3"): + self._load_w13(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + param=param, + expert_id=expert_id, + tp_rank=tp_rank) + + def weight_loader(self, param: Parameter, loaded_weight: Tensor, + weight_name: str, shard_id: str, expert_id: int) -> None: + + expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) + if expert_id == -1: + return + + if shard_id not in ("w1", "w2", "w3"): + raise ValueError(f"shard_id must be ['w1','w2','w3'] but " + f"got {shard_id}.") + + # Fetch the dim to shard the parameter/loaded weight + # based on the shard id. This will be whatever + # dimension intermediate_size_per_partition is used. + SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} + + shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] + + # TODO: full_load will slow down the loading process, + # Support it when need it in the future. + + # Case g_idx + if "g_idx" in weight_name: + self._load_g_idx(shard_dim=0, + shard_id=shard_id, + loaded_weight=loaded_weight, + param=param, + tp_rank=self.tp_rank, + expert_id=expert_id) + return + + # Case weight_shape + if "weight_shape" in weight_name: + # only required by compressed-tensors + self._load_single_value(param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + return + + # Case model weights + if "weight" in weight_name: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + param=param, + expert_id=expert_id, + tp_rank=self.tp_rank) + return + + @staticmethod + def select_experts(hidden_states: Tensor, + router_logits: Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[Tensor] = None, + indices_type=None): + + # DeekSeekv2 uses grouped_top_k + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + topk_weights, topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + if indices_type is not None: + topk_ids = topk_ids.to(dtype=indices_type) + elif custom_routing_function is None: + topk_weights, topk_ids = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + indices_type=indices_type, + ) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) + if indices_type is not None: + topk_ids = topk_ids.to(dtype=indices_type) + + return topk_weights, topk_ids + + def must_reduce_shared_expert_outputs(self) -> bool: + # If use tp moe, there is a delay all reduce ops of routed experts. + # Therefore, shared experts in tensor parallelism can perform + # all-reduce operations together with routing experts + return not (self.pure_tp or self.tp_ep) + + def maybe_all_reduce_tensor_model_parallel(self, + final_hidden_states: Tensor): + """ + To all_reduce after routed expert and shared expert are added. + """ + # Do delay allreduce If "must_reduce_shared_expert_outputs" return True + if self.pure_tp or self.tp_ep: + return self.all_reduce_from_tp_group(final_hidden_states) + return final_hidden_states + + def construct(self, + hidden_states: Tensor, + router_logits: Tensor, + dp_pad_index=None, + dp_unpad_index=None, + dp_pad_index_with_offset=None, + dp_unpad_index_total_with_offset=None): + if self.use_all2all_kernels: + return self.forward_impl_chunked(hidden_states, router_logits) + + return self.forward_impl(hidden_states, router_logits, dp_pad_index, + dp_unpad_index, dp_pad_index_with_offset, + dp_unpad_index_total_with_offset) + + def forward_impl(self, hidden_states: Tensor, router_logits: Tensor, + dp_pad_index, dp_unpad_index, + dp_pad_index_total_with_offset, + dp_unpad_index_total_with_offset): + """ + If dp_world_size == 4, dp_rank == 1, + tokens_num across dp is [1, 3, 4, 2], then + dp_pad_index = [0, 1, 2, 0] + dp_unpad_index = [0, 1, 2] + dp_pad_index_total_with_offset = \ + [0, 0, 0, 0, 1, 2, 3, 0, 4, 5, 6, 0, 7, 8, 0, 0] + dp_unpad_index_total_with_offset = \ + [0, 4, 5, 6, 8, 9, 10, 11, 12, 13] + """ + if (self.pure_tp or self.tp_ep) and self.dp_size > 1: + # TODO: replace AllGather with AllGatherV to eliminate padding + # ops.AllGather is not supported for uneven size tensor, + # so need to pad to same size. + hidden_buffer = mint.index_select(hidden_states, 0, dp_pad_index) + hidden_buffer = self.all_gather_from_dp_group(hidden_buffer) + + logit_buffer = mint.index_select(router_logits, 0, dp_pad_index) + logit_buffer = self.all_gather_from_dp_group(logit_buffer) + + hidden_states = mint.index_select( + hidden_buffer, 0, dp_unpad_index_total_with_offset) + router_logits = mint.index_select( + logit_buffer, 0, dp_unpad_index_total_with_offset) + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, + apply_router_weight_on_input=self.apply_router_weight_on_input, + ) + + if (self.pure_tp or self.tp_ep) and self.dp_size > 1: + final_hidden_states = mint.index_select( + final_hidden_states, 0, dp_pad_index_total_with_offset) + final_hidden_states = final_hidden_states.reshape( + self.dp_size, -1, final_hidden_states.shape[-1]) + final_hidden_states = final_hidden_states.reshape( + -1, final_hidden_states.shape[-1]) + # TODO: replace ReudceScatter with ReduceScatterV + # to eliminate padding + final_hidden_states = self.reduce_scatter_from_dp_group( + final_hidden_states) + final_hidden_states = mint.index_select(final_hidden_states, 0, + dp_unpad_index) + + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + # Default set to False. (May have to add shared expert outputs.) + final_hidden_states = self.maybe_all_reduce_tensor_model_parallel( + final_hidden_states) + + return final_hidden_states + + def forward_impl_chunked(self, full_hidden_states: Tensor, + full_router_logits: Tensor): + # TODO: to implement chunked forward for FusedMoE. + # Chunked forward can solve the batch size limitation + # of the dispatch-combine kernel. + + hidden_states = full_hidden_states + router_logits = full_router_logits + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, + apply_router_weight_on_input=self.apply_router_weight_on_input, + ) + + return final_hidden_states + + @classmethod + def make_expert_params_mapping( + cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int) -> list[tuple[str, str, int, str]]: + + return [ + # the format is (param_name, weight_name, expert_id, shard_id) + ("experts.w13_" if weight_name + in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", + f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) + for expert_id in range(num_experts) for shard_id, weight_name in [ + ("w1", ckpt_gate_proj_name), + ("w2", ckpt_down_proj_name), + ("w3", ckpt_up_proj_name), + ] + ] + + def extra_repr(self) -> str: + + s = ( + f"global_num_experts={self.global_num_experts}, " + f"local_num_experts={self.local_num_experts}, " + f"top_k={self.top_k}, " + f"intermediate_size_per_partition={self.intermediate_size_per_partition}, " # noqa: E501 + f"tp_size={self.tp_size},\n" + f"ep_size={self.ep_size}, " + f"reduce_results={self.reduce_results}, " + f"renormalize={self.renormalize}, " + f"use_grouped_topk={self.use_grouped_topk}") + + if self.use_grouped_topk: + s += f", num_expert_group={self.num_expert_group}, topk_group={self.topk_group}" # noqa: E501 + + s += f", scoring_func='{self.scoring_func}', activation='{self.activation}'" # noqa: E501 + + return s diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index 76518feaf84f072a916b1e5b7f08472dfa16211b..f1ef6a8e29c8a80cbfc61c2f8f8e907e8ccf93d8 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -162,6 +162,108 @@ class LinearBase(nn.Cell): raise NotImplementedError +class ReplicatedLinear(LinearBase): + """Replicated linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + return_bias: If true, return bias together with outputs in forward pass. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype=None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + optim_tp_ep_gating_perf: bool = False, + expert_start_index=None, + expert_end_index=None, + ): + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix=prefix, + return_bias=return_bias) + + # All the linear layer supports quant method. + assert self.quant_method is not None + self.quant_method.create_weights(self, + self.input_size, [self.output_size], + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) + + if bias: + self.bias = Parameter( + mint.empty(self.output_size, dtype=self.params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.bias = None + + self.optim_tp_ep_gating_perf = optim_tp_ep_gating_perf + self.expert_start_index = expert_start_index + self.expert_end_index = expert_end_index + + def weight_loader(self, param: Parameter, loaded_weight: Tensor): + loaded_weight = loaded_weight[:] + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param.shape == loaded_weight.shape, ( + f"Tried to load weights of size {loaded_weight.size()}" + f"to a parameter of size {param.size()}") + + if self.optim_tp_ep_gating_perf: + if self.expert_start_index is None \ + or self.expert_end_index is None: + raise ValueError( + "If setting optim_tp_ep_gating_perf, expert_start_index " + "and expert_end_index must be set too.") + rearange_weight = [ + loaded_weight[self.expert_start_index:self.expert_end_index], + loaded_weight[:self.expert_start_index], + loaded_weight[self.expert_end_index:] + ] + loaded_weight = np.concatenate(rearange_weight, axis=0) + + param.set_data(ms.from_numpy(loaded_weight)) + + def construct( + self, + x: Tensor) -> Union[Tensor, tuple[Tensor, Optional[Parameter]]]: + bias = self.bias if not self.skip_bias_add else None + output = self.quant_method.apply(self, x, bias) + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + return s + + class ColumnParallelLinear(LinearBase): """Linear layer with column parallelism. diff --git a/vllm_mindspore/model_executor/layers/rotary_embedding.py b/vllm_mindspore/model_executor/layers/rotary_embedding.py index 7bd3a43f35dda7f1e52de45dde04ec4c7bd4a3a9..ca071b461a7e5ba43327cb2b5d36b4b5c3676a57 100644 --- a/vllm_mindspore/model_executor/layers/rotary_embedding.py +++ b/vllm_mindspore/model_executor/layers/rotary_embedding.py @@ -77,6 +77,7 @@ class RotaryEmbedding(nn.Cell): base: float, is_neox_style: bool, dtype, + partial_rotary_factor: float = 1.0, ) -> None: super().__init__() self.head_size = head_size @@ -152,6 +153,7 @@ class InferRotaryEmbedding(nn.Cell): base: float, is_neox_style: bool, dtype, + partial_rotary_factor: float = 1.0, ) -> None: if not is_neox_style: raise NotImplementedError("InferRotaryEmbedding only support" @@ -164,6 +166,7 @@ class InferRotaryEmbedding(nn.Cell): self.base = base self.is_neox_style = is_neox_style self.dtype = dtype + self.partial_rotary_factor = partial_rotary_factor self.freqs_cos, self.freqs_sin = self._compute_cos_sin_cache() def _compute_inv_freq(self, base: Union[int, float]) -> Tensor: @@ -199,11 +202,29 @@ class InferRotaryEmbedding(nn.Cell): query = query.contiguous() key = key.contiguous() if get_model_context("is_prefill"): - return self.rotary_embedding_op(query, key, self.freqs_cos, - self.freqs_sin, batch_valid_length) + freqs_cos, freqs_sin = self.freqs_cos, self.freqs_sin + else: + freqs_cos = mint.index_select(self.freqs_cos, 0, positions) + freqs_sin = mint.index_select(self.freqs_sin, 0, positions) + if self.partial_rotary_factor < 1.0: + bs, _ = query.shape + query = query.reshape((bs, -1, self.head_size)) + key = key.reshape((bs, -1, self.head_size)) + q_rot, q_pass = query[..., :self.rotary_dim], query[..., self.rotary_dim:] + k_rot, k_pass = key[..., :self.rotary_dim], key[..., self.rotary_dim:] + q_rot = q_rot.reshape((bs, -1)) + k_rot = k_rot.reshape((bs, -1)) + q_rot = q_rot.contiguous() + k_rot = k_rot.contiguous() + q_rot, k_rot = self.rotary_embedding_op(q_rot, k_rot, freqs_cos, freqs_sin, batch_valid_length) + q_rot = q_rot.reshape((bs, -1, self.rotary_dim)) + k_rot = k_rot.reshape((bs, -1, self.rotary_dim)) + query = mint.cat((q_rot, q_pass), -1) + key = mint.cat((k_rot, k_pass), -1) + query = query.reshape((bs, -1)) + key = key.reshape((bs, -1)) + return query, key - freqs_cos = mint.index_select(self.freqs_cos, 0, positions) - freqs_sin = mint.index_select(self.freqs_sin, 0, positions) return self.rotary_embedding_op(query, key, freqs_cos, freqs_sin, batch_valid_length) @@ -757,6 +778,7 @@ def get_rope( base, is_neox_style, dtype, + partial_rotary_factor=partial_rotary_factor, ) else: scaling_type = rope_scaling["rope_type"] diff --git a/vllm_mindspore/model_executor/models/interfaces.py b/vllm_mindspore/model_executor/models/interfaces.py index ff5a5e368632dadc4367a0782b0faa3e147a6bdb..b88be6ab20f306232355a2d7d59c227fb0a433ad 100644 --- a/vllm_mindspore/model_executor/models/interfaces.py +++ b/vllm_mindspore/model_executor/models/interfaces.py @@ -108,3 +108,16 @@ class _SupportsLoRAType(Protocol): supported_lora_modules: list[str] embedding_modules: dict[str, str] embedding_padding_modules: list[str] + + +@runtime_checkable +class SupportesMoeDpTp(Protocol): + """ + The interface required for MoE models that MoE layer + supports TP parallel under DP Context. + """ + support_moe_dp_tp = True + + +def supports_moe_dp_tp(model): + return getattr(model, "support_moe_dp_tp", False) diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 524e31cec9cc2f35d1827ff5f94b795b60598c86..d3c20eb27d9a0b8fb9fa8d345af0fd39056d7287 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -26,12 +26,14 @@ from mindspore import Tensor, mutable, nn from mindspore.common import dtype as mstype from vllm.attention.backends.abstract import AttentionType from vllm.config import VllmConfig, get_current_vllm_config +from vllm.distributed import get_dp_group, get_ep_group from vllm.forward_context import get_forward_context from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm_mindspore.model_executor.models.attention_mask import ( LowerTriangularMask) +from vllm_mindspore.model_executor.models.interfaces import supports_moe_dp_tp from vllm_mindspore.model_executor.models.utils import is_use_ringmla from vllm_mindspore.model_executor.utils import set_model_context from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE, create_kv_cache @@ -384,6 +386,25 @@ class NativeModel(MsModelBase): self.prefill_graph = None self.decode_graph = None + self.supports_moe_dp_tp = supports_moe_dp_tp(self) + if self.supports_moe_dp_tp: + # some configure for MOE + custom_ep_size = \ + vllm_config.additional_config.get("expert_parallel", 1) \ + if vllm_config.additional_config is not None else 1 + if get_dp_group().world_size > 1 and \ + (not self.parallel_config.enable_expert_parallel or + get_ep_group().world_size // int(custom_ep_size) > 1): + # If moe running tensor parallel under data parallel context, + # the inputs need to be pad for moe allgather. + self.moe_dp_pad = True + self.dp_group = get_dp_group().device_group._name + self.dp_cpu_group = get_dp_group().cpu_group._name + self.dp_world_size = get_dp_group().world_size + self.dp_rank = get_dp_group().rank_in_group + else: + self.moe_dp_pad = False + @property def ready_model(self) -> nn.Cell: if self.model is None: @@ -468,17 +489,85 @@ class NativeModel(MsModelBase): dyn_batch_valid_length = Tensor(shape=[None], dtype=mstype.int32) dyn_q_seq_lens = Tensor(shape=[None], dtype=mstype.int32) dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) - self.ready_model.set_inputs(dyn_input_ids, dyn_position_ids, - dyn_key_caches, dyn_value_caches, - dyn_slot_mapping, dynamic_attention_mask, - dyn_batch_valid_length, dyn_q_seq_lens, - dyn_block_tables, dyn_intermediate_tensors, - dyn_inputs_embeds) + + if self.supports_moe_dp_tp: + dyn_dp_pad_index = (Tensor(shape=[None], dtype=mstype.int32) + if self.moe_dp_pad else None) + dyn_dp_unpad_index = (Tensor(shape=[None], dtype=mstype.int32) + if self.moe_dp_pad else None) + dyn_dp_pad_index_with_offset = (Tensor( + shape=[None], dtype=mstype.int32) if self.moe_dp_pad else None) + dyn_dp_unpad_index_total_with_offset = (Tensor( + shape=[None], dtype=mstype.int32) if self.moe_dp_pad else None) + self.ready_model.set_inputs( + dyn_input_ids, dyn_position_ids, dyn_key_caches, + dyn_value_caches, dyn_slot_mapping, dynamic_attention_mask, + dyn_batch_valid_length, dyn_q_seq_lens, dyn_block_tables, + dyn_intermediate_tensors, dyn_inputs_embeds, dyn_dp_pad_index, + dyn_dp_unpad_index, dyn_dp_pad_index_with_offset, + dyn_dp_unpad_index_total_with_offset) + else: + self.ready_model.set_inputs( + dyn_input_ids, dyn_position_ids, dyn_key_caches, + dyn_value_caches, dyn_slot_mapping, dynamic_attention_mask, + dyn_batch_valid_length, dyn_q_seq_lens, dyn_block_tables, + dyn_intermediate_tensors, dyn_inputs_embeds) dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.model_config.dtype) self.ready_lm_head.set_inputs(dynamic_hidden_states) + def prepare_moe_tp_ep_inputs(self): + """ + prepare moe tp + ep padding indices for models that support moe tp + ep. + """ + if self.moe_dp_pad: + dp_meta = get_forward_context().dp_metadata + token_num_total_cumsum = dp_meta.cu_tokens_across_dp_cpu + max_token_num = dp_meta.max_tokens_across_dp_cpu + token_num_total_cumsum = token_num_total_cumsum.numpy() + max_token_num = max_token_num.numpy() + + token_num_total = np.diff(token_num_total_cumsum, prepend=0) + total_pad_num = max_token_num - token_num_total + this_pad_num = total_pad_num[self.dp_rank] + + dp_unpad_index = np.arange(token_num_total[self.dp_rank]) + dp_pad_index = np.pad(dp_unpad_index, (0, this_pad_num)) + + dp_pad_index_total_with_offset = [ + np.pad( + np.arange( + 0 if rank == 0 else token_num_total_cumsum[rank - 1], + token_num_total_cumsum[rank]), + (0, total_pad_num[rank])) + for rank in range(self.dp_world_size) + ] + dp_pad_index_total_with_offset = \ + np.concatenate(dp_pad_index_total_with_offset, axis=0) + + dp_unpad_index_total_with_offset = [ + np.arange(token_num_total[rank]) + rank * max_token_num + for rank in range(self.dp_world_size) + ] + dp_unpad_index_total_with_offset = \ + np.concatenate(dp_unpad_index_total_with_offset, axis=0) + + dp_unpad_index = ms.from_numpy(dp_unpad_index.astype(np.int32)) + dp_pad_index = ms.from_numpy(dp_pad_index.astype(np.int32)) + dp_pad_index_total_with_offset = ms.from_numpy( + dp_pad_index_total_with_offset.astype(np.int32)) + dp_unpad_index_total_with_offset = ms.from_numpy( + dp_unpad_index_total_with_offset.astype(np.int32)) + else: + dp_unpad_index = None + dp_pad_index = None + dp_pad_index_total_with_offset = None + dp_unpad_index_total_with_offset = None + + return (dp_unpad_index, dp_pad_index, dp_pad_index_total_with_offset, + dp_unpad_index_total_with_offset) + def prepare_inputs(self, input_ids, positions, intermediate_tensors, inputs_embeds): model_inputs, is_prefill = self.prepare_base_inputs( @@ -499,6 +588,16 @@ class NativeModel(MsModelBase): new_model_inputs["intermediate_tensors"] = intermediate_tensors new_model_inputs["inputs_embeds"] = inputs_embeds + if getattr(self, "supports_moe_dp_tp", False): + dp_unpad_index, dp_pad_index, dp_pad_index_total_with_offset, \ + dp_unpad_index_total_with_offset = self.prepare_moe_tp_ep_inputs() + new_model_inputs["dp_unpad_index"] = dp_unpad_index + new_model_inputs["dp_pad_index"] = dp_pad_index + new_model_inputs["dp_pad_index_total_with_offset"] = \ + dp_pad_index_total_with_offset + new_model_inputs["dp_unpad_index_total_with_offset"] = \ + dp_unpad_index_total_with_offset + return new_model_inputs, is_prefill def exec_model(self, @@ -510,7 +609,9 @@ class NativeModel(MsModelBase): model_inputs, is_prefill = self.prepare_inputs(input_ids, positions, intermediate_tensors, inputs_embeds) - + if kwargs.get('cache_params', None) is not None: + # used for qwen3-next + model_inputs['cache_params'] = kwargs.get('cache_params') # for dummy_attention_metadata if is_prefill and not self.set_flags: self.set_flags = True diff --git a/vllm_mindspore/model_executor/models/qwen3_moe.py b/vllm_mindspore/model_executor/models/qwen3_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..ce469d27670207b5c653d42b802b8fc7232e44c9 --- /dev/null +++ b/vllm_mindspore/model_executor/models/qwen3_moe.py @@ -0,0 +1,583 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_moe.py +# +# Copyright 2025 Huawei Technologites Co., Ltd +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen3MoE model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Any, Optional, Union + +from mindspore import Parameter, Tensor, mint, nn +from transformers import PretrainedConfig +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from vllm_mindspore.attention import Attention +from vllm_mindspore.model_executor.layers.activation import SiluAndMul +from vllm_mindspore.model_executor.layers.fused_moe import FusedMoE +from vllm_mindspore.model_executor.layers.layernorm import RMSNorm +from vllm_mindspore.model_executor.layers.linear import ( + MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, + RowParallelLinear) +from vllm_mindspore.model_executor.layers.logits_processor import ( + LogitsProcessor) +from vllm_mindspore.model_executor.layers.rotary_embedding import get_rope +from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm_mindspore.model_executor.model_loader.weight_utils import ( + default_weight_loader) +from vllm_mindspore.model_executor.models.interfaces import SupportesMoeDpTp +from vllm_mindspore.model_executor.models.model_base import NativeModel +from vllm_mindspore.model_executor.models.utils import ( + extract_layer_index, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) + +logger = init_logger(__name__) + + +class Qwen3MoeMLP(nn.Cell): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def construct(self, x, dp_pad_index=None,dp_unpad_index=None, + dp_pad_index_with_offset=None, dp_unpad_index_total_with_offset=None): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Qwen3MoeSparseMoeBlock(nn.Cell): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + self.experts = FusedMoE(num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + optim_tp_ep_gating_perf=True) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + optim_tp_ep_gating_perf=self.experts.optim_tp_ep_gating_perf, + expert_start_index=self.experts.expert_start_index, + expert_end_index=self.experts.expert_end_index, + ) + if config.shared_expert_intermediate_size > 0: + self.shared_expert = Qwen3MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.shared_expert_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=self.experts.must_reduce_shared_expert_outputs( + ), + prefix=f"{prefix}.shared_expert", + ) + self.shared_expert_gate = ReplicatedLinear(config.hidden_size, + 1, + bias=False, + prefix=f"{prefix}.shared_expert_gate") + else: + self.shared_expert = None + + def construct(self, hidden_states: Tensor, dp_pad_index, dp_unpad_index, + dp_pad_index_with_offset, + dp_unpad_index_total_with_offset) -> Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.shared_expert is not None: + shared_output = self.shared_expert(hidden_states) + logits, _ = self.shared_expert_gate(hidden_states) + shared_output = mint.sigmoid(logits) * shared_output + + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + dp_pad_index=dp_pad_index, + dp_unpad_index=dp_unpad_index, + dp_pad_index_with_offset=dp_pad_index_with_offset, + dp_unpad_index_total_with_offset=dp_unpad_index_total_with_offset) + + if self.shared_expert is not None: + final_hidden_states = final_hidden_states + shared_output + + if self.tp_size > 1: + final_hidden_states = \ + self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states) + + return final_hidden_states.view(orig_shape) + + +class Qwen3MoeAttention(nn.Cell): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.total_num_heads=config.num_attention_heads + self.total_num_kv_heads=config.num_key_value_heads + + tp_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = getattr(config, 'head_dim', None) or \ + (self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.scaling = self.head_dim ** -0.5 + self.rope_theta = getattr(config, "rope_theta", 10000) + self.rope_scaling = getattr(config, "rope_scaling", None) + self.partial_rotary_factor = ( + getattr(config, "partial_rotary_factor", 1.0)) + self.max_position_embeddings = ( + getattr(config, "max_position_embeddings", 8192)) + self.qkv_bias=getattr(config, 'attention_bias', False) + + self.qkv_proj = QKVParallelLinear(self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=self.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=self.rope_scaling, + partial_rotary_factor=self.partial_rotary_factor, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def construct( + self, + positions: Tensor, + hidden_states: Tensor, + key_cache: Tensor, + value_cache: Tensor, + slot_mapping: Tensor, + attn_mask: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, + block_tables: Tensor, + ) -> Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + # Add qk-norm + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, + self.head_dim) + q_by_head = self.q_norm(q_by_head) + q = q_by_head.view(q.shape) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, + self.head_dim) + k_by_head = self.k_norm(k_by_head) + k = k_by_head.view(k.shape) + q, k = self.rotary_emb(positions, q, k, batch_valid_length) + attn_output = self.attn(q, k, v, key_cache, value_cache, slot_mapping, + attn_mask, batch_valid_length, q_seq_lens, + block_tables) + output, _ = self.o_proj(attn_output) + return output + + +class Qwen3MoeDecoderLayer(nn.Cell): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.self_attn = Qwen3MoeAttention( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + # `mlp_only_layers` in the config. + layer_idx = extract_layer_index(prefix) + mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else + config.mlp_only_layers) + if (layer_idx not in mlp_only_layers) and ( + config.num_experts > 0 and + (layer_idx + 1) % config.decoder_sparse_step == 0): + self.mlp = Qwen3MoeSparseMoeBlock(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def construct( + self, + positions: Tensor, + hidden_states: Tensor, + key_cache: Tensor, + value_cache: Tensor, + slot_mapping: Tensor, + attn_mask: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, + block_tables: Tensor, + residual: Optional[Tensor], + dp_pad_index: Optional[bool] = None, + dp_unpad_index: Optional[Tensor] = None, + dp_pad_index_with_offset: Optional[Tensor] = None, + dp_unpad_index_total_with_offset: Optional[Tensor] = None, + ) -> Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn(positions, hidden_states, key_cache, + value_cache, slot_mapping, attn_mask, + batch_valid_length, q_seq_lens, + block_tables) + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states, dp_pad_index, dp_unpad_index, + dp_pad_index_with_offset, + dp_unpad_index_total_with_offset) + return hidden_states, residual + + +class Qwen3MoeModel(nn.Cell): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=f"{prefix}.embed_tokens") + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Qwen3MoeDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: Tensor) -> Tensor: + return self.embed_tokens(input_ids) + + def construct( + self, + input_ids: Tensor, + positions: Tensor, + key_caches: list[Tensor], + value_caches: list[Tensor], + slot_mapping: Tensor, + attn_mask: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, + block_tables: Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, + dp_pad_index: Optional[Tensor] = None, + dp_unpad_index: Optional[Tensor] = None, + dp_pad_index_total_with_offset: Optional[Tensor] = None, + dp_unpad_index_total_with_offset: Optional[Tensor] = None, + ) -> Union[Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, key_caches[i - self.start_layer], + value_caches[i - self.start_layer], slot_mapping, attn_mask, + batch_valid_length, q_seq_lens, block_tables, residual, + dp_pad_index, dp_unpad_index, dp_pad_index_total_with_offset, + dp_unpad_index_total_with_offset) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, Tensor]], + params_dict: dict[str, Parameter]): + # return + stacked_params_mapping = [ + # the format is (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + # Params for weights, fp8 weight scales, fp8 activation scales + # the format is (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + loaded_params: set[str] = set() + for name, loaded_weight in weights: + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + or name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen3MoeForCausalLM(NativeModel, SupportesMoeDpTp): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Qwen3MoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + self.common_preprocess(vllm_config, prefix) + + def get_input_embeddings(self, input_ids: Tensor) -> Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward(self, + input_ids: Tensor, + positions: Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, + **kwargs) -> Union[Tensor, IntermediateTensors]: + hidden_states = self.exec_model(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> set[str]: + params_dict = self.get_params_dict() + return self.model.load_weights(weights, params_dict) diff --git a/vllm_mindspore/model_executor/models/qwen3_next.py b/vllm_mindspore/model_executor/models/qwen3_next.py new file mode 100644 index 0000000000000000000000000000000000000000..a6240da889486a431b3cc4c85e00304cee604de6 --- /dev/null +++ b/vllm_mindspore/model_executor/models/qwen3_next.py @@ -0,0 +1,1063 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_moe.py +# +# Copyright 2025 Huawei Technologites Co., Ltd +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen3MoE model compatible with HuggingFace weights.""" +import collections +import numpy as np +from collections.abc import Iterable +from itertools import repeat +from typing import Any, Optional, Union + +import mindspore as ms +from mindspore import nn, mint, ops, Tensor, Parameter +from mindspore import dtype as mstype +from mindspore.common.initializer import initializer, One +from transformers import PretrainedConfig +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.interfaces import HasInnerState +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from vllm_mindspore.model_executor.layers.layernorm import RMSNorm +from vllm_mindspore.model_executor.layers.linear import ( + ColumnParallelLinear, RowParallelLinear, QKVParallelLinear) +from vllm_mindspore.model_executor.layers.logits_processor import ( + LogitsProcessor) +from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm_mindspore.model_executor.models.interfaces import SupportesMoeDpTp +from vllm_mindspore.model_executor.models.model_base import NativeModel +from vllm_mindspore.model_executor.models.mf_models.qwen3_next_cache import \ + Qwen3NextCacheManager, Qwen3NextCacheParams +from vllm_mindspore.model_executor.models.utils import ( + extract_layer_index, make_empty_intermediate_tensors_factory, + make_layers, maybe_prefix) +from vllm_mindspore.model_executor.models.qwen3_moe import Qwen3MoeAttention, \ + Qwen3MoeSparseMoeBlock, Qwen3MoeMLP, Qwen3MoeModel +from vllm_mindspore.model_executor.utils import get_model_context + +from vllm_mindspore.model_executor.models.triton_kernels import fused_gdn_gating, \ + fused_sigmoid_gating_delta_rule_update_npu, fused_recurrent_gated_delta_rule_fwd + +logger = init_logger(__name__) + + +FusedRMSNormGated = None +causal_conv1d_fn = None +causal_conv1d_update = None +chunk_gated_delta_rule = None +fused_recurrent_gated_delta_rule = None + +class Conv1d(nn.Cell): + r"""Applies a 1D convolution over an input signal composed of several input + planes. + + In the simplest case, the output value of the layer with input size + :math:`(N, C_{\text{in}}, L)` and output :math:`(N, C_{\text{out}}, L_{\text{out}})` can be + precisely described as: + + .. math:: + \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + + \sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}_j}, k) + \star \text{input}(N_i, k) + + where :math:`\star` is the valid `cross-correlation`_ operator, + :math:`N` is a batch size, :math:`C` denotes a number of channels, + :math:`L` is a length of signal sequence. + """ + __constants__ = ['stride', 'padding', 'dilation', 'groups', + 'padding_mode', 'output_padding', 'in_channels', + 'out_channels', 'kernel_size'] + __annotations__ = {'bias': Optional[Tensor]} + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros', + dtype=None + ) -> None: + super().__init__() + def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + _single = _ntuple(1) + + def _reverse_repeat_tuple(t, n): + r"""Reverse the order of `t` and repeat each element for `n` times. + + This can be used to translate padding arg used by Conv and Pooling modules + to the ones used by `F.pad`. + """ + return tuple(x for x in reversed(t) for _ in range(n)) + + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = padding if isinstance(padding, str) else _single(padding) + dilation = _single(dilation) + output_padding = _single(0) + if groups <= 0: + raise ValueError('groups must be a positive integer') + if in_channels % groups != 0: + raise ValueError('in_channels must be divisible by groups') + if out_channels % groups != 0: + raise ValueError('out_channels must be divisible by groups') + valid_padding_strings = {'same', 'valid'} + if isinstance(padding, str): + if padding not in valid_padding_strings: + raise ValueError( + f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}") + if padding == 'same' and any(s != 1 for s in stride): + raise ValueError("padding='same' is not supported for strided convolutions") + + valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'} + if padding_mode not in valid_padding_modes: + raise ValueError(f"padding_mode must be one of {valid_padding_modes}, but got padding_mode='{padding_mode}'") + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.output_padding = output_padding + self.groups = groups + self.padding_mode = padding_mode + if isinstance(self.padding, str): + self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size) + if padding == 'same': + for d, k, i in zip(dilation, kernel_size, + range(len(kernel_size) - 1, -1, -1)): + total_padding = d * (k - 1) + left_pad = total_padding // 2 + self._reversed_padding_repeated_twice[2 * i] = left_pad + self._reversed_padding_repeated_twice[2 * i + 1] = ( + total_padding - left_pad) + else: + self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2) + self.weight = Parameter(mint.empty( + (out_channels, in_channels // groups, *kernel_size), dtype=dtype)) + + self.bias = None + if bias: + self.bias = Parameter(mint.empty(out_channels, dtype=dtype)) + + pad_mode = 'valid' + pad = padding + if isinstance(padding, tuple): + if padding[0] != 0: + pad_mode = 'pad' + pad = (0, 0, padding[0], padding[0]) + elif isinstance(padding, int): + if padding != 0: + pad_mode = 'pad' + pad = (0, 0) + (padding,) * 2 + if not isinstance(padding, (int, tuple)): + pad_mode = padding + pad = (0,) * 4 + + if self.padding_mode != 'zeros': + pad_mode = 'valid' + pad = (0,) * 4 + self.conv2d = ops.Conv2D(out_channel=self.out_channels, + kernel_size=(1,) + self.kernel_size, + mode=1, + pad_mode=pad_mode, + pad=pad, + stride=(1,) + self.stride, + dilation=(1,) + self.dilation, + group=self.groups) + + def construct(self, input): + if self.padding_mode != 'zeros': + input = mint.nn.functional.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode) + input = input.expand_dims(2) + output = self.conv2d(input, self.weight.expand_dims(2)) + + if self.bias is not None: + output = ops.bias_add(output, self.bias) + + output = output.squeeze(2) + return output + +class Qwen3NextRMSNorm(RMSNorm): + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + ): + super().__init__( + hidden_size=hidden_size, + eps=eps, + ) + self.cast = ops.Cast() + + def construct(self, x, residual: Optional[Tensor] = None): + """Forward of RMSNormGated.""" + original_type = x.dtype + x = self.cast(x, mstype.float32) + weight = 1.0 + self.weight.astype(mstype.float32) + if residual is not None: + residual = self.cast(residual, mstype.float32) + output, _, residual = self.add_rms_norm(x, residual, weight, self.eps) + return self.cast(output, original_type), self.cast(residual, original_type) + output = self.rms_norm(x, weight)[0] + return self.cast(output, original_type) + +class RMSNormGated(RMSNorm): + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + ): + super().__init__( + hidden_size=hidden_size, + eps=eps, + ) + + def construct(self, x, gate=None): + """Forward of RMSNormGated.""" + output = self.rms_norm(x, self.weight)[0] + output = output * mint.nn.functional.silu(gate) + return output + +def mint_causal_conv1d_update( + hidden_states, + conv_state, + weight, + bias=None, + activation=None, +): + """ + MindSpore mint implementation of mint_causal_conv1d_update function. + """ + _, hidden_size, seq_len = hidden_states.shape + state_len = conv_state.shape[-1] + + hidden_states_new = mint.cat([conv_state, hidden_states], dim=-1).to(weight.dtype) + # ops.assign(conv_state, hidden_states_new[:, :, -state_len:]) + conv_state = hidden_states_new[:, :, -state_len:] + # Use MindSpore's ops.conv1d function + # weight needs to be expanded for grouped convolution + weight_expanded = weight.unsqueeze(1) # Shape: (hidden_size, 1, kernel_size) + + out = ops.conv1d(hidden_states_new, weight_expanded, bias, groups=hidden_size, padding=0, stride=1) + + # Apply SiLU activation and extract last seq_len outputs + out = mint.nn.functional.silu(out[:, :, -seq_len:]) + out = out.to(hidden_states.dtype) + return out, conv_state + + +def mint_chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + chunk_size=64, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, +): + """ + MindSpore mint implementation of torch_chunk_gated_delta_rule function. + """ + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = mint.nn.functional.normalize(query, p=2.0, dim=-1) + key = mint.nn.functional.normalize(key, p=2.0, dim=-1) + + # Convert tensors to float32 and transpose + query = mint.transpose(query, 1, 2).to(ms.float32) + key = mint.transpose(key, 1, 2).to(ms.float32) + value = mint.transpose(value, 1, 2).to(ms.float32) + beta = mint.transpose(beta, 1, 2).to(ms.float32) + g = mint.transpose(g, 1, 2).to(ms.float32) + + batch_size, sequence_length, num_heads, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - num_heads % chunk_size) % chunk_size + + # Padding operations - MindSpore pad format: flattened tuple for all dimensions + # For query/key/value: pad num_heads dimension (dim=2), format: (dim0_left, dim0_right, dim1_left, dim1_right, dim2_left, dim2_right, dim3_left, dim3_right) + query = mint.nn.functional.pad(query, (0, 0, 0, pad_size)) + key = mint.nn.functional.pad(key, (0, 0, 0, pad_size)) + value = mint.nn.functional.pad(value, (0, 0, 0, pad_size)) + beta = mint.nn.functional.pad(beta, (0, pad_size)) + g = mint.nn.functional.pad(g, (0, pad_size)) + + tot_heads = num_heads + pad_size + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + # reshape to chunks + query = query.reshape(query.shape[0], query.shape[1], -1, chunk_size, query.shape[-1]) + key = key.reshape(key.shape[0], key.shape[1], -1, chunk_size, key.shape[-1]) + value = value.reshape(value.shape[0], value.shape[1], -1, chunk_size, value.shape[-1]) + k_beta = k_beta.reshape(k_beta.shape[0], k_beta.shape[1], -1, chunk_size, k_beta.shape[-1]) + v_beta = v_beta.reshape(v_beta.shape[0], v_beta.shape[1], -1, chunk_size, v_beta.shape[-1]) + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + + # Create upper triangular mask + mask = mint.triu(mint.ones((chunk_size, chunk_size)), diagonal=0).astype(ms.bool_) + + # Chunk decay + g = ops.cumsum(g, axis=-1) + g_diff = g.unsqueeze(-1) - g.unsqueeze(-2) + decay_mask = ops.exp(mint.tril(g_diff)).astype(ms.float32) + decay_mask = mint.tril(decay_mask) + + # Attention computation + key_T = mint.transpose(key, -1, -2) + attn_base = mint.matmul(k_beta, key_T) * decay_mask + attn = -attn_base.masked_fill(mask, 0) + + # Sequential attention update + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn_update = mint.sum(row.unsqueeze(-1) * sub, dim=-2) + attn[..., i, :i] = row + attn_update + + # Add identity matrix + eye = mint.eye(chunk_size, dtype=attn.dtype) + attn = attn + eye + + value = mint.matmul(attn, v_beta) + g_exp = ops.exp(g).unsqueeze(-1) + k_cumdecay = mint.matmul(attn, k_beta * g_exp) + + # Initialize recurrent state + if initial_state is None: + last_recurrent_state = mint.zeros((batch_size, sequence_length, k_head_dim, v_head_dim)) + last_recurrent_state = last_recurrent_state.to(value.dtype) + else: + last_recurrent_state = initial_state.astype(value.dtype) + + core_attn_out = mint.zeros_like(value) + mask = mint.triu(mint.ones((chunk_size, chunk_size)), diagonal=1).astype(ms.bool_) + + # Process each chunk + for i in range(0, tot_heads // chunk_size): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + + # Attention within chunk + q_k_T = mint.matmul(q_i, mint.transpose(k_i, -1, -2)) + attn = (q_k_T * decay_mask[:, :, i]).masked_fill_(mask, 0) + + # Cross-chunk interaction + v_prime = mint.matmul(k_cumdecay[:, :, i], last_recurrent_state) + v_new = v_i - v_prime + + g_exp_chunk = ops.exp(g[:, :, i, :, None]) + attn_inter = mint.matmul(q_i * g_exp_chunk, last_recurrent_state) + core_attn_out[:, :, i] = attn_inter + mint.matmul(attn, v_new) + + # Update recurrent state + g_last_exp = ops.exp(g[:, :, i, -1, None, None]) + g_diff_exp = ops.exp(g[:, :, i, -1, None] - g[:, :, i]) + k_weighted = k_i * g_diff_exp[..., None] + k_weighted_T = mint.transpose(k_weighted, -1, -2) + state_update = mint.matmul(k_weighted_T, v_new) + last_recurrent_state = last_recurrent_state * g_last_exp + state_update + + if not output_final_state: + last_recurrent_state = None + + # Reshape and slice output + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) + core_attn_out = core_attn_out[:, :, :num_heads] + core_attn_out = mint.transpose(core_attn_out, 1, 2).to(initial_dtype) + + return core_attn_out, last_recurrent_state + + +def mint_recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = mint.nn.functional.normalize(query, p=2.0, dim=-1) + key = mint.nn.functional.normalize(key, p=2.0, dim=-1) + + # (b, s, h, d) -> (b, h, s, d)? Torch code performs transpose(1,2) but uses (b, s, h, d) afterwards. + # Their inputs to this function are (b, h, s, d) so transpose(1,2) yields (b, s, h, d). + # We mirror that behavior: always end up with (b, s, h, d) + query = (mint.transpose(query, 1, 2)).astype(ms.float32) + key = (mint.transpose(key, 1, 2)).astype(ms.float32) + value = (mint.transpose(value, 1, 2)).astype(ms.float32) + beta = (mint.transpose(beta, 1, 2)).astype(ms.float32) + g = (mint.transpose(g, 1, 2)).astype(ms.float32) + + batch_size, sequence_length, num_heads, k_head_dim = key.shape + v_head_dim = value.shape[-1] + + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out_list = [] # per-head outputs of shape (b, s, v) + if initial_state is None: + last_recurrent_state = mint.zeros((batch_size, sequence_length, k_head_dim, v_head_dim), dtype=value.dtype) + else: + last_recurrent_state = initial_state.astype(value.dtype) + + for i in range(num_heads): + q_t = query[:, :, i] + k_t = key[:, :, i] + v_t = value[:, :, i] + g_t = mint.unsqueeze(mint.unsqueeze(mint.exp(g[:, :, i]), -1), -1) + beta_t = mint.unsqueeze(beta[:, :, i], -1) + + last_recurrent_state = last_recurrent_state * g_t + kv_mem = mint.sum(last_recurrent_state * mint.unsqueeze(k_t, -1), dim=-2) + delta = (v_t - kv_mem) * beta_t + last_recurrent_state = last_recurrent_state + mint.unsqueeze(k_t, -1) * mint.unsqueeze(delta, -2) + out_t = mint.sum(last_recurrent_state * mint.unsqueeze(q_t, -1), dim=-2) + core_attn_out_list.append(out_t) + + core_attn_out = mint.stack(core_attn_out_list, dim=2) # (b, s, h, v) + core_attn_out = mint.transpose(core_attn_out, 1, 2) # match torch: (b, h, s, v) + + if not output_final_state: + last_recurrent_state = None + + # cast back to input dtype + core_attn_out = core_attn_out.astype(initial_dtype) + return core_attn_out, last_recurrent_state + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +class GDNAttention(nn.Cell): + def __init__(self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "",): + super().__init__() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.activation = config.hidden_act + self.layer_norm_epsilon = config.rms_norm_eps + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + + self.conv1d = Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + dtype=get_current_vllm_config().model_config.dtype + ) + + # projection of the input hidden states + projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 + projection_size_ba = self.num_v_heads * 2 + self.in_proj_qkvz = ColumnParallelLinear( + self.hidden_size, + projection_size_qkvz, + gather_output=True, + bias=False, + prefix=f"{prefix}.in_proj_qkvz", + return_bias=False, + ) + self.in_proj_ba = ColumnParallelLinear( + self.hidden_size, + projection_size_ba, + gather_output=True, + bias=False, + prefix=f"{prefix}.in_proj_ba", + return_bias=False, + ) + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = Parameter(initializer(One(), (self.num_v_heads,)), name='dt_bias') + + A = Tensor(np.random.uniform(0, 16, (self.num_v_heads,)), dtype=mstype.float32) + self.A_log = Parameter(ops.log(A), name='A_log') + + if FusedRMSNormGated is None: + self.norm = RMSNormGated(hidden_size=self.head_v_dim, eps=self.layer_norm_epsilon) + else: + # FusedRMSNormGated is available but not implemented in this branch + raise NotImplementedError("Support for FusedRMSNormGated is not implemented.") + + self.out_proj = RowParallelLinear(self.value_dim, + self.hidden_size, + input_is_parallel=False, + bias=False, + prefix=f"{prefix}.out_proj", + return_bias=False) + + self.causal_conv1d_fn = causal_conv1d_fn + self.causal_conv1d_update = causal_conv1d_update or mint_causal_conv1d_update + self.chunk_gated_delta_rule = chunk_gated_delta_rule or mint_chunk_gated_delta_rule + self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or mint_recurrent_gated_delta_rule + + def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): + """ + Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`. + """ + # qkvz_dim = (2 * self.key_dim + 2 * self.value_dim) # // self.tp.size + # ba_dim = 2 * self.num_v_heads # // self.tp.size + # + # mixed_qkvz, mixed_ba = mint.split(mixed_qkvzba, [qkvz_dim, ba_dim], dim=-1) + new_tensor_shape_qkvz = mixed_qkvz.shape[:-1] + ( + self.num_k_heads, # // self.tp.size, + 2 * self.head_k_dim + 2 * self.head_v_dim * self.num_v_heads // self.num_k_heads, + ) + new_tensor_shape_ba = mixed_ba.shape[:-1] + (self.num_k_heads, # // self.tp.size, + 2 * self.num_v_heads // self.num_k_heads) + + mixed_qkvz = mixed_qkvz.reshape(*new_tensor_shape_qkvz) + mixed_ba = mixed_ba.reshape(*new_tensor_shape_ba) + split_arg_list_qkvz = [ + self.head_k_dim, + self.head_k_dim, + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + ] + split_arg_list_ba = [self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads] + query, key, value, z = mint.split(mixed_qkvz, split_arg_list_qkvz, dim=3) + b, a = mint.split(mixed_ba, split_arg_list_ba, dim=3) + # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] + value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim) + z = z.reshape(z.shape[0], z.shape[1], -1, self.head_v_dim) + b = b.reshape(b.shape[0], b.shape[1], self.num_v_heads) # // self.tp.size) + a = a.reshape(a.shape[0], a.shape[1], self.num_v_heads) # // self.tp.size) + return query, key, value, z, b, a + + def pack_to_bsh(self, hidden_states, seq_lengths): + batch_size = seq_lengths.shape[0] + max_seq_len = seq_lengths.max() + output = mint.zeros((batch_size, max_seq_len, self.hidden_size), dtype=hidden_states.dtype) + + token_idx = 0 + for batch_idx in range(batch_size): + seq_len = seq_lengths[batch_idx] + output[batch_idx, -seq_len:, :] = hidden_states[token_idx:token_idx + seq_len, :] + token_idx += seq_len + return output + + def pack_to_th(self, core_attn_out, seq_lengths): + output_hiden_states = [] + batch_size = seq_lengths.shape[0] + for batch_idx in range(batch_size): + # left_padding to adapt conv_state cache. [seq_len, num_heads, hidden_size] + hiden_states = core_attn_out[batch_idx, -seq_lengths[batch_idx]:, :] # [seq_len, hidden_size] + output_hiden_states.append(hiden_states) + + core_attn_out_th = mint.cat(output_hiden_states, dim=0) # [total_tokens, heads, head_dim] + return core_attn_out_th + + def construct( + self, + hidden_states, + cache_params=None, + has_previous_state=True, + cache_position=None, + attention_mask=None, + batch_valid_length=None + ): + hidden_states = self.pack_to_bsh(hidden_states, batch_valid_length) \ + if get_model_context("is_prefill") else hidden_states.unsqueeze(1) + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + + use_precomputed_states = ( + cache_params is not None + and has_previous_state + and seq_len == 1 + ) + + projected_states_qkvz = self.in_proj_qkvz(hidden_states) + projected_states_ba = self.in_proj_ba(hidden_states) + query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba) + query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)) + + mixed_qkv = mint.cat((query, key, value), dim=-1) + mixed_qkv = mixed_qkv.transpose(1, 2) + + if use_precomputed_states: + # getting projected states from cache + conv_state = cache_params.conv_state[cache_params.state_indices_tensor].transpose(1, 2) + # 2. Convolution sequence transformation + # NOTE: the conv state is updated in `causal_conv1d_update` + mixed_qkv, conv_state = self.causal_conv1d_update( + mixed_qkv, + conv_state, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + else: + conv_state = mint.nn.functional.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) + if self.causal_conv1d_fn is not None: + mixed_qkv = self.causal_conv1d_fn( + x=mixed_qkv, + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + seq_idx=None, + ) + else: + mixed_qkv = ops.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + + if cache_params is not None: + cache_params.conv_state[cache_params.state_indices_tensor] = conv_state.transpose(1, 2) + mixed_qkv = mixed_qkv.transpose(1, 2) + query, key, value = mint.split( + mixed_qkv, + [ + self.key_dim, + self.key_dim, + self.value_dim, + ], + dim=-1, + ) + query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim) + key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim) + value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim) + + beta = mint.sigmoid(b) + # If the model is loaded in fp16, without the .float() here, A might be -inf + # g = -self.A_log.float().exp() * mint.nn.functional.softplus(a + self.dt_bias) + g = fused_gdn_gating(self.A_log * 1.0, a, self.dt_bias * 1.0) + if self.num_v_heads // self.num_k_heads > 1: + query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + + if not use_precomputed_states: + # beta = mint.sigmoid(b) + # g = fused_gdn_gating(self.A_log * 1.0, a, self.dt_bias * 1.0) + # if self.num_v_heads // self.num_k_heads > 1: + # query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + # key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + else: + recurrent_state = cache_params.temporal_state[cache_params.state_indices_tensor] + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + # ToDo: the precision of fused triton-ascend kernel is not right. + # core_attn_out, last_recurrent_state = fused_recurrent_gated_delta_rule_fwd( + # query.contiguous(), + # key.contiguous(), + # value.contiguous(), + # g.contiguous(), + # beta=beta.contiguous(), + # initial_state=cache_params.temporal_state, + # inplace_final_state=False, + # # ssm_state_indices=cache_params.state_indices_tensor, + # use_qk_l2norm_in_kernel=True, + # ) + + # ToDo: the precision of fused triton-ascend kernel is not right. + # print("aaaaaaaaaaaaaaaaaaaaaaaaaaaaa, cache_params.temporal_state", cache_params.temporal_state) + # a = a.reshape(a.shape[0] * a.shape[1], self.num_v_heads) + # b = b.reshape(b.shape[0] * b.shape[1], self.num_v_heads) + # core_attn_out, last_recurrent_state = fused_sigmoid_gating_delta_rule_update_npu( + # self.A_log * 1.0, + # a.contiguous(), + # self.dt_bias * 1.0, + # 1.0, + # 20.0, + # query.contiguous(), + # key.contiguous(), + # value.contiguous(), + # b.contiguous(), + # cache_params.temporal_state * 1.0, + # cache_params.state_indices_tensor, + # use_qk_l2norm_in_kernel=True, + # ) + + + # Update cache + if cache_params is not None: + cache_params.temporal_state[cache_params.state_indices_tensor] = last_recurrent_state + + z_shape_og = z.shape + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) + + output = self.out_proj(core_attn_out) + output = self.pack_to_th(output, batch_valid_length) \ + if get_model_context("is_prefill") else output.squeeze(1) + return output + +class Qwen3NextAttention(Qwen3MoeAttention): + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, cache_config, quant_config, prefix) + self.q_size = self.num_heads * self.head_dim * 2 + + self.qkv_proj = QKVParallelLinear(self.hidden_size, + self.head_dim, + self.total_num_heads * 2, + self.total_num_kv_heads, + bias=self.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def construct( + self, + positions: Tensor, + hidden_states: Tensor, + key_cache: Tensor, + value_cache: Tensor, + slot_mapping: Tensor, + attn_mask: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, + block_tables: Tensor, + ) -> Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q = q.reshape(*q.shape[:-1], self.num_heads, -1) + q, g = mint.chunk(q, 2, dim=-1) + q = q.reshape(*hidden_states.shape[:-1], -1) + g = g.reshape(*hidden_states.shape[:-1], -1) + + # Add qk-norm + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, + self.head_dim) + q_by_head = self.q_norm(q_by_head) + q = q_by_head.view(q.shape) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, + self.head_dim) + k_by_head = self.k_norm(k_by_head) + k = k_by_head.view(k.shape) + q, k = self.rotary_emb(positions, q, k, batch_valid_length) + attn_output = self.attn(q, k, v, key_cache, value_cache, slot_mapping, + attn_mask, batch_valid_length, q_seq_lens, + block_tables) + + attn_output = attn_output.reshape(*attn_output.shape[:-1], -1).contiguous() + attn_output = attn_output * mint.sigmoid(g) + output, _ = self.o_proj(attn_output) + return output + +class Qwen3NextDecoderLayer(nn.Cell): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_type = config.layer_types[extract_layer_index(prefix)] + + if self.layer_type == "linear_attention": + self.linear_attn = GDNAttention( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.linear_attn" + ) + elif self.layer_type == "full_attention": + self.self_attn = Qwen3NextAttention( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + # `mlp_only_layers` in the config. + layer_idx = extract_layer_index(prefix) + mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else + config.mlp_only_layers) + + if (layer_idx not in mlp_only_layers) and ( + config.num_experts > 0 and + (layer_idx + 1) % config.decoder_sparse_step == 0): + self.mlp = Qwen3MoeSparseMoeBlock(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.input_layernorm = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def construct( + self, + positions: Tensor, + hidden_states: Tensor, + key_cache: Tensor, + value_cache: Tensor, + slot_mapping: Tensor, + attn_mask: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, + block_tables: Tensor, + residual: Optional[Tensor], + dp_pad_index: Optional[bool] = None, + dp_unpad_index: Optional[Tensor] = None, + dp_pad_index_with_offset: Optional[Tensor] = None, + dp_unpad_index_total_with_offset: Optional[Tensor] = None, + cache_params: Optional[Qwen3NextCacheParams] = None, + ) -> Tensor: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + if self.layer_type == "linear_attention": + hidden_states = self.linear_attn( + hidden_states, + cache_params=cache_params, + cache_position=positions, + batch_valid_length=batch_valid_length + ) + elif self.layer_type == "full_attention": + hidden_states = self.self_attn(positions, hidden_states, key_cache, + value_cache, slot_mapping, attn_mask, + batch_valid_length, q_seq_lens, + block_tables) + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + hidden_states = self.mlp(hidden_states, dp_pad_index, dp_unpad_index, + dp_pad_index_with_offset, + dp_unpad_index_total_with_offset) + return hidden_states, residual + + +class Qwen3NextModel(Qwen3MoeModel, nn.Cell): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Cell.__init__(self) + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=f"{prefix}.embed_tokens") + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Qwen3NextDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def construct( + self, + input_ids: Tensor, + positions: Tensor, + key_caches: list[Tensor], + value_caches: list[Tensor], + slot_mapping: Tensor, + attn_mask: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, + block_tables: Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, + dp_pad_index: Optional[Tensor] = None, + dp_unpad_index: Optional[Tensor] = None, + dp_pad_index_total_with_offset: Optional[Tensor] = None, + dp_unpad_index_total_with_offset: Optional[Tensor] = None, + cache_params: Optional[Qwen3NextCacheParams] = None, + ) -> Union[Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + cache_param = cache_params.at_layer_idx(i) \ + if cache_params is not None else cache_params + hidden_states, residual = layer( + positions, hidden_states, key_caches[i - self.start_layer], + value_caches[i - self.start_layer], slot_mapping, attn_mask, + batch_valid_length, q_seq_lens, block_tables, residual, + dp_pad_index, dp_unpad_index, dp_pad_index_total_with_offset, + dp_unpad_index_total_with_offset, cache_param) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class Qwen3NextForCausalLM(NativeModel, SupportesMoeDpTp, HasInnerState): + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Qwen3NextModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + self.common_preprocess(vllm_config, prefix) + self.qwen3_next_cache = Qwen3NextCacheManager(vllm_config, + self.num_layers) + + def get_input_embeddings(self, input_ids: Tensor) -> Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward(self, + input_ids: Tensor, + positions: Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, + **kwargs) -> Union[Tensor, IntermediateTensors]: + cache_params = self.qwen3_next_cache.current_run_tensors(**kwargs) + cache_params = cache_params if cache_params.conv_state is not None else None + hidden_states = self.exec_model(input_ids, positions, + intermediate_tensors, inputs_embeds, + cache_params=cache_params) + return hidden_states + + def compute_logits( + self, + hidden_states: Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> set[str]: + params_dict = self.get_params_dict() + return self.model.load_weights(weights, params_dict) diff --git a/vllm_mindspore/model_executor/models/registry.py b/vllm_mindspore/model_executor/models/registry.py index f3564e0f220f6499aaae8e0e906efabf3e0aec1d..f21cb5c8eb43d9444318b15254fa257ed08e9c38 100644 --- a/vllm_mindspore/model_executor/models/registry.py +++ b/vllm_mindspore/model_executor/models/registry.py @@ -59,6 +59,8 @@ _NATIVE_MODELS = { "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"), "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), + "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"), + "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"), } _MINDFORMERS_MODELS = { diff --git a/vllm_mindspore/model_executor/models/run.py b/vllm_mindspore/model_executor/models/run.py new file mode 100644 index 0000000000000000000000000000000000000000..83bea1f1f5743691b6f592436eae2e2457f6ca6b --- /dev/null +++ b/vllm_mindspore/model_executor/models/run.py @@ -0,0 +1,62 @@ +def prepare_inputs(input_ids, is_prefill): + ''' + :return: input_ids, batch_valid_len, slot_mapping, attention_mask, key_cache + ''' + + +class LowerTriangularMaskWithDynamic(nn.Cell): + + def __init__(self, seq_length): + self.lower_triangle_mask = Tensor( + np.triu(np.ones(shape=(128, 128), dtype=np.float16), 1) * mask_coeff, dtype=compute_type + ) + self.upper_triangle_mask = Tensor((1 - np.tril(np.ones(shape=(seq_length, seq_length)))).astype(np.bool_), + mstype.bool_) + self.gather = P.Gather() + + def construct(self, tokens=None, masks=None, seq_chunk=None): + """Forward process of the CausalMask""" + if not self.use_pfa: + return super().construct(tokens, masks, seq_chunk) + seq_len = self.shape(tokens)[1] + lower_triangle_mask = self.lower_triangle_mask + if self.is_dynamic: + lower_triangle_mask = self.slice(self.lower_triangle_mask, (0, 0), (seq_len, seq_len), (1, 1)) + attention_mask = self.expand_dim(lower_triangle_mask, 0) + attention_mask = self.sub(self.one, attention_mask) + attention_mask = self.expand_dim_post(attention_mask, 1) + attention_mask = self.cast(attention_mask, mstype.bool_) + return attention_mask + + def get_ifa_mask(self, batch_valid_length): + mask = self.gather(self.upper_triangle_mask, batch_valid_length, 0) + mask = self.expand_dim_post(mask, 1) + mask = self.expand_dim_post(mask, 1) + return mask + +def generate(input_ids): + kv_caches = [] + for i in range(config.layers): + kv_caches[layer_name] = torch.zeros(kv_cache_shape, dtype=dtype, device=self.device) + is_finished = False + seq_len = input_ids.shape[1] + model_inputs['input_ids'] = input_ids + model_inputs['batch_valid_len'] = ms.Tensor([seq_len], mstype.int32) + model_inputs['key_cache'] = kv_caches #[bs, max_model_len, hidden_size) + is_prefill = True + while not is_finished: + model_inputs['attention_mask'] = self.causal_mask(input_ids) if is_prefill else \ + self.causal_mask.get_ifa_mask(model_inputs['batch_valid_len']) + hidden_states = model(model_inputs)[0, -1, :] + # lmhead + logits = model.lm_head(hidden_states) + # postprocess + token_id = ops.argmax(logits) + + is_finished == token_id in generation_config.eos_token_id or \ + batch_valid_len + 1 == generation_config.max_len + + model_inputs['input_ids'] = token_id + model_inputs['batch_valid_len'] = batch_valid_len + 1 + model_inputs['slot_mapping'] = ms.Tensor(np.arange(model_inputs['batch_valid_len'].numpy())) + diff --git a/vllm_mindspore/model_executor/models/triton_kernels.py b/vllm_mindspore/model_executor/models/triton_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..d8301e1ccd6ebfa9c7cbcdee0c99b402a921c7f0 --- /dev/null +++ b/vllm_mindspore/model_executor/models/triton_kernels.py @@ -0,0 +1,584 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import functools + +import os +import triton +import triton.language as tl +from typing import Callable, Optional + +import mindspore as ms +from mindspore import Tensor, ops, mint + +# def custom_device_ctx(index: int): +# return torch.npu.device(index) + + +def input_guard(fn: Callable[..., Tensor]) -> Callable[..., Tensor]: + """ + A decorator to make sure all input tensors are contiguous and set the device based on input tensors. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + contiguous_args = ( + i if not isinstance(i, Tensor) else i.contiguous() for i in args + ) + contiguous_kwargs = { + k: (v if not isinstance(v, Tensor) else v.contiguous()) + for k, v in kwargs.items() + } + + # tensor = None + # for arg in args: + # if isinstance(arg, Tensor): + # tensor = arg + # break + # if tensor is None: + # for value in kwargs.values(): + # if isinstance(value, Tensor): + # tensor = value + # break + + # if tensor is not None: + # ctx = custom_device_ctx(tensor.device.index) + # else: + # ctx = contextlib.nullcontext() + # + # with ctx: + return fn(*contiguous_args, **contiguous_kwargs) + + return wrapper + +# contiguous = input_guard + +@triton.jit +def fused_gdn_gating_kernel( + g, + A_log, + a, + dt_bias, + seq_len, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off + mask = head_off < NUM_HEADS + blk_A_log = tl.load(A_log + head_off, mask=mask) + blk_a = tl.load(a + off, mask=mask) + blk_bias = tl.load(dt_bias + head_off, mask=mask) + # If the model is loaded in fp16, without the .float() here, A might be -inf + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where(beta * x <= threshold, + (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + + +def fused_gdn_gating( + A_log, + a, + dt_bias, + beta: float = 1.0, + threshold: float = 20.0, +) -> ms.Tensor: + batch, seq_len, num_heads = a.shape + # seq_len = 1 + grid = (batch, seq_len, triton.cdiv(num_heads, 8)) + g = mint.empty_like(a, dtype=ms.float32) + fused_gdn_gating_kernel[grid](g, + A_log, + a, + dt_bias, + seq_len, + num_heads, + beta, + threshold, + 8, + num_warps=1) + return g + + +@triton.heuristics( + { + "USE_INITIAL_STATE": lambda args: args["h0_source"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def fused_sigmoid_gating_delta_rule_update_kernel( + A_log, + a, + dt_bias, + softplus_beta, + softplus_threshold, + q, + k, + v, + b, + o, + h0_source, + h0_indices, + cu_seqlens, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + """ + Fused kernel that combines sigmoid gating computation with recurrent delta rule update. + """ + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int64), + tl.load(cu_seqlens + i_n + 1).to(tl.int64), + ) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + p_b = b + bos * HV + i_hv + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + # Gating computation pointers + p_A_log = A_log + i_hv + p_a = a + bos * HV + i_hv + p_dt_bias = dt_bias + i_hv + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + idx = tl.load(h0_indices + i_n) + # if idx >= 0: + tmp0 = tl.where(idx < 0, 0, idx) + p_h0 = ( + h0_source + + tmp0 * HV * K * V + i_hv * K * V + + o_k[:, None] * V+ o_v[None, :] + ) + temp1 = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + temp2 = tl.zeros_like(temp1) + value0 = tl.where(idx < 0, temp2, temp1) + b_h += value0 # tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i in range(0, T): + # Load inputs + b_q = tl.load(p_q + i * H * K, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k + i * H * K, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v + i * HV * V, mask=mask_v, other=0).to(tl.float32) + b_b = tl.load(p_b + i * HV).to(tl.float32) + + # Compute sigmoid gating + # Load gating parameters + b_A_log = tl.load(p_A_log).to(tl.float32) + b_a = tl.load(p_a + i * HV).to(tl.float32) + b_dt_bias = tl.load(p_dt_bias).to(tl.float32) + + # Compute g = -exp(A_log) * softplus(a + dt_bias) + x = b_a + b_dt_bias + beta_x = softplus_beta * x + # Apply softplus with numerical stability + softplus_x = tl.where( + beta_x <= softplus_threshold, + (1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)), + x, + ) + b_g = -tl.exp(b_A_log) * softplus_x + + # Compute beta = sigmoid(b) + b_beta = 1.0 / (1.0 + tl.exp(-b_b)) + + # Apply L2 normalization if enabled + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6) + b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6) + + b_q = b_q * scale + + # Apply gating to hidden state: h *= exp(g) + b_h *= tl.exp(b_g) + + # Delta rule: v -= sum(h * k, dim=0) + b_v -= tl.sum(b_h * b_k[:, None], 0) + + # Apply beta gating: v *= beta + b_v *= b_beta + + # Update hidden state: h += k[:, None] * v[None, :] + b_h += b_k[:, None] * b_v[None, :] + + # Compute output: o = sum(h * q, dim=0) + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o + i * HV * V, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # # Update pointers for next timestep + # p_q += H * K + # p_k += H * K + # p_o += HV * V + # p_v += HV * V + # p_b += HV + # p_a += HV + + # Store final state back to h0_source with bounds checking + if USE_INITIAL_STATE: + idx = tl.load(h0_indices + i_n) + if idx >= 0: + p_h0 = ( + h0_source + + idx * HV * K * V + + i_hv * K * V + + o_k[:, None] * V + + o_v[None, :] + ) + tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h) + + +# @input_guard +def fused_sigmoid_gating_delta_rule_update_npu( + A_log, + a, + dt_bias, + softplus_beta, + softplus_threshold, + q, + k, + v, + b, + initial_state_source, + initial_state_indices, + scale = None, + use_qk_l2norm_in_kernel = False, + cu_seqlens = None, +): + """ + Fused triton implementation of sigmoid gating delta rule update. + This function uses a single fused kernel that combines both sigmoid gating computation + and the recurrent delta rule update for better performance. + """ + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + + target_shape = (NK,) + v.shape + o = mint.empty(target_shape, dtype=q.dtype) + #o = q.new_empty(NK, *v.shape) + grid = (NK, NV, N * HV) + + fused_sigmoid_gating_delta_rule_update_kernel[grid]( + A_log=A_log, + a=a, + dt_bias=dt_bias, + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + q=q, + k=k, + v=v, + b=b, + o=o, + h0_source=initial_state_source, + h0_indices=initial_state_indices, + cu_seqlens=cu_seqlens, + scale=scale, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, initial_state_source[initial_state_indices] + + +@triton.heuristics({ + 'USE_INITIAL_STATE': + lambda args: args['h0'] is not None, + 'IS_VARLEN': + lambda args: args['cu_seqlens'] is not None, + "IS_CONTINUOUS_BATCHING": + lambda args: args['ssm_state_indices'] is not None, + "IS_SPEC_DECODING": + lambda args: args['num_accepted_tokens'] is not None, +}) +@triton.jit(do_not_specialize=['N', 'T']) +def fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0, + ht, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + scale, + N: tl.constexpr, # num of sequences + T: tl.constexpr, # num of tokens + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + stride_indices_tok: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace + IS_BETA_HEADWISE: tl. + constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if T == 0: + # no tokens to process for this sequence + return + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + # p_q = q + (bos * H + i_h) * K + o_k + # p_k = k + (bos * H + i_h) * K + o_k + # p_v = v + (bos * HV + i_hv) * V + o_v + # if IS_BETA_HEADWISE: + # p_beta = beta + (bos * HV + i_hv) * V + o_v + # else: + # p_beta = beta + bos * HV + i_hv + # p_g = g + bos * HV + i_hv + # p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + if IS_CONTINUOUS_BATCHING: + if IS_SPEC_DECODING: + i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 + else: + i_t = 0 + p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + + i_t).to(tl.int64) * stride_init_state_token + else: + p_h0 = h0 + bos * HV * K * V + p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i_t in range(0, T): + p_q = q + (bos * H + i_h) * K + o_k + H * K * i_t + p_k = k + (bos * H + i_h) * K + o_k + H * K * i_t + p_v = v + (bos * HV + i_hv) * V + o_v + HV * V * i_t + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + HV * V * i_t + else: + p_beta = beta + bos * HV + i_hv + HV * i_t + p_g = g + bos * HV + i_hv + HV * i_t + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + HV * V * i_t + + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_g = tl.load(p_g).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + # [BK, BV] + b_h *= tl.exp(b_g) + # b_h *= exp(b_g) + # [BV] + b_v -= tl.sum(b_h * b_k[:, None], 0) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + # [BK, BV] + b_h += b_k[:, None] * b_v[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # keep the states for multi-query tokens + if INPLACE_FINAL_STATE: + p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq + + i_t).to(tl.int64) * stride_final_state_token + else: + p_ht = ht + (bos + i_t) * stride_final_state_token + p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + # p_q += H * K + # p_k += H * K + # p_o += HV * V + # p_v += HV * V + # p_g += HV + # p_beta += HV * (V if IS_BETA_HEADWISE else 1) + + +def fused_recurrent_gated_delta_rule_fwd( + q, + k, + v, + g, + beta, + scale = True, + initial_state = True, + inplace_final_state = True, + cu_seqlens = None, + ssm_state_indices = None, + num_accepted_tokens = None, + use_qk_l2norm_in_kernel = False, +): + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + + o = mint.empty((NK, *v.shape), dtype=q.dtype) + # o = q.new_empty(NK, *v.shape) + if inplace_final_state: + final_state = initial_state + else: + # final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) + final_state = mint.empty((T, HV, K, V), dtype=initial_state.dtype) + + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = final_state.stride(0) + + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + # print("N: ", N) + # print("T: ", T) + # print("B: ", B) + # print("H: ", H) + # print("HV: ", HV) + # print("K: ", K) + # print("V: ", V) + # print("BK: ", BK) + # print("BV: ", BV) + + grid = (NK, NV, N * HV) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + scale=scale, + N=N, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + INPLACE_FINAL_STATE=inplace_final_state, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state \ No newline at end of file diff --git a/vllm_mindspore/model_executor/models/utils.py b/vllm_mindspore/model_executor/models/utils.py index 5d9e56d3431257e81e07c47515b6084b5fa38f0a..ff269ed5643659eea108db811bcaf39bd8a3d80d 100644 --- a/vllm_mindspore/model_executor/models/utils.py +++ b/vllm_mindspore/model_executor/models/utils.py @@ -23,7 +23,7 @@ from dataclasses import dataclass, field from typing import Optional, Union import mindspore as ms -from mindspore import mint, ops +from mindspore import mint, nn, ops from vllm import envs from vllm.sequence import IntermediateTensors @@ -285,3 +285,34 @@ def is_use_ringmla(vllm_config, mf_config=None): and vllm_config.model_config.quantization is not None and vllm_config.parallel_config.tensor_parallel_size < 16) return use_ringmla + + +_model_to_pp_missing_layer_names: dict[int, list[str]] = {} + + +def get_pp_missing_layer_names(model: nn.Cell) -> list[str]: + """Get the names of the missing layers in a pipeline parallel model.""" + model_id = id(model) + if model_id in _model_to_pp_missing_layer_names: + return _model_to_pp_missing_layer_names[model_id] + + missing_layer_names = [] + for name, cell in model.cells_and_names(): + if isinstance(cell, PPMissingLayer): + # NOTE: the trailing dot is used to match the prefix of the layer. + # without the dot, we could match a layer that is not missing, + # e.g., 'encoder.layer.1' would match 'encoder.layer.11' + missing_layer_names.append(name + '.') + _model_to_pp_missing_layer_names[model_id] = missing_layer_names + + return missing_layer_names + + +def is_pp_missing_parameter(name: str, model: nn.Cell) -> bool: + """Check if a parameter is missing in a pipeline parallel model.""" + if isinstance(model, PPMissingLayer): + return True + + return any( + name.startswith(missing_layer_name) + for missing_layer_name in get_pp_missing_layer_names(model)) diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index f3a9adbef212300bfc831dd68c1b4bebc5378255..98f2af7bb6a22be53034d418b2bcd4a193c74cb0 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -303,6 +303,11 @@ def is_310p(): return device in ['310p', 'ascend310p'] +def is_910b(): + device = get_ascend_soc_version() + return '910b' in device.lower() + + def check_ready(): from mindspore import set_context diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index 64973e5b9dc0d41903cff85b216680afcc76287c..f23c1970613789a7d347e04d26d06436e123b87b 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -27,6 +27,7 @@ import torch from mindspore import Generator as msGenerator from mindspore import Tensor, mint, mutable, ops from vllm.attention import AttentionType +from vllm.config import get_layers_from_vllm_config from vllm.logger import init_logger from vllm.sampling_params import SamplingType from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -39,6 +40,7 @@ from vllm.v1.worker.utils import initialize_kv_cache_for_kv_sharing from vllm_mindspore.model_executor.layers.rotary_embedding import ( InferMRotaryEmbedding as MRotaryEmbedding) +from vllm_mindspore.model_executor.models.model_base import AttentionWrapper from vllm_mindspore.model_executor.models.utils import is_use_ringmla from vllm_mindspore.utils import (create_kv_cache, get_dtype_size, get_valid_dtype, is_310p) @@ -737,7 +739,6 @@ def wrapper_gpu_model_runner_execute_model(func): def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: - forward_ctx = self.vllm_config.compilation_config.static_forward_context block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla fa3_quant = self.vllm_config.quant_config.fa3_quant \ @@ -745,7 +746,9 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: fa3_quant_layer = self.vllm_config.quant_config.fa3_quant_layer \ if self.vllm_config.quant_config else set() kv_cache_spec: dict[str, KVCacheSpec] = {} - for layer_name, attn_module in forward_ctx.items(): + attn_layers = get_layers_from_vllm_config(self.vllm_config, + AttentionWrapper) + for layer_name, attn_module in attn_layers.items(): """ vllm-mindspore AttentionWrapper is not an Attention isinstance assert isinstance(attn_module, Attention)