From b6f0b26b2f3b99ef61e5d10b7c6f20ecb478eeaf Mon Sep 17 00:00:00 2001 From: huoxinyou Date: Mon, 15 Sep 2025 09:39:28 +0800 Subject: [PATCH 1/6] 310p weight_loader need format_cast('nz') --- .../model_executor/layers/linear.py | 25 +++++++++++++++++++ .../layers/vocab_parallel_embedding.py | 6 +++++ .../model_loader/weight_utils.py | 18 +++++++++++++ .../model_executor/models/model_base.py | 5 ++-- 4 files changed, 52 insertions(+), 2 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index c81f0e32..28dabdd6 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -20,6 +20,7 @@ """ Linear methods for quantized linear layers. """ from abc import abstractmethod +from ast import Dict from typing import Optional, Union import mindspore as ms @@ -39,6 +40,7 @@ from vllm_mindspore.model_executor.layers.quantization.base_config import ( from vllm_mindspore.model_executor.model_loader.weight_utils import ( split_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs +from vllm_mindspore.utils import FORMAT_TYPE, is_310p WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", @@ -157,10 +159,20 @@ class LinearBase(nn.Cell): self.quant_method = quant_config.get_quant_method(self, prefix=prefix) self.return_bias = return_bias + self.param_load_counts: Dict[str, int] = {} def construct(self, x: Tensor) -> Tensor: raise NotImplementedError + def format_to_nz(self, param, merge_count=1): + current_count = self.param_load_counts.get(param.name, 0) + 1 + self.param_load_counts[param.name] = current_count + + if current_count == merge_count: + cast_weight = ops.auto_generate.format_cast(param, FORMAT_TYPE['nz']) + param.set_data(cast_weight) + del self.param_load_counts[param.name] + class ColumnParallelLinear(LinearBase): """Linear layer with column parallelism. @@ -281,6 +293,8 @@ class ColumnParallelLinear(LinearBase): assert param.shape == loaded_weight.shape param.set_data(ms.from_numpy(loaded_weight)) + if is_310p() and param.name.endswith("weight"): + self.format_to_nz(param, merge_count=1) class MergedColumnParallelLinear(ColumnParallelLinear): @@ -351,6 +365,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): assert loaded_weight.shape == (shard_size, param.shape[1]) param[shard_offset:shard_offset + shard_size, :] = ms.from_numpy(loaded_weight) + if is_310p() and param.name.endswith("weight"): + loaded_shard_num = 2 # gating/hidden + self.format_to_nz(param, loaded_shard_num) class QKVParallelLinear(ColumnParallelLinear): @@ -459,6 +476,9 @@ class QKVParallelLinear(ColumnParallelLinear): assert loaded_weight.shape == param.shape param.set_data(loaded_weight) + if is_310p() and param.name.endswith("weight"): + loaded_shard_num = 3 # q/k/v + self.format_to_nz(param, loaded_shard_num) return assert loaded_shard_id in ["q", "k", "v"] @@ -487,6 +507,9 @@ class QKVParallelLinear(ColumnParallelLinear): if param.name.endswith("bias"): assert loaded_weight.shape == (shard_size, ) param[shard_offset:shard_offset + shard_size] = loaded_weight + if is_310p() and param.name.endswith("weight"): + loaded_shard_num = 3 # q/k/v + self.format_to_nz(param, loaded_shard_num) class RowParallelLinear(LinearBase): @@ -620,3 +643,5 @@ class RowParallelLinear(LinearBase): assert param.shape == loaded_weight.shape param.set_data(ms.from_numpy(loaded_weight)) + if is_310p() and param.name.endswith("weight"): + self.format_to_nz(param, merge_count=1) diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index d4143916..dbf8e0db 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -38,6 +38,7 @@ from vllm_mindspore.model_executor.layers.quantization.base_config import ( from vllm_mindspore.model_executor.model_loader.weight_utils import ( split_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs +from vllm_mindspore.utils import is_310p, FORMAT_TYPE DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -339,6 +340,7 @@ class VocabParallelEmbedding(nn.Cell): return output def weight_loader(self, param: Parameter, loaded_weight: Tensor): + need_nz = is_310p() and param.name == "weight" output_dim = getattr(param, "output_dim", None) get_tensor_model_parallel_rank() # If parameter does not have output dim, then it should @@ -352,6 +354,8 @@ class VocabParallelEmbedding(nn.Cell): f"'loaded_weight.shape', but got {param.data.shape} " f"and {loaded_weight.shape}") param.set_data(ms.from_numpy(loaded_weight)) + if need_nz: + param.set_data(ops.auto_generate.format_cast(param, FORMAT_TYPE['nz'])) return # Shard indexes for loading the weight @@ -368,6 +372,8 @@ class VocabParallelEmbedding(nn.Cell): param[:loaded_weight.shape[0]] = ms.from_numpy(loaded_weight) param[loaded_weight.shape[0]:] = 0 + if need_nz: + param.set_data(ops.auto_generate.format_cast(param, FORMAT_TYPE['nz'])) class ParallelLMHead(VocabParallelEmbedding): diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 6bf2dd4c..9947387e 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -20,6 +20,7 @@ from collections.abc import Generator from typing import Any +import numpy as np import mindspore as ms from mindspore import Parameter @@ -27,6 +28,7 @@ from safetensors import safe_open from tqdm.auto import tqdm from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, enable_tqdm) +from vllm_mindspore.utils import is_310p def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): @@ -39,6 +41,11 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): """ if shard_dim is None: loaded_weight = loaded_weight[:] + loaded_weight = ( + loaded_weight.astype(np.float16) + if ((str(loaded_weight.dtype) == "float32" or str(loaded_weight.dtype) == "bfloat16") and is_310p()) + else loaded_weight + ) return loaded_weight end_idx = start_idx + shard_size @@ -50,6 +57,12 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): loaded_weight = loaded_weight[:, :, start_idx:end_idx] else: raise ValueError("shard_dim:{} is not supported.".format(shard_dim)) + loaded_weight = ( + loaded_weight.astype(np.float16) + if ((str(loaded_weight.dtype) == "float32" or str(loaded_weight.dtype) == "bfloat16") and is_310p()) + else loaded_weight + ) + return loaded_weight @@ -77,4 +90,9 @@ def safetensors_weights_iterator( def default_weight_loader(param: Parameter, loaded_weight: Any) -> None: """Default weight loader.""" loaded_weight = loaded_weight[:] + loaded_weight = ( + loaded_weight.astype(np.float16) + if ((str(loaded_weight.dtype) == "float32" or str(loaded_weight.dtype) == "bfloat16") and is_310p()) + else loaded_weight + ) param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index f0e5621e..f1c616fa 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -34,7 +34,7 @@ from vllm_mindspore.model_executor.models.attention_mask import ( LowerTriangularMask) from vllm_mindspore.model_executor.models.utils import is_use_ringmla from vllm_mindspore.model_executor.utils import set_model_context -from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE, create_kv_cache +from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE, create_kv_cache, is_310p from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata @@ -445,7 +445,8 @@ class NativeModel(MsModelBase): block_size = self.cache_config.block_size num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) head_size = self.model_config.get_head_size() - kv_cache_shape = (None, block_size, num_kv_heads, head_size) + kv_cache_shape = (None, block_size, num_kv_heads * head_size) \ + if is_310p() else (None, block_size, num_kv_heads, head_size) kv_cache_dtype = (self.model_config.dtype if self.cache_config.cache_dtype == "auto" else -- Gitee From 6522a5ed90969648ddb0150962e70bf431c32ec5 Mon Sep 17 00:00:00 2001 From: luolihao Date: Wed, 13 Aug 2025 14:55:00 +0800 Subject: [PATCH 2/6] fix bug and support lm_head to Nz --- .../distributed/communication_op.py | 31 ++++- .../model_executor/layers/logits_processor.py | 121 ++++++++++++++---- 2 files changed, 123 insertions(+), 29 deletions(-) diff --git a/vllm_mindspore/distributed/communication_op.py b/vllm_mindspore/distributed/communication_op.py index c933dc4a..c3646e48 100644 --- a/vllm_mindspore/distributed/communication_op.py +++ b/vllm_mindspore/distributed/communication_op.py @@ -21,9 +21,12 @@ Implement a unified communication interface for both graph and pynative mode. """ -from mindspore import nn, ops +from typing import Optional + +from mindspore import Tensor, mint, nn, ops from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_world_size, get_tp_group) + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_group) class ReduceFromModelParallelRegion(nn.Cell): @@ -64,3 +67,27 @@ class AllGatherFromModelParallelRegion(nn.Cell): output = self.all_gather_into_tensor(input_) output = ops.swapaxes(output, 0, -1) return output + + +class GatherFromModelParallelRegion(nn.Cell): + "Gather the input from model parallel region and concatenate." + + def __init__(self): + super().__init__() + self.world_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + if self.world_size > 1: + self.tp_group = get_tp_group().device_group._name + + def construct(self, + input_: Tensor, + dst: int = 0, + dim: int = -1) -> Optional[Tensor]: + # Size and dimension. + if self.world_size == 1: + return input_ + output = ops.CollectiveGather(dest_rank=dst, + group=self.tp_group)(mint.transpose(input_, 0, dim)) + if self.tp_rank != dst: + return None + return mint.transpose(output, 0, dim) diff --git a/vllm_mindspore/model_executor/layers/logits_processor.py b/vllm_mindspore/model_executor/layers/logits_processor.py index ee8c8edc..6910804a 100644 --- a/vllm_mindspore/model_executor/layers/logits_processor.py +++ b/vllm_mindspore/model_executor/layers/logits_processor.py @@ -23,12 +23,14 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional import vllm.envs as envs -from mindspore import Tensor, mint, nn -from vllm.config import current_platform +from mindspore import Tensor, jit, mint, nn +from vllm.config import current_platform, get_current_vllm_config from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm_mindspore.distributed.communication_op import ( + AllGatherFromModelParallelRegion, GatherFromModelParallelRegion) from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -60,6 +62,9 @@ class LogitsProcessor(nn.Cell): scale: A scaling factor to apply to the logits. """ super().__init__() + vllm_config = get_current_vllm_config() + self.vllm_config = vllm_config + self.is_graph_mode = bool(not vllm_config.model_config.enforce_eager) self.scale = scale self.vocab_size = vocab_size # Whether the input is logits (default is hidden states). @@ -71,25 +76,101 @@ class LogitsProcessor(nn.Cell): # Whether to use gather or all-gather to gather the logits. self.use_all_gather = current_platform.use_all_gather() + if self.use_all_gather: + self.tensor_model_parallel_all_gather = AllGatherFromModelParallelRegion() + else: + self.tensor_model_parallel_gather = GatherFromModelParallelRegion() + self.lm_head = None + self.run_model = None + self.cached_input_info = {} + + def set_dynamic_inputs(self): + dyn_hidden_states = Tensor(shape=[None, None], + dtype=self.vllm_config.model_config.dtype) + + if self.cached_input_info["indices"] is None: + dyn_indices = None + else: + dyn_indices_shape = [ + None for _ in range(self.cached_input_info["indices"]["ndim"]) + ] + dyn_indices_dtype = self.cached_input_info["indices"]["dtype"] + dyn_indices = Tensor(shape=dyn_indices_shape, + dtype=dyn_indices_dtype) + + if self.cached_input_info["bias"] is None: + dyn_bias = None + else: + dyn_bias_shape = [ + None for _ in range(self.cached_input_info["bias"]["ndim"]) + ] + dyn_bias_dtype = self.cached_input_info["bias"]["dtype"] + dyn_bias = Tensor(shape=dyn_bias_shape, dtype=dyn_bias_dtype) + + self.set_inputs(dyn_hidden_states, dyn_indices, dyn_bias) + + def __call__( + self, + lm_head: VocabParallelEmbedding, + hidden_states: Tensor, + sampling_metadata: Optional[SamplingMetadata] = None, + embedding_bias: Optional[Tensor] = None, + ) -> Optional[Tensor]: + if self.lm_head is None: + self.lm_head = lm_head + if self.run_model is None: + self.run_model = jit( + function=self.construct, + jit_level='O0') if self.is_graph_mode else self.construct + selected_token_indices = None + if sampling_metadata is not None: + selected_token_indices = sampling_metadata.selected_token_indices + dyn_indices_info = None if selected_token_indices is None else { + "ndim": selected_token_indices.ndim, + "dtype": selected_token_indices.dtype, + } + dyn_bias_info = None if embedding_bias is None else { + "ndim": embedding_bias.ndim, + "dtype": embedding_bias.dtype, + } + if self.cached_input_info != {"indices": dyn_indices_info, + "bias": dyn_bias_info}: + self.cached_input_info = { + "indices": dyn_indices_info, + "bias": dyn_bias_info, + } + self.set_dynamic_inputs() + + logits = self.run_model( + hidden_states, + selected_token_indices, + embedding_bias + ) + + if sampling_metadata is not None and \ + sampling_metadata.seq_groups is not None: + logits = _apply_logits_processors(logits, sampling_metadata) + + return logits + def construct( self, - lm_head: VocabParallelEmbedding, hidden_states: Tensor, - sampling_metadata: Optional[SamplingMetadata] = None, + selected_token_indices: Optional[Tensor] = None, embedding_bias: Optional[Tensor] = None, ) -> Optional[Tensor]: if self.logits_as_input: logits = hidden_states else: - if sampling_metadata is not None: - if sampling_metadata.selected_token_indices.numel() <= 0: - return mint.zeros((0, self.vocab_size), - dtype=hidden_states.dtype) - hidden_states = _prune_hidden_states(hidden_states, - sampling_metadata) + if selected_token_indices is not None: + if selected_token_indices.numel() <= 0: + return mint.zeros((0, self.vocab_size), dtype=hidden_states.dtype) + hidden_states = mint.index_select( + hidden_states, 0, selected_token_indices) # Get the logits for the next tokens. - logits = self._get_logits(hidden_states, lm_head, embedding_bias) + logits = self._get_logits( + hidden_states, self.lm_head, embedding_bias) if logits is not None: if self.soft_cap is not None: logits = logits / self.soft_cap @@ -100,9 +181,6 @@ class LogitsProcessor(nn.Cell): logits *= self.scale # Apply logits processors (if any). - if sampling_metadata is not None and \ - sampling_metadata.seq_groups is not None: - logits = _apply_logits_processors(logits, sampling_metadata) return logits @@ -118,10 +196,10 @@ class LogitsProcessor(nn.Cell): bias=embedding_bias) if self.use_all_gather: # Gather is not supported for some devices such as NPUs. - logits = tensor_model_parallel_all_gather(logits) + logits = self.tensor_model_parallel_all_gather(logits) else: # None may be returned for rank > 0 - logits = tensor_model_parallel_gather(logits) + logits = self.tensor_model_parallel_gather(logits) # Remove paddings in vocab (if any). if logits is not None: logits = logits[..., :self.org_vocab_size] @@ -134,17 +212,6 @@ class LogitsProcessor(nn.Cell): return s -def _prune_hidden_states( - hidden_states: Tensor, - sampling_metadata: SamplingMetadata, -) -> Tensor: - indices = sampling_metadata.selected_token_indices - if indices is not None and indices.numel() > 0: - return mint.index_select(hidden_states, 0, - sampling_metadata.selected_token_indices) - return hidden_states - - def _apply_logits_processors( logits: Tensor, sampling_metadata: SamplingMetadata, -- Gitee From 251508dde304b7fc83dd99b4cd0e2460049c13fc Mon Sep 17 00:00:00 2001 From: zhanzhan1 Date: Tue, 5 Aug 2025 19:40:26 +0800 Subject: [PATCH 3/6] lm_head support jit --- .../model_executor/layers/logits_processor.py | 47 ++++++++----------- 1 file changed, 19 insertions(+), 28 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/logits_processor.py b/vllm_mindspore/model_executor/layers/logits_processor.py index 6910804a..1b1770cc 100644 --- a/vllm_mindspore/model_executor/layers/logits_processor.py +++ b/vllm_mindspore/model_executor/layers/logits_processor.py @@ -75,7 +75,7 @@ class LogitsProcessor(nn.Cell): self.soft_cap = soft_cap # Whether to use gather or all-gather to gather the logits. self.use_all_gather = current_platform.use_all_gather() - + if self.use_all_gather: self.tensor_model_parallel_all_gather = AllGatherFromModelParallelRegion() else: @@ -85,36 +85,28 @@ class LogitsProcessor(nn.Cell): self.cached_input_info = {} def set_dynamic_inputs(self): - dyn_hidden_states = Tensor(shape=[None, None], - dtype=self.vllm_config.model_config.dtype) - - if self.cached_input_info["indices"] is None: - dyn_indices = None - else: - dyn_indices_shape = [ - None for _ in range(self.cached_input_info["indices"]["ndim"]) - ] - dyn_indices_dtype = self.cached_input_info["indices"]["dtype"] - dyn_indices = Tensor(shape=dyn_indices_shape, - dtype=dyn_indices_dtype) - - if self.cached_input_info["bias"] is None: - dyn_bias = None - else: - dyn_bias_shape = [ - None for _ in range(self.cached_input_info["bias"]["ndim"]) - ] - dyn_bias_dtype = self.cached_input_info["bias"]["dtype"] - dyn_bias = Tensor(shape=dyn_bias_shape, dtype=dyn_bias_dtype) + dyn_hidden_states = Tensor( + shape=[None, None], dtype=self.vllm_config.model_config.dtype) + + dyn_indices_shape = [None for _ in range( + self.cached_input_info["indices"]["ndim"])] + dyn_indices_dtype = self.cached_input_info["indices"]["dtype"] + dyn_indices = None if self.cached_input_info["indices"] is None else \ + Tensor(shape=dyn_indices_shape, dtype=dyn_indices_dtype) + + dyn_bias_shape = [None for _ in range( + self.cached_input_info["bias"]["ndim"])] + dyn_bias_dtype = self.cached_input_info["bias"]["dtype"] + dyn_bias = None if self.cached_input_info["bias"] is None else \ + Tensor(shape=dyn_bias_shape, dtype=dyn_bias_dtype) self.set_inputs(dyn_hidden_states, dyn_indices, dyn_bias) def __call__( - self, - lm_head: VocabParallelEmbedding, - hidden_states: Tensor, - sampling_metadata: Optional[SamplingMetadata] = None, - embedding_bias: Optional[Tensor] = None, + self, + hidden_states: Tensor, + selected_token_indices: Optional[Tensor] = None, + embedding_bias: Optional[Tensor] = None, ) -> Optional[Tensor]: if self.lm_head is None: self.lm_head = lm_head @@ -181,7 +173,6 @@ class LogitsProcessor(nn.Cell): logits *= self.scale # Apply logits processors (if any). - return logits def _get_logits( -- Gitee From 332b5057bb2732dd8fb08cf092440382d4086c48 Mon Sep 17 00:00:00 2001 From: huoxinyou Date: Mon, 15 Sep 2025 09:58:29 +0800 Subject: [PATCH 4/6] fixed lm_head bug --- .../distributed/communication_op.py | 31 +------ .../model_executor/layers/linear.py | 9 ++- .../model_executor/layers/logits_processor.py | 80 ++++++++++--------- .../layers/vocab_parallel_embedding.py | 8 +- .../model_loader/weight_utils.py | 30 ++++--- .../model_executor/models/model_base.py | 3 +- 6 files changed, 72 insertions(+), 89 deletions(-) diff --git a/vllm_mindspore/distributed/communication_op.py b/vllm_mindspore/distributed/communication_op.py index c3646e48..c933dc4a 100644 --- a/vllm_mindspore/distributed/communication_op.py +++ b/vllm_mindspore/distributed/communication_op.py @@ -21,12 +21,9 @@ Implement a unified communication interface for both graph and pynative mode. """ -from typing import Optional - -from mindspore import Tensor, mint, nn, ops +from mindspore import nn, ops from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - get_tp_group) + get_tensor_model_parallel_world_size, get_tp_group) class ReduceFromModelParallelRegion(nn.Cell): @@ -67,27 +64,3 @@ class AllGatherFromModelParallelRegion(nn.Cell): output = self.all_gather_into_tensor(input_) output = ops.swapaxes(output, 0, -1) return output - - -class GatherFromModelParallelRegion(nn.Cell): - "Gather the input from model parallel region and concatenate." - - def __init__(self): - super().__init__() - self.world_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - if self.world_size > 1: - self.tp_group = get_tp_group().device_group._name - - def construct(self, - input_: Tensor, - dst: int = 0, - dim: int = -1) -> Optional[Tensor]: - # Size and dimension. - if self.world_size == 1: - return input_ - output = ops.CollectiveGather(dest_rank=dst, - group=self.tp_group)(mint.transpose(input_, 0, dim)) - if self.tp_rank != dst: - return None - return mint.transpose(output, 0, dim) diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index 28dabdd6..696890a4 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -169,7 +169,8 @@ class LinearBase(nn.Cell): self.param_load_counts[param.name] = current_count if current_count == merge_count: - cast_weight = ops.auto_generate.format_cast(param, FORMAT_TYPE['nz']) + cast_weight = ops.auto_generate.format_cast( + param, FORMAT_TYPE['nz']) param.set_data(cast_weight) del self.param_load_counts[param.name] @@ -366,7 +367,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): param[shard_offset:shard_offset + shard_size, :] = ms.from_numpy(loaded_weight) if is_310p() and param.name.endswith("weight"): - loaded_shard_num = 2 # gating/hidden + loaded_shard_num = 2 # gating/hidden self.format_to_nz(param, loaded_shard_num) @@ -477,7 +478,7 @@ class QKVParallelLinear(ColumnParallelLinear): assert loaded_weight.shape == param.shape param.set_data(loaded_weight) if is_310p() and param.name.endswith("weight"): - loaded_shard_num = 3 # q/k/v + loaded_shard_num = 3 # q/k/v self.format_to_nz(param, loaded_shard_num) return @@ -508,7 +509,7 @@ class QKVParallelLinear(ColumnParallelLinear): assert loaded_weight.shape == (shard_size, ) param[shard_offset:shard_offset + shard_size] = loaded_weight if is_310p() and param.name.endswith("weight"): - loaded_shard_num = 3 # q/k/v + loaded_shard_num = 3 # q/k/v self.format_to_nz(param, loaded_shard_num) diff --git a/vllm_mindspore/model_executor/layers/logits_processor.py b/vllm_mindspore/model_executor/layers/logits_processor.py index 1b1770cc..c14aac8f 100644 --- a/vllm_mindspore/model_executor/layers/logits_processor.py +++ b/vllm_mindspore/model_executor/layers/logits_processor.py @@ -25,12 +25,11 @@ from typing import Optional import vllm.envs as envs from mindspore import Tensor, jit, mint, nn from vllm.config import current_platform, get_current_vllm_config -from vllm.distributed import (tensor_model_parallel_all_gather, - tensor_model_parallel_gather) +from vllm.distributed import tensor_model_parallel_gather from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm_mindspore.distributed.communication_op import ( - AllGatherFromModelParallelRegion, GatherFromModelParallelRegion) + AllGatherFromModelParallelRegion) from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -75,37 +74,44 @@ class LogitsProcessor(nn.Cell): self.soft_cap = soft_cap # Whether to use gather or all-gather to gather the logits. self.use_all_gather = current_platform.use_all_gather() - + if self.use_all_gather: - self.tensor_model_parallel_all_gather = AllGatherFromModelParallelRegion() - else: - self.tensor_model_parallel_gather = GatherFromModelParallelRegion() + self.tensor_model_parallel_all_gather = \ + AllGatherFromModelParallelRegion() self.lm_head = None self.run_model = None self.cached_input_info = {} def set_dynamic_inputs(self): - dyn_hidden_states = Tensor( - shape=[None, None], dtype=self.vllm_config.model_config.dtype) - - dyn_indices_shape = [None for _ in range( - self.cached_input_info["indices"]["ndim"])] - dyn_indices_dtype = self.cached_input_info["indices"]["dtype"] - dyn_indices = None if self.cached_input_info["indices"] is None else \ - Tensor(shape=dyn_indices_shape, dtype=dyn_indices_dtype) - - dyn_bias_shape = [None for _ in range( - self.cached_input_info["bias"]["ndim"])] - dyn_bias_dtype = self.cached_input_info["bias"]["dtype"] - dyn_bias = None if self.cached_input_info["bias"] is None else \ - Tensor(shape=dyn_bias_shape, dtype=dyn_bias_dtype) + dyn_hidden_states = Tensor(shape=[None, None], + dtype=self.vllm_config.model_config.dtype) + + if self.cached_input_info["indices"] is None: + dyn_indices = None + else: + dyn_indices_shape = [ + None for _ in range(self.cached_input_info["indices"]["ndim"]) + ] + dyn_indices_dtype = self.cached_input_info["indices"]["dtype"] + dyn_indices = Tensor(shape=dyn_indices_shape, + dtype=dyn_indices_dtype) + + if self.cached_input_info["bias"] is None: + dyn_bias = None + else: + dyn_bias_shape = [ + None for _ in range(self.cached_input_info["bias"]["ndim"]) + ] + dyn_bias_dtype = self.cached_input_info["bias"]["dtype"] + dyn_bias = Tensor(shape=dyn_bias_shape, dtype=dyn_bias_dtype) self.set_inputs(dyn_hidden_states, dyn_indices, dyn_bias) def __call__( self, + lm_head: VocabParallelEmbedding, hidden_states: Tensor, - selected_token_indices: Optional[Tensor] = None, + sampling_metadata: Optional[SamplingMetadata] = None, embedding_bias: Optional[Tensor] = None, ) -> Optional[Tensor]: if self.lm_head is None: @@ -125,19 +131,23 @@ class LogitsProcessor(nn.Cell): "ndim": embedding_bias.ndim, "dtype": embedding_bias.dtype, } - if self.cached_input_info != {"indices": dyn_indices_info, - "bias": dyn_bias_info}: + if self.cached_input_info != { + "indices": dyn_indices_info, + "bias": dyn_bias_info + }: self.cached_input_info = { "indices": dyn_indices_info, "bias": dyn_bias_info, } self.set_dynamic_inputs() - logits = self.run_model( - hidden_states, - selected_token_indices, - embedding_bias - ) + if selected_token_indices is not None and selected_token_indices.numel( + ) <= 0: + logits = mint.zeros((0, self.vocab_size), + dtype=hidden_states.dtype) + else: + logits = self.run_model(hidden_states, selected_token_indices, + embedding_bias) if sampling_metadata is not None and \ sampling_metadata.seq_groups is not None: @@ -155,14 +165,12 @@ class LogitsProcessor(nn.Cell): logits = hidden_states else: if selected_token_indices is not None: - if selected_token_indices.numel() <= 0: - return mint.zeros((0, self.vocab_size), dtype=hidden_states.dtype) - hidden_states = mint.index_select( - hidden_states, 0, selected_token_indices) + hidden_states = mint.index_select(hidden_states, 0, + selected_token_indices) # Get the logits for the next tokens. - logits = self._get_logits( - hidden_states, self.lm_head, embedding_bias) + logits = self._get_logits(hidden_states, self.lm_head, + embedding_bias) if logits is not None: if self.soft_cap is not None: logits = logits / self.soft_cap @@ -190,7 +198,7 @@ class LogitsProcessor(nn.Cell): logits = self.tensor_model_parallel_all_gather(logits) else: # None may be returned for rank > 0 - logits = self.tensor_model_parallel_gather(logits) + logits = tensor_model_parallel_gather(logits) # Remove paddings in vocab (if any). if logits is not None: logits = logits[..., :self.org_vocab_size] diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index dbf8e0db..9f8f80e3 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -38,7 +38,7 @@ from vllm_mindspore.model_executor.layers.quantization.base_config import ( from vllm_mindspore.model_executor.model_loader.weight_utils import ( split_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs -from vllm_mindspore.utils import is_310p, FORMAT_TYPE +from vllm_mindspore.utils import FORMAT_TYPE, is_310p DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -355,7 +355,8 @@ class VocabParallelEmbedding(nn.Cell): f"and {loaded_weight.shape}") param.set_data(ms.from_numpy(loaded_weight)) if need_nz: - param.set_data(ops.auto_generate.format_cast(param, FORMAT_TYPE['nz'])) + param.set_data( + ops.auto_generate.format_cast(param, FORMAT_TYPE['nz'])) return # Shard indexes for loading the weight @@ -373,7 +374,8 @@ class VocabParallelEmbedding(nn.Cell): param[:loaded_weight.shape[0]] = ms.from_numpy(loaded_weight) param[loaded_weight.shape[0]:] = 0 if need_nz: - param.set_data(ops.auto_generate.format_cast(param, FORMAT_TYPE['nz'])) + param.set_data( + ops.auto_generate.format_cast(param, FORMAT_TYPE['nz'])) class ParallelLMHead(VocabParallelEmbedding): diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 9947387e..7a4ff7fd 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -20,14 +20,15 @@ from collections.abc import Generator from typing import Any -import numpy as np import mindspore as ms +import numpy as np from mindspore import Parameter from safetensors import safe_open from tqdm.auto import tqdm from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, enable_tqdm) + from vllm_mindspore.utils import is_310p @@ -41,11 +42,10 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): """ if shard_dim is None: loaded_weight = loaded_weight[:] - loaded_weight = ( - loaded_weight.astype(np.float16) - if ((str(loaded_weight.dtype) == "float32" or str(loaded_weight.dtype) == "bfloat16") and is_310p()) - else loaded_weight - ) + loaded_weight = (loaded_weight.astype(np.float16) if + ((str(loaded_weight.dtype) == "float32" + or str(loaded_weight.dtype) == "bfloat16") + and is_310p()) else loaded_weight) return loaded_weight end_idx = start_idx + shard_size @@ -57,11 +57,10 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): loaded_weight = loaded_weight[:, :, start_idx:end_idx] else: raise ValueError("shard_dim:{} is not supported.".format(shard_dim)) - loaded_weight = ( - loaded_weight.astype(np.float16) - if ((str(loaded_weight.dtype) == "float32" or str(loaded_weight.dtype) == "bfloat16") and is_310p()) - else loaded_weight - ) + loaded_weight = (loaded_weight.astype(np.float16) if + ((str(loaded_weight.dtype) == "float32" + or str(loaded_weight.dtype) == "bfloat16") + and is_310p()) else loaded_weight) return loaded_weight @@ -90,9 +89,8 @@ def safetensors_weights_iterator( def default_weight_loader(param: Parameter, loaded_weight: Any) -> None: """Default weight loader.""" loaded_weight = loaded_weight[:] - loaded_weight = ( - loaded_weight.astype(np.float16) - if ((str(loaded_weight.dtype) == "float32" or str(loaded_weight.dtype) == "bfloat16") and is_310p()) - else loaded_weight - ) + loaded_weight = (loaded_weight.astype(np.float16) if + ((str(loaded_weight.dtype) == "float32" + or str(loaded_weight.dtype) == "bfloat16") + and is_310p()) else loaded_weight) param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index f1c616fa..cffb886f 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -34,7 +34,8 @@ from vllm_mindspore.model_executor.models.attention_mask import ( LowerTriangularMask) from vllm_mindspore.model_executor.models.utils import is_use_ringmla from vllm_mindspore.model_executor.utils import set_model_context -from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE, create_kv_cache, is_310p +from vllm_mindspore.utils import (STR_DTYPE_TO_MS_DTYPE, create_kv_cache, + is_310p) from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata -- Gitee From 3c7bbf05bfe2f4bb79d8948377b2584230de3883 Mon Sep 17 00:00:00 2001 From: huoxinyou Date: Tue, 16 Sep 2025 16:38:26 +0800 Subject: [PATCH 5/6] w8a8 310 --- vllm_mindspore/__init__.py | 40 ++++++ vllm_mindspore/engine/arg_utils.py | 123 +++++++++++++++++- .../model_executor/layers/linear.py | 33 +++-- .../layers/quantization/__init__.py | 49 +++++++ .../layers/quantization/base_config.py | 3 + .../model_loader/default_loader.py | 99 ++++++++++++++ .../model_loader/weight_utils.py | 109 +++++++++++++++- vllm_mindspore/model_executor/models/qwen2.py | 5 + 8 files changed, 443 insertions(+), 18 deletions(-) create mode 100644 vllm_mindspore/model_executor/model_loader/default_loader.py diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index dbd26f9b..100a61e7 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -551,6 +551,46 @@ sys.modules["vllm.entrypoints.openai.tool_parsers.deepseekv3_tool_parser"] = ( from vllm_mindspore.entrypoints.__main__ import ( patch_server_run_api_server_worker_proc, ) +from vllm_mindspore.model_executor.model_loader.utils import ( + process_weights_after_loading) + +vllm.model_executor.model_loader.utils.process_weights_after_loading = ( + process_weights_after_loading) +vllm.model_executor.model_loader.base_loader.process_weights_after_loading = ( + process_weights_after_loading) + +from vllm_mindspore.model_executor.layers.quantization import ( + get_quantization_config) + +vllm.model_executor.layers.quantization.get_quantization_config = ( + get_quantization_config) +vllm.config.get_quantization_config = get_quantization_config +vllm.model_executor.model_loader.weight_utils.get_quantization_config = ( + get_quantization_config) + +from vllm_mindspore.model_executor.model_loader.weight_utils import ( + get_quant_config) + +vllm.model_executor.model_loader.weight_utils.get_quant_config = ( + get_quant_config) +vllm.config.get_quant_config = get_quant_config + +from vllm_mindspore.model_executor.layers.quantization import ( + QuantizationMethods) + +vllm.model_executor.layers.quantization.QuantizationMethods = ( + QuantizationMethods) + +from vllm_mindspore.engine.arg_utils import get_kwargs + +vllm.engine.arg_utils.get_kwargs = get_kwargs + +from vllm_mindspore.model_executor.model_loader.default_loader import ( + _prepare_weights) +from vllm.model_executor.model_loader.default_loader import DefaultModelLoader + +DefaultModelLoader._prepare_weights = _prepare_weights + patch_server_run_api_server_worker_proc() from vllm_mindspore.model_executor.models.registry import _normalize_archs diff --git a/vllm_mindspore/engine/arg_utils.py b/vllm_mindspore/engine/arg_utils.py index f7a3d6aa..9803e158 100644 --- a/vllm_mindspore/engine/arg_utils.py +++ b/vllm_mindspore/engine/arg_utils.py @@ -19,19 +19,132 @@ # limitations under the License. """Adaption for arguments utils.""" +import argparse +import json import threading -from typing import get_args +from dataclasses import MISSING, fields, is_dataclass +from typing import Any, Literal, get_origin import torch import vllm.envs as envs -from vllm.config import (GuidedDecodingBackendV1, LoadFormat, ModelConfig, - ParallelConfig, SchedulerConfig) -from vllm.engine.arg_utils import (EngineArgs, _raise_or_fallback, - _warn_or_fallback) from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) +from pydantic import TypeAdapter, ValidationError +from vllm.config import (ConfigType, GuidedDecodingBackendV1, LoadFormat, + ModelConfig, ParallelConfig, SchedulerConfig) +from vllm.engine.arg_utils import (EngineArgs, TypeHint, _raise_or_fallback, + _warn_or_fallback, contains_type, get_args, + get_attr_docs, get_type, get_type_hints, + human_readable_int, is_not_builtin, + literal_to_kwargs, optional_type, + parse_type, union_dict_and_str) + +from vllm_mindspore.model_executor.layers.quantization import ( + QUANTIZATION_METHODS) + + +def get_kwargs(cls: ConfigType) -> dict[str, Any]: + cls_docs = get_attr_docs(cls) + kwargs = {} + for field in fields(cls): + type_hints: set[TypeHint] = get_type_hints(field.type) + + # If the field is a dataclass, we can use the model_validate_json + generator = (th for th in type_hints if is_dataclass(th)) + dataclass_cls = next(generator, None) + + # Get the default value of the field + if field.default is not MISSING: + default = field.default + elif field.default_factory is not MISSING: + default = field.default_factory() + + # Get the help text for the field + name = field.name + help = cls_docs[name].strip() + # Escape % for argparse + help = help.replace("%", "%%") + + # Initialise the kwargs dictionary for the field + kwargs[name] = {"default": default, "help": help} + + # Set other kwargs based on the type hints + json_tip = """\n\nShould either be a valid JSON string or JSON keys + passed individually. For example, the following sets of arguments are + equivalent:\n\n + - `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n + - `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n""" + if dataclass_cls is not None: + + def parse_dataclass(val: str, cls=dataclass_cls) -> Any: + try: + if hasattr(cls, "from_cli"): + return cls.from_cli(val) + return TypeAdapter(cls).validate_json(val) + except ValidationError as e: + raise argparse.ArgumentTypeError(repr(e)) from e + + kwargs[name]["type"] = parse_dataclass + kwargs[name]["help"] += json_tip + elif contains_type(type_hints, bool): + # Creates --no- and -- flags + kwargs[name]["action"] = argparse.BooleanOptionalAction + elif contains_type(type_hints, Literal): + kwargs[name].update(literal_to_kwargs(type_hints)) + elif contains_type(type_hints, tuple): + type_hint = get_type(type_hints, tuple) + types = get_args(type_hint) + tuple_type = types[0] + assert all(t is tuple_type for t in types if t is not Ellipsis), ( + "All non-Ellipsis tuple elements must be of the same " + f"type. Got {types}.") + kwargs[name]["type"] = tuple_type + kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types) + elif contains_type(type_hints, list): + type_hint = get_type(type_hints, list) + types = get_args(type_hint) + assert len(types) == 1, ( + "List type must have exactly one type. Got " + f"{type_hint} with types {types}") + kwargs[name]["type"] = types[0] + kwargs[name]["nargs"] = "+" + elif contains_type(type_hints, int): + kwargs[name]["type"] = int + # Special case for large integers + if name in {"max_model_len", "max_num_batched_tokens"}: + kwargs[name]["type"] = human_readable_int + elif contains_type(type_hints, float): + kwargs[name]["type"] = float + elif (contains_type(type_hints, dict) + and (contains_type(type_hints, str) + or any(is_not_builtin(th) for th in type_hints))): + kwargs[name]["type"] = union_dict_and_str + elif contains_type(type_hints, dict): + kwargs[name]["type"] = parse_type(json.loads) + kwargs[name]["help"] += json_tip + elif (contains_type(type_hints, str) + or any(is_not_builtin(th) for th in type_hints)): + kwargs[name]["type"] = str + else: + raise ValueError( + f"Unsupported type {type_hints} for argument {name}.") + + # If the type hint was a sequence of literals, use the helper function + # to update the type and choices + if get_origin(kwargs[name].get("type")) is Literal: + kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]})) + + # If None is in type_hints, make the argument optional. + # But not if it's a bool, argparse will handle this better. + if type(None) in type_hints and not contains_type(type_hints, bool): + kwargs[name]["type"] = optional_type(kwargs[name]["type"]) + if kwargs[name].get("choices"): + kwargs[name]["choices"].append("None") + if field.name == "quantization": + kwargs[name]["choices"] = QUANTIZATION_METHODS + return kwargs def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index 696890a4..1589abab 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -18,7 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Linear methods for quantized linear layers. """ - from abc import abstractmethod from ast import Dict from typing import Optional, Union @@ -354,21 +353,31 @@ class MergedColumnParallelLinear(ColumnParallelLinear): tp_size = get_tensor_model_parallel_world_size() shard_size = 0 shard_offset = 0 - if loaded_shard_id is not None: + if loaded_shard_id is None: + current_shard_offset = 0 + shard_offsets = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + for shard_id, shard_offset, shard_size in shard_offsets: + loaded_weight_shard = split_loaded_weight( + loaded_weight, output_dim, shard_offset, shard_size) + self.weight_loader(param, loaded_weight_shard, shard_id) + else: assert loaded_shard_id < len(self.output_sizes) shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size shard_size = self.output_sizes[loaded_shard_id] // tp_size - start_idx = tp_rank * shard_size - loaded_weight = split_loaded_weight(loaded_weight, output_dim, - start_idx, shard_size) - - assert loaded_weight.shape == (shard_size, param.shape[1]) - param[shard_offset:shard_offset + - shard_size, :] = ms.from_numpy(loaded_weight) - if is_310p() and param.name.endswith("weight"): - loaded_shard_num = 2 # gating/hidden - self.format_to_nz(param, loaded_shard_num) + start_idx = tp_rank * shard_size + loaded_weight = split_loaded_weight(loaded_weight, output_dim, + start_idx, shard_size) + + assert loaded_weight.shape == (shard_size, param.shape[1]) + param[shard_offset:shard_offset + + shard_size, :] = ms.from_numpy(loaded_weight) + if is_310p() and param.name.endswith("weight"): + loaded_shard_num = 2 # gating/hidden + self.format_to_nz(param, loaded_shard_num) class QKVParallelLinear(ColumnParallelLinear): diff --git a/vllm_mindspore/model_executor/layers/quantization/__init__.py b/vllm_mindspore/model_executor/layers/quantization/__init__.py index e69de29b..6c9e2e41 100644 --- a/vllm_mindspore/model_executor/layers/quantization/__init__.py +++ b/vllm_mindspore/model_executor/layers/quantization/__init__.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from typing import Literal, get_args + +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + +QuantizationMethods = Literal["smoothquant"] +QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) + +# The customized quantization methods which will be added to this dict. +_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {} + + +def get_quantization_config(quantization: str) -> type[QuantizationConfig]: + if quantization not in QUANTIZATION_METHODS: + raise ValueError(f"Invalid quantization method: {quantization}") + + # lazy import to avoid triggering `torch.compile` too early + from .smooth_quant_modelslim import SmoothQuantModelSlimConfig + method_to_config: dict[str, type[QuantizationConfig]] = { + "smoothquant": SmoothQuantModelSlimConfig + } + # Update the `method_to_config` with customized quantization methods. + method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) + + return method_to_config[quantization] + + +__all__ = [ + "QuantizationConfig", "get_quantization_config", "QUANTIZATION_METHODS", + "QuantizationMethods" +] diff --git a/vllm_mindspore/model_executor/layers/quantization/base_config.py b/vllm_mindspore/model_executor/layers/quantization/base_config.py index 37144a43..5728702d 100644 --- a/vllm_mindspore/model_executor/layers/quantization/base_config.py +++ b/vllm_mindspore/model_executor/layers/quantization/base_config.py @@ -142,6 +142,9 @@ class QuantizationConfig(ABC): """ raise NotImplementedError + def get_cache_scale(self, name: str) -> Optional[str]: + return None + def method_has_implemented_embedding( method_class: type[QuantizeMethodBase]) -> bool: diff --git a/vllm_mindspore/model_executor/model_loader/default_loader.py b/vllm_mindspore/model_executor/model_loader/default_loader.py new file mode 100644 index 00000000..dbd6ea8b --- /dev/null +++ b/vllm_mindspore/model_executor/model_loader/default_loader.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 +import glob +import os +from typing import Optional + +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME +from vllm.config import LoadFormat +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.model_executor.model_loader.weight_utils import ( + download_safetensors_index_file_from_hf, download_weights_from_hf, + filter_duplicate_safetensors_files, filter_files_not_needed_for_inference) + + +def _prepare_weights( + self, + model_name_or_path: str, + revision: Optional[str], + fall_back_to_pt: bool, + allow_patterns_overrides: Optional[list[str]], +) -> tuple[str, list[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + model_name_or_path = (self._maybe_download_from_modelscope( + model_name_or_path, revision) or model_name_or_path) + + is_local = os.path.isdir(model_name_or_path) + load_format = self.load_config.load_format + use_safetensors = False + index_file = SAFE_WEIGHTS_INDEX_NAME + # Some quantized models use .pt files for storing the weights. + if load_format == LoadFormat.AUTO: + allow_patterns = ["*.safetensors", "*.bin"] + elif (load_format == LoadFormat.SAFETENSORS + or load_format == LoadFormat.FASTSAFETENSORS): + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.MISTRAL: + use_safetensors = True + allow_patterns = ["consolidated*.safetensors"] + index_file = "consolidated.safetensors.index.json" + elif load_format == LoadFormat.PT: + allow_patterns = ["*.pt"] + elif load_format == LoadFormat.NPCACHE: + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if allow_patterns_overrides is not None: + allow_patterns = allow_patterns_overrides + + if not is_local: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + else: + hf_folder = model_name_or_path + hf_weights_files: list[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) == 0: + tp_rank = get_tensor_model_parallel_rank() + hf_weights_files += glob.glob( + os.path.join(hf_folder, f"rank_{tp_rank}", pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, + index_file, + self.load_config.download_dir, + revision, + ) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file) + else: + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_folder, hf_weights_files, use_safetensors diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 7a4ff7fd..4abd7ac8 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -17,15 +17,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import glob +import json +import os from collections.abc import Generator from typing import Any +import huggingface_hub import mindspore as ms import numpy as np + +from huggingface_hub import snapshot_download from mindspore import Parameter from safetensors import safe_open from tqdm.auto import tqdm +from vllm.config import LoadConfig +from vllm.model_executor.model_loader.weight_utils import (DisabledTqdm, + get_lock) + +from vllm_mindspore.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm_mindspore.platforms.ascend import ModelConfig +from vllm_mindspore.utils import atlas_inference from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, enable_tqdm) @@ -94,3 +107,97 @@ def default_weight_loader(param: Parameter, loaded_weight: Any) -> None: or str(loaded_weight.dtype) == "bfloat16") and is_310p()) else loaded_weight) param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) + + +def get_quant_config(model_config: ModelConfig, + load_config: LoadConfig) -> QuantizationConfig: + + from vllm_mindspore.model_executor.layers.quantization import ( + get_quantization_config) + quant_cls = get_quantization_config(model_config.quantization) + + # GGUF doesn't have config file + if model_config.quantization == "gguf": + return quant_cls.from_config({}) + + # Read the quantization config from the HF model config, if available. + hf_quant_config = getattr(model_config.hf_config, "quantization_config", + None) + # some vision model may keep quantization_config in their text_config + hf_text_config = getattr(model_config.hf_config, "text_config", None) + if hf_quant_config is None and hf_text_config is not None: + hf_quant_config = getattr(hf_text_config, "quantization_config", None) + if hf_quant_config is None: + # compressed-tensors uses a compressions_config + hf_quant_config = getattr(model_config.hf_config, "compression_config", + None) + if hf_quant_config is not None: + if os.path.isdir(model_config.model): + quant_config_file = os.path.join( + model_config.model, + quant_cls.get_config_filenames()[0]) + with open(quant_config_file) as f: + quant_config = json.load(f) + return quant_cls.from_config(hf_quant_config | quant_config) + + # In case of bitsandbytes/QLoRA, get quant config from the adapter model. + if model_config.quantization == "bitsandbytes": + if (not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" + not in load_config.model_loader_extra_config): + return quant_cls.from_config({"adapter_name_or_path": ""}) + model_name_or_path = load_config.model_loader_extra_config[ + "qlora_adapter_name_or_path"] + + else: + model_name_or_path = model_config.model + is_local = os.path.isdir(model_name_or_path) + if not is_local: + # Download the config files. + with get_lock(model_name_or_path, load_config.download_dir): + hf_folder = snapshot_download( + model_name_or_path, + revision=model_config.revision, + allow_patterns="*.json", + cache_dir=load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + tqdm_class=DisabledTqdm, + ) + else: + hf_folder = model_name_or_path + + possible_config_filenames = quant_cls.get_config_filenames() + + # If the quantization config is not found, use the default config. + if not possible_config_filenames: + return quant_cls() + + config_files = glob.glob(os.path.join(hf_folder, "*.json")) + + quant_config_files = [ + f for f in config_files if any( + f.endswith(x) for x in possible_config_filenames) + ] + if len(quant_config_files) == 0: + raise ValueError( + f"Cannot find the config file for {model_config.quantization}") + if len(quant_config_files) > 1: + raise ValueError( + f"Found multiple config files for {model_config.quantization}: " + f"{quant_config_files}") + + quant_config_file = quant_config_files[0] + with open(quant_config_file) as f: + config = json.load(f) + + if model_config.quantization == "bitsandbytes": + config["adapter_name_or_path"] = model_name_or_path + elif model_config.quantization == "modelopt": + if config["producer"]["name"] == "modelopt": + return quant_cls.from_config(config) + else: + raise ValueError( + f"Unsupported quantization config" + f" found for {model_config.quantization} in {f}.") + + return quant_cls.from_config(config) diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index db8e31c0..9a901ba2 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -44,6 +44,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.sequence import IntermediateTensors from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.distributed import get_tensor_model_parallel_rank from vllm_mindspore.attention import Attention from vllm_mindspore.model_executor.layers.activation import SiluAndMul @@ -375,6 +376,10 @@ class Qwen2Model(nn.Cell): ] for name, loaded_weight in weights: + if get_tensor_model_parallel_rank( + ) > 0 and "o_proj.quant_bias" in name: + continue + if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name -- Gitee From 470ddf7f2492aa7bb6d7366c72464b99b2e138e4 Mon Sep 17 00:00:00 2001 From: luolihao Date: Thu, 24 Jul 2025 19:08:26 +0800 Subject: [PATCH 6/6] support qwq w8a8sc --- .../layers/quantization/__init__.py | 9 +- .../quantization/sparse_quant_modelslim.py | 182 ++++++++++++++++++ vllm_mindspore/model_executor/models/qwen2.py | 50 ++++- 3 files changed, 238 insertions(+), 3 deletions(-) create mode 100644 vllm_mindspore/model_executor/layers/quantization/sparse_quant_modelslim.py diff --git a/vllm_mindspore/model_executor/layers/quantization/__init__.py b/vllm_mindspore/model_executor/layers/quantization/__init__.py index 6c9e2e41..3c6c2da9 100644 --- a/vllm_mindspore/model_executor/layers/quantization/__init__.py +++ b/vllm_mindspore/model_executor/layers/quantization/__init__.py @@ -21,7 +21,10 @@ from typing import Literal, get_args from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -QuantizationMethods = Literal["smoothquant"] +QuantizationMethods = Literal[ + "smoothquant", + "sparsequant" +] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) # The customized quantization methods which will be added to this dict. @@ -34,8 +37,10 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: # lazy import to avoid triggering `torch.compile` too early from .smooth_quant_modelslim import SmoothQuantModelSlimConfig + from .sparse_quant_modelslim import SparseQuantModelSlimConfig method_to_config: dict[str, type[QuantizationConfig]] = { - "smoothquant": SmoothQuantModelSlimConfig + "smoothquant": SmoothQuantModelSlimConfig, + "sparsequant": SparseQuantModelSlimConfig } # Update the `method_to_config` with customized quantization methods. method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) diff --git a/vllm_mindspore/model_executor/layers/quantization/sparse_quant_modelslim.py b/vllm_mindspore/model_executor/layers/quantization/sparse_quant_modelslim.py new file mode 100644 index 00000000..f6ede5ed --- /dev/null +++ b/vllm_mindspore/model_executor/layers/quantization/sparse_quant_modelslim.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from typing import Any, Optional, Dict + +import torch +import numpy as np +import mindspore + +from mindspore.common.initializer import initializer +from mindspore import Parameter, ops, Tensor +from mindspore.ops.operations._infer_ops import QuantV2 +from mindspore.communication import get_rank +from vllm_mindspore.model_executor.layers.linear import LinearMethodBase, UnquantizedLinearMethod, LinearBase + +from .base_config import QuantizationConfig + + + +class SparseQuantModelSlimConfig(QuantizationConfig): + '''Config class for SparseQuant.''' + + def __init__( + self, + full_config: Dict[str, Any], + weight_bits: Optional[int] = 8, + group_size: Optional[int] = 1, + zero_point: Optional[bool] = True, + dynamic_quant: Optional[bool] = False, + kv_cache_bits: Optional[int] = 16, + modules_to_not_convert: Optional[list[str]] = None, + ) -> None: + super().__init__() + self.full_config = full_config + self.weight_bits = weight_bits + self.group_size = group_size + self.zero_point = zero_point + self.dynamic_quant = dynamic_quant + self.kv_cache_bits = kv_cache_bits + self.modules_to_not_convert = modules_to_not_convert or [] + + if self.weight_bits != 8: + raise ValueError( + "Currently, only 8-bit weight quantization is supported for " + f"A8W8SC, but got {self.weight_bits} bits.") + self.pack_factor = 8 // self.weight_bits + + def __repr__(self) -> str: + return (f"SparseConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"zero_point={self.zero_point}, " + f"modules_to_not_convert={self.modules_to_not_convert})") + + @staticmethod + def get_config_filenames() -> list[str]: + return [ + "quant_model_description.json" + ] + + @classmethod + def get_min_capability(cls) -> int: + """Minimum GPU capability to support the quantization method. + + E.g., 70 for Volta, 75 for Turing, 80 for Ampere. + This requirement is due to the custom CUDA kernels used by the + quantization method. + """ + return -1 + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "SparseQuantModelSlimConfig": + return cls(config) + + def get_name(self) -> str: + return "SparseQuant" + + def get_supported_act_dtypes(self) -> list[torch.dtype]: + return [torch.int8, torch.float16, torch.bfloat16] + + def get_quant_method(self, layer: mindspore.nn.Cell, + prefix: str) -> "QuantizeMethodBase": + + rank_id = get_rank() + sparse_quant_description = self.full_config[f'rank_{rank_id}'] + if isinstance(layer, LinearBase) and sparse_quant_description[f"{prefix}.weight"].lower() == "w8a8s": + compress_weight_size = sparse_quant_description[f"{prefix}.weight.shape"] + compress_index_size = sparse_quant_description[f"{prefix}.index.shape"] + + return A8W8SCLinearMethod(self, compress_weight_size[0], compress_index_size[0]) + + return UnquantizedLinearMethod() + + +class A8W8SCLinearMethod(LinearMethodBase): + '''Linear method for A8W8SCLinearMethod.''' + + def __init__(self, quant_config: SparseQuantModelSlimConfig, compress_weight_size=None, compress_index_size=None): + self.quant_config = quant_config + self.compress_weight_size = compress_weight_size + self.compress_index_size = compress_index_size + + self.quant = QuantV2() + self.linear_sparse = ops.auto_generate.QuantLinearSparse() + + def create_weights(self, + layer: mindspore.nn.Cell, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype, + is_group_mm=False, + expert_num_per_partition=1, + **extra_weight_attrs): + if input_size_per_partition % self.quant_config.group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + output_size_per_partition = sum(output_partition_sizes) + self.output_size_per_partition = output_size_per_partition + self.input_size_per_partition = input_size_per_partition + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + weight = Parameter(initializer('normal', (self.compress_weight_size), mindspore.int8), name="weight") + index = Parameter(initializer('normal', (self.compress_index_size), mindspore.int8), name="index") + deq_scale = Parameter(initializer('normal', (self.output_size_per_partition), mindspore.int64), + name="deq_scale") + quant_bias = Parameter(initializer('zeros', (self.output_size_per_partition), mindspore.int32), + name="quant_bias") + input_scale = Parameter(Tensor(np.ones(self.input_size_per_partition), mindspore.float16), + name="input_scale") + input_offset = Parameter(Tensor(np.zeros(self.input_size_per_partition), mindspore.int8), + name="input_offset") + + layer.insert_param_to_cell("weight", weight) + layer.insert_param_to_cell("index", index) + layer.insert_param_to_cell("deq_scale", deq_scale) + layer.insert_param_to_cell("quant_bias", quant_bias) + layer.insert_param_to_cell("input_scale", input_scale) + layer.insert_param_to_cell("input_offset", input_offset) + + def apply(self, + layer: mindspore.nn.Cell, + x: mindspore.Tensor, + bias: mindspore.Parameter = None, group_list=None, cumsum_flag=False) -> mindspore.Tensor: + weight = layer.weight + index = layer.index + deq_scale = layer.deq_scale + quant_bias = layer.quant_bias + input_scale = layer.input_scale + input_offset = layer.input_offset + + output_shape = x.shape[:-1] + (self.output_size_per_partition,) + x = x.reshape(-1, self.input_size_per_partition) + + x = self.quant(x, input_scale, input_offset, False, "ROUND", mindspore.int8) + x = self.linear_sparse(x, weight, deq_scale, index, quant_bias) + + x = x.reshape(output_shape) + + return x \ No newline at end of file diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 9a901ba2..e2ad9322 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -62,6 +62,7 @@ from vllm_mindspore.model_executor.models.model_base import (NativeModel) from vllm_mindspore.model_executor.models.utils import ( PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from mindspore.communication.management import get_rank class Qwen2MLP(nn.Cell): @@ -363,6 +364,50 @@ class Qwen2Model(nn.Cell): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def load_split_weights(self, weights: Iterable[tuple[str, Tensor]], + params_dict: dict[str, Parameter]): + weights_dict = dict(weights) + + for name, loaded_weight in weights_dict.items(): + if get_tensor_model_parallel_rank( + ) > 0 and "o_proj.quant_bias" in name: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + param.set_data(loaded_weight.contiguous()) + + def adjust_weight(params_dict): + if not is_310p(): + return + + target_keywords = [ + "qkv_proj.weight", + "o_proj.weight", + "gate_up_proj.weight", + "down_proj.weight", + # "lm_head.weight", + ] + + rank_id = get_rank() + for name, param in params_dict.items(): + if any(name.endswith(keyword) for keyword in target_keywords): + weight_type = self.quant_config.full_config[f"rank_{rank_id}"][name] + if weight_type.lower() == "w8a8s": + # 压缩后权重不需要转Nz + continue + + cast_weight = ops.auto_generate.format_cast(param, FORMAT_TYPE['nz']) + ms.runtime.synchronize() + param.set_data(cast_weight) + + if is_310p(): + ms.runtime.synchronize() + adjust_weight(params_dict) + ms.runtime.synchronize() + def load_weights(self, weights: Iterable[tuple[str, Tensor]], params_dict: dict[str, Parameter]): loaded_params: set[str] = set() @@ -486,7 +531,10 @@ class Qwen2ForCausalLM(NativeModel, SupportsLoRA): def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> set[str]: params_dict = self.get_params_dict() - self.model.load_weights(weights, params_dict) + if self.vllm_config.model_config.quantization == "sparsequant": + self.model.load_split_weights(weights, params_dict) + else: + self.model.load_weights(weights, params_dict) def compute_logits( self, -- Gitee