diff --git a/profiler/msprof_analyze/cluster_analyse/common_func/excel_utils.py b/profiler/msprof_analyze/cluster_analyse/common_func/excel_utils.py index 3ca5308b5963ab9e408adf433711028c06a5fc4e..a39f03e0e3e000ce5bcec71d758f4ed24a548185 100644 --- a/profiler/msprof_analyze/cluster_analyse/common_func/excel_utils.py +++ b/profiler/msprof_analyze/cluster_analyse/common_func/excel_utils.py @@ -12,10 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os from typing import List, Dict, Optional import pandas as pd +from msprof_analyze.prof_common.path_manager import PathManager from msprof_analyze.prof_common.logger import get_logger @@ -53,24 +54,36 @@ class ExcelUtils: self.df = None self._formats_cache = {} - def create_excel_writer(self, output_file: str, + def create_excel_writer(self, + output_path: str, + file_name: str, df: pd.DataFrame, - column_format: Optional[Dict] = None, - header_format: Optional[Dict] = None, - sheet_name: str = 'Sheet1'): - """初始化Excel写入器并写入原始数据""" - self.writer = pd.ExcelWriter(output_file, engine='xlsxwriter') + sheet_name: str = 'Sheet1', + format_config: Optional[Dict[str, Dict]] = None) -> None: + """初始化Excel写入器并写入原始数据 + Args: + output_path: 输出目录路径 + file_name: 输出文件名 + df: 要写入的DataFrame数据 + sheet_name: 工作表名称 (可选,默认为'Sheet1') + format_config: 格式化配置字典 (可选) + - header: 标题行格式 (可选) + - column: 数据列格式 (可选) + """ + PathManager.check_path_writeable(output_path) + self.writer = pd.ExcelWriter(os.path.join(output_path, file_name), engine='xlsxwriter') self.workbook = self.writer.book self.worksheet = self.workbook.add_worksheet(sheet_name) self.df = df # 写入标题行 - header_fmt = self._get_format(header_format if header_format else self.DEFAULT_HEADER_FORMAT) + format_config = format_config or {} + header_fmt = self._get_format(format_config.get('header', self.DEFAULT_HEADER_FORMAT)) for col_idx, col_name in enumerate(df.columns): self.worksheet.write(0, col_idx, col_name, header_fmt) # 写入数据行 - default_fmt = self._get_format(column_format if column_format else self.DEFAULT_FORMAT) + default_fmt = self._get_format(format_config.get('column', self.DEFAULT_FORMAT)) for row_idx, row in df.iterrows(): for col_idx, col_name in enumerate(df.columns): self.worksheet.write(row_idx + 1, col_idx, row[col_name], default_fmt) diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/module_statistic/module_statistic.py b/profiler/msprof_analyze/cluster_analyse/recipes/module_statistic/module_statistic.py index 91a354f81ae0838e8e96688fa47e2a8e45057060..72601f0aaa6ea6aa21683b1b02765db47012237e 100644 --- a/profiler/msprof_analyze/cluster_analyse/recipes/module_statistic/module_statistic.py +++ b/profiler/msprof_analyze/cluster_analyse/recipes/module_statistic/module_statistic.py @@ -110,9 +110,8 @@ class ModuleStatistic(BaseRecipeAnalysis): logger.warning(f"No module analysis result for rank {rank_id}, skipping dump data") continue file_name = f"module_statistic_{rank_id}.xlsx" - file_path = os.path.join(self.output_path, file_name) try: - excel_utils.create_excel_writer(file_path, stat_df) + excel_utils.create_excel_writer(self.output_path, file_name, stat_df) excel_utils.merge_duplicate_cells(columns_to_merge) excel_utils.set_column_width(column_width_config) excel_utils.set_row_height(0, 27) # 标题行行高27 diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/common_func/test_excel_utils.py b/profiler/msprof_analyze/test/ut/cluster_analyse/common_func/test_excel_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5936fc2545ecdd5b9873659449c3bdaba9658c --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/common_func/test_excel_utils.py @@ -0,0 +1,174 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import unittest +from unittest.mock import patch +import pandas as pd + +from msprof_analyze.cluster_analyse.common_func.excel_utils import ExcelUtils + + +class TestExcelUtils(unittest.TestCase): + def setUp(self): + """在每个测试方法前执行""" + self.sample_df = pd.DataFrame({ + 'A': ['foo', 'foo', 'bar', 'bar', 'baz'], + 'B': [1, 1, 2, 2, 3], + 'C': ['x', 'x', 'y', 'y', 'z'] + }) + self.excel_utils = ExcelUtils() + self.tmp_dir = "test_tmp" # 实际使用时应该用 pytest 的 tmp_path + os.makedirs(self.tmp_dir, exist_ok=True) + + def tearDown(self): + """在每个测试方法后执行""" + self.excel_utils.clear() + # 清理临时文件 + for filename in os.listdir(self.tmp_dir): + file_path = os.path.join(self.tmp_dir, filename) + if os.path.isfile(file_path): + os.unlink(file_path) + os.rmdir(self.tmp_dir) + + def test_create_excel_writer_when_given_valid_df_then_excel_created(self): + """测试 create_excel_writer 能正确写入 DataFrame 到 Excel 文件。""" + file_name = 'test.xlsx' + expected_path = os.path.join(self.tmp_dir, file_name) + self.excel_utils.create_excel_writer(output_path=self.tmp_dir, file_name=file_name, df=self.sample_df) + self.excel_utils.save_and_close() + self.assertTrue(os.path.exists(expected_path)) + + def test_set_column_width_when_called_then_column_width_set(self): + """测试 set_column_width 能正确设置列宽。""" + file_name = 'test_col_width.xlsx' + expected_path = os.path.join(self.tmp_dir, file_name) + self.excel_utils.create_excel_writer(output_path=self.tmp_dir, file_name=file_name, df=self.sample_df) + self.excel_utils.set_column_width({'A': 20, 'B': 10, 'C': 15}) + self.excel_utils.save_and_close() + self.assertTrue(os.path.exists(expected_path)) + + def test_set_row_height_when_called_then_row_height_set(self): + """测试 set_row_height 能正确设置行高。""" + file_name = 'test_row_height.xlsx' + expected_path = os.path.join(self.tmp_dir, file_name) + self.excel_utils.create_excel_writer(output_path=self.tmp_dir, file_name=file_name, df=self.sample_df) + self.excel_utils.set_row_height(1, 30) + self.excel_utils.save_and_close() + self.assertTrue(os.path.exists(expected_path)) + + def test_freeze_panes_when_called_then_panes_frozen(self): + """测试 freeze_panes 能正确冻结窗格。""" + file_name = 'test_freeze.xlsx' + expected_path = os.path.join(self.tmp_dir, file_name) + self.excel_utils.create_excel_writer(output_path=self.tmp_dir, file_name=file_name, df=self.sample_df) + self.excel_utils.freeze_panes(1, 1) + self.excel_utils.save_and_close() + self.assertTrue(os.path.exists(expected_path)) + + def test_merge_duplicate_cells_when_duplicates_exist_then_cells_merged(self): + """测试 merge_duplicate_cells 能合并连续相同值的单元格。""" + file_name = 'test_merge.xlsx' + expected_path = os.path.join(self.tmp_dir, file_name) + self.excel_utils.create_excel_writer(output_path=self.tmp_dir, file_name=file_name, df=self.sample_df) + self.excel_utils.merge_duplicate_cells(['A']) + self.excel_utils.save_and_close() + self.assertTrue(os.path.exists(expected_path)) + + def test_clear_when_called_then_resources_released(self): + """测试 clear 能释放资源并允许实例复用。""" + file_name = 'test_clear.xlsx' + self.excel_utils.create_excel_writer(output_path=self.tmp_dir, file_name=file_name, df=self.sample_df) + self.excel_utils.clear() + self.assertIsNone(self.excel_utils.writer) + self.assertIsNone(self.excel_utils.workbook) + self.assertIsNone(self.excel_utils.worksheet) + self.assertIsNone(self.excel_utils.df) + self.assertEqual(self.excel_utils._formats_cache, {}) + + def test_get_format_when_called_multiple_times_then_cache_used(self): + """测试 _get_format 多次调用同一格式时使用缓存。""" + file_name = 'test_format.xlsx' + self.excel_utils.create_excel_writer(output_path=self.tmp_dir, file_name=file_name, df=self.sample_df) + fmt1 = self.excel_utils._get_format({'valign': 'vcenter'}) + fmt2 = self.excel_utils._get_format({'valign': 'vcenter'}) + self.assertIs(fmt1, fmt2) + + def test_set_column_width_when_worksheet_not_initialized_then_raises_exception(self): + """测试 set_column_width 在worksheet未初始化时抛出异常。""" + with self.assertRaises(Exception) as context: + self.excel_utils.set_column_width({'A': 20}) + self.assertIn("Worksheet has not been initialized", str(context.exception)) + + def test_set_row_height_when_worksheet_not_initialized_then_raises_exception(self): + """测试 set_row_height 在worksheet未初始化时抛出异常。""" + with self.assertRaises(Exception) as context: + self.excel_utils.set_row_height(1, 30) + self.assertIn("Worksheet not initialized", str(context.exception)) + + def test_freeze_panes_when_worksheet_not_initialized_then_raises_exception(self): + """测试 freeze_panes 在worksheet未初始化时抛出异常。""" + with self.assertRaises(Exception) as context: + self.excel_utils.freeze_panes(1, 1) + self.assertIn("Worksheet has not been initialized", str(context.exception)) + + def test_merge_duplicate_cells_when_worksheet_not_initialized_then_raises_exception(self): + """测试 merge_duplicate_cells 在worksheet未初始化时抛出异常。""" + with self.assertRaises(Exception) as context: + self.excel_utils.merge_duplicate_cells(['A']) + self.assertIn("Worksheet has not been initialized", str(context.exception)) + + @patch("msprof_analyze.cluster_analyse.common_func.excel_utils.logger") + def test_merge_duplicate_cells_when_invalid_column_then_warns_and_continues(self, mock_logger): + """测试 merge_duplicate_cells 在无效列名时发出警告但继续执行。""" + file_name = 'test_invalid_column.xlsx' + expected_path = os.path.join(self.tmp_dir, file_name) + self.excel_utils.create_excel_writer(output_path=self.tmp_dir, file_name=file_name, df=self.sample_df) + # 测试不存在的列名 + self.excel_utils.merge_duplicate_cells(['InvalidColumn']) + self.excel_utils.save_and_close() + self.assertTrue(os.path.exists(expected_path)) + self.assertTrue(mock_logger.warning.called) + + def test_create_excel_writer_when_invalid_file_path_then_raises_exception(self): + """测试 create_excel_writer 在无效文件路径时抛出异常。""" + invalid_dir = os.path.join(self.tmp_dir, 'not_exist_dir') # 目录不存在 + with self.assertRaises(RuntimeError): + self.excel_utils.create_excel_writer(output_path=invalid_dir, file_name='test.xlsx', df=self.sample_df) + + def test_create_excel_writer_when_empty_dataframe_then_creates_empty_excel(self): + """测试 create_excel_writer 处理空DataFrame的情况。""" + empty_df = pd.DataFrame() + file_name = 'test_empty.xlsx' + expected_path = os.path.join(self.tmp_dir, file_name) + self.excel_utils.create_excel_writer(output_path=self.tmp_dir, file_name=file_name, df=empty_df) + self.excel_utils.save_and_close() + self.assertTrue(os.path.exists(expected_path)) + + def test_clear_when_called_multiple_times_then_no_error(self): + """测试 clear 方法被多次调用时不会出错。""" + self.excel_utils.clear() + self.excel_utils.clear() # 再次调用 + self.assertIsNone(self.excel_utils.writer) + self.assertIsNone(self.excel_utils.workbook) + + def test_save_and_close_when_writer_not_initialized_then_no_error(self): + """测试 save_and_close 在writer未初始化时不会出错。""" + self.excel_utils.save_and_close() + self.assertIsNone(self.excel_utils.writer) + + def test_get_format_when_workbook_not_initialized_then_raises_exception(self): + """测试 _get_format 在workbook未初始化时抛出异常。""" + with self.assertRaises(AttributeError): + self.excel_utils._get_format({'valign': 'vcenter'}) \ No newline at end of file diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_module_statistic.py b/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_module_statistic.py new file mode 100644 index 0000000000000000000000000000000000000000..ad7e6ed1d8abec5c059ce6e023bed15e7833183d --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_module_statistic.py @@ -0,0 +1,287 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch, MagicMock +import pandas as pd + +from msprof_analyze.cluster_analyse.recipes.module_statistic.module_statistic import ModuleStatistic, NodeType, TreeNode +from msprof_analyze.prof_common.constant import Constant + + +class TestModuleStatistic(unittest.TestCase): + def setUp(self): + self.params = { + Constant.COLLECTION_PATH: "/tmp", + Constant.DATA_MAP: {}, + Constant.RECIPE_NAME: "ModuleStatistic", + Constant.PARALLEL_MODE: "concurrent", + Constant.EXPORT_TYPE: Constant.DB, + Constant.CLUSTER_ANALYSIS_OUTPUT_PATH: "/tmp/output", + Constant.RANK_LIST: Constant.ALL, + } + self.analysis = ModuleStatistic(self.params) + + @patch("msprof_analyze.cluster_analyse.recipes.module_statistic.module_statistic.ModuleStatistic.mapper_func") + @patch("msprof_analyze.cluster_analyse.recipes.module_statistic.module_statistic.ModuleStatistic.save_db") + @patch("msprof_analyze.cluster_analyse.recipes.module_statistic.module_statistic.ModuleStatistic.save_csv") + def test_run_when_export_type_is_db_then_save_db_called(self, mock_save_csv, mock_save_db, mock_mapper_func): + """测试run方法,export_type为DB时应调用save_db""" + self.analysis._export_type = Constant.DB + mock_mapper_func.return_value = [(0, pd.DataFrame({"a": [1]}))] + self.analysis.run(context=None) + mock_save_db.assert_called() + mock_save_csv.assert_not_called() + + @patch("msprof_analyze.cluster_analyse.recipes.module_statistic.module_statistic.ModuleStatistic.mapper_func") + @patch("msprof_analyze.cluster_analyse.recipes.module_statistic.module_statistic.ModuleStatistic.save_db") + @patch("msprof_analyze.cluster_analyse.recipes.module_statistic.module_statistic.ModuleStatistic.save_csv") + def test_run_when_export_type_is_excel_then_save_csv_called(self, mock_save_csv, mock_save_db, mock_mapper_func): + """测试run方法,export_type为EXCEL时应调用save_csv""" + self.analysis._export_type = Constant.EXCEL + mock_mapper_func.return_value = [(0, pd.DataFrame({"a": [1]}))] + self.analysis.run(context=None) + mock_save_csv.assert_called() + mock_save_db.assert_not_called() + + @patch("msprof_analyze.cluster_analyse.recipes.module_statistic.module_statistic.ModuleStatistic.mapper_func") + @patch("msprof_analyze.cluster_analyse.recipes.module_statistic.module_statistic.ModuleStatistic.save_db") + @patch("msprof_analyze.cluster_analyse.recipes.module_statistic.module_statistic.ModuleStatistic.save_csv") + @patch("msprof_analyze.cluster_analyse.recipes.module_statistic.module_statistic.logger") + def test_run_when_export_type_is_notebook_then_error(self, mock_logger, mock_save_csv, mock_save_db, + mock_mapper_func): + """测试run方法,export_type为notebook时应save_csv/save_db都不调用,logger有error打屏""" + self.analysis._export_type = Constant.NOTEBOOK + mock_mapper_func.return_value = [(0, pd.DataFrame({"a": [1]}))] + self.analysis.run(context=None) + self.assertTrue(mock_logger.error.called) + mock_save_csv.assert_not_called() + mock_save_db.assert_not_called() + + def test_reducer_func_when_mapper_res_is_empty_then_return_none(self): + """测试reducer_func,输入为空时应返回None""" + result = self.analysis.reducer_func([]) + self.assertIsNone(result) + + def test_reducer_func_when_mapper_res_has_valid_df_then_concat(self): + """测试reducer_func,输入有有效DataFrame时应返回拼接结果""" + df1 = pd.DataFrame({"a": [1]}) + df2 = pd.DataFrame({"a": [2]}) + res = [(0, df1), (1, df2)] + out = self.analysis.reducer_func(res) + self.assertEqual(len(out), 2) + self.assertIn("rankID", out.columns) + + def test_reducer_func_when_mapper_res_has_empty_df_then_skip_concat_this_df(self): + """测试reducer_func,输入有DataFrame Empty时返回的拼接结果不包含对应Rank""" + df1 = pd.DataFrame({"a": [1]}) + df2 = pd.DataFrame(columns=["a"]) + res = [(0, df1), (1, df2)] + out = self.analysis.reducer_func(res) + self.assertEqual(len(out), 1) + self.assertIn(0, out["rankID"].tolist()) + self.assertNotIn(1, out["rankID"].tolist()) + + @patch("msprof_analyze.cluster_analyse.recipes.module_statistic.module_statistic.logger") + def test_save_db_when_df_is_none_or_empty_then_warn(self, mock_logger): + """测试save_db,df为None或empty时应记录warning""" + self.analysis.save_db(None) + self.analysis.save_db(pd.DataFrame()) + self.assertTrue(mock_logger.warning.called) + + @patch("msprof_analyze.cluster_analyse.recipes.module_statistic.module_statistic.ExcelUtils") + @patch("msprof_analyze.cluster_analyse.recipes.module_statistic.module_statistic.logger") + def test_save_csv_when_stat_df_empty_then_warn(self, mock_logger, mock_excel_utils): + """测试save_csv,stat_df为空时应记录warning""" + mock_excel = MagicMock() + mock_excel_utils.return_value = mock_excel + self.analysis._output_path = "/tmp" + mapper_res = [(0, pd.DataFrame())] + self.analysis.save_csv(mapper_res) + self.assertTrue(mock_logger.warning.called) + + def test_format_stat_df_columns_when_export_type_db_then_rename(self): + """测试_format_stat_df_columns,export_type为DB时应重命名列""" + self.analysis._export_type = Constant.DB + stat_df = pd.DataFrame({ + 'module_parent': ['a'], + 'op_name': ['b'], + 'kernel_list': ['c'], + 'op_count': [1], + 'total_kernel_duration': [2], + 'avg_kernel_duration': [3] + }) + out = self.analysis._format_stat_df_columns(stat_df) + self.assertIn('parentModule', out.columns) + self.assertIn('opName', out.columns) + + def test_format_stat_df_columns_when_export_type_excel_then_rename(self): + """测试_format_stat_df_columns,export_type为EXCEL时应重命名列""" + self.analysis._export_type = Constant.EXCEL + stat_df = pd.DataFrame({ + 'module_parent': ['a'], + 'module': ['b'], + 'op_name': ['c'], + 'kernel_list': ['d'], + 'op_count': [1], + 'total_kernel_duration': [2], + 'avg_kernel_duration': [3] + }) + out = self.analysis._format_stat_df_columns(stat_df) + self.assertIn('Parent Module', out.columns) + self.assertIn('Module', out.columns) + + def test_build_node_tree_when_module_and_kernel_df_valid_then_tree_structure_correct(self): + """测试_build_node_tree,输入有效module_df和kernel_df时应正确构建树结构""" + module_df = pd.DataFrame([ + {"startNs": 0, "endNs": 100, "name": "mod1"}, + {"startNs": 10, "endNs": 90, "name": "mod2"} + ]) + kernel_df = pd.DataFrame([ + {"kernel_name": "k1", "kernel_ts": 20, "kernel_end": 30, "op_name": "op1", "op_ts": 15, "op_end": 40}, + {"kernel_name": "k2", "kernel_ts": 31, "kernel_end": 35, "op_name": "op1", "op_ts": 15, "op_end": 40}, + {"kernel_name": "k3", "kernel_ts": 50, "kernel_end": 60, "op_name": "op2", "op_ts": 45, "op_end": 70} + ]) + root = self.analysis._build_node_tree(module_df, kernel_df) + # 根节点应有1个子module节点mod1 + self.assertTrue(hasattr(root, 'children')) + root_child = [c.name for c in root.children] + self.assertEqual(['mod1'], root_child) + + # mod1节点,子节点为mod2 + mod1_node = root.children[0] + self.assertTrue(hasattr(mod1_node, 'children')) + mod1_child = [c.name for c in mod1_node.children] + self.assertEqual(['mod2'], mod1_child) + + # op节点为mod2的子节点 + mod2_node = mod1_node.children[0] + op_names = [c.name for c in mod2_node.children if c.node_type == NodeType.CPU_OP_EVENT] + self.assertEqual(['op1', 'op2'], op_names) + # op节点的children为kernel + op1_nodes = [c for c in mod2_node.children if c.node_type == 1 and c.name == 'op1'] + self.assertGreater(len(op1_nodes), 0) + op1_kernels = [child.name for child in op1_nodes[0].children] + self.assertEqual(['k1', 'k2'], op1_kernels) + + def test_build_node_tree_when_no_nodes_then_return_none(self): + """测试_build_node_tree,module_df和kernel_df都为空时应返回None""" + module_df = pd.DataFrame(columns=["startNs", "endNs", "name"]) + kernel_df = pd.DataFrame(columns=["kernel_name", "kernel_ts", "kernel_end", "op_name", "op_ts", "op_end"]) + root = self.analysis._build_node_tree(module_df, kernel_df) + self.assertIsNone(root) + + def test_create_root_node_when_module_and_kernel_df_given_then_min_max_used(self): + """测试_create_root_node,能正确取全局min/max作为root的start/end""" + module_df = pd.DataFrame([{"startNs": 10, "endNs": 100, "name": "mod1"}]) + kernel_df = pd.DataFrame([ + {"kernel_name": "k1", "kernel_ts": 5, "kernel_end": 30, "op_name": "op1", "op_ts": 2, "op_end": 40} + ]) + root = self.analysis._create_root_node(module_df, kernel_df) + self.assertEqual(root.start, 2) + self.assertEqual(root.end, 100) + self.assertEqual(root.name, "") + + def test_flatten_tree_to_dataframe_when_tree_has_data_then_return_df(self): + """测试_flatten_tree_to_dataframe,树结构有数据时应返回DataFrame""" + # 构造简单树:root->module->op->kernel + root = TreeNode(0, 100, NodeType.MODULE_EVENT_NODE, "") + mod_parent = TreeNode(10, 90, NodeType.MODULE_EVENT_NODE, "mod_parent") + mod = TreeNode(15, 50, NodeType.MODULE_EVENT_NODE, "mod") + op = TreeNode(20, 30, NodeType.CPU_OP_EVENT, "op") + kernel = TreeNode(21, 29, NodeType.KERNEL_EVENT, "k") + op.add_child(kernel) + mod.add_child(op) + mod_parent.add_child(mod) + root.add_child(mod_parent) + df = self.analysis._flatten_tree_to_dataframe(root) + self.assertFalse(df.empty) + self.assertIn('module', df.columns) + self.assertIn('op_name', df.columns) + self.assertEqual(df.iloc[0]['kernel_list'], 'k') + self.assertEqual(df.iloc[0]['device_time'], 8) + self.assertEqual(df.iloc[0]['module_parent'], '/mod_parent') + + def test_flatten_tree_to_dataframe_when_tree_is_empty_then_return_empty_df(self): + """测试_flatten_tree_to_dataframe,树无有效数据时应返回空DataFrame""" + root = TreeNode(0, 100, NodeType.MODULE_EVENT_NODE, "root") + df = self.analysis._flatten_tree_to_dataframe(root) + self.assertTrue(df.empty) + + def test_aggregate_module_operator_stats_when_df_is_empty_then_return_empty(self): + """测试_aggregate_module_operator_stats,df为空时应返回空DataFrame""" + df = pd.DataFrame() + out = self.analysis._aggregate_module_operator_stats(df) + self.assertTrue(out.empty) + + def test_aggregate_module_operator_stats_when_df_valid_then_stat_df_shape_and_columns(self): + """测试_aggregate_module_operator_stats,验证聚合和分组逻辑""" + # 输入有4条数据能聚合成2个想同的seq + df1 = pd.DataFrame({ + 'module_parent': ['p', 'p', 'p', 'p'], + 'module': ['m', 'm', 'm', 'm'], + 'module_start': [0, 0, 20, 20], + 'module_end': [10, 10, 40, 40], + 'op_name': ['op1', 'op2', 'op1', 'op2'], + 'op_start': [1, 5, 25, 30], + 'op_end': [4, 8, 29, 33], + 'kernel_list': ['k1', 'k2', 'k1', 'k2'], + 'device_time': [2.0, 2.0, 3.0, 2.0] + }) + self.analysis._export_type = Constant.EXCEL + out1 = self.analysis._aggregate_module_operator_stats(df1) + self.assertEqual(len(out1), 2) + self.assertIn('Op Name', out1.columns) + self.assertIn('Total Kernel Duration(ns)', out1.columns) + self.assertIn('Avg Kernel Duration(ns)', out1.columns) + + op1_row = out1[out1['Op Name'] == 'op1'].iloc[0] + op2_row = out1[out1['Op Name'] == 'op2'].iloc[0] + self.assertEqual(op1_row['Total Kernel Duration(ns)'], 5.0) + self.assertEqual(op1_row['Op Count'], 2) + self.assertEqual(op2_row['Total Kernel Duration(ns)'], 4.0) + self.assertEqual(op2_row['Op Count'], 2) + + # 输入有4条数据不能聚合 + df2 = pd.DataFrame({ + 'module_parent': ['p', 'p', 'p', 'p'], + 'module': ['m', 'm', 'm', 'm'], + 'module_start': [0, 0, 20, 20], + 'module_end': [10, 10, 40, 40], + 'op_name': ['op1', 'op2', 'op1', 'op3'], + 'op_start': [1, 5, 25, 30], + 'op_end': [4, 8, 29, 33], + 'kernel_list': ['k1', 'k2', 'k1', 'k3'], + 'device_time': [2.0, 2.0, 3.0, 2.0] + }) + expected_stat_df = pd.DataFrame({ + 'Parent Module': ['p', 'p', 'p', 'p'], + 'Module': ['m', 'm', 'm', 'm'], + 'Op Name': ['op1', 'op2', 'op1', 'op3'], + 'Kernel List': ['k1', 'k2', 'k1', 'k3'], + 'Total Kernel Duration(ns)': [2.0, 2.0, 3.0, 2.0], + 'Avg Kernel Duration(ns)': [2.0, 2.0, 3.0, 2.0], + 'Op Count': [1, 1, 1, 1] + }) + out2 = self.analysis._aggregate_module_operator_stats(df2) + self.assertEqual(len(out2), 4) + # 逐项对比out2和expected_stat_df内容 + self.assertEqual(list(out2.columns), list(expected_stat_df.columns)) + self.assertEqual(len(out2), len(expected_stat_df)) + for i in range(len(out2)): + for col in out2.columns: + self.assertEqual(out2.at[i, col], expected_stat_df.at[i, col], + msg=f"Row {i}, column '{col}' not equal: " + f"{out2.at[i, col]} != {expected_stat_df.at[i, col]}") \ No newline at end of file