From 7db691b9f0ed910bf955eec39f34b521afff5971 Mon Sep 17 00:00:00 2001 From: Hsshuai Date: Thu, 27 Nov 2025 19:28:34 +0800 Subject: [PATCH] add testcase for trainer --- mindformers/trainer/trainer.py | 5 +- .../test_trainer/test_trainer_config.yaml | 183 ++++ .../test_trainer/test_trainer_methods.py | 980 ++++++++++++++++++ 3 files changed, 1166 insertions(+), 2 deletions(-) create mode 100644 tests/st/test_ut/test_trainer/test_trainer_config.yaml create mode 100644 tests/st/test_ut/test_trainer/test_trainer_methods.py diff --git a/mindformers/trainer/trainer.py b/mindformers/trainer/trainer.py index eeba09a76..4273d7074 100644 --- a/mindformers/trainer/trainer.py +++ b/mindformers/trainer/trainer.py @@ -1580,8 +1580,9 @@ def _reset_config_for_save(config: dict = None): """ if config is None: config = {} - config = copy.deepcopy(config) - reset_parallel_config(config) + else: + config = copy.deepcopy(config) + reset_parallel_config(config) config_dict = OrderedDict() diff --git a/tests/st/test_ut/test_trainer/test_trainer_config.yaml b/tests/st/test_ut/test_trainer/test_trainer_config.yaml new file mode 100644 index 000000000..089c314cd --- /dev/null +++ b/tests/st/test_ut/test_trainer/test_trainer_config.yaml @@ -0,0 +1,183 @@ +seed: 0 +run_mode: 'train' +output_dir: './output' +load_checkpoint: '' +src_strategy_path_or_dir: '' +auto_trans_ckpt: False +only_save_strategy: False +resume_training: False + +# context +context: + mode: 0 + device_target: "Ascend" + max_call_depth: 10000 + save_graphs: False + device_id: 0 + +# runner +runner_config: + epochs: 1 + batch_size: 2 + sink_mode: False + sink_size: 1 + +runner_wrapper: + type: MFTrainOneStepCell + scale_sense: + type: DynamicLossScaleUpdateCell + loss_scale_value: 4096 + scale_factor: 2 + scale_window: 1000 + use_clip_grad: True + +# parallel +use_parallel: False +parallel: + parallel_mode: 0 + gradients_mean: True + search_mode: "sharding_propagation" + enable_parallel_optimizer: False + +parallel_config: + data_parallel: 1 + model_parallel: 1 + pipeline_stage: 1 + use_seq_parallel: False + micro_batch_num: 1 + gradient_aggregation_group: 4 + +micro_batch_interleave_num: 1 + +# moe +moe_config: + expert_num: 1 + capacity_factor: 1.05 + aux_loss_factor: 0.05 + num_experts_chosen: 1 + +# recompute +recompute_config: + recompute: False + select_recompute: False + parallel_optimizer_comm_recompute: False + mp_comm_recompute: True + recompute_slice_activation: False + +# autotune +auto_tune: False +filepath_prefix: './autotune' +autotune_per_step: 10 + +# profile +profile: False +profile_start_step: 1 +profile_stop_step: 10 +init_start_profile: False +profile_communication: False +profile_memory: False +profile_output: './profile' +profiler_level: 0 +with_stack: False +data_simplification: True +mstx: False +profile_rank_ids: [0] +profile_pipeline: False + +# Trainer +trainer: + type: CausalLanguageModelingTrainer + model_name: 'common' + +# dataset +train_dataset: + data_loader: + type: MindDataset + dataset_dir: "" + shuffle: True + tokenizer: + type: PreTrainedTokenizerBase + +train_dataset_task: + type: CausalLanguageModelDataset + +eval_dataset: + data_loader: + type: MindDataset + dataset_dir: "" + shuffle: False + tokenizer: + type: PreTrainedTokenizerBase + +eval_dataset_task: + type: CausalLanguageModelDataset + +# model +model: + model_config: + type: PretrainedConfig + checkpoint_name_or_path: "" + max_decode_length: 1024 + batch_size: 2 + arch: + type: PreTrainedModel + +# lr schedule +lr_schedule: + type: polynomial + learning_rate: 0.0001 + lr_end: 0.00001 + warmup_steps: 0 + total_steps: -1 + +layer_scale: False +layer_decay: 0.65 + +# optimizer +optimizer: + type: AdamW + betas: [0.9, 0.95] + eps: 0.00000001 + weight_decay: 0.05 + +lr_scale: False +lr_scale_factor: 256 + +# callbacks +callbacks: + - type: MFLossMonitor + - type: CheckpointMonitor + prefix: "mindformers" + save_checkpoint_steps: 100 + integrated_save: True + async_save: False + +# metric +metric: + type: PerplexityMetric + +# processor +processor: + return_tensors: ms + tokenizer: + type: PreTrainedTokenizerBase + image_processor: + type: BaseImageProcessor + audio_processor: + type: BaseAudioProcessor + +# check for nan in loss and grad +check_for_nan_in_loss_and_grad: False + +# push to hub +push_to_hub: False +hub_model_id: null +hub_token: null +hub_private_repo: False + +# other configs +do_eval: False +use_legacy_format: True +load_ckpt_format: 'safetensors' +resume_by_last_timestamp_ckpt: False +use_checkpoint_health_monitor: False diff --git a/tests/st/test_ut/test_trainer/test_trainer_methods.py b/tests/st/test_ut/test_trainer/test_trainer_methods.py new file mode 100644 index 000000000..c3b9f294d --- /dev/null +++ b/tests/st/test_ut/test_trainer/test_trainer_methods.py @@ -0,0 +1,980 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test trainer methods.""" +import os +import tempfile +import unittest +from copy import deepcopy +from contextlib import ExitStack +from unittest.mock import patch, MagicMock +from collections import OrderedDict +import pytest +import yaml + +try: + from mindspore.train import Callback +except ImportError: + # Fallback for testing environments + class Callback: + """Dummy Callback class for testing.""" + +from mindformers import Trainer, MindFormerConfig +from mindformers.trainer.trainer import _reset_config_for_save, _save_config_to_yaml +from mindformers.models import PreTrainedTokenizerBase + + +# pylint: disable=W0212 +class TestTrainerCallbackMethods(unittest.TestCase): + """Test Trainer callback related methods.""" + + @classmethod + def setUpClass(cls): + cls._exit_stack = ExitStack() + cls.temp_dir = cls._exit_stack.enter_context(tempfile.TemporaryDirectory()) + # Load config from yaml file + config_path = os.path.join(os.path.dirname(__file__), 'test_trainer_config.yaml') + cls.config = MindFormerConfig(config_path) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_add_callback_with_class(self): + """test add_callback with callback class.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + + class DummyCallback(Callback): + """Dummy callback for testing.""" + + initial_len = len(trainer.callbacks) + trainer.add_callback(DummyCallback) + assert len(trainer.callbacks) == initial_len + 1 + assert isinstance(trainer.callbacks[-1], DummyCallback) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_add_callback_with_instance(self): + """test add_callback with callback instance.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + + class DummyCallback(Callback): + """Dummy callback for testing.""" + + callback_instance = DummyCallback() + initial_len = len(trainer.callbacks) + trainer.add_callback(callback_instance) + assert len(trainer.callbacks) == initial_len + 1 + assert trainer.callbacks[-1] == callback_instance + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_add_duplicate_callback(self): + """test add_callback with duplicate callback type.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + + class DummyCallback(Callback): + """Dummy callback for testing.""" + + trainer.add_callback(DummyCallback) + initial_len = len(trainer.callbacks) + # Adding duplicate should still add but warn + trainer.add_callback(DummyCallback) + assert len(trainer.callbacks) == initial_len + 1 + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_pop_callback_with_class(self): + """test pop_callback with callback class.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + + class DummyCallback(Callback): + """Dummy callback for testing.""" + + trainer.add_callback(DummyCallback) + initial_len = len(trainer.callbacks) + popped = trainer.pop_callback(DummyCallback) + assert isinstance(popped, DummyCallback) + assert len(trainer.callbacks) == initial_len - 1 + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_pop_callback_with_instance(self): + """test pop_callback with callback instance.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + + class DummyCallback(Callback): + """Dummy callback for testing.""" + + callback_instance = DummyCallback() + trainer.add_callback(callback_instance) + initial_len = len(trainer.callbacks) + popped = trainer.pop_callback(callback_instance) + assert popped == callback_instance + assert len(trainer.callbacks) == initial_len - 1 + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_pop_callback_not_found(self): + """test pop_callback when callback not in list.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + + class DummyCallback(Callback): + """Dummy callback for testing.""" + + result = trainer.pop_callback(DummyCallback) + assert result == DummyCallback + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_remove_callback_with_class(self): + """test remove_callback with callback class.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + + class DummyCallback(Callback): + """Dummy callback for testing.""" + + trainer.add_callback(DummyCallback) + initial_len = len(trainer.callbacks) + trainer.remove_callback(DummyCallback) + assert len(trainer.callbacks) == initial_len - 1 + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_remove_callback_with_instance(self): + """test remove_callback with callback instance.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + + class DummyCallback(Callback): + """Dummy callback for testing.""" + + callback_instance = DummyCallback() + trainer.add_callback(callback_instance) + initial_len = len(trainer.callbacks) + trainer.remove_callback(callback_instance) + assert len(trainer.callbacks) == initial_len - 1 + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_callback_list_property(self): + """test callback_list property.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + + class DummyCallback1(Callback): + """Dummy callback 1 for testing.""" + + class DummyCallback2(Callback): + """Dummy callback 2 for testing.""" + + trainer.add_callback(DummyCallback1) + trainer.add_callback(DummyCallback2) + callback_list = trainer.callback_list + assert isinstance(callback_list, str) + assert 'DummyCallback1' in callback_list + assert 'DummyCallback2' in callback_list + + @classmethod + def tearDownClass(cls): + """Clean up temporary directory.""" + if hasattr(cls, '_exit_stack'): + cls._exit_stack.close() + + +class TestTrainerParallelConfig(unittest.TestCase): + """Test Trainer parallel configuration methods.""" + + @classmethod + def setUpClass(cls): + cls._exit_stack = ExitStack() + cls.temp_dir = cls._exit_stack.enter_context(tempfile.TemporaryDirectory()) + config_path = os.path.join(os.path.dirname(__file__), 'test_trainer_config.yaml') + cls.config = MindFormerConfig(config_path) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_set_parallel_config_basic(self): + """test set_parallel_config with basic parameters.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + trainer.set_parallel_config( + data_parallel=2, + model_parallel=2, + pipeline_stage=1 + ) + assert trainer.config.parallel_config.data_parallel == 2 + assert trainer.config.parallel_config.model_parallel == 2 + assert trainer.config.parallel_config.pipeline_stage == 1 + assert trainer.reset_model is True + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_set_parallel_config_all_params(self): + """test set_parallel_config with all parameters.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + trainer.set_parallel_config( + data_parallel=2, + model_parallel=2, + context_parallel=1, + expert_parallel=1, + pipeline_stage=2, + micro_batch_interleave_num=2, + micro_batch_num=4, + use_seq_parallel=True, + optimizer_shard=True, + gradient_aggregation_group=8, + vocab_emb_dp=False + ) + assert trainer.config.parallel_config.data_parallel == 2 + assert trainer.config.parallel_config.model_parallel == 2 + assert trainer.config.parallel_config.context_parallel == 1 + assert trainer.config.parallel_config.expert_parallel == 1 + assert trainer.config.parallel_config.pipeline_stage == 2 + assert trainer.config.parallel_config.use_seq_parallel is True + assert trainer.config.parallel_config.optimizer_shard is True + assert trainer.config.parallel_config.micro_batch_num == 4 + assert trainer.config.parallel_config.vocab_emb_dp is False + assert trainer.config.parallel_config.gradient_aggregation_group == 8 + assert trainer.config.micro_batch_interleave_num == 2 + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_set_parallel_config_default_values(self): + """test set_parallel_config with default values.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + trainer.set_parallel_config() + assert trainer.config.parallel_config.data_parallel == 1 + assert trainer.config.parallel_config.model_parallel == 1 + assert trainer.config.parallel_config.pipeline_stage == 1 + assert trainer.config.parallel_config.use_seq_parallel is False + assert trainer.config.parallel_config.optimizer_shard is False + + @classmethod + def tearDownClass(cls): + """Clean up temporary directory.""" + if hasattr(cls, '_exit_stack'): + cls._exit_stack.close() + + +class TestTrainerRecomputeConfig(unittest.TestCase): + """Test Trainer recompute configuration methods.""" + + @classmethod + def setUpClass(cls): + cls._exit_stack = ExitStack() + cls.temp_dir = cls._exit_stack.enter_context(tempfile.TemporaryDirectory()) + config_path = os.path.join(os.path.dirname(__file__), 'test_trainer_config.yaml') + cls.config = MindFormerConfig(config_path) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_set_recompute_config_basic(self): + """test set_recompute_config with basic parameters.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + trainer.set_recompute_config(recompute=True) + assert trainer.config.recompute_config.recompute is True + assert trainer.reset_model is True + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_set_recompute_config_all_params(self): + """test set_recompute_config with all parameters.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + trainer.set_recompute_config( + recompute=True, + parallel_optimizer_comm_recompute=True, + select_recompute=True, + mp_comm_recompute=False, + recompute_slice_activation=True + ) + assert trainer.config.recompute_config.recompute is True + assert trainer.config.recompute_config.parallel_optimizer_comm_recompute is True + assert trainer.config.recompute_config.select_recompute is True + assert trainer.config.recompute_config.mp_comm_recompute is False + assert trainer.config.recompute_config.recompute_slice_activation is True + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_set_recompute_config_default_values(self): + """test set_recompute_config with default values.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + trainer.set_recompute_config() + assert trainer.config.recompute_config.recompute is False + assert trainer.config.recompute_config.parallel_optimizer_comm_recompute is False + assert trainer.config.recompute_config.select_recompute is False + assert trainer.config.recompute_config.mp_comm_recompute is True + assert trainer.config.recompute_config.recompute_slice_activation is False + + @classmethod + def tearDownClass(cls): + """Clean up temporary directory.""" + if hasattr(cls, '_exit_stack'): + cls._exit_stack.close() + + +class TestTrainerMoeConfig(unittest.TestCase): + """Test Trainer MoE configuration methods.""" + + @classmethod + def setUpClass(cls): + cls._exit_stack = ExitStack() + cls.temp_dir = cls._exit_stack.enter_context(tempfile.TemporaryDirectory()) + config_path = os.path.join(os.path.dirname(__file__), 'test_trainer_config.yaml') + cls.config = MindFormerConfig(config_path) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_set_moe_config_basic(self): + """test _set_moe_config with basic parameters.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + trainer._set_moe_config(expert_num=4) + assert trainer.config.moe_config.expert_num == 4 + assert trainer.reset_model is True + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_set_moe_config_all_params(self): + """test _set_moe_config with all parameters.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + trainer._set_moe_config( + expert_num=8, + capacity_factor=1.5, + aux_loss_factor=0.1, + num_experts_chosen=2, + expert_group_size=128, + group_wise_a2a=True, + comp_comm_parallel=True, + comp_comm_parallel_degree=4 + ) + assert trainer.config.moe_config.expert_num == 8 + assert trainer.config.moe_config.capacity_factor == 1.5 + assert trainer.config.moe_config.aux_loss_factor == 0.1 + assert trainer.config.moe_config.num_experts_chosen == 2 + assert trainer.config.moe_config.expert_group_size == 128 + assert trainer.config.moe_config.group_wise_a2a is True + assert trainer.config.moe_config.comp_comm_parallel is True + assert trainer.config.moe_config.comp_comm_parallel_degree == 4 + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_set_moe_config_default_values(self): + """test _set_moe_config with default values.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + trainer._set_moe_config() + assert trainer.config.moe_config.expert_num == 1 + assert trainer.config.moe_config.capacity_factor == 1.1 + assert trainer.config.moe_config.aux_loss_factor == 0.05 + assert trainer.config.moe_config.num_experts_chosen == 1 + assert trainer.config.moe_config.group_wise_a2a is False + assert trainer.config.moe_config.comp_comm_parallel is False + assert trainer.config.moe_config.comp_comm_parallel_degree == 2 + + @classmethod + def tearDownClass(cls): + """Clean up temporary directory.""" + if hasattr(cls, '_exit_stack'): + cls._exit_stack.close() + + +class TestTrainerDataloaderMethods(unittest.TestCase): + """Test Trainer dataloader related methods.""" + + @classmethod + def setUpClass(cls): + cls._exit_stack = ExitStack() + cls.temp_dir = cls._exit_stack.enter_context(tempfile.TemporaryDirectory()) + config_path = os.path.join(os.path.dirname(__file__), 'test_trainer_config.yaml') + cls.config = MindFormerConfig(config_path) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_get_train_dataloader(self): + """test get_train_dataloader method.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + # Mock the config to avoid actual dataloader building + with patch('mindformers.trainer.trainer.build_dataset_loader') as mock_build: + mock_build.return_value = MagicMock() + dataloader = trainer.get_train_dataloader() + assert dataloader is not None + mock_build.assert_called_once() + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_get_eval_dataloader(self): + """test get_eval_dataloader method.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + # Mock the config to avoid actual dataloader building + with patch('mindformers.trainer.trainer.build_dataset_loader') as mock_build: + mock_build.return_value = MagicMock() + dataloader = trainer.get_eval_dataloader() + assert dataloader is not None + mock_build.assert_called_once() + + @classmethod + def tearDownClass(cls): + """Clean up temporary directory.""" + if hasattr(cls, '_exit_stack'): + cls._exit_stack.close() + + +class TestTrainerCheckpointMethods(unittest.TestCase): + """Test Trainer checkpoint related methods.""" + + @classmethod + def setUpClass(cls): + cls.config_path = os.path.join(os.path.dirname(__file__), 'test_trainer_config.yaml') + cls.config = MindFormerConfig(cls.config_path) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_get_load_checkpoint_valid_path(self): + """test get_load_checkpoint with valid checkpoint path.""" + # Create a temporary checkpoint file + with tempfile.TemporaryDirectory() as temp_dir: + ckpt_path = os.path.join(temp_dir, 'test_checkpoint.ckpt') + with open(ckpt_path, 'w', encoding='utf-8') as f: + f.write('mock checkpoint') + os.stat(ckpt_path) + + result = Trainer.get_load_checkpoint(ckpt_path) + assert result == ckpt_path + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_get_load_checkpoint_none(self): + """test get_load_checkpoint with None.""" + result = Trainer.get_load_checkpoint(None) + assert result is None + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_get_load_checkpoint_invalid_path(self): + """test get_load_checkpoint with invalid path.""" + with pytest.raises(ValueError) as exc_info: + Trainer.get_load_checkpoint('/nonexistent/path/checkpoint.ckpt') + assert 'not existed' in str(exc_info.value) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_get_load_checkpoint_invalid_type(self): + """test get_load_checkpoint with invalid type.""" + with pytest.raises(TypeError) as exc_info: + Trainer.get_load_checkpoint(123) + assert 'should be a str' in str(exc_info.value) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_check_checkpoint_config_with_true(self): + """test _check_checkpoint_config with True.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + + with tempfile.TemporaryDirectory() as temp_dir: + trainer.config.output_dir = temp_dir + checkpoint_dir = os.path.join(temp_dir, 'checkpoint', 'rank_0') + os.makedirs(checkpoint_dir, exist_ok=True) + + with open(os.path.join(checkpoint_dir, "1.safetensors"), 'w', encoding='utf-8') as f: + f.write('mock') + os.stat(os.path.join(checkpoint_dir, "1.safetensors")) + + with open(os.path.join(checkpoint_dir, "2.safetensors"), 'w', encoding='utf-8') as f: + f.write('mock') + os.stat(os.path.join(checkpoint_dir, "2.safetensors")) + + last_checkpoint_path = os.path.join(checkpoint_dir, "3.safetensors") + with open(last_checkpoint_path, 'w', encoding='utf-8') as f: + f.write('mock') + os.stat(last_checkpoint_path) + + trainer._check_checkpoint_config(True) + assert trainer.config.model.model_config.checkpoint_name_or_path == last_checkpoint_path + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_check_checkpoint_config_with_path(self): + """test _check_checkpoint_config with checkpoint path.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + with tempfile.TemporaryDirectory() as temp_dir: + ckpt_path = os.path.join(temp_dir, 'test.ckpt') + with open(ckpt_path, 'w', encoding='utf-8') as f: + f.write('mock') + os.stat(ckpt_path) + + trainer._check_checkpoint_config(ckpt_path) + assert trainer.config.model.model_config.checkpoint_name_or_path == ckpt_path + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_check_checkpoint_config_with_none(self): + """test _check_checkpoint_config with None.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + trainer.default_checkpoint_name_or_path = 'default_ckpt' + trainer._check_checkpoint_config(None) + assert trainer.config.model.model_config.checkpoint_name_or_path == 'default_ckpt' + + +class TestTrainerConfigMethods(unittest.TestCase): + """Test Trainer config related methods.""" + + @classmethod + def setUpClass(cls): + cls._exit_stack = ExitStack() + cls.temp_dir = cls._exit_stack.enter_context(tempfile.TemporaryDirectory()) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_config_init_with_none(self): + """test _config_init with None args.""" + result = Trainer._config_init(None, None) + assert isinstance(result, MindFormerConfig) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_config_init_with_mindformer_config(self): + """test _config_init with MindFormerConfig.""" + config = MindFormerConfig() + result = Trainer._config_init(config, None) + assert result == config + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_config_init_with_yaml_path(self): + """test _config_init with yaml path.""" + # Create a temporary yaml file + yaml_path = os.path.join(self.temp_dir, 'test_config.yaml') + config_dict = { + 'model': {'model_config': {'type': 'test_model'}}, + 'trainer': {'type': 'test_trainer'} + } + with open(yaml_path, 'w', encoding='utf-8') as f: + yaml.dump(config_dict, f) + + result = Trainer._config_init(yaml_path, None) + assert isinstance(result, MindFormerConfig) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_config_init_with_invalid_yaml_path(self): + """test _config_init with invalid yaml path.""" + with pytest.raises(ValueError) as exc_info: + Trainer._config_init('/nonexistent/config.yaml', None) + assert 'must be exist' in str(exc_info.value) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_config_init_with_invalid_extension(self): + """test _config_init with invalid file extension.""" + txt_path = os.path.join(self.temp_dir, 'config.txt') + with open(txt_path, 'w', encoding='utf-8') as f: + f.write('test') + + with pytest.raises(ValueError) as exc_info: + Trainer._config_init(txt_path, None) + assert 'must be end with .yaml or .yml' in str(exc_info.value) + + @classmethod + def tearDownClass(cls): + """Clean up temporary directory.""" + if hasattr(cls, '_exit_stack'): + cls._exit_stack.close() + + +class TestTrainerInitMethods(unittest.TestCase): + """Test Trainer initialization related methods.""" + + @classmethod + def setUpClass(cls): + cls._exit_stack = ExitStack() + cls.temp_dir = cls._exit_stack.enter_context(tempfile.TemporaryDirectory()) + config_path = os.path.join(os.path.dirname(__file__), 'test_trainer_config.yaml') + cls.config = MindFormerConfig(config_path) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_init_tokenizer(self): + """test _init_tokenizer method.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + mock_tokenizer = MagicMock(spec=PreTrainedTokenizerBase) + trainer.tokenizer = mock_tokenizer + trainer._init_tokenizer() + # Verify tokenizer is set in config + assert trainer.config.train_dataset.tokenizer == mock_tokenizer + assert trainer.config.eval_dataset.tokenizer == mock_tokenizer + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_init_tokenizer_none(self): + """test _init_tokenizer with None tokenizer.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + trainer.tokenizer = None + trainer._init_tokenizer() + # Should not raise error + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_init_dataset_with_train_path(self): + """test _init_dataset with train dataset path.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + train_dir = os.path.join(self.temp_dir, 'train_data') + os.makedirs(train_dir, exist_ok=True) + trainer.train_dataset = train_dir + trainer._init_dataset() + assert trainer.config.train_dataset.data_loader.dataset_dir == train_dir + assert trainer.train_dataset is None + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_init_dataset_with_eval_path(self): + """test _init_dataset with eval dataset path.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + eval_dir = os.path.join(self.temp_dir, 'eval_data') + os.makedirs(eval_dir, exist_ok=True) + trainer.eval_dataset = eval_dir + trainer._init_dataset() + assert trainer.config.eval_dataset.data_loader.dataset_dir == eval_dir + assert trainer.eval_dataset is None + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_init_dataset_invalid_train_path(self): + """test _init_dataset with invalid train path.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + trainer.train_dataset = '/nonexistent/train/path' + with pytest.raises(ValueError) as exc_info: + trainer._init_dataset() + assert 'must be exist' in str(exc_info.value) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_init_dataset_invalid_eval_path(self): + """test _init_dataset with invalid eval path.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + trainer.eval_dataset = '/nonexistent/eval/path' + with pytest.raises(ValueError) as exc_info: + trainer._init_dataset() + assert 'must be exist' in str(exc_info.value) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_build_profile_cb_enabled(self): + """test _build_profile_cb when profile is enabled.""" + config = deepcopy(self.config) + config.profile = True + config.profile_start_step = 10 + config.profile_stop_step = 20 + config.init_start_profile = False + config.profile_communication = False + config.profile_memory = False + config.profile_output = self.temp_dir + config.profiler_level = 0 + config.with_stack = False + config.data_simplification = True + config.runner_config.sink_size = 1 + config.mstx = False + config.runner_config.sink_mode = False + trainer = Trainer(args=config, task='general', model_name='common') + + + trainer._build_profile_cb() + assert hasattr(trainer.config, 'profile_cb') + assert trainer.config.profile_cb is not None + assert trainer.config.auto_tune is False + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_build_profile_cb_disabled(self): + """test _build_profile_cb when profile is disabled.""" + config = deepcopy(self.config) + config.profile = False + trainer = Trainer(args=config, task='general', model_name='common') + trainer._build_profile_cb() + assert not hasattr(trainer.config, 'profile_cb') or trainer.config.profile_cb is None + + @classmethod + def tearDownClass(cls): + """Clean up temporary directory.""" + if hasattr(cls, '_exit_stack'): + cls._exit_stack.close() + + +class TestTrainerSaveMethods(unittest.TestCase): + """Test Trainer save related methods.""" + + @classmethod + def setUpClass(cls): + cls._exit_stack = ExitStack() + cls.temp_dir = cls._exit_stack.enter_context(tempfile.TemporaryDirectory()) + config_path = os.path.join(os.path.dirname(__file__), 'test_trainer_config.yaml') + cls.config = MindFormerConfig(config_path) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_save_config_to_yaml_function(self): + """test _save_config_to_yaml function.""" + save_path = os.path.join(self.temp_dir, 'test_config.yaml') + config_dict = {'model': {'type': 'test'}, 'trainer': {'type': 'test'}} + + _save_config_to_yaml(save_path, config_dict) + assert os.path.exists(save_path) + + # Verify content + with open(save_path, 'r', encoding='utf-8') as f: + loaded = yaml.safe_load(f) + assert loaded['model']['type'] == 'test' + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_save_model_method(self): + """test save_model method.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + save_dir = os.path.join(self.temp_dir, 'model_save') + os.makedirs(save_dir, exist_ok=True) + + with patch.object(trainer, '_save') as mock_save: + trainer.save_model(output_dir=save_dir) + mock_save.assert_called_once() + + @classmethod + def tearDownClass(cls): + """Clean up temporary directory.""" + if hasattr(cls, '_exit_stack'): + cls._exit_stack.close() + + +class TestTrainerCheckMethods(unittest.TestCase): + """Test Trainer check related methods.""" + + @classmethod + def setUpClass(cls): + cls._exit_stack = ExitStack() + cls.temp_dir = cls._exit_stack.enter_context(tempfile.TemporaryDirectory()) + config_path = os.path.join(os.path.dirname(__file__), 'test_trainer_config.yaml') + cls.config = MindFormerConfig(config_path) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_check_args_task_and_model_general_common(self): + """test _check_args_task_and_model with general task and common model.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + # Should not raise error + trainer._check_args_task_and_model() + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_check_args_task_and_model_invalid_model_name(self): + """test _check_args_task_and_model with invalid model name for task.""" + with pytest.raises(ValueError) as exc_info: + Trainer(task='text_generation', model_name='invalid_model') + assert 'not support in task' in str(exc_info.value) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_load_model_checkpoint_none(self): + """test _load_model_checkpoint with None checkpoint.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + trainer.config.model.model_config.checkpoint_name_or_path = None + trainer._load_model_checkpoint() + # Should not raise error + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_load_model_checkpoint_invalid_type(self): + """test _load_model_checkpoint with invalid type.""" + trainer = Trainer(args=self.config, task='general', model_name='common') + trainer.config.model.model_config.checkpoint_name_or_path = 123 + with pytest.raises(TypeError) as exc_info: + trainer._load_model_checkpoint() + assert 'type error' in str(exc_info.value) + + @classmethod + def tearDownClass(cls): + """Clean up temporary directory.""" + if hasattr(cls, '_exit_stack'): + cls._exit_stack.close() + + +class TestResetConfigForSave(unittest.TestCase): + """Test _reset_config_for_save function.""" + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_reset_config_for_save_with_model(self): + """test _reset_config_for_save with model config.""" + config = MindFormerConfig() + config.model = MindFormerConfig({'type': 'test_model'}) + result = _reset_config_for_save(config) + assert isinstance(result, OrderedDict) + assert 'model' in result + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_reset_config_for_save_with_processor(self): + """test _reset_config_for_save with processor config.""" + config = MindFormerConfig() + config.processor = MindFormerConfig({'type': 'test_processor'}) + result = _reset_config_for_save(config) + assert isinstance(result, OrderedDict) + assert 'processor' in result + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_reset_config_for_save_with_datasets(self): + """test _reset_config_for_save with dataset configs.""" + config = MindFormerConfig() + config.train_dataset = MindFormerConfig({'type': 'train_dataset'}) + config.train_dataset_task = MindFormerConfig({'type': 'train_task'}) + config.eval_dataset = MindFormerConfig({'type': 'eval_dataset'}) + config.eval_dataset_task = MindFormerConfig({'type': 'eval_task'}) + result = _reset_config_for_save(config) + assert 'train_dataset' in result + assert 'train_dataset_task' in result + assert 'eval_dataset' in result + assert 'eval_dataset_task' in result + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_reset_config_for_save_with_parallel_configs(self): + """test _reset_config_for_save with parallel configs.""" + config = MindFormerConfig() + config.context = MindFormerConfig({'mode': 'GRAPH_MODE'}) + config.parallel = MindFormerConfig({'parallel_mode': 'semi_auto_parallel'}) + config.moe_config = MindFormerConfig({'expert_num': 4}) + config.recompute_config = MindFormerConfig({'recompute': True}) + config.parallel_config = MindFormerConfig({'data_parallel': 2}) + result = _reset_config_for_save(config) + assert 'context' in result + assert 'parallel' in result + assert 'moe_config' in result + assert 'recompute_config' in result + assert 'parallel_config' in result + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_reset_config_for_save_with_optimizer_lr(self): + """test _reset_config_for_save with optimizer and lr configs.""" + config = MindFormerConfig() + config.optimizer = MindFormerConfig({'type': 'AdamW'}) + config.lr_schedule = MindFormerConfig({'type': 'cosine'}) + result = _reset_config_for_save(config) + assert 'optimizer' in result + assert 'lr_schedule' in result + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_reset_config_for_save_with_callbacks(self): + """test _reset_config_for_save with callbacks config.""" + config = MindFormerConfig() + config.callbacks = [{'type': 'CheckpointMonitor'}] + result = _reset_config_for_save(config) + assert 'callbacks' in result + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_reset_config_for_save_with_runner_config(self): + """test _reset_config_for_save with runner config.""" + config = MindFormerConfig() + config.runner_config = MindFormerConfig({'sink_mode': True}) + result = _reset_config_for_save(config) + assert 'runner_config' in result + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_reset_config_for_save_with_runner_wrapper(self): + """test _reset_config_for_save with runner wrapper config.""" + config = MindFormerConfig() + config.runner_wrapper = MindFormerConfig({'type': 'TrainOneStepCell'}) + result = _reset_config_for_save(config) + assert 'runner_wrapper' in result + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_reset_config_for_save_empty_config(self): + """test _reset_config_for_save with empty config.""" + config = MindFormerConfig() + result = _reset_config_for_save(config) + assert isinstance(result, OrderedDict) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_reset_config_for_save_none_config(self): + """test _reset_config_for_save with None.""" + result = _reset_config_for_save(None) + assert isinstance(result, OrderedDict) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_reset_config_for_save_ordered_dict(self): + """test _reset_config_for_save returns OrderedDict.""" + config = MindFormerConfig() + config.model = MindFormerConfig({'type': 'test'}) + config.optimizer = MindFormerConfig({'type': 'AdamW'}) + result = _reset_config_for_save(config) + assert isinstance(result, OrderedDict) + # Check order: model should come before optimizer + keys_list = list(result.keys()) + if 'model' in keys_list and 'optimizer' in keys_list: + assert keys_list.index('model') < keys_list.index('optimizer') -- Gitee