diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index d13e4351c8252134be298c3172b01fe331e13ec9..ad935053fcf4c8732c6e422ec2c1fa77028cb68a 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -74,9 +74,9 @@ class MfModelBase(MsModelBase): self.mf_config.model.model_config.parallel_config.model_parallel = ( get_tensor_model_parallel_world_size()) self.mf_config.model.model_config.parallel_config.pipeline_stage = 1 - self.use_mla_op = \ - bool(vllm_config.additional_config - and vllm_config.additional_config.get('use_mla_op') == 1) + self.use_ringmla = vllm_config.model_config.quantization is not None \ + and vllm_config.parallel_config.tensor_parallel_size < 16 + self.is_chunked = False self._generate_model_config() if not hasattr(self, 'mf_model_config'): raise RuntimeError('mf_model_config not initialized') diff --git a/vllm_mindspore/model_executor/models/mf_models/mindformers.py b/vllm_mindspore/model_executor/models/mf_models/mindformers.py index e9676cb91d7279d6e7a095efbbf622931a08cfc2..bc85552675df4a8bc336f4a05587841f43459a75 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mindformers.py +++ b/vllm_mindspore/model_executor/models/mf_models/mindformers.py @@ -1,382 +1,398 @@ -# SPDX-License-Identifier: Apache-2.0 - -# Copyright 2025 Huawei Technologies Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from collections.abc import Iterable -from typing import Optional, Union - -import mindspore as ms -import numpy as np -from mindformers import AutoModel, PreTrainedModel -from mindformers.core.context import build_mf_context -from mindformers.parallel_core.process_group_config import ( - default_model_comm_pgs) -from mindformers.tools.utils import is_pynative -from mindspore import Tensor, mutable, ops -from mindspore.common.api import _pynative_executor -from mindspore.nn.utils import no_init_parameters -from vllm import envs -from vllm.config import VllmConfig, get_current_vllm_config -from vllm.distributed.parallel_state import get_dp_group, get_pp_group -from vllm.forward_context import get_forward_context -from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.models.interfaces import SupportsPP -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors - -from vllm_mindspore.model_executor.models.attention_mask import ( - LowerTriangularMask, MLALowerTriangularMask) -from vllm_mindspore.model_executor.models.mf_models.config import gen_mf_config -from vllm_mindspore.model_executor.models.model_base import ( - AttentionWrapper, MLAAttentionWrapper, MsModelBase) -from vllm_mindspore.model_executor.models.utils import ( - make_empty_intermediate_tensors_factory) -from vllm_mindspore.utils import is_310p - -logger = init_logger(__name__) - - -class MindFormersForCausalLM(MsModelBase, SupportsPP): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: - super().__init__(vllm_config=vllm_config, prefix=prefix) - self.set_flags = False - self.model_config = vllm_config.model_config - self.lm_head_graph = None - - mf_config = gen_mf_config(vllm_config) - mf_config.load_checkpoint = self.get_model_path() - mf_config.pretrained_model_dir = self.get_model_path() - self.mf_config = mf_config - self.mla_config = self.mf_config.get('model', None).get( - 'model_config', None).get('multi_latent_attention', False) - - build_mf_context(self.mf_config) - - self.network, self.lm_head = self._create_network() - self.casual_mask = self._create_mask() - - self._set_dynamic_inputs() - - self.sampler = get_sampler() - self.set_modules({"model": self.network}) - - num_layers = self.model_config.get_num_layers(self.parallel_config) - self.kv_caches = self._create_kv_caches(num_layers) - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - for i in range(num_layers): - compilation_config.static_forward_context[str( - i)] = self.kv_caches[i] - - self.make_empty_intermediate_tensors = \ - make_empty_intermediate_tensors_factory( - keys=["hidden_states"], - hidden_size=self.model_config.hf_config.hidden_size) - - self.cast = ops.Cast() - - def _set_dynamic_inputs(self): - self.network.set_dynamic_inputs() - dynamic_hidden_states = Tensor(shape=[None, None], - dtype=self.network.compute_dtype) - if get_pp_group().is_last_rank: - self.lm_head.set_inputs(dynamic_hidden_states) - - def _create_mask(self): - # Initial mask - mask_func = (MLALowerTriangularMask - if self.mla_config else LowerTriangularMask) - return mask_func(dtype=self.network.compute_dtype, - max_model_len=self.model_config.max_model_len) - - def _create_kv_caches(self, num_layers): - # Initial kv_caches - wrapper_func = (MLAAttentionWrapper - if self.mla_config else AttentionWrapper) - return [wrapper_func() for _ in range(num_layers)] - - def get_kvcache(self): - if not self.mla_config: - return super().get_kvcache() - - key_cache = [] - forward_context = get_forward_context() - for i in range(self.config.num_hidden_layers): - k_cache = self.kv_caches[i].kv_cache[ - forward_context.virtual_engine][0] - key_cache.append(k_cache) - return mutable(key_cache), None - - def _get_padding_index(self, q_seq_len): - """ - Calculate the padding index used in the mixed parallel scenario. - Case 1: When data_parallel_size equals 1, no padding operation - required, returns None. - Case 2: When data_parallel_size equals expert_parallel_size and - model_parallel equals 1, all_to_all communication is applied, - no padding operation required, returns None. - Case 3: In other DP enabled scenarios, calculate the corresponding - padding index based on the query sequence lengths processed - by each DP domain. - - e.g. DP2 TP4 MoE_EP2 - +------------------+------------------------+------------------------+ - | DP domain | DP0 | DP1 | - +------------------+------------------------+------------------------+ - | q_seq_len | 3 | 5 | - +------------------+------------------------+------------------------+ - | attn_padding_idx | [0,1,2,0,0,0,0,0] | [0,1,2,3,4,0,0,0] | - +------------------+------------------------+------------------------+ - |attn_unpadding_idx| [0,1,2,8,9,10,11,12] | - +------------------+------------------------+------------------------+ - | ffn_padding_idx | [0,1,2,0,0,0,0,0,3,4,5,6,7,0,0,0] | - +------------------+------------------------+------------------------+ - |ffn_unpadding_idx | [0,1,2] | [0,1,2,3,4] | - +------------------+------------------------+------------------------+ - - Args: - - q_seq_len (Tensor): query sequence lengths. - - Returns: - - attn_padding_idx (Tensor or None): Indices mapping positions in - attention output sequence to original token positions, used for - padding attention output to fixed size. - - attn_unpadding_idx (Tensor or None): Indices mapping valid tokens - in padded attention output sequence to their original positions, - used for removing padding in attention output. - - ffn_padding_idx (Tensor or None): Indices mapping positions in MoE - output sequence to flattened valid token positions, used for padding - MoE output to fixed size. - - ffn_unpadding_idx (Tensor or None): Indices mapping valid tokens in - padded MoE output sequence to their original positions, used for - removing padding in MoE output. - """ - dp_size = self.mf_config.parallel_config.data_parallel - tp_size = self.mf_config.parallel_config.model_parallel - ep_size = self.mf_config.parallel_config.expert_parallel - if dp_size == 1 or (dp_size == ep_size and tp_size == 1): - return None, None, None, None - - tokens_len_per_dp = q_seq_len.sum().reshape(-1) - tokens_len_per_dp = get_dp_group().all_gather(tokens_len_per_dp) - tokens_len_per_dp = tokens_len_per_dp.asnumpy() - - # Simultaneously satisfying the requirement of being divisible by - # tensor_parallel_size and greater than the maximum q_seq_len in all - # DP domains. - padding_size = ((tokens_len_per_dp.max() + tp_size - 1) // tp_size * - tp_size) - - dp_rank_id = get_dp_group().rank_in_group - attn_padding_idx = None - attn_unpadding_idx = None - ffn_padding_idx = None - ffn_unpadding_idx = None - last_arange_index = 0 - - for dp_rank, tokens_length in enumerate(tokens_len_per_dp): - arange_data = np.arange(0, int(tokens_length), dtype=np.int32) - if dp_rank == dp_rank_id: - ffn_unpadding_idx = arange_data - pad = np.zeros(padding_size - arange_data.shape[0], - dtype=np.int32) - attn_padding_idx = np.concatenate((arange_data, pad), axis=0) - if dp_rank == 0: - attn_unpadding_idx = arange_data - last_arange_index = arange_data[-1] - pad = np.zeros(padding_size - attn_unpadding_idx.shape[0], - dtype=np.int32) - ffn_padding_idx = np.concatenate((attn_unpadding_idx, pad), - axis=0) - else: - attn_offset_idx = arange_data + padding_size * dp_rank - attn_unpadding_idx = np.concatenate( - (attn_unpadding_idx, attn_offset_idx), axis=0) - ffn_offset_idx = arange_data + last_arange_index + 1 - last_arange_index = ffn_offset_idx[-1] - pad = np.zeros(padding_size - ffn_offset_idx.shape[0], - dtype=np.int32) - ffn_padding_idx = np.concatenate( - (ffn_padding_idx, ffn_offset_idx, pad), axis=0) - return (ms.from_numpy(attn_padding_idx), - ms.from_numpy(attn_unpadding_idx), - ms.from_numpy(ffn_padding_idx), - ms.from_numpy(ffn_unpadding_idx)) - - def update_padding_index_to_inputs(self, model_inputs, q_seq_lens): - """ - Update the model input and add the related parameters of padding_index. - """ - if (self.network.model_comm_pgs is not default_model_comm_pgs - and getattr(self.network.model_comm_pgs, 'dp', None) - and getattr(self.network.model_comm_pgs, 'moe_ep', None)): - - (attn_padding_idx, attn_unpadding_idx, ffn_padding_idx, - ffn_unpadding_idx) = self._get_padding_index(q_seq_lens) - - model_inputs["attn_padding_idx"] = attn_padding_idx - model_inputs["attn_unpadding_idx"] = attn_unpadding_idx - model_inputs["ffn_padding_idx"] = ffn_padding_idx - model_inputs["ffn_unpadding_idx"] = ffn_unpadding_idx - - return model_inputs - - def prepare_inputs(self, input_ids, positions): - - attn_metadata = get_forward_context().attn_metadata - # 0.9.1 attn_metadata[layer_name], don't have layer_name here - # so we just take one by default - if isinstance(attn_metadata, dict) and '1' in attn_metadata: - attn_metadata = attn_metadata['1'] - if attn_metadata is None: - attn_metadata = self._dummy_attention_metadata( - input_ids, positions) - key_cache, value_cache = self.get_kvcache() - if not envs.VLLM_USE_V1: - # V0 - seq_lens = attn_metadata.seq_lens - max_query_len = attn_metadata.max_query_len - # When Mutli-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes and max_query_len will be 1. - if self.is_multi_step_chunked_prefill and max_query_len == 1: - query_lens = [1] * len(seq_lens) - else: - query_lens = attn_metadata.query_lens - - seq_lens_np = np.array(seq_lens, dtype=np.int32) - query_lens_np = np.array(query_lens, dtype=np.int32) - kv_cache_lens = seq_lens_np - query_lens_np - if attn_metadata.num_decode_tokens == 0 and kv_cache_lens.max( - ) == 0: - is_prefill = True - else: - is_prefill = False - context_lens_tensor = ms.from_numpy(kv_cache_lens) - else: - # V1 - is_prefill = attn_metadata.max_context_lens == 0 - query_lens_np = attn_metadata.q_seq_lens_np - seq_lens_np = attn_metadata.seq_lens_np - context_lens_tensor = attn_metadata.context_lens - - q_seq_lens = ms.Tensor(query_lens_np, dtype=ms.int32) - position_ids = ms.Tensor(positions, dtype=ms.int32) - attention_mask = self.casual_mask.gen_attention_mask( - is_prefill, position_ids, query_lens_np, seq_lens_np) - - model_inputs = {} - model_inputs["input_ids"] = input_ids.astype(ms.int32) - model_inputs["batch_valid_length"] = ms.from_numpy(seq_lens_np) - model_inputs["block_tables"] = attn_metadata.block_tables - model_inputs["slot_mapping"] = attn_metadata.slot_mapping - model_inputs["positions"] = position_ids - model_inputs["q_seq_lens"] = q_seq_lens - model_inputs["attention_mask"] = attention_mask - model_inputs["key_cache"] = key_cache - model_inputs["value_cache"] = value_cache - model_inputs["context_lens_tensor"] = context_lens_tensor - model_inputs = (self.update_padding_index_to_inputs( - model_inputs, q_seq_lens)) - - return model_inputs, is_prefill - - def forward(self, - input_ids: Tensor, - positions: Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[Tensor] = None, - **kwargs) -> Union[Tensor, IntermediateTensors]: - model_inputs, is_prefill = self.prepare_inputs(input_ids, positions) - model_inputs = self.update_model_inputs(model_inputs, **kwargs) - if intermediate_tensors is not None: - model_inputs["hidden_states"] = \ - intermediate_tensors["hidden_states"] - - if is_prefill: - self.network.phase = "prefill" - if not self.set_flags or is_pynative(): - self.network.add_flags_custom_mcore(is_prefill=True) - hidden_states = self.network(**model_inputs) - self.network.phase = "increment" - if not self.set_flags or is_pynative(): - self.network.add_flags_custom_mcore(is_prefill=False) - self.set_flags = True - else: - hidden_states = self.network(**model_inputs) - - if not get_pp_group().is_last_rank: - return IntermediateTensors({"hidden_states": hidden_states}) - return hidden_states - - def _create_network(self): - # Initial network - if self.model_config.enforce_eager: - os.environ['ENFORCE_EAGER'] = 'True' - with no_init_parameters(): # Delay initialization - network: PreTrainedModel = AutoModel.from_config(self.mf_config) - network.model.return_hidden_states = True - if get_pp_group().is_last_rank: - return network, network.model.output_layer - return network, None - - def update_model_inputs(self, model_inputs, **kwargs): - return model_inputs - - def compute_logits( - self, - hidden_states: Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[Tensor]: - if sampling_metadata is not None: - selected_token_indices = sampling_metadata.selected_token_indices - if (selected_token_indices is not None - and selected_token_indices.numel() <= 0): - logits = ms.mint.zeros( - (0, self.model_config.hf_config.vocab_size), - dtype=self.model_config.hf_config.torch_dtype) - return logits - else: - hidden_states = hidden_states.reshape( - (-1, hidden_states.shape[-1])) - hidden_states = hidden_states.index_select( - 0, selected_token_indices) - if is_310p(): - # To get better performance in 310p, the lm head should run - # in O0 mode to avoid transdata, 910 keep the original process. - if self.lm_head_graph is None: - self.lm_head_graph = ms.jit(function=self.lm_head, - jit_level="O0") - logits = self.lm_head_graph(hidden_states) - else: - logits = self.lm_head(hidden_states) - logits = logits.view(-1, logits.shape[-1]) - return logits - - def sample( - self, - logits: Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - _pynative_executor.sync() - return next_tokens - - def load_weights(self, weights: Iterable[tuple[str, Tensor]]): - self.network.load_weights(self.mf_config.load_checkpoint) - return None +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from collections.abc import Iterable +from typing import Optional, Union + +import mindspore as ms +import numpy as np +from mindformers import AutoModel, PreTrainedModel +from mindformers.core.context import build_mf_context +from mindformers.parallel_core.process_group_config import ( + default_model_comm_pgs) +from mindformers.tools.utils import is_pynative +from mindspore import Tensor, mutable, ops +from mindspore.common.api import _pynative_executor +from mindspore.nn.utils import no_init_parameters +from vllm import envs +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.distributed.parallel_state import get_dp_group, get_pp_group +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.models.interfaces import SupportsPP +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from vllm_mindspore.model_executor.models.attention_mask import ( + LowerTriangularMask, MLALowerTriangularMask) +from vllm_mindspore.model_executor.models.mf_models.config import gen_mf_config +from vllm_mindspore.model_executor.models.model_base import ( + AttentionWrapper, MLAAttentionWrapper, MsModelBase) +from vllm_mindspore.model_executor.models.utils import ( + make_empty_intermediate_tensors_factory) +from vllm_mindspore.utils import is_310p + +logger = init_logger(__name__) + + +class MindFormersForCausalLM(MsModelBase, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.set_flags = False + self.model_config = vllm_config.model_config + self.lm_head_graph = None + + mf_config = gen_mf_config(vllm_config) + mf_config.load_checkpoint = self.get_model_path() + mf_config.pretrained_model_dir = self.get_model_path() + self.mf_config = mf_config + + self.use_ringmla = vllm_config.model_config.quantization is not None \ + and vllm_config.parallel_config.tensor_parallel_size < 16 + self.mla_config = self.mf_config.get('model', None).get( + 'model_config', None).get('multi_latent_attention', False) + self.is_chunked = False + + build_mf_context(self.mf_config) + + self.network, self.lm_head = self._create_network() + self.casual_mask = self._create_mask() + + self._set_dynamic_inputs() + + self.sampler = get_sampler() + self.set_modules({"model": self.network}) + + num_layers = self.model_config.get_num_layers(self.parallel_config) + self.kv_caches = self._create_kv_caches(num_layers) + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + for i in range(num_layers): + compilation_config.static_forward_context[str( + i)] = self.kv_caches[i] + + self.make_empty_intermediate_tensors = \ + make_empty_intermediate_tensors_factory( + keys=["hidden_states"], + hidden_size=self.model_config.hf_config.hidden_size) + + self.cast = ops.Cast() + + def _set_dynamic_inputs(self): + self.network.set_dynamic_inputs() + dynamic_hidden_states = Tensor(shape=[None, None], + dtype=self.network.compute_dtype) + if get_pp_group().is_last_rank: + self.lm_head.set_inputs(dynamic_hidden_states) + + def _create_mask(self): + # Initial mask + mask_func = (MLALowerTriangularMask + if self.mla_config else LowerTriangularMask) + return mask_func(dtype=self.network.compute_dtype, + max_model_len=self.model_config.max_model_len) + + def _create_kv_caches(self, num_layers): + # Initial kv_caches + wrapper_func = (MLAAttentionWrapper + if self.mla_config else AttentionWrapper) + return [wrapper_func() for _ in range(num_layers)] + + def get_kvcache(self): + if not self.mla_config: + return super().get_kvcache() + + key_cache = [] + rope_cache = [] + forward_context = get_forward_context() + key_cache = [ + self.kv_caches[i].kv_cache[forward_context.virtual_engine][0] + for i in range(self.config.num_hidden_layers) + ] + if not self.use_ringmla: + return mutable(key_cache), None + else: + # deepseek mla op need key cache and rope cache + rope_cache = [ + self.kv_caches[i].kv_cache[forward_context.virtual_engine][1] + for i in range(self.config.num_hidden_layers) + ] + return mutable(key_cache), mutable(rope_cache) + + def _get_padding_index(self, q_seq_len): + """ + Calculate the padding index used in the mixed parallel scenario. + Case 1: When data_parallel_size equals 1, no padding operation + required, returns None. + Case 2: When data_parallel_size equals expert_parallel_size and + model_parallel equals 1, all_to_all communication is applied, + no padding operation required, returns None. + Case 3: In other DP enabled scenarios, calculate the corresponding + padding index based on the query sequence lengths processed + by each DP domain. + + e.g. DP2 TP4 MoE_EP2 + +------------------+------------------------+------------------------+ + | DP domain | DP0 | DP1 | + +------------------+------------------------+------------------------+ + | q_seq_len | 3 | 5 | + +------------------+------------------------+------------------------+ + | attn_padding_idx | [0,1,2,0,0,0,0,0] | [0,1,2,3,4,0,0,0] | + +------------------+------------------------+------------------------+ + |attn_unpadding_idx| [0,1,2,8,9,10,11,12] | + +------------------+------------------------+------------------------+ + | ffn_padding_idx | [0,1,2,0,0,0,0,0,3,4,5,6,7,0,0,0] | + +------------------+------------------------+------------------------+ + |ffn_unpadding_idx | [0,1,2] | [0,1,2,3,4] | + +------------------+------------------------+------------------------+ + + Args: + - q_seq_len (Tensor): query sequence lengths. + + Returns: + - attn_padding_idx (Tensor or None): Indices mapping positions in + attention output sequence to original token positions, used for + padding attention output to fixed size. + - attn_unpadding_idx (Tensor or None): Indices mapping valid tokens + in padded attention output sequence to their original positions, + used for removing padding in attention output. + - ffn_padding_idx (Tensor or None): Indices mapping positions in MoE + output sequence to flattened valid token positions, used for padding + MoE output to fixed size. + - ffn_unpadding_idx (Tensor or None): Indices mapping valid tokens in + padded MoE output sequence to their original positions, used for + removing padding in MoE output. + """ + dp_size = self.mf_config.parallel_config.data_parallel + tp_size = self.mf_config.parallel_config.model_parallel + ep_size = self.mf_config.parallel_config.expert_parallel + if dp_size == 1 or (dp_size == ep_size and tp_size == 1): + return None, None, None, None + + tokens_len_per_dp = q_seq_len.sum().reshape(-1) + tokens_len_per_dp = get_dp_group().all_gather(tokens_len_per_dp) + tokens_len_per_dp = tokens_len_per_dp.asnumpy() + + # Simultaneously satisfying the requirement of being divisible by + # tensor_parallel_size and greater than the maximum q_seq_len in all + # DP domains. + padding_size = ((tokens_len_per_dp.max() + tp_size - 1) // tp_size * + tp_size) + + dp_rank_id = get_dp_group().rank_in_group + attn_padding_idx = None + attn_unpadding_idx = None + ffn_padding_idx = None + ffn_unpadding_idx = None + last_arange_index = 0 + + for dp_rank, tokens_length in enumerate(tokens_len_per_dp): + arange_data = np.arange(0, int(tokens_length), dtype=np.int32) + if dp_rank == dp_rank_id: + ffn_unpadding_idx = arange_data + pad = np.zeros(padding_size - arange_data.shape[0], + dtype=np.int32) + attn_padding_idx = np.concatenate((arange_data, pad), axis=0) + if dp_rank == 0: + attn_unpadding_idx = arange_data + last_arange_index = arange_data[-1] + pad = np.zeros(padding_size - attn_unpadding_idx.shape[0], + dtype=np.int32) + ffn_padding_idx = np.concatenate((attn_unpadding_idx, pad), + axis=0) + else: + attn_offset_idx = arange_data + padding_size * dp_rank + attn_unpadding_idx = np.concatenate( + (attn_unpadding_idx, attn_offset_idx), axis=0) + ffn_offset_idx = arange_data + last_arange_index + 1 + last_arange_index = ffn_offset_idx[-1] + pad = np.zeros(padding_size - ffn_offset_idx.shape[0], + dtype=np.int32) + ffn_padding_idx = np.concatenate( + (ffn_padding_idx, ffn_offset_idx, pad), axis=0) + return (ms.from_numpy(attn_padding_idx), + ms.from_numpy(attn_unpadding_idx), + ms.from_numpy(ffn_padding_idx), + ms.from_numpy(ffn_unpadding_idx)) + + def update_padding_index_to_inputs(self, model_inputs, q_seq_lens): + """ + Update the model input and add the related parameters of padding_index. + """ + if (self.network.model_comm_pgs is not default_model_comm_pgs + and getattr(self.network.model_comm_pgs, 'dp', None) + and getattr(self.network.model_comm_pgs, 'moe_ep', None)): + + (attn_padding_idx, attn_unpadding_idx, ffn_padding_idx, + ffn_unpadding_idx) = self._get_padding_index(q_seq_lens) + + model_inputs["attn_padding_idx"] = attn_padding_idx + model_inputs["attn_unpadding_idx"] = attn_unpadding_idx + model_inputs["ffn_padding_idx"] = ffn_padding_idx + model_inputs["ffn_unpadding_idx"] = ffn_unpadding_idx + + return model_inputs + + def prepare_inputs(self, input_ids, positions): + + attn_metadata = get_forward_context().attn_metadata + # 0.9.1 attn_metadata[layer_name], don't have layer_name here + # so we just take one by default + if isinstance(attn_metadata, dict) and '1' in attn_metadata: + attn_metadata = attn_metadata['1'] + if attn_metadata is None: + attn_metadata = self._dummy_attention_metadata( + input_ids, positions) + key_cache, value_cache = self.get_kvcache() + if not envs.VLLM_USE_V1: + # V0 + seq_lens = attn_metadata.seq_lens + max_query_len = attn_metadata.max_query_len + # When Mutli-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes and max_query_len will be 1. + if self.is_multi_step_chunked_prefill and max_query_len == 1: + query_lens = [1] * len(seq_lens) + else: + query_lens = attn_metadata.query_lens + + seq_lens_np = np.array(seq_lens, dtype=np.int32) + query_lens_np = np.array(query_lens, dtype=np.int32) + kv_cache_lens = seq_lens_np - query_lens_np + is_prefill = kv_cache_lens.max() == 0 + is_chunked = attn_metadata.num_decode_tokens == 0 and \ + kv_cache_lens.max() > 0 + context_lens_tensor = ms.from_numpy(kv_cache_lens) + else: + # V1 + is_prefill = attn_metadata.max_context_lens == 0 + print(f"attn_metadata.context_lens: {attn_metadata.context_lens}, attn_metadata.num_prompt_tokens: {attn_metadata.num_prompt_tokens}") + is_chunked = not is_prefill and \ + (attn_metadata.context_lens - attn_metadata.num_prompt_tokens).min() < 0 + query_lens_np = attn_metadata.q_seq_lens_np + seq_lens_np = attn_metadata.seq_lens_np + context_lens_tensor = attn_metadata.context_lens + + q_seq_lens = ms.Tensor(query_lens_np, dtype=ms.int32) + position_ids = ms.Tensor(positions, dtype=ms.int32) + attention_mask = self.casual_mask.gen_attention_mask( + is_prefill, position_ids, query_lens_np, seq_lens_np) + + model_inputs = {} + model_inputs["input_ids"] = input_ids.astype(ms.int32) + model_inputs["batch_valid_length"] = ms.from_numpy(seq_lens_np) + model_inputs["block_tables"] = attn_metadata.block_tables + model_inputs["slot_mapping"] = attn_metadata.slot_mapping + model_inputs["positions"] = position_ids + model_inputs["q_seq_lens"] = q_seq_lens + model_inputs["attention_mask"] = attention_mask + model_inputs["key_cache"] = key_cache + model_inputs["value_cache"] = value_cache + model_inputs["context_lens_tensor"] = context_lens_tensor + model_inputs = (self.update_padding_index_to_inputs( + model_inputs, q_seq_lens)) + + return model_inputs, is_prefill, is_chunked + + def forward(self, + input_ids: Tensor, + positions: Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, + **kwargs) -> Union[Tensor, IntermediateTensors]: + model_inputs, is_prefill, is_chunked = self.prepare_inputs(input_ids, positions) + model_inputs = self.update_model_inputs(model_inputs, **kwargs) + if intermediate_tensors is not None: + model_inputs["hidden_states"] = \ + intermediate_tensors["hidden_states"] + + if is_prefill: + self.network.phase = "prefill" if (not self.use_ringmla or not is_chunked) else "chunked" + if not self.set_flags or is_pynative(): + self.network.add_flags_custom_mcore(is_prefill=True) + self.network.add_flags_chunked(is_chunked=is_chunked) + self.is_chunked |= (self.use_ringmla and is_chunked) + hidden_states = self.network(**model_inputs) + self.network.phase = "increment" + if not self.set_flags or is_pynative(): + self.network.add_flags_custom_mcore(is_prefill=False) + self.set_flags = True if not self.use_ringmla else self.is_chunked + else: + hidden_states = self.network(**model_inputs) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + return hidden_states + + def _create_network(self): + # Initial network + if self.model_config.enforce_eager: + os.environ['ENFORCE_EAGER'] = 'True' + with no_init_parameters(): # Delay initialization + network: PreTrainedModel = AutoModel.from_config(self.mf_config) + network.model.return_hidden_states = True + if get_pp_group().is_last_rank: + return network, network.model.output_layer + return network, None + + def update_model_inputs(self, model_inputs, **kwargs): + return model_inputs + + def compute_logits( + self, + hidden_states: Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[Tensor]: + if sampling_metadata is not None: + selected_token_indices = sampling_metadata.selected_token_indices + if (selected_token_indices is not None + and selected_token_indices.numel() <= 0): + logits = ms.mint.zeros( + (0, self.model_config.hf_config.vocab_size), + dtype=self.model_config.hf_config.torch_dtype) + return logits + else: + hidden_states = hidden_states.reshape( + (-1, hidden_states.shape[-1])) + hidden_states = hidden_states.index_select( + 0, selected_token_indices) + if is_310p(): + # To get better performance in 310p, the lm head should run + # in O0 mode to avoid transdata, 910 keep the original process. + if self.lm_head_graph is None: + self.lm_head_graph = ms.jit(function=self.lm_head, + jit_level="O0") + logits = self.lm_head_graph(hidden_states) + else: + logits = self.lm_head(hidden_states) + logits = logits.view(-1, logits.shape[-1]) + return logits + + def sample( + self, + logits: Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + _pynative_executor.sync() + return next_tokens + + def load_weights(self, weights: Iterable[tuple[str, Tensor]]): + self.network.load_weights(self.mf_config.load_checkpoint) + return None diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 4d35d968449d602115bdab34f01f9fcde2fcc091..3249b4475b53f0ad49647d865a72e15da1a644c5 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -70,10 +70,9 @@ class MLAAttentionWrapper(AttentionWrapper): def __init__(self): super().__init__() vllm_config = get_current_vllm_config() - self.use_mla_op = bool( - vllm_config.additional_config - and vllm_config.additional_config.get('use_mla_op') == 1) - if not self.use_mla_op: + self.use_ringmla = vllm_config.model_config.quantization is not None \ + and vllm_config.parallel_config.tensor_parallel_size < 16 + if not self.use_ringmla: self.kv_cache = [ ( ms.mint.zeros( @@ -88,9 +87,9 @@ class MLAAttentionWrapper(AttentionWrapper): 'qk_rope_head_dim', 0) # k_shape, r_shape used for mla_op k_shape = [*(self.kv_shape[0:-1]), kv_lora_rank - ] if self.use_mla_op else None + ] if self.use_ringmla else None r_shape = [*(self.kv_shape[0:-1]), qk_rope_head_dim - ] if self.use_mla_op else None + ] if self.use_ringmla else None self.kv_cache = [ (ms.mint.zeros(k_shape, dtype=vllm_config.model_config.dtype), ms.mint.zeros(r_shape, dtype=vllm_config.model_config.dtype)) @@ -300,7 +299,8 @@ class MsModelBase: # To enforce prefill and decode are both complied in warmup process. # So set max_context_lens to 0 for prefill and 1 for decode. max_context_lens=0 if not self.set_flags else 1, - query_start_loc=None) + query_start_loc=None, + num_prompt_tokens=seq_lengths) def prepare_base_inputs(self, input_ids, positions): attn_metadata = get_forward_context().attn_metadata diff --git a/vllm_mindspore/v1/attention/backends/ms_attn.py b/vllm_mindspore/v1/attention/backends/ms_attn.py index f014381b393154cf81ee466439777df0510f8fb1..194681db730c3c0bebd541f4215578a478b583e5 100644 --- a/vllm_mindspore/v1/attention/backends/ms_attn.py +++ b/vllm_mindspore/v1/attention/backends/ms_attn.py @@ -137,6 +137,7 @@ class MsAttentionMetadata: #block_table: torch.Tensor slot_mapping: ms.Tensor + num_prompt_tokens: ms.Tensor # For cascade attention. #use_cascade: bool #common_prefix_len: int @@ -201,6 +202,7 @@ class MsAttentionMetadataBuilder: max_context_lens = self.runner.input_batch.num_computed_tokens_cpu[: num_reqs].max( ) + num_prompt_tokens = ms.from_numpy(self.runner.input_batch.num_prompt_tokens[:num_reqs]) slot_mapping = ms.from_numpy( self.block_table.slot_mapping_np[:num_actual_tokens]) seq_lens_np = self.runner.seq_lens_np[:num_reqs] @@ -220,7 +222,8 @@ class MsAttentionMetadataBuilder: max_seq_len=max_seq_len, context_lens=context_lens, max_context_lens=max_context_lens, - query_start_loc=query_start_loc) + query_start_loc=query_start_loc, + num_prompt_tokens=num_prompt_tokens) return attn_metadata diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index b37b39bc1f718473c16cc86a3cbcb28fc73fc70a..85d51006499c2a19c960faf42f9da6f9b467a9a0 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -232,11 +232,9 @@ def _allocate_nz_kv_cache_tensors(self, kv_cache_config): for i, group in enumerate(kv_cache_config.kv_cache_groups) for layer_name in group.layer_names } - - use_mla_op = bool( - self.vllm_config.additional_config - and self.vllm_config.additional_config.get('use_mla_op') == 1) - if use_mla_op: + use_ringmla = self.vllm_config.model_config.quantization is not None \ + and self.vllm_config.parallel_config.tensor_parallel_size < 16 + if use_ringmla: logger.error("For 310p, mla kv cache not supported") raise NotImplementedError @@ -290,9 +288,8 @@ def _allocate_kv_cache_tensors(self, kv_cache_config): dtype = kv_cache_spec.dtype coef = 1 if use_mla else 2 # Determine whether deepseek use mla op - use_mla_op = bool( - self.vllm_config.additional_config - and self.vllm_config.additional_config.get('use_mla_op') == 1) + use_ringmla = self.vllm_config.model_config.quantization is not None \ + and self.vllm_config.parallel_config.tensor_parallel_size < 16 kv_lora_rank = getattr(self.vllm_config.model_config.hf_text_config, 'kv_lora_rank', 0) qk_rope_head_dim = getattr(self.vllm_config.model_config.hf_text_config, @@ -317,7 +314,7 @@ def _allocate_kv_cache_tensors(self, kv_cache_config): """ raw_tensors.extend( [mint.zeros(raw_tensor_shape, dtype=target_dtype)] - if not use_mla_op else [ + if not use_ringmla else [ mint.zeros(int(raw_tensor_shape * kv_lora_rank / (kv_lora_rank + qk_rope_head_dim)), dtype=target_dtype), @@ -354,9 +351,8 @@ def _reshape_kv_cache_tensors( corresponding memory buffer for KV cache. """ # Determine whether deepseek use mla op - use_mla_op = bool( - self.vllm_config.additional_config - and self.vllm_config.additional_config.get('use_mla_op') == 1) + use_ringmla = self.vllm_config.model_config.quantization is not None \ + and self.vllm_config.parallel_config.tensor_parallel_size < 16 kv_lora_rank = getattr(self.vllm_config.model_config.hf_text_config, 'kv_lora_rank', 0) qk_rope_head_dim = getattr(self.vllm_config.model_config.hf_text_config, @@ -371,7 +367,7 @@ def _reshape_kv_cache_tensors( dtype_size = get_dtype_size(target_dtype) num_blocks = \ (raw_tensor[0].numel() - if not use_mla_op else + if not use_ringmla else # deepseek mla op need key cache and rope cache (raw_tensor[0].numel() + raw_tensor[1].numel())) * \ coef * dtype_size // kv_cache_spec.page_size_bytes @@ -400,7 +396,7 @@ def _reshape_kv_cache_tensors( kv_cache_layer = [] for idx, kv_cache_raw_tensor in enumerate( kv_cache_raw_tensors[layer_name]): - if use_mla_op: + if use_ringmla: # deepseek mla op need key cache and rope cache cache_shape = [ *(kv_cache_shape[1:-1]),