diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 58c78e076943e60e5477904cbf90a63437fc126b..6676b630f576b7c272089f512a639599f850428b 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -165,6 +165,8 @@ from vllm_mindspore.model_executor.models.registry import ( MindSporeModelRegistry, _SUBPROCESS_COMMAND, ) +from vllm_mindspore.model_executor.layers.quantization import ( + get_quantization_config) vllm.config.ModelRegistry = MindSporeModelRegistry @@ -174,7 +176,7 @@ vllm.model_executor.models.ModelRegistry = MindSporeModelRegistry vllm.model_executor.models.registry._SUBPROCESS_COMMAND = _SUBPROCESS_COMMAND from vllm_mindspore.model_executor.model_loader.utils import ( - get_ms_model_architecture, ) + get_ms_model_architecture, ms_device_loading_context) # To patching the get_model_architecture, should import it first. from vllm.model_executor.model_loader import get_model_architecture # noqa F401 @@ -185,6 +187,8 @@ vllm.model_executor.model_loader.utils.get_model_architecture = ( get_ms_model_architecture) vllm.model_executor.model_loader.default_loader.get_model_architecture = ( get_ms_model_architecture) +vllm.model_executor.model_loader.utils.device_loading_context = ( + ms_device_loading_context) from vllm_mindspore.model_executor.sampling_metadata import SamplingTensors @@ -214,6 +218,8 @@ from vllm_mindspore.model_executor.model_loader.weight_utils import ( vllm.model_executor.model_loader.default_loader.safetensors_weights_iterator = ( safetensors_weights_iterator) +vllm.model_executor.model_loader.weight_utils.get_quantization_config = ( + get_quantization_config) from vllm_mindspore.worker.worker import (_warm_up_model, wrapper_worker_bind_cpu) @@ -280,14 +286,11 @@ vllm.executor.ray_distributed_executor.initialize_ray_cluster = ( vllm.v1.utils.CoreEngineActorManager.__init__ = core_engine_actor_manager_init from .config import (_verify_quantization, _verify_args, vllm_config_post_init, - vllm_config_get_quantization_config, model_post_init, - _get_and_verify_dtype, stateless_init_dp_group, - has_unfinished_dp) + model_post_init, _get_and_verify_dtype, + stateless_init_dp_group, has_unfinished_dp) vllm.config.ModelConfig._verify_quantization = _verify_quantization vllm.config.VllmConfig.__post_init__ = vllm_config_post_init -vllm.config.VllmConfig._get_quantization_config = staticmethod( - vllm_config_get_quantization_config) vllm.config.SchedulerConfig._verify_args = _verify_args vllm.config.CompilationConfig.model_post_init = model_post_init vllm.config._get_and_verify_dtype = _get_and_verify_dtype diff --git a/vllm_mindspore/config.py b/vllm_mindspore/config.py index 7f93164b604795a65addf35cae9f69608b779e22..d072e35adfb2e323c071072453dd2810a38ea305 100644 --- a/vllm_mindspore/config.py +++ b/vllm_mindspore/config.py @@ -36,7 +36,7 @@ from vllm.config import (_STR_DTYPE_TO_TORCH_DTYPE, CacheConfig, from vllm.logger import init_logger from vllm.utils import random_uuid -from vllm_mindspore.utils import is_310p +from vllm_mindspore.utils import is_310p, is_native_model_backend logger = init_logger(__name__) @@ -69,7 +69,8 @@ def vllm_config_post_init(self): self.prompt_adapter_config.verify_with_model_config(self.model_config) if self.quant_config is None and \ - self.model_config is not None and self.load_config is not None: + self.model_config is not None and self.load_config is not None and \ + is_native_model_backend(): self.quant_config = VllmConfig._get_quantization_config( self.model_config, self.load_config) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 4669352a301b746e75c5ab7527ce45d22c364bfd..479bb36eb30ba4dee85dc57f9cbb8b465e58a850 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -31,15 +31,13 @@ from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group) from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs from vllm_mindspore.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, MoeMode) from vllm_mindspore.model_executor.layers.fused_moe.fused_moe import ( FusedExperts, fused_topk, grouped_topk) -from vllm_mindspore.model_executor.layers.quantization.base_config import ( - QuantizeMethodBase) from vllm_mindspore.model_executor.model_loader.weight_utils import ( split_loaded_weight) diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index 9212190506bdba063cf9ce98bfe0a59b7e8f680f..81f0fe792a150c27392fdd1d1f0067dcfd3801cc 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -30,14 +30,14 @@ from vllm.config import get_current_vllm_config from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.utils import set_weight_attrs from vllm_mindspore.distributed.communication_op import ( AllGatherFromModelParallelRegion, ReduceFromModelParallelRegion) -from vllm_mindspore.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) from vllm_mindspore.model_executor.model_loader.weight_utils import ( split_loaded_weight) -from vllm_mindspore.model_executor.utils import set_weight_attrs WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", @@ -440,9 +440,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): loaded_weight = split_loaded_weight(loaded_weight, output_dim, start_idx, shard_size) - assert loaded_weight.shape == (shard_size, param.shape[1]) param[shard_offset:shard_offset + - shard_size, :] = ms.from_numpy(loaded_weight) + shard_size] = ms.from_numpy(loaded_weight) class QKVParallelLinear(ColumnParallelLinear): @@ -699,6 +698,9 @@ class RowParallelLinear(LinearBase): return output, output_bias def weight_loader(self, param, loaded_weight): + if param.name.endswith("bias") and (self.tp_rank > 0 + or self.skip_bias_add): + return tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) shard_size = self.input_size_per_partition diff --git a/vllm_mindspore/model_executor/layers/quantization/__init__.py b/vllm_mindspore/model_executor/layers/quantization/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..58c12532fb88577a18b7803572ea43539f4e0d61 100644 --- a/vllm_mindspore/model_executor/layers/quantization/__init__.py +++ b/vllm_mindspore/model_executor/layers/quantization/__init__.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/model_executor/layers/quantization/__init__.py +# +# +# Copyright 2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + +from vllm_mindspore.utils import is_native_model_backend + +QUANTIZATION_METHODS: list[str] = list() + + +def get_quantization_config(quantization: str) -> type[QuantizationConfig]: + if not is_native_model_backend(): + return + if quantization not in QUANTIZATION_METHODS: + raise ValueError(f"Invalid quantization method: {quantization}.") + + method_to_config: dict[str, type[QuantizationConfig]] = {} + return method_to_config[quantization] diff --git a/vllm_mindspore/model_executor/layers/quantization/attention.py b/vllm_mindspore/model_executor/layers/quantization/attention.py index 556b69e2a3e75fdce029ccc43ad07af738768608..6561a50a99d42fbc9ad57eafa761d98d5b38399f 100644 --- a/vllm_mindspore/model_executor/layers/quantization/attention.py +++ b/vllm_mindspore/model_executor/layers/quantization/attention.py @@ -31,11 +31,11 @@ from mindspore.ops.operations.nn_ops import (FlashAttentionScore, PromptFlashAttention) from vllm.attention.backends.abstract import AttentionType from vllm.config import CacheConfig - -from vllm_mindspore.model_executor.layers.quantization.base_config import ( +from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm_mindspore.model_executor.utils import (get_model_context, - set_weight_attrs) +from vllm.model_executor.utils import set_weight_attrs + +from vllm_mindspore.model_executor.utils import get_model_context from vllm_mindspore.utils import is_310p diff --git a/vllm_mindspore/model_executor/layers/quantization/base_config.py b/vllm_mindspore/model_executor/layers/quantization/base_config.py deleted file mode 100644 index 37144a431aca255687d28958d835282980826b52..0000000000000000000000000000000000000000 --- a/vllm_mindspore/model_executor/layers/quantization/base_config.py +++ /dev/null @@ -1,157 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# Adapted from -# https://github.com/vllm-project/vllm/blob/v0.8.3/vllm/model_executor/layers/quantization/base_config.py -# -# Copyright 2025 Huawei Technologies Co., Ltd. -# Copyright 2024-2025 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from abc import ABC, abstractmethod -from typing import Any, Optional - -import mindspore as ms - - -class QuantizeMethodBase(ABC): - """Base class for different quantized methods.""" - - @abstractmethod - def create_weights(self, layer: ms.nn.Cell, *weight_args, - **extra_weight_attrs): - """Create weights for a layer. - - The weights will be set as attributes of the layer.""" - raise NotImplementedError - - @abstractmethod - def apply(self, layer: ms.nn.Cell, *args, **kwargs) -> ms.Tensor: - """Apply the weights in layer to the input tensor. - - Expects create_weights to have been called before on the layer.""" - raise NotImplementedError - - # Not required functions - def embedding(self, layer: ms.nn.Cell, *args, **kwargs) -> ms.Tensor: - """Gather embeddings in the layer based on indices in the input tensor. - - Expects create_weights to have been called before on the layer.""" - raise NotImplementedError - - def process_weights_after_loading(self, layer: ms.nn.Cell) -> None: - """Process the weight after loading. - - This can be used for example, to transpose weights for computation. - """ - return - - -class QuantizationConfig(ABC): - """Base class for quantization configs.""" - - def __init__(self): - super().__init__() - # mapping is updated by models as they initialize - self.packed_modules_mapping: dict[str, list[str]] = dict() - - @abstractmethod - def get_name(self) -> str: - """Name of the quantization method.""" - raise NotImplementedError - - @abstractmethod - def get_supported_act_dtypes(self): - """list of supported activation dtypes.""" - raise NotImplementedError - - @classmethod - @abstractmethod - def get_min_capability(cls) -> int: - """Minimum GPU capability to support the quantization method. - - E.g., 70 for Volta, 75 for Turing, 80 for Ampere. - This requirement is due to the custom CUDA kernels used by the - quantization method. - """ - raise NotImplementedError - - @staticmethod - @abstractmethod - def get_config_filenames() -> list[str]: - """list of filenames to search for in the model directory.""" - raise NotImplementedError - - @classmethod - @abstractmethod - def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig": - """Create a config class from the model's quantization config.""" - raise NotImplementedError - - @classmethod - def override_quantization_method(cls, hf_quant_cfg, - user_quant) -> Optional[str]: - """ - Detects if this quantization method can support a given checkpoint - format by overriding the user specified quantization method -- - this method should only be overwritten by subclasses in exceptional - circumstances - """ - return None - - @staticmethod - def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any: - """Get a value from the model's quantization config.""" - for key in keys: - if key in config: - return config[key] - raise ValueError(f"Cannot find any of {keys} in the model's " - "quantization config.") - - @staticmethod - def get_from_keys_or(config: dict[str, Any], keys: list[str], - default: Any) -> Any: - """Get a optional value from the model's quantization config.""" - try: - return QuantizationConfig.get_from_keys(config, keys) - except ValueError: - return default - - @abstractmethod - def get_quant_method(self, layer: ms.nn.Cell, - prefix: str) -> Optional[QuantizeMethodBase]: - """Get the quantize method to use for the quantized layer. - - Args: - layer: The layer for the quant method. - prefix: The full name of the layer in the state dict - Returns: - The quantize method. None if the given layer doesn't support quant - method. - """ - raise NotImplementedError - - -def method_has_implemented_embedding( - method_class: type[QuantizeMethodBase]) -> bool: - """ - Not all quant methods have embedding implemented, so we need to check that - it exists for our given method. We check this by making sure the function - has been changed from the base implementation. - """ - base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", - None) - class_embedding = inspect.getattr_static(method_class, "embedding", None) - - return class_embedding is not None and class_embedding is not base_embedding diff --git a/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py b/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py index 841e179756d5f3461879f6722da8077e39574f34..26436873c35cf74f13fb4b32cd22945123e4cdec 100644 --- a/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py +++ b/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py @@ -14,16 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re from typing import Any, Optional import mindspore import numpy as np +import regex as re from mindspore import Parameter, Tensor, mint from mindspore.common.initializer import initializer from mindspore.ops.auto_generate import (DynamicQuantExt, GroupedMatmul, GroupedMatmulV4, QuantBatchMatmul) from mindspore.ops.operations._infer_ops import QuantV2 +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) from vllm_mindspore.attention import Attention from vllm_mindspore.model_executor.layers.linear import ( @@ -32,7 +34,6 @@ from vllm_mindspore.model_executor.utils import set_weight_attrs from vllm_mindspore.utils import is_310p from .attention import BaseKVCacheMethod, KVCacheInt8Method -from .base_config import QuantizationConfig, QuantizeMethodBase class SmoothQuantModelSlimConfig(QuantizationConfig): diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index f25c8059107c6b910d18bcb768f36a0081048718..9d22a11c57142f8ae1b3b400f0e20d64aeb9ab8b 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -29,15 +29,13 @@ from vllm.config import get_current_vllm_config from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) +from vllm.model_executor.utils import set_weight_attrs from vllm_mindspore.distributed.communication_op import ( ReduceFromModelParallelRegion) -from vllm_mindspore.model_executor.layers.quantization.base_config import ( - QuantizeMethodBase, method_has_implemented_embedding) from vllm_mindspore.model_executor.model_loader.weight_utils import ( split_loaded_weight) -from vllm_mindspore.model_executor.utils import set_weight_attrs DEFAULT_VOCAB_PADDING_SIZE = 64 diff --git a/vllm_mindspore/model_executor/model_loader/utils.py b/vllm_mindspore/model_executor/model_loader/utils.py index cd76b99a214d9842e44c23017026a0d64dd1637c..525dc1b88d765e3186f8508985e9c46caf88298c 100644 --- a/vllm_mindspore/model_executor/model_loader/utils.py +++ b/vllm_mindspore/model_executor/model_loader/utils.py @@ -20,6 +20,7 @@ """ utils for load model """ import os +from contextlib import contextmanager from mindspore import nn from vllm.config import ModelConfig, ModelImpl @@ -198,3 +199,13 @@ def get_ms_model_architecture( raise RecursionError("MindSpore unsupported reward model task now!") return model_cls, arch + + +@contextmanager +def ms_device_loading_context(module, target_device): + if target_device != "cuda": + raise NotImplementedError( + f"vLLM-Mindspore Plugin only supports loading model on " + f"'cuda' device now, but got '{target_device}'.") + yield module + return diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index db8e31c0c4764af35a7b487473e5a9874b786a55..13118438bf731840dc445bfc4a57b2ae929f35a1 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -399,6 +399,8 @@ class Qwen2Model(nn.Cell): name = name.replace(weight_name, param_name) if name in params_dict: param = params_dict[name] + if not getattr(param, "weight_loader", None): + continue weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) loaded_params.add(name) diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index b2b61af9f11e6ba045f96fd411f7664cfa1b9182..b49bd06f05d956634aa4b89e220428a3a6d7e362 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -298,10 +298,9 @@ def _allocate_nz_kv_cache_tensors_fa3(self, kv_cache_config): } # fa3 quant layer target_dtype is int8 # no fa3 quant layer target_dtype is bfloat16 - fa3_quant = self.vllm_config.quant_config.fa3_quant \ - if self.vllm_config.quant_config else False - fa3_quant_layer = self.vllm_config.quant_config.fa3_quant_layer \ - if self.vllm_config.quant_config else set() + fa3_quant = getattr(self.vllm_config.quant_config, "fa3_quant", False) + fa3_quant_layer: set[int] = getattr(self.vllm_config.quant_config, + "fa3_quant_layer", set()) kv_lora_rank = getattr(self.vllm_config.model_config.hf_text_config, 'kv_lora_rank', 0) qk_rope_head_dim = getattr(self.vllm_config.model_config.hf_text_config, @@ -742,10 +741,9 @@ def wrapper_gpu_model_runner_execute_model(func): def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla - fa3_quant = self.vllm_config.quant_config.fa3_quant \ - if self.vllm_config.quant_config else False - fa3_quant_layer = self.vllm_config.quant_config.fa3_quant_layer \ - if self.vllm_config.quant_config else set() + fa3_quant = getattr(self.vllm_config.quant_config, "fa3_quant", False) + fa3_quant_layer: set[int] = getattr(self.vllm_config.quant_config, + "fa3_quant_layer", set()) kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionWrapper)