From d3235745b7629cddd0d839e6d684af609b521c73 Mon Sep 17 00:00:00 2001 From: superxf Date: Fri, 10 Oct 2025 10:02:14 +0800 Subject: [PATCH 1/3] support 310p v0 --- vllm_mindspore/__init__.py | 13 ++ vllm_mindspore/config.py | 2 +- vllm_mindspore/distributed/parallel_state.py | 114 +++++++++++++ .../model_executor/layers/linear.py | 5 + .../model_executor/layers/logits_processor.py | 157 +++++++++++++++++- .../quantization/smooth_quant_modelslim.py | 8 +- .../layers/vocab_parallel_embedding.py | 5 + .../model_executor/model_loader/utils.py | 31 ++++ .../model_loader/weight_utils.py | 9 + .../model_executor/models/model_base.py | 9 +- vllm_mindspore/model_executor/models/qwen2.py | 8 +- vllm_mindspore/model_executor/models/qwen3.py | 9 +- vllm_mindspore/platforms/ascend.py | 9 + vllm_mindspore/utils.py | 18 ++ vllm_mindspore/worker/cache_engine.py | 7 +- vllm_mindspore/worker/model_runner.py | 17 +- 16 files changed, 395 insertions(+), 26 deletions(-) create mode 100644 vllm_mindspore/distributed/parallel_state.py diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 58c78e07..1382967f 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -337,6 +337,19 @@ from vllm.model_executor.layers.rejection_sampler import RejectionSampler RejectionSampler._smallest_positive_value = _smallest_positive_value RejectionSampler._smallest_positive_value.__set_name__( RejectionSampler, "_smallest_positive_value") +from vllm_mindspore.model_executor.model_loader.utils import ( + process_weights_after_loading) + +vllm.model_executor.model_loader.utils.process_weights_after_loading = ( + process_weights_after_loading) +vllm.model_executor.model_loader.base_loader.process_weights_after_loading = ( + process_weights_after_loading) + +import vllm.distributed.parallel_state +from vllm_mindspore.distributed.parallel_state import gc_broadcast_tensor_dict + +vllm.distributed.parallel_state.GroupCoordinator.broadcast_tensor_dict = ( + gc_broadcast_tensor_dict) ######### for multi-model from vllm_mindspore.inputs.registry import call_hf_processor diff --git a/vllm_mindspore/config.py b/vllm_mindspore/config.py index 7f93164b..1f9e8e27 100644 --- a/vllm_mindspore/config.py +++ b/vllm_mindspore/config.py @@ -248,7 +248,7 @@ def _get_and_verify_dtype( raise ValueError(f"Unknown dtype: {dtype}") if torch_dtype == torch.bfloat16 and is_310p(): - torch_dtype = torch.float16 + raise ValueError("For 310p, bfloat16 type is not support") if torch_dtype != config_dtype: if torch_dtype == torch.float32: diff --git a/vllm_mindspore/distributed/parallel_state.py b/vllm_mindspore/distributed/parallel_state.py new file mode 100644 index 00000000..371511e2 --- /dev/null +++ b/vllm_mindspore/distributed/parallel_state.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Communication functions are adapted from +# https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/distributed/parallel_state.py +# +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2024-2025 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Adaption for cpu broadcast.""" +from typing import Any, Optional, Union + +import torch +import torch.distributed +from torch.distributed import ProcessGroup +from vllm.distributed.parallel_state import (TensorMetadata, + _split_tensor_dict, get_tp_group) + +from vllm_mindspore.utils import is_310p + + +def gc_broadcast_tensor_dict( + self, + tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None +) -> Optional[dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if (not torch.distributed.is_initialized() or self.world_size == 1): + return tensor_dict + + # Due to 310p limitations, the broadcast is in cpu group + group = get_tp_group().cpu_group if is_310p() else self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + + rank_in_group = self.rank_in_group + if rank_in_group == src: + metadata_list: list[tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), (f"Expecting a dictionary, got {type(tensor_dict)}") + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + for async_handle in async_handles: + async_handle.wait() + return tensor_dict diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index 92121905..46ea2d7d 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -38,6 +38,7 @@ from vllm_mindspore.model_executor.layers.quantization.base_config import ( from vllm_mindspore.model_executor.model_loader.weight_utils import ( split_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs +from vllm_mindspore.utils import is_310p, set_weight_format_to_nz WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", @@ -117,6 +118,10 @@ class UnquantizedLinearMethod(LinearMethodBase): x = x.view(output_shape) return x + def process_weights_after_loading(self, layer): + if is_310p(): + set_weight_format_to_nz(layer.weight) + class LinearBase(nn.Cell): """Base linear layer. diff --git a/vllm_mindspore/model_executor/layers/logits_processor.py b/vllm_mindspore/model_executor/layers/logits_processor.py index ee8c8edc..ba300d86 100644 --- a/vllm_mindspore/model_executor/layers/logits_processor.py +++ b/vllm_mindspore/model_executor/layers/logits_processor.py @@ -23,12 +23,14 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional import vllm.envs as envs -from mindspore import Tensor, mint, nn -from vllm.config import current_platform +from mindspore import Tensor, jit, mint, nn +from vllm.config import current_platform, get_current_vllm_config from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm_mindspore.distributed.communication_op import ( + AllGatherFromModelParallelRegion) from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -201,3 +203,154 @@ def _apply_logits_processors_single_seq(logits_row, logits_processors, else: logits_row = logits_processor(past_tokens_ids, logits_row) return logits_row + + +class LogitsProcessor310P(LogitsProcessor): + """Process logits for 310P, running in graph mode for better performance. + + This layer does the following: + 1. Gather logits from model hidden_states. + 2. Scale logits if needed. + 3. Apply logits processors (if any). + """ + + def __init__( + self, + vocab_size: int, + org_vocab_size: Optional[int] = None, + scale: float = 1.0, + logits_as_input: bool = False, + soft_cap: Optional[float] = None, + ) -> None: + """ + Args: + scale: A scaling factor to apply to the logits. + """ + super().__init__(vocab_size, org_vocab_size, scale, logits_as_input, + soft_cap) + vllm_config = get_current_vllm_config() + self.vllm_config = vllm_config + self.is_graph_mode = bool(not vllm_config.model_config.enforce_eager) + self.tensor_model_parallel_all_gather = \ + AllGatherFromModelParallelRegion() + self.lm_head = None + self.run_model = None + self.cached_input_info = {} + + def set_dynamic_inputs(self): + dyn_hidden_states = Tensor(shape=[None, None], + dtype=self.vllm_config.model_config.dtype) + + if self.cached_input_info["indices"] is None: + dyn_indices = None + else: + dyn_indices_shape = [ + None for _ in range(self.cached_input_info["indices"]["ndim"]) + ] + dyn_indices_dtype = self.cached_input_info["indices"]["dtype"] + dyn_indices = Tensor(shape=dyn_indices_shape, + dtype=dyn_indices_dtype) + + if self.cached_input_info["bias"] is None: + dyn_bias = None + else: + dyn_bias_shape = [ + None for _ in range(self.cached_input_info["bias"]["ndim"]) + ] + dyn_bias_dtype = self.cached_input_info["bias"]["dtype"] + dyn_bias = Tensor(shape=dyn_bias_shape, dtype=dyn_bias_dtype) + + self.set_inputs(dyn_hidden_states, dyn_indices, dyn_bias) + + def __call__( + self, + lm_head: VocabParallelEmbedding, + hidden_states: Tensor, + sampling_metadata: Optional[SamplingMetadata] = None, + embedding_bias: Optional[Tensor] = None, + ) -> Optional[Tensor]: + if self.lm_head is None: + self.lm_head = lm_head + if self.run_model is None: + self.run_model = jit( + function=self.construct, + jit_level='O0') if self.is_graph_mode else self.construct + selected_token_indices = None + if sampling_metadata is not None: + selected_token_indices = sampling_metadata.selected_token_indices + dyn_indices_info = None if selected_token_indices is None else { + "ndim": selected_token_indices.ndim, + "dtype": selected_token_indices.dtype, + } + dyn_bias_info = None if embedding_bias is None else { + "ndim": embedding_bias.ndim, + "dtype": embedding_bias.dtype, + } + if self.cached_input_info != { + "indices": dyn_indices_info, + "bias": dyn_bias_info + }: + self.cached_input_info = { + "indices": dyn_indices_info, + "bias": dyn_bias_info, + } + self.set_dynamic_inputs() + + if selected_token_indices is not None and selected_token_indices.numel( + ) <= 0: + logits = mint.zeros((0, self.vocab_size), + dtype=hidden_states.dtype) + else: + logits = self.run_model(hidden_states, selected_token_indices, + embedding_bias) + + if sampling_metadata is not None and \ + sampling_metadata.seq_groups is not None: + logits = _apply_logits_processors(logits, sampling_metadata) + + return logits + + def construct( + self, + hidden_states: Tensor, + selected_token_indices: Optional[Tensor] = None, + embedding_bias: Optional[Tensor] = None, + ) -> Optional[Tensor]: + if self.logits_as_input: + logits = hidden_states + else: + if selected_token_indices is not None: + hidden_states = mint.index_select(hidden_states, 0, + selected_token_indices) + + # Get the logits for the next tokens. + logits = self._get_logits(hidden_states, self.lm_head, + embedding_bias) + if logits is not None: + if self.soft_cap is not None: + logits = logits / self.soft_cap + logits = mint.tanh(logits) + logits = logits * self.soft_cap + + if self.scale != 1.0: + logits *= self.scale + + # Apply logits processors (if any). + return logits + + def _get_logits( + self, + hidden_states: Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[Tensor], + ) -> Optional[Tensor]: + # Get the logits for the next tokens. + logits = lm_head.quant_method.apply(lm_head, + hidden_states, + bias=embedding_bias) + # For 310p, all gather has better performance. + logits = self.tensor_model_parallel_all_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[..., :self.org_vocab_size] + return logits diff --git a/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py b/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py index 841e1797..8e1772dc 100644 --- a/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py +++ b/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py @@ -14,11 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re from typing import Any, Optional import mindspore import numpy as np +import regex as re from mindspore import Parameter, Tensor, mint from mindspore.common.initializer import initializer from mindspore.ops.auto_generate import (DynamicQuantExt, GroupedMatmul, @@ -29,7 +29,7 @@ from vllm_mindspore.attention import Attention from vllm_mindspore.model_executor.layers.linear import ( LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm_mindspore.model_executor.utils import set_weight_attrs -from vllm_mindspore.utils import is_310p +from vllm_mindspore.utils import is_310p, set_weight_format_to_nz from .attention import BaseKVCacheMethod, KVCacheInt8Method from .base_config import QuantizationConfig, QuantizeMethodBase @@ -324,6 +324,8 @@ class A8W8LinearMethod(LinearMethodBase): layer.insert_param_to_cell("input_offset", input_offset) def process_weights_after_loading(self, layer: mindspore.nn.Cell) -> None: + if is_310p(): + set_weight_format_to_nz(layer.weight) input_offset = np.array([0]) params_dtype = layer.params_dtype layer.input_offset = Parameter(Tensor(input_offset, @@ -498,6 +500,8 @@ class A8W8DYNLinearMethod(LinearMethodBase): layer.insert_param_to_cell("smooth_scale", smooth_scale) def process_weights_after_loading(self, layer: mindspore.nn.Cell) -> None: + if is_310p(): + set_weight_format_to_nz(layer.weight) if self.is_2d_smooth_scale: smooth_scale = layer.smooth_scale.asnumpy().reshape(1, -1) layer.smooth_scale = Parameter(Tensor(smooth_scale, diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index f25c8059..8524d380 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -38,6 +38,7 @@ from vllm_mindspore.model_executor.layers.quantization.base_config import ( from vllm_mindspore.model_executor.model_loader.weight_utils import ( split_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs +from vllm_mindspore.utils import is_310p, set_weight_format_to_nz DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -76,6 +77,10 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): def embedding(self, layer: nn.Cell, input_: Tensor) -> Tensor: return mint.index_select(layer.weight, 0, input_) + def process_weights_after_loading(self, layer): + if is_310p(): + set_weight_format_to_nz(layer.weight) + def get_masked_input_and_mask( input_: Tensor, diff --git a/vllm_mindspore/model_executor/model_loader/utils.py b/vllm_mindspore/model_executor/model_loader/utils.py index cd76b99a..4d98894f 100644 --- a/vllm_mindspore/model_executor/model_loader/utils.py +++ b/vllm_mindspore/model_executor/model_loader/utils.py @@ -22,10 +22,14 @@ import os from mindspore import nn +from vllm.attention import Attention from vllm.config import ModelConfig, ModelImpl +from vllm.model_executor.layers.linear import QKVCrossParallelLinear from vllm.model_executor.model_loader.utils import logger from vllm.model_executor.models import ModelRegistry +from vllm_mindspore.model_executor.layers.quantization.base_config import ( + QuantizeMethodBase) from vllm_mindspore.model_executor.models.registry import ( AUTO_SELECT_FIXED_MODEL, MindSporeModelRegistry, mcore_support_list, mf_supported, mindone_supported) @@ -198,3 +202,30 @@ def get_ms_model_architecture( raise RecursionError("MindSpore unsupported reward model task now!") return model_cls, arch + + +def process_weights_after_loading(model, model_config, target_device) -> None: + for _, module in model.named_modules(): + if isinstance(module, QKVCrossParallelLinear): + # NOTE(Isotr0py): special case for cross QKV layer because + # q and kv proj aren't registered as submodules intentionally + module.process_weights_after_loading() + continue + quant_method = getattr(module, "quant_method", None) + if isinstance(quant_method, QuantizeMethodBase): + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + quant_method.process_weights_after_loading(module) + + # Currently only used by MLA. + # NOTE: This intentionally happens after other modules so we can easily + # decompress the weights for MLA. + for _, module in model.named_modules(): + if isinstance(module, Attention) and \ + hasattr(module, "process_weights_after_loading"): + # TODO(lucas): see if there is a way to unify the signatures + # of process_weights_after_loading + module.process_weights_after_loading(model_config.dtype) \ No newline at end of file diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 6bf2dd4c..d409e80a 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -28,6 +28,8 @@ from tqdm.auto import tqdm from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, enable_tqdm) +from vllm_mindspore.utils import cast_weight_for_310p, is_310p + def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): """ @@ -39,6 +41,8 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): """ if shard_dim is None: loaded_weight = loaded_weight[:] + if is_310p(): + loaded_weight = cast_weight_for_310p(loaded_weight) return loaded_weight end_idx = start_idx + shard_size @@ -50,6 +54,9 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): loaded_weight = loaded_weight[:, :, start_idx:end_idx] else: raise ValueError("shard_dim:{} is not supported.".format(shard_dim)) + if is_310p(): + loaded_weight = cast_weight_for_310p(loaded_weight) + return loaded_weight @@ -77,4 +84,6 @@ def safetensors_weights_iterator( def default_weight_loader(param: Parameter, loaded_weight: Any) -> None: """Default weight loader.""" loaded_weight = loaded_weight[:] + if is_310p(): + loaded_weight = cast_weight_for_310p(loaded_weight) param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 4c0d729e..31cbd374 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -38,7 +38,8 @@ from vllm_mindspore.model_executor.models.interfaces import ( from vllm_mindspore.model_executor.models.utils import (convert_pin, is_use_ringmla) from vllm_mindspore.model_executor.utils import set_model_context -from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE, create_kv_cache +from vllm_mindspore.utils import (STR_DTYPE_TO_MS_DTYPE, create_kv_cache, + is_310p) from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata @@ -472,7 +473,11 @@ class NativeModel(MsModelBase): block_size = self.cache_config.block_size num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) head_size = self.model_config.get_head_size() - kv_cache_shape = (None, block_size, num_kv_heads, head_size) + # FormatCast op is used to convert kv_cache to nz format for 310p. + # The kv shape is flattened to 3 dimensions, + # because FormatCast op only support 3d input. + kv_cache_shape = (None, block_size, num_kv_heads * head_size) \ + if is_310p() else (None, block_size, num_kv_heads, head_size) kv_cache_dtype = (self.model_config.dtype if self.cache_config.cache_dtype == "auto" else diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index db8e31c0..a1755372 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -46,12 +46,13 @@ from vllm.sequence import IntermediateTensors from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm_mindspore.attention import Attention +from vllm_mindspore.utils import is_310p from vllm_mindspore.model_executor.layers.activation import SiluAndMul from vllm_mindspore.model_executor.layers.layernorm import RMSNorm from vllm_mindspore.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm_mindspore.model_executor.layers.logits_processor import \ - LogitsProcessor + LogitsProcessor, LogitsProcessor310P from vllm_mindspore.model_executor.layers.rotary_embedding import get_rope from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -462,7 +463,10 @@ class Qwen2ForCausalLM(NativeModel, SupportsLoRA): else: self.lm_head = PPMissingLayer() - self.logits_processor = LogitsProcessor(self.config.vocab_size) + if is_310p(): + self.logits_processor = LogitsProcessor310P(config.vocab_size) + else: + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm_mindspore/model_executor/models/qwen3.py b/vllm_mindspore/model_executor/models/qwen3.py index 89a98b58..b464ee4c 100644 --- a/vllm_mindspore/model_executor/models/qwen3.py +++ b/vllm_mindspore/model_executor/models/qwen3.py @@ -58,12 +58,13 @@ from vllm_mindspore.model_executor.layers.layernorm import RMSNorm from vllm_mindspore.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) from vllm_mindspore.model_executor.layers.logits_processor import ( - LogitsProcessor) + LogitsProcessor, LogitsProcessor310P) from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead) from vllm_mindspore.model_executor.models.model_base import NativeModel from vllm_mindspore.model_executor.models.utils import (PPMissingLayer, maybe_prefix) +from vllm_mindspore.utils import is_310p from vllm_mindspore.model_executor.layers.rotary_embedding import ( # type: ignore[attr-defined] # isort: skip get_rope) @@ -297,8 +298,10 @@ class Qwen3ForCausalLM(NativeModel): prefix, "lm_head")) else: self.lm_head = PPMissingLayer() - - self.logits_processor = LogitsProcessor(config.vocab_size) + if is_310p(): + self.logits_processor = LogitsProcessor310P(config.vocab_size) + else: + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm_mindspore/platforms/ascend.py b/vllm_mindspore/platforms/ascend.py index 0df8f3e8..3141e198 100644 --- a/vllm_mindspore/platforms/ascend.py +++ b/vllm_mindspore/platforms/ascend.py @@ -26,6 +26,8 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms.interface import Platform, PlatformEnum, _Backend +from vllm_mindspore.utils import is_310p + if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig else: @@ -181,3 +183,10 @@ class AscendPlatform(Platform): if ASCEND_QUANTIZATION_METHOD not in quant_action.choices: quant_action.choices.extend(ASCEND_QUANTIZATION_METHOD) logger.debug("--quantization support ascend/golden-stick.") + + @property + def supported_dtypes(self) -> list[torch.dtype]: + if is_310p(): + # bfloat16 is not supported on the 310p due to hardware limitations. + return [torch.float16, torch.float32] + return [torch.bfloat16, torch.float16, torch.float32] \ No newline at end of file diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index 7cba47a2..edbbb6b2 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -93,6 +93,24 @@ def create_kv_cache(kv_shape, dtype, is_fa3_quant=False): return ms.mint.zeros(kv_shape, dtype=dtype) +def cast_weight_for_310p(loaded_weight): + """ + Casts weights to float16 for 310p. + + In non-quantized scenarios, the 310P hardware only supports float16 weights. + This function converts float32 or bfloat16 weights to float16. + """ + cast_weight = (loaded_weight.astype(np.float16) if + (str(loaded_weight.dtype) == "float32" or str( + loaded_weight.dtype) == "bfloat16") else loaded_weight) + return cast_weight + + +def set_weight_format_to_nz(param): + cast_weight = ms.ops.auto_generate.format_cast(param, FORMAT_TYPE['nz']) + param.set_data(cast_weight) + + def get_valid_dtype(dtype): if isinstance(dtype, str): dtype = STR_DTYPE_TO_MS_DTYPE[dtype] diff --git a/vllm_mindspore/worker/cache_engine.py b/vllm_mindspore/worker/cache_engine.py index b57b8833..2ff204c6 100644 --- a/vllm_mindspore/worker/cache_engine.py +++ b/vllm_mindspore/worker/cache_engine.py @@ -26,13 +26,16 @@ from mindspore import mutable, mint from typing import List from vllm.logger import init_logger -from vllm_mindspore.utils import MsKVCache, get_valid_dtype +from vllm_mindspore.utils import MsKVCache, get_valid_dtype, create_kv_cache logger = init_logger(__name__) def create_block(shape, dtype, name=None, device=None): - blocks = mint.empty(shape, dtype=dtype, device=device) + if device == "CPU": + blocks = mint.empty(shape, dtype=dtype, device=device) + else: + blocks = create_kv_cache(shape, dtype) return blocks diff --git a/vllm_mindspore/worker/model_runner.py b/vllm_mindspore/worker/model_runner.py index 7fd89fc5..8f1aba25 100644 --- a/vllm_mindspore/worker/model_runner.py +++ b/vllm_mindspore/worker/model_runner.py @@ -28,7 +28,7 @@ from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.sequence import SequenceGroupMetadata -from vllm_mindspore.utils import STR_DTYPE_TO_TENSOR_DTYPE +from vllm_mindspore.utils import STR_DTYPE_TO_TENSOR_DTYPE, create_kv_cache logger = init_logger(__name__) @@ -142,17 +142,10 @@ def _dummy_run(self, head_size = self.model_config.get_head_size() kv_shape = [0, block_size, num_kv_heads, head_size] kv_caches = mutable([ - mutable( - ( - mutable( - torch.tensor([], - dtype=kv_cache_dtype, - device=self.device).reshape(kv_shape)), - mutable( - torch.tensor([], - dtype=kv_cache_dtype, - device=self.device).reshape(kv_shape)), - )) for _ in range(num_layers) + mutable(( + mutable(create_kv_cache(kv_shape, kv_cache_dtype)), + mutable(create_kv_cache(kv_shape, kv_cache_dtype)), + )) for _ in range(num_layers) ]) finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( -- Gitee From 5bb42eb0c229994ec555a318d70c3c3fa9c5bc72 Mon Sep 17 00:00:00 2001 From: huoxinyou Date: Fri, 7 Nov 2025 11:48:28 +0800 Subject: [PATCH 2/3] moe support 310p Signed-off-by: huoxinyou --- .../layers/fused_moe/fused_moe.py | 19 ++++++++++++++----- .../model_executor/layers/fused_moe/layer.py | 6 ++++++ .../layers/vocab_parallel_embedding.py | 2 +- .../model_executor/models/qwen3_moe.py | 11 +++++++++-- 4 files changed, 30 insertions(+), 8 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py index 695c481f..d7ab2412 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -31,7 +31,11 @@ from mindspore.ops.auto_generate import (GroupedMatmulV4, MoeDistributeCombine, from vllm.distributed.parallel_state import get_ep_group from vllm_mindspore.model_executor.layers.fused_moe.config import MoeMode -from vllm_mindspore.utils import is_910b +from vllm_mindspore.utils import is_910b, is_310p + + +def softmax_score_function(x): + return mint.softmax(x, dim=-1, dtype=ms.float32) def fused_topk( @@ -41,10 +45,14 @@ def fused_topk( renormalize: bool, indices_type=None, ) -> tuple[Tensor, Tensor]: - moe_topk_softmax = MoeGatingTopKSoftmax() - topk_weights, topk_ids, _ = moe_topk_softmax(gating_output, None, topk) + if is_310p(): + scores = softmax_score_function(gating_output) + topk_weights, topk_ids = mint.topk(scores, k=topk, dim=-1) + else: + moe_topk_softmax = MoeGatingTopKSoftmax() + topk_weights, topk_ids, _ = moe_topk_softmax(gating_output, None, topk) if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = mint.div(topk_weights, mint.add(mint.sum(topk_weights, -1, True), 1e-20)) if indices_type is not None: topk_ids = topk_ids.to(indices_type) @@ -171,7 +179,8 @@ class FusedExperts(nn.Cell): gate = self._gate_activation(gate, activation) hidden = mint.mul(hidden, gate) expert_output = self._group_matmul(hidden, w2, group_list) - expert_output = mint.nan_to_num(expert_output, 0, 0, 0) + if not is_310p(): + expert_output = mint.nan_to_num(expert_output, 0, 0, 0) return expert_output def run_tp(self, hidden_states, w1, w2, topk_ids, topk_weights, activation, diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 4669352a..4acbe8af 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -42,6 +42,7 @@ from vllm_mindspore.model_executor.layers.quantization.base_config import ( QuantizeMethodBase) from vllm_mindspore.model_executor.model_loader.weight_utils import ( split_loaded_weight) +from vllm_mindspore.utils import is_310p, set_weight_format_to_nz logger = init_logger(__name__) @@ -109,6 +110,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, nn.Cell): set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, {"is_transposed": True}) + def process_weights_after_loading(self, layer): + if is_310p(): + set_weight_format_to_nz(layer.w13_weight) + set_weight_format_to_nz(layer.w2_weight) + def apply( self, layer: nn.Cell, diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index 8524d380..0c678ca1 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -78,7 +78,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): return mint.index_select(layer.weight, 0, input_) def process_weights_after_loading(self, layer): - if is_310p(): + if isinstance(layer, ParallelLMHead) and is_310p(): set_weight_format_to_nz(layer.weight) diff --git a/vllm_mindspore/model_executor/models/qwen3_moe.py b/vllm_mindspore/model_executor/models/qwen3_moe.py index 5c0260e4..746fe952 100644 --- a/vllm_mindspore/model_executor/models/qwen3_moe.py +++ b/vllm_mindspore/model_executor/models/qwen3_moe.py @@ -46,7 +46,7 @@ from vllm_mindspore.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm_mindspore.model_executor.layers.logits_processor import ( - LogitsProcessor) + LogitsProcessor, LogitsProcessor310P) from vllm_mindspore.model_executor.layers.rotary_embedding import get_rope from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -58,6 +58,7 @@ from vllm_mindspore.model_executor.models.model_base import NativeModel from vllm_mindspore.model_executor.models.utils import ( extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm_mindspore.utils import is_310p, cast_weight_for_310p logger = init_logger(__name__) @@ -449,6 +450,9 @@ class Qwen3MoeModel(nn.Cell): loaded_params: set[str] = set() for name, loaded_weight in weights: + if is_310p(): + loaded_weight = loaded_weight[:] + loaded_weight = cast_weight_for_310p(loaded_weight) for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: @@ -537,7 +541,10 @@ class Qwen3MoeForCausalLM(NativeModel, SupportesMoeDpTp, MixtureOfExperts): quant_config=quant_config) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(config.vocab_size) + if is_310p(): + self.logits_processor = LogitsProcessor310P(config.vocab_size) + else: + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) -- Gitee From da5b223871d885c8ba2205b3e55ac196a8558b70 Mon Sep 17 00:00:00 2001 From: huoxinyou Date: Mon, 10 Nov 2025 15:40:10 +0800 Subject: [PATCH 3/3] fix weight_loader --- .../model_executor/layers/fused_moe/layer.py | 6 +++--- vllm_mindspore/model_executor/layers/linear.py | 4 ++-- .../layers/vocab_parallel_embedding.py | 4 ++-- .../model_executor/model_loader/weight_utils.py | 13 +++++++++---- vllm_mindspore/model_executor/models/qwen2_5_vl.py | 3 ++- vllm_mindspore/model_executor/models/qwen3_moe.py | 5 +---- 6 files changed, 19 insertions(+), 16 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 4acbe8af..05eaf438 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -41,7 +41,7 @@ from vllm_mindspore.model_executor.layers.fused_moe.fused_moe import ( from vllm_mindspore.model_executor.layers.quantization.base_config import ( QuantizeMethodBase) from vllm_mindspore.model_executor.model_loader.weight_utils import ( - split_loaded_weight) + split_loaded_weight, convert_loaded_weight) from vllm_mindspore.utils import is_310p, set_weight_format_to_nz logger = init_logger(__name__) @@ -514,7 +514,7 @@ class FusedMoE(nn.Cell): expert_id: int): is_param_transpose = param.is_transposed \ if hasattr(param, "is_transposed") else False - loaded_weight = loaded_weight[:] + loaded_weight = convert_loaded_weight(loaded_weight) if is_param_transpose: loaded_weight = from_numpy(loaded_weight.swapaxes(-1, -2)) else: @@ -534,7 +534,7 @@ class FusedMoE(nn.Cell): assert shard_id in ("w1", "w3") is_param_transpose = param.is_transposed \ if hasattr(param, "is_transposed") else False - loaded_weight = loaded_weight[:] + loaded_weight = convert_loaded_weight(loaded_weight) if is_param_transpose: loaded_weight = from_numpy(loaded_weight.swapaxes(-1, -2)) else: diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index 46ea2d7d..f53fc513 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -36,7 +36,7 @@ from vllm_mindspore.distributed.communication_op import ( from vllm_mindspore.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm_mindspore.model_executor.model_loader.weight_utils import ( - split_loaded_weight) + split_loaded_weight, convert_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs from vllm_mindspore.utils import is_310p, set_weight_format_to_nz @@ -219,7 +219,7 @@ class ReplicatedLinear(LinearBase): self.bias = None def weight_loader(self, param: Parameter, loaded_weight: Tensor): - loaded_weight = loaded_weight[:] + loaded_weight = convert_loaded_weight(loaded_weight) if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index 0c678ca1..ef78d5e4 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -36,7 +36,7 @@ from vllm_mindspore.distributed.communication_op import ( from vllm_mindspore.model_executor.layers.quantization.base_config import ( QuantizeMethodBase, method_has_implemented_embedding) from vllm_mindspore.model_executor.model_loader.weight_utils import ( - split_loaded_weight) + split_loaded_weight, convert_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs from vllm_mindspore.utils import is_310p, set_weight_format_to_nz @@ -345,7 +345,7 @@ class VocabParallelEmbedding(nn.Cell): # If parameter does not have output dim, then it should # be copied onto all gpus (e.g. g_idx for act_order gptq). if output_dim is None: - loaded_weight = loaded_weight[:] + loaded_weight = convert_loaded_weight(loaded_weight) assert param.data.shape == loaded_weight.shape if param.data.shape != loaded_weight.shape: raise ValueError( diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index d409e80a..85f2e3b4 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -31,6 +31,13 @@ from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, from vllm_mindspore.utils import cast_weight_for_310p, is_310p +def convert_loaded_weight(loaded_weight): + """Get all loaded_weight value and dtype conversion on 310p""" + loaded_weight = loaded_weight[:] + if is_310p(): + loaded_weight = cast_weight_for_310p(loaded_weight) + return loaded_weight + def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): """ Read numpy slice data based on axis and slice range. @@ -40,9 +47,7 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): :shard_size: end slice index """ if shard_dim is None: - loaded_weight = loaded_weight[:] - if is_310p(): - loaded_weight = cast_weight_for_310p(loaded_weight) + loaded_weight = convert_loaded_weight(loaded_weight) return loaded_weight end_idx = start_idx + shard_size @@ -83,7 +88,7 @@ def safetensors_weights_iterator( def default_weight_loader(param: Parameter, loaded_weight: Any) -> None: """Default weight loader.""" - loaded_weight = loaded_weight[:] + loaded_weight = convert_loaded_weight(loaded_weight) if is_310p(): loaded_weight = cast_weight_for_310p(loaded_weight) param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) diff --git a/vllm_mindspore/model_executor/models/qwen2_5_vl.py b/vllm_mindspore/model_executor/models/qwen2_5_vl.py index 12589a9d..540bc41a 100644 --- a/vllm_mindspore/model_executor/models/qwen2_5_vl.py +++ b/vllm_mindspore/model_executor/models/qwen2_5_vl.py @@ -90,6 +90,7 @@ from vllm_mindspore.model_executor.models.model_base import NativeModel, \ AttentionWrapper from vllm_mindspore.model_executor.models.qwen2 import Qwen2Model from vllm_mindspore.model_executor.models.utils import PPMissingLayer +from vllm_mindspore.model_executor.model_loader.weight_utils import convert_loaded_weight from .interfaces import (SupportsMultiModal) from .utils import (WeightsMapper, maybe_prefix, merge_multimodal_embeddings) @@ -1059,7 +1060,7 @@ class Qwen2_5_VisionTransformer(nn.Cell): for name, loaded_weight in weights: param = params_dict[name] if name == "visual.patch_embed.proj.weight": - loaded_weight = loaded_weight[:] + loaded_weight = convert_loaded_weight(loaded_weight) loaded_weight = loaded_weight.reshape(loaded_weight.shape[0], -1) param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) diff --git a/vllm_mindspore/model_executor/models/qwen3_moe.py b/vllm_mindspore/model_executor/models/qwen3_moe.py index 746fe952..c7138633 100644 --- a/vllm_mindspore/model_executor/models/qwen3_moe.py +++ b/vllm_mindspore/model_executor/models/qwen3_moe.py @@ -58,7 +58,7 @@ from vllm_mindspore.model_executor.models.model_base import NativeModel from vllm_mindspore.model_executor.models.utils import ( extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm_mindspore.utils import is_310p, cast_weight_for_310p +from vllm_mindspore.utils import is_310p logger = init_logger(__name__) @@ -450,9 +450,6 @@ class Qwen3MoeModel(nn.Cell): loaded_params: set[str] = set() for name, loaded_weight in weights: - if is_310p(): - loaded_weight = loaded_weight[:] - loaded_weight = cast_weight_for_310p(loaded_weight) for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: -- Gitee