From c9f21c7f6632abc15304f909278a563d9d9f5e55 Mon Sep 17 00:00:00 2001 From: JingweiHuang Date: Tue, 25 Nov 2025 11:41:23 +0800 Subject: [PATCH] Add unit tests to context --- mindformers/core/context/validators.py | 18 +- .../test_context/test_build_context.py | 300 +++++++++- .../test_core/test_context/test_parallel.py | 521 ++++++++++++++++++ .../test_core/test_context/test_validators.py | 171 +++++- 4 files changed, 999 insertions(+), 11 deletions(-) create mode 100644 tests/st/test_ut/test_core/test_context/test_parallel.py diff --git a/mindformers/core/context/validators.py b/mindformers/core/context/validators.py index d0ae956e5..5ef5be17d 100644 --- a/mindformers/core/context/validators.py +++ b/mindformers/core/context/validators.py @@ -69,17 +69,23 @@ def validate_sink_size(config): def validate_precision_sync(config): - """Validate train_percision_sync and infer_percision_sync.""" + """ + Validate train_precision_sync and infer_precision_sync configuration values. + Args: + config (MindFormerConfig): Configuration object containing precision sync settings + Raises: + ValueError: If train_precision_sync or infer_precision_sync are not boolean values + """ train_precision_sync = config.get_value('train_precision_sync') - infer_percision_sync = config.get_value('train_precision_sync') + infer_precision_sync = config.get_value('infer_precision_sync') if train_precision_sync is not None and not isinstance( train_precision_sync, bool): raise ValueError( - f'train_percision_sync should be bool, got {train_precision_sync}') - if infer_percision_sync is not None and not isinstance( - infer_percision_sync, bool): + f'train_precision_sync should be bool, got {train_precision_sync}') + if infer_precision_sync is not None and not isinstance( + infer_precision_sync, bool): raise ValueError( - f'train_percision_sync should be bool, got {infer_percision_sync}') + f'infer_precision_sync should be bool, got {infer_precision_sync}') def validate_invalid_predict_mode(config): diff --git a/tests/st/test_ut/test_core/test_context/test_build_context.py b/tests/st/test_ut/test_core/test_context/test_build_context.py index e97d40af5..71595195f 100644 --- a/tests/st/test_ut/test_core/test_context/test_build_context.py +++ b/tests/st/test_ut/test_core/test_context/test_build_context.py @@ -15,20 +15,32 @@ """Test build_context.py""" import multiprocessing import os +from unittest.mock import patch + import pytest from mindformers.core.context import ( build_context, build_mf_context, + build_parallel_context, get_context, is_legacy_model, set_context, ) -from mindformers.core.context.build_context import MFContextOperator +from mindformers.core.context.build_context import ( + Context, + MFContextOperator, + MSContextOperator, + set_cpu_affinity, + set_ms_affinity, +) +from mindformers.tools.register import MindFormerConfig + def get_config_tpl(): return {'context': {'mode': 'PYNATIVE_MODE'}, 'parallel': {}} + def run_in_subprocess(func, *args): """Run testcase in subprocess and check it is successfully.""" process = multiprocessing.Process(target=func, args=args) @@ -65,6 +77,7 @@ def run_deterministic_setting( assert os.getenv('CUSTOM_MATMUL_SHUFFLE') == custom_matmul_shuffle_expect assert os.getenv('LCCL_DETERMINISTIC') == lccl_deterministic_expect + @pytest.mark.level1 @pytest.mark.platform_x86_cpu @pytest.mark.parametrize( @@ -97,9 +110,10 @@ def test_deterministic(mode, switch, hccl_deterministic_env, hccl_deterministic_expect, te_parallel_compiler_expect, custom_matmul_shuffle_expect, lccl_deterministic_expect) + @pytest.mark.level1 @pytest.mark.platform_x86_cpu -def test_build_mf_context(): +def test_mf_context_singleton(): """ Feature: Test whether MFContextOperator is a singleton. Description: The MFContextOperator instance created twice is the same object. @@ -110,6 +124,23 @@ def test_build_mf_context(): another_mf_ctx = build_mf_context(config_tpl) assert mf_ctx is another_mf_ctx + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +def test_context_singleton(): + """ + Feature: Test Context singleton pattern. + Description: Test that Context is a singleton. + Expectation: Multiple Context instances are the same object. + """ + def is_singleton_context(): + config_tpl = get_config_tpl() + ctx1 = build_context(config_tpl) + ctx2 = build_context(config_tpl) + assert ctx1 is ctx2 + run_in_subprocess(is_singleton_context) + + @pytest.mark.level1 @pytest.mark.platform_x86_cpu @pytest.mark.parametrize( @@ -126,7 +157,268 @@ def test_get_use_legacy(cfg, is_legacy_model_except): Expectation: The result of execution does not equal the expected result. """ build_mf_context(cfg) - assert is_legacy_model() == is_legacy_model_except - MFContextOperator.reset_instance() + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +def test_context_set_mf_ctx_run_mode(): + """ + Feature: Test Context.set_mf_ctx_run_mode method. + Description: Test setting run_mode with valid and invalid values. + Expectation: Valid run_mode is set, invalid run_mode raises ValueError. + """ + Context.reset_instance() + config_tpl = get_config_tpl() + ctx = build_context(config_tpl) + + # Test valid run_mode + ctx.set_mf_ctx_run_mode('train') + assert ctx.mf_ctx_opr.run_mode == 'train' + + # Test invalid run_mode + with pytest.raises(ValueError) as exc_info: + ctx.set_mf_ctx_run_mode('invalid_mode') + assert 'Invalid value' in str(exc_info.value) + + # Test None run_mode + ctx.set_mf_ctx_run_mode(None) + Context.reset_instance() + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +def test_ms_context_operator_set_save_graphs_path(): + """ + Feature: Test MSContextOperator._set_save_graphs_path. + Description: Test save_graphs_path setting. + Expectation: save_graphs_path is set when save_graphs is True. + """ + config = MindFormerConfig( + context={'save_graphs': True, 'save_graphs_path': '/tmp/graphs'}, + parallel={} + ) + operator = MSContextOperator(config) + assert operator.get_context('save_graphs_path') == '/tmp/graphs' + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +def test_ms_context_operator_predict_jit_config_o1(): + """ + Feature: Test MSContextOperator._set_predict_jit_config with O1. + Description: Test that O1 jit_level raises ValueError in predict mode. + Expectation: ValueError is raised. + """ + config = MindFormerConfig( + run_mode='predict', + context={'jit_level': 'O1'}, + parallel={} + ) + with pytest.raises(ValueError) as exc_info: + MSContextOperator(config) + assert 'O1 is not supported' in str(exc_info.value) + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +def test_ms_context_operator_predict_jit_config_o2_with_boost(): + """ + Feature: Test MSContextOperator._set_predict_jit_config with O2 and boost. + Description: Test that O2 with infer_boost=on raises ValueError. + Expectation: ValueError is raised. + """ + config = MindFormerConfig( + run_mode='predict', + context={'jit_level': 'O2', 'infer_boost': 'on'}, + parallel={} + ) + with pytest.raises(ValueError) as exc_info: + MSContextOperator(config) + assert 'infer_boost must set off' in str(exc_info.value) + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +def test_ms_context_operator_predict_jit_config_o2_without_boost(): + """ + Feature: Test MSContextOperator._set_predict_jit_config with O2 without boost. + Description: Test that O2 with infer_boost=off works. + Expectation: jit_config is set correctly. + """ + def is_ms_context_operator_predict_jit_config(): + config = MindFormerConfig( + run_mode='predict', + context={'jit_level': 'O2', 'infer_boost': 'off'}, + parallel={} + ) + operator = MSContextOperator(config) + assert operator.get_context("jit_level") == "O2" + assert operator.get_context("infer_boost") == "off" + run_in_subprocess(is_ms_context_operator_predict_jit_config) + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +def test_ms_context_operator_predict_jit_config_from_jit_config(): + """ + Feature: Test MSContextOperator._set_predict_jit_config with jit_config dict. + Description: Test that jit_config dict is used. + Expectation: jit_config values are taken from jit_config dict. + """ + config = MindFormerConfig( + run_mode='predict', + context={ + 'jit_level': 'O0', + 'infer_boost': 'on', + 'jit_config': {'jit_level': 'O2', 'infer_boost': 'off'} + }, + parallel={} + ) + operator = MSContextOperator(config) + assert operator.get_context("jit_level") == "O2" + assert operator.get_context("infer_boost") == "off" + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +def test_set_context_without_build(): + """ + Feature: Test set_context without building context first. + Description: Test that set_context raises RuntimeError when Context doesn't exist. + Expectation: RuntimeError is raised. + """ + Context.reset_instance() + with pytest.raises(RuntimeError) as exc_info: + set_context(run_mode='train') + assert 'Build a Context instance' in str(exc_info.value) + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +def test_get_context_without_build(): + """ + Feature: Test get_context without building context first. + Description: Test that get_context raises RuntimeError when Context doesn't exist. + Expectation: RuntimeError is raised. + """ + Context.reset_instance() + with pytest.raises(RuntimeError) as exc_info: + get_context('mode') + assert 'Build a Context instance' in str(exc_info.value) + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +def test_build_parallel_context(): + """ + Feature: Test build_parallel_context function. + Description: Test building parallel context. + Expectation: ParallelOperator is returned. + """ + config_tpl = get_config_tpl() + parallel_opr = build_parallel_context(config_tpl) + assert parallel_opr is not None + assert hasattr(parallel_opr, 'parallel_ctx') + assert hasattr(parallel_opr, 'parallel') + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@patch('mindformers.core.context.build_context.get_real_local_rank') +@patch('mindformers.core.context.build_context.ms.runtime.set_cpu_affinity') +def test_set_ms_affinity_with_affinity_config(mock_set_affinity, mock_rank): + """ + Feature: Test set_ms_affinity with affinity_config. + Description: Verify affinity_config overrides affinity_cpu_list and passes module config. + Expectation: MindSpore set_cpu_affinity called with config values. + """ + mock_rank.return_value = 1 + affinity_config = { + 'device_1': { + 'affinity_cpu_list': [0, 1], + 'module_to_cpu_dict': {'module_a': [2, 3]} + } + } + set_ms_affinity(affinity_config, [4, 5]) + mock_set_affinity.assert_called_once_with( + True, + [0, 1], + {'module_a': [2, 3]} + ) + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@patch('mindformers.core.context.build_context.get_real_local_rank') +@patch('mindformers.core.context.build_context.ms.runtime.set_cpu_affinity') +def test_set_ms_affinity_without_device_entry(mock_set_affinity, mock_rank): + """ + Feature: Test set_ms_affinity when device entry missing. + Description: Verify defaults are used when affinity_config lacks device info. + Expectation: MindSpore set_cpu_affinity called with None values. + """ + mock_rank.return_value = 0 + affinity_config = { + 'device_1': { + 'affinity_cpu_list': [4, 5], + 'module_to_cpu_dict': {'module_a': [6, 7]} + } + } + set_ms_affinity(affinity_config, None) + mock_set_affinity.assert_called_once_with( + True, + None, + None + ) + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@patch('mindformers.core.context.build_context.get_cann_workqueue_cores', return_value=[0, 1]) +@patch('mindformers.core.context.build_context.psutil.Process') +@patch('mindformers.core.context.build_context.psutil.cpu_count', return_value=8) +@patch('mindformers.core.context.build_context.ds.config.set_numa_enable') +def test_set_cpu_affinity_bind_available_cpus(mock_set_numa, mock_cpu_count, + mock_process_cls, mock_get_cores, + monkeypatch): + """ + Feature: Test set_cpu_affinity binding behavior. + Description: Verify CPU affinity excludes CANN workqueue cores when available. + Expectation: Process cpu_affinity receives filtered CPU list. + """ + monkeypatch.setenv('CPU_AFFINITY', 'True') + process_mock = mock_process_cls.return_value + + set_cpu_affinity(rank_id=0, rank_size=2) + + mock_set_numa.assert_called_once_with(True) + mock_cpu_count.assert_called_once() + mock_get_cores.assert_called_once_with(0) + process_mock.cpu_affinity.assert_called_once_with([2, 3]) + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@patch('mindformers.core.context.build_context.get_cann_workqueue_cores', return_value=[0, 1, 2, 3]) +@patch('mindformers.core.context.build_context.psutil.Process') +@patch('mindformers.core.context.build_context.psutil.cpu_count', return_value=8) +@patch('mindformers.core.context.build_context.ds.config.set_numa_enable') +def test_set_cpu_affinity_fallback_when_all_cores_taken(mock_set_numa, mock_cpu_count, + mock_process_cls, mock_get_cores, + monkeypatch): + """ + Feature: Test set_cpu_affinity fallback behavior. + Description: Verify original CPU list is used when CANN occupies all candidate cores. + Expectation: Process cpu_affinity receives unfiltered CPU list. + """ + monkeypatch.setenv('CPU_AFFINITY', 'True') + process_mock = mock_process_cls.return_value + + set_cpu_affinity(rank_id=0, rank_size=2) + + mock_set_numa.assert_called_once_with(True) + mock_cpu_count.assert_called_once() + mock_get_cores.assert_called_once_with(0) + process_mock.cpu_affinity.assert_called_once_with([0, 1, 2, 3]) diff --git a/tests/st/test_ut/test_core/test_context/test_parallel.py b/tests/st/test_ut/test_core/test_context/test_parallel.py new file mode 100644 index 000000000..163033eba --- /dev/null +++ b/tests/st/test_ut/test_core/test_context/test_parallel.py @@ -0,0 +1,521 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test parallel.py""" +# pylint: disable=protected-access +from unittest.mock import patch + +import pytest +import mindspore.context as ms_context + +from mindformers.core.context.parallel import ParallelOperator +from mindformers.tools.register import MindFormerConfig +from mindformers.modules.transformer.transformer import TransformerOpParallelConfig + + +@pytest.fixture(name="mock_config") +def fixture_mock_config(): + """Create a mock config for testing.""" + config = MindFormerConfig( + use_parallel=False, + parallel={}, + parallel_config={} + ) + return config + + +@pytest.fixture(name="mock_config_with_parallel") +def fixture_mock_config_with_parallel(): + """Create a mock config with parallel enabled.""" + config = MindFormerConfig( + use_parallel=True, + parallel={ + 'parallel_mode': 'semi_auto_parallel', + 'full_batch': True + }, + parallel_config={ + 'data_parallel': 2, + 'model_parallel': 2, + 'pipeline_stage': 1 + } + ) + return config + + +class TestParallelOperator: + """Test ParallelOperator class.""" + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_init_without_parallel(self, mock_config): + """ + Feature: Test ParallelOperator initialization without parallel. + Description: Test that ParallelOperator can be initialized with use_parallel=False. + Expectation: ParallelOperator is initialized successfully. + """ + operator = ParallelOperator(mock_config) + assert operator.parallel_ctx is not None + assert operator.parallel is not None + assert not hasattr(operator, 'config') + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_init_with_parallel(self, mock_config_with_parallel): + """ + Feature: Test ParallelOperator initialization with parallel. + Description: Test that ParallelOperator can be initialized with use_parallel=True. + Expectation: ParallelOperator is initialized successfully. + """ + operator = ParallelOperator(mock_config_with_parallel) + assert operator.parallel_ctx is not None + assert operator.parallel is not None + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_set_pipeline_stage(self): + """ + Feature: Test _set_pipeline_stage method. + Description: Test pipeline stage setting logic. + Expectation: Pipeline stage is set correctly. + """ + config = MindFormerConfig( + use_parallel=True, + parallel={ + 'auto_pipeline': False + }, + parallel_config={ + 'pipeline_stage': 2 + } + ) + operator = ParallelOperator(config) + assert operator.parallel.pipeline_stage == 2 + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_set_pipeline_stage_with_auto_pipeline(self): + """ + Feature: Test _set_pipeline_stage with auto_pipeline=True. + Description: Test that auto_pipeline raises ValueError. + Expectation: ValueError is raised when auto_pipeline is True. + """ + config = MindFormerConfig( + use_parallel=True, + parallel={ + 'auto_pipeline': True + }, + parallel_config={ + 'pipeline_stage': 2 + } + ) + with pytest.raises(ValueError) as exc_info: + ParallelOperator(config) + assert "Automatic pipeline stage is unavailable" in str(exc_info.value) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_set_pipeline_stage_single_stage(self): + """ + Feature: Test _set_pipeline_stage with single stage. + Description: Test pipeline stage with value 1. + Expectation: Pipeline stages is not set when final_stages <= 1. + """ + config = MindFormerConfig( + use_parallel=True, + parallel={ + 'auto_pipeline': False + }, + parallel_config={ + 'pipeline_stage': 1 + } + ) + operator = ParallelOperator(config) + assert operator.parallel.pipeline_stage == 1 + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_set_pipeline_stage_none(self): + """ + Feature: Test _set_pipeline_stage with None pipeline_stage. + Description: Test pipeline stage with None value. + Expectation: Pipeline stage defaults to 1. + """ + config = MindFormerConfig( + use_parallel=True, + parallel={ + 'auto_pipeline': False + }, + parallel_config={} + ) + operator = ParallelOperator(config) + assert operator.parallel.pipeline_stage == 1 + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_get_parallel_ctx_config_with_full_batch(self): + """ + Feature: Test _get_parallel_ctx_config with full_batch. + Description: Test that full_batch is set to False for non-parallel modes. + Expectation: full_batch is False when parallel_mode is not SEMI_AUTO_PARALLEL or AUTO_PARALLEL. + """ + config = MindFormerConfig( + use_parallel=False, + parallel={ + 'parallel_mode': 'stand_alone', + 'full_batch': True + }, + parallel_config={} + ) + operator = ParallelOperator(config) + assert operator.parallel_ctx.get('full_batch') is False + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_get_parallel_ctx_config_with_semi_auto_parallel(self): + """ + Feature: Test _get_parallel_ctx_config with SEMI_AUTO_PARALLEL. + Description: Test that full_batch is preserved for SEMI_AUTO_PARALLEL mode. + Expectation: full_batch remains True for SEMI_AUTO_PARALLEL mode. + """ + config = MindFormerConfig( + use_parallel=False, + parallel={ + 'parallel_mode': ms_context.ParallelMode.SEMI_AUTO_PARALLEL, + 'full_batch': True + }, + parallel_config={} + ) + operator = ParallelOperator(config) + assert operator.parallel_ctx.get('full_batch') is True + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_get_parallel_ctx_config_with_auto_parallel(self): + """ + Feature: Test _get_parallel_ctx_config with AUTO_PARALLEL. + Description: Test that full_batch is preserved for AUTO_PARALLEL mode. + Expectation: full_batch remains True for AUTO_PARALLEL mode. + """ + config = MindFormerConfig( + use_parallel=False, + parallel={ + 'parallel_mode': ms_context.ParallelMode.AUTO_PARALLEL, + 'full_batch': True + }, + parallel_config={} + ) + operator = ParallelOperator(config) + assert operator.parallel_ctx.get('full_batch') is True + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_get_parallel_ctx_config_without_full_batch(self): + """ + Feature: Test _get_parallel_ctx_config without full_batch. + Description: Test that function works without full_batch key. + Expectation: Function executes successfully. + """ + config = MindFormerConfig( + use_parallel=False, + parallel={ + 'parallel_mode': 'stand_alone' + }, + parallel_config={} + ) + operator = ParallelOperator(config) + assert 'full_batch' not in operator.parallel_ctx or operator.parallel_ctx.get('full_batch') is None + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_get_parallel_config_dict(self): + """ + Feature: Test _get_parallel_config with dict. + Description: Test that dict parallel_config is handled correctly. + Expectation: ParallelConfig is created from dict. + """ + config = MindFormerConfig( + use_parallel=False, + parallel={}, + parallel_config={ + 'data_parallel': 2, + 'model_parallel': 4 + } + ) + operator = ParallelOperator(config) + assert operator.parallel.data_parallel == 2 + assert operator.parallel.model_parallel == 4 + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_get_parallel_config_transformer_op(self): + """ + Feature: Test _get_parallel_config with TransformerOpParallelConfig. + Description: Test that TransformerOpParallelConfig is converted to dict. + Expectation: ParallelConfig is created from TransformerOpParallelConfig. + """ + transformer_config = TransformerOpParallelConfig( + data_parallel=2, + model_parallel=4 + ) + config = MindFormerConfig( + use_parallel=False, + parallel={}, + parallel_config=transformer_config + ) + operator = ParallelOperator(config) + assert operator.parallel.data_parallel == 2 + assert operator.parallel.model_parallel == 4 + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @patch('mindformers.core.context.parallel.init') + @patch('mindformers.core.context.parallel.get_group_size') + @patch('mindformers.core.context.parallel.get_rank') + @patch('mindformers.core.context.parallel.context.reset_auto_parallel_context') + @patch('mindformers.core.context.parallel.context.get_auto_parallel_context') + def test_init_communication_success( + self, mock_get_context, mock_reset_context, mock_get_rank, mock_get_group_size, mock_init): + """ + Feature: Test init_communication method success. + Description: Test that communication is initialized successfully. + Expectation: Communication is initialized and rank/device_num are returned. + """ + mock_get_rank.return_value = 0 + mock_get_group_size.return_value = 8 + mock_get_context.return_value = "semi_auto_parallel" + + config = MindFormerConfig( + use_parallel=True, + parallel={ + 'parallel_mode': 'semi_auto_parallel' + }, + parallel_config={} + ) + operator = ParallelOperator(config) + rank, device_num = operator.init_communication() + + assert rank == 0 + assert device_num == 8 + mock_init.assert_called_once() + mock_get_group_size.assert_called_once() + mock_reset_context.assert_called_once() + assert operator.parallel_ctx['device_num'] == 8 + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @patch('mindformers.core.context.parallel.init') + def test_init_communication_failure(self, mock_init): + """ + Feature: Test init_communication method failure. + Description: Test that communication initialization failure is handled. + Expectation: Exception is raised with appropriate error message. + """ + mock_init.side_effect = Exception("Communication init failed") + + config = MindFormerConfig( + use_parallel=True, + parallel={}, + parallel_config={} + ) + operator = ParallelOperator(config) + + with pytest.raises(Exception) as exc_info: + operator.init_communication() + assert "Communication init failed" in str(exc_info.value) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @patch('mindformers.core.context.parallel.context.set_auto_parallel_context') + def test_set_ms_auto_parallel_context_with_full_batch_false(self, mock_set_context): + """ + Feature: Test _set_ms_auto_parallel_context with full_batch=False. + Description: Test dataset_strategy conversion from list to tuple. + Expectation: dataset_strategy is converted to tuple when full_batch is False. + """ + parallel_ctx = { + 'full_batch': False, + 'dataset_strategy': [[1, 2], [1, 4]] + } + ParallelOperator._set_ms_auto_parallel_context(**parallel_ctx) + mock_set_context.assert_called_once() + call_args = mock_set_context.call_args[1] + assert isinstance(call_args['dataset_strategy'], tuple) + assert call_args['dataset_strategy'] == ((1, 2), (1, 4)) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @patch('mindformers.core.context.parallel.context.set_auto_parallel_context') + def test_set_ms_auto_parallel_context_with_non_list_dataset_strategy(self, mock_set_context): + """ + Feature: Test _set_ms_auto_parallel_context with non-list dataset_strategy. + Description: Test that non-list dataset_strategy is not converted. + Expectation: Non-list dataset_strategy remains unchanged. + """ + parallel_ctx = { + 'full_batch': False, + 'dataset_strategy': "data_parallel" + } + ParallelOperator._set_ms_auto_parallel_context(**parallel_ctx) + mock_set_context.assert_called_once() + call_args = mock_set_context.call_args[1] + assert call_args.get('dataset_strategy') is None + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @patch('mindformers.core.context.parallel.context.get_auto_parallel_context') + @patch('mindformers.core.context.parallel.set_algo_parameters') + @patch('mindformers.core.context.parallel._set_multi_subgraphs') + def test_set_ms_parallel_auto_parallel(self, mock_set_multi_subgraphs, + mock_set_algo_parameters, mock_get_context): + """ + Feature: Test _set_ms_parallel with auto_parallel mode. + Description: Test algo parameters for auto_parallel mode. + Expectation: elementwise_op_strategy_follow=False and fully_use_devices=False. + """ + mock_get_context.return_value = "auto_parallel" + ParallelOperator._set_ms_parallel() + mock_set_algo_parameters.assert_called_once_with( + elementwise_op_strategy_follow=False, + fully_use_devices=False + ) + mock_set_multi_subgraphs.assert_called_once() + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @patch('mindformers.core.context.parallel.context.get_auto_parallel_context') + @patch('mindformers.core.context.parallel.set_algo_parameters') + @patch('mindformers.core.context.parallel._set_multi_subgraphs') + def test_set_ms_parallel_semi_auto_parallel(self, mock_set_multi_subgraphs, + mock_set_algo_parameters, mock_get_context): + """ + Feature: Test _set_ms_parallel with semi_auto_parallel mode. + Description: Test algo parameters for semi_auto_parallel mode. + Expectation: elementwise_op_strategy_follow=True and fully_use_devices=True. + """ + mock_get_context.return_value = "semi_auto_parallel" + ParallelOperator._set_ms_parallel() + mock_set_algo_parameters.assert_called_once_with( + elementwise_op_strategy_follow=True, + fully_use_devices=True + ) + mock_set_multi_subgraphs.assert_called_once() + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @patch('mindformers.core.context.parallel.context.get_auto_parallel_context') + @patch('mindformers.core.context.parallel.initialize_model_parallel') + def test_set_manmul_parallel_stand_alone(self, mock_init_model_parallel, mock_get_context): + """ + Feature: Test _set_manmul_parallel with stand_alone mode. + Description: Test that model parallel is initialized for stand_alone mode. + Expectation: initialize_model_parallel is called with correct parameters. + """ + mock_get_context.return_value = "stand_alone" + config = MindFormerConfig( + use_parallel=False, + parallel={}, + parallel_config={ + 'model_parallel': 2, + 'data_parallel': 4, + 'pipeline_stage': 1, + 'expert_parallel': 1 + } + ) + operator = ParallelOperator(config) + operator._set_manmul_parallel() + + mock_init_model_parallel.assert_called_once_with( + tensor_model_parallel_size=2, + data_parallel_size=4, + pipeline_model_parallel_size=1, + expert_model_parallel_size=1, + order="tp-ep-pp-dp" + ) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @patch('mindformers.core.context.parallel.context.get_auto_parallel_context') + @patch('mindformers.core.context.parallel.initialize_model_parallel') + def test_set_manmul_parallel_not_stand_alone(self, mock_init_model_parallel, mock_get_context): + """ + Feature: Test _set_manmul_parallel with non-stand_alone mode. + Description: Test that model parallel is not initialized for non-stand_alone mode. + Expectation: initialize_model_parallel is not called. + """ + mock_get_context.return_value = "semi_auto_parallel" + config = MindFormerConfig( + use_parallel=False, + parallel={}, + parallel_config={} + ) + operator = ParallelOperator(config) + operator._set_manmul_parallel() + + mock_init_model_parallel.assert_not_called() + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @patch('mindformers.core.context.parallel.context.get_auto_parallel_context') + @patch('mindformers.core.context.parallel.initialize_model_parallel') + def test_set_manmul_parallel_with_all_parallel_types(self, mock_init_model_parallel, mock_get_context): + """ + Feature: Test _set_manmul_parallel with all parallel types. + Description: Test parallel strategy string generation. + Expectation: Parallel strategy string includes all parallel types > 1. + """ + mock_get_context.return_value = "stand_alone" + config = MindFormerConfig( + use_parallel=False, + parallel={}, + parallel_config={ + 'model_parallel': 2, + 'data_parallel': 4, + 'pipeline_stage': 2, + 'expert_parallel': 2 + } + ) + operator = ParallelOperator(config) + operator._set_manmul_parallel() + + mock_init_model_parallel.assert_called_once() + call_args = mock_init_model_parallel.call_args[1] + assert call_args['tensor_model_parallel_size'] == 2 + assert call_args['data_parallel_size'] == 4 + assert call_args['pipeline_model_parallel_size'] == 2 + assert call_args['expert_model_parallel_size'] == 2 + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + @patch('mindformers.core.context.parallel.context.get_auto_parallel_context') + @patch('mindformers.core.context.parallel.initialize_model_parallel') + def test_set_manmul_parallel_with_default_values(self, mock_init_model_parallel, mock_get_context): + """ + Feature: Test _set_manmul_parallel with default values. + Description: Test that default values are used when attributes are missing. + Expectation: Default values (1) are used for missing attributes. + """ + mock_get_context.return_value = "stand_alone" + config = MindFormerConfig( + use_parallel=False, + parallel={}, + parallel_config={} + ) + operator = ParallelOperator(config) + operator._set_manmul_parallel() + + mock_init_model_parallel.assert_called_once() + call_args = mock_init_model_parallel.call_args[1] + assert call_args['tensor_model_parallel_size'] == 1 + assert call_args['data_parallel_size'] == 1 + assert call_args['pipeline_model_parallel_size'] == 1 + assert call_args['expert_model_parallel_size'] == 1 diff --git a/tests/st/test_ut/test_core/test_context/test_validators.py b/tests/st/test_ut/test_core/test_context/test_validators.py index c91a1d525..4c71799f4 100644 --- a/tests/st/test_ut/test_core/test_context/test_validators.py +++ b/tests/st/test_ut/test_core/test_context/test_validators.py @@ -13,11 +13,21 @@ # limitations under the License. # ============================================================================ """Test the validators of context.""" +from unittest.mock import patch + import pytest -from mindformers.core.context.validators import validate_invalid_predict_mode +from mindformers.core.context.validators import ( + validate_invalid_predict_mode, + validate_ms_ctx_mode, + validate_mf_ctx_run_mode, + validate_parallel_mode, + validate_precision_sync, + validate_sink_size, +) from mindformers.tools import MindFormerConfig + @pytest.mark.level1 @pytest.mark.platform_x86_cpu @pytest.mark.parametrize( @@ -51,3 +61,162 @@ def test_validate_invalid_predict(use_past, use_flash_attention, expect_error): "Flash Attention is incompatible when use_past=False") else: assert validate_invalid_predict_mode(cfg) is None + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@pytest.mark.parametrize( + 'run_mode, expect_error', ( + (None, False), + ('train', False), + ('predict', False), + ('finetune', False), + ('eval', False), + ('predict_with_train_model', False), + ('invalid_mode', True), + ) +) +def test_validate_mf_ctx_run_mode(run_mode, expect_error): + """ + Feature: Test validate_mf_ctx_run_mode function. + Description: Test run_mode validation with different values. + Expectation: Valid run_mode passes, invalid run_mode raises ValueError. + """ + config_ = {'run_mode': run_mode} + cfg = MindFormerConfig(**config_) + if expect_error: + with pytest.raises(ValueError) as exc_info: + validate_mf_ctx_run_mode(cfg) + assert 'Invalid run_mode' in str(exc_info.value) + else: + assert validate_mf_ctx_run_mode(cfg) is None + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@pytest.mark.parametrize( + 'mode, expect_error', ( + ('GRAPH_MODE', False), + ('PYNATIVE_MODE', False), + (0, False), + (1, False), + ('INVALID_MODE', True), + (999, True), + ) +) +def test_validate_ms_ctx_mode(mode, expect_error): + """ + Feature: Test validate_ms_ctx_mode function. + Description: Test context.mode validation with different values. + Expectation: Valid mode passes, invalid mode raises ValueError. + """ + config_ = {'context': {'mode': mode}} + cfg = MindFormerConfig(**config_) + if expect_error: + with pytest.raises(ValueError) as exc_info: + validate_ms_ctx_mode(cfg) + assert 'Invalid mode' in str(exc_info.value) + else: + assert validate_ms_ctx_mode(cfg) is None + # Verify mode is set + assert cfg.get_value('context.mode') is not None + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@pytest.mark.parametrize( + 'parallel_mode, expect_error', ( + ('DATA_PARALLEL', False), + ('SEMI_AUTO_PARALLEL', False), + ('AUTO_PARALLEL', False), + ('HYBRID_PARALLEL', False), + ('STAND_ALONE', False), + (0, False), + (1, False), + (2, False), + (3, False), + ('INVALID_MODE', True), + (999, True), + ) +) +def test_validate_parallel_mode(parallel_mode, expect_error): + """ + Feature: Test validate_parallel_mode function. + Description: Test parallel.parallel_mode validation with different values. + Expectation: Valid parallel_mode passes, invalid parallel_mode raises ValueError. + """ + config_ = {'parallel': {'parallel_mode': parallel_mode}} + cfg = MindFormerConfig(**config_) + if expect_error: + with pytest.raises(ValueError) as exc_info: + validate_parallel_mode(cfg) + assert 'Invalid parallel mode' in str(exc_info.value) + else: + assert validate_parallel_mode(cfg) is None + # Verify parallel_mode is set + assert cfg.get_value('parallel.parallel_mode') is not None + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@patch('mindformers.core.context.validators.check_tft_valid') +@pytest.mark.parametrize( + 'tft_valid, sink_size, expect_error', ( + (False, 1, False), + (False, 2, False), + (True, 1, False), + (True, 2, True), + ) +) +def test_validate_sink_size(mock_check_tft, tft_valid, sink_size, expect_error): + """ + Feature: Test validate_sink_size function. + Description: Test sink_size validation when TFT is valid. + Expectation: sink_size must be 1 when TFT is valid. + """ + mock_check_tft.return_value = tft_valid + config_ = {'runner_config': {'sink_size': sink_size}} + cfg = MindFormerConfig(**config_) + if expect_error: + with pytest.raises(ValueError) as exc_info: + validate_sink_size(cfg) + assert 'sink_size should be 1' in str(exc_info.value) + else: + assert validate_sink_size(cfg) is None + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@pytest.mark.parametrize( + 'train_precision_sync, infer_precision_sync, expect_error', ( + (None, None, False), + (True, None, False), + (False, None, False), + (None, True, False), + (None, False, False), + ('invalid', None, True), + (None, 'invalid', True), + (1, None, True), + (None, 0, True), + ) +) +def test_validate_precision_sync( + train_precision_sync, infer_precision_sync, expect_error + ): + """ + Feature: Test validate_precision_sync function. + Description: Test train_precision_sync and infer_precision_sync validation. + Expectation: Bool values pass, non-bool values raise ValueError. + """ + config_ = {} + if train_precision_sync is not None: + config_['train_precision_sync'] = train_precision_sync + if infer_precision_sync is not None: + config_['infer_precision_sync'] = infer_precision_sync + cfg = MindFormerConfig(**config_) + if expect_error: + with pytest.raises(ValueError) as exc_info: + validate_precision_sync(cfg) + assert 'should be bool' in str(exc_info.value) + else: + assert validate_precision_sync(cfg) is None -- Gitee