From 488926e6ab57c68034d2343696ddaaed88c62883 Mon Sep 17 00:00:00 2001 From: Yule100 Date: Thu, 9 Oct 2025 15:21:03 +0800 Subject: [PATCH] bugfix pin_mem_1009 --- mindformers/generation/text_generator.py | 13 ++++++++++++- mindformers/generation/utils.py | 21 +++++++++++++++++++++ mindformers/version_control.py | 5 +++++ research/llama3_1/llama.py | 6 ++++-- 4 files changed, 42 insertions(+), 3 deletions(-) diff --git a/mindformers/generation/text_generator.py b/mindformers/generation/text_generator.py index b525c8a2e..fc0f0f7cb 100644 --- a/mindformers/generation/text_generator.py +++ b/mindformers/generation/text_generator.py @@ -16,6 +16,7 @@ """ For text generation """ +import os import copy import time from typing import Optional, List, Union, Dict @@ -40,7 +41,7 @@ from mindformers.version_control import is_310p from mindformers.models.utils import format_type from mindformers.models.tokenization_utils import PreTrainedTokenizer from mindformers.generation.streamers import BaseStreamer -from mindformers.generation.utils import softmax_with_threads, topk, GenerateOutput, InferOutput +from mindformers.generation.utils import softmax_with_threads, topk, GenerateOutput, InferOutput, convert_pin from mindformers.modules.block_tables import BlockTables from mindformers.tools.logger import logger from mindformers.tools.utils import is_pynative @@ -463,6 +464,7 @@ class GenerationMixin: if need_flatten: model_inputs["input_ids"] = model_inputs["input_ids"].reshape(-1) model_inputs["batch_valid_length"] = Tensor.from_numpy(model_inputs["batch_valid_length"]) + model_inputs = self.convert_pin_model_inputs(model_inputs) # pylint: disable=E1102 res = self( **model_inputs, @@ -482,6 +484,7 @@ class GenerationMixin: self.slice_incremental_inputs(model_inputs, current_index, need_flatten) self.detailed_latency.start_predict_timer() model_inputs["batch_valid_length"] = Tensor.from_numpy(model_inputs["batch_valid_length"]) + model_inputs = self.convert_pin_model_inputs(model_inputs) # pylint: disable=E1102 res = self( **model_inputs, @@ -1843,3 +1846,11 @@ class GenerationMixin: response = tokenizer.decode(output_ids, skip_special_tokens=True) history.append({"role": assistant_role_name, "content": response}) return response, history + + def convert_pin_model_inputs(self, model_inputs): + if os.environ.get("EXPERIMENTAL_KERNEL_LAUNCH_GROUP", None): + model_inputs["input_ids"] = convert_pin(model_inputs["input_ids"]) + model_inputs["batch_valid_length"] = convert_pin(model_inputs["batch_valid_length"]) + model_inputs["block_tables"] = convert_pin(model_inputs["block_tables"]) + model_inputs["slot_mapping"] = convert_pin(model_inputs["slot_mapping"]) + return model_inputs diff --git a/mindformers/generation/utils.py b/mindformers/generation/utils.py index c844886d4..38aa88071 100644 --- a/mindformers/generation/utils.py +++ b/mindformers/generation/utils.py @@ -19,6 +19,10 @@ from threading import Thread from typing import Optional import numpy as np +from mindspore.common.tensor import Tensor + +from mindformers.version_control import check_pin_memory_interface_support + def log_softmax(x, axis=None): """ @@ -156,3 +160,20 @@ class InferOutput(UserDict): probs=self.probs, logits=self.logits ) + + +def convert_pin(input_tensor): + """Convert tensor to pinned memory if it's on CPU and not already pinned. + + Args: + input_tensor: Input tensor to convert + + Returns: + Tensor with pinned memory if applicable, otherwise original tensor + """ + if not isinstance(input_tensor, Tensor): + return input_tensor + if input_tensor.device == "CPU" and not input_tensor.is_pinned() and check_pin_memory_interface_support(): + input_tensor_pinned = input_tensor.pin_memory() + return input_tensor_pinned + return input_tensor diff --git a/mindformers/version_control.py b/mindformers/version_control.py index 212424121..ef7f463b2 100644 --- a/mindformers/version_control.py +++ b/mindformers/version_control.py @@ -321,6 +321,11 @@ def check_safetensors_addition_param_support(): return is_version_ge(ms.__version__, "2.6.0") +def check_pin_memory_interface_support(): + """check mindspore version if support pin_memory() interface.""" + return is_version_ge(ms.__version__, "2.7.1") + + def set_ms_deterministic(deterministic): """Set deterministic computing through mindspore.""" logger.debug("The version of MindSpore is %s, " diff --git a/research/llama3_1/llama.py b/research/llama3_1/llama.py index 39d478b02..42dee2d2b 100644 --- a/research/llama3_1/llama.py +++ b/research/llama3_1/llama.py @@ -24,6 +24,7 @@ from multiprocessing.synchronize import Condition from safetensors import safe_open import numpy as np +import mindspore as ms import mindspore.common.dtype as mstype from mindspore import Tensor, ops, mint, mutable from mindspore.communication._comm_helper import _is_initialized as mindspore_comm_has_init @@ -38,6 +39,7 @@ from mindformers.tools.register.register import MindFormerModuleType, MindFormer from mindformers.tools.utils import get_predict_run_mode from mindformers.tools.logger import logger from mindformers.models.utils import jit +from mindformers.generation.utils import convert_pin from research.llama3_1.infer.layers import ColumnParallelLinear from research.llama3_1.infer.transformer import ParallelTransformer from research.llama3_1.utils import convert_model_config @@ -138,13 +140,13 @@ class ParallelLlamaForCausalLM(LlamaPreTrainedModel): slot_mapping, model_inputs,) position_ids = batch_valid_length.astype(np.int32) - 1 - model_inputs["position_ids"] = Tensor.from_numpy(position_ids.reshape((-1,))) + model_inputs["position_ids"] = ms.Tensor(position_ids, dtype=ms.int32).reshape(-1) if not prefill: q_seq_lens = np.ones(batch_valid_length.shape, dtype=np.int32).reshape(-1) else: q_seq_lens = batch_valid_length.astype(np.int32).reshape(-1) - model_inputs["q_seq_lens"] = Tensor.from_numpy(q_seq_lens) + model_inputs["q_seq_lens"] = convert_pin(Tensor.from_numpy(q_seq_lens)) model_inputs["attention_mask"] = self.model.casual_mask.gen_attention_mask(prefill) model_inputs["need_flatten"] = True -- Gitee