diff --git a/OWNERS b/OWNERS
index 7b721dd643e3399d29ff649e2f76182de72421a4..2e949debf181a6e75fdb5b1e1e091ce7a39c7e69 100644
--- a/OWNERS
+++ b/OWNERS
@@ -13,35 +13,14 @@ approvers:
- kun_8
- binghamhuang
reviewers:
-- leo920320
-- wo-wenjie
-- ma-dongfang
-- wuyulong11
-- alysongirl
-- wangchao285
-- brightlyking
-- chenhao_1209
-- feng123www
-- zhang-mingyu-0813
-- snowflakephoenix
-- Seanesmhxocism
-- augboost
-- fanxiaotong1995
-- sunboquan
-- kun_8
-- Martin-M
-- ly-qianxiao
-- yang-minghai22
-- hu-xiao-bo
- lv-kaimeng
- litian_drinksnow
-- blian
-- cycoe
-- machj
-- zhengweifeng6
-- gong-siwei
-- uniteone
- binghamhuang
-- wjchuee
-- zhou-xianqi
-- stby11
\ No newline at end of file
+- wo-wenjie
+- ly-qianxiao
+- leo920320
+- sunboquan
+- stby
+- Seanesmhxocism
+- TAJh
+- czr9775
\ No newline at end of file
diff --git a/README.md b/README.md
index 014a4d59f07116c2519de1b2a69463484549642d..7c28869d371debd863409bd8dbe75baf586a5c7e 100644
--- a/README.md
+++ b/README.md
@@ -1,96 +1,84 @@
-# 变更通知
+# 🚨 重要通知
-原Ascend Training Tools工具更名为MindStudio Training Tools,MindStudio训练工具链。变更计划如下:
+**1. Ascend Training Tools 更名为 MindStudio Training Tools (mstt)。**
-1. 2024.06.25本代码仓名称变更为mstt。
-2. 2024.07.04 URL变更为[https://gitee.com/ascend/mstt](https://gitee.com/ascend/mstt),原始URL仍然可用,但建议使用新URL。
+**2. 本代码仓 URL 变更为 [https://gitee.com/ascend/mstt](https://gitee.com/ascend/mstt),原 URL 仍然可用(2024.07.04 )。**
-# MindStudio Training Tools
+**3. 不再维护:[api_accuracy_checker](./debug/accuracy_tools/api_accuracy_checker/) (2024.09.30下线)和[ ptdbg_ascend](./debug/accuracy_tools/ptdbg_ascend/)
+(2024.09.30下线)**
-MindStudio Training Tools,MindStudio训练工具链。针对训练&大模型场景,提供端到端命令行&可视化调试调优工具,帮助用户快速提高模型开发效率。
+**相关目录 mstt/debug/accuracy_tools/api_accuracy_checker 和 mstt/debug/accuracy_tools/ptdbg_ascend 将于 2024.09.30 删除。新版本的预检和 ptdbg 已经合到 mstt/debug/accuracy_tools/msprobe 目录下。**
-## 模型训练迁移全流程
-
+---
-## 使用说明
+# 🧰 MindStudio Training Tools
-### [分析迁移工具](https://gitee.com/ascend/mstt/wikis/工具介绍/分析迁移工具/分析迁移工具介绍)
+
+
+
+
+## [分析迁移工具](https://gitee.com/ascend/mstt/wikis/工具介绍/分析迁移工具/分析迁移工具介绍)
1. [脚本分析工具](https://gitee.com/ascend/mstt/wikis/%E5%B7%A5%E5%85%B7%E4%BB%8B%E7%BB%8D/%E5%88%86%E6%9E%90%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7/%E5%88%86%E6%9E%90%E5%B7%A5%E5%85%B7%E4%BD%BF%E7%94%A8%E6%8C%87%E5%AF%BC)
- 脚本分析工具提供分析脚本,帮助用户在执行迁移操作前,分析基于GPU平台的PyTorch训练脚本中算子、三方库套件、亲和API分析以及动态shape的支持情况。
+ 脚本分析工具可以帮助用户在执行迁移操作前,分析基于 GPU 平台的 PyTorch 训练脚本中算子、三方库套件、API 亲和性以及动态 shape 的支持情况。
2. [(推荐)自动迁移工具](https://gitee.com/ascend/mstt/wikis/%E5%B7%A5%E5%85%B7%E4%BB%8B%E7%BB%8D/%E5%88%86%E6%9E%90%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7/%E8%87%AA%E5%8A%A8%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7%E4%BD%BF%E7%94%A8%E6%8C%87%E5%AF%BC)
- 自动迁移只需在训练脚本中导入库代码即可完成模型脚本迁移,使用方式较简单,且修改内容最少。
+ 自动迁移工具只需在训练脚本中导入库代码即可完成模型脚本的迁移,使用方式简单,且修改内容少。
3. [脚本迁移工具](https://gitee.com/ascend/mstt/wikis/%E5%B7%A5%E5%85%B7%E4%BB%8B%E7%BB%8D/%E5%88%86%E6%9E%90%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7/%E8%84%9A%E6%9C%AC%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7%E4%BD%BF%E7%94%A8%E6%8C%87%E5%AF%BC)
- 脚本迁移工具提供后端命令行用于将GPU上训练的PyTorch脚本迁移至NPU上,得到新的训练脚本用于训练。
+ 脚本迁移工具通过后端命令行,将 GPU 上训练的 PyTorch 脚本迁移至 NPU 上,得到新的训练脚本用于训练。
4. [训推一体权重转换工具](https://gitee.com/Ascend/mstt/wikis/%E5%B7%A5%E5%85%B7%E4%BB%8B%E7%BB%8D/%E5%88%86%E6%9E%90%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7/%E8%AE%AD%E6%8E%A8%E4%B8%80%E4%BD%93%E6%9D%83%E9%87%8D%E8%BD%AC%E6%8D%A2%E5%B7%A5%E5%85%B7%E4%BD%BF%E7%94%A8%E6%8C%87%E5%AF%BC)
- 训推一体权重转换工具,支持在GPU和NPU上训练好的模型转成加速推理支持的格式。
-
-### [精度工具](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools)
+ 训推一体权重转换工具,支持在 GPU 和 NPU 上训练好的模型转成加速推理支持的格式。
-1. [api_accuracy_checker(Ascend模型精度预检工具)](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/api_accuracy_checker)
+## [精度工具](./debug/accuracy_tools/)
- 在昇腾NPU上扫描用户训练模型中所有API,进行API复现,给出精度情况的诊断和分析。
+[MindStudio Probe(msprobe,MindStudio 精度调试工具)](./debug/accuracy_tools/msprobe)。
-2. [ptdbg_ascend(PyTorch精度工具)](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/ptdbg_ascend)
+## [性能工具](./profiler)
- 进行PyTorch整网API粒度的数据dump、精度比对和溢出检测,从而定位PyTorch训练场景下的精度问题。
+1. [compare_tools(性能比对工具)](./profiler/compare_tools)
-### [性能工具](https://gitee.com/ascend/mstt/tree/master/profiler)
+ 提供 NPU 与 GPU 性能拆解功能以及算子、通信、内存性能的比对功能。
-1. [compare_tools(性能比对工具)](https://gitee.com/ascend/mstt/tree/master/profiler/compare_tools)
+2. [cluster_analyse(集群分析工具)](./profiler/cluster_analyse)
- 提供NPU与GPU性能拆解功能以及算子、通信、内存性能的比对功能。
+ 提供多机多卡的集群分析能力(基于通信域的通信分析和迭代耗时分析), 当前需要配合 MindStudio Insight 的集群分析功能使用。
-2. [cluster_analyse(集群分析工具)](https://gitee.com/ascend/mstt/tree/master/profiler/cluster_analyse)
+3. [advisor](./profiler/advisor)
- 提供多机多卡的集群分析能力(基于通信域的通信分析和迭代耗时分析), 当前需要配合Ascend Insight的集群分析功能使用。
+ 将 Ascend PyTorch Profiler 或者 msprof 采集的 PyTorch 场景性能数据进行分析,并输出性能调优建议。
-3. [affinity_cpu_bind (亲和性cpu绑核工具) ](https://gitee.com/ascend/mstt/tree/master/profiler/affinity_cpu_bind)
+4. [bind_core](./profiler/affinity_cpu_bind)
- 提供亲和性CPU绑核能力,改善host_bound调度问题。
+ 绑核脚本,支持非侵入修改工程代码,实现一键式绑核功能。
-### [Tensorboard](https://gitee.com/ascend/mstt/tree/master/plugins/tensorboard-plugins/tb_plugin)
+## [Tensorboard](./plugins/tensorboard-plugins/tb_plugin)
-Tensorboard支持NPU性能数据可视化插件PyTorch Profiler TensorBoard NPU Plugin。
+Tensorboard 支持 NPU 性能数据可视化插件 PyTorch Profiler TensorBoard NPU Plugin。
-支持将Ascend平台采集、解析的Pytorch Profiling数据可视化呈现,也兼容GPU数据采集、解析可视化。
+支持将 Ascend 平台采集、解析的 PyTorch Profiling 数据可视化呈现,也兼容 GPU 数据采集、解析可视化。
## 分支维护策略
-MindStudio Training Tools工具版本分支的维护阶段如下:
-
-| **状态** | **时间** | **说明** |
-| ------------------- | -------- | ------------------------------------------------ |
-| 计划 | 1—3 个月 | 计划特性 |
-| 开发 | 3个月 | 开发特性 |
-| 维护 | 6—12个月 | 合入所有已解决的问题并发布版本 |
-| 无维护 | 0—3 个月 | 合入所有已解决的问题,无专职维护人员,无版本发布 |
-| 生命周期终止(EOL) | N/A | 分支不再接受任何修改 |
-
-## 现有分支的维护状态
-
-MindStudio Training Tools分支版本号命名规则如下:
-
-mstt仓每年发布4个版本,每个版本都将对应一个分支;以v6.0为例,其将对应v6.0.RC1、v6.0.RC2、v6.0.RC3以及v6.0.0四个版本,在仓库中将存在与之对应的分支。
-
-| **分支** | **状态** | **发布日期** | **后续状态** | **EOL日期** |
-| ------------- | -------- | ------------ | ------------------------ | ----------- |
-| **v6.0.0** | 维护 | 2023/12/12 | 预计2024/12/12起无维护 | |
+1. MindStudio Training Tools 工具版本分支的维护阶段如下:
-## 参与贡献
+ | **状态** | **时间** | **说明** |
+ | ------------------- | -------- | ------------------------------------------------ |
+ | 计划 | 1—3 个月 | 计划特性 |
+ | 开发 | 3个月 | 开发特性 |
+ | 维护 | 6—12个月 | 合入所有已解决的问题并发布版本 |
+ | 无维护 | 0—3 个月 | 合入所有已解决的问题,无专职维护人员,无版本发布 |
+ | 生命周期终止(EOL) | N/A | 分支不再接受任何修改 |
-1. Fork 本仓库
-2. 新建 xxx 分支
-3. 提交代码
-4. 新建 Pull Request
+2. MindStudio Training Tools 分支版本号命名规则如下:
-## 版本过渡提示
+ mstt 仓每年发布 4 个版本,每个版本都将对应一个分支;以 v6.0 为例,其将对应 v6.0.RC1、v6.0.RC2、v6.0.RC3 以及 v6.0.0 四个版本,在仓库中将存在与之对应的分支。
-当前版本预检和ptdbg维护到2024/09/30,准备于2024/09/30下线,相关目录mstt/debug/accuracy_tools/api_accuracy_checker和mstt/debug/accuracy_tools/ptdbg_ascend将于2024/09/30删除。新版本的预检和ptdbg已经合到mstt/debug/accuracy_tools/atat目录下。
+ | **分支** | **状态** | **发布日期** | **后续状态** | **EOL日期** |
+ | ------------- | -------- | ------------ | ------------------------ | ----------- |
+ | **v6.0.0** | 维护 | 2023.12.12 | 预计 2024.12.12 起无维护 | |
diff --git a/debug/OWNERS b/debug/OWNERS
index 09121722c9d7147133c6f111cd10b279979ebdb3..36e09821f37f6abed2a9da211ff7c1ef447218b9 100644
--- a/debug/OWNERS
+++ b/debug/OWNERS
@@ -5,7 +5,13 @@ approvers:
- kun_8
- binghamhuang
- brightlyking
+- litian_drinksnow
reviewers:
- lv-kaimeng
-- litian_drinksnow
- binghamhuang
+- xiangsen2
+- TAJh
+- jiandaobao
+- pengxiaopeng1
+- zhengxinqian
+- louyujing
diff --git a/debug/accuracy_tools/MANIFEST.in b/debug/accuracy_tools/MANIFEST.in
index 7242c0c95627b56620b63c650dbbffbf8aaa2896..1af064685bcca89be479a465fd8b4f466047165f 100644
--- a/debug/accuracy_tools/MANIFEST.in
+++ b/debug/accuracy_tools/MANIFEST.in
@@ -1,2 +1,4 @@
-recursive-include atat/ *
-recursive-exclude atat/test *
\ No newline at end of file
+include README.md
+include LICENSE
+recursive-include msprobe *
+recursive-exclude msprobe/test *
\ No newline at end of file
diff --git a/debug/accuracy_tools/api_accuracy_checker/generate_op_script/operator_replication.template b/debug/accuracy_tools/api_accuracy_checker/generate_op_script/operator_replication.template
new file mode 100644
index 0000000000000000000000000000000000000000..7630839aa937c6d0419629b5e93c34b51b71f295
--- /dev/null
+++ b/debug/accuracy_tools/api_accuracy_checker/generate_op_script/operator_replication.template
@@ -0,0 +1,325 @@
+import json
+import os
+import math
+from enum import Enum, auto
+import torch
+try:
+ import torch_npu
+except ImportError:
+ pass
+
+
+TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
+TORCH_BOOL_TYPE = ["torch.bool"]
+TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int",
+ "torch.int64", "torch.long"]
+TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float",
+ "torch.float64", "torch.double"]
+TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128", "torch.cdouble"]
+RAISE_PRECISION = {{
+ "torch.float16": torch.float32,
+ "torch.half": torch.float32,
+ "torch.bfloat16": torch.float32,
+ "torch.float32": torch.float64,
+ "torch.float": torch.float64
+}}
+
+
+class CompareStandard(Enum):
+ BINARY_EQUALITY_STANDARD = auto()
+ ABSOLUTE_THRESHOLD_STANDARD = auto()
+ ULP_ERROR_STANDARD = auto()
+ BENCHMARK_STANDARD = auto()
+
+
+def get_device():
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+ elif torch_npu.npu.is_available():
+ device = torch.device("npu")
+ else:
+ raise Exception("Error: This device is not NPU or GPU!")
+ return device
+
+
+def generate_bool_tensor(low, high, shape):
+ low, high = int(low), int(high)
+ tensor = torch.randint(low, high + 1, shape)
+ bool_tensor = torch.gt(tensor, 0)
+ return bool_tensor
+
+
+def generate_numerical_tensor(low, high, shape, data_dtype):
+ if data_dtype in TORCH_FLOAT_TYPE:
+ scale = high - low
+ rand01 = torch.rand(shape, dtype=eval(data_dtype))
+ tensor = rand01 * scale + low
+ elif data_dtype in TORCH_INT_TYPE:
+ low, high = int(low), int(high)
+ tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype))
+ else:
+ raise NotImplementedError(f"{{data_dtype}} is not supported!")
+ if torch.numel(tensor) == 0:
+ return tensor
+ tmp_tensor = tensor.reshape(-1)
+ tmp_tensor[0] = low
+ tmp_tensor[-1] = high
+ data = tmp_tensor.reshape(shape)
+ return data
+
+
+def generate_random_tensor(info):
+ low, high = info.get('Min'), info.get('Max')
+ data_dtype = info.get('dtype')
+ shape = tuple(info.get('shape'))
+ if data_dtype == "torch.bool":
+ data = generate_bool_tensor(low, high, shape)
+ else:
+ data = generate_numerical_tensor(low, high, shape, data_dtype)
+ return data
+
+
+def generate_real_tensor(data_path):
+ data_path = os.path.realpath(data_path)
+ data = torch.load(data_path)
+ return data
+
+
+def generate_data(info):
+ data_type = info.get("type")
+ data_path = info.get("datapath")
+ if data_type in TENSOR_DATA_LIST:
+ if data_path:
+ data = generate_real_tensor(data_path)
+ else:
+ data = generate_random_tensor(info)
+ else:
+ data = info.get("value")
+ return data
+
+
+def get_input():
+{args_element_assignment}
+ args_device = [{args_list_generator_device}]
+ args_bench = [{args_list_generator_bench}]
+{kwargs_value_assignment}
+ kwargs_device = {{{kwargs_dict_generator_device}}}
+ kwargs_bench = {{{kwargs_dict_generator_bench}}}
+ return args_device, kwargs_device, args_bench, kwargs_bench
+
+
+def exec_api_device(args, kwargs):
+ output_device = {api_type}.{api_name}(*args, **kwargs)
+ return output_device
+
+
+def exec_api_bench(args, kwargs):
+ output_bench = {api_type}.{api_name}(*args, **kwargs)
+ return output_bench
+
+
+def compute_inf_nan_proportion(inf_nan_mask, out_device, out_bench, abs_bench_with_eps, rtol):
+ out_bench = out_bench.to(out_device.dtype)
+ min = torch.finfo(out_device.dtype).min
+ max = torch.finfo(out_device.dtype).max
+ bench_clip = torch.clamp(out_bench, min=min, max=max)
+ device_clip = torch.clamp(out_device, min=min, max=max)
+ clipped_abs_ae = torch.abs(device_clip - bench_clip)
+ clipped_re = clipped_abs_ae / abs_bench_with_eps
+ pass_mask = torch.less_equal(clipped_re, rtol)
+ both_nan_mask = torch.logical_and(torch.isnan(out_device), torch.isnan(bench_clip))
+ pass_mask = torch.logical_or(pass_mask, both_nan_mask)
+ not_pass_mask = torch.logical_not(pass_mask)
+ not_pass_mask = torch.logical_and(not_pass_mask, inf_nan_mask)
+ inf_nan_err_cnt = torch.sum(not_pass_mask)
+ return 0 if torch.sum(inf_nan_mask) == 0 else inf_nan_err_cnt / torch.sum(inf_nan_mask)
+
+
+def compute_rmse(abs_err, normal_value_mask):
+ if torch.sum(normal_value_mask) == 0:
+ return 0
+ else:
+ masked_ae = torch.where(normal_value_mask, abs_err, 0)
+ mse = torch.sum(torch.square(masked_ae)) / torch.sum(normal_value_mask)
+ rmse = torch.sqrt(mse)
+ return rmse
+
+
+def compute_error_balance(out_device, out_bench):
+ larger_count = torch.sum(torch.greater(out_device - out_bench.to(out_device.dtype), 0))
+ smaller_count = torch.sum(torch.less(out_device - out_bench.to(out_device.dtype), 0))
+ total_count = torch.numel(out_bench)
+ error_balance = abs(larger_count - smaller_count) / total_count
+ return error_balance
+
+
+def compare_tensor(out_device, out_bench, api_name):
+ if out_device.shape != out_bench.shape:
+ print("ERROR: shape of out_device and out_bench is not equal!")
+ return None
+ if torch.numel(out_bench) == 0:
+ print("Both out_device and out_bench have zero elements.")
+ return None
+ print(f"shape is {{out_bench.shape}}")
+ print(f"dtype of out_device is {{out_device.dtype}}")
+ print(f"dtype of out_bench is {{out_bench.dtype}}")
+ dtype_device = out_device.dtype
+ dtype_bench = out_bench.dtype
+ if str(dtype_device) in TORCH_FLOAT_TYPE and str(dtype_bench) in TORCH_FLOAT_TYPE \
+ or str(dtype_device) in TORCH_INT_TYPE and str(dtype_bench) in TORCH_INT_TYPE \
+ or str(dtype_device) in TORCH_BOOL_TYPE and str(dtype_bench) in TORCH_BOOL_TYPE:
+ out_device = out_device.to(torch.device("cpu"))
+ if str(dtype_device) in TORCH_BOOL_TYPE or str(dtype_device) in TORCH_INT_TYPE or compare_standard == CompareStandard.BINARY_EQUALITY_STANDARD:
+ print("compare standard: binary equality standard:")
+ error_number = torch.sum(out_device != out_bench).item()
+ error_rate = error_number / torch.numel(out_bench)
+ print(f"error rate is {{error_rate}}.")
+ else:
+ abs_err = torch.abs(out_device - out_bench)
+ abs_bench = torch.abs(out_bench)
+ if dtype_bench == torch.float32:
+ eps = 2 ** -23
+ if dtype_bench == torch.float64:
+ eps = 2 ** -52
+ abs_bench_with_eps = abs_bench + eps
+ rel_err = torch.abs(abs_err / abs_bench_with_eps)
+ device_finite_mask = torch.isfinite(out_device)
+ bench_finite_mask = torch.isfinite(out_bench.to(dtype_device))
+ both_finite_mask = torch.logical_and(device_finite_mask, bench_finite_mask)
+ inf_nan_mask = torch.logical_not(both_finite_mask)
+ if compare_standard == CompareStandard.ABSOLUTE_THRESHOLD_STANDARD:
+ if dtype_device == torch.float16:
+ rtol, small_value, small_value_atol = 1.0e-3, 1.0e-3, 1.0e-5
+ elif dtype_device == torch.bfloat16:
+ rtol, small_value, small_value_atol = 4.0e-3, 1.0e-3, 1.0e-5
+ else:
+ rtol, small_value, small_value_atol = 1.0e-6, 1.0e-6, 1.0e-9
+ small_value_mask = torch.less_equal(abs_bench, small_value)
+ small_value_mask = torch.logical_and(small_value_mask, both_finite_mask)
+ normal_value_mask = torch.logical_and(both_finite_mask, torch.logical_not(small_value_mask))
+ inf_nan_proportion = compute_inf_nan_proportion(inf_nan_mask, out_device, out_bench, abs_bench_with_eps, rtol)
+ rel_err_mask = torch.greater(rel_err, rtol)
+ rel_err_mask = torch.logical_and(rel_err_mask, normal_value_mask)
+ if torch.sum(normal_value_mask) == 0:
+ rel_err_proportion = 0
+ else:
+ rel_err_proportion = torch.sum(rel_err_mask) / torch.sum(normal_value_mask)
+ abs_err_mask = torch.greater(abs_err, small_value_atol)
+ abs_err_mask = torch.logical_and(abs_err_mask, small_value_mask)
+ if torch.sum(small_value_mask) == 0:
+ abs_err_proportion = 0
+ else:
+ abs_err_proportion = torch.sum(abs_err_mask) / torch.sum(small_value_mask)
+ print("compare standard: absolute threshold standard")
+ print(f"relative error ratio is {{rel_err_proportion}}")
+ print(f"absolute error ratio is {{abs_err_proportion}}")
+ elif compare_standard == CompareStandard.ULP_ERROR_STANDARD:
+ if dtype_device == torch.float16:
+ min_eb, exponent_num = -14, 10
+ elif dtype_device == torch.bfloat16:
+ min_eb, exponent_num = -126, 7
+ else:
+ min_eb, exponent_num = -126, 23
+ eb = torch.where(abs_bench == 0, torch.zeros(out_bench.shape), torch.floor(torch.log2(abs_bench)))
+ eb = torch.maximum(eb, min_eb * torch.ones(out_bench.shape))
+ if dtype_device == torch.float32:
+ ulp_err = (out_device.to(torch.float64) - out_bench).to(torch.float64) * torch.exp2(-eb + exponent_num).to(torch.float64)
+ else:
+ ulp_err = (out_device.to(torch.float32) - out_bench).to(torch.float32) * torch.exp2(-eb + exponent_num).to(torch.float32)
+ ulp_err = torch.abs(ulp_err)
+ max_ulp_err = torch.max(ulp_err)
+ mean_ulp_err = torch.mean(ulp_err)
+ if dtype_device == torch.float32:
+ ulp_err_proportion = torch.sum(ulp_err > 32) / torch.numel(out_bench)
+ else:
+ ulp_err_proportion = torch.sum(ulp_err > 1) / torch.numel(out_bench)
+ print("compare standard: ulp error standard")
+ print(f"maximum ulp error is {{max_ulp_err}}")
+ print(f"mean ulp error is {{mean_ulp_err}}")
+ print(f"ulp error proportion is {{ulp_err_proportion}}")
+ else:
+ if dtype_device == torch.float16:
+ small_value, small_value_atol = 1.0e-3, 1.0e-5
+ elif dtype_device == torch.bfloat16:
+ small_value, small_value_atol = 1.0e-3, 1.0e-5
+ else:
+ small_value, small_value_atol = 1.0e-6, 1.0e-9
+ small_value_mask = torch.less_equal(abs_bench, small_value)
+ small_value_mask = torch.logical_and(small_value_mask, both_finite_mask)
+ normal_value_mask = torch.logical_and(both_finite_mask, torch.logical_not(small_value_mask))
+ abs_err_mask = torch.greater(abs_err, small_value_atol)
+ abs_err_mask = torch.logical_and(abs_err_mask, small_value_mask)
+ if torch.sum(small_value_mask) == 0:
+ small_value_err_proportion = 0
+ else:
+ small_value_err_proportion = torch.sum(abs_err_mask) / torch.sum(small_value_mask)
+ rel_err = torch.where(normal_value_mask, rel_err, -1 * torch.ones(out_device.shape))
+ if torch.max(rel_err) >= 0:
+ max_rel_err = torch.max(rel_err)
+ else:
+ max_rel_err = 0
+ if torch.sum(normal_value_mask) == 0:
+ mean_rel_err = 0
+ else:
+ mean_rel_err = torch.sum(torch.clamp(rel_err, min=0)) / torch.sum(normal_value_mask)
+ rmse = compute_rmse(abs_err, normal_value_mask)
+ error_balance = compute_error_balance(out_device, out_bench)
+ print("compare standard: benchmark standard")
+ print(f"small value error proportion is {{small_value_err_proportion}}")
+ print(f"maximum relative error is {{max_rel_err}}")
+ print(f"mean relative error is {{mean_rel_err}}")
+ print(f"root mean squared error is {{rmse}}")
+ print(f"error balance is {{error_balance}}")
+ else:
+ print(f"ERROR: out_device dtype is {{dtype_device}}, out_bench dtype is {{dtype_bench}}, not comparable.")
+ return None
+
+
+def compare_element(out_device, out_bench, api_name):
+ if type(out_device) != type(out_bench):
+ print("ERROR: out_device and out_bench is not the same type!")
+ return None
+ if isinstance(out_bench, torch.Tensor):
+ print(f"data type: {{type(out_bench)}}")
+ compare_tensor(out_device, out_bench, api_name)
+ elif isinstance(out_bench, (bool, int, float, str)):
+ print(f"data type: {{type(out_bench)}}")
+ if out_device == out_bench:
+ print("PASS: out_device and out_bench equals.")
+ else:
+ print("ERROR: out_device and out_bench is not equal!")
+ else:
+ print(f"ERROR: comparison of type {{type(out_bench)}} is not supported.")
+ return None
+
+
+def compare(out_device, out_bench, api_name):
+ print("Compare result:")
+ if type(out_device) != type(out_bench):
+ print("ERROR: out_device and out_bench is not the same type!")
+ print("Compare finished.")
+ return None
+ if isinstance(out_bench, (list, tuple)):
+ print(f"data type: {{type(out_bench)}}")
+ if len(out_device) != len(out_bench):
+ print("ERROR: len of out_device and out_bench is different!")
+ print("Compare finished.")
+ return None
+ for index, _ in enumerate(out_bench):
+ print(f"index {{index}}:")
+ compare_element(out_device[index], out_bench[index], api_name)
+ else:
+ compare_element(out_device, out_bench, api_name)
+ print("Compare finished.")
+
+
+device = get_device()
+api_name = "{api_name}"
+compare_standard = {compare_standard}
+torch.manual_seed({random_seed})
+for i in range({iter_times}):
+ print(f"iter: {{i}}:")
+ args_device, kwargs_device, args_bench, kwargs_bench = get_input()
+ output_device = exec_api_device(args_device, kwargs_device)
+ output_bench = exec_api_bench(args_bench, kwargs_bench)
+ compare(output_device, output_bench, api_name)
diff --git a/debug/accuracy_tools/atat/mindspore/__init__.py b/debug/accuracy_tools/atat/mindspore/__init__.py
deleted file mode 100644
index bb3f93567542e93ff913edf3daabcd3aedb91ee3..0000000000000000000000000000000000000000
--- a/debug/accuracy_tools/atat/mindspore/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from atat.mindspore.debugger.precision_debugger import PrecisionDebugger
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/compare/test_acc_compare.py b/debug/accuracy_tools/atat/test/pytorch_ut/compare/test_acc_compare.py
deleted file mode 100644
index 5a82289a0003a332e41d9202b63e7bde4bc43c42..0000000000000000000000000000000000000000
--- a/debug/accuracy_tools/atat/test/pytorch_ut/compare/test_acc_compare.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# coding=utf-8
-import unittest
-from atat.pytorch.compare.acc_compare import rename_api
-
-class TestUtilsMethods(unittest.TestCase):
-
- def test_rename_api(self):
- test_name_1 = "Distributed.broadcast.0.forward.input.0"
- expect_name_1 = "Distributed.broadcast.input.0"
- actual_name_1 = rename_api(test_name_1, "forward")
- self.assertEqual(actual_name_1, expect_name_1)
-
- test_name_2 = "Torch.sum.0.backward.output.0"
- expect_name_2 = "Torch.sum.output.0"
- actual_name_2 = rename_api(test_name_2, "backward")
- self.assertEqual(actual_name_2, expect_name_2)
-
\ No newline at end of file
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py b/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py
deleted file mode 100644
index fa52fe0e1b05701cec9c8cdf41fe1586029c826e..0000000000000000000000000000000000000000
--- a/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py
+++ /dev/null
@@ -1,38 +0,0 @@
-from unittest import TestCase
-from unittest.mock import patch, mock_open
-
-from atat.core.common.const import Const
-from atat.pytorch.pt_config import parse_json_config
-
-
-class TestPtConfig(TestCase):
- def test_parse_json_config(self):
- mock_json_data = {
- "task": "statistics",
- "dump_path": "./dump/",
- "rank": [],
- "step": [],
- "level": "L1",
- "seed": 1234,
- "statistics": {
- "scope": [],
- "list": [],
- "data_mode": ["all"],
- },
- "tensor": {
- "file_format": "npy"
- }
- }
- with patch("atat.pytorch.pt_config.os.path.join", return_value="/path/config.json"), \
- patch("atat.pytorch.pt_config.FileOpen", mock_open(read_data='')), \
- patch("atat.pytorch.pt_config.json.load", return_value=mock_json_data):
- common_config, task_config = parse_json_config(None, None)
- self.assertEqual(common_config.task, Const.STATISTICS)
- self.assertEqual(task_config.data_mode, ["all"])
-
- with patch("atat.pytorch.pt_config.os.path.join", return_value="/path/config.json"), \
- patch("atat.pytorch.pt_config.FileOpen", mock_open(read_data='')), \
- patch("atat.pytorch.pt_config.json.load", return_value=mock_json_data):
- common_config, task_config = parse_json_config(None, Const.TENSOR)
- self.assertEqual(common_config.task, Const.STATISTICS)
- self.assertEqual(task_config.file_format, "npy")
diff --git a/debug/accuracy_tools/graph_analyzer/README.md b/debug/accuracy_tools/graph_analyzer/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..cf0ad58c363d88e20787de5ea20d19df6da40074
--- /dev/null
+++ b/debug/accuracy_tools/graph_analyzer/README.md
@@ -0,0 +1,53 @@
+# Graph Analyzer
+
+#### 介绍
+图分析精度工具
+
+#### 软件架构
+软件架构说明
+
+
+#### 安装教程
+
+1. 下载源码
+```
+git clone https://gitee.com/ascend/mstt.git -b poc
+```
+2. pip安装
+```
+cd debug/accuracy_tools/graph_analyzer
+pip install .
+```
+
+#### 使用说明
+- ir图使用推荐:
+ - ir图推荐使用`anf_after_graph_build`图
+
+#### 功能说明
+**`<>`表示必选参数,`[]`表示可选参数**
+1. ir图结构分析
+使用方式:
+
+```
+graph_analyzer --ir [--output ]
+```
+执行后,会自动分析ir文件,将ir文件分析后的结果输出到指定output目录下的struct.json,如果未指定output则默认为当前目录
+
+
+2. 数码关联
+数码关联是指数据和代码调用栈的关联,数据一般意义上指静态图`O0`,`O1`,`O2`下dump下来的数据
+目前支持:
+- [x] 全量[tensor(npy)]数据格式的数码关联
+- [x] 统计值[statisitc]数据格式的数码关联
+- [x] 融合算子场景
+- [x] 支持超长算子名dump文件的自动解析
+- [x] 反向算子的正向绑定
+
+使用方式:
+
+```
+graph_analyzer --ir --data [--output ]
+```
+
+- 如果是全量模式,则会把数据文件路径和代码调用栈的关联关系存到output路径下的mapping.csv中
+- 如果是统计值模式,则会把统计值csv中每个条目加上该条目对应的代码栈
diff --git a/debug/accuracy_tools/atat/__init__.py b/debug/accuracy_tools/graph_analyzer/graph_analyzer/__init__.py
similarity index 100%
rename from debug/accuracy_tools/atat/__init__.py
rename to debug/accuracy_tools/graph_analyzer/graph_analyzer/__init__.py
diff --git a/debug/accuracy_tools/graph_analyzer/graph_analyzer/bind.py b/debug/accuracy_tools/graph_analyzer/graph_analyzer/bind.py
new file mode 100644
index 0000000000000000000000000000000000000000..5958b618fa1c4eb46ce9d1785dc35794ebaeeb9b
--- /dev/null
+++ b/debug/accuracy_tools/graph_analyzer/graph_analyzer/bind.py
@@ -0,0 +1,159 @@
+import os
+import logging
+import glob
+from typing import Dict, List
+from pathlib import Path
+import pandas as pd
+from graph_analyzer.graph import GraphNode
+
+
+# 定义Trie节点
+class TrieNode:
+ def __init__(self):
+ self.children = {}
+ self.is_end_of_key = False
+ self.value = None
+
+# 定义Trie树
+class Trie:
+ def __init__(self):
+ self.root = TrieNode()
+
+ # 向Trie中插入一个键
+ def insert(self, key, value):
+ node = self.root
+ for char in key:
+ if char not in node.children:
+ node.children[char] = TrieNode()
+ node = node.children[char]
+ # 标记结束位置
+ node.is_end_of_key = True
+ node.value = value
+
+ # 在name字符串中查找所有匹配的键
+ def search_in_string(self, string):
+ matched_values = []
+ for i in range(len(string)):
+ node = self.root
+ j = i
+ # 从字符串的每个字符开始,逐字符查找匹配
+ while j < len(string) and string[j] in node.children:
+ node = node.children[string[j]]
+ if node.is_end_of_key:
+ matched_values.append(node.value)
+ j += 1
+ return matched_values
+
+# 定义匹配函数
+def match_codes(trie, name):
+ matched_nodes = trie.search_in_string(name)
+ matched_codes = ['\n'.join(ii.code_info) for ii in matched_nodes]
+ return '\n'.join(matched_codes)
+
+
+def match_names(trie, name):
+ matched_nodes = trie.search_in_string(name)
+ matched_names = [ii.scope for ii in matched_nodes]
+ return '\n'.join(matched_names)
+
+
+def complex_map(df, match_dict):
+# 构建Trie树并插入所有键
+ trie = Trie()
+ for key, value in match_dict.items():
+ trie.insert(key, value)
+
+ df['Code Stack'] = df['Op Name'].apply(lambda name: match_codes(trie, name))
+ df['Scope Name'] = df['Op Name'].apply(lambda name: match_names(trie, name))
+ return df
+
+
+def find_npy_files(npy_path):
+ npy_files = []
+ # 检查当前路径是否是一个以 .npy 结尾的文件
+ if npy_path.endswith('npy') and os.path.isfile(npy_path):
+ npy_files.append(Path(npy_path).resolve())
+ return npy_files
+
+ npy_files = list(Path(npy_path).rglob('*.npy'))
+ return npy_files
+
+
+def write_to_csv(param: Dict, output_dir: str, append: bool):
+ # 打开CSV文件以写入模式
+ os.makedirs(output_dir, exist_ok=True)
+ file_name = os.path.join(output_dir, "code.csv")
+ data = [(name, res1, res2) for name, (res1, res2) in param.items()]
+ df = pd.DataFrame(data, columns=['File Path', 'Code Stacks', 'Scope Name'])
+ # 如果 append 为 True 并且文件已经存在,追加写入
+ if append and os.path.exists(file_name):
+ # 清洗数据,筛选掉空字符串
+ df = df[(df['Code Stacks'] != '') | (df['Scope Name'] != '')]
+ df.to_csv(file_name, mode='a', header=False, index=False)
+ # 否则,覆盖写入或者文件不存在时正常写入
+ else:
+ df.to_csv(file_name, mode='w', header=True, index=False)
+
+
+def find_statistic_files(directory):
+ if not os.path.isdir(directory):
+ return []
+ pattern = os.path.join(directory, '**', "statistic.csv")
+ statistic_files = list(glob.glob(pattern))
+ return statistic_files
+
+
+
+def bind_for_statistic(statistic_files: List[str], match_dict: Dict):
+ for statistic_file in statistic_files:
+ df = pd.read_csv(statistic_file)
+ df = complex_map(df, match_dict)
+ logging.info("Processing %s completed, code stack saved in %s", statistic_file, statistic_file)
+ df.to_csv(statistic_file, index=False)
+
+
+def bind_code_info_for_data(input_dir: str, nodes: Dict[str, GraphNode]) -> Dict[str, str]:
+ # 待重构后优化性能
+ match_dict = {}
+ for node in nodes.values():
+ # 屏蔽子图节点
+ if node.is_subgraph:
+ continue
+ # 获取规范化后的scope name
+ scope_name = node.scope.replace("/", "_")
+ match_dict[scope_name] = node
+ npy_files = find_npy_files(input_dir)
+
+ bind_result = {}
+ if not npy_files:
+ statistic_files = find_statistic_files(input_dir)
+ if statistic_files:
+ bind_for_statistic(statistic_files, match_dict)
+ return bind_result
+
+ for npy_file in npy_files:
+ directory, file_name = os.path.split(npy_file) # 拆分路径
+ name_without_ext = os.path.splitext(file_name)[0] # 提取文件名(去掉扩展名)
+ if '.' not in name_without_ext:
+ # 3. 读取find.csv文件
+ csv_file_path = os.path.join(directory, 'mapping.csv')
+ df = pd.read_csv(csv_file_path, header=None)
+
+ # 4. 查找是否有与xxx.npy匹配的条目
+ matching_row = df[df[0] == file_name] # 假设A列存储文件名
+ if not matching_row.empty:
+ corresponding_name = matching_row[1].values[0]
+ logging.info("The corresponding name in column B is: %s", corresponding_name)
+ else:
+ corresponding_name = None
+ logging.info("No entry found for %s in find.csv.", file_name)
+ name_without_ext = os.path.splitext(corresponding_name)[0]
+ npy_path = os.path.realpath(npy_file)
+ node_scope = name_without_ext.split(".")[1]
+ trie = Trie()
+ for key, value in match_dict.items():
+ trie.insert(key, value)
+ bind_code = match_codes(trie, node_scope)
+ bind_name = match_names(trie, node_scope)
+ bind_result[npy_path] = (bind_code, bind_name)
+ return bind_result
diff --git a/debug/accuracy_tools/graph_analyzer/graph_analyzer/graph.py b/debug/accuracy_tools/graph_analyzer/graph_analyzer/graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..564f56c861f847de5f82f33d2cc91295a0871862
--- /dev/null
+++ b/debug/accuracy_tools/graph_analyzer/graph_analyzer/graph.py
@@ -0,0 +1,117 @@
+from typing import List, Dict, Union
+from collections import defaultdict, deque
+
+class GraphNode:
+ def __init__(self, name: str, pos: int = -1, unique_name: str = "", operator_name: str = "", return_variable: str = "", return_value: str = "",
+ var_inputs: List[str] = None, has_constant_input: bool = False, unique_id: str="", scope: str = "", code_info: List[str] = None,
+ is_subgraph: bool = False, attrs: Union[Dict[str, str], List[str]] = None):
+ self.name = name
+ self.unique_name = unique_name
+ self.pos = pos
+ self.operator_name = operator_name
+ self.return_variable = return_variable
+ self.return_value = return_value
+ self.var_inputs = var_inputs if var_inputs else []
+ self.has_constant_input = has_constant_input
+ self.unique_id = unique_id
+ self.scope = scope
+ self.code_info = code_info if code_info else []
+ self.attrs = attrs if attrs else ({} if not is_subgraph else [])
+ self.nodes = {} # Internal nodes if this is a subgraph
+ self.predecessors = [] # Predecessor nodes
+ self.successors = [] # Successor nodes
+ self.is_subgraph = is_subgraph
+
+ def trace_back_ancestors(self, ancestors: List[str], visited: Dict[str, bool], parser) -> None:
+ if visited[self.unique_name]:
+ return
+ visited[self.unique_name] = True
+ ancestors.append(self.unique_name)
+ for predecessor in self.predecessors:
+ predecessor.trace_back_ancestors(ancestors, visited, parser)
+
+
+class Graph:
+ def __init__(self, nodes):
+ self.nodes = set(nodes.values())
+
+ def topological_sort(self):
+ # 创建邻接表和入度表
+ nodes = self.nodes
+ in_degree = {node: len(node.predecessors) for node in nodes}
+
+ # 初始化队列,将所有入度为 0 的节点加入队列
+ queue = deque([node for node in nodes if in_degree[node] == 0])
+ topo_order = []
+
+ # Kahn算法的拓扑排序
+ while queue:
+ node = queue.popleft()
+ topo_order.append(node)
+
+ for successor in node.successors:
+ in_degree[successor] -= 1
+ if in_degree[successor] == 0:
+ queue.append(successor)
+
+ return topo_order
+
+ def find_independent_nodes(self, subset_nodes):
+ # 获取整个图的拓扑排序
+
+ topo_order = self.topological_sort()
+
+ # 将子集节点记录为集合,方便查找
+ subset_set = set(subset_nodes)
+
+ # 追踪哪些子集节点有被访问过
+ visited = set()
+
+ # 筛选出不被其他子集节点依赖的节点
+ independent_nodes = []
+
+ # 按照拓扑排序遍历
+ for node in topo_order:
+ if node in subset_set:
+ # 如果该节点在子集中,检查它是否已经被访问
+ if node not in visited:
+ independent_nodes.append(node)
+ # 将该节点指向的所有邻居标记为访问过(被依赖过)
+ for successor in node.successors:
+ if successor in subset_set:
+ visited.add(successor)
+ return independent_nodes
+
+def find_boundary_nodes(nodes, domain_level):
+ domain_structure = defaultdict(lambda: {'boundary': {'upper': set(), 'lower': set()}, 'nodes': set()})
+
+ for node in nodes:
+ if node.scope.startswith("Gradient"):
+ continue
+ node_new_scope = node.scope.split('/')
+ if domain_level <= len(node_new_scope) - 1: # 确保不使用最后一级
+ current_domain = '/'.join(node_new_scope[:domain_level])
+ domain_structure[current_domain]['nodes'].add(node)
+
+ for domain, data in domain_structure.items():
+ # 遍历域内的节点,寻找上边界和下边界
+ for node in data['nodes']:
+ if not node.operator_name.startswith("Prim"):
+ continue
+ node_scope = node.scope.split('/')
+ for succ in node.successors:
+ succ_scope = succ.scope.split('/')
+ if succ.scope.startswith("Gradient") or len(succ_scope) == 2:
+ continue
+ if (succ.operator_name != "Param" and succ.operator_name != "Constant") and node_scope[:domain_level] != succ_scope[:domain_level]:
+ data['boundary']['lower'].add(node.name)
+ for pred in node.predecessors:
+ pred_scope = pred.scope.split('/')
+ if (pred.operator_name != "Param" and pred.operator_name != "Constant") and node_scope[:domain_level] != pred_scope[:domain_level]:
+ data['boundary']['upper'].add(node.name)
+
+ # 递归处理子域
+ sub_nodes = [node for node in data['nodes'] if len(node.scope) > domain_level]
+ if sub_nodes:
+ domain_structure[domain].update(find_boundary_nodes(sub_nodes, domain_level + 1))
+ return domain_structure
diff --git a/debug/accuracy_tools/graph_analyzer/graph_analyzer/graph_parser.py b/debug/accuracy_tools/graph_analyzer/graph_analyzer/graph_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..376473ca0ea444df978bfa81c5a96ccb7c3ff55a
--- /dev/null
+++ b/debug/accuracy_tools/graph_analyzer/graph_analyzer/graph_parser.py
@@ -0,0 +1,214 @@
+import re
+import logging
+from typing import Tuple, List, Dict
+from graph_analyzer.graph import GraphNode
+
+
+class Parser:
+ def __init__(self):
+ self.nodes = {}
+ self.local_dict = {}
+ self.number_dict = {}
+
+ @staticmethod
+ def parse_subgraph_attributes(text: str, subgraph_node: GraphNode, start_pos: int, end_pos: int) -> None:
+ subgraph_attr_pattern = re.compile(r'subgraph attr:\s*(.*)', re.DOTALL)
+ match = subgraph_attr_pattern.search(text, start_pos, end_pos)
+ if match:
+ attrs = match.group(1).strip().split('\n')
+ if isinstance(subgraph_node.attrs, list):
+ subgraph_node.attrs.extend(attrs)
+
+ @staticmethod
+ def parse_graph_attributes(text: str, graph_node: GraphNode) -> None:
+ attr_pattern = re.compile(r'# Attrs:\s*(.*)', re.DOTALL)
+ match = attr_pattern.search(text, graph_node.pos)
+ if match:
+ attrs = match.group(1).strip().split('\n')
+ for attr in attrs:
+ if not attr: # if end line
+ break
+ key, value = attr.split(':')
+ if isinstance(graph_node.attrs, dict):
+ graph_node.attrs[key.strip()] = value.strip()
+
+ @staticmethod
+ def parse_code_info(text: str, start_pos: int, end_pos: int) -> List[str]:
+ code_info = []
+ code_info_pattern = re.compile(r'# .*', re.MULTILINE)
+ final_pos = end_pos if end_pos else len(text) - 1
+ lines = text[start_pos + 1:final_pos].split('\n')
+ for line in lines:
+ match = code_info_pattern.search(line)
+ if not match:
+ break
+ code_info.append(match.group(0).strip('# ').strip('/'))
+ return code_info
+
+ @staticmethod
+ def extract_bracket_content(text: str, start_pos: int) -> Tuple[str, int]:
+ stack = []
+ content = []
+ for i in range(start_pos, len(text)):
+ char = text[i]
+ if char == '(':
+ stack.append('(')
+ elif char == ')':
+ stack.pop()
+ if not stack:
+ content.append(char)
+ return ''.join(content), i
+ content.append(char)
+ raise ValueError("Mismatched parentheses")
+
+ # check ok
+ @staticmethod
+ def find_matching_brace(text: str, start_pos: int) -> int:
+ stack = []
+ for i in range(start_pos, len(text)):
+ if text[i] == '{':
+ stack.append('{')
+ elif text[i] == '}':
+ stack.pop()
+ if not stack:
+ return i
+ raise ValueError("Matching closing brace not found")
+
+ # check ok
+ @staticmethod
+ def extract_constants(inputs_str: str) -> List[str]:
+ constant_pattern = re.compile(r'\b(\w+\(.*?\))')
+ constants = constant_pattern.findall(inputs_str)
+ return constants
+
+ def parse_func_graph(self, text: str) -> None:
+ func_graph_pattern = re.compile(r'# IR entry: @(\S+)')
+ matches = func_graph_pattern.finditer(text)
+ for match in matches:
+ func_name = match.group(1)
+ func_graph_info = GraphNode(name=func_name, pos=match.start(), is_subgraph=False)
+ self.nodes[func_name] = func_graph_info
+
+ def parse_nodes(self, text: str, subgraph_info: GraphNode) -> None:
+ node_pattern = re.compile(r'(%\d+)\((\S+)\)\s*=\s*(\S+)\(')
+ matches = list(node_pattern.finditer(text))
+ for i, match in enumerate(matches):
+ series_number = match.group(1)
+ variable_name = match.group(2)
+ operator_name = match.group(3)
+ unique_name = "&".join([series_number, variable_name])
+ self.local_dict[series_number] = unique_name
+
+ args_str, end_pos = self.__class__.extract_bracket_content(text, match.end() - 1)
+ inputs = re.findall(r'%\w+', args_str)
+ subgraph_inputs = re.findall(r'@\w+', args_str)
+ inputs += subgraph_inputs
+
+ constants = self.__class__.extract_constants(args_str)
+
+ scope_pattern = re.compile(r'# .*scope.*:\s*\((.*?)\)', re.IGNORECASE | re.MULTILINE)
+ # [^:]scope[^:]:\s*\((.*?)\)
+ scope_match = scope_pattern.search(text, end_pos)
+ scope = scope_match.group(1) if scope_match else ""
+
+ id_pattern = re.compile(r'.*cnode_primal_attrs:\s*\{.*\b(?:forward_unique_id|unique_id):\s*\"(\d+)\".*', re.IGNORECASE)
+ unique_id_match = id_pattern.search(text, end_pos, scope_match.start())
+ unique_id = unique_id_match.group(1) if unique_id_match else None
+
+ if scope:
+ next_match = matches[i + 1].start() - 1 if i < len(matches) - 1 else None
+ code_info = self.__class__.parse_code_info(text, scope_match.end(), next_match)
+ else:
+ code_info = None
+
+ node_info = GraphNode(name=variable_name, unique_name=unique_name, operator_name=operator_name, var_inputs=inputs + constants, unique_id=unique_id, scope=scope, code_info=code_info)
+
+ if unique_id and scope and not scope.startswith("Gradients"):
+ self.number_dict[unique_id] = node_info
+
+ if subgraph_info:
+ subgraph_info.nodes[variable_name] = node_info # 这里不用unique_name会有事吗
+
+ if not self.nodes.get(unique_name, None):
+ self.nodes[unique_name] = node_info
+ else:
+ pass
+
+ for const in constants:
+ if const not in self.nodes:
+ const_node = GraphNode(name=const, operator_name="Constant", var_inputs=[], has_constant_input=True)
+ if not self.nodes.get(const_node, None):
+ self.nodes[const] = const_node
+ if subgraph_info:
+ subgraph_info.nodes[const] = const_node
+ self.local_dict[const] = const
+
+ for input_var in node_info.var_inputs:
+ if input_var in self.local_dict or input_var in self.nodes:
+ input_name = self.local_dict.get(input_var, input_var) # 没有就用原来名字
+ input_node = self.nodes.get(input_name, None)
+ if input_node:
+ node_info.predecessors.append(input_node)
+ input_node.successors.append(node_info)
+ else:
+ param_node = GraphNode(name=input_var, operator_name="Param", var_inputs=[], has_constant_input=False)
+ if not self.nodes.get(input_var, None):
+ self.nodes[input_var] = param_node
+ node_info.predecessors.append(param_node)
+ param_node.successors.append(node_info)
+
+ # check ok
+ def extract_callees(self, text: str) -> None:
+ for node_info in self.nodes.values():
+ func_start_pos = node_info.pos
+ func_end_pos = text.find('}', func_start_pos)
+ func_text = text[func_start_pos:func_end_pos]
+ callee_pattern = re.compile(r'Partial\(@(\S+)\(')
+ callee_matches = callee_pattern.finditer(func_text)
+ for callee_match in callee_matches:
+ callee_name = callee_match.group(1)
+ if callee_name not in node_info.var_inputs:
+ node_info.var_inputs.append(callee_name)
+
+ # check ok
+ def parse_subgraphs(self, text: str) -> None:
+ subgraph_pattern = re.compile(r'subgraph\s+@(\S+)(\([^\)]*\))?\s+.*\{')
+ matches = list(subgraph_pattern.finditer(text))
+ end_pos = 0
+ for match in matches:
+ last_pos = end_pos + 2
+ subgraph_name = match.group(1).split('(')[0]
+ start_pos = match.start()
+ end_pos = self.__class__.find_matching_brace(text, start_pos)
+ subgraph_text = text[start_pos:end_pos + 1]
+ attr_text = text[last_pos:start_pos]
+ subgraph_info = GraphNode(name=subgraph_name, pos=start_pos, is_subgraph=True)
+ self.nodes[subgraph_name] = subgraph_info
+ self.__class__.parse_subgraph_attributes(text, subgraph_info, last_pos, start_pos)
+ self.parse_nodes(subgraph_text, subgraph_info)
+ subgraph_info.end = end_pos
+ logging.info('Parsed subgraph: %s', subgraph_name)
+
+ # check ok
+ def count_nodes(self) -> Tuple[int, int]:
+ total_nodes = len(self.nodes)
+ total_cnodes = sum(1 for node in self.nodes.values() if node.name.startswith('CNode'))
+ return total_nodes, total_cnodes
+
+ # check ok
+ def create_backward_map(self):
+ for node in self.nodes.values():
+ if node.scope and node.scope.startswith("Gradients"):
+ related_forward_node = self.number_dict.get(node.unique_id, None)
+ if related_forward_node:
+ node.code_info = related_forward_node.code_info
+
+ def parse(self, text: str) -> None:
+ self.parse_func_graph(text)
+ self.parse_subgraphs(text)
+ self.parse_nodes(text, None)
+ self.extract_callees(text)
+ self.create_backward_map()
+
+ def get_nodes(self) -> Dict[str, GraphNode]:
+ return self.nodes
diff --git a/debug/accuracy_tools/graph_analyzer/graph_analyzer/main.py b/debug/accuracy_tools/graph_analyzer/graph_analyzer/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..18df68f01670018fb989e3b35eaaa5658a14461a
--- /dev/null
+++ b/debug/accuracy_tools/graph_analyzer/graph_analyzer/main.py
@@ -0,0 +1,23 @@
+import os
+import re
+import sys
+import argparse
+from typing import List
+from graph_analyzer.processor import process
+
+
+
+def main():
+ parser = argparse.ArgumentParser(description="IR Parser")
+ parser.add_argument('--ir', type=str, required=True, help="Path to the graph file")
+ parser.add_argument('--data', type=str, required=False, default=None, help="Path to data dir")
+ parser.add_argument('--node-list', type=List, required=False, default=None, help="Error node list")
+ parser.add_argument('--output', type=str, required=False, default="./", help="Path to output dir")
+ parser.add_argument('--append', action='store_true', help="Whether to append to the CSV file if it exists")
+ args = parser.parse_args()
+
+ process(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/debug/accuracy_tools/graph_analyzer/graph_analyzer/processor.py b/debug/accuracy_tools/graph_analyzer/graph_analyzer/processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..370008da591df2566cfc8d37b8bda9dfe5b57bff
--- /dev/null
+++ b/debug/accuracy_tools/graph_analyzer/graph_analyzer/processor.py
@@ -0,0 +1,45 @@
+import os
+import stat
+import json
+from graph_analyzer.graph import Graph, find_boundary_nodes
+from graph_analyzer.graph_parser import Parser
+from graph_analyzer.bind import bind_code_info_for_data, write_to_csv
+
+def serialize_domain_structure(domain_structure):
+ serialized_structure = {}
+ for domain, data in domain_structure.items():
+ serialized_structure[domain] = {
+ 'boundary': {'upper': list(data['boundary']['upper']), 'lower': list(data['boundary']['lower'])},
+ 'nodes': [node.name for node in data['nodes']]
+ }
+ # 递归处理子域,避免解析 boundary 和 nodes 部分
+ for key in data:
+ if key not in ['boundary', 'nodes', 'upper', 'lower']:
+ serialized_structure[domain][key] = serialize_domain_structure({key: data[key]})
+ return serialized_structure
+
+def process(args):
+ ir_file_path = args.ir
+ with open(ir_file_path, 'r') as f:
+ input_text = f.read()
+
+ parser = Parser()
+ parser.parse(input_text)
+
+ nodes = parser.get_nodes()
+ graph = Graph(nodes)
+
+ if args.data:
+ bind_result = bind_code_info_for_data(args.data, nodes)
+ if bind_result:
+ # 将 append 参数传递给 write_to_csv
+ write_to_csv(bind_result, args.output, args.append)
+
+ domain_structure = find_boundary_nodes(nodes.values(), 1)
+ output_structure = serialize_domain_structure(domain_structure)
+ output_file = os.path.join(args.output, "struct.json")
+
+ # 使用 os.open() 指定文件权限
+ fd = os.open(output_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) # 0o600 表示读写
+ with open(fd, "w") as f:
+ json.dump(output_structure, f, indent=4)
diff --git a/debug/accuracy_tools/graph_analyzer/setup.py b/debug/accuracy_tools/graph_analyzer/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b4fbb52fb0ef997eb0339fb49c6393fd33f9019
--- /dev/null
+++ b/debug/accuracy_tools/graph_analyzer/setup.py
@@ -0,0 +1,26 @@
+import os
+from setuptools import setup, find_packages
+
+setup(
+ name='graph_analyzer',
+ version='1.0.0',
+ packages=find_packages(),
+ install_requires=[
+ ],
+ entry_points={
+ 'console_scripts': [
+ 'graph_analyzer=graph_analyzer.main:main', # Allows running `graph_analyzer`
+ ],
+ },
+ author='TAJh',
+ description='Graph Analyzer used for graph analysis and dump data analysis',
+ long_description=open('README.md', encoding='utf-8').read() if os.path.exists('README.md') else '',
+ long_description_content_type='text/markdown',
+ url='https://gitee.com/tajh/graph_analyzer.git', # Replace with the correct URL
+ classifiers=[
+ 'Programming Language :: Python :: 3',
+ 'License :: OSI Approved :: MIT License',
+ 'Operating System :: OS Independent',
+ ],
+ python_requires='>=3.6',
+)
diff --git a/debug/accuracy_tools/kj600/README.md b/debug/accuracy_tools/monitor/README.md
similarity index 54%
rename from debug/accuracy_tools/kj600/README.md
rename to debug/accuracy_tools/monitor/README.md
index bd97acf6dc9b35b21d23d58ff1e6cb56a82e92fb..d9a1171f4ea9237931f704e5b4cf5b6a68642dbb 100644
--- a/debug/accuracy_tools/kj600/README.md
+++ b/debug/accuracy_tools/monitor/README.md
@@ -1,4 +1,4 @@
-# TensorProbe (codename:kj600) 模型训练状态监控工具
+# Monitor 模型训练状态监控工具
## 简介
@@ -10,7 +10,7 @@
| 依赖软件 |
|-------------|
-| torch |
+| torch>=2.0 |
| torch_npu |
| torchvision |
| tensorboard |
@@ -18,25 +18,126 @@
| sqlalchemy |
| pymysql |
-### 2. 安装 kj600
+### 2. 安装 monitor
-方式一:从 git 直接安装
+方式一:下载源码安装
```
-pip install git+https://gitee.com/xiangsen2/kj600.git
+git clone -b poc https://gitee.com/ascend/mstt.git
+cd mstt/debug/accuracy_tools/monitor
+pip install .
```
-方式二:下载源码安装
+## 快速上手
+### 梯度监控
+模型训练状态的异常通常会反映在loss和梯度上,通过对模型各个模块梯度的监控,可以帮助快速定位异常的第一现场。
+
+1. 输出目录
+监控结果写入tensorboard的event文件/csv中,设置输出路径(默认为`monitor_output`,通过环境变量配置)
+```bash
+export MONITOR_OUTPUT_DIR=/xxx/output_dir
```
-git clone https://gitee.com/xiangsen2/kj600.git
-cd kj600
-pip install .
+
+2. 在训练脚本中使能工具(Megatron-LM)
+
+```
+from monitor.module_hook import TrainerMon
+hooker = TrainerMon("./monitor_config.json", process_group=None, params_have_main_grad=True)
+
+model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
+ model_provider, model_type)
+# 模型、优化器初始化后使能工具
+
+hooker.monitor_gnorm_with_ad(
+ model, grad_acc_steps=args.global_batch_size//args.data_parallel_size//args.micro_batch_size, optimizer=optimizer, dp_group=mpu.get_data_parallel_group(), tp_group=mpu.get_tensor_model_parallel_group())
+
+
+# 可以在任意位置获取当前的梯度统计量, 不同调用位置不能保证reduce已完成
+reduced, unreduced = hooker.generate_wgrad_metrics()
+```
+
+
+| 字段名字 | 是否必选 | 解释 |
+| ------------------------------------------------------------ | -------- | -------- |
+|"grad_acc_steps"| 必选 |梯度累积的步数,当micro step=grad acc steps时,会触发反向hook获取模型梯度|
+|"optimizer"| 可选 |各种并行域reduce后的梯度在opt.step前获取,数据写入在step后进行。默认patch pytorch的优化器,传入其他优化器(如MegatronOptimizer)可以调整工具行为,如clip_grad发生在megatron的优化器中,pytorch的优化器之前。|
+|"dp_group"| 可选 |训练过程中的dp_group。dp域通信后,group内所有rank的梯度相同,落盘数据冗余。提供dp_group后,工具仅保留每个dp_group的第一个rank的梯度|
+|"tp_group"| 可选 |训练过程中的tp_group。tp域通信后,group内部分参数所有rank的梯度相同,落盘数据冗余。提供tp_group后,工具仅保留每个tp_group中冗余参数在第一个rank的梯度。当前适配Megatron core_v0.6.0, 通过权重属性`tensor_model_parallel`判断是否冗余|
+
+3. 在json文件中配置工具
+```
+{
+ "targets": {
+ "module": {},
+ "module.module.language_model.encoder.layers.0": {"input_grad":"tuple[1]:0", "output_grad":"tuple[2]:0"}
+ },
+ "print_struct": false, # 若不了解模型结构,可以打开print_struct打印模型结构
+ "module_ranks": [0,1,2,3], # 需要监控的rank
+ "wg_distribution": true,
+ "format": "csv", # 如果不需要落盘文件,设置为 "api"
+ "ops": ["norm", "min", "max", "mean"],
+ "eps": 1e-8,
+ "ndigits: 6
+}
+```
+
+4. 结果验证
+训练日志中通常会打屏一个训练步的grad norm。提供了脚本校验落盘数据和打屏信息的一致性。
+```bash
+python monitor/unittest/test_monitor.py -m monitor_output/Aug13_02-27-5 -l logs/train_gpt3_TP2_PP1_CP1_monitor.log -d 2 -t 2
+```
+`-m`指定落盘csv的路径前缀。`-l`指定训练日志。脚本通过关键词`grad norm: `匹配训练日志中的grad norm,根据实际情况修改。从落盘数据计算的grad norm和日志中的grad norm相对偏差超过1%,会有警告。`-d`、`--dp_size`声明data parallel size,`-t`、`--tp_size`声明tensor paralllel size。
+示例输出:
+```txt
+rank 2 is duplicated in dp group
+rank 3 is duplicated in dp group
+grad norm in consiste between training log and reduced gradients monitored
+grad mean is in consisten between unreduced grad and reduced grad monitord.
+```
+需要提供并行相关参数,具体参见:
+```bash
+python monitor/unittest/test_monitor.py -h
+```
+### 梯度异常时序判断
+0. 训练前配置相关参数
+工具支持自动判断训练过程中的梯度异常,需要在配置文件中设置alert相关字段。`AnomalyTurbulence`会将当前数值与历史均值比较,如果相对偏差超过阈值,会在打屏信息中提示用户。如果打开`dump`选项,则会将异常梯度相关信息落盘,用于后续时序判断。
+```json
+ "alert": {
+ "rules": [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}],
+ "dump": true
+ },
+```
+1. 实例化工具时传入流水线并行group
+```python
+hooker = TrainerMon("./monitor_config.json", process_group=mpu.get_pipeline_model_parallel_group(), params_have_main_grad=True)
+```
+照常开始训练
+
+2. 进入工具路径启动异常分析脚本:
+```shell
+cd monitor/
+python3 anomaly_analyse.py -d $MONITOR_OUTPUT_DIR/anomaly_detected
+```
+支持以下参数配置
+| 字段名字 | 解释 | 是否必选释 |
+| ------ | -------- | -------- |
+|-d 或 --data_path| 指定梯度异常落盘文件夹,梯度监控功能输出,一般为$MONITOR_OUTPUT_DIR/anomaly_detected。|是 |
+|-o 或 --out_path| 排序后的异常落盘文件地址,默认在--data_path路径下落盘一个anomaly_analyse.json文件| 否 |
+|-k 或 --topk| 指定保留前topk个异常,默认为8| 否 |
+|-s 或 --step_list| 指定分析的step范围,默认为[]| 否 |
+
+## 已知问题
+- Megatron中使用流水线并行时,完成当前stage的计算并将output传递到下一个stage后,会调用`deallocate_output_tensor`释放output。当工具使能后,部分功能会给一些module注错反向hook,hook功能可能为output创建一个view副本,导致output内存无法释放。如果工具使能后出现如下报错,则需要跳过deallocate的步骤。在较新的megatron代码中,可以在`megatron/training/arguments.py`中将`kw_args['deallocate_pipeline_outputs']`设为False,或在`megatron/core/pipeline_parallel/schedules.py`中跳过`deallocate_output_tensor`的调用
+```bash
+File "~/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 117, in deallocate_output_tensor
+ assert out._base is None, "counter-productive to free a view of another tensor."
+AssertionError: counter-productive to free a view of another tensor.
```
-# 快速上手
+## 详细配置
- 下面以Ascend/ModelLink训练框架为例,给出kj600工具的使用方法。
+ 下面以Ascend/ModelLink训练框架为例,给出monitor工具的使用方法。
1. 在ModelLink的根目录,创建json配置文件,如llama2_config.json,内容如下:
@@ -54,8 +155,10 @@ pip install .
"cc_distribution": {"enable":true, "cc_codeline":[]},
"alert": {
"rules": [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}],
- "inform": {"recipient": "database", "connection_str": "mysql+pymysql://username:password@host:port/database"}
+ "inform": {"recipient": "database", "connection_str": "mysql+pymysql://username:password@host:port/database"},
+ "dump": true
},
+ "format": "tensorboard"
"ops": ["min", "max", "norm", "zeros", "id"],
"eps": 1e-8
}
@@ -78,8 +181,9 @@ pip install .
|"xy_distribution"| 可选 | 若为true则会监控指定module(targets中指定)的输入输出张量。 默认为false。|
|"mv_distribution"| 可选 | 若为true则会监控指定模块中的参数的优化器状态, 默认为false。需要在TrainerMon构造函数正确指定opt_ty. 目前只支持megatron的混合精度优化器以及megatron的分布式优化器。 Deepspeed的分布式优化器实现暂不支持。 |
|"wg_distribution"| 可选 | 若为true则会监控指定模块的参数梯度, 默认为false。 |
-|"alert"| 必选 | · "rules": 指定自动报警的异常检测机制及其相应的阈值。目前实现的异常检测是AnomalyTurbulence。 如果统计标量超出历史均值的指定浮动范围(threshold指定, 0.5意味着上浮或者下浮50%)则在控制台打印报警信息。
· "inform": 自动报警需要的配置,若想关闭自动报警删掉inform的配置即可。其中"recipient"指定自动报警的通知方式,可选值为"database"或"email",默认为"database"。
- 若"recipient"为"database",则需要指定"connection_str"字段,即数据库的连接URL,默认为{"recipient":"database", "connection_str": "mysql+pymysql://username:password@host:port/database"},若有特殊字符需要转义。
- 若"recipient"为"email",则需要指定"send_email_address"-发送方邮箱地址,"receive_email_address"-接收方邮箱地址,"send_email_username"-发送方邮箱用户名,"send_email_password"-发送方邮箱密码,"smtp_server"-发送方邮箱对应的SMTP服务器,"smtp_port"-发送方邮箱对应的SMTP端口号。默认为:
{"recipient":"email", send_email_address": "sender@huawei.com", "receive_email_address": "receiver@huawei.com", "send_email_username": "username", "send_email_password": "******", "smtp_server": "smtpscn.huawei.com", "smtp_port": "587"}|
-|"cc_distribution"| 可选 | 其中“enable”字段控制开关;需要监控通信算子时,务必尽量早地实例化`TrainerMon`, 因为监控通过劫持原始func后挂hook实现,部分加速库初始化时会保存原始function,避免监控失效。“cc_codeline”字段指定监控的代码行,如:`train.py\\[23\\]`,默认为空列表,不特别指定;"cc_pre_hook"字段控制是否监控通信前的数据; "cc_log_only"为true时,仅记录调用到的算子及其调用栈, 不监控通信的输入输出|
+|"alert"| 必选 | · "rules": 指定自动报警的异常检测机制及其相应的阈值。目前实现的异常检测是AnomalyTurbulence。 如果统计标量超出历史均值的指定浮动范围(threshold指定, 0.5意味着上浮或者下浮50%)则在控制台打印报警信息。
· "inform": 自动报警需要的配置,若想关闭自动报警删掉inform的配置即可。其中"recipient"指定自动报警的通知方式,可选值为"database"或"email",默认为"database"。
- 若"recipient"为"database",则需要指定"connection_str"字段,即数据库的连接URL,默认为{"recipient":"database", "connection_str": "mysql+pymysql://username:password@host:port/database"},若有特殊字符需要转义。
- 若"recipient"为"email",则需要指定"send_email_address"-发送方邮箱地址,"receive_email_address"-接收方邮箱地址,"send_email_username"-发送方邮箱用户名,"send_email_password"-发送方邮箱密码,"smtp_server"-发送方邮箱对应的SMTP服务器,"smtp_port"-发送方邮箱对应的SMTP端口号。默认为:
{"recipient":"email", send_email_address": "sender@h*****.com", "receive_email_address": "receiver@h*****.com", "send_email_username": "username", "send_email_password": "******", "smtp_server": "smtpscn.h*****.com", "smtp_port": "587"}|
+|"cc_distribution"| 可选 | 其中"enable"字段控制通信监控模块的开关;需要监控通信算子时,务必尽量早地实例化`TrainerMon`, 因为监控通过劫持原始func后挂hook实现,部分加速库初始化时会保存原始function,避免监控失效。"cc_codeline"字段指定监控的代码行,如:`train.py\\[23\\]`,默认为空列表,不特别指定;"cc_pre_hook"字段控制是否监控通信前的数据; 模块会在第二个optimize.step之前打印通信日志,包括通信api的调用栈、输入dtype、通信group。 "cc_log_only"为true时,仅打印日志,不监控通信的输入输出,并在打印后中断训练。可以根据通信日志设置"cc_codeline",规避与训练过程不相关的通信,比如一些时间、metrics的同步。|
+|"format"| 可选 | 数据落盘格式,默认为tensorboard,支持可选 "csv"。 |
|"ops"| 可选 |与ur_distribution、xy_distribution、mv_distribution、wg_distribution、mg_direction、cc_distribution配合,监控所选张量的min、max、norm、zeros值。其中,zeros代表监控所选张量的元素小于eps的比例,id代表监控所选的非张量本身,默认为[]。|
|"eps"| 可选 |若ops里包含"zeros"则需要配置,默认为1e-8。|
@@ -109,33 +213,36 @@ pip install .
}
```
-2. 在训练器中加入代码,开启kj600训练监控。
+2. 在训练器中加入代码,开启monitor训练监控。
例如在ModelLink/pretrain_gpt.py的model_provider GPTModel构造后加入以下代码, **注意优化器类型opt_ty** :
```
- from kj600.module_hook import TrainerMon
- hooker = TrainerMon("./llama2_config.json", params_have_main_grad=True, opt_ty="Megatron_DistributedOptimizer") # or opt_ty=Megatron_Float16OptimizerWithFloat16Params
+ from monitor.module_hook import TrainerMon
+ hooker = TrainerMon("./llama2_config.json", process_group=None, params_have_main_grad=True, opt_ty="Megatron_DistributedOptimizer") # or opt_ty=Megatron_Float16OptimizerWithFloat16Params
hooker.hook_modules(model=model, grad_acc_steps=args.global_batch_size//args.data_parallel_size//args.micro_batch_size)
```
params_have_main_grad: 若为True则参数权重梯度为main_grad,否则为grad,默认为True。
如果不是Megatron-LM的训练框架, 可以设置对应的梯度累积步数grad_acc_steps。
- 如果要监控混合精度优化器的动量和方差, 需要在混合精度优化器构造后加入如下代码。 目前只支持Megatron_DistributedOptimizer, 使用bf16或者fp16混合精度时开启分布式优化器。 或者Megatron_Float16OptimizerWithFloat16Params, 使用bf16或者fp16混合精度选项并且不开启分布式优化器。
+ 如果要监控优化器的动量和方差,需要在优化器构造后加入如下代码。 目前支持Megatron实现的优化器:
+ - Megatron_FP32OptimizerMon,普通优化器。
+ - Megatron_Float16OptimizerWithFloat16Params, 使用bf16或者fp16混合精度选项并且不开启分布式优化器。
+ - Megatron_DistributedOptimizer, 使用bf16或者fp16混合精度时开启分布式优化器。
```
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
model_provider, model_type)
# 插入位置
- from kj600.module_hook import TrainerMon
+ from monitor.module_hook import TrainerMon
TrainerMon.set_wrapped_optimizer(optimizer)
```
3. 配置tensorboard写入的目录
```
- export KJ600_OUTPUT_DIR=/xxx/output_dir
+ export MONITOR_OUTPUT_DIR=/xxx/output_dir
```
4. 开始预训练,在日志中如果发现以下内容, 则说明指定的模块被成功监视。
@@ -148,7 +255,7 @@ pip install .
5. 训练过程中,打开tensorboard,可以查看训练的中间状态:
```
-tensorboard --logdir=$KJ600_OUTPUT_DIR
+tensorboard --logdir=$MONITOR_OUTPUT_DIR
```
之后,运行以下SSH命令来建立端口转发,可以在本地通过http://localhost:6006访问tensorboard:
@@ -171,6 +278,7 @@ TrainerMon.__init__(config_file_path, params_have_main_grad=True, opt_ty=None) -
| 参数 | 说明 | 是否必选 |
| ----- | -------------------- | -------- |
| config_file_path |自己写的json配置文件路径。 | 是 |
+| process_group | 传入ProcessGroup对象,用以确定pipeline并行不同rank异常间时序,megatron下通过core.parallel_state.get_pipeline_model_parallel_group()获得 | 否 |
| params_have_main_grad |权重是否使用main_grad,是就为True,否则为False。默认为True。 | 否 |
| opt_ty |优化器类型,有两个选项,Megatron_DistributedOptimizer:使用bf16或者fp16混合精度时开启分布式优化器;Megatron_Float16OptimizerWithFloat16Params:使用bf16或者fp16混合精度选项并且不开启分布式优化器,也适用于常规的adam优化器。如果使用的不是adam优化器,使用None。默认为None。 | 否 |
diff --git a/debug/accuracy_tools/kj600/img/cpu_info.png b/debug/accuracy_tools/monitor/img/cpu_info.png
similarity index 100%
rename from debug/accuracy_tools/kj600/img/cpu_info.png
rename to debug/accuracy_tools/monitor/img/cpu_info.png
diff --git a/debug/accuracy_tools/kj600/img/train.png b/debug/accuracy_tools/monitor/img/train.png
similarity index 100%
rename from debug/accuracy_tools/kj600/img/train.png
rename to debug/accuracy_tools/monitor/img/train.png
diff --git a/debug/accuracy_tools/kj600/img/train_with_kj600.png b/debug/accuracy_tools/monitor/img/train_with_kj600.png
similarity index 100%
rename from debug/accuracy_tools/kj600/img/train_with_kj600.png
rename to debug/accuracy_tools/monitor/img/train_with_kj600.png
diff --git a/debug/accuracy_tools/atat/mindspore/debugger/__init__.py b/debug/accuracy_tools/monitor/monitor/__init__.py
similarity index 100%
rename from debug/accuracy_tools/atat/mindspore/debugger/__init__.py
rename to debug/accuracy_tools/monitor/monitor/__init__.py
diff --git a/debug/accuracy_tools/monitor/monitor/anomaly_analyse.py b/debug/accuracy_tools/monitor/monitor/anomaly_analyse.py
new file mode 100644
index 0000000000000000000000000000000000000000..aad42a84cef30b735f4caecf9e6d9d7fb58394de
--- /dev/null
+++ b/debug/accuracy_tools/monitor/monitor/anomaly_analyse.py
@@ -0,0 +1,248 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+# Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+import argparse
+import ast
+import fcntl
+import heapq
+import json
+import os
+from pathlib import Path
+import sys
+
+from monitor.utils import print_info_log, print_warn_log
+from monitor.anomaly_detect import GradAnomalyData
+from monitor.file_check import (
+ change_mode,
+ check_link,
+ FileCheckConst,
+ check_path_before_create,
+ FileChecker,
+ FileOpen,
+)
+
+ANOMALY_JSON = "anomaly.json"
+ANALYSE_JSON = "anomaly_analyse.json"
+
+class AnomalyDataWriter:
+ """
+ 异常数据写入类,负责将异常数据写入到JSON文件中。
+ """
+
+ def __init__(self, dump_path, rank) -> None:
+ self.dump_path = dump_path
+ self.dump_rank_dir = os.path.join(self.dump_path, f"rank{rank}")
+ self.json_path = os.path.join(self.dump_rank_dir, ANOMALY_JSON)
+
+ @staticmethod
+ def get_anomaly_dict(anomalies):
+ """将GradAnomalyData列表转换为json"""
+ anomalies_json = {}
+ for anomaly in anomalies:
+ anomalies_json.update({anomaly.get_key(): anomaly.to_dict()})
+ return anomalies_json
+
+ @staticmethod
+ def update_data_in_single_json(json_path, anomalies_data):
+ with FileOpen(json_path, "w+") as f:
+ fcntl.flock(f, fcntl.LOCK_EX)
+ json.dump(anomalies_data, f, indent=1)
+ fcntl.flock(f, fcntl.LOCK_UN)
+ change_mode(json_path, FileCheckConst.DATA_FILE_AUTHORITY)
+
+ def init_detected_json(self):
+ """初始化落盘文件"""
+ check_path_before_create(self.dump_path)
+ if not os.path.exists(self.dump_path):
+ Path(self.dump_path).mkdir(
+ mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True
+ )
+ file_check = FileChecker(self.dump_path, FileCheckConst.DIR)
+ file_check.common_check()
+
+ if not os.path.exists(self.dump_rank_dir):
+ Path(self.dump_rank_dir).mkdir(
+ FileCheckConst.DATA_DIR_AUTHORITY, parents=True, exist_ok=True
+ )
+
+ if os.path.exists(self.json_path):
+ file_check = FileChecker(
+ self.json_path, FileCheckConst.FILE, FileCheckConst.WRITE_ABLE
+ )
+ file_check.common_check()
+ print_warn_log(f"The existing file will be deleted: {self.json_path}.")
+ os.remove(self.json_path)
+ Path(self.json_path).touch()
+ change_mode(self.json_path, FileCheckConst.DATA_FILE_AUTHORITY)
+
+ def write_detected_json(self, anomalies):
+ """
+ 落盘异常数据
+ Args:
+ anomalies: GradAnomalyData对象列表
+ """
+ anomalies_json = self.get_anomaly_dict(anomalies)
+ print_info_log(f"{ANOMALY_JSON} is at {self.dump_rank_dir}.")
+ if Path(self.json_path).exists() and os.path.getsize(self.json_path) > 0:
+ with FileOpen(self.json_path, "r+") as f:
+ fcntl.flock(f, fcntl.LOCK_EX)
+ data_to_write = json.load(f)
+ fcntl.flock(f, fcntl.LOCK_UN)
+ else:
+ data_to_write = {}
+ data_to_write.update(anomalies_json)
+ self.update_data_in_single_json(self.json_path, data_to_write)
+
+
+class AnomalyDataLoader:
+ def __init__(self, data_path) -> None:
+ self.data_path = data_path
+
+ @staticmethod
+ def create_instances_from_dict(anomalies_dict: dict):
+ instances = []
+ for values in anomalies_dict.values():
+ try:
+ instances.append(GradAnomalyData(**values))
+ except KeyError as e:
+ print_warn_log(f"Missing key in anomaly data: {e}")
+ except ValueError as e:
+ print_warn_log(
+ f"Value error when creating a GradAnomalyData instance: {e}"
+ )
+ return instances
+
+ def get_anomalies_from_jsons(self):
+ """遍历文件夹,从rankK/anomaly.json中读取异常数据
+ return: anomalies: GradAnomalyData对象列表
+ """
+ anomalies = []
+ check_link(self.data_path)
+ for rank_dir in os.listdir(self.data_path):
+ rank_path = os.path.join(self.data_path, rank_dir)
+ if not os.path.isdir(rank_path):
+ continue
+ json_path = os.path.join(rank_path, ANOMALY_JSON)
+ if not os.path.exists(json_path):
+ continue
+ with FileOpen(json_path, "r+") as f:
+ fcntl.flock(f, fcntl.LOCK_EX)
+ data_anomalies = json.load(f)
+ fcntl.flock(f, fcntl.LOCK_UN)
+ instances = self.create_instances_from_dict(data_anomalies)
+ anomalies.extend(instances)
+ return anomalies
+
+
+class AnomalyAnalyse:
+ def __init__(self) -> None:
+ self.sorted_anomalies = []
+
+ def get_range_top_K(self, topk, step_list, anomalies):
+ """
+ 获取前topk个step_list范围内的异常。
+ """
+ if not step_list:
+ filtered_anomalies = anomalies
+ else:
+ filtered_anomalies = [
+ anomaly for anomaly in anomalies if anomaly.step in step_list
+ ]
+ if topk >= len(filtered_anomalies):
+ self.sorted_anomalies = sorted(filtered_anomalies)
+ else:
+ self.sorted_anomalies = list(heapq.nsmallest(topk, filtered_anomalies))
+ return self.sorted_anomalies
+
+ def rewrite_sorted_anomalies(self, output_path):
+ """
+ 将排序后的异常数据重新落盘
+ """
+ file_check = FileChecker(
+ output_path, FileCheckConst.DIR, FileCheckConst.WRITE_ABLE
+ )
+ file_check.common_check()
+
+ sorted_data = AnomalyDataWriter.get_anomaly_dict(self.sorted_anomalies)
+ print_info_log(f"{ANALYSE_JSON} is at {output_path}.")
+ json_path = os.path.join(output_path, ANALYSE_JSON)
+ if os.path.exists(json_path):
+ file_check = FileChecker(
+ json_path, FileCheckConst.FILE, FileCheckConst.WRITE_ABLE
+ )
+ file_check.common_check()
+ print_warn_log(f"The existing file will be deleted: {json_path}.")
+ os.remove(json_path)
+ Path(json_path).touch()
+ change_mode(json_path, FileCheckConst.DATA_FILE_AUTHORITY)
+ AnomalyDataWriter.update_data_in_single_json(json_path, sorted_data)
+
+
+def _get_parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-d", "--data_path", dest="data_path_dir", default="./", type=str,
+ help=" The anomaly detect result dictionary: generate from monitor tool.",
+ required=True,
+ )
+ parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
+ help=" The analyse task result out path.",
+ required=False,
+ )
+ parser.add_argument("-k", "--topk", dest="top_k_number", default=8, type=int,
+ help=" Top K number of earliest anomalies.",
+ required=False,
+ )
+ parser.add_argument("-s", "--step", dest="step_list", default=[], type=str,
+ help=" Analyse which steps.",
+ required=False,
+ )
+ return parser.parse_args(sys.argv[1:])
+
+def _get_step_and_stop(args):
+ try:
+ step_list = ast.literal_eval(args.step_list)
+ if not isinstance(step_list, list):
+ raise ValueError(f"{args.step_list} is not a list")
+ except (ValueError, SyntaxError, RecursionError) as e:
+ raise Exception(
+ f"The step list must be a resolvable list type"
+ ) from e
+ if args.top_k_number <= 0:
+ raise Exception("The top k number must be greater than 0.")
+ return step_list, args.top_k_number
+
+def _anomaly_analyse():
+ args = _get_parse_args()
+ step_list, top_k_number = _get_step_and_stop(args)
+ loader = AnomalyDataLoader(args.data_path_dir)
+ anomalies = loader.get_anomalies_from_jsons()
+ analyser = AnomalyAnalyse()
+ top_anomalies = analyser.get_range_top_K(
+ top_k_number, step_list, anomalies
+ )
+ analyser.rewrite_sorted_anomalies(
+ args.out_path if args.out_path else args.data_path_dir
+ )
+
+ print_info_log(f"Top {top_k_number} anomalies are listed as follows:")
+ for index, anomaly in enumerate(top_anomalies):
+ print_info_log(f"{index}: {anomaly.message}")
+
+
+if __name__ == "__main__":
+ _anomaly_analyse()
+ print_info_log("Analyse task completed.")
diff --git a/debug/accuracy_tools/kj600/kj600/anomaly_detect.py b/debug/accuracy_tools/monitor/monitor/anomaly_detect.py
similarity index 32%
rename from debug/accuracy_tools/kj600/kj600/anomaly_detect.py
rename to debug/accuracy_tools/monitor/monitor/anomaly_detect.py
index cbd7b6daa2f0d9b0a9b28016993e836ee07df72d..4a5d9baf9c578324c2e93cedb44ed66551fa813f 100644
--- a/debug/accuracy_tools/kj600/kj600/anomaly_detect.py
+++ b/debug/accuracy_tools/monitor/monitor/anomaly_detect.py
@@ -1,10 +1,16 @@
+import os
+import sys
import statistics as st
from abc import ABC
from typing import List
-import sys
-from torch.utils.tensorboard import SummaryWriter
from collections import defaultdict
-from kj600.utils import print_info_log
+from dataclasses import dataclass, field
+import pandas as pd
+from torch.utils.tensorboard import SummaryWriter
+from monitor.utils import print_info_log, check_file_valid_writable, make_file_safety, create_directory
+from monitor.const import Const
+from monitor.file_check import change_mode, FileCheckConst
+
class ScanRule(ABC):
def apply(self, history, cur):
@@ -59,15 +65,101 @@ class bcolors:
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
-class SummaryWriterWithAD(SummaryWriter):
- def __init__(self, path, ad_rules, job_id, anomaly_inform=False):
- super().__init__(path)
+class AnomalyDataFactory(ABC):
+ def __init__(self, rank, pp_stage, group_mates):
+ super().__init__()
+ self.rank = rank
+ self.pp_stage = pp_stage
+ self.group_mates = group_mates
+ self.micro_step = 0
+ self.vpp_stage = 0
+ self.name2callid = {}
+
+ def set_call_id(self, name2callid):
+ """根据当前GradContext信息更新call_id vpp_stage等信息
+ """
+ self.name2callid = name2callid
+
+ def create(self, tag_name, message, step):
+ """如果检查出异常, 调用当前接口生成GradAnomalyData实例
+ """
+ param_name = tag_name.split('/')[0]
+ call_id = self.name2callid.get(param_name,-1)
+ if Const.vpp in param_name:
+ vpp_stage = int(param_name.lstrip(Const.vpp).split(Const.vpp_sep)[0])
+ else:
+ vpp_stage = 0
+
+ return GradAnomalyData(
+ self.rank,
+ step,
+ self.micro_step,
+ self.pp_stage,
+ self.vpp_stage,
+ call_id,
+ tag_name,
+ message,
+ self.group_mates
+ )
+
+@dataclass(eq=True)
+class GradAnomalyData:
+ rank: int = 0
+ step: int = 0
+ micro_step: int = 0
+ pp_stage: int = 0
+ vpp_stage: int = 0
+ call_id: int = 0
+ tag_name: str = field(default=None, compare=False)
+ message: str = field(default="", compare=False)
+ group_mates: list = field(default=None, compare=False)
+
+ def __lt__(self, other):
+ if not isinstance(other, GradAnomalyData):
+ return NotImplemented
+ if self.step != other.step:
+ return self.step < other.step
+ if self.micro_step != other.micro_step:
+ return self.micro_step < other.micro_step
+ if self.pp_stage != other.pp_stage:
+ return self.pp_stage > other.pp_stage
+ if self.vpp_stage != other.vpp_stage:
+ return self.vpp_stage > other.vpp_stage
+ if self.call_id != other.call_id:
+ return self.call_id < other.call_id
+ return False
+
+ def __le__(self, other):
+ if not isinstance(other, GradAnomalyData):
+ return NotImplemented
+ return self == other or self < other
+
+ def to_dict(self):
+ return self.__dict__
+
+ def get_key(self):
+ return ''.join(
+ (str(self.tag_name), "_step_", str(self.step), "_call_" , str(self.call_id)))
+
+class BaseWriterWithAD:
+ def __init__(self, path, ad_rules, job_id, anomaly_inform=False, anomaly_factory=None, ndigits=6):
self.tag2scalars = defaultdict(list)
self.ad_rules = ad_rules
self.job_id = job_id
self.anomaly_inform = anomaly_inform
-
- def add_scalar(self, tag, scalar_value, global_step=None, walltime=None, new_style=False, double_precision=False):
+ self.anomaly_factory = anomaly_factory
+ self.anomalies = []
+ self.ndigits = ndigits
+
+ def get_anomalies(self):
+ """返回已检测到的异常列表
+ """
+ return self.anomalies
+
+ def clear_anomalies(self):
+ self.anomalies.clear()
+
+ def add_scalar(self, tag, scalar_value, global_step=None):
new_avg = avg = scalar_value
if tag in self.tag2scalars:
N = len(self.tag2scalars[tag])
@@ -76,11 +168,64 @@ class SummaryWriterWithAD(SummaryWriter):
self.tag2scalars[tag].append((scalar_value, new_avg))
detected, rule_name = self._ad(scalar_value, history=avg)
if detected:
- print_info_log(f"{bcolors.WARNING}> Rule {rule_name} reports anomaly signal in {tag} at step {global_step}.{bcolors.ENDC}")
- exception_message = f"{bcolors.WARNING}> Rule {rule_name} reports anomaly signal in {tag} at step {global_step}.{bcolors.ENDC}"
+ exception_message = f"Rule {rule_name} reports anomaly signal in {tag} at step {global_step}."
+ print_info_log(f"{bcolors.WARNING}> {exception_message}{bcolors.ENDC}")
if self.anomaly_inform:
self.anomaly_inform.run(exception_message, self.job_id)
- return super().add_scalar(tag, scalar_value, global_step, walltime, new_style, double_precision)
-
+
+ if self.anomaly_factory:
+ self.anomalies.append(self.anomaly_factory.create(tag, exception_message, global_step))
+
def _ad(self, scalar_value, history):
return AnomalyScanner.scan(self.ad_rules, history, cur=scalar_value)
+
+
+class CSVWriterWithAD(BaseWriterWithAD):
+ def __init__(self, path, ad_rules, job_id, anomaly_inform=False, anomaly_factory=None, ndigits=6):
+ super().__init__(path, ad_rules, job_id, anomaly_inform, anomaly_factory, ndigits)
+
+ self.log_dir = path
+ create_directory(path)
+ self.context_dict = defaultdict(list)
+ self.header = []
+
+ def write_csv(self, prefix, step):
+ if len(self.context_dict) == 0:
+ return
+ filepath = os.path.join(self.log_dir, f'{prefix}_{step}.csv')
+ if not os.path.exists(filepath):
+ make_file_safety(filepath)
+ data_frame = pd.DataFrame(columns=self.header)
+ data_frame.to_csv(filepath, index=False)
+ change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
+
+ check_file_valid_writable(filepath)
+ new_data = []
+ for name, metric_value in self.context_dict.items():
+ if Const.vpp not in name:
+ new_data.append([name]+metric_value)
+ else:
+ new_data.append(name.lstrip(Const.vpp).split(Const.vpp_sep)+metric_value)
+ new_data = pd.DataFrame(new_data)
+ new_data.to_csv(filepath, mode='a+', header=False, index=False)
+ self.context_dict = defaultdict(list)
+
+ def add_scalar(self, tag, scalar_value, global_step):
+ super().add_scalar(tag, scalar_value, global_step)
+
+ name = tag.split('/')[0]
+ self.context_dict[name].append(round(scalar_value, self.ndigits))
+
+ def close(self):
+ pass
+
+class SummaryWriterWithAD(SummaryWriter, BaseWriterWithAD):
+ def __init__(self, path, ad_rules, job_id, anomaly_inform=False, anomaly_factory=None, ndigits=6):
+ super(SummaryWriter, self).__init__(path, ad_rules, job_id, anomaly_inform, anomaly_factory, ndigits)
+ super().__init__(path)
+ change_mode(path, FileCheckConst.DATA_DIR_AUTHORITY)
+
+ def add_scalar(self, tag, scalar_value, global_step):
+ super(SummaryWriter, self).add_scalar(tag, scalar_value, global_step)
+ return super().add_scalar(tag, scalar_value, global_step)
+
\ No newline at end of file
diff --git a/debug/accuracy_tools/kj600/kj600/anomaly_inform.py b/debug/accuracy_tools/monitor/monitor/anomaly_inform.py
similarity index 95%
rename from debug/accuracy_tools/kj600/kj600/anomaly_inform.py
rename to debug/accuracy_tools/monitor/monitor/anomaly_inform.py
index 301ac769217943a36e5d4cbe06033c828e5c675e..fe0fdb3f105cf3f400eee91734e37048fbaa3ca0 100644
--- a/debug/accuracy_tools/kj600/kj600/anomaly_inform.py
+++ b/debug/accuracy_tools/monitor/monitor/anomaly_inform.py
@@ -1,18 +1,17 @@
import smtplib
from email.mime.text import MIMEText
-import sqlite3
from datetime import datetime, timedelta
-from kj600.database import Database, ExceptionMessage
+from monitor.database import Database, ExceptionMessage
# define class InformRegistry to get inform_sub_class
class AnomalyInformFactory:
@staticmethod
def create_informer(**kwargs):
- if kwargs['recipient'] == "database":
+ if kwargs.get('recipient') == "database":
return DatabaseInform(**kwargs)
- elif kwargs['recipient'] == "email":
+ elif kwargs.get('recipient') == "email":
return EmailInform(**kwargs)
else:
raise ValueError("Invaild recipient specified")
diff --git a/debug/accuracy_tools/kj600/kj600/config.json b/debug/accuracy_tools/monitor/monitor/config.json
similarity index 100%
rename from debug/accuracy_tools/kj600/kj600/config.json
rename to debug/accuracy_tools/monitor/monitor/config.json
diff --git a/debug/accuracy_tools/monitor/monitor/const.py b/debug/accuracy_tools/monitor/monitor/const.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4198a994228fb4e4a99ec8a3aaa5fe9b3d14dcd
--- /dev/null
+++ b/debug/accuracy_tools/monitor/monitor/const.py
@@ -0,0 +1,4 @@
+
+class Const:
+ vpp = "vpp"
+ vpp_sep = ':'
\ No newline at end of file
diff --git a/debug/accuracy_tools/kj600/kj600/database.py b/debug/accuracy_tools/monitor/monitor/database.py
similarity index 100%
rename from debug/accuracy_tools/kj600/kj600/database.py
rename to debug/accuracy_tools/monitor/monitor/database.py
diff --git a/debug/accuracy_tools/kj600/kj600/distributed/distributed_ops.yaml b/debug/accuracy_tools/monitor/monitor/distributed/distributed_ops.yaml
similarity index 100%
rename from debug/accuracy_tools/kj600/kj600/distributed/distributed_ops.yaml
rename to debug/accuracy_tools/monitor/monitor/distributed/distributed_ops.yaml
diff --git a/debug/accuracy_tools/kj600/kj600/distributed/stack_blacklist.yaml b/debug/accuracy_tools/monitor/monitor/distributed/stack_blacklist.yaml
similarity index 77%
rename from debug/accuracy_tools/kj600/kj600/distributed/stack_blacklist.yaml
rename to debug/accuracy_tools/monitor/monitor/distributed/stack_blacklist.yaml
index 00b0013619fcfa1445a8df18c3c7d16764fb4872..40692d6942c39dfd5bbe52d33df6de4f1238eab2 100644
--- a/debug/accuracy_tools/kj600/kj600/distributed/stack_blacklist.yaml
+++ b/debug/accuracy_tools/monitor/monitor/distributed/stack_blacklist.yaml
@@ -1,5 +1,5 @@
stack:
-- kj600/distributed
+- monitor/distributed
- site-packages/torch/nn/modules/module.py
- multiprocessing
- debugpy
\ No newline at end of file
diff --git a/debug/accuracy_tools/kj600/kj600/distributed/wrap_distributed.py b/debug/accuracy_tools/monitor/monitor/distributed/wrap_distributed.py
similarity index 96%
rename from debug/accuracy_tools/kj600/kj600/distributed/wrap_distributed.py
rename to debug/accuracy_tools/monitor/monitor/distributed/wrap_distributed.py
index fad007fe35c45b1eb09b8ed986bd9b9893528a75..3f5ebc727dfd306d127ac7605a8aff6e83182f54 100644
--- a/debug/accuracy_tools/kj600/kj600/distributed/wrap_distributed.py
+++ b/debug/accuracy_tools/monitor/monitor/distributed/wrap_distributed.py
@@ -1,7 +1,8 @@
import os
-import yaml
import re
import inspect
+
+import yaml
import torch
import torch.nn as nn
import torch.distributed as dist
@@ -77,6 +78,21 @@ class ApiRegistry:
else:
setattr(api_group, cc_api_name, cc_api_entry_func)
+ @staticmethod
+ def redirect_wait():
+ global ORIGIN_WAIT
+ global PENDING_ASYNC_CC_BY_HANDLE
+
+ def wrapped_wait(work):
+ def wrapped_wait(*args, **kwargs):
+ ORIGIN_WAIT(*args, **kwargs)
+ if args[0] in PENDING_ASYNC_CC_BY_HANDLE:
+ store_func = PENDING_ASYNC_CC_BY_HANDLE.pop(args[0])
+ store_func()
+ return wrapped_wait
+
+ dist.Work.wait = wrapped_wait(dist.Work)
+
def redirect_api(self):
self.set_api_attr(dist, self.distributed_attr_hooked)
self.set_api_attr(dist.distributed_c10d, self.distributed_attr_hooked)
@@ -92,19 +108,13 @@ class ApiRegistry:
for op_name in get_distributed_ops():
self.distributed_attr_hooked[op_name] = DistributedOPTemplate(op_name, pre_hooks, post_hooks)
- def redirect_wait(self):
- global ORIGIN_WAIT
- global PENDING_ASYNC_CC_BY_HANDLE
-
- def wrapped_wait(work):
- def wrapped_wait(*args, **kwargs):
- ORIGIN_WAIT(*args, **kwargs)
- if args[0] in PENDING_ASYNC_CC_BY_HANDLE:
- store_func = PENDING_ASYNC_CC_BY_HANDLE.pop(args[0])
- store_func()
- return wrapped_wait
- dist.Work.wait = wrapped_wait(dist.Work)
+def get_process_group(process_group):
+ return (
+ process_group
+ if isinstance(process_group, dist.ProcessGroup)
+ else dist.GroupMember.WORLD
+ )
def stack_filter(stack):
@@ -115,7 +125,7 @@ def stack_filter(stack):
def get_callstack():
callstack = []
- for (_, path, line, func, code, _) in inspect.stack():
+ for (_, path, line, func, _, _) in inspect.stack():
stack_line = f'{path}[{line}]'
if stack_filter(stack_line):
callstack.append(stack_line+' '+func)
diff --git a/debug/accuracy_tools/kj600/kj600/features.py b/debug/accuracy_tools/monitor/monitor/features.py
similarity index 94%
rename from debug/accuracy_tools/kj600/kj600/features.py
rename to debug/accuracy_tools/monitor/monitor/features.py
index 7810188f7d7df66dce4c489f18062f9381b95646..92c6959a7d56ed3264366bf12444029f00b4df25 100644
--- a/debug/accuracy_tools/kj600/kj600/features.py
+++ b/debug/accuracy_tools/monitor/monitor/features.py
@@ -1,6 +1,6 @@
import torch
from torch.autograd.functional import jacobian
-from kj600.utils import print_info_log
+from monitor.utils import print_info_log
@torch.no_grad()
@@ -11,6 +11,10 @@ def square_sum(x: torch.tensor):
def get_min(x: torch.tensor):
return torch.min(x)
+@torch.no_grad()
+def get_mean(x: torch.tensor):
+ return torch.mean(x)
+
@torch.no_grad()
def get_norm(x: torch.tensor):
return torch.norm(x, p=2)
diff --git a/debug/accuracy_tools/atat/core/common/file_check.py b/debug/accuracy_tools/monitor/monitor/file_check.py
similarity index 65%
rename from debug/accuracy_tools/atat/core/common/file_check.py
rename to debug/accuracy_tools/monitor/monitor/file_check.py
index 2df825aa35108fc08b9d886bf68f0ef3e2bc1533..af233ca2c8f01b76d912b753078a5cbf117049c6 100644
--- a/debug/accuracy_tools/atat/core/common/file_check.py
+++ b/debug/accuracy_tools/monitor/monitor/file_check.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
-# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
@@ -17,9 +17,58 @@
import os
import re
-from atat.core.common.log import logger
-from atat.core.common.exceptions import FileCheckException
-from atat.core.common.const import FileCheckConst
+from monitor.utils import print_info_log
+
+
+class CodedException(Exception):
+ def __init__(self, code, error_info=""):
+ super().__init__()
+ self.code = code
+ self.error_info = self.err_strs.get(code) + error_info
+
+ def __str__(self):
+ return self.error_info
+
+
+class FileCheckException(CodedException):
+ INVALID_FILE_ERROR = 0
+ FILE_PERMISSION_ERROR = 1
+ SOFT_LINK_ERROR = 2
+ ILLEGAL_PATH_ERROR = 3
+ ILLEGAL_PARAM_ERROR = 4
+ FILE_TOO_LARGE_ERROR = 5
+
+ err_strs = {
+ SOFT_LINK_ERROR: "[monitor] 检测到软链接: ",
+ FILE_PERMISSION_ERROR: "[monitor] 文件权限错误: ",
+ INVALID_FILE_ERROR: "[monitor] 无效文件: ",
+ ILLEGAL_PATH_ERROR: "[monitor] 非法文件路径: ",
+ ILLEGAL_PARAM_ERROR: "[monitor] 非法打开方式: ",
+ FILE_TOO_LARGE_ERROR: "[monitor] 文件过大: ",
+ }
+
+
+class FileCheckConst:
+ """
+ Class for file check const
+ """
+
+ READ_ABLE = "read"
+ WRITE_ABLE = "write"
+ READ_WRITE_ABLE = "read and write"
+ DIRECTORY_LENGTH = 4096
+ FILE_NAME_LENGTH = 255
+ FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$"
+ FILE_PATTERN = r"^[a-zA-Z0-9_./-]+$"
+ JSON_SUFFIX = ".json"
+ MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
+ DIR = "dir"
+ FILE = "file"
+ DATA_DIR_AUTHORITY = 0o750
+ DATA_FILE_AUTHORITY = 0o640
+ FILE_SIZE_DICT = {
+ JSON_SUFFIX: MAX_JSON_SIZE,
+ }
class FileChecker:
@@ -32,7 +81,10 @@ class FileChecker:
ability(str): FileCheckConst.WRITE_ABLE or FileCheckConst.READ_ABLE to set file has writability or readability
file_type(str): The correct file type for file
"""
- def __init__(self, file_path, path_type, ability=None, file_type=None, is_script=True):
+
+ def __init__(
+ self, file_path, path_type, ability=None, file_type=None, is_script=True
+ ):
self.file_path = file_path
self.path_type = self._check_path_type(path_type)
self.ability = ability
@@ -42,7 +94,9 @@ class FileChecker:
@staticmethod
def _check_path_type(path_type):
if path_type not in [FileCheckConst.DIR, FileCheckConst.FILE]:
- logger.error(f'The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}.')
+ print_info_log(
+ f"The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}."
+ )
raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR)
return path_type
@@ -82,11 +136,12 @@ class FileOpen:
file_path: The file or dictionary path to be opened.
mode(str): The file open mode
"""
+
SUPPORT_READ_MODE = ["r", "rb"]
SUPPORT_WRITE_MODE = ["w", "wb", "a", "ab"]
SUPPORT_READ_WRITE_MODE = ["r+", "rb+", "w+", "wb+", "a+", "ab+"]
- def __init__(self, file_path, mode, encoding='utf-8'):
+ def __init__(self, file_path, mode, encoding="utf-8"):
self.file_path = file_path
self.mode = mode
self.encoding = encoding
@@ -106,9 +161,13 @@ class FileOpen:
self._handle.close()
def check_file_path(self):
- support_mode = self.SUPPORT_READ_MODE + self.SUPPORT_WRITE_MODE + self.SUPPORT_READ_WRITE_MODE
+ support_mode = (
+ self.SUPPORT_READ_MODE
+ + self.SUPPORT_WRITE_MODE
+ + self.SUPPORT_READ_WRITE_MODE
+ )
if self.mode not in support_mode:
- logger.error("File open not support %s mode" % self.mode)
+ print_info_log("File open not support %s mode" % self.mode)
raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR)
check_link(self.file_path)
self.file_path = os.path.realpath(self.file_path)
@@ -135,66 +194,75 @@ class FileOpen:
def check_link(path):
abs_path = os.path.abspath(path)
if os.path.islink(abs_path):
- logger.error('The file path {} is a soft link.'.format(path))
+ print_info_log("The file path {} is a soft link.".format(path))
raise FileCheckException(FileCheckException.SOFT_LINK_ERROR)
def check_path_length(path, name_length=None):
- file_max_name_length = name_length if name_length else FileCheckConst.FILE_NAME_LENGTH
- if len(path) > FileCheckConst.DIRECTORY_LENGTH or \
- len(os.path.basename(path)) > file_max_name_length:
- logger.error('The file path length exceeds limit.')
+ file_max_name_length = (
+ name_length if name_length else FileCheckConst.FILE_NAME_LENGTH
+ )
+ if (
+ len(path) > FileCheckConst.DIRECTORY_LENGTH
+ or len(os.path.basename(path)) > file_max_name_length
+ ):
+ print_info_log("The file path length exceeds limit.")
raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
def check_path_exists(path):
if not os.path.exists(path):
- logger.error('The file path %s does not exist.' % path)
+ print_info_log("The file path %s does not exist." % path)
raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
def check_path_readability(path):
if not os.access(path, os.R_OK):
- logger.error('The file path %s is not readable.' % path)
+ print_info_log("The file path %s is not readable." % path)
raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
def check_path_writability(path):
if not os.access(path, os.W_OK):
- logger.error('The file path %s is not writable.' % path)
+ print_info_log("The file path %s is not writable." % path)
raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
def check_path_executable(path):
if not os.access(path, os.X_OK):
- logger.error('The file path %s is not executable.' % path)
+ print_info_log("The file path %s is not executable." % path)
raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
def check_other_user_writable(path):
st = os.stat(path)
if st.st_mode & 0o002:
- logger.error('The file path %s may be insecure because other users have write permissions. ' % path)
+ print_info_log(
+ "The file path %s may be insecure because other users have write permissions. "
+ % path
+ )
raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
def check_path_owner_consistent(path):
file_owner = os.stat(path).st_uid
if file_owner != os.getuid():
- logger.error('The file path %s may be insecure because is does not belong to you.' % path)
+ print_info_log(
+ "The file path %s may be insecure because is does not belong to you." % path
+ )
raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
def check_path_pattern_vaild(path):
if not re.match(FileCheckConst.FILE_VALID_PATTERN, path):
- logger.error('The file path %s contains special characters.' %(path))
+ print_info_log("The file path %s contains special characters." % (path))
raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
def check_file_size(file_path, max_size):
file_size = os.path.getsize(file_path)
if file_size >= max_size:
- logger.error(f'The size of file path {file_path} exceeds {max_size} bytes.')
+ print_info_log(f"The size of file path {file_path} exceeds {max_size} bytes.")
raise FileCheckException(FileCheckException.FILE_TOO_LARGE_ERROR)
@@ -209,45 +277,32 @@ def check_common_file_size(file_path):
def check_file_suffix(file_path, file_suffix):
if file_suffix:
if not file_path.endswith(file_suffix):
- logger.error(f"The {file_path} should be a {file_suffix} file!")
+ print_info_log(f"The {file_path} should be a {file_suffix} file!")
raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
def check_path_type(file_path, file_type):
if file_type == FileCheckConst.FILE:
if not os.path.isfile(file_path):
- logger.error(f"The {file_path} should be a file!")
+ print_info_log(f"The {file_path} should be a file!")
raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
if file_type == FileCheckConst.DIR:
if not os.path.isdir(file_path):
- logger.error(f"The {file_path} should be a dictionary!")
+ print_info_log(f"The {file_path} should be a dictionary!")
raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
-
-
-def create_directory(dir_path):
- """
- Function Description:
- creating a directory with specified permissions
- Parameter:
- dir_path: directory path
- Exception Description:
- when invalid data throw exception
- """
- dir_path = os.path.realpath(dir_path)
- try:
- os.makedirs(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True)
- except OSError as ex:
- raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR,
- 'Failed to create {}. Please check the path permission or disk space .{}'.format(dir_path, str(ex))) from ex
-
+
def check_path_before_create(path):
if path_len_exceeds_limit(path):
- raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, 'The file path length exceeds limit.')
+ raise FileCheckException(
+ FileCheckException.ILLEGAL_PATH_ERROR, "The file path length exceeds limit."
+ )
if not re.match(FileCheckConst.FILE_PATTERN, os.path.realpath(path)):
- raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR,
- 'The file path {} contains special characters.'.format(path))
+ raise FileCheckException(
+ FileCheckException.ILLEGAL_PATH_ERROR,
+ "The file path {} contains special characters.".format(path),
+ )
def change_mode(path, mode):
@@ -256,10 +311,14 @@ def change_mode(path, mode):
try:
os.chmod(path, mode)
except PermissionError as ex:
- raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR,
- 'Failed to change {} authority. {}'.format(path, str(ex))) from ex
+ raise FileCheckException(
+ FileCheckException.FILE_PERMISSION_ERROR,
+ "Failed to change {} authority. {}".format(path, str(ex)),
+ ) from ex
def path_len_exceeds_limit(file_path):
- return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \
- len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH
\ No newline at end of file
+ return (
+ len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH
+ or len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH
+ )
diff --git a/debug/accuracy_tools/kj600/kj600/module_hook.py b/debug/accuracy_tools/monitor/monitor/module_hook.py
similarity index 49%
rename from debug/accuracy_tools/kj600/kj600/module_hook.py
rename to debug/accuracy_tools/monitor/monitor/module_hook.py
index 8043c5671c42675dc55944e4818bce8e9137b455..dca629cb161d4b2eb4e6e330d8dfe3d76c76775e 100644
--- a/debug/accuracy_tools/kj600/kj600/module_hook.py
+++ b/debug/accuracy_tools/monitor/monitor/module_hook.py
@@ -2,19 +2,46 @@ import os
import uuid
import json
from collections import defaultdict
+from functools import partial
from datetime import datetime
import torch
+
+torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
+if not torch_version_above_or_equal_2:
+ raise ValueError("msmonitor require torch>=2.0")
+
import torch.distributed as dist
+from torch import Stream
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
-from kj600.module_spec_verifier import get_config, validate_config_spec
-from kj600.optimizer_collect import MixPrecsionOptimizerMon, print_rank_0, OptimizerMonFactory, MegatronDistributedOptimizerMon
-from kj600.features import eff_rank, get_sign_matches
-from kj600.visualizer import HeatmapVisualizer
-from kj600.anomaly_detect import AnomalyScanner, SummaryWriterWithAD
-from kj600.anomaly_inform import AnomalyInformFactory
-from kj600.module_metric import get_metrics, write_metrics_tensorboard, get_summary_writer_tag_name, TensorMetrics
-from kj600.distributed.wrap_distributed import api_register, create_hooks, op_aggregate
-from kj600.utils import print_warn_log, print_info_log, get_param_struct
+from monitor.module_spec_verifier import validate_config_spec
+from monitor.optimizer_collect import OptimizerMon, print_rank_0, OptimizerMonFactory
+from monitor.features import eff_rank, get_sign_matches
+from monitor.visualizer import HeatmapVisualizer
+from monitor.anomaly_detect import AnomalyScanner, AnomalyDataFactory, SummaryWriterWithAD, CSVWriterWithAD, \
+ BaseWriterWithAD
+from monitor.anomaly_inform import AnomalyInformFactory
+from monitor.anomaly_analyse import AnomalyDataWriter
+from monitor.module_metric import get_metrics, write_metrics_tensorboard, write_metrics_csv, \
+ get_summary_writer_tag_name, TensorMetrics, squash_param_name
+from monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, get_process_group
+from monitor.utils import print_warn_log, print_info_log, print_error_log, get_param_struct
+from monitor.const import Const
+from monitor.file_check import FileOpen
+
+try:
+ import torch_npu
+except ImportError:
+ pass
+
+
+def param_is_not_tensor_parallel_duplicate(param, tp_group):
+ return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or (
+ torch.distributed.get_rank(group=tp_group) == 0
+ )
+
+
+def param_is_data_parallel_duplicate(dp_group):
+ return torch.distributed.get_rank(group=dp_group) != 0
class ModuleHookContext:
@@ -30,20 +57,18 @@ class ModuleHookContext:
self.focused_out_col = 0
self.ignore_in = False # no need to care when no key 'input' or 'input_grad' found
- def set_format_by_arg(self, key_name:str, target_config:dict):
+ def set_format_by_arg(self, key_name: str, target_config: dict):
if key_name in target_config[self.module_name]:
self.format_by_arg[key_name] = target_config[self.module_name][key_name]
elif key_name in ['input', 'input_grad']:
self.ignore_in = True
- else:
- raise KeyError(f"Missing key: {key_name} of {self.module_name} in config.json")
class OptimizerContext:
def __init__(self) -> None:
self.step = 0
self.param_effective_rank = defaultdict(float)
- self.param_mg_direction = defaultdict(float)
+ self.param_mg_direction = defaultdict(float)
self.param_adam_update = defaultdict()
self.param_adam_ratio = defaultdict()
self.param_weight_grad = defaultdict()
@@ -71,30 +96,50 @@ class CommunicationContext:
def aggregate(self):
self.data = self._agg(self.data)
-class TrainerMon:
+class GradContext:
+ def __init__(self) -> None:
+ self.pre = []
+ self.post = []
+ self.acc_metric = []
+ self.acc = {}
+ self.actv = defaultdict(dict)
+
+ def reset(self):
+ self.pre.clear()
+ self.post.clear()
+ self.acc_metric.clear()
+ self.acc.clear()
+ self.actv.clear()
+
+
+class TrainerMon:
tensor_metrics = TensorMetrics()
- # opt_ty: "Megatron_Float16OptimizerWithFloat16Params" or "Megatron_DistributedOptimizer"
- def __init__(self, config_file_path, params_have_main_grad=True, opt_ty=None) -> None:
+ def __init__(self, config_file_path, process_group=None, params_have_main_grad=True, opt_ty=None) -> None:
self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext)
self.optimizer_context = defaultdict(OptimizerContext)
self.cc_context = defaultdict(CommunicationContext)
+ self.grad_context = GradContext()
+ self.process_group = get_process_group(process_group)
self.params_have_main_grad = params_have_main_grad
- self.config = get_config(config_file_path)
+ with FileOpen(config_file_path, 'r') as f:
+ self.config = json.load(f)
self.module_rank_list = self.config.get("module_ranks", [])
+ self.format = self.config.get('format', 'tensorboard')
self.eps = self.config.get('eps', 1e-8)
self.ops = self.config.get('ops', [])
+ self.ndigits = self.config.get('ndigits', 6)
self.xy_distribution = self.config.get('xy_distribution', False)
if not self.xy_distribution:
print_rank_0("> module input/output input_grad/output_grad is not monitored. ")
-
- # backward hook cause megatron-lm pipeline parallel schedule assert exception.
+ # backward hook cause megatron-lm pipeline parallel schedule assert exception.
# TBD: backward hook cause output tensor is view of some base tensor. root cause invesigation pending.
- self.forward_only = self.config.get('forward_only', False)
- if self.forward_only:
+ self.forward_only = self.config.get('forward_only', False)
+ if self.forward_only:
print_rank_0("> only module forward is monitored. ")
+ self.backward_only = self.config.get('backward_only', False)
self.ur_distribution = self.config.get('ur_distribution', False)
if not self.ur_distribution:
@@ -120,28 +165,72 @@ class TrainerMon:
api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
api_register.redirect_api()
- alert_setting = self.config.get('alert', {"rules":[]})
+ alert_setting = self.config.get('alert', {"rules": []})
self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"])
-
- anomaly_inform = AnomalyInformFactory.create_informer(**alert_setting["inform"]) if "inform" in alert_setting else None
-
- self.optimizer_hooked = False
- output_base_dir = os.getenv('KJ600_OUTPUT_DIR', './kj600_output')
+ anomaly_inform = AnomalyInformFactory.create_informer(
+ **alert_setting["inform"]) if "inform" in alert_setting else None
+
+ output_base_dir = os.getenv('MONITOR_OUTPUT_DIR', './monitor_output')
cur_time = datetime.now().strftime('%b%d_%H-%M-%S')
unique_id = str(uuid.uuid4())[:8]
+
if dist.is_initialized():
- if (dist.get_rank() in self.module_rank_list) or len(self.module_rank_list) == 0:
- self.summary_writer = SummaryWriterWithAD(
- os.path.join(output_base_dir, f"{cur_time}-rank{dist.get_rank()}-{unique_id}"), self.alert_rules, unique_id, anomaly_inform)
+ rank = dist.get_rank()
+ tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-rank{rank}-{unique_id}")
+ pp_stage = dist.get_group_rank(self.process_group, rank)
+ group_mates = dist.get_process_group_ranks(self.process_group)
else:
- self.summary_writer = SummaryWriterWithAD(os.path.join(output_base_dir, f"{cur_time}-{unique_id}"), self.alert_rules, unique_id, anomaly_inform)
+ rank = 0
+ tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-{unique_id}")
+ pp_stage = 0
+ group_mates = [0]
+ self.rank = rank
+
+ # 初始化AnomalyData工厂
+ self.anomaly_data_factory = AnomalyDataFactory(rank, pp_stage, group_mates) if alert_setting.get('dump',
+ False) else None
+
+ if self.format == 'tensorboard':
+ writer = SummaryWriterWithAD
+ self.write_metrics = write_metrics_tensorboard
+ elif self.format == 'csv':
+ writer = CSVWriterWithAD
+ self.write_metrics = write_metrics_csv
+ elif self.format == 'api':
+ writer = BaseWriterWithAD
+ self.write_metrics = write_metrics_tensorboard
+
+ if (rank in self.module_rank_list) or len(self.module_rank_list) == 0:
+
+ self.summary_writer = writer(
+ tensorboard_dir,
+ self.alert_rules,
+ unique_id,
+ anomaly_inform,
+ self.anomaly_data_factory,
+ self.ndigits
+ )
+ # 初始化anomaly deteted文件目录
+ if self.anomaly_data_factory:
+ self.anomaly_data_writer = AnomalyDataWriter(
+ os.path.join(output_base_dir, "anomaly_detected"), rank)
+ self.anomaly_data_writer.init_detected_json()
+
# A HeatmapVisualizer instance is associated with an image
self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
- self.micro_batch_number = 0
+ self.micro_batch_number = 1
+
+ self.weight_hooked = False
+ self.optimizer_hooked = False
+ self.param_registered = False
+ self.vpp = False
+ self.dp_group = None
+ self.tp_group = None
- self.param_name_list = []
self.param2name = defaultdict(str)
+ self.param_name_call_id = {}
+ self.call_id = 0
self.mix_precision_optimizer_mon = OptimizerMonFactory.create_optimizer_mon(opt_ty)
if opt_ty is None:
@@ -149,21 +238,25 @@ class TrainerMon:
raise Exception("ur_distribution cannot be enabled with unknown optimizer.")
if self.mv_distribution:
raise Exception("mv_distribution cannot be enabled with unknown optimizer.")
+ self.verbose = False
self.print_struct = self.config.get("print_struct", False)
+ if self.print_struct:
+ self.verbose = True
self.struct_printed = False
- self.module_struct = {}
+ self.module_struct = defaultdict(dict)
+
return
def __del__(self):
if hasattr(self, "summary_writer"):
self.summary_writer.close()
-
+
@staticmethod
def set_wrapped_optimizer(_wrapped_optimizer):
- MixPrecsionOptimizerMon.set_wrapped_optimizer(_wrapped_optimizer)
+ OptimizerMon.set_wrapped_optimizer(_wrapped_optimizer)
@staticmethod
- def adhoc_check(target_tensor:torch.tensor, module_name:str, tensor_name:str, rank_list, ops_list):
+ def adhoc_check(target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
rank = None
if dist.is_initialized():
rank = dist.get_rank()
@@ -171,60 +264,97 @@ class TrainerMon:
return
TrainerMon.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
- def hook_modules(self, model:torch.nn.Module, grad_acc_steps):
- # fwd=0, bkd=1
- # targets is module name list like ["xx.xxx1", "xxx.xxx2"] which can be obtained when first run.
- print_rank_0("> module names:")
- for name, _ in model.named_modules():
- print_rank_0(f"\t{name}")
- self.micro_batch_number = grad_acc_steps
+ @staticmethod
+ def build_tbtag_tensor_map(module_name, tag, tensor):
+ metrics = {}
+ rank = dist.get_rank() if dist.is_initialized() else None
+ key = get_summary_writer_tag_name(module_name, tag, rank)
+ if tensor is not None:
+ metrics[key] = tensor
+ return metrics
- if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list):
- targets = [x for x, _ in model.named_modules()] if self.print_struct else self.config['targets'].keys()
- hooked_count = self._hook_module(targets, model, fwd_or_bkd=0)
- print_rank_0(f"> {hooked_count} out of {len(self.config['targets'])} are monitored.")
- else:
+ @staticmethod
+ def generate_cc_metrics(cc_name, cc_tensor):
+ metrics = defaultdict(dict)
+ rank = dist.get_rank() if dist.is_initialized() else None
+ for op, tag2tensor in cc_tensor.data.items():
+ for tag, tensor in tag2tensor.items():
+ key = get_summary_writer_tag_name(cc_name, tag, rank)
+ metrics[op].update({key: tensor})
+ cc_tensor.reset()
+ return metrics
+
+ def hook_modules(self, model: torch.nn.Module, grad_acc_steps):
+ if self.module_rank_list and (self.rank not in self.module_rank_list):
return
+ if not isinstance(model, list):
+ model = [model]
+
+ self._register_param_name(model)
+
+ self.micro_batch_number = grad_acc_steps
+ for vpp_stage, model_chunk in enumerate(model):
+ vpp_stage = f'{vpp_stage}{Const.vpp_sep}' if self.vpp else ''
+ targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
+ 'targets'].keys()
+ hooked_count = self._hook_module(targets, model_chunk, vpp_stage)
+ print_rank_0(f"> {hooked_count} out of {len(self.config['targets'])} are monitored.")
+
if not self.optimizer_hooked:
- self.optimizer_hooked = True
- print_rank_0("> parameter names:")
- for name, param in model.named_parameters():
- print_rank_0(f"\t{name}")
- for target_module, _ in self.config['targets'].items():
- if name.startswith(target_module): # name : language_model.encoder.layers.0.mlp.weight, target_module:language_model.encoder.layers.0
- self.param_name_list.append(name)
- self.param2name[param] = name
self.hook_optimizer()
return
- def build_tbtag_tensor_map(self, module_name, tag, tensor):
- metrics = {}
- rank = dist.get_rank() if dist.is_initialized() else None
- key = get_summary_writer_tag_name(module_name, tag, rank)
- if tensor is not None:
- metrics[key] = tensor
- return metrics
+ def generate_wgrad_metrics(self):
+ if not self.wg_distribution:
+ return {}, {}
+
+ unreduced = {}
+ if self.weight_hooked:
+ for metric_name in self.ops:
+ unreduced[metric_name] = get_metrics(metric_name, self.grad_context.acc, self.eps)
+ self.grad_context.acc_metric = [unreduced]
+
+ grad_dict = {}
+ for param, name in self.param2name.items():
+ if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group):
+ continue
+ if self.dp_group and param_is_data_parallel_duplicate(self.dp_group):
+ continue
+ grad = param.main_grad if self.params_have_main_grad else param.grad
+ if grad is None:
+ print_warn_log(f"grad is None: {name}, maybe something wrong happened.")
+ continue
+ key = get_summary_writer_tag_name(name, 'post_grad', self.rank)
+ grad_dict[key] = grad
+
+ reduced = {op: get_metrics(op, grad_dict, self.eps) for op in self.ops}
+ self.grad_context.post = [reduced]
+
+ return reduced, unreduced
+
+ def monitor_gnorm_with_ad(self, model, grad_acc_steps=1, optimizer=None, tp_group=None, dp_group=None):
+ print_info_log(f'grad acc steps {grad_acc_steps}')
+ self.hook_optimizer(optimizer)
+ self.micro_batch_number = grad_acc_steps
+ self.backward_only = True
+
+ self.dp_group = dp_group
+ self.tp_group = tp_group
+
+ self._register_param_name(model)
+ self._hook_weights()
+ self.hook_modules(model, grad_acc_steps)
def generate_param_metrics(self, tag, param_tensor):
metrics = {}
rank = dist.get_rank() if dist.is_initialized() else None
- for param, name in self.param2name.items():
+ for _, name in self.param2name.items():
key = get_summary_writer_tag_name(name, tag, rank)
if name not in param_tensor or param_tensor[name] is None:
continue
metrics[key] = param_tensor[name]
return metrics
-
- def generate_cc_metrics(self, cc_name, cc_tensor):
- metrics = defaultdict(dict)
- rank = dist.get_rank() if dist.is_initialized() else None
- for op, tag2tensor in cc_tensor.data.items():
- for tag, tensor in tag2tensor.items():
- key = get_summary_writer_tag_name(cc_name, tag, rank)
- metrics[op].update({key: tensor})
- cc_tensor.reset()
- return metrics
def write_adhoc_check(self, step):
TrainerMon.tensor_metrics.flush(self.summary_writer)
@@ -233,37 +363,47 @@ class TrainerMon:
if not self.xy_distribution:
return
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
+ if len(fwd_context.actv) == 0:
+ continue
if not len(fwd_context.actv) == self.micro_batch_number:
- print_warn_log(f"fwd_context.actv not equal to micro_batch_number: {len(fwd_context.actv)}, {self.micro_batch_number}")
- for metric_name in self.ops:
- write_metrics_tensorboard(metric_name, self.summary_writer, fwd_context.actv, step)
+ print_warn_log(
+ f"fwd_context.actv not equal to micro_batch_number: {len(fwd_context.actv)}, {self.micro_batch_number}")
+ self.write_metrics(self.ops, self.summary_writer, fwd_context.actv, step, 'actv')
fwd_context.actv.clear()
- for _, bwd_context in self.module_bwd_hook_context_by_module.items():
- if not len(bwd_context.actvgrad) == self.micro_batch_number:
- print_warn_log(f"bwd_context.actvgrad not equal to micro_batch_number: {len(bwd_context.actvgrad)}, {self.micro_batch_number}")
- for metric_name in self.ops:
- write_metrics_tensorboard(metric_name, self.summary_writer, bwd_context.actvgrad, step)
- bwd_context.actvgrad.clear()
+ self.write_metrics(self.ops, self.summary_writer, [self.grad_context.actv], step, 'grad_actv')
- def hook_optimizer(self):
+ def write_grad_tb(self, step):
+ if not self.wg_distribution:
+ return
+
+ self.write_metrics(self.ops, self.summary_writer, self.grad_context.post, step, 'grad_reduced')
+ self.write_metrics(self.ops, self.summary_writer, self.grad_context.acc_metric, step, 'grad_unreduced')
+
+ def hook_optimizer(self, optimizer=None):
# in DDP by default use params_have_main_grad
def optimizer_pre_step_hook(optimizer, args, kwargs):
context = self.optimizer_context[optimizer]
- if self.print_struct and not all(value == {} for value in self.module_struct.values()) and not self.struct_printed:
+ if self.print_struct and not all(
+ value == {} for value in self.module_struct.values()) and not self.struct_printed:
self._smallest_rank_print("> module struct:")
self._smallest_rank_print(json.dumps(self.module_struct, indent=4))
if not self.cc_log_only:
raise Exception("exit after first step when print model struct")
if self.cc_log_only and context.step > 0:
self._smallest_rank_print("> Used communication ops and corresponding stack")
- self._smallest_rank_print(json.dumps({k:[i.split(';') for i in v] for k,v in self.cc_logged_stack.items()}, indent=4))
+ self._smallest_rank_print(
+ json.dumps({k: [i.split(';') for i in v] for k, v in self.cc_logged_stack.items()}, indent=4))
raise Exception("exit after first step when print cc stack")
-
- context.param_exp_avg, context.param_exp_avg_sq, context.param_adam_update, context.param_adam_ratio = self.mix_precision_optimizer_mon.fetch_mv(self,
- optimizer, self.param2name)
-
+ self.generate_wgrad_metrics()
+
+ mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name)
+ context.param_exp_avg = mv_result.exp_avg
+ context.param_exp_avg_sq = mv_result.exp_avg_sq
+ context.param_adam_update = mv_result.update
+ context.param_adam_ratio = mv_result.ratio
+
for param, name in self.param2name.items():
if "params_effrank" in self.config and name in self.config["params_effrank"]:
context.param_effective_rank[name] = eff_rank(param.detach())
@@ -271,9 +411,8 @@ class TrainerMon:
if grad is None:
print_warn_log(f"grad is None: {name}, maybe something wrong happened.")
continue
- if self.wg_distribution:
- context.param_weight_grad[name] = grad
- if self.mg_direction:
+
+ if self.mg_direction:
if context.step == 0:
same_direction_ratio = torch.tensor(1.)
else:
@@ -281,15 +420,11 @@ class TrainerMon:
context.param_mg_direction[name] = same_direction_ratio
tbtag_tensor_map = {}
- if self.wg_distribution:
- tbtag_tensor_map.update(self.generate_param_metrics('weight_grad', context.param_weight_grad))
if self.mv_distribution:
tbtag_tensor_map.update(self.generate_param_metrics('exp_avg', context.param_exp_avg))
tbtag_tensor_map.update(self.generate_param_metrics('exp_avg_sq', context.param_exp_avg_sq))
if self.mg_direction:
tbtag_tensor_map.update(self.generate_param_metrics('mg_direction', context.param_mg_direction))
- # if not tbtag_tensor_map:
- # return
metric_dict = {}
for metric_name in self.ops:
metric_dict[metric_name] = get_metrics(metric_name, tbtag_tensor_map, self.eps)
@@ -298,6 +433,7 @@ class TrainerMon:
cc_metrics = self.generate_cc_metrics(k, c)
for op, m in cc_metrics.items():
metric_dict[op].update(m)
+
if not metric_dict:
return
context.metric_list.append(metric_dict)
@@ -307,29 +443,57 @@ class TrainerMon:
context = self.optimizer_context[optimizer]
rank = dist.get_rank() if dist.is_initialized() else None
+ if self.anomaly_data_factory:
+ self.anomaly_data_factory.set_call_id(self.param_name_call_id)
self.write_xy_tb(context.step)
+ self.write_grad_tb(context.step)
self.write_adhoc_check(context.step)
if self.ur_distribution:
for param_name, _ in context.param_adam_update.items():
- self.update_heatmap_visualizer[param_name].visualize(get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step, self.summary_writer)
+ self.update_heatmap_visualizer[param_name].visualize(
+ get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step, self.summary_writer)
for param_name, _ in context.param_adam_ratio.items():
- self.ratio_heatmap_visualizer[param_name].visualize(get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step, self.summary_writer)
+ self.ratio_heatmap_visualizer[param_name].visualize(
+ get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step, self.summary_writer)
- for metric_name in self.ops:
- if not context.metric_list:
- break
- write_metrics_tensorboard(metric_name, self.summary_writer, context.metric_list, context.step)
+ if context.metric_list:
+ self.write_metrics(self.ops, self.summary_writer, context.metric_list, context.step, 'other')
context.metric_list.clear()
context.step += 1
+ self.grad_context.reset()
+ if self.anomaly_data_factory:
+ self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
+ self.summary_writer.clear_anomalies()
+ self.call_id = 0
+ self.param_name_call_id.clear()
+ return
+
+ def patch_step(func, optimizer):
+ def wrapper(*args, **kwargs):
+ optimizer_pre_step_hook(optimizer, args, kwargs)
+ out = func(*args, **kwargs)
+ optimizer_post_step_hook(optimizer, args, kwargs)
+ return out
+ return wrapper
+
+ if self.optimizer_hooked:
return
- if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list):
- register_optimizer_step_pre_hook(optimizer_pre_step_hook)
- register_optimizer_step_post_hook(optimizer_post_step_hook)
+
+ if optimizer:
+ optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
+
+ else:
+ if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list):
+ register_optimizer_step_pre_hook(optimizer_pre_step_hook)
+ register_optimizer_step_post_hook(optimizer_post_step_hook)
+ self.optimizer_hooked = True
return
def _smallest_rank_print(self, msg):
+ if not self.verbose:
+ return
if dist.is_initialized():
if self.module_rank_list:
if dist.get_rank() == min(self.module_rank_list):
@@ -340,7 +504,34 @@ class TrainerMon:
else:
print_info_log(msg)
- def _hook_module(self, target_names, module: torch.nn.Module, fwd_or_bkd):
+ def _register_param_name(self, model):
+ if self.param_registered:
+ return
+ if not isinstance(model, list):
+ model = [model]
+
+ if len(model) > 1:
+ self.vpp = True
+ self._smallest_rank_print('vpp enabled')
+
+ for vpp_stage, model_chunk in enumerate(model):
+ prefix = f'{Const.vpp}{vpp_stage}{Const.vpp_sep}' if self.vpp else ''
+ for param_name, param in model_chunk.named_parameters():
+ name = prefix + squash_param_name(param_name)
+ for target in self.config['targets'].keys():
+ if param_name.startswith(target) and param.requires_grad:
+ self._smallest_rank_print(f'>> monitoring: {name}')
+ setattr(param, "zero_out_wgrad", True)
+ if name in self.param2name.values() or name == '':
+ print_error_log(f'same name {name} for different param. Current param is {param_name}. \
+ May be error of squash_param_name')
+ raise Exception("param with same name will be overwriten.")
+ self.param2name[param] = name
+ break
+
+ self.param_registered = True
+
+ def _hook_module(self, target_names, module: torch.nn.Module, vpp_stage=''):
if '_modules' not in module.__dict__:
# nothing to hook
return 0
@@ -351,15 +542,15 @@ class TrainerMon:
self.module_struct[context.module_name].update(
{"input": f"{get_param_struct(module_input)}", "output": f"{get_param_struct(module_output)}"})
return
- if not self.xy_distribution:
- return
if not context.format_by_arg:
context.set_format_by_arg('input', self.config['targets'])
context.set_format_by_arg('output', self.config['targets'])
if not context.verified:
if not context.ignore_in:
- context.focused_in_col = validate_config_spec(context.format_by_arg['input'], module_input, context.module_name, 'input')
- context.focused_out_col = validate_config_spec(context.format_by_arg['output'], module_output, context.module_name, 'output')
+ context.focused_in_col = validate_config_spec(context.format_by_arg['input'], module_input,
+ context.module_name, 'input')
+ context.focused_out_col = validate_config_spec(context.format_by_arg['output'], module_output,
+ context.module_name, 'output')
context.verified = True
# expect output be tensor type
tbtag_tensor_map = {}
@@ -387,32 +578,40 @@ class TrainerMon:
context: ModuleHookContext = self.module_bwd_hook_context_by_module[module]
if self.print_struct:
self.module_struct[context.module_name].update(
- {"input_grad": f"{get_param_struct(input_grad)}", "output_grad": f"{get_param_struct(output_grad)}"})
- return
- if not self.xy_distribution:
+ {"input_grad": f"{get_param_struct(input_grad)}",
+ "output_grad": f"{get_param_struct(output_grad)}"})
return
if not context.format_by_arg:
context.set_format_by_arg('input_grad', self.config['targets'])
context.set_format_by_arg('output_grad', self.config['targets'])
+ if not context.format_by_arg:
+ return
if not context.verified:
if not context.ignore_in:
- context.focused_in_col = validate_config_spec(context.format_by_arg['input_grad'], input_grad, context.module_name, 'input_grad')
- context.focused_out_col = validate_config_spec(context.format_by_arg['output_grad'], output_grad, context.module_name, 'output_grad')
+ context.focused_in_col = validate_config_spec(context.format_by_arg['input_grad'], input_grad,
+ context.module_name, 'input_grad')
+ context.focused_out_col = validate_config_spec(context.format_by_arg['output_grad'], output_grad,
+ context.module_name, 'output_grad')
context.verified = True
tbtag_tensor_map = {}
if not context.ignore_in:
cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
- tbtag_tensor_map.update(self.build_tbtag_tensor_map(context.module_name, 'input_grad', cared_input_grad))
+ tbtag_tensor_map.update(
+ self.build_tbtag_tensor_map(context.module_name + f'_{context.micro_step}', f'input_grad',
+ cared_input_grad))
cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col]
- tbtag_tensor_map.update(self.build_tbtag_tensor_map(context.module_name, 'output_grad', cared_output_grad))
- metric_dict = {}
- for metric_name in self.ops:
- metric_dict[metric_name] = get_metrics(metric_name, tbtag_tensor_map, self.eps)
+ tbtag_tensor_map.update(
+ self.build_tbtag_tensor_map(context.module_name + f'_{context.micro_step}', f'output_grad',
+ cared_output_grad))
+
if context.micro_step == 0 and context.actvgrad:
- print_warn_log(f"actvgrad context of {context.module_name} is not empty when first micro_step, maybe something wrong happened. Now clear it.")
+ print_warn_log(
+ f"actvgrad context of {context.module_name} is not empty when first micro_step, maybe something wrong happened. Now clear it.")
context.actvgrad.clear()
- context.actvgrad.append(metric_dict)
+
+ for metric_name in self.ops:
+ self.grad_context.actv[metric_name].update(get_metrics(metric_name, tbtag_tensor_map, self.eps))
context.micro_step += 1
if context.micro_step == self.micro_batch_number:
@@ -420,15 +619,45 @@ class TrainerMon:
context.step += 1
return
+ if self.backward_only and self.forward_only:
+ print_warn_log('not enable backward_only and forward_only simultaneously')
+
hooked_count = 0
- for name, submodule in module.named_modules():
- self.module_struct[name] = {}
- if name in target_names:
- submodule.register_forward_hook(fwd_hook_fun)
- self.module_fwd_hook_context_by_module[submodule] = ModuleHookContext(name)
- if not self.forward_only:
- submodule.register_full_backward_hook(bwd_hook_fun)
- self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name)
- print_rank_0(f"> {name} is monitored successfully")
- hooked_count += 1
+ if self.xy_distribution or self.print_struct:
+ for module_name, submodule in module.named_modules():
+ name = vpp_stage + module_name
+ self.module_struct[name] = {}
+ if name in target_names or module_name in target_names:
+ if not self.backward_only:
+ submodule.register_forward_hook(fwd_hook_fun)
+ self.module_fwd_hook_context_by_module[submodule] = ModuleHookContext(name)
+ if not self.forward_only:
+ submodule.register_full_backward_hook(bwd_hook_fun)
+ self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name)
+ print_rank_0(f"> {name} is monitored successfully")
+ hooked_count += 1
return hooked_count
+
+ def _hook_weights(self):
+ context = self.grad_context
+
+ @torch.no_grad
+ def param_hook(*args, context_dict, param, key, name):
+ param.micro_step += 1
+ self.param_name_call_id[name] = self.call_id
+ self.call_id += 1
+ if param.micro_step == self.micro_batch_number:
+ param.micro_step = 0
+ if self.params_have_main_grad:
+ context_dict[key] = param.main_grad.clone()
+ else:
+ context_dict[key] = param.grad.clone()
+
+ for param, name in self.param2name.items():
+ key = get_summary_writer_tag_name(name, 'acc_grad', self.rank)
+ setattr(param, 'micro_step', 0)
+ param_tmp = param.expand_as(param)
+ grad_acc = param_tmp.grad_fn.next_functions[0][0]
+ grad_acc.register_hook(partial(param_hook, context_dict=context.acc, param=param, key=key, name=name))
+
+ self.weight_hooked = True
diff --git a/debug/accuracy_tools/kj600/kj600/module_metric.py b/debug/accuracy_tools/monitor/monitor/module_metric.py
similarity index 65%
rename from debug/accuracy_tools/kj600/kj600/module_metric.py
rename to debug/accuracy_tools/monitor/monitor/module_metric.py
index e09536b072cf7953e6b6106420936416d4264d0e..681e9d90829a3333302a85b7867ff5c91cc51dcf 100644
--- a/debug/accuracy_tools/kj600/kj600/module_metric.py
+++ b/debug/accuracy_tools/monitor/monitor/module_metric.py
@@ -1,15 +1,26 @@
import math
+import re
import statistics
-from kj600.features import square_sum, get_max, get_min, get_zeros, get_nans, get_norm
+from monitor.features import square_sum, get_max, get_min, get_zeros, get_nans, get_norm, get_mean
def get_summary_writer_tag_name(module_or_param_name:str, tag:str, rank):
if rank is None:
return f"{module_or_param_name}/{tag}"
else:
- return f"{module_or_param_name}/{rank}/{tag}"
-
+ return f"{module_or_param_name}/rank{rank}/{tag}"
+
+def squash_param_name(param_name):
+ name = ''
+ for pattern in ['(?<=layers\.)[\d]*.*', 'embeddings?\.(.*)', 'final.*', 'output.*','norm.*']:
+ match = re.findall(pattern, param_name)
+ if match:
+ name += match[0]
+ break
+ if name == '':
+ name = param_name
+ return name
# 用于存储所有metric实现类的注册表
config_metric_registry = {}
@@ -28,9 +39,9 @@ class TensorMetrics:
self.metrics = {} #tensor_tag --> []
self.cur_idx = {}
- fun_map = {"norm": get_norm, "max": get_max, "min": get_min}
+ fun_map = {"norm": get_norm, "max": get_max, "min": get_min, "mean": get_mean}
#get stats and insert into metrics dictionary
- def stat_insert(self, tensor, stat_ops, module_name, tensor_name, rank, eps=1e-8):
+ def stat_insert(self, tensor, stat_ops, module_name, tensor_name, rank):
prefix = get_summary_writer_tag_name(module_name, tensor_name, rank)
for stat_op in stat_ops:
y = TensorMetrics.fun_map[stat_op](tensor)
@@ -75,6 +86,19 @@ class MinMetric(Metric):
summary_writer.add_scalar(f'{key}_min', min_value, step)
+@register_config_metric("mean")
+class MeanMetric(Metric):
+ @staticmethod
+ def get_metric_value(tensor, eps):
+ return get_mean(tensor)
+
+ @staticmethod
+ def metric_tensorboard(metric_name, summary_writer, metric_value, step):
+ for key in metric_value[0][metric_name].keys():
+ mean_value = sum([item[metric_name][key].item() for item in metric_value]) / len(metric_value)
+ summary_writer.add_scalar(f'{key}_mean', mean_value, step)
+
+
@register_config_metric("max")
class MaxMetric(Metric):
@staticmethod
@@ -129,17 +153,16 @@ class NaNsMetric(Metric):
class IdentMetric(Metric):
@staticmethod
def get_metric_value(tensor, eps):
- if tensor.dim() != 0:
- return None
- return tensor
+ if tensor.dim() == 0:
+ return tensor
@staticmethod
- def metric_tensorboard(metric_name, summary_writer, metric_value, step): #metric_value is a dict, key is parameter name and value is a list of scalar tensor
+ def metric_tensorboard(metric_name, summary_writer, metric_value, context): #metric_value is a dict, key is parameter name and value is a list of scalar tensor
if len(metric_value) == 1:
for key, value in metric_value[0][metric_name].items():
if not value:
continue
- summary_writer.add_scalar(f'{key}_identical', value.item(), step)
+ summary_writer.add_scalar(f'{key}_identical', value.item(), context)
def get_metrics(metric_name, tag2tensor, eps):
@@ -150,9 +173,32 @@ def get_metrics(metric_name, tag2tensor, eps):
raise ValueError(f"Not supported this metric, expected metric: {config_metric_registry.keys()}, actual metric: {metric_name}") from e
-def write_metrics_tensorboard(metric_name, summary_writer, metric_value, step):
- try:
- fun_metric = config_metric_registry[metric_name]
- return fun_metric.metric_tensorboard(metric_name, summary_writer, metric_value, step)
- except KeyError as e:
- raise ValueError(f"Not supported this metric, expected metric: {config_metric_registry.keys()}, actual metric: {metric_name}") from e
+def write_metrics_tensorboard(ops, summary_writer, metric_value, step, prefix=''):
+ for metric_name in ops:
+ try:
+ fun_metric = config_metric_registry[metric_name]
+ fun_metric.metric_tensorboard(metric_name, summary_writer, metric_value, step)
+ except KeyError as e:
+ raise ValueError(f"Not supported this metric, expected metric: {config_metric_registry.keys()}, actual metric: {metric_name}") from e
+
+def write_metrics_csv(ops, summary_writer, metric_value, step, prefix=''):
+ for metric_name in ops:
+ try:
+ fun_metric = config_metric_registry[metric_name]
+ fun_metric.metric_tensorboard(metric_name, summary_writer, metric_value, step)
+
+ except KeyError as e:
+ raise ValueError(f"Not supported this metric, expected metric: {config_metric_registry.keys()}, actual metric: {metric_name}") from e
+
+ if not summary_writer.header:
+ if prefix in ['actv', 'grad_actv']:
+ summary_writer.header = ['param_name'] + ['input_'+op for op in ops] + ['output_'+op for op in ops]
+ else:
+ summary_writer.header = ['param_name'] + ops
+
+ for key in metric_value[0][ops[0]].keys():
+ if 'vpp' in key:
+ summary_writer.header.insert(0, 'vpp_stage')
+ break
+ summary_writer.write_csv(prefix, step)
+ summary_writer.header = []
diff --git a/debug/accuracy_tools/kj600/kj600/module_spec_verifier.py b/debug/accuracy_tools/monitor/monitor/module_spec_verifier.py
similarity index 91%
rename from debug/accuracy_tools/kj600/kj600/module_spec_verifier.py
rename to debug/accuracy_tools/monitor/monitor/module_spec_verifier.py
index 395aa82f17a87cdf742a8294e29ccb1c32081200..062d5c230ef9d438ab27b3e704b12f3828e7f076 100644
--- a/debug/accuracy_tools/kj600/kj600/module_spec_verifier.py
+++ b/debug/accuracy_tools/monitor/monitor/module_spec_verifier.py
@@ -2,15 +2,8 @@ import json
import re
import abc
import torch
-from kj600.utils import check_file_valid_readable
-def get_config(file_path='config.json'):
- check_file_valid_readable(file_path)
- with open(file_path, 'r') as file:
- config = json.load(file)
- return config
-
# 用于存储所有validator实现类的注册表
config_validator_registry = {}
@@ -40,7 +33,6 @@ class TensorValidator(ConfigValidator):
def validate(self, actual_data, module_name:str, data_type:str, pattern_match):
if not torch.is_tensor(actual_data):
raise ValueError(f"Format of {module_name} {data_type} does not match the required format 'tensor' in config.")
- return None
@register_config_validator
diff --git a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py b/debug/accuracy_tools/monitor/monitor/optimizer_collect.py
similarity index 51%
rename from debug/accuracy_tools/kj600/kj600/optimizer_collect.py
rename to debug/accuracy_tools/monitor/monitor/optimizer_collect.py
index 285f17ca6dc6a00814b0847c7d203524d8a8caa6..7d2f488efad18d23ca48063f9e21193f24c9e05e 100644
--- a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py
+++ b/debug/accuracy_tools/monitor/monitor/optimizer_collect.py
@@ -1,10 +1,12 @@
-from collections import defaultdict
+from abc import ABC, abstractmethod
+from collections import defaultdict, namedtuple
import torch
import torch.distributed as dist
-from kj600.visualizer import HeatmapVisualizer
+from monitor.utils import print_warn_log
-def print_rank_0(message, debug=False, force=False):
+
+def print_rank_0(message):
if dist.is_initialized():
if dist.get_rank() == 0:
print(message)
@@ -12,20 +14,29 @@ def print_rank_0(message, debug=False, force=False):
print(message)
-class MixPrecsionOptimizerMon:
+MVResult = namedtuple('MVResult', ("exp_avg", "exp_avg_sq", "update", "ratio"))
+
+
+class OptimizerMon(ABC):
wrapped_optimizer = None
+ @classmethod
+ def set_wrapped_optimizer(cls, wrapped_optimizer):
+ cls.wrapped_optimizer = wrapped_optimizer
+
+ @abstractmethod
+ def fetch_mv(self, monitor, torch_opt, params2name):
+ pass
+
+
+class MixPrecisionOptimizerMon(OptimizerMon):
def __init__(self) -> None:
self.fp16_to_fp32_param = {}
- @staticmethod
- def set_wrapped_optimizer(_wrapped_optimizer):
- MixPrecsionOptimizerMon.wrapped_optimizer = _wrapped_optimizer
-
# parameter tensors we want to monitor and their names are in params2name_dict
# base_optimizer is pytorch optimizer, wrapped_optimizer is a normal object with base_optimizer
def fetch_mv(self, monitor, torch_opt, params2name):
- mix_prec_opt = MixPrecsionOptimizerMon.wrapped_optimizer
+ mix_prec_opt = self.wrapped_optimizer
if not self.fp16_to_fp32_param and mix_prec_opt is not None:
for fp16_group, fp32_group in zip(mix_prec_opt.float16_groups, mix_prec_opt.fp32_from_float16_groups):
@@ -44,8 +55,12 @@ class MixPrecsionOptimizerMon:
param = self.fp16_to_fp32_param[param]
if param in torch_opt.state:
- exp_avg = torch_opt.state[param]["exp_avg"]
- exp_avg_sq = torch_opt.state[param]["exp_avg_sq"]
+ state_param = torch_opt.state.get(param, None)
+ exp_avg = state_param.get("exp_avg", None)
+ exp_avg_sq = state_param.get("exp_avg_sq", None)
+ if exp_avg is None or exp_avg_sq is None:
+ print_warn_log(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.")
+ continue
if monitor.mv_distribution:
exp_avg_dict[name] = exp_avg
exp_avg_sq_dict[name] = exp_avg_sq
@@ -53,15 +68,15 @@ class MixPrecsionOptimizerMon:
exp_avg_dict[name] = exp_avg
if monitor.ur_distribution:
update_dict[name] = exp_avg / (torch.sqrt(exp_avg_sq) + torch_opt.defaults['eps'])
- ratio_dict[name] = exp_avg / torch.sqrt(exp_avg_sq)
+ ratio_dict[name] = (exp_avg / torch.sqrt(exp_avg_sq)).nan_to_num(0)
monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
- return exp_avg_dict, exp_avg_sq_dict, update_dict, ratio_dict
+ return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict)
-class MegatronDistributedOptimizerMon(MixPrecsionOptimizerMon):
+class MegatronDistributedOptimizerMon(MixPrecisionOptimizerMon):
def fetch_mv(self, monitor, torch_opt, params2name):
- mix_prec_opt = MixPrecsionOptimizerMon.wrapped_optimizer
+ mix_prec_opt = self.wrapped_optimizer
if not (hasattr(mix_prec_opt, "model_float16_groups") and hasattr(mix_prec_opt, "shard_fp32_from_float16_groups")):
raise Exception("megatron distributed optimizer should have model_float16_groups and shard_fp32_from_float16_groups, \
if not, please check megatron-lm version")
@@ -73,18 +88,48 @@ class MegatronDistributedOptimizerMon(MixPrecsionOptimizerMon):
return self._fetch_mv_in_adam(params2name, torch_opt, monitor)
-class DummyOptimizerMon(MixPrecsionOptimizerMon):
+class MegatronFP32OptimizerMon(OptimizerMon):
+ def fetch_mv(self, monitor, torch_opt, params2name):
+ exp_avg_dict = defaultdict(float)
+ exp_avg_sq_dict = defaultdict(float)
+ update_dict = defaultdict()
+ ratio_dict = defaultdict()
+
+ for param, name in params2name.items():
+ if param in torch_opt.state:
+ state_param = torch_opt.state.get(param, None)
+ exp_avg = state_param.get("exp_avg", None)
+ exp_avg_sq = state_param.get("exp_avg_sq", None)
+ if exp_avg is None or exp_avg_sq is None:
+ print_warn_log(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.")
+ continue
+ if monitor.mv_distribution:
+ exp_avg_dict[name] = exp_avg
+ exp_avg_sq_dict[name] = exp_avg_sq
+ if monitor.mg_direction:
+ exp_avg_dict[name] = exp_avg
+ if monitor.ur_distribution:
+ update_dict[name] = exp_avg / (torch.sqrt(exp_avg_sq) + torch_opt.defaults['eps'])
+ ratio_dict[name] = (exp_avg / torch.sqrt(exp_avg_sq)).nan_to_num(0)
+ monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
+ monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
+ return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict)
+
+
+class DummyOptimizerMon(OptimizerMon):
def fetch_mv(self, monitor, torch_opt, params2name):
- return None, None, None, None
+ return MVResult(exp_avg=None, exp_avg_sq=None, update=None, ratio=None)
class OptimizerMonFactory:
@staticmethod
- def create_optimizer_mon(opt_ty:str):
+ def create_optimizer_mon(opt_ty: str):
if opt_ty == "Megatron_Float16OptimizerWithFloat16Params":
- return MixPrecsionOptimizerMon()
+ return MixPrecisionOptimizerMon()
if opt_ty == "Megatron_DistributedOptimizer":
return MegatronDistributedOptimizerMon()
+ if opt_ty == "Megatron_FP32Optimizer":
+ return MegatronFP32OptimizerMon()
if opt_ty is None or opt_ty == "unknown":
return DummyOptimizerMon()
raise Exception("opt_ty should be Megatron_Float16OptimizerWithFloat16Params or Megatron_DistributedOptimizer or None or unknown")
diff --git a/debug/accuracy_tools/monitor/monitor/unittest/test_monitor.py b/debug/accuracy_tools/monitor/monitor/unittest/test_monitor.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7017a6be5e01ad1b332c08e2332a47eee1de023
--- /dev/null
+++ b/debug/accuracy_tools/monitor/monitor/unittest/test_monitor.py
@@ -0,0 +1,145 @@
+import sys
+import os
+import re
+import argparse
+import pandas as pd
+from glob import glob
+from collections import defaultdict
+
+
+def parse_logfile(logfile):
+ grad_norm = []
+ step = []
+ with open(logfile) as f:
+ for line in f.readlines():
+ if 'consumed samples' in line:
+ grad_norm.append(float(re.findall('(?<=grad norm\: )[\d\.]*', line)[0]))
+ # step = int(re.findall('(?<=iteration)[ \d]*', line)[0])
+ return grad_norm
+
+
+def parse_monitor_output(output_dir):
+ reduced = {}
+ unreduced = {}
+ for dir in glob(output_dir+'*'):
+ rank = int(re.findall('(?<=rank)[\d]*', dir)[0])
+ unreduced[rank] = []
+ reduced[rank] = []
+ for file in os.listdir(dir):
+ # step = int(re.search("(?<=reduced\_)[\d]*", file)[0])
+ # if step != 0:
+ # continue
+ df = pd.read_csv(os.path.join(dir, file))
+ if '_unreduced_' in file:
+ unreduced[rank].append(df)
+ pass
+ elif '_reduced_' in file:
+ reduced[rank].append(df)
+ else:
+ print(f'unexpected file {file} in {dir}')
+ return reduced, unreduced
+
+def valid_reduce(reduced, unreduced, tp_size, dp_size, sequence_parallel):
+ steps = len(reduced[0])
+ world_size = len(reduced)
+ errors = []
+ for index, row in unreduced[0][0].iterrows():
+ param = row['param_name']
+ is_tp_duplicate = False
+ for step in range(2):
+ # sum reduced
+ reduced_mean = 0.
+ for rank in range(world_size):
+ if len(reduced[rank]) == 0:
+ continue
+ df = reduced[rank][step]
+ value = list(df[df['param_name'] == param]['mean'])
+ if value == []:
+ if step == 0:
+ is_tp_duplicate = True
+ continue
+ reduced_mean += value[0]
+
+ # sum unreduced
+ unreduced_mean = 0.
+ for rank in range(world_size):
+ df = unreduced[rank][step]
+ value = list(df[df['param_name'] == param]['mean'])
+ if value == []:
+ continue
+ unreduced_mean += list(df[df['param_name'] == param]['mean'])[0]
+
+ unreduced_mean /= dp_size
+ if is_tp_duplicate and (not sequence_parallel or 'embedding' in param):
+ unreduced_mean /= tp_size
+ try:
+ assert_equal(unreduced_mean, reduced_mean)
+ except AssertionError as e:
+ errors.append([param, step, e, is_tp_duplicate])
+ if errors:
+ print(errors)
+ else:
+ print(f'grad mean is in consist between unreduced grad and reduced grad monitord.')
+
+
+
+def assert_equal(a, b):
+ if b == 0 or a == 0:
+ return
+ if b == 0:
+ rel_diff = a
+ elif a == 0:
+ rel_diff = b
+ else:
+ rel_diff = abs(a/b-1)
+ assert rel_diff<0.01, f'{a}, {b}, {rel_diff}'
+
+
+def valid_total_norm(total_norm, reduced, duplicate_embedding):
+ steps = len(total_norm)
+ world_size = len(reduced)
+ errors = []
+ for step in range(steps):
+ calculated_norm = 0.
+ for rank in range(world_size):
+ if len(reduced[rank]) == 0:
+ if step == 0:
+ print(f'rank {rank} is duplicated in dp group')
+ continue
+ for index, row in reduced[rank][step].iterrows():
+ if duplicate_embedding and 'word_embedding' in row['param_name']:
+ continue
+ calculated_norm += row['norm']**2
+ try:
+ assert_equal(calculated_norm**0.5, total_norm[step])
+ except AssertionError as e:
+ errors.append([step, e])
+ if errors:
+ print('total norm errors: ', errors)
+ else:
+ print('grad norm in consist between training log and reduced gradients monitored')
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--monitor_output', '-m', type=str, required=True, help='path prefix to the output of monitor e.g. monitor_output/Aug12_07-16')
+ parser.add_argument('--logfile', '-l', type=str, required=True, help='path to the training log file')
+ parser.add_argument('--tp_size', '-t', type=int, required=True, help='tp parallel size')
+ parser.add_argument('--dp_size', '-d', type=int, required=True, help='dp parallel size')
+ parser.add_argument('--pp_size', '-p', type=int, required=True, help='pp parallel size')
+ parser.add_argument('--untie_embeddings_and_output_weights', '-u', action="store_true", default=False, help='whether untie_embeddings_and_output_weights in pp parallel')
+ parser.add_argument('--sequence_parallel', '-s', action="store_true", default=False, help='whether sequence parallel is enabled. Add -s to store true')
+
+ args = parser.parse_args()
+
+ assert args.tp_size > 0, 'if tp not enabled, set tp_size = 1'
+ assert args.dp_size > 0, 'if tp not enabled, set dp_size = 1'
+ assert args.pp_size > 0, 'if tp not enabled, set pp_size = 1'
+
+ total_norm = parse_logfile(args.logfile)
+ reduced, unreduced = parse_monitor_output(args.monitor_output)
+
+ duplicate_embedding = not args.untie_embeddings_and_output_weights and args.pp_size > 1
+
+ valid_total_norm(total_norm, reduced, duplicate_embedding)
+ valid_reduce(reduced, unreduced, args.tp_size, args.dp_size, args.sequence_parallel)
\ No newline at end of file
diff --git a/debug/accuracy_tools/kj600/kj600/utils.py b/debug/accuracy_tools/monitor/monitor/utils.py
similarity index 73%
rename from debug/accuracy_tools/kj600/kj600/utils.py
rename to debug/accuracy_tools/monitor/monitor/utils.py
index 53d47d9988647202bdb711afde38b94b51899b5a..3aed6911c440da507a0b3ec8a02091308061e84b 100644
--- a/debug/accuracy_tools/kj600/kj600/utils.py
+++ b/debug/accuracy_tools/monitor/monitor/utils.py
@@ -107,4 +107,29 @@ def check_file_valid_readable(path):
def check_file_valid_writable(path):
check_file_valid(path)
check_path_writability(path)
-
\ No newline at end of file
+
+
+def make_file_safety(file_path: str, permission=0o640):
+ if os.path.islink(file_path):
+ raise RuntimeError(f"Invalid soft link path: {file_path}")
+ file_real_path = os.path.realpath(file_path)
+ if os.path.exists(file_real_path):
+ return
+ parent_path = os.path.dirname(file_real_path)
+ if not os.path.exists(parent_path):
+ os.makedirs(parent_path, mode=0o750, exist_ok=True)
+ if not os.access(parent_path, os.W_OK):
+ raise PermissionError(f"The path {parent_path} is not writable!")
+ try:
+ os.close(os.open(file_real_path, os.O_WRONLY | os.O_CREAT, permission))
+ except OSError as e:
+ raise RuntimeError("Can't create file: " + file_real_path) from e
+ os.chmod(file_real_path, permission)
+
+
+def create_directory(dir_path):
+ dir_path = os.path.realpath(dir_path)
+ try:
+ os.makedirs(dir_path, mode=0o750, exist_ok=True)
+ except OSError as ex:
+ raise RuntimeError("Failed to create directory. Please check the path permission or disk space.") from ex
\ No newline at end of file
diff --git a/debug/accuracy_tools/kj600/kj600/visualizer.py b/debug/accuracy_tools/monitor/monitor/visualizer.py
similarity index 97%
rename from debug/accuracy_tools/kj600/kj600/visualizer.py
rename to debug/accuracy_tools/monitor/monitor/visualizer.py
index e1929bfa3fb338b1cb66cda80a128e83176bfcbf..151f1ea1c451a5f27d250df591c4d00f64a1a34c 100644
--- a/debug/accuracy_tools/kj600/kj600/visualizer.py
+++ b/debug/accuracy_tools/monitor/monitor/visualizer.py
@@ -1,7 +1,7 @@
import torch
import numpy as np
import matplotlib.pyplot as plt
-from kj600.features import cal_histc
+from monitor.features import cal_histc
class HeatmapVisualizer:
diff --git a/debug/accuracy_tools/kj600/pyproject.toml b/debug/accuracy_tools/monitor/pyproject.toml
similarity index 78%
rename from debug/accuracy_tools/kj600/pyproject.toml
rename to debug/accuracy_tools/monitor/pyproject.toml
index 5df968563345dd07ed477ec73b967b63c6e812a6..5111fbbbd53a9e88bd52f12a61621d2dc3ed6203 100644
--- a/debug/accuracy_tools/kj600/pyproject.toml
+++ b/debug/accuracy_tools/monitor/pyproject.toml
@@ -3,11 +3,9 @@ requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[project]
-name = "kj600"
+name = "monitor"
version = "0.0.1"
dependencies = [
- "torch",
- "torch_npu",
"torchvision",
"tensorboard",
"matplotlib",
@@ -16,4 +14,7 @@ dependencies = [
]
[tool.setuptools.packages]
-find = {} # Scan the project directory with the default parameters
\ No newline at end of file
+find = {} # Scan the project directory with the default parameters
+
+[tool.setuptools.package-data]
+monitor = ["distributed/*.yaml"]
\ No newline at end of file
diff --git "a/debug/accuracy_tools/kj600/\350\256\255\347\273\203\347\212\266\346\200\201\347\233\221\346\216\247\345\267\245\345\205\267\346\200\247\350\203\275\345\237\272\347\272\277.md" "b/debug/accuracy_tools/monitor/\350\256\255\347\273\203\347\212\266\346\200\201\347\233\221\346\216\247\345\267\245\345\205\267\346\200\247\350\203\275\345\237\272\347\272\277.md"
similarity index 100%
rename from "debug/accuracy_tools/kj600/\350\256\255\347\273\203\347\212\266\346\200\201\347\233\221\346\216\247\345\267\245\345\205\267\346\200\247\350\203\275\345\237\272\347\272\277.md"
rename to "debug/accuracy_tools/monitor/\350\256\255\347\273\203\347\212\266\346\200\201\347\233\221\346\216\247\345\267\245\345\205\267\346\200\247\350\203\275\345\237\272\347\272\277.md"
diff --git a/debug/accuracy_tools/atat/README.md b/debug/accuracy_tools/msprobe/README.md
similarity index 62%
rename from debug/accuracy_tools/atat/README.md
rename to debug/accuracy_tools/msprobe/README.md
index de7d74f8f29cef077bec8b665afcc537974df774..663dde17e1af734b7f3c499251a5cec91f848a36 100644
--- a/debug/accuracy_tools/atat/README.md
+++ b/debug/accuracy_tools/msprobe/README.md
@@ -1,11 +1,26 @@
# MindStudio精度调试工具
-MindStudio精度调试工具(ascend_training_accuracy_tools),简称atat,是MindStudio Training Tools工具链下精度调试部分的工具包。主要包括精度预检和精度比对等子工具,当前适配场景包括PyTorch和MindSpore。
+MindStudio精度调试工具(MindStudio Probe),简称msprobe,是MindStudio Training Tools工具链下精度调试部分的工具包。主要包括精度预检和精度比对等子工具,当前适配场景包括PyTorch和MindSpore。
## 工具安装
-精度工具合一软件包名称:`ascend_training_accuracy_tools-{version}-py3-none-any.whl`
+精度工具合一软件包名称:`mindstudio_probe-{version}-py3-none-any.whl`
+### pip安装
+ ```shell
+ pip install mindstudio-probe
+ ```
+使用`pip install mindstudio-probe==版本号`可安装指定版本的包。
+
+pip命令会自动安装最新的包及其配套依赖。
+
+提示如下信息则表示安装成功。
+
+```bash
+Successfully installed mindstudio_probe-{version}
+```
+
+### 下载whl包安装
1. 使用pip命令安装numpy、openpyxl、pandas、PyYAML、rich、torch、tqdm依赖。
若环境中已安装部分依赖,不需要重复安装。
@@ -16,6 +31,7 @@ MindStudio精度调试工具(ascend_training_accuracy_tools),简称atat,
| 版本 | 发布日期 | 支持PyTorch版本 | 下载链接 | 校验码 |
| ----- | ---------- | --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ |
+ | 1.0.1 | 2024-07-25 | 2.0/2.1/2.2 | [mindstudio_probe-1.0.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.0/mindstudio_probe-1.0.1-py3-none-any.whl) | b699e224e4d4e3bcf9412c54fa858a1ee370f0d7a2bc69cb3f1273ac14a6dc82 |
| 1.0 | 2024-07-09 | 2.0/2.1/2.2 | [ascend_training_accuracy_tools-1.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/att/1.0/ascend_training_accuracy_tools-1.0-py3-none-any.whl) | 5016dfe886c5d340ec6f60a959673355855f313c91f100680da814efb49f8e81 |
| 0.0.3 | 2024-06-11 | 2.0/2.1/2.2 | [ascend_training_accuracy_tools-0.0.3-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/att/0.0/ascend_training_accuracy_tools-0.0.3-py3-none-any.whl) | f46d9714704859e2d67861a65bbb3c76b0a250cf6e238b978b5b959ab1fe125a |
| 0.0.2 | 2024-05-23 | 2.0/2.1/2.2 | [ascend_training_accuracy_tools-0.0.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/att/0.0/ascend_training_accuracy_tools-0.0.2-py3-none-any.whl) | 2e35809bde559e9c4d2f16a02ccde779ed9e436bb65fded0b7ebaf6ac2c88d93 |
@@ -43,25 +59,79 @@ MindStudio精度调试工具(ascend_training_accuracy_tools),简称atat,
4. 执行如下命令进行安装。
```bash
- pip3 install ./ascend_training_accuracy_tools-{version}-py3-none-any.whl
+ pip3 install ./mindstudio_probe-{version}-py3-none-any.whl
```
若为覆盖安装,请在命令行末尾增加“--force-reinstall”参数强制安装,例如:
```bash
- pip3 install ./ascend_training_accuracy_tools-{version}-py3-none-any.whl --force-reinstall
+ pip3 install ./mindstudio_probe-{version}-py3-none-any.whl --force-reinstall
```
提示如下信息则表示安装成功。
```bash
- Successfully installed ascend_training_accuracy_tools-{version}
+ Successfully installed mindstudio_probe-{version}
+ ```
+
+### 从源码安装
+1. 克隆或者下载项目源代码
+
+ ```shell
+ git clone https://gitee.com/ascend/mstt.git
+ cd debug/accuracy_tools
+ ```
+
+2. 安装setuptools和wheel
+
+ ```shell
+ pip install setuptools wheel
+ ```
+
+3. 安装msprobe
+
+ ```shell
+ python setup.py install
+ ```
+ 提示出现如下信息则表示源码安装成功。
+ ```shell
+ Finished processing dependencies for mindstudio-probe=={version}
```
+### 查看msprobe工具信息
+
+执行如下命令查看msprobe工具信息。
+
+```bash
+pip show mindstudio-probe
+```
+
+输出结果如下示例:
+
+```bash
+Name: mindstudio-probe
+Version: 1.0
+Summary: This is a pytorch precision comparison tools
+Home-page:
+Author:
+Author-email:
+License:
+Location: /home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages
+Requires: numpy, openpyxl, pandas, pyyaml, rich, tqdm, wheel
+Required-by:
+```
+
+关键字段含义:
+
+- Name:工具名称。
+- Version:工具版本号。
+- Summary:工具概述。
+- Location:工具安装路径。
+- Requires:工具依赖。
## 工具使用
-安装atat工具后,可以按照如下思路选择合适的子工具进行精度调试:
+安装msprobe工具后,可以按照如下思路选择合适的子工具进行精度调试:
1. 判断框架场景。
@@ -107,32 +177,36 @@ MindStudio精度调试工具(ascend_training_accuracy_tools),简称atat,
MindSpore场景:暂不支持。
-上述流程中的工具均为atat工具的子工具,使用相同的命令行,格式如下:
+7. TorchAir场景的数据采集和精度比对。
+
+ 仅支持PyTorch场景,详见[TorchAir训练场景-整网算子精度比对](./pytorch/doc/torchair_compare.md)和[TorchAir训练场景Dump案例](./pytorch/doc/torchair_dump_sample.md)。
+
+上述流程中的工具均为msprobe工具的子工具,使用相同的命令行,格式如下:
精度预检工具
```bash
-atat -f run_ut [-h]
+msprobe -f run_ut [-h]
```
```bash
-atat -f multi_run_ut [-h]
+msprobe -f multi_run_ut [-h]
```
```bash
-atat -f api_precision_compare [-h]
+msprobe -f api_precision_compare [-h]
```
溢出解析工具
```bash
-atat -f run_overflow_check [-h]
+msprobe -f run_overflow_check [-h]
```
数据解析工具
```bash
-atat -f parse [-h]
+msprobe -f parse [-h]
```
| 参数 | 说明 |
diff --git a/debug/accuracy_tools/atat/mindspore/dump/__init__.py b/debug/accuracy_tools/msprobe/__init__.py
similarity index 100%
rename from debug/accuracy_tools/atat/mindspore/dump/__init__.py
rename to debug/accuracy_tools/msprobe/__init__.py
diff --git a/debug/accuracy_tools/atat/config/README.md b/debug/accuracy_tools/msprobe/config/README.md
similarity index 89%
rename from debug/accuracy_tools/atat/config/README.md
rename to debug/accuracy_tools/msprobe/config/README.md
index a998704993accefa6f167c6fac6ec18218cca211..7d11a3652539d457736e1c8d5be62cabb558422d 100644
--- a/debug/accuracy_tools/atat/config/README.md
+++ b/debug/accuracy_tools/msprobe/config/README.md
@@ -2,13 +2,38 @@
当前配置文件主要为PrecisionDebugger接口执行dump或无标杆比对操作时调用的配置,当PrecisionDebugger接口未指定该配置文件时,使用该文件的默认配置。配置文件详见[config.json](./config.json)。
+当在环境上安装msprobe工具后,config.json文件位置可通过如下方式查找:
+
+查找msprobe工具安装路径。
+
+```
+pip show mindstudio-probe
+```
+
+输出结果如下示例:
+
+```
+Name: mindstudio-probe
+Version: 1.0
+Summary: This is a pytorch precision comparison tools
+Home-page:
+Author:
+Author-email:
+License:
+Location: /home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages
+Requires: numpy, openpyxl, pandas, pyyaml, rich, tqdm, wheel
+Required-by:
+```
+
+Location字段为msprobe工具的安装路径,那么config.json文件位置为/home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages/msprobe/config
+
## 参数说明
### **通用配置参数**
| 参数名 | 说明 | 是否必选 |
| ----------------- | ------------------------------------------------------------ | -------- |
-| task | dump的任务类型,str类型。可取值"free_benchmark"(无标杆比对,仅PyTorch场景支持)、"statistics"(仅dump API统计信息,默认值)、"tensor"(dump API统计信息和完全复刻整网的API运行情况的真实数据)、"overflow_check"(溢出检测)。配置示例:"task": "tensor"。根据task参数取值的不同,可以配置不同场景参数,详见:“**task配置为free_benchmark**”,“**task配置为statistics**”,“**task配置为tensor**”,“**task配置为overflow_check**”。 | 否 |
+| task | dump的任务类型,str类型。可取值:
"free_benchmark"(无标杆比对,仅PyTorch场景支持)。
"statistics"(仅dump API统计信息,默认值)。
"tensor"(dump API统计信息和完全复刻整网的API运行情况的真实数据)。
"overflow_check"(溢出检测,仅PyTorch和MindSpore静态图场景支持)。
"run_ut"(精度预检配置,仅PyTorch场景支持)。
配置示例:"task": "tensor"。
根据task参数取值的不同,可以配置不同场景参数,详见:“**task配置为free_benchmark**”,“**task配置为statistics**”,“**task配置为tensor**”,“**task配置为overflow_check**”,“**task配置为run_ut**”。 | 否 |
| dump_path | 设置dump数据目录路径,str类型。配置示例:"dump_path": "./dump_path"。MindSpore场景仅支持绝对路径。 | 是 |
| rank | 指定对某张卡上的数据进行dump,list[int]类型,默认未配置(表示dump所有卡的数据),应配置为大于等于0的整数,且须配置实际可用的Rank ID。配置示例:"rank": [1]。
对于PyTorch场景,Rank ID从0开始计数,最大取值为所有节点可用卡总数-1,若所配置的值大于实际训练所运行的卡的Rank ID,则dump数据为空,比如当前环境Rank ID为0到7,实际训练运行0到3卡,此时若配置Rank ID为4或不存在的10等其他值,此时dump数据为空。
对于MindSpore场景,所有节点的Rank ID均从0开始计数,最大取值为每个节点可用卡总数-1,config.json配置一次rank参数对所有节点同时生效。 | 否 |
| step | 指定dump某个step的数据,list[int]类型。默认未配置,表示dump所有step数据。dump特定step时,须指定为训练脚本中存在的step。step为list格式,可配置逐个step,例如:"step": [0,1,2]。 | 否 |
@@ -85,6 +110,18 @@ task配置为free_benchmark时,开启**无标杆比对**,在NPU环境下通
| overflow_nums | 控制溢出次数,int类型,仅PyTorch场景支持,表示第N次溢出时,停止训练,过程中检测到溢出API对应kernel数据均dump。配置示例:"overflow_nums": 3。默认为1,即检测到1次溢出,训练停止,配置为-1时,表示持续检测溢出直到训练结束。 | 否 |
| check_mode | MindSpore场景kernel级别的溢出检测,str类型,可取值"aicore"(开启AI Core的溢出检测)、"atomic"(开启Atomic的溢出检测)、"all"(开启AI Core和Atomic的溢出检测,默认值)。配置示例"check_mode": "aicore"。 | 否 |
+### task配置为run_ut
+
+仅PyTorch场景支持。
+
+| 参数名称 | 说明 | 是否必选 |
+| --------------- | ------------------------------------------------------------ | -------- |
+| white_list | API dump白名单,仅对指定的API进行dump。配置示例:"white_list": ["conv1d", "conv2d"]。默认未配置白名单,即dump全量API数据。 | 否 |
+| black_list | API dump黑名单,被指定的API不进行dump。配置示例:"black_list": ["conv1d", "conv2d"]。默认未配置黑名单,即dump全量API数据。 | 否 |
+| error_data_path | 配置保存精度未达标的API输入输出数据路径,默认为当前路径。配置示例"error_data_path": "./"。 | 否 |
+
+说明:white_list和black_list同时配置时,二者配置的API名单若无交集,则白名单生效,若API名单存在交集,则白名单排除的部分以及交集的API不进行dump。
+
## 配置示例
以下示例包含当前支持的所有场景可配置的完整参数。
@@ -180,6 +217,27 @@ task配置为free_benchmark时,开启**无标杆比对**,在NPU环境下通
}
```
+### PyTorch场景task配置为run_ut
+
+```json
+{
+ "task": "run_ut",
+ "dump_path": "/home/data_dump",
+ "rank": [],
+ "step": [],
+ "level": "L1",
+ "seed": 1234,
+ "is_deterministic": false,
+ "enable_dataloader": false,
+
+ "run_ut": {
+ "white_list": [],
+ "black_list": [],
+ "error_data_path": "./"
+ }
+}
+```
+
### MindSpore场景task配置为statistics
```json
@@ -394,4 +452,4 @@ train_loader = torch.utils.data.DataLoader(
关闭dropout:
-在使用from atat.pytorch import PrecisionDebugger后,工具会自动将torch.nn.functional.dropout、torch.nn.functional.dropout2d、torch.nn.functional.dropout3d、torch.nn.Dropout、torch.nn.Dropout2d、torch.nn.Dropout3d的接口参数p置为0。
+在使用from msprobe.pytorch import PrecisionDebugger后,工具会自动将torch.nn.functional.dropout、torch.nn.functional.dropout2d、torch.nn.functional.dropout3d、torch.nn.Dropout、torch.nn.Dropout2d、torch.nn.Dropout3d的接口参数p置为0。
diff --git a/debug/accuracy_tools/atat/config/config.json b/debug/accuracy_tools/msprobe/config/config.json
similarity index 80%
rename from debug/accuracy_tools/atat/config/config.json
rename to debug/accuracy_tools/msprobe/config/config.json
index 70a630a40af1fbd827d060d4783071415d2cb610..f3191ad985c95121ca314146ef96e14dafe7eb09 100644
--- a/debug/accuracy_tools/atat/config/config.json
+++ b/debug/accuracy_tools/msprobe/config/config.json
@@ -7,6 +7,7 @@
"seed": 1234,
"is_deterministic": false,
"enable_dataloader": false,
+ "enable_step_auto_dump": false,
"acl_config": "",
"tensor": {
"scope": [],
@@ -24,5 +25,10 @@
"overflow_check": {
"overflow_nums": 1,
"check_mode":"all"
+ },
+ "run_ut": {
+ "white_list": [],
+ "black_list": [],
+ "error_data_path": "./"
}
}
\ No newline at end of file
diff --git a/debug/accuracy_tools/atat/config/img/free_benchmark.png b/debug/accuracy_tools/msprobe/config/img/free_benchmark.png
similarity index 100%
rename from debug/accuracy_tools/atat/config/img/free_benchmark.png
rename to debug/accuracy_tools/msprobe/config/img/free_benchmark.png
diff --git a/debug/accuracy_tools/atat/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py
similarity index 70%
rename from debug/accuracy_tools/atat/core/common/const.py
rename to debug/accuracy_tools/msprobe/core/common/const.py
index dea829c3ffad125d8c2ac98d97215db37e895b7e..84c5ac1abe4fe44402caf3a1877d99f582626111 100644
--- a/debug/accuracy_tools/atat/core/common/const.py
+++ b/debug/accuracy_tools/msprobe/core/common/const.py
@@ -1,7 +1,9 @@
import os
import stat
+
import numpy as np
+
class Const:
"""
Class for const
@@ -15,6 +17,17 @@ class Const:
OFF = 'OFF'
BACKWARD = 'backward'
FORWARD = 'forward'
+ PRIMITIVE_PREFIX = 'Primitive'
+ DEFAULT_LIST = []
+ DEFAULT_PATH = './'
+ WHITE_LIST = 'white_list'
+ BLACK_LIST = 'black_list'
+ DUMP_TENSOR_DATA = 'dump_tensor_data'
+ NONE = None
+ THREE_SEGMENT = 3
+ FOUR_SEGMENT = 4
+ SIX_SEGMENT = 6
+ SEVEN_SEGMENT = 7
# dump mode
ALL = "all"
@@ -25,6 +38,8 @@ class Const:
API_LIST = "api_list"
API_STACK = "api_stack"
DUMP_MODE = [ALL, LIST, RANGE, STACK, ACL, API_LIST, API_STACK]
+ AUTO = "auto"
+ ONLINE_DUMP_MODE = [ALL, LIST, AUTO, OFF]
SUMMARY = "summary"
MD5 = "md5"
SUMMARY_MODE = [ALL, SUMMARY, MD5]
@@ -35,8 +50,10 @@ class Const:
PKL_SUFFIX = ".pkl"
NUMPY_SUFFIX = ".npy"
+ PT_SUFFIX = ".pt"
ONE_GB = 1073741824 # 1 * 1024 * 1024 * 1024
TEN_GB = 10737418240 # 10 * 1024 * 1024 * 1024
+ ONE_MB = 1048576 # 1 * 1024 * 1024
FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$'
DISTRIBUTED_PREFIX_LENGTH = 60
# env dump path
@@ -52,13 +69,20 @@ class Const:
ENV_ENABLE = "1"
ENV_DISABLE = "0"
MAX_SEED_VALUE = 4294967295 # 2**32 - 1
- TASK_LIST = ["tensor", "statistics", "overflow_check", "free_benchmark"]
- LEVEL_LIST = ["L0", "L1", "L2", "mix"]
STATISTICS = "statistics"
TENSOR = "tensor"
OVERFLOW_CHECK = "overflow_check"
FREE_BENCHMARK = "free_benchmark"
+ RUN_UT = "run_ut"
+ GRAD_PROBE = "grad_probe"
+ TASK_LIST = [TENSOR, STATISTICS, OVERFLOW_CHECK, FREE_BENCHMARK, RUN_UT, GRAD_PROBE]
+ LEVEL_L0 = "L0"
+ LEVEL_L1 = "L1"
+ LEVEL_L2 = "L2"
+ LEVEL_MIX = "mix"
+ LEVEL_LIST = [LEVEL_L0, LEVEL_L1, LEVEL_L2, LEVEL_MIX]
ATTR_NAME_PREFIX = "wrap_"
+ ATTR_NAME_PREFIX_LEN = len(ATTR_NAME_PREFIX)
KERNEL_DUMP = "kernel_dump"
DATA = "data"
PT_FRAMEWORK = "pytorch"
@@ -69,13 +93,17 @@ class Const:
BOOL_TYPE = [bool, np.uint8]
INT_TYPE = [np.int32, np.int64]
NPU = 'NPU'
+ NPU_LOWERCASE = 'npu'
+ CPU_LOWERCASE = 'cpu'
+ CUDA_LOWERCASE = 'cuda'
DISTRIBUTED = 'Distributed'
-
+
INPLACE_LIST = [
"broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter",
- "_reduce_scatter_base", "_all_gather_base", "send", "recv", "irecv", "isend", "all_to_all_single"
+ "_reduce_scatter_base", "_all_gather_base", "send", "recv", "irecv", "isend", "all_to_all_single", "all_to_all",
+ "all_gather_into_tensor", "reduce_scatter_tensor"
]
-
+
CONVERT = {
"int32_to_int64": ["torch.int32", "torch.int64"],
}
@@ -84,6 +112,7 @@ class Const:
"int32_to_int64": ["cross_entropy"]
}
+
class CompareConst:
"""
Class for compare module const
@@ -123,7 +152,12 @@ class CompareConst:
NPU_MD5 = "NPU MD5"
BENCH_MD5 = "BENCH MD5"
RESULT = "Result"
-
+ MAGNITUDE = 0.5
+ OP_NAME = "op_name"
+ INPUT_STRUCT = "input_struct"
+ OUTPUT_STRUCT = "output_struct"
+ SUMMARY = "summary"
+
COMPARE_RESULT_HEADER = [
NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR,
ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO,
@@ -160,6 +194,7 @@ class CompareConst:
WARNING = 'Warning'
ERROR = 'error'
SKIP = 'SKIP'
+ N_A = 'N/A'
BFLOAT16_MIN = -3.3895313892515355e+38
BFLOAT16_MAX = 3.3895313892515355e+38
BFLOAT16_EPS = 3.90625e-3 # 2 ** -8
@@ -167,6 +202,7 @@ class CompareConst:
# accuracy standards
COS_THRESHOLD = 0.99
MAX_ABS_ERR_THRESHOLD = 0.001
+ MAX_RELATIVE_ERR_THRESHOLD = 0.001
COS_MAX_THRESHOLD = 0.9
MAX_ABS_ERR_MAX_THRESHOLD = 1
ACCURACY_CHECK_YES = "Yes"
@@ -183,6 +219,10 @@ class CompareConst:
RED = "FFFF0000"
YELLOW = "FFFF00"
BLUE = "0000FF"
+
+ # run_ut const
+ MAX_TOKENS = 65536
+ SPECIAL_SPARSE_MOED = 4
# highlight rules const
OVERFLOW_LIST = ['nan\t', 'inf\t', '-inf\t', 'nan', 'inf', '-inf']
@@ -195,6 +235,20 @@ class CompareConst:
MAX_RELATIVE_OUT_RED = 0.5
MAX_RELATIVE_OUT_YELLOW = 0.1
MAX_RELATIVE_IN_YELLOW = 0.01
+ MS_GRAPH_BASE = {
+ NPU_NAME: None, BENCH_NAME: None, NPU_DTYPE: None, BENCH_DTYPE: None, NPU_SHAPE: None, BENCH_SHAPE: None,
+ NPU_MAX: None, NPU_MIN: None, NPU_MEAN: None, NPU_NORM: None, BENCH_MAX: None, BENCH_MIN: None,
+ BENCH_MEAN: None, BENCH_NORM: None, ACCURACY: '', ERROR_MESSAGE: ''
+ }
+ MS_GRAPH_NPY = {
+ COSINE: None, MAX_ABS_ERR: None, MAX_RELATIVE_ERR: None, ONE_THOUSANDTH_ERR_RATIO: None,
+ FIVE_THOUSANDTHS_ERR_RATIO: None
+ }
+ MS_GRAPH_STATISTIC = {
+ MAX_DIFF: None, MIN_DIFF: None, MEAN_DIFF: None, NORM_DIFF: None, MAX_RELATIVE_ERR: None,
+ MIN_RELATIVE_ERR: None, MEAN_RELATIVE_ERR: None, NORM_RELATIVE_ERR: None
+ }
+
class FileCheckConst:
"""
@@ -232,6 +286,7 @@ class FileCheckConst:
YAML_SUFFIX: MAX_YAML_SIZE
}
+
class OverflowConst:
"""
Class for Overflow
@@ -239,3 +294,44 @@ class OverflowConst:
OVERFLOW_DEBUG_MODE_ENABLE = "OVERFLOW_DEBUG_MODE_ENABLE"
OVERFLOW_ORIGINAL_MODE = 0
OVERFLOW_DEBUG_MODE = 1
+
+class MsCompareConst:
+ # api_info field
+ MINT = "Mint"
+ MINT_FUNCTIONAL = "MintFunctional"
+
+ TASK_FIELD = "task"
+ STATISTICS_TASK = "statistics"
+ TENSOR_TASK = "tensor"
+ DUMP_DATA_DIR_FIELD = "dump_data_dir"
+ DATA_FIELD = "data"
+
+ #detail_csv
+ DETAIL_CSV_API_NAME = "API Name"
+ DETAIL_CSV_BENCH_DTYPE = "Bench Dtype"
+ DETAIL_CSV_TESTED_DTYPE = "Tested Dtype"
+ DETAIL_CSV_SHAPE = "Shape"
+ DETAIL_CSV_PASS_STATUS = "Status"
+ DETAIL_CSV_MESSAGE = "Message"
+ DETAIL_CSV_FILE_NAME = "accuracy_checking_details"
+
+ #result_csv
+ RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success"
+ RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success"
+ RESULT_CSV_FILE_NAME = "accuracy_checking_result"
+
+ EPSILON = 1e-8
+
+class MsgConst:
+ """
+ Class for log messages const
+ """
+ CLEAR_SYMBOL = "\033[K"
+ LEVEL = ["INFO", "WARNING", "ERROR"]
+ SPECIAL_CHAR = ["\n", "\r", "\u007F", "\b", "\f", "\t", "\u000B", "%08", "%0a", "%0b", "%0c", "%0d", "%7f"]
+
+
+class GraphMode:
+ NPY_MODE = "NPY_MODE"
+ STATISTIC_MODE = "STATISTIC_MODE"
+ ERROR_MODE = "ERROR_MODE"
diff --git a/debug/accuracy_tools/atat/core/common/exceptions.py b/debug/accuracy_tools/msprobe/core/common/exceptions.py
similarity index 52%
rename from debug/accuracy_tools/atat/core/common/exceptions.py
rename to debug/accuracy_tools/msprobe/core/common/exceptions.py
index f6b7c19ba32df57513562436c43518cf4a40a7cc..ea61f8cd58fe057ba6836dd1ed368d52adedeb18 100644
--- a/debug/accuracy_tools/atat/core/common/exceptions.py
+++ b/debug/accuracy_tools/msprobe/core/common/exceptions.py
@@ -1,19 +1,20 @@
class CodedException(Exception):
def __init__(self, code, error_info=''):
super().__init__()
+ self.code = code
self.error_info = self.err_strs.get(code) + error_info
def __str__(self):
return self.error_info
-class MsaccException(CodedException):
+class MsprobeException(CodedException):
INVALID_PARAM_ERROR = 0
OVERFLOW_NUMS_ERROR = 1
err_strs = {
- INVALID_PARAM_ERROR: "[msacc] 无效参数: ",
- OVERFLOW_NUMS_ERROR: "[msacc] 超过预设溢出次数 当前溢出次数:"
+ INVALID_PARAM_ERROR: "[msprobe] 无效参数: ",
+ OVERFLOW_NUMS_ERROR: "[msprobe] 超过预设溢出次数 当前溢出次数:"
}
@@ -26,12 +27,12 @@ class FileCheckException(CodedException):
FILE_TOO_LARGE_ERROR = 5
err_strs = {
- SOFT_LINK_ERROR: "[msacc] 检测到软链接: ",
- FILE_PERMISSION_ERROR: "[msacc] 文件权限错误: ",
- INVALID_FILE_ERROR: "[msacc] 无效文件: ",
- ILLEGAL_PATH_ERROR: "[msacc] 非法文件路径: ",
- ILLEGAL_PARAM_ERROR: "[msacc] 非法打开方式: ",
- FILE_TOO_LARGE_ERROR: "[msacc] 文件过大: "
+ SOFT_LINK_ERROR: "[msprobe] 检测到软链接: ",
+ FILE_PERMISSION_ERROR: "[msprobe] 文件权限错误: ",
+ INVALID_FILE_ERROR: "[msprobe] 无效文件: ",
+ ILLEGAL_PATH_ERROR: "[msprobe] 非法文件路径: ",
+ ILLEGAL_PARAM_ERROR: "[msprobe] 非法打开方式: ",
+ FILE_TOO_LARGE_ERROR: "[msprobe] 文件过大: "
}
@@ -39,8 +40,8 @@ class ParseJsonException(CodedException):
UnexpectedNameStruct = 0
InvalidDumpJson = 1
err_strs = {
- UnexpectedNameStruct: "[msacc] Unexpected name in json: ",
- InvalidDumpJson: "[msacc] json格式不正确: ",
+ UnexpectedNameStruct: "[msprobe] Unexpected name in json: ",
+ InvalidDumpJson: "[msprobe] json格式不正确: ",
}
@@ -49,23 +50,23 @@ class ScopeException(CodedException):
InvalidScope = 1
ArgConflict = 2
err_strs = {
- InvalidApiStr: "[msacc] Invalid api_list: ",
- InvalidScope: "[msacc] Invalid scope: ",
- ArgConflict: "[msacc] Scope and api_list conflict: ",
+ InvalidApiStr: "[msprobe] Invalid api_list: ",
+ InvalidScope: "[msprobe] Invalid scope: ",
+ ArgConflict: "[msprobe] Scope and api_list conflict: ",
}
class RepairException(CodedException):
InvalidRepairType = 0
err_strs = {
- InvalidRepairType: "[msacc] Invalid repair_type: "
+ InvalidRepairType: "[msprobe] Invalid repair_type: "
}
class StepException(CodedException):
InvalidPostProcess = 0
err_strs = {
- InvalidPostProcess: "[msacc] 错误的step后处理配置: ",
+ InvalidPostProcess: "[msprobe] 错误的step后处理配置: ",
}
@@ -73,8 +74,8 @@ class FreeBenchmarkException(CodedException):
UnsupportedType = 0
InvalidGrad = 1
err_strs = {
- UnsupportedType: "[msacc] Free benchmark get unsupported type: ",
- InvalidGrad: "[msacc] Free benchmark gradient invalid: ",
+ UnsupportedType: "[msprobe] Free benchmark get unsupported type: ",
+ InvalidGrad: "[msprobe] Free benchmark gradient invalid: ",
}
diff --git a/debug/accuracy_tools/msprobe/core/common/file_utils.py b/debug/accuracy_tools/msprobe/core/common/file_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0976323b5a760c1e6d250923ab7bfdbc166a0bef
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/core/common/file_utils.py
@@ -0,0 +1,478 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+import csv
+import fcntl
+import os
+import json
+import re
+import shutil
+import yaml
+import numpy as np
+
+from msprobe.core.common.log import logger
+from msprobe.core.common.exceptions import FileCheckException
+from msprobe.core.common.const import FileCheckConst
+
+
+class FileChecker:
+ """
+ The class for check file.
+
+ Attributes:
+ file_path: The file or dictionary path to be verified.
+ path_type: file or dictionary
+ ability(str): FileCheckConst.WRITE_ABLE or FileCheckConst.READ_ABLE to set file has writability or readability
+ file_type(str): The correct file type for file
+ """
+
+ def __init__(self, file_path, path_type, ability=None, file_type=None, is_script=True):
+ self.file_path = file_path
+ self.path_type = self._check_path_type(path_type)
+ self.ability = ability
+ self.file_type = file_type
+ self.is_script = is_script
+
+ @staticmethod
+ def _check_path_type(path_type):
+ if path_type not in [FileCheckConst.DIR, FileCheckConst.FILE]:
+ logger.error(f'The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}.')
+ raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR)
+ return path_type
+
+ def common_check(self):
+ """
+ 功能:用户校验基本文件权限:软连接、文件长度、是否存在、读写权限、文件属组、文件特殊字符
+ 注意:文件后缀的合法性,非通用操作,可使用其他独立接口实现
+ """
+ check_path_exists(self.file_path)
+ check_link(self.file_path)
+ self.file_path = os.path.realpath(self.file_path)
+ check_path_length(self.file_path)
+ check_path_type(self.file_path, self.path_type)
+ self.check_path_ability()
+ if self.is_script:
+ check_path_owner_consistent(self.file_path)
+ check_path_pattern_vaild(self.file_path)
+ check_common_file_size(self.file_path)
+ check_file_suffix(self.file_path, self.file_type)
+ return self.file_path
+
+ def check_path_ability(self):
+ if self.ability == FileCheckConst.WRITE_ABLE:
+ check_path_writability(self.file_path)
+ if self.ability == FileCheckConst.READ_ABLE:
+ check_path_readability(self.file_path)
+ if self.ability == FileCheckConst.READ_WRITE_ABLE:
+ check_path_readability(self.file_path)
+ check_path_writability(self.file_path)
+
+
+class FileOpen:
+ """
+ The class for open file by a safe way.
+
+ Attributes:
+ file_path: The file or dictionary path to be opened.
+ mode(str): The file open mode
+ """
+ SUPPORT_READ_MODE = ["r", "rb"]
+ SUPPORT_WRITE_MODE = ["w", "wb", "a", "ab"]
+ SUPPORT_READ_WRITE_MODE = ["r+", "rb+", "w+", "wb+", "a+", "ab+"]
+
+ def __init__(self, file_path, mode, encoding='utf-8'):
+ self.file_path = file_path
+ self.mode = mode
+ self.encoding = encoding
+ self._handle = None
+
+ def __enter__(self):
+ self.check_file_path()
+ binary_mode = "b"
+ if binary_mode not in self.mode:
+ self._handle = open(self.file_path, self.mode, encoding=self.encoding)
+ else:
+ self._handle = open(self.file_path, self.mode)
+ return self._handle
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if self._handle:
+ self._handle.close()
+
+ def check_file_path(self):
+ support_mode = self.SUPPORT_READ_MODE + self.SUPPORT_WRITE_MODE + self.SUPPORT_READ_WRITE_MODE
+ if self.mode not in support_mode:
+ logger.error("File open not support %s mode" % self.mode)
+ raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR)
+ check_link(self.file_path)
+ self.file_path = os.path.realpath(self.file_path)
+ check_path_length(self.file_path)
+ self.check_ability_and_owner()
+ check_path_pattern_vaild(self.file_path)
+ if os.path.exists(self.file_path):
+ check_common_file_size(self.file_path)
+
+ def check_ability_and_owner(self):
+ if self.mode in self.SUPPORT_READ_MODE:
+ check_path_exists(self.file_path)
+ check_path_readability(self.file_path)
+ check_path_owner_consistent(self.file_path)
+ if self.mode in self.SUPPORT_WRITE_MODE and os.path.exists(self.file_path):
+ check_path_writability(self.file_path)
+ check_path_owner_consistent(self.file_path)
+ if self.mode in self.SUPPORT_READ_WRITE_MODE and os.path.exists(self.file_path):
+ check_path_readability(self.file_path)
+ check_path_writability(self.file_path)
+ check_path_owner_consistent(self.file_path)
+
+
+def check_link(path):
+ abs_path = os.path.abspath(path)
+ if os.path.islink(abs_path):
+ logger.error('The file path {} is a soft link.'.format(path))
+ raise FileCheckException(FileCheckException.SOFT_LINK_ERROR)
+
+
+def check_path_length(path, name_length=None):
+ file_max_name_length = name_length if name_length else FileCheckConst.FILE_NAME_LENGTH
+ if len(path) > FileCheckConst.DIRECTORY_LENGTH or \
+ len(os.path.basename(path)) > file_max_name_length:
+ logger.error('The file path length exceeds limit.')
+ raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
+
+
+def check_path_exists(path):
+ if not os.path.exists(path):
+ logger.error('The file path %s does not exist.' % path)
+ raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
+
+
+def check_path_readability(path):
+ if not os.access(path, os.R_OK):
+ logger.error('The file path %s is not readable.' % path)
+ raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
+
+
+def check_path_writability(path):
+ if not os.access(path, os.W_OK):
+ logger.error('The file path %s is not writable.' % path)
+ raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
+
+
+def check_path_executable(path):
+ if not os.access(path, os.X_OK):
+ logger.error('The file path %s is not executable.' % path)
+ raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
+
+
+def check_other_user_writable(path):
+ st = os.stat(path)
+ if st.st_mode & 0o002:
+ logger.error('The file path %s may be insecure because other users have write permissions. ' % path)
+ raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
+
+
+def check_path_owner_consistent(path):
+ file_owner = os.stat(path).st_uid
+ if file_owner != os.getuid():
+ logger.error('The file path %s may be insecure because is does not belong to you.' % path)
+ raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
+
+
+def check_path_pattern_vaild(path):
+ if not re.match(FileCheckConst.FILE_VALID_PATTERN, path):
+ logger.error('The file path %s contains special characters.' % (path))
+ raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
+
+
+def check_file_size(file_path, max_size):
+ try:
+ file_size = os.path.getsize(file_path)
+ except OSError as os_error:
+ logger.error(f'Failed to open "{file_path}". {str(os_error)}')
+ raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) from os_error
+ if file_size >= max_size:
+ logger.error(f'The size ({file_size}) of {file_path} exceeds ({max_size}) bytes, tools not support.')
+ raise FileCheckException(FileCheckException.FILE_TOO_LARGE_ERROR)
+
+
+def check_common_file_size(file_path):
+ if os.path.isfile(file_path):
+ for suffix, max_size in FileCheckConst.FILE_SIZE_DICT.items():
+ if file_path.endswith(suffix):
+ check_file_size(file_path, max_size)
+ break
+
+
+def check_file_suffix(file_path, file_suffix):
+ if file_suffix:
+ if not file_path.endswith(file_suffix):
+ logger.error(f"The {file_path} should be a {file_suffix} file!")
+ raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
+
+
+def check_path_type(file_path, file_type):
+ if file_type == FileCheckConst.FILE:
+ if not os.path.isfile(file_path):
+ logger.error(f"The {file_path} should be a file!")
+ raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
+ if file_type == FileCheckConst.DIR:
+ if not os.path.isdir(file_path):
+ logger.error(f"The {file_path} should be a dictionary!")
+ raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
+
+
+def make_dir(dir_path):
+ dir_path = os.path.realpath(dir_path)
+ check_path_before_create(dir_path)
+ if os.path.isdir(dir_path):
+ return
+ try:
+ os.mkdir(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY)
+ except OSError as ex:
+ raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR,
+ f"Failed to create {dir_path}. "
+ f"Please check the path permission or disk space. {str(ex)}") from ex
+ file_check = FileChecker(dir_path, FileCheckConst.DIR)
+ file_check.common_check()
+
+
+def create_directory(dir_path):
+ """
+ Function Description:
+ creating a safe directory with specified permissions
+ Parameter:
+ dir_path: directory path
+ Exception Description:
+ when invalid data throw exception
+ """
+ dir_path = os.path.realpath(dir_path)
+ check_path_before_create(dir_path)
+ parent_dir = os.path.dirname(dir_path)
+ if not os.path.isdir(parent_dir):
+ create_directory(parent_dir)
+ make_dir(dir_path)
+
+
+def check_path_before_create(path):
+ if path_len_exceeds_limit(path):
+ raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, 'The file path length exceeds limit.')
+
+ if not re.match(FileCheckConst.FILE_PATTERN, os.path.realpath(path)):
+ raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR,
+ 'The file path {} contains special characters.'.format(path))
+
+
+def check_file_or_directory_path(path, isdir=False):
+ """
+ Function Description:
+ check whether the path is valid
+ Parameter:
+ path: the path to check
+ isdir: the path is dir or file
+ Exception Description:
+ when invalid data throw exception
+ """
+ if isdir:
+ path_checker = FileChecker(path, FileCheckConst.DIR, FileCheckConst.WRITE_ABLE)
+ else:
+ path_checker = FileChecker(path, FileCheckConst.FILE, FileCheckConst.READ_ABLE)
+ path_checker.common_check()
+
+
+def change_mode(path, mode):
+ if not os.path.exists(path) or os.path.islink(path):
+ return
+ try:
+ os.chmod(path, mode)
+ except PermissionError as ex:
+ raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR,
+ 'Failed to change {} authority. {}'.format(path, str(ex))) from ex
+
+
+def path_len_exceeds_limit(file_path):
+ return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \
+ len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH
+
+
+def check_file_type(path):
+ """
+ Function Description:
+ determine if it is a file or a directory
+ Parameter:
+ path: path
+ Exception Description:
+ when neither a file nor a directory throw exception
+ """
+ if os.path.isdir(path):
+ return FileCheckConst.DIR
+ elif os.path.isfile(path):
+ return FileCheckConst.FILE
+ else:
+ logger.error('Neither a file nor a directory.')
+ raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
+
+
+def load_yaml(yaml_path):
+ path_checker = FileChecker(yaml_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.YAML_SUFFIX)
+ checked_path = path_checker.common_check()
+ try:
+ with FileOpen(checked_path, "r") as f:
+ yaml_data = yaml.safe_load(f)
+ except Exception as e:
+ logger.error(f"The yaml file failed to load. Please check the path: {checked_path}.")
+ raise RuntimeError(f"Load yaml file {checked_path} failed.") from e
+ return yaml_data
+
+
+def load_npy(filepath, enable_pickle=False):
+ check_file_or_directory_path(filepath)
+ try:
+ npy = np.load(filepath, allow_pickle=enable_pickle)
+ except Exception as e:
+ logger.error(f"The numpy file failed to load. Please check the path: {filepath}.")
+ raise RuntimeError(f"Load numpy file {filepath} failed.") from e
+ return npy
+
+
+def load_json(json_path):
+ try:
+ with FileOpen(json_path, "r") as f:
+ fcntl.flock(f, fcntl.LOCK_EX)
+ data = json.load(f)
+ fcntl.flock(f, fcntl.LOCK_UN)
+ except Exception as e:
+ logger.error(f'load json file "{os.path.basename(json_path)}" failed.')
+ raise RuntimeError(f"Load json file {json_path} failed.") from e
+ return data
+
+
+def save_json(json_path, data, indent=None):
+ json_path = os.path.realpath(json_path)
+ check_path_before_create(json_path)
+ try:
+ with FileOpen(json_path, 'w') as f:
+ fcntl.flock(f, fcntl.LOCK_EX)
+ json.dump(data, f, indent=indent)
+ fcntl.flock(f, fcntl.LOCK_UN)
+ except Exception as e:
+ logger.error(f'Save json file "{os.path.basename(json_path)}" failed.')
+ raise RuntimeError(f"Save json file {json_path} failed.") from e
+ change_mode(json_path, FileCheckConst.DATA_FILE_AUTHORITY)
+
+
+def move_file(src_path, dst_path):
+ check_file_or_directory_path(src_path)
+ check_path_before_create(dst_path)
+ try:
+ shutil.move(src_path, dst_path)
+ except Exception as e:
+ logger.error(f"move file {src_path} to {dst_path} failed")
+ raise RuntimeError(f"move file {src_path} to {dst_path} failed") from e
+ change_mode(dst_path, FileCheckConst.DATA_FILE_AUTHORITY)
+
+
+def save_npy(data, filepath):
+ filepath = os.path.realpath(filepath)
+ check_path_before_create(filepath)
+ try:
+ np.save(filepath, data)
+ except Exception as e:
+ logger.error(f"The numpy file failed to save. Please check the path: {filepath}.")
+ raise RuntimeError(f"Save numpy file {filepath} failed.") from e
+ change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
+
+
+def save_npy_to_txt(self, data, dst_file='', align=0):
+ if os.path.exists(dst_file):
+ self.log.info("Dst file %s exists, will not save new one.", dst_file)
+ return
+ shape = data.shape
+ data = data.flatten()
+ if align == 0:
+ align = 1 if len(shape) == 0 else shape[-1]
+ elif data.size % align != 0:
+ pad_array = np.zeros((align - data.size % align,))
+ data = np.append(data, pad_array)
+ check_path_before_create(dst_file)
+ try:
+ np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g')
+ except Exception as e:
+ self.log.error("An unexpected error occurred: %s when savetxt to %s" % (str(e)), dst_file)
+ change_mode(dst_file, FileCheckConst.DATA_FILE_AUTHORITY)
+
+
+def save_workbook(workbook, file_path):
+ """
+ 保存工作簿到指定的文件路径
+ workbook: 要保存的工作簿对象
+ file_path: 文件保存路径
+ """
+ file_path = os.path.realpath(file_path)
+ check_path_before_create(file_path)
+ try:
+ workbook.save(file_path)
+ except Exception as e:
+ logger.error(f'Save result file "{os.path.basename(file_path)}" failed')
+ raise RuntimeError(f"Save result file {file_path} failed.") from e
+ change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
+
+
+def write_csv(data, filepath, mode="a+"):
+ file_path = os.path.realpath(filepath)
+ check_path_before_create(filepath)
+ try:
+ with FileOpen(filepath, mode, encoding='utf-8-sig') as f:
+ writer = csv.writer(f)
+ writer.writerows(data)
+ except Exception as e:
+ logger.error(f'Save csv file "{os.path.basename(file_path)}" failed')
+ raise RuntimeError(f"Save csv file {file_path} failed.") from e
+ change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
+
+
+def remove_path(path):
+ if not os.path.exists(path):
+ return
+ try:
+ if os.path.islink(path) or os.path.isfile(path):
+ os.remove(path)
+ else:
+ shutil.rmtree(path)
+ except PermissionError as err:
+ logger.error("Failed to delete {}. Please check the permission.".format(path))
+ raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) from err
+ except Exception as e:
+ logger.error("Failed to delete {}. Please check.".format(path))
+ raise RuntimeError(f"Delete {path} failed.") from e
+
+
+def get_json_contents(file_path):
+ ops = get_file_content_bytes(file_path)
+ try:
+ json_obj = json.loads(ops)
+ except ValueError as error:
+ logger.error('Failed to load json.')
+ raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) from error
+ if not isinstance(json_obj, dict):
+ logger.error('Json file content is not a dictionary!')
+ raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
+ return json_obj
+
+
+def get_file_content_bytes(file):
+ with FileOpen(file, 'rb') as file_handle:
+ return file_handle.read()
diff --git a/debug/accuracy_tools/atat/core/common/log.py b/debug/accuracy_tools/msprobe/core/common/log.py
similarity index 100%
rename from debug/accuracy_tools/atat/core/common/log.py
rename to debug/accuracy_tools/msprobe/core/common/log.py
diff --git a/debug/accuracy_tools/atat/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py
similarity index 72%
rename from debug/accuracy_tools/atat/core/common/utils.py
rename to debug/accuracy_tools/msprobe/core/common/utils.py
index 088530f3c5c88e4e97cd1eda470b02a6a2176fdf..6795bd83f350e2084dc94991cd48c481f38586aa 100644
--- a/debug/accuracy_tools/atat/core/common/utils.py
+++ b/debug/accuracy_tools/msprobe/core/common/utils.py
@@ -15,20 +15,21 @@
# limitations under the License.
"""
import collections
+import fcntl
import os
import re
import shutil
-import stat
import subprocess
import time
import json
+import csv
from datetime import datetime, timezone
-from pathlib import Path
+import yaml
import numpy as np
-from atat.core.common.file_check import FileOpen, FileChecker
-from atat.core.common.const import Const, FileCheckConst, CompareConst, OverflowConst
-from atat.core.common.log import logger
+from msprobe.core.common.file_utils import FileOpen, FileChecker, change_mode
+from msprobe.core.common.const import Const, FileCheckConst, CompareConst
+from msprobe.core.common.log import logger
device = collections.namedtuple('device', ['type', 'index'])
@@ -60,6 +61,8 @@ class CompareException(Exception):
OVER_SIZE_FILE_ERROR = 18
INVALID_SUMMARY_MODE = 19
INVALID_TASK_ERROR = 20
+ DETACH_ERROR = 21
+
def __init__(self, code, error_info: str = ""):
super(CompareException, self).__init__()
@@ -70,21 +73,26 @@ class CompareException(Exception):
return self.error_info
-class DumpException(CompareException):
- pass
+class PathAddException(Exception):
+ """
+ Class for Path Add Exception
+ """
+ INVALID_FILE_ERROR = 0
+ PERMISSION_DENIED_ERROR = 1
+ UNEXPECTED_ERROR = 2
+ PACKAGE_VERSION_CHECK = 3
+ def __init__(self, code, error_info: str = ""):
+ super(PathAddException, self).__init__()
+ self.code = code
+ self.error_info = error_info
-def make_dump_path_if_not_exists(dump_path):
- if not os.path.exists(dump_path):
- try:
- Path(dump_path).mkdir(mode=0o750, exist_ok=True, parents=True)
- except OSError as ex:
- logger.error(
- 'Failed to create {}.Please check the path permission or disk space .{}'.format(dump_path, str(ex)))
- raise CompareException(CompareException.INVALID_PATH_ERROR) from ex
- else:
- if not os.path.isdir(dump_path):
- logger.error('{} already exists and is not a directory.'.format(dump_path))
+ def __str__(self):
+ return self.error_info
+
+
+class DumpException(CompareException):
+ pass
def check_mode_valid(mode, scope=None, api_list=None):
@@ -148,21 +156,24 @@ def check_summary_only_valid(summary_only):
return summary_only
-def check_compare_param(input_parma, output_path, stack_mode=False, summary_compare=False, md5_compare=False):
- if not (isinstance(input_parma, dict) and isinstance(output_path, str)):
+def check_compare_param(input_param, output_path, summary_compare=False, md5_compare=False):
+ if not (isinstance(input_param, dict) and isinstance(output_path, str)):
logger.error("Invalid input parameters")
raise CompareException(CompareException.INVALID_PARAM_ERROR)
- check_file_or_directory_path(input_parma.get("npu_json_path"), False)
- check_file_or_directory_path(input_parma.get("bench_json_path"), False)
- check_file_or_directory_path(input_parma.get("stack_json_path"), False)
+
+ check_file_or_directory_path(input_param.get("npu_json_path"), False)
+ check_file_or_directory_path(input_param.get("bench_json_path"), False)
+ check_file_or_directory_path(input_param.get("stack_json_path"), False)
if not summary_compare and not md5_compare:
- check_file_or_directory_path(input_parma.get("npu_dump_data_dir"), True)
- check_file_or_directory_path(input_parma.get("bench_dump_data_dir"), True)
+ check_file_or_directory_path(input_param.get("npu_dump_data_dir"), True)
+ check_file_or_directory_path(input_param.get("bench_dump_data_dir"), True)
check_file_or_directory_path(output_path, True)
- with FileOpen(input_parma.get("npu_json_path"), "r") as npu_json, \
- FileOpen(input_parma.get("bench_json_path"), "r") as bench_json, \
- FileOpen(input_parma.get("stack_json_path"), "r") as stack_json:
- check_json_file(input_parma, npu_json, bench_json, stack_json)
+
+ with FileOpen(input_param.get("npu_json_path"), "r") as npu_json, \
+ FileOpen(input_param.get("bench_json_path"), "r") as bench_json, \
+ FileOpen(input_param.get("stack_json_path"), "r") as stack_json:
+ check_json_file(input_param, npu_json, bench_json, stack_json)
+
def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False):
@@ -257,6 +268,17 @@ def remove_path(path):
raise CompareException(CompareException.INVALID_PATH_ERROR) from err
+def move_file(src_path, dst_path):
+ check_file_or_directory_path(src_path)
+ check_path_before_create(dst_path)
+ try:
+ shutil.move(src_path, dst_path)
+ except Exception as e:
+ logger.error(f"move file {src_path} to {dst_path} failed")
+ raise RuntimeError(f"move file {src_path} to {dst_path} failed") from e
+ change_mode(dst_path, FileCheckConst.DATA_FILE_AUTHORITY)
+
+
def get_dump_data_path(dump_dir):
"""
Function Description:
@@ -279,24 +301,6 @@ def get_dump_data_path(dump_dir):
return dump_data_path, file_is_exist
-def create_directory(dir_path):
- """
- Function Description:
- creating a directory with specified permissions
- Parameter:
- dir_path: directory path
- Exception Description:
- when invalid data throw exception
- """
- if not os.path.exists(dir_path):
- try:
- os.makedirs(dir_path, mode=0o700)
- except OSError as ex:
- logger.error(
- 'Failed to create {}.Please check the path permission or disk space .{}'.format(dir_path, str(ex)))
- raise CompareException(CompareException.INVALID_PATH_ERROR) from ex
-
-
def execute_command(cmd):
"""
Function Description:
@@ -318,15 +322,6 @@ def execute_command(cmd):
raise CompareException(CompareException.INVALID_DATA_ERROR)
-def save_numpy_data(file_path, data):
- """
- save_numpy_data
- """
- if not os.path.exists(os.path.dirname(file_path)):
- os.makedirs(os.path.dirname(file_path))
- np.save(file_path, data)
-
-
def parse_value_by_comma(value):
"""
parse value by comma, like '1,2,4,8'
@@ -472,14 +467,14 @@ def md5_find(data):
def task_dumppath_get(input_param):
- npu_json_path = input_param.get("npu_json_path", None)
- bench_json_path = input_param.get("bench_json_path", None)
- if not npu_json_path or not bench_json_path:
+ npu_path = input_param.get("npu_json_path", None)
+ bench_path = input_param.get("bench_json_path", None)
+ if not npu_path or not bench_path:
logger.error(f"Please check the json path is valid.")
raise CompareException(CompareException.INVALID_PATH_ERROR)
- with FileOpen(npu_json_path, 'r') as npu_f:
+ with FileOpen(npu_path, 'r') as npu_f:
npu_json_data = json.load(npu_f)
- with FileOpen(bench_json_path, 'r') as bench_f:
+ with FileOpen(bench_path, 'r') as bench_f:
bench_json_data = json.load(bench_f)
if npu_json_data['task'] != bench_json_data['task']:
logger.error(f"Please check the dump task is consistent.")
@@ -496,8 +491,8 @@ def task_dumppath_get(input_param):
else:
logger.error(f"Compare is not required for overflow_check or free_benchmark.")
raise CompareException(CompareException.INVALID_TASK_ERROR)
- input_param['npu_dump_data_dir'] = npu_json_data['dump_data_dir']
- input_param['bench_dump_data_dir'] = bench_json_data['dump_data_dir']
+ input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
+ input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
return summary_compare, md5_compare
@@ -514,3 +509,122 @@ def get_header_index(header_name, summary_compare=False):
def convert_tuple(data):
return data if isinstance(data, tuple) else (data, )
+
+
+def write_csv(data, filepath, mode="a+"):
+ exist = os.path.exists(filepath)
+ with FileOpen(filepath, mode, encoding='utf-8-sig') as f:
+ writer = csv.writer(f)
+ writer.writerows(data)
+ if not exist:
+ change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
+
+
+def load_npy(filepath, enable_pickle=False):
+ check_file_or_directory_path(filepath)
+ try:
+ npy = np.load(filepath, allow_pickle=enable_pickle)
+ except Exception as e:
+ logger.error(f"The numpy file failed to load. Please check the path: {filepath}.")
+ raise RuntimeError(f"Load numpy file {filepath} failed.") from e
+ return npy
+
+
+def save_npy(data, filepath):
+ filepath = os.path.realpath(filepath)
+ check_path_before_create(filepath)
+ try:
+ np.save(filepath, data)
+ except Exception as e:
+ logger.error(f"The numpy file failed to save. Please check the path: {filepath}.")
+ raise RuntimeError(f"Save numpy file {filepath} failed.") from e
+ change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
+
+def save_npy_to_txt(self, data, dst_file='', align=0):
+ if os.path.exists(dst_file):
+ self.log.info("Dst file %s exists, will not save new one.", dst_file)
+ return
+ shape = data.shape
+ data = data.flatten()
+ if align == 0:
+ align = 1 if len(shape) == 0 else shape[-1]
+ elif data.size % align != 0:
+ pad_array = np.zeros((align - data.size % align,))
+ data = np.append(data, pad_array)
+ check_path_before_create(dst_file)
+ try:
+ np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g')
+ except Exception as e:
+ self.log.error("An unexpected error occurred: %s when savetxt to %s" % (str(e)), dst_file)
+ change_mode(dst_file, FileCheckConst.DATA_FILE_AUTHORITY)
+
+def get_json_contents(file_path):
+ ops = get_file_content_bytes(file_path)
+ try:
+ json_obj = json.loads(ops)
+ except ValueError as error:
+ logger.error('Failed to load json.')
+ raise CompareException(CompareException.INVALID_FILE_ERROR) from error
+ if not isinstance(json_obj, dict):
+ logger.error('Json file content is not a dictionary!')
+ raise CompareException(CompareException.INVALID_FILE_ERROR)
+ return json_obj
+
+
+def get_file_content_bytes(file):
+ with FileOpen(file, 'rb') as file_handle:
+ return file_handle.read()
+
+
+def load_yaml(yaml_path):
+ path_checker = FileChecker(yaml_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.YAML_SUFFIX)
+ checked_path = path_checker.common_check()
+ try:
+ with FileOpen(checked_path, "r") as f:
+ yaml_data = yaml.safe_load(f)
+ except Exception as e:
+ logger.error(f"The yaml file failed to load. Please check the path: {checked_path}.")
+ raise RuntimeError(f"Load yaml file {checked_path} failed.") from e
+ return yaml_data
+
+
+def save_workbook(workbook, file_path):
+ """
+ 保存工作簿到指定的文件路径
+ workbook: 要保存的工作簿对象
+ file_path: 文件保存路径
+ """
+ file_path = os.path.realpath(file_path)
+ check_path_before_create(file_path)
+ try:
+ workbook.save(file_path)
+ except Exception as e:
+ logger.error(f'Save result file "{os.path.basename(file_path)}" failed')
+ raise CompareException(CompareException.WRITE_FILE_ERROR) from e
+ change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
+
+
+def load_json(json_path):
+ try:
+ with FileOpen(json_path, "r") as f:
+ fcntl.flock(f, fcntl.LOCK_EX)
+ data = json.load(f)
+ fcntl.flock(f, fcntl.LOCK_UN)
+ except Exception as e:
+ logger.error(f'load json file "{os.path.basename(json_path)}" failed.')
+ raise DumpException(DumpException.WRITE_FILE_ERROR) from e
+ return data
+
+
+def save_json(json_path, data, indent=None):
+ json_path = os.path.realpath(json_path)
+ check_path_before_create(json_path)
+ try:
+ with FileOpen(json_path, 'w') as f:
+ fcntl.flock(f, fcntl.LOCK_EX)
+ json.dump(data, f, indent=indent)
+ fcntl.flock(f, fcntl.LOCK_UN)
+ except Exception as e:
+ logger.error(f'Save json file "{os.path.basename(json_path)}" failed.')
+ raise DumpException(DumpException.WRITE_FILE_ERROR) from e
+ change_mode(json_path, FileCheckConst.DATA_FILE_AUTHORITY)
diff --git a/debug/accuracy_tools/atat/core/common_config.py b/debug/accuracy_tools/msprobe/core/common_config.py
similarity index 50%
rename from debug/accuracy_tools/atat/core/common_config.py
rename to debug/accuracy_tools/msprobe/core/common_config.py
index e256372ca877be1ca5474dd87c00decdf9d3c1c1..3c92d6e766f568ff012cc3358ec4d0d6064aa1d2 100644
--- a/debug/accuracy_tools/atat/core/common_config.py
+++ b/debug/accuracy_tools/msprobe/core/common_config.py
@@ -1,6 +1,6 @@
-from atat.core.common.const import Const
-from atat.core.common.log import logger
-from atat.core.common.exceptions import MsaccException
+from msprobe.core.common.const import Const
+from msprobe.core.common.log import logger
+from msprobe.core.common.exceptions import MsprobeException
class CommonConfig:
@@ -14,28 +14,35 @@ class CommonConfig:
self.acl_config = json_config.get('acl_config')
self.is_deterministic = json_config.get('is_deterministic', False)
self.enable_dataloader = json_config.get('enable_dataloader', False)
+ self.enable_step_auto_dump = json_config.get('enable_step_auto_dump', False)
self._check_config()
def _check_config(self):
if self.task and self.task not in Const.TASK_LIST:
- logger.error_log_with_exp(
- "task is invalid, it should be one of {}".format(Const.TASK_LIST), MsaccException(MsaccException.INVALID_PARAM_ERROR))
+ logger.error_log_with_exp("task is invalid, it should be one of {}".format(Const.TASK_LIST),
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
if self.rank is not None and not isinstance(self.rank, list):
- logger.error_log_with_exp("rank is invalid, it should be a list", MsaccException(MsaccException.INVALID_PARAM_ERROR))
+ logger.error_log_with_exp("rank is invalid, it should be a list",
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
if self.step is not None and not isinstance(self.step, list):
- logger.error_log_with_exp("step is invalid, it should be a list", MsaccException(MsaccException.INVALID_PARAM_ERROR))
+ logger.error_log_with_exp("step is invalid, it should be a list",
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
if self.level and self.level not in Const.LEVEL_LIST:
- logger.error_log_with_exp(
- "level is invalid, it should be one of {}".format(Const.LEVEL_LIST), MsaccException(MsaccException.INVALID_PARAM_ERROR))
+ logger.error_log_with_exp("level is invalid, it should be one of {}".format(Const.LEVEL_LIST),
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
if self.seed is not None and not isinstance(self.seed, int):
- logger.error_log_with_exp("seed is invalid, it should be an integer", MsaccException(MsaccException.INVALID_PARAM_ERROR))
+ logger.error_log_with_exp("seed is invalid, it should be an integer",
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
if not isinstance(self.is_deterministic, bool):
- logger.error_log_with_exp(
- "is_deterministic is invalid, it should be a boolean", MsaccException(MsaccException.INVALID_PARAM_ERROR))
+ logger.error_log_with_exp("is_deterministic is invalid, it should be a boolean",
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
if not isinstance(self.enable_dataloader, bool):
- logger.error_log_with_exp(
- "enable_dataloader is invalid, it should be a boolean", MsaccException(MsaccException.INVALID_PARAM_ERROR))
-
+ logger.error_log_with_exp("enable_dataloader is invalid, it should be a boolean",
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
+ if not isinstance(self.enable_step_auto_dump, bool):
+ logger.error_log_with_exp("enable_step_auto_dump is invalid, it should be a boolean",
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
+
class BaseConfig:
def __init__(self, json_config):
@@ -44,15 +51,17 @@ class BaseConfig:
self.data_mode = json_config.get('data_mode')
self.backward_input = json_config.get("backward_input")
self.file_format = json_config.get("file_format")
- self.summary_mode = json_config.get("summary_mode")
- self.overflow_num = json_config.get("overflow_num")
+ self.summary_mode = json_config.get("summary_mode")
+ self.overflow_nums = json_config.get("overflow_nums")
self.check_mode = json_config.get("check_mode")
def check_config(self):
if self.scope is not None and not isinstance(self.scope, list):
- logger.error_log_with_exp("scope is invalid, it should be a list", MsaccException(MsaccException.INVALID_PARAM_ERROR))
+ logger.error_log_with_exp("scope is invalid, it should be a list",
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
if self.list is not None and not isinstance(self.list, list):
- logger.error_log_with_exp("list is invalid, it should be a list", MsaccException(MsaccException.INVALID_PARAM_ERROR))
+ logger.error_log_with_exp("list is invalid, it should be a list",
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
if self.data_mode is not None and not isinstance(self.data_mode, list):
- logger.error_log_with_exp("data_mode is invalid, it should be a list", MsaccException(MsaccException.INVALID_PARAM_ERROR))
-
+ logger.error_log_with_exp("data_mode is invalid, it should be a list",
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cf02b953f6898e753139b3e0801f4dcf3799db1
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/core/compare/utils.py
@@ -0,0 +1,430 @@
+
+import os
+import re
+import numpy as np
+from msprobe.core.common.const import Const, CompareConst
+from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger
+from msprobe.core.common.file_utils import check_file_or_directory_path
+
+
+def extract_json(dirname, stack_json=False):
+ json_path = ''
+ for fname in os.listdir(dirname):
+ if fname == "construct.json":
+ continue
+ full_path = os.path.join(dirname, fname)
+ if full_path.endswith('.json'):
+ json_path = full_path
+ if not stack_json and 'stack' not in json_path:
+ break
+ if stack_json and 'stack' in json_path:
+ break
+
+ # Provide robustness on invalid directory inputs
+ if not json_path:
+ logger.error(f'No file is found in dump dir {dirname}. ')
+ raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
+ return json_path
+
+
+def check_and_return_dir_contents(dump_dir, prefix):
+ """
+ check the given dump dir and validate files in dump dir by using the given prefix patterns to build a
+ pattern: ^{prefix}(?:0|[0-9][1-9]*)?$
+
+ Args:
+ dump_dir (str): dump dir
+ prefix (str): prefix for the patterns, prefix should be less than 20 characters and alphanumeric/-/_ only
+
+ Returns:
+ content [list]: dir contents
+ Raises:
+ CompareException: invalid path
+ ValueError: prefix not match the patterns
+
+ """
+ check_regex_prefix_format_valid(prefix)
+ check_file_or_directory_path(dump_dir, True)
+ contents = os.listdir(dump_dir)
+ pattern = re.compile(rf'^{prefix}(?:0|[0-9][1-9]*)?$')
+ for name in contents:
+ if not pattern.match(name):
+ logger.error(
+ f"dump_dir contains '{name}'. Expected '{prefix}'. This name is not in the format of dump "
+ f"output. Please check and delete irrelevant files in {dump_dir} and try again."
+ )
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
+ return contents
+
+
+def rename_api(npu_name, process):
+ npu_split = npu_name.split(process)
+ torch_func_index, in_out = npu_split[0], npu_split[1]
+ torch_func_split = torch_func_index.rsplit(Const.SEP, 2)
+ torch_func = str(torch_func_split[0]) + str(in_out)
+ return torch_func
+
+
+def read_op(op_data, op_name):
+ op_parsed_list = Const.DEFAULT_LIST
+ if Const.FORWARD in op_name:
+ if Const.INPUT_ARGS in op_data:
+ input_item = op_data[Const.INPUT_ARGS]
+ input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
+ op_parsed_list = input_parsed_list.copy()
+ input_parsed_list.clear()
+ if Const.INPUT_KWARGS in op_data:
+ kwargs_item = op_data[Const.INPUT_KWARGS]
+ if isinstance(kwargs_item, dict) and "type" in kwargs_item or isinstance(kwargs_item, list):
+ kwarg_parsed_list = op_item_parse(kwargs_item, op_name + '.input', None)
+ op_parsed_list += kwarg_parsed_list
+ kwarg_parsed_list.clear()
+ elif kwargs_item:
+ for kwarg in kwargs_item:
+ kwarg_parsed_list = op_item_parse(kwargs_item[kwarg], op_name + '.input.' + kwarg, None)
+ op_parsed_list += kwarg_parsed_list
+ kwarg_parsed_list.clear()
+ if Const.OUTPUT in op_data:
+ output_item = op_data[Const.OUTPUT]
+ output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
+ op_parsed_list += output_parsed_list
+ output_parsed_list.clear()
+ if Const.BACKWARD in op_name:
+ if Const.INPUT in op_data:
+ input_item = op_data[Const.INPUT]
+ input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
+ op_parsed_list = input_parsed_list.copy()
+ input_parsed_list.clear()
+ if Const.OUTPUT in op_data:
+ output_item = op_data[Const.OUTPUT]
+ output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
+ op_parsed_list += output_parsed_list
+ output_parsed_list.clear()
+ return op_parsed_list
+
+
+def op_item_parse(item, op_name, index, item_list=None, top_bool=True):
+ if item_list is None:
+ item_list = []
+ if item is None or (isinstance(item, dict) and not item):
+ if not top_bool:
+ tmp = {'full_op_name': op_name + '.' + str(index), 'Max': None, 'Min': None, 'Mean': None, 'Norm': None,
+ 'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'}
+ else:
+ tmp = {'full_op_name': op_name + '.0', 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None,
+ 'shape': None, 'md5': None, 'data_name': '-1'}
+ item_list.append(tmp)
+ return item_list
+ if index is None:
+ if isinstance(item, dict):
+ full_op_name = op_name + '.0'
+ else:
+ full_op_name = op_name
+ else:
+ full_op_name = op_name + Const.SEP + str(index)
+ if isinstance(item, dict):
+ if 'type' not in item:
+ for kwarg in item:
+ kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None)
+ item_list += kwarg_parsed_list
+ kwarg_parsed_list.clear()
+ elif 'dtype' in item:
+ parsed_item = item
+ parsed_item['full_op_name'] = full_op_name
+ item_list.append(parsed_item)
+ elif 'type' in item:
+ parsed_item = {}
+ if item['type'] == 'torch.Size':
+ parsed_item['full_op_name'] = full_op_name
+ parsed_item['dtype'] = 'torch.Size'
+ parsed_item['shape'] = str(item['value'])
+ parsed_item['md5'] = None
+ parsed_item['Max'] = None
+ parsed_item['Min'] = None
+ parsed_item['Mean'] = None
+ parsed_item['Norm'] = None
+ parsed_item['data_name'] = '-1'
+ item_list.append(parsed_item)
+ elif item['type'] == 'slice':
+ parsed_item['full_op_name'] = full_op_name
+ parsed_item['dtype'] = 'slice'
+ parsed_item['shape'] = str(np.shape(np.array(item['value'])))
+ parsed_item['md5'] = None
+ parsed_item['Max'] = None
+ parsed_item['Min'] = None
+ parsed_item['Mean'] = None
+ parsed_item['Norm'] = None
+ parsed_item['data_name'] = '-1'
+ item_list.append(parsed_item)
+ else:
+ parsed_item['full_op_name'] = full_op_name
+ parsed_item['dtype'] = str(type(item['value']))
+ parsed_item['shape'] = '[]'
+ parsed_item['md5'] = None
+ parsed_item['Max'] = item['value']
+ parsed_item['Min'] = item['value']
+ parsed_item['Mean'] = item['value']
+ parsed_item['Norm'] = item['value']
+ parsed_item['data_name'] = '-1'
+ item_list.append(parsed_item)
+ else:
+ resolve_api_special_parameters(item, full_op_name, item_list)
+ else:
+ for j, item_spec in enumerate(item):
+ op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False)
+ return item_list
+
+
+def resolve_api_special_parameters(data_dict, full_op_name, item_list):
+ """
+ Function Description:
+ 解析下面格式的数据, 是api参数的一种特殊格式
+ {
+ "last_hidden_state": {
+ "type": "torch.Tensor",
+ "dtype": "torch.bfloat16",
+ ...
+ },
+ "loss": {
+ "type": "torch.Tensor",
+ "dtype": "torch.float32",
+ ...
+ }
+ }
+ Parameter:
+ data_dict: 字典格式的数据
+ full_op_name: 参数的全名字符串
+ item_list: 参数信息集合
+ """
+ for key, value in data_dict.items():
+ if isinstance(value, dict):
+ parsed_item = value
+ parts = full_op_name.split(Const.SEP)
+ parts.insert(-1, key)
+ full_op_name_new = ".".join(parts)
+ parsed_item['full_op_name'] = full_op_name_new
+ item_list.append(parsed_item)
+
+
+def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=False):
+ def get_accuracy_core(n_start, n_len, b_start, b_len, key):
+ min_len = min(n_len, b_len)
+ npu_stack_info = n_dict.get("stack_info", None)
+ bench_stack_info = b_dict.get("stack_info", None)
+ has_stack = npu_stack_info and bench_stack_info
+
+ all_mode_bool = not (summary_compare or md5_compare)
+ if all_mode_bool:
+ npu_data_name = n_dict.get("data_name", None)
+ bench_data_name = b_dict.get("data_name", None)
+
+ for index in range(min_len):
+
+ n_name = n_dict['op_name'][n_start + index]
+ b_name = b_dict['op_name'][b_start + index]
+ n_struct = n_dict[key][index]
+ b_struct = b_dict[key][index]
+ err_msg = ""
+ if md5_compare:
+ result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
+ n_struct[2], b_struct[2],
+ CompareConst.PASS if n_struct[2] == b_struct[2] else CompareConst.DIFF]
+ if has_stack and index == 0 and key == "input_struct":
+ result_item.extend(npu_stack_info)
+ else:
+ result_item.append(CompareConst.NONE)
+ result.append(result_item)
+ continue
+
+ if summary_compare:
+ result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
+ " ", " ", " ", " ", " ", " ", " ", " "]
+ else:
+ result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
+ " ", " ", " ", " ", " "]
+
+ npu_summary_data = n_dict.get("summary")[n_start + index]
+ result_item.extend(npu_summary_data)
+ bench_summary_data = b_dict.get("summary")[b_start + index]
+ result_item.extend(bench_summary_data)
+
+ if summary_compare:
+ start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
+ warning_flag = False
+ for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
+ if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
+ diff = npu_val - bench_val
+ if bench_val != 0:
+ relative = str(abs((diff / bench_val) * 100)) + '%'
+ else:
+ relative = "N/A"
+ result_item[start_idx + i] = diff
+ result_item[start_idx + i + 4] = relative
+ magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
+ if magnitude_diff > 0.5:
+ warning_flag = True
+ else:
+ result_item[start_idx + i] = CompareConst.NONE
+ accuracy_check = CompareConst.WARNING if warning_flag else ""
+ err_msg += "Need double check api accuracy." if warning_flag else ""
+ for i in range(start_idx, len(result_item)):
+ if str(result_item[i]) in ('inf', '-inf', 'nan'):
+ result_item[i] = f'{result_item[i]}\t'
+
+ result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES)
+ result_item.append(err_msg)
+ if has_stack and index == 0 and key == "input_struct":
+ result_item.extend(npu_stack_info)
+ else:
+ result_item.append(CompareConst.NONE)
+ if all_mode_bool:
+ result_item.append(npu_data_name[n_start + index])
+
+ result.append(result_item)
+
+ if n_len > b_len:
+ for index in range(b_len, n_len):
+ n_name = n_dict['op_name'][n_start + index]
+ n_struct = n_dict[key][index]
+ if md5_compare:
+ result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN,
+ n_struct[1], CompareConst.NAN, n_struct[2], CompareConst.NAN, CompareConst.NAN]
+ result.append(result_item)
+ continue
+ result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN,
+ n_struct[1], CompareConst.NAN, " ", " ", " ", " ", " "]
+ summary_data = n_dict.get("summary")[n_start + index]
+ result_item.extend(summary_data)
+ summary_data = [CompareConst.NAN for _ in range(len(n_dict.get("summary")[0]))]
+ result_item.extend(summary_data)
+
+ err_msg = ""
+ result_item.append(CompareConst.ACCURACY_CHECK_YES)
+ result_item.append(err_msg)
+
+ if has_stack and index == 0 and key == "input_struct":
+ result_item.extend(npu_stack_info)
+ else:
+ result_item.append(CompareConst.NONE)
+ if all_mode_bool:
+ result_item.append(npu_data_name[n_start + index])
+
+ result.append(result_item)
+
+ n_num = len(n_dict['op_name'])
+ b_num = len(b_dict['op_name'])
+ n_num_input = len([name for name in n_dict['op_name'] if Const.INPUT in name])
+ b_num_input = len([name for name in b_dict['op_name'] if Const.INPUT in name])
+ n_num_kwarg = len([name for name in n_dict['op_name'] if 'kwarg' in name])
+ b_num_kwarg = len([name for name in b_dict['op_name'] if 'kwarg' in name])
+ n_num_output = n_num - n_num_input - n_num_kwarg
+ b_num_output = b_num - b_num_input - b_num_kwarg
+ get_accuracy_core(0, n_num_input, 0, b_num_input, 'input_struct')
+ get_accuracy_core(n_num_input, n_num_kwarg, b_num_input, b_num_kwarg, "kwargs_struct")
+ get_accuracy_core(n_num_input + n_num_kwarg, n_num_output, b_num_input + b_num_kwarg, b_num_output, 'output_struct')
+
+
+def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
+ index_out = 0
+ npu_stack_info = n_dict.get("stack_info", None)
+ bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A
+ err_msg = CompareConst.NO_BENCH
+ accuracy_check_res = CompareConst.N_A
+ for index, n_name in enumerate(n_dict["op_name"]):
+ if n_name.find("input") != -1:
+ n_struct = n_dict["input_struct"][index]
+ else:
+ n_struct = n_dict["output_struct"][index_out]
+ index_out += 1
+
+ result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
+ if md5_compare:
+ result_item.extend([CompareConst.N_A] * 3)
+ if npu_stack_info and index == 0:
+ result_item.extend(npu_stack_info)
+ else:
+ result_item.append(CompareConst.NONE)
+ result.append(result_item)
+ continue
+ if summary_compare:
+ result_item.extend([CompareConst.N_A] * 8)
+ else:
+ result_item.extend([CompareConst.N_A] * 5)
+ npu_summary_data = n_dict.get("summary")[index]
+ result_item.extend(npu_summary_data)
+ bench_summary_data = [CompareConst.N_A] * 4
+ result_item.extend(bench_summary_data)
+ result_item.append(accuracy_check_res)
+ result_item.append(err_msg)
+ if npu_stack_info and index == 0:
+ result_item.extend(npu_stack_info)
+ else:
+ result_item.append(CompareConst.NONE)
+ if not md5_compare and not summary_compare and result_item[1] == CompareConst.N_A:
+ result_item.extend(["-1"])
+ result.append(result_item)
+
+
+def merge_tensor(tensor_list, summary_compare, md5_compare):
+ op_dict = {}
+ op_dict["op_name"] = []
+ op_dict["input_struct"] = []
+ op_dict["kwargs_struct"] = []
+ op_dict["output_struct"] = []
+ op_dict["summary"] = []
+ op_dict["stack_info"] = []
+
+ all_mode_bool = not (summary_compare or md5_compare)
+ if all_mode_bool:
+ op_dict["data_name"] = []
+
+ for tensor in tensor_list:
+ if len(tensor) == 2:
+ op_dict['stack_info'].append(tensor['full_info'])
+ break
+ op_dict["op_name"].append(tensor['full_op_name'])
+ if not md5_compare:
+ if tensor['full_op_name'].find("input") != -1:
+ op_dict["input_struct"].append((tensor['dtype'], tensor['shape']))
+ elif tensor['full_op_name'].find("kwarg") != -1:
+ op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape']))
+ elif tensor['full_op_name'].find("output") != -1:
+ op_dict["output_struct"].append((tensor['dtype'], tensor['shape']))
+ else:
+ if tensor['full_op_name'].find("input") != -1:
+ op_dict["input_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
+ elif tensor['full_op_name'].find("kwarg") != -1:
+ op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
+ elif tensor['full_op_name'].find("output") != -1:
+ op_dict["output_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
+
+ op_dict["summary"].append([tensor['Max'], tensor['Min'], tensor['Mean'], tensor['Norm']])
+
+ if all_mode_bool:
+ op_dict["data_name"].append(tensor['data_name'])
+
+ if not op_dict["kwargs_struct"]:
+ del op_dict["kwargs_struct"]
+ return op_dict if op_dict["op_name"] else {}
+
+
+def _compare_parser(parser):
+ parser.add_argument("-i", "--input_path", dest="input_path", type=str,
+ help=" The compare input path, a dict json.", required=True)
+ parser.add_argument("-o", "--output_path", dest="output_path", type=str,
+ help=" The compare task result out path.", required=True)
+ parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true",
+ help=" Whether to save stack info.", required=False)
+ parser.add_argument("-c", "--compare_only", dest="compare_only", action="store_true",
+ help=" Whether to give advisor.", required=False)
+ parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true",
+ help=" Whether to perform a fuzzy match on the api name.", required=False)
+ parser.add_argument("-cm", "--cell_mapping", dest="cell_mapping", type=str, nargs='?', const=True,
+ help=" The cell mapping file path.", required=False)
+ parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True,
+ help=" The api mapping file path.", required=False)
+
+
+
+
+
diff --git a/debug/accuracy_tools/atat/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py
similarity index 85%
rename from debug/accuracy_tools/atat/core/data_dump/data_collector.py
rename to debug/accuracy_tools/msprobe/core/data_dump/data_collector.py
index f6a9a70b138f1f58c777c5eacf1168b628de3f38..db437539afeb98050ce59aad87a1e79d98b84085 100644
--- a/debug/accuracy_tools/atat/core/data_dump/data_collector.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py
@@ -1,11 +1,10 @@
-
import os
-from atat.core.data_dump.scope import build_scope, ListScope
-from atat.core.data_dump.json_writer import DataWriter
-from atat.core.common.log import logger
-from atat.core.common.const import Const
-from atat.core.data_dump.data_processor.factory import DataProcessorFactory
+from msprobe.core.data_dump.scope import build_scope, ListScope
+from msprobe.core.data_dump.json_writer import DataWriter
+from msprobe.core.common.log import logger
+from msprobe.core.common.const import Const
+from msprobe.core.data_dump.data_processor.factory import DataProcessorFactory
def build_data_collector(config):
@@ -21,7 +20,8 @@ class DataCollector:
self.config = config
self.data_writer = DataWriter()
self.data_processor = DataProcessorFactory.create_processor(self.config, self.data_writer)
- self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework) if self.config.framework == Const.PT_FRAMEWORK else None
+ self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework) \
+ if self.config.framework == Const.PT_FRAMEWORK else None
self.module_count = {}
if self.config.task == Const.FREE_BENCHMARK:
self.scope = build_scope(ListScope, self.config.scope, self.config.list)
@@ -35,7 +35,7 @@ class DataCollector:
@property
def dump_file_path(self):
return self.data_writer.dump_file_path
-
+
@staticmethod
def check_scope_and_pid(scope, name, pid):
return (not scope or scope.check(name)) and pid == os.getpid()
@@ -43,10 +43,10 @@ class DataCollector:
@staticmethod
def is_inplace(module):
return getattr(module, "op_is_inplace", False)
-
+
def if_return_forward_new_output(self):
return self.data_processor.if_return_forward_new_output()
-
+
def get_forward_new_output(self):
return self.data_processor.get_forward_new_output()
@@ -88,8 +88,11 @@ class DataCollector:
else:
data_info = self.data_processor.analyze_forward_inplace(name, module_input_output)
if self.config.level == "L2":
- return
+ return
self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
+ if self.data_processor.stop_run():
+ self.handle_data(name, data_info, use_buffer=False)
+ raise Exception("[msprobe] exit")
self.handle_data(name, data_info)
def backward_data_collect(self, name, module, pid, module_input_output):
@@ -98,6 +101,9 @@ class DataCollector:
return
data_info = self.data_processor.analyze_backward(name, module, module_input_output)
+ if self.data_processor.stop_run():
+ self.handle_data(name, data_info, use_buffer=False)
+ raise Exception("[msprobe] exit")
self.handle_data(name, data_info)
def update_construct(self, name):
@@ -105,12 +111,15 @@ class DataCollector:
self.data_writer.update_construct({name: self.module_processor.api_parent_node})
self.data_writer.update_construct(self.module_processor.module_node)
- def handle_data(self, name, data_info):
+ def handle_data(self, name, data_info, use_buffer=True):
msg = f"msProbe is collecting data on {name}. "
if data_info:
msg = self.update_data(data_info, msg)
logger.info(msg)
- self.data_writer.flush_data_when_buffer_is_full()
+ if use_buffer:
+ self.data_writer.flush_data_when_buffer_is_full()
+ else:
+ self.write_json()
def module_count_func(self, name, name_template):
module_name = name.split(Const.SEP)[-3]
@@ -135,6 +144,6 @@ class DataCollector:
def update_dump_paths(self, *args):
self.data_writer.update_dump_paths(*args)
self.data_writer.initialize_json_file(task=self.config.task, level=self.config.level)
-
+
def update_iter(self, current_iter):
self.data_processor.update_iter(current_iter)
diff --git a/debug/accuracy_tools/atat/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
similarity index 87%
rename from debug/accuracy_tools/atat/core/data_dump/data_processor/base.py
rename to debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
index 208c053192c6da674ec9c7f522a0affdf1d091e2..2fbc86b5656c3bcfe14b2fe9fe6bb295451e9466 100644
--- a/debug/accuracy_tools/atat/core/data_dump/data_processor/base.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
@@ -3,9 +3,9 @@ import inspect
from dataclasses import dataclass
from typing import Tuple, Dict, Optional, Any
import numpy as np
-from atat.core.common.log import logger
-from atat.core.common.utils import convert_tuple
-from atat.core.common.const import Const
+from msprobe.core.common.log import logger
+from msprobe.core.common.utils import convert_tuple
+from msprobe.core.common.const import Const
@dataclass
@@ -35,11 +35,11 @@ class ModuleBackwardInputsOutputs:
@property
def grad_input_tuple(self):
return convert_tuple(self.grad_input)
-
+
@property
def grad_output_tuple(self):
- return convert_tuple(self.grad_output)
-
+ return convert_tuple(self.grad_output)
+
class TensorStatInfo:
def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None):
@@ -53,7 +53,7 @@ class BaseDataProcessor:
_recursive_key_stack = []
special_type = (np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_,
bool, int, float, str, slice)
-
+
def __init__(self, config, data_writer):
self.data_writer = data_writer
self.config = config
@@ -65,11 +65,11 @@ class BaseDataProcessor:
self.current_iter = 0
self._return_forward_new_output = False
self._forward_new_output = None
-
+
@property
def data_path(self):
return self.data_writer.dump_tensor_data_dir
-
+
@staticmethod
def analyze_api_call_stack(name):
stack_str = []
@@ -87,7 +87,7 @@ class BaseDataProcessor:
stack_str.append(stack_line)
stack_info_struct = {name: stack_str}
return stack_info_struct
-
+
@staticmethod
def _convert_numpy_to_builtin(arg):
type_mapping = {
@@ -103,26 +103,15 @@ class BaseDataProcessor:
if isinstance(arg, numpy_type):
return builtin_type(arg), type(arg).__name__
return arg, ''
-
+
@staticmethod
def _analyze_numpy(value, numpy_type):
return {"type": numpy_type, "value": value}
-
- @staticmethod
- def _analyze_builtin(arg):
- single_arg = {}
- if isinstance(arg, slice):
- single_arg.update({"type": "slice"})
- single_arg.update({"value": [arg.start, arg.stop, arg.step]})
- else:
- single_arg.update({"type": type(arg).__name__})
- single_arg.update({"value": arg})
- return single_arg
-
+
@classmethod
def get_special_types(cls):
return cls.special_type
-
+
@classmethod
def recursive_apply_transform(cls, args, transform):
if isinstance(args, cls.get_special_types()):
@@ -142,9 +131,11 @@ class BaseDataProcessor:
resutl_dict[k] = cls.recursive_apply_transform(arg, transform)
cls._recursive_key_stack.pop()
return resutl_dict
- else:
+ elif args is not None:
logger.warning(f"Data type {type(args)} is not supported.")
return None
+ else:
+ return None
def if_return_forward_new_output(self):
return self._return_forward_new_output
@@ -175,13 +166,14 @@ class BaseDataProcessor:
return (Const.ALL in self.config.data_mode or
forward_backward in self.config.data_mode or
input_output in self.config.data_mode)
-
- def analyze_pre_forward(self, name, module,module_input_output: ModuleForwardInputsOutputs):
+
+ def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
pass
-
+
def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
api_info_struct = {}
- if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT): # check whether data_mode contains forward or input
+ # check whether data_mode contains forward or input
+ if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
api_info_struct[name] = {}
self.api_data_category = Const.INPUT
args_info_list = self.analyze_element(module_input_output.args_tuple)
@@ -190,13 +182,14 @@ class BaseDataProcessor:
kwargs_info_list = self.analyze_element(module_input_output.kwargs)
api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
- if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT): # check whether data_mode contains forward or output
+ # check whether data_mode contains forward or output
+ if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
api_info_struct[name] = api_info_struct.get(name, {})
self.api_data_category = Const.OUTPUT
output_info_list = self.analyze_element(module_input_output.output_tuple)
api_info_struct[name][Const.OUTPUT] = output_info_list
return api_info_struct
-
+
def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
api_info_struct = {}
if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
@@ -208,7 +201,7 @@ class BaseDataProcessor:
kwargs_info_list = self.analyze_element(module_input_output.kwargs)
api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
return api_info_struct
-
+
def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
concat_args = module_input_output.concat_args_and_kwargs()
api_info_struct = {}
@@ -218,26 +211,29 @@ class BaseDataProcessor:
output_info_list = self.analyze_element(concat_args)
api_info_struct[name][Const.OUTPUT] = output_info_list
return api_info_struct
-
+
def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
api_info_struct = {}
- if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
+ if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
api_info_struct[name] = {}
- self.api_data_category = Const.OUTPUT
+ self.api_data_category = Const.INPUT
input_info_list = self.analyze_element(module_input_output.grad_input_tuple)
- api_info_struct[name][Const.GRAD_INPUT] = input_info_list
+ api_info_struct[name][Const.INPUT] = input_info_list
- if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
+ if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
api_info_struct[name] = api_info_struct.get(name, {})
- self.api_data_category = Const.INPUT
+ self.api_data_category = Const.OUTPUT
output_info_list = self.analyze_element(module_input_output.grad_output_tuple)
- api_info_struct[name][Const.GRAD_OUTPUT] = output_info_list
+ api_info_struct[name][Const.OUTPUT] = output_info_list
return api_info_struct
def get_save_file_path(self, suffix):
- file_format = "pt" if self.config.framework == Const.PT_FRAMEWORK else "npy"
+ file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX
dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP +
- suffix + Const.SEP + file_format)
+ suffix + file_format)
file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
- return dump_data_name, file_path
\ No newline at end of file
+ return dump_data_name, file_path
+
+ def stop_run(self):
+ return False
diff --git a/debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/factory.py
similarity index 89%
rename from debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py
rename to debug/accuracy_tools/msprobe/core/data_dump/data_processor/factory.py
index bcc771f3684aa55a422691bd9dbfff31f07773dd..ad74acdeebaf9df257eca35cc670a6131805758c 100644
--- a/debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/factory.py
@@ -1,10 +1,10 @@
-from atat.core.common.const import Const
+from msprobe.core.common.const import Const
class DataProcessorFactory:
_data_processor = {}
_module_processor = {}
-
+
@classmethod
def register_processor(cls, framework, task, processor_class):
key = (framework, task)
@@ -13,7 +13,7 @@ class DataProcessorFactory:
@classmethod
def register_module_processor(cls, framework, processor_class):
cls._module_processor[framework] = processor_class
-
+
@classmethod
def get_module_processor(cls, framework):
processor_class = cls._module_processor.get(framework)
@@ -39,7 +39,7 @@ class DataProcessorFactory:
TensorDataProcessor as PytorchTensorDataProcessor,
OverflowCheckDataProcessor as PytorchOverflowCheckDataProcessor,
FreeBenchmarkDataProcessor as PytorchFreeBenchmarkDataProcessor,
- KernelDumpDataProcessor as PytorchKernelDumpDataProcessor
+ KernelDumpDataProcessor as PytorchKernelDumpDataProcessor
)
from ....pytorch.module_processer import ModuleProcesser
cls.register_processor(Const.PT_FRAMEWORK, Const.STATISTICS, PytorchStatisticsDataProcessor)
@@ -47,15 +47,13 @@ class DataProcessorFactory:
cls.register_processor(Const.PT_FRAMEWORK, Const.OVERFLOW_CHECK, PytorchOverflowCheckDataProcessor)
cls.register_processor(Const.PT_FRAMEWORK, Const.FREE_BENCHMARK, PytorchFreeBenchmarkDataProcessor)
cls.register_processor(Const.PT_FRAMEWORK, Const.KERNEL_DUMP, PytorchKernelDumpDataProcessor)
- cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser)
+ cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser)
elif framework == Const.MS_FRAMEWORK:
from .mindspore_processor import (
StatisticsDataProcessor as MindsporeStatisticsDataProcessor,
TensorDataProcessor as MindsporeTensorDataProcessor,
- OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor,
- FreeBenchmarkDataProcessor as MindsporeFreeBenchmarkDataProcessor
+ OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor
)
cls.register_processor(Const.MS_FRAMEWORK, Const.STATISTICS, MindsporeStatisticsDataProcessor)
cls.register_processor(Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor)
cls.register_processor(Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor)
- cls.register_processor(Const.MS_FRAMEWORK, Const.FREE_BENCHMARK, MindsporeFreeBenchmarkDataProcessor)
\ No newline at end of file
diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cfb27cda2afa8794943dceb23b4085f0305c50b
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py
@@ -0,0 +1,206 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import zlib
+import mindspore as ms
+from mindspore import ops
+import numpy as np
+
+from msprobe.core.common.const import Const
+from msprobe.core.data_dump.data_processor.base import (BaseDataProcessor, TensorStatInfo,
+ ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs)
+from msprobe.core.common.file_utils import path_len_exceeds_limit, change_mode, FileCheckConst
+from msprobe.mindspore.dump.hook_cell.wrap_functional import load_ops_functions
+from msprobe.mindspore.common.utils import convert_bf16_to_fp32
+from msprobe.mindspore.common.log import logger
+from msprobe.mindspore.dump.hook_cell.api_registry import api_register
+
+
+class MindsporeDataProcessor(BaseDataProcessor):
+ mindspore_special_type = tuple([ms.Tensor])
+ ops_func, mint_ops_func, _ = load_ops_functions()
+
+ def __init__(self, config, data_writer):
+ super().__init__(config, data_writer)
+ self.mindspore_object_key = {
+ "dtype": self.analyze_dtype_in_kwargs
+ }
+
+ @staticmethod
+ def get_md5_for_tensor(x):
+ x = convert_bf16_to_fp32(x)
+ tensor_bytes = x.asnumpy().tobytes()
+ crc32_hash = zlib.crc32(tensor_bytes)
+ return f"{crc32_hash:08x}"
+
+ @staticmethod
+ def analyze_dtype_in_kwargs(element):
+ return {"type": "mindspore.dtype", "value": str(element)}
+
+ @staticmethod
+ def _analyze_builtin(arg):
+ single_arg = {}
+ if isinstance(arg, slice):
+ single_arg.update({"type": "slice"})
+ # slice参数中可能存在tensor类型,json序列化,需要转换为python数值类型
+ values = [
+ value if not isinstance(value, ms.Tensor) else value.item()
+ for value in [arg.start, arg.stop, arg.step]
+ ]
+ single_arg.update({"value": values})
+ else:
+ single_arg.update({"type": type(arg).__name__})
+ single_arg.update({"value": arg})
+ return single_arg
+
+ @classmethod
+ def get_special_types(cls):
+ return super().get_special_types() + cls.mindspore_special_type
+
+ def get_stat_info(self, data):
+ tensor_stat = TensorStatInfo()
+ if data.numel() == 0:
+ return tensor_stat
+ elif data.dtype == ms.bool_:
+ tensor_stat.max = self.mint_ops_func["max"](data).item()
+ tensor_stat.min = self.mint_ops_func["min"](data).item()
+ elif not data.shape:
+ tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
+ elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
+ data_abs = np.abs(data.asnumpy())
+ tensor_stat.max = np.max(data_abs)
+ tensor_stat.min = np.min(data_abs)
+ tensor_stat.mean = np.mean(data_abs)
+ tensor_stat.norm = np.linalg.norm(data_abs)
+ else:
+ if data.dtype == ms.bfloat16 or not ops.is_floating_point(data):
+ data = data.to(ms.float32)
+ api_register.norm_inner_op_set_ori_func()
+ tensor_stat.max = self.mint_ops_func["max"](data).item()
+ tensor_stat.min = self.mint_ops_func["min"](data).item()
+ tensor_stat.mean = self.mint_ops_func["mean"](data).item()
+ tensor_stat.norm = self.ops_func["norm"](data).item()
+ api_register.norm_inner_op_set_hook_func()
+ return tensor_stat
+
+ def analyze_single_element(self, element, suffix_stack):
+ if suffix_stack and suffix_stack[-1] in self.mindspore_object_key:
+ return self.mindspore_object_key[suffix_stack[-1]](element)
+
+ converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
+ if converted_numpy is not element:
+ return self._analyze_numpy(converted_numpy, numpy_type)
+ if isinstance(element, ms.Tensor):
+ return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
+
+ if isinstance(element, (bool, int, float, str, slice)):
+ return self._analyze_builtin(element)
+ return {}
+
+ def analyze_element(self, element):
+ return self.recursive_apply_transform(element, self.analyze_single_element)
+
+ def _analyze_tensor(self, tensor, suffix):
+ tensor_stat = self.get_stat_info(tensor)
+ tensor_json = {
+ 'type': 'mindspore.Tensor',
+ 'dtype': str(tensor.dtype),
+ 'shape': tensor.shape,
+ 'Max': tensor_stat.max,
+ 'Min': tensor_stat.min,
+ 'Mean': tensor_stat.mean,
+ 'Norm': tensor_stat.norm
+ }
+ if self.config.summary_mode == Const.MD5:
+ tensor_md5 = self.get_md5_for_tensor(tensor)
+ tensor_json.update({Const.MD5: tensor_md5})
+ return tensor_json
+
+
+class StatisticsDataProcessor(MindsporeDataProcessor):
+ pass
+
+
+class TensorDataProcessor(MindsporeDataProcessor):
+ def _analyze_tensor(self, tensor, suffix):
+ dump_data_name, file_path = self.get_save_file_path(suffix)
+ single_arg = super()._analyze_tensor(tensor, suffix)
+ single_arg.update({"data_name": dump_data_name})
+ if not path_len_exceeds_limit(file_path):
+ tensor = convert_bf16_to_fp32(tensor)
+ saved_tensor = tensor.asnumpy()
+ np.save(file_path, saved_tensor)
+ change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
+ else:
+ logger.warning(f'The file path {file_path} length exceeds limit.')
+ return single_arg
+
+
+class OverflowCheckDataProcessor(MindsporeDataProcessor):
+ __slots__ = ["cached_tensors_and_file_paths"]
+
+ def __init__(self, config, data_writer):
+ super().__init__(config, data_writer)
+ self.cached_tensors_and_file_paths = {}
+ self.real_overflow_dump_times = 0
+ self.overflow_nums = config.overflow_nums
+
+ def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
+ self.has_overflow = False
+ api_info_struct = super().analyze_forward(name, module, module_input_output)
+ self.maybe_save_overflow_data()
+ return api_info_struct if self.has_overflow else None
+
+ def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
+ self.has_overflow = False
+ api_info_struct = super().analyze_backward(name, module, module_input_output)
+ self.maybe_save_overflow_data()
+ return api_info_struct if self.has_overflow else None
+
+ def maybe_save_overflow_data(self):
+ if self.has_overflow:
+ for file_path, tensor in self.cached_tensors_and_file_paths.items():
+ tensor = convert_bf16_to_fp32(tensor)
+ np.save(file_path, tensor.asnumpy())
+ change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
+ self.real_overflow_dump_times += 1
+ self.cached_tensors_and_file_paths = {}
+
+ def stop_run(self):
+ if self.overflow_nums == -1:
+ return False
+ if self.real_overflow_dump_times >= self.overflow_nums:
+ logger.warning(f"[msprobe] 超过预设溢出次数 当前溢出次数: {self.real_overflow_dump_times}")
+ return True
+ return False
+
+ def _analyze_maybe_overflow_tensor(self, tensor_json):
+ if tensor_json['Max'] is None:
+ return
+ if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']):
+ self.has_overflow = True
+ if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']):
+ self.has_overflow = True
+
+ def _analyze_tensor(self, tensor, suffix):
+ dump_data_name, file_path = self.get_save_file_path(suffix)
+ if not path_len_exceeds_limit(file_path):
+ self.cached_tensors_and_file_paths.update({file_path: tensor})
+ else:
+ logger.warning(f'The file path {file_path} length exceeds limit.')
+ single_arg = super()._analyze_tensor(tensor, suffix)
+ self._analyze_maybe_overflow_tensor(single_arg)
+ single_arg.update({"data_name": dump_data_name})
+ return single_arg
diff --git a/debug/accuracy_tools/atat/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py
similarity index 87%
rename from debug/accuracy_tools/atat/core/data_dump/data_processor/pytorch_processor.py
rename to debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py
index cf3c5ebe5864ff48d9f2442c020cb3ae99473b50..0986ccf826330f9783a5dde2cc1d07a87c2e57f4 100644
--- a/debug/accuracy_tools/atat/core/data_dump/data_processor/pytorch_processor.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py
@@ -5,18 +5,19 @@ from typing import List
import numpy as np
import torch
-from atat.core.common.exceptions import MsaccException
-from atat.core.common.file_check import path_len_exceeds_limit, change_mode
-from atat.core.common.log import logger
-from atat.core.common.const import Const, OverflowConst, FileCheckConst
-from atat.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
+from msprobe.core.common.exceptions import MsprobeException
+from msprobe.core.common.file_utils import path_len_exceeds_limit, change_mode
+from msprobe.core.common.log import logger
+from msprobe.core.common.const import Const, OverflowConst, FileCheckConst
+from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
ModuleForwardInputsOutputs, TensorStatInfo
-from atat.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
+from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
try:
import torch_npu
+ is_gpu = False
except ImportError:
- pass
+ is_gpu = True
class PytorchDataProcessor(BaseDataProcessor):
@@ -76,6 +77,42 @@ class PytorchDataProcessor(BaseDataProcessor):
tensor_stat.mean = torch._C._VariableFunctionsClass.mean(data_clone).item()
tensor_stat.norm = torch._C._VariableFunctionsClass.norm(data_clone).item()
return tensor_stat
+
+ @staticmethod
+ def handle_tensor_extremum_nan_inf(tensor, operator):
+ data_clone = tensor.detach()
+ data_nan = torch._C._VariableFunctionsClass.isnan(data_clone)
+ if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel():
+ return float('nan')
+ finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone)
+ if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0:
+ finite_values = data_clone[finite_mask]
+ return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \
+ torch._C._VariableFunctionsClass.min(finite_values).item()
+ else:
+ data_no_nan = data_clone[~data_nan]
+ return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \
+ torch._C._VariableFunctionsClass.min(data_no_nan).item()
+
+ @staticmethod
+ def _analyze_builtin(arg):
+ single_arg = {}
+ if isinstance(arg, slice):
+ single_arg.update({"type": "slice"})
+ # slice参数中可能存在tensor类型,json序列化,需要转换为python数值类型
+ values = [
+ value if not isinstance(value, torch.Tensor) else value.item()
+ for value in [arg.start, arg.stop, arg.step]
+ ]
+ single_arg.update({"value": values})
+ else:
+ single_arg.update({"type": type(arg).__name__})
+ single_arg.update({"value": arg})
+ return single_arg
+
+ @staticmethod
+ def _analyze_torch_size(arg):
+ return {"type": "torch.Size", "value": list(arg)}
@classmethod
def get_special_types(cls):
@@ -93,14 +130,11 @@ class PytorchDataProcessor(BaseDataProcessor):
return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
if isinstance(element, (bool, int, float, str, slice)):
return self._analyze_builtin(element)
- return None
+ return {}
def analyze_element(self, element):
return self.recursive_apply_transform(element, self.analyze_single_element)
- def _analyze_torch_size(arg):
- return {"type": "torch.Size", "value": list(arg)}
-
def _analyze_tensor(self, tensor, suffix):
tensor_stat = self.get_stat_info(tensor)
tensor_json = {}
@@ -112,9 +146,17 @@ class PytorchDataProcessor(BaseDataProcessor):
tensor_json.update({"Mean": tensor_stat.mean})
tensor_json.update({"Norm": tensor_stat.norm})
tensor_json.update({"requires_grad": tensor.requires_grad})
- if self.config.summary_mode == "md5":
+
+ if tensor_stat.max is not None:
+ if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max):
+ tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max")
+ if tensor_stat.min is not None:
+ if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min):
+ tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min")
+
+ if self.config.summary_mode == Const.MD5:
tensor_md5 = self.get_md5_for_tensor(tensor)
- tensor_json.update({"md5": tensor_md5})
+ tensor_json.update({Const.MD5: tensor_md5})
return tensor_json
@@ -142,7 +184,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
super().__init__(config, data_writer)
self.cached_tensors_and_file_paths = {}
self.real_overflow_dump_times = 0
- self.overflow_nums = config.overflow_num
+ self.overflow_nums = config.overflow_nums
self.bits_for_overflow = 8
@staticmethod
@@ -150,21 +192,6 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
overflow_mode = os.getenv(OverflowConst.OVERFLOW_DEBUG_MODE_ENABLE, Const.ENV_DISABLE)
return overflow_mode == Const.ENV_ENABLE
- @staticmethod
- def handle_tensor_extremum_nan_inf(data_clone, operator):
- data_nan = torch._C._VariableFunctionsClass.isnan(data_clone)
- if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel():
- return float('nan')
- finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone)
- if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0:
- finite_values = data_clone[finite_mask]
- return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \
- torch._C._VariableFunctionsClass.min(finite_values).item()
- else:
- data_no_nan = data_clone[~data_nan]
- return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \
- torch._C._VariableFunctionsClass.min(data_no_nan).item()
-
def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
self.has_overflow = False
api_info_struct = super().analyze_forward(name, module, module_input_output)
@@ -190,7 +217,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
if self.overflow_nums == -1:
return
if self.real_overflow_dump_times >= self.overflow_nums:
- raise MsaccException(MsaccException.OVERFLOW_NUMS_ERROR, str(self.real_overflow_dump_times))
+ raise MsprobeException(MsprobeException.OVERFLOW_NUMS_ERROR, str(self.real_overflow_dump_times))
def check_overflow_npu(self):
if self.overflow_debug_mode_enalbe():
@@ -210,16 +237,13 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
else:
torch_npu._C._clear_overflow_npu()
- def _analyze_maybe_overflow_tensor(self, tensor_json, tensor):
- data_clone = tensor.detach()
- if hasattr(torch_npu._C, '_npu_is_support_inf_nan') and torch_npu._C._npu_is_support_inf_nan():
+ def _analyze_maybe_overflow_tensor(self, tensor_json):
+ if is_gpu or (hasattr(torch_npu._C, '_npu_is_support_inf_nan') and torch_npu._C._npu_is_support_inf_nan()):
if tensor_json['Max'] is None:
return
if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']):
- tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "max")
self.has_overflow = True
if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']):
- tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "min")
self.has_overflow = True
else:
self.has_overflow = self.check_overflow_npu()
@@ -233,7 +257,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
else:
logger.warning(f'The file path {file_path} length exceeds limit.')
single_arg = super()._analyze_tensor(tensor, suffix)
- self._analyze_maybe_overflow_tensor(single_arg, tensor)
+ self._analyze_maybe_overflow_tensor(single_arg)
single_arg.update({"data_name": dump_data_name})
return single_arg
diff --git a/debug/accuracy_tools/atat/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py
similarity index 85%
rename from debug/accuracy_tools/atat/core/data_dump/json_writer.py
rename to debug/accuracy_tools/msprobe/core/data_dump/json_writer.py
index 23f37b2342e9bde9d69bb65022a40911d5a54dc4..99cc5f3159ee038133e874109a78b2a7a85d9413 100644
--- a/debug/accuracy_tools/atat/core/data_dump/json_writer.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py
@@ -4,9 +4,9 @@ import fcntl
import json
from pathlib import Path
-from atat.core.common.file_check import change_mode
-from atat.core.common.log import logger
-from atat.core.common.const import Const, FileCheckConst
+from msprobe.core.common.file_utils import change_mode, FileOpen
+from msprobe.core.common.log import logger
+from msprobe.core.common.const import Const, FileCheckConst
class DataWriter:
@@ -30,20 +30,20 @@ class DataWriter:
return
is_exists = os.path.exists(file_path)
append = "a+" if is_exists else "w+"
- with os.fdopen(
- os.open(file_path, Const.WRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), append, newline=""
- ) as csv_file:
+ with FileOpen(file_path, append) as csv_file:
spawn_writer = csv.writer(csv_file)
if not is_exists:
spawn_writer.writerow(result_header)
spawn_writer.writerows([result,])
+ is_new_file = not is_exists
+ if is_new_file:
+ change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
def initialize_json_file(self, **kwargs):
kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
- with os.fdopen(
- os.open(self.dump_file_path, Const.OVERWRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), 'w'
- ) as f:
+ with FileOpen(self.dump_file_path, 'w') as f:
json.dump(kwargs, f)
+ change_mode(self.dump_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
if os.path.exists(self.stack_file_path):
os.remove(self.stack_file_path)
@@ -83,7 +83,7 @@ class DataWriter:
def write_data_json(self, file_path):
logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
if Path(file_path).exists() and os.path.getsize(file_path) > 0:
- with open(file_path, "r+") as f:
+ with FileOpen(file_path, "r+") as f:
fcntl.flock(f, fcntl.LOCK_EX)
data_to_write = json.load(f)
fcntl.flock(f, fcntl.LOCK_UN)
@@ -91,7 +91,7 @@ class DataWriter:
self.init_json['data_path'] = self.dump_tensor_data_dir
data_to_write = self.init_json
data_to_write[Const.DATA].update(self.cache_data[Const.DATA])
- with open(file_path, 'w+') as f:
+ with FileOpen(file_path, 'w+') as f:
fcntl.flock(f, fcntl.LOCK_EX)
json.dump(data_to_write, f, indent=1)
fcntl.flock(f, fcntl.LOCK_UN)
@@ -99,13 +99,13 @@ class DataWriter:
self.cache_data[Const.DATA].clear()
def write_stack_info_json(self, file_path):
- with open(file_path, 'w+') as f:
+ with FileOpen(file_path, 'w+') as f:
fcntl.flock(f, fcntl.LOCK_EX)
json.dump(self.cache_stack, f, indent=1)
fcntl.flock(f, fcntl.LOCK_UN)
def write_construct_info_json(self, file_path):
- with open(file_path, 'w+') as f:
+ with FileOpen(file_path, 'w+') as f:
fcntl.flock(f, fcntl.LOCK_EX)
json.dump(self.cache_construct, f, indent=1)
fcntl.flock(f, fcntl.LOCK_UN)
diff --git a/debug/accuracy_tools/atat/core/data_dump/scope.py b/debug/accuracy_tools/msprobe/core/data_dump/scope.py
similarity index 98%
rename from debug/accuracy_tools/atat/core/data_dump/scope.py
rename to debug/accuracy_tools/msprobe/core/data_dump/scope.py
index e7114f343fe724ffdd40d4837a09b8417d03a1b0..1d74c3e461ac4b0005e4fdd40ae1fe2c12bb1c4e 100644
--- a/debug/accuracy_tools/atat/core/data_dump/scope.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/scope.py
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
-from atat.core.common.exceptions import ScopeException
-from atat.core.common.const import Const
+from msprobe.core.common.exceptions import ScopeException
+from msprobe.core.common.const import Const
def build_scope(scope_class, scope=None, api_list=None):
diff --git a/debug/accuracy_tools/msprobe/mindspore/__init__.py b/debug/accuracy_tools/msprobe/mindspore/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bf42d1e39feef616e6eb2cc296a099b0bddfd98
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/mindspore/__init__.py
@@ -0,0 +1 @@
+from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger
diff --git a/debug/accuracy_tools/msprobe/mindspore/common/log.py b/debug/accuracy_tools/msprobe/mindspore/common/log.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec027c750133ce4aabdac4ed914b4a5c50b2a2f1
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/mindspore/common/log.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import os
+import time
+import sys
+
+from msprobe.mindspore.common.utils import get_rank_if_initialized
+from msprobe.core.common.log import BaseLogger
+from msprobe.core.common.exceptions import DistributedNotInitializedError
+
+
+class MindsporeLogger(BaseLogger):
+ def __init__(self):
+ super().__init__()
+
+ def get_rank(self):
+ try:
+ current_rank = get_rank_if_initialized()
+ except DistributedNotInitializedError:
+ current_rank = None
+
+ return current_rank
+
+
+logger = MindsporeLogger()
\ No newline at end of file
diff --git a/debug/accuracy_tools/msprobe/mindspore/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/common/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6abf0a1ee8870c6044d7cad95542cb89ebf81d46
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/mindspore/common/utils.py
@@ -0,0 +1,44 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import mindspore as ms
+from msprobe.core.common.exceptions import DistributedNotInitializedError
+
+
+def get_rank_if_initialized():
+ if ms.communication.GlobalComm.INITED:
+ return ms.communication.get_rank()
+ else:
+ raise DistributedNotInitializedError("mindspore distributed environment is not initialized")
+
+
+def convert_bf16_to_fp32(tensor):
+ if tensor.dtype == ms.bfloat16:
+ tensor = tensor.to(ms.float32)
+ return tensor
+
+
+class MsprobeStep(ms.train.Callback):
+
+ def __init__(self, debugger):
+ super(MsprobeStep, self).__init__()
+ self.debugger = debugger
+
+ def on_train_step_begin(self, run_context):
+ self.debugger.start()
+
+ def on_train_step_end(self, run_context):
+ self.debugger.stop()
+ self.debugger.step()
diff --git a/debug/accuracy_tools/atat/mindspore/overflow_check/__init__.py b/debug/accuracy_tools/msprobe/mindspore/debugger/__init__.py
similarity index 100%
rename from debug/accuracy_tools/atat/mindspore/overflow_check/__init__.py
rename to debug/accuracy_tools/msprobe/mindspore/debugger/__init__.py
diff --git a/debug/accuracy_tools/atat/mindspore/debugger/debugger_config.py b/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py
similarity index 69%
rename from debug/accuracy_tools/atat/mindspore/debugger/debugger_config.py
rename to debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py
index 56a4b9bf758197d77ef04874f2865e2136d6f67c..23cb7294b8dc0d64e012f5fac2b863bdfe871bbe 100644
--- a/debug/accuracy_tools/atat/mindspore/debugger/debugger_config.py
+++ b/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py
@@ -1,13 +1,10 @@
import os
+from msprobe.core.common.utils import Const
+from msprobe.core.common.const import MsConst
-class DebuggerConfig:
- convert_map = {
- "L0": "cell",
- "L1": "api",
- "L2": 'kernel'
- }
+class DebuggerConfig:
def __init__(self, common_config, task_config):
self.dump_path = common_config.dump_path
self.task = common_config.task
@@ -15,18 +12,22 @@ class DebuggerConfig:
self.step = [] if not common_config.step else common_config.step
if not common_config.level:
common_config.level = "L1"
- self.level = DebuggerConfig.convert_map[common_config.level]
+ self.level = MsConst.TOOL_LEVEL_DICT.get(common_config.level, MsConst.API)
+ self.level_ori = common_config.level
self.list = [] if not task_config.list else task_config.list
- self.data_mode = [] if not task_config.data_mode else task_config.data_mode
+ self.scope = [] if not task_config.scope else task_config.scope
+ self.data_mode = [] if not task_config.data_mode else task_config.data_mode
self.file_format = task_config.file_format
+ self.overflow_nums = 1 if not task_config.overflow_nums else task_config.overflow_nums
self.check_mode = task_config.check_mode
-
+ self.framework = Const.MS_FRAMEWORK
+ self.summary_mode = task_config.summary_mode
self.check()
def check(self):
if not self.dump_path:
raise Exception("Dump path is empty.")
- if not os.path.isabs(self.dump_path):
+ if self.level_ori != "L1" and not os.path.isabs(self.dump_path):
raise Exception("Dump path must be absolute path.")
if not self.task:
self.task = "statistics"
diff --git a/debug/accuracy_tools/atat/mindspore/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py
similarity index 31%
rename from debug/accuracy_tools/atat/mindspore/debugger/precision_debugger.py
rename to debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py
index 0099074762f0746c1bd8341047f37b3e5fe08855..5475dc3586c35687fec63b51f265ac83c0d33a87 100644
--- a/debug/accuracy_tools/atat/mindspore/debugger/precision_debugger.py
+++ b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py
@@ -1,7 +1,12 @@
import os
-from atat.mindspore.ms_config import parse_json_config
-from atat.mindspore.debugger.debugger_config import DebuggerConfig
-from atat.mindspore.task_handler_factory import TaskHandlerFactory
+
+import mindspore as ms
+
+from msprobe.mindspore.service import Service
+from msprobe.mindspore.ms_config import parse_json_config
+from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
+from msprobe.mindspore.task_handler_factory import TaskHandlerFactory
+from msprobe.core.common.const import MsConst
class PrecisionDebugger:
@@ -12,6 +17,8 @@ class PrecisionDebugger:
cls._instance = super().__new__(cls)
cls._instance.initialized = False
cls._instance.config = None
+ cls.service = None
+ cls.first_start = False
return cls._instance
def __init__(self, config_path=None):
@@ -23,10 +30,46 @@ class PrecisionDebugger:
self.config = DebuggerConfig(common_config, task_config)
self.initialized = True
+ @staticmethod
+ def _get_execution_mode():
+ if ms.get_context("mode") == ms.GRAPH_MODE:
+ if ms.context.get_jit_config().get("jit_level") == "O2" or ms.get_context("jit_level") == "O2":
+ return MsConst.GRAPH_GE_MODE
+ else:
+ return MsConst.GRAPH_KBYK_MODE
+ else:
+ return MsConst.PYNATIVE_MODE
+
@classmethod
- def start(cls, target=None):
+ def start(cls):
instance = cls._instance
if not instance:
raise Exception("No instance of PrecisionDebugger found.")
- handler = TaskHandlerFactory.create(instance.config)
- handler.handle()
+
+ instance.config.execution_mode = instance._get_execution_mode()
+ if instance.config.execution_mode == MsConst.PYNATIVE_MODE and instance.config.level == MsConst.API:
+ if not instance.service:
+ instance.service = Service(instance.config)
+ instance.service.start()
+ else:
+ if not instance.first_start:
+ handler = TaskHandlerFactory.create(instance.config)
+ handler.handle()
+
+ instance.first_start = True
+
+ @classmethod
+ def stop(cls):
+ instance = cls._instance
+ if not instance:
+ raise Exception("PrecisionDebugger instance is not created.")
+ if instance.service:
+ instance.service.stop()
+
+ @classmethod
+ def step(cls):
+ instance = cls._instance
+ if not instance:
+ raise Exception("PrecisionDebugger instance is not created.")
+ if instance.service:
+ instance.service.step()
diff --git a/debug/accuracy_tools/atat/mindspore/doc/dump.md b/debug/accuracy_tools/msprobe/mindspore/doc/dump.md
similarity index 81%
rename from debug/accuracy_tools/atat/mindspore/doc/dump.md
rename to debug/accuracy_tools/msprobe/mindspore/doc/dump.md
index 3321a4da12bcf7ff5f215c3ddfe8fe922159f750..425d0683a268ebdcaf54a4f70b5e448bb1233f3c 100644
--- a/debug/accuracy_tools/atat/mindspore/doc/dump.md
+++ b/debug/accuracy_tools/msprobe/mindspore/doc/dump.md
@@ -1,8 +1,8 @@
# **精度数据采集**
-atat工具主要通过在训练脚本内添加dump接口并启动训练的方式来采集精度数据。
+msprobe工具主要通过在训练脚本内添加dump接口并启动训练的方式来采集精度数据。
-执行dump操作需要安装atat工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。
+执行dump操作需要安装msprobe工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。
## dump接口介绍
@@ -12,7 +12,7 @@ atat工具主要通过在训练脚本内添加dump接口并启动训练的方式
通过加载dump配置文件的方式来确定dump操作的详细配置。
-可以在from atat.mindspore import PrecisionDebugger和模型初始化之间的任意位置添加该接口。
+可以在from msprobe.mindspore import PrecisionDebugger和模型初始化之间的任意位置添加该接口。
**原型**
@@ -43,7 +43,7 @@ debugger.start()
## 示例代码
```Python
-from atat.mindspore import PrecisionDebugger
+from msprobe.mindspore import PrecisionDebugger
debugger = PrecisionDebugger(config_path="./config.json")
# 请勿将以上初始化流程插入到循环代码中
# 下面代码也可以用PrecisionDebugger.start()
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/__init__.py b/debug/accuracy_tools/msprobe/mindspore/dump/__init__.py
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/__init__.py
rename to debug/accuracy_tools/msprobe/mindspore/dump/__init__.py
diff --git a/debug/accuracy_tools/atat/mindspore/dump/api_kbk_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/api_kbk_dump.py
similarity index 91%
rename from debug/accuracy_tools/atat/mindspore/dump/api_kbk_dump.py
rename to debug/accuracy_tools/msprobe/mindspore/dump/api_kbk_dump.py
index a53841189f5a52de74900b9ba4382e0746dfee3a..5c7af45d79060c00ce198f19a589d46bacf1f756 100644
--- a/debug/accuracy_tools/atat/mindspore/dump/api_kbk_dump.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/api_kbk_dump.py
@@ -1,9 +1,9 @@
import os
import json
-from atat.core.common.utils import make_dump_path_if_not_exists
-from atat.mindspore.debugger.debugger_config import DebuggerConfig
-from atat.core.common.log import logger
-from atat.core.common.file_check import FileOpen
+from msprobe.core.common.utils import make_dump_path_if_not_exists
+from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
+from msprobe.core.common.log import logger
+from msprobe.core.common.file_check import FileOpen
class ApiKbkDump:
diff --git a/debug/accuracy_tools/atat/mindspore/dump/dump_tool_factory.py b/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py
similarity index 82%
rename from debug/accuracy_tools/atat/mindspore/dump/dump_tool_factory.py
rename to debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py
index ab534edc243dfd5f44688358fe4ca8edb6a8a12d..2c4579b0e75fe1573f387f696c3d9e4efd4945e3 100644
--- a/debug/accuracy_tools/atat/mindspore/dump/dump_tool_factory.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py
@@ -1,6 +1,6 @@
-from atat.mindspore.debugger.debugger_config import DebuggerConfig
-from atat.mindspore.dump.api_kbk_dump import ApiKbkDump
-from atat.mindspore.dump.kernel_graph_dump import KernelGraphDump
+from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
+from msprobe.mindspore.dump.api_kbk_dump import ApiKbkDump
+from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump
class DumpToolFactory:
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5508416fde0729d5a8cf31333332464783024c07
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
@@ -0,0 +1,104 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import mindspore as ms
+from msprobe.mindspore.dump.hook_cell.wrap_functional import get_functional_ops, setup_hooks, \
+ HOOKFunctionalOP, HOOKMintOP, HOOKMintNNFunctionalOP
+from msprobe.mindspore.dump.hook_cell.wrap_tensor import get_tensor_ops, wrap_tensor_ops_and_bind, HOOKTensor
+from msprobe.core.common.utils import Const
+
+
+class ApiRegistry:
+ def __init__(self):
+ self.tensor_ori_attr = {}
+ self.functional_ori_attr = {}
+ self.mint_ops_ori_attr = {}
+ self.mint_func_ops_ori_attr = {}
+ self.norm_inner_ops_ori_attr = {}
+
+ self.tensor_hook_attr = {}
+ self.functional_hook_attr = {}
+ self.mint_ops_hook_attr = {}
+ self.mint_func_ops_hook_attr = {}
+ self.norm_inner_ops_hook_attr = {}
+
+ self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"]
+
+ @staticmethod
+ def store_ori_attr(ori_api_group, api_list, api_ori_attr):
+ for api in api_list:
+ if Const.SEP in api:
+ sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
+ sub_module = getattr(ori_api_group, sub_module_name)
+ api_ori_attr[api] = getattr(sub_module, sub_op)
+ else:
+ api_ori_attr[api] = getattr(ori_api_group, api)
+
+ @staticmethod
+ def set_api_attr(api_group, attr_dict):
+ for api, api_attr in attr_dict.items():
+ if Const.SEP in api:
+ sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
+ sub_module = getattr(api_group, sub_module_name, None)
+ if sub_module is not None:
+ setattr(sub_module, sub_op, api_attr)
+ else:
+ setattr(api_group, api, api_attr)
+
+ def norm_inner_op_set_hook_func(self):
+ self.set_api_attr(ms.ops, self.norm_inner_ops_hook_attr)
+
+ def norm_inner_op_set_ori_func(self):
+ self.set_api_attr(ms.ops, self.norm_inner_ops_ori_attr)
+
+ def api_set_hook_func(self):
+ self.set_api_attr(ms.Tensor, self.tensor_hook_attr)
+ self.set_api_attr(ms.ops, self.functional_hook_attr)
+ self.set_api_attr(ms.mint, self.mint_ops_hook_attr)
+ self.set_api_attr(ms.mint.nn.functional, self.mint_func_ops_hook_attr)
+
+ def api_set_ori_func(self):
+ self.set_api_attr(ms.Tensor, self.tensor_ori_attr)
+ self.set_api_attr(ms.ops, self.functional_ori_attr)
+ self.set_api_attr(ms.mint, self.mint_ops_ori_attr)
+ self.set_api_attr(ms.mint.nn.functional, self.mint_func_ops_ori_attr)
+
+ def initialize_hook(self, hook):
+ self.store_ori_attr(ms.Tensor, get_tensor_ops(), self.tensor_ori_attr)
+ wrap_tensor_ops_and_bind(hook)
+ for attr_name in dir(HOOKTensor):
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
+ self.tensor_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKTensor, attr_name)
+
+ functional_ops, mint_ops, mint_func_ops = get_functional_ops()
+ self.store_ori_attr(ms.ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr)
+ self.store_ori_attr(ms.ops, functional_ops, self.functional_ori_attr)
+ self.store_ori_attr(ms.mint, mint_ops, self.mint_ops_ori_attr)
+ self.store_ori_attr(ms.mint.nn.functional, mint_func_ops, self.mint_func_ops_ori_attr)
+ setup_hooks(hook)
+ for attr_name in dir(HOOKFunctionalOP):
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
+ self.functional_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKFunctionalOP, attr_name)
+ if attr_name[Const.ATTR_NAME_PREFIX_LEN:] in self.norm_inner_ops:
+ self.norm_inner_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKFunctionalOP, attr_name)
+ for attr_name in dir(HOOKMintOP):
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
+ self.mint_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKMintOP, attr_name)
+ for attr_name in dir(HOOKMintNNFunctionalOP):
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
+ self.mint_func_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKMintNNFunctionalOP, attr_name)
+
+
+api_register = ApiRegistry()
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py
new file mode 100644
index 0000000000000000000000000000000000000000..57ed44111ca0bcdf4d8bafbf27a1373e55bb5480
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py
@@ -0,0 +1,53 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+from collections import defaultdict
+
+from mindspore import nn
+from msprobe.core.common.const import Const
+
+
+class HOOKCell(nn.Cell):
+ cell_count = defaultdict(int)
+ g_stop_hook = False
+
+ def __init__(self, build_hook) -> None:
+ super(HOOKCell, self).__init__()
+ self.changed_status = False
+ self.input_kwargs = {}
+ self.prefix = ""
+ if not HOOKCell.g_stop_hook:
+ HOOKCell.g_stop_hook = True
+ self.changed_status = True
+ if hasattr(self, "prefix_op_name_"):
+ self.prefix = self.prefix_op_name_
+
+ HOOKCell.cell_count[self.prefix] += 1
+ self.prefix = self.prefix + str(HOOKCell.cell_count[self.prefix] - 1) + Const.SEP
+ forward_hook, backward_hook = build_hook(self.prefix)
+ self.register_forward_hook(forward_hook)
+ self.register_backward_hook(backward_hook)
+
+ # 重载call,加全局标志。
+ def __call__(self, *args, **kwargs):
+ try:
+ self.input_kwargs = kwargs
+ out = super(HOOKCell, self).__call__(*args, **kwargs)
+ except Exception as e:
+ raise e
+ finally:
+ if self.changed_status:
+ self.changed_status = False
+ HOOKCell.g_stop_hook = False
+ return out
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..089f444b6181f0623c8029926c4808ab22ae27ca
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml
@@ -0,0 +1,925 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+# List of ops that register hooks
+
+
+ops:
+ - adaptive_avg_pool1d
+ - adaptive_avg_pool2d
+ - adaptive_avg_pool3d
+ - adaptive_max_pool1d
+ - adaptive_max_pool2d
+ - avg_pool1d
+ - avg_pool2d
+ - avg_pool3d
+ - batch_norm
+ - bias_add
+ - ctc_greedy_decoder
+ - conv1d
+ - conv2d
+ - conv3d
+ - deformable_conv2d
+ - dense
+ - dropout
+ - dropout1d
+ - dropout2d
+ - dropout3d
+ - flatten
+ - fold
+ - fractional_max_pool3d
+ - lp_pool1d
+ - lp_pool2d
+ - lrn
+ - max_pool2d
+ - max_pool3d
+ - max_unpool1d
+ - max_unpool2d
+ - max_unpool3d
+ - unfold
+ - binary_cross_entropy
+ - binary_cross_entropy_with_logits
+ - cosine_embedding_loss
+ - cross_entropy
+ - ctc_loss
+ - gaussian_nll_loss
+ - hinge_embedding_loss
+ - huber_loss
+ - kl_div
+ - l1_loss
+ - margin_ranking_loss
+ - mse_loss
+ - multi_margin_loss
+ - multilabel_margin_loss
+ - multilabel_soft_margin_loss
+ - nll_loss
+ - smooth_l1_loss
+ - triplet_margin_loss
+ - elu
+ - fast_gelu
+ - gelu
+ - glu
+ - gumbel_softmax
+ - hardshrink
+ - hardsigmoid
+ - hardswish
+ - hardtanh
+ - leaky_relu
+ - log_softmax
+ - logsigmoid
+ - mish
+ - prelu
+ - relu
+ - relu6
+ - celu
+ - rrelu
+ - selu
+ - sigmoid
+ - silu
+ - softmax
+ - softmin
+ - softshrink
+ - softsign
+ - tanh
+ - threshold
+ - cdist
+ - dist
+ - pdist
+ - choice_with_mask
+ - random_categorical
+ - log_uniform_candidate_sampler
+ - uniform_candidate_sampler
+ - affine_grid
+ - bounding_box_decode
+ - bounding_box_encode
+ - col2im
+ - check_valid
+ - crop_and_resize
+ - grid_sample
+ - interpolate
+ - iou
+ - pad
+ - padding
+ - pixel_shuffle
+ - pixel_unshuffle
+ - upsample
+ - abs
+ - absolute
+ - accumulate_n
+ - acos
+ - arccos
+ - acosh
+ - add
+ - addcdiv
+ - addcmul
+ - addmv
+ - addn
+ - angle
+ - arccosh
+ - arcsin
+ - arcsinh
+ - arctan
+ - arctanh
+ - arctan2
+ - asin
+ - asinh
+ - atan
+ - atan2
+ - atanh
+ - atleast_1d
+ - atleast_2d
+ - atleast_3d
+ - bessel_i0
+ - bessel_i0e
+ - bessel_i1
+ - bessel_i1e
+ - bessel_j0
+ - bessel_j1
+ - bessel_k0
+ - bessel_k0e
+ - bessel_k1
+ - bessel_k1e
+ - bessel_y0
+ - bessel_y1
+ - bitwise_and
+ - bitwise_left_shift
+ - bitwise_or
+ - bitwise_right_shift
+ - bitwise_xor
+ - ceil
+ - clamp
+ - clip
+ - combinations
+ - copysign
+ - cos
+ - cosh
+ - cosine_similarity
+ - cov
+ - diag_embed
+ - diff
+ - deg2rad
+ - digamma
+ - div
+ - divide
+ - erf
+ - erfc
+ - erfinv
+ - exp
+ - exp2
+ - expm1
+ - floor
+ - floor_div
+ - floor_mod
+ - float_power
+ - fmod
+ - frac
+ - gcd
+ - hypot
+ - igamma
+ - igammac
+ - imag
+ - i0
+ - inv
+ - invert
+ - lcm
+ - ldexp
+ - lerp
+ - log
+ - log2
+ - log10
+ - log1p
+ - logaddexp
+ - logaddexp2
+ - logical_and
+ - logical_not
+ - logical_or
+ - logical_xor
+ - logit
+ - mul
+ - multiply
+ - mvlgamma
+ - neg
+ - negative
+ - nextafter
+ - polar
+ - polygamma
+ - positive
+ - pow
+ - rad2deg
+ - ravel
+ - real
+ - reciprocal
+ - remainder
+ - rot90
+ - round
+ - rsqrt
+ - sgn
+ - sign
+ - signbit
+ - sin
+ - sinc
+ - sinh
+ - sqrt
+ - square
+ - sub
+ - subtract
+ - t
+ - tan
+ - tanhshrink
+ - trapz
+ - tril_indices
+ - triu_indices
+ - true_divide
+ - trunc
+ - truncate_div
+ - truncate_mod
+ - xdivy
+ - xlogy
+ - zeta
+ - all
+ - amax
+ - amin
+ - aminmax
+ - any
+ - argmax
+ - argmin
+ - cummax
+ - cummin
+ - cumprod
+ - cumsum
+ - fmax
+ - histc
+ - logsumexp
+ - max
+ - mean
+ - median
+ - min
+ - norm
+ - prod
+ - std
+ - std_mean
+ - var
+ - var_mean
+ - argsort
+ - approximate_equal
+ - equal
+ - ge
+ - greater
+ - greater_equal
+ - gt
+ - intopk
+ - isclose
+ - isfinite
+ - isinf
+ - isnan
+ - isneginf
+ - isposinf
+ - isreal
+ - is_complex
+ - le
+ - less
+ - less_equal
+ - lt
+ - maximum
+ - minimum
+ - msort
+ - ne
+ - not_equal
+ - searchsorted
+ - topk
+ - bmm
+ - addbmm
+ - addmm
+ - baddbmm
+ - addr
+ - adjoint
+ - cholesky
+ - cholesky_solve
+ - batch_dot
+ - dot
+ - eig
+ - inner
+ - inverse
+ - geqrf
+ - ger
+ - kron
+ - lu_solve
+ - lu_unpack
+ - matmul
+ - matrix_solve
+ - matrix_band_part
+ - matrix_diag
+ - matrix_diag_part
+ - matrix_set_diag
+ - mm
+ - mv
+ - outer
+ - orgqr
+ - ormqr
+ - pinv
+ - svd
+ - tensor_dot
+ - logdet
+ - slogdet
+ - qr
+ - trace
+ - bartlett_window
+ - blackman_window
+ - hamming_window
+ - hann_window
+ - kaiser_window
+ - eye
+ - fill
+ - full
+ - full_like
+ - linspace
+ - logspace
+ - one_hot
+ - arange
+ - range
+ - heaviside
+ - bernoulli
+ - gamma
+ - laplace
+ - multinomial
+ - multinomial_with_replacement
+ - rand
+ - rand_like
+ - randint
+ - randint_like
+ - randn
+ - randn_like
+ - random_gamma
+ - random_poisson
+ - randperm
+ - standard_laplace
+ - standard_normal
+ - uniform
+ - argwhere
+ - batch_to_space_nd
+ - bincount
+ - block_diag
+ - broadcast_to
+ - cat
+ - channel_shuffle
+ - chunk
+ - column_stack
+ - concat
+ - conj
+ - count_nonzero
+ - deepcopy
+ - diag
+ - diagflat
+ - diagonal
+ - dyn_shape
+ - dsplit
+ - dstack
+ - einsum
+ - expand
+ - expand_dims
+ - flip
+ - fliplr
+ - flipud
+ - gather_d
+ - gather_elements
+ - gather_nd
+ - hsplit
+ - hstack
+ - index_add
+ - index_fill
+ - index_select
+ - inplace_add
+ - inplace_index_add
+ - inplace_sub
+ - inplace_update
+ - masked_fill
+ - masked_select
+ - meshgrid
+ - moveaxis
+ - movedim
+ - narrow
+ - nan_to_num
+ - nansum
+ - normal
+ - nonzero
+ - population_count
+ - rank
+ - repeat_elements
+ - repeat_interleave
+ - reshape
+ - reverse
+ - reverse_sequence
+ - roll
+ - scatter
+ - scatter_nd
+ - select
+ - sequence_mask
+ - shuffle
+ - size
+ - slice
+ - sort
+ - space_to_batch_nd
+ - sparse_segment_mean
+ - split
+ - squeeze
+ - stack
+ - strided_slice
+ - sum
+ - swapaxes
+ - swapdims
+ - tensor_scatter_add
+ - tensor_scatter_div
+ - tensor_scatter_max
+ - tensor_scatter_min
+ - tensor_scatter_mul
+ - tensor_scatter_sub
+ - tensor_scatter_elements
+ - tensor_split
+ - tile
+ - tril
+ - triu
+ - transpose
+ - unbind
+ - unique
+ - unique_consecutive
+ - unique_with_pad
+ - unsorted_segment_max
+ - unsorted_segment_min
+ - unsorted_segment_prod
+ - unsorted_segment_sum
+ - unsqueeze
+ - unstack
+ - view_as_real
+ - vsplit
+ - vstack
+ - where
+ - cross
+ - renorm
+ - is_tensor
+ - scalar_cast
+ - scalar_to_tensor
+ - tuple_to_array
+ - clip_by_global_norm
+ - clip_by_value
+ - assign
+ - assign_add
+ - assign_sub
+ - scatter_add
+ - scatter_div
+ - scatter_max
+ - scatter_min
+ - scatter_mul
+ - scatter_nd_add
+ - scatter_nd_div
+ - scatter_nd_max
+ - scatter_nd_min
+ - scatter_nd_mul
+ - scatter_nd_sub
+ - scatter_update
+ - derivative
+ - jet
+
+tensor:
+ - __abs__
+ - __add__
+ - __and__
+ - __bool__
+ - __eq__
+ - __ge__
+ - __gt__
+ - __iadd__
+ - __ifloordiv__
+ - __imatmul__
+ - __imod__
+ - __imul__
+ - __isub__
+ - __le__
+ - __lt__
+ - __matmul__
+ - __mod__
+ - __mul__
+ - __ne__
+ - __neg__
+ - __or__
+ - __pow__
+ - __radd__
+ - __rmatmul__
+ - __rmod__
+ - __rmul__
+ - __rpow__
+ - __rsub__
+ - __sub__
+ - __truediv__
+ - __xor__
+ - abs
+ - absolute
+ - acos
+ - acosh
+ - add
+ - addbmm
+ - addcdiv
+ - addcmul
+ - addmm
+ - addmv
+ - addr
+ - all
+ - amax
+ - amin
+ - any
+ - arccos
+ - arccosh
+ - argmax
+ - angle
+ - arcsin
+ - arcsinh
+ - arctan
+ - arctanh
+ - argmin
+ - argsort
+ - asin
+ - asinh
+ - atan
+ - atan2
+ - atanh
+ - baddbmm
+ - bernoulli
+ - bincount
+ - bitwise_and
+ - bitwise_or
+ - bitwise_xor
+ - bmm
+ - bool
+ - broadcast_to
+ - ceil
+ - cholesky_solve
+ - cholesky
+ - clamp
+ - clip
+ - conj
+ - copysign
+ - cos
+ - cosh
+ - cross
+ - cummax
+ - cummin
+ - cumprod
+ - cumsum
+ - deg2rad
+ - diag
+ - diagflat
+ - diff
+ - digamma
+ - div
+ - divide
+ - equal
+ - erf
+ - erfc
+ - erfinv
+ - exp
+ - expand_as
+ - expm1
+ - flip
+ - fliplr
+ - flipud
+ - float_power
+ - floor
+ - fmod
+ - frac
+ - gather_elements
+ - ge
+ - geqrf
+ - ger
+ - greater
+ - greater_equal
+ - gt
+ - half
+ - hardshrink
+ - heaviside
+ - histc
+ - hypot
+ - i0
+ - igamma
+ - igammac
+ - imag
+ - index_add
+ - index_fill
+ - index_put
+ - index_select
+ - inner
+ - int
+ - inverse
+ - isclose
+ - isfinite
+ - isinf
+ - isnan
+ - is_complex
+ - is_signed
+ - isneginf
+ - isposinf
+ - isreal
+ - lcm
+ - ldexp
+ - le
+ - lerp
+ - less
+ - less_equal
+ - log
+ - log10
+ - log1p
+ - log2
+ - logaddexp
+ - logaddexp2
+ - logdet
+ - logical_and
+ - logical_not
+ - logical_or
+ - logical_xor
+ - logit
+ - logsumexp
+ - long
+ - lt
+ - masked_fill
+ - masked_scatter
+ - masked_select
+ - matmul
+ - max
+ - maximum
+ - mean
+ - median
+ - min
+ - minimum
+ - moveaxis
+ - movedim
+ - msort
+ - multinomial
+ - multiply
+ - mvlgamma
+ - nan_to_num
+ - nansum
+ - narrow
+ - ne
+ - neg
+ - negative
+ - nelement
+ - new_ones
+ - new_zeros
+ - nextafter
+ - norm
+ - nonzero
+ - not_equal
+ - ormqr
+ - permute
+ - pow
+ - prod
+ - qr
+ - ravel
+ - real
+ - reciprocal
+ - remainder
+ - renorm
+ - rad2deg
+ - tile
+ - repeat_interleave
+ - reshape
+ - reshape
+ - round
+ - rot90
+ - rsqrt
+ - sum_to_size
+ - scatter
+ - sgn
+ - short
+ - sigmoid
+ - sign
+ - signbit
+ - sin
+ - sinc
+ - sinh
+ - slogdet
+ - sort
+ - split
+ - sqrt
+ - square
+ - squeeze
+ - std
+ - subtract
+ - subtract
+ - svd
+ - swapaxes
+ - swapdims
+ - t
+ - take
+ - tan
+ - tanh
+ - trace
+ - swapaxes
+ - tile
+ - to
+ - topk
+ - tril
+ - tensor_split
+ - transpose
+ - true_divide
+ - trunc
+ - unbind
+ - unique_consecutive
+ - unsqueeze
+ - var
+ - view
+ - where
+ - xlogy
+ - from_numpy
+ - std
+ - take
+ - var
+ - all
+ - any
+ - copy
+ - diagonal
+ - flatten
+ - resize
+ - sum
+
+mint.ops:
+ - abs
+ - absolute_import
+ - add
+ - add_ex
+ - all
+ - any
+ - any_ex
+ - arange
+ - argmax
+ - avg_pool2d
+ - baddbmm
+ - baddbmm_ex
+ - batch_norm
+ - binary_cross_entropy_with_logits
+ - bitwise_and
+ - bitwise_or
+ - bitwise_xor
+ - bmm
+ - broadcast_to
+ - cat
+ - cat_ex
+ - ceil
+ - chunk
+ - clamp
+ - conv2d
+ - conv_transpose2d
+ - cos
+ - cross
+ - cummax
+ - cummin
+ - cumsum
+ - div
+ - divide
+ - dropout
+ - embedding
+ - eq
+ - erf
+ - erfinv
+ - exp
+ - flatten
+ - flip
+ - flip_ex
+ - fold
+ - full
+ - functional
+ - gather
+ - gelu
+ - greater
+ - grid_sample
+ - group_norm
+ - gt
+ - index_select
+ - interpolate
+ - isclose
+ - isfinite
+ - layer_norm
+ - le
+ - leaky_relu
+ - less
+ - less_equal
+ - linear
+ - linspace
+ - log
+ - logical_and
+ - logical_not
+ - logical_or
+ - lt
+ - masked_select
+ - matmul
+ - max
+ - max_pool2d
+ - maximum
+ - mean
+ - mean_ex
+ - min
+ - minimum
+ - mul
+ - ne
+ - neg
+ - negative
+ - nn
+ - nonzero
+ - normal
+ - one_hot
+ - ones
+ - ones_ex
+ - ones_like
+ - pad
+ - permute
+ - permute_ex
+ - pow
+ - prod
+ - reciprocal
+ - relu
+ - remainder
+ - repeat_interleave
+ - rsqrt
+ - scatter
+ - scatter_add
+ - searchsorted
+ - sigmoid
+ - silu
+ - sin
+ - softmax
+ - softplus
+ - sort
+ - split
+ - sqrt
+ - sqrt_ex
+ - square
+ - stack
+ - sub
+ - sub_ex
+ - sum
+ - tanh
+ - tile
+ - topk
+ - tril
+ - triu
+ - unfold
+ - unique
+ - where
+ - xlogy
+ - zeros
+ - zeros_ex
+ - zeros_like
+
+mint.nn:
+ - Dropout
+ - Embedding
+ - Fold
+ - LayerNorm
+ - Linear
+ - MaxPool2d
+ - Unfold
+ - Upsample
+
+mint.nn.functional:
+ - absolute_import
+ - avg_pool2d
+ - batch_norm
+ - batch_norm_ex
+ - bce_with_logits
+ - binary_cross_entropy_with_logits
+ - conv_transpose2d
+ - dense
+ - dropout
+ - embedding
+ - fold
+ - gelu
+ - grid_sample
+ - group_norm
+ - interpolate
+ - layer_norm
+ - leaky_relu
+ - linear
+ - max_pool2d
+ - max_pool2d_ex
+ - normal
+ - one_hot
+ - one_hot_ext
+ - pad
+ - relu
+ - sigmoid
+ - silu
+ - softmax
+ - softmax_ex
+ - softplus
+ - tanh
+ - unfold
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_functional.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_functional.py
new file mode 100644
index 0000000000000000000000000000000000000000..be3d1bd2545cbdc5d1b03044feed74841d5f2f91
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_functional.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import os
+import yaml
+import mindspore as ms
+from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
+from msprobe.core.common.utils import Const
+from msprobe.core.common.file_check import FileOpen
+
+
+cur_path = os.path.dirname(os.path.realpath(__file__))
+yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
+
+
+def load_ops_functions():
+ ops_func = {f: getattr(ms.ops, f) for f in dir(ms.ops)}
+ mint_ops_func = {f: getattr(ms.mint, f) for f in dir(ms.mint)}
+ mint_func_ops_func = {f: getattr(ms.mint.nn.functional, f) for f in dir(ms.mint.nn.functional)}
+ return ops_func, mint_ops_func, mint_func_ops_func
+
+
+def get_functional_ops():
+ ops_func, mint_ops_func, mint_func_ops_func = load_ops_functions()
+ with FileOpen(yaml_path, 'r') as f:
+ config = yaml.safe_load(f)
+ WrapFunctionalOps = config.get("ops")
+ WrapMintOps = config.get("mint.ops")
+ WrapMintFunctionalOps = config.get("mint.nn.functional")
+ return (
+ set(WrapFunctionalOps) & set(ops_func.keys()),
+ set(WrapMintOps) & set(mint_ops_func.keys()),
+ set(WrapMintFunctionalOps) & set(mint_func_ops_func.keys())
+ )
+
+
+class HOOKFunctionalOP(object):
+ pass
+
+
+class HOOKMintOP(object):
+ pass
+
+
+class HOOKMintNNFunctionalOP(object):
+ pass
+
+
+class FunctionalOPTemplate(HOOKCell):
+ def __init__(self, op_name, op_dict, prefix, hook):
+ self.op_name = op_name
+ self.op_func = op_dict[op_name]
+ self.prefix_op_name_ = prefix + str(op_name.split(Const.SEP)[-1]) + Const.SEP
+ super().__init__(hook)
+
+ def construct(self, *args, **kwargs):
+ if self.op_name.startswith('dropout'):
+ return args[0] if args else kwargs.get('input')
+ return self.op_func(*args, **kwargs)
+
+
+def wrap_functional_op(op_name, op_dict, prefix, hook):
+ def op_template(*args, **kwargs):
+ return FunctionalOPTemplate(op_name, op_dict, prefix, hook)(*args, **kwargs)
+ return op_template
+
+
+def wrap_functional_ops_and_bind(ops, op_dict, prefix, hook, hook_class):
+ for op_name in ops:
+ if callable(op_dict[op_name]):
+ setattr(hook_class, Const.ATTR_NAME_PREFIX + op_name, wrap_functional_op(op_name, op_dict, prefix, hook))
+
+
+def setup_hooks(hook):
+ functional_ops, mint_ops, mint_func_ops = get_functional_ops()
+ wrap_functional_ops_and_bind(
+ functional_ops, {f: getattr(ms.ops, f) for f in dir(ms.ops)}, "Functional.", hook, HOOKFunctionalOP)
+ wrap_functional_ops_and_bind(
+ mint_ops, {f: getattr(ms.mint, f) for f in dir(ms.mint)}, "Mint.", hook, HOOKMintOP)
+ wrap_functional_ops_and_bind(
+ mint_func_ops, {f: getattr(ms.mint.nn.functional, f) for f in dir(ms.mint.nn.functional)}, "MintFunctional.", hook, HOOKMintNNFunctionalOP)
+
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_tensor.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae6a9a979dd37851d5d33bc0edf6ee37095e1e7e
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_tensor.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import os
+import yaml
+import mindspore as ms
+
+from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
+from msprobe.core.common.utils import Const
+from msprobe.core.common.file_check import FileOpen
+
+cur_path = os.path.dirname(os.path.realpath(__file__))
+yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
+with FileOpen(yaml_path, 'r') as f:
+ WrapTensorOps = yaml.safe_load(f).get('tensor')
+
+TensorFunc = {}
+for f in dir(ms.Tensor):
+ TensorFunc[f] = getattr(ms.Tensor, f)
+
+
+def get_tensor_ops():
+ global WrapTensorOps
+ _tensor_ops = dir(ms.Tensor)
+ return set(WrapTensorOps) & set(_tensor_ops)
+
+
+class HOOKTensor(object):
+ pass
+
+
+class TensorOPTemplate(HOOKCell):
+
+ def __init__(self, op_name, hook):
+ self.op_name_ = op_name
+ self.prefix_op_name_ = "Tensor." + str(op_name) + Const.SEP
+ super().__init__(hook)
+
+ def construct(self, *args, **kwargs):
+ return TensorFunc[str(self.op_name_)](*args, **kwargs)
+
+
+def wrap_tensor_op(op_name, hook):
+ def tensor_op_template(*args, **kwargs):
+ return TensorOPTemplate(op_name, hook)(*args, **kwargs)
+
+ return tensor_op_template
+
+
+def wrap_tensor_ops_and_bind(hook):
+ _tensor_ops = get_tensor_ops()
+ for op_name in _tensor_ops:
+ if callable(TensorFunc[op_name]):
+ setattr(HOOKTensor, Const.ATTR_NAME_PREFIX + str(op_name), wrap_tensor_op(op_name, hook))
diff --git a/debug/accuracy_tools/atat/mindspore/dump/kernel_graph_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/kernel_graph_dump.py
similarity index 92%
rename from debug/accuracy_tools/atat/mindspore/dump/kernel_graph_dump.py
rename to debug/accuracy_tools/msprobe/mindspore/dump/kernel_graph_dump.py
index 190e6bc4d5591f9ed5aa466c2be631fb224fc89b..8320ee0906458734b29b9b911351739fefb77163 100644
--- a/debug/accuracy_tools/atat/mindspore/dump/kernel_graph_dump.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/kernel_graph_dump.py
@@ -1,9 +1,9 @@
import os
import json
-from atat.core.common.utils import make_dump_path_if_not_exists
-from atat.mindspore.debugger.debugger_config import DebuggerConfig
-from atat.core.common.log import logger
-from atat.core.common.file_check import FileOpen
+from msprobe.core.common.utils import make_dump_path_if_not_exists
+from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
+from msprobe.core.common.log import logger
+from msprobe.core.common.file_check import FileOpen
class KernelGraphDump:
diff --git a/debug/accuracy_tools/atat/mindspore/ms_config.py b/debug/accuracy_tools/msprobe/mindspore/ms_config.py
similarity index 67%
rename from debug/accuracy_tools/atat/mindspore/ms_config.py
rename to debug/accuracy_tools/msprobe/mindspore/ms_config.py
index 02cead32f1f5fc2b00c47d75ac9d9950a3cd258d..c0ef6bb6c00aab426fd42a11c3bc2436440a4a6a 100644
--- a/debug/accuracy_tools/atat/mindspore/ms_config.py
+++ b/debug/accuracy_tools/msprobe/mindspore/ms_config.py
@@ -1,6 +1,7 @@
import json
-from atat.core.common_config import CommonConfig, BaseConfig
-from atat.core.common.file_check import FileOpen
+from msprobe.core.common_config import CommonConfig, BaseConfig
+from msprobe.core.common.file_check import FileOpen
+from msprobe.core.common.const import Const
class TensorConfig(BaseConfig):
@@ -31,39 +32,43 @@ class StatisticsConfig(BaseConfig):
if self.data_mode is not None and len(self.data_mode) > 0:
if len(self.data_mode) > 1 or self.data_mode[0] not in ["all", "input", "output"]:
raise Exception("data_mode must be all, input or output")
+ if self.summary_mode and self.summary_mode not in ["statistics", "md5"]:
+ raise Exception("summary_mode is invalid")
-class OverflowCheck(BaseConfig):
+class OverflowCheckConfig(BaseConfig):
def __init__(self, json_config):
super().__init__(json_config)
- self.file_format = None
- self.check_mode = json_config.get("check_mode")
+ self.data_mode = ["all"]
self._check_config()
def _check_config(self):
- if self.data_mode is not None and len(self.data_mode) > 0:
- if len(self.data_mode) > 1 or self.data_mode[0] not in ["all", "input", "output"]:
- raise Exception("data_mode must be all, input or output")
+ if self.overflow_nums is not None and not isinstance(self.overflow_nums, int):
+ raise Exception("overflow_nums is invalid, it should be an integer")
+ if self.overflow_nums is not None and self.overflow_nums != -1 and self.overflow_nums <= 0:
+ raise Exception("overflow_nums should be -1 or positive integer")
if self.check_mode and self.check_mode not in ["all", "aicore", "atomic"]:
raise Exception("check_mode is invalid")
+TaskDict = {
+ Const.TENSOR: TensorConfig,
+ Const.STATISTICS: StatisticsConfig,
+ Const.OVERFLOW_CHECK: OverflowCheckConfig,
+}
+
+
def parse_common_config(json_config):
return CommonConfig(json_config)
def parse_task_config(task, json_config):
- task_map = json_config[task]
+ task_map = json_config.get(task)
if not task_map:
task_map = dict()
- if task == "tensor":
- return TensorConfig(task_map)
- elif task == "statistics":
- return StatisticsConfig(task_map)
- elif task == "overflow_check":
- return OverflowCheck(task_map)
- else:
+ if task not in TaskDict:
raise Exception("task is invalid.")
+ return TaskDict.get(task)(task_map)
def parse_json_config(json_file_path):
@@ -73,6 +78,6 @@ def parse_json_config(json_file_path):
json_config = json.load(file)
common_config = parse_common_config(json_config)
if not common_config.task:
- common_config.task = "statistics"
+ common_config.task = Const.STATISTICS
task_config = parse_task_config(common_config.task, json_config)
return common_config, task_config
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/__init__.py b/debug/accuracy_tools/msprobe/mindspore/overflow_check/__init__.py
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/__init__.py
rename to debug/accuracy_tools/msprobe/mindspore/overflow_check/__init__.py
diff --git a/debug/accuracy_tools/atat/mindspore/overflow_check/kernel_graph_overflow_check.py b/debug/accuracy_tools/msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py
similarity index 89%
rename from debug/accuracy_tools/atat/mindspore/overflow_check/kernel_graph_overflow_check.py
rename to debug/accuracy_tools/msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py
index 7a677eb3c70c583e745785d3b8b988cbbe93e7dd..6640608735d98c17c1b544b58183224a1cd4ba55 100644
--- a/debug/accuracy_tools/atat/mindspore/overflow_check/kernel_graph_overflow_check.py
+++ b/debug/accuracy_tools/msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py
@@ -1,9 +1,9 @@
import os
import json
-from atat.core.common.utils import make_dump_path_if_not_exists
-from atat.mindspore.debugger.debugger_config import DebuggerConfig
-from atat.core.common.log import logger
-from atat.core.common.file_check import FileOpen
+from msprobe.core.common.utils import make_dump_path_if_not_exists
+from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
+from msprobe.core.common.log import logger
+from msprobe.core.common.file_check import FileOpen
class KernelGraphOverflowCheck:
diff --git a/debug/accuracy_tools/atat/mindspore/overflow_check/overflow_check_tool_factory.py b/debug/accuracy_tools/msprobe/mindspore/overflow_check/overflow_check_tool_factory.py
similarity index 81%
rename from debug/accuracy_tools/atat/mindspore/overflow_check/overflow_check_tool_factory.py
rename to debug/accuracy_tools/msprobe/mindspore/overflow_check/overflow_check_tool_factory.py
index fe53359be1ba1ecb73fb84138228415f68e1c2ce..d809c714211aa34f41588bba332b55ed808b5376 100644
--- a/debug/accuracy_tools/atat/mindspore/overflow_check/overflow_check_tool_factory.py
+++ b/debug/accuracy_tools/msprobe/mindspore/overflow_check/overflow_check_tool_factory.py
@@ -1,5 +1,5 @@
-from atat.mindspore.debugger.debugger_config import DebuggerConfig
-from atat.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck
+from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
+from msprobe.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck
class OverflowCheckToolFactory:
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
new file mode 100644
index 0000000000000000000000000000000000000000..50776aaf1097339e7c6d98944db7ddf2d2238c5f
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -0,0 +1,152 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import os
+import copy
+from pathlib import Path
+import functools
+from collections import defaultdict
+
+from msprobe.core.data_dump.data_collector import build_data_collector
+from msprobe.core.data_dump.scope import BaseScope
+from msprobe.mindspore.common.utils import get_rank_if_initialized
+from msprobe.core.common.file_check import FileChecker, FileCheckConst, check_path_before_create
+from msprobe.mindspore.common.log import logger
+from msprobe.core.common.utils import Const
+from msprobe.core.common.exceptions import DistributedNotInitializedError
+from msprobe.mindspore.dump.hook_cell.api_registry import api_register
+from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
+from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
+
+
+class Service:
+ def __init__(self, config):
+ self.model = None
+ self.config = copy.deepcopy(config)
+ self.config.level = self.config.level_ori
+ self.data_collector = build_data_collector(self.config)
+ self.switch = False
+ self.current_iter = 0
+ self.first_start = True
+ self.current_rank = None
+ self.dump_iter_dir = None
+ self.start_call = False
+
+ def build_hook(self, module_type, name):
+ def forward_hook(api_or_module_name, module, input, output):
+ self.data_collector.visit_and_clear_overflow_status(api_or_module_name)
+ if not self.switch:
+ return None
+ if self.data_collector:
+ module_input_output = ModuleForwardInputsOutputs(args=input, kwargs=module.input_kwargs, output=output)
+ self.data_collector.forward_data_collect(api_or_module_name, module, pid, module_input_output)
+ if self.data_collector.if_return_forward_new_output():
+ return self.data_collector.get_forward_new_output()
+ del module.input_kwargs
+ return output
+
+ def backward_hook(api_or_module_name, module, grad_input, grad_output):
+ self.data_collector.visit_and_clear_overflow_status(api_or_module_name)
+ if not self.switch:
+ return
+ if self.data_collector:
+ module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
+ self.data_collector.backward_data_collect(api_or_module_name, module, pid, module_input_output)
+
+ pid = os.getpid()
+ forward_name_template = name + Const.FORWARD
+ backward_name_template = name + Const.BACKWARD
+ forward_hook = functools.partial(forward_hook, forward_name_template)
+ backward_hook = functools.partial(backward_hook, backward_name_template)
+
+ def wrap_forward_hook(*args, **kwargs):
+ return forward_hook(*args, **kwargs)
+
+ def wrap_backward_hook(*args, **kwargs):
+ return backward_hook(*args, **kwargs)
+
+ return wrap_forward_hook, wrap_backward_hook
+
+ def step(self):
+ self.current_iter += 1
+ self.data_collector.update_iter(self.current_iter)
+ HOOKCell.cell_count = defaultdict(int)
+
+ def start(self, model=None):
+ self.model = model
+ self.start_call = True
+ logger.info("msprobe: debugger.start() is set successfully")
+ if self.config.step and self.current_iter > max(self.config.step):
+ self.stop()
+ raise Exception("msprobe: exit after iteration {}".format(max(self.config.step)))
+ if self.config.step and self.current_iter not in self.config.step:
+ return
+ if self.first_start:
+ try:
+ self.current_rank = get_rank_if_initialized()
+ except DistributedNotInitializedError:
+ self.current_rank = None
+
+ if self.config.rank and self.current_rank not in self.config.rank:
+ return
+ self.register_hook_new()
+ self.first_start = False
+ self.switch = True
+ logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
+ self.create_dirs()
+ logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
+
+ def stop(self):
+ logger.info("msprobe: debugger.stop() is set successfully. "
+ "Please set debugger.start() to turn on the dump switch again. ")
+ if not self.start_call:
+ logger.error("msprobe: debugger.start() is not set in the current scope.")
+ raise Exception("debugger.start() is not set in the current scope.")
+ if self.config.step and self.current_iter not in self.config.step:
+ return
+ if self.config.rank and self.current_rank not in self.config.rank:
+ return
+ self.switch = False
+ self.start_call = False
+ self.data_collector.write_json()
+
+ def create_dirs(self):
+ check_path_before_create(self.config.dump_path)
+ if not os.path.exists(self.config.dump_path):
+ Path(self.config.dump_path).mkdir(mode=0o750, exist_ok=True)
+ file_check = FileChecker(self.config.dump_path, FileCheckConst.DIR)
+ file_check.common_check()
+ self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
+ cur_rank = self.current_rank if self.current_rank is not None else ''
+ dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
+ if not os.path.exists(dump_dir):
+ Path(dump_dir).mkdir(mode=0o750, parents=True, exist_ok=True)
+ if self.config.task in self.data_collector.tasks_need_tensor_data:
+ dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
+ Path(dump_data_dir).mkdir(mode=0o750, exist_ok=True)
+ else:
+ dump_data_dir = None
+
+ dump_file_path = os.path.join(dump_dir, "dump.json")
+ stack_file_path = os.path.join(dump_dir, "stack.json")
+ construct_file_path = os.path.join(dump_dir, "construct.json")
+ self.data_collector.update_dump_paths(
+ dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None)
+
+ def register_hook_new(self):
+ logger.info("The {} hook function is successfully mounted to the model.".format(self.config.task))
+ if self.config.level == "L1":
+ api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
+ api_register.api_set_hook_func()
diff --git a/debug/accuracy_tools/atat/mindspore/task_handler_factory.py b/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py
similarity index 68%
rename from debug/accuracy_tools/atat/mindspore/task_handler_factory.py
rename to debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py
index 4f80e4e89c92156762ea0e4c4ed3302cc5c31f5f..7b7e6fd889c775a4491e824c1f73e6021cb99350 100644
--- a/debug/accuracy_tools/atat/mindspore/task_handler_factory.py
+++ b/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py
@@ -1,6 +1,6 @@
-from atat.mindspore.debugger.debugger_config import DebuggerConfig
-from atat.mindspore.dump.dump_tool_factory import DumpToolFactory
-from atat.mindspore.overflow_check.overflow_check_tool_factory import OverflowCheckToolFactory
+from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
+from msprobe.mindspore.dump.dump_tool_factory import DumpToolFactory
+from msprobe.mindspore.overflow_check.overflow_check_tool_factory import OverflowCheckToolFactory
class TaskHandlerFactory:
diff --git a/debug/accuracy_tools/atat/atat.py b/debug/accuracy_tools/msprobe/msprobe.py
similarity index 74%
rename from debug/accuracy_tools/atat/atat.py
rename to debug/accuracy_tools/msprobe/msprobe.py
index 90f8215b102d4f120b879773797c4b4864b25d6b..9c96b193994410403e5be9bc592e76999f6bc8b2 100644
--- a/debug/accuracy_tools/atat/atat.py
+++ b/debug/accuracy_tools/msprobe/msprobe.py
@@ -15,19 +15,20 @@
import argparse
import sys
-from atat.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command
-from atat.pytorch.parse_tool.cli import parse as cli_parse
-from atat.pytorch.api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut
-from atat.pytorch.api_accuracy_checker.compare.api_precision_compare import _api_precision_compare_parser, \
+from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command
+from msprobe.pytorch.parse_tool.cli import parse as cli_parse
+from msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut
+from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import _api_precision_compare_parser, \
_api_precision_compare_command
-from atat.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \
+from msprobe.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \
_run_overflow_check_command
+from msprobe.pytorch.torchair_compare.torchair_compare_cli import torchair_compare_parser, torchair_compare_cli
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
- description="atat(ascend training accuracy tools), [Powered by MindStudio].\n"
+ description="msprobe(mindstudio probe), [Powered by MindStudio].\n"
"Providing one-site accuracy difference debugging toolkit for training on Ascend Devices.\n"
f"For any issue, refer README.md first",
)
@@ -46,6 +47,8 @@ def main():
help='Number of splits for parallel processing. Range: 1-64')
_api_precision_compare_parser(api_precision_compare_cmd_parser)
_run_overflow_check_parser(run_overflow_check_cmd_parser)
+ torchair_compare_cmd_parser = subparsers.add_parser("torchair_compare")
+ torchair_compare_parser(torchair_compare_cmd_parser)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(0)
@@ -61,6 +64,8 @@ def main():
_api_precision_compare_command(args)
elif sys.argv[3] == "run_overflow_check":
_run_overflow_check_command(args)
+ elif sys.argv[3] == "torchair_compare":
+ torchair_compare_cli(args)
if __name__ == "__main__":
diff --git a/debug/accuracy_tools/atat/pytorch/__init__.py b/debug/accuracy_tools/msprobe/pytorch/__init__.py
similarity index 74%
rename from debug/accuracy_tools/atat/pytorch/__init__.py
rename to debug/accuracy_tools/msprobe/pytorch/__init__.py
index 482e850f7baa845bd831e0d4728e841661b9345b..58ab1ac35a3dcb25474c1da02f4d84f9de526346 100644
--- a/debug/accuracy_tools/atat/pytorch/__init__.py
+++ b/debug/accuracy_tools/msprobe/pytorch/__init__.py
@@ -2,3 +2,4 @@ from .debugger.precision_debugger import PrecisionDebugger
from .common.utils import seed_all
from .compare.acc_compare import compare
from .compare.distributed_compare import compare_distributed
+from .visualization.graph_service import compare_graph, build_graph
diff --git a/debug/accuracy_tools/atat/pytorch/advisor/advisor.py b/debug/accuracy_tools/msprobe/pytorch/advisor/advisor.py
similarity index 93%
rename from debug/accuracy_tools/atat/pytorch/advisor/advisor.py
rename to debug/accuracy_tools/msprobe/pytorch/advisor/advisor.py
index 43b3f40f97948808a987bd4211530cfca2cb025a..fe5bf1efb1f0e520c9cf77cd1cb808650e66e548 100644
--- a/debug/accuracy_tools/atat/pytorch/advisor/advisor.py
+++ b/debug/accuracy_tools/msprobe/pytorch/advisor/advisor.py
@@ -17,12 +17,12 @@
import os
-from atat.pytorch.advisor.advisor_result import AdvisorResult
-from atat.pytorch.advisor.advisor_const import AdvisorConst
-from atat.pytorch.common.log import logger
-from atat.core.common.utils import CompareException
-from atat.core.common.file_check import FileChecker
-from atat.core.common.const import Const, CompareConst, FileCheckConst
+from msprobe.pytorch.advisor.advisor_result import AdvisorResult
+from msprobe.pytorch.advisor.advisor_const import AdvisorConst
+from msprobe.pytorch.common.log import logger
+from msprobe.core.common.utils import CompareException
+from msprobe.core.common.file_utils import FileChecker
+from msprobe.core.common.const import Const, CompareConst, FileCheckConst
class Advisor:
"""
diff --git a/debug/accuracy_tools/atat/pytorch/advisor/advisor_const.py b/debug/accuracy_tools/msprobe/pytorch/advisor/advisor_const.py
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/advisor/advisor_const.py
rename to debug/accuracy_tools/msprobe/pytorch/advisor/advisor_const.py
diff --git a/debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py b/debug/accuracy_tools/msprobe/pytorch/advisor/advisor_result.py
similarity index 90%
rename from debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py
rename to debug/accuracy_tools/msprobe/pytorch/advisor/advisor_result.py
index a24fa2a1155501d91eb2462528f71824f091f318..58b76d3c8479321eb354e27c785a38d4ca3d8aaa 100644
--- a/debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py
+++ b/debug/accuracy_tools/msprobe/pytorch/advisor/advisor_result.py
@@ -17,10 +17,10 @@
import os
import time
-from atat.pytorch.advisor.advisor_const import AdvisorConst
-from atat.pytorch.common.log import logger
-from atat.core.common.const import Const, FileCheckConst
-from atat.core.common.file_check import change_mode
+from msprobe.pytorch.advisor.advisor_const import AdvisorConst
+from msprobe.pytorch.common.log import logger
+from msprobe.core.common.const import Const, FileCheckConst
+from msprobe.core.common.file_utils import change_mode
class AdvisorResult:
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/.keep b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/.keep
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/.keep
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/.keep
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/__init__.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/__init__.py
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/__init__.py
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/__init__.py
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/.keep b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/.keep
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/.keep
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/.keep
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/__init__.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/__init__.py
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/__init__.py
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/__init__.py
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/config.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py
similarity index 58%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/config.py
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py
index 0aceb691b2530bd8117396ef34c9cd4154c76716..ca6bb1627e736ded4050bc3fd0bab9f5d64a1d51 100644
--- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/config.py
+++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py
@@ -1,10 +1,8 @@
import os
import yaml
-from atat.pytorch.api_accuracy_checker.common.utils import check_file_or_directory_path
-from atat.pytorch.hook_module.utils import WrapFunctionalOps, WrapTensorOps, WrapTorchOps
-from atat.core.common.file_check import FileOpen
-
-WrapApi = set(WrapFunctionalOps) | set(WrapTensorOps) | set(WrapTorchOps)
+from msprobe.pytorch.api_accuracy_checker.common.utils import check_file_or_directory_path
+from msprobe.core.common.file_utils import FileOpen
+from msprobe.pytorch.pt_config import RunUTConfig
class Config:
@@ -14,9 +12,17 @@ class Config:
config = yaml.safe_load(file)
self.config = {key: self.validate(key, value) for key, value in config.items()}
- def validate(self, key, value):
+ def __getattr__(self, item):
+ return self.config[item]
+
+ def __str__(self):
+ return '\n'.join(f"{key}={value}" for key, value in self.config.items())
+
+ @staticmethod
+ def validate(key, value):
validators = {
'white_list': list,
+ 'black_list': list,
'error_data_path': str,
'precision': int
}
@@ -27,22 +33,13 @@ class Config:
if key == 'precision' and value < 0:
raise ValueError("precision must be greater than 0")
if key == 'white_list':
- if not isinstance(value, list):
- raise ValueError("white_list must be a list type")
- if not all(isinstance(i, str) for i in value):
- raise ValueError("All elements in white_list must be of str type")
- invalid_api = [i for i in value if i not in WrapApi]
- if invalid_api:
- raise ValueError(
- f"{', '.join(invalid_api)} is not in support_wrap_ops.yaml, please check the white_list")
+ RunUTConfig.check_filter_list_config(key, value)
+ if key == 'black_list':
+ RunUTConfig.check_filter_list_config(key, value)
+ if key == 'error_data_path':
+ RunUTConfig.check_error_data_path_config(value)
return value
- def __getattr__(self, item):
- return self.config[item]
-
- def __str__(self):
- return '\n'.join(f"{key}={value}" for key, value in self.config.items())
-
cur_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
yaml_path = os.path.join(cur_path, "config.yaml")
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py
similarity index 96%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py
index 9e1b02c0154760f468c8d603e4bf385777258434..7855a51e4b472fd3d22b5ad9ee2f7e4d4a0d39e5 100644
--- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py
+++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py
@@ -28,10 +28,10 @@ except ImportError:
else:
IS_GPU = False
-from atat.pytorch.common.log import logger
-from atat.core.common.file_check import FileChecker, FileOpen, change_mode, create_directory
-from atat.core.common.const import Const, FileCheckConst
-from atat.core.common.utils import CompareException
+from msprobe.pytorch.common.log import logger
+from msprobe.core.common.file_utils import FileChecker, FileOpen, change_mode, create_directory
+from msprobe.core.common.const import Const, FileCheckConst
+from msprobe.core.common.utils import CompareException
class DumpException(CompareException):
@@ -166,6 +166,7 @@ def initialize_save_path(save_path, dir_name):
os.mkdir(data_path, mode=FileCheckConst.DATA_DIR_AUTHORITY)
data_path_checker = FileChecker(data_path, FileCheckConst.DIR)
data_path_checker.common_check()
+ return data_path
def write_pt(file_path, tensor):
diff --git a/debug/accuracy_tools/atat/pytorch/debugger/__init__.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/__init__.py
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/debugger/__init__.py
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/__init__.py
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py
similarity index 98%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py
index 3982c167cca13617dc19332621b5ebc9fce12997..1bb19cc048e88c9353a19069031bb8acfdae05e9 100644
--- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py
+++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py
@@ -2,8 +2,8 @@
import torch
import numpy as np
-from atat.pytorch.api_accuracy_checker.compare.compare_utils import ULP_PARAMETERS
-from atat.core.common.const import CompareConst
+from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ULP_PARAMETERS
+from msprobe.core.common.const import CompareConst
DEFAULT_THRESHOLD = 1
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py
similarity index 97%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py
index f73c83c4881a5c7b9664314a2051ba8ba770c0a1..d85abfe9ddaf24ce19ecfef965886db93e16761c 100644
--- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py
+++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py
@@ -7,19 +7,19 @@ from collections import namedtuple
import torch
import pandas as pd
-from atat.pytorch.api_accuracy_checker.common.utils import write_csv
-from atat.pytorch.api_accuracy_checker.common.config import msCheckerConfig
-from atat.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \
+from msprobe.pytorch.api_accuracy_checker.common.utils import write_csv
+from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
+from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \
API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \
ApiPrecisionCompareColumn, AbsoluteStandardApi, BinaryStandardApi, ULPStandardApi, ThousandthStandardApi, \
BINARY_COMPARE_UNSUPPORT_LIST, ULP_COMPARE_SUPPORT_LIST, convert_str_to_float, CompareMessage, is_inf_or_nan, \
check_inf_or_nan
-from atat.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
-from atat.pytorch.api_accuracy_checker.run_ut.run_ut import get_validated_result_csv_path
-from atat.core.common.file_check import FileChecker, change_mode, check_path_before_create, create_directory
-from atat.pytorch.common.log import logger
-from atat.core.common.utils import CompareException
-from atat.core.common.const import CompareConst, FileCheckConst
+from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
+from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import get_validated_result_csv_path
+from msprobe.core.common.file_utils import FileChecker, change_mode, check_path_before_create, create_directory
+from msprobe.pytorch.common.log import logger
+from msprobe.core.common.utils import CompareException
+from msprobe.core.common.const import CompareConst, FileCheckConst
CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path'])
BenchmarkInf_Nan_Consistency = namedtuple('BenchmarkInf_Nan_Consistency', ['small_value_inf_nan_consistency',
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py
similarity index 97%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py
index ca35c8ed5daccf2c7c371ea1ae625c1e2435d743..ee49588288efc0a33c086913cc5624059de82272 100644
--- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py
+++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py
@@ -3,18 +3,18 @@ import os
from collections import namedtuple
import torch
import numpy as np
-from atat.pytorch.common.log import logger
-from atat.pytorch.api_accuracy_checker.common.utils import get_json_contents, write_csv
-from atat.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \
+from msprobe.pytorch.common.log import logger
+from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents, write_csv
+from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \
DETAIL_TEST_ROWS, precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, AbsoluteStandardApi, BinaryStandardApi, \
ULPStandardApi, ThousandthStandardApi, apis_threshold
-from atat.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
-from atat.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, \
+from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
+from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, \
get_mean_rel_err, get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \
get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \
check_small_value, check_norm_value, get_abs_bench_with_eps, get_ulp_err
-from atat.pytorch.api_accuracy_checker.common.config import msCheckerConfig
-from atat.core.common.const import Const, CompareConst
+from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
+from msprobe.core.common.const import Const, CompareConst
ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'bwd_success_status',
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py
similarity index 98%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py
index 9867a76fadff6eb071236cfabda93189547124a3..fb6d5dcc0f1c8b67ec2b67e8b419e8407cdc8d6d 100644
--- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py
+++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py
@@ -1,4 +1,4 @@
-from atat.core.common.const import CompareConst
+from msprobe.core.common.const import CompareConst
class CompareColumn:
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py
similarity index 97%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py
index b7b32e41e47ac69798f0bf0c15c5cdb929574a11..f8450b64b56a24fde76f0fc5a7be46479b5059ea 100644
--- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py
+++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py
@@ -5,10 +5,10 @@ import math
import numpy as np
import torch
import yaml
-from atat.core.common.utils import CompareException
-from atat.core.common.const import Const
-from atat.pytorch.common.log import logger
-from atat.core.common.file_check import FileOpen
+from msprobe.core.common.utils import CompareException
+from msprobe.core.common.const import Const
+from msprobe.pytorch.common.log import logger
+from msprobe.core.common.file_utils import FileOpen
current_time = time.strftime("%Y%m%d%H%M%S")
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/config.yaml b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml
similarity index 77%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/config.yaml
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml
index 7f26c72aa3102fa658c44caf7f5ce789cb23c46d..2dac535dc0501f6e47f0cdcc48bd88e1d73ab0dd 100644
--- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/config.yaml
+++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml
@@ -1,4 +1,5 @@
white_list: []
+black_list: []
error_data_path: './'
precision: 14
\ No newline at end of file
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/.keep b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/.keep
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/.keep
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/.keep
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/__init__.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/common/__init__.py
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py
similarity index 95%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py
index 97b2dcc7e404c399aca0838b57650c0c54933bf4..51477ddd46efe607b4407814c753d07a6f714ad9 100644
--- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py
+++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py
@@ -20,10 +20,12 @@ import math
import torch
import numpy
-from atat.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api
-from atat.pytorch.api_accuracy_checker.common.utils import check_file_or_directory_path, check_object_type, get_full_data_path, CompareException
-from atat.pytorch.common.log import logger
-from atat.core.common.const import Const
+from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api
+from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, get_full_data_path, \
+ CompareException
+from msprobe.core.common.file_utils import FileChecker
+from msprobe.pytorch.common.log import logger
+from msprobe.core.common.const import Const, FileCheckConst
TORCH_TYPE = ["torch.device", "torch.dtype"]
TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
@@ -86,12 +88,13 @@ def gen_real_tensor(data_path, convert_type):
convert_type: convert ori_type to dist_type flag.
"""
data_path = os.path.realpath(data_path)
- check_file_or_directory_path(data_path)
+ data_path_checker = FileChecker(data_path, FileCheckConst.FILE, ability=FileCheckConst.READ_ABLE)
+ data_path = data_path_checker.common_check()
if not data_path.endswith('.pt') and not data_path.endswith('.npy'):
error_info = f"The file: {data_path} is not a pt or numpy file."
raise CompareException(CompareException.INVALID_FILE_ERROR, error_info)
if data_path.endswith('.pt'):
- data = torch.load(data_path).cpu()
+ data = torch.load(data_path, map_location=torch.device('cpu'))
else:
data_np = numpy.load(data_path)
data = torch.from_numpy(data_np)
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py
similarity index 86%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py
index d2ab9c1e952001f4551044f4395662dd25931e08..049f6e9de5f590d5eb78a9ff2614a570c8bb88cb 100644
--- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py
+++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py
@@ -9,14 +9,14 @@ import threading
from collections import namedtuple
from itertools import cycle
from tqdm import tqdm
-from atat.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, get_validated_result_csv_path, \
+from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, get_validated_result_csv_path, \
get_validated_details_csv_path, preprocess_forward_content
-from atat.pytorch.api_accuracy_checker.compare.compare import Comparator
-from atat.pytorch.common import parse_json_info_forward_backward
-from atat.core.common.file_check import FileChecker, check_file_suffix, check_link, FileOpen, \
+from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
+from msprobe.pytorch.common import parse_json_info_forward_backward
+from msprobe.core.common.file_utils import FileChecker, check_file_suffix, check_link, FileOpen, \
check_path_before_create, create_directory
-from atat.pytorch.common.log import logger
-from atat.core.common.const import FileCheckConst
+from msprobe.pytorch.common.log import logger
+from msprobe.core.common.const import FileCheckConst
def split_json_file(input_file, num_splits, filter_api):
@@ -68,7 +68,7 @@ signal.signal(signal.SIGTERM, signal_handler)
ParallelUTConfig = namedtuple('ParallelUTConfig', ['api_files', 'out_path', 'num_splits',
'save_error_data_flag', 'jit_compile_flag', 'device_id',
- 'result_csv_path', 'total_items', 'real_data_path'])
+ 'result_csv_path', 'total_items', 'config_path'])
def run_parallel_ut(config):
@@ -90,7 +90,7 @@ def run_parallel_ut(config):
*(['-j'] if config.jit_compile_flag else []),
*(['-save_error_data'] if config.save_error_data_flag else []),
'-csv_path', config.result_csv_path,
- *(['-real_data_path', config.real_data_path] if config.real_data_path else [])
+ *(['-config', config.config_path] if config.config_path else [])
]
return cmd
@@ -110,14 +110,9 @@ def run_parallel_ut(config):
def update_progress_bar(progress_bar, result_csv_path):
while any(process.poll() is None for process in processes):
- try:
- with open(result_csv_path, 'r') as result_file:
- completed_items = len(result_file.readlines()) - 1
- progress_bar.update(completed_items - progress_bar.n)
- except FileNotFoundError:
- logger.warning(f"Result CSV file not found: {result_csv_path}.")
- except Exception as e:
- logger.error(f"An unexpected error occurred while reading result CSV: {e}")
+ with FileOpen(result_csv_path, 'r') as result_file:
+ completed_items = len(result_file.readlines()) - 1
+ progress_bar.update(completed_items - progress_bar.n)
time.sleep(1)
for api_info in config.api_files:
@@ -175,7 +170,7 @@ def prepare_config(args):
out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
out_path = out_path_checker.common_check()
split_files, total_items = split_json_file(api_info, args.num_splits, args.filter_api)
-
+ config_path = os.path.realpath(args.config_path) if args.config_path else None
result_csv_path = args.result_csv_path or os.path.join(out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv")
if not args.result_csv_path:
details_csv_path = os.path.join(out_path, f"accuracy_checking_details_{time.strftime('%Y%m%d%H%M%S')}.csv")
@@ -187,7 +182,7 @@ def prepare_config(args):
logger.info(f"UT task details will be saved in {details_csv_path}")
return ParallelUTConfig(split_files, out_path, args.num_splits, args.save_error_data,
args.jit_compile, args.device_id, result_csv_path,
- total_items, args.real_data_path)
+ total_items, config_path)
def main():
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py
similarity index 85%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py
index c5834e9a8c738b4766612263ccb0d5cae24add9d..1b9b26f9c0e3b954c577b68d35d8c5e086660b73 100644
--- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py
+++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py
@@ -10,10 +10,13 @@ else:
is_gpu = False
import torch
from tqdm import tqdm
-from atat.pytorch.api_accuracy_checker.run_ut.run_ut import exec_api, generate_device_params, get_api_info
-from atat.pytorch.api_accuracy_checker.common.utils import get_json_contents
-from atat.core.common.file_check import check_link
-from atat.pytorch.common.log import logger
+from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import exec_api, generate_device_params, get_api_info
+from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents
+from msprobe.core.common.file_utils import check_link
+from msprobe.pytorch.common.log import logger
+from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
+from msprobe.core.common.const import Const
+
def check_tensor_overflow(x):
if isinstance(x, torch.Tensor) and x.numel() != 0 and x.dtype != torch.bool:
@@ -52,12 +55,12 @@ def check_data_overflow(x):
def run_overflow_check(forward_file):
logger.info("start UT test")
- forward_content = get_json_contents(forward_file)
+ forward_content, _, real_data_path = parse_json_info_forward_backward(forward_file)
for api_full_name, api_info_dict in tqdm(forward_content.items()):
try:
- run_torch_api(api_full_name, api_info_dict)
+ run_torch_api(api_full_name, api_info_dict, real_data_path)
except Exception as err:
- api_name = api_full_name.split("_", 1)[1].rsplit("_", 2)[0]
+ _, api_name, _ = api_full_name.split(Const.SEP)
if "not implemented for 'Half'" in str(err):
logger.warning(f"API {api_name} not support half tensor in CPU, please add {api_name} to CONVERT_API "
f"'fp16_to_fp32' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
@@ -68,11 +71,10 @@ def run_overflow_check(forward_file):
logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
-def run_torch_api(api_full_name, api_info_dict):
+def run_torch_api(api_full_name, api_info_dict, real_data_path):
torch.npu.clear_npu_overflow_flag()
- api_type = api_full_name.split(".")[0]
- api_name = api_full_name.split(".", 1)[1].rsplit(".", 2)[0]
- args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path='')
+ api_type, api_name, _ = api_full_name.split(Const.SEP)
+ args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path)
if not need_grad:
logger.warning("%s function with out=... arguments don't support automatic differentiation, skip backward."
% api_full_name)
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py
similarity index 85%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py
index cd83a95801ffecf5bc28f9908378eb09f945d6a5..2fb709127d893320988b64db634c92b1eda46d60 100644
--- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py
+++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py
@@ -18,28 +18,32 @@ else:
import torch
from tqdm import tqdm
-from atat.pytorch.api_accuracy_checker.run_ut.run_ut_utils import Backward_Message, hf_32_standard_api
-from atat.pytorch.api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args
-from atat.pytorch.api_accuracy_checker.common.utils import get_json_contents, api_info_preprocess, \
+from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import Backward_Message, hf_32_standard_api
+from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args
+from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents, api_info_preprocess, \
initialize_save_path, UtDataProcessor
-from atat.pytorch.api_accuracy_checker.compare.compare import Comparator
-from atat.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
-from atat.pytorch.hook_module.wrap_tensor import TensorOPTemplate
-from atat.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
-from atat.pytorch.hook_module.wrap_torch import TorchOPTemplate
-from atat.pytorch.api_accuracy_checker.common.config import msCheckerConfig
-from atat.pytorch.common.parse_json import parse_json_info_forward_backward
-from atat.core.common.file_check import FileOpen, FileChecker, \
+from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
+from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
+from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
+from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
+from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
+from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
+from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
+from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
+from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
+from msprobe.core.common.file_utils import FileOpen, FileChecker, \
change_mode, check_file_suffix, check_link, check_path_before_create, create_directory
-from atat.pytorch.common.log import logger
-from atat.core.common.const import Const, FileCheckConst, CompareConst
+from msprobe.pytorch.common.log import logger
+from msprobe.pytorch.pt_config import parse_json_config
+from msprobe.core.common.const import Const, FileCheckConst, CompareConst
current_time = time.strftime("%Y%m%d%H%M%S")
UT_ERROR_DATA_DIR = 'ut_error_data' + current_time
RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
- 'save_error_data', 'is_continue_run_ut', 'real_data_path'])
+ 'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
+ 'black_list', 'error_data_path'])
not_backward_list = ['repeat_interleave']
not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
not_raise_dtype_set = {'type_as'}
@@ -76,6 +80,12 @@ def exec_api(api_type, api_name, args, kwargs):
if api_type == "Torch":
torch_api = TorchOPTemplate(api_name, str, False)
out = torch_api.forward(*args, **kwargs)
+ if api_type == "Aten":
+ torch_api = AtenOPTemplate(api_name, None, False)
+ out = torch_api.forward(*args, **kwargs)
+ if api_type == "NPU":
+ torch_api = NpuOPTemplate(api_name, None, False)
+ out = torch_api.forward(*args, **kwargs)
return out
@@ -176,8 +186,7 @@ def run_ut(config):
logger.info(f"UT task result will be saved in {config.result_csv_path}")
logger.info(f"UT task details will be saved in {config.details_csv_path}")
if config.save_error_data:
- error_data_path = os.path.abspath(os.path.join(msCheckerConfig.error_data_path, UT_ERROR_DATA_DIR))
- logger.info(f"UT task error_datas will be saved in {error_data_path}")
+ logger.info(f"UT task error_datas will be saved in {config.error_data_path}")
compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut)
with FileOpen(config.result_csv_path, 'r') as file:
csv_reader = csv.reader(file)
@@ -188,17 +197,17 @@ def run_ut(config):
continue
if is_unsupported_api(api_full_name): # TODO run_ut does not support to the npu fusion api and distributed api
continue
+ [_, api_name, _] = api_full_name.split(Const.SEP)
try:
- if msCheckerConfig.white_list:
- [_, api_name, _] = api_full_name.split(Const.SEP)
- if api_name not in set(msCheckerConfig.white_list):
- continue
+ if config.black_list and api_name in config.black_list:
+ continue
+ if config.white_list and api_name not in config.white_list:
+ continue
data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict)
is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info)
if config.save_error_data:
- do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success)
+ do_save_error_data(api_full_name, data_info, config.error_data_path, is_fwd_success, is_bwd_success)
except Exception as err:
- [_, api_name, _] = api_full_name.split(Const.SEP)
if "expected scalar type Long" in str(err):
logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
@@ -227,16 +236,16 @@ def is_unsupported_api(api_name):
return flag
-def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success):
+def do_save_error_data(api_full_name, data_info, error_data_path, is_fwd_success, is_bwd_success):
if not is_fwd_success or not is_bwd_success:
- processor = UtDataProcessor(os.path.join(msCheckerConfig.error_data_path, UT_ERROR_DATA_DIR))
+ processor = UtDataProcessor(error_data_path)
for element in data_info.in_fwd_data_list:
processor.save_tensors_in_element(api_full_name + '.forward.input', element)
- processor.save_tensors_in_element(api_full_name + '.forward.output.bench', data_info.bench_out)
- processor.save_tensors_in_element(api_full_name + '.forward.output.device', data_info.device_out)
+ processor.save_tensors_in_element(api_full_name + '.forward.output.bench', data_info.bench_output)
+ processor.save_tensors_in_element(api_full_name + '.forward.output.device', data_info.device_output)
processor.save_tensors_in_element(api_full_name + '.backward.input', data_info.grad_in)
- processor.save_tensors_in_element(api_full_name + '.backward.output.bench', data_info.bench_grad_out)
- processor.save_tensors_in_element(api_full_name + '.backward.output.device', data_info.device_grad_out)
+ processor.save_tensors_in_element(api_full_name + '.backward.output.bench', data_info.bench_grad)
+ processor.save_tensors_in_element(api_full_name + '.backward.output.device', data_info.device_grad)
def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict):
@@ -273,7 +282,7 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
if need_backward:
if need_to_backward(grad_index, out):
- backward_args = backward_content[api_full_name].get("grad_output")
+ backward_args = backward_content[api_full_name].get("input")
grad = gen_args(backward_args, api_name, real_data_path=real_data_path)[0]
bench_grad, _ = generate_cpu_params(grad, {}, False, api_name)
bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out)
@@ -314,14 +323,14 @@ def run_backward(args, grad, grad_index, out):
return grad_out
-def initialize_save_error_data():
- error_data_path = msCheckerConfig.error_data_path
+def initialize_save_error_data(error_data_path):
check_path_before_create(error_data_path)
create_directory(error_data_path)
- error_data_path_checker = FileChecker(msCheckerConfig.error_data_path, FileCheckConst.DIR,
+ error_data_path_checker = FileChecker(error_data_path, FileCheckConst.DIR,
ability=FileCheckConst.WRITE_ABLE)
error_data_path = error_data_path_checker.common_check()
- initialize_save_path(error_data_path, UT_ERROR_DATA_DIR)
+ error_data_path =initialize_save_path(error_data_path, UT_ERROR_DATA_DIR)
+ return error_data_path
def get_validated_result_csv_path(result_csv_path, mode):
@@ -378,12 +387,10 @@ def _run_ut_parser(parser):
help=" The path of accuracy_checking_result_{timestamp}.csv, "
"when run ut is interrupted, enter the file path to continue run ut.",
required=False)
- parser.add_argument("-real_data_path", dest="real_data_path", nargs="?", const="", default="", type=str,
- help=" In real data mode, the root directory for storing real data "
- "must be configured.",
- required=False)
parser.add_argument("-f", "--filter_api", dest="filter_api", action="store_true",
help=" Whether to filter the api in the api_info_file.", required=False)
+ parser.add_argument("-config", "--config_path", dest="config_path", default="", type=str,
+ help=" The path of config.json", required=False)
def preprocess_forward_content(forward_content):
@@ -397,9 +404,9 @@ def preprocess_forward_content(forward_content):
if key not in arg_cache:
filtered_new_args = [
{k: v for k, v in arg.items() if k not in ['Max', 'Min']}
- for arg in value['args'] if isinstance(arg, dict)
+ for arg in value['input_args'] if isinstance(arg, dict)
]
- arg_cache[key] = (filtered_new_args, value['kwargs'])
+ arg_cache[key] = (filtered_new_args, value['input_kwargs'])
filtered_new_args, new_kwargs = arg_cache[key]
@@ -464,14 +471,22 @@ def run_ut_command(args):
if args.result_csv_path:
result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result')
details_csv_path = get_validated_details_csv_path(result_csv_path)
+ white_list = msCheckerConfig.white_list
+ black_list = msCheckerConfig.black_list
+ error_data_path = msCheckerConfig.error_data_path
+ if args.config_path:
+ _, task_config = parse_json_config(args.config_path, Const.RUN_UT)
+ white_list = task_config.white_list
+ black_list = task_config.black_list
+ error_data_path = task_config.error_data_path
if save_error_data:
if args.result_csv_path:
time_info = result_csv_path.split('.')[0].split('_')[-1]
global UT_ERROR_DATA_DIR
UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
- initialize_save_error_data()
+ error_data_path = initialize_save_error_data(error_data_path)
run_ut_config = RunUTConfig(forward_content, backward_content, result_csv_path, details_csv_path, save_error_data,
- args.result_csv_path, real_data_path)
+ args.result_csv_path, real_data_path, set(white_list), set(black_list), error_data_path)
run_ut(run_ut_config)
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json
rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/__init__.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb06867371c6234583cabd485bcaa3dd671cb00c
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/__init__.py
@@ -0,0 +1,15 @@
+import os
+from pkgutil import iter_modules
+from importlib import import_module
+
+"""
+gpu and cpu not implement benchmark function, supplementary benchmarking function implementation
+"""
+
+package_path = os.path.dirname(os.path.realpath(__file__))
+for _, module_name, _ in iter_modules([package_path]):
+ module = import_module(f"{__name__}.{module_name}")
+ for attr_name in dir(module):
+ attr = getattr(module, attr_name)
+ if callable(attr) and "npu_custom" not in attr_name:
+ globals()[attr_name] = attr
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/apply_adam_w.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/apply_adam_w.py
new file mode 100644
index 0000000000000000000000000000000000000000..caf21a604c68a61343c6dbafa2e8c604a2d9649f
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/apply_adam_w.py
@@ -0,0 +1,28 @@
+import torch
+
+
+def npu_apply_adam_w(beta1_power, beta2_power, lr, weight_decay,
+ beta1, beta2, eps, grad, max_grad_norm, amsgrad, maximize, out):
+ var, m, v = out
+ if amsgrad:
+ max_grad_norm = (torch.rand(var.shape) * 10.0 - 5.0).to(var.dtype)
+ beta1_power_out = beta1_power * beta1
+ beta2_power_out = beta2_power * beta2
+ var_t = var * (1 + (-lr * weight_decay))
+ gt = -grad if maximize else grad
+ m_out = m * beta1 - (beta1 + (-1)) * gt
+ v_out = v * beta2 - (beta2 + (-1)) * gt * gt
+
+ if amsgrad:
+ max_grad_norm_out = torch.max(max_grad_norm, v_out)
+ if (1 - beta2_power_out) == 0:
+ beta2_power_out -= eps
+ denom = torch.sqrt(torch.div(max_grad_norm_out, (1 - beta2_power_out))) + eps
+ else:
+ vraintain = torch.div(v_out, (1 - beta2_power_out))
+ denom = torch.sqrt(vraintain) + eps
+
+ if (1 - beta1_power_out) == 0:
+ beta1_power_out -= eps
+ var_out = var_t + torch.div(-lr * m_out, (1 - beta1_power_out)).div(denom)
+ return var_out.cpu(), m_out.cpu(), v_out.cpu()
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/confusion_transpose.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/confusion_transpose.py
new file mode 100644
index 0000000000000000000000000000000000000000..627bf11b64f3e62b6b305b7f7559d0fa69056ba6
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/confusion_transpose.py
@@ -0,0 +1,19 @@
+def npu_confusion_transpose(data, perm, shape, transpose_first):
+ if transpose_first:
+ output = data.permute(*perm).contiguous().view(shape)
+ else:
+ output = data.view(shape).permute(*perm)
+ return output.cpu()
+
+
+def npu_confusion_transpose_backward(grad, perm, shape, transpose_first):
+ shape_cal = shape if transpose_first else [shape[perm_dim] for perm_dim in perm]
+ perm_cal = [0] * len(perm)
+ for i, perm_dim in enumerate(perm):
+ perm_cal[perm_dim] = i
+
+ if transpose_first:
+ result = grad.permute(*perm_cal).reshape(shape_cal)
+ else:
+ result = grad.reshape(shape_cal).permute(*perm_cal)
+ return result.cpu()
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/fast_gelu.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/fast_gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1a9ca080851b580e8d1554a75b2c6ee1aee5165
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/fast_gelu.py
@@ -0,0 +1,55 @@
+import torch
+
+
+def fast_gelu(input0):
+ attr = 1.702
+ const_0 = 0 - attr
+ const_1 = 1
+ const_2 = attr / 2
+
+ abs_x = torch.abs(input0)
+ mul_abs_x = abs_x * const_0
+ exp_abs_x = torch.exp(mul_abs_x)
+ div_down = exp_abs_x + const_1
+
+ pn_x = input0 - abs_x
+ mul_pn_x = pn_x * const_2
+ exp_pn_x = torch.exp(mul_pn_x)
+ div_up = input0 * exp_pn_x
+ div_down_rec = torch.reciprocal(div_down)
+ result = div_up * div_down_rec
+
+ return result.cpu()
+
+
+def npu_fast_gelu_backward(grad, input_x):
+ const_2 = 1.702
+ const_3 = 1.0
+ const_1 = 0.0 - const_2
+
+ # e^(-1.702x)
+ abs_x = torch.abs(input_x)
+ mul_abs_x = abs_x * const_1
+ exp_x = torch.exp(mul_abs_x)
+
+ # 1.702xe^(-1.702x)
+ add_2 = input_x * exp_x
+ add_2 = add_2 * const_2
+
+ # e^(1.702(x-|x|))
+ pn_x = input_x - abs_x
+ mul_pn_x = pn_x * const_2
+ exp_pn_x = torch.exp(mul_pn_x)
+
+ # e^(-1.702x) + 1.702xe^(-1.702x) + e^(1.702(x-|x|))
+ div_up = exp_x + add_2
+ div_up = div_up + exp_pn_x
+
+ # (e^(-1.702x)+1)^2
+ div_down_i = exp_x + const_3
+ div_down = div_down_i * div_down_i
+ div_down_rec = torch.reciprocal(div_down)
+ result_temp = div_up * div_down_rec
+ result = grad * result_temp
+
+ return result.cpu()
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/layer_norm_eval.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/layer_norm_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6949c079e2b7928f2f0b30d5fbdb305dc0ba535
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/layer_norm_eval.py
@@ -0,0 +1,6 @@
+import torch
+
+
+def npu_layer_norm_eval(data, normalized_shape):
+ result = torch.nn.functional.layer_norm(data, normalized_shape)
+ return result.cpu()
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/linear.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..95db875edf6dea3103fcd6b0d8533a850d8edad8
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/linear.py
@@ -0,0 +1,12 @@
+import torch
+
+
+def npu_linear(x, weight, bias):
+ output = torch.nn.functional.linear(x, weight, bias)
+ return output.cpu()
+
+
+def npu_linear_backward(grad, input_data, weight):
+ input_grad = torch.matmul(grad, weight)
+ weight_grad = torch.matmul(grad.t(), input_data)
+ return input_grad.cpu(), weight_grad.cpu()
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/matmul_backward.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/matmul_backward.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed1c746ec1618bbbec7cb49914a7e71012047d6b
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/matmul_backward.py
@@ -0,0 +1,48 @@
+import torch
+
+
+def matmul_backward(grad, self, other, mask):
+ grad_self, grad_other = None, None
+ dim_self = self.dim()
+ dim_other = other.dim()
+
+ size_grad = list(grad.size())
+ size_self = list(self.size())
+ size_other = list(other.size())
+ if dim_self == 1 and dim_other == 1:
+ grad_self = other.mul(grad) if mask[0] else grad_self
+ grad_other = self.mul(grad) if mask[1] else grad_other
+ elif dim_self == 2 and dim_other == 1:
+ grad_self = grad.unsqueeze(1).mm(other.unsqueeze(0)) if mask[0] else grad_self
+ grad_other = self.transpose(-1, -2).mm(grad.unsqueeze(1)).squeeze_(1) if mask[1] else grad_other
+ elif dim_self == 1 and dim_other == 2:
+ grad_self = grad.unsqueeze(0).mm(other.transpose(-1, -2)).squeeze_(0) if mask[0] else grad_self
+ grad_other = self.unsqueeze(1).mm(grad.unsqueeze(0)) if mask[1] else grad_other
+ elif dim_self >= 3 and (dim_other == 1 or dim_other == 2):
+ view_size = 1 if dim_other == 1 else size_grad[-1]
+ unfolded_grad = (grad.unsqueeze(-1) if dim_other == 1 else grad).contiguous().view(-1, view_size)
+ if mask[0]:
+ grad_self = unfolded_grad.mm(other.unsqueeze(0) if dim_other == 1 else other.transpose(-1, -2)) \
+ .view(size_self)
+ if mask[1]:
+ unfolded_self = self.contiguous().view([-1, size_self[-1]])
+ grad_other = unfolded_self.transpose(-1, -2).mm(unfolded_grad).view(size_other)
+ elif (dim_self == 1 or dim_self == 2) and dim_other >= 3:
+ view_size = 1 if dim_self == 1 else size_grad[-2]
+ unfolded_grad_T = grad.view([-1, view_size]) \
+ if dim_self == 1 else grad.transpose(-1, -2).contiguous().view([-1, view_size])
+ if mask[0]:
+ # create a 2D-matrix from other
+ unfolded_other_T = \
+ other.transpose(-1, -2).contiguous().view([-1, size_other[-2]]).transpose(-1, -2)
+ grad_self = unfolded_other_T.mm(unfolded_grad_T).transpose(-1, -2).view(size_self)
+ if mask[1]:
+ size_other_T = size_other[:-2]
+ size_other_T.extend(size_other[::-1][:2])
+ grad_other = \
+ unfolded_grad_T.mm(self.unsqueeze(0) if dim_self == 1 else self).view(size_other_T).transpose(-1, -2)
+ else:
+ grad_self = torch.matmul(grad, other.transpose(-1, -2)) if mask[0] else grad_self
+ grad_other = torch.matmul(self.transpose(-1, -2), grad) if mask[1] else grad_other
+
+ return grad_self.cpu(), grad_other.cpu()
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..63f1fa2a3b657bce4fc470edf75d8271c6666364
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py
@@ -0,0 +1,421 @@
+import torch
+import numpy as np
+from einops import rearrange
+
+from msprobe.pytorch.common.utils import logger
+
+gtype = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
+softmax_build_mode = "QKV" # "MAX_SUM"
+
+"""
+# 前向函数声明对比
+标杆实现:fusion_attention_forward: q, k, v, drop_mask, atten_mask, pse, scale, keep_prob
+融合算子:npu_fusion_attention_forward: query, key, value, head_num, input_layout, *, pse=None, padding_mask=None,
+ atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647,
+ next_tockens=2147483647, inner_precise=0, prefix=None, sparse_mode=0,
+ gen_mask_parallel=True, sync=False
+
+# 反向函数声明对比
+标杆实现:fusion_attention_backward: dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
+融合算子:npu_fusion_attention_backward: query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None,
+ atten_mask=None, softmax_max=None, softmax_sum=None, softmax_in=None,
+ attention_in=None, scale_value=1.0, keep_prob=1.0, pre_tockens=2147483647,
+ next_tockens=2147483647, inner_precise=0, seed=0, offset=0,
+ numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False
+"""
+
+
+def softmax_forward(x):
+ x_max = torch.max(x, dim=-1, keepdims=True)[0]
+ x_sub = x.sub(x_max)
+ y = torch.exp(x_sub)
+ x_sum = y.sum(dim=-1, keepdims=True)
+ res = y.div(x_sum)
+ return res, x_max, x_sum
+
+
+def softmax_grad(dp, softmax_res):
+ muls = dp * softmax_res
+ muls_r = muls.sum(dim=-1, keepdims=True)
+ sub_r = dp - muls_r
+ res = sub_r * softmax_res
+ return res
+
+
+def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype):
+ if num_kv_heads == 0 or num_kv_heads < num_heads:
+ raise ValueError(f"num_kv_heads must be non-zero and less than num_heads.")
+
+ factor = num_heads // num_kv_heads
+ kv_shape = kv_tensor.shape
+ B = kv_shape[0]
+ S = kv_shape[2]
+ D = kv_shape[3]
+ kv_res = torch.zeros([B, num_heads, S, D]).to(dtype)
+ for i in range(num_heads):
+ j = i // factor
+ kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :]
+ return kv_res
+
+
+def calculate_qk(q, k, atten_mask, pse, scale):
+ if pse is None or len(pse.shape) == 0:
+ qk = torch.matmul(q, k.permute(0, 1, 3, 2)).mul(scale)
+ else:
+ qk = (torch.matmul(q, k.permute(0, 1, 3, 2)) + pse).mul(scale)
+ if atten_mask is None or len(atten_mask.shape) == 0:
+ return qk
+ else:
+ qk = qk + atten_mask.bool() * (-40000.0) # -10000
+ return qk
+
+
+def fusion_attention_forward(q, k, v, drop_mask, atten_mask, pse, scale, keep_prob):
+ qk = calculate_qk(q, k, atten_mask, pse, scale)
+ softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
+ if drop_mask is None or len(drop_mask.shape) == 0:
+ drop_res = softmax_res
+ else:
+ drop_res = softmax_res * drop_mask * (1.0 / keep_prob)
+ y = torch.matmul(drop_res, v)
+ return y, softmax_max, softmax_sum
+
+
+def fusion_attention_backward(dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob):
+ dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
+ if drop_mask is None or len(drop_mask.shape) == 0:
+ drop_res = softmax_res.permute(0, 1, 3, 2)
+ dp_drop = dp
+ else:
+ drop_res = softmax_res.mul(drop_mask).mul(1.0 / keep_prob).permute(0, 1, 3, 2)
+ dp_drop = dp * drop_mask * (1.0 / keep_prob)
+ dv = torch.matmul(drop_res, dx)
+ softmax_grad_res = (softmax_grad(dp_drop, softmax_res) * scale)
+ dq = torch.matmul(softmax_grad_res, k)
+ dk = torch.matmul(softmax_grad_res.permute(0, 1, 3, 2), q)
+ return dq, dk, dv
+
+
+def parse_bsnd_args(query, key, head_num, input_layout):
+ supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"]
+ B, S1, S2, N1, N2, D, H1, H2 = None, None, None, head_num, None, None, None, None
+
+ if not isinstance(input_layout, str) or input_layout not in supported_input_layout:
+ raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.")
+
+ if input_layout == "TND":
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
+ try:
+ if input_layout == "BSH":
+ B, S1, H1 = query.shape
+ _, S2, H2 = key.shape
+ D = H1 // N1
+ N2 = H2 // D
+ elif input_layout == "SBH":
+ S1, B, H1 = query.shape
+ S2, _, H2 = key.shape
+ D = H1 // N1
+ N2 = H2 // D
+ elif input_layout == "BSND":
+ B, S1, N1, D = query.shape
+ _, S2, N2, _ = key.shape
+ H1 = N1 * D
+ H2 = N2 * D
+ elif input_layout == "BNSD":
+ B, N1, S1, D = query.shape
+ _, N2, S2, _ = key.shape
+ H1 = N1 * D
+ H2 = N2 * D
+ except Exception as e:
+ raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
+
+ if D == 0:
+ raise ValueError(f"Value D must be non-zero.")
+ DTYPE = query.dtype
+ return B, S1, S2, N1, N2, D, H1, H2, DTYPE
+
+
+def convert_from_bnsd(_input, input_layout):
+ if input_layout == "BSH":
+ # (B,N,S,D)=>(B,S,N*D)
+ out = rearrange(_input, 'b n s d -> b s (n d)').contiguous()
+ elif input_layout == "SBH":
+ # (B,N,S,D)=>(S,B,N*D)
+ out = rearrange(_input, 'b n s d -> s b (n d)').contiguous()
+ elif input_layout == "BSND":
+ # (B,N,S,D)=>(B,S,N,D)
+ out = rearrange(_input, 'b n s d -> b s n d').contiguous()
+ elif input_layout == "TND":
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
+ else:
+ out = _input
+ return out
+
+
+def convert_to_bnsd(_input, n, input_layout):
+ # 默认"BNSD"无需处理
+ if input_layout == "BSH":
+ # (B,S,N*D)=>(B,N,S,D)
+ out = rearrange(_input, 'b s (n d) -> b n s d', n=n)
+ elif input_layout == "SBH":
+ # (S,B,N*D)=>(B,N,S,D)
+ out = rearrange(_input, 's b (n d) -> b n s d', n=n)
+ elif input_layout == "BSND":
+ # (B,S,N,D)=>(B,N,S,D)
+ out = rearrange(_input, 'b s n d -> b n s d', n=n)
+ elif input_layout == "TND":
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
+ else:
+ out = _input
+ if out.dim() != 4:
+ raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
+ return out.to(gtype)
+
+
+def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next_tocken, dtype):
+ """
+ # 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现
+ ===> atten_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype)
+ """
+ shape = [S1, S2]
+
+ if atten_mask is not None:
+ # 当FA的输入已经包含atten_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原
+ if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4:
+ logger.info(f"S1: {S1}, S2:{S2}, atten_mask.shape:{atten_mask.shape}, atten_mask.dtype:{atten_mask.dtype}")
+
+ if atten_mask.dim() == 2 and atten_mask.shape[0] == 2048 and atten_mask.shape[1] == 2048:
+ if atten_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(atten_mask.dtype)):
+ if sparse_mode == 2:
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
+ elif sparse_mode == 3:
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1))
+ elif sparse_mode == 4:
+ atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
+ atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
+ atten_mask = atten_mask_u + atten_mask_l
+ logger.debug(f"反向转换atten_mask {atten_mask.shape}")
+ return atten_mask.to(dtype)
+
+ return atten_mask.to(dtype)
+
+ if atten_mask is not None:
+ if atten_mask.dim() == 2:
+ if atten_mask.shape[0] != S1 or atten_mask.shape[1] != S2:
+ raise ValueError(f"Invalid atten_mask shape `SS` {atten_mask.shape}")
+ shape = [S1, S2]
+ elif atten_mask.dim() == 4:
+ if atten_mask.shape[1] == 1:
+ shape = [B, 1, S1, S2] if B != 1 else [1, 1, S1, S2]
+ else:
+ shape = [B, N1, S1, S2] if B != 1 else [1, N1, S1, S2]
+
+ if sparse_mode == 0:
+ atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
+ atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
+ atten_mask = atten_mask_u + atten_mask_l
+ elif sparse_mode == 1: # no sparse
+ atten_mask = torch.from_numpy(np.zeros(shape))
+ elif sparse_mode == 2:
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
+ elif sparse_mode == 3:
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1))
+ elif sparse_mode == 4:
+ atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
+ atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
+ atten_mask = atten_mask_u + atten_mask_l
+ # 注:不会出现sparse_mode=5的情况,该情况要求必须要传入atten_mask,且atten_mask矩阵数据格式须为BNSS或B1SS,
+ # 因此可以认为FA的输入已经是正确的atten_mask了
+ return atten_mask.to(dtype)
+
+
+def generate_kv(key, value, N1, N2):
+ # N不等长适配by cdy
+ if not (N1 == N2):
+ k_new = broadcast_kv(N1, N2, key, key.dtype)
+ v_new = broadcast_kv(N1, N2, value, value.dtype)
+ else:
+ k_new = key
+ v_new = value
+ return k_new, v_new
+
+
+def rebuid_softmax_by_qkv(q, k, atten_mask, pse, scale):
+ """
+ attention = softmax(QK^T/sqrt(d))V
+ softmax(x_i) = e^(x_i - x_max) / sum(e^(x_i - x_max))
+ """
+ logger.info("Using QKV to rebuild original softmax")
+ qk = calculate_qk(q, k, atten_mask, pse, scale)
+ softmax_res, x_max, x_sum = softmax_forward(qk)
+ return softmax_res
+
+
+def rebuild_softmax_by_max_sum(q, k, atten_mask, pse, scale, softmax_max, softmax_sum):
+ """
+ attention = softmax(QK^T/sqrt(d))V
+ softmax(x_i) = e^(x_i - x_max_i) / x_sum_i)
+ """
+ logger.info("Using softmax_max and softmax_sum to rebuild original softmax")
+ qk = calculate_qk(q, k, atten_mask, pse, scale)
+ if softmax_max.shape[-1] == 0:
+ raise ValueError(f"softmax_max.shape[-1] must be non-zero, softmax_max.shape: {softmax_max.shape}")
+ repeat_dim = qk.shape[-1] // softmax_max.shape[-1]
+ softmax_res = torch.exp(qk.sub(softmax_max.repeat(1, 1, 1, repeat_dim))).div(
+ softmax_sum.repeat(1, 1, 1, repeat_dim))
+ return softmax_res
+
+
+def npu_fusion_attention_forward_patch(*args, **kwargs):
+ # query, key, value, head_num, input_layout
+ if len(args) != 5:
+ raise ValueError(f"Unsupported npu_fusion_attention args {args}.")
+
+ B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], args[3], args[4])
+ if N1 == N2 and S1 == S2:
+ logger.debug(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
+ else:
+ logger.debug(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
+ if not (N1 % N2 == 0 and N1 >= N2):
+ raise ValueError(f"N1与N2不匹配,请检查: N1 = {N1}, N2 = {N2}.")
+
+ dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2,
+ "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE}
+
+ new_kwargs = {"keep_prob": 1,
+ "scale": kwargs.get("scale", 1 / (D ** 0.5)),
+ "sparse_mode": kwargs.get("sparse_mode", 0),
+ "prefix": kwargs.get("prefix"),
+ "pre_tockens": kwargs.get("pre_tockens", 2147483647),
+ "next_tockens": kwargs.get("next_tockens", 2147483647),
+ "pse": kwargs.get("pse"),
+ "padding_mask": kwargs.get("padding_mask"),
+ "atten_mask": kwargs.get("atten_mask")}
+
+ return args, dims_kwargs, new_kwargs
+
+
+def npu_fusion_attention_backward_patch(*args, **kwargs):
+ if len(args) != 6:
+ raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
+
+ B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], args[4], args[5])
+ if N1 == N2 and S1 == S2:
+ logger.info(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
+ else:
+ logger.info(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
+ if not (N1 % N2 == 0 and N1 >= N2):
+ raise ValueError(f"N1与N2不匹配,请检查: N1 = {N1}, N2 = {N2}.")
+
+ dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2,
+ "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE}
+
+ new_kwargs = {"keep_prob": 1,
+ "scale_value": kwargs.get("scale_value", 1 / (D ** 0.5)),
+ "sparse_mode": kwargs.get("sparse_mode", 0),
+ "prefix": kwargs.get("prefix"),
+ "pre_tockens": kwargs.get("pre_tockens", 2147483647),
+ "next_tockens": kwargs.get("next_tockens", 2147483647),
+ "pse": kwargs.get("pse"),
+ "padding_mask": kwargs.get("padding_mask"),
+ "softmax_max": kwargs.get("softmax_max"),
+ "softmax_sum": kwargs.get("softmax_sum"),
+ "softmax_in": kwargs.get("softmax_in"),
+ "attention_in": kwargs.get("attention_in"),
+ "seed": kwargs.get("seed", 0),
+ "offset": kwargs.get("offset", 0),
+ "numels": kwargs.get("numels", 0),
+ "atten_mask": kwargs.get("atten_mask")}
+
+ return args, dims_kwargs, new_kwargs
+
+
+def npu_fusion_attention(*args, **kwargs):
+ new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*args, **kwargs)
+ query, key, value, input_layout = new_args[0], new_args[1], new_args[2], new_args[4]
+ N1 = dims_kwargs.get("N1")
+ N2 = dims_kwargs.get("N2")
+ S1 = dims_kwargs.get("S1")
+ S2 = dims_kwargs.get("S2")
+ B = dims_kwargs.get("B")
+ DTYPE = dims_kwargs.get("DTYPE")
+ atten_mask = new_kwargs.get("atten_mask")
+ keep_prob = new_kwargs.get("keep_prob")
+ sparse_mode = new_kwargs.get("sparse_mode")
+ pre_tockens = new_kwargs.get("pre_tockens")
+ next_tockens = new_kwargs.get("next_tockens")
+ pse = new_kwargs.get("pse")
+ scale = new_kwargs.get("scale")
+
+ atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE)
+ query = convert_to_bnsd(query, N1, input_layout)
+ key = convert_to_bnsd(key, N2, input_layout)
+ value = convert_to_bnsd(value, N2, input_layout)
+ k_new, v_new = generate_kv(key, value, N1, N2)
+ out_golden, softmax_max, softmax_sum = fusion_attention_forward(q=query, k=k_new, v=v_new,
+ drop_mask=None, atten_mask=atten_mask,
+ pse=pse, scale=scale,
+ keep_prob=keep_prob)
+ if out_golden.dim() == 5:
+ out_golden = out_golden.reshape(out_golden.size(0), out_golden.size(1) * out_golden.size(2), out_golden.size(3),
+ out_golden.size(4))
+ out_golden = convert_from_bnsd(out_golden, input_layout)
+
+ return out_golden.cpu(), softmax_max.repeat(1, 1, 1, 8).cpu(), softmax_sum.repeat(1, 1, 1, 8).cpu()
+
+
+def npu_fusion_attention_grad(*args, **kwargs):
+ # dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
+ new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*args, **kwargs)
+ query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5]
+ N1 = dims_kwargs.get("N1")
+ N2 = dims_kwargs.get("N2")
+ S1 = dims_kwargs.get("S1")
+ S2 = dims_kwargs.get("S2")
+ B = dims_kwargs.get("B")
+ D = dims_kwargs.get("D")
+ DTYPE = dims_kwargs.get("DTYPE")
+ atten_mask = new_kwargs.get("atten_mask")
+ keep_prob = new_kwargs.get("keep_prob")
+ sparse_mode = new_kwargs.get("sparse_mode")
+ pre_tockens = new_kwargs.get("pre_tockens")
+ next_tockens = new_kwargs.get("next_tockens")
+ pse = new_kwargs.get("pse")
+ softmax_max = new_kwargs.get("softmax_max")
+ softmax_sum = new_kwargs.get("softmax_sum")
+ scale_value = new_kwargs.get("scale_value")
+
+ atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE)
+ query = convert_to_bnsd(query, N1, input_layout)
+ dx = convert_to_bnsd(dx, N1, input_layout)
+ key = convert_to_bnsd(key, N2, input_layout)
+ value = convert_to_bnsd(value, N2, input_layout)
+ k_new, v_new = generate_kv(key, value, N1, N2)
+
+ if softmax_build_mode == "QKV":
+ softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value)
+ else:
+ softmax_res = rebuild_softmax_by_max_sum(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum)
+
+ dq, dk, dv = fusion_attention_backward(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob)
+
+ # N不等长适配by cdy
+ if not (N1 == N2):
+ if N2 == 0:
+ raise ValueError("dims_kwargs.N2 must be non-zero.")
+ G = int(N1 / N2)
+ dk = torch.sum(dk.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D)
+ dv = torch.sum(dv.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D)
+
+ if dq.dim() == 5:
+ dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4))
+ if dk.dim() == 5:
+ dk = dk.reshape(dk.size(0), dk.size(1) * dk.size(2), dk.size(3), dk.size(4))
+ if dv.dim() == 5:
+ dv = dv.reshape(dv.size(0), dv.size(1) * dv.size(2), dv.size(3), dv.size(4))
+
+ dq = convert_from_bnsd(dq, input_layout)
+ dk = convert_from_bnsd(dk, input_layout)
+ dv = convert_from_bnsd(dv, input_layout)
+
+ return dq.cpu(), dk.cpu(), dv.cpu()
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/rms_norm.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/rms_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..e647312fdb221137e58f775d21e0605d21f98e07
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/rms_norm.py
@@ -0,0 +1,15 @@
+import torch
+
+
+def npu_rms_norm(x, gamma, epsilon=1e-5):
+ rstd = torch.rsqrt(torch.mean(torch.pow(x, 2), axis=-1, keepdim=True) + epsilon)
+ res = x * rstd * gamma
+ return res.cpu(), rstd.float().cpu()
+
+
+def npu_rms_norm_backward(grad, x, gamma, rstd):
+ mean_gy = (grad * x * gamma * rstd).mean(dim=-1, keepdim=True)
+ grad_x = (grad * gamma - x * rstd * mean_gy) * rstd
+ grad_gamma = x * grad * rstd
+ return grad_x.cpu(), grad_gamma.cpu()
+
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/rotary_mul.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/rotary_mul.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e0fda5f73fced0913ff1fd755a68e0dfb6652fd
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/rotary_mul.py
@@ -0,0 +1,52 @@
+import torch
+
+
+def npu_rotary_mul(x, r1, r2):
+ x1, x2 = torch.chunk(x, 2, -1)
+ x_new = torch.cat((-x2, x1), dim=-1)
+ output = r1 * x + r2 * x_new
+ return output.cpu()
+
+
+def npu_rotary_mul_backward(dy_tensor, x, r1, r2):
+ x.requires_grad = True
+ r1.requires_grad = True
+ r2.requires_grad = True
+ # golden
+ x1, x2 = torch.chunk(x, 2, -1)
+ x_new = torch.cat((-x2, x1), dim=-1)
+ golden_tensor = r1 * x + r2 * x_new
+ golden_tensor.backward(dy_tensor)
+ r1_shape = r1.shape
+ r1_grad = torch.zeros(r1_shape).type(torch.float32)
+ r2_grad = torch.zeros(r1_shape).type(torch.float32)
+ x1, x2 = torch.chunk(x.float(), 2, -1)
+ x_new2 = torch.cat((-x2, x1), dim=-1)
+ x_shape = x.shape
+ h = x.float()
+ grad = dy_tensor.float()
+ condition_1 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and
+ ((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and
+ (r1_shape[1] == x_shape[1]) and (r1_shape[3] == x_shape[3]))
+ condition_2 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and
+ ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and
+ (r1_shape[2] == x_shape[2]) and (r1_shape[3] == x_shape[3]))
+ condition_3 = (((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and
+ ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and
+ (r1_shape[0] == x_shape[0]) and (r1_shape[3] == x_shape[3]))
+ if condition_1:
+ for i in range(x_shape[0]):
+ for j in range(x_shape[2]):
+ r2_grad[0, :, 0, :] += (x_new2[i, :, j, :] * grad[i, :, j, :])
+ r1_grad[0, :, 0, :] += (h[i, :, j, :] * grad[i, :, j, :])
+ elif condition_2:
+ for i in range(x_shape[0]):
+ for j in range(x_shape[1]):
+ r2_grad[0, 0, :, :] += (x_new2[i, j, :, :] * grad[i, j, :, :])
+ r1_grad[0, 0, :, :] += (h[i, j, :, :] * grad[i, j, :, :])
+ elif condition_3:
+ for i in range(x_shape[1]):
+ for j in range(x_shape[2]):
+ r2_grad[:, 0, 0, :] += (x_new2[:, i, j, :] * grad[:, i, j, :])
+ r1_grad[:, 0, 0, :] += (h[:, i, j, :] * grad[:, i, j, :])
+ return x.grad.cpu(), r1_grad.cpu(), r2_grad.cpu()
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/scaled_mask_softmax.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/scaled_mask_softmax.py
new file mode 100644
index 0000000000000000000000000000000000000000..8717aebaf902db8af37fe5d8ec5cb054fd83439b
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/scaled_mask_softmax.py
@@ -0,0 +1,26 @@
+import torch
+
+
+def npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask):
+ if fixed_triu_mask:
+ mask = (torch.triu(torch.ones(mask.shape), k=1)).bool().to(mask.device)
+ dtype = x.dtype
+ x = (x * scale).masked_fill(mask, value=-10000)
+ x = x - torch.max(x, dim=-1, keepdims=True)[0]
+ x = torch.exp(x.float())
+ y = torch.div(x, torch.sum(x, dim=-1, keepdims=True))
+ return y.to(dtype).cpu()
+
+
+def npu_scaled_masked_softmax_backward(y_grad, y, mask, scale, fixed_triu_mask):
+ if fixed_triu_mask:
+ mask = (torch.triu(torch.ones(mask.shape), k=1)).bool().to(mask.device)
+ dtype = y_grad.dtype
+ y_grad = y_grad.float()
+ y = y.float()
+ x_grad = y_grad * y
+ x_grad = y_grad - torch.sum(x_grad, dim=-1, keepdims=True)
+ x_grad = x_grad * y
+ x_grad = x_grad * scale
+ x_grad = x_grad.masked_fill(mask, value=0)
+ return x_grad.to(dtype).cpu()
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/swiglu.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/swiglu.py
new file mode 100644
index 0000000000000000000000000000000000000000..e03c975a50aaf37573645365d2e6228706070844
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/swiglu.py
@@ -0,0 +1,55 @@
+import torch
+
+
+def npu_swiglu(x, dim=-1):
+ tensor_dtype = x.dtype
+
+ inTensors = torch.chunk(x, 2, dim=dim)
+ if tensor_dtype == torch.float32:
+ tensor_scalar = torch.sigmoid(torch.mul(inTensors[0], 1.0))
+ output_data = torch.mul(torch.mul(tensor_scalar, inTensors[0]), inTensors[1])
+ else:
+ tensor_self_float = inTensors[0].type(torch.float)
+ tensor_other_float = inTensors[1].type(torch.float)
+ tensor_out_float = torch.nn.functional.silu(tensor_self_float).type(tensor_dtype).type(
+ torch.float32) * tensor_other_float
+ output_data = tensor_out_float.type(tensor_dtype)
+ return output_data.cpu()
+
+
+def npu_swiglu_backward(grad, x, dim=-1):
+ tensor_dtype = grad.dtype
+ in_tensors = torch.chunk(x, 2, dim=dim)
+ tensor_grad_out = grad
+
+ if tensor_dtype == torch.float16:
+ tensor_out1 = torch.mul(
+ torch.mul(in_tensors[1].type(torch.float32), swish_grad(1, in_tensors[0].type(torch.float32))),
+ tensor_grad_out.type(torch.float32)).type(torch.float16)
+ tensor_out2 = torch.mul(tensor_grad_out.type(torch.float32),
+ swish(1, in_tensors[0].type(torch.float32))).type(torch.float16)
+ output = torch.cat((tensor_out1, tensor_out2), dim)
+ elif tensor_dtype == torch.bfloat16:
+ tensor_self_float = in_tensors[0].type(torch.float)
+ tensor_other_float = in_tensors[1].type(torch.float)
+ tensor_gradout_float = tensor_grad_out.type(torch.float)
+
+ tensor_out1 = torch.mul(tensor_gradout_float, swish_grad(1.0, tensor_self_float)).type(torch.bfloat16).type(
+ torch.float32) * tensor_other_float
+ tensor_out2 = swish(1.0, tensor_self_float).type(torch.bfloat16).type(torch.float32) * tensor_gradout_float
+ tensor_out_float = torch.cat((tensor_out1, tensor_out2), dim=dim)
+ output = tensor_out_float.type(torch.bfloat16)
+ else:
+ tensor_out1 = torch.mul(torch.mul(in_tensors[1], swish_grad(1.0, in_tensors[0])), tensor_grad_out)
+ tensor_out2 = torch.mul(tensor_grad_out, swish(1.0, in_tensors[0]))
+ output = torch.cat((tensor_out1, tensor_out2), dim)
+ return output.cpu()
+
+
+def swish_grad(beta, x):
+ return torch.sigmoid(beta * x) + x * (1 - torch.sigmoid(beta * x)) * torch.sigmoid(beta * x) * beta
+
+
+def swish(beta, x):
+ return x * torch.sigmoid(beta * x)
+
diff --git a/debug/accuracy_tools/atat/pytorch/common/__init__.py b/debug/accuracy_tools/msprobe/pytorch/common/__init__.py
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/common/__init__.py
rename to debug/accuracy_tools/msprobe/pytorch/common/__init__.py
diff --git a/debug/accuracy_tools/atat/pytorch/common/compare_script.template b/debug/accuracy_tools/msprobe/pytorch/common/compare_script.template
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/common/compare_script.template
rename to debug/accuracy_tools/msprobe/pytorch/common/compare_script.template
diff --git a/debug/accuracy_tools/atat/pytorch/common/log.py b/debug/accuracy_tools/msprobe/pytorch/common/log.py
similarity index 81%
rename from debug/accuracy_tools/atat/pytorch/common/log.py
rename to debug/accuracy_tools/msprobe/pytorch/common/log.py
index e496e9b72ad449c24dd3f2a76a9a149d0f2eff1e..cea518fa47b977ad08b5fded3401a63bf3a29d03 100644
--- a/debug/accuracy_tools/atat/pytorch/common/log.py
+++ b/debug/accuracy_tools/msprobe/pytorch/common/log.py
@@ -1,9 +1,9 @@
import os
import time
import sys
-from atat.pytorch.common.utils import get_rank_if_initialized
-from atat.core.common.log import BaseLogger
-from atat.core.common.exceptions import DistributedNotInitializedError
+from msprobe.pytorch.common.utils import get_rank_if_initialized
+from msprobe.core.common.log import BaseLogger
+from msprobe.core.common.exceptions import DistributedNotInitializedError
class PyTorchLogger(BaseLogger):
diff --git a/debug/accuracy_tools/atat/pytorch/common/parse_json.py b/debug/accuracy_tools/msprobe/pytorch/common/parse_json.py
similarity index 89%
rename from debug/accuracy_tools/atat/pytorch/common/parse_json.py
rename to debug/accuracy_tools/msprobe/pytorch/common/parse_json.py
index a938f5f0da9ea465923157ad4131ce72bb92962f..89edd834cf67c8006f577f1d33fee32b3b6dc751 100644
--- a/debug/accuracy_tools/atat/pytorch/common/parse_json.py
+++ b/debug/accuracy_tools/msprobe/pytorch/common/parse_json.py
@@ -1,5 +1,7 @@
import json
-from atat.core.common.exceptions import ParseJsonException
+
+from msprobe.core.common.exceptions import ParseJsonException
+from msprobe.core.common.file_utils import FileOpen
def parse_json_info_forward_backward(json_path):
@@ -11,7 +13,7 @@ def parse_json_info_forward_backward(json_path):
api_name = '.'.join(name_struct[:-1])
return api_name
- with open(json_path, 'r') as f:
+ with FileOpen(json_path, 'r') as f:
dump_json = json.load(f)
real_data_path = dump_json.get("dump_data_dir")
diff --git a/debug/accuracy_tools/atat/pytorch/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/common/utils.py
similarity index 87%
rename from debug/accuracy_tools/atat/pytorch/common/utils.py
rename to debug/accuracy_tools/msprobe/pytorch/common/utils.py
index 4b413ac57502ad777bf68140fb57854a691811b5..181491488f91049148c51827387edc53d21d41cf 100644
--- a/debug/accuracy_tools/atat/pytorch/common/utils.py
+++ b/debug/accuracy_tools/msprobe/pytorch/common/utils.py
@@ -14,13 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
+import logging
import os
import random
import stat
import torch
+import torch.distributed as dist
import numpy as np
from functools import wraps
-from atat.core.common.exceptions import DistributedNotInitializedError
+from msprobe.core.common.exceptions import DistributedNotInitializedError
try:
import torch_npu
@@ -29,7 +31,6 @@ except ImportError:
else:
is_gpu = False
-
torch_without_guard_version_list = ['2.1', '2.2']
for version in torch_without_guard_version_list:
if torch.__version__.startswith(version):
@@ -222,3 +223,36 @@ class Const:
CONVERT_API = {
"int32_to_int64": ["cross_entropy"]
}
+
+
+def get_tensor_rank(in_feat, out_feat):
+ if dist.is_initialized():
+ return dist.get_rank()
+
+ def get_tensor_rank_single(x):
+ if isinstance(x, (list, tuple)):
+ if len(x) > 0:
+ return get_tensor_rank_single(x[0])
+ elif isinstance(x, torch.Tensor):
+ device = x.device
+ if device.type != 'cpu':
+ return device.index
+ return None
+
+ in_rank = get_tensor_rank_single(in_feat)
+ out_rank = get_tensor_rank_single(out_feat)
+ tensor_rank = in_rank if in_rank else out_rank
+ return tensor_rank
+
+
+def _create_logger(level=logging.INFO):
+ logger_ = logging.getLogger()
+ logger_.setLevel(level)
+ ch = logging.StreamHandler()
+ ch.setLevel(level)
+ logger_.addHandler(ch)
+ return logger_
+
+
+log_level = logging.DEBUG if os.environ.get("API_ACCURACY_CHECK_LOG_LEVEL") == "1" else logging.INFO
+logger = _create_logger(log_level)
diff --git a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py
similarity index 94%
rename from debug/accuracy_tools/atat/pytorch/compare/acc_compare.py
rename to debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py
index 061c9cdfca872bb76ad8ffa0e292fb24ce1a4ada..16533accb26680e341a71e8f40f30e2ed899301f 100644
--- a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py
+++ b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py
@@ -27,15 +27,17 @@ from openpyxl.styles import PatternFill
from collections import namedtuple
from dataclasses import dataclass
-from atat.pytorch.compare.match import graph_mapping
-from atat.pytorch.compare.highlight import HighlightRules, get_header_index
-from atat.pytorch.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, get_error_message
-from atat.pytorch.advisor.advisor import Advisor
-from atat.pytorch.common.log import logger
-from atat.core.common.utils import check_compare_param, add_time_with_xlsx, CompareException, \
+from msprobe.pytorch.compare.match import graph_mapping
+from msprobe.pytorch.compare.highlight import HighlightRules, get_header_index
+from msprobe.pytorch.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, \
+ get_error_message
+from msprobe.pytorch.advisor.advisor import Advisor
+from msprobe.pytorch.common.log import logger
+from msprobe.core.common.utils import check_compare_param, add_time_with_xlsx, CompareException, \
format_value, check_file_not_exists, check_configuration_param, task_dumppath_get
-from atat.core.common.file_check import FileChecker, change_mode, FileOpen, create_directory
-from atat.core.common.const import Const, CompareConst, FileCheckConst
+from msprobe.core.common.file_utils import FileChecker, change_mode, FileOpen, create_directory
+from msprobe.core.common.const import Const, CompareConst, FileCheckConst
+from msprobe.core.common.exceptions import FileCheckException
def check_graph_mode(a_op_name, b_op_name):
@@ -490,6 +492,10 @@ def compare_by_op(op_name, op_name_mapping_dict, input_parma):
error_file = error.filename
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
error_flag = True
+ except FileCheckException:
+ error_file = data_name
+ n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
+ error_flag = True
n_value, b_value, error_flag = get_error_type(n_value, b_value, error_flag)
if not error_flag:
@@ -525,8 +531,10 @@ def handle_inf_nan(n_value, b_value):
return n_value, b_value
-def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compare=False):
+def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compare=False, md5_compare=False):
"""找到单个API中需要高亮的行"""
+ if md5_compare:
+ return
npu_max_index = get_header_index('NPU max', summary_compare)
bench_max_index = get_header_index('Bench max', summary_compare)
max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
@@ -582,7 +590,7 @@ def get_name_and_state(name):
return api_name, state
-def find_compare_result_error_rows(result_df, highlight_dict, summary_compare):
+def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare):
"""将dataframe根据API分组,并找到有误差的算子用于高亮"""
result = result_df.values
start, input_num, output_num, end = 0, 0, 0, len(result_df)
@@ -600,7 +608,7 @@ def find_compare_result_error_rows(result_df, highlight_dict, summary_compare):
else:
output_num = num
find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
- summary_compare)
+ summary_compare, md5_compare)
num, last_api_name, last_state = 1, api_name, state
start += input_num + output_num
input_num, output_num = 1, 0
@@ -611,7 +619,7 @@ def find_compare_result_error_rows(result_df, highlight_dict, summary_compare):
input_num = num
else:
output_num = num
- find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, summary_compare)
+ find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, summary_compare, md5_compare)
def highlight_rows_xlsx(result_df, highlight_dict, file_path):
@@ -637,7 +645,11 @@ def highlight_rows_xlsx(result_df, highlight_dict, file_path):
elif (i - 2) in highlight_dict['yellow_rows']:
ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.YELLOW,
end_color=CompareConst.YELLOW, fill_type="solid")
- wb.save(file_path)
+ try:
+ wb.save(file_path)
+ except Exception as e:
+ logger.error('Save result file failed')
+ raise CompareException(CompareException.WRITE_FILE_ERROR) from e
change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
@@ -647,8 +659,8 @@ def compare(input_parma, output_path, stack_mode=False, auto_analyze=True,
summary_compare, md5_compare = task_dumppath_get(input_parma)
check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
create_directory(output_path)
- check_compare_param(input_parma, output_path, stack_mode, summary_compare, md5_compare)
- except CompareException as error:
+ check_compare_param(input_parma, output_path, summary_compare, md5_compare)
+ except (CompareException, FileCheckException) as error:
logger.error('Compare failed. Please check the arguments and do it again!')
sys.exit(error.code)
compare_core(input_parma, output_path, stack_mode=stack_mode,
@@ -696,7 +708,7 @@ def compare_core(input_parma, output_path, **kwargs):
if not md5_compare and not summary_compare:
result_df = _do_multi_process(input_parma, result_df)
- find_compare_result_error_rows(result_df, highlight_dict, summary_compare)
+ find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare)
highlight_rows_xlsx(result_df, highlight_dict, file_path)
if auto_analyze:
advisor = Advisor(result_df, output_path)
@@ -738,7 +750,7 @@ def parse(pkl_file, module_name_prefix):
logger.info(summary_info)
-def op_item_parse(item, op_name, index, item_list=[], top_bool=True):
+def op_item_parse(item, op_name, index, item_list=None, top_bool=True):
if item_list is None:
item_list = []
if item is None or (isinstance(item, dict) and not item):
@@ -756,9 +768,14 @@ def op_item_parse(item, op_name, index, item_list=[], top_bool=True):
else:
full_op_name = op_name
else:
- full_op_name = op_name + '.' + str(index)
+ full_op_name = op_name + Const.SEP + str(index)
if isinstance(item, dict):
- if 'dtype' in item:
+ if 'type' not in item:
+ for kwarg in item:
+ kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None)
+ item_list += kwarg_parsed_list
+ kwarg_parsed_list.clear()
+ elif 'dtype' in item:
parsed_item = item
parsed_item['full_op_name'] = full_op_name
item_list.append(parsed_item)
@@ -800,8 +817,8 @@ def op_item_parse(item, op_name, index, item_list=[], top_bool=True):
else:
resolve_api_special_parameters(item, full_op_name, item_list)
else:
- for j in range(len(item)):
- op_item_parse(item[j], full_op_name, j, item_list=item_list, top_bool=False)
+ for j, item_spec in enumerate(item):
+ op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False)
return item_list
@@ -861,13 +878,13 @@ def read_op(op_data, op_name):
op_parsed_list += output_parsed_list
output_parsed_list.clear()
if 'backward' in op_name:
- if 'grad_input' in op_data:
- input_item = op_data['grad_input']
+ if 'input' in op_data:
+ input_item = op_data['input']
input_parsed_list = op_item_parse(input_item, op_name + '_input', None)
op_parsed_list = input_parsed_list.copy()
input_parsed_list.clear()
- if 'grad_output' in op_data:
- output_item = op_data['grad_output']
+ if 'output' in op_data:
+ output_item = op_data['output']
output_parsed_list = op_item_parse(output_item, op_name + '_output', None)
op_parsed_list += output_parsed_list
output_parsed_list.clear()
@@ -952,8 +969,9 @@ def compare_process(file_handles, stack_mode, fuzzy_match, summary_compare=False
if npu_ops_queue:
for npu_data in npu_ops_queue:
get_un_match_accuracy(result, npu_data, md5_compare, summary_compare)
+ result_to_csv(md5_compare, summary_compare, stack_mode, result)
- header = []
+def result_to_csv(md5_compare, summary_compare, stack_mode, result):
if md5_compare:
header = CompareConst.MD5_COMPARE_RESULT_HEADER[:]
elif summary_compare:
diff --git a/debug/accuracy_tools/atat/pytorch/compare/distributed_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py
similarity index 85%
rename from debug/accuracy_tools/atat/pytorch/compare/distributed_compare.py
rename to debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py
index b89adc1581e8b0cf76ca28c41dbf0e86738ebece..b2b3a6672c49f38ed8c27ef42f13b52bbdb051f5 100644
--- a/debug/accuracy_tools/atat/pytorch/compare/distributed_compare.py
+++ b/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py
@@ -17,11 +17,12 @@
import os
import sys
import re
-from atat.core.common.utils import CompareException, check_compare_param, \
+from msprobe.core.common.utils import CompareException, check_compare_param, \
check_configuration_param, task_dumppath_get, check_file_or_directory_path, check_regex_prefix_format_valid
-from atat.pytorch.compare.acc_compare import compare_core
-from atat.core.common.file_check import create_directory
-from atat.pytorch.common.log import logger
+from msprobe.pytorch.compare.acc_compare import compare_core
+from msprobe.core.common.file_utils import create_directory
+from msprobe.core.common.exceptions import FileCheckException
+from msprobe.pytorch.common.log import logger
def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
@@ -86,12 +87,11 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
'or use compare() api and manually match the ranks.')
raise CompareException(CompareException.INVALID_PATH_ERROR)
for nr, br in zip(npu_ranks, bench_ranks):
- n_dir = os.path.join(npu_dump_dir, nr)
- b_dir = os.path.join(bench_dump_dir, br)
- s_dir = b_dir
- npu_json_path = extract_json(n_dir, stack_json=False)
- bench_json_path = extract_json(b_dir, stack_json=False)
- stack_json_path = extract_json(s_dir, stack_json=True)
+ npu_data_dir = os.path.join(npu_dump_dir, nr)
+ bench_data_dir = os.path.join(bench_dump_dir, br)
+ npu_json_path = extract_json(npu_data_dir, stack_json=False)
+ bench_json_path = extract_json(bench_data_dir, stack_json=False)
+ stack_json_path = extract_json(npu_data_dir, stack_json=True)
dump_result_param = {
'npu_json_path': npu_json_path,
@@ -103,8 +103,8 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
summary_compare, md5_compare = task_dumppath_get(dump_result_param)
check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
create_directory(output_path)
- check_compare_param(dump_result_param, output_path, stack_mode=stack_mode, summary_compare=summary_compare)
- except CompareException as error:
+ check_compare_param(dump_result_param, output_path, summary_compare=summary_compare, md5_compare=md5_compare)
+ except (CompareException, FileCheckException) as error:
logger.error('Compare failed. Please check the arguments and do it again!')
sys.exit(error.code)
compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare,
diff --git a/debug/accuracy_tools/atat/pytorch/compare/highlight.py b/debug/accuracy_tools/msprobe/pytorch/compare/highlight.py
similarity index 97%
rename from debug/accuracy_tools/atat/pytorch/compare/highlight.py
rename to debug/accuracy_tools/msprobe/pytorch/compare/highlight.py
index 3a6898dedbb6910d0c1c9e55f80b55eb4fa0ed3c..82f0022f8b5d4a0c6472b749d4937bfe39ef8a86 100644
--- a/debug/accuracy_tools/atat/pytorch/compare/highlight.py
+++ b/debug/accuracy_tools/msprobe/pytorch/compare/highlight.py
@@ -1,8 +1,8 @@
import math
import abc
import numpy as np
-from atat.core.common.utils import get_header_index
-from atat.core.common.const import CompareConst
+from msprobe.core.common.utils import get_header_index
+from msprobe.core.common.const import CompareConst
class HighlightCheck(abc.ABC):
diff --git a/debug/accuracy_tools/atat/pytorch/compare/mapping.yaml b/debug/accuracy_tools/msprobe/pytorch/compare/mapping.yaml
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/compare/mapping.yaml
rename to debug/accuracy_tools/msprobe/pytorch/compare/mapping.yaml
diff --git a/debug/accuracy_tools/atat/pytorch/compare/match.py b/debug/accuracy_tools/msprobe/pytorch/compare/match.py
similarity index 91%
rename from debug/accuracy_tools/atat/pytorch/compare/match.py
rename to debug/accuracy_tools/msprobe/pytorch/compare/match.py
index 148fbb7d640b3fde5e3292508c29404e277cde84..ca335f7d8ec9bda9b594ff9cd2c9d5ed1f7a053f 100644
--- a/debug/accuracy_tools/atat/pytorch/compare/match.py
+++ b/debug/accuracy_tools/msprobe/pytorch/compare/match.py
@@ -1,7 +1,7 @@
import os
import yaml
-from atat.core.common.file_check import FileOpen
-from atat.core.common.utils import CompareException
+from msprobe.core.common.file_utils import FileOpen
+from msprobe.core.common.utils import CompareException
class AtenIrMapping():
diff --git a/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/npy_compare.py
similarity index 98%
rename from debug/accuracy_tools/atat/pytorch/compare/npy_compare.py
rename to debug/accuracy_tools/msprobe/pytorch/compare/npy_compare.py
index 0cf4c6c00a0a671a6ee46dc6dacce48af6e67adf..5a0feb4cd4a63b6f2ab680c9e9a0f0e92b594e2e 100644
--- a/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py
+++ b/debug/accuracy_tools/msprobe/pytorch/compare/npy_compare.py
@@ -1,8 +1,8 @@
import abc
import numpy as np
-from atat.core.common.utils import format_value
-from atat.core.common.const import Const, CompareConst
-from atat.pytorch.common.log import logger
+from msprobe.core.common.utils import format_value
+from msprobe.core.common.const import Const, CompareConst
+from msprobe.pytorch.common.log import logger
def handle_inf_nan(n_value, b_value):
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/__init__.py b/debug/accuracy_tools/msprobe/pytorch/debugger/__init__.py
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/__init__.py
rename to debug/accuracy_tools/msprobe/pytorch/debugger/__init__.py
diff --git a/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py
similarity index 89%
rename from debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py
rename to debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py
index 1ad69701e4167db3e2e9f61f7f725f024ec21d16..851a61d04b1442dac30fb93b8b3c888943d10a1f 100644
--- a/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py
+++ b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py
@@ -1,6 +1,6 @@
-from atat.pytorch.common import seed_all
-from atat.pytorch.common.log import logger
-from atat.core.common.const import Const
+from msprobe.pytorch.common import seed_all
+from msprobe.pytorch.common.log import logger
+from msprobe.core.common.const import Const
class DebuggerConfig:
@@ -13,6 +13,7 @@ class DebuggerConfig:
self.seed = common_config.seed if common_config.seed else 1234
self.is_deterministic = common_config.is_deterministic
self.enable_dataloader = common_config.enable_dataloader
+ self.enable_step_auto_dump = common_config.enable_step_auto_dump
self.scope = task_config.scope if task_config.scope else []
self.list = task_config.list if task_config.list else []
self.data_mode = task_config.data_mode if task_config.data_mode else ["all"]
@@ -21,7 +22,7 @@ class DebuggerConfig:
self.acl_config = common_config.acl_config if common_config.acl_config else ""
self.is_forward_acl_dump = True
self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
- self.overflow_num = task_config.overflow_num if task_config.overflow_num else 1
+ self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
self.framework = Const.PT_FRAMEWORK
if self.task == Const.FREE_BENCHMARK:
@@ -46,9 +47,8 @@ class DebuggerConfig:
raise ValueError("backward_input must be configured when scope contains 'backward'")
if Const.BACKWARD in self.scope[0]:
self.is_forward_acl_dump = False
- for index in range(len(self.scope)):
- # Do this replace operation to let the acl backward dump can be done in forward hook.
- self.scope[index] = self.scope[index].replace(Const.BACKWARD, Const.FORWARD)
+ for index, scope_spec in enumerate(self.scope):
+ self.scope[index] = scope_spec.replace(Const.BACKWARD, Const.FORWARD)
self.backward_input[self.scope[index]] = self.backward_input_list[index]
seed_all(self.seed, self.is_deterministic)
diff --git a/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py
similarity index 49%
rename from debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py
rename to debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py
index 140d829bedc6fb1243820c0d0ccc84af42f01424..2d566800906cfce2eca2404211b7f6e261f44d5d 100644
--- a/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py
+++ b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py
@@ -1,10 +1,10 @@
import torch
from torch.utils.data import dataloader
-from atat.pytorch.debugger.debugger_config import DebuggerConfig
-from atat.pytorch.service import Service
-from atat.pytorch.common.log import logger
-from atat.pytorch.pt_config import parse_json_config
-from atat.core.common.exceptions import MsaccException
+from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
+from msprobe.pytorch.service import Service
+from msprobe.pytorch.common.log import logger
+from msprobe.pytorch.pt_config import parse_json_config
+from msprobe.core.common.exceptions import MsprobeException
class PrecisionDebugger:
@@ -25,13 +25,17 @@ class PrecisionDebugger:
level=None,
model=None,
step=None,
+ enable_step_auto_dump=None
):
if not hasattr(self, "initialized"):
+ self.api_origin = False
self.initialized = True
self.model = self.check_model_valid(model)
common_config, task_config = parse_json_config(config_path, task)
if step:
common_config.step = step
+ if enable_step_auto_dump:
+ common_config.enable_step_auto_dump = enable_step_auto_dump
self.config = DebuggerConfig(
common_config, task_config, task, dump_path, level
)
@@ -41,11 +45,37 @@ class PrecisionDebugger:
if self.enable_dataloader:
logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__)
+ self.enable_step_auto_dump = self.config.enable_step_auto_dump
+ if self.enable_step_auto_dump:
+ self.start_for_optimizer()
@property
def instance(self):
return self._instance
+ @staticmethod
+ def check_model_valid(model):
+ if not model or isinstance(model, torch.nn.Module):
+ return model
+ raise MsprobeException(
+ MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是torch.nn.Module类型。"
+ )
+
+ # 非侵入式dump使能方法
+ @classmethod
+ def start_for_optimizer(cls):
+ instance = cls._instance
+ if not instance:
+ raise Exception("No instance of PrecisionDebugger found.")
+ elif torch.__version__ < '2.0.0':
+ raise Exception("Pytorch version is earlier than 2.0.0 does not support optimizer hooks. \
+ Please turn off enable_step_auto_dump and use the start, stop and step methods of PrecisionDebugger instead.")
+ else:
+ logger.info_on_rank_0("The enable_step_auto_dump is on and start()/stop()/step() will not take effect.")
+ logger.warning_on_rank_0("Customized optimizer iteration is not supported. Please use start, stop and step methods when using customized optimizer.")
+ instance.service.hook_optimizer(instance.model)
+ instance.service.start(instance.model)
+
@classmethod
def start(cls):
instance = cls._instance
@@ -53,8 +83,25 @@ class PrecisionDebugger:
raise Exception("No instance of PrecisionDebugger found.")
if instance.enable_dataloader:
logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
+ elif instance.enable_step_auto_dump:
+ logger.warning_on_rank_0("optimizer is enabled, start() skipped.")
else:
- instance.service.start(instance.model)
+ instance.service.start(instance.model, instance.api_origin)
+ instance.api_origin = False
+
+ # 指定代码段dump前反向结束符,之后的计算过程数据将被忽略,无法被dump
+ @classmethod
+ def forward_backward_dump_end(cls):
+ instance = cls._instance
+ if not instance:
+ raise Exception("PrecisionDebugger instance is not created.")
+ if instance.enable_dataloader:
+ logger.warning_on_rank_0("DataLoader is enabled, forward_backward_dump_end() skipped.")
+ elif instance.enable_step_auto_dump:
+ logger.warning_on_rank_0("optimizer is enabled, forward_backward_dump_end() skipped.")
+ else:
+ instance.service.forward_backward_dump_end()
+ instance.api_origin = True
@classmethod
def stop(cls):
@@ -63,23 +110,20 @@ class PrecisionDebugger:
raise Exception("PrecisionDebugger instance is not created.")
if instance.enable_dataloader:
logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.")
+ elif instance.enable_step_auto_dump:
+ logger.warning_on_rank_0("optimizer is enabled, stop() skipped.")
else:
instance.service.stop()
@classmethod
def step(cls):
- if not cls._instance:
+ instance = cls._instance
+ if not instance:
raise Exception("PrecisionDebugger instance is not created.")
- cls._instance.service.step()
-
- @staticmethod
- def check_model_valid(model):
- if not model or isinstance(model, torch.nn.Module):
- return model
- raise MsaccException(
- MsaccException.INVALID_PARAM_ERROR, "model 参数必须是torch.nn.Module类型。"
- )
-
+ elif instance.enable_step_auto_dump:
+ logger.warning_on_rank_0("optimizer is enabled, step() skipped.")
+ else:
+ instance.service.step()
def iter_tracer(func):
def func_wrapper(*args, **kwargs):
diff --git a/debug/accuracy_tools/atat/pytorch/doc/FAQ.md b/debug/accuracy_tools/msprobe/pytorch/doc/FAQ.md
similarity index 72%
rename from debug/accuracy_tools/atat/pytorch/doc/FAQ.md
rename to debug/accuracy_tools/msprobe/pytorch/doc/FAQ.md
index 19a434a194686e1d75c411785e052c7dc0b085f9..8d12a72928ee4d9977b1db05a72f2189b2edd3c1 100644
--- a/debug/accuracy_tools/atat/pytorch/doc/FAQ.md
+++ b/debug/accuracy_tools/msprobe/pytorch/doc/FAQ.md
@@ -22,15 +22,15 @@
6. 添加预检工具后截取操作报错:`IndexError: too many indices for tensor of dimension x` 或 `TypeError: len() of a 0-d tensor`。
- 答:注释工具目录mstt/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml文件中Tensor:下的`- __getitem__`,工具会跳过dump该API。如果是需要dump的关键位置API也可以考虑根据报错堆栈信息注释引发报错的类型检查。
+ 答:注释工具目录mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml文件中Tensor:下的`- __getitem__`,工具会跳过dump该API。如果是需要dump的关键位置API也可以考虑根据报错堆栈信息注释引发报错的类型检查。
7. 添加预检工具后F.gelu触发ValueError报错:`activation_func must be F.gelu`等。
- 答:注释工具目录mstt/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml文件中functional:下的的`- gelu`,工具会跳过dump该API。如果是需要dump的关键位置API也可以考虑根据报错堆栈信息注释引发报错的类型检查。
+ 答:注释工具目录mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml文件中functional:下的的`- gelu`,工具会跳过dump该API。如果是需要dump的关键位置API也可以考虑根据报错堆栈信息注释引发报错的类型检查。
8. 添加预检工具后触发AsStrided算子相关的报错,或者编译相关的报错,如:`Failed to compile Op [AsStrided]`。
- 答:注释工具目录mstt/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml文件中Tensor:下的`- t`和`- transpose`。
+ 答:注释工具目录mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml文件中Tensor:下的`- t`和`- transpose`。
9. Tensor 魔法函数具体对应什么操作?
@@ -75,7 +75,7 @@
### dump指定融合算子
-dump指定操作当前支持dump指定融合算子的输入输出,需要在mstt/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml中添加,比如以下代码段调用的softmax融合算子
+dump指定操作当前支持dump指定融合算子的输入输出,需要在mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml中添加,比如以下代码段调用的softmax融合算子
```
def npu_forward_fused_softmax(self, input_, mask):
@@ -111,7 +111,7 @@ torch版本和硬件差异属于正常情况。
**故障现象**
-使用atat工具时,报错: error code: EI0006。
+使用msprobe工具时,报错: error code: EI0006。
**故障原因**
@@ -136,7 +136,7 @@ torch.npu.set_device('npu:0')
torch.npu.set_device(f'npu:{rank}')
```
-如果运行精度比对功能遇到这个报错,尝试安装最新版本的atat。
+如果运行精度比对功能遇到这个报错,尝试安装最新版本的msprobe。
### 4. dump得到的VF_lstm_99_forward_input.1.0.npy、VF_lstm_99_forward_input.1.1.npy类似的数据是否正常?
@@ -147,7 +147,7 @@ torch.npu.set_device(f'npu:{rank}')
在比对脚本中,设置stack_mode=True,例如:
```
-from atat.pytorch import compare
+from msprobe.pytorch import compare
dump_result_param={
"npu_json_path": "./npu_dump/dump.json",
"bench_json_path": "./gpu_dump/dump.json",
@@ -174,20 +174,20 @@ compare(dump_result_param, output_path="./output", stack_mode=True)
### 9. dump.json文件中的某些api的dtype类型为float16,但是读取此api的npy文件显示的dtype类型为float32
-- atat工具在dump数据时需要将原始数据从npu to cpu上再转换为numpy类型,npu to cpu的逻辑和gpu to cpu是保持一致的,都存在dtype可能从float16变为float32类型的情况,如果出现dtype不一致的问题,最终dump数据的dtype以pkl文件为准。
+- msprobe工具在dump数据时需要将原始数据从npu to cpu上再转换为numpy类型,npu to cpu的逻辑和gpu to cpu是保持一致的,都存在dtype可能从float16变为float32类型的情况,如果出现dtype不一致的问题,最终dump数据的dtype以pkl文件为准。
-### 10. 使用dataloader后raise异常Exception("atat: exit after iteration {}". format(max(self.config.step))
+### 10. 使用dataloader后raise异常Exception("msprobe: exit after iteration {}". format(max(self.config.step))
- 正常现象,dataloader通过raise结束程序,堆栈信息可忽略。
-### 11. 添加atat工具后截取操作报错:`IndexError: too many indices for tensor of dimension x` 或 `TypeError: len() of a 0-d tensor`。
+### 11. 添加msprobe工具后截取操作报错:`IndexError: too many indices for tensor of dimension x` 或 `TypeError: len() of a 0-d tensor`。
-- 注释工具目录mstt/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml文件中Tensor:下的`- __getitem__`,工具会跳过dump该API。如果是需要dump的关键位置API也可以考虑根据报错堆栈信息注释引发报错的类型检查。
+- 注释工具目录mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml文件中Tensor:下的`- __getitem__`,工具会跳过dump该API。如果是需要dump的关键位置API也可以考虑根据报错堆栈信息注释引发报错的类型检查。
-### 12. 添加atat工具后F.gelu触发ValueError报错:`activation_func must be F.gelu`等。
+### 12. 添加msprobe工具后F.gelu触发ValueError报错:`activation_func must be F.gelu`等。
-- 注释工具目录mstt/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml文件中functional:下的的`- gelu`,工具会跳过dump该API。如果是需要dump的关键位置api也可以考虑根据报错堆栈信息注释引发报错的类型检查。
+- 注释工具目录mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml文件中functional:下的的`- gelu`,工具会跳过dump该API。如果是需要dump的关键位置api也可以考虑根据报错堆栈信息注释引发报错的类型检查。
-### 13. 添加atat工具后触发AsStrided算子相关的报错,或者编译相关的报错,如:`Failed to compile Op [AsStrided]`。
+### 13. 添加msprobe工具后触发AsStrided算子相关的报错,或者编译相关的报错,如:`Failed to compile Op [AsStrided]`。
-- 注释工具目录mstt/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml文件中Tensor:下的`- t`和`- transpose`。
+- 注释工具目录mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml文件中Tensor:下的`- t`和`- transpose`。
diff --git a/debug/accuracy_tools/atat/pytorch/doc/api_accuracy_checker.md b/debug/accuracy_tools/msprobe/pytorch/doc/api_accuracy_checker.md
similarity index 84%
rename from debug/accuracy_tools/atat/pytorch/doc/api_accuracy_checker.md
rename to debug/accuracy_tools/msprobe/pytorch/doc/api_accuracy_checker.md
index 0e45a4e83fb720411b07923e7888cff404148d8d..41b97098ae95394511a375bc9f17a2041de6a826 100644
--- a/debug/accuracy_tools/atat/pytorch/doc/api_accuracy_checker.md
+++ b/debug/accuracy_tools/msprobe/pytorch/doc/api_accuracy_checker.md
@@ -20,8 +20,8 @@
精度预检操作流程如下:
-1. 在NPU和GPU环境下分别安装atat工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。
-2. 在NPU训练脚本内添加atat工具dump接口PrecisionDebugger采集待预检数据。详见《[精度数据采集](./dump.md)》。
+1. 在NPU和GPU环境下分别安装msprobe工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。
+2. 在NPU训练脚本内添加msprobe工具dump接口PrecisionDebugger,采集待预检数据。详见《[精度数据采集](./dump.md)》,注意需要配置level="L1"。
3. 将NPU环境下dump的预检数据拷贝至GPU环境。
4. 在NPU和GPU环境下分别执行run_ut,生成结果用于最终api_precision_compare操作的输入。详见“**run_ut预检操作**”。
5. 将NPU和GPU执行run_ut生成的`accuracy_checking_details_{timestamp}.csv`结果文件拷贝至同一环境下。
@@ -43,7 +43,7 @@ run_ut预检操作包括如下场景:
1. 将API信息输入给run_ut模块运行精度检测并比对,运行如下命令:
```bash
- atat -f pytorch run_ut -api_info ./dump.json
+ msprobe -f pytorch run_ut -api_info ./dump.json
```
| 参数名称 | 说明 | 是否必选 |
@@ -51,20 +51,22 @@ run_ut预检操作包括如下场景:
| -api_info或--api_info_file | 指定API信息文件dump.json。 | 是 |
| -save_error_data | 保存精度未达标的API输入输出数据。 | 否 |
| -o或--out_path | 指定run_ut执行结果存盘路径,默认“./”(相对于run_ut的路径)。 | 否 |
+ | | | |
| -j或--jit_compile | 开启jit编译。 | 否 |
| -d或--device | 指定Device ID,选择UT代码运行所在的卡,默认值为0。 | 否 |
| -csv_path或--result_csv_path | 指定本次运行中断时生成的`accuracy_checking_result_{timestamp}.csv`文件路径,执行run_ut中断时,若想从中断处继续执行,配置此参数即可。需要指定为上次中断的`accuracy_checking_result_{timestamp}.csv`文件。详见“**断点续检**”。 | run_ut操作中断后继续执行场景下必选 |
| -f或--filter_api | 过滤模型中除最大值和最小值以外其他参数和结构相同的API。适用于模型较大且重复API较多的场景。 | 否 |
+ | -config或--config_path | 指定预检操作过程中的额外配置(包括黑名单、白名单等)的[config.json](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/msprobe/config)文件,默认未配置。config.json文件的配置可参考《[配置文件说明](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/config/README.md#pytorch场景task配置为run_ut)》。 | 否 |
run_ut执行结果包括`accuracy_checking_result_{timestamp}.csv`和`accuracy_checking_details_{timestamp}.csv`两个文件。`accuracy_checking_result_{timestamp}.csv`是API粒度的,标明每个API是否通过测试。建议用户先查看`accuracy_checking_result_{timestamp}.csv`文件,对于其中没有通过测试的或者特定感兴趣的API,根据其API name字段在`accuracy_checking_details_{timestamp}.csv`中查询其各个输出的达标情况以及比较指标。详细介绍请参见“**预检结果**”。
2. (可选)如果需要保存比对不达标的输入和输出数据,可以在run_ut执行命令结尾添加-save_error_data,例如:
```bash
- atat -f pytorch run_ut -api_info ./dump.json -save_error_data
+ msprobe -f pytorch run_ut -api_info ./dump.json -save_error_data
```
- 数据默认会存盘到'./ut_error_data{timestamp}'路径下(相对于启动run_ut的路径),有需要的话,用户可以通过修改mstt/debug/accuracy_tools/api_accuracy_checker目录下,config.yaml文件的error_data_path参数来配置保存路径,详见“config.yaml文件说明”。
+ 数据默认会存盘到'./ut_error_data{timestamp}'路径下(相对于启动run_ut的路径),有需要的话,用户可以通过error_data_path参数来配置保存路径,error_data_path参数在[config.json](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/msprobe/config)文件或config.yaml文件配置,config.json文件需要在run_ut操作时通过-config参数指定,config.yaml文件详见“**config.yaml文件说明**”。
#### 使用multi_run_ut.py执行多线程预检
@@ -73,7 +75,7 @@ multi_run_ut.py脚本,可以并行执行多个run_ut操作,从而降低预
命令示例如下:
```bash
-atat -f pytorch multi_run_ut -api_info ./dump.json -n 32 -d 0 1 2 3
+msprobe -f pytorch multi_run_ut -api_info ./dump.json -n 32 -d 0 1 2 3
```
| 参数名称 | 说明 | 是否必选 |
@@ -96,26 +98,68 @@ atat -f pytorch multi_run_ut -api_info ./dump.json -n 32 -d 0 1 2 3
断点续检操作通过如下命令执行:
```bash
-atat -f pytorch run_ut -api_info ./dump.json -csv_path /home/xxx/ut/accuracy_checking_result_{timestamp}.csv
+msprobe -f pytorch run_ut -api_info ./dump.json -csv_path /home/xxx/ut/accuracy_checking_result_{timestamp}.csv
```
-#### API预检白名单
+#### API预检黑名单和白名单
-run_ut过程支持API预检白名单,操作方式如下:
+run_ut过程支持API预检黑名单和白名单,通过如下文件配置black_list(黑名单)或white_list(白名单)参数来指定不需要或需要预检的API名称:
-修改mstt/debug/accuracy_tools/api_accuracy_checker目录下config.yaml文件的white_list参数,配置需要预检的API名称,详见“config.yaml文件说明”。
+- 配置[config.json](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/msprobe/config)文件,config.json文件需要在run_ut操作时通过-config参数指定。
+- 配置config.yaml文件,详见“**config.yaml文件说明**”。
+
+config.json文件的优先级高于config.yaml文件,即执行config.json文件时,config.yaml文件的配置不生效。
### config.yaml文件说明
-config.yaml文件可以通过配置参数来控制dump和run_ut操作的白名单等功能。
+config.yaml文件可以通过配置参数来控制dump和run_ut操作的白名单、黑名单等功能。操作步骤如下:
+
+1. 查找msprobe工具安装路径。
+
+ ```bash
+ pip show mindstudio-probe
+ ```
+
+ 输出结果如下示例:
+
+ ```bash
+ Name: mindstudio-probe
+ Version: 1.0
+ Summary: This is a pytorch precision comparison tools
+ Home-page:
+ Author:
+ Author-email:
+ License:
+ Location: /home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages
+ Requires: numpy, openpyxl, pandas, pyyaml, rich, tqdm, wheel
+ Required-by:
+ ```
+
+ Location字段为msprobe工具的安装路径,那么config.yaml文件位置为/home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages/msprobe/pytorch/api_accuracy_checker/config.yaml
+
+2. 进入config.yaml文件
+
+ ```bash
+ vi /home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages/msprobe/pytorch/api_accuracy_checker/config.yaml
+ ```
+
+3. 修改config.yaml文件参数。
+
+ ```yaml
+ white_list: []
+ black_list: []
+ error_data_path: './'
+ precision: 14
+ ```
-文件路径为:mstt/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/config.yaml
+ | 参数名称 | 说明 | 是否必选 |
+ | --------------- | ------------------------------------------------------------ | -------- |
+ | white_list | API dump白名单,仅对指定的API进行dump。参数示例:white_list=["conv1d", "conv2d"]。默认未配置白名单,即dump全量API数据。 | 否 |
+ | black_list | API dump黑名单,被指定的API不进行dump。参数示例:black_list=["conv1d", "conv2d"]。默认未配置黑名单,即dump全量API数据。 | 否 |
+ | error_data_path | 配置保存精度未达标的API输入输出数据路径。参数示例"error_data_path": "./"。默认为当前路径。 | 否 |
+ | precision | 浮点数表示位数,默认取小数点后14位。 | 否 |
-| 参数名称 | 说明 | 是否必选 |
-| --------------- | ------------------------------------------------------------ | -------- |
-| white_list | API dump白名单,指定dump具体API数据,也可以直接配置预检的API白名单,详细请参见“**API预检白名单**”。参数示例:white_list=["conv1d", "conv2d"]。默认未配置白名单,即dump全量API数据。 | 否 |
-| error_data_path | 配置保存精度未达标的API输入输出数据路径。 | 否 |
-| precision | 浮点数表示位数,默认取小数点后14位。 | 否 |
+ 说明:white_list和black_list同时配置时,二者配置的API名单若无交集,则白名单生效,若API名单存在交集,则白名单排除的部分以及交集的API不进行dump。
## 预检结果
@@ -203,7 +247,7 @@ API预检通过测试,则在`accuracy_checking_details_{timestamp}.csv`文件
需要同时获取NPU和GPU环境下run_ut操作的预检结果`accuracy_checking_details_{timestamp}.csv`文件。执行如下命令进行NPU和GPU预检结果的比对:
```bash
-atat -f pytorch api_precision_compare -npu /home/xxx/npu/accuracy_checking_details_{timestamp}.csv -gpu /home/xxx/gpu/accuracy_checking_details_{timestamp}.csv -o /home/xxx/
+msprobe -f pytorch api_precision_compare -npu /home/xxx/npu/accuracy_checking_details_{timestamp}.csv -gpu /home/xxx/gpu/accuracy_checking_details_{timestamp}.csv -o /home/xxx/
```
| 参数名称 | 说明 | 是否必选 |
diff --git a/debug/accuracy_tools/atat/pytorch/doc/dump.md b/debug/accuracy_tools/msprobe/pytorch/doc/dump.md
similarity index 66%
rename from debug/accuracy_tools/atat/pytorch/doc/dump.md
rename to debug/accuracy_tools/msprobe/pytorch/doc/dump.md
index 1e401b4f5a22fc137872f8678591b9c1eca094b6..7e393cd1026f5881de98ad830ddcc29bc6946673 100644
--- a/debug/accuracy_tools/atat/pytorch/doc/dump.md
+++ b/debug/accuracy_tools/msprobe/pytorch/doc/dump.md
@@ -1,8 +1,8 @@
# **精度数据采集**
-atat工具主要通过在训练脚本内添加dump接口并启动训练的方式来采集精度数据。
+msprobe工具主要通过在训练脚本内添加dump接口并启动训练的方式来采集精度数据。
-执行dump操作需要安装atat工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。
+执行dump操作需要安装msprobe工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。
## dump接口介绍
@@ -12,7 +12,7 @@ atat工具主要通过在训练脚本内添加dump接口并启动训练的方式
通过加载dump配置文件的方式来确定dump操作的详细配置。
-可以在from atat.pytorch import PrecisionDebugger和模型初始化之间的任意位置添加该接口。
+PrecisionDebugger接口可以在from msprobe.pytorch import PrecisionDebugger之后的位置添加。详细使用可参考“**示例代码**”或“**model配置代码示例**”。
**原型**
@@ -20,7 +20,7 @@ atat工具主要通过在训练脚本内添加dump接口并启动训练的方式
PrecisionDebugger(config_path=None, task=None, dump_path=None, level=None, model=None, step=None)
```
-说明:上述参数除config_path和model外,其他参数均在[config.json](../../config)文件中可配,此处的参数优先级高于[config.json](../../config)文件中的配置,而config.json文件可以配置更多参数,若需要进行更多场景的精度数据dump,建议配置[config.json](../../config)文件。
+说明:上述参数除config_path和model外,其他参数均在[config.json](../../config)文件中可配,此处的参数优先级高于[config.json](../../config)文件中的配置,而config.json文件可以配置更多参数,若需要进行更多场景的精度数据dump,建议配置[config.json](../../config)文件。config.json文件的配置可参考《[配置文件说明](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/config/README.md)》。
**参数说明**
@@ -44,7 +44,7 @@ import torch
import torch.nn as nn
import torch_npu
import torch.nn.functional as F
-from atat.pytorch import PrecisionDebugger
+from msprobe.pytorch import PrecisionDebugger
torch.npu.set_device("npu:0")
#定义一个简单的网络
@@ -77,9 +77,9 @@ if __name__ == "__main__"
**功能说明**
-启动函数。
+dump启动函数。
-在模型初始化之后的任意位置添加。
+在模型初始化之后的位置添加。需要与stop函数一起添加在for循环内。
**原型**
@@ -93,9 +93,9 @@ debugger.start()
**功能说明**
-停止函数。
+dump停止函数。
-在**start**函数之后的任意位置添加。
+在**start**函数之后的任意位置添加。若需要dump反向数据,则需要添加在反向计算代码(如loss.backward)之后。
**原型**
@@ -105,13 +105,33 @@ debugger.stop()
该函数为类函数,可以使用debugger.stop()也可以使用PrecisionDebugger.stop()。
+### forward_backward_dump_end函数
+
+**功能说明**
+
+dump停止函数。用于dump指定代码的前反向数据。
+
+在**start**函数之后,反向计算代码(如loss.backward)之前的任意位置添加,可以dump **start**函数和该函数之间的前反向数据,可以通过调整**start**函数与该函数的位置,来指定需要dump的代码块。
+
+要求**stop**函数添加在反向计算代码(如loss.backward)之后,此时该函数与**stop**函数之间的代码不会被dump。
+
+使用示例参见“**示例代码 > 扩展示例**”。
+
+**原型**
+
+```Python
+forward_backward_dump_end()
+```
+
+该函数为类函数,可以使用debugger.forward_backward_dump_end()也可以使用PrecisionDebugger.forward_backward_dump_end()。
+
### step函数
**功能说明**
结束标识。
-在最后一个**stop**函数后或一个step结束的位置添加。
+在最后一个**stop**函数后或一个step结束的位置添加。需要与start函数一起添加在for循环内。
**原型**
@@ -123,24 +143,57 @@ debugger.step()
## 示例代码
+### 基础操作
+
+如下示例可dump完整代码的前反向数据。
+
```Python
-from atat.pytorch import PrecisionDebugger
+from msprobe.pytorch import PrecisionDebugger
+
+# 请勿将PrecisionDebugger的初始化流程插入到循环代码中
debugger = PrecisionDebugger(config_path="./config.json", dump_path="./dump_path")
-# 请勿将以上初始化流程插入到循环代码中
-# 模型初始化
-# 下面代码也可以用PrecisionDebugger.start()和PrecisionDebugger.stop()
-debugger.start()
+# 模型、损失函数的定义及初始化等操作
+# ...
-# 需要dump的代码片段1
+# 数据集迭代的位置一般为模型训练开始的位置
+for data, label in data_loader:
+ debugger.start() # 开启数据dump
-debugger.stop()
-debugger.start()
+ # 如下是模型每个step执行的逻辑
+ output = model(data)
+ #...
+ loss.backward()
+
+ debugger.stop() # 关闭数据dump
+ debugger.step() # 结束一个step的dump
+```
-# 需要dump的代码片段2
+### 扩展示例
-debugger.stop()
-debugger.step()
+如下示例dump指定代码块前反向数据。
+
+```Python
+from msprobe.pytorch import PrecisionDebugger
+
+# 请勿将PrecisionDebugger的初始化流程插入到循环代码中
+debugger = PrecisionDebugger(config_path="./config.json", dump_path="./dump_path")
+
+# 模型、损失函数的定义及初始化等操作
+# ...
+
+# 数据集迭代的位置一般为模型训练开始的位置
+for data, label in data_loader:
+ debugger.start() # 开启数据dump
+
+ # 如下是模型每个step执行的逻辑
+ output = model(data)
+ debugger.forward_backward_dump_end() # 插入该函数到start函数之后,只dump start函数到该函数之间代码的前反向数据,本函数到stop函数之间的数据则不dump
+ #...
+ loss.backward()
+
+ debugger.stop() # 关闭数据dump
+ debugger.step() # 结束一个step的dump
```
## dump结果文件介绍
@@ -193,7 +246,7 @@ pt文件保存的前缀和PyTorch对应关系如下:
## 工具支持的API列表
-atat工具维护固定的API支持列表,若需要删除或增加dump的API,可以在atat/pytorch/hook_module/support_wrap_ops.yaml文件内手动修改,如下示例:
+msprobe工具维护固定的API支持列表,若需要删除或增加dump的API,可以在msprobe/pytorch/hook_module/support_wrap_ops.yaml文件内手动修改,如下示例:
```Python
functional: # functional为算子类别,找到对应的类别,在该类别下按照下列格式删除或添加API
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/BLOOM-7B_1.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/BLOOM-7B_1.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/BLOOM-7B_1.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/BLOOM-7B_1.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/BLOOM-7B_2.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/BLOOM-7B_2.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/BLOOM-7B_2.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/BLOOM-7B_2.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/BLOOM-7B_3.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/BLOOM-7B_3.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/BLOOM-7B_3.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/BLOOM-7B_3.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/BLOOM-7B_4.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/BLOOM-7B_4.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/BLOOM-7B_4.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/BLOOM-7B_4.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_1.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_1.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_1.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_1.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_2.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_2.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_2.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_2.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_3.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_3.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_3.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_3.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_4.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_4.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_4.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_4.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_5.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_5.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_5.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_5.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_6.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_6.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_6.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_6.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_7.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_7.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_7.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_7.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_8.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_8.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_8.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_8.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/YOLOV5S_1.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/YOLOV5S_1.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/YOLOV5S_1.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/YOLOV5S_1.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/YOLOV5S_2.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/YOLOV5S_2.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/YOLOV5S_2.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/YOLOV5S_2.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/accuracy_checking_details.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/accuracy_checking_details.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/accuracy_checking_details.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/accuracy_checking_details.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/accuracy_checking_result.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/accuracy_checking_result.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/accuracy_checking_result.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/accuracy_checking_result.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/api_precision_compare_details.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/api_precision_compare_details.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/api_precision_compare_details.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/api_precision_compare_details.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/api_precision_compare_result.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/api_precision_compare_result.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/api_precision_compare_result.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/api_precision_compare_result.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/auto_analyze_log.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/auto_analyze_log.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/auto_analyze_log.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/auto_analyze_log.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/compare_result_pkl.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/compare_result_pkl.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/compare_result_pkl.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/compare_result_pkl.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/compare_result_pkl_md5.png.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/compare_result_pkl_md5.png.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/cpu_info.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/cpu_info.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/cpu_info.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/cpu_info.png
diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/module_compare.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/module_compare.png
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/doc/img/module_compare.png
rename to debug/accuracy_tools/msprobe/pytorch/doc/img/module_compare.png
diff --git "a/debug/accuracy_tools/atat/pytorch/doc/atat\347\262\276\345\272\246\345\267\245\345\205\267\346\225\260\346\215\256dump\346\240\207\345\207\206\346\200\247\350\203\275\345\237\272\347\272\277\346\212\245\345\221\212.md" "b/debug/accuracy_tools/msprobe/pytorch/doc/msprobe\347\262\276\345\272\246\345\267\245\345\205\267\346\225\260\346\215\256dump\346\240\207\345\207\206\346\200\247\350\203\275\345\237\272\347\272\277\346\212\245\345\221\212.md"
similarity index 97%
rename from "debug/accuracy_tools/atat/pytorch/doc/atat\347\262\276\345\272\246\345\267\245\345\205\267\346\225\260\346\215\256dump\346\240\207\345\207\206\346\200\247\350\203\275\345\237\272\347\272\277\346\212\245\345\221\212.md"
rename to "debug/accuracy_tools/msprobe/pytorch/doc/msprobe\347\262\276\345\272\246\345\267\245\345\205\267\346\225\260\346\215\256dump\346\240\207\345\207\206\346\200\247\350\203\275\345\237\272\347\272\277\346\212\245\345\221\212.md"
index ed175ff30172a54d8d4868097599ab8518b45e4f..c9db3ae78d7d47330cf6cddcc66c741c77a63514 100644
--- "a/debug/accuracy_tools/atat/pytorch/doc/atat\347\262\276\345\272\246\345\267\245\345\205\267\346\225\260\346\215\256dump\346\240\207\345\207\206\346\200\247\350\203\275\345\237\272\347\272\277\346\212\245\345\221\212.md"
+++ "b/debug/accuracy_tools/msprobe/pytorch/doc/msprobe\347\262\276\345\272\246\345\267\245\345\205\267\346\225\260\346\215\256dump\346\240\207\345\207\206\346\200\247\350\203\275\345\237\272\347\272\277\346\212\245\345\221\212.md"
@@ -1,4 +1,4 @@
-# atat精度工具标准性能基线报告
+# msprobe精度工具标准性能基线报告
## 环境信息
@@ -16,7 +16,7 @@ CANN:8.0.T2
## 模型信息和性能基线
-大模型在使用atat工具dump数据时,建议先简化模型层数,减少dump数据量。
+大模型在使用msprobe工具dump数据时,建议先简化模型层数,减少dump数据量。
以下场景的性能基线测试数据均为多次测试后取平均值,因此实际运行时性能数据可能会根据环境状态稍有浮动。
diff --git a/debug/accuracy_tools/atat/pytorch/doc/parse_tool.md b/debug/accuracy_tools/msprobe/pytorch/doc/parse_tool.md
similarity index 98%
rename from debug/accuracy_tools/atat/pytorch/doc/parse_tool.md
rename to debug/accuracy_tools/msprobe/pytorch/doc/parse_tool.md
index 23000912910e8f95b4cb74c7983961918bd9a513..81efa10fa3ec4307603e24e9599c5a00367462d4 100644
--- a/debug/accuracy_tools/atat/pytorch/doc/parse_tool.md
+++ b/debug/accuracy_tools/msprobe/pytorch/doc/parse_tool.md
@@ -6,10 +6,10 @@
## 进入parse交互式界面
-安装atat工具后(详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节),可以通过使用命令 **atat -f pytorch parse** 进入交互式界面,如下所示:
+安装msprobe工具后(详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节),可以通过使用命令 **msprobe -f pytorch parse** 进入交互式界面,如下所示:
```bash
-atat -f pytorch parse
+msprobe -f pytorch parse
Parse >>>
```
@@ -23,7 +23,7 @@ Parse >>>
Ctrl+C可以退出parse交互式界面。不退出parse交互式界面若需要执行非该界面下的内置Shell命令,且命令与parse交互式界面命令冲突时,非该界面命令需要使用run命令,在相关命令前加上run前缀,如下示例:
```bash
-atat -f pytorch parse
+msprobe -f pytorch parse
Parse >>> run vim cli.py
Parse >>> vim cli.py
```
diff --git a/debug/accuracy_tools/atat/pytorch/doc/ptdbg_ascend_compare.md b/debug/accuracy_tools/msprobe/pytorch/doc/ptdbg_ascend_compare.md
similarity index 99%
rename from debug/accuracy_tools/atat/pytorch/doc/ptdbg_ascend_compare.md
rename to debug/accuracy_tools/msprobe/pytorch/doc/ptdbg_ascend_compare.md
index e3537594c4f8c9e277ca867172875e3e28c23113..4bd05c73e21c4491ee8286366a6b987a60ee69ae 100644
--- a/debug/accuracy_tools/atat/pytorch/doc/ptdbg_ascend_compare.md
+++ b/debug/accuracy_tools/msprobe/pytorch/doc/ptdbg_ascend_compare.md
@@ -44,7 +44,7 @@ compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs)
创建比对脚本,例如compare_distributed.py,拷贝如下代码,具体参数请根据实际环境修改。
```Python
-from atat.pytorch import *
+from msprobe.pytorch import *
compare_distributed('./npu_dump/step0', './gpu_dump/step0', './output')
```
@@ -77,7 +77,7 @@ compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_mat
单机单卡场景下创建比对脚本,例如compare.py,拷贝如下代码,具体参数请根据实际环境修改。
```Python
-from atat.pytorch import compare
+from msprobe.pytorch import compare
dump_result_param={
"npu_json_path": "./npu_dump/dump.json",
"bench_json_path": "./gpu_dump/dump.json",
@@ -96,7 +96,7 @@ compare(dump_result_param, output_path="./output", stack_mode=True)
以compare.py为例。
```Python
-from atat.pytorch import compare
+from msprobe.pytorch import compare
dump_result_param={
"npu_json_path": "./npu_dump/dump.json",
"bench_json_path": "./gpu_dump/dump.json",
diff --git a/debug/accuracy_tools/atat/pytorch/doc/ptdbg_ascend_overview.md b/debug/accuracy_tools/msprobe/pytorch/doc/ptdbg_ascend_overview.md
similarity index 81%
rename from debug/accuracy_tools/atat/pytorch/doc/ptdbg_ascend_overview.md
rename to debug/accuracy_tools/msprobe/pytorch/doc/ptdbg_ascend_overview.md
index 708d90b3487c47249c5f6a8b0f37671e8918e7e2..019451454877eddd3cf6e59cc1eef1c48fcf2a3c 100644
--- a/debug/accuracy_tools/atat/pytorch/doc/ptdbg_ascend_overview.md
+++ b/debug/accuracy_tools/msprobe/pytorch/doc/ptdbg_ascend_overview.md
@@ -4,7 +4,7 @@
在PyTorch训练网络,对同一模型或API调试过程中,遇到API相关的计算精度问题,定位时费时费力。
-atat的精度比对工具,用来进行PyTorch整网API粒度的数据dump、精度比对和溢出检测,从而定位PyTorch训练场景下的精度问题。
+msprobe的精度比对工具,用来进行PyTorch整网API粒度的数据dump、精度比对和溢出检测,从而定位PyTorch训练场景下的精度问题。
**使用场景**
@@ -42,17 +42,17 @@ atat的精度比对工具,用来进行PyTorch整网API粒度的数据dump、
1. 准备CPU或GPU训练工程。
-2. 在环境下安装atat工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。
+2. 在环境下安装msprobe工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。
-3. 在训练脚本内添加atat工具dump接口PrecisionDebugger采集标杆数据。详见《[精度数据采集](./dump.md)》。
+3. 在训练脚本内添加msprobe工具dump接口PrecisionDebugger采集标杆数据。详见《[精度数据采集](./dump.md)》。
4. 执行训练dump数据。
5. 将CPU或GPU训练工程迁移为NPU训练工程。详见《[PyTorch模型迁移调优指南](https://www.hiascend.com/document/detail/zh/Pytorch/60RC1/ptmoddevg/trainingmigrguide/PT_LMTMOG_0003.html)》。
-6. 在NPU环境下安装atat工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。
+6. 在NPU环境下安装msprobe工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。
-7. 在NPU训练脚本内添加atat工具dump接口PrecisionDebugger采集标杆数据。详见《[精度数据采集](./dump.md)》。
+7. 在NPU训练脚本内添加msprobe工具dump接口PrecisionDebugger采集标杆数据。详见《[精度数据采集](./dump.md)》。
8. NPU环境下执行训练dump数据。
diff --git a/debug/accuracy_tools/atat/pytorch/doc/ptdbg_ascend_quickstart.md b/debug/accuracy_tools/msprobe/pytorch/doc/ptdbg_ascend_quickstart.md
similarity index 94%
rename from debug/accuracy_tools/atat/pytorch/doc/ptdbg_ascend_quickstart.md
rename to debug/accuracy_tools/msprobe/pytorch/doc/ptdbg_ascend_quickstart.md
index c05302055687fdf6071befd7ff8ad77c9e32f2df..4b6ac9de2f075ad17ff8a594bc40df4c36692f5f 100644
--- a/debug/accuracy_tools/atat/pytorch/doc/ptdbg_ascend_quickstart.md
+++ b/debug/accuracy_tools/msprobe/pytorch/doc/ptdbg_ascend_quickstart.md
@@ -1,8 +1,8 @@
# **精度比对工具**
-本文主要介绍atat的精度比对工具的快速入门和场景化示例。
+本文主要介绍msprobe的精度比对工具的快速入门和场景化示例。
-本文介绍的操作需要安装atat工具,详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。
+本文介绍的操作需要安装msprobe工具,详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。
本文介绍的操作主要是精度数据dump和精度比对,详细操作指导可参考《[精度数据采集](./dump.md)》和《[CPU或GPU与NPU精度数据比对](./ptdbg_ascend.md)》。
@@ -51,12 +51,12 @@ PyTorch训练场景的精度问题分析建议参考以下思路进行精度比
}
```
-2. 在训练脚本内添加atat工具,dump整网数据。
+2. 在训练脚本内添加msprobe工具,dump整网数据。
分别dump CPU或GPU以及NPU数据,在PyTorch训练脚本插入dump接口,示例代码如下(下面以NPU为例,CPU或GPU dump基本相同):
```python
- from atat.pytorch import PrecisionDebugger
+ from msprobe.pytorch import PrecisionDebugger
debugger = PrecisionDebugger(config_path="./config.json", dump_path="./npu_dump")
# 请勿将以上初始化流程插入到循环代码中
@@ -82,7 +82,7 @@ PyTorch训练场景的精度问题分析建议参考以下思路进行精度比
创建并配置精度比对脚本,以创建compare.py为例,示例代码如下:
```python
- from atat.pytorch import compare
+ from msprobe.pytorch import compare
dump_result_param={
"npu_json_path": "./npu_dump/dump.json",
"bench_json_path": "./gpu_dump/dump.json",
@@ -140,10 +140,10 @@ python3 compare.py
}
```
-2. 在NPU训练脚本内添加atat工具,执行溢出检测dump。
+2. 在NPU训练脚本内添加msprobe工具,执行溢出检测dump。
```python
- from atat.pytorch import PrecisionDebugger
+ from msprobe.pytorch import PrecisionDebugger
debugger = PrecisionDebugger(config_path="./config.json", dump_path="./npu_dump")
# 请勿将以上初始化流程插入到循环代码中
@@ -171,7 +171,7 @@ python3 compare.py
溢出解析工具执行命令如下:
```bash
- atat -f pytorch run_overflow_check -api_info ./dump.json
+ msprobe -f pytorch run_overflow_check -api_info ./dump.json
```
反向过程溢出的API暂不支持精度预检功能。
@@ -200,7 +200,7 @@ python3 compare.py
1. 创建比对脚本,例如compare_distributed.py,拷贝如下代码。
```python
- from atat.pytorch import *
+ from msprobe.pytorch import *
compare_distributed('./npu_dump/step0', './gpu_dump/step0', './output')
```
@@ -219,7 +219,7 @@ python3 compare.py
多卡一般为多进程,须保证每个进程都正确调用PrecisionDebugger,或把PrecisionDebugger插入到import语句后,如:
```python
-from atat.pytorch import PrecisionDebugger
+from msprobe.pytorch import PrecisionDebugger
debugger = PrecisionDebugger(config_path="./config.json", dump_path="./npu_dump")
```
@@ -339,10 +339,10 @@ debugger = PrecisionDebugger(config_path="./config.json", dump_path="./npu_dump"
}
```
-2. 在训练脚本内添加atat工具,dump整网数据。
+2. 在训练脚本内添加msprobe工具,dump整网数据。
```python
- from atat.pytorch import PrecisionDebugger
+ from msprobe.pytorch import PrecisionDebugger
debugger = PrecisionDebugger(config_path="./config.json", dump_path="./npu_dump")
# 请勿将以上初始化流程插入到循环代码中
diff --git a/debug/accuracy_tools/atat/pytorch/doc/run_overflow_check.md b/debug/accuracy_tools/msprobe/pytorch/doc/run_overflow_check.md
similarity index 95%
rename from debug/accuracy_tools/atat/pytorch/doc/run_overflow_check.md
rename to debug/accuracy_tools/msprobe/pytorch/doc/run_overflow_check.md
index 1bdc4f354cfaf0bfbdf701baa7dfb05f3771e30b..b8c9c3b4c292886e2ef8229ec421a244ae38f92e 100644
--- a/debug/accuracy_tools/atat/pytorch/doc/run_overflow_check.md
+++ b/debug/accuracy_tools/msprobe/pytorch/doc/run_overflow_check.md
@@ -13,7 +13,7 @@
2. 执行溢出API解析操作。
```bash
- atat -f pytorch run_overflow_check -api_info ./dump.json
+ msprobe -f pytorch run_overflow_check -api_info ./dump.json
```
| 参数名称 | 说明 | 是否必选 |
diff --git a/debug/accuracy_tools/msprobe/pytorch/doc/torchair_compare.md b/debug/accuracy_tools/msprobe/pytorch/doc/torchair_compare.md
new file mode 100644
index 0000000000000000000000000000000000000000..a27ba388b030dbd534c6d79b160a8bb3b87bd5a8
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/doc/torchair_compare.md
@@ -0,0 +1,43 @@
+# TorchAir训练场景-整网算子精度比对
+
+TorchAir训练场景数据dump和比对能力继承自torchair推理场景。数据dump和比对基本资料可参考[基于torch图模式(torchair)推理场景](https://gitee.com/ascend/msit/blob/master/msit/docs/llm/TorchAir%E5%9C%BA%E6%99%AF-%E6%95%B4%E7%BD%91%E7%AE%97%E5%AD%90%E7%B2%BE%E5%BA%A6%E6%AF%94%E5%AF%B9.md)。
+
+## 工具安装
+
+1. 克隆mstt仓poc分支后,进入mstt/debug/accuracy_tools目录,执行python setup.py bdist_wheel,进行poc分支的msprobe包构建,构建完成后,在accuracy_tools目录下生成的dist文件夹,dist下生成whl安装包,安装whl包。
+2. 命令行中输入add_torchair_compare_path。提示torchair compare path added successfully,则说明可以使用图模式数据dump和比对能力。
+
+## 约束
+
+- 数据dump时,torch.compile()函数训练场景不接受参数dynamic=True。
+- FX模式dump训练脚本需要单独放在一个目录下。
+
+## 比对操作
+
+1. 数据dump
+
+ 用法见[基于torch图模式(torchair)推理场景](https://gitee.com/ascend/msit/blob/master/msit/docs/llm/TorchAir%E5%9C%BA%E6%99%AF-%E6%95%B4%E7%BD%91%E7%AE%97%E5%AD%90%E7%B2%BE%E5%BA%A6%E6%AF%94%E5%AF%B9.md)。可参考[TorchAir场景Dump案例](./torchair_dump_sample.md)。
+
+2. 精度比对
+
+ 比对通过命令行子命令torchair_compare执行,命令如下:
+
+ ```bash
+ msprobe -f pytorch torchair_compare --my-path [my dump data] --golden-path [gold dump data] --output-path [output path]
+ ```
+
+3. 查看比对结果,请参考《[PyTorch 场景的精度比对-精度比对结果分析](https://gitee.com/ascend/mstt/blob/e4206fa5268a96f991dc7a3444c6ce16040089ed/debug/accuracy_tools/msprobe/docs/10.accuracy_compare_PyTorch.md#3-精度比对结果分析)》章节。
+
+TorchAir训练场景支持如下两种比对场景:
+
+- GE融合模式(默认)dump数据与FX dump数据精度比对
+
+ ```bash
+ msprobe -f pytorch torchair_compare --my-path {dump_path}/dump_{time_stamp} --golden-path {dump_path}
+ ```
+
+- GE融合模式(默认)dump数据与GE关闭融合模式dump数据精度比对
+
+ ```bash
+ msprobe -f pytorch torchair_compare --my-path {dump_path}/dump_{time_stamp} --golden-path {dump_path}/dump_{time_stamp}
+ ```
diff --git a/debug/accuracy_tools/msprobe/pytorch/doc/torchair_dump_sample.md b/debug/accuracy_tools/msprobe/pytorch/doc/torchair_dump_sample.md
new file mode 100644
index 0000000000000000000000000000000000000000..b809113712e7a03cff7b9d55019b4576960a1f05
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/doc/torchair_dump_sample.md
@@ -0,0 +1,131 @@
+# TorchAir训练场景Dump案例
+
+## GE开启融合(默认)dump
+
+```python
+import os
+import numpy as np
+import torch, torch_npu, torchvision
+import torchair as tng
+from msit_llm.dump import torchair_dump
+import torch.optim as optim
+
+target_dtype = torch.float16
+model = torchvision.models.resnet50(pretrained=True).eval().to(target_dtype).npu()
+if not os.path.exists('aa_224_224.npy'):
+ np.save('aa_224_224.npy', np.random.uniform(size=[1, 3, 224, 224]))
+aa = torch.from_numpy(np.load('aa_224_224.npy')).to(target_dtype).npu()
+
+config = torchair_dump.get_ge_dump_config(dump_path="dump_gefusion_open_train")
+npu_backend = tng.get_npu_backend(compiler_config=config)
+model = torch.compile(model, backend=npu_backend)
+
+criterion = torch.nn.CrossEntropyLoss()
+optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
+
+if not os.path.exists('bb_1.pt'):
+ torch.save(torch.randint(0, 1000, (1,), dtype=torch.long), 'bb_1.pt')
+bb = torch.load('bb_1.pt').npu()
+
+optimizer.zero_grad()
+output = model(aa)
+loss = criterion(output, bb)
+print(loss)
+print("loss finished")
+loss.backward()
+print("backward finished")
+optimizer.step()
+print(f"loss gradient: {loss.grad}")
+```
+
+## FX dump
+
+```python
+import os
+import numpy as np
+import torch, torch_npu, torchvision
+import torchair as tng
+from msit_llm.dump import torchair_dump
+import torch.optim as optim
+
+target_dtype = torch.float16
+model = torchvision.models.resnet50(pretrained=True).eval().to(target_dtype).npu()
+if not os.path.exists('../aa_224_224.npy'):
+ np.save('../aa_224_224.npy', np.random.uniform(size=[1, 3, 224, 224]))
+aa = torch.from_numpy(np.load('../aa_224_224.npy')).to(target_dtype).npu()
+
+config = torchair_dump.get_fx_dump_config()
+npu_backend = tng.get_npu_backend(compiler_config=config)
+model = torch.compile(model, backend=npu_backend)
+
+criterion = torch.nn.CrossEntropyLoss()
+optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
+
+if not os.path.exists('../bb_1.pt'):
+ torch.save(torch.randint(0, 1000, (1,), dtype=torch.long), '../bb_1.pt')
+bb = torch.load('../bb_1.pt').npu()
+
+optimizer.zero_grad()
+output = model(aa)
+loss = criterion(output, bb)
+print(loss)
+print("loss finished")
+loss.backward()
+print("backward finished")
+optimizer.step()
+print(f"loss gradient: {loss.grad}")
+```
+
+## GE关闭融合模式dump
+
+```python
+import os
+import numpy as np
+import torch, torch_npu, torchvision
+import torchair as tng
+from msit_llm.dump import torchair_dump
+import torch.optim as optim
+
+target_dtype = torch.float16
+model = torchvision.models.resnet50(pretrained=True).eval().to(target_dtype).npu()
+if not os.path.exists('aa_224_224.npy'):
+ np.save('aa_224_224.npy', np.random.uniform(size=[1, 3, 224, 224]))
+aa = torch.from_numpy(np.load('aa_224_224.npy')).to(target_dtype).npu()
+
+config = torchair_dump.get_ge_dump_config(dump_path="dump_gefusion_close_train", fusion_switch_file='./fusion_switch.json')
+npu_backend = tng.get_npu_backend(compiler_config=config)
+model = torch.compile(model, backend=npu_backend)
+
+criterion = torch.nn.CrossEntropyLoss()
+optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
+
+if not os.path.exists('bb_1.pt'):
+ torch.save(torch.randint(0, 1000, (1,), dtype=torch.long), 'bb_1.pt')
+bb = torch.load('bb_1.pt').npu()
+
+optimizer.zero_grad()
+output = model(aa)
+loss = criterion(output, bb)
+print(loss)
+print("loss finished")
+loss.backward()
+print("backward finished")
+optimizer.step()
+print(f"loss gradient: {loss.grad}")
+```
+
+新建fusion_switch.json文件,内容如下:
+
+```json
+{
+ "Switch": {
+ "GraphFusion": {
+ "ALL": "off"
+ },
+ "UBFusion": {
+ "ALL": "off"
+ }
+ }
+}
+```
+
diff --git "a/debug/accuracy_tools/atat/pytorch/doc/\345\234\250\347\272\277\347\262\276\345\272\246\346\257\224\345\257\271.md" "b/debug/accuracy_tools/msprobe/pytorch/doc/\345\234\250\347\272\277\347\262\276\345\272\246\346\257\224\345\257\271.md"
similarity index 95%
rename from "debug/accuracy_tools/atat/pytorch/doc/\345\234\250\347\272\277\347\262\276\345\272\246\346\257\224\345\257\271.md"
rename to "debug/accuracy_tools/msprobe/pytorch/doc/\345\234\250\347\272\277\347\262\276\345\272\246\346\257\224\345\257\271.md"
index b2e373feb6cf7cb0c63dcff592939567b52738b4..05bebaf0a22d8d5d7886ec9937b32b4755caf872 100644
--- "a/debug/accuracy_tools/atat/pytorch/doc/\345\234\250\347\272\277\347\262\276\345\272\246\346\257\224\345\257\271.md"
+++ "b/debug/accuracy_tools/msprobe/pytorch/doc/\345\234\250\347\272\277\347\262\276\345\272\246\346\257\224\345\257\271.md"
@@ -32,8 +32,8 @@ PyTorch NPU在线精度比对是ptdbg_ascend工具实现在PyTorch训练过程
1. 在NPU训练脚本中添加在线精度比对接口,示例如下:
```python
- from atat.pytorch.common.utils import seed_all
- from atat.pytorch.online_dispatch import PtdbgDispatch
+ from msprobe.pytorch.common.utils import seed_all
+ from msprobe.pytorch.online_dispatch import PtdbgDispatch
# 在main函数开始前固定随机数
seed_all()
@@ -74,12 +74,12 @@ PyTorch NPU在线精度比对是ptdbg_ascend工具实现在PyTorch训练过程
| process_num | 多进程并发数,默认为0。 | 否 |
| debug | debug信息打印,默认为False。 | 否 |
### dump数据存盘说明
-dump数据存盘目录名格式:`atat_tag_rankid_{timestamp}`。
+dump数据存盘目录名格式:`msprobe_tag_rankid_{timestamp}`。
子目录下包含1个比对结果csv文件、cpu和npudump数据目录,npu目录下包含Aten IR在NPU上的输入输出的dump数据,由于CPU的输入是直接使用NPU的输入执行,因此cpu目录下只包含执行输出的dump数据。
```bash
-atat_rank4_20230911170521
+msprobe_rank4_20230911170521
├── compare_result_rank4_20230911170521.csv
├── cpu
│ ├── native_batch_norm_backward_10_output.0.npy
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/__init__.py
similarity index 43%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/__init__.py
index b9d41330a87605af37f7b538e8d7bcf56f9725f1..d234898c0df158308b070e4a9147c9dff0b67c8d 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/__init__.py
@@ -1,6 +1,6 @@
-from atat.core.common.log import logger
-from atat.core.common.exceptions import FreeBenchmarkException
-from atat.core.common.const import Const
+from msprobe.core.common.log import logger
+from msprobe.core.common.exceptions import FreeBenchmarkException
+from msprobe.core.common.const import Const
from .main import FreeBenchmarkCheck
from .common.params import UnequalRow
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/__init__.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/__init__.py
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/__init__.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/__init__.py
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/constant.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/constant.py
similarity index 89%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/common/constant.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/constant.py
index 36b7a6491580a0137ff89958b4f6883dccbea4fc..c5e93be138d24af8c18858db483e397527fb4092 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/constant.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/constant.py
@@ -2,8 +2,8 @@ from typing import Dict
import numpy as np
import torch
-from atat.pytorch.free_benchmark.common.enums import FuzzThreshold
-from atat.pytorch.free_benchmark.common.params import BenchmarkThd
+from msprobe.pytorch.free_benchmark.common.enums import FuzzThreshold
+from msprobe.pytorch.free_benchmark.common.params import BenchmarkThd
class CommonField:
@@ -52,6 +52,7 @@ class ThresholdConfig:
DTYPE_PER_THD = {
torch.float16: 1.002,
+ torch.bfloat16: 1.004,
torch.float32: 1.0002,
}
BENCHMARK_THD_DICT = {
@@ -60,6 +61,8 @@ class ThresholdConfig:
torch.bfloat16: BenchmarkThd(2**-8, 1.0, 2**-8, 1e-4),
}
+ TENSOR_SPLIT_MAX_CHUNK = 128
+
class PreheatConfig:
IF_PREHEAT = "if_preheat"
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/counter.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/counter.py
similarity index 97%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/common/counter.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/counter.py
index 186b75c71aeaf71fc2adab7ec38c7f00f6b7fdb7..b2f8c81f3a4ea57d712e49b0b58fc77747797323 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/counter.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/counter.py
@@ -1,5 +1,5 @@
from collections import defaultdict
-from atat.pytorch.free_benchmark.common.constant import ThresholdConfig
+from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
class PreheatCounter:
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/enums.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/enums.py
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/common/enums.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/enums.py
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/params.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/params.py
similarity index 93%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/common/params.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/params.py
index 440348d78c28d3f7cc816932ff12e83aa71915bc..bbfc245a635322f0cde6951663d4d76a009ee66a 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/params.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/params.py
@@ -2,13 +2,13 @@ from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
-from atat.pytorch.free_benchmark import logger
-from atat.pytorch.free_benchmark.common.enums import (
+from msprobe.pytorch.free_benchmark import logger
+from msprobe.pytorch.free_benchmark.common.enums import (
DeviceType,
FuzzLevel,
PerturbationMode,
)
-from atat.pytorch.free_benchmark.common.utils import Tools
+from msprobe.pytorch.free_benchmark.common.utils import Tools
@dataclass
@@ -78,7 +78,7 @@ def data_pre_deal(name, func, args, kwargs):
data_params.valid_input_index = index
if index == -1:
logger.warning_on_rank_0(
- f"[atat] Free benchmark: 无标杆工具不支持当前算子的输入类型 {name}."
+ f"[msprobe] Free benchmark: 无标杆工具不支持当前算子的输入类型 {name}."
)
return data_params
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py
similarity index 92%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/common/utils.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py
index 24d25967635b3dcfd1da89e1f54d3282fa1181ed..631beeb85cbce50ba126f39be8c04347f189c50b 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/utils.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py
@@ -1,5 +1,5 @@
import torch
-from atat.pytorch.free_benchmark.common.enums import DeviceType
+from msprobe.pytorch.free_benchmark.common.enums import DeviceType
class Tools:
@@ -96,3 +96,7 @@ class TorchC:
add = torch._C._VariableFunctionsClass.add
bitwise_xor = torch._C._VariableFunctionsClass.bitwise_xor
clone = torch._C._VariableFunctionsClass.clone
+ clamp = torch._C._VariableFunctionsClass.clamp
+ tensor_split = torch._C._VariableFunctionsClass.tensor_split
+ stack = torch._C._VariableFunctionsClass.stack
+ reshape = torch._C._VariableFunctionsClass.reshape
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py
similarity index 89%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py
index 89ef9e4c9b4500953b0edc26f28f5b14e401ca50..6781a1c2fc4c7fd348d330a49352e2f6195e8a71 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py
@@ -1,10 +1,10 @@
import torch
-from atat.core.common.exceptions import FreeBenchmarkException
-from atat.pytorch.free_benchmark import logger
-from atat.pytorch.free_benchmark.common.constant import CommonField
-from atat.pytorch.free_benchmark.common.params import DataParams, HandlerParams
-from atat.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory
-from atat.pytorch.free_benchmark.result_handlers.handler_factory import (
+from msprobe.core.common.exceptions import FreeBenchmarkException
+from msprobe.pytorch.free_benchmark import logger
+from msprobe.pytorch.free_benchmark.common.constant import CommonField
+from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams
+from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory
+from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import (
FuzzHandlerFactory,
)
@@ -41,18 +41,18 @@ class GradSaver:
data_processor.update_unequal_rows(handler.get_unequal_rows())
except IndexError:
logger.warning_on_rank_0(
- f"[atat] Free benchmark: grad index out of range. api:{self.handler_params.api_name}."
+ f"[msprobe] Free benchmark: grad index out of range. api:{self.handler_params.api_name}."
f"index:{new_grad_index}, perturbation grad len {len(self.perturbed_grad_input)}"
)
return grad
except FreeBenchmarkException as e:
logger.warning_on_rank_0(
- f"[atat] Free benchmark: grad input check error: {e}"
+ f"[msprobe] Free benchmark: grad input check error: {e}"
)
return grad
except Exception as e:
logger.warning_on_rank_0(
- f"[atat] Free benchmark: grad compare error: {e}"
+ f"[msprobe] Free benchmark: grad compare error: {e}"
)
return grad
return grad
@@ -77,7 +77,7 @@ class GradSaver:
handler.handle(self.data_params)
except Exception as e:
logger.warning_on_rank_0(
- f"[atat] Free benchmark: compare two vjp failed: api:{self.handler_params.api_name}."
+ f"[msprobe] Free benchmark: compare two vjp failed: api:{self.handler_params.api_name}."
f"{e}"
)
# 在扰动前后输出对比后释放输出的引用
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/single_benchmark.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/single_benchmark.py
similarity index 89%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/compare/single_benchmark.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/single_benchmark.py
index 85aa68f13b969e996407f5f64353de43b916e00f..59239fcd004fb3472d5c8f53305e692fec00adee 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/single_benchmark.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/single_benchmark.py
@@ -1,9 +1,9 @@
import math
import torch
-from atat.pytorch.free_benchmark import logger
-from atat.pytorch.free_benchmark.common.constant import ThresholdConfig
-from atat.pytorch.free_benchmark.common.utils import TorchC
+from msprobe.pytorch.free_benchmark import logger
+from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
+from msprobe.pytorch.free_benchmark.common.utils import TorchC
class SingleCompare:
@@ -28,6 +28,14 @@ class SingleCompare:
tensor[inf_or_nan_mask] = 1
return tensor
+ @staticmethod
+ def compare_float_seq(actual, golden):
+ return math.isclose(actual, golden)
+
+ @staticmethod
+ def compare_other_seq(actual, golden):
+ return actual == golden
+
def compare_dict_seq(self, actual, golden):
if len(actual) != len(golden):
return False
@@ -61,7 +69,7 @@ class SingleCompare:
actual.dtype, ThresholdConfig.BENCHMARK_THD_DICT.get(torch.float32)
)
if self.filter_overflow(golden) > 0:
- logger.warning_on_rank_0("[atat] Free Benchmark: inf and nan"
+ logger.warning_on_rank_0("[msprobe] Free Benchmark: inf and nan"
"in golden tensor is not supported.")
return True
actual = self.replace_inf_or_nan(actual)
@@ -76,12 +84,6 @@ class SingleCompare:
return False
return True
- def compare_float_seq(self, actual, golden):
- return math.isclose(actual, golden)
-
- def compare_other_seq(self, actual, golden):
- return actual == golden
-
def _cal_compare_metrics(self, actual, golden):
diff_value = TorchC.subtract(actual, golden)
diff_abs = TorchC.abs(diff_value)
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py
similarity index 81%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/main.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py
index 2ebc0a6db917334f492019fa694330d86a0f37e1..971776d1326409c8878849e7b09a4614ffbc16f5 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py
@@ -1,19 +1,19 @@
from abc import ABC
import torch
-from atat.core.common.const import Const
-from atat.pytorch.free_benchmark import logger
-from atat.pytorch.free_benchmark.common.constant import CommonField
-from atat.pytorch.free_benchmark.common.enums import (
+from msprobe.core.common.const import Const
+from msprobe.pytorch.free_benchmark import logger
+from msprobe.pytorch.free_benchmark.common.constant import CommonField
+from msprobe.pytorch.free_benchmark.common.enums import (
DeviceType,
FuzzLevel,
HandlerType,
PerturbationMode,
)
-from atat.pytorch.free_benchmark.common.params import data_pre_deal, make_handler_params
-from atat.pytorch.free_benchmark.compare.grad_saver import GradSaver
-from atat.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory
-from atat.pytorch.free_benchmark.result_handlers.handler_factory import (
+from msprobe.pytorch.free_benchmark.common.params import data_pre_deal, make_handler_params
+from msprobe.pytorch.free_benchmark.compare.grad_saver import GradSaver
+from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory
+from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import (
FuzzHandlerFactory,
)
@@ -81,7 +81,7 @@ class FreeBenchmarkCheck(ABC):
grad_saver = getattr(module, CommonField.GRADSAVER)
except AttributeError:
logger.warning_on_rank_0(
- f"[atat] Free benchmark: get grad saver failed. api_name:{name}"
+ f"[msprobe] Free benchmark: get grad saver failed. api_name:{name}"
)
return
@@ -97,6 +97,6 @@ class FreeBenchmarkCheck(ABC):
)
except Exception as e:
logger.warning_on_rank_0(
- f"[atat] Free benchmark: grad vjp calculate failed. api_name:{name} error: {e}"
+ f"[msprobe] Free benchmark: grad vjp calculate failed. api_name:{name} error: {e}"
)
return
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/__init__.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/__init__.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/base_layer.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py
similarity index 78%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/base_layer.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py
index aa572fd8e8dc8b62493dfa1fecc587b934c83a99..f64a201d5efa007ff4ed848eafd2ab6db535a2f5 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/base_layer.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any
-from atat.pytorch.free_benchmark.common.params import DataParams
+from msprobe.pytorch.free_benchmark.common.params import DataParams
class BaseLayer(ABC):
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/layer_factory.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py
similarity index 62%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/layer_factory.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py
index 0d09438ce04132c9c5c301d758dc06818805082e..0ea9107aa84c2633435fe616891f5386b17de423 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/layer_factory.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py
@@ -1,15 +1,15 @@
-from atat.pytorch.free_benchmark import FreeBenchmarkException
-from atat.pytorch.free_benchmark.common.enums import DeviceType, PerturbationMode
-from atat.pytorch.free_benchmark.perturbed_layers.npu.improve_precision import (
+from msprobe.pytorch.free_benchmark import FreeBenchmarkException
+from msprobe.pytorch.free_benchmark.common.enums import DeviceType, PerturbationMode
+from msprobe.pytorch.free_benchmark.perturbed_layers.npu.improve_precision import (
ImprovePrecisionLayer,
)
-from atat.pytorch.free_benchmark.perturbed_layers.npu.add_noise import AddNoiseLayer
-from atat.pytorch.free_benchmark.perturbed_layers.npu.bit_noise import BitNoiseLayer
-from atat.pytorch.free_benchmark.perturbed_layers.npu.no_change import NoChangeLayer
-from atat.pytorch.free_benchmark.perturbed_layers.npu.change_value import (
+from msprobe.pytorch.free_benchmark.perturbed_layers.npu.add_noise import AddNoiseLayer
+from msprobe.pytorch.free_benchmark.perturbed_layers.npu.bit_noise import BitNoiseLayer
+from msprobe.pytorch.free_benchmark.perturbed_layers.npu.no_change import NoChangeLayer
+from msprobe.pytorch.free_benchmark.perturbed_layers.npu.change_value import (
ChangeValueLayer,
)
-from atat.pytorch.free_benchmark.perturbed_layers.run_cpu import CpuLayer
+from msprobe.pytorch.free_benchmark.perturbed_layers.run_cpu import CpuLayer
class LayerFactory:
diff --git a/debug/accuracy_tools/atat/pytorch/functional/__init__.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/functional/__init__.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py
similarity index 78%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py
index af8a93f7d4b9b06623b70c22e7fb5065305e84a0..a18ef1c51bd342c9b3ab5ffecf14c307e9be5527 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py
@@ -1,10 +1,10 @@
import torch
-from atat.pytorch.free_benchmark import logger
-from atat.pytorch.free_benchmark.common.constant import ThresholdConfig
-from atat.pytorch.free_benchmark.common.enums import PerturbationMode
-from atat.pytorch.free_benchmark.common.params import DataParams
-from atat.pytorch.free_benchmark.common.utils import TorchC
-from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
+from msprobe.pytorch.free_benchmark import logger
+from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
+from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
+from msprobe.pytorch.free_benchmark.common.params import DataParams
+from msprobe.pytorch.free_benchmark.common.utils import TorchC
+from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
NpuBaseLayer,
)
@@ -37,7 +37,7 @@ class AddNoiseLayer(NpuBaseLayer):
对输入添加扰动并返回
"""
logger.info_on_rank_0(
- f"[atat] Free benchmark: Perturbation is "
+ f"[msprobe] Free benchmark: Perturbation is "
f"{PerturbationMode.ADD_NOISE} of {self.api_name}."
)
params.perturbed_value = self.add_noise(params.args[params.valid_input_index])
@@ -60,13 +60,13 @@ class AddNoiseLayer(NpuBaseLayer):
"""
if not self.perturbed_value:
logger.warning_on_rank_0(
- f"[atat] Free Benchmark: For {self.api_name}, "
+ f"[msprobe] Free Benchmark: For {self.api_name}, "
f"dtype unsupported. Cancel perturbation."
)
return False
if tensor_obj.numel() == 0:
logger.warning_on_rank_0(
- f"[atat] Free benchmark: For {self.api_name}, tensor shape must > 0."
+ f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0."
f" Cancel adding noise."
)
return False
@@ -77,13 +77,13 @@ class AddNoiseLayer(NpuBaseLayer):
max_val = TorchC.max(TorchC.abs(tensor_obj)).item()
except Exception:
logger.warning_on_rank_0(
- f"[atat] Free Benchmark: For {self.api_name}, "
+ f"[msprobe] Free Benchmark: For {self.api_name}, "
f"when calculate maximun value, tensor is changed to float32."
)
max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
if max_val < abs_tol:
logger.warning_on_rank_0(
- f"[atat] Free Benchmark: For {self.api_name}, "
+ f"[msprobe] Free Benchmark: For {self.api_name}, "
f"Maximun value is less than the minimun threshold. Cancel add noise."
)
return False
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py
similarity index 80%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py
index 40b99acf41105fa61792ef52e27cc7f2e6686ba7..45dea7b93a5c7628b24bf0470af10af355a7742f 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py
@@ -1,10 +1,10 @@
import torch
-from atat.pytorch.free_benchmark import logger
-from atat.pytorch.free_benchmark.common.constant import ThresholdConfig
-from atat.pytorch.free_benchmark.common.enums import PerturbationMode
-from atat.pytorch.free_benchmark.common.params import DataParams
-from atat.pytorch.free_benchmark.common.utils import TorchC
-from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
+from msprobe.pytorch.free_benchmark import logger
+from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
+from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
+from msprobe.pytorch.free_benchmark.common.params import DataParams
+from msprobe.pytorch.free_benchmark.common.utils import TorchC
+from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
NpuBaseLayer,
)
@@ -53,7 +53,7 @@ class BitNoiseLayer(NpuBaseLayer):
对输入添加扰动并返回
"""
logger.info_on_rank_0(
- f"[atat] Free benchmark: Perturbation is "
+ f"[msprobe] Free benchmark: Perturbation is "
f"{PerturbationMode.BIT_NOISE} of {self.api_name}."
)
params.perturbed_value = self.add_bit_noise(params.args[params.valid_input_index])
@@ -65,13 +65,13 @@ class BitNoiseLayer(NpuBaseLayer):
"""
if not self.bit_type:
logger.info_on_rank_0(
- f"[atat] Free Benchmark: For {self.api_name}, "
+ f"[msprobe] Free Benchmark: For {self.api_name}, "
f"dtype unsupported. Cancel perturbation."
)
return False
if tensor_obj.numel() == 0:
logger.warning_on_rank_0(
- f"[atat] Free benchmark: For {self.api_name}, tensor shape must > 0"
+ f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0"
f" Cancel adding noise."
)
return False
@@ -82,13 +82,13 @@ class BitNoiseLayer(NpuBaseLayer):
max_val = TorchC.max(TorchC.abs(tensor_obj)).item()
except Exception:
logger.warning_on_rank_0(
- f"[atat] Free Benchmark: For {self.api_name}, "
+ f"[msprobe] Free Benchmark: For {self.api_name}, "
f"when calculate maximun value, tensor is changed to float32."
)
max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
if max_val < abs_tol:
logger.info_on_rank_0(
- f"[atat] Free Benchmark: For {self.api_name}, "
+ f"[msprobe] Free Benchmark: For {self.api_name}, "
f"Maximun value is less than the minimun threshold. Cancel add noise."
)
return False
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/change_value.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py
similarity index 81%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/change_value.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py
index b7a967e18b91ecc2d36c22afce49f72677bef565..91085d57a68b4841b2e04453c05c41a2903477c3 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/change_value.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py
@@ -1,9 +1,9 @@
import torch
-from atat.pytorch.free_benchmark import logger
-from atat.pytorch.free_benchmark.common.enums import PerturbationMode
-from atat.pytorch.free_benchmark.common.params import DataParams
-from atat.pytorch.free_benchmark.common.utils import TorchC
-from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
+from msprobe.pytorch.free_benchmark import logger
+from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
+from msprobe.pytorch.free_benchmark.common.params import DataParams
+from msprobe.pytorch.free_benchmark.common.utils import TorchC
+from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
NpuBaseLayer,
)
@@ -44,7 +44,7 @@ class ChangeValueLayer(NpuBaseLayer):
对输入添加扰动并返回
"""
logger.info_on_rank_0(
- f"[atat] Free benchmark: Perturbation is "
+ f"[msprobe] Free benchmark: Perturbation is "
f"{PerturbationMode.CHANGE_VALUE} of {self.api_name}."
)
params.perturbed_value = self.change_value(params.args[params.valid_input_index])
@@ -56,7 +56,7 @@ class ChangeValueLayer(NpuBaseLayer):
"""
if tensor_obj.size(0) < 2:
logger.info_on_rank_0(
- f"[atat] Free Benchmark: For {self.api_name}, "
+ f"[msprobe] Free Benchmark: For {self.api_name}, "
f"size 0 must greater than 1. Cancel change value."
)
return False
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py
similarity index 83%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py
index 03718e3c4d6c4c4ea28ae6eec5daddc02bcedb7d..ad6d8b8989d6983f81a9a2d58798a26d4ccc45c1 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py
@@ -1,10 +1,10 @@
import torch
-from atat.core.common.const import Const
-from atat.pytorch.free_benchmark import logger
-from atat.pytorch.free_benchmark.common.constant import CommonField
-from atat.pytorch.free_benchmark.common.enums import PerturbationMode
-from atat.pytorch.free_benchmark.common.params import DataParams
-from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
+from msprobe.core.common.const import Const
+from msprobe.pytorch.free_benchmark import logger
+from msprobe.pytorch.free_benchmark.common.constant import CommonField
+from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
+from msprobe.pytorch.free_benchmark.common.params import DataParams
+from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
NpuBaseLayer,
)
@@ -34,7 +34,7 @@ class ImprovePrecisionLayer(NpuBaseLayer):
def handle(self, params: DataParams) -> torch.Any:
logger.info_on_rank_0(
- f"[atat] Free benchmark: Perturbation is "
+ f"[msprobe] Free benchmark: Perturbation is "
f"{PerturbationMode.IMPROVE_PRECISION} of {self.api_name}."
)
new_args = self.improve_tensor_precision(params.args)
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/no_change.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py
similarity index 64%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/no_change.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py
index bb065385c690f937c702cadac5707b787489aee5..a69c56002a205a518a6929835591859f63b800ff 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/no_change.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py
@@ -1,8 +1,8 @@
import torch
-from atat.pytorch.free_benchmark import logger
-from atat.pytorch.free_benchmark.common.enums import PerturbationMode
-from atat.pytorch.free_benchmark.common.params import DataParams
-from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
+from msprobe.pytorch.free_benchmark import logger
+from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
+from msprobe.pytorch.free_benchmark.common.params import DataParams
+from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
NpuBaseLayer,
)
@@ -21,7 +21,7 @@ class NoChangeLayer(NpuBaseLayer):
对输入添加扰动并返回
"""
logger.info_on_rank_0(
- f"[atat] Free benchmark: Perturbation is "
+ f"[msprobe] Free benchmark: Perturbation is "
f"{PerturbationMode.NO_CHANGE} of {self.api_name}."
)
params.perturbed_value = self.no_change(params.args[params.valid_input_index])
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py
similarity index 90%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py
index 3784af0953022f1eb981ca26cf88765044f56f3f..1a859481475bc9963a5d3b96389061a257a1a759 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py
@@ -2,8 +2,8 @@ from abc import abstractmethod
from typing import Any
import torch
-from atat.pytorch.free_benchmark.common.params import DataParams
-from atat.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
+from msprobe.pytorch.free_benchmark.common.params import DataParams
+from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
class NpuBaseLayer(BaseLayer):
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/run_cpu.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py
similarity index 52%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/run_cpu.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py
index 024958ffbe126b89ec15fa10b277d90af4ed3e45..d34ac976537d794a05255a32de8d54de2dbac5d3 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/run_cpu.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py
@@ -1,9 +1,9 @@
import torch
-from atat.pytorch.free_benchmark import logger
-from atat.pytorch.free_benchmark.common.params import DataParams
-from atat.pytorch.free_benchmark.common.utils import Tools
-from atat.pytorch.free_benchmark.common.enums import DeviceType
-from atat.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
+from msprobe.pytorch.free_benchmark import logger
+from msprobe.pytorch.free_benchmark.common.params import DataParams
+from msprobe.pytorch.free_benchmark.common.utils import Tools
+from msprobe.pytorch.free_benchmark.common.enums import DeviceType
+from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
class CpuLayer(BaseLayer):
@@ -11,7 +11,7 @@ class CpuLayer(BaseLayer):
def handle(self, params: DataParams) -> torch.Any:
logger.info_on_rank_0(
- f"[atat] Free benchmark: Perturbation is to_cpu of {self.api_name}."
+ f"[msprobe] Free benchmark: Perturbation is to_cpu of {self.api_name}."
)
new_args = Tools.convert_device_and_dtype(params.args, DeviceType.CPU, change_dtype=True)
new_kwargs = Tools.convert_device_and_dtype(params.kwargs, DeviceType.CPU, change_dtype=True)
diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/__init__.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/__init__.py
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/parse_tool/__init__.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/__init__.py
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py
similarity index 64%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py
index c57d7e390a0fad73a5e87d72ec72e434835618cf..e36f586735538ed2d51b3afaaec390724d9a4b02 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py
@@ -1,22 +1,23 @@
import math
from abc import ABC, abstractmethod
from typing import Any, Optional, Tuple
+import numpy as np
import torch
-from atat.core.common.const import Const
-from atat.pytorch.free_benchmark import logger
-from atat.pytorch.free_benchmark.common.constant import ThresholdConfig
-from atat.pytorch.free_benchmark.common.enums import (
+from msprobe.core.common.const import Const
+from msprobe.pytorch.free_benchmark import logger
+from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
+from msprobe.pytorch.free_benchmark.common.enums import (
FuzzThreshold,
NormType,
PerturbationMode,
)
-from atat.pytorch.free_benchmark.common.params import (
+from msprobe.pytorch.free_benchmark.common.params import (
DataParams,
HandlerParams,
make_unequal_row,
)
-from atat.pytorch.free_benchmark.common.utils import Tools, TorchC
+from msprobe.pytorch.free_benchmark.common.utils import Tools, TorchC
class FuzzHandler(ABC):
@@ -34,15 +35,36 @@ class FuzzHandler(ABC):
origin_ouput = origin_ouput.values
perturbed_output = perturbed_output.values
if hasattr(perturbed_output, "dtype"):
- abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(perturbed_output.dtype)
+ abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(perturbed_output.dtype, FuzzThreshold.F32_THD)
else:
- abs_tol = FuzzThreshold.F32_THD.value
+ abs_tol = FuzzThreshold.F32_THD
return (
origin_ouput.to(perturbed_output.dtype).to(perturbed_output.device),
perturbed_output,
abs_tol,
)
+ @staticmethod
+ def tensor_split_for_error_calculate(origin_output, perturbed_output):
+ """
+ 对将投入误差值计算的扰动前后输出张量进行分块
+ :param origin_output: 原始输出
+ :param perturbed_output: 扰动后输出
+ :return origin_output_chunks: 切块后原始输出列表
+ :return perturbed_output_chunks: 切块后扰动后输出列表
+ """
+ single_output_mem = origin_output.element_size() * origin_output.nelement() / Const.ONE_MB
+ if single_output_mem == 0 or origin_output.ndim == 0:
+ return [origin_output], [perturbed_output]
+ # 张量大小和批数之间的关系:chunks_exp=math.log(M,2)-4, chunks=2**chunks_exp (M为对比张量数据大小[Mb])
+ chunks_exp = int(math.log(single_output_mem, 2)) - 4
+ chunks = 2 ** chunks_exp
+ chunks = max(chunks, 1)
+ chunks = min(chunks, ThresholdConfig.TENSOR_SPLIT_MAX_CHUNK)
+ origin_output_chunks = TorchC.tensor_split(TorchC.reshape(origin_output, (-1,)), chunks)
+ perturbed_output_chunks = TorchC.tensor_split(TorchC.reshape(perturbed_output, (-1,)), chunks)
+ return origin_output_chunks, perturbed_output_chunks
+
@staticmethod
def convert_overflow_ratio_to_consistent(ratio):
if math.isnan(ratio) or math.isinf(ratio):
@@ -61,36 +83,28 @@ class FuzzHandler(ABC):
self, origin_output, perturbed_output, norm_type, abs_tol
):
if norm_type == NormType.ENDLESS_NORM:
- return self.get_endless_norm(origin_output, perturbed_output, abs_tol)
+ return self.calculate_error(origin_output, perturbed_output, abs_tol)
return ThresholdConfig.COMP_CONSISTENT
- def get_endless_norm(self, origin_output, perturbed_output, abs_tol):
- ratio_tensor1 = TorchC.where(
- TorchC.gt(TorchC.abs(perturbed_output), abs_tol),
- TorchC.div(
- TorchC.abs(origin_output),
- TorchC.add(TorchC.abs(perturbed_output), abs_tol),
- ),
- 1,
- )
- ratio_tensor2 = TorchC.where(
- TorchC.gt(TorchC.abs(origin_output), abs_tol),
- TorchC.div(
- TorchC.abs(perturbed_output),
- TorchC.add(TorchC.abs(origin_output), abs_tol),
- ),
- 1,
- )
+ def calculate_error(self, origin_output, perturbed_output, abs_tol):
+ origin_output_chunks, perturbed_output_chunks = self.tensor_split_for_error_calculate(origin_output, perturbed_output)
+ norm1 = -np.inf
+ norm2 = -np.inf
+ norm3 = np.inf
+ for i, chunk_origin in enumerate(origin_output_chunks):
+ if chunk_origin.nelement() == 0:
+ break
+ chunk_perturbed = perturbed_output_chunks[i]
+ ratio_tensor1 = TorchC.where(TorchC.abs(chunk_perturbed) > abs_tol,
+ TorchC.div(TorchC.clamp(chunk_origin, min=abs_tol), TorchC.clamp(chunk_perturbed, min=abs_tol)), 1)
+ ratio_tensor2 = TorchC.where(TorchC.abs(chunk_origin) > abs_tol,
+ TorchC.div(TorchC.clamp(chunk_perturbed, min=abs_tol), TorchC.clamp(chunk_origin, min=abs_tol)), 1)
+ norm_values = TorchC.stack([TorchC.max(ratio_tensor1), TorchC.max(ratio_tensor2)])
+ max_ratio1, max_ratio2 = norm_values.tolist()
+ norm1 = max(norm1, self.convert_overflow_ratio_to_consistent(max_ratio1))
+ norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(max_ratio2))
+ norm3 = min(norm3, self.convert_overflow_ratio_to_consistent(max_ratio1))
- norm1 = self.convert_overflow_ratio_to_consistent(
- TorchC.max(ratio_tensor1).item()
- )
- norm2 = self.convert_overflow_ratio_to_consistent(
- TorchC.max(ratio_tensor2).item()
- )
- norm3 = self.convert_overflow_ratio_to_consistent(
- TorchC.min(ratio_tensor1).item()
- )
if norm3 < 0:
ratio = ThresholdConfig.SYMBOL_FLIPPING
else:
@@ -104,7 +118,7 @@ class FuzzHandler(ABC):
)
except Exception as e:
logger.warning_on_rank_0(
- f"[atat] Free Benchmark: For {self.params.api_name}, "
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
f"when computing ratio,"
f" y1 or y2 dtype is not supported {e}"
)
@@ -133,7 +147,7 @@ class FuzzHandler(ABC):
)
elif not isinstance(perturbed_output, torch.Tensor):
logger.warning_on_rank_0(
- f"[atat] Free Benchmark: For {self.params.api_name} "
+ f"[msprobe] Free Benchmark: For {self.params.api_name} "
f"The compare for output type {type(perturbed_output)} is not supported"
)
@@ -185,7 +199,7 @@ class FuzzHandler(ABC):
)
except Exception as e:
logger.warning_on_rank_0(
- f"[atat] Free Benchmark: For {self.params.api_name}, "
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
f"when campare the result exception raise {e}"
)
return npu_consistent, max_fuzz_ratio
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/check_handler.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/check_handler.py
similarity index 68%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/check_handler.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/check_handler.py
index ed846803a180e3cca7b0a05f01f592dff07d5a44..c16284eb07beda10a38755dc54349c8835ada37a 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/check_handler.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/check_handler.py
@@ -1,11 +1,11 @@
from typing import Any
-from atat.pytorch.free_benchmark import logger
-from atat.pytorch.free_benchmark.common.enums import DeviceType
-from atat.pytorch.free_benchmark.common.params import DataParams, make_unequal_row
-from atat.pytorch.free_benchmark.common.utils import Tools
-from atat.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
-from atat.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
+from msprobe.pytorch.free_benchmark import logger
+from msprobe.pytorch.free_benchmark.common.enums import DeviceType
+from msprobe.pytorch.free_benchmark.common.params import DataParams, make_unequal_row
+from msprobe.pytorch.free_benchmark.common.utils import Tools
+from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
+from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
class CheckerHandler(FuzzHandler):
@@ -33,7 +33,7 @@ class CheckerHandler(FuzzHandler):
self.other_compare(data_params)
except Exception as e:
logger.warning_on_rank_0(
- f"[atat] Free Benchmark: For {self.params.api_name}, "
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
f"when campare the result exception raise {e}"
)
return data_params.original_result
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/fix_handler.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py
similarity index 60%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/fix_handler.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py
index fa5c6f37495323693175117040c9c2f7fa3c01c6..a1d90035e847abb26c0838635666ff0425853513 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/fix_handler.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py
@@ -1,9 +1,9 @@
from typing import Any
-from atat.pytorch.free_benchmark.common.params import DataParams
-from atat.pytorch.free_benchmark.common.utils import Tools
-from atat.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
-from atat.pytorch.free_benchmark import logger
+from msprobe.pytorch.free_benchmark.common.params import DataParams
+from msprobe.pytorch.free_benchmark.common.utils import Tools
+from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
+from msprobe.pytorch.free_benchmark import logger
class FixHandler(FuzzHandler):
@@ -18,7 +18,7 @@ class FixHandler(FuzzHandler):
)
except Exception as e:
logger.warning_on_rank_0(
- f"[atat] Free Benchmark: For {self.params.api_name} "
+ f"[msprobe] Free Benchmark: For {self.params.api_name} "
f"Fix output failed. "
)
return data_params.original_result
\ No newline at end of file
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/handler_factory.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py
similarity index 59%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/handler_factory.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py
index cff629854d9b2bd1413b273171ad4ce73493bbb0..5ee968c6a86728786f526660594fbb6de4ce18ee 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/handler_factory.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py
@@ -1,10 +1,10 @@
-from atat.pytorch.free_benchmark import FreeBenchmarkException
-from atat.pytorch.free_benchmark.common.constant import PreheatConfig
-from atat.pytorch.free_benchmark.common.enums import HandlerType
-from atat.pytorch.free_benchmark.common.params import HandlerParams
-from atat.pytorch.free_benchmark.result_handlers.check_handler import CheckerHandler
-from atat.pytorch.free_benchmark.result_handlers.preheat_handler import PreheatHandler
-from atat.pytorch.free_benchmark.result_handlers.fix_handler import FixHandler
+from msprobe.pytorch.free_benchmark import FreeBenchmarkException
+from msprobe.pytorch.free_benchmark.common.constant import PreheatConfig
+from msprobe.pytorch.free_benchmark.common.enums import HandlerType
+from msprobe.pytorch.free_benchmark.common.params import HandlerParams
+from msprobe.pytorch.free_benchmark.result_handlers.check_handler import CheckerHandler
+from msprobe.pytorch.free_benchmark.result_handlers.preheat_handler import PreheatHandler
+from msprobe.pytorch.free_benchmark.result_handlers.fix_handler import FixHandler
class FuzzHandlerFactory:
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/preheat_handler.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py
similarity index 88%
rename from debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/preheat_handler.py
rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py
index 033a6d4931f718fdd2b60b3ad502766ea3f06382..d78e4303620f3ca73522cc9188452ffce0de2b12 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/preheat_handler.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py
@@ -1,14 +1,14 @@
import math
from typing import Any
-from atat.pytorch.free_benchmark import logger
-from atat.pytorch.free_benchmark.common.constant import ThresholdConfig
-from atat.pytorch.free_benchmark.common.counter import preheat_counter
-from atat.pytorch.free_benchmark.common.enums import DeviceType
-from atat.pytorch.free_benchmark.common.params import DataParams, HandlerParams
-from atat.pytorch.free_benchmark.common.utils import Tools
-from atat.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
-from atat.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
+from msprobe.pytorch.free_benchmark import logger
+from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
+from msprobe.pytorch.free_benchmark.common.counter import preheat_counter
+from msprobe.pytorch.free_benchmark.common.enums import DeviceType
+from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams
+from msprobe.pytorch.free_benchmark.common.utils import Tools
+from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
+from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
class PreheatHandler(FuzzHandler):
@@ -74,14 +74,14 @@ class PreheatHandler(FuzzHandler):
cpu_consistent = self.compare_npu_and_cpu(data_params)
except Exception as e:
logger.warning_on_rank_0(
- f"[atat] Free Benchmark: For {self.params.api_name}, "
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
f"when campare to cpu exception raise {e}"
)
try:
first_dtype = Tools.get_first_tensor_dtype(data_params.original_result)
except RuntimeError:
logger.warning_on_rank_0(
- f"[atat] Free Benchmark: For {self.params.api_name}, "
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
f"the output sequence does not contain tensors."
)
if preheat_counter.get_api_preheat(self.pure_name, str(first_dtype)):
@@ -96,7 +96,7 @@ class PreheatHandler(FuzzHandler):
if res:
total_count = preheat_counter.get_one_step_used_api(self.pure_name)
logger.info_on_rank_0(
- f"[atat] Free benchmark: preheat sample in step{self.params.step}"
+ f"[msprobe] Free benchmark: preheat sample in step{self.params.step}"
f"api_name {self.params.api_name}, "
f"curr_called_seq: {curr_called_seq}/{total_count}"
)
diff --git a/debug/accuracy_tools/msprobe/pytorch/function_factory.py b/debug/accuracy_tools/msprobe/pytorch/function_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2fd8bfd0cbbad21e688c44c6fbb8671ba942511
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/function_factory.py
@@ -0,0 +1,75 @@
+from msprobe.pytorch.common.utils import logger
+from msprobe.pytorch.bench_functions.apply_adam_w import npu_apply_adam_w
+from msprobe.pytorch.bench_functions.confusion_transpose import npu_confusion_transpose, \
+ npu_confusion_transpose_backward
+from msprobe.pytorch.bench_functions.fast_gelu import fast_gelu, npu_fast_gelu_backward
+from msprobe.pytorch.bench_functions.layer_norm_eval import npu_layer_norm_eval
+from msprobe.pytorch.bench_functions.linear import npu_linear, npu_linear_backward
+from msprobe.pytorch.bench_functions.matmul_backward import matmul_backward
+from msprobe.pytorch.bench_functions.npu_fusion_attention import npu_fusion_attention, npu_fusion_attention_grad
+from msprobe.pytorch.bench_functions.rms_norm import npu_rms_norm, npu_rms_norm_backward
+from msprobe.pytorch.bench_functions.rotary_mul import npu_rotary_mul, npu_rotary_mul_backward
+from msprobe.pytorch.bench_functions.scaled_mask_softmax import npu_scaled_masked_softmax, \
+ npu_scaled_masked_softmax_backward
+from msprobe.pytorch.bench_functions.swiglu import npu_swiglu, npu_swiglu_backward, swish_grad, swish
+
+
+class Register(dict):
+ def __init__(self, *args, **kwargs):
+ super(Register, self).__init__(*args, **kwargs)
+ self._dict = {}
+
+ def __call__(self, target_func_list):
+ for target in target_func_list:
+ self.register(target)
+ return
+
+ def __setitem__(self, key, value):
+ self._dict[key] = value
+
+ def __getitem__(self, key):
+ return self._dict[key]
+
+ def __contains__(self, key):
+ return key in self._dict
+
+ def __str__(self):
+ return str(self._dict)
+
+ def keys(self):
+ return self._dict.keys()
+
+ def values(self):
+ return self._dict.values()
+
+ def items(self):
+ return self._dict.items()
+
+ def register(self, target):
+
+ def add_register_item(key, value):
+ if key in self._dict:
+ logger.warning(f"{value.__name__} has been registered before, so we will overriden it.")
+ self[key] = value
+ return value
+
+ if callable(target):
+ return add_register_item(target.__name__, target)
+ else:
+ raise Exception(f"The func {target} is not callable.")
+
+
+# register for npu custom bench functions
+npu_custom_functions = Register()
+npu_custom_functions([
+ npu_apply_adam_w, npu_confusion_transpose, fast_gelu, npu_layer_norm_eval, npu_linear, npu_fusion_attention,
+ npu_rms_norm, npu_rotary_mul, npu_scaled_masked_softmax, npu_swiglu
+])
+
+# register for npu custom backward bench functions
+npu_custom_grad_functions = Register()
+npu_custom_grad_functions([
+ npu_confusion_transpose_backward, npu_fast_gelu_backward, npu_linear_backward, matmul_backward,
+ npu_fusion_attention_grad, npu_rms_norm_backward, npu_rotary_mul_backward, npu_scaled_masked_softmax_backward,
+ npu_swiglu_backward
+])
diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/__init__.py b/debug/accuracy_tools/msprobe/pytorch/functional/__init__.py
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/parse_tool/lib/__init__.py
rename to debug/accuracy_tools/msprobe/pytorch/functional/__init__.py
diff --git a/debug/accuracy_tools/atat/pytorch/functional/data_processor.py b/debug/accuracy_tools/msprobe/pytorch/functional/data_processor.py
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/functional/data_processor.py
rename to debug/accuracy_tools/msprobe/pytorch/functional/data_processor.py
diff --git a/debug/accuracy_tools/atat/pytorch/functional/dump_module.py b/debug/accuracy_tools/msprobe/pytorch/functional/dump_module.py
similarity index 73%
rename from debug/accuracy_tools/atat/pytorch/functional/dump_module.py
rename to debug/accuracy_tools/msprobe/pytorch/functional/dump_module.py
index 675fa2a1bfdfdef9b12bf99f2428e3236c86a906..efb95c3369f6cda2f883d70a86261e0232535f86 100644
--- a/debug/accuracy_tools/atat/pytorch/functional/dump_module.py
+++ b/debug/accuracy_tools/msprobe/pytorch/functional/dump_module.py
@@ -1,10 +1,10 @@
import torch.nn as nn
-from atat.pytorch.common.log import logger
-from atat.core.common.const import Const
-from atat.pytorch.hook_module.api_registry import api_register
-from atat.pytorch.debugger.precision_debugger import PrecisionDebugger
-from atat.core.common.exceptions import MsaccException
-from atat.core.data_dump.scope import BaseScope
+from msprobe.pytorch.common.log import logger
+from msprobe.core.common.const import Const
+from msprobe.pytorch.hook_module.api_registry import api_register
+from msprobe.pytorch.debugger.precision_debugger import PrecisionDebugger
+from msprobe.core.common.exceptions import MsprobeException
+from msprobe.core.data_dump.scope import BaseScope
module_count = {}
@@ -12,10 +12,10 @@ module_count = {}
def module_dump(module, dump_name):
if not isinstance(module, nn.Module):
logger.error("The parameter:module in module_dump is not a Module subclass.")
- raise MsaccException(MsaccException.INVALID_PARAM_ERROR)
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
if not isinstance(dump_name, str):
logger.error("The parameter:dump_name in module_dump is not a str type.")
- raise MsaccException(MsaccException.INVALID_PARAM_ERROR)
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
api_register.api_originality()
if dump_name not in module_count:
module_count[dump_name] = 0
diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/__init__.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/__init__.py
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/hook_module/__init__.py
rename to debug/accuracy_tools/msprobe/pytorch/hook_module/__init__.py
diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/api_registry.py
similarity index 91%
rename from debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py
rename to debug/accuracy_tools/msprobe/pytorch/hook_module/api_registry.py
index 3b971cc71ecaf1d229337f4acf5afab6b5b0f9db..f75201eafcda40c61b2c5c3da710d6cfb06719b8 100644
--- a/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py
+++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/api_registry.py
@@ -18,15 +18,15 @@
import torch
import torch.distributed as dist
-from atat.pytorch.hook_module import wrap_torch, wrap_functional, wrap_tensor, wrap_vf, wrap_distributed, wrap_aten
-from atat.pytorch.hook_module.wrap_aten import get_aten_ops
-from atat.pytorch.hook_module.wrap_distributed import get_distributed_ops
-from atat.pytorch.hook_module.wrap_functional import get_functional_ops
-from atat.pytorch.hook_module.wrap_tensor import get_tensor_ops
-from atat.pytorch.hook_module.wrap_torch import get_torch_ops
-from atat.pytorch.hook_module.wrap_vf import get_vf_ops
-from atat.pytorch.common.utils import torch_without_guard_version, npu_distributed_api, is_gpu
-from atat.core.common.const import Const
+from msprobe.pytorch.hook_module import wrap_torch, wrap_functional, wrap_tensor, wrap_vf, wrap_distributed, wrap_aten
+from msprobe.pytorch.hook_module.wrap_aten import get_aten_ops
+from msprobe.pytorch.hook_module.wrap_distributed import get_distributed_ops
+from msprobe.pytorch.hook_module.wrap_functional import get_functional_ops
+from msprobe.pytorch.hook_module.wrap_tensor import get_tensor_ops
+from msprobe.pytorch.hook_module.wrap_torch import get_torch_ops
+from msprobe.pytorch.hook_module.wrap_vf import get_vf_ops
+from msprobe.pytorch.common.utils import torch_without_guard_version, npu_distributed_api, is_gpu
+from msprobe.core.common.const import Const
torch_version_above_2 = torch.__version__.split('+')[0] > '2.0'
diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py
similarity index 97%
rename from debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py
rename to debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py
index 57212b6e45c572db3d3a235035647bbb614a3cb3..ff6427e51e5c6bc6b715991979890f759ab955cf 100644
--- a/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py
+++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py
@@ -17,10 +17,12 @@
import functools
import threading
+
import torch
import torch.nn as nn
import torch.utils.hooks as full_hooks
-from atat.core.common.const import Const
+
+from msprobe.core.common.const import Const
class HOOKModule(nn.Module):
@@ -61,6 +63,10 @@ class HOOKModule(nn.Module):
HOOKModule.inner_stop_hook[self.current_thread] = False
return result
+ @classmethod
+ def reset_module_stats(cls):
+ cls.module_count = {}
+
def _call_func(self, *input, **kwargs):
full_backward_hooks, non_full_backward_hooks = [], []
if len(self._backward_hooks) > 0:
diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml
similarity index 99%
rename from debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml
rename to debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml
index d64c577ff38b7b7e3478d59eb1754845973c7103..f68708e945ea5351c84be982250b1dde436f3ba1 100644
--- a/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml
+++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml
@@ -1873,4 +1873,5 @@ distributed:
- reduce_scatter
- _reduce_scatter_base
- _all_gather_base
- - all_to_all_single
\ No newline at end of file
+ - all_to_all_single
+ - all_to_all
\ No newline at end of file
diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/utils.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py
similarity index 95%
rename from debug/accuracy_tools/atat/pytorch/hook_module/utils.py
rename to debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py
index e4ed157af6dcafc826eb74fcc40898bfdc835eac..d991445db091f74fbe9f5f7d5f68dd4bdcf4b868 100644
--- a/debug/accuracy_tools/atat/pytorch/hook_module/utils.py
+++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py
@@ -18,7 +18,7 @@
import os
import yaml
-from atat.core.common.file_check import FileOpen
+from msprobe.core.common.file_utils import FileOpen
cur_path = os.path.dirname(os.path.realpath(__file__))
yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py
similarity index 71%
rename from debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py
rename to debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py
index c5a3c6365d1841927c3acd2993ac2828092ae80b..a99d669beb227c0576d95991d5e89811d229c4c3 100644
--- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py
+++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py
@@ -20,16 +20,18 @@ import torch
import yaml
-from atat.pytorch.hook_module.hook_module import HOOKModule
-from atat.pytorch.common.utils import torch_device_guard
-from atat.core.common.const import Const
-from atat.core.common.file_check import FileOpen
-
+from msprobe.pytorch.hook_module.hook_module import HOOKModule
+from msprobe.pytorch.common.utils import torch_device_guard
+from msprobe.core.common.const import Const
+from msprobe.core.common.file_utils import FileOpen
+from msprobe.pytorch.function_factory import npu_custom_grad_functions
cur_path = os.path.dirname(os.path.realpath(__file__))
yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
with FileOpen(yaml_path, 'r') as f:
- WrapAtenOps = yaml.safe_load(f).get('aten')
+ Ops = yaml.safe_load(f)
+ WrapAtenOps = Ops.get('aten')
+ WhiteAtenOps = Ops.get('white_aten_ops', [])
aten_func = {}
@@ -48,7 +50,7 @@ class HOOKAtenOP(object):
class AtenOPTemplate(HOOKModule):
- def __init__(self, op, hook):
+ def __init__(self, op, hook, need_hook=True):
if isinstance(op, torch._ops.OpOverloadPacket):
op_name_ = op._qualified_op_name.split("::")[-1]
else:
@@ -58,10 +60,21 @@ class AtenOPTemplate(HOOKModule):
op_name_ = op_name_ + '.' + overload_name
self.op = op
self.prefix_op_name_ = "Aten" + Const.SEP + str(op_name_) + Const.SEP
- super().__init__(hook)
+ self.need_hook = need_hook
+ if self.need_hook:
+ super().__init__(hook)
@torch_device_guard
def forward(self, *args, **kwargs):
+ if isinstance(self.op, str):
+ if self.op in npu_custom_grad_functions:
+ return npu_custom_grad_functions[self.op](*args, **kwargs)
+ if self.op in WhiteAtenOps:
+ return eval(f"torch.ops.aten.{self.op}")(*args, **kwargs)
+ if self.op not in aten_func:
+ raise Exception(f"Skip op[{self.op}] accuracy check, because the op is not "
+ f"in dir(torch.ops.aten) and support yaml.")
+ return aten_func[self.op](*args, **kwargs)
return self.op(*args, **kwargs)
@@ -80,13 +93,13 @@ class AtenOPPacketTemplate():
else:
return attr
- def overloads(self):
- return self.opPacket.overloads()
-
@torch_device_guard
def __call__(self, *args, **kwargs):
return AtenOPTemplate(self.opPacket, self.hook)(*args, **kwargs)
+ def overloads(self):
+ return self.opPacket.overloads()
+
def wrap_aten_op(op, hook):
return AtenOPPacketTemplate(op, hook)
diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py
similarity index 91%
rename from debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py
rename to debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py
index e02189ac1bf2a5d2e1607f2625200eb4ee2ea6e8..54afecb9a663db8aa2db998247fca707a8aa2ced 100644
--- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py
+++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py
@@ -20,10 +20,10 @@ from functools import wraps
import torch.distributed as dist
import yaml
-from atat.pytorch.hook_module.hook_module import HOOKModule
-from atat.pytorch.common.utils import torch_device_guard
-from atat.core.common.const import Const
-from atat.core.common.file_check import FileOpen
+from msprobe.pytorch.hook_module.hook_module import HOOKModule
+from msprobe.pytorch.common.utils import torch_device_guard
+from msprobe.core.common.const import Const
+from msprobe.core.common.file_utils import FileOpen
cur_path = os.path.dirname(os.path.realpath(__file__))
diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py
similarity index 94%
rename from debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py
rename to debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py
index fa97f5ee3106c62ab63d80e2b2ebe349494ce695..96d6986a0784cf98e42fd04d45a52f0be6bd39ed 100644
--- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py
+++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py
@@ -20,11 +20,11 @@ import os
import torch
import yaml
-from atat.pytorch.hook_module.hook_module import HOOKModule
-from atat.pytorch.common.utils import torch_device_guard
-from atat.core.common.const import Const
-from atat.pytorch.common.log import logger
-from atat.core.common.file_check import FileOpen
+from msprobe.pytorch.hook_module.hook_module import HOOKModule
+from msprobe.pytorch.common.utils import torch_device_guard
+from msprobe.core.common.const import Const
+from msprobe.pytorch.common.log import logger
+from msprobe.core.common.file_utils import FileOpen
def remove_dropout():
diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py
similarity index 71%
rename from debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py
rename to debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py
index 7d0882804f478d8657faeb5f90b0b718914d72e5..607b4ed3cd54a7e745fa1eca715b1ef8ba1e04b2 100644
--- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py
+++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py
@@ -17,19 +17,26 @@
import os
import torch
-import torch_npu
import yaml
-from atat.pytorch.hook_module.hook_module import HOOKModule
-from atat.pytorch.common.utils import torch_device_guard, torch_without_guard_version
-from atat.core.common.const import Const
-from atat.core.common.file_check import FileOpen
+from msprobe.pytorch.hook_module.hook_module import HOOKModule
+from msprobe.pytorch.common.utils import torch_device_guard, torch_without_guard_version
+from msprobe.core.common.const import Const
+from msprobe.core.common.file_utils import FileOpen
+from msprobe.pytorch.function_factory import npu_custom_functions
cur_path = os.path.dirname(os.path.realpath(__file__))
yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
with FileOpen(yaml_path, 'r') as f:
WrapNpuOps = yaml.safe_load(f).get('torch_npu')
+try:
+ import torch_npu
+except ImportError:
+ is_gpu = True
+else:
+ is_gpu = False
+
def get_npu_ops():
global WrapNpuOps
@@ -46,13 +53,19 @@ class HOOKNpuOP(object):
class NpuOPTemplate(HOOKModule):
- def __init__(self, op_name, hook):
+ def __init__(self, op_name, hook, need_hook=True):
self.op_name_ = op_name
self.prefix_op_name_ = "NPU" + Const.SEP + str(op_name) + Const.SEP
- super().__init__(hook)
+ self.need_hook = need_hook
+ if need_hook:
+ super().__init__(hook)
@torch_device_guard
def forward(self, *args, **kwargs):
+ if not self.need_hook:
+ if self.op_name_ not in npu_custom_functions:
+ raise Exception(f'There is not bench function {self.op_name_}')
+ return npu_custom_functions[self.op_name_](*args, **kwargs)
if torch_without_guard_version:
return getattr(torch.ops.npu, str(self.op_name_))(*args, **kwargs)
else:
@@ -60,7 +73,6 @@ class NpuOPTemplate(HOOKModule):
def wrap_npu_op(op_name, hook):
-
def npu_op_template(*args, **kwargs):
return NpuOPTemplate(op_name, hook)(*args, **kwargs)
diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py
similarity index 89%
rename from debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py
rename to debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py
index 6fac18140238e016004c2329f0d8ff647045c0c0..aba6f86148967f0f5d0af4d9ac2a85150e79ff47 100644
--- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py
+++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py
@@ -20,10 +20,10 @@ import os
import torch
import yaml
-from atat.pytorch.hook_module.hook_module import HOOKModule
-from atat.pytorch.common.utils import torch_device_guard, parameter_adapter
-from atat.core.common.const import Const
-from atat.core.common.file_check import FileOpen
+from msprobe.pytorch.hook_module.hook_module import HOOKModule
+from msprobe.pytorch.common.utils import torch_device_guard, parameter_adapter
+from msprobe.core.common.const import Const
+from msprobe.core.common.file_utils import FileOpen
cur_path = os.path.dirname(os.path.realpath(__file__))
yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py
similarity index 91%
rename from debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py
rename to debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py
index f0bd01fe4624ccc1a50d1f3e1fb5d614b30a7a8d..3f9518b7f1a3b9532de0ad016b21246be5bd883d 100644
--- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py
+++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py
@@ -20,10 +20,10 @@ import os
import torch
import yaml
-from atat.pytorch.hook_module.hook_module import HOOKModule
-from atat.pytorch.common.utils import torch_device_guard
-from atat.core.common.const import Const
-from atat.core.common.file_check import FileOpen
+from msprobe.pytorch.hook_module.hook_module import HOOKModule
+from msprobe.pytorch.common.utils import torch_device_guard
+from msprobe.core.common.const import Const
+from msprobe.core.common.file_utils import FileOpen
cur_path = os.path.dirname(os.path.realpath(__file__))
yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py
similarity index 88%
rename from debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py
rename to debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py
index d4c570221d4b219be168b25dafd76264769da197..351820fd6c11f71b7465e1e5e456b582161ff462 100644
--- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py
+++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py
@@ -20,10 +20,10 @@ import os
import torch
import yaml
-from atat.pytorch.hook_module.hook_module import HOOKModule
-from atat.core.common.file_check import FileOpen
-from atat.pytorch.common.utils import torch_device_guard
-from atat.core.common.const import Const
+from msprobe.pytorch.hook_module.hook_module import HOOKModule
+from msprobe.core.common.file_utils import FileOpen
+from msprobe.pytorch.common.utils import torch_device_guard
+from msprobe.core.common.const import Const
cur_path = os.path.dirname(os.path.realpath(__file__))
yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
diff --git a/debug/accuracy_tools/atat/pytorch/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/module_processer.py
similarity index 72%
rename from debug/accuracy_tools/atat/pytorch/module_processer.py
rename to debug/accuracy_tools/msprobe/pytorch/module_processer.py
index 8ce9140e32cbc9f7d2d62eede164e24127a40f34..3e9969d32d9147a6f26b7a6a4219364368a6116b 100644
--- a/debug/accuracy_tools/atat/pytorch/module_processer.py
+++ b/debug/accuracy_tools/msprobe/pytorch/module_processer.py
@@ -1,15 +1,17 @@
from functools import wraps
+
import torch
from torch.utils.hooks import BackwardHook
-from atat.core.common.const import Const
-from atat.core.data_dump.scope import ModuleRangeScope
+
+from msprobe.core.common.const import Const
+from msprobe.core.data_dump.scope import ModuleRangeScope
class ModuleProcesser:
+ module_count = {}
module_stack = []
api_parent_node = ""
module_node = {}
- current_module_name = ""
def __init__(self, scope):
if isinstance(scope, ModuleRangeScope):
@@ -19,15 +21,22 @@ class ModuleProcesser:
BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook)
BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook)
BackwardHook.setup_output_hook = ModuleProcesser.filter_tensor_and_tuple(BackwardHook.setup_output_hook)
- self.module_count = {}
@staticmethod
def filter_tensor_and_tuple(func):
@wraps(func)
def wrap_by_filter_tensor_and_tuple(*args, **kwargs):
- # setup_output_hook传入非tensor数据,工具后续dump会报错,处理方式是非tensor数据不传入
+ # setup_output_hook传入非tensor数据,工具后续dump会报错,处理方式是解析非tensor数据的属性,对tensor属性挂hook
# setup_output_hook定义为setup_output_hook(self, args),因此处理第二个位置参数,即*args[1]
if not isinstance(args[1], (torch.Tensor, tuple)):
+ for item_str in dir(args[1]):
+ item = getattr(args[1], item_str)
+ # 处理tensor或者只包含tensor的元组
+ if isinstance(item, torch.Tensor) or \
+ (isinstance(item, tuple) and all(isinstance(x, torch.Tensor) for x in item)):
+ args_new = (args[0], item)
+ result = func(*args_new, **kwargs)
+ setattr(args[1], item_str, result)
return args[1]
return func(*args, **kwargs)
@@ -55,11 +64,26 @@ class ModuleProcesser:
else:
return result
+ @staticmethod
+ def module_count_func(module_name):
+ if module_name not in ModuleProcesser.module_count:
+ ModuleProcesser.module_count[module_name] = 0
+ else:
+ ModuleProcesser.module_count[module_name] += 1
+ return ModuleProcesser.module_count[module_name]
+
+ @classmethod
+ def reset_module_stats(cls):
+ cls.module_count = {}
+ cls.module_stack = []
+ cls.api_parent_node = ""
+ cls.module_node = {}
+
def node_hook(self, name_prefix, start_or_stop, **kwargs):
def pre_hook(module, input, output=None):
try:
- index = self.module_count_func(name_prefix)
+ index = ModuleProcesser.module_count_func(name_prefix)
except IndexError as e:
index = None
pass
@@ -89,10 +113,3 @@ class ModuleProcesser:
return pre_hook
else:
return end_hook
-
- def module_count_func(self, module_name):
- if module_name not in self.module_count:
- self.module_count[module_name] = 0
- else:
- self.module_count[module_name] += 1
- return self.module_count[module_name]
diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/__init__.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/__init__.py
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/online_dispatch/__init__.py
rename to debug/accuracy_tools/msprobe/pytorch/online_dispatch/__init__.py
diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/compare.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py
similarity index 96%
rename from debug/accuracy_tools/atat/pytorch/online_dispatch/compare.py
rename to debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py
index e6d55ca0614767338023673b6d0cfa5c7d990c0f..4e3d574cd84118dda2e32667be09d8695312a584 100644
--- a/debug/accuracy_tools/atat/pytorch/online_dispatch/compare.py
+++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py
@@ -6,12 +6,11 @@ import json
from collections import namedtuple
from rich.table import Table
from rich.console import Console
-from .single_compare import single_benchmark_compare_wrap
-from .utils import DispatchException
-from atat.core.common.const import CompareConst
-from atat.core.common.file_check import FileOpen
-from atat.pytorch.common.log import logger
-from atat.core.common.utils import CompareException
+from msprobe.pytorch.online_dispatch.single_compare import single_benchmark_compare_wrap
+from msprobe.core.common.const import CompareConst
+from msprobe.core.common.file_utils import FileOpen
+from msprobe.pytorch.common.log import logger
+from msprobe.core.common.utils import CompareException
ELEMENT_NUM_THRESHOLD = 100
ZERO_NUM_THRESHOLD = 0.1
@@ -228,7 +227,7 @@ class Comparator:
else:
is_bwd_success, bwd_compare_alg_results = True, None
if is_bwd_success and bwd_compare_alg_results is None:
- self.saver.record_results(ResultInfo(api_name, is_fwd_success, CompareConst.NA, fwd_compare_alg_results,
+ self.saver.record_results(ResultInfo(api_name, is_fwd_success, CompareConst.NAN, fwd_compare_alg_results,
bwd_compare_alg_results))
else:
self.saver.record_results(ResultInfo(api_name, is_fwd_success, is_bwd_success, fwd_compare_alg_results,
diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/dispatch.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py
similarity index 94%
rename from debug/accuracy_tools/atat/pytorch/online_dispatch/dispatch.py
rename to debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py
index 7502d746acf38a52b427659b8d97f5f55a4ce96c..2251fa6cb6aabb0054598cbeb0354d3ac9552c7a 100644
--- a/debug/accuracy_tools/atat/pytorch/online_dispatch/dispatch.py
+++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py
@@ -16,14 +16,14 @@ except ImportError:
else:
is_npu = True
-from .dump_compare import dispatch_workflow, dispatch_multiprocess, error_call, TimeStatistics, \
+from msprobe.pytorch.online_dispatch.dump_compare import dispatch_workflow, dispatch_multiprocess, error_call, TimeStatistics, \
DispatchRunParam, DisPatchDataInfo
-from .utils import get_callstack, data_to_cpu, logger_debug, logger_error, logger_warn, logger_logo, get_sys_info, \
+from msprobe.pytorch.online_dispatch.utils import get_callstack, data_to_cpu, logger_debug, logger_error, logger_warn, logger_logo, get_sys_info, \
DispatchException
-from .compare import Comparator
-from atat.core.common.file_check import FileOpen
-from atat.core.common.utils import check_file_or_directory_path, check_path_before_create
-from atat.core.common.const import Const, CompareConst
+from msprobe.pytorch.online_dispatch.compare import Comparator
+from msprobe.core.common.file_utils import FileOpen
+from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create
+from msprobe.core.common.const import Const, CompareConst
current_time = time.strftime("%Y%m%d%H%M%S")
RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
@@ -209,9 +209,9 @@ class PtdbgDispatch(TorchDispatchMode):
time_now = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
if tag is None or not isinstance(tag, str):
logger_warn('There is not tag or the type of tag is not string.')
- dir_name = f'atat_rank{self.device_id}_{time_now}'
+ dir_name = f'msprobe_rank{self.device_id}_{time_now}'
else:
- dir_name = f'atat_{tag}_rank{self.device_id}_{time_now}'
+ dir_name = f'msprobe_{tag}_rank{self.device_id}_{time_now}'
return dir_name
def load_yaml_file(self, file_path):
diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/dump_compare.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py
similarity index 95%
rename from debug/accuracy_tools/atat/pytorch/online_dispatch/dump_compare.py
rename to debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py
index cd7c5a3f282d17b4e064a4b00928861b2e32f737..5e8bf4f1117b0a86f67902df723ca08ad19db73e 100644
--- a/debug/accuracy_tools/atat/pytorch/online_dispatch/dump_compare.py
+++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py
@@ -5,11 +5,11 @@ from datetime import datetime, timezone
import pandas as pd
import torch
-from .utils import np_save_data, logger_debug, logger_error, logger_warn, logger_user, COLOR_RED, COLOR_GREEN, \
+from msprobe.pytorch.online_dispatch.utils import np_save_data, logger_debug, logger_error, logger_warn, logger_user, COLOR_RED, COLOR_GREEN, \
COLOR_RESET, CSV_COLUMN_NAME
-from atat.core.common.file_check import FileOpen, change_mode
-from atat.core.common.const import CompareConst, FileCheckConst, Const
-from atat.pytorch.common.log import logger
+from msprobe.core.common.file_utils import FileOpen, change_mode
+from msprobe.core.common.const import CompareConst, FileCheckConst, Const
+from msprobe.pytorch.common.log import logger
class DispatchRunParam:
def __init__(self, debug_flag, device_id, root_npu_path, root_cpu_path, process_num, comparator):
diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/single_compare.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/single_compare.py
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/online_dispatch/single_compare.py
rename to debug/accuracy_tools/msprobe/pytorch/online_dispatch/single_compare.py
diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/torch_ops_config.yaml b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/torch_ops_config.yaml
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/online_dispatch/torch_ops_config.yaml
rename to debug/accuracy_tools/msprobe/pytorch/online_dispatch/torch_ops_config.yaml
diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/utils.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py
similarity index 97%
rename from debug/accuracy_tools/atat/pytorch/online_dispatch/utils.py
rename to debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py
index f3fcffb6f26adbd2b2ff6b77da720f7ecdfcb0ca..c1d1e841a40137508c1fd09d617d57c83e9306a9 100644
--- a/debug/accuracy_tools/atat/pytorch/online_dispatch/utils.py
+++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py
@@ -12,8 +12,8 @@ except ImportError:
else:
pta_cpu_device = torch.device("cpu")
-from atat.core.common.const import CompareConst, FileCheckConst
-from atat.core.common.file_check import change_mode
+from msprobe.core.common.const import CompareConst, FileCheckConst
+from msprobe.core.common.file_utils import change_mode
cpu_device = torch._C.device("cpu")
COLOR_RED = '\033[31m'
diff --git a/debug/accuracy_tools/atat/pytorch/parse.py b/debug/accuracy_tools/msprobe/pytorch/parse.py
similarity index 50%
rename from debug/accuracy_tools/atat/pytorch/parse.py
rename to debug/accuracy_tools/msprobe/pytorch/parse.py
index 40792d0e0297a9b034f186e255193d6201517764..efd3d4a2ddb807e23ba346b6c80ef344d560050d 100644
--- a/debug/accuracy_tools/atat/pytorch/parse.py
+++ b/debug/accuracy_tools/msprobe/pytorch/parse.py
@@ -1,4 +1,4 @@
-from atat.pytorch.parse_tool import cli
+from msprobe.pytorch.parse_tool import cli
if __name__ == '__main__':
cli.parse()
diff --git a/debug/accuracy_tools/kj600/kj600/__init__.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/__init__.py
similarity index 100%
rename from debug/accuracy_tools/kj600/kj600/__init__.py
rename to debug/accuracy_tools/msprobe/pytorch/parse_tool/__init__.py
diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/cli.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/cli.py
similarity index 89%
rename from debug/accuracy_tools/atat/pytorch/parse_tool/cli.py
rename to debug/accuracy_tools/msprobe/pytorch/parse_tool/cli.py
index f59fbf13a8d3e2785611022cd7b5b9a2926ea008..500e8eef6846b1209817083992a5e3630381b7dc 100644
--- a/debug/accuracy_tools/atat/pytorch/parse_tool/cli.py
+++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/cli.py
@@ -14,8 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
-from atat.pytorch.parse_tool.lib.interactive_cli import InteractiveCli
-from atat.pytorch.common.log import logger
+from msprobe.pytorch.parse_tool.lib.interactive_cli import InteractiveCli
+from msprobe.pytorch.common.log import logger
def _run_interactive_cli(cli=None):
diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/__init__.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/compare.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/compare.py
similarity index 92%
rename from debug/accuracy_tools/atat/pytorch/parse_tool/lib/compare.py
rename to debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/compare.py
index dfc4529414cbe00307b36fc58e5d62d64c6fdf32..2b091c59e8caef33576a7f7c79efcdbc0987afbe 100644
--- a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/compare.py
+++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/compare.py
@@ -19,9 +19,9 @@ import os
import time
import numpy as np
from collections import namedtuple
-from atat.pytorch.parse_tool.lib.utils import Util
-from atat.pytorch.parse_tool.lib.config import Const
-from atat.pytorch.parse_tool.lib.parse_exception import ParseException
+from msprobe.pytorch.parse_tool.lib.utils import Util
+from msprobe.pytorch.parse_tool.lib.config import Const
+from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
class Compare:
@@ -83,16 +83,17 @@ class Compare:
(left, right, save_txt, rl, al, diff_count) = args
if left is None or right is None:
raise ParseException("invalid input or output")
- try:
- left_data = np.load(left)
- right_data = np.load(right)
- except UnicodeError as e:
- self.log.error("%s %s" % ("UnicodeError", str(e)))
- self.log.warning("Please check the npy file")
- raise ParseException(ParseException.PARSE_UNICODE_ERROR) from e
- except IOError:
- self.log.error("Failed to load npy %s or %s." % (left, right))
- raise ParseException(ParseException.PARSE_LOAD_NPY_ERROR) from e
+ if self.util.check_path_valid(left) and self.util.check_path_valid(right):
+ try:
+ left_data = np.load(left)
+ right_data = np.load(right)
+ except UnicodeError as e:
+ self.log.error("%s %s" % ("UnicodeError", str(e)))
+ self.log.warning("Please check the npy file")
+ raise ParseException(ParseException.PARSE_UNICODE_ERROR) from e
+ except IOError:
+ self.log.error("Failed to load npy %s or %s." % (left, right))
+ raise ParseException(ParseException.PARSE_LOAD_NPY_ERROR) from e
# save to txt
if save_txt:
@@ -157,8 +158,10 @@ class Compare:
return res
def compare_npy(self, file, bench_file, output_path):
- data = np.load(file)
- bench_data = np.load(bench_file)
+ if self.util.check_path_valid(file):
+ data = np.load(file)
+ if self.util.check_path_valid(bench_file):
+ bench_data = np.load(bench_file)
shape, dtype = data.shape, data.dtype
bench_shape, bench_dtype = bench_data.shape, bench_data.dtype
filename = os.path.basename(file)
diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/config.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/config.py
similarity index 98%
rename from debug/accuracy_tools/atat/pytorch/parse_tool/lib/config.py
rename to debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/config.py
index a745ff46f08a28c39c989a5d8dce4ff5cf475ee5..a9a8b2b00e2703ae28d6beff114848de099f3900 100644
--- a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/config.py
+++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/config.py
@@ -33,7 +33,7 @@ class Const:
OFFLINE_DUMP_CONVERT_PATTERN = \
r"^([A-Za-z0-9_-]+)\.([A-Za-z0-9_-]+)\.([0-9]+)(\.[0-9]+)?\.([0-9]{1,255})" \
r"\.([a-z]+)\.([0-9]{1,255})(\.[x0-9]+)?\.npy$"
- NUMPY_PATTERN = r".*\.npy$"
+ NUMPY_PATTERN = r"^[\w\-_-]\.npy$"
NPY_SUFFIX = ".npy"
PKL_SUFFIX = ".pkl"
DIRECTORY_LENGTH = 4096
diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/file_desc.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/file_desc.py
similarity index 100%
rename from debug/accuracy_tools/atat/pytorch/parse_tool/lib/file_desc.py
rename to debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/file_desc.py
diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/interactive_cli.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/interactive_cli.py
similarity index 93%
rename from debug/accuracy_tools/atat/pytorch/parse_tool/lib/interactive_cli.py
rename to debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/interactive_cli.py
index 12b07183fbc2e1c2ea630f05ac44deda744d4d01..1ea7dd30153e458b758dc0a79779b54a25fe8289 100644
--- a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/interactive_cli.py
+++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/interactive_cli.py
@@ -16,10 +16,10 @@
"""
import cmd
import argparse
-from atat.pytorch.parse_tool.lib.parse_tool import ParseTool
-from atat.pytorch.parse_tool.lib.utils import Util
-from atat.pytorch.parse_tool.lib.config import Const
-from atat.pytorch.parse_tool.lib.parse_exception import catch_exception
+from msprobe.pytorch.parse_tool.lib.parse_tool import ParseTool
+from msprobe.pytorch.parse_tool.lib.utils import Util
+from msprobe.pytorch.parse_tool.lib.config import Const
+from msprobe.pytorch.parse_tool.lib.parse_exception import catch_exception
class InteractiveCli(cmd.Cmd):
diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/parse_exception.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_exception.py
similarity index 96%
rename from debug/accuracy_tools/atat/pytorch/parse_tool/lib/parse_exception.py
rename to debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_exception.py
index 1177c51985dc82fba632898590762e38387603ab..7525230cedc7ff11d4112a55998c6414e8f09217 100644
--- a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/parse_exception.py
+++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_exception.py
@@ -15,7 +15,7 @@
# limitations under the License.
"""
import logging
-from atat.core.common.exceptions import FileCheckException
+from msprobe.core.common.exceptions import FileCheckException
class ParseException(Exception):
diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/parse_tool.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_tool.py
similarity index 95%
rename from debug/accuracy_tools/atat/pytorch/parse_tool/lib/parse_tool.py
rename to debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_tool.py
index 3e02baa1272199a961fd550a0837d68100de8348..9a47dc54cf9e15d65eb5fb8d3bae54358777051c 100644
--- a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/parse_tool.py
+++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_tool.py
@@ -18,11 +18,11 @@ import argparse
import os
from collections import namedtuple
-from atat.pytorch.parse_tool.lib.config import Const
-from atat.pytorch.parse_tool.lib.utils import Util
-from atat.pytorch.parse_tool.lib.compare import Compare
-from atat.pytorch.parse_tool.lib.visualization import Visualization
-from atat.pytorch.parse_tool.lib.parse_exception import catch_exception, ParseException
+from msprobe.pytorch.parse_tool.lib.config import Const
+from msprobe.pytorch.parse_tool.lib.utils import Util
+from msprobe.pytorch.parse_tool.lib.compare import Compare
+from msprobe.pytorch.parse_tool.lib.visualization import Visualization
+from msprobe.pytorch.parse_tool.lib.parse_exception import catch_exception, ParseException
class ParseTool:
diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/utils.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py
similarity index 94%
rename from debug/accuracy_tools/atat/pytorch/parse_tool/lib/utils.py
rename to debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py
index ce42d242ba29ef4dcf7f2af042242b1b9987de42..a8abec2d15634d08429e51899a83d4938c4a8869 100644
--- a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/utils.py
+++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py
@@ -25,15 +25,15 @@ import csv
import time
import numpy as np
from collections import namedtuple
-from atat.pytorch.parse_tool.lib.config import Const
-from atat.pytorch.parse_tool.lib.file_desc import DumpDecodeFileDesc, FileDesc
-from atat.pytorch.parse_tool.lib.parse_exception import ParseException
-from atat.core.common.file_check import change_mode, check_other_user_writable,\
+from msprobe.pytorch.parse_tool.lib.config import Const
+from msprobe.pytorch.parse_tool.lib.file_desc import DumpDecodeFileDesc, FileDesc
+from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
+from msprobe.core.common.file_utils import change_mode, check_other_user_writable,\
check_path_executable, check_path_owner_consistent
-from atat.core.common.const import FileCheckConst
-from atat.core.common.file_check import FileOpen
-from atat.core.common.utils import check_file_or_directory_path
-from atat.pytorch.common.log import logger
+from msprobe.core.common.const import FileCheckConst
+from msprobe.core.common.file_utils import FileOpen
+from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create
+from msprobe.pytorch.common.log import logger
try:
@@ -73,16 +73,6 @@ class Util:
def path_strip(path):
return path.strip("'").strip('"')
- @staticmethod
- def _gen_npu_dump_convert_file_info(name, match, dir_path):
- return DumpDecodeFileDesc(name, dir_path, int(match.groups()[-4]), op_name=match.group(2),
- op_type=match.group(1), task_id=int(match.group(3)), anchor_type=match.groups()[-3],
- anchor_idx=int(match.groups()[-2]))
-
- @staticmethod
- def _gen_numpy_file_info(name, math, dir_path):
- return FileDesc(name, dir_path)
-
@staticmethod
def check_executable_file(path):
check_path_owner_consistent(path)
@@ -184,6 +174,16 @@ class Util:
def change_filemode_safe(self, path):
change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY)
+ @staticmethod
+ def _gen_npu_dump_convert_file_info(name, match, dir_path):
+ return DumpDecodeFileDesc(name, dir_path, int(match.groups()[-4]), op_name=match.group(2),
+ op_type=match.group(1), task_id=int(match.group(3)), anchor_type=match.groups()[-3],
+ anchor_idx=int(match.groups()[-2]))
+
+ @staticmethod
+ def _gen_numpy_file_info(name, math, dir_path):
+ return FileDesc(name, dir_path)
+
def execute_command(self, cmd):
if not cmd:
self.log.error("Commond is None")
@@ -245,7 +245,11 @@ class Util:
elif data.size % align != 0:
pad_array = np.zeros((align - data.size % align,))
data = np.append(data, pad_array)
- np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g')
+ check_path_before_create(dst_file)
+ try:
+ np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g')
+ except Exception as e:
+ self.log.error("An unexpected error occurred: %s when savetxt to %s" % (str(e)), dst_file)
change_mode(dst_file, FileCheckConst.DATA_FILE_AUTHORITY)
def list_convert_files(self, path, external_pattern=""):
diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/visualization.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py
similarity index 94%
rename from debug/accuracy_tools/atat/pytorch/parse_tool/lib/visualization.py
rename to debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py
index 3ef9878ae8213e5ece5813d8857544ce07b603d5..a10c7a447fd1cc5b18202bfb382190ed7a663ccc 100644
--- a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/visualization.py
+++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py
@@ -17,10 +17,10 @@
import json
import numpy as np
-from atat.pytorch.parse_tool.lib.config import Const
-from atat.pytorch.parse_tool.lib.utils import Util
-from atat.pytorch.parse_tool.lib.parse_exception import ParseException
-from atat.core.common.file_check import FileOpen
+from msprobe.pytorch.parse_tool.lib.config import Const
+from msprobe.pytorch.parse_tool.lib.utils import Util
+from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
+from msprobe.core.common.file_utils import FileOpen
class Visualization:
diff --git a/debug/accuracy_tools/atat/pytorch/pt_config.py b/debug/accuracy_tools/msprobe/pytorch/pt_config.py
similarity index 58%
rename from debug/accuracy_tools/atat/pytorch/pt_config.py
rename to debug/accuracy_tools/msprobe/pytorch/pt_config.py
index 0674b91b3410765ca7dcb9a6f38c6f03fa6e94f1..92145ee2c70fe8667726977236afecb4c14c4d3b 100644
--- a/debug/accuracy_tools/atat/pytorch/pt_config.py
+++ b/debug/accuracy_tools/msprobe/pytorch/pt_config.py
@@ -1,9 +1,10 @@
import json
import os
-from atat.core.common_config import CommonConfig, BaseConfig
-from atat.core.common.file_check import FileOpen
-from atat.core.common.const import Const
+from msprobe.core.common_config import CommonConfig, BaseConfig
+from msprobe.core.common.file_utils import FileOpen
+from msprobe.core.common.const import Const
+from msprobe.pytorch.hook_module.utils import WrapFunctionalOps, WrapTensorOps, WrapTorchOps
class TensorConfig(BaseConfig):
@@ -31,12 +32,12 @@ class StatisticsConfig(BaseConfig):
class OverflowCheckConfig(BaseConfig):
def __init__(self, json_config):
super().__init__(json_config)
- self.overflow_num = json_config.get("overflow_nums")
+ self.overflow_nums = json_config.get("overflow_nums")
self.check_mode = json_config.get("check_mode")
self.check_overflow_config()
def check_overflow_config(self):
- if self.overflow_num is not None and not isinstance(self.overflow_num, int):
+ if self.overflow_nums is not None and not isinstance(self.overflow_nums, int):
raise Exception("overflow_num is invalid")
if self.check_mode is not None and self.check_mode not in ["all", "aicore", "atomic"]:
raise Exception("check_mode is invalid")
@@ -61,20 +62,54 @@ class FreeBenchmarkCheckConfig(BaseConfig):
if self.preheat_step and self.preheat_step == 0:
raise Exception("preheat_step cannot be 0")
+
+class RunUTConfig(BaseConfig):
+ WrapApi = set(WrapFunctionalOps) | set(WrapTensorOps) | set(WrapTorchOps)
+ def __init__(self, json_config):
+ super().__init__(json_config)
+ self.white_list = json_config.get("white_list", Const.DEFAULT_LIST)
+ self.black_list = json_config.get("black_list", Const.DEFAULT_LIST)
+ self.error_data_path = json_config.get("error_data_path", Const.DEFAULT_PATH)
+ self.check_run_ut_config()
+
+ @classmethod
+ def check_filter_list_config(cls, key, filter_list):
+ if not isinstance(filter_list, list):
+ raise Exception("%s must be a list type" % key)
+ if not all(isinstance(item, str) for item in filter_list):
+ raise Exception("All elements in %s must be string type" % key)
+ invalid_api = [item for item in filter_list if item not in cls.WrapApi]
+ if invalid_api:
+ raise Exception("Invalid api in %s: %s" % (key, invalid_api))
+
+ @classmethod
+ def check_error_data_path_config(cls, error_data_path):
+ if not os.path.exists(error_data_path):
+ raise Exception("error_data_path: %s does not exist" % error_data_path)
+
+ def check_run_ut_config(self):
+ RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
+ RunUTConfig.check_filter_list_config(Const.BLACK_LIST, self.black_list)
+ RunUTConfig.check_error_data_path_config(self.error_data_path)
+
+
def parse_task_config(task, json_config):
default_dic = {}
if task == Const.TENSOR:
- config_dic = json_config.get(Const.TENSOR) if json_config.get(Const.TENSOR) else default_dic
+ config_dic = json_config.get(Const.TENSOR, default_dic)
return TensorConfig(config_dic)
elif task == Const.STATISTICS:
- config_dic = json_config.get(Const.STATISTICS) if json_config.get(Const.STATISTICS) else default_dic
+ config_dic = json_config.get(Const.STATISTICS, default_dic)
return StatisticsConfig(config_dic)
elif task == Const.OVERFLOW_CHECK:
- config_dic = json_config.get(Const.OVERFLOW_CHECK) if json_config.get(Const.OVERFLOW_CHECK) else default_dic
+ config_dic = json_config.get(Const.OVERFLOW_CHECK, default_dic)
return OverflowCheckConfig(config_dic)
elif task == Const.FREE_BENCHMARK:
- config_dic = json_config.get(Const.FREE_BENCHMARK) if json_config.get(Const.FREE_BENCHMARK) else default_dic
+ config_dic = json_config.get(Const.FREE_BENCHMARK, default_dic)
return FreeBenchmarkCheckConfig(config_dic)
+ elif task == Const.RUN_UT:
+ config_dic = json_config.get(Const.RUN_UT, default_dic)
+ return RunUTConfig(config_dic)
else:
return StatisticsConfig(default_dic)
diff --git a/debug/accuracy_tools/atat/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py
similarity index 77%
rename from debug/accuracy_tools/atat/pytorch/service.py
rename to debug/accuracy_tools/msprobe/pytorch/service.py
index d0b9c4d4b27bd13672e0529de6dd3349860e533f..3238d11b2b96689f0185d71d6c8f1c6a04e3f270 100644
--- a/debug/accuracy_tools/atat/pytorch/service.py
+++ b/debug/accuracy_tools/msprobe/pytorch/service.py
@@ -1,18 +1,23 @@
import functools
import os
from pathlib import Path
-
-from atat.pytorch.common.log import logger
-from atat.core.common.file_check import FileChecker, check_path_before_create
-from atat.core.common.const import Const, FileCheckConst
-from atat.core.common.exceptions import DistributedNotInitializedError, MsaccException
-from atat.core.data_dump.data_collector import build_data_collector
-from atat.core.data_dump.scope import BaseScope
-from atat.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
-from atat.pytorch.common.utils import get_rank_if_initialized
-from atat.pytorch.module_processer import ModuleProcesser
-from atat.pytorch.hook_module import remove_dropout
-from atat.pytorch.hook_module.api_registry import api_register
+import torch
+from packaging import version
+from msprobe.core.common.const import Const, FileCheckConst
+from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
+from msprobe.core.common.file_utils import FileChecker, check_path_before_create
+from msprobe.core.data_dump.data_collector import build_data_collector
+from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
+from msprobe.core.data_dump.scope import BaseScope
+from msprobe.pytorch.common.log import logger
+from msprobe.pytorch.common.utils import get_rank_if_initialized
+from msprobe.pytorch.hook_module import remove_dropout
+from msprobe.pytorch.hook_module.api_registry import api_register
+from msprobe.pytorch.hook_module.hook_module import HOOKModule
+from msprobe.pytorch.module_processer import ModuleProcesser
+
+if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
class Service:
@@ -27,6 +32,11 @@ class Service:
self.current_rank = None
self.dump_iter_dir = None
+ @staticmethod
+ def forward_backward_dump_end():
+ logger.info_on_rank_0("Data needed ends here.")
+ api_register.api_originality()
+
def build_hook(self, module_type, name):
def pre_hook(api_or_module_name, module, args, kwargs):
if module_type == BaseScope.Module_Type_Module:
@@ -62,7 +72,8 @@ class Service:
if not self.switch:
return
if self.data_collector:
- module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
+ # 此处获取到的grad_input实际为反向过程的输出数据,grad_output为反向过程的输入数据,因此传入时调换顺序
+ module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
self.data_collector.backward_data_collect(api_or_module_name, module, pid, module_input_output)
pid = os.getpid()
@@ -73,15 +84,30 @@ class Service:
backward_hook = functools.partial(backward_hook, backward_name_template)
return pre_forward_hook, forward_hook, backward_hook
+ def hook_optimizer(self, model):
+ def optimizer_pre_step_hook(optimizer, args, kwargs):
+ self.stop()
+ self.step()
+
+ def optimizer_post_step_hook(optimizer, args, kwargs):
+ self.start(model)
+
+
+ register_optimizer_step_pre_hook(optimizer_pre_step_hook)
+ register_optimizer_step_post_hook(optimizer_post_step_hook)
+
def step(self):
self.current_iter += 1
self.data_collector.update_iter(self.current_iter)
- def start(self, model):
+ ModuleProcesser.reset_module_stats()
+ HOOKModule.reset_module_stats()
+
+ def start(self, model, api_origin=False):
self.model = model
if self.config.step and self.current_iter > max(self.config.step):
self.stop()
- raise Exception("atat: exit after iteration {}".format(max(self.config.step)))
+ raise Exception("msprobe: exit after iteration {}".format(max(self.config.step)))
if self.config.step and self.current_iter not in self.config.step:
return
if self.first_start:
@@ -94,6 +120,8 @@ class Service:
return
self.register_hook_new()
self.first_start = False
+ if api_origin:
+ api_register.api_modularity()
self.switch = True
logger.info_on_rank_0(f"Dump switch is turned on at step {self.current_iter}. ")
if self.config.level != "L2":
@@ -138,7 +166,7 @@ class Service:
logger.info_on_rank_0("The {} hook function is successfully mounted to the model.".format(self.config.task))
if self.config.level in ["L0", "mix"]:
if self.model is None:
- logger.error_log_with_exp("The model is None.", MsaccException.INVALID_PARAM_ERROR)
+ logger.error_log_with_exp("The model is None.", MsprobeException.INVALID_PARAM_ERROR)
logger.info_on_rank_0("The init dump mode is enabled, and the module dump function will not be available")
for name, module in self.model.named_modules():
if module == self.model:
@@ -164,4 +192,4 @@ class Service:
api_register.api_modularity()
if Const.STATISTICS == self.config.task or Const.TENSOR == self.config.task:
- remove_dropout()
+ remove_dropout()
\ No newline at end of file
diff --git a/debug/accuracy_tools/msprobe/pytorch/torchair_compare/msit_path_add.py b/debug/accuracy_tools/msprobe/pytorch/torchair_compare/msit_path_add.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe49958cd8752bd45d6c5f9c198a49ea1c0996ba
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/torchair_compare/msit_path_add.py
@@ -0,0 +1,35 @@
+import os
+import site
+from msprobe.core.common.file_utils import change_mode
+from msprobe.core.common.const import FileCheckConst
+from msprobe.core.common.log import logger
+from msprobe.core.common.utils import PathAddException
+
+
+def create_symlink():
+ create_single_symlink(link_rel_path="msit_llm", source_rel_path="msprobe/msit/msit/components/llm/msit_llm")
+ create_single_symlink(link_rel_path="components", source_rel_path="msprobe/msit/msit/components")
+ logger.info("torchair compare path added successfully.")
+
+
+def create_single_symlink(link_rel_path, source_rel_path):
+ site_package_dir = site.getsitepackages()[0]
+ link_path = os.path.join(site_package_dir, link_rel_path)
+ if os.path.exists(link_path) or os.path.islink(link_path):
+ logger.warning("The msit and msit_llm packages are already installed. If you need accuracy comparison "
+ "capabilities in the torchair training scenario, please ensure that versions of the packages "
+ "support this feature or update to the latest version.")
+ raise PathAddException(PathAddException.PACKAGE_VERSION_CHECK)
+ try:
+ source_path = os.path.join(site_package_dir, source_rel_path)
+ os.symlink(source_path, link_path)
+ change_mode(link_path, FileCheckConst.DATA_DIR_AUTHORITY)
+ except FileNotFoundError as e:
+ logger.error("Source folder does not exist before create symlink.")
+ raise PathAddException(PathAddException.INVALID_FILE_ERROR) from e
+ except PermissionError as e:
+ logger.error("Permission denied when trying to create the symlink.")
+ raise PathAddException(PathAddException.PERMISSION_DENIED_ERROR) from e
+ except Exception as e:
+ logger.error(f"An unexpected error occurred: {e}")
+ raise PathAddException(PathAddException.UNEXPECTED_ERROR) from e
diff --git a/debug/accuracy_tools/msprobe/pytorch/torchair_compare/torchair_compare_cli.py b/debug/accuracy_tools/msprobe/pytorch/torchair_compare/torchair_compare_cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..c46e3f04456db0f63eb7c19ba94a5252f99f26b2
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/torchair_compare/torchair_compare_cli.py
@@ -0,0 +1,14 @@
+from msit_llm.compare.torchair_acc_cmp import acc_compare
+
+
+def torchair_compare_parser(parser):
+ parser.add_argument("--my-path", dest="my_path", type=str,
+ help=" The graph compare input path.", required=True)
+ parser.add_argument("--golden-path", dest="golden_path", type=str,
+ help=" The graph compare gold input path.", required=True)
+ parser.add_argument("--output-path", dest="output_path", type=str, default=".",
+ help=" The graph compare input path.", required=False)
+
+
+def torchair_compare_cli(args):
+ acc_compare(args.golden_path, args.my_path, args.output_path)
diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/__init__.py b/debug/accuracy_tools/msprobe/pytorch/visualization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/builder/__init__.py b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/builder/graph_builder.py b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/graph_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..675a87a412cf078e940441e53512156e8082476e
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/graph_builder.py
@@ -0,0 +1,152 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from msprobe.pytorch.visualization.graph.graph import Graph, BaseNode
+from msprobe.pytorch.visualization.graph.node_op import NodeOp
+from msprobe.pytorch.visualization.utils import load_json_file, load_data_json_file, save_json_file, GraphConst
+from msprobe.pytorch.visualization.builder.msprobe_adapter import get_input_output
+
+
+class GraphBuilder:
+ @staticmethod
+ def build(construct_path, data_path, model_name='DefaultModel'):
+ """
+ GraphBuilder的对外提供的构图方法
+ Args:
+ construct_path: construct.json路径
+ data_path: dump.json路径
+ model_name: 模型名字,依赖外部输入
+ Returns: Graph,代表图的数据结构
+ """
+ construct_dict = load_json_file(construct_path)
+ data_dict = load_data_json_file(data_path)
+ graph = Graph(model_name)
+ GraphBuilder._init_nodes(graph, construct_dict, data_dict)
+ GraphBuilder._collect_apis_between_modules(graph)
+ return graph
+
+ @staticmethod
+ def to_json(filename, config):
+ """
+ 将graph导出成.vis文件的接口
+ """
+ result = {}
+ if config.graph_b:
+ result[GraphConst.JSON_NPU_KEY] = config.graph_n.to_dict()
+ result[GraphConst.JSON_BENCH_KEY] = config.graph_b.to_dict()
+ else:
+ result = config.graph_n.to_dict()
+ if config.tool_tip:
+ result[GraphConst.JSON_TIP_KEY] = config.tool_tip
+ if config.node_colors:
+ result[GraphConst.COLORS] = config.node_colors
+ if config.micro_steps:
+ result[GraphConst.MICRO_STEPS] = config.micro_steps
+ save_json_file(filename, result)
+
+ @staticmethod
+ def _handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id):
+ """
+ 如果backward节点的父级节点是null,则尝试从同名的forward节点寻找父级节点
+ """
+ # 匹配以.backward.后跟一个或多个数字结尾的模式
+ backward_pattern = r"(\.backward\.)(\d+)$"
+ forward_pattern = r"(\.forward\.)(\d+)$"
+ if re.search(backward_pattern, subnode_id) and not upnode_id:
+ forward_upnode_id = construct_dict.get(re.sub(backward_pattern, r".forward.\2", subnode_id))
+ if forward_upnode_id:
+ new_upnode_id = re.sub(forward_pattern, r".backward.\2", forward_upnode_id)
+ if new_upnode_id in construct_dict:
+ return new_upnode_id
+ return upnode_id
+
+ @staticmethod
+ def _init_nodes(graph, construct_dict, data_dict):
+ for subnode_id, upnode_id in construct_dict.items():
+ upnode_id = GraphBuilder._handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id)
+ if upnode_id:
+ upnode_op = NodeOp.get_node_op(upnode_id)
+ upnode = GraphBuilder._create_or_get_node(graph, data_dict, upnode_op, upnode_id)
+ else:
+ upnode = graph.root
+ node_op = NodeOp.get_node_op(subnode_id)
+ GraphBuilder._create_or_get_node(graph, data_dict, node_op, subnode_id, upnode)
+
+ @staticmethod
+ def _create_or_get_node(graph, data_dict, op, name, upnode=None):
+ if name in graph.node_map:
+ node = graph.get_node(name)
+ else:
+ graph.add_node(op, name, upnode)
+ node = graph.get_node(name)
+ node_data = data_dict.get(name, {})
+ # 添加输入输出数据
+ input_data, output_data = get_input_output(node_data, node.id)
+ # 更新数据
+ node.set_input_output(input_data, output_data)
+ # 添加节点
+ node.add_upnode(upnode)
+ return node
+
+ @staticmethod
+ def _collect_apis_between_modules(graph):
+ """
+ 图首次展开,这些首层节点包含许多module和api,api数量很多导致图被拉得很长严重影响查阅,因此将module之间的apis收集起来成为节点
+ Args:
+ graph: 模型结构
+
+ Returns: None
+ """
+ i = 0
+ output = []
+ node_list = graph.root.subnodes
+ while i < len(node_list):
+ current_node = node_list[i]
+
+ # 当前节点为api,检查后续是否还有api
+ if current_node.op == NodeOp.function_api:
+ temp_nodes = [current_node]
+ i += 1
+ while i < len(node_list) and node_list[i].op == NodeOp.function_api:
+ temp_nodes.append(node_list[i])
+ i += 1
+
+ # 检查api节点是否大于等于2个
+ if len(temp_nodes) >= 2:
+ # 创建新节点,将这些api节点放入新节点的subnodes属性
+ node_id = graph.add_node(NodeOp.api_collection, GraphConst.APIS_BETWEEN_MODULES,
+ id_accumulation=True)
+ api_collection_node = graph.get_node(node_id)
+ api_collection_node.subnodes = temp_nodes
+ output.append(api_collection_node)
+ else:
+ # 如果连续的api节点不足2个,将它们原样添加到输出列表
+ output.extend(temp_nodes)
+ else:
+ # 如果当前节点为module,直接添加到输出列表
+ output.append(current_node)
+ i += 1
+
+ graph.root.subnodes = output
+
+
+class GraphExportConfig:
+ def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None):
+ self.graph_n = graph_n
+ self.graph_b = graph_b
+ self.tool_tip = tool_tip
+ self.node_colors = node_colors
+ self.micro_steps = micro_steps
diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/builder/msprobe_adapter.py b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/msprobe_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..92af6d67325c2b2ff2b4dbd34bf799b3219445de
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/msprobe_adapter.py
@@ -0,0 +1,210 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from msprobe.pytorch.compare.acc_compare import read_op, merge_tensor, get_accuracy, _do_multi_process
+from msprobe.core.common.utils import task_dumppath_get
+from msprobe.pytorch.visualization.utils import GraphConst
+
+
+# 用于将节点名字解析成对应的NodeOp的规则
+op_patterns = [
+ r'^(Module)', #NodeOp.module
+ r'^(Tensor|Torch|Functional|NPU|VF|Distributed|Aten)' #NodeOp.function_api
+]
+
+
+def get_compare_mode(dump_path_param):
+ """
+ 获得比较模式,包括summary、MD5和真实数据三种模式
+ Args:
+ dump_path_param: 调用acc_compare接口所依赖的参数
+ Returns: 0 summary mode, 1 md5 mode, 2 true data mode
+ """
+ summary_compare, md5_compare = task_dumppath_get(dump_path_param)
+ if summary_compare:
+ compare_mode = GraphConst.SUMMARY_COMPARE
+ elif md5_compare:
+ compare_mode = GraphConst.MD5_COMPARE
+ else:
+ compare_mode = GraphConst.REAL_DATA_COMPARE
+ return compare_mode
+
+
+def run_real_data(dump_path_param, csv_path):
+ """
+ 多进程运行生成真实数据
+ Args:
+ dump_path_param: 调用acc_compare接口所依赖的参数
+ csv_path: 生成文件路径
+ """
+ return _do_multi_process(dump_path_param, csv_path)
+
+
+def get_input_output(node_data, node_id):
+ """
+ 将dump的原始数据进行拆解,分解为output和input两个数据
+ Args:
+ node_data: 属于单个节点的dump数据
+ node_id: 节点名字
+ """
+ input_data = {}
+ output_data = {}
+ op_parsed_list = read_op(node_data, node_id)
+ for item in op_parsed_list:
+ full_op_name = item.get('full_op_name', '')
+ if not full_op_name:
+ continue
+ splits = full_op_name.split('.')
+ if len(splits) < GraphConst.OUTPUT_MIN_LEN:
+ continue
+ if GraphConst.OUTPUT in splits[GraphConst.OUTPUT_INDEX_TWO] and \
+ GraphConst.INPUT not in splits[GraphConst.OUTPUT_INDEX_THREE]:
+ output_data[full_op_name] = item
+ else:
+ input_data[full_op_name] = item
+ return input_data, output_data
+
+
+def compare_data(data_dict_list1, data_dict_list2):
+ """
+ 比较get_input_output中输出的结果是否结构一致,比较一致返回True
+ """
+ if len(data_dict_list1) != len(data_dict_list2):
+ return False
+ # 用于比较两个节点是否相等的关键字段
+ tag_keys = ['type', 'dtype', 'shape']
+ for key1, key2 in zip(data_dict_list1, data_dict_list2):
+ dict1 = data_dict_list1[key1]
+ dict2 = data_dict_list2[key2]
+ for tag_key in tag_keys:
+ tag_value1 = dict1.get(tag_key, None)
+ tag_value2 = dict2.get(tag_key, None)
+ if tag_value1 != tag_value2:
+ return False
+ return True
+
+
+def compare_mapping_data(data_dict_list1, data_dict_list2):
+ """
+ node1映射node2,可能node1参数多于或少于node2参数,个别参数的shape的维度顺序不同,node1参数null对应node2参数其他值
+ 工具要尽可能保证node的数据能够比对,进行数据的弱校验,仅校验参数的shape维度数值是否相同
+ """
+ for x, y in zip(data_dict_list1.values(), data_dict_list2.values()):
+ x_shape = x.get('shape')
+ y_shape = y.get('shape')
+ if x_shape is None or y_shape is None:
+ continue
+ x_shape = sorted(x_shape) if isinstance(x_shape, list) else x_shape
+ y_shape = sorted(y_shape) if isinstance(y_shape, list) else y_shape
+ if x_shape != y_shape:
+ return False
+ return True
+
+
+def format_node_data(data_dict):
+ """
+ 批量进行节点数据的输出
+ """
+ del_list = ['requires_grad', 'data_name', 'full_op_name']
+ for _, value in data_dict.items():
+ if not isinstance(value, dict):
+ continue
+ for item in del_list:
+ if item in value:
+ del value[item]
+ _format_data(value)
+ return data_dict
+
+
+def compare_node(node_ids, data_dicts, stack_json_data, is_summary_compare, is_md5_compare):
+ """
+ 调用acc_compare.py中的get_accuracy获得精度对比指标
+ 真实数据对比模式无法获得精度对比指标,需要调用多进程比对接口
+ Returns: 包含参数信息和对比指标(真实数据对比模式除外)的list
+ """
+ merge_n = _parse_node(node_ids[0], data_dicts[0], stack_json_data, is_summary_compare, is_md5_compare)
+ merge_b = _parse_node(node_ids[1], data_dicts[1], stack_json_data, is_summary_compare, is_md5_compare)
+ result = []
+ get_accuracy(result, merge_n, merge_b, is_summary_compare, is_md5_compare)
+ return result
+
+
+def _parse_node(node_id, data_dict, stack_json_data, is_summary_compare, is_md5_compare):
+ """
+ 转换节点,使其能够作为acc_compare.py中的get_accuracy的入参
+ """
+ op_parsed_list = read_op(data_dict.get(node_id, {}), node_id)
+ if node_id in stack_json_data:
+ op_parsed_list.append(
+ {'full_op_name': node_id, 'full_info': stack_json_data[node_id]})
+ else:
+ op_parsed_list.append({'full_op_name': node_id, 'full_info': None})
+ result = merge_tensor(op_parsed_list, is_summary_compare, is_md5_compare)
+ if not result:
+ result['op_name'] = []
+ return result
+
+
+def _format_decimal_string(s):
+ """
+ 使用正则表达式匹配包含数字、小数点和可选的百分号的字符串
+ """
+ pattern = re.compile(r'\d{1,20}\.\d{1,20}%?')
+ matches = pattern.findall(s)
+ for match in matches:
+ is_percent = match.endswith('%')
+ number_str = match.rstrip('%')
+ decimal_part = number_str.split('.')[1]
+ # 如果小数位数大于6,进行处理
+ if len(decimal_part) > GraphConst.ROUND_TH:
+ number_float = float(number_str)
+ formatted_number = f"{number_float:.{GraphConst.ROUND_TH}f}"
+ # 如果原来是百分数,加回百分号
+ if is_percent:
+ formatted_number += '%'
+ # 替换原字符串中的数值部分
+ s = s.replace(match, formatted_number)
+ return s
+
+
+def _format_data(data_dict):
+ """
+ 格式化数据,小数保留6位,处理一些异常值
+ """
+ pattern = r'^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)$'
+ none_num = 0
+ for key, value in data_dict.items():
+ if isinstance(value, str):
+ # 将单引号删掉,None换成null避免前端解析错误
+ value = value.replace("'", "").replace(GraphConst.NONE, GraphConst.NULL)
+ value = _format_decimal_string(value)
+ elif value is None or value == ' ':
+ value = GraphConst.NULL
+ # 科学计数法1.123123123123e-11,格式化为1.123123e-11
+ elif isinstance(value, float) and len(str(value)) < GraphConst.STR_MAX_LEN and re.match(pattern, str(value)):
+ value = "{:.6e}".format(value)
+ elif isinstance(value, float):
+ value = round(value, GraphConst.ROUND_TH)
+ # Inf会走入这里,确保转成Inf。另外给其他不符合预期的类型做兜底方案
+ if not isinstance(value, (list, tuple, dict, str)):
+ value = str(value)
+ if value == GraphConst.NULL or key == GraphConst.ERROR_KEY:
+ none_num += 1
+ data_dict[key] = value
+ # 字典里的value全null,只保留一个null
+ if none_num == len(data_dict):
+ data_dict.clear()
+ data_dict[GraphConst.VALUE] = GraphConst.NULL
diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/__init__.py b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5ec6c8db4be6431352570cfeb3644a7da816f35
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py
@@ -0,0 +1,131 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from msprobe.pytorch.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data
+from msprobe.pytorch.visualization.utils import GraphConst, load_json_file, load_data_json_file, get_csv_df
+from msprobe.pytorch.visualization.graph.graph import Graph, NodeOp
+from msprobe.pytorch.visualization.graph.node_colors import NodeColors
+from msprobe.pytorch.visualization.compare.mode_adapter import ModeAdapter
+
+
+class GraphComparator:
+ def __init__(self, graphs, data_paths, stack_path, output_path, mapping_config=None):
+ self.graph_n = graphs[0]
+ self.graph_b = graphs[1]
+ self._parse_param(data_paths, stack_path, output_path)
+ self.mapping_config = mapping_config
+
+ def compare(self):
+ """
+ 比较函数,初始化结束后单独调用。比较结果写入graph_n
+ """
+ self._compare_nodes(self.graph_n.root)
+ self._postcompare()
+
+ def add_compare_result_to_node(self, node, compare_result_list):
+ """
+ 将比对结果添加到节点的输入输出数据中
+ Args:
+ node: 节点
+ compare_result_list: 包含参数信息和对比指标(真实数据对比模式除外)的list
+ """
+ # 真实数据比对,先暂存节点,在多进程对比得到精度指标后,再将指标添加到节点中
+ if self.ma.prepare_real_data(node):
+ return
+ compare_in_dict = {}
+ compare_out_dict = {}
+ # input和output对比数据分开
+ for item in compare_result_list:
+ if not node.stack_info and node.id in item[0]:
+ node.stack_info = item[-1]
+ if 'output' in item[0]:
+ compare_out_dict[item[0]] = item
+ else:
+ compare_in_dict[item[0]] = item
+ precision_index, other_dict = (
+ self.ma.parse_result(node, [compare_in_dict, compare_out_dict]))
+ node.data[GraphConst.JSON_INDEX_KEY] = precision_index
+ node.data.update(other_dict)
+ if NodeColors.get_node_error_status(self.ma.compare_mode, precision_index):
+ self.ma.add_error_key(node.output_data)
+ node.get_suggestions()
+
+ def _parse_param(self, data_paths, stack_path, output_path):
+ self.dump_path_param = {
+ 'npu_json_path': data_paths[0],
+ 'bench_json_path': data_paths[1],
+ 'stack_json_path': stack_path,
+ 'is_print_compare_log': True
+ }
+ self.output_path = output_path
+ compare_mode = get_compare_mode(self.dump_path_param)
+ self.ma = ModeAdapter(compare_mode)
+ self.data_n_dict = load_data_json_file(data_paths[0])
+ self.data_b_dict = load_data_json_file(data_paths[1])
+ self.stack_json_data = load_json_file(stack_path)
+
+ def _postcompare(self):
+ self._handle_api_collection_index()
+ if not self.ma.is_real_data_compare():
+ return
+ df = get_csv_df(self.ma.is_md5_compare(), self.ma.is_summary_compare(), True, self.ma.csv_data)
+ df = run_real_data(self.dump_path_param, df)
+ compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()}
+ for node in self.ma.compare_nodes:
+ precision_index, _ = self.ma.parse_result(node, [compare_data_dict])
+ node.data[GraphConst.JSON_INDEX_KEY] = precision_index
+ if NodeColors.get_node_error_status(self.ma.compare_mode, precision_index):
+ self.ma.add_error_key(node.output_data)
+ node.get_suggestions()
+
+ def _handle_api_collection_index(self):
+ """
+ api集合的指标使用集合中所有api最小的指标
+ """
+ for node in self.graph_n.root.subnodes:
+ if node.op == NodeOp.api_collection:
+ precision_index = 1
+ for api in node.subnodes:
+ precision_index = min(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, 1))
+ node.data[GraphConst.JSON_INDEX_KEY] = precision_index
+
+ def _compare_nodes(self, node_n):
+ """
+ 递归遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比
+ 这里采用先序遍历,好处在于当这个节点被比较时,他的先序已经被匹配,这可以为后续的模糊匹配提供重要信息
+ """
+ if self.mapping_config:
+ node_b, ancestors_n, ancestors_b = Graph.mapping_match(node_n, self.graph_b, self.mapping_config)
+ if node_b:
+ ancestors_n.append(node_n.id)
+ ancestors_b.append(node_b.id)
+ node_n.matched_node_link = ancestors_b
+ node_b.matched_node_link = ancestors_n
+ else:
+ node_b, ancestors = Graph.match(self.graph_n, node_n, self.graph_b)
+ if node_b:
+ ancestors.append(node_b.id)
+ node_n.add_link(node_b, ancestors)
+ if node_b:
+ # 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口
+ compare_result_list = compare_node([node_n.id, node_b.id],
+ [self.data_n_dict, self.data_b_dict],
+ self.stack_json_data, self.ma.is_summary_compare(),
+ self.ma.is_md5_compare())
+ if compare_result_list:
+ self.ma.add_csv_data(compare_result_list)
+ self.add_compare_result_to_node(node_n, compare_result_list)
+ for subnode in node_n.subnodes:
+ self._compare_nodes(subnode)
diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fb16c8a789fc98522cfc4e62e8563a8c3a5a8bd
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py
@@ -0,0 +1,199 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from msprobe.core.common.const import CompareConst, Const
+from msprobe.pytorch.visualization.utils import ToolTip, GraphConst, str2float
+
+
+class ModeAdapter:
+ def __init__(self, compare_mode):
+ self.compare_mode = compare_mode
+ self.csv_data = []
+ self.compare_nodes = []
+
+ @staticmethod
+ def _add_md5_compare_data(node_data, compare_data_dict):
+ precision_index = GraphConst.MIN_INDEX_KEY
+ for key, value in node_data.items():
+ if not isinstance(value, dict):
+ continue
+ compare_data = compare_data_dict.get(key)
+ if compare_data:
+ headers = CompareConst.MD5_COMPARE_RESULT_HEADER
+ id_list = [headers.index(x) for x in GraphConst.MD5_INDEX_LIST]
+ ModeAdapter._match_data(value, compare_data, GraphConst.MD5_INDEX_LIST, id_list)
+ # md5比对是否通过
+ if value.get(CompareConst.RESULT) != CompareConst.PASS:
+ precision_index = GraphConst.MAX_INDEX_KEY
+ node_data[key] = value
+ return precision_index
+
+ @staticmethod
+ def _add_real_compare_data(node_data, compare_data_dict):
+ min_thousandth = float(1)
+ numbers = []
+ for key, value in node_data.items():
+ if not isinstance(value, dict):
+ continue
+ compare_data = compare_data_dict.get(key)
+ if compare_data:
+ headers = CompareConst.COMPARE_RESULT_HEADER
+ id_list = [headers.index(x) for x in GraphConst.REAL_DATA_INDEX_LIST]
+ ModeAdapter._match_data(value, compare_data, GraphConst.REAL_DATA_INDEX_LIST, id_list)
+ # 获取一个节点所有的输入或输出最小的双千指标
+ thousandth = value.get(CompareConst.ONE_THOUSANDTH_ERR_RATIO)
+ # 可能是None,可能是非数字内容str
+ try:
+ thousandth = float(thousandth)
+ except (ValueError, TypeError):
+ thousandth = None
+ if thousandth is not None:
+ numbers.append(thousandth)
+ node_data[key] = value
+ # 双千指标都是None的异常情况
+ if not numbers:
+ min_thousandth = None
+ else:
+ min_thousandth = min(numbers + [min_thousandth])
+ return min_thousandth
+
+ @staticmethod
+ def _add_summary_compare_data( node_data, compare_data_dict):
+ max_relative_err = 0
+ for key, value in node_data.items():
+ if not isinstance(value, dict):
+ continue
+ compare_data = compare_data_dict.get(key)
+ if compare_data:
+ # 对应比对结果csv的列
+ key_list = GraphConst.SUMMARY_INDEX_LIST
+ headers = CompareConst.SUMMARY_COMPARE_RESULT_HEADER
+ id_list = [headers.index(x) for x in key_list]
+ ModeAdapter._match_data(value, compare_data, key_list, id_list)
+ # 相对误差大于0.5疑似有精度问题,小值域1e-3不比较相对误差
+ for index, item in enumerate(key_list[4:]):
+ value_diff = value.get(key_list[index])
+ if isinstance(value_diff, float) and value_diff != 0 and abs(value_diff) < GraphConst.SMALL_VALUE:
+ value[item] = ToolTip.SMALL_VALUE_TIP.format(key_list[index])
+ continue
+ relative_err = str2float(value.get(item))
+ max_relative_err = max(max_relative_err, relative_err)
+ node_data[key] = value
+ max_relative_err = 1 if max_relative_err > 1 else max_relative_err
+ return max_relative_err
+
+ @staticmethod
+ def _match_data(data_dict, compare_data, key_list, id_list):
+ """
+ 绑定精度指标到node的input_data和output_data
+ """
+ if len(key_list) != len(id_list):
+ return
+ for id, key in zip(id_list, key_list):
+ data = compare_data[id]
+ if data is not None and 'nan' not in str(data) and str(data) != ' ':
+ data_dict[key] = data
+ else:
+ data_dict[key] = 'null'
+
+ def parse_result(self, node, compare_data_dict):
+ """
+ 根据结果返回数据,分别是precision_index,和附加数据
+ """
+ other_dict = {}
+ if self.is_md5_compare():
+ precision_index_in = ModeAdapter._add_md5_compare_data(node.input_data, compare_data_dict[0])
+ precision_index_out = ModeAdapter._add_md5_compare_data(node.output_data, compare_data_dict[1])
+ # 所有输入输出md5对比通过,这个节点才算通过
+ precision_index = max(precision_index_in, precision_index_out)
+ other_result = CompareConst.PASS if precision_index == 1 else CompareConst.DIFF
+ other_dict[CompareConst.RESULT] = other_result
+ elif self.is_summary_compare():
+ precision_index_in = ModeAdapter._add_summary_compare_data(node.input_data, compare_data_dict[0])
+ precision_index_out = ModeAdapter._add_summary_compare_data(node.output_data, compare_data_dict[1])
+ precision_index = max(precision_index_in, precision_index_out)
+ else:
+ min_thousandth_in = ModeAdapter._add_real_compare_data(node.input_data, compare_data_dict[0])
+ min_thousandth_out = ModeAdapter._add_real_compare_data(node.output_data, compare_data_dict[0])
+ if min_thousandth_in is not None and min_thousandth_out is not None:
+ change_percentage = abs(min_thousandth_in - min_thousandth_out)
+ else:
+ change_percentage = 0
+ precision_index = GraphConst.MAX_INDEX_KEY \
+ if change_percentage > GraphConst.MAX_INDEX_KEY else change_percentage
+ return precision_index, other_dict
+
+ def prepare_real_data(self, node):
+ """
+ 为真实数据比较模式准备节点信息
+ """
+ if self.is_real_data_compare():
+ self.compare_nodes.append(node)
+ return True
+ return False
+
+ def is_summary_compare(self):
+ return self.compare_mode == GraphConst.SUMMARY_COMPARE
+
+ def is_md5_compare(self):
+ return self.compare_mode == GraphConst.MD5_COMPARE
+
+ def is_real_data_compare(self):
+ return self.compare_mode == GraphConst.REAL_DATA_COMPARE
+
+ def add_csv_data(self, compare_result_list):
+ if not self.is_real_data_compare():
+ return
+ self.csv_data.extend(compare_result_list)
+
+ def add_error_key(self, node_data):
+ """
+ 根据不同的模式进行提供不同错误信息
+ """
+ for key, value in node_data.items():
+ if not isinstance(value, dict):
+ continue
+ if self.is_summary_compare():
+ message = [CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR,
+ CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR]
+ elif self.is_real_data_compare():
+ message = [CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO]
+ else:
+ # 输出件优化
+ message = []
+ value[GraphConst.ERROR_KEY] = message
+ node_data[key] = value
+
+ def get_tool_tip(self):
+ """
+ 用于前端展示字段的具体含义
+ """
+ if self.is_summary_compare():
+ tips = {
+ CompareConst.MAX_DIFF: ToolTip.MAX_DIFF,
+ CompareConst.MIN_DIFF: ToolTip.MIN_DIFF,
+ CompareConst.MEAN_DIFF: ToolTip.MEAN_DIFF,
+ CompareConst.NORM_DIFF: ToolTip.NORM_DIFF}
+ elif self.is_md5_compare():
+ tips = {Const.MD5: ToolTip.MD5}
+ else:
+ tips = {
+ CompareConst.ONE_THOUSANDTH_ERR_RATIO: ToolTip.ONE_THOUSANDTH_ERR_RATIO,
+ CompareConst.FIVE_THOUSANDTHS_ERR_RATIO: ToolTip.FIVE_THOUSANDTHS_ERR_RATIO,
+ CompareConst.COSINE: ToolTip.COSINE,
+ CompareConst.MAX_ABS_ERR: ToolTip.MAX_ABS_ERR,
+ CompareConst.MAX_RELATIVE_ERR: ToolTip.MAX_RELATIVE_ERR}
+ return json.dumps(tips)
diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/__init__.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..87f020a956b24f4b5884e3e60b72513f6f39625f
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py
@@ -0,0 +1,119 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from msprobe.pytorch.visualization.graph.node_op import NodeOp
+from msprobe.pytorch.visualization.utils import Suggestions, GraphConst
+from msprobe.pytorch.visualization.builder.msprobe_adapter import format_node_data, compare_data, compare_mapping_data
+
+
+class BaseNode:
+ def __init__(self, node_op, node_id, up_node=None):
+ self.op = node_op
+ self.id = node_id
+ self.data = {}
+ self.output_data = {}
+ self.input_data = {}
+ self.upnode = None
+ self.add_upnode(up_node)
+ self.subnodes = []
+ self.matched_node_link = []
+ self.suggestions = {}
+ self.stack_info = []
+ self.micro_step_id = None
+
+ def __str__(self):
+ info = f'id:\t{self.id}'
+ return info
+
+ def __eq__(self, other):
+ """
+ 用来判断两个节点是否可以被匹配上,认为结构上是否一致
+ """
+ if not compare_data(self.input_data, other.input_data):
+ return False
+ if not compare_data(self.output_data, other.output_data):
+ return False
+ return True
+
+ def compare_mapping_node(self, other):
+ if not compare_mapping_data(self.input_data, other.input_data):
+ return False
+ if not compare_mapping_data(self.output_data, other.output_data):
+ return False
+ return True
+
+ def get_suggestions(self):
+ """
+ 精度疑似有问题时,提供一些建议
+ """
+ if self.op == NodeOp.module:
+ self.suggestions[GraphConst.SUGGEST_KEY] = Suggestions.Module
+ self.suggestions[Suggestions.DUMP] = Suggestions.DUMP_URL
+ elif self.op == NodeOp.function_api:
+ self.suggestions[GraphConst.SUGGEST_KEY] = Suggestions.API
+ self.suggestions[Suggestions.API_ACCURACY_CHECKER] = Suggestions.API_ACCURACY_CHECKER_URL
+
+ def set_input_output(self, input_data, output_data):
+ self.input_data = input_data
+ self.output_data = output_data
+
+ def add_upnode(self, node):
+ """
+ 绑定upnode,用于对两个节点进行上下级关联
+ """
+ if not node or node.id == self.id or self.upnode:
+ return
+ self.upnode = node
+ node.subnodes.append(self)
+
+ def add_link(self, node, ancestors):
+ """
+ 在节点匹配成功后进行匹配数据的录入
+ Args:
+ node: 和self相互匹配的节点
+ ancestors: 对面节点的祖先信息
+ """
+ self.matched_node_link = ancestors
+ node.matched_node_link = ancestors
+
+ def to_dict(self):
+ """
+ 输出数据
+ """
+ result = {}
+ result['id'] = self.id
+ result['node_type'] = self.op.value
+ result['data'] = self.data
+ result['output_data'] = format_node_data(self.output_data)
+ result['input_data'] = format_node_data(self.input_data)
+ result['upnode'] = self.upnode.id if self.upnode else 'None'
+ result['subnodes'] = [node.id for node in self.subnodes]
+ result['matched_node_link'] = self.matched_node_link
+ result['suggestions'] = self.suggestions
+ result['stack_info'] = self.stack_info
+ if self.micro_step_id is not None:
+ result['micro_step_id'] = self.micro_step_id
+ return result
+
+ def get_ancestors(self):
+ """
+ 获取节点所有祖先的列表
+ """
+ ancestors = []
+ current_node = self.upnode
+ while current_node:
+ ancestors.append(current_node.id)
+ current_node = current_node.upnode
+ return list(reversed(ancestors))
diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/graph.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ca5c6811b496808c521f8bfd5df58adb4605c1e
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/graph.py
@@ -0,0 +1,167 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from msprobe.pytorch.visualization.graph.base_node import BaseNode
+from msprobe.pytorch.visualization.graph.node_op import NodeOp
+from msprobe.pytorch.visualization.utils import GraphConst
+from msprobe.core.common.log import logger
+from msprobe.core.common.const import Const
+
+
+class Graph:
+ def __init__(self, model_name):
+ self.node_map = {}
+ self.node_id_map = {}
+ self.add_node(NodeOp.module, model_name)
+ self.root = self.get_node(model_name)
+
+ def __str__(self):
+ infos = [f'{str(self.node_map.get(node_id))}' for node_id in self.node_map]
+ info = "\n".join(infos)
+ return info
+
+ @staticmethod
+ def match(graph_n, node_n, graph_b):
+ """
+ 给定节点n,在另一个graph中匹配它对应的节点。前置条件是它的父节点匹配已经完成
+ 目前采用完全匹配的方式,后续可能在这里加入一定的模糊匹配逻辑
+ 返回匹配结果,匹配到的节点,以及祖先树。没匹配到则返回None, []
+ """
+ if not node_n or node_n.id not in graph_b.node_map:
+ return None, []
+ node_b = graph_b.node_map.get(node_n.id)
+ if node_n != node_b:
+ return None, []
+ ancestors_n = node_n.get_ancestors()
+ ancestors_b = node_b.get_ancestors()
+ if ancestors_n != ancestors_b:
+ return None, []
+ return node_b, ancestors_n
+
+ @staticmethod
+ def mapping_match(node_n, graph_b, mapping_config):
+ """
+ 根据映射配置对节点进行匹配
+ """
+ node_b = graph_b.node_map.get(mapping_config.get_mapping_string(node_n.id))
+ if not node_b or not node_n.compare_mapping_node(node_b):
+ return None, [], []
+ ancestors_n = node_n.get_ancestors()
+ ancestors_b = node_b.get_ancestors()
+ return node_b, ancestors_n, ancestors_b
+
+ @staticmethod
+ def dfs(node, result):
+ info = node.to_dict()
+ result[node.id] = info
+ for subnode in node.subnodes:
+ Graph.dfs(subnode, result)
+
+ @staticmethod
+ def split_nodes_by_micro_step(nodes):
+ """
+ 根据Module名称后缀数字, 区分一个step中的多个micro steps, 后缀数字相同代表节点属于同一个micro step.
+ 如果是非Module节点,分类到前一个Module节点所在的micro step.
+ """
+ result = {}
+ default_id = 0
+ result[default_id] = []
+
+ for node in nodes:
+ if node.op == NodeOp.module:
+ micro_step_id = node.id.split(Const.SEP)[-1]
+ try:
+ micro_step_id = int(micro_step_id)
+ except ValueError:
+ logger.warning(f'The node id suffix {micro_step_id} is not a number, micro steps cannot be split.')
+ micro_step_id = 0
+ if micro_step_id not in result:
+ default_id = micro_step_id
+ result[micro_step_id] = []
+ result[micro_step_id].append(node)
+ else:
+ result[default_id].append(node)
+ return result
+
+ def add_node(self, node_op, node_id, up_node=None, id_accumulation=False):
+ """
+ 在graph中进行节点的添加
+ Args:
+ node_op: 需要添加的节点类型
+ node_id: 需要添加的节点id
+ up_node:对应节点的父节点
+ id_accumulation: 是否对传入的重复node_id进行累加
+ """
+ if node_id in self.node_map:
+ if id_accumulation:
+ self.node_id_map[node_id] = 0
+ else:
+ return node_id
+ if id_accumulation:
+ if node_id in self.node_id_map:
+ self.node_id_map[node_id] += 1
+ else:
+ self.node_id_map[node_id] = 0
+ node_id = f'{node_id}.{self.node_id_map[node_id]}'
+ node = BaseNode(node_op, node_id, up_node)
+ self.node_map[node_id] = node
+ return node_id
+
+ def get_node(self, node_id):
+ """
+ 返回节点,不存在返回None
+ """
+ return self.node_map.get(node_id, None)
+
+ def to_dict(self):
+ """
+ 用于数据输出
+ """
+ result = {}
+ result[GraphConst.JSON_ROOT_KEY] = self.root.id if self.root else 'None'
+ result[GraphConst.JSON_NODE_KEY] = {}
+ for node_id in self.node_map:
+ info = self.node_map.get(node_id).to_dict()
+ result[GraphConst.JSON_NODE_KEY][node_id] = info
+ return result
+
+ def paging_by_micro_step(self, graph_other=None):
+ """
+ 给graph首层节点增加micro step标记,供前端分页展示,有助于在处理大规模图数据时进行优化和管理
+ 比对场景中,同步更新另一个图graph_other中相应节点的micro step信息
+ Args:
+ self: 当前graph
+ graph_other: 可选参数,另一个graph
+ Returns: 分批的数量
+ """
+ batches_n = Graph.split_nodes_by_micro_step(self.root.subnodes)
+ for batch_number, nodes in batches_n.items():
+ for node in nodes:
+ node.micro_step_id = batch_number
+ # 在graph_other中更新已匹配节点的micro_step_id
+ if graph_other and node.matched_node_link:
+ node_other = graph_other.get_node(node.matched_node_link[-1])
+ if node_other:
+ node_other.micro_step_id = batch_number
+ # 遍历graph_other根节点下的所有子节点,确保未匹配节点也有micro_step_id
+ if graph_other:
+ for node in graph_other.root.subnodes:
+ if node.micro_step_id is None:
+ try:
+ micro_step_id = int(node.id.split(Const.SEP)[-1])
+ except ValueError:
+ micro_step_id = 0
+ node.micro_step_id = micro_step_id
+ return len(batches_n)
diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/node_colors.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/node_colors.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e07625140d284afac60ef1f1cf3b40e3a6f4ea7
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/node_colors.py
@@ -0,0 +1,95 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from enum import Enum
+from msprobe.pytorch.visualization.utils import GraphConst, ToolTip
+
+SUMMARY_DESCRIPTION = "此节点所有输入输出的统计量相对误差, 值越大代表测量值与标杆值的偏差越大, 相对误差计算方式:|(测量值-标杆值)/标杆值|"
+REAL_DATA_DESCRIPTION = (f"此节点所有输入的最小双千分之一和所有输出的最小双千分之一的差值的绝对值, 代表双千指标的变化情况, "
+ f"值越大代表测量值与标杆值的偏差越大, 双千分之一指标计算方式:{ToolTip.ONE_THOUSANDTH_ERR_RATIO}")
+MD5_DESCRIPTION_N = "与标杆相比, 此节点任意输入输出的md5值不同"
+MD5_DESCRIPTION_Y = "与标杆相比, 此节点所有输入输出的md5值相同"
+NOT_MATCHED = "比对过程中节点未匹配上"
+
+
+class NodeColors(Enum):
+ # 枚举值后缀数字越小, 颜色越浅
+ # value值左闭右开, 两个值相同代表固定值
+ YELLOW_1 = ("#FFFCF3", {
+ GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0, 0.2], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION},
+ GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0, 0.05], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION},
+ GraphConst.MD5_COMPARE: {GraphConst.VALUE: [1, 1], GraphConst.DESCRIPTION: MD5_DESCRIPTION_Y},
+ })
+ YELLOW_2 = ("#FFEDBE", {
+ GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.2, 0.4], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION},
+ GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.05, 0.1], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION}
+ })
+ ORANGE_1 = ("#FFDC7F", {
+ GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.4, 0.6], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION},
+ GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.1, 0.15], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION}
+ })
+ ORANGE_2 = ("#FFC62E", {
+ GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.6, 0.8], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION},
+ GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.15, 0.2], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION}
+ })
+ RED = ("#E32020", {
+ GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.8, 1], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION},
+ GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.2, 1], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION},
+ GraphConst.MD5_COMPARE: {GraphConst.VALUE: [0, 0], GraphConst.DESCRIPTION: MD5_DESCRIPTION_N},
+ })
+ GREY = ("#C7C7C7", {
+ GraphConst.VALUE: [], GraphConst.DESCRIPTION: NOT_MATCHED
+ })
+
+ def __init__(self, hex_value, mode_info):
+ self.hex_value = hex_value
+ self.mode_info = mode_info
+
+ @staticmethod
+ def get_node_colors(mode):
+ """
+ 获取不同比对模式下的颜色说明
+ Args:
+ mode: 比对模式
+ Returns: 颜色说明
+ """
+ return {
+ color.hex_value: color.get_info_by_mode(mode) for color in NodeColors if color.get_info_by_mode(mode)
+ }
+
+ @staticmethod
+ def get_node_error_status(mode, value):
+ """
+ 判断精度数据比对指标是否大于基准值
+ Args:
+ mode: 比对模式
+ value: 精度数据比对指标
+ Returns: bool
+ """
+ info = NodeColors.ORANGE_1.get_info_by_mode(mode)
+ if info and GraphConst.VALUE in info:
+ value_range = info[GraphConst.VALUE]
+ return value > value_range[0]
+ return False
+
+ def get_info_by_mode(self, mode):
+ if isinstance(self.mode_info, dict):
+ # 检查是否是模式特定的信息
+ if isinstance(next(iter(self.mode_info.values())), dict):
+ return self.mode_info.get(mode, {})
+ else:
+ # 所有模式共享相同的信息
+ return self.mode_info
+ return {}
diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/node_op.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/node_op.py
new file mode 100644
index 0000000000000000000000000000000000000000..9be4923c5d6913d3e4885fca34227f2db75e28d6
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/node_op.py
@@ -0,0 +1,38 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from enum import Enum
+import re
+from msprobe.pytorch.visualization.builder.msprobe_adapter import op_patterns
+
+
+class NodeOp(Enum):
+ module = 0
+ function_api = 1
+ api_collection = 9
+
+ @staticmethod
+ def get_node_op(node_name: str):
+ """
+ 基于代表节点的字符串,解析节点种类
+ """
+ for op in NodeOp:
+ index = op.value
+ if index < 0 or index >= len(op_patterns):
+ raise Exception("NodeOp and op_patterns in MsprobeAdapter do not match")
+ pattern = op_patterns[index]
+ if re.match(pattern, node_name):
+ return op
+ raise Exception(f"Cannot parse node_name {node_name} into NodeOp")
diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph_service.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ab966a3f33214095cbb6e01872a0878e63b59b7
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph_service.py
@@ -0,0 +1,59 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import time
+from msprobe.pytorch.visualization.compare.graph_comparator import GraphComparator
+from msprobe.pytorch.visualization.utils import GraphConst
+from msprobe.pytorch.visualization.builder.graph_builder import GraphBuilder, GraphExportConfig
+from msprobe.core.common.log import logger
+from msprobe.pytorch.visualization.mapping_config import MappingConfig
+from msprobe.pytorch.visualization.graph.node_colors import NodeColors
+
+current_time = time.strftime("%Y%m%d%H%M%S")
+
+
+def compare_graph(dump_path_n, dump_path_b, out_path, model_name='Model', mapping_file=None):
+ logger.info('Start building model graphs...')
+ # 对两个数据进行构图
+ construct_path_n = os.path.join(dump_path_n, GraphConst.CONSTRUCT_FILE)
+ construct_path_b = os.path.join(dump_path_b, GraphConst.CONSTRUCT_FILE)
+ data_path_n = os.path.join(dump_path_n, GraphConst.DUMP_FILE)
+ data_path_b = os.path.join(dump_path_b, GraphConst.DUMP_FILE)
+ graph_n = GraphBuilder.build(construct_path_n, data_path_n, model_name)
+ graph_b = GraphBuilder.build(construct_path_b, data_path_b, model_name)
+ logger.info('Model graphs built successfully, start Comparing graphs...')
+ # 基于graph、stack和data进行比较
+ stack_path = os.path.join(dump_path_n, GraphConst.STACK_FILE)
+ graph_comparator = GraphComparator([graph_n, graph_b], [data_path_n, data_path_b], stack_path, out_path,
+ mapping_config=MappingConfig(mapping_file) if mapping_file else None)
+ graph_comparator.compare()
+ micro_steps = graph_n.paging_by_micro_step(graph_b)
+ output_path = os.path.join(out_path, f'compare_{current_time}.vis')
+ export_config = GraphExportConfig(graph_n, graph_b, graph_comparator.ma.get_tool_tip(),
+ NodeColors.get_node_colors(graph_comparator.ma.compare_mode), micro_steps)
+ GraphBuilder.to_json(output_path, export_config)
+ logger.info(f'Model graphs compared successfully, the result file is saved in {output_path}')
+
+
+def build_graph(dump_path, out_path, model_name='Model'):
+ logger.info('Start building model graph...')
+ construct_path = os.path.join(dump_path, GraphConst.CONSTRUCT_FILE)
+ data_path = os.path.join(dump_path, GraphConst.DUMP_FILE)
+ output_path = os.path.join(out_path, f'build_{current_time}.vis')
+ graph = GraphBuilder.build(construct_path, data_path, model_name)
+ micro_steps = graph.paging_by_micro_step()
+ GraphBuilder.to_json(output_path, GraphExportConfig(graph, micro_steps=micro_steps))
+ logger.info(f'Model graph built successfully, the result file is saved in {output_path}')
diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/mapping_config.py b/debug/accuracy_tools/msprobe/pytorch/visualization/mapping_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d986493513a345483bd6c32f808cca002b3cba5b
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/visualization/mapping_config.py
@@ -0,0 +1,77 @@
+import re
+import yaml
+from msprobe.core.common.file_utils import FileOpen
+from msprobe.core.common.const import Const
+from msprobe.pytorch.visualization.utils import GraphConst
+
+
+class MappingConfig:
+ MAX_STRING_LEN = 10000
+
+ def __init__(self, yaml_file):
+ with FileOpen(yaml_file, 'r') as file:
+ config = yaml.safe_load(file)
+ try:
+ self.config = {key: self.validate(key, value) for data in config for key, value in data.items()}
+ except Exception as e:
+ raise RuntimeError("Line of yaml contains content that is not '- key: value'.") from e
+ self.classify_config = self._classify_and_sort_keys()
+
+ @staticmethod
+ def validate(key, value):
+ if not isinstance(key, str):
+ raise ValueError(f"{key} must be a string.")
+ if not isinstance(value, str):
+ raise ValueError(f"{value} must be a string.")
+ return value
+
+ @staticmethod
+ def convert_to_regex(s):
+ """
+ 字符串转换为正则表达式, {}替换为d+以匹配一个或多个数字, 开始和结束添加.*以匹配任意前缀和后缀
+ Args:
+ s: 字符串
+ Returns: 正则表达式
+ """
+ escaped_pattern = re.escape(s)
+ pattern = re.sub(r'\\\{\\\}', r'\\d+', escaped_pattern)
+ pattern = f'.*{pattern}.*'
+ return pattern
+
+ @staticmethod
+ def _replace_parts(origin_string, mapping_key, mapping_value):
+ if GraphConst.BRACE in mapping_key:
+ parts = mapping_key.split(GraphConst.BRACE)
+ m_parts = mapping_value.split(GraphConst.BRACE)
+ return origin_string.replace(parts[0], m_parts[0]).replace(parts[1], m_parts[1])
+ else:
+ return origin_string.replace(mapping_key, mapping_value)
+
+ def get_mapping_string(self, origin_string: str):
+ if len(origin_string) > MappingConfig.MAX_STRING_LEN:
+ return origin_string
+ for category, items in self.classify_config.items():
+ if category in origin_string:
+ for key, value in items:
+ if re.match(MappingConfig.convert_to_regex(key), origin_string):
+ return MappingConfig._replace_parts(origin_string, key, value)
+ return origin_string
+
+ def _classify_and_sort_keys(self):
+ categorized_dict = {}
+ for key, value in self.config.items():
+ parts = key.split(Const.SEP)
+ # 获取第一个部分作为新的分类key
+ category_key = parts[0]
+
+ if category_key not in categorized_dict:
+ categorized_dict[category_key] = []
+
+ # 将原始的key-value对添加到对应的分类中
+ categorized_dict[category_key].append((key, value))
+
+ # 对每个分类中的项按key中的.数量进行排序, .数量越多排越靠前, 优先匹配
+ for category in categorized_dict:
+ categorized_dict[category].sort(key=lambda x: -x[0].count(Const.SEP))
+
+ return categorized_dict
diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py b/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a81f4f20922bb9fa58ce1d1e358ed1e0ce12c74
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py
@@ -0,0 +1,138 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from msprobe.core.common.file_utils import FileOpen
+from msprobe.core.common.const import CompareConst
+from msprobe.pytorch.compare.acc_compare import result_to_csv
+
+
+def load_json_file(file_path):
+ """
+ 加载json文件
+ """
+ try:
+ with FileOpen(file_path, 'r') as f:
+ file_dict = json.load(f)
+ if not isinstance(file_dict, dict):
+ return {}
+ return file_dict
+ except json.JSONDecodeError:
+ return {}
+
+
+def load_data_json_file(file_path):
+ """
+ 加载dump.json中的data字段
+ """
+ return load_json_file(file_path).get(GraphConst.DATA_KEY, {})
+
+
+def save_json_file(file_path, data):
+ """
+ 保存json文件
+ """
+ with FileOpen(file_path, 'w') as f:
+ f.write(json.dumps(data, indent=4))
+
+
+def get_csv_df(md5_compare, summary_compare, stack, csv_data):
+ """
+ 调用acc接口写入csv
+ """
+ return result_to_csv(md5_compare, summary_compare, stack, csv_data)
+
+
+def str2float(percentage_str):
+ """
+ 百分比字符串转换转换为浮点型
+ Args:
+ percentage_str: '0.00%', '23.4%'
+ Returns: float 0.00, 0.234
+ """
+ try:
+ percentage_str = percentage_str.strip('%')
+ return float(percentage_str) / 100
+ except ValueError:
+ return 0
+
+
+class ToolTip:
+ MAX_DIFF = 'NPU与标杆API统计信息比对,最大值的差值'
+ MIN_DIFF = 'NPU与标杆API统计信息比对,最小值的差值'
+ MEAN_DIFF = 'NPU与标杆API统计信息比对,平均值的差值'
+ NORM_DIFF = 'NPU与标杆API统计信息比对,2范数(平方根)的差值'
+ MD5 = '数据MD5信息,用于比较两个数据信息是否完全一致'
+ ONE_THOUSANDTH_ERR_RATIO = 'Tensor中的元素逐个与对应的标杆数据对比,相对误差小于千分之一的比例占总元素个数的比例,比例越接近1越好'
+ FIVE_THOUSANDTHS_ERR_RATIO = 'Tensor中的元素逐个与对应的标杆数据对比,相对误差小于千分之五的比例占总元素个数的比例,比例越接近1越好'
+ COSINE = '通过计算两个向量的余弦值来判断其相似度,数值越接近于1说明计算出的两个张量越相似,实际可接受阈值为大于0.99。在计算中可能会存在nan,主要由于可能会出现其中一个向量为0'
+ MAX_ABS_ERR = '当最大绝对误差越接近0表示其计算的误差越小,实际可接受阈值为小于0.001'
+ MAX_RELATIVE_ERR = '当最大相对误差越接近0表示其计算的误差越小。当dump数据中存在0或Nan时,比对结果中最大相对误差则出现inf或Nan的情况,属于正常现象'
+ SMALL_VALUE_TIP = '{} 小于1e-3,不计算相对误差'
+
+
+class Suggestions:
+ Module = '此模块精度比对结果疑似异常,请使用msprobe工具的数据采集功能对模块中的api进行dump比对'
+ API = '此api精度比对结果疑似异常,请使用msprobe工具的预检功能对api进行精度检测'
+ DUMP = 'msprobe工具的数据采集功能'
+ DUMP_URL = 'https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/pytorch/doc/dump.md'
+ API_ACCURACY_CHECKER = 'msprobe工具的预检功能'
+ API_ACCURACY_CHECKER_URL = 'https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/pytorch/doc/api_accuracy_checker.md'
+
+
+class GraphConst:
+ CONSTRUCT_FILE = 'construct.json'
+ DUMP_FILE = 'dump.json'
+ STACK_FILE = 'stack.json'
+ GRAPH_FILE = 'graph.vis'
+ ERROR_KEY = 'error_key'
+ SUMMARY_COMPARE = 0
+ MD5_COMPARE = 1
+ REAL_DATA_COMPARE = 2
+ JSON_NPU_KEY = 'NPU'
+ JSON_BENCH_KEY = 'Bench'
+ JSON_TIP_KEY = 'ToolTip'
+ JSON_ROOT_KEY = 'root'
+ JSON_NODE_KEY = 'node'
+ DATA_KEY = 'data'
+ REAL_DATA_TH = 0.1
+ MAX_RELATIVE_ERR_TH = 0.5
+ ROUND_TH = 6
+ JSON_INDEX_KEY = 'precision_index'
+ MAX_INDEX_KEY = 1
+ MIN_INDEX_KEY = 0
+ SUGGEST_KEY = 'text'
+ TAG_NA = 'na'
+ OUTPUT_INDEX_TWO = -2
+ OUTPUT_INDEX_THREE = -3
+ OUTPUT_MIN_LEN = 3
+ INPUT = 'input'
+ OUTPUT = 'output'
+ STR_MAX_LEN = 50
+ SMALL_VALUE = 1e-3
+ MD5_INDEX_LIST = [CompareConst.RESULT]
+ REAL_DATA_INDEX_LIST = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
+ CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO]
+ SUMMARY_INDEX_LIST = [CompareConst.MAX_DIFF, CompareConst.MIN_DIFF, CompareConst.MEAN_DIFF,
+ CompareConst.NORM_DIFF, CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR,
+ CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR]
+ APIS_BETWEEN_MODULES = 'Apis_Between_Modules'
+ NULL = 'null'
+ NONE = 'None'
+ VALUE = 'value'
+ BRACE = '{}'
+ DESCRIPTION = 'description'
+ COLORS = 'Colors'
+ MICRO_STEPS = 'MicroSteps'
diff --git a/debug/accuracy_tools/atat/test/core_ut/test_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py
similarity index 90%
rename from debug/accuracy_tools/atat/test/core_ut/test_utils.py
rename to debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py
index b3273358e43593e14f931a5675e9039c684d1c59..27ffe411586707e384a404edfa39b6ac5a013464 100644
--- a/debug/accuracy_tools/atat/test/core_ut/test_utils.py
+++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py
@@ -20,9 +20,9 @@ import uuid
from unittest import TestCase
from unittest.mock import patch, MagicMock, mock_open
-from atat.core.common.log import logger
-from atat.core.common.const import Const
-from atat.core.common.utils import (CompareException,
+from msprobe.core.common.log import logger
+from msprobe.core.common.const import Const
+from msprobe.core.common.utils import (CompareException,
check_seed_all,
check_inplace_op,
make_dump_path_if_not_exists,
@@ -41,7 +41,7 @@ from atat.core.common.utils import (CompareException,
check_regex_prefix_format_valid,
get_dump_data_path,
task_dumppath_get)
-from atat.core.common.file_check import FileCheckConst
+from msprobe.core.common.file_utils import FileCheckConst
class TestUtils(TestCase):
@@ -88,7 +88,7 @@ class TestUtils(TestCase):
raise OSError
if not os.path.exists(dirname):
- with patch("atat.core.common.utils.Path.mkdir", new=test_mkdir):
+ with patch("msprobe.core.common.utils.Path.mkdir", new=test_mkdir):
with self.assertRaises(CompareException) as context:
make_dump_path_if_not_exists(dirname)
self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR)
@@ -171,7 +171,7 @@ class TestUtils(TestCase):
file_path = os.path.realpath(__file__)
dirname = os.path.dirname(file_path)
- with patch("atat.core.common.utils.FileChecker", new=TestFileChecker):
+ with patch("msprobe.core.common.utils.FileChecker", new=TestFileChecker):
check_file_or_directory_path(file_path, isdir=False)
self.assertTrue(TestFileChecker.checked)
self.assertEqual(TestFileChecker.file_path, file_path)
@@ -179,7 +179,7 @@ class TestUtils(TestCase):
self.assertEqual(TestFileChecker.ability, FileCheckConst.READ_ABLE)
TestFileChecker.checked = False
- with patch("atat.core.common.utils.FileChecker", new=TestFileChecker):
+ with patch("msprobe.core.common.utils.FileChecker", new=TestFileChecker):
check_file_or_directory_path(dirname, isdir=True)
self.assertTrue(TestFileChecker.checked)
self.assertEqual(TestFileChecker.file_path, dirname)
@@ -216,9 +216,9 @@ class TestUtils(TestCase):
mock_check_file_or_directory_path = MagicMock()
mock_check_json_file = MagicMock()
- with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \
- patch("atat.core.common.utils.check_json_file", new=mock_check_json_file), \
- patch("atat.core.common.utils.check_file_or_directory_path", new=mock_check_file_or_directory_path):
+ with patch("msprobe.core.common.utils.FileOpen", mock_open(read_data="")), \
+ patch("msprobe.core.common.utils.check_json_file", new=mock_check_json_file), \
+ patch("msprobe.core.common.utils.check_file_or_directory_path", new=mock_check_file_or_directory_path):
check_compare_param(params, "output_path")
check_compare_param(params, "output_path", summary_compare=False, md5_compare=True)
for i in range(len(call_args)):
@@ -261,7 +261,7 @@ class TestUtils(TestCase):
_check_json(handler, "test.json")
self.assertEqual(handler.string, "0_0")
- @patch("atat.core.common.utils._check_json")
+ @patch("msprobe.core.common.utils._check_json")
def test_check_json_file(self, _mock_check_json):
input_param = {
"npu_json_path": "npu_json_path",
@@ -275,7 +275,7 @@ class TestUtils(TestCase):
@patch.object(logger, "error")
def test_check_file_size(self, mock_error):
- with patch("atat.core.common.utils.os.path.getsize", return_value=120):
+ with patch("msprobe.core.common.utils.os.path.getsize", return_value=120):
with self.assertRaises(CompareException) as context:
check_file_size("input_file", 100)
self.assertEqual(context.exception.code, CompareException.INVALID_FILE_ERROR)
@@ -294,7 +294,7 @@ class TestUtils(TestCase):
self.assertEqual(str(context.exception), f"prefix contains invalid characters, "
f"prefix pattern {Const.REGEX_PREFIX_PATTERN}")
- @patch("atat.core.common.utils.check_file_or_directory_path")
+ @patch("msprobe.core.common.utils.check_file_or_directory_path")
def test_get_dump_data_path(self, mock_check_file_or_directory_path):
file_path = os.path.realpath(__file__)
dirname = os.path.dirname(file_path)
@@ -322,23 +322,23 @@ class TestUtils(TestCase):
mock_error.assert_called_with("Please check the json path is valid.")
input_param["npu_json_path"] = "npu_json_path"
- with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \
- patch("atat.core.common.utils.json.load", return_value=npu_json):
+ with patch("msprobe.core.common.utils.FileOpen", mock_open(read_data="")), \
+ patch("msprobe.core.common.utils.json.load", return_value=npu_json):
summary_compare, md5_compare = task_dumppath_get(input_param)
self.assertFalse(summary_compare)
self.assertFalse(md5_compare)
npu_json["task"] = Const.STATISTICS
- with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \
- patch("atat.core.common.utils.json.load", return_value=npu_json), \
- patch("atat.core.common.utils.md5_find", return_value=True):
+ with patch("msprobe.core.common.utils.FileOpen", mock_open(read_data="")), \
+ patch("msprobe.core.common.utils.json.load", return_value=npu_json), \
+ patch("msprobe.core.common.utils.md5_find", return_value=True):
summary_compare, md5_compare = task_dumppath_get(input_param)
self.assertFalse(summary_compare)
self.assertTrue(md5_compare)
npu_json["task"] = Const.OVERFLOW_CHECK
- with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \
- patch("atat.core.common.utils.json.load", return_value=npu_json):
+ with patch("msprobe.core.common.utils.FileOpen", mock_open(read_data="")), \
+ patch("msprobe.core.common.utils.json.load", return_value=npu_json):
with self.assertRaises(CompareException) as context:
task_dumppath_get(input_param)
self.assertEqual(context.exception.code, CompareException.INVALID_TASK_ERROR)
diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py
new file mode 100644
index 0000000000000000000000000000000000000000..eedbe5be7e0360d7874439357419510cbde73b71
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py
@@ -0,0 +1,47 @@
+import unittest
+from unittest.mock import patch, mock_open, MagicMock
+
+from msprobe.core.common.utils import Const
+from msprobe.core.data_dump.data_collector import DataCollector
+from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
+from msprobe.pytorch.pt_config import parse_json_config
+
+
+class TestDataCollector(unittest.TestCase):
+ def setUp(self):
+ mock_json_data = {
+ "dump_path": "./ut_dump",
+ }
+ with patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \
+ patch("msprobe.pytorch.pt_config.json.load", return_value=mock_json_data):
+ common_config, task_config = parse_json_config("./config.json", Const.STATISTICS)
+ config = DebuggerConfig(common_config, task_config, Const.STATISTICS, "./ut_dump", "L1")
+ self.data_collector = DataCollector(config)
+
+ def test_update_data(self):
+ self.data_collector.config.task = Const.OVERFLOW_CHECK
+ self.data_collector.data_processor.has_overflow = True
+ with patch("msprobe.core.data_dump.json_writer.DataWriter.update_data", return_value=None):
+ result1 = self.data_collector.update_data("test message", "test1:")
+ self.assertEqual(result1, "test1:Overflow detected.")
+
+ self.data_collector.data_processor.has_overflow = False
+ result2 = self.data_collector.update_data("test message", "test2:")
+ self.assertEqual(result2, "test2:No Overflow, OK.")
+
+ self.data_collector.config.task = Const.STATISTICS
+ self.data_collector.data_processor.has_overflow = True
+ with patch("msprobe.core.data_dump.json_writer.DataWriter.update_data", return_value=None):
+ result3 = self.data_collector.update_data("test message", "test3")
+ self.assertEqual(result3, "test3")
+
+ def test_pre_forward_data_collect(self):
+ self.data_collector.check_scope_and_pid = MagicMock(return_value=False)
+ self.data_collector.is_inplace = MagicMock(return_value=False)
+ self.data_collector.data_processor.analyze_pre_forward = MagicMock()
+ name = "TestModule.forward"
+ pid = 123
+
+ self.data_collector.pre_forward_data_collect(name, None, pid, None)
+ self.data_collector.check_scope_and_pid.assert_called_once_with(
+ self.data_collector.scope, "TestModule.backward", 123)
diff --git a/debug/accuracy_tools/atat/test/core_ut/data_dump/test_json_writer.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py
similarity index 97%
rename from debug/accuracy_tools/atat/test/core_ut/data_dump/test_json_writer.py
rename to debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py
index 867da001e610c7e006a9b7d4a027ecf4d6f43a69..4161377a6067c53fa4735a13641c5609bd0d06c9 100644
--- a/debug/accuracy_tools/atat/test/core_ut/data_dump/test_json_writer.py
+++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py
@@ -1,10 +1,10 @@
import unittest
-from atat.core.data_dump.json_writer import DataWriter
+from msprobe.core.data_dump.json_writer import DataWriter
import os
import csv
-from atat.core.common.file_check import FileOpen
-from atat.core.common import utils
+from msprobe.core.common.file_utils import FileOpen
+from msprobe.core.common import utils
from pathlib import Path
import json
diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_scope.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_scope.py
new file mode 100644
index 0000000000000000000000000000000000000000..1989fd0a95a5894b012aa916cdf44c625de27b1a
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_scope.py
@@ -0,0 +1,151 @@
+import unittest
+from unittest.mock import MagicMock
+
+from msprobe.core.common.exceptions import ScopeException
+from msprobe.core.data_dump.scope import (
+ build_scope,
+ build_range_scope_according_to_scope_name,
+ BaseScope,
+ ListScope,
+ RangeScope,
+ APIRangeScope,
+ ModuleRangeScope
+)
+
+
+class TestBuildScope(unittest.TestCase):
+ def test_build_scope(self):
+ scope_class = MagicMock()
+ result1 = build_scope(scope_class, None, None)
+ self.assertEqual(result1, None)
+
+ api_list = ['api1', 'api2']
+ result2 = build_scope(scope_class, None, api_list)
+ self.assertEqual(result2, scope_class.return_value)
+
+ def test_build_range_scope_according_to_scope_name(self):
+ result = build_range_scope_according_to_scope_name([], [])
+ self.assertIsInstance(result, APIRangeScope)
+
+
+class TestBaseScope(unittest.TestCase):
+ def test_rectify_args(self):
+ scope = []
+ api_list = "invalid_api_list"
+ with self.assertRaises(ScopeException) as context:
+ BaseScope.rectify_args(scope, api_list)
+ self.assertEqual(context.exception.code, ScopeException.InvalidApiStr)
+
+ api_list = [1, 2, 3]
+ with self.assertRaises(ScopeException) as context:
+ BaseScope.rectify_args(scope, api_list)
+ self.assertEqual(context.exception.code, ScopeException.InvalidApiStr)
+
+ scope = "module1"
+ api_list = []
+
+ expected_scope = ["module1"]
+ expected_api_list = []
+ result_scope, result_api_list = BaseScope.rectify_args(scope, api_list)
+ self.assertEqual(result_scope, expected_scope)
+ self.assertEqual(result_api_list, expected_api_list)
+
+ scope = 123
+ api_list = []
+ with self.assertRaises(ScopeException) as context:
+ BaseScope.rectify_args(scope, api_list)
+ self.assertEqual(context.exception.code, ScopeException.InvalidScope)
+
+ scope = ["module1", 2, "module3"]
+ api_list = []
+ with self.assertRaises(ScopeException) as context:
+ BaseScope.rectify_args(scope, api_list)
+ self.assertEqual(context.exception.code, ScopeException.InvalidScope)
+
+
+class TestListScope(unittest.TestCase):
+ def test_rectify_args(self):
+ scope = ["module1"]
+ api_list = ["api1"]
+ with self.assertRaises(ScopeException) as context:
+ ListScope.rectify_args(scope, api_list)
+ self.assertEqual(context.exception.code, ScopeException.ArgConflict)
+
+ def test_check(self):
+ list_scope = ListScope([], [])
+ module_name = "module1"
+ result = list_scope.check(module_name)
+ self.assertTrue(result)
+
+ list_scope = ListScope(["module1"], [])
+ module_name = "module1"
+ result = list_scope.check(module_name)
+ self.assertTrue(result)
+
+ list_scope = ListScope(["module1"], [])
+ module_name = "module2"
+ result = list_scope.check(module_name)
+ self.assertFalse(result)
+
+
+class TestRangeScope(unittest.TestCase):
+ def test_rectify_args(self):
+ scope = ["module1", "module2", "module3"]
+ with self.assertRaises(ScopeException) as context:
+ RangeScope.rectify_args(scope, [])
+ self.assertEqual(context.exception.code, ScopeException.InvalidScope)
+
+ scope = ["module1"]
+ expected_scope = ["module1", "module1"]
+ result_scope, result_api_list = RangeScope.rectify_args(scope, [])
+ self.assertEqual(result_scope, expected_scope)
+
+
+class TestAPIRangeScope(unittest.TestCase):
+ def test_check_scope_is_valid(self):
+ api_range_scope = APIRangeScope([], [])
+ result = api_range_scope.check_scope_is_valid()
+ self.assertTrue(result)
+
+ def test_check(self):
+ api_range_scope = APIRangeScope([], [])
+ api_name = "api1"
+ result = api_range_scope.check(api_name)
+ self.assertTrue(result)
+
+
+class TestModuleRangeScope(unittest.TestCase):
+ def test_check_scope_is_valid(self):
+ module_range_scope = ModuleRangeScope([], [])
+ result = module_range_scope.check_scope_is_valid()
+ self.assertTrue(result)
+
+ def test_begin_module(self):
+ module_range_scope = ModuleRangeScope(["module1", "module2"], [])
+ module_name = "module1"
+ module_range_scope.begin_module(module_name)
+ self.assertTrue(module_range_scope.in_scope)
+
+ module_range_scope = ModuleRangeScope(["module1", "module2"], [])
+ module_name = "module3"
+ module_range_scope.begin_module(module_name)
+ self.assertFalse(module_range_scope.in_scope)
+
+ def test_end_module(self):
+ module_range_scope = ModuleRangeScope(["module1", "module2"], [])
+ module_name = "module2"
+ module_range_scope.in_scope = True
+ module_range_scope.end_module(module_name)
+ self.assertFalse(module_range_scope.in_scope)
+
+ module_range_scope = ModuleRangeScope(["module1", "module2"], [])
+ module_name = "module3"
+ module_range_scope.in_scope = True
+ module_range_scope.end_module(module_name)
+ self.assertTrue(module_range_scope.in_scope)
+
+ def test_check(self):
+ module_range_scope = ModuleRangeScope([], [])
+ module_name = "module1"
+ result = module_range_scope.check(module_name)
+ self.assertTrue(result)
diff --git a/debug/accuracy_tools/atat/test/core_ut/test_common_config.py b/debug/accuracy_tools/msprobe/test/core_ut/test_common_config.py
similarity index 83%
rename from debug/accuracy_tools/atat/test/core_ut/test_common_config.py
rename to debug/accuracy_tools/msprobe/test/core_ut/test_common_config.py
index 00b17e1f1cfce0d28d4bd5c3a2130a2d89eb7409..8b2138a485b7ddac32b31ff3784df6299276b4ba 100644
--- a/debug/accuracy_tools/atat/test/core_ut/test_common_config.py
+++ b/debug/accuracy_tools/msprobe/test/core_ut/test_common_config.py
@@ -17,10 +17,10 @@
from unittest import TestCase
from unittest.mock import patch
-from atat.core.common.log import logger
-from atat.core.common.const import Const
-from atat.core.common.exceptions import MsaccException
-from atat.core.common_config import CommonConfig, BaseConfig
+from msprobe.core.common.log import logger
+from msprobe.core.common.const import Const
+from msprobe.core.common.exceptions import MsprobeException
+from msprobe.core.common_config import CommonConfig, BaseConfig
class TestCommonConfig(TestCase):
@@ -44,7 +44,7 @@ class TestCommonConfig(TestCase):
self.assertEqual(mock_error_log_with_exp.call_args[0][0],
"task is invalid, it should be one of {}".format(Const.TASK_LIST))
self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]),
- MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR))
+ MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR))
json_config.update({"task": Const.TENSOR})
json_config.update({"rank": 0})
@@ -52,7 +52,7 @@ class TestCommonConfig(TestCase):
self.assertEqual(mock_error_log_with_exp.call_args[0][0],
"rank is invalid, it should be a list")
self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]),
- MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR))
+ MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR))
json_config.update({"task": Const.TENSOR})
json_config.update({"rank": [0]})
@@ -61,7 +61,7 @@ class TestCommonConfig(TestCase):
self.assertEqual(mock_error_log_with_exp.call_args[0][0],
"step is invalid, it should be a list")
self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]),
- MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR))
+ MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR))
json_config.update({"task": Const.TENSOR})
json_config.update({"rank": [0]})
@@ -71,7 +71,7 @@ class TestCommonConfig(TestCase):
self.assertEqual(mock_error_log_with_exp.call_args[0][0],
"level is invalid, it should be one of {}".format(Const.LEVEL_LIST))
self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]),
- MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR))
+ MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR))
json_config.update({"task": Const.TENSOR})
json_config.update({"rank": [0]})
@@ -82,7 +82,7 @@ class TestCommonConfig(TestCase):
self.assertEqual(mock_error_log_with_exp.call_args[0][0],
"seed is invalid, it should be an integer")
self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]),
- MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR))
+ MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR))
json_config.update({"task": Const.TENSOR})
json_config.update({"rank": [0]})
@@ -94,7 +94,7 @@ class TestCommonConfig(TestCase):
self.assertEqual(mock_error_log_with_exp.call_args[0][0],
"is_deterministic is invalid, it should be a boolean")
self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]),
- MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR))
+ MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR))
json_config.update({"task": Const.TENSOR})
json_config.update({"rank": [0]})
@@ -107,7 +107,7 @@ class TestCommonConfig(TestCase):
self.assertEqual(mock_error_log_with_exp.call_args[0][0],
"enable_dataloader is invalid, it should be a boolean")
self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]),
- MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR))
+ MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR))
@patch.object(logger, "error_log_with_exp")
def test_base_config(self, mock_error_log_with_exp):
@@ -121,7 +121,7 @@ class TestCommonConfig(TestCase):
self.assertIsNone(base_config.backward_input)
self.assertIsNone(base_config.file_format)
self.assertIsNone(base_config.summary_mode)
- self.assertIsNone(base_config.overflow_num)
+ self.assertIsNone(base_config.overflow_nums)
self.assertIsNone(base_config.check_mode)
json_config.update({"scope": "Tensor_Add"})
@@ -130,7 +130,7 @@ class TestCommonConfig(TestCase):
self.assertEqual(mock_error_log_with_exp.call_args[0][0],
"scope is invalid, it should be a list")
self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]),
- MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR))
+ MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR))
json_config.update({"scope": ["Tensor_Add"]})
json_config.update({"list": "Tensor_Add"})
@@ -139,7 +139,7 @@ class TestCommonConfig(TestCase):
self.assertEqual(mock_error_log_with_exp.call_args[0][0],
"list is invalid, it should be a list")
self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]),
- MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR))
+ MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR))
json_config.update({"scope": ["Tensor_Add"]})
json_config.update({"list": ["Tensor_Add"]})
@@ -149,4 +149,4 @@ class TestCommonConfig(TestCase):
self.assertEqual(mock_error_log_with_exp.call_args[0][0],
"data_mode is invalid, it should be a list")
self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]),
- MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR))
+ MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR))
diff --git a/debug/accuracy_tools/atat/test/core_ut/test_file_check.py b/debug/accuracy_tools/msprobe/test/core_ut/test_file_check.py
similarity index 85%
rename from debug/accuracy_tools/atat/test/core_ut/test_file_check.py
rename to debug/accuracy_tools/msprobe/test/core_ut/test_file_check.py
index aa7882aa5906b8e472495ede42653f9f5584f573..c3f7836bf177c283b246e061d2df61f682175c38 100644
--- a/debug/accuracy_tools/atat/test/core_ut/test_file_check.py
+++ b/debug/accuracy_tools/msprobe/test/core_ut/test_file_check.py
@@ -19,10 +19,10 @@ import os
from unittest import TestCase
from unittest.mock import patch, MagicMock
-from atat.core.common.log import logger
-from atat.core.common.const import FileCheckConst
-from atat.core.common.exceptions import FileCheckException
-from atat.core.common.file_check import (check_link,
+from msprobe.core.common.log import logger
+from msprobe.core.common.const import FileCheckConst
+from msprobe.core.common.exceptions import FileCheckException
+from msprobe.core.common.file_utils import (check_link,
check_path_length,
check_path_exists,
check_path_readability,
@@ -40,7 +40,7 @@ from atat.core.common.file_check import (check_link,
class TestFileCheckUtil(TestCase):
@patch.object(logger, "error")
def test_check_link(self, mock_logger_error):
- with patch("atat.core.common.file_check.os.path.islink", return_value=True):
+ with patch("msprobe.core.common.file_utils.os.path.islink", return_value=True):
with self.assertRaises(FileCheckException) as context:
check_link("link_path")
self.assertEqual(str(context.exception),
@@ -72,7 +72,7 @@ class TestFileCheckUtil(TestCase):
@patch.object(logger, "error")
def test_check_path_exists(self, mock_logger_error):
- with patch("atat.core.common.file_check.os.path.exists", return_value=False):
+ with patch("msprobe.core.common.file_utils.os.path.exists", return_value=False):
with self.assertRaises(FileCheckException) as context:
check_path_exists("file_path")
self.assertEqual(str(context.exception),
@@ -82,7 +82,7 @@ class TestFileCheckUtil(TestCase):
@patch.object(logger, "error")
def test_check_path_readability(self, mock_logger_error):
path = "file_path"
- with patch("atat.core.common.file_check.os.access", return_value=False):
+ with patch("msprobe.core.common.file_utils.os.access", return_value=False):
with self.assertRaises(FileCheckException) as context:
check_path_readability(path)
self.assertEqual(str(context.exception),
@@ -91,14 +91,14 @@ class TestFileCheckUtil(TestCase):
mock_access = MagicMock()
mock_access.return_value = True
- with patch("atat.core.common.file_check.os.access", new=mock_access):
+ with patch("msprobe.core.common.file_utils.os.access", new=mock_access):
check_path_readability(path)
self.assertEqual(mock_access.call_args[0], (path, os.R_OK))
@patch.object(logger, "error")
def test_check_path_writability(self, mock_logger_error):
path = "file_path"
- with patch("atat.core.common.file_check.os.access", return_value=False):
+ with patch("msprobe.core.common.file_utils.os.access", return_value=False):
with self.assertRaises(FileCheckException) as context:
check_path_writability(path)
self.assertEqual(str(context.exception),
@@ -107,14 +107,14 @@ class TestFileCheckUtil(TestCase):
mock_access = MagicMock()
mock_access.return_value = True
- with patch("atat.core.common.file_check.os.access", new=mock_access):
+ with patch("msprobe.core.common.file_utils.os.access", new=mock_access):
check_path_writability(path)
self.assertEqual(mock_access.call_args[0], (path, os.W_OK))
@patch.object(logger, "error")
def test_check_path_executable(self, mock_logger_error):
path = "file_path"
- with patch("atat.core.common.file_check.os.access", return_value=False):
+ with patch("msprobe.core.common.file_utils.os.access", return_value=False):
with self.assertRaises(FileCheckException) as context:
check_path_executable(path)
self.assertEqual(str(context.exception),
@@ -123,7 +123,7 @@ class TestFileCheckUtil(TestCase):
mock_access = MagicMock()
mock_access.return_value = True
- with patch("atat.core.common.file_check.os.access", new=mock_access):
+ with patch("msprobe.core.common.file_utils.os.access", new=mock_access):
check_path_executable(path)
self.assertEqual(mock_access.call_args[0], (path, os.X_OK))
@@ -135,7 +135,7 @@ class TestFileCheckUtil(TestCase):
path = "file_path"
mock_stat = TestStat(0o002)
- with patch("atat.core.common.file_check.os.stat", return_value=mock_stat):
+ with patch("msprobe.core.common.file_utils.os.stat", return_value=mock_stat):
with self.assertRaises(FileCheckException) as context:
check_other_user_writable(path)
self.assertEqual(str(context.exception),
@@ -147,7 +147,7 @@ class TestFileCheckUtil(TestCase):
def test_check_path_owner_consistent(self, mock_logger_error):
file_path = os.path.realpath(__file__)
file_owner = os.stat(file_path).st_uid
- with patch("atat.core.common.file_check.os.getuid", return_value=file_owner+1):
+ with patch("msprobe.core.common.file_utils.os.getuid", return_value=file_owner+1):
with self.assertRaises(FileCheckException) as context:
check_path_owner_consistent(file_path)
self.assertEqual(str(context.exception),
@@ -160,7 +160,7 @@ class TestFileCheckUtil(TestCase):
path = "path"
mock_re_match = MagicMock()
mock_re_match.return_value = False
- with patch("atat.core.common.file_check.re.match", new=mock_re_match):
+ with patch("msprobe.core.common.file_utils.re.match", new=mock_re_match):
with self.assertRaises(FileCheckException) as context:
check_path_pattern_vaild(path)
self.assertEqual(str(context.exception),
@@ -181,8 +181,8 @@ class TestFileCheckUtil(TestCase):
def test_check_common_file_size(self):
mock_check_file_size = MagicMock()
- with patch("atat.core.common.file_check.os.path.isfile", return_value=True), \
- patch("atat.core.common.file_check.check_file_size", new=mock_check_file_size):
+ with patch("msprobe.core.common.file_utils.os.path.isfile", return_value=True), \
+ patch("msprobe.core.common.file_utils.check_file_size", new=mock_check_file_size):
for suffix, max_size in FileCheckConst.FILE_SIZE_DICT.items():
check_common_file_size(suffix)
mock_check_file_size.assert_called_with(suffix, max_size)
@@ -201,16 +201,16 @@ class TestFileCheckUtil(TestCase):
def test_check_path_type(self, mock_logger_error):
file_path = "file_path"
- with patch("atat.core.common.file_check.os.path.isfile", return_value=False), \
- patch("atat.core.common.file_check.os.path.isdir", return_value=True):
+ with patch("msprobe.core.common.file_utils.os.path.isfile", return_value=False), \
+ patch("msprobe.core.common.file_utils.os.path.isdir", return_value=True):
with self.assertRaises(FileCheckException) as context:
check_path_type(file_path, FileCheckConst.FILE)
self.assertEqual(str(context.exception),
FileCheckException.err_strs.get(FileCheckException.INVALID_FILE_ERROR))
mock_logger_error.assert_called_with(f"The {file_path} should be a file!")
- with patch("atat.core.common.file_check.os.path.isfile", return_value=True), \
- patch("atat.core.common.file_check.os.path.isdir", return_value=False):
+ with patch("msprobe.core.common.file_utils.os.path.isfile", return_value=True), \
+ patch("msprobe.core.common.file_utils.os.path.isdir", return_value=False):
with self.assertRaises(FileCheckException) as context:
check_path_type(file_path, FileCheckConst.DIR)
self.assertEqual(str(context.exception),
diff --git a/debug/accuracy_tools/atat/test/core_ut/test_log.py b/debug/accuracy_tools/msprobe/test/core_ut/test_log.py
similarity index 92%
rename from debug/accuracy_tools/atat/test/core_ut/test_log.py
rename to debug/accuracy_tools/msprobe/test/core_ut/test_log.py
index 6d7998d5ae068bf5044eb097b92cd5dbc6971e0a..1687c48d025a72a9237a561cc3bea0dd473f6fbd 100644
--- a/debug/accuracy_tools/atat/test/core_ut/test_log.py
+++ b/debug/accuracy_tools/msprobe/test/core_ut/test_log.py
@@ -17,11 +17,11 @@
from unittest import TestCase
from unittest.mock import patch, MagicMock
-from atat.core.common.log import BaseLogger, logger
+from msprobe.core.common.log import BaseLogger, logger
class TestLog(TestCase):
- @patch("atat.core.common.log.print")
+ @patch("msprobe.core.common.log.print")
def test__print_log(self, mock_print):
logger._print_log("level", "msg")
self.assertIn("[level] msg", mock_print.call_args[0][0])
@@ -75,7 +75,7 @@ class TestLog(TestCase):
@patch.object(BaseLogger, "get_rank")
def test_info_on_rank_0(self, mock_get_rank):
mock_print = MagicMock()
- with patch("atat.core.common.log.print", new=mock_print):
+ with patch("msprobe.core.common.log.print", new=mock_print):
mock_get_rank.return_value = 0
logger.info_on_rank_0("msg")
self.assertIn("[INFO] msg", mock_print.call_args[0][0])
@@ -87,7 +87,7 @@ class TestLog(TestCase):
@patch.object(BaseLogger, "get_rank")
def test_error_on_rank_0(self, mock_get_rank):
mock_print = MagicMock()
- with patch("atat.core.common.log.print", new=mock_print):
+ with patch("msprobe.core.common.log.print", new=mock_print):
mock_get_rank.return_value = 0
logger.error_on_rank_0("msg")
self.assertIn("[ERROR] msg", mock_print.call_args[0][0])
@@ -99,7 +99,7 @@ class TestLog(TestCase):
@patch.object(BaseLogger, "get_rank")
def test_warning_on_rank_0(self, mock_get_rank):
mock_print = MagicMock()
- with patch("atat.core.common.log.print", new=mock_print):
+ with patch("msprobe.core.common.log.print", new=mock_print):
mock_get_rank.return_value = 0
logger.warning_on_rank_0("msg")
self.assertIn("[WARNING] msg", mock_print.call_args[0][0])
diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_api_kbk_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_api_kbk_dump.py
similarity index 75%
rename from debug/accuracy_tools/atat/test/mindspore_ut/test_api_kbk_dump.py
rename to debug/accuracy_tools/msprobe/test/mindspore_ut/test_api_kbk_dump.py
index 47d60999b16a16d7593559c581354d4674438343..7411018ff08507f0ab867b6394aa1c08b5f26469 100644
--- a/debug/accuracy_tools/atat/test/mindspore_ut/test_api_kbk_dump.py
+++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_api_kbk_dump.py
@@ -19,9 +19,9 @@ import os
from unittest import TestCase
from unittest.mock import patch
-from atat.core.common_config import CommonConfig, BaseConfig
-from atat.mindspore.debugger.debugger_config import DebuggerConfig
-from atat.mindspore.dump.api_kbk_dump import ApiKbkDump
+from msprobe.core.common_config import CommonConfig, BaseConfig
+from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
+from msprobe.mindspore.dump.api_kbk_dump import ApiKbkDump
class TestApiKbkDump(TestCase):
@@ -42,10 +42,10 @@ class TestApiKbkDump(TestCase):
self.assertEqual(dumper.dump_json["common_dump_settings"]["iteration"], "0|2")
os.environ["MS_ACL_DUMP_CFG_PATH"] = "path"
- with patch("atat.mindspore.dump.api_kbk_dump.make_dump_path_if_not_exists"), \
- patch("atat.mindspore.dump.api_kbk_dump.FileOpen"), \
- patch("atat.mindspore.dump.api_kbk_dump.json.dump"), \
- patch("atat.mindspore.dump.api_kbk_dump.logger.info"):
+ with patch("msprobe.mindspore.dump.api_kbk_dump.make_dump_path_if_not_exists"), \
+ patch("msprobe.mindspore.dump.api_kbk_dump.FileOpen"), \
+ patch("msprobe.mindspore.dump.api_kbk_dump.json.dump"), \
+ patch("msprobe.mindspore.dump.api_kbk_dump.logger.info"):
dumper.handle()
self.assertEqual(os.environ.get("GRAPH_OP_RUN"), "1")
self.assertEqual(os.environ.get("MS_ACL_DUMP_CFG_PATH"), None)
diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_debugger_config.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_debugger_config.py
similarity index 87%
rename from debug/accuracy_tools/atat/test/mindspore_ut/test_debugger_config.py
rename to debug/accuracy_tools/msprobe/test/mindspore_ut/test_debugger_config.py
index 3bdf341c3979aa0f7e98cab468dcdcb090215137..5187d3951c0cbf2bdeb5db6f402933c8bf08e94d 100644
--- a/debug/accuracy_tools/atat/test/mindspore_ut/test_debugger_config.py
+++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_debugger_config.py
@@ -16,9 +16,9 @@
"""
from unittest import TestCase
-from atat.core.common.const import Const
-from atat.core.common_config import CommonConfig, BaseConfig
-from atat.mindspore.debugger.debugger_config import DebuggerConfig
+from msprobe.core.common.const import Const
+from msprobe.core.common_config import CommonConfig, BaseConfig
+from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
class TestDebuggerConfig(TestCase):
@@ -27,7 +27,7 @@ class TestDebuggerConfig(TestCase):
"dump_path": "/absolute_path",
"rank": [],
"step": [],
- "level": "L1"
+ "level": "L0"
}
common_config = CommonConfig(json_config)
task_config = BaseConfig(json_config)
diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_dump_tool_factory.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_dump_tool_factory.py
similarity index 89%
rename from debug/accuracy_tools/atat/test/mindspore_ut/test_dump_tool_factory.py
rename to debug/accuracy_tools/msprobe/test/mindspore_ut/test_dump_tool_factory.py
index f6626f551fec02438e39ed474374654201e6204c..fb88d7bbbf328b0b8a61b11d41808b756881510e 100644
--- a/debug/accuracy_tools/atat/test/mindspore_ut/test_dump_tool_factory.py
+++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_dump_tool_factory.py
@@ -16,9 +16,9 @@
"""
from unittest import TestCase
-from atat.core.common_config import CommonConfig, BaseConfig
-from atat.mindspore.debugger.debugger_config import DebuggerConfig
-from atat.mindspore.dump.dump_tool_factory import DumpToolFactory
+from msprobe.core.common_config import CommonConfig, BaseConfig
+from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
+from msprobe.mindspore.dump.dump_tool_factory import DumpToolFactory
class TestDumpToolFactory(TestCase):
diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_dump.py
similarity index 80%
rename from debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_dump.py
rename to debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_dump.py
index 6c59521a17d57585170753f4d935b6109b920595..e691a2c7edde2feb1f2c1d60fbba275724bb9092 100644
--- a/debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_dump.py
+++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_dump.py
@@ -19,9 +19,9 @@ import os
from unittest import TestCase
from unittest.mock import patch
-from atat.core.common_config import CommonConfig, BaseConfig
-from atat.mindspore.debugger.debugger_config import DebuggerConfig
-from atat.mindspore.dump.kernel_graph_dump import KernelGraphDump
+from msprobe.core.common_config import CommonConfig, BaseConfig
+from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
+from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump
class TestKernelGraphDump(TestCase):
@@ -45,10 +45,10 @@ class TestKernelGraphDump(TestCase):
self.assertEqual(dumper.dump_json["common_dump_settings"]["file_format"], "bin")
self.assertEqual(dumper.dump_json["common_dump_settings"]["input_output"], 2)
- with patch("atat.mindspore.dump.kernel_graph_dump.make_dump_path_if_not_exists"), \
- patch("atat.mindspore.dump.kernel_graph_dump.FileOpen"), \
- patch("atat.mindspore.dump.kernel_graph_dump.json.dump"), \
- patch("atat.mindspore.dump.kernel_graph_dump.logger.info"):
+ with patch("msprobe.mindspore.dump.kernel_graph_dump.make_dump_path_if_not_exists"), \
+ patch("msprobe.mindspore.dump.kernel_graph_dump.FileOpen"), \
+ patch("msprobe.mindspore.dump.kernel_graph_dump.json.dump"), \
+ patch("msprobe.mindspore.dump.kernel_graph_dump.logger.info"):
os.environ["GRAPH_OP_RUN"] = "1"
with self.assertRaises(Exception) as context:
diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_overflow_check.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py
similarity index 76%
rename from debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_overflow_check.py
rename to debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py
index 101482458dc0a901e0066937ae6df9e9a23fe4fc..a93fab021ab59beff9016895a1748eb942274e30 100644
--- a/debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_overflow_check.py
+++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py
@@ -19,9 +19,9 @@ import os
from unittest import TestCase
from unittest.mock import patch
-from atat.core.common_config import CommonConfig, BaseConfig
-from atat.mindspore.debugger.debugger_config import DebuggerConfig
-from atat.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck
+from msprobe.core.common_config import CommonConfig, BaseConfig
+from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
+from msprobe.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck
class TestKernelGraphOverflowCheck(TestCase):
@@ -43,10 +43,10 @@ class TestKernelGraphOverflowCheck(TestCase):
self.assertEqual(checker.dump_json["common_dump_settings"]["op_debug_mode"], 2)
os.environ["MS_ACL_DUMP_CFG_PATH"] = "path"
- with patch("atat.mindspore.overflow_check.kernel_graph_overflow_check.make_dump_path_if_not_exists"), \
- patch("atat.mindspore.overflow_check.kernel_graph_overflow_check.FileOpen"), \
- patch("atat.mindspore.overflow_check.kernel_graph_overflow_check.json.dump"), \
- patch("atat.mindspore.overflow_check.kernel_graph_overflow_check.logger.info"):
+ with patch("msprobe.mindspore.overflow_check.kernel_graph_overflow_check.make_dump_path_if_not_exists"), \
+ patch("msprobe.mindspore.overflow_check.kernel_graph_overflow_check.FileOpen"), \
+ patch("msprobe.mindspore.overflow_check.kernel_graph_overflow_check.json.dump"), \
+ patch("msprobe.mindspore.overflow_check.kernel_graph_overflow_check.logger.info"):
os.environ["GRAPH_OP_RUN"] = "1"
with self.assertRaises(Exception) as context:
diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py
similarity index 83%
rename from debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py
rename to debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py
index 3dc3670128c1c6b74b0f67fe43b008df340dc858..30212d95e621bea516b888e7e61d042990a2c93a 100644
--- a/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py
+++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py
@@ -17,9 +17,9 @@
from unittest import TestCase
from unittest.mock import patch, mock_open
-from atat.core.common.const import Const
-from atat.mindspore.ms_config import (parse_json_config, parse_task_config,
- TensorConfig, StatisticsConfig, OverflowCheck)
+from msprobe.core.common.const import Const
+from msprobe.mindspore.ms_config import (parse_json_config, parse_task_config,
+ TensorConfig, StatisticsConfig, OverflowCheckConfig)
class TestMsConfig(TestCase):
@@ -37,8 +37,8 @@ class TestMsConfig(TestCase):
"summary_mode": "statistics"
}
}
- with patch("atat.mindspore.ms_config.FileOpen", mock_open(read_data='')), \
- patch("atat.mindspore.ms_config.json.load", return_value=mock_json_data):
+ with patch("msprobe.mindspore.ms_config.FileOpen", mock_open(read_data='')), \
+ patch("msprobe.mindspore.ms_config.json.load", return_value=mock_json_data):
common_config, task_config = parse_json_config("./config.json")
self.assertEqual(common_config.task, Const.STATISTICS)
self.assertEqual(task_config.data_mode, ["all"])
@@ -62,7 +62,7 @@ class TestMsConfig(TestCase):
self.assertTrue(isinstance(task_config, StatisticsConfig))
task_config = parse_task_config("overflow_check", mock_json_config)
- self.assertTrue(isinstance(task_config, OverflowCheck))
+ self.assertTrue(isinstance(task_config, OverflowCheckConfig))
with self.assertRaises(Exception) as context:
parse_task_config("free_benchmark", mock_json_config)
diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_overflow_check_tool_factory.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py
similarity index 88%
rename from debug/accuracy_tools/atat/test/mindspore_ut/test_overflow_check_tool_factory.py
rename to debug/accuracy_tools/msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py
index 497fe1376abcff0607d6afd3f2b03d94f963bcd7..47da051d4fdd1d9b65ef8c6092a7b05a3e2263b6 100644
--- a/debug/accuracy_tools/atat/test/mindspore_ut/test_overflow_check_tool_factory.py
+++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py
@@ -16,9 +16,9 @@
"""
from unittest import TestCase
-from atat.core.common_config import CommonConfig, BaseConfig
-from atat.mindspore.debugger.debugger_config import DebuggerConfig
-from atat.mindspore.overflow_check.overflow_check_tool_factory import OverflowCheckToolFactory
+from msprobe.core.common_config import CommonConfig, BaseConfig
+from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
+from msprobe.mindspore.overflow_check.overflow_check_tool_factory import OverflowCheckToolFactory
class TestOverflowCheckToolFactory(TestCase):
diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_precision_debugger.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_precision_debugger.py
similarity index 79%
rename from debug/accuracy_tools/atat/test/mindspore_ut/test_precision_debugger.py
rename to debug/accuracy_tools/msprobe/test/mindspore_ut/test_precision_debugger.py
index 834a58e41a426d975ac97b0f757db5a0432a297f..425ed3040dcc829927a8f4cbb25024f1b567a48f 100644
--- a/debug/accuracy_tools/atat/test/mindspore_ut/test_precision_debugger.py
+++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_precision_debugger.py
@@ -17,9 +17,9 @@
from unittest import TestCase
from unittest.mock import patch
-from atat.core.common_config import CommonConfig, BaseConfig
-from atat.mindspore.debugger.debugger_config import DebuggerConfig
-from atat.mindspore.debugger.precision_debugger import PrecisionDebugger
+from msprobe.core.common_config import CommonConfig, BaseConfig
+from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
+from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger
class TestPrecisionDebugger(TestCase):
@@ -35,16 +35,16 @@ class TestPrecisionDebugger(TestCase):
"dump_path": "/absolute_path",
"rank": [],
"step": [],
- "level": "L1"
+ "level": "L0"
}
common_config = CommonConfig(json_config)
task_config = BaseConfig(json_config)
handler = Handler()
- with patch("atat.mindspore.debugger.precision_debugger.parse_json_config",
+ with patch("msprobe.mindspore.debugger.precision_debugger.parse_json_config",
return_value=[common_config, task_config]), \
- patch("atat.mindspore.debugger.precision_debugger.TaskHandlerFactory.create", return_value=handler):
+ patch("msprobe.mindspore.debugger.precision_debugger.TaskHandlerFactory.create", return_value=handler):
debugger = PrecisionDebugger()
debugger.start()
self.assertTrue(isinstance(debugger.config, DebuggerConfig))
diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_task_handler_factory.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py
similarity index 82%
rename from debug/accuracy_tools/atat/test/mindspore_ut/test_task_handler_factory.py
rename to debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py
index 02cd9934cb1635b9cec68bcd4b54dd2753ea23e9..41be7b1db6c7d723aaeec1607f564ac3d772b404 100644
--- a/debug/accuracy_tools/atat/test/mindspore_ut/test_task_handler_factory.py
+++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py
@@ -17,10 +17,10 @@
from unittest import TestCase
from unittest.mock import patch
-from atat.core.common_config import CommonConfig, BaseConfig
-from atat.mindspore.debugger.debugger_config import DebuggerConfig
-from atat.mindspore.dump.kernel_graph_dump import KernelGraphDump
-from atat.mindspore.task_handler_factory import TaskHandlerFactory
+from msprobe.core.common_config import CommonConfig, BaseConfig
+from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
+from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump
+from msprobe.mindspore.task_handler_factory import TaskHandlerFactory
class TestTaskHandlerFactory(TestCase):
@@ -47,7 +47,7 @@ class TestTaskHandlerFactory(TestCase):
handler = TaskHandlerFactory.create(config)
self.assertTrue(isinstance(handler, KernelGraphDump))
- with patch("atat.mindspore.task_handler_factory.TaskHandlerFactory.tasks", new=tasks):
+ with patch("msprobe.mindspore.task_handler_factory.TaskHandlerFactory.tasks", new=tasks):
with self.assertRaises(Exception) as context:
TaskHandlerFactory.create(config)
self.assertEqual(str(context.exception), "Can not find task handler")
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/advisor/test_advisor.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/advisor/test_advisor.py
similarity index 85%
rename from debug/accuracy_tools/atat/test/pytorch_ut/advisor/test_advisor.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/advisor/test_advisor.py
index 78e5b489e7ad14f2965b813d733a30c8849b8a71..176b80068f70e60a06a6eed77b23b35e8b48a50d 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/advisor/test_advisor.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/advisor/test_advisor.py
@@ -2,12 +2,13 @@ import difflib
import os
import shutil
import unittest
+import logging
from unittest.mock import patch
import pandas
-from atat.pytorch.advisor.advisor import Advisor
-from atat.pytorch.advisor.advisor_const import AdvisorConst
+from msprobe.pytorch.advisor.advisor import Advisor
+from msprobe.pytorch.advisor.advisor_const import AdvisorConst
class TestAdvisor(unittest.TestCase):
@@ -70,11 +71,11 @@ class TestAdvisor(unittest.TestCase):
output_content = out_file.read().splitlines()
result = list(difflib.unified_diff(standard_content, output_content, n=0))
if result:
- print('\n\n-------------------------------------------------------------------------', flush=True)
- print(f'[ERROR] {output_file.replace(self.output_path, "")} advisor summary are inconsistent.',
- flush=True)
- print('\n'.join(result), flush=True)
- print('-------------------------------------------------------------------------', flush=True)
+ logging.basicConfig(level=logging.INFO)
+ logging.info('\n\n-------------------------------------------------------------------------')
+ logging.error(f'[ERROR] {output_file.replace(self.output_path, "")} advisor summary are inconsistent.')
+ logging.error('\n'.join(result))
+ logging.info('\n\n-------------------------------------------------------------------------')
self.has_error = True
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py
similarity index 96%
rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py
index 16d0c0bc12738bb7ce129224cf124761503c31fd..56d100f0a1b570c0c0a1753db3c79aacea7b76ac 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py
@@ -1,12 +1,12 @@
import unittest
from unittest.mock import patch
-from atat.pytorch.api_accuracy_checker.common.utils import *
+from msprobe.pytorch.api_accuracy_checker.common.utils import *
class TestUtils(unittest.TestCase):
- @patch('atat.pytorch.api_accuracy_checker.common.utils.get_file_content_bytes')
+ @patch('msprobe.pytorch.api_accuracy_checker.common.utils.get_file_content_bytes')
def test_get_json_contents_should_raise_exception(self, mock_get_file_content_bytes):
mock_get_file_content_bytes.return_value = 'not a dict'
with self.assertRaises(CompareException) as ce:
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/common/test_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py
similarity index 92%
rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/common/test_config.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py
index 066e74aa518dd7958a511bb47be92bce7ce5ac0b..35fc6164763e685d09e737e7f85bec33623ec111 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/common/test_config.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py
@@ -2,7 +2,7 @@ import unittest
import os
from unittest.mock import patch
-from atat.pytorch.api_accuracy_checker.common.config import Config
+from msprobe.pytorch.api_accuracy_checker.common.config import Config
class TestConfig(unittest.TestCase):
@@ -35,5 +35,5 @@ class TestConfig(unittest.TestCase):
validate_white_list = ['conv1d', 'max_pool1d', 'dropout', '__add__']
self.assertEqual(self.cfg.validate('white_list', validate_white_list), validate_white_list)
- with self.assertRaises(ValueError):
+ with self.assertRaises(Exception):
self.cfg.validate('white_list', ['invalid_api1', 'invalid_api2'])
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py
similarity index 98%
rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py
index 9604e7a681c869c41cdfa70b7b1b551ceed9604e..35a8b9f1fa52f689905986ca477c8a7077a084da 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py
@@ -2,7 +2,7 @@ import unittest
import numpy as np
-from atat.pytorch.api_accuracy_checker.compare import algorithm as alg
+from msprobe.pytorch.api_accuracy_checker.compare import algorithm as alg
class TestAlgorithmMethods(unittest.TestCase):
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py
similarity index 94%
rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py
index 7717d826577cbc3ade87c6188e1049c7455df905..540460d0896532bbe242c3eb30b3ee945bae9571 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py
@@ -2,14 +2,14 @@ import unittest
import pandas as pd
-from atat.pytorch.api_accuracy_checker.compare.api_precision_compare import (
+from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import (
CompareConfig,
BenchmarkStandard,
check_csv_columns,
check_error_rate,
get_api_checker_result,
)
-from atat.core.common.const import CompareConst
+from msprobe.core.common.const import CompareConst
class TestApiPrecisionCompare(unittest.TestCase):
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py
similarity index 96%
rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py
index 2c97471c7ab2aec4a6db25d6d8dca3cd07001cda..e1e6d51de292cf4d8b617ab73db67ff4920bfac3 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py
@@ -7,9 +7,9 @@ import unittest
import numpy as np
import torch.nn.functional
-from atat.pytorch.api_accuracy_checker.compare.compare import Comparator
-from atat.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
-from atat.pytorch.api_accuracy_checker.run_ut.run_ut import UtDataInfo
+from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
+from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
+from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import UtDataInfo
current_time = time.strftime("%Y%m%d%H%M%S")
RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py
similarity index 68%
rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py
index ee25a25e74d18dd0cc5436747767aa2acbbab05e..782321868a8cbcae9ffed3b215ca068968b1c0ae 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py
@@ -1,6 +1,6 @@
import unittest
-from atat.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
+from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
class TestCompareColumns(unittest.TestCase):
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py
similarity index 88%
rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py
index 93f3c2c73e14130f479198a32a6b4f2f5c64745d..ac9c974ea3ecf6a835ce448d754582d435548ed8 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py
@@ -2,8 +2,8 @@ import unittest
import numpy as np
-from atat.pytorch.api_accuracy_checker.common.utils import CompareException
-from atat.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, convert_str_to_float
+from msprobe.pytorch.api_accuracy_checker.common.utils import CompareException
+from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, convert_str_to_float
class TestCompareUtils(unittest.TestCase):
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json
similarity index 100%
rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json
similarity index 100%
rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py
similarity index 96%
rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py
index f47c71c984f7e03449c8b1bcee11d4e81e64f842..f664dad197f6bbaaed3d574f657552377b176dec 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py
@@ -3,8 +3,8 @@ import os
import unittest
import copy
-from atat.pytorch.api_accuracy_checker.run_ut.data_generate import *
-from atat.pytorch.api_accuracy_checker.common.utils import get_json_contents
+from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import *
+from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents
base_dir = os.path.dirname(os.path.realpath(__file__))
forward_file = os.path.join(base_dir, "forward.json")
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py
similarity index 83%
rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py
index 6a9071f15ea982ba8b515c4e95683be586b39aef..27126cdddda215ea521b5090782cc1c85f07e5f0 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py
@@ -5,7 +5,7 @@ import logging
from unittest.mock import patch, mock_open, MagicMock
import json
import signal
-from atat.pytorch.api_accuracy_checker.run_ut.multi_run_ut import split_json_file, signal_handler, run_parallel_ut, \
+from msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut import split_json_file, signal_handler, run_parallel_ut, \
prepare_config, main, ParallelUTConfig
@@ -20,7 +20,7 @@ class TestMultiRunUT(unittest.TestCase):
{'key3': 'TRUE', 'key4': 'TRUE'}
]
- @patch('atat.pytorch.api_accuracy_checker.run_ut.multi_run_ut.FileOpen')
+ @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.FileOpen')
def test_split_json_file(self, mock_FileOpen):
mock_FileOpen.return_value.__enter__.return_value = mock_open(read_data=self.test_json_content).return_value
num_splits = 2
@@ -48,7 +48,7 @@ class TestMultiRunUT(unittest.TestCase):
device_id=[0, 1],
result_csv_path='result.csv',
total_items=2,
- real_data_path=None
+ config_path=None
)
mock_file.side_effect = [
@@ -63,10 +63,10 @@ class TestMultiRunUT(unittest.TestCase):
@patch('os.remove')
@patch('os.path.realpath', side_effect=lambda x: x)
- @patch('atat.pytorch.api_accuracy_checker.run_ut.multi_run_ut.check_link')
- @patch('atat.pytorch.api_accuracy_checker.run_ut.multi_run_ut.check_file_suffix')
- @patch('atat.pytorch.api_accuracy_checker.run_ut.multi_run_ut.FileChecker')
- @patch('atat.pytorch.api_accuracy_checker.run_ut.multi_run_ut.split_json_file',
+ @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.check_link')
+ @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.check_file_suffix')
+ @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.FileChecker')
+ @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.split_json_file',
return_value=(['forward_split1.json', 'forward_split2.json'], 2))
def test_prepare_config(self, mock_split_json_file, mock_FileChecker, mock_check_file_suffix, mock_check_link,
mock_realpath, mock_remove):
@@ -81,7 +81,7 @@ class TestMultiRunUT(unittest.TestCase):
args.jit_compile = False
args.device_id = [0, 1]
args.result_csv_path = None
- args.real_data_path = None
+ args.config_path = None
config = prepare_config(args)
@@ -93,8 +93,8 @@ class TestMultiRunUT(unittest.TestCase):
@patch('argparse.ArgumentParser.parse_args')
- @patch('atat.pytorch.api_accuracy_checker.run_ut.multi_run_ut.prepare_config')
- @patch('atat.pytorch.api_accuracy_checker.run_ut.multi_run_ut.run_parallel_ut')
+ @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.prepare_config')
+ @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.run_parallel_ut')
def test_main(self, mock_run_parallel_ut, mock_prepare_config, mock_parse_args):
main()
mock_parse_args.assert_called()
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py
similarity index 95%
rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py
index 97dccd2b58f7005cdd8c3907ff150d9df7f9ff3d..bc643794ab692ff5d21bcf412450f845d01f662d 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py
@@ -4,8 +4,8 @@ import copy
import unittest
import torch
from unittest.mock import patch, DEFAULT
-from atat.pytorch.api_accuracy_checker.run_ut.run_ut import *
-from atat.pytorch.api_accuracy_checker.common.utils import get_json_contents
+from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import *
+from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents
base_dir = os.path.dirname(os.path.realpath(__file__))
forward_file = os.path.join(base_dir, "forward.json")
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_acc_compare.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_acc_compare.py
new file mode 100644
index 0000000000000000000000000000000000000000..288e259c0aae104a62054af3813b7831ec7722f7
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_acc_compare.py
@@ -0,0 +1,267 @@
+# coding=utf-8
+import unittest
+import pandas as pd
+from msprobe.pytorch.compare import acc_compare as compare
+
+npu_dict = {'op_name': ['Functional_conv2d_0_forward_input.0', 'Functional_conv2d_0_forward_input.1',
+ 'Functional_conv2d_0_forward_input.2', 'Functional_conv2d_0_forward_output'],
+ 'input_struct': [('torch.float32', [1, 1, 28, 28]), ('torch.float32', [16, 1, 5, 5]),
+ ('torch.float32', [16])],
+ 'output_struct': [('torch.float32', [1, 16, 28, 28])],
+ 'summary': [[3.029174327850342, -2.926689624786377, -0.06619918346405029],
+ [0.19919930398464203, -0.19974489510059357, 0.006269412115216255],
+ [0.19734230637550354, -0.18177609145641327, 0.007903944700956345],
+ [2.1166646480560303, -2.190781354904175, -0.003579073818400502]], 'stack_info': []}
+
+bench_dict = {'op_name': ['Functional_conv2d_0_forward_input.0', 'Functional_conv2d_0_forward_input.1',
+ 'Functional_conv2d_0_forward_input.2', 'Functional_conv2d_0_forward_output'],
+ 'input_struct': [('torch.float32', [1, 1, 28, 28]), ('torch.float32', [16, 1, 5, 5]),
+ ('torch.float32', [16])],
+ 'output_struct': [('torch.float32', [1, 16, 28, 28])],
+ 'summary': [[3.029174327850342, -2.926689624786377, -0.06619918346405029],
+ [0.19919930398464203, -0.19974489510059357, 0.006269412115216255],
+ [0.19734230637550354, -0.18177609145641327, 0.007903944700956345],
+ [2.1166646480560303, -2.190781354904175, -0.003579073818400502]], 'stack_info': []}
+
+tensor_list = [
+ {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], 'Max': 0.33033010363578796,
+ 'Min': -0.331031858921051,'Mean': -0.030964046716690063, 'Norm': 2.2533628940582275, 'requires_grad': True,
+ 'full_op_name': 'Tensor.add_.0.forward_input.0'},
+ {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3],
+ 'Max': 0.003992878366261721, 'Min': -0.008102823048830032, 'Mean': -0.0002002553956117481,
+ 'Norm': 0.02844562754034996, 'requires_grad': False, 'full_op_name': 'Tensor.add_.0.forward_input.1'},
+ {'full_op_name': 'Tensor.add_.0.forward_input.alpha.0', 'dtype': "", "shape": '[]', 'md5': None,
+ 'Max': -0.1, 'Min': -0.1, 'Mean': -0.1, 'Norm': -0.1, 'data_name': '-1'},
+ {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3],
+ 'Max': 0.33033010363578796, 'Min': -0.331031858921051, 'Mean': -0.030964046716690063,
+ 'Norm': 2.2533628940582275, 'requires_grad': True, 'full_op_name': 'Tensor.add_.0.forward_output.0'}
+]
+
+result_op_dict = {'op_name': ['Tensor.add_.0.forward_input.0', 'Tensor.add_.0.forward_input.1',
+ 'Tensor.add_.0.forward_input.alpha.0', 'Tensor.add_.0.forward_output.0'],
+ 'input_struct': [('torch.float32', [16, 1, 3, 3]), ('torch.float32', [16, 1, 3, 3]),
+ ("", '[]')],
+ 'output_struct': [('torch.float32', [16, 1, 3, 3])],
+ 'summary': [[0.33033010363578796, -0.331031858921051, -0.030964046716690063, 2.2533628940582275],
+ [0.003992878366261721, -0.008102823048830032, -0.0002002553956117481, 0.02844562754034996],
+ [-0.1, -0.1, -0.1, -0.1],
+ [0.33033010363578796, -0.331031858921051, -0.030964046716690063, 2.2533628940582275]],
+ 'stack_info': []}
+
+o_result = [
+ ['Functional_conv2d_0_forward_input.0', 'Functional_conv2d_0_forward_input.0', 'torch.float32', 'torch.float32',
+ [1, 1, 28, 28], [1, 1, 28, 28], 0.0, 0.0, 0.0, ' ', '0.0%', '0.0%', '0.0%', ' ', 3.029174327850342, -2.926689624786377,
+ -0.06619918346405029, 3.029174327850342, -2.926689624786377, -0.06619918346405029, '', '', 'None'],
+ ['Functional_conv2d_0_forward_input.1', 'Functional_conv2d_0_forward_input.1', 'torch.float32', 'torch.float32',
+ [16, 1, 5, 5], [16, 1, 5, 5], 0.0, 0.0, 0.0, ' ', '0.0%', '0.0%', '0.0%', ' ', 0.19919930398464203, -0.19974489510059357,
+ 0.006269412115216255, 0.19919930398464203, -0.19974489510059357, 0.006269412115216255, '', '', 'None'],
+ ['Functional_conv2d_0_forward_input.2', 'Functional_conv2d_0_forward_input.2', 'torch.float32', 'torch.float32',
+ [16], [16], 0.0, 0.0, 0.0, ' ', '0.0%', '0.0%', '0.0%', ' ', 0.19734230637550354, -0.18177609145641327, 0.007903944700956345,
+ 0.19734230637550354, -0.18177609145641327, 0.007903944700956345, '', '', 'None'],
+ ['Functional_conv2d_0_forward_output', 'Functional_conv2d_0_forward_output', 'torch.float32', 'torch.float32',
+ [1, 16, 28, 28], [1, 16, 28, 28], 0.0, 0.0, 0.0, ' ', '0.0%', '0.0%', '0.0%', ' ', 2.1166646480560303, -2.190781354904175,
+ -0.003579073818400502, 2.1166646480560303, -2.190781354904175, -0.003579073818400502, '', '', 'None']]
+
+npu_dict_aten = {'op_name': ['Aten__native_batch_norm_legit_functional.default_0_forward_input.0',
+ 'Aten__native_batch_norm_legit_functional.default_0_forward_input.1',
+ 'Aten__native_batch_norm_legit_functional.default_0_forward_input.2',
+ 'Aten__native_batch_norm_legit_functional.default_0_forward_input.3',
+ 'Aten__native_batch_norm_legit_functional.default_0_forward_input.4',
+ 'Aten__native_batch_norm_legit_functional.default_0_forward_output.0',
+ 'Aten__native_batch_norm_legit_functional.default_0_forward_output.1',
+ 'Aten__native_batch_norm_legit_functional.default_0_forward_output.2',
+ 'Aten__native_batch_norm_legit_functional.default_0_forward_output.3',
+ 'Aten__native_batch_norm_legit_functional.default_0_forward_output.4'],
+ 'input_struct': [('torch.float16', [256, 256, 14, 14]), ('torch.float32', [256]),
+ ('torch.float32', [256]), ('torch.float32', [256]), ('torch.float32', [256])],
+ 'output_struct': [('torch.float16', [256, 256, 14, 14]), ('torch.float32', [256]),
+ ('torch.float32', [256]), ('torch.float32', [256]), ('torch.float32', [256])],
+ 'summary': [[139.625, -127.5625, -0.0103607177734375],
+ [2.5276029109954834, -2.1788690090179443, -0.0008259844034910202],
+ [2.472219944000244, -2.845968723297119, -0.008756577968597412],
+ [2.763145923614502, -3.398397922515869, -0.052132632583379745],
+ [2.673110008239746, -3.149275064468384, 0.01613386906683445],
+ [13.5546875, -10.640625, -0.008758544921875],
+ [0.30550330877304077, -0.24485322833061218, -0.010361209511756897],
+ [623.9192504882812, 432.96826171875, 520.2276611328125],
+ [2.4797861576080322, -3.055997371673584, -0.04795549064874649],
+ [61.7945556640625, 42.59713363647461, 52.03831481933594]]}
+
+bench_dict_functional = {
+ 'op_name': ['Functional_batch_norm_0_forward_input.0', 'Functional_batch_norm_0_forward_input.1',
+ 'Functional_batch_norm_0_forward_input.2', 'Functional_batch_norm_0_forward_input.3',
+ 'Functional_batch_norm_0_forward_input.4', 'Functional_batch_norm_0_forward_output'],
+ 'input_struct': [('torch.float32', [256, 256, 14, 14]), ('torch.float32', [256]), ('torch.float32', [256]),
+ ('torch.float32', [256]), ('torch.float32', [256])],
+ 'output_struct': [('torch.float32', [256, 256, 14, 14])],
+ 'summary': [[3.061628818511963, -3.22507381439209, 3.634914173744619e-05],
+ [0.0005779837374575436, -0.0006301702815108001, 3.634906533989124e-06],
+ [0.9338104128837585, 0.9277191162109375, 0.930335283279419],
+ [1.0, 1.0, 1.0], [0.0, 0.0, 0.0],
+ [5.397906303405762, -5.796811580657959, 2.5283952709287405e-10]]
+}
+
+aten_result = [
+ ['Aten__native_batch_norm_legit_functional.default_0_forward_input.0', 'Functional_batch_norm_0_forward_input.0',
+ 'torch.float16', 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 136.56337118148804, -124.33742618560791,
+ -0.010397066915174946, ' ', '4460.480981749501%', '3855.335826136584%', '28603.33536971545%', ' ', 139.625,
+ -127.5625, -0.0103607177734375, 3.061628818511963, -3.22507381439209, 3.634914173744619e-05, 'Warning',
+ 'Need double check api accuracy.', 'None'],
+ ['Aten__native_batch_norm_legit_functional.default_0_forward_input.1', 'Functional_batch_norm_0_forward_input.1',
+ 'torch.float32', 'torch.float32', [256], [256], 2.527024927258026, -2.1782388387364335, -0.0008296193100250093,
+ ' ', '437213.84590749856%', '345658.76916858414%', '22823.676544842117%', ' ', 2.5276029109954834,
+ -2.1788690090179443, -0.0008259844034910202, 0.0005779837374575436, -0.0006301702815108001, 3.634906533989124e-06,
+ 'Warning', 'Need double check api accuracy.', 'None'],
+ ['Aten__native_batch_norm_legit_functional.default_0_forward_input.2', 'Functional_batch_norm_0_forward_input.2',
+ 'torch.float32', 'torch.float32', [256], [256], 1.5384095311164856, -3.7736878395080566, -0.9390918612480164, ' ',
+ '164.74538192025793%', '406.7705163736246%', '100.94122819224167%', ' ', 2.472219944000244, -2.845968723297119,
+ -0.008756577968597412, 0.9338104128837585, 0.9277191162109375, 0.930335283279419, 'Warning',
+ 'Need double check api accuracy.', 'None'],
+ ['Aten__native_batch_norm_legit_functional.default_0_forward_input.3', 'Functional_batch_norm_0_forward_input.3',
+ 'torch.float32', 'torch.float32', [256], [256], 1.763145923614502, -4.398397922515869, -1.0521326325833797, ' ',
+ '176.3145923614502%', '439.8397922515869%', '105.21326325833797%', ' ', 2.763145923614502, -3.398397922515869,
+ -0.052132632583379745, 1.0, 1.0, 1.0, 'Warning', 'Need double check api accuracy.', 'None'],
+ ['Aten__native_batch_norm_legit_functional.default_0_forward_input.4', 'Functional_batch_norm_0_forward_input.4',
+ 'torch.float32', 'torch.float32', [256], [256], 2.673110008239746, -3.149275064468384, 0.01613386906683445, ' ',
+ 'N/A', 'N/A', 'N/A', ' ', 2.673110008239746, -3.149275064468384, 0.01613386906683445, 0.0, 0.0, 0.0, 'Warning',
+ 'Need double check api accuracy.', 'None'],
+ ['Aten__native_batch_norm_legit_functional.default_0_forward_output.0', 'Functional_batch_norm_0_forward_output',
+ 'torch.float16', 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 8.156781196594238, -4.843813419342041,
+ -0.008758545174714527, ' ', '151.11009228611078%', '83.55995967687207%', '3464072756.115108%', ' ', 13.5546875,
+ -10.640625, -0.008758544921875, 5.397906303405762, -5.796811580657959, 2.5283952709287405e-10, 'Warning',
+ 'Need double check api accuracy.', 'None'],
+ ['Aten__native_batch_norm_legit_functional.default_0_forward_output.1', 'Nan', 'torch.float32', 'Nan', [256], 'Nan',
+ ' ', ' ', ' ', ' ', ' ', 0.30550330877304077, -0.24485322833061218, -0.010361209511756897, 'Nan', 'Nan', 'Nan',
+ 'Yes', '', 'None'],
+ ['Aten__native_batch_norm_legit_functional.default_0_forward_output.2', 'Nan', 'torch.float32', 'Nan', [256], 'Nan',
+ ' ', ' ', ' ', ' ', ' ', 623.9192504882812, 432.96826171875, 520.2276611328125, 'Nan', 'Nan', 'Nan',
+ 'Yes', '', 'None'],
+ ['Aten__native_batch_norm_legit_functional.default_0_forward_output.3', 'Nan', 'torch.float32', 'Nan', [256], 'Nan',
+ ' ', ' ', ' ', ' ', ' ', 2.4797861576080322, -3.055997371673584, -0.04795549064874649, 'Nan', 'Nan', 'Nan',
+ 'Yes', '', 'None'],
+ ['Aten__native_batch_norm_legit_functional.default_0_forward_output.4', 'Nan', 'torch.float32', 'Nan', [256], 'Nan',
+ ' ', ' ', ' ', ' ', ' ', 61.7945556640625, 42.59713363647461, 52.03831481933594, 'Nan', 'Nan', 'Nan',
+ 'Yes', '', 'None']]
+
+highlight_dict = {'red_rows': [], 'yellow_rows': []}
+
+num_0, num_1, num_2, num_3 = 0, 1, 2, 3
+summary_line_input = ['Functional_batch_norm_0_forward_input.0', 'Functional_batch_norm_0_forward_input.0',
+ 'torch.float16',
+ 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.01, 0, 0, 0, 1, 1, 1, 1, 1.01, 1, 1, 1,
+ 'Yes', '']
+summary_line_1 = ['Functional_batch_norm_0_forward_output.0', 'Functional_batch_norm_0_forward_output.0',
+ 'torch.float16',
+ 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 10, 0, 0, 0, 2, 0, 1, 1, 1, 1, 1, 1,
+ 'Warning', '']
+summary_line_2 = ['Functional_batch_norm_0_forward_output.1', 'Functional_batch_norm_0_forward_output.1',
+ 'torch.float16',
+ 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.02, 0, 0, 0, 0.12, 0, 1, 1, 0.1, 1, 1, 1,
+ 'Warning', '']
+summary_line_3 = ['Functional_batch_norm_0_forward_output.2', 'Functional_batch_norm_0_forward_output.2',
+ 'torch.float16',
+ 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0, 0, 0, 0, 2, 0, 1, 1, 1, 1, 1, 1,
+ 'Warning', '']
+line_input = ['Functional_batch_norm_0_forward_input.0', 'Functional_batch_norm_0_forward_input.0', 'torch.float16',
+ 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 1, 1, 1, 0.95, 1, 1, 1, 1, 1, 1.01, 1, 1, 1,
+ 'Yes', '']
+line_1 = ['Functional_batch_norm_0_forward_output.0', 'Functional_batch_norm_0_forward_output.0', 'torch.float16',
+ 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.8, 1, 1, 0.59, 1, 'nan', 0, 1, 1, 19, 1, 1, 1,
+ 'Warning', '']
+line_2 = ['Functional_batch_norm_0_forward_output.1', 'Functional_batch_norm_0_forward_output.1', 'torch.float16',
+ 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.9, 1, 1, 0.8, 1, 0, 0.12, 0, 1, 1, 0.1, 1, 1, 1,
+ 'Warning', '']
+line_3 = ['Functional_batch_norm_0_forward_output.2', 'Functional_batch_norm_0_forward_output.2', 'torch.float16',
+ 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.8, 1.1e+10, 1, 0.85, 1, 9, 0.12, 0, 1, 1, 0.1, 1,
+ 1, 1, 'Warning', '']
+
+op_data = {
+ 'input_args': [{'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3],
+ 'Max': 0.33033010363578796, 'Min': -0.331031858921051,'Mean': -0.030964046716690063,
+ 'Norm': 2.2533628940582275, 'requires_grad': True},
+ {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3],
+ 'Max': 0.003992878366261721, 'Min': -0.008102823048830032, 'Mean': -0.0002002553956117481,
+ 'Norm': 0.02844562754034996, 'requires_grad': False}],
+ 'input_kwargs': {'alpha': {'type': 'float', 'value': -0.1}},
+ 'output': [{'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3],
+ 'Max': 0.33033010363578796, 'Min': -0.331031858921051,'Mean': -0.030964046716690063,
+ 'Norm': 2.2533628940582275, 'requires_grad': True}]}
+
+op_name = "Tensor.add_0.0.forward"
+
+op_result = [
+ {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3],
+ 'Max': 0.33033010363578796, 'Min': -0.331031858921051, 'Mean': -0.030964046716690063,
+ 'Norm': 2.2533628940582275, 'requires_grad': True, 'full_op_name': 'Tensor.add_0.0.forward_input.0'},
+ {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3],
+ 'Max': 0.003992878366261721, 'Min': -0.008102823048830032, 'Mean': -0.0002002553956117481,
+ 'Norm': 0.02844562754034996, 'requires_grad': False, 'full_op_name': 'Tensor.add_0.0.forward_input.1'},
+ {'full_op_name': 'Tensor.add_0.0.forward_input.alpha.0', 'dtype': "", 'shape': '[]', 'md5': None,
+ 'Max': -0.1, 'Min': -0.1, 'Mean': -0.1, 'Norm': -0.1, 'data_name': '-1'},
+ {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3],
+ 'Max': 0.33033010363578796, 'Min': -0.331031858921051, 'Mean': -0.030964046716690063,
+ 'Norm': 2.2533628940582275, 'requires_grad': True, 'full_op_name': 'Tensor.add_0.0.forward_output.0'}]
+
+
+class TestUtilsMethods(unittest.TestCase):
+
+ def test_check_graph_mode(self):
+ op1 = "Aten"
+ op2 = "torch"
+ self.assertTrue(compare.check_graph_mode(op1, op2))
+ self.assertTrue(compare.check_graph_mode(op2, op1))
+ self.assertFalse(compare.check_graph_mode(op1, op1))
+ self.assertFalse(compare.check_graph_mode(op2, op2))
+
+ def test_check_op(self):
+ fuzzy_match = False
+ result = compare.check_op(npu_dict, bench_dict, fuzzy_match)
+ self.assertEqual(result, True)
+
+ def test_merge_tensor(self):
+ op_dict = compare.merge_tensor(tensor_list, True, False)
+ self.assertEqual(op_dict, result_op_dict)
+
+ def test_read_op(self):
+ result = compare.read_op(op_data, op_name)
+ self.assertEqual(result, op_result)
+
+ def test_match_op(self):
+ fuzzy_match = False
+ a, b = compare.match_op([npu_dict], [bench_dict], fuzzy_match)
+ self.assertEqual(a, 0)
+ self.assertEqual(b, 0)
+
+ def test_get_accuracy(self):
+ result = []
+ compare.get_accuracy(result, npu_dict, bench_dict, highlight_dict)
+ self.assertEqual(result, o_result)
+
+ def test_get_accuracy_graph_mode(self):
+ result = []
+ compare.get_accuracy(result, npu_dict_aten, bench_dict_functional, highlight_dict)
+ self.assertEqual(result, aten_result)
+
+ def test_find_error_rows(self):
+ summary_result = [summary_line_input, summary_line_1, summary_line_2, summary_line_3]
+ highlight_dict = {'red_rows': [], 'yellow_rows': []}
+ compare.find_error_rows(summary_result, 0, 1, highlight_dict, summary_compare=True)
+ self.assertEqual(highlight_dict, {'red_rows': [], 'yellow_rows': []})
+
+ def test_find_compare_result_error_rows(self):
+ result = [line_input, line_1, line_2, line_3]
+ result_df = pd.DataFrame(result)
+ highlight_dict = {'red_rows': [], 'yellow_rows': []}
+ compare.find_compare_result_error_rows(result_df, highlight_dict, False, False)
+ self.assertEqual(highlight_dict, {'red_rows': [num_1, num_3], 'yellow_rows': [num_2]})
+
+ def test_rename_api(self):
+ test_name_1 = "Distributed.broadcast.0.forward.input.0"
+ expect_name_1 = "Distributed.broadcast.input.0"
+ actual_name_1 = compare.rename_api(test_name_1, "forward")
+ self.assertEqual(actual_name_1, expect_name_1)
+
+ test_name_2 = "Torch.sum.0.backward.output.0"
+ expect_name_2 = "Torch.sum.output.0"
+ actual_name_2 = compare.rename_api(test_name_2, "backward")
+ self.assertEqual(actual_name_2, expect_name_2)
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_match.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_match.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac28e994e9c8e77f8ae675fec3322eaf64a64321
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_match.py
@@ -0,0 +1,20 @@
+# coding=utf-8
+import unittest
+from msprobe.pytorch.compare import match
+
+
+class TestMatch(unittest.TestCase):
+ def test_graph_mapping(self):
+ op1 = "Aten_convolution_1_forward_0.input.0"
+ op2 = "Torch_conv2d_0_forward_0.input.0"
+ op3 = "Torch_batch_norm_0_forward_0.input.0"
+ op4 = "Aten_convolution.default_1_forward_0.input.0"
+ op5 = "Aten_foo_1_forward_0.input.0"
+ self.assertTrue(match.graph_mapping.match(op1, op2))
+ self.assertTrue(match.graph_mapping.match(op2, op1))
+ self.assertTrue(match.graph_mapping.match(op4, op2))
+ self.assertTrue(match.graph_mapping.match(op2, op4))
+ self.assertFalse(match.graph_mapping.match(op1, op3))
+ self.assertFalse(match.graph_mapping.match(op3, op1))
+ self.assertFalse(match.graph_mapping.match(op5, op2))
+ self.assertFalse(match.graph_mapping.match(op2, op5))
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py
similarity index 94%
rename from debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py
index 828d646c52f363e758a10939dbfcd98d942eb9f2..ad9eb5cd0ed5b2eaaff9745c8af9ced8dc1ab883 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py
@@ -1,10 +1,10 @@
from unittest import TestCase
import torch
-from atat.core.common.const import Const
-from atat.pytorch.free_benchmark.common.enums import DeviceType, PerturbationMode
-from atat.pytorch.free_benchmark.common.params import data_pre_deal
-from atat.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory
+from msprobe.core.common.const import Const
+from msprobe.pytorch.free_benchmark.common.enums import DeviceType, PerturbationMode
+from msprobe.pytorch.free_benchmark.common.params import data_pre_deal
+from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory
class TestPerturbedLayer(TestCase):
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py
similarity index 91%
rename from debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py
index d46e26e09488d481ac9e96a8b4002dd3b62446bd..399efeb42d7cd7e7e34dd472cd8a9d82c26a5b5e 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py
@@ -2,17 +2,17 @@ from abc import ABC
from unittest import TestCase
import torch
-from atat.core.common.const import Const
-from atat.pytorch.free_benchmark.common.constant import PreheatConfig, ThresholdConfig
-from atat.pytorch.free_benchmark.common.counter import preheat_counter
-from atat.pytorch.free_benchmark.common.enums import (
+from msprobe.core.common.const import Const
+from msprobe.pytorch.free_benchmark.common.constant import PreheatConfig, ThresholdConfig
+from msprobe.pytorch.free_benchmark.common.counter import preheat_counter
+from msprobe.pytorch.free_benchmark.common.enums import (
DeviceType,
FuzzLevel,
HandlerType,
PerturbationMode,
)
-from atat.pytorch.free_benchmark.common.params import DataParams, make_handler_params
-from atat.pytorch.free_benchmark.result_handlers.handler_factory import (
+from msprobe.pytorch.free_benchmark.common.params import DataParams, make_handler_params
+from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import (
FuzzHandlerFactory,
)
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/test_main.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/test_main.py
similarity index 92%
rename from debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/test_main.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/test_main.py
index d326e993c07d66a1baf5ae785ed4b519624bb982..4498a2af7054edd89aa6fae6a057a489216794b6 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/test_main.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/test_main.py
@@ -4,10 +4,10 @@ from unittest import TestCase
import torch
import torch.nn as nn
-from atat.core.common.const import Const
-from atat.pytorch.free_benchmark import FreeBenchmarkCheck
-from atat.pytorch.free_benchmark.common.constant import CommonField, PreheatConfig
-from atat.pytorch.free_benchmark.common.enums import (
+from msprobe.core.common.const import Const
+from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck
+from msprobe.pytorch.free_benchmark.common.constant import CommonField, PreheatConfig
+from msprobe.pytorch.free_benchmark.common.enums import (
DeviceType,
FuzzLevel,
HandlerType,
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/functional/test_dump_module.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/functional/test_dump_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..d67adf2f91292391ff01d450bb5647524f6fc9c4
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/functional/test_dump_module.py
@@ -0,0 +1,15 @@
+import unittest
+
+import torch.nn as nn
+from msprobe.pytorch import PrecisionDebugger
+from msprobe.pytorch.functional.dump_module import module_dump, module_count
+
+
+class TestDumpModule(unittest.TestCase):
+ def setUp(self):
+ self.module = nn.Linear(in_features=8, out_features=4)
+
+ def test_module_dump(self):
+ PrecisionDebugger(dump_path="./dump")
+ module_dump(self.module, "TestModule")
+ self.assertTrue("TestModule" in module_count)
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_api_registry.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_api_registry.py
similarity index 91%
rename from debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_api_registry.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_api_registry.py
index c80e5dbed456fd7df9c574a14f96e16b886e8a3e..837ad23df76be2a012a7408dab4879847937f229 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_api_registry.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_api_registry.py
@@ -1,5 +1,5 @@
import unittest
-from atat.pytorch.hook_module.api_registry import ApiRegistry, torch_version_above_2, is_gpu
+from msprobe.pytorch.hook_module.api_registry import ApiRegistry, torch_version_above_2, is_gpu
class TestApiRegistry(unittest.TestCase):
@@ -43,7 +43,7 @@ class TestApiRegistry(unittest.TestCase):
import torch
import torch.distributed as dist
#import torch_npu #门禁没有安装torch_npu
- from atat.pytorch.hook_module.api_registry import torch_without_guard_version, npu_distributed_api, is_gpu, torch_version_above_2
+ from msprobe.pytorch.hook_module.api_registry import torch_without_guard_version, npu_distributed_api, is_gpu, torch_version_above_2
@@ -79,7 +79,7 @@ class TestApiRegistry(unittest.TestCase):
import torch
import torch.distributed as dist
#import torch_npu #门禁没有安装torch_npu
- from atat.pytorch.hook_module.api_registry import torch_without_guard_version, npu_distributed_api, is_gpu, torch_version_above_2
+ from msprobe.pytorch.hook_module.api_registry import torch_without_guard_version, npu_distributed_api, is_gpu, torch_version_above_2
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_hook_module.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py
similarity index 94%
rename from debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_hook_module.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py
index 646f64152267702d10eb0ff9aa2bce4e497a9f35..50783e5d736c024b03f20008ad6b72882eddcd87 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_hook_module.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py
@@ -1,7 +1,7 @@
import unittest
from unittest.mock import patch, Mock
-from atat.pytorch.hook_module.hook_module import HOOKModule
+from msprobe.pytorch.hook_module.hook_module import HOOKModule
class TestHookModule(unittest.TestCase):
def test_call_1(self):
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_aten.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py
similarity index 96%
rename from debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_aten.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py
index 92aee790ddd4d28e652cc063209028a1e4e5b3d1..4940b07cb0d8e9d2283db3daebc910a7fdcd6ce9 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_aten.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py
@@ -1,6 +1,6 @@
import unittest
import torch
-from atat.pytorch.hook_module.wrap_aten import AtenOPTemplate, AtenOPPacketTemplate
+from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate, AtenOPPacketTemplate
def hook(name):
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_distributed.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py
similarity index 95%
rename from debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_distributed.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py
index bd0501ef2fdfa90b00f9049c00aceb03e5225831..9a375e45bfcdc93ac36fb9d44a79f50fea7932d5 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_distributed.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py
@@ -1,6 +1,6 @@
import unittest
import torch.distributed as dist
-from atat.pytorch.hook_module.wrap_distributed import *
+from msprobe.pytorch.hook_module.wrap_distributed import *
class TestWrapDistributed(unittest.TestCase):
def hook(name, prefix):
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_functional.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py
similarity index 91%
rename from debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_functional.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py
index 232117498b5900b1d7d178208b81430f86b52eb4..f43b8ea6cb98dd1947811b7b0641439b225b51ec 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_functional.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py
@@ -1,6 +1,6 @@
import unittest
import torch
-from atat.pytorch.hook_module import wrap_functional as wf
+from msprobe.pytorch.hook_module import wrap_functional as wf
class TestWrapFunctional(unittest.TestCase):
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_tensor.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py
similarity index 88%
rename from debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_tensor.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py
index e027270540e6a5ebfc1d3963ac29bae87c9b1110..61f76b0ca0a59ee680ff40991fa9cba5e42d869d 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_tensor.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py
@@ -1,7 +1,7 @@
import unittest
import torch
import yaml
-from atat.pytorch.hook_module.wrap_tensor import get_tensor_ops, HOOKTensor, TensorOPTemplate, wrap_tensor_op, wrap_tensor_ops_and_bind
+from msprobe.pytorch.hook_module.wrap_tensor import get_tensor_ops, HOOKTensor, TensorOPTemplate, wrap_tensor_op, wrap_tensor_ops_and_bind
class TestWrapTensor(unittest.TestCase):
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_torch.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py
similarity index 96%
rename from debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_torch.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py
index 8817bc758ae4b9f76f194e8c191119f56f483ea9..e1a3e77983d80e7c0519e30afbb592311550e794 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_torch.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py
@@ -1,7 +1,7 @@
import unittest
import torch
import yaml
-from atat.pytorch.hook_module.wrap_torch import *
+from msprobe.pytorch.hook_module.wrap_torch import *
class TestWrapTorch(unittest.TestCase):
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_vf.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py
similarity index 82%
rename from debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_vf.py
rename to debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py
index 8d57fad6eb623569acfbbffe28dc9decb1ca30b6..98efb4bc5b8a30284fe820124e48af7f487d1c54 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_vf.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py
@@ -1,6 +1,6 @@
import unittest
import torch
-from atat.pytorch.hook_module import wrap_vf
+from msprobe.pytorch.hook_module import wrap_vf
class TestWrapVF(unittest.TestCase):
def setUp(self):
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..470390d77b2f233fb97f0916eed6d60c0b0c10ef
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py
@@ -0,0 +1,84 @@
+from unittest import TestCase
+from unittest.mock import patch, mock_open
+
+from msprobe.core.common.const import Const
+from msprobe.pytorch.pt_config import parse_json_config, parse_task_config
+
+
+class TestPtConfig(TestCase):
+ def test_parse_json_config(self):
+ mock_json_data = {
+ "task": "statistics",
+ "dump_path": "./dump/",
+ "rank": [],
+ "step": [],
+ "level": "L1",
+ "seed": 1234,
+ "statistics": {
+ "scope": [],
+ "list": [],
+ "data_mode": ["all"],
+ },
+ "tensor": {
+ "file_format": "npy"
+ }
+ }
+ with patch("msprobe.pytorch.pt_config.os.path.join", return_value="/path/config.json"), \
+ patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \
+ patch("msprobe.pytorch.pt_config.json.load", return_value=mock_json_data):
+ common_config, task_config = parse_json_config(None, None)
+ self.assertEqual(common_config.task, Const.STATISTICS)
+ self.assertEqual(task_config.data_mode, ["all"])
+
+ with patch("msprobe.pytorch.pt_config.os.path.join", return_value="/path/config.json"), \
+ patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \
+ patch("msprobe.pytorch.pt_config.json.load", return_value=mock_json_data):
+ common_config, task_config = parse_json_config(None, Const.TENSOR)
+ self.assertEqual(common_config.task, Const.STATISTICS)
+ self.assertEqual(task_config.file_format, "npy")
+
+ def test_parse_task_config(self):
+ overflow_check_config = {
+ "overflow_check": {
+ "overflow_nums": 1,
+ "check_mode": "all"
+ }
+ }
+ result = parse_task_config(Const.OVERFLOW_CHECK, overflow_check_config)
+ self.assertEqual(result.overflow_nums, 1)
+ self.assertEqual(result.check_mode, "all")
+
+ free_benchmark_config = {
+ "free_benchmark": {
+ "scope": [],
+ "list": ["conv2d"],
+ "fuzz_device": "npu",
+ "pert_mode": "improve_precision",
+ "handler_type": "check",
+ "fuzz_level": "L1",
+ "fuzz_stage": "forward",
+ "if_preheat": False,
+ "preheat_step": 15,
+ "max_sample": 20
+ }
+ }
+ result = parse_task_config(Const.FREE_BENCHMARK, free_benchmark_config)
+ self.assertEqual(result.pert_mode, "improve_precision")
+ self.assertEqual(result.handler_type, "check")
+ self.assertEqual(result.preheat_step, 15)
+ self.assertEqual(result.max_sample, 20)
+
+ run_ut_config = {
+ "run_ut": {
+ "white_list": ["conv2d"],
+ "black_list": ["matmul"],
+ "error_data_path": '/home/dump_path'
+
+ }
+ }
+ with patch('os.path.exists', return_value=True) as mocked_exists:
+ result = parse_task_config(Const.RUN_UT, run_ut_config)
+ self.assertEqual(result.white_list, ["conv2d"])
+ self.assertEqual(result.black_list, ["matmul"])
+ self.assertEqual(result.error_data_path, '/home/dump_path')
+ mocked_exists.assert_called_once_with('/home/dump_path')
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..c09b6abcb693a048e360ed0d783a0c817c76b06f
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py
@@ -0,0 +1,59 @@
+import unittest
+from unittest.mock import patch, mock_open
+
+import torch.nn as nn
+from msprobe.core.common.utils import Const
+from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
+from msprobe.pytorch.pt_config import parse_json_config
+from msprobe.pytorch.service import Service
+
+
+class TestService(unittest.TestCase):
+ def setUp(self):
+ mock_json_data = {
+ "dump_path": "./dump/",
+ }
+ with patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \
+ patch("msprobe.pytorch.pt_config.json.load", return_value=mock_json_data):
+ common_config, task_config = parse_json_config("./config.json", Const.STATISTICS)
+ self.config = DebuggerConfig(common_config, task_config, Const.STATISTICS, "./ut_dump", "L1")
+ self.service = Service(self.config)
+
+ def test_start(self):
+ with patch("msprobe.pytorch.service.get_rank_if_initialized", return_value=0), \
+ patch("msprobe.pytorch.service.Service.create_dirs", return_value=None):
+ self.service.start(None)
+ self.assertEqual(self.service.current_rank, 0)
+
+ def test_stop_and_step(self):
+ with patch("msprobe.core.data_dump.data_collector.DataCollector.write_json", return_value=None):
+ self.service.stop()
+ self.assertFalse(self.service.switch)
+
+ self.service.step()
+ self.assertEqual(self.service.current_iter, 1)
+
+ def test_register_hook_new(self):
+ class TestModule(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.linear = nn.Linear(in_features=8, out_features=4)
+
+ def forward(self, x):
+ x = self.linear(x)
+ return x
+
+ self.service.model = TestModule()
+ self.config.level = "L0"
+ with patch("msprobe.pytorch.service.logger.info_on_rank_0") as mock_logger, \
+ patch("msprobe.pytorch.service.remove_dropout", return_value=None):
+ self.service.register_hook_new()
+ self.assertEqual(mock_logger.call_count, 2)
+
+ def test_create_dirs(self):
+ with patch("msprobe.pytorch.service.Path.mkdir", return_value=None), \
+ patch("msprobe.core.common.file_check.FileChecker.common_check", return_value=None), \
+ patch("msprobe.core.data_dump.data_collector.DataCollector.update_dump_paths",
+ return_value=None):
+ self.service.create_dirs()
+ self.assertEqual(self.service.dump_iter_dir, "./ut_dump/step0")
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_graph_builder.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_graph_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9433bc136395773e118810b6154d906670b633f4
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_graph_builder.py
@@ -0,0 +1,99 @@
+import unittest
+from unittest.mock import MagicMock, patch
+from msprobe.pytorch.visualization.builder.graph_builder import GraphBuilder, Graph, BaseNode, NodeOp
+
+
+class TestGraphBuilder(unittest.TestCase):
+
+ def setUp(self):
+ self.construct_path = "step/rank/construct.json"
+ self.data_path = "step/rank/dump.json"
+ self.model_name = "TestModel"
+ self.graph = Graph(self.model_name)
+ self.construct_dict = {
+ "Tensor1": "Module1",
+ "Module1": None
+ }
+ self.data_dict = {
+ "Module1": {"data": "data for Module1"},
+ "Tensor1": {"data": "data for Tensor1"}
+ }
+
+ @patch('msprobe.pytorch.visualization.builder.graph_builder.load_json_file')
+ @patch('msprobe.pytorch.visualization.builder.graph_builder.load_data_json_file')
+ def test_build(self, mock_load_data_json_file, mock_load_json_file):
+ mock_load_data_json_file.return_value = self.data_dict
+ mock_load_json_file.return_value = self.construct_dict
+
+ graph = GraphBuilder.build(self.construct_path, self.data_path, self.model_name)
+ self.assertIsNotNone(graph)
+ self.assertIsInstance(graph, Graph)
+ self.assertEqual(len(graph.node_map), 3)
+
+ @patch('msprobe.pytorch.visualization.builder.graph_builder.save_json_file')
+ def test_to_json(self, mock_save_json_file):
+ GraphBuilder.to_json("step/rank/output.vis", self.graph)
+ mock_save_json_file.assert_called_once()
+
+ @patch('msprobe.pytorch.visualization.graph.node_op.NodeOp.get_node_op')
+ @patch('msprobe.pytorch.visualization.builder.msprobe_adapter.get_input_output', return_value=([], []))
+ def test__init_nodes(self, mock_get_input_output, mock_get_node_op):
+ GraphBuilder._init_nodes(self.graph, self.construct_dict, self.data_dict)
+ mock_get_node_op.assert_any_call("Tensor1")
+ mock_get_node_op.assert_any_call("Module1")
+ self.assertIs(self.graph.root, self.graph.get_node("TestModel"))
+
+ def test__create_or_get_node(self):
+ node_op = MagicMock()
+ data_dict = {"node1": {}}
+ node = GraphBuilder._create_or_get_node(self.graph, data_dict, node_op, "node1")
+ self.assertIn("node1", self.graph.node_map)
+ self.assertEqual(node.input_data, {})
+ self.assertEqual(node.output_data, {})
+
+ def test__handle_backward_upnode_missing(self):
+ construct_dict = {'Module.module.a.forward.0': 'Module.root.forward.0', 'Module.module.a.backward.0': None,
+ 'Module.root.forward.0': None, 'Module.root.backward.0': None,
+ 'Module.module.b.forward.0': 'Module.root.forward.0',
+ 'Module.module.b.backward.0': 'Module.root.backward.0', 'Module.module.c.backward.0': None}
+ node_id_a = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.a.backward.0', None)
+ self.assertEqual(node_id_a, 'Module.root.backward.0')
+ node_id_b = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.b.backward.0',
+ 'Module.root.backward.0')
+ self.assertEqual(node_id_b, 'Module.root.backward.0')
+ node_id_c = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.c.backward.0', None)
+ self.assertIsNone(node_id_c)
+
+ def test__collect_apis_between_modules_only_apis(self):
+ graph = Graph('TestNet')
+ graph.root.subnodes = [BaseNode(NodeOp.function_api, 'Tensor.a.0'), BaseNode(NodeOp.function_api, 'Tensor.b.0')]
+ GraphBuilder._collect_apis_between_modules(graph)
+ self.assertEqual(len(graph.root.subnodes), 1)
+ self.assertEqual(graph.root.subnodes[0].op, NodeOp.api_collection)
+ self.assertEqual(len(graph.root.subnodes[0].subnodes), 2)
+ self.assertEqual(graph.root.subnodes[0].id, 'Apis_Between_Modules.0')
+
+ def test__collect_apis_between_modules_mixed_nodes(self):
+ graph = Graph('TestNet')
+ graph.root.subnodes = [BaseNode(NodeOp.function_api, 'Tensor.a.0'), BaseNode(NodeOp.module, 'Module.a.0'),
+ BaseNode(NodeOp.module, 'Module.b.0'), BaseNode(NodeOp.function_api, 'Tensor.b.0'),
+ BaseNode(NodeOp.function_api, 'Tensor.c.0'), BaseNode(NodeOp.module, 'Module.a.1')]
+ GraphBuilder._collect_apis_between_modules(graph)
+ self.assertEqual(len(graph.root.subnodes), 5)
+ self.assertEqual(graph.root.subnodes[0].op, NodeOp.function_api)
+ self.assertEqual(graph.root.subnodes[1].op, NodeOp.module)
+ self.assertEqual(graph.root.subnodes[3].op, NodeOp.api_collection)
+ self.assertEqual(len(graph.root.subnodes[3].subnodes), 2)
+ self.assertEqual(graph.root.subnodes[3].id, 'Apis_Between_Modules.0')
+
+ def test__collect_apis_between_modules_only_modules(self):
+ graph = Graph('TestNet')
+ graph.root.subnodes = [BaseNode(NodeOp.module, 'Module.a.0'), BaseNode(NodeOp.module, 'Module.b.0'),
+ BaseNode(NodeOp.module, 'Module.a.1')]
+ GraphBuilder._collect_apis_between_modules(graph)
+ self.assertEqual(len(graph.root.subnodes), 3)
+ self.assertEqual(graph.root.subnodes[0].op, NodeOp.module)
+ self.assertEqual(graph.root.subnodes[1].op, NodeOp.module)
+ self.assertEqual(graph.root.subnodes[2].op, NodeOp.module)
+ self.assertEqual(len(graph.root.subnodes[0].subnodes), 0)
+ self.assertEqual(graph.root.subnodes[0].id, 'Module.a.0')
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_msprobe_adapter.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_msprobe_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f023128b8c7350b2bcefebe6007dc5ef46133e14
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_msprobe_adapter.py
@@ -0,0 +1,104 @@
+import unittest
+from unittest.mock import patch
+from msprobe.pytorch.visualization.builder.msprobe_adapter import (
+ get_compare_mode,
+ run_real_data,
+ get_input_output,
+ compare_data,
+ format_node_data,
+ compare_node,
+ _format_decimal_string,
+ _format_data,
+ compare_mapping_data
+)
+from msprobe.pytorch.visualization.utils import GraphConst
+
+
+class TestMsprobeAdapter(unittest.TestCase):
+ @patch('msprobe.pytorch.visualization.builder.msprobe_adapter.task_dumppath_get', return_value=(True, False))
+ def test_get_compare_mode_summary(self, mock_task_dumppath_get):
+ mode = get_compare_mode("dummy_param")
+ self.assertEqual(mode, GraphConst.SUMMARY_COMPARE)
+
+ @patch('msprobe.pytorch.visualization.builder.msprobe_adapter._do_multi_process')
+ def test_run_real_data(self, mock_do_multi_process):
+ run_real_data("dump_path", "csv_path")
+ mock_do_multi_process.assert_called_once_with("dump_path", "csv_path")
+
+ def test_get_input_output(self):
+ node_data = {
+ 'input_args': [{'type': 'torch.Tensor', 'dtype': 'torch.int64', 'shape': [5],
+ 'Max': 2049.0, 'Min': 0.0, 'Mean': 410.20001220703125, 'Norm': 2049.0009765625,
+ 'requires_grad': False, 'full_op_name': 'Distributed.broadcast.0.forward_input.0'},
+ {'type': 'int', 'value': 0}],
+ 'input_kwargs': {'group': None},
+ 'output': [{'type': 'torch.Tensor', 'dtype': 'torch.int64', 'shape': [5],
+ 'Max': 2049.0, 'Min': 0.0, 'Mean': 410.20001220703125, 'Norm': 2049.0009765625,
+ 'requires_grad': False, 'full_op_name': 'Distributed.broadcast.0.forward_output.0'},
+ {'type': 'int', 'value': 0}, None]
+ }
+ node_id = "Distributed.broadcast.0.forward"
+ input_data, output_data = get_input_output(node_data, node_id)
+ self.assertIn("Distributed.broadcast.0.forward_output.0", output_data)
+ self.assertIn("Distributed.broadcast.0.forward_input.0", input_data)
+
+ def test_compare_data(self):
+ data_dict_list1 = {'key1': {'type': 'Type1', 'dtype': 'DType1', 'shape': 'Shape1'}}
+ data_dict_list2 = {'key1': {'type': 'Type1', 'dtype': 'DType1', 'shape': 'Shape1'}}
+ data_dict_list3 = {'key1': {'type': 'Type2', 'dtype': 'DType1', 'shape': 'Shape1'}}
+ data_dict_list4 = {}
+ self.assertTrue(compare_data(data_dict_list1, data_dict_list2))
+ self.assertFalse(compare_data(data_dict_list1, data_dict_list3))
+ self.assertFalse(compare_data(data_dict_list1, data_dict_list4))
+
+ def test_format_node_data(self):
+ data_dict = {'node1': {'data_name': 'data1', 'full_op_name': 'op1'}}
+ result = format_node_data(data_dict)
+ self.assertNotIn('data_name', result['node1'])
+ self.assertNotIn('requires_grad', result['node1'])
+
+ @patch('msprobe.pytorch.visualization.builder.msprobe_adapter.get_accuracy')
+ def test_compare_node(self, mock_get_accuracy):
+ node_ids = ["node1", "node2"]
+ data_dicts = [{'node1': {"input_args": [], "input_kwargs": {}, "output": {}}},
+ {'node2': {"input_args": [], "input_kwargs": {}, "output": {}}}]
+ stack_json_data = {}
+ result = compare_node(node_ids, data_dicts, stack_json_data, False, False)
+ mock_get_accuracy.assert_called_once()
+ self.assertIsInstance(result, list)
+
+ def test__format_decimal_string(self):
+ s = "0.123456789%"
+ formatted_s = _format_decimal_string(s)
+ self.assertIn("0.123457%", formatted_s)
+ self.assertEqual('0.123457', _format_decimal_string('0.12345678'))
+ self.assertEqual('-1', _format_decimal_string('-1'))
+ self.assertEqual('0.0.25698548%', _format_decimal_string('0.0.25698548%'))
+
+ def test__format_data(self):
+ data_dict = {'value': 0.123456789, 'value1': None, 'value2': "", 'value3': 1.123123123123e-11,
+ 'value4': torch.inf, 'value5': -1}
+ _format_data(data_dict)
+ self.assertEqual(data_dict['value'], '0.123457')
+ self.assertEqual(data_dict['value1'], 'null')
+ self.assertEqual(data_dict['value2'], '')
+ self.assertEqual(data_dict['value3'], '1.123123e-11')
+ self.assertEqual(data_dict['value4'], 'inf')
+ self.assertEqual(data_dict['value5'], '-1')
+
+ all_none_dict = {'a': None, 'b': None, 'c': None, 'd': None, 'e': None}
+ _format_data(all_none_dict)
+ self.assertEqual({'value': 'null'}, all_none_dict)
+
+ def test_compare_mapping_data(self):
+ dict1 = {'a': {'shape': [1, 2, 3]}, 'b': {'shape': [1, 2, 3]}, 'c': {'shape': [1, 2, 3]}}
+ dict2 = {'a': {'shape': [1, 2, 3]}, 'b': {'shape': [1, 2, 3]}, 'c': {'shape': [1, 2, 3]}}
+ dict3 = {'a': {'shape': [1, 2, 3]}, 'b': {'shape': [1, 2, 3]}}
+ dict4 = {'a': {'shape': [2, 1, 3]}, 'b': {'shape': [1, 2, 3]}}
+ dict5 = {'a': {'shape': [2, 2, 3]}, 'b': {'shape': [1, 2, 3]}}
+ dict6 = {'a': {'type': 'str'}}
+ self.assertTrue(compare_mapping_data(dict1, dict2))
+ self.assertTrue(compare_mapping_data(dict1, dict3))
+ self.assertTrue(compare_mapping_data(dict1, dict4))
+ self.assertFalse(compare_mapping_data(dict1, dict5))
+ self.assertTrue(compare_mapping_data(dict1, dict6))
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_graph_comparator.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_graph_comparator.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb69ac7723d9bff42ef62aa93fd62131790b5fc2
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_graph_comparator.py
@@ -0,0 +1,143 @@
+import unittest
+from unittest.mock import patch
+from unittest.mock import MagicMock
+from msprobe.pytorch.visualization.compare.graph_comparator import GraphComparator
+from msprobe.pytorch.visualization.graph.graph import Graph, BaseNode, NodeOp
+from msprobe.pytorch.visualization.utils import GraphConst
+
+
+class TestGraphComparator(unittest.TestCase):
+
+ def setUp(self):
+ self.graphs = [Graph("model1"), Graph("model2")]
+ self.data_paths = ["step1/rank/dump.json", "step2/rank/dump.json"]
+ self.stack_path = "step1/rank/stack.json"
+ self.output_path = "output/output.vis"
+
+ @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode')
+ @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file')
+ @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file')
+ def test__parse_param(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode):
+ mock_load_data_json_file.return_value = "data_dict"
+ mock_load_json_file.return_value = "construct_dict"
+ mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE
+ self.comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path)
+ self.comparator._parse_param(self.data_paths, self.stack_path, self.output_path)
+
+ self.assertEqual(self.comparator.dump_path_param, {
+ 'npu_json_path': self.data_paths[0],
+ 'bench_json_path': self.data_paths[1],
+ 'stack_json_path': self.stack_path,
+ 'is_print_compare_log': True
+ })
+ self.assertEqual(self.comparator.output_path, self.output_path)
+
+ @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode')
+ @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file')
+ @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file')
+ def test_compare(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode):
+ mock_load_data_json_file.return_value = "data_dict"
+ mock_load_json_file.return_value = "construct_dict"
+ mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE
+ comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path)
+ comparator._compare_nodes = MagicMock()
+ comparator._postcompare = MagicMock()
+
+ comparator.compare()
+
+ comparator._compare_nodes.assert_called_once()
+ comparator._postcompare.assert_called_once()
+
+ @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode')
+ @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file')
+ @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file')
+ def test_add_compare_result_to_node(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode):
+ mock_load_data_json_file.return_value = "data_dict"
+ mock_load_json_file.return_value = "construct_dict"
+ mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE
+ node = MagicMock()
+ compare_result_list = [("output1", "data1"), ("input1", "data2")]
+
+ comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path)
+ comparator.ma = MagicMock()
+ comparator.ma.prepare_real_data.return_value = True
+
+ comparator.add_compare_result_to_node(node, compare_result_list)
+ comparator.ma.prepare_real_data.assert_called_once_with(node)
+ node.data.update.assert_not_called()
+
+ @patch('msprobe.pytorch.visualization.graph.node_colors.NodeColors.get_node_error_status')
+ @patch('msprobe.pytorch.visualization.utils.get_csv_df')
+ @patch('msprobe.pytorch.visualization.builder.msprobe_adapter.run_real_data')
+ @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode')
+ @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file')
+ @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file')
+ def test__postcompare(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode,
+ mock_run_real_data, mock_get_csv_df, mock_get_node_error_status):
+ mock_load_data_json_file.return_value = "data_dict"
+ mock_load_json_file.return_value = "construct_dict"
+ mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE
+ mock_df = MagicMock()
+ mock_df.iterrows = MagicMock(return_value=[(None, MagicMock())])
+ mock_run_real_data.return_value = mock_df
+ mock_get_csv_df.return_value = mock_df
+ mock_get_node_error_status.return_value = True
+ comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path)
+ comparator.ma = MagicMock()
+ comparator.ma.is_real_data_compare.return_value = True
+ comparator._handle_api_collection_index = MagicMock()
+ comparator.ma.compare_nodes = [MagicMock()]
+ comparator.ma.parse_result = MagicMock(return_value=(0.9, None))
+
+ comparator._postcompare()
+
+ comparator._handle_api_collection_index.assert_called_once()
+ comparator.ma.add_error_key.assert_called()
+
+ @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode')
+ @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file')
+ @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file')
+ def test__handle_api_collection_index(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode):
+ mock_load_data_json_file.return_value = "data_dict"
+ mock_load_json_file.return_value = "construct_dict"
+ mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE
+ comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path)
+ apis = BaseNode(NodeOp.api_collection, 'Apis_Between_Modules.0')
+ api1 = BaseNode(NodeOp.function_api, 'Tensor.a.0')
+ api1.data = {GraphConst.JSON_INDEX_KEY: 0.9}
+ api2 = BaseNode(NodeOp.function_api, 'Tensor.b.0')
+ api2.data = {GraphConst.JSON_INDEX_KEY: 0.6}
+ apis.subnodes = [api1, api2]
+ sub_nodes = [BaseNode(NodeOp.module, 'Module.a.0'), apis, BaseNode(NodeOp.module, 'Module.a.1')]
+ comparator.graph_n.root.subnodes = sub_nodes
+ comparator._handle_api_collection_index()
+ self.assertEqual(comparator.graph_n.root.subnodes[1].data.get(GraphConst.JSON_INDEX_KEY), 0.6)
+
+ @patch('msprobe.pytorch.visualization.builder.msprobe_adapter.compare_node')
+ @patch('msprobe.pytorch.visualization.graph.graph.Graph.match')
+ @patch('msprobe.pytorch.visualization.graph.graph.Graph.mapping_match')
+ @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode')
+ @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file')
+ @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file')
+ def test__compare_nodes(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode,
+ mock_mapping_match, mock_match, mock_compare_node):
+ node_n = BaseNode(NodeOp.function_api, 'Tensor.a.0')
+ node_b = BaseNode(NodeOp.function_api, 'Tensor.b.0')
+ mock_load_data_json_file.return_value = {}
+ mock_load_json_file.return_value = {}
+ mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE
+ mock_mapping_match.return_value = (node_b, [], [])
+ mock_compare_node.return_value = ['result']
+ comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path)
+ comparator.mapping_config = True
+ comparator._compare_nodes(node_n)
+ self.assertEqual(node_n.matched_node_link, ['Tensor.b.0'])
+ self.assertEqual(node_b.matched_node_link, ['Tensor.a.0'])
+ comparator.mapping_config = False
+ node_n = BaseNode(NodeOp.function_api, 'Tensor.a.0')
+ node_b = BaseNode(NodeOp.function_api, 'Tensor.a.0')
+ mock_match.return_value = (node_b, [])
+ comparator._compare_nodes(node_n)
+ self.assertEqual(node_n.matched_node_link, ['Tensor.a.0'])
+ self.assertEqual(node_b.matched_node_link, ['Tensor.a.0'])
+
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..da76d8e0d57dc700d45a16f4ea79df1ff7ff1707
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py
@@ -0,0 +1,99 @@
+import json
+import unittest
+from unittest.mock import patch, MagicMock
+from msprobe.pytorch.visualization.compare.mode_adapter import ModeAdapter
+from msprobe.pytorch.visualization.graph.base_node import BaseNode, NodeOp
+from msprobe.pytorch.visualization.utils import GraphConst, ToolTip
+from msprobe.core.common.const import CompareConst
+
+
+class TestModeAdapter(unittest.TestCase):
+
+ def setUp(self):
+ self.node_op = NodeOp.module
+ self.node_id = "node_1"
+ self.node = BaseNode(self.node_op, self.node_id)
+ self.compare_mode = GraphConst.REAL_DATA_COMPARE
+ self.adapter = ModeAdapter(self.compare_mode)
+ self.compare_data_dict = [{}, {}]
+
+ def test_add_md5_compare_data(self):
+ node_data = {'md5_key': 'some_md5_value'}
+ compare_data_dict = {'md5_key': 'expected_md5_value'}
+ precision_index = ModeAdapter._add_md5_compare_data(node_data, compare_data_dict)
+ self.assertEqual(precision_index, 0)
+
+ @patch('msprobe.pytorch.visualization.compare.mode_adapter.ModeAdapter')
+ def test_parse_result(self, mock_mode_adapter):
+ mock_mode_adapter._add_summary_compare_data.return_value = 0.5
+ self.adapter.compare_mode = GraphConst.SUMMARY_COMPARE
+ precision_index, other_dict = self.adapter.parse_result(self.node, self.compare_data_dict)
+ self.assertEqual(precision_index, 0.5)
+ self.assertEqual(other_dict, {})
+
+ mock_mode_adapter._add_md5_compare_data.return_value = 1
+ self.adapter.compare_mode = GraphConst.MD5_COMPARE
+ precision_index, other_dict = self.adapter.parse_result(self.node, self.compare_data_dict)
+ self.assertEqual(precision_index, 1)
+ self.assertEqual(other_dict, {'Result': 'pass'})
+
+ mock_mode_adapter._add_real_compare_data.return_value = 0.6
+ self.adapter.compare_mode = GraphConst.REAL_DATA_COMPARE
+ precision_index, other_dict = self.adapter.parse_result(self.node, self.compare_data_dict)
+ self.assertEqual(precision_index, 0.0)
+ self.assertEqual(other_dict, {})
+
+ def test_prepare_real_data(self):
+ self.adapter.is_real_data_compare = MagicMock(return_value=True)
+ result = self.adapter.prepare_real_data(self.node)
+ self.assertTrue(result)
+
+ self.adapter.is_real_data_compare = MagicMock(return_value=False)
+ result = self.adapter.prepare_real_data(self.node)
+ self.assertFalse(result)
+
+ def test_compare_mode_methods(self):
+ self.adapter.compare_mode = GraphConst.SUMMARY_COMPARE
+ self.assertTrue(self.adapter.is_summary_compare())
+ self.assertFalse(self.adapter.is_md5_compare())
+ self.assertFalse(self.adapter.is_real_data_compare())
+
+ def test_add_csv_data(self):
+ compare_result_list = ['result1', 'result2']
+ self.adapter.add_csv_data(compare_result_list)
+ self.assertEqual(self.adapter.csv_data, compare_result_list)
+
+ def test_add_error_key(self):
+ node_data = {'key': {}}
+ self.adapter.compare_mode = GraphConst.REAL_DATA_COMPARE
+ self.adapter.add_error_key(node_data)
+ self.assertEqual(node_data['key'][GraphConst.ERROR_KEY],
+ [CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO])
+ node_data = {'key': {}}
+ self.adapter.compare_mode = GraphConst.SUMMARY_COMPARE
+ self.adapter.add_error_key(node_data)
+ self.assertEqual(node_data['key'][GraphConst.ERROR_KEY],
+ [CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR,
+ CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR])
+
+ def test_get_tool_tip(self):
+ self.adapter.compare_mode = GraphConst.MD5_COMPARE
+ tips = self.adapter.get_tool_tip()
+ self.assertEqual(tips, json.dumps({'md5': ToolTip.MD5}))
+
+ self.adapter.compare_mode = GraphConst.SUMMARY_COMPARE
+ tips = self.adapter.get_tool_tip()
+ self.assertEqual(tips, json.dumps({
+ CompareConst.MAX_DIFF: ToolTip.MAX_DIFF,
+ CompareConst.MIN_DIFF: ToolTip.MIN_DIFF,
+ CompareConst.MEAN_DIFF: ToolTip.MEAN_DIFF,
+ CompareConst.NORM_DIFF: ToolTip.NORM_DIFF}))
+
+ self.adapter.compare_mode = GraphConst.REAL_DATA_COMPARE
+ tips = self.adapter.get_tool_tip()
+ self.assertEqual(tips, json.dumps({
+ CompareConst.ONE_THOUSANDTH_ERR_RATIO: ToolTip.ONE_THOUSANDTH_ERR_RATIO,
+ CompareConst.FIVE_THOUSANDTHS_ERR_RATIO: ToolTip.FIVE_THOUSANDTHS_ERR_RATIO,
+ CompareConst.COSINE: ToolTip.COSINE,
+ CompareConst.MAX_ABS_ERR: ToolTip.MAX_ABS_ERR,
+ CompareConst.MAX_RELATIVE_ERR: ToolTip.MAX_RELATIVE_ERR}))
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_base_node.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_base_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f0f12582a36b21000416f2603143b87d0032c65
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_base_node.py
@@ -0,0 +1,76 @@
+import unittest
+from unittest.mock import patch
+from msprobe.pytorch.visualization.graph.base_node import BaseNode, NodeOp
+from msprobe.pytorch.visualization.utils import GraphConst
+
+
+class TestBaseNode(unittest.TestCase):
+
+ def setUp(self):
+ self.node_op = NodeOp.module
+ self.node_id = "node_1"
+ self.up_node = BaseNode(self.node_op, "up_node_1")
+ self.node = BaseNode(self.node_op, self.node_id, self.up_node)
+
+ def test_init_and_str(self):
+ self.assertEqual(self.node.op, self.node_op)
+ self.assertEqual(self.node.id, self.node_id)
+ self.assertEqual(str(self.node), 'id:\tnode_1')
+
+ def test_eq(self):
+ other_node = BaseNode(self.node_op, self.node_id, self.up_node)
+ self.assertEqual(self.node, other_node)
+
+ def test_get_suggestions(self):
+ self.node.get_suggestions()
+ self.assertIn(GraphConst.SUGGEST_KEY, self.node.suggestions)
+
+ node = BaseNode(NodeOp.function_api, "up_node_1")
+ node.get_suggestions()
+ self.assertIn(GraphConst.SUGGEST_KEY, node.suggestions)
+
+ def test_set_input_output(self):
+ input_data = {'input1': 'value1'}
+ output_data = {'output1': 'value2'}
+ self.node.set_input_output(input_data, output_data)
+ self.assertEqual(self.node.input_data, input_data)
+ self.assertEqual(self.node.output_data, output_data)
+
+ def test_add_upnode(self):
+ self.node = BaseNode(self.node_op, self.node_id)
+ new_up_node = BaseNode(self.node_op, "new_up_node_1")
+ self.node.add_upnode(new_up_node)
+ self.assertEqual(self.node.upnode, new_up_node)
+ self.assertIn(self.node, new_up_node.subnodes)
+
+ def test_add_link(self):
+ other_node = BaseNode(self.node_op, "other_node_1")
+ ancestors = ['a1', 'a2']
+ self.node.add_link(other_node, ancestors)
+ self.assertEqual(self.node.matched_node_link, ancestors)
+ self.assertEqual(other_node.matched_node_link, ancestors)
+
+ def test_to_dict(self):
+ expected_result = {
+ 'id': self.node_id,
+ 'node_type': self.node_op.value,
+ 'data': {},
+ 'output_data': {},
+ 'input_data': {},
+ 'upnode': self.up_node.id,
+ 'subnodes': [],
+ 'matched_node_link': [],
+ 'suggestions': {},
+ 'stack_info': []
+ }
+ self.assertEqual(self.node.to_dict(), expected_result)
+
+ def test_get_ancestors(self):
+ expected_ancestors = ['up_node_1']
+ self.assertEqual(self.node.get_ancestors(), expected_ancestors)
+
+ @patch('msprobe.pytorch.visualization.builder.msprobe_adapter.compare_mapping_data')
+ def test_compare_mapping_node(self, mock_compare_mapping_data):
+ mock_compare_mapping_data.return_value = True
+ result = self.node.compare_mapping_node(BaseNode(NodeOp.function_api, "up_node_1"))
+ self.assertTrue(result)
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_graph.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7a55bb36e2291caf14a5e27052c5600f4492d0c
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_graph.py
@@ -0,0 +1,102 @@
+import unittest
+from unittest.mock import MagicMock
+from msprobe.pytorch.visualization.graph.graph import Graph, NodeOp
+from msprobe.pytorch.visualization.graph.base_node import BaseNode
+from msprobe.pytorch.visualization.utils import GraphConst
+
+
+class TestGraph(unittest.TestCase):
+
+ def setUp(self):
+ self.graph = Graph("model_name")
+ self.node_id = "node_id"
+ self.node_op = NodeOp.module
+
+ def test_add_node_and_get_node(self):
+ self.graph.add_node(self.node_op, self.node_id)
+ node = self.graph.get_node(self.node_id)
+ self.assertIsNotNone(node)
+ self.assertIn(self.node_id, self.graph.node_map)
+
+ node_id = "api"
+ graph = Graph("model_name")
+ for i in range(0, 9):
+ graph.add_node(NodeOp.function_api, node_id, id_accumulation=True)
+ self.assertEqual(len(graph.node_map), 10)
+ self.assertIn("api.0", graph.node_map)
+ self.assertIn("api.8", graph.node_map)
+ self.assertNotIn("api", graph.node_map)
+
+ def test_to_dict(self):
+ self.graph.add_node(self.node_op, self.node_id)
+ result = self.graph.to_dict()
+ self.assertEqual(result[GraphConst.JSON_ROOT_KEY], "model_name")
+ self.assertIn(self.node_id, result[GraphConst.JSON_NODE_KEY])
+
+ def test_str(self):
+ self.graph.add_node(self.node_op, self.node_id)
+ expected_str = f'{self.node_id}'
+ self.assertIn(expected_str, str(self.graph))
+
+ def test_match(self):
+ graph_a = Graph("model_name_a")
+ graph_b = Graph("model_name_b")
+ node_a = BaseNode(self.node_op, self.node_id)
+ graph_a.add_node(NodeOp.module, "node_id_a")
+ graph_b.add_node(NodeOp.module, "node_id_b")
+ matched_node, ancestors = Graph.match(graph_a, node_a, graph_b)
+ self.assertIsNone(matched_node)
+ self.assertEqual(ancestors, [])
+
+ graph_b.add_node(NodeOp.module, "node_id_a")
+ graph_a.add_node(NodeOp.module, "node_id_a_1", graph_a.get_node("node_id_a"))
+ graph_b.add_node(NodeOp.module, "node_id_a_1", graph_a.get_node("node_id_a"))
+ matched_node, ancestors = Graph.match(graph_a, graph_a.get_node("node_id_a_1"), graph_b)
+ self.assertIsNotNone(matched_node)
+ self.assertEqual(ancestors, ['node_id_a'])
+
+ def test_dfs(self):
+ graph = Graph("model_name")
+ graph.add_node(NodeOp.module, "node_a")
+ graph.add_node(NodeOp.module, "node_b")
+ node_a = BaseNode(self.node_op, self.node_id)
+ result = {}
+ graph.dfs(node_a, result)
+ self.assertEqual(result, {'node_id': {'id': 'node_id', 'node_type': 0, 'data': {},
+ 'output_data': {}, 'input_data': {}, 'upnode': 'None', 'subnodes': [],
+ 'matched_node_link': [], 'suggestions': {}}})
+
+ def test_split_nodes_by_micro_step(self):
+ nodes = [BaseNode(NodeOp.module, 'a.0'), BaseNode(NodeOp.module, 'b.0'),
+ BaseNode(NodeOp.api_collection, 'apis.0'), BaseNode(NodeOp.module, 'a.1'),
+ BaseNode(NodeOp.module, 'b.1'), BaseNode(NodeOp.api_collection, 'apis.1')]
+ result = Graph.split_nodes_by_micro_step(nodes)
+ self.assertEqual(len(result), 2)
+ self.assertEqual(len(result[0]), 3)
+
+ def test_paging_by_micro_step(self):
+ nodes = [BaseNode(NodeOp.module, 'a.0'), BaseNode(NodeOp.module, 'b.0'),
+ BaseNode(NodeOp.api_collection, 'apis.0'), BaseNode(NodeOp.module, 'a.1'),
+ BaseNode(NodeOp.module, 'b.1'), BaseNode(NodeOp.api_collection, 'apis.1')]
+
+ graph = Graph('Model1')
+ graph.root.subnodes = nodes
+ graph_other = Graph('Model2')
+ graph_other.root.subnodes = nodes
+
+ result = graph.paging_by_micro_step(graph_other)
+ self.assertEqual(result, 2)
+ self.assertEqual(graph.root.subnodes[0].micro_step_id, 0)
+ self.assertEqual(graph_other.root.subnodes[0].micro_step_id, 0)
+
+ def test_mapping_match(self):
+ mapping_config = MagicMock()
+ graph_a = Graph("model_name_a")
+ graph_b = Graph("model_name_b")
+ graph_a.add_node(NodeOp.module, "a1", BaseNode(NodeOp.module, "root"))
+ graph_b.add_node(NodeOp.module, "b1", BaseNode(NodeOp.module, "root"))
+ mapping_config.get_mapping_string.return_value = "b1"
+ node_b, ancestors_n, ancestors_b = Graph.mapping_match(graph_a.get_node("a1"), graph_b, mapping_config)
+ self.assertIsNotNone(node_b)
+ self.assertEqual(ancestors_n, ["root"])
+ self.assertEqual(ancestors_b, ["root"])
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_node_colors.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_node_colors.py
new file mode 100644
index 0000000000000000000000000000000000000000..0be05586b150cfd39670535cc3015925d9e2f44e
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_node_colors.py
@@ -0,0 +1,67 @@
+import unittest
+from msprobe.pytorch.visualization.graph.node_colors import NodeColors, SUMMARY_DESCRIPTION, REAL_DATA_DESCRIPTION, \
+ NOT_MATCHED
+from msprobe.pytorch.visualization.utils import GraphConst
+
+
+class TestNodeColors(unittest.TestCase):
+
+ def test_get_info_by_mode(self):
+ node_yellow = NodeColors.YELLOW_1
+ summary_info = node_yellow.get_info_by_mode(GraphConst.SUMMARY_COMPARE)
+ self.assertEqual(summary_info[GraphConst.VALUE], [0, 0.2])
+ self.assertEqual(summary_info[GraphConst.DESCRIPTION], SUMMARY_DESCRIPTION)
+ node_grey = NodeColors.GREY
+ md5_info = node_grey.get_info_by_mode(GraphConst.MD5_COMPARE)
+ self.assertEqual(md5_info[GraphConst.VALUE], [])
+ self.assertEqual(md5_info[GraphConst.DESCRIPTION], NOT_MATCHED)
+ node_red = NodeColors.RED
+ real_info = node_red.get_info_by_mode(GraphConst.REAL_DATA_COMPARE)
+ self.assertEqual(real_info[GraphConst.VALUE], [0.2, 1])
+ self.assertEqual(real_info[GraphConst.DESCRIPTION], REAL_DATA_DESCRIPTION)
+ none_info = node_yellow.get_info_by_mode("non_existent_mode")
+ self.assertEqual(none_info, {})
+
+ def test_get_node_colors(self):
+ # 测试获取所有颜色信息的函数
+ mode = GraphConst.SUMMARY_COMPARE
+ colors_info = NodeColors.get_node_colors(mode)
+ self.assertIn("#FFFCF3", colors_info)
+ self.assertIn("#FFEDBE", colors_info)
+ self.assertIn("#FFDC7F", colors_info)
+ self.assertIn("#FFC62E", colors_info)
+ self.assertIn("#E32020", colors_info)
+ self.assertIn("#C7C7C7", colors_info)
+
+ # 确保返回的字典具有正确的描述和值范围
+ expected_value_range = [0, 0.2]
+ expected_description = "此节点所有输入输出的统计量相对误差, 值越大代表测量值与标杆值的偏差越大, 相对误差计算方式:|(测量值-标杆值)/标杆值|"
+ self.assertEqual(colors_info["#FFFCF3"][GraphConst.VALUE], expected_value_range)
+ self.assertEqual(colors_info["#FFFCF3"][GraphConst.DESCRIPTION], expected_description)
+
+ mode = GraphConst.MD5_COMPARE
+ colors_info = NodeColors.get_node_colors(mode)
+ self.assertIn("#FFFCF3", colors_info)
+ self.assertIn("#C7C7C7", colors_info)
+ self.assertNotIn("#FFDC7F", colors_info)
+
+ expected_value_range = [1, 1]
+ expected_description = "与标杆相比, 此节点所有输入输出的md5值相同"
+ self.assertEqual(colors_info["#FFFCF3"][GraphConst.VALUE], expected_value_range)
+ self.assertEqual(colors_info["#FFFCF3"][GraphConst.DESCRIPTION], expected_description)
+
+ def test_get_node_error_status(self):
+ # 测试错误状态判断功能
+ mode = GraphConst.SUMMARY_COMPARE
+ value0 = 0
+ value1 = 0.25
+ value2 = 0.55
+ value3 = 111
+ self.assertFalse(NodeColors.get_node_error_status(mode, value0))
+ self.assertFalse(NodeColors.get_node_error_status(mode, value1))
+ self.assertTrue(NodeColors.get_node_error_status(mode, value2))
+ self.assertTrue(NodeColors.get_node_error_status(mode, value3))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_node_op.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_node_op.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a340ac8b3c7144a9e07485c93e289a950eee8c7
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_node_op.py
@@ -0,0 +1,28 @@
+import unittest
+from msprobe.pytorch.visualization.graph.node_op import NodeOp
+
+
+class TestNodeOp(unittest.TestCase):
+
+ def test_get_node_op_valid(self):
+ node_name = "ModuleTest"
+ self.assertEqual(NodeOp.get_node_op(node_name), NodeOp.module)
+
+ def test_get_node_op_invalid(self):
+ node_name = "InvalidNodeName"
+ with self.assertRaises(Exception):
+ NodeOp.get_node_op(node_name)
+
+ def test_get_node_op_all(self):
+ test_cases = [
+ ("ModuleTest", NodeOp.module),
+ ("TensorTest", NodeOp.function_api),
+ ("TorchTest", NodeOp.function_api),
+ ("FunctionalTest", NodeOp.function_api),
+ ("NPUTest", NodeOp.function_api),
+ ("VFTest", NodeOp.function_api),
+ ("DistributedTest", NodeOp.function_api),
+ ("AtenTest", NodeOp.function_api)
+ ]
+ for node_name, expected_op in test_cases:
+ self.assertEqual(NodeOp.get_node_op(node_name), expected_op)
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/mapping.yaml b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/mapping.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8b2f85ebf872aae4b3377842ac899824da5877f9
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/mapping.yaml
@@ -0,0 +1,2 @@
+- vision_model: "language_model.vision_encoder"
+- vision_projection: "language_model.projection"
\ No newline at end of file
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_mapping_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_mapping_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..010a4f686198ef1127299fcad1b9a5abf6505d6b
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_mapping_config.py
@@ -0,0 +1,52 @@
+import os
+import unittest
+from msprobe.pytorch.visualization.mapping_config import MappingConfig
+
+
+class TestMappingConfig(unittest.TestCase):
+
+ def setUp(self):
+ self.yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mapping.yaml")
+
+ def test_validate(self):
+ with self.assertRaises(ValueError):
+ MappingConfig.validate(123, "some value")
+ with self.assertRaises(ValueError):
+ MappingConfig.validate("some key", 456)
+ self.assertEqual(MappingConfig.validate("key", "value"), "value")
+
+ def test_convert_to_regex(self):
+ regex = MappingConfig.convert_to_regex("hello{world}")
+ self.assertEqual(regex, ".*hello\\{world\\}.*")
+
+ def test_replace_parts(self):
+ result = MappingConfig._replace_parts('hello world', 'world', 'everyone')
+ self.assertEqual(result, 'hello everyone')
+ result = MappingConfig._replace_parts('radio_model.layers.0.input_norm', 'radio_model.layers.{}.input_norm',
+ 'radio_model.transformer.layers.{}.input_layernorm')
+ self.assertEqual(result, 'radio_model.transformer.layers.0.input_layernorm')
+
+ def test_get_mapping_string(self):
+ mc = MappingConfig(self.yaml_path)
+ mc.classify_config = {
+ 'category1': [('category1.key1', 'replacement1')],
+ 'category2': [('category2.key1', 'replacement2')]
+ }
+ result = mc.get_mapping_string("some category1.key1 text")
+ self.assertEqual(result, "some replacement1 text")
+
+ def test_long_string(self):
+ long_string = "x" * (MappingConfig.MAX_STRING_LEN + 1)
+ mc = MappingConfig(self.yaml_path)
+ result = mc.get_mapping_string(long_string)
+ self.assertEqual(result, long_string)
+
+ def test__classify_and_sort_keys(self):
+ mc = MappingConfig(self.yaml_path)
+ result = mc._classify_and_sort_keys()
+ self.assertEqual(result, {'vision_model': [('vision_model', 'language_model.vision_encoder')],
+ 'vision_projection': [('vision_projection', 'language_model.projection')]})
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..56493235cd63d97b8a2beca0358c59ed16a154cf
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_utils.py
@@ -0,0 +1,27 @@
+import os
+import unittest
+from msprobe.pytorch.visualization.utils import (load_json_file, load_data_json_file, str2float)
+
+
+class TestMappingConfig(unittest.TestCase):
+
+ def setUp(self):
+ self.yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mapping.yaml")
+
+ def test_load_json_file(self):
+ result = load_json_file(self.yaml_path)
+ self.assertEqual(result, {})
+
+ def test_load_data_json_file(self):
+ result = load_data_json_file(self.yaml_path)
+ self.assertEqual(result, {})
+
+ def test_str2float(self):
+ result = str2float('23.4%')
+ self.assertAlmostEqual(result, 0.234)
+ result = str2float('2.3.4%')
+ self.assertAlmostEqual(result, 0)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/debug/accuracy_tools/atat/test/resources/advisor.txt b/debug/accuracy_tools/msprobe/test/resources/advisor.txt
similarity index 100%
rename from debug/accuracy_tools/atat/test/resources/advisor.txt
rename to debug/accuracy_tools/msprobe/test/resources/advisor.txt
diff --git a/debug/accuracy_tools/atat/test/resources/compare_result_20230703104808.csv b/debug/accuracy_tools/msprobe/test/resources/compare_result_20230703104808.csv
similarity index 100%
rename from debug/accuracy_tools/atat/test/resources/compare_result_20230703104808.csv
rename to debug/accuracy_tools/msprobe/test/resources/compare_result_20230703104808.csv
diff --git a/debug/accuracy_tools/atat/test/resources/compare_result_without_accuracy.csv b/debug/accuracy_tools/msprobe/test/resources/compare_result_without_accuracy.csv
similarity index 100%
rename from debug/accuracy_tools/atat/test/resources/compare_result_without_accuracy.csv
rename to debug/accuracy_tools/msprobe/test/resources/compare_result_without_accuracy.csv
diff --git a/debug/accuracy_tools/atat/test/resources/config.yaml b/debug/accuracy_tools/msprobe/test/resources/config.yaml
similarity index 100%
rename from debug/accuracy_tools/atat/test/resources/config.yaml
rename to debug/accuracy_tools/msprobe/test/resources/config.yaml
diff --git a/debug/accuracy_tools/atat/test/resources/npu_test.pkl b/debug/accuracy_tools/msprobe/test/resources/npu_test.pkl
similarity index 100%
rename from debug/accuracy_tools/atat/test/resources/npu_test.pkl
rename to debug/accuracy_tools/msprobe/test/resources/npu_test.pkl
diff --git a/debug/accuracy_tools/atat/test/run_test.sh b/debug/accuracy_tools/msprobe/test/run_test.sh
similarity index 100%
rename from debug/accuracy_tools/atat/test/run_test.sh
rename to debug/accuracy_tools/msprobe/test/run_test.sh
diff --git a/debug/accuracy_tools/atat/test/run_ut.py b/debug/accuracy_tools/msprobe/test/run_ut.py
similarity index 97%
rename from debug/accuracy_tools/atat/test/run_ut.py
rename to debug/accuracy_tools/msprobe/test/run_ut.py
index 7c593c14abca82f39050276255316693a47c6fc9..8ea81ccca719952bdb8a6603b998902df94a53fb 100644
--- a/debug/accuracy_tools/atat/test/run_ut.py
+++ b/debug/accuracy_tools/msprobe/test/run_ut.py
@@ -3,7 +3,7 @@ import shutil
import subprocess
import sys
-from atat.core.common.log import logger
+from msprobe.core.common.log import logger
def run_ut():
diff --git a/debug/accuracy_tools/atat/test/test_module_processer.py b/debug/accuracy_tools/msprobe/test/test_module_processer.py
similarity index 95%
rename from debug/accuracy_tools/atat/test/test_module_processer.py
rename to debug/accuracy_tools/msprobe/test/test_module_processer.py
index 89ee299f66fc1c14e9fcb666bd6d21e0ca2f17a9..448c35f0554551884dc690a71aef6bc8141e9a39 100644
--- a/debug/accuracy_tools/atat/test/test_module_processer.py
+++ b/debug/accuracy_tools/msprobe/test/test_module_processer.py
@@ -1,6 +1,6 @@
import unittest
-from atat.pytorch.module_processer import ModuleProcesser
-from atat.pytorch.common.utils import Const
+from msprobe.pytorch.module_processer import ModuleProcesser
+from msprobe.pytorch.common.utils import Const
import torch
diff --git a/debug/accuracy_tools/setup.py b/debug/accuracy_tools/setup.py
index 3568e3a47c2fa2a42b016fb787683565e4c3ab0b..58b6881b41aa44098996a24877aeae43a66dcced 100644
--- a/debug/accuracy_tools/setup.py
+++ b/debug/accuracy_tools/setup.py
@@ -14,27 +14,54 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
+import subprocess
+import setuptools
+from setuptools import setup
-from setuptools import setup, find_packages
+
+repository_url = "https://gitee.com/ascend/msit.git"
+target_dir = "./msprobe/msit"
+branch_name = "master"
+
+subprocess.check_call(["git", "submodule", "add", "-b", branch_name, repository_url, target_dir])
+subprocess.check_call(["git", "submodule", "init"])
+subprocess.check_call(["git", "submodule", "update"])
+
+
+EXCLUDE_PKGS = [
+ "api_accuracy_checker*",
+ "grad_tool*",
+ "monitor*",
+ "ptdbg_ascend*",
+ "msprobe.test*",
+]
setup(
- name='ascend_training_accuracy_tools',
- version='1.0',
+ name='mindstudio-probe',
+ version='1.0.4',
description='This is a pytorch precision comparison tools',
long_description='This is a pytorch precision comparison tools, include ptdbg and api accuracy checker',
- packages=find_packages(),
+ packages=setuptools.find_namespace_packages(exclude=EXCLUDE_PKGS, include=["msprobe", "msprobe*"]),
install_requires=[
"wheel",
- "numpy",
- "pandas >= 1.3.5",
+ "einops",
+ "numpy < 2.0",
+ "pandas >= 1.3.5, < 2.1",
"pyyaml",
"rich",
"tqdm",
- "openpyxl"
+ "openpyxl",
+ "pyOpenSSL",
+ "twisted",
+ "matplotlib",
+ "protobuf <= 3.20.1",
+ "service_identity",
+ "onnx"
],
include_package_data=True,
ext_modules=[],
zip_safe=False,
entry_points={
- 'console_scripts' : ['atat=atat.atat:main'],
- },)
\ No newline at end of file
+ 'console_scripts': ['msprobe=msprobe.msprobe:main',
+ 'add_torchair_compare_path=msprobe.pytorch.torchair_compare.msit_path_add:create_symlink'],
+ },)
diff --git a/plugins/mindstudio-insight-plugins/CMakeLists.txt b/plugins/mindstudio-insight-plugins/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..48c8919553685570400ef939df31cfcf9eb7d120
--- /dev/null
+++ b/plugins/mindstudio-insight-plugins/CMakeLists.txt
@@ -0,0 +1,13 @@
+cmake_minimum_required(VERSION 3.20)
+project(MindStudio_Board)
+set(PROJECT_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR})
+include(${CMAKE_CURRENT_SOURCE_DIR}/plugin_core/cmake/mind_expression.cmake)
+add_subdirectory(proto)
+
+#if (${CMAKE_BUILD_TYPE} MATCHES "Debug")
+message(STATUS "Open tools generate")
+add_subdirectory(tools/httpServer)
+#endif ()
+# add_subdirectory(Scalar)
+add_subdirectory(Histogram)
+add_subdirectory(plugin_core)
\ No newline at end of file
diff --git a/plugins/mindstudio-insight-plugins/Histogram/CMakeLists.txt b/plugins/mindstudio-insight-plugins/Histogram/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b1df7818cbf8dc8e82d6f6d9326782072f2cd1b
--- /dev/null
+++ b/plugins/mindstudio-insight-plugins/Histogram/CMakeLists.txt
@@ -0,0 +1,2 @@
+set(HISTOGRAM_PROJECT_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR})
+add_subdirectory(server)
\ No newline at end of file
diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/CMakeLists.txt b/plugins/mindstudio-insight-plugins/Histogram/server/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1d10f8d8a0674a039a224a5e546ebd8d842aa97a
--- /dev/null
+++ b/plugins/mindstudio-insight-plugins/Histogram/server/CMakeLists.txt
@@ -0,0 +1,45 @@
+cmake_minimum_required(VERSION 3.20)
+project(Histogram)
+
+set(CMAKE_BUILD_TYPE Debug CACHE STRING "Build type" FORCE)
+
+if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
+ if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0)
+ message(FATAL_ERROR "GCC version must be 7.3.0 and above, but found ${CMAKE_CXX_COMPILER_VERSION}")
+ elseif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 11.4.0)
+ message(WARNING "GCC version ${CMAKE_CXX_COMPILER_VERSION} is greater than 11.4.0, may cause unknown problems.")
+ endif()
+endif()
+
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_C_STANDARD 11)
+
+set(HOME_DIR ${PROJECT_SOURCE_DIR})
+set(EXECUTABLE_OUTPUT_PATH ${PROJECT_ROOT_DIR}/output/plugins)
+set(LIBRARY_OUTPUT_PATH ${PROJECT_ROOT_DIR}/output/plugins)
+
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fstack-protector-all -D_FORTIFY_SOURCE=2 -O0 -g -ftrapv -fstack-protector-strong -fPIE -fPIC")
+set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fstack-protector-all -D_FORTIFY_SOURCE=2 -O0 -g -ftrapv -fstack-protector-strong -fPIE -fPIC")
+if (${CMAKE_BUILD_TYPE} MATCHES "Debug")
+ message(STATUS "Enable debug symbol table, change optimization level to 0")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0")
+ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -g -O0")
+endif ()
+
+set(CMAKE_SKIP_RPATH true)
+if (CMAKE_SYSTEM_NAME MATCHES "windows")
+ if ((NOT CMAKE_BUILD_TYPE MATCHES "Debug") AND (NOT CMAKE_BUILD_TYPE MATCHES "PROFILE"))
+ message(STATUS "Build type = ${CMAKE_BUILD_TYPE}, static = enable.")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--nxcompat -Wl,--dynamicbase -s -pie -Wincompatible-pointer-types")
+ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wl,--nxcompat -Wl,--dynamicbase -s -pie -Wincompatible-pointer-types")
+ add_link_options(-static)
+ endif()
+elseif()
+ if ((NOT CMAKE_BUILD_TYPE MATCHES "Debug") AND (NOT CMAKE_BUILD_TYPE MATCHES "PROFILE"))
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -s -pie -Wl,-z,now")
+ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -s -pie -Wl,-z,now")
+ endif()
+endif()
+
+
+add_subdirectory(src)
\ No newline at end of file
diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/CMakeLists.txt b/plugins/mindstudio-insight-plugins/Histogram/server/src/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b054d2e9d6eed28708488a957edb05a5e168e0a3
--- /dev/null
+++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/CMakeLists.txt
@@ -0,0 +1,84 @@
+set(SRC_HOME_DIR ${HOME_DIR}/src)
+aux_source_directory(defs HISTOGRAM_SRC_LIST)
+aux_source_directory(handler HISTOGRAM_SRC_LIST)
+aux_source_directory(histogramManager HISTOGRAM_SRC_LIST)
+aux_source_directory(histoParser HISTOGRAM_SRC_LIST)
+aux_source_directory(viewManager HISTOGRAM_SRC_LIST)
+aux_source_directory(plugin HISTOGRAM_SRC_LIST)
+aux_source_directory(${PROJECT_ROOT_DIR}/proto HISTOGRAM_SRC_LIST)
+
+set(LOG_SRC ${PROJECT_ROOT_DIR}/plugin_core/src/Logger.cpp)
+
+list(APPEND ${PROJECT_NAME}_SRC
+ ${PROTO_SRC}
+ ${HISTOGRAM_SRC_LIST}
+ ${LOG_SRC})
+include_directories(${SRC_HOME_DIR}
+ ${SRC_HOME_DIR}/parser
+ ${PROJECT_ROOT_DIR}/proto
+ ${PROJECT_ROOT_DIR}
+ ${SRC_HOME_DIR}/defs
+ ${SRC_HOME_DIR}/handler
+ ${SRC_HOME_DIR}/histoParser
+ ${SRC_HOME_DIR}/viewManager
+ ${SRC_HOME_DIR}/plugin
+
+)
+set(LIBRARY_OUTPUT_PATH ${LIBRARY_OUTPUT_PATH}/${PROJECT_NAME})
+add_library(${PROJECT_NAME} SHARED ${${PROJECT_NAME}_SRC}
+ utils/FileUtils.h
+ handler/GetHistoDataRequestHandler.h
+ handler/GetHistoDataRequestHandler.cpp
+ handler/GetNewFilesRequestHandler.h
+ handler/GetNewFilesRequestHandler.cpp
+ handler/AddImportFileRequestHandler.cpp
+ handler/AddImportFileRequestHandler.h
+ handler/ImportRequestHandler.h
+ viewManager/FileMonitor.h
+ histoParser/MindsporeParser.h
+ histoParser/MindsporeParser.cpp)
+target_include_directories(${PROJECT_NAME} PRIVATE ${${PROJECT_NAME}_H})
+target_include_directories(${PROJECT_NAME} PRIVATE ${PROJECT_ROOT_DIR}/rapidjson/include/rapidjson)
+if (${CMAKE_BUILD_TYPE} MATCHES "Debug")
+ message(STATUS "Open Debug Info")
+ target_compile_options(${PROJECT_NAME} PRIVATE -O0 -g)
+endif ()
+target_include_directories(${PROJECT_NAME} PRIVATE ${PROJECT_ROOT_DIR}/plugin_core/include)
+target_link_libraries(${PROJECT_NAME} PRIVATE ${PROJECT_ROOT_DIR}/output/lib/libmsinsight.so)
+target_link_libraries(${PROJECT_NAME} PRIVATE mindboard::protobuf)
+set(CMAKE_CXX_VISIBILITY_PRESET default)
+if (${CMAKE_SYSTEM_NAME} MATCHES "Linux")
+ target_link_libraries(${PROJECT_NAME} PRIVATE stdc++fs)
+endif ()
+
+#-------- test --------
+
+if (ENABLE_TESTCASES OR ENABLE_CPP_ST)
+ enable_testing()
+ aux_source_directory(test TEST_SRC)
+ aux_source_directory(${PROJECT_ROOT_DIR}/plugin_core/src TEST_SRC)
+ list(APPEND ${PROJECT_NAME}_TEST_SRC
+ ${${PROJECT_NAME}_SRC}
+ ${TEST_SRC})
+ list(APPEND ${PROJECT_NAME}_TEST_H
+ ${${PROJECT_NAME}_H})
+ add_executable(${PROJECT_NAME}_test ${${PROJECT_NAME}_TEST_SRC})
+ message(STATUS "Dir:${PROJECT_ROOT_DIR}")
+ target_include_directories(${PROJECT_NAME}_test PUBLIC ${${PROJECT_NAME}_TEST_H})
+ target_include_directories(${PROJECT_NAME}_test PUBLIC ${PROJECT_ROOT_DIR}/plugin_core/include)
+ target_include_directories(${PROJECT_NAME}_test PUBLIC ${PROJECT_ROOT_DIR})
+ target_include_directories(${PROJECT_NAME}_test PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/plugin)
+ # if (CMAKE_SYSTEM_NAME STREQUAL "Linux")
+ message(STATUS "Open Debug ${CMAKE_BUILD_TYPE}")
+ target_link_libraries(${PROJECT_NAME}_test PRIVATE stdc++fs)
+ # endif ()
+ target_link_libraries(${PROJECT_NAME}_test PRIVATE mindboard::protobuf)
+ target_link_libraries(${PROJECT_NAME}_test PUBLIC gtest_main gmock_main)
+ target_link_libraries(${PROJECT_NAME}_test PUBLIC ${PROJECT_ROOT_DIR}/output/lib/libmsinsight.so)
+ target_link_libraries(${PROJECT_NAME}_test PRIVATE mindboard::protobuf)
+ if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
+ target_compile_options(${PROJECT_NAME}_test PRIVATE -O0 -g)
+ endif ()
+
+ set_target_properties(${PROJECT_NAME}_test PROPERTIES EXCLUDE_FROM_ALL true)
+endif ()
\ No newline at end of file
diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/defs/HistoConceptDefs.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/defs/HistoConceptDefs.h
new file mode 100644
index 0000000000000000000000000000000000000000..223d65333b0758357e56ccc563a2ea4ec0ce2a72
--- /dev/null
+++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/defs/HistoConceptDefs.h
@@ -0,0 +1,140 @@
+/*
+* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved.
+ */
+#ifndef CONCEPTDEFS_H
+#define CONCEPTDEFS_H
+#include
+#include
+
+#include "rapidjson.h"
+#include "document.h"
+#ifdef _WIN32
+#include
+namespace fs = std::filesystem;
+#else
+#include
+namespace fs = std::experimental::filesystem;
+#endif
+#include
+
+#include "proto/event.pb.h"
+#include "proto/mindspore_summary.pb.h"
+
+namespace Insight::Histogram {
+ using value_t = rapidjson::Value;
+ using document_t = rapidjson::Document;
+
+ enum class ParseDataType {
+ MindSpore_Summary = 0,
+ TF_EVENT = 1,
+ Unknown = 2
+ };
+
+ enum ErrCode : int {
+ OK = 0,
+ INVALID_REQUEST_JSON,
+ REQUEST_INVALID_PARAM,
+ INVALID_PATH
+ };
+
+ // 直方图上每一条线,y轴的高度和属于哪一个step,由HistogramProto转化而来
+ struct HistogramLine {
+ int64_t step_{0};
+ std::vector bucket_;
+ std::vector bucketLimit_;
+ HistogramLine() = default;
+ HistogramLine(int64_t step, tensorboard::HistogramProto hisProto) {
+ step_ = step;
+ std::copy(hisProto.bucket().begin(), hisProto.bucket().end(), std::back_inserter(bucket_));
+ std::copy(hisProto.bucket_limit().begin(), hisProto.bucket_limit().end(),
+ std::back_inserter(bucketLimit_));
+ }
+
+ HistogramLine(int64_t step, mindspore::irpb::Summary_Histogram histogram) {
+ step_ = step;
+ for (auto& bucket : histogram.buckets()) {
+ bucketLimit_.push_back(bucket.left() + bucket.width());
+ bucket_.push_back(bucket.count());
+ }
+ }
+
+ value_t CreatHistoLineValue(document_t::AllocatorType& allocator) {
+ value_t histoGramLine(rapidjson::kObjectType);
+ // bucket_ 数据
+ value_t bucket(rapidjson::kArrayType);
+ for (const auto& count: bucket_) {
+ bucket.PushBack(count, allocator);
+ }
+ value_t bucketLimit(rapidjson::kArrayType);
+ for (const auto& limit: bucketLimit_) {
+ bucketLimit.PushBack(limit, allocator);
+ }
+ histoGramLine.AddMember("step", step_, allocator);
+ histoGramLine.AddMember("bucket", bucket, allocator);
+ histoGramLine.AddMember("bucketLimit", bucketLimit, allocator);
+ return histoGramLine;
+ }
+ };
+
+ struct HistogramGraph {
+ std::vector hsitogramLines_;
+ // 记录上一次getdata请求的时候hsitogramLines_长度,初始未请求时是0
+ uint64_t length{0};
+ // 下采样的缓存数据
+ std::vector downsampleCatchLines;
+ HistogramGraph() = default;
+ explicit HistogramGraph(const std::vector& hsitogramLines) : hsitogramLines_(hsitogramLines) {}
+ void AddValue(const HistogramLine& line) {
+ hsitogramLines_.push_back(line);
+ }
+ void MergeData(const HistogramGraph& graph) {
+ hsitogramLines_.insert(hsitogramLines_.begin(),graph.hsitogramLines_.begin(), graph.hsitogramLines_.end());
+ }
+
+ // 对数据进行下采样
+ void SetdDownsampleCatchLines() {
+ constexpr int reservoirNum = 50;
+ // hsitogramLines_ 小于 length 或者抽样数量比存储的数量多
+ if (hsitogramLines_.size() <= length || reservoirNum > hsitogramLines_.size()) {
+ return;
+ }
+
+ std::vector indices(hsitogramLines_.size() - 1);
+ std::iota(indices.begin(), indices.end(), 0);
+
+ // 使用固定的随机数种子0进行随机打乱
+ std::default_random_engine generator(0);
+ std::shuffle(indices.begin(), indices.end(), generator);
+
+ // 选择前k - 1 个元素,并添加最后一个元素
+ indices.resize(reservoirNum - 1);
+ std::sort(indices.begin(), indices.end());
+ indices.push_back(hsitogramLines_.size() - 1);
+
+ std::vector result;
+ for (int idx : indices) {
+ downsampleCatchLines.push_back(hsitogramLines_[idx]);
+ }
+ }
+
+ void UpdateCatchData() {
+ // 更新缓存的下采样数据
+ SetdDownsampleCatchLines();
+ // 更新长度到最新的长度
+ length = hsitogramLines_.size();
+ }
+
+ value_t CreatHistogramValue(document_t::AllocatorType& allocator) {
+ // 获取json说明接收到了getdata的请求,所以更新下采样数据
+ UpdateCatchData();
+ // 转成json格式
+ value_t hsitogramLinesValue(rapidjson::kArrayType);
+ for (auto& hsitogramLine: downsampleCatchLines) {
+ value_t histogram = hsitogramLine.CreatHistoLineValue(allocator);
+ hsitogramLinesValue.PushBack(histogram, allocator);
+ }
+ return hsitogramLinesValue;
+ }
+ };
+}
+#endif //CONCEPTDEFS_H
diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/defs/RequestDef.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/defs/RequestDef.h
new file mode 100644
index 0000000000000000000000000000000000000000..55a63a98021ddce904a247b27c4d9cabe83f62b8
--- /dev/null
+++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/defs/RequestDef.h
@@ -0,0 +1,99 @@
+/*
+* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved.
+ */
+#ifndef REQUEST_DEF_H
+#define REQUEST_DEF_H
+
+#include
+#include
+#include "rapidjson.h"
+#include "document.h"
+
+#include "HistoConceptDefs.h"
+
+namespace Insight::Histogram {
+ using document_t = rapidjson::Document;
+ using value_t = rapidjson::Value;
+ using size_type_t =rapidjson::SizeType;
+ struct ImportRequest {
+ std::vector rootPathList;
+ ImportRequest() = default;
+ explicit ImportRequest(std::string_view request) {
+ std::string jsonString(request);
+ document_t document;
+ if (document.Parse(jsonString.c_str()).HasParseError()) {
+ return;
+ }
+ if (document.HasMember("rootPaths") && document["rootPaths"].IsArray()) {
+ const value_t& pathList = document["rootPaths"];
+ for (size_type_t i = 0; i < pathList.Size(); i++) {
+ rootPathList.emplace_back(pathList[i].GetString());
+ }
+ }
+ }
+ };
+
+ // 获取新增文件的接口,返回rootpath下面路径列表,没有入参 filrlist
+ struct GetNewFilePathRequest {
+ };
+
+ // 更新file和tag的映射关系, 返回的是file tagsd对应关系,这部分一定是没有解析过的文件最新导入进来
+ struct AddImportFileRequest {
+ std::vector filepathList; //确定的具体文件
+ AddImportFileRequest() = default;
+ explicit AddImportFileRequest(std::string_view request) {
+ std::string jsonString(request);
+ document_t document;
+ if (document.Parse(jsonString.c_str()).HasParseError()) {
+ return;
+ }
+ if (document.HasMember("filePathList") && document["filePathList"].IsArray()) {
+ const value_t& pathList = document["filePathList"];
+ for (size_type_t i = 0; i < pathList.Size(); i++) {
+ filepathList.emplace_back(pathList[i].GetString());
+ }
+ }
+ }
+ };
+
+ /*
+ * 前端发送的json请求格式如下
+ * {
+ * "filePathTotags": [
+ * {"filepath": "/home/text/path1", "tag": "tagXXX1"},
+ * {"filepath": "/home/text/path2", "tag": "tagXXX2"},
+ * ]
+ * }
+ **/
+ // 获取具体某个图的时候,会限定是哪个文件的哪个tag的图。一个文件可能有多个
+ struct GetHistoDataRequest {
+ std::map> filePathTotags;
+ GetHistoDataRequest() = default;
+ explicit GetHistoDataRequest(std::string_view request)
+ {
+ std::string jsonString(request);
+ document_t document;
+ if (document.Parse(jsonString.c_str()).HasParseError()) {
+ return;
+ }
+ if (document.HasMember("filePathToTags") && document["filePathToTags"].IsArray()) {
+ const value_t& pathList = document["filePathToTags"];
+ for (size_type_t i = 0; i < pathList.Size(); i++) {
+ const value_t& pathToTag = pathList[i];
+ if (!pathToTag.HasMember("filePath") || !pathToTag.HasMember("tag")) {
+ continue;
+ }
+ if (!pathToTag["filePath"].IsString() || !pathToTag["tag"].IsString()) {
+ continue;
+ }
+ std::string path = pathToTag["filePath"].GetString();
+ if (filePathTotags.find(path) == filePathTotags.end()) {
+ filePathTotags[path] = std::set();
+ }
+ filePathTotags[path].insert(pathToTag["tag"].GetString());
+ }
+ }
+ }
+ };
+}
+#endif //REQUEST_DEF_H
diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/defs/ResponseDef.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/defs/ResponseDef.h
new file mode 100644
index 0000000000000000000000000000000000000000..26cfca53d88d0310402a54ffd6d158d899d6d632
--- /dev/null
+++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/defs/ResponseDef.h
@@ -0,0 +1,141 @@
+/*
+* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved.
+*/
+
+#ifndef RESPONSE_DEF_H
+#define RESPONSE_DEF_H
+
+#include
+#include
+#include
+#include
+#include "rapidjson.h"
+#include "document.h"
+
+#include "defs/HistoConceptDefs.h"
+
+namespace Insight::Histogram {
+
+using allocator_t = rapidjson::Document::AllocatorType;
+using document_t = rapidjson::Document;
+using value_t = rapidjson::Value;
+using size_type_t =rapidjson::SizeType;
+
+struct ResponseDef {
+ std::string message_;
+ bool result_{true};
+ int errCode_{0};
+ ResponseDef() = default;
+ ResponseDef(const std::string & message, bool result, int errCode) :
+ message_(message), result_(result), errCode_(errCode) {}
+
+ virtual value_t CreatDataValue(document_t::AllocatorType& allocator) {
+ value_t body(rapidjson::kObjectType);
+ return body;
+ }
+
+ document_t CreatBasicJson() {
+ document_t document;
+ document.SetObject();
+ document_t::AllocatorType& allocator = document.GetAllocator();
+ document.AddMember("errCode", errCode_, allocator);
+ document.AddMember("msg", value_t().SetString(message_.c_str(), allocator), allocator);
+ document.AddMember("result", result_, allocator);
+
+ value_t body = CreatDataValue(allocator);
+ document.AddMember("body", body, allocator);
+ return document;
+ }
+
+ std::string ToJsonString()
+ {
+ document_t document = CreatBasicJson();
+ // json 转换成string格式
+ rapidjson::StringBuffer buffer;
+ rapidjson::Writer writer(buffer);
+ document.Accept(writer);
+
+ std::string jsonString = buffer.GetString();
+ return jsonString;
+ }
+ virtual ~ResponseDef() = default;
+};
+
+struct ImportResponse : public ResponseDef {
+ ImportResponse() = default;
+ explicit ImportResponse(const std::map> &filePathToTags) : filePathToTags_(filePathToTags) {}
+ ImportResponse(const std::string & message, bool result, int errCode, const std::map> &filePathToTags) :
+ ResponseDef(message, result, errCode), filePathToTags_(filePathToTags) {}
+ std::map> filePathToTags_;
+ // 数据转换成json格式,返回给前端
+ value_t CreatDataValue(document_t::AllocatorType& allocator) override
+ {
+ value_t data(rapidjson::kArrayType);
+ for (const auto& [filePath, tags] : filePathToTags_) {
+ // filepath是一个json对象
+ value_t entry(rapidjson::kObjectType);
+ entry.AddMember("filePath", value_t().SetString(filePath.c_str(), allocator), allocator);
+ // tags是一个集合,新建一个数组对象
+ value_t tagList(rapidjson::kArrayType);
+ for (const auto& tag : tags) {
+ tagList.PushBack(value_t().SetString(tag.c_str(), allocator), allocator);
+ }
+ entry.AddMember("tagList", tagList, allocator);
+ data.PushBack(entry, allocator);
+ }
+ value_t body(rapidjson::kObjectType);
+ body.AddMember("data", data, allocator);
+ return body;
+ }
+};
+
+// 获取采集路径下的新增文件,request的时候不需要参数
+struct GetNewFilesResponse : public ResponseDef {
+ std::vector filePaths_;
+ GetNewFilesResponse() = default;
+ explicit GetNewFilesResponse(const std::vector &filePaths) : filePaths_(filePaths) {}
+ GetNewFilesResponse(const std::string & message, bool result, int errCode,
+ const std::vector &filePaths) : ResponseDef(message, result, errCode), filePaths_(filePaths) {}
+ // 数据转换成json格式,返回给前端
+ value_t CreatDataValue(document_t::AllocatorType& allocator) override
+ {
+ value_t data(rapidjson::kArrayType);
+ for (const auto& filePath: filePaths_) {
+ data.PushBack(value_t().SetString(filePath.c_str(), allocator), allocator);
+ }
+ value_t body(rapidjson::kObjectType);
+ body.AddMember("data", data, allocator);
+ return body;
+ }
+};
+
+// 返回数据按照文件 - tag - 直方图的形式
+struct GetHistoDataResponse : public ResponseDef{
+ std::map> filePathToInfo_;
+ GetHistoDataResponse() = default;
+ explicit GetHistoDataResponse(const std::map> &filePathToInfo) :
+ filePathToInfo_(filePathToInfo) {}
+ GetHistoDataResponse(const std::string & message, bool result, int errCode,
+ const std::map> &filePathToInfo) :
+ ResponseDef(message, result, errCode), filePathToInfo_(filePathToInfo) {}
+ // 数据转换成json格式,返回给前端
+ value_t CreatDataValue(document_t::AllocatorType& allocator) override
+ {
+ value_t data(rapidjson::kArrayType);
+ for (const auto& filePair: filePathToInfo_) {
+ for (const auto& tagPair: filePair.second) {
+ value_t filetagHisto(rapidjson::kObjectType);
+ filetagHisto.AddMember("filePath", value_t().SetString(filePair.first.c_str(), allocator), allocator);
+ filetagHisto.AddMember("tag", value_t().SetString(tagPair.first.c_str(), allocator), allocator);
+ HistogramGraph histogram = tagPair.second;
+ filetagHisto.AddMember("histogramGraph", histogram.CreatHistogramValue(allocator), allocator);
+
+ data.PushBack(filetagHisto, allocator);
+ }
+ }
+ return data;
+ }
+};
+
+}
+#endif //RESPONSE_DEF_H
diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/AddImportFileRequestHandler.cpp b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/AddImportFileRequestHandler.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..54cd3a84c7ba320defa2d6fa9931b04cd2454634
--- /dev/null
+++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/AddImportFileRequestHandler.cpp
@@ -0,0 +1,21 @@
+/*
+* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved.
+*/
+
+#include "AddImportFileRequestHandler.h"
+#include "histoParser/FileParser.h"
+#include "defs/ResponseDef.h"
+#include "utils/FileUtils.h"
+#include "defs/RequestDef.h"
+
+namespace Insight::Histogram::handler {
+using namespace Insight::Histogram::Parser;
+ bool AddImportFileRequestHandler::HandleRequest(std::string_view data, std::string &resultStr)
+ {
+ AddImportFileRequest req(data);
+ // 解析并获取文件和tag映射数据
+ ImportResponse rsp(GetTagsByFilePath(req.filepathList));
+ resultStr = rsp.ToJsonString();
+ return true;
+ }
+} // Histogram
diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/AddImportFileRequestHandler.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/AddImportFileRequestHandler.h
new file mode 100644
index 0000000000000000000000000000000000000000..a25a59dd66115333134448892498fe300563ad3d
--- /dev/null
+++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/AddImportFileRequestHandler.h
@@ -0,0 +1,18 @@
+/*
+* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved.
+/*
+* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved.
+*/
+
+#ifndef UPDATEIMPORTFILEREQUESTHANDLER_H
+#define UPDATEIMPORTFILEREQUESTHANDLER_H
+
+#include
+#include "ImportRequestHandler.h"
+
+namespace Insight::Histogram::handler {
+ class AddImportFileRequestHandler : public ImportRequestHandler {
+ bool HandleRequest(std::string_view data, std::string &resultStr) override;
+ };
+} // Histogram
+#endif //UPDATEIMPORTFILEREQUESTHANDLER_H
diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/GetHistoDataRequestHandler.cpp b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/GetHistoDataRequestHandler.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..b6c33a363daac6b3096edc364fb1d4258bc37a07
--- /dev/null
+++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/GetHistoDataRequestHandler.cpp
@@ -0,0 +1,65 @@
+/*
+* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved.
+*/
+
+#include "GetHistoDataRequestHandler.h"
+#include "histogramViewManager.h"
+#include "histoParser/ParserFactory.h"
+#include "histoParser/FileParser.h"
+#include "defs/ResponseDef.h"
+#include "utils/FileUtils.h"
+#include "RequestDef.h"
+
+namespace Insight::Histogram::handler {
+using namespace Insight::Histogram::Parser;
+
+ bool GetHistoDataRequestHandler::run(std::string_view data, std::string &result) {
+ return HandleRequest(data, result);
+ }
+
+ bool GetHistoDataRequestHandler::HandleRequest(std::string_view data, std::string &resultStr)
+ {
+ GetHistoDataRequest req(data);
+ std::map> pathToTagsMap = req.filePathTotags;
+
+ HistogramViewManager& viewManager = HistogramViewManager::getInstance();
+
+ std::map> filePathToInfo;
+ // 遍历请求命令里的每一个文件
+ ParserFactory parserFactory;
+ for (auto & pathToTags : pathToTagsMap) {
+ std::string filePath = pathToTags.first;
+ if (pathToTags.second.empty()) {
+ // 这个文件路径下想要的tag 图为空,则不再往下走节省时间
+ continue;
+ }
+ uint64_t offset = viewManager.GetReaderOffsetByFilePath(filePath);
+ std::shared_ptr parser = parserFactory.CreateFileParse(filePath);
+ if (!parser) {
+ continue;
+ }
+ // 从上次解析的偏移量开始解析新的数据
+ parser->ParseData(filePath, offset);
+ // 获取到最新数据交给infomanager追加到老数据里并做更新
+ viewManager.AppendNewFileInfo(filePath, offset, parser->GetTags(), parser->GetTagToHistoGraph());
+ // 获取合并之后的新数据
+ std::map newData = viewManager.GetHistoInfoByFilePath(filePath);
+ filePathToInfo[filePath] = FilterHistoDataByTag(pathToTags.second, viewManager.GetHistoInfoByFilePath(filePath));
+ };
+ GetHistoDataResponse rsp(filePathToInfo);
+ resultStr = rsp.ToJsonString();
+ return true;
+ }
+
+ // 根据前端请求返回需要的tag数据
+ std::map GetHistoDataRequestHandler::FilterHistoDataByTag(std::set tags, std::map histoGraphs)
+ {
+ std::map result;
+ for (auto & tag : tags) {
+ if (histoGraphs.find(tag) != histoGraphs.end()) {
+ result[tag] = histoGraphs[tag];
+ }
+ }
+ return result;
+ }
+} // Histogram
diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/GetHistoDataRequestHandler.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/GetHistoDataRequestHandler.h
new file mode 100644
index 0000000000000000000000000000000000000000..0c20a1ac00a159efb1f2e7e824b5a5e86c42ec86
--- /dev/null
+++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/GetHistoDataRequestHandler.h
@@ -0,0 +1,29 @@
+//
+// Created by admin on 2024/12/11.
+//
+
+#ifndef GET_HISTO_DATA_REQUEST_HANDLER_H
+#define GET_HISTO_DATA_REQUEST_HANDLER_H
+
+#include
+#include