From 132db9c5a984572e5b2300ad167e1b4491ce8860 Mon Sep 17 00:00:00 2001 From: fanglanyue Date: Fri, 5 Sep 2025 11:18:37 +0800 Subject: [PATCH] rm ip address, check class name in ProfilingDataset --- .../advisor/dataset/profiling/profiling_dataset.py | 7 ++++++- profiler/msprof_analyze/docs/recipe_output_format.md | 4 ++-- .../test_base_communication_group.py | 12 ++++++------ 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/profiler/msprof_analyze/advisor/dataset/profiling/profiling_dataset.py b/profiler/msprof_analyze/advisor/dataset/profiling/profiling_dataset.py index ed2eb37b2..10c11091f 100644 --- a/profiler/msprof_analyze/advisor/dataset/profiling/profiling_dataset.py +++ b/profiler/msprof_analyze/advisor/dataset/profiling/profiling_dataset.py @@ -35,6 +35,7 @@ logger = logging.getLogger() @singleton class ProfilingDataset(Dataset): + LEGAL_CLASS_NAME = ["OpSummary", "Msprof", "MsprofDB", "OpSummaryDB", "GeInfo", "TaskTime"] prof_type = "" def __init__(self, collection_path, data: dict, **kwargs) -> None: @@ -60,7 +61,11 @@ class ProfilingDataset(Dataset): # 避免重复构建kernel_details.csv, op_summary.csv的数据对象 continue file_pattern_list = self.current_version_pattern.get('file_attr').get(item) - data_class = globals()[self.current_version_pattern.get('class_attr').get(item)] + class_name = self.current_version_pattern.get('class_attr').get(item) + if class_name not in self.LEGAL_CLASS_NAME: + logger.error(f"Invalid class name for parse profiling data.") + continue + data_class = globals()[class_name] if not hasattr(data_class, "file_pattern_list"): continue setattr(data_class, "file_pattern_list", self.current_version_pattern.get('file_attr').get(item)) diff --git a/profiler/msprof_analyze/docs/recipe_output_format.md b/profiler/msprof_analyze/docs/recipe_output_format.md index 626486340..8a067f929 100644 --- a/profiler/msprof_analyze/docs/recipe_output_format.md +++ b/profiler/msprof_analyze/docs/recipe_output_format.md @@ -261,7 +261,7 @@ O列:TP Index,指集群数据按照并行策略切分后所属TP组的索引 | 字段名 | 类型 | 含义 | | ------ | ---- | ---- | -| GroupName | TEXT | 通信域,例如:10.170.22.98%enp67s0f5_60000_0_1708156014257149 | +| GroupName | TEXT | 通信域,例如:{ip_address}%enp67s0f5_60000_0_1708156014257149 | | GroupId | TEXT | 通信域的hash值的后三位 | | Ranks | TEXT | 该通信域的所有rank | @@ -392,7 +392,7 @@ O列:TP Index,指集群数据按照并行策略切分后所属TP组的索引 | type | TEXT | 算子类型,包含collective和p2p, 其中算子名包含"send","recv","receive"的算子被认为是p2p | | rank_set | TEXT | 通信域内包含的rank(global rank)| | group_name | TEXT | 通信域的hash值,可映射成group_id | -| group_id | TEXT | hccl内部定义的通信域名字,例如:10.170.22.98%enp67s0f5_60000_0_1708156014257149 | +| group_id | TEXT | hccl内部定义的通信域名字,例如:{ip_address}%enp67s0f5_60000_0_1708156014257149 | | pg_name | TEXT | 业务定义的通信域名字,例如:"dp","dp_cp","mp"等等 | ### cluster_time_summary diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/communication_group/test_base_communication_group.py b/profiler/msprof_analyze/test/ut/cluster_analyse/communication_group/test_base_communication_group.py index 558384480..27906311e 100644 --- a/profiler/msprof_analyze/test/ut/cluster_analyse/communication_group/test_base_communication_group.py +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/communication_group/test_base_communication_group.py @@ -134,12 +134,12 @@ class TestBaseCommunicationGroup(unittest.TestCase): "world_size": 8 }, "parallel_group_info": { - "10.174.216.241%enp189s0f1_55000_0_1738895521183247": { + "100%enp189s0f1_55000_0_1738895521183247": { "group_name": "dp", "group_rank": 0, "global_ranks": [0, 2] }, - "10.174.216.241%enp189s0f1_55000_0_1738895507756334": { + "100%enp189s0f1_55000_0_1738895507756334": { "group_name": "pp", "group_rank": 0, "global_ranks": [0, 4] @@ -150,12 +150,12 @@ class TestBaseCommunicationGroup(unittest.TestCase): self.comm_group.read_parallel_group_info() expected_info = { - "10.174.216.241%enp189s0f1_55000_0_1738895521183247": { + "100%enp189s0f1_55000_0_1738895521183247": { "group_name": "dp", "group_rank": 0, "global_ranks": [0, 2] }, - "10.174.216.241%enp189s0f1_55000_0_1738895507756334": { + "100%enp189s0f1_55000_0_1738895507756334": { "group_name": "pp", "group_rank": 0, "global_ranks": [0, 4] @@ -169,12 +169,12 @@ class TestBaseCommunicationGroup(unittest.TestCase): self.comm_group.collective_group_dict = {"12809826787724806246": {0, 2}} self.comm_group.p2p_group_dict = {"9609979115979062393": {0, 4}} self.comm_group.parallel_group_info = { - "10.174.216.241%enp189s0f1_55000_0_1738895521183247": { + "100%enp189s0f1_55000_0_1738895521183247": { "group_name": "dp", "group_rank": 0, "global_ranks": [0, 2] }, - "10.174.216.241%enp189s0f1_55000_0_1738895507756334": { + "100%enp189s0f1_55000_0_1738895507756334": { "group_name": "pp", "group_rank": 0, "global_ranks": [0, 4] -- Gitee