diff --git a/profiler/msprof_analyze/advisor/advisor_backend/compute_advice/npu_slow_advice.py b/profiler/msprof_analyze/advisor/advisor_backend/compute_advice/npu_slow_advice.py index 5f2f123fb6867b6b8fc48e8049f6687c5c9369d9..2a2040e1712bb330bfeaaa5552fe8f959e6c912c 100644 --- a/profiler/msprof_analyze/advisor/advisor_backend/compute_advice/npu_slow_advice.py +++ b/profiler/msprof_analyze/advisor/advisor_backend/compute_advice/npu_slow_advice.py @@ -37,12 +37,11 @@ class NpuSlowAdvice(ComputeAdviceBase, ABC): @staticmethod def save_to_excel(data: pd.DataFrame, file_path: str) -> None: PathManager.check_path_writeable(os.path.dirname(file_path)) - writer = pd.ExcelWriter(file_path, engine="xlsxwriter", mode="w") - data.index.name = Constant.TITLE.INDEX - data.to_excel(writer, index=True, sheet_name=NpuSlowAdvice.OP_PERF_SHEET) - NpuSlowAdvice.color_sheet(data, writer.book, writer.sheets[NpuSlowAdvice.OP_PERF_SHEET]) - writer.sheets[NpuSlowAdvice.OP_PERF_SHEET].freeze_panes = "A2" - writer.close() + with pd.ExcelWriter(file_path, engine="xlsxwriter", mode="w") as writer: + data.index.name = Constant.TITLE.INDEX + data.to_excel(writer, index=True, sheet_name=NpuSlowAdvice.OP_PERF_SHEET) + NpuSlowAdvice.color_sheet(data, writer.book, writer.sheets[NpuSlowAdvice.OP_PERF_SHEET]) + writer.sheets[NpuSlowAdvice.OP_PERF_SHEET].freeze_panes = "A2" @staticmethod def color_sheet(data: pd.DataFrame, workbook, worksheet): @@ -80,7 +79,6 @@ class NpuSlowAdvice(ComputeAdviceBase, ABC): self.data = pd.read_csv(self.kernel_details_path, dtype={"Start Time(us)": str}) # 去除末尾的\t分隔符 self.data["Start Time(us)"] = self.data["Start Time(us)"].apply(lambda x: x[:-1]) - pool = multiprocessing.Pool(multiprocessing.cpu_count()) - result = pool.map(self.update_op_row, self.data.iterrows()) - pool.close() + with multiprocessing.Pool(multiprocessing.cpu_count()) as pool: + result = pool.map(self.update_op_row, self.data.iterrows()) self.data = pd.DataFrame(result) diff --git a/profiler/msprof_analyze/advisor/common/profiling/ge_info.py b/profiler/msprof_analyze/advisor/common/profiling/ge_info.py index f255684290e1935928ba741dec4cfdc55341cfe5..b9ef012a5ae0d504e5edde24c33e6ae2afdd7bf6 100644 --- a/profiler/msprof_analyze/advisor/common/profiling/ge_info.py +++ b/profiler/msprof_analyze/advisor/common/profiling/ge_info.py @@ -18,8 +18,7 @@ import logging import os from typing import Any, List -from sqlalchemy import text -from sqlalchemy.exc import SQLAlchemyError +from msprof_analyze.prof_common.db_manager import DBManager from msprof_analyze.advisor.dataset.profiling.db_manager import ConnectionManager from msprof_analyze.advisor.dataset.profiling.profiling_parser import ProfilingParser @@ -51,14 +50,11 @@ class GeInfo(ProfilingParser): check_path_valid(db_path) if not ConnectionManager.check_db_exists(db_path, [db_file]): return False - try: - conn = ConnectionManager(db_path, db_file) - except SQLAlchemyError as e: - logger.error("Database error: %s", e) - return False - if conn.check_table_exists(['TaskInfo']): - with conn().connect() as sql_conn: - self.op_state_info_list = sql_conn.execute(text("select op_name, op_state from TaskInfo")).fetchall() + conn, cursor = DBManager.create_connect_db(db_path) + if DBManager.judge_table_exists(cursor, 'TaskInfo'): + sql = "select op_name, op_state from TaskInfo" + self.op_state_info_list = DBManager.fetch_all_data(cursor, sql) + DBManager.destroy_db_connect(conn, cursor) return True def get_static_shape_operators(self) -> List[Any]: diff --git a/profiler/msprof_analyze/advisor/dataset/communication/communication_dataset.py b/profiler/msprof_analyze/advisor/dataset/communication/communication_dataset.py index eed7e299589f26c132c83638a579811d2faa1557..97587dc42a29489fe895fd1998aaaa6c5107e5f0 100644 --- a/profiler/msprof_analyze/advisor/dataset/communication/communication_dataset.py +++ b/profiler/msprof_analyze/advisor/dataset/communication/communication_dataset.py @@ -39,9 +39,6 @@ class CommunicationDataset(Dataset): def __init__(self, collection_path, data: dict, **kwargs) -> None: self.collection_path = collection_path - if not collection_path.endswith("ascend_pt") and not collection_path.endswith("ascend_ms"): - return - self.is_pta = collection_path.endswith("ascend_pt") self.communication_file = "" self.hccl_dict = defaultdict(list) self.step = kwargs.get("step") @@ -137,7 +134,8 @@ class CommunicationDataset(Dataset): if not DBManager.check_tables_in_db(self.communication_file, *expected_tables): logger.warning(f"Communication tables: {expected_tables} not found in {self.communication_file}") return False - export = CommunicationInfoExport(self.communication_file, self.is_pta) + is_pta = self.collection_path.endswith("ascend_pt") + export = CommunicationInfoExport(self.communication_file, is_pta) df = export.read_export_db() if TableConstant.STEP not in df.columns: df[TableConstant.STEP] = 'step' diff --git a/profiler/msprof_analyze/advisor/dataset/profiling/profiling_dataset.py b/profiler/msprof_analyze/advisor/dataset/profiling/profiling_dataset.py index ed2eb37b27b95ce957a6627834c6379ea32911fc..10c11091fc2450f67106460fe0132722bbbc13c3 100644 --- a/profiler/msprof_analyze/advisor/dataset/profiling/profiling_dataset.py +++ b/profiler/msprof_analyze/advisor/dataset/profiling/profiling_dataset.py @@ -35,6 +35,7 @@ logger = logging.getLogger() @singleton class ProfilingDataset(Dataset): + LEGAL_CLASS_NAME = ["OpSummary", "Msprof", "MsprofDB", "OpSummaryDB", "GeInfo", "TaskTime"] prof_type = "" def __init__(self, collection_path, data: dict, **kwargs) -> None: @@ -60,7 +61,11 @@ class ProfilingDataset(Dataset): # 避免重复构建kernel_details.csv, op_summary.csv的数据对象 continue file_pattern_list = self.current_version_pattern.get('file_attr').get(item) - data_class = globals()[self.current_version_pattern.get('class_attr').get(item)] + class_name = self.current_version_pattern.get('class_attr').get(item) + if class_name not in self.LEGAL_CLASS_NAME: + logger.error(f"Invalid class name for parse profiling data.") + continue + data_class = globals()[class_name] if not hasattr(data_class, "file_pattern_list"): continue setattr(data_class, "file_pattern_list", self.current_version_pattern.get('file_attr').get(item)) diff --git a/profiler/msprof_analyze/advisor/dataset/stack/db_stack_finder.py b/profiler/msprof_analyze/advisor/dataset/stack/db_stack_finder.py index a61bd8d6a18f65214e8a2b8b068f7d745f33c5d3..56c7b9a02a3d6bb0dff04b5396fb56a1cdf48f74 100644 --- a/profiler/msprof_analyze/advisor/dataset/stack/db_stack_finder.py +++ b/profiler/msprof_analyze/advisor/dataset/stack/db_stack_finder.py @@ -144,13 +144,13 @@ class DBStackFinder: if not self._is_db_contains_stack(): self.stack_map[name] = None return False + conn, cursor = None, None try: conn, cursor = DBManager.create_connect_db(self._db_path) if params: df = pd.read_sql(sql, conn, params=params) else: df = pd.read_sql(sql, conn) - DBManager.destroy_db_connect(conn, cursor) if df is None or df.empty: self.stack_map[name] = None return False @@ -160,3 +160,7 @@ class DBStackFinder: logger.error(f"Error loading API stack data: {e}") self.stack_map[name] = None return False + finally: + if conn and cursor: + DBManager.destroy_db_connect(conn, cursor) + diff --git a/profiler/msprof_analyze/advisor/dataset/timeline_event_dataset.py b/profiler/msprof_analyze/advisor/dataset/timeline_event_dataset.py index 016d0b8753d379adaffd25778b92ada977764477..512a7ae16354be0dea91f824e14241c4592438d2 100644 --- a/profiler/msprof_analyze/advisor/dataset/timeline_event_dataset.py +++ b/profiler/msprof_analyze/advisor/dataset/timeline_event_dataset.py @@ -156,7 +156,7 @@ class BaseTimelineEventDataset(Dataset): for event_type in collector.get_event_type(): df = db_helper.query_timeline_event(event_type) collector.add_op_from_db(df) - db_helper.destory_db_connection() + db_helper.destroy_db_connection() return True def parse_data_with_generator(self, func): diff --git a/profiler/msprof_analyze/advisor/dataset/timeline_op_collector/timeline_op_sql.py b/profiler/msprof_analyze/advisor/dataset/timeline_op_collector/timeline_op_sql.py index 2e820af7d2c1aa0365a44ceb5e1cb42559d27acb..e96c4c60934621a86c90edc0d2084814c69d127e 100644 --- a/profiler/msprof_analyze/advisor/dataset/timeline_op_collector/timeline_op_sql.py +++ b/profiler/msprof_analyze/advisor/dataset/timeline_op_collector/timeline_op_sql.py @@ -243,7 +243,7 @@ class TimelineDBHelper: self.init = bool(self.conn and self.curs) return self.init - def destory_db_connection(self): + def destroy_db_connection(self): DBManager.destroy_db_connect(self.conn, self.curs) self.init = False diff --git a/profiler/msprof_analyze/cluster_analyse/README.md b/profiler/msprof_analyze/cluster_analyse/README.md index 5f8e0bde30f086917d27a5f6ab6da2953b30ea1b..fb1e8588c5872599f8034141cc3aed6767a6c16c 100644 --- a/profiler/msprof_analyze/cluster_analyse/README.md +++ b/profiler/msprof_analyze/cluster_analyse/README.md @@ -399,11 +399,11 @@ msprof-analyze配置--mode参数时可分析并输出cluster_analysis.db交付 格式: -| 字段名 | 类型 | 含义 | -| ------ | ---- | ---- | -| GroupName | TEXT | 通信域,例如:10.170.22.98%enp67s0f5_60000_0_1708156014257149 | -| GroupId | TEXT | 通信域的hash值的后三位 | -| Ranks | TEXT | 该通信域的所有rank | +| 字段名 | 类型 | 含义 | +| ------ | ---- |------------------------------------------------| +| GroupName | TEXT | 通信域,例如:{ip}%enp67s0f5_60000_0_1708156014257149 | +| GroupId | TEXT | 通信域的hash值的后三位 | +| Ranks | TEXT | 该通信域的所有rank | #### HcclTopOpStats @@ -440,13 +440,13 @@ msprof-analyze配置--mode参数时可分析并输出cluster_analysis.db交付 格式: -| 字段名 | 类型 | 含义 | -| ------ | ---- | ---- | +| 字段名 | 类型 | 含义 | +| ------ | ---- |-----------------------------------------------------------------| | type | TEXT | 算子类型,包含collective和p2p, 其中算子名包含"send","recv","receive"的算子被认为是p2p | -| rank_set | TEXT | 通信域内包含的rank(global rank)| -| group_name | TEXT | 通信域的hash值,可映射成group_id | -| group_id | TEXT | hccl内部定义的通信域名字,例如:10.170.22.98%enp67s0f5_60000_0_1708156014257149 | -| pg_name | TEXT | 业务定义的通信域名字,例如:"dp","dp_cp","mp"等等 | +| rank_set | TEXT | 通信域内包含的rank(global rank) | +| group_name | TEXT | 通信域的hash值,可映射成group_id | +| group_id | TEXT | hccl内部定义的通信域名字,例如:{ip}%enp67s0f5_60000_0_1708156014257149 | +| pg_name | TEXT | 业务定义的通信域名字,例如:"dp","dp_cp","mp"等等 | ### communication_time_sum diff --git a/profiler/msprof_analyze/compare_tools/compare_backend/compare_bean/kernel_compare_bean.py b/profiler/msprof_analyze/compare_tools/compare_backend/compare_bean/kernel_compare_bean.py index 3f64f28ec6c62b2c61548b7d2f189339670aea06..b652cdc2086113a1db50d363f28a823820b9e6dc 100644 --- a/profiler/msprof_analyze/compare_tools/compare_backend/compare_bean/kernel_compare_bean.py +++ b/profiler/msprof_analyze/compare_tools/compare_backend/compare_bean/kernel_compare_bean.py @@ -27,7 +27,7 @@ class KernelCompareInfo: self._number = None self._max_dur = None self._min_dur = None - if not data_list: + if len(data_list) < 6: return self._kernel_type = data_list[0] self._input_shapes = data_list[1] diff --git a/profiler/msprof_analyze/compare_tools/compare_backend/view/excel_view.py b/profiler/msprof_analyze/compare_tools/compare_backend/view/excel_view.py index 6a094fdf3df8828f0a89e3333517459257440d19..57f2803e96ee0c7559c7f8aed6f2270f939a3896 100644 --- a/profiler/msprof_analyze/compare_tools/compare_backend/view/excel_view.py +++ b/profiler/msprof_analyze/compare_tools/compare_backend/view/excel_view.py @@ -29,8 +29,8 @@ class ExcelView(BaseView): self._args = args def generate_view(self): - workbook = Workbook(self._file_path) - for sheet_name, data in self._data_dict.items(): - WorkSheetCreator(workbook, sheet_name, data, self._args).create_sheet() - workbook.close() + with Workbook(self._file_path) as workbook: + for sheet_name, data in self._data_dict.items(): + WorkSheetCreator(workbook, sheet_name, data, self._args).create_sheet() os.chmod(self._file_path, Constant.FILE_AUTHORITY) + diff --git a/profiler/msprof_analyze/compare_tools/compare_backend/view/work_sheet_creator.py b/profiler/msprof_analyze/compare_tools/compare_backend/view/work_sheet_creator.py index b73f6df97e81886e6034eae0dcdbe4c180f7995c..d10e35f498e75d2cb36488cd319b5bcf39926030 100644 --- a/profiler/msprof_analyze/compare_tools/compare_backend/view/work_sheet_creator.py +++ b/profiler/msprof_analyze/compare_tools/compare_backend/view/work_sheet_creator.py @@ -44,7 +44,7 @@ class WorkSheetCreator: com_header_format = self._work_book.add_format(CellFormatType.YELLOW_BOLD) com_index_range = [-1, -1] overhead = self._data.get("overhead", []) - if overhead: + if len(overhead) >= 2: base_path = f"Base Profiling: {self._args.base_profiling_path}" self._work_sheet.merge_range(overhead[0], base_path, base_header_format) com_index_range = [self._col_ids.index(overhead[1].split(":")[0][0]), diff --git a/profiler/msprof_analyze/prof_common/database_service.py b/profiler/msprof_analyze/prof_common/database_service.py index 8cd4cdd2a1f414ba6d7945a22dea8fc6c312f85e..45df254c3d8e461067e68292ada6fb3da8d1a89d 100644 --- a/profiler/msprof_analyze/prof_common/database_service.py +++ b/profiler/msprof_analyze/prof_common/database_service.py @@ -98,10 +98,6 @@ class DatabaseService: result_data[table_name] = data except Exception as err: logger.error(err) - return result_data - try: - DBManager.destroy_db_connect(conn, cursor) - except Exception as err: - logger.error(err) - return result_data + break + DBManager.destroy_db_connect(conn, cursor) return result_data diff --git a/profiler/msprof_analyze/prof_common/file_manager.py b/profiler/msprof_analyze/prof_common/file_manager.py index 064eb3f039aea5d39dd44e8ca84aac45f68019b8..92a4a3ec677d68ead7e72670a371f111b7e0a45a 100644 --- a/profiler/msprof_analyze/prof_common/file_manager.py +++ b/profiler/msprof_analyze/prof_common/file_manager.py @@ -120,7 +120,7 @@ class FileManager: PathManager.check_path_writeable(os.path.dirname(file_path)) try: with os.fdopen( - os.open(file_path, os.O_WRONLY | os.O_CREAT, Constant.FILE_AUTHORITY), + os.open(file_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, Constant.FILE_AUTHORITY), 'w') as file: file.write(content) except Exception as e: diff --git a/profiler/msprof_analyze/test/ut/advisor/dataset/test_timeline_op_sql.py b/profiler/msprof_analyze/test/ut/advisor/dataset/test_timeline_op_sql.py index 6cc9cd0374e5f31351ac4f935f237859af30fe17..c94a963823f3fc2542811bebb8216d7fb5ed9cd1 100644 --- a/profiler/msprof_analyze/test/ut/advisor/dataset/test_timeline_op_sql.py +++ b/profiler/msprof_analyze/test/ut/advisor/dataset/test_timeline_op_sql.py @@ -131,6 +131,6 @@ class TestTimelineDBHelper(unittest.TestCase): self.db_helper.init = True self.db_helper.conn = MagicMock() self.db_helper.curs = MagicMock() - self.db_helper.destory_db_connection() + self.db_helper.destroy_db_connection() self.assertFalse(self.db_helper.init) diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/communication_group/test_base_communication_group.py b/profiler/msprof_analyze/test/ut/cluster_analyse/communication_group/test_base_communication_group.py index 558384480575a2d48dd17a11e563efc9c46df2e5..27906311e99ff0472d4ec5d7e0d0d29b36a17ec9 100644 --- a/profiler/msprof_analyze/test/ut/cluster_analyse/communication_group/test_base_communication_group.py +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/communication_group/test_base_communication_group.py @@ -134,12 +134,12 @@ class TestBaseCommunicationGroup(unittest.TestCase): "world_size": 8 }, "parallel_group_info": { - "10.174.216.241%enp189s0f1_55000_0_1738895521183247": { + "100%enp189s0f1_55000_0_1738895521183247": { "group_name": "dp", "group_rank": 0, "global_ranks": [0, 2] }, - "10.174.216.241%enp189s0f1_55000_0_1738895507756334": { + "100%enp189s0f1_55000_0_1738895507756334": { "group_name": "pp", "group_rank": 0, "global_ranks": [0, 4] @@ -150,12 +150,12 @@ class TestBaseCommunicationGroup(unittest.TestCase): self.comm_group.read_parallel_group_info() expected_info = { - "10.174.216.241%enp189s0f1_55000_0_1738895521183247": { + "100%enp189s0f1_55000_0_1738895521183247": { "group_name": "dp", "group_rank": 0, "global_ranks": [0, 2] }, - "10.174.216.241%enp189s0f1_55000_0_1738895507756334": { + "100%enp189s0f1_55000_0_1738895507756334": { "group_name": "pp", "group_rank": 0, "global_ranks": [0, 4] @@ -169,12 +169,12 @@ class TestBaseCommunicationGroup(unittest.TestCase): self.comm_group.collective_group_dict = {"12809826787724806246": {0, 2}} self.comm_group.p2p_group_dict = {"9609979115979062393": {0, 4}} self.comm_group.parallel_group_info = { - "10.174.216.241%enp189s0f1_55000_0_1738895521183247": { + "100%enp189s0f1_55000_0_1738895521183247": { "group_name": "dp", "group_rank": 0, "global_ranks": [0, 2] }, - "10.174.216.241%enp189s0f1_55000_0_1738895507756334": { + "100%enp189s0f1_55000_0_1738895507756334": { "group_name": "pp", "group_rank": 0, "global_ranks": [0, 4]