From 29658bcad48dae12f914f7e5c14ca98fee74e09f Mon Sep 17 00:00:00 2001 From: kai-ma Date: Thu, 24 Jul 2025 17:07:31 +0800 Subject: [PATCH] add core/base,comp,config,dump --- .../UT/core_ut/base_ut/test_dump_actuator.py | 101 +++++++ .../UT/core_ut/base_ut/test_dump_dumper.py | 49 ++++ .../UT/core_ut/base_ut/test_dump_writer.py | 104 +++++++ .../test_dumper_offline_model.py | 26 ++ .../test_validate_params.py | 269 ++++++++++++++++++ .../UT/core_ut/dump_ut/test_onnx_model.py | 90 ++++++ .../test/UT/core_ut/dump_ut/test_tf_model.py | 234 +++++++++++++++ 7 files changed, 873 insertions(+) create mode 100644 msprobe/test/UT/core_ut/base_ut/test_dump_actuator.py create mode 100644 msprobe/test/UT/core_ut/base_ut/test_dump_dumper.py create mode 100644 msprobe/test/UT/core_ut/base_ut/test_dump_writer.py create mode 100644 msprobe/test/UT/core_ut/components_ut/test_dumper_offline_model.py create mode 100644 msprobe/test/UT/core_ut/config_initiator_ut/test_validate_params.py create mode 100644 msprobe/test/UT/core_ut/dump_ut/test_onnx_model.py create mode 100644 msprobe/test/UT/core_ut/dump_ut/test_tf_model.py diff --git a/msprobe/test/UT/core_ut/base_ut/test_dump_actuator.py b/msprobe/test/UT/core_ut/base_ut/test_dump_actuator.py new file mode 100644 index 000000000..8a061e557 --- /dev/null +++ b/msprobe/test/UT/core_ut/base_ut/test_dump_actuator.py @@ -0,0 +1,101 @@ +import unittest +from unittest import mock + +import numpy as np + +from msprobe.core.base import OfflineModelActuator +from msprobe.utils.exceptions import MsprobeException + + +class TestOfflineModelActuator(unittest.TestCase): + + def setUp(self): + self.model_path = "fake_model_path" + self.input_shape = {"input_tensor": [1, 3, 224, 224]} + self.input_path = ["fake_input.npy"] + self.actuator = OfflineModelActuator(self.model_path, self.input_shape, self.input_path) + + @mock.patch("msprobe.core.base.dump_actuator.get_tf_type2dtype_map", return_value={}) + def test_tensor2numpy_for_type_valid(self, _): + dtype = self.actuator._tensor2numpy_for_type("tensor(float)") + self.assertEqual(dtype, np.float32) + + def test_tensor2numpy_for_type_invalid(self): + with self.assertRaises(MsprobeException): + self.actuator._tensor2numpy_for_type("invalid_type") + + def test_is_dynamic_shape(self): + self.assertTrue(self.actuator._is_dynamic_shape([None, 224, 224])) + self.assertFalse(self.actuator._is_dynamic_shape([1, 3, 224])) + + def test_check_input_shape_valid(self): + try: + self.actuator._check_input_shape("input", [1, 3, 224], [1, 3, 224]) + except MsprobeException: + self.fail("Unexpected MsprobeException raised") + + def test_check_input_shape_invalid_len(self): + with self.assertRaises(MsprobeException): + self.actuator._check_input_shape("input", [1, 3, 224], [1, 3]) + + def test_check_input_shape_mismatch(self): + with self.assertRaises(MsprobeException): + self.actuator._check_input_shape("input", [1, 3, 224], [1, 1, 224]) + + def test_check_input_shape_missing(self): + with self.assertRaises(MsprobeException): + self.actuator._check_input_shape("input", [1, 3, 224], None) + + def test_process_tensor_shape_static(self): + tensor_info = self.actuator.process_tensor_shape("input_tensor", "tensor(float)", [1, 3, 224, 224]) + self.assertEqual(tensor_info[0]["shape"], [1, 3, 224, 224]) + + def test_process_tensor_shape_missing_input_shape(self): + actuator = OfflineModelActuator(self.model_path, {}, self.input_path) + with self.assertRaises(MsprobeException): + actuator.process_tensor_shape("input_tensor", "tensor(float)", [None, 3, 224, 224]) + + @mock.patch("msprobe.core.base.dump_actuator.load_npy", return_value=np.ones((1, 3, 224, 224))) + @mock.patch("msprobe.core.base.dump_actuator.logger") + def test_read_input_data_npy(self, mock_logger, mock_load_npy): + input_map = self.actuator._read_input_data(["file.npy"], ["input_tensor"], [[1, 3, 224, 224]], [np.float32]) + self.assertIn("input_tensor", input_map) + mock_load_npy.assert_called_once() + + @mock.patch("msprobe.core.base.dump_actuator.load_bin_data", return_value=np.ones((1, 3, 224, 224))) + def test_read_input_data_bin(self, mock_load_bin): + input_map = self.actuator._read_input_data(["file.bin"], ["input_tensor"], [[1, 3, 224, 224]], [np.float32]) + self.assertIn("input_tensor", input_map) + mock_load_bin.assert_called_once() + + @mock.patch("msprobe.core.base.dump_actuator.save_npy") + @mock.patch("msprobe.core.base.dump_actuator.logger") + def test_generate_random_input_data(self, mock_logger, mock_save_npy): + inputs = self.actuator._generate_random_input_data( + "mock_dir", ["input_tensor"], [[1, 3, 224, 224]], [np.float32] + ) + self.assertIn("input_tensor", inputs) + mock_save_npy.assert_called_once() + + @mock.patch.object( + OfflineModelActuator, "_generate_random_input_data", return_value={"input_tensor": np.zeros((1, 3, 224, 224))} + ) + def test_get_inputs_data_random(self, mock_generate): + actuator = OfflineModelActuator(self.model_path, self.input_shape, "") + actuator.dir_pool = mock.Mock() + actuator.dir_pool.make_input_dir.return_value = None + actuator.dir_pool.get_input_dir.return_value = "mock_dir" + + data = actuator.get_inputs_data([{"name": "input_tensor", "shape": [1, 3, 224, 224], "type": "tensor(float)"}]) + self.assertIn("input_tensor", data) + mock_generate.assert_called_once() + + @mock.patch.object( + OfflineModelActuator, "_read_input_data", return_value={"input_tensor": np.zeros((1, 3, 224, 224))} + ) + def test_get_inputs_data_from_file(self, mock_read_input): + data = self.actuator.get_inputs_data( + [{"name": "input_tensor", "shape": [1, 3, 224, 224], "type": "tensor(float)"}] + ) + self.assertIn("input_tensor", data) + mock_read_input.assert_called_once() diff --git a/msprobe/test/UT/core_ut/base_ut/test_dump_dumper.py b/msprobe/test/UT/core_ut/base_ut/test_dump_dumper.py new file mode 100644 index 000000000..a519ef910 --- /dev/null +++ b/msprobe/test/UT/core_ut/base_ut/test_dump_dumper.py @@ -0,0 +1,49 @@ +import unittest +from unittest import mock + +from msprobe.core.base import BaseDumper +from msprobe.utils.exceptions import MsprobeException + + +class DummyDumper(BaseDumper): + def register_hook(self): + pass + + +class DummyNode: + def __init__(self, name): + self.name = name + + +class TestBaseDumper(unittest.TestCase): + + def setUp(self): + self.dumper = DummyDumper(data_mode="test") + + def test_through_nodes_with_string(self): + nodes = ["input1", "input2"] + data_map = {"input1": 123, "input2": 456} + results = list(self.dumper.through_nodes(nodes, "nodeX", "input", data_map)) + self.assertEqual(results[0], ["nodeX", "input", "input1", 0, 123]) + self.assertEqual(results[1], ["nodeX", "input", "input2", 1, 456]) + + def test_through_nodes_with_named_objects(self): + nodes = [DummyNode("tensorA"), DummyNode("tensorB")] + data_map = {"tensorA": "A_data", "tensorB": "B_data"} + results = list(self.dumper.through_nodes(nodes, "MyNode", "output", data_map)) + self.assertEqual(results[0], ["MyNode", "output", "tensorA", 0, "A_data"]) + self.assertEqual(results[1], ["MyNode", "output", "tensorB", 1, "B_data"]) + + def test_through_nodes_with_invalid_type(self): + nodes = [123] # Not a str and has no 'name' + with self.assertRaises(MsprobeException): + list(self.dumper.through_nodes(nodes, "BrokenNode", "input", {})) + + @mock.patch("msprobe.core.base.dump_dumper.release") + def test_release_hook_calls_release(self, mock_release): + self.dumper.handler = ["h1", "h2", "h3"] + self.dumper.release_hook() + self.assertEqual(mock_release.call_count, 3) + mock_release.assert_any_call("h1") + mock_release.assert_any_call("h2") + mock_release.assert_any_call("h3") diff --git a/msprobe/test/UT/core_ut/base_ut/test_dump_writer.py b/msprobe/test/UT/core_ut/base_ut/test_dump_writer.py new file mode 100644 index 000000000..d27c47dc8 --- /dev/null +++ b/msprobe/test/UT/core_ut/base_ut/test_dump_writer.py @@ -0,0 +1,104 @@ +import unittest +from unittest.mock import MagicMock, patch + +import numpy as np + +from msprobe.core.base.dump_writer import RankDirFile, SaveBinTensor, SaveNpyTensor, SaveTensorStrategy +from msprobe.utils.exceptions import MsprobeException + + +# Dummy subclass for abstract RankDirFile +class DummyRankDirFile(RankDirFile): + def __init__(self, buffer_size): + super().__init__(buffer_size) + self._save_called = 0 + + def _save(self): + self._save_called += 1 + + +# Dummy subclass for SaveTensorStrategy (abstract) +class DummySaveTensorStrategy(SaveTensorStrategy): + def __init__(self): + super().__init__() + self.suffix = ".mock" + + def _save(self, data): + self._last_saved_data = data + + +class TestRankDirFile(unittest.TestCase): + + def test_add_rank_dir_and_cover_below_buffer(self): + dumper = DummyRankDirFile(buffer_size=10000) + dumper.add_rank_dir("/mock/rank") + dumper.cover("short string") + self.assertEqual(dumper._save_called, 0) + + def test_cover_triggers_save_when_buffer_exceeds(self): + dumper = DummyRankDirFile(buffer_size=10) + dumper.cover("this is a long string") + self.assertEqual(dumper._save_called, 1) + + def test_clear_cache_calls_save_when_nonzero(self): + dumper = DummyRankDirFile(buffer_size=10) + dumper.cache_file_size = 100 + dumper.cache_file = {"mock": "data"} + dumper.clear_cache() + self.assertEqual(dumper._save_called, 1) + self.assertEqual(dumper.cache_file_size, 0) + self.assertEqual(dumper.cache_file, {}) + + +class TestSaveTensorStrategy(unittest.TestCase): + + @patch("msprobe.core.base.dump_writer.get_current_timestamp", return_value="1234567890") + @patch("msprobe.core.base.dump_writer.get_valid_name", side_effect=lambda x: f"valid_{x}") + @patch("msprobe.core.base.dump_writer.SafePath.check") + @patch("msprobe.core.base.dump_writer.join_path", return_value="/mock/path") + def test_save_tensor_data_calls_save(self, mock_join, mock_check, mock_valid_name, mock_timestamp): + saver = DummySaveTensorStrategy() + saver.add_tensor_dir("/mock/dir") + saver._save = MagicMock() + + saver.save_tensor_data("node", "arg", "data123") + + self.assertTrue(saver._save.called) + self.assertEqual(saver.tensor_path, mock_check.return_value) + mock_check.assert_called_once_with(path_exist=False) + + +class TestSaveNpyTensor(unittest.TestCase): + + @patch("msprobe.core.base.dump_writer.save_npy") + def test_save_calls_save_npy(self, mock_save_npy): + saver = SaveNpyTensor() + saver.tensor_path = "/mock/file.npy" + data = np.array([1, 2, 3]) + saver._save(data) + mock_save_npy.assert_called_once_with(data, "/mock/file.npy") + + +class TestSaveBinTensor(unittest.TestCase): + + @patch("msprobe.core.base.dump_writer.save_bin_from_ndarray") + def test_save_ndarray(self, mock_save_nd): + saver = SaveBinTensor() + saver.tensor_path = "/mock/file.bin" + arr = np.array([1, 2, 3], dtype=np.float32) + saver._save(arr) + mock_save_nd.assert_called_once_with(arr, "/mock/file.bin") + + @patch("msprobe.core.base.dump_writer.save_bin_from_bytes") + def test_save_bytes(self, mock_save_bytes): + saver = SaveBinTensor() + saver.tensor_path = "/mock/file.bin" + byte_data = b"\x00\x01" + saver._save(byte_data) + mock_save_bytes.assert_called_once_with(byte_data, "/mock/file.bin") + + def test_save_invalid_type_raises_exception(self): + saver = SaveBinTensor() + saver.tensor_path = "/mock/file.bin" + with self.assertRaises(MsprobeException): + saver._save("invalid string type") diff --git a/msprobe/test/UT/core_ut/components_ut/test_dumper_offline_model.py b/msprobe/test/UT/core_ut/components_ut/test_dumper_offline_model.py new file mode 100644 index 000000000..07e71f908 --- /dev/null +++ b/msprobe/test/UT/core_ut/components_ut/test_dumper_offline_model.py @@ -0,0 +1,26 @@ +import unittest +from unittest.mock import patch + +from msprobe.base import BaseComponent +from msprobe.core.components.dumper_offline_model import OfflineModelActuatorComp + + +class TestOfflineModelActuatorComp(unittest.TestCase): + def test_inheritance(self): + self.assertTrue(issubclass(OfflineModelActuatorComp, BaseComponent)) + + @patch.object(BaseComponent, "__init__", return_value=None) + def test_init_default_priority(self, mock_base_init): + OfflineModelActuatorComp() + mock_base_init.assert_called_once_with(100) + + @patch.object(BaseComponent, "__init__", return_value=None) + def test_init_custom_priority(self, mock_base_init): + custom_priority = 200 + OfflineModelActuatorComp(priority=custom_priority) + mock_base_init.assert_called_once_with(custom_priority) + + @patch.object(BaseComponent, "__init__", return_value=None) + def test_instance_type(self, mock_base_init): + instance = OfflineModelActuatorComp() + self.assertIsInstance(instance, BaseComponent) diff --git a/msprobe/test/UT/core_ut/config_initiator_ut/test_validate_params.py b/msprobe/test/UT/core_ut/config_initiator_ut/test_validate_params.py new file mode 100644 index 000000000..d0b0975ea --- /dev/null +++ b/msprobe/test/UT/core_ut/config_initiator_ut/test_validate_params.py @@ -0,0 +1,269 @@ +import unittest +from unittest.mock import MagicMock, patch + +from msprobe.core.config_initiator.validate_params import ( + OfflineModelInput, + valid_data_mode, + valid_device, + valid_dump_extra, + valid_dump_ge_graph, + valid_dump_graph_level, + valid_dump_path, + valid_fusion_switch_file, + valid_input, + valid_list, + valid_onnx_fusion_switch, + valid_op_id, + valid_saved_model_signature, + valid_saved_model_tag, +) +from msprobe.utils.exceptions import MsprobeException + + +class TestValidators(unittest.TestCase): + @patch("msprobe.core.config_initiator.validate_params.SafePath") + def test_valid_dump_path(self, mock_msit_path): + mock_instance = MagicMock() + mock_instance.check.return_value = "valid" + mock_msit_path.return_value = mock_instance + result = valid_dump_path("some/path") + self.assertEqual(result, "valid") + mock_msit_path.assert_called_once() + + def test_valid_list_invalid_dict_key(self): + value = ({"bad": [1]}, ["allowed"]) + with self.assertRaises(MsprobeException): + valid_list(value) + + def test_valid_list_invalid_list(self): + value = ({"level1": 12}, ["level1", "level2"]) + with self.assertRaises(MsprobeException): + valid_list(value) + + def test_invalid_list(self): + value = (12, ["level1", "level2"]) + with self.assertRaises(MsprobeException): + valid_list(value) + + def test_valid_data_mode_none(self): + result = valid_data_mode([]) + self.assertEqual(result, []) + + @patch("msprobe.core.config_initiator.validate_params.DumpConst.ALL_DATA_MODE", ["mode1"]) + def test_valid_data_mode_valid(self): + result = valid_data_mode(["mode1"]) + self.assertEqual(result, ["mode1"]) + + @patch("msprobe.core.config_initiator.validate_params.DumpConst.ALL_DATA_MODE", ["mode1"]) + def test_valid_data_mode_invalid(self): + with self.assertRaises(MsprobeException): + valid_data_mode(["invalid"]) + + @patch("msprobe.core.config_initiator.validate_params.DumpConst.ALL_DATA_MODE", ["mode1"]) + def test_invalid_data_mode(self): + with self.assertRaises(MsprobeException): + valid_data_mode(12) + + @patch("msprobe.core.config_initiator.validate_params.DumpConst.ALL_DATA_MODE", ["mode1", "mode2"]) + def test_invalid_data_mode_more_element(self): + valid_data_mode(["mode1", "mode2"]) + + @patch("msprobe.core.config_initiator.validate_params.DumpConst.ALL_DUMP_EXTRA", ["extra1"]) + def valid_dump_extra_none(self): + self.assertIsNone(valid_dump_extra(None)) + result = valid_dump_extra([]) + self.assertEqual(result, []) + + @patch("msprobe.core.config_initiator.validate_params.DumpConst.ALL_DUMP_EXTRA", ["extra1"]) + def test_valid_dump_extra_valid(self): + result = valid_dump_extra(["extra1"]) + self.assertEqual(result, ["extra1"]) + + @patch("msprobe.core.config_initiator.validate_params.DumpConst.ALL_DUMP_EXTRA", ["extra1"]) + def test_valid_dump_extra_invalid(self): + with self.assertRaises(MsprobeException): + valid_dump_extra(123) + with self.assertRaises(MsprobeException): + valid_dump_extra(["bad"]) + + def test_valid_op_id_none(self): + result = valid_op_id("") + self.assertEqual(result, "") + + def test_valid_op_id(self): + valid_list = [1, "3_1", "4_2_3"] + result = valid_op_id(valid_list) + self.assertEqual(result, valid_list) + + def test_valid_op_id_invalid_element_format(self): + with self.assertRaises(MsprobeException): + valid_op_id(12) + with self.assertRaises(MsprobeException): + valid_op_id([["invalid"]]) + + def test_valid_dump_ge_graph(self): + self.assertIsNone(valid_dump_ge_graph(None)) + with self.assertRaises(MsprobeException): + valid_dump_ge_graph(123) + with self.assertRaises(MsprobeException): + valid_dump_ge_graph("8") + self.assertEqual(valid_dump_ge_graph("2"), "2") + + def test_valid_dump_graph_level(self): + self.assertIsNone(valid_dump_graph_level(None)) + with self.assertRaises(MsprobeException): + valid_dump_graph_level(123) + with self.assertRaises(MsprobeException): + valid_dump_graph_level("8") + self.assertEqual(valid_dump_graph_level("2"), "2") + + @patch("msprobe.core.config_initiator.validate_params.SafePath") + def test_valid_fusion_switch_file(self, mock_msit_path): + self.assertIsNone(valid_fusion_switch_file(None)) + mock_instance = MagicMock() + mock_instance.check.return_value = "valid" + mock_msit_path.return_value = mock_instance + result = valid_fusion_switch_file("some/path") + self.assertEqual(result, "valid") + mock_msit_path.assert_called_once() + + def test_valid_device(self): + self.assertIsNone(valid_device(None)) + self.assertEqual(valid_device("cpu"), "cpu") + with self.assertRaises(MsprobeException): + valid_device("gpu") + with self.assertRaises(MsprobeException): + valid_device(123) + + def test_valid_onnx_fusion_switch(self): + self.assertIsNone(valid_onnx_fusion_switch(None)) + self.assertTrue(valid_onnx_fusion_switch(True)) + with self.assertRaises(MsprobeException): + valid_onnx_fusion_switch(123) + + def test_valid_saved_model_tag(self): + self.assertIsNone(valid_saved_model_tag(None)) + with self.assertRaises(MsprobeException): + valid_saved_model_tag(123) + with self.assertRaises(MsprobeException): + valid_saved_model_tag(["%qsc/"]) + self.assertEqual(valid_saved_model_tag(["qazx"]), ["qazx"]) + + +class TestValidInputAndOfflineModelInput(unittest.TestCase): + def test_valid_input_none(self): + self.assertIsNone(valid_input(None)) + + @patch("msprobe.core.config_initiator.validate_params.OfflineModelInput") + def test_valid_input_calls_parse(self, mock_input_cls): + mock_parser = MagicMock() + mock_input_cls.return_value = mock_parser + valid_input([{"name": "x"}]) + mock_parser.parse.assert_called_once() + + def test_check_form_not_list(self): + with self.assertRaisesRegex(MsprobeException, "The input must be a list."): + OfflineModelInput("invalid") + + def test_check_form_element_not_dict(self): + with self.assertRaisesRegex(MsprobeException, "Each element in the input must be a dictionary."): + OfflineModelInput([1, 2]) + + def test_check_name_missing(self): + with self.assertRaisesRegex(MsprobeException, "Each input must have a name."): + OfflineModelInput([{}])._check_name({}) + + def test_check_input_shape_invalid_type(self): + with self.assertRaisesRegex(MsprobeException, "must be a list"): + OfflineModelInput([{}])._check_input_shape({"shape": "not_list"}, "input1") + + def test_check_input_shape_element_not_int(self): + with self.assertRaisesRegex(MsprobeException, "Expected int type"): + OfflineModelInput([{}])._check_input_shape({"shape": [1, "a"]}, "input1") + + @patch("msprobe.core.config_initiator.validate_params.SafePath") + def test_check_input_path_invalid_type(self, mock_path): + with self.assertRaisesRegex(MsprobeException, "must be a string"): + OfflineModelInput([{}])._check_input_path({"path": 123}, "input1") + + @patch("msprobe.core.config_initiator.validate_params.SafePath") + def test_check_input_path_invalid_suffix(self, mock_path): + with self.assertRaisesRegex(MsprobeException, "can only accept .npy or .bin"): + OfflineModelInput([{}])._check_input_path({"path": "file.txt"}, "input1") + + @patch("msprobe.core.config_initiator.validate_params.SafePath") + def test_check_input_path_valid(self, mock_path): + mock_check = mock_path.return_value.check + mock_check.return_value = True + self.assertIsNone(OfflineModelInput([{}])._check_input_path({"path": "input.npy"}, "input1")) + + @patch("msprobe.core.config_initiator.validate_params.parse_hyphen") + def test_parse_shape_range_for_str_hyphen(self, mock_parse): + mock_parse.return_value = [1, 2] + result = OfflineModelInput([{}])._parse_shape_range_for_str("1-2") + self.assertEqual(result, [1, 2]) + + def test_parse_shape_range_for_str_comma_valid(self): + result = OfflineModelInput([{}])._parse_shape_range_for_str("2,3") + self.assertEqual(result, [2, 3]) + + def test_parse_shape_range_for_str_invalid_format(self): + with self.assertRaisesRegex(MsprobeException, "can only contain hyphen"): + OfflineModelInput([{}])._parse_shape_range_for_str("wrong") + + @patch("msprobe.core.config_initiator.validate_params.check_int_border") + @patch("msprobe.core.config_initiator.validate_params.OfflineModelInput._parse_shape_range_for_str") + def test_parse_dyn_shape_range_mixed(self, mock_parse, mock_check): + mock_parse.return_value = [1, 2] + input_obj = OfflineModelInput([{}]) + result = input_obj._parse_dyn_shape_range(["1-2", 3], "input1") + self.assertIsInstance(result, list) + + def test_parse_dyn_shape_range_invalid_type(self): + with self.assertRaisesRegex(MsprobeException, "must be a list"): + OfflineModelInput([{}])._parse_dyn_shape_range("not_list", "input1") + + def test_parse_dyn_shape_range_element_type_error(self): + with self.assertRaisesRegex(MsprobeException, "support only string and integers"): + OfflineModelInput([{}])._parse_dyn_shape_range([1.5], "input1") + + @patch("msprobe.core.config_initiator.validate_params.logger") + @patch.object(OfflineModelInput, "_parse_dyn_shape_range") + def test_check_dyn_shape_with_path(self, mock_parse_dym, mock_logger): + mock_parse_dym.return_value = [[1, 2], [3, 4]] + input_obj = OfflineModelInput([{}]) + result = input_obj._check_dyn_shape({"name": "x", "dyn_shape": ["1-2"], "path": "input.npy"}, "x") + self.assertEqual(result["path"], "") + self.assertEqual(result["shape"], []) + + def test_draw_shape_and_path_static(self): + input_obj = OfflineModelInput([{}]) + input_obj.is_need_expand_shape = False + shapes, paths = input_obj._draw_shape_and_path([{"name": "x", "shape": [1, 2], "path": "x.npy"}]) + self.assertEqual(shapes["x"], [1, 2]) + self.assertEqual(paths, ["x.npy"]) + + def test_draw_shape_and_path_dynamic_valid(self): + input_obj = OfflineModelInput([{}]) + input_obj.is_need_expand_shape = True + input_data = [{"name": "x", "dyn_shape": [[1], [2]]}, {"name": "y", "dyn_shape": [[3], [4]]}] + shapes, paths = input_obj._draw_shape_and_path(input_data) + self.assertEqual(len(shapes), 2) + self.assertIsNone(paths) + + def test_draw_shape_and_path_dynamic_invalid(self): + input_obj = OfflineModelInput([{}]) + input_obj.is_need_expand_shape = True + input_data = [{"name": "x", "dyn_shape": [[1], [2], [3]]}, {"name": "y", "dyn_shape": [[4], [5]]}] + with self.assertRaisesRegex(MsprobeException, "same expanded dynamic shape length"): + input_obj._draw_shape_and_path(input_data) + + @patch.object(OfflineModelInput, "_check_name", return_value="x") + @patch.object(OfflineModelInput, "_check_input_shape") + @patch.object(OfflineModelInput, "_check_input_path") + @patch.object(OfflineModelInput, "_check_dyn_shape", side_effect=lambda x, y: x) + @patch.object(OfflineModelInput, "_draw_shape_and_path", return_value=({}, [])) + def test_parse_calls_all(self, mock_draw, mock_dym, mock_path, mock_shape, mock_name): + input_obj = OfflineModelInput([{"name": "x"}]) + result = input_obj.parse() + self.assertEqual(result, ({}, [])) diff --git a/msprobe/test/UT/core_ut/dump_ut/test_onnx_model.py b/msprobe/test/UT/core_ut/dump_ut/test_onnx_model.py new file mode 100644 index 000000000..4c0a2b574 --- /dev/null +++ b/msprobe/test/UT/core_ut/dump_ut/test_onnx_model.py @@ -0,0 +1,90 @@ +import unittest +from unittest.mock import MagicMock, patch + +from msprobe.core.dump import OnnxModelActuator +from msprobe.utils.exceptions import MsprobeException + + +class TestOnnxModelActuator(unittest.TestCase): + + @patch("msprobe.core.dump.onnx_model.load_onnx_model") + @patch("msprobe.core.dump.onnx_model.load_onnx_session") + def test_load_model_success(self, mock_load_session, mock_load_model): + actuator = OnnxModelActuator("model.onnx", None, None) + actuator.kwargs = {"onnx_fusion_switch": False} + actuator.load_model() + mock_load_model.assert_called_once_with("model.onnx") + mock_load_session.assert_called_once_with("model.onnx", False) + + @patch("msprobe.core.dump.onnx_model.load_onnx_session") + def test_infer_success(self, mock_load_session): + mock_session = MagicMock() + mock_session.get_outputs.return_value = [MagicMock(name="out1"), MagicMock(name="out2")] + mock_session.run.return_value = ["result"] + mock_load_session.return_value = mock_session + result = OnnxModelActuator.infer("model.onnx", {"input": "data"}) + self.assertEqual(result, ["result"]) + mock_session.run.assert_called_once() + + @patch("msprobe.core.dump.onnx_model.load_onnx_session") + def test_infer_failure_raises_exception(self, mock_load_session): + mock_session = MagicMock() + mock_session.get_outputs.return_value = [MagicMock(name="out")] + mock_session.run.side_effect = RuntimeError("Failed") + mock_load_session.return_value = mock_session + with self.assertRaises(MsprobeException): + OnnxModelActuator.infer("model.onnx", {"input": "data"}) + + @patch("msprobe.core.dump.onnx_model.logger") + def test_get_input_tensor_info(self, mock_logger): + actuator = OnnxModelActuator("model.onnx", None, None) + input_mock = MagicMock() + input_mock.name = "input" + input_mock.type = "float32" + input_mock.shape = [1, 3, 224, 224] + actuator.model_session = MagicMock() + actuator.model_session.get_inputs.return_value = [input_mock] + actuator.process_tensor_shape = MagicMock(return_value=[("input", "float32", (1, 3, 224, 224))]) + result = actuator.get_input_tensor_info() + self.assertEqual(result, [("input", "float32", (1, 3, 224, 224))]) + mock_logger.info.assert_called() + + @patch("msprobe.core.dump.onnx_model.save_onnx_model") + @patch("msprobe.core.dump.onnx_model.convert_bytes", return_value="1MB") + @patch("msprobe.core.dump.onnx_model.is_file", return_value=False) + @patch("msprobe.core.dump.onnx_model.join_path", return_value="/mock/dir/infer_model.onnx") + @patch("msprobe.core.dump.onnx_model.get_basename_from_path", return_value="base_model.onnx") + @patch("msprobe.core.dump.onnx_model.dependent.get") + @patch("msprobe.core.dump.onnx_model.logger") + def test_export_uninfer_model( + self, mock_logger, mock_get_dep, mock_basename, mock_join_path, mock_is_file, mock_convert, mock_save_model + ): + # Setup mock ONNX graph structure + onnx_mock = MagicMock() + onnx_mock.ValueInfoProto.side_effect = lambda name: f"vi({name})" + mock_get_dep.return_value = onnx_mock + actuator = OnnxModelActuator("model.onnx", None, None) + actuator.dir_pool = MagicMock() + actuator.dir_pool.get_model_dir.return_value = "/mock/dir" + node1 = MagicMock() + node1.output = ["out1", "out2"] + actuator.origin_model = MagicMock() + actuator.origin_model.graph.node = [node1] + actuator.origin_model.graph.output = [] + actuator.origin_model.ByteSize.return_value = 123456 + uninfer_path = actuator.export_uninfer_model() + self.assertEqual(uninfer_path, "/mock/dir/infer_model.onnx") + self.assertEqual(len(actuator.origin_model.graph.output), 2) + mock_save_model.assert_called_once() + mock_logger.info.assert_any_call("The size of the modified ONNX model to be saved is 1MB.") + + @patch("msprobe.core.dump.onnx_model.is_file", return_value=True) + @patch("msprobe.core.dump.onnx_model.join_path", return_value="/mock/dir/infer_model.onnx") + @patch("msprobe.core.dump.onnx_model.get_basename_from_path", return_value="base_model.onnx") + def test_export_uninfer_model_skips_if_file_exists(self, mock_basename, mock_join, mock_is_file): + actuator = OnnxModelActuator("model.onnx", None, None) + actuator.dir_pool = MagicMock() + actuator.dir_pool.get_model_dir.return_value = "/mock/dir" + actuator.origin_model = MagicMock() + uninfer_path = actuator.export_uninfer_model() + self.assertEqual(uninfer_path, "/mock/dir/infer_model.onnx") diff --git a/msprobe/test/UT/core_ut/dump_ut/test_tf_model.py b/msprobe/test/UT/core_ut/dump_ut/test_tf_model.py new file mode 100644 index 000000000..0afce08bc --- /dev/null +++ b/msprobe/test/UT/core_ut/dump_ut/test_tf_model.py @@ -0,0 +1,234 @@ +import unittest +from unittest.mock import MagicMock, patch + +import numpy as np + +from msprobe.core.dump.tf_model import ( + FrozenGraphActuator, + FrozenGraphActuatorCPU, + FrozenGraphActuatorNPU, +) +from msprobe.utils.exceptions import MsprobeException + + +class TestFrozenGraphActuator(unittest.TestCase): + + @patch("msprobe.core.dump.tf_model.dependent.get_tensorflow") + def setUp(self, mock_get_tf): + mock_tf = MagicMock() + mock_rewriter_config = MagicMock() + mock_get_tf.return_value = (mock_tf, mock_rewriter_config, None) + self.actuator = FrozenGraphActuator( + model_path="fake_model.pb", input_shape=(1, 224, 224, 3), input_path="input.npy" + ) + self.tf = mock_tf + self.actuator.tf = mock_tf + + @patch("msprobe.core.dump.tf_model.dependent.get_tensorflow") + def test_import_tf_success(self, mock_get_tf): + mock_tf = MagicMock() + mock_rewriter = MagicMock() + mock_get_tf.return_value = (mock_tf, mock_rewriter, "extra") + tf, rewriter = FrozenGraphActuator._import_tf() + self.assertEqual(tf, mock_tf) + self.assertEqual(rewriter, mock_rewriter) + mock_tf.compat.v1.disable_eager_execution.assert_called_once() + + @patch("msprobe.core.dump.tf_model.dependent.get_tensorflow") + def test_import_tf_none(self, mock_get_tf): + mock_get_tf.return_value = (None, None, None) + tf, rewriter = FrozenGraphActuator._import_tf() + self.assertIsNone(tf) + self.assertIsNone(rewriter) + + @patch("msprobe.core.dump.tf_model.load_pb_frozen_graph_model") + def test_load_model(self, mock_load_pb): + mock_graph_def = MagicMock() + mock_load_pb.return_value = mock_graph_def + self.actuator.load_model() + mock_load_pb.assert_called_once_with("fake_model.pb") + self.assertEqual(self.actuator.graph_def, mock_graph_def) + + def test_get_tensor_name(self): + name = FrozenGraphActuator._get_tensor_name("input:0") + self.assertEqual(name, "input") + name = FrozenGraphActuator._get_tensor_name("no_colon") + self.assertEqual(name, "no_colon") + + def test_tf_shape_to_list(self): + mock_shape = MagicMock() + dim1 = MagicMock(size=1) + dim2 = MagicMock(size=-1) + dim3 = MagicMock(size=3) + mock_shape.dim = [dim1, dim2, dim3] + result = FrozenGraphActuator._tf_shape_to_list(mock_shape) + self.assertEqual(result, [1, None, 3]) + + def test_get_input_tensor_info(self): + mock_dtype = MagicMock() + mock_dtype.type = 1 + mock_tensor_shape = MagicMock() + mock_tensor_shape.dim = [MagicMock(size=1), MagicMock(size=224)] + node = MagicMock() + node.name = "input_node" + node.op = "Placeholder" + node.attr = {"dtype": mock_dtype, "shape": MagicMock(shape=mock_tensor_shape)} + self.actuator.graph_def = MagicMock() + self.actuator.graph_def.node = [node] + self.actuator.tf.dtypes.as_dtype.return_value = "float32" + self.actuator.process_tensor_shape = MagicMock( + return_value=[{"name": "input_node", "shape": [1, 224], "type": "float32"}] + ) + result = self.actuator.get_input_tensor_info() + self.assertEqual(len(result), 1) + self.assertIn("input_node", self.actuator.all_node_names) + + def test_close_session(self): + mock_sess = MagicMock() + self.actuator.sess = mock_sess + self.actuator.close() + mock_sess.close.assert_called_once() + self.assertIsNone(self.actuator.sess) + + def test_close_session_no_attr(self): + self.actuator.sess = None + try: + self.actuator.close() + except Exception as e: + self.fail(f"close() raised an exception unexpectedly: {e}") + + def test_get_tf_ops_success(self): + self.actuator.all_node_names = ["input"] + mock_graph = MagicMock() + tensor = MagicMock() + mock_graph.get_tensor_by_name.return_value = tensor + + self.actuator.sess = MagicMock() + self.actuator.sess.graph = mock_graph + + ops = self.actuator._get_tf_ops() + self.assertEqual(len(ops), 1) + self.assertEqual(ops[0], tensor) + + def test_get_tf_ops_failure(self): + self.actuator.all_node_names = ["bad_node"] + self.actuator.sess = MagicMock() + self.actuator.sess.graph.get_tensor_by_name.side_effect = Exception("fail") + + with self.assertRaises(MsprobeException): + self.actuator._get_tf_ops() + + def test_build_feed_success(self): + tensor = MagicMock() + input_map = {"input": np.ones((1, 224, 224, 3))} + + self.actuator.sess = MagicMock() + self.actuator.sess.graph.get_tensor_by_name.return_value = tensor + + feed_dict = self.actuator._build_feed(input_map) + self.assertEqual(feed_dict[tensor].shape, (1, 224, 224, 3)) + + def test_build_feed_failure(self): + input_map = {"bad_input": np.zeros((1,))} + + self.actuator.sess = MagicMock() + self.actuator.sess.graph.get_tensor_by_name.side_effect = Exception("fail") + + with self.assertRaises(MsprobeException): + self.actuator._build_feed(input_map) + + def test_infer_success(self): + mock_sess = MagicMock() + mock_sess.run.return_value = ["result"] + self.actuator._open_session = MagicMock(return_value=mock_sess) + self.actuator._renew_all_node_names = MagicMock() + self.actuator._get_tf_ops = MagicMock(return_value=["fake_op"]) + self.actuator._build_feed = MagicMock(return_value={"input": "fake_data"}) + self.actuator.close = MagicMock() + result = self.actuator.infer({"input": "data"}) + self.assertEqual(result, ["result"]) + mock_sess.run.assert_called_once() + + def test_infer_failure(self): + mock_sess = MagicMock() + mock_sess.run.side_effect = RuntimeError("bad inference") + self.actuator._open_session = MagicMock(return_value=mock_sess) + self.actuator._renew_all_node_names = MagicMock() + self.actuator._get_tf_ops = MagicMock(return_value=["fake_op"]) + self.actuator._build_feed = MagicMock(return_value={"input": "fake_data"}) + self.actuator.close = MagicMock() + with self.assertRaises(MsprobeException) as context: + self.actuator.infer({"input": "data"}) + self.assertIn("input shape or data", str(context.exception)) + + +class TestFrozenGraphActuatorCPU(unittest.TestCase): + @patch("msprobe.core.dump.tf_model.FrozenGraphActuator._import_tf") + def test_open_session(self, mock_import_tf): + mock_tf = MagicMock() + mock_tf.compat.v1.Session.return_value = "mock_session" + mock_import_tf.return_value = (mock_tf, MagicMock()) + actuator = FrozenGraphActuatorCPU("model", {}, "input") + session = actuator._open_session() + self.assertEqual(session, "mock_session") + + +class TestFrozenGraphActuatorNPU(unittest.TestCase): + def setUp(self): + self.kwargs = {"data_mode": ["all"], "fsf": "mock_fusion_file.txt"} + self.actuator = FrozenGraphActuatorNPU("mock_model.pb", None, None, **self.kwargs) + self.actuator.tf = MagicMock() + self.actuator.rewriter_config = MagicMock() + self.actuator.dir_pool = MagicMock() + self.actuator.dir_pool.get_model_dir.return_value = "/mock/model_dir" + self.actuator.dir_pool.get_rank_dir.return_value = "/mock/rank_dir" + + @patch("msprobe.core.dump.tf_model.glob") + @patch("msprobe.core.dump.tf_model.get_name_and_ext", return_value=("mock_model", ".txt")) + @patch("msprobe.core.dump.tf_model.join_path", side_effect=lambda *args: "/".join(args)) + @patch("msprobe.core.dump.tf_model.cann.model2json") + def test_convert_txt2json_success(self, mock_model2json, mock_join, mock_get_name, mock_glob): + mock_glob.return_value = ["/mock/model_dir/subdir/mock_Build.txt"] + self.actuator.convert_txt2json() + mock_model2json.assert_called_once_with( + "/mock/model_dir/subdir/mock_Build.txt", "/mock/model_dir/mock_model.json" + ) + + @patch("msprobe.core.dump.tf_model.glob", return_value=[]) + def test_convert_txt2json_not_found(self, mock_glob): + with self.assertRaises(MsprobeException) as cm: + self.actuator.convert_txt2json() + self.assertIn("No TXT format graph file found", str(cm.exception)) + + @patch("msprobe.core.dump.tf_model.logger") + @patch("msprobe.core.dump.tf_model.dependent.get") + def test_open_session_with_fusion_file(self, mock_dependent_get, mock_logger): + mock_session_cls = MagicMock() + mock_npu_device = MagicMock() + mock_npu_device.compat.enable_v1 = MagicMock() + mock_dependent_get.return_value = mock_npu_device + + config_proto_mock = MagicMock() + self.actuator.tf.compat.v1.ConfigProto.return_value = config_proto_mock + self.actuator.tf.compat.v1.Session.return_value = "mock_session" + + custom_op = MagicMock() + config_proto_mock.graph_options.rewrite_options.custom_optimizers.add.return_value = custom_op + + session = self.actuator._open_session() + + self.assertEqual(session, "mock_session") + mock_npu_device.compat.enable_v1.assert_called_once() + mock_logger.info.assert_called_with("Fusion switch settings read from mock_fusion_file.txt.") + + @patch("msprobe.core.dump.tf_model.dependent.get", return_value=None) + def test_open_session_npu_device_not_installed(self, mock_dependent_get): + with self.assertRaises(MsprobeException) as cm: + self.actuator._open_session() + self.assertIn("npu_device is properly installed", str(cm.exception)) + + @patch("msprobe.core.dump.tf_model.get_net_output_nodes_from_graph_def", return_value=["node1", "node2"]) + def test_renew_all_node_names(self, mock_get_nodes): + self.actuator.graph_def = MagicMock() + self.actuator._renew_all_node_names() + self.assertEqual(self.actuator.all_node_names, ["node1", "node2"]) -- Gitee