diff --git a/.jenkins/test/config/dependent_packages.yaml b/.jenkins/test/config/dependent_packages.yaml index d873cf813eaff4f2123c987792974cdf8314f664..44cc32dc90f0f57911e5042ed0b108a561fa213d 100644 --- a/.jenkins/test/config/dependent_packages.yaml +++ b/.jenkins/test/config/dependent_packages.yaml @@ -1,5 +1,5 @@ mindspore: - 'https://repo.mindspore.cn/mindspore/mindspore/version/202509/20250904/master_20250904010018_31badacf8c241fee892a5ae0e08c09720bfe4902_newest/' + 'https://repo.mindspore.cn/mindspore/mindspore/version/202509/20250915/master_20250915160019_3868ae86368269272cb7950c6ce6d376f714e128_newest/' mindspore_gs: 'https://repo.mindspore.cn/mindspore/golden-stick/version/202509/20250901/master_20250901221800_3e34fd43040b0c5d296e6bc1a82212deae3ee041_newest/' diff --git a/ms_custom_ops b/ms_custom_ops index 21ebc2f393ab61e9eba38cf19e563f5ce2eec7da..d6017d77a1fd7dc40c42bef98cee1e63de24a4c8 160000 --- a/ms_custom_ops +++ b/ms_custom_ops @@ -1 +1 @@ -Subproject commit 21ebc2f393ab61e9eba38cf19e563f5ce2eec7da +Subproject commit d6017d77a1fd7dc40c42bef98cee1e63de24a4c8 diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index 76518feaf84f072a916b1e5b7f08472dfa16211b..5757f64d9f383da0b59bfcdca680e33c6cded8bb 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -20,6 +20,7 @@ """ Linear methods for quantized linear layers. """ from abc import abstractmethod +from ast import Dict from typing import Optional, Union import mindspore as ms @@ -39,6 +40,7 @@ from vllm_mindspore.model_executor.layers.quantization.base_config import ( from vllm_mindspore.model_executor.model_loader.weight_utils import ( split_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs +from vllm_mindspore.utils import FORMAT_TYPE, is_310p WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", @@ -157,10 +159,21 @@ class LinearBase(nn.Cell): self.quant_method = quant_config.get_quant_method(self, prefix=prefix) self.return_bias = return_bias + self.param_load_counts: Dict[str, int] = {} def construct(self, x: Tensor) -> Tensor: raise NotImplementedError + def format_to_nz(self, param, merge_count=1): + current_count = self.param_load_counts.get(param.name, 0) + 1 + self.param_load_counts[param.name] = current_count + + if current_count == merge_count: + cast_weight = ops.auto_generate.format_cast( + param, FORMAT_TYPE['nz']) + param.set_data(cast_weight) + del self.param_load_counts[param.name] + class ColumnParallelLinear(LinearBase): """Linear layer with column parallelism. @@ -281,6 +294,8 @@ class ColumnParallelLinear(LinearBase): assert param.shape == loaded_weight.shape param.set_data(ms.from_numpy(loaded_weight)) + if is_310p() and param.name.endswith("weight"): + self.format_to_nz(param, merge_count=1) class MergedColumnParallelLinear(ColumnParallelLinear): @@ -351,6 +366,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): assert loaded_weight.shape == (shard_size, param.shape[1]) param[shard_offset:shard_offset + shard_size, :] = ms.from_numpy(loaded_weight) + if is_310p() and param.name.endswith("weight"): + loaded_shard_num = 2 # gating/hidden + self.format_to_nz(param, loaded_shard_num) class QKVParallelLinear(ColumnParallelLinear): @@ -459,6 +477,9 @@ class QKVParallelLinear(ColumnParallelLinear): assert loaded_weight.shape == param.shape param.set_data(loaded_weight) + if is_310p() and param.name.endswith("weight"): + loaded_shard_num = 3 # q/k/v + self.format_to_nz(param, loaded_shard_num) return assert loaded_shard_id in ["q", "k", "v"] @@ -487,6 +508,9 @@ class QKVParallelLinear(ColumnParallelLinear): if param.name.endswith("bias"): assert loaded_weight.shape == (shard_size, ) param[shard_offset:shard_offset + shard_size] = loaded_weight + if is_310p() and param.name.endswith("weight"): + loaded_shard_num = 3 # q/k/v + self.format_to_nz(param, loaded_shard_num) class RowParallelLinear(LinearBase): @@ -621,3 +645,5 @@ class RowParallelLinear(LinearBase): assert param.shape == loaded_weight.shape param.set_data(ms.from_numpy(loaded_weight)) + if is_310p() and param.name.endswith("weight"): + self.format_to_nz(param, merge_count=1) diff --git a/vllm_mindspore/model_executor/layers/logits_processor.py b/vllm_mindspore/model_executor/layers/logits_processor.py index ee8c8edcf7fd6dd9243283ef6d139fae5430267d..c14aac8f3437788a3eb75b91a4f0e361d7ca8c3d 100644 --- a/vllm_mindspore/model_executor/layers/logits_processor.py +++ b/vllm_mindspore/model_executor/layers/logits_processor.py @@ -23,12 +23,13 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional import vllm.envs as envs -from mindspore import Tensor, mint, nn -from vllm.config import current_platform -from vllm.distributed import (tensor_model_parallel_all_gather, - tensor_model_parallel_gather) +from mindspore import Tensor, jit, mint, nn +from vllm.config import current_platform, get_current_vllm_config +from vllm.distributed import tensor_model_parallel_gather from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm_mindspore.distributed.communication_op import ( + AllGatherFromModelParallelRegion) from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -60,6 +61,9 @@ class LogitsProcessor(nn.Cell): scale: A scaling factor to apply to the logits. """ super().__init__() + vllm_config = get_current_vllm_config() + self.vllm_config = vllm_config + self.is_graph_mode = bool(not vllm_config.model_config.enforce_eager) self.scale = scale self.vocab_size = vocab_size # Whether the input is logits (default is hidden states). @@ -71,25 +75,102 @@ class LogitsProcessor(nn.Cell): # Whether to use gather or all-gather to gather the logits. self.use_all_gather = current_platform.use_all_gather() - def construct( + if self.use_all_gather: + self.tensor_model_parallel_all_gather = \ + AllGatherFromModelParallelRegion() + self.lm_head = None + self.run_model = None + self.cached_input_info = {} + + def set_dynamic_inputs(self): + dyn_hidden_states = Tensor(shape=[None, None], + dtype=self.vllm_config.model_config.dtype) + + if self.cached_input_info["indices"] is None: + dyn_indices = None + else: + dyn_indices_shape = [ + None for _ in range(self.cached_input_info["indices"]["ndim"]) + ] + dyn_indices_dtype = self.cached_input_info["indices"]["dtype"] + dyn_indices = Tensor(shape=dyn_indices_shape, + dtype=dyn_indices_dtype) + + if self.cached_input_info["bias"] is None: + dyn_bias = None + else: + dyn_bias_shape = [ + None for _ in range(self.cached_input_info["bias"]["ndim"]) + ] + dyn_bias_dtype = self.cached_input_info["bias"]["dtype"] + dyn_bias = Tensor(shape=dyn_bias_shape, dtype=dyn_bias_dtype) + + self.set_inputs(dyn_hidden_states, dyn_indices, dyn_bias) + + def __call__( self, lm_head: VocabParallelEmbedding, hidden_states: Tensor, sampling_metadata: Optional[SamplingMetadata] = None, embedding_bias: Optional[Tensor] = None, + ) -> Optional[Tensor]: + if self.lm_head is None: + self.lm_head = lm_head + if self.run_model is None: + self.run_model = jit( + function=self.construct, + jit_level='O0') if self.is_graph_mode else self.construct + selected_token_indices = None + if sampling_metadata is not None: + selected_token_indices = sampling_metadata.selected_token_indices + dyn_indices_info = None if selected_token_indices is None else { + "ndim": selected_token_indices.ndim, + "dtype": selected_token_indices.dtype, + } + dyn_bias_info = None if embedding_bias is None else { + "ndim": embedding_bias.ndim, + "dtype": embedding_bias.dtype, + } + if self.cached_input_info != { + "indices": dyn_indices_info, + "bias": dyn_bias_info + }: + self.cached_input_info = { + "indices": dyn_indices_info, + "bias": dyn_bias_info, + } + self.set_dynamic_inputs() + + if selected_token_indices is not None and selected_token_indices.numel( + ) <= 0: + logits = mint.zeros((0, self.vocab_size), + dtype=hidden_states.dtype) + else: + logits = self.run_model(hidden_states, selected_token_indices, + embedding_bias) + + if sampling_metadata is not None and \ + sampling_metadata.seq_groups is not None: + logits = _apply_logits_processors(logits, sampling_metadata) + + return logits + + def construct( + self, + hidden_states: Tensor, + selected_token_indices: Optional[Tensor] = None, + embedding_bias: Optional[Tensor] = None, ) -> Optional[Tensor]: if self.logits_as_input: logits = hidden_states else: - if sampling_metadata is not None: - if sampling_metadata.selected_token_indices.numel() <= 0: - return mint.zeros((0, self.vocab_size), - dtype=hidden_states.dtype) - hidden_states = _prune_hidden_states(hidden_states, - sampling_metadata) + if selected_token_indices is not None: + hidden_states = mint.index_select(hidden_states, 0, + selected_token_indices) # Get the logits for the next tokens. - logits = self._get_logits(hidden_states, lm_head, embedding_bias) + logits = self._get_logits(hidden_states, self.lm_head, + embedding_bias) if logits is not None: if self.soft_cap is not None: logits = logits / self.soft_cap @@ -100,10 +181,6 @@ class LogitsProcessor(nn.Cell): logits *= self.scale # Apply logits processors (if any). - if sampling_metadata is not None and \ - sampling_metadata.seq_groups is not None: - logits = _apply_logits_processors(logits, sampling_metadata) - return logits def _get_logits( @@ -118,7 +195,7 @@ class LogitsProcessor(nn.Cell): bias=embedding_bias) if self.use_all_gather: # Gather is not supported for some devices such as NPUs. - logits = tensor_model_parallel_all_gather(logits) + logits = self.tensor_model_parallel_all_gather(logits) else: # None may be returned for rank > 0 logits = tensor_model_parallel_gather(logits) @@ -134,17 +211,6 @@ class LogitsProcessor(nn.Cell): return s -def _prune_hidden_states( - hidden_states: Tensor, - sampling_metadata: SamplingMetadata, -) -> Tensor: - indices = sampling_metadata.selected_token_indices - if indices is not None and indices.numel() > 0: - return mint.index_select(hidden_states, 0, - sampling_metadata.selected_token_indices) - return hidden_states - - def _apply_logits_processors( logits: Tensor, sampling_metadata: SamplingMetadata, diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index abd9cb24ddb3e017bed685fbb19ef788dffdd833..39d259ffc57704903dad29a0f0e1d2f1eecba9a9 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -38,6 +38,7 @@ from vllm_mindspore.model_executor.layers.quantization.base_config import ( from vllm_mindspore.model_executor.model_loader.weight_utils import ( split_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs +from vllm_mindspore.utils import FORMAT_TYPE, is_310p DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -339,6 +340,7 @@ class VocabParallelEmbedding(nn.Cell): return output def weight_loader(self, param: Parameter, loaded_weight: Tensor): + need_nz = is_310p() and param.name == "weight" output_dim = getattr(param, "output_dim", None) get_tensor_model_parallel_rank() # If parameter does not have output dim, then it should @@ -352,6 +354,9 @@ class VocabParallelEmbedding(nn.Cell): f"'loaded_weight.shape', but got {param.data.shape} " f"and {loaded_weight.shape}") param.set_data(ms.from_numpy(loaded_weight)) + if need_nz: + param.set_data( + ops.auto_generate.format_cast(param, FORMAT_TYPE['nz'])) return # Shard indexes for loading the weight @@ -368,6 +373,9 @@ class VocabParallelEmbedding(nn.Cell): param[:loaded_weight.shape[0]] = ms.from_numpy(loaded_weight) param[loaded_weight.shape[0]:] = 0 + if need_nz: + param.set_data( + ops.auto_generate.format_cast(param, FORMAT_TYPE['nz'])) class ParallelLMHead(VocabParallelEmbedding): diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 6bf2dd4cd7bf7c66665f9f66737238dc0d9083d2..7a4ff7fdd114215b941adabc9e867ebd2aeb7b83 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -22,12 +22,15 @@ from collections.abc import Generator from typing import Any import mindspore as ms +import numpy as np from mindspore import Parameter from safetensors import safe_open from tqdm.auto import tqdm from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, enable_tqdm) +from vllm_mindspore.utils import is_310p + def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): """ @@ -39,6 +42,10 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): """ if shard_dim is None: loaded_weight = loaded_weight[:] + loaded_weight = (loaded_weight.astype(np.float16) if + ((str(loaded_weight.dtype) == "float32" + or str(loaded_weight.dtype) == "bfloat16") + and is_310p()) else loaded_weight) return loaded_weight end_idx = start_idx + shard_size @@ -50,6 +57,11 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): loaded_weight = loaded_weight[:, :, start_idx:end_idx] else: raise ValueError("shard_dim:{} is not supported.".format(shard_dim)) + loaded_weight = (loaded_weight.astype(np.float16) if + ((str(loaded_weight.dtype) == "float32" + or str(loaded_weight.dtype) == "bfloat16") + and is_310p()) else loaded_weight) + return loaded_weight @@ -77,4 +89,8 @@ def safetensors_weights_iterator( def default_weight_loader(param: Parameter, loaded_weight: Any) -> None: """Default weight loader.""" loaded_weight = loaded_weight[:] + loaded_weight = (loaded_weight.astype(np.float16) if + ((str(loaded_weight.dtype) == "float32" + or str(loaded_weight.dtype) == "bfloat16") + and is_310p()) else loaded_weight) param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 524e31cec9cc2f35d1827ff5f94b795b60598c86..602a8314d359ccd6409aaf9740060c06dd324e9a 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -34,7 +34,8 @@ from vllm_mindspore.model_executor.models.attention_mask import ( LowerTriangularMask) from vllm_mindspore.model_executor.models.utils import is_use_ringmla from vllm_mindspore.model_executor.utils import set_model_context -from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE, create_kv_cache +from vllm_mindspore.utils import (STR_DTYPE_TO_MS_DTYPE, create_kv_cache, + is_310p) from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata @@ -446,7 +447,8 @@ class NativeModel(MsModelBase): block_size = self.cache_config.block_size num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) head_size = self.model_config.get_head_size() - kv_cache_shape = (None, block_size, num_kv_heads, head_size) + kv_cache_shape = (None, block_size, num_kv_heads * head_size) \ + if is_310p() else (None, block_size, num_kv_heads, head_size) kv_cache_dtype = (self.model_config.dtype if self.cache_config.cache_dtype == "auto" else