diff --git a/mindformers/generation/text_generator.py b/mindformers/generation/text_generator.py index b525c8a2e2ab33ee1db653f1acbe51b5fda01854..fc0f0f7cb47a6ea6adda88a879dc3b1fdc433dbb 100644 --- a/mindformers/generation/text_generator.py +++ b/mindformers/generation/text_generator.py @@ -16,6 +16,7 @@ """ For text generation """ +import os import copy import time from typing import Optional, List, Union, Dict @@ -40,7 +41,7 @@ from mindformers.version_control import is_310p from mindformers.models.utils import format_type from mindformers.models.tokenization_utils import PreTrainedTokenizer from mindformers.generation.streamers import BaseStreamer -from mindformers.generation.utils import softmax_with_threads, topk, GenerateOutput, InferOutput +from mindformers.generation.utils import softmax_with_threads, topk, GenerateOutput, InferOutput, convert_pin from mindformers.modules.block_tables import BlockTables from mindformers.tools.logger import logger from mindformers.tools.utils import is_pynative @@ -463,6 +464,7 @@ class GenerationMixin: if need_flatten: model_inputs["input_ids"] = model_inputs["input_ids"].reshape(-1) model_inputs["batch_valid_length"] = Tensor.from_numpy(model_inputs["batch_valid_length"]) + model_inputs = self.convert_pin_model_inputs(model_inputs) # pylint: disable=E1102 res = self( **model_inputs, @@ -482,6 +484,7 @@ class GenerationMixin: self.slice_incremental_inputs(model_inputs, current_index, need_flatten) self.detailed_latency.start_predict_timer() model_inputs["batch_valid_length"] = Tensor.from_numpy(model_inputs["batch_valid_length"]) + model_inputs = self.convert_pin_model_inputs(model_inputs) # pylint: disable=E1102 res = self( **model_inputs, @@ -1843,3 +1846,11 @@ class GenerationMixin: response = tokenizer.decode(output_ids, skip_special_tokens=True) history.append({"role": assistant_role_name, "content": response}) return response, history + + def convert_pin_model_inputs(self, model_inputs): + if os.environ.get("EXPERIMENTAL_KERNEL_LAUNCH_GROUP", None): + model_inputs["input_ids"] = convert_pin(model_inputs["input_ids"]) + model_inputs["batch_valid_length"] = convert_pin(model_inputs["batch_valid_length"]) + model_inputs["block_tables"] = convert_pin(model_inputs["block_tables"]) + model_inputs["slot_mapping"] = convert_pin(model_inputs["slot_mapping"]) + return model_inputs diff --git a/mindformers/generation/utils.py b/mindformers/generation/utils.py index c844886d44d711d999c087bd4eb418657a628132..38aa88071086550d5073271fb37577c55bfb4921 100644 --- a/mindformers/generation/utils.py +++ b/mindformers/generation/utils.py @@ -19,6 +19,10 @@ from threading import Thread from typing import Optional import numpy as np +from mindspore.common.tensor import Tensor + +from mindformers.version_control import check_pin_memory_interface_support + def log_softmax(x, axis=None): """ @@ -156,3 +160,20 @@ class InferOutput(UserDict): probs=self.probs, logits=self.logits ) + + +def convert_pin(input_tensor): + """Convert tensor to pinned memory if it's on CPU and not already pinned. + + Args: + input_tensor: Input tensor to convert + + Returns: + Tensor with pinned memory if applicable, otherwise original tensor + """ + if not isinstance(input_tensor, Tensor): + return input_tensor + if input_tensor.device == "CPU" and not input_tensor.is_pinned() and check_pin_memory_interface_support(): + input_tensor_pinned = input_tensor.pin_memory() + return input_tensor_pinned + return input_tensor diff --git a/mindformers/version_control.py b/mindformers/version_control.py index 212424121abff8aea871a45ac422574775931905..ef7f463b268cf787b6be94ef3e25278a60892cd4 100644 --- a/mindformers/version_control.py +++ b/mindformers/version_control.py @@ -321,6 +321,11 @@ def check_safetensors_addition_param_support(): return is_version_ge(ms.__version__, "2.6.0") +def check_pin_memory_interface_support(): + """check mindspore version if support pin_memory() interface.""" + return is_version_ge(ms.__version__, "2.7.1") + + def set_ms_deterministic(deterministic): """Set deterministic computing through mindspore.""" logger.debug("The version of MindSpore is %s, " diff --git a/research/llama3_1/llama.py b/research/llama3_1/llama.py index 39d478b02ee077657f5598e6dd4491e6217e1f9a..42dee2d2bf210dd09073a04f6e9fc0843a3cb9ea 100644 --- a/research/llama3_1/llama.py +++ b/research/llama3_1/llama.py @@ -24,6 +24,7 @@ from multiprocessing.synchronize import Condition from safetensors import safe_open import numpy as np +import mindspore as ms import mindspore.common.dtype as mstype from mindspore import Tensor, ops, mint, mutable from mindspore.communication._comm_helper import _is_initialized as mindspore_comm_has_init @@ -38,6 +39,7 @@ from mindformers.tools.register.register import MindFormerModuleType, MindFormer from mindformers.tools.utils import get_predict_run_mode from mindformers.tools.logger import logger from mindformers.models.utils import jit +from mindformers.generation.utils import convert_pin from research.llama3_1.infer.layers import ColumnParallelLinear from research.llama3_1.infer.transformer import ParallelTransformer from research.llama3_1.utils import convert_model_config @@ -138,13 +140,13 @@ class ParallelLlamaForCausalLM(LlamaPreTrainedModel): slot_mapping, model_inputs,) position_ids = batch_valid_length.astype(np.int32) - 1 - model_inputs["position_ids"] = Tensor.from_numpy(position_ids.reshape((-1,))) + model_inputs["position_ids"] = ms.Tensor(position_ids, dtype=ms.int32).reshape(-1) if not prefill: q_seq_lens = np.ones(batch_valid_length.shape, dtype=np.int32).reshape(-1) else: q_seq_lens = batch_valid_length.astype(np.int32).reshape(-1) - model_inputs["q_seq_lens"] = Tensor.from_numpy(q_seq_lens) + model_inputs["q_seq_lens"] = convert_pin(Tensor.from_numpy(q_seq_lens)) model_inputs["attention_mask"] = self.model.casual_mask.gen_attention_mask(prefill) model_inputs["need_flatten"] = True