From bda8799f977dbfa024d6ef91211487b055cf6bee Mon Sep 17 00:00:00 2001 From: fan-jibin Date: Mon, 15 Sep 2025 22:27:48 +0800 Subject: [PATCH 1/5] update ms_custom_ops --- .jenkins/test/config/dependent_packages.yaml | 2 +- ms_custom_ops | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.jenkins/test/config/dependent_packages.yaml b/.jenkins/test/config/dependent_packages.yaml index d873cf81..44cc32dc 100644 --- a/.jenkins/test/config/dependent_packages.yaml +++ b/.jenkins/test/config/dependent_packages.yaml @@ -1,5 +1,5 @@ mindspore: - 'https://repo.mindspore.cn/mindspore/mindspore/version/202509/20250904/master_20250904010018_31badacf8c241fee892a5ae0e08c09720bfe4902_newest/' + 'https://repo.mindspore.cn/mindspore/mindspore/version/202509/20250915/master_20250915160019_3868ae86368269272cb7950c6ce6d376f714e128_newest/' mindspore_gs: 'https://repo.mindspore.cn/mindspore/golden-stick/version/202509/20250901/master_20250901221800_3e34fd43040b0c5d296e6bc1a82212deae3ee041_newest/' diff --git a/ms_custom_ops b/ms_custom_ops index 21ebc2f3..d6017d77 160000 --- a/ms_custom_ops +++ b/ms_custom_ops @@ -1 +1 @@ -Subproject commit 21ebc2f393ab61e9eba38cf19e563f5ce2eec7da +Subproject commit d6017d77a1fd7dc40c42bef98cee1e63de24a4c8 -- Gitee From c04f5baaa06e5a6a2590d37db40bdc2c714f511f Mon Sep 17 00:00:00 2001 From: huoxinyou Date: Mon, 15 Sep 2025 09:39:28 +0800 Subject: [PATCH 2/5] 310p weight_loader need format_cast('nz') --- .../model_executor/layers/linear.py | 25 +++++++++++++++++++ .../layers/vocab_parallel_embedding.py | 6 +++++ .../model_loader/weight_utils.py | 18 +++++++++++++ .../model_executor/models/model_base.py | 5 ++-- 4 files changed, 52 insertions(+), 2 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index 76518fea..369ac118 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -20,6 +20,7 @@ """ Linear methods for quantized linear layers. """ from abc import abstractmethod +from ast import Dict from typing import Optional, Union import mindspore as ms @@ -39,6 +40,7 @@ from vllm_mindspore.model_executor.layers.quantization.base_config import ( from vllm_mindspore.model_executor.model_loader.weight_utils import ( split_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs +from vllm_mindspore.utils import FORMAT_TYPE, is_310p WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", @@ -157,10 +159,20 @@ class LinearBase(nn.Cell): self.quant_method = quant_config.get_quant_method(self, prefix=prefix) self.return_bias = return_bias + self.param_load_counts: Dict[str, int] = {} def construct(self, x: Tensor) -> Tensor: raise NotImplementedError + def format_to_nz(self, param, merge_count=1): + current_count = self.param_load_counts.get(param.name, 0) + 1 + self.param_load_counts[param.name] = current_count + + if current_count == merge_count: + cast_weight = ops.auto_generate.format_cast(param, FORMAT_TYPE['nz']) + param.set_data(cast_weight) + del self.param_load_counts[param.name] + class ColumnParallelLinear(LinearBase): """Linear layer with column parallelism. @@ -281,6 +293,8 @@ class ColumnParallelLinear(LinearBase): assert param.shape == loaded_weight.shape param.set_data(ms.from_numpy(loaded_weight)) + if is_310p() and param.name.endswith("weight"): + self.format_to_nz(param, merge_count=1) class MergedColumnParallelLinear(ColumnParallelLinear): @@ -351,6 +365,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): assert loaded_weight.shape == (shard_size, param.shape[1]) param[shard_offset:shard_offset + shard_size, :] = ms.from_numpy(loaded_weight) + if is_310p() and param.name.endswith("weight"): + loaded_shard_num = 2 # gating/hidden + self.format_to_nz(param, loaded_shard_num) class QKVParallelLinear(ColumnParallelLinear): @@ -459,6 +476,9 @@ class QKVParallelLinear(ColumnParallelLinear): assert loaded_weight.shape == param.shape param.set_data(loaded_weight) + if is_310p() and param.name.endswith("weight"): + loaded_shard_num = 3 # q/k/v + self.format_to_nz(param, loaded_shard_num) return assert loaded_shard_id in ["q", "k", "v"] @@ -487,6 +507,9 @@ class QKVParallelLinear(ColumnParallelLinear): if param.name.endswith("bias"): assert loaded_weight.shape == (shard_size, ) param[shard_offset:shard_offset + shard_size] = loaded_weight + if is_310p() and param.name.endswith("weight"): + loaded_shard_num = 3 # q/k/v + self.format_to_nz(param, loaded_shard_num) class RowParallelLinear(LinearBase): @@ -621,3 +644,5 @@ class RowParallelLinear(LinearBase): assert param.shape == loaded_weight.shape param.set_data(ms.from_numpy(loaded_weight)) + if is_310p() and param.name.endswith("weight"): + self.format_to_nz(param, merge_count=1) diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index abd9cb24..607722b1 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -38,6 +38,7 @@ from vllm_mindspore.model_executor.layers.quantization.base_config import ( from vllm_mindspore.model_executor.model_loader.weight_utils import ( split_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs +from vllm_mindspore.utils import is_310p, FORMAT_TYPE DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -339,6 +340,7 @@ class VocabParallelEmbedding(nn.Cell): return output def weight_loader(self, param: Parameter, loaded_weight: Tensor): + need_nz = is_310p() and param.name == "weight" output_dim = getattr(param, "output_dim", None) get_tensor_model_parallel_rank() # If parameter does not have output dim, then it should @@ -352,6 +354,8 @@ class VocabParallelEmbedding(nn.Cell): f"'loaded_weight.shape', but got {param.data.shape} " f"and {loaded_weight.shape}") param.set_data(ms.from_numpy(loaded_weight)) + if need_nz: + param.set_data(ops.auto_generate.format_cast(param, FORMAT_TYPE['nz'])) return # Shard indexes for loading the weight @@ -368,6 +372,8 @@ class VocabParallelEmbedding(nn.Cell): param[:loaded_weight.shape[0]] = ms.from_numpy(loaded_weight) param[loaded_weight.shape[0]:] = 0 + if need_nz: + param.set_data(ops.auto_generate.format_cast(param, FORMAT_TYPE['nz'])) class ParallelLMHead(VocabParallelEmbedding): diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 6bf2dd4c..9947387e 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -20,6 +20,7 @@ from collections.abc import Generator from typing import Any +import numpy as np import mindspore as ms from mindspore import Parameter @@ -27,6 +28,7 @@ from safetensors import safe_open from tqdm.auto import tqdm from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, enable_tqdm) +from vllm_mindspore.utils import is_310p def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): @@ -39,6 +41,11 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): """ if shard_dim is None: loaded_weight = loaded_weight[:] + loaded_weight = ( + loaded_weight.astype(np.float16) + if ((str(loaded_weight.dtype) == "float32" or str(loaded_weight.dtype) == "bfloat16") and is_310p()) + else loaded_weight + ) return loaded_weight end_idx = start_idx + shard_size @@ -50,6 +57,12 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): loaded_weight = loaded_weight[:, :, start_idx:end_idx] else: raise ValueError("shard_dim:{} is not supported.".format(shard_dim)) + loaded_weight = ( + loaded_weight.astype(np.float16) + if ((str(loaded_weight.dtype) == "float32" or str(loaded_weight.dtype) == "bfloat16") and is_310p()) + else loaded_weight + ) + return loaded_weight @@ -77,4 +90,9 @@ def safetensors_weights_iterator( def default_weight_loader(param: Parameter, loaded_weight: Any) -> None: """Default weight loader.""" loaded_weight = loaded_weight[:] + loaded_weight = ( + loaded_weight.astype(np.float16) + if ((str(loaded_weight.dtype) == "float32" or str(loaded_weight.dtype) == "bfloat16") and is_310p()) + else loaded_weight + ) param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 524e31ce..b5f9acb4 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -34,7 +34,7 @@ from vllm_mindspore.model_executor.models.attention_mask import ( LowerTriangularMask) from vllm_mindspore.model_executor.models.utils import is_use_ringmla from vllm_mindspore.model_executor.utils import set_model_context -from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE, create_kv_cache +from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE, create_kv_cache, is_310p from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata @@ -446,7 +446,8 @@ class NativeModel(MsModelBase): block_size = self.cache_config.block_size num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) head_size = self.model_config.get_head_size() - kv_cache_shape = (None, block_size, num_kv_heads, head_size) + kv_cache_shape = (None, block_size, num_kv_heads * head_size) \ + if is_310p() else (None, block_size, num_kv_heads, head_size) kv_cache_dtype = (self.model_config.dtype if self.cache_config.cache_dtype == "auto" else -- Gitee From 80828f57e3a11025d02310707c691a63a7982897 Mon Sep 17 00:00:00 2001 From: luolihao Date: Wed, 13 Aug 2025 14:55:00 +0800 Subject: [PATCH 3/5] fix bug and support lm_head to Nz --- .../distributed/communication_op.py | 31 ++++- .../model_executor/layers/logits_processor.py | 121 ++++++++++++++---- 2 files changed, 123 insertions(+), 29 deletions(-) diff --git a/vllm_mindspore/distributed/communication_op.py b/vllm_mindspore/distributed/communication_op.py index c933dc4a..c3646e48 100644 --- a/vllm_mindspore/distributed/communication_op.py +++ b/vllm_mindspore/distributed/communication_op.py @@ -21,9 +21,12 @@ Implement a unified communication interface for both graph and pynative mode. """ -from mindspore import nn, ops +from typing import Optional + +from mindspore import Tensor, mint, nn, ops from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_world_size, get_tp_group) + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_group) class ReduceFromModelParallelRegion(nn.Cell): @@ -64,3 +67,27 @@ class AllGatherFromModelParallelRegion(nn.Cell): output = self.all_gather_into_tensor(input_) output = ops.swapaxes(output, 0, -1) return output + + +class GatherFromModelParallelRegion(nn.Cell): + "Gather the input from model parallel region and concatenate." + + def __init__(self): + super().__init__() + self.world_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + if self.world_size > 1: + self.tp_group = get_tp_group().device_group._name + + def construct(self, + input_: Tensor, + dst: int = 0, + dim: int = -1) -> Optional[Tensor]: + # Size and dimension. + if self.world_size == 1: + return input_ + output = ops.CollectiveGather(dest_rank=dst, + group=self.tp_group)(mint.transpose(input_, 0, dim)) + if self.tp_rank != dst: + return None + return mint.transpose(output, 0, dim) diff --git a/vllm_mindspore/model_executor/layers/logits_processor.py b/vllm_mindspore/model_executor/layers/logits_processor.py index ee8c8edc..6910804a 100644 --- a/vllm_mindspore/model_executor/layers/logits_processor.py +++ b/vllm_mindspore/model_executor/layers/logits_processor.py @@ -23,12 +23,14 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional import vllm.envs as envs -from mindspore import Tensor, mint, nn -from vllm.config import current_platform +from mindspore import Tensor, jit, mint, nn +from vllm.config import current_platform, get_current_vllm_config from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm_mindspore.distributed.communication_op import ( + AllGatherFromModelParallelRegion, GatherFromModelParallelRegion) from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -60,6 +62,9 @@ class LogitsProcessor(nn.Cell): scale: A scaling factor to apply to the logits. """ super().__init__() + vllm_config = get_current_vllm_config() + self.vllm_config = vllm_config + self.is_graph_mode = bool(not vllm_config.model_config.enforce_eager) self.scale = scale self.vocab_size = vocab_size # Whether the input is logits (default is hidden states). @@ -71,25 +76,101 @@ class LogitsProcessor(nn.Cell): # Whether to use gather or all-gather to gather the logits. self.use_all_gather = current_platform.use_all_gather() + if self.use_all_gather: + self.tensor_model_parallel_all_gather = AllGatherFromModelParallelRegion() + else: + self.tensor_model_parallel_gather = GatherFromModelParallelRegion() + self.lm_head = None + self.run_model = None + self.cached_input_info = {} + + def set_dynamic_inputs(self): + dyn_hidden_states = Tensor(shape=[None, None], + dtype=self.vllm_config.model_config.dtype) + + if self.cached_input_info["indices"] is None: + dyn_indices = None + else: + dyn_indices_shape = [ + None for _ in range(self.cached_input_info["indices"]["ndim"]) + ] + dyn_indices_dtype = self.cached_input_info["indices"]["dtype"] + dyn_indices = Tensor(shape=dyn_indices_shape, + dtype=dyn_indices_dtype) + + if self.cached_input_info["bias"] is None: + dyn_bias = None + else: + dyn_bias_shape = [ + None for _ in range(self.cached_input_info["bias"]["ndim"]) + ] + dyn_bias_dtype = self.cached_input_info["bias"]["dtype"] + dyn_bias = Tensor(shape=dyn_bias_shape, dtype=dyn_bias_dtype) + + self.set_inputs(dyn_hidden_states, dyn_indices, dyn_bias) + + def __call__( + self, + lm_head: VocabParallelEmbedding, + hidden_states: Tensor, + sampling_metadata: Optional[SamplingMetadata] = None, + embedding_bias: Optional[Tensor] = None, + ) -> Optional[Tensor]: + if self.lm_head is None: + self.lm_head = lm_head + if self.run_model is None: + self.run_model = jit( + function=self.construct, + jit_level='O0') if self.is_graph_mode else self.construct + selected_token_indices = None + if sampling_metadata is not None: + selected_token_indices = sampling_metadata.selected_token_indices + dyn_indices_info = None if selected_token_indices is None else { + "ndim": selected_token_indices.ndim, + "dtype": selected_token_indices.dtype, + } + dyn_bias_info = None if embedding_bias is None else { + "ndim": embedding_bias.ndim, + "dtype": embedding_bias.dtype, + } + if self.cached_input_info != {"indices": dyn_indices_info, + "bias": dyn_bias_info}: + self.cached_input_info = { + "indices": dyn_indices_info, + "bias": dyn_bias_info, + } + self.set_dynamic_inputs() + + logits = self.run_model( + hidden_states, + selected_token_indices, + embedding_bias + ) + + if sampling_metadata is not None and \ + sampling_metadata.seq_groups is not None: + logits = _apply_logits_processors(logits, sampling_metadata) + + return logits + def construct( self, - lm_head: VocabParallelEmbedding, hidden_states: Tensor, - sampling_metadata: Optional[SamplingMetadata] = None, + selected_token_indices: Optional[Tensor] = None, embedding_bias: Optional[Tensor] = None, ) -> Optional[Tensor]: if self.logits_as_input: logits = hidden_states else: - if sampling_metadata is not None: - if sampling_metadata.selected_token_indices.numel() <= 0: - return mint.zeros((0, self.vocab_size), - dtype=hidden_states.dtype) - hidden_states = _prune_hidden_states(hidden_states, - sampling_metadata) + if selected_token_indices is not None: + if selected_token_indices.numel() <= 0: + return mint.zeros((0, self.vocab_size), dtype=hidden_states.dtype) + hidden_states = mint.index_select( + hidden_states, 0, selected_token_indices) # Get the logits for the next tokens. - logits = self._get_logits(hidden_states, lm_head, embedding_bias) + logits = self._get_logits( + hidden_states, self.lm_head, embedding_bias) if logits is not None: if self.soft_cap is not None: logits = logits / self.soft_cap @@ -100,9 +181,6 @@ class LogitsProcessor(nn.Cell): logits *= self.scale # Apply logits processors (if any). - if sampling_metadata is not None and \ - sampling_metadata.seq_groups is not None: - logits = _apply_logits_processors(logits, sampling_metadata) return logits @@ -118,10 +196,10 @@ class LogitsProcessor(nn.Cell): bias=embedding_bias) if self.use_all_gather: # Gather is not supported for some devices such as NPUs. - logits = tensor_model_parallel_all_gather(logits) + logits = self.tensor_model_parallel_all_gather(logits) else: # None may be returned for rank > 0 - logits = tensor_model_parallel_gather(logits) + logits = self.tensor_model_parallel_gather(logits) # Remove paddings in vocab (if any). if logits is not None: logits = logits[..., :self.org_vocab_size] @@ -134,17 +212,6 @@ class LogitsProcessor(nn.Cell): return s -def _prune_hidden_states( - hidden_states: Tensor, - sampling_metadata: SamplingMetadata, -) -> Tensor: - indices = sampling_metadata.selected_token_indices - if indices is not None and indices.numel() > 0: - return mint.index_select(hidden_states, 0, - sampling_metadata.selected_token_indices) - return hidden_states - - def _apply_logits_processors( logits: Tensor, sampling_metadata: SamplingMetadata, -- Gitee From 142f39f83942f02052f44123933f880f6998b4cc Mon Sep 17 00:00:00 2001 From: zhanzhan1 Date: Tue, 5 Aug 2025 19:40:26 +0800 Subject: [PATCH 4/5] lm_head support jit --- .../model_executor/layers/logits_processor.py | 47 ++++++++----------- 1 file changed, 19 insertions(+), 28 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/logits_processor.py b/vllm_mindspore/model_executor/layers/logits_processor.py index 6910804a..1b1770cc 100644 --- a/vllm_mindspore/model_executor/layers/logits_processor.py +++ b/vllm_mindspore/model_executor/layers/logits_processor.py @@ -75,7 +75,7 @@ class LogitsProcessor(nn.Cell): self.soft_cap = soft_cap # Whether to use gather or all-gather to gather the logits. self.use_all_gather = current_platform.use_all_gather() - + if self.use_all_gather: self.tensor_model_parallel_all_gather = AllGatherFromModelParallelRegion() else: @@ -85,36 +85,28 @@ class LogitsProcessor(nn.Cell): self.cached_input_info = {} def set_dynamic_inputs(self): - dyn_hidden_states = Tensor(shape=[None, None], - dtype=self.vllm_config.model_config.dtype) - - if self.cached_input_info["indices"] is None: - dyn_indices = None - else: - dyn_indices_shape = [ - None for _ in range(self.cached_input_info["indices"]["ndim"]) - ] - dyn_indices_dtype = self.cached_input_info["indices"]["dtype"] - dyn_indices = Tensor(shape=dyn_indices_shape, - dtype=dyn_indices_dtype) - - if self.cached_input_info["bias"] is None: - dyn_bias = None - else: - dyn_bias_shape = [ - None for _ in range(self.cached_input_info["bias"]["ndim"]) - ] - dyn_bias_dtype = self.cached_input_info["bias"]["dtype"] - dyn_bias = Tensor(shape=dyn_bias_shape, dtype=dyn_bias_dtype) + dyn_hidden_states = Tensor( + shape=[None, None], dtype=self.vllm_config.model_config.dtype) + + dyn_indices_shape = [None for _ in range( + self.cached_input_info["indices"]["ndim"])] + dyn_indices_dtype = self.cached_input_info["indices"]["dtype"] + dyn_indices = None if self.cached_input_info["indices"] is None else \ + Tensor(shape=dyn_indices_shape, dtype=dyn_indices_dtype) + + dyn_bias_shape = [None for _ in range( + self.cached_input_info["bias"]["ndim"])] + dyn_bias_dtype = self.cached_input_info["bias"]["dtype"] + dyn_bias = None if self.cached_input_info["bias"] is None else \ + Tensor(shape=dyn_bias_shape, dtype=dyn_bias_dtype) self.set_inputs(dyn_hidden_states, dyn_indices, dyn_bias) def __call__( - self, - lm_head: VocabParallelEmbedding, - hidden_states: Tensor, - sampling_metadata: Optional[SamplingMetadata] = None, - embedding_bias: Optional[Tensor] = None, + self, + hidden_states: Tensor, + selected_token_indices: Optional[Tensor] = None, + embedding_bias: Optional[Tensor] = None, ) -> Optional[Tensor]: if self.lm_head is None: self.lm_head = lm_head @@ -181,7 +173,6 @@ class LogitsProcessor(nn.Cell): logits *= self.scale # Apply logits processors (if any). - return logits def _get_logits( -- Gitee From 580b6b6d50761cc580c19b386cdada08e37a1545 Mon Sep 17 00:00:00 2001 From: huoxinyou Date: Mon, 15 Sep 2025 09:58:29 +0800 Subject: [PATCH 5/5] fixed lm_head bug --- .../distributed/communication_op.py | 31 +------ .../model_executor/layers/linear.py | 9 ++- .../model_executor/layers/logits_processor.py | 80 ++++++++++--------- .../layers/vocab_parallel_embedding.py | 8 +- .../model_loader/weight_utils.py | 30 ++++--- .../model_executor/models/model_base.py | 3 +- 6 files changed, 72 insertions(+), 89 deletions(-) diff --git a/vllm_mindspore/distributed/communication_op.py b/vllm_mindspore/distributed/communication_op.py index c3646e48..c933dc4a 100644 --- a/vllm_mindspore/distributed/communication_op.py +++ b/vllm_mindspore/distributed/communication_op.py @@ -21,12 +21,9 @@ Implement a unified communication interface for both graph and pynative mode. """ -from typing import Optional - -from mindspore import Tensor, mint, nn, ops +from mindspore import nn, ops from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - get_tp_group) + get_tensor_model_parallel_world_size, get_tp_group) class ReduceFromModelParallelRegion(nn.Cell): @@ -67,27 +64,3 @@ class AllGatherFromModelParallelRegion(nn.Cell): output = self.all_gather_into_tensor(input_) output = ops.swapaxes(output, 0, -1) return output - - -class GatherFromModelParallelRegion(nn.Cell): - "Gather the input from model parallel region and concatenate." - - def __init__(self): - super().__init__() - self.world_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - if self.world_size > 1: - self.tp_group = get_tp_group().device_group._name - - def construct(self, - input_: Tensor, - dst: int = 0, - dim: int = -1) -> Optional[Tensor]: - # Size and dimension. - if self.world_size == 1: - return input_ - output = ops.CollectiveGather(dest_rank=dst, - group=self.tp_group)(mint.transpose(input_, 0, dim)) - if self.tp_rank != dst: - return None - return mint.transpose(output, 0, dim) diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index 369ac118..5757f64d 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -169,7 +169,8 @@ class LinearBase(nn.Cell): self.param_load_counts[param.name] = current_count if current_count == merge_count: - cast_weight = ops.auto_generate.format_cast(param, FORMAT_TYPE['nz']) + cast_weight = ops.auto_generate.format_cast( + param, FORMAT_TYPE['nz']) param.set_data(cast_weight) del self.param_load_counts[param.name] @@ -366,7 +367,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): param[shard_offset:shard_offset + shard_size, :] = ms.from_numpy(loaded_weight) if is_310p() and param.name.endswith("weight"): - loaded_shard_num = 2 # gating/hidden + loaded_shard_num = 2 # gating/hidden self.format_to_nz(param, loaded_shard_num) @@ -477,7 +478,7 @@ class QKVParallelLinear(ColumnParallelLinear): assert loaded_weight.shape == param.shape param.set_data(loaded_weight) if is_310p() and param.name.endswith("weight"): - loaded_shard_num = 3 # q/k/v + loaded_shard_num = 3 # q/k/v self.format_to_nz(param, loaded_shard_num) return @@ -508,7 +509,7 @@ class QKVParallelLinear(ColumnParallelLinear): assert loaded_weight.shape == (shard_size, ) param[shard_offset:shard_offset + shard_size] = loaded_weight if is_310p() and param.name.endswith("weight"): - loaded_shard_num = 3 # q/k/v + loaded_shard_num = 3 # q/k/v self.format_to_nz(param, loaded_shard_num) diff --git a/vllm_mindspore/model_executor/layers/logits_processor.py b/vllm_mindspore/model_executor/layers/logits_processor.py index 1b1770cc..c14aac8f 100644 --- a/vllm_mindspore/model_executor/layers/logits_processor.py +++ b/vllm_mindspore/model_executor/layers/logits_processor.py @@ -25,12 +25,11 @@ from typing import Optional import vllm.envs as envs from mindspore import Tensor, jit, mint, nn from vllm.config import current_platform, get_current_vllm_config -from vllm.distributed import (tensor_model_parallel_all_gather, - tensor_model_parallel_gather) +from vllm.distributed import tensor_model_parallel_gather from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm_mindspore.distributed.communication_op import ( - AllGatherFromModelParallelRegion, GatherFromModelParallelRegion) + AllGatherFromModelParallelRegion) from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -75,37 +74,44 @@ class LogitsProcessor(nn.Cell): self.soft_cap = soft_cap # Whether to use gather or all-gather to gather the logits. self.use_all_gather = current_platform.use_all_gather() - + if self.use_all_gather: - self.tensor_model_parallel_all_gather = AllGatherFromModelParallelRegion() - else: - self.tensor_model_parallel_gather = GatherFromModelParallelRegion() + self.tensor_model_parallel_all_gather = \ + AllGatherFromModelParallelRegion() self.lm_head = None self.run_model = None self.cached_input_info = {} def set_dynamic_inputs(self): - dyn_hidden_states = Tensor( - shape=[None, None], dtype=self.vllm_config.model_config.dtype) - - dyn_indices_shape = [None for _ in range( - self.cached_input_info["indices"]["ndim"])] - dyn_indices_dtype = self.cached_input_info["indices"]["dtype"] - dyn_indices = None if self.cached_input_info["indices"] is None else \ - Tensor(shape=dyn_indices_shape, dtype=dyn_indices_dtype) - - dyn_bias_shape = [None for _ in range( - self.cached_input_info["bias"]["ndim"])] - dyn_bias_dtype = self.cached_input_info["bias"]["dtype"] - dyn_bias = None if self.cached_input_info["bias"] is None else \ - Tensor(shape=dyn_bias_shape, dtype=dyn_bias_dtype) + dyn_hidden_states = Tensor(shape=[None, None], + dtype=self.vllm_config.model_config.dtype) + + if self.cached_input_info["indices"] is None: + dyn_indices = None + else: + dyn_indices_shape = [ + None for _ in range(self.cached_input_info["indices"]["ndim"]) + ] + dyn_indices_dtype = self.cached_input_info["indices"]["dtype"] + dyn_indices = Tensor(shape=dyn_indices_shape, + dtype=dyn_indices_dtype) + + if self.cached_input_info["bias"] is None: + dyn_bias = None + else: + dyn_bias_shape = [ + None for _ in range(self.cached_input_info["bias"]["ndim"]) + ] + dyn_bias_dtype = self.cached_input_info["bias"]["dtype"] + dyn_bias = Tensor(shape=dyn_bias_shape, dtype=dyn_bias_dtype) self.set_inputs(dyn_hidden_states, dyn_indices, dyn_bias) def __call__( self, + lm_head: VocabParallelEmbedding, hidden_states: Tensor, - selected_token_indices: Optional[Tensor] = None, + sampling_metadata: Optional[SamplingMetadata] = None, embedding_bias: Optional[Tensor] = None, ) -> Optional[Tensor]: if self.lm_head is None: @@ -125,19 +131,23 @@ class LogitsProcessor(nn.Cell): "ndim": embedding_bias.ndim, "dtype": embedding_bias.dtype, } - if self.cached_input_info != {"indices": dyn_indices_info, - "bias": dyn_bias_info}: + if self.cached_input_info != { + "indices": dyn_indices_info, + "bias": dyn_bias_info + }: self.cached_input_info = { "indices": dyn_indices_info, "bias": dyn_bias_info, } self.set_dynamic_inputs() - logits = self.run_model( - hidden_states, - selected_token_indices, - embedding_bias - ) + if selected_token_indices is not None and selected_token_indices.numel( + ) <= 0: + logits = mint.zeros((0, self.vocab_size), + dtype=hidden_states.dtype) + else: + logits = self.run_model(hidden_states, selected_token_indices, + embedding_bias) if sampling_metadata is not None and \ sampling_metadata.seq_groups is not None: @@ -155,14 +165,12 @@ class LogitsProcessor(nn.Cell): logits = hidden_states else: if selected_token_indices is not None: - if selected_token_indices.numel() <= 0: - return mint.zeros((0, self.vocab_size), dtype=hidden_states.dtype) - hidden_states = mint.index_select( - hidden_states, 0, selected_token_indices) + hidden_states = mint.index_select(hidden_states, 0, + selected_token_indices) # Get the logits for the next tokens. - logits = self._get_logits( - hidden_states, self.lm_head, embedding_bias) + logits = self._get_logits(hidden_states, self.lm_head, + embedding_bias) if logits is not None: if self.soft_cap is not None: logits = logits / self.soft_cap @@ -190,7 +198,7 @@ class LogitsProcessor(nn.Cell): logits = self.tensor_model_parallel_all_gather(logits) else: # None may be returned for rank > 0 - logits = self.tensor_model_parallel_gather(logits) + logits = tensor_model_parallel_gather(logits) # Remove paddings in vocab (if any). if logits is not None: logits = logits[..., :self.org_vocab_size] diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index 607722b1..39d259ff 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -38,7 +38,7 @@ from vllm_mindspore.model_executor.layers.quantization.base_config import ( from vllm_mindspore.model_executor.model_loader.weight_utils import ( split_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs -from vllm_mindspore.utils import is_310p, FORMAT_TYPE +from vllm_mindspore.utils import FORMAT_TYPE, is_310p DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -355,7 +355,8 @@ class VocabParallelEmbedding(nn.Cell): f"and {loaded_weight.shape}") param.set_data(ms.from_numpy(loaded_weight)) if need_nz: - param.set_data(ops.auto_generate.format_cast(param, FORMAT_TYPE['nz'])) + param.set_data( + ops.auto_generate.format_cast(param, FORMAT_TYPE['nz'])) return # Shard indexes for loading the weight @@ -373,7 +374,8 @@ class VocabParallelEmbedding(nn.Cell): param[:loaded_weight.shape[0]] = ms.from_numpy(loaded_weight) param[loaded_weight.shape[0]:] = 0 if need_nz: - param.set_data(ops.auto_generate.format_cast(param, FORMAT_TYPE['nz'])) + param.set_data( + ops.auto_generate.format_cast(param, FORMAT_TYPE['nz'])) class ParallelLMHead(VocabParallelEmbedding): diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 9947387e..7a4ff7fd 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -20,14 +20,15 @@ from collections.abc import Generator from typing import Any -import numpy as np import mindspore as ms +import numpy as np from mindspore import Parameter from safetensors import safe_open from tqdm.auto import tqdm from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, enable_tqdm) + from vllm_mindspore.utils import is_310p @@ -41,11 +42,10 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): """ if shard_dim is None: loaded_weight = loaded_weight[:] - loaded_weight = ( - loaded_weight.astype(np.float16) - if ((str(loaded_weight.dtype) == "float32" or str(loaded_weight.dtype) == "bfloat16") and is_310p()) - else loaded_weight - ) + loaded_weight = (loaded_weight.astype(np.float16) if + ((str(loaded_weight.dtype) == "float32" + or str(loaded_weight.dtype) == "bfloat16") + and is_310p()) else loaded_weight) return loaded_weight end_idx = start_idx + shard_size @@ -57,11 +57,10 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): loaded_weight = loaded_weight[:, :, start_idx:end_idx] else: raise ValueError("shard_dim:{} is not supported.".format(shard_dim)) - loaded_weight = ( - loaded_weight.astype(np.float16) - if ((str(loaded_weight.dtype) == "float32" or str(loaded_weight.dtype) == "bfloat16") and is_310p()) - else loaded_weight - ) + loaded_weight = (loaded_weight.astype(np.float16) if + ((str(loaded_weight.dtype) == "float32" + or str(loaded_weight.dtype) == "bfloat16") + and is_310p()) else loaded_weight) return loaded_weight @@ -90,9 +89,8 @@ def safetensors_weights_iterator( def default_weight_loader(param: Parameter, loaded_weight: Any) -> None: """Default weight loader.""" loaded_weight = loaded_weight[:] - loaded_weight = ( - loaded_weight.astype(np.float16) - if ((str(loaded_weight.dtype) == "float32" or str(loaded_weight.dtype) == "bfloat16") and is_310p()) - else loaded_weight - ) + loaded_weight = (loaded_weight.astype(np.float16) if + ((str(loaded_weight.dtype) == "float32" + or str(loaded_weight.dtype) == "bfloat16") + and is_310p()) else loaded_weight) param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index b5f9acb4..602a8314 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -34,7 +34,8 @@ from vllm_mindspore.model_executor.models.attention_mask import ( LowerTriangularMask) from vllm_mindspore.model_executor.models.utils import is_use_ringmla from vllm_mindspore.model_executor.utils import set_model_context -from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE, create_kv_cache, is_310p +from vllm_mindspore.utils import (STR_DTYPE_TO_MS_DTYPE, create_kv_cache, + is_310p) from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata -- Gitee