From 00cc2d118deb8de1ad95f60b2b8f4dd110bb33ab Mon Sep 17 00:00:00 2001 From: zhang_xu_hao1230 Date: Sun, 28 Sep 2025 17:43:36 +0800 Subject: [PATCH] add convert pin --- .../models/mf_models/mindformers.py | 25 +++++++++++-------- .../model_executor/models/model_base.py | 18 +++++++------ vllm_mindspore/model_executor/models/utils.py | 19 +++++++++++++- 3 files changed, 42 insertions(+), 20 deletions(-) diff --git a/vllm_mindspore/model_executor/models/mf_models/mindformers.py b/vllm_mindspore/model_executor/models/mf_models/mindformers.py index 9b15a2a8..359c4a93 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mindformers.py +++ b/vllm_mindspore/model_executor/models/mf_models/mindformers.py @@ -39,7 +39,7 @@ from vllm_mindspore.model_executor.models.mf_models.config import gen_mf_config from vllm_mindspore.model_executor.models.model_base import ( AttentionWrapper, MLAAttentionWrapper, MsModelBase) from vllm_mindspore.model_executor.models.utils import ( - is_use_ringmla, make_empty_intermediate_tensors_factory) + convert_pin, is_use_ringmla, make_empty_intermediate_tensors_factory) from vllm_mindspore.utils import is_310p logger = init_logger(__name__) @@ -314,16 +314,18 @@ class MindFormersForCausalLM(MsModelBase, SupportsPP): is_prefill, position_ids, query_lens_np, seq_lens_np) model_inputs = {} - model_inputs["input_ids"] = input_ids.astype(ms.int32) * 1 - model_inputs["batch_valid_length"] = ms.from_numpy(seq_lens_np) - model_inputs["block_tables"] = attn_metadata.block_tables * 1 - model_inputs["slot_mapping"] = attn_metadata.slot_mapping - model_inputs["positions"] = position_ids - model_inputs["q_seq_lens"] = q_seq_lens - model_inputs["attention_mask"] = attention_mask + model_inputs["input_ids"] = convert_pin(input_ids.astype(ms.int32) * 1) + model_inputs["batch_valid_length"] = convert_pin( + ms.from_numpy(seq_lens_np)) + model_inputs["block_tables"] = convert_pin(attn_metadata.block_tables * + 1) + model_inputs["slot_mapping"] = convert_pin(attn_metadata.slot_mapping) + model_inputs["positions"] = convert_pin(position_ids) + model_inputs["q_seq_lens"] = convert_pin(q_seq_lens) + model_inputs["attention_mask"] = convert_pin(attention_mask) model_inputs["key_cache"] = key_cache model_inputs["value_cache"] = value_cache - model_inputs["context_lens_tensor"] = context_lens_tensor + model_inputs["context_lens_tensor"] = convert_pin(context_lens_tensor) model_inputs = (self.update_padding_index_to_inputs( model_inputs, q_seq_lens)) @@ -353,10 +355,11 @@ class MindFormersForCausalLM(MsModelBase, SupportsPP): model_inputs = self.update_model_inputs(model_inputs, **kwargs) if intermediate_tensors is not None: model_inputs["hidden_states"] = \ - intermediate_tensors["hidden_states"] + convert_pin(intermediate_tensors["hidden_states"]) elif kwargs.get("previous_hidden_states") is not None: # used for deepseek-mtp - model_inputs["hidden_states"] = kwargs["previous_hidden_states"] + model_inputs["hidden_states"] = convert_pin( + kwargs["previous_hidden_states"]) if is_prefill or is_ringmla_chunked: self.network.phase = \ diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 47cfc601..99ec9912 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -32,7 +32,8 @@ from vllm.sequence import IntermediateTensors from vllm_mindspore.model_executor.models.attention_mask import ( LowerTriangularMask) -from vllm_mindspore.model_executor.models.utils import is_use_ringmla +from vllm_mindspore.model_executor.models.utils import (convert_pin, + is_use_ringmla) from vllm_mindspore.model_executor.utils import set_model_context from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE, create_kv_cache from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata @@ -352,13 +353,14 @@ class MsModelBase: is_prefill, position_ids, query_lens_np, seq_lens_np) model_inputs = {} - model_inputs["input_ids"] = input_ids - model_inputs["batch_valid_length"] = ms.from_numpy(seq_lens_np) - model_inputs["block_tables"] = attn_metadata.block_tables - model_inputs["slot_mapping"] = attn_metadata.slot_mapping - model_inputs["position_ids"] = position_ids - model_inputs["q_seq_lens"] = q_seq_lens - model_inputs["attention_mask"] = attention_mask + model_inputs["input_ids"] = convert_pin(input_ids) + model_inputs["batch_valid_length"] = convert_pin( + ms.from_numpy(seq_lens_np)) + model_inputs["block_tables"] = convert_pin(attn_metadata.block_tables) + model_inputs["slot_mapping"] = convert_pin(attn_metadata.slot_mapping) + model_inputs["position_ids"] = convert_pin(position_ids) + model_inputs["q_seq_lens"] = convert_pin(q_seq_lens) + model_inputs["attention_mask"] = convert_pin(attention_mask) model_inputs["key_cache"] = key_cache model_inputs["value_cache"] = value_cache diff --git a/vllm_mindspore/model_executor/models/utils.py b/vllm_mindspore/model_executor/models/utils.py index 5d9e56d3..2a4b3e51 100644 --- a/vllm_mindspore/model_executor/models/utils.py +++ b/vllm_mindspore/model_executor/models/utils.py @@ -23,7 +23,7 @@ from dataclasses import dataclass, field from typing import Optional, Union import mindspore as ms -from mindspore import mint, ops +from mindspore import Tensor, mint, ops from vllm import envs from vllm.sequence import IntermediateTensors @@ -34,6 +34,23 @@ WeightsMapping = Mapping[str, Optional[str]] """If a key maps to a value of `None`, the corresponding weight is ignored.""" +def convert_pin(input_tensor): + """Convert tensor to pinned memory if it's on CPU and not already pinned. + + Args: + input_tensor: Input tensor to convert + + Returns: + Tensor with pinned memory if applicable, otherwise original tensor + """ + if not isinstance(input_tensor, Tensor): + return input_tensor + if input_tensor._ms_device == "CPU" and not input_tensor.is_pinned(): + input_pined = input_tensor.pin_memory() + return input_pined + return input_tensor + + @dataclass class WeightsMapper: """Maps the name of each weight if they match the following patterns.""" -- Gitee