From 50487f6d5b7d08930f051e6a40fb5aee2cf94fe3 Mon Sep 17 00:00:00 2001 From: tangmengcheng <745274877@qq.com> Date: Thu, 24 Jul 2025 15:23:42 +0800 Subject: [PATCH] rl analysis recipe --- .../cluster_analyse/cluster_analysis.py | 4 +- .../data_preprocessor.py | 16 +- .../mindspore_data_preprocessor.py | 14 +- .../pytorch_data_preprocessor.py | 12 +- .../recipes/base_recipe_analysis.py | 3 + .../recipes/cluster_display.py | 188 ++++++++++++++- .../recipes/rl_analysis/__init__.py | 0 .../recipes/rl_analysis/rl_analysis.py | 223 ++++++++++++++++++ .../recipes/rl_analysis/stats.ipynb | 206 ++++++++++++++++ .../msprof_analyze/prof_common/constant.py | 3 + .../prof_exports/mstx_event_export.py | 44 ++++ 11 files changed, 699 insertions(+), 14 deletions(-) create mode 100644 profiler/msprof_analyze/cluster_analyse/recipes/rl_analysis/__init__.py create mode 100644 profiler/msprof_analyze/cluster_analyse/recipes/rl_analysis/rl_analysis.py create mode 100644 profiler/msprof_analyze/cluster_analyse/recipes/rl_analysis/stats.ipynb create mode 100644 profiler/msprof_analyze/prof_exports/mstx_event_export.py diff --git a/profiler/msprof_analyze/cluster_analyse/cluster_analysis.py b/profiler/msprof_analyze/cluster_analyse/cluster_analysis.py index 43e716efa..5aaab1671 100644 --- a/profiler/msprof_analyze/cluster_analyse/cluster_analysis.py +++ b/profiler/msprof_analyze/cluster_analyse/cluster_analysis.py @@ -76,10 +76,10 @@ class Interface: ascend_pt_dirs.append(os.path.join(root, dir_name)) if dir_name.endswith(self.ASCEND_MS): ascend_ms_dirs.append(os.path.join(root, dir_name)) - pytorch_processor = PytorchDataPreprocessor(ascend_pt_dirs) + pytorch_processor = PytorchDataPreprocessor(ascend_pt_dirs, self.analysis_mode) pt_data_map = pytorch_processor.get_data_map() data_type = pytorch_processor.get_data_type() - ms_data_map = MindsporeDataPreprocessor(ascend_ms_dirs).get_data_map() + ms_data_map = MindsporeDataPreprocessor(ascend_ms_dirs, self.analysis_mode).get_data_map() if pt_data_map and ms_data_map: logger.error("Can not analyze pytorch and mindspore meantime.") return [] diff --git a/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/data_preprocessor.py b/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/data_preprocessor.py index f2d38d9c1..9e2cadad9 100644 --- a/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/data_preprocessor.py +++ b/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/data_preprocessor.py @@ -15,13 +15,17 @@ import os from abc import abstractmethod +from msprof_analyze.prof_common.file_manager import FileManager + class DataPreprocessor: PROFILER_INFO_HEAD = 'profiler_info_' PROFILER_INFO_EXTENSION = '.json' + PROFILER_METADATA_JSON = 'profiler_metadata.json' - def __init__(self, path_list: list): + def __init__(self, path_list: list, analysis_mode: str): self.path_list = path_list + self.analysis_mode = analysis_mode self.data_map = {} @abstractmethod @@ -39,3 +43,13 @@ class DataPreprocessor: rank_id = -1 return rank_id return -1 + + def get_task_roll(self, dir_name: str) -> str: + files = os.listdir(dir_name) + for file_name in files: + if file_name == self.PROFILER_METADATA_JSON: + config = FileManager.read_json_file(os.path.join(dir_name, file_name)) + task_roll_str = config.get("roll") + if task_roll_str: + return task_roll_str + return "default_roll" diff --git a/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/mindspore_data_preprocessor.py b/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/mindspore_data_preprocessor.py index eaa14fb71..b3aef32c6 100644 --- a/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/mindspore_data_preprocessor.py +++ b/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/mindspore_data_preprocessor.py @@ -15,7 +15,7 @@ from collections import defaultdict from msprof_analyze.cluster_analyse.cluster_data_preprocess.data_preprocessor import DataPreprocessor - +from msprof_analyze.prof_common.constant import Constant from msprof_analyze.prof_common.logger import get_logger logger = get_logger() @@ -23,22 +23,26 @@ logger = get_logger() class MindsporeDataPreprocessor(DataPreprocessor): - def __init__(self, path_list: list): - super().__init__(path_list) + def __init__(self, path_list: list, analysis_mode: str = ""): + super().__init__(path_list, analysis_mode) def get_data_map(self) -> dict: rank_id_map = defaultdict(list) for dir_name in self.path_list: rank_id = self.get_rank_id(dir_name) + task_roll = self.get_task_roll(dir_name) if rank_id < 0: logger.error("fail to get rankid or rankid invalid.") continue + if self.analysis_mode in Constant.MUTI_TASK_RECIPES and task_roll: + rank_id_map[(task_roll, rank_id)].append(dir_name) + continue rank_id_map[rank_id].append(dir_name) try: - for (rank_id, dir_list) in rank_id_map.items(): + for (map_key, dir_list) in rank_id_map.items(): dir_list.sort(key=lambda x: x.split('_')[-3]) - self.data_map[rank_id] = dir_list[0] + self.data_map[map_key] = dir_list[0] except Exception as e: raise RuntimeError("Found invalid directory name!") from e return self.data_map diff --git a/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/pytorch_data_preprocessor.py b/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/pytorch_data_preprocessor.py index 09a46bc71..15f68a46e 100644 --- a/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/pytorch_data_preprocessor.py +++ b/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/pytorch_data_preprocessor.py @@ -25,14 +25,15 @@ logger = get_logger() class PytorchDataPreprocessor(DataPreprocessor): - def __init__(self, path_list: list): - super().__init__(path_list) + def __init__(self, path_list: list, analysis_mode: str = ""): + super().__init__(path_list, analysis_mode) self.data_type = set() def get_data_map(self) -> dict: rank_id_map = defaultdict(list) for dir_name in self.path_list: rank_id = self.get_rank_id(dir_name) + task_roll = self.get_task_roll(dir_name) if rank_id < 0: logger.error("fail to get rankid or rankid invalid.") continue @@ -46,12 +47,15 @@ class PytorchDataPreprocessor(DataPreprocessor): self.data_type.add(Constant.DB if Constant.DB in export_type else Constant.TEXT) else: self.data_type.add(export_type) + if self.analysis_mode in Constant.MUTI_TASK_RECIPES and task_roll: + rank_id_map[(task_roll, rank_id)].append(dir_name) + continue rank_id_map[rank_id].append(dir_name) try: - for (rank_id, dir_list) in rank_id_map.items(): + for (map_key, dir_list) in rank_id_map.items(): dir_list.sort(key=lambda x: x.split('_')[-3]) - self.data_map[rank_id] = dir_list[0] + self.data_map[map_key] = dir_list[0] except Exception as e: raise RuntimeError("Found invalid directory name!") from e return self.data_map diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/base_recipe_analysis.py b/profiler/msprof_analyze/cluster_analyse/recipes/base_recipe_analysis.py index 6a0273c8c..9db6f548f 100644 --- a/profiler/msprof_analyze/cluster_analyse/recipes/base_recipe_analysis.py +++ b/profiler/msprof_analyze/cluster_analyse/recipes/base_recipe_analysis.py @@ -237,6 +237,9 @@ class BaseRecipeAnalysis(ABC): logger.warning(f"Invalid Rank id : [{','.join(invalid_rank_id)}].") return db_paths + def _get_profiler_db_path(self, rank_id, data_path): + return os.path.join(data_path, Constant.SINGLE_OUTPUT, f"ascend_pytorch_profiler_{rank_id}.db") + def _mapper_func(self, data_map, analysis_class): """ Extract the profiling data required for cluster analysis from each device, and then aggregate the diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/cluster_display.py b/profiler/msprof_analyze/cluster_analyse/recipes/cluster_display.py index 7e8913948..6d7578482 100644 --- a/profiler/msprof_analyze/cluster_analyse/recipes/cluster_display.py +++ b/profiler/msprof_analyze/cluster_analyse/recipes/cluster_display.py @@ -15,13 +15,16 @@ import logging import math +from tqdm import tqdm import matplotlib.pyplot as plt +import matplotlib.patches as patches import numpy as np import pandas as pd import plotly.graph_objects as go +from matplotlib.colors import TABLEAU_COLORS from IPython.display import display, HTML from ipywidgets import Dropdown, fixed, interact -from msprof_analyze.cluster_analyse.common_func.utils import calculate_zscore +from msprof_analyze.cluster_analyse.common_func.utils import calculate_zscore logger = logging.getLogger("cluster_display") @@ -324,4 +327,185 @@ def display_transmittime_bar(slowlinkops_df, ratio_set=0.05, optype='hcom_allGat data_ttzscore = process_data(slowlinkops_df_f, 'opTypeRelatedRanksDataSize', 'transmitTime_Zscore', num_intervals) plot_data(data_tt, 'Transmit Time Distribution', 'Time (ns)') - plot_data(data_ttzscore, 'Z-Score of Transmit Time Distribution', 'Z-Score') \ No newline at end of file + plot_data(data_ttzscore, 'Z-Score of Transmit Time Distribution', 'Z-Score') + + +class RLAnalysisChart: + """ + A class for generating various charts and visualizations for RL analysis. + """ + + def __init__(self, events_df, selected_domains=None): + """ + Initialize the RLAnalysisChart with event data. + + Parameters: + events_df: Event data DataFrame + selected_domains: List of domains to display, if None then show all (excluding 'default') + """ + # Validate input data + if events_df is None or events_df.empty: + logger.error("No event data available") + raise ValueError("events_df cannot be None or empty") + + self.events_df = events_df + + # Handle selected domains (exclude 'default' by default) + if selected_domains is None: + self.selected_domains = [domain for domain in events_df['domain'].unique() if domain != 'default'] + else: + self.selected_domains = [domain for domain in selected_domains] + + self.filtered_events_df = self.events_df[self.events_df['domain'].isin(self.selected_domains)].copy() + if self.filtered_events_df.empty: + logger.error("No events found for selected domains") + raise ValueError("No events found for selected domains") + + global_start_time = self.filtered_events_df['start_time_ms'].min() + self.filtered_events_df['normalized_start'] = self.filtered_events_df['start_time_ms'] - global_start_time + + # Initialize domain colors + self._initialize_domain_colors() + + def plot_timeline_thumbnail(self, + output_path=None, + show_chart=False, + chart_title="RL Timeline thumbnail", + domain_height=0.2, + dpi=96, + group_by_communication_group_rank=None): + """ + Plot timeline thumbnail chart. + + Parameters: + output_path: Output file path, if None then not saved + show_chart: Whether to display the chart + chart_title: Title of the chart + domain_height: Height of each domain segment + dpi: DPI for saved image + group_by_communication_group_rank: Communication group to group data by + """ + # Prepare data with communication grouping + if group_by_communication_group_rank is not None: + data_df = self._group_data_by_communication_group_rank(group_by_communication_group_rank) + data_df['file_key'] = (data_df['roll'].astype(str) + "_" + data_df['extracted_group'].astype(str) + ) + else: + data_df = self.filtered_events_df.copy() + data_df['file_key'] = (data_df['roll'].astype(str) + ' rank_' + data_df['rank_id'].astype(str)) + + # Set font to Times New Roman + plt.rcParams['font.family'] = 'Times New Roman' + plt.rcParams['font.size'] = 12 + + # Create y-position mapping + file_keys = data_df['file_key'].unique() + + # For each file_key, determine the domains present and assign sub-rows + file_key_domains = {} + for key in file_keys: + domains_in_key = data_df[data_df['file_key'] == key]['domain'].unique() + file_key_domains[key] = {domain: i for i, domain in enumerate(sorted(domains_in_key))} + + # Create mapping from (file_key, domain) to y-position + y_mapping = {} + current_y = 0 + file_key_y_positions = {} + + for key in sorted(file_keys): + domains = file_key_domains[key] + for domain, sub_row in domains.items(): + y_mapping[(key, domain)] = current_y + sub_row * domain_height + file_key_y_positions[key] = current_y + len(domains) * domain_height / 2 + current_y += len(domains) * domain_height + current_y += domain_height * 0.5 + + # Calculate chart height + fig, ax = plt.subplots(figsize=(16, current_y)) + + # Add progress bar for event processing + for _, event in tqdm(data_df.iterrows(), total=len(data_df), desc="Generating timeline"): + key = (event['file_key'], event['domain']) + y_pos = y_mapping.get(key) + if y_pos is None: + logger.warning(f"Missing y-position mapping for key: {key}") + continue + + rect = patches.Rectangle( + (event['normalized_start'], y_pos), + event['duration_ms'], + height=domain_height, + facecolor=self.domain_colors.get(event['domain'], 'gray') + ) + ax.add_patch(rect) + + # Add legend + legend_elements = [patches.Patch(color=color, label=domain) + for domain, color in self.domain_colors.items()] + ax.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(1, 0.5), fontsize=10) + + # Set axes + max_time = data_df['normalized_start'].max() + data_df['duration_ms'].max() + ax.set_xlim(0, max_time) + ax.set_ylim(0, current_y) + + # Set yticks + ax.set_yticks(list(file_key_y_positions.values())) + ax.set_yticklabels(list(file_key_y_positions.keys())) + + # Set labels and title + ax.set_xlabel("Time (milliseconds)", fontsize=12) + ax.set_title(chart_title, fontsize=16, fontweight='bold', pad=20) + + plt.tight_layout() + + # Save or display chart + if output_path: + plt.savefig(output_path, bbox_inches='tight', dpi=dpi) + logger.info(f"Chart saved to {output_path}") + + if show_chart: + plt.show() + else: + plt.close() + + def _initialize_domain_colors(self): + """Initialize color mapping for domains.""" + unique_domains = self.filtered_events_df['domain'].unique() + tableau_colors = list(TABLEAU_COLORS) + colors = tableau_colors[:max(len(unique_domains), len(tableau_colors) - 1)] + + self.domain_colors = {} + for i, domain in enumerate(unique_domains): + base_color = colors[i % len(colors)] + self.domain_colors[domain] = base_color + + def _group_data_by_communication_group_rank(self, group_by_communication_group_rank): + """ + Prepare data grouped by communication group rank. + + Parameters: + group_by_communication_group_rank: Communication group to group data by + + Returns: + DataFrame with communication grouping applied + """ + def extract_group(group_str, prefix): + group_list = group_str.split(",") + for group in group_list: + if group.startswith(prefix): + return group + return None + + # Apply communication grouping + self.filtered_events_df['extracted_group'] = self.filtered_events_df['communication_group'].apply( + lambda x: extract_group(x, group_by_communication_group_rank) + ) + + grouped_df = self.filtered_events_df[self.filtered_events_df['extracted_group'].notna()].copy() + + if grouped_df.empty: + logger.warning(f"No events found for communication group: {group_by_communication_group_rank}") + return grouped_df + + return grouped_df diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/rl_analysis/__init__.py b/profiler/msprof_analyze/cluster_analyse/recipes/rl_analysis/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/rl_analysis/rl_analysis.py b/profiler/msprof_analyze/cluster_analyse/recipes/rl_analysis/rl_analysis.py new file mode 100644 index 000000000..a84609e27 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/rl_analysis/rl_analysis.py @@ -0,0 +1,223 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import List, Dict, Optional + +import pandas as pd + +from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis +from msprof_analyze.prof_exports.mstx_event_export import MstxHostExport +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_common.file_manager import FileManager +from msprof_analyze.cluster_analyse.common_func.excel_utils import ExcelUtils +from msprof_analyze.cluster_analyse.recipes.cluster_display import RLAnalysisChart + +logger = get_logger() + + +class RlAnalysis(BaseRecipeAnalysis): + """Reinforcement Learning Profiling Data Analysis Recipe""" + + # db table name + TABLE_RL_EVENTS = "RlEvents" + + # communication rank domian name + COMMUNICATION_GROUP_DOMAIN = 'communication_group' + + # File names + EVENT_STATISTIC_FILE = "rl_event_statistic.xlsx" + EVENT_SUMMARY_FILE = "rl_event_summary.csv" + TIMELINE_THUMBNAIL_FILE = "rl_timeline_thumbnail.png" + + # Excel configuration + EXCEL_SHEET_NAME = "Rl Event Statistic" + EXCEL_MERGE_COLUMNS = ['Roll', 'Rank ID'] + EXCEL_COLUMN_WIDTHS = {'Roll': 30, 'Rank ID': 20, 'Domain': 50, 'Count': 20, + 'Total Time (ms)': 20, 'Max Time (ms)': 20, 'Min Time (ms)': 20, + 'Mean Time (ms)': 20} + + def __init__(self, params): + super().__init__(params) + self.events_summary: Optional[pd.DataFrame] = None + + @property + def base_dir(self) -> str: + return os.path.basename(os.path.dirname(__file__)) + + def run(self, context): + """Main entry point for the recipe""" + mapper_res = self.mapper_func(context) + self.reducer_func(mapper_res) + self.generate_rl_event_statistic() + self.generate_rl_event_summary() + self.generate_timeline_thumbnail() + if self._export_type == Constant.DB: + self.save_db() + elif self._export_type == Constant.NOTEBOOK: + self.save_notebook() + else: + logger.error(f"Unknown export type: {self._export_type}") + + def parse_rl_mstx_event(self, profiler_db_path: str, rank_id: int, roll: str) -> List[Dict]: + """Parse MSTX events from database and extract non-default domain events""" + events = [] + mstx_export = MstxHostExport(profiler_db_path, self._recipe_name) + df = mstx_export.read_export_db() + + if df is None or df.empty: + logger.warning(f"Rank {rank_id}: No MSTX events found in database") + return events + + filtered_df = df[(df['domain'].notna())] + if filtered_df.empty: + logger.warning(f"Rank {rank_id}: No non-default domain events found") + return events + # Convert nanoseconds to milliseconds + ns_to_ms = Constant.NS_TO_US * Constant.US_TO_MS + filtered_df['start_time_ms'] = filtered_df['cann_start_ts'] / ns_to_ms + filtered_df['end_time_ms'] = filtered_df['cann_end_ts'] / ns_to_ms + filtered_df['duration_ms'] = filtered_df['end_time_ms'] - filtered_df['start_time_ms'] + + for row in filtered_df.itertuples(): + event_data = { + 'name': row.msg, + "roll": roll, + 'domain': row.domain, + 'start_time_ms': row.start_time_ms, + 'end_time_ms': row.end_time_ms, + 'duration_ms': row.duration_ms, + 'host_uid': row.host_uid, + 'device_id': row.device_id, + 'rank_id': rank_id, + 'tid': row.tid + } + events.append(event_data) + + return events + + def reducer_func(self, mapper_res): + """Process data collected from all ranks""" + # Remove None results + reduce_results = [result for result in mapper_res if result is not None] + + if not reduce_results: + logger.warning("No valid data collected from any rank") + return + + self.events_summary = [] + for events in reduce_results: + self.events_summary.extend(events) + + roll_rank_to_comm_groups = {} + for event in self.events_summary: + if event['domain'] == self.COMMUNICATION_GROUP_DOMAIN: + roll_rank_to_comm_groups.setdefault((event['roll'], event['rank_id']), set()).add(event["name"]) + + for event in self.events_summary: + groups_set = roll_rank_to_comm_groups.get((event['roll'], event['rank_id']), set()) + event["communication_group"] = ",".join(groups_set) if groups_set else "" + + self.events_summary = pd.DataFrame(self.events_summary) + + def generate_rl_event_statistic(self): + """Generate both event summary and statistic data outputs""" + # Group by roll, rank_id, and domain to calculate statistics + statistic_data = [] + for (roll, rank_id, domain), group in self.events_summary.groupby(['roll', 'rank_id', 'domain']): + duration_stats = group['duration_ms'].agg(['count', 'sum', 'max', 'min', 'mean']) + + statistic_data.append({ + 'Roll': roll, + 'Rank ID': rank_id, + 'Domain': domain, + 'Count': int(duration_stats['count']), + 'Total Time (ms)': round(duration_stats['sum'], 3), + 'Max Time (ms)': round(duration_stats['max'], 3), + 'Min Time (ms)': round(duration_stats['min'], 3), + 'Mean Time (ms)': round(duration_stats['mean'], 3) + }) + + statistic_df = pd.DataFrame(statistic_data) + statistic_df = statistic_df.sort_values(['Roll', 'Rank ID']) + + excel_utils = ExcelUtils() + excel_utils.create_excel_writer(self._output_path, self.EVENT_STATISTIC_FILE, statistic_df) + excel_utils.merge_duplicate_cells(self.EXCEL_MERGE_COLUMNS) + excel_utils.set_column_width(self.EXCEL_COLUMN_WIDTHS) + excel_utils.save_and_close() + excel_utils.clear() + + def generate_rl_event_summary(self): + self.dump_data(self.events_summary, self.EVENT_SUMMARY_FILE, index=False) + + def generate_timeline_thumbnail(self): + output_path = os.path.join(self.output_path, self.TIMELINE_THUMBNAIL_FILE) + RLAnalysisChart(self.events_summary).plot_timeline_thumbnail(output_path=output_path) + + def save_notebook(self): + """Save results in notebook format""" + self.create_notebook("stats.ipynb") + self.add_helper_file("cluster_display.py") + + def save_db(self): + """Save to database format""" + if self.events_summary is not None: + self.dump_data( + self.events_summary, + Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, + self.TABLE_RL_EVENTS, + index=False + ) + + def mapper_func(self, context): + return context.wait( + context.map( + self._mapper_func, + self._get_rank_db_with_roll(), + analysis_class=self._recipe_name + ) + ) + + def _mapper_func(self, data_map, analysis_class): + """Collect RL performance data from a single rank""" + profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) + rank_id = data_map.get(Constant.RANK_ID) + roll = data_map.get(Constant.ROLL) + + if not profiler_db_path: + logger.warning(f"Rank {rank_id}: profiler_db_path not found") + return None + + return self.parse_rl_mstx_event(profiler_db_path, rank_id, roll) + + def _get_rank_db_with_roll(self): + """Get database path information for all ranks""" + if self._rank_list != 'all': + logger.warning("RL analysis currently only supports processing all ranks") + rank_ids_with_roll = list(self._data_map.keys()) + db_paths = [] + for roll, rank_id in rank_ids_with_roll: + rank_path = self._data_map[(roll, rank_id)] + db_path_dict = {Constant.RANK_ID: rank_id, Constant.ROLL: roll, Constant.PROFILER_DB_PATH: ""} + profiler_db_path = self._get_profiler_db_path(rank_id, rank_path) + if os.path.exists(profiler_db_path): + db_path_dict[Constant.PROFILER_DB_PATH] = profiler_db_path + db_paths.append(db_path_dict) + else: + logger.warning(f"Profiler DB file not found, rank id: {rank_id}, db path: {profiler_db_path}.") + + return db_paths diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/rl_analysis/stats.ipynb b/profiler/msprof_analyze/cluster_analyse/recipes/rl_analysis/stats.ipynb new file mode 100644 index 000000000..fdbc72589 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/rl_analysis/stats.ipynb @@ -0,0 +1,206 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 强化学习事件数据分析\n", + "\n", + "集群场景强化学习事件数据分析\n", + "\n", + "主要包含以下分析内容:\n", + "1. mstx打点事件数据总览和统计\n", + "2. 按 \"强化学习计算阶段(Roll)- 全局rank\" 分组的时间线缩略图分析\n", + "3. 按 \"强化学习计算阶段(Roll)- 通讯域rank\" 分组的时间线缩略图分析" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. 数据准备和配置" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import display, HTML\n", + "display(HTML(\"\"))\n", + "\n", + "import pandas as pd\n", + "import os\n", + "from cluster_display import RLAnalysisChart\n", + "\n", + "# 配置输出路径(默认为当前目录)\n", + "output_dir = \"./\" # 可以修改为其他路径\n", + "os.makedirs(output_dir, exist_ok=True)\n", + "\n", + "print(\"✅ 配置完成!\")\n", + "print(f\"输出目录: {os.path.abspath(output_dir)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 加载数据\n", + "rl_events_df = pd.read_csv(\"rl_event_summary.csv\")\n", + "print(\"✅ 数据加载完成!\")\n", + "print(f\"总共加载了 {len(rl_events_df)} 个mstx打点事件\")\n", + "print(f\"包含 {rl_events_df['domain'].nunique()} 种不同的mstx domain打点类型\")\n", + "\n", + "# 显示数据基本信息\n", + "print(\"\\n数据列信息:\")\n", + "print(rl_events_df.columns.tolist())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. 初始化RLAnalysisChart实例\n", + "\n", + "可以通过 `selected_domains` 参数选择需要展示的domain打点类型" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 查看所有可用的domain类型\n", + "available_domains = sorted(rl_events_df['domain'].unique())\n", + "print(\"可用的domain类型:\")\n", + "for i, domain in enumerate(available_domains, 1):\n", + " count = len(rl_events_df[rl_events_df['domain'] == domain])\n", + " print(f\"{i}. {domain} (共{count}个事件)\")\n", + "\n", + "# 初始化RLAnalysisChart实例\n", + "# 可以通过selected_domains参数指定要显示的domain类型\n", + "# 例如: selected_domains=['domain1', 'domain2']\n", + "rl_chart = RLAnalysisChart(rl_events_df)\n", + "print(\"\\n✅ RLAnalysisChart实例初始化完成!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. 查看强化学习计算阶段(Roll)信息" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"数据中包含的强化学习计算阶段(Roll):\")\n", + "available_rolls = sorted(rl_events_df['roll'].unique())\n", + "for i, roll in enumerate(available_rolls, 1):\n", + " count = len(rl_events_df[rl_events_df['roll'] == roll])\n", + " print(f\"{i}. {roll} (共{count}个事件)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4. Timeline缩略图(按 \"强化学习计算阶段(Roll)- 全局rank\" 分组)\n", + "\n", + "生成按Roll分组的时间线图,展示不同roll上的事件执行时序,按照roll-rank进行分组打点时间,生成timeline缩略图" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 生成按Roll分组的timeline缩略图\n", + "output_path_roll = os.path.join(output_dir, \"rl_timeline_by_roll.png\")\n", + "\n", + "print(\"正在生成按Roll分组的timeline缩略图...\")\n", + "rl_chart.plot_timeline_thumbnail(\n", + " output_path=output_path_roll,\n", + " show_chart=True,\n", + ")\n", + "print(f\"✅ 按Roll分组的timeline缩略图生成完成!\")\n", + "print(f\"图片保存路径: {os.path.abspath(output_path_roll)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5. Timeline缩略图(按 \"强化学习计算阶段(Roll)- 通讯域rank\" 分组)\n", + "\n", + "由于强化学习各个计算阶段可能存在数据交互,Timeline缩略图也提供了按照 \"强化学习计算阶段(Roll)- 通讯域rank\"对数据进行分组合并的功能\n", + "**使用方法:**\n", + "1. 可从[`dp`,`pp`, `tp`]中选择一个communication rank\n", + "2. 修改下面代码中的 `selected_comm_group` 变量\n", + "3. 运行代码即可生成对应的timeline缩略图" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 选择要分析的communication group\n", + "selected_comm_group = \"dp\" # 可以修改为 \"pp\" 或 \"tp\"\n", + "\n", + "if selected_comm_group:\n", + " output_path_comm = os.path.join(output_dir, f\"rl_timeline_by_{selected_comm_group}.png\")\n", + "\n", + " print(f\"正在生成按照 {selected_comm_group} rank汇总的timeline缩略图...\")\n", + " rl_chart.plot_timeline_thumbnail(\n", + " output_path=output_path_comm,\n", + " show_chart=True,\n", + " group_by_communication_group_rank=selected_comm_group\n", + " )\n", + " print(f\"✅ {selected_comm_group} 的时间线图生成完成!\")\n", + " print(f\"图片保存路径: {os.path.abspath(output_path_comm)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 数据说明\n", + "\n", + "**图表说明:**\n", + "- **X轴**:时间(毫秒)\n", + "- **Y轴**:设备组和rank信息\n", + "- **颜色**:不同的domain类型\n", + "\n", + "**输出文件:**\n", + "- `rl_event_summary.csv`:原始事件数据\n", + "- `rl_analysis_statistics.xlsx`:统计信息Excel文件\n", + "- `rl_timeline_by_roll.png`:按Roll分组的时间线缩略图\n", + "- `rl_timeline_by_dp.png`:按dp分组的时间线缩略图\n", + "- `rl_timeline_by_pp.png`:按pp分组的时间线缩略图\n", + "- `rl_timeline_by_tp.png`:按tp分组的时间线缩略图\n", + "\n", + "**使用说明:**\n", + "1. 可以通过修改 `output_dir` 变量来指定输出目录\n", + "2. 可以通过修改 `selected_domains` 参数来选择要显示的domain类型\n", + "3. 可以通过修改 `selected_comm_group` 变量来选择要分析的communication group" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/profiler/msprof_analyze/prof_common/constant.py b/profiler/msprof_analyze/prof_common/constant.py index b05961137..7b9c88e8b 100644 --- a/profiler/msprof_analyze/prof_common/constant.py +++ b/profiler/msprof_analyze/prof_common/constant.py @@ -205,6 +205,7 @@ class Constant(object): BLUE_COLOR = "00BFFF" LIGHT_BLUE_COLOR = "87CEFA" US_TO_MS = 1000 + NS_TO_US = 1000 KB_TO_MB = 1024 INVALID_VALUE = -1 MILLISECONDS_TO_SECONDS = 10 ** 3 @@ -469,10 +470,12 @@ class Constant(object): PARALLEL_MODE = "parallel_mode" MSPROF_ANALYZE_PATH = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) RECIPES_PATH = os.path.join(MSPROF_ANALYZE_PATH, 'cluster_analyse', 'recipes') + MUTI_TASK_RECIPES = ["rl_analysis"] CONCURRENT_MODE = "concurrent" PROFILER_DB_PATH = "profiler_db_path" ANALYSIS_DB_PATH = "analysis_db_path" + ROLL = "roll" RANK_LIST = "rank_list" EXPORT_TYPE = "export_type" EXTRA_ARGS = "args" diff --git a/profiler/msprof_analyze/prof_exports/mstx_event_export.py b/profiler/msprof_analyze/prof_exports/mstx_event_export.py new file mode 100644 index 000000000..5c5f6afd9 --- /dev/null +++ b/profiler/msprof_analyze/prof_exports/mstx_event_export.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport + +HOST_MSTX_QUERY = ''' +SELECT + MSG_IDS.value AS "msg", + DOMAIN_IDS.value AS "domain", + MSTX_EVENTS.startNs AS "cann_start_ts", + MSTX_EVENTS.endNs AS "cann_end_ts", + MSTX_EVENTS.globalTid AS "tid", + (SELECT hostUid FROM HOST_INFO LIMIT 1) AS "host_uid", + (SELECT id FROM NPU_INFO LIMIT 1) AS "device_id" +FROM + MSTX_EVENTS +LEFT JOIN + STRING_IDS AS MSG_IDS + ON MSTX_EVENTS.message = MSG_IDS.id +LEFT JOIN + STRING_IDS AS DOMAIN_IDS + ON MSTX_EVENTS.domainId = DOMAIN_IDS.id +ORDER BY + MSTX_EVENTS.startNs + ''' + + +class MstxHostExport(BaseStatsExport): + + def __init__(self, db_path, recipe_name): + super().__init__(db_path, recipe_name) + self._query = HOST_MSTX_QUERY -- Gitee