From acf40bf0db321c9c0af26818723599c5f873fb9b Mon Sep 17 00:00:00 2001 From: fanglanyue Date: Thu, 4 Sep 2025 14:33:03 +0800 Subject: [PATCH] bugfix communication_group_map --- .../cluster_analyse/recipes/base_recipe_analysis.py | 4 +++- .../communication_group_map/communication_group_map.py | 9 +++++++-- .../communication_matrix_sum/communication_matrix_sum.py | 2 ++ .../communication_time_sum/communication_time_sum.py | 2 +- profiler/msprof_analyze/prof_common/database_service.py | 2 +- 5 files changed, 14 insertions(+), 5 deletions(-) diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/base_recipe_analysis.py b/profiler/msprof_analyze/cluster_analyse/recipes/base_recipe_analysis.py index c0fe0b20b..0d663675b 100644 --- a/profiler/msprof_analyze/cluster_analyse/recipes/base_recipe_analysis.py +++ b/profiler/msprof_analyze/cluster_analyse/recipes/base_recipe_analysis.py @@ -114,6 +114,9 @@ class BaseRecipeAnalysis(ABC): ) def dump_data(self, data, file_name, table_name=None, index=True, custom_db_path=None): + if data is None: + logger.warning(f"No data to dump, skipping.") + return if not isinstance(data, pd.DataFrame): logger.error(f"Unknown dump data type: {type(data)}, expected pandas DataFrame") return @@ -275,7 +278,6 @@ class BaseRecipeAnalysis(ABC): return os.path.join(data_path, Constant.SINGLE_OUTPUT, "analysis.db") return "" - def _get_step_range(self, db_path): step_range = {} if self._step_id == Constant.VOID_STEP: diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/communication_group_map/communication_group_map.py b/profiler/msprof_analyze/cluster_analyse/recipes/communication_group_map/communication_group_map.py index 7990aefd4..8d51d79a9 100644 --- a/profiler/msprof_analyze/cluster_analyse/recipes/communication_group_map/communication_group_map.py +++ b/profiler/msprof_analyze/cluster_analyse/recipes/communication_group_map/communication_group_map.py @@ -57,6 +57,8 @@ class CommunicationGroupMap(BaseRecipeAnalysis): # concat and process all comm group comm_group_df_list = [df for df, _ in mapper_res] comm_group_combined_df = pd.concat(comm_group_df_list).drop_duplicates() + if comm_group_combined_df.empty: + return comm_group_combined_df = (comm_group_combined_df.groupby([TableConstant.TYPE, TableConstant.GROUP_NAME]) [TableConstant.RANK_ID].apply(lambda x: sorted(set(x))).reset_index()) comm_group_combined_df[TableConstant.RANK_SET] = (comm_group_combined_df[TableConstant.RANK_ID]. @@ -92,8 +94,10 @@ class CommunicationGroupMap(BaseRecipeAnalysis): analysis_data_service.add_table_for_query(Constant.TABLE_COMM_ANALYZER_TIME, [TableConstant.HCCL_OP_NAME, TableConstant.GROUP_NAME]) comm_time_res = analysis_data_service.query_data() - # process comm_time_df: group_name, type, rank_id comm_time_df = comm_time_res.get(Constant.TABLE_COMM_ANALYZER_TIME) + if comm_time_df is None or comm_time_df.empty: + return pd.DataFrame(), pd.DataFrame() + # process comm_time_df: group_name, type, rank_id comm_time_df[TableConstant.RANK_ID] = rank_id comm_time_df[TableConstant.TYPE] = (comm_time_df[TableConstant.HCCL_OP_NAME]. apply(lambda x: self.get_comm_type_from_op_name(x))) @@ -110,7 +114,8 @@ class CommunicationGroupMap(BaseRecipeAnalysis): # process parallel_info_df parallel_info_df = pd.DataFrame(columns=[TableConstant.GROUP_NAME, TableConstant.GROUP_ID, TableConstant.PG_NAME, self.GLOBAL_RANKS]) - if Constant.PARALLEL_GROUP_INFO not in meta_data_df[TableConstant.NAME].values: + if (meta_data_df is None or meta_data_df.empty or + Constant.PARALLEL_GROUP_INFO not in meta_data_df[TableConstant.NAME].values): return comm_time_df, parallel_info_df info_str = meta_data_df.loc[meta_data_df[TableConstant.NAME] == Constant.PARALLEL_GROUP_INFO, TableConstant.VALUE].values[0] diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/communication_matrix_sum/communication_matrix_sum.py b/profiler/msprof_analyze/cluster_analyse/recipes/communication_matrix_sum/communication_matrix_sum.py index 82a4f87e9..55e0daf6e 100644 --- a/profiler/msprof_analyze/cluster_analyse/recipes/communication_matrix_sum/communication_matrix_sum.py +++ b/profiler/msprof_analyze/cluster_analyse/recipes/communication_matrix_sum/communication_matrix_sum.py @@ -173,6 +173,8 @@ class CommMatrixSum(BaseRecipeAnalysis): for rank_data in mapper_res: rank_map.update(rank_data.get(self.RANK_MAP)) matrix_df = rank_data.get(self.MATRIX_DATA) + if matrix_df is None or matrix_df.empty: + continue filter_matrix_df = matrix_df[matrix_df["src_rank"] == matrix_df["dst_rank"]] grouped_matrix_df = filter_matrix_df[['group_name', 'src_rank']].drop_duplicates() grouped_matrix_df[Constant.RANK_ID] = rank_data.get(Constant.RANK_ID) diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/communication_time_sum/communication_time_sum.py b/profiler/msprof_analyze/cluster_analyse/recipes/communication_time_sum/communication_time_sum.py index fd6182a92..7484d4258 100644 --- a/profiler/msprof_analyze/cluster_analyse/recipes/communication_time_sum/communication_time_sum.py +++ b/profiler/msprof_analyze/cluster_analyse/recipes/communication_time_sum/communication_time_sum.py @@ -49,7 +49,7 @@ class CommunicationTimeSum(BaseRecipeAnalysis): def run(self, context): if not self.check_table_exist(self.TABLE_COMMUNICATION_GROUP_MAPPING): if not self.run_communication_group_map_recipe(context): - logger.error("Create CommunicationGroupMap table failed!") + logger.error("Create CommunicationGroupMap table failed! Skip CommunicationTimeSum.") return mapper_res = self.mapper_func(context) self.reducer_func(mapper_res) diff --git a/profiler/msprof_analyze/prof_common/database_service.py b/profiler/msprof_analyze/prof_common/database_service.py index 45df254c3..f1d6a7273 100644 --- a/profiler/msprof_analyze/prof_common/database_service.py +++ b/profiler/msprof_analyze/prof_common/database_service.py @@ -64,7 +64,7 @@ class DatabaseService: def query_data(self): result_data = {} - if not self._table_info: + if not self._table_info or not self._db_path: return result_data try: conn, cursor = DBManager.create_connect_db(self._db_path) -- Gitee