From b446a39fe7e3b301e9c37c7daf58ac1eca6e851c Mon Sep 17 00:00:00 2001 From: yiyison Date: Sat, 29 Nov 2025 09:25:11 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0load=5Fcheckpoint=5Futils.py?= =?UTF-8?q?=E4=BB=A5=E5=8F=8Arun=5Fcheck.py=E7=9A=84=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindformers/utils/load_checkpoint_utils.py | 10 +- tests/st/test_run_check.py | 15 + .../test_safetensors/test_checkpoint_utils.py | 882 ++++++++++++++++-- 3 files changed, 850 insertions(+), 57 deletions(-) create mode 100644 tests/st/test_run_check.py diff --git a/mindformers/utils/load_checkpoint_utils.py b/mindformers/utils/load_checkpoint_utils.py index 2233f5532..420915781 100644 --- a/mindformers/utils/load_checkpoint_utils.py +++ b/mindformers/utils/load_checkpoint_utils.py @@ -65,7 +65,7 @@ def _get_origin_network(network): """recursive find if cells which have function """ if 'convert_name' in dir(network): return network, True - #DFS for network + # DFS for network for cell in list(network.cells()): network, find_cell = _get_origin_network(cell) if find_cell: @@ -314,7 +314,7 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval if config.resume_training or (config.get('remove_redundancy', False) and not do_predict): # pylint: disable=W0212 network = model._train_network - #build model + # build model if config.use_parallel: compile_model( model=model, @@ -325,7 +325,7 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval sink_size=config.runner_config.sink_size, do_eval=do_eval, do_predict=do_predict ) - #wait generate all rank strategy files + # wait generate all rank strategy files barrier() # only execute qkv concat check on the main rank in predict mode @@ -337,7 +337,7 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval barrier() process_for_stand_alone_mode(config, network, strategy_path) - #merge dst strategy + # merge dst strategy strategy_path = get_merged_dst_strategy_path(config, strategy_path) load_safetensors_checkpoint(config, load_checkpoint_files, network, strategy_path, load_checkpoint, optimizer) @@ -457,7 +457,7 @@ def load_safetensors_checkpoint(config, load_checkpoint_files, network, strategy format=config.load_ckpt_format )) if not config.model.model_config.get("qkv_concat", False) \ - and is_hf_safetensors_dir(load_ckpt_path, origin_network): + and is_hf_safetensors_dir(load_ckpt_path, origin_network): logger.info("......HuggingFace weights convert name......") params_dict = origin_network.convert_weight_dict(params_dict, model_config=config.model.model_config) if optimizer and config.resume_training: diff --git a/tests/st/test_run_check.py b/tests/st/test_run_check.py new file mode 100644 index 000000000..ff36e2290 --- /dev/null +++ b/tests/st/test_run_check.py @@ -0,0 +1,15 @@ +"""Test for run_check function""" +import pytest +from mindformers import run_check + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_run_check(): + """ + Feature: Test run_check function + Description: Call run_check to check if MindSpore, MindFormers, CANN and driver versions are compatible + Expectation: No exceptions raised, all checks pass + """ + run_check() diff --git a/tests/st/test_safetensors/test_checkpoint_utils.py b/tests/st/test_safetensors/test_checkpoint_utils.py index fd8744104..f77746d4d 100644 --- a/tests/st/test_safetensors/test_checkpoint_utils.py +++ b/tests/st/test_safetensors/test_checkpoint_utils.py @@ -13,17 +13,95 @@ # limitations under the License. # ============================================================================ """test for load_checkpoint_utils.""" +# pylint: disable=W0621 +import tempfile from unittest.mock import patch, MagicMock import pytest +import numpy as np +from mindspore import Parameter from mindformers.tools.register import MindFormerConfig from mindformers.checkpoint.utils import compile_model -from mindformers.utils.load_checkpoint_utils import CkptFormat, _get_checkpoint_mode, CheckpointFileMode, \ - _check_checkpoint_path +from mindformers.models.modeling_utils import PreTrainedModel +from mindformers.utils.load_checkpoint_utils import ( + CkptFormat, _get_checkpoint_mode, CheckpointFileMode, _check_checkpoint_path, + extract_suffix, get_last_checkpoint, validate_config_with_file_mode, + update_global_step, unify_safetensors, _revise_remove_redundancy_with_file, + _get_origin_network, get_load_path_after_hf_convert, _get_src_strategy, + _get_src_file_suffix, _get_src_file, load_safetensors_checkpoint, + process_hf_checkpoint, validate_qkv_concat, get_merged_src_strategy_path, + get_merged_dst_strategy_path, process_for_stand_alone_mode, + load_checkpoint_with_safetensors +) + + +@pytest.fixture +def mock_config(): + """Create a mock config with default values""" + + class MockConfig: + """Mock configuration class for testing""" + + def __init__(self): + self.load_checkpoint = "/path/to/checkpoint" + self.load_ckpt_format = "safetensors" + self.use_parallel = False + self.auto_trans_ckpt = False + self.resume_training = None + self.remove_redundancy = False + self.output_dir = "/output" + self.src_strategy_path_or_dir = None + self.load_ckpt_async = False + self.context = type('', (), {})() + self.context.mode = "GRAPH_MODE" + self.runner_config = type('', (), {})() + self.runner_config.sink_mode = True + self.runner_config.epochs = 1 + self.runner_config.sink_size = 1 + self.runner_config.step_scale = 2.0 + self.model = type('', (), {})() + self.model.model_config = {} + self.parallel = type('', (), {})() + self.parallel.parallel_mode = "DATA_PARALLEL" + + def get(self, key, default=None): + return getattr(self, key, default) + + return MockConfig() + + +@pytest.fixture +def mock_network(): + """Create a mock network""" + mock_net = MagicMock() + mock_net.cells.return_value = [] + return mock_net + + +@pytest.fixture +def mock_model(): + """Create a mock model""" + mock_mod = MagicMock() + mock_mod.config = MagicMock() + mock_mod.config.model_type = "test_model" + return mock_mod + + +@pytest.fixture +def mock_file(): + """Create a mock file""" + mock_f = MagicMock() + mock_f.metadata.return_value = None + return mock_f + class TestCommonCheckpointMethod: """A test class for testing common methods""" + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard def test_support_type(self): """test CkptFormat support type""" # run the test @@ -32,25 +110,709 @@ class TestCommonCheckpointMethod: # verify the results assert result == ['ckpt', 'safetensors'] + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard def test_check_checkpoint_path_with_non_string_pathlike(self): """test check checkpoint path with non string pathlike""" path = 123 with pytest.raises(ValueError, - match=r"config.load_checkpoint must be a str, but got 123 as type ."): + match=r"config.load_checkpoint must be a `str`, but got `123` as type ``."): _check_checkpoint_path(path) + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard def test_check_checkpoint_path_with_nonexistent_path(self): """test check checkpoint path with nonexistent path""" path = 'NoneExistPath' - with pytest.raises(FileNotFoundError, match=r"config.load_checkpoint NoneExistPath does not exist."): + with pytest.raises(FileNotFoundError, match=r"config.load_checkpoint `NoneExistPath` does not exist."): _check_checkpoint_path(path) + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_check_checkpoint_path_with_valid_path(self): + """test check checkpoint path with valid path""" + # create a temporary directory for testing + with tempfile.TemporaryDirectory() as tmpdir: + # test with directory path + result = _check_checkpoint_path(tmpdir) + assert result == tmpdir + + # test with directory path ending with slash + result = _check_checkpoint_path(tmpdir + '/') + assert result == tmpdir + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + @pytest.mark.parametrize( + "file_path, expected", + [ + # test pattern 1: {prefix}_rank_{rank_id}-{epoch}_{step}.safetensors + ("model_rank_0-10_200.safetensors", "-10_200"), + # test pattern 2: {prefix}_rank_{rank_id}_{task_id}-{epoch}_{step}.safetensors + ("model_rank_0_1-10_200.safetensors", "_1-10_200"), + # test with invalid pattern + ("invalid_filename.safetensors", "invalid_filename") + ] + ) + def test_extract_suffix(self, file_path, expected): + """test extract_suffix function""" + result = extract_suffix(file_path) + assert result == expected + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_get_last_checkpoint(self): + """test get_last_checkpoint function""" + # setup mocks using context managers + with patch('os.path.isdir') as mock_isdir, \ + patch('os.path.exists') as mock_exists, \ + patch('os.listdir') as mock_listdir, \ + patch('os.path.getmtime') as mock_getmtime: + # setup mock return values + mock_isdir.return_value = True + mock_exists.return_value = True + mock_listdir.return_value = ["model_0.ckpt", "model_1.ckpt", "model_2.ckpt"] + mock_getmtime.side_effect = lambda x: { + "/test/model_0.ckpt": 100, + "/test/model_1.ckpt": 200, + "/test/model_2.ckpt": 300 + }[x] + + # test with valid directory + result = get_last_checkpoint("/test", "ckpt") + assert result == "/test/model_2.ckpt" + + # test with no checkpoint files + mock_listdir.return_value = ["other_file.txt"] + result = get_last_checkpoint("/test", "ckpt") + assert result is None + + # test with invalid directory + mock_isdir.return_value = False + with pytest.raises(NotADirectoryError): + get_last_checkpoint("/invalid/dir", "ckpt") + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + @pytest.mark.parametrize( + "file_mode, use_parallel, auto_trans_ckpt, expected_exception", + [ + # test single checkpoint file mode with parallel + (CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value, True, False, ValueError), + # test multi checkpoint file mode with parallel but no auto_trans_ckpt + (CheckpointFileMode.MULTI_CHECKPOINT_FILE.value, True, False, ValueError), + # test multi checkpoint file with rank id mode without parallel + (CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value, False, False, ValueError), + # test invalid mode + ("invalid_mode", False, False, ValueError), + # test valid cases - no exception expected + (CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value, False, False, None), + (CheckpointFileMode.MULTI_CHECKPOINT_FILE.value, True, True, None), + (CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value, True, False, None) + ] + ) + def test_validate_config_with_file_mode(self, file_mode, use_parallel, auto_trans_ckpt, expected_exception): + """test validate_config_with_file_mode function""" + if expected_exception: + with pytest.raises(expected_exception): + validate_config_with_file_mode(file_mode, use_parallel, auto_trans_ckpt) + else: + validate_config_with_file_mode(file_mode, use_parallel, auto_trans_ckpt) + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + @pytest.mark.parametrize( + "step_scale, initial_global_step, expected_global_step, expected_in_dict", + [ + (2.0, 100, 200, True), + (None, 100, 100, True), + (2.0, None, None, False) + ] + ) + def test_update_global_step(self, step_scale, initial_global_step, expected_global_step, expected_in_dict): + """test update_global_step function""" + # setup config + config = type('', (), {})() + config.runner_config = type('', (), {})() + config.runner_config.step_scale = step_scale + + # setup hyper_param_dict + hyper_param_dict = {} + if initial_global_step is not None: + hyper_param_dict["global_step"] = Parameter(np.array(initial_global_step, dtype=np.int32)) + + # test update_global_step + update_global_step(config, hyper_param_dict) + + # verify the results + if expected_in_dict: + assert "global_step" in hyper_param_dict + assert hyper_param_dict["global_step"].asnumpy() == expected_global_step + else: + assert "global_step" not in hyper_param_dict + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_unify_safetensors(self): + """test unify_safetensors function""" + # setup mocks using context managers + with patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \ + patch('mindformers.utils.load_checkpoint_utils.barrier') as mock_barrier, \ + patch('mindspore.unified_safetensors') as mock_unified_safetensors: + # test when is_main_rank is True + mock_is_main_rank.return_value = True + unify_safetensors("/src/checkpoint", "/src/strategy", "/dst/unified", True, "-10_200", False) + mock_unified_safetensors.assert_called_once() + mock_barrier.assert_called_once() + + # test when is_main_rank is False + mock_is_main_rank.return_value = False + mock_barrier.reset_mock() + unify_safetensors("/src/checkpoint", "/src/strategy", "/dst/unified", True, "-10_200", False) + mock_unified_safetensors.assert_called_once() # should not be called again + mock_barrier.assert_called_once() + + # test without parallel + mock_is_main_rank.return_value = True + mock_barrier.reset_mock() + unify_safetensors("/src/checkpoint", "/src/strategy", "/dst/unified", False, "-10_200", False) + mock_barrier.assert_not_called() + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + @pytest.mark.parametrize( + "config_remove_redundancy, metadata, expected_result", + [ + # test with metadata remove_redundancy=True and config remove_redundancy=False + (False, {"remove_redundancy": "True"}, True), + # test with metadata remove_redundancy=False and config remove_redundancy=True + (True, {"remove_redundancy": "False"}, False), + # test with matching metadata and config + (True, {"remove_redundancy": "True"}, True), + # test with no metadata + (True, None, True), + # test with metadata but no remove_redundancy key + (True, {"other_key": "value"}, True) + ] + ) + def test__revise_remove_redundancy_with_file(self, config_remove_redundancy, metadata, expected_result, mock_file): + """test _revise_remove_redundancy_with_file function""" + mock_file.metadata.return_value = metadata + result = _revise_remove_redundancy_with_file(config_remove_redundancy, mock_file) + assert result == expected_result + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + @pytest.mark.parametrize( + "network_has_convert_name, child_has_convert_name, expected_found", + [ + # test with network that has convert_name + (True, False, True), + # test with nested network where child has convert_name + (False, True, True), + # test with network that doesn't have convert_name and no children with it + (False, False, False) + ] + ) + def test__get_origin_network(self, network_has_convert_name, child_has_convert_name, expected_found): + """test _get_origin_network function""" + # setup mocks using context managers + with patch('mindformers.utils.load_checkpoint_utils.logger'): + if network_has_convert_name: + # create a mock network with convert_name attribute + mock_network = MagicMock() + mock_network.convert_name = MagicMock() + # Return empty list for cells() to avoid recursion + mock_network.cells.return_value = [] + else: + if child_has_convert_name: + # create a mock network without convert_name but with a child that has it + mock_child = MagicMock() + mock_child.convert_name = MagicMock() + # Return empty list for cells() to avoid further recursion + mock_child.cells.return_value = [] + + # Create a network that returns the child directly when cells() is called + mock_network = MagicMock() + mock_network.cells.return_value = [mock_child] + else: + # create a mock network without convert_name and no children with it + mock_network = MagicMock() + mock_network.cells.return_value = [] + + # Remove convert_name attribute to simulate network without it + if hasattr(mock_network, 'convert_name'): + delattr(mock_network, 'convert_name') + + # run the test + _, found = _get_origin_network(mock_network) + assert found == expected_found + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_get_load_path_after_hf_convert(self, mock_config, mock_network): + """test get_load_path_after_hf_convert function""" + # setup mocks using context managers + with patch('mindformers.utils.load_checkpoint_utils.is_hf_safetensors_dir') as mock_is_hf_safetensors_dir, \ + patch('mindformers.utils.load_checkpoint_utils.' + 'check_safetensors_addition_param_support') as mock_check_support: + # test when not hf safetensors + mock_is_hf_safetensors_dir.return_value = False + result = get_load_path_after_hf_convert(mock_config, mock_network) + assert result == "/path/to/checkpoint" + + # test when hf safetensors but not qkv_concat and not supported + mock_is_hf_safetensors_dir.return_value = True + mock_check_support.return_value = False + mock_config.model.model_config = {"qkv_concat": False} + + with patch('mindformers.utils.load_checkpoint_utils.process_hf_checkpoint', + return_value="/path/to/converted"): + with patch('mindformers.utils.load_checkpoint_utils.barrier'): + result = get_load_path_after_hf_convert(mock_config, mock_network) + assert result == "/path/to/converted" + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test__get_src_strategy(self, mock_config): + """test _get_src_strategy function""" + # setup mocks using context managers + with patch('os.path.isfile') as mock_isfile, \ + patch('os.path.isdir') as mock_isdir, \ + patch('os.path.join') as mock_join, \ + patch('os.path.exists') as mock_exists, \ + patch('os.path.dirname') as mock_dirname, \ + patch('mindformers.utils.load_checkpoint_utils.logger'): + # Test case 1: input_src_strategy is provided + mock_config.load_checkpoint = "/test/checkpoint.ckpt" + mock_config.src_strategy_path_or_dir = "/input/strategy" + mock_isdir.return_value = True + result = _get_src_strategy(mock_config) + assert result == "/input/strategy" + + # Test case 2: no strategy dir exists + mock_config.src_strategy_path_or_dir = None + mock_isfile.return_value = True + mock_exists.return_value = False + + with pytest.raises( + ValueError, + match="when use checkpoint after train/finetune, src_strategy_path_or_dir should be set" + ): + _get_src_strategy(mock_config) + + # Test case 3: config.load_checkpoint is a directory and strategy dir exists + mock_isfile.return_value = False + mock_exists.return_value = True + + # Setup mock_dirname to return a valid parent directory + mock_dirname.return_value = "/test" + + # Setup mock_join to return a valid path + mock_join.return_value = "/test/strategy" + + mock_config.load_checkpoint = "/test/checkpoint_dir" + mock_config.src_strategy_path_or_dir = None + + result = _get_src_strategy(mock_config) + assert result == "/test/strategy" + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test__get_src_file_suffix(self, mock_config): + """test _get_src_file_suffix function""" + # setup mocks using context managers + with patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \ + patch('mindformers.utils.load_checkpoint_utils.get_last_checkpoint') as mock_get_last_checkpoint, \ + patch('os.path.isfile') as mock_isfile, \ + patch('os.path.isdir') as mock_isdir: + # test when is_main_rank is True and resume_training is string + mock_is_main_rank.return_value = True + mock_config.resume_training = "checkpoint-10_200.safetensors" + mock_config.load_checkpoint = "/path/to/checkpoint" + mock_config.load_ckpt_format = "safetensors" + + with patch('mindformers.utils.load_checkpoint_utils.extract_suffix', return_value="-10_200"): + result = _get_src_file_suffix(mock_config) + assert result == ("/path/to/checkpoint", "-10_200") + + # test when is_main_rank is True and load_checkpoint is file + mock_isfile.return_value = True + mock_isdir.return_value = False + mock_config.resume_training = None + mock_config.load_checkpoint = "/path/to/rank_0/checkpoint-10_200.safetensors" + + with patch('mindformers.utils.load_checkpoint_utils.extract_suffix', return_value="-10_200"): + result = _get_src_file_suffix(mock_config) + assert result == ("/path/to", "-10_200") + + # test when is_main_rank is True and load_checkpoint is dir + mock_isfile.return_value = False + mock_isdir.return_value = True + mock_config.load_checkpoint = "/path/to/checkpoint" + mock_get_last_checkpoint.return_value = "/path/to/checkpoint/rank_0/checkpoint-10_200.safetensors" + + with patch('mindformers.utils.load_checkpoint_utils.extract_suffix', return_value="-10_200"): + result = _get_src_file_suffix(mock_config) + assert result == ("/path/to/checkpoint", "-10_200") + + # test when is_main_rank is False + mock_is_main_rank.return_value = False + result = _get_src_file_suffix(mock_config) + assert result == (None, None) + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test__get_src_file(self): + """test _get_src_file function""" + # setup mocks using context managers + with patch('os.path.exists') as mock_exists, \ + patch('os.path.join') as mock_join, \ + patch('mindformers.utils.load_checkpoint_utils.get_real_rank') as mock_get_real_rank, \ + patch('mindformers.utils.load_checkpoint_utils.get_last_checkpoint') as mock_get_last_checkpoint: + # test with checkpoint_name provided + mock_get_real_rank.return_value = 0 + mock_join.return_value = "/test/rank_0/checkpoint.ckpt" + mock_exists.return_value = True + + result = _get_src_file("/test", "checkpoint.ckpt", "ckpt") + assert result == "/test/rank_0/checkpoint.ckpt" + + # test without checkpoint_name + mock_get_last_checkpoint.return_value = "/test/rank_0/last_checkpoint.ckpt" + result = _get_src_file("/test", None, "ckpt") + assert result == "/test/rank_0/last_checkpoint.ckpt" + + # test with non-existent file + mock_exists.return_value = False + with pytest.raises(FileNotFoundError): + _get_src_file("/test", "non_existent.ckpt", "ckpt") + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_load_safetensors_checkpoint(self, mock_config, mock_network, mock_file): + """test load_safetensors_checkpoint function""" + # Setup mocks using context managers + with patch('mindformers.utils.load_checkpoint_utils._get_origin_network') as mock_get_origin_network, \ + patch('mindformers.utils.load_checkpoint_utils.ms') as mock_ms, \ + patch('mindformers.utils.load_checkpoint_utils.logger'), \ + patch('mindformers.utils.load_checkpoint_utils.safe_open') as mock_safe_open, \ + patch('mindformers.utils.load_checkpoint_utils.is_hf_safetensors_dir') as mock_is_hf_safetensors_dir: + # Setup mock return values + mock_get_origin_network.return_value = (MagicMock(), False) + mock_ms.load_checkpoint.return_value = {"param1": MagicMock()} + mock_is_hf_safetensors_dir.return_value = False + + # Mock the safe_open context manager + mock_safe_open.return_value.__enter__.return_value = mock_file + + strategy_path = "/path/to/strategy" + load_ckpt_path = "/path/to/checkpoint" + optimizer = None + + load_safetensors_checkpoint(mock_config, ["/path/to/checkpoint.safetensors"], mock_network, strategy_path, + load_ckpt_path, + optimizer) + mock_ms.load_param_into_net.assert_called_once() + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_process_hf_checkpoint(self, mock_model, tmp_path): + """test process_hf_checkpoint function""" + # setup mocks using context managers + with patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \ + patch('mindformers.utils.load_checkpoint_utils.barrier_world') as mock_barrier_world, \ + patch('mindformers.utils.load_checkpoint_utils.Process') as mock_process: + # test when is_main_rank is True + mock_is_main_rank.return_value = True + mock_process_instance = MagicMock() + mock_process_instance.exitcode = 0 + mock_process.return_value = mock_process_instance + + # Use tmp_path for output and input paths + output_dir = tmp_path / "output" / "dir" + input_checkpoint = tmp_path / "input" / "checkpoint" + # Create input directory + input_checkpoint.parent.mkdir(parents=True, exist_ok=True) + + result = process_hf_checkpoint(mock_model, str(output_dir), str(input_checkpoint)) + expected_path = str(output_dir / "test_model_ms_converted_weight") + assert result == expected_path + mock_process_instance.start.assert_called_once() + mock_process_instance.join.assert_called_once() + mock_barrier_world.assert_called_once() + + # Reset mocks for next test case + mock_process.reset_mock() + mock_process_instance = MagicMock() + mock_process_instance.exitcode = 1 + mock_process.return_value = mock_process_instance + + # test when process exits with error + with pytest.raises(RuntimeError, match="convert HuggingFace weight failed."): + process_hf_checkpoint(mock_model, str(output_dir), str(input_checkpoint)) + + # Reset mocks for next test case + mock_process.reset_mock() + mock_process_instance = MagicMock() + mock_process_instance.exitcode = 0 + mock_process.return_value = mock_process_instance + + # test when is_main_rank is False + mock_is_main_rank.return_value = False + process_hf_checkpoint(mock_model, str(output_dir), str(input_checkpoint)) + mock_process_instance.start.assert_not_called() + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + @pytest.mark.parametrize( + "model, qkv_concat_config, check_safetensors_key_return, " + "has_concat_keys, expected_exception, should_log_warning", + [ + # test with non-PreTrainedModel + ("not_a_model", False, False, False, None, True), + # test with PreTrainedModel but no concat keys + (MagicMock(spec=PreTrainedModel), False, False, False, None, False), + # Test case where check_safetensors_key returns True and qkv_concat_config is True + (MagicMock(spec=PreTrainedModel), True, True, True, None, False), + # Test case where check_safetensors_key returns False and qkv_concat_config is True + (MagicMock(spec=PreTrainedModel), True, False, True, ValueError, False), + # Test case where check_safetensors_key returns True and qkv_concat_config is False + (MagicMock(spec=PreTrainedModel), False, True, True, ValueError, False) + ] + ) + def test_validate_qkv_concat(self, model, qkv_concat_config, + check_safetensors_key_return, has_concat_keys, expected_exception, should_log_warning): + """test validate_qkv_concat function""" + # Setup mocks using context managers + with patch('mindformers.utils.load_checkpoint_utils.logger') as mock_logger, \ + patch('mindformers.utils.load_checkpoint_utils.check_safetensors_key') as mock_check_safetensors_key: + + # Setup mock behavior + mock_check_safetensors_key.return_value = check_safetensors_key_return + + # If it's a PreTrainedModel, set up obtain_qkv_ffn_concat_keys + if hasattr(model, 'obtain_qkv_ffn_concat_keys'): + model.obtain_qkv_ffn_concat_keys.return_value = ["qkv_concat_key"] if has_concat_keys else None + + # Run the test and check results + if expected_exception: + with pytest.raises(expected_exception, match="The qkv concat check failed!"): + validate_qkv_concat(model, qkv_concat_config, "/path/to/checkpoint") + else: + validate_qkv_concat(model, qkv_concat_config, "/path/to/checkpoint") + + # Check if warning was logged when expected + if should_log_warning: + mock_logger.warning.assert_called_once() + else: + mock_logger.warning.assert_not_called() + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_get_merged_src_strategy_path(self, mock_config): + """test get_merged_src_strategy_path function""" + # setup mocks using context managers + with patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \ + patch('mindformers.utils.load_checkpoint_utils.barrier') as mock_barrier, \ + patch('mindformers.utils.load_checkpoint_utils._get_src_strategy') as mock_get_src_strategy, \ + patch('mindformers.utils.load_checkpoint_utils.ms.merge_pipeline_strategys') as mock_merge_strategys, \ + patch('os.makedirs'): + # test when is_main_rank is True + mock_is_main_rank.return_value = True + mock_get_src_strategy.return_value = "/input/strategy" + + result = get_merged_src_strategy_path(mock_config) + assert result == "/output/merged_strategy/src_strategy.ckpt" + mock_merge_strategys.assert_called_once() + mock_barrier.assert_called_once() + + # test when is_main_rank is False + mock_is_main_rank.return_value = False + mock_barrier.reset_mock() + result = get_merged_src_strategy_path(mock_config) + mock_merge_strategys.assert_called_once() # should not be called again + mock_barrier.assert_called_once() + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_get_merged_dst_strategy_path(self, mock_config): + """test get_merged_dst_strategy_path function""" + # setup mocks using context managers + with patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \ + patch('mindformers.utils.load_checkpoint_utils.barrier') as mock_barrier, \ + patch('mindformers.utils.load_checkpoint_utils.ms.merge_pipeline_strategys') as mock_merge_strategys, \ + patch('os.makedirs'): + # test with use_parallel=True, auto_trans_ckpt=True, not stand_alone + mock_is_main_rank.return_value = True + + mock_config.use_parallel = True + mock_config.auto_trans_ckpt = True + mock_config.parallel.parallel_mode = "DATA_PARALLEL" + + strategy_path = "/path/to/strategy.ckpt" + + result = get_merged_dst_strategy_path(mock_config, strategy_path) + assert result == "/output/merged_strategy/dst_strategy.ckpt" + mock_merge_strategys.assert_called_once() + mock_barrier.assert_called_once() + + # test with stand_alone mode + mock_config.parallel.parallel_mode = "STAND_ALONE" + result = get_merged_dst_strategy_path(mock_config, strategy_path) + assert result == "/path/to/strategy.ckpt" + + # test with use_parallel=False + mock_config.use_parallel = False + result = get_merged_dst_strategy_path(mock_config, strategy_path) + assert result == "/path/to/strategy.ckpt" + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_process_for_stand_alone_mode(self, mock_config, mock_network): + """test process_for_stand_alone_mode function""" + strategy_path = "/path/to/strategy.ckpt" + + # setup mocks using context managers + with patch('mindformers.utils.load_checkpoint_utils._pynative_executor'), \ + patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \ + patch('mindformers.utils.load_checkpoint_utils.barrier') as mock_barrier, \ + patch('mindformers.utils.load_checkpoint_utils.generate_state_dict') as mock_generate_state_dict, \ + patch('mindformers.utils.load_checkpoint_utils.save_strategy_file') as mock_save_strategy_file, \ + patch('os.makedirs') as mock_makedirs, \ + patch('shutil.rmtree') as mock_rmtree, \ + patch('os.path.exists') as mock_exists: + # test with stand_alone mode + mock_is_main_rank.return_value = True + mock_exists.return_value = True + mock_config.parallel.parallel_mode = "STAND_ALONE" + mock_config.use_parallel = True + + process_for_stand_alone_mode(mock_config, mock_network, strategy_path) + mock_rmtree.assert_called_once() + mock_makedirs.assert_called_once() + mock_generate_state_dict.assert_called_once() + mock_save_strategy_file.assert_called_once() + mock_barrier.assert_called() + + # Reset mocks for next test case + mock_barrier.reset_mock() + mock_rmtree.reset_mock() + mock_makedirs.reset_mock() + mock_generate_state_dict.reset_mock() + mock_save_strategy_file.reset_mock() + + # test when strategy dir doesn't exist + mock_exists.return_value = False + process_for_stand_alone_mode(mock_config, mock_network, strategy_path) + mock_rmtree.assert_not_called() + + # Reset mocks for next test case + mock_barrier.reset_mock() + mock_rmtree.reset_mock() + mock_makedirs.reset_mock() + mock_generate_state_dict.reset_mock() + mock_save_strategy_file.reset_mock() + + # test when not stand_alone mode + mock_config.parallel.parallel_mode = "DATA_PARALLEL" + process_for_stand_alone_mode(mock_config, mock_network, strategy_path) + mock_rmtree.assert_not_called() + mock_barrier.assert_not_called() + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_load_checkpoint_with_safetensors(self, mock_config, mock_model, mock_network): + """test load_checkpoint_with_safetensors function""" + # setup mocks using context managers + with patch('mindformers.utils.load_checkpoint_utils._check_checkpoint_path') as mock_check_checkpoint_path, \ + patch('mindformers.utils.load_checkpoint_utils._get_checkpoint_mode') as mock_get_checkpoint_mode, \ + patch('mindformers.utils.load_checkpoint_utils.' + 'validate_config_with_file_mode') as mock_validate_config_with_file_mode, \ + patch('mindformers.utils.load_checkpoint_utils.compile_model') as mock_compile_model, \ + patch('mindformers.utils.load_checkpoint_utils.validate_qkv_concat'), \ + patch('mindformers.utils.load_checkpoint_utils.process_for_stand_alone_mode'), \ + patch('mindformers.utils.load_checkpoint_utils.' + 'get_merged_dst_strategy_path') as mock_get_merged_dst_strategy_path, \ + patch('mindformers.utils.load_checkpoint_utils.' + 'load_safetensors_checkpoint') as mock_load_safetensors_checkpoint, \ + patch('mindformers.utils.load_checkpoint_utils.logger'), \ + patch('mindformers.utils.load_checkpoint_utils.barrier'): + # setup mocks return values + mock_check_checkpoint_path.return_value = "/valid/checkpoint" + mock_get_checkpoint_mode.return_value = CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value + mock_get_merged_dst_strategy_path.return_value = "/path/to/merged/strategy" + + # setup input_data and optimizer + input_data = MagicMock() + optimizer = None + + # test with do_eval=True + load_checkpoint_with_safetensors(mock_config, mock_model, mock_network, input_data, do_eval=True, + do_predict=False, + optimizer=optimizer) + mock_check_checkpoint_path.assert_called_once() + mock_get_checkpoint_mode.assert_called_once() + mock_validate_config_with_file_mode.assert_called_once() + mock_load_safetensors_checkpoint.assert_called_once() + + # test with do_predict=True + mock_load_safetensors_checkpoint.reset_mock() + load_checkpoint_with_safetensors(mock_config, mock_model, mock_network, input_data, do_eval=False, + do_predict=True, + optimizer=optimizer) + mock_load_safetensors_checkpoint.assert_called_once() + + # test with use_parallel=True + mock_config.use_parallel = True + mock_load_safetensors_checkpoint.reset_mock() + mock_compile_model.reset_mock() + load_checkpoint_with_safetensors(mock_config, mock_model, mock_network, input_data, do_eval=False, + do_predict=False, + optimizer=optimizer) + mock_compile_model.assert_called_once() + mock_load_safetensors_checkpoint.assert_called_once() + + # test with resume_training=True + mock_config.resume_training = True + # Access protected member for testing purposes + # pylint: disable=W0212 + mock_model._train_network = MagicMock() + mock_load_safetensors_checkpoint.reset_mock() + load_checkpoint_with_safetensors(mock_config, mock_model, mock_network, input_data, do_eval=False, + do_predict=False, + optimizer=optimizer) + mock_load_safetensors_checkpoint.assert_called_once() class TestBuildModel: """A test class for testing build_model""" runner_config = {'sink_mode': True, 'epochs': 1, 'sink_size': 1} - config = {'runner_config': runner_config} + config = { + 'runner_config': runner_config, + 'context': {'mode': 0} # Add context.mode to fix AttributeError + } model = MagicMock() dataset = MagicMock() @@ -124,59 +886,75 @@ class TestBuildModel: class TestGetCheckpointMode: """A test class for testing get_checkpoint_mode""" - @patch('os.path.isfile') - @patch('os.path.isdir') - def test_single_checkpoint_file(self, mock_isdir, mock_isfile): + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_single_checkpoint_file(self): """test single checkpoint file""" - mock_isfile.return_value = True - mock_isdir.return_value = False - config = type('', (), {})() - config.load_checkpoint = '/test/checkpoint_file.safetensors' - assert _get_checkpoint_mode(config) == CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value + with patch('os.path.isfile') as mock_isfile, \ + patch('os.path.isdir') as mock_isdir: + mock_isfile.return_value = True + mock_isdir.return_value = False + config = type('', (), {})() + config.load_checkpoint = '/test/checkpoint_file.safetensors' + assert _get_checkpoint_mode(config) == CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value - @patch('os.path.isfile') - @patch('os.path.isdir') - def test_multi_checkpoint_file_with_rank_id(self, mock_isdir, mock_isfile): + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_multi_checkpoint_file_with_rank_id(self): """test multi checkpoint file with rank id""" - mock_isfile.return_value = False - mock_isdir.return_value = True - with patch('os.listdir', return_value=['rank_0']): - config = type('', (), {})() - config.load_checkpoint = '/test/checkpoint_dir/' - assert _get_checkpoint_mode(config) == CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value + with patch('os.path.isfile') as mock_isfile, \ + patch('os.path.isdir') as mock_isdir: + mock_isfile.return_value = False + mock_isdir.return_value = True + with patch('os.listdir', return_value=['rank_0']): + config = type('', (), {})() + config.load_checkpoint = '/test/checkpoint_dir/' + assert _get_checkpoint_mode(config) == CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value - @patch('os.path.isfile') - @patch('os.path.isdir') - def test_multi_checkpoint_file(self, mock_isdir, mock_isfile): + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_multi_checkpoint_file(self): """ test multi checkpoint file""" - mock_isfile.return_value = False - mock_isdir.return_value = True - with patch('os.listdir', return_value=['checkpoint.safetensors']): - config = type('', (), {})() - config.load_checkpoint = '/test/checkpoint_dir/' - config.load_ckpt_format = '.safetensors' - assert _get_checkpoint_mode(config) == CheckpointFileMode.MULTI_CHECKPOINT_FILE.value + with patch('os.path.isfile') as mock_isfile, \ + patch('os.path.isdir') as mock_isdir: + mock_isfile.return_value = False + mock_isdir.return_value = True + with patch('os.listdir', return_value=['checkpoint.safetensors']): + config = type('', (), {})() + config.load_checkpoint = '/test/checkpoint_dir/' + config.load_ckpt_format = '.safetensors' + assert _get_checkpoint_mode(config) == CheckpointFileMode.MULTI_CHECKPOINT_FILE.value - @patch('os.path.isfile') - @patch('os.path.isdir') - def test_invalid_path(self, mock_isdir, mock_isfile): + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_invalid_path(self): """test invalid path""" - mock_isfile.return_value = False - mock_isdir.return_value = False - config = type('', (), {})() - config.load_checkpoint = 'invalid_path' - with pytest.raises(ValueError, match="Provided path is neither a file nor a directory."): - _get_checkpoint_mode(config) - - @patch('os.path.isfile') - @patch('os.path.isdir') - def test_no_valid_checkpoint_files(self, mock_isdir, mock_isfile): - """test no valid checkpoint files""" - mock_isfile.return_value = False - mock_isdir.return_value = True - with patch('os.listdir', return_value=['not_a_checkpoint_file']): + with patch('os.path.isfile') as mock_isfile, \ + patch('os.path.isdir') as mock_isdir: + mock_isfile.return_value = False + mock_isdir.return_value = False config = type('', (), {})() - config.load_checkpoint = '/test/checkpoint_dir/' - config.load_ckpt_format = '.safetensors' - with pytest.raises(ValueError, match="not support mode: no valid checkpoint files found"): + config.load_checkpoint = 'invalid_path' + with pytest.raises(ValueError, match="Provided path is neither a file nor a directory."): _get_checkpoint_mode(config) + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_no_valid_checkpoint_files(self): + """test no valid checkpoint files""" + with patch('os.path.isfile') as mock_isfile, \ + patch('os.path.isdir') as mock_isdir: + mock_isfile.return_value = False + mock_isdir.return_value = True + with patch('os.listdir', return_value=['not_a_checkpoint_file']): + config = type('', (), {})() + config.load_checkpoint = '/test/checkpoint_dir/' + config.load_ckpt_format = '.safetensors' + with pytest.raises(ValueError, match="not support mode: no valid checkpoint files found"): + _get_checkpoint_mode(config) -- Gitee