From 6df52fc560e090c72f8ed5e476cfd18457eb76dd Mon Sep 17 00:00:00 2001 From: liu lili Date: Fri, 24 Oct 2025 17:00:23 +0800 Subject: [PATCH] lll: add moveto batch_valid_lens to cpu cofnig --- .../inference/base_models/gpt/gpt_model.py | 41 ++++++++++++------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/mindformers/parallel_core/inference/base_models/gpt/gpt_model.py b/mindformers/parallel_core/inference/base_models/gpt/gpt_model.py index 647b73f25..71b6498f6 100644 --- a/mindformers/parallel_core/inference/base_models/gpt/gpt_model.py +++ b/mindformers/parallel_core/inference/base_models/gpt/gpt_model.py @@ -116,7 +116,7 @@ class GPTModel(nn.Cell): model_comm_pgs: Optional[ModelCommProcessGroups] = None, quant_config: Optional[QuantizationConfig] = None, ): - super(GPTModel, self).__init__() + super().__init__() self.check_support(fp16_lm_cross_entropy, rope_scaling) self.config = config self.quant_config = quant_config @@ -144,6 +144,7 @@ class GPTModel(nn.Cell): self.return_hidden_states = False # For serving, return hidden_states early and skip output_layer self.is_mtp_model = self.mtp_block_spec is not None self.is_pynative = is_pynative() + self.move_lens_to_cpu = True self.position_embedding_type = position_embedding_type if position_embedding_type != "none" else \ getattr(self.config, 'position_embedding_type') @@ -261,20 +262,29 @@ class GPTModel(nn.Cell): else: rotary_pos_cos, rotary_pos_sin = \ self.rotary_pos_emb.get_cos_sin_for_decode(positions) - if self.is_pynative: # ops.move_to not support pynative mode - batch_valid_length_cpu = batch_valid_length.move_to("CPU") - q_seq_lens_cpu = q_seq_lens.move_to("CPU") - context_lens_tensor_cpu = context_lens_tensor.move_to("CPU") + + # current aclgraph not support moveto in graph + # add moveto config for model to control + if self.move_lens_to_cpu: + if self.is_pynative: # ops.move_to not support pynative mode + batch_valid_length_cpu = batch_valid_length.move_to("CPU") + q_seq_lens_cpu = q_seq_lens.move_to("CPU") + context_lens_tensor_cpu = context_lens_tensor.move_to("CPU") + else: + batch_valid_length_cpu = ops.move_to(batch_valid_length, "CPU") + q_seq_lens_cpu = ops.move_to(q_seq_lens, "CPU") + context_lens_tensor_cpu = ops.move_to(context_lens_tensor, "CPU") + + # embedding contains the allreduce ops. Adding the depend ops ensures that the move_to ops is + # launched before the allreduce, reducing the sync waiting time when the move_to ops launched. + input_ids = self.depend(input_ids, q_seq_lens_cpu) + input_ids = self.depend(input_ids, batch_valid_length_cpu) + input_ids = self.depend(input_ids, context_lens_tensor_cpu) else: - batch_valid_length_cpu = ops.move_to(batch_valid_length, "CPU") - q_seq_lens_cpu = ops.move_to(q_seq_lens, "CPU") - context_lens_tensor_cpu = ops.move_to(context_lens_tensor, "CPU") - - # embedding contains the allreduce ops. Adding the depend ops ensures that the move_to ops is - # launched before the allreduce, reducing the sync waiting time when the move_to ops launched. - input_ids = self.depend(input_ids, q_seq_lens_cpu) - input_ids = self.depend(input_ids, batch_valid_length_cpu) - input_ids = self.depend(input_ids, context_lens_tensor_cpu) + batch_valid_length_cpu = None + q_seq_lens_cpu = None + context_lens_tensor_cpu = None + # Decoder embedding. if self.pre_process: decoder_input = self.cast(self.embedding(input_ids), self.compute_dtype) @@ -317,7 +327,7 @@ class GPTModel(nn.Cell): return logits def get_params_dict(self): - params_dict = dict() + params_dict = {} for _, module in self.modules_dict.items(): module_params = module.parameters_dict() for param_name, param in module_params.items(): @@ -401,6 +411,7 @@ class GPTModel(nn.Cell): if name in params_dict: if '.weight1' in name or '.weight2' in name: num_experts = self.config.num_moe_experts + weight = {} if '.weight1' in name: weight = loaded_weight[:].reshape(num_experts, self.config.hidden_size, -1) if '.weight2' in name: -- Gitee