From 93eb56dd28b14159e067c0de36bb6fb901b4560f Mon Sep 17 00:00:00 2001 From: qinsichun Date: Wed, 3 Dec 2025 15:03:05 +0800 Subject: [PATCH] test_conv --- .../test_auto/test_configuration_auto.py | 264 +++++++++++++++ .../quantization/test_base_config.py | 302 ++++++++++++++++++ .../test_tools/test_register/test_config.py | 160 ++++++++++ 3 files changed, 726 insertions(+) create mode 100644 tests/st/test_ut/test_models/test_auto/test_configuration_auto.py create mode 100644 tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization/test_base_config.py create mode 100644 tests/st/test_ut/test_tools/test_register/test_config.py diff --git a/tests/st/test_ut/test_models/test_auto/test_configuration_auto.py b/tests/st/test_ut/test_models/test_auto/test_configuration_auto.py new file mode 100644 index 000000000..cb9ecf76e --- /dev/null +++ b/tests/st/test_ut/test_models/test_auto/test_configuration_auto.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 Huawei Technologies +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Unit tests for mindformers.models.auto.configuration_auto.""" +from types import SimpleNamespace + +import pytest + +import mindformers.models.auto.configuration_auto as auto_cfg +from mindformers.models.auto.configuration_auto import ( + AutoConfig, + CONFIG_MAPPING, + config_class_to_model_type, + _LazyConfigMapping, + _list_model_options, + replace_list_option_in_docstrings, +) +from mindformers.models.configuration_utils import PretrainedConfig + + +class DummyMindFormerConfig(dict): + """Stub MindFormerConfig for unit tests.""" + + def __init__(self, use_legacy=True, has_pretrained=False, has_generation=False): + super().__init__() + self._use_legacy = use_legacy + self.model = SimpleNamespace( + model_config={"type": "DemoConfig"}, + arch=SimpleNamespace(type="demo_arch"), + ) + if has_pretrained: + self["pretrained_model_dir"] = "pretrained_dir" + self.pretrained_model_dir = "pretrained_dir" + else: + self.pretrained_model_dir = None + if has_generation: + self["generation_config"] = {"gen": True} + self.generation_config = {"gen": True} + else: + self.generation_config = None + + def get_value(self, key, default=None): + """Get value from config.""" + if key == "use_legacy": + return self._use_legacy + return default + + +@pytest.fixture(autouse=True) +def restore_extra_content(): + """Ensure CONFIG_MAPPING extra registrations are restored between tests.""" + # Accessing protected member for test cleanup is intentional + original = CONFIG_MAPPING._extra_content.copy() # pylint: disable=W0212,protected-access + yield + CONFIG_MAPPING._extra_content = original # pylint: disable=W0212,protected-access + + +class TestConfigurationAuto: + """Test class for mindformers.models.auto.configuration_auto.""" + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_config_class_to_model_type_core_and_extra(self, monkeypatch): + """config_class_to_model_type should inspect default and extra registries.""" + assert config_class_to_model_type("LlamaConfig") == "llama" + dummy_class = type("NewConfig", (), {}) + # Accessing protected member for test setup is intentional + monkeypatch.setitem(CONFIG_MAPPING._extra_content, "custom", dummy_class) # pylint: disable=W0212,protected-access + assert config_class_to_model_type("NewConfig") == "custom" + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_lazy_config_mapping_register_and_getitem(self, monkeypatch): + """_LazyConfigMapping should import modules lazily and honor register.""" + module = SimpleNamespace(MockConfig="sentinel") + monkeypatch.setattr(auto_cfg.importlib, "import_module", lambda name, package=None: module) + mapping = _LazyConfigMapping({"mock": "MockConfig"}) + assert mapping["mock"] == "sentinel" + mapping.register("extra", "ExtraConfig", exist_ok=True) + assert mapping["extra"] == "ExtraConfig" + with pytest.raises(ValueError): + mapping.register("mock", "OtherConfig") + with pytest.raises(KeyError): + _ = mapping["missing"] + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_list_model_options_and_docstring_replacement(self): + """_list_model_options and decorator should update docstrings or raise errors.""" + doc = _list_model_options(" ", {"llama": ["LlamaConfig"]}, use_model_types=False) + assert "LlamaConfig" in doc + + @replace_list_option_in_docstrings({"llama": ["LlamaConfig"]}) + def sample(): + """List options""" + + assert "llama" in sample.__doc__ + + def broken(): + """no placeholder""" + + decorator = replace_list_option_in_docstrings({"llama": ["LlamaConfig"]}) + with pytest.raises(ValueError): + decorator(broken) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_autoconfig_invalid_yaml_name_branches(self, monkeypatch): + """AutoConfig.invalid_yaml_name should validate against support list.""" + monkeypatch.setattr(AutoConfig, "_support_list", + {"llama": ["llama_7b"], "glm": {"9b": ["glm_9b"]}}) + assert AutoConfig.invalid_yaml_name("unknown_model") + assert not AutoConfig.invalid_yaml_name("llama_7b") + assert not AutoConfig.invalid_yaml_name("glm_9b") + with pytest.raises(ValueError): + AutoConfig.invalid_yaml_name("glm_bad") + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_autoconfig_for_model_and_error(self): + """AutoConfig.for_model should instantiate registered configs or raise.""" + class DummyConfig(PretrainedConfig): + """Dummy config for unit tests.""" + model_type = "dummy_key" + + def __init__(self, value=None): + super().__init__() + self.value = value + + CONFIG_MAPPING.register("dummy_key", DummyConfig, exist_ok=True) + result = AutoConfig.for_model("dummy_key", value=3) + assert isinstance(result, DummyConfig) and result.value == 3 + with pytest.raises(ValueError): + AutoConfig.for_model("missing") + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_from_pretrained_switches_modes(self, monkeypatch): + """AutoConfig.from_pretrained should delegate based on experimental flag.""" + monkeypatch.setattr(auto_cfg, "is_experimental_mode", lambda _: False) + monkeypatch.setattr(AutoConfig, "get_config_origin_mode", + classmethod(lambda cls, name, **_: ("origin", name))) + res = AutoConfig.from_pretrained("path/model.yaml", pretrained_model_name_or_path="override") + assert res == ("origin", "override") + monkeypatch.setattr(auto_cfg, "is_experimental_mode", lambda _: True) + monkeypatch.setattr(AutoConfig, "get_config_experimental_mode", + classmethod(lambda cls, name, **_: ("exp", name))) + assert AutoConfig.from_pretrained("path/model.yaml") == ("exp", "path/model.yaml") + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_get_config_origin_mode_type_and_extension_errors(self, tmp_path): + """get_config_origin_mode should validate input types and extensions.""" + with pytest.raises(TypeError): + AutoConfig.get_config_origin_mode(123) + bad_file = tmp_path / "not_yaml.txt" + bad_file.write_text("content", encoding="utf-8") + with pytest.raises(ValueError): + AutoConfig.get_config_origin_mode(str(bad_file)) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_get_config_origin_mode_invalid_yaml_name(self, monkeypatch): + """Non-existing yaml names should raise ValueError.""" + monkeypatch.setattr(AutoConfig, "invalid_yaml_name", classmethod(lambda cls, _: True)) + with pytest.raises(ValueError): + AutoConfig.get_config_origin_mode("unknown_name") + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_get_config_origin_mode_legacy_flow(self, monkeypatch, tmp_path): + """Legacy pathway should build configs and update auxiliary fields.""" + dummy = DummyMindFormerConfig(use_legacy=True, has_pretrained=True, has_generation=True) + monkeypatch.setattr(auto_cfg, "MindFormerConfig", lambda *_: dummy) + built = {} + monkeypatch.setattr(auto_cfg, "build_model_config", + lambda cfg: built.setdefault("config", cfg) or "legacy") + monkeypatch.setattr(auto_cfg.MindFormerBook, "set_model_config_to_name", + lambda *args, **kwargs: built.setdefault("mark", args)) + yaml_file = tmp_path / "model.yaml" + yaml_file.write_text("model: {}", encoding="utf-8") + AutoConfig.get_config_origin_mode(str(yaml_file), hidden_size=128) + assert dummy.model.model_config["hidden_size"] == 128 + assert dummy.model.pretrained_model_dir == "pretrained_dir" + assert dummy.model.generation_config == {"gen": True} + assert built["config"] == dummy.model.model_config + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_get_config_origin_mode_nonlegacy_flow(self, monkeypatch, tmp_path): + """Non-legacy pathway should use get_model_config without calling builder.""" + dummy = DummyMindFormerConfig(use_legacy=False) + monkeypatch.setattr(auto_cfg, "MindFormerConfig", lambda *_: dummy) + marker = {} + monkeypatch.setattr(auto_cfg, "build_model_config", + lambda *_: marker.setdefault("should_not_call", True)) + monkeypatch.setattr(auto_cfg, "get_model_config", + lambda model: marker.setdefault("model", model) or "new_config") + yaml_file = tmp_path / "model.yaml" + yaml_file.write_text("model: {}", encoding="utf-8") + AutoConfig.get_config_origin_mode(str(yaml_file), dropout=0.1) + assert dummy.model.model_config["dropout"] == 0.1 + assert marker["model"] == dummy.model + assert "should_not_call" not in marker + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_get_config_experimental_mode_remote_code(self, monkeypatch): + """Remote code configs should be loaded via dynamic modules when trusted.""" + monkeypatch.setattr(auto_cfg.PretrainedConfig, "get_config_dict", + classmethod(lambda cls, name, + **kwargs: ({"auto_map": {"AutoConfig": "mod.Class"}}, {}))) + monkeypatch.setattr(auto_cfg, "resolve_trust_remote_code", lambda trust, *args, **kwargs: True) + + class RemoteConfig: + """Remote config for unit tests.""" + @staticmethod + def register_for_auto_class(): + """Register for auto class.""" + RemoteConfig.registered = True + + @staticmethod + def from_pretrained(name, **kwargs): + """From pretrained.""" + return {"name": name, "kwargs": kwargs} + + monkeypatch.setattr(auto_cfg, "get_class_from_dynamic_module", lambda *args, **kwargs: RemoteConfig) + monkeypatch.setattr(auto_cfg.os.path, "isdir", lambda _: True) + result = AutoConfig.get_config_experimental_mode("remote_repo", trust_remote_code=True) + assert result["name"] == "remote_repo" + assert RemoteConfig.registered is True + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_get_config_experimental_mode_local_config(self, monkeypatch): + """Local configs with model_type should resolve via CONFIG_MAPPING.""" + class LocalConfig(PretrainedConfig): + """LocalConfig for tests.""" + model_type = "custom_dummy" + + @classmethod + def from_dict(cls, config_dict, **kwargs): + return {"config": config_dict, "extra": kwargs} + + CONFIG_MAPPING.register("custom_dummy", LocalConfig, exist_ok=True) + monkeypatch.setattr(auto_cfg.PretrainedConfig, "get_config_dict", + classmethod(lambda cls, name, + **kwargs: ({"model_type": "custom_dummy", "value": 1}, {"unused": True}))) + result = AutoConfig.get_config_experimental_mode("local_repo") + assert result["config"]["value"] == 1 + assert result["extra"]["unused"] is True diff --git a/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization/test_base_config.py b/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization/test_base_config.py new file mode 100644 index 000000000..836f4e776 --- /dev/null +++ b/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization/test_base_config.py @@ -0,0 +1,302 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Unit tests for base_config.py""" + +from typing import List, Any, Optional +import pytest +import mindspore +from mindspore import nn, Tensor + +from mindformers.parallel_core.inference.quantization.base_config import ( + QuantizeMethodBase, + QuantizationConfig, +) + + +class ConcreteQuantizeMethod(QuantizeMethodBase): + """Concrete implementation of QuantizeMethodBase for testing.""" + + def __init__(self): + super().__init__() + self.weights_created = False + + def create_weights(self, layer: nn.Cell, *_weight_args, **_extra_weight_attrs): + """Create weights for a layer.""" + self.weights_created = True + _ = layer + + def apply(self, layer: nn.Cell, *_args, **_kwargs) -> Tensor: + """Apply the weights in layer to the input tensor.""" + _ = layer + if not self.weights_created: + raise RuntimeError("Weights must be created before applying") + return Tensor([1.0]) + + +class ConcreteQuantizationConfig(QuantizationConfig): + """Concrete implementation of QuantizationConfig for testing.""" + + def __init__(self, name: str = "test_quant"): + super().__init__() + self._name = name + + def get_name(self) -> str: + """Name of the quantization method.""" + return self._name + + def get_supported_act_dtypes(self) -> List[str]: + """List of supported activation dtypes.""" + return ["float16", "float32"] + + @classmethod + def get_min_capability(cls) -> int: + """Minimum capability to support the quantization method.""" + return 70 + + @staticmethod + def get_config_filenames() -> list[str]: + """List of filenames to search for in the model directory.""" + return ["quantization_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "ConcreteQuantizationConfig": + """Create a config class from the model's quantization config.""" + name = config.get("quantization_type", "test_quant") + return cls(name=name) + + def get_quant_method( + self, layer: mindspore.nn.Cell, prefix: str + ) -> Optional[QuantizeMethodBase]: + """Get the quantize method to use for the quantized layer.""" + _ = prefix + if isinstance(layer, nn.Dense): + return ConcreteQuantizeMethod() + return None + + +class TestQuantizeMethodBase: + """Test class for QuantizeMethodBase.""" + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_concrete_quantize_method_create_weights(self): + """Test that concrete implementation can create weights.""" + method = ConcreteQuantizeMethod() + layer = nn.Dense(10, 20) + method.create_weights(layer) + assert method.weights_created is True + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_concrete_quantize_method_apply(self): + """Test that concrete implementation can apply weights.""" + method = ConcreteQuantizeMethod() + layer = nn.Dense(10, 20) + method.create_weights(layer) + result = method.apply(layer) + assert isinstance(result, Tensor) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_quantize_method_base_embedding_raises_runtime_error(self): + """Test that embedding method raises RuntimeError by default.""" + method = ConcreteQuantizeMethod() + layer = nn.Dense(10, 20) + with pytest.raises(RuntimeError): + method.embedding(layer) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_quantize_method_base_process_weights_after_loading(self): + """Test that process_weights_after_loading returns None by default.""" + method = ConcreteQuantizeMethod() + layer = nn.Dense(10, 20) + method.process_weights_after_loading(layer) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_apply_without_create_weights_raises_error(self): + """Test that apply without create_weights raises error.""" + method = ConcreteQuantizeMethod() + layer = nn.Dense(10, 20) + # The concrete implementation should check if weights are created + with pytest.raises(RuntimeError): + method.apply(layer) + + +class TestQuantizationConfig: + """Test class for QuantizationConfig.""" + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_concrete_config_initialization(self): + """Test that concrete config initializes packed_modules_mapping.""" + config = ConcreteQuantizationConfig() + assert isinstance(config.packed_modules_mapping, dict) + assert len(config.packed_modules_mapping) == 0 + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_concrete_config_get_name(self): + """Test get_name method.""" + config = ConcreteQuantizationConfig(name="test_quantization") + assert config.get_name() == "test_quantization" + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_concrete_config_get_supported_act_dtypes(self): + """Test get_supported_act_dtypes method.""" + config = ConcreteQuantizationConfig() + dtypes = config.get_supported_act_dtypes() + assert isinstance(dtypes, list) + assert "float16" in dtypes + assert "float32" in dtypes + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_concrete_config_get_min_capability(self): + """Test get_min_capability class method.""" + capability = ConcreteQuantizationConfig.get_min_capability() + assert isinstance(capability, int) + assert capability == 70 + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_concrete_config_get_config_filenames(self): + """Test get_config_filenames static method.""" + filenames = ConcreteQuantizationConfig.get_config_filenames() + assert isinstance(filenames, list) + assert "quantization_config.json" in filenames + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_concrete_config_from_config(self): + """Test from_config class method.""" + config_dict = {"quantization_type": "custom_quant"} + config = ConcreteQuantizationConfig.from_config(config_dict) + assert isinstance(config, ConcreteQuantizationConfig) + assert config.get_name() == "custom_quant" + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_concrete_config_get_quant_method(self): + """Test get_quant_method method.""" + config = ConcreteQuantizationConfig() + layer = nn.Dense(10, 20) + quant_method = config.get_quant_method(layer, prefix="dense_layer") + assert isinstance(quant_method, ConcreteQuantizeMethod) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_concrete_config_get_quant_method_returns_none(self): + """Test get_quant_method returns None for unsupported layer.""" + config = ConcreteQuantizationConfig() + layer = nn.ReLU() # Not a Dense layer + quant_method = config.get_quant_method(layer, prefix="relu_layer") + assert quant_method is None + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_get_from_keys_finds_first_key(self): + """Test get_from_keys finds value using first matching key.""" + config = {"quantization_type": "test", "quant_type": "alternative"} + result = QuantizationConfig.get_from_keys(config, ["quantization_type", "quant_type"]) + assert result == "test" + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_get_from_keys_finds_second_key(self): + """Test get_from_keys finds value using second key when first not present.""" + config = {"quant_type": "alternative"} + result = QuantizationConfig.get_from_keys(config, ["quantization_type", "quant_type"]) + assert result == "alternative" + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_get_from_keys_raises_value_error(self): + """Test get_from_keys raises ValueError when no key found.""" + config = {"other_key": "value"} + with pytest.raises(ValueError, match="Cannot find any of"): + QuantizationConfig.get_from_keys(config, ["quantization_type", "quant_type"]) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_get_from_keys_with_empty_keys(self): + """Test get_from_keys with empty keys list raises ValueError.""" + config = {"key": "value"} + with pytest.raises(ValueError, match="Cannot find any of"): + QuantizationConfig.get_from_keys(config, []) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_get_from_keys_or_returns_value(self): + """Test get_from_keys_or returns value when key exists.""" + config = {"quantization_type": "test"} + result = QuantizationConfig.get_from_keys_or( + config, ["quantization_type"], "default_value" + ) + assert result == "test" + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_get_from_keys_or_returns_default(self): + """Test get_from_keys_or returns default when key does not exist.""" + config = {"other_key": "value"} + default_value = "default_quant" + result = QuantizationConfig.get_from_keys_or( + config, ["quantization_type"], default_value + ) + assert result == default_value + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_get_from_keys_or_with_none_default(self): + """Test get_from_keys_or works with None as default.""" + config = {"other_key": "value"} + result = QuantizationConfig.get_from_keys_or(config, ["quantization_type"], None) + assert result is None + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_get_from_keys_or_with_empty_config(self): + """Test get_from_keys_or with empty config returns default.""" + config = {} + default_value = "default" + result = QuantizationConfig.get_from_keys_or(config, ["any_key"], default_value) + assert result == default_value + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_packed_modules_mapping_mutable(self): + """Test that packed_modules_mapping can be modified.""" + config = ConcreteQuantizationConfig() + config.packed_modules_mapping["module1"] = ["weight1", "weight2"] + assert config.packed_modules_mapping["module1"] == ["weight1", "weight2"] + assert len(config.packed_modules_mapping) == 1 + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_get_from_keys_with_different_value_types(self): + """Test get_from_keys works with different value types.""" + config = { + "int_value": 42, + "float_value": 3.14, + "list_value": [1, 2, 3], + "dict_value": {"nested": "value"}, + } + assert QuantizationConfig.get_from_keys(config, ["int_value"]) == 42 + assert QuantizationConfig.get_from_keys(config, ["float_value"]) == 3.14 + assert QuantizationConfig.get_from_keys(config, ["list_value"]) == [1, 2, 3] + assert QuantizationConfig.get_from_keys(config, ["dict_value"]) == {"nested": "value"} diff --git a/tests/st/test_ut/test_tools/test_register/test_config.py b/tests/st/test_ut/test_tools/test_register/test_config.py new file mode 100644 index 000000000..d32a55472 --- /dev/null +++ b/tests/st/test_ut/test_tools/test_register/test_config.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Unit tests for mindformers.tools.register.config.""" +import argparse +import sys +from collections import OrderedDict + +import pytest + +from mindformers.tools.register import config as config_module +from mindformers.tools.register.config import ( + ActionDict, + DictConfig, + MindFormerConfig, + BASE_CONFIG, + ordered_yaml_dump, + parse_args, +) + +yaml = pytest.importorskip("yaml") + +# pylint: disable=protected-access + + +class TestConfig: + """Test class for mindformers.tools.register.config.""" + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_dict_config_attribute_and_to_dict(self): + """DictConfig should expose attribute access semantics.""" + cfg = DictConfig(a=1, nested=DictConfig(b=2)) + assert cfg.a == 1 + cfg.c = 3 + assert cfg.c == 3 + del cfg.a + assert cfg.a is None + plain = cfg.to_dict() + assert plain == {"nested": {"b": 2}, "c": 3} + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_dict_config_deepcopy_isolated(self): + """Deep copy should create independent nested objects.""" + cfg = DictConfig(nested=DictConfig(value=[1, 2])) + copied = cfg + copied.nested.value.append(3) + assert cfg.nested.value == [1, 2] + assert copied.nested.value == [1, 2, 3] + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_mindformer_config_loads_yaml_with_base(self, monkeypatch, tmp_path): + """MindFormerConfig should merge base yaml files and convert dict to config.""" + monkeypatch.setattr(config_module.ConfigTemplate, "apply_template", lambda _: None) + base_content = {"alpha": 1, "nested": {"from_base": True}} + base_file = tmp_path / "base.yaml" + base_file.write_text(yaml.safe_dump(base_content), encoding="utf-8") + + child_content = { + BASE_CONFIG: "base.yaml", + "beta": 2, + "nested": {"from_child": True}, + } + child_file = tmp_path / "child.yaml" + child_file.write_text(yaml.safe_dump(child_content), encoding="utf-8") + + cfg = MindFormerConfig(str(child_file)) + assert cfg.alpha == 1 + assert cfg.beta == 2 + assert isinstance(cfg.nested, MindFormerConfig) + assert cfg.nested.from_child + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_mindformer_config_merge_and_set(self, monkeypatch): + """merge_from_dict and set_value should correctly update nested fields.""" + monkeypatch.setattr(config_module.ConfigTemplate, "apply_template", lambda _: None) + cfg = MindFormerConfig(model={"model_config": {"type": "Demo"}}) + cfg.merge_from_dict({"model.arch": "DemoArch", "new.branch.leaf": 10}) + assert cfg.model.arch == "DemoArch" + assert cfg.new.branch.leaf == 10 + + cfg.set_value("context.mode", "GRAPH") + cfg.set_value(["context", "device_id"], 3) + assert cfg.get_value("context.mode") == "GRAPH" + assert cfg.get_value(["context", "device_id"]) == 3 + assert cfg.get_value("context.fake", default="fallback") == "fallback" + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_file2dict_without_filename_raises(self): + """_file2dict should raise when filename is None.""" + with pytest.raises(NameError): + getattr(MindFormerConfig, "_file2dict")(None) + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_action_dict_parse_and_call(self): + """ActionDict should parse ints, floats, tuples and bool strings.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--opts", + action=ActionDict, + nargs="*", + default={}, + ) + args = parser.parse_args( + [ + "--opts", + "ints=1,2", + "floats=3.5", + "tuple=(7,8)", + "mixed=[1,(2,3),[4,5]]", + "flag=True", + ] + ) + assert args.opts["ints"] == [1, 2] + assert args.opts["floats"] == 3.5 + assert args.opts["tuple"] == (7, 8) + assert args.opts["mixed"] == [1, (2, 3), [4, 5]] + assert args.opts["flag"] is False # current implementation compares function object + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_action_dict_find_next_comma_invalid_pairs(self): + """find_next_comma should raise when brackets are unbalanced.""" + with pytest.raises(ValueError): + ActionDict.find_next_comma("[1,2") + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_ordered_yaml_dump_preserves_order(self): + """ordered_yaml_dump should keep OrderedDict order in emitted yaml.""" + ordered = OrderedDict() + ordered["first"] = 1 + ordered["second"] = 2 + dumped = ordered_yaml_dump(ordered) + assert dumped.index("first") < dumped.index("second") + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_parse_args_reads_cli(self, monkeypatch): + """parse_args should honor the --config cli argument.""" + monkeypatch.setattr(sys, "argv", ["prog", "--config", "path/to/model.yaml"]) + parsed = parse_args() + assert parsed.config == "path/to/model.yaml" -- Gitee