From 6184efd1b9fec3ebca6e868e5c8ee1528f9f3a34 Mon Sep 17 00:00:00 2001 From: Hsshuai Date: Wed, 5 Nov 2025 20:37:32 +0800 Subject: [PATCH] add support for predicting with train model --- mindformers/core/context/build_context.py | 2 +- mindformers/core/context/validators.py | 3 +- mindformers/generation/text_generator.py | 33 +++++++++++-------- .../models/qwen3/modeling_qwen3_train.py | 2 +- .../qwen3_moe/modeling_qwen3_moe_train.py | 2 +- .../base_models/gpt/gpt_model.py | 14 +++++--- .../parallel_core/utils/model_mixin.py | 7 ++++ mindformers/tools/register/template.py | 4 +-- mindformers/trainer/base_trainer.py | 2 +- run_mindformer.py | 9 ++--- 10 files changed, 49 insertions(+), 29 deletions(-) diff --git a/mindformers/core/context/build_context.py b/mindformers/core/context/build_context.py index 89b0371f2..1a0b39790 100644 --- a/mindformers/core/context/build_context.py +++ b/mindformers/core/context/build_context.py @@ -306,7 +306,7 @@ class MFContextOperator(MFContextConfig): lccl_deterministic = os.getenv('LCCL_DETERMINISTIC') run_mode = getattr(self, 'run_mode') if hasattr(self, 'run_mode') else None if run_mode in ( - RunMode.TRAIN.value, RunMode.FINETUNE.value + RunMode.TRAIN.value, RunMode.FINETUNE.value, RunMode.PREDICT_WITH_TRAIN_MODEL.value ) and self.train_precision_sync is not None: _, _ = self._call_ms_deterministic(self.train_precision_sync) diff --git a/mindformers/core/context/validators.py b/mindformers/core/context/validators.py index 891a4679c..d0ae956e5 100644 --- a/mindformers/core/context/validators.py +++ b/mindformers/core/context/validators.py @@ -25,6 +25,7 @@ class RunMode(Enum): TRAIN = 'train' PREDICT = 'predict' FINETUNE = 'finetune' + PREDICT_WITH_TRAIN_MODEL = 'predict_with_train_model' EVAL = 'eval' @@ -41,7 +42,7 @@ def validate_mf_ctx_run_mode(config): def validate_ms_ctx_mode(config): """Validate mode in context.""" mode = config.get_value('context.mode', 0) - if mode not in MODE.keys(): + if mode not in MODE: raise ValueError( f'Invalid mode. Expected one of {MODE.keys()}, got {mode}' ) diff --git a/mindformers/generation/text_generator.py b/mindformers/generation/text_generator.py index fc0f0f7cb..85abf848b 100644 --- a/mindformers/generation/text_generator.py +++ b/mindformers/generation/text_generator.py @@ -22,6 +22,12 @@ import time from typing import Optional, List, Union, Dict import numpy as np + +try: + import ms_custom_ops +except ImportError: + ms_custom_ops = None + import mindspore as ms from mindspore import mint, mutable, ops from mindspore.ops import functional as F @@ -29,6 +35,11 @@ from mindspore.ops import operations as P import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor +try: + from mindspore.common.api import _pynative_executor +except ImportError: + _pynative_executor = None + from mindformers.generation.beam_search import BeamSearchScorer from mindformers.generation.generation_config import GenerationConfig from mindformers.generation.logits_process import (LogitNormalization, LogitsProcessorList, @@ -127,8 +138,7 @@ class GenerationMixin: tp_group_size = get_tp_world_size() # When kv heads < tp size, will replicate kv heads - if num_query_groups < tp_group_size: - num_query_groups = tp_group_size + num_query_groups = max(num_query_groups, tp_group_size) if hasattr(tansformer_config, 'kv_channels'): hidden_size_per_attention_head = getattr(tansformer_config, 'kv_channels') @@ -166,7 +176,7 @@ class GenerationMixin: fa3_quant_layer = self.model.quant_config.fa3_quant_layer if self.model.quant_config else set() compute_dtype = tansformer_config.compute_dtype if fa3_quant and not use_ringmla: - raise ValueError(f'For fa3_quant, it is necessary to set use_ringmla to True.') + raise ValueError('For fa3_quant, it is necessary to set use_ringmla to True.') key_cache = [] value_cache = [] if is_310p(): @@ -175,7 +185,6 @@ class GenerationMixin: kv_cache_shape = (num_blocks, block_size, merge_dim) for num_layer in range(tansformer_config.num_layers): if fa3_quant: - import ms_custom_ops k_cache_dtype = mstype.int8 if num_layer in fa3_quant_layer else compute_dtype k_cache = mint.zeros(kv_cache_shape[:-2] + (tansformer_config.kv_lora_rank,), dtype=k_cache_dtype) v_cache = mint.zeros(kv_cache_shape[:-2] + (tansformer_config.qk_pos_emb_head_dim,), @@ -629,8 +638,7 @@ class GenerationMixin: encoder_output = None encoder_mask = None if self.config.is_encoder_decoder: - if target_length > self.config.max_decode_length: - target_length = self.config.max_decode_length + target_length = min(target_length, self.config.max_decode_length) logger.debug("target_length is: %s", target_length) # When do encoder and decoder prediction, the encoder can be cached @@ -943,7 +951,7 @@ class GenerationMixin: "`streamer` cannot be used with beam search yet. Make sure that `num_beams` is set to 1." ) - if not use_legacy: + if not use_legacy and not hasattr(self, "is_train_model"): self._set_block_mgr(batch_size, self.config.seq_length) self._set_kv_cache() self._set_lower_triangle_mask() @@ -1017,8 +1025,7 @@ class GenerationMixin: encoder_mask = None target_mask = None if self.config.is_encoder_decoder: - if generation_config.max_length > self.config.max_decode_length: - generation_config.max_length = self.config.max_decode_length + generation_config.max_length = min(generation_config.max_length, self.config.max_decode_length) logger.debug("max decode length is: %s", generation_config.max_length) # When do encoder and decoder prediction, the encoder can be cached @@ -1058,7 +1065,7 @@ class GenerationMixin: self.detailed_latency.start_preprocess_timer() block_tables = None slot_mapping = None - if not use_legacy or generation_config.use_past: + if (not use_legacy or generation_config.use_past) and not hasattr(self, "is_train_model"): if prefill: if (use_legacy and self.is_pynative and self.config.is_dynamic): max_input_length = len(origin_inputs[0]) @@ -1071,7 +1078,7 @@ class GenerationMixin: block_tables, slot_mapping = self.block_mgr.assemble_pa_inc_inputs(valid_length_each_example, is_finished) self.profile.start_profiling(valid_length_each_example[0] - input_ids_length) - if use_legacy: + if use_legacy or (hasattr(self, "is_train_model") and self.is_train_model): infer_output, is_finished = self.infer(input_ids=input_ids, valid_length_each_example=valid_length_each_example, generation_config=generation_config, @@ -1382,7 +1389,7 @@ class GenerationMixin: prefill: bool = None, **model_kwargs): """prepare inputs for mcore""" - model_inputs = dict() + model_inputs = {} seq_lens = np.array(valid_length_each_example) q_seq_lens = model_kwargs.get("q_seq_lens", None) @@ -1668,7 +1675,6 @@ class GenerationMixin: is_finished, whether the sequence has completed its generation task. """ if not self.is_pynative: - from mindspore.common.api import _pynative_executor _pynative_executor.set_async_for_graph(True) batch_size = input_ids.shape[0] target_list = [[] for _ in range(batch_size)] @@ -1762,7 +1768,6 @@ class GenerationMixin: elif generation_config.generation_mode == GenerationMode.BEAM_SEARCH: raise ValueError("sampler method doesn't support BEAM_SEARCH. ") if not self.is_pynative: - from mindspore.common.api import _pynative_executor _pynative_executor.sync() _pynative_executor.set_async_for_graph(False) return target_list, next_probs_cache, next_logits_cache, is_finished diff --git a/mindformers/models/qwen3/modeling_qwen3_train.py b/mindformers/models/qwen3/modeling_qwen3_train.py index c6ee82060..98d83b092 100644 --- a/mindformers/models/qwen3/modeling_qwen3_train.py +++ b/mindformers/models/qwen3/modeling_qwen3_train.py @@ -24,7 +24,7 @@ from mindformers.models.qwen3.utils import Qwen3PreTrainedModel from .configuration_qwen3 import Qwen3Config -class TrainingQwen3ForCausalLM(Qwen3PreTrainedModel, TrainModelMixin): +class TrainingQwen3ForCausalLM(TrainModelMixin, Qwen3PreTrainedModel): """ Provide qwen2 model infer through network. diff --git a/mindformers/models/qwen3_moe/modeling_qwen3_moe_train.py b/mindformers/models/qwen3_moe/modeling_qwen3_moe_train.py index 09d2e2a7a..331cffff8 100644 --- a/mindformers/models/qwen3_moe/modeling_qwen3_moe_train.py +++ b/mindformers/models/qwen3_moe/modeling_qwen3_moe_train.py @@ -29,7 +29,7 @@ from mindformers.parallel_core.utils.model_mixin import TrainModelMixin @MindFormerRegister.register(MindFormerModuleType.MODELS) -class TrainingQwen3MoeForCausalLM(Qwen3MoePreTrainedModel, TrainModelMixin): +class TrainingQwen3MoeForCausalLM(TrainModelMixin, Qwen3MoePreTrainedModel): """ Provide qwen3_moe model infer through network. diff --git a/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py b/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py index 342d901d6..297416075 100644 --- a/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py +++ b/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py @@ -139,10 +139,16 @@ class PreprocessLabelsAndMasks(nn.Cell): """ if loss_mask is None: loss_mask = self.cast(self.not_equal(input_ids, self.pad_token_id), dtype.float32) - label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), dtype.float32) - loss_mask = self.mul(loss_mask, label_mask) - local_loss_mask = self.morphed_reshape_labels_and_masks(loss_mask) - local_labels = self.morphed_reshape_labels_and_masks(labels) + + if labels is not None: + label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), dtype.float32) + loss_mask = self.mul(loss_mask, label_mask) + local_loss_mask = self.morphed_reshape_labels_and_masks(loss_mask) + local_labels = self.morphed_reshape_labels_and_masks(labels) + else: + local_loss_mask = None + local_labels = None + if self.use_attn_mask_compression: attention_mask = self.casual_mask() elif attention_mask is None: diff --git a/mindformers/parallel_core/utils/model_mixin.py b/mindformers/parallel_core/utils/model_mixin.py index 87bb49c2e..7a0bff698 100644 --- a/mindformers/parallel_core/utils/model_mixin.py +++ b/mindformers/parallel_core/utils/model_mixin.py @@ -14,6 +14,8 @@ # ============================================================================ """ModelMixin for train models""" +from mindspore import Tensor +import mindspore.common.dtype as mstype import numpy as np from mindformers.tools.logger import logger @@ -125,6 +127,8 @@ class ModelMixin: class TrainModelMixin: """General interfaces for train models.""" + is_train_model = True + def concat_qkv_weight_megatron(self, wq_keys, wk_keys, wv_keys, qkv_weight_dict, condition, ms_weight_dict, head_dim, n_kv_heads, num_attention_heads): """ @@ -503,3 +507,6 @@ class TrainModelMixin: parameters, param_names, param_layers, logit_threshold, split_fn, merge_fn ) + + def prepare_inputs_for_predict_layout(self, input_ids, **kwargs): + return Tensor(input_ids, mstype.int32), None, None, None, None, None, None, None, None diff --git a/mindformers/tools/register/template.py b/mindformers/tools/register/template.py index 48f6770be..aaddbb43f 100644 --- a/mindformers/tools/register/template.py +++ b/mindformers/tools/register/template.py @@ -638,7 +638,7 @@ class ConfigTemplate: "metric" ] - _run_modes = ['train', 'eval', 'predict', 'finetune'] + _run_modes = ['train', 'eval', 'predict', 'finetune', 'predict_with_train_model'] @classmethod def apply_template(cls, config): @@ -657,7 +657,7 @@ class ConfigTemplate: if run_mode not in cls._run_modes: logger.warning(f"The specified run_mode '{run_mode}' is invalid. Expected one of {cls._run_modes}.") template = cls.general_configs - elif run_mode in ['train', 'finetune']: + elif run_mode in ['train', 'finetune', 'predict_with_train_model']: template = cls._train_template(config.get("do_eval", False)) elif run_mode == "predict": template = cls._predict_template() diff --git a/mindformers/trainer/base_trainer.py b/mindformers/trainer/base_trainer.py index 4da99944e..f35b5ff58 100644 --- a/mindformers/trainer/base_trainer.py +++ b/mindformers/trainer/base_trainer.py @@ -1675,7 +1675,7 @@ class BaseTrainer: model = Model(network) - if not config.use_legacy and config.load_checkpoint: + if not config.use_legacy and config.load_checkpoint and config.run_mode != "predict_with_train_model": if self.config.load_ckpt_format == 'safetensors': network.load_weights(config.load_checkpoint) else: diff --git a/run_mindformer.py b/run_mindformer.py index b28ca84e2..62b22c3f5 100644 --- a/run_mindformer.py +++ b/run_mindformer.py @@ -70,7 +70,7 @@ def main(config): trainer.train() elif config.run_mode == 'eval': trainer.evaluate(eval_checkpoint=config.load_checkpoint) - elif config.run_mode == 'predict': + elif config.run_mode in ['predict', 'predict_with_train_model']: trainer.predict(predict_checkpoint=config.load_checkpoint, input_data=config.input_data, batch_size=config.predict_batch_size, adapter_id=config.adapter_id) @@ -272,15 +272,16 @@ if __name__ == "__main__": config_.profile = args_.profile if args_.options is not None: config_.merge_from_dict(args_.options) - if config_.run_mode not in ['train', 'eval', 'predict', 'finetune']: - raise TypeError(f"run status must be in {['train', 'eval', 'predict', 'finetune']}, but {config_.run_mode}") + if config_.run_mode not in ['train', 'eval', 'predict', 'finetune', 'predict_with_train_model']: + raise TypeError(f"run status must be in {['train', 'eval', 'predict', 'finetune', 'predict_with_train_model']}" + f", but {config_.run_mode}") if args_.train_dataset_dir: config_.train_dataset.data_loader.dataset_dir = args_.train_dataset_dir if args_.eval_dataset_dir: config_.eval_dataset.data_loader.dataset_dir = args_.eval_dataset_dir if args_.do_sample is not None: config_.model.model_config.do_sample = args_.do_sample - if config_.run_mode == 'predict': + if config_.run_mode in ['predict', 'predict_with_train_model']: if args_.predict_data is None: logger.info("dataset by config is used as input_data.") if isinstance(args_.predict_data, list): -- Gitee