From 40ff789befe250fdf6ce8f067a9ba7a9a86549b0 Mon Sep 17 00:00:00 2001 From: superxf Date: Fri, 10 Oct 2025 10:02:14 +0800 Subject: [PATCH 1/3] support 310p v0 --- .../python/cases_parallel/vllm_mf_qwen3_8b.py | 10 -- tests/st/python/test_cases_parallel.py | 20 ++- vllm_mindspore/__init__.py | 15 +++ vllm_mindspore/config.py | 5 - .../distributed/communication_op.py | 14 ++ vllm_mindspore/distributed/parallel_state.py | 113 ++++++++++++++++ .../model_executor/layers/linear.py | 9 ++ .../model_executor/layers/logits_processor.py | 124 ++++++++++++++---- .../quantization/smooth_quant_modelslim.py | 6 +- .../layers/vocab_parallel_embedding.py | 9 ++ .../model_loader/weight_utils.py | 16 +++ .../model_executor/models/model_base.py | 6 +- vllm_mindspore/platforms/ascend.py | 8 ++ vllm_mindspore/worker/cache_engine.py | 7 +- vllm_mindspore/worker/model_runner.py | 17 +-- 15 files changed, 316 insertions(+), 63 deletions(-) create mode 100644 vllm_mindspore/distributed/parallel_state.py diff --git a/tests/st/python/cases_parallel/vllm_mf_qwen3_8b.py b/tests/st/python/cases_parallel/vllm_mf_qwen3_8b.py index 3a66efa3..7f3fc37b 100644 --- a/tests/st/python/cases_parallel/vllm_mf_qwen3_8b.py +++ b/tests/st/python/cases_parallel/vllm_mf_qwen3_8b.py @@ -84,14 +84,4 @@ def test_mf_qwen3_v1(): env_vars["VLLM_USE_V1"] = "1" env_manager.setup_ai_environment(env_vars) run_mf_qwen3_networt() - env_manager.unset_all() - - -def test_mf_qwen3_v1_310p(): - """Test qwen3 8B using V1 LLMEngine in 310p.""" - env_vars["VLLM_USE_V1"] = "1" - # In 310p, INTERNAL BOOST will case unsupported kernel fusion - env_vars["MS_ENABLE_INTERNAL_BOOST"] = "off" - env_manager.setup_ai_environment(env_vars) - run_mf_qwen3_networt() env_manager.unset_all() \ No newline at end of file diff --git a/tests/st/python/test_cases_parallel.py b/tests/st/python/test_cases_parallel.py index 9a5a665f..bc24db0d 100644 --- a/tests/st/python/test_cases_parallel.py +++ b/tests/st/python/test_cases_parallel.py @@ -265,14 +265,30 @@ def test_cases_parallel_level4_mcore2(): @pytest.mark.level0 @pytest.mark.platform_ascend310p @pytest.mark.env_single -def test_cases_parallel_310p_part0(): +def test_cases_parallel_310p_v0_part0(): """ Feature: test cases parallel in 310p. Description: test cases parallel. Expectation: Pass. """ cases = [ - (2, "cases_parallel/vllm_mf_qwen3_8b.py::test_mf_qwen3_v1_310p", + (2, "cases_parallel/vllm_mf_qwen3_8b.py::test_mf_qwen3_v0", + "vllm_mf_qwen3_8b_v0_310p_test_mf_qwen3.log"), + ] + run_tasks(cases) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend310p +@pytest.mark.env_single +def test_cases_parallel_310p_v1_part0(): + """ + Feature: test cases parallel in 310p. + Description: test cases parallel. + Expectation: Pass. + """ + cases = [ + (2, "cases_parallel/vllm_mf_qwen3_8b.py::test_mf_qwen3_v1", "vllm_mf_qwen3_8b_v1_310p_test_mf_qwen3.log"), ] run_tasks(cases) diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index a2d0c6e0..83a25659 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -342,6 +342,21 @@ RejectionSampler._smallest_positive_value = _smallest_positive_value RejectionSampler._smallest_positive_value.__set_name__( RejectionSampler, "_smallest_positive_value") +import vllm.distributed.communication_op +import vllm.worker.worker_base +from vllm_mindspore.distributed.communication_op import ( + cpu_broadcast_tensor_dict) + +vllm.distributed.communication_op.broadcast_tensor_dict = ( + cpu_broadcast_tensor_dict) +vllm.worker.worker_base.broadcast_tensor_dict = cpu_broadcast_tensor_dict + +import vllm.distributed.parallel_state +from vllm_mindspore.distributed.parallel_state import gc_broadcast_tensor_dict + +vllm.distributed.parallel_state.GroupCoordinator.broadcast_tensor_dict = ( + gc_broadcast_tensor_dict) + ######### for multi-model from vllm_mindspore.inputs.registry import call_hf_processor from vllm.inputs.registry import InputProcessingContext diff --git a/vllm_mindspore/config.py b/vllm_mindspore/config.py index 7f93164b..88445fa0 100644 --- a/vllm_mindspore/config.py +++ b/vllm_mindspore/config.py @@ -36,8 +36,6 @@ from vllm.config import (_STR_DTYPE_TO_TORCH_DTYPE, CacheConfig, from vllm.logger import init_logger from vllm.utils import random_uuid -from vllm_mindspore.utils import is_310p - logger = init_logger(__name__) @@ -247,9 +245,6 @@ def _get_and_verify_dtype( else: raise ValueError(f"Unknown dtype: {dtype}") - if torch_dtype == torch.bfloat16 and is_310p(): - torch_dtype = torch.float16 - if torch_dtype != config_dtype: if torch_dtype == torch.float32: # Upcasting to float32 is allowed. diff --git a/vllm_mindspore/distributed/communication_op.py b/vllm_mindspore/distributed/communication_op.py index c933dc4a..35a1c1d8 100644 --- a/vllm_mindspore/distributed/communication_op.py +++ b/vllm_mindspore/distributed/communication_op.py @@ -21,11 +21,25 @@ Implement a unified communication interface for both graph and pynative mode. """ +from typing import Any, Optional, Union + +import torch from mindspore import nn, ops from vllm.distributed.parallel_state import ( get_tensor_model_parallel_world_size, get_tp_group) +def cpu_broadcast_tensor_dict(tensor_dict: Optional[dict[Any, + Union[torch.Tensor, + Any]]] = None, + src: int = 0): + if not torch.distributed.is_initialized(): + return tensor_dict + return get_tp_group().broadcast_tensor_dict(tensor_dict, + src, + group=get_tp_group().cpu_group) + + class ReduceFromModelParallelRegion(nn.Cell): "All reduce the input from the model parallel region." diff --git a/vllm_mindspore/distributed/parallel_state.py b/vllm_mindspore/distributed/parallel_state.py new file mode 100644 index 00000000..b13777ad --- /dev/null +++ b/vllm_mindspore/distributed/parallel_state.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Communication functions are adapted from +# https://github.com/vllm-project/vllm/blob/v0.8.3/vllm/distributed/communication_op.py +# +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2024-2025 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Adaption for cpu broadcast.""" +from typing import Any, Optional, Union + +import torch +import torch.distributed +from torch.distributed import ProcessGroup +from vllm.distributed.parallel_state import TensorMetadata, _split_tensor_dict + +from vllm_mindspore.utils import is_310p + + +def gc_broadcast_tensor_dict( + self, + tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None +) -> Optional[dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if (not torch.distributed.is_initialized() or self.world_size == 1): + return tensor_dict + + if not is_310p(): + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + + rank_in_group = self.rank_in_group + if rank_in_group == src: + metadata_list: list[tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), (f"Expecting a dictionary, got {type(tensor_dict)}") + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + for async_handle in async_handles: + async_handle.wait() + return tensor_dict diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index 76518fea..4a0c72f3 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -39,6 +39,7 @@ from vllm_mindspore.model_executor.layers.quantization.base_config import ( from vllm_mindspore.model_executor.model_loader.weight_utils import ( split_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs +from vllm_mindspore.utils import FORMAT_TYPE, is_310p WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", @@ -82,6 +83,10 @@ class LinearMethodBase(QuantizeMethodBase): Expects create_weights to have been called before on the layer.""" raise NotImplementedError + def format_to_nz(self, param): + cast_weight = ops.auto_generate.format_cast(param, FORMAT_TYPE['nz']) + param.set_data(cast_weight) + class UnquantizedLinearMethod(LinearMethodBase): """Linear method without quantization.""" @@ -118,6 +123,10 @@ class UnquantizedLinearMethod(LinearMethodBase): x = x.view(output_shape) return x + def process_weights_after_loading(self, layer): + if is_310p(): + self.format_to_nz(layer.weight) + class LinearBase(nn.Cell): """Base linear layer. diff --git a/vllm_mindspore/model_executor/layers/logits_processor.py b/vllm_mindspore/model_executor/layers/logits_processor.py index ee8c8edc..55a2ce6b 100644 --- a/vllm_mindspore/model_executor/layers/logits_processor.py +++ b/vllm_mindspore/model_executor/layers/logits_processor.py @@ -23,12 +23,13 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional import vllm.envs as envs -from mindspore import Tensor, mint, nn -from vllm.config import current_platform -from vllm.distributed import (tensor_model_parallel_all_gather, - tensor_model_parallel_gather) +from mindspore import Tensor, jit, mint, nn +from vllm.config import get_current_vllm_config +from vllm.distributed import tensor_model_parallel_gather from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm_mindspore.distributed.communication_op import ( + AllGatherFromModelParallelRegion) from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -60,6 +61,9 @@ class LogitsProcessor(nn.Cell): scale: A scaling factor to apply to the logits. """ super().__init__() + vllm_config = get_current_vllm_config() + self.vllm_config = vllm_config + self.is_graph_mode = bool(not vllm_config.model_config.enforce_eager) self.scale = scale self.vocab_size = vocab_size # Whether the input is logits (default is hidden states). @@ -69,27 +73,104 @@ class LogitsProcessor(nn.Cell): # Soft cap the logits. Used in Gemma 2. self.soft_cap = soft_cap # Whether to use gather or all-gather to gather the logits. - self.use_all_gather = current_platform.use_all_gather() + self.use_all_gather = True - def construct( + if self.use_all_gather: + self.tensor_model_parallel_all_gather = \ + AllGatherFromModelParallelRegion() + self.lm_head = None + self.run_model = None + self.cached_input_info = {} + + def set_dynamic_inputs(self): + dyn_hidden_states = Tensor(shape=[None, None], + dtype=self.vllm_config.model_config.dtype) + + if self.cached_input_info["indices"] is None: + dyn_indices = None + else: + dyn_indices_shape = [ + None for _ in range(self.cached_input_info["indices"]["ndim"]) + ] + dyn_indices_dtype = self.cached_input_info["indices"]["dtype"] + dyn_indices = Tensor(shape=dyn_indices_shape, + dtype=dyn_indices_dtype) + + if self.cached_input_info["bias"] is None: + dyn_bias = None + else: + dyn_bias_shape = [ + None for _ in range(self.cached_input_info["bias"]["ndim"]) + ] + dyn_bias_dtype = self.cached_input_info["bias"]["dtype"] + dyn_bias = Tensor(shape=dyn_bias_shape, dtype=dyn_bias_dtype) + + self.set_inputs(dyn_hidden_states, dyn_indices, dyn_bias) + + def __call__( self, lm_head: VocabParallelEmbedding, hidden_states: Tensor, sampling_metadata: Optional[SamplingMetadata] = None, embedding_bias: Optional[Tensor] = None, + ) -> Optional[Tensor]: + if self.lm_head is None: + self.lm_head = lm_head + if self.run_model is None: + self.run_model = jit( + function=self.construct, + jit_level='O0') if self.is_graph_mode else self.construct + selected_token_indices = None + if sampling_metadata is not None: + selected_token_indices = sampling_metadata.selected_token_indices + dyn_indices_info = None if selected_token_indices is None else { + "ndim": selected_token_indices.ndim, + "dtype": selected_token_indices.dtype, + } + dyn_bias_info = None if embedding_bias is None else { + "ndim": embedding_bias.ndim, + "dtype": embedding_bias.dtype, + } + if self.cached_input_info != { + "indices": dyn_indices_info, + "bias": dyn_bias_info + }: + self.cached_input_info = { + "indices": dyn_indices_info, + "bias": dyn_bias_info, + } + self.set_dynamic_inputs() + + if selected_token_indices is not None and selected_token_indices.numel( + ) <= 0: + logits = mint.zeros((0, self.vocab_size), + dtype=hidden_states.dtype) + else: + logits = self.run_model(hidden_states, selected_token_indices, + embedding_bias) + + if sampling_metadata is not None and \ + sampling_metadata.seq_groups is not None: + logits = _apply_logits_processors(logits, sampling_metadata) + + return logits + + def construct( + self, + hidden_states: Tensor, + selected_token_indices: Optional[Tensor] = None, + embedding_bias: Optional[Tensor] = None, ) -> Optional[Tensor]: if self.logits_as_input: logits = hidden_states else: - if sampling_metadata is not None: - if sampling_metadata.selected_token_indices.numel() <= 0: - return mint.zeros((0, self.vocab_size), - dtype=hidden_states.dtype) - hidden_states = _prune_hidden_states(hidden_states, - sampling_metadata) + if selected_token_indices is not None: + hidden_states = mint.index_select(hidden_states, 0, + selected_token_indices) # Get the logits for the next tokens. - logits = self._get_logits(hidden_states, lm_head, embedding_bias) + logits = self._get_logits(hidden_states, self.lm_head, + embedding_bias) if logits is not None: if self.soft_cap is not None: logits = logits / self.soft_cap @@ -100,10 +181,6 @@ class LogitsProcessor(nn.Cell): logits *= self.scale # Apply logits processors (if any). - if sampling_metadata is not None and \ - sampling_metadata.seq_groups is not None: - logits = _apply_logits_processors(logits, sampling_metadata) - return logits def _get_logits( @@ -118,7 +195,7 @@ class LogitsProcessor(nn.Cell): bias=embedding_bias) if self.use_all_gather: # Gather is not supported for some devices such as NPUs. - logits = tensor_model_parallel_all_gather(logits) + logits = self.tensor_model_parallel_all_gather(logits) else: # None may be returned for rank > 0 logits = tensor_model_parallel_gather(logits) @@ -134,17 +211,6 @@ class LogitsProcessor(nn.Cell): return s -def _prune_hidden_states( - hidden_states: Tensor, - sampling_metadata: SamplingMetadata, -) -> Tensor: - indices = sampling_metadata.selected_token_indices - if indices is not None and indices.numel() > 0: - return mint.index_select(hidden_states, 0, - sampling_metadata.selected_token_indices) - return hidden_states - - def _apply_logits_processors( logits: Tensor, sampling_metadata: SamplingMetadata, diff --git a/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py b/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py index 841e1797..7cb17c5f 100644 --- a/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py +++ b/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py @@ -14,11 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re from typing import Any, Optional import mindspore import numpy as np +import regex as re from mindspore import Parameter, Tensor, mint from mindspore.common.initializer import initializer from mindspore.ops.auto_generate import (DynamicQuantExt, GroupedMatmul, @@ -324,6 +324,8 @@ class A8W8LinearMethod(LinearMethodBase): layer.insert_param_to_cell("input_offset", input_offset) def process_weights_after_loading(self, layer: mindspore.nn.Cell) -> None: + if is_310p(): + self.format_to_nz(layer.weight) input_offset = np.array([0]) params_dtype = layer.params_dtype layer.input_offset = Parameter(Tensor(input_offset, @@ -498,6 +500,8 @@ class A8W8DYNLinearMethod(LinearMethodBase): layer.insert_param_to_cell("smooth_scale", smooth_scale) def process_weights_after_loading(self, layer: mindspore.nn.Cell) -> None: + if is_310p(): + self.format_to_nz(layer.weight) if self.is_2d_smooth_scale: smooth_scale = layer.smooth_scale.asnumpy().reshape(1, -1) layer.smooth_scale = Parameter(Tensor(smooth_scale, diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index abd9cb24..8547586a 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -38,6 +38,7 @@ from vllm_mindspore.model_executor.layers.quantization.base_config import ( from vllm_mindspore.model_executor.model_loader.weight_utils import ( split_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs +from vllm_mindspore.utils import FORMAT_TYPE, is_310p DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -76,6 +77,14 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): def embedding(self, layer: nn.Cell, input_: Tensor) -> Tensor: return mint.index_select(layer.weight, 0, input_) + def format_to_nz(self, param): + cast_weight = ops.auto_generate.format_cast(param, FORMAT_TYPE['nz']) + param.set_data(cast_weight) + + def process_weights_after_loading(self, layer): + if is_310p(): + self.format_to_nz(layer.weight) + def get_masked_input_and_mask( input_: Tensor, diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 6bf2dd4c..7a4ff7fd 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -22,12 +22,15 @@ from collections.abc import Generator from typing import Any import mindspore as ms +import numpy as np from mindspore import Parameter from safetensors import safe_open from tqdm.auto import tqdm from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, enable_tqdm) +from vllm_mindspore.utils import is_310p + def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): """ @@ -39,6 +42,10 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): """ if shard_dim is None: loaded_weight = loaded_weight[:] + loaded_weight = (loaded_weight.astype(np.float16) if + ((str(loaded_weight.dtype) == "float32" + or str(loaded_weight.dtype) == "bfloat16") + and is_310p()) else loaded_weight) return loaded_weight end_idx = start_idx + shard_size @@ -50,6 +57,11 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): loaded_weight = loaded_weight[:, :, start_idx:end_idx] else: raise ValueError("shard_dim:{} is not supported.".format(shard_dim)) + loaded_weight = (loaded_weight.astype(np.float16) if + ((str(loaded_weight.dtype) == "float32" + or str(loaded_weight.dtype) == "bfloat16") + and is_310p()) else loaded_weight) + return loaded_weight @@ -77,4 +89,8 @@ def safetensors_weights_iterator( def default_weight_loader(param: Parameter, loaded_weight: Any) -> None: """Default weight loader.""" loaded_weight = loaded_weight[:] + loaded_weight = (loaded_weight.astype(np.float16) if + ((str(loaded_weight.dtype) == "float32" + or str(loaded_weight.dtype) == "bfloat16") + and is_310p()) else loaded_weight) param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 524e31ce..602a8314 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -34,7 +34,8 @@ from vllm_mindspore.model_executor.models.attention_mask import ( LowerTriangularMask) from vllm_mindspore.model_executor.models.utils import is_use_ringmla from vllm_mindspore.model_executor.utils import set_model_context -from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE, create_kv_cache +from vllm_mindspore.utils import (STR_DTYPE_TO_MS_DTYPE, create_kv_cache, + is_310p) from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata @@ -446,7 +447,8 @@ class NativeModel(MsModelBase): block_size = self.cache_config.block_size num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) head_size = self.model_config.get_head_size() - kv_cache_shape = (None, block_size, num_kv_heads, head_size) + kv_cache_shape = (None, block_size, num_kv_heads * head_size) \ + if is_310p() else (None, block_size, num_kv_heads, head_size) kv_cache_dtype = (self.model_config.dtype if self.cache_config.cache_dtype == "auto" else diff --git a/vllm_mindspore/platforms/ascend.py b/vllm_mindspore/platforms/ascend.py index 8e39c3d1..3f86a551 100644 --- a/vllm_mindspore/platforms/ascend.py +++ b/vllm_mindspore/platforms/ascend.py @@ -26,6 +26,8 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms.interface import Platform, PlatformEnum, _Backend +from vllm_mindspore.utils import is_310p + if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig else: @@ -180,3 +182,9 @@ class AscendPlatform(Platform): if ASCEND_QUANTIZATION_METHOD not in quant_action.choices: quant_action.choices.extend(ASCEND_QUANTIZATION_METHOD) logger.debug("--quantization support ascend/golden-stick.") + + @property + def supported_dtypes(self) -> list[torch.dtype]: + if is_310p(): + return [torch.float16, torch.float32] + return [torch.bfloat16, torch.float16, torch.float32] \ No newline at end of file diff --git a/vllm_mindspore/worker/cache_engine.py b/vllm_mindspore/worker/cache_engine.py index b57b8833..2ff204c6 100644 --- a/vllm_mindspore/worker/cache_engine.py +++ b/vllm_mindspore/worker/cache_engine.py @@ -26,13 +26,16 @@ from mindspore import mutable, mint from typing import List from vllm.logger import init_logger -from vllm_mindspore.utils import MsKVCache, get_valid_dtype +from vllm_mindspore.utils import MsKVCache, get_valid_dtype, create_kv_cache logger = init_logger(__name__) def create_block(shape, dtype, name=None, device=None): - blocks = mint.empty(shape, dtype=dtype, device=device) + if device == "CPU": + blocks = mint.empty(shape, dtype=dtype, device=device) + else: + blocks = create_kv_cache(shape, dtype) return blocks diff --git a/vllm_mindspore/worker/model_runner.py b/vllm_mindspore/worker/model_runner.py index 7fd89fc5..8f1aba25 100644 --- a/vllm_mindspore/worker/model_runner.py +++ b/vllm_mindspore/worker/model_runner.py @@ -28,7 +28,7 @@ from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.sequence import SequenceGroupMetadata -from vllm_mindspore.utils import STR_DTYPE_TO_TENSOR_DTYPE +from vllm_mindspore.utils import STR_DTYPE_TO_TENSOR_DTYPE, create_kv_cache logger = init_logger(__name__) @@ -142,17 +142,10 @@ def _dummy_run(self, head_size = self.model_config.get_head_size() kv_shape = [0, block_size, num_kv_heads, head_size] kv_caches = mutable([ - mutable( - ( - mutable( - torch.tensor([], - dtype=kv_cache_dtype, - device=self.device).reshape(kv_shape)), - mutable( - torch.tensor([], - dtype=kv_cache_dtype, - device=self.device).reshape(kv_shape)), - )) for _ in range(num_layers) + mutable(( + mutable(create_kv_cache(kv_shape, kv_cache_dtype)), + mutable(create_kv_cache(kv_shape, kv_cache_dtype)), + )) for _ in range(num_layers) ]) finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( -- Gitee From 004ae22bec6d0e36281815b382dd4129ced0dc3c Mon Sep 17 00:00:00 2001 From: superxf Date: Sat, 11 Oct 2025 18:02:42 +0800 Subject: [PATCH 2/3] add --- tests/st/python/test_cases_parallel.py | 14 +- vllm_mindspore/__init__.py | 14 +- .../distributed/communication_op.py | 14 - vllm_mindspore/distributed/parallel_state.py | 13 +- .../model_executor/layers/linear.py | 8 +- .../model_executor/layers/logits_processor.py | 243 ++++++++++++------ .../quantization/smooth_quant_modelslim.py | 6 +- .../layers/vocab_parallel_embedding.py | 8 +- .../model_executor/model_loader/utils.py | 31 +++ .../model_loader/weight_utils.py | 21 +- .../model_executor/models/model_base.py | 3 + vllm_mindspore/model_executor/models/qwen2.py | 8 +- vllm_mindspore/model_executor/models/qwen3.py | 9 +- vllm_mindspore/platforms/ascend.py | 1 + vllm_mindspore/utils.py | 18 ++ 15 files changed, 264 insertions(+), 147 deletions(-) diff --git a/tests/st/python/test_cases_parallel.py b/tests/st/python/test_cases_parallel.py index bc24db0d..5c8ceb69 100644 --- a/tests/st/python/test_cases_parallel.py +++ b/tests/st/python/test_cases_parallel.py @@ -265,15 +265,15 @@ def test_cases_parallel_level4_mcore2(): @pytest.mark.level0 @pytest.mark.platform_ascend310p @pytest.mark.env_single -def test_cases_parallel_310p_v0_part0(): +def test_cases_parallel_310p_v1_part0(): """ Feature: test cases parallel in 310p. Description: test cases parallel. Expectation: Pass. """ cases = [ - (2, "cases_parallel/vllm_mf_qwen3_8b.py::test_mf_qwen3_v0", - "vllm_mf_qwen3_8b_v0_310p_test_mf_qwen3.log"), + (2, "cases_parallel/vllm_mf_qwen3_8b.py::test_mf_qwen3_v1", + "vllm_mf_qwen3_8b_v1_310p_test_mf_qwen3.log"), ] run_tasks(cases) @@ -281,14 +281,14 @@ def test_cases_parallel_310p_v0_part0(): @pytest.mark.level0 @pytest.mark.platform_ascend310p @pytest.mark.env_single -def test_cases_parallel_310p_v1_part0(): +def test_cases_parallel_310p_v0_part0(): """ Feature: test cases parallel in 310p. Description: test cases parallel. Expectation: Pass. """ cases = [ - (2, "cases_parallel/vllm_mf_qwen3_8b.py::test_mf_qwen3_v1", - "vllm_mf_qwen3_8b_v1_310p_test_mf_qwen3.log"), + (2, "cases_parallel/vllm_mf_qwen3_8b.py::test_mf_qwen3_v0", + "vllm_mf_qwen3_8b_v0_310p_test_mf_qwen3.log"), ] - run_tasks(cases) + run_tasks(cases) \ No newline at end of file diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 83a25659..4537193b 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -341,15 +341,13 @@ from vllm.model_executor.layers.rejection_sampler import RejectionSampler RejectionSampler._smallest_positive_value = _smallest_positive_value RejectionSampler._smallest_positive_value.__set_name__( RejectionSampler, "_smallest_positive_value") +from vllm_mindspore.model_executor.model_loader.utils import ( + process_weights_after_loading) -import vllm.distributed.communication_op -import vllm.worker.worker_base -from vllm_mindspore.distributed.communication_op import ( - cpu_broadcast_tensor_dict) - -vllm.distributed.communication_op.broadcast_tensor_dict = ( - cpu_broadcast_tensor_dict) -vllm.worker.worker_base.broadcast_tensor_dict = cpu_broadcast_tensor_dict +vllm.model_executor.model_loader.utils.process_weights_after_loading = ( + process_weights_after_loading) +vllm.model_executor.model_loader.base_loader.process_weights_after_loading = ( + process_weights_after_loading) import vllm.distributed.parallel_state from vllm_mindspore.distributed.parallel_state import gc_broadcast_tensor_dict diff --git a/vllm_mindspore/distributed/communication_op.py b/vllm_mindspore/distributed/communication_op.py index 35a1c1d8..c933dc4a 100644 --- a/vllm_mindspore/distributed/communication_op.py +++ b/vllm_mindspore/distributed/communication_op.py @@ -21,25 +21,11 @@ Implement a unified communication interface for both graph and pynative mode. """ -from typing import Any, Optional, Union - -import torch from mindspore import nn, ops from vllm.distributed.parallel_state import ( get_tensor_model_parallel_world_size, get_tp_group) -def cpu_broadcast_tensor_dict(tensor_dict: Optional[dict[Any, - Union[torch.Tensor, - Any]]] = None, - src: int = 0): - if not torch.distributed.is_initialized(): - return tensor_dict - return get_tp_group().broadcast_tensor_dict(tensor_dict, - src, - group=get_tp_group().cpu_group) - - class ReduceFromModelParallelRegion(nn.Cell): "All reduce the input from the model parallel region." diff --git a/vllm_mindspore/distributed/parallel_state.py b/vllm_mindspore/distributed/parallel_state.py index b13777ad..371511e2 100644 --- a/vllm_mindspore/distributed/parallel_state.py +++ b/vllm_mindspore/distributed/parallel_state.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Communication functions are adapted from -# https://github.com/vllm-project/vllm/blob/v0.8.3/vllm/distributed/communication_op.py +# https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/distributed/parallel_state.py # # Copyright 2025 Huawei Technologies Co., Ltd. # Copyright 2024-2025 The vLLM team. @@ -23,7 +23,8 @@ from typing import Any, Optional, Union import torch import torch.distributed from torch.distributed import ProcessGroup -from vllm.distributed.parallel_state import TensorMetadata, _split_tensor_dict +from vllm.distributed.parallel_state import (TensorMetadata, + _split_tensor_dict, get_tp_group) from vllm_mindspore.utils import is_310p @@ -36,14 +37,14 @@ def gc_broadcast_tensor_dict( metadata_group: Optional[ProcessGroup] = None ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: """Broadcast the input tensor dictionary. - NOTE: `src` is the local rank of the source rank. - """ + NOTE: `src` is the local rank of the source rank. + """ # Bypass the function if we are using only 1 GPU. if (not torch.distributed.is_initialized() or self.world_size == 1): return tensor_dict - if not is_310p(): - group = self.device_group + # Due to 310p limitations, the broadcast is in cpu group + group = get_tp_group().cpu_group if is_310p() else self.device_group metadata_group = self.cpu_group assert src < self.world_size, f"Invalid src rank ({src})" diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index 4a0c72f3..983931ec 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -39,7 +39,7 @@ from vllm_mindspore.model_executor.layers.quantization.base_config import ( from vllm_mindspore.model_executor.model_loader.weight_utils import ( split_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs -from vllm_mindspore.utils import FORMAT_TYPE, is_310p +from vllm_mindspore.utils import is_310p, set_weight_format_to_nz WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", @@ -83,10 +83,6 @@ class LinearMethodBase(QuantizeMethodBase): Expects create_weights to have been called before on the layer.""" raise NotImplementedError - def format_to_nz(self, param): - cast_weight = ops.auto_generate.format_cast(param, FORMAT_TYPE['nz']) - param.set_data(cast_weight) - class UnquantizedLinearMethod(LinearMethodBase): """Linear method without quantization.""" @@ -125,7 +121,7 @@ class UnquantizedLinearMethod(LinearMethodBase): def process_weights_after_loading(self, layer): if is_310p(): - self.format_to_nz(layer.weight) + set_weight_format_to_nz(layer.weight) class LinearBase(nn.Cell): diff --git a/vllm_mindspore/model_executor/layers/logits_processor.py b/vllm_mindspore/model_executor/layers/logits_processor.py index 55a2ce6b..ba300d86 100644 --- a/vllm_mindspore/model_executor/layers/logits_processor.py +++ b/vllm_mindspore/model_executor/layers/logits_processor.py @@ -24,8 +24,9 @@ from typing import Optional import vllm.envs as envs from mindspore import Tensor, jit, mint, nn -from vllm.config import get_current_vllm_config -from vllm.distributed import tensor_model_parallel_gather +from vllm.config import current_platform, get_current_vllm_config +from vllm.distributed import (tensor_model_parallel_all_gather, + tensor_model_parallel_gather) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm_mindspore.distributed.communication_op import ( @@ -61,9 +62,6 @@ class LogitsProcessor(nn.Cell): scale: A scaling factor to apply to the logits. """ super().__init__() - vllm_config = get_current_vllm_config() - self.vllm_config = vllm_config - self.is_graph_mode = bool(not vllm_config.model_config.enforce_eager) self.scale = scale self.vocab_size = vocab_size # Whether the input is logits (default is hidden states). @@ -73,11 +71,168 @@ class LogitsProcessor(nn.Cell): # Soft cap the logits. Used in Gemma 2. self.soft_cap = soft_cap # Whether to use gather or all-gather to gather the logits. - self.use_all_gather = True + self.use_all_gather = current_platform.use_all_gather() + + def construct( + self, + lm_head: VocabParallelEmbedding, + hidden_states: Tensor, + sampling_metadata: Optional[SamplingMetadata] = None, + embedding_bias: Optional[Tensor] = None, + ) -> Optional[Tensor]: + if self.logits_as_input: + logits = hidden_states + else: + if sampling_metadata is not None: + if sampling_metadata.selected_token_indices.numel() <= 0: + return mint.zeros((0, self.vocab_size), + dtype=hidden_states.dtype) + hidden_states = _prune_hidden_states(hidden_states, + sampling_metadata) + + # Get the logits for the next tokens. + logits = self._get_logits(hidden_states, lm_head, embedding_bias) + if logits is not None: + if self.soft_cap is not None: + logits = logits / self.soft_cap + logits = mint.tanh(logits) + logits = logits * self.soft_cap + + if self.scale != 1.0: + logits *= self.scale + + # Apply logits processors (if any). + if sampling_metadata is not None and \ + sampling_metadata.seq_groups is not None: + logits = _apply_logits_processors(logits, sampling_metadata) + return logits + + def _get_logits( + self, + hidden_states: Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[Tensor], + ) -> Optional[Tensor]: + # Get the logits for the next tokens. + logits = lm_head.quant_method.apply(lm_head, + hidden_states, + bias=embedding_bias) if self.use_all_gather: - self.tensor_model_parallel_all_gather = \ - AllGatherFromModelParallelRegion() + # Gather is not supported for some devices such as NPUs. + logits = tensor_model_parallel_all_gather(logits) + else: + # None may be returned for rank > 0 + logits = tensor_model_parallel_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[..., :self.org_vocab_size] + return logits + + def extra_repr(self) -> str: + s = f"vocab_size={self.vocab_size}" + s += f", forg_vocab_size={self.org_vocab_size}" + s += f", scale={self.scale}, logits_as_input={self.logits_as_input}" + return s + + +def _prune_hidden_states( + hidden_states: Tensor, + sampling_metadata: SamplingMetadata, +) -> Tensor: + indices = sampling_metadata.selected_token_indices + if indices is not None and indices.numel() > 0: + return mint.index_select(hidden_states, 0, + sampling_metadata.selected_token_indices) + return hidden_states + + +def _apply_logits_processors( + logits: Tensor, + sampling_metadata: SamplingMetadata, +) -> Tensor: + found_logits_processors = False + logits_processed = 0 + logits_row_ids_and_logits_row_futures = [] + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params + logits_processors = sampling_params.logits_processors + if logits_processors: + found_logits_processors = True + + for seq_id, logits_row_idx in zip(seq_ids, + seq_group.sample_indices): + logits_row = logits[logits_row_idx] + past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids + prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids + + if _logits_processor_threadpool is not None: + logits_row_ids_and_logits_row_futures.append( + (logits_row_idx, + _logits_processor_threadpool.submit( + _apply_logits_processors_single_seq, logits_row, + logits_processors, past_tokens_ids, + prompt_tokens_ids))) + else: + logits[logits_row_idx] = \ + _apply_logits_processors_single_seq( + logits_row, logits_processors, past_tokens_ids, + prompt_tokens_ids) + + logits_processed += len(seq_group.sample_indices) + len( + seq_group.prompt_logprob_indices) + + for logits_row_idx, future in logits_row_ids_and_logits_row_futures: + logits[logits_row_idx] = future.result() + + if found_logits_processors: + # verifies that no rows in logits were missed unexpectedly + assert logits_processed == logits.shape[0] + return logits + + +def _apply_logits_processors_single_seq(logits_row, logits_processors, + past_tokens_ids, + prompt_tokens_ids) -> Tensor: + for logits_processor in logits_processors: + parameters = inspect.signature(logits_processor).parameters + if len(parameters) == 3: + logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids, + logits_row) + else: + logits_row = logits_processor(past_tokens_ids, logits_row) + return logits_row + + +class LogitsProcessor310P(LogitsProcessor): + """Process logits for 310P, running in graph mode for better performance. + + This layer does the following: + 1. Gather logits from model hidden_states. + 2. Scale logits if needed. + 3. Apply logits processors (if any). + """ + + def __init__( + self, + vocab_size: int, + org_vocab_size: Optional[int] = None, + scale: float = 1.0, + logits_as_input: bool = False, + soft_cap: Optional[float] = None, + ) -> None: + """ + Args: + scale: A scaling factor to apply to the logits. + """ + super().__init__(vocab_size, org_vocab_size, scale, logits_as_input, + soft_cap) + vllm_config = get_current_vllm_config() + self.vllm_config = vllm_config + self.is_graph_mode = bool(not vllm_config.model_config.enforce_eager) + self.tensor_model_parallel_all_gather = \ + AllGatherFromModelParallelRegion() self.lm_head = None self.run_model = None self.cached_input_info = {} @@ -193,77 +348,9 @@ class LogitsProcessor(nn.Cell): logits = lm_head.quant_method.apply(lm_head, hidden_states, bias=embedding_bias) - if self.use_all_gather: - # Gather is not supported for some devices such as NPUs. - logits = self.tensor_model_parallel_all_gather(logits) - else: - # None may be returned for rank > 0 - logits = tensor_model_parallel_gather(logits) + # For 310p, all gather has better performance. + logits = self.tensor_model_parallel_all_gather(logits) # Remove paddings in vocab (if any). if logits is not None: logits = logits[..., :self.org_vocab_size] return logits - - def extra_repr(self) -> str: - s = f"vocab_size={self.vocab_size}" - s += f", forg_vocab_size={self.org_vocab_size}" - s += f", scale={self.scale}, logits_as_input={self.logits_as_input}" - return s - - -def _apply_logits_processors( - logits: Tensor, - sampling_metadata: SamplingMetadata, -) -> Tensor: - found_logits_processors = False - logits_processed = 0 - logits_row_ids_and_logits_row_futures = [] - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - logits_processors = sampling_params.logits_processors - if logits_processors: - found_logits_processors = True - - for seq_id, logits_row_idx in zip(seq_ids, - seq_group.sample_indices): - logits_row = logits[logits_row_idx] - past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids - prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids - - if _logits_processor_threadpool is not None: - logits_row_ids_and_logits_row_futures.append( - (logits_row_idx, - _logits_processor_threadpool.submit( - _apply_logits_processors_single_seq, logits_row, - logits_processors, past_tokens_ids, - prompt_tokens_ids))) - else: - logits[logits_row_idx] = \ - _apply_logits_processors_single_seq( - logits_row, logits_processors, past_tokens_ids, - prompt_tokens_ids) - - logits_processed += len(seq_group.sample_indices) + len( - seq_group.prompt_logprob_indices) - - for logits_row_idx, future in logits_row_ids_and_logits_row_futures: - logits[logits_row_idx] = future.result() - - if found_logits_processors: - # verifies that no rows in logits were missed unexpectedly - assert logits_processed == logits.shape[0] - return logits - - -def _apply_logits_processors_single_seq(logits_row, logits_processors, - past_tokens_ids, - prompt_tokens_ids) -> Tensor: - for logits_processor in logits_processors: - parameters = inspect.signature(logits_processor).parameters - if len(parameters) == 3: - logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids, - logits_row) - else: - logits_row = logits_processor(past_tokens_ids, logits_row) - return logits_row diff --git a/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py b/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py index 7cb17c5f..8e1772dc 100644 --- a/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py +++ b/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py @@ -29,7 +29,7 @@ from vllm_mindspore.attention import Attention from vllm_mindspore.model_executor.layers.linear import ( LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm_mindspore.model_executor.utils import set_weight_attrs -from vllm_mindspore.utils import is_310p +from vllm_mindspore.utils import is_310p, set_weight_format_to_nz from .attention import BaseKVCacheMethod, KVCacheInt8Method from .base_config import QuantizationConfig, QuantizeMethodBase @@ -325,7 +325,7 @@ class A8W8LinearMethod(LinearMethodBase): def process_weights_after_loading(self, layer: mindspore.nn.Cell) -> None: if is_310p(): - self.format_to_nz(layer.weight) + set_weight_format_to_nz(layer.weight) input_offset = np.array([0]) params_dtype = layer.params_dtype layer.input_offset = Parameter(Tensor(input_offset, @@ -501,7 +501,7 @@ class A8W8DYNLinearMethod(LinearMethodBase): def process_weights_after_loading(self, layer: mindspore.nn.Cell) -> None: if is_310p(): - self.format_to_nz(layer.weight) + set_weight_format_to_nz(layer.weight) if self.is_2d_smooth_scale: smooth_scale = layer.smooth_scale.asnumpy().reshape(1, -1) layer.smooth_scale = Parameter(Tensor(smooth_scale, diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index 8547586a..70654aa0 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -38,7 +38,7 @@ from vllm_mindspore.model_executor.layers.quantization.base_config import ( from vllm_mindspore.model_executor.model_loader.weight_utils import ( split_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs -from vllm_mindspore.utils import FORMAT_TYPE, is_310p +from vllm_mindspore.utils import is_310p, set_weight_format_to_nz DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -77,13 +77,9 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): def embedding(self, layer: nn.Cell, input_: Tensor) -> Tensor: return mint.index_select(layer.weight, 0, input_) - def format_to_nz(self, param): - cast_weight = ops.auto_generate.format_cast(param, FORMAT_TYPE['nz']) - param.set_data(cast_weight) - def process_weights_after_loading(self, layer): if is_310p(): - self.format_to_nz(layer.weight) + set_weight_format_to_nz(layer.weight) def get_masked_input_and_mask( diff --git a/vllm_mindspore/model_executor/model_loader/utils.py b/vllm_mindspore/model_executor/model_loader/utils.py index fb155027..fbb5fcf1 100644 --- a/vllm_mindspore/model_executor/model_loader/utils.py +++ b/vllm_mindspore/model_executor/model_loader/utils.py @@ -20,10 +20,14 @@ """ utils for load model """ from mindspore import nn +from vllm.attention import Attention from vllm.config import ModelConfig, ModelImpl +from vllm.model_executor.layers.linear import QKVCrossParallelLinear from vllm.model_executor.model_loader.utils import logger from vllm.model_executor.models import ModelRegistry +from vllm_mindspore.model_executor.layers.quantization.base_config import ( + QuantizeMethodBase) from vllm_mindspore.model_executor.models.registry import ( AUTO_SELECT_FIXED_MODEL, MindSporeModelRegistry, mcore_support_list, mf_supported, mindone_supported) @@ -188,3 +192,30 @@ def get_ms_model_architecture( raise RecursionError("MindSpore unsupported reward model task now!") return model_cls, arch + + +def process_weights_after_loading(model, model_config, target_device) -> None: + for _, module in model.named_modules(): + if isinstance(module, QKVCrossParallelLinear): + # NOTE(Isotr0py): special case for cross QKV layer because + # q and kv proj aren't registered as submodules intentionally + module.process_weights_after_loading() + continue + quant_method = getattr(module, "quant_method", None) + if isinstance(quant_method, QuantizeMethodBase): + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + quant_method.process_weights_after_loading(module) + + # Currently only used by MLA. + # NOTE: This intentionally happens after other modules so we can easily + # decompress the weights for MLA. + for _, module in model.named_modules(): + if isinstance(module, Attention) and \ + hasattr(module, "process_weights_after_loading"): + # TODO(lucas): see if there is a way to unify the signatures + # of process_weights_after_loading + module.process_weights_after_loading(model_config.dtype) \ No newline at end of file diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 7a4ff7fd..d409e80a 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -22,14 +22,13 @@ from collections.abc import Generator from typing import Any import mindspore as ms -import numpy as np from mindspore import Parameter from safetensors import safe_open from tqdm.auto import tqdm from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, enable_tqdm) -from vllm_mindspore.utils import is_310p +from vllm_mindspore.utils import cast_weight_for_310p, is_310p def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): @@ -42,10 +41,8 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): """ if shard_dim is None: loaded_weight = loaded_weight[:] - loaded_weight = (loaded_weight.astype(np.float16) if - ((str(loaded_weight.dtype) == "float32" - or str(loaded_weight.dtype) == "bfloat16") - and is_310p()) else loaded_weight) + if is_310p(): + loaded_weight = cast_weight_for_310p(loaded_weight) return loaded_weight end_idx = start_idx + shard_size @@ -57,10 +54,8 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): loaded_weight = loaded_weight[:, :, start_idx:end_idx] else: raise ValueError("shard_dim:{} is not supported.".format(shard_dim)) - loaded_weight = (loaded_weight.astype(np.float16) if - ((str(loaded_weight.dtype) == "float32" - or str(loaded_weight.dtype) == "bfloat16") - and is_310p()) else loaded_weight) + if is_310p(): + loaded_weight = cast_weight_for_310p(loaded_weight) return loaded_weight @@ -89,8 +84,6 @@ def safetensors_weights_iterator( def default_weight_loader(param: Parameter, loaded_weight: Any) -> None: """Default weight loader.""" loaded_weight = loaded_weight[:] - loaded_weight = (loaded_weight.astype(np.float16) if - ((str(loaded_weight.dtype) == "float32" - or str(loaded_weight.dtype) == "bfloat16") - and is_310p()) else loaded_weight) + if is_310p(): + loaded_weight = cast_weight_for_310p(loaded_weight) param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 602a8314..d5ba05e3 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -447,6 +447,9 @@ class NativeModel(MsModelBase): block_size = self.cache_config.block_size num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) head_size = self.model_config.get_head_size() + # FormatCast op is used to convert kv_cache to nz format for 310p. + # The kv shape is flattened to 3 dimensions, + # because FormatCast op only support 3d input. kv_cache_shape = (None, block_size, num_kv_heads * head_size) \ if is_310p() else (None, block_size, num_kv_heads, head_size) diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index db8e31c0..a1755372 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -46,12 +46,13 @@ from vllm.sequence import IntermediateTensors from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm_mindspore.attention import Attention +from vllm_mindspore.utils import is_310p from vllm_mindspore.model_executor.layers.activation import SiluAndMul from vllm_mindspore.model_executor.layers.layernorm import RMSNorm from vllm_mindspore.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm_mindspore.model_executor.layers.logits_processor import \ - LogitsProcessor + LogitsProcessor, LogitsProcessor310P from vllm_mindspore.model_executor.layers.rotary_embedding import get_rope from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -462,7 +463,10 @@ class Qwen2ForCausalLM(NativeModel, SupportsLoRA): else: self.lm_head = PPMissingLayer() - self.logits_processor = LogitsProcessor(self.config.vocab_size) + if is_310p(): + self.logits_processor = LogitsProcessor310P(config.vocab_size) + else: + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm_mindspore/model_executor/models/qwen3.py b/vllm_mindspore/model_executor/models/qwen3.py index 89a98b58..b464ee4c 100644 --- a/vllm_mindspore/model_executor/models/qwen3.py +++ b/vllm_mindspore/model_executor/models/qwen3.py @@ -58,12 +58,13 @@ from vllm_mindspore.model_executor.layers.layernorm import RMSNorm from vllm_mindspore.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) from vllm_mindspore.model_executor.layers.logits_processor import ( - LogitsProcessor) + LogitsProcessor, LogitsProcessor310P) from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead) from vllm_mindspore.model_executor.models.model_base import NativeModel from vllm_mindspore.model_executor.models.utils import (PPMissingLayer, maybe_prefix) +from vllm_mindspore.utils import is_310p from vllm_mindspore.model_executor.layers.rotary_embedding import ( # type: ignore[attr-defined] # isort: skip get_rope) @@ -297,8 +298,10 @@ class Qwen3ForCausalLM(NativeModel): prefix, "lm_head")) else: self.lm_head = PPMissingLayer() - - self.logits_processor = LogitsProcessor(config.vocab_size) + if is_310p(): + self.logits_processor = LogitsProcessor310P(config.vocab_size) + else: + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm_mindspore/platforms/ascend.py b/vllm_mindspore/platforms/ascend.py index 3f86a551..4954bd47 100644 --- a/vllm_mindspore/platforms/ascend.py +++ b/vllm_mindspore/platforms/ascend.py @@ -186,5 +186,6 @@ class AscendPlatform(Platform): @property def supported_dtypes(self) -> list[torch.dtype]: if is_310p(): + # bfloat16 is not supported on the 310p due to hardware limitations. return [torch.float16, torch.float32] return [torch.bfloat16, torch.float16, torch.float32] \ No newline at end of file diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index f3a9adbe..7ea9ec95 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -83,6 +83,24 @@ def create_kv_cache(kv_shape, dtype): return ms.mint.zeros(kv_shape, dtype=dtype) +def cast_weight_for_310p(loaded_weight): + """ + Casts weights to float16 for 310p. + + In non-quantized scenarios, the 310P hardware only supports float16 weights. + This function converts float32 or bfloat16 weights to float16. + """ + cast_weight = (loaded_weight.astype(np.float16) if + (str(loaded_weight.dtype) == "float32" or str( + loaded_weight.dtype) == "bfloat16") else loaded_weight) + return cast_weight + + +def set_weight_format_to_nz(param): + cast_weight = ms.ops.auto_generate.format_cast(param, FORMAT_TYPE['nz']) + param.set_data(cast_weight) + + def get_valid_dtype(dtype): if isinstance(dtype, str): dtype = STR_DTYPE_TO_MS_DTYPE[dtype] -- Gitee From 16f1ede99d88d392af8daa98bd498f9e746954c8 Mon Sep 17 00:00:00 2001 From: superxf Date: Wed, 15 Oct 2025 14:31:22 +0800 Subject: [PATCH 3/3] test --- tests/st/python/test_cases_parallel.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/st/python/test_cases_parallel.py b/tests/st/python/test_cases_parallel.py index 5c8ceb69..342057a0 100644 --- a/tests/st/python/test_cases_parallel.py +++ b/tests/st/python/test_cases_parallel.py @@ -287,6 +287,9 @@ def test_cases_parallel_310p_v0_part0(): Description: test cases parallel. Expectation: Pass. """ + print("env info") + for ky, va in os.environ.items(): + print(f"{ky} : {va}") cases = [ (2, "cases_parallel/vllm_mf_qwen3_8b.py::test_mf_qwen3_v0", "vllm_mf_qwen3_8b_v0_310p_test_mf_qwen3.log"), -- Gitee