From 558497c868d320380d9310d8c28d2c815e674f0e Mon Sep 17 00:00:00 2001 From: moran Date: Tue, 11 Nov 2025 11:34:30 +0800 Subject: [PATCH 1/8] fix codecheck in new branch --- codecheck_toolkits/vllm_codecheck.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codecheck_toolkits/vllm_codecheck.sh b/codecheck_toolkits/vllm_codecheck.sh index 72891a68..5486f650 100644 --- a/codecheck_toolkits/vllm_codecheck.sh +++ b/codecheck_toolkits/vllm_codecheck.sh @@ -26,7 +26,7 @@ pip install -r codecheck_toolkits/requirements-lint.txt RUN_GATE=${1:-0} if [ "$RUN_GATE" -eq 0 ]; then - pre-commit run --from-ref origin/master --to-ref HEAD + pre-commit run --from-ref origin/br_infer_boom_1115 --to-ref HEAD else pre-commit run --all-files fi -- Gitee From b909d6ce0795e87ad3e1bfcfee79641dad0fb4a8 Mon Sep 17 00:00:00 2001 From: superxf Date: Fri, 10 Oct 2025 10:02:14 +0800 Subject: [PATCH 2/8] support 310p v0 --- vllm_mindspore/__init__.py | 16 ++ vllm_mindspore/config.py | 2 +- vllm_mindspore/distributed/parallel_state.py | 114 ++++++++++++ .../model_executor/layers/linear.py | 5 + .../model_executor/layers/logits_processor.py | 167 +++++++++++++++++- .../quantization/smooth_quant_modelslim.py | 8 +- .../layers/vocab_parallel_embedding.py | 5 + .../model_executor/model_loader/utils.py | 11 ++ .../model_loader/weight_utils.py | 9 + .../model_executor/models/model_base.py | 9 +- vllm_mindspore/platforms/ascend.py | 9 + vllm_mindspore/utils.py | 18 ++ vllm_mindspore/worker/cache_engine.py | 7 +- vllm_mindspore/worker/model_runner.py | 17 +- 14 files changed, 376 insertions(+), 21 deletions(-) create mode 100644 vllm_mindspore/distributed/parallel_state.py diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 84c5d126..7129aa91 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -340,6 +340,22 @@ RejectionSampler._smallest_positive_value = _smallest_positive_value RejectionSampler._smallest_positive_value.__set_name__( RejectionSampler, "_smallest_positive_value") +from vllm_mindspore.model_executor.model_loader.utils import ( + ms_device_loading_context) + +vllm.model_executor.model_loader.utils.device_loading_context = ( + ms_device_loading_context) + +from vllm_mindspore.model_executor.layers.quantization.base_config import ( + QuantizeMethodBase) + +vllm.model_executor.model_loader.utils.QuantizeMethodBase = QuantizeMethodBase + +from vllm_mindspore.distributed.parallel_state import gc_broadcast_tensor_dict + +vllm.distributed.parallel_state.GroupCoordinator.broadcast_tensor_dict = ( + gc_broadcast_tensor_dict) + ######### for multi-model from vllm_mindspore.inputs.registry import call_hf_processor from vllm.inputs.registry import InputProcessingContext diff --git a/vllm_mindspore/config.py b/vllm_mindspore/config.py index 7f93164b..1f9e8e27 100644 --- a/vllm_mindspore/config.py +++ b/vllm_mindspore/config.py @@ -248,7 +248,7 @@ def _get_and_verify_dtype( raise ValueError(f"Unknown dtype: {dtype}") if torch_dtype == torch.bfloat16 and is_310p(): - torch_dtype = torch.float16 + raise ValueError("For 310p, bfloat16 type is not support") if torch_dtype != config_dtype: if torch_dtype == torch.float32: diff --git a/vllm_mindspore/distributed/parallel_state.py b/vllm_mindspore/distributed/parallel_state.py new file mode 100644 index 00000000..371511e2 --- /dev/null +++ b/vllm_mindspore/distributed/parallel_state.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Communication functions are adapted from +# https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/distributed/parallel_state.py +# +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2024-2025 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Adaption for cpu broadcast.""" +from typing import Any, Optional, Union + +import torch +import torch.distributed +from torch.distributed import ProcessGroup +from vllm.distributed.parallel_state import (TensorMetadata, + _split_tensor_dict, get_tp_group) + +from vllm_mindspore.utils import is_310p + + +def gc_broadcast_tensor_dict( + self, + tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None +) -> Optional[dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if (not torch.distributed.is_initialized() or self.world_size == 1): + return tensor_dict + + # Due to 310p limitations, the broadcast is in cpu group + group = get_tp_group().cpu_group if is_310p() else self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + + rank_in_group = self.rank_in_group + if rank_in_group == src: + metadata_list: list[tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), (f"Expecting a dictionary, got {type(tensor_dict)}") + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + for async_handle in async_handles: + async_handle.wait() + return tensor_dict diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index 92121905..46ea2d7d 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -38,6 +38,7 @@ from vllm_mindspore.model_executor.layers.quantization.base_config import ( from vllm_mindspore.model_executor.model_loader.weight_utils import ( split_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs +from vllm_mindspore.utils import is_310p, set_weight_format_to_nz WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", @@ -117,6 +118,10 @@ class UnquantizedLinearMethod(LinearMethodBase): x = x.view(output_shape) return x + def process_weights_after_loading(self, layer): + if is_310p(): + set_weight_format_to_nz(layer.weight) + class LinearBase(nn.Cell): """Base linear layer. diff --git a/vllm_mindspore/model_executor/layers/logits_processor.py b/vllm_mindspore/model_executor/layers/logits_processor.py index ee8c8edc..c4f6f22e 100644 --- a/vllm_mindspore/model_executor/layers/logits_processor.py +++ b/vllm_mindspore/model_executor/layers/logits_processor.py @@ -23,19 +23,24 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional import vllm.envs as envs -from mindspore import Tensor, mint, nn -from vllm.config import current_platform +from mindspore import Tensor, jit, mint, nn +from vllm.config import current_platform, get_current_vllm_config from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) +from vllm.logger import init_logger from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm_mindspore.distributed.communication_op import ( + AllGatherFromModelParallelRegion) from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm_mindspore.utils import is_310p _logits_processor_threadpool: Optional[ThreadPoolExecutor] = None if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None: _logits_processor_threadpool = ThreadPoolExecutor( envs.VLLM_LOGITS_PROCESSOR_THREADS) +logger = init_logger(__name__) class LogitsProcessor(nn.Cell): @@ -47,6 +52,13 @@ class LogitsProcessor(nn.Cell): 3. Apply logits processors (if any). """ + def __new__(cls, *args, **kwargs): + if cls is LogitsProcessor and is_310p(): + logger.info( + "In 310p, use LogitsProcessorGraph to run in graph mode") + return LogitsProcessorGraph(*args, **kwargs) + return super().__new__(cls) + def __init__( self, vocab_size: int, @@ -201,3 +213,154 @@ def _apply_logits_processors_single_seq(logits_row, logits_processors, else: logits_row = logits_processor(past_tokens_ids, logits_row) return logits_row + + +class LogitsProcessorGraph(LogitsProcessor): + """Process logits for 310P, running in graph mode for better performance. + + This layer does the following: + 1. Gather logits from model hidden_states. + 2. Scale logits if needed. + 3. Apply logits processors (if any). + """ + + def __init__( + self, + vocab_size: int, + org_vocab_size: Optional[int] = None, + scale: float = 1.0, + logits_as_input: bool = False, + soft_cap: Optional[float] = None, + ) -> None: + """ + Args: + scale: A scaling factor to apply to the logits. + """ + super().__init__(vocab_size, org_vocab_size, scale, logits_as_input, + soft_cap) + vllm_config = get_current_vllm_config() + self.vllm_config = vllm_config + self.is_graph_mode = bool(not vllm_config.model_config.enforce_eager) + self.tensor_model_parallel_all_gather = \ + AllGatherFromModelParallelRegion() + self.lm_head = None + self.run_model = None + self.cached_input_info = {} + + def set_dynamic_inputs(self): + dyn_hidden_states = Tensor(shape=[None, None], + dtype=self.vllm_config.model_config.dtype) + + if self.cached_input_info["indices"] is None: + dyn_indices = None + else: + dyn_indices_shape = [ + None for _ in range(self.cached_input_info["indices"]["ndim"]) + ] + dyn_indices_dtype = self.cached_input_info["indices"]["dtype"] + dyn_indices = Tensor(shape=dyn_indices_shape, + dtype=dyn_indices_dtype) + + if self.cached_input_info["bias"] is None: + dyn_bias = None + else: + dyn_bias_shape = [ + None for _ in range(self.cached_input_info["bias"]["ndim"]) + ] + dyn_bias_dtype = self.cached_input_info["bias"]["dtype"] + dyn_bias = Tensor(shape=dyn_bias_shape, dtype=dyn_bias_dtype) + + self.set_inputs(dyn_hidden_states, dyn_indices, dyn_bias) + + def __call__( + self, + lm_head: VocabParallelEmbedding, + hidden_states: Tensor, + sampling_metadata: Optional[SamplingMetadata] = None, + embedding_bias: Optional[Tensor] = None, + ) -> Optional[Tensor]: + if self.lm_head is None: + self.lm_head = lm_head + if self.run_model is None: + self.run_model = jit( + function=self.construct, + jit_level='O0') if self.is_graph_mode else self.construct + selected_token_indices = None + if sampling_metadata is not None: + selected_token_indices = sampling_metadata.selected_token_indices + dyn_indices_info = None if selected_token_indices is None else { + "ndim": selected_token_indices.ndim, + "dtype": selected_token_indices.dtype, + } + dyn_bias_info = None if embedding_bias is None else { + "ndim": embedding_bias.ndim, + "dtype": embedding_bias.dtype, + } + if self.cached_input_info != { + "indices": dyn_indices_info, + "bias": dyn_bias_info + }: + self.cached_input_info = { + "indices": dyn_indices_info, + "bias": dyn_bias_info, + } + self.set_dynamic_inputs() + + if selected_token_indices is not None and selected_token_indices.numel( + ) <= 0: + logits = mint.zeros((0, self.vocab_size), + dtype=hidden_states.dtype) + else: + logits = self.run_model(hidden_states, selected_token_indices, + embedding_bias) + + if sampling_metadata is not None and \ + sampling_metadata.seq_groups is not None: + logits = _apply_logits_processors(logits, sampling_metadata) + + return logits + + def construct( + self, + hidden_states: Tensor, + selected_token_indices: Optional[Tensor] = None, + embedding_bias: Optional[Tensor] = None, + ) -> Optional[Tensor]: + if self.logits_as_input: + logits = hidden_states + else: + if selected_token_indices is not None: + hidden_states = mint.index_select(hidden_states, 0, + selected_token_indices) + + # Get the logits for the next tokens. + logits = self._get_logits(hidden_states, self.lm_head, + embedding_bias) + if logits is not None: + if self.soft_cap is not None: + logits = logits / self.soft_cap + logits = mint.tanh(logits) + logits = logits * self.soft_cap + + if self.scale != 1.0: + logits *= self.scale + + # Apply logits processors (if any). + return logits + + def _get_logits( + self, + hidden_states: Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[Tensor], + ) -> Optional[Tensor]: + # Get the logits for the next tokens. + logits = lm_head.quant_method.apply(lm_head, + hidden_states, + bias=embedding_bias) + # For 310p, all gather has better performance. + logits = self.tensor_model_parallel_all_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[..., :self.org_vocab_size] + return logits diff --git a/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py b/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py index 841e1797..8e1772dc 100644 --- a/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py +++ b/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py @@ -14,11 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re from typing import Any, Optional import mindspore import numpy as np +import regex as re from mindspore import Parameter, Tensor, mint from mindspore.common.initializer import initializer from mindspore.ops.auto_generate import (DynamicQuantExt, GroupedMatmul, @@ -29,7 +29,7 @@ from vllm_mindspore.attention import Attention from vllm_mindspore.model_executor.layers.linear import ( LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm_mindspore.model_executor.utils import set_weight_attrs -from vllm_mindspore.utils import is_310p +from vllm_mindspore.utils import is_310p, set_weight_format_to_nz from .attention import BaseKVCacheMethod, KVCacheInt8Method from .base_config import QuantizationConfig, QuantizeMethodBase @@ -324,6 +324,8 @@ class A8W8LinearMethod(LinearMethodBase): layer.insert_param_to_cell("input_offset", input_offset) def process_weights_after_loading(self, layer: mindspore.nn.Cell) -> None: + if is_310p(): + set_weight_format_to_nz(layer.weight) input_offset = np.array([0]) params_dtype = layer.params_dtype layer.input_offset = Parameter(Tensor(input_offset, @@ -498,6 +500,8 @@ class A8W8DYNLinearMethod(LinearMethodBase): layer.insert_param_to_cell("smooth_scale", smooth_scale) def process_weights_after_loading(self, layer: mindspore.nn.Cell) -> None: + if is_310p(): + set_weight_format_to_nz(layer.weight) if self.is_2d_smooth_scale: smooth_scale = layer.smooth_scale.asnumpy().reshape(1, -1) layer.smooth_scale = Parameter(Tensor(smooth_scale, diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index f25c8059..0c678ca1 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -38,6 +38,7 @@ from vllm_mindspore.model_executor.layers.quantization.base_config import ( from vllm_mindspore.model_executor.model_loader.weight_utils import ( split_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs +from vllm_mindspore.utils import is_310p, set_weight_format_to_nz DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -76,6 +77,10 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): def embedding(self, layer: nn.Cell, input_: Tensor) -> Tensor: return mint.index_select(layer.weight, 0, input_) + def process_weights_after_loading(self, layer): + if isinstance(layer, ParallelLMHead) and is_310p(): + set_weight_format_to_nz(layer.weight) + def get_masked_input_and_mask( input_: Tensor, diff --git a/vllm_mindspore/model_executor/model_loader/utils.py b/vllm_mindspore/model_executor/model_loader/utils.py index cd76b99a..571fe06c 100644 --- a/vllm_mindspore/model_executor/model_loader/utils.py +++ b/vllm_mindspore/model_executor/model_loader/utils.py @@ -20,6 +20,7 @@ """ utils for load model """ import os +from contextlib import contextmanager from mindspore import nn from vllm.config import ModelConfig, ModelImpl @@ -198,3 +199,13 @@ def get_ms_model_architecture( raise RecursionError("MindSpore unsupported reward model task now!") return model_cls, arch + + +@contextmanager +def ms_device_loading_context(module, target_device): + if target_device != "cuda": + raise NotImplementedError( + f"VLLM-Mindspore Plugin only supports loading model on " + f"'cuda' device now, but got '{target_device}'.") + yield module + return diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 6bf2dd4c..d409e80a 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -28,6 +28,8 @@ from tqdm.auto import tqdm from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, enable_tqdm) +from vllm_mindspore.utils import cast_weight_for_310p, is_310p + def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): """ @@ -39,6 +41,8 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): """ if shard_dim is None: loaded_weight = loaded_weight[:] + if is_310p(): + loaded_weight = cast_weight_for_310p(loaded_weight) return loaded_weight end_idx = start_idx + shard_size @@ -50,6 +54,9 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): loaded_weight = loaded_weight[:, :, start_idx:end_idx] else: raise ValueError("shard_dim:{} is not supported.".format(shard_dim)) + if is_310p(): + loaded_weight = cast_weight_for_310p(loaded_weight) + return loaded_weight @@ -77,4 +84,6 @@ def safetensors_weights_iterator( def default_weight_loader(param: Parameter, loaded_weight: Any) -> None: """Default weight loader.""" loaded_weight = loaded_weight[:] + if is_310p(): + loaded_weight = cast_weight_for_310p(loaded_weight) param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 4c0d729e..31cbd374 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -38,7 +38,8 @@ from vllm_mindspore.model_executor.models.interfaces import ( from vllm_mindspore.model_executor.models.utils import (convert_pin, is_use_ringmla) from vllm_mindspore.model_executor.utils import set_model_context -from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE, create_kv_cache +from vllm_mindspore.utils import (STR_DTYPE_TO_MS_DTYPE, create_kv_cache, + is_310p) from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata @@ -472,7 +473,11 @@ class NativeModel(MsModelBase): block_size = self.cache_config.block_size num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) head_size = self.model_config.get_head_size() - kv_cache_shape = (None, block_size, num_kv_heads, head_size) + # FormatCast op is used to convert kv_cache to nz format for 310p. + # The kv shape is flattened to 3 dimensions, + # because FormatCast op only support 3d input. + kv_cache_shape = (None, block_size, num_kv_heads * head_size) \ + if is_310p() else (None, block_size, num_kv_heads, head_size) kv_cache_dtype = (self.model_config.dtype if self.cache_config.cache_dtype == "auto" else diff --git a/vllm_mindspore/platforms/ascend.py b/vllm_mindspore/platforms/ascend.py index 06ec2f04..57d99732 100644 --- a/vllm_mindspore/platforms/ascend.py +++ b/vllm_mindspore/platforms/ascend.py @@ -26,6 +26,8 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms.interface import Platform, PlatformEnum, _Backend +from vllm_mindspore.utils import is_310p + if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig else: @@ -191,3 +193,10 @@ class AscendPlatform(Platform): if ASCEND_QUANTIZATION_METHOD not in quant_action.choices: quant_action.choices.extend(ASCEND_QUANTIZATION_METHOD) logger.debug("--quantization support ascend/golden-stick.") + + @property + def supported_dtypes(self) -> list[torch.dtype]: + if is_310p(): + # bfloat16 is not supported on the 310p due to hardware limitations. + return [torch.float16, torch.float32] + return [torch.bfloat16, torch.float16, torch.float32] \ No newline at end of file diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index 7cba47a2..edbbb6b2 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -93,6 +93,24 @@ def create_kv_cache(kv_shape, dtype, is_fa3_quant=False): return ms.mint.zeros(kv_shape, dtype=dtype) +def cast_weight_for_310p(loaded_weight): + """ + Casts weights to float16 for 310p. + + In non-quantized scenarios, the 310P hardware only supports float16 weights. + This function converts float32 or bfloat16 weights to float16. + """ + cast_weight = (loaded_weight.astype(np.float16) if + (str(loaded_weight.dtype) == "float32" or str( + loaded_weight.dtype) == "bfloat16") else loaded_weight) + return cast_weight + + +def set_weight_format_to_nz(param): + cast_weight = ms.ops.auto_generate.format_cast(param, FORMAT_TYPE['nz']) + param.set_data(cast_weight) + + def get_valid_dtype(dtype): if isinstance(dtype, str): dtype = STR_DTYPE_TO_MS_DTYPE[dtype] diff --git a/vllm_mindspore/worker/cache_engine.py b/vllm_mindspore/worker/cache_engine.py index b57b8833..2ff204c6 100644 --- a/vllm_mindspore/worker/cache_engine.py +++ b/vllm_mindspore/worker/cache_engine.py @@ -26,13 +26,16 @@ from mindspore import mutable, mint from typing import List from vllm.logger import init_logger -from vllm_mindspore.utils import MsKVCache, get_valid_dtype +from vllm_mindspore.utils import MsKVCache, get_valid_dtype, create_kv_cache logger = init_logger(__name__) def create_block(shape, dtype, name=None, device=None): - blocks = mint.empty(shape, dtype=dtype, device=device) + if device == "CPU": + blocks = mint.empty(shape, dtype=dtype, device=device) + else: + blocks = create_kv_cache(shape, dtype) return blocks diff --git a/vllm_mindspore/worker/model_runner.py b/vllm_mindspore/worker/model_runner.py index 7fd89fc5..8f1aba25 100644 --- a/vllm_mindspore/worker/model_runner.py +++ b/vllm_mindspore/worker/model_runner.py @@ -28,7 +28,7 @@ from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.sequence import SequenceGroupMetadata -from vllm_mindspore.utils import STR_DTYPE_TO_TENSOR_DTYPE +from vllm_mindspore.utils import STR_DTYPE_TO_TENSOR_DTYPE, create_kv_cache logger = init_logger(__name__) @@ -142,17 +142,10 @@ def _dummy_run(self, head_size = self.model_config.get_head_size() kv_shape = [0, block_size, num_kv_heads, head_size] kv_caches = mutable([ - mutable( - ( - mutable( - torch.tensor([], - dtype=kv_cache_dtype, - device=self.device).reshape(kv_shape)), - mutable( - torch.tensor([], - dtype=kv_cache_dtype, - device=self.device).reshape(kv_shape)), - )) for _ in range(num_layers) + mutable(( + mutable(create_kv_cache(kv_shape, kv_cache_dtype)), + mutable(create_kv_cache(kv_shape, kv_cache_dtype)), + )) for _ in range(num_layers) ]) finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( -- Gitee From 47d53b5f1311edd1c6f64c5caedda6d854e1661d Mon Sep 17 00:00:00 2001 From: superxf Date: Tue, 11 Nov 2025 15:45:38 +0800 Subject: [PATCH 3/8] add maintainer --- CODEOWNERS | 80 +++++++++++++++++++++++++++--------------------------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/CODEOWNERS b/CODEOWNERS index 7dcf1731..c561c853 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -7,59 +7,59 @@ # The rules are applied from top to bottom, and the last matching rule will take effect. # If the file in PR does not match rule, the default maintainer erpim will be designated to conduct code review. -* @erpim -.* @erpim +* @zhengzuohe +.* @zhengzuohe # The maintainers of vllm_mindspore/. -/vllm_mindspore/ @erpim @tronzhang -/vllm_mindspore/attention/ @erpim @wang_shaocong -/vllm_mindspore/distributed/ @zlq2020 @tronzhang -/vllm_mindspore/engine/ @zlq2020 @tronzhang -/vllm_mindspore/entrypoints/ @panshaowu @zlq2020 -/vllm_mindspore/executor/ @zlq2020 @tronzhang -/vllm_mindspore/inputs/ @zlq2020 @tronzhang -/vllm_mindspore/lora/ @erpim @wang_shaocong -/vllm_mindspore/model_executor/ @zlq2020 @tronzhang -/vllm_mindspore/model_executor/layers/ @zlq2020 @tronzhang -/vllm_mindspore/model_executor/layers/quantization/ @erpim @hangangqiang -/vllm_mindspore/model_executor/model_loader/ @zlq2020 @tronzhang -/vllm_mindspore/model_executor/models/ @erpim @wang_shaocong -/vllm_mindspore/model_executor/models/mf_models/ @erpim @wang_shaocong -/vllm_mindspore/model_executor/models/mindone_models/ @erpim @wang_shaocong -/vllm_mindspore/multimodal/ @erpim @wang_shaocong -/vllm_mindspore/platforms/ @zlq2020 @tronzhang -/vllm_mindspore/v1/ @zlq2020 @tronzhang -/vllm_mindspore/v1/attention/ @erpim @wang_shaocong -/vllm_mindspore/v1/attention/backends/ @erpim @wang_shaocong -/vllm_mindspore/v1/core/ @zlq2020 @r1chardf1d0 -/vllm_mindspore/v1/engine/ @zlq2020 @tronzhang -/vllm_mindspore/v1/executor/ @zlq2020 @tronzhang -/vllm_mindspore/v1/sample/ @zlq2020 @erpim -/vllm_mindspore/v1/worker/ @zlq2020 @r1chardf1d0 -/vllm_mindspore/worker/ @zlq2020 @tronzhang +/vllm_mindspore/ @zhengzuohe +/vllm_mindspore/attention/ @zhengzuohe +/vllm_mindspore/distributed/ @zhengzuohe +/vllm_mindspore/engine/ @zhengzuohe +/vllm_mindspore/entrypoints/ @zhengzuohe +/vllm_mindspore/executor/ @zhengzuohe +/vllm_mindspore/inputs/ @zhengzuohe +/vllm_mindspore/lora/ @zhengzuohe +/vllm_mindspore/model_executor/ @zhengzuohe +/vllm_mindspore/model_executor/layers/ @zhengzuohe +/vllm_mindspore/model_executor/layers/quantization/ @zhengzuohe +/vllm_mindspore/model_executor/model_loader/ @zhengzuohe +/vllm_mindspore/model_executor/models/ @zhengzuohe +/vllm_mindspore/model_executor/models/mf_models/ @zhengzuohe +/vllm_mindspore/model_executor/models/mindone_models/ @zhengzuohe +/vllm_mindspore/multimodal/ @zhengzuohe +/vllm_mindspore/platforms/ @zhengzuohe +/vllm_mindspore/v1/ @zhengzuohe +/vllm_mindspore/v1/attention/ @zhengzuohe +/vllm_mindspore/v1/attention/backends/ @zhengzuohe +/vllm_mindspore/v1/core/ @zhengzuohe +/vllm_mindspore/v1/engine/ @zhengzuohe +/vllm_mindspore/v1/executor/ @zhengzuohe +/vllm_mindspore/v1/sample/ @zhengzuohe +/vllm_mindspore/v1/worker/ @zhengzuohe +/vllm_mindspore/worker/ @zhengzuohe # The maintainers of tests/. -/tests/ @erpim @wang_shaocong +/tests/ @zhengzuohe # The maintainers of examples/. -/examples/ @panshaowu +/examples/ @zhengzuohe # The maintainers of dashboard/. -/dashboard/ @r1chardf1d0 +/dashboard/ @zhengzuohe # The maintainers of csrc/. -/csrc/ @dayschan +/csrc/ @zhengzuohe # The maintainers of codecheck_toolkits/. -/codecheck_toolkits/ @erpim +/codecheck_toolkits/ @zhengzuohe # The maintainers of .gitee/. -/.gitee/ @TrHan +/.gitee/ @zhengzuohe # The maintainers of other code directories. -/requirements.txt @tronzhang -/setup.py @tronzhang -/install_depend_pkgs.sh @tronzhang -/build_image.sh @tronzhang -/Dockerfile @tronzhang -/LICENSE @tronzhang +/requirements.txt @zhengzuohe +/setup.py @zhengzuohe +/install_depend_pkgs.sh @zhengzuohe +/build_image.sh @zhengzuohe +/Dockerfile @zhengzuohe +/LICENSE @zhengzuohe -- Gitee From fcac8e2a95cc5e5bf8b28bc1c3d88ad7b426eb82 Mon Sep 17 00:00:00 2001 From: huangzhuo Date: Wed, 24 Sep 2025 09:46:41 +0800 Subject: [PATCH 4/8] Multi-LORA support Graph Mode --- vllm_mindspore/__init__.py | 5 +- vllm_mindspore/lora/layers.py | 315 ++++++++++-------- vllm_mindspore/lora/models.py | 7 +- .../lora/punica_wrapper/punica_npu.py | 97 +++++- vllm_mindspore/lora/utils.py | 15 + .../model_executor/models/model_base.py | 34 +- vllm_mindspore/model_executor/utils.py | 6 +- vllm_mindspore/platforms/ascend.py | 7 +- .../v1/attention/backends/ms_attn.py | 30 ++ vllm_mindspore/v1/worker/gpu_worker.py | 21 ++ vllm_mindspore/worker/model_runner.py | 16 +- vllm_mindspore/worker/worker.py | 30 +- 12 files changed, 427 insertions(+), 156 deletions(-) diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 7129aa91..2e200672 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -127,10 +127,11 @@ vllm.utils.memory_profiling = ms_memory_profiling import vllm.lora.utils from vllm_mindspore.model_executor.layers.linear import LinearBase -from vllm_mindspore.lora.utils import _all_lora_classes +from vllm_mindspore.lora.utils import _all_lora_classes, replace_submodule vllm.lora.utils._all_lora_classes = _all_lora_classes vllm.lora.utils.LinearBase = LinearBase +vllm.lora.utils.replace_submodule = replace_submodule import vllm.lora.models from vllm_mindspore.lora.models import ( @@ -233,12 +234,14 @@ V0Worker.init_device = wrapper_worker_init_device(V0Worker.init_device) from vllm_mindspore.worker.model_runner import ( _get_cuda_graph_pad_size, _dummy_run, + profile_run, _get_supported_attention_backends, ) vllm.worker.model_runner.ModelInputForGPUBuilder._get_cuda_graph_pad_size = ( _get_cuda_graph_pad_size) vllm.worker.model_runner.GPUModelRunnerBase._dummy_run = _dummy_run +vllm.worker.model_runner.GPUModelRunnerBase.profile_run = profile_run import vllm.worker.multi_step_model_runner diff --git a/vllm_mindspore/lora/layers.py b/vllm_mindspore/lora/layers.py index d3f4b367..45d922dd 100644 --- a/vllm_mindspore/lora/layers.py +++ b/vllm_mindspore/lora/layers.py @@ -21,18 +21,18 @@ import math from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Union, cast +from typing import TYPE_CHECKING, Optional, Union import mindspore as ms -from mindspore import mint +from mindspore import Parameter, Tensor, mint, ops +from mindspore.common.initializer import initializer from transformers import PretrainedConfig from vllm.adapter_commons.layers import AdapterMapping from vllm.config import LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) + tensor_model_parallel_all_gather) from vllm.distributed.utils import divide from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import ( @@ -43,6 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm_mindspore.model_executor.layers.linear import ( ColumnParallelLinear, LinearBase, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) +from vllm_mindspore.utils import FORMAT_TYPE, is_310p if TYPE_CHECKING: from vllm.lora.punica_wrapper import PunicaWrapperBase @@ -320,49 +321,42 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): self.output_size, self.tp_size)) else: raise NotImplementedError - - self.lora_a_stacked = tuple( - mint.zeros( - ( - max_loras, - 1, - lora_a_out_size, - self.input_size, - ), - dtype=lora_config.lora_dtype, - ) for _ in range(self.n_slices)) - self.lora_b_stacked = tuple( - mint.zeros( - ( - max_loras, - 1, - lora_b_out_size, - lora_config.max_lora_rank, - ), - dtype=lora_config.lora_dtype, - ) for _ in range(self.n_slices)) + self.lora_a_stacked = Parameter( + initializer('zeros', + (max_loras + 1, self.input_size, lora_a_out_size), + lora_config.lora_dtype)) + self.lora_b_stacked = Parameter( + initializer('zeros', + (max_loras + 1, lora_a_out_size, lora_b_out_size), + lora_config.lora_dtype)) if lora_config.bias_enabled: lora_bias_out_size = lora_b_out_size - self.lora_bias_stacked = tuple( - mint.zeros( - ( - max_loras, - 1, - lora_bias_out_size, - ), - dtype=lora_config.lora_dtype, - ) for _ in range(self.n_slices)) - self.output_slices = (self.lora_b_stacked[0].shape[2], ) + self.lora_bias_stacked = Parameter( + initializer('zeros', (max_loras + 1, lora_bias_out_size), + lora_config.lora_dtype)) + else: + self.lora_bias_stacked = None def reset_lora(self, index: int): - for s_index in range(self.n_slices): - self.lora_a_stacked[s_index][index] = 0 - self.lora_b_stacked[s_index][index] = 0 - if self.lora_config.bias_enabled: - # Make mypy happy - self.lora_bias_stacked = cast(tuple[ms.Tensor, ...], - self.lora_bias_stacked) - self.lora_bias_stacked[s_index][index] = 0 + tmp_lora_a = self.lora_a_stacked.value() + tmp_lora_b = self.lora_b_stacked.value() + tmp_lora_a[index + 1] = 0 + tmp_lora_b[index + 1] = 0 + if is_310p(): + tmp_lora_a = ops.auto_generate.format_cast(tmp_lora_a, + FORMAT_TYPE['nz']) + tmp_lora_b = ops.auto_generate.format_cast(tmp_lora_b, + FORMAT_TYPE['nz']) + self.lora_a_stacked.set_data(tmp_lora_a) + self.lora_b_stacked.set_data(tmp_lora_b) + if self.lora_bias_stacked: + # Make mypy happy + tmp_lora_bias = self.lora_bias_stacked.value() + tmp_lora_bias[index + 1] = 0 + tmp_lora_bias = ops.auto_generate.format_cast( + tmp_lora_bias, FORMAT_TYPE['nz']) + self.lora_bias_stacked.set_data( + Tensor(tmp_lora_bias, dtype=self.lora_bias_stacked.dtype)) def set_lora( self, @@ -376,8 +370,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers # store weights in a tuple of size 1. These two layers will # override this function. - assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) == - self.n_slices == 1) self.reset_lora(index) if self.tp_size > 1: @@ -386,28 +378,34 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): if lora_bias is not None: lora_bias = self.slice_bias(lora_bias) - self.lora_a_stacked[0][index, - 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) - self.lora_b_stacked[0][index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) - if lora_bias is not None: - - self.lora_bias_stacked = cast(tuple[ms.Tensor, ...], - self.lora_bias_stacked) + tmp_lora_a = self.lora_a_stacked.value() + tmp_lora_b = self.lora_b_stacked.value() + tmp_lora_a[index + 1, :, :lora_a.shape[1]] = lora_a + tmp_lora_b[index + 1, :lora_b.shape[0], :] = lora_b + if is_310p(): + tmp_lora_a = ops.auto_generate.format_cast(tmp_lora_a, + FORMAT_TYPE['nz']) + tmp_lora_b = ops.auto_generate.format_cast(tmp_lora_b, + FORMAT_TYPE['nz']) + self.lora_a_stacked.set_data(tmp_lora_a) + self.lora_b_stacked.set_data(tmp_lora_b) + + if self.lora_bias_stacked is not None: assert len(self.lora_bias_stacked) - self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( - lora_bias.T, non_blocking=True) + tmp_lora_bias = self.lora_bias_stacked.value() + tmp_lora_bias[index + 1] = lora_bias + if is_310p(): + tmp_lora_bias = ops.auto_generate.format_cast( + tmp_lora_bias, FORMAT_TYPE['nz']) + self.lora_bias_stacked.set_data(tmp_lora_bias) def apply(self, x: ms.Tensor, bias: Optional[ms.Tensor] = None) -> ms.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, - self.lora_b_stacked, - self.lora_bias_stacked, 1.0, - self.output_slices) + output = self.punica_wrapper(output, x, self.lora_a_stacked, + self.lora_b_stacked, + self.lora_bias_stacked, 1.0) return output @@ -540,7 +538,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): model_config: Optional[PretrainedConfig] = None, ) -> None: """ - The main reason for overriding this function is to enhance code + The main reason for overriding this function is to enhance code maintainability. """ self.lora_config = lora_config @@ -548,36 +546,22 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): lora_a_output_size_per_partition = ( lora_config.max_lora_rank if not lora_config.fully_sharded_loras else divide(lora_config.max_lora_rank, self.tp_size)) - self.lora_a_stacked = tuple( - mint.zeros( - ( - max_loras, - 1, - lora_a_output_size_per_partition, - self.input_size, - ), - dtype=lora_config.lora_dtype, - ) for _ in range(self.n_slices)) - self.lora_b_stacked = tuple( - mint.zeros( - ( - max_loras, - 1, - output_size, - lora_config.max_lora_rank, - ), - dtype=lora_config.lora_dtype, - ) for output_size in self.output_slices) + output_size = sum(self.output_slices) + self.lora_a_stacked = Parameter( + initializer('zeros', + (max_loras + 1, self.input_size, + lora_a_output_size_per_partition * self.n_slices), + lora_config.lora_dtype)) + self.lora_b_stacked = Parameter( + initializer('zeros', + (max_loras + 1, lora_a_output_size_per_partition * + self.n_slices, output_size), lora_config.lora_dtype)) if lora_config.bias_enabled: - self.lora_bias_stacked = tuple( - mint.zeros( - ( - max_loras, - 1, - output_size, - ), - dtype=lora_config.lora_dtype, - ) for output_size in self.output_slices) + self.lora_bias_stacked = Parameter( + initializer('zeros', (max_loras + 1, output_size), + lora_config.lora_dtype)) + else: + self.lora_bias_stacked = None def slice_lora_a( self, lora_a: list[Union[ms.Tensor, @@ -619,26 +603,36 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): lora_b = self.slice_lora_b(lora_b) if lora_bias is not None: lora_bias = self.slice_bias(lora_bias) - - for i in range(self.n_slices): - if (lora_a_i := lora_a[i]) is not None: - self.lora_a_stacked[i][ - index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_( - lora_a_i.T, non_blocking=True) - if (lora_b_i := lora_b[i]) is not None: - self.lora_b_stacked[i][ - index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_( - lora_b_i.T, non_blocking=True) - - if lora_bias is not None: - self.lora_bias_stacked = cast(tuple[ms.Tensor, ...], - self.lora_bias_stacked) - for i in range(self.n_slices): - if (lora_bias_i := lora_bias[i]) is not None: - self.lora_bias_stacked[i][index, - 0, :lora_bias_i.shape[0]].copy_( - lora_bias_i.T, - non_blocking=True) + tmp_lora_a = self.lora_a_stacked.value() + tmp_lora_b = self.lora_b_stacked.value() + tmp_lora_a[index + 1, :, :lora_a[0].shape[1]] = lora_a[0] + tmp_lora_a[ + index + 1, :, + self.lora_config.max_lora_rank:self.lora_config.max_lora_rank + + lora_a[1].shape[1]] = lora_a[1] + tmp_lora_b[index + + 1, :lora_b[0].shape[0], :lora_b[0].shape[1]] = lora_b[0] + tmp_lora_b[ + index + 1, + self.lora_config.max_lora_rank:self.lora_config.max_lora_rank + + lora_b[1].shape[0], lora_b[0].shape[1]:lora_b[0].shape[1] + + lora_b[1].shape[1]] = lora_b[1] + if is_310p(): + tmp_lora_a = ops.auto_generate.format_cast(tmp_lora_a, + FORMAT_TYPE['nz']) + tmp_lora_b = ops.auto_generate.format_cast(tmp_lora_b, + FORMAT_TYPE['nz']) + self.lora_a_stacked.set_data(tmp_lora_a) + self.lora_b_stacked.set_data(tmp_lora_b) + if self.lora_bias_stacked is not None: + assert len(self.lora_bias_stacked) + lora_bias = ops.concat(lora_bias, axis=0) + tmp_lora_bias = self.lora_bias_stacked.value() + tmp_lora_bias[index + 1] = lora_bias + if is_310p(): + tmp_lora_bias = ops.auto_generate.format_cast( + tmp_lora_bias, FORMAT_TYPE['nz']) + self.lora_bias_stacked.set_data(tmp_lora_bias) @classmethod @_not_fully_sharded_can_replace @@ -757,17 +751,59 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): self.kv_shard_id, ) - def create_lora_weights( + def set_lora( self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - """ - The main reason for overloading this function is to handle inconsistent - weight dimensions in qkv lora. - """ - super().create_lora_weights(max_loras, lora_config, model_config) + index: int, + lora_a: ms.Tensor, + lora_b: ms.Tensor, + embeddings_tensor: Optional[ms.Tensor], + lora_bias: Optional[ms.Tensor] = None, + ): + self.reset_lora(index) + + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) + + tmp_lora_a = self.lora_a_stacked.value() + tmp_lora_b = self.lora_b_stacked.value() + tmp_lora_a[index + 1, :, :lora_a[0].shape[1]] = lora_a[0] + tmp_lora_a[ + index + 1, :, + self.lora_config.max_lora_rank:self.lora_config.max_lora_rank + + lora_a[1].shape[1]] = lora_a[1] + tmp_lora_a[index + 1, :, self.lora_config.max_lora_rank * + 2:self.lora_config.max_lora_rank * 2 + + lora_a[2].shape[1]] = lora_a[2] + tmp_lora_b[index + + 1, :lora_b[0].shape[0], :lora_b[0].shape[1]] = lora_b[0] + tmp_lora_b[ + index + 1, + self.lora_config.max_lora_rank:self.lora_config.max_lora_rank + + lora_b[1].shape[0], lora_b[0].shape[1]:lora_b[0].shape[1] + + lora_b[1].shape[1]] = lora_b[1] + tmp_lora_b[index + 1, self.lora_config.max_lora_rank * + 2:self.lora_config.max_lora_rank * 2 + lora_b[2].shape[0], + lora_b[0].shape[1] + lora_b[1].shape[1]:] = lora_b[2] + if is_310p(): + tmp_lora_a = ops.auto_generate.format_cast(tmp_lora_a, + FORMAT_TYPE['nz']) + tmp_lora_b = ops.auto_generate.format_cast(tmp_lora_b, + FORMAT_TYPE['nz']) + self.lora_a_stacked.set_data(tmp_lora_a) + self.lora_b_stacked.set_data(tmp_lora_b) + + if self.lora_bias_stacked is not None: + assert len(self.lora_bias_stacked) + lora_bias = ops.concat(lora_bias, axis=0) + tmp_lora_bias = self.lora_bias_stacked.value() + tmp_lora_bias[index + 1] = lora_bias + if is_310p(): + tmp_lora_bias = ops.auto_generate.format_cast( + tmp_lora_bias, FORMAT_TYPE['nz']) + self.lora_bias_stacked.set_data(tmp_lora_bias) @classmethod @_not_fully_sharded_can_replace @@ -836,7 +872,8 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): # Matrix multiply. output_parallel = self.apply(input_parallel) if self.base_layer.reduce_results and self.base_layer.tp_size > 1: - output_ = tensor_model_parallel_all_reduce(output_parallel) + output_ = self.base_layer.tensor_model_parallel_all_reduce( + output_parallel) else: output_ = output_parallel @@ -1012,24 +1049,22 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): return None if self.sharded_to_full_mapping_gpu is not None: - """ - Reindex full logits tensor to ensure 1:1 mapping between - index and token_id - Example for: - org_vocab_size = 4 - added_vocab_size = 2 - pad_to_size = 8 - tp_size = 2 - - indices: [0, 1, 2, 3, 4, 5, 6, 7] - token_id: [0, 1, 4, -1, 2, 3, 5, -1] - - Therefore, the mapping is expected to be: - [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex, - we get: - indices: [0, 1, 2, 3, 4, 5, 6, 7] - token_id: [0, 1, 2, 3, 4, 5, -1, -1] - """ + # Reindex full logits tensor to ensure 1:1 mapping between + # index and token_id + # Example for: + # org_vocab_size: 4 # noqa: ERA001 + # added_vocab_size: 2 # noqa: ERA001 + # pad_to_size: 8 # noqa: ERA001 + # tp_size: 2 # noqa: ERA001 + + # indices: [0, 1, 2, 3, 4, 5, 6, 7] # noqa: ERA001 + # token_id: [0, 1, 4, -1, 2, 3, 5, -1] # noqa: ERA001 + + # Therefore, the mapping is expected to be: + # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex, + # we get: + # indices: [0, 1, 2, 3, 4, 5, 6, 7] # noqa: ERA001 + # token_id: [0, 1, 2, 3, 4, 5, -1, -1] # noqa: ERA001 logits = logits[:, self.sharded_to_full_mapping_gpu] lora_logits = mint.empty( diff --git a/vllm_mindspore/lora/models.py b/vllm_mindspore/lora/models.py index 7a687019..f433c5c0 100644 --- a/vllm_mindspore/lora/models.py +++ b/vllm_mindspore/lora/models.py @@ -23,6 +23,7 @@ import os from typing import Optional, Union import mindspore as ms +import numpy as np import safetensors.torch from mindspore import mint from vllm.lora.lora import LoRALayerWeights @@ -33,6 +34,7 @@ from vllm.model_executor.models.utils import WeightsMapper from vllm.utils import is_pin_memory_available from vllm_mindspore.lora.layers import BaseLayerWithLoRA +from vllm_mindspore.utils import is_310p _GLOBAL_LORA_ID = 0 @@ -198,7 +200,10 @@ def from_local_checkpoint( check_unexpected_modules(f) for module in f.keys(): # noqa # vllm-mindspore add numpy to tensor - tensors[module] = mint.Tensor(f.get_tensor(module)) + np_data = f.get_tensor(module) + if is_310p() and str(np_data.dtype) == "bfloat16": + np_data = np_data.astype(np.float32).astype(np.float16) + tensors[module] = mint.Tensor(np_data) elif os.path.isfile(lora_bin_file_path): # When a bin file is provided, we rely on config to find unexpected # modules. diff --git a/vllm_mindspore/lora/punica_wrapper/punica_npu.py b/vllm_mindspore/lora/punica_wrapper/punica_npu.py index 0a60baf2..70808f69 100644 --- a/vllm_mindspore/lora/punica_wrapper/punica_npu.py +++ b/vllm_mindspore/lora/punica_wrapper/punica_npu.py @@ -19,16 +19,25 @@ # isort: skip_file """Punica wrapper for NPU.""" -from typing import Callable +from typing import TYPE_CHECKING, Callable, Optional -from mindspore import mint +from mindspore import mint, nn, Parameter, ops, dtype from mindspore.common import dtype as mstype +from mindspore.common.initializer import initializer +from mindspore.ops.auto_generate import grouped_matmul_v4 from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase +from vllm_mindspore.model_executor.utils import (get_model_context, + set_model_context) from vllm_mindspore.lora.ops.torch_ops.lora_ops import ( bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, sgmv_expand_slice, sgmv_shrink) +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext + # The platforms that are compatible with the PyTorch-native implementation can # inherit this class @@ -357,3 +366,87 @@ class PunicaWrapperNPU(PunicaWrapperBase): self.sampler_indices, add_inputs=True) y.view_as(y_org) + + +class InferPunicaWrapperNPU(PunicaWrapperBase, nn.Cell): + """ + InferPunicaWrapperNPU is designed to manage and provide metadata for the + punica kernel. This function utilizes a grouped matrix multiplication + (groupedMatmul) approach to adapt to scenarios involving the concurrent + execution of multiple LoRA adapters, thereby enhancing the inference + performance for Multi-LoRA setups. + """ + + def __init__(self, max_num_batched_tokens, max_batches, device, **kwargs): + nn.Cell.__init__(self) + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + self.max_loras = kwargs["max_loras"] + self.group_list = Parameter(initializer("ones", self.max_loras + 1, + dtype.int64), + name="group_list") + + def sgmv_shrink( + self, + inputs, + lora_a_weights, + group_list, + scaling, + ): + outputs = grouped_matmul_v4([inputs], [lora_a_weights], + group_list=group_list, + split_item=3, + group_type=0, + group_list_type=1)[0] + return outputs + + def sgmv_expand_slice(self, y, inputs, lora_b_weights, group_list): + expand_outputs = grouped_matmul_v4([inputs], [lora_b_weights], + group_list=group_list, + split_item=3, + group_type=0, + group_list_type=1)[0] + outputs = ops.add(y, expand_outputs) + return outputs + + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: list[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + **kwargs): + self._update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) + # Update metadata required for prefill and decode operators. + self._update_prefill_metadata(self.token_lora_indices) + if mapping.is_prefill: + self.is_prefill = True + else: + self.is_prefill = False + _, seq_len, lora_indices, _, _, _ = self.prefill_metadata + new_tensor = ops.zeros(self.max_loras + 1, dtype=self.group_list.dtype) + lora_indices = lora_indices + 1 + new_tensor[lora_indices] = seq_len + self.group_list.set_data(new_tensor.astype(dtype.int64)) + set_model_context("no_lora", self.no_lora) + + def construct(self, y, x, lora_a_stacked, lora_b_stacked, + lora_bias_stacked, scale): + if get_model_context("no_lora"): + return y + x = x.reshape(-1, x.shape[-1]) + orign_shape = y.shape + y = y.reshape(-1, y.shape[-1]) + if lora_bias_stacked is not None: + selected_loras_bias = lora_bias_stacked[self.token_lora_indices] + y = ops.add(y, selected_loras_bias) + shrink_outputs = self.sgmv_shrink(x, lora_a_stacked, self.group_list, + scale) + outputs = self.sgmv_expand_slice(y, shrink_outputs, lora_b_stacked, + self.group_list) + outputs = outputs.reshape(orign_shape) + return outputs diff --git a/vllm_mindspore/lora/utils.py b/vllm_mindspore/lora/utils.py index d9157467..8916a0dd 100644 --- a/vllm_mindspore/lora/utils.py +++ b/vllm_mindspore/lora/utils.py @@ -50,3 +50,18 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = { RowParallelLinearWithShardedLoRA, LinearScalingRotaryEmbeddingWithLoRA, } + + +def replace_submodule(model, module_name, new_module): + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + setattr(parent, target_name, new_module) + new_module.base_layer.weight.name = module_name + ".weight" + new_module.lora_a_stacked.name = module_name + ".lora_a_weight" + new_module.lora_b_stacked.name = module_name + ".lora_b_weight" + if new_module.base_layer.bias is not None: + new_module.base_layer.bias.name = module_name + ".bias" + if new_module.lora_bias_stacked is not None: + new_module.lora_bias_stacked.name = module_name + ".lora_bias" + return new_module diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 31cbd374..b0019615 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -37,7 +37,8 @@ from vllm_mindspore.model_executor.models.interfaces import ( is_mixture_of_experts, supports_moe_dp_tp) from vllm_mindspore.model_executor.models.utils import (convert_pin, is_use_ringmla) -from vllm_mindspore.model_executor.utils import set_model_context +from vllm_mindspore.model_executor.utils import (get_model_context, + set_model_context) from vllm_mindspore.utils import (STR_DTYPE_TO_MS_DTYPE, create_kv_cache, is_310p) from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata @@ -383,12 +384,13 @@ class NativeModel(MsModelBase): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__(vllm_config=vllm_config, prefix=prefix) self.quant_config = vllm_config.quant_config - if vllm_config.lora_config is not None: - # native model lora only support pynative mode now - vllm_config.model_config.enforce_eager = True self.is_eager_mode = vllm_config.model_config.enforce_eager + if not self.is_eager_mode and vllm_config.lora_config is not None: + self.lora_prefill_graph = None + self.lora_decode_graph = None self.prefill_graph = None self.decode_graph = None + set_model_context("enforce_eager", self.is_eager_mode) if is_mixture_of_experts(self): ep_size = get_ep_group().world_size @@ -632,6 +634,30 @@ class NativeModel(MsModelBase): model_output = self.model(**model_inputs) return model_output + if not get_model_context("no_lora"): + # graph mode with lora + if is_prefill: + self.model.phase = "prefill_for_lora" + if self.lora_prefill_graph is None: + set_model_context("is_prefill", True) + self.model._set_jit_graph_name("prefill_for_lora") + self.set_model_inputs(input_ids, positions, + intermediate_tensors, inputs_embeds) + self.lora_prefill_graph = ms.jit(function=self.model, + jit_level="O0") + model_output = self.lora_prefill_graph(**model_inputs) + else: + self.model.phase = "increment_for_lora" + if self.lora_decode_graph is None: + set_model_context("is_prefill", False) + self.model._set_jit_graph_name("decode_for_lora") + self.set_model_inputs(input_ids, positions, + intermediate_tensors, inputs_embeds) + self.lora_decode_graph = ms.jit(function=self.model, + jit_level="O0") + model_output = self.lora_decode_graph(**model_inputs) + return model_output + # graph mode if is_prefill: self.model.phase = "prefill" diff --git a/vllm_mindspore/model_executor/utils.py b/vllm_mindspore/model_executor/utils.py index b7187afe..90383df3 100644 --- a/vllm_mindspore/model_executor/utils.py +++ b/vllm_mindspore/model_executor/utils.py @@ -33,7 +33,11 @@ def set_weight_attrs( setattr(weight, key, value) -_native_model_context = {"is_prefill": True} +_native_model_context = { + "is_prefill": True, + "no_lora": True, + "enforce_eager": False +} def set_model_context(key, value): diff --git a/vllm_mindspore/platforms/ascend.py b/vllm_mindspore/platforms/ascend.py index 57d99732..5b4f232a 100644 --- a/vllm_mindspore/platforms/ascend.py +++ b/vllm_mindspore/platforms/ascend.py @@ -33,6 +33,7 @@ if TYPE_CHECKING: else: ModelConfig = None VllmConfig = None +from vllm_mindspore.model_executor.utils import get_model_context logger = init_logger(__name__) @@ -72,7 +73,6 @@ class AscendPlatform(Platform): def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config - model_config = vllm_config.model_config if parallel_config.worker_cls == "auto": if scheduler_config.is_multi_step: @@ -168,7 +168,10 @@ class AscendPlatform(Platform): return True def get_punica_wrapper(cls) -> str: - return "vllm_mindspore.lora.punica_wrapper.punica_npu.PunicaWrapperNPU" + if get_model_context("enforce_eager"): + return "vllm_mindspore.lora.punica_wrapper.punica_npu.PunicaWrapperNPU" # noqa E501 + else: + return "vllm_mindspore.lora.punica_wrapper.punica_npu.InferPunicaWrapperNPU" # noqa E501 @classmethod def use_all_gather(cls) -> bool: diff --git a/vllm_mindspore/v1/attention/backends/ms_attn.py b/vllm_mindspore/v1/attention/backends/ms_attn.py index d3ff6c8b..1ee1a2ac 100644 --- a/vllm_mindspore/v1/attention/backends/ms_attn.py +++ b/vllm_mindspore/v1/attention/backends/ms_attn.py @@ -183,6 +183,36 @@ class MsAttentionMetadataBuilder: self.block_table = block_table def reorder_batch(self, input_batch, scheduler_output) -> bool: + if len(input_batch.lora_id_to_request_ids): + req_lora_ids = {} + lora_ids = [] + # sort by lora_id + for lora_id, requests in input_batch.lora_id_to_request_ids.items( + ): + for request in requests: + req_lora_ids[request] = lora_id + for req_id in input_batch._req_ids: + if req_id not in req_lora_ids: + lora_ids.append(-1) + else: + lora_ids.append(req_lora_ids[req_id]) + self.is_sort = not np.all( + np.array(lora_ids[:-1]) <= np.array(lora_ids[1:])) + if self.is_sort: + lora_id_sort = np.argsort(lora_ids) + cur_req_index = list(range(len(input_batch._req_ids))) + req_nums = len(cur_req_index) + + for i in range(req_nums): + while cur_req_index[i] != lora_id_sort[i]: + for j in range(i + 1, req_nums): + if cur_req_index[j] == lora_id_sort[i]: + cur_req_index[i], cur_req_index[ + j] = cur_req_index[j], cur_req_index[i] + # swap inputs information of the two requests + input_batch.swap_states(i, j) + break + return True return False def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, diff --git a/vllm_mindspore/v1/worker/gpu_worker.py b/vllm_mindspore/v1/worker/gpu_worker.py index 3e2e4537..5137e775 100644 --- a/vllm_mindspore/v1/worker/gpu_worker.py +++ b/vllm_mindspore/v1/worker/gpu_worker.py @@ -22,6 +22,8 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger +from vllm_mindspore.model_executor.utils import set_model_context + logger = init_logger(__name__) @@ -41,3 +43,22 @@ def compile_or_warm_up_model(self) -> None: self.model_runner._dummy_run(num_tokens=default_max_num_reqs)) else: self.model_runner._dummy_run(num_tokens=default_max_num_reqs) + + # Compile graph and warm up for base model when using lora. + # Optimize the mixed execution of LoRA and the base model by avoiding + # redundant graph recompilation, thus improving first-inference latency. + if self.model_runner.lora_config is not None: + set_model_context("no_lora", True) + self.model_runner.model.has_prefill_warmup = False + if get_pp_group().is_last_rank: + # prefill for base model + self.model_runner._dummy_sampler_run( + self.model_runner._dummy_run(num_tokens=default_max_num_reqs)) + # decode for base model + self.model_runner._dummy_sampler_run( + self.model_runner._dummy_run(num_tokens=default_max_num_reqs)) + else: + # prefill for base model + self.model_runner._dummy_run(num_tokens=default_max_num_reqs) + # decode for base model + self.model_runner._dummy_run(num_tokens=default_max_num_reqs) \ No newline at end of file diff --git a/vllm_mindspore/worker/model_runner.py b/vllm_mindspore/worker/model_runner.py index 8f1aba25..a1cd5f3a 100644 --- a/vllm_mindspore/worker/model_runner.py +++ b/vllm_mindspore/worker/model_runner.py @@ -43,9 +43,19 @@ def _get_cuda_graph_pad_size(self, return -1 +def profile_run(self) -> None: + max_num_batched_tokens = \ + self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + self._dummy_run(max_num_batched_tokens, max_num_seqs) + if self.lora_config: + self._dummy_run(max_num_batched_tokens, max_num_seqs, True) + + def _dummy_run(self, max_num_batched_tokens: int, - max_num_seqs: int = 1) -> None: + max_num_seqs: int = 1, + use_lora: bool = False) -> None: with self.set_in_profile_run(): # Enable top-k sampling to reflect the accurate memory usage. sampling_params = \ @@ -57,7 +67,7 @@ def _dummy_run(self, # passed in, which contains a lora from the lora warmup path. dummy_lora_requests: List[LoRARequest] = [] dummy_lora_requests_per_seq: List[LoRARequest] = [] - if self.lora_config: + if use_lora: assert self.lora_manager is not None with self.lora_manager.dummy_lora_cache(): for idx in range(self.lora_config.max_loras): @@ -164,7 +174,7 @@ def _dummy_run(self, self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() - if self.lora_config: + if use_lora: # Remove dummy loras. assert self.lora_manager is not None self.remove_all_loras() diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index 66afa92c..0a9d60a9 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -18,10 +18,12 @@ import math import os import subprocess +from typing import List import psutil import torch from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.sampling_params import SamplingParams from vllm.sequence import SequenceGroupMetadata @@ -165,7 +167,8 @@ def _prepare_input_for_warmup(model_config, model_runner, cache_engine, is_prefill, - is_mtp_model=False): + is_mtp_model=False, + use_lora=False): bs = 1 seq_len = model_runner.scheduler_config.max_num_batched_tokens \ if is_prefill else 1 @@ -175,6 +178,19 @@ def _prepare_input_for_warmup(model_config, i for i in range(math.ceil(seq_len / cache_engine.block_size)) ] + dummy_lora_requests: List[LoRARequest] = [] + if use_lora: + assert model_runner.lora_manager is not None + LORA_WARMUP_RANK = 8 + dummy_lora_request = LoRARequest( + lora_name="warmup_for_decode", + lora_int_id=1, + lora_path="/not/a/real/path", + ) + model_runner.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests = dummy_lora_request + # adapter multi modal warm up seq_data = dummy_data.seq_data if seq_len == 1: @@ -187,7 +203,7 @@ def _prepare_input_for_warmup(model_config, seq_data={idx: seq_data}, sampling_params=SamplingParams(), block_tables={idx: block_tables_num}, - lora_request=None, + lora_request=dummy_lora_requests if use_lora else None, multi_modal_data=None, multi_modal_placeholders=None, ) for idx in range(bs) @@ -236,6 +252,16 @@ def _warm_up_model(self) -> None: torch.cuda.synchronize() + if self.vllm_config.lora_config is not None: + # warmup for lora decode + model_input, _ = _prepare_input_for_warmup(self.model_config, + self.model_runner, + self.cache_engine[0], False, + False, True) + self.model_runner.execute_model(model_input, kv_cache, None) + self.model_runner.remove_all_loras() + torch.cuda.synchronize() + # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) -- Gitee From 48abcb697bdb1617ccf9b44f2daccbd3a582284c Mon Sep 17 00:00:00 2001 From: huoxinyou Date: Fri, 7 Nov 2025 11:48:28 +0800 Subject: [PATCH 5/8] moe support 310p Signed-off-by: huoxinyou --- .../layers/fused_moe/fused_moe.py | 20 ++++++++++++++----- .../model_executor/layers/fused_moe/layer.py | 12 ++++++++--- .../model_executor/layers/linear.py | 4 ++-- .../layers/vocab_parallel_embedding.py | 4 ++-- .../model_loader/weight_utils.py | 14 +++++++++---- .../model_executor/models/qwen2_5_vl.py | 4 +++- 6 files changed, 41 insertions(+), 17 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py index 695c481f..00e43110 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -31,7 +31,11 @@ from mindspore.ops.auto_generate import (GroupedMatmulV4, MoeDistributeCombine, from vllm.distributed.parallel_state import get_ep_group from vllm_mindspore.model_executor.layers.fused_moe.config import MoeMode -from vllm_mindspore.utils import is_910b +from vllm_mindspore.utils import is_310p, is_910b + + +def softmax_score_function(x): + return mint.softmax(x, dim=-1, dtype=ms.float32) def fused_topk( @@ -41,10 +45,15 @@ def fused_topk( renormalize: bool, indices_type=None, ) -> tuple[Tensor, Tensor]: - moe_topk_softmax = MoeGatingTopKSoftmax() - topk_weights, topk_ids, _ = moe_topk_softmax(gating_output, None, topk) + if is_310p(): + scores = softmax_score_function(gating_output) + topk_weights, topk_ids = mint.topk(scores, k=topk, dim=-1) + else: + moe_topk_softmax = MoeGatingTopKSoftmax() + topk_weights, topk_ids, _ = moe_topk_softmax(gating_output, None, topk) if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = mint.div( + topk_weights, mint.add(mint.sum(topk_weights, -1, True), 1e-20)) if indices_type is not None: topk_ids = topk_ids.to(indices_type) @@ -171,7 +180,8 @@ class FusedExperts(nn.Cell): gate = self._gate_activation(gate, activation) hidden = mint.mul(hidden, gate) expert_output = self._group_matmul(hidden, w2, group_list) - expert_output = mint.nan_to_num(expert_output, 0, 0, 0) + if not is_310p(): + expert_output = mint.nan_to_num(expert_output, 0, 0, 0) return expert_output def run_tp(self, hidden_states, w1, w2, topk_ids, topk_weights, activation, diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 4669352a..5cebcd44 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -41,7 +41,8 @@ from vllm_mindspore.model_executor.layers.fused_moe.fused_moe import ( from vllm_mindspore.model_executor.layers.quantization.base_config import ( QuantizeMethodBase) from vllm_mindspore.model_executor.model_loader.weight_utils import ( - split_loaded_weight) + convert_loaded_weight, split_loaded_weight) +from vllm_mindspore.utils import is_310p, set_weight_format_to_nz logger = init_logger(__name__) @@ -109,6 +110,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, nn.Cell): set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, {"is_transposed": True}) + def process_weights_after_loading(self, layer): + if is_310p(): + set_weight_format_to_nz(layer.w13_weight) + set_weight_format_to_nz(layer.w2_weight) + def apply( self, layer: nn.Cell, @@ -508,7 +514,7 @@ class FusedMoE(nn.Cell): expert_id: int): is_param_transpose = param.is_transposed \ if hasattr(param, "is_transposed") else False - loaded_weight = loaded_weight[:] + loaded_weight = convert_loaded_weight(loaded_weight) if is_param_transpose: loaded_weight = from_numpy(loaded_weight.swapaxes(-1, -2)) else: @@ -528,7 +534,7 @@ class FusedMoE(nn.Cell): assert shard_id in ("w1", "w3") is_param_transpose = param.is_transposed \ if hasattr(param, "is_transposed") else False - loaded_weight = loaded_weight[:] + loaded_weight = convert_loaded_weight(loaded_weight) if is_param_transpose: loaded_weight = from_numpy(loaded_weight.swapaxes(-1, -2)) else: diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index 46ea2d7d..b22c3d5d 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -36,7 +36,7 @@ from vllm_mindspore.distributed.communication_op import ( from vllm_mindspore.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm_mindspore.model_executor.model_loader.weight_utils import ( - split_loaded_weight) + convert_loaded_weight, split_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs from vllm_mindspore.utils import is_310p, set_weight_format_to_nz @@ -219,7 +219,7 @@ class ReplicatedLinear(LinearBase): self.bias = None def weight_loader(self, param: Parameter, loaded_weight: Tensor): - loaded_weight = loaded_weight[:] + loaded_weight = convert_loaded_weight(loaded_weight) if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index 0c678ca1..7f02dca5 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -36,7 +36,7 @@ from vllm_mindspore.distributed.communication_op import ( from vllm_mindspore.model_executor.layers.quantization.base_config import ( QuantizeMethodBase, method_has_implemented_embedding) from vllm_mindspore.model_executor.model_loader.weight_utils import ( - split_loaded_weight) + convert_loaded_weight, split_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs from vllm_mindspore.utils import is_310p, set_weight_format_to_nz @@ -345,7 +345,7 @@ class VocabParallelEmbedding(nn.Cell): # If parameter does not have output dim, then it should # be copied onto all gpus (e.g. g_idx for act_order gptq). if output_dim is None: - loaded_weight = loaded_weight[:] + loaded_weight = convert_loaded_weight(loaded_weight) assert param.data.shape == loaded_weight.shape if param.data.shape != loaded_weight.shape: raise ValueError( diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index d409e80a..47d2a67f 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -31,6 +31,14 @@ from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, from vllm_mindspore.utils import cast_weight_for_310p, is_310p +def convert_loaded_weight(loaded_weight): + """Get all loaded_weight value and dtype conversion on 310p""" + loaded_weight = loaded_weight[:] + if is_310p(): + loaded_weight = cast_weight_for_310p(loaded_weight) + return loaded_weight + + def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): """ Read numpy slice data based on axis and slice range. @@ -40,9 +48,7 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): :shard_size: end slice index """ if shard_dim is None: - loaded_weight = loaded_weight[:] - if is_310p(): - loaded_weight = cast_weight_for_310p(loaded_weight) + loaded_weight = convert_loaded_weight(loaded_weight) return loaded_weight end_idx = start_idx + shard_size @@ -83,7 +89,7 @@ def safetensors_weights_iterator( def default_weight_loader(param: Parameter, loaded_weight: Any) -> None: """Default weight loader.""" - loaded_weight = loaded_weight[:] + loaded_weight = convert_loaded_weight(loaded_weight) if is_310p(): loaded_weight = cast_weight_for_310p(loaded_weight) param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) diff --git a/vllm_mindspore/model_executor/models/qwen2_5_vl.py b/vllm_mindspore/model_executor/models/qwen2_5_vl.py index 12589a9d..6f408297 100644 --- a/vllm_mindspore/model_executor/models/qwen2_5_vl.py +++ b/vllm_mindspore/model_executor/models/qwen2_5_vl.py @@ -90,6 +90,8 @@ from vllm_mindspore.model_executor.models.model_base import NativeModel, \ AttentionWrapper from vllm_mindspore.model_executor.models.qwen2 import Qwen2Model from vllm_mindspore.model_executor.models.utils import PPMissingLayer +from vllm_mindspore.model_executor.model_loader.weight_utils import \ + convert_loaded_weight from .interfaces import (SupportsMultiModal) from .utils import (WeightsMapper, maybe_prefix, merge_multimodal_embeddings) @@ -1059,7 +1061,7 @@ class Qwen2_5_VisionTransformer(nn.Cell): for name, loaded_weight in weights: param = params_dict[name] if name == "visual.patch_embed.proj.weight": - loaded_weight = loaded_weight[:] + loaded_weight = convert_loaded_weight(loaded_weight) loaded_weight = loaded_weight.reshape(loaded_weight.shape[0], -1) param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) -- Gitee From 35b31a125eef2eb2659be3543a9cdd0dcde7d8c5 Mon Sep 17 00:00:00 2001 From: superxf Date: Wed, 12 Nov 2025 09:53:41 +0800 Subject: [PATCH 6/8] opt memory --- vllm_mindspore/model_executor/layers/linear.py | 18 ++++++++++-------- .../layers/vocab_parallel_embedding.py | 8 +++++--- vllm_mindspore/model_executor/models/qwen2.py | 17 ++++++++++------- vllm_mindspore/model_executor/models/qwen3.py | 18 +++++++++++------- .../model_executor/models/qwen3_moe.py | 12 +++++++----- 5 files changed, 43 insertions(+), 30 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index 46ea2d7d..5c28b8c2 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -26,6 +26,7 @@ import mindspore as ms import numpy as np from mindspore import Parameter, Tensor, mint, nn, ops from mindspore._c_expression.typing import Type as MSDtype +from mindspore.common.initializer import initializer from vllm.config import get_current_vllm_config from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -92,14 +93,10 @@ class UnquantizedLinearMethod(LinearMethodBase): def create_weights(self, layer: nn.Cell, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype, **extra_weight_attrs): - weight = Parameter( - mint.zeros( - (int(sum(output_partition_sizes)), - int(input_size_per_partition)), - dtype=params_dtype, - ), - requires_grad=False, - ) + weight_shape = (int(sum(output_partition_sizes)), + int(input_size_per_partition)) + weight = Parameter(initializer("zeros", weight_shape, params_dtype), + requires_grad=False) self.input_size_per_partition = int(input_size_per_partition) self.output_size_per_partition = int(sum(output_partition_sizes)) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) @@ -227,6 +224,7 @@ class ReplicatedLinear(LinearBase): f"Tried to load weights of size {loaded_weight.size()}" f"to a parameter of size {param.size()}") + param.init_data() param.set_data(ms.from_numpy(loaded_weight)) def construct( @@ -354,6 +352,7 @@ class ColumnParallelLinear(LinearBase): return output, output_bias def weight_loader(self, param, loaded_weight): + param.init_data() tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) shard_size = self.output_size_per_partition @@ -421,6 +420,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): param, loaded_weight, loaded_shard_id: Optional[int] = None): + param.init_data() output_dim = getattr(param, "output_dim", None) tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() @@ -529,6 +529,7 @@ class QKVParallelLinear(ColumnParallelLinear): param, loaded_weight, loaded_shard_id: Optional[str] = None): + param.init_data() output_dim = getattr(param, "output_dim", None) tp_rank = get_tensor_model_parallel_rank() @@ -704,6 +705,7 @@ class RowParallelLinear(LinearBase): return output, output_bias def weight_loader(self, param, loaded_weight): + param.init_data() tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) shard_size = self.input_size_per_partition diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index 0c678ca1..b7351ae9 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -25,6 +25,7 @@ from typing import Optional import mindspore as ms from mindspore import Parameter, Tensor, mint, nn, ops from mindspore.common.dtype import typing +from mindspore.common.initializer import initializer from vllm.config import get_current_vllm_config from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -50,9 +51,9 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype, **extra_weight_attrs): """Create weights for embedding layer.""" - weight = Parameter(mint.zeros( - (sum(output_partition_sizes), input_size_per_partition), - dtype=params_dtype), + weight_shape = (int(sum(output_partition_sizes)), + int(input_size_per_partition)) + weight = Parameter(initializer("zeros", weight_shape, params_dtype), requires_grad=False) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) layer.insert_param_to_cell("weight", weight) @@ -340,6 +341,7 @@ class VocabParallelEmbedding(nn.Cell): return output def weight_loader(self, param: Parameter, loaded_weight: Tensor): + param.init_data() output_dim = getattr(param, "output_dim", None) get_tensor_model_parallel_rank() # If parameter does not have output dim, then it should diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index db8e31c0..7b388c32 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -36,6 +36,7 @@ else: Qwen2Config = None from mindspore import Parameter, Tensor, mint, nn +from mindspore.nn.utils import no_init_parameters from vllm.attention.backends.abstract import AttentionType from vllm.config import CacheConfig, VllmConfig @@ -447,18 +448,20 @@ class Qwen2ForCausalLM(NativeModel, SupportsLoRA): self.lora_config = lora_config self.quant_config = quant_config - self.model = Qwen2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + with no_init_parameters(): + self.model = Qwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + with no_init_parameters(): + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) else: self.lm_head = PPMissingLayer() diff --git a/vllm_mindspore/model_executor/models/qwen3.py b/vllm_mindspore/model_executor/models/qwen3.py index 89a98b58..6c5fba3f 100644 --- a/vllm_mindspore/model_executor/models/qwen3.py +++ b/vllm_mindspore/model_executor/models/qwen3.py @@ -44,6 +44,7 @@ from collections.abc import Iterable from typing import Any, Optional, Union from mindspore import Tensor, nn +from mindspore.nn.utils import no_init_parameters from transformers import Qwen3Config from vllm.attention import AttentionType from vllm.config import CacheConfig, VllmConfig @@ -283,18 +284,21 @@ class Qwen3ForCausalLM(NativeModel): self.lora_config = lora_config self.quant_config = quant_config - self.model = Qwen3Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + + with no_init_parameters(): + self.model = Qwen3Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + with no_init_parameters(): + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) else: self.lm_head = PPMissingLayer() diff --git a/vllm_mindspore/model_executor/models/qwen3_moe.py b/vllm_mindspore/model_executor/models/qwen3_moe.py index 5c0260e4..545bbf8e 100644 --- a/vllm_mindspore/model_executor/models/qwen3_moe.py +++ b/vllm_mindspore/model_executor/models/qwen3_moe.py @@ -30,6 +30,7 @@ from collections.abc import Iterable from typing import Any, Optional, Union from mindspore import Parameter, Tensor, nn +from mindspore.nn.utils import no_init_parameters from transformers import PretrainedConfig from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -530,11 +531,12 @@ class Qwen3MoeForCausalLM(NativeModel, SupportesMoeDpTp, MixtureOfExperts): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = Qwen3MoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + with no_init_parameters(): + self.model = Qwen3MoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) -- Gitee From f4b5bf5751e89dbb9f4077732eb04479098cb454 Mon Sep 17 00:00:00 2001 From: superxf Date: Tue, 2 Sep 2025 17:30:31 +0800 Subject: [PATCH 7/8] opt penalties --- vllm_mindspore/model_executor/layers/utils.py | 63 +++++++++++++------ vllm_mindspore/v1/worker/gpu_input_batch.py | 41 +++++++----- 2 files changed, 69 insertions(+), 35 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/utils.py b/vllm_mindspore/model_executor/layers/utils.py index 3081182f..47580978 100644 --- a/vllm_mindspore/model_executor/layers/utils.py +++ b/vllm_mindspore/model_executor/layers/utils.py @@ -36,6 +36,23 @@ def get_token_bin_counts_and_mask( return bin_counts, mask +def get_repetition_penalties_mask( + prompt_tokens: ms.Tensor, + output_tokens: ms.Tensor, + vocab_size: int, + num_seqs: int, +) -> ms.Tensor: + # Compute the bin counts for the tokens. + # vocab_size + 1 for padding. + bin_counts = ms.mint.zeros((num_seqs, vocab_size + 1), dtype=ms.int64) + bin_counts.scatter_add_(1, prompt_tokens, ms.mint.ones_like(prompt_tokens)) + bin_counts.scatter_add_(1, output_tokens, ms.mint.ones_like(output_tokens)) + bin_counts = bin_counts[:, :vocab_size] + mask = bin_counts > 0 + + return mask + + def apply_penalties(logits: ms.Tensor, prompt_tokens_tensor: ms.Tensor, output_tokens_tensor: ms.Tensor, presence_penalties: ms.Tensor, @@ -57,26 +74,34 @@ def apply_penalties(logits: ms.Tensor, prompt_tokens_tensor: ms.Tensor, if logits.numel() <= 0: return logits num_seqs, vocab_size = logits.shape - _, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor, - vocab_size, num_seqs) - output_bin_counts, output_mask = get_token_bin_counts_and_mask( - output_tokens_tensor, vocab_size, num_seqs) - # use 'broadcast_to' to replace 'tensor.repeat' to improve performance - # when tensor shape is (num,seqs, 1), then 'tensor.repeat(1, vocab_size)' - # is equal to 'broadcast_to(tensor, (num_seqs, vocab_size))' - repetition_penalties = ms.mint.broadcast_to( - repetition_penalties.unsqueeze(dim=1), (num_seqs, vocab_size)) + if repetition_penalties is not None: + mask = get_repetition_penalties_mask( + prompt_tokens_tensor, + output_tokens_tensor, + vocab_size, + num_seqs, + ) + # use 'broadcast_to' to replace 'tensor.repeat' to improve performance + # when tensor shape is (num,seqs, 1), 'tensor.repeat(1, vocab_size)' + # is equal to 'broadcast_to(tensor, (num_seqs, vocab_size))' + repetition_penalties = ms.mint.broadcast_to( + repetition_penalties.unsqueeze(dim=1), (num_seqs, vocab_size)) - # use out of place computation instead of inplace setitem to improve - # performance 'tensor[tensor > 0]' will result in setitem, which is slow. - mask = prompt_mask | output_mask - logits = ms.mint.where(mask & (logits > 0), logits / repetition_penalties, - logits) - logits = ms.mint.where(mask & (logits <= 0), logits * repetition_penalties, - logits) + # use out of place computation instead of inplace setitem to improve + # performance 'tensor[tensor > 0]' will result in setitem, + # which is slow. + logits = ms.mint.where(mask & (logits > 0), + logits / repetition_penalties, logits) + logits = ms.mint.where(mask & (logits <= 0), + logits * repetition_penalties, logits) # We follow the definition in OpenAI API. - # Refer to https://platform.openai.com/docs/api-reference/parameter-details - logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts - logits -= presence_penalties.unsqueeze(dim=1) * output_mask + # Refer to https://platform.openai.com/docs/api-reference/parameter-details\ + if frequency_penalties is not None or presence_penalties is not None: + output_bin_counts, output_mask = get_token_bin_counts_and_mask( + output_tokens_tensor, vocab_size, num_seqs) + if frequency_penalties is not None: + logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts + if presence_penalties is not None: + logits -= presence_penalties.unsqueeze(dim=1) * output_mask return logits diff --git a/vllm_mindspore/v1/worker/gpu_input_batch.py b/vllm_mindspore/v1/worker/gpu_input_batch.py index f23dc4df..0ed0cb12 100644 --- a/vllm_mindspore/v1/worker/gpu_input_batch.py +++ b/vllm_mindspore/v1/worker/gpu_input_batch.py @@ -44,24 +44,33 @@ def _make_sampling_metadata(self) -> SamplingMetadata: _copy_slice_from_np(self.top_k_cpu, self.top_k, num_reqs) if not self.no_min_p: _copy_slice_from_np(self.min_p_cpu, self.min_p, num_reqs) - + frequency_penalties = None + presence_penalties = None + repetition_penalties = None + prompt_token_ids = None if not self.no_penalties: # Since syncing these tensors is expensive only copy them # if necessary i.e. if there are requests which require # penalties to be applied during sampling. - _copy_slice_from_np(self.frequency_penalties_cpu, - self.frequency_penalties, num_reqs) - _copy_slice_from_np(self.presence_penalties_cpu, - self.presence_penalties, num_reqs) - _copy_slice_from_np(self.repetition_penalties_cpu, - self.repetition_penalties, num_reqs) - - # The prompt tokens are used only for applying penalties during - # the sampling process. Hence copy these tensors only when - # there are requests which need penalties to be applied. + apply_freq = not np.all(self.frequency_penalties_cpu[:num_reqs] == 0.0) + apply_pres = not np.all(self.presence_penalties_cpu[:num_reqs] == 0.0) + apply_rep = not np.all(self.repetition_penalties_cpu[:num_reqs] == 1.0) prompt_token_ids = self._make_prompt_token_ids_tensor() - else: - prompt_token_ids = None + + if apply_freq: + _copy_slice_from_np(self.frequency_penalties_cpu, + self.frequency_penalties, num_reqs) + frequency_penalties = self.frequency_penalties[:num_reqs] + + if apply_pres: + _copy_slice_from_np(self.presence_penalties_cpu, + self.presence_penalties, num_reqs) + presence_penalties = self.presence_penalties[:num_reqs] + + if apply_rep: + _copy_slice_from_np(self.repetition_penalties_cpu, + self.repetition_penalties, num_reqs) + repetition_penalties = self.repetition_penalties[:num_reqs] allowed_token_ids_mask: Optional[Tensor] = None if not self.no_allowed_token_ids: @@ -82,9 +91,9 @@ def _make_sampling_metadata(self) -> SamplingMetadata: generators=self.generators, max_num_logprobs=self.max_num_logprobs, prompt_token_ids=prompt_token_ids, - frequency_penalties=self.frequency_penalties[:num_reqs], - presence_penalties=self.presence_penalties[:num_reqs], - repetition_penalties=self.repetition_penalties[:num_reqs], + frequency_penalties=frequency_penalties, + presence_penalties=presence_penalties, + repetition_penalties=repetition_penalties, output_token_ids=cast(list[list[int]], self.req_output_token_ids), min_tokens=self.min_tokens, no_penalties=self.no_penalties, -- Gitee From 7b183a587e33219ccb18353df3837191d3da967c Mon Sep 17 00:00:00 2001 From: huangzhuo Date: Fri, 26 Sep 2025 09:52:44 +0800 Subject: [PATCH 8/8] native support a8w8/sparsequant --- tests/st/python/test_ds_online.py | 2 +- vllm_mindspore/__init__.py | 37 ++++- .../model_executor/layers/logits_processor.py | 6 +- .../layers/quantization/__init__.py | 48 ++++++ .../layers/quantization/base_config.py | 3 + .../quantization/smooth_quant_modelslim.py | 144 +++++++++++++++++- .../model_loader/default_loader.py | 100 ++++++++++++ .../model_executor/model_loader/utils.py | 26 ++++ .../model_loader/weight_utils.py | 107 ++++++++++++- vllm_mindspore/model_executor/models/qwen2.py | 59 ++++++- vllm_mindspore/platforms/ascend.py | 2 +- 11 files changed, 523 insertions(+), 11 deletions(-) create mode 100644 vllm_mindspore/model_executor/model_loader/default_loader.py diff --git a/tests/st/python/test_ds_online.py b/tests/st/python/test_ds_online.py index ed342ffc..8b29f9ce 100644 --- a/tests/st/python/test_ds_online.py +++ b/tests/st/python/test_ds_online.py @@ -31,7 +31,7 @@ from tests.st.python.utils.env_var_manager import EnvVarManager env_manager = EnvVarManager() env_manager.setup_mindformers_environment() env_vars = { - "vLLM_MODEL_BACKEND": "MindFormers", + "VLLM_MS_MODEL_BACKEND": "MindFormers", "MS_ENABLE_LCCL": "off", "HCCL_OP_EXPANSION_MODE": "AIV", "ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3,4,5,6,7", diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 2e200672..31c93080 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -17,11 +17,13 @@ # isort:skip_file +import os import sys import warnings import msadapter # noqa: F401 from vllm_mindspore.ray_patch import patch_ray +ms_backend = os.environ.get("VLLM_MS_MODEL_BACKEND") patch_ray() if "vllm" in sys.modules: @@ -291,8 +293,9 @@ from .config import (_verify_quantization, _verify_args, vllm_config_post_init, vllm.config.ModelConfig._verify_quantization = _verify_quantization vllm.config.VllmConfig.__post_init__ = vllm_config_post_init -vllm.config.VllmConfig._get_quantization_config = staticmethod( - vllm_config_get_quantization_config) +if ms_backend == "MindFormers": + vllm.config.VllmConfig._get_quantization_config = staticmethod( + vllm_config_get_quantization_config) vllm.config.SchedulerConfig._verify_args = _verify_args vllm.config.CompilationConfig.model_post_init = model_post_init vllm.config._get_and_verify_dtype = _get_and_verify_dtype @@ -570,6 +573,36 @@ sys.modules["vllm.entrypoints.openai.tool_parsers.deepseekv3_tool_parser"] = ( from vllm_mindspore.entrypoints.__main__ import ( patch_server_run_api_server_worker_proc, ) +from vllm_mindspore.model_executor.layers.quantization import ( + get_quantization_config, QuantizationMethods) +from vllm_mindspore.model_executor.model_loader.weight_utils import ( + get_quant_config) + +vllm.model_executor.layers.quantization.get_quantization_config = ( + get_quantization_config) +vllm.config.get_quantization_config = get_quantization_config +vllm.model_executor.model_loader.weight_utils.get_quantization_config = ( + get_quantization_config) +vllm.model_executor.layers.quantization.QuantizationMethods = ( + QuantizationMethods) +vllm.model_executor.model_loader.weight_utils.get_quant_config = ( + get_quant_config) +vllm.config.get_quant_config = get_quant_config + +from vllm_mindspore.model_executor.model_loader.utils import ( + process_weights_after_loading) + +vllm.model_executor.model_loader.utils.process_weights_after_loading = ( + process_weights_after_loading) +vllm.model_executor.model_loader.base_loader.process_weights_after_loading = ( + process_weights_after_loading) + +from vllm_mindspore.model_executor.model_loader.default_loader import ( + _prepare_weights) +from vllm.model_executor.model_loader.default_loader import DefaultModelLoader + +DefaultModelLoader._prepare_weights = _prepare_weights + patch_server_run_api_server_worker_proc() from vllm_mindspore.model_executor.models.registry import _normalize_archs diff --git a/vllm_mindspore/model_executor/layers/logits_processor.py b/vllm_mindspore/model_executor/layers/logits_processor.py index c4f6f22e..14955fc7 100644 --- a/vllm_mindspore/model_executor/layers/logits_processor.py +++ b/vllm_mindspore/model_executor/layers/logits_processor.py @@ -81,7 +81,11 @@ class LogitsProcessor(nn.Cell): # Soft cap the logits. Used in Gemma 2. self.soft_cap = soft_cap # Whether to use gather or all-gather to gather the logits. - self.use_all_gather = current_platform.use_all_gather() + + if not envs.VLLM_USE_V1: + self.use_all_gather = True + else: + self.use_all_gather = current_platform.use_all_gather() def construct( self, diff --git a/vllm_mindspore/model_executor/layers/quantization/__init__.py b/vllm_mindspore/model_executor/layers/quantization/__init__.py index e69de29b..b33258d9 100644 --- a/vllm_mindspore/model_executor/layers/quantization/__init__.py +++ b/vllm_mindspore/model_executor/layers/quantization/__init__.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal, get_args + +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + +QuantizationMethods = Literal["golden-stick", "ascend"] +QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) + +# The customized quantization methods which will be added to this dict. +_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {} + + +def get_quantization_config(quantization: str) -> type[QuantizationConfig]: + if quantization not in QUANTIZATION_METHODS: + raise ValueError(f"Invalid quantization method: {quantization}") + + # lazy import to avoid triggering `torch.compile` too early + from .smooth_quant_modelslim import SmoothQuantModelSlimConfig + method_to_config: dict[str, type[QuantizationConfig]] = { + "golden-stick": SmoothQuantModelSlimConfig, + "ascend": SmoothQuantModelSlimConfig + } + # Update the `method_to_config` with customized quantization methods. + method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) + + return method_to_config[quantization] + + +__all__ = [ + "QuantizationConfig", "get_quantization_config", "QUANTIZATION_METHODS", + "QuantizationMethods" +] \ No newline at end of file diff --git a/vllm_mindspore/model_executor/layers/quantization/base_config.py b/vllm_mindspore/model_executor/layers/quantization/base_config.py index 37144a43..5728702d 100644 --- a/vllm_mindspore/model_executor/layers/quantization/base_config.py +++ b/vllm_mindspore/model_executor/layers/quantization/base_config.py @@ -142,6 +142,9 @@ class QuantizationConfig(ABC): """ raise NotImplementedError + def get_cache_scale(self, name: str) -> Optional[str]: + return None + def method_has_implemented_embedding( method_class: type[QuantizeMethodBase]) -> bool: diff --git a/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py b/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py index 8e1772dc..56ad689c 100644 --- a/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py +++ b/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py @@ -19,15 +19,16 @@ from typing import Any, Optional import mindspore import numpy as np import regex as re -from mindspore import Parameter, Tensor, mint +from mindspore import Parameter, Tensor, mint, ops from mindspore.common.initializer import initializer +from mindspore.communication import get_rank from mindspore.ops.auto_generate import (DynamicQuantExt, GroupedMatmul, GroupedMatmulV4, QuantBatchMatmul) from mindspore.ops.operations._infer_ops import QuantV2 from vllm_mindspore.attention import Attention from vllm_mindspore.model_executor.layers.linear import ( - LinearBase, LinearMethodBase, UnquantizedLinearMethod) + LinearBase, LinearMethodBase, RowParallelLinear, UnquantizedLinearMethod) from vllm_mindspore.model_executor.utils import set_weight_attrs from vllm_mindspore.utils import is_310p, set_weight_format_to_nz @@ -49,12 +50,16 @@ class SmoothQuantModelSlimConfig(QuantizationConfig): modules_to_not_convert: Optional[list[str]] = None, ) -> None: super().__init__() + self.fa3_quant = False + self.fa3_quant_layer = set() + self.rank_id = get_rank() self.full_config = full_config self.weight_bits = weight_bits self.group_size = group_size self.zero_point = zero_point self.dynamic_quant = dynamic_quant self.kv_cache_bits = kv_cache_bits + self.sparse_quant = False self.modules_to_not_convert = modules_to_not_convert or [] if self.weight_bits != 8: @@ -87,7 +92,11 @@ class SmoothQuantModelSlimConfig(QuantizationConfig): @staticmethod def get_config_filenames() -> list[str]: - return ["quant_model_description.json"] + return [ + "quant_model_description.json", + "quant_model_description_w8a8.json", + "quant_model_description_w8a8s.json" + ] @classmethod def from_config(cls, config: dict[str, @@ -112,6 +121,18 @@ class SmoothQuantModelSlimConfig(QuantizationConfig): if quant_config and quant_config.lower() == 'w8a8_dyn': self.dynamic_quant = True return A8W8DYNLinearMethod(self) + if f'rank_{self.rank_id}' in self.full_config: + sparse_quant_description = self.full_config[ + f'rank_{self.rank_id}'] + if sparse_quant_description[f"{prefix}.weight"].lower( + ) == "w8a8s": + self.sparse_quant = True + compress_weight_size = sparse_quant_description[ + f"{prefix}.weight.shape"] + compress_index_size = sparse_quant_description[ + f"{prefix}.index.shape"] + return A8W8SCLinearMethod(self, compress_weight_size[0], + compress_index_size[0]) print(f"get_quant_method unmatched {layer.__class__.__name__}, " f"{quant_key,self.full_config.get(quant_key)}") @@ -539,3 +560,120 @@ class A8W8DYNLinearMethod(LinearMethodBase): qx = mint.add(qx, bias) qx = qx.reshape(output_shape) return qx + + +class A8W8SCLinearMethod(LinearMethodBase): + '''Linear method for A8W8SCLinearMethod.''' + + def __init__(self, + quant_config: SmoothQuantModelSlimConfig, + compress_weight_size=None, + compress_index_size=None): + self.quant_config = quant_config + self.quant = QuantV2() + self.compress_weight_size = compress_weight_size + self.compress_index_size = compress_index_size + self.linear_sparse = ops.auto_generate.QuantLinearSparse() + + def create_weights(self, + layer: mindspore.nn.Cell, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype, + is_group_mm=False, + expert_num_per_partition=1, + **extra_weight_attrs): + if input_size_per_partition % self.quant_config.group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + output_size_per_partition = sum(output_partition_sizes) + self.output_size_per_partition = output_size_per_partition + self.input_size_per_partition = input_size_per_partition + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + weight = Parameter(initializer('normal', (self.compress_weight_size), + mindspore.int8), + name="weight") + index = Parameter(initializer('normal', (self.compress_index_size), + mindspore.int8), + name="index") + deq_scale = Parameter(initializer('normal', + (self.output_size_per_partition), + mindspore.int64), + name="deq_scale") + quant_bias = Parameter(initializer('zeros', + (self.output_size_per_partition), + mindspore.int32), + name="quant_bias") + input_scale = Parameter(Tensor(np.ones(self.input_size_per_partition), + mindspore.float16), + name="input_scale") + input_offset = Parameter(Tensor( + np.zeros(self.input_size_per_partition), mindspore.int8), + name="input_offset") + + layer.insert_param_to_cell("weight", weight) + layer.insert_param_to_cell("index", index) + layer.insert_param_to_cell("deq_scale", deq_scale) + layer.insert_param_to_cell("quant_bias", quant_bias) + layer.insert_param_to_cell("input_scale", input_scale) + layer.insert_param_to_cell("input_offset", input_offset) + + def process_weights_after_loading(self, layer: mindspore.nn.Cell) -> None: + input_scale = layer.input_scale.asnumpy() + input_offset = layer.input_offset.asnumpy() + if input_scale.shape == (1, ) and input_offset.shape == (1, ): + input_scale = np.full(shape=self.input_size_per_partition, + fill_value=input_scale[0]) + input_offset = np.full(shape=self.input_size_per_partition, + fill_value=input_offset[0]) + layer.input_scale = Parameter(Tensor(input_scale, + dtype=mindspore.float16), + name=layer.input_scale.name, + requires_grad=False) + layer.input_offset = Parameter(Tensor(input_offset, + dtype=mindspore.int8), + name=layer.input_offset.name, + requires_grad=False) + rank_id = get_rank() + if isinstance(layer, RowParallelLinear) and \ + layer.quant_bias is not None and\ + rank_id != 0: + quant_bias = np.zeros_like(layer.quant_bias, dtype=np.int32) + layer.quant_bias = Parameter(Tensor(quant_bias, + dtype=mindspore.int32), + name=layer.quant_bias.name, + requires_grad=False) + + def apply(self, + layer: mindspore.nn.Cell, + x: mindspore.Tensor, + bias: mindspore.Parameter = None, + group_list=None, + cumsum_flag=False) -> mindspore.Tensor: + weight = layer.weight + index = layer.index + deq_scale = layer.deq_scale + quant_bias = layer.quant_bias + input_scale = layer.input_scale + input_offset = layer.input_offset + + output_shape = x.shape[:-1] + (self.output_size_per_partition, ) + x = x.reshape(-1, self.input_size_per_partition) + + x = self.quant(x, input_scale, input_offset, False, "ROUND", + mindspore.int8) + x = self.linear_sparse(x, weight, deq_scale, index, quant_bias) + + x = x.reshape(output_shape) + + return x \ No newline at end of file diff --git a/vllm_mindspore/model_executor/model_loader/default_loader.py b/vllm_mindspore/model_executor/model_loader/default_loader.py new file mode 100644 index 00000000..5c694d76 --- /dev/null +++ b/vllm_mindspore/model_executor/model_loader/default_loader.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +import glob +import os +from typing import Optional + +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME +from vllm.config import LoadFormat +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.model_executor.model_loader.weight_utils import ( + download_safetensors_index_file_from_hf, download_weights_from_hf, + filter_duplicate_safetensors_files, filter_files_not_needed_for_inference) + + +def _prepare_weights( + self, + model_name_or_path: str, + revision: Optional[str], + fall_back_to_pt: bool, + allow_patterns_overrides: Optional[list[str]], +) -> tuple[str, list[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + model_name_or_path = (self._maybe_download_from_modelscope( + model_name_or_path, revision) or model_name_or_path) + + is_local = os.path.isdir(model_name_or_path) + load_format = self.load_config.load_format + use_safetensors = False + index_file = SAFE_WEIGHTS_INDEX_NAME + # Some quantized models use .pt files for storing the weights. + if load_format == LoadFormat.AUTO: + allow_patterns = ["*.safetensors", "*.bin"] + elif (load_format == LoadFormat.SAFETENSORS + or load_format == LoadFormat.FASTSAFETENSORS): + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.MISTRAL: + use_safetensors = True + allow_patterns = ["consolidated*.safetensors"] + index_file = "consolidated.safetensors.index.json" + elif load_format == LoadFormat.PT: + allow_patterns = ["*.pt"] + elif load_format == LoadFormat.NPCACHE: + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if allow_patterns_overrides is not None: + allow_patterns = allow_patterns_overrides + + if not is_local: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + else: + hf_folder = model_name_or_path + hf_weights_files: list[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + # Supports loading offline-partitioned weights. + if len(hf_weights_files) == 0: + tp_rank = get_tensor_model_parallel_rank() + hf_weights_files += glob.glob( + os.path.join(hf_folder, f"rank_{tp_rank}", pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, + index_file, + self.load_config.download_dir, + revision, + ) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file) + else: + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_folder, hf_weights_files, use_safetensors \ No newline at end of file diff --git a/vllm_mindspore/model_executor/model_loader/utils.py b/vllm_mindspore/model_executor/model_loader/utils.py index 571fe06c..beb299f3 100644 --- a/vllm_mindspore/model_executor/model_loader/utils.py +++ b/vllm_mindspore/model_executor/model_loader/utils.py @@ -23,10 +23,13 @@ import os from contextlib import contextmanager from mindspore import nn +from vllm.attention import Attention from vllm.config import ModelConfig, ModelImpl from vllm.model_executor.model_loader.utils import logger from vllm.model_executor.models import ModelRegistry +from vllm_mindspore.model_executor.layers.quantization.base_config import ( + QuantizeMethodBase) from vllm_mindspore.model_executor.models.registry import ( AUTO_SELECT_FIXED_MODEL, MindSporeModelRegistry, mcore_support_list, mf_supported, mindone_supported) @@ -209,3 +212,26 @@ def ms_device_loading_context(module, target_device): f"'cuda' device now, but got '{target_device}'.") yield module return + + +def process_weights_after_loading(model, model_config, target_device) -> None: + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if isinstance(quant_method, QuantizeMethodBase): + # # When quant methods need to process weights after loading + # # (for repacking, quantizing, etc), they expect parameters + # # to be on the global target device. This scope is for the + # # case where cpu offloading is used, where we will move the + # # parameters onto device for processing and back off after. + # with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + + # Currently only used by MLA. + # NOTE: This intentionally happens after other modules so we can easily + # decompress the weights for MLA. + for _, module in model.named_modules(): + if isinstance(module, Attention) and \ + hasattr(module, "process_weights_after_loading"): + # TODO(lucas): see if there is a way to unify the signatures + # of process_weights_after_loading + module.process_weights_after_loading(model_config.dtype) diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 47d2a67f..16809c9f 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -18,16 +18,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import glob +import json +import os from collections.abc import Generator from typing import Any +import huggingface_hub import mindspore as ms +from huggingface_hub import snapshot_download from mindspore import Parameter from safetensors import safe_open from tqdm.auto import tqdm +from vllm.config import LoadConfig from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, - enable_tqdm) + DisabledTqdm, + enable_tqdm, + get_lock) +from vllm_mindspore.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm_mindspore.platforms.ascend import ModelConfig from vllm_mindspore.utils import cast_weight_for_310p, is_310p @@ -39,6 +50,100 @@ def convert_loaded_weight(loaded_weight): return loaded_weight +def get_quant_config(model_config: ModelConfig, + load_config: LoadConfig) -> QuantizationConfig: + + from vllm_mindspore.model_executor.layers.quantization import ( + get_quantization_config) + quant_cls = get_quantization_config(model_config.quantization) + + # GGUF doesn't have config file + if model_config.quantization == "gguf": + return quant_cls.from_config({}) + + # Read the quantization config from the HF model config, if available. + hf_quant_config = getattr(model_config.hf_config, "quantization_config", + None) + # some vision model may keep quantization_config in their text_config + hf_text_config = getattr(model_config.hf_config, "text_config", None) + if hf_quant_config is None and hf_text_config is not None: + hf_quant_config = getattr(hf_text_config, "quantization_config", None) + if hf_quant_config is None: + # compressed-tensors uses a compressions_config + hf_quant_config = getattr(model_config.hf_config, "compression_config", + None) + if hf_quant_config is not None: + if os.path.isdir(model_config.model): + quant_config_file = os.path.join( + model_config.model, + quant_cls.get_config_filenames()[2]) + with open(quant_config_file) as f: + quant_config = json.load(f) + return quant_cls.from_config(hf_quant_config | quant_config) + + # In case of bitsandbytes/QLoRA, get quant config from the adapter model. + if model_config.quantization == "bitsandbytes": + if (not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" + not in load_config.model_loader_extra_config): + return quant_cls.from_config({"adapter_name_or_path": ""}) + model_name_or_path = load_config.model_loader_extra_config[ + "qlora_adapter_name_or_path"] + + else: + model_name_or_path = model_config.model + is_local = os.path.isdir(model_name_or_path) + if not is_local: + # Download the config files. + with get_lock(model_name_or_path, load_config.download_dir): + hf_folder = snapshot_download( + model_name_or_path, + revision=model_config.revision, + allow_patterns="*.json", + cache_dir=load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + tqdm_class=DisabledTqdm, + ) + else: + hf_folder = model_name_or_path + + possible_config_filenames = quant_cls.get_config_filenames() + + # If the quantization config is not found, use the default config. + if not possible_config_filenames: + return quant_cls() + + config_files = glob.glob(os.path.join(hf_folder, "*.json")) + + quant_config_files = [ + f for f in config_files if any( + f.endswith(x) for x in possible_config_filenames) + ] + if len(quant_config_files) == 0: + raise ValueError( + f"Cannot find the config file for {model_config.quantization}") + if len(quant_config_files) > 1: + raise ValueError( + f"Found multiple config files for {model_config.quantization}: " + f"{quant_config_files}") + + quant_config_file = quant_config_files[0] + with open(quant_config_file) as f: + config = json.load(f) + + if model_config.quantization == "bitsandbytes": + config["adapter_name_or_path"] = model_name_or_path + elif model_config.quantization == "modelopt": + if config["producer"]["name"] == "modelopt": + return quant_cls.from_config(config) + else: + raise ValueError( + f"Unsupported quantization config" + f" found for {model_config.quantization} in {f}.") + + return quant_cls.from_config(config) + + def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): """ Read numpy slice data based on axis and slice range. diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 7b388c32..b4697ab0 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -35,12 +35,15 @@ if TYPE_CHECKING: else: Qwen2Config = None -from mindspore import Parameter, Tensor, mint, nn from mindspore.nn.utils import no_init_parameters +from mindspore import Parameter, Tensor, mint, nn, ops +from mindspore.communication.management import get_rank from vllm.attention.backends.abstract import AttentionType from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + get_tensor_model_parallel_rank) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.sequence import IntermediateTensors @@ -62,6 +65,7 @@ from vllm_mindspore.model_executor.models.model_base import (NativeModel) from vllm_mindspore.model_executor.models.utils import ( PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm_mindspore.utils import is_310p, FORMAT_TYPE class Qwen2MLP(nn.Cell): @@ -363,8 +367,55 @@ class Qwen2Model(nn.Cell): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def load_sparsequant_weights(self, weights: Iterable[tuple[str, Tensor]], + params_dict: dict[str, Parameter]): + weights_dict = dict(weights) + + for name, loaded_weight in weights_dict.items(): + if get_tensor_model_parallel_rank( + ) > 0 and "o_proj.quant_bias" in name: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + param.set_data(Tensor(loaded_weight[:]).contiguous()) + + def adjust_weight(params_dict): + if not is_310p(): + return + + target_keywords = [ + "qkv_proj.weight", + "o_proj.weight", + "gate_up_proj.weight", + "down_proj.weight", + "lm_head.weight", + ] + + rank_id = get_rank() + for name, param in params_dict.items(): + if any(name.endswith(keyword) for keyword in target_keywords): + weight_type = self.quant_config.full_config[ + f"rank_{rank_id}"][name] + if weight_type.lower() == "w8a8s": + # Compressed weights do not need to be + # converted to Nz format. + continue + + cast_weight = ops.auto_generate.format_cast( + param, FORMAT_TYPE['nz']) + param.set_data(cast_weight) + + if is_310p(): + adjust_weight(params_dict) + def load_weights(self, weights: Iterable[tuple[str, Tensor]], params_dict: dict[str, Parameter]): + if self.quant_config is not None and self.quant_config.sparse_quant: + self.load_sparsequant_weights(weights, params_dict) + return loaded_params: set[str] = set() stacked_params_mapping = [ # (param_name, shard_name, shard_id) # noqa: ERA001 @@ -376,6 +427,10 @@ class Qwen2Model(nn.Cell): ] for name, loaded_weight in weights: + if get_tensor_model_parallel_rank( + ) > 0 and "o_proj.quant_bias" in name: + continue + if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name diff --git a/vllm_mindspore/platforms/ascend.py b/vllm_mindspore/platforms/ascend.py index 5b4f232a..fee63ec5 100644 --- a/vllm_mindspore/platforms/ascend.py +++ b/vllm_mindspore/platforms/ascend.py @@ -202,4 +202,4 @@ class AscendPlatform(Platform): if is_310p(): # bfloat16 is not supported on the 310p due to hardware limitations. return [torch.float16, torch.float32] - return [torch.bfloat16, torch.float16, torch.float32] \ No newline at end of file + return [torch.bfloat16, torch.float16, torch.float32] -- Gitee