diff --git a/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md b/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md index fc9e09f35030f71a8b23b5bc9fe86b120820b8bc..56c30e022ed179d7b318795ce5ec142dfa613d6e 100644 --- a/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md +++ b/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md @@ -17,6 +17,11 @@ --- +## 3. 分支合并要求 +- [ ] **代码合并**(请确保将 master 分支的最新代码同步合并至 poc 分支及 pre-research 分支,同时保证 poc 分支的代码也已正确合并到 pre-research 分支。) + +--- + ## 3. 代码检视 - **要求:** - 合入代码超过 200 行,需三人以上会议检视。 @@ -33,14 +38,13 @@ ## 4. 安全自检 ### Python、C++: -- [ ] **对外接口新增/删除/变更,需要更新外部输入表格** -- [ ] **不允许私有的文件操作,需要使用公共函数** -- [ ] **数组使用需要校验越界场景** -- [ ] **对正则表达式做 ReDos 校验** -- [ ] **对除法做除零校验** -- [ ] **充分进行接口返回值异常情况的校验** -- [ ] **充分进行接口输入值异常情况的校验** -- [ ] **日志不要暴露代码细节和敏感信息** +- [ ] **对外接口新增/删除/变更后,资料要同步新增/删除/变更,新增接口入参校验参考外部输入表格** +- [ ] **不允许私有的文件操作,需要使用公共模块的安全函数** +- [ ] **任务结束后需要删除临时文件,同时需要考虑任务失败后,临时文件没有残留** +- [ ] **数组访问需要校验越界场景,对除法需要做除零校验** +- [ ] **需要对递归方法做递归深度校验,正则表达式必须做 ReDoS 校验** +- [ ] **需要充分进行接口输入和返回值异常情况的校验** +- [ ] **日志打印不要出现拼写或语法错误,不要暴露代码细节和敏感信息** ### C++: - [ ] **指针使用前需要判空** diff --git a/.gitmodules b/.gitmodules index b08433f072bf89f62edf88b3aff40d24c1040ea8..0c8727a91869b9afe6f5c50ff759ecb5fb45988c 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ -[submodule "dynolog_npu/third_party/dynolog"] - path = dynolog_npu/third_party/dynolog +[submodule "msmonitor/third_party/dynolog"] + path = msmonitor/third_party/dynolog url = https://github.com/facebookincubator/dynolog.git diff --git "a/.\346\234\254\351\241\271\347\233\256\345\267\262\347\273\217\346\255\243\345\274\217\350\277\201\347\247\273\350\207\263 Gitcode \345\271\263\345\217\260/README.md" "b/.\346\234\254\351\241\271\347\233\256\345\267\262\347\273\217\346\255\243\345\274\217\350\277\201\347\247\273\350\207\263 Gitcode \345\271\263\345\217\260/README.md" new file mode 100644 index 0000000000000000000000000000000000000000..5bb8fadbf19ad1e7c299766c9ef99c9e6583a43c --- /dev/null +++ "b/.\346\234\254\351\241\271\347\233\256\345\267\262\347\273\217\346\255\243\345\274\217\350\277\201\347\247\273\350\207\263 Gitcode \345\271\263\345\217\260/README.md" @@ -0,0 +1 @@ +# 通知: 本项目已经正式迁移至 [Gitcode](https://gitcode.com/Ascend) 平台 diff --git a/OWNERS b/OWNERS index 415d737ed907c577bc61e71c2839a485395b899c..1b8f63546de38bc966852e4e1b318ad68a1161af 100644 --- a/OWNERS +++ b/OWNERS @@ -10,6 +10,7 @@ approvers: - ly-qianxiao - blian - kun_8 +- uniteone reviewers: - lv-kaimeng - wo-wenjie @@ -20,4 +21,7 @@ reviewers: - TAJh - czr9775 - kali20gakki -- wjchuee \ No newline at end of file +- wjchuee +- chenhao_1209 +- feng123www +- uniteone \ No newline at end of file diff --git a/README.md b/README.md index 5ae0bf742fced7ed86452d03d013670cc3528316..881f0298d246ce7e740b7fb829d4d6ce49c93281 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ +# 通知: 本项目已经正式迁移至 [Gitcode](https://gitcode.com/Ascend) 平台 # 🚨 重要通知 **1. Ascend Training Tools 更名为 MindStudio Training Tools (mstt)。** @@ -12,47 +13,36 @@ ![Commit Activity](https://img.shields.io/badge/commit%20activity-high-red) ![License: Apache 2.0](https://img.shields.io/badge/license-Apache%202.0-blue) -## [分析迁移工具](https://gitee.com/ascend/mstt/wikis/工具介绍/分析迁移工具/分析迁移工具介绍) +## [模型训练开发全流程](https://www.hiascend.com/software/mindstudio/training) -1. [脚本分析工具](https://gitee.com/ascend/mstt/wikis/%E5%B7%A5%E5%85%B7%E4%BB%8B%E7%BB%8D/%E5%88%86%E6%9E%90%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7/%E5%88%86%E6%9E%90%E5%B7%A5%E5%85%B7%E4%BD%BF%E7%94%A8%E6%8C%87%E5%AF%BC) +mstt包括精度工具(msprobe)和性能工具(msprof-analyze),分析迁移工具请参见[昇腾社区](https://www.hiascend.com/software/mindstudio/training)。 - 脚本分析工具可以帮助用户在执行迁移操作前,分析基于 GPU 平台的 PyTorch 训练脚本中算子、三方库套件、API 亲和性以及动态 shape 的支持情况。 +![training_process](debug/resources/training_process.png) -2. [(推荐)自动迁移工具](https://gitee.com/ascend/mstt/wikis/%E5%B7%A5%E5%85%B7%E4%BB%8B%E7%BB%8D/%E5%88%86%E6%9E%90%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7/%E8%87%AA%E5%8A%A8%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7%E4%BD%BF%E7%94%A8%E6%8C%87%E5%AF%BC) - - 自动迁移工具只需在训练脚本中导入库代码即可完成模型脚本的迁移,使用方式简单,且修改内容少。 - -3. [脚本迁移工具](https://gitee.com/ascend/mstt/wikis/%E5%B7%A5%E5%85%B7%E4%BB%8B%E7%BB%8D/%E5%88%86%E6%9E%90%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7/%E8%84%9A%E6%9C%AC%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7%E4%BD%BF%E7%94%A8%E6%8C%87%E5%AF%BC) - - 脚本迁移工具通过后端命令行,将 GPU 上训练的 PyTorch 脚本迁移至 NPU 上,得到新的训练脚本用于训练。 +# 使用说明 ## [精度工具](./debug/accuracy_tools/) [MindStudio Probe(msprobe,MindStudio 精度调试工具)](./debug/accuracy_tools/msprobe)。 -## [性能工具](./profiler/msprof_analyze) - -1. [compare_tools(性能比对工具)](./profiler/msprof_analyze/compare_tools) +## [性能工具](./profiler) - 提供 NPU 与 GPU 性能拆解功能以及算子、通信、内存性能的比对功能。 +[msprof-analyze(MindStudio Profiler Analyze 性能分析工具)](./profiler/msprof_analyze/) -2. [cluster_analyse(集群分析工具)](./profiler/msprof_analyze/cluster_analyse) +基于采集的性能数据进行分析,识别AI作业中的性能瓶颈。 - 提供多机多卡的集群分析能力(基于通信域的通信分析和迭代耗时分析), 当前需要配合 MindStudio Insight 的集群分析功能使用。 - -3. [advisor](./profiler/msprof_analyze/advisor) - - 将 Ascend PyTorch Profiler 或者 msprof 采集的 PyTorch 场景性能数据进行分析,并输出性能调优建议。 - -4. [bind_core](./profiler/affinity_cpu_bind) +[bind_core](./profiler/affinity_cpu_bind) 绑核脚本,支持非侵入修改工程代码,实现一键式绑核功能。 -## [Tensorboard](./plugins/tensorboard-plugins/tb_plugin) +[msMonitor](./msmonitor) + + MindStudio一站式在线监控工具。 -Tensorboard 支持 NPU 性能数据可视化插件 PyTorch Profiler TensorBoard NPU Plugin。 +## [Tensorboard](./plugins/tensorboard-plugins/tb_graph_ascend) +Tensorboard 支持模型结构进行分级可视化展示的插件 tb-graph-ascend。 -支持将 Ascend 平台采集、解析的 PyTorch Profiling 数据可视化呈现,也兼容 GPU 数据采集、解析可视化。 +可将模型的层级关系、精度数据进行可视化,并支持将调试模型和标杆模型进行分视图展示和关联比对,方便用户快速定位精度问题。 ## 分支维护策略 diff --git a/debug/OWNERS b/debug/OWNERS index 0bda9243569f0b6bcd0ce761d7817d512b487ddd..99d0920d89c96d3031e39b28fc3f059510ad3d46 100644 --- a/debug/OWNERS +++ b/debug/OWNERS @@ -4,13 +4,16 @@ approvers: - wangchao285 - kun_8 - brightlyking +- wqc01202410 +- shawnzhu1 +- pengxiaopeng1 +- zderry reviewers: - lv-kaimeng - TAJh - jiandaobao -- pengxiaopeng1 - zhengxinqian - louyujing - yang_chen_2001_02_14 -- shawnzhu1 -- wqc01202410 +- li-changwei4 +- qiangge123a diff --git a/debug/accuracy_tools/cmake/Findgtest.cmake b/debug/accuracy_tools/cmake/Findgtest.cmake index dbfe76abcc9b5d3c2f61642cc8c6e270fc441a0f..d4dd8d8895466d3367dff2032a7de03c829e3dc6 100644 --- a/debug/accuracy_tools/cmake/Findgtest.cmake +++ b/debug/accuracy_tools/cmake/Findgtest.cmake @@ -1,7 +1,6 @@ set(PACKAGE_VERSION 1.12.1) set(PKG_NAME gtest) -set(URL "https://gitee.com/mirrors/googletest/repository/archive/release-1.12.1.tar.gz") set(SHA256_VALUE "81964fe578e9bd7c94dfdb09c8e4d6e6759e19967e397dbea48d1c10e45d0df2") set(DOWNLOAD_PATH "$ENV{PROJECT_ROOT_PATH}/third_party") set(DIR_NAME "${DOWNLOAD_PATH}/googletest-release-1.12.1") @@ -9,7 +8,6 @@ set(DIR_NAME "${DOWNLOAD_PATH}/googletest-release-1.12.1") if (NOT ${PKG_NAME}_FOUND) download_opensource_pkg(${PKG_NAME} - URL ${URL} SHA256 ${SHA256_VALUE} DOWNLOAD_PATH ${DOWNLOAD_PATH} ) diff --git a/debug/accuracy_tools/cmake/Findmockcpp.cmake b/debug/accuracy_tools/cmake/Findmockcpp.cmake index c360702c187bfdef553a6b67344ea132a18373f6..73b1729aa5bec968c3e127560db981885c80ba83 100644 --- a/debug/accuracy_tools/cmake/Findmockcpp.cmake +++ b/debug/accuracy_tools/cmake/Findmockcpp.cmake @@ -1,7 +1,6 @@ set(PACKAGE_VERSION 2.7) set(PKG_NAME mockcpp) -set(URL "https://gitee.com/sinojelly/mockcpp/repository/archive/v2.7.zip") set(SHA256_VALUE "0dc7111c5be9785d0550ed3b68db7e12fd5d7802b7bc6548c52ac7b9e727fcc1") set(DOWNLOAD_PATH "$ENV{PROJECT_ROOT_PATH}/third_party") set(DIR_NAME "${DOWNLOAD_PATH}/mockcpp-v2.7") @@ -9,7 +8,6 @@ set(DIR_NAME "${DOWNLOAD_PATH}/mockcpp-v2.7") if (NOT ${PKG_NAME}_FOUND) download_opensource_pkg(${PKG_NAME} - URL ${URL} SHA256 ${SHA256_VALUE} DOWNLOAD_PATH ${DOWNLOAD_PATH} ) diff --git a/debug/accuracy_tools/cmake/Findnlohmannjson.cmake b/debug/accuracy_tools/cmake/Findnlohmannjson.cmake index 0f85cc00a0d30a3896a8f47cac95911929070e33..7acac96ca3ff8025745a6eeddbdf568e453a58f1 100644 --- a/debug/accuracy_tools/cmake/Findnlohmannjson.cmake +++ b/debug/accuracy_tools/cmake/Findnlohmannjson.cmake @@ -1,7 +1,6 @@ set(PACKAGE_VERSION 3.10.1) set(PKG_NAME nlohmannjson) -set(URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.10.1.zip") set(SHA256_VALUE "5c7d0a0542431fef628f8dc4c34fd022fe8747ccb577012d58f38672d8747e0d") set(DOWNLOAD_PATH "$ENV{PROJECT_ROOT_PATH}/third_party") set(DIR_NAME "${DOWNLOAD_PATH}/JSON-for-Modern-CPP-v3.10.1") @@ -9,7 +8,6 @@ set(DIR_NAME "${DOWNLOAD_PATH}/JSON-for-Modern-CPP-v3.10.1") if (NOT ${PKG_NAME}_FOUND) download_opensource_pkg(${PKG_NAME} - URL ${URL} SHA256 ${SHA256_VALUE} DOWNLOAD_PATH ${DOWNLOAD_PATH} ) diff --git a/debug/accuracy_tools/cmake/Findopenssl.cmake b/debug/accuracy_tools/cmake/Findopenssl.cmake index d361095242917df8accbb81a51de65c5ca5ac980..cc33bfc5902aa4c1651029789f04c8a4d2dc10bf 100644 --- a/debug/accuracy_tools/cmake/Findopenssl.cmake +++ b/debug/accuracy_tools/cmake/Findopenssl.cmake @@ -1,7 +1,6 @@ set(PACKAGE_VERSION 1.1.1) set(PKG_NAME openssl) -set(URL "https://gitee.com/mirrors/openssl/repository/archive/OpenSSL_1_1_1k.tar.gz") set(SHA256_VALUE "b92f9d3d12043c02860e5e602e50a73ed21a69947bcc74d391f41148e9f6aa95") set(DOWNLOAD_PATH "$ENV{PROJECT_ROOT_PATH}/third_party") set(DIR_NAME "${DOWNLOAD_PATH}/openssl-OpenSSL_1_1_1k") @@ -23,7 +22,6 @@ endif() endif() download_opensource_pkg(${PKG_NAME} - URL ${URL} SHA256 ${SHA256_VALUE} DOWNLOAD_PATH ${DOWNLOAD_PATH} ) diff --git a/debug/accuracy_tools/cmake/Findprotobuf.cmake b/debug/accuracy_tools/cmake/Findprotobuf.cmake index 4d70515e980f7a921447250fe58400f600419e4c..62c1fe7fbbebc6e0d76fec309a0154d5b102d3aa 100644 --- a/debug/accuracy_tools/cmake/Findprotobuf.cmake +++ b/debug/accuracy_tools/cmake/Findprotobuf.cmake @@ -1,10 +1,9 @@ -set(PACKAGE_VERSION 3.13.0) +set(PACKAGE_VERSION 3.15.0) set(PKG_NAME protobuf) -set(URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz") -set(SHA256_VALUE "ab9b39e7053a6fb06b01bf75fb6ec6a71a1ada5a5f8e2446f927336e97b9e7bb") +set(SHA256_VALUE "a1ce078c369f46a3277fdc7ce462ac73cb7cb0edec8bc9d90d23fdb34491c575") set(DOWNLOAD_PATH "$ENV{PROJECT_ROOT_PATH}/third_party") -set(DIR_NAME "${DOWNLOAD_PATH}/protobuf_source-v3.13.0") +set(DIR_NAME "${DOWNLOAD_PATH}/protobuf_source-v3.15.0") if (NOT ${PKG_NAME}_FOUND) @@ -32,7 +31,6 @@ endif() endif() download_opensource_pkg(${PKG_NAME} - URL ${URL} SHA256 ${SHA256_VALUE} DOWNLOAD_PATH ${DOWNLOAD_PATH} ) diff --git a/debug/accuracy_tools/cmake/Findre2.cmake b/debug/accuracy_tools/cmake/Findre2.cmake new file mode 100644 index 0000000000000000000000000000000000000000..bca1df0bac554b59dd17649ed9dc2687c49ff389 --- /dev/null +++ b/debug/accuracy_tools/cmake/Findre2.cmake @@ -0,0 +1,58 @@ +set(PKG_NAME re2) +set(SHA256_VALUE "7268e1b4254d9ffa5ccf010fee954150dbb788fd9705234442e7d9f0ee5a42d3") +set(DOWNLOAD_PATH "$ENV{PROJECT_ROOT_PATH}/third_party") +set(DIR_NAME "${DOWNLOAD_PATH}/re2-2019-12-01") +set(BUILD_DIR "${DIR_NAME}/build") +file(MAKE_DIRECTORY "${BUILD_DIR}") +set(BUILD_DEPENDENCY_PATH "$ENV{PROJECT_ROOT_PATH}/build_dependency/${PKG_NAME}") + +if (NOT ${PKG_NAME}_FOUND) + +file(GLOB RE2_INCLUDE "${BUILD_DEPENDENCY_PATH}/include/${PKG_NAME}/re2.h") +file(GLOB_RECURSE RE2_LIB "${BUILD_DEPENDENCY_PATH}/*libre2.a") +if (RE2_INCLUDE AND RE2_LIB) + include_directories(${BUILD_DEPENDENCY_PATH}/include) + set(${PKG_NAME}_LIBRARIES "${RE2_LIB}") + set(${PKG_NAME}_FOUND TRUE) + return() +endif() + +download_opensource_pkg(${PKG_NAME} + SHA256 ${SHA256_VALUE} + DOWNLOAD_PATH ${DOWNLOAD_PATH} +) + +execute_process( + WORKING_DIRECTORY ${BUILD_DIR} + COMMAND cmake -DCMAKE_INSTALL_PREFIX=${BUILD_DEPENDENCY_PATH} -DCMAKE_C_FLAGS=-fPIC -DCMAKE_CXX_FLAGS=-fPIC .. + RESULT_VARIABLE RESULT +) +if (NOT RESULT EQUAL 0) + message(FATAL_ERROR "Failed to build re2. ${RESULT}") +endif() + +execute_process( + WORKING_DIRECTORY ${BUILD_DIR} + COMMAND make -j16 + RESULT_VARIABLE RESULT +) +if (NOT RESULT EQUAL 0) + message(FATAL_ERROR "Failed to build re2. ${RESULT}") +endif() + +execute_process( + WORKING_DIRECTORY ${BUILD_DIR} + COMMAND make install +) + +file(GLOB RE2_INCLUDE "${BUILD_DEPENDENCY_PATH}/include/${PKG_NAME}/re2.h") +file(GLOB_RECURSE RE2_LIB "${BUILD_DEPENDENCY_PATH}/*libre2.a") +if (NOT RE2_INCLUDE OR NOT RE2_LIB) + message(FATAL_ERROR "Failed to build re2.") +endif() + +include_directories(${BUILD_DEPENDENCY_PATH}/include) +set(${PKG_NAME}_LIBRARIES "${RE2_LIB}") +set(${PKG_NAME}_FOUND TRUE) + +endif() diff --git a/debug/accuracy_tools/cmake/config.ini b/debug/accuracy_tools/cmake/config.ini new file mode 100644 index 0000000000000000000000000000000000000000..81b9ee5b4741952b02d09caa4b5b087f74194a88 --- /dev/null +++ b/debug/accuracy_tools/cmake/config.ini @@ -0,0 +1,17 @@ +[gtest] +url = https://gitee.com/mirrors/googletest/repository/archive/release-1.12.1.tar.gz + +[mockcpp] +url = https://gitee.com/sinojelly/mockcpp/repository/archive/v2.7.zip + +[nlohmannjson] +url = https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.10.1.zip + +[openssl] +url = https://gitee.com/mirrors/openssl/repository/archive/OpenSSL_1_1_1k.tar.gz + +[protobuf] +url = https://gitee.com/mirrors/protobuf_source/repository/archive/v3.15.0.tar.gz + +[re2] +url = https://gitee.com/mirrors/re2/repository/archive/2019-12-01.tar.gz \ No newline at end of file diff --git a/debug/accuracy_tools/cmake/download_opensource.sh b/debug/accuracy_tools/cmake/download_opensource.sh index 725e971621434c32d9954c80b9efe234502eefcc..671dc218bb135a39ffc8937777815d84df654187 100644 --- a/debug/accuracy_tools/cmake/download_opensource.sh +++ b/debug/accuracy_tools/cmake/download_opensource.sh @@ -1,11 +1,11 @@ #!/bin/bash if [ "$#" -lt 2 ]; then - echo "Usage: $0 [ ] [ ]" + echo "Usage: $0 [ ] [ ]" exit 1 fi -url=$1 +pkg_name=$1 path=$2 if [ "$#" -ge 3 ]; then @@ -15,6 +15,16 @@ if [ "$#" -ge 4 ]; then tag=$4 fi +url=$(awk -F " = " '/\['${pkg_name}'\]/{a=1}a==1&&$1~/url/{print $2;exit}' config.ini) +lib_path=$MSTT_LIB_PATH +if [ -n "$lib_path" ]; then + url=${lib_path}$(echo $url | awk -F '/' -v OFS='/' '{print $5,$8}') +fi +if [[ ! $url = https* ]]; then + echo "The URL of $pkg_name is illegal." + exit 1 +fi + echo "Start to download ${url}..." if [ ! -d "$path" ]; then diff --git a/debug/accuracy_tools/cmake/utils.cmake b/debug/accuracy_tools/cmake/utils.cmake index e3e963d63e99da4e0bb1fd2973051278feb04435..738afff874f37bea442c33f6cf607a21bdd6cbe7 100644 --- a/debug/accuracy_tools/cmake/utils.cmake +++ b/debug/accuracy_tools/cmake/utils.cmake @@ -2,13 +2,10 @@ function(download_opensource_pkg pkg_name) message("start to download ${pkg_name}...") set(options) - set(oneValueArgs URL SHA256 GIT_TAG DOWNLOAD_PATH DIR_NAME BUILD_CMD) + set(oneValueArgs SHA256 GIT_TAG DOWNLOAD_PATH DIR_NAME BUILD_CMD) set(multiValueArgs PATCHES) cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - if (NOT PKG_URL) - message(FATAL_ERROR "${pkg_name} need URL.") - endif() if (NOT PKG_DOWNLOAD_PATH) set(PKG_DOWNLOAD_PATH "${CMAKE_SOURCE_DIR}/../third_party") endif() @@ -16,7 +13,7 @@ function(download_opensource_pkg pkg_name) execute_process( WORKING_DIRECTORY $ENV{PROJECT_ROOT_PATH}/cmake - COMMAND bash download_opensource.sh ${PKG_URL} ${PKG_DOWNLOAD_PATH} ${PKG_SHA256} ${PKG_GIT_TAG} + COMMAND bash download_opensource.sh ${pkg_name} ${PKG_DOWNLOAD_PATH} ${PKG_SHA256} ${PKG_GIT_TAG} RESULT_VARIABLE RESULT ) if (NOT RESULT EQUAL 0) diff --git a/debug/accuracy_tools/msprobe/README.md b/debug/accuracy_tools/msprobe/README.md index 0e68d1f8d9bdaba93a2f65220f85d08eb45f8586..7b612fbf3e2277e145b6bd471c809853ab541c45 100644 --- a/debug/accuracy_tools/msprobe/README.md +++ b/debug/accuracy_tools/msprobe/README.md @@ -44,6 +44,7 @@ export MSPROBE_LOG_LEVEL={x} - msprobe支持AscendPyTorch 1.11.0或更高版本,支持的PyTorch和CANN以及PyTorch和python软件版本配套关系请参见《[Ascend Extension for PyTorch插件](https://gitee.com/ascend/pytorch)》。 - msprobe支持MindSpore 2.4.0或更高版本,支持的MindSpore和CANN以及MindSpore和python软件版本配套关系请参见《[MindSpore版本发布列表](https://www.mindspore.cn/versions)》。 +- msprobe支持MSAdapter 2.1.0。 - msprobe支持的固件驱动版本与配套CANN软件支持的固件驱动版本相同,开发者可通过“[昇腾社区-固件与驱动](https://gitee.com/link?target=https%3A%2F%2Fwww.hiascend.com%2Fhardware%2Ffirmware-drivers%2Fcommunity%3Fproduct%3D2%26model%3D28%26cann%3D8.0.RC3.alpha003%26driver%3D1.0.25.alpha)”页面根据产品型号与CANN软件版本获取配套的固件与驱动。 @@ -53,7 +54,9 @@ export MSPROBE_LOG_LEVEL={x} **2. 工具读写的所有路径,如config_path、dump_path等,只允许包含大小写字母、数字、下划线、斜杠、点和短横线。** -## ⚙️ [安装](./docs/01.installation.md) +## ⚙️ 安装 + +请参见[安装指导说明](./docs/01.installation.md)。 ## 🌟 新版本特性 @@ -69,35 +72,37 @@ export MSPROBE_LOG_LEVEL={x} ### 1 数据采集 -msprobe 通过在训练脚本中添加 PrecisionDebugger 接口的方式对 API 执行精度数据 dump 操作,对应 config.json 中的 task 为 statistics 或 tensor。 +msprobe 通过在训练脚本中添加 PrecisionDebugger 接口的方式对 API 执行精度数据 dump 操作。对应 config.json 中的 "statistics" 或 "tensor" task。 [PyTorch 场景的数据采集](./docs/05.data_dump_PyTorch.md) [MindSpore 场景的数据采集](./docs/06.data_dump_MindSpore.md) +[MSAdapter 场景的数据采集](./docs/29.data_dump_MSAdapter.md) + ### 2 精度预检 -精度预检旨在昇腾 NPU 上扫描训练模型中的所有 API 进行 API 复现,给出精度情况的诊断和分析。对应 config.json 中的 task 为 run_ut。 +精度预检旨在昇腾 NPU 上扫描训练模型中的所有 API 进行 API 复现,给出精度情况的诊断和分析。对应 config.json 中的 "run_ut" task。 -PyTorch 场景的[离线预检](./docs/07.accuracy_checker_PyTorch.md)和[在线预检](./docs/08.accuracy_checker_online_PyTorch.md) +PyTorch 场景的[离线预检](./docs/07.accuracy_checker_PyTorch.md) MindSpore 动态图场景的[离线预检](./docs/09.accuracy_checker_MindSpore.md) -### 3 精度比对 +### 3 分级可视化构图比对 -该功能进行 PyTorch 整网 API 粒度的数据 dump、精度比对,进而定位训练场景下的精度问题。 +该功能将msprobe工具dump的精度数据进行解析,还原模型图结构,实现模型各个层级的精度数据比对,方便用户理解模型结构、分析精度问题。 -[PyTorch 场景的精度比对](./docs/10.accuracy_compare_PyTorch.md) +[PyTorch 场景的分级可视化构图比对](./docs/21.visualization_PyTorch.md) -[MindSpore 场景的精度比对](./docs/11.accuracy_compare_MindSpore.md) +[MindSpore 场景的分级可视化构图比对](./docs/22.visualization_MindSpore.md) -### 4 溢出检测与解析 +### 4 精度比对 -溢出检测与解析是在执行精度数据 dump 时,判断是否存在输入正常但输出存在溢出的 API,从而判断是否为正常溢出。对应 config.json 中的 overflow_check。 +该功能进行 PyTorch 整网 API 粒度的数据 dump、精度比对,进而定位训练场景下的精度问题。 -[PyTorch 场景的溢出检测与解析](./docs/12.overflow_check_PyTorch.md) +[PyTorch 场景的精度比对](./docs/10.accuracy_compare_PyTorch.md) -[MindSpore 场景的溢出检测与解析](./docs/13.overflow_check_MindSpore.md) +[MindSpore 场景的精度比对](./docs/11.accuracy_compare_MindSpore.md) ### 5 数据解析 @@ -129,26 +134,57 @@ MindSpore 动态图场景的[离线预检](./docs/09.accuracy_checker_MindSpore. [兼容 PyTorch 和 MindSpore 框架的训练状态监控](./docs/19.monitor.md) -### 10 分级可视化构图比对 +### 10 单算子API自动生成脚本 -该功能将msprobe工具dump的精度数据进行解析,还原模型图结构,实现模型各个层级的精度数据比对,方便用户理解模型结构、分析精度问题。 +该功能将msprobe工具dump的精度数据进行解析,自动生成单API脚本,用于复现整网中出现的算子问题,降低用户复现问题的成本,供开发分析算子问题。 -[PyTorch 场景的分级可视化构图比对](./docs/21.visualization_PyTorch.md) +[PyTorch 单算子API自动生成脚本](./docs/23.generate_operator_PyTorch.md) -[MindSpore 场景的分级可视化构图比对](./docs/22.visualization_MindSpore.md) +[MindSpore 单算子API自动生成脚本](./docs/33.generate_operator_MindSpore.md) +### 11 数码关联 -### 11 单算子API自动生成脚本 +该功能只支持 MindSpore 静态图场景,用于将IR图与dump数据进行关联,获取dump数据和代码调用栈的关联关系。 -该功能将msprobe工具dump的精度数据进行解析,自动生成单API脚本,用于复现整网中出现的算子问题,降低用户复现问题的成本,供开发分析算子问题。 +[MindSpore 场景的数码关联](./docs/24.code_mapping_Mindspore.md) -[PyTorch 单算子API自动生成脚本](./docs/23.generate_operator_PyTorch.md) +### 12 溢出检测与解析 -### 12 数码关联 +溢出检测用于采集溢出 API 或 模块的精度数据,而溢出解析则是通过对溢出数据的分析,进一步判断是否为正常溢出。对应 config.json 中的 "overflow_check" task。 +推荐直接使用[数据采集](#1-数据采集)功能采集统计量信息,检测溢出问题。 -该功能只支持 MindSpore 静态图场景,用于将IR图与dump数据进行关联,获取dump数据和代码调用栈的关联关系。 +[PyTorch 场景的溢出检测与解析](./docs/12.overflow_check_PyTorch.md) + +[MindSpore 场景的溢出检测](./docs/13.overflow_check_MindSpore.md) + +[MSAdapter 场景的溢出检测](./docs/30.overflow_check_MSAdapter.md) + +### 13 训练检查 + +该工具主要包括: + +训练前或精度比对前,对比两个环境下可能影响训练精度的配置差异。 + +[训练前配置检查](./docs/31.config_check.md) + +训练过程中或结束后,比较两个不同的checkpoint,评估模型相似度。 + +[checkpoint比对](./docs/32.ckpt_compare.md) + +### 14 强化学习数据采集 + +主要能力: + +灵活采集强化学习中重要关键过程数据,并支持比对。 + +[强化学习数据采集](./docs/34.RL_collect.md) + +### 15 整网首个溢出节点分析 + +多rank场景下通过dump数据找到首个出现Nan或Inf的节点。 + +[PyTorch 场景整网首个溢出节点分析](./docs/35.nan_analyze.md) -[MindSpore 场景的数码关联](./docs/24.code_mapping_Mindspore.md) ## 📑 补充材料 diff --git a/debug/accuracy_tools/msprobe/ccsrc/CMakeLists.txt b/debug/accuracy_tools/msprobe/ccsrc/CMakeLists.txt index 2579a3a0e785c0e0ca384b4d52118a5d828249f8..57313609f906d47f0ee04eb0a42ec3f7eedf4c35 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/CMakeLists.txt +++ b/debug/accuracy_tools/msprobe/ccsrc/CMakeLists.txt @@ -8,6 +8,7 @@ find_package(cpython MODULE REQUIRED) find_package(openssl MODULE REQUIRED) find_package(nlohmannjson MODULE REQUIRED) find_package(protobuf MODULE REQUIRED) +find_package(re2 MODULE REQUIRED) if (DEFINED CANN_PATH AND NOT "${CANN_PATH}" STREQUAL "") file(GLOB_RECURSE DUMP_DATA_PROTOS "${CANN_PATH}/**/dump_data.proto") @@ -26,6 +27,8 @@ compile_protobuf_file( ${PROTO_SRC} ) +set(CMAKE_SKIP_RPATH TRUE) + add_library(_msprobe_c SHARED) target_compile_options(_msprobe_c PRIVATE "-Wall") @@ -33,8 +36,9 @@ target_compile_options(_msprobe_c PRIVATE "-fPIC") target_compile_options(_msprobe_c PRIVATE "-fstack-protector-all") target_compile_options(_msprobe_c PRIVATE "-ftrapv") target_compile_options(_msprobe_c PRIVATE "-fstack-check") +target_compile_options(_msprobe_c PRIVATE "-D_FORTIFY_SOURCE=2") -target_link_options(_msprobe_c PRIVATE "-Wl,-z,relor") +target_link_options(_msprobe_c PRIVATE "-Wl,-z,relro") target_link_options(_msprobe_c PRIVATE "-Wl,-z,now") target_link_options(_msprobe_c PRIVATE "-Wl,-z,noexecstack") @@ -43,6 +47,7 @@ target_link_libraries(_msprobe_c PUBLIC pthread) target_link_libraries(_msprobe_c PUBLIC ${cpython_LIBRARIES}) target_link_libraries(_msprobe_c PUBLIC ${openssl_LIBRARIES}) target_link_libraries(_msprobe_c PUBLIC ${protobuf_LIBRARIES}) +target_link_libraries(_msprobe_c PUBLIC ${re2_LIBRARIES}) if(DEFINED BUILD_TYPE AND "${BUILD_TYPE}" STREQUAL "debug") target_compile_options(_msprobe_c PRIVATE "-O0") @@ -50,6 +55,7 @@ if(DEFINED BUILD_TYPE AND "${BUILD_TYPE}" STREQUAL "debug") target_compile_definitions(_msprobe_c PRIVATE __DEBUG__) else() target_compile_options(_msprobe_c PRIVATE "-O2") + target_link_options(_msprobe_c PRIVATE "-s") endif() target_include_directories(_msprobe_c PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.cpp b/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.cpp index 9f61e03a31f6d4dfa2ca0b258d589bbcd29356fa..e79a951614784e8dc31d897ce496dedc247f3603 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.cpp @@ -1,5 +1,5 @@ /* - * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (C) 2024-2025. Huawei Technologies Co., Ltd. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,21 +16,23 @@ #include #include +#include #include #include +#include -#include "include/ErrorCode.hpp" -#include "include/Macro.hpp" -#include "utils/FileUtils.hpp" -#include "base/ErrorInfos.hpp" -#include "DebuggerConfigFieldMap.hpp" -#include "DebuggerConfig.hpp" +#include "include/ErrorCode.h" +#include "include/Macro.h" +#include "utils/FileUtils.h" +#include "base/ErrorInfosManager.h" +#include "DebuggerConfigFieldMap.h" +#include "DebuggerConfig.h" namespace MindStudioDebugger { template DebuggerErrno ParseJsonBaseObj2Var(const nlohmann::json& content, const std::string& field, T& output, - bool mandatory=false) + bool mandatory = false) { nlohmann::json::const_iterator iter = content.find(field); if (iter == content.end()) { @@ -51,8 +53,12 @@ DebuggerErrno ParseJsonBaseObj2Var(const nlohmann::json& content, const std::str } template -DebuggerErrno ParseJsonStringAndTrans(const nlohmann::json& content, const std::string& field, - const std::map& enum2name, T& output, bool mandatory=false) { +DebuggerErrno ParseJsonStringAndTrans(const nlohmann::json& content, + const std::string& field, + const std::map& enum2name, + T& output, + bool mandatory = false) +{ DebuggerErrno ret; std::string value; @@ -66,7 +72,7 @@ DebuggerErrno ParseJsonStringAndTrans(const nlohmann::json& content, const std:: } int32_t enumId = GetEnumIdFromName(enum2name, value); - if (enumId == debuggerInvalidEnum) { + if (enumId == DEBUGGER_INVALID_ENUM) { return DebuggerErrno::ERROR_UNKNOWN_VALUE; } @@ -93,19 +99,21 @@ DebuggerErrno ParseJsonStringAndTrans(const nlohmann::json& content, const std:: static bool DebuggerCfgParseUIntRangeGetBorder(const std::string& exp, uint32_t& left, uint32_t& right) { if (std::count(exp.begin(), exp.end(), '-') != 1) { - LOG_ERROR(DebuggerErrno::ERROR_INVALID_FORMAT, "When using a range expression, it should be formatted as \"a-b\"."); + LOG_ERROR(DebuggerErrno::ERROR_INVALID_FORMAT, + "When using a range expression, it should be formatted as \"a-b\"."); return false; } std::istringstream iss(exp); char dash; iss >> left >> dash >> right; if (iss.fail() || dash != '-') { - LOG_ERROR(DebuggerErrno::ERROR_INVALID_FORMAT, "When using a range expression, it should be formatted as \"a-b\"."); + LOG_ERROR(DebuggerErrno::ERROR_INVALID_FORMAT, + "When using a range expression, it should be formatted as \"a-b\"."); return false; } if (left >= right) { LOG_ERROR(DebuggerErrno::ERROR_INVALID_FORMAT, - "When using a range expression, the left border should be smaller than the right."); + "When using a range expression, the left border should be smaller than the right."); return false; } return true; @@ -135,12 +143,18 @@ void DebuggerCfgParseUIntRange(const nlohmann::json& content, const std::string& realLen++; } else if (element.is_string()) { std::string exp = element.get(); - uint32_t begin, end; + uint32_t begin; + uint32_t end; if (!DebuggerCfgParseUIntRangeGetBorder(exp, begin, end)) { LOG_ERROR(DebuggerErrno::ERROR_INVALID_FORMAT, "Failed to parse " + name + "."); return; } - realLen += (end - begin + 1); + uint32_t rangeSize = end - begin; + if (realLen > UINT32_MAX - (rangeSize + 1)) { + LOG_ERROR(DebuggerErrno::ERROR_VALUE_OVERFLOW, name + " size exceeds limit"); + return; + } + realLen += (rangeSize + 1); buf.emplace_back(std::make_pair(begin, end)); } } @@ -148,7 +162,7 @@ void DebuggerCfgParseUIntRange(const nlohmann::json& content, const std::string& constexpr uint32_t maxEleNum = 65536; if (realLen > maxEleNum) { LOG_ERROR(DebuggerErrno::ERROR_INVALID_FORMAT, - "When using a range expression in " + name + ", maximum of 65536 elements can be expressed."); + "When using a range expression in " + name + ", maximum of 65536 elements can be expressed."); return; } @@ -170,9 +184,9 @@ void CommonCfgParseTasks(const nlohmann::json& content, std::vector(content, kTask, taskName, true); + ret = ParseJsonBaseObj2Var(content, TASK, taskName, true); if (ret == DebuggerErrno::ERROR_FIELD_NOT_EXISTS) { - ret = ParseJsonBaseObj2Var>(content, kTasks, taskNameList, true); + ret = ParseJsonBaseObj2Var>(content, TASKS, taskNameList, true); } else { taskNameList.emplace_back(taskName); } @@ -183,8 +197,8 @@ void CommonCfgParseTasks(const nlohmann::json& content, std::vector& expressions) { for (auto& expression : expressions) { size_t len = expression.size(); - if (strncmp(expression.c_str(), kRegexPrefix, kRegexPrefixLen) == 0 && - strncmp(expression.c_str() + (len - kRegexSuffixLen), kRegexSuffix, kRegexSuffixLen) == 0) { - /* name-regex(xxx)表示正则表达式*/ - regexList.emplace_back(expression.substr(kRegexPrefixLen, len - kRegexPrefixLen - kRegexSuffixLen)); + if (strncmp(expression.c_str(), REGEX_PREFIX, REGEX_PREFIX_LEN) == 0 && + strncmp(expression.c_str() + (len - REGEX_SUFFIX_LEN), REGEX_SUFFIX, REGEX_SUFFIX_LEN) == 0) { + /* name-regex(xxx)表示正则表达式 */ + regexList.emplace_back(expression.substr(REGEX_INDEX, len - REGEX_INDEX)); } else { /* 否则认为是full scope name */ fullNameList.emplace_back(expression); @@ -219,15 +234,19 @@ std::vector KernelListMatcher::GenRealKernelList(const char** fullK { std::vector output; /* 返回空列表表示全部dump,返回一个空字符串表示没有匹配上的,都不dump */ - if (this->empty() || fullKernelList == nullptr) { + if (this->Empty() || fullKernelList == nullptr) { return output; } output = fullNameList; - for (auto& reg : regexList) { - for (const char** ss = fullKernelList; *ss != nullptr; ++ss) { - if (std::regex_search(*ss, reg)) { - output.emplace_back(*ss); + for (auto& pattern : regexList) { + re2::RE2 reg(pattern, re2::RE2::Quiet); + if (reg.ok()) { + for (const char** ss = fullKernelList; *ss != nullptr; ++ss) { + std::string ret; + if (re2::RE2::FullMatch(*ss, reg, &ret)) { + output.emplace_back(*ss); + } } } } @@ -247,34 +266,38 @@ void CommonCfg::Parse(const nlohmann::json& content) return; } - PARSE_OPTIONAL_FIELD_CHECK_RET(content, kOutputPath, outputPath); + PARSE_OPTIONAL_FIELD_CHECK_RET(content, OUTPUT_PATH, outputPath); outputPath = FileUtils::GetAbsPath(outputPath); - DebuggerCfgParseUIntRange(content, kRank, rank); - DebuggerCfgParseUIntRange(content, kStep, step); - PARSE_OPTIONAL_FIELD_TRANS_CHECK_RET(content, kLevel, DebuggerLevelEnum2Name, level); - PARSE_OPTIONAL_FIELD_CHECK_RET(content, kSeed, seed); - PARSE_OPTIONAL_FIELD_CHECK_RET(content, kIsDeterministic, isDeterministic); - PARSE_OPTIONAL_FIELD_CHECK_RET(content, kEnableDataloader, enableDataloader); - PARSE_OPTIONAL_FIELD_CHECK_RET(content, kAclConfig, aclConfig); + DebuggerCfgParseUIntRange(content, RANK, rank); + DebuggerCfgParseUIntRange(content, STEP, step); + PARSE_OPTIONAL_FIELD_TRANS_CHECK_RET(content, LEVEL, DEBUGGER_LEVEL_ENUM_2_NAME, level); + PARSE_OPTIONAL_FIELD_CHECK_RET(content, SEED, seed); + PARSE_OPTIONAL_FIELD_CHECK_RET(content, IS_DETERMINISTIC, isDeterministic); + PARSE_OPTIONAL_FIELD_CHECK_RET(content, ENABLE_DATALOADER, enableDataloader); + PARSE_OPTIONAL_FIELD_CHECK_RET(content, ACL_CONFIG, aclConfig); } void DebuggerCfgParseDataMode(const nlohmann::json& content, DebuggerDataDirection& direction, DebuggerDataInOut& inout) { std::vector buf; - bool fw, bw, in, out, all; + bool fw; + bool bw; + bool in; + bool out; + bool all; direction = DebuggerDataDirection::DIRECTION_BOTH; inout = DebuggerDataInOut::INOUT_BOTH; - PARSE_OPTIONAL_FIELD_CHECK_RET(content, kDataMode, buf); - all = static_cast(std::find(buf.begin(), buf.end(), kDataModeAll) != buf.end()); + PARSE_OPTIONAL_FIELD_CHECK_RET(content, DATA_MODE, buf); + all = static_cast(std::find(buf.begin(), buf.end(), DATA_MODE_ALL) != buf.end()); if (buf.empty() || all) { return; } - fw = static_cast(std::find(buf.begin(), buf.end(), kDirectionForward) != buf.end()); - bw = static_cast(std::find(buf.begin(), buf.end(), kDirectionBackward) != buf.end()); - in = static_cast(std::find(buf.begin(), buf.end(), kInOutInput) != buf.end()); - out = static_cast(std::find(buf.begin(), buf.end(), kInOutOutput) != buf.end()); + fw = static_cast(std::find(buf.begin(), buf.end(), DIRECTION_FORWARD) != buf.end()); + bw = static_cast(std::find(buf.begin(), buf.end(), DIRECTION_BACKWARD) != buf.end()); + in = static_cast(std::find(buf.begin(), buf.end(), INOUT_INPUT) != buf.end()); + out = static_cast(std::find(buf.begin(), buf.end(), INOUT_OUTPUT) != buf.end()); /* 互补项都配或都不配都表示both,因此关注不同的场景就行 */ if (fw != bw) { @@ -298,18 +321,18 @@ void StatisticsCfgParseSummary(const nlohmann::json& content, std::vector modeListName; /* 若无该字段,认为是statistic,因此这里给mode设个默认值 */ - ret = ParseJsonBaseObj2Var(content, kSummaryMode, mode); + ret = ParseJsonBaseObj2Var(content, SUMMARY_MODE, mode); if (ret == DebuggerErrno::OK) { - if (mode == kStatistics) { + if (mode == STATISTICS) { summaryOption.push_back(DebuggerSummaryOption::MAX); summaryOption.push_back(DebuggerSummaryOption::MIN); summaryOption.push_back(DebuggerSummaryOption::MEAN); summaryOption.push_back(DebuggerSummaryOption::L2NORM); - } else if (mode == kMd5) { + } else if (mode == MD5) { summaryOption.push_back(DebuggerSummaryOption::MD5); } else { LOG_ERROR(DebuggerErrno::ERROR_UNKNOWN_VALUE, "Summary mode " + mode + " is unknown."); @@ -317,7 +340,7 @@ void StatisticsCfgParseSummary(const nlohmann::json& content, std::vector>(content, kSummaryMode, modeListName); + ret = ParseJsonBaseObj2Var>(content, SUMMARY_MODE, modeListName); if (ret != DebuggerErrno::OK) { LOG_ERROR(ret, "Value of field summary_mode should be string or list."); return; @@ -333,8 +356,8 @@ void StatisticsCfgParseSummary(const nlohmann::json& content, std::vector filter; - PARSE_OPTIONAL_FIELD_CHECK_RET(content, kScope, scope); - PARSE_OPTIONAL_FIELD_CHECK_RET(content, kList, filter); + PARSE_OPTIONAL_FIELD_CHECK_RET(content, SCOPE, scope); + PARSE_OPTIONAL_FIELD_CHECK_RET(content, LIST, filter); filter.erase(std::remove_if(filter.begin(), filter.end(), [](const std::string& s) { return s.find_first_not_of(' ') == std::string::npos; }), - filter.end()); + filter.end()); list = std::move(filter); if (DebuggerConfig::GetInstance().GetDebugLevel() == DebuggerLevel::L2) { matcher.Parse(list); @@ -363,24 +386,24 @@ void StatisticsCfg::Parse(const nlohmann::json& content) void DumpTensorCfg::Parse(const nlohmann::json& content) { std::vector filter; - PARSE_OPTIONAL_FIELD_CHECK_RET(content, kScope, scope); - PARSE_OPTIONAL_FIELD_CHECK_RET(content, kList, filter); + PARSE_OPTIONAL_FIELD_CHECK_RET(content, SCOPE, scope); + PARSE_OPTIONAL_FIELD_CHECK_RET(content, LIST, filter); filter.erase(std::remove_if(filter.begin(), filter.end(), [](const std::string& s) { return s.find_first_not_of(' ') == std::string::npos; }), - filter.end()); + filter.end()); list = std::move(filter); if (DebuggerConfig::GetInstance().GetDebugLevel() == DebuggerLevel::L2) { matcher.Parse(list); } DebuggerCfgParseDataMode(content, direction, inout); - PARSE_OPTIONAL_FIELD_TRANS_CHECK_RET(content, kFileFormat, DumpFileFormatEnum2Name, fileFormat); - PARSE_OPTIONAL_FIELD_CHECK_RET(content, kBackwardInput, backwardInput); + PARSE_OPTIONAL_FIELD_TRANS_CHECK_RET(content, FILE_FORMAT, DUMP_FILE_FORMAT_ENUM_2_NAME, fileFormat); + PARSE_OPTIONAL_FIELD_CHECK_RET(content, BACKWARD_INPUT, backwardInput); } void OverflowCheckCfg::Parse(const nlohmann::json& content) { - PARSE_OPTIONAL_FIELD_CHECK_RET(content, kOverflowNums, overflowNums); - PARSE_OPTIONAL_FIELD_TRANS_CHECK_RET(content, kCheckMode, OpCheckLevelEnum2Name, checkMode); + PARSE_OPTIONAL_FIELD_CHECK_RET(content, OVERFLOW_NUMS, overflowNums); + PARSE_OPTIONAL_FIELD_TRANS_CHECK_RET(content, CHECK_MODE, OP_CHECK_LEVEL_ENUM_2_NAME, checkMode); } void DebuggerConfig::Reset() @@ -419,14 +442,14 @@ void DebuggerConfig::Parse() iter = content.find(name); \ if (iter != content.end()) { \ member = std::make_shared(); \ - member->Parse(*(iter)); \ + ((member)->Parse(*(iter))); \ } \ } \ } while (0) - PARSE_SUBTASK_CONFIG(DebuggerTaskType::TASK_DUMP_STATISTICS, kTaskStatistics, statisticCfg, StatisticsCfg); - PARSE_SUBTASK_CONFIG(DebuggerTaskType::TASK_DUMP_TENSOR, kTaskDumpTensor, dumpTensorCfg, DumpTensorCfg); - PARSE_SUBTASK_CONFIG(DebuggerTaskType::TASK_OVERFLOW_CHECK, kTaskOverflowCheck, overflowCheckCfg, OverflowCheckCfg); + PARSE_SUBTASK_CONFIG(DebuggerTaskType::TASK_DUMP_STATISTICS, TASK_STATISTICS, statisticCfg, StatisticsCfg); + PARSE_SUBTASK_CONFIG(DebuggerTaskType::TASK_DUMP_TENSOR, TASK_DUMP_TENSOR, dumpTensorCfg, DumpTensorCfg); + PARSE_SUBTASK_CONFIG(DebuggerTaskType::TASK_OVERFLOW_CHECK, TASK_OVERFLOW_CHECK, overflowCheckCfg, OverflowCheckCfg); #undef PARSE_SUBTASK_CONFIG return; @@ -451,8 +474,8 @@ int32_t DebuggerConfig::LoadConfig(const std::string& framework, const std::stri return -1; } - int32_t enumId = GetEnumIdFromName(FrameworkEnum2Name, framework); - if (enumId == debuggerInvalidEnum) { + int32_t enumId = GetEnumIdFromName(FRAMEWORK_ENUM_2_NAME, framework); + if (enumId == DEBUGGER_INVALID_ENUM) { LOG_ERROR(DebuggerErrno::ERROR_UNKNOWN_VALUE, "Unknown framework " + framework + "."); return -1; } diff --git a/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.hpp b/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.h similarity index 88% rename from debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.hpp rename to debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.h index 15ea9e6fda47c0380d9718f135a1baf0658788eb..bdae11cb721837d05440162436c05d8306a15a55 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (C) 2024-2025. Huawei Technologies Co., Ltd. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,8 @@ * limitations under the License. */ -#pragma once +#ifndef DEBUGGERCONFIG_H +#define DEBUGGERCONFIG_H #include #include @@ -22,15 +23,14 @@ #include #include #include -#include #include #include -#include "include/Macro.hpp" +#include "include/Macro.h" namespace MindStudioDebugger { -constexpr int debuggerInvalidEnum = -1; +constexpr int DEBUGGER_INVALID_ENUM = -1; enum class DebuggerFramework { FRAMEWORK_PYTORCH, @@ -47,7 +47,7 @@ enum class DebuggerTaskType { TASK_RUN_UT, TASK_GRAD_PROBE, - TASK_BUTT = debuggerInvalidEnum, + TASK_BUTT = DEBUGGER_INVALID_ENUM, }; enum class DebuggerDevType { @@ -55,7 +55,7 @@ enum class DebuggerDevType { DEVICE_TYPE_GPU, DEVICE_TYPE_CPU, - DEVICE_TYPE_BUTT = debuggerInvalidEnum, + DEVICE_TYPE_BUTT = DEBUGGER_INVALID_ENUM, }; enum class DebuggerLevel { @@ -64,7 +64,7 @@ enum class DebuggerLevel { L2, MIX, - LEVEL_BUTT = debuggerInvalidEnum, + LEVEL_BUTT = DEBUGGER_INVALID_ENUM, }; enum class DebuggerDataDirection { @@ -72,7 +72,7 @@ enum class DebuggerDataDirection { DIRECTION_BACKWARD, DIRECTION_BOTH, - DIRECTION_BUTT = debuggerInvalidEnum, + DIRECTION_BUTT = DEBUGGER_INVALID_ENUM, }; enum class DebuggerDataInOut { @@ -80,14 +80,14 @@ enum class DebuggerDataInOut { INOUT_OUTPUT, INOUT_BOTH, - INOUT_BUTT = debuggerInvalidEnum, + INOUT_BUTT = DEBUGGER_INVALID_ENUM, }; enum class DebuggerDumpFileFormat { FILE_FORMAT_BIN, FILE_FORMAT_NPY, - FILE_FORMAT_BUTT = debuggerInvalidEnum, + FILE_FORMAT_BUTT = DEBUGGER_INVALID_ENUM, }; enum class DebuggerOpCheckLevel { @@ -95,7 +95,7 @@ enum class DebuggerOpCheckLevel { CHECK_LEVEL_ATOMIC, CHECK_LEVEL_ALL, - CHECK_LEVEL_BUTT = debuggerInvalidEnum, + CHECK_LEVEL_BUTT = DEBUGGER_INVALID_ENUM, }; enum class DebuggerSummaryOption { @@ -108,7 +108,7 @@ enum class DebuggerSummaryOption { POS_INF_CNT, MD5, - SUMMARY_BUTT = debuggerInvalidEnum, + SUMMARY_BUTT = DEBUGGER_INVALID_ENUM, }; class KernelListMatcher { @@ -119,12 +119,12 @@ public: void Parse(const std::vector& expressions); std::vector GenRealKernelList(const char** fullKernelList) const; - inline bool empty() const {return fullNameList.empty() && regexList.empty();} - inline bool needAllKernels() const {return !regexList.empty();} + inline bool Empty() const {return fullNameList.empty() && regexList.empty();} + inline bool NeedAllKernels() const {return !regexList.empty();} private: std::vector fullNameList; - std::vector regexList; + std::vector regexList; }; /* 说明:config类作为基础的配置解析查询类,对外应该是只读的,外部仅能通过Parse接口解析配置文件,而不应该直接修改配置字段,此处用以下方式防止外部误操作 @@ -199,7 +199,7 @@ public: OverflowCheckCfg() = default; ~OverflowCheckCfg() = default; - uint32_t overflowNums{1}; + int32_t overflowNums{1}; DebuggerOpCheckLevel checkMode{DebuggerOpCheckLevel::CHECK_LEVEL_ALL}; private: @@ -208,11 +208,11 @@ private: class DebuggerConfig { - public: - static DebuggerConfig& GetInstance() { - static DebuggerConfig instance_; - return instance_; + static DebuggerConfig& GetInstance() + { + static DebuggerConfig configInstance; + return configInstance; } int32_t LoadConfig(const std::string& framework, const std::string& cfgFilePath); @@ -262,4 +262,6 @@ private: std::shared_ptr overflowCheckCfg{nullptr}; }; -} \ No newline at end of file +} + +#endif \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfigFieldMap.hpp b/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfigFieldMap.h similarity index 30% rename from debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfigFieldMap.hpp rename to debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfigFieldMap.h index 8ebef4206b42b702712edccc5b19d9611370c63b..95954ecd417275c6e38fc37f01a6f8bb18c939e4 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfigFieldMap.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfigFieldMap.h @@ -19,129 +19,129 @@ #include #include -#include "DebuggerConfig.hpp" +#include "DebuggerConfig.h" namespace MindStudioDebugger { -constexpr const char* kFramework = "framework"; -constexpr const char* kFrameworkPyTorch = "PyTorch"; -constexpr const char* kFrameworkMindSpore = "MindSpore"; - -constexpr const char* kTaskStatistics = "statistics"; -constexpr const char* kTaskDumpTensor = "tensor"; -constexpr const char* kTaskOverflowCheck = "overflow_check"; -constexpr const char* kFreeBenchmark = "free_benchmark"; -constexpr const char* kRunUT = "run_ut"; -constexpr const char* kGradProbe = "grad_probe"; - -constexpr const char* kLevel0 = "L0"; -constexpr const char* kLevel1 = "L1"; -constexpr const char* kLevel2 = "L2"; -constexpr const char* kLevelMix = "mix"; - -constexpr const char* kDirectionForward = "forward"; -constexpr const char* kDirectionBackward = "backward"; -constexpr const char* kDirectionBoth = "both"; -constexpr const char* kInOutInput = "input"; -constexpr const char* kInOutOutput = "output"; -constexpr const char* kInOutBoth = "both"; -constexpr const char* kDataModeAll = "all"; - -constexpr const char* kFreeBenchmarkHandlerCheck = "check"; -constexpr const char* kFreeBenchmarkHandlerFix = "fix"; - -constexpr const char* kDumpFileFormatBin = "bin"; -constexpr const char* kDumpFileFormatNpy = "npy"; - -constexpr const char* kOpCheckLevelAiCore = "aicore"; -constexpr const char* kOpCheckLevelAtomic = "atomic"; -constexpr const char* kOpCheckLevelAll = "all"; - -constexpr const char* kTask = "task"; -constexpr const char* kTasks = "tasks"; -constexpr const char* kOutputPath = "dump_path"; -constexpr const char* kRank = "rank"; -constexpr const char* kStep = "step"; -constexpr const char* kLevel = "level"; -constexpr const char* kSeed = "seed"; -constexpr const char* kIsDeterministic = "is_deterministic"; -constexpr const char* kEnableDataloader = "enable_dataloader"; -constexpr const char* kAclConfig = "acl_config"; - -constexpr const char* kScope = "scope"; -constexpr const char* kList = "list"; - -constexpr const char* kDataMode = "data_mode"; -constexpr const char* kSummaryMode = "summary_mode"; -constexpr const char* kFileFormat = "file_format"; -constexpr const char* kOverflowNums = "overflow_nums"; -constexpr const char* kCheckMode = "check_mode"; -constexpr const char* kBackwardInput = "backward_input"; - -constexpr const char* kStatistics = "statistics"; -constexpr const char* kMd5 = "md5"; -constexpr const char* kMax = "max"; -constexpr const char* kMin = "min"; -constexpr const char* kMean = "mean"; -constexpr const char* kL2Norm = "l2norm"; -constexpr const char* kNanCount = "nan count"; -constexpr const char* kNegativeInfCount = "negative inf count"; -constexpr const char* kPositiveInfCount = "positive inf count"; - -const std::map FrameworkEnum2Name = { - {static_cast(DebuggerFramework::FRAMEWORK_PYTORCH), kFrameworkPyTorch}, - {static_cast(DebuggerFramework::FRAMEWORK_MINDSPORE), kFrameworkMindSpore}, +constexpr const char* FRAMEWORK = "framework"; +constexpr const char* FRAMEWORK_PYTORCH = "PyTorch"; +constexpr const char* FRAMEWORK_MINDSPORE = "MindSpore"; + +constexpr const char* TASK_STATISTICS = "statistics"; +constexpr const char* TASK_DUMP_TENSOR = "tensor"; +constexpr const char* TASK_OVERFLOW_CHECK = "overflow_check"; +constexpr const char* TASK_FREE_BENCHMARK = "free_benchmark"; +constexpr const char* TASK_RUN_UT = "run_ut"; +constexpr const char* TASK_GRAD_PROBE = "grad_probe"; + +constexpr const char* LEVEL0 = "L0"; +constexpr const char* LEVEL1 = "L1"; +constexpr const char* LEVEL2 = "L2"; +constexpr const char* LEVEL_MIX = "mix"; + +constexpr const char* DIRECTION_FORWARD = "forward"; +constexpr const char* DIRECTION_BACKWARD = "backward"; +constexpr const char* DIRECTION_BOTH = "both"; +constexpr const char* INOUT_INPUT = "input"; +constexpr const char* INOUT_OUTPUT = "output"; +constexpr const char* INOUT_BOTH = "both"; +constexpr const char* DATA_MODE_ALL = "all"; + +constexpr const char* FREE_BENCHMARK_HANDLER_CHECK = "check"; +constexpr const char* FREE_BENCHMARK_HANDLER_FIX = "fix"; + +constexpr const char* DUMP_FILE_FORMAT_BIN = "bin"; +constexpr const char* DUMP_FILE_FORMAT_NPY = "npy"; + +constexpr const char* OP_CHECK_LEVEL_AICORE = "aicore"; +constexpr const char* OP_CHECK_LEVEL_ATOMIC = "atomic"; +constexpr const char* OP_CHECK_LEVEL_ALL = "all"; + +constexpr const char* TASK = "task"; +constexpr const char* TASKS = "tasks"; +constexpr const char* OUTPUT_PATH = "dump_path"; +constexpr const char* RANK = "rank"; +constexpr const char* STEP = "step"; +constexpr const char* LEVEL = "level"; +constexpr const char* SEED = "seed"; +constexpr const char* IS_DETERMINISTIC = "is_deterministic"; +constexpr const char* ENABLE_DATALOADER = "enable_dataloader"; +constexpr const char* ACL_CONFIG = "acl_config"; + +constexpr const char* SCOPE = "scope"; +constexpr const char* LIST = "list"; + +constexpr const char* DATA_MODE = "data_mode"; +constexpr const char* SUMMARY_MODE = "summary_mode"; +constexpr const char* FILE_FORMAT = "file_format"; +constexpr const char* OVERFLOW_NUMS = "overflow_nums"; +constexpr const char* CHECK_MODE = "check_mode"; +constexpr const char* BACKWARD_INPUT = "backward_input"; + +constexpr const char* STATISTICS = "statistics"; +constexpr const char* MD5 = "md5"; +constexpr const char* MAX = "max"; +constexpr const char* MIN = "min"; +constexpr const char* MEAN = "mean"; +constexpr const char* L2_NORM = "l2norm"; +constexpr const char* NAN_COUNT = "nan count"; +constexpr const char* NEGATIVE_INF_COUNT = "negative inf count"; +constexpr const char* POSITIVE_INF_COUNT = "positive inf count"; + +const std::map FRAMEWORK_ENUM_2_NAME = { + {static_cast(DebuggerFramework::FRAMEWORK_PYTORCH), FRAMEWORK_PYTORCH}, + {static_cast(DebuggerFramework::FRAMEWORK_MINDSPORE), FRAMEWORK_MINDSPORE}, }; -const std::map TaskTypeEnum2Name = { - {static_cast(DebuggerTaskType::TASK_DUMP_TENSOR), kTaskDumpTensor}, - {static_cast(DebuggerTaskType::TASK_DUMP_STATISTICS), kTaskStatistics}, - {static_cast(DebuggerTaskType::TASK_OVERFLOW_CHECK), kTaskOverflowCheck}, - {static_cast(DebuggerTaskType::TASK_FREE_BENCHMARK), kFreeBenchmark}, - {static_cast(DebuggerTaskType::TASK_RUN_UT), kRunUT}, - {static_cast(DebuggerTaskType::TASK_GRAD_PROBE), kGradProbe}, +const std::map TASK_TYPE_ENUM_2_NAME = { + {static_cast(DebuggerTaskType::TASK_DUMP_TENSOR), TASK_DUMP_TENSOR}, + {static_cast(DebuggerTaskType::TASK_DUMP_STATISTICS), TASK_STATISTICS}, + {static_cast(DebuggerTaskType::TASK_OVERFLOW_CHECK), TASK_OVERFLOW_CHECK}, + {static_cast(DebuggerTaskType::TASK_FREE_BENCHMARK), TASK_FREE_BENCHMARK}, + {static_cast(DebuggerTaskType::TASK_RUN_UT), TASK_RUN_UT}, + {static_cast(DebuggerTaskType::TASK_GRAD_PROBE), TASK_GRAD_PROBE}, }; -const std::map DebuggerLevelEnum2Name = { - {static_cast(DebuggerLevel::L0), kLevel0}, - {static_cast(DebuggerLevel::L1), kLevel1}, - {static_cast(DebuggerLevel::L2), kLevel2}, - {static_cast(DebuggerLevel::MIX), kLevelMix}, +const std::map DEBUGGER_LEVEL_ENUM_2_NAME = { + {static_cast(DebuggerLevel::L0), LEVEL0}, + {static_cast(DebuggerLevel::L1), LEVEL0}, + {static_cast(DebuggerLevel::L2), LEVEL2}, + {static_cast(DebuggerLevel::MIX), LEVEL_MIX}, }; -const std::map DataDirectionEnum2Name = { - {static_cast(DebuggerDataDirection::DIRECTION_FORWARD), kDirectionForward}, - {static_cast(DebuggerDataDirection::DIRECTION_BACKWARD), kDirectionBackward}, - {static_cast(DebuggerDataDirection::DIRECTION_BOTH), kDirectionBoth}, +const std::map DATA_DIRECTION_ENUM_2_NAME = { + {static_cast(DebuggerDataDirection::DIRECTION_FORWARD), DIRECTION_FORWARD}, + {static_cast(DebuggerDataDirection::DIRECTION_BACKWARD), DIRECTION_BACKWARD}, + {static_cast(DebuggerDataDirection::DIRECTION_BOTH), DIRECTION_BOTH}, }; -const std::map DataInOutEnum2Name = { - {static_cast(DebuggerDataInOut::INOUT_INPUT), kInOutInput}, - {static_cast(DebuggerDataInOut::INOUT_OUTPUT), kInOutOutput}, - {static_cast(DebuggerDataInOut::INOUT_BOTH), kInOutBoth}, +const std::map DATA_INOUT_ENUM_2_NAME = { + {static_cast(DebuggerDataInOut::INOUT_INPUT), INOUT_INPUT}, + {static_cast(DebuggerDataInOut::INOUT_OUTPUT), INOUT_OUTPUT}, + {static_cast(DebuggerDataInOut::INOUT_BOTH), INOUT_BOTH}, }; -const std::map DumpFileFormatEnum2Name = { - {static_cast(DebuggerDumpFileFormat::FILE_FORMAT_BIN), kDumpFileFormatBin}, - {static_cast(DebuggerDumpFileFormat::FILE_FORMAT_NPY), kDumpFileFormatNpy}, +const std::map DUMP_FILE_FORMAT_ENUM_2_NAME = { + {static_cast(DebuggerDumpFileFormat::FILE_FORMAT_BIN), DUMP_FILE_FORMAT_BIN}, + {static_cast(DebuggerDumpFileFormat::FILE_FORMAT_NPY), DUMP_FILE_FORMAT_NPY}, }; -const std::map OpCheckLevelEnum2Name = { - {static_cast(DebuggerOpCheckLevel::CHECK_LEVEL_AICORE), kOpCheckLevelAiCore}, - {static_cast(DebuggerOpCheckLevel::CHECK_LEVEL_ATOMIC), kOpCheckLevelAtomic}, - {static_cast(DebuggerOpCheckLevel::CHECK_LEVEL_ALL), kOpCheckLevelAll}, +const std::map OP_CHECK_LEVEL_ENUM_2_NAME = { + {static_cast(DebuggerOpCheckLevel::CHECK_LEVEL_AICORE), OP_CHECK_LEVEL_AICORE}, + {static_cast(DebuggerOpCheckLevel::CHECK_LEVEL_ATOMIC), OP_CHECK_LEVEL_ATOMIC}, + {static_cast(DebuggerOpCheckLevel::CHECK_LEVEL_ALL), OP_CHECK_LEVEL_ALL}, }; -const std::map SummaryOptionEnum2Name = { - {static_cast(DebuggerSummaryOption::MAX), kMax}, - {static_cast(DebuggerSummaryOption::MIN), kMin}, - {static_cast(DebuggerSummaryOption::MEAN), kMean}, - {static_cast(DebuggerSummaryOption::NAN_CNT), kNanCount}, - {static_cast(DebuggerSummaryOption::NEG_INF_CNT), kNegativeInfCount}, - {static_cast(DebuggerSummaryOption::POS_INF_CNT), kPositiveInfCount}, - {static_cast(DebuggerSummaryOption::L2NORM), kL2Norm}, +const std::map SUMMARY_OPTION_ENUM_2_NAME = { + {static_cast(DebuggerSummaryOption::MAX), MAX}, + {static_cast(DebuggerSummaryOption::MIN), MIN}, + {static_cast(DebuggerSummaryOption::MEAN), MEAN}, + {static_cast(DebuggerSummaryOption::NAN_CNT), NAN_COUNT}, + {static_cast(DebuggerSummaryOption::NEG_INF_CNT), NEGATIVE_INF_COUNT}, + {static_cast(DebuggerSummaryOption::POS_INF_CNT), POSITIVE_INF_COUNT}, + {static_cast(DebuggerSummaryOption::L2NORM), L2_NORM}, - {static_cast(DebuggerSummaryOption::MD5), kMd5}, + {static_cast(DebuggerSummaryOption::MD5), MD5}, }; inline int32_t GetEnumIdFromName(const std::map& enum2name, const std::string& name) @@ -151,7 +151,7 @@ inline int32_t GetEnumIdFromName(const std::map& enum2name return iter->first; } } - return debuggerInvalidEnum; + return DEBUGGER_INVALID_ENUM; } inline std::string GetNameFromEnumId(const std::map& enum2name, int32_t id) diff --git a/debug/accuracy_tools/msprobe/ccsrc/base/Environment.cpp b/debug/accuracy_tools/msprobe/ccsrc/base/Environment.cpp index 3a31e03cf898901767e3c658b993edc14b76e35a..58d89561a0c2d19f760d87435d1c1c187aeb6bc7 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/base/Environment.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/base/Environment.cpp @@ -1,5 +1,5 @@ /* - * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (C) 2024-2025. Huawei Technologies Co., Ltd. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,14 +14,16 @@ * limitations under the License. */ -#include "utils/CPythonUtils.hpp" -#include "DebuggerConfig.hpp" -#include "Environment.hpp" +#include + +#include "utils/CPythonUtils.h" +#include "DebuggerConfig.h" +#include "Environment.h" namespace MindStudioDebugger { namespace Environment { -static int32_t GetRankID_PT() +static int32_t GetPTRankID() { /* if torch.distributed.is_initialized(): * return torch.distributed.get_rank() @@ -48,10 +50,10 @@ static int32_t GetRankID_PT() return id; } -static int32_t GetRankID_MS() +static int32_t GetMSRankID() { - constexpr const char* kRankId = "RANK_ID"; - const char* rankIdEnv = getenv(kRankId); + constexpr const char* RANK_ID = "RANK_ID"; + const char* rankIdEnv = getenv(RANK_ID); if (rankIdEnv == nullptr) { return -1; } @@ -78,9 +80,9 @@ int32_t GetRankID() } if (DebuggerConfig::GetInstance().GetFramework() == DebuggerFramework::FRAMEWORK_PYTORCH) { - id = GetRankID_PT(); + id = GetPTRankID(); } else if (DebuggerConfig::GetInstance().GetFramework() == DebuggerFramework::FRAMEWORK_MINDSPORE) { - id = GetRankID_MS(); + id = GetMSRankID(); } return id; diff --git a/debug/accuracy_tools/msprobe/ccsrc/base/Environment.hpp b/debug/accuracy_tools/msprobe/ccsrc/base/Environment.h similarity index 100% rename from debug/accuracy_tools/msprobe/ccsrc/base/Environment.hpp rename to debug/accuracy_tools/msprobe/ccsrc/base/Environment.h diff --git a/debug/accuracy_tools/msprobe/ccsrc/base/ErrorInfos.cpp b/debug/accuracy_tools/msprobe/ccsrc/base/ErrorInfosManager.cpp similarity index 89% rename from debug/accuracy_tools/msprobe/ccsrc/base/ErrorInfos.cpp rename to debug/accuracy_tools/msprobe/ccsrc/base/ErrorInfosManager.cpp index b07554a9fe10609ab4fa03357877b2f7630bd55e..755be22eac060c150aa9bdd508888ae2879a5d90 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/base/ErrorInfos.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/base/ErrorInfosManager.cpp @@ -22,13 +22,12 @@ #include #include -#include "utils/FileUtils.hpp" -#include "ErrorInfos.hpp" +#include "utils/FileUtils.h" +#include "ErrorInfosManager.h" namespace MindStudioDebugger { -static std::mutex errInfoMtx; -static std::ofstream logOfs; +static std::mutex g_errInfoMtx; DebuggerErrLevel ErrorInfosManager::topLevel = DebuggerErrLevel::LEVEL_NONE; DebuggerErrLevel ErrorInfosManager::threshold = DebuggerErrLevel::LEVEL_INFO; @@ -84,8 +83,8 @@ void ErrorInfosManager::LogErrorInfo(DebuggerErrLevel level, DebuggerErrno errId return; } - std::lock_guard lk(errInfoMtx); - std::ostream& output = logOfs.is_open() ? logOfs : std::cout; + std::lock_guard lk(g_errInfoMtx); + std::ostream& output = std::cout; output << "[" << ErrorLevelString[level] << "]"; if (errId != DebuggerErrno::NONE) { output << "[" << ErrnoString[errId] << "]"; @@ -101,26 +100,12 @@ void ErrorInfosManager::LogErrorInfo(DebuggerErrLevel level, DebuggerErrno errId DebuggerErrLevel ErrorInfosManager::GetTopErrLevelInDuration() { - std::lock_guard lk(errInfoMtx); + std::lock_guard lk(g_errInfoMtx); DebuggerErrLevel ret = topLevel; topLevel = DebuggerErrLevel::LEVEL_NONE; return ret; } -void ErrorInfosManager::SetLogPath(const std::string& path) -{ - std::lock_guard lk(errInfoMtx); - if (logOfs.is_open()) { - logOfs.close(); - } - - if (path.empty()) { - return; - } - - FileUtils::OpenFile(path, logOfs); -} - __attribute__((constructor)) void InitDebuggerThreshold() { const char* msprobeLogLevelEnv = getenv("MSPROBE_LOG_LEVEL"); diff --git a/debug/accuracy_tools/msprobe/ccsrc/base/ErrorInfos.hpp b/debug/accuracy_tools/msprobe/ccsrc/base/ErrorInfosManager.h similarity index 96% rename from debug/accuracy_tools/msprobe/ccsrc/base/ErrorInfos.hpp rename to debug/accuracy_tools/msprobe/ccsrc/base/ErrorInfosManager.h index 6c740a6a36cfd7692b793dfa7625789771731289..62d1a1e8902da59ebeef90e7c1fd2dd4ce188f21 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/base/ErrorInfos.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/base/ErrorInfosManager.h @@ -18,7 +18,7 @@ #include #include -#include "include/ErrorCode.hpp" +#include "include/ErrorCode.h" namespace MindStudioDebugger { @@ -35,14 +35,14 @@ class ErrorInfosManager { public: static void LogErrorInfo(DebuggerErrLevel level, DebuggerErrno errId, const std::string& info); static DebuggerErrLevel GetTopErrLevelInDuration(); - static void SetLogPath(const std::string& path); static void SetLogThreshold(DebuggerErrLevel t) { threshold = t; } private: static DebuggerErrLevel topLevel; static DebuggerErrLevel threshold; }; -inline void CleanErrorInfoCache() { +inline void CleanErrorInfoCache() +{ ErrorInfosManager::GetTopErrLevelInDuration(); } diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumpDataProcessor.cpp b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumpDataProcessor.cpp index 0fe3443fa1f9286fe77c710c955d543d94c4b3a4..426913da79142993586b09afd272fd66562d10df 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumpDataProcessor.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumpDataProcessor.cpp @@ -25,54 +25,61 @@ #include #include -#include "include/Macro.hpp" -#include "utils/FileUtils.hpp" -#include "utils/FileOperation.hpp" -#include "utils/DataUtils.hpp" -#include "utils/MathUtils.hpp" -#include "core/AclTensor.hpp" -#include "base/ErrorInfos.hpp" +#include "include/Macro.h" +#include "utils/FileUtils.h" +#include "utils/FileOperation.h" +#include "utils/DataUtils.h" +#include "utils/MathUtils.h" +#include "core/AclTensor.h" +#include "base/ErrorInfosManager.h" #include "proto/AclDumpMsg.pb.h" -#include "AclDumpDataProcessor.hpp" +#include "AclDumpDataProcessor.h" namespace MindStudioDebugger { namespace AclDumpMsg = toolkit::dumpdata; -constexpr size_t kDhaAtomicAddInfoSize = 128; -constexpr size_t kL2AtomicAddInfoSize = 128; -constexpr size_t kAiCoreInfoSize = 256; -constexpr size_t kDhaAtomicAddStatusSize = 256; -constexpr size_t kL2AtomicAddStatusSize = 256; -constexpr size_t kUint64Size = sizeof(uint64_t); -constexpr const char* debugFileSign = "Opdebug.Node_OpDebug."; - -constexpr const char* kStatsHeaderInout = "Input/Output"; -constexpr const char* kStatsHeaderId = "Index"; -constexpr const char* kStatsHeaderDataSize = "Data Size"; -constexpr const char* kStatsHeaderDataType = "Data Type"; -constexpr const char* kStatsHeaderFormat = "Format"; -constexpr const char* kStatsHeaderShape = "Shape"; -constexpr const char* kStatsHeaderMax = "Max Value"; -constexpr const char* kStatsHeaderMin = "Min Value"; -constexpr const char* kStatsHeaderAvg = "Avg Value"; -constexpr const char* kStatsHeaderL2Norm = "L2 Norm Value"; -constexpr const char* kStatsHeaderMD5 = "MD5 Value"; -constexpr const char* kStatsHeaderNan = "Nan Count"; -constexpr const char* kStatsHeaderNegInf = "Negative Inf Count"; -constexpr const char* kStatsHeaderPosInf = "Positive Inf Count"; -constexpr const char* kRankId = "RANK_ID"; -constexpr const char* kDigitalNumbers = "0123456789"; - -static const std::map summaryOptionHeaderStrMap = { - {DebuggerSummaryOption::MAX, kStatsHeaderMax}, - {DebuggerSummaryOption::MIN, kStatsHeaderMin}, - {DebuggerSummaryOption::MEAN, kStatsHeaderAvg}, - {DebuggerSummaryOption::L2NORM, kStatsHeaderL2Norm}, - {DebuggerSummaryOption::NAN_CNT, kStatsHeaderNan}, - {DebuggerSummaryOption::NEG_INF_CNT, kStatsHeaderNegInf}, - {DebuggerSummaryOption::POS_INF_CNT, kStatsHeaderPosInf}, - {DebuggerSummaryOption::MD5, kStatsHeaderMD5}, +constexpr size_t DHA_ATOMIC_ADD_INFO_SIZE = 128; +constexpr size_t L2_ATOMIC_ADD_INFO_SIZE = 128; +constexpr size_t AICORE_INFO_SIZE = 256; +constexpr size_t DHA_ATOMIC_ADD_STATUS_SIZE = 256; +constexpr size_t L2_ATOMIC_ADD_STATUS_SIZE = 256; +constexpr size_t UINT64_SIZE = sizeof(uint64_t); +constexpr const char* DEBUG_FILE_SIGN = "Opdebug.Node_OpDebug."; + +constexpr const char* STATS_HEADER_INOUT = "Input/Output"; +constexpr const char* STATS_HEADER_ID = "Index"; +constexpr const char* STATS_HEADER_DATA_SIZE = "Data Size"; +constexpr const char* STATS_HEADER_DATA_TYPE = "Data Type"; +constexpr const char* STATS_HEADER_FORMAT = "Format"; +constexpr const char* STATS_HEADER_SHAPE = "Shape"; +constexpr const char* STATS_HEADER_MAX = "Max Value"; +constexpr const char* STATS_HEADER_MIN = "Min Value"; +constexpr const char* STATS_HEADER_AVG = "Avg Value"; +constexpr const char* STATS_HEADER_L2NORM = "l2norm"; +constexpr const char* STATS_CSV_HEADER_L2NORM = "L2Norm Value"; +constexpr const char* STATS_HEADER_MD5 = "MD5 Value"; +constexpr const char* STATS_HEADER_NAN = "Nan Count"; +constexpr const char* STATS_CSV_HEADER_NAN = "NaN Count"; +constexpr const char* STATS_HEADER_NEG_INF = "Negative Inf Count"; +constexpr const char* STATS_HEADER_POS_INF = "Positive Inf Count"; +constexpr const char* RANK_ID = "RANK_ID"; +constexpr const char* DIGITAL_NUMBERS = "0123456789"; + +static const std::map> SUMMARY_OPTION_HEADER_STR_MAP = { + {DebuggerSummaryOption::MAX, {STATS_HEADER_MAX, STATS_HEADER_MAX}}, + {DebuggerSummaryOption::MIN, {STATS_HEADER_MIN, STATS_HEADER_MIN}}, + {DebuggerSummaryOption::MEAN, {STATS_HEADER_AVG, STATS_HEADER_AVG}}, + {DebuggerSummaryOption::L2NORM, {STATS_HEADER_L2NORM, STATS_CSV_HEADER_L2NORM}}, + {DebuggerSummaryOption::NAN_CNT, {STATS_HEADER_NAN, STATS_CSV_HEADER_NAN}}, + {DebuggerSummaryOption::NEG_INF_CNT, {STATS_HEADER_NEG_INF, STATS_HEADER_NEG_INF}}, + {DebuggerSummaryOption::POS_INF_CNT, {STATS_HEADER_POS_INF, STATS_HEADER_POS_INF}}, + {DebuggerSummaryOption::MD5, {STATS_HEADER_MD5, STATS_HEADER_MD5}}, +}; + +const static std::map kDtypeTransMap = { + {AclDtype::DT_BF16, AclDtype::DT_FLOAT}, + {AclDtype::DT_INT4, AclDtype::DT_INT8}, }; class AclTensorStats { @@ -84,7 +91,7 @@ public: std::string GetCsvHeader() const; std::string GetCsvValue() const; std::string GetPath() const {return path;} - bool empty() const {return stats.empty();}; + bool Empty() const {return stats.empty();}; static AclTensorStats CalTensorSummary(const AclTensorInfo& tensor, const std::vector& opt); static AclTensorStats ParseTensorSummary(const std::string& dumpPath, const std::string& input); @@ -107,13 +114,13 @@ private: void ParseInfoFromDumpPath(const std::string& dumpPath); std::string& operator[](DebuggerSummaryOption opt) { return stats[opt]; } - static constexpr const size_t bufferLen = 1024; + static constexpr const size_t BUFFER_LEN = 1024; }; void AclTensorStats::ParseInfoFromDumpPath(const std::string& dumpPath) { std::string filename; - if (FileUtils::GetFileSuffix(filename) == "csv") { + if (FileUtils::GetFileSuffix(dumpPath) == "csv") { filename = FileUtils::GetFileBaseName(dumpPath); } else { filename = FileUtils::GetFileName(dumpPath); @@ -152,7 +159,8 @@ AclTensorStats::AclTensorStats(const AclTensorInfo& tensor, const std::map& opt) +AclTensorStats AclTensorStats::CalTensorSummary(const AclTensorInfo& tensor, + const std::vector& opt) { DEBUG_FUNC_TRACE(); std::map summary; @@ -167,10 +175,10 @@ AclTensorStats AclTensorStats::CalTensorSummary(const AclTensorInfo& tensor, con static std::map ParseTensorSummaryHeaderOrder(const std::vector& segs) { std::map ret; - for (uint32_t pos = 0; pos < segs.size(); ++pos) { + for (size_t pos = 0; pos < segs.size(); ++pos) { const std::string& opt = segs[pos]; - for (auto it = summaryOptionHeaderStrMap.begin(); it != summaryOptionHeaderStrMap.end(); ++it) { - if (opt == it->second) { + for (auto it = SUMMARY_OPTION_HEADER_STR_MAP.begin(); it != SUMMARY_OPTION_HEADER_STR_MAP.end(); ++it) { + if (opt == it->second.first) { ret[pos] = it->first; break; } @@ -181,14 +189,14 @@ static std::map ParseTensorSummaryHeaderOrder(c AclTensorStats AclTensorStats::ParseTensorSummary(const std::string& dumpPath, const std::string& input) { - constexpr const uint32_t optPosBase = 7; + constexpr const size_t optPosBase = 7; static std::map order; static uint32_t headerLen = 0; std::vector segs = FileUtils::SplitPath(input, ','); /* device计算统计量场景,各个kernel的统计项的顺序是相同的,只要计算一次即可 */ if (order.empty()) { - if (segs.size() <= optPosBase || segs[0] != kStatsHeaderInout) { + if (segs.size() <= optPosBase || segs[0] != STATS_HEADER_INOUT) { LOG_WARNING(DebuggerErrno::ERROR_INVALID_FORMAT, "Summary data miss header, some data may lose."); return AclTensorStats(); } @@ -204,7 +212,7 @@ AclTensorStats AclTensorStats::ParseTensorSummary(const std::string& dumpPath, c } /* 不重复解析header行 */ - if (segs[0] == kStatsHeaderInout) { + if (segs[0] == STATS_HEADER_INOUT) { return AclTensorStats(); } @@ -229,11 +237,11 @@ std::string AclTensorStats::GetCsvHeader() const return std::string(); } std::string ret; - ret.reserve(bufferLen); + ret.reserve(BUFFER_LEN); ret.append("Op Type,Op Name,Task ID,Stream ID,Timestamp,Input/Output,Slot,Data Size,Data Type,Format,Shape"); for (auto it = stats.begin(); it != stats.end(); it++) { ret.append(","); - ret.append(summaryOptionHeaderStrMap.at(it->first)); + ret.append(SUMMARY_OPTION_HEADER_STR_MAP.at(it->first).second); } ret.append("\n"); @@ -247,7 +255,7 @@ std::string AclTensorStats::GetCsvValue() const } std::string ret; - ret.reserve(bufferLen); + ret.reserve(BUFFER_LEN); ret.append(opType).append(",").append(opName).append(",").append(taskID).append(",").append(streamID).append(",") \ .append(timestamp).append(",").append(inout).append(",").append(slot).append(",") .append(dataSize) \ .append(",").append(dataType).append(",").append(format).append(",").append(shape); @@ -275,7 +283,7 @@ std::string AclDumpDataProcessor::ToString() const std::to_string(totalLen) + ")"; } -DebuggerErrno AclDumpDataProcessor::PushData(const acldumpChunk *chunk) +DebuggerErrno AclDumpDataProcessor::PushData(const AclDumpChunk *chunk) { DEBUG_FUNC_TRACE(); if (completed) { @@ -290,8 +298,15 @@ DebuggerErrno AclDumpDataProcessor::PushData(const acldumpChunk *chunk) } size_t len = chunk->bufLen; + if (len == 0) { + LOG_ERROR(DebuggerErrno::ERROR_INVALID_VALUE, ToString() + ": invalid value(cached size " + + std::to_string(totalLen) + ", receiving size " + std::to_string(len) + ")."); + errorOccurred = true; + return DebuggerErrno::ERROR_INVALID_VALUE; + } + /* 防止正负翻转 */ - if (SIZE_MAX - len < totalLen || totalLen + len > kMaxDataLen || len == 0) { + if (SIZE_MAX - len < totalLen || totalLen + len > MAX_DATA_LEN) { LOG_ERROR(DebuggerErrno::ERROR_BUFFER_OVERFLOW, ToString() + ": buffer overflow(cached size " + std::to_string(totalLen) + ", receiving size " + std::to_string(len) + ")."); errorOccurred = true; @@ -306,7 +321,10 @@ DebuggerErrno AclDumpDataProcessor::PushData(const acldumpChunk *chunk) return DebuggerErrno::ERROR_NO_MEMORY; } - if (memcpy(p->data(), chunk->dataBuf, len) == nullptr) { + /* vector p根据chunk->dataBuf的长度,即len,申请创建,所以无需校验空间大小 */ + try { + std::copy(chunk->dataBuf, chunk->dataBuf + len, p->begin()); + } catch (const std::exception& e) { LOG_ERROR(DebuggerErrno::ERROR_SYSCALL_FAILED, ToString() + ": Failed to copy data;"); delete p; errorOccurred = true; @@ -354,9 +372,11 @@ DebuggerErrno AclDumpDataProcessor::ConcatenateData() } size_t offset = 0; - uint8_t* msg = p->data(); while (!buffer.empty()) { - if (memcpy(msg + offset, buffer.front()->data(), buffer.front()->size()) == nullptr) { + /* vector p根据buffer里所有vector的总长度,即totalLen,申请创建,所以无需校验空间大小 */ + try { + std::copy(buffer.front()->begin(), buffer.front()->end(), p->begin() + offset); + } catch (const std::exception& e) { delete p; LOG_ERROR(DebuggerErrno::ERROR_SYSCALL_FAILED, "Data processor(" + dumpPath + "): Failed to copy."); return DebuggerErrno::ERROR_SYSCALL_FAILED; @@ -398,17 +418,17 @@ static nlohmann::json ParseOverflowInfo(const uint8_t* data) DEBUG_FUNC_TRACE(); uint32_t index = 0; nlohmann::json overflowInfo; - uint64_t modelId = DataUtils::UnpackUint64Value_Le(data); - index += kUint64Size; - uint64_t streamId = DataUtils::UnpackUint64Value_Le(data + index); - index += kUint64Size; - uint64_t taskId = DataUtils::UnpackUint64Value_Le(data + index); - index += kUint64Size; - uint64_t taskType = DataUtils::UnpackUint64Value_Le(data + index); - index += kUint64Size; - uint64_t pcStart = DataUtils::UnpackUint64Value_Le(data + index); - index += kUint64Size; - uint64_t paraBase = DataUtils::UnpackUint64Value_Le(data + index); + uint64_t modelId = DataUtils::UnpackUint64ValueLe(data); + index += UINT64_SIZE; + uint64_t streamId = DataUtils::UnpackUint64ValueLe(data + index); + index += UINT64_SIZE; + uint64_t taskId = DataUtils::UnpackUint64ValueLe(data + index); + index += UINT64_SIZE; + uint64_t taskType = DataUtils::UnpackUint64ValueLe(data + index); + index += UINT64_SIZE; + uint64_t pcStart = DataUtils::UnpackUint64ValueLe(data + index); + index += UINT64_SIZE; + uint64_t paraBase = DataUtils::UnpackUint64ValueLe(data + index); overflowInfo["model_id"] = modelId; overflowInfo["stream_id"] = streamId; @@ -424,30 +444,30 @@ static DebuggerErrno DumpOpDebugDataToDisk(const std::string& dumpPath, AclDumpM { DEBUG_FUNC_TRACE(); std::string outPath = dumpPath + ".output."; - uint32_t num = dumpData.output().size(); + uint32_t num = static_cast(dumpData.output().size()); for (uint32_t slot = 0; slot < num; slot++) { uint32_t offset = 0; // parse DHA Atomic Add info nlohmann::json dhaAtomicAddInfo = ParseOverflowInfo(data + offset); - offset += kDhaAtomicAddInfoSize; + offset += DHA_ATOMIC_ADD_INFO_SIZE; // parse L2 Atomic Add info nlohmann::json l2AtomicAddInfo = ParseOverflowInfo(data + offset); - offset += kL2AtomicAddInfoSize; + offset += L2_ATOMIC_ADD_INFO_SIZE; // parse AICore info nlohmann::json aiCoreInfo = ParseOverflowInfo(data + offset); - offset += kAiCoreInfoSize; + offset += AICORE_INFO_SIZE; // parse DHA Atomic Add status - dhaAtomicAddInfo["status"] = DataUtils::UnpackUint64Value_Le(data + offset); - offset += kDhaAtomicAddStatusSize; + dhaAtomicAddInfo["status"] = DataUtils::UnpackUint64ValueLe(data + offset); + offset += DHA_ATOMIC_ADD_STATUS_SIZE; // parse L2 Atomic Add status - l2AtomicAddInfo["status"] = DataUtils::UnpackUint64Value_Le(data + offset); - offset += kL2AtomicAddStatusSize; + l2AtomicAddInfo["status"] = DataUtils::UnpackUint64ValueLe(data + offset); + offset += L2_ATOMIC_ADD_STATUS_SIZE; // parse AICore status - uint64_t kernelCode = DataUtils::UnpackUint64Value_Le(data + offset); - offset += kUint64Size; - uint64_t blockIdx = DataUtils::UnpackUint64Value_Le(data + offset); - offset += kUint64Size; - uint64_t status = DataUtils::UnpackUint64Value_Le(data + offset); + uint64_t kernelCode = DataUtils::UnpackUint64ValueLe(data + offset); + offset += UINT64_SIZE; + uint64_t blockIdx = DataUtils::UnpackUint64ValueLe(data + offset); + offset += UINT64_SIZE; + uint64_t status = DataUtils::UnpackUint64ValueLe(data + offset); aiCoreInfo["kernel_code"] = DataUtils::U64ToHexString(kernelCode); aiCoreInfo["block_idx"] = blockIdx; aiCoreInfo["status"] = status; @@ -523,8 +543,11 @@ static std::string MappingFilePath(const std::string& originPath) return std::string(); } - DebuggerErrno ret; - FileUtils::CreateDir(dir); + DebuggerErrno ret = FileUtils::CreateDir(dir); + if (ret != DebuggerErrno::OK) { + LOG_ERROR(DebuggerErrno::ERROR, "Failed to create directory " + dir + "."); + return std::string(); + } std::ofstream ofs; constexpr const char* mapFileName = "mapping.csv"; @@ -563,7 +586,8 @@ static DebuggerErrno StandardizedDumpPath(std::string& originPath) return DebuggerErrno::OK; } -static std::string GenDataPath(const std::string& path) { +static std::string GenDataPath(const std::string& path) +{ LOG_DEBUG("Original acl data path is " + path); std::string outputPath = DebuggerConfig::GetInstance().GetOutputPath(); std::string dataPath; @@ -585,7 +609,8 @@ static std::string GenDataPath(const std::string& path) { } /* * ACL 接口返回数据的路径格式如下 - * {dump_path}/rank_{rank_id}/{time stamp}/step_{step_id}/{time}/{device_id}/{model_name}/{model_id}/{iteration_id}/{data name} + * {dump_path}/rank_{rank_id}/{time stamp}/step_{step_id}/{time} + /{device_id}/{model_name}/{model_id}/{iteration_id}/{data name} * items[0] 表示 rank_{rank_id} * items[1] 表示 {time stamp} * items[2] 表示 step_{step_id} @@ -603,7 +628,7 @@ static std::string GenDataPath(const std::string& path) { inline std::string GetTensorInfoSuffix(AclTensorInfo& tensor) { return "." + tensor.inout + "." + std::to_string(tensor.slot) + - "." + DataUtils::GetFormatString(tensor.hostFmt) + "." + DataUtils::GetDTypeString(tensor.dtype); + "." + DataUtils::GetFormatString(tensor.hostFmt) + "." + DataUtils::GetDTypeString(tensor.oriDtype); } static DebuggerErrno DumpOneAclTensorFmtBin(AclTensorInfo& tensor) @@ -640,17 +665,20 @@ static DebuggerErrno DumpOneAclTensorFmtNpy(AclTensorInfo& tensor) return DebuggerErrno::OK; } - if (tensor.dtype == AclDtype::DT_BF16) { - ret = AclTensor::TransDtype(tensor, AclDtype::DT_FLOAT); + auto it = kDtypeTransMap.find(tensor.dtype); + if (it != kDtypeTransMap.end()) { + AclDtype dstDtype = it->second; + ret = AclTensor::TransDtype(tensor, dstDtype); if (ret != DebuggerErrno::OK) { - LOG_ERROR(ret, tensor + ": Failed to transform dtype from bf16 to fp32."); + LOG_ERROR(ret, tensor + ": Failed to transform dtype from " + + DataUtils::GetDTypeString(it->first) + " to " + + DataUtils::GetDTypeString(it->second)+ "."); return ret; } } // dump_path: dump_dir/op_type.op_name.task_id.stream_id.timestamp std::string dumpPathSlot = tensor.dumpPath + GetTensorInfoSuffix(tensor) + "." + NPY_SUFFIX; - if (StandardizedDumpPath(dumpPathSlot) != DebuggerErrno::OK) { LOG_ERROR(DebuggerErrno::ERROR, "Failed to standardize path " + dumpPathSlot + "."); return DebuggerErrno::ERROR; @@ -676,7 +704,7 @@ static DebuggerErrno DumpOneAclTensorFmtNpy(AclTensorInfo& tensor) static DebuggerErrno WriteOneTensorStatToDisk(const AclTensorStats& stat) { DEBUG_FUNC_TRACE(); - if (stat.empty()) { + if (stat.Empty()) { return DebuggerErrno::OK; } @@ -684,7 +712,7 @@ static DebuggerErrno WriteOneTensorStatToDisk(const AclTensorStats& stat) /* 此处防止多进程间竞争,使用文件锁,故使用C风格接口 */ uint32_t retry = 100; uint32_t interval = 10; - if (FileUtils::IsPathExist(dumpfile) && !FileUtils::IsRegularFile(dumpfile)) { + if (FileUtils::CheckFileBeforeCreateOrWrite(dumpfile, true) != DebuggerErrno::OK) { LOG_ERROR(DebuggerErrno::ERROR_FILE_ALREADY_EXISTS, "File " + dumpfile + " exists and has invalid format."); return DebuggerErrno::ERROR_FILE_ALREADY_EXISTS; } @@ -703,8 +731,9 @@ static DebuggerErrno WriteOneTensorStatToDisk(const AclTensorStats& stat) std::this_thread::sleep_for(std::chrono::milliseconds(interval)); } - if (i >= retry) { + if (i == retry) { LOG_ERROR(DebuggerErrno::ERROR_SYSCALL_FAILED, "Failed to occupy file " + dumpfile); + close(fd); return DebuggerErrno::ERROR_SYSCALL_FAILED; } @@ -736,7 +765,9 @@ static DebuggerErrno DumpOneAclTensor(AclTensorInfo& tensor, std::vector #include #include -#include "include/ErrorCode.hpp" -#include "base/DebuggerConfig.hpp" -#include "third_party/ACL/AclApi.hpp" +#include "include/ErrorCode.h" +#include "base/DebuggerConfig.h" +#include "third_party/ACL/AclApi.h" namespace MindStudioDebugger { -constexpr size_t kMaxDataLen = 4ULL * 1024 * 1024 * 1024; +constexpr size_t MAX_DATA_LEN = 4ULL * 1024 * 1024 * 1024; class AclDumpDataProcessor { public: - AclDumpDataProcessor(const std::string& path, const std::vector& opts) : - dumpPath{path}, hostAnalysisOpts{opts} {}; + AclDumpDataProcessor(const std::string& path, const std::vector& opts) + : dumpPath{path}, hostAnalysisOpts{opts} {}; ~AclDumpDataProcessor(); bool IsCompleted() const {return completed;} bool ErrorOccurred() const {return errorOccurred;} - DebuggerErrno PushData(const acldumpChunk *chunk); + DebuggerErrno PushData(const AclDumpChunk *chunk); DebuggerErrno DumpToDisk(); std::string ToString() const; @@ -57,3 +58,5 @@ private: } +#endif + diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.cpp b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.cpp index 80769d7fc5fbc9d36115a544e05dd00f2a7541c3..e015386784feec471a0cac538f8a685ecf33f44e 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.cpp @@ -19,51 +19,51 @@ #include #include -#include "include/Macro.hpp" -#include "utils/FileUtils.hpp" -#include "utils/FileOperation.hpp" -#include "third_party/ACL/AclApi.hpp" -#include "base/Environment.hpp" -#include "base/ErrorInfos.hpp" -#include "AclDumper.hpp" +#include "include/Macro.h" +#include "utils/FileUtils.h" +#include "utils/FileOperation.h" +#include "third_party/ACL/AclApi.h" +#include "base/Environment.h" +#include "base/ErrorInfosManager.h" +#include "AclDumper.h" namespace MindStudioDebugger { -constexpr const char* kAclDumpScene = "dump_scene"; -constexpr const char* kSceneNormal = "normal"; -constexpr const char* kSceneException ="lite_exception"; +constexpr const char* ACL_DUMP_SCENE = "dump_scene"; +constexpr const char* SCENE_NORMAL = "normal"; +constexpr const char* SCENE_EXCEPTION = "lite_exception"; -constexpr const char* kAclDumpPath = "dump_path"; -constexpr const char* kAclDumpStep = "dump_step"; +constexpr const char* ACL_DUMP_PATH = "dump_path"; +constexpr const char* ACL_DUMP_STEP = "dump_step"; -constexpr const char* kAclDumpList = "dump_list"; -constexpr const char* kAclDumpLayer = "layer"; -constexpr const char* kAclDumpModel = "model_name"; +constexpr const char* ACL_DUMP_LIST = "dump_list"; +constexpr const char* ACL_DUMP_LAYER = "layer"; +constexpr const char* ACL_DUMP_MODEL_NAME = "model_name"; -constexpr const char* kAclDumpMode = "dump_mode"; -constexpr const char* kAclModeInput = "input"; -constexpr const char* kAclModeOutput = "output"; -constexpr const char* kAclModeAll = "all"; +constexpr const char* ACL_DUMP_MODE = "dump_mode"; +constexpr const char* ACL_MODE_INPUT = "input"; +constexpr const char* ACL_MODE_OUTPUT = "output"; +constexpr const char* ACL_MODE_ALL = "all"; -constexpr const char* kAclDumpOpSwitch = "dump_op_switch"; -constexpr const char* kAclDumpDebug = "dump_debug"; -constexpr const char* kAclSwitchOn = "on"; -constexpr const char* kAclSwitchOff = "off"; +constexpr const char* DUMP_OP_SWITCH = "dump_op_switch"; +constexpr const char* ACL_DUMP_DEBUG = "dump_debug"; +constexpr const char* ACL_SWITCH_ON = "on"; +constexpr const char* ACL_SWITCH_OFF = "off"; -constexpr const char* kAclDumpData = "dump_data"; -constexpr const char* kAclDumpTensor = "tensor"; -constexpr const char* kAclDumpStats = "stats"; +constexpr const char* ACL_DUMP_DATA = "dump_data"; +constexpr const char* ACL_DUMP_TENSOR = "tensor"; +constexpr const char* ACL_DUMP_STATS = "stats"; -constexpr const char* kAclDumpStatsOpt = "dump_stats"; -constexpr const char* kAclDumpStatsMax = "Max"; -constexpr const char* kAclDumpStatsMin = "Min"; -constexpr const char* kAclDumpStatsAvg = "Avg"; -constexpr const char* kAclDumpStatsNorn = "L2norm"; -constexpr const char* kAclDumpStatsNan = "Nan"; -constexpr const char* kAclDumpStatsNegInf = "Negative Inf"; -constexpr const char* kAclDumpStatsPosInf = "Positive Inf"; +constexpr const char* ACL_DUMP_STATS_OPT = "dump_stats"; +constexpr const char* ACL_DUMP_STATS_MAX = "Max"; +constexpr const char* ACL_DUMP_STATS_MIN = "Min"; +constexpr const char* ACL_DUMP_STATS_AVG = "Avg"; +constexpr const char* ACL_DUMP_STATS_NORM = "L2norm"; +constexpr const char* ACL_DUMP_STATS_NAN = "Nan"; +constexpr const char* ACL_DUMP_STATS_NEG_INF = "Negative Inf"; +constexpr const char* ACL_DUMP_STATS_POS_INF = "Positive Inf"; -constexpr const size_t kProcessorNumMax = 100; +constexpr const size_t PROCESSOR_NUM_MAX = 100; inline std::string GenAclJsonPath(const std::string& dumpPath, uint32_t rank) { @@ -74,14 +74,14 @@ inline std::string GenAclJsonPath(const std::string& dumpPath, uint32_t rank) static std::string GenDumpInoutString(DebuggerDataInOut mode) { static std::map dumpModeMap = { - {DebuggerDataInOut::INOUT_INPUT, kAclModeInput}, - {DebuggerDataInOut::INOUT_OUTPUT, kAclModeOutput}, - {DebuggerDataInOut::INOUT_BOTH, kAclModeAll}, + {DebuggerDataInOut::INOUT_INPUT, ACL_MODE_INPUT}, + {DebuggerDataInOut::INOUT_OUTPUT, ACL_MODE_OUTPUT}, + {DebuggerDataInOut::INOUT_BOTH, ACL_MODE_ALL}, }; auto it = dumpModeMap.find(mode); if (it == dumpModeMap.end()) { - return kAclModeAll; + return ACL_MODE_ALL; } else { return it->second; } @@ -90,13 +90,13 @@ static std::string GenDumpInoutString(DebuggerDataInOut mode) static std::vector GenStatsOptions(const std::vector& options) { static std::map summaryOptMap = { - {DebuggerSummaryOption::MAX, kAclDumpStatsMax}, - {DebuggerSummaryOption::MIN, kAclDumpStatsMin}, - {DebuggerSummaryOption::MEAN, kAclDumpStatsAvg}, - {DebuggerSummaryOption::L2NORM, kAclDumpStatsNorn}, - {DebuggerSummaryOption::NAN_CNT, kAclDumpStatsNan}, - {DebuggerSummaryOption::NEG_INF_CNT, kAclDumpStatsNegInf}, - {DebuggerSummaryOption::POS_INF_CNT, kAclDumpStatsPosInf}, + {DebuggerSummaryOption::MAX, ACL_DUMP_STATS_MAX}, + {DebuggerSummaryOption::MIN, ACL_DUMP_STATS_MIN}, + {DebuggerSummaryOption::MEAN, ACL_DUMP_STATS_AVG}, + {DebuggerSummaryOption::L2NORM, ACL_DUMP_STATS_NORM}, + {DebuggerSummaryOption::NAN_CNT, ACL_DUMP_STATS_NAN}, + {DebuggerSummaryOption::NEG_INF_CNT, ACL_DUMP_STATS_NEG_INF}, + {DebuggerSummaryOption::POS_INF_CNT, ACL_DUMP_STATS_POS_INF}, }; std::vector output; @@ -151,6 +151,26 @@ bool AclDumper::IsCfgEnableAclDumper() ELE_IN_VECTOR(tasks, DebuggerTaskType::TASK_OVERFLOW_CHECK)); } +bool AclDumper::IsOverflowCompleted() +{ + return overflowNums != -1 && realOverflowNums > overflowNums; +} + +void AclDumper::CountOverflowNumbers(const AclDumpChunk* chunk) +{ + if (IsOverflowCompleted() || !isOverflowDump || !chunk->isLastChunk) { + return; + } + const std::string fileName = chunk->fileName; + auto separator = fileName.rfind("/"); + auto fileBaseName = fileName.substr(separator + 1); + if (fileBaseName.rfind("Opdebug.Node_OpDebug.") == 0) { + // count according to the first file: Node_OpDebug + realOverflowNums++; + } + return; +} + std::string AclDumper::GetDumpPath(uint32_t curStep) const { if (!initialized || foreDumpPath.empty()) { @@ -174,19 +194,19 @@ DebuggerErrno AclDumper::AclDumpGenTensorJson(std::shared_ptrinout); - aclDumpJson[kAclDumpData] = kAclDumpTensor; - aclDumpJson[kAclDumpList] = nlohmann::json::array(); - aclDumpJson[kAclDumpOpSwitch] = kAclSwitchOn; + aclDumpJson[ACL_DUMP_PATH] = fullDumpPath; + aclDumpJson[ACL_DUMP_MODE] = GenDumpInoutString(dumpTensorCfg->inout); + aclDumpJson[ACL_DUMP_DATA] = ACL_DUMP_TENSOR; + aclDumpJson[ACL_DUMP_LIST] = nlohmann::json::array(); + aclDumpJson[DUMP_OP_SWITCH] = ACL_SWITCH_ON; if (!needDump) { /* 这里沿用mindspore框架的方案,用一个大数0x7FFFFFFF表示不需要dump;这个方案非常奇怪,后续可以看下能否优化 */ - aclDumpJson[kAclDumpStep] = std::to_string(INT_MAX); + aclDumpJson[ACL_DUMP_STEP] = std::to_string(INT_MAX); } else { std::vector kernelsList = dumpTensorCfg->matcher.GenRealKernelList(kernels); if (!kernelsList.empty()) { - aclDumpJson[kAclDumpList].push_back({{kAclDumpLayer, kernelsList}}); + aclDumpJson[ACL_DUMP_LIST].push_back({{ACL_DUMP_LAYER, kernelsList}}); } } @@ -210,25 +230,25 @@ DebuggerErrno AclDumper::AclDumpGenStatJson(std::shared_ptr fullDumpPath = dumpPath; } - aclDumpJson[kAclDumpPath] = fullDumpPath; - aclDumpJson[kAclDumpMode] = GenDumpInoutString(statisticsCfg->inout); - aclDumpJson[kAclDumpList] = nlohmann::json::array(); - aclDumpJson[kAclDumpOpSwitch] = kAclSwitchOn; + aclDumpJson[ACL_DUMP_PATH] = fullDumpPath; + aclDumpJson[ACL_DUMP_MODE] = GenDumpInoutString(statisticsCfg->inout); + aclDumpJson[ACL_DUMP_LIST] = nlohmann::json::array(); + aclDumpJson[DUMP_OP_SWITCH] = ACL_SWITCH_ON; /* 如果需要host侧分析,下给acl的任务还是dump tensor,然后在host侧转成统计量 */ if (!hostAnalysisOpt.empty()) { - aclDumpJson[kAclDumpData] = kAclDumpTensor; + aclDumpJson[ACL_DUMP_DATA] = ACL_DUMP_TENSOR; } else { - aclDumpJson[kAclDumpData] = kAclDumpStats; - aclDumpJson[kAclDumpStatsOpt] = GenStatsOptions(statisticsCfg->summaryOption); + aclDumpJson[ACL_DUMP_DATA] = ACL_DUMP_STATS; + aclDumpJson[ACL_DUMP_STATS_OPT] = GenStatsOptions(statisticsCfg->summaryOption); } if (!needDump) { - aclDumpJson[kAclDumpStep] = std::to_string(INT_MAX); + aclDumpJson[ACL_DUMP_STEP] = std::to_string(INT_MAX); } else { std::vector kernelsList = statisticsCfg->matcher.GenRealKernelList(kernels); - if (!kernelsList.empty()){ - aclDumpJson[kAclDumpList].push_back({{kAclDumpLayer, kernelsList}}); + if (!kernelsList.empty()) { + aclDumpJson[ACL_DUMP_LIST].push_back({{ACL_DUMP_LAYER, kernelsList}}); } } @@ -257,10 +277,10 @@ DebuggerErrno AclDumper::AclDumpGenOverflowJson(std::shared_ptrfileName); auto it = dataProcessors.find(dumpPath); if (it == dataProcessors.end()) { - if (dataProcessors.size() > kProcessorNumMax) { + if (dataProcessors.size() > PROCESSOR_NUM_MAX) { LOG_ERROR(DebuggerErrno::ERROR_BUFFER_OVERFLOW, "The number of processors has reached the upper limit."); return; } @@ -404,7 +429,7 @@ void AclDumper::SetDump(uint32_t rank, uint32_t curStep, ExtArgs& args) if (!initialized) { ret = Initialize(); - if(ret != DebuggerErrno::OK) { + if (ret != DebuggerErrno::OK) { LOG_ERROR(ret, "AclDumper initialization failed."); return; } @@ -424,6 +449,8 @@ void AclDumper::SetDump(uint32_t rank, uint32_t curStep, ExtArgs& args) ret = AclDumpGenStatJson(statisticsCfg, rank, curStep, kernels); } else if (overflowCheckCfg != nullptr) { ret = AclDumpGenOverflowJson(overflowCheckCfg, rank, curStep); + overflowNums = overflowCheckCfg->overflowNums; + isOverflowDump = true; } if (ret != DebuggerErrno::OK) { @@ -431,8 +458,7 @@ void AclDumper::SetDump(uint32_t rank, uint32_t curStep, ExtArgs& args) return; } - aclError aclRet; - aclRet = CALL_ACL_API(aclmdlInitDump); + aclError aclRet = CALL_ACL_API(AclmdlInitDump); if (aclRet != ACL_SUCCESS) { LOG_ERROR(DebuggerErrno::ERROR_EXTERNAL_API_ERROR, "Failed to init acldump(" + std::to_string(aclRet) + ")."); @@ -440,7 +466,7 @@ void AclDumper::SetDump(uint32_t rank, uint32_t curStep, ExtArgs& args) } const std::string& dumpPath = DebuggerConfig::GetInstance().GetOutputPath(); - aclRet = CALL_ACL_API(aclmdlSetDump, GenAclJsonPath(dumpPath, rank).c_str()); + aclRet = CALL_ACL_API(AclmdlSetDump, GenAclJsonPath(dumpPath, rank).c_str()); if (aclRet != ACL_SUCCESS) { LOG_ERROR(DebuggerErrno::ERROR_EXTERNAL_API_ERROR, "Failed to enable acldump(" + std::to_string(aclRet) + ")."); @@ -458,51 +484,53 @@ void AclDumper::FinalizeDump(ExtArgs& args) return; } - CALL_ACL_API(aclrtSynchronizeDevice); - aclError aclRet = CALL_ACL_API(aclmdlFinalizeDump); + CALL_ACL_API(AclrtSynchronizeDevice); + aclError aclRet = CALL_ACL_API(AclmdlFinalizeDump); if (aclRet != ACL_SUCCESS) { LOG_ERROR(DebuggerErrno::ERROR_EXTERNAL_API_ERROR, "Failed to finalize acldump(" + std::to_string(aclRet) + ")."); - } aclDumpHasSet = false; } -void KernelInitDump() { - if (AscendCLApi::LoadAclApi() != DebuggerErrno::OK) { - return; - } +void KernelInitDump() +{ + if (AscendCLApi::LoadAclApi() != DebuggerErrno::OK) { + return; + } - DebuggerErrno ret = InitAcl(); - if (ret != DebuggerErrno::OK) { - LOG_ERROR(ret, "Failed to call InitAcl."); - return; - } - auto aclRet = CALL_ACL_API(aclmdlInitDump); - if (aclRet != ACL_SUCCESS) { + DebuggerErrno ret = InitAcl(); + if (ret != DebuggerErrno::OK) { + LOG_ERROR(ret, "Failed to call InitAcl."); + return; + } + auto aclRet = CALL_ACL_API(AclmdlInitDump); + if (aclRet != ACL_SUCCESS) { LOG_ERROR(DebuggerErrno::ERROR_EXTERNAL_API_ERROR, "Failed to init acldump(" + std::to_string(aclRet) + ")."); return; - } + } } -void KernelSetDump(const std::string &filePath) { - std::string dumpPath = FileUtils::GetAbsPath(filePath); - auto aclRet = CALL_ACL_API(aclmdlSetDump, dumpPath.c_str()); - if (aclRet != ACL_SUCCESS) { +void KernelSetDump(const std::string &filePath) +{ + std::string dumpPath = FileUtils::GetAbsPath(filePath); + auto aclRet = CALL_ACL_API(AclmdlSetDump, dumpPath.c_str()); + if (aclRet != ACL_SUCCESS) { LOG_ERROR(DebuggerErrno::ERROR_EXTERNAL_API_ERROR, "Failed to enable acldump(" + std::to_string(aclRet) + ")."); return; - } + } } -void KernelFinalizeDump() { - CALL_ACL_API(aclrtSynchronizeDevice); - auto aclRet = CALL_ACL_API(aclmdlFinalizeDump); - if (aclRet != ACL_SUCCESS) { +void KernelFinalizeDump() +{ + CALL_ACL_API(AclrtSynchronizeDevice); + auto aclRet = CALL_ACL_API(AclmdlFinalizeDump); + if (aclRet != ACL_SUCCESS) { LOG_ERROR(DebuggerErrno::ERROR_EXTERNAL_API_ERROR, "Failed to finalize acldump(" + std::to_string(aclRet) + ")."); - } + } } } \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.hpp b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.h similarity index 82% rename from debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.hpp rename to debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.h index dcfad5fafcabdf944e1d4b0b0a3cd77251ce047d..8407394e29e6a1af5ba5db9df91ee4812b670c74 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.h @@ -21,17 +21,18 @@ #include #include -#include "include/ExtArgs.hpp" -#include "base/DebuggerConfig.hpp" -#include "AclDumpDataProcessor.hpp" +#include "include/ExtArgs.h" +#include "base/DebuggerConfig.h" +#include "AclDumpDataProcessor.h" namespace MindStudioDebugger { class AclDumper { public: - static AclDumper& GetInstance() { - static AclDumper instance_; - return instance_; + static AclDumper& GetInstance() + { + static AclDumper dumperInstance; + return dumperInstance; } static bool IsIterNeedDump(uint32_t iterId); @@ -39,7 +40,7 @@ public: void SetDump(uint32_t rank, uint32_t curStep, ExtArgs& args); void FinalizeDump(ExtArgs& args); - void OnAclDumpCallBack(const acldumpChunk* chunk, int32_t len); + void OnAclDumpCallBack(const AclDumpChunk* chunk, int32_t len); std::string GetDumpPath(uint32_t curStep) const; @@ -58,11 +59,17 @@ private: uint32_t curStep, const char** kernels); DebuggerErrno AclDumpGenOverflowJson(std::shared_ptr overflowCfg, uint32_t rank, uint32_t curStep); + void CountOverflowNumbers(const AclDumpChunk* chunk); + bool IsOverflowCompleted(); + bool initialized{false}; bool aclDumpHasSet{false}; std::string foreDumpPath; std::vector hostAnalysisOpt; std::map> dataProcessors; + bool isOverflowDump{false}; + int32_t overflowNums{1}; + int32_t realOverflowNums{0}; }; void KernelInitDump(); diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.cpp b/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.cpp index 45adff4962156f87f52c17166bc3b381f07f2978..5ae66a06ad82abf7dd0bebfcdf9fe0dbba6750b5 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.cpp @@ -22,10 +22,10 @@ #include #include -#include "utils/DataUtils.hpp" -#include "utils/MathUtils.hpp" -#include "base/ErrorInfos.hpp" -#include "AclTensor.hpp" +#include "utils/DataUtils.h" +#include "utils/MathUtils.h" +#include "base/ErrorInfosManager.h" +#include "AclTensor.h" namespace MindStudioDebugger { namespace AclDumpMsg = toolkit::dumpdata; @@ -33,21 +33,21 @@ namespace AclTensor { using namespace MathUtils; -constexpr int64_t kCubeSize = 16; -constexpr int64_t kCube16 = kCubeSize; -constexpr int64_t kCube32 = 32; -constexpr int64_t kCube64 = 64; -constexpr int64_t kCubeSize_C04 = 4; - -constexpr size_t hwH = 1; -constexpr size_t hwW = 2; -constexpr size_t fnzW1 = 4; -constexpr size_t fnzH1 = 3; -constexpr size_t fnzH0 = 2; -constexpr size_t fnzW0 = 1; -constexpr size_t fzN0 = 1; -constexpr size_t fzNi = 2; -constexpr size_t fzC0 = 3; +constexpr int64_t CUBE_SIZE = 16; +constexpr int64_t CUBE_16 = CUBE_SIZE; +constexpr int64_t CUBE_32 = 32; +constexpr int64_t CUBE_64 = 64; +constexpr int64_t CUBE_SIZE_C04 = 4; + +constexpr size_t HW_H = 1; +constexpr size_t HW_W = 2; +constexpr size_t FNZ_W1 = 4; +constexpr size_t FNZ_H1 = 3; +constexpr size_t FNZ_H0 = 2; +constexpr size_t FNZ_W0 = 1; +constexpr size_t FZ_N0 = 1; +constexpr size_t FZ_NI = 2; +constexpr size_t FZ_C0 = 3; using TensorTransFunc = DebuggerErrno (*)(AclTensorInfo &); @@ -94,21 +94,20 @@ const static std::unordered_set kSupportedFormat = { AclFormat::FORMAT_DHWNC, AclFormat::FORMAT_NDC1HWC0, AclFormat::FORMAT_FRACTAL_Z_3D, - AclFormat::FORMAT_C1HWNCoC0, + AclFormat::FORMAT_C1HWNCOC0, AclFormat::FORMAT_FRACTAL_NZ, AclFormat::FORMAT_FRACTAL_ZN_LSTM, AclFormat::FORMAT_NCL, }; const static std::map, TensorTransFunc> formatTransFuncMap = { - /* {{from, to}, function} */ {{AclFormat::FORMAT_HWCN, AclFormat::FORMAT_NCHW}, nullptr}, {{AclFormat::FORMAT_NHWC, AclFormat::FORMAT_NCHW}, nullptr}, {{AclFormat::FORMAT_FRACTAL_Z, AclFormat::FORMAT_NCHW}, FRAC_Z_TO_NCHW}, {{AclFormat::FORMAT_FRACTAL_NZ, AclFormat::FORMAT_NCHW}, FRAC_NZ_TO_NCHW}, {{AclFormat::FORMAT_NC1HWC0, AclFormat::FORMAT_NCHW}, NC1HWC0_TO_NCHW}, {{AclFormat::FORMAT_NDC1HWC0, AclFormat::FORMAT_NCHW}, NDC1HWC0_TO_NCDHW}, - {{AclFormat::FORMAT_C1HWNCoC0, AclFormat::FORMAT_NCHW}, C1HWNCoC0_TO_NCHW}, + {{AclFormat::FORMAT_C1HWNCOC0, AclFormat::FORMAT_NCHW}, C1HWNCoC0_TO_NCHW}, {{AclFormat::FORMAT_NC1HWC0_C04, AclFormat::FORMAT_NCHW}, NC1HWC0_C04_TO_NCHW}, {{AclFormat::FORMAT_FRACTAL_Z_3D, AclFormat::FORMAT_NCHW}, FRAC_Z3D_TO_NCDHW}, }; @@ -164,7 +163,8 @@ const static std::unordered_map formatTrans {AclDumpMsg::OutputFormat::FORMAT_NC1HWC0_C04, AclFormat::FORMAT_NC1HWC0_C04}, {AclDumpMsg::OutputFormat::FORMAT_FRACTAL_Z_C04, AclFormat::FORMAT_FRACTAL_Z_C04}, {AclDumpMsg::OutputFormat::FORMAT_CHWN, AclFormat::FORMAT_CHWN}, - {AclDumpMsg::OutputFormat::FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, AclFormat::FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS}, + {AclDumpMsg::OutputFormat::FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, + AclFormat::FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS}, {AclDumpMsg::OutputFormat::FORMAT_HWCN, AclFormat::FORMAT_HWCN}, {AclDumpMsg::OutputFormat::FORMAT_NC1KHKWHWC0, AclFormat::FORMAT_NC1KHKWHWC0}, {AclDumpMsg::OutputFormat::FORMAT_BN_WEIGHT, AclFormat::FORMAT_BN_WEIGHT}, @@ -174,7 +174,7 @@ const static std::unordered_map formatTrans {AclDumpMsg::OutputFormat::FORMAT_HASHTABLE_LOOKUP_VALUE, AclFormat::FORMAT_HASHTABLE_LOOKUP_VALUE}, {AclDumpMsg::OutputFormat::FORMAT_HASHTABLE_LOOKUP_OUTPUT, AclFormat::FORMAT_HASHTABLE_LOOKUP_OUTPUT}, {AclDumpMsg::OutputFormat::FORMAT_HASHTABLE_LOOKUP_HITS, AclFormat::FORMAT_HASHTABLE_LOOKUP_HITS}, - {AclDumpMsg::OutputFormat::FORMAT_C1HWNCoC0, AclFormat::FORMAT_C1HWNCoC0}, + {AclDumpMsg::OutputFormat::FORMAT_C1HWNCoC0, AclFormat::FORMAT_C1HWNCOC0}, {AclDumpMsg::OutputFormat::FORMAT_MD, AclFormat::FORMAT_MD}, {AclDumpMsg::OutputFormat::FORMAT_NDHWC, AclFormat::FORMAT_NDHWC}, {AclDumpMsg::OutputFormat::FORMAT_FRACTAL_ZZ, AclFormat::FORMAT_FRACTAL_ZZ}, @@ -201,20 +201,20 @@ const static std::unordered_map formatTrans {AclDumpMsg::OutputFormat::FORMAT_C1HWC0, AclFormat::FORMAT_C1HWC0}, }; -enum kAxis4D : int { kN = 0, kC, kH, kW, kNchwDims }; +enum Axis4D : int { AXIS_N = 0, AXIS_C, AXIS_H, AXIS_W, NCHW_DIMS }; enum Axis5D : int { - N_ncdhw = 0, - C_ncdhw, - D_ncdhw, - H_ncdhw, - W_ncdhw, - kNcdhw, - N_ndc1hwc0 = 0, - D_ndc1hwc0, - C1_ndc1hwc0, - H_ndc1hwc0, - W_ndc1hwc0, - C0_ndc1hwc0 + N_NCDHW, + C_NCDHW, + D_NCDHW, + H_NCDHW, + W_NCDHW, + NCDHW, + N_NDC1HWC0, + D_NDC1HWC0, + C1_NDC1HWC0, + H_NDC1HWC0, + W_NDC1HWC0, + C0_NDC1HWC0 }; static inline AclDtype transAclDtype2MS(AclDumpMsg::OutputDataType dt) @@ -235,7 +235,8 @@ static inline AclFormat transAclFormat2MS(AclDumpMsg::OutputFormat fmt) return AclFormat::FORMAT_MAX; } -static size_t EleNumOfTensor(const AclTensorInfo& tensor, bool host = true) { +static size_t EleNumOfTensor(const AclTensorInfo& tensor, bool host = true) +{ size_t num = 1; const AclShape& shape = host ? tensor.hostShape : tensor.deviceShape; for (auto dim : shape) { @@ -244,23 +245,26 @@ static size_t EleNumOfTensor(const AclTensorInfo& tensor, bool host = true) { return 0; } - if (SIZE_MAX / dim < num) { + if (SIZE_MAX / static_cast(dim) < static_cast(num)) { throw std::out_of_range(tensor + ": Count of element over size_t."); } num *= static_cast(dim); } - return num; + return num; } -static inline size_t SizeOfAclDType(const AclTensorInfo& tensor) { +static inline size_t SizeOfAclDType(const AclTensorInfo& tensor) +{ return DataUtils::SizeOfDType(tensor.dtype); } -static inline size_t SizeOfAclDType(const AclDtype& dtype) { +static inline size_t SizeOfAclDType(const AclDtype& dtype) +{ return DataUtils::SizeOfDType(dtype); } -size_t SizeOfTensor(const AclTensorInfo& tensor, bool host) { +size_t SizeOfTensor(const AclTensorInfo& tensor, bool host) +{ size_t num = EleNumOfTensor(tensor, host); size_t eleSize = SizeOfAclDType(tensor); if (num != 0 && SIZE_MAX / num < eleSize) { @@ -269,16 +273,17 @@ size_t SizeOfTensor(const AclTensorInfo& tensor, bool host) { return num * eleSize; } -static inline int64_t GetCubeSizeByType(const AclDtype& dtype) { +static inline int64_t GetCubeSizeByType(const AclDtype& dtype) +{ if (dtype == AclDtype::DT_UINT8 || dtype == AclDtype::DT_INT8) { - return kCube32; + return CUBE_32; } if (dtype == AclDtype::DT_INT4) { - return kCube64; + return CUBE_64; } - return kCube16; + return CUBE_16; } static inline void AssertDim(const AclShape& shape, size_t dim) @@ -291,7 +296,14 @@ static inline void AssertDim(const AclShape& shape, size_t dim) static inline void AssertConsis(const AclTensorInfo& tensor) { - if (EleNumOfTensor(tensor, false) * SizeOfAclDType(tensor) != tensor.dataSize) { + size_t tensorSize = EleNumOfTensor(tensor, false) * SizeOfAclDType(tensor); + // Processing dtype whose size < 1 + // The ele num of quantization type(qint4*2) in MindSpore must be even. + size_t int4_size_factor = 2; + if (tensor.dtype == AclDtype::DT_INT4) { + tensorSize = EleNumOfTensor(tensor, false) / int4_size_factor; + } + if (tensorSize != tensor.dataSize) { throw std::runtime_error(tensor + ": The internal data of Tensor is inconsistent."); } } @@ -321,8 +333,8 @@ AclTensorInfo ParseAttrsFromDumpData(const std::string& dumpPath, const uint8_t* for (auto d : tensor.original_shape().dim()) { if (d > INT64_MAX) { LOG_WARNING(DebuggerErrno::ERROR_VALUE_OVERFLOW, - "The value(" + std::to_string(d) + ") exceeds the max value of int64_t, " + - "this maybe caused by the unfixed shape operaters."); + "The value(" + std::to_string(d) + ") exceeds the max value of int64_t, " + + "this maybe caused by the unfixed shape operaters."); hShape.clear(); break; } @@ -331,7 +343,7 @@ AclTensorInfo ParseAttrsFromDumpData(const std::string& dumpPath, const uint8_t* // convert format to host format. It can be either NCHW or ND (non 4-dimemsions). AclFormat hFmt; - if (hShape.size() == kDim4) { + if (hShape.size() == DIM_4) { hFmt = AclFormat::FORMAT_NCHW; } else if (hShape.empty()) { hFmt = dFmt; @@ -343,7 +355,8 @@ AclTensorInfo ParseAttrsFromDumpData(const std::string& dumpPath, const uint8_t* } int32_t subFormat = tensor.sub_format(); - return AclTensorInfo{dumpPath, data, dtype, dFmt, hFmt, dShape, hShape, dataSize, subFormat, io, slot, dumpOriginData}; + return AclTensorInfo{dumpPath, data, dtype, dtype, dFmt, hFmt, + dShape, hShape, dataSize, subFormat, io, slot, dumpOriginData}; } template AclTensorInfo ParseAttrsFromDumpData( @@ -360,14 +373,14 @@ static inline void AllocTensorTransBuf(AclTensorInfo& tensor) static DebuggerErrno FRAC_Z_TO_NCHW_WITH_GROUPS(AclTensorInfo& tensor) { - AssertDim(tensor.hostShape, kDim4); + AssertDim(tensor.hostShape, DIM_4); AssertConsis(tensor); AllocTensorTransBuf(tensor); - auto nDim = tensor.hostShape[kN]; - auto cDim = tensor.hostShape[kC]; - auto hDim = tensor.hostShape[kH]; - auto wDim = tensor.hostShape[kW]; + auto nDim = tensor.hostShape[AXIS_N]; + auto cDim = tensor.hostShape[AXIS_C]; + auto hDim = tensor.hostShape[AXIS_H]; + auto wDim = tensor.hostShape[AXIS_W]; auto groups = tensor.subFormat; auto cinOri = cDim; auto coutOri = nDim / groups; @@ -378,7 +391,7 @@ static DebuggerErrno FRAC_Z_TO_NCHW_WITH_GROUPS(AclTensorInfo& tensor) } auto cubeK = GetCubeSizeByType(tensor.dtype); - auto eMult = std::min(Lcm(Lcm(cinOri, cubeK) / cinOri, Lcm(coutOri, kCubeSize) / cinOri), + auto eMult = std::min(Lcm(Lcm(cinOri, cubeK) / cinOri, Lcm(coutOri, CUBE_SIZE) / cinOri), static_cast(groups)); if (eMult == 0) { LOG_WARNING(DebuggerErrno::ERROR_INVALID_VALUE, @@ -387,11 +400,12 @@ static DebuggerErrno FRAC_Z_TO_NCHW_WITH_GROUPS(AclTensorInfo& tensor) } auto cinOpt = AlignCeil(eMult * cinOri, cubeK); - auto coutOpt = AlignCeil(eMult * coutOri, kCubeSize); + auto coutOpt = AlignCeil(eMult * coutOri, CUBE_SIZE); auto c1Dim = cinOpt / cubeK; const uint8_t* src = tensor.aclData; - uint8_t* dst = tensor.transBuf.data(); - auto dtypeSize = SizeOfAclDType(tensor); + auto dst = tensor.transBuf.begin(); + int64_t dtypeSize = static_cast(SizeOfAclDType(tensor)); + int64_t dstSize = static_cast(tensor.transBuf.size()); for (int64_t g = 0; g < groups; ++g) { for (int64_t c = 0; c < cDim; ++c) { @@ -407,8 +421,13 @@ static DebuggerErrno FRAC_Z_TO_NCHW_WITH_GROUPS(AclTensorInfo& tensor) (dstCi / cubeK) * hDim * wDim * coutOpt * cubeK + h * wDim * coutOpt * cubeK + w * coutOpt * cubeK + dstCo * cubeK + temporary; int64_t hstIdx = srcCo * cDim * hDim * wDim + c * hDim * wDim + h * wDim + w; - /* 此处由偏移计算逻辑保障不会越界读写 */ - std::memcpy(dst + hstIdx * dtypeSize, src + devIdx * dtypeSize, dtypeSize); + int64_t devOffset = devIdx * dtypeSize; + int64_t hstOffset = hstIdx * dtypeSize; + if (hstOffset + dtypeSize > dstSize) { + return DebuggerErrno::ERROR_INVALID_VALUE; + } + std::copy(src + devOffset, src + devOffset + dtypeSize, + dst + hstOffset); } } } @@ -423,17 +442,17 @@ static DebuggerErrno FRAC_Z_TO_NCHW(AclTensorInfo& tensor) return FRAC_Z_TO_NCHW_WITH_GROUPS(tensor); } - AssertDim(tensor.hostShape, kDim4); + AssertDim(tensor.hostShape, DIM_4); AssertConsis(tensor); AllocTensorTransBuf(tensor); - auto n0 = tensor.deviceShape.at(fzN0); - auto ni = tensor.deviceShape.at(fzNi); - auto c0 = tensor.deviceShape.at(fzC0); - auto n = tensor.hostShape[kN]; - auto c = tensor.hostShape[kC]; - auto h = tensor.hostShape[kH]; - auto w = tensor.hostShape[kW]; + auto n0 = tensor.deviceShape.at(FZ_N0); + auto ni = tensor.deviceShape.at(FZ_NI); + auto c0 = tensor.deviceShape.at(FZ_C0); + auto n = tensor.hostShape[AXIS_N]; + auto c = tensor.hostShape[AXIS_C]; + auto h = tensor.hostShape[AXIS_H]; + auto w = tensor.hostShape[AXIS_W]; auto nc = ni * n0; auto ncc0 = nc * c0; auto wncc0 = w * ncc0; @@ -446,8 +465,9 @@ static DebuggerErrno FRAC_Z_TO_NCHW(AclTensorInfo& tensor) } const uint8_t* src = tensor.aclData; - uint8_t* dst = tensor.transBuf.data(); - auto dtypeSize = SizeOfAclDType(tensor); + auto dst = tensor.transBuf.begin(); + int64_t dtypeSize = static_cast(SizeOfAclDType(tensor)); + int64_t dstSize = static_cast(tensor.transBuf.size()); for (int64_t nIdx = 0; nIdx < n; nIdx++) { int64_t nHeadAddr = nIdx * chw; for (int64_t cIdx = 0; cIdx < c; cIdx++) { @@ -460,8 +480,13 @@ static DebuggerErrno FRAC_Z_TO_NCHW(AclTensorInfo& tensor) auto c0Idx = cIdx % c0; auto ncIdx = nIdx; auto srcIdx = c1Idx * hwncc0 + hIdx * wncc0 + wIdx * ncc0 + ncIdx * c0 + c0Idx; - /* 此处由偏移计算逻辑保障不会越界读写 */ - std::memcpy(dst + dstIdx * dtypeSize, src + srcIdx * dtypeSize, dtypeSize); + auto dstOffset = dstIdx * dtypeSize; + auto srcOffset = srcIdx * dtypeSize; + if (dstOffset + dtypeSize > dstSize) { + return DebuggerErrno::ERROR_INVALID_VALUE; + } + std::copy(src + srcOffset, src + srcOffset + dtypeSize, + dst + dstOffset); } } } @@ -471,7 +496,7 @@ static DebuggerErrno FRAC_Z_TO_NCHW(AclTensorInfo& tensor) static void TransShapeToHwNz(const AclShape &hostShape, AclShape& hwShape) { - if (hostShape.size() == kDim1) { + if (hostShape.size() == DIM_1) { hwShape.push_back(1); hwShape.push_back(1); hwShape.push_back(hostShape[0]); @@ -479,12 +504,12 @@ static void TransShapeToHwNz(const AclShape &hostShape, AclShape& hwShape) } auto size = hostShape.size(); int64_t times = 1; - for (size_t i = 0; i != size - kDim2; i++) { + for (size_t i = 0; i != size - DIM_2; i++) { times *= hostShape[i]; } hwShape.push_back(times); - hwShape.push_back(hostShape[size - kDim2]); - hwShape.push_back(hostShape[size - kDim1]); + hwShape.push_back(hostShape[size - DIM_2]); + hwShape.push_back(hostShape[size - DIM_1]); } static DebuggerErrno FRAC_NZ_TO_NCHW(AclTensorInfo& tensor) @@ -495,27 +520,32 @@ static DebuggerErrno FRAC_NZ_TO_NCHW(AclTensorInfo& tensor) AclShape hwShape; TransShapeToHwNz(tensor.hostShape, hwShape); auto times = hwShape.at(0); - auto h = hwShape.at(hwH); - auto w = hwShape.at(hwW); + auto h = hwShape.at(HW_H); + auto w = hwShape.at(HW_W); auto hw = h * w; auto shapeSize = tensor.deviceShape.size(); - if (shapeSize < kDim4) { + if (shapeSize < DIM_4) { LOG_WARNING(DebuggerErrno::ERROR_INVALID_VALUE, tensor + ": Invalid shape size."); return DebuggerErrno::ERROR_INVALID_VALUE; } - auto w1 = tensor.deviceShape[shapeSize - fnzW1]; - auto h1 = tensor.deviceShape[shapeSize - fnzH1]; - auto h0 = tensor.deviceShape[shapeSize - fnzH0]; - auto w0 = tensor.deviceShape[shapeSize - fnzW0]; + auto w1 = tensor.deviceShape[shapeSize - FNZ_W1]; + auto h1 = tensor.deviceShape[shapeSize - FNZ_H1]; + auto h0 = tensor.deviceShape[shapeSize - FNZ_H0]; + auto w0 = tensor.deviceShape[shapeSize - FNZ_W0]; auto h1h0w0 = h1 * h0 * w0; auto w1h1h0w0 = w1 * h1h0w0; + if (w0 == 0) { + LOG_WARNING(DebuggerErrno::ERROR_INVALID_VALUE, tensor + ": Invalid shape size."); + return DebuggerErrno::ERROR_INVALID_VALUE; + } auto numW1 = w / w0; const uint8_t* src = tensor.aclData; - uint8_t* dst = tensor.transBuf.data(); - auto dtypeSize = SizeOfAclDType(tensor); + auto dst = tensor.transBuf.begin(); + int64_t dtypeSize = static_cast(SizeOfAclDType(tensor)); + int64_t dstSize = static_cast(tensor.transBuf.size()); for (int64_t timesIdx = 0; timesIdx < times; timesIdx++) { auto timesHead = timesIdx * w1h1h0w0; @@ -527,8 +557,13 @@ static DebuggerErrno FRAC_NZ_TO_NCHW(AclTensorInfo& tensor) for (int64_t i = 0; i < w0; ++i) { int64_t srcIdx = h1h0Head + w1Idx * h1h0w0 + i; int64_t dstIdx = srcHHead + w1Idx * w0 + i; - /* 此处由偏移计算逻辑保障不会越界读写 */ - std::memcpy(dst + dstIdx * dtypeSize, src + srcIdx * dtypeSize, dtypeSize); + int64_t dstOffset = dstIdx * dtypeSize; + int64_t srcOffset = srcIdx * dtypeSize; + if (dstOffset + dtypeSize > dstSize) { + return DebuggerErrno::ERROR_INVALID_VALUE; + } + std::copy(src + srcOffset, src + srcOffset + dtypeSize, + dst + dstOffset); } } auto w1Head = numW1 * w0; @@ -536,8 +571,12 @@ static DebuggerErrno FRAC_NZ_TO_NCHW(AclTensorInfo& tensor) auto srcWIdx = w1Head + w0Idx; int64_t srcIdx = h1h0Head + numW1 * h1h0w0 + w0Idx; int64_t dstIdx = srcHHead + srcWIdx; - /* 此处由偏移计算逻辑保障不会越界读写 */ - std::memcpy(dst + dstIdx * dtypeSize, src + srcIdx * dtypeSize, dtypeSize); + int64_t dstOffset = dstIdx * dtypeSize; + int64_t srcOffset = srcIdx * dtypeSize; + if (dstOffset + dtypeSize > dstSize) { + return DebuggerErrno::ERROR_INVALID_VALUE; + } + std::copy(src + srcOffset, src + srcOffset + dtypeSize, dst + dstOffset); } } } @@ -546,16 +585,20 @@ static DebuggerErrno FRAC_NZ_TO_NCHW(AclTensorInfo& tensor) static DebuggerErrno NC1HWC0_TO_NCHW(AclTensorInfo& tensor) { - AssertDim(tensor.hostShape, kDim4); + AssertDim(tensor.hostShape, DIM_4); AssertConsis(tensor); AllocTensorTransBuf(tensor); - auto n = tensor.hostShape[kN]; - auto c = tensor.hostShape[kC]; - auto h = tensor.hostShape[kH]; - auto w = tensor.hostShape[kW]; - auto c1 = tensor.deviceShape[kDim1]; - auto c0 = tensor.deviceShape[kDim4]; + auto n = tensor.hostShape[AXIS_N]; + auto c = tensor.hostShape[AXIS_C]; + auto h = tensor.hostShape[AXIS_H]; + auto w = tensor.hostShape[AXIS_W]; + auto c1 = tensor.deviceShape[DIM_1]; + auto c0 = tensor.deviceShape[DIM_4]; + if (c0 == 0) { + LOG_WARNING(DebuggerErrno::ERROR_INVALID_VALUE, tensor + ": Invalid shape size."); + return DebuggerErrno::ERROR_INVALID_VALUE; + } auto hw = h * w; auto chw = c * hw; @@ -564,8 +607,9 @@ static DebuggerErrno NC1HWC0_TO_NCHW(AclTensorInfo& tensor) auto c1hwc0 = c1 * hwc0; const uint8_t* src = tensor.aclData; - uint8_t* dst = tensor.transBuf.data(); - auto dtypeSize = SizeOfAclDType(tensor); + auto dst = tensor.transBuf.begin(); + int64_t dtypeSize = static_cast(SizeOfAclDType(tensor)); + int64_t dstSize = static_cast(tensor.transBuf.size()); for (int64_t nIndex = 0; nIndex < n; nIndex++) { int64_t nHeadAddr = nIndex * chw; for (int64_t cIndex = 0; cIndex < c; cIndex++) { @@ -577,8 +621,13 @@ static DebuggerErrno NC1HWC0_TO_NCHW(AclTensorInfo& tensor) int64_t c1Index = cIndex / c0; int64_t c0Index = cIndex % c0; int64_t srcIdx = nIndex * c1hwc0 + c1Index * hwc0 + hIndex * wc0 + wIndex * c0 + c0Index; - /* 此处由偏移计算逻辑保障不会越界读写 */ - std::memcpy(dst + dstIdx * dtypeSize, src + srcIdx * dtypeSize, dtypeSize); + int64_t dstOffset = dstIdx * dtypeSize; + int64_t srcOffset = srcIdx * dtypeSize; + if (dstOffset + dtypeSize > dstSize) { + return DebuggerErrno::ERROR_INVALID_VALUE; + } + std::copy(src + srcOffset, src + srcOffset + dtypeSize, + dst + dstOffset); } } } @@ -588,17 +637,21 @@ static DebuggerErrno NC1HWC0_TO_NCHW(AclTensorInfo& tensor) static DebuggerErrno NDC1HWC0_TO_NCDHW(AclTensorInfo& tensor) { - AssertDim(tensor.hostShape, kDim5); + AssertDim(tensor.hostShape, DIM_5); AssertConsis(tensor); AllocTensorTransBuf(tensor); - auto n = tensor.hostShape[N_ncdhw]; - auto c = tensor.hostShape[C_ncdhw]; - auto d = tensor.hostShape[D_ncdhw]; - auto h = tensor.hostShape[H_ncdhw]; - auto w = tensor.hostShape[W_ncdhw]; - auto c1 = tensor.deviceShape[C1_ndc1hwc0]; - auto c0 = tensor.deviceShape[C0_ndc1hwc0]; + auto n = tensor.hostShape[N_NCDHW]; + auto c = tensor.hostShape[C_NCDHW]; + auto d = tensor.hostShape[D_NCDHW]; + auto h = tensor.hostShape[H_NCDHW]; + auto w = tensor.hostShape[W_NCDHW]; + auto c1 = tensor.deviceShape[C1_NDC1HWC0]; + auto c0 = tensor.deviceShape[C0_NDC1HWC0]; + if (c0 == 0) { + LOG_WARNING(DebuggerErrno::ERROR_INVALID_VALUE, tensor + ": Invalid shape size."); + return DebuggerErrno::ERROR_INVALID_VALUE; + } const int64_t cdhw = c * d * h * w; const int64_t dhw = d * h * w; @@ -609,8 +662,9 @@ static DebuggerErrno NDC1HWC0_TO_NCDHW(AclTensorInfo& tensor) const int64_t wc0 = w * c0; const uint8_t* src = tensor.aclData; - uint8_t* dst = tensor.transBuf.data(); - auto dtypeSize = SizeOfAclDType(tensor); + auto dst = tensor.transBuf.begin(); + int64_t dtypeSize = static_cast(SizeOfAclDType(tensor)); + int64_t dstSize = static_cast(tensor.transBuf.size()); for (int64_t nIndex = 0; nIndex < n; nIndex++) { int64_t nHead = nIndex * cdhw; for (int64_t cIndex = 0; cIndex < c; cIndex++) { @@ -625,8 +679,13 @@ static DebuggerErrno NDC1HWC0_TO_NCDHW(AclTensorInfo& tensor) int64_t c0Index = cIndex % c0; auto srcIdx = nIndex * dc1hwc0 + dIndex * c1hwc0 + c1Index * hwc0 + hIndex * wc0 + wIndex * c0 + c0Index; - /* 此处由偏移计算逻辑保障不会越界读写 */ - std::memcpy(dst + dstIdx * dtypeSize, src + srcIdx * dtypeSize, dtypeSize); + int64_t dstOffset = dstIdx * dtypeSize; + int64_t srcOffset = srcIdx * dtypeSize; + if (dstOffset + dtypeSize > dstSize) { + return DebuggerErrno::ERROR_INVALID_VALUE; + } + std::copy(src + srcOffset, src + srcOffset + dtypeSize, + dst + dstOffset); } } } @@ -637,14 +696,14 @@ static DebuggerErrno NDC1HWC0_TO_NCDHW(AclTensorInfo& tensor) static DebuggerErrno C1HWNCoC0_TO_NCHW(AclTensorInfo& tensor) { - AssertDim(tensor.hostShape, kDim4); + AssertDim(tensor.hostShape, DIM_4); AssertConsis(tensor); AllocTensorTransBuf(tensor); - auto n = tensor.hostShape[kN]; - auto c = tensor.hostShape[kC]; - auto h = tensor.hostShape[kH]; - auto w = tensor.hostShape[kW]; + auto n = tensor.hostShape[AXIS_N]; + auto c = tensor.hostShape[AXIS_C]; + auto h = tensor.hostShape[AXIS_H]; + auto w = tensor.hostShape[AXIS_W]; const int coIdx = 4; const int c0Idx = 5; auto co = tensor.deviceShape[coIdx]; @@ -652,8 +711,9 @@ static DebuggerErrno C1HWNCoC0_TO_NCHW(AclTensorInfo& tensor) auto cubeK = GetCubeSizeByType(tensor.dtype); const uint8_t* src = tensor.aclData; - uint8_t* dst = tensor.transBuf.data(); - auto dtypeSize = SizeOfAclDType(tensor); + auto dst = tensor.transBuf.begin(); + int64_t dtypeSize = static_cast(SizeOfAclDType(tensor)); + int64_t dstSize = static_cast(tensor.transBuf.size()); for (int64_t nIndex = 0; nIndex < n; nIndex++) { for (int64_t cIndex = 0; cIndex < c; cIndex++) { for (int64_t hIndex = 0; hIndex < h; hIndex++) { @@ -664,8 +724,13 @@ static DebuggerErrno C1HWNCoC0_TO_NCHW(AclTensorInfo& tensor) int64_t coIndex = c0Index; int64_t srcIdx = c1Index * h * w * n * co * c0 + hIndex * w * n * co * c0 + wIndex * n * co * c0 + nIndex * co * c0 + coIndex * c0 + c0Index; - /* 此处由偏移计算逻辑保障不会越界读写 */ - std::memcpy(dst + dstIdx * dtypeSize, src + srcIdx * dtypeSize, dtypeSize); + int64_t dstOffset = dstIdx * dtypeSize; + int64_t srcOffset = srcIdx * dtypeSize; + if (dstOffset + dtypeSize > dstSize) { + return DebuggerErrno::ERROR_INVALID_VALUE; + } + std::copy(src + srcOffset, src + srcOffset + dtypeSize, + dst + dstOffset); } } } @@ -680,17 +745,21 @@ static DebuggerErrno NC1HWC0_C04_TO_NCHW(AclTensorInfo& tensor) static DebuggerErrno FRAC_Z3D_TO_NCDHW(AclTensorInfo& tensor) { - AssertDim(tensor.hostShape, kDim5); + AssertDim(tensor.hostShape, DIM_5); AssertConsis(tensor); AllocTensorTransBuf(tensor); - auto n = tensor.hostShape[N_ncdhw]; - auto c = tensor.hostShape[C_ncdhw]; - auto d = tensor.hostShape[D_ncdhw]; - auto h = tensor.hostShape[H_ncdhw]; - auto w = tensor.hostShape[W_ncdhw]; - constexpr int kFZ3D_C0 = 3; - auto c0 = tensor.deviceShape[kFZ3D_C0]; + auto n = tensor.hostShape[N_NCDHW]; + auto c = tensor.hostShape[C_NCDHW]; + auto d = tensor.hostShape[D_NCDHW]; + auto h = tensor.hostShape[H_NCDHW]; + auto w = tensor.hostShape[W_NCDHW]; + constexpr int FZ3D_C0 = 3; + auto c0 = tensor.deviceShape[FZ3D_C0]; + if (c0 == 0) { + LOG_WARNING(DebuggerErrno::ERROR_INVALID_VALUE, tensor + ": Invalid shape size."); + return DebuggerErrno::ERROR_INVALID_VALUE; + } auto cube_k = GetCubeSizeByType(tensor.dtype); auto c1 = DivCeil(c, cube_k); constexpr int64_t kNiSize = 16; @@ -704,8 +773,9 @@ static DebuggerErrno FRAC_Z3D_TO_NCDHW(AclTensorInfo& tensor) auto cdhw = c * dhw; const uint8_t* src = tensor.aclData; - uint8_t* dst = tensor.transBuf.data(); - auto dtypeSize = SizeOfAclDType(tensor); + auto dst = tensor.transBuf.begin(); + int64_t dtypeSize = static_cast(SizeOfAclDType(tensor)); + int64_t dstSize = static_cast(tensor.transBuf.size()); for (int64_t nIdx = 0; nIdx < n; nIdx++) { int64_t nHead = nIdx * cdhw; for (int64_t cIdx = 0; cIdx < c; cIdx++) { @@ -721,8 +791,13 @@ static DebuggerErrno FRAC_Z3D_TO_NCDHW(AclTensorInfo& tensor) int64_t ncIdx = nIdx; int64_t srcIdx = dIdx * c1hwn1n0c0 + c1I * c1hwn1n0c0 + hIdx * wn1n0c0 + wI * n1n0c0 + ncIdx * c0 + c0I; - /* 此处由偏移计算逻辑保障不会越界读写 */ - std::memcpy(dst + dstIdx * dtypeSize, src + srcIdx * dtypeSize, dtypeSize); + int64_t dstOffset = dstIdx * dtypeSize; + int64_t srcOffset = srcIdx * dtypeSize; + if (dstOffset + dtypeSize > dstSize) { + return DebuggerErrno::ERROR_INVALID_VALUE; + } + std::copy(src + srcOffset, src + srcOffset + dtypeSize, + dst + dstOffset); } } } @@ -749,11 +824,11 @@ DebuggerErrno TransFormatD2H(AclTensorInfo& tensor) } } -static void TransBf16ToFp32(const uint8_t* input, size_t num, uint8_t* output, size_t bufferSize) +static DebuggerErrno TransBf16ToFp32(const uint8_t* input, size_t num, uint8_t* output, size_t bufferSize) { if (bufferSize < num * sizeof(float)) { LOG_ERROR(DebuggerErrno::ERROR_BUFFER_OVERFLOW, "Insufficient space for converting data from bf16 to fp32."); - return; + return DebuggerErrno::ERROR_BUFFER_OVERFLOW; } const DataUtils::BFloat16* in = reinterpret_cast(input); float* out = reinterpret_cast(output); @@ -761,36 +836,100 @@ static void TransBf16ToFp32(const uint8_t* input, size_t num, uint8_t* output, s for (size_t i = 0; i < num; i++) { out[i] = static_cast(in[i]); } + return DebuggerErrno::OK; } -DebuggerErrno TransDtype(AclTensorInfo& tensor, AclDtype to) +static DebuggerErrno TransInt4ToInt8(const uint8_t* input, + size_t elemNums, + uint8_t* output, + size_t bufferSize) { + // 输出缓冲区要能容纳 elemNums 个 int8_t + if (bufferSize < elemNums * sizeof(int8_t)) { + LOG_ERROR(DebuggerErrno::ERROR_BUFFER_OVERFLOW, + "Insufficient space for converting data from int4 to int8."); + return DebuggerErrno::ERROR_BUFFER_OVERFLOW; + } + + const uint8_t* srcData = input; // 原始数据按字节读取 + int8_t* dstData = reinterpret_cast(output); + size_t inputLength = elemNums / 2; + + const int8_t maxValue = 7; + const int8_t minValue = -8; + const uint8_t signBitMask = 0x08; + const int signBitShift = 3; + + for (size_t i = 0; i < inputLength; ++i) { + uint8_t byte = srcData[i]; + + // —— 低 4 位 —— + uint8_t u = byte & 0x0F; // 在无符号变量上做 AND + uint8_t sign = (u & signBitMask) >> signBitShift; + if (sign) { + u |= 0xF0; // 在无符号变量上做 OR + } + // 转回有符号并检查范围 + int8_t t = static_cast(u); + if (t < minValue || t > maxValue) { + LOG_ERROR(DebuggerErrno::ERROR_INVALID_VALUE, + "Invalid int4 value (low nibble)."); + } + *dstData++ = t; - const static std::set> kSupportedDtypeTrans = { - {AclDtype::DT_BF16, AclDtype::DT_FLOAT}, - }; + // —— 高 4 位 —— + u = (byte >> 4) & 0x0F; // 无符号右移后截低 4 位 + sign = (u & signBitMask) >> signBitShift; + if (sign) { + u |= 0xF0; + } + t = static_cast(u); + if (t < minValue || t > maxValue) { + LOG_ERROR(DebuggerErrno::ERROR_INVALID_VALUE, + "Invalid int4 value (high nibble)."); + } + *dstData++ = t; + } + + return DebuggerErrno::OK; +} +DebuggerErrno TransDtype(AclTensorInfo& tensor, AclDtype to) +{ if (tensor.dtype == to) { return DebuggerErrno::OK; } - if (kSupportedDtypeTrans.find({tensor.dtype, to}) == kSupportedDtypeTrans.end()) { - return DebuggerErrno::ERROR_UNKNOWN_TRANS; - } - + tensor.oriDtype = tensor.dtype; std::vector buffer; - AssertConsis(tensor); + try { + AssertConsis(tensor); + } catch (const std::runtime_error& e) { + LOG_ERROR(DebuggerErrno::ERROR_INVALID_OPERATION, e.what()); + return DebuggerErrno::ERROR_INVALID_OPERATION; + } size_t bufferSize = EleNumOfTensor(tensor) * SizeOfAclDType(to); - buffer.reserve(bufferSize); + buffer.resize(bufferSize); const uint8_t* input = tensor.transBuf.empty() ? tensor.aclData : tensor.transBuf.data(); uint8_t* output = buffer.data(); + DebuggerErrno ret; - /* 目前仅支持bf16->fp32,若有通用转换需求再用更泛化的方式重写 */ if (tensor.dtype == AclDtype::DT_BF16 && to == AclDtype::DT_FLOAT) { - TransBf16ToFp32(input, EleNumOfTensor(tensor), output, bufferSize); + ret = TransBf16ToFp32(input, EleNumOfTensor(tensor), output, bufferSize); + } else if (tensor.dtype == AclDtype::DT_INT4 && to == AclDtype::DT_INT8) { + ret = TransInt4ToInt8(input, EleNumOfTensor(tensor), output, bufferSize); + } else { + LOG_ERROR(DebuggerErrno::ERROR_UNKNOWN_TRANS, tensor + ": Trans " + DataUtils::GetDTypeString(tensor.dtype) + + " to " + DataUtils::GetDTypeString(to) + " is not supported."); + return DebuggerErrno::ERROR_UNKNOWN_TRANS; + } + + if (ret != DebuggerErrno::OK) { + return ret; } tensor.transBuf = std::move(buffer); + tensor.dtype = to; return DebuggerErrno::OK; } diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.hpp b/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.h similarity index 76% rename from debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.hpp rename to debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.h index 8b5ba5b06d935d5aaa2dff35e921b9072db6aa1a..8d89984ddecd522918126da0b057e0b98651b1ca 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.h @@ -19,9 +19,9 @@ #include #include -#include "include/ErrorCode.hpp" +#include "include/ErrorCode.h" #include "proto/AclDumpMsg.pb.h" -#include "utils/DataUtils.hpp" +#include "utils/DataUtils.h" namespace MindStudioDebugger { @@ -29,17 +29,18 @@ using AclShape = DataUtils::TensorShape; using AclDtype = DataUtils::DataType; using AclFormat = DataUtils::TensorFormat; -constexpr uint8_t kDim1 = 1; -constexpr uint8_t kDim2 = 2; -constexpr uint8_t kDim3 = 3; -constexpr uint8_t kDim4 = 4; -constexpr uint8_t kDim5 = 5; -constexpr uint8_t kDim6 = 6; +constexpr uint8_t DIM_1 = 1; +constexpr uint8_t DIM_2 = 2; +constexpr uint8_t DIM_3 = 3; +constexpr uint8_t DIM_4 = 4; +constexpr uint8_t DIM_5 = 5; +constexpr uint8_t DIM_6 = 6; struct AclTensorInfo { std::string dumpPath; const uint8_t* aclData; AclDtype dtype; + AclDtype oriDtype; AclFormat deviceFmt; AclFormat hostFmt; AclShape deviceShape; @@ -51,26 +52,30 @@ struct AclTensorInfo { bool dumpOriginData; std::vector transBuf; - std::string ToString() const { - return "AclTensor(path=" + dumpPath + ",dtype=" + std::to_string(dtype) + ",inout=" + inout + ")"; + std::string ToString() const + { + return "AclTensor(path=" + dumpPath + ",dtype=" + DataUtils::GetDTypeString(dtype) + ",inout=" + inout + ")"; } }; -inline std::string operator+(const std::string& s, const AclTensorInfo& tensor) { +inline std::string operator+(const std::string& s, const AclTensorInfo& tensor) +{ return s + tensor.ToString(); } -inline std::string operator+(const AclTensorInfo& tensor, const std::string& s) { +inline std::string operator+(const AclTensorInfo& tensor, const std::string& s) +{ return tensor.ToString() + s; } namespace AclTensor { -size_t SizeOfTensor(const AclTensorInfo& tensor, bool host=true); +size_t SizeOfTensor(const AclTensorInfo& tensor, bool host = true); template AclTensorInfo ParseAttrsFromDumpData(const std::string &dumpPath, const uint8_t* data, const T& tensor, const std::string& io, uint32_t slot); DebuggerErrno TransFormatD2H(AclTensorInfo& tensor); DebuggerErrno TransDtype(AclTensorInfo& tensor, AclDtype to); +bool IsDtypeSupportTrans(AclDtype dtype); } } diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/PrecisionDebugger.cpp b/debug/accuracy_tools/msprobe/ccsrc/core/PrecisionDebugger.cpp index d4d74f1962222558c88c576b8ffbd8c474e152f2..6b51f6f28cee382e4b2928936387957d88f9f427 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/PrecisionDebugger.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/PrecisionDebugger.cpp @@ -16,10 +16,11 @@ #include -#include "base/ErrorInfos.hpp" -#include "base/DebuggerConfig.hpp" -#include "third_party/ACL/AclApi.hpp" -#include "PrecisionDebugger.hpp" +#include "base/ErrorInfosManager.h" +#include "base/DebuggerConfig.h" +#include "third_party/ACL/AclApi.h" +#include "core/mindspore/MSAclDumper.h" +#include "PrecisionDebugger.h" namespace MindStudioDebugger { @@ -83,12 +84,12 @@ int32_t PrecisionDebugger::Initialize(const std::string& framework, const std::s return ret; } - if(AscendCLApi::LoadAclApi() != DebuggerErrno::OK) { + if (AscendCLApi::LoadAclApi() != DebuggerErrno::OK) { return -1; } const DebuggerConfig& cfg = DebuggerConfig::GetInstance(); - for (auto iter = subDebuggers.begin(); iter != subDebuggers.end(); ) { + for (auto iter = subDebuggers.begin(); iter != subDebuggers.end();) { if (!(*iter)->Condition(cfg)) { iter = subDebuggers.erase(iter); } else { @@ -124,7 +125,7 @@ void PrecisionDebugger::Stop() } enable = false; - CALL_ACL_API(aclrtSynchronizeDevice); + CALL_ACL_API(AclrtSynchronizeDevice); for (auto task : subDebuggers) { task->OnStop(); @@ -133,25 +134,7 @@ void PrecisionDebugger::Stop() void PrecisionDebugger::Step() { - return Step(1); -} - -void PrecisionDebugger::Step(uint32_t step) -{ - DEBUG_FUNC_TRACE(); - if (!initialized) { - return; - } - - if (step > UINT32_MAX - curStep) { - throw std::runtime_error("Step over upper limit(4294967295)."); - } - curStep += step; - CALL_ACL_API(aclrtSynchronizeDevice); - - for (auto task : subDebuggers) { - task->OnStep(curStep); - } + MSAclDumper::GetInstance().Step(); } } \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/PrecisionDebugger.hpp b/debug/accuracy_tools/msprobe/ccsrc/core/PrecisionDebugger.h similarity index 92% rename from debug/accuracy_tools/msprobe/ccsrc/core/PrecisionDebugger.hpp rename to debug/accuracy_tools/msprobe/ccsrc/core/PrecisionDebugger.h index fbc22c016c40285a90a3de5989684098639256c9..311d84d8c2afb166e0041a4f3e799a5dfeb9583a 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/PrecisionDebugger.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/PrecisionDebugger.h @@ -19,7 +19,7 @@ #include #include -#include "base/DebuggerConfig.hpp" +#include "base/DebuggerConfig.h" namespace MindStudioDebugger { @@ -43,9 +43,10 @@ protected: class PrecisionDebugger { public: - static PrecisionDebugger& GetInstance() { - static PrecisionDebugger instance_; - return instance_; + static PrecisionDebugger& GetInstance() + { + static PrecisionDebugger debuggerInstance; + return debuggerInstance; } int32_t Initialize(const std::string& framework, const std::string& cfgFile); diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MSAclDumper.cpp b/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MSAclDumper.cpp index 2d80ed3ce1ab11ee5ddf9bad18583a6813f32529..27f48412c690bbb5dafa0fdd31565f136718ab45 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MSAclDumper.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MSAclDumper.cpp @@ -16,15 +16,15 @@ #include -#include "base/ErrorInfos.hpp" -#include "base/DebuggerConfig.hpp" -#include "base/Environment.hpp" -#include "core/AclDumper.hpp" -#include "MSAclDumper.hpp" +#include "base/ErrorInfosManager.h" +#include "base/DebuggerConfig.h" +#include "base/Environment.h" +#include "core/AclDumper.h" +#include "MSAclDumper.h" namespace MindStudioDebugger { -void MSAclDumper::OnStepBegin(uint32_t device, uint32_t curStep, ExtArgs& args) +void MSAclDumper::OnStepBegin(uint32_t device, ExtArgs& args) { DEBUG_FUNC_TRACE(); if (!PrecisionDebugger::GetInstance().IsEnable()) { @@ -41,7 +41,7 @@ void MSAclDumper::OnStepBegin(uint32_t device, uint32_t curStep, ExtArgs& args) rank = static_cast(device); } - AclDumper::GetInstance().SetDump(rank, curStep, args); + AclDumper::GetInstance().SetDump(rank, msprobeStep, args); return; } @@ -51,6 +51,11 @@ void MSAclDumper::OnStepEnd(ExtArgs& args) AclDumper::GetInstance().FinalizeDump(args); } +void MSAclDumper::Step() +{ + msprobeStep++; +} + __attribute__((constructor)) void RegisterMSAclDumper() { MSAclDumper::GetInstance().Register(); diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MSAclDumper.hpp b/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MSAclDumper.h similarity index 79% rename from debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MSAclDumper.hpp rename to debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MSAclDumper.h index cd09bf51af0dac67065d51b8ce60c20f011cd585..3baac1cc6b517f480cdbe891c0f0ba639477e3a8 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MSAclDumper.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MSAclDumper.h @@ -18,26 +18,29 @@ #include -#include "include/ExtArgs.hpp" -#include "core/PrecisionDebugger.hpp" +#include "include/ExtArgs.h" +#include "core/PrecisionDebugger.h" namespace MindStudioDebugger { class MSAclDumper : public PrecisionDbgTaskBase { public: - static MSAclDumper& GetInstance() { - static MSAclDumper instance_; - return instance_; + static MSAclDumper& GetInstance() + { + static MSAclDumper dumperInstance; + return dumperInstance; } std::string Name() const override {return "MindSpore AclDumper";} - bool Condition(const DebuggerConfig& cfg) const override { + bool Condition(const DebuggerConfig& cfg) const override + { return cfg.GetFramework() == DebuggerFramework::FRAMEWORK_MINDSPORE && cfg.GetDebugLevel() == DebuggerLevel::L2; } - void OnStepBegin(uint32_t device, uint32_t curStep, ExtArgs& args); + void OnStepBegin(uint32_t device, ExtArgs& args); void OnStepEnd(ExtArgs& args); + void Step(); private: MSAclDumper() = default; @@ -46,6 +49,7 @@ private: MSAclDumper& operator=(const MSAclDumper &obj) = delete; explicit MSAclDumper(MSAclDumper &&obj) = delete; MSAclDumper& operator=(MSAclDumper &&obj) = delete; + uint32_t msprobeStep{0}; }; } \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MindSporeTrigger.cpp b/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MindSporeTrigger.cpp index 631ea7c4acf4666b911a3bb5f28a3c6cc4fe0d54..031b718737e9af877148b6e8c92192bd0c92fb47 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MindSporeTrigger.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MindSporeTrigger.cpp @@ -14,21 +14,22 @@ * limitations under the License. */ -#include "include/Macro.hpp" -#include "base/ErrorInfos.hpp" -#include "MindSporeTrigger.hpp" -#include "MSAclDumper.hpp" +#include "include/Macro.h" +#include "base/ErrorInfosManager.h" +#include "MSAclDumper.h" +#include "MindSporeTrigger.h" namespace MindStudioDebugger { bool MindSporeTrigger::stepBeginFlag = false; -void MindSporeTrigger::TriggerOnStepBegin(uint32_t device, uint32_t curStep, ExtArgs& args) +void MindSporeTrigger::TriggerOnStepBegin(uint32_t device, uint32_t /* curStep */, ExtArgs& args) { DEBUG_FUNC_TRACE(); CleanErrorInfoCache(); + + MSAclDumper::GetInstance().OnStepBegin(device, args); - MSAclDumper::GetInstance().OnStepBegin(device, curStep, args); stepBeginFlag = true; CleanErrorInfoCache(); diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MindSporeTrigger.hpp b/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MindSporeTrigger.h similarity index 97% rename from debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MindSporeTrigger.hpp rename to debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MindSporeTrigger.h index 022e5d7d4c14a9771681840b967b2ec3aebb811b..d5048925bf58a1e4414b2983d796e598ac56c17b 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MindSporeTrigger.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MindSporeTrigger.h @@ -18,7 +18,7 @@ #include -#include "include/ExtArgs.hpp" +#include "include/ExtArgs.h" namespace MindStudioDebugger { diff --git a/debug/accuracy_tools/msprobe/ccsrc/if/mindspore/MindSporeDbgHook.cpp b/debug/accuracy_tools/msprobe/ccsrc/if/mindspore/MindSporeDbgHook.cpp index 42f3a2e5b61d5da021b2ef7da4a7b88c6dc2abbb..2d744282d4eb2e741ae0e4afa7081a1a65738d61 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/if/mindspore/MindSporeDbgHook.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/if/mindspore/MindSporeDbgHook.cpp @@ -19,9 +19,9 @@ #include #include -#include "include/Macro.hpp" -#include "include/ExtArgs.hpp" -#include "core/mindspore/MindSporeTrigger.hpp" +#include "include/Macro.h" +#include "include/ExtArgs.h" +#include "core/mindspore/MindSporeTrigger.h" EXPORT_SYMBOL void MS_DbgOnStepBegin(uint32_t device, int32_t curStep, std::map exts) @@ -34,8 +34,11 @@ EXPORT_SYMBOL void MS_DbgOnStepBegin(uint32_t device, int32_t curStep, } /* mindspore使用了_GLIBCXX_USE_CXX11_ABI=0,为了解决CXX版本兼容问题,此处将string转char*使用 */ if (ext.first == static_cast(MindStudioDebugger::MindStudioExtensionArgs::ALL_KERNEL_NAMES)) { + if (ext.second == nullptr) { + continue; + } std::vector* ss = reinterpret_cast*>(ext.second); - strBuf = new const char*[(*ss).size() + 1]; + strBuf = new const char* [(*ss).size() + 1]; strBuf[(*ss).size()] = nullptr; size_t i = 0; for (std::string& s : *ss) { @@ -66,6 +69,4 @@ EXPORT_SYMBOL void MS_DbgOnStepEnd(std::map& exts) args[static_cast(ext.first)] = ext.second; } return MindStudioDebugger::MindSporeTrigger::TriggerOnStepEnd(args); -} - - +} \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/ccsrc/if/python/ACLDump.cpp b/debug/accuracy_tools/msprobe/ccsrc/if/python/ACLDump.cpp index 1c380ed3f505795eb622f7f401558f72a54db557..9723f72686a3cc4386927f3f00ef9d1eaffeb9ad 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/if/python/ACLDump.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/if/python/ACLDump.cpp @@ -18,37 +18,40 @@ #include #include -#include "base/ErrorInfos.hpp" -#include "core/AclDumper.hpp" -#include "utils/CPythonUtils.hpp" +#include "base/ErrorInfosManager.h" +#include "core/AclDumper.h" +#include "utils/CPythonUtils.h" namespace MindStudioDebugger { -static PyObject *CPythonKernelInitDump(PyObject *module, PyObject *args) { - PyGILState_STATE gstate = PyGILState_Ensure(); - KernelInitDump(); - PyGILState_Release(gstate); - Py_RETURN_NONE; +static PyObject *CPythonKernelInitDump(PyObject *module, PyObject *args) +{ + PyGILState_STATE gstate = PyGILState_Ensure(); + KernelInitDump(); + PyGILState_Release(gstate); + Py_RETURN_NONE; } -static PyObject *CPythonKernelSetDump(PyObject *module, PyObject *args) { - const char *path; - if (!PyArg_ParseTuple(args, "s", &path)) { +static PyObject *CPythonKernelSetDump(PyObject *module, PyObject *args) +{ + const char *path; + if (!PyArg_ParseTuple(args, "s", &path)) { LOG_ERROR(DebuggerErrno::ERROR_INVALID_VALUE, "npu set dump error, cfg_file must string"); return nullptr; - } - PyGILState_STATE gstate = PyGILState_Ensure(); - KernelSetDump(std::string(path)); - PyGILState_Release(gstate); - Py_RETURN_NONE; + } + PyGILState_STATE gstate = PyGILState_Ensure(); + KernelSetDump(std::string(path)); + PyGILState_Release(gstate); + Py_RETURN_NONE; } -static PyObject *CPythonKernelFinalizeDump(PyObject *module, PyObject *args) { - PyGILState_STATE gstate = PyGILState_Ensure(); - KernelFinalizeDump(); - PyGILState_Release(gstate); - Py_RETURN_NONE; +static PyObject *CPythonKernelFinalizeDump(PyObject *module, PyObject *args) +{ + PyGILState_STATE gstate = PyGILState_Ensure(); + KernelFinalizeDump(); + PyGILState_Release(gstate); + Py_RETURN_NONE; } static PyMethodDef DumpMethods[] = { diff --git a/debug/accuracy_tools/msprobe/ccsrc/if/python/ACLDump.hpp b/debug/accuracy_tools/msprobe/ccsrc/if/python/ACLDump.h similarity index 100% rename from debug/accuracy_tools/msprobe/ccsrc/if/python/ACLDump.hpp rename to debug/accuracy_tools/msprobe/ccsrc/if/python/ACLDump.h diff --git a/debug/accuracy_tools/msprobe/ccsrc/if/python/CPythonAgent.cpp b/debug/accuracy_tools/msprobe/ccsrc/if/python/CPythonAgent.cpp index 4b8fc03491e2c0792c3c707c272e7b587d60c7ad..e41243aa8d3c27b92c275dcd098e983083328d8e 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/if/python/CPythonAgent.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/if/python/CPythonAgent.cpp @@ -18,7 +18,7 @@ #include #include -#include "utils/CPythonUtils.hpp" +#include "utils/CPythonUtils.h" namespace MindStudioDebugger { @@ -29,8 +29,12 @@ PyDoc_STRVAR(CPythonAgentModuleDoc, static PyObject* CPythonAgentRegister(PyObject *module, PyObject *args) { + if (args == nullptr || !PyTuple_Check(args)) { + PyErr_SetString(PyExc_TypeError, "Expect a tuple."); + Py_RETURN_NONE; + } /* 预期2个参数,name和obj */ - if (args == nullptr || PyTuple_GET_SIZE(args) != 2) { + if (PyTuple_GET_SIZE(args) != 2) { PyErr_SetString(PyExc_TypeError, "\'register_context\' expects 2 arguments."); Py_RETURN_NONE; } @@ -56,7 +60,7 @@ static PyObject* CPythonAgentRegister(PyObject *module, PyObject *args) static PyObject* CPythonAgentUnRegister(PyObject *module, PyObject *obj) { CPythonUtils::PythonStringObject name(obj); - if(name.IsNone()) { + if (name.IsNone()) { PyErr_SetString(PyExc_TypeError, "\"name\" should be a string."); Py_RETURN_NONE; } @@ -68,7 +72,7 @@ static PyObject* CPythonAgentUnRegister(PyObject *module, PyObject *obj) static PyObject* CPythonAgentGetContext(PyObject *module, PyObject *obj) { CPythonUtils::PythonStringObject name(obj); - if(name.IsNone()) { + if (name.IsNone()) { PyErr_SetString(PyExc_TypeError, "\"name\" should be a string."); Py_RETURN_NONE; } diff --git a/debug/accuracy_tools/msprobe/ccsrc/if/python/CPythonAgent.hpp b/debug/accuracy_tools/msprobe/ccsrc/if/python/CPythonAgent.h similarity index 100% rename from debug/accuracy_tools/msprobe/ccsrc/if/python/CPythonAgent.hpp rename to debug/accuracy_tools/msprobe/ccsrc/if/python/CPythonAgent.h diff --git a/debug/accuracy_tools/msprobe/ccsrc/if/python/MsProbeIfPython.cpp b/debug/accuracy_tools/msprobe/ccsrc/if/python/MsProbeIfPython.cpp index a18c54a146f7d676d6b3c7f760e50f9e7eebe56c..58a60538e0989ac1a197f77f14fc67b9593ecd40 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/if/python/MsProbeIfPython.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/if/python/MsProbeIfPython.cpp @@ -16,9 +16,9 @@ #include -#include "PrecisionDebuggerIfPython.hpp" -#include "CPythonAgent.hpp" -#include "ACLDump.hpp" +#include "PrecisionDebuggerIfPython.h" +#include "CPythonAgent.h" +#include "ACLDump.h" namespace MindStudioDebugger { @@ -27,7 +27,7 @@ PyDoc_STRVAR(MsProbeCModuleDoc, class _PrecisionDebugger: PrecisionDebugger in CXX \n\ class _DebuggerConfig: Configuration data of PrecisionDebugger \n\ class CPythonAgent: Used for front-end and back-end code interactions \n\ - \n\ + \n\ ..."); static struct PyModuleDef g_MsProbeCModule = { @@ -58,7 +58,6 @@ PyMODINIT_FUNC PyInit__msprobe_c(void) Py_DECREF(m); return nullptr; } - Py_INCREF(precisionDebugger); PyObject* cpyAgent = MindStudioDebugger::GetCPythonAgentModule(); if (cpyAgent == nullptr) { @@ -71,11 +70,17 @@ PyMODINIT_FUNC PyInit__msprobe_c(void) Py_DECREF(m); return nullptr; } - Py_INCREF(cpyAgent); PyMethodDef* dumpmethods = MindStudioDebugger::GetDumpMethods(); for (PyMethodDef* method = dumpmethods; method->ml_name != nullptr; ++method) { - if (PyModule_AddObject(m, method->ml_name, PyCFunction_New(method, nullptr)) < 0) { + PyObject* func = PyCFunction_New(method, nullptr); + if (func == nullptr) { + PyErr_SetString(PyExc_ImportError, "Failed to create dump method."); + Py_DECREF(m); + return nullptr; + } + if (PyModule_AddObject(m, method->ml_name, func) < 0) { + Py_DECREF(func); // 释放未被模块接管的方法对象 PyErr_SetString(PyExc_ImportError, "Failed to bind dump method."); Py_DECREF(m); return nullptr; diff --git a/debug/accuracy_tools/msprobe/ccsrc/if/python/PrecisionDebuggerIfPython.cpp b/debug/accuracy_tools/msprobe/ccsrc/if/python/PrecisionDebuggerIfPython.cpp index da1cf3cf1c5d4c8894d0b12b5518657b5928a8d6..23e41db019dc8da8acce847245ca4e6bc41be67d 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/if/python/PrecisionDebuggerIfPython.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/if/python/PrecisionDebuggerIfPython.cpp @@ -18,8 +18,8 @@ #include #include -#include "utils/CPythonUtils.hpp" -#include "core/PrecisionDebugger.hpp" +#include "utils/CPythonUtils.h" +#include "core/PrecisionDebugger.h" namespace MindStudioDebugger { @@ -53,7 +53,6 @@ static int InitPrecisionDebugger(PyObject *self, PyObject *args, PyObject *kws) CPythonUtils::PythonDictObject kwArgs(kws); std::string framework = kwArgs.GetItem("framework"); std::string cfgFile = kwArgs.GetItem("config_path"); - if (PrecisionDebugger::GetInstance().Initialize(framework, cfgFile) != 0) { PyErr_SetString(PyExc_RuntimeError, "Failed to load config, read log for more details."); return -1; @@ -99,20 +98,9 @@ static PyObject* PrecisionDebuggerStop(PyObject *self) Py_RETURN_NONE; } -static PyObject* PrecisionDebuggerStep(PyObject *self, PyObject *args) +static PyObject* PrecisionDebuggerStep(PyObject *self) { - if (args == nullptr || PyTuple_GET_SIZE(args) == 0) { - PrecisionDebugger::GetInstance().Step(); - Py_RETURN_NONE; - } - - PyObject* increment = PyTuple_GetItem(args, 0); - if (!PyLong_Check(increment)) { - PyErr_SetString(PyExc_TypeError, "\'step\' should be a int."); - Py_RETURN_NONE; - } - - PrecisionDebugger::GetInstance().Step(PyLong_AsUnsignedLong(increment)); + PrecisionDebugger::GetInstance().Step(); Py_RETURN_NONE; } @@ -126,7 +114,7 @@ PyDoc_STRVAR(StepDoc, static PyMethodDef PrecisionDebuggerMethods[] = { {"start", reinterpret_cast(PrecisionDebuggerStart), METH_NOARGS, StartDoc}, {"stop", reinterpret_cast(PrecisionDebuggerStop), METH_NOARGS, StopDoc}, - {"step", reinterpret_cast(PrecisionDebuggerStep), METH_VARARGS, StepDoc}, + {"step", reinterpret_cast(PrecisionDebuggerStep), METH_NOARGS, StepDoc}, {nullptr, nullptr, 0, nullptr} }; @@ -184,5 +172,4 @@ PyTypeObject* GetPyPrecisionDebuggerType() } return &PyPrecisionDebuggerType; } - } \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/ccsrc/if/python/PrecisionDebuggerIfPython.hpp b/debug/accuracy_tools/msprobe/ccsrc/if/python/PrecisionDebuggerIfPython.h similarity index 100% rename from debug/accuracy_tools/msprobe/ccsrc/if/python/PrecisionDebuggerIfPython.hpp rename to debug/accuracy_tools/msprobe/ccsrc/if/python/PrecisionDebuggerIfPython.h diff --git a/debug/accuracy_tools/msprobe/ccsrc/include/ErrorCode.hpp b/debug/accuracy_tools/msprobe/ccsrc/include/ErrorCode.h similarity index 100% rename from debug/accuracy_tools/msprobe/ccsrc/include/ErrorCode.hpp rename to debug/accuracy_tools/msprobe/ccsrc/include/ErrorCode.h diff --git a/debug/accuracy_tools/msprobe/ccsrc/include/ExtArgs.hpp b/debug/accuracy_tools/msprobe/ccsrc/include/ExtArgs.h similarity index 100% rename from debug/accuracy_tools/msprobe/ccsrc/include/ExtArgs.hpp rename to debug/accuracy_tools/msprobe/ccsrc/include/ExtArgs.h diff --git a/debug/accuracy_tools/msprobe/ccsrc/include/Macro.hpp b/debug/accuracy_tools/msprobe/ccsrc/include/Macro.h similarity index 100% rename from debug/accuracy_tools/msprobe/ccsrc/include/Macro.hpp rename to debug/accuracy_tools/msprobe/ccsrc/include/Macro.h diff --git a/debug/accuracy_tools/msprobe/ccsrc/third_party/ACL/AclApi.cpp b/debug/accuracy_tools/msprobe/ccsrc/third_party/ACL/AclApi.cpp index 1636c6998d9096b62e9a7f281c7e5ac1b4de4818..49b23434712967e89c9b047d7fe3fa58fe146977 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/third_party/ACL/AclApi.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/third_party/ACL/AclApi.cpp @@ -18,30 +18,40 @@ #include #include -#include "base/ErrorInfos.hpp" -#include "AclApi.hpp" +#include "base/ErrorInfosManager.h" +#include "AclApi.h" namespace MindStudioDebugger { namespace AscendCLApi { using namespace MindStudioDebugger; -constexpr const char* kLibAscendclName = "libascendcl.so"; -constexpr const char* kLibMSAscendName = "libmindspore_ascend.so.2"; - -using aclInitFuncType = aclError (*)(const char *); -using aclmdlInitDumpFuncType = aclError (*)(); -using aclmdlSetDumpFuncType = aclError (*)(const char *); -using aclmdlFinalizeDumpFuncType = aclError (*)(); -using acldumpRegCallbackFuncType = aclError (*)(AclDumpCallbackFuncType, int32_t); -using aclrtSynchronizeDeviceFuncType = aclError (*)(); - -static aclInitFuncType aclInitFunc = nullptr; -static aclmdlInitDumpFuncType aclmdlInitDumpFunc = nullptr; -static aclmdlSetDumpFuncType aclmdlSetDumpFunc = nullptr; -static aclmdlFinalizeDumpFuncType aclmdlFinalizeDumpFunc = nullptr; -static acldumpRegCallbackFuncType acldumpRegCallbackFunc = nullptr; -static aclrtSynchronizeDeviceFuncType aclrtSynchronizeDeviceFunc = nullptr; +constexpr const char* LIB_ASCEND_CL_NAME = "libascendcl.so"; +constexpr const char* LIB_MS_ASCEND_NAME = "libmindspore_ascend.so.2"; +constexpr const char* LIB_ASCEND_DUMP_NAME = "libascend_dump.so"; + +using AclInitFuncType = aclError (*)(const char *); +using AclmdlInitDumpFuncType = aclError (*)(); +using AclmdlSetDumpFuncType = aclError (*)(const char *); +using AclmdlFinalizeDumpFuncType = aclError (*)(); +using AcldumpRegCallbackFuncType = aclError (*)(AclDumpCallbackFuncType, int32_t); +using AclrtSynchronizeDeviceFuncType = aclError (*)(); + +static AclInitFuncType g_aclInitFunc = nullptr; +static AclmdlInitDumpFuncType g_aclmdlInitDumpFunc = nullptr; +static AclmdlSetDumpFuncType g_aclmdlSetDumpFunc = nullptr; +static AclmdlFinalizeDumpFuncType g_aclmdlFinalizeDumpFunc = nullptr; +static AcldumpRegCallbackFuncType g_acldumpRegCallbackFunc = nullptr; +static AcldumpRegCallbackFuncType g_acldumpRegCallbackFuncInSo = nullptr; +static AclrtSynchronizeDeviceFuncType g_aclrtSynchronizeDeviceFunc = nullptr; + +static const std::map functionMap = { + {"aclInit", reinterpret_cast(&g_aclInitFunc)}, + {"aclmdlInitDump", reinterpret_cast(&g_aclmdlInitDumpFunc)}, + {"aclmdlSetDump", reinterpret_cast(&g_aclmdlSetDumpFunc)}, + {"aclmdlFinalizeDump", reinterpret_cast(&g_aclmdlFinalizeDumpFunc)}, + {"aclrtSynchronizeDevice", reinterpret_cast(&g_aclrtSynchronizeDeviceFunc)}, +}; DebuggerErrno LoadAclApi() { @@ -52,25 +62,15 @@ DebuggerErrno LoadAclApi() return DebuggerErrno::OK; } - hLibAscendcl = dlopen(kLibAscendclName, RTLD_LAZY); + hLibAscendcl = dlopen(LIB_ASCEND_CL_NAME, RTLD_LAZY | RTLD_NOLOAD); if (hLibAscendcl == nullptr) { LOG_ERROR(DebuggerErrno::ERROR_DEPENDENCY_NOT_FIND, "Failed to search libascendcl.so." + std::string(dlerror())); return DebuggerErrno::ERROR_DEPENDENCY_NOT_FIND; } - static const std::map functionMap = { - {"aclInit", reinterpret_cast(&aclInitFunc)}, - {"aclmdlInitDump", reinterpret_cast(&aclmdlInitDumpFunc)}, - {"aclmdlSetDump", reinterpret_cast(&aclmdlSetDumpFunc)}, - {"aclmdlFinalizeDump", reinterpret_cast(&aclmdlFinalizeDumpFunc)}, - {"aclrtSynchronizeDevice", reinterpret_cast(&aclrtSynchronizeDeviceFunc)}, - }; - for (auto& iter : functionMap) { - if (*(iter.second) != nullptr) { - continue; - } + if (*(iter.second) != nullptr) { continue; } *(iter.second) = dlsym(hLibAscendcl, iter.first); if (*(iter.second) == nullptr) { LOG_ERROR(DebuggerErrno::ERROR_DEPENDENCY_NOT_FIND, "Failed to load function " + @@ -83,74 +83,96 @@ DebuggerErrno LoadAclApi() } /* 规避adump的bug,mindspore场景优先使用libmindspore_ascend.so中的符号 */ - void* handler = dlopen(kLibMSAscendName, RTLD_LAZY); - std::string libName = kLibMSAscendName; + void* handler = dlopen(LIB_MS_ASCEND_NAME, RTLD_LAZY | RTLD_NOLOAD); + std::string libName = LIB_MS_ASCEND_NAME; if (handler == nullptr) { handler = hLibAscendcl; - libName = kLibAscendclName; + libName = LIB_ASCEND_CL_NAME; } - acldumpRegCallbackFunc = reinterpret_cast(dlsym(handler, "acldumpRegCallback")); - if (acldumpRegCallbackFunc == nullptr) { - LOG_ERROR(DebuggerErrno::ERROR_DEPENDENCY_NOT_FIND, "Failed to load function acldumpRegCallback from " + - libName + "."); + g_acldumpRegCallbackFunc = reinterpret_cast(dlsym(handler, "acldumpRegCallback")); + if (g_acldumpRegCallbackFunc == nullptr) { + LOG_WARNING(DebuggerErrno::ERROR_DEPENDENCY_NOT_FIND, "Failed to load function acldumpRegCallback from " + + libName + "."); } - LOG_DEBUG("Load function acldumpRegCallback from " + libName); - - if (handler != hLibAscendcl) { - dlclose(handler); + LOG_DEBUG("Load function acldumpRegCallback from " + libName + "."); + + if (handler != hLibAscendcl) { dlclose(handler); } + + void* dumpHandler = dlopen(LIB_ASCEND_DUMP_NAME, RTLD_LAZY | RTLD_NOLOAD); + if (dumpHandler == nullptr) { + LOG_WARNING(DebuggerErrno::ERROR_DEPENDENCY_NOT_FIND, "Failed to load libascend_dump.so."); + } else { + g_acldumpRegCallbackFuncInSo = reinterpret_cast(dlsym(dumpHandler, "acldumpRegCallback")); + if (g_acldumpRegCallbackFuncInSo == nullptr) { + LOG_WARNING(DebuggerErrno::ERROR_DEPENDENCY_NOT_FIND, + "Failed to load function acldumpRegCallback from libascend_dump.so."); + } + LOG_DEBUG("Load function acldumpRegCallback from libascend_dump.so."); + dlclose(dumpHandler); } return DebuggerErrno::OK; } -aclError ACLAPI_aclInit(const char* cfg) +aclError AclApiAclInit(const char* cfg) { - if (aclInitFunc == nullptr) { + if (g_aclInitFunc == nullptr) { throw std::runtime_error("API aclInit does not have a definition."); } - return aclInitFunc(cfg); + return g_aclInitFunc(cfg); } -aclError ACLAPI_aclmdlInitDump() +aclError AclApiAclmdlInitDump() { - if (aclmdlInitDumpFunc == nullptr) { + if (g_aclmdlInitDumpFunc == nullptr) { throw std::runtime_error("API aclmdlInitDump does not have a definition."); } - return aclmdlInitDumpFunc(); + return g_aclmdlInitDumpFunc(); } -aclError ACLAPI_aclmdlSetDump(const char* cfg) +aclError AclApiAclmdlSetDump(const char* cfg) { - if (aclmdlSetDumpFunc == nullptr) { + if (g_aclmdlSetDumpFunc == nullptr) { throw std::runtime_error("API aclmdlSetDump does not have a definition."); } - return aclmdlSetDumpFunc(cfg); + return g_aclmdlSetDumpFunc(cfg); } -aclError ACLAPI_aclmdlFinalizeDump() +aclError AclApiAclmdlFinalizeDump() { - if (aclmdlFinalizeDumpFunc == nullptr) { + if (g_aclmdlFinalizeDumpFunc == nullptr) { throw std::runtime_error("API aclmdlFinalizeDump does not have a definition."); } - return aclmdlFinalizeDumpFunc(); + return g_aclmdlFinalizeDumpFunc(); } -aclError ACLAPI_acldumpRegCallback(AclDumpCallbackFuncType messageCallback, int32_t flag) +aclError AclApiAcldumpRegCallback(AclDumpCallbackFuncType messageCallback, int32_t flag) { - if (acldumpRegCallbackFunc == nullptr) { + if (g_acldumpRegCallbackFunc == nullptr && g_acldumpRegCallbackFuncInSo == nullptr) { throw std::runtime_error("API acldumpRegCallback does not have a definition."); } - return acldumpRegCallbackFunc(messageCallback, flag); + aclError staticAclRet = -1; + aclError dynamicAclRet = -1; + if (g_acldumpRegCallbackFunc != nullptr) { + staticAclRet = g_acldumpRegCallbackFunc(messageCallback, flag); + } + if (g_acldumpRegCallbackFuncInSo != nullptr) { + dynamicAclRet = g_acldumpRegCallbackFuncInSo(messageCallback, flag); + } + if (staticAclRet != ACL_SUCCESS && dynamicAclRet != ACL_SUCCESS) { + return dynamicAclRet; + } + return ACL_SUCCESS; } -aclError ACLAPI_aclrtSynchronizeDevice() +aclError AclApiAclrtSynchronizeDevice() { - if (aclrtSynchronizeDeviceFunc == nullptr) { + if (g_aclrtSynchronizeDeviceFunc == nullptr) { throw std::runtime_error("API aclrtSynchronizeDevice does not have a definition."); } - return aclrtSynchronizeDeviceFunc(); + return g_aclrtSynchronizeDeviceFunc(); } -} +} } diff --git a/debug/accuracy_tools/msprobe/ccsrc/third_party/ACL/AclApi.hpp b/debug/accuracy_tools/msprobe/ccsrc/third_party/ACL/AclApi.h similarity index 78% rename from debug/accuracy_tools/msprobe/ccsrc/third_party/ACL/AclApi.hpp rename to debug/accuracy_tools/msprobe/ccsrc/third_party/ACL/AclApi.h index 731ae2e2caacaa345605ec572c8dcd6dba091488..25804b84c9d8b8e432b71026753cc2f5929b88b0 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/third_party/ACL/AclApi.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/third_party/ACL/AclApi.h @@ -18,25 +18,23 @@ #include -#include "include/ErrorCode.hpp" +#include "include/ErrorCode.h" extern "C" { - -typedef int aclError; +using aclError = int; constexpr int ACL_SUCCESS = 0; constexpr int ACL_ERROR_NONE = 0; constexpr int ACL_ERROR_REPEAT_INITIALIZE = 100002; #define ACL_DUMP_MAX_FILE_PATH_LENGTH 4096 -typedef struct acldumpChunk { +typedef struct AclDumpChunk { char fileName[ACL_DUMP_MAX_FILE_PATH_LENGTH]; // 待落盘的Dump数据文件名,ACL_DUMP_MAX_FILE_PATH_LENGTH表示文件名最大长度,当前为4096 uint32_t bufLen; // dataBuf数据长度,单位Byte uint32_t isLastChunk; // 标识Dump数据是否为最后一个分片,0表示不是最后一个分片,1表示最后一个分片 int64_t offset; // Dump数据文件内容的偏移,其中-1表示文件追加内容 int32_t flag; // 预留Dump数据标识,当前数据无标识 uint8_t dataBuf[0]; // Dump数据的内存地址 -} acldumpChunk; - +} AclDumpChunk; } namespace MindStudioDebugger { @@ -44,16 +42,16 @@ namespace AscendCLApi { DebuggerErrno LoadAclApi(); -using AclDumpCallbackFuncType = int32_t (*)(const acldumpChunk*, int32_t); -aclError ACLAPI_aclInit(const char* cfg); -aclError ACLAPI_aclmdlInitDump(); -aclError ACLAPI_aclmdlSetDump(const char* cfg); -aclError ACLAPI_aclmdlFinalizeDump(); -aclError ACLAPI_acldumpRegCallback(AclDumpCallbackFuncType messageCallback, int32_t flag); +using AclDumpCallbackFuncType = int32_t (*)(const AclDumpChunk*, int32_t); +aclError AclApiAclInit(const char* cfg); +aclError AclApiAclmdlInitDump(); +aclError AclApiAclmdlSetDump(const char* cfg); +aclError AclApiAclmdlFinalizeDump(); +aclError AclApiAcldumpRegCallback(AclDumpCallbackFuncType messageCallback, int32_t flag); -aclError ACLAPI_aclrtSynchronizeDevice(); +aclError AclApiAclrtSynchronizeDevice(); -#define CALL_ACL_API(func, ...) MindStudioDebugger::AscendCLApi::ACLAPI_##func(__VA_ARGS__) +#define CALL_ACL_API(func, ...) MindStudioDebugger::AscendCLApi::AclApi##func(__VA_ARGS__) } } diff --git a/debug/accuracy_tools/msprobe/ccsrc/utils/CPythonUtils.cpp b/debug/accuracy_tools/msprobe/ccsrc/utils/CPythonUtils.cpp index fd944f62db4ff728d1aa2c5d1d5ff818bd5dcf62..1122a348311697ee199fd11607be567db35eb5da 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/utils/CPythonUtils.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/utils/CPythonUtils.cpp @@ -18,7 +18,7 @@ #include #include -#include "CPythonUtils.hpp" +#include "CPythonUtils.h" namespace MindStudioDebugger { namespace CPythonUtils { @@ -77,7 +77,6 @@ PythonObject PythonObject::From(const uint32_t& input) PythonObject PythonObject::From(const double& input) { return PythonNumberObject::From(input); - } PythonObject PythonObject::From(const std::string& input) { @@ -108,7 +107,7 @@ int32_t PythonObject::To(uint32_t& output) const if (!PyLong_Check(ptr)) { return -1; } - output = static_cast(PyLong_AsUnsignedLong(ptr)); + output = static_cast(PyLong_AsUnsignedLong(ptr)); return 0; } @@ -155,7 +154,7 @@ PythonObject PythonObject::Get(const std::string& name, bool ignore) const return ret; } -PythonObject PythonObject::Call(bool ignore) +PythonObject PythonObject::Call(bool ignore) noexcept { if (!PyCallable_Check(ptr)) { if (!ignore) { @@ -173,7 +172,7 @@ PythonObject PythonObject::Call(bool ignore) return ret; } -PythonObject PythonObject::Call(PythonTupleObject& args, bool ignore) +PythonObject PythonObject::Call(PythonTupleObject& args, bool ignore) noexcept { if (!PyCallable_Check(ptr)) { if (!ignore) { @@ -182,7 +181,7 @@ PythonObject PythonObject::Call(PythonTupleObject& args, bool ignore) return PythonObject(); } - PyObject* o = PyObject_CallObject(ptr, args.IsNone() ? nullptr : args); + PyObject* o = PyObject_CallObject(ptr, args.IsNone() ? nullptr : reinterpret_cast(&args)); if (o == nullptr && ignore) { PyErr_Clear(); } @@ -191,7 +190,7 @@ PythonObject PythonObject::Call(PythonTupleObject& args, bool ignore) return ret; } -PythonObject PythonObject::Call(PythonTupleObject& args, PythonDictObject& kwargs, bool ignore) +PythonObject PythonObject::Call(PythonTupleObject& args, PythonDictObject& kwargs, bool ignore) noexcept { if (!PyCallable_Check(ptr)) { if (!ignore) { @@ -203,7 +202,7 @@ PythonObject PythonObject::Call(PythonTupleObject& args, PythonDictObject& kwarg if (args.IsNone() || kwargs.IsNone()) { if (!ignore) { PyErr_SetString(PyExc_TypeError, "Call python object with invalid parameters."); - } + } return PythonObject(); } @@ -227,10 +226,9 @@ PythonObject PythonObject::GetGlobal(const std::string& name, bool ignore) } return PythonObject(PyDict_GetItemString(globals, name.c_str())); - } -PythonObject PythonObject::Import(const std::string& name, bool ignore) +PythonObject PythonObject::Import(const std::string& name, bool ignore) noexcept { PyObject* m = PyImport_ImportModule(name.c_str()); if (m == nullptr) { @@ -483,7 +481,7 @@ PythonTupleObject::PythonTupleObject() : PythonObject() PythonTupleObject::PythonTupleObject(PyObject* o) : PythonObject() { - if (!PyTuple_Check(o)) { + if (!o || !PyTuple_Check(o)) { return; } diff --git a/debug/accuracy_tools/msprobe/ccsrc/utils/CPythonUtils.hpp b/debug/accuracy_tools/msprobe/ccsrc/utils/CPythonUtils.h similarity index 88% rename from debug/accuracy_tools/msprobe/ccsrc/utils/CPythonUtils.hpp rename to debug/accuracy_tools/msprobe/ccsrc/utils/CPythonUtils.h index 40ebcb1dafd505fd7dfa3bda1c2c1609cb60297a..fdf2a4236a1a960f4ab48e13178fa4e86293c7a0 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/utils/CPythonUtils.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/utils/CPythonUtils.h @@ -40,14 +40,14 @@ namespace CPythonUtils { * | tuple | PythonTupleObject | * | dict | PythonDictObject | * ------------------------------------------- - * + * * 创建对象的方式: * 1、通过原生PyObject*类型创建,PythonObject生命周期内会持有原生对象的一个引用 * 2、通过From方法从c++对象创建 * 3、通过GetGlobal、Import等方法从解释器上下文获取 * 4、通过GetRegisteredPyObj获取到上下文的python对象 * 5、通过已有PythonObject对象的Get、GetItem等方法获取子对象 - * + * * 对象转换: * 1、对于转换成PyObject*、bool、string的场景,支持隐式转换 * 2、对于非通用类型转换,调用To方法,返回0表示成功 @@ -56,7 +56,7 @@ namespace CPythonUtils { * python维度支持bool()的都可以转bool(即并非只有bool类型支持转换,下同) * 支持str()的都可以转string * 可迭代对象(且元素支持转换)都可以转vector - * + * * 对象传递: * 1、子类可以安全传递或拷贝给PythonObject对象 * 2、PythonObject传给子类时,若类型匹配,可以安全转递,否则会转为None @@ -81,29 +81,33 @@ PythonObject GetRegisteredPyObj(const std::string& name); class PythonObject { public: - PythonObject() { + PythonObject() + { Py_INCREF(Py_None); ptr = Py_None; } - PythonObject(PyObject* o) : ptr(o) { + PythonObject(PyObject* o) : ptr(o) + { if (ptr == nullptr) { ptr = Py_None; } Py_XINCREF(ptr); } - ~PythonObject() { + ~PythonObject() + { Py_XDECREF(ptr); } explicit PythonObject(const PythonObject &obj) : PythonObject(static_cast(obj)) {} - PythonObject& operator=(const PythonObject &obj) { + PythonObject& operator=(const PythonObject &obj) + { SetPtr(static_cast(obj)); return *this; } /* 获取全局对象 */ - static PythonObject GetGlobal(const std::string& name, bool ignore=true); + static PythonObject GetGlobal(const std::string& name, bool ignore = true); /* 获取模块对象;若其还未加载至缓存,则加载一遍 */ - static PythonObject Import(const std::string& name, bool ignore=true); + static PythonObject Import(const std::string& name, bool ignore = true) noexcept; /* From/To转换,统一放一份在基类,用于遍历迭代器等场景 */ static PythonObject From(const PythonObject& input); @@ -136,17 +140,19 @@ public: bool IsCallable() const {return PyCallable_Check(ptr);} /* 用于调用可调用对象,相当于python代码中的obj(),为了简单只实现了args+kwargs参数形式 */ - PythonObject Call(bool ignore=true); - PythonObject Call(PythonTupleObject& args, bool ignore=true); - PythonObject Call(PythonTupleObject& args, PythonDictObject& kwargs, bool ignore=true); + PythonObject Call(bool ignore = true) noexcept; + PythonObject Call(PythonTupleObject& args, bool ignore = true) noexcept; + PythonObject Call(PythonTupleObject& args, PythonDictObject& kwargs, bool ignore = true) noexcept; /* 用于获取对象属性,相当于python代码中的obj.xx */ - PythonObject Get(const std::string& name, bool ignore=true) const; - PythonObject& NewRef() { + PythonObject Get(const std::string& name, bool ignore = true) const; + PythonObject& NewRef() + { Py_XINCREF(ptr); return *this; } - std::string ToString() const { + std::string ToString() const + { std::string ret; if (To(ret) == 0) { return ret; @@ -156,21 +162,24 @@ public: operator PyObject*() const {return ptr;} operator bool() const {return static_cast(PyObject_IsTrue(ptr));} - operator std::string() const { + operator std::string() const + { return ToString(); } - PythonObject operator()(bool ignore=true) {return Call(ignore);} - PythonObject operator()(PythonTupleObject& args, bool ignore=true) {return Call(args, ignore);} - PythonObject operator()(PythonTupleObject& args, PythonDictObject& kwargs, bool ignore=true) { + PythonObject operator()(bool ignore = true) {return Call(ignore);} + PythonObject operator()(PythonTupleObject& args, bool ignore = true) {return Call(args, ignore);} + PythonObject operator()(PythonTupleObject& args, PythonDictObject& kwargs, bool ignore = true) + { return Call(args, kwargs, ignore); } protected: - void SetPtr(PyObject* o) { + void SetPtr(PyObject* o) + { Py_XDECREF(ptr); if (o == nullptr) { o = Py_None; - } + } Py_INCREF(o); ptr = o; } @@ -185,7 +194,7 @@ private: class PythonNumberObject : public PythonObject { public: PythonNumberObject(); - PythonNumberObject(PyObject* o); + explicit PythonNumberObject(PyObject* o); static PythonNumberObject From(const int32_t& input); static PythonNumberObject From(const uint32_t& input); @@ -195,7 +204,7 @@ public: class PythonStringObject : public PythonObject { public: PythonStringObject(); - PythonStringObject(PyObject* o); + explicit PythonStringObject(PyObject* o); static PythonStringObject From(const std::string& input); static PythonStringObject From(const char* input); @@ -204,7 +213,7 @@ public: class PythonBoolObject : public PythonObject { public: PythonBoolObject(); - PythonBoolObject(PyObject* o); + explicit PythonBoolObject(PyObject* o); static PythonBoolObject From(const bool& input); }; @@ -213,46 +222,46 @@ class PythonListObject : public PythonObject { public: PythonListObject(); explicit PythonListObject(size_t size); - PythonListObject(PyObject* o); + explicit PythonListObject(PyObject* o); template static PythonListObject From(const std::vector& input); size_t Size() const; template - PythonListObject& Append(T value, bool ignore=true); - PythonObject GetItem(size_t pos, bool ignore=true); - PythonListObject& SetItem(size_t pos, PythonObject& item, bool ignore=true); - PythonListObject& Insert(int64_t pos, PythonObject& item, bool ignore=true); - PythonTupleObject ToTuple(bool ignore=true); + PythonListObject& Append(T value, bool ignore = true); + PythonObject GetItem(size_t pos, bool ignore = true); + PythonListObject& SetItem(size_t pos, PythonObject& item, bool ignore = true); + PythonListObject& Insert(int64_t pos, PythonObject& item, bool ignore = true); + PythonTupleObject ToTuple(bool ignore = true); }; class PythonTupleObject : public PythonObject { public: PythonTupleObject(); - PythonTupleObject(PyObject* o); + explicit PythonTupleObject(PyObject* o); template static PythonTupleObject From(const std::vector& input); size_t Size() const; - PythonObject GetItem(size_t pos, bool ignore=true); + PythonObject GetItem(size_t pos, bool ignore = true); }; class PythonDictObject : public PythonObject { public: PythonDictObject(); - PythonDictObject(PyObject* o); + explicit PythonDictObject(PyObject* o); template static PythonDictObject From(const std::map& input); template - PythonDictObject& Add(T1 key, T2 value, bool ignore=true); + PythonDictObject& Add(T1 key, T2 value, bool ignore = true); template - PythonDictObject& Delete(T key, bool ignore=true); + PythonDictObject& Delete(T key, bool ignore = true); template - PythonObject GetItem(T key, bool ignore=true); + PythonObject GetItem(T key, bool ignore = true); }; /**************************************************************************************************/ @@ -282,7 +291,9 @@ int32_t PythonObject::To(std::vector& output) const while ((item = PyIter_Next(iter)) != nullptr) { T tmp; if (PythonObject(item).To(tmp) != 0) { - goto error; + Py_DECREF(item); + Py_DECREF(iter); + return -1; } output.emplace_back(tmp); Py_DECREF(item); @@ -290,10 +301,6 @@ int32_t PythonObject::To(std::vector& output) const Py_DECREF(iter); return 0; -error: - Py_DECREF(item); - Py_DECREF(iter); - return -1; } template diff --git a/debug/accuracy_tools/msprobe/ccsrc/utils/DataUtils.cpp b/debug/accuracy_tools/msprobe/ccsrc/utils/DataUtils.cpp index c2d7df85294f7c96f0fe1a1b9458dfd2ad2e502c..4918ade610d5882e98d61e4a4cd3bba0fd26b61d 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/utils/DataUtils.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/utils/DataUtils.cpp @@ -21,19 +21,21 @@ #include #include -#include "DataUtils.hpp" +#include "DataUtils.h" namespace MindStudioDebugger { namespace DataUtils { -int64_t SizeToS64(size_t v) { +int64_t SizeToS64(size_t v) +{ if (v > static_cast(INT64_MAX)) { throw std::runtime_error("Value " + std::to_string(v) + "exceeds the maximum value of int64."); } return static_cast(v); } -std::string U64ToHexString(uint64_t v) { +std::string U64ToHexString(uint64_t v) +{ std::stringstream ss; ss << "0x" << std::hex << std::uppercase << v; return std::move(ss.str()); @@ -42,28 +44,33 @@ std::string U64ToHexString(uint64_t v) { BFloat16::BFloat16(float f32) { if (std::isnan(f32)) { - value_ = BFloat16::nan_value; + value_ = BFloat16::NAN_VALUE; } else { + constexpr uint8_t offsetSize = 16; union { - uint32_t U32; - float F32; + uint32_t u32Value; + float f32Value; }; - F32 = f32; - uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); - value_ = static_cast((U32 + rounding_bias) >> 16); + f32Value = f32; + uint32_t rounding_bias = ((u32Value >> offsetSize) & 1) + UINT32_C(0x7FFF); + value_ = static_cast((u32Value + rounding_bias) >> offsetSize); } } BFloat16::operator float() const { - float f32 = 0; - uint32_t tmp = value_; - tmp <<= 16; - std::memcpy(&f32, &tmp, sizeof(f32)); + /* 为了兼容性,不要用c++20的bit_cast */ + constexpr uint8_t offsetSize = 16; + union { + float f32; + uint32_t ui32; + }; + ui32 = static_cast(value_); + ui32 <<= offsetSize; // 将ui32左移16位 return f32; } -const static std::unordered_map kTypeSizeMap = { +constexpr std::pair TYPE_SIZE_ARRAY[] = { {DataType::DT_BOOL, 1}, {DataType::DT_INT8, 1}, {DataType::DT_UINT8, 1}, @@ -83,15 +90,16 @@ const static std::unordered_map kTypeSizeMap = { size_t SizeOfDType(DataType type) { - auto it = kTypeSizeMap.find(type); - if (it == kTypeSizeMap.end()) { - return 0; + for (const auto& pair : TYPE_SIZE_ARRAY) { + if (pair.first == type) { + return pair.second; + } } - return it->second; + return 0; } -constexpr auto kOpDType_UNKNOWN = "UNKNOWN"; -const static std::unordered_map kDDTypeToStringMap = { +constexpr auto OP_DTYPE_UNKNOWN = "UNKNOWN"; +const std::pair DTYPE_TO_STRING_ARRAY[] = { {DataType::DT_UNDEFINED, "UNDEFINED"}, {DataType::DT_FLOAT, "FLOAT"}, {DataType::DT_FLOAT16, "FLOAT16"}, @@ -128,15 +136,16 @@ const static std::unordered_map kDDTypeToStringMap = { std::string GetDTypeString(DataType dtype) { - auto it = kDDTypeToStringMap.find(dtype); - if (it != kDDTypeToStringMap.end()) { - return it->second; + for (const auto& pair : DTYPE_TO_STRING_ARRAY) { + if (pair.first == dtype) { + return std::string(pair.second); + } } - return kOpDType_UNKNOWN; + return OP_DTYPE_UNKNOWN; } -constexpr auto kOpFormat_UNKNOWN = "UNKNOWN"; -const static std::unordered_map kFormatToStringMap = { +constexpr auto OP_FORMAT_UNKNOWN = "UNKNOWN"; +const std::pair FORMAT_TO_STRING_ARRAY[] = { {TensorFormat::FORMAT_NCHW, "NCHW"}, {TensorFormat::FORMAT_NHWC, "NHWC"}, {TensorFormat::FORMAT_ND, "ND"}, @@ -162,7 +171,7 @@ const static std::unordered_map kFormatToStringMap = {TensorFormat::FORMAT_HASHTABLE_LOOKUP_VALUE, "HASHTABLE_LOOKUP_VALUE"}, {TensorFormat::FORMAT_HASHTABLE_LOOKUP_OUTPUT, "HASHTABLE_LOOKUP_OUTPUT"}, {TensorFormat::FORMAT_HASHTABLE_LOOKUP_HITS, "HASHTABLE_LOOKUP_HITS"}, - {TensorFormat::FORMAT_C1HWNCoC0, "C1HWNCoC0"}, + {TensorFormat::FORMAT_C1HWNCOC0, "C1HWNCoC0"}, {TensorFormat::FORMAT_MD, "MD"}, {TensorFormat::FORMAT_NDHWC, "NDHWC"}, {TensorFormat::FORMAT_FRACTAL_ZZ, "FRACTAL_ZZ"}, @@ -191,11 +200,12 @@ const static std::unordered_map kFormatToStringMap = std::string GetFormatString(TensorFormat fmt) { - auto it = kFormatToStringMap.find(fmt); - if (it != kFormatToStringMap.end()) { - return it->second; + for (const auto& pair : FORMAT_TO_STRING_ARRAY) { + if (pair.first == fmt) { + return std::string(pair.second); + } } - return kOpFormat_UNKNOWN; + return OP_FORMAT_UNKNOWN; } std::string GetShapeString(const TensorShape& shape) diff --git a/debug/accuracy_tools/msprobe/ccsrc/utils/DataUtils.hpp b/debug/accuracy_tools/msprobe/ccsrc/utils/DataUtils.h similarity index 89% rename from debug/accuracy_tools/msprobe/ccsrc/utils/DataUtils.hpp rename to debug/accuracy_tools/msprobe/ccsrc/utils/DataUtils.h index f58e15a8c77719f62ddeef8ebbcd25a5b5ebf624..f1cd9239ceaf27a2a52a7ee1275b27231c7ff7cb 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/utils/DataUtils.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/utils/DataUtils.h @@ -14,7 +14,8 @@ * limitations under the License. */ -#pragma once +#ifndef DATAUTILS_H +#define DATAUTILS_H #include #include @@ -24,11 +25,11 @@ namespace MindStudioDebugger { namespace DataUtils { -inline uint64_t UnpackUint64Value_Le(const void* data) +inline uint64_t UnpackUint64ValueLe(const void* data) { return le64toh(*reinterpret_cast(data)); } -inline uint64_t UnpackUint64Value_Be(const void* data) +inline uint64_t UnpackUint64ValueBe(const void* data) { return be64toh(*reinterpret_cast(data)); } @@ -38,11 +39,11 @@ std::string U64ToHexString(uint64_t v); class BFloat16 { public: - static constexpr uint16_t value_mask = 0x7fff; - static constexpr uint16_t inf_value = 0x7f80; - static constexpr uint16_t nan_value = 0x7fc0; - static constexpr uint16_t true_value = 0x3c00; - static constexpr uint32_t f32_inf_value = 0x7f800000; + static constexpr uint16_t VALUE_MASK = 0x7fff; + static constexpr uint16_t INF_VALUE = 0x7f80; + static constexpr uint16_t NAN_VALUE = 0x7fc0; + static constexpr uint16_t TRUE_VALUE = 0x3c00; + static constexpr uint32_t F32_INF_VALUE = 0x7f800000; BFloat16() = default; ~BFloat16() = default; @@ -51,7 +52,7 @@ public: BFloat16 &operator=(const BFloat16 &other) noexcept = default; BFloat16 &operator=(BFloat16 &&other) noexcept = default; - explicit BFloat16(float f); + explicit BFloat16(float f32); explicit operator float() const; BFloat16 operator+(const BFloat16& other) const { return BFloat16(static_cast(*this) + static_cast(other)); } @@ -131,7 +132,7 @@ enum TensorFormat : int { FORMAT_HASHTABLE_LOOKUP_VALUE = 22, FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23, FORMAT_HASHTABLE_LOOKUP_HITS = 24, - FORMAT_C1HWNCoC0 = 25, + FORMAT_C1HWNCOC0 = 25, FORMAT_MD = 26, FORMAT_NDHWC = 27, FORMAT_FRACTAL_ZZ = 28, @@ -166,4 +167,6 @@ std::string GetFormatString(TensorFormat fmt); std::string GetShapeString(const TensorShape& shape); } -} \ No newline at end of file +} + +#endif \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/ccsrc/utils/FileOperation.cpp b/debug/accuracy_tools/msprobe/ccsrc/utils/FileOperation.cpp index 7f025e568abdfe95830902d1e72bdb77300f7de5..029f463ea9c0f479beb12f82fa7b2b289fc44338 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/utils/FileOperation.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/utils/FileOperation.cpp @@ -18,9 +18,9 @@ #include #include -#include "FileUtils.hpp" -#include "DataUtils.hpp" -#include "FileOperation.hpp" +#include "FileUtils.h" +#include "DataUtils.h" +#include "FileOperation.h" namespace MindStudioDebugger { namespace FileOperation { @@ -34,7 +34,8 @@ struct NpyDtypeDescr { char type; size_t length; - std::string str() const { + std::string Str() const + { std::ostringstream buffer; buffer << "\'" << byteorder << type << length << "\'"; return buffer.str(); @@ -42,9 +43,9 @@ struct NpyDtypeDescr { }; // npy file header start information -constexpr char kNpyMagicPrefix[] = "\x93NUMPY"; -constexpr size_t kNpyMagicLen = sizeof(kNpyMagicPrefix) - 1; -constexpr size_t kNpyArrayAlign = 64; +constexpr char NPY_MAGIC_PREFIX[] = "\x93NUMPY"; +constexpr size_t NPY_MAGIC_LEN = sizeof(NPY_MAGIC_PREFIX) - 1; +constexpr size_t NPY_ARRAY_ALIGN = 64; static const std::unordered_map npyTypeDescMap = { {DataType::DT_BOOL, NpyDtypeDescr{'|', 'b', 1}}, {DataType::DT_INT8, NpyDtypeDescr{'|', 'i', 1}}, {DataType::DT_INT16, NpyDtypeDescr{'<', 'i', 2}}, {DataType::DT_INT32, NpyDtypeDescr{'<', 'i', 4}}, @@ -90,7 +91,8 @@ inline static std::string NpyTransShapeToStr(const DataUtils::TensorShape &shape return buffer.str(); } -inline static std::vector NpyLen2Bytes(size_t length, size_t lengthLen) { +inline static std::vector NpyLen2Bytes(size_t length, size_t lengthLen) +{ std::vector buff; lengthLen = std::min(lengthLen, static_cast(sizeof(length))); for (size_t i = 0; i < lengthLen; i++) { @@ -100,7 +102,8 @@ inline static std::vector NpyLen2Bytes(size_t length, size_t lengthLen) { return buff; } -static std::string GenerateNpyHeader(const DataUtils::TensorShape &shape, DataUtils::DataType dt, bool fortranOrder=false) +static std::string GenerateNpyHeader(const DataUtils::TensorShape &shape, + DataUtils::DataType dt, bool fortranOrder = false) { auto typeDesc = npyTypeDescMap.find(dt); if (typeDesc == npyTypeDescMap.end()) { @@ -111,7 +114,7 @@ static std::string GenerateNpyHeader(const DataUtils::TensorShape &shape, DataUt std::string fortranOrderStr = fortranOrder ? "True" : "False" ; buffer << "{"; - buffer << "'descr': " << typeDesc->second.str() << ", "; + buffer << "'descr': " << typeDesc->second.Str() << ", "; buffer << "'fortran_order': " << fortranOrderStr << ", "; buffer << "'shape': " << NpyTransShapeToStr(shape) << ", "; buffer << "}"; @@ -125,19 +128,19 @@ static std::string GenerateNpyHeader(const DataUtils::TensorShape &shape, DataUt constexpr const size_t lengthLenV2 = 4; size_t lengthLen = lengthLenV1; - size_t totalLen = kNpyMagicLen + versionLen + lengthLen + headerLen + 1; + size_t totalLen = NPY_MAGIC_LEN + versionLen + lengthLen + headerLen + 1; if (totalLen > maxLen) { version = {2, 0}; lengthLen = lengthLenV2; - totalLen = kNpyMagicLen + versionLen + lengthLen + headerLen + 1; + totalLen = NPY_MAGIC_LEN + versionLen + lengthLen + headerLen + 1; } - const size_t padLen = kNpyArrayAlign - totalLen % kNpyArrayAlign; + const size_t padLen = NPY_ARRAY_ALIGN - totalLen % NPY_ARRAY_ALIGN; const size_t paddingHeaderLen = headerLen + padLen + 1; const std::string padding(padLen, ' '); std::vector lengthBytes = NpyLen2Bytes(paddingHeaderLen, lengthLen); std::ostringstream out; - out.write(kNpyMagicPrefix, DataUtils::SizeToS64(kNpyMagicLen)); + out.write(NPY_MAGIC_PREFIX, DataUtils::SizeToS64(NPY_MAGIC_LEN)); out.put(version.first); out.put(version.second); out.write(lengthBytes.data(), DataUtils::SizeToS64(lengthBytes.size())); diff --git a/debug/accuracy_tools/msprobe/ccsrc/utils/FileOperation.hpp b/debug/accuracy_tools/msprobe/ccsrc/utils/FileOperation.h similarity index 95% rename from debug/accuracy_tools/msprobe/ccsrc/utils/FileOperation.hpp rename to debug/accuracy_tools/msprobe/ccsrc/utils/FileOperation.h index 3f89263ae3621d33f5bbc8a67e86887d8063067e..1560a1a6dba353f2e0122a639e46fa4c87195bba 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/utils/FileOperation.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/utils/FileOperation.h @@ -18,8 +18,8 @@ #include -#include "include/ErrorCode.hpp" -#include "DataUtils.hpp" +#include "include/ErrorCode.h" +#include "DataUtils.h" namespace MindStudioDebugger { diff --git a/debug/accuracy_tools/msprobe/ccsrc/utils/FileUtils.cpp b/debug/accuracy_tools/msprobe/ccsrc/utils/FileUtils.cpp index 246f899690ccd0e306f5b6b550870406086430cc..9b1cdb75d7a0b7e2e643ab7ab7280008784ff333 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/utils/FileUtils.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/utils/FileUtils.cpp @@ -27,8 +27,8 @@ #include #include -#include "include/ErrorCode.hpp" -#include "FileUtils.hpp" +#include "include/ErrorCode.h" +#include "FileUtils.h" /* 部分环境上c++版本比较老,这里不用filesystem库实现 */ @@ -38,7 +38,8 @@ namespace FileUtils { using namespace MindStudioDebugger; /********************* 基础检查函数库,不做过多校验,路径有效性由调用者保证 ******************/ -bool IsPathExist(const std::string& path) { +bool IsPathExist(const std::string& path) +{ struct stat buffer; return (stat(path.c_str(), &buffer) == 0); } @@ -60,7 +61,7 @@ static std::string GetFullPath(const std::string &originPath) } cwd = cwdBuf; - std::string fullPath = std::move(cwd + pathSeparator + originPath); + std::string fullPath = std::move(cwd + PATH_SEPARATOR + originPath); return fullPath; } @@ -84,8 +85,9 @@ std::vector SplitPath(const std::string &path, char separator) return tokens; } -std::string GetAbsPath(const std::string &originPath) { - std::string fullPath = GetFullPath(originPath); +std::string GetAbsPath(const std::string &originpath) +{ + std::string fullPath = GetFullPath(originpath); if (fullPath.empty()) { return ""; } @@ -118,7 +120,8 @@ std::string GetAbsPath(const std::string &originPath) { return resolvedPath; } -bool IsDir(const std::string& path) { +bool IsDir(const std::string& path) +{ struct stat buffer; if (stat(path.c_str(), &buffer) == 0) { return (buffer.st_mode & S_IFDIR) != 0; @@ -126,15 +129,17 @@ bool IsDir(const std::string& path) { return false; } -bool IsRegularFile(const std::string& path) { - struct stat path_stat; - if (stat(path.c_str(), &path_stat) == 0) { - return S_ISREG(path_stat.st_mode); +bool IsRegularFile(const std::string& path) +{ + struct stat pathStat; + if (stat(path.c_str(), &pathStat) == 0) { + return S_ISREG(pathStat.st_mode); } return false; } -bool IsFileSymbolLink(const std::string& path) { +bool IsFileSymbolLink(const std::string& path) +{ struct stat buffer; if (lstat(path.c_str(), &buffer) == 0) { if (S_ISLNK(buffer.st_mode)) { @@ -144,7 +149,8 @@ bool IsFileSymbolLink(const std::string& path) { return false; } -bool IsPathCharactersValid(const std::string& path) { +bool IsPathCharactersValid(const std::string& path) +{ for (const char& ch : path) { if (!std::isalnum(ch) && ch != '_' && ch != '.' && ch != ':' && ch != '/' && ch != '-') { return false; @@ -243,14 +249,15 @@ bool IsPathLengthLegal(const std::string& path) bool IsPathDepthValid(const std::string& path) { - return std::count(path.begin(), path.end(), pathSeparator) <= PATH_DEPTH_MAX; + auto depth = static_cast(std::count(path.begin(), path.end(), PATH_SEPARATOR)); + return depth <= PATH_DEPTH_MAX; } bool IsFileOwner(const std::string& path) { - struct stat file_stat; - if (stat(path.c_str(), &file_stat) == 0) { - if (file_stat.st_uid == getuid()) { + struct stat fileStat; + if (stat(path.c_str(), &fileStat) == 0) { + if (fileStat.st_uid == getuid()) { return true; } } @@ -258,7 +265,8 @@ bool IsFileOwner(const std::string& path) } /****************** 文件操作函数库,会对入参做基本检查 ************************/ -DebuggerErrno DeleteFile(const std::string &path) { +DebuggerErrno DeleteFile(const std::string &path) +{ if (!IsPathExist(path)) { return DebuggerErrno::OK; } @@ -306,7 +314,6 @@ static DebuggerErrno DeleteDirRec(const std::string &path, uint32_t depth) closedir(dir); return DebuggerErrno::ERROR_ILLEGAL_FILE_TYPE; } - } closedir(dir); @@ -321,7 +328,8 @@ static DebuggerErrno DeleteDirRec(const std::string &path, uint32_t depth) return DebuggerErrno::OK; } -DebuggerErrno DeleteDir(const std::string &path, bool recursion) { +DebuggerErrno DeleteDir(const std::string &path, bool recursion) +{ if (!IsPathExist(path)) { return DebuggerErrno::OK; } @@ -340,7 +348,8 @@ DebuggerErrno DeleteDir(const std::string &path, bool recursion) { return DebuggerErrno::OK; } -static DebuggerErrno CreateDirAux(const std::string& path, bool recursion, mode_t mode) { +static DebuggerErrno CreateDirAux(const std::string& path, bool recursion, mode_t mode) +{ std::string parent = GetParentDir(path); DebuggerErrno ret; @@ -404,16 +413,17 @@ DebuggerErrno Chmod(const std::string& path, const mode_t& mode) return chmod(absPath.c_str(), mode) == 0 ? DebuggerErrno::OK : DebuggerErrno::ERROR_SYSCALL_FAILED; } -DebuggerErrno GetFileSize(const std::string &path, size_t& size) { - struct stat path_stat; - if (stat(path.c_str(), &path_stat) != 0) { +DebuggerErrno GetFileSize(const std::string &path, size_t& size) +{ + struct stat pathStat; + if (stat(path.c_str(), &pathStat) != 0) { return DebuggerErrno::ERROR_FILE_NOT_EXISTS; } - if (!S_ISREG(path_stat.st_mode)) { + if (!S_ISREG(pathStat.st_mode)) { return DebuggerErrno::ERROR_ILLEGAL_FILE_TYPE; } - size = static_cast(path_stat.st_size); + size = static_cast(pathStat.st_size); return DebuggerErrno::OK; } @@ -455,8 +465,8 @@ DebuggerErrno OpenFile(const std::string& path, std::ofstream& ofs, std::ios::op } } - if (!IsPathExist(path)) { - int fd = open(path.c_str(), O_CREAT | O_WRONLY, permission); + if (!IsPathExist(realPath)) { + int fd = open(realPath.c_str(), O_CREAT | O_WRONLY, permission); if (fd < 0) { return DebuggerErrno::ERROR_FAILED_TO_OPEN_FILE; } @@ -600,63 +610,5 @@ DebuggerErrno CheckFileBeforeCreateOrWrite(const std::string &path, bool overwri } return DebuggerErrno::OK; } - -/* 其他文件操作工具 */ -static DebuggerErrno ListAllAux(const std::string &path, std::vector& output, uint32_t depth) -{ - if (depth > PATH_DEPTH_MAX) { - return DebuggerErrno::ERROR_PATH_TOO_DEEP; - } - - DIR* dir = opendir(path.c_str()); - if (dir == nullptr) { - return DebuggerErrno::ERROR_FAILED_TO_OPEN_FILE; - } - - DebuggerErrno ret = DebuggerErrno::OK; - size_t max = output.capacity(); - size_t num = output.size(); - if (num >= max) { - return DebuggerErrno::OK; - } - - struct dirent* entry = nullptr; - while ((entry = readdir(dir)) != nullptr) { - if (strcmp(entry->d_name, ".") == 0 || (strcmp(entry->d_name, "..") == 0)) { - continue; - } - std::string entryPath = path + "/" + entry->d_name; - if (entry->d_type == DT_DIR) { - ret = ListAllAux(entryPath, output, depth + 1); - if (ret != DebuggerErrno::OK) { - closedir(dir); - return ret; - } - } else if (entry->d_type == DT_REG) { - output.emplace_back(entryPath); - if (++num >= max) { - break; - } - } - } - closedir(dir); - return DebuggerErrno::OK; -} - -std::vector ListAll(const std::string &path, size_t max) -{ - std::vector ret; - std::string realPath = GetAbsPath(path); - if (CheckDirCommon(realPath) != DebuggerErrno::OK) { - return ret; - } - ret.reserve(max); - - uint32_t depth = std::count(realPath.begin(), realPath.end(), pathSeparator); - ListAllAux(realPath, ret, depth); - ret.resize(ret.size()); - return ret; -} - } } \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/ccsrc/utils/FileUtils.hpp b/debug/accuracy_tools/msprobe/ccsrc/utils/FileUtils.h similarity index 82% rename from debug/accuracy_tools/msprobe/ccsrc/utils/FileUtils.hpp rename to debug/accuracy_tools/msprobe/ccsrc/utils/FileUtils.h index 70b47137fc40fd7fb73be11ddb8d3551550e2b8d..61ae47f7c8389718adb210635631cf916418a1c2 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/utils/FileUtils.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/utils/FileUtils.h @@ -23,11 +23,11 @@ #include #include -#include "include/ErrorCode.hpp" +#include "include/ErrorCode.h" namespace MindStudioDebugger { -constexpr const char pathSeparator = '/'; +constexpr const char PATH_SEPARATOR = '/'; constexpr const uint32_t FULL_PATH_LENGTH_MAX = 4096; constexpr const uint32_t FILE_NAME_LENGTH_MAX = 255; constexpr const uint32_t PATH_DEPTH_MAX = 32; @@ -64,8 +64,8 @@ constexpr const uint32_t FILE_NAME_MAX = 255; /* 基础检查函数库,不做过多校验,路径有效性由调用者保证 */ bool IsPathExist(const std::string& path); -std::vector SplitPath(const std::string &path, char separator=pathSeparator); -std::string GetAbsPath(const std::string &path); +std::vector SplitPath(const std::string &path, char separator = PATH_SEPARATOR); +std::string GetAbsPath(const std::string &originpath); bool IsDir(const std::string& path); bool IsRegularFile(const std::string& path); bool IsFileSymbolLink(const std::string& path); @@ -85,23 +85,19 @@ bool IsFileOwner(const std::string& path); /* 文件操作函数库,会对入参做基本检查 */ DebuggerErrno DeleteFile(const std::string &path); -DebuggerErrno DeleteDir(const std::string &path, bool recursion=false); -DebuggerErrno CreateDir(const std::string &path, bool recursion=false, mode_t mode=NORMAL_DIR_MODE_DEFAULT); +DebuggerErrno DeleteDir(const std::string &path, bool recursion = false); +DebuggerErrno CreateDir(const std::string &path, bool recursion = false, mode_t mode = NORMAL_DIR_MODE_DEFAULT); DebuggerErrno Chmod(const std::string& path, const mode_t& mode); DebuggerErrno GetFileSize(const std::string &path, size_t& size); -DebuggerErrno OpenFile(const std::string& path, std::ifstream& ifs, std::ios::openmode mode=std::ios::in); -DebuggerErrno OpenFile(const std::string& path, std::ofstream& ofs, std::ios::openmode mode=std::ios::out, - mode_t permission=NORMAL_FILE_MODE_DEFAULT); +DebuggerErrno OpenFile(const std::string& path, std::ifstream& ifs, std::ios::openmode mode = std::ios::in); +DebuggerErrno OpenFile(const std::string& path, std::ofstream& ofs, std::ios::openmode mode = std::ios::out, + mode_t permission = NORMAL_FILE_MODE_DEFAULT); /* 通用检查函数 */ DebuggerErrno CheckFileSuffixAndSize(const std::string &path, FileType type); DebuggerErrno CheckDirCommon(const std::string &path); -DebuggerErrno CheckFileBeforeRead(const std::string &path, const std::string& authority="r", - FileType type=FileType::COMMON); -DebuggerErrno CheckFileBeforeCreateOrWrite(const std::string &path, bool overwrite=false); - -/* 其他文件操作工具 */ -std::vector ListAll(const std::string &path, size_t max = 1024); - +DebuggerErrno CheckFileBeforeRead(const std::string &path, const std::string& authority = "r", + FileType type = FileType::COMMON); +DebuggerErrno CheckFileBeforeCreateOrWrite(const std::string &path, bool overwrite = false); } } \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/ccsrc/utils/MathUtils.cpp b/debug/accuracy_tools/msprobe/ccsrc/utils/MathUtils.cpp index 27111d60c9f86f2ae9b2b2a00b804ab886917755..3618659c120591dbba489b3f101be39a2f2302e8 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/utils/MathUtils.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/utils/MathUtils.cpp @@ -62,19 +62,27 @@ std::string RandomString(uint32_t len, char min, char max) std::string CalculateMD5(const uint8_t* data, size_t length) { MD5_CTX md5ctx; + /* + * 不用于数据加密,不用于文件完整性校验,不用于密码存储,不用于数据唯一性检查 + * 只用于Tensor的统计信息呈现,不涉及数据安全 + */ MD5_Init(&md5ctx); MD5_Update(&md5ctx, data, length); unsigned char digest[MD5_DIGEST_LENGTH]; + /* + * 不用于数据加密,不用于文件完整性校验,不用于密码存储,不用于数据唯一性检查 + * 只用于Tensor的统计信息呈现,不涉及数据安全 + */ MD5_Final(digest, &md5ctx); - static const char hexchar[] = "0123456789abcdef"; + static const char HEX_CHAR[] = "0123456789abcdef"; constexpr const uint8_t hexbase = 16; constexpr const size_t byteToStrWidth = 2; char md5string[MD5_DIGEST_LENGTH * byteToStrWidth + 1]; for (int i = 0; i < MD5_DIGEST_LENGTH; i++) { - md5string[i * byteToStrWidth] = hexchar[digest[i] / hexbase]; - md5string[i * byteToStrWidth + 1] = hexchar[digest[i] % hexbase]; + md5string[i * byteToStrWidth] = HEX_CHAR[digest[i] / hexbase]; + md5string[i * byteToStrWidth + 1] = HEX_CHAR[digest[i] % hexbase]; } md5string[sizeof(md5string) - 1] = '\0'; diff --git a/debug/accuracy_tools/msprobe/ccsrc/utils/MathUtils.hpp b/debug/accuracy_tools/msprobe/ccsrc/utils/MathUtils.h similarity index 88% rename from debug/accuracy_tools/msprobe/ccsrc/utils/MathUtils.hpp rename to debug/accuracy_tools/msprobe/ccsrc/utils/MathUtils.h index 141471ac8ce284ac1a7ab4b6db59f5d0da9a9fe2..7838b1f35d9380ecc918e98ba559ebd91a475a38 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/utils/MathUtils.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/utils/MathUtils.h @@ -23,7 +23,8 @@ namespace MindStudioDebugger { namespace MathUtils { template -T Gcd(T a, T b) { +T Gcd(T a, T b) +{ if (a == 0 || b == 0) { return 0; } @@ -37,7 +38,8 @@ T Gcd(T a, T b) { } template -T Lcm(T a, T b) { +T Lcm(T a, T b) +{ if (a == 0 || b == 0) { return 0; } @@ -46,7 +48,8 @@ T Lcm(T a, T b) { } template -T DivCeil(T v, T divisor) { +T DivCeil(T v, T divisor) +{ if (divisor == 0) { return 0; } @@ -56,13 +59,13 @@ T DivCeil(T v, T divisor) { template T AlignCeil(T v, T block) { - return DivCeil(v, block) * block; + return DivCeil(v, block) * block; } float Random(); float Random(float floor, float ceil); int32_t RandomInt(int32_t floor, int32_t ceil); -std::string RandomString(uint32_t len, char min=' ', char max='~'); +std::string RandomString(uint32_t len, char min = ' ', char max = '~'); std::string CalculateMD5(const uint8_t* data, size_t length); diff --git a/debug/accuracy_tools/msprobe/core/__init__.py b/debug/accuracy_tools/msprobe/core/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..eb80022b66670467408474bec6f5f46e48ff29b2 100644 --- a/debug/accuracy_tools/msprobe/core/__init__.py +++ b/debug/accuracy_tools/msprobe/core/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core.single_save.single_saver import SingleSave +from msprobe.core.single_save.single_comparator import SingleComparator diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index d9623b807121ea129484a535fe8a9e2293e662f3..1aaf7d055fdcaab86028d8c7c2d19317ea050fa2 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -24,6 +24,8 @@ class Const: Class for const """ TOOL_NAME = "msprobe" + MD5_INDEX = "md5_index" + MD5 = "md5" ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$" SEP = "." @@ -51,7 +53,13 @@ class Const: FOUR_SEGMENT = 4 SIX_SEGMENT = 6 SEVEN_SEGMENT = 7 - MAX_DEPTH = 10 + + MAX_DEPTH = 400 + CPU_QUARTER = 4 + DUMP_MAX_DEPTH = 400 + + EXTERN_INPUT_LIST_MAX_LEN = 100 + MAX_PROCESS_NUM = 128 # dump mode ALL = "all" @@ -66,8 +74,9 @@ class Const: ONLINE_DUMP_MODE = [ALL, LIST, AUTO, OFF] SUMMARY = "summary" MD5 = "md5" + HASH = "hash" VALUE = "value" - SUMMARY_MODE = [ALL, SUMMARY, MD5] + SUMMARY_MODE = ["statistics", "md5"] WRITE_FLAGS = os.O_WRONLY | os.O_CREAT WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR @@ -77,6 +86,8 @@ class Const: NUMPY_SUFFIX = ".npy" NUMPY_PATTERN = "*.npy" PT_SUFFIX = ".pt" + PY_SUFFIX = ".py" + INIT_PY = "init.py" ONE_GB = 1073741824 # 1 * 1024 * 1024 * 1024 TEN_GB = 10737418240 # 10 * 1024 * 1024 * 1024 ONE_MB = 1048576 # 1 * 1024 * 1024 @@ -92,6 +103,7 @@ class Const: GRAD_OUTPUT = 'grad_output' PARAMS = 'parameters' PARAMS_GRAD = 'parameters_grad' + DEBUG = 'debug' START = "start" STOP = "stop" ENV_ENABLE = "1" @@ -104,9 +116,13 @@ class Const: RUN_UT = "run_ut" GRAD_PROBE = "grad_probe" STRUCTURE = "structure" - TASK_LIST = [TENSOR, STATISTICS, OVERFLOW_CHECK, FREE_BENCHMARK, RUN_UT, GRAD_PROBE, STRUCTURE] + EXCEPTION_DUMP = "exception_dump" + DUMP_PRECISION_HIGH = "high" + DUMP_PRECISION_LOW = "low" + TASK_LIST = [TENSOR, STATISTICS, OVERFLOW_CHECK, FREE_BENCHMARK, RUN_UT, GRAD_PROBE, STRUCTURE, EXCEPTION_DUMP] DUMP_DATA_COLLECTION_LIST = [STATISTICS, TENSOR, STRUCTURE] DUMP_DATA_MODE_LIST = [ALL, INPUT, OUTPUT, FORWARD, BACKWARD] + DUMP_PRECISION_LIST = [DUMP_PRECISION_LOW, DUMP_PRECISION_HIGH] LEVEL_L0 = "L0" LEVEL_L1 = "L1" LEVEL_L2 = "L2" @@ -129,6 +145,7 @@ class Const: NPU = 'NPU' NPU_LOWERCASE = 'npu' CPU_LOWERCASE = 'cpu' + GPU_LOWERCASE = 'gpu' CUDA_LOWERCASE = 'cuda' DEVICE = 'device' DISTRIBUTED = 'Distributed' @@ -137,6 +154,10 @@ class Const: MODULE_PREFIX = ["Module", "Cell"] FORWARD_NAME_SUFFIX = ".forward" + DUMP_JSON_FILE = "dump_json_file" + DEBUG_JSON_FILE = "debug_json_file" + STACK_JSON_FILE = "stack_json_file" + # struct json param ORIGIN_DATA = "origin_data" SCOPE = "scope" @@ -167,6 +188,10 @@ class Const: TOP_LAYER = "TopLayer" CELL = "Cell" MODULE = "Module" + API = "api" + PYNATIVE_MODE = "pynative" + PYNATIVE_GRAPH_MODE = "pynative_graph" + FRAME_FILE_LIST = ["site-packages/torch", "package/torch", "site-packages/mindspore", "package/mindspore"] INPLACE_LIST = [ "broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter", @@ -188,7 +213,11 @@ class Const: FILL_CHAR_NUMS = 50 TOOL_ENDS_SUCCESSFULLY = f"{TOOL_NAME} ends successfully." + WITHOUT_CALL_STACK = "The call stack retrieval failed." + STACK_FILTER_KEYWORDS = ["msprobe/core", "msprobe/pytorch", "msprobe/mindspore"] + CALL_STACK_FLAG = "data_dump/api_registry" + NEW_STACK_FLAG = "0" STEP = "step" RANK = "rank" @@ -206,12 +235,20 @@ class Const: TORCH_FLOAT32 = "torch.float32" TORCH_BFLOAT16 = "torch.bfloat16" + TYPE = 'type' DTYPE = 'dtype' SHAPE = 'shape' + STACK_INFO = 'stack_info' MAX = 'Max' MIN = 'Min' MEAN = 'Mean' NORM = 'Norm' + DATA_NAME = 'data_name' + STATE = 'state' + REQ_GRAD = 'requires_grad' + API_ORIGIN_NAME = 'api_origin_name' + TENSOR_STAT_INDEX = 'tensor_stat_index' + SUMMARY_METRICS_LIST = [MAX, MIN, MEAN, NORM] CODE_STACK = 'Code Stack' OP_NAME = 'Op Name' @@ -223,6 +260,10 @@ class Const: # 分隔符常量 SCOPE_SEPARATOR = "/" REPLACEMENT_CHARACTER = "_" + PIPE_SEPARATOR = "|" + + FORWARD_PATTERN = SEP + FORWARD + SEP + BACKWARD_PATTERN = SEP + BACKWARD + SEP OPTIMIZER = "optimizer" CLIP_GRAD = "clip_grad" @@ -230,12 +271,147 @@ class Const: TENSOR_STAT_LEN = 2 + TENSOR_TYPE = "torch.Tensor" + DTENSOR_TYPE = "torch.distributed.tensor.DTensor" + FAKE_TENSOR_TYPE = "torch._subclasses.fake_tensor.FakeTensor" + + SUPPORT_API_FILE_NAME = "support_wrap_ops.yaml" + + API_ATTR_LIST = ["__name__", "default"] + + PT_API_TYPE_FUNCTIONAL = "functional" + PT_API_TYPE_TENSOR = "tensor" + PT_API_TYPE_TORCH = "torch" + PT_API_TYPE_VF = "_VF" + PT_API_TYPE_NPU = "torch_npu" + PT_API_TYPE_ATEN = "aten" + PT_API_TYPE_DIST = "distributed" + PT_API_TYPE_NPU_DIST = "npu_distributed" + PT_API_TYPE_MINDSPEED = "mindspeed" + + MS_API_TYPE_OPS = "ops" + MS_API_TYPE_TENSOR = "tensor" + MS_API_TYPE_STUB_TENSOR = "stubtensor" + MS_API_TYPE_MINT = "mint.ops" + MS_API_TYPE_MINT_FUNC = "mint.nn.functional" + MS_API_TYPE_COM = "communication.comm_func" + MS_API_TYPE_MINT_DIST = "mint.distributed" + + FUNCTIONAL_API_TYPE_PREFIX = "Functional" + TENSOR_API_TYPE_PREFIX = "Tensor" + DIST_API_TYPE_PREFIX = "Distributed" + + TORCH_API_TYPE_PREFIX = "Torch" + NPU_API_TYPE_PREFIX = "NPU" + ATEN_API_TYPE_PREFIX = "Aten" + VF_API_TYPE_PREFIX = "VF" + MINDSPEED_API_TYPE_PREFIX = "MindSpeed" + + MINT_API_TYPE_PREFIX = "Mint" + MINT_FUNC_API_TYPE_PREFIX = "MintFunctional" + MINT_DIST_API_TYPE_PREFIX = "MintDistributed" + + SUPPORT_API_DICT_KEY_MAP = { + PT_FRAMEWORK: { + PT_API_TYPE_FUNCTIONAL: PT_API_TYPE_FUNCTIONAL, + PT_API_TYPE_TENSOR: PT_API_TYPE_TENSOR, + PT_API_TYPE_TORCH: PT_API_TYPE_TORCH, + PT_API_TYPE_VF: PT_API_TYPE_VF, + PT_API_TYPE_NPU: PT_API_TYPE_NPU, + PT_API_TYPE_ATEN: PT_API_TYPE_ATEN, + PT_API_TYPE_DIST: PT_API_TYPE_DIST, + PT_API_TYPE_NPU_DIST: PT_API_TYPE_NPU_DIST, + PT_API_TYPE_MINDSPEED: PT_API_TYPE_MINDSPEED + }, + MS_FRAMEWORK: { + MS_API_TYPE_OPS: MS_API_TYPE_OPS, + MS_API_TYPE_TENSOR: MS_API_TYPE_TENSOR, + MS_API_TYPE_STUB_TENSOR: MS_API_TYPE_TENSOR, + MS_API_TYPE_MINT: MS_API_TYPE_MINT, + MS_API_TYPE_MINT_FUNC: MS_API_TYPE_MINT_FUNC, + MS_API_TYPE_COM: MS_API_TYPE_COM, + MS_API_TYPE_MINT_DIST: MS_API_TYPE_MINT_DIST + }, + MT_FRAMEWORK: { + PT_API_TYPE_FUNCTIONAL: PT_API_TYPE_FUNCTIONAL, + PT_API_TYPE_TENSOR: PT_API_TYPE_TENSOR, + PT_API_TYPE_TORCH: PT_API_TYPE_TORCH, + PT_API_TYPE_NPU: PT_API_TYPE_NPU, + PT_API_TYPE_DIST: PT_API_TYPE_DIST + } + } + + API_DATA_PREFIX = { + PT_FRAMEWORK: { + PT_API_TYPE_FUNCTIONAL: FUNCTIONAL_API_TYPE_PREFIX, + PT_API_TYPE_TENSOR: TENSOR_API_TYPE_PREFIX, + PT_API_TYPE_TORCH: TORCH_API_TYPE_PREFIX, + PT_API_TYPE_VF: VF_API_TYPE_PREFIX, + PT_API_TYPE_NPU: NPU_API_TYPE_PREFIX, + PT_API_TYPE_ATEN: ATEN_API_TYPE_PREFIX, + PT_API_TYPE_DIST: DIST_API_TYPE_PREFIX, + PT_API_TYPE_NPU_DIST: DIST_API_TYPE_PREFIX, + PT_API_TYPE_MINDSPEED: MINDSPEED_API_TYPE_PREFIX + }, + MS_FRAMEWORK: { + MS_API_TYPE_OPS: FUNCTIONAL_API_TYPE_PREFIX, + MS_API_TYPE_TENSOR: TENSOR_API_TYPE_PREFIX, + MS_API_TYPE_STUB_TENSOR: TENSOR_API_TYPE_PREFIX, + MS_API_TYPE_MINT: MINT_API_TYPE_PREFIX, + MS_API_TYPE_MINT_FUNC: MINT_FUNC_API_TYPE_PREFIX, + MS_API_TYPE_COM: DIST_API_TYPE_PREFIX, + MS_API_TYPE_MINT_DIST: MINT_DIST_API_TYPE_PREFIX + }, + MT_FRAMEWORK: { + PT_API_TYPE_FUNCTIONAL: FUNCTIONAL_API_TYPE_PREFIX, + PT_API_TYPE_TENSOR: TENSOR_API_TYPE_PREFIX, + PT_API_TYPE_TORCH: TORCH_API_TYPE_PREFIX, + PT_API_TYPE_NPU: NPU_API_TYPE_PREFIX, + PT_API_TYPE_DIST: DIST_API_TYPE_PREFIX + } + } + + def _fused_adamw_( + self, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + *, + lr, + beta1, + beta2, + weight_decay, + eps, + amsgrad, + maximize, + grad_scale=None, + found_inf=None + ): + pass + + API_WITH_SELF_ARG = { + 'Torch._fused_adamw_': _fused_adamw_ + } + + ASCEND = "ASCEND" + MATCH_MODE_NAME = "pure name" + MATCH_MODE_MAPPING = "mapping" + MATCH_MODE_SIMILARITY = "similarity" + CONFIG_CHECK_PASS = "pass" + CONFIG_CHECK_WARNING = "warning" + CONFIG_CHECK_ERROR = "error" + + MIX_DUMP_NAMES = {'graph', 'pynative'} + class CompareConst: """ Class for compare module const """ SPACE = " " + NAME = "Name" # compare result column name NPU_NAME = "NPU Name" BENCH_NAME = "Bench Name" @@ -243,6 +419,8 @@ class CompareConst: BENCH_DTYPE = "Bench Dtype" NPU_SHAPE = "NPU Tensor Shape" BENCH_SHAPE = "Bench Tensor Shape" + NPU_CSV_FILE = "NPU CSV File" + BENCH_CSV_FILE = "Bench CSV File" NPU_MAX = "NPU max" NPU_MIN = "NPU min" NPU_MEAN = "NPU mean" @@ -256,11 +434,15 @@ class CompareConst: MEAN_DIFF = "Mean diff" NORM_DIFF = "L2norm diff" COSINE = "Cosine" + EUC_DIST = "EucDist" MAX_ABS_ERR = "MaxAbsErr" MAX_RELATIVE_ERR = "MaxRelativeErr" MIN_RELATIVE_ERR = "MinRelativeErr" MEAN_RELATIVE_ERR = "MeanRelativeErr" NORM_RELATIVE_ERR = "NormRelativeErr" + REQ_GRAD_CONSIST = "Requires_grad Consistent" + NPU_REQ_GRAD = "NPU Requires_grad" + BENCH_REQ_GRAD = "Bench Requires_grad" ACCURACY = "Accuracy Reached or Not" STACK = "NPU_Stack_Info" DATA_NAME = "Data_name" @@ -278,10 +460,11 @@ class CompareConst: OUTPUT_STRUCT = "output_struct" PARAMS_STRUCT = "params_struct" PARAMS_GRAD_STRUCT = "params_grad_struct" + DEBUG_STRUCT = "debug_struct" SUMMARY = "summary" COMPARE_RESULT = "compare_result" COMPARE_MESSAGE = "compare_message" - MAX_EXCEL_LENGTH = 1048576 + MAX_EXCEL_LENGTH = 1048500 YES = "Yes" NO = "No" STATISTICS_INDICATOR_NUM = 4 @@ -329,21 +512,21 @@ class CompareConst: ULP_ERR_STATUS = "ulp_err_status" - COMPARE_RESULT_HEADER = [ - NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR, - ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO, - NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, ACCURACY, ERROR_MESSAGE - ] + ALL_COMPARE_INDEX = [COSINE, EUC_DIST, MAX_ABS_ERR, MAX_RELATIVE_ERR, + ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO] + SUMMARY_COMPARE_INDEX = [MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF, + MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR] + MD5_COMPARE_INDEX = [RESULT] - SUMMARY_COMPARE_RESULT_HEADER = [ - NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF, - MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR, - NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, RESULT, ERROR_MESSAGE - ] + BASIC_INFO = [NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, NPU_REQ_GRAD, BENCH_REQ_GRAD] + SUMMARY_INFO = [NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM] - MD5_COMPARE_RESULT_HEADER = [ - NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, NPU_MD5, BENCH_MD5, RESULT - ] + COMPARE_RESULT_HEADER = BASIC_INFO + ALL_COMPARE_INDEX + SUMMARY_INFO + [REQ_GRAD_CONSIST, ACCURACY, ERROR_MESSAGE] + + SUMMARY_COMPARE_RESULT_HEADER = BASIC_INFO + SUMMARY_COMPARE_INDEX + SUMMARY_INFO + [REQ_GRAD_CONSIST, RESULT, + ERROR_MESSAGE] + + MD5_COMPARE_RESULT_HEADER = BASIC_INFO + [NPU_MD5, BENCH_MD5, REQ_GRAD_CONSIST] + MD5_COMPARE_INDEX COMPARE_RESULT_HEADER_STACK = COMPARE_RESULT_HEADER + [STACK] @@ -357,18 +540,11 @@ class CompareConst: Const.MD5: MD5_COMPARE_RESULT_HEADER } - ALL_COMPARE_INDEX = [COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO] - SUMMARY_COMPARE_INDEX = [MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF, - MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR] - # dtype match - MS_TYPE = [ - [Const.FLOAT16, Const.FLOAT32], [Const.FLOAT32, Const.FLOAT16], - [Const.FLOAT16, Const.BFLOAT16], [Const.BFLOAT16, Const.FLOAT16] - ] - TORCH_TYPE = [ - [Const.TORCH_FLOAT16, Const.TORCH_FLOAT32], [Const.TORCH_FLOAT32, Const.TORCH_FLOAT16], - [Const.TORCH_FLOAT16, Const.TORCH_BFLOAT16], [Const.TORCH_BFLOAT16, Const.TORCH_FLOAT16] + + DTYPE_MATCH_GROUPS = [ + {Const.FLOAT16, Const.FLOAT32, Const.BFLOAT16}, + {Const.TORCH_FLOAT16, Const.TORCH_FLOAT32, Const.TORCH_BFLOAT16} ] # read_op @@ -386,16 +562,10 @@ class CompareConst: Const.KWARGS: INPUT_STRUCT, Const.OUTPUT: OUTPUT_STRUCT, Const.PARAMS: PARAMS_STRUCT, - Const.PARAMS_GRAD: PARAMS_GRAD_STRUCT + Const.PARAMS_GRAD: PARAMS_GRAD_STRUCT, + Const.DEBUG: DEBUG_STRUCT } - STRUCT_COMPARE_KEY = [ - INPUT_STRUCT, - OUTPUT_STRUCT, - PARAMS_STRUCT, - PARAMS_GRAD_STRUCT - ] - # compare standard HUNDRED_RATIO_THRESHOLD = 0.01 THOUSAND_RATIO_THRESHOLD = 0.001 @@ -406,6 +576,8 @@ class CompareConst: ULP_FLOAT16_THRESHOLD = 1 # compare result data + NO_REAL_DATA = 'No real data' + API_UNMATCH = 'api unmatched' READ_NONE = 'No data' NONE = 'None' SHAPE_UNMATCH = 'shape unmatched' @@ -467,22 +639,47 @@ class CompareConst: BENCH_MEAN: None, BENCH_NORM: None, ACCURACY: '', ERROR_MESSAGE: '' } MS_GRAPH_NPY = { - COSINE: None, MAX_ABS_ERR: None, MAX_RELATIVE_ERR: None, ONE_THOUSANDTH_ERR_RATIO: None, + COSINE: None, EUC_DIST: None, MAX_ABS_ERR: None, MAX_RELATIVE_ERR: None, ONE_THOUSANDTH_ERR_RATIO: None, FIVE_THOUSANDTHS_ERR_RATIO: None } MS_GRAPH_STATISTIC = { MAX_DIFF: None, MIN_DIFF: None, MEAN_DIFF: None, NORM_DIFF: None, MAX_RELATIVE_ERR: None, MIN_RELATIVE_ERR: None, MEAN_RELATIVE_ERR: None, NORM_RELATIVE_ERR: None } + MS_GRAPH_CSV = { + NPU_CSV_FILE: None, BENCH_CSV_FILE: None + } + + API_MAPPING_KEYS_TO_COMPARE = [ + ('ms_args', 'pt_args'), + ('ms_outputs', 'pt_outputs'), + ('ms_parameters', 'pt_parameters'), + ('ms_parameters_grad', 'pt_parameters_grad') + ] + INPUT_PATTERN = Const.SEP + Const.INPUT + Const.SEP KWARGS_PATTERN = Const.SEP + Const.KWARGS + Const.SEP OUTPUT_PATTERN = Const.SEP + Const.OUTPUT + Const.SEP PARAMS_PATTERN = Const.SEP + Const.PARAMS + Const.SEP PARAMS_GRAD_PATTERN = Const.SEP + Const.PARAMS_GRAD + Const.SEP - COMPARE_KEY = 'compare_key' - COMPARE_SHAPE = 'compare_shape' + + CMP_KEY = 'compare_key' + CMP_SHAPE = 'compare_shape' + + OP_NAME_X = 'op_name_x' + MATCH_RESULT_COLUMNS = [ + OP_NAME_X, 'dtype_x', 'shape_x', 'summary_x', 'stack_info_x', 'state_x', 'api_origin_name_x', + 'requires_grad_x', 'data_name_x', + CMP_KEY, CMP_SHAPE, + 'op_name_y', 'dtype_y', 'shape_y', 'summary_y', 'stack_info_y', 'state_y', 'api_origin_name_y', + 'requires_grad_y', 'data_name_y' + ] + INTERNAL_API_MAPPING_FILE = 'ms_to_pt_api.yaml' UNREADABLE = 'unreadable data' + NPU_DUMP_DATA_DIR = 'npu_dump_data_dir' + BENCH_DUMP_DATA_DIR = 'bench_dump_data_dir' + NO_REAL_DATA_FLAG = '-1' class FileCheckConst: @@ -504,6 +701,10 @@ class FileCheckConst: XLSX_SUFFIX = ".xlsx" YAML_SUFFIX = ".yaml" IR_SUFFIX = ".ir" + ZIP_SUFFIX = ".zip" + SHELL_SUFFIX = ".sh" + LOG_SUFFIX = ".log" + DB_SUFFIX = '.db' MAX_PKL_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 MAX_NUMPY_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024 MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 @@ -512,7 +713,12 @@ class FileCheckConst: MAX_XLSX_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 MAX_YAML_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 MAX_IR_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_ZIP_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024 + MAX_FILE_IN_ZIP_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_FILE_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024 + MAX_LOG_SIZE = 10737418240 # 1 * 1024 * 1024 * 1024 + MAX_DB_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024 DIR = "dir" FILE = "file" DATA_DIR_AUTHORITY = 0o750 @@ -525,7 +731,10 @@ class FileCheckConst: CSV_SUFFIX: MAX_CSV_SIZE, XLSX_SUFFIX: MAX_XLSX_SIZE, YAML_SUFFIX: MAX_YAML_SIZE, - IR_SUFFIX: MAX_IR_SIZE + IR_SUFFIX: MAX_IR_SIZE, + ZIP_SUFFIX: MAX_ZIP_SIZE, + LOG_SUFFIX: MAX_LOG_SIZE, + DB_SUFFIX: MAX_DB_SIZE } CSV_BLACK_LIST = r'^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]' @@ -538,61 +747,6 @@ class OverflowConst: OVERFLOW_DEBUG_MODE = 1 -class MsCompareConst: - # api_info field - MINT = "Mint" - MINT_FUNCTIONAL = "MintFunctional" - TENSOR_API = "Tensor" - - API_NAME_STR_LENGTH = 4 - MAX_RECURSION_DEPTH = 20 - - # Mindtorch api_info field - MINDTORCH_TENSOR = "Tensor" - MINDTORCH = "Torch" - MINDTORCH_FUNC = "Functional" - MINDTORCH_NPU = "NPU" - MINDTORCH_DIST = "Distributed" - - - - MT_VALID_API_TYPES = [ - MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR - ] - - TASK_FIELD = "task" - STATISTICS_TASK = "statistics" - FRAMEWORK = "framework" - TENSOR_TASK = "tensor" - DUMP_DATA_DIR_FIELD = "dump_data_dir" - DATA_FIELD = "data" - - # supported api yaml - SUPPORTED_API_LIST_FILE = "checker_support_api.yaml" - SUPPORTED_TENSOR_LIST_KEY = "tensor" - - # detail_csv - DETAIL_CSV_API_NAME = "API Name" - DETAIL_CSV_BENCH_DTYPE = "Bench Dtype" - DETAIL_CSV_TESTED_DTYPE = "Tested Dtype" - DETAIL_CSV_SHAPE = "Shape" - DETAIL_CSV_PASS_STATUS = "Status" - DETAIL_CSV_MESSAGE = "Message" - DETAIL_CSV_FILE_NAME = "accuracy_checking_details" - - # result_csv - RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success" - RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success" - RESULT_CSV_FILE_NAME = "accuracy_checking_result" - - EPSILON = 1e-8 - - class ProcessStatus: - SUCCESS = "success" - API_NOT_FOUND = "api_not_found" - EXCEPTION_SKIP = "exception_skip" - - class MsgConst: """ Class for log messages const @@ -629,7 +783,21 @@ class MonitorConst: """ Class for monitor const """ - OP_LIST = ["norm", "min", "max", "zeros", "nans", "id", "mean"] + + # monitor config set default values + DEFAULT_GRAD_ACC_STEPS = 1 + DEFAULT_START_ITERATION = 0 + DEFAULT_START_STEP = 0 + DEFAULT_MAX_COLLECT_TIMES = 1e8 + DEFAULT_MIN_COLLECT_TIMES = 0 + DEFAULT_STEP_INTERVAL = 1 + + OP_LIST = ["norm", "min", "max", "zeros", "nans", "id", "mean", "shape", "dtype"] + OP_MONVIS_SUPPORTED = [ + "norm", "min", "max", "zeros", "nans", "mean", + "entropy", "softmax_max", "sr", "kernel_norm", "std_x", "jacobian", + "proxy", "token_similarity" + ] MONITOR_OUTPUT_DIR = "MONITOR_OUTPUT_DIR" DEFAULT_MONITOR_OUTPUT_DIR = "./monitor_output" DATABASE = "database" @@ -641,7 +809,9 @@ class MonitorConst: "DeepSpeedZeroOptimizer_Stage3" ) DEEPSPEED_ZERO_OPT_FILTER = "DeepSpeedZeroOptimizer" - RULE_NAME = ['AnomalyTurbulence'] + RULE_NAME = ['AnomalyTurbulence', 'AnomalyNan'] + L2_HOOKS = ["linear_hook", "attention_hook"] + SA_ORDERS = ["s,b,h,d", "b,s,h,d"] SLICE_SIZE = 20480 # used for name @@ -658,15 +828,16 @@ class MonitorConst: ACTVGRAD = "actv_grad" POST_GRAD = "post_grad" PRE_GRAD = "pre_grad" + PRE_PARAM = "param_origin" + POST_PARAM = "param_updated" ACC_GRAD = "acc_grad" PREFIX_POST = "post" PREFIX_PRE = "pre" EXP_AVG = "exp_avg" EXP_AVG_SQ = "exp_avg_sq" - PARAM = "param" CSV_HEADER = ["vpp_stage", "name", "step"] - CSV_HEADER_XY = ["vpp_stage", "name", "step", "micro_step"] + CSV_HEADER_MICRO_STEP = ["vpp_stage", "name", "step", "micro_step"] OUTPUT_DIR_PATTERN = r"([\w-]{0,20})-rank(\d{1,5})-" ANOMALY_JSON = "anomaly.json" ANALYSE_JSON = "anomaly_analyse.json" @@ -674,3 +845,29 @@ class MonitorConst: CSV = "csv" API = "api" HEADER_NAME = 'name' + MAX_NDIGITS = 20 + + DEFAULT_STAGE = -1 + FORWARD_STAGE = 0 + BACKWARD_STAGE = 1 + OPTIMIZER_STAGE = 2 + FORWARD_KEY = [ACTV] + BACKWARD_KEY = [ACTVGRAD, PRE_GRAD, POST_GRAD, ACC_GRAD] + OPTIMIZER_KEY = [EXP_AVG, EXP_AVG_SQ] + + TRAIN_STAGE = {} + for key in FORWARD_KEY: + TRAIN_STAGE[key] = FORWARD_STAGE + for key in BACKWARD_KEY: + TRAIN_STAGE[key] = BACKWARD_STAGE + for key in OPTIMIZER_KEY: + TRAIN_STAGE[key] = OPTIMIZER_STAGE + + # csv2db + DEFAULT_INT_VALUE = 0 + MAX_PROCESS_NUM = 128 + CSV_FILE_PATTERN = r"_(\d+)-(\d+)\.csv" + BATCH_SIZE = 10000 + MAX_PARTITION = 10_000_000 + MIN_PARTITION = 10 + \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/core/common/db_manager.py b/debug/accuracy_tools/msprobe/core/common/db_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..4bb7540d87f69745e58cce5be3e1d84f8129e88b --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/common/db_manager.py @@ -0,0 +1,225 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sqlite3 +from typing import List, Tuple, Dict, Any +from functools import wraps + +from msprobe.pytorch.common.log import logger +from msprobe.core.common.file_utils import check_path_before_create, change_mode +from msprobe.core.common.const import FileCheckConst + + +def _db_operation(func): + """数据库操作装饰器,自动管理连接""" + @wraps(func) + def wrapper(self, *args, **kwargs): + conn, curs = None, None + try: + conn, curs = self._get_connection() + result = func(self, conn, curs, *args, **kwargs) + return result # 显式返回正常结果 + + except sqlite3.Error as err: + logger.error(f"Database operation failed: {err}") + if conn: + conn.rollback() + return None # 显式返回错误情况下的None + + finally: + self._release_connection(conn, curs) + return wrapper + + +class DBManager: + """ + 数据库管理类,封装常用数据库操作 + """ + + DEFAULT_FETCH_SIZE = 10000 + DEFAULT_INSERT_SIZE = 10000 + MAX_ROW_COUNT = 100000000 + + def __init__(self, db_path: str): + """ + 初始化DBManager + :param db_path: 数据库文件路径 + :param table_config: 表配置对象 + """ + self.db_path = db_path + + @staticmethod + def _get_where_sql(where_list): + if not where_list: + return "", tuple() + + where_clauses = [] + where_values = [] + if where_list: + for col, val in where_list.items(): + where_clauses.append(f"{col} = ?") + where_values.append(val) + if where_clauses: + where_sql = " WHERE " + " AND ".join(where_clauses) + return where_sql, tuple(where_values) + + @_db_operation + def insert_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, + table_name: str, data: List[Tuple], key_list: List[str] = None) -> int: + """ + 批量插入数据 + :param table_name: 表名 + :param data: 要插入的数据列表 + :param batch_size: 每批插入的大小 + :return: 插入的行数 + """ + if not data: + return 0 + columns = len(data[0]) + if key_list and columns != len(key_list): + raise ValueError( + f"When inserting into table {table_name}, the length of key list ({key_list})" + f"does not match the data({columns}).") + + batch_size = self.DEFAULT_INSERT_SIZE + placeholders = ", ".join(["?"] * columns) + if key_list: + keys = ", ".join(key_list) + sql = f"INSERT OR IGNORE INTO {table_name} ({keys}) VALUES ({placeholders})" + else: + sql = f"INSERT OR IGNORE INTO {table_name} VALUES ({placeholders})" + + inserted_rows = 0 + for i in range(0, len(data), batch_size): + batch = data[i:i + batch_size] + curs.executemany(sql, batch) + inserted_rows += curs.rowcount + + conn.commit() + return inserted_rows + + @_db_operation + def select_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, + table_name: str, + columns: List[str] = None, + where: dict = None) -> List[Dict]: + """ + 查询数据 + :param table_name: 表名 + :param columns: 要查询的列 + :param where: WHERE条件 + :return: 查询结果列表(字典形式) + """ + + if not columns: + raise ValueError("columns parameter cannot be empty, specify columns to select (e.g. ['id', 'name'])") + if not isinstance(columns, list) or not all(isinstance(col, str) for col in columns): + raise TypeError("columns must be a list of strings (e.g. ['id', 'name'])") + + cols = ", ".join(columns) + sql = f"SELECT {cols} FROM {table_name}" + + where_sql, where_parems = self._get_where_sql(where) + curs.execute(sql + where_sql, where_parems) + + return [dict(row) for row in curs.fetchall()] + + @_db_operation + def update_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, + table_name: str, updates: Dict[str, Any], + where: dict = None) -> int: + """ + 更新数据 + :param table_name: 表名 + :param updates: 要更新的字段和值 + :param where: WHERE条件 + :param where_params: WHERE条件参数 + :return: 影响的行数 + """ + set_clause = ", ".join([f"{k} = ?" for k in updates.keys()]) + sql = f"UPDATE {table_name} SET {set_clause}" + + params = tuple(updates.values()) + + where_sql, where_parems = self._get_where_sql(where) + + curs.execute(sql + where_sql, params + where_parems) + conn.commit() + return curs.rowcount + + @_db_operation + def execute_sql(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, + sql: str, params: Tuple = None) -> List[Dict]: + """ + 执行自定义SQL查询 + :param sql: SQL语句 + :param params: 参数 + :return: 查询结果 + """ + curs.execute(sql, params or ()) + if sql.strip().upper().startswith("SELECT"): + return [dict(row) for row in curs.fetchall()] + conn.commit() + return [] + + def table_exists(self, table_name: str) -> bool: + """ + :param table_name: 表名 + :return: 查询结果 + """ + result = self.select_data( + table_name="sqlite_master", + columns=["name"], + where={"type": "table", "name": table_name} + ) + return len(result) > 0 + + @_db_operation + def execute_multi_sql(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, + sql_commands: List[str]) -> List[List[Dict]]: + """ + 批量执行多个SQL语句 + :param sql_commands: [sql1, sql2, ...] + :return: 每个SELECT语句的结果列表 + """ + results = [] + for sql in sql_commands: + curs.execute(sql) + if sql.strip().upper().startswith("SELECT"): + results.append([dict(row) for row in curs.fetchall()]) + conn.commit() + return results + + def _get_connection(self) -> Tuple[sqlite3.Connection, sqlite3.Cursor]: + """获取数据库连接和游标""" + check_path_before_create(self.db_path) + try: + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row # 使用Row工厂获取字典形式的结果 + curs = conn.cursor() + return conn, curs + except sqlite3.Error as err: + logger.error(f"Database connection failed: {err}") + raise + + def _release_connection(self, conn: sqlite3.Connection, curs: sqlite3.Cursor) -> None: + """释放数据库连接""" + try: + if curs is not None: + curs.close() + if conn is not None: + conn.close() + except sqlite3.Error as err: + logger.error(f"Failed to release database connection: {err}") + change_mode(self.db_path, FileCheckConst.DATA_FILE_AUTHORITY) diff --git a/debug/accuracy_tools/msprobe/core/common/decorator.py b/debug/accuracy_tools/msprobe/core/common/decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..d3710002bcc281be2fd0f19fc7abda1af35ec936 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/common/decorator.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from functools import wraps + +from msprobe.core.common.const import Const +from msprobe.core.common.exceptions import MsprobeException +from msprobe.core.common.log import logger + +# 记录工具函数递归的深度 +recursion_depth = defaultdict(int) + + +def recursion_depth_decorator(func_info, max_depth=Const.MAX_DEPTH): + """装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。""" + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + func_id = id(func) + recursion_depth[func_id] += 1 + if recursion_depth[func_id] > max_depth: + msg = f"call {func_info} exceeds the recursion limit." + logger.error_log_with_exp( + msg, + MsprobeException( + MsprobeException.RECURSION_LIMIT_ERROR, msg + ), + ) + try: + result = func(*args, **kwargs) + finally: + recursion_depth[func_id] -= 1 + return result + + return wrapper + + return decorator diff --git a/debug/accuracy_tools/msprobe/core/common/exceptions.py b/debug/accuracy_tools/msprobe/core/common/exceptions.py index d71d30224b677fb19361f62de0ee25b2d32d389f..f4bb39db4c9534a6fe553d75e396dc5b91d71f33 100644 --- a/debug/accuracy_tools/msprobe/core/common/exceptions.py +++ b/debug/accuracy_tools/msprobe/core/common/exceptions.py @@ -21,19 +21,21 @@ class CodedException(Exception): def __str__(self): return self.error_info - - + + class MsprobeException(CodedException): INVALID_PARAM_ERROR = 0 OVERFLOW_NUMS_ERROR = 1 RECURSION_LIMIT_ERROR = 2 INTERFACE_USAGE_ERROR = 3 + UNSUPPORTED_TYPE_ERROR = 4 err_strs = { INVALID_PARAM_ERROR: "[msprobe] 无效参数:", OVERFLOW_NUMS_ERROR: "[msprobe] 超过预设溢出次数 当前溢出次数:", RECURSION_LIMIT_ERROR: "[msprobe] 递归调用超过限制:", - INTERFACE_USAGE_ERROR: "[msprobe] Invalid interface usage: " + INTERFACE_USAGE_ERROR: "[msprobe] Invalid interface usage: ", + UNSUPPORTED_TYPE_ERROR: "[msprobe] Unsupported type: " } diff --git a/debug/accuracy_tools/msprobe/core/common/file_utils.py b/debug/accuracy_tools/msprobe/core/common/file_utils.py index fdc626ca6a1a90e9060cefa237f9d5d8d7e42844..1b848791f1670f92bf72516085402099462a9a44 100644 --- a/debug/accuracy_tools/msprobe/core/common/file_utils.py +++ b/debug/accuracy_tools/msprobe/core/common/file_utils.py @@ -12,23 +12,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import atexit import csv import fcntl +import io import os +import pickle +from multiprocessing import shared_memory import stat import json import re import shutil -from datetime import datetime, timezone -from dateutil import parser +import sys +import zipfile +import multiprocessing import yaml import numpy as np import pandas as pd +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.core.common.log import logger from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.const import FileCheckConst +from msprobe.core.common.const import FileCheckConst, CompareConst +from msprobe.core.common.global_lock import global_lock, is_main_process + +proc_lock = multiprocessing.Lock() class FileChecker: @@ -164,6 +172,12 @@ def check_path_exists(path): if not os.path.exists(path): logger.error('The file path %s does not exist.' % path) raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) + + +def check_path_not_exists(path): + if os.path.exists(path): + logger.error('The file path %s already exist.' % path) + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_path_readability(path): @@ -266,6 +280,7 @@ def make_dir(dir_path): file_check.common_check() +@recursion_depth_decorator('msprobe.core.common.file_utils.create_directory', max_depth=16) def create_directory(dir_path): """ Function Description: @@ -297,12 +312,13 @@ def check_path_before_create(path): def check_dirpath_before_read(path): path = os.path.realpath(path) dirpath = os.path.dirname(path) - if check_others_writable(dirpath): - logger.warning(f"The directory is writable by others: {dirpath}.") - try: - check_path_owner_consistent(dirpath) - except FileCheckException: - logger.warning(f"The directory {dirpath} is not yours.") + if dedup_log('check_dirpath_before_read', dirpath): + if check_others_writable(dirpath): + logger.warning(f"The directory is writable by others: {dirpath}.") + try: + check_path_owner_consistent(dirpath) + except FileCheckException: + logger.warning(f"The directory {dirpath} is not yours.") def check_file_or_directory_path(path, isdir=False): @@ -332,6 +348,23 @@ def change_mode(path, mode): 'Failed to change {} authority. {}'.format(path, str(ex))) from ex +@recursion_depth_decorator('msprobe.core.common.file_utils.recursive_chmod') +def recursive_chmod(path): + """ + 递归地修改目录及其子目录和文件的权限,文件修改为640,路径修改为750 + + :param path: 要修改权限的目录路径 + """ + for _, dirs, files in os.walk(path): + for file_name in files: + file_path = os.path.join(path, file_name) + change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) + for dir_name in dirs: + dir_path = os.path.join(path, dir_name) + change_mode(dir_path, FileCheckConst.DATA_DIR_AUTHORITY) + recursive_chmod(dir_path) + + def path_len_exceeds_limit(file_path): return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \ len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH @@ -351,7 +384,7 @@ def check_file_type(path): elif os.path.isfile(path): return FileCheckConst.FILE else: - logger.error(f'{path} does not exist, please check!') + logger.error(f'path does not exist, please check!') raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) @@ -427,6 +460,43 @@ def save_excel(path, data): return "list" raise ValueError("Data must be a DataFrame or a list of (DataFrame, sheet_name) pairs.") + def check_value_is_valid(value: str) -> bool: + if not isinstance(value, str): + return True + try: + # -1.00 or +1.00 should be considered as digit numbers + float(value) + except ValueError: + # otherwise, they will be considered as formular injections + return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value)) + return True + + def malicious_check(df): + for row_name in df.index: + if not check_value_is_valid(row_name): + raise RuntimeError(f"Malicious value [{row_name}] not allowed to be written into the excel: {path}.") + + for col_name in df.columns: + if not check_value_is_valid(col_name): + raise RuntimeError(f"Malicious value [{col_name}] not allowed to be written into the excel: {path}.") + + for _, row in df.iterrows(): + for _, value in row.items(): + if not check_value_is_valid(value): + raise RuntimeError(f"Malicious value [{value}] not allowed to be written into the excel: {path}.") + + def save_in_slice(df, base_name): + malicious_check(df) + df_length = len(df) + if df_length < CompareConst.MAX_EXCEL_LENGTH: + df.to_excel(writer, sheet_name=base_name if base_name else 'Sheet1', index=False) + else: + slice_num = (df_length + CompareConst.MAX_EXCEL_LENGTH - 1) // CompareConst.MAX_EXCEL_LENGTH + slice_size = (df_length + slice_num - 1) // slice_num + for i in range(slice_num): + df.iloc[i * slice_size: min((i + 1) * slice_size, df_length)] \ + .to_excel(writer, sheet_name=f'{base_name}_part_{i}' if base_name else f'part_{i}', index=False) + check_path_before_create(path) path = os.path.realpath(path) @@ -434,18 +504,27 @@ def save_excel(path, data): data_type = validate_data(data) try: - if data_type == "single": - data.to_excel(path, index=False) - elif data_type == "list": - with pd.ExcelWriter(path) as writer: + with pd.ExcelWriter(path) as writer: + if data_type == "single": + save_in_slice(data, None) + elif data_type == "list": for data_df, sheet_name in data: - data_df.to_excel(writer, sheet_name=sheet_name, index=False) + save_in_slice(data_df, sheet_name) except Exception as e: logger.error(f'Save excel file "{os.path.basename(path)}" failed.') raise RuntimeError(f"Save excel file {path} failed.") from e change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY) +def move_directory(src_path, dst_path): + check_file_or_directory_path(src_path, isdir=True) + check_path_before_create(dst_path) + try: + shutil.move(src_path, dst_path) + except Exception as e: + logger.error(f"move directory {src_path} to {dst_path} failed") + raise RuntimeError(f"move directory {src_path} to {dst_path} failed") from e + change_mode(dst_path, FileCheckConst.DATA_DIR_AUTHORITY) def move_file(src_path, dst_path): @@ -511,7 +590,7 @@ def write_csv(data, filepath, mode="a+", malicious_check=False): if not isinstance(value, str): return True try: - # -1.00 or +1.00 should be consdiered as digit numbers + # -1.00 or +1.00 should be considered as digit numbers float(value) except ValueError: # otherwise, they will be considered as formular injections @@ -557,7 +636,7 @@ def write_df_to_csv(data, filepath, mode="w", header=True, malicious_check=False if not isinstance(value, str): return True try: - # -1.00 or +1.00 should be consdiered as digit numbers + # -1.00 or +1.00 should be considered as digit numbers float(value) except ValueError: # otherwise, they will be considered as formular injections @@ -588,8 +667,11 @@ def write_df_to_csv(data, filepath, mode="w", header=True, malicious_check=False def remove_path(path): if not os.path.exists(path): return + if os.path.islink(path): + logger.error(f"Failed to delete {path}, it is a symbolic link.") + raise RuntimeError("Delete file or directory failed.") try: - if os.path.islink(path) or os.path.isfile(path): + if os.path.isfile(path): os.remove(path) else: shutil.rmtree(path) @@ -598,7 +680,7 @@ def remove_path(path): raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) from err except Exception as e: logger.error("Failed to delete {}. Please check.".format(path)) - raise RuntimeError(f"Delete {path} failed.") from e + raise RuntimeError("Delete file or directory failed.") from e def get_json_contents(file_path): @@ -632,42 +714,252 @@ def os_walk_for_files(path, depth): return res -def check_crt_valid(pem_path): +def check_zip_file(zip_file_path): + with zipfile.ZipFile(zip_file_path, 'r') as zip_file: + total_size = 0 + if len(zip_file.infolist()) > FileCheckConst.MAX_FILE_IN_ZIP_SIZE: + raise ValueError(f"Too many files in {os.path.basename(zip_file_path)}") + for file_info in zip_file.infolist(): + if file_info.file_size > FileCheckConst.MAX_FILE_SIZE: + raise ValueError(f"File {file_info.filename} is too large to extract") + + total_size += file_info.file_size + if total_size > FileCheckConst.MAX_ZIP_SIZE: + raise ValueError(f"Total extracted size exceeds the limit of {FileCheckConst.MAX_ZIP_SIZE} bytes") + + +def read_xlsx(file_path, sheet_name=None): + check_file_or_directory_path(file_path) + check_zip_file(file_path) + try: + if sheet_name: + result_df = pd.read_excel(file_path, keep_default_na=False, sheet_name=sheet_name) + else: + result_df = pd.read_excel(file_path, keep_default_na=False) + except Exception as e: + logger.error(f"The xlsx file failed to load. Please check the path: {file_path}.") + raise RuntimeError(f"Read xlsx file {file_path} failed.") from e + return result_df + + +def create_file_with_list(result_list, filepath): + check_path_before_create(filepath) + filepath = os.path.realpath(filepath) + try: + with FileOpen(filepath, 'w', encoding='utf-8') as file: + fcntl.flock(file, fcntl.LOCK_EX) + for item in result_list: + file.write(item + '\n') + fcntl.flock(file, fcntl.LOCK_UN) + except Exception as e: + logger.error(f'Save list to file "{os.path.basename(filepath)}" failed.') + raise RuntimeError(f"Save list to file {os.path.basename(filepath)} failed.") from e + change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) + + +def create_file_with_content(data, filepath): + check_path_before_create(filepath) + filepath = os.path.realpath(filepath) + try: + with FileOpen(filepath, 'w', encoding='utf-8') as file: + fcntl.flock(file, fcntl.LOCK_EX) + file.write(data) + fcntl.flock(file, fcntl.LOCK_UN) + except Exception as e: + logger.error(f'Save content to file "{os.path.basename(filepath)}" failed.') + raise RuntimeError(f"Save content to file {os.path.basename(filepath)} failed.") from e + change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) + + +def check_file_whether_exist_or_not(filepath): + if os.path.exists(filepath): + check_file_or_directory_path(filepath) + else: + check_path_before_create(filepath) + + +def add_file_to_zip(zip_file_path, file_path, arc_path=None): + """ + Add a file to a ZIP archive, if zip does not exist, create one. + + :param zip_file_path: Path to the ZIP archive + :param file_path: Path to the file to add + :param arc_path: Optional path inside the ZIP archive where the file should be added """ - Check the validity of the SSL certificate. + check_file_or_directory_path(file_path) + check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX) + check_file_whether_exist_or_not(zip_file_path) + check_file_size(file_path, FileCheckConst.MAX_FILE_IN_ZIP_SIZE) + zip_size = os.path.getsize(zip_file_path) if os.path.exists(zip_file_path) else 0 + if zip_size + os.path.getsize(file_path) > FileCheckConst.MAX_ZIP_SIZE: + raise RuntimeError(f"ZIP file size exceeds the limit of {FileCheckConst.MAX_ZIP_SIZE} bytes") + try: + proc_lock.acquire() + with zipfile.ZipFile(zip_file_path, 'a') as zip_file: + zip_file.write(file_path, arc_path) + except Exception as e: + logger.error(f'add file to zip "{os.path.basename(zip_file_path)}" failed.') + raise RuntimeError(f"add file to zip {os.path.basename(zip_file_path)} failed.") from e + finally: + proc_lock.release() + change_mode(zip_file_path, FileCheckConst.DATA_FILE_AUTHORITY) - Load the SSL certificate from the specified path, parse and check its validity period. - If the certificate is expired or invalid, raise a RuntimeError. - Parameters: - pem_path (str): The file path of the SSL certificate. +def create_file_in_zip(zip_file_path, file_name, content): + """ + Create a file with content inside a ZIP archive. - Raises: - RuntimeError: If the SSL certificate is invalid or expired. + :param zip_file_path: Path to the ZIP archive + :param file_name: Name of the file to create + :param content: Content to write to the file """ - import OpenSSL + check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX) + check_file_whether_exist_or_not(zip_file_path) + zip_size = os.path.getsize(zip_file_path) if os.path.exists(zip_file_path) else 0 + if zip_size + sys.getsizeof(content) > FileCheckConst.MAX_ZIP_SIZE: + raise RuntimeError(f"ZIP file size exceeds the limit of {FileCheckConst.MAX_ZIP_SIZE} bytes") try: - with FileOpen(pem_path, "r") as f: - pem_data = f.read() - cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, pem_data) - pem_start = parser.parse(cert.get_notBefore().decode("UTF-8")) - pem_end = parser.parse(cert.get_notAfter().decode("UTF-8")) - logger.info(f"The SSL certificate passes the verification and the validity period " - f"starts from {pem_start} ends at {pem_end}.") + with open(zip_file_path, 'a+') as f: # 必须用 'a+' 模式才能 flock + # 2. 获取排他锁(阻塞直到成功) + fcntl.flock(f, fcntl.LOCK_EX) # LOCK_EX: 独占锁 + with zipfile.ZipFile(zip_file_path, 'a') as zip_file: + zip_info = zipfile.ZipInfo(file_name) + zip_info.compress_type = zipfile.ZIP_DEFLATED + zip_file.writestr(zip_info, content) + fcntl.flock(f, fcntl.LOCK_UN) except Exception as e: - logger.error("Failed to parse the SSL certificate. Check the certificate.") - raise RuntimeError(f"The SSL certificate is invalid, {pem_path}") from e + logger.error(f'Save content to file "{os.path.basename(zip_file_path)}" failed.') + raise RuntimeError(f"Save content to file {os.path.basename(zip_file_path)} failed.") from e + change_mode(zip_file_path, FileCheckConst.DATA_FILE_AUTHORITY) - now_utc = datetime.now(tz=timezone.utc) - if cert.has_expired() or not (pem_start <= now_utc <= pem_end): - raise RuntimeError(f"The SSL certificate has expired and needs to be replaced, {pem_path}") +def extract_zip(zip_file_path, extract_dir): + """ + Extract the contents of a ZIP archive to a specified directory. -def read_xlsx(file_path): - check_file_or_directory_path(file_path) + :param zip_file_path: Path to the ZIP archive + :param extract_dir: Directory to extract the contents to + """ + check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX) + check_file_or_directory_path(zip_file_path) + create_directory(extract_dir) try: - result_df = pd.read_excel(file_path, keep_default_na=False) + proc_lock.acquire() + check_zip_file(zip_file_path) except Exception as e: - logger.error(f"The xlsx file failed to load. Please check the path: {file_path}.") - raise RuntimeError(f"Read xlsx file {file_path} failed.") from e - return result_df + logger.error(f'Save content to file "{os.path.basename(zip_file_path)}" failed.') + raise RuntimeError(f"Save content to file {os.path.basename(zip_file_path)} failed.") from e + finally: + proc_lock.release() + try: + with zipfile.ZipFile(zip_file_path, 'r') as zip_file: + zip_file.extractall(extract_dir) + except Exception as e: + raise RuntimeError(f"extract zip file {os.path.basename(zip_file_path)} failed") from e + recursive_chmod(extract_dir) + + +def split_zip_file_path(zip_file_path): + check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX) + zip_file_path = os.path.realpath(zip_file_path) + return os.path.dirname(zip_file_path), os.path.basename(zip_file_path) + + +def dedup_log(func_name, filter_name): + with SharedDict() as shared_dict: + exist_names = shared_dict.get(func_name, set()) + if filter_name in exist_names: + return False + exist_names.add(filter_name) + shared_dict[func_name] = exist_names + return True + + +class SharedDict: + def __init__(self): + self._changed = False + self._dict = None + self._shm = None + + def __enter__(self): + self._load_shared_memory() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + if self._changed: + data = pickle.dumps(self._dict) + global_lock.acquire() + try: + self._shm.buf[0:len(data)] = bytearray(data) + finally: + global_lock.release() + self._shm.close() + except FileNotFoundError: + name = self.get_shared_memory_name() + logger.debug(f'close shared memory {name} failed, shared memory has already been destroyed.') + + def __setitem__(self, key, value): + self._dict[key] = value + self._changed = True + + def __contains__(self, item): + return item in self._dict + + @classmethod + def destroy_shared_memory(cls): + if is_main_process(): + name = cls.get_shared_memory_name() + try: + shm = shared_memory.SharedMemory(create=False, name=name) + shm.close() + shm.unlink() + logger.debug(f'destroy shared memory, name: {name}') + except FileNotFoundError: + logger.debug(f'destroy shared memory {name} failed, shared memory has already been destroyed.') + + @classmethod + def get_shared_memory_name(cls): + if is_main_process(): + return f'shared_memory_{os.getpid()}' + return f'shared_memory_{os.getppid()}' + + def get(self, key, default=None): + return self._dict.get(key, default) + + def _load_shared_memory(self): + name = self.get_shared_memory_name() + try: + self._shm = shared_memory.SharedMemory(create=False, name=name) + except FileNotFoundError: + try: + # 共享内存空间增加至5M + self._shm = shared_memory.SharedMemory(create=True, name=name, size=1024 * 1024 * 5) + data = pickle.dumps({}) + self._shm.buf[0:len(data)] = bytearray(data) + logger.debug(f'create shared memory, name: {name}') + except FileExistsError: + self._shm = shared_memory.SharedMemory(create=False, name=name) + self._safe_load() + + def _safe_load(self): + with io.BytesIO(self._shm.buf[:]) as buff: + try: + self._dict = SafeUnpickler(buff).load() + except Exception as e: + logger.debug(f'shared dict is unreadable, reason: {e}, create new dict.') + self._dict = {} + self._shm.buf[:] = bytearray(b'\x00' * len(self._shm.buf)) # 清空内存 + self._changed = True + + +class SafeUnpickler(pickle.Unpickler): + WHITELIST = {'builtins': {'str', 'bool', 'int', 'float', 'list', 'set', 'dict'}} + + def find_class(self, module, name): + if module in self.WHITELIST and name in self.WHITELIST[module]: + return super().find_class(module, name) + raise pickle.PicklingError(f'Unpickling {module}.{name} is illegal!') + + +atexit.register(SharedDict.destroy_shared_memory) diff --git a/debug/accuracy_tools/msprobe/core/common/framework_adapter.py b/debug/accuracy_tools/msprobe/core/common/framework_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..5302b96aad7af62142deda1b55be28629f8e847d --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/common/framework_adapter.py @@ -0,0 +1,169 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.import functools +import functools +from msprobe.core.common.const import Const +from msprobe.core.common.file_utils import check_file_or_directory_path +from msprobe.core.common.file_utils import save_npy + + +class FrameworkDescriptor: + def __get__(self, instance, owner): + if owner._framework is None: + owner.import_framework() + return owner._framework + + +class FmkAdp: + fmk = Const.PT_FRAMEWORK + supported_fmk = [Const.PT_FRAMEWORK, Const.MS_FRAMEWORK] + supported_dtype_list = ["bfloat16", "float16", "float32", "float64"] + _framework = None + framework = FrameworkDescriptor() + + @classmethod + def import_framework(cls): + if cls.fmk == Const.PT_FRAMEWORK: + import torch + cls._framework = torch + elif cls.fmk == Const.MS_FRAMEWORK: + import mindspore + cls._framework = mindspore + else: + raise Exception(f"init framework adapter error, not in {cls.supported_fmk}") + + @classmethod + def set_fmk(cls, fmk=Const.PT_FRAMEWORK): + if fmk not in cls.supported_fmk: + raise Exception(f"init framework adapter error, not in {cls.supported_fmk}") + cls.fmk = fmk + cls._framework = None # 重置框架,以便下次访问时重新导入 + + @classmethod + def get_rank(cls): + if cls.fmk == Const.PT_FRAMEWORK: + return cls.framework.distributed.get_rank() + return cls.framework.communication.get_rank() + + @classmethod + def get_rank_id(cls): + if cls.is_initialized(): + return cls.get_rank() + return 0 + + @classmethod + def is_initialized(cls): + if cls.fmk == Const.PT_FRAMEWORK: + return cls.framework.distributed.is_initialized() + return cls.framework.communication.GlobalComm.INITED + + @classmethod + def is_nn_module(cls, module): + if cls.fmk == Const.PT_FRAMEWORK: + return isinstance(module, cls.framework.nn.Module) + return isinstance(module, cls.framework.nn.Cell) + + @classmethod + def is_tensor(cls, tensor): + if cls.fmk == Const.PT_FRAMEWORK: + return isinstance(tensor, cls.framework.Tensor) + return isinstance(tensor, cls.framework.Tensor) + + @classmethod + def process_tensor(cls, tensor, func): + if cls.fmk == Const.PT_FRAMEWORK: + if not tensor.is_floating_point() or tensor.dtype == cls.framework.float64: + tensor = tensor.float() + return float(func(tensor)) + return float(func(tensor).asnumpy()) + + @classmethod + def tensor_max(cls, tensor): + return cls.process_tensor(tensor, lambda x: x.max()) + + @classmethod + def tensor_min(cls, tensor): + return cls.process_tensor(tensor, lambda x: x.min()) + + @classmethod + def tensor_mean(cls, tensor): + return cls.process_tensor(tensor, lambda x: x.mean()) + + @classmethod + def tensor_norm(cls, tensor): + return cls.process_tensor(tensor, lambda x: x.norm()) + + @classmethod + def save_tensor(cls, tensor, filepath): + if cls.fmk == Const.PT_FRAMEWORK: + tensor_npy = tensor.cpu().detach().float().numpy() + else: + tensor_npy = tensor.asnumpy() + save_npy(tensor_npy, filepath) + + @classmethod + def dtype(cls, dtype_str): + if dtype_str not in cls.supported_dtype_list: + raise Exception(f"{dtype_str} is not supported by adapter, not in {cls.supported_dtype_list}") + return getattr(cls.framework, dtype_str) + + @classmethod + def named_parameters(cls, module): + if cls.fmk == Const.PT_FRAMEWORK: + if not isinstance(module, cls.framework.nn.Module): + raise Exception(f"{module} is not a torch.nn.Module") + return module.named_parameters() + if not isinstance(module, cls.framework.nn.Cell): + raise Exception(f"{module} is not a mindspore.nn.Cell") + return module.parameters_and_names() + + @classmethod + def register_forward_pre_hook(cls, module, hook, with_kwargs=False): + if cls.fmk == Const.PT_FRAMEWORK: + if not isinstance(module, cls.framework.nn.Module): + raise Exception(f"{module} is not a torch.nn.Module") + module.register_forward_pre_hook(hook, with_kwargs=with_kwargs) + else: + if not isinstance(module, cls.framework.nn.Cell): + raise Exception(f"{module} is not a mindspore.nn.Cell") + original_construct = module.construct + + @functools.wraps(original_construct) + def new_construct(*args, **kwargs): + if with_kwargs: + hook(module, args, kwargs) + else: + hook(module, args) + return original_construct(*args, **kwargs) + + module.construct = new_construct + + @classmethod + def load_checkpoint(cls, path, to_cpu=True, weights_only=True): + check_file_or_directory_path(path) + if cls.fmk == Const.PT_FRAMEWORK: + try: + if to_cpu: + return cls.framework.load(path, map_location=cls.framework.device("cpu"), weights_only=weights_only) + else: + return cls.framework.load(path, weights_only=weights_only) + except Exception as e: + raise RuntimeError(f"load pt file {path} failed: {e}") from e + return mindspore.load_checkpoint(path) + + @classmethod + def asnumpy(cls, tensor): + if cls.fmk == Const.PT_FRAMEWORK: + return tensor.float().numpy() + return tensor.float().asnumpy() \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/core/common/global_lock.py b/debug/accuracy_tools/msprobe/core/common/global_lock.py new file mode 100644 index 0000000000000000000000000000000000000000..2090f009ea5a78a7c5fbda61c12b6c0a842b7d25 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/common/global_lock.py @@ -0,0 +1,86 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +from multiprocessing.shared_memory import SharedMemory +import random +import time +import atexit +import os + +from msprobe.core.common.log import logger + + +def is_main_process(): + return multiprocessing.current_process().name == 'MainProcess' + + +class GlobalLock: + def __init__(self): + self.name = self.get_lock_name() + try: + self._shm = SharedMemory(create=False, name=self.name) + time.sleep(random.randint(0, 500) / 10000) # 等待随机时长以避免同时获得锁 + except FileNotFoundError: + try: + self._shm = SharedMemory(create=True, name=self.name, size=1) + self._shm.buf[0] = 0 + logger.debug(f'{self.name} is created.') + except FileExistsError: + self.__init__() + + @classmethod + def get_lock_name(cls): + if is_main_process(): + return f'global_lock_{os.getpid()}' + return f'global_lock_{os.getppid()}' + + @classmethod + def is_lock_exist(cls): + try: + SharedMemory(create=False, name=cls.get_lock_name()).close() + return True + except FileNotFoundError: + return False + + def cleanup(self): + self._shm.close() + if is_main_process(): + try: + self._shm.unlink() + logger.debug(f'{self.name} is unlinked.') + except FileNotFoundError: + logger.warning(f'{self.name} has already been unlinked.') + + def acquire(self, timeout=180): + """ + acquire global lock, default timeout is 3 minutes. + + :param float timeout: timeout(seconds), default value is 180. + """ + start = time.time() + while time.time() - start < timeout: + if self._shm.buf[0] == 0: + self._shm.buf[0] = 1 + return + time.sleep(random.randint(10, 500) / 10000) # 自旋,等待1-50ms + self._shm.buf[0] = 1 + + def release(self): + self._shm.buf[0] = 0 + + +global_lock = GlobalLock() +atexit.register(global_lock.cleanup) diff --git a/debug/accuracy_tools/msprobe/core/common/log.py b/debug/accuracy_tools/msprobe/core/common/log.py index f20d25d991ef2d3da1307336e4aa05ec3bc87d86..4ce19e4961c62cad736ad90072de12cb44b4fb95 100644 --- a/debug/accuracy_tools/msprobe/core/common/log.py +++ b/debug/accuracy_tools/msprobe/core/common/log.py @@ -89,6 +89,13 @@ class BaseLogger: self.error(msg) raise exception + def warning_log_with_exp(self, msg, exception): + """ + 打印警告日志并抛出指定异常 + """ + self.warning(msg) + raise exception + def _print_log(self, level, msg, end='\n'): current_rank = self.get_rank() current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) diff --git a/debug/accuracy_tools/msprobe/core/common/parallel_state.py b/debug/accuracy_tools/msprobe/core/common/parallel_state.py new file mode 100644 index 0000000000000000000000000000000000000000..d4e48b53c309ccea6c5f3c430f474f957bf5171f --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/common/parallel_state.py @@ -0,0 +1,193 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from msprobe.core.common.log import logger +from msprobe.core.common.exceptions import MsprobeException + + +class RankGroupGenerator(object): + def __init__(self, tensor_parallel: int, expert_parallel: int, data_parallel: int, + pipeline_parallel: int, context_parallel: int, order: str) -> None: + self.tensor_parallel = tensor_parallel + self.expert_parallel = expert_parallel + self.data_parallel = data_parallel + self.pipeline_parallel = pipeline_parallel + self.context_parallel = context_parallel + self.total_size = tensor_parallel * data_parallel * pipeline_parallel * context_parallel + + self.parallel_sizes = { + "tp": self.tensor_parallel, + "pp": self.pipeline_parallel, + "dp": self.data_parallel, + "ep": self.expert_parallel, + "cp": self.context_parallel, + } + self.original_order = order + normalized_order = order.lower() + + # 检查ep和dp是否相邻 + if 'ep' in normalized_order: + if 'ep-dp' not in normalized_order and 'dp-ep' not in normalized_order: + logger.error(f"The ep and dp must be adjacent in order ({self.original_order}).") + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + + # 检查所有非1的并行维度是否都在order中 + for name in self.parallel_sizes.keys(): + size = self.parallel_sizes[name] + if name not in normalized_order: + if size != 1: + logger.error(f"The parallel size ({name}) is ({size}), " + f"but it's not specified in order ({self.original_order}).") + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + else: + normalized_order += '-' + name + + self.order_with_ep = normalized_order + self.order_without_ep = '-'.join([item for item in normalized_order.split('-') if item != 'ep']) + + self.size_list_with_ep = [] + self.size_list_without_ep = [] + + for item in normalized_order.split('-'): + if item == 'dp': + self.size_list_with_ep.append(self.data_parallel // self.expert_parallel) + self.size_list_without_ep.append(self.data_parallel) + elif item == 'ep': + self.size_list_with_ep.append(self.expert_parallel) + else: + self.size_list_with_ep.append(self.parallel_sizes[item]) + self.size_list_without_ep.append(self.parallel_sizes[item]) + + @staticmethod + def create_mask(order_str: str, target_tokens: str) -> List[bool]: + order_elements = order_str.split('-') + target_elements = target_tokens.split('-') + mask = [False] * len(order_elements) + for token in target_elements: + mask[order_elements.index(token)] = True + return mask + + @staticmethod + def create_masked_rank_groups( + total_size: int, + parallel_dims: List[int], + mask: List[bool], + ) -> List[List[int]]: + def compute_prefix_products(dimensions: List[int], initial: int = 1) -> List[int]: + products = [initial] + current = initial + for dim in dimensions: + current *= dim + products.append(current) + return products + + def calculate_inner_product(a: List[int], b: List[int]) -> int: + return sum(x * y for x, y in zip(a, b)) + + def decompose_index(index: int, shape: List[int], strides: List[int] = None) -> List[int]: + if strides is None: + strides = compute_prefix_products(shape) + indices = [(index // stride) % dim for dim, stride in zip(shape, strides)] + + # 验证分解是否正确 + if calculate_inner_product(indices, strides[:-1]) != index: + error_msg = f"The index {index} with shape {shape} doesn't match decomposed indices {indices}." + logger.error(error_msg) + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + + return indices + + # 分离被掩码和未被掩码的维度 + masked_dims = [dim for dim, is_masked in zip(parallel_dims, mask) if is_masked] + unmasked_dims = [dim for dim, is_masked in zip(parallel_dims, mask) if not is_masked] + + # 计算全局、掩码和未掩码的步长 + global_strides = compute_prefix_products(parallel_dims) + masked_strides = [stride for stride, is_masked in zip(global_strides, mask) if is_masked] + unmasked_strides = [stride for stride, is_masked in zip(global_strides, mask) if not is_masked] + + # 计算组大小和组数 + group_dim = compute_prefix_products(masked_dims)[-1] + group_count = total_size // group_dim + + # 生成所有组的rank + rank_groups = [] + for group_idx in range(group_count): + decomposed_group = decompose_index(group_idx, unmasked_dims) + current_group = [] + for in_group_idx in range(group_dim): + decomposed_rank = decompose_index(in_group_idx, masked_dims) + rank_value = (calculate_inner_product(decomposed_rank, masked_strides) + + calculate_inner_product(decomposed_group, unmasked_strides)) + current_group.append(rank_value) + rank_groups.append(current_group) + + return rank_groups + + def generate_ranks(self, token: str, separate_ep: bool = False) -> List[List[int]]: + if separate_ep: + parallel_dims = self.size_list_with_ep + current_order = self.order_with_ep + else: + parallel_dims = self.size_list_without_ep + current_order = self.order_without_ep + + mask = self.create_mask(current_order, token) + return self.create_masked_rank_groups(self.total_size, parallel_dims, mask) + + def generate_all_ranks(self) -> dict: + result = {} + for token in ["dp", "pp", "tp"]: + result[token] = self.generate_ranks(token) + result[f"{token}_size"] = self.parallel_sizes[token] + return result + + +def get_tp_pp_default_groups( + total_world_size: int, + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + order: str = "tp-cp-ep-dp-pp", +) -> tuple: + context_parallel_size = 1 + expert_parallel_size = 1 + + # 检查world_size是否可被各并行维度的乘积整除 + product = tensor_parallel_size * pipeline_parallel_size * context_parallel_size + if total_world_size % product != 0: + logger.error(f"The world size ({total_world_size}) is not divisible by " + f"{tensor_parallel_size} x {pipeline_parallel_size} x {context_parallel_size}.") + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + + data_parallel_size = total_world_size // product + + # 检查数据并行是否可被专家并行整除 + if data_parallel_size % expert_parallel_size != 0: + logger.error(f"The data parallel size ({data_parallel_size}) is not divisible by expert parallel size.") + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + + # 生成rank组 + rank_creator = RankGroupGenerator( + tensor_parallel=tensor_parallel_size, + expert_parallel=expert_parallel_size, + data_parallel=data_parallel_size, + pipeline_parallel=pipeline_parallel_size, + context_parallel=context_parallel_size, + order=order, + ) + + return rank_creator.generate_ranks('tp'), rank_creator.generate_ranks('pp') diff --git a/debug/accuracy_tools/msprobe/mindspore/runtime.py b/debug/accuracy_tools/msprobe/core/common/runtime.py similarity index 77% rename from debug/accuracy_tools/msprobe/mindspore/runtime.py rename to debug/accuracy_tools/msprobe/core/common/runtime.py index 0191a484cbc096b2e211b22b5abce147eac23b97..b905c5470e89c6f8dd3856c9b1ae22f78a22c8fe 100644 --- a/debug/accuracy_tools/msprobe/mindspore/runtime.py +++ b/debug/accuracy_tools/msprobe/core/common/runtime.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,7 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from msprobe.core.common.const import Const + + class Runtime: step_count: int = 0 rank_id: int = -1 is_running: bool = False + run_mode: str = Const.PYNATIVE_MODE + current_iter: int = 0 + current_rank: None diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index c06b5b64927bf47da1573df3b1d4db34dfa24cb1..83455b8e3a13b95428e0e878c8de331f6d2b0b4e 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -14,24 +14,29 @@ # limitations under the License. import collections +import functools +import inspect import os import re -import subprocess +import threading import time -from collections import defaultdict +from collections import OrderedDict from datetime import datetime, timezone -from functools import wraps import numpy as np -from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path, load_json) from msprobe.core.common.const import Const, CompareConst -from msprobe.core.common.log import logger +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.core.common.exceptions import MsprobeException - +from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path, load_json) +from msprobe.core.common.log import logger device = collections.namedtuple('device', ['type', 'index']) prefixes = ['api_stack', 'list', 'range', 'acl'] +file_suffix_to_file_type = { + "dump.json": Const.DUMP_JSON_FILE, + "debug.json": Const.DEBUG_JSON_FILE, +} class MsprobeBaseException(Exception): @@ -75,6 +80,10 @@ class MsprobeBaseException(Exception): MERGE_COMPARE_RESULT_ERROR = 33 NAMES_STRUCTS_MATCH_ERROR = 34 INVALID_STATE_ERROR = 35 + INVALID_API_NAME_ERROR = 36 + CROSS_FRAME_ERROR = 37 + MISSING_THRESHOLD_ERROR = 38 + WRONG_THRESHOLD_ERROR = 38 def __init__(self, code, error_info: str = ""): super(MsprobeBaseException, self).__init__() @@ -106,6 +115,82 @@ class DumpException(MsprobeBaseException): return f"Dump Error Code {self.code}: {self.error_info}" +class ThreadSafe: + """ + 线程安全控制工具类,提供三种使用方式: + 1.上下文管理器:with ThreadSafe() + 2.主动加锁与释放锁:ThreadSafe.acquire()/ThreadSafe.release() + 3.方法装饰器:@ThreadSafe.synchronized + """ + _lock = threading.RLock() + + def __enter__(self): + self.__class__._lock.acquire() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.__class__._lock.release() + + @classmethod + def acquire(cls): + cls._lock.acquire() + + @classmethod + def release(cls): + cls._lock.release() + + @classmethod + def synchronized(cls, func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + with cls._lock: + return func(*args, **kwargs) + + return wrapper + + +class ModuleQueue: + def __init__(self): + self.queue = OrderedDict() + + def add_name(self, name): + self.queue[name] = True + + def remove_name(self, name): + if name in self.queue: + del self.queue[name] + + def find_last(self, name): + """ + 在队列中找到当前 Module/Cell 的父节点名称并返回,若找不到则返回None + + Args: + name: 需要寻找父节点的 Module/Cell 的名称 + + Returns: + 返回父节点名称,找不到则返回None + + Examples: + 父节点名称格式: Module.module1.module1.forward.0 + 子节点名称格式: Module.module1.module2.Module2.forward.0 + 匹配关系: Module/Cell 的名称总能被点(.)分割符分成5个部分及以上,子节点截断后4个点和父节点截断后3个点的前缀名称是匹配的 + """ + child_parts = name.split('.') + if len(child_parts) < 5: + return None + child_name_prefix = '.'.join(child_parts[:-4]) + if child_name_prefix in Const.MODULE_PREFIX: + return None + + for parent_name in reversed(self.queue): + parent_parts = parent_name.split('.') + if len(parent_parts) < 5: + return None + parent_name_prefix = '.'.join(parent_parts[:-3]) + if parent_name_prefix == child_name_prefix: + return parent_name + return None + + def is_json_file(file_path): if isinstance(file_path, str) and file_path.lower().endswith('.json'): return True @@ -148,14 +233,6 @@ def check_compare_param(input_param, output_path, dump_mode, stack_mode): _check_json(stack_json, input_param.get("stack_json_path")) -def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False, is_print_compare_log=True): - arg_list = [stack_mode, auto_analyze, fuzzy_match, is_print_compare_log] - for arg in arg_list: - if not isinstance(arg, bool): - logger.error(f"Invalid input parameter, {arg} which should be only bool type.") - raise CompareException(CompareException.INVALID_PARAM_ERROR) - - def _check_json(json_file_handle, file_name): tensor_line = json_file_handle.readline() if not tensor_line: @@ -191,27 +268,6 @@ def check_regex_prefix_format_valid(prefix): raise ValueError(f"prefix contains invalid characters, prefix pattern {Const.REGEX_PREFIX_PATTERN}") -def execute_command(cmd): - """ - Function Description: - run the following command - Parameter: - cmd: command - Exception Description: - when invalid command throw exception - """ - logger.info('Execute command:%s' % cmd) - process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - while process.poll() is None: - line = process.stdout.readline() - line = line.strip() - if line: - logger.info(line) - if process.returncode != 0: - logger.error('Failed to execute command:%s' % " ".join(cmd)) - raise CompareException(CompareException.INVALID_DATA_ERROR) - - def add_time_as_suffix(name): return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))) @@ -220,6 +276,10 @@ def add_time_with_xlsx(name): return '{}_{}.xlsx'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))) +def add_time_with_json(name): + return '{}_{}.json'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))) + + def add_time_with_yaml(name): return '{}_{}.yaml'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))) @@ -232,21 +292,41 @@ def format_value(value): return float('{:.12f}'.format(value)) -def md5_find(data): - for key_op in data: - for api_info in data[key_op]: - if isinstance(data[key_op][api_info], list): - for data_detail in data[key_op][api_info]: - if data_detail and 'md5' in data_detail: - return True - if isinstance(data[key_op][api_info], bool): - continue - elif data[key_op][api_info] and 'md5' in data[key_op][api_info]: +@recursion_depth_decorator('msprobe.core.common.utils.md5_find', max_depth=Const.DUMP_MAX_DEPTH) +def md5_find(data, json_type=Const.DUMP_JSON_FILE): + if json_type == Const.DUMP_JSON_FILE: + for key_op in data: + for api_info in data[key_op]: + if isinstance(data[key_op][api_info], list): + for data_detail in data[key_op][api_info]: + if data_detail and Const.MD5 in data_detail: + return True + if isinstance(data[key_op][api_info], bool): + continue + elif data[key_op][api_info] and Const.MD5 in data[key_op][api_info]: + return True + elif json_type == Const.DEBUG_JSON_FILE: + if isinstance(data, dict): + if Const.MD5 in data: return True + else: + for _, data_info in data.items(): + if md5_find(data_info, Const.DEBUG_JSON_FILE): + return True + elif isinstance(data, list): + for data_info in data: + if md5_find(data_info, Const.DEBUG_JSON_FILE): + return True + else: + return False return False def detect_framework_by_dump_json(file_path): + json_data = load_json(file_path) + framework = json_data.get("framework", None) + if framework in [Const.PT_FRAMEWORK, Const.MS_FRAMEWORK]: + return framework pattern_ms = r'"type":\s*"mindspore' pattern_pt = r'"type":\s*"torch' with FileOpen(file_path, 'r') as file: @@ -276,13 +356,41 @@ def get_stack_construct_by_dump_json_path(dump_json_path): def set_dump_path(input_param): npu_path = input_param.get("npu_json_path", None) bench_path = input_param.get("bench_json_path", None) - npu_path_valid = npu_path is not None and npu_path.endswith("dump.json") - bench_path_valid = bench_path is not None and bench_path.endswith("dump.json") - if not npu_path_valid or not bench_path_valid: - logger.error(f"Please check the json path is valid. npu_path: {npu_path}, bench_path: {bench_path}") + dump_json_path_valid = npu_path is not None and npu_path.endswith("dump.json") and \ + bench_path is not None and bench_path.endswith("dump.json") + debug_json_path_valid = npu_path is not None and npu_path.endswith("debug.json") and \ + bench_path is not None and bench_path.endswith("debug.json") + if not dump_json_path_valid and not debug_json_path_valid: + logger.error(f"Please check the json path is valid and ensure that neither npu_path nor bench_path is None.") + raise CompareException(CompareException.INVALID_PATH_ERROR) + input_param[CompareConst.NPU_DUMP_DATA_DIR] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA) + input_param[CompareConst.BENCH_DUMP_DATA_DIR] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA) + + +def get_file_type(file_path): + if not isinstance(file_path, str): + logger.error("get_file_type failed, check the type of file_path.") + raise CompareException(CompareException.INVALID_PATH_ERROR) + file_type = file_suffix_to_file_type.get(file_path.split(Const.SCOPE_SEPARATOR)[-1]) + if file_type is None: + logger.error("get_file_type failed, file_path is neither dump.json nor debug.json.") raise CompareException(CompareException.INVALID_PATH_ERROR) - input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA) - input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA) + return file_type + + +def check_dump_json_key(json_data, device_type): + task = json_data.get('task', None) + if not task: + logger.error(f"Task for {device_type} is empty, please check.") + raise CompareException(CompareException.INVALID_TASK_ERROR) + if 'data' not in json_data: + logger.error(f"Missing 'data' in dump.json, please check dump.json of {device_type}.") + raise CompareException(CompareException.INVALID_DATA_ERROR) + api_data = json_data.get('data') + if not isinstance(api_data, dict): + logger.error(f"Invalid type for 'data': expected a dict. Please check dump.json of {device_type}.") + raise CompareException(CompareException.INVALID_DATA_ERROR) + return task, api_data def get_dump_mode(input_param): @@ -290,13 +398,10 @@ def get_dump_mode(input_param): bench_path = input_param.get("bench_json_path", None) npu_json_data = load_json(npu_path) bench_json_data = load_json(bench_path) + json_type = get_file_type(file_path=npu_path) - npu_task = npu_json_data.get('task', None) - bench_task = bench_json_data.get('task', None) - - if not npu_task or not bench_task: - logger.error(f"Please check the dump task is correct, npu's task is {npu_task}, bench's task is {bench_task}.") - raise CompareException(CompareException.INVALID_TASK_ERROR) + npu_task, npu_api_data = check_dump_json_key(npu_json_data, 'npu') + bench_task, bench_api_data = check_dump_json_key(bench_json_data, 'bench') if npu_task != bench_task: logger.error(f"Please check the dump task is consistent.") @@ -309,8 +414,8 @@ def get_dump_mode(input_param): return Const.STRUCTURE if npu_task == Const.STATISTICS: - npu_md5_compare = md5_find(npu_json_data['data']) - bench_md5_compare = md5_find(bench_json_data['data']) + npu_md5_compare = md5_find(npu_api_data, json_type) + bench_md5_compare = md5_find(bench_api_data, json_type) if npu_md5_compare == bench_md5_compare: return Const.MD5 if npu_md5_compare else Const.SUMMARY else: @@ -424,6 +529,37 @@ def get_real_step_or_rank(step_or_rank_input, obj): return real_step_or_rank +def check_init_step(step): + if not is_int(step): + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, + f"{step} must be an integer") + if not step >= 0: + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, + f"{step} must be greater than or equal to 0") + + +def check_token_range(token_range): + if token_range is None: + return + if not isinstance(token_range, (list, tuple)): + logger.error("Token_range must be a list or tuple.") + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if len(token_range) != 2: + logger.error("Token_range must contains exactly 2 elements.") + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + + start, end = token_range + if not is_int(start) or not is_int(end): + logger.error("Start and end in token_range must be integer.") + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if start > end: + logger.error("Start in token_range must less than the end.") + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if start < 0: + logger.error("Start in token_range must >= 0.") + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + + def check_seed_all(seed, mode, rm_dropout): if is_int(seed): if seed < 0 or seed > Const.MAX_SEED_VALUE: @@ -467,36 +603,6 @@ def safe_get_value(container, index, container_name, key=None): raise MsprobeBaseException(MsprobeBaseException.INVALID_OBJECT_TYPE_ERROR) from e -# 记录工具函数递归的深度 -recursion_depth = defaultdict(int) - - -# 装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。 -def recursion_depth_decorator(func_info): - def decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - func_id = id(func) - recursion_depth[func_id] += 1 - if recursion_depth[func_id] > Const.MAX_DEPTH: - msg = f"call {func_info} exceeds the recursion limit." - logger.error_log_with_exp( - msg, - MsprobeException( - MsprobeException.RECURSION_LIMIT_ERROR, msg - ), - ) - try: - result = func(*args, **kwargs) - finally: - recursion_depth[func_id] -= 1 - return result - - return wrapper - - return decorator - - def check_str_param(param): if not re.match(Const.REGEX_PREFIX_PATTERN, param): logger.error('The parameter {} contains special characters.'.format(param)) @@ -509,4 +615,85 @@ class DumpPathAggregation: construct_file_path = None dump_tensor_data_dir = None free_benchmark_file_path = None - debug_file_path = None \ No newline at end of file + debug_file_path = None + + +def is_save_variable_valid(variable, valid_special_types, depth=0): + if depth > Const.DUMP_MAX_DEPTH: + return False + if isinstance(variable, valid_special_types): + return True + elif isinstance(variable, (list, tuple)): + return all(is_save_variable_valid(item, valid_special_types, depth + 1) for item in variable) + elif isinstance(variable, dict): + return all(isinstance(key, str) and is_save_variable_valid(value, valid_special_types, depth + 1) + for key, value in variable.items()) + else: + return False + + +def replace_last_occurrence(text, old, new): + if text is None: + return text + index = text.rfind(old) + if index != -1: + return text[:index] + text[index:].replace(old, new, 1) + return text + + +def load_stack_json(stack_path): + stack_dict = load_json(stack_path) + + if not isinstance(stack_dict, dict): + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, + "The format of the stack.json is incorrect, the outermost layer of stack.json should be a dict type." + ) + + if not stack_dict.get(Const.NEW_STACK_FLAG): + return stack_dict + + new_stack_dict = {} + for stack_info in stack_dict.values(): + if not isinstance(stack_info, list) or len(stack_info) != 2: + continue + + api_list, stack_str = stack_info + if not isinstance(api_list, list): + continue + + for api_name in api_list: + new_stack_dict.update({api_name: stack_str}) + return new_stack_dict + + +def analyze_api_call_stack(name): + try: + api_stack = inspect.stack()[2:] + except Exception as e: + logger.warning(f"The call stack of {name} failed to retrieve, {e}.") + api_stack = None + stack_str = [] + if api_stack: + for (_, path, line, func, code, _) in api_stack: + if not code: + continue + stack_line = f"File {path}, line {str(line)}, in {func}, \n {code[0].strip()} \n" + stack_str.append(stack_line) + else: + stack_str.append(Const.WITHOUT_CALL_STACK) + return "".join(stack_str) + + +def check_extern_input_list(input_list): + if not isinstance(input_list, list): + raise Exception("input is not a list") + if len(input_list) > Const.EXTERN_INPUT_LIST_MAX_LEN: + raise Exception(f"input list exceed max length {Const.EXTERN_INPUT_LIST_MAX_LEN}") + + +def check_process_num(process_num): + if not is_int(process_num) or process_num <= 0: + raise ValueError(f"process_num({process_num}) is not a positive integer") + if process_num > Const.MAX_PROCESS_NUM: + raise ValueError(f"The maximum supported process_num is {Const.MAX_PROCESS_NUM}, current value: {process_num}.") diff --git a/debug/accuracy_tools/msprobe/core/common_config.py b/debug/accuracy_tools/msprobe/core/common_config.py index b9a717c0c52f11e52ac055e3cfe6a0e77fe7e44c..56a8a9afe704cd5e79d5e8f0887538c93a61f51e 100644 --- a/debug/accuracy_tools/msprobe/core/common_config.py +++ b/debug/accuracy_tools/msprobe/core/common_config.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from msprobe.core.common.const import Const, FileCheckConst +import re + +from msprobe.core.common.const import Const from msprobe.core.common.log import logger from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common.utils import get_real_step_or_rank @@ -28,6 +30,7 @@ class CommonConfig: self.level = json_config.get('level') self.enable_dataloader = json_config.get('enable_dataloader', False) self.async_dump = json_config.get("async_dump", False) + self.precision = json_config.get("precision", Const.DUMP_PRECISION_LOW) self._check_config() def _check_config(self): @@ -49,6 +52,10 @@ class CommonConfig: elif self.async_dump: logger.warning("async_dump is True, it may cause OOM when dumping large tensor.") + if self.precision not in Const.DUMP_PRECISION_LIST: + logger.error_log_with_exp("precision is invalid, it should be one of {}".format(Const.DUMP_PRECISION_LIST), + MsprobeException(MsprobeException.INVALID_PARAM_ERROR)) + class BaseConfig: def __init__(self, json_config): @@ -67,6 +74,7 @@ class BaseConfig: self.if_preheat = json_config.get("if_preheat") self.preheat_step = json_config.get("preheat_step") self.max_sample = json_config.get("max_sample") + self.is_regex_valid = True @staticmethod def _check_str_list_config(config_item, config_name): @@ -83,6 +91,7 @@ class BaseConfig: self._check_str_list_config(self.scope, "scope") self._check_str_list_config(self.list, "list") self._check_data_mode() + self._check_regex_in_list() def _check_data_mode(self): if self.data_mode is not None: @@ -111,3 +120,20 @@ class BaseConfig: f"The element '{mode}' of data_mode {self.data_mode} is not in {Const.DUMP_DATA_MODE_LIST}.", MsprobeException(MsprobeException.INVALID_PARAM_ERROR) ) + + def _check_summary_mode(self): + if self.summary_mode and self.summary_mode not in Const.SUMMARY_MODE: + logger.error_log_with_exp( + f"summary_mode is invalid, summary_mode is not in {Const.SUMMARY_MODE}.", + MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + ) + + def _check_regex_in_list(self): + if self.list: + for name in self.list: + if name.startswith('name-regex(') and name.endswith(')'): + try: + re.compile(name[len('name-regex('):-1]) + except re.error: + self.is_regex_valid = False + break diff --git a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py index 55229d72657c67428186bcb233371e3b9eee73e0..21869ba870fac0ce02f9dabe32b59b03a5f70ea7 100644 --- a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py @@ -13,537 +13,783 @@ # See the License for the specific language governing permissions and # limitations under the License. -import multiprocessing import os import re -from copy import deepcopy +from dataclasses import dataclass +from collections import defaultdict +import numpy as np import pandas as pd from tqdm import tqdm from msprobe.core.advisor.advisor import Advisor from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.file_utils import load_json, remove_path +from msprobe.core.common.file_utils import load_json, remove_path, create_directory, save_excel, save_json from msprobe.core.common.log import logger -from msprobe.core.common.utils import CompareException, add_time_with_xlsx, check_op_str_pattern_valid, safe_get_value -from msprobe.core.compare.check import check_dump_json_str, check_graph_mode, check_stack_json_str, \ - check_struct_match, fuzzy_check_op -from msprobe.core.compare.highlight import find_compare_result_error_rows, highlight_rows_xlsx -from msprobe.core.compare.multiprocessing_compute import ComparisonResult, _handle_multi_process, _save_cmp_result -from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_flag_and_msg -from msprobe.core.compare.utils import get_accuracy, get_rela_diff_summary_mode, get_un_match_accuracy, merge_tensor, \ - print_compare_ends_info, read_op, get_name_and_state, reorder_op_x_list +from msprobe.core.common.utils import CompareException, add_time_with_xlsx, check_op_str_pattern_valid, \ + set_dump_path, get_dump_mode, check_compare_param, load_stack_json, get_file_type, add_time_with_json +from msprobe.core.compare.check import check_dump_json_str, check_stack_json_str, cross_dtype_mapping, \ + check_configuration_param +from msprobe.core.compare.utils import merge_tensor, print_compare_ends_info, read_op, set_stack_json_path, \ + reorder_index +from msprobe.core.compare.config import ModeConfig, MappingConfig, MappingDict +from msprobe.core.compare.multiprocessing_compute import CompareRealData +from msprobe.core.compare.highlight import HighLight +from msprobe.core.compare.diff_analyze.first_diff_analyze import FirstDiffAnalyze + + +@dataclass +class ComparisonConfig: + dump_mode: str + stack_mode: bool + auto_analyze: bool + fuzzy_match: bool + highlight: bool + data_mapping: dict + suffix: str + cell_mapping: dict + api_mapping: dict + layer_mapping: dict + compared_file_type: str + first_diff_analyze: bool + is_print_compare_log: bool -class ModeConfig: - def __init__(self, stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=None): - self.stack_mode = stack_mode - self.auto_analyze = auto_analyze - self.fuzzy_match = fuzzy_match - self.dump_mode = dump_mode +class Comparator: + def __init__(self, file_reader, mode_config: ModeConfig, mapping_config: MappingConfig, is_cross_framework=False): + self.file_reader = file_reader + self.mode_config = mode_config + self.mapping_config = mapping_config + self.cross_frame = is_cross_framework + self.mapping_dict = MappingDict(mapping_config) + + def process_output_file(self, output_path, suffix, compared_file_type): + file_name_prefix_mapping = { + Const.DUMP_JSON_FILE: "compare_result", + Const.DEBUG_JSON_FILE: "debug_compare_result" + } + file_name_prefix = file_name_prefix_mapping.get(compared_file_type, "compare_result") + if self.mode_config.first_diff_analyze: + file_name = add_time_with_json("compare_result" + suffix) + else: + file_name = add_time_with_xlsx(file_name_prefix + suffix) + file_path = os.path.join(os.path.realpath(output_path), file_name) + if os.path.exists(file_path): + logger.warning(f"{file_path} will be deleted.") + remove_path(file_path) + return file_path + def compare_core(self, input_param, output_path, **kwargs): + """ + Compares data from multiple JSON files and generates a comparison report. -class Comparator: - def __init__(self, mode_config: ModeConfig): - self.stack_mode = mode_config.stack_mode - self.auto_analyze = mode_config.auto_analyze - self.fuzzy_match = mode_config.fuzzy_match - self.dump_mode = mode_config.dump_mode + Args: + input_param (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path", + "stack_path"). + output_path (str): The path where the output Excel report will be saved. + **kwargs: Additional keyword arguments including: + - stack_mode (bool, optional): Enables stack mode comparison. Defaults to False. + - auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True. + - suffix (str, optional): Suffix to append to the output file name. Defaults to ''. + - fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False. + - dump_mode (str): ALL, SUMMARY, MD5. - @staticmethod - def get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all, bench_ops_all, *args): - npu_struct = npu_ops_all.get(ms_op_name).get('struct', []) - bench_struct = bench_ops_all.get(bench_op_name).get('struct', []) + Returns: + """ + logger.info("Please check whether the input data belongs to you. If not, there may be security risks.") + + # get kwargs or set default value + suffix = kwargs.get('suffix', '') + rank = suffix[1:] - if len(npu_struct) < 3 or len(bench_struct) < 3: - logger.error(f"The length of npu_struct and bench_struct must be >= 3, " - f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. Please check!") - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) + # process output file + file_path = self.process_output_file(output_path, suffix, self.mode_config.compared_file_type) + + # initialize the compare result table and compare general data(name, dtype, shape, statistics/md5, etc.) + npu_json = input_param.get("npu_json_path") + bench_json = input_param.get("bench_json_path") + stack_json = input_param.get("stack_json_path") + parse_data = ParseData(self.mode_config, rank) # load and parse json data + npu_df, bench_df = parse_data.parse([npu_json, bench_json, stack_json]) + result_df = self.compare_statistics(npu_df, bench_df) + if not result_df.values.tolist(): + logger.warning("Can`t match any op. No compare result file generated.") + return - result_item = [ms_op_name, bench_op_name, npu_struct[0], bench_struct[0], - npu_struct[1], bench_struct[1], npu_struct[2], bench_struct[2], - CompareConst.PASS if npu_struct[2] == bench_struct[2] else CompareConst.DIFF] + if self.mode_config.first_diff_analyze: + # add P2POp additional info from npu_df and bench_df to result_df + result_df['NPU P2POp op'] = npu_df['op'] + result_df['Bench P2POp op'] = bench_df['op'] + result_df['NPU P2POp peer'] = npu_df['peer'] + result_df['Bench P2POp peer'] = bench_df['peer'] + + first_diff_analyze = FirstDiffAnalyze(self.mode_config, rank) + check_result = first_diff_analyze.check(result_df) + save_json(file_path, check_result, indent=4) + logger.info(f"Saving json file to disk: {file_path}") + return - if len(args) >= 2 and args[0]: - result_item.extend(args[1]) + # compare real data + if self.mode_config.dump_mode == Const.ALL: + compare_real_data = CompareRealData(self.file_reader, self.mode_config, self.cross_frame) + result_df = compare_real_data.do_multi_process(input_param, result_df) + + # save result excel file + logger.info(f'Saving result excel file in progress. The file path is: {file_path}.') + if self.mode_config.highlight and len(result_df) <= CompareConst.MAX_EXCEL_LENGTH: + # highlight if not too long + highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []} + highlight = HighLight(self.mode_config, rank) + if self.mode_config.compared_file_type == Const.DUMP_JSON_FILE: + highlight.find_compare_result_error_rows(result_df, highlight_dict) + result_df.drop(columns=['state', 'api_origin_name'], inplace=True) # 删除中间数据,两列不落盘 + highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path) else: - result_item.append(CompareConst.NONE) - return result_item + # fallback to simple save without highlight + result_df.drop(columns=['state', 'api_origin_name'], inplace=True) # 删除中间数据,两列不落盘 + save_excel(file_path, result_df) - @staticmethod - def calculate_summary_data(npu_summary_data, bench_summary_data, result_item): - err_msg = "" - result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data, - bench_summary_data, err_msg) - result_item.append(accuracy_check) - result_item.append(err_msg) + # output compare analysis suggestions + if self.mode_config.auto_analyze: + advisor = Advisor(result_df, output_path, suffix) + advisor.analysis() - @staticmethod - def _generate_na_data(ops_all): - if not ops_all: - return {} - key = next(iter(ops_all)) - value = deepcopy(ops_all[key]) - for k, v in value.items(): - if isinstance(v, tuple): - value[k] = tuple(CompareConst.N_A for _ in range(len(v))) - elif isinstance(v, list): - value[k] = [CompareConst.N_A] * len(v) - else: - value[k] = CompareConst.N_A - return value + print_compare_ends_info() + + def compare_statistics(self, npu_df, bench_df): + npu_df[[Const.DTYPE, Const.SHAPE]] = npu_df[[Const.DTYPE, Const.SHAPE]].astype(str) + bench_df[[Const.DTYPE, Const.SHAPE]] = bench_df[[Const.DTYPE, Const.SHAPE]].astype(str) + + # create new columns for compare op_name and shape + # process npu_df's COMPARE_KEY whether same or different framework + process_df = ProcessDf(self.mode_config, self.mapping_config, self.mapping_dict) + npu_df, bench_df = process_df.process_compare_key_and_shape(npu_df, bench_df) + + # match npu and bench, match_result contains both npu_info and bench_info + match = Match(self.mode_config, self.mapping_config, self.cross_frame) + match_result = match.match_api_infos(npu_df, bench_df) + # 筛选出npu_name存在的行并填充筛选出行中的缺失值为N/A + match_result = match_result[match_result['op_name_x'].notna()].fillna(CompareConst.N_A) + bench_columns = [i + '_y' for i in bench_df.columns] + match_result.loc[~match.gen_dtype_condition(match_result), bench_columns] = CompareConst.N_A + + # organize compare result table by renaming columns + if self.mode_config.dump_mode == Const.ALL and self.mode_config.first_diff_analyze: + self.mode_config.dump_mode = Const.SUMMARY + create_table = CreateTable(self.mode_config) + result_df, header = create_table.make_result_df(match_result) + + # calculate statistics diff + calc_stats_diff = CalcStatsDiff(self.mode_config) + return calc_stats_diff.calc_accuracy(result_df, header) + + +class ParseData: + def __init__(self, mode_config: ModeConfig, rank): + self.mode_config = mode_config + self.rank = rank + + def parse(self, file_list): + npu_json_path, bench_json_path, stack_json_path = file_list + npu_json_data = load_json(npu_json_path) + bench_json_data = load_json(bench_json_path) + stack_json_data = load_stack_json(stack_json_path) if self.mode_config.stack_mode else None + + # parse json data and generate df + npu_df = self.gen_data_df(npu_json_data, stack_json_data, 'NPU') + bench_df = self.gen_data_df(bench_json_data, stack_json_data, 'Bench') + + return npu_df, bench_df + + def gen_data_df(self, data_json, stack_json_data, device: str): + result = { + CompareConst.OP_NAME: [], + Const.DTYPE: [], + Const.SHAPE: [], + Const.SUMMARY: [], + Const.STACK_INFO: [], + Const.STATE: [], + Const.API_ORIGIN_NAME: [], + Const.REQ_GRAD: [] + } + if self.mode_config.dump_mode == Const.ALL: + result[Const.DATA_NAME] = [] + elif self.mode_config.dump_mode == Const.MD5: + result[Const.MD5] = [] + + apis_data = data_json.get('data', None) + if not apis_data: + logger.warning('No APIs found in dump.json.') + return pd.DataFrame(result) + + api_nums = len(apis_data) + default_bar_desc = f'{device} API/Module Read Progress' + bar_desc_add_rank = f'[{self.rank}]' + default_bar_desc if self.rank else default_bar_desc + progress_bar = tqdm(total=api_nums, desc=bar_desc_add_rank, unit="api/module", ncols=100) + + # 从json中循环解析API数据,遍历所有API + for data_name in apis_data: + check_op_str_pattern_valid(data_name) + op_parsed_list = self.gen_merge_list(data_json, data_name, stack_json_data) + if not op_parsed_list: + continue + reordered_index_list = reorder_index(op_parsed_list) + for i, index in enumerate(reordered_index_list): + op_item = op_parsed_list[index] + + # common key + result[CompareConst.OP_NAME].append(op_item.get('full_op_name')) + result[Const.DTYPE].append(op_item.get(Const.DTYPE)) + result[Const.SHAPE].append(op_item.get(Const.SHAPE)) + result[Const.STATE].append(op_item.get(Const.STATE)) + result[Const.REQ_GRAD].append(op_item.get(Const.REQ_GRAD)) + result[Const.API_ORIGIN_NAME].append(data_name) + summary_data = [ + str(op_item.get(key)) if op_item.get(key) is None else op_item.get(key) + for key in Const.SUMMARY_METRICS_LIST + ] + result[Const.SUMMARY].append(summary_data) - def make_result_table(self, result): - header = CompareConst.HEAD_OF_COMPARE_MODE[self.dump_mode][:] + # dump_mode differ key + if self.mode_config.dump_mode == Const.MD5: + result[Const.MD5].append(op_parsed_list[index].get(Const.MD5)) + if self.mode_config.dump_mode == Const.ALL: + result[Const.DATA_NAME].append(op_item.get(Const.DATA_NAME)) - if self.stack_mode: - header.append(CompareConst.STACK) - if self.dump_mode == Const.ALL: - header.append(CompareConst.DATA_NAME) - else: - if self.dump_mode == Const.ALL: - for row in result: - del row[-2] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,真实数据时为倒数第2列 - header.append(CompareConst.DATA_NAME) - else: - for row in result: - del row[-1] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,非真实数据时为倒数第1列 - result_df = pd.DataFrame(result, columns=header, dtype='object') - return result_df + # mode_config stack_mode addition key + if i == 0 and self.mode_config.stack_mode: + result[Const.STACK_INFO].append(op_parsed_list[-1].get('full_info')) + else: + result[Const.STACK_INFO].append(None) + + # mode_config first_diff_analyze addition key + if self.mode_config.first_diff_analyze: + result.setdefault('op', []).append(op_item.get('op', str(None))) + result.setdefault('peer', []).append(op_item.get('peer', str(None))) + + progress_bar.update(1) + progress_bar.close() + return pd.DataFrame(result) def gen_merge_list(self, json_data, op_name, stack_json_data): op_data = json_data['data'][op_name] - check_dump_json_str(op_data, op_name) + if self.mode_config.compared_file_type == Const.DUMP_JSON_FILE: + check_dump_json_str(op_data, op_name) op_parsed_list = read_op(op_data, op_name) - if self.stack_mode: + if self.mode_config.stack_mode: stack_info = stack_json_data.get(op_name) if stack_info is not None: check_stack_json_str(stack_info, op_name) - # append only when stack_mode is True, - op_parsed_list.append({ - 'full_op_name': op_name, - 'full_info': stack_info - }) - - merge_list = merge_tensor(op_parsed_list, self.dump_mode) - return merge_list - - def check_op(self, npu_dict, bench_dict): - npu_op_name = npu_dict[CompareConst.OP_NAME] - bench_op_name = bench_dict[CompareConst.OP_NAME] - graph_mode = check_graph_mode(safe_get_value(npu_op_name, 0, "npu_op_name"), - safe_get_value(bench_op_name, 0, "bench_op_name")) - - frame_name = getattr(self, "frame_name") - if frame_name == "PTComparator": - from msprobe.pytorch.compare.match import graph_mapping - if graph_mode: - return graph_mapping.match(npu_op_name[0], bench_op_name[0]) - struct_match = check_struct_match(npu_dict, bench_dict) - if not self.fuzzy_match: - name_match = npu_op_name == bench_op_name - return name_match and struct_match - try: - name_match = fuzzy_check_op(npu_op_name, bench_op_name) - except Exception as err: - logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name)) - name_match = False - return name_match and struct_match + else: + stack_info = None + # always add stack_info whether stack_mode is True + op_parsed_list.append({ + 'full_op_name': op_name, + 'full_info': stack_info + }) + return op_parsed_list - def match_op(self, npu_queue, bench_queue): - for b_index, b_op in enumerate(bench_queue[0: -1]): - if self.check_op(npu_queue[-1], b_op): - return len(npu_queue) - 1, b_index - if self.check_op(npu_queue[-1], bench_queue[-1]): - return len(npu_queue) - 1, len(bench_queue) - 1 - for n_index, n_op in enumerate(npu_queue[0: -1]): - if self.check_op(n_op, bench_queue[-1]): - return n_index, len(bench_queue) - 1 - return -1, -1 - def compare_process(self, file_lists): - npu_json_path, bench_json_path, stack_json_path = file_lists - npu_json_data = load_json(npu_json_path) - bench_json_data = load_json(bench_json_path) - stack_json_data = load_json(stack_json_path) if self.stack_mode else None +class ProcessDf: + def __init__(self, mode_config: ModeConfig, mapping_config: MappingConfig, mapping_dict: MappingDict): + self.mode_config = mode_config + self.mapping_config = mapping_config + self.mapping_dict = mapping_dict - if self.fuzzy_match: - logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.") + @staticmethod + def get_api_name(api_list): + try: + api_name = api_list[0] + Const.SEP + api_list[1] + except IndexError as error: + logger.error('Failed to retrieve API name, please check if the dump data is reasonable') + raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error + return api_name + + def process_compare_key_and_shape(self, npu_df, bench_df): + npu_df = self.assign_npu_df_compare_key(npu_df, bench_df) + npu_df[CompareConst.CMP_SHAPE] = npu_df[Const.SHAPE] + bench_df[CompareConst.CMP_KEY] = bench_df[CompareConst.OP_NAME] + bench_df[CompareConst.CMP_SHAPE] = bench_df[Const.SHAPE] + return npu_df, bench_df + + def assign_npu_df_compare_key(self, npu_df, bench_df): + """ + 处理 npu_df 的 COMPARE_KEY 赋值逻辑 - npu_ops_queue = [] - bench_ops_queue = [] - result = [] + :param npu_df: DataFrame,NPU 对比数据 + :param bench_df: DataFrame,Bench 对比数据 + :return: compare_key(name)处理后的 npu_df + """ + # 处理api_mapping映射 + if self.mapping_config.api_mapping: + # 如果用户不传api_mapping.yaml,先使用内置api_mapping.yaml替换npu_op_name + npu_df[CompareConst.CMP_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_internal_api_mapping) + # 如果用户传入api_mapping.yaml,再使用传入api_mapping.yaml进一步替换npu_op_name + if isinstance(self.mapping_config.api_mapping, str): + self.modify_compare_data_with_user_mapping(npu_df, bench_df) + # 处理cell_mapping映射 + elif self.mapping_config.cell_mapping: + npu_df[CompareConst.CMP_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_cell_mapping) + # 处理data_mapping映射 + elif self.mapping_config.data_mapping: + npu_df[CompareConst.CMP_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_data_mapping) + else: + npu_df[CompareConst.CMP_KEY] = npu_df[CompareConst.OP_NAME] + return npu_df + + def process_internal_api_mapping(self, npu_op_name): + # get api name & class name from op_name + ms_api_name = self.get_api_name(npu_op_name.split(Const.SEP)) + class_name = ms_api_name.split(Const.SEP)[0] + if class_name == "Mint": + return npu_op_name.replace("Mint", "Torch") + elif class_name == "MintFunctional": + return npu_op_name.replace("MintFunctional", "Functional") + elif self.mapping_dict.ms_to_pt_mapping.get(ms_api_name): + return npu_op_name.replace(ms_api_name, self.mapping_dict.ms_to_pt_mapping.get(ms_api_name)) + else: + return npu_op_name + + def modify_compare_data_with_user_mapping(self, npu_df, bench_df): + def remove_prefix(string, prefix): + if string.startswith(prefix): + return string[len(prefix):] + return string + + def gen_input_compare_key(pattern, term): + is_unmatched = True + for i, prefix in enumerate(mapping_dict.get(f'ms_{term}')): + if remove_prefix(op_name, api_origin_name + pattern) == str(prefix): + npu_df.loc[index, CompareConst.CMP_KEY] = ( + op_name.replace(pattern + str(prefix), pattern + str(mapping_dict.get(f'pt_{term}')[i]))) + is_unmatched = False + return is_unmatched + + ms_api_indices_dict = self.get_api_indices_dict(npu_df) + pt_api_indices_dict = self.get_api_indices_dict(bench_df) + + for mapping_dict in self.mapping_dict.api_mapping_dict: + all_length_equal = True + for k1, k2 in CompareConst.API_MAPPING_KEYS_TO_COMPARE: + if len(mapping_dict.get(k1, [])) != len(mapping_dict.get(k2, [])): + all_length_equal = False + if not all_length_equal: + logger.warning('The user-defined mapping table is incorrect,\ + make sure that the number of parameters is equal') + continue - ops_npu_iter = iter(npu_json_data['data']) - ops_bench_iter = iter(bench_json_data['data']) - read_err_npu = True - read_err_bench = True - last_npu_ops_len = 0 - last_bench_ops_len = 0 + ms_api, pt_api = mapping_dict.get('ms_api'), mapping_dict.get('pt_api') + if ms_api not in ms_api_indices_dict or pt_api not in pt_api_indices_dict: + continue + for index in ms_api_indices_dict.get(ms_api): + op_name = npu_df.loc[index, CompareConst.OP_NAME].replace(ms_api, pt_api, 1) + state = npu_df.loc[index, Const.STATE] + api_origin_name = npu_df.loc[index, Const.API_ORIGIN_NAME].replace(ms_api, pt_api, 1) + if state == Const.INPUT: + is_abandoned = gen_input_compare_key(CompareConst.INPUT_PATTERN, 'args') + elif state == Const.KWARGS: + is_abandoned = gen_input_compare_key(CompareConst.KWARGS_PATTERN, 'args') + elif state == Const.OUTPUT: + is_abandoned = gen_input_compare_key(CompareConst.OUTPUT_PATTERN, 'output') + elif state == Const.PARAMS: + is_abandoned = gen_input_compare_key(CompareConst.PARAMS_PATTERN, 'parameters') + elif state == Const.PARAMS_GRAD: + is_abandoned = gen_input_compare_key(CompareConst.PARAMS_GRAD_PATTERN, 'parameters_grad') + else: + logger.error(f'Excepted op_name: {op_name}') + raise CompareException(CompareException.INVALID_DATA_ERROR) + if is_abandoned: + npu_df.loc[index, CompareConst.CMP_KEY] = op_name + 'abandoned' - npu_api_nums = len(npu_json_data['data']) - progress_bar = tqdm(total=npu_api_nums, desc="API/Module Read Progress", unit="item", ncols=100) + def get_api_indices_dict(self, op_name_df): + """ + 生成多个api对应的各自的所有的input、output等的index的键值对字典 + 示例: + {'Functional.conv2d': [0, 1, 2, 3], + 'Functional.batch_norm': [4, 5, 6, 7, 8] + } + """ + api_indices_dict = defaultdict(list) + for op_index, name in enumerate(op_name_df[CompareConst.OP_NAME]): + api_name = self.get_api_name(name.split(Const.SEP)) + api_indices_dict[api_name].append(op_index) + return api_indices_dict + + def process_cell_mapping(self, npu_op_name): + if not npu_op_name: + return CompareConst.N_A + param_grad_flag = Const.PARAMS_GRAD in npu_op_name.split(Const.SEP) + if not param_grad_flag and not re.search(Const.REGEX_FORWARD_BACKWARD, npu_op_name): + return CompareConst.N_A + npu_op_name = npu_op_name.replace("Cell", "Module", 1) + if self.mapping_dict.cell_mapping_dict: + # get cell name & class name from op_name + # Cell.fc1.Dense.forward.0.input.0 + cell_name = re.split(r'\.(?:forward|backward|parameters_grad)\.', npu_op_name.split(Const.SEP, 1)[-1])[0] + if cell_name in self.mapping_dict.cell_mapping_dict: + npu_op_name = npu_op_name.replace(cell_name, self.mapping_dict.cell_mapping_dict[cell_name], 1) + return npu_op_name + + def process_data_mapping(self, npu_op_name): + return self.mapping_dict.data_mapping_dict.get(npu_op_name, npu_op_name) + + +class Match: + def __init__(self, mode_config: ModeConfig, mapping_config: MappingConfig, cross_frame): + self.mode_config = mode_config + self.mapping_config = mapping_config + self.cross_frame = cross_frame - while True: - if not read_err_npu and not read_err_bench: - break - try: - last_npu_ops_len = len(npu_ops_queue) - op_name_npu = next(ops_npu_iter) - check_op_str_pattern_valid(op_name_npu) - npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data) - if npu_merge_list: - npu_ops_queue.append(npu_merge_list) - except StopIteration: - read_err_npu = False - try: - last_bench_ops_len = len(bench_ops_queue) - op_name_bench = next(ops_bench_iter) - check_op_str_pattern_valid(op_name_bench) - bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data) - if bench_merge_list: - bench_ops_queue.append(bench_merge_list) - except StopIteration: - read_err_bench = False + @staticmethod + def put_unmatched_in_table(match_result, npu_op_item): + npu_columns = npu_op_item.index.tolist()[:-2] + bench_columns = [name + '_y' for name in npu_columns] + na_series = pd.Series([CompareConst.N_A] * len(bench_columns), index=bench_columns) + new_result_item = pd.concat([npu_op_item, na_series]).to_frame().T + new_result_item.columns = CompareConst.MATCH_RESULT_COLUMNS + match_result = pd.concat([match_result, new_result_item]) + return match_result - progress_bar.update(1) + @staticmethod + def put_matched_in_table(match_result, npu_op_item, bench_op_item): + head_len = len(CompareConst.MATCH_RESULT_COLUMNS) + new_result_item = pd.concat([npu_op_item, bench_op_item]).head(head_len).to_frame().T + new_result_item.columns = CompareConst.MATCH_RESULT_COLUMNS + match_result = pd.concat([match_result, new_result_item]) + return match_result - # merge all boolean expressions - both_empty = not npu_ops_queue and not bench_ops_queue - no_change = (len(npu_ops_queue) == last_npu_ops_len) and (len(bench_ops_queue) == last_bench_ops_len) - if both_empty or no_change: - continue + @staticmethod + def rename_api(op_name): + """ + 原api: {api_type}.{api_name}.{API调用次数}.{前向反向}.{input/output}.{参数序号} + rename后: {api_type}.{api_name}.{前向反向}.{input/output}.{参数序号} + """ + if Const.FORWARD not in op_name and Const.BACKWARD not in op_name: + return op_name + process = Const.FORWARD if Const.FORWARD in op_name else Const.BACKWARD + name_split = op_name.split(process) + try: + torch_func_index, in_out = name_split[0], name_split[1] + except IndexError as error: + logger.error(f'{op_name} can not be split with {process}, please check!') + raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error + torch_func_split = torch_func_index.rsplit(Const.SEP, 2) + torch_func = str(torch_func_split[0]) + Const.SEP + process + str(in_out) + return torch_func + + def check_op_item(self, npu_op_item, bench_op_item): + name_match = self.rename_api(npu_op_item[CompareConst.CMP_KEY]) == self.rename_api( + bench_op_item[CompareConst.CMP_KEY]) + shape_match = npu_op_item[CompareConst.CMP_SHAPE] == bench_op_item[CompareConst.CMP_SHAPE] + if name_match and shape_match: + return True + else: + npu_op_name = npu_op_item[CompareConst.OP_NAME] + bench_op_name = bench_op_item[CompareConst.OP_NAME] + check_op_str_pattern_valid(npu_op_name) + check_op_str_pattern_valid(bench_op_name) + logger.warning(f"{npu_op_name} and {bench_op_name} can not fuzzy match") + return False + + def match_api_infos(self, npu_df, bench_df): + """ + 正常匹配和模糊匹配 + """ + if self.mapping_config.data_mapping: + match_result = pd.merge(npu_df, bench_df, on=[CompareConst.CMP_KEY], how='left') + + # reorder match_result by op_name of npu + op_name_order = npu_df[CompareConst.OP_NAME].tolist() + match_result[CompareConst.OP_NAME_X] = pd.Categorical(match_result[CompareConst.OP_NAME_X], + categories=op_name_order, ordered=True) + match_result = match_result.sort_values(CompareConst.OP_NAME_X).reset_index(drop=True) + match_result[CompareConst.OP_NAME_X] = match_result[CompareConst.OP_NAME_X].astype('object') + elif not self.mode_config.fuzzy_match: + match_result = pd.merge(npu_df, bench_df, on=[CompareConst.CMP_KEY, CompareConst.CMP_SHAPE], + how='outer') + else: + match_result = self.process_fuzzy_match(npu_df, bench_df) + return match_result - # APIs in NPU and Bench models unconsistent judgment + def process_fuzzy_match(self, npu_df, bench_df): + """ + 模糊匹配通过循环方式匹配api + """ + npu_ops_queue = [] + bench_ops_queue = [] + match_result = pd.DataFrame(columns=CompareConst.MATCH_RESULT_COLUMNS) + + max_len = max(len(npu_df), len(bench_df)) + min_len = min(len(npu_df), len(bench_df)) + for i in range(max_len): + if i < min_len: + npu_ops_queue.append(npu_df.iloc[i]) + bench_ops_queue.append(bench_df.iloc[i]) + else: + try: + npu_ops_queue.append(npu_df.iloc[i]) + except IndexError: + pass + try: + bench_ops_queue.append(bench_df.iloc[i]) + except IndexError: + pass + + # 如果append之后queue状态不一致,则判断结束 if bool(npu_ops_queue) ^ bool(bench_ops_queue): - logger.info("Please check whether the number and calls of APIs in NPU and Bench models are consistent.") break - n_match_point, b_match_point = self.match_op(npu_ops_queue, bench_ops_queue) + npu_match_point, bench_match_point = self.match_op(npu_ops_queue, bench_ops_queue) - # 如果没有匹配到,数据放到队列中,跳过,直到后面匹配到,把匹配之前的api放到不匹配中 - if n_match_point == -1 and b_match_point == -1: + # 如果没有匹配到,数据放到队列中,跳过。直到后面匹配到,把匹配之前的api放到不匹配中 + if npu_match_point == -1 and bench_match_point == -1: continue - n_match_data = npu_ops_queue[n_match_point] - b_match_data = bench_ops_queue[b_match_point] - un_match_data = npu_ops_queue[0: n_match_point] - for npu_data in un_match_data: - get_un_match_accuracy(result, npu_data, self.dump_mode) - get_accuracy(result, n_match_data, b_match_data, self.dump_mode) - del npu_ops_queue[0: n_match_point + 1] - del bench_ops_queue[0: b_match_point + 1] - progress_bar.close() - if npu_ops_queue: - for npu_data in npu_ops_queue: - get_un_match_accuracy(result, npu_data, self.dump_mode) - - result_df = self.make_result_table(result) - return result_df - - def merge_data(self, json_data, stack_json_data): - ops_all = {} - for op_name in json_data.get('data', {}): - merge_list = self.gen_merge_list(json_data, op_name, stack_json_data) - if merge_list: - struct_to_index_mapping = { - CompareConst.INPUT_STRUCT: 0, - CompareConst.OUTPUT_STRUCT: 0, - CompareConst.PARAMS_STRUCT: 0, - CompareConst.PARAMS_GRAD_STRUCT: 0 - } - - op_name_list = merge_list.get(CompareConst.OP_NAME) - summary_list = merge_list.get(Const.SUMMARY) - data_name_list = merge_list.get('data_name') - op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list, - summary_list, - data_name_list) - for index, op_full_name in enumerate(op_name_reorder): - data_name = data_name_reorder[index] if data_name_reorder else None - - _, state = get_name_and_state(op_full_name) - struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state) - if not struct_key: - continue - ops_all[op_full_name] = { - CompareConst.STRUCT: safe_get_value(merge_list, struct_to_index_mapping.get(struct_key), - "merge_list", key=struct_key), - CompareConst.SUMMARY: safe_get_value(summary_reorder, index, "summary_reorder"), - 'data_name': data_name, - 'stack_info': merge_list.get('stack_info') - } - struct_to_index_mapping[struct_key] += 1 - return ops_all - - def get_accuracy(self, npu_ops_all, bench_ops_all): - result = [] - bench_ops_all[CompareConst.N_A] = self._generate_na_data(bench_ops_all) - for ms_op_name, bench_op_name in self.data_mapping_dict.items(): - if ms_op_name in npu_ops_all and bench_op_name in bench_ops_all: - npu_stack_info = npu_ops_all.get(ms_op_name).get("stack_info", None) - bench_stack_info = bench_ops_all.get(bench_op_name).get("stack_info", None) - has_stack = npu_stack_info and bench_stack_info - if self.dump_mode == Const.MD5: - result.append(self.get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all, - bench_ops_all, has_stack, npu_stack_info)) - continue - - npu_struct = npu_ops_all.get(ms_op_name).get('struct', []) - bench_struct = bench_ops_all.get(bench_op_name).get('struct', []) - - if len(npu_struct) < 2 or len(bench_struct) < 2: - logger.error( - f"The length of npu_struct and bench_struct must be >= 2, " - f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. " - f"Please check!" - ) - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) - - base_result_item = [ - ms_op_name, bench_op_name, - npu_struct[0], - bench_struct[0], - npu_struct[1], - bench_struct[1] - ] + npu_op_item = npu_ops_queue[npu_match_point] + bench_op_item = bench_ops_queue[bench_match_point] + unmatched_data = npu_ops_queue[0: npu_match_point] + for op_item in unmatched_data: + match_result = self.put_unmatched_in_table(match_result, op_item) + match_result = self.put_matched_in_table(match_result, npu_op_item, bench_op_item) + del npu_ops_queue[0: npu_match_point + 1] + del bench_ops_queue[0: bench_match_point + 1] - if self.dump_mode == Const.SUMMARY: - result_item = base_result_item + [" "] * 8 - else: - result_item = base_result_item + [" "] * 5 - - npu_summary_data = npu_ops_all.get(ms_op_name).get("summary") - result_item.extend(npu_summary_data) - bench_summary_data = bench_ops_all.get(bench_op_name).get("summary") - result_item.extend(bench_summary_data) - if self.dump_mode == Const.SUMMARY: - self.calculate_summary_data(npu_summary_data, bench_summary_data, result_item) - else: - result_item.append(CompareConst.ACCURACY_CHECK_YES) - result_item.append("") - if has_stack: - result_item.extend(npu_stack_info) - else: - result_item.append(CompareConst.NONE) - if self.dump_mode == Const.ALL: - result_item.append(npu_ops_all.get(ms_op_name).get("data_name", None)) - result.append(result_item) - elif ms_op_name not in npu_ops_all: - logger.warning(f'Can not find npu op name : `{ms_op_name}` in npu dump json file.') - elif bench_op_name not in npu_ops_all: - logger.warning(f'Can not find bench op name : `{bench_op_name}` in bench dump json file.') - return result + if npu_ops_queue: + for op_item in npu_ops_queue: + match_result = self.put_unmatched_in_table(match_result, op_item) - def compare_process_custom(self, file_lists): - npu_json_path, bench_json_path, stack_json_path = file_lists - npu_json_data = load_json(npu_json_path) - bench_json_data = load_json(bench_json_path) - stack_json_data = load_json(stack_json_path) if self.stack_mode else None - npu_ops_all = self.merge_data(npu_json_data, stack_json_data) - bench_ops_all = self.merge_data(bench_json_data, stack_json_data) + match_result.reset_index(drop=True, inplace=True) + return match_result - result = self.get_accuracy(npu_ops_all, bench_ops_all) - result_df = self.make_result_table(result) - return result_df + def match_op(self, npu_queue, bench_queue): + for b_index, b_op in enumerate(bench_queue[0: -1]): + if self.check_op_item(npu_queue[-1], b_op): + return len(npu_queue) - 1, b_index + if self.check_op_item(npu_queue[-1], bench_queue[-1]): + return len(npu_queue) - 1, len(bench_queue) - 1 + for n_index, n_op in enumerate(npu_queue[0: -1]): + if self.check_op_item(n_op, bench_queue[-1]): + return n_index, len(bench_queue) - 1 + return -1, -1 - def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param, bench_data): + def gen_dtype_condition(self, match_result): """ - :param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0 - :param bench_op_name: excel中的Bench_Name,例如:Functional.conv2d.0.forward.input.3.0 - :param op_name_mapping_dict: op_name和npy或pt文件的映射关系 - :param input_param: npu_json_path/bench_json_path/stack_json_path等参数 - :param bench_data: bench的dump数据中"data"字段 - :return: result_list,包含余弦相似度、最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率和错误信息 - 用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、 - 最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率并生成错误信息 + dtype匹配条件为npu、bench的dtype一致或属于规定的映射关系 """ - npu_bench_name_list = op_name_mapping_dict[npu_op_name] - data_name = safe_get_value(npu_bench_name_list, 1, "npu_bench_name_list") - error_file, relative_err, error_flag = None, None, False - bench_data_name = get_bench_data_name(bench_op_name, bench_data) - if data_name == '-1' or data_name == -1: # 没有真实数据路径 - n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE - error_flag = True - elif not bench_data_name: - n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True - error_file = 'no_bench_data' - else: - try: - read_npy_data = getattr(self, "read_npy_data") - frame_name = getattr(self, "frame_name") - if frame_name == "MSComparator": - n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.NUMPY_SUFFIX) - if self.cross_frame: - b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name, - load_pt_file=True) - else: - b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name) - else: - n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.PT_SUFFIX) - b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name) - except IOError as error: - error_file = error.filename - n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE - error_flag = True - except (FileCheckException, CompareException): - error_file = data_name - n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE - error_flag = True - - # 通过n_value, b_value同时得到错误标志和错误信息 - n_value, b_value, error_flag, err_msg = get_error_flag_and_msg(n_value, b_value, - error_flag=error_flag, error_file=error_file) - - result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg) - - if self.fuzzy_match and npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A: - err_msg += " Fuzzy matching data, the comparison accuracy may be affected." - result_list.append(err_msg) - return result_list + # 如果使用了data_mapping,不校验dtype,返回全True的DataFrame + if self.mapping_config.data_mapping: + return pd.Series(True, index=match_result.index) + + npu_dtype = match_result['dtype_x'] + bench_dtype = match_result['dtype_y'] + npu_dtype = self.process_cross_frame_dtype(npu_dtype) + bench_dtype = self.process_cross_frame_dtype(bench_dtype) + + equal_condition = npu_dtype == bench_dtype + match_condition = ( + (npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[0]) & bench_dtype.isin( + CompareConst.DTYPE_MATCH_GROUPS[0])) | + (npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[1]) & bench_dtype.isin( + CompareConst.DTYPE_MATCH_GROUPS[1])) + ) + return equal_condition | match_condition - def compare_core(self, input_param, output_path, **kwargs): - """ - Compares data from multiple JSON files and generates a comparison report. + def process_cross_frame_dtype(self, dtype): + if self.cross_frame: + dtype = dtype.map(cross_dtype_mapping).fillna(dtype) + return dtype - Args: - input_param (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path", - "stack_path"). - output_path (str): The path where the output Excel report will be saved. - **kwargs: Additional keyword arguments including: - - stack_mode (bool, optional): Enables stack mode comparison. Defaults to False. - - auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True. - - suffix (str, optional): Suffix to append to the output file name. Defaults to ''. - - fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False. - - dump_mode (str): ALL, SUMMARY, MD5. - Returns: - """ - # get kwargs or set default value - suffix = kwargs.get('suffix', '') +class CreateTable: + def __init__(self, mode_config: ModeConfig): + self.mode_config = mode_config - logger.info("Please check whether the input data belongs to you. If not, there may be security risks.") - file_name = add_time_with_xlsx("compare_result" + suffix) - file_path = os.path.join(os.path.realpath(output_path), file_name) - remove_path(file_path) - highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []} + @staticmethod + def process_data_name(result): + result['data_name_x'] = result.apply(lambda row: [row['data_name_x'], row['data_name_y']], axis=1) + return result - npu_json = input_param.get("npu_json_path") - bench_json = input_param.get("bench_json_path") - stack_json = input_param.get("stack_json_path") - if self.data_mapping: - result_df = self.compare_process_custom([npu_json, bench_json, stack_json]) - else: - result_df = self.compare_process([npu_json, bench_json, stack_json]) + @staticmethod + def set_summary(summary): + if summary == CompareConst.N_A: + return [CompareConst.N_A] * 4 # 4为统计值个数 + summary_list = [] + for i in summary: + if str(i).lower() == 'nan': + summary_list.append(CompareConst.NAN) + else: + summary_list.append(i) + return summary_list - if not result_df.values.tolist(): - logger.warning("Can`t match any op.") - return + def make_result_df(self, result): + # get header + header = CompareConst.HEAD_OF_COMPARE_MODE[self.mode_config.dump_mode][:] + if self.mode_config.stack_mode: + header.append(CompareConst.STACK) + if self.mode_config.dump_mode == Const.ALL: + header.append(CompareConst.DATA_NAME) + result = self.process_data_name(result) + + # rename match_result columns + result.rename(columns={'op_name_x': CompareConst.NPU_NAME, + 'op_name_y': CompareConst.BENCH_NAME, + 'dtype_x': CompareConst.NPU_DTYPE, + 'dtype_y': CompareConst.BENCH_DTYPE, + 'shape_x': CompareConst.NPU_SHAPE, + 'shape_y': CompareConst.BENCH_SHAPE, + 'md5_x': CompareConst.NPU_MD5, + 'md5_y': CompareConst.BENCH_MD5, + 'data_name_x': CompareConst.DATA_NAME, + 'stack_info_x': CompareConst.STACK, + 'state_x': Const.STATE, + 'api_origin_name_x': Const.API_ORIGIN_NAME, + 'requires_grad_x': CompareConst.NPU_REQ_GRAD, + 'requires_grad_y': CompareConst.BENCH_REQ_GRAD + }, + inplace=True) + + # process summary data + npu_summary = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM] + bench_summary = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN, + CompareConst.BENCH_NORM] + # process requires_grad + result[CompareConst.REQ_GRAD_CONSIST] = result[CompareConst.NPU_REQ_GRAD] == result[CompareConst.BENCH_REQ_GRAD] + + if result.empty: + result[npu_summary] = pd.DataFrame(columns=npu_summary) + result[bench_summary] = pd.DataFrame(columns=bench_summary) + else: + result[npu_summary] = result['summary_x'].apply(self.set_summary).tolist() + result[bench_summary] = result['summary_y'].apply(self.set_summary).tolist() - if self.dump_mode == Const.ALL: - result_df = self.do_multi_process(input_param, result_df) + header.extend([Const.STATE, Const.API_ORIGIN_NAME]) + result_df = pd.DataFrame(columns=header) + for h in header: + if h in result.columns: + result_df[h] = result[h] + return result_df, header - find_compare_result_error_rows(result_df, highlight_dict, self.dump_mode) - highlight_rows_xlsx(result_df, highlight_dict, file_path) - if self.auto_analyze: - advisor = Advisor(result_df, output_path, suffix) - advisor.analysis() +class CalcStatsDiff: + def __init__(self, mode_config: ModeConfig): + self.mode_config = mode_config - print_compare_ends_info() + @staticmethod + def type_check(val): + """ + 检查是否为数值或字符串形式的nan, 如果是返回True + """ + check_series = pd.Series(False, index=val.index) + val_str = val.astype(str) + check_series[pd.to_numeric(val_str, errors='coerce').notna() | val_str.str.lower().eq('nan')] = True + return check_series - def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param): - cos_result = [] - max_err_result = [] - max_relative_err_result = [] - err_mess = [] - one_thousand_err_ratio_result = [] - five_thousand_err_ratio_result = [] - is_print_compare_log = input_param.get("is_print_compare_log") - bench_data = load_json(input_param.get("bench_json_path")).get('data') - for i in range(len(result_df)): - npu_op_name = result_df.iloc[i, 0] - bench_op_name = result_df.iloc[i, 1] - if is_print_compare_log: - logger.info("start compare: {}".format(npu_op_name)) - - cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = \ - self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param, bench_data) - - if is_print_compare_log: - logger.info( - "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, \ - one_thousand_err_ratio {}, " - "five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err, - err_msg, one_thousand_err_ratio, five_thousand_err_ratio)) - cos_result.append(cos_sim) - max_err_result.append(max_abs_err) - max_relative_err_result.append(max_relative_err) - err_mess.append(err_msg) - one_thousand_err_ratio_result.append(one_thousand_err_ratio) - five_thousand_err_ratio_result.append(five_thousand_err_ratio) - - cr = ComparisonResult( - cos_result=cos_result, - max_err_result=max_err_result, - max_relative_err_result=max_relative_err_result, - err_msgs=err_mess, - one_thousand_err_ratio_result=one_thousand_err_ratio_result, - five_thousand_err_ratio_result=five_thousand_err_ratio_result + @staticmethod + def get_number(val): + return pd.to_numeric(val.astype(str), errors='coerce') + + def calc_summary_diff(self, result_df, cond_no_bench, stats_index: str): + npu_val = result_df['NPU ' + stats_index] + bench_val = result_df['Bench ' + stats_index] + diff_name = stats_index.capitalize() + ' diff' + rel_err_name = ('norm' if stats_index == 'l2norm' else stats_index).capitalize() + 'RelativeErr' + + # npu、bench中统计量均为数字或nan + cond_num_nan = self.type_check(npu_val) & self.type_check(bench_val) + + # 如果统计量不是数字或nan,就赋值统计量差异为N/A + result_df.loc[~cond_num_nan, [diff_name, rel_err_name]] = CompareConst.N_A + cond_valid_stat = ~cond_no_bench & cond_num_nan # 有效统计条件:bench_name不是N/A,并且NPU和bench的统计量都是数字或nan + result_df.loc[cond_valid_stat, diff_name] = self.get_number(npu_val) - self.get_number(bench_val) + + cond_diff_nan = result_df[diff_name].isna() # 统计量差异是nan + cond_nan_diff = cond_valid_stat & cond_diff_nan + result_df.loc[cond_nan_diff, [diff_name, rel_err_name]] = CompareConst.NAN + + cond_not_nan_diff = cond_valid_stat & ~cond_diff_nan + condition_pt_zero = self.get_number(bench_val) == 0 + result_df.loc[cond_not_nan_diff & condition_pt_zero, rel_err_name] = CompareConst.N_A + + # 相对误差转成百分比字符串 + cond_ref_err = cond_not_nan_diff & ~condition_pt_zero + result_df.loc[cond_ref_err, rel_err_name] = ( + result_df.loc[cond_ref_err, diff_name] / bench_val[cond_ref_err].astype(float) * 100) + result_df.loc[cond_ref_err, rel_err_name] = (result_df.loc[cond_ref_err, rel_err_name].abs().astype(str) + '%') + + magnitude = self.get_number(result_df[diff_name]).abs() / (pd.Series( + np.maximum(self.get_number(npu_val), self.get_number(bench_val))).abs() + CompareConst.EPSILON) + return magnitude > CompareConst.MAGNITUDE + + def calc_accuracy(self, result_df, header): + # bench name N/A represents no bench data, err_msg adds "No bench data matched." + condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A + result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A) + result_df.loc[condition_no_bench, CompareConst.ERROR_MESSAGE] = CompareConst.NO_BENCH + condition_req_grad_consist = result_df[CompareConst.NPU_REQ_GRAD] == result_df[CompareConst.BENCH_REQ_GRAD] + + if self.mode_config.dump_mode == Const.MD5: + condition_md5_equal = result_df[CompareConst.NPU_MD5] == result_df[CompareConst.BENCH_MD5] + result_df.loc[condition_md5_equal, CompareConst.RESULT] = CompareConst.PASS + result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF + elif self.mode_config.first_diff_analyze or self.mode_config.dump_mode == Const.SUMMARY: + warning_list = [ + self.calc_summary_diff(result_df, condition_no_bench, stats_index) + for stats_index in ['max', 'min', 'mean', 'l2norm'] + ] + warning_flag = pd.DataFrame(warning_list).any() + result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = '' + result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING + result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy. ' + result_df.loc[~condition_req_grad_consist, CompareConst.ERROR_MESSAGE] += 'Requires_grad inconsistent. ' + else: + fill_cols = [CompareConst.COSINE, CompareConst.EUC_DIST, + CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR, + CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO, + CompareConst.ERROR_MESSAGE] + result_df.loc[~condition_no_bench, fill_cols] = '' # 默认填充'', df默认省缺值为nan,不便后续处理,容易出现意外情况 + result_df.loc[~condition_no_bench, CompareConst.ACCURACY] = CompareConst.ACCURACY_CHECK_YES + result_df.loc[~condition_req_grad_consist, CompareConst.ERROR_MESSAGE] = 'Requires_grad inconsistent. ' + + return result_df[header] + + +def setup_comparison(input_param, output_path, **kwargs) -> ComparisonConfig: + """公共的前置处理逻辑,返回封装后的 ComparisonConfig 对象""" + try: + config = ComparisonConfig( + dump_mode='', + stack_mode=False, + auto_analyze=kwargs.get('auto_analyze', True), + fuzzy_match=kwargs.get('fuzzy_match', False), + highlight=kwargs.get('highlight', False), + data_mapping=kwargs.get('data_mapping', {}), + suffix=kwargs.get('suffix', ''), + cell_mapping=kwargs.get('cell_mapping', {}), + api_mapping=kwargs.get('api_mapping', {}), + layer_mapping=kwargs.get('layer_mapping', {}), + first_diff_analyze=kwargs.get('first_diff_analyze', False), + compared_file_type='', + is_print_compare_log=input_param.get('is_print_compare_log', True) ) - return _save_cmp_result(idx, cr, result_df, lock) + set_dump_path(input_param) + config.dump_mode = get_dump_mode(input_param) + config.compared_file_type = get_file_type(input_param.get("npu_json_path", None)) - def do_multi_process(self, input_parma, result_df): - try: - result_df = _handle_multi_process(self.compare_ops, input_parma, result_df, - multiprocessing.Manager().RLock()) - return result_df - except ValueError as e: - logger.error('result dataframe is not found.') - raise CompareException(CompareException.INVALID_DATA_ERROR) from e - - -def get_bench_data_name(bench_op_name, bench_data): - bench_name_list = re.split(r'\.(input|output|kwargs|parameters|parameters_grad)\.', bench_op_name) - if len(bench_name_list) > 1 and bench_name_list[1] == Const.PARAMS_GRAD: - bench_data_bundle = bench_data.get(bench_name_list[0] + Const.SEP + bench_name_list[1], {}) - else: - bench_data_bundle = bench_data.get(bench_name_list[0], {}) - if not bench_data_bundle or len(bench_name_list) < 3: - return None - layers = bench_name_list[2].split(Const.SEP) - - def _get(key, container): - if isinstance(container, dict): - return container.get(key) - if isinstance(container, list): - try: - return container[int(key)] - except (ValueError, IndexError): - return None - return None - - def get_by_layer(container, params_grad=False): - data = container - # dump.json中parameters_grad的结构为key:[{}], 如果存在key,有且只有一个列表元素,而op_name中只命名到了key,因此加'0' - if params_grad: - layers.append('0') - for layer in layers: - data = _get(layer, data) - return _get(CompareConst.DATA_NAME.lower(), data) - - if Const.INPUT == bench_name_list[1]: - return get_by_layer(bench_data_bundle.get(Const.INPUT, bench_data_bundle.get(Const.INPUT_ARGS))) - elif Const.KWARGS == bench_name_list[1]: - return get_by_layer(bench_data_bundle.get(Const.INPUT_KWARGS)) - elif Const.OUTPUT == bench_name_list[1]: - return get_by_layer(bench_data_bundle.get(Const.OUTPUT)) - elif Const.PARAMS == bench_name_list[1]: - return get_by_layer(bench_data_bundle.get(Const.PARAMS)) - elif Const.PARAMS_GRAD == bench_name_list[1]: - return get_by_layer(bench_data_bundle, params_grad=True) - else: - return None + # set stack_mode and set "stack_json_path" in input_param + if 'stack_json_path' in input_param: + config.stack_mode = kwargs.get('stack_mode', False) + else: + config.stack_mode = set_stack_json_path(input_param) + + check_configuration_param(config) + create_directory(output_path) + check_compare_param(input_param, output_path, config.dump_mode, config.stack_mode) + + return config + + except (CompareException, FileCheckException) as error: + logger.error('Compare failed. Please check the arguments and do it again!') + raise CompareException(error.code) from error diff --git a/debug/accuracy_tools/msprobe/core/compare/check.py b/debug/accuracy_tools/msprobe/core/compare/check.py index 653823e20b29b14b6e7ede929f3bd2865bffaa18..78fb0d35503d8d8012d33e197221a3a66d5200a3 100644 --- a/debug/accuracy_tools/msprobe/core/compare/check.py +++ b/debug/accuracy_tools/msprobe/core/compare/check.py @@ -13,118 +13,49 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + from msprobe.core.common.log import logger -from msprobe.core.compare.utils import rename_api from msprobe.core.common.utils import check_op_str_pattern_valid, CompareException -from msprobe.core.common.const import CompareConst, Const - -dtype_mapping = { - "Int8": "torch.int8", - "UInt8": "torch.uint8", - "Int16": "torch.int16", - "UInt16": "torch.uint16", - "Int32": "torch.int32", - "UInt32": "torch.uint32", - "Int64": "torch.int64", - "UInt64": "torch.uint64", - "Float16": "torch.float16", - "Float32": "torch.float32", - "Float64": "torch.float64", - "Bool": "torch.bool", - "BFloat16": "torch.bfloat16", - "Complex64": "torch.complex64", - "Complex128": "torch.complex128" +from msprobe.core.common.const import Const + +cross_dtype_mapping = { + "Int8": "int", + "torch.int8": "int", + "UInt8": "int", + "torch.uint8": "int", + "Int16": "int", + "torch.int16": "int", + "UInt16": "int", + "torch.uint16": "int", + "Int32": "int", + "torch.int32": "int", + "UInt32": "int", + "torch.uint32": "int", + "Int64": "int", + "torch.int64": "int", + "UInt64": "int", + "torch.uint64": "int", + + "Float16": "float", + "torch.float16": "float", + "Float32": "float", + "torch.float32": "float", + "Float64": "float", + "torch.float64": "float", + "BFloat16": "float", + "torch.bfloat16": "float", + + "Bool": "bool", + "torch.bool": "bool", + + "Complex64": "complex", + "torch.complex64": "complex", + "Complex128": "complex", + "torch.complex128": "complex", } -def compare_op_dict_struct(npu_dict, bench_dict): - return all(npu_dict.get(key) == bench_dict.get(key) for key in CompareConst.STRUCT_COMPARE_KEY) - - -def check_struct_match(npu_dict, bench_dict): - is_match = compare_op_dict_struct(npu_dict, bench_dict) - if not is_match: - struct_match_list = [] - try: - for i, key in enumerate(CompareConst.STRUCT_COMPARE_KEY): - # 首先额外检查input_struct是否空,input_struct不可能为空 - if i == 0 and (not npu_dict.get(key, []) or not bench_dict.get(key, [])): - return False - struct_match_list.append(check_type_shape_match(npu_dict.get(key, []), bench_dict.get(key, []))) - except CompareException as error: - err_msg = f'index out of bounds error occurs in npu or bench api, please check!\n' \ - f'npu_dict: {npu_dict}' \ - f'bench_dict: {bench_dict}' - logger.error(err_msg) - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error - is_match = all(struct_match_list) - return is_match - - -def check_type_shape_match(npu_struct, bench_struct): - """ - further check dtypes with a dtype mapping list when dtypes are not entirely consistent. - """ - if len(npu_struct) != len(bench_struct): - return False - if not npu_struct and not bench_struct: - return True - - struct_match = False - for npu_type_shape, bench_type_shape in zip(npu_struct, bench_struct): - try: - npu_type = npu_type_shape[0] - npu_shape = npu_type_shape[1] - bench_type = bench_type_shape[0] - bench_shape = bench_type_shape[1] - except IndexError as error: - logger.error(f'length of npu_type_shape: {npu_type_shape} and bench_type_shape: {bench_type_shape} ' - f'should both be 2, please check!') - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error - shape_match = npu_shape == bench_shape - type_match = npu_type == bench_type - if not type_match: - if ([npu_type, bench_type] in CompareConst.MS_TYPE) or ([npu_type, bench_type] in CompareConst.TORCH_TYPE): - type_match = True - else: - type_match = False - struct_match = shape_match and type_match - if not struct_match: - return False - return struct_match - - -def check_graph_mode(a_op_name, b_op_name): - if Const.ATEN in a_op_name and Const.ATEN not in b_op_name: - return True - if Const.ATEN not in a_op_name and Const.ATEN in b_op_name: - return True - return False - - -def fuzzy_check_op(npu_name_list, bench_name_list): - # 先检查api里的item长度是否相等,如果不是parameters_grad, 必然有input或者output,长度不可能为0 - # 如果是parameters_grad, "parameters_grad"字段的字典不会是空字典,因此len>=1 - if len(npu_name_list) == 0 or len(bench_name_list) == 0 or len(npu_name_list) != len(bench_name_list): - return False - is_match = True - for npu_name, bench_name in zip(npu_name_list, bench_name_list): - is_match = fuzzy_check_name(npu_name, bench_name) - if not is_match: - break - return is_match - - -def fuzzy_check_name(npu_name, bench_name): - if Const.FORWARD in npu_name and Const.FORWARD in bench_name: - is_match = rename_api(npu_name, Const.FORWARD) == rename_api(bench_name, Const.FORWARD) - elif Const.BACKWARD in npu_name and Const.BACKWARD in bench_name: - is_match = rename_api(npu_name, Const.BACKWARD) == rename_api(bench_name, Const.BACKWARD) - else: - is_match = npu_name == bench_name - return is_match - - def check_dump_json_str(op_data, op_name): input_list = op_data.get(Const.INPUT_ARGS, None) if op_data.get(Const.INPUT_ARGS, None) else op_data.get( Const.INPUT, None) @@ -177,3 +108,14 @@ def check_stack_json_str(stack_info, op_name): else: logger.error(f"Expected stack_info to be a list, but got {type(stack_info).__name__} for '{op_name}'") raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR) + + +def check_configuration_param(config): + arg_list = [config.stack_mode, config.auto_analyze, config.fuzzy_match, + config.highlight, config.first_diff_analyze, config.is_print_compare_log] + arg_names = ['stack_mode', 'auto_analyze', 'fuzzy_match', + 'highlight', 'first_diff_analyze', 'is_print_compare_log'] + for arg, name in zip(arg_list, arg_names): + if not isinstance(arg, bool): + logger.error(f"Invalid input parameter, {name} which should be only bool type.") + raise CompareException(CompareException.INVALID_PARAM_ERROR) diff --git a/debug/accuracy_tools/msprobe/core/compare/compare_cli.py b/debug/accuracy_tools/msprobe/core/compare/compare_cli.py index 7df7315043cb57b057871a7d12f5aa63cf927c74..991978250b133710ce893118403fc86727b354a2 100644 --- a/debug/accuracy_tools/msprobe/core/compare/compare_cli.py +++ b/debug/accuracy_tools/msprobe/core/compare/compare_cli.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,39 +13,59 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json -from msprobe.core.common.file_utils import check_file_type, load_json +import os + +from msprobe.core.common.file_utils import check_file_type, load_json, check_file_or_directory_path from msprobe.core.common.const import FileCheckConst, Const from msprobe.core.common.utils import CompareException from msprobe.core.common.log import logger +from msprobe.core.compare.utils import get_paired_dirs + +def compare_cli(args, depth=1): + if depth > 2: + logger.error("Recursive compare error, depth exceeds 2.") + raise CompareException(CompareException.RECURSION_LIMIT_ERROR) + + if isinstance(args.input_path, dict): # special for dyn-graph mix compare + input_param = args.input_path + else: + input_param = load_json(args.input_path) + if not isinstance(input_param, dict): + logger.error("input_param should be dict, please check!") + raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR) -def compare_cli(args): - input_param = load_json(args.input_path) npu_path = input_param.get("npu_path", None) bench_path = input_param.get("bench_path", None) if not npu_path: - logger.error(f"Missing npu_path in configuration file {args.input_path}, please check!") + logger.error(f"Missing npu_path in input configuration file, please check!") raise CompareException(CompareException.INVALID_PATH_ERROR) if not bench_path: - logger.error(f"Missing bench_path in configuration file {args.input_path}, please check!") + logger.error(f"Missing bench_path in input configuration file, please check!") raise CompareException(CompareException.INVALID_PATH_ERROR) + frame_name = args.framework auto_analyze = not args.compare_only + if frame_name == Const.PT_FRAMEWORK: from msprobe.pytorch.compare.pt_compare import compare from msprobe.pytorch.compare.distributed_compare import compare_distributed else: from msprobe.mindspore.compare.ms_compare import ms_compare from msprobe.mindspore.compare.distributed_compare import ms_compare_distributed, ms_graph_compare + from msprobe.mindspore.compare.common_dir_compare import common_dir_compare common_kwargs = { "auto_analyze": auto_analyze, "fuzzy_match": args.fuzzy_match, + "highlight": args.highlight, "data_mapping": args.data_mapping, + "diff_analyze": args.diff_analyze } if check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE: + check_file_or_directory_path(npu_path) + check_file_or_directory_path(bench_path) input_param["npu_json_path"] = input_param.pop("npu_path") input_param["bench_json_path"] = input_param.pop("bench_path") if "stack_path" not in input_param: @@ -67,6 +87,14 @@ def compare_cli(args): } ms_compare(input_param, args.output_path, **kwargs) elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR: + check_file_or_directory_path(npu_path, isdir=True) + check_file_or_directory_path(bench_path, isdir=True) + + if depth == 1: + mix_compare_success = mix_compare(args, input_param, depth) + if mix_compare_success: + return + kwargs = { **common_kwargs, "stack_mode": args.stack_mode, @@ -78,6 +106,17 @@ def compare_cli(args): if input_param.get("rank_id") is not None: ms_graph_compare(input_param, args.output_path) return + common = input_param.get("common", False) + if isinstance(common, bool) and common: + common_dir_compare(input_param, args.output_path) + return + + if common_kwargs.get('diff_analyze', False): + logger.info("Start finding first diff node......") + from msprobe.core.compare.find_first.analyzer import DiffAnalyzer + DiffAnalyzer(npu_path, bench_path, args.output_path, frame_name).analyze() + return + if frame_name == Const.PT_FRAMEWORK: compare_distributed(npu_path, bench_path, args.output_path, **kwargs) else: @@ -85,3 +124,34 @@ def compare_cli(args): else: logger.error("The npu_path and bench_path need to be of the same type.") raise CompareException(CompareException.INVALID_COMPARE_MODE) + + +def mix_compare(args, input_param, depth): + npu_path = input_param.get("npu_path", None) + bench_path = input_param.get("bench_path", None) + + npu_bench_same_dirs_set = set(get_paired_dirs(npu_path, bench_path)) + compare_cross_set = npu_bench_same_dirs_set & Const.MIX_DUMP_NAMES + + if compare_cross_set: + logger.info("Start mix compare.") + origin_output = args.output_path + + for folder_name in list(compare_cross_set): + new_npu_path = os.path.join(npu_path, folder_name) + new_bench_path = os.path.join(bench_path, folder_name) + paired_steps = get_paired_dirs(new_npu_path, new_bench_path) + + for step_name in paired_steps: + logger.info(f"[mix compare] Start comparing {folder_name}/{step_name}") + npu_dir = os.path.join(new_npu_path, step_name) + bench_dir = os.path.join(new_bench_path, step_name) + args.input_path = { + "npu_path": npu_dir, + "bench_path": bench_dir, + "is_print_compare_log": input_param.get("is_print_compare_log", True) + } + args.output_path = os.path.join(origin_output, folder_name, step_name) + compare_cli(args, depth + 1) + return True + return False diff --git a/debug/accuracy_tools/msprobe/core/compare/config.py b/debug/accuracy_tools/msprobe/core/compare/config.py new file mode 100644 index 0000000000000000000000000000000000000000..8743aab8d781028cb620617876f18914ac322b55 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/compare/config.py @@ -0,0 +1,74 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from msprobe.core.common.const import Const, CompareConst +from msprobe.core.common.file_utils import load_yaml + + +class ModeConfig: + def __init__(self, **kwargs): + self.stack_mode = kwargs.get('stack_mode', False) + self.auto_analyze = kwargs.get('auto_analyze', True) + self.fuzzy_match = kwargs.get('fuzzy_match', False) + self.highlight = kwargs.get('highlight', False) + self.dump_mode = kwargs.get('dump_mode', Const.SUMMARY) + self.first_diff_analyze = kwargs.get('first_diff_analyze', False) + self.diff_analyze = kwargs.get('diff_analyze', False) + self.compared_file_type = kwargs.get('compared_file_type', Const.DUMP_JSON_FILE) + + +class MappingConfig: + def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None): + self.cell_mapping = cell_mapping + self.api_mapping = api_mapping + self.data_mapping = data_mapping + + +class MappingDict: + def __init__(self, mapping_config: MappingConfig): + self.cell_mapping_dict = self.load_mapping_file(mapping_config.cell_mapping) + self.api_mapping_dict = self.load_mapping_file(mapping_config.api_mapping) + if mapping_config.api_mapping is not None: + self.ms_to_pt_mapping = self.load_internal_api() + self.data_mapping_dict = self.init_data_mapping(mapping_config.data_mapping) + + @staticmethod + def load_internal_api(): + cur_path = os.path.dirname(os.path.realpath(__file__)) + yaml_path = os.path.abspath(os.path.join(cur_path, CompareConst.INTERNAL_API_MAPPING_FILE)) + return load_yaml(yaml_path) + + @staticmethod + def load_mapping_file(mapping_file): + if isinstance(mapping_file, str): + mapping_dict = load_yaml(mapping_file) + else: + mapping_dict = {} + return mapping_dict + + def init_data_mapping(self, data_mapping): + """ + 初始化data_mapping_dict + """ + if isinstance(data_mapping, str) or data_mapping is None: + data_mapping_dict = self.load_mapping_file(data_mapping) + elif isinstance(data_mapping, dict): + data_mapping_dict = data_mapping + else: + raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got " + f"{type(data_mapping)}") + return data_mapping_dict diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py b/debug/accuracy_tools/msprobe/core/compare/diff_analyze/__init__.py similarity index 100% rename from debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py rename to debug/accuracy_tools/msprobe/core/compare/diff_analyze/__init__.py diff --git a/debug/accuracy_tools/msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml b/debug/accuracy_tools/msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4b35d10bfb99f00a93e2fd6ad69112c6a40efce1 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml @@ -0,0 +1,14 @@ +compare_metrics: + - MaxRelativeErr + - MinRelativeErr + - MeanRelativeErr + - NormRelativeErr + +MaxRelativeErr: + - 0.5 +MinRelativeErr: + - 0.5 +MeanRelativeErr: + - 0.5 +NormRelativeErr: + - 0.5 \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/core/compare/diff_analyze/first_diff_analyze.py b/debug/accuracy_tools/msprobe/core/compare/diff_analyze/first_diff_analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..924312a2bc892b1d70e771cfabad5292ff486ebd --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/compare/diff_analyze/first_diff_analyze.py @@ -0,0 +1,123 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from tqdm import tqdm + +from msprobe.core.common.const import Const, CompareConst +from msprobe.core.common.utils import logger, CompareException +from msprobe.core.common.file_utils import load_yaml +from msprobe.core.compare.config import ModeConfig +from msprobe.core.compare.utils import gen_api_batches + + +cur_dir = os.path.dirname(os.path.realpath(__file__)) +diff_threshold_yaml_path = os.path.join(cur_dir, 'diff_analyze_threshold.yaml') +thresholds = load_yaml(diff_threshold_yaml_path) +cmp_metrics = thresholds.get('compare_metrics') + + +class FirstDiffAnalyze: + def __init__(self, mode_config: ModeConfig, rank): + self.mode_config = mode_config + self.rank = rank + + @staticmethod + def single_metric_diff_check(cmp_metric, metric_value): + threshold = thresholds.get(cmp_metric, None) + if threshold is None: + logger.error(f"Check diff or {cmp_metric} need to configure the threshold. " + f"Please configure it in 'diff_analyze_threshold.yaml'.") + raise CompareException(CompareException.MISSING_THRESHOLD_ERROR) + if not isinstance(threshold, list) or len(threshold) != 1: + logger.error(f"{cmp_metric} threshold configure wrong. Please check.") + raise CompareException(CompareException.WRONG_THRESHOLD_ERROR) + if isinstance(metric_value, str) and metric_value.endswith('%'): + metric_value_float = float(metric_value[:-1]) / 100 + if metric_value_float > threshold[0]: + return True + return False + + def single_api_check(self, result_slice, header): + """ + 单个api差异检查 + + :param result_slice: 数据切片 + :param header: 列名列表 + :return: {'is_same': bool, 'op_items': list[dict]} + """ + single_check_result = { + 'is_same': True, + 'op_items': [] + } + + column_indices = {name: idx for idx, name in enumerate(header)} + + for line in result_slice: + op_item = { + column_name: line[column_indices[column_name]] + for column_name in header + } + single_check_result['op_items'].append(op_item) + + # set is_same + if self.mode_config.dump_mode == Const.MD5: + if line[column_indices[CompareConst.RESULT]] == CompareConst.DIFF: + single_check_result['is_same'] = False + else: + for cmp_metric in cmp_metrics: + metric_value = line[column_indices[cmp_metric]] + if self.single_metric_diff_check(cmp_metric, metric_value): + single_check_result['is_same'] = False + break + return single_check_result + + def check(self, result_df): + """ + 比对后循环遍历api检查差异 + example: + { + 'Functional.conv2d.0.forward': { + 'is_same': true, + 'op_items': [ + { + 'NPU name': 'Functional.conv2d.0.forward.input.0', + 'Bench name': 'Functional.conv2d.0.forward.input.0', + 'xxx': 1, + 'NormRelativeErr': 2, + 'yyy': 3, + ... + } + ] + } + } + """ + result = result_df.values + header = result_df.columns.tolist() + + api_batches = gen_api_batches(result, header) + + check_result = {} + + default_bar_desc = 'API/Module diff check Progress' + bar_desc_add_rank = f'[{self.rank}]' + default_bar_desc if self.rank else default_bar_desc + with tqdm(total=len(api_batches), desc=bar_desc_add_rank, unit="api/module", ncols=100) as progress_bar: + for api_batch in api_batches: + result_slice = result[api_batch.start: api_batch.params_grad_end_index] + check_result[api_batch.api_name] = self.single_api_check(result_slice, header) + progress_bar.update(1) + + return check_result diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/unittest/__init__.py b/debug/accuracy_tools/msprobe/core/compare/find_first/__init__.py similarity index 100% rename from debug/accuracy_tools/msprobe/pytorch/monitor/unittest/__init__.py rename to debug/accuracy_tools/msprobe/core/compare/find_first/__init__.py diff --git a/debug/accuracy_tools/msprobe/core/compare/find_first/analyzer.py b/debug/accuracy_tools/msprobe/core/compare/find_first/analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..0036fba964443960ca8ad6dbfe9826de03ca37d7 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/compare/find_first/analyzer.py @@ -0,0 +1,274 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from collections import defaultdict +import os +from itertools import dropwhile, chain + +from msprobe.core.common import const +from msprobe.core.common.file_utils import check_file_or_directory_path, save_json, make_dir +from msprobe.core.common.log import logger +from msprobe.core.common.const import Const +from msprobe.core.compare.find_first.data_processor import DataProcessor +from msprobe.core.compare.find_first.utils import (RankPath, FileCache, is_communication_op, is_ignore_op, + DiffAnalyseConst, analyze_diff_in_group) +from msprobe.core.compare.find_first.graph import DataNode, CommunicationNode + + +class DiffAnalyzer: + def __init__(self, npu_path, bench_path, output_path, data_frame=Const.PT_FRAMEWORK): + self._bench_path = bench_path + self._npu_path = npu_path + self._output_path = output_path + self.pre_processor = DataProcessor(data_frame) + self._paths = {} + self._diff_nodes = [] # 记录所有异常节点 + self._cache = FileCache() + self._first_comm_nodes = {} # 记录各rank下首个通信节点的node_id + self._after_comm_diffs = {} # 记录各rank下发生在通信节点之后的异常计算节点 + self._rank_comm_nodes_dict = {} # 记录各rank的通信节点 + + def analyze(self): + self._pre_process() + for analyze_func in [self._pre_analyze, self._analyze, self._post_analyze]: + analyze_func() + if self._diff_nodes: + self._gen_analyze_info() + return + logger.info('Cannot find any diff node, no need to generate analyze file.') + + def _pre_process(self): + self.pre_processor.process(self._npu_path, self._bench_path, self._output_path) + self._resolve_input_path(self._output_path) + logger.info("Pre Process completed.") + + """ + 这里需要生成stack,但是直接用dict中自带就行,在op_items.NPU_Stack_Info中 + """ + def _resolve_input_path(self, result_input_path): + contents = os.listdir(result_input_path) + rank_paths = {} + + for path in contents: + # 检查文件名是否符合compare_result_rank{rank_id}_{timestamp}.json格式 + if not path.startswith('compare_result_rank'): + continue + if not path.endswith('.json'): + continue + + # 从文件名中提取rank_id + try: + path_ele_list = path.split('_') + if len(path_ele_list) <= 2: + continue + rank_part = path_ele_list[2] + if not rank_part.startswith('rank'): + continue + rank_str = rank_part.strip('rank') # 去掉'rank'前缀 + rank = int(rank_str) if rank_str else 0 + except (IndexError, ValueError): + continue + + # 构建完整的json文件路径 + dump_path = os.path.join(result_input_path, path) + rank_paths[rank] = RankPath(rank, dump_path) + + # 按照rank id排序后添加到self._paths中 + for rank in sorted(rank_paths.keys()): + self._paths[rank] = rank_paths[rank] + + def _pre_analyze(self): + logger.info('Start searching diff node before communication.') + for path in self._paths.values(): + dump_data = self._cache.load_json(path.dump_path) + if not dump_data: + logger.warning(f'Rank {path.rank} has no dump data!') + continue + for op_name, op_data in dump_data.items(): + if is_communication_op(op_name): + self._first_comm_nodes[path.rank] = op_name + break + data_node = DataNode(op_name, path.rank, op_data) + if data_node.is_diff: + self._diff_nodes.append(data_node) + break + + def _analyze(self): + logger.info('Start searching diff node during communication.') + self._rank_comm_nodes_dict = {rank: self._analyze_comm_nodes(rank) for rank in self._paths} + self._connect_comm_nodes() + self._pruning() + self._search_first_diff() + + def _post_analyze(self): + logger.info('Start searching diff node after communication.') + for nodes in self._after_comm_diffs.values(): + if nodes: + self._diff_nodes.append(nodes[0]) + + def _connect_comm_nodes(self): + searched_ranks = set() + for rank, nodes in list(self._rank_comm_nodes_dict.items())[:-1]: + searched_ranks.add(rank) + seen_nodes = set() + for cur_node in nodes.values(): + conn_info = cur_node.find_connected_nodes() + if not conn_info.get('ranks'): + conn_info['ranks'] = self._rank_comm_nodes_dict.keys() + if not self._find_connection(conn_info, cur_node, searched_ranks, seen_nodes): + logger.debug(f'Cannot find connected communication node for "{cur_node.node_id}".') + + def _find_connection(self, conn_info, cur_node, searched_ranks, seen_nodes): + def connect(search_node): + seen_nodes.add(search_node.node_id) + if search_node.type == DiffAnalyseConst.DST: + cur_node.add_dst(search_node) + elif search_node.type == DiffAnalyseConst.SRC: + search_node.layer = cur_node.layer + search_node.add_dst(cur_node) + else: + cur_node.add_link(search_node) + + found = cur_node.connected + for connected_rank in conn_info['ranks']: + if connected_rank in searched_ranks: + continue + tar_id_prefix = f'{connected_rank}.{conn_info["api"]}' + for search_id, search_node in self._rank_comm_nodes_dict[connected_rank].items(): + if search_id in seen_nodes: + continue + if not (search_id.startswith(tar_id_prefix) and search_node.type == conn_info.get('type')): + continue + search_conn_ranks = search_node.find_connected_nodes().get('ranks') + if ((not search_conn_ranks and search_node.api not in DiffAnalyseConst.DIRECTED_API) or + cur_node.rank in search_conn_ranks): # 有些无向通信算子没有填ProcessGroup,默认连接所有rank + connect(search_node) + found = True + break + return found + + def _analyze_comm_nodes(self, rank): + path = self._paths[rank] + data = self._cache.load_json(path.dump_path) + communication_nodes = {} + if rank not in self._first_comm_nodes: # 此rank没有通信节点 + return communication_nodes + last_node_id = None # 记录上一个通信节点的node_id + compute_ops = [] # 记录两个通信节点之间的计算节点 + sub_layer = 0 # 记录两个通信算子之间异常计算节点的调用序数 + for op_name in dropwhile(lambda k: k != self._first_comm_nodes[rank], data): + node_id = f'{rank}.{op_name}' + op_data = data[op_name] + if is_communication_op(op_name): + comm_node = CommunicationNode(node_id, rank, DataNode(op_name, rank, op_data, sub_layer=sub_layer), + compute_ops=compute_ops) + if last_node_id: + communication_nodes.get(last_node_id).add_next(comm_node) + communication_nodes[node_id] = comm_node + last_node_id = node_id + compute_ops = [] + sub_layer = 0 + elif not is_ignore_op(op_name): + data_node = DataNode(op_name, rank, op_data, sub_layer=sub_layer) + if data_node.is_diff: + compute_ops.append(data_node) + sub_layer += 1 + if compute_ops: + self._after_comm_diffs[rank] = compute_ops + return communication_nodes + + def _pruning(self): + deleted_node_id = [] + for nodes in self._rank_comm_nodes_dict.values(): + for node_id in list(nodes.keys()): + node = nodes[node_id] + if node.is_diff or node.compute_ops: + continue + deleted_node_id.append(node_id) + node.delete() + del nodes[node_id] + logger.debug(f'After pruning, following nodes are removed: [{", ".join(deleted_node_id)}]') + + def _search_first_diff(self): + nodes_queues = [] + for comm_nodes in self._rank_comm_nodes_dict.values(): + nodes_queues.append(sorted(list(comm_nodes.values()), key=lambda x: x.layer)) + seen_nodes = set() + + def get_next_node(node_list): + while node_list: + next_node = node_list.pop(0) + if next_node.node_id not in seen_nodes: + return next_node + return None + + def find_all_members(ori_node): + ids = get_relative_ids(ori_node) + id_queue = list(chain(*[get_relative_ids(self._get_node_by_id(n_id)).difference(ids) for n_id in ids])) + while id_queue: + new_id = id_queue.pop(0) + ids.add(new_id) + id_queue.extend(get_relative_ids(self._get_node_by_id(new_id)).difference(ids)) + return ids + + def get_relative_ids(ori_node): + if not ori_node: + return set() + return ({ori_node.node_id} | ori_node.link_nodes.keys() | ori_node.src_nodes.keys() | + ori_node.dst_nodes.keys()) + + while any(nodes_queues): + groups = [] + all_ids_in_groups = set() + for nodes in nodes_queues: + node = get_next_node(nodes) + if not node: + continue + if not groups or node.node_id not in all_ids_in_groups: + new_group = find_all_members(node) + groups.append(new_group) + all_ids_in_groups.update(new_group) + for group in groups: + seen_nodes.update(group) + self._diff_nodes.extend(analyze_diff_in_group([self._get_node_by_id(n_id) for n_id in group])) + if self._diff_nodes: + # 找出所有layer和sub_layer最小的节点 + min_layer_sublayer = min((x.layer, x.sub_layer) for x in self._diff_nodes) + self._diff_nodes = [ + node + for node in self._diff_nodes + if (node.layer, node.sub_layer) == min_layer_sublayer + ] + return + + def _get_node_by_id(self, node_id): + splits = node_id.split(Const.SEP, 1) + if len(splits) < 2 or not splits[0].isdigit(): + logger.error(f'invalid node_id {node_id}') + raise RuntimeError(f'invalid node_id {node_id}') + rank = int(splits[0]) + return self._rank_comm_nodes_dict.get(rank, {}).get(node_id) + + def _gen_analyze_info(self): + if not os.path.exists(self._output_path): + make_dir(self._output_path) + file_name = f'diff_analyze_{time.time_ns()}.json' + result_file = os.path.join(self._output_path, file_name) + result_content = defaultdict(list) + for node in self._diff_nodes: + result_content[f'rank_{node.rank}'].append(node.gen_node_info(self._paths[node.rank])) + save_json(result_file, result_content, 2) + logger.info(f"The analyze result is saved in: {result_file}") diff --git a/debug/accuracy_tools/msprobe/core/compare/find_first/data_processor.py b/debug/accuracy_tools/msprobe/core/compare/find_first/data_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..31d1e3433f49b938ecc7c253d3e655ec507a043c --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/compare/find_first/data_processor.py @@ -0,0 +1,35 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from msprobe.core.common.const import Const +from msprobe.core.common.log import logger + + +class DataProcessor: + def __init__(self, data_frame): + self.data_frame = data_frame + if self.data_frame == Const.PT_FRAMEWORK: + from msprobe.pytorch.compare.distributed_compare import compare_distributed + self.process_func = compare_distributed + elif self.data_frame == Const.MS_FRAMEWORK: + from msprobe.mindspore.compare.distributed_compare import ms_compare_distributed + self.process_func = ms_compare_distributed + else: + raise ValueError(f"Unsupported data_frame: {self.data_frame}") + + def process(self, npu_path, bench_path, output_path): + logger.info("Start comparing data ......") + return self.process_func(npu_path, bench_path, output_path, first_diff_analyze=True) diff --git a/debug/accuracy_tools/msprobe/core/compare/find_first/graph.py b/debug/accuracy_tools/msprobe/core/compare/find_first/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..5043ec90dc529b75b576a558be39a3b903701a45 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/compare/find_first/graph.py @@ -0,0 +1,180 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from msprobe.core.common.const import Const +from msprobe.core.common.log import logger +from msprobe.core.common.const import CompareConst +from msprobe.core.compare.find_first.utils import RankPath, DiffAnalyseConst + + +@dataclass +class DataNode: + op_name: str + rank: int + inputs: dict + outputs: dict + op_data: list + layer: int = 0 # 和communication_node的layer保持一致 + sub_layer: int = 0 # 调用顺序,越小表示越先调用 + + def __init__(self, op_name, rank, op_data, **kwargs): + self.op_name = op_name + self.rank = rank + self.stack = None + self.inputs = {} + self.outputs = {} + self.is_diff = False + self.parse_data(op_data) + self.sub_layer = kwargs.get('sub_layer', 0) + + def find_stack(self): + for item in self.stack: + if len(item) >= 2 and self.op_name in item[0]: + return item[1] + return {} + + def parse_data(self, op_data): + self.is_diff = not op_data.get("is_same", True) + self.op_data = op_data.get("op_items") # 这里拿到的是比对column,是一个list,有若干行 + metrics = {} + for cmp_data in self.op_data: + name = cmp_data.get(CompareConst.NPU_NAME) + if CompareConst.NPU_MAX in cmp_data: + metrics = {CompareConst.NPU_MAX: cmp_data.get(CompareConst.NPU_MAX), + CompareConst.NPU_MIN: cmp_data.get(CompareConst.NPU_MIN), + CompareConst.NPU_MEAN: cmp_data.get(CompareConst.NPU_MEAN), + CompareConst.NPU_NORM: cmp_data.get(CompareConst.NPU_NORM)} + elif CompareConst.NPU_MD5 in cmp_data: + metrics = {CompareConst.NPU_MD5: cmp_data.get(CompareConst.NPU_MD5)} + + if cmp_data.get(CompareConst.STACK) != CompareConst.N_A and not self.stack: + self.stack = cmp_data.get(CompareConst.STACK) + if Const.INPUT in name: + self.inputs[name] = metrics + elif Const.OUTPUT in name: + self.outputs[name] = metrics + + def gen_node_info(self, path: RankPath): + data_info_list = {Const.INPUT: self.inputs, Const.OUTPUT: self.outputs} + return {'op_name': self.op_name, + 'data_info': data_info_list, + 'stack_info': self.stack} + + +class CommunicationNode: + def __init__(self, node_id, rank, data: DataNode, layer=0, **kwargs): + self.node_id = node_id + self.rank = rank + self.data = data + self.is_diff = data.is_diff + self.layer = layer + op_name_split = self.data.op_name.split(Const.SEP) + if len(op_name_split) < 4: + logger.error(f'invalid op_name: {self.data.op_name}') + raise RuntimeError(f'invalid op_name: {self.data.op_name}') + self.api = op_name_split[1] + self.call_cnt = op_name_split[2] + self.pre_node = kwargs.get('pre_node') + self.link_nodes = kwargs.get('link_nodes', {}) + self.dst_nodes = kwargs.get('dst_nodes', {}) + self.src_nodes = kwargs.get('src_nodes', {}) + self.next_nodes = kwargs.get('next_nodes', {}) + self.compute_ops = kwargs.get('compute_ops', []) + self.type = self._resolve_type() + self.connected = False + + def add_next(self, node): + self.next_nodes[node.node_id] = node + node.pre_node = self + node.layer = self.layer + 1 + node.data.layer = node.layer + + def add_link(self, node): + self.link_nodes[node.node_id] = node + node.link_nodes[self.node_id] = self + node.layer = self.layer + node.data.layer = node.layer + self.connected = True + node.connected = True + + def add_dst(self, node): + self.dst_nodes[node.node_id] = node + node.src_nodes[self.node_id] = self + node.layer = self.layer + node.data.layer = node.layer + self.connected = True + node.connected = True + + def delete(self): + for node in self.next_nodes.values(): + node.pre_node = None + for node in self.dst_nodes.values(): + if node.src_nodes: + node.src_nodes.pop(self.node_id) + for node in self.src_nodes.values(): + if node.dst_nodes: + node.dst_nodes.pop(self.node_id) + for node in self.link_nodes.values(): + if node.link_nodes: + node.link_nodes.pop(self.node_id) + if self.pre_node: + if self.pre_node.next_nodes: + self.pre_node.next_nodes.pop(self.node_id) + + def find_connected_nodes(self): + """ + 根据 api/类型/入参/调用次数 确定相连接的node的op_name + """ + tar_api = DiffAnalyseConst.P2P_API_MAPPING.get(self.api, self.api) + ranks = set() + # 遍历DST和SRC相关的input,获取对应的rank值 + # 遍历inputs获取所有rank值 + for input_name, v in self.data.inputs.items(): + # 检查key是否包含DST/SRC相关标识 + target_types = [DiffAnalyseConst.DST, DiffAnalyseConst.DST_GROUP, + DiffAnalyseConst.SRC, DiffAnalyseConst.SRC_GROUP] + if any(keyword in input_name for keyword in target_types): + # 优先使用MD5值,如果没有则使用NPU_MAX值 + rank_val = 0 + if CompareConst.NPU_MD5 in v: + rank_val = int(v.get(CompareConst.NPU_MD5, 0)) + else: + rank_val = int(v.get(CompareConst.NPU_MAX, 0)) + if rank_val: + ranks.add(rank_val) + elif input_name.endswith('.group'): + # 优先使用MD5值,如果没有则使用NPU_MAX值 + val = v.get(CompareConst.NPU_MD5) if CompareConst.NPU_MD5 in v else v.get(CompareConst.NPU_MAX) + if val and val.startswith('[') and val.endswith(']'): + val = [int(part) for part in val.strip('[]').split(',')] + ranks.update(val) + + return {'ranks': ranks, 'api': f'Distributed.{tar_api}', + 'type': DiffAnalyseConst.OPPOSITE_DIR.get(self.type, DiffAnalyseConst.LINK)} + + + def _resolve_type(self): + # 遍历SRC和DST相关的输入,根据rank值判断节点类型 + for prefix, node_type in [(DiffAnalyseConst.SRC, DiffAnalyseConst.SRC), + (DiffAnalyseConst.DST, DiffAnalyseConst.DST)]: + for k, v in self.data.inputs.items(): + if prefix in k or f"{prefix}_GROUP" in k: + # 优先使用MD5值,如果没有则使用NPU_MAX值 + compare_val = v.get(CompareConst.NPU_MD5) if CompareConst.NPU_MD5 in v \ + else v.get(CompareConst.NPU_MAX) + return node_type if compare_val == self.rank \ + else DiffAnalyseConst.OPPOSITE_DIR[node_type] + return DiffAnalyseConst.LINK diff --git a/debug/accuracy_tools/msprobe/core/compare/find_first/utils.py b/debug/accuracy_tools/msprobe/core/compare/find_first/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3c4f940dd963f3ac905135c7931a81ae1680a8eb --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/compare/find_first/utils.py @@ -0,0 +1,188 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from dataclasses import dataclass +import sys +import time +import psutil + +from msprobe.core.common.file_utils import check_file_or_directory_path, load_json +from msprobe.core.common.const import Const + + +@dataclass +class RankPath: + rank: int + dump_path: str + + def __init__(self, rank, dump_path): + self.rank = rank + check_file_or_directory_path(dump_path) + self.dump_path = dump_path + + +class FileCache: + """ + lazy load file + """ + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super().__new__(cls, *args, **kwargs) + return cls._instance + + def __init__(self): + self._max_memory_usage = psutil.virtual_memory().available / 4 # 最大占用当前可用内存空间的1/4 + self._cache = OrderedDict() + self._access_cnt = {} + self._access_time = {} + self._size = {} + + @staticmethod + def _sizeof(obj): + seen = set() + objs = [obj] + size = 0 + while objs: + obj = objs.pop() + obj_id = id(obj) + if obj_id in seen: + continue + seen.add(obj_id) + size += sys.getsizeof(obj) + if isinstance(obj, dict): + objs.extend(obj.keys()) + objs.extend(obj.values()) + elif isinstance(obj, (list, tuple, set, frozenset)): + objs.extend(obj) + return size + + def load_json(self, json_path): + if json_path in self._cache: + self._access_cnt[json_path] += 1 + self._access_time[json_path] = time.monotonic() + self._cache.move_to_end(json_path) + return self._cache[json_path] + self._cleanup() + return self._load(json_path) + + def _load(self, json_path): + data = load_json(json_path) + self._add_to_cache(json_path, data) + return data + + def _add_to_cache(self, key, data): + if key in self._cache: + self._cache.move_to_end(key) + else: + self._cache[key] = data + self._access_cnt[key] = 0 + self._access_time[key] = time.monotonic() + self._size[key] = self._sizeof(data) + + def _calc_cache_size(self): + return sys.getsizeof(self._cache) + sum(self._size.values()) + + def _cleanup(self): + while self._calc_cache_size() > self._max_memory_usage and self._cache: + least_frequent_key = min(self._access_cnt.keys(), key=lambda k: self._access_cnt[k]) + least_recent_key = min(self._access_time.keys(), key=lambda k: self._access_time[k]) + largest_key = max(self._cache.keys(), key=lambda k: self._size[k]) + key_to_rm = min([least_frequent_key, least_recent_key, largest_key], + key=lambda k: (self._access_cnt[k], self._access_time[k], -self._size[k])) + del self._cache[key_to_rm] + del self._access_cnt[key_to_rm] + del self._access_time[key_to_rm] + del self._size[key_to_rm] + + +def is_communication_op(op_name): + # 定义通信算子的关键字,覆盖各种通信操作,如all_reduce, send, broadcast等 + # 从wrap文件中读取,先硬编码在文件中 + return (op_name.startswith((Const.DISTRIBUTED, Const.MINT_DIST_API_TYPE_PREFIX, Const.MS_API_TYPE_COM)) and + any(keyword in op_name for keyword in DiffAnalyseConst.COMMUNICATION_KEYWORDS)) + + +def is_ignore_op(op_name): + ignore_keywords = [ + 'Torch.empty', + 'Torch.fill' + ] + return any(keyword in op_name for keyword in ignore_keywords) + + +class DiffAnalyseConst: + COMMUNICATION_KEYWORDS = { + 'send', # send 算子 + 'recv', # recv 算子 + 'broadcast', # broadcast 算子 + 'all_reduce', # all_reduce 算子 + 'reduce', # reduce 算子 + 'all_gather', # all_gather 算子 + 'gather', # gather 算子 + 'isend', # isend 算子 + 'irecv', # irecv 算子 + 'scatter', # scatter 算子 + 'reduce_scatter', # reduce_scatter 算子 + '_reduce_scatter_base', # _reduce_scatter_base 算子 + '_all_gather_base', # _all_gather_base 算子 + 'all_to_all_single', # all_to_all_single 算子 + 'all_to_all', # all_to_all 算子 + 'all_gather_into_tensor', # all_gather_into_tensor 算子 + 'reduce_scatter_tensor', # reduce_scatter_tensor 算子 + 'send_object_list', # send_object_list 算子 + 'recv_object_list' # recv_object_list 算子 + } + P2P_API_MAPPING = {'send': 'recv', 'recv': 'send', 'isend': 'irecv', 'irecv': 'isend', + 'send_object_list': 'recv_object_list', 'recv_object_list': 'send_object_list'} + SRC = 'src' + DST = 'dst' + SRC_GROUP = 'group_src' + DST_GROUP = 'group_dst' + LINK = 'link' + DIRECTED_API = {'send': DST, 'recv': SRC, 'isend': DST, 'irecv': SRC, 'broadcast': SRC, 'scatter': SRC, + 'gather': DST, 'send_object_list': DST, 'recv_object_list': SRC} + OPPOSITE_DIR = {SRC: DST, DST: SRC} + DUMP_FILE = "dump.json" + CONSTRUCT_FILE = "construct.json" + STACK_FILE = "stack.json" + + +def analyze_diff_in_group(nodes_group): + diff_nodes = [] + + def get_compute_ops_from_comm_nodes(comm_nodes): + for comm_node in comm_nodes: + for op_node in comm_node.compute_ops: + op_node.layer = comm_node.layer + diff_nodes.append(op_node) + + def get_comm_ops(comm_nodes): + for node in comm_nodes: + node.data.layer = node.layer + diff_nodes.append(node.data) + + # 先看src或link中input是否有异常 + src_list = list(filter(lambda node: node.type in [DiffAnalyseConst.SRC, DiffAnalyseConst.LINK], nodes_group)) + input_diff_nodes = list(filter(lambda node: node.is_diff, src_list)) + # 如果有异常回溯计算节点找到异常来源 + # 使用cpu模拟节点进行计算,查看结果是否有问题。需要对所有计算节点录入/映射,暂不实现。 + get_compute_ops_from_comm_nodes(input_diff_nodes) + # 筛选入参没问题但出参有问题的通信节点 + output_diff_nodes = list(filter(lambda node: node.data.is_diff, nodes_group)) + get_comm_ops(output_diff_nodes) + return diff_nodes \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/core/compare/highlight.py b/debug/accuracy_tools/msprobe/core/compare/highlight.py index cf3e1c4c03e9553f5566870b7c5ebe2d890e9774..6e5aaa232df114b3554d61ac564a1a7072054feb 100644 --- a/debug/accuracy_tools/msprobe/core/compare/highlight.py +++ b/debug/accuracy_tools/msprobe/core/compare/highlight.py @@ -16,10 +16,8 @@ import abc import math import multiprocessing -import re from collections import namedtuple -import numpy as np import openpyxl from openpyxl.styles import PatternFill from openpyxl.utils.dataframe import dataframe_to_rows @@ -28,14 +26,9 @@ from tqdm import tqdm from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.file_utils import save_workbook from msprobe.core.common.log import logger -from msprobe.core.common.utils import get_header_index, safe_get_value -from msprobe.core.compare.utils import table_value_is_valid, get_name_and_state, CompareException - - -class HighlightCheck(abc.ABC): - @abc.abstractmethod - def apply(self, info, color_columns, dump_mode): - raise NotImplementedError +from msprobe.core.common.utils import get_header_index +from msprobe.core.compare.utils import table_value_is_valid, gen_api_batches +from msprobe.core.compare.config import ModeConfig def add_highlight_row_info(color_list, num, highlight_err_msg): @@ -46,6 +39,12 @@ def add_highlight_row_info(color_list, num, highlight_err_msg): color_list.append((num, [highlight_err_msg])) +class HighlightCheck(abc.ABC): + @abc.abstractmethod + def apply(self, info, color_columns, dump_mode): + raise NotImplementedError + + class CheckOrderMagnitude(HighlightCheck): """检查Max diff的数量级差异""" @@ -53,10 +52,12 @@ class CheckOrderMagnitude(HighlightCheck): api_in, api_out, num = info max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY else CompareConst.MAX_ABS_ERR, dump_mode) - if abs(api_in[max_diff_index]) > abs(api_out[max_diff_index]): + max_diff_in = abs(api_in[max_diff_index]) + max_diff_out = abs(api_out[max_diff_index]) + if max_diff_in > max_diff_out or (max_diff_in <= 1 or max_diff_out <= 1): return - in_order = 0 if abs(api_in[max_diff_index]) < 1 else math.log10(abs(api_in[max_diff_index])) - out_order = 0 if abs(api_out[max_diff_index]) < 1 else math.log10(abs(api_out[max_diff_index])) + in_order = 0 if max_diff_in < 1 else math.log10(max_diff_in) + out_order = 0 if max_diff_out < 1 else math.log10(max_diff_out) if out_order - in_order >= CompareConst.ORDER_MAGNITUDE_DIFF_YELLOW: add_highlight_row_info(color_columns.yellow, num, "maximum absolute error of both input/parameters and output exceed 1, " @@ -75,12 +76,12 @@ class CheckOneThousandErrorRatio(HighlightCheck): if (api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED): add_highlight_row_info(color_columns.red, num, - "The input/parameters's one thousandth err ratio exceeds 0.9, " + "The input/parameter's one thousandth err ratio exceeds 0.9, " "while the output's is below 0.6") elif api_in[one_thousand_index] - api_out[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_DIFF_YELLOW: add_highlight_row_info(color_columns.yellow, num, "The output's one thousandth err ratio decreases by more than 0.1 " - "compared to the input/parameters's") + "compared to the input/parameter's") class CheckCosineSimilarity(HighlightCheck): @@ -94,30 +95,38 @@ class CheckCosineSimilarity(HighlightCheck): if api_in[cosine_index] - api_out[cosine_index] > CompareConst.COSINE_DIFF_YELLOW: add_highlight_row_info(color_columns.yellow, num, "The output's cosine decreases by more than 0.1 " - "compared to the input/parameters's") + "compared to the input/parameter's") class CheckMaxRelativeDiff(HighlightCheck): """检查最大相对差异""" def apply(self, info, color_columns, dump_mode): + def get_number(data): + """统计量相对值如果为正常百分数据,str格式并以%结尾""" + if isinstance(data, str) and data.endswith("%"): + return float(data[:-1]) / 100 + return data + api_in, api_out, num = info - max_diff_index = get_header_index(CompareConst.MAX_DIFF, dump_mode) - bench_max_index = get_header_index(CompareConst.BENCH_MAX, dump_mode) - input_max_relative_diff = np.abs( - np.divide(api_in[max_diff_index], max(Const.FLOAT_EPSILON, api_in[bench_max_index]))) - output_max_relative_diff = np.abs( - np.divide(api_out[max_diff_index], max(Const.FLOAT_EPSILON, api_out[bench_max_index]))) - if not isinstance(input_max_relative_diff, (float, int)) or not isinstance(output_max_relative_diff, - (float, int)): + max_rel_diff = get_header_index(CompareConst.MAX_RELATIVE_ERR, dump_mode) + input_max_relative_diff = api_in[max_rel_diff] # 内部数据,长度总是和表头一致,不会越界 + output_max_relative_diff = api_out[max_rel_diff] + input_max_relative_diff = get_number(input_max_relative_diff) + output_max_relative_diff = get_number(output_max_relative_diff) + + if not isinstance(output_max_relative_diff, (float, int)): return if output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_RED: add_highlight_row_info(color_columns.red, num, "maximum relative error exceeds 0.5") - elif (output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and - input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW): + + if not isinstance(input_max_relative_diff, (float, int)): + return + if (output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and + input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW): add_highlight_row_info(color_columns.yellow, num, "The output's maximum relative error exceeds 0.1, " - "while the input/parameters's is below 0.01") + "while the input/parameter's is below 0.01") class CheckOverflow(HighlightCheck): @@ -138,278 +147,237 @@ class CheckOverflow(HighlightCheck): add_highlight_row_info(color_columns.red, num, "maximum absolute error exceeds 1e+10") +class CheckReqGradConsist(HighlightCheck): + """检查requires_grad是否一致""" + + def apply(self, info, color_columns, dump_mode): + line, num = info + req_grad_consist_index = get_header_index(CompareConst.REQ_GRAD_CONSIST, dump_mode) + if not line[req_grad_consist_index]: + add_highlight_row_info(color_columns.yellow, num, "requires_grad is inconsistent") + + class HighlightRules: """高亮规则集合,用于检查API的误差""" # 适用于每行的规则 basic_rules = { "check_overflow": CheckOverflow() } + consist_rules = { + "check_req_grad_consist": CheckReqGradConsist() + } # 用于比较输入和输出的规则 + # 真实数据检查规则 compare_rules = { "check_order_magnitude": CheckOrderMagnitude(), "check_one_thousand_error": CheckOneThousandErrorRatio(), "check_cosine_similarity": CheckCosineSimilarity() } + # 统计量数据检查规则 summary_compare_rules = { "check_order_magnitude": CheckOrderMagnitude(), "check_max_relative_diff": CheckMaxRelativeDiff(), } -def check_indices_numeric(api_items, indices: list): - """检查指定索引处的值是否都为数字类型(int 或 float)""" - return all(isinstance(api_items[i], (float, int)) for i in indices) - - -def apply_comparison_rules(api_info, dump_mode, color_columns): - """output与input/params的比较""" - if dump_mode == Const.SUMMARY: - for rule in HighlightRules.summary_compare_rules.values(): - rule.apply(api_info, color_columns, dump_mode) - else: - for rule in HighlightRules.compare_rules.values(): - rule.apply(api_info, color_columns, dump_mode) - - -def find_error_rows(result, api_batch, highlight_dict, dump_mode): - """找到单个API中需要高亮的行""" - if dump_mode == Const.MD5: - return - npu_max_index = get_header_index(CompareConst.NPU_MAX, dump_mode) - bench_max_index = get_header_index(CompareConst.BENCH_MAX, dump_mode) - max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY - else CompareConst.MAX_ABS_ERR, dump_mode) - - red_lines, yellow_lines = [], [] - LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer']) - ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer']) - ColorColumns = namedtuple('ColorColumns', ['red', 'yellow']) - color_columns = ColorColumns(red=red_lines, yellow=yellow_lines) - - api_batch_start = api_batch.start # result_df的input起始全局索引 - api_batch_params_end_index = api_batch.params_end_index # result_df的params结束全局索引 + 1 - api_batch_output_end_index = api_batch.output_end_index # result_df的output结束全局索引 + 1 - api_batch_params_slice_index_local = api_batch_params_end_index - api_batch_start # result的params结束局部切片索引 - api_batch_output_slice_index_local = api_batch_output_end_index - api_batch_start # result的output结束局部切片索引 - - # 对单行API的输入或输出进行误差判断 - for i, line in enumerate(result): - index = api_batch_start + i - line_info = LineInfo(line_data=line, num_pointer=index) - for rule in HighlightRules.basic_rules.values(): - rule.apply(line_info, color_columns, dump_mode) - - # 对API的输出与输入比较,进行误差判断 - for n, api_out in enumerate(result[api_batch_params_slice_index_local: api_batch_output_slice_index_local]): - index = api_batch_start + api_batch_params_slice_index_local + n - # 单行检查只有溢出检查(红色),如果已经溢出,不进一步检查 - if index in red_lines: - continue - if not check_indices_numeric(api_out, [npu_max_index, bench_max_index, max_diff_index]): - continue - - # input/parameters的比较检查, 这里api_in包括input、parameters - for _, api_in in enumerate(result[0: api_batch_params_slice_index_local]): - if not check_indices_numeric(api_in, [npu_max_index, bench_max_index, max_diff_index]): - continue - api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=index) - apply_comparison_rules(api_info, dump_mode, color_columns) - - red_lines_num_set = {x[0] for x in red_lines} - yellow_lines_num_set = {x[0] for x in yellow_lines} - highlight_dict.get('red_rows', set()).update(red_lines_num_set) - highlight_dict.get('yellow_rows', set()).update(yellow_lines_num_set - red_lines_num_set) - highlight_dict.get('red_lines', []).extend(red_lines) - highlight_dict.get('yellow_lines', []).extend(yellow_lines) - - -class ApiBatch: - def __init__(self, api_name: str, start: int): - self.api_name = api_name - self.start = start - self.input_len = 1 # input的数量 - self.params_end_index = start + 1 # params的结束index - self.output_end_index = start + 1 # output的结束index - self.params_grad_end_index = start + 1 # params_grad的结束index - # 内部state的标志("input", "output", "parameters", "parameters_grad"), - # 用于控制计算input_len, output_end_index, params_end_index, self.params_grad_end_index - self._state = Const.INPUT # api_batch初始化为input - - def set_state(self, state: str): - """设置当前状态""" - if state in {Const.INPUT, Const.OUTPUT, Const.KWARGS, Const.PARAMS, Const.PARAMS_GRAD}: - self._state = state - else: - raise ValueError(f"Invalid state: {state}") - - def increment(self, state: str): - self.set_state(state) - if self._state == Const.INPUT or self._state == Const.KWARGS: - self.input_len += 1 - self.params_end_index += 1 - self.output_end_index += 1 - if self._state == Const.PARAMS: - self.params_end_index += 1 - self.output_end_index += 1 - if self._state == Const.OUTPUT: - self.output_end_index += 1 - self.params_grad_end_index += 1 - - -def api_batches_update(api_batches, api_name, state, index): - """ - 当一个api的所有item更新完后,input, output的索引范围: - input: [start: start+input_len] - output: [start+input_len: output_end_index] - params: [output_end_index: params_end_index] - """ - if not api_batches: - api_batches.append(ApiBatch(api_name, index)) - else: - api_batch = api_batches[-1] - if api_batch.api_name == api_name or ( - not re.search(Const.REGEX_FORWARD_BACKWARD, api_name) and api_name in api_batch.api_name): - try: - api_batch.increment(state) - except ValueError as e: - logger.error(f"api_batch: {api_batch} with invalid state, please check! {e}") - raise CompareException(CompareException.INVALID_STATE_ERROR) from e - else: - api_batches.append(ApiBatch(api_name, index)) - - -def find_compare_result_error_rows(result_df, highlight_dict, dump_mode): - """将dataframe根据API分组,并找到有误差的算子用于高亮""" - result = result_df.values - api_batches = [] - for i, res_i in enumerate(result): - api_full_name = safe_get_value(res_i, 0, "res_i") - api_name, state = get_name_and_state(api_full_name) - api_batches_update(api_batches, api_name, state, i) - with tqdm(total=len(api_batches), desc="API/Module Analyse Progress", unit="item", ncols=100) as progress_bar: - for api_batch in api_batches: - find_error_rows(result[api_batch.start: api_batch.params_grad_end_index], api_batch, highlight_dict, - dump_mode) - progress_bar.update(1) - - -def value_check(value, api_name=None, i=None, result_df_columns=None): - if not table_value_is_valid(value): - if result_df_columns: - logger.error(f"Malicious value [{value}] at api_name [{api_name}], column [{result_df_columns[i]}], " - f"is not allowed to be written into the compare result xlsx.") - else: - logger.error(f"Malicious value [{value}] is not allowed to be written into the compare result xlsx.") - - -def df_malicious_value_check(df_chunk, result_df_columns): - for row in df_chunk.itertuples(index=False): - api_name = row[0] - for i, value in enumerate(row): - value_check(value, api_name, i, result_df_columns) +class HighLight: + def __init__(self, mode_config: ModeConfig, rank): + self.mode_config = mode_config + self.rank = rank + @staticmethod + def check_indices_numeric(api_items, indices: list): + """检查指定索引处的值是否都为数字类型(int 或 float)""" + return all(isinstance(api_items[i], (float, int)) for i in indices) -def handle_multi_process_malicious_value_check(func, result_df): - result_total_nums = len(result_df) - process_num = int((multiprocessing.cpu_count() + 1) / 2) - - if result_total_nums <= process_num: - process_num = 1 - chunks = [result_df] - else: - chunk_size = result_total_nums // process_num - chunks = [result_df.iloc[i: i + chunk_size] for i in range(0, result_total_nums, chunk_size)] - - pool = multiprocessing.Pool(process_num) - - def err_call(args): - logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args)) - try: - pool.terminate() - except OSError: - logger.error("Pool terminate failed") - - result_df_columns = result_df.columns.tolist() - for column in result_df_columns: - value_check(column) - for df_chunk in chunks: - pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call) - - pool.close() - pool.join() - - -def compare_result_df_convert(value): - if not isinstance(value, (float, int)) or isinstance(value, bool): # bool类型或者非数字类型转str - value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else str(value) - if isinstance(value, float): - value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else value - return value - - -def highlight_rows_xlsx(result_df, highlight_dict, file_path): - """Write and highlight results in Excel""" - - update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg - - wb = openpyxl.Workbook() - ws = wb.active - - # write header - logger.info('Initializing Excel file.') - - handle_multi_process_malicious_value_check(df_malicious_value_check, result_df) - - result_df_convert = result_df.applymap(compare_result_df_convert) - - for row in dataframe_to_rows(result_df_convert, index=False, header=True): - ws.append(row) - - # 对可疑数据标色 - logger.info('Coloring Excel in progress.') - col_len = len(result_df.columns) - red_fill = PatternFill( - start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid" - ) - yellow_fill = PatternFill( - start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid", - ) - for i in highlight_dict.get("red_rows", []): - for j in range(1, col_len + 1): - ws.cell(row=i + 2, column=j).fill = red_fill # 2因为ws.cell中的row或column需要>=1,数据从第2行开始 - for i in highlight_dict.get("yellow_rows", []): - for j in range(1, col_len + 1): - ws.cell(row=i + 2, column=j).fill = yellow_fill - - logger.info('Saving Excel file to disk: %s' % file_path) - save_workbook(wb, file_path) - - -def update_highlight_err_msg(result_df, highlight_dict): - if result_df.shape[1] <= 1: - return - - if CompareConst.NPU_MD5 in result_df.columns: - return + @staticmethod + def update_highlight_err_msg(result_df, highlight_dict): + if result_df.shape[1] <= 1: + return - err_msg = result_df.get(CompareConst.ERROR_MESSAGE) - red_lines_num_set = highlight_dict.get('red_rows') + if CompareConst.NPU_MD5 in result_df.columns: + return - for color in ['red', 'yellow']: - line_key = f'{color}_lines' - lines = highlight_dict.get(line_key, []) - for line_index, messages in lines: - if color == 'yellow' and line_index in red_lines_num_set: - continue # 如果是 yellow 行,且已被 red 行覆盖,跳过 + err_msg = result_df.get(CompareConst.ERROR_MESSAGE).copy() + red_lines_num_set = highlight_dict.get('red_rows') + + for color in ['red', 'yellow']: + line_key = f'{color}_lines' + lines = highlight_dict.get(line_key, []) + for line_index, messages in lines: + if color == 'yellow' and line_index in red_lines_num_set: + continue # 如果是 yellow 行,且已被 red 行覆盖,跳过 + + for msg in messages: + if err_msg[line_index] == '': + err_msg[line_index] = msg + else: + err_msg[line_index] += '\n' + msg + + if color == 'red': + red_lines_num_set.add(line_index) + + result_df[CompareConst.ERROR_MESSAGE] = err_msg + + @staticmethod + def compare_result_df_convert(value): + if not isinstance(value, (float, int)) or isinstance(value, bool): # bool类型或者非数字类型转str + value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else str(value) + if isinstance(value, float): + value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else value + return value + + @staticmethod + def value_check(value, api_name=None, i=None, result_df_columns=None): + if not table_value_is_valid(value): + if result_df_columns: + logger.error(f"Malicious value [{value}] at api_name [{api_name}], column [{result_df_columns[i]}], " + f"is not allowed to be written into the compare result xlsx.") + else: + logger.error(f"Malicious value [{value}] is not allowed to be written into the compare result xlsx.") + + def find_compare_result_error_rows(self, result_df, highlight_dict): + """将dataframe根据API分组,并找到有误差的算子用于高亮""" + result = result_df.values + header = result_df.columns.tolist() + api_batches = gen_api_batches(result, header) + default_bar_desc = 'API/Module Analyse Progress' + bar_desc_add_rank = f'[{self.rank}]' + default_bar_desc if self.rank else default_bar_desc + with tqdm(total=len(api_batches), desc=bar_desc_add_rank, unit="item", ncols=100) as progress_bar: + for api_batch in api_batches: + self.find_error_rows(result[api_batch.start: api_batch.params_grad_end_index], api_batch, + highlight_dict) + progress_bar.update(1) + + def find_error_rows(self, result, api_batch, highlight_dict): + """找到单个API中需要高亮的行""" + if self.mode_config.dump_mode == Const.MD5: + return + npu_max_index = get_header_index(CompareConst.NPU_MAX, self.mode_config.dump_mode) + bench_max_index = get_header_index(CompareConst.BENCH_MAX, self.mode_config.dump_mode) + max_diff_index = get_header_index(CompareConst.MAX_DIFF if self.mode_config.dump_mode == Const.SUMMARY + else CompareConst.MAX_ABS_ERR, self.mode_config.dump_mode) + + red_lines, yellow_lines = [], [] + LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer']) + ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer']) + ColorColumns = namedtuple('ColorColumns', ['red', 'yellow']) + color_columns = ColorColumns(red=red_lines, yellow=yellow_lines) + + api_batch_start = api_batch.start # result_df的input起始全局索引 + api_batch_params_end_index = api_batch.params_end_index # result_df的params结束全局索引 + 1 + api_batch_output_end_index = api_batch.output_end_index # result_df的output结束全局索引 + 1 + api_batch_params_slice_index_local = api_batch_params_end_index - api_batch_start # result的params结束局部切片索引 + api_batch_output_slice_index_local = api_batch_output_end_index - api_batch_start # result的output结束局部切片索引 + + # 对单行API的输入或输出进行误差判断 + for i, line in enumerate(result): + index = api_batch_start + i + line_info = LineInfo(line_data=line, num_pointer=index) + for rule in HighlightRules.basic_rules.values(): + rule.apply(line_info, color_columns, self.mode_config.dump_mode) + + # 对API的输出与输入比较,进行误差判断 + for n, api_out in enumerate(result[api_batch_params_slice_index_local: api_batch_output_slice_index_local]): + index = api_batch_start + api_batch_params_slice_index_local + n + # 单行检查只有溢出检查(红色),如果已经溢出,不进一步检查 + if index in red_lines: + continue + if not self.check_indices_numeric(api_out, [npu_max_index, bench_max_index, max_diff_index]): + continue - for msg in messages: - if err_msg[line_index] == '': - err_msg[line_index] = msg - else: - err_msg[line_index] += '\n' + msg + # input/parameters的比较检查, 这里api_in包括input、parameters + for api_in in result[0: api_batch_params_slice_index_local]: + if not self.check_indices_numeric(api_in, [npu_max_index, bench_max_index, max_diff_index]): + continue + api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=index) + self.apply_comparison_rules(api_info, color_columns) + + # 对单行API的输入或输出进行requires_grad是否一致判断 + for i, line in enumerate(result): + index = api_batch_start + i + line_info = LineInfo(line_data=line, num_pointer=index) + for rule in HighlightRules.consist_rules.values(): + rule.apply(line_info, color_columns, self.mode_config.dump_mode) + + red_lines_num_set = {x[0] for x in red_lines} + yellow_lines_num_set = {x[0] for x in yellow_lines} + highlight_dict.get('red_rows', set()).update(red_lines_num_set) + highlight_dict.get('yellow_rows', set()).update(yellow_lines_num_set - red_lines_num_set) + highlight_dict.get('red_lines', []).extend(red_lines) + highlight_dict.get('yellow_lines', []).extend(yellow_lines) + + def apply_comparison_rules(self, api_info, color_columns): + """output与input/params的比较""" + if self.mode_config.dump_mode == Const.SUMMARY: + for rule in HighlightRules.summary_compare_rules.values(): + rule.apply(api_info, color_columns, self.mode_config.dump_mode) + else: + for rule in HighlightRules.compare_rules.values(): + rule.apply(api_info, color_columns, self.mode_config.dump_mode) + + def highlight_rows_xlsx(self, result_df, highlight_dict, file_path): + """Write and highlight results in Excel""" + + self.update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg + + self.df_malicious_value_check(result_df) + + wb = openpyxl.Workbook() + ws = wb.active + result_df_convert = result_df.applymap(self.compare_result_df_convert) + for row in dataframe_to_rows(result_df_convert, index=False, header=True): + ws.append(row) + + # 对可疑数据标色 + logger.info('Coloring Excel in progress.') + red_fill = PatternFill(start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid") + yellow_fill = PatternFill(start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid") + col_len = len(result_df.columns) + for i in highlight_dict.get("red_rows", []): + for j in range(1, col_len + 1): + ws.cell(row=i + 2, column=j).fill = red_fill # 2因为ws.cell中的row或column需要>=1,数据从第2行开始 + for i in highlight_dict.get("yellow_rows", []): + for j in range(1, col_len + 1): + ws.cell(row=i + 2, column=j).fill = yellow_fill + + save_workbook(wb, file_path) + + def handle_multi_process_malicious_value_check(self, func, result_df): + result_total_nums = len(result_df) + process_num = int((multiprocessing.cpu_count() + 1) / 2) + + if result_total_nums <= process_num: + process_num = 1 + chunks = [result_df] + else: + chunk_size = result_total_nums // process_num + chunks = [result_df.iloc[i: i + chunk_size] for i in range(0, result_total_nums, chunk_size)] - if color == 'red': - red_lines_num_set.add(line_index) + pool = multiprocessing.Pool(process_num) - result_df[CompareConst.ERROR_MESSAGE] = err_msg + def err_call(args): + logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args)) + try: + pool.close() + except OSError: + logger.error("Pool terminate failed") + + result_df_columns = result_df.columns.tolist() + for column in result_df_columns: + self.value_check(column) + for df_chunk in chunks: + pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call) + + pool.close() + pool.join() + + def df_malicious_value_check(self, result_df): + result_df_columns = result_df.columns.tolist() + for column in result_df_columns: + self.value_check(column) + for row in result_df.itertuples(index=False): + api_name = row[0] + for i, value in enumerate(row): + self.value_check(value, api_name, i, result_df_columns) diff --git a/debug/accuracy_tools/msprobe/core/compare/layer_mapping/layer_mapping.py b/debug/accuracy_tools/msprobe/core/compare/layer_mapping/layer_mapping.py index d0f19462ee1ccf4d72c69885c18174cec32df056..fd45ef1488d881ebc062e3589ec29c947d59c750 100644 --- a/debug/accuracy_tools/msprobe/core/compare/layer_mapping/layer_mapping.py +++ b/debug/accuracy_tools/msprobe/core/compare/layer_mapping/layer_mapping.py @@ -18,12 +18,12 @@ from collections import defaultdict from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.file_utils import load_json, load_yaml, save_yaml -from msprobe.core.common.utils import (add_time_with_yaml, - detect_framework_by_dump_json, - get_stack_construct_by_dump_json_path) +from msprobe.core.common.utils import add_time_with_yaml, detect_framework_by_dump_json, \ + get_stack_construct_by_dump_json_path, CompareException from msprobe.core.compare.layer_mapping.data_scope_parser import get_dump_data_items from msprobe.core.compare.utils import read_op, reorder_op_name_list - +from msprobe.core.common.decorator import recursion_depth_decorator +from msprobe.core.common.log import logger class LayerTrie: @@ -63,7 +63,11 @@ class LayerTrie: node = node.children[name] if index >= len(node.data_items[state]): return default_value - return node.data_items[state][index] + if node.data_items[state]: + return node.data_items[state][index] + else: + logger.error(f"node.data_items of state:{state} is empty, please check.") + raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) def save_to_yaml(self, output_path): result = {f"{self.type_name} @ {self}": self.convert_to_dict(self)} @@ -71,6 +75,7 @@ class LayerTrie: file_path = os.path.join(os.path.realpath(output_path), file_name) save_yaml(file_path, result) + @recursion_depth_decorator("LayerMapping: LayerTrie.convert_to_dict", max_depth=100) def convert_to_dict(self, node): result = {} result["data_item"] = {st: [dt.data_name for dt in dts] for st, dts in node.data_items.items()} @@ -163,6 +168,8 @@ def preprocess_layer_mapping(mapping): for key, value in name_map.items(): key_list = key.split('.') prefix = key_list[0] # 取前缀 + value_list = value.split('(') + value = value_list[0] # 取前缀 key_len = len(key_list) if prefix not in final_mapping[type_name]: final_mapping[type_name][prefix] = [] @@ -205,7 +212,8 @@ def generate_data_mapping(npu_json_path, bench_json_path, api_mapping, output_pa def read_full_op_names(data, op_name): op_parsed_list = read_op(data.get(op_name, {}), op_name) full_op_names = [op_parsed.get('full_op_name') for op_parsed in op_parsed_list] - return full_op_names + states = [op_parsed.get(Const.STATE) for op_parsed in op_parsed_list] + return full_op_names, states def generate_op_data_mapping(npu_op_name, npu_full_op_names, bench_op_name, bench_full_op_names): suffix_to_full_op_name = {} @@ -225,10 +233,10 @@ def generate_data_mapping(npu_json_path, bench_json_path, api_mapping, output_pa for npu_op_name, bench_op_name in api_mapping.items(): if not npu_op_name: continue - npu_full_op_names = read_full_op_names(npu_data, npu_op_name) - bench_full_op_names = read_full_op_names(bench_data, bench_op_name) - npu_full_op_names_reorder = reorder_op_name_list(npu_full_op_names) - bench_full_op_names_reorder = reorder_op_name_list(bench_full_op_names) + npu_full_op_names, npu_states = read_full_op_names(npu_data, npu_op_name) + bench_full_op_names, bench_states = read_full_op_names(bench_data, bench_op_name) + npu_full_op_names_reorder, _ = reorder_op_name_list(npu_full_op_names, npu_states) + bench_full_op_names_reorder, _ = reorder_op_name_list(bench_full_op_names, bench_states) mapping = generate_op_data_mapping(npu_op_name, npu_full_op_names_reorder, bench_op_name, bench_full_op_names_reorder) data_mapping.update(mapping) diff --git a/debug/accuracy_tools/msprobe/core/compare/merge_result/merge_result.py b/debug/accuracy_tools/msprobe/core/compare/merge_result/merge_result.py index b605bd59fca0b2b3a510a7a686caa94383488bd2..97ddc7ba22e4e8a122e5006f1f6acbe6712ab004 100644 --- a/debug/accuracy_tools/msprobe/core/compare/merge_result/merge_result.py +++ b/debug/accuracy_tools/msprobe/core/compare/merge_result/merge_result.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,7 +21,8 @@ from functools import partial import pandas as pd from tqdm import tqdm -from msprobe.core.common.file_utils import load_yaml, logger, FileChecker, save_excel, read_xlsx, create_directory +from msprobe.core.common.file_utils import load_yaml, logger, FileChecker, save_excel, read_xlsx, create_directory, \ + remove_path from msprobe.core.common.const import FileCheckConst, Const, CompareConst from msprobe.core.common.utils import CompareException, add_time_with_xlsx from msprobe.core.compare.utils import table_value_is_valid @@ -32,8 +33,8 @@ def check_compare_result_name(file_name): """ check whether the compare result name is as expected """ - single_rank_pattern = r"^compare_result_rank-rank_\d{14}.xlsx$" - multi_ranks_pattern = r"^compare_result_rank(\d+)-rank\1_\d{14}.xlsx$" + single_rank_pattern = r"^compare_result_(rank|rank-rank)_\d{14}\.xlsx$" + multi_ranks_pattern = r"^compare_result_rank(\d+)(?:-rank\1)?_\d{14}\.xlsx$" if re.match(multi_ranks_pattern, file_name): return True if re.match(single_rank_pattern, file_name): @@ -47,7 +48,7 @@ def reorder_path(compare_result_path_list): """ reorder compare results by rank num """ - rank_pattern = r"compare_result_rank(\d+)-rank" + rank_pattern = r"compare_result_rank(\d+)" reorder_path_list = sorted( compare_result_path_list, key=lambda path: int(re.search(rank_pattern, os.path.basename(path)).group(1)) @@ -63,6 +64,7 @@ def get_result_path(input_dir): for f in os.listdir(input_dir) if f.endswith(FileCheckConst.XLSX_SUFFIX)] filt_compare_result_path_list = [] for file_path in compare_result_path_list: + FileChecker(file_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check() file_name = os.path.basename(file_path) if check_compare_result_name(file_name): compare_result_path_checker = FileChecker(file_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE) @@ -107,8 +109,8 @@ def check_index_dump_mode_consistent(dump_mode, rank_num): return [] dump_mode_compare_index_map = { - Const.ALL: CompareConst.ALL_COMPARE_INDEX, - Const.SUMMARY: CompareConst.SUMMARY_COMPARE_INDEX + Const.ALL: CompareConst.ALL_COMPARE_INDEX + [CompareConst.REQ_GRAD_CONSIST], + Const.SUMMARY: CompareConst.SUMMARY_COMPARE_INDEX + [CompareConst.REQ_GRAD_CONSIST] } valid_compare_index = dump_mode_compare_index_map.get(dump_mode) @@ -194,7 +196,7 @@ def result_process(compare_result_path_list, api_list): compare_index_dict = {} result_df = read_xlsx(compare_result_path) - rank_pattern = r"compare_result_rank(\d+)-rank" + rank_pattern = r"compare_result_rank(\d+)" rank_num = int(re.search(rank_pattern, os.path.basename(compare_result_path)).group(1)) logger.info(f"Parsing rank{rank_num} compare result...") if not result_df.empty: @@ -236,7 +238,7 @@ def handle_multi_process(func, func_args, lock): def err_call(args): logger.error('Multiprocess merge result failed! Reason: {}'.format(args)) try: - pool.terminate() + pool.close() except OSError: logger.error("Pool terminate failed") @@ -329,6 +331,10 @@ def generate_merge_result(all_compare_index_dict_list, all_rank_num_list, all_co for i, df in enumerate(merge_df_list): # merge_df_list中df与compare_index_list中compare_index一一对应 final_result_df_list.append((df, compare_index_list[i])) + + if os.path.exists(output_path): + logger.warning(f"{output_path} will be deleted.") + remove_path(output_path) save_excel(output_path, final_result_df_list) logger.info(f"The compare results of the multi-ranks are merged and saved in: {output_path}.") diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/ms_to_pt_api.yaml b/debug/accuracy_tools/msprobe/core/compare/ms_to_pt_api.yaml similarity index 100% rename from debug/accuracy_tools/msprobe/mindspore/compare/ms_to_pt_api.yaml rename to debug/accuracy_tools/msprobe/core/compare/ms_to_pt_api.yaml diff --git a/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py b/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py index c2c1461e452f9d2c7f4e0e2803dfe51be2a132c0..12b253e16a0bd68e5c8ab3ae13ec832aa374dcae 100644 --- a/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py +++ b/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,51 +15,28 @@ import multiprocessing from dataclasses import dataclass +from functools import partial + import pandas as pd from tqdm import tqdm + from msprobe.core.common.log import logger from msprobe.core.common.utils import CompareException from msprobe.core.common.const import CompareConst +from msprobe.core.common.exceptions import FileCheckException +from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_flag_and_msg +from msprobe.core.compare.config import ModeConfig -def _handle_multi_process(func, input_parma, result_df, lock): - process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1) - op_name_mapping_dict = read_dump_data(result_df) - - df_chunk_size = len(result_df) // process_num - if df_chunk_size > 0: - df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)] - else: - df_chunks = [result_df] - - results = [] - pool = multiprocessing.Pool(process_num) - - def err_call(args): - logger.error('multiprocess compare failed! Reason: {}'.format(args)) - try: - pool.terminate() - except OSError as e: - logger.error("pool terminate failed") - - progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100) - - def update_progress(size, progress_lock): - with progress_lock: - progress_bar.update(size) - - for process_idx, df_chunk in enumerate(df_chunks): - idx = df_chunk_size * process_idx - chunk_size = len(df_chunk) - result = pool.apply_async(func, - args=(idx, op_name_mapping_dict, df_chunk, lock, input_parma), - error_callback=err_call, - callback=update_progress(chunk_size, lock)) - results.append(result) - final_results = [r.get() for r in results] - pool.close() - pool.join() - return pd.concat(final_results, ignore_index=True) +@dataclass +class ComparisonResult: + cos_result: list + euc_dist_result: list + max_err_result: list + max_relative_err_result: list + one_thousand_err_ratio_result: list + five_thousand_err_ratio_result: list + err_msgs: list def _ms_graph_handle_multi_process(func, result_df, mode): @@ -76,9 +53,9 @@ def _ms_graph_handle_multi_process(func, result_df, mode): def err_call(args): logger.error('multiprocess compare failed! Reason: {}'.format(args)) try: - pool.terminate() + pool.close() except OSError as e: - logger.error("pool terminate failed") + logger.error(f'pool terminate failed: {str(e)}') for df_chunk in df_chunks: result = pool.apply_async(func, args=(df_chunk, mode), error_callback=err_call) @@ -89,72 +66,6 @@ def _ms_graph_handle_multi_process(func, result_df, mode): return pd.concat(final_results, ignore_index=True) -def read_dump_data(result_df): - try: - npu_dump_name_list = result_df.iloc[0:, 0].tolist() - npu_dump_tensor_list = result_df.iloc[0:, -1].tolist() - op_name_mapping_dict = {} - for index, _ in enumerate(npu_dump_name_list): - npu_dump_name = npu_dump_name_list[index] - npu_dump_tensor = npu_dump_tensor_list[index] - op_name_mapping_dict[npu_dump_name] = [npu_dump_tensor, npu_dump_tensor] - return op_name_mapping_dict - except ValueError as e: - logger.error('result dataframe is not found.') - raise CompareException(CompareException.INVALID_DATA_ERROR) from e - except IndexError as e: - logger.error('result dataframe elements can not be access.') - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e - - -@dataclass -class ComparisonResult: - cos_result: list - max_err_result: list - max_relative_err_result: list - err_msgs: list - one_thousand_err_ratio_result: list - five_thousand_err_ratio_result: list - - -def _save_cmp_result(offset, result: ComparisonResult, result_df, lock): - """ - Save comparison results into the result DataFrame with thread safety. - Args: - offset: offset for index - result: data struct of ComparisonResult - result_df: result of DataFrame - lock: thread lock - - Returns: - comparison results in DataFrame - """ - - lock.acquire() - try: - for i, _ in enumerate(result.cos_result): - process_index = i + offset - result_df.loc[process_index, CompareConst.COSINE] = result.cos_result[i] - result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i] - result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i] - result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i] - result_df.loc[process_index, CompareConst.ACCURACY] = ( - check_accuracy(result.cos_result[i], result.max_err_result[i])) - result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = ( - result.one_thousand_err_ratio_result)[i] - result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = ( - result.five_thousand_err_ratio_result)[i] - return result_df - except ValueError as e: - logger.error('result dataframe is not found.') - raise CompareException(CompareException.INVALID_DATA_ERROR) from e - except IndexError as e: - logger.error('result dataframe elements can not be access.') - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e - finally: - lock.release() - - def check_accuracy(cos, max_abs_err): if cos == CompareConst.SHAPE_UNMATCH: return CompareConst.ACCURACY_CHECK_UNMATCH @@ -172,3 +83,222 @@ def check_accuracy(cos, max_abs_err): if cos < CompareConst.COS_MAX_THRESHOLD or max_abs_err > CompareConst.MAX_ABS_ERR_MAX_THRESHOLD: return CompareConst.ACCURACY_CHECK_NO return CompareConst.ACCURACY_CHECK_YES + + +class CompareRealData: + def __init__(self, file_reader, mode_config: ModeConfig, cross_frame): + self.file_reader = file_reader + self.mode_config = mode_config + self.cross_frame = cross_frame + + @staticmethod + def read_dump_data(result_df): + try: + npu_dump_name_list = result_df.loc[0:, CompareConst.NPU_NAME].tolist() + dump_tensor_pair_list = result_df.loc[0:, CompareConst.DATA_NAME].tolist() + op_name_mapping_dict = {} + for index, npu_dump_name in enumerate(npu_dump_name_list): + dump_tensor_pair = dump_tensor_pair_list[index] + op_name_mapping_dict[npu_dump_name] = dump_tensor_pair + return op_name_mapping_dict + except ValueError as e: + logger.error('result dataframe is not found.') + raise CompareException(CompareException.INVALID_DATA_ERROR) from e + except KeyError as e: + logger.error('result dataframe elements can not be access.') + raise CompareException(CompareException.INVALID_KEY_ERROR) from e + + @staticmethod + def _save_cmp_result(offset, result: ComparisonResult, result_df, lock): + """ + Save comparison results into the result DataFrame with thread safety. + Args: + offset: offset for index + result: data struct of ComparisonResult + result_df: result of DataFrame + lock: thread lock + + Returns: + comparison results in DataFrame + """ + + lock.acquire() + try: + for i, cos_item in enumerate(result.cos_result): + process_index = i + offset + result_df.loc[process_index, CompareConst.COSINE] = cos_item + result_df.loc[process_index, CompareConst.EUC_DIST] = result.euc_dist_result[i] + result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i] + result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i] + result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = ( + result.one_thousand_err_ratio_result)[i] + result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = ( + result.five_thousand_err_ratio_result)[i] + result_df.loc[process_index, CompareConst.ACCURACY] = ( + check_accuracy(result.cos_result[i], result.max_err_result[i])) + result_df.loc[process_index, CompareConst.ERROR_MESSAGE] += result.err_msgs[i] + return result_df + except ValueError as e: + logger.error('result dataframe is not found.') + raise CompareException(CompareException.INVALID_DATA_ERROR) from e + except IndexError as e: + logger.error('result dataframe elements can not be access.') + raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e + finally: + lock.release() + + def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param): + """ + :param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0 + :param bench_op_name: excel中的Bench_Name,例如:Functional.conv2d.0.forward.input.3.0 + :param op_name_mapping_dict: op_name和npy或pt文件的映射关系 + :param input_param: npu_json_path/bench_json_path/stack_json_path等参数 + :return: result_list,包含余弦相似度、最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率和错误信息 + 用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、欧式距离 + 最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率并生成错误信息 + """ + relative_err, error_flag, err_msg = None, False, None + + data_name_pair = op_name_mapping_dict.get(npu_op_name) + npu_data_name = data_name_pair[0] + bench_data_name = data_name_pair[1] + + error_file = data_name_pair + + if str(npu_data_name) == CompareConst.NO_REAL_DATA_FLAG: # 没有npu真实数据 + n_value, b_value, error_flag = CompareConst.NO_REAL_DATA, CompareConst.NO_REAL_DATA, True + err_msg = "NPU does not have data file." + elif str(bench_data_name) == CompareConst.NO_REAL_DATA_FLAG: # 没有bench真实数据 + n_value, b_value, error_flag = CompareConst.NO_REAL_DATA, CompareConst.NO_REAL_DATA, True + err_msg = "Bench does not have data file." + elif str(bench_data_name) == CompareConst.N_A: # bench没匹配 + n_value, b_value, error_flag = CompareConst.API_UNMATCH, CompareConst.API_UNMATCH, True + err_msg = "Bench api/module unmatched." + else: + npu_dir = input_param.get(CompareConst.NPU_DUMP_DATA_DIR) + bench_dir = input_param.get(CompareConst.BENCH_DUMP_DATA_DIR) + try: + n_value, b_value = self.file_reader(npu_dir, npu_data_name, bench_dir, bench_data_name, + self.cross_frame) + except IOError as error: + error_file = error.filename + n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE + error_flag = True + except (FileCheckException, CompareException): + error_file = data_name_pair + n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE + error_flag = True + + # 通过n_value, b_value同时得到错误标志和错误信息 + if not err_msg: + n_value, b_value, error_flag, err_msg = get_error_flag_and_msg(n_value, b_value, error_flag=error_flag, + error_file=error_file) + + result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg) + + if self.mode_config.fuzzy_match and npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A: + err_msg += " Fuzzy matching data, the comparison accuracy may be affected." + result_list.append(err_msg) + return result_list + + def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param): + cos_result = [] + euc_dist_result = [] + max_err_result = [] + max_relative_err_result = [] + one_thousand_err_ratio_result = [] + five_thousand_err_ratio_result = [] + err_mess = [] + + is_print_compare_log = input_param.get("is_print_compare_log") + + for i in range(len(result_df)): + npu_op_name = result_df.iloc[i, 0] + bench_op_name = result_df.iloc[i, 1] + if is_print_compare_log: + logger.info("start compare: {}".format(npu_op_name)) + + cos_sim, euc_dist, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg \ + = self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param) + + if is_print_compare_log: + if "does not have data file" in err_msg: + logger.info(f"[{npu_op_name}] Compare result: {err_msg} ") + elif "Bench api/module unmatched" in err_msg: + logger.info(f"[{npu_op_name}] Compare result: {err_msg} ") + else: + logger.info( + f"[{npu_op_name}] Compare result: cosine {cos_sim}, euc_dist {euc_dist}, " + f"max_abs_err {max_abs_err}, max_relative_err {max_relative_err}, " + f"one_thousand_err_ratio {one_thousand_err_ratio}, " + f"five_thousand_err_ratio {five_thousand_err_ratio}, {err_msg}" + ) + cos_result.append(cos_sim) + euc_dist_result.append(euc_dist) + max_err_result.append(max_abs_err) + max_relative_err_result.append(max_relative_err) + one_thousand_err_ratio_result.append(one_thousand_err_ratio) + five_thousand_err_ratio_result.append(five_thousand_err_ratio) + err_mess.append(err_msg) + + cr = ComparisonResult( + cos_result=cos_result, + euc_dist_result=euc_dist_result, + max_err_result=max_err_result, + max_relative_err_result=max_relative_err_result, + one_thousand_err_ratio_result=one_thousand_err_ratio_result, + five_thousand_err_ratio_result=five_thousand_err_ratio_result, + err_msgs=err_mess + ) + + return self._save_cmp_result(idx, cr, result_df, lock) + + def do_multi_process(self, input_param, result_df): + try: + result_df = self._handle_multi_process(self.compare_ops, input_param, result_df, + multiprocessing.Manager().RLock()) + return result_df + except ValueError as e: + logger.error('result dataframe is not found.') + raise CompareException(CompareException.INVALID_DATA_ERROR) from e + + def _handle_multi_process(self, func, input_param, result_df, lock): + process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1) + op_name_mapping_dict = self.read_dump_data(result_df) + + df_chunk_size = len(result_df) // process_num + if df_chunk_size > 0: + df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)] + else: + df_chunks = [result_df] + + results = [] + pool = multiprocessing.Pool(process_num) + + def err_call(args): + logger.error('multiprocess compare failed! Reason: {}'.format(args)) + try: + pool.close() + except OSError: + logger.error("pool terminate failed") + + progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100) + + def update_progress(size, progress_lock, extra_param=None): + with progress_lock: + progress_bar.update(size) + + for process_idx, df_chunk in enumerate(df_chunks): + idx = df_chunk_size * process_idx + chunk_size = len(df_chunk) + result = pool.apply_async(func, + args=(idx, op_name_mapping_dict, df_chunk, lock, input_param), + error_callback=err_call, + callback=partial(update_progress, chunk_size, lock) + ) + results.append(result) + + final_results = [r.get() for r in results] + pool.close() + pool.join() + return pd.concat(final_results, ignore_index=True) diff --git a/debug/accuracy_tools/msprobe/core/compare/npy_compare.py b/debug/accuracy_tools/msprobe/core/compare/npy_compare.py index c551985780cb9b56e32573727f9bf88f274da24e..b0df6017ffd22bb5b40f25f76dead9f8d335e577 100644 --- a/debug/accuracy_tools/msprobe/core/compare/npy_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/npy_compare.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -56,13 +56,7 @@ def get_error_flag_and_msg(n_value, b_value, error_flag=False, error_file=None): """判断数据是否有异常并返回异常的n_value, b_value,同时返回error_flag和error_msg""" err_msg = "" if error_flag: - if error_file == "no_bench_data": - err_msg = "Bench does not have data file." - elif error_file: - err_msg = f"Dump file: {error_file} not found." - else: - err_msg = CompareConst.NO_BENCH - error_flag = True + err_msg = f"Dump file: {error_file} not found or read failed." return CompareConst.READ_NONE, CompareConst.READ_NONE, error_flag, err_msg if n_value.size == 0: # 判断读取到的数据是否为空 @@ -70,7 +64,7 @@ def get_error_flag_and_msg(n_value, b_value, error_flag=False, error_file=None): error_flag = True return CompareConst.NONE, CompareConst.NONE, error_flag, err_msg if not n_value.shape: # 判断数据是否为0维张量 - err_msg = (f"This is type of 0-d tensor, can not calculate '{CompareConst.COSINE}', " + err_msg = (f"This is type of 0-d tensor, can not calculate '{CompareConst.COSINE}', '{CompareConst.EUC_DIST}', " f"'{CompareConst.ONE_THOUSANDTH_ERR_RATIO}' and '{CompareConst.FIVE_THOUSANDTHS_ERR_RATIO}'. ") error_flag = False # 0-d tensor 最大绝对误差、最大相对误差仍然支持计算,因此error_flag设置为False,不做统一处理 return n_value, b_value, error_flag, err_msg @@ -168,16 +162,19 @@ def statistics_data_check(result_dict): class TensorComparisonBasic(abc.ABC): """NPU和bench中npy数据的比较模板""" + @abc.abstractmethod - def apply(self, n_value, b_value, relative_err): + def apply(self, n_value, b_value, relative_err, err_msg): raise NotImplementedError def get_relative_err(n_value, b_value): """计算相对误差""" with np.errstate(divide='ignore', invalid='ignore'): + if n_value.dtype not in CompareConst.FLOAT_TYPE: + n_value = n_value.astype(float) if b_value.dtype not in CompareConst.FLOAT_TYPE: - n_value, b_value = n_value.astype(float), b_value.astype(float) + b_value = b_value.astype(float) n_value_copy = n_value.copy() b_value_copy = b_value.copy() @@ -190,6 +187,7 @@ def get_relative_err(n_value, b_value): class GetCosineSimilarity(TensorComparisonBasic): """计算cosine相似度""" + @staticmethod def correct_data(result): if result == CompareConst.NAN: @@ -198,9 +196,9 @@ class GetCosineSimilarity(TensorComparisonBasic): return round(float(result), 6) return result - def apply(self, n_value, b_value, relative_err): - if not n_value.shape: - return CompareConst.UNSUPPORTED, "" + def apply(self, n_value, b_value, relative_err, err_msg): + if "This is type of 0-d tensor" in err_msg: + return CompareConst.UNSUPPORTED, err_msg with np.errstate(divide="ignore", invalid="ignore"): if len(n_value) == 1: @@ -224,9 +222,22 @@ class GetCosineSimilarity(TensorComparisonBasic): return result, "" +class GetEuclideanDistance(TensorComparisonBasic): + """计算欧式距离""" + + def apply(self, n_value, b_value, relative_err, err_msg): + if "This is type of 0-d tensor" in err_msg: + return CompareConst.UNSUPPORTED, err_msg + + distance = np.linalg.norm(n_value - b_value, ord=2) + + return distance, "" + + class GetMaxAbsErr(TensorComparisonBasic): """计算最大绝对误差""" - def apply(self, n_value, b_value, relative_err): + + def apply(self, n_value, b_value, relative_err, err_msg): temp_res = n_value - b_value max_value = np.max(np.abs(temp_res)) if np.isnan(max_value): @@ -237,7 +248,8 @@ class GetMaxAbsErr(TensorComparisonBasic): class GetMaxRelativeErr(TensorComparisonBasic): """计算最大相对误差""" - def apply(self, n_value, b_value, relative_err): + + def apply(self, n_value, b_value, relative_err, err_msg): max_relative_err = np.max(np.abs(relative_err)) if np.isnan(max_relative_err): msg = "Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data." @@ -247,12 +259,13 @@ class GetMaxRelativeErr(TensorComparisonBasic): class GetErrRatio(TensorComparisonBasic): """计算相对误差小于指定阈值(千分之一、千分之五)的比例""" + def __init__(self, threshold): self.threshold = threshold - def apply(self, n_value, b_value, relative_err): - if not n_value.shape: - return CompareConst.UNSUPPORTED, "" + def apply(self, n_value, b_value, relative_err, err_msg): + if "This is type of 0-d tensor" in err_msg: + return CompareConst.UNSUPPORTED, err_msg if not np.size(relative_err): return CompareConst.NAN, "" @@ -264,6 +277,7 @@ class GetErrRatio(TensorComparisonBasic): class CompareOps: compare_ops = { "cosine_similarity": GetCosineSimilarity(), + "euclidean_distance": GetEuclideanDistance(), "max_abs_error": GetMaxAbsErr(), "max_relative_error": GetMaxRelativeErr(), "one_thousand_err_ratio": GetErrRatio(CompareConst.THOUSAND_RATIO_THRESHOLD), @@ -272,10 +286,9 @@ class CompareOps: def error_value_process(n_value): - if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE: + if n_value in [CompareConst.READ_NONE, CompareConst.UNREADABLE, CompareConst.NONE, + CompareConst.NO_REAL_DATA, CompareConst.API_UNMATCH]: return CompareConst.UNSUPPORTED, "" - if n_value == CompareConst.NONE: - return 0, "" if n_value == CompareConst.SHAPE_UNMATCH: return CompareConst.SHAPE_UNMATCH, "" if n_value == CompareConst.NAN: @@ -288,14 +301,14 @@ def compare_ops_apply(n_value, b_value, error_flag, err_msg): if error_flag: result, msg = error_value_process(n_value) result_list = [result] * len(CompareOps.compare_ops) - err_msg += msg * len(CompareOps.compare_ops) + err_msg += msg return result_list, err_msg relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) for op in CompareOps.compare_ops.values(): - result, msg = op.apply(n_value, b_value, relative_err) + result, msg = op.apply(n_value, b_value, relative_err, err_msg) result_list.append(result) err_msg += msg return result_list, err_msg diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py index a2edf57e5bb91400675fe01734ea7fbf0e1df893..5503bdd3feec45e7e96bd4dfbafb4ac110cf7e68 100644 --- a/debug/accuracy_tools/msprobe/core/compare/utils.py +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -18,35 +18,48 @@ import re import math import zlib from dataclasses import dataclass +import multiprocessing import numpy as np +import pandas as pd from msprobe.core.common.const import Const, CompareConst, FileCheckConst from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger, safe_get_value -from msprobe.core.common.file_utils import check_file_or_directory_path +from msprobe.core.common.file_utils import check_file_or_directory_path, load_json +json_file_mapping = { + Const.DUMP_JSON_FILE: "dump.json", + Const.DEBUG_JSON_FILE: "debug.json", + Const.STACK_JSON_FILE: "stack.json" +} -def extract_json(dirname, stack_json=False): + +def extract_json(dirname, json_file_type): json_path = '' for filename in os.listdir(dirname): - target_file_name = 'stack.json' if stack_json else 'dump.json' + target_file_name = json_file_mapping.get(json_file_type) + if target_file_name is None: + logger.error(f'extract_json failed, invalid json_file_type: {json_file_type}.') + raise CompareException(CompareException.INVALID_KEY_ERROR) if filename == target_file_name: json_path = os.path.join(dirname, filename) break # Provide robustness on invalid directory inputs if not json_path: - if stack_json: + if json_file_type == Const.STACK_JSON_FILE: logger.warning(f'stack.json is not found in dump dir {dirname}.') - else: + elif json_file_type == Const.DUMP_JSON_FILE: logger.error(f'dump.json is not found in dump dir {dirname}.') - raise CompareException(CompareException.NO_DUMP_FILE_ERROR) + elif json_file_type == Const.DEBUG_JSON_FILE: + logger.warning(f'debug.json is not found in dump dir {dirname}.') + return json_path def set_stack_json_path(input_param): npu_data_dir = os.path.dirname(input_param.get("npu_json_path")) - stack_path = extract_json(npu_data_dir, stack_json=True) + stack_path = extract_json(npu_data_dir, json_file_type=Const.STACK_JSON_FILE) input_param["stack_json_path"] = stack_path if stack_path else None return bool(stack_path) @@ -81,46 +94,40 @@ def check_and_return_dir_contents(dump_dir, prefix): return contents -def rename_api(npu_name, process): - """ - 原api: {api_type}.{api_name}.{API调用次数}.{前向反向}.{input/output}.{参数序号} - rename后: {api_type}.{api_name}.{input/output}.{参数序号} - """ - npu_split = npu_name.split(process) - try: - torch_func_index, in_out = npu_split[0], npu_split[1] - except IndexError as error: - logger.error(f'{npu_name} can not be split with {process}, please check!') - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error - torch_func_split = torch_func_index.rsplit(Const.SEP, 2) - torch_func = str(torch_func_split[0]) + str(in_out) - return torch_func - - def read_op(op_data, op_name): - if Const.PARAMS_GRAD in op_name.split(Const.SEP): - op_parsed_list = op_item_parse(op_data, op_name) + if not isinstance(op_name, str): + logger.error(f"api name error: {op_name} is not a string, please check.") + raise CompareException(CompareException.INVALID_API_NAME_ERROR) + split_name = op_name.split(Const.SEP) + if split_name[-1] == Const.DEBUG: + op_parsed_list = op_item_parse(op_data, op_name, Const.DEBUG) + elif split_name[-1] == Const.PARAMS_GRAD: + op_parsed_list = op_item_parse(op_data, op_name, Const.PARAMS_GRAD) else: op_parsed_list = [] for name in CompareConst.IO_NAME_MAPPING: if name in op_data: - op_parsed_list.extend(op_item_parse(op_data[name], op_name + CompareConst.IO_NAME_MAPPING[name])) + op_parsed_list.extend(op_item_parse(op_data[name], op_name + CompareConst.IO_NAME_MAPPING[name], name)) return op_parsed_list -def op_item_parse(op_data, op_name: str, depth: int = 0) -> list: +def op_item_parse(op_data, op_name: str, state: str, depth: int = 0) -> list: + if state == Const.INPUT_ARGS or state == Const.INPUT_KWARGS: + state = Const.INPUT default_item = { 'full_op_name': op_name, - 'type': None, - 'Max': None, - 'Min': None, - 'Mean': None, - 'Norm': None, - 'dtype': None, - 'shape': None, - 'md5': None, - 'value': None, - 'data_name': '-1' + Const.TYPE: None, + Const.MAX: None, + Const.MIN: None, + Const.MEAN: None, + Const.NORM: None, + Const.DTYPE: None, + Const.SHAPE: None, + Const.MD5: None, + Const.VALUE: None, + Const.DATA_NAME: '-1', + Const.STATE: state, + Const.REQ_GRAD: None } if depth > Const.MAX_DEPTH: @@ -136,33 +143,53 @@ def op_item_parse(op_data, op_name: str, depth: int = 0) -> list: if isinstance(op_data, list): for i, data in enumerate(op_data): if Const.PARAMS_GRAD not in op_name.split(Const.SEP): - item_list.extend(op_item_parse(data, op_name + Const.SEP + str(i), depth + 1)) + item_list.extend(op_item_parse(data, op_name + Const.SEP + str(i), state, depth + 1)) else: - item_list.extend(op_item_parse(data, op_name, depth + 1)) + item_list.extend(op_item_parse(data, op_name, state, depth + 1)) elif isinstance(op_data, dict): + if is_p2pop_leaf_data(op_data): + p2pop_item = {} + for key in ['class_type', 'op', 'peer', 'tag', 'group_id']: + p2pop_item[key] = op_data.get(key) + op_data = op_data.get('tensor') + if isinstance(op_data, dict): + op_item = gen_op_item(op_data, op_name, state) + else: + op_item = default_item + op_item.update(p2pop_item) + return [op_item] if is_leaf_data(op_data): - return [gen_op_item(op_data, op_name)] + return [gen_op_item(op_data, op_name, state)] for sub_name, sub_data in op_data.items(): - item_list.extend(op_item_parse(sub_data, op_name + Const.SEP + str(sub_name), depth + 1)) + item_list.extend(op_item_parse(sub_data, op_name + Const.SEP + str(sub_name), state, depth + 1)) return item_list +def is_p2pop_leaf_data(op_data): + return op_data.get('class_type') == 'torch.distributed.P2POp' + + def is_leaf_data(op_data): return 'type' in op_data and isinstance(op_data['type'], str) -def gen_op_item(op_data, op_name): +def gen_op_item(op_data, op_name, state): op_item = {} - op_item.update(op_data) - data_name = op_data.get('data_name') if op_data.get('data_name') else '-1' # 如果是""也返回-1 - op_item['data_name'] = data_name + op_item.update({key: str(value) if isinstance(value, bool) else value for key, value in op_data.items()}) + data_name = op_data.get(Const.DATA_NAME) if op_data.get(Const.DATA_NAME) else '-1' # 如果是""也返回-1 + op_item[Const.DATA_NAME] = data_name op_item['full_op_name'] = data_name.rsplit(Const.SEP, 1)[0] if data_name != '-1' else op_name + op_item[Const.STATE] = state + if Const.REQ_GRAD not in op_item: + op_item[Const.REQ_GRAD] = None - params = ['Max', 'Min', 'Mean', 'Norm'] + # 补齐统计量字段 + params = [Const.MAX, Const.MIN, Const.MEAN, Const.NORM] for i in params: if i not in op_item: op_item[i] = None + # special cases if not op_item.get('dtype'): if op_item.get('type') == 'torch.Size': op_item['dtype'] = op_data.get('type') @@ -175,11 +202,18 @@ def gen_op_item(op_data, op_name): op_item['shape'] = '[]' for i in params: op_item[i] = op_data.get('value') + elif op_name.split(Const.SEP)[-1] in ['src', 'dst', 'group_src', 'group_dst']: + op_item['dtype'] = op_data.get('type') + op_item['shape'] = '[]' + for i in params: + op_item[i] = str(op_data.get('value')) + op_item['md5'] = str(op_data.get('value')) elif op_item.get('type') == 'torch.ProcessGroup': op_item['dtype'] = op_data.get('type') op_item['shape'] = '[]' for i in params: op_item[i] = str(op_data.get('group_ranks')) + op_item['md5'] = str(op_data.get('group_ranks')) else: op_item['dtype'] = str(type(op_data.get('value'))) op_item['shape'] = '[]' @@ -191,35 +225,189 @@ def gen_op_item(op_data, op_name): return op_item -def resolve_api_special_parameters(data_dict, full_op_name, item_list): +@dataclass +class ApiItemInfo: + name: str + struct: tuple + stack_info: list + + +def merge_tensor(tensor_list, dump_mode): + keys = [ + CompareConst.OP_NAME, + CompareConst.INPUT_STRUCT, + CompareConst.KWARGS_STRUCT, + CompareConst.OUTPUT_STRUCT, + CompareConst.PARAMS_STRUCT, + CompareConst.PARAMS_GRAD_STRUCT, + CompareConst.DEBUG_STRUCT, + Const.SUMMARY, + Const.STACK_INFO, + Const.STATE, + Const.REQ_GRAD + ] + op_dict = {key: [] for key in keys} + + if dump_mode == Const.ALL: + op_dict[Const.DATA_NAME] = [] + + for tensor in tensor_list: + # A dict(len=2) with 'full_op_name' and 'full_info' is added to the tensor only if self.stack_mode is True + if len(tensor) == 2: + op_dict[Const.STACK_INFO].append(tensor.get('full_info')) + break + + op_dict[CompareConst.OP_NAME].append(tensor.get('full_op_name')) + state = tensor.get(Const.STATE) + op_dict[Const.STATE].append(state) + op_dict[Const.REQ_GRAD].append(tensor.get(Const.REQ_GRAD)) + + struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state) + if not struct_key: + continue + if dump_mode == Const.MD5: + op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5])) + else: + op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE])) + + # 当统计量为None时,转成字符串None,避免后续操作list放到pd中时None被默认转成NaN + op_dict[Const.SUMMARY].append( + [str(tensor[key]) if tensor[key] is None else tensor[key] for key in Const.SUMMARY_METRICS_LIST]) + + if dump_mode == Const.ALL: + op_dict[Const.DATA_NAME].append(tensor.get(Const.DATA_NAME)) + + if not op_dict[CompareConst.KWARGS_STRUCT]: + del op_dict[CompareConst.KWARGS_STRUCT] + return op_dict if op_dict[CompareConst.OP_NAME] else {} + + +def print_compare_ends_info(): + total_len = len(CompareConst.COMPARE_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS + logger.info('*' * total_len) + logger.info(f"*{CompareConst.COMPARE_ENDS_SUCCESSFULLY.center(total_len - 2)}*") + logger.info('*' * total_len) + + +def table_value_is_valid(value: str) -> bool: + if not isinstance(value, str): + return True + try: + # -1.00 or +1.00 should be considered as digit numbers + float(value) + except ValueError: + # otherwise, they will be considered as formular injections + return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value)) + return True + + +class ApiBatch: + def __init__(self, api_name: str, start: int): + self.api_name = api_name + self.start = start + self.input_len = 1 # input的数量 + self.params_end_index = start + 1 # params的结束index + self.output_end_index = start + 1 # output的结束index + self.params_grad_end_index = start + 1 # params_grad的结束index + # 内部state的标志("input", "output", "parameters", "parameters_grad"), + # 用于控制计算input_len, output_end_index, params_end_index, self.params_grad_end_index + self._state = Const.INPUT # api_batch初始化为input + + def set_state(self, state: str): + """设置当前状态""" + if state in {Const.INPUT, Const.OUTPUT, Const.KWARGS, Const.PARAMS, Const.PARAMS_GRAD}: + self._state = state + else: + raise ValueError(f"Invalid state: {state}") + + def increment(self, state: str): + self.set_state(state) + if self._state == Const.INPUT or self._state == Const.KWARGS: + self.input_len += 1 + self.params_end_index += 1 + self.output_end_index += 1 + if self._state == Const.PARAMS: + self.params_end_index += 1 + self.output_end_index += 1 + if self._state == Const.OUTPUT: + self.output_end_index += 1 + self.params_grad_end_index += 1 + + +def api_batches_update(api_batches, api_name, state, index): """ - Function Description: - 解析下面格式的数据, 是api参数的一种特殊格式 - { - "last_hidden_state": { - "type": "torch.Tensor", - "dtype": "torch.bfloat16", - ... - }, - "loss": { - "type": "torch.Tensor", - "dtype": "torch.float32", - ... - } - } - Parameter: - data_dict: 字典格式的数据 - full_op_name: 参数的全名字符串 - item_list: 参数信息集合 + 当一个api的所有item更新完后,input, output的索引范围: + input: [start: start+input_len] + output: [start+input_len: output_end_index] + params: [output_end_index: params_end_index] """ - for key, value in data_dict.items(): - if isinstance(value, dict): - parsed_item = value - parts = full_op_name.split(Const.SEP) - parts.insert(-1, key) - full_op_name_new = ".".join(parts) - parsed_item['full_op_name'] = full_op_name_new - item_list.append(parsed_item) + if not api_batches: + api_batches.append(ApiBatch(api_name, index)) + else: + api_batch = api_batches[-1] + if api_batch.api_name == api_name or ( + not re.search(Const.REGEX_FORWARD_BACKWARD, api_name) and api_name in api_batch.api_name): + try: + api_batch.increment(state) + except ValueError as e: + logger.error(f"api_batch: {api_batch} with invalid state, please check! {e}") + raise CompareException(CompareException.INVALID_STATE_ERROR) from e + else: + api_batches.append(ApiBatch(api_name, index)) + + +def reorder_index(op_parsed_list): + """ + 对单个api解析的op_items的index进行重排,将parameter的index放到output前面,返回新的重排后的index列表,op_parsed_list不变 + """ + index_param = [] + index_output = [] + index_param_grad = [] + index_other = [] + for i, op_item in enumerate(op_parsed_list[:-1]): + state = op_item.get(Const.STATE) + if state == Const.PARAMS: + index_param.append(i) + elif state == Const.OUTPUT: + index_output.append(i) + elif state == Const.PARAMS_GRAD: + index_param_grad.append(i) + else: + index_other.append(i) + # 合并others, parameters, 和output,确保parameters排在output前面 + reordered_index_list = index_other + index_param + index_output + index_param_grad + return reordered_index_list + + +def reorder_op_name_list(op_name_list, state_list): + if not op_name_list: + return op_name_list, state_list + + parameters = [] + output = [] + parameters_grad = [] + others = [] + parameters_s = [] + output_s = [] + parameters_grad_s = [] + others_s = [] + for op_name, state in zip(op_name_list, state_list): + if state == Const.PARAMS: + parameters.append(op_name) + parameters_s.append(state) + elif state == Const.OUTPUT: + output.append(op_name) + output_s.append(state) + elif state == Const.PARAMS_GRAD: + parameters_grad.append(op_name) + parameters_grad_s.append(state) + else: + others.append(op_name) + others_s.append(state) + # 合并others, parameters, 和output,确保parameters排在output前面 + op_name_reorder = others + parameters + output + parameters_grad + state_reorder = others_s + parameters_s + output_s + parameters_grad_s + return op_name_reorder, state_reorder def process_summary_data(summary_data): @@ -273,21 +461,26 @@ def stack_column_process(result_item, has_stack, index, key, npu_stack_info): return result_item -def result_item_init(n_info, b_info, dump_mode): +def result_item_init(n_info, b_info, requires_grad_pair, dump_mode): n_len = len(n_info.struct) b_len = len(b_info.struct) + # requires_grad_pair内部创建,固定两个元素 + n_requires_grad = requires_grad_pair[0] + b_requires_grad = requires_grad_pair[1] + req_grad_consist = n_requires_grad == b_requires_grad struct_long_enough = (n_len > 2 and b_len > 2) if dump_mode == Const.MD5 else (n_len > 1 and b_len > 1) if struct_long_enough: result_item = [ - n_info.name, b_info.name, n_info.struct[0], b_info.struct[0], n_info.struct[1], b_info.struct[1] + n_info.name, b_info.name, n_info.struct[0], b_info.struct[0], n_info.struct[1], b_info.struct[1], + n_requires_grad, b_requires_grad ] if dump_mode == Const.MD5: md5_compare_result = CompareConst.PASS if n_info.struct[2] == b_info.struct[2] else CompareConst.DIFF - result_item.extend([n_info.struct[2], b_info.struct[2], md5_compare_result]) + result_item.extend([n_info.struct[2], b_info.struct[2], req_grad_consist, md5_compare_result]) elif dump_mode == Const.SUMMARY: - result_item.extend([" "] * 8) + result_item.extend([" "] * 8) # 8个统计量数据情况的比对指标 else: - result_item.extend([" "] * 5) + result_item.extend([" "] * 6) # 6个真实数据情况的比对指标 else: err_msg = "index out of bounds error will occur in result_item_init, please check!\n" \ f"npu_info_struct is {n_info.struct}\n" \ @@ -321,19 +514,23 @@ def get_accuracy(result, n_dict, b_dict, dump_mode): has_stack = npu_stack_info and bench_stack_info if dump_mode == Const.ALL: - npu_data_name = n_dict.get("data_name", None) - bench_data_name = b_dict.get("data_name", None) + npu_data_name_list = n_dict.get("data_name", None) + bench_data_name_list = b_dict.get("data_name", None) for index in range(min_len): n_name = safe_get_value(n_dict, n_start + index, "n_dict", key="op_name") b_name = safe_get_value(b_dict, b_start + index, "b_dict", key="op_name") n_struct = safe_get_value(n_dict, index, "n_dict", key=key) b_struct = safe_get_value(b_dict, index, "b_dict", key=key) + n_requires_grad = safe_get_value(n_dict, n_start + index, "n_dict", key='requires_grad') + b_requires_grad = safe_get_value(b_dict, b_start + index, "b_dict", key='requires_grad') + requires_grad_pair = [n_requires_grad, b_requires_grad] + req_grad_consist = n_requires_grad == b_requires_grad err_msg = "" npu_info = ApiItemInfo(n_name, n_struct, npu_stack_info) bench_info = ApiItemInfo(b_name, b_struct, bench_stack_info) - result_item = result_item_init(npu_info, bench_info, dump_mode) + result_item = result_item_init(npu_info, bench_info, requires_grad_pair, dump_mode) if dump_mode == Const.MD5: result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info) @@ -349,34 +546,45 @@ def get_accuracy(result, n_dict, b_dict, dump_mode): result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data, bench_summary_data, err_msg) + result_item.append(req_grad_consist) + err_msg += "Requires_grad inconsistent." if not req_grad_consist else "" result_item.append(accuracy_check if dump_mode == Const.SUMMARY else CompareConst.ACCURACY_CHECK_YES) result_item.append(err_msg) result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info) if dump_mode == Const.ALL: - result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name")) + npu_data_name = safe_get_value(npu_data_name_list, n_start + index, "npu_data_name_list") + bench_data_name = safe_get_value(bench_data_name_list, b_start + index, "bench_data_name_list") + result_item.append([npu_data_name, bench_data_name]) result.append(result_item) if n_len > b_len: for index in range(b_len, n_len): try: - n_name = n_dict['op_name'][n_start + index] - n_struct = n_dict[key][index] + n_name = safe_get_value(n_dict, n_start + index, "n_dict", key="op_name") + n_struct = safe_get_value(n_dict, index, "n_dict", key=key) + n_requires_grad = safe_get_value(n_dict, n_start + index, "n_dict", key='requires_grad') + if dump_mode == Const.MD5: result_item = [ n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN, - n_struct[2], CompareConst.NAN, CompareConst.NAN + n_requires_grad, CompareConst.NAN, + n_struct[2], CompareConst.NAN, + False, + CompareConst.NAN ] result.append(result_item) continue result_item = [ n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN, - " ", " ", " ", " ", " " + n_requires_grad, CompareConst.NAN, + " ", " ", " ", " ", " ", " " ] summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index] result_item.extend(summary_data) summary_data = [CompareConst.NAN for _ in range(len(n_dict.get(CompareConst.SUMMARY)[0]))] result_item.extend(summary_data) + result_item.append(False) except IndexError as e: err_msg = "index out of bounds error occurs, please check!\n" \ f"n_dict is {n_dict}" @@ -388,12 +596,13 @@ def get_accuracy(result, n_dict, b_dict, dump_mode): result_item.append(err_msg) result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info) if dump_mode == Const.ALL: - result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name")) + npu_data_name = safe_get_value(npu_data_name_list, n_start + index, "npu_data_name_list") + result_item.append([npu_data_name, "-1"]) result.append(result_item) - n_num, n_num_input, n_num_output, n_num_params, n_num_params_grad = count_struct(n_dict) - b_num, b_num_input, b_num_output, b_num_params, b_num_params_grad = count_struct(b_dict) + _, n_num_input, n_num_output, n_num_params, n_num_params_grad = count_struct(n_dict) + _, b_num_input, b_num_output, b_num_params, b_num_params_grad = count_struct(b_dict) get_accuracy_core(0, n_num_input, 0, b_num_input, CompareConst.INPUT_STRUCT) get_accuracy_core(n_num_input + n_num_output, n_num_params, b_num_input + b_num_output, b_num_params, @@ -404,197 +613,40 @@ def get_accuracy(result, n_dict, b_dict, dump_mode): CompareConst.PARAMS_GRAD_STRUCT) -def append_stack_info(result_item, npu_stack_info, index): - """添加堆栈信息到 result_item""" - if npu_stack_info and index == 0: - result_item.extend(npu_stack_info) - else: - result_item.append(CompareConst.NONE) - +def make_result_table(result, dump_mode, stack_mode): + header = CompareConst.HEAD_OF_COMPARE_MODE[dump_mode][:] -def get_un_match_accuracy(result, n_dict, dump_mode): - npu_stack_info = n_dict.get("stack_info", None) - bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A - - struct_to_index_mapping = { - CompareConst.INPUT_STRUCT: 0, - CompareConst.OUTPUT_STRUCT: 0, - CompareConst.PARAMS_STRUCT: 0, - CompareConst.PARAMS_GRAD_STRUCT: 0 - } - - op_name_list = n_dict.get(CompareConst.OP_NAME) - summary_list = n_dict.get(Const.SUMMARY) - data_name_list = n_dict.get('data_name') - op_name_reorder, summary_reorder, _ = reorder_op_x_list(op_name_list, - summary_list, - data_name_list) - for index, n_name in enumerate(op_name_reorder): - _, state = get_name_and_state(n_name) - struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state) - if not struct_key: - continue - n_struct = safe_get_value(n_dict, struct_to_index_mapping.get(struct_key), "n_dict", key=struct_key) - struct_to_index_mapping[struct_key] += 1 - - try: - result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape] - except IndexError as e: - err_msg = "index out of bounds error occurs, please check!\n" \ - f"op_name of n_dict is {n_dict['op_name']}\n" \ - f"input_struct of n_dict is {n_dict[CompareConst.INPUT_STRUCT]}\n" \ - f"output_struct of n_dict is {n_dict[CompareConst.OUTPUT_STRUCT]}" - logger.error(err_msg) - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e - - if dump_mode == Const.MD5: - result_item.extend([CompareConst.N_A] * 3) - append_stack_info(result_item, npu_stack_info, index) - result.append(result_item) - continue - if dump_mode == Const.SUMMARY: - result_item.extend([CompareConst.N_A] * 8) + if stack_mode: + header.append(CompareConst.STACK) if dump_mode == Const.ALL: - result_item.extend([CompareConst.N_A] * 5) - - npu_summary_data = safe_get_value(summary_reorder, index, "summary_reorder") - bench_summary_data = [CompareConst.N_A] * 4 - result_item.extend(npu_summary_data) - result_item.extend(bench_summary_data) - err_msg = CompareConst.NO_BENCH - accuracy_check_res = CompareConst.N_A - result_item.append(accuracy_check_res) - result_item.append(err_msg) - append_stack_info(result_item, npu_stack_info, index) - if dump_mode == Const.ALL and result_item[1] == CompareConst.N_A: - result_item.extend(["-1"]) - result.append(result_item) - - -def merge_tensor(tensor_list, dump_mode): - op_dict = {} - op_dict["op_name"] = [] - op_dict[CompareConst.INPUT_STRUCT] = [] - op_dict[CompareConst.KWARGS_STRUCT] = [] - op_dict[CompareConst.OUTPUT_STRUCT] = [] - op_dict[CompareConst.PARAMS_STRUCT] = [] - op_dict[CompareConst.PARAMS_GRAD_STRUCT] = [] - op_dict[Const.SUMMARY] = [] - op_dict["stack_info"] = [] - - if dump_mode == Const.ALL: - op_dict["data_name"] = [] - - for tensor in tensor_list: - # A dict(len=2) with 'full_op_name' and 'full_info' is added to the tensor only if self.stack_mode is True - if len(tensor) == 2: - op_dict['stack_info'].append(tensor['full_info']) - break - - op_dict["op_name"].append(tensor['full_op_name']) - - _, state = get_name_and_state(tensor['full_op_name']) - struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state) - if not struct_key: - continue - if dump_mode == Const.MD5: - op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5])) - else: - op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE])) - op_dict[Const.SUMMARY].append([tensor[Const.MAX], tensor[Const.MIN], tensor[Const.MEAN], tensor[Const.NORM]]) - + header.append(CompareConst.DATA_NAME) + else: if dump_mode == Const.ALL: - op_dict["data_name"].append(tensor['data_name']) - - if not op_dict[CompareConst.KWARGS_STRUCT]: - del op_dict[CompareConst.KWARGS_STRUCT] - return op_dict if op_dict["op_name"] else {} - - -def print_compare_ends_info(): - total_len = len(CompareConst.COMPARE_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS - logger.info('*' * total_len) - logger.info(f"*{CompareConst.COMPARE_ENDS_SUCCESSFULLY.center(total_len - 2)}*") - logger.info('*' * total_len) - - -def table_value_is_valid(value: str) -> bool: - if not isinstance(value, str): - return True - try: - # -1.00 or +1.00 should be consdiered as digit numbers - float(value) - except ValueError: - # otherwise, they will be considered as formular injections - return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value)) - return True - - -def get_name_and_state(name): - """ - Get api/module name and state - example: - name = 'conv2d.forward.1.input.0' - return: ('conv2d.forward.1.', 'input') - - name = 'Functional.pad.0.backward.output.0' - return: ('Functional.pad.0.backward.', 'output') - - state type: input, output, kwargs, parameters, parameters_grad - """ - if Const.PARAMS_GRAD in name.split(Const.SEP): - return name.split(Const.PARAMS_GRAD)[0], Const.PARAMS_GRAD - - split = re.split(Const.REGEX_FORWARD_BACKWARD, name) - api = f'{split[0]}.{split[1]}.' - state_str = split[2] - match = re.match(r'^(\d+\.)?(input|output|kwargs|parameters)\..+$', state_str) - if not match: - raise CompareException(f'Invalid name string: {name}') - if match.group(1): - api = f'{api}{match.group(1)}' - state = match.group(2) - return api, state - - -def reorder_op_name_list(op_name_list): - if not op_name_list: - return op_name_list - - parameters = [] - output = [] - parameters_grad = [] - others = [] - for x in op_name_list: - state = get_name_and_state(x)[1] - if state == Const.PARAMS: - parameters.append(x) - elif state == Const.OUTPUT: - output.append(x) - elif state == Const.PARAMS_GRAD: - parameters_grad.append(x) + for row in result: + del row[-2] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,真实数据时为倒数第2列 + header.append(CompareConst.DATA_NAME) else: - others.append(x) - # 合并others, parameters, 和output,确保parameters排在output前面 - op_name_reorder = others + parameters + output + parameters_grad - return op_name_reorder + for row in result: + del row[-1] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,非真实数据时为倒数第1列 + result_df = pd.DataFrame(result, columns=header, dtype='object') + return result_df -def reorder_op_x_list(op_name_list, summary_list, data_name_list): - """对op_name, summary, data_name重新排序,把parameters放到input后output前,data_name由于统计量比对时,为None,单独处理""" - if not op_name_list or not summary_list: - return op_name_list, summary_list, data_name_list +def gen_api_batches(result: np.ndarray, header: list): + api_name_index = header.index(Const.API_ORIGIN_NAME) + state_name_index = header.index(Const.STATE) + api_batches = [] + for i, res_i in enumerate(result): + api_name = safe_get_value(res_i, api_name_index, "res_i") + state = safe_get_value(res_i, state_name_index, "res_i") + api_batches_update(api_batches, api_name, state, i) + return api_batches - index_map = {name: index for index, name in enumerate(op_name_list)} - op_name_reorder = reorder_op_name_list(op_name_list) - summary_reorder = [summary_list[index_map.get(name)] for name in op_name_reorder] - if data_name_list: - data_name_reorder = [data_name_list[index_map.get(name)] for name in op_name_reorder] - else: - data_name_reorder = data_name_list - - return op_name_reorder, summary_reorder, data_name_reorder +def get_paired_dirs(npu_path, bench_path): + npu_dirs = set(os.listdir(npu_path)) + bench_dirs = set(os.listdir(bench_path)) + return list(npu_dirs & bench_dirs) def _compare_parser(parser): @@ -609,6 +661,8 @@ def _compare_parser(parser): help=" Whether to give advisor.", required=False) parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true", help=" Whether to perform a fuzzy match on the api name.", required=False) + parser.add_argument("-hl", "--highlight", dest="highlight", action="store_true", + help=" Whether to set result highlighting.", required=False) parser.add_argument("-cm", "--cell_mapping", dest="cell_mapping", type=str, nargs='?', const=True, help=" The cell mapping file path.", required=False) parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True, @@ -617,3 +671,134 @@ def _compare_parser(parser): help=" The data mapping file path.", required=False) parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, nargs='?', const=True, help=" The layer mapping file path.", required=False) + parser.add_argument("-da", "--diff_analyze", dest="diff_analyze", action="store_true", + help=" Whether to perform a diff analyze on the api name.", required=False) + + +def get_sorted_ranks(npu_dump_dir, bench_dump_dir): + """ + get the ranks and match by order + """ + unsorted_npu_ranks = check_and_return_dir_contents(npu_dump_dir, 'rank') + unsorted_bench_ranks = check_and_return_dir_contents(bench_dump_dir, 'rank') + # 正则匹配已经校验rank后面必是数字,或者无数字的rank + npu_ranks = sorted(unsorted_npu_ranks, key=lambda x: int(x[4:]) if len(x) > 4 else -1) # 前四个字符都是rank,后面是卡号 + bench_ranks = sorted(unsorted_bench_ranks, key=lambda x: int(x[4:]) if len(x) > 4 else -1) + if len(npu_ranks) != len(bench_ranks): + logger.error('The number of ranks in the two runs are different. ' + 'Unable to match the ranks. Please use another folder to compare ' + 'or use compare() api and manually match the ranks.') + raise CompareException(CompareException.INVALID_PATH_ERROR) + return npu_ranks, bench_ranks + + +def multi_statistics_compare(func, func_args): + def err_call(args): + logger.error(f'Multiprocess statistics compare failed! Reason: {args}') + try: + pool.close() + except OSError: + logger.error("Pool terminate failed") + + compare_func, input_param_nr_list, output_path, kwargs = func_args + + param_num = len(input_param_nr_list) + process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1) + if param_num <= process_num: + process_num = param_num + chunks = [[input_param_nr] for input_param_nr in input_param_nr_list] + else: + chunk_size = param_num // process_num + remainder = param_num % process_num + chunks = [input_param_nr_list[i:i + chunk_size] for i in range(0, param_num - remainder, chunk_size)] + for i in range(remainder): + chunks[i].append(input_param_nr_list[param_num - remainder + i]) + + pool = multiprocessing.Pool(process_num) + for chunk in chunks: + pool.apply_async(func, args=(compare_func, chunk, output_path, kwargs), error_callback=err_call) + pool.close() + pool.join() + + +def mp_logger_init(ranks_str): + """ + 多进程比对需要对logger进行wrap和patch,在日志前加上卡号信息,从而实现不同进程日志的隔离 + """ + + def wrap_logger(fn): + def inner(msg, *args, **kwargs): + return fn(ranks_str + msg, *args, **kwargs) + return inner + + logger.info = wrap_logger(logger.info) + logger.warning = wrap_logger(logger.warning) + logger.error = wrap_logger(logger.error) + + +def multi_ranks_compare(compare_func, input_param_nr_list, output_path, kwargs): + """ + 将多卡数据分成多进程后,单进程内可能还有多张卡的数据,因此还需要多次比对 + """ + rank_list = [input_param_nr[1] for input_param_nr in input_param_nr_list] # input_param_nr内部数据结构,2元素tuple + ranks_str = f"[{' '.join(rank_list)}]" + mp_logger_init(ranks_str) + for input_param_nr in input_param_nr_list: + input_param, nr = input_param_nr + compare_entry(compare_func, input_param, output_path, nr, kwargs) + + +def compare_entry(compare_func, input_param, output_path, nr, kwargs): + try: + compare_func(input_param=input_param, output_path=output_path, suffix=f'_{nr}', **kwargs) + except CompareException as e: + if e.code == CompareException.INVALID_DATA_ERROR: + logger.error(f"Invalid or missing 'data' in dump.json. Skipping {nr} comparison.") + if e.code == CompareException.INVALID_TASK_ERROR: + logger.error(f"Invalid or missing 'task' in dump.json. Skipping {nr} comparison.") + + +def compare_distributed_inner(npu_dump_dir, bench_dump_dir, output_path, compare_func, **kwargs): + def extract_compare_param(_file_type): + npu_data_dir = os.path.join(npu_dump_dir, nr) + bench_data_dir = os.path.join(bench_dump_dir, br) + npu_path = extract_json(npu_data_dir, _file_type) + bench_path = extract_json(bench_data_dir, _file_type) + if npu_path == "" or bench_path == "": + logger.debug(f'Did not find paired {_file_type} in {nr} and {br}, skip comparing.') + return {}, True + _input_param = { + 'npu_json_path': npu_path, + 'bench_json_path': bench_path, + 'is_print_compare_log': kwargs.get('is_print_compare_log', True) + } + return _input_param, False + + if kwargs.get('suffix'): + logger.error("Argument 'suffix' is not supported for compare_distributed.") + raise CompareException(CompareException.INVALID_PARAM_ERROR) + + npu_ranks, bench_ranks = get_sorted_ranks(npu_dump_dir, bench_dump_dir) + + # 统计量、md5比对 + pre_check_dump_path = os.path.join(npu_dump_dir, npu_ranks[0], 'dump.json') if npu_ranks else '' + if not pre_check_dump_path: + return + dump_data = load_json(pre_check_dump_path) + if dump_data.get('task') == Const.STATISTICS: + # dump数据为统计量或md5时,多进程加速比对 + input_param_nr_list = [] + for nr, br in zip(npu_ranks, bench_ranks): + input_param, skip = extract_compare_param(Const.DUMP_JSON_FILE) + if not skip: + input_param_nr_list.append((input_param, nr)) + func_args = (compare_func, input_param_nr_list, output_path, kwargs) + multi_statistics_compare(multi_ranks_compare, func_args) + return + + # 真实数据比对 + for nr, br in zip(npu_ranks, bench_ranks): + for file_type in [Const.DUMP_JSON_FILE, Const.DEBUG_JSON_FILE]: + input_param, skip = extract_compare_param(file_type) + if not skip: + compare_entry(compare_func, input_param, output_path, nr, kwargs) diff --git a/debug/accuracy_tools/msprobe/core/config_check/__init__.py b/debug/accuracy_tools/msprobe/core/config_check/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..621122ffa00ba40a868853ccb46ff582c3e5fdda --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/config_check/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import msprobe.core.config_check.checkers +from msprobe.core.config_check.config_checker import ConfigChecker diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/kernel_dump/kernel_config.py b/debug/accuracy_tools/msprobe/core/config_check/checkers/__init__.py similarity index 52% rename from debug/accuracy_tools/msprobe/pytorch/dump/kernel_dump/kernel_config.py rename to debug/accuracy_tools/msprobe/core/config_check/checkers/__init__.py index 48d0918ca68d7f429cc97fc64c5ba7d7f884960b..9b9024b862f1f60655d2f71a47ab401546a86076 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/kernel_dump/kernel_config.py +++ b/debug/accuracy_tools/msprobe/core/config_check/checkers/__init__.py @@ -13,21 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os +__all__ = ['BaseChecker', 'apply_patches'] -from msprobe.core.common.file_utils import save_json +import msprobe.core.config_check.checkers.env_args_checker +import msprobe.core.config_check.checkers.pip_checker +import msprobe.core.config_check.checkers.dataset_checker +import msprobe.core.config_check.checkers.weights_checker +import msprobe.core.config_check.checkers.hyperparameter_checker +import msprobe.core.config_check.checkers.random_checker - -def create_kernel_config_json(dump_path, cur_rank): - kernel_config_name = "kernel_config.json" if cur_rank == '' else f"kernel_config_{cur_rank}.json" - kernel_config_path = os.path.join(dump_path, kernel_config_name) - config_info = { - "dump": { - "dump_list": [], - "dump_path": dump_path, - "dump_mode": "all", - "dump_op_switch": "on" - } - } - save_json(kernel_config_path, config_info, indent=4) - return kernel_config_path +from msprobe.core.config_check.checkers.base_checker import BaseChecker diff --git a/debug/accuracy_tools/msprobe/core/config_check/checkers/base_checker.py b/debug/accuracy_tools/msprobe/core/config_check/checkers/base_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..1e6e002edf4830633b791650c73943000c631bba --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/config_check/checkers/base_checker.py @@ -0,0 +1,62 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from msprobe.core.common.file_utils import check_path_pattern_valid +from msprobe.core.common.framework_adapter import FmkAdp +from msprobe.core.common.const import FileCheckConst + + +class PackInput: + + def __init__(self, output_zip_path, model, shell_path): + self.output_zip_path = output_zip_path + self.shell_path = shell_path + self.model = model[0] if isinstance(model, list) and len(model) > 0 else model + self.check_input_params() + + def check_input_params(self): + if self.model and not FmkAdp.is_nn_module(self.model): + raise Exception(f"model is not torch.nn.Module/mindspore.nn.Cell or module list.") + if not isinstance(self.output_zip_path, str) or not self.output_zip_path.endswith(FileCheckConst.ZIP_SUFFIX): + raise Exception(f"output zip path must be a string and ends with '.zip'") + check_path_pattern_valid(self.output_zip_path) + + +class BaseChecker: + input_needed = None + target_name_in_zip = None + multi_rank = False + + @staticmethod + def pack(pack_input): + pass + + @staticmethod + def compare(bench_dir, cmp_dir, output_path, fmk): + pass + + @staticmethod + def apply_patches(fmk): + pass + + @classmethod + def compare_ex(cls, bench_dir, cmp_dir, output_path, fmk): + bench_filepath = os.path.join(bench_dir, cls.target_name_in_zip) + cmp_filepath = os.path.join(cmp_dir, cls.target_name_in_zip) + if not os.path.exists(bench_filepath) or not os.path.exists(cmp_filepath): + return None, None, None + return cls.compare(bench_dir, cmp_dir, output_path, fmk) diff --git a/debug/accuracy_tools/msprobe/core/config_check/checkers/dataset_checker.py b/debug/accuracy_tools/msprobe/core/config_check/checkers/dataset_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..84af2b29aba7eff13f94eb6ab939058053b3d01e --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/config_check/checkers/dataset_checker.py @@ -0,0 +1,139 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import pandas as pd +from msprobe.core.common.file_utils import create_file_in_zip, load_json +from msprobe.core.config_check.checkers.base_checker import BaseChecker +from msprobe.core.config_check.config_checker import register_checker_item, register_pre_forward_fun_list +from msprobe.core.config_check.utils.utils import config_checking_print, get_tensor_features +from msprobe.core.common.decorator import recursion_depth_decorator +from msprobe.core.common.framework_adapter import FmkAdp +from msprobe.core.common.const import Const + + +@recursion_depth_decorator("config_check: process_obj") +def process_obj(obj): + if FmkAdp.is_tensor(obj): + return get_tensor_features(obj) + elif isinstance(obj, (tuple, list)): + return {i: process_obj(x) for i, x in enumerate(obj)} + elif isinstance(obj, dict): + return {k: process_obj(v) for k, v in obj.items()} + else: + return "" + + +def parse_args_and_kargs(args, kwargs): + processed_args = process_obj(args) + processed_kargs = process_obj(kwargs) + + return { + 'args': processed_args, + 'kwargs': processed_kargs + } + + +@recursion_depth_decorator("config_check: compare_dataset_dicts") +def compare_dataset_dicts(dict1, dict2, tag=''): + results = [] + # 处理 dict1 中的键 + for key in dict1: + new_tag = f"{tag}.{key}" if tag else key + if key not in dict2: + result = {'tag': new_tag, 'equal': False, 'status': 'delete'} + results.append(result) + continue + value1 = dict1[key] + value2 = dict2[key] + if not isinstance(value1, dict): + continue + if set(value1.keys()) == {'max', 'min', 'mean', 'norm'}: + equal = value1 == value2 + relative_diffs = { + f"{k}_relative_diff": (abs(value1[k] - value2[k]) / value1[k]) if value1[k] != 0 else None + for k in ['max', 'min', 'mean', 'norm'] + } + result = {'tag': new_tag, 'equal': equal, 'status': 'unchanged'} + result.update(relative_diffs) + results.append(result) + else: + results.extend(compare_dataset_dicts(value1, value2, new_tag)) + # 处理 dict2 中独有的键 + for key in dict2: + if key not in dict1: + new_tag = f"{tag}.{key}" if tag else key + result = {'tag': new_tag, 'equal': False, 'status': 'added'} + results.append(result) + return results + + +def compare_dataset(bench_dir, cmp_dir): + all_results = [] + for step in os.listdir(bench_dir): + step_path_bench = os.path.join(bench_dir, step) + if not os.path.isdir(step_path_bench): + continue + step_path_cmp = os.path.join(cmp_dir, step) + for rank in os.listdir(step_path_bench): + rank_path_bench = os.path.join(step_path_bench, rank, 'dataset.json') + rank_path_cmp = os.path.join(step_path_cmp, rank, 'dataset.json') + if not os.path.isfile(rank_path_bench) or not os.path.isfile(rank_path_cmp): + continue + + dict1 = load_json(rank_path_bench) + dict2 = load_json(rank_path_cmp) + results = compare_dataset_dicts(dict1, dict2) + for result in results: + result['step'] = int(step.replace("step", "")) + result['rank'] = int(rank.replace("rank", "")) + all_results.extend(results) + + df = pd.DataFrame(all_results, columns=DatasetChecker.result_header) + df = df.sort_values(by=['step', 'rank'], ascending=[True, True]) + return df + + +@register_checker_item("dataset") +class DatasetChecker(BaseChecker): + input_needed = "model" + multi_rank = True + + target_name_in_zip = "dataset" + result_header = ['step', 'rank', 'tag', 'equal', 'max_relative_diff', + 'min_relative_diff', 'mean_relative_diff', 'norm_relative_diff'] + + @staticmethod + def pack(pack_input): + output_zip_path = pack_input.output_zip_path + + def collect_input(model, args, kwargs, step): + features = parse_args_and_kargs(args, kwargs) + dataset_filepath = os.path.join(DatasetChecker.target_name_in_zip, + f"step{step}", f"rank{FmkAdp.get_rank_id()}", "dataset.json") + create_file_in_zip(output_zip_path, dataset_filepath, json.dumps(features, indent=4)) + config_checking_print(f"add first dataset input features to zip") + + register_pre_forward_fun_list(collect_input) + + @staticmethod + def compare(bench_dir, cmp_dir, output_path, fmk): + bench_dataset_pack_path = os.path.join(bench_dir, DatasetChecker.target_name_in_zip) + cmp_dataset_pack_path = os.path.join(cmp_dir, DatasetChecker.target_name_in_zip) + + df = compare_dataset(bench_dataset_pack_path, cmp_dataset_pack_path) + pass_check = Const.CONFIG_CHECK_PASS if False not in df['equal'].values else Const.CONFIG_CHECK_ERROR + return DatasetChecker.target_name_in_zip, pass_check, df diff --git a/debug/accuracy_tools/msprobe/core/config_check/checkers/env_args_checker.py b/debug/accuracy_tools/msprobe/core/config_check/checkers/env_args_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..513a9f3b69c24225e4ba7a54f364104dce079dff --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/config_check/checkers/env_args_checker.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json + +import pandas as pd + +from msprobe.core.common.file_utils import load_json, load_yaml, create_file_with_content, create_file_in_zip +from msprobe.core.config_check.checkers.base_checker import BaseChecker +from msprobe.core.config_check.config_checker import register_checker_item +from msprobe.core.config_check.utils.utils import config_checking_print, process_pass_check +from msprobe.core.common.const import Const + + +dirpath = os.path.dirname(__file__) +env_yaml_path = os.path.join(dirpath, "../resource/env.yaml") + + +def collect_env_data(): + result = {} + for key, value in os.environ.items(): + result[key] = value + return result + + +def get_device_type(env_json): + for key in env_json.keys(): + if Const.ASCEND in key: + return Const.NPU_LOWERCASE + return Const.GPU_LOWERCASE + + +def compare_env_data(npu_path, bench_path): + necessary_env = load_yaml(env_yaml_path) + cmp_data = load_json(npu_path) + cmp_type = get_device_type(cmp_data) + bench_data = load_json(bench_path) + bench_type = get_device_type(bench_data) + data = [] + for _, value in necessary_env.items(): + cmp_env = value.get(cmp_type) + bench_env = value.get(bench_type) + if not bench_env and not cmp_env: + continue + elif cmp_env: + cmp_env_name = cmp_env["name"] + cmp_value = cmp_data.get(cmp_env_name, value[cmp_type]["default_value"]) + if not bench_env: + data.append(["only cmp has this env", cmp_env["name"], "", cmp_value, Const.CONFIG_CHECK_WARNING]) + continue + bench_env_name = bench_env["name"] + bench_value = bench_data.get(bench_env_name, value[bench_type]["default_value"]) + if cmp_value != bench_value: + data.append([bench_env_name, cmp_env_name, bench_value, cmp_value, Const.CONFIG_CHECK_ERROR]) + else: + bench_env_name = bench_env["name"] + bench_value = bench_data.get(bench_env_name) if bench_data.get(bench_env_name) else value[bench_type][ + "default_value"] + data.append([bench_env_name, "only bench has this env", bench_value, "", Const.CONFIG_CHECK_WARNING]) + df = pd.DataFrame(data, columns=EnvArgsChecker.result_header) + return df + + +@register_checker_item("env") +class EnvArgsChecker(BaseChecker): + + target_name_in_zip = "env" + result_header = ["bench_env_name", "cmp_env_name", "bench_value", "cmp_value", "level"] + + @staticmethod + def pack(pack_input): + output_zip_path = pack_input.output_zip_path + env_args_dict = collect_env_data() + create_file_in_zip(output_zip_path, EnvArgsChecker.target_name_in_zip, json.dumps(env_args_dict, indent=4)) + config_checking_print(f"add env args to zip") + + @staticmethod + def compare(bench_dir, cmp_dir, output_path, fmk): + bench_env_data = os.path.join(bench_dir, EnvArgsChecker.target_name_in_zip) + cmp_env_data = os.path.join(cmp_dir, EnvArgsChecker.target_name_in_zip) + df = compare_env_data(bench_env_data, cmp_env_data) + pass_check = process_pass_check(df['level'].values) + return EnvArgsChecker.target_name_in_zip, pass_check, df diff --git a/debug/accuracy_tools/msprobe/core/config_check/checkers/hyperparameter_checker.py b/debug/accuracy_tools/msprobe/core/config_check/checkers/hyperparameter_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..fd6a5388e0c0b7489ab6acd73dde0782077eb909 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/config_check/checkers/hyperparameter_checker.py @@ -0,0 +1,191 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +from difflib import SequenceMatcher + +from typing import Union, List, Dict, Any +import pandas as pd + +from msprobe.core.common.utils import check_extern_input_list +from msprobe.core.config_check.checkers.base_checker import BaseChecker +from msprobe.core.config_check.config_checker import register_checker_item +from msprobe.core.config_check.utils.utils import compare_dict, config_checking_print, update_dict, process_pass_check +from msprobe.core.config_check.utils.hyperparameter_parser import ParserFactory +from msprobe.core.common.file_utils import (check_file_or_directory_path, create_file_in_zip, load_json, + load_yaml) +from msprobe.core.common.const import Const + + +dirpath = os.path.dirname(__file__) +hyperparameters_path = os.path.join(dirpath, "../resource/hyperparameter.yaml") +parameter_name_mapping = load_yaml(os.path.realpath(hyperparameters_path)) +hyperparameters_dict = {} + + +def refine_json_keys(json_dcit): + new_dict = {} + for key in json_dcit.keys(): + new_key = key.split(Const.SEP)[-1].replace("-", "_") + new_dict[new_key] = key + return new_dict + + +def to_str_if_number(value): + if isinstance(value, (int, float)): + return str(value) + return value + + +@register_checker_item("hyperparameter") +class HyperparameterChecker(BaseChecker): + target_name_in_zip = "hyperparameters" + result_header = ["file_name", "bench_para", "cmp_para", "bench_value", "cmp_value", "matched_with", "level"] + hyperparameters_file_list = ["hyperparameters_static.json", "hyperparameters_dynamic.json"] + + @staticmethod + def pack(pack_input): + shell_path = pack_input.shell_path + output_zip_path = pack_input.output_zip_path + + if shell_path: + check_extern_input_list(shell_path) + + hyperparameters = {} + parser_factory = ParserFactory() + for script_path in shell_path: + if os.path.isfile(script_path): + check_file_or_directory_path(script_path) + parser = parser_factory.get_parser(os.path.splitext(script_path)[1]) + update_dict(hyperparameters, parser.run(os.path.realpath(script_path))) + else: + config_checking_print(f"Warning: Script path {script_path} is not a file.") + if hyperparameters: + create_file_in_zip(output_zip_path, + os.path.join(HyperparameterChecker.target_name_in_zip, + HyperparameterChecker.hyperparameters_file_list[0]), + json.dumps(hyperparameters, indent=4)) + config_checking_print(f"add static hyperparameters args to zip") + else: + config_checking_print(f"Warning: Failed to extract hyperparameters from script {shell_path}") + if hyperparameters_dict: + create_file_in_zip(output_zip_path, + os.path.join(HyperparameterChecker.target_name_in_zip, + HyperparameterChecker.hyperparameters_file_list[1]), + json.dumps(vars(hyperparameters_dict), default=lambda x: None, indent=4)) + config_checking_print(f"add dynamic hyperparameters args to zip") + + @staticmethod + def compare(bench_dir, cmp_dir, output_path, fmk): + all_diffs = [] + for file_name in HyperparameterChecker.hyperparameters_file_list: + bench_model_dir = os.path.join(bench_dir, HyperparameterChecker.target_name_in_zip, file_name) + cmp_model_dir = os.path.join(cmp_dir, HyperparameterChecker.target_name_in_zip, file_name) + if os.path.isfile(bench_model_dir) and os.path.isfile(cmp_model_dir): + bench_hyperparameters = load_json(bench_model_dir) + cmp_hyperparameters = load_json(cmp_model_dir) + all_diffs.extend( + HyperparameterChecker.compare_param(bench_hyperparameters, cmp_hyperparameters, file_name)) + df = pd.DataFrame(all_diffs, columns=HyperparameterChecker.result_header) + pass_check = process_pass_check(df["level"].values) + return HyperparameterChecker.target_name_in_zip, pass_check, df + + @staticmethod + def compare_param(bench_params, cmp_params, file_name): + all_diffs = [] + bench_params_refined = refine_json_keys(bench_params) + cmp_params_refined = refine_json_keys(cmp_params) + + for bench_param_name in bench_params_refined.keys(): + matched_cmp_param_name, matched_with = HyperparameterChecker._fuzzy_match_parameter(bench_param_name, + cmp_params_refined) + matched_cmp_param_name = cmp_params_refined.get(matched_cmp_param_name) + bench_param_name = bench_params_refined.get(bench_param_name) + bench_param_value = to_str_if_number(bench_params[bench_param_name]) + if matched_cmp_param_name: + cmp_param_value = to_str_if_number(cmp_params[matched_cmp_param_name]) + if bench_param_value != cmp_param_value: + all_diffs.append( + [file_name, bench_param_name, matched_cmp_param_name, bench_param_value, cmp_param_value, + matched_with, Const.CONFIG_CHECK_ERROR]) + del cmp_params[matched_cmp_param_name] + else: + all_diffs.append( + [file_name, bench_param_name, "Only in benchmark", bench_param_value, "", "", + Const.CONFIG_CHECK_WARNING]) + for cmp_param_name, cmp_param_value in cmp_params.items(): + all_diffs.append( + [file_name, "Only in comparison", cmp_param_name, "", cmp_param_value, "", Const.CONFIG_CHECK_WARNING]) + all_diffs.sort() + return all_diffs + + @staticmethod + def apply_patches(fmk): + try: + from megatron import training + + def collect_hyperparameter_wrapper(func): + def wrapper(*args, **kwargs): + global hyperparameters_dict + result = func(*args, **kwargs) + if not hyperparameters_dict: + hyperparameters_dict = result + return result + return wrapper + training.get_args = collect_hyperparameter_wrapper(training.get_args) + except ImportError: + config_checking_print("No megatron find.") + except Exception as e: + config_checking_print(f"Patch megatron method failed, detail:{str(e)}") + + @staticmethod + def _fuzzy_match_parameter(param_name: str, available_params: Dict[str, Any]): + """ + Fuzzy matches a parameter name against available parameter names using predefined + mappings and string similarity. + """ + if param_name in available_params: + return param_name, Const.MATCH_MODE_NAME + + canonical_name = None + for standard_name, aliases in parameter_name_mapping.items(): + if param_name == standard_name or param_name in aliases: + canonical_name = standard_name + break + + if canonical_name: + if canonical_name in available_params: + return canonical_name, Const.MATCH_MODE_MAPPING + for alias in parameter_name_mapping[canonical_name]: + if alias in available_params: + config_checking_print( + f"Matched '{param_name}' to alias '{alias}' via canonical name '{canonical_name}'") + return alias, Const.MATCH_MODE_MAPPING + + best_match_name = None + best_match_ratio = 0.8 + for available_param_name in available_params: + ratio = SequenceMatcher(None, param_name.lower(), available_param_name.lower()).ratio() + if ratio > best_match_ratio: + best_match_ratio = ratio + best_match_name = available_param_name + + if best_match_name: + config_checking_print( + f"Fuzzy matched parameter '{param_name}' to '{best_match_name}' (similarity: {best_match_ratio:.2f})") + return best_match_name, f"{Const.MATCH_MODE_SIMILARITY}:{best_match_ratio}" + + return None, None diff --git a/debug/accuracy_tools/msprobe/core/config_check/checkers/pip_checker.py b/debug/accuracy_tools/msprobe/core/config_check/checkers/pip_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..0795ad6bc02d1418e2145a73b31cec219ff7dee4 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/config_check/checkers/pip_checker.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pandas as pd +try: + import importlib.metadata as metadata +except ImportError: + import importlib_metadata as metadata + +from msprobe.core.common.file_utils import load_yaml, create_file_in_zip +from msprobe.core.config_check.checkers.base_checker import BaseChecker +from msprobe.core.config_check.config_checker import register_checker_item +from msprobe.core.config_check.utils.utils import config_checking_print, process_pass_check +from msprobe.core.common.file_utils import FileOpen, save_excel +from msprobe.core.common.const import Const + +dirpath = os.path.dirname(__file__) +depend_path = os.path.join(dirpath, "../resource/dependency.yaml") + + +def load_pip_txt(file_path): + output_dir = {} + with FileOpen(file_path, 'r', encoding='utf-8') as file: + lines = file.readlines() + for line in lines: + info_list = line.strip().split("=") + output_dir[info_list[0]] = "" if len(info_list) != 2 else info_list[1] + return output_dir + + +def collect_pip_data(): + result = "" + packages = metadata.distributions() + for pkg in packages: + if pkg.metadata: + result += f"{pkg.metadata.get('Name')}={pkg.version}\n" + return result + + +def compare_pip_data(bench_pip_path, cmp_pip_path, fmk): + necessary_dependency = load_yaml(depend_path)["dependency"] + necessary_dependency.append(fmk) + bench_data = load_pip_txt(bench_pip_path) + cmp_data = load_pip_txt(cmp_pip_path) + data = [] + for package in necessary_dependency: + bench_version = bench_data.get(package) + cmp_version = cmp_data.get(package) + + if bench_version != cmp_version: + data.append([package, bench_version if bench_version else 'None', + cmp_version if cmp_version else 'None', + Const.CONFIG_CHECK_ERROR]) + + df = pd.DataFrame(data, columns=PipPackageChecker.result_header) + return df + + +@register_checker_item("pip") +class PipPackageChecker(BaseChecker): + + target_name_in_zip = "pip" + result_header = ['package', 'bench version', 'cmp version', 'level'] + + @staticmethod + def pack(pack_input): + output_zip_path = pack_input.output_zip_path + pip_data = collect_pip_data() + create_file_in_zip(output_zip_path, PipPackageChecker.target_name_in_zip, pip_data) + config_checking_print(f"add pip info to zip") + + @staticmethod + def compare(bench_dir, cmp_dir, output_path, fmk): + bench_pip_path = os.path.join(bench_dir, PipPackageChecker.target_name_in_zip) + cmp_pip_path = os.path.join(cmp_dir, PipPackageChecker.target_name_in_zip) + df = compare_pip_data(bench_pip_path, cmp_pip_path, fmk) + pass_check = process_pass_check(df['level'].values) + return PipPackageChecker.target_name_in_zip, pass_check, df diff --git a/debug/accuracy_tools/msprobe/core/config_check/checkers/random_checker.py b/debug/accuracy_tools/msprobe/core/config_check/checkers/random_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..00408e384cf4b34a0a2d000275a13d4f614acc68 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/config_check/checkers/random_checker.py @@ -0,0 +1,367 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +from functools import wraps +from typing import Callable, List, Dict, Tuple, Optional +import inspect +import os +import json +from collections import defaultdict +import difflib + +import numpy as np +import pandas as pd +from msprobe.core.config_check.config_checker import register_checker_item, register_pre_forward_fun_list +from msprobe.core.common.file_utils import create_file_in_zip, load_json +from msprobe.core.config_check.checkers.base_checker import BaseChecker +from msprobe.core.config_check.utils.utils import config_checking_print +from msprobe.core.common.framework_adapter import FmkAdp +from msprobe.core.common.const import Const +from msprobe.core.common.log import logger + + +# 数据结构:{随机操作名字: [{count: 调用次数, stack: 调用栈列表}]} +random_op_stats = defaultdict(list) + + +def get_call_stack(frame) -> List[str]: + """获取详细的调用栈信息,每个元素包含完整路径、行号、函数名和代码行""" + stack = [] + current_frame = frame.f_back # 跳过当前函数 + + while current_frame: + frame_info = inspect.getframeinfo(current_frame) + filename = os.path.abspath(frame_info.filename) + code_line = frame_info.code_context[0].strip() if frame_info.code_context else "" + + # 格式化为详细的栈帧信息 + stack_entry = f"File {filename}, line {frame_info.lineno}, in {frame_info.function}, {code_line}" + stack.append(stack_entry) + + current_frame = current_frame.f_back + + # 反转堆栈以显示正确的调用顺序(栈底到栈顶) + return stack[::-1] + + +def track_random_call(func: Callable, name: str): + """记录随机函数的调用信息""" + @wraps(func) + def wrapper(*args, **kwargs): + frame = inspect.currentframe() + stack = get_call_stack(frame) + + # 更新调用统计:操作名 -> [{count: 次数, stack: 调用栈列表}] + # 检查是否已有相同调用栈的记录 + for entry in random_op_stats[name]: + if entry['stack'] == stack: + entry['count'] += 1 + break + else: + # 新增调用栈记录 + random_op_stats[name].append({'count': 1, 'stack': stack}) + + try: + result = func(*args, **kwargs) + return result + except Exception as e: + raise e + finally: + del frame + + return wrapper + + +def load_stats_files(directory: str) -> Dict[str, Dict[str, List[Dict]]]: + """加载目录下所有统计文件并按rank组织数据""" + rank_data = {} + for file in os.listdir(directory): + file_path = os.path.join(directory, file) + if file.startswith('rank') and file.endswith('.json'): + rank = os.path.basename(file.split('.')[0])[4:] + if not rank or not rank.isdigit(): + logger.error(f"extract rank id from {file} failed") + raise ValueError + + # 加载并存储数据 + data = load_json(file_path) + rank_data[int(rank)] = data + + return rank_data + + +def stack_match(stack1: List[str], stack2: List[str], threshold: float = 0.8) -> bool: + """ + 比较两个调用栈是否相似,同时考虑路径、函数名和代码行(各占1/3),每一层的相似度阈值需要达到0.8 + + 参数: + - stack1: 第一个调用栈列表 + - stack2: 第二个调用栈列表 + - threshold: 相似度阈值,默认0.8 + + 返回: + - 两个调用栈是否相似的布尔值 + """ + if len(stack1) != len(stack2): + return False + + for frame1, frame2 in zip(stack1, stack2): + # 提取路径、函数名和代码行 + path1, func1, code1 = _parse_frame(frame1) + path2, func2, code2 = _parse_frame(frame2) + + # 计算相似度得分 (路径、函数名、代码行各占1/3权重) + path_score = _compare_path(path1, path2) + func_score = 1.0 if func1 == func2 else 0.0 + # 代码相似度 + code_score = difflib.SequenceMatcher(None, code1, code2).ratio() + + frame_score = (path_score + func_score + code_score) / 3.0 + if frame_score < threshold: + return False + + return True + + +def _parse_frame(frame: str) -> Tuple[str, str, str]: + """ + 解析栈帧字符串,提取路径、函数名和代码行 + + 参数: + - frame: 栈帧字符串。格式为"File {path}, line {line}, in {func}, {code}" + + 返回: + - path, func, code + """ + path = func = code = '' + stack_info = frame.split(' ') + if len(stack_info) > 6: + path = stack_info[1][:-1] + func = stack_info[5][:-1] + code = ' '.join(stack_info[6:]) + return path, func, code + + +def _compare_path(path1: str, path2: str) -> float: + """比较两个路径的相似度,只考虑文件名""" + if not path1 or not path2: + return 0.0 + + # 提取文件名(忽略目录路径) + file1 = os.path.basename(path1) + file2 = os.path.basename(path2) + + return 1.0 if file1 == file2 else 0.0 + + +def find_matching_stack(bench_stack: List[str], cmp_stacks: List[Dict]) -> Optional[Dict]: + """ + 查找匹配的调用栈 + + 参数: + - bench_stack: 基准侧的调用栈列表 + - cmp_stacks: 比较侧的调用栈条目列表,每个条目是{'count': 次数, 'stack': 调用栈列表} + + 返回: + - 匹配的调用栈条目或None + """ + for cmp_entry in cmp_stacks: + if stack_match(cmp_entry['stack'], bench_stack): + return cmp_entry + + return None + + +def stack_list_to_string(stack_list): + """ + 将调用栈列表转换为换行分隔的字符串 + 如果输入是特殊标记(如"no match stack"),则直接返回 + """ + if isinstance(stack_list, list): + return '\n'.join(stack_list) + return stack_list + + +def compare_random_calls(bench_dir: str = 'bench', cmp_dir: str = 'cmp') -> pd.DataFrame: + """比较两个目录下的随机调用栈统计,生成详细比对结果""" + bench_rank_data = load_stats_files(bench_dir) + cmp_rank_data = load_stats_files(cmp_dir) + + # 获取所有rank + all_ranks = sorted(set(bench_rank_data.keys()) | set(cmp_rank_data.keys())) + + results = [] + + for rank in all_ranks: + bench_data = bench_rank_data.get(rank, {}) + cmp_data = cmp_rank_data.get(rank, {}) + + # 获取所有操作 + all_ops = set(bench_data.keys()) | set(cmp_data.keys()) + + for op in all_ops: + bench_stacks = bench_data.get(op, []) + cmp_stacks = cmp_data.get(op, []) + + # 处理bench侧的每个调用栈 + for bench_entry in bench_stacks: + bench_stack = bench_entry['stack'] + bench_count = bench_entry['count'] + + # 查找匹配的cmp侧调用栈 + cmp_entry = find_matching_stack(bench_stack, cmp_stacks) + + if cmp_entry: + cmp_count = cmp_entry['count'] + check_result = bench_count == cmp_count + results.append([op, rank, bench_stack, cmp_entry['stack'], bench_count, cmp_count, check_result]) + else: + # 没有匹配的调用栈 + results.append([op, rank, bench_stack, "no match stack", bench_count, 0, False]) + + # 处理cmp侧中没有在bench侧出现的调用栈 + for cmp_entry in cmp_stacks: + cmp_stack = cmp_entry['stack'] + # 检查是否已经在上面处理过 + if not any(stack_match(bench_entry['stack'], cmp_stack) for bench_entry in bench_stacks): + results.append([op, rank, "no match stack", cmp_stack, 0, cmp_entry['count'], False]) + + # 创建DataFrame + df = pd.DataFrame(results, columns=RandomChecker.result_header) + + # 应用转换函数 + df['bench_stack'] = df['bench_stack'].apply(stack_list_to_string) + df['cmp_stack'] = df['cmp_stack'].apply(stack_list_to_string) + + return df + + +def torch_patchs(): + """补丁Torch随机函数""" + import torch + torch_patches = { + 'rand': torch.rand, + 'randint': torch.randint, + 'randn': torch.randn, + 'rand_like': torch.rand_like, + 'randint_like': torch.randint_like, + 'randn_like': torch.randn_like, + 'manual_seed': torch.manual_seed + } + for name, func in torch_patches.items(): + setattr(torch, name, track_random_call(func, f"torch.{name}")) + + tensor_patches = { + 'exponential_': torch.Tensor.exponential_, + 'geometric_': torch.Tensor.geometric_, + 'log_normal_': torch.Tensor.log_normal_, + 'cauchy_': torch.Tensor.cauchy_ + } + for name, func in tensor_patches.items(): + setattr(torch.Tensor, name, track_random_call(func, f"torch.Tensor.{name}")) + + +def mindspore_patchs(): + """补丁MindSpore随机函数""" + import mindspore + + mindspore_ops_patches = { + 'rand': mindspore.ops.rand, + 'randint': mindspore.ops.randint, + 'randn': mindspore.ops.randn + } + for name, func in mindspore_ops_patches.items(): + setattr(mindspore.ops, name, track_random_call(func, f"mindspore.ops.{name}")) + + mindspore_patches = { + 'manual_seed': mindspore.set_seed + } + for name, func in mindspore_patches.items(): + setattr(mindspore, name, track_random_call(func, f"mindspore.{name}")) + + +@register_checker_item("random") +class RandomChecker(BaseChecker): + input_needed = None + target_name_in_zip = "random" + result_header = ['op', 'rank', 'bench_stack', 'cmp_stack', 'bench_count', 'cmp_count', 'check_result'] + write_once = False + + @staticmethod + def pack(pack_input): + """打包随机调用统计到zip文件""" + output_zip_path = pack_input.output_zip_path + + def collect_input(model, args, kwargs, step): + if RandomChecker.write_once: + return + + random_stats_dir = os.path.join(RandomChecker.target_name_in_zip) + stats_filepath = os.path.join(random_stats_dir, f"rank{FmkAdp.get_rank_id()}.json") + + # 转换为JSON格式:{操作名: [{count: 次数, stack: 调用栈列表}]} + stats_json = {} + for op_name, entries in random_op_stats.items(): + stats_json[op_name] = entries + + create_file_in_zip(output_zip_path, stats_filepath, json.dumps(stats_json, indent=4)) + config_checking_print(f"已将随机调用统计打包到: {stats_filepath}") + RandomChecker.write_once = True + + register_pre_forward_fun_list(collect_input) + + @staticmethod + def compare(bench_dir, cmp_dir, output_path, fmk): + """比较两组随机调用统计""" + bench_stats_path = os.path.join(bench_dir, RandomChecker.target_name_in_zip) + cmp_stats_path = os.path.join(cmp_dir, RandomChecker.target_name_in_zip) + + df = compare_random_calls(bench_stats_path, cmp_stats_path) + pass_check = Const.CONFIG_CHECK_PASS if False not in df['check_result'].values else Const.CONFIG_CHECK_ERROR + + return RandomChecker.target_name_in_zip, pass_check, df + + @staticmethod + def apply_patches(fmk=Const.PT_FRAMEWORK): + """应用随机函数补丁""" + # 补丁Python random模块 + random_patches = { + 'random': random.random, + 'randint': random.randint, + 'uniform': random.uniform, + 'choice': random.choice + } + for name, func in random_patches.items(): + setattr(random, name, track_random_call(func, f"random.{name}")) + + # 补丁Numpy随机函数 + np_random_patches = { + 'rand': np.random.rand, + 'randint': np.random.randint, + 'choice': np.random.choice, + 'normal': np.random.normal + } + for name, func in np_random_patches.items(): + setattr(np.random, name, track_random_call(func, f"np.random.{name}")) + + # 补丁框架特定随机函数 + if fmk == Const.PT_FRAMEWORK: + torch_patchs() + elif fmk == Const.MS_FRAMEWORK: + mindspore_patchs() + else: + raise Exception(f"不支持的框架: {fmk}, 支持的框架: {FmkAdp.supported_fmk}") diff --git a/debug/accuracy_tools/msprobe/core/config_check/checkers/weights_checker.py b/debug/accuracy_tools/msprobe/core/config_check/checkers/weights_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..32716ea2e86ab79b290b48f8f96a4e1fe862f4b5 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/config_check/checkers/weights_checker.py @@ -0,0 +1,148 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import pandas as pd + +from msprobe.core.common.file_utils import create_file_in_zip, os_walk_for_files, load_json +from msprobe.core.config_check.checkers.base_checker import BaseChecker +from msprobe.core.config_check.config_checker import register_checker_item, register_pre_forward_fun_list +from msprobe.core.config_check.utils.utils import config_checking_print, get_tensor_features +from msprobe.core.common.framework_adapter import FmkAdp +from msprobe.core.common.const import Const + + +def collect_weights_data(model): + weights_data = {} + for name, param in FmkAdp.named_parameters(model): + if param.dtype != FmkAdp.dtype("float32"): + param = param.float() + weights_data[name] = get_tensor_features(param) + return weights_data + + +def compare_weight_file(bench_file, cmp_file): + bench_data = load_json(bench_file) + cmp_data = load_json(cmp_file) + + results = [] + for weight_name in set(bench_data.keys()) | set(cmp_data.keys()): + result = { + "weight_name": weight_name, + "equal": None, + "max_relative_diff": None, + "min_relative_diff": None, + "mean_relative_diff": None, + "norm_relative_diff": None + } + + if weight_name not in bench_data: + result["equal"] = "only cmp have" + results.append(result) + continue + + if weight_name not in cmp_data: + result["equal"] = "only bench have" + results.append(result) + continue + + bench_vals = bench_data[weight_name] + cmp_vals = cmp_data[weight_name] + keys = ["max", "min", "mean", "norm"] + equal = all([bench_vals[k] == cmp_vals[k] for k in keys]) + result["equal"] = equal + + for key in keys: + diff_key = f"{key}_relative_diff" + result[diff_key] = (abs(bench_vals[key] - cmp_vals[key]) / bench_vals[key]) \ + if bench_vals[key] != 0 else None + + results.append(result) + + return results + + +def compare_weight(bench_dir, cmp_dir): + all_results = [] + bench_files_info = os_walk_for_files(bench_dir, 10) + for info in bench_files_info: + if not info["file"].endswith('.json'): + continue + bench_file = os.path.join(info["root"], info["file"]) + relative_path = os.path.relpath(info["root"], bench_dir) + cmp_root = os.path.join(cmp_dir, relative_path) + cmp_file = os.path.join(cmp_root, info["file"]) + + path_list = relative_path.split(os.sep) + if len(path_list) < 2: + raise Exception("Can not compare weights because the extracted file has been corrupted!") + step = int(path_list[0].replace("step", "")) + rank = int(path_list[1].replace("rank", "")) + + if not os.path.exists(cmp_file): + bench_data = load_json(bench_file) + for weight_name in bench_data.keys(): + result = { + "step": step, + "rank": rank, + "weight_name": weight_name, + "equal": "only bench have", + "max_relative_diff": None, + "min_relative_diff": None, + "mean_relative_diff": None, + "norm_relative_diff": None + } + all_results.append(result) + else: + results = compare_weight_file(bench_file, cmp_file) + for res in results: + res["step"] = step + res["rank"] = rank + all_results.append(res) + + df = pd.DataFrame(all_results, columns=WeightsChecker.result_header) + df = df.sort_values(by=['step', 'rank'], ascending=[True, True]) + return df + + +@register_checker_item("weights") +class WeightsChecker(BaseChecker): + input_needed = "model" + multi_rank = True + + target_name_in_zip = "weights" + result_header = ["step", "rank", "weight_name", "equal", "max_relative_diff", + "min_relative_diff", "mean_relative_diff", "norm_relative_diff"] + + @staticmethod + def pack(pack_input): + output_zip_path = pack_input.output_zip_path + + def collect_weights(model, args, kwargs, step): + weights_data_dict = collect_weights_data(model) + weights_data_filepath = os.path.join(WeightsChecker.target_name_in_zip, + f"step{step}", f"rank{FmkAdp.get_rank_id()}", "weight.json") + create_file_in_zip(output_zip_path, weights_data_filepath, json.dumps(weights_data_dict, indent=4)) + config_checking_print(f"add weights info to zip") + register_pre_forward_fun_list(collect_weights) + + @staticmethod + def compare(bench_dir, cmp_dir, output_path, fmk): + bench_weight_pack_path = os.path.join(bench_dir, WeightsChecker.target_name_in_zip) + cmp_weight_pack_path = os.path.join(cmp_dir, WeightsChecker.target_name_in_zip) + df = compare_weight(bench_weight_pack_path, cmp_weight_pack_path) + pass_check = Const.CONFIG_CHECK_PASS if False not in df['equal'].values else Const.CONFIG_CHECK_ERROR + return WeightsChecker.target_name_in_zip, pass_check, df diff --git a/debug/accuracy_tools/msprobe/core/config_check/ckpt_compare/ckpt_comparator.py b/debug/accuracy_tools/msprobe/core/config_check/ckpt_compare/ckpt_comparator.py new file mode 100644 index 0000000000000000000000000000000000000000..42e2dcd53924fea8e53b1cc1357e0839d2bddcff --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/config_check/ckpt_compare/ckpt_comparator.py @@ -0,0 +1,77 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict +from tqdm import tqdm + +from msprobe.core.common.file_utils import save_json, check_path_before_create, check_path_not_exists, \ + check_file_or_directory_path +from msprobe.core.common.log import logger +from msprobe.core.config_check.ckpt_compare.megatron_loader import load_megatron_weights +from msprobe.core.config_check.ckpt_compare.metrics import METRIC_FUNC + + + +def compare_checkpoints(ckpt_path1, ckpt_path2, output_path) -> Dict: + """Compare weights between two checkpoints using cosine similarity and L2 distance. + + Args: + ckpt_path1 (str): Path to first checkpoint directory + ckpt_path2 (str): Path to second checkpoint directory + output_path (str): Path to save comparison results JSON file + + Returns: + Dict: Dictionary containing comparison metrics for each parameter. The dictionary has the following structure: + { + "param_name": { + "cosine_similarity": float, # Cosine similarity between parameter tensors + "l2_distance": float, # L2 distance between parameter tensors + "shape": List[int] # Shape of the parameter tensors + }, + ... + } + """ + + # Load both checkpoints + check_file_or_directory_path(ckpt_path1, isdir=True) + check_file_or_directory_path(ckpt_path2, isdir=True) + check_path_before_create(output_path) + check_path_not_exists(output_path) + weights1 = load_megatron_weights(ckpt_path1) + weights2 = load_megatron_weights(ckpt_path2) + + # Initialize results dictionary + results = {} + + # Compare weights with matching keys + common = set(weights1) & set(weights2) + logger.warning(f'Parameters not in ckpt2: {set(weights1) - set(weights2)}') + logger.warning(f'Parameters not in ckpt1: {set(weights2) - set(weights1)}') + for key in tqdm(common): + tensor1 = weights1[key] + tensor2 = weights2[key] + + results[key] = {} + for metric, func in METRIC_FUNC.items(): + try: + results[key][metric] = func(tensor1, tensor2) + except Exception as e: + results[key][metric] = 'error' + logger.warning(f'Error when calculate {metric} for reason: {e}') + + # Write results to JSON file + save_json(output_path, results, indent=4) + logger.info(f"Comparison results written to {output_path}") + return results diff --git a/debug/accuracy_tools/msprobe/core/config_check/ckpt_compare/megatron_loader.py b/debug/accuracy_tools/msprobe/core/config_check/ckpt_compare/megatron_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..5a03b4d2542c98736745d995d66a5869beb0bf13 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/config_check/ckpt_compare/megatron_loader.py @@ -0,0 +1,304 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +from collections import defaultdict +from typing import Dict +import numpy as np +from msprobe.core.common.log import logger +from msprobe.core.common.decorator import recursion_depth_decorator +from msprobe.core.common.const import Const +from msprobe.core.common.file_utils import FileOpen, load_yaml +from msprobe.core.common.framework_adapter import FmkAdp + +# both weights and bias are partitioned in column parallel +COLUMN_PARALLEL_PARAMS = ['linear_qkv', 'linear_fc1', 'word_embeddings.weight', 'output_layer.weight'] +# only weights are partitioned in column parallel +ROW_PARALLEL_PARAMS = ['linear_fc2.weight', 'linear_proj.weight'] +ARGS = 'args' +LAYER_IDX_PATTERN = re.compile('layers\.(\d+)\.') +EXPERT_IDX_PATTERN = re.compile('experts\.(\d+)\.') +ITER_DIR_PATTERN = re.compile('iter_([\d]{7})') + + +@recursion_depth_decorator('') +def _get_parameter(weights, prefix=''): + for k, v in weights.items(): + name = Const.SEP.join([prefix, k]).strip(Const.SEP) + if isinstance(v, dict): + yield from _get_parameter(v, prefix=name) + elif FmkAdp.is_tensor(v): + yield name, FmkAdp.asnumpy(v) + + +def _map_to_mcore_local_names(param_name: str) -> str: + """Map parameter names to mcore + local transformer implementation names.""" + mcore_local_map = load_yaml(os.path.join(os.path.dirname(__file__), 'name_mapping.yaml')) + for other_name, mcore_local_name in mcore_local_map.items(): + param_name = param_name.replace(other_name, mcore_local_name) + + return param_name + + +def _parse_real_layer_idx(param_name, num_layers_per_stage, pp_size, pp_rank): + """Map local (virtual) pipeline stage layer index to global layer index. + + For virtual pipeline parallel, each pipeline stage is further divided into virtual stages. + The global layer index needs to account for both pipeline stage and virtual stage. + + Args: + param_name (str): Parameter name containing layer index: layers.x./ + num_layers_per_stage (int): Number of layers per pipeline stage + pp_size (int): Pipeline parallel size + + Returns: + int: Global layer index accounting for both pipeline and virtual pipeline stages + """ + # Extract local layer index from parameter name + layer_match = re.search(LAYER_IDX_PATTERN, param_name) + param_name, vpp_stage = param_name.split(Const.SCOPE_SEPARATOR) + if not layer_match: + return param_name + + local_layer_idx = int(layer_match.group(1)) + vpp_stage = int(vpp_stage) + + # Calculate global layer index based on pipeline stage and virtual stage + real_layer_idx = local_layer_idx + (pp_size * vpp_stage + pp_rank) * num_layers_per_stage + + return param_name.replace(f'layers.{local_layer_idx}', f'layers.{real_layer_idx}') + + +def _parse_real_expert_idx(param_name, num_experts_per_rank, exp_rank): + """Map local expert index to global expert index. TODO: shared expert + + For expert parallel, experts are distributed across ranks. This function maps + the local expert index on a rank to its global index across all ranks. + + Args: + param_name (str): Parameter name containing local expert index + num_experts_per_rank (int): Number of experts on each rank + exp_rank (int): Expert parallel rank + + Returns: + str: Parameter name with local expert index replaced by global expert index + """ + # Extract local layer index from parameter name + expert_match = re.search(EXPERT_IDX_PATTERN, param_name) + if not expert_match: + return param_name + + local_expert_idx = int(expert_match.group(1)) + # Calculate global layer index based on pipeline stage and virtual stage + real_experts_idx = local_expert_idx + exp_rank * num_experts_per_rank + + return param_name.replace(f'experts.{local_expert_idx}', f'experts.{real_experts_idx}') + + +def _consolidate_tp_weights(weights: Dict) -> Dict: + """Consolidate weights from different tensor parallel ranks into combined tensors. + + Args: + weights: Dictionary of weights with rank information in keys + + Returns: + Dict: Consolidated weights without rank information + """ + consolidated = {} + for key, tensors in weights.items(): + if any([name in key for name in COLUMN_PARALLEL_PARAMS]): + # Column parallel - concatenate along input dimension (dim 0) + combined = np.concatenate(tensors, axis=0) + elif any([name in key for name in ROW_PARALLEL_PARAMS]): + # Row parallel - concatenate along output dimension (dim 1) + combined = np.concatenate(tensors, axis=1) + else: + # For other params, verify identical and use first + if not all(np.allclose(tensors[0], t) for t in tensors[1:]): + logger.warning(f"Inconsistent values for {key} across TP ranks") + combined = tensors[0] + + consolidated[key] = combined + return consolidated + + +def _parse_num_layers_per_stage(tp_partition): + match = [re.findall(LAYER_IDX_PATTERN, key) for key in tp_partition.keys()] + layer_idx = [int(i[0]) for i in match if i] + if not layer_idx: + return 1 + num_layers_per_pipeline_stage = max(layer_idx) + 1 + + return num_layers_per_pipeline_stage + + +def parse_parallel_size(checkpoint_dir: str): + """Parse tensor, pipeline and expert parallel sizes from checkpoint filenames. + + Args: + checkpoint_dir (str): Directory containing checkpoint files + + Returns: + Namespace + """ + # Find all rank directories + rank_dirs = [d for d in os.listdir(checkpoint_dir) if d.startswith('mp_rank_')] + + if not rank_dirs: + raise ValueError(f"No checkpoint rank directories found in {checkpoint_dir}") + + ckpt = FmkAdp.load_checkpoint( + os.path.join(checkpoint_dir, rank_dirs[0], 'model_optim_rng.pt'), + to_cpu=True, + weights_only=False) + args = ckpt[ARGS] + return ( + args.tensor_model_parallel_size, + args.pipeline_model_parallel_size, + args.expert_model_parallel_size, + args.num_experts + ) + + +def parse_iteration(checkpoint_path: str) -> Dict: + """ + Parse the checkpoint iteration directory from a given checkpoint path. + + If the path is a top-level checkpoint directory, this function reads the + 'latest_checkpointed_iteration.txt' file to determine the latest iteration. + If the path is already an iteration directory (e.g., 'iter_0000005'), it extracts + the iteration number from the path. + + Args: + checkpoint_path (str): Path to the checkpoint directory or iteration directory. + + Returns: + str: The full path to the checkpoint directory for the determined iteration. + + Raises: + ValueError: If the checkpoint directory for the determined iteration does not exist. + """ + iteration = None + tracker_file = os.path.join(checkpoint_path, "latest_checkpointed_iteration.txt") + if os.path.exists(tracker_file): + with FileOpen(tracker_file, 'r') as f: + latest_iteration = f.read().strip() + if latest_iteration != 'release': + try: + iteration = int(latest_iteration) + except Exception: + logger.warning( + f"The latest_checkpointed_iteration is supposed to be `release` or an int. \ + But {latest_iteration} is found." + ) + checkpoint_path = os.path.join(checkpoint_path, f'iter_{iteration:07d}') + else: + match = re.findall(ITER_DIR_PATTERN, checkpoint_path) + if match: + iteration = int(match[0]) + + # Checkpoint directory for this iteration + logger.info(f"Loaded checkpoint from iteration {iteration}") + + if not os.path.exists(checkpoint_path): + raise ValueError(f"Checkpoint directory not found: {checkpoint_path}") + + return checkpoint_path + + +def get_weights_from_state_dict(state_dict): + weights = {} + vpp_stage = 0 + if 'model' in state_dict: + model_weights = state_dict['model'] + + for key, value in _get_parameter(model_weights): + key = _map_to_mcore_local_names(key) + weights[f"{key}{Const.SCOPE_SEPARATOR}{vpp_stage}"] = value + + elif 'model0' in state_dict: + #vpp enabled + while f'model{vpp_stage}' in state_dict: + model_weights = state_dict[f'model{vpp_stage}'] + for key, value in _get_parameter(model_weights): + key = _map_to_mcore_local_names(key) + weights[f"{key}{Const.SCOPE_SEPARATOR}{vpp_stage}"] = value + vpp_stage += 1 + return weights + + +def load_megatron_weights(checkpoint_path: str) -> Dict: + """Load Megatron parallel checkpoint weights into a single dictionary. + + Args: + checkpoint_path (str): Base checkpoint directory path + + Returns: + combined_weights: Dict with weights from all ranks, keys include rank info + """ + try: + import megatron + except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megatron', which is required to load a megatron ckpt") from e + + # Find latest iteration if not specified + checkpoint_path = parse_iteration(checkpoint_path) + + # Parse parallel sizes from checkpoint directory structure + tp_size, pp_size, exp_size, num_experts = parse_parallel_size(checkpoint_path) + combined_weights = {} + + # Load checkpoints from all ranks + for exp_rank in range(exp_size): + num_layers_per_pipeline_stage = 0 + for pp_rank in range(pp_size): + tp_partition = defaultdict(list) + for tp_rank in range(tp_size): + # Construct checkpoint path based on parallel ranks + if pp_size > 1: + rank_dir = f'mp_rank_{tp_rank:02d}_{pp_rank:03d}' + else: + rank_dir = f'mp_rank_{tp_rank:02d}' + + if exp_size > 1: + rank_dir = f'{rank_dir}_{exp_rank:03d}' + + ckpt_file = os.path.join(checkpoint_path, rank_dir, 'model_optim_rng.pt') + try: + state_dict = FmkAdp.load_checkpoint(ckpt_file, to_cpu=True, weights_only=False) + partition = get_weights_from_state_dict(state_dict) + for key, weight in partition.items(): + tp_partition[key].append(weight) + + except Exception as load_error: + logger.warning(f"Error loading {ckpt_file}: {load_error}") + + if not tp_partition: + raise ValueError('No state loaded.') + + if not num_layers_per_pipeline_stage: + num_layers_per_pipeline_stage = _parse_num_layers_per_stage(tp_partition) + + consolidated_weight = _consolidate_tp_weights(tp_partition) + for key, value in consolidated_weight.items(): + key = _parse_real_layer_idx(key, num_layers_per_pipeline_stage, pp_size, pp_rank) + if num_experts: + key = _parse_real_expert_idx(key, num_experts // exp_size, exp_rank) + combined_weights[key] = value + + logger.info(f"Found {len(combined_weights)} total parameters across all ranks") + + return combined_weights diff --git a/debug/accuracy_tools/msprobe/core/config_check/ckpt_compare/metrics.py b/debug/accuracy_tools/msprobe/core/config_check/ckpt_compare/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..2e9e1324b33c570033fa4fc29a6a32dff73b64de --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/config_check/ckpt_compare/metrics.py @@ -0,0 +1,83 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from msprobe.core.common.log import logger +from msprobe.core.compare.npy_compare import CompareOps + + + +def in_different_shape(a, b): + if a.shape != b.shape: + logger.warning(f"a, b are in different shape. a: {a.shape}, b: {b.shape}") + return True + return False + + +def l2_distance(a, b): + if a is None or b is None: + return None + if in_different_shape(a, b): + return None + return np.linalg.norm(a - b).item() + + +def cos_sim(a, b): + if a is None or b is None: + return None + + if in_different_shape(a, b): + return None + if a.ndim > 0: + a = a.flatten().squeeze() + b = b.flatten().squeeze() + + num = a.dot(b) + a_norm = np.linalg.norm(a) + b_norm = np.linalg.norm(b) + + if a_norm == 0 and b_norm == 0: + return 1. + if a_norm == 0 or b_norm == 0: + logger.warning(f'One tensor norm is zero.') + return None + + sim = num / (a_norm * b_norm) + + return sim.item() + + +def numel(a, b): + n1 = a.size + n2 = b.size + if n1 != n2: + logger.warning('parameters have different number of element') + return (n1, n2) + return n1 + + +def shape(a, b): + if in_different_shape(a, b): + return [list(a.shape), list(b.shape)] + return list(a.shape) + + +METRIC_FUNC = { + 'l2': l2_distance, + 'cos': cos_sim, + 'numel': numel, + 'shape': shape + } diff --git a/debug/accuracy_tools/msprobe/core/config_check/ckpt_compare/name_mapping.yaml b/debug/accuracy_tools/msprobe/core/config_check/ckpt_compare/name_mapping.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0caecc53a73b108939435867fe1b6e614bd91812 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/config_check/ckpt_compare/name_mapping.yaml @@ -0,0 +1,12 @@ +self_attention.linear_qkv.layer_norm_: input_layernorm. +language_model.: '' +encoder: decoder +.input_norm.: .input_layernorm. +query_key_value: linear_qkv +.dense.: .linear_proj. +post_attention_norm: pre_mlp_layernorm +dense_h_to_4h: linear_fc1 +dense_4h_to_h: linear_fc2 +mlp.local_experts: mlp.experts.local_experts +final_norm: final_layernorm +word_embeddings_for_head: output_layer diff --git a/debug/accuracy_tools/msprobe/core/config_check/config_check_cli.py b/debug/accuracy_tools/msprobe/core/config_check/config_check_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..d7ea45ff0890297702cc91c6c9a9c464e5a6c263 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/config_check/config_check_cli.py @@ -0,0 +1,51 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core.config_check.config_checker import ConfigChecker +from msprobe.core.config_check.ckpt_compare.ckpt_comparator import compare_checkpoints +from msprobe.core.common.log import logger + + +def pack(shell_path, output_path, framework): + ConfigChecker(shell_path=shell_path, output_zip_path=output_path, fmk=framework) + + +def compare(bench_zip_path, cmp_zip_path, output_path, framework): + ConfigChecker.compare(bench_zip_path, cmp_zip_path, output_path, framework) + + +def _config_checking_parser(parser): + parser.add_argument('-d', '--dump', nargs='*', help='Collect the train config into a zip file') + parser.add_argument('-c', '--compare', nargs=2, help='Compare two zip files or checkpoints') + parser.add_argument('-o', '--output', help='output path, default is ./config_check_result') + + +def _run_config_checking_command(args): + if args.dump is not None: + output_dirpath = args.output if args.output else "./config_check_pack.zip" + pack(args.dump, output_dirpath, args.framework) + elif args.compare: + if args.compare[0].endswith('zip'): + logger.info('The input paths is zip files, comparing packed config.') + output_dirpath = args.output if args.output else "./config_check_result" + compare(args.compare[0], args.compare[1], output_dirpath, args.framework) + else: + logger.info('Comparing model checkpoint.') + output_dirpath = args.output if args.output else "./ckpt_similarity.json" + compare_checkpoints(args.compare[0], args.compare[1], output_dirpath) + + else: + logger.error("The param is not correct, you need to give '-d' for dump or '-c' for compare.") + raise Exception("The param is not correct, you need to give '-d' for dump or '-c' for compare.") diff --git a/debug/accuracy_tools/msprobe/core/config_check/config_checker.py b/debug/accuracy_tools/msprobe/core/config_check/config_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..537bbda1a415a48fb92d95b43a77642f4cc9e192 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/config_check/config_checker.py @@ -0,0 +1,99 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil + +import pandas as pd + +from msprobe.core.common.file_utils import save_excel, split_zip_file_path, \ + create_directory, extract_zip +from msprobe.core.common.framework_adapter import FmkAdp +from msprobe.core.config_check.checkers.base_checker import PackInput +from msprobe.core.config_check.utils.utils import config_checking_print +from msprobe.core.common.const import Const + + +class ConfigChecker: + checkers = {} + pre_forward_fun_list = [] + result_filename = "result.xlsx" + result_header = ["filename", "pass_check"] + step = 0 + + def __init__(self, model=None, shell_path=None, output_zip_path="./config_check_pack.zip", fmk="pytorch"): + FmkAdp.set_fmk(fmk) + self.pack_input = PackInput(output_zip_path, model, shell_path) + file_path, file_name = split_zip_file_path(self.pack_input.output_zip_path) + if not os.path.exists(file_path): + create_directory(file_path) + self.pack() + + @staticmethod + def compare(bench_zip_path, cmp_zip_path, output_path, fmk=Const.PT_FRAMEWORK): + create_directory(output_path) + bench_dir = os.path.join(output_path, "bench") + cmp_dir = os.path.join(output_path, "cmp") + extract_zip(bench_zip_path, bench_dir) + config_checking_print(f"extract zip file {bench_zip_path} to {bench_dir}") + extract_zip(cmp_zip_path, cmp_dir) + config_checking_print(f"extract zip file {cmp_zip_path} to {cmp_dir}") + + result = [] + summary_result = [] + for checker in ConfigChecker.checkers.values(): + checker_name, pass_check, df = checker.compare_ex(bench_dir, cmp_dir, output_path, fmk) + if checker_name: + summary_result.append([checker_name, pass_check]) + if df is not None: + result.append((df, checker_name)) + summary_result_df = pd.DataFrame(summary_result, columns=ConfigChecker.result_header) + result.insert(0, (summary_result_df, "summary")) + save_excel(os.path.join(output_path, ConfigChecker.result_filename), result) + config_checking_print(f"config checking result save to {os.path.realpath(output_path)}") + + @staticmethod + def apply_patches(fmk=Const.PT_FRAMEWORK): + for checker in ConfigChecker.checkers.values(): + checker.apply_patches(fmk) + + def pack(self): + config_checking_print(f"pack result zip path {os.path.realpath(self.pack_input.output_zip_path)}") + + def hook(model, args, kwargs): + for collect_func in self.pre_forward_fun_list: + collect_func(model, args, kwargs, ConfigChecker.step) + ConfigChecker.step += 1 + + if self.pack_input.model: + FmkAdp.register_forward_pre_hook(self.pack_input.model, hook, with_kwargs=True) + for checker in ConfigChecker.checkers.values(): + if checker.input_needed and not getattr(self.pack_input, checker.input_needed): + continue + if FmkAdp.is_initialized() and FmkAdp.get_rank() != 0 and not checker.multi_rank: + continue + checker.pack(self.pack_input) + + +def register_checker_item(key, cls=None): + if cls is None: + # 无参数时,返回装饰器函数 + return lambda cls: register_checker_item(key, cls) + ConfigChecker.checkers[key] = cls + return cls + + +def register_pre_forward_fun_list(func): + ConfigChecker.pre_forward_fun_list.append(func) diff --git a/debug/accuracy_tools/msprobe/pytorch/parse.py b/debug/accuracy_tools/msprobe/core/config_check/resource/dependency.yaml similarity index 87% rename from debug/accuracy_tools/msprobe/pytorch/parse.py rename to debug/accuracy_tools/msprobe/core/config_check/resource/dependency.yaml index 3dfd88f03d1b944f6943a58ce860c7de9c4a3424..02c0b565bf59b1b220f16ae17a47f5f4d5b13c1f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse.py +++ b/debug/accuracy_tools/msprobe/core/config_check/resource/dependency.yaml @@ -13,7 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from msprobe.pytorch.parse_tool import cli - -if __name__ == '__main__': - cli.parse() +dependency: + - transformers + - deepspeed + - megatron + - numpy + - datasets + - peft \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/core/config_check/resource/env.yaml b/debug/accuracy_tools/msprobe/core/config_check/resource/env.yaml new file mode 100644 index 0000000000000000000000000000000000000000..87d663b9d94976c24feb88b181b3ead98905eb5a --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/config_check/resource/env.yaml @@ -0,0 +1,57 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +HCCL_DETERMINISTIC: + npu: + name: HCCL_DETERMINISTIC + default_value: False + gpu: + name: NCCL_DETERMINISTIC + default_value: False + +HCCL_ALGO: + npu: + name: HCCL_ALGO + default_value: None + gpu: + name: NCCL_ALGO + default_value: None + +HCCL_INTRA_ROCE_ENABLE: + npu: + name: HCCL_INTRA_ROCE_ENABLE + default_value: 0 + + +HCCL_INTRA_PICE_ENABLE: + npu: + name: HCCL_INTRA_ROCE_ENABLE + default_value: 1 + +ASCEND_LAUNCH_BLOCKING: + npu: + name: ASCEND_LAUNCH_BLOCKING + default_value: 0 + gpu: + name: CUDA_LAUNCH_BLOCKING + default_value: 0 + +ASCEND_RT_VISIBLE_DEVICES: + npu: + name: ASCEND_RT_VISIBLE_DEVICES + default_value: None + gpu: + name: CUDA_VISIBLE_DEVICES + default_value: None \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/core/config_check/resource/hyperparameter.yaml b/debug/accuracy_tools/msprobe/core/config_check/resource/hyperparameter.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4ec150331ed798882e8abbb6cb826af8f8109a91 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/config_check/resource/hyperparameter.yaml @@ -0,0 +1,31 @@ +learning_rate: + - lr + - learningrate + +batch_size: + - batch + - bs + - batch_size_per_gpu + +epochs: + - num_epochs + - max_epochs + - epoch + +weight_decay: + - wd + - weightdecay + +dropout_rate: + - dropout + - drop_rate + +compute_dtype: + - bf16 + - fp32 + +residual_dtype: + - fp32_residual_connection + +softmax_compute_dtype: + - attention_softmax_in_fp32 diff --git a/debug/accuracy_tools/msprobe/core/config_check/utils/hyperparameter_parser.py b/debug/accuracy_tools/msprobe/core/config_check/utils/hyperparameter_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..9e02cf5c92885215fede197d5e98fa738d243741 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/config_check/utils/hyperparameter_parser.py @@ -0,0 +1,119 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from abc import ABC, abstractmethod + +from msprobe.core.config_check.utils.utils import config_checking_print +from msprobe.core.common.file_utils import FileOpen, load_yaml +from msprobe.core.common.const import Const, FileCheckConst + + +class Parser(ABC): + @abstractmethod + def parse(self, file_path: str) -> dict: + pass + + def run(self, file_path: str) -> dict: + """ + 统一对外调用接口 + :param file_path: 需解析的文件路径 + :return: + """ + try: + result = self.parse(file_path) + except Exception as exc: + config_checking_print(f"{self.__class__} parsing error, skip file path: {file_path}, error: {exc}") + result = {} + return result + + +class ShellParser(Parser): + def parse(self, file_path: str) -> dict: + """ + Extracts arguments from bash script used to run a model training. + """ + hyperparameters = {} + script_content_list = [] + with FileOpen(file_path, 'r') as file: + for line in file: + stripped_line = line.lstrip() + if not stripped_line.startswith('#'): + line = line.split('#')[0].rstrip() + '\n' + if line.strip(): + script_content_list.append(line) + script_content = ''.join(script_content_list) + + command_line = re.search(r'msrun\s[^|]*|torchrun\s[^|]*|python\d? -m torch.distributed.launch\s[^|]*', + script_content, + re.DOTALL) + if command_line: + command_line = command_line.group() + + blocks = re.findall(r'([a-zA-Z0-9_]{1,20}_ARGS)="(.*?)"', script_content, re.DOTALL) + block_contents = {} + for block_name, block_content in blocks: + block_content = block_content.replace('\n', ' ') + block_contents[block_name] = block_content + command_line = command_line.replace(f"${block_name}", block_content) + + matches = re.findall(r'--([\w-]+)(?:\s+([^\s\\]+))?', command_line) + for match in matches: + key, value = match + args_key = re.match(r'\$\{?(\w+)}?', value) + if args_key: + env_vars = re.findall(rf'{args_key.group(1)}=\s*(.+)', script_content) + if env_vars: + value = env_vars[-1] + hyperparameters[key] = value if value else True + + return hyperparameters + + +class YamlParser(Parser): + hyperparameters = {} + + def parse(self, file_path: str) -> dict: + ori_hyper = load_yaml(file_path) + self.recursive_parse_parameters(ori_hyper, "") + return self.hyperparameters + + def recursive_parse_parameters(self, parameters, prefix): + if isinstance(parameters, dict): + for key, value in parameters.items(): + new_prefix = prefix + Const.SEP + key if prefix else key + self.recursive_parse_parameters(value, new_prefix) + elif isinstance(parameters, list): + if all(isinstance(x, (int, float, str, bool, list))for x in parameters): + self.hyperparameters.update({prefix: parameters}) + else: + for idx, value in enumerate(parameters): + new_prefix = prefix + Const.SEP + str(idx) if prefix else str(idx) + self.recursive_parse_parameters(value, new_prefix) + elif isinstance(parameters, (int, float, str, bool)): + self.hyperparameters.update({prefix: parameters}) + + +class ParserFactory: + __ParserDict = { + FileCheckConst.SHELL_SUFFIX: ShellParser(), + FileCheckConst.YAML_SUFFIX: YamlParser() + } + + def get_parser(self, file_type: str) -> Parser: + parser = self.__ParserDict.get(file_type, None) + if not parser: + raise ValueError(f'Invalid parser type: {file_type}') + return parser diff --git a/debug/accuracy_tools/msprobe/core/config_check/utils/utils.py b/debug/accuracy_tools/msprobe/core/config_check/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eedcc34cf5ec6753b86828f9ef688df3679dac60 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/config_check/utils/utils.py @@ -0,0 +1,117 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import hashlib + +from msprobe.core.common.framework_adapter import FmkAdp +from msprobe.core.common.log import logger +from msprobe.core.common.const import Const + + +def merge_keys(dir_0, dir_1): + output_list = list(dir_0.keys()) + output_list.extend(list(dir_1.keys())) + return set(output_list) + + +def compare_dict(bench_dict, cmp_dict): + result = [] + for key in set(bench_dict.keys()) | set(cmp_dict.keys()): + if key in bench_dict and key in cmp_dict: + if bench_dict[key] != cmp_dict[key]: + result.append(f"{key}: {bench_dict[key]} -> {cmp_dict[key]}") + elif key in bench_dict: + result.append(f"{key}: [deleted] -> {bench_dict[key]}") + else: + result.append(f"{key}: [added] -> {cmp_dict[key]}") + return result + + +def config_checking_print(msg): + logger.info(f"[config checking log] {msg}") + + +def tensor_to_hash(tensor): + """Compute the hash value of a tensor""" + tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes() + return bytes_hash(tensor_bytes) + + +def get_tensor_features(tensor): + features = { + "max": FmkAdp.tensor_max(tensor), + "min": FmkAdp.tensor_min(tensor), + "mean": FmkAdp.tensor_mean(tensor), + "norm": FmkAdp.tensor_norm(tensor), + } + + return features + + +def compare_dicts(dict1, dict2, path=''): + deleted = [] + added = [] + changed = [] + result = {} + + for key in dict1: + if key not in dict2: + deleted.append(f"[Deleted]: {path + key}") + result[key] = "[deleted]" + else: + if isinstance(dict1[key], dict) and isinstance(dict2[key], dict): + sub_deleted, sub_added, sub_changed, sub_result = compare_dicts( + dict1[key], dict2[key], path + key + '/') + deleted.extend(sub_deleted) + added.extend(sub_added) + changed.extend(sub_changed) + if sub_result: + result[key] = sub_result + elif dict1[key] != dict2[key]: + changed.append(f"[Changed]: {path + key} : {dict1[key]} -> {dict2[key]}") + result[key] = f"[changed]: {dict1[key]} -> {dict2[key]}" + for key in dict2: + if key not in dict1: + added.append(f"[Added]: {path + key}") + result[key] = "[added]" + return deleted, added, changed, result + + +def bytes_hash(obj: bytes): + hex_dig = hashlib.sha256(obj).hexdigest() + short_hash = int(hex_dig, 16) % (2 ** 16) + return short_hash + + +def update_dict(ori_dict, new_dict): + for key, value in new_dict.items(): + if key in ori_dict and ori_dict[key] != value: + if "values" in ori_dict.keys(): + ori_dict[key]["values"].append(new_dict[key]) + else: + ori_dict[key] = {"description": "duplicate_value", "values": [ori_dict[key], new_dict[key]]} + else: + ori_dict[key] = value + + +def process_pass_check(data): + if Const.CONFIG_CHECK_ERROR in data: + return Const.CONFIG_CHECK_ERROR + elif Const.CONFIG_CHECK_WARNING in data: + return Const.CONFIG_CHECK_WARNING + else: + return Const.CONFIG_CHECK_PASS diff --git a/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py b/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..cc90aaa94ddbcd539919385b58ac4253b0120e09 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py @@ -0,0 +1,258 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Dict, Any, Optional, Callable, Union, List, Tuple + +from msprobe.core.common.const import Const +from msprobe.core.common.file_utils import load_yaml +from msprobe.core.common.log import logger + + +def _get_attr(module, attr_name): + if Const.SEP in attr_name: + sub_module_name, sub_attr = attr_name.rsplit(Const.SEP, 1) + sub_module = getattr(module, sub_module_name, None) + attr = getattr(sub_module, sub_attr, None) + else: + attr = getattr(module, attr_name, None) + return attr + + +class ApiWrapper: + def __init__( + self, api_types: Dict[str, Dict[str, Any]], + api_list_paths: Union[str, List[str], Tuple[str]], + blacklist: Union[List[str], Tuple[str]] = None + ): + self.api_types = api_types + if not isinstance(api_list_paths, (list, tuple)): + api_list_paths = [api_list_paths] * len(self.api_types) + elif len(api_list_paths) != len(self.api_types): + raise RuntimeError("The number of api_list_paths must be equal to the number of frameworks in 'api_types', " + "when api_list_paths is a list or tuple.") + self.api_list_paths = api_list_paths + self.blacklist = blacklist if blacklist else [] + self.api_names = self._get_api_names() + self.wrapped_api_functions = dict() + + @staticmethod + def deal_with_self_kwargs(api_name, api_func, args, kwargs): + if kwargs and 'self' in kwargs: + func_params = None + try: + func_params = inspect.signature(api_func).parameters + except Exception: + if api_name in Const.API_WITH_SELF_ARG: + func_params = inspect.signature(Const.API_WITH_SELF_ARG.get(api_name)).parameters + if func_params is None: + return False, args, kwargs + + for name, param in func_params.items(): + if name == 'self' and param.kind == inspect.Parameter.KEYWORD_ONLY: + return False, args, kwargs + args_ = list(args) + names_and_values = [] + self_index = 0 + for i, item in enumerate(func_params.items()): + names_and_values.append((item[0], item[1].default)) + if item[0] == 'self': + self_index = i + break + for i in range(len(args), self_index + 1): + if names_and_values[i][0] in kwargs: + args_.append(kwargs.pop(names_and_values[i][0])) + else: + args_.append(names_and_values[i][1]) + args = tuple(args_) + + return True, args, kwargs + + def wrap_api_func(self, api_name, api_func, prefix, hook_build_func, api_template): + api_instance = api_template(api_name, api_func, prefix, hook_build_func) + + def api_function(*args, **kwargs): + api_name_with_prefix = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + enable_wrap, args, kwargs = self.deal_with_self_kwargs(api_name_with_prefix, api_func, args, kwargs) + if not enable_wrap: + logger.warning(f'Cannot collect precision data of {api_name_with_prefix}. ' + 'It may be fixed by passing the value of "self" ' + 'as a positional argument instead of a keyword argument. ') + return api_func(*args, **kwargs) + return api_instance(*args, **kwargs) + + for attr_name in Const.API_ATTR_LIST: + if hasattr(api_func, attr_name): + attr_value = getattr(api_func, attr_name) + setattr(api_function, attr_name, attr_value) + + return api_function + + def wrap_api( + self, api_templates, hook_build_func: Optional[Callable] + ): + api_types_num = sum([len(v) for v in self.api_types.values()]) + if not isinstance(api_templates, (list, tuple)): + api_templates = [api_templates] * api_types_num + elif len(api_templates) != api_types_num: + raise RuntimeError("The number of api_templates must be equal to the number of api_types, " + "when api_templates is a list or tuple.") + + self.wrapped_api_functions.clear() + index = 0 + for framework, api_types in self.api_types.items(): + wrapped_functions_in_framework = dict() + for api_type, api_modules in api_types.items(): + wrapped_functions = dict() + name_prefix = Const.API_DATA_PREFIX.get(framework, {}).get(api_type, "API") + api_template = api_templates[index] + index += 1 + for api_name in self.api_names.get(framework, {}).get(api_type, []): + ori_api = None + for module in api_modules[0]: + ori_api = ori_api or _get_attr(module, api_name) + if callable(ori_api): + wrapped_functions[api_name] = self.wrap_api_func( + api_name, + ori_api, + name_prefix, + hook_build_func, + api_template + ) + wrapped_functions_in_framework[api_type] = wrapped_functions + self.wrapped_api_functions[framework] = wrapped_functions_in_framework + return self.wrapped_api_functions + + def _get_api_names(self): + api_names = dict() + + for index, framework in enumerate(self.api_types.keys()): + api_list = load_yaml(self.api_list_paths[index]) + valid_names = dict() + for api_type, api_modules in self.api_types.get(framework, {}).items(): + key_in_file = Const.SUPPORT_API_DICT_KEY_MAP.get(framework, {}).get(api_type) + api_from_file = api_list.get(key_in_file, []) + names = set() + for api_name in api_from_file: + if f'{key_in_file}.{api_name}' in self.blacklist: + continue + target_attr = api_name + for module in api_modules[0]: + if Const.SEP in api_name: + sub_module_name, target_attr = api_name.rsplit(Const.SEP, 1) + target_module = getattr(module, sub_module_name, None) + else: + target_module = module + if target_module and target_attr in dir(target_module): + names.add(api_name) + valid_names[api_type] = names + api_names[framework] = valid_names + + return api_names + + +class ApiRegistry: + """ + Base class for api registry. + """ + + def __init__(self, api_types, inner_used_api, supported_api_list_path, api_templates, blacklist=None): + self.ori_api_attr = dict() + self.wrapped_api_attr = dict() + self.inner_used_ori_attr = dict() + self.inner_used_wrapped_attr = dict() + self.api_types = api_types + self.inner_used_api = inner_used_api + self.supported_api_list_path = supported_api_list_path + self.api_templates = api_templates + self.blacklist = blacklist if blacklist else [] + self.all_api_registered = False + + @staticmethod + def store_ori_attr(ori_api_groups, api_list, api_ori_attr): + for api in api_list: + ori_api = None + for ori_api_group in ori_api_groups: + ori_api = ori_api or _get_attr(ori_api_group, api) + api_ori_attr[api] = ori_api + + @staticmethod + def set_api_attr(api_group, attr_dict): + for api, api_attr in attr_dict.items(): + if Const.SEP in api: + sub_module_name, sub_op = api.rsplit(Const.SEP, 1) + sub_module = getattr(api_group, sub_module_name, None) + if sub_module is not None: + setattr(sub_module, sub_op, api_attr) + else: + setattr(api_group, api, api_attr) + + @staticmethod + def register_custom_api(module, api_name, api_prefix, hook_build_func, api_template): + def wrap_api_func(api_name, api_func, prefix, hook_build_func, api_template): + def api_function(*args, **kwargs): + return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs) + + api_function.__name__ = api_name + return api_function + + setattr(module, api_name, + wrap_api_func(api_name, getattr(module, api_name), api_prefix, hook_build_func, api_template)) + + def register_all_api(self): + self.all_api_registered = True + for framework, api_types in self.api_types.items(): + for api_type, api_modules in api_types.items(): + api_type_with_framework = framework + Const.SEP + api_type + for module in api_modules[1]: + self.set_api_attr(module, self.wrapped_api_attr.get(api_type_with_framework, {})) + + def register_inner_used_api(self): + for api_type in self.inner_used_api.keys(): + self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_wrapped_attr.get(api_type, {})) + + def restore_all_api(self): + self.all_api_registered = False + for framework, api_types in self.api_types.items(): + for api_type, api_modules in api_types.items(): + api_type_with_framework = framework + Const.SEP + api_type + for module in api_modules[1]: + self.set_api_attr(module, self.ori_api_attr.get(api_type_with_framework, {})) + + def restore_inner_used_api(self): + for api_type in self.inner_used_api.keys(): + self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_ori_attr.get(api_type, {})) + + def initialize_hook(self, hook_build_func): + api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path, self.blacklist) + wrapped_api_functions = api_wrapper.wrap_api(self.api_templates, hook_build_func) + + for framework, api_types in self.api_types.items(): + for api_type, api_modules in api_types.items(): + ori_attr = dict() + self.store_ori_attr(api_modules[0], api_wrapper.api_names.get(framework).get(api_type), ori_attr) + api_type_with_framework = framework + Const.SEP + api_type + self.ori_api_attr[api_type_with_framework] = ori_attr + self.wrapped_api_attr[api_type_with_framework] = wrapped_api_functions.get(framework).get(api_type) + + for inner_used_api_type, inner_used_api_list in self.inner_used_api.items(): + ori_attr = dict() + wrapped_attr = dict() + for api_name in inner_used_api_list[1:]: + if self.ori_api_attr.get(inner_used_api_type, {}).get(api_name): + ori_attr[api_name] = self.ori_api_attr.get(inner_used_api_type).get(api_name) + wrapped_attr[api_name] = self.wrapped_api_attr.get(inner_used_api_type).get(api_name) + self.inner_used_ori_attr[inner_used_api_type] = ori_attr + self.inner_used_wrapped_attr[inner_used_api_type] = wrapped_attr diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py index 20e4489f89e4bd345595e6a1db1e39ab427d4908..b180030ca872a1824494a57ccab2ff64fd46fb5f 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py @@ -15,6 +15,8 @@ import atexit import os +import threading +import traceback from msprobe.core.data_dump.scope import ScopeFactory from msprobe.core.data_dump.json_writer import DataWriter @@ -39,9 +41,10 @@ class DataCollector: self.module_count = {} self.scope = ScopeFactory(self.config).build_scope() self.backward_module_names = {} + self.params_grad_record = {} self.optimizer_status = "" self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True} - atexit.register(self.write_json) + atexit.register(self.write_json_at_exit) @property def dump_data_dir(self): @@ -78,6 +81,11 @@ class DataCollector: def write_json(self): self.data_writer.write_json() + def write_json_at_exit(self): + if self.config.async_dump and self.config.task == Const.TENSOR: + self.data_processor.dump_async_data() + self.data_writer.write_json() + def update_data(self, name, data_info): msg = f"msprobe is collecting data on {name}." if self.config.task == Const.OVERFLOW_CHECK: @@ -89,88 +97,174 @@ class DataCollector: logger.debug(msg) self.data_writer.update_data(data_info) - def forward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None): - if self.config.task == Const.FREE_BENCHMARK: - backward_name = name.replace(Const.FORWARD, Const.BACKWARD) - if self.check_scope_and_pid(self.scope, backward_name, pid): - self.data_processor.analyze_forward_input(backward_name, module, module_input_output) - return - - if not self.check_scope_and_pid(self.scope, name, pid): - return + def call_stack_collect(self, name): + stack_info = self.data_processor.analyze_api_call_stack(name) + self.data_writer.update_stack(name, stack_info) - data_info = {} - if self.config.task != Const.STRUCTURE: - data_info = self.data_processor.analyze_forward_input(name, module, module_input_output) - self.set_is_recomputable(data_info, is_recompute) - if self.config.level == Const.LEVEL_L2: - return - self.handle_data(name, data_info, flush=self.data_processor.is_terminated) + def forward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None): + try: + + if self.config.task == Const.FREE_BENCHMARK: + backward_name = name.replace(Const.FORWARD, Const.BACKWARD) + if self.check_scope_and_pid(self.scope, backward_name, pid): + self.data_processor.analyze_forward_input(backward_name, module, module_input_output) + return + + if not self.check_scope_and_pid(self.scope, name, pid): + return + + data_info = {} + if self.config.task != Const.STRUCTURE: + data_info = self.data_processor.analyze_forward_input(name, module, module_input_output) + self.set_is_recomputable(data_info, is_recompute) + if self.config.level == Const.LEVEL_L2: + return + self.call_stack_collect(name) + self.handle_data(name, data_info, flush=self.data_processor.is_terminated) + + except Exception as e: + # 取异常类名作为“类型”做去重 + error_type = type(e).__name__ + tb = traceback.format_exc() + self.data_writer.write_error_log( + f"[ERROR] forward_input_data_collect failed: name={name}, pid={pid}\n{tb}", + error_type=error_type + ) def forward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None): - self.update_construct(name) - if not self.check_scope_and_pid(self.scope, name, pid): - return - - data_info = {} - if self.config.task != Const.STRUCTURE: - data_info = self.data_processor.analyze_forward_output(name, module, module_input_output) - self.set_is_recomputable(data_info, is_recompute) - if self.config.level == Const.LEVEL_L2: - return - self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name)) - self.handle_data(name, data_info, flush=self.data_processor.is_terminated) + try: + + self.update_construct(name) + if not self.check_scope_and_pid(self.scope, name, pid): + return + + data_info = {} + if self.config.task != Const.STRUCTURE: + data_info = self.data_processor.analyze_forward_output(name, module, module_input_output) + self.set_is_recomputable(data_info, is_recompute) + if self.config.level == Const.LEVEL_L2: + return + self.handle_data(name, data_info, flush=self.data_processor.is_terminated) + + except Exception as e: + # 取异常类名作为“类型”做去重 + error_type = type(e).__name__ + tb = traceback.format_exc() + self.data_writer.write_error_log( + f"[ERROR] forward_output_data_collect failed: name={name}, pid={pid}\n{tb}", + error_type=error_type + ) + + def forward_data_collect_only_tensor(self, name, module, pid, module_input_output): + try: + if not self.check_scope_and_pid(self.scope, name, pid): + return + self.data_processor.analyze_forward(name, module, module_input_output) + + except Exception as e: + # 取异常类名作为“类型”做去重 + error_type = type(e).__name__ + tb = traceback.format_exc() + self.data_writer.write_error_log( + f"[ERROR] forward_data_collect_only_tensor failed: name={name}, pid={pid}\n{tb}", + error_type=error_type + ) def forward_data_collect(self, name, module, pid, module_input_output, is_recompute=None): - self.update_construct(name) - if not self.check_scope_and_pid(self.scope, name, pid): - return - - data_info = {} - if self.config.task != Const.STRUCTURE: - data_info = self.data_processor.analyze_forward(name, module, module_input_output) - self.set_is_recomputable(data_info, is_recompute) - self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name)) - self.handle_data(name, data_info, flush=self.data_processor.is_terminated) + try: + + self.update_construct(name) + if not self.check_scope_and_pid(self.scope, name, pid): + return + data_info = {} + if self.config.task != Const.STRUCTURE: + data_info = self.data_processor.analyze_forward(name, module, module_input_output) + self.set_is_recomputable(data_info, is_recompute) + self.call_stack_collect(name) + self.handle_data(name, data_info, flush=self.data_processor.is_terminated) + + except Exception as e: + error_type = type(e).__name__ + tb = traceback.format_exc() + self.data_writer.write_error_log( + f"[ERROR] forward_data_collect failed: name={name}, pid={pid}\n{tb}", + error_type=error_type + ) + + def backward_data_collect_only_tensor(self, name, module, pid, module_input_output, is_recompute=None): + try: + if not self.check_scope_and_pid(self.scope, name, pid): + return + self.data_processor.analyze_backward(name, module, module_input_output) + + except Exception as e: + error_type = type(e).__name__ + tb = traceback.format_exc() + self.data_writer.write_error_log( + f"[ERROR] backward_data_collect_only_tensor failed: name={name}, pid={pid}\n{tb}", + error_type=error_type + ) def backward_data_collect(self, name, module, pid, module_input_output, is_recompute=None): - self.update_construct(name) - if not self.check_scope_and_pid(self.scope, name, pid): - return - - data_info = {} - if self.config.task != Const.STRUCTURE: - data_info = self.data_processor.analyze_backward(name, module, module_input_output) - if self.config.level == Const.LEVEL_L2: - return - # 获取执行反向的模块名称 - if data_info and name.split(Const.SEP)[0] in Const.MODULE_PREFIX: - module_name = name.rsplit(Const.SEP, 2)[0] - # 将模块名称加入到反向模块名称集合中,用于梯度收集时判断是否需要收集梯度 - self.backward_module_names[module_name] = True - self.handle_data(name, data_info, flush=self.data_processor.is_terminated) + try: + self.update_construct(name) + if not self.check_scope_and_pid(self.scope, name, pid): + return + data_info = {} + if self.config.task != Const.STRUCTURE: + data_info = self.data_processor.analyze_backward(name, module, module_input_output) + if self.config.level == Const.LEVEL_L2: + return + if data_info and name.split(Const.SEP)[0] in Const.MODULE_PREFIX: + module_name = name.rsplit(Const.SEP, 2)[0] + self.backward_module_names[module_name] = True + self.handle_data(name, data_info, flush=self.data_processor.is_terminated) + + except Exception as e: + error_type = type(e).__name__ + tb = traceback.format_exc() + self.data_writer.write_error_log( + f"[ERROR] backward_data_collect failed: name={name}, pid={pid}\n{tb}", + error_type=error_type + ) def backward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None): - self.update_construct(name) - if not self.check_scope_and_pid(self.scope, name, pid): - return - - data_info = {} - if self.config.task != Const.STRUCTURE: - data_info = self.data_processor.analyze_backward_input(name, module, module_input_output) - self.set_is_recomputable(data_info, is_recompute) - self.handle_data(name, data_info) + try: + self.update_construct(name) + if not self.check_scope_and_pid(self.scope, name, pid): + return + data_info = {} + if self.config.task != Const.STRUCTURE: + data_info = self.data_processor.analyze_backward_input(name, module, module_input_output) + self.set_is_recomputable(data_info, is_recompute) + self.handle_data(name, data_info) + + except Exception as e: + error_type = type(e).__name__ + tb = traceback.format_exc() + self.data_writer.write_error_log( + f"[ERROR] backward_input_data_collect failed: name={name}, pid={pid}\n{tb}", + error_type=error_type + ) def backward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None): - self.update_construct(name) - if not self.check_scope_and_pid(self.scope, name, pid): - return - - data_info = {} - if self.config.task != Const.STRUCTURE: - data_info = self.data_processor.analyze_backward_output(name, module, module_input_output) - self.set_is_recomputable(data_info, is_recompute) - self.handle_data(name, data_info) + try: + self.update_construct(name) + if not self.check_scope_and_pid(self.scope, name, pid): + return + data_info = {} + if self.config.task != Const.STRUCTURE: + data_info = self.data_processor.analyze_backward_output(name, module, module_input_output) + self.set_is_recomputable(data_info, is_recompute) + self.handle_data(name, data_info) + + except Exception as e: + error_type = type(e).__name__ + tb = traceback.format_exc() + self.data_writer.write_error_log( + f"[ERROR] backward_output_data_collect failed: name={name}, pid={pid}\n{tb}", + error_type=error_type + ) def update_construct(self, name): if self.config.level not in DataCollector.level_without_construct: @@ -180,7 +274,12 @@ class DataCollector: self.optimizer_status_first_start[self.optimizer_status] = False self.data_writer.update_construct({name: self.optimizer_status}) else: - self.data_writer.update_construct({name: self.module_processor.api_parent_node}) + if self.config.level == Const.LEVEL_MIX and \ + not (name.startswith(Const.MODULE) or name.startswith(Const.CELL)): + self.data_writer.update_construct( + {name: self.module_processor.api_parent_node.get(threading.get_ident())} + ) + self.data_writer.update_construct(self.module_processor.module_node) def handle_data(self, name, data_info, flush=False): @@ -203,28 +302,58 @@ class DataCollector: self.data_processor.update_iter(current_iter) def params_data_collect(self, name, param_name, pid, data): - grad_name = name + Const.SEP + Const.PARAMS_GRAD - # 校验scope和pid,以及当前name是否有过反向计算 - if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name): - # 如果没有反向计算,则需要清除之前占位写入的grad数据 - if self.data_writer.cache_data.get("data"): - self.data_writer.cache_data.get("data").pop(grad_name, None) - return - data_info = self.data_processor.analyze_params(grad_name, param_name, data) - self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated) - - def fill_stack_tensor_data(self): - self.data_writer.fill_stack_tensor_data() + try: + grad_name = name + Const.SEP + Const.PARAMS_GRAD + self.update_api_or_module_name(grad_name) + if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name): + if self.data_writer.cache_data.get("data"): + self.data_writer.cache_data.get("data").pop(grad_name, None) + self.params_grad_record[grad_name] = False + return + data_info = self.data_processor.analyze_params(grad_name, param_name, data) + self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated) + self.params_grad_record[grad_name] = False + except Exception as e: + error_type = type(e).__name__ + tb = traceback.format_exc() + self.data_writer.write_error_log( + f"[ERROR] params_data_collect failed: " + f"name={name}, param_name={param_name}, pid={pid}\n{tb}", + error_type=error_type + ) + + def params_data_collect_in_bw_hook(self, params_dict, name): + try: + if not params_dict: + return + ori_name = name.rsplit(Const.SEP, 2)[0] + for param_name, param in params_dict.items(): + grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD + self.update_api_or_module_name(grad_name) + if self.params_grad_record.get(grad_name, False): + grad = param.grad if hasattr(param, "grad") else None + data_info = self.data_processor.analyze_params(grad_name, param_name, grad) + self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated) + except Exception as e: + error_type = type(e).__name__ + tb = traceback.format_exc() + self.data_writer.write_error_log( + f"[ERROR] params_data_collect_in_bw_hook failed: " + f"name={name}", + error_type=error_type + ) def debug_data_collect_forward(self, variable, name_with_count): - data_info = self.data_processor.analyze_debug_forward(variable, name_with_count) - self.data_writer.update_debug({name_with_count: data_info}) + name_with_count_category = name_with_count + Const.SEP + Const.DEBUG + self.data_writer.update_debug({name_with_count_category: data_info}) def debug_data_collect_backward(self, variable, grad_name_with_count): # prepare all None nested data structure all_none_data_info = self.data_processor.analyze_element_to_all_none(variable) - self.data_writer.update_debug({grad_name_with_count: all_none_data_info}) + grad_name_with_count_category = grad_name_with_count + Const.SEP + Const.DEBUG + self.data_writer.update_debug({grad_name_with_count_category: all_none_data_info}) # register tensor backward hook - self.data_processor.analyze_debug_backward(variable, grad_name_with_count, self.data_writer.cache_debug['data']) + self.data_processor.analyze_debug_backward(variable, grad_name_with_count_category, + self.data_writer.cache_debug['data']) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index 775a80b2418ef356867228b4ca09fad8c86cce25..c43dc9deee7b7d531e725a9b92ae8985cbaf4cad 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,17 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import inspect import os from dataclasses import dataclass, is_dataclass -from typing import Tuple, Dict, Optional, Any from functools import partial -import copy -from typing import Union +from typing import Tuple, Dict, Optional, Any, Union import numpy as np from msprobe.core.common.const import Const +from msprobe.core.common.file_utils import save_npy from msprobe.core.common.log import logger from msprobe.core.common.utils import convert_tuple, CompareException @@ -79,25 +79,23 @@ class ModuleBackwardOutputs: class TensorStatInfo: - def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None, stack_tensor_stat=None): + def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None): self.max = max_val self.min = min_val self.mean = mean_val self.norm = norm_val - self.stack_tensor_stat = stack_tensor_stat class BaseDataProcessor: _recursive_key_stack = [] - special_type = ( - np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, np.ndarray, - bool, int, float, str, slice, - type(Ellipsis) - ) + builtin_type = (bool, int, float, str, slice, type(Ellipsis)) + np_type = (np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, np.ndarray) def __init__(self, config, data_writer): self.data_writer = data_writer self.config = config + if self.data_writer is not None: + self.data_writer.config = config self.api_info_struct = {} self.stack_info_struct = {} self.current_api_or_module_name = None @@ -120,7 +118,10 @@ class BaseDataProcessor: @staticmethod def analyze_api_call_stack(name): try: - api_stack = inspect.stack()[5:] + if name.startswith("Primitive"): + api_stack = inspect.stack()[4:] + else: + api_stack = inspect.stack()[5:] except Exception as e: logger.warning(f"The call stack of <{name}> failed to retrieve, {e}.") api_stack = None @@ -129,12 +130,14 @@ class BaseDataProcessor: for (_, path, line, func, code, _) in api_stack: if not code: continue + if any(filter_path in path for filter_path in Const.STACK_FILTER_KEYWORDS) and \ + Const.CALL_STACK_FLAG not in path: + continue stack_line = f"File {path}, line {str(line)}, in {func}, \n {code[0].strip()}" stack_str.append(stack_line) else: stack_str.append(Const.WITHOUT_CALL_STACK) - stack_info_struct = {name: stack_str} - return stack_info_struct + return tuple(stack_str) @staticmethod def transfer_type(data): @@ -178,20 +181,8 @@ class BaseDataProcessor: "invalid data_structure type or invalid index") @staticmethod - def _convert_numpy_to_builtin(arg): - type_mapping = { - np.integer: int, - np.floating: float, - np.bool_: bool, - np.complexfloating: complex, - np.str_: str, - np.byte: bytes, - np.unicode_: str - } - for numpy_type, builtin_type in type_mapping.items(): - if isinstance(arg, numpy_type): - return builtin_type(arg), type(arg).__name__ - return arg, '' + def is_distributed_op(module): + return getattr(module, "op_is_distributed", False) @staticmethod def _analyze_builtin(arg): @@ -217,21 +208,40 @@ class BaseDataProcessor: return single_arg @staticmethod - def _analyze_numpy(ndarray, numpy_type): + def _analyze_numpy(arg): + return {"type": type(arg).__name__, "value": arg.item()} + + @staticmethod + def _analyze_ndarray(ndarray, _): ndarray_json = {} ndarray_json.update({'type': 'numpy.ndarray'}) ndarray_json.update({'dtype': str(ndarray.dtype)}) ndarray_json.update({'shape': ndarray.shape}) - if ndarray.size > 0: - ndarray_json.update({"Max": np.max(ndarray).item()}) - ndarray_json.update({"Min": np.min(ndarray).item()}) - ndarray_json.update({"Mean": np.mean(ndarray).item()}) - ndarray_json.update({"Norm": np.linalg.norm(ndarray).item()}) - else: - ndarray_json.update({"Max": None}) - ndarray_json.update({"Min": None}) - ndarray_json.update({"Mean": None}) - ndarray_json.update({"Norm": None}) + + # 先初始化默认值 + stats = { + "Max": None, + "Min": None, + "Mean": None, + "Norm": None + } + + try: + # 只有非空时才尝试计算 + if ndarray.size > 0: + stats = { + "Max": np.max(ndarray).item(), + "Min": np.min(ndarray).item(), + "Mean": np.mean(ndarray).item(), + "Norm": np.linalg.norm(ndarray).item() + } + except Exception as e: + # 决定打印内容或切片 + logger.warning(f"Error analyzing ndarray stats: {e}") + + # 最后一次性更新 + ndarray_json.update(stats) + return ndarray_json @staticmethod @@ -248,12 +258,12 @@ class BaseDataProcessor: @classmethod def get_special_types(cls): - return cls.special_type + return cls.builtin_type + cls.np_type @classmethod def recursive_apply_transform(cls, args, transform, depth=0) -> Union[dict, list, None]: - if depth > Const.MAX_DEPTH: - logger.error(f"The maximum depth of recursive transform, {Const.MAX_DEPTH} is reached.") + if depth > Const.DUMP_MAX_DEPTH: + logger.error(f"The maximum depth of recursive transform, {Const.DUMP_MAX_DEPTH} is reached.") raise CompareException(CompareException.RECURSION_LIMIT_ERROR) if isinstance(args, cls.get_special_types()): arg_transform = transform(args, cls._recursive_key_stack) @@ -303,6 +313,7 @@ class BaseDataProcessor: def real_hook_fn(grad): return wrap_hook_fn(grad) + element.register_hook(real_hook_fn) def if_return_forward_new_output(self): @@ -350,6 +361,8 @@ class BaseDataProcessor: return api_info_struct def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs): + if self.is_distributed_op(module): + module_input_output.update_output_with_args_and_kwargs() api_info_struct = {} # check whether data_mode contains forward or input if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT): @@ -427,6 +440,7 @@ class BaseDataProcessor: api_info_struct = {} self.save_name = name + Const.SEP + param_name data_info = self.analyze_element(grad) + self.save_name = None grad_info_dict = {param_name: [data_info]} api_info_struct[name] = grad_info_dict return api_info_struct @@ -435,10 +449,10 @@ class BaseDataProcessor: file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX if self.save_name is not None: dump_data_name = (self.save_name + file_format) - self.save_name = None else: - dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP + - suffix + file_format) + suffix_with_seq = (Const.SEP + suffix) if suffix else "" + dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + suffix_with_seq + + file_format) file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) return dump_data_name, file_path @@ -447,23 +461,32 @@ class BaseDataProcessor: def analyze_debug_forward(self, variable, name_with_count): self.current_api_or_module_name = name_with_count - self.api_data_category = Const.TENSOR - # these two attributes are used to construct tensor file name {name_with_count}.tensor.{indexes}.npy/pt + self.api_data_category = Const.DEBUG + # these two attributes are used to construct tensor file name {name_with_count}.debug.{indexes}.npy/pt data_info = self.analyze_element(variable) return data_info - def analyze_debug_backward(self, variable, grad_name_with_count, nested_data_structure): + def analyze_debug_backward(self, variable, grad_name_with_count_category, nested_data_structure): def hook_fn(grad, indexes): suffix = Const.SEP.join([str(index) for index in indexes]) - self.save_name = grad_name_with_count + Const.SEP + Const.TENSOR + Const.SEP + suffix + suffix_with_sep = (Const.SEP + suffix) if suffix else "" + self.save_name = grad_name_with_count_category + suffix_with_sep grad_data_info = self.analyze_element(grad) self.save_name = None - full_index = [grad_name_with_count] + indexes + full_index = [grad_name_with_count_category] + indexes try: self.set_value_into_nested_structure(nested_data_structure, full_index, grad_data_info) except (ValueError, IndexError) as e: - logger.warning(f"error occured while recording statistics of {grad_name_with_count} variable, " - f"skip current recording, detailed infomation: {e}") + logger.warning(f"error occurred while recording statistics of {grad_name_with_count_category} variable," + f"skip current recording, detailed information: {e}") return grad + wrap_register_hook_single_element = partial(self.register_hook_single_element, hook_fn=hook_fn) - self.recursive_apply_transform(variable, wrap_register_hook_single_element) \ No newline at end of file + self.recursive_apply_transform(variable, wrap_register_hook_single_element) + + def _analyze_and_save_ndarray(self, ndarray, suffix): + dump_data_name, file_path = self.get_save_file_path(suffix) + save_npy(ndarray, file_path) + ndarray_json = BaseDataProcessor._analyze_ndarray(ndarray, suffix) + ndarray_json.update({"data_name": dump_data_name}) + return ndarray_json diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index 8c4542a1917b76809aad21971e148ec17bd6045e..3a4d096667ece74f2f5c4ebef96f54f1e974af85 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -13,20 +13,24 @@ # limitations under the License. # ============================================================================ +import os import zlib +from concurrent.futures import ThreadPoolExecutor import mindspore as ms from mindspore import mint, ops, hal +from mindspore.mint import distributed from mindspore._c_expression.typing import Number import numpy as np from msprobe.core.common.const import Const from msprobe.core.data_dump.data_processor.base import (BaseDataProcessor, TensorStatInfo, ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs) -from msprobe.core.common.file_utils import path_len_exceeds_limit, save_npy +from msprobe.core.common.file_utils import path_len_exceeds_limit from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_npy from msprobe.mindspore.common.log import logger -from msprobe.mindspore.dump.hook_cell.api_registry import api_register +from msprobe.mindspore.dump.hook_cell.api_register import get_api_register +from msprobe.mindspore.common.utils import is_mindtorch has_adump = True try: @@ -34,9 +38,15 @@ try: except ImportError: has_adump = False +if is_mindtorch(): + from torch import distributed as dist + class MindsporeDataProcessor(BaseDataProcessor): - mindspore_special_type = tuple([ms.Tensor, Number]) + if is_mindtorch(): + mindspore_special_type = tuple([ms.Tensor, Number, distributed.P2POp, dist.ProcessGroup]) + else: + mindspore_special_type = tuple([ms.Tensor, Number, distributed.P2POp]) def __init__(self, config, data_writer): super().__init__(config, data_writer) @@ -44,6 +54,12 @@ class MindsporeDataProcessor(BaseDataProcessor): "dtype": self.analyze_dtype_in_kwargs } self._async_dump_cache = {} + self.api_register = get_api_register() + self._crc_executor = ThreadPoolExecutor(max_workers=os.cpu_count() // 2) + + @staticmethod + def compute_crc32_bytes(tensor_bytes): + return f"{zlib.crc32(tensor_bytes):08x}" @staticmethod def get_md5_for_tensor(x): @@ -57,100 +73,107 @@ class MindsporeDataProcessor(BaseDataProcessor): return {"type": "mindspore.dtype", "value": str(element)} @staticmethod - def get_stat_info_sync(data): - tensor_stat = TensorStatInfo() - if data.dtype == ms.bool_: - data_np = data.asnumpy() - tensor_stat.max = np.max(data_np).item() - tensor_stat.min = np.min(data_np).item() - elif not data.shape: - tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item() - elif data.dtype == ms.complex64 or data.dtype == ms.complex128: - data_abs = np.abs(data.asnumpy()) - tensor_stat.max = np.max(data_abs).item() - tensor_stat.min = np.min(data_abs).item() - tensor_stat.mean = np.mean(data_abs).item() - tensor_stat.norm = np.linalg.norm(data_abs).item() - else: - if not ops.is_floating_point(data) or data.dtype == ms.float64: - data = data.to(ms.float32) - api_register.norm_inner_op_set_ori_func() - get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max) - get_min_value = api_register.mint_ops_ori_attr.get("min", mint.min) - get_mean_value = api_register.mint_ops_ori_attr.get("mean", mint.mean) - if hasattr(mint, "norm"): - get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm) - else: - get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm) - tensor_stat.max = get_max_value(data).item() - tensor_stat.min = get_min_value(data).item() - tensor_stat.mean = get_mean_value(data).item() - tensor_stat.norm = get_norm_value(data).item() - api_register.norm_inner_op_set_hook_func() - return tensor_stat + def is_hookable_element(element): + return hasattr(element, "register_hook") and callable(element.register_hook) @staticmethod - def get_stat_info_async(data): - tensor_stat = TensorStatInfo() - stack_method = api_register.functional_ori_attr.get("stack", ms.ops.stack) - if data.dtype == ms.complex64 or data.dtype == ms.complex128: - logger.warning("Async dump do not support complex data!") - return tensor_stat - elif data.dtype == ms.bool_: - tensor_stat.stack_tensor_stat = (["Max", "Min"], stack_method([data.any(), data.all()])) - elif not data.shape: - tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method([data, data, data, data])) - else: - if not ops.is_floating_point(data) or data.dtype == ms.float64: - data = data.to(ms.float32) - api_register.norm_inner_op_set_ori_func() - get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max) - get_min_value = api_register.mint_ops_ori_attr.get("min", mint.min) - get_mean_value = api_register.mint_ops_ori_attr.get("mean", mint.mean) - if hasattr(mint, "norm"): - get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm) - else: - get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm) - tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method( - [get_max_value(data), get_min_value(data), get_mean_value(data), get_norm_value(data)])) - api_register.norm_inner_op_set_hook_func() - return tensor_stat + def process_group_hash(arg): + group_ranks = distributed.get_process_group_ranks(arg) + group_ranks_hash = zlib.crc32(str(group_ranks).encode('utf-8')) + return f"{group_ranks_hash:08x}" @staticmethod - def is_hookable_element(element): - return hasattr(element, "register_hook") and callable(element.register_hook) + def _analyze_process_group(arg): + group_info = {"type": "mindspore.ProcessGroup"} + try: + group_ranks = dist.get_process_group_ranks(arg) + group_info.update({"group_ranks": group_ranks}) + group_ranks_hash = zlib.crc32(str(group_ranks).encode('utf-8')) + group_id = f"{group_ranks_hash:08x}" + group_info.update({"group_id": group_id}) + except Exception as e: + logger.warning(f"Failed to get process group ranks info with error info: {e}.") + return group_info @classmethod def get_special_types(cls): return super().get_special_types() + cls.mindspore_special_type + def dump_async_data(self): + for file_path, tensor in self._async_dump_cache.items(): + save_tensor_as_npy(tensor, file_path) + self._async_dump_cache.clear() + def get_stat_info(self, data): + self.api_register.restore_inner_used_api() tensor_stat = TensorStatInfo() if data.numel() == 0: - return tensor_stat - else: + pass + elif data.dtype == ms.bool_: + if self.config.async_dump: + tensor_stat.max = mint.any(data) + tensor_stat.min = mint.all(data) + else: + data_np = data.asnumpy() + tensor_stat.max = np.max(data_np).item() + tensor_stat.min = np.min(data_np).item() + elif not data.shape: + tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.copy() + elif data.dtype == ms.complex64 or data.dtype == ms.complex128: if self.config.async_dump: - return MindsporeDataProcessor.get_stat_info_async(data) + logger.warning("Async dump do not support complex data!") else: - return MindsporeDataProcessor.get_stat_info_sync(data) + data_abs = np.abs(data.asnumpy()) + tensor_stat.max = np.max(data_abs).item() + tensor_stat.min = np.min(data_abs).item() + tensor_stat.mean = np.mean(data_abs).item() + tensor_stat.norm = np.linalg.norm(data_abs).item() + else: + if self.config.precision == Const.DUMP_PRECISION_HIGH or not ops.is_floating_point( + data) or data.dtype == ms.float64: + data = data.to(ms.float32) + get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm + tensor_stat.max = mint.max(data) + tensor_stat.min = mint.min(data) + tensor_stat.mean = mint.mean(data) + tensor_stat.norm = get_norm_value(data) + self.api_register.register_inner_used_api() + return tensor_stat def analyze_single_element(self, element, suffix_stack): if suffix_stack and suffix_stack[-1] in self.mindspore_object_key: return self.mindspore_object_key[suffix_stack[-1]](element) - converted_numpy, numpy_type = self._convert_numpy_to_builtin(element) - if converted_numpy is not element: - return {"type": numpy_type, "value": converted_numpy} - if isinstance(element, Number): - return self.analyze_dtype_in_kwargs(element) - if isinstance(element, ms.Tensor): - return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) - if isinstance(element, np.ndarray): - return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) - if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))): - return self._analyze_builtin(element) + suffix_str = Const.SEP.join(str(s) for s in suffix_stack) + type_analyzer = [ + (MindsporeDataProcessor.builtin_type, self._analyze_builtin), + (ms.Tensor, lambda e: self._analyze_tensor(e, suffix_str)), + (Number, self.analyze_dtype_in_kwargs), + (MindsporeDataProcessor.np_type[:-1], self._analyze_numpy), + (np.ndarray, lambda e: self._analyze_ndarray(e, suffix_str)), + (distributed.P2POp, lambda e: self._analyze_p2pop(e, suffix_str)) + ] + if is_mindtorch(): + type_analyzer.append((dist.ProcessGroup, self._analyze_process_group)) + for type_key, analyze_fn in type_analyzer: + if isinstance(element, type_key): + return analyze_fn(element) return {} + def _analyze_p2pop(self, arg, suffix): + p2pop_info = {"class_type": "mindspore.mint.distributed.P2POp"} + try: + tensor_info = self._analyze_tensor(arg.tensor, suffix) + p2pop_info.update({"tensor": tensor_info}) + p2pop_info.update({"op": arg.op}) + p2pop_info.update({"peer": arg.peer}) + p2pop_info.update({"tag": arg.tag}) + group_id = self.process_group_hash(arg.group) if arg.group else None + p2pop_info.update({"group_id": group_id}) + except Exception as e: + logger.warning(f"Failed to parse the P2POp content with error info: {e}.") + return p2pop_info + def _analyze_tensor(self, tensor, suffix): tensor_stat = self.get_stat_info(tensor) tensor_json = { @@ -159,45 +182,64 @@ class MindsporeDataProcessor(BaseDataProcessor): 'shape': tensor.shape } - if tensor_stat.stack_tensor_stat is None: - tensor_json.update({'Max': self.transfer_type(tensor_stat.max)}) - tensor_json.update({'Min': self.transfer_type(tensor_stat.min)}) - tensor_json.update({'Mean': self.transfer_type(tensor_stat.mean)}) - tensor_json.update({'Norm': self.transfer_type(tensor_stat.norm)}) - else: - tensor_json.update({'tensor_stat': tensor_stat.stack_tensor_stat}) - if self.config.summary_mode == Const.MD5 and not self.config.async_dump: - tensor_md5 = self.get_md5_for_tensor(tensor) - tensor_json.update({Const.MD5: tensor_md5}) - return tensor_json + # 将统计值存入全局 buffer,并返回占位索引 + stat_values = [ + tensor_stat.max, + tensor_stat.min, + tensor_stat.mean, + tensor_stat.norm + ] + placeholder_index = self.data_writer.append_stat_to_buffer(stat_values) -class StatisticsDataProcessor(MindsporeDataProcessor): - pass + tensor_json.update({Const.TENSOR_STAT_INDEX: placeholder_index}) + if self.config.summary_mode == Const.MD5 and not self.config.async_dump: + tensor = convert_bf16_to_fp32(tensor) + # 拷贝并搬到 CPU + tensor_bytes = tensor.asnumpy() -class TensorDataProcessor(MindsporeDataProcessor): - def dump_async_data(self): - for file_path, tensor in self._async_dump_cache.items(): - save_tensor_as_npy(tensor, file_path) - self._async_dump_cache.clear() + future = self._crc_executor.submit( + MindsporeDataProcessor.compute_crc32_bytes, + tensor_bytes + ) - def _analyze_tensor(self, tensor, suffix): + crc_placeholder = self.data_writer.append_crc32_to_buffer(future) + tensor_json[Const.MD5_INDEX] = crc_placeholder + + return tensor_json + + def _analyze_and_save_tensor(self, tensor, suffix): dump_data_name, file_path = self.get_save_file_path(suffix) - single_arg = super()._analyze_tensor(tensor, suffix) + single_arg = MindsporeDataProcessor._analyze_tensor(self, tensor, suffix) single_arg.update({"data_name": dump_data_name}) if self.config.async_dump: self._async_dump_cache[file_path] = tensor.copy() else: save_tensor_as_npy(tensor, file_path) return single_arg - - def _analyze_numpy(self, ndarray, suffix): - dump_data_name, file_path = self.get_save_file_path(suffix) - save_npy(ndarray, file_path) - ndarray_json = super()._analyze_numpy(ndarray, suffix) - ndarray_json.update({"data_name": dump_data_name}) - return ndarray_json + + +class StatisticsDataProcessor(MindsporeDataProcessor): + def _analyze_tensor(self, tensor, suffix): + if any(item in self.current_api_or_module_name for item in self.config.tensor_list): + return self._analyze_and_save_tensor(tensor, suffix) + else: + return super()._analyze_tensor(tensor, suffix) + + def _analyze_ndarray(self, ndarray, suffix): + if any(item in self.current_api_or_module_name for item in self.config.tensor_list): + return self._analyze_and_save_ndarray(ndarray, suffix) + else: + return super()._analyze_ndarray(ndarray, suffix) + + +class TensorDataProcessor(MindsporeDataProcessor): + def _analyze_tensor(self, tensor, suffix): + return self._analyze_and_save_tensor(tensor, suffix) + + def _analyze_ndarray(self, ndarray, suffix): + return self._analyze_and_save_ndarray(ndarray, suffix) class OverflowCheckDataProcessor(MindsporeDataProcessor): @@ -262,11 +304,26 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor): self.cached_tensors_and_file_paths = {} def _analyze_maybe_overflow_tensor(self, tensor_json): - if tensor_json['Max'] is None: + tensor_stat_index = tensor_json.get(Const.TENSOR_STAT_INDEX) + if tensor_stat_index is None: + logger.warning("tensor_stat_index does not exist in tensor_json.") + return + max_tensor = self.data_writer.get_buffer_values_max(tensor_stat_index) + min_tensor = self.data_writer.get_buffer_values_min(tensor_stat_index) + if max_tensor is None or min_tensor is None: return - if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']): + + def check_inf_nan(value): + # Use .item() if it's a tensor-like structure + if hasattr(value, "item"): + value = value.item() + return np.isinf(value) or np.isnan(value) + + if check_inf_nan(max_tensor): self.has_overflow = True - if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']): + return + + if check_inf_nan(min_tensor): self.has_overflow = True def _analyze_tensor(self, tensor, suffix): diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 64253aa4260cab608e5ca84a5d006b28b94a33ab..9bef9ad2d84647df6aaa27256c71dccfe7732961 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -13,10 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import hashlib +import os import zlib +import ctypes +from collections.abc import Iterable from dataclasses import asdict from typing import List +from concurrent.futures import ThreadPoolExecutor import numpy as np import torch @@ -24,14 +27,15 @@ from torch import distributed as dist from torch.distributed.distributed_c10d import _get_default_group from msprobe.core.common.const import Const +from msprobe.core.common.decorator import recursion_depth_decorator +from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common.file_utils import path_len_exceeds_limit from msprobe.core.common.log import logger -from msprobe.core.common.utils import convert_tuple +from msprobe.core.common.utils import convert_tuple, is_int from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \ ModuleForwardInputsOutputs, TensorStatInfo -from msprobe.pytorch.common.utils import save_pt, load_pt +from msprobe.pytorch.common.utils import save_pt from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow -from msprobe.core.common.utils import recursion_depth_decorator is_gpu = False try: @@ -40,6 +44,57 @@ except ImportError: is_gpu = True +class TensorHandler: + def __init__(self): + self.has_dtensor = hasattr(dist, "tensor") and hasattr(dist.tensor, "DTensor") + self.has_fake_tensor = hasattr(torch, "_subclasses") and hasattr(torch._subclasses, "fake_tensor") + + def is_dtensor(self, tensor): + return self.has_dtensor and isinstance(tensor, torch.distributed.tensor.DTensor) + + def is_fake_tensor(self, tensor): + return self.has_fake_tensor and isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor) + + def is_empty_data(self, tensor): + return tensor.is_meta or self.is_fake_tensor(tensor) + + def convert_common_tensor(self, tensor): + if self.is_dtensor(tensor): + return tensor.to_local() + if self.is_fake_tensor(tensor): + logger.debug("FakeTensor cannot be converted to torch.Tensor type.") + return tensor + return tensor + + def get_tensor_type(self, tensor): + if self.is_dtensor(tensor): + return Const.DTENSOR_TYPE + if self.is_fake_tensor(tensor): + return Const.FAKE_TENSOR_TYPE + return Const.TENSOR_TYPE + + def get_dtensor_info(self, tensor): + dtensor_info = {} + if not self.is_dtensor(tensor): + return dtensor_info + if hasattr(tensor, "device_mesh") and tensor.device_mesh: + dtensor_info.update({"device_mesh": tensor.device_mesh.mesh.tolist()}) + + placements = [] + if hasattr(tensor, "placements") and isinstance(tensor.placements, Iterable): + for placement in tensor.placements: + if placement.is_shard() and is_int(placement.dim): + placements.append({"Shard": {"dim": placement.dim}}) + continue + if placement.is_replicate(): + placements.append({"Replicate": {}}) + continue + if placement.is_partial() and isinstance(placement.reduce_op, str): + placements.append({"Partial": {"reduce_op": placement.reduce_op}}) + dtensor_info.update({"placements": placements}) + return dtensor_info + + class PytorchDataProcessor(BaseDataProcessor): pytorch_special_type = ( torch.device, @@ -65,6 +120,8 @@ class PytorchDataProcessor(BaseDataProcessor): "dtype": self.analyze_dtype_in_kwargs } self._async_dump_cache = {} + self.tensor_handler = TensorHandler() + self._crc_executor = ThreadPoolExecutor(max_workers=os.cpu_count() // 2) @staticmethod def get_md5_for_tensor(x): @@ -74,109 +131,89 @@ class PytorchDataProcessor(BaseDataProcessor): crc32_hash = zlib.crc32(tensor_bytes) return f"{crc32_hash:08x}" + @staticmethod + def tensor_bytes_view_cpu(t: torch.Tensor): + """ + 返回 t 在当前 dtype 下的原始字节视图(优先零拷贝)。 + 需保证:t 已在 CPU 且是 contiguous。 + 可能返回 memoryview 或 bytes(兜底拷贝)或者 转为numpy,均可被 zlib.crc32 接受。 + """ + + nbytes = t.numel() * t.element_size() + byte_offset = t.storage_offset() * t.element_size() + + if nbytes == 0: + return memoryview(b"") + + storage = t.untyped_storage() + + # ctypes 指针构造 memoryview(零拷贝 FFI) + try: + addr = storage.data_ptr() + byte_offset + buf = (ctypes.c_ubyte * nbytes).from_address(addr) + mv3 = memoryview(buf) + + return mv3 + except Exception as e1: + logger.warning(f"path_A_failed: {e1}.") + + try: + data = ctypes.string_at(storage.data_ptr() + byte_offset, nbytes) + + return data # bytes 也可直接用于 zlib.crc32 + except Exception as e2: + logger.warning(f"path_B_failed: {e2}.") + + try: + if t.dtype == torch.bfloat16: + t = t.float() + data = t.numpy() + + return data + except Exception as e3: + logger.warning(f"path_C_failed: {e3}.") + return memoryview(b"") + + @staticmethod + def compute_crc32_from_tensor(t: torch.Tensor) -> str: + """ + 直接对 Tensor 原始字节做 CRC32。 + : + - "raw": 保持 bfloat16 原始 16bit 字节(推荐,避免升精/增容) + """ + + # 取得字节视图(含多级回退),然后做 CRC + mv = PytorchDataProcessor.tensor_bytes_view_cpu(t) + + crc = zlib.crc32(mv) + + return f"{crc:08x}" + @staticmethod def analyze_device_in_kwargs(element): single_arg = {} single_arg.update({'type': "torch.device"}) - if not isinstance(element, str): + if isinstance(element, (int, str)): + single_arg.update({"value": element}) + elif isinstance(element, torch.device): if hasattr(element, "index"): device_value = element.type + ":" + str(element.index) else: device_value = element.type single_arg.update({"value": device_value}) else: - single_arg.update({"value": element}) + logger.debug(f"Device type {type(element)} is not supported.") return single_arg @staticmethod def analyze_dtype_in_kwargs(element): return {"type": "torch.dtype", "value": str(element)} - @staticmethod - def get_stat_info_async(data): - tensor_stat = TensorStatInfo() - if torch.is_complex(data): - logger.warning("Async dump do not support complex data!") - return tensor_stat - elif data.dtype == torch.bool: - tensor_stat.stack_tensor_stat = (["Max", "Min"], torch.stack( - [torch.any(data), torch.all(data)])) - elif not data.shape: - tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([data, data, data, data])) - else: - if not data.is_floating_point() or data.dtype == torch.float64: - data = data.float() - tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([ - torch.max(data), - torch.min(data), - torch.mean(data), - torch.norm(data) - ])) - return tensor_stat - - @staticmethod - def get_stat_info_sync(data): - tensor_stat = TensorStatInfo() - if torch.is_complex(data): - data_np = data.cpu().numpy() - data_abs = np.abs(data_np) - tensor_stat.max = np.max(data_abs).item() - tensor_stat.min = np.min(data_abs).item() - tensor_stat.mean = np.mean(data_abs).item() - elif data.dtype == torch.bool: - tensor_stat.max = torch.any(data).item() - tensor_stat.min = torch.all(data).item() - elif not data.shape: - tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item() - else: - if not data.is_floating_point() or data.dtype == torch.float64: - data = data.float() - tensor_stat.max = torch.max(data).item() - tensor_stat.min = torch.min(data).item() - tensor_stat.mean = torch.mean(data).item() - tensor_stat.norm = torch.norm(data).item() - return tensor_stat - - @staticmethod - def get_stat_info(data, async_dump=False): - tensor_stat = TensorStatInfo() - if data.is_meta: - return tensor_stat - data_clone = data.detach() - if data_clone.numel() == 0: - return tensor_stat - else: - if data_clone.device.type == Const.CPU_LOWERCASE or not async_dump: - return PytorchDataProcessor.get_stat_info_sync(data_clone) - else: - return PytorchDataProcessor.get_stat_info_async(data_clone) - - @staticmethod - def handle_tensor_extremum_nan_inf(tensor, operator): - data_clone = tensor.detach() - data_nan = torch.isnan(data_clone) - if int(torch.sum(data_nan)) == data_clone.numel(): - return float('nan') - - finite_mask = torch.isfinite(data_clone) - if int(torch.sum(finite_mask)) > 0: - finite_values = data_clone[finite_mask] - return torch.max(finite_values).item() if operator == 'max' else \ - torch.min(finite_values).item() - else: - data_no_nan = data_clone[~data_nan] - return torch.max(data_no_nan).item() if operator == 'max' else \ - torch.min(data_no_nan).item() - @staticmethod def process_group_hash(arg): group_ranks = dist.get_process_group_ranks(arg) - group_ranks_hash = hashlib.md5(str(group_ranks).encode('utf-8')).hexdigest() - return group_ranks_hash - - @staticmethod - def is_distributed_op(module): - return getattr(module, "op_is_distributed", False) + group_ranks_hash = zlib.crc32(str(group_ranks).encode('utf-8')) + return f"{group_ranks_hash:08x}" @staticmethod def is_hookable_element(element): @@ -185,7 +222,7 @@ class PytorchDataProcessor(BaseDataProcessor): @staticmethod def _analyze_torch_size(arg): - return {"type": "torch.Size", "value": list(arg)} + return {"type": "torch.Size", "value": [int(x) for x in list(arg)]} @staticmethod def _analyze_memory_format(arg): @@ -218,39 +255,67 @@ class PytorchDataProcessor(BaseDataProcessor): def get_special_types(cls): return super().get_special_types() + cls.pytorch_special_type + def get_stat_info(self, data, async_dump=False, precision=Const.DUMP_PRECISION_LOW): + tensor_stat = TensorStatInfo() + if self.tensor_handler.is_empty_data(data): + return tensor_stat + data_clone = data.detach() + if not data_clone.numel() or not data_clone.data_ptr(): + return tensor_stat + if torch.is_complex(data_clone): + if async_dump: + logger.warning("Async dump do not support complex data!") + return tensor_stat + data_np = data_clone.cpu().numpy() + data_abs = np.abs(data_np) + tensor_stat.max = np.max(data_abs).item() + tensor_stat.min = np.min(data_abs).item() + tensor_stat.mean = np.mean(data_abs).item() + elif data_clone.dtype == torch.bool: + tensor_stat.max = torch.any(data_clone) + tensor_stat.min = torch.all(data_clone) + elif not data_clone.shape: + tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data_clone.clone() + else: + if (precision == Const.DUMP_PRECISION_HIGH or data_clone.dtype == torch.float64 + or not data_clone.is_floating_point()): + data_clone = data_clone.float() + tensor_stat.max = torch.max(data_clone) + tensor_stat.min = torch.min(data_clone) + tensor_stat.mean = torch.mean(data_clone) + tensor_stat.norm = torch.norm(data_clone) + return tensor_stat + + def dump_async_data(self): + for file_path, tensor in self._async_dump_cache.items(): + save_pt(tensor.contiguous(), file_path) + self._async_dump_cache.clear() + def analyze_single_element(self, element, suffix_stack): if suffix_stack and suffix_stack[-1] in self.torch_object_key: return self.torch_object_key[suffix_stack[-1]](element) - if isinstance(element, torch.Size): - return self._analyze_torch_size(element) - if isinstance(element, torch.memory_format): - return self._analyze_memory_format(element) - if isinstance(element, dist.ProcessGroup): - return self._analyze_process_group(element) - if isinstance(element, dist.P2POp): - return self._analyze_p2pop(element) - if isinstance(element, dist.ReduceOp): - return self._analyze_reduce_op(element) - converted_numpy, numpy_type = self._convert_numpy_to_builtin(element) - if converted_numpy is not element: - return {"type": numpy_type, "value": converted_numpy} - if isinstance(element, torch.Tensor): - return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) - if isinstance(element, np.ndarray): - return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) - if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))): - return self._analyze_builtin(element) - return {} - def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs): - if self.is_distributed_op(module): - module_input_output.update_output_with_args_and_kwargs() - return super().analyze_forward_output(name, module, module_input_output) + suffix_str = Const.SEP.join(str(s) for s in suffix_stack) + type_analyzer = [ + (PytorchDataProcessor.builtin_type, self._analyze_builtin), + (torch.Size, self._analyze_torch_size), + (torch.Tensor, lambda e: self._analyze_tensor(e, suffix_str)), + (torch.memory_format, self._analyze_memory_format), + (dist.ProcessGroup, self._analyze_process_group), + (dist.P2POp, lambda e: self._analyze_p2pop(e, suffix_str)), + (dist.ReduceOp, self._analyze_reduce_op), + (PytorchDataProcessor.np_type[:-1], self._analyze_numpy), + (np.ndarray, lambda e: self._analyze_ndarray(e, suffix_str)), + ] + for type_key, analyze_fn in type_analyzer: + if isinstance(element, type_key): + return analyze_fn(element) + return {} - def _analyze_p2pop(self, arg): + def _analyze_p2pop(self, arg, suffix): p2pop_info = {"class_type": "torch.distributed.P2POp"} try: - tensor_info = self._analyze_tensor(arg.tensor, []) + tensor_info = self._analyze_tensor(arg.tensor, suffix) p2pop_info.update({"tensor": tensor_info}) p2pop_info.update({"op": arg.op.__name__}) p2pop_info.update({"peer": arg.peer}) @@ -263,47 +328,70 @@ class PytorchDataProcessor(BaseDataProcessor): return p2pop_info def _analyze_tensor(self, tensor, suffix): - tensor_stat = self.get_stat_info(tensor, self.config.async_dump) + common_tensor = self.tensor_handler.convert_common_tensor(tensor) + tensor_stat = self.get_stat_info(common_tensor, self.config.async_dump, self.config.precision) tensor_json = {} - tensor_json.update({'type': 'torch.Tensor'}) - tensor_json.update({'dtype': str(tensor.dtype)}) - tensor_json.update({"shape": tensor.shape}) - if tensor_stat.stack_tensor_stat is None: - tensor_json.update({"Max": tensor_stat.max}) - tensor_json.update({"Min": tensor_stat.min}) - tensor_json.update({"Mean": tensor_stat.mean}) - tensor_json.update({"Norm": tensor_stat.norm}) - tensor_json.update({"requires_grad": tensor.requires_grad}) - if tensor_stat.max is not None: - if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max): - tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max") - if tensor_stat.min is not None: - if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min): - tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min") - - else: - tensor_json.update({"requires_grad": tensor.requires_grad}) - tensor_json.update({"tensor_stat": tensor_stat.stack_tensor_stat}) + tensor_json.update({'type': self.tensor_handler.get_tensor_type(tensor)}) + tensor_json.update({'dtype': str(common_tensor.dtype)}) + tensor_json.update({"shape": common_tensor.shape}) + + stat_values = [ + tensor_stat.max, + tensor_stat.min, + tensor_stat.mean, + tensor_stat.norm + ] + placeholder_index = self.data_writer.append_stat_to_buffer(stat_values) + + tensor_json.update({Const.TENSOR_STAT_INDEX: placeholder_index}) + tensor_json.update({"requires_grad": tensor.requires_grad}) + if self.tensor_handler.is_dtensor(tensor): + dtensor_info = self.tensor_handler.get_dtensor_info(tensor) + tensor_json.update(dtensor_info) if self.config.summary_mode == Const.MD5 and not self.config.async_dump: - tensor_md5 = self.get_md5_for_tensor(tensor) - tensor_json.update({Const.MD5: tensor_md5}) + tensor_md5 = None + if not self.tensor_handler.is_empty_data(tensor): + t_cpu = common_tensor + + # 根据设备类型做同步,确保数据已准备好 + if t_cpu.device.type == "cuda": + t_cpu = t_cpu.to("cpu", non_blocking=True) + torch.cuda.synchronize() + # 先异步搬运再进行同步可以显著提升性能 + elif t_cpu.device.type == "npu": + t_cpu = t_cpu.to("cpu", non_blocking=True) + torch.npu.synchronize() + + t_cpu = t_cpu.detach() + if not t_cpu.is_contiguous(): + t_cpu = t_cpu.contiguous() + + future = self._crc_executor.submit( + PytorchDataProcessor.compute_crc32_from_tensor, + t_cpu + ) + + crc_placeholder = self.data_writer.append_crc32_to_buffer(future) + tensor_json[Const.MD5_INDEX] = crc_placeholder + else: + logger.debug( + "Calculating the md5 value of fake tensor or meta tensor is not supported, " + f"the current api/module name is {self.current_api_or_module_name}." + ) + tensor_json.update({Const.MD5: tensor_md5}) return tensor_json - -class StatisticsDataProcessor(PytorchDataProcessor): - pass - - -class TensorDataProcessor(PytorchDataProcessor): - def dump_async_data(self): - for file_path, tensor in self._async_dump_cache.items(): - save_pt(tensor.contiguous(), file_path) - self._async_dump_cache.clear() - - def _analyze_tensor(self, tensor, suffix): + def _analyze_and_save_tensor(self, tensor, suffix): dump_data_name, file_path = self.get_save_file_path(suffix) - single_arg = super()._analyze_tensor(tensor, suffix) + single_arg = PytorchDataProcessor._analyze_tensor(self, tensor, suffix) + if self.tensor_handler.is_empty_data(tensor) or tensor.untyped_storage().data_ptr() == 0: + logger.debug( + "Collecting real data of fake tensor or meta tensor is not supported or data_ptr is 0, " + f"the current api/module name is {self.current_api_or_module_name}." + ) + return single_arg + single_arg.update({"data_name": dump_data_name}) if self.config.async_dump: self._async_dump_cache[file_path] = tensor.clone().detach() @@ -311,15 +399,37 @@ class TensorDataProcessor(PytorchDataProcessor): saved_tensor = tensor.clone().contiguous().detach() save_pt(saved_tensor, file_path) return single_arg - - def _analyze_numpy(self, ndarray, suffix): + + def _analyze_and_save_ndarray(self, ndarray, suffix): dump_data_name, file_path = self.get_save_file_path(suffix) save_pt(torch.tensor(ndarray), file_path) - ndarray_json = super()._analyze_numpy(ndarray, suffix) + ndarray_json = PytorchDataProcessor._analyze_ndarray(ndarray, suffix) ndarray_json.update({"data_name": dump_data_name}) return ndarray_json +class StatisticsDataProcessor(PytorchDataProcessor): + def _analyze_tensor(self, tensor, suffix): + if any(item in self.current_api_or_module_name for item in self.config.tensor_list): + return self._analyze_and_save_tensor(tensor, suffix) + else: + return super()._analyze_tensor(tensor, suffix) + + def _analyze_ndarray(self, ndarray, suffix): + if any(item in self.current_api_or_module_name for item in self.config.tensor_list): + return self._analyze_and_save_ndarray(ndarray, suffix) + else: + return super()._analyze_ndarray(ndarray, suffix) + + +class TensorDataProcessor(PytorchDataProcessor): + def _analyze_tensor(self, tensor, suffix): + return self._analyze_and_save_tensor(tensor, suffix) + + def _analyze_ndarray(self, ndarray, suffix): + return self._analyze_and_save_ndarray(ndarray, suffix) + + class OverflowCheckDataProcessor(PytorchDataProcessor): __slots__ = ["cached_tensors_and_file_paths"] @@ -383,7 +493,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor): self._analyze_maybe_overflow_flag() if self.has_overflow: for file_path, tensor in self.cached_tensors_and_file_paths.items(): - save_pt(tensor, file_path) + save_pt(tensor.clone().contiguous().detach(), file_path) self.real_overflow_nums += 1 if self.overflow_nums != -1 and self.real_overflow_nums >= self.overflow_nums: logger.info(f"[{Const.TOOL_NAME}] Reached the preset overflow times, " @@ -409,10 +519,22 @@ class OverflowCheckDataProcessor(PytorchDataProcessor): raise RuntimeError(f"overflow check failed") from e def _analyze_maybe_overflow_tensor(self, tensor_json): - if tensor_json['Max'] is None or tensor_json['Min'] is None: + tensor_stat_index = tensor_json.get(Const.TENSOR_STAT_INDEX) + if tensor_stat_index is None: + logger.warning("tensor_stat_index does not exist in tensor_json.") return - self.has_overflow = np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']) or \ - np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']) + max_tensor = self.data_writer.get_buffer_values_max(tensor_stat_index) + min_tensor = self.data_writer.get_buffer_values_min(tensor_stat_index) + + if max_tensor is None or min_tensor is None: + return + + if torch.isinf(max_tensor) or torch.isnan(max_tensor): + self.has_overflow = True + return + + if torch.isinf(min_tensor) or torch.isnan(min_tensor): + self.has_overflow = True def _analyze_tensor(self, tensor, suffix): dump_data_name, file_path = self.get_save_file_path(suffix) @@ -508,11 +630,13 @@ class KernelDumpDataProcessor(PytorchDataProcessor): return if self.config.is_backward_kernel_dump: - self.forward_args = self.clone_and_detach_tensor(module_input_output.args) - self.forward_kwargs = self.clone_and_detach_tensor(module_input_output.kwargs) try: + self.forward_args = self.clone_and_detach_tensor(module_input_output.args) + self.forward_kwargs = self.clone_and_detach_tensor(module_input_output.kwargs) output = module.forward(*self.forward_args, **self.forward_kwargs) - except Exception: + except Exception as e: + if isinstance(e, MsprobeException): + logger.warning(str(e)) self._print_unsupported_log(name) self.enable_kernel_dump = False return @@ -554,7 +678,10 @@ class KernelDumpDataProcessor(PytorchDataProcessor): self.stop_kernel_dump() logger.info(f"The kernel data of {name} is dumped successfully.") - @recursion_depth_decorator("KernelDump: KernelDumpDataProcessor.clone_and_detach_tensor") + @recursion_depth_decorator( + "KernelDump: KernelDumpDataProcessor.clone_and_detach_tensor", + max_depth=Const.DUMP_MAX_DEPTH + ) def clone_and_detach_tensor(self, input_params): if isinstance(input_params, torch.Tensor): if input_params.requires_grad: diff --git a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py index b1e26d16f9741765c1c9600a64efb112aa0f42d7..79c6202bca61abe204f847aa4c05b0cbf2be1f16 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,12 +16,17 @@ import csv import os import copy -import numpy as np +import threading +import traceback +from datetime import datetime, timezone, timedelta +import concurrent from msprobe.core.common.const import Const, FileCheckConst -from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json +from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json, check_path_before_create from msprobe.core.common.log import logger -from msprobe.core.common.exceptions import MsprobeException +from msprobe.core.common.decorator import recursion_depth_decorator + +lock = threading.Lock() class DataWriter: @@ -33,11 +38,17 @@ class DataWriter: self.free_benchmark_file_path = None self.dump_tensor_data_dir = None self.debug_file_path = None + self.dump_error_info_path = None self.flush_size = 1000 + self.larger_flush_size = 20000 self.cache_data = {} self.cache_stack = {} self.cache_construct = {} self.cache_debug = {} + self.stat_stack_list = [] + self._error_log_initialized = False + self._cache_logged_error_types = set() + self.crc32_stack_list = [] @staticmethod def write_data_to_csv(result: list, result_header: tuple, file_path: str): @@ -54,13 +65,93 @@ class DataWriter: if is_new_file: change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) + @recursion_depth_decorator("JsonWriter: DataWriter._replace_crc32_placeholders") + def _replace_crc32_placeholders(self, data, crc32_results): + """ + 遍历 JSON 结构,将所有 md5_index 占位符替换成真实的 CRC32 + """ + if isinstance(data, dict): + for k, v in list(data.items()): + if k == Const.MD5_INDEX and isinstance(v, int): + idx = v + # 防越界 + crc = crc32_results[idx] if idx < len(crc32_results) else None + # 删除占位符,改成真实字段 + del data[k] + data[Const.MD5] = crc + else: + self._replace_crc32_placeholders(v, crc32_results) + elif isinstance(data, (list, tuple)): + for item in data: + self._replace_crc32_placeholders(item, crc32_results) + + @recursion_depth_decorator("JsonWriter: DataWriter._replace_stat_placeholders") + def _replace_stat_placeholders(self, data, stat_result): + if isinstance(data, dict): + keys = list(data.keys()) # 获取当前所有键 + for key in keys: # 递归所有变量 + value = data[key] + if key == Const.TENSOR_STAT_INDEX and isinstance(value, int): + if value >= 0: + idx = value + else: + return + stat_values = stat_result[idx] if idx < len(stat_result) else [None] * 4 + + new_entries = { + Const.TYPE: data["type"], + Const.DTYPE: data["dtype"], + Const.SHAPE: data["shape"], + Const.MAX: stat_values[0], + Const.MIN: stat_values[1], + Const.MEAN: stat_values[2], + Const.NORM: stat_values[3], + } + del data[key] + + # 重构字典顺序 + updated_dict = {} + # 通过插入排序后字段保证字段写入json的有序 + updated_dict.update(new_entries) + # 遍历原字典其他字段(排除已删除的tensor_stat_index) + for k in data: + if k not in new_entries: + updated_dict[k] = data[k] + data.clear() + data.update(updated_dict) + else: + self._replace_stat_placeholders(value, stat_result) + elif isinstance(data, (list, tuple)): + for item in data: + self._replace_stat_placeholders(item, stat_result) + def reset_cache(self): self.cache_data = {} self.cache_stack = {} self.cache_construct = {} + self.cache_debug = {} + self._cache_logged_error_types = set() + + def append_crc32_to_buffer(self, future: concurrent.futures.Future) -> int: + """ + 把一个计算 CRC32 的 Future 放入队列,返回占位符索引 + """ + idx = len(self.crc32_stack_list) + self.crc32_stack_list.append(future) + return idx + + def flush_crc32_stack(self): + """ + 等待所有 CRC32 计算完成,返回结果列表 + """ + if not self.crc32_stack_list: + return [] + results = [f.result() for f in self.crc32_stack_list] + self.crc32_stack_list = [] + return results def initialize_json_file(self, **kwargs): - if self.debug_file_path and not self.cache_debug: + if kwargs["level"] == Const.LEVEL_DEBUG and not self.cache_debug: # debug level case only create debug.json debug_dict = copy.deepcopy(kwargs) debug_dict.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}}) @@ -83,42 +174,102 @@ class DataWriter: self.dump_tensor_data_dir = dump_path_aggregation.dump_tensor_data_dir self.free_benchmark_file_path = dump_path_aggregation.free_benchmark_file_path self.debug_file_path = dump_path_aggregation.debug_file_path + self.dump_error_info_path = dump_path_aggregation.dump_error_info_path def flush_data_periodically(self): dump_data = self.cache_data.get(Const.DATA) - if dump_data and isinstance(dump_data, dict) and len(dump_data) % self.flush_size == 0: - self.write_json() - def update_data(self, new_data): - if not isinstance(new_data, dict) or len(new_data.keys()) != 1: - logger.warning(f"The data info({new_data}) should be a dict with only one outer key.") - return - dump_data = self.cache_data.get(Const.DATA) - if not isinstance(dump_data, dict): - logger.warning(f"The dump data({dump_data}) should be a dict.") + if not dump_data or not isinstance(dump_data, dict): return - key = next(iter(new_data.keys())) - if key in dump_data: - dump_data.get(key).update(new_data.get(key)) + length = len(dump_data) + + # 1) 先取到 config(如果没有,就拿 None) + cfg = getattr(self, "config", None) + # 2) 再取 summary_mode(如果 cfg 是 None 或者没 summary_mode,就拿 None) + summary_mode = getattr(cfg, "summary_mode", None) + + if summary_mode == Const.MD5: + threshold = self.flush_size else: - dump_data.update(new_data) + threshold = self.flush_size if length < self.larger_flush_size else self.larger_flush_size + + if length % threshold == 0: + self.write_json() + + def write_error_log(self, message: str, error_type: str): + """ + 写错误日志: + - 第一次调用时以 'w' 模式清空文件,之后都用 'a' 模式追加 + - 添加时间戳 + - 在 message 后写入当前的调用栈(方便追踪日志来源) + """ + # 如果同类型错误已经记录过,跳过 + if error_type in self._cache_logged_error_types: + return + # 否则添加到已记录集合,并继续写日志 + self._cache_logged_error_types.add(error_type) + + try: + mode = "w" if not self._error_log_initialized else "a" + self._error_log_initialized = True + + check_path_before_create(self.dump_error_info_path) + + with FileOpen(self.dump_error_info_path, mode) as f: + cst_timezone = timezone(timedelta(hours=8), name="CST") + timestamp = datetime.now(cst_timezone).strftime("%Y-%m-%d %H:%M:%S %z") + f.write(f"[{timestamp}] {message}\n") + f.write("Call stack (most recent call last):\n") + + f.write("".join(traceback.format_stack()[:-1])) # 去掉自己这一层 + f.write("\n") + except Exception as e: + # 如果连写日志都失败了,就打印到 stderr + logger.warning(f"[FallbackError] Failed to write error log: {e}") + + def update_data(self, new_data): + with lock: + if not isinstance(new_data, dict) or len(new_data.keys()) != 1: + logger.warning(f"The data info({new_data}) should be a dict with only one outer key.") + return + dump_data = self.cache_data.get(Const.DATA) + if not isinstance(dump_data, dict): + logger.warning(f"The dump data({dump_data}) should be a dict.") + return + + key = next(iter(new_data.keys())) + if key in dump_data: + dump_data.get(key).update(new_data.get(key)) + else: + dump_data.update(new_data) - def update_stack(self, new_data): - self.cache_stack.update(new_data) + def update_stack(self, name, stack_data): + with lock: + api_list = self.cache_stack.get(stack_data) + if api_list is None: + self.cache_stack.update({stack_data: [name]}) + else: + api_list.append(name) def update_construct(self, new_data): - self.cache_construct.update(new_data) + with lock: + self.cache_construct.update(new_data) def update_debug(self, new_data): - self.cache_debug['data'].update(new_data) + with lock: + self.cache_debug['data'].update(new_data) def write_data_json(self, file_path): logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ") save_json(file_path, self.cache_data, indent=1) def write_stack_info_json(self, file_path): - save_json(file_path, self.cache_stack, indent=1) + num, new_cache_stack = 0, {} + for key, value in self.cache_stack.items(): + new_cache_stack[num] = [value, key] + num += 1 + save_json(file_path, new_cache_stack, indent=1) def write_construct_info_json(self, file_path): save_json(file_path, self.cache_construct, indent=1) @@ -126,38 +277,70 @@ class DataWriter: def write_debug_info_json(self, file_path): save_json(file_path, self.cache_debug, indent=1) + def append_stat_to_buffer(self, stat_vector): + """ + 直接使用 Python list 存储 stat_vector, + 将 stat_vector 存入 self.stat_stack_list 的方式 + """ + self.stat_stack_list.append(stat_vector) + return len(self.stat_stack_list) - 1 + + def get_buffer_values_max(self, index): + if 0 <= index < len(self.stat_stack_list) and len(self.stat_stack_list[index]) >= 1: + return self.stat_stack_list[index][0] + else: + logger.warning(f"stat_stack_list[{index}] The internal data is incomplete," + f" and the maximum value cannot be obtained.") + return None + + def get_buffer_values_min(self, index): + if 0 <= index < len(self.stat_stack_list) and len(self.stat_stack_list[index]) >= 1: + return self.stat_stack_list[index][1] + else: + logger.warning(f"stat_stack_list[{index}] Internal data is incomplete" + f" and minimum values cannot be obtained.") + return None + + def flush_stat_stack(self): + """ + 在 flush 阶段,将所有存储的统计值从设备搬到 CPU, + 这里返回一个列表,每个元素是 [Max, Min, Mean, Norm] 的数值列表 + """ + if not self.stat_stack_list: + return [] + result = [ + [ + x.item() if hasattr(x, "item") else x + for x in stat_values + ] + for stat_values in self.stat_stack_list + ] + self.stat_stack_list = [] + return result + def write_json(self): - if self.cache_data: - self.write_data_json(self.dump_file_path) - if self.cache_stack: - self.write_stack_info_json(self.stack_file_path) - if self.cache_construct: - self.write_construct_info_json(self.construct_file_path) - if self.cache_debug: - self.write_debug_info_json(self.debug_file_path) - - def fill_stack_tensor_data(self): - self.process_stat_data_recursive(self.cache_data) - - def process_stat_data_recursive(self, data, depth=0): - if depth > Const.MAX_DEPTH: - logger.error(f"The maximum depth of recursive process stat data, {Const.MAX_DEPTH} is reached.") - raise MsprobeException(MsprobeException.RECURSION_LIMIT_ERROR) - if isinstance(data, dict): - if "tensor_stat" in data.keys(): - tensor_stat = data["tensor_stat"] - if len(tensor_stat) != Const.TENSOR_STAT_LEN or len(tensor_stat[0]) != len(tensor_stat[1]): - logger.warning("Some bad data in async dump") - else: - tensor_stat_index, tensor_stat_data = tensor_stat[0], tensor_stat[1] - if hasattr(tensor_stat_data, "device") and tensor_stat_data.device != Const.CPU_LOWERCASE: - tensor_stat_data = tensor_stat_data.cpu() - for index, stat in zip(tensor_stat_index, tensor_stat_data): - data.update({index: stat.item()}) - del data["tensor_stat"] - else: - for key in data.keys(): - self.process_stat_data_recursive(data[key], depth + 1) - elif isinstance(data, (list, tuple)): - for i in data: - self.process_stat_data_recursive(i, depth + 1) \ No newline at end of file + with lock: + # 在写 JSON 前,统一获取统计值 + stat_result = self.flush_stat_stack() + # 遍历 cache_data,将占位符替换为最终统计值 + if stat_result: + self._replace_stat_placeholders(self.cache_data, stat_result) + if self.cache_debug: + self._replace_stat_placeholders(self.cache_debug, stat_result) + + # 2) 再 flush CRC32 + crc32_result = self.flush_crc32_stack() + if crc32_result: + self._replace_crc32_placeholders(self.cache_data, crc32_result) + if self.cache_debug: + self._replace_crc32_placeholders(self.cache_debug, crc32_result) + + if self.cache_data: + self.write_data_json(self.dump_file_path) + if self.cache_stack: + self.write_stack_info_json(self.stack_file_path) + if self.cache_construct: + self.write_construct_info_json(self.construct_file_path) + if self.cache_debug: + self.write_debug_info_json(self.debug_file_path) + diff --git a/debug/accuracy_tools/msprobe/core/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/core/debugger/precision_debugger.py new file mode 100644 index 0000000000000000000000000000000000000000..9cfbd76c198a64503e63bbcf04e86af6cbf7bb6d --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/debugger/precision_debugger.py @@ -0,0 +1,149 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from msprobe.core.common.const import Const, FileCheckConst, MsgConst +from msprobe.core.common.exceptions import MsprobeException +from msprobe.core.common.file_utils import FileChecker, load_json +from msprobe.core.common.utils import get_real_step_or_rank, check_init_step, ThreadSafe +from msprobe.core.common_config import CommonConfig + + +class BasePrecisionDebugger: + _instance = None + tasks_not_need_debugger = [Const.GRAD_PROBE] + + def __new__(cls, *args, **kwargs): + if not cls._instance: + with ThreadSafe(): + if not cls._instance: + cls._instance = super(BasePrecisionDebugger, cls).__new__(cls) + cls._instance.config = None + cls._instance.initialized = False + cls.service = None + cls.first_start = False + return cls._instance + + def __init__( + self, + config_path=None, + task=None, + dump_path=None, + level=None, + step=None + ): + if self.initialized: + return + self.initialized = True + self._check_input_params(config_path, task, dump_path, level) + self.common_config, self.task_config = self._parse_config_path(config_path, task) + self.task = self.common_config.task + if step is not None: + self.common_config.step = get_real_step_or_rank(step, Const.STEP) + + @staticmethod + def _check_input_params(config_path, task, dump_path, level): + if not config_path: + config_path = os.path.join(os.path.dirname(__file__), "../../config.json") + if config_path is not None: + if not isinstance(config_path, str): + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string") + file_checker = FileChecker( + file_path=config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX) + file_checker.common_check() + + if task is not None and task not in Const.TASK_LIST: + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}") + + if dump_path is not None: + if not isinstance(dump_path, str): + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string") + + if level is not None and level not in Const.LEVEL_LIST: + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}") + + @staticmethod + def _get_task_config(task, json_config): + raise NotImplementedError("Subclass must implement _get_task_config") + + @classmethod + @ThreadSafe.synchronized + def forward_backward_dump_end(cls): + instance = cls._instance + instance.stop() + + @classmethod + @ThreadSafe.synchronized + def set_init_step(cls, step): + instance = cls._instance + if not instance: + raise Exception(MsgConst.NOT_CREATED_INSTANCE) + check_init_step(step) + instance.service.init_step = step + instance.service.loop = 0 + + @classmethod + @ThreadSafe.synchronized + def register_custom_api(cls, module, api, api_prefix=None): + if not api_prefix: + api_prefix = getattr(module, "__name__", "Custom") + if not isinstance(api_prefix, str): + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, "api_prefix must be string") + if not hasattr(module, api): + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, f"module {str(module)} does not have {api}") + instance = cls._instance + if not instance: + raise Exception(MsgConst.NOT_CREATED_INSTANCE) + instance.service.register_custom_api(module, api, api_prefix) + + @classmethod + @ThreadSafe.synchronized + def restore_custom_api(cls, module, api): + if not hasattr(module, api): + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, f"module {str(module)} does not have {api}") + instance = cls._instance + if not instance: + raise Exception(MsgConst.NOT_CREATED_INSTANCE) + instance.service.restore_custom_api(module, api) + + @classmethod + def _get_instance(cls): + instance = cls._instance + if not instance: + raise Exception(MsgConst.NOT_CREATED_INSTANCE) + if instance.task in BasePrecisionDebugger.tasks_not_need_debugger: + instance = None + return instance + + def _parse_config_path(self, json_file_path, task): + if not json_file_path: + json_file_path = os.path.join(os.path.dirname(__file__), "../../config.json") + json_config = load_json(json_file_path) + common_config = CommonConfig(json_config) + if task: + task_config = self._get_task_config(task, json_config) + else: + if not common_config.task: + common_config.task = Const.STATISTICS + task_config = self._get_task_config(common_config.task, json_config) + return common_config, task_config diff --git a/debug/accuracy_tools/msprobe/core/grad_probe/constant.py b/debug/accuracy_tools/msprobe/core/grad_probe/constant.py index 22a8b6c13411b68a6566d0686062f8c74cb27196..5d9c72a6f2d60203b0d9ba716e867e39ee22d807 100644 --- a/debug/accuracy_tools/msprobe/core/grad_probe/constant.py +++ b/debug/accuracy_tools/msprobe/core/grad_probe/constant.py @@ -31,6 +31,7 @@ class GradConst: STEP = "step" BOUNDS = "bounds" OUTPUT_PATH = "output_path" + TIME_STAMP = "time_stamp" # level const LEVEL = "level" @@ -51,7 +52,7 @@ class GradConst: BOUNDS_MINIMUM = -2**63 BOUNDS_MAXIMUM = 2**63 - 1 - # file safty + # file safety DATA_DIR_AUTHORITY = 0o750 DATA_FILE_AUTHORITY = 0o640 DIRECTORY_LENGTH = 4096 diff --git a/debug/accuracy_tools/msprobe/core/grad_probe/grad_compare.py b/debug/accuracy_tools/msprobe/core/grad_probe/grad_compare.py index 4f2b25bd28dfe330a8716695278ab8c64222c4b6..f50fc0f4e381db0e4069ef99b5c70b593f1580d0 100644 --- a/debug/accuracy_tools/msprobe/core/grad_probe/grad_compare.py +++ b/debug/accuracy_tools/msprobe/core/grad_probe/grad_compare.py @@ -112,7 +112,7 @@ class GradComparator: result.append([key] + value) result_csv_path = os.path.join(output_dir, "similarities.csv") if os.path.exists(result_csv_path): - logger.warning(f"{result_csv_path} will be recoverd") + logger.warning(f"{result_csv_path} will be deleted") remove_path(result_csv_path) write_csv(result, result_csv_path) @@ -121,7 +121,7 @@ class GradComparator: similarities = {} logger.info(f"{len(steps)} steps will be compared") grad_weight_order = cls._get_grad_weight_order(path1, path2) - for step in tqdm(steps, desc="culculate similarities (by step)"): + for step in tqdm(steps, desc="calculate similarities (by step)"): grad_files = cls._get_matched_grad_files(path1, path2, step) same_count_summary = 0 total_count_summary = 0 diff --git a/debug/accuracy_tools/msprobe/core/grad_probe/utils.py b/debug/accuracy_tools/msprobe/core/grad_probe/utils.py index de3e4156acc74f135120e06116b5894a0e9ed09e..468367a54a8bf4926edd5a8f25cefaa5890ec40c 100644 --- a/debug/accuracy_tools/msprobe/core/grad_probe/utils.py +++ b/debug/accuracy_tools/msprobe/core/grad_probe/utils.py @@ -82,7 +82,7 @@ class ListCache(list): if len(self) == 0: return if not self._output_file: - logger.warning("dumpfile path is not setted") + logger.warning("dumpfile path is not set.") write_csv(self, self._output_file) logger.info(f"write {len(self)} items to {self._output_file}.") self.clear() diff --git a/debug/accuracy_tools/msprobe/core/hook_manager.py b/debug/accuracy_tools/msprobe/core/hook_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..e0af279786a7ed2594daff63e615d57fb002240a --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/hook_manager.py @@ -0,0 +1,310 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import threading +from abc import ABC, abstractmethod +from collections import defaultdict + +from msprobe.core.common.runtime import Runtime +from msprobe.core.common.utils import Const, ThreadSafe +from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs) + + +class HookSet: + def __init__( + self, + forward_pre_hook=None, + forward_hook=None, + backward_pre_hook=None, + backward_hook=None, + distributed_forward_hook=None + ): + self.forward_pre_hook = forward_pre_hook + self.forward_hook = forward_hook + self.backward_pre_hook = backward_pre_hook + self.backward_hook = backward_hook + self.distributed_forward_hook = distributed_forward_hook + + +class BaseHookManager(ABC): + inner_switch = defaultdict(bool) + inner_api_count = defaultdict(int) + hook_handle_dict = {} + params_grad_info = {} + + def __init__(self, data_collector, config): + self.data_collector = data_collector + self.config = config + + @property + def _pid(self): + return os.getpid() + + @property + @abstractmethod + def _is_recompute(self): + pass + + @staticmethod + def reset_status(): + BaseHookManager.inner_switch = defaultdict(bool) + BaseHookManager.inner_api_count = defaultdict(int) + BaseHookManager.hook_handle_dict.clear() + BaseHookManager.params_grad_info.clear() + + @staticmethod + def _clear_input_kwargs(module, tid): + if hasattr(module, 'msprobe_input_kwargs') and tid in module.msprobe_input_kwargs: + del module.msprobe_input_kwargs[tid] + + @staticmethod + @abstractmethod + def _no_grad_context(): + pass + + @staticmethod + @abstractmethod + def _add_count(name): + pass + + @staticmethod + @abstractmethod + def _get_count(name): + pass + + @staticmethod + @abstractmethod + def _process_kwargs_and_output(module, tid, hook_type, kwargs_or_output, output_or_kwargs): + pass + + @abstractmethod + def build_hook(self): + pass + + @abstractmethod + def _register_forward_hook(self, module, api_name): + pass + + @abstractmethod + def _register_backward_hook(self, module, full_backward_name, args): + pass + + @abstractmethod + def _register_backward_pre_hook(self, module, full_backward_name, output): + pass + + @abstractmethod + def _get_params_dict(self, module): + pass + + @abstractmethod + def _need_exchange(self, module): + pass + + def _register_param_hook(self, name, module, params_dict): + ori_name = name.rsplit(Const.SEP, 2)[0] + grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD + # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook + setattr(module, 'params_grad_name', grad_name) + # data_mode为forward时,不注册参数hook + if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode): + for param_name, param in params_dict.items(): + if param.requires_grad: + name = ori_name + Const.SEP + param_name + old_handle = BaseHookManager.hook_handle_dict.get(name) + if old_handle and hasattr(old_handle, "remove"): + old_handle.remove() + handle = param.register_hook(self._build_grad_hook(ori_name, param_name)) + BaseHookManager.hook_handle_dict[name] = handle + + def _init_params_grad_info(self, module, params_dict): + ''' + 初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位 + ''' + if not params_dict: + return + if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode): + grad_name = module.params_grad_name if hasattr(module, 'params_grad_name') else None + # 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中 + if not BaseHookManager.params_grad_info.get(grad_name): + data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}} + # 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位 + if data_info.get(grad_name): + # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新 + self.data_collector.handle_data(grad_name, data_info, + flush=self.data_collector.data_processor.is_terminated) + self.data_collector.params_grad_record[grad_name] = True + # 记录当前模块的参数梯度信息已占位 + BaseHookManager.params_grad_info[grad_name] = True + + def _should_execute_hook(self, hook_type, tid, is_forward=True): + is_api_hook = hook_type == Const.API + if BaseHookManager.inner_switch[tid]: + return False + if not is_api_hook and not Runtime.is_running: + return False + elif is_api_hook and is_forward and not Runtime.is_running: + return False + if not self.data_collector or self.data_collector.data_processor.is_terminated: + return False + return True + + def _build_grad_hook(self, ori_name, param_name): + def hook_fn(grad): + tid = threading.get_ident() + if not self._should_execute_hook(Const.MODULE, tid): + return + with ThreadSafe(): + BaseHookManager.inner_switch[tid] = True + self.data_collector.params_data_collect(ori_name, param_name, self._pid, grad) + BaseHookManager.inner_switch[tid] = False + return + + return hook_fn + + def _build_forward_pre_hook(self, hook_type, api_name): + def forward_pre_hook(module, args, kwargs=None): + if hook_type == Const.MODULE: + return None + + tid = threading.get_ident() + if not self._should_execute_hook(hook_type, tid): + return None + + with ThreadSafe(): + self._register_forward_hook(module, api_name) + BaseHookManager.inner_api_count[tid] += 1 + if BaseHookManager.inner_api_count[tid] != 1: + return None + + full_forward_name = api_name + str(self._get_count(api_name)) + Const.SEP + Const.FORWARD + full_backward_name = api_name + str(self._get_count(api_name)) + Const.SEP + Const.BACKWARD + module.full_forward_name = full_forward_name + if kwargs is None: + kwargs = module.msprobe_input_kwargs.get(tid, {}) if hasattr(module, 'msprobe_input_kwargs') else {} + BaseHookManager.inner_switch[tid] = True + module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None) + + args = self._register_backward_hook(module, full_backward_name, args) + with self._no_grad_context(): + self.data_collector.update_api_or_module_name(full_forward_name) + self.data_collector.forward_input_data_collect( + full_forward_name, + module, + self._pid, + module_input_output, + self._is_recompute + ) + BaseHookManager.inner_switch[tid] = False + return args + + return forward_pre_hook + + def _build_forward_hook(self, hook_type, api_name): + def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None): + tid = threading.get_ident() + if not self._should_execute_hook(hook_type, tid): + self._clear_input_kwargs(module, tid) + return None + + with ThreadSafe(): + if hook_type == Const.API: + if BaseHookManager.inner_api_count[tid] != 1: + if BaseHookManager.inner_api_count[tid] > 1: + BaseHookManager.inner_api_count[tid] -= 1 + self._clear_input_kwargs(module, tid) + return None + + kwargs, output = self._process_kwargs_and_output( + module, + tid, + hook_type, + kwargs_or_output, + output_or_kwargs + ) + BaseHookManager.inner_switch[tid] = True + module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output) + if hook_type == Const.API: + full_forward_name = api_name + str(self._get_count(api_name)) + Const.SEP + Const.FORWARD + full_backward_name = api_name + str(self._get_count(api_name)) + Const.SEP + Const.BACKWARD + output = self._register_backward_pre_hook(module, full_backward_name, output) + + with self._no_grad_context(): + if hook_type == Const.MODULE: + params_dict = self._get_params_dict(module) + setattr(module_input_output, Const.PARAMS, params_dict) + if params_dict: + self._register_param_hook(api_name, module, params_dict) + self.data_collector.update_api_or_module_name(api_name) + self.data_collector.forward_data_collect( + api_name, + module, + self._pid, + module_input_output, + self._is_recompute + ) + self._init_params_grad_info(module, params_dict) + else: + self.data_collector.update_api_or_module_name(full_forward_name) + self.data_collector.forward_output_data_collect( + full_forward_name, + module, + self._pid, + module_input_output, + self._is_recompute + ) + self._add_count(api_name) + BaseHookManager.inner_api_count[tid] -= 1 + self._clear_input_kwargs(module, tid) + + if self.data_collector.if_return_forward_new_output(): + forward_new_output = self.data_collector.get_forward_new_output() + BaseHookManager.inner_switch[tid] = False + return forward_new_output + + BaseHookManager.inner_switch[tid] = False + return output + + return forward_hook + + def _build_backward_hook(self, hook_type, full_name): + def backward_hook(module, grad_input, grad_output): + tid = threading.get_ident() + if not self._should_execute_hook(hook_type, tid, is_forward=False): + return + + with ThreadSafe(): + BaseHookManager.inner_switch[tid] = True + self.data_collector.update_api_or_module_name(full_name) + + need_exchange = self._need_exchange(module) if hook_type == Const.MODULE else True + if need_exchange: + module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input) + else: + module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output) + self.data_collector.backward_data_collect( + full_name, + module, + self._pid, + module_input_output, + self._is_recompute + ) + if hook_type == Const.MODULE: + params_dict = self._get_params_dict(module) + self.data_collector.params_data_collect_in_bw_hook(params_dict, full_name) + BaseHookManager.inner_switch[tid] = False + + return backward_hook diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/kernel_dump/kernel_config.py b/debug/accuracy_tools/msprobe/core/kernel_dump/kernel_config.py similarity index 100% rename from debug/accuracy_tools/msprobe/mindspore/dump/kernel_dump/kernel_config.py rename to debug/accuracy_tools/msprobe/core/kernel_dump/kernel_config.py diff --git a/debug/accuracy_tools/msprobe/core/monitor/__init__.py b/debug/accuracy_tools/msprobe/core/monitor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_analyse.py b/debug/accuracy_tools/msprobe/core/monitor/anomaly_processor.py similarity index 49% rename from debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_analyse.py rename to debug/accuracy_tools/msprobe/core/monitor/anomaly_processor.py index 9a0b71e8a5791bc216c82737d1d4f4a482abceb9..8c50ad761682c05533d525acaa39e6f830cc4e48 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_analyse.py +++ b/debug/accuracy_tools/msprobe/core/monitor/anomaly_processor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,18 +12,205 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import os import sys +import math import argparse import ast import heapq +from abc import ABC +from dataclasses import dataclass, field +from typing import List -from msprobe.pytorch.common.log import logger from msprobe.core.common.const import MonitorConst -from msprobe.core.common.file_utils import check_path_before_create, save_json, create_directory, remove_path, \ +from msprobe.core.common.log import logger +from msprobe.core.common.file_utils import save_json, create_directory, remove_path, \ check_file_or_directory_path, load_json -from msprobe.pytorch.monitor.anomaly_detect import GradAnomalyData + + +class ScanRule(ABC): + name = "ScanRule" + + def apply(self, cur, history=None): + raise NotImplementedError("abstract method apply is not implemented") + + +class AnomalyTurbulence(ScanRule): + name = "AnomalyTurbulence" + + def __init__(self, threshold) -> None: + self.threshold = threshold + + def apply(self, cur, history=None): + """ + :param cur: float, current metric value + :param history: float, history weighted average + :return: bool, whether the current value deviates from the historical average value of current metric + """ + up_bound = history * (1 + self.threshold) + return abs(cur) > up_bound + + +class AnomalyNan(ScanRule): + name = "AnomalyNan" + + def __init__(self, threshold=None) -> None: + self.threshold = threshold + + def apply(self, cur, history=None): + return math.isnan(cur) or (self.threshold is not None and abs(cur) > self.threshold) + + +class AnomalyScanner: + + @staticmethod + def load_rules(specs: List[dict]): + """ + specs: [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}] + """ + if specs is None: + return [] + alert_rules = [] + for spec in specs: + # 使用get方法获取键值,如果键不存在则返回None + rule_cls_name = spec.get("rule_name") + rule_args = spec.get("args") + + # 检查必要的键是否存在 + if rule_cls_name is None or (rule_cls_name == "AnomalyTurbulence" and rule_args is None): + logger.warning(f"Spec is missing required keys: {spec}") + continue + + cur_module = sys.modules.get(__name__) + try: + rule_cls = getattr(cur_module, rule_cls_name) + except AttributeError: + logger.error(f"Rule class '{rule_cls_name}' not found in the current module.") + continue + + try: + rule_instance = rule_cls(**rule_args) if rule_args is not None else rule_cls() + alert_rules.append(rule_instance) + except Exception as e: + logger.error(f"Error creating instance of rule '{rule_cls_name}': {e}") + continue + + return alert_rules + + @staticmethod + def scan(scan_rules: List[ScanRule], history, cur): + anomaly = False + for rule in scan_rules: + anomaly = rule.apply(cur, history=history) + if anomaly: + return anomaly, rule.name + return anomaly, None + + +class AnomalyDataFactory(ABC): + def __init__(self, rank, pp_stage, group_mates): + super().__init__() + self.rank = rank + self.pp_stage = pp_stage + self.group_mates = group_mates + self.micro_step = 0 + self.name2callid = {} + + def set_call_id(self, name2callid): + """根据当前GradContext信息更新call_id vpp_stage等信息 + """ + self.name2callid = name2callid + + def create(self, tag, message, step): + """如果检查出异常, 调用当前接口生成GradAnomalyData实例 + tag (tuple): metric tag ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min') + message (str): anomaly detect message + step (int): training step + """ + if not isinstance(tag, tuple) or len(tag) != 2: + raise ValueError("tag must be a tuple with length 2") + tag_name = tag[0] + param_name = tag_name.split('/')[0] + call_id = self.name2callid.get(tag_name, -1) + if MonitorConst.NAME_SEP in param_name: + vpp_stage = int(param_name.split(MonitorConst.NAME_SEP)[0]) + else: + vpp_stage = 0 + + return GradAnomalyData( + self.rank, + step, + self.micro_step, + self.pp_stage, + vpp_stage, + call_id, + tag_name, + message, + self.group_mates + ) + + +@dataclass(eq=True) +class GradAnomalyData: + rank: int = 0 + step: int = 0 + micro_step: int = 0 + pp_stage: int = 0 + vpp_stage: int = 0 + call_id: int = 0 + tag_name: str = field(default=None, compare=False) + message: str = field(default="", compare=False) + group_mates: list = field(default=None, compare=False) + + def __lt__(self, other): + """ + 自定义比较函数,用于确定 GradAnomalyData 实例之间的顺序。 + 比较规则为: + step 和 micro_step 值越小优先级越高; + vpp 和 pp 在前向阶段值越小优先级越高,在非前向阶段值越大优先级越高; + call_id 值越小优先级越高。 + """ + if not isinstance(other, GradAnomalyData): + return NotImplemented + + self_train_stage = self.get_train_stage(self.tag_name) + other_train_stage = self.get_train_stage(other.tag_name) + + def vpp_pp_comparator(anomaly): + """ + Determine the priority rule for vpp and pp based on train stage + Forward stage prefers smaller vpp and pp + Other stages prefer larger vpp and pp + """ + if self_train_stage == MonitorConst.FORWARD_STAGE: + return anomaly.vpp_stage, anomaly.pp_stage + else: + return -anomaly.vpp_stage, -anomaly.pp_stage + + self_cmp = [self.step, self.micro_step, self_train_stage, *vpp_pp_comparator(self), self.call_id] + other_cmp = [other.step, other.micro_step, other_train_stage, *vpp_pp_comparator(other), other.call_id] + return self_cmp < other_cmp + + def __le__(self, other): + if not isinstance(other, GradAnomalyData): + return NotImplemented + return self == other or self < other + + @staticmethod + def get_train_stage(tag_name): + """ + :param tag_name: "0:fc2.input:0/rank0/actv", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq" + :return: int, if forward return 0; if backward return 1; if optimizer return 2 + """ + key_ = tag_name.split("/")[-1] + return MonitorConst.TRAIN_STAGE.get(key_, MonitorConst.DEFAULT_STAGE) + + def to_dict(self): + return self.__dict__ + + def get_key(self): + # 0:1.self_attention.core_attention_flash_0/rank0/input_grad + return ''.join([str(self.tag_name), "_step_", str(self.step), "_call_", str(self.call_id)]) class AnomalyDataWriter: @@ -46,12 +233,7 @@ class AnomalyDataWriter: def init_detected_json(self): """初始化落盘文件""" - check_path_before_create(self.dump_path) - if not os.path.exists(self.dump_path): - create_directory(self.dump_path) - - if not os.path.exists(self.dump_rank_dir): - create_directory(self.dump_rank_dir) + create_directory(self.dump_rank_dir) if os.path.exists(self.json_path): check_file_or_directory_path(self.json_path, isdir=False) @@ -66,11 +248,12 @@ class AnomalyDataWriter: anomalies: GradAnomalyData对象列表 """ anomalies_json = self.get_anomaly_dict(anomalies) - logger.info(f"{MonitorConst.ANOMALY_JSON} is at {self.dump_rank_dir}.") + if anomalies_json: + logger.info(f"{MonitorConst.ANOMALY_JSON} is at {self.dump_rank_dir}.") - data_to_write = load_json(self.json_path) if os.path.exists(self.json_path) else {} - data_to_write.update(anomalies_json) - save_json(self.json_path, data_to_write, indent=1) + data_to_write = load_json(self.json_path) if os.path.exists(self.json_path) else {} + data_to_write.update(anomalies_json) + save_json(self.json_path, data_to_write, indent=1) class AnomalyDataLoader: @@ -145,27 +328,6 @@ class AnomalyAnalyse: save_json(json_path, sorted_data, indent=1) -def _get_parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("-d", "--data_path", dest="data_path_dir", default="./", type=str, - help=" The anomaly detect result dictionary: generate from monitor tool.", - required=True, - ) - parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, - help=" The analyse task result out path.", - required=False, - ) - parser.add_argument("-k", "--topk", dest="top_k_number", default=8, type=int, - help=" Top K number of earliest anomalies.", - required=False, - ) - parser.add_argument("-s", "--step", dest="step_list", default="[]", type=str, - help=" Analyse which steps.", - required=False, - ) - return parser.parse_args(sys.argv[1:]) - - def _get_step_and_stop(args): try: step_list = ast.literal_eval(args.step_list) @@ -196,6 +358,27 @@ def _anomaly_analyse(): logger.info(f"{index}: {anomaly.message}") +def _get_parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-d", "--data_path", dest="data_path_dir", default="./", type=str, + help=" The anomaly detect result dictionary: generate from monitor tool.", + required=True, + ) + parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, + help=" The analyse task result out path.", + required=False, + ) + parser.add_argument("-k", "--topk", dest="top_k_number", default=8, type=int, + help=" Top K number of earliest anomalies.", + required=False, + ) + parser.add_argument("-s", "--step", dest="step_list", default="[]", type=str, + help=" Analyse which steps.", + required=False, + ) + return parser.parse_args(sys.argv[1:]) + + if __name__ == "__main__": _anomaly_analyse() logger.info("Analyse task completed.") diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py new file mode 100644 index 0000000000000000000000000000000000000000..ef8d4e26c34b75be06241b6f04204a9f6739da19 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py @@ -0,0 +1,361 @@ +# Copyright (c) 2025-2026, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import os +import re +from collections import OrderedDict, defaultdict +from concurrent.futures import ProcessPoolExecutor, as_completed +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import pytz +from msprobe.core.common.const import MonitorConst +from msprobe.core.common.file_utils import (create_directory, read_csv, + recursive_chmod, remove_path) +from msprobe.core.common.log import logger +from msprobe.core.common.utils import is_int +from msprobe.core.monitor.db_utils import MonitorDB, update_ordered_dict +from msprobe.core.monitor.utils import get_target_output_dir +from tqdm import tqdm + +# Constants +all_data_type_list = [ + "actv", "actv_grad", "exp_avg", "exp_avg_sq", + "grad_unreduced", "grad_reduced", "param_origin", "param_updated", "other" +] + + + +@dataclass +class CSV2DBConfig: + """Configuration for CSV to database conversion""" + monitor_path: str + time_start: Optional[str] = None + time_end: Optional[str] = None + process_num: int = 1 + data_type_list: Optional[List[str]] = None + output_dirpath: Optional[str] = None + step_partition: int = 500 + + +def validate_process_num(process_num: int) -> None: + """Validate process number parameter""" + if not is_int(process_num) or process_num <= 0: + raise ValueError("process_num must be a positive integer") + if process_num > MonitorConst.MAX_PROCESS_NUM: + raise ValueError(f"Maximum supported process_num is {MonitorConst.MAX_PROCESS_NUM}") + + +def validate_step_partition(step_partition: int) -> None: + if not isinstance(step_partition, int): + raise TypeError("step_partition must be integer") + if not MonitorConst.MIN_PARTITION <= step_partition <= MonitorConst.MAX_PARTITION: + raise ValueError( + f"step_partition must be between {MonitorConst.MIN_PARTITION} ", + f"and {MonitorConst.MAX_PARTITION}, got {step_partition}" + ) + + +def validate_data_type_list(data_type_list: Optional[List[str]]) -> None: + """Validate data type list parameter""" + if data_type_list is None or not data_type_list: + logger.info(f"Using default data types: {all_data_type_list}") + return + + if not isinstance(data_type_list, list): + raise ValueError("data_type_list must be a list") + + invalid_types = [t for t in data_type_list if t not in all_data_type_list] + if invalid_types: + raise ValueError(f"Unsupported data types: {invalid_types}") + + +def get_info_from_filename(file_name, metric_list=None): + metric_name = "_".join(file_name.split('_')[:-1]) + if metric_list and metric_name not in metric_list: + return "", 0, 0 + match = re.match(f"{metric_name}{MonitorConst.CSV_FILE_PATTERN}", file_name) + if not match: + return "", 0, 0 + step_start, step_end = match.groups() + return metric_name, step_start, step_end + + +def _pre_scan_single_rank(rank: int, files: List[str]) -> Dict: + """Pre-scan files for a single rank to collect metadata""" + metrics = set() + min_step = None + max_step = 0 + metric_stats = defaultdict(set) + targets = OrderedDict() + + for file_path in files: + file_name = os.path.basename(file_path) + metric_name, step_start, step_end = get_info_from_filename(file_name) + if not metric_name: + continue + step_start, step_end = int(step_start), int(step_end) + + metrics.add(metric_name) + min_step = min( + step_start if min_step is None else min_step, step_start) + max_step = max(max_step, step_end) + + data = read_csv(file_path) + stats = [k for k in data.keys() if k in MonitorConst.OP_MONVIS_SUPPORTED] + metric_stats[metric_name].update(stats) + + for row_id, row in data.iterrows(): + try: + name = row[MonitorConst.HEADER_NAME] + vpp_stage = int(row['vpp_stage']) + micro_step = int(row.get('micro_step', MonitorConst.DEFAULT_INT_VALUE)) + except (ValueError, KeyError) as e: + logger.warning( + f"CSV conversion failed | file={file_path}:{row_id+2} | error={str(e)}") + continue + target = (name, vpp_stage, micro_step) + if target not in targets: + targets[target] = None + + return { + 'max_rank': int(rank), + 'metrics': metrics, + 'min_step': min_step, + 'max_step': max_step, + 'metric_stats': metric_stats, + 'targets': list(targets.keys()) + } + + +def _pre_scan(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list: List[str], workers: int = 1): + """Pre-scan all targets, metrics, and statistics""" + logger.info("Scanning dimensions...") + rank_files = defaultdict(list) + + # Collect files for each rank + for rank, dir_path in data_dirs.items(): + files = os.listdir(dir_path) + for file in files: + metric_name, _, _ = get_info_from_filename( + file, metric_list=data_type_list) + if not metric_name: + continue + rank_files[rank].append(os.path.join(dir_path, file)) + + # Parallel pre-scan + with ProcessPoolExecutor(max_workers=workers) as executor: + futures = { + executor.submit(_pre_scan_single_rank, rank, files): rank + for rank, files in rank_files.items() + } + + results = [] + with tqdm(total=len(futures), desc="Pre-scanning ranks") as pbar: + for future in as_completed(futures): + rank = futures[future] + try: + result = future.result() + results.append(result) + except Exception as e: + logger.error( + f"Error pre-scanning rank {rank}: {str(e)}") + pbar.update(1) + + # Aggregate results + targets = OrderedDict() + metrics = set() + min_step = None + max_step = 0 + max_rank = 0 + metric_stats = defaultdict(set) + + for rank_result in results: + max_rank = max(max_rank, rank_result['max_rank']) + metrics.update(rank_result['metrics']) + min_step = min( + min_step if min_step is not None else rank_result['min_step'], + rank_result['min_step'] + ) + max_step = max(max_step, rank_result['max_step']) + + for metric, stats in rank_result['metric_stats'].items(): + metric_stats[metric].update(stats) + + targets = update_ordered_dict(targets, rank_result['targets']) + + monitor_db.insert_dimensions( + targets, metrics, metric_stats, min_step=min_step, max_step=max_step) + monitor_db.update_global_stats( + max_rank=max_rank, min_step=min_step, max_step=max_step) + return rank_files + + +def process_single_rank( + task: Tuple[int, List[str]], + metric_id_dict: Dict[str, Tuple[int, List[str]]], + target_dict: Dict[Tuple[str, int, int], int], + step_partition_size: int, + db_path: str +) -> int: + """Process data import for a single rank""" + rank, files = task + db = MonitorDB(db_path, step_partition_size=step_partition_size) + total_inserted = 0 + table_batches = defaultdict(list) + + for file in files: + filename = os.path.basename(file) + metric_name, _, _ = get_info_from_filename(filename) + if not metric_name: + continue + metric_info = metric_id_dict.get(metric_name) + if not metric_info: + continue + + metric_id, stats = metric_info + + for row_id, row in read_csv(file).iterrows(): + try: + # Parse row data + name = row.get(MonitorConst.HEADER_NAME) + vpp_stage = int(row['vpp_stage']) + micro_step = int(row.get('micro_step', MonitorConst.DEFAULT_INT_VALUE)) + target_id = target_dict.get((name, vpp_stage, micro_step)) + if not target_id: + continue + + step = int(row['step']) + table_name, _, _ = db.get_metric_table_name(metric_id, step) + # Prepare row data + row_data = [rank, step, target_id] + row_data.extend( + float(row[stat]) if stat in row else None + for stat in stats + ) + except (ValueError, KeyError) as e: + logger.error( + f"CSV conversion failed | file={file}:{row_id+2} | error={str(e)}") + continue + + table_batches[table_name].append(tuple(row_data)) + # Batch insert when threshold reached + if len(table_batches[table_name]) >= MonitorConst.BATCH_SIZE: + inserted = db.insert_rows( + table_name, table_batches[table_name]) + if inserted is not None: + total_inserted += inserted + table_batches[table_name] = [] + + # Insert remaining data + for table_name, batch in table_batches.items(): + if batch: + inserted = db.insert_rows(table_name, batch) + if inserted is not None: + total_inserted += inserted + + logger.info(f"Rank {rank} inserted {total_inserted} rows") + return total_inserted + + +def import_data(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list: List[str], workers: int = 4) -> bool: + """Main method to import data into database""" + # 1. Pre-scan to get rank tasks + monitor_db.init_schema() + rank_tasks = _pre_scan(monitor_db, data_dirs, data_type_list, workers) + if not rank_tasks: + logger.error("No valid data files found during pre-scan") + return False + + # 2. Get metric and target mappings + try: + metric_id_dict = monitor_db.get_metric_mapping() + target_dict = monitor_db.get_target_mapping() + except Exception as e: + logger.error(f"Failed to get database mappings: {str(e)}") + return False + + # 3. Process data for each rank in parallel + total_files = sum(len(files) for files in rank_tasks.values()) + logger.info(f"Starting data import for {len(rank_tasks)} ranks," + f"{total_files} files..." + ) + all_succeeded = True + with ProcessPoolExecutor(max_workers=workers) as executor: + futures = { + executor.submit( + process_single_rank, + (rank, files), + metric_id_dict, + target_dict, + monitor_db.step_partition_size, + monitor_db.db_path): rank + for rank, files in rank_tasks.items() + } + + with tqdm(as_completed(futures), total=len(futures), desc="Import progress") as pbar: + for future in pbar: + rank = futures[future] + try: + inserted = future.result() + pbar.set_postfix_str( + f"Rank {rank}: inserted {inserted} rows") + except Exception as e: + logger.error( + f"Failed to process Rank {rank}: {str(e)}") + all_succeeded = False + return all_succeeded + + +def csv2db(config: CSV2DBConfig) -> None: + """Main function to convert CSV files to database""" + validate_process_num(config.process_num) + validate_step_partition(config.step_partition) + validate_data_type_list(config.data_type_list) + + target_output_dirs = get_target_output_dir( + config.monitor_path, config.time_start, config.time_end) + + if config.output_dirpath is None: + local_tz = pytz.timezone("Asia/Shanghai") + cur_time = datetime.datetime.now(local_tz).strftime("%b%d_%H-%M-%S") + config.output_dirpath = os.path.join( + config.monitor_path, f"{cur_time}-csv2db") + + create_directory(config.output_dirpath) + db_path = os.path.join(config.output_dirpath, "monitor_metrics.db") + + if os.path.exists(db_path): + remove_path(db_path) + logger.warning(f"Existing path {db_path} will be recovered") + + db = MonitorDB(db_path, step_partition_size=config.step_partition) + + result = import_data( + db, + target_output_dirs, + config.data_type_list if config.data_type_list else all_data_type_list, + workers=config.process_num + ) + recursive_chmod(config.output_dirpath) + if result: + logger.info( + f"Data import completed. Output saved to: {config.output_dirpath}") + else: + logger.warning( + f"Data import may be incomplete. Output directory: {config.output_dirpath} " + f"(Some records might have failed)" + ) diff --git a/debug/accuracy_tools/msprobe/core/monitor/db_utils.py b/debug/accuracy_tools/msprobe/core/monitor/db_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cf623311b1514d45bd1edc17f060763031f237a7 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/monitor/db_utils.py @@ -0,0 +1,278 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import OrderedDict +from collections.abc import Iterable +from typing import Dict, List, Optional, Set, Tuple + +from msprobe.core.common.const import MonitorConst +from msprobe.core.common.db_manager import DBManager + + +def update_ordered_dict(main_dict: OrderedDict, new_list: List) -> OrderedDict: + """Update ordered dictionary with new items""" + for item in new_list: + if item not in main_dict: + main_dict[item] = None + return main_dict + + +def get_ordered_stats(stats: Iterable) -> List[str]: + """Get statistics in predefined order""" + if not isinstance(stats, Iterable): + return [] + return [stat for stat in MonitorConst.OP_MONVIS_SUPPORTED if stat in stats] + + +class MonitorSql: + """数据库表参数类""" + + @staticmethod + def create_monitoring_targets_table(): + """监控目标表""" + return """ + CREATE TABLE IF NOT EXISTS monitoring_targets ( + target_id INTEGER PRIMARY KEY AUTOINCREMENT, + target_name TEXT NOT NULL, + vpp_stage INTEGER NOT NULL, + micro_step INTEGER NOT NULL DEFAULT 0, + UNIQUE(target_name, vpp_stage, micro_step) + )""" + + @staticmethod + def create_monitoring_metrics_table(): + """监控指标表""" + return """ + CREATE TABLE IF NOT EXISTS monitoring_metrics ( + metric_id INTEGER PRIMARY KEY AUTOINCREMENT, + metric_name TEXT UNIQUE NOT NULL + )""" + + @staticmethod + def get_metric_mapping_sql(): + return """ + SELECT m.metric_id, m.metric_name, GROUP_CONCAT(ms.stat_name) as stats + FROM monitoring_metrics m + LEFT JOIN metric_stats ms ON m.metric_id = ms.metric_id + GROUP BY m.metric_id + """ + + @staticmethod + def create_metric_stats_table(): + """指标统计表""" + return """ + CREATE TABLE IF NOT EXISTS metric_stats ( + metric_id INTEGER NOT NULL, + stat_name TEXT NOT NULL, + PRIMARY KEY (metric_id, stat_name), + FOREIGN KEY (metric_id) REFERENCES monitoring_metrics(metric_id) + ) WITHOUT ROWID""" + + @staticmethod + def create_global_stat_table(): + return """ + CREATE TABLE IF NOT EXISTS global_stats ( + stat_name TEXT PRIMARY KEY, + stat_value INTEGER NOT NULL + ) WITHOUT ROWID""" + + @classmethod + def get_table_definition(cls, table_name=""): + """ + 获取表定义SQL + :param table_name: 表名 + :return: 建表SQL语句 + :raises ValueError: 当表名不存在时 + """ + table_creators = { + "monitoring_targets": cls.create_monitoring_targets_table, + "monitoring_metrics": cls.create_monitoring_metrics_table, + "metric_stats": cls.create_metric_stats_table, + "global_stats": cls.create_global_stat_table, + } + if not table_name: + return [table_creators.get(table, lambda x: "")() for table in table_creators] + if table_name not in table_creators: + raise ValueError(f"Unsupported table name: {table_name}") + return table_creators[table_name]() + + @classmethod + def get_metric_table_definition(cls, table_name, stats, patition=None): + stat_columns = [f"{stat} REAL DEFAULT NULL" for stat in stats] + if patition and len(patition) == 2: + partition_start_step, partition_end_step = patition + step_column = f"""step INTEGER NOT NULL CHECK(step BETWEEN {partition_start_step} + AND {partition_end_step}),""" + else: + step_column = "step INTEGER NOT NULL" + create_sql = f""" + CREATE TABLE {table_name} ( + rank INTEGER NOT NULL, + {step_column} + target_id INTEGER NOT NULL, + {', '.join(stat_columns)}, + PRIMARY KEY (rank, step, target_id), + FOREIGN KEY (target_id) REFERENCES monitoring_targets(target_id) + ) WITHOUT ROWID + """ + return create_sql + + +class MonitorDB: + """Main class for monitoring database operations""" + + def __init__(self, db_path: str, step_partition_size: int = 500): + self.db_path = db_path + self.db_manager = DBManager(db_path) + self.step_partition_size = step_partition_size + + def get_metric_table_name(self, metric_id: int, step: int) -> str: + """Generate metric table name""" + step_start = ( + step // self.step_partition_size) * self.step_partition_size + step_end = step_start + self.step_partition_size - 1 + return f"metric_{metric_id}_step_{step_start}_{step_end}", step_start, step_end + + def init_schema(self) -> None: + """Initialize database schema""" + self.db_manager.execute_multi_sql(MonitorSql.get_table_definition()) + + # Insert initial global stats + global_stats = [ + ('max_rank', 0), + ('min_step', 0), + ('max_step', 0), + ('step_partition_size', self.step_partition_size) + ] + self.db_manager.insert_data("global_stats", global_stats) + + def insert_dimensions( + self, + targets: OrderedDict, + metrics: Set[str], + metric_stats: Dict[str, Set[str]], + min_step: Optional[int] = None, + max_step: int = None, + ) -> None: + """Insert dimension data into database""" + # Insert targets + self.db_manager.insert_data( + "monitoring_targets", + [(name, vpp_stage, micro_step) + for (name, vpp_stage, micro_step) in targets], + key_list=["target_name", "vpp_stage", "micro_step"] + ) + + # Insert metrics + self.db_manager.insert_data( + "monitoring_metrics", + [(metric,) for metric in metrics], + key_list=["metric_name"] + ) + + # Insert metric-stat relationships + for metric, stats in metric_stats.items(): + metric_id = self._get_metric_id(metric) + ordered_stats = get_ordered_stats(stats) + + self.db_manager.insert_data( + "metric_stats", + [(metric_id, stat) for stat in ordered_stats], + key_list=["metric_id", "stat_name"] + ) + + # Create metric tables for each partition + if min_step is not None and max_step is not None: + first_partition = min_step // self.step_partition_size + last_partition = max_step // self.step_partition_size + + for partition in range(first_partition, last_partition + 1): + step_start = partition * self.step_partition_size + self.create_metric_table( + metric_id, step_start, ordered_stats) + + def insert_rows(self, table_name, rows): + if not self.db_manager.table_exists(table_name): + raise RuntimeError(f"{table_name} not existed in {self.db_path}") + inserted = self.db_manager.insert_data(table_name, rows) + inserted = 0 if inserted is None else inserted + return inserted + + def create_metric_table(self, metric_id: int, step: int, stats: List[str]) -> str: + """Create metric table for a specific partition""" + table_name, partition_start_step, partition_end_step = self.get_metric_table_name( + metric_id, + step + ) + if self.db_manager.table_exists(table_name): + return table_name + + create_sql = MonitorSql.get_metric_table_definition( + table_name, stats, patition=( + partition_start_step, partition_end_step) + ) + self.db_manager.execute_sql(create_sql) + return table_name + + def update_global_stats(self, max_rank: int = None, min_step: Optional[int] = None, max_step: int = None) -> None: + """Update global statistics""" + updates = [ + ("max_rank", max_rank), + ("min_step", min_step), + ("max_step", max_step) + ] + for stat_name, value in updates: + if not value: + continue + self.db_manager.update_data( + table_name="global_stats", + updates={"stat_value": value}, + where={"stat_name": stat_name} + ) + + def get_metric_mapping(self) -> Dict[str, Tuple[int, List[str]]]: + """Get metric name to ID mapping with statistics""" + results = self.db_manager.execute_sql( + MonitorSql.get_metric_mapping_sql() + ) + + return { + row["metric_name"]: ( + row["metric_id"], + get_ordered_stats(row["stats"].split(",") + ) if row["stats"] else [] + ) for row in results + } + + def get_target_mapping(self) -> Dict[Tuple[str, int, int], int]: + """Get target mapping dictionary""" + results = self.db_manager.select_data( + table_name="monitoring_targets", + columns=["target_id", "target_name", "vpp_stage", "micro_step"] + ) + if not results: + return {} + return { + (row["target_name"], row["vpp_stage"], row["micro_step"]): row["target_id"] + for row in results + } + + def _get_metric_id(self, metric_name: str) -> Optional[int]: + """Get metric ID by name""" + result = self.db_manager.select_data( + table_name="monitoring_metrics", + columns=["metric_id"], + where={"metric_name": metric_name} + ) + return result[0]["metric_id"] if result else None diff --git a/debug/accuracy_tools/msprobe/core/monitor/utils.py b/debug/accuracy_tools/msprobe/core/monitor/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2c7a20a040aa47762fec1b1ca37944fe64d1c793 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/monitor/utils.py @@ -0,0 +1,372 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from datetime import timezone, timedelta +from functools import wraps +from datetime import datetime +import os +import re + +from msprobe.core.common.const import MonitorConst +from msprobe.core.common.log import logger +from msprobe.core.common.utils import is_int +from msprobe.core.common.file_utils import check_file_or_directory_path, recursive_chmod + + +beijing_tz = timezone(timedelta(hours=8)) +MVResult = namedtuple('MVResult', ("exp_avg", "exp_avg_sq", "update", "ratio")) + + +class MsgConst: + """ + Class for log messages const + """ + SPECIAL_CHAR = ["\n", "\r", "\u007F", "\b", "\f", "\t", "\u000B", "%08", "%0a", "%0b", "%0c", "%0d", "%7f"] + + +def get_output_base_dir(): + return os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR) + + +def filter_special_chars(func): + @wraps(func) + def func_level(msg): + for char in MsgConst.SPECIAL_CHAR: + msg = msg.replace(char, '_') + return func(msg) + + return func_level + + +def validate_ops(ops): + if not isinstance(ops, list): + raise TypeError("ops should be a list") + valid_ops = [] + for op in ops: + if op not in MonitorConst.OP_LIST: + logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}") + continue + valid_ops.append(op) + if not valid_ops: + default_op = MonitorConst.OP_LIST[0] + valid_ops.append(default_op) + logger.info_on_rank_0(f"There is no valid ops, default op {default_op} is used") + # 增加默认shape和dtype参数 + if "shape" not in valid_ops: + valid_ops.append("shape") + if "dtype" not in valid_ops: + valid_ops.append("dtype") + return valid_ops + + +def validate_ndigits(ndigits): + if not ndigits: + return + if not is_int(ndigits) or ndigits <= 0: + raise ValueError(f"ndigits({ndigits}) is not a positive integer, current is: {ndigits}.") + if ndigits > MonitorConst.MAX_NDIGITS: + raise ValueError(f"The maximum supported ndigits is {MonitorConst.MAX_NDIGITS}, current value: {ndigits}.") + + +def validate_ranks(ranks): + if not isinstance(ranks, list): + raise TypeError("module_ranks should be a list") + for rank in ranks: + if not isinstance(rank, int) or isinstance(rank, bool): + raise TypeError(f"element in module_ranks should be a int, get {type(rank)}") + + +def validate_targets(targets): + if not isinstance(targets, dict): + raise TypeError('targets in config.json should be a dict') + for module_name, field in targets.items(): + if not isinstance(module_name, str): + raise TypeError('key of targets should be module_name[str] in config.json') + if not isinstance(field, dict): + raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json') + + +def validate_l2_targets(targets): + if not isinstance(targets, dict): + raise TypeError('l2_targets in config.json should be a dict') + for hook_name, target_list in targets.items(): + if hook_name not in MonitorConst.L2_HOOKS: + raise TypeError(f'key of l2_targtes must be in {MonitorConst.L2_HOOKS}, got {hook_name}') + if not isinstance(target_list, list): + raise TypeError('values of l2_targets should be a list in config.json') + for item in target_list: + if not isinstance(item, str): + raise TypeError(f'item of "{hook_name}" in l2_targets should be module_name[str] in config.json') + + +def validate_recording_l2_features(recording_l2_features): + if not isinstance(recording_l2_features, bool): + raise TypeError("recording_l2_features should be a bool") + + +def validate_sa_order(sa_order): + if isinstance(sa_order, str): + sa_order = sa_order.replace(' ', '') + if sa_order not in MonitorConst.SA_ORDERS: + raise TypeError(f'sa_order must be in {MonitorConst.SA_ORDERS}, got {sa_order}') + + +def validate_print_struct(print_struct): + if not isinstance(print_struct, bool): + raise TypeError("print_struct should be a bool") + + +def validate_ur_distribution(ur_distribution): + if not isinstance(ur_distribution, bool): + raise TypeError('ur_distribution should be a bool') + + +def validate_xy_distribution(xy_distribution): + if not isinstance(xy_distribution, bool): + raise TypeError('xy_distribution should be a bool') + + +def validate_wg_distribution(wg_distribution): + if not isinstance(wg_distribution, bool): + raise TypeError('wg_distribution should be a bool') + + +def validate_mg_distribution(mg_distribution): + if not isinstance(mg_distribution, bool): + raise TypeError('mg_distribution should be a bool') + + +def validate_param_distribution(param_distribution): + if not isinstance(param_distribution, bool): + raise TypeError('param_distribution should be a bool') + + +def validate_cc_distribution(cc_distribution): + if not isinstance(cc_distribution, dict): + raise TypeError('cc_distribution should be a dictionary') + for key, value in cc_distribution.items(): + if key == 'enable': + if not isinstance(value, bool): + raise TypeError('cc_distribution enable should be a bool') + elif key == 'cc_codeline': + if not isinstance(value, list): + raise TypeError('cc_distribution cc_codeline should be a list') + elif key == 'cc_pre_hook': + if not isinstance(value, bool): + raise TypeError('cc_distribution cc_pre_hook should be a bool') + elif key == 'cc_log_only': + if not isinstance(value, bool): + raise TypeError('cc_distribution cc_log_only should be a bool') + else: + raise TypeError(f'{key} of cc_distribution is not supported.') + + +def validate_squash_name(squash_name): + if not isinstance(squash_name, bool): + raise TypeError('squash_name should be a bool') + + +def validate_alert(alert): + if not isinstance(alert, dict): + raise TypeError('alert should be a dictionary') + rules = alert.get('rules') + if rules and isinstance(rules, list): + for rule in rules: + rule_name = rule.get("rule_name") + if rule_name and rule_name not in MonitorConst.RULE_NAME: + raise TypeError(f"{rule_name} is not supported") + args = rule.get("args") + if args and isinstance(args, dict): + threshold = args.get("threshold") + if not isinstance(threshold, (float, int)) or threshold < 0: + raise TypeError('threshold must be float and not less than 0') + dump = alert.get('dump') + if dump and not isinstance(dump, bool): + raise TypeError('dump must be bool.') + + +def validate_step_count_per_record(step_count_per_record): + if not is_int(step_count_per_record): + raise TypeError('step_count_per_record must be int.') + if step_count_per_record < 1: + raise ValueError("step_count_per_record must greater than 0") + if step_count_per_record > 1e6: + raise ValueError("step_count_per_record must smaller than 1e6") + + +def validate_dynamic_on(dynamic_on): + if not isinstance(dynamic_on, bool): + raise TypeError('dynamic_on should be a bool') + + +def validate_monitor_mbs_grad(monitor_mbs_grad): + if not isinstance(monitor_mbs_grad, bool): + logger.warning(f'monitor_mbs_grad should be a bool, actual value is {monitor_mbs_grad}.') + return False + return monitor_mbs_grad + + +def validate_append_output(append_output): + if not isinstance(append_output, list): + raise TypeError('append_output should be a list') + if len(append_output) > 0 and len(append_output) != 2: + raise ValueError('append_output should be empty or contain exactly 2 elements') + + +def validate_config(config): + config['ops'] = validate_ops(config.get('ops', [])) + + ndigits = config.get('ndigits') + validate_ndigits(ndigits) + + eps = config.get('eps', 1e-8) + if not isinstance(eps, float): + raise TypeError("eps should be a float") + + ranks = config.get("module_ranks", []) + validate_ranks(ranks) + + targets = config.get("targets", {}) + validate_targets(targets) + + l2_targets = config.get("l2_targets", {}) + validate_l2_targets(l2_targets) + + recording_l2_features = config.get("recording_l2_features", False) + validate_recording_l2_features(recording_l2_features) + + sa_order = config.get("sa_order", "s,b,h,d") + validate_sa_order(sa_order) + + print_struct = config.get('print_struct', False) + validate_print_struct(print_struct) + + ur_distribution = config.get('ur_distribution', False) + validate_ur_distribution(ur_distribution) + + xy_distribution = config.get('xy_distribution', False) + validate_xy_distribution(xy_distribution) + + wg_distribution = config.get('wg_distribution', False) + validate_wg_distribution(wg_distribution) + + mg_distribution = config.get('mg_distribution', False) + validate_mg_distribution(mg_distribution) + + param_distribution = config.get('param_distribution', False) + validate_param_distribution(param_distribution) + + cc_distribution = config.get('cc_distribution', {}) + validate_cc_distribution(cc_distribution) + + alert = config.get('alert', {}) + validate_alert(alert) + + step_count_per_record = config.get('step_count_per_record', 1) + validate_step_count_per_record(step_count_per_record) + + config["start_step"] = validate_int_arg(config.get("start_step"), "start_step", + MonitorConst.DEFAULT_START_STEP, MonitorConst.DEFAULT_START_STEP) + config["collect_times"] = validate_int_arg(config.get("collect_times"), "collect_times", + MonitorConst.DEFAULT_MIN_COLLECT_TIMES, + MonitorConst.DEFAULT_MAX_COLLECT_TIMES) + config["step_interval"] = validate_int_arg(config.get("step_interval"), "step_interval", + MonitorConst.DEFAULT_STEP_INTERVAL, MonitorConst.DEFAULT_STEP_INTERVAL) + + squash_name = config.get('squash_name', True) + validate_squash_name(squash_name) + + time_tags = config.get("append_output", []) + validate_append_output(time_tags) + + config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False)) + + dynamic_on = config.get('dynamic_on', False) + validate_dynamic_on(dynamic_on) + + if not targets: + if xy_distribution: + config["all_xy"] = True + config["targets"] = {"": {}} + + +def time_str2time_digit(time_str): + time_format = '%b%d_%H-%M-%S' + if not isinstance(time_str, str): + raise TypeError(f"time_str:{time_str} should be a str") + try: + time_digit = datetime.strptime(time_str, time_format) + except Exception as e: + raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \ + of existing output dirpath, like 'Dec03_21-34-40'.") from e + return time_digit + + +def get_target_output_dir(monitor_path, time_start, time_end): + check_file_or_directory_path(monitor_path, isdir=True) + time_start = time_str2time_digit(time_start) if time_start is not None else time_start + time_end = time_str2time_digit(time_end) if time_end is not None else time_end + if time_start and time_end and time_start > time_end: + raise ValueError(f"time_start({time_start}) greater than time_end({time_end})") + result = {} + for dirname in os.listdir(monitor_path): + match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname) + if not match: + continue + time_tag = match.group(1) + rank = match.group(2) + target_time = time_str2time_digit(time_tag) + start_ok = time_start is None or target_time >= time_start + end_ok = time_end is None or target_time <= time_end + if start_ok and end_ok: + result[rank] = os.path.join(monitor_path, dirname) + return result + + +def chmod_tensorboard_dir(path): + """ + format配置为tensorboard时,需要补充文件权限设置 + """ + try: + recursive_chmod(path) + except Exception as e: + logger.warning(f"chmod tensorboard dir wrong because {e}, not updated, please check!!!") + + +def validate_set_monitor(grad_acc_steps, start_iteration): + """ + validate parameters of set_monitor. + """ + grad_acc_steps = validate_int_arg(grad_acc_steps, "grad_acc_steps", + MonitorConst.DEFAULT_GRAD_ACC_STEPS, MonitorConst.DEFAULT_GRAD_ACC_STEPS) + + start_iteration = validate_int_arg(start_iteration, "start_iteration", + MonitorConst.DEFAULT_START_ITERATION, MonitorConst.DEFAULT_START_ITERATION) + return grad_acc_steps, start_iteration + + +def validate_int_arg(value, name, minimum, default_value): + """Validate int args, if any exception occurs, use the default value.""" + if value is None: + return default_value + try: + if not is_int(value): + raise TypeError(f"{name} must be int") + if value < minimum: + raise ValueError(f"{name} must greater than {minimum}") + except Exception as e: + value = default_value + logger.warning(f"Validate {name} failed, {e}, replaced with default value {value}.") + return value diff --git a/debug/accuracy_tools/msprobe/core/overflow_check/abnormal_scene.py b/debug/accuracy_tools/msprobe/core/overflow_check/abnormal_scene.py index 54dae2576e48b7ad75df97fa046e6e90bbd144c2..0e0c50cc6aa0cf93f963a699ee36c13d888ec320 100644 --- a/debug/accuracy_tools/msprobe/core/overflow_check/abnormal_scene.py +++ b/debug/accuracy_tools/msprobe/core/overflow_check/abnormal_scene.py @@ -20,6 +20,7 @@ import numpy as np from msprobe.core.overflow_check.api_info import APIInfo from msprobe.core.overflow_check.level import OverflowLevel from msprobe.core.overflow_check.utils import has_nan_inf +from msprobe.core.common.decorator import recursion_depth_decorator class AnomalyScene: @@ -35,6 +36,7 @@ class AnomalyScene: raise NotImplementedError @staticmethod + @recursion_depth_decorator("AbnormalScene: AnomalyScene._has_anomaly") def _has_anomaly(data: Union[Dict, Any]) -> bool: """检查张量是否包含异常值""" if isinstance(data, dict): diff --git a/debug/accuracy_tools/msprobe/core/service.py b/debug/accuracy_tools/msprobe/core/service.py new file mode 100644 index 0000000000000000000000000000000000000000..a62d24974bb16843a4af80a0d7b555168e4a80e6 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/service.py @@ -0,0 +1,348 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import functools +import os +from abc import ABC, abstractmethod +from collections import defaultdict + +from msprobe.core.common.exceptions import DistributedNotInitializedError +from msprobe.core.common.file_utils import create_directory +from msprobe.core.common.runtime import Runtime +from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation +from msprobe.core.data_dump.api_registry import ApiRegistry +from msprobe.core.data_dump.data_collector import build_data_collector +from msprobe.core.kernel_dump.kernel_config import create_kernel_config_json + + +class BaseService(ABC): + def __init__(self, config): + self.config = copy.deepcopy(config) + self.config.level = getattr(config, 'level_ori', config.level) # 兼容MindSpore配置 + self.model = None + self.data_collector = build_data_collector(self.config) + self.current_iter = 0 + self.loop = 0 + self.init_step = 0 + self.cur_token_id = 0 + self.first_start = True + self.primitive_switch = False + self.current_rank = None + self.dump_iter_dir = None + self.should_stop_service = False + self.hooked_modules = [] + self.ori_customer_func = {} + self.debug_variable_counter = None + self.currrent_step_first_debug_save = True + self.logger = None # 子类中注入 + self.api_register = None # 子类中注入 + self.api_template = None # 子类中注入 + self.hook_manager = None # 子类中注入 + self._init_specific_components() + self._register_api_hook() + + @property + def _is_debug_level(self): + return self.config.level == Const.LEVEL_DEBUG + + @property + def _is_l2_level(self): + return self.config.level == Const.LEVEL_L2 + + @property + def _is_mix_level(self): + return self.config.level == Const.LEVEL_MIX + + @property + def _is_need_module_hook(self): + return self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L0] + + @property + def _is_need_api_hook(self): + return self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2] + + @property + def _is_no_dump_step(self): + return (self.config.step and self.current_iter not in self.config.step) + + @property + def _is_no_dump_rank(self): + return (self.config.rank and self.current_rank not in self.config.rank) + + @property + def _need_tensor_data(self): + """判断是否需要采集tensor数据""" + return bool( + self.config.task in self.data_collector.tasks_need_tensor_data or + (self.config.task == Const.STATISTICS and self.config.tensor_list) + ) + + @property + @abstractmethod + def _get_framework_type(self): + """获取框架类型""" + pass + + @staticmethod + @abstractmethod + def _get_current_rank(): + """获取当前rank_id""" + pass + + @staticmethod + def _change_jit_switch(status): + """修改JitDump开关,mindspore子类重写""" + pass + + def start(self, model=None, token_range=None): + """通用start模板""" + self._process_iteration() + if self._is_debug_level: + return + if model: + self.model = model + if self._is_need_module_hook and self.model not in self.hooked_modules: + self._register_module_hook() + self.hooked_modules.append(self.model) + if self._need_stop_service(): + return + Runtime.is_running = True + self.cur_token_id = 0 + if self.first_start: + try: + self.current_rank = self._get_current_rank() + except DistributedNotInitializedError: + self.current_rank = None + Runtime.current_rank = self.current_rank + if self._is_no_dump_rank: + Runtime.is_running = False + return + self._register_hook() + self.first_start = False + + if token_range: + self._register_infer_count_hook(self.model, token_range) + self.logger.info(f"{Const.TOOL_NAME}: debugger.start() is set successfully") + if token_range is None: + self.primitive_switch = True + self._change_jit_switch(True) + self.logger.info(f"Dump switch is turned on at step {self.current_iter}. ") + + self.create_dirs() + self.logger.info(f"Dump data will be saved in {self.dump_iter_dir}.") + + def stop(self): + """通用stop模板""" + if self._is_debug_level or self.should_stop_service: + return + if self._is_no_dump_step or self._is_no_dump_rank: + return + self.logger.info(f"{Const.TOOL_NAME}: debugger.stop() is set successfully. " + "Please set debugger.start() to turn on the dump switch again. ") + Runtime.is_running = False + self.primitive_switch = False + self._change_jit_switch(False) + if self._is_l2_level: + return + + self._process_async_dump() + self.data_collector.write_json() + + def step(self): + """通用step处理""" + if self.should_stop_service: + return + self._process_async_dump() + self.data_collector.write_json() + self.currrent_step_first_debug_save = True + self.loop += 1 + self._reset_status() + + def save(self, variable, name, save_backward): + ''' + Args: + variable: Union[List[variable], dict{str: variable}, mindspore.tensor, str, float, int] + name: str + save_backward: boolean + Return: + void + ''' + if not self._is_debug_level: + return + self.current_iter = self.loop + self.init_step + if self._is_no_dump_step: + return + + if self.currrent_step_first_debug_save: + try: + self.current_rank = self._get_current_rank() + except DistributedNotInitializedError: + self.current_rank = None + + self.create_dirs() + self.debug_variable_counter = defaultdict(int) + self.currrent_step_first_debug_save = False + + count = self.debug_variable_counter[name] + self.debug_variable_counter[name] += 1 + + name_with_count = f"{name}.{count}" + grad_name_with_count = f"{name}_grad.{count}" + + # forward save + self.data_collector.debug_data_collect_forward(variable, name_with_count) + + # backward save + if save_backward: + self.data_collector.debug_data_collect_backward(variable, grad_name_with_count) + + def register_custom_api(self, module, api_name, api_prefix): + self.ori_customer_func[str(module) + Const.SEP + api_name] = getattr(module, api_name) + ApiRegistry.register_custom_api(module, api_name, api_prefix, + functools.partial(self.build_hook, Const.API), self.api_template) + + def restore_custom_api(self, module, api): + ori_func = self.ori_customer_func.get(str(module) + Const.SEP + api) + if ori_func: + setattr(module, api, ori_func) + + def build_hook(self, hook_type, name): + return self.hook_manager.build_hook(hook_type, name) + + def create_dirs(self): + """统一目录创建逻辑""" + create_directory(self.config.dump_path) + if Runtime.run_mode == Const.PYNATIVE_GRAPH_MODE: + self.dump_iter_dir = os.path.join(self.config.dump_path, Const.PYNATIVE_MODE, f"step{self.current_iter}") + else: + self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}") + + cur_rank = self.current_rank if self.current_rank is not None else '' + if self._is_l2_level: + self._create_l2_dirs(cur_rank) + else: + self._create_default_dirs(cur_rank) + + @abstractmethod + def _init_specific_components(self): + """初始化框架特定组件""" + pass + + @abstractmethod + def _register_hook(self): + """注册hook函数""" + pass + + @abstractmethod + def _register_module_hook(self): + """注册模块级别的hook函数""" + + def _need_stop_service(self): + if self.should_stop_service: + return True + end_service = self.config.step and self.current_iter > max(self.config.step) or \ + self.data_collector and self.data_collector.data_processor.is_terminated + if end_service: + self.primitive_switch = False + self._change_jit_switch(False) + Runtime.is_running = False + self.should_stop_service = True + print_tools_ends_info() + return True + if self._is_no_dump_step: + return True + return False + + def _register_api_hook(self): + if self._is_need_api_hook: + self.api_register.initialize_hook(functools.partial(self.build_hook, Const.API)) + self.api_register.register_all_api() + self.logger.info(f"The api {self.config.task} hook function is successfully mounted to the model.") + + def _register_infer_count_hook(self, root_model, token_range): + """ + 通过root_model执行的轮次来判断当前在第几个token + param root_model: 需要采集的推理模型 + param token_range: [start, end], 采集infer的token循环范围,左右皆包含在内 + return: None + """ + def infer_hook(model, args): + if self.cur_token_id == token_range[0]: + Runtime.is_running = True + self.primitive_switch = True + self._change_jit_switch(True) + self.logger.info(f"Current token id: {self.cur_token_id}, start dump infer token.") + elif token_range[0] < self.cur_token_id <= token_range[1]: + self.logger.debug(f"Current token id: {self.cur_token_id}.") + elif self.cur_token_id == token_range[1] + 1: + Runtime.is_running = False + self.primitive_switch = False + self._change_jit_switch(False) + self.logger.info( + f"Current token id: {self.cur_token_id}, exceed token_range, early stop dump infer token.") + self.cur_token_id += 1 + # 此处root_model可以保证为 Module/Cell类型 或 [Module/Cell]类型 + if root_model and isinstance(root_model, list): + root_model = root_model[0] + self.logger.warning("Infer model can only input one to support token_range, choose the first one.") + + root_model.register_forward_pre_hook(infer_hook) + + def _create_l2_dirs(self, cur_rank): + create_directory(self.dump_iter_dir) + kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank) + self.config.kernel_config_path = kernel_config_path + + def _create_default_dirs(self, cur_rank): + dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}") + create_directory(dump_dir) + + dump_data_dir = None + if self._need_tensor_data: + dump_data_dir = os.path.join(dump_dir, "dump_tensor_data") + create_directory(dump_data_dir) + + self._configure_dump_paths(dump_dir, dump_data_dir) + + def _configure_dump_paths(self, dump_dir, dump_data_dir): + dump_path_aggregation = DumpPathAggregation() + dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json") + dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json") + dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json") + dump_path_aggregation.dump_error_info_path = os.path.join(dump_dir, "dump_error_info.log") + dump_path_aggregation.dump_tensor_data_dir = dump_data_dir + dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json") + dump_path_aggregation.free_benchmark_file_path = os.path.join(dump_dir, "free_benchmark.csv") + self.data_collector.update_dump_paths(dump_path_aggregation) + self.data_collector.initialize_json_file(self._get_framework_type) + + def _process_iteration(self): + """处理迭代计数""" + self.current_iter = self.loop + self.init_step + self.data_collector.update_iter(self.current_iter) + Runtime.current_iter = self.current_iter + + def _process_async_dump(self): + """处理异步dump逻辑""" + if self.config.async_dump and self.config.task in [Const.STATISTICS, Const.TENSOR]: + self.data_collector.data_processor.dump_async_data() + + def _reset_status(self): + """通用状态重置""" + self.data_collector.reset_status() + self.hook_manager.reset_status() + if self._is_l2_level: + self.data_collector.data_processor.reset_status() diff --git a/debug/accuracy_tools/msprobe/core/single_save/__init__.py b/debug/accuracy_tools/msprobe/core/single_save/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/core/single_save/single_comparator.py b/debug/accuracy_tools/msprobe/core/single_save/single_comparator.py new file mode 100644 index 0000000000000000000000000000000000000000..61095de62cf693d53c4088593f644617476186bb --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/single_save/single_comparator.py @@ -0,0 +1,258 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import multiprocessing +from dataclasses import dataclass + +import numpy as np +import pandas as pd +from tqdm import tqdm + +from msprobe.core.common.file_utils import check_file_or_directory_path, create_directory, load_npy, save_excel +from msprobe.core.common.log import logger +from msprobe.core.common.utils import check_process_num + + +@dataclass +class CompareResult: + max_abs_error: float + max_relative_error: float + same_percentage: float + first_mismatch_index: int + percentage_within_thousandth: float + percentage_within_hundredth: float + + +class SingleComparator: + result_header = [ + 'step', + 'rank', + 'micro_step', + 'id', + 'shape1', + 'shape2', + '相同元素百分比(%)', + '首个不匹配元素索引', + '最大绝对误差', + '最大相对误差', + '误差在千分之一内元素占比(%)', + '误差在百分之一内元素占比(%)' + ] + + @classmethod + def compare(cls, dir1, dir2, output_path="./msprobe_compare_output", num_processes=8): + data_dir1 = os.path.join(dir1, "data") + data_dir2 = os.path.join(dir2, "data") + check_file_or_directory_path(data_dir1, isdir=True) + check_file_or_directory_path(data_dir2, isdir=True) + check_process_num(num_processes) + # 确保输出目录存在,如果不存在则创建 + if not os.path.exists(output_path): + create_directory(output_path) + cls.compare_data(data_dir1, data_dir2, output_path, num_processes) + + @classmethod + def compare_arrays(cls, array1, array2) -> CompareResult: + """ + 比较两个NumPy数组,计算最大绝对误差、最大相对误差和相同元素的百分比 + """ + # 计算每个维度上的最小尺寸 + if array1.ndim != array2.ndim: + array1 = array1.flatten() + array2 = array2.flatten() + min_shape = [min(s1, s2) for s1, s2 in zip(array1.shape, array2.shape)] + # 截取数组到相同的形状 + sliced_array1 = array1[tuple(slice(0, s) for s in min_shape)] + sliced_array2 = array2[tuple(slice(0, s) for s in min_shape)] + + abs_error = np.abs(sliced_array1 - sliced_array2) + max_abs_error = np.max(abs_error) + + # 计算相对误差,处理分母为零的情况 + with np.errstate(divide='ignore', invalid='ignore'): + relative_error = np.abs(sliced_array1 - sliced_array2) / \ + np.maximum(np.abs(sliced_array1), np.abs(sliced_array2)) + relative_error = np.nan_to_num(relative_error) + max_relative_error = np.max(relative_error) + + same_elements = np.sum(sliced_array1 == sliced_array2) + total_elements = sliced_array1.size + same_percentage = (same_elements / total_elements) * 100 + + # 展平数组 + flat_array1 = sliced_array1.flatten() + flat_array2 = sliced_array2.flatten() + + # 计算从第几个元素开始对不上 + mismatch_indices = np.nonzero(flat_array1 != flat_array2)[0] + first_mismatch_index = mismatch_indices[0] if mismatch_indices.size > 0 else None + + # 计算误差在千分之一内的元素占比 + threshold = 0.001 * np.maximum(np.abs(sliced_array1), np.abs(sliced_array2)) + error_within_thousandth = np.sum(abs_error <= threshold) + percentage_within_thousandth = (error_within_thousandth / total_elements) * 100 + + # 计算误差在百分之一内的元素占比 + threshold = 0.01 * np.maximum(np.abs(sliced_array1), np.abs(sliced_array2)) + error_within_hundredth = np.sum(abs_error <= threshold) + percentage_within_hundredth = (error_within_hundredth / total_elements) * 100 + + return CompareResult( + max_abs_error, + max_relative_error, + same_percentage, + first_mismatch_index, + percentage_within_thousandth, + percentage_within_hundredth + ) + + @classmethod + def get_steps(cls, tag_path): + for step_folder in os.listdir(tag_path): + if step_folder.startswith('step'): + try: + step = int(step_folder[4:]) + except Exception as e: + raise RuntimeError(f"parse step number error") from e + yield step, os.path.join(tag_path, step_folder) + + @classmethod + def get_ranks(cls, step_path): + for rank_folder in os.listdir(step_path): + if rank_folder.startswith('rank'): + try: + rank = int(rank_folder[4:]) + except Exception as e: + raise RuntimeError(f"parse rank number error") from e + yield rank, os.path.join(step_path, rank_folder) + + @classmethod + def get_micro_steps(cls, rank_path): + for micro_step_folder in os.listdir(rank_path): + if micro_step_folder.startswith('micro_step'): + try: + micro_step = int(micro_step_folder[10:]) + except Exception as e: + raise RuntimeError(f"parse nicro_step number error") from e + yield micro_step, os.path.join(rank_path, micro_step_folder) + else: + yield 0, rank_path + + @classmethod + def get_arrays(cls, micro_step_path): + for file in os.listdir(micro_step_path): + if file.endswith('.npy'): + try: + parts = file.rsplit('.', 2) + if len(parts) > 1 and parts[-2].isdigit(): + array_id = int(parts[-2]) + else: + array_id = 0 + except ValueError: + array_id = 0 + yield array_id, os.path.join(micro_step_path, file) + + @classmethod + def get_array_paths(cls, dir_path): + """ + 获取目录中所有符合结构的NumPy数组文件路径 + """ + array_paths = {} + if not os.path.exists(dir_path): + return array_paths + for tag in os.listdir(dir_path): + tag_path = os.path.join(dir_path, tag) + if not os.path.isdir(tag_path): + continue + for step, step_path in cls.get_steps(tag_path): + for rank, rank_path in cls.get_ranks(step_path): + for item in os.listdir(rank_path): + next_path = os.path.join(rank_path, item) + if re.match(r"micro_step(\d+)", item): + micro_step = re.match(r"micro_step(\d+)", item).group(1) + for array_id, array_path in cls.get_arrays(next_path): + array_paths.setdefault(tag, []).append( + (step, rank, int(micro_step), array_id, array_path)) + elif re.match(r"\w{1,100}_(\d{1,100})\.npy", item): + array_id = re.match(r"\w{1,100}_(\d{1,100})\.npy", item).group(1) + array_paths.setdefault(tag, []).append((step, rank, 0, int(array_id), next_path)) + else: + array_paths.setdefault(tag, []).append((step, rank, 0, 0, next_path)) + return array_paths + + @classmethod + def compare_single_tag(cls, tag, array_paths1, array_paths2, output_dir): + data = [] + paths1 = array_paths1.get(tag, []) + paths2 = array_paths2.get(tag, []) + path_dict1 = {(step, rank, micro_step, array_id): path for step, rank, micro_step, array_id, path in paths1} + path_dict2 = {(step, rank, micro_step, array_id): path for step, rank, micro_step, array_id, path in paths2} + common_keys = set(path_dict1.keys()) & set(path_dict2.keys()) + for key in common_keys: + try: + array1 = load_npy(path_dict1[key]) + array2 = load_npy(path_dict2[key]) + result = cls.compare_arrays(array1, array2) + step, rank, micro_step, array_id = key + data.append([ + step, rank, micro_step, array_id, + list(array1.shape), list(array2.shape), + result.same_percentage, + result.first_mismatch_index, + result.max_abs_error, + result.max_relative_error, + result.percentage_within_thousandth, + result.percentage_within_hundredth + ]) + except Exception as e: + logger.error(f"Error comparing {path_dict1[key]} and {path_dict2[key]}: {e}") + + try: + df = pd.DataFrame(data, columns=SingleComparator.result_header) + df = df.sort_values(by=['step', 'rank', 'micro_step', 'id']) + # 构建输出文件的完整路径 + output_file_path = os.path.join(output_dir, f'{tag}.xlsx') + save_excel(output_file_path, df) + except Exception as e: + logger.error(f"Error processing tag {tag}: {e}") + + @classmethod + def compare_data(cls, dir1, dir2, output_dir, num_processes=8): + """ + 比较两个目录中的NumPy数组文件,并将结果保存到指定目录的Excel文件中 + """ + + array_paths1 = cls.get_array_paths(dir1) + array_paths2 = cls.get_array_paths(dir2) + + all_tags = set(array_paths1.keys()) | set(array_paths2.keys()) + + with multiprocessing.Pool(processes=num_processes) as pool: + args = [(tag, array_paths1, array_paths2, output_dir) for tag in all_tags] + try: + results = pool.starmap_async(cls.compare_single_tag, args) + with tqdm(total=len(all_tags), desc="Processing data") as pbar: + while not results.ready(): + pbar.n = len(all_tags) - results._number_left + pbar.refresh() + results.wait() + results.get() + except Exception as e: + logger.error(f"Multiprocessing error: {e}") + finally: + pool.close() + pool.join() diff --git a/debug/accuracy_tools/msprobe/core/single_save/single_saver.py b/debug/accuracy_tools/msprobe/core/single_save/single_saver.py new file mode 100644 index 0000000000000000000000000000000000000000..14281e1b513c090299e8aaca3db0320e52742205 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/single_save/single_saver.py @@ -0,0 +1,146 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from collections import defaultdict + +from msprobe.core.common.file_utils import create_directory, save_json +from msprobe.core.common.const import Const +from msprobe.core.common.framework_adapter import FmkAdp +from msprobe.core.common.log import logger + + +support_nested_data_type = (list, tuple, dict) + + +class SingleSave: + _instance = None + + def __new__(cls, dump_path, fmk=Const.PT_FRAMEWORK): + if cls._instance is None: + cls._instance = super(SingleSave, cls).__new__(cls) + FmkAdp.set_fmk(fmk) + create_directory(dump_path) + + cls._instance.dump_path = dump_path + cls._instance.rank = FmkAdp.get_rank_id() + cls._instance.step_count = 0 + cls._instance.tag_count = defaultdict(int) + return cls._instance + + @staticmethod + def _analyze_tensor_data(data, data_name=None, save_dir=None): + ''' + data: Tensor + return: + result_data: with keys {"max", "min", "mean", "norm", "shape"} + ''' + result_data = {} + result_data["max"] = FmkAdp.tensor_max(data) + result_data["min"] = FmkAdp.tensor_min(data) + result_data["mean"] = FmkAdp.tensor_mean(data) + result_data["norm"] = FmkAdp.tensor_norm(data) + result_data["shape"] = list(data.shape) + if save_dir is not None and data_name is not None: + real_save_path = os.path.join(save_dir, data_name + ".npy") + FmkAdp.save_tensor(data, real_save_path) + return result_data + + @classmethod + def save_config(cls, data): + dump_file = os.path.join(cls._instance.dump_path, 'configurations.json') + save_json(dump_file, data, indent=4) + + @classmethod + def save_ex(cls, data, micro_batch=None): + ''' + data: dict{str: Union[Tensor, tuple, list]} + + return: void + ''' + + instance = cls._instance + + if not isinstance(data, dict): + logger.warning("SingleSave data type not valid, " + "should be dict. " + "Skip current save process.") + return + for key, value in data.items(): + if not isinstance(key, str): + logger.warning("key should be string when save data") + continue + if not isinstance(value, support_nested_data_type) and not FmkAdp.is_tensor(value): + logger.warning(f"value should be {support_nested_data_type} or Tensor when save data") + continue + real_dump_dir = os.path.join( + instance.dump_path, + "data", + key, + f"step{instance.step_count}", + f"rank{instance.rank}") + if micro_batch is not None: + real_dump_dir = os.path.join(real_dump_dir, f"micro_step{micro_batch}") + create_directory(real_dump_dir) + + if FmkAdp.is_tensor(value): + result = cls._analyze_tensor_data(value, key, real_dump_dir) + elif isinstance(value, (tuple, list)): + result = cls._analyze_list_tuple_data(value, key, real_dump_dir) + elif isinstance(value, dict): + result = cls._analyze_dict_data(value, key, real_dump_dir) + + result_json = {"data": result} + json_path = os.path.join(real_dump_dir, key + ".json") + save_json(json_path, result_json, indent=4) + + + @classmethod + def step(cls): + instance = cls._instance + instance.tag_count = defaultdict(int) + instance.step_count += 1 + + @classmethod + def save(cls, data): + instance = cls._instance + if not isinstance(data, dict): + logger.warning("SingleSave data type not valid, " + "should be dict. " + "Skip current save process.") + return + for key, value in data.items(): + cls.save_ex({key: value}, micro_batch=instance.tag_count[key]) + instance.tag_count[key] += 1 + + @classmethod + def _analyze_list_tuple_data(cls, data, data_name=None, save_dir=None): + lst = [] + for index, element in enumerate(data): + if not FmkAdp.is_tensor(element): + raise TypeError(f"SingleSave: Unsupported type: {type(element)}") + element_name = data_name + "." + str(index) + lst.append(cls._analyze_tensor_data(element, element_name, save_dir)) + return lst + + @classmethod + def _analyze_dict_data(cls, data, data_name=None, save_dir=None): + result_data = {} + for key, value in data.items(): + if not FmkAdp.is_tensor(value): + raise TypeError(f"SingleSave: Unsupported type: {type(value)}") + key_name = data_name + "." + str(key) + result_data[key] = cls._analyze_tensor_data(value, key_name, save_dir) + return result_data diff --git a/debug/accuracy_tools/msprobe/docs/01.installation.md b/debug/accuracy_tools/msprobe/docs/01.installation.md index 1ab5f6419ba07ec749bad139f874fbc7301fd8b3..92205360e62dea50c5975cb57a7411b23744c901 100644 --- a/debug/accuracy_tools/msprobe/docs/01.installation.md +++ b/debug/accuracy_tools/msprobe/docs/01.installation.md @@ -14,14 +14,20 @@ pip install mindstudio-probe ## 2 下载 whl 包安装 -|版本|发布日期|支持 PyTorch 版本|支持 MindSpore 版本|下载链接|校验码| -|:--:|:--:|:--:|:--:|:--:|:--:| -|1.2.2|2025.2.26|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.2.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.2/mindstudio_probe-1.2.2-py3-none-any.whl)|1db0cf4572bc0305c68705b74775f652c6cb2c2bedb6c6e57f43e31ab273b288| -|1.2.1|2025.2.07|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.2.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.2/mindstudio_probe-1.2.1-py3-none-any.whl)|b64b342118558e0339b39237f88a49b93fd24551b0cb202c872fbfef4260c86b| -|1.2.0|2025.1.13|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.2.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.2/mindstudio_probe-1.2.0-py3-none-any.whl)|1e3aeea1706112f6ee52fd1165037936bb209138f0b9ec42ea21e2c1c8942cdc| -|1.1.1|2024.12.09|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.1.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.1/mindstudio_probe-1.1.1-py3-none-any.whl)|577b597555dc155b76ba1a62d575c3546004644e140a456c3ba0824d46283735| -|1.1.0|2024.10.14|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.1.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.1/mindstudio_probe-1.1.0-py3-none-any.whl)|83a5a9b7c65a357639f8c9636d88c693b4cf0eb590d4f8f5cb56395ba69b1f6d| -|1.0.4|2024.09.09|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.0.4-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.0/mindstudio_probe-1.0.4-py3-none-any.whl)|4e1909566a71a855b356597750c20ee43d964a22b2c2b02ac08312a5def75fd6| +| 版本 | 发布日期 |支持 PyTorch 版本|支持 MindSpore 版本| 下载链接 |校验码| +|:-----:|:----------:|:--:|:--:|:----------------------------------------------------------------------------------------------------------------------------------:|:--:| +| 8.2.0 | 2025.9.03 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.2.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.2/mindstudio_probe-8.2.0-py3-none-any.whl) |bbc1577d76754adf987069308177d3e0a04e36de9c7f22e75c34cf4ad0ce1af2| +| 8.1.2 | 2025.8.01 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.1.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.1/mindstudio_probe-8.1.2-py3-none-any.whl) |ff07bb81fddd3b8f3096d119ca1481bde8fdb24f10644def5250caad727448ab| +| 8.1.1 | 2025.6.20 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.1.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.1/mindstudio_probe-8.1.1-py3-none-any.whl) |2aad10a243575544d7feef552caf4d06aa93028488ebd0bbc9aa350379da859d| +| 8.1.0 | 2025.6.14 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.1.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.1/mindstudio_probe-8.1.0-py3-none-any.whl) |d10c0a57d073bbe7c681042a11e93a0eaaaf5aa45e1cec997142ce2593d77afd| +| 8.0.0 | 2025.5.07 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.0.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.0/mindstudio_probe-8.0.0-py3-none-any.whl) |6810eade7ae99e3b24657d5cab251119882decd791aa76a7aeeb94dea767daec| +| 1.3.0 | 2025.4.17 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-1.3.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.3/mindstudio_probe-1.3.0-py3-none-any.whl) |85dbc5518b5c23d29c67d7b85d662517d0318352f372891f8d91e73e71b439c3| +| 1.2.2 | 2025.3.03 |1.11/2.0/2.1/2.2|2.4.0| [mindstudio_probe-1.2.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.2/mindstudio_probe-1.2.2-py3-none-any.whl) |961411bb460d327ea51d6ca4d0c8e8c5565f07c0852d7b8592b781ca35b87212| +| 1.2.1 | 2025.2.07 |1.11/2.0/2.1/2.2|2.4.0| [mindstudio_probe-1.2.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.2/mindstudio_probe-1.2.1-py3-none-any.whl) |b64b342118558e0339b39237f88a49b93fd24551b0cb202c872fbfef4260c86b| +| 1.2.0 | 2025.1.13 |1.11/2.0/2.1/2.2|2.4.0| [mindstudio_probe-1.2.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.2/mindstudio_probe-1.2.0-py3-none-any.whl) |1e3aeea1706112f6ee52fd1165037936bb209138f0b9ec42ea21e2c1c8942cdc| +| 1.1.1 | 2024.12.09 |1.11/2.0/2.1/2.2|2.4.0| [mindstudio_probe-1.1.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.1/mindstudio_probe-1.1.1-py3-none-any.whl) |577b597555dc155b76ba1a62d575c3546004644e140a456c3ba0824d46283735| +| 1.1.0 | 2024.10.14 |1.11/2.0/2.1/2.2|2.4.0| [mindstudio_probe-1.1.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.1/mindstudio_probe-1.1.0-py3-none-any.whl) |83a5a9b7c65a357639f8c9636d88c693b4cf0eb590d4f8f5cb56395ba69b1f6d| +| 1.0.4 | 2024.09.09 |1.11/2.0/2.1/2.2|2.4.0| [mindstudio_probe-1.0.4-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.0/mindstudio_probe-1.0.4-py3-none-any.whl) |4e1909566a71a855b356597750c20ee43d964a22b2c2b02ac08312a5def75fd6| | 1.0.3 | 2024.08.23 | 1.11/2.0/2.1/2.2 | 2.4.0 | [mindstudio_probe-1.0.3-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.0/mindstudio_probe-1.0.3-py3-none-any.whl) | 7060cc141a5b98ef770cd9220995d299393f32a61938261e632c7e8b5160bef2 | | 1.0.2 | 2024.08.09 | 1.11/2.0/2.1/2.2 | 2.4.0 | [mindstudio_probe-1.0.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.0/mindstudio_probe-1.0.2-py3-none-any.whl) | e4a980e5d98c426ce5ce9842520d9bc031d3b3de621c74b3d59414cc6e238e0e | | 1.0.1 | 2024.07.25 | 2.0/2.1/2.2 | 2.4.0 | [mindstudio_probe-1.0.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.0/mindstudio_probe-1.0.1-py3-none-any.whl) | b699e224e4d4e3bcf9412c54fa858a1ee370f0d7a2bc69cb3f1273ac14a6dc82 | @@ -34,8 +40,8 @@ sha256sum {name}.whl # 验证whl包,若校验码一致,则whl包在下载中 ```bash pip install ./mindstudio_probe-{version}-py3-none-any.whl # 安装whl包 ``` - -若覆盖安装,请在命令行末尾添加 `--force-reinstall` 参数。 +若覆盖安装,请在命令行末尾添加 `--force-reinstall` 参数。 +上面提供的whl包链接不包含adump功能,如果需要使用adump功能,请参考[从源码安装](#3-从源码安装)下载源码编译whl包。 ## 3 从源码安装 @@ -52,10 +58,39 @@ pip install ./mindstudio_probe*.whl |参数|说明|是否必选| |--|--|:--:| -|--include-mod|指定可选模块,可取值`adump`,表示在编whl包时加入adump模块。默认未配置该参数,表示编基础包。
• adump模块用于MindSpore静态图场景L2级别的dump。
• 仅MindSpore 2.5.0及以上版本支持adump模块。
• 若使用源码安装,编译环境需支持GCC 7或以上版本,和CMAKE 3.14或以上版本。
• 生成的whl包仅限编译时使用的python版本和处理器架构可用。|否| +|--include-mod|指定可选模块,可取值`adump`,表示在编whl包时加入adump模块。默认未配置该参数,表示编基础包。
• adump模块用于MindSpore静态图场景L2级别的dump。
• 仅MindSpore 2.5.0及以上版本支持adump模块。
• 若使用源码安装,编译环境需支持GCC 7.5或以上版本,和CMake 3.14或以上版本。
• 生成的whl包仅限编译时使用的python版本和处理器架构可用。|否| # 特性变更说明 +## 8.1.1 + +【数据采集】 + +- 单点保存能力增强,新增 MindSpore 和 Pytorch 框架异步单点保存,MindSpore 静态图单点保存能力。 +- task 支持 statistic + tenser 模式共存 +- MindSpore 静态图支持模块级 dump 及比对 +- 支持分析整网首个溢出节点 +- 提供对外接口支持用户注册自定义 api 的 dump + +【训练状态监控】 + +- 支持偏离历史值及时告警 +- 支持 nan 值和极大值即时告警 +- 支持堆栈信息采集 +- 支持 mbs 粒度梯度信息采集 +- 支持采集 shape, dtype 信息 +- 激活值监控支持多输入场景 + +【训练检查】 + +- 新增模块,用于[训练前配置项](./docs/31.config_check.md)对齐 +- 支持三方库,环境变量,训练超参,模型权重,输入数据及随机性函数检查 +- 支持 [checkpoint 比对](./docs/32.ckpt_compare.md) + +【单算子API自动生成脚本】 + +- 新增支持 MindSpore 框架 + ## 1.2.0 【数据采集】 @@ -80,8 +115,6 @@ pip install ./mindstudio_probe*.whl ## 1.1.1 -## 1.1.1 - 【数据采集】 - dump 支持 processgroup、namedtuple、slice 等数据类型 @@ -209,6 +242,6 @@ source {cann_path}/ascend-toolkit/set_env.sh 链接:[https://gitee.com/ascend/pytorch](https://gitee.com/ascend/pytorch)。 -## 3 安装 ModelLink +## 3 安装 MindSpeed LLM -链接:[https://gitee.com/ascend/ModelLink](https://gitee.com/ascend/ModelLink)。 +链接:[https://gitee.com/ascend/MindSpeed-LLM](https://gitee.com/ascend/MindSpeed-LLM)。 diff --git a/debug/accuracy_tools/msprobe/docs/02.config_introduction.md b/debug/accuracy_tools/msprobe/docs/02.config_introduction.md index f134bd4536294d209e7b3e6e73fd80b9be61041d..ac9967b332286d8157482f3b03ef9a10fec8a201 100644 --- a/debug/accuracy_tools/msprobe/docs/02.config_introduction.md +++ b/debug/accuracy_tools/msprobe/docs/02.config_introduction.md @@ -10,47 +10,59 @@ ### 1.1 通用配置 -| 参数 | 解释 | 是否必选 | -| ----------------- |------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | -| task | dump 的任务类型,str 类型。可选参数:
"statistics":仅采集统计信息,默认值;
"tensor":采集统计信息和完全复刻整网的真实数据;
"run_ut":精度预检,仅 PyTorch 场景支持,采集数据时勿选;
"overflow_check":溢出检测;
"free_benchmark":无标杆比对;
"grad_probe":梯度监控;
"structure":仅采集模型结构以及调用栈信息,不采集具体数据。
根据 task 参数取值的不同,可以配置不同场景参数,详见:
[1.2 task 配置为 statistics](#12-task-配置为-statistics),
[1.3 task 配置为 tensor](#13-task-配置为-tensor),
[1.4 task 配置为 run_ut](#14-task-配置为-run_ut),
[1.5 task 配置为 overflow_check](#15-task-配置为-overflow_check),
[1.6 task 配置为 free_benchmark](#16-task-配置为-free_benchmark),
[1.7 task 配置为 grad_probe](#17-task-配置为-grad_probe)。
**配置示例**:"task": "tensor"。 | 否 | -| dump_path | 设置 dump 数据目录路径,str 类型。
**配置示例**:"dump_path": "./dump_path"。 | 是 | -| rank | 指定对某张卡上的数据进行采集,list[Union[int, str]] 类型,默认未配置(表示采集所有卡的数据),应配置元素为 ≥0 的整数或类似"4-6"的字符串,且须配置实际可用的 Rank ID。
PyTorch 场景: Rank ID 从 0 开始计数,最大取值为所有节点可用卡总数-1,若所配置的值大于实际训练所运行的卡的 Rank ID,则 dump 数据为空,比如当前环境 Rank ID 为 0 到 7,实际训练运行 0 到 3 卡,此时若配置 Rank ID 为 4 或不存在的 10 等其他值,dump 数据为空。
MindSpore 场景:所有节点的 Rank ID 均从 0 开始计数,最大取值为每个节点可用卡总数-1,config.json 配置一次 rank 参数对所有节点同时生效。
注意,单卡训练时,rank必须为[],即空列表,不能指定rank。
**配置示例**:"rank": [1, "4-6"]。 | 否 | -| step | 指定采集某个 step 的数据,list[Union[int, str]] 类型。默认未配置,表示采集所有 step 数据。采集特定 step 时,须指定为训练脚本中存在的 step,可逐个配置,也可以指定范围。
**配置示例**:"step": [0, 1 , 2, "4-6"]。 | 否 | -| level | dump 级别,str 类型,根据不同级别采集不同数据。可选参数:
"L0":dump 模块级精度数据,仅 PyTorch 与 MindSpore 动态图场景支持,使用背景详见 [1.1.1 模块级精度数据 dump 说明](#111-模块级精度数据-dump-说明);
"L1":dump API 级精度数据,默认值,仅 PyTorch 与 MindSpore 动态图场景支持;
"L2":dump kernel 级精度数据,PyTorch场景详细介绍见 [PyTorch 场景的 kernel dump 说明](./04.kernel_dump_PyTorch.md);MindSpore场景详细介绍见 [MindSpore 场景的 kernel dump 说明](./28.kernel_dump_MindSpore.md);
"mix":dump module 模块级和 API 级精度数据,即"L0"+"L1",仅 PyTorch 与 MindSpore 动态图场景支持。
"debug":单点保存功能,细节详见[单点保存工具 README](./28.debugger_save_instruction.md)
**配置示例**:"level": "L1"。 | 否 | -| enable_dataloader | 自动控制开关,bool 类型,仅 PyTorch 场景支持。可选参数 true(开启)或 false(关闭),默认为 false。配置为 true 后自动识别 step 参数指定的迭代,并在该迭代执行完成后退出训练,此时 start、stop 和 step 函数可不配置,开启该开关要求训练脚本是通过 torch.utils.data.dataloader 方式加载数据。仅支持 PyTorch 单卡训练使用,分布式训练场景下存在数据 dump 不全问题。 **这个特性下个版本将被废弃** | 否 | -| async_dump | 异步 dump 开关,bool 类型。可选参数 true(开启)或 false(关闭),默认为 false。配置为 true 后开启异步 dump,即采集的精度数据会在当前 step 训练结束后统一落盘,训练过程中工具不触发同步操作。由于使用该模式有**显存溢出**的风险,当 task 配置为 tensor 时,即真实数据的异步dump模式,必须配置 [list](#13-task-配置为-tensor) 参数,指定需要 dump 的 tensor 。该模式暂不支持复数类型 tensor
的统计量计算。 | 否 | +| 参数 | 解释 | 是否必选 | +|-------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | +| task | dump 的任务类型,str 类型。可选参数:
"statistics":仅采集统计信息,默认值;
"tensor":采集统计信息和完全复刻整网的真实数据;
"run_ut":精度预检,仅 PyTorch 场景支持,采集数据时勿选;
"overflow_check":溢出检测;
"free_benchmark":无标杆比对,不支持 MSAdapter 场景;
"grad_probe":梯度监控, 不支持 MSAdapter 场景;
"structure":仅采集模型结构以及调用栈信息,不采集具体数据。
根据 task 参数取值的不同,可以配置不同场景参数,详见:
[1.2 task 配置为 statistics](#12-task-配置为-statistics),
[1.3 task 配置为 tensor](#13-task-配置为-tensor),
[1.4 task 配置为 run_ut](#14-task-配置为-run_ut),
[1.5 task 配置为 overflow_check](#15-task-配置为-overflow_check),
[1.6 task 配置为 free_benchmark](#16-task-配置为-free_benchmark),
[1.7 task 配置为 grad_probe](#17-task-配置为-grad_probe),
[1.8 task 配置为 structure](#18-task-配置为-structure),
[1.9 task 配置为 exception_dump](#19-task-配置为-exception_dump)。
**配置示例**:"task": "tensor"。 | 否 | +| dump_path | 设置 dump 数据目录路径,str 类型。
**配置示例**:"dump_path": "./dump_path"。 | 是 | +| rank | 指定对某张卡上的数据进行采集,list[Union[int, str]] 类型,默认未配置(表示采集所有卡的数据),应配置元素为 ≥0 的整数或类似"4-6"的字符串,且须配置实际可用的 Rank ID。
PyTorch 场景: Rank ID 从 0 开始计数,最大取值为所有节点可用卡总数-1,若所配置的值大于实际训练所运行的卡的 Rank ID,则 dump 数据为空,比如当前环境 Rank ID 为 0 到 7,实际训练运行 0 到 3 卡,此时若配置 Rank ID 为 4 或不存在的 10 等其他值,dump 数据为空。
MindSpore 场景:所有节点的 Rank ID 均从 0 开始计数,最大取值为每个节点可用卡总数-1,config.json 配置一次 rank 参数对所有节点同时生效。静态图 L0 级别 dump 暂不支持指定rank。
注意,单卡训练时,rank必须为[],即空列表,不能指定rank。
**配置示例**:"rank": [1, "4-6"]。 | 否 | +| step | 指定采集某个 step 的数据,list[Union[int, str]] 类型。默认未配置,表示采集所有 step 数据。采集特定 step 时,须指定为训练脚本中存在的 step,可逐个配置,也可以指定范围。
**配置示例**:"step": [0, 1 , 2, "4-6"]。 | 否 | +| level | dump 级别,str 类型,根据不同级别采集不同数据。可选参数:
"L0":dump 模块级精度数据,使用背景详见 [1.1.1 模块级精度数据 dump 说明](#111-模块级精度数据-dump-说明)。
"L1":dump API 级精度数据,默认值,仅 PyTorch、MSAdapter 以及 MindSpore 动态图场景支持。
"L2":dump kernel 级精度数据,PyTorch 场景详细介绍见 [PyTorch 场景的 kernel dump 说明](./04.kernel_dump_PyTorch.md);MindSpore 动态图场景详细介绍见 [MindSpore 动态图场景的 kernel dump 说明](./28.kernel_dump_MindSpore.md);MindSpore 静态图场景详细介绍见《MindSpore 场景的数据采集》中的 ["**8.1 静态图场景**"](./06.data_dump_MindSpore.md#81-静态图场景)小节。
"mix":dump module 模块级和 API 级精度数据,即"L0"+"L1",仅 PyTorch、MSAdapter 以及 MindSpore 动态图场景支持。
"debug":单点保存功能,详见[单点保存工具](./28.debugger_save_instruction.md)。
**配置示例**:"level": "L1"。 | 否 | +| enable_dataloader | 自动控制开关,bool 类型,仅 PyTorch 场景支持。可选参数 true(开启)或 false(关闭),默认为 false。配置为 true 后自动识别 step 参数指定的迭代,并在该迭代执行完成后退出训练,此时 start、stop 和 step 函数可不配置,开启该开关要求训练脚本是通过 torch.utils.data.dataloader 方式加载数据。仅支持 PyTorch 单卡训练使用,分布式训练场景下存在数据 dump 不全问题。 **这个特性下个版本将被废弃** | 否 | +| async_dump | 异步 dump 开关,bool 类型, 支持 task 为 tensor 或 statistic 模式, level 支持 L0、 L1、 mix、 debug 模式。可选参数 true(开启)或 false(关闭),默认为 false。配置为 true 后开启异步 dump,即采集的精度数据会在当前 step 训练结束后统一落盘,训练过程中工具不触发同步操作。由于使用该模式有**显存溢出**的风险,当 task 配置为 tensor 时,即真实数据的异步dump模式,必须配置 [list](#13-task-配置为-tensor) 参数,指定需要 dump 的 tensor 。该模式下,summary_mode 不支持 md5 值,也不支持复数类型 tensor 的统计量计算。
| 否 | +| precision | 控制统计值计算所用精度,可选值["high", "low"],默认值为"low"。选择"high"时,统计量使用float32进行计算,会增加device内存占用,精度更高,但在处理较大数值时可能会导致**显存溢出**;为"low"时使用与原始数据相同的类型进行计算,device内存占用较少。支持 Pytorch,MindSpore 动态图,MindSpore静态图 O0/O1 场景。支持 task 配置为 statistic 或 tensor, level 配置为 L0,L1,mix,debug。 | 否 | #### 1.1.1 模块级精度数据 dump 说明 -仅 PyTorch 与 MindSpore 动态图场景支持。 - 大模型场景下,通常不是简单的利用自动迁移能力实现从 GPU 到 NPU 的训练脚本迁移,而是会对 NPU 网络进行一系列针对性的适配,因此,常常会造成迁移后的 NPU 模型存在部分子结构不能与 GPU 原始模型完全对应。模型结构不一致导致 API 调用类型及数量不一致,若直接按照 API 粒度进行精度数据 dump 和比对,则无法完全比对所有的 API。 本小节介绍的功能是对模型中的大粒度模块进行数据 dump,使其比对时,对于无法以 API 粒度比对的模块可以直接以模块粒度进行比对。 -模块指的是继承 nn.Module 类(PyTorch场景)或 nn.Cell 类(MindSpore场景)的子类,通常情况下这类模块就是一个小模型,可以被视为一个整体,dump 数据时以模块为粒度进行 dump。 +模块指的是继承 nn.Module 类(PyTorch 与 MSAdapter 场景)或 nn.Cell 类(MindSpore 场景)的子类,通常情况下这类模块就是一个小模型,可以被视为一个整体,dump 数据时以模块为粒度进行 dump。 +特别地,在PyTorch场景中,为了规避BackwardHook函数的输出不能进行原地操作的框架限制,工具使用了`torch._C._autograd._set_creation_meta`接口对BackwardHook函数的输出张量进行属性重置,这可能会造成dump数据中缺少原地操作模块(nn.ReLU(inplace=True)及其上一个模块的反向数据。 ### 1.2 task 配置为 statistics - - - - - - - - + + + + + + + + + +
参数解释是否必选
scopePyTorch 和 MindSpore 动态图场景 dump 范围,list[str] 类型,默认未配置(list 也未配置时表示 dump 所有 API 的数据)。该参数可以在 [ ] 内配置两个模块名或 API 名,要求列表长度必须为2,需要配置按照工具命名格式的完整模块名或API名称,用于锁定区间,dump 该范围内的数据。
配置示例: +
scopePyTorch、MSAdapter 以及 MindSpore 动态图场景 dump 范围,list[str] 类型,默认未配置(list 也未配置时表示 dump 所有 API 的数据)。该参数可以在 [ ] 内配置两个模块名或 API 名,要求列表长度必须为2,需要配置按照工具命名格式的完整模块名或API名称,用于锁定区间,dump 该范围内的数据。
配置示例: "scope": ["Module.conv1.Conv2d.forward.0", "Module.fc2.Linear.forward.0"], 或 "scope": ["Cell.conv1.Conv2d.forward.0", "Cell.fc2.Dense.backward.0"], 或"scope": ["Tensor.add.0.forward", "Functional.square.2.forward"]。与 level 参数取值相关,level 为 L0 级别时,可配置模块名;level 为 L1 级别时,可配置 API 名, level为 mix 级别时,可配置为模块名或API名。
list自定义采集的算子列表,list[str] 类型,默认未配置(scope 也未配置时表示 dump 所有 API 的数据),包含以下配置方法:
PyTorch 和 MindSpore 动态图场景配置具体的 API 全称,dump 该 API 数据。在 PyTorch 场景,如果 level 配置成 L2,该配置为必填项。
配置示例:"list": ["Tensor.permute.1.forward", "Tensor.transpose.2.forward", "Torch.relu.3.backward"]。
PyTorch 和 MindSpore 动态图场景在level为 mix 级别时可以配置模块名称,dump该模块展开数据 (dump该模块从执行开始到执行结束期间的所有数据)。 +
PyTorch、MSAdapter 以及 MindSpore 动态图场景配置具体的 API 全称,dump 该 API 数据。在 PyTorch 场景,如果 level 配置成 L2,该配置为必填项。
配置示例:"list": ["Tensor.permute.1.forward", "Tensor.transpose.2.forward", "Torch.relu.3.backward"]。
PyTorch 和 MindSpore 动态图场景在level为 mix 级别时可以配置模块名称,dump该模块展开数据 (dump该模块从执行开始到执行结束期间的所有数据)。
配置示例:"list": ["Module.module.language_model.encoder.layers.0.mlp.ParallelMlp.forward.0"], 或 "list": ["Cell.network_with_loss.language_model.encoder.layers.0.mlp.ParallelMlp.forward.0"]
PyTorch 和 MindSpore 动态图场景指定某一类 API,dump 某一类的 API 级别输入输出数据。
配置示例:"list": ["relu"]。
PyTorch 和 MindSpore 动态图场景在level为 mix 级别时, 会dump名称中包含list中配置的字符串的API数据,还会将名称中包含list中配置的字符串的模块进行展开dump (dump该模块从执行开始到执行结束期间的所有数据)。
MindSpore 静态图场景配置 kernel_name,可以是算子的名称列表,也可以指定算子类型("level": "L2"时不支持),还可以配置算子名称的正则表达式(当字符串符合“name-regex(xxx)”格式时,后台则会将其作为正则表达式。
配置示例:list: ["name-regex(Default/.+)"]
可匹配算子名称以“Default/”开头的所有算子。
data_modedump 数据过滤,str 类型。
PyTorch 与 MindSpore 动态图场景:支持"all"、"forward"、"backward"、"input"和"output",除"all"外,其余参数可以自由组合。默认为["all"],即保存所有 dump 的数据。
配置示例:"data_mode": ["backward"] (仅保存反向数据)或 "data_mode": ["forward", "input"](仅保存前向的输入数据)。
MindSpore 静态图场景:仅支持"all"、"input"和"output"参数,且各参数只能单独配置,不支持自由组合。
配置示例:"data_mode": ["all"]。
summary_mode控制 dump 文件输出的模式,str 类型,仅 PyTorch 与 MindSpore 动态图场景支持,可选参数:
md5:dump 输出包含 CRC-32 值以及 API 统计信息的 dump.json 文件,用于验证数据的完整性;
statistics:dump 仅输出包含 API 统计信息的 dump.json 文件,默认值。
配置示例:"summary_mode": "md5"。
MindSpore静态图jit_level=O2场景L2级dump,支持上述配置的同时额外支持配置统计项列表,可选统计项为max、min、mean、l2norm,可从中任意选取组合搭配。其中mean、l2norm的结果为float数据格式。
配置示例:"summary_mode": ["max", "min"]。
PyTorch、MSAdapter 以及 MindSpore 动态图场景指定某一类 API,dump 某一类的 API 级别输入输出数据。
配置示例:"list": ["relu"]。
PyTorch、MSAdapter 以及 MindSpore 动态图场景在level为 mix 级别时, 会dump名称中包含list中配置的字符串的API数据,还会将名称中包含list中配置的字符串的模块进行展开dump (dump该模块从执行开始到执行结束期间的所有数据)。
MindSpore 静态图场景配置 kernel_name,可以是算子的名称列表,也可以指定算子类型(jit_level=O2 时不支持),还可以配置算子名称的正则表达式(当字符串符合“name-regex(xxx)”格式时,后台则会将其作为正则表达式。
配置示例:list: ["name-regex(Default/.+)"]
可匹配算子名称以“Default/”开头的所有算子。
tensor_list自定义采集真实数据的算子列表,list[str] 类型,默认未配置。包含以下配置方法:
PyTorch、MSAdapter 以及 MindSpore 动态图场景指定某一类 API 或模块,即会 dump 这一类 API 或模块输入输出的统计量信息和完整的 tensor 数据。
配置示例:"tensor_list": ["relu"]。
PyTorch、MSAdapter 以及 MindSpore 动态图场景目前只支持level配置为 L0, L1 和 mix 级别。
MindSpore 静态图场景不支持。
device控制统计值计算所用的设备,可选值["device", "host"],默认"host"。使用device计算会比host有性能加速,只支持min/max/avg/l2norm统计量。支持 MindSpore静态图 O0/O1 场景。
data_modedump 数据过滤,str 类型。
PyTorch、MSAdapter 以及 MindSpore 动态图场景:支持"all"、"forward"、"backward"、"input"和"output",除"all"外,其余参数可以自由组合。默认为["all"],即保存所有 dump 的数据。
配置示例:"data_mode": ["backward"] (仅保存反向数据)或 "data_mode": ["forward", "input"](仅保存前向的输入数据)。
MindSpore 静态图场景:L0 级别 dump 仅支持"all"、"forward"和"backward"参数;L2 级别 dump 仅支持"all"、"input"和"output"参数。且各参数只能单独配置,不支持自由组合。
配置示例:"data_mode": ["all"]。
summary_mode控制 dump 文件输出的模式,str 类型,支持 PyTorch、MSAdapter、MindSpore 动态图以及 MindSpore 静态图 L2 级别 jit_level=O2 场景和 L0 级别 jit_level=O0/O1 场景。
PyTorch、MSAdapter 以及 MindSpore 动态图场景:可选参数为
md5:dump 输出包含 CRC-32 值以及 API 统计信息的 dump.json 文件,用于验证数据的完整性;
statistics:dump 仅输出包含 API 统计信息的 dump.json 文件,默认值。
配置示例:"summary_mode": "md5"。
MindSpore 静态图 L2 级别 jit_level=O2 场景:支持上述配置的同时额外支持配置统计项列表,可选统计项为max、min、mean、l2norm,可从中任意选取组合搭配。其中mean、l2norm的结果为float数据格式。
MindSpore 静态图 L2 级别 jit_level=O0/O1 场景:支持上述配置的同时额外支持配置统计项列表,可选统计项为max、min、mean、l2norm、count、negative zero count、zero count、positive zero count、nan count、negative inf count、positive inf count、hash、md5,可从中任意选取组合搭配。注意:hash统计项在MindSpore2.7.0及以前版本计算MD5值,在以后版本计算SHA1值。
MindSpore 静态图 L0 级别 jit_level=O0/O1场景:仅支持上述配置中"statistics"字段和max、min、mean、l2norm中任意组合搭配的统计项列表。
配置示例:"summary_mode": ["max", "min"]。
-**说明**:"summary_mode"配置为"md5"时,所使用的校验算法为CRC-32算法。 +**说明**: + + +1. PyTorch、MSAdapter 以及 MindSpore 动态图场景,"summary_mode" 配置为 "md5" 时,所使用的校验算法为 CRC-32 算法;MindSpore 静态图场景,"summary_mode" 配置为 "md5" 时,所使用的校验算法为 MD5 算法。 + +**示例**: + - [PyTorch场景](03.config_examples.md#11-task-配置为-statistics) + - [MindSpore静态图场景](03.config_examples.md#21-task-配置为-statistics) + - [MindSpore动态图场景](03.config_examples.md#31-task-配置为-statistics) ### 1.3 task 配置为 tensor @@ -60,12 +72,14 @@ | list | 与[ 1.2 task 配置为 statistics ](#12-task-配置为-statistics)中的解释相同。 | 否 | | data_mode | 与[ 1.2 task 配置为 statistics ](#12-task-配置为-statistics)中的解释相同 | 否 | | file_format | tensor 数据的保存格式,str 类型,仅支持 MindSpore 静态图场景的 L2 级别配置该字段,其他场景不生效。可选参数:
"bin":dump 的 tensor 文件为二进制格式;
"npy":dump 的 tensor 文件后缀为 .npy,默认值。 | 否 | -| online_run_uta | 在线预检模式开关,bool 类型,可选参数 true(开启)、false(关闭),默认未配置,表示关闭。配置为 true 表示开启在线预检。| 否 | -| nfs_patha | 在线预检模式共享存储目录路径,str 类型,用于 GPU 设备和 NPU 设备间进行通信。仅在 online_run_ut 字段配置为 true 时生效,配置该参数后 host 和 port 不生效。 | 否 | -| hosta | 在线预检模式局域网场景信息接收端 IP,str 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的局域网 IP 地址。仅在 online_run_ut 字段配置为 true 时生效,局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 | -| porta | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的端口号。仅在 online_run_ut 字段配置为 true 时生效,局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。| 否 | +| summary_mode | 控制 dump 文件输出的模式,str 类型,支持 PyTorch、MSAdapter、MindSpore 动态图。可选参数:
md5:dump 输出包含 CRC-32 值以及 API 统计信息的 dump.json 文件,用于验证数据的完整性;
statistics:dump 仅输出包含 API 统计信息的 dump.json 文件,默认值。| 否 | + + +**示例**: + - [PyTorch场景](03.config_examples.md#12-task-配置为-tensor) + - [MindSpore静态图场景](03.config_examples.md#22-task-配置为-tensor) + - [MindSpore动态图场景](03.config_examples.md#32-task-配置为-tensor) -**a**:online_run_ut、nfs_path、host、port 等字段仅在线预检场景 NPU 机器生效。 ### 1.4 task 配置为 run_ut @@ -74,28 +88,46 @@ | white_lista | API dump 白名单,仅对指定的 API 进行 dump。
**配置示例**:"white_list": ["conv1d", "conv2d"]。默认未配置白名单,即 dump 全量 API 数据。 | 否 | | black_lista | API dump 黑名单,被指定的 API 不进行 dump。
**配置示例**:"black_list": ["conv1d", "conv2d"]。默认未配置黑名单,即 dump 全量 API 数据。 | 否 | | error_data_path | 配置保存精度未达标的 API 输入输出数据路径,默认为当前路径。
**配置示例**:"error_data_path": "./"。 | 否 | -| is_onlineb | 在线预检模式开关,bool 类型,可选参数 true(开启)、false(关闭),默认关闭。 | 否 | -| nfs_pathb | 在线预检模式共享存储目录路径,str 类型,用于 GPU 设备和 NPU 设备间进行通信。配置该参数后 host 和 port 不生效,仅在 is_online 字段配置为 true 时生效。 | 否 | -| hostb | 在线预检模式局域网场景信息接收端 IP,str 类型,用于 GPU 设备和 NPU 设备间进行通信,GPU 侧配置为本机地址 127.0.0.1 或本机局域网 IP。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。仅在 is_online 字段配置为 true 时生效。 | 否 | -| portb | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,GPU 侧配置为本机可用端口。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。仅在 is_online 字段配置为 true 时生效。| 否 | -| rank_listb | 指定在线预检的 Rank ID,默认值为 [0],list[int] 类型,应配置为大于等于 0 的整数,且须根据实际卡的 Rank ID 配置,若所配置的值大于实际训练所运行的卡的 Rank ID,则在线预检输出数据为空。GPU 和 NPU 须配置一致。仅在 is_online 字段配置为 true 时生效。 | 否 | -**a**:white_list 和 black_list 同时配置时,二者配置的 API 名单若无交集,则白名单生效,若 API 名单存在交集,则白名单排除的部分以及交集的 API 不进行 dump。 +**说明**: + +1. white_list 和 black_list 同时配置时,二者配置的 API 名单若无交集,则白名单生效,若 API 名单存在交集,则白名单排除的部分以及交集的 API 不进行 dump。 + + +**示例**: +```json +{ + "task": "run_ut", + "dump_path": "/home/data_dump", + "rank": [], + "step": [], + "level": "L1", -**b**:is_online、nfs_path、host、port、rank_list 等字段仅在线预检场景 GPU 机器生效。 + "run_ut": { + "white_list": [], + "black_list": [], + "error_data_path": "./" + } +} +``` ### 1.5 task 配置为 overflow_check -PyTorch 与 MindSpore 动态图场景下,"level"须为"L0"或"L1";MindSpore 静态图场景下,"level"须为"L2",且模型编译优化等级(jit_level)须为"O2"。 +PyTorch、MSAdapter 以及 MindSpore 动态图场景下,"level"须为"L0"或"L1";MindSpore 静态图场景下,"level"须为"L2",且模型编译优化等级(jit_level)须为"O2"。 | 参数 | 解释 | 是否必选 | | ------------- | ---------------------- | -------- | -| overflow_nums | 最大溢出次数,int 类型,默认为 1,仅 PyTorch 与 MindSpore 动态图场景支持。表示第 N 次溢出后,不再进行溢出检测。过程中检测到溢出 API 对应的 输入输出 数据均 dump。
**配置示例**:"overflow_nums": 3。配置为 -1 时,表示持续检测溢出直到训练结束。 | 否 | -| check_mode | 溢出类型,str 类型,仅 MindSpore 场景支持,可选参数:
"aicore":开启 AI Core 的溢出检测,不支持 MindSpore v2.3.0 以上版本;
"atomic":开启 Atomic 的溢出检测,不支持 MindSpore v2.3.0 以上版本;
"all":开启算子的溢出检测,默认值。
**配置示例**:"check_mode": "all"。 | 否 | +| overflow_nums | 最大溢出次数,int 类型,默认为 1,仅 PyTorch、MSAdapter 以及 MindSpore 动态图场景支持。表示第 N 次溢出后,不再进行溢出检测。过程中检测到溢出 API 对应的 输入输出 数据均 dump。
**配置示例**:"overflow_nums": 3。配置为 -1 时,表示持续检测溢出直到训练结束。 | 否 | +| check_mode | 溢出类型,str 类型,仅 MindSpore v2.3.0 以下版本的静态图场景支持,可选参数:
"aicore":开启 AI Core 的溢出检测;
"atomic":开启 Atomic 的溢出检测;
"all":开启算子的溢出检测,默认值。
**配置示例**:"check_mode": "all"。 | 否 | + +**示例**: + - [PyTorch场景](03.config_examples.md#14-task-配置为-overflow_check) + - [MindSpore静态图场景](03.config_examples.md#23-task-配置为-overflow_check) + - [MindSpore动态图场景](03.config_examples.md#33-task-配置为-overflow_check) ### 1.6 task 配置为 free_benchmark -仅 PyTorch 场景与 MindSpore 动态图场景支持,且"level"为"L1"。 +仅 PyTorch 与 MindSpore 动态图场景支持,且"level"为"L1"。 - task 配置为 free_benchmark 时,开启**无标杆比对**,在 NPU 环境下通过对当前模型 API 的输入添加扰动因子,二次执行,将得到的输出与未添加扰动因子前的输出进行比对,从而**得出该模型中可能存在因迁移等变化导致精度降低的 API**。 @@ -119,6 +151,10 @@ PyTorch 与 MindSpore 动态图场景下,"level"须为"L0"或"L1";MindSpore max_sample每个算子预热的采样次数的最大阈值(仅 PyTorch 场景支持),int 类型,默认值为 20。须配置 "if_preheat": "true"。否 +**示例**: + - [PyTorch场景](03.config_examples.md#15-task-配置为-free_benchmark) + - [MindSpore动态图场景](03.config_examples.md#34-task-配置为-free_benchmark) + #### 1.6.1 无标杆比对数据存盘格式 无标杆比对在 dump_path 目录下输出结果文件 `free_benchmark.csv`,如下示例: @@ -162,5 +198,24 @@ PyTorch 与 MindSpore 动态图场景下,"level"须为"L0"或"L1";MindSpore | L1 | ("param_name", "max", "min", "norm", "shape") | 是 | | L2 | ("param_name", *intervals, "=0", "max", "min", "norm", "shape") | 是 | - intervals就是根据值分布bounds划分出的区间。 - MindSpore静态图模式下,L0级别中暂不支持"MD5" +**说明**: + +1. intervals就是根据值分布bounds划分出的区间。 +2. MindSpore静态图模式下,L0级别中暂不支持"MD5" + +### 1.8 task 配置为 structure +structure 模式仅采集模型结构,无其他特殊配置。 + +**示例**: + - [PyTorch场景](03.config_examples.md#16-task-配置为-structure) + - [MindSpore动态图场景](03.config_examples.md#35-task-配置为-structure) + +### 1.9 task 配置为 exception_dump +MindSpore 动态图场景下,"level"须为"L2"; MindSpore 静态图场景下,"level"须为"L2",且模型编译优化等级(jit_level)须为"O0"或"O1"。 + +在运行过程中会在指定目录下生成kernel_graph_exception_dump.json的中间文件,该文件包含异常dump的相关设置。 +除中间文件外的其他 dump 结果文件请参见 MindSpore 官方文档中的[ Ascend 下 O0/O1 模式 Dump 数据对象目录和数据文件介绍](https://www.mindspore.cn/docs/zh-CN/r2.5.0/model_train/debug/dump.html#%E6%95%B0%E6%8D%AE%E5%AF%B9%E8%B1%A1%E7%9B%AE%E5%BD%95%E5%92%8C%E6%95%B0%E6%8D%AE%E6%96%87%E4%BB%B6%E4%BB%8B%E7%BB%8D) + +**示例**: + - [MindSpore动态图场景](03.config_examples.md#36-task-配置为-exception_dump) + - [MindSpore静态图场景](03.config_examples.md#24-task-配置为-exception_dump) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/03.config_examples.md b/debug/accuracy_tools/msprobe/docs/03.config_examples.md index 542250fac243f3ab2f1d0aff87bc509ac7c1a675..6350099be5a100a62d2f196c53c56baec7cca6e4 100644 --- a/debug/accuracy_tools/msprobe/docs/03.config_examples.md +++ b/debug/accuracy_tools/msprobe/docs/03.config_examples.md @@ -13,10 +13,12 @@ "rank": [], "step": [], "level": "L1", + "async_dump": false, "statistics": { "scope": [], "list": [], + "tensor_list": [], "data_mode": ["all"], "summary_mode": "statistics" } @@ -32,6 +34,7 @@ "rank": [], "step": [], "level": "L1", + "async_dump": false, "tensor": { "scope": [], @@ -165,6 +168,18 @@ } ``` +### 2.4 task 配置为 exception_dump + +```json +{ + "task": "exception_dump", + "dump_path": "/home/data_dump", + "rank": [], + "step": [], + "level": "L2" +} +``` + ## 3 MindSpore 动态图场景 ### 3.1 task 配置为 statistics @@ -252,3 +267,15 @@ "level": "mix" } ``` + +### 3.6 task 配置为 exception_dump + +```json +{ + "task": "exception_dump", + "dump_path": "/home/data_dump", + "rank": [], + "step": [], + "level": "L2" +} +``` diff --git a/debug/accuracy_tools/msprobe/docs/04.kernel_dump_PyTorch.md b/debug/accuracy_tools/msprobe/docs/04.kernel_dump_PyTorch.md index ce3fd54f5a6741b262f6248f70a9f1166ca0b4a6..346481aad12c42994669b7b3ea794843e49c1618 100644 --- a/debug/accuracy_tools/msprobe/docs/04.kernel_dump_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/04.kernel_dump_PyTorch.md @@ -6,7 +6,7 @@ ## 1 kernel dump 配置示例 -使用 kernel dump 时,list 必须要填一个 API 名称,kernel dump 目前每个 step 只支持采集一个 API 的数据。 +使用 kernel dump 时,task 需要配置为 tensor , list 必须要填一个 API 名称,kernel dump 目前每个 step 只支持采集一个 API 的数据。 API 名称填写参考 L1 dump 结果文件 dump.json 中的API名称,命名格式为:`{api_type}.{api_name}.{API调用次数}.{forward/backward}`。 ```json diff --git a/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md b/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md index db9a989c9d1c731fd9099d311f3ab3b95e5c7d5d..823710244b5a8e7008fa3603edf8d43f1d7dae46 100644 --- a/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md @@ -2,7 +2,7 @@ msprobe 工具主要通过在训练脚本内添加 dump 接口、启动训练的方式采集精度数据。 -dump的'tensor'模式采集数据量大小,可以参考[数据量基线](./26.data_dump_PyTorch_baseline.md)。 +dump "statistics"模式的性能膨胀大小"与"tensor"模式采集的数据量大小,可以参考[dump基线](./26.data_dump_PyTorch_baseline.md)。 本工具提供固定的 API 支持列表,若需要删除或增加 dump 的 API,可以在 msprobe/pytorch/hook_module/support_wrap_ops.yaml 文件内手动修改,如下示例: @@ -15,6 +15,54 @@ functional: # functional为算子类别,找到对应的类别,在该类别 删除API的场景:部分模型代码逻辑会存在API原生类型校验,工具执行dump操作时,对模型的API封装可能与模型的原生API类型不一致,此时可能引发校验失败,详见《[FAQ](FAQ.md)》中“异常情况”的第10和11条。 +加工具后loss/gnorm发生变化:可能是工具中的item操作引入同步,pt/ms框架的hook机制等原因导致的,详见《[工具导致计算结果变化](36.calculation_result_change.md)》。 + +## 快速上手 + +这个示例定义了一个 nn.Module 类型的简单网络,使用原型函数 PrecisionDebugger 进行数据采集。 + +```python +# 根据需要import包 +import torch +import torch.nn as nn +import torch.nn.functional as F + +# 导入工具的数据采集接口 +from msprobe.pytorch import PrecisionDebugger, seed_all + +# 在模型训练开始前固定随机性 +seed_all() + +# 在模型训练开始前实例化PrecisionDebugger +debugger = PrecisionDebugger() + +# 定义网络 +class ModuleOP(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear_1 = nn.Linear(in_features=8, out_features=4) + self.linear_2 = nn.Linear(in_features=4, out_features=2) + + def forward(self, x): + x1 = self.linear_1(x) + x2 = self.linear_2(x1) + r1 = F.relu(x2) + return r1 + +if __name__ == "__main__": + module = ModuleOP() + + # 开启数据 dump + debugger.start(model=module) + x = torch.randn(10, 8) + out = module(x) + loss = out.sum() + loss.backward() + + # 关闭数据 dump + debugger.stop() +``` + ## 1 接口介绍 ### 1.1 PrecisionDebugger @@ -30,9 +78,11 @@ PrecisionDebugger(config_path=None, task=None, dump_path=None, level=None, model 1. config_path:指定 dump 配置文件路径; 2. model:指定需要采集 Module 级数据的模型,支持传入 torch.nn.Module 或 list[torch.nn.Module] 类型,默认未配置。 level 配置为"L0"或"mix"时,必须在该接口或 **start** 接口中配置该参数。该参数在将来会从该接口移除,建议在 **start** 接口中配置该参数。 -3. 其他参数均在 [config.json](../config.json) 文件中可配,详细配置可见 [config.json 介绍](./02.config_introduction.md)。 +3. 其他参数均在 config.json 文件中可配,详细配置可见 [config.json 介绍](./02.config_introduction.md)。 + +此接口的参数均不是必要(均不配置的情况下默认采集所有 rank 和 step 的 L1 级别的统计数据),且优先级高于 config.json 文件中的配置,但可配置的参数相比 config.json 较少。 -此接口的参数均不是必要,且优先级高于 [config.json](../config.json) 文件中的配置,但可配置的参数相比 config.json 较少。 +注:此接口的初始化需与采集目标在同一个进程中,否则将无法采集目标数据。 ### 1.2 start @@ -41,19 +91,22 @@ level 配置为"L0"或"mix"时,必须在该接口或 **start** 接口中配置 **原型**: ```Python -debugger.start(model=None) +debugger.start(model=None, token_range=None) ``` 1. model:指定需要采集 Module 级数据的模型,支持传入 torch.nn.Module、list[torch.nn.Module]或Tuple[torch.nn.Module] 类型,默认未配置。 -level 配置为"L0"或"mix"时,必须在该接口或 **PrecisionDebugger** 接口中配置该参数。 +level 配置为"L0"|"mix"或token_range不为None时,必须在该接口或 **PrecisionDebugger** 接口中配置该参数。 本接口中的 model 比 PrecisionDebugger 中 model 参数优先级更高,会覆盖 PrecisionDebugger 中的 model 参数。 +
对于复杂模型,如果仅需要监控一部分(如model.A,model.A extends torch.nn.Module),传入需要监控的部分(如model.A)即可。 +注意:传入的当前层不会被dump,工具只会dump传入层的子层级。如传入了model.A,A本身不会被dump,而是会dump A.x, A.x.xx等。 +2. token_range:指定推理模型采集时的token循环始末范围,支持传入[int, int]类型,代表[start, end],范围包含边界,默认未配置。 ### 1.3 stop **功能说明**:停止精度数据采集。在 **start** 函数之后的任意位置添加。 若 **stop** 函数添加在反向计算代码(如loss.backward)之后,则会采集 **start** 和该函数之间的前反向数据。 若 **stop** 函数添加在反向计算代码之前,则需要将 [**step**](#15-step) 函数添加到反向计算代码之后,才能采集 **start** 和该函数之间的前反向数据。 -使用示例可参见 [2.1 快速上手](#21-快速上手) 和 [2.2 采集完整的前反向数据](#22-采集完整的前反向数据)。 +使用示例可参见 [快速上手](#快速上手) 和 [2.1 采集完整的前反向数据](#21-采集完整的前反向数据)。 **注意**:**stop** 函数必须调用,否则可能导致精度数据落盘不全。 @@ -66,7 +119,7 @@ debugger.stop() ### 1.4 forward_backward_dump_end **功能说明**:停止精度数据采集。与 **stop** 函数功能相同,该函数在将来会被移除,建议使用 **stop** 函数。 -使用示例可参见 [2.3 采集指定代码块的前反向数据](#23-采集指定代码块的前反向数据)。 +使用示例可参见 [2.2 采集指定代码块的前反向数据](#22-采集指定代码块的前反向数据)。 **原型**: @@ -77,7 +130,7 @@ forward_backward_dump_end() ### 1.5 step **功能说明**:结束一个 step 的数据采集,完成所有数据落盘并更新 dump 参数。在一个 step 结束的位置添加,且必须在 **stop** 函数之后的位置调用。 -该函数需要配合 **start** 和 **stop** 函数使用,尽量添加在反向计算代码(如loss.backward)之后,否则可能会导致反向数据丢失。使用示例可参见[2.2 采集完整的前反向数据](#22-采集完整的前反向数据)。 +该函数需要配合 **start** 和 **stop** 函数使用,尽量添加在反向计算代码(如loss.backward)之后,否则可能会导致反向数据丢失。使用示例可参见[2.1 采集完整的前反向数据](#21-采集完整的前反向数据)。 **原型**: @@ -88,7 +141,7 @@ debugger.step() ### 1.6 module_dump **功能说明**:开启模块级精度数据dump。该接口为函数模块化接口,即只会dump输入的模块数据,不会dump子模块和模块内API的数据。 -需要配合start、stop和step等接口使用。使用示例可参考[2.4 采集函数模块化数据](#24-采集函数模块化数据) +需要配合start、stop和step等接口使用。使用示例可参考[2.3 采集函数模块化数据](#23-采集函数模块化数据) **原型**: @@ -117,14 +170,14 @@ module_dump_end() **原型**: ```python -seed_all(seed=1234, mode=False, rm_dropout=True) +seed_all(seed=1234, mode=False, rm_dropout=False) ``` **参数说明**: 1. seed: 随机性种子。参数示例: seed=1000。默认值:1234。非必选 2. mode:确定性计算模式。可配置True或False。参数示例:mode=True。默认为False。非必选(注意:确定性计算会导致API执行性能降低,建议在发现模型多次执行结果不同的情况下开启) -3. rm_dropout:控制dropout失效的开关。可配置 True 或 False,默认值:True,非必选。参数示例:rm_dropout=True。 +3. rm_dropout:控制dropout失效的开关。可配置 True 或 False,默认值:False,非必选。参数示例:rm_dropout=True。 该参数设置为 True 后, 工具会自动将 `torch.nn.functional.dropout`、`torch.nn.functional.dropout2d`、`torch.nn.functional.dropout3d`、`torch.nn.Dropout`、`torch.nn.Dropout2d`、`torch.nn.Dropout3d` 的接口参数 p 置为0,以避免因随机dropout造成的网络随机性。 注意:通过rm_dropout控制dropout失效或生效需要在初始化dropout实例前调用才能生效。 @@ -183,58 +236,67 @@ save(variable, name, save_backward=True) **参数说明**: | 参数名称 | 参数含义 | 支持数据类型 | 是否必选| | ---------- | ------------------| ------------------- | ------------------- | -| variable | 需要保存的变量 |dict, list, torch.tensor, int, float, str | 是 | +| variable | 需要保存的变量 |dict, list, tuple, torch.tensor, int, float, str | 是 | | name | 指定的名称 | str | 是 | | save_backward | 是否保存反向数据 | boolean | 否 | -## 2 示例代码 +具体使用样例可参考:[单点保存工具使用介绍](./28.debugger_save_instruction.md)。 -### 2.1 快速上手 +### 1.10 set_init_step -这个示例定义了一个 nn.Module 类型的简单网络,在进行数据采集时使用原型函数 PrecisionDebugger 传入 config_path 参数和 model 参数。 +**功能说明**:设置起始step数,step数默认从0开始计数,使用该接口后step从指定值开始计数。该函数需要写在训练迭代的循环开始前,不能写在循环内。 -```python -# 根据需要import包 -import torch -import torch.nn as nn -import torch.nn.functional as F +**原型**: -# 导入工具的数据采集接口 -from msprobe.pytorch import PrecisionDebugger, seed_all +```Python +debugger.set_init_step(step) +``` -# 在模型训练开始前固定随机性 -seed_all() -# 在模型训练开始前实例化PrecisionDebugger -debugger = PrecisionDebugger(config_path='./config.json') +**参数说明**: -# 定义网络 -class ModuleOP(nn.Module): - def __init__(self) -> None: - super().__init__() - self.linear_1 = nn.Linear(in_features=8, out_features=4) - self.linear_2 = nn.Linear(in_features=4, out_features=2) +1.step: 指定的起始step数。 - def forward(self, x): - x1 = self.linear_1(x) - x2 = self.linear_2(x1) - r1 = F.relu(x2) - return r1 +### 1.11 register_custom_api -if __name__ == "__main__": - module = ModuleOP() - # 开启数据 dump - debugger.start(model=module) +**功能说明**:注册用户自定义的api到工具用于 L1 dump 。 - x = torch.randn(10, 8) - out = module(x) - loss = out.sum() - loss.backward() +**原型**: - # 关闭数据 dump - debugger.stop() +```Python +debugger.register_custom_api(module, api_name, api_prefix) ``` +**参数说明**: + +以 torch.matmul api 为例 + +1.module: api 所属的包,即传入 torch。 + +2.api_name: api 名,string类型,即传入 "matmul"。 + +3.api_prefix: [dump.json](./27.dump_json_instruction.md) 中 api 名的前缀,可选,默认为包名的字符串格式, 即 "torch"。 -### 2.2 采集完整的前反向数据 +### 1.12 restore_custom_api + +**功能说明**:恢复用户原有的自定义的api,取消 dump 。 + +**原型**: + +```Python +debugger.restore_custom_api(module, api_name) +``` +**参数说明**: + +以 torch.matmul api 为例 + +1.module: api 所属的包,即传入 torch。 + +2.api_name: api 名,string类型,即传入 "matmul"。 + + +## 2 示例代码 + + +### 2.1 采集完整的前反向数据 ```Python from msprobe.pytorch import PrecisionDebugger, seed_all @@ -255,7 +317,7 @@ for data, label in data_loader: debugger.step() # 结束一个step的dump ``` -### 2.3 采集指定代码块的前反向数据 +### 2.2 采集指定代码块的前反向数据 ```Python from msprobe.pytorch import PrecisionDebugger, seed_all @@ -279,7 +341,7 @@ for data, label in data_loader: debugger.step() # 结束一个step的dump ``` -### 2.4 采集函数模块化数据 +### 2.3 采集函数模块化数据 ```Python # 根据需要import包 @@ -321,6 +383,80 @@ if __name__ == "__main__": debugger.stop() ``` +### 2.4 跨文件采集数据 +为了确保所有API都被工具封装,PrecisionDebugger的实例化通常放在训练工程的入口位置,但有的时候,模型定义会在另一个文件中。 假设有两个文件,train.py(为训练工程入口)module.py(为模型定义文件),为了采集module.py中定义的ModuleOP模块中某些子模块或API的前反向数据,需要在train.py和module.py文件中分别导入PrecisionDebugger并进行如下配置。 + +train.py文件: + +```Python +# 根据需要import包 +import torch +from module import ModuleOP + +# 导入工具的数据采集接口 +from msprobe.pytorch import PrecisionDebugger + +# 将PrecisionDebugger的实例化放在文件的开始位置,即导包后的位置,确保所有API都被封装 +debugger = PrecisionDebugger(config_path='./config.json') + +if __name__ == "__main__": + module = ModuleOP() + + x = torch.randn(10, 8) + out = module(x) + loss = out.sum() + loss.backward() +``` + +module.py文件: + +```Python +import torch +import torch.nn as nn +import torch.nn.functional as F + +from msprobe.pytorch import PrecisionDebugger + +# 定义网络 +class ModuleOP(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear_1 = nn.Linear(in_features=8, out_features=4) + self.linear_2 = nn.Linear(in_features=4, out_features=2) + + def forward(self, x): + PrecisionDebugger.start() + x1 = self.linear_1(x) + PrecisionDebugger.stop() + x2 = self.linear_2(x1) + r1 = F.relu(x2) + return r1 + +``` + +### 2.5 推理模型采集指定token_range + +```Python +from vllm import LLM, SamplingParams +from msprobe.pytorch import PrecisionDebugger, seed_all +# 在模型训练开始前固定随机性 +seed_all() +# 请勿将PrecisionDebugger的初始化流程插入到循环代码中 +debugger = PrecisionDebugger(config_path="./config.json", dump_path="./dump_path") +# 模型定义及初始化等操作 +prompts = ["Hello, my name is"] +sampling_params = SamplingParams(temprature=0.8, top_p=0.95) +llm = LLM(model='...') +model = llm.llm_engine.model_executor.driver_worker.worker.model_runner.get_model() +# 开启数据dump, 指定采集推理模型逐字符循环推理中的第1~3次 +debugger.start(model=model, token_range=[1,3]) +# 推理模型生成的逻辑 +output = llm.generate(prompts, sampling_params=sampling_params) +# 关闭数据dump并落盘 +debugger.stop() +debugger.step() +``` + ## 3 dump 结果文件介绍 训练结束后,工具将 dump 的数据保存在 dump_path 参数指定的目录下。目录结构示例如下: @@ -334,17 +470,19 @@ if __name__ == "__main__": | | | | ├── Functional.linear.5.backward.output.pt # 命名格式为{api_type}.{api_name}.{API调用次数}.{forward/backward}.{input/output}.{参数序号}, 其中,“参数序号”表示该API的第n个输入或输出,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该API的第1个参数的第1个元素。 | | | | ... | | | | ├── Module.conv1.Conv2d.forward.0.input.0.pt # 命名格式为{Module}.{module_name}.{class_name}.{forward/backward}.{调用次数}.{input/output}.{参数序号}, 其中,“参数序号”表示该Module的第n个参数,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该Module的第1个参数的第1个元素。 -| | | | ├── Module.conv1.Conv2D.forward.0.parameters.bias.pt # 模块参数数据:命名格式为{Module}.{module_name}.{class_name}.forward.{调用次数}.parameters.{parameter_name}。 -| | | | └── Module.conv1.Conv2D.parameters_grad.weight.pt # 模块参数梯度数据:命名格式为{Module}.{module_name}.{class_name}.parameters_grad.{parameter_name}。因为同一模块的参数使用同一梯度进行更新,所以参数梯度文件名不包含调用次数。 +| | | | ├── Module.conv1.Conv2d.forward.0.parameters.bias.pt # 模块参数数据:命名格式为{Module}.{module_name}.{class_name}.forward.{调用次数}.parameters.{parameter_name}。 +| | | | └── Module.conv1.Conv2d.parameters_grad.weight.pt # 模块参数梯度数据:命名格式为{Module}.{module_name}.{class_name}.parameters_grad.{parameter_name}。因为同一模块的参数使用同一梯度进行更新,所以参数梯度文件名不包含调用次数。 | | | | # 当dump时传入的model参数为List[torch.nn.Module]或Tuple[torch.nn.Module]时,模块级数据的命名中包含该模块在列表中的索引index,命名格式为{Module}.{index}.*,*表示以上三种模块级数据的命名格式,例如:Module.0.conv1.Conv2d.forward.0.input.0.pt。 │ | | ├── dump.json │ | | ├── stack.json +│ | | ├── dump_error_info.log │ | | └── construct.json │ | ├── rank1 | | | ├── dump_tensor_data | | | | └── ... │ | | ├── dump.json │ | | ├── stack.json +│ | | ├── dump_error_info.log | | | └── construct.json │ | ├── ... │ | | @@ -355,7 +493,8 @@ if __name__ == "__main__": ``` * `rank`:设备 ID,每张卡的数据保存在对应的 `rank{ID}` 目录下。非分布式场景下没有 rank ID,目录名称为 rank。 * `dump_tensor_data`:保存采集到的张量数据。 -* `dump.json`: 保存API或Module前反向数据的统计量信息。包含dump数据的API名称或Module名称,各数据的dtype、 shape、max、min、mean、L2norm(L2范数,平方根)统计信息以及当配置summary_mode="md5"时的CRC-32数据。具体介绍可参考[dump.json文件说明](./27.dump_json_instruction.md#1-dumpjson文件介绍pytorch)。 +* `dump.json`: 保存API或Module前反向数据的统计量信息。包含dump数据的API名称或Module名称,各数据的dtype、 shape、max、min、mean、L2norm(L2范数,平方根)统计信息以及当配置summary_mode="md5"时的CRC-32数据。具体介绍可参考[dump.json文件说明](./27.dump_json_instruction.md#1-PyTorch场景下的dump.json文件)。 +* `dump_error_info.log`: 仅在dump工具报错时拥有此记录日志,用于记录dump错误日志。 * `stack.json`:API/Module的调用栈信息。 * `construct.json`:分层分级结构,level为L1时,construct.json内容为空。 @@ -366,12 +505,14 @@ dump 过程中,pt 文件在对应算子或者模块被执行后就会落盘, pt 文件保存的前缀和 PyTorch 对应关系如下: -| 前缀 | Torch模块 | -| ----------- | ------------------- | +| 前缀 | Torch模块 | +|-------------|---------------------| | Tensor | torch.Tensor | | Torch | torch | | Functional | torch.nn.functional | -| NPU | NPU 亲和算子 | +| NPU | NPU 亲和算子 | | VF | torch._VF | | Aten | torch.ops.aten | | Distributed | torch.distributed | +| MindSpeed | mindspeed.ops | + diff --git a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md index f7507facd2a92f3acbefdc92fa6cd808a155d6e3..66b76cc1fb88d58f7a663e3e90fb95be45883dec 100644 --- a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md @@ -26,18 +26,26 @@ msprobe 工具通过在训练脚本中添加 `PrecisionDebugger` 接口并启动训练的方式,采集模型在运行过程中的精度数据。该工具支持对MindSpore的静态图和动态图场景进行不同Level等级的精度数据采集。 -dump 的"tensor"模式采集数据量大小,可以参考[数据量基线](data_dump_MindSpore/data_dump_MindSpore_baseline.md)。 +dump "statistics"模式的性能膨胀大小"与"tensor"模式采集的数据量大小,可以参考[dump基线](data_dump_MindSpore/data_dump_MindSpore_baseline.md)。 + +**注意**: + +* 因 MindSpore 框架自动微分机制的限制,dump 数据中可能会缺少原地操作模块/API 及其上一个模块/API 的反向数据。 + +* 使用msprobe工具后loss/gnorm发生变化:可能是工具中的item操作引入同步,pt/ms框架的hook机制等原因导致的,详见《工具导致计算结果变化》。 ## 5. 场景介绍 ### 5.1 静态图场景 -在静态图场景下,msprobe 仅支持 **L2 Level** 的数据采集。 +在静态图场景下,msprobe 支持 **L0 Level** 和 **L2 Level** 的数据采集。且当 MindSpore 版本高于 2.5.0 时,若需采集 **L2 Level** 数据,必须使用编包时添加了`--include-mod=adump`选项的 mindstudio-probe whl 包进行 msprobe 工具安装。 +- **L0 Level(Cell 级)** :采集 `Cell` 对象的数据,适用于需要分析特定网络模块的情况。仅支持 2.7.0 及以上版本的 MindSpore 框架。 + - **L2 Level(Kernel 级)** :采集底层算子的输入输出数据,适用于深入分析算子级别的精度问题。 采集方式请参见[示例代码 > 静态图场景](#71-静态图场景)。详细介绍请参见[《config.json 配置文件介绍》](./02.config_introduction.md#11-通用配置)中的“level 参数”和[《config.json 配置示例》](./03.config_examples.md#2-mindspore-静态图场景) 中的“MindSpore 静态图场景”。 ### 5.2 动态图场景 -在动态图场景下,msprobe 支持 **L0** 、**L1** 、**mix** 、**L2 Level**、 **debug** 的数据采集,具体分为以下几种情况: +在动态图场景下,msprobe 支持 **L0** 、**L1** 、**mix** 、**L2**、 **debug** 的数据采集,具体分为以下几种情况: - **使用高阶 API(如 `Model 高阶API`)** : - 需要使用 `MsprobeStep` 回调类来控制数据采集的启停,适用于 **L0** 、**L1** 、**mix** 、**L2** 数据采集。 @@ -46,7 +54,7 @@ dump 的"tensor"模式采集数据量大小,可以参考[数据量基线](data 采集方式请参见[示例代码 > 动态图场景](#72-动态图场景)。 -> **注意** :动态图模式下,使用 `PSJit` 或 `PIJit` 装饰的部分实际以静态图模式执行,此时的 **Kernel 级(L2 Level)** 数据采集方式与静态图场景相同。 +> **注意** :动态图模式下,使用 `mindspore.jit` 装饰的部分实际以静态图模式执行,此时的 **Kernel 级(L2 Level)** 数据采集方式与静态图场景相同。 - **L0 Level(Cell 级)** :采集 `Cell` 对象的数据,适用于需要分析特定网络模块的情况。 - **L1 Level(API 级)** :采集 MindSpore API 的输入输出数据,适用于定位 API 层面的精度问题。 @@ -56,7 +64,7 @@ dump 的"tensor"模式采集数据量大小,可以参考[数据量基线](data - **debug level (单点保存)**:单点保存网络中变量的正反向数据,适用于用户熟悉网络结构的场景。 -详细介绍请参见[《config.json 配置文件介绍》](./02.config_introduction.md#11-通用配置)中的“level 参数”和[《config.json 配置示例》](./03.config_examples.md#3-mindspore-动态图场景) 中的“MindSpore 动态图场景”。 +详细介绍请参见[《config.json 配置文件介绍》](./02.config_introduction.md#11-通用配置)中的“level 参数”。 ## 6 接口介绍 @@ -80,22 +88,25 @@ PrecisionDebugger(config_path=None, task=None, dump_path=None, level=None, step= #### 6.1.1 start -**功能说明**:启动精度数据采集。需在模型执行模式(静态图/动态图、O0/O1/O2编译等级)设置后调用。静态图场景下,必须在模型初始化及 mindspore.communication.init 调用前添加;动态图场景下,如果没有使用 [Model](https://gitee.com/link?target=https%3A%2F%2Fwww.mindspore.cn%2Ftutorials%2Fzh-CN%2Fr2.3.1%2Fadvanced%2Fmodel.html) 高阶 API 进行训练,则需要与 stop 函数一起添加在 for 循环内,否则只有需要传入model参数时,才使用该接口。 +**功能说明**:启动精度数据采集。静态图场景下,必须在mindspore.communication.init 调用前添加。如果没有使用 [Model](https://gitee.com/link?target=https%3A%2F%2Fwww.mindspore.cn%2Ftutorials%2Fzh-CN%2Fr2.3.1%2Fadvanced%2Fmodel.html) 高阶 API 进行训练,则需要与 stop 函数一起添加在 for 循环内,否则只有需要传入model参数时,才使用该接口。 **原型**: ```Python -start(model=None) +start(model=None, token_range=None) ``` **参数说明**: -1. model:指定需要采集数据的实例化模型,支持传入mindspore.nn.Cell、List[mindspore.nn.Cell]或Tuple[mindspore.nn.Cell] 类型, 默认未配置。Cell级别("L0" level)dump 与 "mix" level dump 时,必须传入 model 才可以采集 model 内的所有Cell 对象数据。API级别("L1" level)dump 时,传入 model 可以采集 model 内包含 primitive op 对象在内的所有 API 数据,若不传入 model 参数,则只采集非 primitive op 的 API 数据。 +1. model:指定需要采集数据的实例化模型,支持传入mindspore.nn.Cell、List[mindspore.nn.Cell]或Tuple[mindspore.nn.Cell] 类型,默认未配置。Cell级别("L0" level)dump 与 "mix" level dump 时,必须传入 model 才可以采集 model 内的所有 Cell 对象数据,且若存在会进行图编译的 Cell 对象(例如被 `mindspore.jit` 装饰的 Cell),则必须在第一个 step 训练开始前调用 `start` 接口。API级别("L1" level)dump 时,传入 model 可以采集 model 内包含 primitive op 对象在内的所有 API 数据,若不传入 model 参数,则只采集非 primitive op 的 API 数据。token_range不为None时,必须传入model参数。 +
对于复杂模型,如果仅需要监控一部分(如model.A,model.A extends mindspore.nn.Cell),传入需要监控的部分(如model.A)即可。 +注意:传入的当前层不会被dump,工具只会dump传入层的子层级。如传入了model.A,A本身不会被dump,而是会dump A.x, A.x.xx等。 +2. token_range:指定推理模型采集时的token循环始末范围,支持传入[int, int]类型,代表[start, end],范围包含边界,默认未配置。 #### 6.1.2 stop **功能说明**:停止精度数据采集。在 **start** 函数之后的任意位置添加。若 **stop** 函数添加在反向计算代码之后,则会采集 **start** 和该函数之间的前反向数据。 -若 **stop** 函数添加在反向计算代码之前,则需要将 [**step**](#613-step) 函数添加到反向计算代码之后,才能采集 **start** 和该函数之间的前反向数据。 +若 **stop** 函数添加在反向计算代码之前,则需要将 [**step**](#613-step) 函数添加到反向计算代码之后,才能采集 **start** 和该函数之间的前反向数据,参考[**采集指定代码块的前反向数据**](#7213-采集指定代码块的前反向数据)。 **仅未使用 Model 高阶 API 的动态图场景支持。** **注意**:**stop** 函数必须调用,否则可能导致精度数据落盘不全。 @@ -110,7 +121,7 @@ stop() **功能说明**:结束一个 step 的数据采集,完成所有数据落盘并更新 dump 参数。在一个 step 结束的位置添加,且必须在 **stop** 函数之后的位置调用。 该函数需要配合 **start** 和 **stop** 函数使用,尽量添加在反向计算代码之后,否则可能会导致反向数据丢失。 -**仅未使用 Model 高阶 API 的动态图场景支持。** +**仅未使用 Model 高阶 API 的动态图和静态图场景支持。** **原型**: @@ -144,15 +155,67 @@ save(variable, name, save_backward=True) **参数说明**: | 参数名称 | 参数含义 | 支持数据类型 | 是否必选| | ---------- | ------------------| ------------------- | ------------------- | -| variable | 需要保存的变量 |dict, list, torch.tensor, int, float, str | 是 | +| variable | 需要保存的变量 |dict, list, tuple, torch.tensor, int, float, str | 是 | | name | 指定的名称 | str | 是 | | save_backward | 是否保存反向数据 | boolean | 否 | +具体使用样例可参考:[单点保存工具使用介绍](./28.debugger_save_instruction.md)。 + +#### 6.1.6 set_init_step + +**功能说明**:设置起始step数,step数默认从0开始计数,使用该接口后step从指定值开始计数。该函数需要写在训练迭代的循环开始前,不能写在循环内。 + +**原型**: + +```Python +set_init_step(step) +``` + +**参数说明**: + +1.step: 指定的起始step数。 + + +#### 6.1.7 register_custom_api + +**功能说明**:注册用户自定义的api到工具,用于 L1 dump 。 + +**原型**: + +```Python +debugger.register_custom_api(module, api_name, api_prefix) +``` +**参数说明**: + +以 torch.matmul api 为例 + +1.module: api 所属的包,即传入 torch。 + +2.api_name: api 名,string类型,即传入 "matmul"。 + +3.api_prefix: [dump.json](./27.dump_json_instruction.md) 中 api 名的前缀,可选,默认为包名的字符串格式, 即 "torch"。 + +#### 6.1.8 restore_custom_api + +**功能说明**:恢复用户原有的自定义的api,取消 dump 。 +**原型**: + +```Python +debugger.restore_custom_api(module, api_name) +``` +**参数说明**: + +以 torch.matmul api 为例 + +1.module: api 所属的包,即传入 torch。 + +2.api_name: api 名,string类型,即传入 "matmul"。 -### 6.2 msprobe.mindspore.common.utils.MsprobeStep -**功能说明**:MindSpore Callback类,自动在每个step开始时调用start()接口,在每个step结束时调用stop()、step()接口。实现使用 Model 高阶 API 的动态图场景下 L0、L1、mix 级别的精度数据采集控制,控制粒度为单个 **Step** ,而 PrecisionDebugger.start, PrecisionDebugger.stop 接口的控制粒度任意训练代码段。 +### 6.2 msprobe.mindspore.MsprobeStep + +**功能说明**:MindSpore Callback类,自动在每个step开始时调用start()接口,在每个step结束时调用stop()、step()接口。实现使用 Model 高阶 API 的动态图场景下 L0、L1、mix 级别,和静态图场景下 L0级别的精度数据采集控制,控制粒度为单个 **Step** ,而 PrecisionDebugger.start, PrecisionDebugger.stop 接口的控制粒度为任意训练代码段。 **原型**: @@ -164,13 +227,23 @@ MsprobeStep(debugger) 1. debugger:PrecisionDebugger对象。 -### 6.3 msprobe.mindspore.seed_all +### 6.3 msprobe.mindspore.MsprobeInitStep + +**功能说明**:MindSpore Callback 类,自动获取并设置初始 step 值。仅适用于静态图 O0/O1 模式的断点续训场景。 + +**原型**: + +```Python +MsprobeInitStep() +``` + +### 6.4 msprobe.mindspore.seed_all **功能说明**:用于固定网络中的随机性和开启确定性计算。 **原型**: ```python -seed_all(seed=1234, mode=False, rm_dropout=True) +seed_all(seed=1234, mode=False, rm_dropout=False) ``` **参数说明**: @@ -179,14 +252,61 @@ seed_all(seed=1234, mode=False, rm_dropout=True) 2. mode:确定性计算使能,可配置 True 或 False,默认值:False,非必选。参数示例:mode=True。该参数设置为 True 后,将会开启算子确定性运行模式与归约类通信算子(AllReduce、ReduceScatter、Reduce)的确定性计算。注意:确定性计算会导致API执行性能降低,建议在发现模型多次执行结果不同的情况下开启。 -3. rm_dropout:控制dropout失效的开关。可配置 True 或 False,默认值:True,非必选。参数示例:rm_dropout=True。该参数设置为 True 后,将会使mindspore.ops.Dropout,mindspore.ops.Dropout2D,mindspore.ops.Dropout3D,mindspore.mint.nn.Dropout和mindspore.mint.nn.functional.dropout失效,以避免因随机dropout造成的网络随机性。建议在采集mindspore数据前开启。注意:通过rm_dropout控制dropout失效或生效需要在初始化Dropout实例前调用才能生效。 +3. rm_dropout:控制dropout失效的开关。可配置 True 或 False,默认值:False,非必选。参数示例:rm_dropout=True。该参数设置为 True 后,将会使mindspore.ops.Dropout,mindspore.ops.Dropout2D,mindspore.ops.Dropout3D,mindspore.mint.nn.Dropout和mindspore.mint.nn.functional.dropout失效,以避免因随机dropout造成的网络随机性。建议在采集mindspore数据前开启。注意:通过rm_dropout控制dropout失效或生效需要在初始化Dropout实例前调用才能生效。 +## 7. 示例代码 +### 7.1 静态图场景 +#### 7.1.1 L0 级别 -## 7. 示例代码 +**说明**: 静态图 L0 级别的Dump功能是基于mindspore.ops.TensorDump算子实现。在Ascend平台上的Graph模式下,可以通过设置环境变量 [MS_DUMP_SLICE_SIZE 和 MS_DUMP_WAIT_TIME](https://www.mindspore.cn/docs/zh-CN/r2.5.0/api_python/env_var_list.html) 解决在输出大Tesnor或输出Tensor比较密集场景下算子执行失败的问题。 -### 7.1 静态图场景 +##### 7.1.1.1 未使用 Model 高阶 API + + +```python +import mindspore as ms +ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend") + +from msprobe.mindspore import PrecisionDebugger +debugger = PrecisionDebugger(config_path="./config.json") + +# 模型、损失函数的定义以及初始化等操作 +# ... +model = Network() +# 数据集迭代的地方往往是模型开始训练的地方 +for data, label in data_loader: + debugger.start(model) # 进行 L0 级别下Cell 对象的数据采集时调用 + # 如下是模型每个 step 执行的逻辑 + grad_net = ms.grad(model)(data) + # ... + debugger.step() # 更新迭代数 +``` + +##### 7.1.1.2 使用 Model 高阶 API + + +```python +import mindspore as ms +from mindspore.train import Model +ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend") + +from msprobe.mindspore import PrecisionDebugger +from msprobe.mindspore.common.utils import MsprobeStep +debugger = PrecisionDebugger(config_path="./config.json") + +# 模型、损失函数的定义以及初始化等操作 +# ... + +model = Network() +# 进行 L0 级别下 Cell 对象的数据采集时调用 +debugger.start(model) +trainer = Model(model, loss_fn=loss_fn, optimizer=optimizer, metrics={'accuracy'}) +trainer.train(1, train_dataset, callbacks=[MsprobeStep(debugger)]) +``` + +#### 7.1.2 L2 级别 ```python import mindspore as ms @@ -198,7 +318,8 @@ debugger.start() # 请勿将以上初始化流程置于模型实例化或 mindspore.communication.init 调用后 # 模型定义和训练代码 # ... - +debugger.stop() +debugger.step() ``` ### 7.2 动态图场景 @@ -251,6 +372,34 @@ trainer = Model(model, loss_fn=loss_fn, optimizer=optimizer, metrics={'accuracy' trainer.train(1, train_dataset, callbacks=[MsprobeStep(debugger)]) ``` +##### 7.2.1.3 采集指定代码块的前反向数据 + +```python +import mindspore as ms +from mindspore import set_device +from mindspore.train import Model +ms.set_context(mode=ms.PYNATIVE_MODE) + +set_device("Ascend", 0) + +from msprobe.mindspore import PrecisionDebugger +from msprobe.mindspore.common.utils import MsprobeStep +debugger = PrecisionDebugger(config_path="./config.json") + +# 模型、损失函数的定义及初始化等操作 +# ... +# 数据集迭代的位置一般为模型训练开始的位置 +for data, label in data_loader: + debugger.start() # 开启数据dump + # 如下是模型每个step执行的逻辑 + output = model(data) + + debugger.stop() # 插入该函数到start函数之后,只dump start函数到该函数之间的前反向数据,可以支持start-stop-start-stop-step分段采集。 + # ... + loss.backward() + debugger.step() # 结束一个step的dump +``` + #### 7.2.2 L2 级别 ##### 7.2.2.1 未使用 Model 高阶 API @@ -297,15 +446,43 @@ trainer = Model(model, loss_fn=loss_fn, optimizer=optimizer, metrics={'accuracy' trainer.train(1, train_dataset) ``` + +#### 7.2.3 推理模型采集指定token_range +需要配合mindtorch套件改造原推理代码,套件包装后使用方式与torch一致,唯一区别为import的是msprobe.mindspore下的PrecisionDebugger。 + +```Python +from vllm import LLM, SamplingParams +from msprobe.mindspore import PrecisionDebugger, seed_all +# 在模型训练开始前固定随机性 +seed_all() +# 请勿将PrecisionDebugger的初始化流程插入到循环代码中 +debugger = PrecisionDebugger(config_path="./config.json", dump_path="./dump_path") +# 模型定义及初始化等操作 +prompts = ["Hello, my name is"] +sampling_params = SamplingParams(temprature=0.8, top_p=0.95) +llm = LLM(model='...') +model = llm.llm_engine.model_executor.driver_worker.worker.model_runner.get_model() +# 开启数据dump, 指定采集推理模型逐字符循环推理中的第1~3次 +debugger.start(model=model, token_range=[1,3]) +# 推理模型生成的逻辑 +output = llm.generate(prompts, sampling_params=sampling_params) +# 关闭数据dump并落盘 +debugger.stop() +debugger.step() +``` + ## 8. dump 结果文件介绍 ### 8.1 静态图场景 -训练结束后,数据将保存在 `dump_path` 指定的目录下。 +训练结束后,数据将保存在 `dump_path` 指定的目录下。
+L0 级别 dump 的目录结构与动态图场景下目录结构一致。
+L2 级别 dump 的目录结构如下所示: -若jit_level=O2,且使用mindstudio-probe发布包或源码编包时添加了`--include-mod=adump`选项,目录结构示例如下: +若jit_level=O2,MindSpore 版本不低于 2.5.0,且使用mindstudio-probe发布包或源码编包时添加了`--include-mod=adump`选项,目录结构示例如下: ``` ├── dump_path +│ ├── acl_dump_{device_id}.json │ ├── rank_0 │ | ├── {timestamp} │ | │ ├── step_0 @@ -329,10 +506,9 @@ trainer.train(1, train_dataset) **说明** 1. 若配置文件中指定落盘npy格式,但是实际数据格式不在npy支持范围内(如bf16、int4等),则该tensor会以原始码流落盘,并不会转换为npy格式。 2. 若原始文件全名长度超过255个字符,则文件基础名会被转换为长度为32位的随机数字字符串,原始文件名与转换后文件名的对应关系会保存在同目录下的`mapping.csv`文件中。 +3. acl_dump_{device_id}.json 为在 Dump 接口调用过程中生成的中间文件,一般情况下无需关注。 - -其他场景请参见 MindSpore 官方文档中的[数据对象目录](https://www.mindspore.cn/docs/zh-CN/r2.4.0/model_train/debug/dump.html)。 - +其他场景下,除 kernel_kbyk_dump.json(jit_level=O0/O1)、kernel_graph_dump.json(jit_level=O2)等无需关注的中间文件外的其他 dump 结果文件请参见 MindSpore 官方文档中的[ Ascend 下 O0/O1 模式 Dump 数据对象目录和数据文件介绍](https://www.mindspore.cn/docs/zh-CN/r2.5.0/model_train/debug/dump.html#%E6%95%B0%E6%8D%AE%E5%AF%B9%E8%B1%A1%E7%9B%AE%E5%BD%95%E5%92%8C%E6%95%B0%E6%8D%AE%E6%96%87%E4%BB%B6%E4%BB%8B%E7%BB%8D) ### 8.2 动态图场景 dump 结果目录结构示例如下: @@ -348,19 +524,21 @@ dump 结果目录结构示例如下: | | | | ├── Tensor.__add__.0.forward.output.0.npy | | | | ... | | | | ├── Jit.AlexNet.0.forward.input.0.npy -| | | | ├── Primitive.conv2d.Conv2D.0.forward.input.0.npy -| | | | ├── Cell.conv1.Conv2D.forward.0.parameters.weight.npy # 模块参数数据:命名格式为{Cell}.{cell_name}.{class_name}.forward.{调用次数}.parameters.{parameter_name}。 -| | | | ├── Cell.conv1.Conv2D.parameters_grad.weight.npy # 模块参数梯度数据:命名格式为{Cell}.{cell_name}.{class_name}.parameters_grad.{parameter_name}。因为同一模块的参数使用同一梯度进行更新,所以参数梯度文件名不包含调用次数。 +| | | | ├── Primitive.conv2d.Conv2d.0.forward.input.0.npy +| | | | ├── Cell.conv1.Conv2d.forward.0.parameters.weight.npy # 模块参数数据:命名格式为{Cell}.{cell_name}.{class_name}.forward.{调用次数}.parameters.{parameter_name}。 +| | | | ├── Cell.conv1.Conv2d.parameters_grad.weight.npy # 模块参数梯度数据:命名格式为{Cell}.{cell_name}.{class_name}.parameters_grad.{parameter_name}。因为同一模块的参数使用同一梯度进行更新,所以参数梯度文件名不包含调用次数。 | | | | └── Cell.relu.ReLU.forward.0.input.0.npy # 命名格式为{Cell}.{cell_name}.{class_name}.{forward/backward}.{调用次数}.{input/output}.{参数序号}, 其中,“参数序号”表示该Cell的第n个参数,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该Cell的第1个参数的第1个元素。 | | | | # 当dump时传入的model参数为List[mindspore.nn.Cell]或Tuple[mindspore.nn.Cell]时,模块级数据的命名中包含该模块在列表中的索引index,命名格式为{Cell}.{index}.*,*表示以上三种模块级数据的命名格式,例如:Cell.0.relu.ReLU.forward.0.input.0.npy。 │ | | ├── dump.json │ | | ├── stack.json +│ | | ├── dump_error_info.log │ | | └── construct.json │ | ├── rank1 | | | ├── dump_tensor_data | | | | └── ... │ | | ├── dump.json │ | | ├── stack.json +│ | | ├── dump_error_info.log | | | └── construct.json │ | ├── ... │ | | @@ -372,17 +550,44 @@ dump 结果目录结构示例如下: * `rank`:设备 ID,每张卡的数据保存在对应的 `rank{ID}` 目录下。非分布式场景下没有 rank ID,目录名称为 rank。 * `dump_tensor_data`:保存采集到的张量数据。 -* `dump.json`: 保存API或Cell前反向数据的统计量信息。包含dump数据的API名称或Cell名称,各数据的dtype、 shape、max、min、mean、L2norm(L2范数,平方根)统计信息以及当配置summary_mode="md5"时的CRC-32数据。具体介绍可参考[dump.json文件说明](./27.dump_json_instruction.md#2-dumpjson文件示例mindspore)。 +* `dump.json`: 保存API或Cell前反向数据的统计量信息。包含dump数据的API名称或Cell名称,各数据的dtype、 shape、max、min、mean、L2norm(L2范数,平方根)统计信息以及当配置summary_mode="md5"时的CRC-32数据。具体介绍可参考[dump.json文件说明](./27.dump_json_instruction.md#2-mindspore-场景下的-dumpjson-文件)。 +* `dump_error_info.log`: 仅在dump工具报错时拥有此记录日志,用于记录dump错误日志。 * `stack.json`:API/Cell的调用栈信息。 -* `construct.json`:分层分级结构,level为L1时,construct.json内容为空。 +* `construct.json`:根据model层级展示分层分级结构,level为L1时,construct.json内容为空。 dump 过程中,npy 文件在对应API或者模块被执行后就会落盘,而 json 文件则需要在正常执行 PrecisionDebugger.stop() 后才会写入完整数据,因此,程序异常终止时,被执行API对应的 npy 文件已被保存,但 json 文件中的数据可能丢失。 -动态图场景下使能 PSJit 或 PIJit,装饰特定 Cell 或 function,被装饰的部分会全部/部分使能**静态图**流程。 +动态图场景下使用 `mindspore.jit` 装饰特定 Cell 或 function 时,被装饰的部分会被编译成**静态图**执行。 -- PSJit 场景下 config.json 文件配置 level 为 L1 时,被 PSJit 装饰的部分也作为 API 被 dump 到对应目录;配置 level 为 L2 时,则只会 dump 用户网络中静态图流程下的相关 kernel,其结果目录同jit_level 为 O0/O1 时的静态图 dump 相同。 -- PIJit 场景下 config.json 文件配置 level 为 L1 时,会被还原为动态图,按 API 粒度进行 dump;配置 level 为 L2 时,则只会 dump 用户网络中静态图流程下的相关 kernel。 +- config.json 文件配置 level 为 L0 或 mix,且 MindSpore 版本不低于 2.7.0 时, 若存在 construct 方法被 `mindspore.jit` 装饰的 Cell 对象,则 dump_path 下将生成 `graph` 与 `pynative` 目录,分别存放 construct 方法被 `mindspore.jit` 装饰的 Cell 对象的精度数据、其它Cell 或 API 对象的精度数据。示例如下: +```lua +├── dump_path +│ ├── graph +│ | ├── step0 +│ | | ├── rank0 +│ | │ | ├── dump_tensor_data +| | | | | ├── ... +│ | | | ├── dump.json +│ | | | ├── stack.json +│ | | | └── construct.json +│ | | ├── ... +│ ├── pynative +│ | ├── step0 +│ | | ├── rank0 +│ | │ | ├── dump_tensor_data +| | | | | ├── ... +│ | | | ├── dump.json +│ | | | ├── stack.json +│ | | | └── construct.json +│ | | ├── ... +``` + +**注意**:因为在被 `mindspore.jit` 装饰的 construct 方法前后插入的 Dump 算子既处于动态图模式,也处于静态图模式,所以最外层被装饰的 Cell 对象的精度数据将被重复采集。 + +- config.json 文件配置 level 为 L1 时, 若 `mindspore.jit` 的 `capture_mode` 参数设置为 ast(原 PSJit 场景), 则被装饰的部分也作为 API 被 dump 到对应目录;若 `mindspore.jit` 的 `capture_mode` 参数设置为 bytecode(原 PIJit 场景), 则被装饰的部分会被还原为动态图,按 API 粒度进行 dump。 + +- config.json 文件配置 level 为 L2 时, 仅会 dump 被 `mindspore.jit` 装饰部分的 kernel 精度数据,其结果目录同 jit_level 为 O0/O1 时的静态图 dump 结果相同。 npy文件名的前缀含义如下: @@ -393,12 +598,12 @@ npy文件名的前缀含义如下: | Primitive | mindspore.ops.Primitive API数据 | | Mint | mindspore.mint API数据 | | MintFunctional | mindspore.mint.nn.functional API数据 | +| MintDistributed | mindspore.mint.distributed API数据 | | Distributed | mindspore.communication.comm_func API数据 | | Jit | 被"jit"装饰的模块或函数数据 | | Cell | mindspore.nn.Cell 类(模块)数据 | - ## 9.补充说明 ### 9.1 修改 API 支持列表 diff --git a/debug/accuracy_tools/msprobe/docs/07.accuracy_checker_PyTorch.md b/debug/accuracy_tools/msprobe/docs/07.accuracy_checker_PyTorch.md index b07568e25a2915a4e8e5c2157e7de4252410f38d..1862e3d0c736db202690f0e2cc10f0b9bb9c8348 100644 --- a/debug/accuracy_tools/msprobe/docs/07.accuracy_checker_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/07.accuracy_checker_PyTorch.md @@ -2,7 +2,7 @@ ## 1 简介 -**PyTorch 离线精度预检**通过扫描昇腾 NPU 上用户训练模型中的所有 API,输出模型精度的诊断和分析结果。具体而言,该工具通过采集模型中所有 API 的前反向信息,构造相应的单元测试,将 NPU 输出与标杆(CPU 高精度)比对,从而计算对应的精度指标,该过程通过子命令 run_ut 执行;将 NPU 环境下采集的预检数据拷贝至 GPU 环境,同样执行 run_ut;最后通过**新精度标准比对法**a将 NPU 和 GPU 的预检结果进行比对,从而找出 NPU 中存在精度问题的 API。同时,本工具支持**随机生成模式和真实数据模式**b。 +**PyTorch 离线精度预检**通过扫描昇腾 NPU 上用户训练模型中的PyTorch API,输出模型精度的诊断和分析结果。具体而言,该工具通过采集模型中所有 API 的前反向信息,构造相应的单元测试,将 NPU 输出与标杆(CPU 高精度)比对,从而计算对应的精度指标,该过程通过子命令 run_ut 执行;将 NPU 环境下采集的预检数据拷贝至 GPU 环境,同样执行 run_ut;最后通过**新精度标准比对法**a将 NPU 和 GPU 的预检结果进行比对,从而找出 NPU 中存在精度问题的 API。同时,本工具支持**随机生成模式和真实数据模式**b。 a. 依据新精度标准,对不同的API采取不同的比对算法(包括绝对阈值法,标杆比对法、二进制一致法、ULP误差比对法和双千指标法),最终给定预检判定结果; @@ -13,9 +13,9 @@ b. 在预检 dump 时可以选择由工具构造随机数获得 dump 数据或 1. 在 NPU 和 GPU 环境下分别安装 msprobe。详见[ msprobe 安装](./01.installation.md)章节。 2. 在 NPU 训练脚本内添加 msprobe 工具 dump 接口 PrecisionDebugger,采集待预检数据。注意需要配置 level="L1"。 3. 将 NPU 环境下 dump 的预检数据拷贝至 GPU 环境。 -4. 在 NPU 和 GPU 环境下分别执行 run_ut,生成的结果最终用于 api_precision_compare.py 函数的输入。详见 [3 离线预检操作指导](#3-离线预检操作指导)。 +4. 在 NPU 和 GPU 环境下分别执行 run_ut,生成的结果最终用于 api_precision_compare 的输入。详见 [3 离线预检操作指导](#3-离线预检操作指导)。 5. 将 NPU 和 GPU 执行 run_ut 生成的 `accuracy_checking_details_{timestamp}.csv` 结果文件拷贝至同一环境下。 -6. 运行 api_precision_compare.py,输出结果为预检操作的最终结果。详见 [5 预检结果比对](#5-预检结果比对)章节。 +6. 运行 api_precision_compare,输出结果为预检操作的最终结果。详见 [5 预检结果比对](#5-预检结果比对)章节。 ## 3 离线预检操作指导 @@ -34,16 +34,17 @@ run_ut 预检操作包括以下两种方式: msprobe -f pytorch run_ut -api_info ./dump_path/step{step_number}/rank{rank_number}/dump.json ``` - | 参数名称 | 解释 | 是否必选 | - | ---------------------------- | ------------------------------------------------------------ | ---------------------------------- | - | -api_info 或 --api_info_file | 指定 API 信息文件 dump.json。 | 是 | - | -save_error_data | 保存精度未达标的 API 输入输出数据。 | 否 | - | -o 或 --out_path | 指定 run_ut 执行结果存盘路径,默认“./”。 | 否 | - | -j 或 --jit_compile | 开启 jit 编译。 | 否 | - | -d 或 --device | 指定 Device ID,选择 UT 代码运行所在的卡,默认值为 0。 | 否 | + | 参数名称 | 解释 | 是否必选 | + |-------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ---------------------------------- | + | -f 或 --framework | 指定训练框架。pytorch。 | 是 | + | -api_info 或 --api_info_file | 指定 API 信息文件 dump.json。 | 是 | + | -save_error_data | 保存精度未达标的 API 输入输出数据。 | 否 | + | -o 或 --out_path | 指定 run_ut 执行结果存盘路径,默认“./”。 | 否 | + | -j 或 --jit_compile | 开启 jit 编译。 | 否 | + | -d 或 --device | 指定 Device ID,选择 UT 代码运行所在的卡,默认值为 0。 | 否 | | -csv_path 或 --result_csv_path | 指定本次运行中断时生成的 `accuracy_checking_result_{timestamp}.csv` 文件路径,执行 run_ut 中断时,若想从中断处继续执行,配置此参数即可。需要指定为上次中断的 `accuracy_checking_result_{timestamp}.csv` 文件。详见 [3.3 断点续检](#33-断点续检)。 | run_ut 操作中断后继续执行场景下必须配置 | - | -f 或 --filter_api | 过滤模型中除最大值和最小值以外其他参数和结构相同的 API。适用于模型较大且重复 API 较多的场景。 | 否 | - | -config 或 --config_path | 指定离线预检操作过程中额外配置(包括黑名单、白名单等)的 [config.json](../config.json) 文件,默认未配置。config.json 文件的配置可参考[配置文件介绍](./02.config_introduction.md)。 | 否 | + | -f 或 --filter_api | 过滤模型中除最大值和最小值以外其他参数和结构相同的 API。适用于模型较大且重复 API 较多的场景。 | 否 | + | -config 或 --config_path | 指定离线预检操作过程中额外配置(包括黑名单、白名单等)的 [config.json](../config.json) 文件,默认未配置。config.json 文件的配置可参考[配置文件介绍](./02.config_introduction.md)。 | 否 | run_ut 执行结果包括 `accuracy_checking_result_{timestamp}.csv` 和 `accuracy_checking_details_{timestamp}.csv` 两个文件。`accuracy_checking_result_{timestamp}.csv` 属于 API 级,标明每个 API 是否通过测试。建议用户先查看 `accuracy_checking_result_{timestamp}.csv` 文件,对于其中没有通过测试的或者特定感兴趣的 API,根据其 API name 字段在 `accuracy_checking_details_{timestamp}.csv` 中查询其各个输出的达标情况以及比较指标。详细介绍请参见[ 4 预检结果](#4-预检结果)。 @@ -103,11 +104,12 @@ msprobe -f pytorch multi_run_ut -api_info ./dump_path/step{step_number}/rank{ran | 参数名称 | 解释 | 是否必选 | | ---------------------------- | ------------------------------------------------------------ | ---------------------------------- | +| -f 或 --framework | 指定训练框架。pytorch。 | 是 | | -api_info 或 --api_info_file | 指定 API 信息文件 dump.json。 | 是 | | -save_error_data | 保存精度未达标的 API 输入输出数据。 | 否 | | -o 或 --out_path | 指定 run_ut 执行结果存盘路径,默认“./”。 | 否 | | -j 或 --jit_compile | 开启 jit 编译。 | 否 | -| -n | 同时执行 run_ut 线程的数量,默认为 8,最大支持 64,但每个 Device 最大支持 8 个线程。当指定多个线程和多个 Device 时,线程数在每张卡上均分。 | 否 | +| -n 或 --num_splits | 同时执行 run_ut 线程的数量,默认为 8,最大支持 64,但每个 Device 最大支持 8 个线程。当指定多个线程和多个 Device 时,线程数在每张卡上均分。 | 否 | | -d 或 --device | 指定 Device ID,选择 UT 代码运行所在的卡,默认值为 0,支持同时指定 0~7,共 8 个 Device。 | 否 | | -csv_path 或 --result_csv_path | 指定本次运行中断时生成的 `accuracy_checking_result_{timestamp}.csv` 文件路径,执行 run_ut 中断时,若想从中断处继续执行,配置此参数即可。需要指定为上次中断的 `accuracy_checking_result_{timestamp}.csv` 文件。详见 [3.3 断点续检](#33-断点续检)。 | run_ut 操作中断后继续执行场景下必须配置 | | -f 或 --filter_api | 过滤模型中除最大值和最小值以外其他参数和结构相同的 API。适用于模型较大且重复 API 较多的场景。 | 否 | @@ -212,8 +214,9 @@ Forward Test Success 和 Backward Test Success 是否通过测试是由 `accurac msprobe -f pytorch api_precision_compare -npu /home/xxx/npu/accuracy_checking_details_{timestamp}.csv -gpu /home/xxx/gpu/accuracy_checking_details_{timestamp}.csv -o /home/xxx/ ``` -| 参数名称 | 说明 | 是否必选 | -| -------------------- | ------------- | -------- | +| 参数名称 | 说明 | 是否必选 | +|-----------------------| ------------- | -------- | +| -f 或 --framework | 指定训练框架。pytorch。 | 是 | | -npu 或 --npu_csv_path | NPU 预检结果 `accuracy_checking_details_{timestamp}.csv` 文件路径。默认从当前目录下识别该文件。 | 是 | | -gpu 或 --gpu_csv_path | GPU 预检结果 `accuracy_checking_details_{timestamp}.csv` 文件路径。默认从当前目录下识别该文件。 | 是 | | -o 或 --out_path | 指定 api_precision_compare.py 执行结果存盘路径,默认为当前目录。 | 否 | diff --git a/debug/accuracy_tools/msprobe/docs/08.accuracy_checker_online_PyTorch.md b/debug/accuracy_tools/msprobe/docs/08.accuracy_checker_online_PyTorch.md deleted file mode 100644 index a93ad3b62405d549a16e7196e2f2145de68e8674..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/docs/08.accuracy_checker_online_PyTorch.md +++ /dev/null @@ -1,233 +0,0 @@ -# PyTorch 场景的在线精度预检 - -## 1 简介 - -为了应对大模型场景下,通过离线预检方式 dump API 输入输出数据导致的存储资源紧张问题,提供在线精度预检功能。本功能实现在执行 NPU 训练操作的过程中,通过 TCP/IP 协议在 NPU -Host 与 GPU Host 设备间建立连接,将 NPU 上对应 API 的输入数据在 GPU 设备上运行,将两份输出数据进行比对,得到预检比对结果,从而减少数据 dump 的步骤,降低存储资源的占用。针对偏差较大的算子,两方比对(NPU vs. GPU)的方法缺少裁判进行裁定。 参考离线预检,在线预检场景同时支持两方比对和三方比对方式,按照 api 的精度标准要求,选择比对两方比对和三方比对。 - -## 2 在线精度预检流程 - -在线精度预检当前支持**局域网场景**和**共享存储场景**,请根据不同的场景选择对应的配置。 - -在线精度预检操作流程如下: - -1. 准备 GPU 和 NPU 可正常运行的训练环境,PyTorch 版本大于等于2.0,并保证两台 Host 在同一局域网内可正常通信或能通过共享存储进行通信。 -2. GPU 和 NPU Host 设备上同时安装msprobe工具,详见[ msprobe 安装](./01.installation.md)章节,其中在线预检要安装 twisted、pyOpenSSL,这些包为 Python 模块。 -3. 分别配置 GPU 侧、NPU 侧的 config.json 文件。 -4. 在 GPU 侧运行 `msprobe -f pytorch run_ut -config ./config.json`。 -5. 在 NPU 侧配置训练脚本。 -6. 在 NPU 侧执行训练。 - -## 3 在线精度预检操作指导 - -### 3.1 配置 config.json 文件 - -预检工具安装完成后,需要在 GPU 和 NPU 环境下分别配置 config.json。其中需要重点关注文件中的 is_online、is_benchmark_device、host 和 port 参数的配置,保障在线预检时 GPU 和 NPU 两台设备间的通信正常。 - -#### 3.1.1 GPU 侧在线预检配置说明 - -| 参数名称 | 说明 | 是否必选 | -|-----------------|--------------|------| -| task | 任务名称,str 类型,配置为 run_ut 表示预检任务。通过其他字段 is_online 判断离线预检、在线预检任务。 | 是 | -| white_list | 预检的 API 白名单,list[str] 类型。
**配置示例**:white_list=["conv1d", "conv2d"]。默认未配置白名单,即预检全量 API 数据。 | 否 | -| black_list | 预检的 API 黑名单,list[str] 类型。
**配置示例**:white_list=["conv1d", "conv2d"]。默认未配置黑名单,即预检全量 API 数据。 | 否 | -| error_data_path | 配置保存精度未达标的 API 输入输出数据路径,str 类型。在线预检模式下该参数不生效。 | 否 | -| is_online | 在线预检模式开关,bool 类型,可取值 True(开启)、False(关闭),默认关闭。 | 是 | -| nfs_path | 在线预检模式共享存储目录路径,str 类型,用于 GPU 设备和 NPU 设备间进行通信。配置该参数后 host、port 和 tls_path 不生效。 | 否 | -| host | 在线预检模式局域网场景信息接收端 IP,str 类型,用于 GPU 设备和 NPU 设备间进行通信,GPU 侧配置为本机地址 127.0.0.1 或本机局域网 IP。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 | -| port | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,GPU 侧配置为本机可用端口。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 | -| rank_list | 指定在线预检的 Rank ID,默认值为 [0],list[int] 类型,应配置为大于等于 0 的整数,且须根据实际卡的 Rank ID 配置,若所配置的值大于实际训练所运行的卡的 Rank ID,则在线预检输出数据为空。GPU 和 NPU 须配置一致。 | 是 | -| tls_path | 在线预检模式局域网场景 SSL 证书路径,该路径下包含私钥文件 server.key 和公钥文件 server.crt,str 类型,未配置该参数时默认取值当前路径。tls_path配置为空字符串时,采用TCP协议明文传输api数据;当配置为路径时,采用TLS1.2协议加密传输数据,加密传输时安全性较高,传输速率较低。 | 否 | - - -#### 3.1.2 NPU 侧在线预检配置说明 - -| 参数名称 | 说明 | 是否必选 | -|------------------|-------------|------| -| task | 任务名称,str 类型,配置为 tensor 表示 dump API 统计信息和完全复刻整网的 API 运行情况的真实数据。通过字段 online_run_ut 判断是否使用在线预检功能。 | 是 | -| dump_path | dump 路径,str 类型,配置为合法路径即可,兼容 tensor 任务静态检查。 | 是 | -| level | dump 级别,str 类型,在线预检时配置为 L1,表示 dump API 级精度数据。在线预检可不配置,默认取值 L1。 | 是 | -| rank | 指定对某张卡上的数据进行 dump,list[int] 类型,默认未配置(表示 dump所有卡的数据),需要与 GPU 侧配置项 rank_list 保持一致。 | 否 | -| step | 指定 dump 某个 step 的数据,list[int] 类型,默认未配置,表示 dump 所有 step 的数据。dump 特定 step 时,须指定为训练脚本中存在的 step。 | 否 | -| scope | dump 范围,list[str] 类型,默认未配置(list 也未配置时表示 dump 所有 api 的数据),配置方式参考 [config.json 配置介绍](./02.config_introduction.md)。 | 否 | -| list | dump 范围,list[str] 类型,默认未配置(scope 也未配置时表示 dump 所有 api 的数据),配置方式参考 [config.json 配置介绍](./02.config_introduction.md)。 | 否 | -| online_run_ut | 在线预检模式开关,bool 类型,可取值 True(开启)、False(关闭),默认关闭。 | 是 | -| nfs_path | 在线预检模式共享存储目录路径,str 类型,用于 GPU 设备和 NPU 设备间进行通信。配置该参数后 host 和 port 不生效。 | 否 | -| host | 在线预检模式局域网场景信息接收端 IP,str 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的局域网 IP 地址。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 | -| port | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的端口号。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 | -| tls_path | 在线预检模式局域网场景 SSL 证书路径,该路径下包含私钥文件 client.key 和公钥文件 client.crt,str 类型,未配置该参数时默认取值当前路径。tls_path配置为空字符串时,采用TCP协议明文传输api数据;当配置为路径时,采用TLS1.2协议加密传输数据,加密传输时安全性较高,传输速率较低。 | 否 | -| online_run_ut_recompute | 模型训练是否使用重计算机制,bool类型,默认为False,表示模型没有使用重计算。在线预检暂不支持重计算机制下反向算子的预检,当模型训练使用重计算时,跳过反向算子预检,默认模型关闭重计算。 | 否 | - -#### 3.1.3 局域网场景配置示例 - -若采用 TLS1.2 协议加密传输 api 数据,需配置 SSL 证书,可参考如下生成自签名证书方法,仅供调试使用,生产环境请申请正式证书。 -```shell -# 创建私钥文件server.key -openssl genrsa -out server.key 2048 - -# 创建签名请求文件server.csr -openssl req -new -key server.key -out server.csr - -# 自签名, 生成1年期公钥文件server.crt -openssl x509 -req -days 365 -in server.csr -signkey server.key -out server.crt -``` - -注意:配置TLS协议时,传输性能受机器环境和网络质量的影响,可能触发NPU超时中断模型训练,为避免训练和预检中断,丢弃长时间未传输的api数据,同时NPU侧配置HCCL环境变量,配置方式如下: - -a) 调整HCCL环境变量,关闭看门狗,避免WorkHCCL超时中断模型训练: -```shell -export HCCL_DESYNC_DEBUG=0 -export HCCL_ASYNC_ERROR_HANDLING=0 -``` -b) 调整通信算子超时设置(以1800s举例): -```shell -export HCCL_CONNECT_TIMEOUT=1800 -export HCCL_EXEC_TIMEOUT=1800 -``` - -GPU 侧: - -```json -{ - "task": "run_ut", - "run_ut": { - "white_list": [], - "black_list": [], - "error_data_path": "./", - "is_online": true, - "nfs_path": "", - "host": "127.0.0.1", - "port": 59208, - "rank_list": [0], - "tls_path": "" - } -} -``` - -NPU 侧: - -```json -{ - "task": "tensor", - "dump_path": "./dump_path", - "rank": [0], - "step": [0], - "level": "L1", - "tensor": { - "scope": [], - "list": [], - "online_run_ut": true, - "nfs_path": "", - "host": "xx.xx.xx.x", - "port": 59208, - "tls_path": "" - } -} -``` - -#### 3.1.4 共享存储场景配置示例 - -GPU 侧: - -```json -{ - "task": "run_ut", - "run_ut": { - "white_list": [], - "black_list": [], - "error_data_path": "./", - "is_online": true, - "nfs_path": "/nfs/xxx/data", - "host": "", - "port": -1, - "rank_list": [0], - "tls_path": "" - } -} -``` - -NPU 侧: - -```json -{ - "task": "tensor", - "dump_path": "./dump_path", - "rank": [0], - "step": [0], - "level": "L1", - "tensor": { - "scope": [], - "list": [], - "online_run_ut": true, - "nfs_path": "/nfs/xxx/data", - "host": "", - "port": -1, - "tls_path": "" - } -} -``` - -### 3.2 在 GPU 侧运行 run_ut - -由于 GPU 侧为通信接收端,需先于 NPU 侧执行 run_ut 操作,命令如下: - -```bash -msprobe -f pytorch run_ut -config ./config.json -``` - -GPU 侧配置好 config.json 文件后执行 run_ut 命令,此时 GPU 处于预检等待状态: - -- 局域网场景:当 NPU 侧启动训练后将预检的 API 输入和输出数据发送到 GPU 侧时,GPU 启动预检操作。 -- 共享存储场景:当 NPU 侧启动训练后将预检的 API 输入和输出数据发送到共享存储时,GPU 启动预检操作。 - -### 3.3 在 NPU 侧配置训练脚本 - -在 NPU 训练脚本中添加如下代码以获取 run_ut 操作的预检 API 输入和输出数据: - -```python -from msprobe.pytorch import PrecisionDebugger - -debugger = PrecisionDebugger("config.json") -... - -debugger.start() - -... - -debugger.stop() -debugger.step() -``` - -### 3.4 在 NPU 侧执行训练脚本 - -配置完 NPU 侧训练脚本后即可执行训练脚本,命令示例如下: - -```bash -bash train.sh -``` - -训练脚本执行完毕后,在GPU侧dump_path目录下生成比对结果文件,`accuracy_checking_result_{timestamp}_rank{rank_id}.csv`和`accuracy_checking_details_{timestamp}_rank{rank_id}.csv`记录两方比对结果,`api_precision_compare_result_{timestamp}_rank{rank_id}.csv`和`api_precision_compare_details_{timestamp}_rank{rank_id}.csv`记录三方比对结果。详细介绍请参见[离线精度预检中的 **4 预检结果**](./07.accuracy_checker_PyTorch.md#4-预检结果)。 - -## 4 支持的融合算子列表 - -预检工具当前支持的融合算子如下: - -- npu_apply_adam_w - -- npu_confusion_transpose - -- fast_gelu - -- npu_layer_norm_eval - -- npu_linear - -- npu_fusion_attention(该算子在 GPU 上预检时,需要额外安装 flash_attn,请用户自行安装。) - -- npu_rms_norm - -- npu_rotary_mul - -- npu_scaled_masked_softmax - -- npu_swiglu diff --git a/debug/accuracy_tools/msprobe/docs/09.accuracy_checker_MindSpore.md b/debug/accuracy_tools/msprobe/docs/09.accuracy_checker_MindSpore.md index 8e5ab781ce0652ea572e0a0e5fb053655c5f48ec..fe7b228131ca647583294a32559541694718c89d 100644 --- a/debug/accuracy_tools/msprobe/docs/09.accuracy_checker_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/09.accuracy_checker_MindSpore.md @@ -2,11 +2,11 @@ ## 1 简介 -**MindSpore 动态图精度预检**a通过扫描昇腾 NPU 上用户训练 MindSpore 模型中的所有 Mint API,输出精度情况的诊断和分析。工具以模型中所有 Mint API 前反向的 dump 结果为输入,构造相应的 API 单元测试,将 NPU 输出与标杆(CPU 高精度)比对,计算对应的精度指标,从而找出 NPU 中存在精度问题的 Mint API。本工具支持**随机生成模式和真实数据模式**b。 +**MindSpore 动态图精度预检**a通过扫描昇腾 NPU 上用户训练 MindSpore 模型中的所有 Mint API 以及 Msadapter场景下迁移的 Mindspore API,输出精度情况的诊断和分析。工具以模型中所有 API 前反向的 dump 结果为输入,构造相应的 API 单元测试,将 NPU 输出与标杆(CPU 高精度)比对,计算对应的精度指标,从而找出 NPU 中存在精度问题的 API。本工具支持**随机生成模式和真实数据模式**b。 a. 支持 Mindspore 版本:2.4/2.5; -b. (可选)当使用Msadapter时,由于需要环境中同时存在 Torch 与 Msadapter,所以只支持在**安装原生Torch**的场景下通过export PYTHONPATH="xx/msadapter/build/lib"等通过**环境变量使能Msadapter的方式**的环境中进行预检,预检工具能够自动索引得到所需的 Torch 与 Msadapter环境,环境安装详细参考:[msadapter官网](https://gitee.com/mindspore/msadapter)。 +b. (可选)当使用Msadapter时,由于需要环境中同时存在 Torch 与 Msadapter,所以只支持在**安装原生Torch**的场景下通过export PYTHONPATH="xx/msadapter/build/lib"等通过**环境变量使能Msadapter的方式**的环境中进行预检,预检工具能够自动索引得到所需的 Torch 与 Msadapter环境,环境安装详细参考:[msadapter官网](https://gitee.com/mindspore/msadapter)(该网站需要申请权限方可访问)。 c. 在预检时可以由工具构造随机数据或者获取真实dump数据进行预检操作。随机生成模式执行效率高,可以快速获得结果,但结果准确度低,只能大致判断精度问题;真实数据模式执行效率略低于随机生成模式,并且需要较大磁盘空间存放待预检数据,但是结果准确度高,可以准确判断精度问题。 @@ -29,14 +29,24 @@ c. 在预检时可以由工具构造随机数据或者获取真实dump数据进 msprobe -f mindspore run_ut -api_info ./dump.json -o ./checker_result ``` -| 参数名称 | 说明 |参数类型 | 是否必选 | -| ---------------------------- |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------- | ---------------------------------- | -| -api_info 或 --api_info_file | 指定 API 信息文件 dump.json。对其中的mint api以及部分Tensor api进行预检,预检支持的Tensor api列表详见 [ 预检支持列表](../mindspore/api_accuracy_checker/checker_support_api.yaml)。 | str | 是 | -| -o 或 --out_path | 指定预检结果存盘路径,默认“./”。 | str | 否 | -| -csv_path 或 --result_csv_path | 指定本次运行中断时生成的 `accuracy_checking_result_{timestamp}.csv` 文件路径,执行 run_ut 中断时,若想从中断处继续执行,配置此参数即可。需要指定为上次中断的 `accuracy_checking_result_{timestamp}.csv` 文件。详见 [3.3 断点续检](#33-断点续检)。 | str | 否 | +| 参数名称 | 说明 | 参数类型 | 是否必选 | +| ---------------------------- |--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------| ---------------------------------- | +| -f 或 --framework | 指定训练框架。mindspore。 | str | 是 | +| -api_info 或 --api_info_file | 指定 API 信息文件 dump.json。对其中的mint api以及部分Tensor api进行预检,预检支持的Tensor api列表详见 [ 预检支持列表](../mindspore/api_accuracy_checker/checker_support_api.yaml)。 | str | 是 | +| -o 或 --out_path | 指定预检结果存盘路径,默认“./”。 | str | 否 | +| -csv_path 或 --result_csv_path | 指定本次运行中断时生成的 `accuracy_checking_result_{timestamp}.csv` 文件路径,执行 run_ut 中断时,若想从中断处继续执行,配置此参数即可。需要指定为上次中断的 `accuracy_checking_result_{timestamp}.csv` 文件。详见 [3.3 断点续检](#33-断点续检)。 | str | 否 | +| -save_error_data | 保存(随机数据模式)精度未达标的 API 输入输出数据。 | 空 | 否 | 预检执行结果包括 `accuracy_checking_result_{timestamp}.csv` 和 `accuracy_checking_details_{timestamp}.csv` 两个文件。`accuracy_checking_result_{timestamp}.csv` 属于 API 级,标明每个 API 是否通过测试。建议用户先查看 `accuracy_checking_result_{timestamp}.csv` 文件,对于其中没有通过测试的或者特定感兴趣的 API,根据其 API Name 字段在 `accuracy_checking_details_{timestamp}.csv` 中查询其各个输出的达标情况以及比较指标。详细介绍请参见 [4 预检结果](#4-预检结果)。 +随机数据模式下,如果需要保存比对不达标的输入和输出数据,可以在 run_ut 执行命令结尾添加 `-save_error_data`,例如: + +```bash +msprobe -f mindspore run_ut -api_info ./dump.json -o ./checker_result -save_error_data +``` + +数据默认会存盘到 '{out_path}/error_data' 路径下。 + ### 3.2 使用 multi_run_ut 执行多线程预检 multi_run_ut 脚本,可以并行在多个Device执行 run_ut 操作,从而减少预检耗时。示例如下: @@ -45,16 +55,20 @@ multi_run_ut 脚本,可以并行在多个Device执行 run_ut 操作,从而 msprobe -f mindspore multi_run_ut -api_info ./dump.json -d 0 1 2 3 ``` -| 参数名称 | 说明 |参数类型 | 是否必选 | -| ---------------------------- |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------- | ---------------------------------- | -| -api_info 或 --api_info_file | 指定 API 信息文件 dump.json。对其中的mint api以及部分Tensor api进行预检,预检支持的Tensor api列表详见 [ 预检支持列表](../mindspore/api_accuracy_checker/checker_support_api.yaml)。 | str | 是 | -| -o 或 --out_path | 指定预检结果存盘路径,默认“./”。 | str | 否 | -| -csv_path 或 --result_csv_path | 指定本次运行中断时生成的 `accuracy_checking_result_{timestamp}.csv` 文件路径,执行 run_ut 中断时,若想从中断处继续执行,配置此参数即可。需要指定为上次中断的 `accuracy_checking_result_{timestamp}.csv` 文件。详见 [3.3 断点续检](#33-断点续检)。 | str | 否 | -| -d 或 --device | 指定 Device ID,选择 UT 代码运行所在的卡,默认值为 0,支持同时指定 0 ~ Device数量 - 1 ,例如 0 1 2 3 4。 | List[int] | 否 | +| 参数名称 | 说明 | 参数类型 | 是否必选 | +|-------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------| ---------------------------------- | +| -f 或 --framework | 指定训练框架。mindspore。 | str | 是 | +| -api_info 或 --api_info_file | 指定 API 信息文件 dump.json。对其中的mint api以及部分Tensor api进行预检,预检支持的Tensor api列表详见 [ 预检支持列表](../mindspore/api_accuracy_checker/checker_support_api.yaml)。 | str | 是 | +| -o 或 --out_path | 指定预检结果存盘路径,默认“./”。 | str | 否 | +| -csv_path 或 --result_csv_path | 指定本次运行中断时生成的 `accuracy_checking_result_{timestamp}.csv` 文件路径,执行 run_ut 中断时,若想从中断处继续执行,配置此参数即可。需要指定为上次中断的 `accuracy_checking_result_{timestamp}.csv` 文件。详见 [3.3 断点续检](#33-断点续检)。 | str | 否 | +| -d 或 --device | 指定 Device ID,选择 UT 代码运行所在的卡,默认值为 0,支持同时指定 0 ~ Device数量 - 1 ,例如 0 1 2 3 4。 | List[int] | 否 | +| -save_error_data | 保存(随机数据模式)精度未达标的 API 输入输出数据。 | 空 | 否 | 在不同卡数下,使用38B语言大模型的预检耗时基线参考 [multi_run_ut耗时基线](accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md) +数据默认会存盘到 './ut_error_data{timestamp}' 路径下 + ### 3.3 断点续检 断点续检操作通过如下命令执行: diff --git a/debug/accuracy_tools/msprobe/docs/10.accuracy_compare_PyTorch.md b/debug/accuracy_tools/msprobe/docs/10.accuracy_compare_PyTorch.md index b4525d738d849a17ca5049bd2214784c6f788d21..0f804d91bc2b19b07c52dc5b3f2ff8caace8ce17 100644 --- a/debug/accuracy_tools/msprobe/docs/10.accuracy_compare_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/10.accuracy_compare_PyTorch.md @@ -51,14 +51,17 @@ msprobe -f pytorch compare -i ./compare.json -o ./output -s 完整参数说明: -| 参数名 | 说明 | 是否必选 | -|-------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | -| -i 或 --input_path | 指定[比对文件](#214-比对文件),str 类型。 | 是 | -| -o 或 --output_path | 配置比对结果文件存盘目录,str 类型,默认在当前目录创建output目录。文件名称基于时间戳自动生成,格式为:`compare_result_{timestamp}.xlsx`。 | 否 | -| -s 或 --stack_mode | 比对结果展示调用栈信息(NPU_Stack_Info)的开关,bool 类型。单卡场景开启时,根据[比对文件](#214-比对文件)的参数说明配置stack_path;多卡场景开启时,自动识别npu_dump目录下stack.json文件,如存在生成详细调用栈信息,否则不生成,此参数不生效。通过直接配置该参数开启,默认未配置,表示关闭。 | 否 | +| 参数名 | 说明 | 是否必选 | +|---------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | +| -f 或 --framework | 指定训练框架。pytorch。 | 是 | +| -i 或 --input_path | 指定[比对文件](#51-比对文件),str 类型。 | 是 | +| -o 或 --output_path | 配置比对结果文件存盘目录,str 类型,默认在当前目录创建output目录。文件名称基于时间戳自动生成,格式为:`compare_result_{timestamp}.xlsx`。
提示:output目录下与结果件同名文件将被删除覆盖。 | 否 | +| -s 或 --stack_mode | 比对结果展示调用栈信息(NPU_Stack_Info)的开关,bool 类型。单卡场景开启时,根据[比对文件](#51-比对文件)的参数说明配置stack_path;多卡场景开启时,自动识别npu_dump目录下stack.json文件,如存在生成详细调用栈信息,否则不生成,此参数不生效。通过直接配置该参数开启,默认未配置,表示关闭。 | 否 | | -c 或 --compare_only | 仅比对开关,bool 类型。该参数默认未配置,会启用自动精度分析,工具自动针对比对结果进行分析,识别到第一个精度可能不达标节点(在比对结果文件中的 Accuracy Reached or Not 列显示为 No),并给出问题可能产生的原因(打屏展示并生成 `advisor_{timestamp}.txt` 文件)。通过配置该参数取消自动精度分析,仅输出比对结果表格。 | 否 | | -f 或 --fuzzy_match | 模糊匹配,bool 类型。开启后,对于网络中同一层级且命名仅调用次数不同的 API,可匹配并进行比对。通过直接配置该参数开启,默认未配置,表示关闭。 | 否 | -| -dm或--data_mapping | 自定义映射关系比对。需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件](#215-自定义映射文件)。仅[API和模块无法自动匹配场景](#213-api和模块无法自动匹配场景)需要配置。仅支持逐卡比对,即使用[比对文件](#214-比对文件)的单卡场景示例。 | 否 | +| -hl 或 --highlight | 高亮颜色标记。开启后,比对结果件中通过红色或黄色标记精度可疑API或模块。通过直接配置该参数开启,默认未配置,表示关闭。 开启高亮颜色标记后,比对性能降低,如果比对结果行数超出excel单页限制,程序强制关闭高亮颜色标记。 | 否 | +| -dm或--data_mapping | 自定义映射关系比对。需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件](#52-自定义映射文件)。仅[API和模块无法自动匹配场景](#213-api和模块无法自动匹配场景)需要配置。仅支持逐卡比对,即使用[比对文件](#51-比对文件)的单卡场景示例。 | 否 | +| -da或--diff_analyze | 自动识别网络中首差异节点,支持md5、统计量等dump数据。支持单卡/多卡场景。 | 否 | #### 2.1.2 整网比对场景 @@ -66,19 +69,17 @@ msprobe -f pytorch compare -i ./compare.json -o ./output -s 支持单卡和多卡,可同时比对多卡的 dump 数据。多机场景需要每个设备单独执行比对操作。 -1. 配置[config.json](../config.json)文件。 +1. 参见 [PyTorch 场景下的数据采集](./05.data_dump_PyTorch.md)章节完成 CPU 或 GPU 与 NPU 的精度数据 dump。 -2. 参见 [PyTorch 场景下的数据采集](./05.data_dump_PyTorch.md)章节完成 CPU 或 GPU 与 NPU 的精度数据 dump。 +2. 创建[比对文件](#51-比对文件)。 -3. 创建[比对文件](#214-比对文件)。 - -4. 运行命令: +3. 运行命令: ```shell msprobe -f pytorch compare -i ./compare.json -o ./output -s ``` -5. 查看比对结果,请参见 [3 精度比对结果分析](#3-精度比对结果分析)。 +4. 查看比对结果,请参见 [3 精度比对结果分析](#3-精度比对结果分析)。 #### 2.1.3 API和模块无法自动匹配场景 @@ -88,7 +89,7 @@ msprobe -f pytorch compare -i ./compare.json -o ./output -s 2. 参见[PyTorch 场景下的数据采集](./05.data_dump_PyTorch.md)章节完成 CPU 或 GPU 与 NPU 的精度数据 dump。 -3. 创建[比对文件](#214-比对文件)(单卡场景示例)。 +3. 创建[比对文件](#51-比对文件)(单卡场景示例)。 4. 运行命令: @@ -96,75 +97,59 @@ msprobe -f pytorch compare -i ./compare.json -o ./output -s msprobe -f pytorch compare -i ./compare.json -o ./output -s -dm data_mapping.yaml ``` - data_mapping.yaml文件配置请参见[自定义映射文件](#215-自定义映射文件)。 + data_mapping.yaml文件配置请参见[自定义映射文件](#52-自定义映射文件)。 该场景不支持-f模糊匹配。 5. 查看比对结果,请参见 [3 精度比对结果分析](#3-精度比对结果分析)。 -#### 2.1.4 比对文件 - - 以在当前目录创建 ./compare.json 为例。 - - 单卡场景示例: +#### 2.1.4 单点数据比对场景 - ```json - { - "npu_path": "./npu_dump/dump.json", - "bench_path": "./bench_dump/dump.json", - "stack_path": "./npu_dump/stack.json", - "is_print_compare_log": true - } - ``` +单点数据比对场景是指:CPU 或 GPU 与 NPU环境的网络中单点保存的数据比对。 - - 多卡场景示例: +支持单卡和多卡,可同时比对多卡的单点数据。多机场景需要每个设备单独执行比对操作。 - ```json - { - "npu_path": "./npu_dump/step0", - "bench_path": "./bench_dump/step0", - "is_print_compare_log": true - } - ``` +1. 参见 [单点保存工具](./28.debugger_save_instruction.md)章节完成 CPU 或 GPU 与 NPU 的单点数据采集。 -**参数说明**: +2. 创建[比对文件(单点数据)](#53-比对文件单点数据)。 -| 参数名 | 说明 | 是否必选 | -| -------------------- |-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------| -| npu_path | 配置 NPU 环境下的 dump.json 文件(单卡场景)或真实数据目录(多卡场景),str 类型。 | 是 | -| bench_path | 配置 CPU、GPU 或 NPU 环境下的 dump.json 文件(单卡场景)或真实数据目录(多卡场景),str 类型。 | 是 | -| stack_path | 配置 NPU dump 目录下的 stack.json 文件,str 类型。如果没有配置stack_path,命令行-s参数不生效,程序自动识别是否存在stack.json文件,如存在,则比对结果中呈现NPU_Stack_Info,如不存在,则不呈现。如果配置了stack_path,比对结果中是否呈现NPU_Stack_Info则通过命令行参数-s来控制。 | 否 | -| is_print_compare_log | 配置是否开启单个算子的日志打屏。可取值 true 或 false,默认为 true。关闭后则只输出常规日志,bool 类型。 | 否 | +3. 运行命令: -#### 2.1.5 自定义映射文件 + ```shell + msprobe -f pytorch compare -i ./compare.json -o ./output + ``` -文件名格式:*.yaml,*为文件名,可自定义。 +4. 查看比对结果,请参见 [3 精度比对结果分析](#3-精度比对结果分析)。 -文件内容格式: +#### 2.1.5 首差异算子节点识别场景 -```yaml -# API -{api_type}.{api_name}.{API调用次数}.{前向反向}.{input/output}.{参数序号}: {api_type}.{api_name}.{API调用次数}.{前向反向}.{input/output}.{参数序号} -# 模块 -{Module}.{module_name}.{前向反向}.{index}.{input/output}.{参数序号}: {Module}.{module_name}.{前向反向}.{index}.{input/output}.{参数序号} -``` +首差异算子节点识别场景是指:XPU与NPU环境的网络中通过 `msprobe dump`保存数据的数据分析,找到网络精度问题中出现的首个差异算子节点。 -冒号左侧和右侧分别为PyTorch框架不同版本或不同芯片环境的API的名称和module模块名称。 +支持单卡和多卡,可同时比对多卡的dump数据。 -API和模块名称请从《[PyTorch 场景的精度数据采集](05.data_dump_PyTorch.md)》中的dump.json文件获取。 +执行步骤: -文件内容示例: +1. [config.json](../config.json)文件level配置为L0或L1、task配置为tensor或statistics(也可设置 `summary_mode`为 `md5`)并指定需要dump的API或模块名。 +2. 参见[PyTorch 场景下的数据采集](./05.data_dump_PyTorch.md)章节完成 CPU 或 GPU 与 NPU 的精度数据 dump。 +3. 创建[比对文件](#51-比对文件)。 +4. 运行命令: -```yaml -# API -NPU.npu_fusion_attention.4.forward.input.0: NPU.npu_fusion_attention.4.forward.input.0 -# 模块 -Module.module.language_model.embedding.word_embedding.VocabParallelEmbedding.forward.0.input.0: Module.module.language_model.embedding.word_embedding.VocabParallelEmbedding.forward.0.input.0 -``` + ```shell + msprobe -f pytorch compare -i ./compare.json -o ./output -da + ``` +5. 查看比对结果,在用户指定输出目录下会生成`compare_result_rank{rank_id}_{timestamp}.json`以及`diff_analyze_{timestamp}.json`。 + - 目录结构: + ``` + output/ + ├── compare_result_rank0_{timestamp}.json + ├── compare_result_rank1_{timestamp}.json + ├── diff_analyze_{timestamp}.json + ``` -API和模块名称在dump.json文件中的“data_name”字段展示,如下图红框处所示: + - `compare_result_rank{rank_id}_{timestamp}.json`:包含该rank比对结果,包括API或模块名、比对状态、比对指标等。 + - `diff_analyze_{timestamp}.json`:包含首差异算子节点识别结果,包括算子节点名、算子类型、算子位置等。 -![pt_dump](./img/pt_dump.png) ### 2.2 比对函数方式 @@ -180,13 +165,14 @@ compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_mat **参数说明**: -| 参数名 | 说明 | 是否必选 | -| ------------ |----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | +| 参数名 | 说明 | 是否必选 | +|--------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | | input_param | 配置 dump 数据文件及目录,dict 类型。配置参数包括:
"npu_json_path":指定 NPU dump 目录下的 dump.json 文件。
**配置示例**:"npu_json_path": "./npu_dump/dump.json"。
"bench_json_path":指定 CPU、GPU 或 NPU dump 目录下的 dump.json 文件。
**配置示例**:"bench_json_path": "./bench_dump/dump.json"。
"stack_json_path":指定 NPU dump 目录下的 stack.json 文件。
**配置示例**:"stack_json_path": "./npu_dump/stack.json"。
"is_print_compare_log":配置是否开启单个算子的日志打屏。
**配置示例**:True 或 False。 | 是 | -| output_path | 配置比对结果文件存盘目录,str 类型。
**配置示例**:'./output'。文件名称基于时间戳自动生成,格式为:`compare_result_{timestamp}.xlsx`。 | 是 | +| output_path | 配置比对结果文件存盘目录,str 类型。
**配置示例**:'./output'。文件名称基于时间戳自动生成,格式为:`compare_result_{timestamp}.xlsx`。
提示:output目录下与结果件同名文件将被删除覆盖。 | 是 | | stack_mode | 配置 stack_mode 的开关,bool 类型。仅当配置 stack_json_path 时需要,开启时比对结果呈现NPU_Stack_Info,关闭时不呈现。当不配置stack_json_path 时,自动识别是否存在stack.json,存在时呈现NPU_Stack_Info,否则不呈现。
**配置示例**:stack_mode=True,默认为 False。 | 否 | | auto_analyze | 自动精度分析,bool 类型。开启后工具自动针对比对结果进行分析,识别到第一个精度可能不达标节点(在比对结果文件中的 Accuracy Reached or Not 列显示为 No),并给出问题可能产生的原因(打屏展示并生成 advisor_{timestamp}.txt 文件)。
**配置示例**:auto_analyze=False,默认为 True。 | 否 | | fuzzy_match | 模糊匹配,bool 类型。开启后,对于网络中同一层级且命名仅调用次数不同的 API,可匹配并进行比对。
**配置示例**:fuzzy_match=True,默认为 False。 | 否 | +| highlight | 高亮颜色标记。开启后,比对结果件中通过红色或黄色标记精度可疑API或模块。 开启高亮颜色标记后,比对性能降低,如果比对结果行数超出excel单页限制,程序强制关闭高亮颜色标记。
**配置示例**:highlight=True,默认为 False。 | 否 | **函数示例**: @@ -215,12 +201,12 @@ compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs) **参数说明**: -| 参数名 | 说明 | 是否必选 | -| -------------- |-----------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | -| npu_dump_dir | 配置 NPU 环境下的 dump 目录。str 类型。dump 数据目录须指定到 step 级。
**配置示例**:'./npu_dump/step0'。 | 是 | -| bench_dump_dir | 配置 CPU、GPU 或 NPU 环境下的 dump 目录。str 类型。
**配置示例**:'./gpu_dump/step0'。 | 是 | -| output_path | 配置比对结果文件存盘目录。需要预先创建 output_path 目录。str 类型。
**配置示例**:'./output'。文件名称基于时间戳自动生成,格式为:`compare_result_rank{npu_ID}-rank{cpu/gpu/npu_ID}_{timestamp}.xlsx`。 | 是 | -| **kwargs | 支持 compare 的所有可选参数。 其中,stack_mode不生效,自动识别是否存在stack.json,如存在,呈现NPU_Stack_Info,否则不呈现。 | 否 | +| 参数名 | 说明 | 是否必选 | +| -------------- |------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | +| npu_dump_dir | 配置 NPU 环境下的 dump 目录。str 类型。dump 数据目录须指定到 step 级。
**配置示例**:'./npu_dump/step0'。 | 是 | +| bench_dump_dir | 配置 CPU、GPU 或 NPU 环境下的 dump 目录。str 类型。
**配置示例**:'./gpu_dump/step0'。 | 是 | +| output_path | 配置比对结果文件存盘目录。需要预先创建 output_path 目录。str 类型。
**配置示例**:'./output'。文件名称基于时间戳自动生成,格式为:`compare_result_rank{npu_ID}_{timestamp}.xlsx`。
提示:output目录下与结果件同名文件将被删除覆盖。 | 是 | +| **kwargs | 支持 compare 的所有可选参数。 其中,stack_mode不生效,自动识别是否存在stack.json,如存在,呈现NPU_Stack_Info,否则不呈现。 | 否 | **函数示例**: @@ -247,26 +233,28 @@ PyTorch 精度比对是以 CPU 或 GPU 的计算结果为标杆,通过计算 **公共表头**: -|dump 数据模式|NPU Name (NPU 的 API 名)|Bench Name (bench 的 API 名)|NPU Dtype (NPU 数据类型)|Bench Dtype (bench 数据类型)|NPU Tensor Shape (NPU 张量形状)|Bench Tensor Shape (bench 张量形状)| -|:-:|:-:|:-:|:-:|:-:|:-:|:-:| -|真实数据模式|√|√|√|√|√|√| -|统计数据模式|√|√|√|√|√|√| -|MD5 模式|√|√|√|√|√|√| +|dump 数据模式|NPU Name (NPU 的 API 名)|Bench Name (bench 的 API 名)|NPU Dtype (NPU 数据类型)|Bench Dtype (bench 数据类型)|NPU Tensor Shape (NPU 张量形状)|Bench Tensor Shape (bench 张量形状)| NPU Requires_grad (NPU tensor是否计算梯度) | Bench Requires_grad (Bench tensor是否计算梯度) | +|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:------------------------------------:|:----------------------------------------:| +|真实数据模式|√|√|√|√|√|√| √ | √ | +|统计数据模式|√|√|√|√|√|√| √ | √ | +|MD5 模式|√|√|√|√|√|√| √ | √ | **个性表头**: 统计量有 4 种:最大值(max)、最小值(min)、平均值(mean)和 L2-范数(L2 norm)。 -|dump 数据模式|Cosine (tensor 余弦相似度)|MaxAbsErr (tensor 最大绝对误差)|MaxRelativeErr (tensor 最大相对误差)|One Thousandth Err Ratio (tensor 相对误差小于千分之一的比例)|Five Thousandth Err Ratio (tensor 相对误差小于千分之五的比例)|NPU 和 bench 的统计量绝对误差 (max, min, mean, L2 norm) diff| NPU 和 bench 的统计量相对误差 (max, min, mean, L2 norm) RelativeErr |NPU 和 bench 的统计量 (max, min, mean, L2 norm)|NPU MD5 (NPU 数据 CRC-32 值)|BENCH MD5 (bench 数据 CRC-32 值)|Result (比对结果)|Accuracy Reached or Not (计算精度是否达标)|Err_message (错误信息提示)|NPU_Stack_Info (堆栈信息)|Data_Name (NPU 真实数据名)| -|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| -|真实数据模式|√|√|√|√|√|||√||||√|√|√|√| -|统计数据模式||||||√|√|√|||√||√|√|| -|MD5 模式|||||||||√|√|√|||√|| +|dump 数据模式|Cosine (tensor 余弦相似度)|EucDist (tensor 欧式距离)|MaxAbsErr (tensor 最大绝对误差)|MaxRelativeErr (tensor 最大相对误差)|One Thousandth Err Ratio (tensor 相对误差小于千分之一的比例)|Five Thousandth Err Ratio (tensor 相对误差小于千分之五的比例)|NPU 和 bench 的统计量绝对误差 (max, min, mean, L2 norm) diff| NPU 和 bench 的统计量相对误差 (max, min, mean, L2 norm) RelativeErr |NPU 和 bench 的统计量 (max, min, mean, L2 norm)|NPU MD5 (NPU 数据 CRC-32 值)|BENCH MD5 (bench 数据 CRC-32 值)| Requires_grad Consistent (计算梯度是否一致) | Result (比对结果) |Accuracy Reached or Not (计算精度是否达标)|Err_message (错误信息提示)|NPU_Stack_Info (堆栈信息)| Data_Name ([NPU真实数据名,Bench真实数据名]) | +|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:-----------------------------------:|:-------------:|:---:|:---:|:---:|:---------------------------------:| +|真实数据模式|√|√|√|√|√|√|||√||| √ | |√|√|√| √ | +|统计数据模式|||||||√|√|√||| √ | √ ||√|√| | +|MD5 模式||||||||||√|√| √ | √ |||√| | 上表中NPU_Stack_Info字段需要配置-s参数生成。 ### 3.2 颜色标记——真实数据模式、统计数据模式 +通过在命令行中配置-hl或--highlight开启,或者在比对函数中配置参数highlight=True开启,用于标记精度可疑API或模块。开启后,比对性能会有降低,建议比对较大dump.json文件时不配置此参数。 +颜色标记分为红色标记和黄色标记,红色标记优先级高于黄色标记。 在比对结果中的Err_message列呈现比对结果颜色标记的原因,具体含义如下: 红色标记情况: @@ -275,11 +263,12 @@ PyTorch 精度比对是以 CPU 或 GPU 的计算结果为标杆,通过计算 3. 一个 API 或模块的 One Thousandth Err Ratio 的 input/parameters > 0.9 同时 output < 0.6(真实数据模式)(仅标记output); 4. 一个 API 或模块的 output 的最大值相对误差 (Max diff 除以 max(0.01, Bench max)) > 0.5(统计数据模式)(仅标记output)。 -黄色标记情况(仅标记output): +黄色标记情况(1-4仅标记output,5无限制): 1. 一个 API 或模块的 input/parameters 与 output 的最大值绝对误差都大于 1,同时 output 比 input/parameters 大一个数量级以上(真实数据模式、统计数据模式); 2. 一个 API 或模块的 One Thousandth Err Ratio 的 input/parameters - output > 0.1(真实数据模式); 3. 一个 API 或模块的 output 的最大值相对误差 > 0.1 同时 input/parameters < 0.01(真实数据模式,统计数据模式); -4. 一个 API 或模块的 Cosine 的 input/parameters - output > 0.1(真实数据模式)。 +4. 一个 API 或模块的 Cosine 的 input/parameters - output > 0.1(真实数据模式); +5. 一个 API 或模块的 Requires_grad Consistent 为 False。 ### 3.3 比对结果(Result)——统计数据模式、MD5 模式 @@ -314,15 +303,19 @@ MD5 模式: ### 3.5 错误信息提示(Err_message)——真实数据模式、统计数据模式 1. "Need double check api accuracy.":四个统计值中至少 1 个相对误差 > 0.5(统计数据模式); -2. "Fuzzy matching data, the comparison arruracy may be affected.":NPU 或 Bench 的真实数据名没有匹配上(真实数据模式); -3. "Dump file: {} not found.":NPU 真实数据不存在或者读取出错(真实数据模式); -4. "No bench data matched.":Bench 的 API 没有匹配上、Bench 真实数据不存在或读取出错(真实数据模式); -5. "This is empty data, can not compare.":读取到的数据为空(真实数据模式); -6. "Shape of NPU and bench Tensor do not match. Skipped.":NPU 和 Bench 的数据结构不一致(真实数据模式); -7. "The Position of inf or nan in NPU and bench Tensor do not match.":NPU 和 Bench 的数据有 nan/inf(真实数据模式); -8. "This is type of 0-d tensor, can not calculate 'Cosine', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'.":NPU 为0维张量(真实数据模式); -9. "Dtype of NPU and bench Tensor do not match.":NPU 和 Bench 数据的数据类型不同(真实数据模式); -10. "":除以上情况的其余情况(真实数据模式、统计数据模式)。 +2. "Fuzzy matching data, the comparison accuracy may be affected.":NPU 或 Bench 的真实数据名没有匹配上(真实数据模式); +3. "Dump file: {} not found or read failed.":NPU 或 Bench 的真实数据者读取出错(真实数据模式); +4. "No bench data matched.":Bench 的 API 没有匹配上(真实数据模式,统计数据模式); +5. "NPU does not have data file.": NPU的真实数据不存在(真实数据模式); +6. "Bench does not have data file.": Bench的真实数据不存在(真实数据模式); +7. "Bench api/module unmatched.":Bench 的 API 没有匹配上(真实数据模式); +8. "This is empty data, can not compare.":读取到的数据为空(真实数据模式); +9. "Shape of NPU and bench Tensor do not match. Skipped.":NPU 和 Bench 的数据结构不一致(真实数据模式); +10. "The Position of inf or nan in NPU and bench Tensor do not match.":NPU 和 Bench 的数据有 nan/inf(真实数据模式); +11. "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'.":NPU 为0维张量(真实数据模式); +12. "Dtype of NPU and bench Tensor do not match.":NPU 和 Bench 数据的数据类型不同(真实数据模式); +13. "Requires_grad inconsistent.":NPU 和 Bench 的 Requires_grad 不一致(真实数据模式,统计数据模式); +14. "":除以上情况的其余情况(真实数据模式、统计数据模式)。 除以上错误信息提示外,异常数据颜色高亮标记的原因叠加呈现于此列。 @@ -330,13 +323,15 @@ MD5 模式: 1. Cosine:通过计算两个向量的余弦值来判断其相似度,数值越接近于 1 说明计算出的两个张量越相似,实际可接受阈值为大于 0.99。在计算中可能会存在 nan,主要由于可能会出现其中一个向量为 0。 -2. MaxAbsErr:当最大绝对误差越接近 0 表示其计算的误差越小,实际可接受阈值为小于 0.001。 +2. EucDist:通过计算两个向量的欧式距离来判断其相似度,定义为多维空间中两个点之间的绝对距离。数值越接近0,张量越相似,数值越大,差异越大。 + +3. MaxAbsErr:当最大绝对误差越接近 0 表示其计算的误差越小,实际可接受阈值为小于 0.001。 -3. MaxRelativeErr:当最大相对误差越接近 0 表示其计算的误差越小。 +4. MaxRelativeErr:当最大相对误差越接近 0 表示其计算的误差越小。 - 当 dump 数据中存在 0 或 Nan 时,比对结果中最大相对误差则出现 inf 或 Nan 的情况,属于正常现象。 + 当 dump 数据中存在 0 或 nan 时,比对结果中最大相对误差则出现 inf 或 nan 的情况,属于正常现象。 -4. One Thousandth Err Ratio(相对误差小于千分之一的元素比例)、Five Thousandths Err Ratio(相对误差小于千分之五的元素比例)精度指标:是指 NPU 的 Tensor 中的元素逐个与对应的标杆数据对比,相对误差小于千分之一、千分之五的比例占总元素个数的比例。该数据仅作为精度下降趋势的参考,并不参与计算精度是否通过的判定。 +5. One Thousandth Err Ratio(相对误差小于千分之一的元素比例)、Five Thousandths Err Ratio(相对误差小于千分之五的元素比例)精度指标:是指 NPU 的 Tensor 中的元素逐个与对应的标杆数据对比,相对误差小于千分之一、千分之五的比例占总元素个数的比例。该数据仅作为精度下降趋势的参考,并不参与计算精度是否通过的判定。 ## 4 多卡比对结果提取汇总通信算子数据 @@ -358,11 +353,12 @@ msprobe -f pytorch merge_result -i ./input_dir -o ./output_dir -config ./config. **完整参数说明** -| 参数名 | 说明 | 是否必选 | -| ---------------------- |------------------------------------------------------------------------------------| -------- | -| -i 或 --input_dir | 多卡比对结果存盘目录,即使用compare比对的结果输出目录,str类型。所有比对结果应全部为真实数据比对结果或统计数据比对结果,否则可能导致汇总数据不完整。 | 是 | -| -o 或 --output_dir | 数据提取汇总结果存盘目录,str类型。文件名称基于时间戳自动生成,格式为:`multi_ranks_compare_merge_{timestamp}.xlsx`。 | 是 | -| -config或--config-path | 指定需要汇总数据的API和比对指标的yaml文件路径,str类型。
yaml文件详细介绍见下文“**yaml文件说明**”。 | 是 | +| 参数名 | 说明 | 是否必选 | +| --------------------- |-------------------------------------------------------------------------------------------------------------------| -------- | +| -f 或 --framework | 指定训练框架。pytorch。 | 是 | +| -i 或 --input_dir | 多卡比对结果存盘目录,即使用compare比对的结果输出目录,str类型。所有比对结果应全部为真实数据比对结果或统计数据比对结果,否则可能导致汇总数据不完整。 | 是 | +| -o 或 --output_dir | 数据提取汇总结果存盘目录,str类型。文件名称基于时间戳自动生成,格式为:`multi_ranks_compare_merge_{timestamp}.xlsx`。
提示:output目录下与结果件同名文件将被删除覆盖。 | 是 | +| -config或--config-path | 指定需要汇总数据的API和比对指标的yaml文件路径,str类型。
yaml文件详细介绍见下文“**yaml文件说明**”。 | 是 | **yaml文件说明** @@ -378,10 +374,10 @@ compare_index: - MeanRelativeErr ``` -| 参数名 | 说明 | -| ------------- | ------------------------------------------------------------ | -| api | 表示需要汇总的API或module名称。如果没有配置,工具会提示报错。
api名称配置格式为:`{api_type}.{api_name}.{API调用次数}.{前向反向}`
须按顺序配置以上四个字段,可按如下组合配置:
{api_type}
{api_type}.{api_name}
{api_type}.{api_name}.{API调用次数}
{api_type}.{api_name}.{API调用次数}.{前向反向}
这里的api指代API或module。 | -| compare_index | 表示需要汇总的比对指标。compare_index需为dump_mode对应比对指标的子集。如果没有配置,工具将根据比对结果自动提取dump_mode对应的全部比对指标进行汇总。
统计数据模式比对指标:Max diff、Min diff、Mean diff、Norm diff、MaxRelativeErr、MinRelativeErr、MeanRelativeErr、NormRelativeErr
真实数据模式比对指标:Cosine、MaxAbsErr、MaxRelativeErr、One Thousandth Err Ratio、Five Thousandths Err Ratio | +| 参数名 | 说明 | +| ------------- |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| api | 表示需要汇总的API或module名称。如果没有配置,工具会提示报错。
api名称配置格式为:`{api_type}.{api_name}.{API调用次数}.{前向反向}`
须按顺序配置以上四个字段,可按如下组合配置:
{api_type}
{api_type}.{api_name}
{api_type}.{api_name}.{API调用次数}
{api_type}.{api_name}.{API调用次数}.{前向反向}
这里的api指代API或module。 | +| compare_index | 表示需要汇总的比对指标。compare_index需为dump_mode对应比对指标的子集。如果没有配置,工具将根据比对结果自动提取dump_mode对应的全部比对指标进行汇总。
统计数据模式比对指标:Max diff、Min diff、Mean diff、L2norm diff、MaxRelativeErr、MinRelativeErr、MeanRelativeErr、NormRelativeErr、Requires_grad Consistent
真实数据模式比对指标:Cosine、EucDist、MaxAbsErr、MaxRelativeErr、One Thousandth Err Ratio、Five Thousandths Err Ratio、Requires_grad Consistent | **汇总结果件说明** @@ -412,4 +408,189 @@ compare_index: 6. Distributed.broadcast:输入为要广播的数据,输出为广播后的数据。 7. Distributed.isend:点对点通信,输入为要发送的数据,输出为发送的数据。 8. Distributed.irecv:点对点通信,输入为原数据,输出为接收的新数据。 -9. Distributed.all_to_all_single:输出数据为所有卡上的数据切分后合并的结果。 \ No newline at end of file +9. Distributed.all_to_all_single:输出数据为所有卡上的数据切分后合并的结果。 + +## 5 附录 + +### 5.1 比对文件 + + 以在当前目录创建 ./compare.json 为例。 + + - 单卡场景示例: + + ```json + { + "npu_path": "./npu_dump/dump.json", + "bench_path": "./bench_dump/dump.json", + "stack_path": "./npu_dump/stack.json", + "is_print_compare_log": true + } + ``` + + - 多卡场景示例: + + ```json + { + "npu_path": "./npu_dump/step0", # 需填写到step层级(rank的上一层级) + "bench_path": "./bench_dump/step0", # 需填写到step层级(rank的上一层级) + "is_print_compare_log": true + } + ``` + +**参数说明** + +| 参数名 | 说明 | 是否必选 | +| -------------------- |-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------| +| npu_path | 配置NPU环境下的dump.json文件(单卡场景)或dump目录(多卡场景)。数据类型:str。 | 是 | +| bench_path | 配置CPU、GPU或NPU环境下的dump.json文件(单卡场景)或dump目录(多卡场景)。数据类型:str。 | 是 | +| stack_path | 配置NPU dump目录下的stack.json文件。数据类型:str。如果没有配置stack_path,命令行-s参数不生效,程序自动识别是否存在stack.json文件,如存在,则比对结果中呈现NPU_Stack_Info,如不存在,则不呈现。如果配置了stack_path,比对结果中是否呈现NPU_Stack_Info则通过命令行参数-s来控制。 | 否 | +| is_print_compare_log | 配置是否开启单个算子的日志打屏。可取值true或false,默认为true。关闭后则只输出常规日志。数据类型:bool。 | 否 | + +### 5.2 自定义映射文件 + +文件名格式:*.yaml,*为文件名,可自定义。 + +文件内容格式: + +```yaml +# API +{api_type}.{api_name}.{API调用次数}.{前向反向}.{input/output}.{参数序号}: {api_type}.{api_name}.{API调用次数}.{前向反向}.{input/output}.{参数序号} +# 模块 +{Module}.{module_name}.{前向反向}.{index}.{input/output}.{参数序号}: {Module}.{module_name}.{前向反向}.{index}.{input/output}.{参数序号} +``` + +冒号左侧和右侧分别为PyTorch框架不同版本或不同芯片环境的API的名称和module模块名称。 + +API和模块名称请从《[PyTorch 场景的精度数据采集](05.data_dump_PyTorch.md)》中的dump.json文件获取。 + +文件内容示例: + +```yaml +# API +NPU.npu_fusion_attention.4.forward.input.0: NPU.npu_fusion_attention.4.forward.input.0 +# 模块 +Module.module.language_model.embedding.word_embedding.VocabParallelEmbedding.forward.0.input.0: Module.module.language_model.embedding.word_embedding.VocabParallelEmbedding.forward.0.input.0 +``` + +当dump.json文件中存在“data_name”字段时,API和模块名称为data_name字段去掉文件后缀,如下图红框处所示: + +![pt_dump](./img/pt_dump.png) + +当dump.json文件中不存在“data_name”字段时,名称的拼写规则如下: + +input_args、input_kwargs和output使用统一的命名规则,当值是list类型时,名称后面添加'.{index}',当值类型是dict类型时,名称后面加'.{key}',当值类型是具体Tensor或null或int或float或bool或空list/dict等时,命名结束。 + +以下面api的dump文件为例: +```yaml + "Functional.max_pool2d.0.forward": { + "input_args": [ + { + "type": "torch.Tensor", + "dytpe": "torch_float32", + "shape": [ + 1, + 64, + 14, + 14 + ], + "Max": xxx, + "Min": xxx, + "Mean": xxx, + "Norm": xxx, + "requires_grad": true + }, + { + "type": "int", + "value": 3 + }, + { + "type": "int", + "value": 2 + }, + { + "type": "int", + "value": 1 + }, + { + "type": "int", + "value": 1 + } + ], + "input_kwargs": { + "ceil_mode": { + "type": "bool", + "value": false + }, + "return_indices": { + "type": "bool", + "value": false + }, + }, + "output": [ + { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": [ + 1, + 64, + 7, + 7 + ], + "Max": xxx, + "Min": xxx, + "Mean": xxx, + "Norm": xxx, + "requires_grad": true + } + ] + } +``` + +初始名称为Functional.max_pool2d.0.forward,input_args是list,长度为5,第0项后面是Tensor,命名结束;第1-4项后面均是int,命名结束;按照顺序命名为 +``` +Functional.max_pool2d.0.forward.input.0 +Functional.max_pool2d.0.forward.input.1 +Functional.max_pool2d.0.forward.input.2 +Functional.max_pool2d.0.forward.input.3 +Functional.max_pool2d.0.forward.input.4 +``` +input_kwargs是dict,key是ceil_mode、return_indices,值均是bool,命名结束;命名为 +``` +Functional.max_pool2d.0.forward.input.ceil_mode +Functional.max_pool2d.0.forward.input.return_indices +``` +output是list,长度为1,第0项后面是Tensor,命名结束;按照顺序命名为 +``` +Functional.max_pool2d.0.forward.output.0 +``` +综上,生成的的op_name为 +``` +Functional.max_pool2d.0.forward.input.0 +Functional.max_pool2d.0.forward.input.1 +Functional.max_pool2d.0.forward.input.2 +Functional.max_pool2d.0.forward.input.3 +Functional.max_pool2d.0.forward.input.4 +Functional.max_pool2d.0.forward.input.ceil_mode +Functional.max_pool2d.0.forward.input.return_indices +Functional.max_pool2d.0.forward.output.0 +``` + +### 5.3 比对文件(单点数据) + + - 单卡场景示例: + + ```json + { + "npu_path": "./npu_dump/debug.json", + "bench_path": "./bench_dump/debug.json" + } + ``` + + - 多卡场景示例(step0目录下包含debug.json文件): + + ```json + { + "npu_path": "./npu_dump/step0", + "bench_path": "./bench_dump/step0" + } + ``` \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/11.accuracy_compare_MindSpore.md b/debug/accuracy_tools/msprobe/docs/11.accuracy_compare_MindSpore.md index 1b1824a774f15a86106585669d5f3412b3faca2e..a3b6b64f7773fa178f930e617de117842954677e 100644 --- a/debug/accuracy_tools/msprobe/docs/11.accuracy_compare_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/11.accuracy_compare_MindSpore.md @@ -19,7 +19,7 @@ msprobe精度比对工具主要用于如下场景: - 通过对同一个网络模型,在整网环境下分别在MindSpore动态图和PyTorch环境下获得API或模块dump数据,由用户指定可以比对的API或模块,以PyTorch数据作为标杆,进行自动比对,从而实现跨框架的精度对比。 - 通过对同一个网络模型,在整网环境下分别在MindSpore动态图和PyTorch环境下获得API或模块dump数据,由用户指定可以比对的模型代码中的Layer层,以PyTorch数据作为标杆,进行自动比对,从而实现跨框架的精度对比。 -执行精度比对操作需要安装msprobe工具。详见《[MindStudio精度调试工具](../README.md)》的“工具安装”章节。 +执行精度比对操作需要安装msprobe工具。详见[《msprobe 工具安装指南》](./01.installation.md)。 ## 2 命令行比对 @@ -35,17 +35,20 @@ msprobe -f mindspore compare -i ./compare.json -o ./output -s **完整参数说明** -| 参数名 | 说明 | 是否必选 | -| -------------------- | ------------------------------------------------------------ | -------- | -| -i或--input_path | 指定比对文件。比对文件内容及示例请参见[比对文件](#31-比对文件)或[比对文件(kernel)](#32-比对文件kernel)(比对文件(kernel)仅[不同版本下的全量kernel比对](#23-不同版本下的全量kernel比对)场景支持)。 | 是 | -| -o或--output_path | 配置比对结果文件存盘目录,默认会在当前目录创建output目录。文件名称基于时间戳自动生成,格式为:
`compare_result_{timestamp}.xlsx`
`compare_result_{rank_id}_{step_id}_{timestamp}.xlsx`(仅[不同版本下的全量kernel比对](#23-不同版本下的全量kernel比对)场景支持)。 | 否 | -| -s或--stack_mode | 比对结果展示调用栈信息(NPU_Stack_Info)的开关,bool 类型。单卡场景开启时,需要使用[比对文件](#31-比对文件)的单卡场景配置stack_path指定stack.json文件,才能生成详细调用栈信息,否则在比对时会报错;暂不支持多卡场景。通过直接配置该参数开启,默认未配置,表示关闭。 | 否 | -| -c或--compare_only | 仅比对开关,bool 类型。该参数默认未配置,会启用自动精度分析,工具自动针对比对结果进行分析,识别到第一个精度可能不达标节点(在比对结果文件中的 Accuracy Reached or Not 列显示为 No),并给出问题可能产生的原因(打屏展示并生成 `advisor_{timestamp}.txt` 文件)。通过配置该参数取消自动精度分析,仅输出比对结果表格。 | 否 | -| -f或--fuzzy_match | 模糊匹配。开启后,对于网络中同一层级且命名仅调用次数不同的API,可匹配并进行比对。通过直接配置该参数开启,默认未配置,表示关闭。 | 否 | -| -am或--api_mapping | 跨框架比对。配置该参数时表示开启跨框架API比对功能,可以指定自定义映射文件*.yaml,不指定映射文件时按照msprobe定义的默认映射关系进行比对。自定义映射文件的格式请参见[自定义映射文件(api_mapping)](#33-自定义映射文件api_mapping)。仅[跨框架的API比对](#25-跨框架的api比对)场景需要配置。 | 否 | -| -cm或--cell_mapping | 跨框架比对。配置该参数时表示开启跨框架cell模块比对功能,可以指定自定义映射文件*.yaml,不指定映射文件时按照msprobe定义的默认映射关系进行比对。自定义映射文件的格式请参见[自定义映射文件(cell_mapping)](#34-自定义映射文件cell_mapping)。仅[跨框架的cell模块比对](#26-跨框架的cell模块比对)场景需要配置。 | 否 | -| -dm或--data_mapping | 同框架或跨框架比对。通过映射文件指定两个具体参数的对应关系,可以在L0、L1或mix采集场景下使用。配置该参数的同时需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件(data_mapping)](#35-自定义映射文件data_mapping)。 | 否 | -| -lm或--layer_mapping | 跨框架比对。配置该参数时表示开启跨框架Layer层的比对功能,指定模型代码中的Layer层后,可以识别对应dump数据中的模块或API。需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件(Layer_mapping)](#36-自定义映射文件layer_mapping)。仅[跨框架的Layer层比对](#27-跨框架的layer层比对)场景需要配置。 | 否 | +| 参数名 | 说明 | 是否必选 | +|---------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | +| -f 或 --framework | 指定训练框架。mindspore。 | 是 | +| -i或--input_path | 指定比对文件。比对文件内容及示例请参见[比对文件](#41-比对文件)或[比对文件(kernel)](#42-比对文件kernel)(比对文件(kernel)仅[不同版本下的全量kernel比对](#23-不同版本下的全量kernel比对)场景支持)。 | 是 | +| -o或--output_path | 配置比对结果文件存盘目录,默认会在当前目录创建output目录。文件名称基于时间戳自动生成,格式为:
`compare_result_{timestamp}.xlsx`
`compare_result_{rank_id}_{step_id}_{timestamp}.xlsx`(仅[不同版本下的全量kernel比对](#23-不同版本下的全量kernel比对)场景支持)。
提示:output目录下与结果件同名文件将被删除覆盖。 | 否 | +| -s或--stack_mode | 比对结果展示调用栈信息(NPU_Stack_Info)的开关,bool 类型。单卡场景开启时,需要使用[比对文件](#41-比对文件)的单卡场景配置stack_path指定stack.json文件,才能生成详细调用栈信息,否则在比对时会报错;暂不支持多卡场景。通过直接配置该参数开启,默认未配置,表示关闭。 | 否 | +| -c或--compare_only | 仅比对开关,bool 类型。该参数默认未配置,会启用自动精度分析,工具自动针对比对结果进行分析,识别到第一个精度可能不达标节点(在比对结果文件中的 Accuracy Reached or Not 列显示为 No),并给出问题可能产生的原因(打屏展示并生成 `advisor_{timestamp}.txt` 文件)。通过配置该参数取消自动精度分析,仅输出比对结果表格。 | 否 | +| -f或--fuzzy_match | 模糊匹配。开启后,对于网络中同一层级且命名仅调用次数不同的API,可匹配并进行比对。通过直接配置该参数开启,默认未配置,表示关闭。 | 否 | +| -hl或--highlight | 高亮颜色标记。开启后,比对结果件中通过红色或黄色标记精度可疑API或模块。通过直接配置该参数开启,默认未配置,表示关闭。 开启高亮颜色标记后,比对性能降低,如果比对结果行数超出excel单页限制,程序强制关闭高亮颜色标记。 | 否 | +| -am或--api_mapping | 跨框架比对。配置该参数时表示开启跨框架API比对功能,可以指定自定义映射文件*.yaml,不指定映射文件时按照msprobe定义的默认映射关系进行比对。自定义映射文件的格式请参见[自定义映射文件(api_mapping)](#43-自定义映射文件api_mapping)。仅[跨框架的API比对](#25-跨框架的api比对)场景需要配置。 | 否 | +| -cm或--cell_mapping | 跨框架比对。配置该参数时表示开启跨框架cell模块比对功能,可以指定自定义映射文件*.yaml,不指定映射文件时按照msprobe定义的默认映射关系进行比对。自定义映射文件的格式请参见[自定义映射文件(cell_mapping)](#44-自定义映射文件cell_mapping)。仅[跨框架的cell模块比对](#26-跨框架的cell模块比对)场景需要配置。 | 否 | +| -dm或--data_mapping | 同框架或跨框架比对。通过映射文件指定两个具体参数的对应关系,可以在L0、L1或mix采集场景下使用。配置该参数的同时需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件(data_mapping)](#45-自定义映射文件data_mapping)。 | 否 | +| -lm或--layer_mapping | 跨框架比对。配置该参数时表示开启跨框架Layer层的比对功能,指定模型代码中的Layer层后,可以识别对应dump数据中的模块或API。需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件(Layer_mapping)](#46-自定义映射文件layer_mapping)。仅[跨框架的Layer层比对](#27-跨框架的layer层比对)场景需要配置。 | 否 | +| -da或--diff_analyze | 自动识别网络中首差异节点,支持md5、统计量等dump数据。支持单卡/多卡场景。 | 否 | 动态图模式没有填写任何mapping时,按照同框架比对的方式进行比对,比对数据和标杆数据的Cell或Api名称需要完全相同才能匹配得上。 @@ -53,7 +56,7 @@ msprobe -f mindspore compare -i ./compare.json -o ./output -s 1. 参见《[MindSpore 场景的精度数据采集](./06.data_dump_MindSpore.md)》完成不同环境下MindSpore静态图精度数据的采集,得到不同框架版本的API dump数据。 -2. 创建比对文件,文件内容及示例请参见[比对文件](#31-比对文件)。 +2. 创建比对文件,文件内容及示例请参见[比对文件](#41-比对文件)。 3. 执行如下示例命令进行比对: @@ -67,7 +70,7 @@ msprobe -f mindspore compare -i ./compare.json -o ./output -s 1. 参见《[MindSpore 场景的精度数据采集](./06.data_dump_MindSpore.md)》完成不同环境下MindSpore静态图精度数据的采集,得到不同框架版本的kernel dump数据。 -2. 创建比对文件,文件内容及示例请参见[比对文件(kernel)](#32-比对文件kernel)。 +2. 创建比对文件,文件内容及示例请参见[比对文件(kernel)](#42-比对文件kernel)。 3. 执行如下示例命令进行比对: @@ -85,7 +88,7 @@ msprobe -f mindspore compare -i ./compare.json -o ./output -s 2. 参见《[MindSpore 场景的精度数据采集](./06.data_dump_MindSpore.md)》完成不同环境下MindSpore动态图精度数据的采集,得到不同框架版本的cell模块dump数据。 -3. 创建比对文件,文件内容及示例请参见[比对文件](#31-比对文件)。 +3. 创建比对文件,文件内容及示例请参见[比对文件](#41-比对文件)。 4. 执行如下示例命令进行比对: @@ -101,7 +104,7 @@ msprobe -f mindspore compare -i ./compare.json -o ./output -s 2. 参见《[MindSpore 场景的精度数据采集](./06.data_dump_MindSpore.md)》和《[PyTorch 场景的精度数据采集](./05.data_dump_PyTorch.md)》完成不同环境下API精度数据的采集,得到两个框架的API dump数据。 -3. 创建比对文件,文件内容及示例请参见[比对文件](#31-比对文件)。 +3. 创建比对文件,文件内容及示例请参见[比对文件](#41-比对文件)。 4. 执行如下示例命令进行比对: @@ -115,14 +118,14 @@ msprobe -f mindspore compare -i ./compare.json -o ./output -s msprobe -f mindspore compare -i ./compare.json -o ./output -s -am api_mapping.yaml ``` - api_mapping.yaml文件配置请参见[自定义映射文件(api_mapping)](#33-自定义映射文件api_mapping)。 + api_mapping.yaml文件配置请参见[自定义映射文件(api_mapping)](#43-自定义映射文件api_mapping)。 不传入api_mapping.yaml的情况下将按照内置的api映射进行匹配;传入api_mapping.yaml的情况下优先按照api_mapping.yaml的内容进行匹配,api_mapping.yaml中没有涉及的按照内置的api映射进行匹配。 此外,也可以通过data_mapping.yaml文件实现具体参数的匹配,例: ```shell msprobe -f mindspore compare -i ./compare.json -o ./output -s -dm data_mapping.yaml ``` - data_mapping.yaml的写法请参见[自定义映射文件(data_mapping)](#35-自定义映射文件data_mapping)。 + data_mapping.yaml的写法请参见[自定义映射文件(data_mapping)](#45-自定义映射文件data_mapping)。 5. 查看比对结果,请详见PyTorch目录下的《[PyTorch 场景的精度比对-精度比对结果分析](./10.accuracy_compare_PyTorch.md#3-精度比对结果分析)》章节。 @@ -132,7 +135,7 @@ msprobe -f mindspore compare -i ./compare.json -o ./output -s 2. 参见《[MindSpore 场景的精度数据采集](./06.data_dump_MindSpore.md)》和《[PyTorch 场景的精度数据采集](./05.data_dump_PyTorch.md)》完成不同环境下cell模块精度数据的采集,得到两个框架的cell模块dump数据。 -3. 创建比对文件,文件内容及示例请参见[比对文件](#31-比对文件)。 +3. 创建比对文件,文件内容及示例请参见[比对文件](#41-比对文件)。 4. 执行如下示例命令进行比对: @@ -146,14 +149,14 @@ msprobe -f mindspore compare -i ./compare.json -o ./output -s msprobe -f mindspore compare -i ./compare.json -o ./output -s -cm cell_mapping.yaml ``` - cell_mapping.yaml文件配置请参见[自定义映射文件(cell_mapping)](#34-自定义映射文件cell_mapping)。 + cell_mapping.yaml文件配置请参见[自定义映射文件(cell_mapping)](#44-自定义映射文件cell_mapping)。 不传入cell_mapping.yaml的情况下仅将Cell改成Module后进行匹配;传入cell_mapping.yaml的情况下将按照cell_mapping.yaml的内容进行匹配。 此外,也可以通过data_mapping.yaml文件实现具体参数的匹配,例: ```shell msprobe -f mindspore compare -i ./compare.json -o ./output -s -dm data_mapping.yaml ``` - data_mapping.yaml的写法请参见[自定义映射文件(data_mapping)](#35-自定义映射文件data_mapping)。 + data_mapping.yaml的写法请参见[自定义映射文件(data_mapping)](#45-自定义映射文件data_mapping)。 5. 查看比对结果,请详见PyTorch目录下的《[PyTorch 场景的精度比对-精度比对结果分析](./10.accuracy_compare_PyTorch.md#3-精度比对结果分析)》章节。 @@ -165,7 +168,7 @@ layer_mapping可以从Layer层识别整网的API和Cell,简化配置。 2. 参见《[MindSpore 场景的精度数据采集](./06.data_dump_MindSpore.md)》和《[PyTorch 场景的精度数据采集](./05.data_dump_PyTorch.md)》完成不同环境下API或模块精度数据的采集,得到两个框架的API或模块dump数据。 -3. 创建比对文件,文件内容及示例请参见[比对文件](#31-比对文件)。 +3. 创建比对文件,文件内容及示例请参见[比对文件](#41-比对文件)。 4. 执行如下示例命令进行比对: @@ -173,16 +176,89 @@ layer_mapping可以从Layer层识别整网的API和Cell,简化配置。 msprobe -f mindspore compare -i ./compare.json -o ./output -s -lm layer_mapping.yaml ``` - layer_mapping.yaml文件配置请参见[自定义映射文件(layer_mapping)](#36-自定义映射文件layer_mapping)。 + layer_mapping.yaml文件配置请参见[自定义映射文件(layer_mapping)](#46-自定义映射文件layer_mapping)。 此外,也可以通过data_mapping.yaml文件实现具体参数的匹配,例: ```shell msprobe -f mindspore compare -i ./compare.json -o ./output -s -dm data_mapping.yaml ``` - data_mapping.yaml的写法请参见[自定义映射文件(data_mapping)](#35-自定义映射文件data_mapping)。 + data_mapping.yaml的写法请参见[自定义映射文件(data_mapping)](#45-自定义映射文件data_mapping)。 5. 查看比对结果,请详见PyTorch目录下的《[PyTorch 场景的精度比对-精度比对结果分析](./10.accuracy_compare_PyTorch.md#3-精度比对结果分析)》章节。 +### 2.8 单点数据比对 +1. 参见 [单点保存工具](./28.debugger_save_instruction.md)章节完成 CPU 或 GPU 与 NPU 的单点数据采集。 + +2. 创建比对文件,文件内容及示例请参见[比对文件(单点数据)](#47-比对文件单点数据)。 + +3. 执行如下示例命令进行比对: + + ```shell + msprobe -f mindspore compare -i ./compare.json -o ./output + ``` + +4. Pytorch & MindSpore 动态图场景查看比对结果,请详见PyTorch目录下的《[PyTorch 场景的精度比对-精度比对结果分析](./10.accuracy_compare_PyTorch.md#3-精度比对结果分析)》章节。 +MindSpore静态图场景比对结果: +- `result.csv` 文件列出了所有执行精度比对的 单点保存数据 详细信息和比对结果,示例如下: + + ![compare_result](./img/save_compare_result_sample.png) +具体字段含义同PyTorch目录下的《[PyTorch 场景的精度比对-精度比对结果分析](./10.accuracy_compare_PyTorch.md#3-精度比对结果分析)》章节。 + +### 2.9 动静态图场景L0混合dump数据比对 + +1. 参见 [msprobe工具MindSpore场景精度数据采集指南](./06.data_dump_MindSpore.md),执行dump操作。
动态图场景下使用 `mindspore.jit` 装饰特定 Cell 或 function 时,被装饰的部分会被编译成静态图执行。采集的数据文件目录结构示例如下: + ```lua + ├── graph + │ ├── step0 + │ | ├── rank + │ | │ ├── dump_tensor_data + | | | | ├── Cell.wrap_net.net.Net.forward.0.input.0.npy + | | | | ├── Cell.wrap_net.net.Net.forward.0.output.0.npy + | | | | ... + │ | | ├── dump.json + │ | | ├── stack.json + │ | | └── construct.json + │ ├── ... + ├── pynative + │ ├── step0 + │ | ├── rank + │ | │ ├── dump_tensor_data + | | | | ├── Cell.dense1.Dense.forward.0.input.0.npy + | | | | ├── Cell.dense1.Dense.forward.0.output.0.npy + | | | | ... + │ | | ├── dump.json + │ | | ├── stack.json + │ | | └── construct.json + │ ├── ... + ``` + +2. 创建比对文件,文件内容及示例请参见[比对文件(动静态图场景L0混合数据)](#48-比对文件动静态图场景l0混合数据)。 + +3. 执行如下示例命令进行比对: + + ```shell + msprobe -f mindspore compare -i ./compare.json -o ./output + ``` + +4. 动静态图场景L0混合dump数据比对结果,示例如下: + ```lua + ├── graph + │ ├── step0 + │ | ├── advisor_rank_20250805043414.txt + │ | ├── compare_result_rank_20250805043411.xlsx + ├── pynative + │ ├── step0 + │ | ├── advisor_rank_20250805043416.txt + │ | ├── compare_result_rank_20250805043414.xlsx + ``` + +output目录下生成两个graph和pynative两个文件夹,每个文件夹下生成对应step的比对结果。 + +5. 查看比对结果,请详见PyTorch目录下的《[PyTorch 场景的精度比对-精度比对结果分析](./10.accuracy_compare_PyTorch.md#3-精度比对结果分析)》章节。 + +### 2.10 首差异算子节点识别 +参见《[PyTorch 场景的精度比对-首差异算子节点识别](./10.accuracy_compare_PyTorch.md#215-首差异算子节点识别场景)》章节。 + ## 3 多卡比对结果提取汇总通信算子数据 本功能是将多卡比对场景的比对结果,进行通信算子数据提取和汇总,输出整理好的通信算子多卡比对精度表。 @@ -204,11 +280,12 @@ msprobe -f mindspore merge_result -i ./input_dir -o ./output_dir -config ./confi **完整参数说明** -| 参数名 | 说明 | 是否必选 | -| ---------------------- | ------------------------------------------------------------ | -------- | -| -i 或 --input_dir | 多卡比对结果存盘目录,即使用compare比对的结果输出目录,str类型。所有比对结果应全部为真实数据比对结果或统计数据比对结果,否则可能导致汇总数据不完整。 | 是 | -| -o 或 --output_dir | 数据提取汇总结果存盘目录,str类型。文件名称基于时间戳自动生成,格式为:`multi_ranks_compare_merge_{timestamp}.xlsx`。 | 是 | -| -config或--config-path | 指定需要汇总数据的API和比对指标的yaml文件路径,str类型。
yaml文件详细介绍见下文“**yaml文件说明**”。 | 是 | +| 参数名 | 说明 | 是否必选 | +|-----------------------|-------------------------------------------------------------------------------------------------------------------| -------- | +| -f 或 --framework | 指定训练框架。mindspore。 | 是 | +| -i 或 --input_dir | 多卡比对结果存盘目录,即使用compare比对的结果输出目录,str类型。所有比对结果应全部为真实数据比对结果或统计数据比对结果,否则可能导致汇总数据不完整。 | 是 | +| -o 或 --output_dir | 数据提取汇总结果存盘目录,str类型。文件名称基于时间戳自动生成,格式为:`multi_ranks_compare_merge_{timestamp}.xlsx`。
提示:output目录下与结果件同名文件将被删除覆盖。 | 是 | +| -config或--config-path | 指定需要汇总数据的API和比对指标的yaml文件路径,str类型。
yaml文件详细介绍见下文“**yaml文件说明**”。 | 是 | **yaml文件说明** @@ -224,10 +301,10 @@ compare_index: - MeanRelativeErr ``` -| 参数名 | 说明 | -| ------------- | ------------------------------------------------------------ | -| api | 表示需要汇总的API或module名称。如果没有配置,工具会提示报错。
api名称配置格式为:`{api_type}.{api_name}.{API调用次数}.{前向反向}`
须按顺序配置以上四个字段,可按如下组合配置:
{api_type}
{api_type}.{api_name}
{api_type}.{api_name}.{API调用次数}
{api_type}.{api_name}.{API调用次数}.{前向反向}
这里的api指代API或module。 | -| compare_index | 表示需要汇总的比对指标。compare_index需为dump_mode对应比对指标的子集。如果没有配置,工具将根据比对结果自动提取dump_mode对应的全部比对指标进行汇总。
统计数据模式比对指标:Max diff、Min diff、Mean diff、Norm diff、MaxRelativeErr、MinRelativeErr、MeanRelativeErr、NormRelativeErr
真实数据模式比对指标:Cosine、MaxAbsErr、MaxRelativeErr、One Thousandth Err Ratio、Five Thousandths Err Ratio | +| 参数名 | 说明 | +| ------------- |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| api | 表示需要汇总的API或module名称。如果没有配置,工具会提示报错。
api名称配置格式为:`{api_type}.{api_name}.{API调用次数}.{前向反向}`
须按顺序配置以上四个字段,可按如下组合配置:
{api_type}
{api_type}.{api_name}
{api_type}.{api_name}.{API调用次数}
{api_type}.{api_name}.{API调用次数}.{前向反向}
这里的api指代API或module。 | +| compare_index | 表示需要汇总的比对指标。compare_index需为dump_mode对应比对指标的子集。如果没有配置,工具将根据比对结果自动提取dump_mode对应的全部比对指标进行汇总。
统计数据模式比对指标:Max diff、Min diff、Mean diff、L2norm diff、MaxRelativeErr、MinRelativeErr、MeanRelativeErr、NormRelativeErr、Requires_grad Consistent
真实数据模式比对指标:Cosine、EucDist、MaxAbsErr、MaxRelativeErr、One Thousandth Err Ratio、Five Thousandths Err Ratio、Requires_grad Consistent | **汇总结果件说明** @@ -279,20 +356,20 @@ compare_index: 多卡场景示例如下: ```json { -"npu_path": "./npu_dump/step0", # 需填写到step层级(rank的上一层级) -"bench_path": "./bench_dump/step0", # 需填写到step层级(rank的上一层级) +"npu_path": "./npu_dump/step0", # 需填写到step层级(rank的上一层级) +"bench_path": "./bench_dump/step0", # 需填写到step层级(rank的上一层级) "is_print_compare_log": true } ``` **参数说明** -| 参数名 | 说明 | 是否必选 | -| -------------------- | ------------------------------------------------------------ |------| -| npu_path | 配置NPU环境下的dump.json文件(单卡场景)。跨框架场景指定为MindSpore的json文件。数据类型:str。 | 是 | -| bench_path | 配置CPU、GPU或NPU环境下的dump.json文件(单卡场景)。 跨框架场景指定为PyTorch的json文件。数据类型:str。 | 是 | -| stack_path | 配置NPU dump目录下的stack.json文件。数据类型:str。 如果没有配置stack_path,命令行-s参数不生效,程序自动识别是否存在stack.json文件,如存在,则比对结果中呈现NPU_Stack_Info,如不存在,则不呈现。如果配置了stack_path,比对结果中是否呈现NPU_Stack_Info则通过命令行参数-s来控制。 | 否 | -| is_print_compare_log | 配置是否开启单个算子的日志打屏。可取值true或false,默认为true。关闭后则只输出常规日志。数据类型:bool | 否 | +| 参数名 | 说明 | 是否必选 | +| -------------------- |-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------| +| npu_path | 配置NPU环境下的dump.json文件(单卡场景)或dump目录(多卡场景)。跨框架场景指定为MindSpore的dump.json文件或dump目录。数据类型:str。 | 是 | +| bench_path | 配置CPU、GPU或NPU环境下的dump.json文件(单卡场景)或dump目录(多卡场景)。跨框架场景指定为PyTorch的dump.json文件或dump目录。数据类型:str。 | 是 | +| stack_path | 配置NPU dump目录下的stack.json文件。数据类型:str。如果没有配置stack_path,命令行-s参数不生效,程序自动识别是否存在stack.json文件,如存在,则比对结果中呈现NPU_Stack_Info,如不存在,则不呈现。如果配置了stack_path,比对结果中是否呈现NPU_Stack_Info则通过命令行参数-s来控制。 | 否 | +| is_print_compare_log | 配置是否开启单个算子的日志打屏。可取值true或false,默认为true。关闭后则只输出常规日志。数据类型:bool。 | 否 | ### 4.2 比对文件(kernel) @@ -573,7 +650,7 @@ input_args、input_kwargs和output使用统一的命名规则,当值是list类 "md5": "28f8f74f" } ] -} +} ``` , 初始名称为`Cell.network.module.NetworkWithLoss.forward.0`,`input_args`是`list`,长度为2,按照顺序命名为 @@ -646,4 +723,47 @@ yaml文件中只需配置MindSpore与PyTorch模型代码中功能一致但名称 模型代码示例: -![ms_dump](./img/ms_layer.png) \ No newline at end of file +![ms_dump](./img/ms_layer.png) + +### 4.7 比对文件(单点数据) + +MindSpore动态图单卡场景示例如下: + ```json +{ +"npu_path": "./npu_dump/debug.json", +"bench_path": "./bench_dump/debug.json" +} + ``` + +MindSpore动态图多卡场景(step0目录下包含debug.json文件)示例如下: +```json +{ +"npu_path": "./npu_dump/step0", +"bench_path": "./bench_dump/step0" +} +``` + +MindSpore静态图场景(不区分单/多卡)示例如下: +```json +{ +"npu_path": "./npu_dump/", +"bench_path": "./bench_dump/", +"map_dict": {"input": "x"}, +"common": true +} +``` +- `npu_path`表示NPU dump文件目录,可指定到./npu_dump/ 或者./npu_dump/step0 或者./npu_dump/step0/rank0 保证对应即可,比对结果保持相同目录结构。 +- `bench_path`表示bench dump文件目录,指定同上。 +- `common`表示开启MindSpore静态图单点保存比对,默认关闭。 +- `map_dict`可用于当单点保存比对的`npy`文件名称不完全对应时,通过手动指定保证比对正确执行,比对指定名称对应,如{"input": "x"},则`input_float32_1.npy`会对应`x_float32_1.npy`。 + +### 4.8 比对文件(动静态图场景L0混合数据) + ```json +{ +"npu_path": "./npu_dump", +"bench_path": "./bench_dump", +"is_print_compare_log": true +} + ``` +- npu_path表示NPU dump文件目录,上面示例中的 ./npu_dump/ 是npu侧动静态图dump后graph和pynative目录的父目录。 +- bench_path表示Bench dump文件目录,上面示例中的 ./bench_dump/ 是bench侧动静态图dump后graph和pynative目录的父目录。 \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/12.overflow_check_PyTorch.md b/debug/accuracy_tools/msprobe/docs/12.overflow_check_PyTorch.md index 97b049000c6aca9a69aeca66e1a27a4260b3d142..ea41bcdc10a07a1ddcb033c82765b09fd90ebbf3 100644 --- a/debug/accuracy_tools/msprobe/docs/12.overflow_check_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/12.overflow_check_PyTorch.md @@ -12,13 +12,13 @@ msprobe 工具在 PyTorch 场景下提供溢出数据采集功能和溢出数据 ### 1.2 接口介绍 -溢出检测功能提供的接口与数据采集任务一致,详见[ PyTorch 场景的精度数据采集](./05.data_dump_PyTorch.md)中的"**1 接口介绍**"章节。 +溢出检测功能提供的接口与数据采集任务一致,详见[ PyTorch 场景的精度数据采集](./05.data_dump_PyTorch.md)中的"**接口介绍**"章节。 其中 PrecisionDebugger 中的 task 或是 config.json 中的 task 需要指定为 **overflow_check**,详见[配置文件介绍](./02.config_introduction.md)中的 "**1.1 通用配置介绍**"和"**1.5 task 配置为 overflow_check**"章节。 ### 1.3 示例代码 -溢出检测功能使用方式与数据采集任务一致,详见[ PyTorch 场景的精度数据采集](./05.data_dump_PyTorch.md)中的"**2 示例代码**"章节。 +溢出检测功能使用方式与数据采集任务一致,详见[ PyTorch 场景的精度数据采集](./05.data_dump_PyTorch.md)中的"**示例代码**"章节。 ### 1.4 结果文件介绍 @@ -28,7 +28,7 @@ msprobe 工具在 PyTorch 场景下提供溢出数据采集功能和溢出数据 溢出数据采集功能在昇腾 NPU 上支持饱和模式(仅支持 Atlas 训练系列产品)和 INF/NAN 模式。 -INF/NAN 模式遵循 IEEE 754 标准,根据定义输出 INF/NAN 的计算结果。与之对应的饱和模式在计算出现溢出时,饱和为浮点数极值(+-MAX)。对于 CANN 侧配置,Atlas 训练系列产品,默认为饱和模式,且不建议使用 INF/NAN 模式;Atlas A2 训练系列产品,默认为 INF/NAN 模式,且不建议使用饱和模式。 +INF/NAN 模式遵循 IEEE 754 标准,根据定义输出 INF/NAN 的计算结果。与之对应的饱和模式在计算出现溢出时,饱和为浮点数极值(+-MAX)。对于 CANN 侧配置,Atlas 训练系列产品,默认为饱和模式,且不支持使用 INF/NAN 模式;Atlas A2 训练系列产品,默认为 INF/NAN 模式,且不建议使用饱和模式。 INF/NAN 模式的使能方式如下: @@ -58,8 +58,9 @@ export INF_NAN_MODE_ENABLE=1 msprobe -f pytorch run_overflow_check -api_info ./dump_path/step{step_number}/rank{rank_number}/dump.json ``` -| 参数名称 | 说明 | 是否必选 | -| -------------------------- |------------------------------------| -------- | +| 参数名称 | 说明 | 是否必选 | +|---------------------------|------------------------------------| -------- | +| -f 或 --framework | 指定训练框架。pytorch。 | 是 | | -api_info或--api_info_file | 指定采集下来的 API 信息文件 dump.json。 | 是 | | -j或--jit_compile | 开启 jit 编译。 | 否 | | -d或--device | 指定 Device ID,选择 UT 代码运行所在的卡,默认值为0。 | 否 | diff --git a/debug/accuracy_tools/msprobe/docs/13.overflow_check_MindSpore.md b/debug/accuracy_tools/msprobe/docs/13.overflow_check_MindSpore.md index 33ff4a0259aef02d122022402966c65358e8efff..6872e19d66b38c7dfef9b2a6db0236d478423527 100644 --- a/debug/accuracy_tools/msprobe/docs/13.overflow_check_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/13.overflow_check_MindSpore.md @@ -11,21 +11,23 @@ export INF_NAN_MODE_ENABLE=1 export MS_ASCEND_CHECK_OVERFLOW_MODE="INFNAN_MODE" ``` -**a**:在处理浮点数计算溢出问题时,NPU 当前支持两种溢出模式:INF/NAN 模式与饱和模式。INF/NAN 模式遵循 IEEE 754 标准,根据定义输出 INF/NAN 的计算结果。与之对应的饱和模式在计算出现溢出时,饱和为浮点数极值(+-MAX)。对于 CANN 侧配置,Atlas 训练系列产品,默认为饱和模式,且不建议使用 INF/NAN 模式;Atlas A2训练系列产品,默认为 INF/NAN 模式,且不建议使用饱和模式。对于 MindSpore 框架侧配置,仅支持对 Atlas A2 训练系列产品进行设置,默认为 INF/NAN 模式。CANN 侧 与 MindSpore 框架侧配置须一致。 +**a**:在处理浮点数计算溢出问题时,NPU 当前支持两种溢出模式:INF/NAN 模式与饱和模式。INF/NAN 模式遵循 IEEE 754 标准,根据定义输出 INF/NAN 的计算结果。与之对应的饱和模式在计算出现溢出时,饱和为浮点数极值(+-MAX)。对于 CANN 侧配置,Atlas 训练系列产品,默认为饱和模式,且不支持使用 INF/NAN 模式;Atlas A2训练系列产品,默认为 INF/NAN 模式,且不建议使用饱和模式。对于 MindSpore 框架侧配置,仅支持对 Atlas A2 训练系列产品进行设置,默认为 INF/NAN 模式。CANN 侧 与 MindSpore 框架侧配置须一致。 -溢出检测任务的配置示例见[MindSpore 静态图场景下 task 配置为 overflow_check](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/03.config_examples.md#23-task-%E9%85%8D%E7%BD%AE%E4%B8%BA-overflow_check)、[MindSpore 动态图场景下 task 配置为 overflow_check](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/03.config_examples.md#33-task-%E9%85%8D%E7%BD%AE%E4%B8%BA-overflow_check)。 +溢出检测任务的配置示例见[MindSpore 静态图场景下 task 配置为 overflow_check](03.config_examples.md#23-task-配置为-overflow_check)、[MindSpore 动态图场景下 task 配置为 overflow_check](03.config_examples.md#33-task-配置为-overflow_check)。 ## 1 接口介绍 -溢出检测功能提供的接口与数据采集任务一致,详见MindSpore 场景的精度数据采集中的["**1 接口介绍**"](./06.data_dump_MindSpore.md#1-接口介绍)章节。 +溢出检测功能提供的接口与数据采集任务一致,详见MindSpore 场景的精度数据采集中的["**接口介绍**"](./06.data_dump_MindSpore.md#6-接口介绍)章节。 需要注意,目前暂不支持动态图 "L1" level 下 primitive op 的溢出检测。 ## 2 示例代码 -溢出检测功能使用方式与数据采集任务一致,详见MindSpore 场景的精度数据采集中的["**2 示例代码**"](./06.data_dump_MindSpore.md#2-示例代码)节。 +溢出检测功能使用方式与数据采集任务一致,详见MindSpore 场景的精度数据采集中的["**示例代码**"](./06.data_dump_MindSpore.md#7-示例代码)节。 ## 3 溢出检测结果文件介绍 -溢出检测结果文件目录结构与含义与数据采集任务一致,但仅保存溢出 API 或 kernel 的真实数据或统计信息。详见MindSpore 场景的精度数据采集中的["**3 dump 结果文件介绍**"](./06.data_dump_MindSpore.md#3-dump-结果文件介绍)章节。 +溢出检测结果文件目录结构与含义与数据采集任务一致,但仅保存溢出 API 或 kernel 的真实数据或统计信息。详见MindSpore 场景的精度数据采集中的["**8. dump 结果文件介绍**"](./06.data_dump_MindSpore.md#8-dump-结果文件介绍)章节。 + +**说明**:在静态图 O2 编译等级下,若 MindSpore 版本为 2.4,或者 MindSpore 版本为 2.5,且未使用编包时添加了`--include-mod=adump`选项的 mindstudio-probe whl 包,则会产生 kernel_graph_overflow_check.json 中间文件,一般情况下无需关注。 \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/14.data_parse_PyTorch.md b/debug/accuracy_tools/msprobe/docs/14.data_parse_PyTorch.md index 68a3d1a57dc1b649ffdb6d02d7be378900458e65..931e74a6dc2b266af888504e7d4e10e819ffa5d3 100644 --- a/debug/accuracy_tools/msprobe/docs/14.data_parse_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/14.data_parse_PyTorch.md @@ -18,6 +18,9 @@ msprobe -f pytorch parse Parse >>> ``` +| 参数名称 | 说明 | 是否必选 | +|---------------------------|------------------------------------| -------- | +| -f 或 --framework | 指定训练框架。pytorch。 | 是 | 可在 parse 的界面中执行 Shell 命令,以及如下场景的相关解析命令(详细介绍请参见以下章节。): @@ -26,13 +29,7 @@ Parse >>> - 支持交互式指定 pkl 文件中 API 对应 dump 数据查看。 - 支持 API 进行可选层级比对和打印(统计级和像素级)。 -Ctrl+C 可以退出 parse 交互式界面。不退出 parse 交互式界面若需要执行非该界面下的内置 Shell 命令,且命令与 parse 交互式界面命令冲突时,非该界面命令需要使用 run 命令,在相关命令前加上 run 前缀,如下示例: - -```bash -msprobe -f pytorch parse -Parse >>> run vim cli.py -Parse >>> vim cli.py -``` +Ctrl+C 可以退出 parse 交互式界面。 ### 2.2 kernel 层级算子数据批量转换 @@ -44,11 +41,11 @@ Parse >>> vim cli.py cad -m my_dump_path [-out output_path] [-asc msaccucmp_path] ``` -| 参数名称 | 说明 | 是否必选 | -| -------- | ------------------------------------------------------------ | -------- | -| -m | 待转换 kernel dump 数据目录。需要指定到 kernel dump 数据的 deviceid 级目录。 | 是 | -| -out | 结果输出目录,须指定已存在的目录,默认为 ./parse_data/acl_batch_convert。未指定时保存在默认路径下,比对结束后会打印 log 提示输出结果存放路径。 | 否 | -| -asc | 指定 msaccucmp 路径,默认路径为:/usr/local/Ascend/ascend-toolkit/latest/tools/operator_cmp/compare/msaccucmp.py。 | 否 | +| 参数名称 | 说明 | 是否必选 | +|-------------------------| ------------------------------------------------------------ | -------- | +| -m 或 --my_dump_path | 待转换 kernel dump 数据目录。需要指定到 kernel dump 数据的 deviceid 级目录。 | 是 | +| -out 或 --output_path | 结果输出目录,须指定已存在的目录,默认为 ./parse_data/acl_batch_convert。未指定时保存在默认路径下,比对结束后会打印 log 提示输出结果存放路径。 | 否 | +| -asc 或 --msaccucmp_path | 指定 msaccucmp 路径,默认路径为:/usr/local/Ascend/ascend-toolkit/latest/tools/operator_cmp/compare/msaccucmp.py。 | 否 | **示例代码**: @@ -105,12 +102,12 @@ Parse >>> cad -m /home/xxx/my_dump_path/20000124003856/0 vc -m my_dump_path -g golden_dump_path [-out output_path] [-cmp_path msaccucmp_path] ``` -| 参数名称 | 说明 | 是否必选 | -| --------- | ------------------------------------------------------------ | -------- | -| -m | 待比对 kernel dump 数据目录。如果比对单个算子,需要指定到 kernel dump 数据的 model_id 级目录;如果批量比对,则指定到 cad 转换后的 timestamp 级目录。 | 是 | -| -g | 标杆 kernel dump 数据目录。如果比对单个算子,需要指定到 kernel dump 数据的 model_id 级目录;如果批量比对,则指定到 cad 转换后的 timestamp 级目录。 | 是 | -| -out | 结果输出目录,须指定已存在的目录,默认为 ./parse_data/acl_batch_comapre。未指定时保存在默认路径下,比对结束后会打印 log 提示输出结果存放路径。 | 否 | -| -cmp_path | 指定 msaccucmp 路径,默认路径为:/usr/local/Ascend/ascend-toolkit/latest/tools/operator_cmp/compare/msaccucmp.py | 否 | +| 参数名称 | 说明 | 是否必选 | +|------------------------------| ------------------------------------------------------------ | -------- | +| -m 或 --my_dump_path | 待比对 kernel dump 数据目录。如果比对单个算子,需要指定到 kernel dump 数据的 model_id 级目录;如果批量比对,则指定到 cad 转换后的 timestamp 级目录。 | 是 | +| -g 或 --golden_dump_path | 标杆 kernel dump 数据目录。如果比对单个算子,需要指定到 kernel dump 数据的 model_id 级目录;如果批量比对,则指定到 cad 转换后的 timestamp 级目录。 | 是 | +| -out 或 --output_path | 结果输出目录,须指定已存在的目录,默认为 ./parse_data/acl_batch_compare。未指定时保存在默认路径下,比对结束后会打印 log 提示输出结果存放路径。 | 否 | +| -cmp_path 或 --msaccucmp_path | 指定 msaccucmp 路径,默认路径为:/usr/local/Ascend/ascend-toolkit/latest/tools/operator_cmp/compare/msaccucmp.py | 否 | 输出结果:`batch_compare_{timestamp}.csv` 文件。 @@ -119,7 +116,7 @@ vc -m my_dump_path -g golden_dump_path [-out output_path] [-cmp_path msaccucmp_p ```bash # 传入待比对数据目录以及标杆数据目录 Parse >>> vc -m ./my_dump_path -g ./golden_data_path -[INFO]Compare result is saved in : parse_data/acl_batch_comapre/batch_compare_1707271118.csv +[INFO]Compare result is saved in : parse_data/acl_batch_compare/batch_compare_1707271118.csv ``` ### 2.3 kernel 算子数据的 npy 转换 @@ -130,12 +127,12 @@ Parse >>> vc -m ./my_dump_path -g ./golden_data_path dc -n file_name/file_path [-f format] [-out output_path] ``` -| 参数名称 | 说明 | 是否必选 | -| --------- | ------------------------------------------------------------ | -------- | -| -n | 需转换的 dump 数据文件或 dump 数据文件目录。 | 是 | -| -f | 开启 format 转换,指定该参数时需要配置 format 格式。当前内置的 Format 转换支持如下类型:
FRACTAL_NZ 转换 NCHW;
FRACTAL_NZ 转换成 NHWC;
FRACTAL_NZ 转换 ND;
HWCN 转换 FRACTAL_Z;
HWCN 转换成 NCHW;
HWCN 转换成 NHWC;
NC1HWC0 转换成 HWCN;
NC1HWC0 转换成 NCHW;
NC1HWC0 转换成 NHWC;
NCHW 转换成 FRACTAL_Z;
NCHW转换成NHWC;
NHWC转换成FRACTAL_Z;
NHWC转换成HWCN;
NHWC转换成NCHW;
NDC1HWC0转换成NCDHW。 | 否 | -| -out | 结果输出目录。 | 否 | -| -cmp_path | 指定 msaccucmp 路径,默认路径为:/usr/local/Ascend/ascend-toolkit/latest/tools/operator_cmp/compare/msaccucmp.py | 否 | +| 参数名称 | 说明 | 是否必选 | +|-------------------------| ------------------------------------------------------------ | -------- | +| -n 或 --name | 需转换的 dump 数据文件或 dump 数据文件目录。 | 是 | +| -f 或 --format | 开启 format 转换,指定该参数时需要配置 format 格式。当前内置的 Format 转换支持如下类型:
FRACTAL_NZ 转换 NCHW;
FRACTAL_NZ 转换成 NHWC;
FRACTAL_NZ 转换 ND;
HWCN 转换 FRACTAL_Z;
HWCN 转换成 NCHW;
HWCN 转换成 NHWC;
NC1HWC0 转换成 HWCN;
NC1HWC0 转换成 NCHW;
NC1HWC0 转换成 NHWC;
NCHW 转换成 FRACTAL_Z;
NCHW转换成NHWC;
NHWC转换成FRACTAL_Z;
NHWC转换成HWCN;
NHWC转换成NCHW;
NDC1HWC0转换成NCDHW。 | 否 | +| -out 或 --output_path | 结果输出目录。 | 否 | +| -cmp_path 或 --msaccucmp | 指定 msaccucmp 路径,默认路径为:/usr/local/Ascend/ascend-toolkit/latest/tools/operator_cmp/compare/msaccucmp.py | 否 | - 输出结果:npy 文件。 @@ -149,9 +146,9 @@ dc -n file_name/file_path [-f format] [-out output_path] pt -n file_path ``` - | 参数名称 | 说明 | 是否必选 | - | -------- | ------------- | -------- | - | -n | npy 文件路径。 | 是 | + | 参数名称 | 说明 | 是否必选 | + |-------------| ------------- | -------- | + | -n 或 --name | npy 文件路径。 | 是 | 打印统计信息:shape, dtype, max, min 和 mean。默认在 npy 文件路径下将该数据保存为 txt 文件。 @@ -197,10 +194,10 @@ TextFile:./parse_data/dump_convert/Add.fp32_vars_add_1fp32_vars_Relu_6.24.5.1636 pk -f pkl_path -n api_name ``` -| 参数名称 | 说明 | 是否必选 | -| -------- | ----------------------- | -------- | -| -f | 指定 dump.json 文件路径。 | 是 | -| -n | 指定 API 名称。 | 是 | +| 参数名称 | 说明 | 是否必选 | +|-------------| ----------------------- | -------- | +| -f 或 --file | 指定 dump.json 文件路径。 | 是 | +| -n 或 --name | 指定 API 名称。 | 是 | - 输出结果:打印统计信息(shape, dtype, max和min mean)。 - 若 pkl 文件中存在相应的堆栈信息,则会打印堆栈信息。 @@ -225,20 +222,20 @@ Statistic Info: 输入以下命令, 进行统计级和像素级比对。 ```bash -cn -m my_data*.npy -g gloden*.npy [-p num] [-al atol] [-rl rtol] +cn -m my_data*.npy -g golden*.npy [-p num] [-al atol] [-rl rtol] ``` - 统计级比对:对 tensor 整体进行余弦值及相对误差的计算。 - 像素级比对:对输入的两个 npy 文件进行逐元素比对。若两个 tensor 对应元素的相对误差或绝对误差大于**误差阈值**(-al 和 -rl 配置)则被标记为错误数据。 -| 参数名称 | 说明 | 是否必选 | -| -------- | ----------------------------------------------- | -------- | -| -m | 待比对数据。 | 是 | -| -g | 标杆数据。 | 是 | -| -p | 设置比对结束后打印错误元素的个数,默认值 20。 | 否 | -| -al | 判定数据存在精度问题的绝对误差阈值,默认 0.001。 | 否 | -| -rl | 判定数据存在精度问题的相对误差阈值,默认 0.001。 | 否 | -| -s | 将 npy 文件保存成 txt 文件,用于查看,默认开启。 | 否 | +| 参数名称 | 说明 | 是否必选 | +|-------------------------| ----------------------------------------------- | -------- | +| -m 或 --my_dump_path | 待比对数据。 | 是 | +| -g 或 --golden_dump_path | 标杆数据。 | 是 | +| -p 或 --print | 设置比对结束后打印错误元素的个数,默认值 20。 | 否 | +| -al 或 --atol | 判定数据存在精度问题的绝对误差阈值,默认 0.001。 | 否 | +| -rl 或 --rtol | 判定数据存在精度问题的相对误差阈值,默认 0.001。 | 否 | +| -s 或 --save | 将 npy 文件保存成 txt 文件,用于查看,默认开启。 | 否 | 输出结果: diff --git a/debug/accuracy_tools/msprobe/docs/17.grad_probe.md b/debug/accuracy_tools/msprobe/docs/17.grad_probe.md index f210088013415e40167f3eea3aab6163b0c947dc..fe24bc5cce3414b9f07333257aecf25d225ed45f 100644 --- a/debug/accuracy_tools/msprobe/docs/17.grad_probe.md +++ b/debug/accuracy_tools/msprobe/docs/17.grad_probe.md @@ -5,7 +5,7 @@ - 将模型权重的梯度数据导出。这种功能可以将模型权重的梯度值以统计量的形式采集出来,用以分析问题。 - 将两份梯度数据进行相似度对比。在有标杆问题中,可以确认训练过程中精度问题出现的step,以及抓取反向过程中的问题。 -工具支持PyTorch版本:2.0/2.1/2.2;支持MindSpore版本:r2.3。暂不支持deepspeed的zero1、zero2、zero3。 +工具支持PyTorch版本:2.0/2.1/2.2;支持MindSpore版本:r2.3。暂不支持deepspeed的ZeRO-1、ZeRO-2、ZeRO-3。 ## 工具特性 @@ -65,6 +65,7 @@ + 值分布:梯度数据落在各个区间的元素个数占总元素个数的比例。 + bounds:一个列表,用来划分出区间以统计值分布。例如传入bounds = [-10, 0, 10],此时有一个 grad_value: Tensor = [9.3 , 5.4, -1.0, -12.3],依据 bounds 划分出 (-inf, -10]、(-10, 0]、(0, 10]、(10, inf) 四个区间,然后统计grad_value里的数据落在每个区间内的个数,得到 1、1、2、0。如下图所示: + ![Alt text](./img/grad_probe_image-1.png) 2. 插入代码。示例代码如下: diff --git a/debug/accuracy_tools/msprobe/docs/18.online_dispatch.md b/debug/accuracy_tools/msprobe/docs/18.online_dispatch.md index e686c61b68add9c9a1ade9ae3e89b897c9b8d6bf..4d1833a3fbd32c37aaf5dc7c0993c176417ab584 100644 --- a/debug/accuracy_tools/msprobe/docs/18.online_dispatch.md +++ b/debug/accuracy_tools/msprobe/docs/18.online_dispatch.md @@ -70,15 +70,15 @@ PyTorch NPU在线精度比对是msprobe工具实现在PyTorch训练过程中直 | api_list | dump范围,dump_mode="list"时设置,需要Dump Aten Ir API名称,默认为None,Aten Ir API名称可以通过dir(torch.ops.aten)查看。 | 否 | | dump_path| dump文件生成的路径。 | 是 | | tag | 传入tag字符串,成为dump文件夹名一部分,默认为None。 | 否 | -| process_num | 多进程并发数,默认为0。 | 否 | +| process_num | 多进程并发数,默认为0,最大不超过CPU核数的四分之一。 | 否 | | debug | debug信息打印,默认为False。 | 否 | ### dump数据存盘说明 -dump数据存盘目录名格式:`atat_tag_rankid_{timestamp}`。 +dump数据存盘目录名格式:`msprobe_rankid_{timestamp}`。 子目录下包含1个比对结果csv文件、cpu和npudump数据目录,npu目录下包含Aten IR在NPU上的输入输出的dump数据,由于CPU的输入是直接使用NPU的输入执行,因此cpu目录下只包含执行输出的dump数据。 ```bash -atat_rank4_20230911170521 +msprobe_rank4_20230911170521 ├── compare_result_rank4_20230911170521.csv ├── cpu │ ├── native_batch_norm_backward_10_output.0.npy diff --git a/debug/accuracy_tools/msprobe/docs/19.monitor.md b/debug/accuracy_tools/msprobe/docs/19.monitor.md index fa1b7d06d6c52b55c49f26352f823de41b28cb2d..cd851b73405e53618981855eab055368edaebe18 100644 --- a/debug/accuracy_tools/msprobe/docs/19.monitor.md +++ b/debug/accuracy_tools/msprobe/docs/19.monitor.md @@ -10,7 +10,7 @@ 要求: - PyTorch场景:torch不低于**2.0** -- MindSpore场景:mindspore不低于**2.4.10**,仅支持**MindSpore动态图**,暂不支持**msadapter**套件 +- MindSpore场景:mindspore不低于**2.4.10**,仅支持**MindSpore动态图**,已支持**msadapter**套件 ## 功能介绍 下表中字段为训练状态轻量化监控工具的完整功能点: @@ -21,12 +21,13 @@ | [权重梯度监控](#权重梯度监控) | 开启权重梯度监控 | PyTorch、MindSpore | | [激活值监控](#激活值监控) | 开启激活值监控 | PyTorch、MindSpore | | [优化器状态监控](#优化器状态监控) | 开启优化器状态监控 | PyTorch、MindSpore | +| [采集module堆栈信息](#采集module堆栈信息) | 采集监控的第一个 step 的 module 对应的堆栈信息辅助问题定位 | PyTorch、MindSpore | | [指定监控对象](#指定监控对象) | 指定监控的nn.Module(nn.Cell)及对应的输入输出 | PyTorch、MindSpore | | [打印模型结构](#打印模型结构) | 打印模型结构 | PyTorch | -| [Module全量监控](#Module全量监控) | 对全量module的输入输出做监控 | PyTorch、MindSpore | -| [Parameter全量监控](#Parameter全量监控) | 对全量Parameter的输入输出做监控 | PyTorch、MindSpore | -| [输出格式和统计量](#输出格式和统计量) | format PyTorch支持`csv`、`tensorboard`和`api`,MindSpore仅支持`csv`,`ops`均支持,`ndigits`仅PyTorch支持 | PyTorch、MindSpore | -| [梯度异常时序判断](#梯度异常时序判断) | 梯度异常时自动梯度落盘 | PyTorch | +| [l2可解释特征监控](#l2可解释特征监控) | 开启模型状态的高阶监控 | PyTorch | +| [输出格式和统计量](#输出格式和统计量) | format PyTorch支持`csv`、`tensorboard`和`api`,MindSpore仅支持`csv`,`ops`、`ndigits`均支持 | PyTorch、MindSpore | +| [mbs粒度梯度监控](#mbs粒度梯度监控) | 开启梯度监控时,采集聚合前梯度时支持`micro_batch_size`粒度 | PyTorch、MindSpore | +| [异常告警](#异常告警) | 监控对象指标异常时自动告警,支持异常数据落盘 | PyTorch、MindSpore | | [csv格式数据转tensorboard可视化显示](#csv格式数据转tensorboard可视化显示) | 将csv转为tensorboard文件显示 | PyTorch | | [动态启停](#动态启停) | 训练过程中动态修改配置开启监控 | PyTorch、MindSpore | | [功能重载](#功能重载) | 训练中开启激活值监控。待废弃,请使用动态启停功能代替。 | PyTorch | @@ -205,15 +206,29 @@ monitor.monitor_gnorm_with_ad( 本工具针对分布式计算框架megatron和deepspeed框架做了适配,暂不支持其他框架。 +### 采集module堆栈信息 +- 工具配置示例: +```json +{ + "targets": { + }, + "format": "csv", + "stack_info": true +} +``` +开启 `stack_info` 后会采集监控的第一个 step 的所有 module 的堆栈信息,输出格式仅支持 csv 。 ## 高阶功能 + ### 指定监控对象 -工具支持对nn.Module(**激活值监控**)和nn.Parameter(**权重监控**、**权重梯度监控、优化器监控**)对象实现相应的监控行为,在配置文件的"targets"(dict)字段指定,targets格式为{module_name/param_name: {filed: format}}。 +工具支持对指定nn.Module进行状态监控,在配置文件的`targets`字段中指定,`targets`格式为{module_name: {}}。 + +module_name可以通过nn.Module的接口named_modules()获取。 #### 打印模型结构 -工具提供可选项`print_struct`打印模型结构,帮助配置targets。工具会在在第一个step后打印结构并停止训练进程,模型结构默认打印在`$MONITOR_OUTPUT_DIR/module_struct.json`。 +工具提供可选项`print_struct`打印模型结构,帮助配置targets。工具会在在第一个step后打印结构并停止训练进程,每张卡上的模型结构默认保存在`$MONITOR_OUTPUT_DIR/module_struct/rank{rank}/module_struct.json`, 其中{rank}为对应的卡号。 ```json { "print_struct": true @@ -221,7 +236,6 @@ monitor.monitor_gnorm_with_ad( ``` 输出样例: -字段`config`用于配置文件中指定module target。其余为各个元素的shape和dtype。 ```json "0:63.mlp.linear_fc2": { @@ -245,40 +259,30 @@ monitor.monitor_gnorm_with_ad( } }, ``` +对于module对象,通常关心前向/反向传播的输入和输出: -- Module - 对于module对象,通常关心其前向的输入(input)输出(output)和反向的输入--前向输出的梯度(output_grad)和输出--前向输入的梯度(input_grad)。同时需要声明这些对象的类型,通常为"tensor"或"tuple\[length]"。 +- 前向的输入(input) +- 前向的输出(output) +- 反向的输入,表示前向输出的梯度(output_grad) +- 反向的输出,表示前向输入的梯度(input_grad) - "tensor"可以直接用来计算统计量,"tuple"需要进一步指定监控的索引。如"tuple[2]:0",表示该对象为长度2的tuple,对第0元素进行监控;不指定索引时,默认对第0元素进行监控。 - module_name可以通过nn.Module的接口`named_modules()`获取。 -```json -// 示例:对一个名为"module.encoder.layers.0.mlp"的module,监控其前向输入第0元素和输出。 -{ - "targets": { - "module.encoder.layers.0.mlp": { - "input": "tuple[2]:0", - "output": "tensor" - } - } -} -``` -#### Module全量监控 -工具提供简便的全量module监控方式。或不配置targets、all_xy字段,同样表示全量监控。 +#### 指定监控对象 + +targets字段指定监控对象示例如下: ```json -{ - "targets": {}, - "all_xy": true +// 示例:对一个名为"module.encoder.layers.0.mlp"的module。 +"targets": { + "module.encoder.layers.0.mlp": {} } ``` +对于parameter对象,通常会关注其在一个训练迭代中的梯度(weight grad)、adam类优化器中的动量(1st moment, 2nd moment)。 +parameter归属于某一module,可以通过指定module_name来监控包含在这一module中的**所有**parameter。 -- Parameter - 对于parameter对象,通常会关注其在一个训练迭代中的梯度(weight grad)、adam类优化器中的动量(1st moment, 2nd moment)。 - parameter归属于某一module,也可以通过指定module_name来监控包含在这一module中的**所有**parameter。 +param_name可以通过nn.Module的接口`named_parameters()`获取。 - param_name可以通过nn.Module的接口`named_parameters()`获取。 ```json // 示例:监控"module.encoder.layers.0.mlp"的所有参数和"module.embedding.word_embedding.weight"这一参数 { @@ -289,8 +293,9 @@ monitor.monitor_gnorm_with_ad( } ``` -#### Parameter全量监控 -工具提供简便的全量parameter监控方式。或不配置targets,同样表示全量监控。 +#### 全量监控 + +工具提供简便的全量module对象监控方式。 ```json { @@ -298,7 +303,37 @@ monitor.monitor_gnorm_with_ad( } ``` +### l2可解释特征监控 +- 工具配置示例 +```json +{ + "l2_targets": { + "attention_hook": ["0:0.self_attention.core_attention.flash_attention"], + "linear_hook": ["0:0.self_attention.linear_qkv", "0:1.self_attention.linear_qkv"] + }, + "recording_l2_features": true, + "sa_order": "b,s,h,d" +} +``` +| 配置项 | 类型 | 说明 | 是否必选 | +|--------|------|------|--------| +| **l2_targets** | Dict[str, List[str]] | 指定需要监控的模型层配置
**支持的hook类型**:
• `attention_hook`:监控注意力层
  ▪️ 采集指标:`entropy` `softmax_max`
  ▪️ 必须通过[打印模型结构](#打印模型结构)获取准确层名
  ▪️ 不配置或配置空列表均表示不采集
• `linear_hook`:监控线性层
  ▪️ 采集指标:`sr`, `kernel_norm`
  ▪️ 必须通过[打印模型结构](#打印模型结构)获取准确层名, 不配置表示不采集
  ▪️ 配置空列表会自动识别符合条件的层(包含`weight`或`wg`2D参数属性的层) | 是 | +| **recording_l2_features** | bool | 是否开启L2层特征数据采集,默认为false表示不采集 | 否 | +| **sa_order** | str | 计算`attention_hook`内指标时,指定Attention输入(Q,K)的张量维度排列顺序,支持"s,b,h,d"和"b,s,h,d", 默认为"s,b,h,d"表示输入维度顺序为**s**equence_len​->**b**atch_size​->num_**h**eads​->head_**d**im | 否 | + + +#### L2可解释特征监控指标说明 + +| **指标名称** | **适用Hook类型** | **数学定义/计算方式** | **监控意义** | +|--------------------|-------------------|-------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------| +| **entropy** | attention_hook | $H(p)=-\sum p_i \log p_i$,其中$p_i$为注意力权重 | 衡量注意力分布的不确定性,**低熵值**表示注意力集中 | +| **softmax_max** | attention_hook | $\max(\text{softmax}(QK^T/\sqrt{d}))$ | 反映注意力机制的聚焦程度,**高值**表示存在显著主导的注意力token | +| **sr(stable_rank)** | linear_hook | $\frac{\|W\|_F}{\|W\|_2}$(稳定秩,Frobenius范数除以谱范数) | 评估权重矩阵的有效秩,**低值**表示矩阵接近低秩不稳定状态 | +| **kernel_norm** | linear_hook | $\|W\|_F$(Frobenius范数) | 权重矩阵的缩谱范数,反映输入在矩阵最大奇异向量张成空间的放大系数 | + + ### 输出格式和统计量 + 工具配置示例: ```json { @@ -333,7 +368,7 @@ export MONITOR_OUTPUT_DIR=/xxx/output_dir 监控结果写入csv文件中,可以通过`ndigits`字段设置小数位数。 表头为 vpp_stage | name | step | micro_step(optional) | *ops |。 仅在激活值监控的输出文件中包含micor_step。 - 激活值监控的name为.\, 其他任务的name为> + 激活值监控的name为.\, 其他任务的name为 - **api** 监控结果不落盘,在训练过程中可以通过`generate_wgrad_metrics`、`generate_xy_metrics`等接口获取,使用方式参考[公开接口](#公开接口) 。 @@ -349,16 +384,65 @@ export MONITOR_OUTPUT_DIR=/xxx/output_dir ![step_count_per_record](img/monitor/step_count_per_record.png) -### 梯度异常时序判断 +### mbs粒度梯度监控 + +当配置梯度监控任务时,工具默认`global_batch_size`粒度进行梯度监控。当需要监控`micro_batch_size`粒度梯度信息时,在配置文件中配置`monitor_mbs_grad`为`true`,配置示例如下: + +```json +{ + "wg_distribution": true, + "monitor_mbs_grad": true +} +``` + +应用范围 + +- **仅支持采集聚合前梯度**,在梯度累积场景下,聚合后梯度已无法区分`micro_batch`数据。 +- PyTorch场景下,Megatron和DeepSpeed训练框架下均支持,FSDP训练框架下暂不支持。 +- MindSpore场景下均支持。 + +### 异常告警 + +工具的异常告警功能旨在自动判断训练过程中的异常现象,用户可通过在配置文件中配置alert字段来指定告警规则,并在训练过程中根据该规则及时打屏对用户发出告警。 + + 1. 训练前配置相关参数 -工具支持自动判断训练过程中的梯度异常,需要在配置文件中设置alert相关字段。"AnomalyTurbulence"会将当前数值与历史均值比较,如果相对偏差超过阈值,会在打屏信息中提示用户。如果打开"`dump`"选项,则会将异常梯度相关信息落盘到目录`monitor_output/anomaly_detected`,用于后续时序判断。 +当前支持的异常告警规则如下: + +| 异常告警 |解释| rule_name | args是否可选 | +|--------------|----|-----------|---------------------------------------------------------------------| +| 历史均值偏离告警 |将当前数值与历史均值比较。如果相对偏差超过阈值,会在打屏信息中提示用户指标偏离。当前仅对`norm`和`mean`指标生效。| AnomalyTurbulence | 否,必须传入threshold。当指标超过`(1+threshold)*avg`时,识别为偏离历史均值。 | +| nan值/极大值告警 |根据是否提供threshold来判断nan值或极大值| AnomalyNan | 是, 若未配置args或未配置threshold,则默认检测nan,若提供threshold,则检测nan值以及绝对值超过阈值的极大值 | + +除此之外,我们在alert中支持dump配置项,如果打开"`dump`"选项,则会将异常信息落盘到目录`monitor_output/anomaly_detected`。 + +- 历史均值偏离告警案例如下: ```json "alert": { - "rules": [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}], + "rules": [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}], // 0.5表示偏离50%则提示偏离 + "dump": true + }, +``` +- nan值/极大值告警案例如下: +```json + "alert": { + "rules": [{"rule_name": "AnomalyNan", "args": {"threshold": 1e10}}], + "dump": true + }, +``` + +注:当配置多条异常告警规则时,优先告警第一条,如以下配置时每一层会优先报AnomalyNan的告警(一般不建议配置多条规则): +```json + "alert": { + "rules": [ + {"rule_name": "AnomalyNan", "args": {"threshold": 1e10}}, + {"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}} + ], "dump": true }, ``` + 2. 实例化工具时传入流水线并行group ```python monitor = TrainerMon( @@ -386,18 +470,20 @@ monitor = TrainerMon( } ``` +其中call_{xxx}中的xxx为API的执行调用顺序,为后续异常事件排序做准备。 + 3. 异常事件排序 当模型训练过程中出现较多异常数据,需要对异常事件排序。工具提供topk的异常排序能力,按照api的执行顺序进行排序,便于定界首次异常点。异常分析命令示例: ```shell -python3 -m msprobe.pytorch.monitor.anomaly_analyse -d $MONITOR_OUTPUT_DIR/anomaly_detected +python3 -m msprobe.core.monitor.anomaly_processor -d $MONITOR_OUTPUT_DIR/anomaly_detected ``` 异常事件分析结束,将topk事件写入文件`anomaly_detected/anomaly_analyse.json`。异常分析支持以下参数配置: -| 字段名 | 解释 | 是否必选 | -| ----------------- | ------------------------------------------------------------ | -------- | -| -d 或 --data_path | 指定梯度异常落盘文件夹,梯度监控功能输出,一般为$MONITOR_OUTPUT_DIR/anomaly_detected。 | 是 | +| 字段名 | 解释 | 是否必选 | +| ----------------- | --------------------------------------------------------- | -------- | +| -d 或 --data_path | 指定异常落盘文件夹,监控功能输出,一般为$MONITOR_OUTPUT_DIR/anomaly_detected。 | 是 | | -o 或 --out_path | 排序后的异常落盘文件地址,默认在--data_path路径下落盘一个anomaly_analyse.json文件。 | 否 | | -k 或 --topk | 指定保留前topk个异常,默认为8。 | 否 | | -s 或 --step_list | 指定分析的step范围,默认为[]。 | 否 | @@ -405,45 +491,78 @@ python3 -m msprobe.pytorch.monitor.anomaly_analyse -d $MONITOR_OUTPUT_DIR/anomal ### csv格式数据转tensorboard可视化显示 -将csv数据转换为tensorboard格式数据。 +**将csv数据转换为tensorboard格式数据。** ```python from msprobe.pytorch.monitor.csv2tb import csv2tensorboard_by_step # 前三个参数用来指定需要转换的一批文件,指定monitor输出目录及一个时间范围,会对这个范围内的文件进行转换 # process_num指定拉起的进程个数,默认为1,更多的进程个数可以加速转换 -# data_type_list是一个列表,指定需要转换的数据类型, 数据类型应来自输出件文件前缀,所有类型数据: -# ["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param"] -# 不指定就转换全部数据 -# output_dirpath可指定输出目录, 不传值时保存到"{curtime}_csv2tensorboard_by_step"文件夹,其中curtime为自动获取的当前时间戳 +# data_type_list是一个列表,指定需要转换的数据类型,默认转换全部数据,数据类型应来自输出件文件前缀,所有类型数据: +# ["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param_origin", "param_updated"] +# output_dirpath可指定输出目录,默认保存到"{curtime}_csv2tensorboard_by_step"文件夹,其中curtime为自动获取的当前时间戳 csv2tensorboard_by_step( + monitor_path="~/monitor_output", # 必填 + time_start="Dec03_21-34-40", # 必填 + time_end="Dec03_21-34-42", # 必填 + process_num=8, + data_type_list=["param_origin"] +) +``` +参数详细介绍请参见[公开接口](#公开接口)的“csv输出件转tensorboard输出件” + +**将csv数据转换为sqlite db数据。** +1. 创建Python脚本,以`csv2db.py`命名为例,将以下配置拷贝到文件中, 并按实际情况修改。 + +```python +from msprobe.core.monitor.csv2db import CSV2DBConfig, csv2db +config = CSV2DBConfig( monitor_path="~/monitor_output", time_start="Dec03_21-34-40", time_end="Dec03_21-34-42", process_num=8, - data_type_list=["param"] + data_type_list=["grad_unreduced"], + step_partition=500, + output_dirpath="~/monitor_output" ) +csv2db(config) ``` +参数详细介绍请参见[公开接口](#公开接口)的“csv转sqlite数据库接口” + +2. 执行如下命令开启转换。 +```shell +python csv2db.py +``` +完成转换,在`~/monitor_output`目录下生成`monitor_metrics.db`文件。 ### 动态启停 动态启停模式:支持用户在训练过程中随时启动/更新监控。 -用户可在训练开始前通过配置环境变量DYNAMIC_MONITOR=True来确认开启动态启停模式,该模式下需要配合config.json文件中的dynamic_on字段来使用。 +用户可在训练开始前通过配置环境变量`DYNAMIC_MONITOR=True`来确认进入动态启停模式,该模式下需要配合config.json文件中的`dynamic_on`字段来使用。 在动态启停模式下,启动和停止分别由如下控制: -- 启动: - 首次监控:config.json文件中dynamic_on字段为true,代表是否需要开启监控。 - 非首次监控:config文件时间戳更新且config.json文件中dynamic_on字段为true。 -- 停止: - 到达collect_times之后自动停止并改config.json文件中dynamic_on字段为false,可再通过上述操作重启。 +- **启动**: + - 首次监控:查看config.json文件中`dynamic_on`字段,若为`true`则在下一步开启监控。 + - 非首次监控:查看config.json文件时间戳,若时间戳更新且config.json文件中`dynamic_on`字段为`true`则在下一步开启监控。 +- **停止**: + 到达`collect_times`之后自动停止并改config.json文件中`dynamic_on`字段为`false`,可再通过上述操作重启。 + +**注意事项:**: -大部分情况下,用户可在看到异常趋势后再手动更新config.json文件并打开dynamic_on开关;此外,使用时若想要在一开始就启动监控,可直接打开dynamic_on开关做基础配置的监测(首次不要求时间戳更新) +- 默认监控启动皆统一在配置初始化或查询到更新后的下一步,即第n步挂上hook将在第n+1步启动采集,如需采集第0步数据请使用静态模式。 +- config.json中途修改出错时,若此时不在监控则不生效,若在监控则用原配置继续。 +- 达到`collect_times`之后程序会自动将该值置为`false`待下次改`true`重启。 -注意事项: +**支持的使用场景说明如下:** -- 默认监控启动皆统一在配置初始化或查询到更新后的下一步,也就是若第n步挂上hook则第n+1步才启动采集,如需采集第0步数据请用静态模式。 -- config中途修改出错时,若此时不在监控就不生效,若在监控则用原配置继续。 -- 达到collect_times之后会自动将该值置为false待下次改true重启。 +| 场景 | 监控模式 | 操作步骤 | 结果描述 | +|-----------------------------------------------|----|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------| +| 场景1: 使用默认静态模式 | 静态 | 1. 配置环境变量:`export DYNAMIC_MONITOR=False `
或不设置该环境变量 | 走默认分支进行数据采集和保存,不受config.json中`dynamic_on`影响 | +| 场景2: 进入动态启停模式,初始不启动监控 | 动态 | 1.配置环境变量:`export DYNAMIC_MONITOR=True`
2.配置config.json中`dynamic_on: false`或不设置该字段 | 初始状态下无监控,不进行数据采集和保存 | +| 场景3: 进入动态启停模式,初始即启动监控 | 动态 | 1.配置环境变量:`export DYNAMIC_MONITOR=True`
2.配置config.json中`dynamic_on: true` | 根据初始配置在第1步(初始计数为0)开启监控并保存,采集`collect_times`次数后结束监控 | +| 场景4: 进入动态启停模式,初始暂不启动监控,训练中途启动 | 动态 | 1.配置环境变量:`export DYNAMIC_MONITOR=True`
2.开始时配置config.json中`dynamic_on: false`或不设置该字段
3.训练中途修改config.json中`dynamic_on: true` | 训练中途根据最新配置在下一步开启监控并保存,采集`collect_times`次数后结束监控 | +| 场景5: 进入动态启停模式,监控还未结束时中途修改config.json采集配置 | 动态 | 1.配置环境变量:`export DYNAMIC_MONITOR=True`
2.期间配置`dynamic_on: true`启动采集
3.在采集还未达到`collect_times`次数前,中途修改config.json配置 | 更新前按旧配置采集并保存,更新后下一步以最新config.json采集且`collect_times`重新从0开始计数。此功能可配合中途`collect_times`改0来实现提前停止监控。 +| 场景6: 进入动态启停模式,在根据`collect_times`结束监控后,需重新启动监控 | 动态 | 1.配置环境变量:`export DYNAMIC_MONITOR=True`
2.期间`dynamic_on: true`启动采集
3.采集达到`collect_times`次数后结束监控,程序自动改`dynamic_on:false`
4.配置config.json中`dynamic_on:true`重启监控 | 更新前按旧配置采集并保存,中途停止监控后无采集,重启后下一步以最新config.json重启采集且`collect_times`重新从0开始计数。 ### 功能重载 此功能将在2026年废弃。请使用[动态启停](#动态启停)功能代替。 @@ -499,7 +618,25 @@ csv2tensorboard_by_step(monitor_path, time_start, time_end, process_num=1, data_ | time_start | 起始时间戳。搭配time_end一起使用。指定一个时间范围,会对这个范围内的文件进行转换。左闭右闭的区间。 | 是 | | time_end | 结束时间戳。搭配time_start一起使用。指定一个时间范围,会对这个范围内的文件进行转换。左闭右闭的区间。 | 是 | | process_num | 指定拉起的进程个数,默认为1,更多的进程个数可以加速转换。 | 否 | -| data_type_list | 指定需要转换的数据类型, 数据类型应来自输出件文件前缀,所有类型数据:
["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param"]。
不指定就转换全部数据。 | 否 | +| data_type_list | 指定需要转换的数据类型, 数据类型应来自输出件文件前缀,所有类型数据:
["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param_origin", "param_updated"]。
不指定就转换全部数据。 | 否 | +| output_dirpath | 指定转换后的输出路径,默认输出到"{curtime}_csv2tensorboard_by_step"文件夹,其中curtime为自动获取的当前时间戳。 | 否 | + +- csv转sqlite数据库接口 +```python +csv2db(config: CSV2DBConfig) -> None +``` +配置参数 (CSV2DBConfig) + +| 参数 | 说明 | 是否必选 | +| -------------- | ------------------------------------------------------------ | -------- | +| monitor_path | 待转换的csv存盘目录。 | 是 | +| time_start | 起始时间, 例如"Dec03_21-34-40"。搭配time_end一起使用,从而指定一个时间范围(闭区间),会对这个范围内的文件进行转换。默认为None不限制。 | 否 | +| time_end | 结束时间,例如"Dec03_21-34-41"。搭配time_start一起使用,从而指定一个时间范围(闭区间),会对这个范围内的文件进行转换。默认为None不限制。 | 否 | +| process_num | 指定拉起的进程个数,默认为1,更多的进程个数可以加速转换。 | 否 | +| data_type_list | 指定需要转换的数据类型, 数据类型应来自输出件文件前缀,所有类型数据:
["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param_origin", "param_updated", "other"]。
不指定就转换全部数据。 | 否 | +| step_partition | 控制数据库中按step分区的间隔,默认每500步一个表。 | 否 | +| output_dirpath | 指定转换后的输出路径,默认输出到"{curtime}_csv2db"文件夹,其中curtime为自动获取的当前时间戳。 | 否 | + - 在模型任意位置获取当前参数**梯度**统计量 ```python @@ -561,6 +698,7 @@ TrainerMon.monitor_gnorm_with_ad(model, grad_acc_steps, optimizer, dp_group, tp_ "mv_distribution": true, "param_distribution": true, "wg_distribution": true, + "monitor_mbs_grad": true, "cc_distribution": {"enable":true, "cc_codeline":[]}, "alert": { "rules": [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}], @@ -578,33 +716,36 @@ TrainerMon.monitor_gnorm_with_ad(model, grad_acc_steps, optimizer, dp_group, tp_ 下面详细解释各个字段: -| 字段名字 | 是否必选 | 解释 | -| ----------------------- | -------- |-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| "targets" | 可选 | 指定需要监控的模型层和监控对象, 例如transformer的第0层language_model.encoder.layers.0,可选择监控input、output、input_grad、output_grad。如果不清楚模型结构, 可以将 "print_struct" 字段设置为 true, 监控工具会打印模型中torch module的名字和详细结构,并在第1个step后退出。未配置时默认为全量监控。 | -| "input" | 可选 | "tuple[2]:0"的意思是目标module的前向input参数为长度为2的tuple, 我们关心的是tuple第0个元素。 | -| "output" | 必选 | "tensor"的意思是目标module的前向output参数类型为tensor | -| "input_grad" | 可选 | "tuple[2]:0"的意思是目标module的后向input_grad参数是长度为2的tuple, 我们关心的是tuple的第0个元素。 | -| "output_grad" | 必选 | "tuple[1]:0"的意思是目标module的后向input_grad参数是长度为1的tuple, 我们关心的是tuple的第0个元素。 | -| "dynamic_on" | 可选 | 在动态启停时使用,true代表打开监控,false代表关闭监控,默认值为false,且达到collect_times之后会自动将该值置为false待下次改true重启。**仅PyTorch场景支持此参数**。 | -| "collect_times" | 可选 | 设置采集次数,达到该次数后停止监控,默认值为100000000,目的是一直采集。 | -| "start_step" | 可选 | 设置开始采集step,模型训练达到start_step后开始监控采集,默认值为0,表示从step0开始监控采集。 | -| "step_interval" | 可选 | 设置采集step间隔,默认值为1,表示每个step均采集监控数据。 | -| "print_struct" | 可选 | 设置为true后监控工具会打印模型中torch module的名字和详细结构,并在第1个step后退出。不填默认为false。**仅PyTorch场景支持此参数**。 | -| "module_ranks" | 可选 | 用于在分布式训练场景中希望控制在哪些rank开启module监控。如果不填,则默认在所有rank开启。 列表内rank要求为int类型。 | -| "ur_distribution" | 可选 | 若为true则会统计adam优化器指定模块(targets中指定)参数的update和ratio向量的数值分布,并展示在heatmap里,默认为false,同时format字段必须设置为tensorboard。
依赖histc算子, 需要CANN8.0.rc2以上版本, 否则会有严重的性能问题。**仅PyTorch场景支持此参数**。 | -| "xy_distribution" | 可选 | 若为true则会监控指定module(targets中指定)的输入输出张量。 默认为false。 | -| "all_xy" | 可选 | 开启xy_distribution后生效,若为true,监控所有module。默认为false。
与targets同时生效,all_xy配置为true时,若targets配置module_xx和指定对象,则module_xx按targets配置生效,其他module则监控全部对象,包含input、output、input_grad、output_grad。 | -| "forward_only" | 可选 | 开启xy_distribution后生效,若为true,仅监控指定module的前向,targets中的input_grad、output_grad不生效。默认为false。 | -| "backward_only" | 可选 | 开启xy_distribution后生效,若为true,仅监控指定module的反向,targets中的input、output不生效。默认为false。 | -| "mv_distribution" | 可选 | 若为true则会监控指定模块中的参数的优化器状态, 默认为false。版本依赖histc算子, 需要CANN8.0.rc2以上版本, 否则会有严重的性能问题。**仅PyTorch场景支持此参数**。 | +| "xy_distribution" | 可选 | 若为true则会监控指定module(targets中指定)的输入输出张量。 默认为false。 | +| "all_xy" | 可选 | 开启xy_distribution后生效,若为true,监控所有module。默认为false。
与targets同时生效,all_xy配置为true时,若targets配置module_xx和指定对象,则module_xx按targets配置生效,其他module则监控全部对象,包含input、output、input_grad、output_grad。 | +| "forward_only" | 可选 | 开启xy_distribution后生效,若为true,仅监控指定module的前向,targets中的input_grad、output_grad不生效。默认为false。 | +| "backward_only" | 可选 | 开启xy_distribution后生效,若为true,仅监控指定module的反向,targets中的input、output不生效。默认为false。 | +| "mv_distribution" | 可选 | 若为true则会监控指定模块中的参数的优化器状态, 默认为false。版本
可参考的实际案例:[MindSpeed&LLamaFactory数据采集和自动比对](./visualization/mindspeed_llamafactory_mapping.md) | 否 | +| -oc 或 --overflow_check | 是否开启溢出检测模式,开启后会在输出db文件中(`compare_{timestamp}.vis.db或build_{timestamp}.vis.db`)对每个溢出节点进行标记溢出等级,溢出等级说明参考[溢出等级说明](#312-溢出等级说明) | 否 | +| -f 或 --fuzzy_match | 是否开启模糊匹配,bool类型。模糊匹配说明参考[匹配说明](#311-匹配说明) | 否 | #### 3.1.1 匹配说明 @@ -87,7 +88,7 @@ msprobe -f pytorch graph -i ./compare.json -o ./output | npu_path | 指定待调试侧比对路径,str类型。工具根据路径格式自动进行单rank比对、多rank批量比对或多step批量比对,具体格式参考3.2 图构建和比对。 | 是 | | bench_path | 指定标杆侧比对路径,str类型。单图构建场景可以不配置。 | 否 | | is_print_compare_log | 配置是否开启单个算子的日志打屏。可取值 true 或 false,默认为 true。关闭后则只输出常规日志,bool 类型。 | 否 | - +| parallel_merge | 配置是否开启不同切分策略下的图合并,dict类型。rank_size、tp、pp参数按实际情况进行配置。比对时配置npu、bench,只构图配置npu。 配置示例见[3.2.5 不同切分策略下的图合并](#325-不同切分策略下的图合并)。 | 否 | ### 3.2 图构建和比对 @@ -164,7 +165,7 @@ npu_path或bench_path格式:必须包含dump.json、stack.json和construct.jso msprobe -f pytorch graph -i ./compare.json -o ./output ``` -比对完成后将在**output**下生成一个**vis后缀文件**。 +比对完成后将在**output**下生成一个**vis.db后缀文件**。 #### 3.2.3 批量构建或比对 ##### 3.2.3.1 多rank批量构建或比对 @@ -208,25 +209,15 @@ npu_path或bench_path格式:必须只包含rank+数字格式的文件夹,且 ``` msprobe -f pytorch graph -i ./compare.json -o ./output ``` -比对完成后将在**output**下生成n个**vis后缀文件**。 +比对完成后将在**output**下生成1个**vis.db后缀文件**。 图构建: ``` -├── build_rank0_{timestamp}.vis -├── build_rank1_{timestamp}.vis -├── build_rank2_{timestamp}.vis -├── build_rank3_{timestamp}.vis -├── ... -├── build_rankn_{timestamp}.vis +├── build_{timestamp}.vis.db ``` 图比对: ``` -├── compare_rank0_{timestamp}.vis -├── compare_rank1_{timestamp}.vis -├── compare_rank2_{timestamp}.vis -├── compare_rank3_{timestamp}.vis -├── ... -├── compare_rankn_{timestamp}.vis +├── compare_{timestamp}.vis.db ``` ##### 3.2.3.2 多step批量构建或比对 批量构建或比对多个step下的所有rank的数据 @@ -273,33 +264,15 @@ npu_path或bench_path格式:必须只包含step+数字格式的文件夹,且 ``` msprobe -f pytorch graph -i ./compare.json -o ./output ``` -比对完成后将在**output**下生成若干个**vis后缀文件**。 +比对完成后将在**output**下生成1个**vis.db后缀文件**。 图构建: ``` -├── build_step0_rank0_{timestamp}.vis -├── build_step0_rank1_{timestamp}.vis -├── build_step0_rank2_{timestamp}.vis -├── build_step0_rank3_{timestamp}.vis -├── build_step1_rank0_{timestamp}.vis -├── build_step1_rank1_{timestamp}.vis -├── build_step1_rank2_{timestamp}.vis -├── build_step1_rank3_{timestamp}.vis -├── ... -├── build_stepn_rankn_{timestamp}.vis +├── build_{timestamp}.vis.db ``` 图比对: ``` -├── compare_step0_rank0_{timestamp}.vis -├── compare_step0_rank1_{timestamp}.vis -├── compare_step0_rank2_{timestamp}.vis -├── compare_step0_rank3_{timestamp}.vis -├── compare_step1_rank0_{timestamp}.vis -├── compare_step1_rank1_{timestamp}.vis -├── compare_step1_rank2_{timestamp}.vis -├── compare_step1_rank3_{timestamp}.vis -├── ... -├── compare_stepn_rankn_{timestamp}.vis +├── compare_{timestamp}.vis.db ``` #### 3.2.4 仅模型结构比对 @@ -312,6 +285,49 @@ dump配置请参考[dump配置示例](./03.config_examples.md#16-task-配置为- 得到dump数据后,若需比较特定两个rank之间的数据,请参考[3.2.2 双图比对](#322-双图比对);若需进行多个rank或多个step的数据批量比对,请参考[3.2.3 批量构建或比对](#323-批量构建或比对)。 +#### 3.2.5 不同切分策略下的图合并 + +适用场景:不同模型并行切分策略下,两个模型产生了精度差异,需要进行整网数据比对,但被切分的数据或模型结构分布于多rank中无法进行比对,需要将分布在各个rank的数据或模型结构合并后再进行比对。 + +使用限制: + +- 当前支持的模型并行切分策略:Tensor Parallelism(TP)、Pipeline Parallelism(PP)、Virtual Pipeline Parallelism(VPP),暂不支持Context Parallelism(CP)和Expert Parallelism(EP)。 +- 当前支持基于Megatron、MindSpeed-LLM套件的模型进行图合并,其他套件的模型图合并效果有待验证; +- 当前仅支持msprobe工具dump的statistics数据, level需指定L0或者mix; +- 图合并比对时要确保Data Parallelism(DP)切分一致,例如rank=8 tp=1 pp=8的配置,dp=1,图合并将得到一张图,rank=8 tp=1 pp=4的配置,dp=2,图合并将得到两张图,暂不支持数量不一致的图进行比对。 + +使能方式: + +在compare.json里增加parallel_merge配置项, rank_size、tp、pp和vpp参数按实际情况进行配置。 + +参数说明: + +所需tp、pp和vpp参数来自于Megatron、MindSpeed-LLM套件中的训练脚本实际配置。 + +| 参数名 | 说明 | 是否必填 | +|-----------|--------------------------------------------------------------------------------------------------------------------------|------| +| rank_size | 模型实际训练所用加速卡的数量,int类型。`rank_size=tp*pp*cp*dp`,由于暂不支持CP合并,图合并功能中默认cp=1。 | 是 | +| tp | 张量并行大小,int类型。实际训练脚本中需指定`--tensor-model-parallel-size T`,其中`T`表示张量模型并行大小,即**图合并所需的参数tp**, `tp=T`。 | 是 | +| pp | 流水线并行的阶段数,int类型。实际训练脚本中需指定`--pipeline-model-parallel-size P`,其中`P`表示流水线并行的阶段数,即**图合并所需的参数pp**, `pp=P`。 | 是 | +| vpp | 虚拟流水线并行阶段数,int类型。虚拟流水线并行依赖流水线并行,实际训练脚本中需指定`--num-layers-per-virtual-pipeline-stage V`,其中`V`表示每个虚拟流水线阶段的层数;指定`--num-layers L`,其中`L`表示模型总层数,**图合并所需的参数vpp**=`L/V/P`。vpp参数可以不配置,默认vpp=1代表未开启虚拟流水线并行。 | 否 | +| order | 模型并行维度的排序顺序,str类型。Megatron默认为`tp-cp-ep-dp-pp`。 如果使用msprobe工具dump数据指定level为L0并且实际训练脚本中的order非默认值(例如实际训练脚本中指定`--use-tp-pp-dp-mapping`),请传入修改后的order。dump数据指定level为mix则无需修改。 | 否 | + +npu_path、bench_path的配置以及执行命令请参考[3.2.3 批量构建或比对](#323-批量构建或比对) + +如果只进行图构建,"bench_path"和"parallel_merge"中的"bench"参数可不配置。 + +``` +{ + "npu_path": "./npu_dump", + "bench_path": "./bench_dump", + "is_print_compare_log": true, + "parallel_merge": { + "npu": {"rank_size": 8, "tp": 8, "pp": 1}, + "bench": {"rank_size": 8, "tp": 1, "pp": 8} + } +} +``` + ## 4.启动tensorboard ### 4.1 可直连的服务器 @@ -330,8 +346,22 @@ ubuntu是机器地址,6008是端口号。 **注意,ubuntu需要替换为真实的服务器地址,例如真实的服务器地址为10.123.456.78,则需要在浏览器窗口输入http://10.123.456.78:6008** ### 4.2 不可直连的服务器 -**如果链接打不开(服务器无法直连需要挂vpn才能连接等场景),可以尝试使用vscode连接服务器,在vscode终端输入:** +**如果链接打不开(服务器无法直连需要挂vpn才能连接等场景),可以尝试以下方法,选择其一即可:** +1.本地电脑网络手动设置代理,例如Windows10系统,在【手动设置代理】中添加服务器地址(例如10.123.456.78) + +![proxy](./img/visualization/proxy.png) + +然后,在服务器中输入: +``` +tensorboard --logdir out_path --bind_all --port 6008[可选,端口号] +``` + +最后,在浏览器窗口输入http://10.123.456.78:6008 + +**注意,如果当前服务器开启了防火墙,则此方法无效,需要关闭防火墙,或者尝试后续方法** + +2.或者使用vscode连接服务器,在vscode终端输入: ``` tensorboard --logdir out_path ``` @@ -339,17 +369,30 @@ tensorboard --logdir out_path 按住CTRL点击链接即可 +3.或者将构图结果件vis文件从服务器传输至本地电脑,在本地电脑中安装tb_graph_ascend插件查看构图结果 + +电脑终端输入: +``` +tensorboard --logdir out_path +``` +按住CTRL点击链接即可 + ## 5.浏览器查看 ### 5.1 浏览器打开图 推荐使用谷歌浏览器,在浏览器中输入机器地址+端口号回车,出现TensorBoard页面,其中/#graph_ascend会自动拼接。 + ![vis_browser_1](./img/visualization/vis_browser_1.png) + 如果您切换了TensorBoard的其他功能,此时想回到模型分级可视化页面,可以点击左上方的**GRAPH_ASCEND** + ![vis_browser_2](./img/visualization/vis_browser_2.png) ### 5.2 查看图 ![vis_show_info.png](./img/visualization/vis_show_info.png) +MicroStep是指在一次完整的权重更新前执行的多次前向和反向传播过程,一次完整的训练迭代(step)可以进一步细分为多个更小的步骤(micro step)。其中分级可视化工具通过识别模型首层结构中一次完整的前反向作为一次micro step。 + ### 5.3 名称搜索 ![vis_search_info.png](./img/visualization/vis_search_info.png) @@ -357,37 +400,69 @@ tensorboard --logdir out_path ![vis_precision_info.png](./img/visualization/vis_precision_info.png) ### 5.5 未匹配节点筛选 -节点匹配规则: -1.名称一致 +参考[匹配说明](#311-匹配说明) ,不符合匹配规则的节点为无匹配节点,颜色标灰。适用于排查两个模型结构差异的场景。 -2.节点输入输出参数数量一致,参数type、shape一致 +![vis_unmatch_info.png](./img/visualization/vis_unmatch_info.png) -3.节点的层级一致(父节点们一致) +### 5.6 手动选择节点匹配 -![vis_unmatch_info.png](./img/visualization/vis_unmatch_info.png) +可通过浏览器界面,通过鼠标选择两个待匹配的灰色节点进行匹配。当前暂不支持真实数据模式。 + +![vis_match_info.png](./img/visualization/vis_match_info.png) ## 6.图比对说明 -### 颜色 +### 6.1 颜色 颜色越深,精度比对差异越大,越可疑,具体信息可见浏览器页面左下角颜色图例。 -### 疑似有精度问题判定 - -#### 真实数据模式 -节点中所有输入的最小双千指标和所有输出的最小双千分之一指标的差值,反映了双千指标的下降情况,**值越大精度差距越大,颜色标记越深**。 +#### 6.1.1 真实数据模式 +节点中所有输入的最小双千指标和所有输出的最小双千分之一指标的差值,反映了双千指标的下降情况,**该数值越大,表明两组模型的精度差异越大,在图中标注的对应颜色会更深**。 ``One Thousandth Err Ratio(双千分之一)精度指标:Tensor中的元素逐个与对应的标杆数据对比,相对误差小于千分之一的比例占总元素个数的比例,比例越接近1越好`` -#### 统计信息模式 -节点中输出的统计量相对误差,**值越大精度差距越大,颜色标记越深**。 +如果调试侧(NPU)节点的output指标中的最大值(MAX)或最小值(MIN)中存在 nan/inf/-inf,直接标记为最深颜色。 + +#### 6.1.2 统计信息模式 +节点中输出的统计量相对误差,**该数值越大,表明两组模型的精度差异越大,在图中标注的对应颜色会更深**。 ``相对误差:abs((npu统计值 - bench统计值) / bench统计值)`` -#### md5模式 +如果调试侧(NPU)节点的output指标中的最大值(MAX)或最小值(MIN)中存在 nan/inf/-inf,直接标记为最深颜色。 + +#### 6.1.3 md5模式 节点中任意输入输出的md5值不同。 +### 6.2 指标说明 + +精度比对从三个层面评估 API 的精度,依次是:真实数据模式、统计数据模式和 MD5 模式。比对结果分别有不同的指标。 + +**公共指标**: +- name: 参数名称,例如input.0 +- type: 类型,例如torch.Tensor +- dtype: 数据类型,例如torch.float32 +- shape: 张量形状,例如[32, 1, 32] +- Max: 最大值 +- Min: 最小值 +- Mean: 平均值 +- Norm: L2-范数 + +**真实数据模式指标**: +- Cosine: tensor 余弦相似度 +- EucDist: tensor 欧式距离 +- MaxAbsErr: tensor 最大绝对误差 +- MaxRelativeErr: tensor 最大相对误差 +- One Thousandth Err Ratio: tensor 相对误差小于千分之一的比例(双千分之一) +- Five Thousandth Err Ratio: tensor 相对误差小于千分之五的比例(双千分之五) + +**统计数据模式指标** +- (Max, Min, Mean, Norm) diff: 统计量绝对误差 +- (Max, Min, Mean, Norm) RelativeErr: 统计量相对误差 + +**MD5模式指标** +- md5: CRC-32 值 + ## 7.附录 ### 7.1 自定义映射文件(Layer) @@ -430,47 +505,15 @@ yaml文件中只需配置待调试侧与标杆侧模型代码中功能一致但 ![ms_dump](./img/ms_layer.png) -### 7.2 堆栈信息说明 - -**精简堆栈** - -保留一条当前模块或api的调用信息 - -```json -{ - "Module.layer1.0.bn1.BatchNorm2d.forward.0": [ - "File /home/torchvision/models/resnet.py, line 93, in forward, \n out = self.bn1(out)" - ] -} -``` - -**完整堆栈** - -当前模块或api完整的调用信息 - -```json -{ - "Module.layer1.0.bn1.BatchNorm2d.forward.0": [ - "File /home/torchvision/models/resnet.py, line 93, in forward, \n out = self.bn1(out)", - "File /home/torch/nn/modules/module.py, line 1568, in _call_impl, \n result = forward_call(*args, **kwargs)", - "File /home/torch/nn/modules/module.py, line 1518, in _wrapped_call_impl, \n return self._call_impl(*args, **kwargs)", - "File /home/torch/nn/modules/container.py, line 215, in forward, \n input = module(input)", - "File /home/torch/nn/modules/module.py, line 1568, in _call_impl, \n result = forward_call(*args, **kwargs)", - "File /home/torch/nn/modules/module.py, line 1518, in _wrapped_call_impl, \n return self._call_impl(*args, **kwargs)", - "File /home/torchvision/models/resnet.py, line 273, in _forward_impl, \n x = self.layer1(x)", - "File /home/torchvision/models/resnet.py, line 285, in forward, \n return self._forward_impl(x)", - "File /home/torch/nn/modules/module.py, line 1527, in _call_impl, \n return forward_call(*args, **kwargs)", - "File /home/torch/nn/modules/module.py, line 1518, in _wrapped_call_impl, \n return self._call_impl(*args, **kwargs)", - "File /home/visualization/resnet18.py, line 40, in , \n outputs = model(inputs)" - ] -} - -``` # FAQ 1. 图比对场景,节点呈现灰色,且没有精度比对数据,怎么处理? 节点呈现灰色,代表左边待调试侧节点与右边标杆侧节点没有匹配上,可能有以下几点原因: - **标杆侧确实没有能与待调试侧匹配上的节点**,属于代码实现上的差异,请确认此差异是否正常,是否会影响到整网精度。 -- **节点的输入或输出type、shape不一致,参数个数不一致,节点所在层级的父层级不一致**,导致节点无法匹配,具体匹配规则见[匹配说明](#311-匹配说明),可尝试使用模糊匹配功能,如何使用此功能请参考[构图命令行说明](#31-构图命令行说明)。如果是参数shape不一致,即使是模糊匹配功能也无法让节点匹配上,请检查参数shape不一致是否合理。 -- **节点名称不一致**,导致节点无法匹配,可使用layer mapping功能,如何使用此功能请参考[构图命令行说明](#31-构图命令行说明),如何自定义映射文件请参考[模型分级可视化如何配置layer mapping映射文件](./visualization/layer_mapping_example.md)。 +- **节点名称一致,但节点的输入或输出type、shape不一致,参数个数不一致,节点所在层级的父层级不一致,导致节点无法匹配** + - 具体匹配规则见[匹配说明](#311-匹配说明),可尝试使用模糊匹配功能,如何使用此功能请参考[构图命令行说明](#31-构图命令行说明); + - 如果是参数shape不一致,即使是模糊匹配功能也无法让节点匹配上,请检查参数shape不一致是否合理。 +- **节点名称不一致**,导致节点无法匹配,目前提供两种方法,选其一即可 + - 可使用layer mapping功能,如何使用此功能请参考[构图命令行说明](#31-构图命令行说明),如何自定义映射文件请参考[模型分级可视化如何配置layer mapping映射文件](./visualization/layer_mapping_example.md); + - 可通过浏览器页面手动选择未匹配节点进行匹配,请参考[手动选择节点匹配](#56-手动选择节点匹配)。 diff --git a/debug/accuracy_tools/msprobe/docs/22.visualization_MindSpore.md b/debug/accuracy_tools/msprobe/docs/22.visualization_MindSpore.md index 12306b8be027e7cee715f99f75b00f7504ba8252..4efd0532a25a58002d4012c0e712207a533018e6 100644 --- a/debug/accuracy_tools/msprobe/docs/22.visualization_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/22.visualization_MindSpore.md @@ -2,55 +2,56 @@ 分级可视化工具将msprobe工具dump的精度数据进行解析,还原模型图结构,实现模型各个层级的精度数据比对,方便用户理解模型结构、分析精度问题。 -工具支持MindSpore版本:2.4.0 +工具支持MindSpore版本:>=2.4.0 -## 展示示例 +## 工具特性 -支持重建模型的层级结构; - -支持两个模型的结构差异比对; - -支持两个模型的精度数据比对,支持疑似有精度问题节点的快速搜索,自动跳转展开节点所在的层级。 +- 支持重建模型的层级结构; +- 支持两个模型的结构差异比对; +- 支持两个模型的精度数据比对; +- 支持模型数据的溢出检测; +- 支持多卡场景的批量构图,能够关联各卡的通信节点,分析各卡之间的数据传递; +- 支持节点名称搜索,按精度比对结果筛选节点,按溢出检测结果筛选节点,支持自动跳转展开节点所在的层级; +- 支持跨套件、跨框架的模型比对。 +- 支持不同切分策略下两个模型的精度数据比对:[不同切分策略下的图合并](#325-不同切分策略下的图合并)。 ![vis_show](./img/visualization/vis_showcase.png) ## 1.依赖安装 -分级可视化工具依赖**msprobe工具**和**tensorboard。** - ### 1.1 安装msprobe工具 -[msprobe工具安装](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/01.installation.md) +[msprobe工具安装](./01.installation.md) ### 1.2 安装tb_graph_ascend **请安装tb_graph_ascend,否则无法解析构图结果。** -``pip3 install tb-graph-ascend``即可。 +[tb_graph_ascend安装](../../../../plugins/tensorboard-plugins/tb_graph_ascend#2-安装方式) ## 2.模型结构数据采集 -[MindSpore场景的精度数据采集](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md) +[PyTorch场景的数据采集](./06.data_dump_MindSpore.md) **仅支持动态图场景,需要选择level为L0(cell信息)或者mix(cell信息+api信息),才能采集到模型结构数据,即采集结果件construct.json内容不为空**。 ## 3.生成图结构文件 ### 3.1 构图命令行说明 - + **命令示例如下**: ``` msprobe -f mindspore graph -i ./compare.json -o ./output ``` **命令行参数说明**: -| 参数名 | 说明 | 是否必选 | -|-------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | -| -i 或 --input_path | 指定比对文件,参考[比对文件说明](#313-比对文件说明) | 是 | -| -o 或 --output_path | 配置比对结果文件存盘目录,str 类型。文件名称基于时间戳自动生成,格式为:`compare_{timestamp}.vis或build_{timestamp}.vis`。 | 是 | -| -lm 或 --layer_mapping| 跨框架比对,MindSpore和PyTorch的比对场景。配置该参数时表示开启跨框架Layer层的比对功能,指定模型代码中的Layer层后,可以识别对应dump数据中的模块或API。需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件(Layer)](#71-自定义映射文件layer), 如何配置自定义映射文件请参考[模型分级可视化如何配置layer mapping映射文件](./visualization/layer_mapping_example.md)。 | 否 | -| -oc 或 --overflow_check | 是否开启溢出检测模式,开启后会在输出vis文件中(`compare_{timestamp}.vis或build_{timestamp}.vis`)对每个溢出节点进行标记溢出等级,溢出等级说明参考[溢出等级说明](#312-溢出等级说明) | 否 | -| -f 或 --fuzzy_match | 是否开启模糊匹配,bool类型。模糊匹配说明参考[匹配说明](#311-匹配说明) | 否 | -| -cs 或 --complete_stack | 是否使用完整的堆栈信息,bool类型。默认使用精简的堆栈信息,数据量小有助于增加流畅度。完整堆栈和精简堆栈信息参考[堆栈信息说明](#72-堆栈信息说明) | 否 | +| 参数名 | 说明 | 是否必选 | +|-------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | +| -f 或 --framework | 指定训练框架。mindspore。 | 是 | +| -i 或 --input_path | 指定比对文件,参考[比对文件说明](#313-比对文件说明) | 是 | +| -o 或 --output_path | 配置比对结果文件存盘目录,str 类型。文件名称基于时间戳自动生成,格式为:`compare_{timestamp}.vis.db或build_{timestamp}.vis.db`。 | 是 | +| -lm 或 --layer_mapping| 跨框架比对,MindSpore和PyTorch的比对场景。配置该参数时表示开启跨框架Layer层的比对功能,指定模型代码中的Layer层后,可以识别对应dump数据中的模块或API。需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件(Layer)](#71-自定义映射文件layer), 如何配置自定义映射文件请参考[模型分级可视化如何配置layer mapping映射文件](./visualization/layer_mapping_example.md)。配置该参数后,将仅按节点名称进行比对,忽略节点的 type 和 shape。如果调试侧和标杆侧有名称不同的节点,则需要配置自定义映射文件,-lm参数传入自定义映射文件路径;如果调试侧和标杆侧节点名称相同,则仅指定-lm即可。 | 否 | +| -oc 或 --overflow_check | 是否开启溢出检测模式,开启后会在输出db文件中(`compare_{timestamp}.vis.db或build_{timestamp}.vis.db`)对每个溢出节点进行标记溢出等级,溢出等级说明参考[溢出等级说明](#312-溢出等级说明) | 否 | +| -f 或 --fuzzy_match | 是否开启模糊匹配,bool类型。模糊匹配说明参考[匹配说明](#311-匹配说明) | 否 | #### 3.1.1 匹配说明 @@ -62,7 +63,7 @@ msprobe -f mindspore graph -i ./compare.json -o ./output - 节点的层级一致(父节点们一致) 2.模糊匹配 -- Cell节点dump名称一致,两个匹配上的Cell节点, 忽略各自节点下所有api的dump调用次数,按照名称一致+Cell节点内的调用顺序进行匹配 +- Cell节点dump名称一致,两个匹配上的Cell节点,忽略各自节点下所有api的dump调用次数,按照名称一致+Cell节点内的调用顺序进行匹配 - ![fuzzy_match_ms.png](./img/visualization/fuzzy_match_ms.png) - 参数shape一致 @@ -83,12 +84,12 @@ msprobe -f mindspore graph -i ./compare.json -o ./output ``` **比对文件参数说明**: -| 参数名 | 说明 | 是否必选 | -|-------------------|-------------------------------------------------------------------------------------------------------|------| -| npu_path | 指定待调试侧比对路径,str类型。工具根据路径格式自动进行单rank比对、多rank批量比对或多step批量比对,具体格式参考3.2 图构建和比对。 | 是 | -| bench_path | 指定标杆侧比对路径,str类型。单图构建场景可以不配置 | 否 | -| is_print_compare_log | 配置是否开启单个算子的日志打屏。可取值 true 或 false,默认为 true。关闭后则只输出常规日志,bool 类型。 | 否 | - +| 参数名 | 说明 | 是否必选 | +|-------------------|----------------------------------------------------------------------------|------| +| npu_path | 指定待调试侧比对路径,str类型。工具根据路径格式自动进行单rank比对、多rank批量比对或多step批量比对,具体格式参考3.2 图构建和比对。 | 是 | +| bench_path | 指定标杆侧比对路径,str类型。单图构建场景可以不配置。 | 否 | +| is_print_compare_log | 配置是否开启单个算子的日志打屏。可取值 true 或 false,默认为 true。关闭后则只输出常规日志,bool 类型。 | 否 | +| parallel_merge | 配置是否开启不同切分策略下的图合并,dict类型。rank_size、tp、pp参数按实际情况进行配置。比对时配置npu、bench,只构图配置npu。 配置示例见[3.2.5 不同切分策略下的图合并](#325-不同切分策略下的图合并)。 | 否 | ### 3.2 图构建和比对 @@ -209,25 +210,15 @@ npu_path或bench_path格式:必须只包含rank+数字格式的文件夹,且 ``` msprobe -f mindspore graph -i ./compare.json -o ./output ``` -比对完成后将在**output**下生成n个**vis后缀文件**。 +比对完成后将在**output**下生成1个**vis.db后缀文件**。 图构建: ``` -├── build_rank0_{timestamp}.vis -├── build_rank1_{timestamp}.vis -├── build_rank2_{timestamp}.vis -├── build_rank3_{timestamp}.vis -├── ... -├── build_rankn_{timestamp}.vis +├── build_{timestamp}.vis.db ``` 图比对: ``` -├── compare_rank0_{timestamp}.vis -├── compare_rank1_{timestamp}.vis -├── compare_rank2_{timestamp}.vis -├── compare_rank3_{timestamp}.vis -├── ... -├── compare_rankn_{timestamp}.vis +├── compare_{timestamp}.vis.db ``` ##### 3.2.3.2 多step批量构建或比对 批量构建或比对多个step下的所有rank的数据 @@ -274,33 +265,15 @@ npu_path或bench_path格式:必须只包含step+数字格式的文件夹,且 ``` msprobe -f mindspore graph -i ./compare.json -o ./output ``` -比对完成后将在**output**下生成若干个**vis后缀文件**。 +比对完成后将在**output**下生成1个**vis.db后缀文件**。 图构建: ``` -├── build_step0_rank0_{timestamp}.vis -├── build_step0_rank1_{timestamp}.vis -├── build_step0_rank2_{timestamp}.vis -├── build_step0_rank3_{timestamp}.vis -├── build_step1_rank0_{timestamp}.vis -├── build_step1_rank1_{timestamp}.vis -├── build_step1_rank2_{timestamp}.vis -├── build_step1_rank3_{timestamp}.vis -├── ... -├── build_stepn_rankn_{timestamp}.vis +├── build_{timestamp}.vis.db ``` 图比对: ``` -├── compare_step0_rank0_{timestamp}.vis -├── compare_step0_rank1_{timestamp}.vis -├── compare_step0_rank2_{timestamp}.vis -├── compare_step0_rank3_{timestamp}.vis -├── compare_step1_rank0_{timestamp}.vis -├── compare_step1_rank1_{timestamp}.vis -├── compare_step1_rank2_{timestamp}.vis -├── compare_step1_rank3_{timestamp}.vis -├── ... -├── compare_stepn_rankn_{timestamp}.vis +├── compare_{timestamp}.vis.db ``` #### 3.2.4 仅模型结构比对 @@ -313,8 +286,50 @@ dump配置请参考[dump配置示例](./03.config_examples.md#35-task-配置为- 得到dump数据后,若需比较特定两个rank之间的数据,请参考[3.2.2 双图比对](#322-双图比对);若需进行多个rank或多个step的数据批量比对,请参考[3.2.3 批量构建或比对](#323-批量构建或比对)。 +#### 3.2.5 不同切分策略下的图合并 + +适用场景:不同模型并行切分策略下,两个模型产生了精度差异,需要进行整网数据比对,但被切分的数据或模型结构分布于多rank中无法进行比对,需要将分布在各个rank的数据或模型结构合并后再进行比对。 + +使用限制: + +- 当前支持的模型并行切分策略:Tensor Parallelism(TP)、Pipeline Parallelism(PP)、Virtual Pipeline Parallelism(VPP),暂不支持Context Parallelism(CP)和Expert Parallelism(EP)。 +- 当前支持基于Megatron、MindSpeed-LLM套件的模型进行图合并,其他套件的模型图合并效果有待验证; +- 当前仅支持msprobe工具dump的statistics数据, level需指定L0或者mix; +- 图合并比对时要确保Data Parallelism(DP)切分一致,例如rank=8 tp=1 pp=8的配置,dp=1,图合并将得到一张图,rank=8 tp=1 pp=4的配置,dp=2,图合并将得到两张图,暂不支持数量不一致的图进行比对。 -## 4.启动tensorboard +使能方式: + +在compare.json里增加parallel_merge配置项, rank_size、tp、pp和vpp参数按实际情况进行配置。 + +参数说明: + +所需tp、pp和vpp参数来自于Megatron、MindSpeed-LLM套件中的训练脚本实际配置。 + +| 参数名 | 说明 | 是否必填 | +|-----------|--------------------------------------------------------------------------------------------------------------------------|------| +| rank_size | 模型实际训练所用加速卡的数量,int类型。`rank_size=tp*pp*cp*dp`,由于暂不支持CP合并,图合并功能中默认cp=1。 | 是 | +| tp | 张量并行大小,int类型。实际训练脚本中需指定`--tensor-model-parallel-size T`,其中`T`表示张量模型并行大小,即**图合并所需的参数tp**, `tp=T`。 | 是 | +| pp | 流水线并行的阶段数,int类型。实际训练脚本中需指定`--pipeline-model-parallel-size P`,其中`P`表示流水线并行的阶段数,即**图合并所需的参数pp**, `pp=P`。 | 是 | +| vpp | 虚拟流水线并行阶段数,int类型。虚拟流水线并行依赖流水线并行,实际训练脚本中需指定`--num-layers-per-virtual-pipeline-stage V`,其中`V`表示每个虚拟流水线阶段的层数;指定`--num-layers L`,其中`L`表示模型总层数,**图合并所需的参数vpp**=`L/V/P`。vpp参数可以不配置,默认vpp=1代表未开启虚拟流水线并行。 | 否 | +| order | 模型并行维度的排序顺序,str类型。Megatron默认为`tp-cp-ep-dp-pp`。 如果使用msprobe工具dump数据指定level为L0并且实际训练脚本中的order非默认值(例如实际训练脚本中指定`--use-tp-pp-dp-mapping`),请传入修改后的order。dump数据指定level为mix则无需修改。 | 否 | + +npu_path、bench_path的配置以及执行命令请参考[3.2.3 批量构建或比对](#323-批量构建或比对) + +如果只进行图构建,"bench_path"和"parallel_merge"中的"bench"参数可不配置。 + +``` +{ + "npu_path": "./npu_dump", + "bench_path": "./bench_dump", + "is_print_compare_log": true, + "parallel_merge": { + "npu": {"rank_size": 8, "tp": 8, "pp": 1}, + "bench": {"rank_size": 8, "tp": 1, "pp": 8} + } +} +``` + +## 4.启动TensorBoard ### 4.1 可直连的服务器 @@ -329,11 +344,25 @@ tensorboard --logdir out_path --bind_all --port [可选,端口号] ubuntu是机器地址,6008是端口号。 -**注意,ubuntu需要替换为真实的服务器地址,例如真实的服务器地址为10.123.456.78,则需要在浏览器窗口输入http://10.123.456.78:6008** +**注意,ubuntu需要替换为真实的服务器地址,例如真实的服务器地址为10.123.456.78,则需要在浏览器窗口输入 http://10.123.456.78:6008** ### 4.2 不可直连的服务器 -**如果链接打不开(服务器无法直连需要挂vpn才能连接等场景),可以尝试使用vscode连接服务器,在vscode终端输入:** +**如果链接打不开(服务器无法直连需要挂vpn才能连接等场景),可以尝试以下方法,选择其一即可:** +1.本地电脑网络手动设置代理,例如Windows10系统,在【手动设置代理】中添加服务器地址(例如10.123.456.78) + +![proxy](./img/visualization/proxy.png) + +然后,在服务器中输入: +``` +tensorboard --logdir out_path --bind_all --port 6008[可选,端口号] +``` + +最后,在浏览器窗口输入 http://10.123.456.78:6008 + +**注意,如果当前服务器开启了防火墙,则此方法无效,需要关闭防火墙,或者尝试后续方法** + +2.或者使用vscode连接服务器,在vscode终端输入: ``` tensorboard --logdir out_path ``` @@ -341,17 +370,30 @@ tensorboard --logdir out_path 按住CTRL点击链接即可 +3.或者将构图结果件vis文件从服务器传输至本地电脑,在本地电脑中安装tb_graph_ascend插件查看构图结果 + +电脑终端输入: +``` +tensorboard --logdir out_path +``` +按住CTRL点击链接即可 + ## 5.浏览器查看 ### 5.1 浏览器打开图 推荐使用谷歌浏览器,在浏览器中输入机器地址+端口号回车,出现TensorBoard页面,其中/#graph_ascend会自动拼接。 + ![vis_browser_1](./img/visualization/vis_browser_1.png) + 如果您切换了TensorBoard的其他功能,此时想回到模型分级可视化页面,可以点击左上方的**GRAPH_ASCEND** + ![vis_browser_2](./img/visualization/vis_browser_2.png) ### 5.2 查看图 ![vis_show_info.png](./img/visualization/vis_show_info.png) +MicroStep是指在一次完整的权重更新前执行的多次前向和反向传播过程,一次完整的训练迭代(step)可以进一步细分为多个更小的步骤(micro step)。其中分级可视化工具通过识别模型首层结构中一次完整的前反向作为一次micro step。 + ### 5.3 名称搜索 ![vis_search_info.png](./img/visualization/vis_search_info.png) @@ -359,37 +401,69 @@ tensorboard --logdir out_path ![vis_precision_info.png](./img/visualization/vis_precision_info.png) ### 5.5 未匹配节点筛选 -节点匹配规则: -1.名称一致 +参考[匹配说明](#311-匹配说明) ,不符合匹配规则的节点为无匹配节点,颜色标灰。适用于排查两个模型结构差异的场景。 -2.节点输入输出参数数量一致,参数type、shape一致 +![vis_unmatch_info.png](./img/visualization/vis_unmatch_info.png) -3.节点的层级一致(父节点们一致) +### 5.6 手动选择节点匹配 -![vis_unmatch_info.png](./img/visualization/vis_unmatch_info.png) +可通过浏览器界面,通过鼠标选择两个待匹配的灰色节点进行匹配。当前暂不支持真实数据模式。 + +![vis_match_info.png](./img/visualization/vis_match_info.png) ## 6.图比对说明 -### 颜色 +### 6.1 颜色 颜色越深,精度比对差异越大,越可疑,具体信息可见浏览器页面左下角颜色图例。 -### 疑似有精度问题判定 - -#### 真实数据模式 -节点中所有输入的最小双千指标和所有输出的最小双千分之一指标的差值,反映了双千指标的下降情况,**值越大精度差距越大,颜色标记越深**。 +#### 6.1.1 真实数据模式 +节点中所有输入的最小双千指标和所有输出的最小双千分之一指标的差值,反映了双千指标的下降情况,**该数值越大,表明两组模型的精度差异越大,在图中标注的对应颜色会更深**。 ``One Thousandth Err Ratio(双千分之一)精度指标:Tensor中的元素逐个与对应的标杆数据对比,相对误差小于千分之一的比例占总元素个数的比例,比例越接近1越好`` -#### 统计信息模式 -节点中输出的统计量相对误差,**值越大精度差距越大,颜色标记越深**。 +如果调试侧(NPU)节点的output指标中的最大值(MAX)或最小值(MIN)中存在 nan/inf/-inf,直接标记为最深颜色。 + +#### 6.1.2 统计信息模式 +节点中输出的统计量相对误差,**该数值越大,表明两组模型的精度差异越大,在图中标注的对应颜色会更深**。 + +``相对误差:abs((npu统计值 - bench统计值) / bench统计值)`` -``相对误差:abs((npu统计值 - bench统计值) / bench统计值)`` +如果调试侧(NPU)节点的output指标中的最大值(MAX)或最小值(MIN)中存在 nan/inf/-inf,直接标记为最深颜色。 -#### md5模式 +#### 6.1.3 md5模式 节点中任意输入输出的md5值不同。 +### 6.2 指标说明 + +精度比对从三个层面评估 API 的精度,依次是:真实数据模式、统计数据模式和 MD5 模式。比对结果分别有不同的指标。 + +**公共指标**: +- name: 参数名称,例如input.0 +- type: 类型,例如mindspore.Tensor +- dtype: 数据类型,例如BFloat32 +- shape: 张量形状,例如[32, 1, 32] +- Max: 最大值 +- Min: 最小值 +- Mean: 平均值 +- Norm: L2-范数 + +**真实数据模式指标**: +- Cosine: tensor 余弦相似度 +- EucDist: tensor 欧式距离 +- MaxAbsErr: tensor 最大绝对误差 +- MaxRelativeErr: tensor 最大相对误差 +- One Thousandth Err Ratio: tensor 相对误差小于千分之一的比例(双千分之一) +- Five Thousandth Err Ratio: tensor 相对误差小于千分之五的比例(双千分之五) + +**统计数据模式指标** +- (Max, Min, Mean, Norm) diff: 统计量绝对误差 +- (Max, Min, Mean, Norm) RelativeErr: 统计量相对误差 + +**MD5模式指标** +- md5: CRC-32 值 + ## 7.附录 ### 7.1 自定义映射文件(Layer) @@ -427,66 +501,15 @@ yaml文件中只需配置MindSpore与PyTorch模型代码中功能一致但名称 ![ms_dump](./img/ms_layer.png) -### 7.2 堆栈信息说明 - -**精简堆栈** - -保留一条当前模块或api的调用信息 - -```json -{ - "Cell.model.language_model.embedding.word_embeddings.reduce_scatter_to_sp_region.ReduceScatterToSequenceParallelRegion.forward.0": [ - "File /home/mindformers/experimental/distri_cores/tensor_parallel/layers.py, line 770, in construct, \n output = self.reduce_scatter_to_sp_region(output_parallel)" - ] -} -``` - -**完整堆栈** - -当前模块或api完整的调用信息 - -```json -{ - "Cell.model.language_model.embedding.word_embeddings.reduce_scatter_to_sp_region.ReduceScatterToSequenceParallelRegion.forward.0": [ - "File /home/mindspore/nn/cell.py, line 507, in _run_construct, \n output = self._run_forward_hook(inputs, output)", - "File /home/mindspore/nn/cell.py, line 759, in _complex_call, \n output = self._run_construct(*args, **kwargs)", - "File /home/mindspore/nn/cell.py, line 747, in __call__, \n return self._complex_call(*args, **kwargs)", - "File /home/mindformers/experimental/distri_cores/tensor_parallel/layers.py, line 770, in construct, \n output = self.reduce_scatter_to_sp_region(output_parallel)", - "File /home/mindspore/nn/cell.py, line 2462, in _backward_hook_construct, \n outputs = self.construct(outputs, **kwargs)", - "File /home/mindspore/nn/cell.py, line 498, in _run_construct, \n output = self._backward_hook_construct(*inputs, **kwargs)", - "File /home/mindspore/nn/cell.py, line 745, in __call__, \n return self._run_construct(*args, **kwargs)", - "File /home/mindformers/experimental/distri_cores/transformer/language_model.py, line 151, in construct, \n embeddings = self.word_embeddings(input_ids)", - "File /home/mindspore/nn/cell.py, line 2460, in _backward_hook_construct, \n outputs = self.construct(*outputs, **kwargs)", - "File /home/mindspore/nn/cell.py, line 498, in _run_construct, \n output = self._backward_hook_construct(*inputs, **kwargs)", - "File /home/mindspore/nn/cell.py, line 745, in __call__, \n return self._run_construct(*args, **kwargs)", - "File /home/mindformers/experimental/distri_cores/transformer/language_model.py, line 391, in construct, \n text_embedding_out = self.embedding(enc_input_ids, enc_position_ids,", - "File /home/mindspore/nn/cell.py, line 2460, in _backward_hook_construct, \n outputs = self.construct(*outputs, **kwargs)", - "File /home/mindspore/nn/cell.py, line 498, in _run_construct, \n output = self._backward_hook_construct(*inputs, **kwargs)", - "File /home/mindspore/nn/cell.py, line 745, in __call__, \n return self._run_construct(*args, **kwargs)", - "File /home/model/gpt_model.py, line 104, in construct, \n lm_output = self.language_model(tokens,", - "File /home/mindspore/nn/cell.py, line 2460, in _backward_hook_construct, \n outputs = self.construct(*outputs, **kwargs)", - "File /home/mindspore/nn/cell.py, line 498, in _run_construct, \n output = self._backward_hook_construct(*inputs, **kwargs)", - "File /home/mindspore/nn/cell.py, line 745, in __call__, \n return self._run_construct(*args, **kwargs)", - "File /home/mindformers/experimental/distri_cores/pipeline_parallel/pipeline_cell.py, line 429, in construct, \n return self.model(*inputs)", - "File /home/mindspore/nn/cell.py, line 757, in _complex_call, \n output = self.construct(*args, **kwargs)", - "File /home/mindspore/nn/cell.py, line 747, in __call__, \n return self._complex_call(*args, **kwargs)", - "File /home/mindformers/experimental/distri_cores/pipeline_parallel/schedules.py, line 121, in run_forward, \n output_tensor = model(*input_data, recv_data=None)", - "File /home/mindformers/experimental/distri_cores/pipeline_parallel/schedules.py, line 735, in forward_backward_pipelining_without_interleaving, \n micro_input_data = run_forward(*micro_input_data,", - "File /home/mindformers/experimental/distri_cores/training.py, line 409, in forward_backward_with_pipelining, \n loss, logits, grads = forward_backward_pipelining_without_interleaving(", - "File /home/mindformers/experimental/distri_cores/training.py, line 533, in construct, \n (loss, _), grads = self.forward_backward_func(*inputs_tuple, loss_scale=current_step_loss_scale, **inputs_dict)", - "File /home/mindspore/nn/cell.py, line 757, in _complex_call, \n output = self.construct(*args, **kwargs)", - "File /home/mindspore/nn/cell.py, line 747, in __call__, \n return self._complex_call(*args, **kwargs)", - "File /home/mindformers/experimental/distri_cores/training.py, line 655, in train, \n loss, is_finite, loss_scale, learning_rate = train_one_step_cell(**data)", - "File /home/model/pretrain_gpt.py, line 303, in main, \n train(", - "File /home/model/pretrain_gpt.py, line 316, in , \n main()" - ] -} -``` # FAQ 1. 图比对场景,节点呈现灰色,且没有精度比对数据,怎么处理? 节点呈现灰色,代表左边待调试侧节点与右边标杆侧节点没有匹配上,可能有以下几点原因: - **标杆侧确实没有能与待调试侧匹配上的节点**,属于代码实现上的差异,请确认此差异是否正常,是否会影响到整网精度。 -- **节点的输入或输出type、shape不一致,参数个数不一致,节点所在层级的父层级不一致**,导致节点无法匹配,具体匹配规则见[匹配说明](#311-匹配说明),可尝试使用模糊匹配功能,如何使用此功能请参考[构图命令行说明](#31-构图命令行说明)。如果是参数shape不一致,即使是模糊匹配功能也无法让节点匹配上,请检查参数shape不一致是否合理。 -- **节点名称不一致**,导致节点无法匹配,可使用layer mapping功能,如何使用此功能请参考[构图命令行说明](#31-构图命令行说明),如何自定义映射文件请参考[模型分级可视化如何配置layer mapping映射文件](./visualization/layer_mapping_example.md)。 +- **节点名称一致,但节点的输入或输出type、shape不一致,参数个数不一致,节点所在层级的父层级不一致,导致节点无法匹配** + - 具体匹配规则见[匹配说明](#311-匹配说明),可尝试使用模糊匹配功能,如何使用此功能请参考[构图命令行说明](#31-构图命令行说明); + - 如果是参数shape不一致,即使是模糊匹配功能也无法让节点匹配上,请检查参数shape不一致是否合理。 +- **节点名称不一致**,导致节点无法匹配,目前提供两种方法,选其一即可 + - 可使用layer mapping功能,如何使用此功能请参考[构图命令行说明](#31-构图命令行说明),如何自定义映射文件请参考[模型分级可视化如何配置layer mapping映射文件](./visualization/layer_mapping_example.md); + - 可通过浏览器页面手动选择未匹配节点进行匹配,请参考[手动选择节点匹配](#56-手动选择节点匹配)。 diff --git a/debug/accuracy_tools/msprobe/docs/23.generate_operator_PyTorch.md b/debug/accuracy_tools/msprobe/docs/23.generate_operator_PyTorch.md index 59e2755ec3e5a3939af3a20d19fda12031a9bf51..c4dabebe2c1b3840988e06e20ef5d717647c62ba 100644 --- a/debug/accuracy_tools/msprobe/docs/23.generate_operator_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/23.generate_operator_PyTorch.md @@ -15,7 +15,7 @@ b. 在生成单API脚本时可以选择由工具构造随机数获得 dump 数 2. 已完成对训练过程的dump,获得dump.json文件。 [PyTorch场景的数据采集](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md) - **目前仅支持复现API级的数据,故dump时level可选择L0(API信息)或者mix(module信息+API信息)。如需复现真实数据场景的API脚本,dump时task应选择tensor,如需复现随机数据场景的API脚本,dump时task选择statistics**。 + **目前仅支持复现API级的数据,故dump时level可选择L1(API信息)或者mix(module信息+API信息)。如需复现真实数据场景的API脚本,dump时task应选择tensor,如需复现随机数据场景的API脚本,dump时task选择statistics**。 3. 发现某个算子疑似存在精度问题,并得知算子名,如Functional.softmax.3、Tensor.add.0、Torch.matmul.5等 ### 2.1 配置config_op.json @@ -33,15 +33,15 @@ b. 在生成单API脚本时可以选择由工具构造随机数获得 dump 数 ``` **配置文件参数说明** - | 参数名称 | 解释 | 是否必选 | - | ---------------------------- | ------------------------------------------------------------ | ---------------------------------- | - | dump_json_path | dump.json的文件路径,包含所有dump算子的信息;如果已经提取了可疑算子并保存可以不指定。 | 否 | - | api_name | 算子名,如Functional.softmax.3、Tensor.add.0、Torch.matmul.5等。如果已经提取了可疑算子并保存可以不指定 | 否 | - | extract_api_path | 提取可疑算子的json文件路径 | 是 | - | propagation | 选择复现算子的forward还是backward,默认为forward | 否 | - | data_mode | 选择复现算子的随机数据(random_data)还是真实数据(real_data)模式,默认为random_data | 否 | - | random_seed | 仅random_data模式有效,表示手动设定的随机种子,默认为1234 | 否 | - | iter_times | 仅random_data模式有效,表示单API运行的次数 | 否 | + | 参数名称 | 解释 | 是否必选 | + | ---------------------------- |----------------------------------------------------------------------------| ---------------------------------- | + | dump_json_path | dump.json的文件路径,包含所有dump算子的信息;如果已经提取了可疑算子并保存可以不指定。 | 否 | + | api_name | 算子名,如Functional.softmax.3、Tensor.add.0、Torch.matmul.5等。如果已经提取了可疑算子并保存可以不指定 | 否 | + | extract_api_path | 提取可疑算子的json文件路径 | 是 | + | propagation | 选择复现算子的forward还是backward,默认为forward | 否 | + | data_mode | 选择复现算子的随机数据(random_data)还是真实数据(real_data)模式,默认为random_data | 否 | + | random_seed | 仅random_data模式有效,表示手动设定的随机种子,默认为1234 | 否 | + | iter_times | 仅random_data模式有效,表示单API运行的次数,由于安全相关原因,最大支持设置为1000 | 否 | ### 2.2 运行命令生成单API脚本 config_op.json配置好后,运行如下命令: diff --git a/debug/accuracy_tools/msprobe/docs/24.code_mapping_Mindspore.md b/debug/accuracy_tools/msprobe/docs/24.code_mapping_Mindspore.md index 05e3900d2647b07ed5334082e3ac519cfc7fb2b2..9e741305363e7ee8419a2df8661c6afad463201c 100644 --- a/debug/accuracy_tools/msprobe/docs/24.code_mapping_Mindspore.md +++ b/debug/accuracy_tools/msprobe/docs/24.code_mapping_Mindspore.md @@ -20,9 +20,10 @@ msprobe -f mindspore code_mapping --ir --dump_data [--outp ``` -| 参数名称 | 说明 |参数类型 | 是否必选 | -| ---------------------------- |-------------------------------------------------------------------------------------------------------------------------------------------|---------------------- | ---------------------------------- | -| --ir | 指定 MindSpore 静态图运行时生成的IR图文件。 | str | 是 | -| --dump_data | 指定dump数据文件(支持tensor或statistic模式的dump数据)。可指定单个dump数据 文件或dump数据文件的父目录,指定父目录表示关联目录下的所有dump数据文件。 | str | 是 | -| --output | 关联结果输出目录,默认为"./",只在tensor模式时生效,会把数据文件路径和代码调用栈的关联关系存到output路径下的code_mapping_{时间戳}.csv中。如果关联的是statistic模式,则会把statistic.csv中每个条目加上该条目对应的代码栈。 | str | 否 | +| 参数名称 | 说明 | 参数类型 | 是否必选 | +| ---------------------------- |-------------------------------------------------------------------------------------------------------------------------------------------|------| ---------------------------------- | +| -f 或 --framework | 指定训练框架。mindspore。 | str | 是 | +| --ir | 指定 MindSpore 静态图运行时生成的IR图文件。 | str | 是 | +| --dump_data | 指定dump数据文件(支持tensor或statistic模式的dump数据)。可指定单个dump数据 文件或dump数据文件的父目录,指定父目录表示关联目录下的所有dump数据文件。 | str | 是 | +| --output | 关联结果输出目录,默认为"./",只在tensor模式时生效,会把数据文件路径和代码调用栈的关联关系存到output路径下的code_mapping_{时间戳}.csv中。如果关联的是statistic模式,则会把statistic.csv中每个条目加上该条目对应的代码栈。 | str | 否 | diff --git a/debug/accuracy_tools/msprobe/docs/25.tool_function_introduction.md b/debug/accuracy_tools/msprobe/docs/25.tool_function_introduction.md index f6f5db9781223fc299df978dfd55a9d2af2e07e6..686660b09991ae66e48a3423458fd239b1f01099 100644 --- a/debug/accuracy_tools/msprobe/docs/25.tool_function_introduction.md +++ b/debug/accuracy_tools/msprobe/docs/25.tool_function_introduction.md @@ -2,28 +2,28 @@ ## 1 PyTorch框架 -| 功能名(英文) | 简介 | 适用场景/优势 | 当前版本局限性 | -|------------------------------------------------------------------------------------|---------------------------------------------------------------|--------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------| -| [数据采集
(dump)](./05.data_dump_PyTorch.md) | 采集模型训练过程中的API或Module层级的前反向输入输出数据,包括层次关系、统计值信息、真实数据和调用栈等。 | 1、将模型中训练的API或Module的前反向输入输出数据保存下来分析
2、模型出现溢出时,可用于查看哪些API或Module出现了溢出 | 1、API级数据采集仅支持白名单列表上的API
2、工具会做一些同步操作,引入工具可能会导致一些同步问题消失
3、当前对inplace操作API或Module的支持度有限
4、暂不支持参数及参数梯度的采集 | -| [离线预检
(api_accuracy_checker)](./07.accuracy_checker_PyTorch.md) | 为网络中每个API创建用例,检验其精度,并根据不同比对算法综合判定API在NPU上的精度是否达标,快速找出精度差异API。 | 1、对模型中所有的API做精度初步排查
2、精度排查不受模型累计误差影响 | 1、依赖GPU环境
2、不支持通信算子
3、仅支持部分融合算子 | -| [整网比对
(compare)](./10.accuracy_compare_PyTorch.md) | 计算模型整网NPU和标杆设备的精度误差指标,标记精度异常API或Module,助力快速定位精度问题根因。 | 1、整网比对定位精度可疑算子 | 1、由于使用整网dump数据,定位的可疑算子受累计误差影响
2、当模型规模较大时,比对所需时间较长 | -| [在线预检
(online_api_accuracy_checker)](./08.accuracy_checker_online_PyTorch.md) | 通过TCP通信或共享存储空间的方式,进行在线精度预检,解决离线预检大数据量落盘、传输困难痛点。 | 1、使用离线预检,数据量较大落盘困难或传输耗时长时,可通过在线预检进行精度排查 | 1、依赖GPU环境,NPU和GPU能够通信
2、重计算模式下,不支持反向aten算子预检 | -| [溢出检查
(overflow_checker)](./12.overflow_check_PyTorch.md) | 检测模型计算过程的输入输出,并在溢出时落盘数据,助力用户快速定位溢出位置。 | 1、当模型出现溢出时,用于快速定位最先溢出的API或Module
2、相比数据采集,性能更优,磁盘压力更小 | 1、局限性同数据采集 | -| [数据解析
(parse_tool)](./14.data_parse_PyTorch.md) | 互交式界面处理解析kernel层级dump数据,便于查看分析。 | 1、比对kernel层级dump数据的一致性 | 1、仅限于NPU | -| [无标杆比对
(free_benchmark)](./15.free_benchmarking_PyTorch.md) | 不依赖标杆数据,通过对算子输入增加微小扰动,计算扰动后输出与原始输出的相对误差,识别有精度风险算子。 | 1、无标杆数据场景下的算子精度排查
2、对个别算子进行升精度、“to cpu”等操作,以验证其对模型loss的影响 | 1、由于需要拷贝输入进行二次执行,所以在遇到大张量的输入时容易发生显存OOM的问题, 特别是反向比对过程。建议结合白名单使用
2、比对会延长训练时间,整网比对可能会造成严重的耗时膨胀,建议结合白名单使用 | -| [梯度状态监测
(grad_probe)](./17.grad_probe.md) | 可导出模型权重梯度数据并对比相似度,助力确认训练过程精度问题step和反向中的异常。 | 1、需要分析梯度数据时
2、需要定位发生问题的step时 | 暂无 | -| [在线精度比对
(online_dispatch)](./18.online_dispatch.md) | 训练过程中直接完成NPU和CPU的精度比对并输出比对结果。 | 1、执行一次就可获取NPU和CPU分别执行后的精度比对结果 | 暂无 | -| [训练状态监控
(monitor)](./19.monitor.md) | 收集模型训练过程中的激活值、梯度和优化器状态,助力分析计算、通信、优化器各部分异常情况。 | 1、通过监控模块级统计量指标,快速定位异常模块位置,如loss出现nan | 1、仅支持模块级别统计量指标分析
2、仅支持megatron、deepspeed框架
3、少量增加时间和显存膨胀 | -| [可视化比对
(visualization) ](./21.visualization_PyTorch.md) | 解析dump的精度数据,还原模型图结构,比对各层级精度数据,助力理解模型结构、分析精度问题。 | 1、整网精度比对定位可疑算子,通过浏览器展示比对结果,支持快速搜索到可疑算子
2、支持查看模型层级结果,比对模型层级结构差异 | 1、由于使用整网dump数据,定位的可疑算子受累计误差影响
2、当模型规模较大时,比对所需时间较长 | -| [单API自动生成脚本
(generate_operator) ](./23.generate_operator_PyTorch.md) | 解析dump的精度数据,提取可疑的API算子,自动生成单API复现脚本,并根据不同的API采用不同的比对算法,给定最终比对结果数据;帮助开发者分析算子精度问题。 | 1、该工具支持从整网dump下来的数据中提取可疑算子,并自动生成单API脚本
2、除了支持复现单API的前反向过程,同时会根据不同的API选择不同的比对方法,并给出比对结果 |1、不支持通信算子
2、融合算子需手动修改脚本进行适配
3、目前比对的标杆均为和CPU进行比对,暂不支持直接NPU和GPU比对 +| 功能名(英文) | 简介 | 适用场景/优势 | 当前版本局限性 | +| --------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [数据采集
(dump)](./05.data_dump_PyTorch.md) | 采集模型训练过程中的API或Module层级的前反向输入输出数据,包括层次关系、统计值信息、真实数据和调用栈等。 | 1、将模型中训练的API或Module的前反向输入输出数据保存下来分析
2、模型出现溢出时,可用于查看哪些API或Module出现了溢出 | 1、API级数据采集仅支持白名单列表上的API
2、工具会做一些同步操作,引入工具可能会导致一些同步问题消失
3、当前对inplace操作API或Module的支持度有限
4、暂不支持参数及参数梯度的采集 | +| [离线预检
(api_accuracy_checker)](./07.accuracy_checker_PyTorch.md) | 为网络中每个API创建用例,检验其精度,并根据不同比对算法综合判定API在NPU上的精度是否达标,快速找出精度差异API。 | 1、对模型中所有的API做精度初步排查
2、精度排查不受模型累计误差影响 | 1、依赖GPU环境
2、不支持通信算子
3、仅支持部分融合算子 | +| [整网比对
(compare)](./10.accuracy_compare_PyTorch.md) | 计算模型整网NPU和标杆设备的精度误差指标,标记精度异常API或Module,助力快速定位精度问题根因。 | 1、整网比对定位精度可疑算子 | 1、由于使用整网dump数据,定位的可疑算子受累计误差影响
2、当模型规模较大时,比对所需时间较长 | +| [溢出检查
(overflow_checker)](./12.overflow_check_PyTorch.md) | 检测模型计算过程的输入输出,并在溢出时落盘数据,助力用户快速定位溢出位置。 | 1、当模型出现溢出时,用于快速定位最先溢出的API或Module
2、相比数据采集,性能更优,磁盘压力更小 | 1、局限性同数据采集 | +| [数据解析
(parse_tool)](./14.data_parse_PyTorch.md) | 交互式界面处理解析kernel层级dump数据,便于查看分析。 | 1、比对kernel层级dump数据的一致性 | 1、仅限于NPU | +| [无标杆比对
(free_benchmark)](./15.free_benchmarking_PyTorch.md) | 不依赖标杆数据,通过对算子输入增加微小扰动,计算扰动后输出与原始输出的相对误差,识别有精度风险算子。 | 1、无标杆数据场景下的算子精度排查
2、对个别算子进行升精度、“to cpu”等操作,以验证其对模型loss的影响 | 1、由于需要拷贝输入进行二次执行,所以在遇到大张量的输入时容易发生显存OOM的问题, 特别是反向比对过程。建议结合白名单使用
2、比对会延长训练时间,整网比对可能会造成严重的耗时膨胀,建议结合白名单使用 | +| [梯度状态监测
(grad_probe)](./17.grad_probe.md) | 可导出模型权重梯度数据并对比相似度,助力确认训练过程精度问题step和反向中的异常。 | 1、需要分析梯度数据时
2、需要定位发生问题的step时 | 暂无 | +| [在线精度比对
(online_dispatch)](./18.online_dispatch.md) | 训练过程中直接完成NPU和CPU的精度比对并输出比对结果。 | 1、执行一次就可获取NPU和CPU分别执行后的精度比对结果 | 暂无 | +| [训练状态监控
(monitor)](./19.monitor.md) | 收集模型训练过程中的激活值、梯度和优化器状态,助力分析计算、通信、优化器各部分异常情况。 | 1、通过监控模块级统计量指标,快速定位异常模块位置,如loss出现nan | 1、仅支持模块级别统计量指标分析
2、仅支持megatron、deepspeed框架
3、少量增加时间和显存膨胀 | +| [可视化比对
(visualization) ](./21.visualization_PyTorch.md) | 解析dump的精度数据,还原模型图结构,比对各层级精度数据,助力理解模型结构、分析精度问题。 | 1、整网精度比对定位可疑算子,通过浏览器展示比对结果,支持快速搜索到可疑算子
2、支持查看模型层级结果,比对模型层级结构差异 | 1、由于使用整网dump数据,定位的可疑算子受累计误差影响
2、当模型规模较大时,比对所需时间较长 | +| [单API自动生成脚本
(generate_operator) ](./23.generate_operator_PyTorch.md) | 解析dump的精度数据,提取可疑的API算子,自动生成单API复现脚本,并根据不同的API采用不同的比对算法,给定最终比对结果数据;帮助开发者分析算子精度问题。 | 1、该工具支持从整网dump下来的数据中提取可疑算子,并自动生成单API脚本
2、除了支持复现单API的前反向过程,同时会根据不同的API选择不同的比对方法,并给出比对结果 | 1、不支持通信算子
2、融合算子需手动修改脚本进行适配
3、目前比对的标杆均为和CPU进行比对,暂不支持直接NPU和GPU比对 | ## 2 MindSpore框架 -| 功能名(英文) | 简介 | 适用场景/优势 | 当前版本局限性 | -|----------------------------------------------------------------------|-------------------------------------------------------------------|------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------| -| [数据采集
(dump)](./06.data_dump_MindSpore.md) | 采集模型训练过程中的API或Cell层级的前反向输入输出数据,包括层次关系、统计值信息、真实数据和调用栈等。 | 1、将模型中训练的API或Cell的前反向输入输出数据保存下来分析
2、模型出现溢出时,可用于查看哪些API或Cell出现了溢出 | 1、API级数据采集仅支持白名单列表上的API
2、当前对inplace操作API或Cell的支持度有限
3、暂不支持参数及参数梯度的采集 | -| [离线预检
(api_accuracy_checker)](./09.accuracy_checker_MindSpore.md) | 为网络中每个API创建用例,检验其精度,并根据不同比对算法综合判定API在NPU上的精度是否达标,快速找出精度差异API。 | 1、对模型中所有的API做精度初步排查
2、精度排查不受模型累计误差影响 | 1、仅针对MindSpore.mint API | -| [整网比对
(compare)](./11.accuracy_compare_MindSpore.md) | NPU精度数据与标杆数据的比对,支持MindSpore框架内和与PyTorch跨框架的比对,助力快速定位精度异常API或Cell。 | 1、MindSpore同框架静态图比对
2、MindSpore同框架动态图比对
3、MindSpore vs PyTorch跨框架动态图比对 | 1、部分PyTorch的API关联不到MindSpore,需要手动配置映射关系 | -| [溢出检查
(overflow_checker)](./13.overflow_check_MindSpore.md) | 检测模型计算过程的输入输出,并在溢出时落盘数据,助力用户快速定位溢出位置。 | 1、当模型出现溢出时,可用于定位最先溢出的API或Cell或kernel
2、相比数据采集,性能更优,磁盘压力更小 | 1、除具有与数据采集功能相同的局限性外,动态图场景下,不支持 Primitive 和 Jit 类 API 的检测
2、动态图场景下,仅支持检测API或Cell级别溢出
3、静态图场景下,仅支持检测kernel级别溢出 | -| [无标杆比对
(free_benchmark)](./16.free_benchmarking_MindSpore.md) | 不依赖标杆数据,通过对算子输入增加微小扰动,计算扰动后输出与原始输出的相对误差,识别有精度风险算子。 | 1、无标杆数据场景下的算子精度排查
2、对个别算子进行升精度修复,验证其对模型loss的影响 | 1、仅支持动态图场景
2、由于需要拷贝输入进行二次执行,所以在遇到大张量的输入时容易发生显存OOM的问题, 特别是反向比对过程。建议结合白名单使用
3、比对会延长训练时间,整网比对可能会造成严重的耗时膨胀,建议结合白名单使用
4、不支持“to cpu”操作,不支持预热功能 | -| [可视化比对
(visualization) ](./22.visualization_MindSpore.md) | 解析dump的精度数据,还原模型图结构,比对各层级精度数据,助力理解模型结构、分析精度问题。 | 1、整网精度比对定位可疑算子,通过浏览器展示比对结果,支持快速搜索到可疑算子
2、支持查看模型层级结果,比对模型层级结构差异 | 1、由于使用整网dump数据,定位的可疑算子受累计误差影响
2、当模型规模较大时,比对所需时间较长 | +| 功能名(英文) | 简介 | 适用场景/优势 | 当前版本局限性 | +| ---------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [数据采集
(dump)](./06.data_dump_MindSpore.md) | 采集模型训练过程中的API或Cell层级的前反向输入输出数据,包括层次关系、统计值信息、真实数据和调用栈等。 | 1、将模型中训练的API或Cell的前反向输入输出数据保存下来分析
2、模型出现溢出时,可用于查看哪些API或Cell出现了溢出 | 1、API级数据采集仅支持白名单列表上的API
2、当前对inplace操作API或Cell的支持度有限
3、暂不支持参数及参数梯度的采集 | +| [离线预检
(api_accuracy_checker)](./09.accuracy_checker_MindSpore.md) | 为网络中每个API创建用例,检验其精度,并根据不同比对算法综合判定API在NPU上的精度是否达标,快速找出精度差异API。 | 1、对模型中所有的API做精度初步排查
2、精度排查不受模型累计误差影响 | 1、仅针对MindSpore.mint API | +| [整网比对
(compare)](./11.accuracy_compare_MindSpore.md) | NPU精度数据与标杆数据的比对,支持MindSpore框架内和与PyTorch跨框架的比对,助力快速定位精度异常API或Cell。 | 1、MindSpore同框架静态图比对
2、MindSpore同框架动态图比对
3、MindSpore vs PyTorch跨框架动态图比对 | 1、部分PyTorch的API关联不到MindSpore,需要手动配置映射关系 | +| [溢出检查
(overflow_checker)](./13.overflow_check_MindSpore.md) | 检测模型计算过程的输入输出,并在溢出时落盘数据,助力用户快速定位溢出位置。 | 1、当模型出现溢出时,可用于定位最先溢出的API或Cell或kernel
2、相比数据采集,性能更优,磁盘压力更小 | 1、除具有与数据采集功能相同的局限性外,动态图场景下,不支持 Primitive 和 Jit 类 API 的检测
2、动态图场景下,仅支持检测API或Cell级别溢出
3、静态图场景下,仅支持检测kernel级别溢出 | +| [无标杆比对
(free_benchmark)](./16.free_benchmarking_MindSpore.md) | 不依赖标杆数据,通过对算子输入增加微小扰动,计算扰动后输出与原始输出的相对误差,识别有精度风险算子。 | 1、无标杆数据场景下的算子精度排查
2、对个别算子进行升精度修复,验证其对模型loss的影响 | 1、仅支持动态图场景
2、由于需要拷贝输入进行二次执行,所以在遇到大张量的输入时容易发生显存OOM的问题, 特别是反向比对过程。建议结合白名单使用
3、比对会延长训练时间,整网比对可能会造成严重的耗时膨胀,建议结合白名单使用
4、不支持“to cpu”操作,不支持预热功能 | +| [可视化比对
(visualization) ](./22.visualization_MindSpore.md) | 解析dump的精度数据,还原模型图结构,比对各层级精度数据,助力理解模型结构、分析精度问题。 | 1、整网精度比对定位可疑算子,通过浏览器展示比对结果,支持快速搜索到可疑算子
2、支持查看模型层级结果,比对模型层级结构差异 | 1、由于使用整网dump数据,定位的可疑算子受累计误差影响
2、当模型规模较大时,比对所需时间较长 | +| [训练状态监控
(monitor)](./19.monitor.md) | 收集模型训练过程中的激活值、梯度和优化器状态,助力分析计算、通信、优化器各部分异常情况。 | 1、通过监控模块级统计量指标,快速定位异常模块位置,如loss出现nan | 1、仅支持模块级别统计量指标分析
2、仅支持megatron、deepspeed框架
3、少量增加时间和显存膨胀 | diff --git a/debug/accuracy_tools/msprobe/docs/26.data_dump_PyTorch_baseline.md b/debug/accuracy_tools/msprobe/docs/26.data_dump_PyTorch_baseline.md index 5ca199ab6171a3634af0b26844d6ba8e7d04933f..537c185a016bd583533aa831bdc04a10c6c49c96 100644 --- a/debug/accuracy_tools/msprobe/docs/26.data_dump_PyTorch_baseline.md +++ b/debug/accuracy_tools/msprobe/docs/26.data_dump_PyTorch_baseline.md @@ -1,8 +1,19 @@ # PyTorch 场景的精度数据采集基线 +## "statistics"模式(未开启md5)采集时间膨胀参考基线 + +该基线为PyTorch框架下,使用"statistics"模式采集数据性能膨胀的参考基线。本基线测试了LLAMA2-7B语言大模型在不同采集模式8卡下的时间膨胀。 + +| 采集模式 | 无工具 (耗时) | 加工具但未使能 Dump (耗时) | 加工具并使能 Dump (耗时) | +|:--------:|:--------:|:--------------------:|:------------------:| +| L0 | ≈17.4 s | ≈17.4 s (无膨胀) | ≈78.4 s (膨胀4.5倍) | +| L1 | ≈17.4 s | ≈20.7 s (膨胀1.2倍) | ≈353 s (膨胀20倍) | +| mix | ≈17.4 s | ≈20.7 s (膨胀1.2倍) | ≈430 s (膨胀24.7 倍) | + + ## "tensor"模式采集数据量参考基线 -该基线为pytorch框架下,使用"tensor"模式采集数据量参考基线。本基线测试了两个模型,分别为LLAMA2-7B和LLAMA2-13B,测试了不同采集模式下,不同global_batch_size下,单卡和8卡下,数据量的变化。 +该基线为PyTorch框架下,使用"tensor"模式采集数据量参考基线。本基线测试了两个模型,分别为LLAMA2-7B和LLAMA2-13B,测试了不同采集模式下,不同global_batch_size下,单卡和8卡下,数据量的变化。 ### LLAMA2-7B @@ -25,8 +36,8 @@ - - + + diff --git a/debug/accuracy_tools/msprobe/docs/27.dump_json_instruction.md b/debug/accuracy_tools/msprobe/docs/27.dump_json_instruction.md index f994dc2301bcae6b23dc7a7503297aa4fe5b3724..bf992a02aba6c9b4c6c1d18077775c0a8f4325ea 100644 --- a/debug/accuracy_tools/msprobe/docs/27.dump_json_instruction.md +++ b/debug/accuracy_tools/msprobe/docs/27.dump_json_instruction.md @@ -1,8 +1,8 @@ # dump.json文件说明及示例 -## 1. dump.json文件示例(PyTorch) +## 1. PyTorch 场景下的 dump.json 文件 -### 1.1 L0级别 +### 1.1 L0 级别 L0级别的dump.json文件包括模块的前反向的输入输出,以及模块的参数和参数梯度。以PyTorch的Conv2d模块为例,网络中模块调用代码为: `output = self.conv2(input) # self.conv2 = torch.nn.Conv2d(64, 128, 5, padding=2, bias=True)` @@ -168,7 +168,7 @@ dump.json文件中包含以下数据名称: } ``` -### 1.2 L1级别 +### 1.2 L1 级别 L1级别的dump.json文件包括API的前反向的输入输出。以PyTorch的relu函数为例,网络中API调用代码为: `output = torch.nn.functional.relu(input)` @@ -264,13 +264,13 @@ dump.json文件中包含以下数据名称: } ``` -### 1.3 mix级别 +### 1.3 mix 级别 mix级别的dump.json文件同时包括L0和L1级别的dump数据,文件格式与上述示例相同。 -## 2. dump.json文件示例(MindSpore) +## 2. MindSpore 场景下的 dump.json 文件 -### 2.1 L0级别 +### 2.1 L0 级别 L0级别的dump.json文件包括模块的前反向的输入输出,以及模块的参数和参数梯度。 以MindSpore的Conv2d模块为例,dump.json文件中使用的模块调用代码为: @@ -429,7 +429,7 @@ dump.json文件中包含以下数据名称: } ``` -### 2.2 L1级别 +### 2.2 L1 级别 L1级别的dump.json文件包括API的前反向的输入输出,以MindSpore的relu函数为例,网络中API调用代码为: `output = mindspore.ops.relu(input)` @@ -521,5 +521,275 @@ L1级别的dump.json文件包括API的前反向的输入输出,以MindSpore的 } ``` -### 2.3 mix级别 +### 2.3 mix 级别 + mix级别的dump.json文件同时包括L0和L1级别的dump数据,文件格式与上述示例相同。 + +## 3. MSAdapter 场景下的 dump.json 文件 + +### 3.1 L0 级别 + +L0 级别的 dump.json 文件包括模块的前反向的输入输出,以及模块的参数和参数梯度。以 Conv2d 模块为例,网络中模块调用代码为: +`output = self.conv2(input) # self.conv2 = torch.nn.Conv2d(64, 128, 5, padding=2, bias=True)` + +dump.json文件中包含以下数据名称: + +- `Module.conv2.Conv2d.forward.0`:模块的前向数据,其中input_args为模块的输入数据(位置参数),input_kwargs为模块的输入数据(关键字参数),output为模块的输出数据,parameters为模块的参数数据,包括权重(weight)和偏置(bias)。 +- `Module.conv2.Conv2d.parameters_grad`:模块的参数梯度数据,包括权重(weight)和偏置(bias)的梯度。 +- `Module.conv2.Conv2d.backward.0`:模块的反向数据,其中input为模块反向的输入梯度(对应前向输出的梯度),output为模块的反向输出梯度(对应前向输入的梯度)。 + +**说明**:当dump时传入的model参数为List[torch.nn.Module]或Tuple[torch.nn.Module]时,模块级数据的命名中包含该模块在列表中的索引index,命名格式为`{Module}.{index}.*`,*表示以上三种模块级数据的命名格式,例如:`Module.0.conv1.Conv2d.forward.0`。 + +```json +{ + "task": "tensor", + "level": "L0", + "framework": "mindtorch", + "dump_data_dir": "/dump/path", + "data": { + "Module.conv2.Conv2d.forward.0": { + "input_args": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 8, + 16, + 14, + 14 + ], + "Max": 1.638758659362793, + "Min": 0.0, + "Mean": 0.2544615864753723, + "Norm": 70.50277709960938, + "requires_grad": true, + "data_name": "Module.conv2.Conv2d.forward.0.input.0.npy" + } + ], + "input_kwargs": {}, + "output": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 8, + 32, + 10, + 10 + ], + "Max": 1.6815717220306396, + "Min": -1.5120246410369873, + "Mean": -0.025344856083393097, + "Norm": 149.65576171875, + "requires_grad": true, + "data_name": "Module.conv2.Conv2d.forward.0.output.0.npy" + } + ], + "parameters": { + "weight": { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32, + 16, + 5, + 5 + ], + "Max": 0.05992485210299492, + "Min": -0.05999220535159111, + "Mean": -0.0006165213999338448, + "Norm": 3.421217441558838, + "requires_grad": true, + "data_name": "Module.conv2.Conv2d.forward.0.parameters.weight.npy" + }, + "bias": { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32 + ], + "Max": 0.05744686722755432, + "Min": -0.04894155263900757, + "Mean": 0.006410328671336174, + "Norm": 0.17263513803482056, + "requires_grad": true, + "data_name": "Module.conv2.Conv2d.forward.0.parameters.bias.npy" + } + } + }, + "Module.conv2.Conv2d.parameters_grad": { + "weight": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32, + 16, + 5, + 5 + ], + "Max": 0.018550323322415352, + "Min": -0.008627401664853096, + "Mean": 0.0006675920449197292, + "Norm": 0.26084786653518677, + "requires_grad": false, + "data_name": "Module.conv2.Conv2d.parameters_grad.weight.npy" + } + ], + "bias": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32 + ], + "Max": 0.014914230443537235, + "Min": -0.006656786892563105, + "Mean": 0.002657240955159068, + "Norm": 0.029451673850417137, + "requires_grad": false, + "data_name": "Module.conv2.Conv2d.parameters_grad.bias.npy" + } + ] + }, + "Module.conv2.Conv2d.backward.0": { + "input": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 8, + 32, + 10, + 10 + ], + "Max": 0.0015069986693561077, + "Min": -0.001139344065450132, + "Mean": 3.3215508210560074e-06, + "Norm": 0.020567523315548897, + "requires_grad": false, + "data_name": "Module.conv2.Conv2d.backward.0.input.0.npy" + } + ], + "output": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 8, + 16, + 14, + 14 + ], + "Max": 0.0007466732058674097, + "Min": -0.00044813455315306783, + "Mean": 6.814070275140693e-06, + "Norm": 0.01474067009985447, + "requires_grad": false, + "data_name": "Module.conv2.Conv2d.backward.0.output.0.npy" + } + ] + } + } +} +``` + +### 3.2 L1 级别 +L1级别的dump.json文件包括API的前反向的输入输出。以 relu API 为例,网络中 API 调用代码为: +`output = torch.nn.functional.relu(input)` + +dump.json文件中包含以下数据名称: +- `Functional.relu.0.forward`:API的前向数据,其中input_args为API的输入数据(位置参数),input_kwargs为API的输入数据(关键字参数),output为API的输出数据。 +- `Functional.relu.0.backward`:API的反向数据,其中input为API的反向输入梯度(对应前向输出的梯度),output为API的反向输出梯度(对应前向输入的梯度)。 + +```json +{ + "task": "tensor", + "level": "L1", + "framework": "mindtorch", + "dump_data_dir":"/dump/path", + "data": { + "Functional.relu.0.forward": { + "input_args": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32, + 16, + 28, + 28 + ], + "Max": 1.3864083290100098, + "Min": -1.3364859819412231, + "Mean": 0.03711778670549393, + "Norm": 236.20692443847656, + "requires_grad": true, + "data_name": "Functional.relu.0.forward.input.0.npy" + } + ], + "input_kwargs": {}, + "output": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32, + 16, + 28, + 28 + ], + "Max": 1.3864083290100098, + "Min": 0.0, + "Mean": 0.16849493980407715, + "Norm": 175.23345947265625, + "requires_grad": true, + "data_name": "Functional.relu.0.forward.output.0.npy" + } + ] + }, + "Functional.relu.0.backward": { + "input": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32, + 16, + 28, + 28 + ], + "Max": 0.0001815402356442064, + "Min": -0.00013352684618439525, + "Mean": 0.00011915402356442064, + "Norm": 0.007598237134516239, + "requires_grad": false, + "data_name": "Functional.relu.0.backward.input.0.npy" + } + ], + "output": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32, + 16, + 28, + 28 + ], + "Max": 0.0001815402356442064, + "Min": -0.00012117840378778055, + "Mean": 2.0098118724831693e-08, + "Norm": 0.006532244384288788, + "requires_grad": false, + "data_name": "Functional.relu.0.backward.output.0.npy" + } + ] + } + } +} +``` + +### 3.3 mix 级别 + +mix级别的dump.json文件同时包括L0和L1级别的dump数据,文件格式与上述示例相同。 \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/28.debugger_save_instruction.md b/debug/accuracy_tools/msprobe/docs/28.debugger_save_instruction.md index 6f4d519d5f61d5efaaffe54a1bde4f140b539f72..f4eb3591b200be8a829077af9944220e8610d88f 100644 --- a/debug/accuracy_tools/msprobe/docs/28.debugger_save_instruction.md +++ b/debug/accuracy_tools/msprobe/docs/28.debugger_save_instruction.md @@ -1,49 +1,63 @@ -# 单点保存工具 README +# 单点保存工具 ## 简介 -L0, L1, mix dump存在盲区,网络中的非api/module的输入输出不会被批量dump下来。单点保存提供类似np.save和print的功能和使用体验,可以保存指定的变量。同时针对大模型场景进行了增强,具备以下特性: + +L0, L1, mix级别的dump能力存在盲区,网络中的非API或module的输入输出不会被批量dump下来。单点保存提供类似np.save和print的功能和使用体验,可以保存指定的变量。同时针对大模型场景进行了增强,具备以下特性: + - 可保存变量的反向梯度结果。 - 能直接保存嵌套结构数据(如 list、dict),无需手动遍历。 -- 自动分 rank 保存。 +- 自动分 Rank 保存。 +- 可分 Step 保存数据。 - 多次调用时会自动计数。 -- 可配置保存统计值或者张量。 +- 可配置保存统计值(MindSpore静态图暂不支持)或者张量。 +- 支持异步保存。 + +单点保存工具的使用过程中可能会涉及到工具跨文件使用的场景,具体使能方式见[跨文件采集数据](./05.data_dump_PyTorch.md#24-跨文件采集数据)。 ## 支持场景 -仅支持 PyTorch 与 MindSpore 的动态图场景。 -## 使能方式 +## 动态图场景(Pytorch&MindSpore) -### 配置文件说明 +### 使能方式 -通用配置: +#### 配置文件说明 -| 参数 | 解释 | 是否必选 | -| -------- |-------------------------------------------| -------- | -| task | dump 的任务类型,str 类型。 单点保存场景仅支持传入"statistics", "tensor"。 | 是 | -| level | dump 级别,str 类型,根据不同级别采集不同数据。单点保存场景传入"debug"。 | 是 | -| dump_path | 设置 dump 数据目录路径,str 类型。细节详见[通用配置说明](./02.config_introduction.md#11-通用配置) | 是 | -| rank | 指定对某张卡上的数据进行采集,list[Union[int, str]] 类型。细节详见[通用配置说明](./02.config_introduction.md#11-通用配置) | 否 | +通用配置 (细节详见[通用配置说明](./02.config_introduction.md#11-通用配置) ): + +| 参数 | 解释 | 是否必选 | +| ---------- | ------------------------------------------------------------------------------------------------------- | -------- | +| task | dump 的任务类型,str 类型。 单点保存场景仅支持传入"statistics", "tensor"。 | 是 | +| level | dump 级别,str 类型,根据不同级别采集不同数据。单点保存场景传入"debug"。 | 是 | +| dump_path | 设置 dump 数据目录路径,str 类型。 | 是 | +| rank | 指定对某张卡上的数据进行采集,list[Union[int, str]] 类型。 | 否 | +| step | 指定采集某个 Step 的数据,list[Union[int, str]] 类型。 | 否 | +| async_dump | 异步 dump 开关,bool 类型。该模式下,summary_mode 不支持 md5 值,也不支持复数类型 tensor 的统计量计算。 | 否 | "statistics" 任务子配置项: -| 参数 | 解释 | 是否必选 | -| -------- |-------------------------------------------| -------- | -| summary_mode | 控制 dump 文件输出的模式,str 类型。支持传入"statistics", "md5"。 细节详见[statistics任务子配置项说明](./02.config_introduction.md#12-task-配置为-statistics) | 否 | + +| 参数 | 解释 | 是否必选 | +| ------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | -------- | +| summary_mode | 控制 dump 文件输出的模式,str 类型。支持传入"statistics", "md5"。 细节详见[statistics任务子配置项说明](./02.config_introduction.md#12-task-配置为-statistics) | 否 | "tensor" 任务无子配置项。 -### 接口调用说明 +#### 接口调用说明 -调用PrecisionDebugger.save,传入需要保存的变量,指定变量名称以及是否需要保存反向数据。接口入参说明详见[pytorch单点保存接口](./05.data_dump_PyTorch.md#19-save),[mindspore单点保存接口](./06.data_dump_MindSpore.md#615-save) +调用PrecisionDebugger.save,传入需要保存的变量,指定变量名称以及是否需要保存反向数据。接口入参说明详见[PyTorch单点保存接口](./05.data_dump_PyTorch.md#19-save),[MindSpore单点保存接口](./06.data_dump_MindSpore.md#615-save) -### 实例(以pytorch场景为例) +#### 实例 +(以PyTorch场景为例,MindSpore场景只需要从msprobe.mindspore模块导包即可) 配置文件 + ```json { "task": "statistics", "dump_path": "./dump_path", "rank": [], + "step": [], "level": "debug", + "async_dump": false, "statistics": { "summary_mode": "statistics" } @@ -51,9 +65,10 @@ L0, L1, mix dump存在盲区,网络中的非api/module的输入输出不会被 ``` 初始化 + ```python # 训练启动py脚本 -from mindspore.pytorch import PrecisionDebugger +from msprobe.pytorch import PrecisionDebugger debugger = PrecisionDebugger("./config.json") for data, label in data_loader: # 执行模型训练 @@ -62,9 +77,10 @@ for data, label in data_loader: ``` 初始化(无配置文件) + ```python # 训练启动py脚本 -from mindspore.pytorch import PrecisionDebugger +from msprobe.pytorch import PrecisionDebugger debugger = PrecisionDebugger(dump_path="dump_path", level="debug") for data, label in data_loader: # 执行模型训练 @@ -72,23 +88,201 @@ for data, label in data_loader: ``` -调用保存接口 +调用保存接口示例(以PyTorch代码为例,MindSpore使用方法相同) + +```python +import torch +import torch.nn as nn +import torch.nn.functional as F + +from msprobe.pytorch import PrecisionDebugger, seed_all +# 在模型训练开始前实例化PrecisionDebugger +debugger = PrecisionDebugger(dump_path="dump_path", level="debug") + +# 定义网络 +class ModuleOP(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear_1 = nn.Linear(in_features=8, out_features=4) + self.linear_2 = nn.Linear(in_features=4, out_features=2) + + def forward(self, x): + x1 = self.linear_1(x) + x2 = self.linear_2(x1) + debugger.save(x2, "x2", save_backward=True) # 调用save接口 + r1 = F.relu(x2) + return r1 + +if __name__ == "__main__": + module = ModuleOP() + + x = torch.randn(10, 8) + out = module(x) + loss = out.sum() + loss.backward() +``` + +分step保存数据(以PyTorch代码为例,MindSpore使用方法相同) + +```python +import torch +import torch.nn as nn +import torch.nn.functional as F + +from msprobe.pytorch import PrecisionDebugger +# 在模型训练开始前实例化PrecisionDebugger +debugger = PrecisionDebugger(dump_path="dump_path", level="debug") + +# 定义网络 +class ModuleOP(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear_1 = nn.Linear(in_features=8, out_features=4) + self.linear_2 = nn.Linear(in_features=4, out_features=2) + + def forward(self, x): + x1 = self.linear_1(x) + x2 = self.linear_2(x1) + debugger.save(x2, "x2", save_backward=True) # 调用save接口 + r1 = F.relu(x2) + return r1 + +if __name__ == "__main__": + module = ModuleOP() + train_iter = 10 + for i in range(train_iter): + x = torch.randn(10, 8) + out = module(x) + loss = out.sum() + loss.backward() + debugger.step() # 调用debugger.step用于区分step保存 + +``` + +## 静态图场景(MindSpore) + +### 使能方式 + +### 接口说明 + +工具提供三个对外接口用于保存训练过程中的数据: + +| 接口名称 | 功能描述 | 支持设备 | MindSpore版本 | 使用场景说明 | +| --------- | ------------------------ | -------------- | ------------- | ---------------------------------------------- | +| save | 保存正向传播的tensor数据 | Ascend/GPU/CPU | >= 2.6.0 | 图模式下仅支持Ascend,PyNative模式下支持全平台 | +| save_grad | 保存反向传播的梯度数据 | Ascend/GPU/CPU | >= 2.6.0 | 图模式下仅支持Ascend,PyNative模式下支持全平台 | +| step | 更新训练步数 | Ascend/GPU/CPU | >= 2.6.0 | 控制数据保存的step目录 | + +### 详细接口定义 + +#### 1. save 接口 + +```python +save(save_dir: str, name: str, data: Union[Tensor, List, Tuple, Dict]) +``` + +**参数说明**: + +- `save_dir`: 数据保存目录路径 +- `name`: 数据标识名称(将作为文件名前缀) +- `data`: 支持多种数据类型: + - `mindspore.Tensor` 单个张量 + - `List/Tuple/Dict` 嵌套结构(会自动展开保存) + +**使用示例**: + ```python -# 训练过程中被调用py文件 -from mindspore.pytorch import PrecisionDebugger -dict_variable = {"key1": "value1", "key2": [1, 2]} -PrecisionDebugger.save(dict_variable, "dict_variable", save_backward=False) +from msprobe.mindspore import save + +class Net(nn.Cell): + def construct(self, x): + save("./dump_data", 'input', x) # 保存输入数据 + return x * 2 +``` + +#### 2. save_grad 接口 +```python +save_grad(save_dir: str, name: str, data: Tensor) -> Tensor ``` +**参数说明**: + +- `save_dir`: 梯度保存目录路径 +- `name`: 梯度标识名称(将作为文件名前缀) +- `data`: 必须是 `mindspore.Tensor`类型 + +**特别注意**: + +- 必须接收返回值并传回原计算图 +- 此操作不会影响计算精度 + +**使用示例**: + +```python +from msprobe.mindspore import save_grad + +class Net(nn.Cell): + def construct(self, x): + x = save_grad("./dump_data", 'grad', x) # 保存梯度数据 + return x * 2 +``` + +#### 3. step 接口 + +```python +step() +``` + +**功能说明**: + +- 递增训练步数计数器 +- 控制数据保存到不同的step目录(如step0/, step1/等) +- 如果不调用,所有数据会保存到同一个step目录 + +**使用示例**: + +```python +from msprobe.mindspore import save, step + +# 训练循环中 +for epoch in range(epochs): + train_one_epoch() + step() # 每个epoch后更新step +``` + + ## 输出结果 - * **"task" 配置为 "statistics" 场景** :在 dump 目录下会生成包含变量统计值信息的 `debug.json` 文件。 - * **"task" 配置为 "tensor" 场景** :除了在 dump 目录下生成包含变量统计值信息的 `debug.json` 文件外,还会在 dump 子目录 `dump_tensor_data` 中保存张量二进制文件,文件名称格式为 `{variable_name}{grad_flag}.{count}.tensor.{indexes}.{file_suffix}`。 - - variable_name: 传入save接口的变量名称。 - - grad_flag: 反向数据标识,反向数据为"_grad",正向数据为""。 - - count: 调用计数,多次以相同变量名称调用时的计数。 - - indexes: 索引,在保存嵌套结构数据时的索引。例如:嵌套结构为`{"key1": "value1", "key2": ["value2", "value3"]}`,"value2"的索引为"key2.0" - - file_suffix:文件后缀,pytorch场景为"pt",mindspore场景为"npy" +### 动态图场景(Pytorch&MindSpore) + +* **"task" 配置为 "statistics" 场景** :在 dump 目录下会生成包含变量统计值信息的 `debug.json` 文件。 + `debug.json` 中统计值的key命名格式为 `{variable_name}{grad_flag}.{count}.debug`。 +* **"task" 配置为 "tensor" 场景** :除了在 dump 目录下生成包含变量统计值信息的 `debug.json` 文件外,还会在 dump 子目录 `dump_tensor_data` 中保存张量二进制文件,文件名称格式为 `{variable_name}{grad_flag}.{count}.debug.{indexes}.{file_suffix}`。 + + - variable_name: 传入save接口的变量名称。 + - grad_flag: 反向数据标识,反向数据为"_grad",正向数据为""。 + - count: 调用计数,多次以相同变量名称调用时的计数。 + - indexes: 索引,在保存嵌套结构数据时的索引。例如:嵌套结构为 `{"key1": "value1", "key2": ["value2", "value3"]}`,"value2"的索引为"key2.0"。 + - file_suffix:文件后缀,PyTorch场景为"pt",MindSpore场景为"npy"。 + +### 静态图场景(MindSpore) + +在指定目录 `save_dir`下生成 `{step}/{rank}`目录,目录下生成指定 `{name}`的npy文件,如果是save_grad接口调用,则会生成 `{name}_grad`的npy文件。 +如 `save("./test_dump", 'x', x)` -> `./test_dump/step0/rank0/x_float32_0.npy`。 +或如 `z = save_grad("./test_dump", 'z', z)` -> `./test_dump/step0/rank0/z_grad_float32_0.npy`。 + +结构如下: + +``` +./save_dir/ + ├── step0/ + │ ├── rank0/ + │ │ ├── x_float32_0.npy # save保存的正向数据 + │ │ └── z_grad_float32_0.npy # save_grad保存的梯度数据 + ├── step1/ + │ ├── rank0/ + │ │ ├── ... +``` diff --git a/debug/accuracy_tools/msprobe/docs/28.kernel_dump_MindSpore.md b/debug/accuracy_tools/msprobe/docs/28.kernel_dump_MindSpore.md index 6b8cc558aa22526158033cfb35f31203d8b04278..4988586c0568b391739f7c14f1a9452461f1a6f1 100644 --- a/debug/accuracy_tools/msprobe/docs/28.kernel_dump_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/28.kernel_dump_MindSpore.md @@ -1,4 +1,4 @@ -# MindSpore 场景的 kernel dump 说明 +# MindSpore 动态图场景的 kernel dump 说明 当使用 msprobe 数据采集功能时,level 配置为 "L2" 表示采集 kernel 层级的算子数据,仅支持昇腾 NPU 平台。 diff --git a/debug/accuracy_tools/msprobe/docs/29.data_dump_MSAdapter.md b/debug/accuracy_tools/msprobe/docs/29.data_dump_MSAdapter.md new file mode 100644 index 0000000000000000000000000000000000000000..6439a0a6eb1e5a12829c7f10ab0aa1baccac85d0 --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/29.data_dump_MSAdapter.md @@ -0,0 +1,235 @@ +# MSAdapter 场景的精度数据采集 + +MSAdapter 是一款 MindSpore 生态适配工具,可以将 PyTorch 训练脚本高效迁移至 MindSpore 框架执行,以实现在不改变原有 PyTorch 用户开发习惯的情况下,使得 PyTorch 代码能在昇腾上获得高效性能。 + +msprobe 工具主要通过在训练脚本内添加 dump 接口、启动训练的方式采集精度数据。 + +**注意**: + +- 为了正确识别 MSAdapter 场景,在导入 msprobe 工具前,需完成 torch 模块的导入。 + +- 因 MindSpore 框架自动微分机制的限制,dump 数据中可能会缺少原地操作模块/API 及其上一个模块/API 的反向数据。 + +本工具提供固定的 API 支持列表,若需要删除或增加 dump 的 API,可以在 msprobe/pytorch/hook_module/support_wrap_ops.yaml 文件内手动修改,如下示例: + +```yaml +functional: # functional为算子类别,找到对应的类别,在该类别下按照下列格式删除或添加API + - conv1d + - conv2d + - conv3d +``` + +删除 API 的场景:部分模型代码逻辑会存在 API 原生类型校验,工具执行dump操作时,对封装后的模型 API 可能与模型的原生 API 类型不一致,此时可能引发校验失败,详见《[FAQ](FAQ.md#33-异常情况)》中“异常情况”的第10和11条。 + +## 1. 工具安装 + +请参见[《msprobe 工具安装指南》](./01.installation.md)。 + +## 2 接口介绍 + +### 2.1 msprobe.mindspore.PrecisionDebugger + +**功能说明**:通过加载 dump 配置文件的方式来确定 dump 操作的详细配置。 + +**原型**: + +```Python +PrecisionDebugger(config_path=None, task=None, dump_path=None, level=None, step=None) +``` + +**参数说明**: + +1. config_path:指定 dump 配置文件路径,string 类型。参数示例:"./config.json"。未配置该路径时,默认使用 [config.json](../config.json) 文件的默认配置,配置选项含义可见 [config.json 介绍](./02.config_introduction.md)。 + +2. 其他参数与 [config.json](../config.json) 文件中的同名配置字段含义相同,具体可见 [config.json 介绍](./02.config_introduction.md)。当参数值非None时,优先级高于 [config.json](../config.json) 文件中的同名配置。 + +#### 2.1.1 start + +**功能说明**:启动精度数据采集。需要与 [**stop**](#212-stop) 接口一起添加在训练迭代的 for 循环内。 + +**原型**: + +```Python +start(model=None) +``` + +**参数说明**: + +1. model:指定需要采集 Module 级数据的模型,支持传入 torch.nn.Module、list[torch.nn.Module]或Tuple[torch.nn.Module] 类型,默认未配置。level 配置为 "L0" 或 "mix" 时,必须在该接口中配置该参数。API级别("L1" level)dump 时,传入 model 可以采集 model 内包含 primitive op 对象在内的所有 API 数据,若不传入 model 参数,则只采集非 primitive op 的 API 数据。 + +#### 2.1.2 stop + +**功能说明**:停止精度数据采集。在 **start** 接口调用之后的任意位置添加。若 **stop** 接口添加在反向计算代码之后,则会采集 **start** 和该接口之间的前反向数据。 +若 **stop** 接口添加在反向计算代码之前,则需要将 [**step**](#213-step) 接口添加到反向计算代码之后,才能采集 **start** 和该接口之间的前反向数据。 + +**注意**:**stop** 接口必须调用,否则可能导致精度数据落盘不全。 + +**原型**: + +```Python +stop() +``` + +#### 2.1.3 step + +**功能说明**:进行训练 step 数的自增,完成当前 step 所有数据的落盘并更新 dump 参数。在一个 step 训练结束的位置添加,且必须在 **stop** 接口之后的位置调用。该接口需要配合 **start** 和 **stop** 函数使用,尽量添加在反向计算代码之后,否则可能会导致反向数据丢失。 + +**原型**: + +```Python +step() +``` + +#### 2.1.4 forward_backward_dump_end + +**功能说明**:停止精度数据采集。与 **stop** 接口功能相同,该函数在将来会被移除,建议使用 **stop** 接口。 + +**原型**: + +```Python +forward_backward_dump_end() +``` + +#### 2.1.5 save + +**功能说明**:单点保存网络执行过程中正反向数值,并以统计值/张量文件落盘。 + +**原型**: +```python +save(variable, name, save_backward=True) +``` + +**参数说明**: +| 参数名称 | 参数含义 | 支持数据类型 | 是否必选| +| ---------- | ------------------| ------------------- | ------------------- | +| variable | 需要保存的变量 |dict, list, tuple, torch.tensor, int, float, str | 是 | +| name | 指定的名称 | str | 是 | +| save_backward | 是否保存反向数据 | boolean | 否 | + +### 2.2 msprobe.mindspore.seed_all + +**功能说明**:用于固定网络中的随机性和开启确定性计算。 + +**原型**: +```python +seed_all(seed=1234, mode=False, rm_dropout=True) +``` + +**参数说明**: + +1. seed: 随机性种子,默认值:1234,非必选。参数示例: seed=1000。该参数用于 random、numpy.random, mindspore.common.Initializer、mindspore.nn.probability.distribution的随机数生成以及 Python 中 str、bytes、datetime 对象的 hash 算法。 + +2. mode:确定性计算使能,可配置 True 或 False,默认值:False,非必选。参数示例:mode=True。该参数设置为 True 后,将会开启算子确定性运行模式与归约类通信算子(AllReduce、ReduceScatter、Reduce)的确定性计算。注意:确定性计算会导致 API 执行性能降低,建议在发现模型多次执行结果不同的情况下开启。 + +3. rm_dropout:控制 dropout 失效的开关。可配置 True 或 False,默认值:True,非必选。参数示例:rm_dropout=True。该参数设置为 True 后,将会使 mindspore.ops.Dropout,mindspore.ops.Dropout2D,mindspore.ops.Dropout3D,mindspore.mint.nn.Dropout和mindspore.mint.nn.functional.dropout 失效,以避免因随机 dropout 造成的网络随机性。建议在采集数据前调用。 + +**注意**:通过 rm_dropout 控制 dropout 失效或生效需要在初始化 Dropout 实例前调用才能生效。 + +## 3 示例代码 + +以下为添加了 msprobe 工具 dump 接口的示例训练脚本。 + +```python +import mindspore as ms +import torch +import torch.nn as nn +import torch.nn.functional as F + +# 导入工具的数据采集接口 +from msprobe.mindspore import PrecisionDebugger + +# 在模型训练开始前实例化PrecisionDebugger +debugger = PrecisionDebugger(config_path='./config.json') + + +# 定义网络 +class Net(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = nn.Linear(in_features=8, out_features=4) + self.linear2 = nn.Linear(in_features=4, out_features=2) + + def forward(self, x): + x1 = self.linear1(x) + x2 = self.linear2(x1) + logits = F.relu(x2) + return logits + + +net = Net() + + +def train_step(inputs): + return net(inputs) + + +if __name__ == "__main__": + data = (torch.randn(10, 8), torch.randn(10, 8), torch.randn(10, 8)) + grad_fn = ms.value_and_grad(train_step, grad_position=0) + + for inputs in data: + # 开启数据 dump + debugger.start(model=net) + + out, grad = grad_fn(inputs) + + # 停止数据 dump + debugger.stop() + # 更新 step 信息 + debugger.step() +``` + +## 4 dump 结果文件介绍 + +训练结束后,工具将 dump 的数据保存在 dump_path 参数指定的目录下。目录结构示例如下: + +```lua +├── dump_path +│ ├── step0 +│ | ├── rank0 +│ | │ ├── dump_tensor_data +| | | | ├── Tensor.permute.1.forward.npy +| | | | ├── Functional.linear.5.backward.output.npy # 命名格式为{api_type}.{api_name}.{API调用次数}.{forward/backward}.{input/output}.{参数序号}, 其中,“参数序号”表示该API的第n个输入或输出,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该API的第1个参数的第1个元素。 +| | | | ... +| | | | ├── Module.conv1.Conv2d.forward.0.input.0.npy # 命名格式为{Module}.{module_name}.{class_name}.{forward/backward}.{调用次数}.{input/output}.{参数序号}, 其中,“参数序号”表示该Module的第n个参数,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该Module的第1个参数的第1个元素。 +| | | | ├── Module.conv1.Conv2D.forward.0.parameters.bias.npy # 模块参数数据:命名格式为{Module}.{module_name}.{class_name}.forward.{调用次数}.parameters.{parameter_name}。 +| | | | └── Module.conv1.Conv2D.parameters_grad.weight.npy # 模块参数梯度数据:命名格式为{Module}.{module_name}.{class_name}.parameters_grad.{parameter_name}。因为同一模块的参数使用同一梯度进行更新,所以参数梯度文件名不包含调用次数。 +| | | | # 当dump时传入的model参数为List[torch.nn.Module]或Tuple[torch.nn.Module]时,模块级数据的命名中包含该模块在列表中的索引index,命名格式为{Module}.{index}.*,*表示以上三种模块级数据的命名格式,例如:Module.0.conv1.Conv2d.forward.0.input.0.npy。 +│ | | ├── dump.json +│ | | ├── stack.json +│ | | └── construct.json +│ | ├── rank1 +| | | ├── dump_tensor_data +| | | | └── ... +│ | | ├── dump.json +│ | | ├── stack.json +| | | └── construct.json +│ | ├── ... +│ | | +| | └── rank7 +│ ├── step1 +│ | ├── ... +│ ├── step2 +``` +* `rank`:设备 ID,每张卡的数据保存在对应的 `rank{ID}` 目录下。非分布式场景下没有 rank ID,目录名称为 rank。 +* `dump_tensor_data`:保存采集到的张量数据。 +* `dump.json`: 保存 API 或 Module 前反向数据的统计量信息。包含 dump 数据的 API 名称或 Module 名称,各数据的 dtype、 shape、max、min、mean、L2norm(L2范数,平方根)统计信息以及当配置 summary_mode="md5" 时的 CRC-32 数据。具体介绍可参考[dump.json文件说明](./27.dump_json_instruction.md#3-msadapter-场景下的-dumpjson-文件)。 +* `stack.json`:API/Module 的调用栈信息。 +* `construct.json`:分层分级结构,level 为 L1 时,construct.json 内容为空。 + + +当 task 为 tensor 时,dump 过程中,npy 文件在对应算子或者模块被执行后就会落盘,而 json 文件则需要在正常执行 PrecisionDebugger.stop() 后才会写入完整数据。因此如果程序异常终止,终止前被执行算子的相关 npy 文件得以保存,但 json 文件中的数据可能丢失。 + +其中 rank 为设备上各卡的 ID,每张卡上 dump 的数据会生成对应 dump 目录。非分布式场景下没有 rank ID,目录名称为 rank。 + +npy 文件名的前缀含义如下: + +| 前缀 | 含义 | +| ----------- | ---------------------------- | +| Tensor | torch.Tensor API数据 | +| Torch | torch API数据 | +| Functional | torch.nn.functional API数据 | +| NPU | NPU 亲和API数据 | +| Distributed | torch.distributed API数据 | +| Jit | 被 "jit" 装饰的模块或函数数据 | +| Module | torch.nn.Module 类(模块)数据 | \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/30.overflow_check_MSAdapter.md b/debug/accuracy_tools/msprobe/docs/30.overflow_check_MSAdapter.md new file mode 100644 index 0000000000000000000000000000000000000000..e963a60e8361be2569e7f85ee0d97df9194d6d91 --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/30.overflow_check_MSAdapter.md @@ -0,0 +1,31 @@ +# MSAdapter 场景的溢出检测 + +msprobe 工具提供 MSAdapter 场景下的溢出检测功能。其检测对象为 **API** 级别(除 Primitive 和 Jit 类 API)或**模块**级别,分别对应 config.json 配置中的 **"L1"** 、**"L0"** level。 + +需要注意,本工具仅支持在 INF/NAN 模式a下进行溢出检测。INF/NAN 模式的使能方式如下: + +```Shell +# 使能 CANN 侧 INF/NAN 模式 +export INF_NAN_MODE_ENABLE=1 +# 使能 MindSpore 框架侧 INF/NAN 模式 +export MS_ASCEND_CHECK_OVERFLOW_MODE="INFNAN_MODE" +``` + +**a**:在处理浮点数计算溢出问题时,NPU 当前支持两种溢出模式:INF/NAN 模式与饱和模式。INF/NAN 模式遵循 IEEE 754 标准,根据定义输出 INF/NAN 的计算结果。与之对应的饱和模式在计算出现溢出时,饱和为浮点数极值(+-MAX)。对于 CANN 侧配置,Atlas 训练系列产品,默认为饱和模式,且不支持使用 INF/NAN 模式;Atlas A2训练系列产品,默认为 INF/NAN 模式,且不建议使用饱和模式。对于 MindSpore 框架侧配置,仅支持对 Atlas A2 训练系列产品进行设置,默认为 INF/NAN 模式。CANN 侧 与 MindSpore 框架侧配置须一致。 + +溢出检测任务的配置示例见["**MindSpore 动态图场景 task 配置为 overflow_check**"](./03.config_examples.md#33-task-配置为-overflow_check)小节。 + + +## 1 接口介绍 + +溢出检测功能提供的接口与数据采集任务一致,详见 MSAdapter 场景的精度数据采集中的["**2 接口介绍**"](./29.data_dump_MSAdapter.md#2-接口介绍)小节。 + +需要注意,目前暂不支持 "L1" level 下 primitive op 的溢出检测。 + +## 2 示例代码 + +溢出检测功能使用方式与数据采集任务一致,详见 MSAdapter 场景的精度数据采集中的["**3 示例代码**"](./29.data_dump_MSAdapter.md#3-示例代码)小节。 + +## 3 溢出检测结果文件介绍 + +溢出检测结果文件目录结构与含义与数据采集任务一致,但仅保存溢出 API 或 模块 的真实数据或统计信息。详见 MSAdapter 场景的精度数据采集中的["**4 dump 结果文件介绍**"](./29.data_dump_MSAdapter.md#4-dump-结果文件介绍)小节。 \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/31.config_check.md b/debug/accuracy_tools/msprobe/docs/31.config_check.md new file mode 100644 index 0000000000000000000000000000000000000000..eb9403aa5676c08f37dd26c54f63f2a663cef507 --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/31.config_check.md @@ -0,0 +1,107 @@ +# config check + +## 介绍 + +该工具主要适用于对比两个环境下可能影响训练精度的配置差异,支持mindspore和pytorch两个框架,包括: + +- 环境变量 +- 三方库版本 +- 训练超参 +- 权重 +- 数据集 +- 随机操作 + + +## 安装教程 + +参见 msprobe [安装教程](./01.installation.md) + +## 使用说明 + +用户需要在两个待比对的训练的环境上分别进行数据采集, 工具会采集两个环境下影响精度的配置,采集结果上传到同一机器进行比对。 + +### 数据采集 + +#### 静态数据采集 + +静态数据采集仅支持环境变量,三方库版本及训练超参采集,其中环境变量,三方库版本默认采集,训练超参采集需要用户传入启动训练的 shell 脚本路径或 yaml 配置文件, +支持多个输入,不传入表示不采集。 + +启动命令如下 +```shell +msprobe -f pytorch/mindspore config_check -d **.sh **.yaml -o output_path +``` +-f或--framework 代表训练框架,传入pytorch或mindspore,必选。 + +-d或--dump 代表数据采集模式,可传入启动训练的 shell 脚本路径或 yaml 配置文件路径,可选,不传入代表不采集。 + +-o或--output 代表输出路径,可选,默认为 config_check_pack.zip,必须以 `.zip` 后缀结尾。 + +#### 动态数据采集 + + +在训练流程执行到的第一个python脚本开始处插入如下代码: +``` +from msprobe.core.config_check import ConfigChecker +ConfigChecker.apply_patches(fmk) +``` + +说明: + +apply_patches:启动数据采集所需的各项patch,参数如下: + +- fmk:训练框架。可选 pytorch 和 mindspore ,不传默认为 pytorch。 + +在模型初始化好之后插入如下代码: +``` +from msprobe.core.config_check import ConfigChecker +ConfigChecker(model=model, shell_path="", output_zip_path="./config_check_pack.zip", fmk="pytorch") +``` + +说明: + +ConfigChecker:对模型挂上数据采集所需的hook,会在每次模型前向将要被执行的一刻进行数据采集。参数如下: + +- model:初始化好的模型。不传或缺省就不会采集权重和数据集。 +- shell_path:动态采集模式下支持 **megatron** 训练超参自动捕获,使用 **megatron** 时推荐不传入,其他情况下可传入训练脚本路径,类型为列表,传入一个或多个训练配置/启动脚本。不传或缺省就不会采集超参。 +- output_zip_path:输出zip包的路径,不传默认为"./config_check_pack.zip"。 +- fmk:当前是什么框架。可选 pytorch 和 mindspore ,不传默认为 pytorch。 + +采集完成后会得到一个zip包,里面包括各项[影响精度的配置](#介绍)。会分rank和step存储,其中step为micro_step。 + +在另一个环境上执行上述操作,得到另一个zip包 + +### 数据比对 + +将两个zip包传到同一个环境下,使用如下命令进行比对: + +```shell +msprobe -f pytorch config_check -c bench_zip_path cmp_zip_path -o output_path +``` + +-c或--compare 表示compare,数据对比,有两个参数。其中**bench_zip_path** 为标杆侧采集到的数据, **cmp_zip_path** 为待对比侧采集到的数据。 + +**output_path 里原有的比对结果会被覆盖**,不传默认为"./config_check_result", 在 **output_path** 里会生成2个目录和1个文件: +- bench:bench_zip_path里打包的数据。 +- cmp:cmp_zip_path里打包的数据。 +- result.xlsx:比对结果。里面会有多个sheet页,其中**summary**总览通过情况,其余页是具体检查项的详情。其中step为micro_step。 + +## 通过标准 + +以下五项检查通过: + +- 环境变量 +- 三方库版本 +- 训练超参 +- 权重 +- 数据集 + +这五项检查在**精度比对**前必须保证达成。 + + +## FAQ + +1. 在使用 MindSpeed-LLM 进行数据采集时,需要注意动态数据采集中的 [apply_patches](#动态数据采集) 函数需要在 MindSpeed-LLM +框架 pretrain_gpt.py 的 megatron_adaptor 函数导入之后执行。 + +2. 静态数据采集功能只能获取到系统中的环境变量,shell 脚本中解析的超参不支持复杂运算的数据还原,有类似问题时建议使用[动态采集方式](#动态数据采集)。 diff --git a/debug/accuracy_tools/msprobe/docs/32.ckpt_compare.md b/debug/accuracy_tools/msprobe/docs/32.ckpt_compare.md new file mode 100644 index 0000000000000000000000000000000000000000..0b9a070214653aed3c19cbad5e8888f0985ecf01 --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/32.ckpt_compare.md @@ -0,0 +1,69 @@ +# Checkpoint Compare + +## 介绍 +在模型训练过程中或结束后,可能保存一些检查点文件(checkpoint,简称ckpt)记录当前模型、优化器等训练状态, 工具支持比较两个不同的ckpt,评估模型相似度。 + +当前支持Megatron-LM、MindSpeed(PyTorch/MindTorch)的ckpt比较。支持TP、PP、EP、VPP模型并行;支持megatron.core、megatron.legacy、TransformerEngine的模型实现。 + + +## 安装教程 + +参见 msprobe [安装教程](./01.installation.md) + +## 使用说明 +Megatron、MindSpeed的ckpt加载依赖megatron,请确保megatron在python环境中或megatron在当前路径下。 + + +启动命令如下 +```shell +msprobe --framework pytorch config_check --compare path1 path2 -o output_path.json +``` + +| 参数名 | 解释 | 是否必选 | +|--------|-------|--------| +| -f 或 --framework | 深度学习框架,str类型。支持参数:pytorch,mindspore,注意:msadaptor场景传入mindspore。 | 是 | +| -c 或 --compare | 2个ckpt的路径 | 是 | +| -o 或 --output | 比对结果输出路径,默认为 ./ckpt_similarity.json。输出路径存在时将报错终止。 | 否 | + +Megatron-LM 和 MindSpeed 的 ckpt 目录结构如下: + +```txt +directory_name/ +├── iter_0000005/ # 某个iteration时的ckpt目录。 +│ └── mp_rank_xx_xxx/ # 单个rank的ckpt目录,xx_xxx为模型并行索引。 +│ └── model_optim_rng.pt # 包含模型参数、随机状态等的PyTorch binary文件。 +├── iter_0000010/ +├── latest_checkpointed_iteration.txt # 记录最后一个保存的ckpt的纯文本文件。 +``` + +对于--compare参数的两个路径,为directory_name时,工具通过latest_checkpointed_iteration.txt自动选择latest checkpoint进行比对. 为directory_name/iter_xxxxxxx时, 工具使用指定iteration的ckpt进行比对。暂不支持单个rank的比对。 + +## 输出示例 +Checkpoint比对结果以json文件输出,内容如下示例: +```json +{ + "decoder.layers.0.input_layernorm.weight": { + "l2": 0.0, + "cos": 0.999999, + "numel": 128, + "shape": [ + 128 + ] + }, + "decoder.layers.0.pre_mlp_layernorm.weight": { + "l2": 0.012, + "cos": 0.98, + "numel": 128, + "shape": [ + 128 + ] + } +} +``` + +统计量 | 解释 | +|-------|---------| +| l2 | 欧式距离,$\|\|a-b\|\|_2$ | +| cos | 余弦相似度, $\frac{}{\|\|a\|\|_2\|\|b\|\|_2}$ | +| numel | 参数的元素个数 | +| shape | 参数的shape | \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/33.generate_operator_MindSpore.md b/debug/accuracy_tools/msprobe/docs/33.generate_operator_MindSpore.md new file mode 100644 index 0000000000000000000000000000000000000000..32dcf145fc96050559fb9316ce80a59674d8b4c9 --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/33.generate_operator_MindSpore.md @@ -0,0 +1,181 @@ +# 单算子API自动生成脚本 + +## 1 简介 + +单算子API自动生成脚本通过提取dump数据中的可疑算子,对其进行单API复现,输出单API精度的比对结果。具体而言,该工具可以从dump数据中提取可疑API的前反向信息,根据前反向数据生成单API的前反向过程,最后通过**新精度标准比对法**a将 NPU 和 CPU 的结果进行比对,从而给出不同比对方法下的比对结果。本工具支持**随机生成模式和真实数据模式**b。 + +a. 在生成单API脚本时可以选择由工具构造随机数获得 dump 数据或选择真实输入的数据进行单API复现。随机生成模式(对应 task: "statistics")执行效率高,可以快速获得结果,但数据精度低,只能大致判断精度问题;真实数据模式(对应 task: "tensor")执行效率略低于随机生成模式,但是数据精度高,可以准确判断精度问题。 + +## 2 使用方式 + +### 前提 +1. 安装 msprobe。详见[ msprobe 安装](./01.installation.md)章节。 +2. 已完成对训练过程的dump,获得dump.json文件。 + [MindSpore 场景下的数据采集](./06.data_dump_MindSpore.md)章节或[Msadapter 场景下的数据采集](./29.data_dump_MSAdapter.md)章节,注意需要配置 level="L1"。 + +3. 发现某个算子疑似存在精度问题,并得知算子名,如Mint.split.1、Functional.softmax.3、Tensor.add.0、Torch.matmul.5等,要求API入参与torch场景相同。 + +4.(可选)当需要使用Msadapter时,由于需要环境中同时存在 Torch 与 Msadapter,所以只支持在**安装原生Torch**的场景下通过export PYTHONPATH="xx/msadapter/build/lib"等通过**环境变量使能Msadapter的方式**的环境中进行预检,预检工具能够自动索引得到所需的 Torch 与 Msadapter环境,环境安装详细参考:[msadapter官网](https://gitee.com/mindspore/msadapter)(该网站需要申请权限方可访问)。 + +### 2.1 dump.json 示例(由dump生成) +``` +{ + "task": "statistics", + "level": "L1", + "framework": "mindtorch", + "dump_data_dir": null, + "data": { + "Tensor.reshape.779.forward": { + "input_args": [ + { + "type": "mindspore.Tensor", + "dtype": "BFloat16", + "shape": [ + 8192, + 896 + ], + "Max": 0.62890625, + "Min": -0.78515625, + "Mean": 0.00035858154296875, + "Norm": 105.0 + }, + { + "type": "int", + "value": -1 + }, + { + "type": "int", + "value": 896 + } + ], + "input_kwargs": {}, + "output": [ + { + "type": "mindspore.Tensor", + "dtype": "BFloat16", + "shape": [ + 8192, + 896 + ], + "Max": 0.62890625, + "Min": -0.78515625, + "Mean": 0.00035858154296875, + "Norm": 105.0 + } + ] + }, + "Tensor.reshape.779.backward": { + "input": [ + { + "type": "mindspore.Tensor", + "dtype": "BFloat16", + "shape": [ + 8192, + 896 + ], + "Max": 0.0, + "Min": 0.0, + "Mean": 0.0, + "Norm": 0.0 + } + ], + "output": [ + { + "type": "mindspore.Tensor", + "dtype": "BFloat16", + "shape": [ + 8192, + 896 + ], + "Max": 0.0, + "Min": 0.0, + "Mean": 0.0, + "Norm": 0.0 + } + ] + } + } + } +``` + +### 2.2 配置config_op.json +单API复现参数配置如下(以复现softmax算子为例): +``` +{ + "dump_json_path": "./dump.json", + "api_name": "Tensor.reshape.779", + "extract_api_path": "Tensor.reshape.779", + "propagation": "backward", + "data_mode": "random_data", + "random_seed": 42, + "iter_times": 1 +} +``` +**配置文件参数说明** + + | 参数名称 | 解释 | + | ---------------------------- |-------------------------------------------------------------------------------------------------------------------------------------------------------| + | dump_json_path | dump.json的文件路径,包含所有dump算子的信息;如果已经提取了可疑算子并保存可以不指定。 | + | api_name | 算子名(目前MindSpore支持类型包括:Mint,Tensor,Msadapter支持类型包括:Tensor,Functional,Torch类中可自动求导api),如Mint.split.1,Functional.softmax.3、Tensor.add.0、Torch.matmul.5等。 | + | extract_api_path | 提取可疑算子的json文件路径 | + | propagation | 选择复现算子的forward还是backward,默认为forward | + | data_mode | 选择复现算子的随机数据(random_data)还是真实数据(real_data)模式,默认为random_data | + | random_seed | 仅random_data模式有效,表示手动设定的随机种子,默认为42 | + | iter_times | 仅random_data模式有效,表示单API运行的次数,由于安全相关原因,最大支持设置为1000 | + + ### 2.3 运行命令生成单API脚本 +config_op.json配置好后,运行如下命令: +``` +msprobe -f mindspore op_generate -i ./config_op.json -o ./ +``` + +**参数说明** + | 参数名称 | 解释 | 是否必选 | + | ---------------------------- | ------------------------------------------------------------ | ---------------------------------- | + | -i 或 --config_input | config_op.json的路径 | 是 | + | -o 或 --api_output_path | 单API脚本的输出路径 | 是 | + +### 2.4 运行单API脚本 + 运行完op_generator.py后,会在指定路径下生成api_name.py的单API脚本,例如Mint.split.1.forward.py、Functional.softmax.3.backward.py、Tensor.add.0.forward.py、Torch.matmul.5.backward.py + +运行单API脚本即可获得不同比对方法下的比对结果 +``` +python api_name.py +``` + +**运行结果说明** + +单算子脚本生成到路径`./op_result_output`的 `accuracy_checking_result_{timestamp}.csv` 和 `accuracy_checking_details_{timestamp}.csv` 文件内容详情如下: + +`accuracy_checking_details_{timestamp}.csv` + +| 字段 | 含义 | +| ------------------- | ------------------------------------------------------------ | +| API Name | API 名称。 | +| Bench Dtype | 标杆数据的 API 数据类型。 | +| Tested Dtype | 被检验数据的 API 数据类型。 | +| Shape | API 的 Shape 信息。 | +| Cosine | 被检验数据与标杆数据的余弦相似度。 | +| MaxAbsErr | 被检验数据与标杆数据的最大绝对误差。 | +| MaxRelativeErr | 被检验数据与标杆数据的最大相对误差。 | +| Status | API 预检通过状态,pass 表示通过测试,error 表示未通过。 | +| Message | 提示信息。 | + +注意:PyTorch 无法对 dtype 为整数类型的 tensor 进行反向求导,而 MindSpore 支持。反向过程的预检仅比较 dtype 为浮点型的输出。 + +`accuracy_checking_result_{timestamp}.csv` + +| 字段 | 含义 | +| --------------------- | ----------------- | +| API Name | API 名称。 | +| Forward Test Success | 前向 API 是否通过测试,pass 为通过,error 为错误。 | +| Backward Test Success | 反向 API 是否通过测试,pass 为通过,error 为错误,如果是空白的话代表该 API 没有反向输出。 | +| Message | 提示信息。 | + +Forward Test Success 和 Backward Test Success 是否通过测试是由 `accuracy_checking_details_{timestamp}.csv` 中的余弦相似度、最大绝对误差判定结果决定的。具体规则详见 [3.1 API 预检指标](#31-api-预检指标)。 +需要注意的是 `accuracy_checking_details_{timestamp}.csv` 中可能存在一个 API 的前向(反向)有多个输出,那么每个输出记录一行,而在 `accuracy_checking_result_{timestamp}.csv` 中的结果需要该 API 的所有结果均为 pass 才能标记为 pass,只要存在一个 error 则标记 error。 + +### 3.1 API 预检指标 + + - API 预检指标是通过对 `accuracy_checking_details_{timestamp}.csv` 中的余弦相似度、最大绝对误差的数值进行判断,得出该 API 是否符合精度标准的参考指标。 + - 余弦相似度大于 0.99,并且最大绝对误差小于 0.0001,标记“pass”,否则标记为“error”。 \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/34.RL_collect.md b/debug/accuracy_tools/msprobe/docs/34.RL_collect.md new file mode 100644 index 0000000000000000000000000000000000000000..0767b6ee7dd071dae540c5caebf9c47606d99d55 --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/34.RL_collect.md @@ -0,0 +1,101 @@ +# 强化学习数据采集 + +## 介绍 +在强化学习训练过程中,往往存在多个模型(actor、reward、reference)和两个阶段(推理、训练),问题定界困难。 + +本工具提供一种灵活存储强化学习训练过程中关键阶段性数据的能力,并支持对比两次采集的关键数据,以支持问题快速定界。 + +常用关键数据示例:prompt、response、reward、log_prob、ref_log_probe、old_log_probe、kl_loss。 + + +## 安装教程 + +参见 msprobe [安装教程](./01.installation.md)。 + +## 使用说明 + +### 数据采集 + +用户识别脚本中需要采集数据的地方,然后通过插入代码的方式采集关键数据。 + +当确定需要采集数据的地方,例如response,可以按如下方式对数据进行存储: +``` +from msprobe.core import SingleSave +SingleSave(dump_path="./dump_path", fmk="pytorch") +SingleSave.save(data={"response": response}) +``` +SingleSave:初始化一个单点存储模块,参数如下: +- dump_path:输出路径,"./dump_path"为输出路径,没有默认值,需要自己配置。 +- fmk:当前训练框架。可选"pytorch"或者"mindspore",默认"pytorch"。 + +SingleSave.save:将需要存储的数据落盘,数据会分step分rank存储,并支持后续比对,参数如下: +- data:是需要存储的数据,是一批key和value组成的dict,key是用户自己决定的数据名称,会成为后续落盘数据中的目录,value是代码里需要存储的变量,支持tensor、tuple、list。比如其中"response"是可以任意指定的key,response是训练过程中的真实tensor变量。 + +也支持一次性存储多个数据: +``` +from msprobe.core import SingleSave +SingleSave("./dump_path", fmk="pytorch") +SingleSave.save({ + "prompt": prompt, + "response": response + }) +``` + +### 配置保存 + +当确定需要采集数据配置json的地方,可以按如下方式对配置进行存储: +``` +from msprobe.core import SingleSave +SingleSave("./dump_path") +SingleSave.save_config(data=configurations_json) +``` + +SingleSave.save_config:将配置信息进行存储,入参如下: +- data:需要存储的配置信息,需要是一个dict。 + +采集到的数据目录结构如下: +```txt +dump_path/ +├── data/ # 固定为data +│ └── response/ # 关键数据名称,来自SingleSave.save的时候的key +│ └── step0/ # step数 +│ └── rank0/ # rank数 +│ └── micro_step0/ #micro_step数 +| └── response0.npy #存储的关键数据的真实npy文件 +| └── response0.json #存储的关键数据的统计量文件,包括tensor的最大、最小、均值、norm、shape +├── configurations.json # 配置json文件 +``` + +### 结果比对 + +两次采集数据之后得到dump_path1和dump_path2,可以创建一个比对脚本,例如compare.py,将两次训练的dump_path传入: +``` +from msprobe.core import SingleComparator +SingleComparator.compare( + dir1="dump_path1", + dir2="dump_path2", + output_path="output_path") +``` +SingleComparator.compare:将两次采集的数据进行比对,参数如下: +- dir1:需要比对的其中一个dump_path,对应SingleSave的dump_path。 +- dir2:需要比对的令一个dump_path,对应SingleSave的dump_path。 +- output_path:比对结果输出路径,默认为"./msprobe_compare_output"。 +- num_processes:比对的时候起多少个进程,默认为8。 +会在output_path下对每种关键数据都生成excel结果表格,比如response.xlsx,形式为关键数据的名字加上.xlsx后缀。 + +表格会体现每一个对应tensor的差异,解释: + +表头 | 解释 | +|-------|---------| +| step | 训练步数 | +| rank | 卡号 | +| micro_step | 梯度累计步数 | +| id | 参数的shape | +| shape1 | dump_path1中的数据形状 | +| shape2 | dump_path2中的数据形状 | +| 相同元素百分比 | 元素相同的个数占总元素个数的百分比 | +| 首个不匹配元素索引 | 首个匹配不上的元素是第几个 | +| 最大绝对误差 | 最大绝对误差 | +| 最大相对误差 | 最大相对误差 | +| 误差在千分之一内元素占比 | 误差在千分之一内元素个数占总元素个数的百分比 | +| 误差在百分之一内元素占比 | 误差在百分之一内元素个数占总元素个数的百分比 | diff --git a/debug/accuracy_tools/msprobe/docs/35.nan_analyze.md b/debug/accuracy_tools/msprobe/docs/35.nan_analyze.md new file mode 100644 index 0000000000000000000000000000000000000000..37b0be8bb4fba80e30d93eee4de99fd19e9bb47b --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/35.nan_analyze.md @@ -0,0 +1,73 @@ +# 整网首个溢出节点分析 + +## 介绍 +在分析inf、nan的场景下,会采集多个rank下的多个step的dump数据,前面出现的异常会传播到同rank后续的节点,并通过通信算子传播到其他rank的后续节点中,因此如何分析首个nan出现的节点位置尤为重要。 + +通过nan_analyze工具可以对pytorch的dump数据进行分析。在多卡场景下,检测到每张卡中产生inf/nan的节点。若是经过通信导致的inf/nan,可以分析并找出首个产生inf/nan的rank和节点。 + +## 安装教程 + +参见 msprobe [安装教程](./01.installation.md)。 + +## 使用说明 + +当前仅支持分析pytorch的dump数据。 + +### 采集数据 + +参见 [PyTorch 场景的精度数据采集](./05.data_dump_PyTorch.md)。 + +### 执行命令 + +```commandline +msprobe -f pytorch nan_analyze -i dump_step_path -o output_dir_path +``` + +| 参数 | 说明 | +|--------------------|---------------------------------------------| +| -f 或 --framework | 指定训练框架。pytorch。必选。 | +| -i 或 --input_path | dump数据的目录。需指定到step层级,如`-i /xxx/dump/step0/` | +| -o 或 --output_path | 输出文件的目录,可选,不填时默认在当前目录下创建 \"./output/" 目录。 | + +### 输出文件介绍 + +当日志打印 +``` +Cannot find any anomaly node, no need to generate analyze file. +``` +时,分析认为不存在异常节点,不生成分析文件。 + +存在异常节点时,生成`anomaly_analyze_{timestamp}.json`文件,结构为: +```json +{ + "rank_0": [ // 卡号 + { + "op_name": "Tensor.op_name.0.forward", // 节点名 + "data_info": { + "input_args": [], // input_args数据 + "input_kwargs": {}, // input_kwargs数据 + "output": [] // output数据 + }, + "construct_info": [], // 节点层级数据 + "stack_info": {} // 堆栈数据 + } + ] +} +``` + +## 异常判定 + +### 异常计算节点判定 +当某个计算节点的输入值正常,即Max或Min中不存在inf或nan,而输出值存在异常时认为从此节点开始产生了溢出,并有可能向后传递。 + +### 异常通信节点判定 +通信节点按照功能分为有向节点,如`send`, `recv`, `scatter`, `gather`, `broadcast`, `reduce`等,以及无向节点,如`all_gather`, `all_reduce`, `reduce_scatter`, `all_to_all`等。 + +对于有向节点,当src节点的input存在异常时,通常认为传入的数据中本身就存在异常,因此考虑异常节点发生在src节点所在rank的上一个或多个计算节点中;当src节点的input正常而output存在异常值,或dst节点的output存在异常值时,考虑是通信节点本身的操作产生了异常数据。 + +对于无向节点,当节点input存在异常时,认为传入的数据中本身就存在异常,因此考虑异常节点发生在src节点所在rank的上一个或多个计算节点中;当input正常而output异常时,考虑是通信节点本身的操作产生了异常数据。 + +### 顺序判定 +对于相连接的有向通信算子,认为src节点的异常发生早于dst节点;对于无向通信算子,认为异常是同时发生的。 + +对于计算节点按照dump的顺序排序。 diff --git a/debug/accuracy_tools/msprobe/docs/36.calculation_result_change.md b/debug/accuracy_tools/msprobe/docs/36.calculation_result_change.md new file mode 100644 index 0000000000000000000000000000000000000000..8fa97d8e433725f20ed595fb5161dc2885b0f132 --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/36.calculation_result_change.md @@ -0,0 +1,75 @@ +# 模型计算结果改变原因分析 + +## 介绍 +在模型训练场景下,使用seed_all接口同时固定随机性和打开计算,通信确定性计算,是能够保证模型跑两次得到的loss和gnorm结果完全一样。如果出现使能工具后loss或者gnorm出现偏差,可能是以下原因导致。 + +## 工具引入同步导致计算结果变化 +工具采集统计量数据时,会涉及到将device上的tensor计算后的统计量信息通过item的时候传到cpu侧,再落盘到json文件中,item操作是一个同步的操作,可能会导致模型的计算结果出现变化。**一般的现象就是模型计算出现了NaN,但加了工具后问题不复现了。** + +ASCEND_LAUNCH_BLOCKING 是一个环境变量,用于控制在 PyTorch 训练或在线推理场景中算子的执行模式。当设置为“1”时,算子将采用同步模式运行。因此如果出现加工具计算结果变化,可以设置ASCEND_LAUNCH_BLOCKING 为1,如果结果一样发生了变化,则说明是由于同步引起的结果改变。这个时候需要复现问题现象完成问题定位,推荐使用msprobe工具的异步dump功能,具体使用方式可查看[config配置](02.config_introduction.md)中的async_dump字段。 + +## Hook机制导致计算结果变化 + +pytorch/mindspore的hook机制会导致某些特殊场景下梯度计算的累加序产生变化,从而影响模型反向计算的gnorm结果。具体代码示例如下: +```python +import random, os +import numpy as np +import torch +from torch import nn + + +class Net(nn.Module): + def __init__(self) -> None: + super().__init__() + self.ln1 = nn.Linear(32, 32) + self.bn1 = nn.BatchNorm1d(32) + self.ln2 = nn.Linear(32, 32) + + def forward(self, x): + x1 = self.ln1(x) + + x2 = self.bn1(x) + x2 = self.ln2(x2) + return x1 + x2 + + +class BigNet(nn.Module): + def __init__(self) -> None: + super().__init__() + self.net1 = Net() + self.net2 = Net() + + def forward(self, x): + out1 = self.net1(x) + out2 = self.net2(out1) + return out1, out2 + + +def my_backward_hook(module, grad_input, grad_output): + pass + + +if __name__ == "__main__": + os.environ["HCCL_DETERMINISTIC"] = 'true' + + seed = 1234 + os.environ["PYTHONHASHSEED"] = str(seed) + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + + model = BigNet() + model.net2.register_full_backward_hook(my_backward_hook) + inputs = torch.randn(3, 32) + + out1, out2 = model(inputs) + loss = out1.sum() + out2.sum() + loss.backward() + + for name, param in model.named_parameters(): + print(f"{name}: {param.grad.mean()}") + +``` +执行一遍以上脚本,可以打印得到模型中各层的权重梯度,注释model.net2.register_full_backward_hook(my_backward_hook)后再执行一篇,可以看出bn层的权重梯度已经发生了变化。 + +**如果在msprobe L0,mix级别采集出现gnorm发生变化,可以尝试将采集级别改为L1,若L1级别gnorm不发生变化,则大概率是hook机制导致的梯度计算结果变化。** diff --git a/debug/accuracy_tools/msprobe/docs/FAQ.md b/debug/accuracy_tools/msprobe/docs/FAQ.md index 833ca07a236f33e69b102d4acb45d35cd6fe7e3a..a08a9d84203c785632cbfc9e2596ffbf790c116a 100644 --- a/debug/accuracy_tools/msprobe/docs/FAQ.md +++ b/debug/accuracy_tools/msprobe/docs/FAQ.md @@ -36,6 +36,30 @@ 该信息说明 module 挂载了被 PyTorch 框架废弃的 register_backward_hook,这与工具使用的 register_full_backward_hook 接口会产生冲突,故工具会跳过该 module 的反向数据采集。 - 如果您希望所有 module 数据都能采集下来,可以将模型中使用的 register_backward_hook 接口改为 PyTorch 框架推荐的 register_full_backward_pre_hook 或 register_full_backward_hook 接口。 +5. 在vllm场景下进行数据dump时,发现报错:`RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and npu:0!` + - 这是因为工具的debugger实例化早于LLM实例化导致的,解决方法就需要将debugger的实例化移至LLM实例化之后进行,可参考下方示例: + ```python + from vllm import LLM, SamplingParams + from msprobe.pytorch import PrecisionDebugger + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + llm = LLM(model="Qwen/Qwen2.5-0.5B-Instruct") + + debugger = PrecisionDebugger("./config.json") # debugger实例化晚于LLM实例化 + + debugger.start() + outputs = llm.generate(prompts, sampling_params) + debugger.stop() + ``` + +6. 在使用 msprobe 进行 PyTorch 框架的数据采集功能时,请注意确认环境变量 NPU_ASD_ENABLE=0 ,即关闭特征值检测功能。 由于工具冲突, 在该功能开启的情况下可能导致某些 api 数据采集的缺失。 + # 2 精度预检(PyTorch) 1. 预检工具在 dump 和 run_ut 的过程中,是否需要同时开启或关闭 jit 编译(jit_compile)? @@ -58,11 +82,7 @@ 答:对于 fp16 的数据,CPU 会上升一个精度 fp32 去计算,这是和算子那边对齐的精度结论,CPU 用更高精度去计算会更接近真实值。 -6. 添加预检工具后截取操作报错:`IndexError: too many indices for tensor of dimension x` 或 `TypeError: len() of a 0-d tensor`。 - - 答:注释工具目录 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中 Tensor: 下的 `- __getitem__`,工具会跳过采集该 API。如果是需要 dump 关键位置 API 也可以考虑根据报错堆栈信息注释引发报错的类型检查。 - -7. Tensor 魔法函数具体对应什么操作? +6. Tensor 魔法函数具体对应什么操作? 答: @@ -202,15 +222,11 @@ def npu_forward_fused_softmax(self, input_, mask): 答:正常现象,dataloader 通过 raise 结束程序,堆栈信息可忽略。 -10. 添加 msprobe 工具后截取操作报错:`IndexError: too many indices for tensor of dimension x` 或 `TypeError: len() of a 0-d tensor`。 - - 答:注释工具目录 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中 `Tensor: ` 下的 `- __getitem__`,工具会跳过采集该 API。如果是需要采集关键位置 API 也可以考虑根据报错堆栈信息注释引发报错的类型检查。 - -11. 使用 msprobe 工具数据采集功能后,模型出现报错,报错信息为:`activation_func must be F.gelu` 或 `ValueError(Only support fusion of gelu and swiglu)`。 +10. 使用 msprobe 工具数据采集功能后,模型出现报错,报错信息为:`activation_func must be F.gelu` 或 `ValueError(Only support fusion of gelu and swiglu)`。 答:这一类报错常见于 Megatron/MindSpeed/ModelLink 等加速库或模型仓中,原因是工具本身会封装 torch 的 API(API类型和地址会发生改变),而有些 API 在工具使能前类型和地址就已经确定,此时工具无法对这类 API 再进行封装,而加速库中会对某些 API 进行类型检查,即会把工具无法封装的原始的 API和工具封装之后的 API 进行判断,所以会报错。 规避方式有3种:①将PrecisionDebugger的实例化放在文件的开始位置,即导包后的位置,确保所有API都被封装;②注释 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中的 `-gelu` 或者 `-silu`,工具会跳过采集该 API。③ 可以考虑根据报错堆栈信息注释引发报错的类型检查。 -12. 添加 msprobe 工具后触发与 AsStrided 算子相关、或者编译相关的报错,如:`Failed to compile Op [AsStrided]`。 +11. 添加 msprobe 工具后触发与 AsStrided 算子相关、或者编译相关的报错,如:`Failed to compile Op [AsStrided]`。 答:注释工具目录 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中 `Tensor: `下的 `-t` 和 `- transpose`。 diff --git a/debug/accuracy_tools/msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md b/debug/accuracy_tools/msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md index 0a76c51d71d77c9cbc86d98600203e6faa71a0f6..275aa66e53f25587facb2034dba5706b71bab0bb 100644 --- a/debug/accuracy_tools/msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +++ b/debug/accuracy_tools/msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md @@ -1,6 +1,17 @@ # MindSpore 场景的精度数据采集基线 -## "tensor"模式采集数据量参考基线 +## "statistics"模式(未开启md5)采集**时间**膨胀参考基线 + +该基线为MindSpore框架下,使用"statistics"模式采集数据性能膨胀参考基线。测试了38B语言大模型在不同采集模式8卡下的性能膨胀。 + +| 采集模式 | 无工具 (耗时) | 加工具但未使能 Dump (耗时) | 加工具并使能 Dump (耗时) | +|:--------:|:-------------:|:--------------------:|:----------------:| +| L0 | ≈340 ms | ≈340 ms (无膨胀) | ≈1.2 s (膨胀3.5倍) | +| L1 | ≈340 ms | ≈0.7–1.2 s (膨胀2~4倍) | ≈3.8 s (膨胀11倍) | +| mix | ≈340 ms | ≈0.7–1.2 s (膨胀2~4倍) | ≈5.5 s (膨胀16倍) | + + +## "tensor"模式采集**数据量**参考基线 该基线为MindSpore框架下,使用"tensor"模式采集数据量参考基线。本基线测试了38B语言大模型在不同采集模式下,不同global_batch_size下,单卡和8卡下,数据量的变化。 diff --git a/debug/accuracy_tools/msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md b/debug/accuracy_tools/msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md index 543d260650361431ffb8b5142ae3df6b09d0db1d..e5e43b6f8680d5eda9f8eaef9bf525c207c9e3be 100644 --- a/debug/accuracy_tools/msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +++ b/debug/accuracy_tools/msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md @@ -35,7 +35,7 @@ import os import numpy as np import mindspore as ms from mindspore import nn, ops -from mindspore import context +from mindspore import context, set_device, set_deterministic from mindspore import Tensor from msprobe.mindspore import PrecisionDebugger, seed_all @@ -50,7 +50,12 @@ config_path = os.path.join(script_dir, 'config.json') debugger = PrecisionDebugger(config_path=config_path) # 设置 MindSpore 设备上下文 -context.set_context(mode=ms.PYNATIVE_MODE, device_target="Ascend", device_id=0) +context.set_context(mode=ms.PYNATIVE_MODE) + +set_device("Ascend", 0) + +set_deterministic(True) +print("Context set successfully. Please wait for the training task.") # 定义卷积层 def conv_layer(in_channels, out_channels, kernel_size, stride=1, padding=0, pad_mode="valid", has_bias=True): @@ -199,7 +204,7 @@ python alexnet_model.py ## 5. 数据分析 -在 `dump_path` 参数指定的路径下(本例中为 `./output`),会出现如下目录结构,后续精度数据分析操作可使用 msprobe 工具的精度预检和精度比对等功能,详细流程请参见[《msprobe使用手册》](../../README.md#2-精度预检)。: +在 `dump_path` 参数指定的路径下(本例中为 `./output`),会出现如下目录结构,后续精度数据分析操作可使用 msprobe 工具的精度预检和精度比对等功能,详细流程请参见[《msprobe使用手册》](../../README.md#2-精度预检)。 ```bash output/ @@ -208,4 +213,5 @@ output/ ├── construct.json # level为L0时,保存Cell的层级关系信息。当前场景为空 ├── dump.json # 保存API前反向输入输出数据的统计量信息 └── stack.json # 保存API的调用栈 + ...... ``` \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/img/compare_result.png b/debug/accuracy_tools/msprobe/docs/img/compare_result.png index 07cdb51707fe43d07723ed976275d99f55b50571..3c226d38c362416dc441fe4a48aa30f19d137c41 100644 Binary files a/debug/accuracy_tools/msprobe/docs/img/compare_result.png and b/debug/accuracy_tools/msprobe/docs/img/compare_result.png differ diff --git a/debug/accuracy_tools/msprobe/docs/img/ms_layer.png b/debug/accuracy_tools/msprobe/docs/img/ms_layer.png index d64fc0bbc0c7fe6c7d99151ec9d1ab589436eb09..ddacdc97b2934aab3d8d68cec5445f3d09136019 100644 Binary files a/debug/accuracy_tools/msprobe/docs/img/ms_layer.png and b/debug/accuracy_tools/msprobe/docs/img/ms_layer.png differ diff --git a/debug/accuracy_tools/msprobe/docs/img/save_compare_result_sample.png b/debug/accuracy_tools/msprobe/docs/img/save_compare_result_sample.png new file mode 100644 index 0000000000000000000000000000000000000000..51f902e1b9acdc17255ae7745a77a2b9bc5117b6 Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/img/save_compare_result_sample.png differ diff --git a/debug/accuracy_tools/msprobe/docs/img/visualization/proxy.png b/debug/accuracy_tools/msprobe/docs/img/visualization/proxy.png new file mode 100644 index 0000000000000000000000000000000000000000..3033214904ca3a8a1f50f187a382c47c23f05786 Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/img/visualization/proxy.png differ diff --git a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_browser_1.png b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_browser_1.png index 96e8521fde4b776ba915a00b5d77851b8406c153..ff10a9cd742a42a0133481bf20f83ff95ddf8a49 100644 Binary files a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_browser_1.png and b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_browser_1.png differ diff --git a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_match_info.png b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_match_info.png new file mode 100644 index 0000000000000000000000000000000000000000..adc470d4bfe8b706b1b05ccb331246930bcfabb4 Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_match_info.png differ diff --git a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_precision_info.png b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_precision_info.png index ddd59b37f044fe64c02148b698b95296592e0399..33497c110f49bcd09f0caa0e3b632ae973620c61 100644 Binary files a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_precision_info.png and b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_precision_info.png differ diff --git a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_search_info.png b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_search_info.png index 7c55b33840163c388f8fde69f0bbc531b23f81f6..13dba9b83b6c3b58ab67265827ef32bf5fb1822b 100644 Binary files a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_search_info.png and b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_search_info.png differ diff --git a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_show_info.png b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_show_info.png index 9a6217e04848e671d784ed0b484d2fe10151bde7..c08d59be266f96e2b473ee1cb7141120e0ee3aa4 100644 Binary files a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_show_info.png and b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_show_info.png differ diff --git a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_showcase.png b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_showcase.png index e95b5eeee663d91a67b1ace422c8681797ca96c1..fe71f73ad01410caaa47333723ce040ccc2d88dc 100644 Binary files a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_showcase.png and b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_showcase.png differ diff --git a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_unmatch_info.png b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_unmatch_info.png index e4c9ed4306f9a7b20d031d32f18c815628030da6..6fb0c74a2a09986e22503dcc822c7b90506eb159 100644 Binary files a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_unmatch_info.png and b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_unmatch_info.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png new file mode 100644 index 0000000000000000000000000000000000000000..791befb7d42725d2ab8fe24377f223893080e36b Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png new file mode 100644 index 0000000000000000000000000000000000000000..f8ac5b391d462d2b2e720fc1189762821a9c03eb Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png new file mode 100644 index 0000000000000000000000000000000000000000..7e876f81083f368dd20d66c3a91e37943344904a Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png new file mode 100644 index 0000000000000000000000000000000000000000..e4798076a65454a1507e923ebe8111fdaa4926bb Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png new file mode 100644 index 0000000000000000000000000000000000000000..9a7b68f54ec6ba1e43528fdae20ae495d50a0705 Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png new file mode 100644 index 0000000000000000000000000000000000000000..9bc7c4621f884229c8c99a41c53499cad151dbe8 Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png new file mode 100644 index 0000000000000000000000000000000000000000..1c8731a470c00bd108358407b2a13a9c613d417c Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt new file mode 100644 index 0000000000000000000000000000000000000000..b0322dc208b0aa0490e4adb18bcd8b6ec7b1b557 --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt @@ -0,0 +1,59 @@ +DeepSpeedEngine( + (module): Qwen2_5_VLForConditionalGeneration( + (visual): Qwen2_5_VisionTransformerPretrainedModel( + (patch_embed): Qwen2_5_VisionPatchEmbed( + (proj): Conv3d(3, 1280, kernel_size=(2, 14, 14), stride=(2, 14, 14), bias=False) + ) + (rotary_pos_emb): Qwen2_5_VisionRotaryEmbedding() + (blocks): ModuleList( + (0-15): 16 x Qwen2_5_VLVisionBlock( + (norm1): Qwen2RMSNorm((0,), eps=1e-06) + (norm2): Qwen2RMSNorm((0,), eps=1e-06) + (attn): Qwen2_5_VLVisionSdpaAttention( + (qkv): Linear(in_features=1280, out_features=3840, bias=True) + (proj): Linear(in_features=1280, out_features=1280, bias=True) + ) + (mlp): Qwen2_5_VLMLP( + (gate_proj): Linear(in_features=1280, out_features=3420, bias=True) + (up_proj): Linear(in_features=1280, out_features=3420, bias=True) + (down_proj): Linear(in_features=3420, out_features=1280, bias=True) + (act_fn): SiLU() + ) + ) + ) + (merger): Qwen2_5_VLPatchMerger( + (ln_q): Qwen2RMSNorm((0,), eps=1e-06) + (mlp): Sequential( + (0): Linear(in_features=5120, out_features=5120, bias=True) + (1): GELU(approximate='none') + (2): Linear(in_features=5120, out_features=2048, bias=True) + ) + ) + ) + (model): Qwen2_5_VLModel( + (embed_tokens): Embedding(151936, 2048) + (layers): ModuleList( + (0-7): 8 x Qwen2_5_VLDecoderLayer( + (self_attn): Qwen2_5_VLSdpaAttention( + (q_proj): Linear(in_features=2048, out_features=2048, bias=True) + (k_proj): Linear(in_features=2048, out_features=256, bias=True) + (v_proj): Linear(in_features=2048, out_features=256, bias=True) + (o_proj): Linear(in_features=2048, out_features=2048, bias=False) + (rotary_emb): Qwen2_5_VLRotaryEmbedding() + ) + (mlp): Qwen2MLP( + (gate_proj): Linear(in_features=2048, out_features=11008, bias=False) + (up_proj): Linear(in_features=2048, out_features=11008, bias=False) + (down_proj): Linear(in_features=11008, out_features=2048, bias=False) + (act_fn): SiLU() + ) + (input_layernorm): Qwen2RMSNorm((0,), eps=1e-06) + (post_attention_layernorm): Qwen2RMSNorm((0,), eps=1e-06) + ) + ) + (norm): Qwen2RMSNorm((0,), eps=1e-06) + (rotary_emb): Qwen2_5_VLRotaryEmbedding() + ) + (lm_head): Linear(in_features=2048, out_features=151936, bias=False) + ) +) diff --git a/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png new file mode 100644 index 0000000000000000000000000000000000000000..4be2f18a16cc28b740f5c1e7f181d18e9b1463c5 Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png new file mode 100644 index 0000000000000000000000000000000000000000..21bd1b887c7cddc3f9fad2473a4746d7f4098587 Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt new file mode 100644 index 0000000000000000000000000000000000000000..9f4ba4fd044482fddbc9d24d88d5b29980ebcc7e --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt @@ -0,0 +1,80 @@ +[DistributedDataParallel( + (module): Float16Module( + (module): VLMModel( + (image_encoder): VisionModel( + (encoder): Qwen2VLViT( + (patch_embed): PatchEmbed( + (proj): Conv3d(3, 1280, kernel_size=(2, 14, 14), stride=(2, 14, 14), bias=False) + ) + (rotary_pos_emb): VisionRotaryEmbedding() + (blocks): Qwen2VLVisionTransformerBlock( + (layers): ModuleList( + (0-15): 16 x TransformerLayer( + (input_layernorm): RMSNorm() + (self_attention): Qwen2vlVitSelfAttention( + (core_attention): DotProductAttention( + (scale_mask_softmax): FusedScaleMaskSoftmax() + (attention_dropout): Dropout(p=0.0, inplace=False) + ) + (linear_proj): RowParallelLinear() + (linear_qkv): ColumnParallelLinear() + (q_layernorm): IdentityOp() + (k_layernorm): IdentityOp() + ) + (pre_cross_attn_layernorm): IdentityOp() + (cross_attention): IdentityOp() + (cross_attn_bda): IdentityFuncOp() + (pre_mlp_layernorm): RMSNorm() + (mlp): MLP( + (linear_fc1): ColumnParallelLinear() + (linear_fc2): RowParallelLinear() + ) + ) + ) + ) + ) + (projector): MultimodalProjector( + (layernorm): RMSNorm() + (encoder): MLP( + (linear_fc1): ColumnParallelLinear() + (linear_fc2): RowParallelLinear() + ) + ) + ) + (text_decoder): MMGPTModel( + (embedding): LanguageModelEmbedding( + (word_embeddings): VocabParallelEmbedding() + (embedding_dropout): Dropout(p=0.0, inplace=False) + ) + (rotary_pos_emb): Qwen2VLRotaryEmbedding_llm() + (decoder): TransformerBlock( + (layers): ModuleList( + (0-7): 8 x TransformerLayer( + (input_layernorm): RMSNorm() + (self_attention): Qwen2vlSelfAttention( + (core_attention): DotProductAttention( + (scale_mask_softmax): FusedScaleMaskSoftmax() + (attention_dropout): Dropout(p=0.0, inplace=False) + ) + (linear_proj): RowParallelLinear() + (linear_qkv): ColumnParallelLinear() + (q_layernorm): IdentityOp() + (k_layernorm): IdentityOp() + ) + (pre_cross_attn_layernorm): IdentityOp() + (cross_attention): IdentityOp() + (cross_attn_bda): IdentityFuncOp() + (pre_mlp_layernorm): RMSNorm() + (mlp): MLP( + (linear_fc1): ColumnParallelLinear() + (linear_fc2): RowParallelLinear() + ) + ) + ) + (final_layernorm): RMSNorm() + ) + (output_layer): ColumnParallelLinear() + ) + ) + ) +)] diff --git a/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png new file mode 100644 index 0000000000000000000000000000000000000000..7346684b53a59f1a67a894fcdae143b6a0b142c2 Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png new file mode 100644 index 0000000000000000000000000000000000000000..c3d485a5a9dbcf58e2737f3d49f830bfc9d1b54d Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactory_mapping.md b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactory_mapping.md new file mode 100644 index 0000000000000000000000000000000000000000..c9d93e532a6cf3f34931e547aa4003e28feaedf4 --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/visualization/mindspeed_llamafactory_mapping.md @@ -0,0 +1,330 @@ +# MindSpeed&LLamaFactory数据采集和自动比对 + +## 0. 使用场景 +基于MindSpeed和LLamaFactory框架实现的同一模型,在模型超参、环境变量、初始权重、训练数据等一致的前提下,训练过程中出现了精度差异,需要进行**整网比对**寻找精度差异点。 + +本文选取Qwen2.5vl和Qwen2.5模型,指导用户如何进行MindSpeed&LLamaFactory数据采集和自动比对。 + +## 1. 数据采集 + +### 1.1 准备数据采集配置文件 + +数据采集前需要准备一个json文件,本案例命名为config.json,其内容包含了数据采集的所需配置。 + +本案例使用的配置内容如下,更多配置请参考[config.json 配置示例](../03.config_examples.md),配置详细介绍请参考[配置文件介绍](../02.config_introduction.md)。 + +```json +{ + "task": "statistics", + "dump_path": "/home/data_dump", + "rank": [], + "step": [0], + "level": "mix", + "async_dump": false, + + "statistics": { + "scope": [], + "list": [], + "tensor_list": [], + "data_mode": ["all"], + "summary_mode": "statistics" + } +} +``` +请注意,在数据采集结束后将进行模型分级可视化比对,配置文件中的`level`需要配置为`L0`(模块数据)或`mix`(模块+API数据)。 + +### 1.2 添加msprobe工具采集接口 + +本案例使用的工具采集接口配置如下,更多配置和接口介绍请参考[PyTorch 场景的精度数据采集](../05.data_dump_PyTorch.md)。 + +#### 1.2.1 LLamaFactory数据采集 + +LLamaFactory依赖Transformers的底层能力,msprobe工具采集功能将添加在Transformers中。 + +以Transformers 4.49.0版本为例,通过`pip3 show Transformers`获取`Location路径`,打开`Location路径/transformers/trainer.py`文件。 + +1. 在trainer.py文件中添加工具接口,初始化数据采集配置以及固定随机数: + + ![llamafactory1.png](./mindspeed_llamafactoary_img/llamafactory1.png) + +2. 在trainer.py文件**训练循环逻辑位置**添加工具接口,控制数据采集的启动、停止和step计数: + + ![llamafactory2.png](./mindspeed_llamafactoary_img/llamafactory2.png) + +3. 配置完成,启动模型训练脚本,数据将自动采集,落盘数据格式请参考[PyTorch 场景的精度数据采集-dump-结果文件介绍](../05.data_dump_PyTorch.md#3-dump-结果文件介绍)。 + +#### 1.2.2 MindSpeed数据采集 + +打开training.py文件,MindSpeed-MM路径为`mindspeed_mm/training.py`,MindSpeed-LLM路径为`mindspeed_llm/training/training.py`。 + +1. 在training.py文件中添加工具接口,初始化数据采集配置以及固定随机数: + + ![mindspeed1.png](./mindspeed_llamafactoary_img/mindspeed1.png) + +2. 在training.py文件**训练循环逻辑位置**添加工具接口,控制数据采集的启动、停止和step计数: + + ![mindspeed2.png](./mindspeed_llamafactoary_img/mindspeed2.png) + +3. 配置完成,启动模型训练脚本,数据将自动采集,落盘数据格式请参考[PyTorch 场景的精度数据采集-dump-结果文件介绍](../05.data_dump_PyTorch.md#3-dump-结果文件介绍)。 + +## 2. 自动比对 + +### 2.1 模型分级可视化比对 + +该功能将msprobe工具dump的精度数据进行解析,还原模型图结构,实现模型各个层级的精度数据比对,方便用户理解模型结构、分析精度问题。 + +我们将使用以下命令行进行模型分级可视化比对: + +``` +msprobe -f pytorch graph -i ./compare.json -o ./output -lm ./layer_mapping.yaml +``` +具体的参数说明请点击查看[分级可视化构图比对-构图命令行说明](../21.visualization_PyTorch.md#31-构图命令行说明)。 + +在基于MindSpeed和LLamaFactory框架的模型比对场景中,**-lm参数是必填的**,-lm参数所需的layer_mapping.yaml如何配置将在下面的章节进行介绍。 + +模型分级可视化比对完成后,可通过tensorboard(需安装[tb_graph_ascend插件](../21.visualization_PyTorch.md#1依赖安装))启动端口,在浏览器页面查看模型结构和精度比对结果,请参考[分级可视化构图比对-启动tensorboard](../21.visualization_PyTorch.md#4启动tensorboard)和[分级可视化构图比对-浏览器查看](../21.visualization_PyTorch.md#5浏览器查看)。 + +### 2.2 layer_mapping映射文件配置 +msprobe工具的比对功能会将比对双方dump名称一致的数据进行比对。由于MindSpeed和LLamaFactory框架代码实现的差异,一些模型层级和层级名称有所不同导致无法进行匹配,需要进行layer层名称映射,才能够比对。 + +#### 2.2.1 layer_mapping映射文件模板 + +此处提供了Qwen2.5vl和Qwen2.5模型的layer_mapping映射文件模板,可直接使用。**如果您使用其他模型,或对MindSpeed和LLamaFactory框架进行过定制开发修改过框架源码,此layer_mapping映射文件模板可能会失效,请按照后续步骤修改layer_mapping映射文件模板**。 + +每个模型有两个layer_mapping映射文件模板,分别是NPU侧为Mindspeed Bench侧为LLamaFactory,以及NPU侧为LLamaFactory Bench侧为Mindspeed,映射内容有所不同。 + +文件名格式:\*.yaml,*为文件名,可自定义。本文命名为layer_mapping.yaml。 + +**Qwen2.5vl** + +```yaml +# NPU侧为Mindspeed-MM, Bench侧为LLamaFactory +TopLayer: + 0.module: module + +Float16Module: + module.image_encoder: visual + module.text_decoder: model + +VisionModel: + encoder.patch_embed: patch_embed + encoder.rotary_pos_emb: rotary_pos_emb + encoder.blocks.layers: blocks + projector: merger + +TransformerLayer: + input_layernorm: norm1 + self_attention: attn + pre_mlp_layernorm: norm2 + +Qwen2vlVitSelfAttention: + linear_qkv: qkv + linear_proj: proj + +MLP: + linear_fc1: up_proj + linear_fc2: down_proj + +MultimodalProjector: + layernorm: ln_q + encoder: mlp + encoder.linear_fc1: mlp.0 + encoder.linear_fc2: mlp.2 + +MMGPTModel: + embedding.word_embeddings: embed_tokens + rotary_pos_emb: rotary_emb + decoder.layers: layers + decoder.final_layernorm: norm + output_layer: lm_head +``` +```yaml +# NPU侧为LLamaFactory, Bench侧为Mindspeed-MM +TopLayer: + module: 0.module + +Qwen2_5_VLForConditionalGeneration: + visual: module.image_encoder + model: module.text_decoder + lm_head: module.text_decoder.output_layer + +Qwen2_5_VisionTransformerPretrainedModel: + patch_embed: encoder.patch_embed + rotary_pos_emb: encoder.rotary_pos_emb + blocks: encoder.blocks.layers + merger: projector + +Qwen2_5_VLVisionBlock: + norm1: input_layernorm + attn: self_attention + norm2: pre_mlp_layernorm + +Qwen2_5_VLVisionSdpaAttention: + qkv: linear_qkv + proj: linear_proj + +Qwen2_5_VLMLP: + up_proj: linear_fc1 + down_proj: linear_fc2 + +Qwen2_5_VLPatchMerger: + ln_q: layernorm + mlp: encoder + mlp.0: encoder.linear_fc1 + mlp.2: encoder.linear_fc2 + +Qwen2_5_VLModel: + embed_tokens: embedding.word_embeddings + rotary_emb: rotary_pos_emb + layers: decoder.layers + norm: decoder.final_layernorm + +Qwen2_5_VLDecoderLayer: + self_attn: self_attention + self_attn.o_proj: self_attention.linear_proj + post_attention_layernorm: pre_mlp_layernorm +``` + +**Qwen2.5** + +```yaml +# NPU侧为Mindspeed-LLM, Bench侧为LLamaFactory +TopLayer: + 0.module: module + +Float16Module: + module: model + module.output_layer: lm_head + +GPTModel: + embedding.word_embeddings: embed_tokens + decoder.layers: layers + decoder.final_layernorm: norm + +TransformerLayer: + self_attention: self_attn + pre_mlp_layernorm: post_attention_layernorm + +SelfAttention: + linear_proj: o_proj + +MLP: + linear_fc1: up_proj + linear_fc2: down_proj +``` +```yaml +# NPU侧为LLamaFactory, Bench侧为Mindspeed-LLM +TopLayer: + module: 0.module + +Qwen2ForCausalLM: + model: module + lm_head: module.output_layer + +Qwen2Model: + embed_tokens: embedding.word_embeddings + layers: decoder.layers + norm: decoder.final_layernorm + +Qwen2DecoderLayer: + self_attn: self_attention + post_attention_layernorm: pre_mlp_layernorm + +Qwen2Attention: + o_proj: linear_proj + +Qwen2MLP: + up_proj: linear_fc1 + down_proj: linear_fc2 +``` + +#### 2.2.2 layer_mapping映射文件配置过程 +以Qwen2.5vl模型,NPU侧MindSpeed,Bench侧LLamaFactory为例。 + +1. 模型结构打印 + + 参考[添加msprobe工具采集接口](#12-添加msprobe工具采集接口)章节,配置过程中会在模型文件中添加`debugger.start(model=model)`,针对`start接口`中的`model`进行`print(model)`即可打印模型结构。 + + 打印的模型结构:[mindspeed-mm-qwen25vl.txt](./mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt),[llamafactory-qwen25vl.txt](./mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt) + +2. 基于模型结构由外到内进行layer mapping配置 + +- 结构1 + + ![1.png](./mindspeed_llamafactoary_img/1.png) + + ```yaml + TopLayer: # 代表模型最顶层 + 0.module: module # MindSpeed的model类型是list,msprobe采集会对其添加数字前缀,代表当前模型在list中的索引,因此要做0.module -> module的映射 + + Float16Module: # MindSpeed的Float16Module与LLamaFactory的Qwen2_5_VLForConditionalGeneration同级,对它们的子层进行映射 + module.image_encoder: visual # MindSpeed的Float16Module多了一个子层module,跨层级用"."分隔,配置为module.image_encoder + module.text_decoder: model + ``` +- 结构2 + + ![2.png](./mindspeed_llamafactoary_img/2.png) + + ```yaml + VisionModel: # MindSpeed的VisionModel与LLamaFactory的Qwen2_5_VisionPatchEmbed同级,对它们的子层进行映射 + encoder.patch_embed: patch_embed + encoder.rotary_pos_emb: rotary_pos_emb + encoder.blocks.layers: blocks + projector: merger + ``` +- 结构3 + + ![3.png](./mindspeed_llamafactoary_img/3.png) + + ```yaml + TransformerLayer: # MindSpeed的TransformerLayer与LLamaFactory的Qwen2_5_VLVisionBlock同级,对它们的子层进行映射 + input_layernorm: norm1 + self_attention: attn + pre_mlp_layernorm: norm2 + ``` +- 结构4 + + ![4.png](./mindspeed_llamafactoary_img/4.png) + + ```yaml + Qwen2vlVitSelfAttention: # MindSpeed的Qwen2vlVitSelfAttention与LLamaFactory的Qwen2_5_VLVisionSdpaAttention同级,对它们的子层进行映射 + linear_qkv: qkv + linear_proj: proj + + MLP: # MindSpeed的MLP与LLamaFactory的Qwen2_5_VLMLP同级,对它们的子层进行映射 + linear_fc1: up_proj + linear_fc2: down_proj + ``` +- 结构5 + + ![5.png](./mindspeed_llamafactoary_img/5.png) + + ```yaml + MultimodalProjector: # MindSpeed的MultimodalProjector与LLamaFactory的Qwen2_5_VLPatchMerger同级,对它们的子层进行映射 + layernorm: ln_q + encoder: mlp + encoder.linear_fc1: mlp.0 + encoder.linear_fc2: mlp.2 + ``` +- 结构6 + + ![6.png](./mindspeed_llamafactoary_img/6.png) + + ```yaml + MMGPTModel: # MindSpeed的MMGPTModel与LLamaFactory的Qwen2_5_VLModel同级,对它们的子层进行映射 + embedding.word_embeddings: embed_tokens + rotary_pos_emb: rotary_emb + decoder.layers: layers + decoder.final_layernorm: norm + output_layer: lm_head + ``` +- 结构7 + + ![7.png](./mindspeed_llamafactoary_img/7.png) + + 由于TransformerLayer和MLP层已经配置过,无法再重复配置,此处的节点映射可通过[手动选择节点匹配](#23-手动选择节点匹配)完成。 + +### 2.3 手动选择节点匹配 +如果通过layer_mapping映射配置后,还有节点未匹配上,可通过浏览器界面,使用鼠标选择两个待匹配的灰色节点进行匹配。 + +请参考[分级可视化构图比对-手动选择节点匹配](../21.visualization_PyTorch.md#56-手动选择节点匹配)。 diff --git a/debug/accuracy_tools/msprobe/mindspore/__init__.py b/debug/accuracy_tools/msprobe/mindspore/__init__.py index 089c29eb098ad4305edcca1306462f8924dd9291..c36ea84caace7de247ba97f2c8b504627786f9de 100644 --- a/debug/accuracy_tools/msprobe/mindspore/__init__.py +++ b/debug/accuracy_tools/msprobe/mindspore/__init__.py @@ -17,12 +17,12 @@ import os try: from msprobe.lib import _msprobe_c - os.environ["MS_HOOK_ENABLE"] = "on" os.environ["HOOK_TOOL_PATH"] = _msprobe_c.__file__ except ImportError: from .common.log import logger logger.info("Module _msprobe_c has not been installed. L2-Dump may not work normally.") from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger -from msprobe.mindspore.common.utils import seed_all -from msprobe.mindspore.monitor.module_hook import TrainerMon \ No newline at end of file +from msprobe.mindspore.common.utils import seed_all, MsprobeStep, MsprobeInitStep +from msprobe.mindspore.monitor.module_hook import TrainerMon +from msprobe.mindspore.dump.graph_tensor_dump import save, save_grad, step \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py index 98c6b4b98530ec447c2e239c11b5d4d7b927d874..af6e3602e353843d7d3763856c6d27b056488361 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py @@ -14,9 +14,11 @@ # limitations under the License. import os +from dataclasses import dataclass +from typing import Any, Optional from tqdm import tqdm - -from msprobe.core.common.const import Const, CompareConst, MsCompareConst +import numpy as np +from msprobe.core.common.const import Const, CompareConst from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, load_json, load_yaml from msprobe.core.common.utils import add_time_as_suffix from msprobe.mindspore.api_accuracy_checker.api_info import ApiInfo @@ -25,8 +27,12 @@ from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compar from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_json_dict, global_context, trim_output_compute_element_list) +from msprobe.mindspore.common.const import MsCompareConst from msprobe.mindspore.common.log import logger from msprobe.mindspore.api_accuracy_checker import torch_mindtorch_importer +from msprobe.core.data_dump.data_collector import build_data_collector +from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation +from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, MsCompareConst.SUPPORTED_API_LIST_FILE) @@ -58,13 +64,129 @@ class ProcessResultPacket: self.err_msg = err_msg +@dataclass +class Config: + execution_mode: str + dump_path: str + task: str + level: str + scope: Optional[Any] + list: Optional[Any] + framework: str + data_mode: str + file_format: str + dump_tensor_data_dir: str + async_dump: bool + summary_mode: Optional[Any] = None + + class ApiAccuracyChecker: def __init__(self, args): self.api_infos = dict() self.data_manager = DataManager(args.out_path, args.result_csv_path) # 在初始化时实例化 DataManager + self.save_error_data = args.save_error_data + if self.save_error_data: + config, dump_path_aggregation = self.init_save_error_data(args) + self.data_collector = build_data_collector(config) + self.data_collector.update_dump_paths(dump_path_aggregation) + + @staticmethod + def init_save_error_data(args): + config = Config( + execution_mode="pynative", + dump_path=f"{args.out_path}", + dump_tensor_data_dir=f"{args.out_path}", + task="tensor", # 任务类型,模拟保存tensor数据 + level="L1", # 级别 + scope=None, # 作用域 (None) + list=None, # API 列表 (None) + framework=Const.MS_FRAMEWORK, # 框架类型 + data_mode="all", + file_format="npy", + async_dump=False + ) + + dump_dir = f"{args.out_path}" + dump_data_dir = os.path.join(dump_dir, "error_data") + create_directory(dump_data_dir) + dump_path_aggregation = DumpPathAggregation() + dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json") + dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json") + dump_path_aggregation.dump_error_info_path = os.path.join(dump_dir, "dump_error_info.log") + dump_path_aggregation.dump_tensor_data_dir = dump_data_dir + return config, dump_path_aggregation + + @staticmethod + def prepare_api_input_aggregation(api_info, forward_or_backward=Const.FORWARD): + """ + Args: + api_info: ApiInfo + forward_or_backward: str + Returns: + ApiInputAggregation + """ + forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT) + kwargs = api_info.get_kwargs() + if forward_or_backward == Const.FORWARD: + gradient_inputs = None + else: + gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT) + return ApiInputAggregation(forward_inputs, kwargs, gradient_inputs) @staticmethod - def run_and_compare_helper(api_info, api_name_str, api_input_aggregation, forward_or_backward): + def is_api_checkable(api_name_str): + ''' + Args: + api_name_str: str, e.g. "MintFunctional.relu.0.forward", key in data field of api_info.json + Returns: + is_checkable: bool + Description: + tell whether this api is checkable based on the key in "data" dict in api_info.json + ''' + api_name_str_list = api_name_str.split(Const.SEP) + if len(api_name_str_list) < MsCompareConst.API_NAME_STR_LENGTH: + return False + api_type_str = api_name_str_list[0] + real_api_str = Const.SEP.join(api_name_str_list[1:-2]) + api_list = load_yaml(yaml_path) + supported_tensor_api_list = api_list.get(MsCompareConst.SUPPORTED_TENSOR_LIST_KEY) + supported_fusion_api_list = MsCompareConst.SUPPORTED_FUSION_LIST + if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL) \ + and global_context.get_framework() == Const.MS_FRAMEWORK: + return True + if api_type_str in MsCompareConst.MT_VALID_API_TYPES \ + and global_context.get_framework() == Const.MT_FRAMEWORK: + return True + if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list \ + and global_context.get_framework() == Const.MS_FRAMEWORK: + return True + if api_type_str == MsCompareConst.FUNCTIONAL_API and real_api_str in supported_fusion_api_list \ + and global_context.get_framework() == Const.MS_FRAMEWORK: + return True + return False + + def post_forward_hook(self, api_or_module_name, primitive_instance, args, kwargs, output): + self.data_collector.update_api_or_module_name(api_or_module_name) + module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output) + self.data_collector.forward_data_collect_only_tensor( + api_or_module_name, + primitive_instance, + os.getpid(), + module_input_output + ) + + def backward_hook(self, api_or_module_name, module, grad_input, grad_output): + self.data_collector.update_api_or_module_name(api_or_module_name) + + module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input) + self.data_collector.backward_data_collect_only_tensor( + api_or_module_name, + module, + os.getpid(), + module_input_output + ) + + def run_and_compare_helper(self, api_info, api_name_str, api_input_aggregation, forward_or_backward): """ Args: api_info: ApiInfo @@ -82,13 +204,22 @@ class ApiAccuracyChecker: """ # get output if global_context.get_is_constructed(): - # constructed situation, need use constructed input to run mindspore api getting tested_output - tested_outputs = api_runner(api_input_aggregation, api_name_str, - forward_or_backward, global_context.get_framework()) + if forward_or_backward == Const.FORWARD: + tested_outputs, inputs, kwargs, forward_result_tuple = api_runner(api_input_aggregation, api_name_str, + forward_or_backward, + global_context.get_framework()) + elif forward_or_backward == Const.BACKWARD: + tested_outputs, gradient_inputs, backward_result_tuple = api_runner(api_input_aggregation, api_name_str, + forward_or_backward, + global_context.get_framework()) + else: + tested_outputs = api_runner(api_input_aggregation, api_name_str, + forward_or_backward, global_context.get_framework()) else: tested_outputs = api_info.get_compute_element_list(forward_or_backward, Const.OUTPUT) bench_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.PT_FRAMEWORK) + tested_outputs = trim_output_compute_element_list(tested_outputs, forward_or_backward) bench_outputs = trim_output_compute_element_list(bench_outputs, forward_or_backward) if len(tested_outputs) != len(bench_outputs): @@ -113,60 +244,26 @@ class ApiAccuracyChecker: compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS: status = CompareConst.PASS err_msg = "" + else: status = CompareConst.ERROR err_msg = (compare_result_dict.get(CompareConst.COSINE).err_msg + compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg) + if forward_or_backward == Const.FORWARD and self.save_error_data \ + and global_context.get_is_constructed(): + api_name_str_backward = f"{api_name_str}{Const.SEP}{Const.FORWARD}" + self.post_forward_hook(api_name_str_backward, None, inputs, kwargs, forward_result_tuple) + + if forward_or_backward == Const.BACKWARD and self.save_error_data \ + and global_context.get_is_constructed(): + api_name_str_backward = f"{api_name_str}{Const.SEP}{Const.BACKWARD}" + self.backward_hook(api_name_str_backward, None, gradient_inputs, backward_result_tuple) + basic_info_status = \ BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg) output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict])) return output_list - @staticmethod - def prepare_api_input_aggregation(api_info, forward_or_backward=Const.FORWARD): - """ - Args: - api_info: ApiInfo - forward_or_backward: str - Returns: - ApiInputAggregation - """ - forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT) - kwargs = api_info.get_kwargs() - if forward_or_backward == Const.FORWARD: - gradient_inputs = None - else: - gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT) - return ApiInputAggregation(forward_inputs, kwargs, gradient_inputs) - - @staticmethod - def is_api_checkable(api_name_str): - ''' - Args: - api_name_str: str, e.g. "MintFunctional.relu.0.forward", key in data field of api_info.json - Returns: - is_checkable: bool - Description: - tell whether this api is checkable based on the key in "data" dict in api_info.json - ''' - api_name_str_list = api_name_str.split(Const.SEP) - if len(api_name_str_list) < MsCompareConst.API_NAME_STR_LENGTH: - return False - api_type_str = api_name_str_list[0] - real_api_str = Const.SEP.join(api_name_str_list[1:-2]) - api_list = load_yaml(yaml_path) - supported_tensor_api_list = api_list.get(MsCompareConst.SUPPORTED_TENSOR_LIST_KEY) - if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL) \ - and global_context.get_framework() == Const.MS_FRAMEWORK: - return True - if api_type_str in MsCompareConst.MT_VALID_API_TYPES \ - and global_context.get_framework() == Const.MT_FRAMEWORK: - return True - if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list \ - and global_context.get_framework() == Const.MS_FRAMEWORK: - return True - return False - def parse(self, api_info_path): api_info_dict = load_json(api_info_path) @@ -178,9 +275,9 @@ class ApiAccuracyChecker: MsCompareConst.TENSOR_TASK)) try: framework = check_and_get_from_json_dict(api_info_dict, MsCompareConst.FRAMEWORK, - "framework field in api_info.json", accepted_type=str, - accepted_value=(Const.MS_FRAMEWORK, - Const.MT_FRAMEWORK)) + "framework field in api_info.json", accepted_type=str, + accepted_value=(Const.MS_FRAMEWORK, + Const.MT_FRAMEWORK)) except Exception as e: framework = Const.MS_FRAMEWORK logger.warning(f"JSON parsing error in framework field: {e}") @@ -296,4 +393,4 @@ class ApiAccuracyChecker: elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP: self.data_manager.record_exception_skip(api_name_str, Const.BACKWARD, process_result_packet.err_msg) - self.data_manager.save_results(api_name_str) + self.data_manager.save_results(api_name_str) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_runner.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_runner.py index f42702be0b114e40e5e31dc4326bd9ca21f82202..30397ba956dfa38ddface89514a36a18d2297bcb 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_runner.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_runner.py @@ -13,13 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import ( + Any, + Dict, + List, + Tuple, + Union +) +import os +import numpy as np import mindspore from mindspore import ops -from msprobe.core.common.const import Const, MsCompareConst +from msprobe.core.common.const import Const from msprobe.core.common.exceptions import ApiAccuracyCheckerException from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list, torch_dtype_to_dtype_str from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple +from msprobe.mindspore.api_accuracy_checker.bench_functions.fusion_operator import fusion +from msprobe.mindspore.common.const import MsCompareConst from msprobe.mindspore.common.log import logger @@ -35,6 +46,21 @@ if torch_mindtorch_importer.is_valid_pt_mt_env: else: import torch +# 为了可读性,我们先给每种返回形态起个别名 +ForwardResult = Tuple[ + List[ComputeElement], + Tuple[Any, ...], + Dict[str, Any], + Tuple[Any, ...], +] + +BackwardResultMT = Tuple[ + List[ComputeElement], + Union[Any, Tuple[Any, ...]], + Tuple[Any, ...], +] + +PyTorchBackward = List[ComputeElement] class ApiInputAggregation: @@ -64,7 +90,9 @@ api_parent_module_mapping = { (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): mindtorch_func, (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): torch.nn.functional, (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): mindtorch_dist, - (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed + (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed, + (MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): mindspore.ops, + (MsCompareConst.FUSION_API, Const.PT_FRAMEWORK): fusion } @@ -83,7 +111,9 @@ api_parent_module_str_mapping = { (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): "mindtorch_func", (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): "torch.nn.functional", (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): "mindtorch_dist", - (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed" + (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed", + (MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): "mindspore.ops", + (MsCompareConst.FUSION_API, Const.PT_FRAMEWORK): "fusion" } @@ -122,38 +152,45 @@ class ApiRunner: """ api_name_list = api_name_str.split(Const.SEP) if len(api_name_list) != 3: - err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format" - logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) + err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format." \ + f" Exception has been raised and will be captured/logged externally." + logger.warning_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) api_type_str, api_sub_name = api_name_list[0], api_name_list[1] - if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API] \ + if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API, + MsCompareConst.FUNCTIONAL_API] \ and api_platform == Const.MS_FRAMEWORK: - err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api" - logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) + err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api," \ + f" api_name={api_name_str}. Exception has been raised and will be captured/logged externally." + logger.warning_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) if api_type_str not in MsCompareConst.MT_VALID_API_TYPES and api_platform == Const.MT_FRAMEWORK: - err_msg = f"ApiRunner.get_info_from_name failed: not torch, functional or Tensor api" - logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) + err_msg = f"ApiRunner.get_info_from_name failed: not torch, functional or Tensor api," \ + f" api_name={api_name_str}. Exception has been raised and will be captured/logged externally." + logger.warning_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) return api_type_str, api_sub_name @staticmethod def get_api_instance(api_type_str, api_sub_name, api_platform): """ Args: - api_type_str: str, Union["MintFunctional", "Mint", "Tensor"] + api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Functional"] api_sub_name: str, e.g. "relu" - api_platform: str: Union["mindpore", "torch"] + api_platform: str: Union["mindspore", "pytorch"] Return: api_instance: function object Description: - get mindspore.mint/torch api fucntion + get mindspore.mint/torch api function mindspore.mint.{api_sub_name} <--> torch.{api_sub_name} mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name} """ - - api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform)) - api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform)) + if api_sub_name in MsCompareConst.SUPPORTED_FUSION_LIST and api_platform == "pytorch": + api_parent_module = api_parent_module_mapping.get((MsCompareConst.FUSION_API, api_platform)) + api_parent_module_str = api_parent_module_str_mapping.get((MsCompareConst.FUSION_API, api_platform)) + else: + api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform)) + api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform)) full_api_name = api_parent_module_str + Const.SEP + api_sub_name if not hasattr(api_parent_module, api_sub_name): @@ -168,7 +205,12 @@ class ApiRunner: return api_instance @staticmethod - def run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform): + def run_api( + api_instance, + api_input_aggregation, + forward_or_backward: str, + api_platform: str, + ) -> Union[ForwardResult, BackwardResultMT, PyTorchBackward]: inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform) for compute_element in api_input_aggregation.inputs) kwargs = {key: value.get_parameter(get_origin=False, tensor_platform=api_platform) @@ -179,6 +221,8 @@ class ApiRunner: forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple forward_result_tuple = convert_to_tuple(forward_result) res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple] + if api_platform == Const.MS_FRAMEWORK or api_platform == Const.MT_FRAMEWORK: + return res_compute_element_list, inputs, kwargs, forward_result_tuple else: if gradient_inputs is None: err_msg = f"ApiRunner.run_api failed: run backward api but gradient_inputs is missing" @@ -196,6 +240,7 @@ class ApiRunner: backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple backward_result_tuple = convert_to_tuple(backward_result) res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple] + return res_compute_element_list, gradient_inputs, backward_result_tuple else: # set requires_grad requires_grad_index = [] diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py index ead03d25ea5c2e6bb0422486f1939c5b31ee589b..da2f8ad612fcf3a42083894ff1b8e56db757f919 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py @@ -18,9 +18,10 @@ from abc import ABC, abstractmethod import mindspore import numpy as np import torch -from msprobe.core.common.const import CompareConst, MsCompareConst +from msprobe.core.common.const import CompareConst from msprobe.core.common.exceptions import ApiAccuracyCheckerException from msprobe.mindspore.common.log import logger +from msprobe.mindspore.common.const import MsCompareConst class CompareResult: diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py new file mode 100644 index 0000000000000000000000000000000000000000..edb24e695a734c459ff3e6a151cf1d4727998bf2 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py @@ -0,0 +1,580 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +import torch +import torch.nn as nn +import numpy as np + +from einops import rearrange + + +from msprobe.pytorch.common.utils import logger + +GTYPE = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86 +SOFTMAX_BUILD_MODE = "QKV" # "MAX_SUM" + +FaForwardParams = namedtuple("FaForwardParams", + ["q", "k", "v", "drop_mask", "attn_mask", "pse", "scalar_value", "keep_prob"]) +FaBackwardParams = namedtuple("FaBackwardParams", + ["dx", "q", "k", "v", "softmax_res", "drop_mask", "pse", "scalar_value", "keep_prob"]) +RebuildSoftmaxParams = namedtuple("RebuildSoftmaxParams", + ["q", "k", "attn_mask", "pse", "scalar_value", "softmax_max", "softmax_sum"]) + + +def softmax_forward(x): + x_max = torch.max(x, dim=-1, keepdims=True)[0] + x_sub = x.sub(x_max) + y = torch.exp(x_sub) + x_sum = y.sum(dim=-1, keepdims=True) + res = y.div(x_sum) + return res, x_max, x_sum + + +def softmax_grad(dp, softmax_res): + muls = dp * softmax_res + muls_r = muls.sum(dim=-1, keepdims=True) + sub_r = dp - muls_r + res = sub_r * softmax_res + return res + + +def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype): + # 检查维度 + if kv_tensor.dim() != 4: + raise ValueError(f"broadcast_kv: kv_tensor 必须是 4 维 (B, N_kv, S, D),但得到 {kv_tensor.shape}") + if num_kv_heads == 0 or num_kv_heads > num_heads: + raise ValueError("broadcast_kv: num_kv_heads 必须大于 0 且不超过 num_heads。") + if num_heads % num_kv_heads != 0: + raise ValueError(f"broadcast_kv: num_heads({num_heads}) 必须能被 num_kv_heads({num_kv_heads}) 整除。") + + + factor = num_heads // num_kv_heads + kv_shape = kv_tensor.shape + b = kv_shape[0] + s = kv_shape[2] + d = kv_shape[3] + kv_res = torch.zeros([b, num_heads, s, d]).to(dtype) + for i in range(num_heads): + j = i // factor + kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :] + return kv_res + + +def calculate_qk(q, k, attn_mask, pse, scalar_value): + # 基本形状检查 + if q.dim() < 4 or k.dim() < 4: + raise ValueError(f"calculate_qk: q,k 必须至少 4 维,q={q.dim()},k={k.dim()}") + # 检查 head_dim 一致性 + if q.size(-1) != k.size(-1): + raise ValueError(f"calculate_qk: q.head_dim({q.size(-1)}) != k.head_dim({k.size(-1)})") + + if k.dim() != 4: + raise ValueError(f"k tensor dimension must be 4, but got {k.dim()} dimensions (shape: {k.shape})") + + if k.dim() == 3: + k = k.unsqueeze(1) # 在head维度扩展 + + if pse is None or len(pse.shape) == 0: + qk = torch.matmul(q, k.permute(0, 1, 3, 2)).mul(scalar_value) + else: + qk = (torch.matmul(q, k.permute(0, 1, 3, 2)) + pse).mul(scalar_value) + if attn_mask is None or len(attn_mask.shape) == 0: + return qk + else: + qk = qk + attn_mask.bool() * (-40000.0) # -10000 + return qk + + +def fusion_attention_forward(forward_params): + q = forward_params.q + k = forward_params.k + v = forward_params.v + drop_mask = forward_params.drop_mask + attn_mask = forward_params.attn_mask + pse = forward_params.pse + scalar_value = forward_params.scalar_value + keep_prob = forward_params.keep_prob + + # 拦截 keep_prob 为 0 的情况,防止除零 + if keep_prob == 0: + raise ValueError("fusion_attention_forward: keep_prob 不能为 0,避免除零错误。") + + qk = calculate_qk(q, k, attn_mask, pse, scalar_value) + softmax_res, softmax_max, softmax_sum = softmax_forward(qk) + if drop_mask is None or len(drop_mask.shape) == 0: + drop_res = softmax_res + else: + drop_res = softmax_res * drop_mask * (1.0 / keep_prob) + y = torch.matmul(drop_res, v) + return y, softmax_max, softmax_sum + + +def fusion_attention_backward(backward_params): + dx = backward_params.dx + q = backward_params.q + k = backward_params.k + v = backward_params.v + softmax_res = backward_params.softmax_res + drop_mask = backward_params.drop_mask + pse = backward_params.pse + scalar_value = backward_params.scalar_value + keep_prob = backward_params.keep_prob + + # 拦截 keep_prob 为 0 的情况,防止除零 + if keep_prob == 0: + raise ValueError("fusion_attention_backward: keep_prob 不能为 0,避免除零错误。") + + dp = torch.matmul(dx, v.permute(0, 1, 3, 2)) + if drop_mask is None or len(drop_mask.shape) == 0: + drop_res = softmax_res.permute(0, 1, 3, 2) + dp_drop = dp + else: + drop_res = softmax_res.mul(drop_mask).mul(1.0 / keep_prob).permute(0, 1, 3, 2) + dp_drop = dp * drop_mask * (1.0 / keep_prob) + dv = torch.matmul(drop_res, dx) + softmax_grad_res = (softmax_grad(dp_drop, softmax_res) * scalar_value) + dq = torch.matmul(softmax_grad_res, k) + dk = torch.matmul(softmax_grad_res.permute(0, 1, 3, 2), q) + return dq, dk, dv + + +def parse_bsnd_args(query, key, head_num, input_layout): + supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"] + b, s1, s2, n1, n2, d, h1, h2 = None, None, None, head_num, None, None, None, None + + if not isinstance(input_layout, str) or input_layout not in supported_input_layout: + raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.") + + if input_layout == "TND": + raise ValueError(f"input_layout {input_layout} does not supported for now.") + + # 防止 head_num 为 0 + if n1 == 0: + raise ValueError("parse_bsnd_args: head_num (n1) 不能为 0,避免除零错误。") + + try: + if input_layout == "BSH": + b, s1, h1 = query.shape + _, s2, h2 = key.shape + d = h1 // n1 + # 拦截 d 为 0 的情况 + if d == 0: + raise ValueError("parse_bsnd_args: 计算得到的 head_dim d 不能为 0。") + n2 = h2 // d + elif input_layout == "SBH": + s1, b, h1 = query.shape + s2, _, h2 = key.shape + d = h1 // n1 + if d == 0: + raise ValueError("parse_bsnd_args: 计算得到的 head_dim d 不能为 0。") + n2 = h2 // d + elif input_layout == "BSND": + b, s1, n1, d = query.shape + _, s2, n2, _ = key.shape + if d == 0: + raise ValueError("parse_bsnd_args: head_dim d 不能为 0。") + h1 = n1 * d + h2 = n2 * d + elif input_layout == "BNSD": + b, n1, s1, d = query.shape + _, n2, s2, _ = key.shape + if d == 0: + raise ValueError("parse_bsnd_args: head_dim d 不能为 0。") + h1 = n1 * d + h2 = n2 * d + except Exception as e: + raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e + + ret = (b, s1, s2, n1, n2, d, h1, h2, query.dtype) + return ret + + +def convert_from_bnsd(_input, input_layout): + """ + transform qkv from bnsd to input_layout. + B: batch_size + S: sequence_length + N: num_heads + D: head_dim + Args: + _input (torch.Tensor): tensor of shape (B,N,S,D) + input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND" + Returns: + tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H) + """ + if input_layout == "BSH": + # (B,N,S,D)=>(B,S,N*D) + out = rearrange(_input, 'b n s d -> b s (n d)').contiguous() + elif input_layout == "SBH": + # (B,N,S,D)=>(S,B,N*D) + out = rearrange(_input, 'b n s d -> s b (n d)').contiguous() + elif input_layout == "BSND": + # (B,N,S,D)=>(B,S,N,D) + out = rearrange(_input, 'b n s d -> b s n d').contiguous() + elif input_layout == "TND": + raise ValueError(f"input_layout {input_layout} does not supported for now.") + else: + out = _input + return out + + +def convert_to_bnsd(_input, n, input_layout): + """ + transform qkv from input_layout to bnsd. + B: batch_size + S: sequence_length + N: num_heads + D: head_dim + Args: + _input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H) + n (int): num_heads + input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND" + Returns: + tensor of shape (B,N,S,D) + """ + if input_layout == "BSH": + # (B,S,N*D)=>(B,N,S,D) + out = rearrange(_input, 'b s (n d) -> b n s d', n=n) + elif input_layout == "SBH": + # (S,B,N*D)=>(B,N,S,D) + out = rearrange(_input, 's b (n d) -> b n s d', n=n) + elif input_layout == "BSND": + # (B,S,N,D)=>(B,N,S,D) + out = rearrange(_input, 'b s n d -> b n s d', n=n) + elif input_layout == "TND": + raise ValueError(f"input_layout {input_layout} does not supported for now.") + else: + out = _input + if out.dim() != 4: + raise ValueError(f"convert qkv format failed with input_layout {input_layout}.") + return out.to(GTYPE) + + +def generate_attn_mask(*args): + """ + # 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现 + ===> attn_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype) + """ + + sparse_mode, attn_mask, b, n1, s1, s2, pre_tocken, next_tocken, dtype = args + shape = [s1, s2] + + if attn_mask is not None: + # 当FA的输入已经包含attn_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原 + if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4: + logger.info(f"s1: {s1}, s2:{s2}, attn_mask.shape:{attn_mask.shape}, attn_mask.dtype:{attn_mask.dtype}") + + if attn_mask.dim() == 2 and attn_mask.shape[0] == 2048 and attn_mask.shape[1] == 2048: + if attn_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(attn_mask.dtype)): + if sparse_mode == 2: + attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=1)) + elif sparse_mode == 3: + attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1)) + elif sparse_mode == 4: + attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1)) + attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1)) + attn_mask = attn_mask_u + attn_mask_l + logger.debug(f"反向转换attn_mask {attn_mask.shape}") + return attn_mask.to(dtype) + + return attn_mask.to(dtype) + + if attn_mask is not None: + if attn_mask.dim() == 2: + if attn_mask.shape[0] != s1 or attn_mask.shape[1] != s2: + raise ValueError(f"Invalid attn_mask shape `SS` {attn_mask.shape}") + shape = [s1, s2] + elif attn_mask.dim() == 4: + if attn_mask.shape[1] == 1: + shape = [b, 1, s1, s2] if b != 1 else [1, 1, s1, s2] + else: + shape = [b, n1, s1, s2] if b != 1 else [1, n1, s1, s2] + + if sparse_mode == 0: + attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1)) + attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1)) + attn_mask = attn_mask_u + attn_mask_l + elif sparse_mode == 1: # no sparse + attn_mask = torch.from_numpy(np.zeros(shape)) + elif sparse_mode == 2: + attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=1)) + elif sparse_mode == 3: + attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1)) + elif sparse_mode == 4: + attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1)) + attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1)) + attn_mask = attn_mask_u + attn_mask_l + # 注:不会出现sparse_mode=5的情况,该情况要求必须要传入attn_mask,且attn_mask矩阵数据格式须为BNSS或B1SS, + # 因此可以认为FA的输入已经是正确的attn_mask了 + return attn_mask.to(dtype) + + +def generate_kv(key, value, n1, n2): + # N不等长适配by cdy + if not (n1 == n2): + k_new = broadcast_kv(n1, n2, key, key.dtype) + v_new = broadcast_kv(n1, n2, value, value.dtype) + else: + k_new = key + v_new = value + return k_new, v_new + + +def rebuid_softmax_by_qkv(q, k, attn_mask, pse, scalar_value): + """ + attention = softmax(QK^T/sqrt(d))V + softmax(x_i) = e^(x_i - x_max) / sum(e^(x_i - x_max)) + """ + logger.info("Using QKV to rebuild original softmax") + qk = calculate_qk(q, k, attn_mask, pse, scalar_value) + softmax_res, _, _ = softmax_forward(qk) + return softmax_res + + +def rebuild_softmax_by_max_sum(softmax_params): + """ + attention = softmax(QK^T/sqrt(d))V + softmax(x_i) = e^(x_i - x_max_i) / x_sum_i) + """ + q = softmax_params.q + k = softmax_params.k + attn_mask = softmax_params.attn_mask + pse = softmax_params.pse + scalar_value = softmax_params.scalar_value + softmax_max = softmax_params.softmax_max + softmax_sum = softmax_params.softmax_sum + logger.info("Using softmax_max and softmax_sum to rebuild original softmax") + + qk = calculate_qk(q, k, attn_mask, pse, scalar_value) + if softmax_max.shape[-1] == 0: + raise ValueError(f"softmax_max.shape[-1] must be non-zero, softmax_max.shape: {softmax_max.shape}") + repeat_dim = qk.shape[-1] // softmax_max.shape[-1] + softmax_res = torch.exp(qk.sub(softmax_max.repeat(1, 1, 1, repeat_dim))).div( + softmax_sum.repeat(1, 1, 1, repeat_dim)) + return softmax_res + + +def get_head_num(*args, **kwargs): + if kwargs.get("head_num", None): + head_num = kwargs.get("head_num") + elif len(args) >= 4: + head_num = args[3] + else: + raise ValueError(f"Unsupported npu_fusion_attention args {args}.") + return head_num + + +def get_input_layout(*args, **kwargs): + if kwargs.get("input_layout", None): + input_layout = kwargs.get("input_layout") + elif len(args) >= 5: + input_layout = args[4] + else: + raise ValueError(f"Unsupported npu_fusion_attention args {args}.") + return input_layout + + +def npu_fusion_attention_forward_patch(*args, **kwargs): + if len(args) < 2: + raise RuntimeError("npu_fusion_attention_forward_patch: length of args should be greater than or equal to 2.") + + # query, key, value, head_num, input_layout + head_num = get_head_num(*args, **kwargs) + input_layout = get_input_layout(*args, **kwargs) + + b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], head_num, input_layout) + # 此处 d 已在 parse_bsnd_args 中检查为非零 + if n1 == n2 and s1 == s2: + logger.debug(f"running case : BNSD = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}") + else: + logger.debug(f"running case: BNSD = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}") + if n2 == 0: + raise ValueError("n2 不能为 0,避免除零错误。") + if not (n1 % n2 == 0 and n1 >= n2): + raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.") + + dims_kwargs = { + "b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2, + "d": d, "h1": h1, "h2": h2, "dtype": dtype + } + new_kwargs = { + "keep_prob": 1, # 注意:如果外部传入 keep_prob 为 0,也会在 fusion_attention_forward 中捕获 + "scalar_value": kwargs.get("scalar_value", 1 / (d ** 0.5)), + "sparse_mode": kwargs.get("sparse_mode", 0), + "prefix": kwargs.get("prefix"), + "pre_tockens": kwargs.get("pre_tockens", 2147483647), + "next_tockens": kwargs.get("next_tockens", 2147483647), + "pse": kwargs.get("pse"), + "padding_mask": kwargs.get("padding_mask"), + "attn_mask": kwargs.get("attn_mask") + } + + return args, dims_kwargs, new_kwargs + + +def npu_fusion_attention_backward_patch(*args, **kwargs): + if len(args) != 6: + raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.") + + b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], args[4], args[5]) + # 此处 d 已在 parse_bsnd_args 中检查为非零 + if n1 == n2 and s1 == s2: + logger.info(f"running case : bnsd = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}") + else: + logger.info(f"running case: bnsd = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}") + if n2 == 0: + raise ValueError("n2 不能为 0,避免除零错误。") + if not (n1 % n2 == 0 and n1 >= n2): + raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.") + + dims_kwargs = { + "b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2, + "d": d, "h1": h1, "h2": h2, "dtype": dtype + } + + new_kwargs = { + "keep_prob": 1, # 同上,fusion_attention_backward 内会拦截 keep_prob 为 0 的情况 + "scalar_value_value": kwargs.get("scalar_value_value", 1 / (d ** 0.5)), + "sparse_mode": kwargs.get("sparse_mode", 0), + "prefix": kwargs.get("prefix"), + "pre_tockens": kwargs.get("pre_tockens", 2147483647), + "next_tockens": kwargs.get("next_tockens", 2147483647), + "pse": kwargs.get("pse"), + "padding_mask": kwargs.get("padding_mask"), + "softmax_max": kwargs.get("softmax_max"), + "softmax_sum": kwargs.get("softmax_sum"), + "softmax_in": kwargs.get("softmax_in"), + "attention_in": kwargs.get("attention_in"), + "seed": kwargs.get("seed", 0), + "offset": kwargs.get("offset", 0), + "numels": kwargs.get("numels", 0), + "attn_mask": kwargs.get("attn_mask") + } + + return args, dims_kwargs, new_kwargs + + +class FlashAttentionScore(nn.Module): + def __init__(self): + super(FlashAttentionScore, self).__init__() + # You can initialize any parameters here if necessary + + def forward(self, *inputs, **kwargs): + # Extract the inputs for the attention calculation + new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*inputs, **kwargs) + query, key, value = new_args[0], new_args[1], new_args[2] + + input_layout = get_input_layout(*inputs, **kwargs) + + n1 = dims_kwargs.get("n1") + n2 = dims_kwargs.get("n2") + s1 = dims_kwargs.get("s1") + s2 = dims_kwargs.get("s2") + b = dims_kwargs.get("b") + dtype = dims_kwargs.get("dtype") + attn_mask = new_kwargs.get("attn_mask") + keep_prob = new_kwargs.get("keep_prob") + sparse_mode = new_kwargs.get("sparse_mode") + pre_tockens = new_kwargs.get("pre_tockens") + next_tockens = new_kwargs.get("next_tokens") + pse = new_kwargs.get("real_shift") + scalar_value = new_kwargs.get("scalar_value") + + args_temp = [sparse_mode, attn_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype] + + attn_mask = generate_attn_mask(*args_temp) + query = convert_to_bnsd(query, n1, input_layout) + key = convert_to_bnsd(key, n2, input_layout) + value = convert_to_bnsd(value, n2, input_layout) + + forward_params = FaForwardParams( + q=query, + k=key, + v=value, + drop_mask=None, + attn_mask=attn_mask, + pse=pse, + scalar_value=scalar_value, + keep_prob=keep_prob + ) + + out_golden, softmax_max, softmax_sum = fusion_attention_forward(forward_params) + + # If output dimension is 5, reshape accordingly + if out_golden.dim() == 5: + out_golden = out_golden.reshape(out_golden.size(0), + out_golden.size(1) * out_golden.size(2), + out_golden.size(3), out_golden.size(4)) + + out_golden = convert_from_bnsd(out_golden, input_layout) + + # Ensure the output matches the desired layout + out_golden = out_golden.cpu(), softmax_max.repeat(1, 1, 1, 8).cpu(), softmax_sum.repeat(1, 1, 1, 8).cpu() + + return out_golden + + def backward(self, *inputs, **kwargs): + # The backward pass will be similar to what was described for the gradient computation + new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*inputs, **kwargs) + query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5] + n1 = dims_kwargs.get("n1") + n2 = dims_kwargs.get("n2") + s1 = dims_kwargs.get("s1") + s2 = dims_kwargs.get("s2") + b = dims_kwargs.get("b") + dtype = dims_kwargs.get("dtype") + attn_mask = new_kwargs.get("attn_mask") + keep_prob = new_kwargs.get("keep_prob") + sparse_mode = new_kwargs.get("sparse_mode") + pre_tockens = new_kwargs.get("pre_tockens") + next_tockens = new_kwargs.get("next_tockens") + pse = new_kwargs.get("pse") + softmax_max = new_kwargs.get("softmax_max") + softmax_sum = new_kwargs.get("softmax_sum") + scalar_value = new_kwargs.get("scalar_value") + + args_temp = [sparse_mode, attn_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype] + attn_mask = generate_attn_mask(*args_temp) + + query = convert_to_bnsd(query, n1, input_layout) + dx = convert_to_bnsd(dx, n1, input_layout) + key = convert_to_bnsd(key, n2, input_layout) + value = convert_to_bnsd(value, n2, input_layout) + + k_new, v_new = generate_kv(key, value, n1, n2) + + if SOFTMAX_BUILD_MODE == "QKV": + softmax_res = rebuid_softmax_by_qkv(query, k_new, attn_mask, pse, scalar_value) + else: + softmax_params = RebuildSoftmaxParams(query, k_new, attn_mask, pse, scalar_value, softmax_max, softmax_sum) + softmax_res = rebuild_softmax_by_max_sum(softmax_params) + + backward_params = FaBackwardParams(dx, query, k_new, v_new, softmax_res, None, pse, scalar_value, keep_prob) + dq, dk, dv = fusion_attention_backward(backward_params) + + # Reshape as needed + if dq.dim() == 5: + dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4)) + if dk.dim() == 5: + dk = dk.reshape(dk.size(0), dk.size(1) * dk.size(2), dk.size(3), dk.size(4)) + if dv.dim() == 5: + dv = dv.reshape(dv.size(0), dv.size(1) * dv.size(2), dv.size(3), dv.size(4)) + + dq = convert_from_bnsd(dq, input_layout) + dk = convert_from_bnsd(dk, input_layout) + dv = convert_from_bnsd(dv, input_layout) + + return dq.cpu(), dk.cpu(), dv.cpu() diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py similarity index 34% rename from debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py rename to debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py index aace2f13cc0eeb34a51c03907c9a87a6479617c4..e1344541e89c4dafd9d49d63e3fdea117366bdd9 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py @@ -13,32 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -cipher_list = ":".join( - ["TLS_DHE_RSA_WITH_AES_128_GCM_SHA256", - "TLS_DHE_RSA_WITH_AES_256_GCM_SHA384", - "TLS_DHE_DSS_WITH_AES_128_GCM_SHA256", - "TLS_DHE_DSS_WITH_AES_256_GCM_SHA384", - "TLS_DHE_PSK_WITH_AES_128_GCM_SHA256", - "TLS_DHE_PSK_WITH_AES_256_GCM_SHA384", - "TLS_DHE_PSK_WITH_CHACHA20_POLY1305_SHA256", - "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", - "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", - "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", - "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", - "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", - "TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256", - "TLS_ECDHE_PSK_WITH_AES_128_GCM_SHA256", - "TLS_ECDHE_PSK_WITH_AES_256_GCM_SHA384", - "TLS_ECDHE_PSK_WITH_AES_128_CCM_SHA256", - "TLS_DHE_RSA_WITH_AES_128_CCM", - "TLS_DHE_RSA_WITH_AES_256_CCM", - "TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256", - "TLS_DHE_PSK_WITH_AES_128_CCM", - "TLS_DHE_PSK_WITH_AES_256_CCM", - "TLS_ECDHE_ECDSA_WITH_AES_128_CCM", - "TLS_ECDHE_ECDSA_WITH_AES_256_CCM", - "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"] -).encode() - -STRUCT_UNPACK_MODE = "!Q" -STR_TO_BYTES_ORDER = "big" +from msprobe.mindspore.api_accuracy_checker.bench_functions.flash_attention_score import FlashAttentionScore + + +class FusionOperator: + """ + 所有融合算子的父类,定义了通用的接口和属性。 + """ + + # 初始化操作符字典 + def __init__(self): + self.flash_attention_score = None # 用于存放 FlashAttentionScore 操作符 + self._register_operators() + + def __getattr__(self, name): + """ 动态获取算子类 """ + if hasattr(self, name): + return getattr(self, name) + else: + raise AttributeError(f"'FusionOperator' object has no attribute '{name}'") + + def _register_operators(self): + """ 注册操作符到父类,以便通过 fusion.xxx 调用 """ + self.flash_attention_score = FlashAttentionScore() + + +fusion = FusionOperator() diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/cmd_parser.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/cmd_parser.py index 4af92bfa1002c419d0bd84e5dfd250b712b57136..a55df65a3772c99a6b63ebff171adb710714ab90 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/cmd_parser.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/cmd_parser.py @@ -39,6 +39,8 @@ def add_api_accuracy_checker_argument(parser): help=" The ut task result out path.") parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False, help=" the exit csv for continue") + parser.add_argument('-save_error_data', dest="save_error_data", action="store_true", + help=" Save compare failed api output.", required=False) def multi_add_api_accuracy_checker_argument(parser): @@ -49,6 +51,8 @@ def multi_add_api_accuracy_checker_argument(parser): help=" The ut task result out path.") parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False, help=" the exit csv for continue") + parser.add_argument('-save_error_data', dest="save_error_data", action="store_true", + help=" Save compare failed api output.", required=False) #以下属于多线程参数 parser.add_argument("-d", "--device", dest="device_id", nargs='+', type=int, help=" set device id to run ut, must be unique and in range 0-7", diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/compute_element.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/compute_element.py index 5dcb1421c2b96c61cad766cba7ad5c85107b5b29..26a47c7712fb263badcb43be8deb51d9741a2d6a 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/compute_element.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/compute_element.py @@ -17,7 +17,6 @@ import os import mindspore import numpy as np -import torch from mindspore._c_expression import typing from msprobe.core.common.const import Const from msprobe.core.common.exceptions import ApiAccuracyCheckerException @@ -68,8 +67,9 @@ class ComputeElement: elif compute_element_info is None: self._init_from_null_compute_element_info() else: - logger.error_log_with_exp( - "ComputeElement.__init__ failed: not init with parameter or compute_element info is not (list, dict)", + logger.warning_log_with_exp( + "ComputeElement.__init__ failed: not init with parameter or compute_element info is not (list, dict)." + " Exception has been raised and will be captured/logged externally.", ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) @staticmethod @@ -83,8 +83,9 @@ class ComputeElement: ms_dtype = ms_tensor.dtype dtype_str = ms_dtype_to_dtype_str.get(ms_dtype) if dtype_str not in dtype_str_to_torch_dtype: - err_msg = f"ComputeElement.transfer_to_torch_tensor failed: no matching torch dtype for {dtype_str}" - logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) + err_msg = f"ComputeElement.transfer_to_torch_tensor failed: no matching torch dtype" \ + f" for {dtype_str}. Exception has been raised and will be captured/logged externally." + logger.warning_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) else: torch_dtype = dtype_str_to_torch_dtype.get(dtype_str) @@ -110,8 +111,9 @@ class ComputeElement: dtype_str = ms_dtype_to_dtype_str.get(ms_dtype) if dtype_str not in dtype_str_to_mindtorch_dtype: - err_msg = f"ComputeElement.transfer_to_mindtorch_tensor failed: no matching mindtorch dtype for {dtype_str}" - logger.error_log_with_exp(err_msg, + err_msg = f"ComputeElement.transfer_to_mindtorch_tensor failed: no matching mindtorch dtype for" \ + f" {dtype_str}. Exception has been raised and will be captured/logged externally." + logger.warning_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) else: mindtorch_dtype = dtype_str_to_mindtorch_dtype.get(dtype_str) @@ -140,8 +142,9 @@ class ComputeElement: dtype_str = torch_dtype_to_dtype_str.get(torch_dtype) if dtype_str not in dtype_str_to_ms_dtype: err_msg = \ - f"ComputeElement._transfer_to_mindspore_tensor failed: no matching mindspore dtype for {dtype_str}" - logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) + f"ComputeElement._transfer_to_mindspore_tensor failed: no matching mindspore dtype for {dtype_str}. " \ + f"Exception has been raised and will be captured/logged externally." + logger.warning_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) else: ms_dtype = dtype_str_to_ms_dtype.get(dtype_str) @@ -199,8 +202,9 @@ class ComputeElement: parameter_tmp = mindspore.Tensor(ndarray, dtype=ms_dtype) else: err_msg = "ComputeElement.get_parameter failed: self.parameter type is not in " \ - "(int, float, str, slice, bool, torch.Tensor, mindspore.Tensor, MstensorMetaData)" - logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) + "(int, float, str, slice, bool, torch.Tensor, mindspore.Tensor, MstensorMetaData)." \ + "Exception has been raised and will be captured/logged externally." + logger.warning_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) # if necessary, do transfer if not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.PT_FRAMEWORK: @@ -297,8 +301,9 @@ class ComputeElement: self.shape = tuple() if not isinstance(parameter, self.supported_parameter_type): err_msg = "ComputeElement._init_with_parameter failed: " \ - "parameter type is not in (int, float, str, slice, bool, torch.Tensor, mindspore.Tensor)" - logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) + "parameter type is not in (int, float, str, slice, bool, torch.Tensor, mindspore.Tensor)." \ + "Exception has been raised and will be captured/logged externally." + logger.warning_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) if isinstance(parameter, mindspore.Tensor): self.shape = tuple(parameter.shape) self.dtype_str = ms_dtype_to_dtype_str.get(parameter.dtype) diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/data_manager.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/data_manager.py index 748adf7d02cafe3983fe1990b40b1e77e993698b..24f6eb717e7ebf8fabb59d397d493831011e1161 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/data_manager.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/data_manager.py @@ -16,12 +16,13 @@ import os import csv -from msprobe.core.common.const import Const, CompareConst, MsCompareConst +from msprobe.core.common.const import Const, CompareConst from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, read_csv from msprobe.core.common.utils import add_time_as_suffix, MsprobeBaseException from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms from msprobe.core.common.file_utils import check_file_or_directory_path from msprobe.mindspore.common.log import logger +from msprobe.mindspore.common.const import MsCompareConst class ResultCsvEntry: @@ -187,7 +188,7 @@ class DataManager: def record_exception_skip(self, api_name, forward_or_backward, err_msg): ''' - record exception_skip infomation into self.record_exception_skip. + record exception_skip information into self.record_exception_skip. self.record_exception_skip: dict{str: dict{"forward": str/None, "backward": str/None}} string in key is api_name, string in value is err_msg ''' @@ -269,7 +270,7 @@ class DataManager: entry.backward_pass_status, overall_err_msg ] - # change row if this api has excption_skip infomation + # change row if this api has exception_skip information if api_name in self.results_exception_skip: if self.results_exception_skip[api_name][Const.FORWARD] is not None: row[1] = CompareConst.SKIP diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..1546a6cd50a0b48ce8a2534787f088efa9cea15e --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py @@ -0,0 +1,460 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# 标准库 +import argparse +import json +import os +import re +import string + +# 应用程序自定义模块 +from msprobe.core.common.file_utils import ( + FileOpen, + load_json, + save_json, + make_dir, + change_mode, +) +from msprobe.core.common.utils import ( + check_file_or_directory_path, + check_op_str_pattern_valid, + is_int, +) +from msprobe.core.common.const import Const, MonitorConst, MsgConst, FileCheckConst +from msprobe.core.common.log import logger +from msprobe.core.common.decorator import recursion_depth_decorator + +OPERATOR_TYPE = ("Functional", "Tensor", "Torch", "Mint") + +API_INFO = 2 +FOUR_SEGMENT = 4 +FIVE_SEGMENT = 5 +DATA_NAME = "data_name" +API_MAX_LENGTH = 30 +PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD] +DATAMODE_LIST = ["random_data", "real_data"] +ITER_MAX_TIMES = 1000 +FRAMEWORK = 'framework' +REAL_DATA_PATH = 'real_data_path' +EXCLUED = {FRAMEWORK, REAL_DATA_PATH} + + +class APIInfo: + def __init__(self, api_full_name, api_info_dict, backward_info=None): + self.api_full_name = api_full_name + self.api_info_dict = api_info_dict + self.backward_info = backward_info + + @property + def api_type(self): + return self.api_full_name.split(Const.SEP, -1)[0] + + @classmethod + def from_json(cls, json_content, propagation): + forward_name, forward_dict = list(json_content.items())[0] + forward_info = cls(api_full_name=forward_name, api_info_dict=forward_dict) + + if propagation == Const.BACKWARD: + backward_name, backward_dict = list(json_content.items())[1] + backward_info = cls(api_full_name=backward_name, api_info_dict=backward_dict) + forward_info.backward_info = backward_info + + if not forward_info.is_supported_type(): + raise ValueError(f"type {forward_info.api_type} of API is not supported!") + + return forward_info + + def is_supported_type(self): + return self.api_type in OPERATOR_TYPE + + +class CommonConfig: + def __init__(self, json_config): + self.dump_json_path = json_config.get('dump_json_path') + self.api_name = json_config.get('api_name') + self.extract_api_path = json_config.get('extract_api_path') + self.propagation = json_config.get('propagation') + self.data_mode = json_config.get('data_mode') + self.random_seed = json_config.get('random_seed') + self.iter_times = json_config.get('iter_times') + self._check_config() + + def check_user_settings(self): + iter_t = self.iter_times + if iter_t <= 0 or iter_t > ITER_MAX_TIMES: + raise ValueError(f"iter_times should be range from 1 to {ITER_MAX_TIMES}.") + + json_file = self.extract_api_path + propagation = self.propagation + + json_content = load_json(json_file) + + # ensure the dict is not empty + if not json_content: + raise ValueError(f'json file is empty!') + + # ensure json_content is of type dict + if not isinstance(json_content, dict): + raise ValueError(f'content of json file is not a dict!') + + # ensure the length of json_content is within allowed limits + + filtered = {k: v for k, v in json_content.items() if k not in EXCLUED} + + if not filtered: + raise ValueError(f'json file is empty!') + + if len(filtered) > API_INFO: + raise ValueError(f'json file has more than one API, the API only contains forward and backward info') + + is_forward_phase = propagation == Const.FORWARD + + is_exact_api_count = len(filtered) == API_INFO + + all_keys_forward = all(k.endswith('forward') for k in filtered) + + if is_forward_phase and is_exact_api_count and all_keys_forward: + raise ValueError( + "json file has more than one API, the API only contains forward info。" + ) + + # Retrieve the first API name and dictionary + forward_item = next(iter(json_content.items()), None) + if not forward_item or not isinstance(forward_item[1], dict) or not forward_item[1]: + raise ValueError(f'Invalid forward API data in json_content!') + + # if propagation is backward, ensure json file contains forward and backward info + if propagation == Const.BACKWARD and len(filtered) < API_INFO: + raise ValueError(f'Backward propagation requires contains forward and backward info!') + + # if propagation is backward, ensure it has valid data + if propagation == Const.BACKWARD: + backward_item = list(json_content.items())[1] + if not isinstance(backward_item[1], dict) or not backward_item[1]: + raise ValueError(f'Invalid backward API data in json_content!') + + return json_content + + def _check_config(self): + if self.dump_json_path: + check_file_or_directory_path(self.dump_json_path) + if self.api_name: + check_op_str_pattern_valid(self.api_name) + if len(self.api_name) > API_MAX_LENGTH: + raise ValueError(f'API name {self.api_name} is too long!') + make_dir(os.path.dirname(self.extract_api_path)) + if self.propagation and self.propagation not in PROPAGATION_LIST: + raise ValueError(f'propagation is invalid, it should be one of {PROPAGATION_LIST}') + if self.data_mode and self.data_mode not in DATAMODE_LIST: + raise ValueError(f'data_mode is invalid, it should be one of {DATAMODE_LIST}') + if not is_int(self.random_seed): + raise ValueError(f'random_seed is invalid, it should be an int') + if not is_int(self.iter_times): + raise ValueError(f'iter_times is invalid, it should be an int') + + +class APIExtractor: + def __init__(self, api_name, dump_json_path, output_file): + self.api_name = api_name + self.dump_json_path = dump_json_path + self.output_file = output_file + self.data = None + self.framework = None + self.real_data_path = None + + def extract_op(self): + self.data = load_json(self.dump_json_path) + # 拿到 framework + self.framework = self.data.get(FRAMEWORK, None) + + new_data = {} + extract_key_pattern = re.compile(f"^{re.escape(self.api_name)}\..+") # 修改为只要包含或等于apiname即可,不需要是只包含 + + self.real_data_path = self.data.get('dump_data_dir', '') + + for key, value in self.data.get('data', {}).items(): + if extract_key_pattern.match(key): + if self.real_data_path: + value = self.load_real_data_path(value, self.real_data_path) + new_data[key] = value + + if self.real_data_path is not None: + new_data[REAL_DATA_PATH] = self.real_data_path + + # 把 framework 加进去 + if self.framework is not None: + new_data[FRAMEWORK] = self.framework + if not new_data: + logger.warning(f"Warning: The api '{self.api_name}' does not exist in the file.") + else: + save_json(self.output_file, new_data, indent=4) + logger.info( + f"The api '{self.api_name}' has been successfully extracted and saved in: {self.output_file}") + + def load_real_data_path(self, value, dump_data_dir): + parameters = [Const.INPUT_ARGS, Const.GRAD_INPUT, Const.INPUT, Const.OUTPUT, Const.GRAD_OUTPUT] + for parameter in parameters: + for v in value.get(parameter, []): + if v is not None: + self.update_data_name(v, dump_data_dir) + return value + + @recursion_depth_decorator("OpGenerator: APIExtractor.update_data_name") + def update_data_name(self, data, dump_data_dir): + if isinstance(data, list): + for item in data: + self.update_data_name(item, dump_data_dir) + elif DATA_NAME in data: + data[DATA_NAME] = os.path.join(dump_data_dir, data[DATA_NAME]) + + +class OperatorScriptGenerator: + def __init__(self, common_config, args_info_forward, kwargs_info_forward, args_info_backward): + self.common_config = common_config + self.args_info_forward = args_info_forward + self.kwargs_info_forward = kwargs_info_forward + self.args_info_backward = args_info_backward + + @staticmethod + def extract_detailed_api_segments(full_api_name): + """ + Function Description: + Extract the name of the API. + Parameter: + full_api_name_with_direction_status: Full name of the API. Example: torch.matmul.0.forward.output.0 + Return: + api_name: Name of api. Example: matmul, mul, etc. + full_api_name: Full name of api. Example: torch.matmul.0 + direction_status: Direction status of api. Example: forward, backward, etc. + """ + api_parts = full_api_name.split(Const.SEP) + api_parts_length = len(api_parts) + api_type, api_name, api_order = None, None, None + if api_parts_length == FOUR_SEGMENT: + api_type, api_name, api_order, _ = api_parts + elif api_parts_length == FIVE_SEGMENT: + api_type, prefix, api_name, api_order, _ = api_parts + api_name = Const.SEP.join([prefix, api_name]) + return api_type, api_name, api_order + + @staticmethod + def generate_forward_inputs_code(args_info): + names = [] + + def collect(info): + if isinstance(info, dict): + names.append(info["parameter_name"]) + else: + for sub in info: + collect(sub) + + collect(args_info) + + return ( + " forward_inputs = [\n" + " ComputeElement(parameter=info)\n" + " for info in (" + ", ".join(names) + ")\n" + " ]\n" + ) + + @staticmethod + def generate_kwargs_compute_element_dict_code(): + return ( + " # ---- 构造 kwargs 对应的 ComputeElement 字典 ----\n" + " kwargs_compute_element_dict = {\n" + " key_str: ComputeElement(compute_element_info=compute_element_info)\n" + " for key_str, compute_element_info in kwargs_device.items()\n" + " }\n" + ) + + @staticmethod + def generate_gradient_inputs_code(args_info_backward): + names = [] + + def collect(info): + if isinstance(info, dict): + names.append(info["parameter_name"]) + else: + for sub in info: + collect(sub) + + collect(args_info_backward) + + return ( + " # —— 构造反向梯度 ComputeElement 列表 —— #\n" + " gradient_inputs = [\n" + " ComputeElement(parameter=info)\n" + " for info in (" + ", ".join(names) + ")\n" + " ]\n" + ) + + def get_settings(self, api_full_name): + ''' + internal_settings contain all information needed for the operator program. + keys: + api_full_name: api_type.api_name.ordinal_number + api_type: type of API, one of torch.nn.functional, torch.Tensor or Torch + api_name: name of API + ordinal_number: how many times the same api has been called + direction_status: forward + random_seed: if mode is random_data, random seed is random_seed + iter_times: if mode is random_data, generate iter_times group of data; if mode is real_data, + iter_times does not matter + args_element_assignment: code for args assignment + args_list_generator_device: code for generate args list on device + args_list_generator_bench: code for generate args list on bench + kwargs_value_assignment: code for kwargs assignment + kwargs_dict_generator_device: code for generate kwargs dict on device + kwargs_dict_generator_bench: code for generate kwargs dict on bench + ''' + # Generate an internal setting dictionary based on user settings + # including API name, type, comparison standard, random seed, number of iterations and other information + internal_settings = {} + internal_settings["propagation"] = self.common_config.propagation + internal_settings["api_full_name"] = api_full_name + api_type, api_name, ordinal_number = self.extract_detailed_api_segments(api_full_name) + if api_type == "Functional": + internal_settings["api_type"] = "torch.nn.functional" + elif api_type == "Tensor": + internal_settings["api_type"] = "torch.Tensor" + else: + internal_settings["api_type"] = "torch" + internal_settings["api_name"] = api_name + internal_settings["ordinal_number"] = ordinal_number + internal_settings["direction_status"] = self.common_config.propagation + internal_settings["random_seed"] = self.common_config.random_seed + internal_settings["data_mode"] = self.common_config.data_mode + if self.common_config.data_mode == "real_data": + internal_settings["iter_times"] = 1 + else: + internal_settings["iter_times"] = self.common_config.iter_times + + internal_settings["args_info_forward"] = self.args_info_forward + internal_settings["kwargs_info_forward"] = self.kwargs_info_forward + internal_settings["args_info_backward"] = self.args_info_backward + + return internal_settings + + +def _op_generator_parser(parser): + parser.add_argument("-i", "--config_input", dest="config_input", type=str, + help=" Path of config json file", required=True) + parser.add_argument("-o", "--api_output_path", dest="api_output_path", type=str, + help=" Path of extract api_name.json.", required=True) + + +def parse_json_config(json_file_path): + if not json_file_path: + raise Exception("config_input path can not be empty, please check.") + json_config = load_json(json_file_path) + common_config = CommonConfig(json_config) + return common_config + + +def _run_operator_generate_commond(cmd_args): + common_config = parse_json_config(cmd_args.config_input) + + if common_config.dump_json_path: + api_extract = APIExtractor(common_config.api_name, common_config.dump_json_path, common_config.extract_api_path) + api_extract.extract_op() + framework = api_extract.framework + real_data_path = api_extract.real_data_path + check_file_or_directory_path(common_config.extract_api_path) + check_file_or_directory_path(cmd_args.api_output_path, isdir=True) + json_content = common_config.check_user_settings() + api_info = APIInfo.from_json(json_content, common_config.propagation) + + if common_config.propagation == Const.BACKWARD: + # read and check json + api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict + api_full_name_backward, api_info_dict_backward = (api_info.backward_info.api_full_name, + api_info.backward_info.api_info_dict) + args_info_forward = api_info_dict_forward.get(Const.INPUT_ARGS) + kwargs_info_forward = api_info_dict_forward.get(Const.INPUT_KWARGS) + if Const.GRAD_INPUT in api_info_dict_backward: + args_info_backward = api_info_dict_backward.get(Const.GRAD_INPUT) + elif Const.INPUT in api_info_dict_backward: + args_info_backward = api_info_dict_backward.get(Const.INPUT) + op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, args_info_backward) + internal_settings = op_generate.get_settings(api_full_name_backward) + internal_settings[FRAMEWORK] = framework + internal_settings[REAL_DATA_PATH] = real_data_path + else: + # read and check json + api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict + + args_info_forward = api_info_dict_forward.get(Const.INPUT_ARGS) + + kwargs_info_forward = api_info_dict_forward.get(Const.INPUT_KWARGS) + + op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, None) + internal_settings = op_generate.get_settings(api_full_name_forward) + internal_settings[FRAMEWORK] = framework + internal_settings[REAL_DATA_PATH] = real_data_path + + template_path = os.path.join(os.path.dirname(__file__), "operator_replication.template") + operator_script_path = os.path.join(cmd_args.api_output_path, + "{0}.py".format(internal_settings.get("api_full_name"))) + + class SafeDict(dict): + def __missing__(self, key): + # leave {key} in the output if it’s not in the dict + return '{' + key + '}' + + class RobustFormatter(string.Formatter): + def vformat(self, format_string, args, kwargs): + result = [] + # parse() 会把文本和每个占位符拆开 + for literal, field_name, format_spec, conversion in self.parse(format_string): + # 输出字面文本 + result.append(literal) + if field_name is None: + continue + try: + # 正常获取变量并格式化 + obj, _ = self.get_field(field_name, args, kwargs) + if conversion: + obj = self.convert_field(obj, conversion) + result.append(self.format_field(obj, format_spec)) + except Exception: + # 不管是 KeyError 还是 ValueError,都原样回写 {field_name[:format_spec]} + placeholder = '{' + field_name + if conversion: + placeholder += '!' + conversion + if format_spec: + placeholder += ':' + format_spec + placeholder += '}' + result.append(placeholder) + return ''.join(result) + + fmt = RobustFormatter() + with FileOpen(template_path, 'r') as ftemp, FileOpen(operator_script_path, 'w') as fout: + code_template = ftemp.read() + # 这里用 fmt.format,不用 format_map + fout.write(fmt.format(code_template, **internal_settings)) + + change_mode(operator_script_path, FileCheckConst.DATA_FILE_AUTHORITY) + + logger.info(f"Generate operator script successfully and the name is {operator_script_path}.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + _op_generator_parser(parser) + cmd_args = parser.parse_args() + _run_operator_generate_commond(cmd_args) diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template new file mode 100644 index 0000000000000000000000000000000000000000..7de12218a3f78f9be64192ad70d83bdaf6a437fe --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template @@ -0,0 +1,2081 @@ +import os +import re +import stat +import time +from enum import Enum, auto +from abc import ABC, abstractmethod +import csv +import random + +import gc +import sys +from pathlib import Path +import mindspore +from mindspore import ops + + +from tabulate import tabulate + +import logging + +import traceback + + + +def error_log_with_exp(self, msg: str, exp: Exception): + """ + msg: 你的错误提示 + exp: 你要记录的 Exception 实例 + """ + # 将 Exception 的类型、消息和 traceback 通过 exc_info 参数一并传给 .error() + self.error(msg, exc_info=(type(exp), exp, exp.__traceback__)) + +# 把它挂到 Logger 上 +logging.Logger.error_log_with_exp = error_log_with_exp + + + +# 1. 基本配置:设置日志级别为 INFO,默认输出到控制台 +logging.basicConfig(level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(message)s', + datefmt='%H:%M:%S') + +logger = logging.getLogger() + + +# ======= 常数类 ======= + +class CodedException(Exception): + def __init__(self, code, error_info=''): + super().__init__() + self.code = code + self.error_info = self.err_strs.get(code) + error_info + + def __str__(self): + return self.error_info + + +class ApiAccuracyCheckerException(CodedException): + ParseJsonFailed = 0 + UnsupportType = 1 + WrongValue = 2 + ApiWrong = 3 + err_strs = { + ParseJsonFailed: "[msprobe] Api Accuracy Checker parse json failed: ", + UnsupportType: "[msprobe] Api Accuracy Checker get unsupported type: ", + WrongValue: "[msprobe] Api Accuracy Checker get wrong value: ", + ApiWrong: "[msprobe] Api Accuracy Checker something wrong with api: ", + } + + +class FileCheckConst: + """ + Class for file check const + """ + READ_ABLE = "read" + WRITE_ABLE = "write" + READ_WRITE_ABLE = "read and write" + DIRECTORY_LENGTH = 4096 + FILE_NAME_LENGTH = 255 + FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" + FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' + PKL_SUFFIX = ".pkl" + NUMPY_SUFFIX = ".npy" + JSON_SUFFIX = ".json" + PT_SUFFIX = ".pt" + CSV_SUFFIX = ".csv" + XLSX_SUFFIX = ".xlsx" + YAML_SUFFIX = ".yaml" + IR_SUFFIX = ".ir" + ZIP_SUFFIX = ".zip" + SHELL_SUFFIX = ".sh" + MAX_PKL_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_NUMPY_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024 + MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_PT_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024 + MAX_CSV_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_XLSX_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_YAML_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_IR_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_ZIP_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024 + MAX_FILE_IN_ZIP_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024 + DIR = "dir" + FILE = "file" + DATA_DIR_AUTHORITY = 0o750 + DATA_FILE_AUTHORITY = 0o640 + FILE_SIZE_DICT = { + PKL_SUFFIX: MAX_PKL_SIZE, + NUMPY_SUFFIX: MAX_NUMPY_SIZE, + JSON_SUFFIX: MAX_JSON_SIZE, + PT_SUFFIX: MAX_PT_SIZE, + CSV_SUFFIX: MAX_CSV_SIZE, + XLSX_SUFFIX: MAX_XLSX_SIZE, + YAML_SUFFIX: MAX_YAML_SIZE, + IR_SUFFIX: MAX_IR_SIZE, + ZIP_SUFFIX: MAX_ZIP_SIZE + } + CSV_BLACK_LIST = r'^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]' + +class Const: + MAX_DEPTH = 10 + PT_FRAMEWORK = "pytorch" + MS_FRAMEWORK = "mindspore" + MT_FRAMEWORK = "mindtorch" + SEP = "." + KWARGS = 'kwargs' + INPUT = 'input' + OUTPUT = 'output' + INPUT_ARGS = 'input_args' + INPUT_KWARGS = 'input_kwargs' + GRAD_INPUT = 'grad_input' + GRAD_OUTPUT = 'grad_output' + BACKWARD = 'backward' + FORWARD = 'forward' + + +class CompareConst: + # compare result data + PASS = 'pass' + WARNING = 'Warning' + ERROR = 'error' + TRUE = 'TRUE' + FALSE = 'FALSE' + SKIP = 'SKIP' + + # compare result column name + COSINE = "Cosine" + EUC_DIST = "EucDist" + MAX_ABS_ERR = "MaxAbsErr" + MAX_RELATIVE_ERR = "MaxRelativeErr" + MIN_RELATIVE_ERR = "MinRelativeErr" + MEAN_RELATIVE_ERR = "MeanRelativeErr" + NORM_RELATIVE_ERR = "NormRelativeErr" + + # accuracy standards + COS_THRESHOLD = 0.99 + MAX_ABS_ERR_THRESHOLD = 0.001 + MAX_RELATIVE_ERR_THRESHOLD = 0.001 + COS_MAX_THRESHOLD = 0.9 + MAX_ABS_ERR_MAX_THRESHOLD = 1 + +class MsCompareConst: + # api_info field + MINT = "Mint" + MINT_FUNCTIONAL = "MintFunctional" + TENSOR_API = "Tensor" + FUNCTIONAL_API = "Functional" + FUSION_API = "FUSION" + + API_NAME_STR_LENGTH = 4 + MAX_RECURSION_DEPTH = 20 + + # Mindtorch api_info field + MINDTORCH_TENSOR = "Tensor" + MINDTORCH = "Torch" + MINDTORCH_FUNC = "Functional" + MINDTORCH_NPU = "NPU" + MINDTORCH_DIST = "Distributed" + + MT_VALID_API_TYPES = [ + MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR + ] + SUPPORTED_FUSION_LIST = ["flash_attention_score"] + + TASK_FIELD = "task" + STATISTICS_TASK = "statistics" + FRAMEWORK = "framework" + TENSOR_TASK = "tensor" + DUMP_DATA_DIR_FIELD = "dump_data_dir" + DATA_FIELD = "data" + + # supported api yaml + SUPPORTED_API_LIST_FILE = "checker_support_api.yaml" + SUPPORTED_TENSOR_LIST_KEY = "tensor" + + # detail_csv + DETAIL_CSV_API_NAME = "API Name" + DETAIL_CSV_BENCH_DTYPE = "Bench Dtype" + DETAIL_CSV_TESTED_DTYPE = "Tested Dtype" + DETAIL_CSV_SHAPE = "Shape" + DETAIL_CSV_PASS_STATUS = "Status" + DETAIL_CSV_MESSAGE = "Message" + DETAIL_CSV_FILE_NAME = "accuracy_checking_details" + + # result_csv + RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success" + RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success" + RESULT_CSV_FILE_NAME = "accuracy_checking_result" + + EPSILON = 1e-8 + + class ProcessStatus: + SUCCESS = "success" + API_NOT_FOUND = "api_not_found" + EXCEPTION_SKIP = "exception_skip" + +# ======= mindtorch支持 ======== + +import torch as mindtorch +from torch import Tensor as mindtorch_tensor +import torch.nn.functional as mindtorch_func +import torch.distributed as mindtorch_dist + +is_valid_pt_mt_env = True + + +def is_mindtorch(): + mindtorch_check_result = False + try: + import torch as test_torch + from mindspore import Tensor as MindsporeTensor + except ImportError: + return mindtorch_check_result + tensor = test_torch.tensor(0.0) + if isinstance(tensor, MindsporeTensor): + mindtorch_check_result = True + + return mindtorch_check_result + + +def remove_torch_related_paths(): + removed_paths = [] + if not is_mindtorch(): + return + try: + import torch as remove_torch + torch_file = remove_torch.__file__ + except ImportError: + return + + torch_dir = os.path.dirname(torch_file) + + torch_dir_path = Path(torch_dir).resolve() + parent_dir = torch_dir_path.parent + + paths_to_remove = [str(parent_dir)] + + for path in paths_to_remove: + try: + path_resolved = str(Path(path).resolve()) + except Exception as error: + logger.debug(f"Failed to resolve path {path}: {error}") + + + if path_resolved in sys.path: + index = sys.path.index(path_resolved) + removed_paths.append((path_resolved, index)) + sys.path.pop(index) + + return + + +def clear_torch_from_sys_modules(): + modules_to_remove = [] + for module in sys.modules: + if module == "torch" or module.startswith("torch."): + modules_to_remove.append(module) + + for module in modules_to_remove: + del sys.modules[module] + + +def set_pt_mt_env_invalid(): + global is_valid_pt_mt_env + is_valid_pt_mt_env = False + + +def delete_torch_paths(): + + if not is_mindtorch(): + set_pt_mt_env_invalid() + + clear_torch_from_sys_modules() + + for count_delete_env_path in range(MsCompareConst.MAX_RECURSION_DEPTH): + if not is_mindtorch(): + break + + remove_torch_related_paths() + + clear_torch_from_sys_modules() + + if count_delete_env_path >= MsCompareConst.MAX_RECURSION_DEPTH - 1: + raise Exception(f"Please check if you have a valid PyTorch and MindTorch environment, and ensure " + f"the PYTHONPATH environment variable depth does not exceed {Const.MAX_RECURSION_DEPTH}.") + + +if not is_mindtorch(): + set_pt_mt_env_invalid() + +else: + initial_sys_path = sys.path.copy() + delete_torch_paths() + + gc.collect() + + import torch + + if is_mindtorch(): + set_pt_mt_env_invalid() + + sys.path = initial_sys_path + + + +if not is_valid_pt_mt_env: + import torch + + + +# ======= 常数类 ======= + +import numpy as np +from mindspore._c_expression import typing +from mindspore.common import dtype as mstype + + +TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] +TORCH_BOOL_TYPE = ["torch.bool"] +TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int", + "torch.int64", "torch.long"] +TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float", + "torch.float64", "torch.double"] +TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128", "torch.cdouble"] +RAISE_PRECISION = {{ + "torch.float16": torch.float32, + "torch.half": torch.float32, + "torch.bfloat16": torch.float32, + "torch.float32": torch.float64, + "torch.float": torch.float64 +}} +THOUSANDTH_THRESHOLDING = 0.001 +BACKWARD = 'backward' +DIR = "dir" +FILE = "file" +READ_ABLE = "read" +WRITE_ABLE = "write" +READ_WRITE_ABLE = "read and write" +DIRECTORY_LENGTH = 4096 +FILE_NAME_LENGTH = 255 +SOFT_LINK_ERROR = "检测到软链接" +FILE_PERMISSION_ERROR = "文件权限错误" +INVALID_FILE_ERROR = "无效文件" +ILLEGAL_PATH_ERROR = "非法文件路径" +ILLEGAL_PARAM_ERROR = "非法打开方式" +FILE_TOO_LARGE_ERROR = "文件过大" +FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" +FILE_SIZE_DICT = {{ + ".pkl": 1073741824, # 1 * 1024 * 1024 * 1024 + ".npy": 10737418240, # 10 * 1024 * 1024 * 1024 + ".json": 1073741824, # 1 * 1024 * 1024 * 1024 + ".pt": 10737418240, # 10 * 1024 * 1024 * 1024 + ".csv": 1073741824, # 1 * 1024 * 1024 * 1024 + ".xlsx": 1073741824, # 1 * 1024 * 1024 * 1024 + ".yaml": 1073741824, # 1 * 1024 * 1024 * 1024 + ".ir": 1073741824 # 1 * 1024 * 1024 * 1024 +}} +COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024 + + +INT8 = "Int8" +UINT8 = "UInt8" +INT16 = "Int16" +UINT16 = "UInt16" +INT32 = "Int32" +UINT32 = "UInt32" +INT64 = "Int64" +UINT64 = "UInt64" +FLOAT16 = "Float16" +FLOAT32 = "Float32" +FLOAT64 = "Float64" +BOOL = "Bool" +BFLOAT16 = "BFloat16" +INT4 = "Int4" + +dtype_str_to_ms_dtype = { + INT8: mstype.int8, + UINT8: mstype.uint8, + INT16: mstype.int16, + UINT16: mstype.uint16, + INT32: mstype.int32, + UINT32: mstype.uint32, + INT64: mstype.int64, + UINT64: mstype.uint64, + FLOAT16: mstype.float16, + FLOAT32: mstype.float32, + FLOAT64: mstype.float64, + BOOL: mstype.bool_, + BFLOAT16: mstype.bfloat16, + INT4: mstype.qint4x2 +} +ms_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_ms_dtype.items()} + +dtype_str_to_np_dtype = { + INT8: np.int8, + UINT8: np.uint8, + INT16: np.int16, + UINT16: np.uint16, + INT32: np.int32, + UINT32: np.uint32, + INT64: np.int64, + UINT64: np.uint64, + FLOAT16: np.float16, + FLOAT32: np.float32, + FLOAT64: np.float64, + BOOL: np.bool_ +} +np_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_np_dtype.items()} + +dtype_str_to_torch_dtype = { + INT8: torch.int8, + UINT8: torch.uint8, + INT16: torch.int16, + INT32: torch.int32, + INT64: torch.int64, + FLOAT16: torch.float16, + FLOAT32: torch.float32, + FLOAT64: torch.float64, + BOOL: torch.bool, + BFLOAT16: torch.bfloat16, +} +torch_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_torch_dtype.items()} + + +dtype_str_to_mindtorch_dtype = { + INT8: mindtorch.int8, + UINT8: mindtorch.uint8, + INT16: mindtorch.int16, + INT32: mindtorch.int32, + INT64: mindtorch.int64, + FLOAT16: mindtorch.float16, + FLOAT32: mindtorch.float32, + FLOAT64: mindtorch.float64, + BOOL: mindtorch.bool, + BFLOAT16: mindtorch.bfloat16, +} +mindtorch_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_mindtorch_dtype.items()} + +MINDSPORE_TENSOR_TYPE_STR = "mindspore.Tensor" +BOOL_TYPE_STR = "bool" +INT_TYPE_STR = "int" +FLOAT_TYPE_STR = "float" +SLICE_TYPE_STR = "slice" +TUPLE_TYPE_STR = "tuple" +STR_TYPE_STR = "str" +MINDSPORE_DTYPE_TYPE_STR = "mindspore.dtype" +TORCH_DTYPE_TYPE_STR = "torch.dtype" + +api_info_type_str_to_type = { + MINDSPORE_TENSOR_TYPE_STR: mindspore.Tensor, + BOOL_TYPE_STR: bool, + INT_TYPE_STR: int, + FLOAT_TYPE_STR: float, + SLICE_TYPE_STR: slice, + STR_TYPE_STR: str, + MINDSPORE_DTYPE_TYPE_STR: typing.Type, +} +type_to_api_info_type_str = {value: key for key, value in api_info_type_str_to_type.items()} + +DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE = np.float64 +DEFAULT_CONSTRUCT_NP_INT_DTYPE = np.float64 +DEFAULT_CONSTRUCT_NP_UINT_DTYPE = np.float64 + +float_dtype_str_list = [ + FLOAT16, + FLOAT32, + FLOAT64, + BFLOAT16, +] + +int_dtype_str_list = [ + INT8, + INT16, + INT32, + INT64, + BOOL, + INT4, +] + +uint_dtype_str_list = [ + UINT8, + UINT16, + UINT32, + UINT64, +] + +# ======= 比对类 ======= + +class CompareResult: + def __init__(self, compare_value, pass_status, err_msg): + self.compare_value = compare_value + self.pass_status = pass_status + self.err_msg = err_msg + + +class BaseCompareAlgorithm(ABC): + def __init__(self) -> None: + super().__init__() + self.compare_algorithm_name = None + self.err_msg_mapping = { + CompareConst.COSINE: { + CompareConst.PASS: "", + CompareConst.ERROR: f"cosine similarity is less than threshold: {CompareConst.COS_THRESHOLD} ", + CompareConst.SKIP: "two inputs are not valid for computing cosine similarity, skip comparing ", + }, + CompareConst.MAX_ABS_ERR: { + CompareConst.PASS: "", + CompareConst.ERROR: "max absolute difference is greater than " \ + f"threshold: {CompareConst.MAX_ABS_ERR_THRESHOLD} ", + CompareConst.SKIP: "two inputs are not valid for computing max absolute difference, skip comparing ", + }, + CompareConst.MAX_RELATIVE_ERR: { + CompareConst.PASS: "", + CompareConst.ERROR: "", + CompareConst.SKIP: "", + }, + } + + def __call__(self, bench_compute_element, tested_compute_element): + ''' + Args: + bench_compute_element: ComputeElement + tested_compute_element: ComputeElement + + Return: + compare_result: CompareResult + ''' + if self.check_validity(bench_compute_element, tested_compute_element): + compare_value = self.run_compare(bench_compute_element, tested_compute_element) + pass_status = self.check_pass(compare_value) + else: + logger.warning(f"not suitable for computing {self.compare_algorithm_name}, skip this.") + compare_value = None + pass_status = CompareConst.SKIP + + err_msg = self.err_msg_mapping.get(self.compare_algorithm_name).get(pass_status) + + compare_result = CompareResult(compare_value, pass_status, err_msg) + return compare_result + + @staticmethod + def convert_to_np_float64_ndarray(tensor): + if isinstance(tensor, mindspore.Tensor): + ndarray = tensor.astype(mindspore.float64).numpy() + elif isinstance(tensor, torch.Tensor): + ndarray = tensor.to(torch.float64, copy=True).numpy() + else: + err_msg = "BaseCompareAlgorithm.convert_to_np_float64_ndarray failed: " \ + "input is not mindspore.Tensor or torch.Tensor" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) + return ndarray + + @staticmethod + def check_two_tensor(bench_compute_element, tested_compute_element): + bench_parameter = bench_compute_element.get_parameter() + tested_parameter = tested_compute_element.get_parameter() + + bench_is_tensor = isinstance(bench_parameter, (mindspore.Tensor, torch.Tensor)) + tested_is_tensor = isinstance(tested_parameter, (mindspore.Tensor, torch.Tensor)) + shape_same = bench_compute_element.get_shape() == tested_compute_element.get_shape() + return bench_is_tensor and tested_is_tensor and shape_same + + @abstractmethod + def check_validity(self, bench_compute_element, tested_compute_element): + ''' + Args: + bench_compute_element: ComputeElement + tested_compute_element: ComputeElement + + Return: + check_res: boolean + ''' + raise NotImplementedError + + @abstractmethod + def run_compare(self, bench_compute_element, tested_compute_element): + ''' + Args: + bench_compute_element: ComputeElement + tested_compute_element: ComputeElement + + Return: + compare_value: float/int + ''' + raise NotImplementedError + + @abstractmethod + def check_pass(self, compare_value): + ''' + Args: + compare_value: float/int + + Return: + pass_status: str + ''' + raise NotImplementedError + + +class CosineSimilarityCompareAlgorithm(BaseCompareAlgorithm): + def __init__(self) -> None: + super().__init__() + self.compare_algorithm_name = CompareConst.COSINE + + def check_validity(self, bench_compute_element, tested_compute_element): + return self.check_two_tensor(bench_compute_element, tested_compute_element) + + def run_compare(self, bench_compute_element, tested_compute_element): + bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter()) + tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter()) + + bench_norm = np.linalg.norm(bench_ndarray) + tested_norm = np.linalg.norm(tested_ndarray) + dot_product = np.dot(bench_ndarray.flatten(), tested_ndarray.flatten()) + cosine_similarity = (MsCompareConst.EPSILON + dot_product) / (MsCompareConst.EPSILON + bench_norm * tested_norm) + return cosine_similarity + + def check_pass(self, compare_value): + if compare_value > CompareConst.COS_THRESHOLD: + return CompareConst.PASS + else: + return CompareConst.ERROR + + +class MaxAbsoluteDiffCompareAlgorithm(BaseCompareAlgorithm): + def __init__(self) -> None: + super().__init__() + self.compare_algorithm_name = CompareConst.MAX_ABS_ERR + + def check_validity(self, bench_compute_element, tested_compute_element): + return self.check_two_tensor(bench_compute_element, tested_compute_element) + + def run_compare(self, bench_compute_element, tested_compute_element): + bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter()) + tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter()) + + max_absolute_diff = np.max(np.abs(bench_ndarray - tested_ndarray)) + return max_absolute_diff + + def check_pass(self, compare_value): + if compare_value < CompareConst.MAX_ABS_ERR_THRESHOLD: + return CompareConst.PASS + else: + return CompareConst.ERROR + + +class MaxRelativeDiffCompareAlgorithm(BaseCompareAlgorithm): + def __init__(self) -> None: + super().__init__() + self.compare_algorithm_name = CompareConst.MAX_RELATIVE_ERR + + def check_validity(self, bench_compute_element, tested_compute_element): + return self.check_two_tensor(bench_compute_element, tested_compute_element) + + def run_compare(self, bench_compute_element, tested_compute_element): + bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter()) + tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter()) + + abs_diff = np.abs(bench_ndarray - tested_ndarray) + bench_ndarray_nonzero = np.abs(bench_ndarray) + (bench_ndarray == 0) * MsCompareConst.EPSILON + max_relative_diff = np.max(abs_diff / bench_ndarray_nonzero) + return max_relative_diff + + def check_pass(self, compare_value): + if compare_value < CompareConst.MAX_RELATIVE_ERR_THRESHOLD: + return CompareConst.PASS + else: + return CompareConst.ERROR + + +compare_algorithms = { + CompareConst.COSINE: CosineSimilarityCompareAlgorithm(), + CompareConst.MAX_ABS_ERR: MaxAbsoluteDiffCompareAlgorithm(), + CompareConst.MAX_RELATIVE_ERR: MaxRelativeDiffCompareAlgorithm(), +} + + + +class CompareStandard(Enum): + BINARY_EQUALITY_STANDARD = auto() + ABSOLUTE_THRESHOLD_STANDARD = auto() + ULP_ERROR_STANDARD = auto() + BENCHMARK_STANDARD = auto() + THOUSANDTH_STANDARD = auto() + + +class CompareStandard(Enum): + BINARY_EQUALITY_STANDARD = auto() + ABSOLUTE_THRESHOLD_STANDARD = auto() + ULP_ERROR_STANDARD = auto() + BENCHMARK_STANDARD = auto() + THOUSANDTH_STANDARD = auto() + + +# ======== 文件操作类 ========== + +from collections import defaultdict +from functools import wraps + + +def check_and_get_from_json_dict(dict_instance, key, key_description, accepted_type=None, accepted_value=None): + ''' + Args: + dict_instance: dict, dict parsed from input json + key: str + key_description: str + accepted_type: tuple + accepted_value: Union[tuple, list] + + Return: + value, the corresponding value of "key" in "dict_instance" + + Exception: + raise ApiAccuracyCheckerException.ParseJsonFailed error when + 1. dict_instance is not a dict + 2. value is None + 3. value is not accepted type + 4. value is not accepted value + ''' + if not isinstance(dict_instance, dict): + error_info = "check_and_get_from_json_dict failed: input is not a dict" + raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info) + value = dict_instance.get(key) + if value is None: + error_info = f"check_and_get_from_json_dict failed: {key_description} is missing" + raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info) + elif accepted_type is not None and not isinstance(value, accepted_type): + error_info = f"check_and_get_from_json_dict failed: {key_description} is not accepted type: {accepted_type}" + raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info) + elif accepted_value is not None and value not in accepted_value: + error_info = f"check_and_get_from_json_dict failed: {key_description} is not accepted value: {accepted_value}" + raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info) + return value + + +def convert_to_tuple(args): + if isinstance(args, (tuple, list)): + return tuple(args) + else: + input_list = [args] + return tuple(input_list) + + +def trim_output_compute_element_list(compute_element_list, forward_or_backward): + ''' + Args: + compute_element_list: List[ComputeElement] + forward_or_backward: str, Union["forward", "backward"] + ''' + trimmed_list = [] + for compute_element in compute_element_list: + if compute_element.get_parameter() is None or \ + (forward_or_backward == Const.BACKWARD and compute_element.get_dtype() not in float_dtype_str_list): + # trim case: 1. parameter is None. 2. backward output has non float parameter + continue + trimmed_list.append(compute_element) + return trimmed_list + + + + +# 记录工具函数递归的深度 +recursion_depth = defaultdict(int) + + +def recursion_depth_decorator(func_info, max_depth=Const.MAX_DEPTH): + """装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。""" + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + func_id = id(func) + recursion_depth[func_id] += 1 + + try: + result = func(*args, **kwargs) + finally: + recursion_depth[func_id] -= 1 + return result + + return wrapper + + return decorator + + + +class FileChecker: + """ + The class for check file. + + Attributes: + file_path: The file or dictionary path to be verified. + path_type: file or dictionary + ability(str): FileCheckConst.WRITE_ABLE or FileCheckConst.READ_ABLE to set file has writability or readability + file_type(str): The correct file type for file + """ + + def __init__(self, file_path, path_type, ability=None, file_type=None, is_script=True): + self.file_path = file_path + self.path_type = self._check_path_type(path_type) + self.ability = ability + self.file_type = file_type + self.is_script = is_script + + @staticmethod + def _check_path_type(path_type): + if path_type not in [FileCheckConst.DIR, FileCheckConst.FILE]: + logger.error(f'The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}.') + raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) + return path_type + + def common_check(self): + """ + 功能:用户校验基本文件权限:软连接、文件长度、是否存在、读写权限、文件属组、文件特殊字符 + 注意:文件后缀的合法性,非通用操作,可使用其他独立接口实现 + """ + check_path_exists(self.file_path) + check_link(self.file_path) + self.file_path = os.path.realpath(self.file_path) + check_path_length(self.file_path) + check_path_type(self.file_path, self.path_type) + self.check_path_ability() + if self.is_script: + check_path_owner_consistent(self.file_path) + check_path_pattern_valid(self.file_path) + check_common_file_size(self.file_path) + check_file_suffix(self.file_path, self.file_type) + if self.path_type == FileCheckConst.FILE: + check_dirpath_before_read(self.file_path) + return self.file_path + + def check_path_ability(self): + if self.ability == FileCheckConst.WRITE_ABLE: + check_path_writability(self.file_path) + if self.ability == FileCheckConst.READ_ABLE: + check_path_readability(self.file_path) + if self.ability == FileCheckConst.READ_WRITE_ABLE: + check_path_readability(self.file_path) + check_path_writability(self.file_path) + + +class FileOpen: + """ + The class for open file by a safe way. + + Attributes: + file_path: The file or dictionary path to be opened. + mode(str): The file open mode + """ + SUPPORT_READ_MODE = ["r", "rb"] + SUPPORT_WRITE_MODE = ["w", "wb", "a", "ab"] + SUPPORT_READ_WRITE_MODE = ["r+", "rb+", "w+", "wb+", "a+", "ab+"] + + def __init__(self, file_path, mode, encoding='utf-8'): + self.file_path = file_path + self.mode = mode + self.encoding = encoding + self._handle = None + + def __enter__(self): + self.check_file_path() + binary_mode = "b" + if binary_mode not in self.mode: + self._handle = open(self.file_path, self.mode, encoding=self.encoding) + else: + self._handle = open(self.file_path, self.mode) + return self._handle + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._handle: + self._handle.close() + + def check_file_path(self): + support_mode = self.SUPPORT_READ_MODE + self.SUPPORT_WRITE_MODE + self.SUPPORT_READ_WRITE_MODE + if self.mode not in support_mode: + logger.error("File open not support %s mode" % self.mode) + raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) + check_link(self.file_path) + self.file_path = os.path.realpath(self.file_path) + check_path_length(self.file_path) + self.check_ability_and_owner() + check_path_pattern_valid(self.file_path) + if os.path.exists(self.file_path): + check_common_file_size(self.file_path) + check_dirpath_before_read(self.file_path) + + def check_ability_and_owner(self): + if self.mode in self.SUPPORT_READ_MODE: + check_path_exists(self.file_path) + check_path_readability(self.file_path) + check_path_owner_consistent(self.file_path) + if self.mode in self.SUPPORT_WRITE_MODE and os.path.exists(self.file_path): + check_path_writability(self.file_path) + check_path_owner_consistent(self.file_path) + if self.mode in self.SUPPORT_READ_WRITE_MODE and os.path.exists(self.file_path): + check_path_readability(self.file_path) + check_path_writability(self.file_path) + check_path_owner_consistent(self.file_path) + + +def check_link(path): + abs_path = os.path.abspath(path) + if os.path.islink(abs_path): + logger.error('The file path {} is a soft link.'.format(path)) + raise FileCheckException(FileCheckException.SOFT_LINK_ERROR) + + +def check_path_length(path, name_length=None): + file_max_name_length = name_length if name_length else FileCheckConst.FILE_NAME_LENGTH + if len(path) > FileCheckConst.DIRECTORY_LENGTH or \ + len(os.path.basename(path)) > file_max_name_length: + logger.error('The file path length exceeds limit.') + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) + + +def check_path_exists(path): + if not os.path.exists(path): + logger.error('The file path %s does not exist.' % path) + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) + + +def check_path_readability(path): + if not os.access(path, os.R_OK): + logger.error('The file path %s is not readable.' % path) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + + +def check_path_writability(path): + if not os.access(path, os.W_OK): + logger.error('The file path %s is not writable.' % path) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + + +def check_path_executable(path): + if not os.access(path, os.X_OK): + logger.error('The file path %s is not executable.' % path) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + + +def check_other_user_writable(path): + st = os.stat(path) + if st.st_mode & 0o002: + logger.error('The file path %s may be insecure because other users have write permissions. ' % path) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + + +def check_path_owner_consistent(path): + file_owner = os.stat(path).st_uid + if file_owner != os.getuid() and os.getuid() != 0: + logger.error('The file path %s may be insecure because is does not belong to you.' % path) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + + +def check_path_pattern_valid(path): + if not re.match(FileCheckConst.FILE_VALID_PATTERN, path): + logger.error('The file path %s contains special characters.' % (path)) + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) + + +def check_file_size(file_path, max_size): + try: + file_size = os.path.getsize(file_path) + except OSError as os_error: + logger.error(f'Failed to open "{file_path}". {str(os_error)}') + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) from os_error + if file_size >= max_size: + logger.error(f'The size ({file_size}) of {file_path} exceeds ({max_size}) bytes, tools not support.') + raise FileCheckException(FileCheckException.FILE_TOO_LARGE_ERROR) + + +def check_common_file_size(file_path): + if os.path.isfile(file_path): + for suffix, max_size in FileCheckConst.FILE_SIZE_DICT.items(): + if file_path.endswith(suffix): + check_file_size(file_path, max_size) + return + check_file_size(file_path, FileCheckConst.COMMOM_FILE_SIZE) + + +def check_file_suffix(file_path, file_suffix): + if file_suffix: + if not file_path.endswith(file_suffix): + logger.error(f"The {file_path} should be a {file_suffix} file!") + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) + + +def check_path_type(file_path, file_type): + if file_type == FileCheckConst.FILE: + if not os.path.isfile(file_path): + logger.error(f"The {file_path} should be a file!") + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) + if file_type == FileCheckConst.DIR: + if not os.path.isdir(file_path): + logger.error(f"The {file_path} should be a dictionary!") + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) + +def make_dir(dir_path): + check_path_before_create(dir_path) + dir_path = os.path.realpath(dir_path) + if os.path.isdir(dir_path): + return + try: + os.makedirs(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True) + except OSError as ex: + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, + f"Failed to create {dir_path}. " + f"Please check the path permission or disk space. {str(ex)}") from ex + file_check = FileChecker(dir_path, FileCheckConst.DIR) + file_check.common_check() + + + + +@recursion_depth_decorator('msprobe.core.common.file_utils.create_directory', max_depth=16) +def create_directory(dir_path): + """ + Function Description: + creating a safe directory with specified permissions + Parameter: + dir_path: directory path + Exception Description: + when invalid data throw exception + """ + check_link(dir_path) + check_path_before_create(dir_path) + dir_path = os.path.realpath(dir_path) + parent_dir = os.path.dirname(dir_path) + if not os.path.isdir(parent_dir): + create_directory(parent_dir) + make_dir(dir_path) + + +def check_path_before_create(path): + check_link(path) + if path_len_exceeds_limit(path): + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, 'The file path length exceeds limit.') + + if not re.match(FileCheckConst.FILE_PATTERN, os.path.realpath(path)): + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, + 'The file path {} contains special characters.'.format(path)) + + +def check_dirpath_before_read(path): + path = os.path.realpath(path) + dirpath = os.path.dirname(path) + + +def check_file_or_directory_path(path, isdir=False): + """ + Function Description: + check whether the path is valid + Parameter: + path: the path to check + isdir: the path is dir or file + Exception Description: + when invalid data throw exception + """ + if isdir: + path_checker = FileChecker(path, FileCheckConst.DIR, FileCheckConst.WRITE_ABLE) + else: + path_checker = FileChecker(path, FileCheckConst.FILE, FileCheckConst.READ_ABLE) + path_checker.common_check() + + +def change_mode(path, mode): + if not os.path.exists(path) or os.path.islink(path): + return + try: + os.chmod(path, mode) + except PermissionError as ex: + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR, + 'Failed to change {} authority. {}'.format(path, str(ex))) from ex + + +def path_len_exceeds_limit(file_path): + return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \ + len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH + +def load_npy(filepath): + check_file_or_directory_path(filepath) + try: + npy = np.load(filepath, allow_pickle=False) + except Exception as e: + logger.error(f"The numpy file failed to load. Please check the path: {filepath}.") + raise RuntimeError(f"Load numpy file {filepath} failed.") from e + return npy + +def write_csv(data, filepath, mode="a+", malicious_check=False): + def csv_value_is_valid(value: str) -> bool: + if not isinstance(value, str): + return True + try: + # -1.00 or +1.00 should be considered as digit numbers + float(value) + except ValueError: + # otherwise, they will be considered as formular injections + return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value)) + return True + + if malicious_check: + for row in data: + for cell in row: + if not csv_value_is_valid(cell): + raise RuntimeError(f"Malicious value [{cell}] is not allowed " + f"to be written into the csv: {filepath}.") + + check_path_before_create(filepath) + file_path = os.path.realpath(filepath) + try: + with FileOpen(filepath, mode, encoding='utf-8-sig') as f: + writer = csv.writer(f) + writer.writerows(data) + except Exception as e: + logger.error(f'Save csv file "{os.path.basename(file_path)}" failed') + raise RuntimeError(f"Save csv file {file_path} failed.") from e + change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) + print(f"file_path:{file_path}") + + + +def write_csv_header(csv_path, header_func): + """如果是第一次写入,则写入 CSV 表头""" + header = header_func() # 获取表头 + logger.debug(f"Writing CSV header: {header}") + write_csv([header], csv_path, mode="a+") + + +def get_result_csv_header(): + """获取结果 CSV 文件的表头""" + return [ + MsCompareConst.DETAIL_CSV_API_NAME, + MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS, + MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS, + MsCompareConst.DETAIL_CSV_MESSAGE, + ] + + +def get_detail_csv_header(): + """获取详细 CSV 文件的表头""" + detail_csv_header_basic_info = [ + MsCompareConst.DETAIL_CSV_API_NAME, + MsCompareConst.DETAIL_CSV_BENCH_DTYPE, + MsCompareConst.DETAIL_CSV_TESTED_DTYPE, + MsCompareConst.DETAIL_CSV_SHAPE, + ] + detail_csv_header_compare_result = list(compare_algorithms.keys()) + detail_csv_header_status = [ + MsCompareConst.DETAIL_CSV_PASS_STATUS, + MsCompareConst.DETAIL_CSV_MESSAGE, + ] + return detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status + + +def check_csv_header(headers, required_constants, csv_path): + """校验 CSV 文件表头是否包含所有必需的常量""" + missing_constants = [const for const in required_constants if not any(const in header for header in headers)] + + if missing_constants: + raise MsprobeBaseException( + MsprobeBaseException.MISSING_HEADER_ERROR, + f"{csv_path} 缺少以下必需的表头字段: {missing_constants}" + ) +def add_time_as_suffix(name): + return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))) + + +# ======= 结果落盘管理类 ======== + +class DataManager: + def __init__(self, csv_dir, result_csv_path): + self.results = {} + self.results_exception_skip = {} + self.is_first_write = True # 标记用于添加表头 + self.csv_dir = csv_dir + self.api_names_set = set() # 存储已经出现的 API 名称的集合 + # 如果传入了 result_csv_path,则启用断点续检 + if result_csv_path: + self.resume_from_last_csv(result_csv_path) + self.initialize_api_names_set(result_csv_path) + else: + # 默认情况下,设置输出路径为空,等待首次写入时初始化 + self.result_out_path = os.path.join(self.csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME)) + self.detail_out_path = os.path.join( + self.csv_dir, + os.path.basename(self.result_out_path).replace("result", "details") + ) + + if self.detail_out_path and os.path.exists(self.detail_out_path): + check_file_or_directory_path(self.detail_out_path) + + if self.result_out_path and os.path.exists(self.result_out_path): + check_file_or_directory_path(self.result_out_path) + + def initialize_api_names_set(self, result_csv_path): + """读取现有的 CSV 文件并存储已经出现的 API 名称到集合中""" + # 使用新的 read_csv 函数读取数据 + csv_data = read_csv(result_csv_path, as_pd=False) + + # 读取标题行 + headers = csv_data[0] if csv_data else [] # 如果文件为空,则 headers 会为空 + + # 使用提取的表头校验函数 + if check_csv_header(headers, get_result_csv_header(), result_csv_path): + + # 获取 "API Name" 列的索引 + api_name_index = None + for i, header in enumerate(headers): + if MsCompareConst.DETAIL_CSV_API_NAME in header: # CSV 文件的标题行包含了字节顺序标记,所以使用通过包含方式来查找 + api_name_index = i + break + + if api_name_index is None: + logger.warning(f"{result_csv_path} No column contains 'API Name'.") + return + + # 读取每一行的 API 名称 + for row in csv_data[1:]: # 跳过标题行,从第二行开始 + if row and len(row) > api_name_index: + api_name = row[api_name_index] + if api_name: + self.api_names_set.add(api_name) + + logger.debug(f"Initialized API names set from existing CSV: {self.api_names_set}") + + def is_unique_api(self, api_name): + """检查 API 名称是否唯一,如果已经存在则返回 False,否则加入集合并返回 True""" + if api_name in self.api_names_set: + return False + self.api_names_set.add(api_name) + return True + + def resume_from_last_csv(self, result_csv_path): + """从上次运行的 result_csv_path 恢复断点""" + # 获取上次的目录路径 + last_dir = os.path.dirname(result_csv_path) + + # 设置当前目录和输出路径,确保在首次写入时使用 + self.csv_dir = last_dir + self.detail_out_path = os.path.join(last_dir, os.path.basename(result_csv_path).replace("result", "details")) + if self.detail_out_path and os.path.exists(self.detail_out_path): + check_file_or_directory_path(self.detail_out_path) + self.result_out_path = result_csv_path + self.is_first_write = False + + def save_results(self, api_name_str): + if self.is_first_write: + # 直接写入表头 + logger.info("Writing CSV headers for the first time.") + write_csv_header(self.detail_out_path, get_detail_csv_header) + write_csv_header(self.result_out_path, get_result_csv_header) + self.is_first_write = False # 写入后标记为 False,避免重复写入表头 + + """写入详细输出和结果摘要并清理结果""" + logger.debug("Starting to write detailed output to CSV.") + self.to_detail_csv(self.detail_out_path) + logger.debug(f"Detailed output for {api_name_str} written to {self.detail_out_path}.") + + logger.debug("Starting to write result summary to CSV.") + self.to_result_csv(self.result_out_path) + logger.debug(f"Result summary for {api_name_str} written to {self.result_out_path}.") + + # 清理记录,准备下一次调用 + self.clear_results() + + def record(self, output_list): + if output_list is None: + return + for output in output_list: + api_real_name, forward_or_backward, basic_info, compare_result_dict = output + key = (api_real_name, forward_or_backward) + if key not in self.results: + self.results[key] = [] + self.results[key].append((basic_info, compare_result_dict)) + logger.debug(f"Complete self.results after recording: {self.results}") + + def record_exception_skip(self, api_name, forward_or_backward, err_msg): + ''' + record exception_skip information into self.record_exception_skip. + self.record_exception_skip: dict{str: dict{"forward": str/None, "backward": str/None}} + string in key is api_name, string in value is err_msg + ''' + if api_name not in self.results_exception_skip: + self.results_exception_skip[api_name] = {Const.FORWARD: None, Const.BACKWARD: None} + self.results_exception_skip[api_name][forward_or_backward] = err_msg + + def clear_results(self): + """清空 self.results 数据""" + logger.debug("Clearing self.results data.") + self.results.clear() + self.results_exception_skip.clear() + + def to_detail_csv(self, csv_path): + logger.debug("Preparing detail CSV headers and rows.") + detail_csv = [] + + detail_csv_header_compare_result = list(compare_algorithms.keys()) + + for _, results in self.results.items(): + for res in results: + basic_info, compare_result_dict = res + csv_row_basic_info = [ + basic_info.api_name, + basic_info.bench_dtype, + basic_info.tested_dtype, + basic_info.shape + ] + csv_row_compare_result = [ + compare_result_dict.get(algorithm_name).compare_value + for algorithm_name in detail_csv_header_compare_result + ] + csv_row_status = [basic_info.status, basic_info.err_msg] + csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status + detail_csv.append(csv_row) + logger.debug(f"Detail CSV row added: {csv_row}") + + logger.debug(f"Writing detail CSV to {csv_path}.") + write_csv(detail_csv, csv_path, mode="a+") + logger.debug(f"Detail CSV written successfully to {csv_path}.") + + def to_result_csv(self, csv_path): + ''' + depend on both self.results and self.results_exception_skip + ''' + logger.debug("Preparing result CSV data.") + result_csv = [] + + result_csv_dict = {} + for key, results in self.results.items(): + api_real_name, forward_or_backward = key + pass_status = CompareConst.PASS + overall_err_msg = "" + + for res in results: + basic_info, _ = res + if basic_info.status != CompareConst.PASS: + pass_status = CompareConst.ERROR + overall_err_msg += basic_info.err_msg + + overall_err_msg = "" if pass_status == CompareConst.PASS else overall_err_msg + + if api_real_name not in result_csv_dict: + result_csv_dict[api_real_name] = ResultCsvEntry() + if forward_or_backward == Const.FORWARD: + result_csv_dict[api_real_name].forward_pass_status = pass_status + result_csv_dict[api_real_name].forward_err_msg = overall_err_msg + else: + result_csv_dict[api_real_name].backward_pass_status = pass_status + result_csv_dict[api_real_name].backward_err_msg = overall_err_msg + + for api_name, entry in result_csv_dict.items(): + overall_err_msg = "" if (entry.forward_pass_status == CompareConst.PASS and + entry.backward_pass_status == CompareConst.PASS) else \ + entry.forward_err_msg + entry.backward_err_msg + row = [ + api_name, + entry.forward_pass_status, + entry.backward_pass_status, + overall_err_msg + ] + # change row if this api has exception_skip information + if api_name in self.results_exception_skip: + if self.results_exception_skip[api_name][Const.FORWARD] is not None: + row[1] = CompareConst.SKIP + row[-1] += self.results_exception_skip[api_name][Const.FORWARD] + if self.results_exception_skip[api_name][Const.BACKWARD] is not None: + row[2] = CompareConst.SKIP + row[-1] += self.results_exception_skip[api_name][Const.BACKWARD] + del self.results_exception_skip[api_name] + result_csv.append(row) + logger.debug(f"Result CSV row added: {row}") + for api_name in self.results_exception_skip: + current_exception_skip = self.results_exception_skip[api_name] + forward_status = None + backward_status = None + err_msg = "" + if current_exception_skip[Const.FORWARD] is not None: + forward_status = CompareConst.SKIP + err_msg += current_exception_skip[Const.FORWARD] + if current_exception_skip[Const.BACKWARD] is not None: + backward_status = CompareConst.SKIP + err_msg += current_exception_skip[Const.BACKWARD] + row = [api_name, forward_status, backward_status, err_msg] + result_csv.append(row) + + write_csv(result_csv, csv_path, mode="a+") + logger.debug(f"Result CSV written successfully to {csv_path}.") + + # 设置标记为 False,防止后续重复添加表头 + self.is_first_write = False + +# ======== 全局变量类 ======= + +class GlobalContext: + def __init__(self): + self.is_constructed = True + self.dump_data_dir = "" + self.framework = Const.MS_FRAMEWORK + + def init(self, is_constructed, dump_data_dir, framework): + self.is_constructed = is_constructed + self.dump_data_dir = dump_data_dir + self.framework = framework + + def get_dump_data_dir(self): + return self.dump_data_dir + + def get_is_constructed(self): + return self.is_constructed + + def get_framework(self): + return self.framework + + +global_context = GlobalContext() + +# ======== 输入类型类 ======= + +def seed_all(seed={random_seed}): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.use_deterministic_algorithms(True) + mindtorch.manual_seed(seed) + mindtorch.use_deterministic_algorithms(True) + mindspore.set_deterministic(True) + +class ApiInputAggregation: + def __init__(self, inputs, kwargs, gradient_inputs) -> None: + """ + Args: + inputs: List[ComputeElement] + kwargs: dict{str: ComputeElement} + gradient_inputs: Union[List[ComputeElement], None] + """ + self.inputs = inputs + self.kwargs = kwargs + self.gradient_inputs = gradient_inputs + + +api_parent_module_mapping = { + (MsCompareConst.MINT, Const.MS_FRAMEWORK): mindspore.mint, + (MsCompareConst.MINT, Const.PT_FRAMEWORK): torch, + (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional, + (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional, + (MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): mindspore.Tensor, + (MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): torch.Tensor, + (MsCompareConst.MINDTORCH_TENSOR, Const.MT_FRAMEWORK): mindtorch_tensor, + (MsCompareConst.MINDTORCH_TENSOR, Const.PT_FRAMEWORK): torch.Tensor, + (MsCompareConst.MINDTORCH, Const.MT_FRAMEWORK): mindtorch, + (MsCompareConst.MINDTORCH, Const.PT_FRAMEWORK): torch, + (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): mindtorch_func, + (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): torch.nn.functional, + (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): mindtorch_dist, + (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed, + (MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): mindspore.ops + +} + + +api_parent_module_str_mapping = { + (MsCompareConst.MINT, Const.MS_FRAMEWORK): "mindspore.mint", + (MsCompareConst.MINT, Const.PT_FRAMEWORK): "torch", + (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): "mindspore.mint.nn.functional", + (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): "torch.nn.functional", + (MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): "mindspore.Tensor", + (MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): "torch.Tensor", + (MsCompareConst.MINDTORCH_TENSOR, Const.MT_FRAMEWORK): "mindtorch_tensor", + (MsCompareConst.MINDTORCH_TENSOR, Const.PT_FRAMEWORK): "torch.Tensor", + (MsCompareConst.MINDTORCH, Const.MT_FRAMEWORK): "mindtorch", + (MsCompareConst.MINDTORCH, Const.PT_FRAMEWORK): "torch", + (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): "mindtorch_func", + (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): "torch.nn.functional", + (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): "mindtorch_dist", + (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed", + (MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): "mindspore.ops" +} + + +class ApiRunner: + def __call__(self, api_input_aggregation, api_name_str, forward_or_backward=Const.FORWARD, + api_platform=Const.MS_FRAMEWORK): + ''' + Args: + api_input_aggregation: ApiInputAggregation + api_name_str: str, e.g. "MintFunctional.relu.0" + forward_or_backward: str, Union["forward", "backward"] + api_platform: str, Union["mindspore", "torch", "mindtorch"] + + Return: + outputs: list[ComputeElement] + + Description: + run mindspore.mint/torch api + ''' + + api_type_str, api_sub_name = self.get_info_from_name(api_name_str, api_platform) + api_instance = self.get_api_instance(api_type_str, api_sub_name, api_platform) + + return self.run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform) + + @staticmethod + def get_info_from_name(api_name_str, api_platform=Const.MS_FRAMEWORK): + """ + Args: + api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0" + api_platform: str, the platform for the API, which can be either "mindspore" or "mindtorch". + It specifies which framework is being used. Default is "mindspore". + Return: + api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Torch", "Functional"] + api_sub_name: str, e.g. "relu" + """ + api_name_list = api_name_str.split(Const.SEP) + if len(api_name_list) != 3: + err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) + api_type_str, api_sub_name = api_name_list[0], api_name_list[1] + if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API, + MsCompareConst.FUNCTIONAL_API] \ + and api_platform == Const.MS_FRAMEWORK: + err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) + + if api_type_str not in MsCompareConst.MT_VALID_API_TYPES and api_platform == Const.MT_FRAMEWORK: + err_msg = f"ApiRunner.get_info_from_name failed: not torch, functional or Tensor api" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) + return api_type_str, api_sub_name + + @staticmethod + def get_api_instance(api_type_str, api_sub_name, api_platform): + """ + Args: + api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Functional"] + api_sub_name: str, e.g. "relu" + api_platform: str: Union["mindspore", "pytorch"] + + Return: + api_instance: function object + + Description: + get mindspore.mint/torch api function + mindspore.mint.{api_sub_name} <--> torch.{api_sub_name} + mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name} + """ + + api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform)) + api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform)) + full_api_name = api_parent_module_str + Const.SEP + api_sub_name + + if not hasattr(api_parent_module, api_sub_name): + err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong)) + + api_instance = getattr(api_parent_module, api_sub_name) + if not callable(api_instance): + err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not callable" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong)) + + return api_instance + + @staticmethod + def run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform): + inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform) + for compute_element in api_input_aggregation.inputs) + kwargs = {key: value.get_parameter(get_origin=False, tensor_platform=api_platform) + for key, value in api_input_aggregation.kwargs.items()} + gradient_inputs = api_input_aggregation.gradient_inputs + + if forward_or_backward == Const.FORWARD: + forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple + forward_result_tuple = convert_to_tuple(forward_result) + res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple] + if api_platform == Const.MS_FRAMEWORK or api_platform == Const.MT_FRAMEWORK: + return res_compute_element_list, inputs, kwargs, forward_result_tuple + else: + if gradient_inputs is None: + err_msg = f"ApiRunner.run_api failed: run backward api but gradient_inputs is missing" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) + gradient_inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform) + for compute_element in gradient_inputs) + if api_platform == Const.MS_FRAMEWORK or api_platform == Const.MT_FRAMEWORK: + if len(gradient_inputs) == 1: + gradient_inputs = gradient_inputs[0] + + def api_with_kwargs(*forward_inputs): + return api_instance(*forward_inputs, **kwargs) + + grad_func = ops.GradOperation(get_all=True, sens_param=True)(api_with_kwargs) + backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple + backward_result_tuple = convert_to_tuple(backward_result) + res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple] + return res_compute_element_list, gradient_inputs, backward_result_tuple + else: + # set requires_grad + requires_grad_index = [] + for index, tensor in enumerate(inputs): + if isinstance(tensor, torch.Tensor) and \ + torch_dtype_to_dtype_str.get(tensor.dtype) in float_dtype_str_list: + setattr(tensor, "requires_grad", True) + requires_grad_index.append(index) + forward_results = api_instance(*inputs, **kwargs) + forward_results = convert_to_tuple(forward_results) + for forward_res, gradient_in in zip(forward_results, gradient_inputs): + forward_res.backward(gradient_in) + backward_result_list = [] + for index in requires_grad_index: + backward_result_list.append(getattr(inputs[index], "grad")) + res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_list] + + return res_compute_element_list + + +api_runner = ApiRunner() + +# ======== 数据结构类 ======== + +class ResultCsvEntry: + def __init__(self) -> None: + self.forward_pass_status = None + self.backward_pass_status = None + self.forward_err_msg = "" + self.backward_err_msg = "" + self.overall_err_msg = None + +class ProcessResultPacket: + def __init__(self, process_status, result, err_msg) -> None: + self.process_status = process_status + self.result = result + self.err_msg = err_msg + +class MstensorMetaData: + def __init__(self, dtype_str, npy_path, maximum, minimum, shape) -> None: + self.dtype_str = dtype_str + self.npy_path = npy_path + self.maximum = maximum + self.minimum = minimum + self.shape = shape + + +class DtypeMetaData: + def __init__(self, dtype_str) -> None: + self.dtype_str = dtype_str + + +class ComputeElement: + def __init__(self, compute_element_info=None, parameter=None): + self.supported_parameter_type = tuple(type_to_api_info_type_str.keys()) + tuple([torch.Tensor, tuple]) + if parameter is not None: + self._init_with_parameter(parameter) + elif isinstance(compute_element_info, (list, dict)): + self._init_from_compute_element_info(compute_element_info) + elif compute_element_info is None: + self._init_from_null_compute_element_info() + else: + pass + logger.error_log_with_exp( + "ComputeElement.__init__ failed: not init with parameter or compute_element info is not (list, dict)", + ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) + + @staticmethod + def transfer_to_torch_tensor(ms_tensor): + ''' + Args: + ms_tensor: mindspore.Tensor + Return: + torch_tensor: torch.Tensor + ''' + ms_dtype = ms_tensor.dtype + dtype_str = ms_dtype_to_dtype_str.get(ms_dtype) + if dtype_str not in dtype_str_to_torch_dtype: + err_msg = f"ComputeElement.transfer_to_torch_tensor failed: no matching torch dtype for {dtype_str}" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) + else: + torch_dtype = dtype_str_to_torch_dtype.get(dtype_str) + + if dtype_str in int_dtype_str_list: + middle_dtype = mindspore.int64 + else: + middle_dtype = mindspore.float64 + np_ndarray = ms_tensor.astype(middle_dtype).numpy() + torch_tensor = torch.from_numpy(np_ndarray).to(torch_dtype) + return torch_tensor + + @staticmethod + def transfer_to_mindtorch_tensor(ms_tensor): + """ + Args: + ms_tensor: mindspore.Tensor + Return: + mindtorch_tensor: mindtorch.Tensor + """ + + ms_dtype = ms_tensor.dtype + + dtype_str = ms_dtype_to_dtype_str.get(ms_dtype) + + if dtype_str not in dtype_str_to_mindtorch_dtype: + err_msg = f"ComputeElement.transfer_to_mindtorch_tensor failed: no matching mindtorch dtype for {dtype_str}" + logger.error_log_with_exp(err_msg, + ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) + else: + mindtorch_dtype = dtype_str_to_mindtorch_dtype.get(dtype_str) + + if dtype_str in int_dtype_str_list: + middle_dtype = mindspore.int64 + else: + middle_dtype = mindspore.float64 + + np_ndarray = ms_tensor.astype(middle_dtype).numpy() + + mindtorch_tensor = mindtorch.from_numpy(np_ndarray).to(ms_dtype) + + return mindtorch_tensor + + @staticmethod + def transfer_to_mindspore_tensor(torch_tensor): + ''' + Args: + torch_tensor: torch.Tensor + + Return: + ms_tensor: mindspore.Tensor + ''' + torch_dtype = torch_tensor.dtype + dtype_str = torch_dtype_to_dtype_str.get(torch_dtype) + if dtype_str not in dtype_str_to_ms_dtype: + err_msg = \ + f"ComputeElement._transfer_to_mindspore_tensor failed: no matching mindspore dtype for {dtype_str}" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) + else: + ms_dtype = dtype_str_to_ms_dtype.get(dtype_str) + + if dtype_str in int_dtype_str_list: + middle_dtype = torch.int64 + else: + middle_dtype = torch.float64 + np_ndarray = torch_tensor.to(middle_dtype, copy=True).numpy() + ms_tensor = mindspore.Tensor.from_numpy(np_ndarray).astype(ms_dtype) + return ms_tensor + + @staticmethod + def convert_inf_to_real_num(value, dtype_str): + if value == float("inf"): + np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE) + value = np.finfo(np_dtype).max + elif value == float("-inf"): + np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE) + value = np.finfo(np_dtype).min + return value + + def get_parameter(self, get_origin=True, tensor_platform=Const.MS_FRAMEWORK): + ''' + Args: + get_origin: boolean + tensor_platform: str, Union["mindspore", "pytorch"] + + Return: + parameter: Union[int, float, str, slice, tuple, torch.Tensor, mindspore.Tensor] + ''' + if self.parameter is None: + return self.parameter + if isinstance(self.parameter, tuple): + return tuple([compute_element.get_parameter(get_origin=get_origin, tensor_platform=tensor_platform) + for compute_element in self.parameter]) + elif isinstance(self.parameter, self.supported_parameter_type): + parameter_tmp = self.parameter + elif isinstance(self.parameter, DtypeMetaData): + if tensor_platform == Const.MS_FRAMEWORK: + parameter_tmp = dtype_str_to_ms_dtype.get(self.parameter.dtype_str) + elif tensor_platform == Const.PT_FRAMEWORK: + parameter_tmp = dtype_str_to_torch_dtype.get(self.parameter.dtype_str) + elif tensor_platform == Const.MT_FRAMEWORK: + parameter_tmp = dtype_str_to_mindtorch_dtype.get(self.parameter.dtype_str) + + elif isinstance(self.parameter, MstensorMetaData): + mstensor_meta_data = self.parameter + ms_dtype = dtype_str_to_ms_dtype.get(mstensor_meta_data.dtype_str) + if global_context.get_is_constructed(): + np_dtype = dtype_str_to_np_dtype.get(mstensor_meta_data.dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE) + ndarray = self._construct_ndarray(mstensor_meta_data.shape, mstensor_meta_data.maximum, + mstensor_meta_data.minimum, np_dtype) + else: + ndarray = load_npy(mstensor_meta_data.npy_path) + parameter_tmp = mindspore.Tensor(ndarray, dtype=ms_dtype) + else: + err_msg = "ComputeElement.get_parameter failed: self.parameter type is not in " \ + "(int, float, str, slice, bool, torch.Tensor, mindspore.Tensor, MstensorMetaData)" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) + + # if necessary, do transfer + if not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.PT_FRAMEWORK: + parameter = self.transfer_to_torch_tensor(parameter_tmp) + elif not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.MT_FRAMEWORK: + parameter = self.transfer_to_mindtorch_tensor(parameter_tmp) + elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform == Const.MS_FRAMEWORK: + parameter = self.transfer_to_mindspore_tensor(parameter_tmp) + else: + parameter = parameter_tmp + + return parameter + + def get_shape(self): + return self.shape + + def get_dtype(self): + return self.dtype_str + + def _construct_ndarray(self, shape, maximum, minimum, np_dtype): + shape = tuple(shape) + np.random.seed({random_seed}) + if np_dtype == np.bool_: + ndarray = np.random.rand(*shape) > 0.5 + else: + maximum = self.convert_inf_to_real_num(maximum, np_dtype) + minimum = self.convert_inf_to_real_num(minimum, np_dtype) + ndarray = np.random.uniform(minimum, maximum, shape).astype(np_dtype) + return ndarray + + def _init_from_null_compute_element_info(self): + self.parameter = None + self.shape = tuple() + self.dtype = "None" + + def _init_from_compute_element_info(self, compute_element_info): + ''' + Args: + compute_element_info: Union[list, dict] + + Return: + void + + init member attributes: self.shape, self.dtype_str, self.parameter + ''' + if isinstance(compute_element_info, list): + self.shape = tuple() + self.dtype_str = TUPLE_TYPE_STR + self.parameter = tuple([ComputeElement(compute_element_info=sub_info) + for sub_info in compute_element_info]) + else: + type_str = check_and_get_from_json_dict(compute_element_info, "type", "type field in api_info.json", + accepted_type=str, accepted_value=api_info_type_str_to_type.keys()) + self.shape = tuple() + self.dtype_str = type_str + if type_str == MINDSPORE_TENSOR_TYPE_STR: + self._init_from_mstensor_compute_element_info(compute_element_info) + else: + value = check_and_get_from_json_dict(compute_element_info, "value", "value field in api_info.json") + if type_str == MINDSPORE_DTYPE_TYPE_STR: + self.parameter = DtypeMetaData(value) + elif type_str == SLICE_TYPE_STR: + self.parameter = slice(*tuple(value)) + else: # type_str in ("str", "int", "float", "bool") + self.parameter = value + + def _init_from_mstensor_compute_element_info(self, compute_element_info): + ''' + do not load real tensor, only record meta data + ''' + dtype_str = check_and_get_from_json_dict(compute_element_info, "dtype", "dtype field in api_info.json", + accepted_type=str, accepted_value=dtype_str_to_ms_dtype.keys()) + shape = check_and_get_from_json_dict(compute_element_info, "shape", "shape field in api_info.json", + accepted_type=(list,)) + if global_context.get_is_constructed(): + maximum = check_and_get_from_json_dict(compute_element_info, "Max", "Max field in api_info.json", + accepted_type=(int, float)) + minimum = check_and_get_from_json_dict(compute_element_info, "Min", "Min field in api_info.json", + accepted_type=(int, float)) + + npy_path = None + else: + maximum, minimum = None, None + data_name = check_and_get_from_json_dict(compute_element_info, "data_name", + "data_name field in api_info.json", accepted_type=(str,)) + npy_path = os.path.join(global_context.get_dump_data_dir(), data_name) + mstensor_meta_data = MstensorMetaData(dtype_str, npy_path, maximum, minimum, shape) + self.parameter = mstensor_meta_data + self.dtype_str = dtype_str + self.shape = tuple(shape) + + def _init_with_parameter(self, parameter): + self.parameter = parameter + print(f"parameter:{parameter}") + print(f"self.supported_parameter_type:{self.supported_parameter_type}") + if isinstance(parameter, dict): + # 这里假设 dict 中有 'type'、'shape'、'dtype' 等字段 + return self._init_from_compute_element_info(parameter) + self.shape = tuple() + if not isinstance(parameter, self.supported_parameter_type): + err_msg = "ComputeElement._init_with_parameter failed: " \ + "parameter type is not in (int, float, str, slice, bool, torch.Tensor, mindspore.Tensor)" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) + if isinstance(parameter, mindspore.Tensor): + self.shape = tuple(parameter.shape) + self.dtype_str = ms_dtype_to_dtype_str.get(parameter.dtype) + elif isinstance(parameter, torch.Tensor): + self.shape = tuple(parameter.shape) + self.dtype_str = torch_dtype_to_dtype_str.get(parameter.dtype) + elif isinstance(parameter, typing.Type): + self.dtype_str = MINDSPORE_DTYPE_TYPE_STR + self.parameter = DtypeMetaData(ms_dtype_to_dtype_str.get(parameter)) + elif isinstance(parameter, torch.dtype): + self.dtype_str = TORCH_DTYPE_TYPE_STR + self.parameter = DtypeMetaData(torch_dtype_to_dtype_str.get(parameter)) + elif isinstance(parameter, tuple): + self.dtype_str = TUPLE_TYPE_STR + self.parameter = tuple([ComputeElement(parameter=param) for param in parameter]) + else: + self.dtype_str = type_to_api_info_type_str.get(type(parameter)) + print(f"self.dtype_str{self.dtype_str}") + +class BasicInfoAndStatus: + def __init__(self, api_name, bench_dtype, tested_dtype, shape, status, err_msg) -> None: + self.api_name = api_name + self.bench_dtype = bench_dtype + self.tested_dtype = tested_dtype + self.shape = shape + self.status = status + self.err_msg = err_msg + +# ======== api执行类 ======= + +def get_input(propagation): + args_info_forward = {args_info_forward} + kwargs_info_forward = {kwargs_info_forward} + args_info_backward = {args_info_backward} + forward_inputs = [ComputeElement(compute_element_info=compute_element_info) + for compute_element_info in args_info_forward] + kwargs_compute_element_dict = { + key_str: ComputeElement(compute_element_info=compute_element_info) + for key_str, compute_element_info in kwargs_info_forward.items() + } + if args_info_backward: + gradient_inputs = [ComputeElement(compute_element_info=compute_element_info) + for compute_element_info in args_info_backward] + else: + gradient_inputs = None + return ApiInputAggregation( + forward_inputs, + kwargs_compute_element_dict, + gradient_inputs + ) + +# 运行和比对函数 +def run_and_compare_helper(api_name_str, api_input_aggregation, forward_or_backward): + """ + Args: + api_info: ApiInfo + api_name_str: str + api_input_aggregation: ApiInputAggregation + forward_or_backward: str: Union["forward", "backward"] + + Return: + output_list: List[tuple(str, str, BasicInfoAndStatus, dict{str: CompareResult})] + + Description: + get mindspore api output, run torch api and get output. + compare output. + record compare result. + """ + # get output + if forward_or_backward == Const.FORWARD: + tested_outputs, inputs, kwargs, forward_result_tuple = api_runner(api_input_aggregation, api_name_str, + forward_or_backward, + global_context.get_framework()) + print(f"inputs:{inputs}") + print(f"kwargs:{kwargs}") + print(f"forward_result_tuple:{forward_result_tuple}") + elif forward_or_backward == Const.BACKWARD: + tested_outputs, gradient_inputs, backward_result_tuple = api_runner(api_input_aggregation, api_name_str, + forward_or_backward, + global_context.get_framework()) + print(f"gradient_inputs:{gradient_inputs}") + print(f"backward_result_tuple:{backward_result_tuple}") + else: + tested_outputs = api_runner(api_input_aggregation, api_name_str, + forward_or_backward, global_context.get_framework()) + + bench_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.PT_FRAMEWORK) + + tested_outputs = trim_output_compute_element_list(tested_outputs, forward_or_backward) + bench_outputs = trim_output_compute_element_list(bench_outputs, forward_or_backward) + + # compare output + output_list = [] + for i, (bench_out, tested_out) in enumerate(zip(bench_outputs, tested_outputs)): + api_name_with_slot = Const.SEP.join([api_name_str, forward_or_backward, Const.OUTPUT, str(i)]) + bench_dtype = bench_out.get_dtype() + tested_dtype = tested_out.get_dtype() + shape = bench_out.get_shape() + + compare_result_dict = dict() + for compare_algorithm_name, compare_algorithm in compare_algorithms.items(): + compare_result = compare_algorithm(bench_out, tested_out) + compare_result_dict[compare_algorithm_name] = compare_result + + if compare_result_dict.get(CompareConst.COSINE).pass_status == CompareConst.PASS and \ + compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS: + status = CompareConst.PASS + err_msg = "" + else: + status = CompareConst.ERROR + err_msg = (compare_result_dict.get(CompareConst.COSINE).err_msg + + compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg) + + # self.pre_forward_hook(api_name_str, None, inputs, kwargs) + basic_info_status = \ + BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg) + output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict])) + return output_list + + +if __name__ == "__main__": + framework = "{framework}" + dump_data_dir = "{real_data_path}" + api_name = "{api_name}" + api_full_name = "{api_full_name}" + api_name_str = ".".join(api_full_name.split(".")[:3]) + propagation = "{propagation}" + data_mode = "{data_mode}" + seed_all({random_seed}) + + data_manager = DataManager("./op_result_output", None) + create_directory("./op_result_output") + + is_constructed = data_mode == "random_data" + global_context.init(is_constructed, dump_data_dir, framework) + + for i in range({iter_times}): + print(f"iter: {{i}}:") + if propagation == BACKWARD: + + + backward_inputs_aggregation = get_input(propagation) + + backward_output_list = run_and_compare_helper(api_name_str, backward_inputs_aggregation, + Const.BACKWARD) + process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.SUCCESS, + result=backward_output_list, err_msg="") + + + if process_result_packet.process_status is MsCompareConst.ProcessStatus.SUCCESS: + data_manager.record(process_result_packet.result) + elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP: + data_manager.record_exception_skip(api_name_str, Const.BACKWARD, process_result_packet.err_msg) + + data_manager.save_results(api_name_str) + else: + forward_inputs_aggregation = get_input(propagation) + + forward_output_list = run_and_compare_helper(api_name_str, forward_inputs_aggregation, + Const.FORWARD) + process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.SUCCESS, + result=forward_output_list, err_msg="") + + + if process_result_packet.process_status is MsCompareConst.ProcessStatus.SUCCESS: + data_manager.record(process_result_packet.result) + elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP: + data_manager.record_exception_skip(api_name_str, Const.FORWARD, process_result_packet.err_msg) + + data_manager.save_results(api_name_str) + + print("Compare finished.") \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py index e764140badf4c107ea83044353aba19a1c412fe0..37f6faa514eaf4855211f9db8ff45982c3b8b976 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py @@ -27,10 +27,14 @@ import numpy as np from tqdm import tqdm # 本地应用/库特定导入 -from msprobe.core.common.const import Const, CompareConst, MsCompareConst +from msprobe.core.common.const import Const, CompareConst from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker, BasicInfoAndStatus from msprobe.mindspore.api_accuracy_checker.multi_data_manager import MultiDataManager from msprobe.mindspore.common.log import logger +from msprobe.mindspore.common.const import MsCompareConst + +from msprobe.core.data_dump.data_collector import build_data_collector +from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation class MultiApiAccuracyChecker(ApiAccuracyChecker): @@ -50,6 +54,12 @@ class MultiApiAccuracyChecker(ApiAccuracyChecker): # 初始化一个属性来存储当前的设备ID(用于日志中显示) self.current_device_id = None + self.save_error_data = args.save_error_data + if self.save_error_data: + config, dump_path_aggregation = self.init_save_error_data(args) + self.data_collector = build_data_collector(config) + self.data_collector.update_dump_paths(dump_path_aggregation) + def process_on_device(self, device_id, api_infos, progress_queue): """ 在特定设备上处理一部分API。 diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py index 84f2706cc55fa3d0a1fba13d54ba8310371f1a43..13e2645ea14932afa3ac3e9ea131e443b2ee931e 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py @@ -19,7 +19,8 @@ import sys from pathlib import Path import mindspore from msprobe.mindspore.common.log import logger -from msprobe.core.common.const import Const, CompareConst, MsCompareConst +from msprobe.core.common.const import Const, CompareConst +from msprobe.mindspore.common.const import MsCompareConst import torch as mindtorch from torch import Tensor as mindtorch_tensor import torch.nn.functional as mindtorch_func @@ -107,7 +108,8 @@ def delete_torch_paths(): if count_delete_env_path >= MsCompareConst.MAX_RECURSION_DEPTH - 1: raise Exception(f"Please check if you have a valid PyTorch and MindTorch environment, and ensure " - f"the PYTHONPATH environment variable depth does not exceed {Const.MAX_RECURSION_DEPTH}.") + f"the PYTHONPATH environment variable depth does not " + f"exceed {MsCompareConst.MAX_RECURSION_DEPTH}.") if not is_mindtorch(): diff --git a/debug/accuracy_tools/msprobe/mindspore/cell_processor.py b/debug/accuracy_tools/msprobe/mindspore/cell_processor.py index 6dc5d510ef51ab2a135a8bdf9f15ac670fba9e56..8589e86f6ca2568c77f519aecb0e04a95160ec3c 100644 --- a/debug/accuracy_tools/msprobe/mindspore/cell_processor.py +++ b/debug/accuracy_tools/msprobe/mindspore/cell_processor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,21 +13,55 @@ # See the License for the specific language governing permissions and # limitations under the License. -from msprobe.core.data_dump.scope import ModuleRangeScope, MixRangeScope +import threading +from collections import OrderedDict + +from mindspore import Tensor +from mindspore.common.hook_handle import HookHandle +from mindspore.ops.operations import _inner_ops as inner + from msprobe.core.common.const import Const +from msprobe.core.common.exceptions import MsprobeException +from msprobe.core.common.runtime import Runtime +from msprobe.core.common.utils import ModuleQueue, ThreadSafe +from msprobe.core.data_dump.scope import ModuleRangeScope, MixRangeScope, BaseScope +from msprobe.mindspore.common.const import Const as MsConst +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.common.utils import ( + is_mindtorch, + get_cells_and_names_with_index, + has_kwargs_in_forward_hook, + is_graph_mode_cell_dump_allowed, + is_backward_hook_output_a_view +) +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump + + +def get_cell_construct(construct): + def _construct(self, *args, **kwargs): + if hasattr(self, 'msprobe_hook'): + setattr(self, 'msprobe_input_kwargs', kwargs) + return construct(self, *args, **kwargs) + + return _construct class CellProcessor: + cell_queue = ModuleQueue() cell_count = {} - cell_stack = [] - api_parent_node = "" + cell_stack = {} + api_parent_node = {} module_node = {} + cell_bw_hook_kernels = {} + cell_backward_pre_hook = [] + cell_backward_hook = [] def __init__(self, scope): self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None @staticmethod - def set_cell_count(cell_name): + def set_and_get_calls_number(cell_name): if cell_name not in CellProcessor.cell_count: CellProcessor.cell_count[cell_name] = 0 else: @@ -36,44 +70,213 @@ class CellProcessor: @classmethod def reset_cell_stats(cls): + cls.cell_queue = ModuleQueue() cls.cell_count = {} - cls.cell_stack = [] - cls.api_parent_node = "" + cls.cell_stack = {} + cls.api_parent_node = {} cls.module_node = {} + cls.cell_bw_hook_kernels = {} + cls.cell_backward_pre_hook = [] + cls.cell_backward_hook = [] + + def register_cell_hook(self, models, build_hook, config: DebuggerConfig): + if not models: + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, + 'The model cannot be None, when level is "L0" or "mix"') + + is_registered = False + model_type = Const.MODULE if is_mindtorch() else Const.CELL + cells_with_index_in_pynative_mode, cells_with_index_in_graph_mode = get_cells_and_names_with_index(models) + construct_name = '_call_impl' if is_mindtorch() else '_run_construct' + + for index, cells_and_names in cells_with_index_in_pynative_mode.items(): + model = models if index == "-1" else models[int(index)] + for name, cell in cells_and_names: + if cell == model: + continue + + if not has_kwargs_in_forward_hook(): + if not hasattr(cell.__class__, 'msprobe_construct'): + setattr(cell.__class__, 'msprobe_construct', True) + if hasattr(cell.__class__, construct_name): + setattr(cell.__class__, construct_name, + get_cell_construct(getattr(cell.__class__, construct_name))) + setattr(cell, 'msprobe_hook', True) + + cell_index = (index + Const.SEP) if index != "-1" else "" + prefix = f'{model_type}{Const.SEP}{cell_index}{name}{Const.SEP}{cell.__class__.__name__}{Const.SEP}' + + forward_pre_hook = self.build_cell_hook(prefix, build_hook) + cell.register_forward_pre_hook(forward_pre_hook) + + if not is_registered: + logger.info("The cell hook function is successfully mounted to the model.") + is_registered = True + + if is_graph_mode_cell_dump_allowed(config): + cells_and_names_in_graph_mode = [] + for index, cells_and_names in cells_with_index_in_graph_mode.items(): + model = models if index == "-1" else models[int(index)] + for name, cell, parent_cell in cells_and_names: + if cell == model: + continue + cell_index = (index + Const.SEP) if index != "-1" else "" + cells_and_names_in_graph_mode.append((f'{cell_index}{name}', cell, parent_cell)) + + if cells_and_names_in_graph_mode: + Runtime.run_mode = MsConst.PYNATIVE_GRAPH_MODE + GraphModeCellDump(config, cells_and_names_in_graph_mode, strict=False).handle() + + def build_cell_hook(self, cell_name, build_data_hook): + @ThreadSafe.synchronized + def forward_pre_hook(cell, args): + if not Runtime.is_running: + return args - def node_hook(self, name_prefix, start_or_stop, **kwargs): - def begin_hook(cell, input_data): - full_name = self.set_and_get_reserved_name(cell, name_prefix, is_called_by_pre_hook=True) - if CellProcessor.cell_stack: - CellProcessor.module_node[full_name] = CellProcessor.cell_stack[-1] + index = CellProcessor.set_and_get_calls_number(cell_name) + full_forward_name = f'{cell_name}{Const.FORWARD}{Const.SEP}{index}' + full_backward_name = f'{cell_name}{Const.BACKWARD}{Const.SEP}{index}' + + self.set_construct_info_in_pre_hook(full_forward_name) + + if not hasattr(cell, 'msprobe_forward_hook'): + if is_mindtorch(): + cell.register_forward_hook(forward_hook, prepend=True, with_kwargs=True) + else: + forward_hook_dict = getattr(cell, '_forward_hook', OrderedDict()) + if has_kwargs_in_forward_hook(): + forward_hook_with_kwargs_dict = getattr(cell, '_forward_hook_with_kwargs', OrderedDict()) + handle = HookHandle(forward_hook_dict, extra_dict=forward_hook_with_kwargs_dict) + forward_hook_with_kwargs_dict[handle.handle_id] = True + else: + handle = HookHandle(forward_hook_dict) + forward_hook_dict[handle.handle_id] = forward_hook + forward_hook_dict.move_to_end(handle.handle_id, last=False) + + setattr(cell, 'msprobe_forward_hook', True) + + def get_backward_hook(backward_data_hook, full_backward_name): + @ThreadSafe.synchronized + def backward_hook_fn(cell, grad_input, grad_output): + new_output = backward_data_hook(cell, grad_input, grad_output) + self.set_construct_info_in_hook(full_backward_name) + cell.has_pre_hook_called = False + return new_output + + return backward_hook_fn + + enable_hooked = sum( + [isinstance(ele, Tensor) and ele.dtype not in MsConst.NonDifferentiableType for ele in args] + ) + if enable_hooked: + backward_hook = OrderedDict() + hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name) + backward_hook[full_backward_name] = get_backward_hook(hook_set.backward_hook, full_backward_name) + CellProcessor.cell_backward_hook.append(backward_hook) + bw_hook = inner.CellBackwardHook(full_backward_name, cell, + self.cell_backward_hook[-1]) + bw_hook.register_backward_hook() + CellProcessor.cell_bw_hook_kernels[full_forward_name] = bw_hook + + args = bw_hook(args) if is_backward_hook_output_a_view() else bw_hook(*args) + + return args + + @ThreadSafe.synchronized + def forward_hook(cell, args, kwargs_or_output, output_or_kwargs=None): + index = CellProcessor.cell_count.get(cell_name, 0) + full_forward_name = f'{cell_name}{Const.FORWARD}{Const.SEP}{index}' + full_backward_name = f'{cell_name}{Const.BACKWARD}{Const.SEP}{index}' + + self.set_construct_info_in_hook(full_forward_name) + + hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name) + hook_result = hook_set.forward_hook(cell, args, kwargs_or_output, output_or_kwargs) + if hook_result is not None: + outputs = hook_result else: - CellProcessor.module_node[full_name] = None + outputs = output_or_kwargs if has_kwargs_in_forward_hook() else kwargs_or_output + + bw_hook = CellProcessor.cell_bw_hook_kernels.get(full_forward_name) + if bw_hook: + if not isinstance(outputs, (Tensor, tuple)): + logger.warning("For backward hooks to be called," + " cell output should be a Tensor or a tuple of Tensors" + f" but received {type(outputs)}") + if is_backward_hook_output_a_view(): + new_outputs = bw_hook(outputs) + else: + if isinstance(outputs, tuple): + new_outputs = bw_hook(*outputs) + else: + new_outputs = bw_hook(outputs) + if isinstance(outputs, tuple) and len(outputs) == 1: + new_outputs = (new_outputs,) + outputs = new_outputs - CellProcessor.cell_stack.append(full_name) - CellProcessor.api_parent_node = full_name + def get_backward_pre_hook(full_backward_name, backward_data_hook): + @ThreadSafe.synchronized + def backward_pre_hook_fn(cell, grad_output): + cell.has_pre_hook_called = True + self.set_construct_info_in_pre_hook(full_backward_name) + if backward_data_hook: + backward_data_hook(cell, (), grad_output) + self.set_construct_info_in_hook(full_backward_name) + cell.has_pre_hook_called = False - if self.scope: - self.scope.begin_module(full_name) + return backward_pre_hook_fn - def end_hook(cell, input_data, output_data): - if CellProcessor.cell_stack: - CellProcessor.cell_stack.pop() - if CellProcessor.cell_stack: - CellProcessor.api_parent_node = CellProcessor.cell_stack[-1] + backward_pre_hook = OrderedDict() + backward_data_hook = None if bw_hook else hook_set.backward_hook + backward_pre_hook[full_backward_name] = get_backward_pre_hook(full_backward_name, backward_data_hook) + CellProcessor.cell_backward_pre_hook.append(backward_pre_hook) + bw_pre_hook = inner.CellBackwardHook(full_backward_name, cell, + self.cell_backward_pre_hook[-1]) + bw_pre_hook.register_backward_pre_hook() + + if is_backward_hook_output_a_view(): + result = bw_pre_hook(outputs) else: - CellProcessor.api_parent_node = None + if isinstance(outputs, tuple): + result = bw_pre_hook(*outputs) + else: + result = bw_pre_hook(outputs) + if isinstance(outputs, tuple): + if len(outputs) == 1: + result = (result,) + if len(result) != len(outputs): + raise TypeError( + f"The backward pre hook return value size is {len(result)} " + f"not equal to output size {len(outputs)}" + ) + return result - if self.scope: - self.scope.end_module(cell.mindstudio_reserved_name) + return forward_pre_hook - return begin_hook if Const.START == start_or_stop else end_hook + def set_construct_info_in_pre_hook(self, full_name): + tid = threading.get_ident() + if tid not in self.cell_stack: + CellProcessor.cell_stack[tid] = [] - def set_and_get_reserved_name(self, cell, cell_name, is_called_by_pre_hook=False): - if not is_called_by_pre_hook and hasattr(cell, 'has_pre_hook_called') and cell.has_pre_hook_called: - cell.has_pre_hook_called = False + if self.cell_stack[tid]: + CellProcessor.module_node[full_name] = self.cell_stack[tid][-1] else: - if is_called_by_pre_hook: - cell.has_pre_hook_called = True - index = self.set_cell_count(cell_name) - cell.mindstudio_reserved_name = cell_name + Const.SEP + str(index) - return cell.mindstudio_reserved_name + parent_name = CellProcessor.cell_queue.find_last(full_name) + CellProcessor.module_node[full_name] = parent_name + + CellProcessor.cell_queue.add_name(full_name) + CellProcessor.cell_stack[tid].append(full_name) + CellProcessor.api_parent_node[tid] = full_name + if self.scope: + self.scope.begin_module(full_name) + + def set_construct_info_in_hook(self, full_name): + tid = threading.get_ident() + CellProcessor.cell_queue.remove_name(full_name) + CellProcessor.api_parent_node[tid] = None + if self.cell_stack.get(tid): + CellProcessor.cell_stack[tid].pop() + if self.cell_stack.get(tid): + CellProcessor.api_parent_node[tid] = CellProcessor.cell_stack[tid][-1] + if self.scope: + self.scope.end_module(full_name) diff --git a/debug/accuracy_tools/msprobe/mindspore/code_mapping/bind.py b/debug/accuracy_tools/msprobe/mindspore/code_mapping/bind.py index 614abdf20f238bde69742103cf3b9e534e269313..147d954ecb12b83e2271885cac154ed48176dd24 100644 --- a/debug/accuracy_tools/msprobe/mindspore/code_mapping/bind.py +++ b/debug/accuracy_tools/msprobe/mindspore/code_mapping/bind.py @@ -119,9 +119,17 @@ def find_npy_files(npy_path): # 如果是目录,使用Path.rglob查找所有.npy文件 if npy_path_obj.is_dir(): - for file in npy_path_obj.rglob(Const.NUMPY_PATTERN): - check_file_or_directory_path(file) - npy_files.append(file.resolve()) + base_depth = len(npy_path_obj.resolve().parts) + for root, dirs, files in os.walk(npy_path_obj): + current_depth = len(Path(root).resolve().parts) - base_depth + if current_depth >= 10: + dirs[:] = [] + + for filename in files: + if filename.endswith(Const.NUMPY_SUFFIX): + file_path = Path(root) / filename + check_file_or_directory_path(file_path) + npy_files.append(file_path.resolve()) else: logger.info(f"The specified path is neither an .npy file nor a directory: {npy_path}") @@ -254,7 +262,18 @@ def bind_code_info_for_data(input_dir: str, nodes: Dict[str, GraphNode]) -> Dict corresponding_name = None name_without_ext = os.path.splitext(corresponding_name)[0] npy_path = os.path.realpath(npy_file) - node_scope = name_without_ext.split(".")[1] + + parts = name_without_ext.split(".") + if len(parts) < 2: + logger.error( + f'File name "{file_name}" in "{directory}" ' + f'does not conform to expected format (missing scope separator ".")!' + ) + raise Exception( + f'File name "{file_name}" has incorrect format, cannot extract node scope!' + ) + node_scope = parts[1] + trie = Trie() for key, value in match_dict.items(): trie.insert(key, value) diff --git a/debug/accuracy_tools/msprobe/mindspore/code_mapping/graph_parser.py b/debug/accuracy_tools/msprobe/mindspore/code_mapping/graph_parser.py index ee35750fb35c100e2025b0dcbdd9e20ef998b2ee..39fe1ee69f1b193d0f5ce5ebe1d0c454f60da51f 100644 --- a/debug/accuracy_tools/msprobe/mindspore/code_mapping/graph_parser.py +++ b/debug/accuracy_tools/msprobe/mindspore/code_mapping/graph_parser.py @@ -34,19 +34,6 @@ class Parser: if isinstance(subgraph_node.attrs, list): subgraph_node.attrs.extend(attrs) - @staticmethod - def parse_graph_attributes(text: str, graph_node: GraphNode) -> None: - attr_pattern = re.compile(r'# Attrs:\s*(.*)', re.DOTALL) - match = attr_pattern.search(text, graph_node.pos) - if match: - attrs = match.group(1).strip().split('\n') - for attr in attrs: - if not attr: - break - key, value = attr.split(':') - if isinstance(graph_node.attrs, dict): - graph_node.attrs[key.strip()] = value.strip() - @staticmethod def parse_code_info(text: str, start_pos: int, end_pos: int) -> List[str]: code_info = [] @@ -90,7 +77,7 @@ class Parser: @staticmethod def extract_constants(inputs_str: str) -> List[str]: - constant_pattern = re.compile(r'\b(\w+\(.*?\))') + constant_pattern = re.compile(r'\b([A-Za-z_][A-Za-z0-9_]{0,10000})\(([A-Za-z0-9_\s,.\-+/]{0,10000})\)') constants = constant_pattern.findall(inputs_str) return constants @@ -103,7 +90,8 @@ class Parser: self.nodes[func_name] = func_graph_info def parse_nodes(self, text: str, subgraph_info: GraphNode) -> None: - node_pattern = re.compile(r'(%\d+)\((\S+)\)\s*=\s*(\S+)\(') + node_pattern = re.compile( + r'(%\d{1,10000})\(([A-Za-z0-9_\.]{1,10000})\)\s*=\s*([A-Za-z_][A-Za-z0-9_]{0,10000})\(') matches = list(node_pattern.finditer(text)) for i, match in enumerate(matches): series_number = match.group(1) @@ -119,13 +107,15 @@ class Parser: constants = self.__class__.extract_constants(args_str) - scope_pattern = re.compile(r'# .*scope.*:\s*\((.*?)\)', re.IGNORECASE | re.MULTILINE) - + scope_pattern = re.compile( + r'^(?=.{0,300}$)[ \t]*\#[ \t]*[^\r\n]*?scope[^\r\n]*?:[ \t]*\(([^)\r\n]{1,200})\)[ \t]*$', + re.IGNORECASE | re.MULTILINE) scope_match = scope_pattern.search(text, end_pos) scope = scope_match.group(1) if scope_match else "" - id_pattern = re.compile(r'.*cnode_primal_attrs:' - r'\s*\{.*\b(?:forward_unique_id|unique_id):\s*\"(\d+)\".*', re.IGNORECASE) + id_pattern = re.compile( + r'cnode_primal_attrs:'r'\s*\{[\w+]{1, 10000}\b(?:forward_unique_id|unique_id):\s*\"(\d+)\"', + re.IGNORECASE) unique_id_match = id_pattern.search(text, end_pos, scope_match.start()) unique_id = unique_id_match.group(1) if unique_id_match else None @@ -186,7 +176,7 @@ class Parser: node_info.var_inputs.append(callee_name) def parse_subgraphs(self, text: str) -> None: - subgraph_pattern = re.compile(r'subgraph\s+@(\S+)(\([^\)]*\))?\s+.*\{') + subgraph_pattern = re.compile(r'/subgraph\s+@([\w+]{1,1000)(\([^\)]{1,100}\))?\s+\S[^\{]\{/+') matches = list(subgraph_pattern.finditer(text)) end_pos = 0 for match in matches: @@ -203,11 +193,6 @@ class Parser: subgraph_info.end = end_pos logging.info('Parsed subgraph: %s', subgraph_name) - def count_nodes(self) -> Tuple[int, int]: - total_nodes = len(self.nodes) - total_cnodes = sum(1 for node in self.nodes.values() if node.name.startswith('CNode')) - return total_nodes, total_cnodes - def create_backward_map(self): for node in self.nodes.values(): if node.scope and node.scope.startswith("Gradients"): diff --git a/debug/accuracy_tools/msprobe/mindspore/common/const.py b/debug/accuracy_tools/msprobe/mindspore/common/const.py index 9e8c79e51284b8e9696dde150481609f7da8b488..700c669e20dc18b3824126a09e5ceb20f67693a3 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/const.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/const.py @@ -15,6 +15,7 @@ import numpy as np import mindspore as ms +from mindspore import dtype as mstype from msprobe.core.common.const import Const as CoreConst @@ -23,14 +24,20 @@ class Const: CELL = "cell" API = "api" KERNEL = "kernel" + CELL_AND_API = 'cell_and_api' TOOL_LEVEL_DICT = { CoreConst.LEVEL_L0: CELL, CoreConst.LEVEL_L1: API, - CoreConst.LEVEL_L2: KERNEL + CoreConst.LEVEL_L2: KERNEL, + CoreConst.LEVEL_MIX: CELL_AND_API } - PYNATIVE_MODE = "pynative" + + PYNATIVE_MODE = CoreConst.PYNATIVE_MODE + GRAPH_MODE = "graph" GRAPH_GE_MODE = "graph_ge" GRAPH_KBYK_MODE = "graph_kbyk" + PYNATIVE_GRAPH_MODE = CoreConst.PYNATIVE_GRAPH_MODE + JIT_LEVEL = "jit_level" JIT_LEVEL_O0 = "O0" JIT_LEVEL_O1 = "O1" @@ -61,6 +68,7 @@ class Const: DROPOUT_API_NAME_PREFIX = "dropout" GRAPH_DATA_MODE_LIST = [CoreConst.ALL, CoreConst.INPUT, CoreConst.OUTPUT] + GRAPH_CELL_DUMP_DATA_MODE_LIST = [CoreConst.ALL, CoreConst.FORWARD, CoreConst.BACKWARD] HOOK_MS_PREFIX_DICT = { OPS_DATA_PREFIX: OPS_PREFIX, @@ -69,6 +77,69 @@ class Const: MINT_NN_FUNC_DATA_PREFIX: MINT_NN_FUNC_PREFIX } + NonDifferentiableType = ( + mstype.bool_, mstype.int8, mstype.byte, mstype.uint8, mstype.ubyte, + mstype.int16, mstype.short, mstype.uint16, mstype.ushort, + mstype.int32, mstype.intc, mstype.uint32, mstype.uintc, + mstype.int64, mstype.intp, mstype.uint64, mstype.uintp + ) + + +class MsCompareConst: + # api_info field + MINT = "Mint" + MINT_FUNCTIONAL = "MintFunctional" + TENSOR_API = "Tensor" + FUNCTIONAL_API = "Functional" + FUSION_API = "FUSION" + + API_NAME_STR_LENGTH = 4 + MAX_RECURSION_DEPTH = 20 + + # Mindtorch api_info field + MINDTORCH_TENSOR = "Tensor" + MINDTORCH = "Torch" + MINDTORCH_FUNC = "Functional" + MINDTORCH_NPU = "NPU" + MINDTORCH_DIST = "Distributed" + + MT_VALID_API_TYPES = [ + MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR + ] + SUPPORTED_FUSION_LIST = ["flash_attention_score"] + + TASK_FIELD = "task" + STATISTICS_TASK = "statistics" + FRAMEWORK = "framework" + TENSOR_TASK = "tensor" + DUMP_DATA_DIR_FIELD = "dump_data_dir" + DATA_FIELD = "data" + + # supported api yaml + SUPPORTED_API_LIST_FILE = "checker_support_api.yaml" + SUPPORTED_TENSOR_LIST_KEY = "tensor" + + # detail_csv + DETAIL_CSV_API_NAME = "API Name" + DETAIL_CSV_BENCH_DTYPE = "Bench Dtype" + DETAIL_CSV_TESTED_DTYPE = "Tested Dtype" + DETAIL_CSV_SHAPE = "Shape" + DETAIL_CSV_PASS_STATUS = "Status" + DETAIL_CSV_MESSAGE = "Message" + DETAIL_CSV_FILE_NAME = "accuracy_checking_details" + + # result_csv + RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success" + RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success" + RESULT_CSV_FILE_NAME = "accuracy_checking_result" + + EPSILON = 1e-8 + + class ProcessStatus: + SUCCESS = "success" + API_NOT_FOUND = "api_not_found" + EXCEPTION_SKIP = "exception_skip" + class FreeBenchmarkConst: ADD_NOISE = "add_noise" diff --git a/debug/accuracy_tools/msprobe/mindspore/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/common/utils.py index ded3faaa22b565ef35c17a7596782976ddf9125d..3a58c516472066acf73fb58fc7ec4259da288a01 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/utils.py @@ -13,19 +13,65 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import os import random +import sys +import types import mindspore as ms - from mindspore import ops +from mindspore.common.jit_config import JitConfig from mindspore.mint import nn +from msprobe.core.common.const import Const +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.core.common.exceptions import DistributedNotInitializedError from msprobe.core.common.file_utils import path_len_exceeds_limit, check_path_exists, save_npy from msprobe.core.common.log import logger -from msprobe.core.common.const import Const -from msprobe.core.common.utils import CompareException, check_seed_all +from msprobe.core.common.utils import CompareException, check_seed_all, is_save_variable_valid +from msprobe.mindspore.common.const import Const as MsConst + +try: + from mindspore._c_expression import _set_init_iter +except ImportError: + enable_dynamic_kbyk_dump = False +else: + enable_dynamic_kbyk_dump = True + +mindtorch_check_result = None +register_backward_hook_functions = {} +kwargs_exist_in_forward_hook = None +is_output_of_backward_hook_a_view = None + + +class MsprobeStep(ms.train.Callback): + def __init__(self, debugger): + super(MsprobeStep, self).__init__() + self.debugger = debugger + + def on_train_begin(self, run_context): + self.debugger.start() + if enable_dynamic_kbyk_dump: + _set_init_iter(0) + + def on_train_step_begin(self, run_context): + self.debugger.start() + + def on_train_step_end(self, run_context): + self.debugger.stop() + self.debugger.step() + + +class MsprobeInitStep(ms.train.Callback): + def on_train_begin(self, run_context): + try: + from ms._c_expression import _set_init_iter + except ImportError: + logger.warning('MsprobeInitStep does not work on this version of MindSpore.') + return + cb_params = run_context.original_args() + _set_init_iter(cb_params.cur_step_num) def get_rank_if_initialized(): @@ -51,6 +97,9 @@ def save_tensor_as_npy(tensor, file_path): def convert_to_int(value): + if isinstance(value, bool): + logger.error('The value in rank_id or step should be int, please check!') + raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR) try: return int(value) except Exception: @@ -58,8 +107,8 @@ def convert_to_int(value): def clean_input_kwargs(cell): - if hasattr(cell, 'input_kwargs'): - del cell.input_kwargs + if hasattr(cell, 'msprobe_input_kwargs'): + del cell.msprobe_input_kwargs def list_lowest_level_directories(root_dir): @@ -82,7 +131,7 @@ def list_lowest_level_directories(root_dir): return lowest_level_dirs -def seed_all(seed=1234, mode=False, rm_dropout=True): +def seed_all(seed=1234, mode=False, rm_dropout=False): check_seed_all(seed, mode, rm_dropout) os.environ['PYTHONHASHSEED'] = str(seed) ms.set_seed(seed) @@ -93,20 +142,6 @@ def seed_all(seed=1234, mode=False, rm_dropout=True): remove_dropout() -class MsprobeStep(ms.train.Callback): - - def __init__(self, debugger): - super(MsprobeStep, self).__init__() - self.debugger = debugger - - def on_train_step_begin(self, run_context): - self.debugger.start() - - def on_train_step_end(self, run_context): - self.debugger.stop() - self.debugger.step() - - class Dropout(ops.Dropout): def __init__(self, keep_prob=0.5, seed0=0, seed1=1): super().__init__(1., seed0, seed1) @@ -142,13 +177,12 @@ def remove_dropout(): nn.functional.dropout = dropout_ext -mindtorch_check_result = None - - def is_mindtorch(): global mindtorch_check_result if mindtorch_check_result is None: mindtorch_check_result = False + if 'torch' not in sys.modules: + return mindtorch_check_result try: import torch except ImportError: @@ -159,17 +193,17 @@ def is_mindtorch(): return mindtorch_check_result -register_backward_hook_functions = {} - - def set_register_backward_hook_functions(): global register_backward_hook_functions + if register_backward_hook_functions: + return + if is_mindtorch(): import torch from msprobe.mindspore.mindtorch import (_call_impl, register_full_backward_pre_hook, register_full_backward_hook) - if not hasattr(torch, "register_full_backward_hook"): + if not hasattr(torch.nn.Module, "register_full_backward_hook"): setattr(torch.nn.Module, "_call_impl", _call_impl) setattr(torch.nn.Module, "register_full_backward_pre_hook", register_full_backward_pre_hook) setattr(torch.nn.Module, "register_full_backward_hook", register_full_backward_hook) @@ -182,9 +216,11 @@ def set_register_backward_hook_functions(): def check_save_param(variable, name, save_backward): # try catch this api to skip invalid call - if not isinstance(variable, (list, dict, ms.Tensor, int, float, str)): + valid_data_types = (ms.Tensor, int, float, str) + if not is_save_variable_valid(variable, valid_data_types): + valid_data_types_with_nested_types = valid_data_types + (dict, tuple, list) logger.warning("PrecisionDebugger.save variable type not valid, " - "should be one of list, dict, ms.Tensor, int, float or string. " + f"should be one of {valid_data_types_with_nested_types}" "Skip current save process.") raise ValueError if not isinstance(name, str): @@ -196,4 +232,144 @@ def check_save_param(variable, name, save_backward): logger.warning("PrecisionDebugger.save_backward name not valid, " "should be bool. " "Skip current save process.") - raise ValueError \ No newline at end of file + raise ValueError + + +def is_graph_mode_cell_dump_allowed(config): + if config.task not in [Const.TENSOR, Const.STATISTICS] or is_mindtorch() or not hasattr(ops, 'DumpGradient'): + return False + valid_mix_level = [MsConst.CELL_AND_API, Const.LEVEL_MIX] + if config.level in valid_mix_level and config.execution_mode == MsConst.PYNATIVE_MODE: + return True + return config.level == MsConst.CELL or config.level == Const.LEVEL_L0 + + +@recursion_depth_decorator('msprobe.mindspore.common.utils.is_decorated_by_jit') +def is_decorated_by_jit(func): + closure = getattr(func, '__closure__', []) + if closure: + for obj in closure: + if isinstance(obj.cell_contents, JitConfig): + return True + elif isinstance(obj.cell_contents, types.FunctionType) and hasattr(obj.cell_contents, '__closure__'): + if is_decorated_by_jit(obj.cell_contents): + return True + return False + + +@recursion_depth_decorator('msprobe.mindspore.common.utils.get_cells_and_names') +def get_cells_and_names(model, cells_set=None, name_prefix='', parent_cell=None): + cells_set = cells_set if cells_set else set() + if model in cells_set: + return + + cells_set.add(model) + jit_decorated = is_decorated_by_jit(model.construct) + yield name_prefix, model, jit_decorated, parent_cell + if jit_decorated: + return + + children_cells = getattr(model, '_cells') + for name, cell in children_cells.items(): + if cell: + cells_name_prefix = f'{name_prefix}{Const.SEP}{name}' if name_prefix else name + jit_decorated = is_decorated_by_jit(model.construct) + if jit_decorated: + yield cells_name_prefix, cell, jit_decorated, model + else: + for ele in get_cells_and_names(cell, cells_set, cells_name_prefix, model): + yield ele + + +def get_cells_and_names_with_index(models): + cells_with_index_in_pynative_mode = {} + cells_with_index_in_graph_mode = {} + + def distinguish_cells(cells): + cells_in_pynative_mode = [] + cells_in_graph_mode = [] + for name, cell, jit_decorated, parent_cell in cells: + if jit_decorated: + cells_in_graph_mode.append((name, cell, parent_cell)) + else: + cells_in_pynative_mode.append((name, cell)) + return cells_in_pynative_mode, cells_in_graph_mode + + if is_mindtorch(): + if isinstance(models, (list, tuple)): + for index, model in enumerate(models): + cells_with_index_in_pynative_mode[str(index)] = model.named_modules() + else: + cells_with_index_in_pynative_mode["-1"] = models.named_modules() + else: + if isinstance(models, (list, tuple)): + for index, model in enumerate(models): + cells = get_cells_and_names(model) + cells_in_pynative_mode, cells_in_graph_mode = distinguish_cells(cells) + cells_with_index_in_pynative_mode[str(index)] = cells_in_pynative_mode + cells_with_index_in_graph_mode[str(index)] = cells_in_graph_mode + else: + cells = get_cells_and_names(models) + cells_in_pynative_mode, cells_in_graph_mode = distinguish_cells(cells) + cells_with_index_in_pynative_mode["-1"] = cells_in_pynative_mode + cells_with_index_in_graph_mode["-1"] = cells_in_graph_mode + + return cells_with_index_in_pynative_mode, cells_with_index_in_graph_mode + + +def has_kwargs_in_forward_hook(): + global kwargs_exist_in_forward_hook + + if kwargs_exist_in_forward_hook is None: + if is_mindtorch(): + kwargs_exist_in_forward_hook = True + return kwargs_exist_in_forward_hook + + try: + func_params = inspect.signature(nn.Cell.register_forward_hook).parameters + kwargs_exist_in_forward_hook = 'with_kwargs' in func_params + except Exception: + kwargs_exist_in_forward_hook = False + return kwargs_exist_in_forward_hook + + return kwargs_exist_in_forward_hook + + +def is_backward_hook_output_a_view(): + global is_output_of_backward_hook_a_view + + if is_output_of_backward_hook_a_view is None: + is_output_of_backward_hook_a_view = False + if getattr(ms, '__version__', '2.4.0') < '2.7.0': + return is_output_of_backward_hook_a_view + try: + from mindspore.ops.operations import _inner_ops as inner + call_func = getattr(inner.CellBackwardHook, '__call__') + func_params = inspect.signature(call_func).parameters + except Exception: + return is_output_of_backward_hook_a_view + if 'args' in func_params and func_params['args'].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: + is_output_of_backward_hook_a_view = True + + return is_output_of_backward_hook_a_view + + +def wrap_backward_hook_call_func(call_func): + if not is_backward_hook_output_a_view(): + return call_func + + from mindspore.common.api import _pynative_executor as executor + from mindspore._c_expression import CreationType + + def new_call(self, args): + outputs = call_func(self, args) + if isinstance(outputs, ms.Tensor): + executor.set_creation_type(outputs, CreationType.DEFAULT) + elif isinstance(outputs, tuple): + for item in outputs: + if isinstance(item, ms.Tensor): + executor.set_creation_type(item, CreationType.DEFAULT) + return outputs + new_call.__name__ = '__call__' + + return new_call diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py new file mode 100644 index 0000000000000000000000000000000000000000..a973f8e0833550c6e05a69b043bf747ae9c42fb5 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py @@ -0,0 +1,414 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.import functools + +import os +import multiprocessing +from dataclasses import dataclass +from typing import Dict, List, Tuple, Optional, Any +from concurrent.futures import ProcessPoolExecutor +from functools import partial +from pathlib import Path + +import pandas as pd +import numpy as np +from tqdm import tqdm + +from msprobe.core.common.log import logger +from msprobe.core.common.utils import CompareException +from msprobe.core.common.exceptions import FileCheckException +from msprobe.core.common.file_utils import check_file_or_directory_path, write_df_to_csv, create_directory, \ + check_path_before_create, load_npy +from msprobe.core.common.const import CompareConst +from msprobe.core.compare.npy_compare import compare_ops_apply +from msprobe.core.compare.multiprocessing_compute import check_accuracy +from msprobe.mindspore.compare.utils import check_name_map_dict + + +def common_dir_compare(input_params: Dict, output_dir: str) -> Optional[pd.DataFrame]: + """ + 高级目录比对函数,完全镜像输入目录结构 + + Args: + input_params: 包含npu_path和bench_path的字典 + output_dir: 输出根目录 + + Returns: + 当输入目录是平铺npy文件时返回DataFrame,否则返回None + """ + npu_root = Path(input_params.get('npu_path')) + bench_root = Path(input_params.get('bench_path')) + name_map_dict = input_params.get('map_dict', {}) + check_name_map_dict(name_map_dict) + file_tree = build_mirror_file_tree(npu_root, bench_root) + + # 处理文件比对 + with ProcessPoolExecutor() as executor: + results = list(tqdm( + executor.map( + partial(process_directory_pair, name_map_dict=name_map_dict, output_dir=output_dir), + file_tree.items() + ), + total=len(file_tree), + desc="Processing directories" + )) + return + + +def process_directory_pair(item: Tuple[Path, Tuple[Path, Path]], name_map_dict: Dict, output_dir: str): + """ + 处理一个目录对 + + Args: + item: (相对路径, (npu目录, bench目录))元组 + output_dir: 输出根目录 + + Returns: + 比对结果的DataFrame(仅平铺结构时返回) + """ + rel_path, (npu_dir, bench_dir) = item + + # 创建镜像输出目录 + output_path = Path(output_dir) / rel_path + create_directory(output_path) + + # 生成文件映射 + npu_files = find_npy_files(npu_dir) + bench_files = find_npy_files(bench_dir) + map_dict = generate_map_dict(npu_files, bench_files, name_map_dict) + + if not map_dict: + logger.warning(f"No file pairs found in {rel_path}") + return None + + # 执行比对 + result_df = do_multi_process(process_chunk, map_dict) + check_path_before_create(output_path) + # 保存结果 + result_path = os.path.join(output_path, 'result.csv') + write_df_to_csv(result_df, result_path) + logger.info(f"Results saved to {result_path}") + return None + + +def build_mirror_file_tree(npu_root: Path, bench_root: Path) -> Dict[Path, Tuple[Path, Path]]: + """ + 构建镜像文件树,键为相对路径,值为(npu_path, bench_path)元组 + + Args: + npu_root: NPU数据根目录 + bench_root: 基准数据根目录 + + Returns: + 文件树字典 + """ + file_tree = {} + + # 遍历NPU目录构建树结构 + # 使用os.walk遍历目录,限制深度为10层 + for root, dirs, files in os.walk(npu_root): + # 计算当前目录深度 + depth = len(Path(root).relative_to(npu_root).parts) + if depth > 10: + dirs.clear() # 清空dirs列表以阻止继续递归 + continue + + # 检查当前目录下是否有npy文件 + if any(f.endswith('.npy') for f in files): + # 获取相对路径 + dir_path = Path(root).relative_to(npu_root) + npu_dir_pair = os.path.join(npu_root, dir_path) + bench_dir_pair = os.path.join(bench_root, dir_path) + + try: + check_file_or_directory_path(bench_dir_pair, isdir=True) + except FileCheckException: + continue + + # 添加到文件树 + if dir_path not in file_tree: + file_tree[dir_path] = (npu_dir_pair, bench_dir_pair) + + return file_tree + + +def find_npy_files(directory): + npy_files_dict = {} + # 限制递归深度为1层,即只遍历当前目录和其直接子目录 + for root, dirs, files in os.walk(directory, topdown=True): + # 计算当前目录深度 + depth = root[len(directory):].count(os.sep) + # 如果深度超过10层则跳过 + if depth > 10: + dirs.clear() + for file in files: + if file.endswith(".npy"): + # 正确移除文件扩展名 + base_name = os.path.splitext(file) + if not base_name or len(base_name) < 1: + logger.warning("Invalid file encountered.") + continue + file_name = base_name[0] + + logger.info(f"Generating file info for file: {file}") + + # 使用一致的分割逻辑 + file_ele = file_name.split('_') + + if len(file_ele) < 2: + continue + + key = '_'.join(file_ele[:-2]) + if key: + # 文件的完整路径 + value = os.path.join(root, file) + # 添加到字典中 + if key not in npy_files_dict: + npy_files_dict[key] = [] + npy_files_dict[key].append(value) + return npy_files_dict + + +def generate_map_dict(npu_file_dict, bench_file_dict, name_map_dict=None): + result_dict = {} + for k, npu_file_list in npu_file_dict.items(): + bench_file_list = bench_file_dict.get(k) + if not bench_file_list and k in name_map_dict: + bench_file_list = bench_file_dict.get(name_map_dict.get(k)) + bench_length = len(bench_file_list) + if not (bench_file_list and bench_length): + continue + for i, npu_file in enumerate(npu_file_list): + if i >= bench_length: + break + bench_file = bench_file_list[i] + result_dict[f"{k}_{i}"] = (npu_file, bench_file) + return result_dict + + +def do_multi_process(func, map_dict): + lock = multiprocessing.Manager().RLock() + result_len = len(map_dict) + process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1) + # every block size + df_chunk_size = result_len // process_num + + # generate the same len of map_dict df + result_df = initialize_result_df(result_len) + if df_chunk_size > 0: + df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)] + else: + df_chunks = [result_df] + process_num = 1 + logger.info(f"Using {process_num} processes with chunk size {df_chunk_size}") + + # 分割字典 + map_chunks = split_dict(map_dict, df_chunk_size) + + # 创建结果列表和进程池 + results = [] + pool = multiprocessing.Pool(process_num) + + progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100) + + def update_progress(size, progress_lock, extra_param=None): + with progress_lock: + progress_bar.update(size) + + def err_call(args): + logger.error('multiprocess compare failed! Reason: {}'.format(args)) + try: + pool.close() + except OSError as e: + logger.error(f'pool terminate failed: {str(e)}') + results = [] + try: + # 提交任务到进程池 + for process_idx, (df_chunk, map_chunk) in enumerate(zip(df_chunks, map_chunks)): + start_idx = df_chunk_size * process_idx + result = pool.apply_async( + func, + args=(df_chunk, start_idx, map_chunk, lock), + error_callback=err_call, + callback=partial(update_progress, len(map_chunk), lock) + ) + results.append(result) + + final_results = [r.get() for r in results] + # 等待所有任务完成 + pool.close() + pool.join() + return pd.concat(final_results, ignore_index=True) + except Exception as e: + logger.error(f"\nMain process error: {str(e)}") + pool.terminate() + return pd.DataFrame({}) + finally: + pool.close() + + +def initialize_result_df(total_size): + """预分配结果DataFrame""" + columns = [ + CompareConst.NAME, + CompareConst.NPU_DTYPE, + CompareConst.BENCH_DTYPE, + CompareConst.NPU_SHAPE, + CompareConst.BENCH_SHAPE, + CompareConst.COSINE, + CompareConst.EUC_DIST, + CompareConst.MAX_ABS_ERR, + CompareConst.MAX_RELATIVE_ERR, + CompareConst.ONE_THOUSANDTH_ERR_RATIO, + CompareConst.FIVE_THOUSANDTHS_ERR_RATIO, + CompareConst.NPU_MAX, + CompareConst.NPU_MIN, + CompareConst.NPU_MEAN, + CompareConst.NPU_NORM, + CompareConst.BENCH_MAX, + CompareConst.BENCH_MIN, + CompareConst.BENCH_MEAN, + CompareConst.BENCH_NORM, + CompareConst.ACCURACY, + CompareConst.ERROR_MESSAGE, + CompareConst.DATA_NAME + ] + return pd.DataFrame(index=range(total_size), columns=columns) + + +def split_dict(input_dict, chunk_size): + """将字典按指定chunk_size分割""" + items = list(input_dict.items()) + if chunk_size > 0: + return [dict(items[i:i + chunk_size]) for i in range(0, len(items), chunk_size)] + return [input_dict] + + +def get_tensor_stats(tensor: np.ndarray) -> Tuple[float, float, float, float]: + """获取张量的统计信息""" + t_max = np.max(tensor) + t_min = np.min(tensor) + t_mean = np.mean(tensor) + t_l2norm = np.linalg.norm(tensor) + return t_max, t_min, t_mean, t_l2norm + + +def process_chunk(df, start_idx, map_chunk, lock): + """处理一个数据块""" + err_mess = [] + results = [] + for name, file_pair in map_chunk.items(): + err_msg = "" + npu_file, bench_file = file_pair + n_value = load_npy(npu_file) + # if need to support cross frame b_value need to add load_pt + b_value = load_npy(bench_file) + error_flag = False + + err_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg) + cos_sim, euc_dist, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio = err_list + a_max, a_min, a_mean, a_l2norm = get_tensor_stats(n_value) + b_max, b_min, b_mean, b_l2norm = get_tensor_stats(b_value) + err_mess.append(err_msg) + # 使用示例 + result = ComparisonResult( + name=name, # CompareConst.NAME + npu_dtype=n_value.dtype, # CompareConst.NPU_DTYPE + bench_dtype=b_value.dtype, # CompareConst.BENCH_DTYPE + npu_shape=n_value.shape, # CompareConst.NPU_SHAPE + bench_shape=b_value.shape, # CompareConst.BENCH_SHAPE + cosine=cos_sim, # CompareConst.COSINE + euc_dist=euc_dist, # CompareConst.EUC_DIST + max_abs_err=max_abs_err, # CompareConst.MAX_ABS_ERR + max_relative_err=max_relative_err, # CompareConst.MAX_RELATIVE_ERR + one_thousandth_err_ratio=one_thousand_err_ratio, # CompareConst.ONE_THOUSANDTH_ERR_RATIO + five_thousandth_err_ratio=five_thousand_err_ratio, # CompareConst.FIVE_THOUSANDTHS_ERR_RATIO + npu_max=a_max, # CompareConst.NPU_MAX + npu_min=a_min, # CompareConst.NPU_MIN + npu_mean=a_mean, # CompareConst.NPU_MEAN + npu_norm=a_l2norm, # CompareConst.NPU_NORM + bench_max=b_max, # CompareConst.BENCH_MAX + bench_min=b_min, # CompareConst.BENCH_MIN + bench_mean=b_mean, # CompareConst.BENCH_MEAN + bench_norm=b_l2norm, # CompareConst.BENCH_NORM + accuracy=check_accuracy(cos_sim, max_abs_err), # CompareConst.ACCURACY + error_message=err_msg, # CompareConst.ERROR_MESSAGE + data_name=[npu_file, bench_file] # CompareConst.DATA_NAME + ) + results.append(result) + return _save_part_df(df, start_idx, results, lock) + + +@dataclass +class ComparisonResult: + name: str # CompareConst.NAME + npu_dtype: Any # CompareConst.NPU_DTYPE + bench_dtype: Any # CompareConst.BENCH_DTYPE + npu_shape: Tuple[int, ...] # CompareConst.NPU_SHAPE + bench_shape: Tuple[int, ...] # CompareConst.BENCH_SHAPE + cosine: float # Cons t.COSINE + euc_dist: float # CompareConst.EUC_DIST + max_abs_err: float # CompareConst.MAX_ABS_ERR + max_relative_err: float # CompareConst.MAX_RELATIVE_ERR + one_thousandth_err_ratio: float # CompareConst.ONE_THOUSANDTH_ERR_RATIO + five_thousandth_err_ratio: float # CompareConst.FIVE_THOUSANDTHS_ERR_RATIO + npu_max: float # CompareConst.NPU_MAX + npu_min: float # CompareConst.NPU_MIN + npu_mean: float # CompareConst.NPU_MEAN + npu_norm: float # CompareConst.NPU_NORM + bench_max: float # CompareConst.BENCH_MAX + bench_min: float # CompareConst.BENCH_MIN + bench_mean: float # CompareConst.BENCH_MEAN + bench_norm: float # CompareConst.BENCH_NORM + accuracy: bool # CompareConst.ACCURACY + error_message: str # CompareConst.ERROR_MESSAGE + data_name: List[str] # CompareConst.DATA_NAME + + +def _save_part_df(df, start_idx, results, lock): + lock.acquire() + try: + for i, result in enumerate(results): + process_index = i + start_idx + df.loc[process_index, CompareConst.NAME] = result.name + df.loc[process_index, CompareConst.NPU_DTYPE] = result.npu_dtype + df.loc[process_index, CompareConst.BENCH_DTYPE] = result.bench_dtype + df.loc[process_index, CompareConst.NPU_SHAPE] = str(result.npu_shape) # 通常将tuple转为字符串存储 + df.loc[process_index, CompareConst.BENCH_SHAPE] = str(result.bench_shape) + df.loc[process_index, CompareConst.COSINE] = result.cosine + df.loc[process_index, CompareConst.EUC_DIST] = result.euc_dist + df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_abs_err + df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err + df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result.one_thousandth_err_ratio + df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result.five_thousandth_err_ratio + df.loc[process_index, CompareConst.NPU_MAX] = result.npu_max + df.loc[process_index, CompareConst.NPU_MIN] = result.npu_min + df.loc[process_index, CompareConst.NPU_MEAN] = result.npu_mean + df.loc[process_index, CompareConst.NPU_NORM] = result.npu_norm + df.loc[process_index, CompareConst.BENCH_MAX] = result.bench_max + df.loc[process_index, CompareConst.BENCH_MIN] = result.bench_min + df.loc[process_index, CompareConst.BENCH_MEAN] = result.bench_mean + df.loc[process_index, CompareConst.BENCH_NORM] = result.bench_norm + df.loc[process_index, CompareConst.ACCURACY] = result.accuracy + df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.error_message + df.loc[process_index, CompareConst.DATA_NAME] = str(result.data_name) # 列表转为字符串存储 + return df + except ValueError as e: + logger.error('result dataframe is not found.') + raise CompareException(CompareException.INVALID_DATA_ERROR) from e + except IndexError as e: + logger.error('result dataframe elements can not be access.') + raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e + finally: + lock.release() diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/distributed_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/distributed_compare.py index 46f825330dbb8b7ff5ce9d42cef5c6b74e3846f2..5064bedcdb8d65aa4406b77e5e8ae46696faf4d7 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/distributed_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/distributed_compare.py @@ -13,41 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os from msprobe.core.common.utils import CompareException from msprobe.core.common.file_utils import create_directory from msprobe.core.common.exceptions import FileCheckException from msprobe.mindspore.common.log import logger from msprobe.mindspore.compare.ms_compare import ms_compare -from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json +from msprobe.core.compare.utils import compare_distributed_inner from msprobe.mindspore.compare.ms_graph_compare import GraphMSComparator def ms_compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): - if kwargs.get('suffix'): - logger.error("Argument 'suffix' is not supported for compare_distributed.") - raise CompareException(CompareException.INVALID_PARAM_ERROR) - is_print_compare_log = kwargs.get('is_print_compare_log', True) - # get the ranks and match by order - npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank')) - bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank')) - if len(npu_ranks) != len(bench_ranks): - logger.error('The number of ranks in the two runs are different. ' - 'Unable to match the ranks. Please use another folder to compare ' - 'or use compare() api and manually match the ranks.') - raise CompareException(CompareException.INVALID_PATH_ERROR) - for nr, br in zip(npu_ranks, bench_ranks): - npu_data_dir = os.path.join(npu_dump_dir, nr) - bench_data_dir = os.path.join(bench_dump_dir, br) - npu_path = extract_json(npu_data_dir, stack_json=False) - bench_path = extract_json(bench_data_dir, stack_json=False) - - dump_result_param = { - 'npu_json_path': npu_path, - 'bench_json_path': bench_path, - 'is_print_compare_log': is_print_compare_log - } - ms_compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}-{br}', **kwargs) + compare_distributed_inner(npu_dump_dir, bench_dump_dir, output_path, ms_compare, **kwargs) def ms_graph_compare(inputs, outputs): diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py index 8509a7f38add0c2e8d3f3638f4c247895e07bd6d..ae3dfa63d78b2b7e4553a4f68df90aa84dc362ea 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py @@ -13,410 +13,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import re -from collections import defaultdict - -import numpy as np -import pandas as pd - -from msprobe.core.common.const import CompareConst, Const -from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.file_utils import FileOpen, create_directory, load_json, load_npy, load_yaml -from msprobe.core.common.log import logger -from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, \ - check_op_str_pattern_valid, get_dump_mode, set_dump_path -from msprobe.core.compare.acc_compare import Comparator, ModeConfig -from msprobe.core.compare.check import dtype_mapping +from msprobe.core.compare.acc_compare import Comparator, ModeConfig, MappingConfig, setup_comparison from msprobe.core.compare.layer_mapping import generate_data_mapping_by_layer_mapping -from msprobe.core.compare.utils import set_stack_json_path, reorder_op_x_list - - -class MappingConfig: - def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None): - self.cell_mapping = cell_mapping - self.api_mapping = api_mapping - self.data_mapping = data_mapping - - -class MSComparator(Comparator): - """ - 用于mindspore动态图同框架/跨框架精度比对,支持md5/summary/all模式。 - cell_mapping: mindspore在cell级别(L0)dump数据和pytorch的module之间的映射关系; - api_mapping: mindspore在api级别(L1)dump数据和pytorch的api之间的映射关系; - data_mapping: mindspore的cell或api的入参/出参和pytorch之间的映射关系; - is_cross_framework: 是否跨框架。 - """ - def __init__(self, mode_config, mapping_config=None, is_cross_framework=False): - super().__init__(mode_config) - self.frame_name = MSComparator.__name__ - - self.stack_mode = mode_config.stack_mode - self.auto_analyze = mode_config.auto_analyze - self.fuzzy_match = mode_config.fuzzy_match - self.dump_mode = mode_config.dump_mode - - if mapping_config: - self.cell_mapping = mapping_config.cell_mapping - self.api_mapping = mapping_config.api_mapping - self.data_mapping = mapping_config.data_mapping - - if self.data_mapping: - self.cross_frame = is_cross_framework - else: - self.cross_frame = self.cell_mapping is not None or self.api_mapping is not None - self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping) - self.api_mapping_dict = self.load_mapping_file(self.api_mapping) - if self.api_mapping is not None: - self.ms_to_pt_mapping = self.load_internal_api() - - if isinstance(self.data_mapping, str) or self.data_mapping is None: - self.data_mapping_dict = self.load_mapping_file(self.data_mapping) - elif isinstance(self.data_mapping, dict): - self.data_mapping_dict = self.data_mapping - else: - raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got " - f"{type(self.data_mapping)}") - - def calc_accuracy(self, result_df, header): - condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A - result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A) - result_df.loc[condition_no_bench, CompareConst.ERROR_MESSAGE] = CompareConst.NO_BENCH - - def calc_summary_diff(data_type: str): - def type_check(val): - check_series = pd.Series(False, index=val.index) - val_str = val.astype(str) - check_series[pd.to_numeric(val_str, errors='coerce').notna() | val_str.str.lower().eq('nan')] = True - return check_series - - def get_number(val): - return pd.to_numeric(val.astype(str), errors='coerce') - - ms_val = result_df['NPU ' + data_type] - pt_val = result_df['Bench ' + data_type] - diff_name = data_type.capitalize() + ' diff' - rel_err_name = ('norm' if data_type == 'l2norm' else data_type).capitalize() + 'RelativeErr' - condition_na = ~type_check(ms_val) | ~type_check(pt_val) - result_df.loc[condition_na, [diff_name, rel_err_name]] = CompareConst.N_A - result_df.loc[~(condition_no_bench | condition_na), diff_name] = get_number(ms_val) - get_number(pt_val) - condition_nan_diff = ~condition_no_bench & ~condition_na & result_df[diff_name].isna() - condition_not_nan_diff = ~condition_no_bench & ~condition_na & result_df[diff_name].notna() - result_df.loc[condition_nan_diff, [diff_name, rel_err_name]] = CompareConst.NAN - condition_pt_zero = pt_val == 0 - result_df.loc[condition_not_nan_diff & condition_pt_zero, rel_err_name] = CompareConst.NAN - condition_ref_err = condition_not_nan_diff & ~condition_pt_zero - result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, diff_name] / - pt_val[condition_ref_err] * 100) - result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, rel_err_name] - .abs().astype(str) + '%') - magnitude = get_number(result_df[diff_name]).abs() / ( - pd.Series(np.maximum(get_number(ms_val), get_number(pt_val))).abs() + CompareConst.EPSILON) - return magnitude > CompareConst.MAGNITUDE - - if self.dump_mode == Const.MD5: - condition_md5_equal = result_df[CompareConst.NPU_MD5] == result_df[CompareConst.BENCH_MD5] - result_df.loc[condition_md5_equal, CompareConst.RESULT] = CompareConst.PASS - result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF - elif self.dump_mode == Const.SUMMARY: - warning_list = [calc_summary_diff(data_type) for data_type in ['max', 'min', 'mean', 'l2norm']] - warning_flag = pd.DataFrame(warning_list).all() - result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = '' - result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING - result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.' - else: - fill_cols = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR, - CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO, - CompareConst.ERROR_MESSAGE] - result_df.loc[~condition_no_bench, fill_cols] = '' - result_df.loc[~condition_no_bench, CompareConst.ACCURACY] = CompareConst.ACCURACY_CHECK_YES - return result_df[header] - - def make_result_df(self, result): - header = CompareConst.HEAD_OF_COMPARE_MODE[self.dump_mode][:] - - if self.stack_mode: - header.append(CompareConst.STACK) - if self.dump_mode == Const.ALL: - header.append(CompareConst.DATA_NAME) - result.rename(columns={'op_name_x': CompareConst.NPU_NAME, - 'op_name_y': CompareConst.BENCH_NAME, - 'dtype_x': CompareConst.NPU_DTYPE, - 'dtype_y': CompareConst.BENCH_DTYPE, - 'shape_x': CompareConst.NPU_SHAPE, - 'shape_y': CompareConst.BENCH_SHAPE, - 'md5_x': CompareConst.NPU_MD5, - 'md5_y': CompareConst.BENCH_MD5, - 'data_name_x': CompareConst.DATA_NAME, - 'stack_info_x': CompareConst.STACK}, inplace=True) - - npu_summary = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM] - bench_summary = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN, - CompareConst.BENCH_NORM] - - def set_summary(summary): - if summary == CompareConst.N_A: - return [CompareConst.N_A] * 4 - summary_list = [] - for i in summary: - if i is None: - summary_list.append(CompareConst.N_A) - elif str(i).lower() == 'nan': - summary_list.append(CompareConst.NAN) - else: - summary_list.append(i) - return summary_list +from msprobe.mindspore.compare.utils import read_npy_data, check_cross_framework - result[npu_summary] = result['summary_x'].apply(set_summary).tolist() - result[bench_summary] = result['summary_y'].apply(set_summary).tolist() - result_df = pd.DataFrame(columns=header) - for h in header: - if h in result.columns: - result_df[h] = result[h] - return self.calc_accuracy(result_df, header) - def load_internal_api(self): - cur_path = os.path.dirname(os.path.realpath(__file__)) - yaml_path = os.path.abspath(os.path.join(cur_path, CompareConst.INTERNAL_API_MAPPING_FILE)) - return load_yaml(yaml_path) - - def load_mapping_file(self, mapping_file): - if isinstance(mapping_file, str): - mapping_dict = load_yaml(mapping_file) - else: - mapping_dict = {} - return mapping_dict - - def process_cell_mapping(self, npu_op_name): - if not npu_op_name: - return CompareConst.N_A - param_grad_flag = Const.PARAMS_GRAD in npu_op_name.split(Const.SEP) - if not param_grad_flag and not re.search(Const.REGEX_FORWARD_BACKWARD, npu_op_name): - return CompareConst.N_A - npu_op_name = npu_op_name.replace("Cell", "Module", 1) - if self.cell_mapping_dict: - # get cell name & class name from op_name - # Cell.fc1.Dense.forward.0.input.0 - cell_name = re.split(r'\.(?:forward|backward|parameters_grad)\.', npu_op_name.split(Const.SEP, 1)[-1])[0] - if cell_name in self.cell_mapping_dict: - npu_op_name = npu_op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1) - return npu_op_name - - def read_npy_data(self, dir_path, file_name, load_pt_file=False): - if not file_name: - return None - data_path = os.path.join(dir_path, file_name) - if load_pt_file: - import torch - from msprobe.pytorch.common.utils import load_pt - data_value = load_pt(data_path, True).detach() - if data_value.dtype == torch.bfloat16: - data_value = data_value.to(torch.float32) - data_value = data_value.numpy() - else: - data_value = load_npy(data_path) - return data_value - - def process_internal_api_mapping(self, npu_op_name): - # get api name & class name from op_name - # Functional.addcmul.0.forward.input.0 - ms_api_name = self.get_api_name(npu_op_name.split(Const.SEP)) - class_name = ms_api_name.split(Const.SEP)[0] - if class_name == "Mint": - return npu_op_name.replace("Mint", "Torch") - elif class_name == "MintFunctional": - return npu_op_name.replace("MintFunctional", "Functional") - elif self.ms_to_pt_mapping.get(ms_api_name): - return npu_op_name.replace(ms_api_name, self.ms_to_pt_mapping.get(ms_api_name)) - else: - return npu_op_name - - def get_api_name(self, api_list): - try: - api_name = api_list[0] + Const.SEP + api_list[1] - except IndexError as error: - logger.error(f'Failed to retrieve API name, please check if the dump data is reasonable') - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error - return api_name - - def compare_process(self, file_lists): - npu_json_path, bench_json_path, stack_json_path = file_lists - npu_json_data = load_json(npu_json_path) - bench_json_data = load_json(bench_json_path) - stack_json_data = load_json(stack_json_path) if self.stack_mode else None - - npu_df = self.gen_data_df(npu_json_data, stack_json_data) - bench_df = self.gen_data_df(bench_json_data, stack_json_data) - if self.cell_mapping: - npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_cell_mapping) - elif self.api_mapping: - npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_internal_api_mapping) - if isinstance(self.api_mapping, str): - self.modify_compare_data_with_user_mapping(npu_df, bench_df) - else: - npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME] - npu_df[[Const.DTYPE, Const.SHAPE]] = npu_df[[Const.DTYPE, Const.SHAPE]].astype(str) - bench_df[[Const.DTYPE, Const.SHAPE]] = bench_df[[Const.DTYPE, Const.SHAPE]].astype(str) - npu_df[CompareConst.COMPARE_SHAPE] = npu_df[Const.SHAPE] - bench_df[CompareConst.COMPARE_KEY] = bench_df[CompareConst.OP_NAME] - bench_df[CompareConst.COMPARE_SHAPE] = bench_df[Const.SHAPE] - match_result = pd.merge(npu_df, bench_df, on=[CompareConst.COMPARE_KEY, CompareConst.COMPARE_SHAPE], - how='outer') - match_result = match_result[match_result['op_name_x'].notna()].fillna(CompareConst.N_A) - - def gen_dtype_condition(): - npu_dtype = match_result['dtype_x'] - bench_dtype = match_result['dtype_y'] - if self.cross_frame: - npu_dtype = npu_dtype.map(dtype_mapping).fillna(npu_dtype) - return ((npu_dtype == bench_dtype) | - ((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.FLOAT32)) | - ((npu_dtype == Const.FLOAT32) & (bench_dtype == Const.FLOAT16)) | - ((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.BFLOAT16)) | - ((npu_dtype == Const.BFLOAT16) & (bench_dtype == Const.FLOAT16)) | - ((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_FLOAT32)) | - ((npu_dtype == Const.TORCH_FLOAT32) & (bench_dtype == Const.TORCH_FLOAT16)) | - ((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_BFLOAT16)) | - ((npu_dtype == Const.TORCH_BFLOAT16) & (bench_dtype == Const.TORCH_FLOAT16))) - - match_result.loc[~gen_dtype_condition(), [i + '_y' for i in bench_df.columns]] = CompareConst.N_A - return self.make_result_df(match_result) - - def modify_compare_data_with_user_mapping(self, npu_df, bench_df): - def get_api_indices_dict(op_name_df): - api_indices_dict = defaultdict(list) - for op_index, name in enumerate(op_name_df[CompareConst.OP_NAME]): - api = self.get_api_name(name.split(Const.SEP)) - api_indices_dict[api].append(op_index) - return api_indices_dict - - ms_api_indices_dict = get_api_indices_dict(npu_df) - pt_api_indices_dict = get_api_indices_dict(bench_df) - - def gen_input_compare_key(pattern, term): - flag = True - for i, prefix in enumerate(mapping_dict.get(f'ms_{term}')): - if op_name.split(pattern)[1].startswith(str(prefix)): - npu_df.loc[index, CompareConst.COMPARE_KEY] = ( - op_name.replace(pattern + str(prefix), - pattern + str(mapping_dict.get(f'pt_{term}')[i]))) - flag = False - return flag - - for mapping_dict in self.api_mapping_dict: - keys_to_compare = [ - ('ms_args', 'pt_args'), - ('ms_output', 'pt_output'), - ('ms_parameters', 'pt_parameters'), - ('ms_parameters_grad', 'pt_parameters_grad'), - ] - if not all(len(mapping_dict.get(k1, [])) == len(mapping_dict.get(k2, [])) for k1, k2 in keys_to_compare): - logger.warning('The user-defined mapping table is incorrect,\ - make sure that the number of parameters is equal') - continue - - ms_api, pt_api = mapping_dict.get('ms_api'), mapping_dict.get('pt_api') - if ms_api not in ms_api_indices_dict or pt_api not in pt_api_indices_dict: - continue - for index in ms_api_indices_dict.get(ms_api): - op_name = npu_df.loc[index, CompareConst.OP_NAME].replace(ms_api, pt_api, 1) - if CompareConst.INPUT_PATTERN in op_name: - is_abandoned = gen_input_compare_key(CompareConst.INPUT_PATTERN, 'args') - elif CompareConst.KWARGS_PATTERN in op_name: - is_abandoned = gen_input_compare_key(CompareConst.KWARGS_PATTERN, 'args') - elif CompareConst.OUTPUT_PATTERN in op_name: - is_abandoned = gen_input_compare_key(CompareConst.OUTPUT_PATTERN, 'output') - elif CompareConst.PARAMS_PATTERN in op_name: - is_abandoned = gen_input_compare_key(CompareConst.PARAMS_PATTERN, 'parameters') - elif CompareConst.PARAMS_GRAD_PATTERN in op_name: - is_abandoned = gen_input_compare_key(CompareConst.PARAMS_GRAD_PATTERN, 'parameters_grad') - else: - logger.error(f'Excepted op_name: {op_name}') - raise CompareException(CompareException.INVALID_DATA_ERROR) - if is_abandoned: - npu_df.loc[index, CompareConst.COMPARE_KEY] = op_name + 'abandoned' - - def gen_data_df(self, data_json, stack_json_data): - result = { - CompareConst.OP_NAME: [], - Const.DTYPE: [], - Const.SHAPE: [], - Const.SUMMARY: [], - 'stack_info': [] - } - if self.dump_mode == Const.ALL: - result['data_name'] = [] - elif self.dump_mode == Const.MD5: - result[Const.MD5] = [] - for data_name in data_json['data']: - check_op_str_pattern_valid(data_name) - merge_list = self.gen_merge_list(data_json, data_name, stack_json_data) - if not merge_list: - continue - - op_name_list = merge_list.get(CompareConst.OP_NAME) - summary_list = merge_list.get(Const.SUMMARY) - data_name_list = merge_list.get('data_name') - op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list, - summary_list, - data_name_list) - for op_name in op_name_reorder: - result[CompareConst.OP_NAME].append(op_name) - if (CompareConst.INPUT_PATTERN in op_name) or (CompareConst.KWARGS_PATTERN in op_name): - struct = merge_list[CompareConst.INPUT_STRUCT].pop(0) - elif CompareConst.OUTPUT_PATTERN in op_name: - struct = merge_list[CompareConst.OUTPUT_STRUCT].pop(0) - elif CompareConst.PARAMS_PATTERN in op_name: - struct = merge_list[CompareConst.PARAMS_STRUCT].pop(0) - else: - struct = merge_list[CompareConst.PARAMS_GRAD_STRUCT].pop(0) - result[Const.DTYPE].append(struct[0]) - result[Const.SHAPE].append(struct[1]) - if self.dump_mode == Const.MD5: - result[Const.MD5].append(struct[2]) - result[Const.SUMMARY].append(summary_reorder.pop(0)) - result['stack_info'].append(merge_list['stack_info'][0] if self.stack_mode else None) - if self.dump_mode == Const.ALL: - result['data_name'].append(data_name_reorder.pop(0)) - return pd.DataFrame(result) - - -def check_cross_framework(bench_json_path): - pattern = r'"data_name":\s*"[^"]+\.pt"' - with FileOpen(bench_json_path, 'r') as file: - for line in file: - if re.search(pattern, line): - return True - return False +def read_real_data(npu_dir, npu_data_name, bench_dir, bench_data_name, cross_frame) -> tuple: + n_value = read_npy_data(npu_dir, npu_data_name) + if cross_frame: + from msprobe.pytorch.compare.utils import read_pt_data + b_value = read_pt_data(bench_dir, bench_data_name) + else: + b_value = read_npy_data(bench_dir, bench_data_name) + return n_value, b_value def ms_compare(input_param, output_path, **kwargs): - try: - auto_analyze = kwargs.get('auto_analyze', True) - fuzzy_match = kwargs.get('fuzzy_match', False) - cell_mapping = kwargs.get('cell_mapping', None) - api_mapping = kwargs.get('api_mapping', None) - data_mapping = kwargs.get('data_mapping', None) - layer_mapping = kwargs.get('layer_mapping', None) - suffix = kwargs.get('suffix', '') + config = setup_comparison(input_param, output_path, **kwargs) - set_dump_path(input_param) - dump_mode = get_dump_mode(input_param) - if 'stack_json_path' in input_param: - stack_mode = kwargs.get('stack_mode', False) - else: - stack_mode = set_stack_json_path(input_param) # set stack_mode and set "stack_json_path" in input_param - check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True)) - create_directory(output_path) - check_compare_param(input_param, output_path, dump_mode, stack_mode) - except (CompareException, FileCheckException) as error: - logger.error('Compare failed. Please check the arguments and do it again!') - raise CompareException(error.code) from error - if layer_mapping: - data_mapping = generate_data_mapping_by_layer_mapping(input_param, layer_mapping, output_path) + if config.layer_mapping: + config.data_mapping = generate_data_mapping_by_layer_mapping(input_param, config.layer_mapping, output_path) - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig(cell_mapping, api_mapping, data_mapping) is_cross_framework = check_cross_framework(input_param.get('bench_json_path')) - ms_comparator = MSComparator(mode_config, mapping_config, is_cross_framework) - ms_comparator.compare_core(input_param, output_path, suffix=suffix) + + config_dict = { + 'stack_mode': config.stack_mode, + 'auto_analyze': config.auto_analyze, + 'fuzzy_match': config.fuzzy_match, + 'highlight': config.highlight, + 'dump_mode': config.dump_mode, + 'compared_file_type': config.compared_file_type + } + mode_config = ModeConfig(**config_dict) + mapping_config = MappingConfig(config.cell_mapping, config.api_mapping, config.data_mapping) + ms_comparator = Comparator(read_real_data, mode_config, mapping_config, is_cross_framework) + ms_comparator.compare_core(input_param, output_path, suffix=config.suffix) diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/ms_graph_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/ms_graph_compare.py index 701988ba483de4e13d85892dbb42d62c7cc805b8..40cf4be245fab4bd237ea381fde162b474ff8ff0 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/ms_graph_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/ms_graph_compare.py @@ -34,10 +34,11 @@ class RowData: self.basic_data = copy.deepcopy(CompareConst.MS_GRAPH_BASE) self.npy_data = copy.deepcopy(CompareConst.MS_GRAPH_NPY) self.statistic_data = copy.deepcopy(CompareConst.MS_GRAPH_STATISTIC) + self.csv = copy.deepcopy(CompareConst.MS_GRAPH_CSV) if mode == GraphMode.NPY_MODE: self.data = {**self.basic_data, **self.npy_data} else: - self.data = {**self.basic_data, **self.statistic_data} + self.data = {**self.basic_data, **self.statistic_data, **self.csv} def __call__(self): return self.data @@ -80,16 +81,18 @@ def statistic_data_read(statistic_file_list, statistic_file_path): data_list = [] statistic_data_list = [] header_index = { - 'Data Type': None, 'Shape': None, 'Max Value': None, - 'Min Value': None, 'Avg Value': None, 'L2Norm Value': None + 'Data Type': None, 'Shape': None, + 'Max Value': None, 'Min Value': None, 'Avg Value': None, 'L2Norm Value': None } for statistic_file in statistic_file_list: content = read_csv(statistic_file, as_pd=False) + if not content: + logger.error(f'Empty dump file: {statistic_file}') + raise CompareException(f'Empty dump file: {statistic_file}') header = content[0] - for key in header_index.keys(): - for index, value in enumerate(header): - if key == value: - header_index[key] = index + for index, value in enumerate(header): + if value in header_index: + header_index[value] = index statistic_data_list.extend(content[1:]) for key in header_index.keys(): @@ -97,8 +100,15 @@ def statistic_data_read(statistic_file_list, statistic_file_path): logger.warning(f"Data_path {statistic_file_path} has no key {key}.") for data in statistic_data_list: - compare_key = f"{data[1]}.{data[2]}.{data[3]}.{data[5]}" - op_name = f"{compare_key} {statistic_file_path}" + ''' + 13列分别是OpType, OpName, TaskId, StreamId, TimeStamp, IO, Slot, DataSize, + DataType, Shape, MaxValue, MinValue, L2NormValue + ''' + if len(data) < 13: + logger.error(f'Dump file {statistic_file_path} has been modified into incorrect format!') + raise CompareException(f'Dump file {statistic_file_path} has been modified into incorrect format!') + compare_key = f"{data[1]}.{data[2]}.{data[5]}.{data[6]}" # OpName, TaskId, IO, Slot + op_name = f"{compare_key}" timestamp = int(data[4]) result_data = [op_name, compare_key, timestamp] for key in header_index.keys(): @@ -106,6 +116,8 @@ def statistic_data_read(statistic_file_list, statistic_file_path): result_data.append(np.nan) else: result_data.append(data[header_index[key]]) + csv_file = f"{statistic_file_path}" + result_data.append(csv_file) data_list.append(result_data) return data_list @@ -159,8 +171,13 @@ class GraphMSComparator: self.output_path = output_path self.base_npu_path = input_param.get('npu_path', None) self.base_bench_path = input_param.get('bench_path', None) - self.rank_list = [convert_to_int(rank_id) for rank_id in input_param.get('rank_id', [])] - self.step_list = [convert_to_int(step_id) for step_id in input_param.get('step_id', [])] + rank_id_list = input_param.get('rank_id', []) + step_id_list = input_param.get('step_id', []) + if not isinstance(rank_id_list, list) or not isinstance(step_id_list, list): + logger.error("'rank_id' and 'step_id' should both be lists, please check!") + raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR) + self.rank_list = [convert_to_int(rank_id) for rank_id in rank_id_list] + self.step_list = [convert_to_int(step_id) for step_id in step_id_list] # split by rank and step, generate rank step path self.npu_rank_step_dict = self.generate_rank_step_path(self.base_npu_path) self.bench_rank_step_dict = self.generate_rank_step_path(self.base_bench_path) @@ -195,11 +212,12 @@ class GraphMSComparator: if not error_flag: result_list, err_msg = compare_ops_apply(n_value, b_value, False, "") result_dict[CompareConst.COSINE] = result_list[0] - result_dict[CompareConst.MAX_ABS_ERR] = result_list[1] - result_dict[CompareConst.MAX_RELATIVE_ERR] = result_list[2] - result_dict[CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result_list[3] - result_dict[CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result_list[4] - result_dict[CompareConst.ACCURACY] = check_accuracy(result_list[0], result_list[1]) + result_dict[CompareConst.EUC_DIST] = result_list[1] + result_dict[CompareConst.MAX_ABS_ERR] = result_list[2] + result_dict[CompareConst.MAX_RELATIVE_ERR] = result_list[3] + result_dict[CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result_list[4] + result_dict[CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result_list[5] + result_dict[CompareConst.ACCURACY] = check_accuracy(result_list[0], result_list[2]) result_dict[CompareConst.ERROR_MESSAGE] = err_msg return pd.Series(result_dict) @@ -215,6 +233,17 @@ class GraphMSComparator: result[f'{prefix} min'] = np.float32(rows[f'{prefix} min']) result[f'{prefix} mean'] = np.float32(rows[f'{prefix} mean']) result[f'{prefix} l2norm'] = np.float32(rows[f'{prefix} l2norm']) + result[f'{prefix} CSV File'] = rows[f'{prefix} CSV File'] + + def calculate_relative_error(numerator, denominator): + """Calculates relative error, handling division by zero and NaN.""" + if denominator != 0: + result = numerator / denominator + if not np.isnan(result): + return str(abs(result * 100)) + "%" + else: + return CompareConst.NAN + return CompareConst.N_A # 使用示例 update_result_dict(result_dict, row, 'NPU') @@ -222,34 +251,26 @@ class GraphMSComparator: error_flag, error_message = statistics_data_check(result_dict) result_dict[CompareConst.ERROR_MESSAGE] += error_message if not error_flag: - result_dict[CompareConst.MAX_DIFF] = np.abs( - result_dict[CompareConst.NPU_MAX] - result_dict[CompareConst.BENCH_MAX]) - result_dict[CompareConst.MIN_DIFF] = np.abs( - result_dict[CompareConst.NPU_MIN] - result_dict[CompareConst.BENCH_MIN]) - result_dict[CompareConst.MEAN_DIFF] = np.abs( - result_dict[CompareConst.NPU_MEAN] - result_dict[CompareConst.BENCH_MEAN]) - result_dict[CompareConst.NORM_DIFF] = np.abs( - result_dict[CompareConst.NPU_NORM] - result_dict[CompareConst.BENCH_NORM]) - result_dict[CompareConst.MAX_RELATIVE_ERR] = result_dict[CompareConst.MAX_DIFF] / result_dict[ - CompareConst.BENCH_MAX] if result_dict[CompareConst.BENCH_MAX] > 0 else 0 - if not np.isnan(result_dict[CompareConst.MAX_RELATIVE_ERR]): - result_dict[CompareConst.MAX_RELATIVE_ERR] = str( - result_dict[CompareConst.MAX_RELATIVE_ERR] * 100) + "%" - result_dict[CompareConst.MIN_RELATIVE_ERR] = result_dict[CompareConst.MIN_DIFF] / result_dict[ - CompareConst.BENCH_MIN] if result_dict[CompareConst.BENCH_MIN] > 0 else 0 - if not np.isnan(result_dict[CompareConst.MIN_RELATIVE_ERR]): - result_dict[CompareConst.MIN_RELATIVE_ERR] = \ - str(result_dict[CompareConst.MIN_RELATIVE_ERR] * 100) + "%" - result_dict[CompareConst.MEAN_RELATIVE_ERR] = result_dict[CompareConst.MEAN_DIFF] / result_dict[ - CompareConst.BENCH_MEAN] if result_dict[CompareConst.BENCH_MEAN] > 0 else 0 - if not np.isnan(result_dict[CompareConst.MEAN_RELATIVE_ERR]): - result_dict[CompareConst.MEAN_RELATIVE_ERR] = str( - result_dict[CompareConst.MEAN_RELATIVE_ERR] * 100) + "%" - result_dict[CompareConst.NORM_RELATIVE_ERR] = result_dict[CompareConst.NORM_DIFF] / result_dict[ - CompareConst.BENCH_NORM] if result_dict[CompareConst.BENCH_NORM] > 0 else 0 - if not np.isnan(result_dict[CompareConst.NORM_RELATIVE_ERR]): - result_dict[CompareConst.NORM_RELATIVE_ERR] = str( - result_dict[CompareConst.NORM_RELATIVE_ERR] * 100) + "%" + metrics = [ + (CompareConst.MAX_DIFF, CompareConst.NPU_MAX, CompareConst.BENCH_MAX), + (CompareConst.MIN_DIFF, CompareConst.NPU_MIN, CompareConst.BENCH_MIN), + (CompareConst.MEAN_DIFF, CompareConst.NPU_MEAN, CompareConst.BENCH_MEAN), + (CompareConst.NORM_DIFF, CompareConst.NPU_NORM, CompareConst.BENCH_NORM), + ] + relative_error_metrics = [ + (CompareConst.MAX_RELATIVE_ERR, CompareConst.MAX_DIFF, CompareConst.BENCH_MAX), + (CompareConst.MIN_RELATIVE_ERR, CompareConst.MIN_DIFF, CompareConst.BENCH_MIN), + (CompareConst.MEAN_RELATIVE_ERR, CompareConst.MEAN_DIFF, CompareConst.BENCH_MEAN), + (CompareConst.NORM_RELATIVE_ERR, CompareConst.NORM_DIFF, CompareConst.BENCH_NORM), + ] + + for diff_metric, npu_metric, bench_metric in metrics: + result_dict[diff_metric] = result_dict[npu_metric] - result_dict[bench_metric] + + for rel_metric, diff_metric, bench_metric in relative_error_metrics: + result_dict[rel_metric] = calculate_relative_error(result_dict[diff_metric], + result_dict[bench_metric]) + magnitude_diff = result_dict[CompareConst.MAX_DIFF] / ( max(result_dict[CompareConst.NPU_MAX], result_dict[CompareConst.BENCH_MAX]) + 1e-10) if np.isnan(result_dict[CompareConst.NPU_MAX]) and np.isnan(result_dict[CompareConst.BENCH_MAX]): @@ -281,20 +302,8 @@ class GraphMSComparator: compare_result_df = self.do_multi_process(compare_result_df, mode) compare_result_name = add_time_with_xlsx(f"compare_result_{str(rank_id)}_{str(step_id)}") compare_result_path = os.path.join(os.path.realpath(self.output_path), f"{compare_result_name}") - self.to_excel(compare_result_df, compare_result_path) - logger.info(f"Compare rank: {rank_id} step: {step_id} finish. Compare result: {compare_result_path}.") - - def to_excel(self, compare_result_df: pd.DataFrame, compare_result_path: str, slice_num=0, need_slice=False) -> int: - size = len(compare_result_df) - # sheet size cannot be larger than 1048576 - if size < CompareConst.MAX_EXCEL_LENGTH: - compare_result_path = compare_result_path.replace('.xlsx', f'_slice_{slice_num}.xlsx') if \ - need_slice else compare_result_path save_excel(compare_result_path, compare_result_df) - return slice_num + 1 - else: - slice_num = self.to_excel(compare_result_df.iloc[0: size // 2], compare_result_path, slice_num, True) - return self.to_excel(compare_result_df.iloc[size // 2:], compare_result_path, slice_num, True) + logger.info(f"Compare rank: {rank_id} step: {step_id} finish. Compare result: {compare_result_path}.") def compare_process(self, rank_id, step_id): # generate data_path @@ -316,7 +325,7 @@ class GraphMSComparator: bench_data_list.extend(data_list) if npu_mode == GraphMode.ERROR_MODE or bench_mode == GraphMode.ERROR_MODE: - logger.warning(f"Data_path {npu_data_path} or {bench_data_path} is not exist.") + logger.warning(f"Data path: npu_data_path or bench_data_path does not exist.") return [], '' if npu_mode != bench_mode: logger.error(f"NPU mode {npu_mode} not equal to MATCH mode {bench_mode}.") @@ -329,14 +338,15 @@ class GraphMSComparator: npu_data_df = pd.DataFrame(npu_data_list, columns=[CompareConst.NPU_NAME, 'Compare Key', 'TimeStamp', CompareConst.NPU_DTYPE, CompareConst.NPU_SHAPE, - CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, - CompareConst.NPU_NORM]) + CompareConst.NPU_MAX, CompareConst.NPU_MIN, + CompareConst.NPU_MEAN, CompareConst.NPU_NORM, + CompareConst.NPU_CSV_FILE]) bench_data_df = pd.DataFrame(bench_data_list, columns=[CompareConst.BENCH_NAME, 'Compare Key', 'TimeStamp', - CompareConst.BENCH_DTYPE, - CompareConst.BENCH_SHAPE, CompareConst.BENCH_MAX, - CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN, - CompareConst.BENCH_NORM]) + CompareConst.BENCH_DTYPE, CompareConst.BENCH_SHAPE, + CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, + CompareConst.BENCH_MEAN, CompareConst.BENCH_NORM, + CompareConst.BENCH_CSV_FILE]) npu_float_type = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM] npu_float_data_df = npu_data_df[npu_float_type].astype(str) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py b/debug/accuracy_tools/msprobe/mindspore/compare/utils.py similarity index 34% rename from debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py rename to debug/accuracy_tools/msprobe/mindspore/compare/utils.py index 05ee3bc92257be9882c20cf825ebb7561f41ddb1..a6f9f4ae55a656c269509fc479f476aa8b9251b9 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,47 +14,32 @@ # limitations under the License. import os -import torch from msprobe.core.common.const import Const -from msprobe.core.common.file_utils import load_yaml -from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.common.utils import torch_device_guard +from msprobe.core.common.file_utils import load_npy, FileChecker, FileCheckConst +from msprobe.core.common.utils import detect_framework_by_dump_json, CompareException, check_op_str_pattern_valid +from msprobe.core.common.log import logger -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") +def read_npy_data(dir_path, file_name): + if not file_name: + return None + data_path = os.path.join(dir_path, file_name) + path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, + FileCheckConst.NUMPY_SUFFIX, False) + data_path = path_checker.common_check() + data_value = load_npy(data_path) + return data_value -def get_vf_ops(): - yaml_data = load_yaml(yaml_path) - wrap_vf_ops = yaml_data.get('_VF') - return wrap_vf_ops +def check_cross_framework(bench_json_path): + framework = detect_framework_by_dump_json(bench_json_path) + return framework == Const.PT_FRAMEWORK -class HOOKVfOP(object): - pass - -class VfOPTemplate(HOOKModule): - def __init__(self, op_name, hook): - self.op_name_ = op_name - self.prefix_op_name_ = "VF" + Const.SEP + str(op_name) + Const.SEP - super().__init__(hook) - - @torch_device_guard - def forward(self, *args, **kwargs): - return getattr(torch._C._VariableFunctionsClass, str(self.op_name_))(*args, **kwargs) - - -def wrap_vf_op(op_name, hook): - def vf_op_template(*args, **kwargs): - return VfOPTemplate(op_name, hook)(*args, **kwargs) - - return vf_op_template - - -def wrap_vf_ops_and_bind(hook): - _vf_ops = get_vf_ops() - for op_name in _vf_ops: - setattr(HOOKVfOP, "wrap_" + op_name, wrap_vf_op(op_name, hook)) +def check_name_map_dict(name_map_dict): + if not isinstance(name_map_dict, dict): + logger.error("'map_dict' should be a dict, please check!") + raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR) + check_op_str_pattern_valid(str(name_map_dict)) diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py b/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py index 92155b4ec4ebd636477ef67f1c75b43e7a82b802..1ea0a36256600adf9a5dcf1da383bc719c024b37 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py @@ -15,12 +15,18 @@ import os +from mindspore import nn + from msprobe.core.common.const import Const from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common.file_utils import create_directory +from msprobe.core.common.log import logger from msprobe.mindspore.common.const import Const as MsConst from msprobe.mindspore.common.const import FreeBenchmarkConst -from msprobe.core.common.log import logger +from msprobe.mindspore.common.utils import is_mindtorch + +if is_mindtorch(): + import torch class DebuggerConfig: @@ -41,8 +47,13 @@ class DebuggerConfig: self.check_mode = task_config.check_mode self.framework = Const.MS_FRAMEWORK self.summary_mode = task_config.summary_mode + self.stat_cal_mode = task_config.stat_cal_mode if hasattr(task_config, 'stat_cal_mode') else None + self.device_stat_precision_mode = task_config.device_stat_precision_mode \ + if hasattr(task_config, 'device_stat_precision_mode') else None self.async_dump = common_config.async_dump if common_config.async_dump else False + self.precision = common_config.precision if common_config.precision else Const.DUMP_PRECISION_LOW self.check() + self._check_statistics_config(task_config) create_directory(self.dump_path) if self.task == Const.FREE_BENCHMARK: @@ -53,13 +64,44 @@ class DebuggerConfig: self.stage = FreeBenchmarkConst.DEFAULT_STAGE if not task_config.fuzz_stage else task_config.fuzz_stage if self.handler_type == FreeBenchmarkConst.FIX and \ self.pert_type != FreeBenchmarkConst.DEFAULT_PERT_TYPE: - raise ValueError("pert_mode must be improve_precision or empty when handler_type is fix, " - f"but got {self.pert_type}.") + logger.error("pert_mode must be improve_precision or empty when handler_type is fix, " + f"but got {self.pert_type}.") + raise ValueError if self.stage == Const.BACKWARD and self.handler_type == FreeBenchmarkConst.FIX: - raise ValueError("handler_type must be check or empty when fuzz_stage is backward, " - f"but got {self.handler_type}.") + logger.error("handler_type must be check or empty when fuzz_stage is backward, " + f"but got {self.handler_type}.") + raise ValueError self.dump_level = FreeBenchmarkConst.DEFAULT_DUMP_LEVEL + @staticmethod + def check_model(models, token_range=None): + if token_range and not models: + error_info = "The 'model' parameter must be provided when token_range is not None" + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, error_info) + + target_module_type = (torch.nn.Module, "torch.nn.Module") if is_mindtorch() else (nn.Cell, "mindspore.nn.Cell") + if models is None or isinstance(models, target_module_type[0]): + return models + if isinstance(models, (list, tuple)): + error_model = None + for model in models: + if not isinstance(model, target_module_type[0]): + error_model = model + break + if error_model is not None: + error_info = ( + f"The 'model' parameter must be a {target_module_type[1]} or list[{target_module_type[1]}] " + f"type, currently there is a {type(error_model)} type.") + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, error_info) + + else: + error_info = (f"The 'model' parameter must be a {target_module_type[1]} or list[{target_module_type[1]}] " + f"type, currently there is a {type(models)} type.") + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, error_info) + return models + def check(self): if not self.dump_path: raise Exception("Dump path is empty.") @@ -74,25 +116,48 @@ class DebuggerConfig: self.check_mode = "all" if not isinstance(self.async_dump, bool): raise Exception("The parameters async_dump should be bool.") - if self.async_dump and self.task == Const.TENSOR and not self.list: - raise Exception("The parameters async_dump is true in tensor task, the parameters list cannot be empty.") if self.task == Const.STRUCTURE and self.level_ori not in [Const.LEVEL_L0, Const.LEVEL_MIX]: logger.warning_on_rank_0( f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. " f"If not, the default level is {Const.LEVEL_MIX}." ) self.level_ori = Const.LEVEL_MIX + if self.async_dump: + if self.task == Const.TENSOR: + if self.level_ori == Const.LEVEL_DEBUG: + self.list = [] # async_dump + debug level case ignore list + if not self.list and self.level_ori != Const.LEVEL_DEBUG: + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, + "The parameters async_dump is true in tensor task, the parameters list cannot be empty." + ) + is_unsupported_mode = self.summary_mode == Const.MD5 or \ + isinstance(self.summary_mode, list) and Const.MD5 in self.summary_mode + if is_unsupported_mode: + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, + f"The parameters async_dump is true, the parameters summary_mode cannot be/contain md5." + ) return True - def check_config_with_l2(self): - if self.level_ori != Const.LEVEL_L2: - return - if self.task != Const.TENSOR: + def check_config_with_l2(self, is_graph_config): + if not is_graph_config and self.task != Const.TENSOR: raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"When level is set to L2, the task must be set to tensor.") - if self.scope: + if not is_graph_config and self.scope: raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"When level is set to L2, the scope cannot be configured.") - if not self.list or len(self.list) != 1: + if not is_graph_config and (not self.list or len(self.list) != 1): raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"When level is set to L2, the list must be configured as a list with one api name.") + + def _check_statistics_config(self, task_config): + if self.task != Const.STATISTICS: + return + self.tensor_list = [] + if not hasattr(task_config, "tensor_list"): + return + if self.level_ori == Const.LEVEL_DEBUG and task_config.tensor_list: + logger.warning_on_rank_0("When level is set to debug, the tensor_list will be invalid.") + return + self.tensor_list = task_config.tensor_list diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py index 7694d71dd98ae1c7c4611f9435a274ac018e5df6..99aeeba899b21677f1d5dfe6a64c35e8bd87648f 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py @@ -17,24 +17,38 @@ import os from collections import defaultdict, namedtuple import mindspore as ms +from mindspore.ops.operations import _inner_ops as inner from mindspore._c_expression import MSContext -from msprobe.core.common.const import Const, FileCheckConst, MsgConst -from msprobe.core.common.exceptions import MsprobeException -from msprobe.core.common.file_utils import FileChecker -from msprobe.core.common.utils import get_real_step_or_rank +from msprobe.core.common.const import Const, MsgConst +from msprobe.core.common.utils import check_token_range, ThreadSafe +from msprobe.core.common.runtime import Runtime +from msprobe.core.debugger.precision_debugger import BasePrecisionDebugger from msprobe.mindspore.cell_processor import CellProcessor from msprobe.mindspore.common.const import Const as MsConst -from msprobe.mindspore.common.utils import set_register_backward_hook_functions, check_save_param +from msprobe.mindspore.common.utils import ( + set_register_backward_hook_functions, + check_save_param, + is_graph_mode_cell_dump_allowed, + wrap_backward_hook_call_func +) from msprobe.mindspore.debugger.debugger_config import DebuggerConfig -from msprobe.mindspore.dump.hook_cell.api_registry import api_register +from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump +from msprobe.mindspore.dump.hook_cell.api_register import get_api_register from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell from msprobe.mindspore.grad_probe.grad_monitor import GradientMonitor -from msprobe.mindspore.ms_config import parse_json_config -from msprobe.mindspore.runtime import Runtime -from msprobe.mindspore.service import Service +from msprobe.mindspore.ms_config import parse_task_config +from msprobe.mindspore.mindspore_service import MindsporeService from msprobe.mindspore.task_handler_factory import TaskHandlerFactory +try: + from mindspore._c_expression import _dump_start, _dump_stop, _dump_step, _set_init_iter, _dump_set_dynamic + import mindspore as ms +except ImportError: + enable_dynamic_kbyk_dump = False +else: + enable_dynamic_kbyk_dump = True + try: from msprobe.lib import _msprobe_c except ImportError: @@ -44,79 +58,50 @@ except ImportError: ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task", "dump_path", "level"]) -class PrecisionDebugger: - _instance = None - task_not_need_service = [Const.GRAD_PROBE] - - def __new__(cls, config_path=None, task=None, dump_path=None, - level=None, step=None, opt=None): - if not cls._instance: - cls._instance = super().__new__(cls) - cls._instance.initialized = False - cls._instance.config = None - cls.service = None - cls.first_start = False - return cls._instance +class PrecisionDebugger(BasePrecisionDebugger): - def __init__(self, config_path=None, task=None, dump_path=None, - level=None, step=None): + def __init__( + self, + config_path=None, + task=None, + dump_path=None, + level=None, + step=None + ): if self.initialized: return - self.initialized = True - set_register_backward_hook_functions() + super().__init__(config_path, task, dump_path, level, step) - if not config_path: - config_path = os.path.join(os.path.dirname(__file__), "../../config.json") - - config_params = ConfigParameters(config_path, task, dump_path, level) - self.check_input_params(config_params) - - common_config, task_config = parse_json_config(config_path) - common_config.task = task if task else common_config.task - self.task = common_config.task if self.task == Const.GRAD_PROBE: - self.gm = GradientMonitor(common_config, task_config) + self.gm = GradientMonitor(self.common_config, self.task_config) return - common_config.step = get_real_step_or_rank( - step, Const.STEP) if step is not None else common_config.step - common_config.level = level if level else common_config.level - common_config.dump_path = dump_path if dump_path else common_config.dump_path - self.config = DebuggerConfig(common_config, task_config) + self.common_config.level = level if level else self.common_config.level + self.common_config.dump_path = dump_path if dump_path else self.common_config.dump_path + self.config = DebuggerConfig(self.common_config, self.task_config) - if _msprobe_c: + if self._is_kernel_dump() and not self.task_config.is_regex_valid: + raise ValueError('Illegal regular expressions exist in the list.') + + setattr(inner.CellBackwardHook, '__call__', + wrap_backward_hook_call_func(getattr(inner.CellBackwardHook, '__call__'))) + + if self._is_kernel_dump() and _msprobe_c: + os.environ["MS_HOOK_ENABLE"] = "on" _msprobe_c._PrecisionDebugger(framework="MindSpore", config_path=config_path) self.config.execution_mode = self._get_execution_mode() if self._need_service(): - self.config.check_config_with_l2() - self.service = Service(self.config) + self.service = MindsporeService(self.config) Runtime.step_count = 0 Runtime.is_running = False + if enable_dynamic_kbyk_dump and self.config.level_ori == Const.LEVEL_L2: + _dump_set_dynamic() @staticmethod - def check_input_params(args): - if args.config_path is not None: - if not isinstance(args.config_path, str): - raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string") - file_checker = FileChecker( - file_path=args.config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX) - file_checker.common_check() - - if args.task is not None and args.task not in Const.TASK_LIST: - raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}") - - if args.dump_path is not None: - if not isinstance(args.dump_path, str): - raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string") - - if args.level is not None and args.level not in Const.LEVEL_LIST: - raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}") + def _get_task_config(task, json_config): + return parse_task_config(task, json_config) @staticmethod def _get_execution_mode(): @@ -137,9 +122,7 @@ class PrecisionDebugger: return MsConst.PYNATIVE_MODE @staticmethod - def _is_graph_dump(config): - if config.level != MsConst.KERNEL: - return False + def _is_graph_dump(config: DebuggerConfig): if not config.list: return True is_graph = any(item.startswith("name-regex") for item in config.list) @@ -147,66 +130,70 @@ class PrecisionDebugger: return is_graph @classmethod - def start(cls, model=None): - instance = cls._instance - if not instance: - raise Exception(MsgConst.NOT_CREATED_INSTANCE) - if _msprobe_c: - _msprobe_c._PrecisionDebugger().start() - if instance.task in PrecisionDebugger.task_not_need_service: + def start(cls, model=None, token_range=None): + instance = cls._get_instance() + if instance is None: return - - instance.config.execution_mode = cls._get_execution_mode() - if cls._need_service(): - if not instance.service: - instance.service = Service(instance.config) - instance.service.start(model) + if cls._is_kernel_dump(): + cls._start_kernel_dump() else: - if not instance.first_start: - api_register.api_set_ori_func() - handler = TaskHandlerFactory.create(instance.config) - handler.handle() - + check_token_range(token_range) + instance.config.execution_mode = cls._get_execution_mode() + if cls._need_service(): + with ThreadSafe(): + if not instance.service: + instance.service = MindsporeService(instance.config) + instance.config.check_model(model, token_range) + instance.service.start(model, token_range) + else: + if not instance.first_start: + get_api_register().restore_all_api() + handler = TaskHandlerFactory.create(instance.config, model) + handler.handle() + Runtime.is_running = True instance.first_start = True - Runtime.is_running = True - - @classmethod - def forward_backward_dump_end(cls): - instance = cls._instance - instance.stop() @classmethod def stop(cls): - instance = cls._instance - if not instance: - raise Exception(MsgConst.NOT_CREATED_INSTANCE) - if _msprobe_c: - _msprobe_c._PrecisionDebugger().stop() + instance = cls._get_instance() + if instance is None: + return + if instance.task == Const.GRAD_PROBE: instance.gm.stop() - if instance.task in PrecisionDebugger.task_not_need_service: - return if instance.service: - instance.service.stop() - Runtime.is_running = False + with ThreadSafe(): + instance.service.stop() + else: + Runtime.is_running = False + if enable_dynamic_kbyk_dump and instance.config.level_ori == Const.LEVEL_L2: + ms.runtime.synchronize() + _dump_stop() + if cls._is_kernel_dump() and _msprobe_c: + _msprobe_c._PrecisionDebugger().stop() @classmethod def step(cls): - instance = cls._instance - if not instance: - raise Exception(MsgConst.NOT_CREATED_INSTANCE) - if _msprobe_c: - _msprobe_c._PrecisionDebugger().step() - if instance.task in PrecisionDebugger.task_not_need_service: + instance = cls._get_instance() + if instance is None: return + if instance.service: - instance.service.step() + with ThreadSafe(): + instance.service.step() + if is_graph_mode_cell_dump_allowed(instance.config): + GraphModeCellDump.step() + if enable_dynamic_kbyk_dump and instance.config.level_ori == Const.LEVEL_L2: + _dump_step(1) + if cls._is_kernel_dump() and _msprobe_c: + _msprobe_c._PrecisionDebugger().step() + HOOKCell.cell_count = defaultdict(int) CellProcessor.reset_cell_stats() - Runtime.step_count += 1 @classmethod + @ThreadSafe.synchronized def monitor(cls, opt): instance = cls._instance if not instance: @@ -216,6 +203,7 @@ class PrecisionDebugger: instance.gm.monitor(opt) @classmethod + @ThreadSafe.synchronized def save(cls, variable, name, save_backward=True): instance = cls._instance if not instance: @@ -226,19 +214,50 @@ class PrecisionDebugger: check_save_param(variable, name, save_backward) except ValueError: return - - instance.config.execution_mode = cls._get_execution_mode() - if cls._need_service(): - if not instance.service: - instance.service = Service(instance.config) - instance.service.save(variable, name, save_backward) + if not instance.service: + instance.service = MindsporeService(instance.config) + instance.service.save(variable, name, save_backward) @classmethod def _need_service(cls): instance = cls._instance if not instance: raise Exception(MsgConst.NOT_CREATED_INSTANCE) + if instance.config.level_ori == Const.LEVEL_L2: + return not instance._is_graph_dump(instance.config) if instance.config.execution_mode != MsConst.PYNATIVE_MODE: return False else: - return instance.config.task != Const.FREE_BENCHMARK and not instance._is_graph_dump(instance.config) \ No newline at end of file + return instance.config.task != Const.FREE_BENCHMARK + + @classmethod + def _is_kernel_dump(cls): + instance = cls._instance + if not instance: + raise Exception(MsgConst.NOT_CREATED_INSTANCE) + return instance.config.level_ori == Const.LEVEL_L2 + + @classmethod + def _start_kernel_dump(cls): + instance = cls._get_instance() + is_graph_config = cls._is_graph_dump(instance.config) + instance.config.check_config_with_l2(is_graph_config) + if not is_graph_config: + if not instance.service: + instance.service = MindsporeService(instance.config) + instance.service.start() + else: + if _msprobe_c: + _msprobe_c._PrecisionDebugger().start() + if not instance.first_start: + get_api_register().restore_all_api() + handlers = TaskHandlerFactory.create(instance.config) + for handler in handlers: + handler.handle() + if enable_dynamic_kbyk_dump: + _set_init_iter(0) + if enable_dynamic_kbyk_dump: + is_valid_rank = (not instance.config.rank or Runtime.rank_id in instance.config.rank) + is_valid_step = (not instance.config.step or Runtime.step_count in instance.config.step) + if is_valid_rank and is_valid_step: + _dump_start() diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/cell_dump_process.py b/debug/accuracy_tools/msprobe/mindspore/dump/cell_dump_process.py new file mode 100644 index 0000000000000000000000000000000000000000..c0903ca86b40a9dd38984898475a23dbedf9cb51 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/cell_dump_process.py @@ -0,0 +1,928 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +from multiprocessing import Pool +import os +import re +import time +from dataclasses import dataclass +from typing import List, Optional, Union, Any + +import numpy as np +import pandas as pd +import mindspore as ms +from mindspore import nn, ops + +from msprobe.core.common.const import Const as CoreConst +from msprobe.core.common.const import FileCheckConst +from msprobe.core.common.file_utils import ( + load_npy, save_json, remove_path, load_yaml, + create_directory, read_csv, write_df_to_csv, write_csv, move_file, move_directory) +from msprobe.mindspore.common.log import logger + +CONSTRUCT_FILE_NAME = "construct.json" +DEFAULT_RANK_DIR = "rank0" +KEY_LAYERS = "layers" +construct = {} +cell_list = [] +free_cells = {} +parent_cell_types = {} +KEY_SIDE_EFFECT = "side_effect_io" +KEY_TOPLAYER = "TopLayer" +KEY_FORWARD = CoreConst.FORWARD +KEY_BACKWARD = CoreConst.BACKWARD +KEY_INPUT = CoreConst.INPUT +KEY_OUTPUT = CoreConst.OUTPUT +KEY_DUMP_TENSOR_DATA = "dump_tensor_data_" +KEY_STATISTIC_CSV = "statistic.csv" +KEY_TD_FLAG = "td_flag" +td = ops.TensorDump() +if (ms.__version__ >= "2.5.0"): + td_in = ops.TensorDump("in") +else: + td_in = ops.TensorDump() +dump_gradient_op_existed = False +if hasattr(ops, 'DumpGradient'): + gd = ops.DumpGradient() + dump_gradient_op_existed = True +else: + logger.warning('The operator "DumpGradient" does not exist. Cell dump can not work in graph mode.') +graph_step_flag = True +try: + from mindspore._c_expression import _set_init_iter +except ImportError: + graph_step_flag = False +td.add_prim_attr(KEY_SIDE_EFFECT, False) +td_in.add_prim_attr(KEY_SIDE_EFFECT, False) +td.add_prim_attr(KEY_TD_FLAG, True) +td_in.add_prim_attr(KEY_TD_FLAG, True) +dump_task = CoreConst.STATISTICS +np_ms_dtype_dict = { + "bool": ms.bool_, + "int8": ms.int8, + "byte": ms.byte, + "int16": ms.int16, + "short": ms.short, + "int32": ms.int32, + "intc": ms.intc, + "int64": ms.int64, + "intp": ms.intp, + "uint8": ms.uint8, + "ubyte": ms.ubyte, + "uint16": ms.uint16, + "ushort": ms.ushort, + "uint32": ms.uint32, + "uintc": ms.uintc, + "uint64": ms.uint64, + "uintp": ms.uintp, + "float16": ms.float16, + "half": ms.half, + "float32": ms.float32, + "single": ms.single, + "float64": ms.float64, + "double": ms.double, + "bfloat16": ms.bfloat16, + "complex64": ms.complex64, + "complex128": ms.complex128 +} + + +@dataclass +class CellDumpConfig: + net: object + dump_path: str + data_mode: str + task: str = CoreConst.STATISTICS + summary_mode: Optional[Union[List[str], str]] = None + step: int = 0 + + +def gen_file_path(dump_path, cell_prefix, suffix, io_type, index): + step_path = os.path.join(dump_path, "{step}") + rank_path = os.path.join(step_path, "{rank}") + data_path = os.path.join(rank_path, CoreConst.DUMP_TENSOR_DATA) + file_name = "" + if dump_task == CoreConst.TENSOR: + file_name = cell_prefix + CoreConst.SEP + suffix + CoreConst.SEP + io_type + CoreConst.SEP + str(index) + if dump_task == CoreConst.STATISTICS: + file_name = cell_prefix + CoreConst.HYPHEN + suffix + CoreConst.HYPHEN + io_type + CoreConst.HYPHEN + str(index) + return os.path.join(data_path, file_name) + + +def need_tensordump_in(cell_obj, attr, index): + if not hasattr(cell_obj, attr): + return False + attr_values = getattr(cell_obj, attr) + if index >= len(attr_values): + return False + return attr_values[index] == "in" + + +def cell_construct_wrapper(func, self): + def new_construct(self, *args, **kwargs): + new_args = [] + out_list = [] + + index = 0 + item = None + backward_or_all = self.data_mode in ["backward", "all"] + forward_or_all = self.data_mode in ["forward", "all"] + # The inputs of the cell. + for index, item in enumerate(args): + if backward_or_all and ops.is_tensor(item): + if need_tensordump_in(self, 'input_dump_mode', index): + item = gd(gen_file_path(self.dump_path, self.cell_prefix, KEY_BACKWARD, KEY_OUTPUT, index), + item, "out") + else: + item = gd(gen_file_path(self.dump_path, self.cell_prefix, KEY_BACKWARD, KEY_OUTPUT, index), + item, "in") + if forward_or_all and ops.is_tensor(item): + if need_tensordump_in(self, 'input_dump_mode', index): + temp = td_in( + gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_INPUT, index), + item + ) + else: + temp = td( + gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_INPUT, index), + item + ) + item = ops.depend(item, temp) + new_args.append(item) + + out = func(*new_args, **kwargs) + + # The outputs of the cell. + if isinstance(out, tuple): + for index, item in enumerate(out): + if backward_or_all and ops.is_tensor(item): + if need_tensordump_in(self, 'output_dump_mode', index): + item = gd(gen_file_path(self.dump_path, self.cell_prefix, KEY_BACKWARD, KEY_INPUT, index), + item, "out") + else: + item = gd(gen_file_path(self.dump_path, self.cell_prefix, KEY_BACKWARD, KEY_INPUT, index), + item, "in") + if forward_or_all and ops.is_tensor(item): + if need_tensordump_in(self, 'output_dump_mode', index): + temp = td_in( + gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_OUTPUT, index), + item + ) + else: + temp = td( + gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_OUTPUT, index), + item + ) + item = ops.depend(item, temp) + out_list.append(item) + elif forward_or_all and not ops.is_tensor(item): + out_list.append(item) + out_list = tuple(out_list) + return out_list + else: + if backward_or_all: + if need_tensordump_in(self, 'output_dump_mode', index): + out = gd(gen_file_path(self.dump_path, self.cell_prefix, KEY_BACKWARD, KEY_INPUT, 0), + out, "out") + else: + out = gd(gen_file_path(self.dump_path, self.cell_prefix, KEY_BACKWARD, KEY_INPUT, 0), + out, "in") + if forward_or_all and ops.is_tensor(out): + if need_tensordump_in(self, 'output_dump_mode', index): + temp = td_in( + gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_OUTPUT, 0), + out + ) + else: + temp = td( + gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_OUTPUT, 0), + out + ) + out = ops.depend(out, temp) + return out + + return new_construct.__get__(self, type(self)) + + +# 获取目录下所有文件名并根据TensorDump落盘自增id从小到大排序 +def sort_filenames(path): + filenames = os.listdir(path) + id_pattern = re.compile(rf'{CoreConst.REPLACEMENT_CHARACTER}(\d+){CoreConst.NUMPY_SUFFIX}$') + # 只保留能提取到数字id的文件,避免数组越界 + valid_files = [] + for filename in filenames: + match = id_pattern.findall(filename) + if match and match[0].isdigit(): + valid_files.append(filename) + else: + logger.warning(f"File {filename} does not match the expected pattern and will be ignored.") + valid_files.sort(key=lambda x: int(id_pattern.findall(x)[0])) + return valid_files + + +def rename_filename(path="", data_df=None): + if dump_task == CoreConst.TENSOR: + filenames = sort_filenames(path) + if dump_task == CoreConst.STATISTICS: + filenames = data_df[CoreConst.OP_NAME].tolist() + + filename_dict = {} + for index, filename in enumerate(filenames): + if dump_task == CoreConst.TENSOR: + name_field = filename.rsplit(CoreConst.REPLACEMENT_CHARACTER, 1)[0] + if dump_task == CoreConst.STATISTICS: + name_field = filename + + if name_field in filename_dict: + filename_dict[name_field] += 1 + else: + filename_dict[name_field] = 0 + + cell_index = filename_dict[name_field] + + # 修改文件名,增加重复调用Cell的序号 + if CoreConst.FORWARD_PATTERN in filename: + # Format: Cell.{cell_name}.{class_name}.{forward/backward}.{number}.{input/output}.{index}_{dtype}_{id}.npy + new_file_name = filename.replace(CoreConst.FORWARD_PATTERN, + CoreConst.FORWARD_PATTERN + str(cell_index) + CoreConst.SEP) + if CoreConst.BACKWARD_PATTERN in filename: + new_file_name = filename.replace(CoreConst.BACKWARD_PATTERN, + CoreConst.BACKWARD_PATTERN + str(cell_index) + CoreConst.SEP) + if dump_task == CoreConst.TENSOR: + move_file(os.path.join(path, filename), os.path.join(path, new_file_name)) + if dump_task == CoreConst.STATISTICS: + data_df.loc[index, CoreConst.OP_NAME] = new_file_name + logger.info("==========The rename_filename phase is Finished!==========") + + +# Extract the field between the first "." and the third to last ".", i.e. {cell_name} +def get_cell_name(cell_str): + parts = cell_str.split(CoreConst.SEP) + if len(parts) < 4: + return None + start_index = 1 + end_index = len(parts) - 3 + return CoreConst.SEP.join(parts[start_index:end_index]) + + +# Extract the field between the last "." and the second to last ".", i.e. {data_made} +def get_data_mode(cell_str): + last_dot_index = cell_str.rfind(CoreConst.SEP) + second_last_dot_index = cell_str.rfind(CoreConst.SEP, 0, last_dot_index) + data_mode = cell_str[second_last_dot_index + 1:last_dot_index] + return data_mode + + +# 判断二者之间是否存在父子关系 +def check_relation(cell_name, parent_cell_name): + layers_pattern = rf"{CoreConst.SEP}{KEY_LAYERS}{CoreConst.SEP}\d+$" + last_dot_index = cell_name.rfind(CoreConst.SEP) + if last_dot_index == -1: + return False + # 如果cell_name最后一个'.'之前的字段等于parent_cell_name,则判定存在父子关系 + sub_cell_name = cell_name[:last_dot_index] + if sub_cell_name == parent_cell_name: + return True + elif re.search(layers_pattern, cell_name): + # 如果cell_name以".layer.{layer_id}"结尾,且去掉该字段后等于parent_cell_name,则判定存在父子关系 + sub_cell_name = re.sub(layers_pattern, '', cell_name) + if sub_cell_name == parent_cell_name: + return True + return False + + +def get_parent_cell_name(child_cell_name): + parent_cell_name = '' + + last_dot_index = child_cell_name.rfind(CoreConst.SEP) + if last_dot_index == -1: + return parent_cell_name + + layers_pattern = rf"{CoreConst.SEP}{KEY_LAYERS}{CoreConst.SEP}\d+$" + if re.search(layers_pattern, child_cell_name): + parent_cell_name = re.sub(layers_pattern, '', child_cell_name) + else: + parent_cell_name = child_cell_name[:last_dot_index] + + return parent_cell_name + + +def get_construct(cell_list_input): + global free_cells, parent_cell_types + for cell in cell_list_input: + cell_name = get_cell_name(cell) + cell_data_mode = get_data_mode(cell) + found_flag = False + for parent_cell in cell_list_input: + parent_cell_name = get_cell_name(parent_cell) + parent_data_mode = get_data_mode(parent_cell) + has_relation = check_relation(cell_name, parent_cell_name) + if has_relation and parent_data_mode == cell_data_mode: + construct.update({cell: parent_cell}) + found_flag = True + break + if not found_flag: + cell_name_with_mode = f'{cell_name}{CoreConst.SEP}{cell_data_mode}' + if cell_name_with_mode in free_cells: + construct.update({cell: free_cells.get(cell_name_with_mode)}) + continue + + parent_cell = None + parent_cell_name = get_parent_cell_name(cell_name) + if parent_cell_name and cell_name in parent_cell_types: + parent_cell = CoreConst.SEP.join([CoreConst.CELL, parent_cell_name, parent_cell_types.get(cell_name)]) + second_last_dot_index = cell.rfind(CoreConst.SEP, 0, cell.rfind(CoreConst.SEP)) + parent_cell = f'{parent_cell}{cell[second_last_dot_index:]}' + free_cells[cell_name_with_mode] = parent_cell + + construct.update({cell: parent_cell}) + + +def generate_construct(path): + global construct + if dump_task == CoreConst.TENSOR: + # filename格式:Cell.clip_grad_norm.ClipGradNorm.forward.0.output.1_int32_0.npy + filenames = sort_filenames(path) + point_position = 3 + if dump_task == CoreConst.STATISTICS: + df = read_csv(path) + # filename格式:Cell.clip_grad_norm.ClipGradNorm.forward.0.output.1 + filenames = df[CoreConst.OP_NAME].tolist() + point_position = 2 + + # 提取文件名中Cell.{cell_name}.{class_name}.{data_mode}.{重复调用此cell的序号}字段,并存入cell_list + for filename in filenames: + mid_field = filename.rsplit(CoreConst.SEP, point_position)[0] + if KEY_INPUT in filename: + if mid_field in cell_list: + cell_list.remove(mid_field) + cell_list.append(mid_field) + else: + if mid_field not in cell_list: + index = filenames.index(filename) + output_field = mid_field + KEY_OUTPUT + find_flag = False + for filename_other in cell_list[index + 1:]: + if output_field in filename_other: + find_flag = True + if find_flag is False: + cell_list.append(mid_field) + + get_construct(cell_list) + + # 生成JSON文件 + rank_dir = os.path.dirname(path) + json_path = os.path.join(rank_dir, CONSTRUCT_FILE_NAME) + save_json(json_path, construct, indent=1) + + # 清空'construct'继续处理下一个路径下的数据 + construct = {} + logger.info(f"Construct data saved to {json_path}") + + +def process_file(file_path): + try: + # 读取.npy文件内容 + npy_content = load_npy(file_path) + logger.debug(f"Loaded {file_path}: shape is {npy_content.shape}, dtype is {npy_content.dtype}") + + # 文件名举例:Cell.network._backbone.loss.CrossEntropyLoss.forward.0.input.0_float32_165.npy + parts = os.path.basename(file_path).split(CoreConst.SEP) + data_dtype = "" + # 获取0_float32_165或者0_in_float32_165中的float32 + data_dtype_list = parts[-2].split('_') + if len(data_dtype_list) > 1: + data_dtype = data_dtype_list[-2] + # op_name是Cell.network._backbone.loss.CrossEntropyLoss.forward.0 + op_name = CoreConst.SEP.join(parts[:-3]) + ms_dtype = np_ms_dtype_dict.get(data_dtype) + if ms_dtype is None: + logger.warning(f"Get dtype None from file {file_path}") + + # 修改落盘文件名字,去掉TensorDump自带的数据类型和自增id字段 + data_file_name = os.path.basename(file_path) + data_file_dir = os.path.dirname(file_path) + parts = data_file_name.split(CoreConst.SEP) + if len(parts) >= 2: + param_index = parts[-2].split(CoreConst.REPLACEMENT_CHARACTER)[0] + pre_parts = CoreConst.SEP.join(parts[:-2]) + new_file_name = pre_parts + CoreConst.SEP + param_index + CoreConst.NUMPY_SUFFIX + move_file(os.path.join(data_file_dir, data_file_name), os.path.join(data_file_dir, new_file_name)) + logger.debug(f"{data_file_name} is renamed to {new_file_name}") + else: + logger.warning(f"Failed to rename {data_file_name}.") + new_file_name = data_file_name + + tensor_json = { + CoreConst.TYPE: 'mindspore.Tensor', + CoreConst.DTYPE: str(ms_dtype), + CoreConst.SHAPE: list(npy_content.shape), + CoreConst.MAX: npy_content.max().item(), + CoreConst.MIN: npy_content.min().item(), + CoreConst.MEAN: npy_content.mean().item(), + CoreConst.NORM: np.linalg.norm(npy_content).item(), + CoreConst.DATA_NAME: new_file_name + } + + # 根据文件名的最后一个部分(输入或输出)确定是添加到input_args还是output + if parts[-3] == KEY_INPUT: + return op_name, CoreConst.INPUT_ARGS, tensor_json + elif parts[-3] == KEY_OUTPUT: + return op_name, KEY_OUTPUT, tensor_json + else: + return None, None, None + + except Exception as e: + logger.error(f"Error reading {file_path}: {e}") + return None, None, None + + +def custom_sort(item, key_to_index): + key = item[0] + return key_to_index.get(key, float('inf')) + + +def convert_special_values(value: Any) -> Union[bool, float, None, str, Any]: + if isinstance(value, str): + if value.lower() == "true": + return True + elif value.lower() == "false": + return False + try: + return float(value) + except ValueError: + return value + elif pd.isna(value): + return None + return value + + +def process_csv(path): + data_info = [] + df = read_csv(path) + df = df.sort_values(by='Op Name', ascending=True) + columns = df.columns + colume_to_json_key = { + 'Max Value': CoreConst.MAX, + 'Min Value': CoreConst.MIN, + 'Avg Value': CoreConst.MEAN, + 'L2Norm Value': CoreConst.NORM + } + for _, row in df.iterrows(): + # op_name_value格式:Cell.network._backbone.loss.CrossEntropyLoss.forward.0.input.0 + op_name_value = row['Op Name'] + op_name = op_name_value.rsplit(CoreConst.SEP, 2)[0] + + # 获取input/output字段 + io_key = op_name_value.split(CoreConst.SEP)[-2] + + # shape读取出来为字符串类型转为list。"(1,4096)"->[1,4096] + shape_num = re.findall(r'\d+', row['Shape']) + shape = [int(num) for num in shape_num] + + tensor_json = { + CoreConst.TYPE: 'mindspore.Tensor', + CoreConst.DTYPE: str(np_ms_dtype_dict.get(row['Data Type'])), + CoreConst.SHAPE: shape + } + for col_name, json_key in colume_to_json_key.items(): + if col_name in columns: + value = convert_special_values(row[col_name]) + tensor_json[json_key] = value + + if io_key == KEY_INPUT: + data_info.append([op_name, CoreConst.INPUT_ARGS, tensor_json]) + elif io_key == KEY_OUTPUT: + data_info.append([op_name, KEY_OUTPUT, tensor_json]) + else: + data_info.append([None, None, None]) + return data_info + + +def generate_dump_info(path): + if not os.path.exists(path): + logger.error("The provided path does not exist.") + return + + if dump_task == CoreConst.TENSOR: + dump_data = {"task": "tensor", "level": "L0", "dump_data_dir": path, "data": {}} + with Pool(processes=10) as pool: + file_paths = [] + for file in os.listdir(path): + if file.endswith(FileCheckConst.NUMPY_SUFFIX): + file_paths.append((os.path.join(path, file),)) + file_paths.sort() + results = pool.starmap(process_file, file_paths) + if dump_task == CoreConst.STATISTICS: + dump_data = {"task": "statistics", "level": "L0", "framework": "mindspore", "dump_data_dir": None, "data": {}} + results = process_csv(path) + + # 收集结果 + for op_name, key, tensor_json in results: + if op_name: + if op_name not in dump_data.get(CoreConst.DATA, {}): + dump_data.get(CoreConst.DATA, {})[op_name] = {CoreConst.INPUT_ARGS: [], + CoreConst.INPUT_KWARGS: {}, + KEY_OUTPUT: []} + if key not in dump_data.get(CoreConst.DATA, {}).get(op_name, {}): + dump_data.get(CoreConst.DATA, {}).get(op_name, {})[key] = [] + dump_data.get(CoreConst.DATA, {}).get(op_name, {}).get(key, []).append(tensor_json) + + # 根据cell_list排序 + data_dict = dump_data.get(CoreConst.DATA, {}) + key_to_index = {key: index for index, key in enumerate(cell_list)} + sorted_data_dict = dict(sorted(data_dict.items(), key=lambda item: custom_sort(item, key_to_index))) + dump_data[CoreConst.DATA] = sorted_data_dict + + # 将数据写入dump.json + json_path = os.path.join(os.path.dirname(path), 'dump.json') + save_json(json_path, dump_data, indent=1) + + logger.info(f"Dump data saved to {json_path}") + + +def generate_stack_info(path): + if not os.path.exists(path): + logger.error("The provided path does not exist.") + return + + stack_data = {} + for cell_name in cell_list: + stack_data.update({cell_name: []}) + + # 将数据写入stack.json + json_path = os.path.join(os.path.dirname(path), 'stack.json') + save_json(json_path, stack_data, indent=1) + + # 删除csv文件 + if dump_task == CoreConst.STATISTICS: + remove_path(path) + + logger.info(f"Stack data saved to {json_path}") + + +def is_download_finished(directory, interval=3): + """ + 判断指定目录在一段时间后是否有数据被下载完成 + :param directory: 指定目录的路径 + :param interval: 检查的时间间隔(秒),默认为 3 秒 + :return: 如有数据被下载完成返回 True,否则返回 False + """ + # 检查目录是否存在 + if not os.path.exists(directory): + logger.warning(f"The specified directory {directory} does not exist.") + return False + initial_modification_time = os.path.getmtime(directory) + time.sleep(interval) + current_modification_time = os.path.getmtime(directory) + # 比较初始和当前修改时间 + if current_modification_time > initial_modification_time: + return False + else: + return True + + +def process(dump_path): + if not os.path.exists(dump_path): + logger.warning('No grap cell data is dumped.') + create_directory(dump_path) + return + + rank_id = os.environ.get('RANK_ID') + rank_dir = DEFAULT_RANK_DIR + if rank_id is not None: + rank_dir = CoreConst.RANK + str(rank_id) + + step_dir_list = os.listdir(dump_path) + for step_dir in step_dir_list: + step_path = os.path.join(dump_path, step_dir) + rank_path = os.path.join(step_path, rank_dir) + npy_path = os.path.join(rank_path, CoreConst.DUMP_TENSOR_DATA) + while True: + is_finished = is_download_finished(npy_path) + if not is_finished: + logger.info("There is data being downloaded in the specified directory, continue checking...") + else: + logger.info("There is no data being downloaded in the specified directory, Stop checking.") + break + logger.info("==========Start processing data that has already been stored on the disk!==========") + rename_filename(path=npy_path) + generate_construct(npy_path) + generate_dump_info(npy_path) + generate_stack_info(npy_path) + # 单卡场景,rank目录名称为rank + if rank_id is None: + new_rank_path = os.path.join(step_path, CoreConst.RANK) + try: + move_directory(rank_path, new_rank_path) + logger.info(f"Directory was successfully renamed to: {new_rank_path}") + except Exception as e: + logger.warning(f"Failed to renamed to {new_rank_path}: {e}") + logger.info("==========JSON file generation completed!==========") + + +# 删除csv文件中每行数据最后面的逗号 +def remove_trailing_commas(filename): + csv_data = read_csv(filename, as_pd=False) + for i in range(1, len(csv_data)): + if csv_data[i] and csv_data[i][-1] == "": + csv_data[i].pop() + write_csv(csv_data, filename, mode="w") + + +# 将相同step的csv文件合并,并加工后存入相应step目录下 +def merge_file(dump_path, rank_dir, file_dict): + rank_dir = rank_dir.replace(CoreConst.REPLACEMENT_CHARACTER, '') + for step_dir, file_list in file_dict.items(): + step_dir = CoreConst.STEP + step_dir + rank_path = os.path.join(dump_path, step_dir, rank_dir) + create_directory(rank_path) + output_file = os.path.join(rank_path, KEY_STATISTIC_CSV) + + all_dfs = [] + try: + for file_path in file_list: + remove_trailing_commas(file_path) + df = read_csv(file_path) + all_dfs.append(df) + + # 合并所有 DataFrame + merged_df = pd.concat(all_dfs, ignore_index=True) + # 按 Timestamp 字段升序排序 + merged_df = merged_df.sort_values(by='Timestamp', ascending=True) + # 删除Slot字段为0的数据 + merged_df = merged_df[merged_df['Slot'] != 0] + # 重置索引,从0开始排序 + merged_df.reset_index(drop=True, inplace=True) + except FileNotFoundError as e: + logger.error(f"File not found: {e.filename}") + + try: + # 获取op_name并加工为Cell.network._backbone.LlamaForCausalLM.forward.input.0格式 + merged_df[CoreConst.OP_NAME] = merged_df[CoreConst.OP_NAME].str.split(KEY_DUMP_TENSOR_DATA, expand=True)[1] + merged_df[CoreConst.OP_NAME] = ( + merged_df[CoreConst.OP_NAME].str.split(CoreConst.PIPE_SEPARATOR, expand=True)[0] + ) + merged_df[CoreConst.OP_NAME] = ( + merged_df[CoreConst.OP_NAME].str.replace(CoreConst.HYPHEN, CoreConst.SEP, regex=False) + ) + # 重命名op_name,改为Cell.{cell_name}.{class_name}.{forward/backward}.{number}.{input/output}.{index}格式 + rename_filename(data_df=merged_df) + + # 将合并并排序后的 DataFrame 保存到相应step目录下 + write_df_to_csv(merged_df, output_file) + except KeyError: + logger.error("The value of the ‘Op Name’ field does not contain KEY_DUMP_TENSOR_DATA," + " and the index is out of bounds.") + + +def process_statistics(dump_path): + if not os.path.exists(dump_path): + logger.warning('No grap cell data is dumped.') + create_directory(dump_path) + return + + rank_id = os.environ.get('RANK_ID') + rank_dir_kbk = "rank_0" + if rank_id is not None: + rank_dir_kbk = CoreConst.RANK + CoreConst.REPLACEMENT_CHARACTER + str(rank_id) + rank_path_kbk = os.path.join(dump_path, rank_dir_kbk) + + # 按相同step数将csv文件名分组存入file_dict + file_dict = {} + depth_limit = 4 + base_depth = rank_path_kbk.count(os.sep) + for root, _, files in os.walk(rank_path_kbk): + current_depth = root.count(os.sep) - base_depth + if current_depth > depth_limit: + continue + for file in files: + if file == KEY_STATISTIC_CSV: + file_path = os.path.join(root, file) + step_dir = os.path.basename(os.path.dirname(file_path)) + if step_dir in file_dict: + file_dict[step_dir].append(file_path) + else: + file_dict[step_dir] = [file_path] + + # 将相同step的csv文件合并,并加工后存入相应step目录下 + merge_file(dump_path, rank_dir_kbk, file_dict) + + rank_dir = rank_dir_kbk.replace(CoreConst.REPLACEMENT_CHARACTER, '') + dir_list = os.listdir(dump_path) + step_dir_list = [d for d in dir_list if d.startswith(CoreConst.STEP)] + for step_dir in step_dir_list: + step_path = os.path.join(dump_path, step_dir) + rank_path = os.path.join(step_path, rank_dir) + csv_path = os.path.join(rank_path, KEY_STATISTIC_CSV) + logger.info("==========Start processing data csv!==========") + generate_construct(csv_path) + generate_dump_info(csv_path) + generate_stack_info(csv_path) + remove_path(rank_path_kbk) + # 单卡场景,rank目录名称为rank + if rank_id is None: + new_rank_path = os.path.join(step_path, CoreConst.RANK) + try: + move_directory(rank_path, new_rank_path) + logger.info(f"Directory was successfully renamed to: {new_rank_path}") + except Exception as e: + logger.warning(f"Failed to renamed to {new_rank_path}: {e}") + logger.info("==========JSON file generation completed!==========") + + +def get_yaml_keys(yaml_data): + keys = [] + for key, _ in yaml_data.items(): + keys.append(key) + return keys + + +def get_tensordump_mode(input_str): + left_index = input_str.find('(') + right_index = input_str.find(')') + + # 提取括号内的字符串 + if left_index != -1 and right_index != -1: + inner_str = input_str[left_index + 1:right_index] + # 分割字符串得到元素列表 + elements = inner_str.split(',') + if len(elements) >= 2: + # 去除元素前后的空格 + first_element = elements[0].strip() + second_element = elements[1].strip() + return first_element, second_element + return None, None + + +def str_to_list(input_str): + # 去除首尾的方括号 + input_str = input_str.strip('[]') + # 按逗号分割并去除元素两端的空格 + return [item.strip() for item in input_str.split(',')] + + +def set_tensordump_mode(cell, input_str): + first_str, second_str = get_tensordump_mode(input_str) + inputs_mode = [] + outputs_mode = [] + if first_str and second_str: + inputs_mode = str_to_list(first_str) + outputs_mode = str_to_list(second_str) + if inputs_mode and outputs_mode: + cell.input_dump_mode = inputs_mode + cell.output_dump_mode = outputs_mode + + +def create_kbyk_json(dump_path, summary_mode, step): + if step: + step_str = "" + for s in step: + step_str += (str(s) + '|') + iteration = step_str[:-1] + else: + iteration = "all" + + if summary_mode == "statistics": + statistic_category = ["max", "min", "avg", "l2norm"] + elif "mean" in summary_mode: + mean_index = summary_mode.index("mean") + summary_mode[mean_index] = "avg" + statistic_category = summary_mode + else: + statistic_category = summary_mode + + config_json = { + "common_dump_settings": { + "op_debug_mode": 0, + "dump_mode": 1, + "path": dump_path, + "net_name": "Net", + "iteration": iteration, + "saved_data": "statistic", + "input_output": 0, + "kernels": ["TensorDump"], + "support_device": [0, 1, 2, 3, 4, 5, 6, 7], + "statistic_category": statistic_category + }, + "e2e_dump_settings": { + "enable": False, + "trans_flag": True, + "stat_calc_mode": "device" + } + } + + create_directory(dump_path) + rank_id = os.environ.get('RANK_ID') + if rank_id is None: + rank_id = 0 + config_json_path = os.path.join(dump_path, str(rank_id) + "kernel_kbyk_dump.json") + save_json(config_json_path, config_json, indent=4) + logger.info(config_json_path + " has been created.") + return config_json_path + + +def start(config: CellDumpConfig): + global dump_task, parent_cell_types + dump_task = config.task + net = config.net + dump_path = config.dump_path + data_mode = config.data_mode + summary_mode = config.summary_mode + step = config.step + if dump_task == CoreConst.STATISTICS: + # 使能KBK dump + config_json_path = create_kbyk_json(dump_path, summary_mode, step) + os.environ["MINDSPORE_DUMP_CONFIG"] = config_json_path + + # 执行过程中跳过TensorDump算子 + os.environ["MS_KERNEL_LAUNCH_SKIP"] = "TensorDump" + + # 初始化静态图KBK dump的step数,从0开始 + if not graph_step_flag: + raise Exception( + "Importing _set_init_iter failed, " + "please use the latest version package of MindSpore." + ) + _set_init_iter(0) + remove_path(config_json_path) + + if not dump_gradient_op_existed or net is None: + return + + if isinstance(net, nn.Cell): + net = (('', net, None),) + + td_config_path = "" + try: + import mindformers + mindformers_file = mindformers.__file__ + mindformers_dir = os.path.dirname(mindformers_file) + td_config_path = os.path.join(mindformers_dir, "configuration", "layer_mapping.yaml") + if not os.path.exists(td_config_path): + td_config_path = "" + logger.warning("The configuration file in mindformers was not loaded, the default mode will be used.") + except ImportError: + logger.warning("The mindFormers failed to load, the default mode will be used.") + + if td_config_path == "": + yaml_data = {} + else: + yaml_data = load_yaml(td_config_path) + first_layer_key = get_yaml_keys(yaml_data) + + black_list = ["grad_reducer", ""] + + for name_and_model in net: + parent_cell_types[name_and_model[0]] = name_and_model[2].__class__.__name__ + for name, cell in name_and_model[1].cells_and_names(name_prefix=name_and_model[0]): + class_name = cell.__class__.__name__ + # 跳过黑名单cell + if name in black_list: + logger.info(f"Cell {name}.{class_name} is skipped!") + continue + # 跳过框架内部的cell + if class_name.startswith(CoreConst.REPLACEMENT_CHARACTER): + logger.info(f"Cell {name}.{class_name} is skipped!") + continue + else: + # Format: Cell.{cell_name}.{class_name} + cell.cell_prefix = CoreConst.SEP.join([CoreConst.CELL, name, cell.__class__.__name__]) + if dump_task == CoreConst.STATISTICS: + cell.cell_prefix = cell.cell_prefix.replace(CoreConst.SEP, CoreConst.HYPHEN) + + # 根据yaml配置文件设置cell的TensorDump模式 + if class_name in first_layer_key: + layer_data = yaml_data.get(class_name) + if layer_data: + for child_name, child_cell in cell.cells_and_names(): + if child_name in layer_data: + set_tensordump_mode(child_cell, layer_data[child_name]) + top_layer_data = yaml_data.get(KEY_TOPLAYER) + if top_layer_data and name in top_layer_data: + set_tensordump_mode(cell, top_layer_data[name]) + + # 替换construct函数 + cell.construct = cell_construct_wrapper(cell.construct, cell) + logger.info(f"Cell {name}: construct function is wrapped!") + cell.dump_path = dump_path + cell.data_mode = data_mode + + logger.info("==========The cell_dump_process_start phase is Finished!==========") + if dump_task == CoreConst.TENSOR: + atexit.register(process, dump_path=dump_path) + if dump_task == CoreConst.STATISTICS: + atexit.register(process_statistics, dump_path=dump_path) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/cell_dump_with_insert_gradient.py b/debug/accuracy_tools/msprobe/mindspore/dump/cell_dump_with_insert_gradient.py new file mode 100644 index 0000000000000000000000000000000000000000..bf43ee92d4980122a7ebe5d2995642f996c879fd --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/cell_dump_with_insert_gradient.py @@ -0,0 +1,872 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +from multiprocessing import Pool +import os +import re +import time + +import numpy as np +import pandas as pd +import mindspore as ms +from mindspore import nn, ops + +from msprobe.core.common.const import Const as CoreConst +from msprobe.core.common.const import FileCheckConst +from msprobe.core.common.file_utils import ( + load_npy, save_json, remove_path, load_yaml, + create_directory, read_csv, write_df_to_csv, write_csv, move_file, move_directory) +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.dump.cell_dump_process import CellDumpConfig + + +CONSTRUCT_FILE_NAME = "construct.json" +DEFAULT_RANK_DIR = "rank0" +KEY_LAYERS = "layers" +construct = {} +cell_list = [] +KEY_SIDE_EFFECT = "side_effect_io" +KEY_TOPLAYER = "TopLayer" +KEY_FORWARD = CoreConst.FORWARD +KEY_BACKWARD = CoreConst.BACKWARD +KEY_INPUT = CoreConst.INPUT +KEY_OUTPUT = CoreConst.OUTPUT +KEY_DUMP_TENSOR_DATA = "dump_tensor_data_" +KEY_STATISTIC_CSV = "statistic.csv" +KEY_TD_FLAG = "td_flag" +td = ops.TensorDump() +if (ms.__version__ >= "2.5.0"): + td_in = ops.TensorDump("in") +else: + td_in = ops.TensorDump() +graph_step_flag = True +try: + from mindspore._c_expression import _set_init_iter +except ImportError: + graph_step_flag = False +td.add_prim_attr(KEY_SIDE_EFFECT, False) +td_in.add_prim_attr(KEY_SIDE_EFFECT, False) +td.add_prim_attr(KEY_TD_FLAG, True) +td_in.add_prim_attr(KEY_TD_FLAG, True) +dump_task = CoreConst.STATISTICS +np_ms_dtype_dict = { + "bool": ms.bool_, + "int8": ms.int8, + "byte": ms.byte, + "int16": ms.int16, + "short": ms.short, + "int32": ms.int32, + "intc": ms.intc, + "int64": ms.int64, + "intp": ms.intp, + "uint8": ms.uint8, + "ubyte": ms.ubyte, + "uint16": ms.uint16, + "ushort": ms.ushort, + "uint32": ms.uint32, + "uintc": ms.uintc, + "uint64": ms.uint64, + "uintp": ms.uintp, + "float16": ms.float16, + "half": ms.half, + "float32": ms.float32, + "single": ms.single, + "float64": ms.float64, + "double": ms.double, + "bfloat16": ms.bfloat16, + "complex64": ms.complex64, + "complex128": ms.complex128 +} + + +def gen_file_path(dump_path, cell_prefix, suffix, io_type, index): + data_path = os.path.join(dump_path, '{step}', '{rank}', CoreConst.DUMP_TENSOR_DATA) + file_name = "" + if dump_task == CoreConst.TENSOR: + file_name = cell_prefix + CoreConst.SEP + suffix + CoreConst.SEP + io_type + CoreConst.SEP + str(index) + if dump_task == CoreConst.STATISTICS: + file_name = cell_prefix + CoreConst.HYPHEN + suffix + CoreConst.HYPHEN + io_type + CoreConst.HYPHEN + str(index) + return os.path.join(data_path, file_name) + + +def partial_func(func, dump_path, cell_prefix, index, io_type): + def newfunc(*args, **kwargs): + return func(dump_path, cell_prefix, index, io_type, *args, **kwargs) + return newfunc + + +def clip_gradient(dump_path, cell_prefix, index, io_type, dx): + if io_type == KEY_OUTPUT: + temp = td(gen_file_path(dump_path, cell_prefix, KEY_BACKWARD, io_type, index), dx) + dx = ops.depend(dx, temp) + elif io_type == KEY_INPUT: + temp = td_in(gen_file_path(dump_path, cell_prefix, KEY_BACKWARD, io_type, index), dx) + dx = ops.depend(dx, temp) + return dx + + +def need_tensordump_in(cell_obj, attr): + return hasattr(cell_obj, attr) and getattr(cell_obj, attr) == "in" + + +def cell_construct_wrapper(func, self): + def new_construct(self, *args, **kwargs): + new_args = [] + out_list = [] + + index = 0 + item = None + backward_or_all = self.data_mode in ["backward", "all"] + forward_or_all = self.data_mode in ["forward", "all"] + # The inputs of the cell. + for index, item in enumerate(args): + if backward_or_all and ops.is_tensor(item): + item = self.output_clips[index](item) + if forward_or_all and ops.is_tensor(item): + if need_tensordump_in(self, 'input_dump_mode'): + temp = td_in( + gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_INPUT, index), + item + ) + else: + temp = td( + gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_INPUT, index), + item + ) + item = ops.depend(item, temp) + new_args.append(item) + + out = func(*new_args, **kwargs) + + # The outputs of the cell. + if isinstance(out, tuple): + for index, item in enumerate(out): + if backward_or_all and ops.is_tensor(item): + item = self.input_clips[index](item) + if forward_or_all and ops.is_tensor(item): + if need_tensordump_in(self, 'output_dump_mode'): + temp = td_in( + gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_OUTPUT, index), + item + ) + else: + temp = td( + gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_OUTPUT, index), + item + ) + item = ops.depend(item, temp) + out_list.append(item) + elif forward_or_all and not ops.is_tensor(item): + out_list.append(item) + out_list = tuple(out_list) + return out_list + else: + if backward_or_all: + out = self.input_clips[0](out) + if forward_or_all and ops.is_tensor(out): + if need_tensordump_in(self, 'output_dump_mode'): + temp = td_in( + gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_OUTPUT, 0), + out + ) + else: + temp = td( + gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_OUTPUT, 0), + out + ) + out = ops.depend(out, temp) + return out + + return new_construct.__get__(self, type(self)) + + +# 获取目录下所有文件名并根据TensorDump落盘自增id从小到大排序 +def sort_filenames(path): + filenames = os.listdir(path) + id_pattern = re.compile(rf'{CoreConst.REPLACEMENT_CHARACTER}(\d+){CoreConst.NUMPY_SUFFIX}$') + # 只保留能提取到数字id的文件,避免数组越界 + valid_files = [] + for filename in filenames: + match = id_pattern.findall(filename) + if match and match[0].isdigit(): + valid_files.append(filename) + else: + logger.warning(f"File {filename} does not match the expected pattern and will be ignored.") + valid_files.sort(key=lambda x: int(id_pattern.findall(x)[0])) + return valid_files + + +def rename_filename(path="", data_df=None): + if dump_task == CoreConst.TENSOR: + filenames = sort_filenames(path) + if dump_task == CoreConst.STATISTICS: + filenames = data_df[CoreConst.OP_NAME].tolist() + + filename_dict = {} + for index, filename in enumerate(filenames): + if dump_task == CoreConst.TENSOR: + name_field = filename.rsplit(CoreConst.REPLACEMENT_CHARACTER, 1)[0] + if dump_task == CoreConst.STATISTICS: + name_field = filename + + if name_field in filename_dict: + filename_dict[name_field] += 1 + else: + filename_dict[name_field] = 0 + + cell_index = filename_dict[name_field] + + # 修改文件名,增加重复调用Cell的序号 + if CoreConst.FORWARD_PATTERN in filename: + # Format: Cell.{cell_name}.{class_name}.{forward/backward}.{number}.{input/output}.{index}_{dtype}_{id}.npy + new_file_name = filename.replace(CoreConst.FORWARD_PATTERN, + CoreConst.FORWARD_PATTERN + str(cell_index) + CoreConst.SEP) + if CoreConst.BACKWARD_PATTERN in filename: + new_file_name = filename.replace(CoreConst.BACKWARD_PATTERN, + CoreConst.BACKWARD_PATTERN + str(cell_index) + CoreConst.SEP) + if dump_task == CoreConst.TENSOR: + move_file(os.path.join(path, filename), os.path.join(path, new_file_name)) + if dump_task == CoreConst.STATISTICS: + data_df.loc[index, CoreConst.OP_NAME] = new_file_name + logger.info("==========The rename_filename phase is Finished!==========") + + +# Extract the field between the first "." and the third to last ".", i.e. {cell_name} +def get_cell_name(string): + parts = string.split(CoreConst.SEP) + if len(parts) < 4: + return None + start_index = 1 + end_index = len(parts) - 3 + return CoreConst.SEP.join(parts[start_index:end_index]) + + +# Extract the field between the last "." and the second to last ".", i.e. {data_made} +def get_data_mode(string): + last_dot_index = string.rfind(CoreConst.SEP) + second_last_dot_index = string.rfind(CoreConst.SEP, 0, last_dot_index) + data_mode = string[second_last_dot_index + 1:last_dot_index] + return data_mode + + +# 判断二者之间是否存在父子关系 +def check_relation(cell_name, parent_cell_name): + layers_pattern = rf"{CoreConst.SEP}{KEY_LAYERS}{CoreConst.SEP}\d+$" + last_dot_index = cell_name.rfind(CoreConst.SEP) + if last_dot_index == -1: + return False + # 如果cell_name最后一个'.'之前的字段等于parent_cell_name,则判定存在父子关系 + sub_cell_name = cell_name[:last_dot_index] + if sub_cell_name == parent_cell_name: + return True + elif re.search(layers_pattern, cell_name): + # 如果cell_name以".layer.{layer_id}"结尾,且去掉该字段后等于parent_cell_name,则判定存在父子关系 + sub_cell_name = re.sub(layers_pattern, '', cell_name) + if sub_cell_name == parent_cell_name: + return True + return False + + +def get_construct(cell_list_input): + for cell in cell_list_input: + cell_name = get_cell_name(cell) + cell_data_mode = get_data_mode(cell) + found_flag = False + for parent_cell in cell_list_input: + parent_cell_name = get_cell_name(parent_cell) + parent_data_mode = get_data_mode(parent_cell) + has_relation = check_relation(cell_name, parent_cell_name) + if has_relation and parent_data_mode == cell_data_mode: + construct.update({cell: parent_cell}) + found_flag = True + break + if not found_flag: + construct.update({cell: None}) + + +def generate_construct(path): + global construct + if dump_task == CoreConst.TENSOR: + # filename格式:Cell.clip_grad_norm.ClipGradNorm.forward.0.output.1_int32_0.npy + filenames = sort_filenames(path) + point_position = 3 + if dump_task == CoreConst.STATISTICS: + df = read_csv(path) + # filename格式:Cell.clip_grad_norm.ClipGradNorm.forward.0.output.1 + filenames = df[CoreConst.OP_NAME].tolist() + point_position = 2 + + # 提取文件名中Cell.{cell_name}.{class_name}.{data_mode}.{重复调用此cell的序号}字段,并存入cell_list + for filename in filenames: + mid_field = filename.rsplit(CoreConst.SEP, point_position)[0] + if KEY_INPUT in filename: + if mid_field in cell_list: + cell_list.remove(mid_field) + cell_list.append(mid_field) + else: + if mid_field not in cell_list: + index = filenames.index(filename) + output_field = mid_field + KEY_OUTPUT + find_flag = False + for filename_other in cell_list[index + 1:]: + if output_field in filename_other: + find_flag = True + if find_flag is False: + cell_list.append(mid_field) + + get_construct(cell_list) + + # 生成JSON文件 + rank_dir = os.path.dirname(path) + json_path = os.path.join(rank_dir, CONSTRUCT_FILE_NAME) + save_json(json_path, construct, indent=1) + + # 清空'construct'继续处理下一个路径下的数据 + construct = {} + logger.info(f"Construct data saved to {json_path}") + + +def process_file(file_path): + try: + # 读取.npy文件内容 + npy_content = load_npy(file_path) + logger.debug(f"Loaded {file_path}: shape is {npy_content.shape}, dtype is {npy_content.dtype}") + + # 文件名举例:Cell.network._backbone.loss.CrossEntropyLoss.forward.0.input.0_float32_165.npy + parts = os.path.basename(file_path).split(CoreConst.SEP) + data_dtype = "" + # 获取0_float32_165或者0_in_float32_165中的float32 + data_dtype_list = parts[-2].split('_') + if len(data_dtype_list) > 1: + data_dtype = data_dtype_list[-2] + # op_name是Cell.network._backbone.loss.CrossEntropyLoss.forward.0 + op_name = CoreConst.SEP.join(parts[:-3]) + ms_dtype = np_ms_dtype_dict.get(data_dtype) + if ms_dtype is None: + logger.warning(f"Get dtype None from file {file_path}") + + # 修改落盘文件名字,去掉TensorDump自带的数据类型和自增id字段 + data_file_name = os.path.basename(file_path) + data_file_dir = os.path.dirname(file_path) + parts = data_file_name.split(CoreConst.SEP) + if len(parts) >= 2: + param_index = parts[-2].split(CoreConst.REPLACEMENT_CHARACTER)[0] + pre_parts = CoreConst.SEP.join(parts[:-2]) + new_file_name = pre_parts + CoreConst.SEP + param_index + CoreConst.NUMPY_SUFFIX + move_file(os.path.join(data_file_dir, data_file_name), os.path.join(data_file_dir, new_file_name)) + logger.debug(f"{data_file_name} is renamed to {new_file_name}") + else: + logger.warning(f"Failed to rename {data_file_name}.") + new_file_name = data_file_name + + tensor_json = { + CoreConst.TYPE: 'mindspore.Tensor', + CoreConst.DTYPE: str(ms_dtype), + CoreConst.SHAPE: list(npy_content.shape), + CoreConst.MAX: npy_content.max().item(), + CoreConst.MIN: npy_content.min().item(), + CoreConst.MEAN: npy_content.mean().item(), + CoreConst.NORM: np.linalg.norm(npy_content).item(), + CoreConst.DATA_NAME: new_file_name + } + + # 根据文件名的最后一个部分(输入或输出)确定是添加到input_args还是output + if parts[-3] == KEY_INPUT: + return op_name, CoreConst.INPUT_ARGS, tensor_json + elif parts[-3] == KEY_OUTPUT: + return op_name, KEY_OUTPUT, tensor_json + else: + return None, None, None + + except Exception as e: + logger.error(f"Error reading {file_path}: {e}") + return None, None, None + + +def custom_sort(item, key_to_index): + key = item[0] + return key_to_index.get(key, float('inf')) + + +def convert_special_values(value): + if isinstance(value, str): + if value.lower() == "true": + return True + elif value.lower() == "false": + return False + try: + return float(value) + except ValueError: + return value + elif pd.isna(value): + return None + return value + + +def process_csv(path): + data_info = [] + df = read_csv(path) + df = df.sort_values(by='Op Name', ascending=True) + columns = df.columns + colume_to_json_key = { + 'Max Value': CoreConst.MAX, + 'Min Value': CoreConst.MIN, + 'Avg Value': CoreConst.MEAN, + 'L2Norm Value': CoreConst.NORM + } + for _, row in df.iterrows(): + # op_name_value格式:Cell.network._backbone.loss.CrossEntropyLoss.forward.0.input.0 + op_name_value = row['Op Name'] + op_name = op_name_value.rsplit(CoreConst.SEP, 2)[0] + + # 获取input/output字段 + io_key = op_name_value.split(CoreConst.SEP)[-2] + + # shape读取出来为字符串类型转为list。"(1,4096)"->[1,4096] + shape_num = re.findall(r'\d+', row['Shape']) + shape = [int(num) for num in shape_num] + + tensor_json = { + CoreConst.TYPE: 'mindspore.Tensor', + CoreConst.DTYPE: str(np_ms_dtype_dict.get(row['Data Type'])), + CoreConst.SHAPE: shape + } + for col_name, json_key in colume_to_json_key.items(): + if col_name in columns: + value = convert_special_values(row[col_name]) + tensor_json[json_key] = value + + if io_key == KEY_INPUT: + data_info.append([op_name, CoreConst.INPUT_ARGS, tensor_json]) + elif io_key == KEY_OUTPUT: + data_info.append([op_name, KEY_OUTPUT, tensor_json]) + else: + data_info.append([None, None, None]) + return data_info + + +def generate_dump_info(path): + if not os.path.exists(path): + logger.error("The provided path does not exist.") + return + + if dump_task == CoreConst.TENSOR: + dump_data = {"task": "tensor", "level": "L0", "dump_data_dir": path, "data": {}} + with Pool(processes=10) as pool: + file_paths = [] + for file in os.listdir(path): + if file.endswith(FileCheckConst.NUMPY_SUFFIX): + file_paths.append((os.path.join(path, file),)) + file_paths.sort() + results = pool.starmap(process_file, file_paths) + if dump_task == CoreConst.STATISTICS: + dump_data = {"task": "statistics", "level": "L0", "framework": "mindspore", "dump_data_dir": None, "data": {}} + results = process_csv(path) + + # 收集结果 + for op_name, key, tensor_json in results: + if op_name: + if op_name not in dump_data.get(CoreConst.DATA, {}): + dump_data.get(CoreConst.DATA, {})[op_name] = {CoreConst.INPUT_ARGS: [], + CoreConst.INPUT_KWARGS: {}, + KEY_OUTPUT: []} + if key not in dump_data.get(CoreConst.DATA, {}).get(op_name, {}): + dump_data.get(CoreConst.DATA, {}).get(op_name, {})[key] = [] + dump_data.get(CoreConst.DATA, {}).get(op_name, {}).get(key, []).append(tensor_json) + + # 根据cell_list排序 + data_dict = dump_data.get(CoreConst.DATA, {}) + key_to_index = {key: index for index, key in enumerate(cell_list)} + sorted_data_dict = dict(sorted(data_dict.items(), key=lambda item: custom_sort(item, key_to_index))) + dump_data[CoreConst.DATA] = sorted_data_dict + + # 将数据写入dump.json + json_path = os.path.join(os.path.dirname(path), 'dump.json') + save_json(json_path, dump_data, indent=1) + + logger.info(f"Dump data saved to {json_path}") + + +def generate_stack_info(path): + if not os.path.exists(path): + logger.error("The provided path does not exist.") + return + + stack_data = {} + for cell_name in cell_list: + stack_data.update({cell_name: []}) + + # 将数据写入stack.json + json_path = os.path.join(os.path.dirname(path), 'stack.json') + save_json(json_path, stack_data, indent=1) + + # 删除csv文件 + if dump_task == CoreConst.STATISTICS: + remove_path(path) + + logger.info(f"Stack data saved to {json_path}") + + +def is_download_finished(directory, interval=3): + """ + 判断指定目录在一段时间后是否有数据被下载完成 + :param directory: 指定目录的路径 + :param interval: 检查的时间间隔(秒),默认为 3 秒 + :return: 如有数据被下载完成返回 True,否则返回 False + """ + # 检查目录是否存在 + if not os.path.exists(directory): + logger.warning(f"The specified directory {directory} does not exist.") + return False, False + initial_modification_time = os.path.getmtime(directory) + time.sleep(interval) + current_modification_time = os.path.getmtime(directory) + # 比较初始和当前修改时间 + if current_modification_time > initial_modification_time: + return False, True + else: + return True, False + + +def process(dump_path): + rank_id = os.environ.get('RANK_ID') + rank_dir = DEFAULT_RANK_DIR + if rank_id is not None: + rank_dir = CoreConst.RANK + str(rank_id) + + step_dir_list = os.listdir(dump_path) + for step_dir in step_dir_list: + step_path = os.path.join(dump_path, step_dir) + rank_path = os.path.join(step_path, rank_dir) + npy_path = os.path.join(rank_path, CoreConst.DUMP_TENSOR_DATA) + check_times = 0 + while True: + is_finished, is_downloading = is_download_finished(npy_path) + if not is_finished: + if not is_downloading: + logger.warning(f'{npy_path} does not exist.') + break + check_times += 1 + if check_times < 1000: + logger.info("There is data being downloaded in the specified directory, continue checking...") + else: + logger.warning('Download timeout, stop checking.') + break + else: + logger.info("There is no data being downloaded in the specified directory, stop checking.") + break + logger.info("==========Start processing data that has already been stored on the disk!==========") + rename_filename(path=npy_path) + generate_construct(npy_path) + generate_dump_info(npy_path) + generate_stack_info(npy_path) + # 单卡场景,rank目录名称为rank + if rank_id is None: + new_rank_path = os.path.join(step_path, CoreConst.RANK) + try: + move_directory(rank_path, new_rank_path) + logger.debug(f"Directory was successfully renamed to: {new_rank_path}") + except Exception as e: + logger.warning(f"Failed to renamed to {new_rank_path}: {e}") + logger.info("==========JSON file generation completed!==========") + + +# 删除csv文件中每行数据最后面的逗号 +def remove_trailing_commas(filename): + csv_data = read_csv(filename, as_pd=False) + for i in range(1, len(csv_data)): + if csv_data[i] and csv_data[i][-1] == "": + csv_data[i].pop() + write_csv(csv_data, filename, mode="w") + + +# 将相同step的csv文件合并,并加工后存入相应step目录下 +def merge_file(dump_path, rank_dir, file_dict): + rank_dir = rank_dir.replace(CoreConst.REPLACEMENT_CHARACTER, '') + for step_dir, file_list in file_dict.items(): + step_dir = CoreConst.STEP + step_dir + rank_path = os.path.join(dump_path, step_dir, rank_dir) + create_directory(rank_path) + output_file = os.path.join(rank_path, KEY_STATISTIC_CSV) + + all_dfs = [] + try: + for file_path in file_list: + remove_trailing_commas(file_path) + df = read_csv(file_path) + all_dfs.append(df) + + # 合并所有 DataFrame + merged_df = pd.concat(all_dfs, ignore_index=True) + + # 按 Timestamp 字段升序排序 + merged_df = merged_df.sort_values(by='Timestamp', ascending=True) + # 删除Slot字段为0的数据 + merged_df = merged_df[merged_df['Slot'] != 0] + # 重置索引,从0开始排序 + merged_df.reset_index(drop=True, inplace=True) + + # 获取op_name并加工为Cell.network._backbone.LlamaForCausalLM.forward.input.0格式 + merged_df[CoreConst.OP_NAME] = merged_df[CoreConst.OP_NAME].str.split(KEY_DUMP_TENSOR_DATA, expand=True)[1] + merged_df[CoreConst.OP_NAME] = ( + merged_df[CoreConst.OP_NAME].str.split(CoreConst.PIPE_SEPARATOR, expand=True)[0] + ) + merged_df[CoreConst.OP_NAME] = ( + merged_df[CoreConst.OP_NAME].str.replace(CoreConst.HYPHEN, CoreConst.SEP, regex=False) + ) + # 重命名op_name,改为Cell.{cell_name}.{class_name}.{forward/backward}.{number}.{input/output}.{index}格式 + rename_filename(data_df=merged_df) + + # 将合并并排序后的 DataFrame 保存到相应step目录下 + write_df_to_csv(merged_df, output_file) + except FileNotFoundError: + logger.error("One or more files not found.") + except KeyError: + logger.error("The value of the ‘Op Name’ field does not contain KEY_DUMP_TENSOR_DATA," + " and the index is out of bounds.") + except Exception as e: + logger.error(f"An error occurred:{e}") + + +def process_statistics(dump_path): + rank_id = os.environ.get('RANK_ID') + rank_dir_kbk = "rank_0" + if rank_id is not None: + rank_dir_kbk = CoreConst.RANK + CoreConst.REPLACEMENT_CHARACTER + str(rank_id) + rank_path_kbk = os.path.join(dump_path, rank_dir_kbk) + + # 按相同step数将csv文件名分组存入file_dict + file_dict = {} + depth_limit = 4 + base_depth = rank_path_kbk.count(os.sep) + for root, _, files in os.walk(rank_path_kbk): + current_depth = root.count(os.sep) - base_depth + if current_depth > depth_limit: + continue + for file in files: + if file == KEY_STATISTIC_CSV: + file_path = os.path.join(root, file) + step_dir = os.path.basename(os.path.dirname(file_path)) + if step_dir in file_dict: + file_dict[step_dir].append(file_path) + else: + file_dict[step_dir] = [file_path] + + # 将相同step的csv文件合并,并加工后存入相应step目录下 + merge_file(dump_path, rank_dir_kbk, file_dict) + + rank_dir = rank_dir_kbk.replace(CoreConst.REPLACEMENT_CHARACTER, '') + dir_list = os.listdir(dump_path) + step_dir_list = [d for d in dir_list if d.startswith(CoreConst.STEP)] + for step_dir in step_dir_list: + step_path = os.path.join(dump_path, step_dir) + rank_path = os.path.join(step_path, rank_dir) + csv_path = os.path.join(rank_path, KEY_STATISTIC_CSV) + logger.info("==========Start processing data csv!==========") + generate_construct(csv_path) + generate_dump_info(csv_path) + generate_stack_info(csv_path) + remove_path(rank_path_kbk) + # 单卡场景,rank目录名称为rank + if rank_id is None: + new_rank_path = os.path.join(step_path, CoreConst.RANK) + try: + move_directory(rank_path, new_rank_path) + logger.info(f"Directory was successfully renamed to: {new_rank_path}") + except Exception as e: + logger.warning(f"Failed to renamed to {new_rank_path}: {e}") + logger.info("==========JSON file generation completed!==========") + + +def get_yaml_keys(yaml_data): + keys = [] + for key, _ in yaml_data.items(): + keys.append(key) + return keys + + +def get_tensordump_mode(input_str): + left_index = input_str.find('(') + right_index = input_str.find(')') + + # 提取括号内的字符串 + if left_index != -1 and right_index != -1: + inner_str = input_str[left_index + 1:right_index] + # 分割字符串得到元素列表 + elements = inner_str.split(',') + if len(elements) >= 2: + # 去除元素前后的空格 + first_element = elements[0].strip() + second_element = elements[1].strip() + return first_element, second_element + return None, None + + +def set_tensordump_mode(cell, input_str): + first_str, second_str = get_tensordump_mode(input_str) + if first_str and second_str: + cell.input_dump_mode = first_str + cell.output_dump_mode = second_str + + +def create_kbyk_json(dump_path, summary_mode, step): + if step: + step_str = "" + for s in step: + step_str += (str(s) + '|') + iteration = step_str[:-1] + else: + iteration = "all" + + if summary_mode == "statistics": + statistic_category = ["max", "min", "avg", "l2norm"] + elif "mean" in summary_mode: + mean_index = summary_mode.index("mean") + summary_mode[mean_index] = "avg" + statistic_category = summary_mode + else: + statistic_category = summary_mode + + config_json = { + "common_dump_settings": { + "op_debug_mode": 0, + "dump_mode": 1, + "path": dump_path, + "net_name": "Net", + "iteration": iteration, + "saved_data": "statistic", + "input_output": 0, + "kernels": ["TensorDump"], + "support_device": [0, 1, 2, 3, 4, 5, 6, 7], + "statistic_category": statistic_category + }, + "e2e_dump_settings": { + "enable": False, + "trans_flag": True, + "stat_calc_mode": "device" + } + } + + create_directory(dump_path) + rank_id = os.environ.get('RANK_ID') + if rank_id is None: + rank_id = 0 + config_json_path = os.path.join(dump_path, str(rank_id) + "kernel_kbyk_dump.json") + save_json(config_json_path, config_json, indent=4) + logger.info(config_json_path + " has been created.") + return config_json_path + + +def start(config: CellDumpConfig): + global dump_task + dump_task = config.task + net = config.net + dump_path = config.dump_path + data_mode = config.data_mode + summary_mode = config.summary_mode + step = config.step + if dump_task == CoreConst.STATISTICS: + # 使能KBK dump + config_json_path = create_kbyk_json(dump_path, summary_mode, step) + os.environ["MINDSPORE_DUMP_CONFIG"] = config_json_path + + # 执行过程中跳过TensorDump算子 + os.environ["MS_KERNEL_LAUNCH_SKIP"] = "TensorDump" + + # 初始化静态图KBK dump的step数,从0开始 + if not graph_step_flag: + raise Exception( + "Importing _set_init_iter failed, " + "please use the latest version package of MindSpore." + ) + _set_init_iter(0) + remove_path(config_json_path) + + if net is None: + return + + if isinstance(net, nn.Cell): + net = (('', net),) + + td_config_path = "" + try: + import mindformers + mindformers_file = mindformers.__file__ + mindformers_dir = os.path.dirname(mindformers_file) + td_config_path = os.path.join(mindformers_dir, "configuration", "layer_mapping.yaml") + if not os.path.exists(td_config_path): + td_config_path = "" + logger.warning("The configuration file in mindformers was not loaded, the default mode will be used.") + except ImportError: + logger.warning("The mindFormers failed to load, the default mode will be used.") + + if td_config_path == "": + yaml_data = {} + else: + yaml_data = load_yaml(td_config_path) + first_layer_key = get_yaml_keys(yaml_data) + + black_list = ["grad_reducer", ""] + + for name_and_model in net: + for name, cell in name_and_model[1].cells_and_names(name_prefix=name_and_model[0]): + class_name = cell.__class__.__name__ + # 跳过黑名单cell + if name in black_list: + logger.info(f"Cell {name}.{class_name} is skipped!") + continue + # 跳过框架内部的cell + if class_name.startswith(CoreConst.REPLACEMENT_CHARACTER): + logger.info(f"Cell {name}.{class_name} is skipped!") + continue + else: + # Format: Cell.{cell_name}.{class_name} + cell.cell_prefix = CoreConst.SEP.join([CoreConst.CELL, name, cell.__class__.__name__]) + if dump_task == CoreConst.STATISTICS: + cell.cell_prefix = cell.cell_prefix.replace(CoreConst.SEP, CoreConst.HYPHEN) + + # 根据yaml配置文件设置cell的TensorDump模式 + if class_name in first_layer_key: + layer_data = yaml_data.get(class_name) + if layer_data: + for child_name, child_cell in cell.cells_and_names(): + if child_name in layer_data: + set_tensordump_mode(child_cell, layer_data[child_name]) + top_layer_data = yaml_data.get(KEY_TOPLAYER) + if top_layer_data and name in top_layer_data: + set_tensordump_mode(cell, top_layer_data[name]) + + # 替换construct函数 + cell.construct = cell_construct_wrapper(cell.construct, cell) + logger.info(f"Cell {name}: construct function is wrapped!") + cell.dump_path = dump_path + cell.data_mode = data_mode + cell.input_clips = [] + cell.output_clips = [] + # It is assumed that each cell has a maximum of 50 outputs and 50 inputs. + for i in range(50): + cell.input_clips.append( + ops.InsertGradientOf(partial_func(clip_gradient, cell.dump_path, cell.cell_prefix, i, KEY_INPUT)) + ) + cell.output_clips.append( + ops.InsertGradientOf(partial_func(clip_gradient, cell.dump_path, cell.cell_prefix, i, KEY_OUTPUT)) + ) + + logger.info("==========The cell_dump_process_start phase is Finished!==========") + if dump_task == CoreConst.TENSOR: + atexit.register(process, dump_path=dump_path) + if dump_task == CoreConst.STATISTICS: + atexit.register(process_statistics, dump_path=dump_path) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py b/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py index 0ca63b4a84aee00127bca37b7da36888e905a5aa..f29db2e5b317c6e01d3d50f03b5d0434e638f320 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,15 +14,18 @@ # limitations under the License. from msprobe.mindspore.common.const import Const +from msprobe.core.common.log import logger +from msprobe.mindspore.common.utils import is_graph_mode_cell_dump_allowed from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump +from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump class DumpToolFactory: tools = { Const.CELL: { - Const.GRAPH_KBYK_MODE: None, + Const.GRAPH_KBYK_MODE: GraphModeCellDump, Const.GRAPH_GE_MODE: None, Const.PYNATIVE_MODE: None }, @@ -39,14 +42,23 @@ class DumpToolFactory: } @staticmethod - def create(config: DebuggerConfig): - if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_DATA_MODE_LIST: - raise Exception("data_mode must be one of all, input, output.") + def create(config: DebuggerConfig, model=None): + if config.level == Const.CELL: + if not is_graph_mode_cell_dump_allowed(config): + raise Exception("Cell dump is not supported in graph mode.") + if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_CELL_DUMP_DATA_MODE_LIST: + raise Exception("data_mode must be one of all, forward, backward.") + else: + if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_DATA_MODE_LIST: + raise Exception("data_mode must be one of all, input, output.") + if config.level == Const.KERNEL: + return (KernelGraphDump(config), KernelKbykDump(config)) tool = DumpToolFactory.tools.get(config.level) if not tool: raise Exception("Valid level is needed.") tool = tool.get(config.execution_mode) if not tool: - raise Exception(f"Data dump is not supported in {config.execution_mode} mode " - f"when dump level is {config.level}.") - return tool(config) + logger.error(f"Data dump is not supported in {config.execution_mode} mode " + f"when dump level is {config.level}.") + raise ValueError + return tool(config, model) if tool == GraphModeCellDump else tool(config) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/graph_mode_cell_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/graph_mode_cell_dump.py new file mode 100644 index 0000000000000000000000000000000000000000..53bf260e95dad473fbb65e2177bc402613c2eae5 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/graph_mode_cell_dump.py @@ -0,0 +1,154 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import mindspore as ms +from mindspore import hal, ops, Tensor +from mindspore.ops.primitive import _run_op + +from msprobe.core.common.const import Const as CoreConst +from msprobe.core.common.runtime import Runtime +from msprobe.mindspore.common.const import Const +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +import msprobe.mindspore.dump.cell_dump_process as cellDumperWithDumpGradient +import msprobe.mindspore.dump.cell_dump_with_insert_gradient as cellDumperWithInsertGradient + +tensordump_flag = True +try: + from mindspore._c_expression import _tensordump_set_step +except ImportError: + tensordump_flag = False + +graph_step_flag = True +try: + from mindspore._c_expression import _dump_step +except ImportError: + graph_step_flag = False + + +class GraphModeCellDump: + task = CoreConst.STATISTICS + + def __init__(self, config: DebuggerConfig, model, strict=True): + self.net = model + self.white_list = [] + self.black_list = [] + self.execution_mode = config.execution_mode + self.dump_path = config.dump_path if config.dump_path else "./" + self.rank = config.rank + self.step = config.step + self.scope = config.scope + self.list = config.list + self.data_mode = config.data_mode + self.file_format = config.file_format + GraphModeCellDump.task = config.task + self.summary_mode = config.summary_mode + self.check_config(strict) + self.set_step() + + @staticmethod + def step(): + # 更新TensorDump Step + if GraphModeCellDump.task == CoreConst.TENSOR: + hal.synchronize() + temp_tensor = ms.Tensor([1], dtype=ms.float32) + step_flag = "" + _run_op(ops.TensorDump(), "TensorDump", (step_flag, temp_tensor)) + ops.tensordump(step_flag, temp_tensor) + + # 更新静态图KBK dump的step数 + if GraphModeCellDump.task == CoreConst.STATISTICS: + if not graph_step_flag: + raise Exception( + "Importing _dump_step failed, " + "please use the latest version package of MindSpore." + ) + _dump_step(1) + + def check_config(self, strict): + if not self.net: + raise Exception("The model is empty and cell dump is not enabled.") + + if strict: + if self.rank: + raise Exception("In graph mode, cell dump does not currently support specifying rank.") + if self.scope: + raise Exception("In graph mode, cell dump does not currently support specifying scope.") + if self.list: + raise Exception("In graph mode, cell dump does not currently support specifying list.") + if len(self.data_mode) != 1 or self.data_mode[0] not in Const.GRAPH_CELL_DUMP_DATA_MODE_LIST: + raise Exception("In graph mode and cell dump, data_mode must be one of all, forword, backword.") + if self.file_format != []: + logger.warning("In graph mode, cell dump does not currently support specifying file_format." + " The file will be stored in npy format.") + if self.task == CoreConst.STATISTICS and self.summary_mode == CoreConst.MD5: + raise Exception("The L0 level statistics dump mode does not support " + "the calculation of md5 values currently In graph mode.") + else: + self.rank = [] + self.scope = [] + self.list = [] + self.file_format = [] + if len(self.data_mode) != 1 or self.data_mode[0] not in Const.GRAPH_CELL_DUMP_DATA_MODE_LIST: + self.data_mode = [CoreConst.ALL] + if self.task == CoreConst.STATISTICS and self.summary_mode == CoreConst.MD5: + self.summary_mode = CoreConst.STATISTICS + + return True + + def set_step(self): + if tensordump_flag: + _tensordump_set_step(self.step) + else: + raise Exception( + "Importing _tensordump_set_step failed, " + "please use the latest version package of MindSpore." + ) + + def handle(self): + os.environ['MS_JIT_MODULES'] = 'msprobe' + + if Runtime.run_mode == Const.PYNATIVE_GRAPH_MODE: + dump_path = os.path.join(self.dump_path, Const.GRAPH_MODE) + else: + dump_path = self.dump_path + + cell_dumper = cellDumperWithDumpGradient + + if self.execution_mode == Const.PYNATIVE_MODE: + enable_dump_gradient = hasattr(ops, 'DumpGradient') + if hasattr(ops, 'DumpGradient'): + try: + ops.DumpGradient()('grad.npy', Tensor([0], dtype=ms.float32), 'in') + except Exception: + enable_dump_gradient = False + logger.warning('the DumpGradient operator failed to execute.') + if not enable_dump_gradient: + cell_dumper = cellDumperWithInsertGradient + + dump_config = cell_dumper.CellDumpConfig( + net=self.net, + dump_path=dump_path, + data_mode=self.data_mode[0], + task=self.task, + summary_mode=self.summary_mode, + step=self.step + ) + + cell_dumper.start( + dump_config + ) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/graph_tensor_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/graph_tensor_dump.py new file mode 100644 index 0000000000000000000000000000000000000000..62651144251ed4bc86f3a5e141eb1fd0ae16bbb6 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/graph_tensor_dump.py @@ -0,0 +1,134 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from collections import OrderedDict +import mindspore as ms +from mindspore import hal, ops, Tensor +from mindspore.ops.primitive import _run_op + + +def _iterate_items(data): + if isinstance(data, (dict, OrderedDict)): + return data.items() + elif isinstance(data, (list, tuple)): + return enumerate(data) + else: + raise TypeError("Unsupported data type") + + +class _SaveBase: + def __init__(self, save_dir): + super(_SaveBase, self).__init__() + self.path = save_dir + self.save_func = _npy_save + + def get_save_func(self): + return self.save_func + + +@ms.jit_class +class _SaveCell(_SaveBase): + def __call__(self, name, data): + return self.get_save_func()(self.path, name, data) + + +class _SaveGradBase: + def __init__(self, save_dir, name): + super(_SaveGradBase, self).__init__() + self.file = save_dir + name + + +@ms.jit_class +class _SaveGradCell(_SaveGradBase): + def __init__(self, save_dir, name): + super(_SaveGradCell, self).__init__(save_dir, name) + self.ms_save_grad = ms.ops.InsertGradientOf( + _wrapper_save_grad_func(self.file)) + + def __call__(self, x): + if isinstance(x, ms.Tensor): + return self.ms_save_grad(x) + else: + raise TypeError(f"For 'save_grad', the type of argument 'data' must be mindspore.Tensor or torch.tensor, " + f"but got {type(x)}") + + +def _npy_save_ops(file, data): + if isinstance(data, ms.Tensor): + if data.dtype == ms.bfloat16: + data = data.float() + ms.ops.TensorDump()(file, data) + else: + raise TypeError(f"For 'save', the type of argument 'data' must be mindspore.Tensor or torch.tensor, " + f"but got {type(data)}") + + +def _wrapper_save_grad_func(file): + def _save_grad_func(grad): + data = grad + if data.dtype == ms.bfloat16: + data = data.float() + ms.ops.TensorDump()(file, data) + return grad + return _save_grad_func + + +def _npy_save(save_dir, item_name, data): + if isinstance(data, (list, tuple, dict, OrderedDict)): + for key, val in _iterate_items(data): + _npy_save(save_dir, f"{item_name}.{key}", val) + else: + if data is None: + return + _npy_save_ops(f"{save_dir}{item_name}", data) + + +def generate_dump_dir(save_dir, sep=os.sep): + """ + usage: generate dump directory path str in mindspore graph mode + """ + full_suffix = '{step}' + sep + '{rank}' + sep + if save_dir and save_dir[-1] != sep: + result_dir = save_dir + sep + full_suffix + else: + result_dir = save_dir + full_suffix + return result_dir + + +def save(save_dir, name, data): + """ + save tensor. + """ + dump_dir = generate_dump_dir(save_dir) + _SaveCell(dump_dir)(name, data) + + +def save_grad(save_dir, name, data): + """ + save grad. + """ + dump_dir = generate_dump_dir(save_dir) + suffix_name = name + '_grad' + return _SaveGradCell(dump_dir, suffix_name)(data) + + +def step(): + hal.synchronize() + temp_tensor = Tensor([1], dtype=ms.float32) + step_flag = "" + _run_op(ops.TensorDump(), "TensorDump", (step_flag, temp_tensor)) + ops.tensordump(step_flag, temp_tensor) + hal.synchronize() \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py new file mode 100644 index 0000000000000000000000000000000000000000..0ff9c56fb5a8fa7c2f044b2c93b69a831bd408b1 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py @@ -0,0 +1,175 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import inspect + +from mindspore import Tensor, ops, mint +from mindspore.mint import distributed +from mindspore.mint.nn import functional +from mindspore.communication import comm_func + +from msprobe.core.common.file_utils import load_yaml +from msprobe.core.common.utils import Const +from msprobe.core.data_dump.api_registry import ApiRegistry +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.common.const import Const as MsConst +from msprobe.mindspore.common.utils import is_mindtorch +from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell + + +stub_tensor_existed = True +try: + from mindspore.common._stub_tensor import StubTensor +except ImportError: + stub_tensor_existed = False + +cur_path = os.path.dirname(os.path.realpath(__file__)) +if not is_mindtorch(): + _api_types = { + Const.MS_FRAMEWORK: { + Const.MS_API_TYPE_OPS: ((ops,), (ops,)), + Const.MS_API_TYPE_TENSOR: ((Tensor,), (Tensor,)), + Const.MS_API_TYPE_MINT: ((mint,), (mint,)), + Const.MS_API_TYPE_MINT_FUNC: ((functional,), (functional,)), + Const.MS_API_TYPE_COM: ((comm_func,), (comm_func,)), + Const.MS_API_TYPE_MINT_DIST: ((distributed,), (distributed,)) + } + } + if stub_tensor_existed: + _api_types.get(Const.MS_FRAMEWORK).update( + {Const.MS_API_TYPE_STUB_TENSOR: ((StubTensor,), (StubTensor,))} + ) + + _supported_api_list_path = (os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE),) + _blacklist = [] +else: + import torch + import torch_npu + _api_types = { + Const.MT_FRAMEWORK: { + Const.PT_API_TYPE_FUNCTIONAL: ((torch.nn.functional,), (torch.nn.functional,)), + Const.PT_API_TYPE_TENSOR: ((torch.Tensor,), (torch.Tensor,)), + Const.PT_API_TYPE_TORCH: ((torch,), (torch,)), + Const.PT_API_TYPE_NPU: ((torch_npu,), (torch_npu,)), + Const.PT_API_TYPE_DIST: ((torch.distributed,), (torch.distributed, torch.distributed.distributed_c10d)) + } + } + _supported_api_list_path = (os.path.join(cur_path, '../../../pytorch/hook_module', + MsConst.SUPPORTED_API_LIST_FILE),) + _blacklist = [] + +_inner_used_api = { + Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_OPS: ( + ops, "norm", "square", "sqrt", "is_complex", "stack", "is_floating_point" + ), + Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_TENSOR: ( + Tensor, "to", "numel", 'sum' + ), + Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_MINT: ( + mint, "max", "min", "mean", "norm" + ) +} + + +class ApiTemplate(HOOKCell): + def __init__(self, api_name, api_func, prefix, hook_build_func): + self.api_name = api_name + self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP + distributed_prefix = Const.DIST_API_TYPE_PREFIX if is_mindtorch() else Const.MINT_DIST_API_TYPE_PREFIX + self.op_is_distributed = prefix == distributed_prefix + super().__init__(hook_build_func) + self.api_func = api_func + + @staticmethod + def async_to_sync(output): + # Fake handle, used to return after the CommHandle executes the wait method + fake_handle = type("FakeHandle", (), {"wait": lambda self: None})() + if isinstance(output, tuple) and len(output) == 2 and hasattr(output[1], "wait"): + output[1].wait() + output = (output[0], fake_handle) + elif hasattr(output, "wait"): + output.wait() + output = fake_handle + return output + + def construct(self, *args, **kwargs): + if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX): + return args[0] if args else kwargs.get(Const.INPUT) + + output = self.api_func(*args, **kwargs) + + if self.prefix_api_name.startswith( + (MsConst.DISTRIBUTED_DATA_PREFIX, Const.MINT_DIST_API_TYPE_PREFIX) + ): + try: + bound = inspect.signature(self.api_func).bind(*args, **kwargs) + bound.apply_defaults() + use_async_op_flag = bound.arguments.get("async_op", False) + except Exception as e: + use_async_op_flag = False + logger.warning(f"fail to get dist api's func signature because {e}, no wait") + + if use_async_op_flag or self.api_name in ["isend", "irecv"]: + output = self.async_to_sync(output) + if self.api_name == "batch_isend_irecv" and isinstance(output, list): + output = [self.async_to_sync(handle) for handle in output] + + return output + + def forward(self, *args, **kwargs): + if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX): + return args[0] if args else kwargs.get(Const.INPUT) + return self.api_func(*args, **kwargs) + + +api_register = None +stub_tensor_set = False + + +def get_api_register(return_new=False): + global stub_tensor_set + + def stub_method(method): + def wrapped_method(*args, **kwargs): + return method(*args, **kwargs) + return wrapped_method + if not is_mindtorch() and stub_tensor_existed and not stub_tensor_set: + api_names = load_yaml(_supported_api_list_path[0]).get(Const.MS_API_TYPE_TENSOR, []) + for attr_name in dir(StubTensor): + attr = getattr(StubTensor, attr_name) + if attr_name in api_names and callable(attr): + setattr(StubTensor, attr_name, stub_method(attr)) + stub_tensor_set = True + + if return_new: + return ApiRegistry( + _api_types, + _inner_used_api, + _supported_api_list_path, + ApiTemplate, + _blacklist + ) + + global api_register + if api_register is None: + api_register = ApiRegistry( + _api_types, + _inner_used_api, + _supported_api_list_path, + ApiTemplate, + _blacklist + ) + return api_register diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py deleted file mode 100644 index 7aee1deccd9689985c7a2e270648bd0877cd7cf3..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from mindspore import Tensor, ops, mint -from mindspore.mint.nn import functional -from mindspore.common._stub_tensor import StubTensor -from mindspore.communication import comm_func - -from msprobe.mindspore.dump.hook_cell.wrap_api import (HOOKTensor, HOOKStubTensor, HOOKFunctionalOP, - HOOKMintOP, HOOKMintNNFunctionalOP, HOOKDistributedOP, - HOOKTorchOP, HOOKTorchTensor, HOOKTorchFunctionalOP, - HOOKTorchDistributedOP, HOOKTorchNpuOP, - get_wrap_api_list, get_wrap_torch_api_list, setup_hooks) -from msprobe.core.common.utils import Const -from msprobe.mindspore.common.utils import is_mindtorch - -if is_mindtorch(): - import torch - import torch_npu - - -def stub_method(method): - def wrapped_method(*args, **kwargs): - return method(*args, **kwargs) - return wrapped_method - - -class ApiRegistry: - def __init__(self): - self.tensor_ori_attr = {} - self.stub_tensor_ori_attr = {} - self.functional_ori_attr = {} - self.mint_ops_ori_attr = {} - self.mint_func_ops_ori_attr = {} - self.distributed_ori_attr = {} - self.norm_inner_ops_ori_attr = {} - - self.torch_ori_attr = {} - self.torch_tensor_ori_attr = {} - self.torch_functional_ori_attr = {} - self.torch_distributed_ori_attr = {} - self.torch_npu_ori_attr = {} - - self.tensor_hook_attr = {} - self.stub_tensor_hook_attr = {} - self.functional_hook_attr = {} - self.mint_ops_hook_attr = {} - self.mint_func_ops_hook_attr = {} - self.distibuted_hook_attr = {} - self.norm_inner_ops_hook_attr = {} - - self.torch_hook_attr = {} - self.torch_tensor_hook_attr = {} - self.torch_functional_hook_attr = {} - self.torch_distributed_hook_attr = {} - self.torch_npu_hook_attr = {} - - self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"] - - @staticmethod - def store_ori_attr(ori_api_group, api_list, api_ori_attr): - for api in api_list: - if Const.SEP in api: - sub_module_name, sub_op = api.rsplit(Const.SEP, 1) - sub_module = getattr(ori_api_group, sub_module_name) - ori_api_func = getattr(sub_module, sub_op) - else: - ori_api_func = getattr(ori_api_group, api) - if ori_api_group == StubTensor: - api_ori_attr[api] = stub_method(ori_api_func) - continue - api_ori_attr[api] = ori_api_func - - @staticmethod - def set_api_attr(api_group, attr_dict): - for api, api_attr in attr_dict.items(): - if Const.SEP in api: - sub_module_name, sub_op = api.rsplit(Const.SEP, 1) - sub_module = getattr(api_group, sub_module_name, None) - if sub_module is not None: - setattr(sub_module, sub_op, api_attr) - else: - setattr(api_group, api, api_attr) - - def norm_inner_op_set_hook_func(self): - self.set_api_attr(ops, self.norm_inner_ops_hook_attr) - - def norm_inner_op_set_ori_func(self): - self.set_api_attr(ops, self.norm_inner_ops_ori_attr) - - def api_set_hook_func(self): - if is_mindtorch(): - self.set_api_attr(torch, self.torch_hook_attr) - self.set_api_attr(torch.Tensor, self.torch_tensor_hook_attr) - self.set_api_attr(torch.nn.functional, self.torch_functional_hook_attr) - self.set_api_attr(torch.distributed, self.torch_distributed_hook_attr) - self.set_api_attr(torch.distributed.distributed_c10d, self.torch_distributed_hook_attr) - self.set_api_attr(torch_npu, self.torch_npu_hook_attr) - else: - self.set_api_attr(Tensor, self.tensor_hook_attr) - self.set_api_attr(StubTensor, self.stub_tensor_hook_attr) - self.set_api_attr(ops, self.functional_hook_attr) - self.set_api_attr(mint, self.mint_ops_hook_attr) - self.set_api_attr(functional, self.mint_func_ops_hook_attr) - self.set_api_attr(comm_func, self.distibuted_hook_attr) - - def api_set_ori_func(self): - if is_mindtorch(): - self.set_api_attr(torch, self.torch_ori_attr) - self.set_api_attr(torch.Tensor, self.torch_tensor_ori_attr) - self.set_api_attr(torch.nn.functional, self.torch_functional_ori_attr) - self.set_api_attr(torch.distributed, self.torch_distributed_ori_attr) - self.set_api_attr(torch.distributed.distributed_c10d, self.torch_distributed_ori_attr) - self.set_api_attr(torch_npu, self.torch_npu_ori_attr) - else: - self.set_api_attr(Tensor, self.tensor_ori_attr) - self.set_api_attr(StubTensor, self.stub_tensor_ori_attr) - self.set_api_attr(ops, self.functional_ori_attr) - self.set_api_attr(mint, self.mint_ops_ori_attr) - self.set_api_attr(functional, self.mint_func_ops_ori_attr) - self.set_api_attr(comm_func, self.distributed_ori_attr) - - def initialize_hook(self, hook): - setup_hooks(hook) - if is_mindtorch(): - wrap_torch_api_name = get_wrap_torch_api_list() - self.store_ori_attr(torch, - wrap_torch_api_name.torch_api_names, self.torch_ori_attr) - self.store_ori_attr(torch.Tensor, - wrap_torch_api_name.tensor_api_names, self.torch_tensor_ori_attr) - self.store_ori_attr(torch.nn.functional, - wrap_torch_api_name.functional_api_names, self.torch_functional_ori_attr) - self.store_ori_attr(torch.distributed, - wrap_torch_api_name.distributed_api_names, self.torch_distributed_ori_attr) - self.store_ori_attr(torch_npu, - wrap_torch_api_name.npu_api_names, self.torch_npu_ori_attr) - for attr_name in dir(HOOKTorchOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.torch_hook_attr[api_name] = getattr(HOOKTorchOP, attr_name) - for attr_name in dir(HOOKTorchTensor): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.torch_tensor_hook_attr[api_name] = getattr(HOOKTorchTensor, attr_name) - for attr_name in dir(HOOKTorchFunctionalOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.torch_functional_hook_attr[api_name] = getattr(HOOKTorchFunctionalOP, attr_name) - for attr_name in dir(HOOKTorchDistributedOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.torch_distributed_hook_attr[api_name] = getattr(HOOKTorchDistributedOP, attr_name) - for attr_name in dir(HOOKTorchNpuOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.torch_npu_hook_attr[api_name] = getattr(HOOKTorchNpuOP, attr_name) - return - - wrap_api_name = get_wrap_api_list() - self.store_ori_attr(Tensor, wrap_api_name.tensor_api_names, self.tensor_ori_attr) - self.store_ori_attr(StubTensor, wrap_api_name.stub_tensor_api_names, self.stub_tensor_ori_attr) - self.store_ori_attr(ops, wrap_api_name.ops_api_names, self.functional_ori_attr) - self.store_ori_attr(mint, wrap_api_name.mint_api_names, self.mint_ops_ori_attr) - self.store_ori_attr(functional, wrap_api_name.mint_nn_func_api_names, self.mint_func_ops_ori_attr) - self.store_ori_attr(comm_func, wrap_api_name.distributed_api_names, self.distributed_ori_attr) - self.store_ori_attr(ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr) - for attr_name in dir(HOOKTensor): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.tensor_hook_attr[api_name] = getattr(HOOKTensor, attr_name) - for attr_name in dir(HOOKStubTensor): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.stub_tensor_hook_attr[api_name] = getattr(HOOKStubTensor, attr_name) - for attr_name in dir(HOOKFunctionalOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.functional_hook_attr[api_name] = getattr(HOOKFunctionalOP, attr_name) - if api_name in self.norm_inner_ops: - self.norm_inner_ops_hook_attr[api_name] = getattr(HOOKFunctionalOP, attr_name) - for attr_name in dir(HOOKMintOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.mint_ops_hook_attr[api_name] = getattr(HOOKMintOP, attr_name) - for attr_name in dir(HOOKMintNNFunctionalOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.mint_func_ops_hook_attr[api_name] = getattr(HOOKMintNNFunctionalOP, attr_name) - for attr_name in dir(HOOKDistributedOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.distibuted_hook_attr[api_name] = getattr(HOOKDistributedOP, attr_name) - - -api_register = ApiRegistry() diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py index b68a7d995a56497a219281c5a43d692c46cfac4d..e56b242b8b965d3c7290d87f94e2b2896184f631 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py @@ -13,12 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import threading from collections import defaultdict +import mindspore as ms from mindspore import nn from msprobe.mindspore.common.utils import is_mindtorch, register_backward_hook_functions +ms_version = ms.__version__ + def add_cell_count(name): HOOKCell.cell_count[name] += 1 @@ -28,42 +32,31 @@ def get_cell_count(name): return HOOKCell.cell_count[name] -def __init__(self, build_hook) -> None: +def __init__(self, hook_build_func) -> None: super(HOOKCell, self).__init__() - self.changed_status = False - self.input_kwargs = {} - self.prefix = "" - if not HOOKCell.g_stop_hook: - HOOKCell.g_stop_hook = True - self.changed_status = True - if hasattr(self, "prefix_api_name"): - self.prefix = self.prefix_api_name - - self.forward_data_collected = False - forward_pre_hook, forward_hook, backward_hook, backward_pre_hook = build_hook(self.prefix) - self.register_forward_pre_hook(forward_pre_hook) - self.register_forward_hook(forward_hook) - register_backward_hook_functions["full"](self, backward_hook) - register_backward_hook_functions["pre"](self, backward_pre_hook) + self.msprobe_input_kwargs = {} + prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else "" + if callable(hook_build_func): + hook_set = hook_build_func(prefix) + if ms_version < "2.6.0" and not is_mindtorch(): + getattr(self, "_forward_pre_hook", {})[id(self)] = hook_set.forward_pre_hook + if hook_set.forward_hook: + getattr(self, "_forward_hook", {})[id(self)] = hook_set.forward_hook + else: + self.register_forward_pre_hook(hook_set.forward_pre_hook) + if hook_set.forward_hook: + self.register_forward_hook(hook_set.forward_hook) -# 重载call,加全局标志。 def __call__(self, *args, **kwargs): - try: - self.input_kwargs = kwargs - out = super(HOOKCell, self).__call__(*args, **kwargs) - except Exception as e: - raise e - finally: - if self.changed_status: - self.changed_status = False - HOOKCell.g_stop_hook = False + tid = threading.get_ident() + self.msprobe_input_kwargs[tid] = kwargs + out = super(HOOKCell, self).__call__(*args, **kwargs) return out hook_cell_dict = { "cell_count": defaultdict(int), - "g_stop_hook": False, "add_cell_count": staticmethod(add_cell_count), "get_cell_count": staticmethod(get_cell_count), "__init__": __init__, diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/ms_hook_manager.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/ms_hook_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..ba6e73e342552e774481714fca563aa74ca34f42 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/ms_hook_manager.py @@ -0,0 +1,212 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +from collections import OrderedDict + +import mindspore as ms +from mindspore import Tensor +from mindspore.common.api import _no_grad, _pynative_executor +from mindspore.ops.operations import _inner_ops as inner + +from msprobe.core.common.const import Const +from msprobe.core.common.log import logger +from msprobe.core.common.utils import replace_last_occurrence, ThreadSafe +from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputs +from msprobe.core.hook_manager import BaseHookManager, HookSet +from msprobe.mindspore.common.const import Const as MsConst +from msprobe.mindspore.common.utils import ( + has_kwargs_in_forward_hook, + is_mindtorch, + is_backward_hook_output_a_view +) +from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell + +ms_version = ms.__version__ + + +class MindsporeHookManager(BaseHookManager): + cell_bw_hook_kernels = {} + cell_backward_pre_hook = [] + cell_backward_hook = [] + + @property + def _is_recompute(self): + return None + + @staticmethod + def reset_status(): + BaseHookManager.reset_status() + MindsporeHookManager.cell_bw_hook_kernels.clear() + MindsporeHookManager.cell_backward_pre_hook.clear() + MindsporeHookManager.cell_backward_hook.clear() + + @staticmethod + def _no_grad_context(): + return _no_grad() + + @staticmethod + def _add_count(name): + HOOKCell.add_cell_count(name) + + @staticmethod + def _get_count(name): + return HOOKCell.get_cell_count(name) + + @staticmethod + def _process_kwargs_and_output(module, tid, hook_type, kwargs_or_output, output_or_kwargs): + if not has_kwargs_in_forward_hook() or hook_type == Const.API: + kwargs = module.msprobe_input_kwargs.get(tid, {}) if hasattr(module, 'msprobe_input_kwargs') else {} + output = kwargs_or_output + else: + kwargs = kwargs_or_output + output = output_or_kwargs + return kwargs, output + + def build_hook(self, hook_type, name): + if hook_type == Const.API: + hook_set = HookSet( + forward_pre_hook=self._build_forward_pre_hook(hook_type, name) + ) + else: + full_backward_name = replace_last_occurrence(name, Const.FORWARD, Const.BACKWARD) + hook_set = HookSet( + forward_hook=self._build_forward_hook(hook_type, name), + backward_pre_hook=self._build_backward_pre_hook(hook_type, full_backward_name), + backward_hook=self._build_backward_hook(hook_type, full_backward_name) + ) + return hook_set + + def _register_forward_hook(self, module, api_name): + if not hasattr(module, 'msprobe_forward_hook'): + forward_hook = self._build_forward_hook(Const.API, api_name) + if ms_version < "2.6.0" and not is_mindtorch(): + getattr(module, "_forward_hook", {})[id(module)] = forward_hook + else: + module.register_forward_hook(forward_hook) + setattr(module, 'msprobe_forward_hook', True) + + def _register_backward_hook(self, module, full_backward_name, args): + if not _pynative_executor.requires_grad(): + return args + + enable_hooked = sum( + [isinstance(ele, Tensor) and ele.dtype not in MsConst.NonDifferentiableType for ele in args] + ) + + if enable_hooked: + backward_hook_dict = OrderedDict() + backward_hook_dict[full_backward_name] = self._build_backward_hook(Const.API, full_backward_name) + MindsporeHookManager.cell_backward_hook.append(backward_hook_dict) + bw_hook = inner.CellBackwardHook(full_backward_name, module, MindsporeHookManager.cell_backward_hook[-1]) + bw_hook.register_backward_hook() + MindsporeHookManager.cell_bw_hook_kernels[full_backward_name] = bw_hook + args = bw_hook(args) if is_backward_hook_output_a_view() else bw_hook(*args) + return args + + def _register_backward_pre_hook(self, module, full_backward_name, output): + if not _pynative_executor.requires_grad(): + return output + + bw_hook = MindsporeHookManager.cell_bw_hook_kernels.get(full_backward_name) + if bw_hook: + if not isinstance(output, (Tensor, tuple)): + logger.debug("For backward hooks to be called, " + "cell output should be a Tensor or a tuple of Tensors " + f"but received {type(output)}") + if is_backward_hook_output_a_view(): + new_outputs = bw_hook(output) + else: + if isinstance(output, tuple): + new_outputs = bw_hook(*output) + else: + new_outputs = bw_hook(output) + if isinstance(output, tuple) and len(output) == 1: + new_outputs = (new_outputs,) + output = new_outputs + + def get_backward_pre_hook(backward_pre_hook, backward_post_hook): + @ThreadSafe.synchronized + def backward_pre_hook_fn(cell, grad_output): + backward_pre_hook(cell, grad_output) + if backward_post_hook: + backward_post_hook(cell, (), grad_output) + + return backward_pre_hook_fn + + backward_pre_hook = self._build_backward_pre_hook(Const.API, full_backward_name) + backward_post_hook = None if bw_hook else self._build_backward_hook(Const.API, full_backward_name) + + backward_pre_hook_dict = OrderedDict() + backward_pre_hook_dict[full_backward_name] = get_backward_pre_hook( + backward_pre_hook, + backward_post_hook + ) + MindsporeHookManager.cell_backward_pre_hook.append(backward_pre_hook_dict) + bw_pre_hook = inner.CellBackwardHook( + full_backward_name, + module, + MindsporeHookManager.cell_backward_pre_hook[-1] + ) + bw_pre_hook.register_backward_pre_hook() + + if is_backward_hook_output_a_view(): + result = bw_pre_hook(output) + else: + if isinstance(output, tuple): + result = bw_pre_hook(*output) + else: + result = bw_pre_hook(output) + if isinstance(output, tuple): + if len(output) == 1: + result = (result,) + if len(result) != len(output): + raise TypeError( + f"The backward pre hook return value size is {len(result)} " + f"not equal to output size {len(output)}" + ) + return result + + def _need_exchange(self, module): + if not hasattr(module, 'has_pre_hook_called') or not module.has_pre_hook_called: + return False + else: + return True + + def _get_params_dict(self, module): + params_dict = {} + if self.config.task != Const.STRUCTURE: + params_dict = { + key.split(Const.SEP)[-1]: value + for key, value in module.parameters_dict(recurse=False).items() + } + return params_dict + + def _build_backward_pre_hook(self, hook_type, full_name): + def backward_pre_hook(module, grad_input): + if self.config.level != Const.LEVEL_L2: + return + tid = threading.get_ident() + if not self._should_execute_hook(tid): + return + + with ThreadSafe(): + BaseHookManager.inner_switch[tid] = True + module_input = ModuleBackwardInputs(grad_input=grad_input) + self.data_collector.update_api_or_module_name(full_name) + self.data_collector.backward_input_data_collect(full_name, module, self._pid, module_input) + BaseHookManager.inner_switch[tid] = False + + return backward_pre_hook diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py index 656e48c678956563a6f2d1d5f5ab8a4d03f074e7..71f218bd5f7dc08a6f0b6b2b190570009621f82b 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py @@ -14,13 +14,17 @@ # limitations under the License. import os +import threading from mindspore import ops from mindspore.common.tensor import Tensor - -from msprobe.core.common.utils import Const, DumpException -from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputs, ModuleBackwardOutputs, - ModuleForwardInputsOutputs) +from msprobe.core.common.utils import Const, DumpException, ThreadSafe +from msprobe.core.data_dump.data_processor.base import ( + ModuleBackwardInputs, + ModuleBackwardOutputs, + ModuleForwardInputsOutputs +) +from msprobe.core.hook_manager import BaseHookManager from msprobe.mindspore.common.log import logger @@ -55,10 +59,13 @@ class PrimitiveHookService: callable: 反向 hook 函数。 """ + @ThreadSafe.synchronized def backward_hook(grad): + tid = threading.get_ident() + BaseHookManager.inner_switch[tid] = True + captured_grads.extend(grad) backward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}" - try: if hook_type == Const.INPUT: self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name) @@ -77,6 +84,7 @@ class PrimitiveHookService: logger.error(f"This is a primitive op {hook_type}_backward dump error: {exception}, " f"updated_primitive_name: {updated_primitive_name}") raise DumpException(DumpException.BACKWARD_DATA_COLLECTION_ERROR) from exception + BaseHookManager.inner_switch[tid] = False return backward_hook @@ -135,7 +143,12 @@ class PrimitiveHookService: return tuple(hooked_outputs) return out + @ThreadSafe.synchronized def pre_forward_hook(primitive_name, primitive_instance, args, kwargs): + tid = threading.get_ident() + BaseHookManager.inner_switch[tid] = True + + self.service_instance.data_collector.update_api_or_module_name(primitive_name) module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None) try: self.service_instance.data_collector.forward_input_data_collect( @@ -148,8 +161,14 @@ class PrimitiveHookService: logger.error(f"This is a primitive op dump error during forward input data collection: {exception}, " f"primitive_name: {primitive_name}") raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception + BaseHookManager.inner_switch[tid] = False + @ThreadSafe.synchronized def post_forward_hook(primitive_name, primitive_instance, args, kwargs, output): + tid = threading.get_ident() + BaseHookManager.inner_switch[tid] = True + + self.service_instance.data_collector.update_api_or_module_name(primitive_name) module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output) try: self.service_instance.data_collector.forward_output_data_collect( @@ -162,6 +181,7 @@ class PrimitiveHookService: logger.error(f"This is a primitive op dump error during forward output data collection: {exception}, " f"primitive_name: {primitive_name}") raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception + BaseHookManager.inner_switch[tid] = False def wrapped_primitive_call(instance_self, *args, **kwargs): """ @@ -179,7 +199,8 @@ class PrimitiveHookService: current_count = self.primitive_counters.get(primitive_name, 0) updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}{Const.SEP}{primitive_name}{Const.SEP}{current_count}" - if not self.service_instance.primitive_switch: + tid = threading.get_ident() + if not self.service_instance.primitive_switch or BaseHookManager.inner_switch[tid]: return origin_func(*args, **kwargs) captured_grads_input, captured_grads_output = [], [] @@ -192,8 +213,6 @@ class PrimitiveHookService: raise DumpException(DumpException.INPUT_HOOK_ERROR) from exception forward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.FORWARD}" - self.service_instance.data_collector.update_api_or_module_name(forward_primitive_name) - pre_forward_hook(forward_primitive_name, instance_self, hooked_inputs, kwargs) try: out = origin_func(*hooked_inputs, **kwargs) @@ -214,6 +233,7 @@ class PrimitiveHookService: return wrapped_primitive_call + @ThreadSafe.synchronized def update_primitive_counters(self, primitive_name): if primitive_name not in self.primitive_counters: self.primitive_counters[primitive_name] = 0 diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml index 723b0cbc93f78d50f703838eb488de6733008906..eae8f85a87fb2b0986cefb2e6faae7399a86f367 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml @@ -564,15 +564,15 @@ tensor: - all - amax - amin + - angle - any - arccos - arccosh - - argmax - - angle - arcsin - arcsinh - arctan - arctanh + - argmax - argmin - argsort - asin @@ -582,19 +582,23 @@ tensor: - atanh - baddbmm - bernoulli + - bfloat16 - bincount - bitwise_and - bitwise_or - bitwise_xor - bmm - bool + - bool astype - broadcast_to + - byte - ceil - - cholesky_solve - cholesky + - cholesky_solve - clamp - clip - conj + - copy - copysign - cos - cosh @@ -606,11 +610,13 @@ tensor: - deg2rad - diag - diagflat + - diagonal - diff - digamma - div - div_ - divide + - double - equal - erf - erfc @@ -618,13 +624,16 @@ tensor: - exp - expand_as - expm1 + - flatten - flip - fliplr - flipud + - float - float_power - floor - fmod - frac + - from_numpy - gather_elements - ge - geqrf @@ -648,12 +657,12 @@ tensor: - inner - int - inverse + - is_complex + - is_signed - isclose - isfinite - isinf - isnan - - is_complex - - is_signed - isneginf - isposinf - isreal @@ -704,28 +713,27 @@ tensor: - new_ones - new_zeros - nextafter - - norm - nonzero + - norm - not_equal - ormqr - permute - pow - prod - qr + - rad2deg - ravel - real - reciprocal - remainder - renorm - - rad2deg - - tile - repeat_interleave - reshape - reshape - - round + - resize - rot90 + - round - rsqrt - - sum_to_size - scatter - sgn - short @@ -745,7 +753,8 @@ tensor: - sub - sub_ - subtract - - subtract + - sum + - sum_to_size - svd - swapaxes - swapdims @@ -753,13 +762,13 @@ tensor: - take - tan - tanh - - trace - - swapaxes + - tensor_split - tile + - to - topk - - tril - - tensor_split + - trace - transpose + - tril - true_divide - trunc - unbind @@ -769,17 +778,6 @@ tensor: - view - where - xlogy - - from_numpy - - std - - take - - var - - all - - any - - copy - - diagonal - - flatten - - resize - - sum mint.ops: - abs @@ -1027,3 +1025,21 @@ communication.comm_func: - recv - isend - irecv + +mint.distributed: + - send + - recv + - broadcast + - all_reduce + - reduce + - all_gather + - gather + - isend + - irecv + - scatter + - reduce_scatter + - all_to_all_single + - all_to_all + - all_gather_into_tensor + - reduce_scatter_tensor + - batch_isend_irecv diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_api.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_api.py deleted file mode 100644 index 0e97929ecd7f8444b19fd531efc49883d0df58de..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_api.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -from mindspore import Tensor, mint, ops -from mindspore.common._stub_tensor import StubTensor -from mindspore.communication import comm_func -from mindspore.mint.nn import functional - -from msprobe.core.common.const import Const -from msprobe.core.common.file_utils import load_yaml -from msprobe.mindspore.common.const import Const as MsConst -from msprobe.mindspore.common.utils import is_mindtorch -from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell - -if is_mindtorch(): - import torch - import torch_npu - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE) -torch_yaml_path = os.path.join(cur_path, "../../../pytorch/hook_module", MsConst.SUPPORTED_API_LIST_FILE) - - -class HOOKTensor(object): - pass - - -class HOOKStubTensor(object): - pass - - -class HOOKFunctionalOP(object): - pass - - -class HOOKMintOP(object): - pass - - -class HOOKMintNNFunctionalOP(object): - pass - - -class HOOKDistributedOP(object): - pass - - -class HOOKTorchOP(object): - pass - - -class HOOKTorchTensor(object): - pass - - -class HOOKTorchFunctionalOP(object): - pass - - -class HOOKTorchDistributedOP(object): - pass - - -class HOOKTorchNpuOP(object): - pass - - -class ApiTemplate(HOOKCell): - def __init__(self, api_name, api_dict, prefix, hook): - self.api_name = api_name - self.api_func = api_dict[api_name] - self.prefix_api_name = prefix + str(api_name.split(Const.SEP)[-1]) + Const.SEP - super().__init__(hook) - - @staticmethod - def async_to_sync(output): - # Fake handle, used to return after the CommHandle executes the wait method - fake_handle = type("FakeHandle", (), {"wait": lambda self: None})() - if isinstance(output, tuple) and len(output) == 2 and hasattr(output[1], "wait"): - output[1].wait() - output = (output[0], fake_handle) - elif hasattr(output, "wait"): - output.wait() - output = fake_handle - return output - - def construct(self, *args, **kwargs): - if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX): - return args[0] if args else kwargs.get(Const.INPUT) - - output = self.api_func(*args, **kwargs) - - if self.prefix_api_name.startswith(MsConst.DISTRIBUTED_DATA_PREFIX): - if kwargs.get("async_op") or self.api_name in ["isend", "irecv"]: - output = self.async_to_sync(output) - return output - - def forward(self, *args, **kwargs): - if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX): - return args[0] if args else kwargs.get(Const.INPUT) - return self.api_func(*args, **kwargs) - - -class WrapApiName: - def __init__(self, tensor_api_names, stub_tensor_api_names, ops_api_names, mint_api_names, mint_nn_func_api_names, - distributed_api_names): - self.tensor_api_names = tensor_api_names - self.stub_tensor_api_names = stub_tensor_api_names - self.ops_api_names = ops_api_names - self.mint_api_names = mint_api_names - self.mint_nn_func_api_names = mint_nn_func_api_names - self.distributed_api_names = distributed_api_names - - -class WrapTorchApiName: - def __init__(self, torch_api_names, tensor_api_names, functional_api_names, distributed_api_names, npu_api_names): - self.torch_api_names = torch_api_names - self.tensor_api_names = tensor_api_names - self.functional_api_names = functional_api_names - self.distributed_api_names = distributed_api_names - self.npu_api_names = npu_api_names - - -def get_wrap_api_list(): - api_list = load_yaml(yaml_path) - tensor_api = api_list.get(MsConst.SUPPORTED_TENSOR_LIST_KEY) - ops_api = api_list.get(MsConst.SUPPORTED_OPS_LIST_KEY) - mint_api = api_list.get(MsConst.SUPPORTED_MINT_LIST_KEY) - mint_nn_func_api = api_list.get(MsConst.SUPPORTED__MINT_NN_FUNC_LIST_KEY) - distributed_api = api_list.get(MsConst.SUPPORTED_COMM_LIST_KEY) - wrap_api_name = WrapApiName(set(tensor_api) & set(dir(Tensor)), - set(tensor_api) & set(dir(StubTensor)), - set(ops_api) & set(dir(ops)), - set(mint_api) & set(dir(mint)), - set(mint_nn_func_api) & set(dir(functional)), - set(distributed_api) & set(dir(comm_func))) - return wrap_api_name - - -def get_wrap_torch_api_list(): - api_list = load_yaml(torch_yaml_path) - torch_api = api_list.get("torch") - tensor_api = api_list.get("tensor") - functional_api = api_list.get("functional") - distributed_api = api_list.get("distributed") - npu_api = api_list.get("torch_npu") - wrap_api_name = WrapTorchApiName(set(torch_api) & set(dir(torch)), - set(tensor_api) & set(dir(torch.Tensor)), - set(functional_api) & set(dir(torch.nn.functional)), - set(distributed_api) & set(dir(torch.distributed)), - set(npu_api) & set(dir(torch_npu))) - return wrap_api_name - - -def wrap_api_func(api_name, api_dict, prefix, hook): - def api_function(*args, **kwargs): - return ApiTemplate(api_name, api_dict, prefix, hook)(*args, **kwargs) - return api_function - - -def wrap_api_func_and_bind(api_list, api_dict, prefix, hook, hook_class): - for api_name in api_list: - if callable(api_dict[api_name]): - setattr(hook_class, Const.ATTR_NAME_PREFIX + api_name, wrap_api_func(api_name, api_dict, prefix, hook)) - - -def setup_hooks(hook): - if is_mindtorch(): - torch_wrap_api_name = get_wrap_torch_api_list() - wrap_api_func_and_bind(torch_wrap_api_name.torch_api_names, - {f: getattr(torch, f) for f in dir(torch)}, - MsConst.TORCH_DATA_PREFIX, hook, HOOKTorchOP) - wrap_api_func_and_bind(torch_wrap_api_name.tensor_api_names, - {f: getattr(torch.Tensor, f) for f in dir(torch.Tensor)}, - MsConst.TENSOR_DATA_PREFIX, hook, HOOKTorchTensor) - wrap_api_func_and_bind(torch_wrap_api_name.functional_api_names, - {f: getattr(torch.nn.functional, f) for f in dir(torch.nn.functional)}, - MsConst.OPS_DATA_PREFIX, hook, HOOKTorchFunctionalOP) - wrap_api_func_and_bind(torch_wrap_api_name.distributed_api_names, - {f: getattr(torch.distributed, f) for f in dir(torch.distributed)}, - MsConst.DISTRIBUTED_DATA_PREFIX, hook, HOOKTorchDistributedOP) - wrap_api_func_and_bind(torch_wrap_api_name.npu_api_names, {f: getattr(torch_npu, f) for f in dir(torch_npu)}, - MsConst.TORCH_NPU_DATA_PREFIX, hook, HOOKTorchNpuOP) - return - - wrap_api_name = get_wrap_api_list() - wrap_api_func_and_bind(wrap_api_name.tensor_api_names, {f: getattr(Tensor, f) for f in dir(Tensor)}, - MsConst.TENSOR_DATA_PREFIX, hook, HOOKTensor) - wrap_api_func_and_bind(wrap_api_name.stub_tensor_api_names, {f: getattr(StubTensor, f) for f in dir(StubTensor)}, - MsConst.STUB_TENSOR_DATA_PREFIX, hook, HOOKStubTensor) - wrap_api_func_and_bind(wrap_api_name.ops_api_names, {f: getattr(ops, f) for f in dir(ops)}, - MsConst.OPS_DATA_PREFIX, hook, HOOKFunctionalOP) - wrap_api_func_and_bind(wrap_api_name.mint_api_names, {f: getattr(mint, f) for f in dir(mint)}, - MsConst.MINT_DATA_PREFIX, hook, HOOKMintOP) - wrap_api_func_and_bind(wrap_api_name.mint_nn_func_api_names, {f: getattr(functional, f) for f in dir(functional)}, - MsConst.MINT_NN_FUNC_DATA_PREFIX, hook, HOOKMintNNFunctionalOP) - wrap_api_func_and_bind(wrap_api_name.distributed_api_names, {f: getattr(comm_func, f) for f in dir(comm_func)}, - MsConst.DISTRIBUTED_DATA_PREFIX, hook, HOOKDistributedOP) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py index 0a32200639a1f3805f815c37caaef5d3bb64c82f..5f4564394b76d3f5e86e5e0e3cbd13cd70b830a4 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py @@ -14,9 +14,13 @@ # limitations under the License. import os +import types from collections import defaultdict +import mindspore +from mindspore import nn from mindspore._c_expression import PyNativeExecutor_ + try: from mindspore.common.api import _MindsporeFunctionExecutor except ImportError: @@ -24,30 +28,33 @@ except ImportError: from msprobe.core.common.log import logger from msprobe.core.common.const import Const +from msprobe.core.common.utils import ThreadSafe +from msprobe.core.common.runtime import Runtime from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs -from msprobe.mindspore.dump.hook_cell.api_registry import api_register +from msprobe.mindspore.common.const import Const as MsConst +from msprobe.mindspore.dump.hook_cell.api_register import get_api_register + +_api_register = get_api_register() def dump_jit(name, in_feat, out_feat, is_forward): pid = os.getpid() - ori_args = str(name) - index = ori_args.find("<") - if index != 0 and index != -1: - result = ori_args[0:index] - elif name is not None and "<" not in str(name): - result = str(name) - else: - result = "JitFunction" - if JitDump.need_dump(): + name = name if name else "JitFunction" + if not JitDump.need_dump(): + return + with ThreadSafe(): if is_forward: - JitDump.jit_count[result] += 1 - name_template = (Const.JIT + Const.SEP + result + Const.SEP + - str(JitDump.jit_count[result]) + Const.SEP + Const.FORWARD) + if name in JitDump.jit_count: + JitDump.jit_count[name] += 1 + else: + JitDump.jit_count[name] = 0 + name_template = (Const.JIT + Const.SEP + name + Const.SEP + + str(JitDump.jit_count[name]) + Const.SEP + Const.FORWARD) JitDump.data_collector.update_api_or_module_name(name_template) module_input_output = ModuleForwardInputsOutputs(args=in_feat, kwargs={}, output=out_feat) JitDump.data_collector.forward_data_collect(name_template, None, pid, module_input_output) else: - name_template = Const.JIT + Const.SEP + result + Const.SEP + str(JitDump.jit_count[result]) + Const.SEP + \ + name_template = Const.JIT + Const.SEP + name + Const.SEP + str(JitDump.jit_count[name]) + Const.SEP + \ Const.BACKWARD JitDump.data_collector.update_api_or_module_name(name_template) module_input_output = ModuleBackwardInputsOutputs(grad_input=in_feat, grad_output=out_feat) @@ -57,7 +64,7 @@ def dump_jit(name, in_feat, out_feat, is_forward): class JitDump(_MindsporeFunctionExecutor): dump_config = None jit_enable = False - jit_dump_switch = True + jit_dump_switch = False jit_count = defaultdict(int) def __init__(self, *args, **kwargs): @@ -68,19 +75,17 @@ class JitDump(_MindsporeFunctionExecutor): self._executor = PyNativeExecutor_.get_instance() def __call__(self, *args, **kwargs): - if JitDump.jit_dump_switch: - api_register.api_set_ori_func() + _api_register.restore_all_api() out = super().__call__(*args, **kwargs) - if JitDump.jit_dump_switch and len(args) > 0: - if self.name and self.name != "construct": + if JitDump.jit_dump_switch and len(args) > 0 and self.name: + if self.name != "construct": dump_jit(self.name, args, out, True) - else: - dump_jit(args[0], args, out, True) + elif Runtime.run_mode != MsConst.PYNATIVE_GRAPH_MODE and isinstance(args[0], nn.Cell): + dump_jit(args[0].__class__.__name__, args, out, True) JitDump.jit_enable = True elif len(args) == 0: logger.warning(f"The jit function {self.name} has no input arguments, nothing will be dumped.") - if JitDump.jit_dump_switch: - api_register.api_set_hook_func() + _api_register.register_all_api() return out @classmethod @@ -101,9 +106,15 @@ class JitDump(_MindsporeFunctionExecutor): def grad(self, obj, grad, weights, grad_position, *args, **kwargs): if JitDump.jit_dump_switch and JitDump.jit_enable: - api_register.api_set_ori_func() - output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values())) + _api_register.restore_all_api() + if mindspore.__version__ >= "2.5": + output = self._executor.grad(grad, obj, weights, grad_position, False, *args, *(kwargs.values())) + else: + output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values())) if JitDump.jit_dump_switch and JitDump.jit_enable: - dump_jit(obj, args, None, False) - api_register.api_set_hook_func() + if isinstance(obj, types.FunctionType): + dump_jit(obj.__name__, args, None, False) + elif Runtime.run_mode != MsConst.PYNATIVE_GRAPH_MODE and isinstance(obj, nn.Cell): + dump_jit(obj.__class__.__name__, args, None, False) + _api_register.register_all_api() return output diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/kernel_kbyk_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/kernel_kbyk_dump.py index 2c46b0c73e7789ea41afb991bb985e089b2349cd..7e72e2bb5e8fa6a21b58b8aea709d534ebdd1cdd 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/kernel_kbyk_dump.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/kernel_kbyk_dump.py @@ -20,6 +20,9 @@ from msprobe.core.common.file_utils import create_directory, save_json from msprobe.mindspore.common.log import logger from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +import mindspore as ms +ms_version = ms.__version__ + class KernelKbykDump: COMMON_SETTINGS = "common_dump_settings" @@ -39,9 +42,20 @@ class KernelKbykDump: common_set["input_output"] = 0 common_set["kernels"] = [] common_set["support_device"] = [0, 1, 2, 3, 4, 5, 6, 7] - e2e_set = dict() - e2e_set["enable"] = True - e2e_set["trans_flag"] = True + common_set["statistic_category"] = [] + + if config.stat_cal_mode and config.device_stat_precision_mode: + e2e_set = { + "enable": not config.async_dump, + "trans_flag": True, + "stat_calc_mode": config.stat_cal_mode, + "device_stat_precision_mode": config.device_stat_precision_mode + } + else: + e2e_set = { + "enable": not config.async_dump, + "trans_flag": True + } if config.list: common_set["dump_mode"] = 1 @@ -61,10 +75,30 @@ class KernelKbykDump: common_set["input_output"] = 1 if config.data_mode[0] == Const.OUTPUT: common_set["input_output"] = 2 + if config.summary_mode: + if isinstance(config.summary_mode, str): + if config.summary_mode == Const.STATISTICS: + common_set["statistic_category"] = ["max", "min", "avg", "l2norm"] + else: + mode = self._process_hash(config.summary_mode) + common_set["statistic_category"] = [mode] + elif isinstance(config.summary_mode, list): + common_set["statistic_category"] = list({ + self._process_hash("avg" if mode == "mean" else mode) + for mode in config.summary_mode + }) self.dump_json[KernelKbykDump.COMMON_SETTINGS] = common_set self.dump_json[KernelKbykDump.E2E_SETTINGS] = e2e_set + @staticmethod + def _process_hash(value): + if ms_version <= "2.7.0" and (value == Const.HASH or value == Const.MD5): + value = "md5" + elif value == Const.MD5: + value = "hash:md5" + return value + def handle(self): json_path = self.dump_json[KernelKbykDump.COMMON_SETTINGS]["path"] create_directory(json_path) diff --git a/debug/accuracy_tools/msprobe/mindspore/dym_loader/hook_dynamic_loader.cc b/debug/accuracy_tools/msprobe/mindspore/dym_loader/hook_dynamic_loader.cc deleted file mode 100644 index b72d68741da491fc450c2d697a3ebfec895a3447..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +++ /dev/null @@ -1,140 +0,0 @@ -/** - * Copyright 2024 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "hook_dynamic_loader.h" -#include -#include -#include -#include "utils/log_adapter.h" - -namespace { - -// Utility function to check if a file path is valid -bool IsValidPath(const std::string &path) { - struct stat fileStat; - if (stat(path.c_str(), &fileStat) != 0) { - MS_LOG(ERROR) << "File does not exist or cannot be accessed: " << path; - return false; - } - - if (S_ISLNK(fileStat.st_mode)) { - MS_LOG(ERROR) << "File is a symbolic link, which is not allowed: " << path; - return false; - } - - if (!S_ISREG(fileStat.st_mode)) { - MS_LOG(ERROR) << "File is not a regular file: " << path; - return false; - } - - if (path.substr(path.find_last_of(".")) != ".so") { - MS_LOG(ERROR) << "File is not a .so file: " << path; - return false; - } - - return true; -} - -} // namespace - -HookDynamicLoader &HookDynamicLoader::GetInstance() { - static HookDynamicLoader instance; - return instance; -} - -bool HookDynamicLoader::loadFunction(void *handle, const std::string &functionName) { - void *func = dlsym(handle, functionName.c_str()); - if (!func) { - MS_LOG(WARNING) << "Could not load function: " << functionName << ", error: " << dlerror(); - return false; - } - funcMap_[functionName] = func; - return true; -} - -bool HookDynamicLoader::validateLibraryPath(const std::string &libPath) { - char *realPath = realpath(libPath.c_str(), nullptr); - if (!realPath) { - MS_LOG(WARNING) << "Failed to resolve realpath for the library: " << libPath; - return false; - } - - bool isValid = IsValidPath(realPath); - free(realPath); // Free memory allocated by realpath - return isValid; -} - -bool HookDynamicLoader::LoadLibrary() { - const char *libPath = std::getenv("HOOK_TOOL_PATH"); - if (!libPath) { - MS_LOG(WARNING) << "HOOK_TOOL_PATH is not set!"; - return false; - } - - std::string resolvedLibPath(libPath); - if (!validateLibraryPath(resolvedLibPath)) { - MS_LOG(WARNING) << "Library path validation failed."; - return false; - } - - std::lock_guard lock(mutex_); - if (handle_) { - MS_LOG(WARNING) << "Hook library already loaded!"; - return false; - } - - handle_ = dlopen(resolvedLibPath.c_str(), RTLD_LAZY | RTLD_LOCAL); - if (!handle_) { - MS_LOG(WARNING) << "Failed to load Hook library: " << dlerror(); - return false; - } - - for (const auto &functionName : functionList_) { - if (!loadFunction(handle_, functionName)) { - MS_LOG(WARNING) << "Failed to load function: " << functionName; - dlclose(handle_); - handle_ = nullptr; - return false; - } - } - - MS_LOG(INFO) << "Hook library loaded successfully."; - return true; -} - -bool HookDynamicLoader::UnloadLibrary() { - std::lock_guard lock(mutex_); - if (!handle_) { - MS_LOG(WARNING) << "Hook library hasn't been loaded."; - return false; - } - - dlclose(handle_); - handle_ = nullptr; - funcMap_.clear(); - MS_LOG(INFO) << "Library unloaded successfully."; - return true; -} - -void *HookDynamicLoader::GetHooker(const std::string &funcName) { - std::lock_guard lock(mutex_); - auto iter = funcMap_.find(funcName); - if (iter == funcMap_.end()) { - MS_LOG(WARNING) << "Function not found: " << funcName; - return nullptr; - } - return iter->second; -} diff --git a/debug/accuracy_tools/msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp b/debug/accuracy_tools/msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a53272f03fc8747716e570a4e76e7cd582b011da --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp @@ -0,0 +1,111 @@ +/* + * Copyright (C) 2024-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hook_dynamic_loader.h" +#include +#include +#include +#include +#include "utils/log_adapter.h" + +namespace py = pybind11; + +HookDynamicLoader &HookDynamicLoader::GetInstance() +{ + static HookDynamicLoader instance; + return instance; +} + +bool HookDynamicLoader::LoadFunction(void *handle, const std::string &functionName) +{ + void *func = dlsym(handle, functionName.c_str()); + if (!func) { + MS_LOG(WARNING) << "Could not load function: " << functionName << ", error: " << dlerror(); + return false; + } + funcMap_[functionName] = func; + return true; +} + +bool HookDynamicLoader::LoadLibrary() +{ + std::string msprobePath = ""; + // 获取gil锁 + py::gil_scoped_acquire acquire; + try { + py::module msprobeMod = py::module::import("msprobe.lib._msprobe_c"); + if (!py::hasattr(msprobeMod, "__file__")) { + MS_LOG(WARNING) << "Adump mod not found"; + return false; + } + msprobePath = msprobeMod.attr("__file__").cast(); + } catch (const std::exception& e) { + MS_LOG(WARNING) << "Adump mod path unable to get: " << e.what(); + return false; + } + std::lock_guard lock(mutex_); + if (handle_) { + MS_LOG(WARNING) << "Hook library already loaded!"; + return false; + } + if (msprobePath == "") { + MS_LOG(WARNING) << "Adump path not loaded"; + return false; + } + handle_ = dlopen(msprobePath.c_str(), RTLD_LAZY | RTLD_LOCAL); + if (!handle_) { + MS_LOG(WARNING) << "Failed to load Hook library: " << dlerror(); + return false; + } + + for (const auto &functionName : functionList_) { + if (!LoadFunction(handle_, functionName)) { + MS_LOG(WARNING) << "Failed to load adump function"; + dlclose(handle_); + handle_ = nullptr; + return false; + } + } + + MS_LOG(INFO) << "Hook library loaded successfully."; + return true; +} + +bool HookDynamicLoader::UnloadLibrary() +{ + std::lock_guard lock(mutex_); + if (!handle_) { + MS_LOG(WARNING) << "Hook library hasn't been loaded."; + return false; + } + + dlclose(handle_); + handle_ = nullptr; + funcMap_.clear(); + MS_LOG(INFO) << "Library unloaded successfully."; + return true; +} + +void *HookDynamicLoader::GetHooker(const std::string &funcName) +{ + std::lock_guard lock(mutex_); + auto iter = funcMap_.find(funcName); + if (iter == funcMap_.end()) { + MS_LOG(WARNING) << "Function not found: " << funcName; + return nullptr; + } + return iter->second; +} diff --git a/debug/accuracy_tools/msprobe/mindspore/dym_loader/hook_dynamic_loader.h b/debug/accuracy_tools/msprobe/mindspore/dym_loader/hook_dynamic_loader.h index 6309e60b662a03d7f77cb450986ded5329fd8960..20c987715a411f2c622f720997239c499a2983cd 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dym_loader/hook_dynamic_loader.h +++ b/debug/accuracy_tools/msprobe/mindspore/dym_loader/hook_dynamic_loader.h @@ -1,5 +1,5 @@ -/** - * Copyright 2024 Huawei Technologies Co., Ltd +/* + * Copyright (C) 2024-2025. Huawei Technologies Co., Ltd. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,27 +27,26 @@ constexpr auto kHookBegin = "MS_DbgOnStepBegin"; constexpr auto kHookEnd = "MS_DbgOnStepEnd"; class HookDynamicLoader { - public: - static HookDynamicLoader &GetInstance(); +public: + static HookDynamicLoader &GetInstance(); - HookDynamicLoader(const HookDynamicLoader &) = delete; - HookDynamicLoader &operator=(const HookDynamicLoader &) = delete; + HookDynamicLoader(const HookDynamicLoader &) = delete; + HookDynamicLoader &operator=(const HookDynamicLoader &) = delete; - bool LoadLibrary(); - bool UnloadLibrary(); - void *GetHooker(const std::string &funcName); + bool LoadLibrary(); + bool UnloadLibrary(); + void *GetHooker(const std::string &funcName); - private: - // Helper functions - bool loadFunction(void *handle, const std::string &functionName); - bool validateLibraryPath(const std::string &libPath); +private: + // Helper functions + bool LoadFunction(void *handle, const std::string &functionName); - HookDynamicLoader() = default; + HookDynamicLoader() = default; - void *handle_ = nullptr; - std::vector functionList_ = {kHookBegin, kHookEnd}; - std::map funcMap_; - std::mutex mutex_; + void *handle_ = nullptr; + std::vector functionList_ = {kHookBegin, kHookEnd}; + std::map funcMap_; + std::mutex mutex_; }; #endif // HOOK_DYNAMIC_LOADER_H diff --git a/debug/accuracy_tools/msprobe/mindspore/exception_dump/__init__.py b/debug/accuracy_tools/msprobe/mindspore/exception_dump/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/mindspore/exception_dump/exception_dump_tool_factory.py b/debug/accuracy_tools/msprobe/mindspore/exception_dump/exception_dump_tool_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..db55811de80c4abfa738e50610b5582984b9e08d --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/exception_dump/exception_dump_tool_factory.py @@ -0,0 +1,51 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core.common.log import logger +from msprobe.mindspore.common.const import Const +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.exception_dump.kernel_graph_exception_dump import KernelGraphExceptionDump + + +class ExceptionDumpToolFactory: + tools = { + Const.CELL: { + Const.GRAPH_KBYK_MODE: None, + Const.GRAPH_GE_MODE: None, + Const.PYNATIVE_MODE: None + }, + Const.API: { + Const.GRAPH_KBYK_MODE: None, + Const.GRAPH_GE_MODE: None, + Const.PYNATIVE_MODE: None + }, + Const.KERNEL: { + Const.GRAPH_KBYK_MODE: KernelGraphExceptionDump, + Const.GRAPH_GE_MODE: None, + Const.PYNATIVE_MODE: KernelGraphExceptionDump + } + } + + @staticmethod + def create(config: DebuggerConfig): + tool = ExceptionDumpToolFactory.tools.get(config.level) + if not tool: + raise Exception("Valid level is needed.") + tool = tool.get(config.execution_mode) + if not tool: + logger.error(f"Exception dump is not supported in {config.execution_mode} mode " + f"when level is {config.level}.") + raise ValueError + return (tool(config),) diff --git a/debug/accuracy_tools/msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py b/debug/accuracy_tools/msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py new file mode 100644 index 0000000000000000000000000000000000000000..d9c4e6f72c129dccba86ff85319d7fdc6a139225 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py @@ -0,0 +1,57 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from msprobe.core.common.file_utils import create_directory, save_json +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig + + +class KernelGraphExceptionDump: + + def __init__(self, config: DebuggerConfig): + self.dump_json = dict() + self.dump_json["common_dump_settings"] = dict() + self.dump_json["common_dump_settings"]["dump_mode"] = 0 + self.dump_json["common_dump_settings"]["path"] = "" + self.dump_json["common_dump_settings"]["net_name"] = "Net" + self.dump_json["common_dump_settings"]["iteration"] = "all" + self.dump_json["common_dump_settings"]["saved_data"] = "tensor" + self.dump_json["common_dump_settings"]["input_output"] = 0 + self.dump_json["common_dump_settings"]["kernels"] = [] + self.dump_json["common_dump_settings"]["support_device"] = [0, 1, 2, 3, 4, 5, 6, 7] + self.dump_json["common_dump_settings"]["op_debug_mode"] = 4 + self.dump_json["common_dump_settings"]["file_format"] = "npy" + self.dump_json["e2e_dump_settings"] = dict() + self.dump_json["e2e_dump_settings"]["enable"] = not config.async_dump + self.dump_json["e2e_dump_settings"]["trans_flag"] = True + + if config.stat_cal_mode and config.device_stat_precision_mode: + self.dump_json["e2e_dump_settings"]["stat_calc_mode"] = config.stat_cal_mode + self.dump_json["e2e_dump_settings"]["device_stat_precision_mode"] = config.device_stat_precision_mode + self.dump_json["common_dump_settings"]["path"] = config.dump_path + if len(config.step) > 0: + logger.warning("Step would change to all in this task.") + if len(config.rank) > 0: + self.dump_json["common_dump_settings"]["support_device"] = config.rank + + def handle(self): + json_path = self.dump_json["common_dump_settings"]["path"] + create_directory(json_path) + json_path = os.path.join(json_path, "kernel_graph_exception_check.json") + save_json(json_path, self.dump_json, indent=4) + logger.info(json_path + " has been created.") + os.environ["MINDSPORE_DUMP_CONFIG"] = json_path diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/api_pynative_self_check.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/api_pynative_self_check.py index 57b7de4fa567d73a19178256d79f5e4cbeb38864..8f77caabde3d5481cf034d41a48efbe86119848a 100644 --- a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/api_pynative_self_check.py +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/api_pynative_self_check.py @@ -16,25 +16,30 @@ import functools import importlib import os +import threading import traceback import mindspore as ms + from msprobe.core.common.const import Const from msprobe.core.common.exceptions import DistributedNotInitializedError from msprobe.core.common.file_utils import check_path_length, load_yaml +from msprobe.core.common.runtime import Runtime +from msprobe.core.hook_manager import HookSet from msprobe.mindspore.common.const import Const as MsConst from msprobe.mindspore.common.const import FreeBenchmarkConst from msprobe.mindspore.common.log import logger from msprobe.mindspore.common.utils import get_rank_if_initialized from msprobe.mindspore.debugger.debugger_config import DebuggerConfig -from msprobe.mindspore.dump.hook_cell.api_registry import api_register +from msprobe.mindspore.dump.hook_cell.api_register import get_api_register from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell from msprobe.mindspore.free_benchmark.common.config import Config from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams from msprobe.mindspore.free_benchmark.common.utils import Tools from msprobe.mindspore.free_benchmark.handler.handler_factory import HandlerFactory from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import PerturbationFactory -from msprobe.mindspore.runtime import Runtime + +_api_register = get_api_register() class ApiPyNativeSelfCheck: @@ -60,8 +65,8 @@ class ApiPyNativeSelfCheck: self.store_original_func() def handle(self): - api_register.initialize_hook(self.build_hook) - api_register.api_set_hook_func() + _api_register.initialize_hook(self.build_hook) + _api_register.register_all_api() def build_hook(self, api_name): def pre_hook(cell, input_data): @@ -69,9 +74,10 @@ class ApiPyNativeSelfCheck: def forward_hook(api_name_with_id, cell, input_data, output_data): ret = None + tid = threading.get_ident() if not need_wrapper_func(): - del cell.input_kwargs + del cell.msprobe_input_kwargs[tid] return ret api_name_with_id = api_name_with_id[:-1] @@ -80,9 +86,9 @@ class ApiPyNativeSelfCheck: api_name_with_id[api_name_with_id.find(Const.SEP) + 1:api_name_with_id.rfind(Const.SEP)]) if api_name in self.api_list: ret = check_self(api_name_with_id, output_data, self.ori_func.get(api_name), - *input_data, **cell.input_kwargs) + *input_data, **cell.msprobe_input_kwargs[tid]) - del cell.input_kwargs + del cell.msprobe_input_kwargs[tid] return ret def backward_hook(cell, grad_input, grad_output): @@ -101,8 +107,13 @@ class ApiPyNativeSelfCheck: def pre_backward_hook(cell, grad_input): return None - - return pre_hook, wrap_forward_hook, wrap_backward_hook, pre_backward_hook + + return HookSet( + forward_hook=wrap_forward_hook, + forward_pre_hook=pre_hook, + backward_hook=wrap_backward_hook, + backward_pre_hook=pre_backward_hook + ) def store_original_func(self): for api_name in self.api_list: @@ -166,13 +177,13 @@ def check_self(api_name_with_id, output, ori_func, *args, **kwargs): return ret logger.info(f"[{api_name_with_id}] is {Config.handler_type}ing.") - api_register.api_set_ori_func() + _api_register.restore_all_api() try: perturbation = PerturbationFactory.create(api_name_with_id) params.fuzzed_result = perturbation.handle(params) if params.fuzzed_result is False: - api_register.api_set_hook_func() + _api_register.register_all_api() return ret if Config.stage == Const.BACKWARD: params.original_result = Tools.get_grad(params.original_func, *params.args, **params.kwargs) @@ -183,7 +194,7 @@ def check_self(api_name_with_id, output, ori_func, *args, **kwargs): logger.error(f"[{api_name_with_id}] Error: {str(e)}") logger.error(f"[{api_name_with_id}] Error detail: {traceback.format_exc()}") - api_register.api_set_hook_func() + _api_register.register_all_api() return ret diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/common/utils.py index 14a72a5e6b6a6289595897a15c46a0e6397bcd1a..c3f9b27fe2b5792119f7105955cdecfa6bdc51d4 100644 --- a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/common/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/common/utils.py @@ -19,10 +19,10 @@ from typing import Any, Optional import mindspore as ms from mindspore import Tensor, ops +from msprobe.core.common.runtime import Runtime from msprobe.mindspore.common.const import FreeBenchmarkConst from msprobe.mindspore.free_benchmark.common.config import Config from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams -from msprobe.mindspore.runtime import Runtime class Tools: diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py index 3fd1430bff792d5043429caac8fe477e457b8bee..39ca164f2043c5d8f6d2e05987edfffe5bca2bee 100644 --- a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,6 +14,7 @@ # limitations under the License. from msprobe.mindspore.common.const import FreeBenchmarkConst +from msprobe.mindspore.common.log import logger from msprobe.mindspore.free_benchmark.common.config import Config from msprobe.mindspore.free_benchmark.perturbation.add_noise import AddNoisePerturbation from msprobe.mindspore.free_benchmark.perturbation.bit_noise import BitNoisePerturbation @@ -41,4 +42,5 @@ class PerturbationFactory: if perturbation: return perturbation(api_name_with_id) else: - raise Exception(f'{Config.pert_type} is a invalid perturbation type') + logger.error(f'{Config.pert_type} is a invalid perturbation type') + raise ValueError diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/self_check_tool_factory.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/self_check_tool_factory.py index 35b5eb2ab65511fa4320dc97702a60a9c8d07f62..b21b15d1758a90e62861c7edf2976d38ab43c5f0 100644 --- a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/self_check_tool_factory.py +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/self_check_tool_factory.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,6 +14,7 @@ # limitations under the License. from msprobe.mindspore.common.const import Const +from msprobe.core.common.log import logger from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.free_benchmark.api_pynative_self_check import ApiPyNativeSelfCheck @@ -41,8 +42,10 @@ class SelfCheckToolFactory: def create(config: DebuggerConfig): tool = SelfCheckToolFactory.tools.get(config.level) if not tool: - raise Exception(f"{config.level} is not supported.") + logger.error(f"{config.level} is not supported.") + raise ValueError tool = tool.get(config.execution_mode) if not tool: - raise Exception(f"Task free_benchmark is not supported in this mode: {config.execution_mode}.") + logger.error(f"Task free_benchmark is not supported in this mode: {config.execution_mode}.") + raise ValueError return tool(config) diff --git a/debug/accuracy_tools/msprobe/mindspore/grad_probe/global_context.py b/debug/accuracy_tools/msprobe/mindspore/grad_probe/global_context.py index 01e46e019a4d1634a4592970386d855637c34e8f..0215439aa642716e7d03175dc5ce6e3da032df17 100644 --- a/debug/accuracy_tools/msprobe/mindspore/grad_probe/global_context.py +++ b/debug/accuracy_tools/msprobe/mindspore/grad_probe/global_context.py @@ -16,6 +16,7 @@ import os import threading from typing import Dict, Union, Tuple +import time from msprobe.core.common.utils import is_int from msprobe.core.common.file_utils import create_directory, check_path_before_create @@ -40,8 +41,12 @@ class GlobalContext: def __new__(cls, *args, **kwargs): if cls._instance is None: cls._instance_lock.acquire() - cls._instance = object.__new__(cls) - cls._instance_lock.release() + try: + cls._instance = object.__new__(cls) + except Exception as e: + raise RuntimeError("grad_probe global context init failed") from e + finally: + cls._instance_lock.release() return cls._instance def init_context(self, config_dict: Dict): @@ -69,6 +74,8 @@ class GlobalContext: else: logger.warning("The output_path exists, the data will be covered.") + self._setting[GradConst.TIME_STAMP] = str(int(time.time())) + def get_context(self, key: str): if key not in self._setting: logger.warning(f"Unrecognized {key}.") diff --git a/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_analyzer.py b/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_analyzer.py index 8a154f4d65f63e55f6b0cf3165d3c905bcb68546..79f56436239e13a14b5826a693f634f08e157926 100644 --- a/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_analyzer.py +++ b/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_analyzer.py @@ -111,7 +111,8 @@ class CSVGenerator(Process): output_path = context.get_context(GradConst.OUTPUT_PATH) self.level = context.get_context(GradConst.LEVEL) self.bounds = context.get_context(GradConst.BOUNDS) - self.dump_dir = f"{output_path}/rank{rank_id}/Dump/" + time_stamp = context.get_context(GradConst.TIME_STAMP) + self.dump_dir = f"{output_path}/rank{rank_id}/Dump{time_stamp}/" self.save_dir = f"{output_path}/rank{rank_id}/" self.current_step = None self.stop_event = multiprocessing.Event() @@ -244,6 +245,8 @@ class CSVGenerator(Process): return ["Max", "Min", "Norm", "Shape"] def get_dist_header(self) -> List[str]: + if not self.bounds: + return [] intervals = [] for i, _ in enumerate(self.bounds): if i == 0: diff --git a/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_stat_csv.py b/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_stat_csv.py index 9cc30ea1b9d6575bcd5af94c27f19cb93ed7246d..820f7f21d0cd5e6b1fd98f93f1515515407358c0 100644 --- a/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_stat_csv.py +++ b/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_stat_csv.py @@ -15,6 +15,7 @@ import hashlib from abc import ABC, abstractmethod +import zlib import mindspore from mindspore import ops @@ -76,8 +77,8 @@ class CsvMd5(CsvItem): def generate_csv_content(csv_input): grad = csv_input.grad tensor_bytes = grad.float().numpy().tobytes() - md5_hash = hashlib.md5(tensor_bytes) - return [md5_hash.hexdigest()] + md5_hash = f"{zlib.crc32(tensor_bytes):08x}" + return [md5_hash] @register_csv_item(GradConst.DISTRIBUTION) diff --git a/debug/accuracy_tools/msprobe/mindspore/grad_probe/hook.py b/debug/accuracy_tools/msprobe/mindspore/grad_probe/hook.py index 1aa9fcfad10815d5845de66ab0ea6d4d7211741f..36857636fa301db37ae4267f8e18d41d9f0328a5 100644 --- a/debug/accuracy_tools/msprobe/mindspore/grad_probe/hook.py +++ b/debug/accuracy_tools/msprobe/mindspore/grad_probe/hook.py @@ -49,12 +49,10 @@ class HookInput: self.param_list = grad_context.get_context(GradConst.PARAM_LIST) self.rank_id = get_rank_id() output_path = grad_context.get_context(GradConst.OUTPUT_PATH) - self.dump_dir = os.path.join(output_path, f"rank{self.rank_id}", "Dump") + time_stamp = grad_context.get_context(GradConst.TIME_STAMP) + self.dump_dir = os.path.join(output_path, f"rank{self.rank_id}", f"Dump{time_stamp}") self.save_dir = os.path.join(output_path, f"rank{self.rank_id}") self.step_finish_flag = os.path.join(self.dump_dir, GradConst.STEP_FINISH) - if os.path.exists(self.save_dir): - logger.warning(f"Delete existing path {self.save_dir}.") - remove_path(self.save_dir) self.level = grad_context.get_context(GradConst.LEVEL) self.bounds = grad_context.get_context(GradConst.BOUNDS) self.mode = mindspore.get_context("mode") diff --git a/debug/accuracy_tools/msprobe/mindspore/mindspore_service.py b/debug/accuracy_tools/msprobe/mindspore/mindspore_service.py new file mode 100644 index 0000000000000000000000000000000000000000..29e15272d88a10385d575a41b19bf2712005649f --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/mindspore_service.py @@ -0,0 +1,114 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +import mindspore as ms +from mindspore.ops.primitive import Primitive + +from msprobe.core.common.utils import Const +from msprobe.core.service import BaseService +from msprobe.mindspore.cell_processor import CellProcessor +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.common.utils import ( + get_rank_if_initialized, + is_mindtorch, + get_cells_and_names_with_index +) +from msprobe.mindspore.dump.hook_cell.api_register import get_api_register, ApiTemplate +from msprobe.mindspore.dump.hook_cell.ms_hook_manager import MindsporeHookManager +from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService +from msprobe.mindspore.dump.jit_dump import JitDump + +try: + from mindspore.common._pijit_context import PIJitCaptureContext +except ImportError: + pijit_label = False +else: + pijit_label = True + + +class MindsporeService(BaseService): + @property + def _get_framework_type(self): + return Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK + + @staticmethod + def _get_current_rank(): + return get_rank_if_initialized() + + def empty(self, *args, **kwargs): + pass + + def reset_status(self): + self._reset_status() + + def _init_specific_components(self): + self.logger = logger + self.api_register = get_api_register() + self.primitive_hook_service = PrimitiveHookService(self) + self.cell_processor = CellProcessor(self.data_collector.scope) + self.hook_manager = MindsporeHookManager(self.data_collector, self.config) + self._setup_jit_context() + self.api_template = ApiTemplate + + def _setup_jit_context(self): + if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]: + JitDump.set_config(self.config) + JitDump.set_data_collector(self.data_collector) + if hasattr(ms.common.api, "_MindsporeFunctionExecutor"): + ms.common.api._MindsporeFunctionExecutor = JitDump + else: + ms.common.api._JitExecutor = JitDump + ms.common.api._PyNativeExecutor.grad = JitDump.grad + if pijit_label: + PIJitCaptureContext.__enter__ = self.empty + PIJitCaptureContext.__exit__ = self.empty + + def _register_module_hook(self): + self.cell_processor.register_cell_hook(self.model, self.build_hook, self.config) + self.logger.info_on_rank_0(f"The module {self.config.task} hook function is successfully mounted to the model.") + + def _register_hook(self): + self._register_primitive_hook() + + def _register_primitive_hook(self): + if self.config.level not in [Const.LEVEL_MIX, Const.LEVEL_L1]: + return + if not self.model or self.config.task not in Const.DUMP_DATA_COLLECTION_LIST: + return + + primitive_set = set() + cells_and_names_with_index, _ = get_cells_and_names_with_index(self.model) + for cells_and_names in cells_and_names_with_index.values(): + for _, cell in cells_and_names: + for attribute, value in vars(cell).items(): + if isinstance(value, Primitive): + primitive_set.add((attribute, value)) + + for pname, primitive in primitive_set: + primitive_class_name = primitive.__class__.__name__ + primitive_combined_name = pname + Const.SEP + primitive_class_name + new_primitive = type('NewPrimitive', (primitive.__class__,), + {'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__, + primitive_combined_name)}) + primitive.__class__ = new_primitive + + def _reset_status(self): + super()._reset_status() + self.primitive_hook_service.primitive_counters.clear() + JitDump.jit_count = defaultdict(int) + + def _change_jit_switch(self, status): + JitDump.jit_dump_switch = status diff --git a/debug/accuracy_tools/msprobe/mindspore/mindtorch/mindtorch_adaptor.py b/debug/accuracy_tools/msprobe/mindspore/mindtorch/mindtorch_adaptor.py index 27e42d52ba6190ec7e7531af25464e6aa3996b2b..7ca256dbcf8f9011f5ca84898f643888ea7f890e 100644 --- a/debug/accuracy_tools/msprobe/mindspore/mindtorch/mindtorch_adaptor.py +++ b/debug/accuracy_tools/msprobe/mindspore/mindtorch/mindtorch_adaptor.py @@ -93,6 +93,8 @@ from torch.nn.modules.module import (_global_backward_pre_hooks, _global_backwar _global_forward_hooks, _global_forward_hooks_always_called) from torch.utils.hooks import RemovableHandle +from msprobe.mindspore.common.utils import is_backward_hook_output_a_view + def _call_impl(self, *args, **kwargs): forward_call = self.forward @@ -245,11 +247,14 @@ def _get_backward_hooks(self): def apply_backward_hook_on_tensors(cell_backward_hook, args): - is_tuple = True - if not isinstance(args, tuple): - args = (args,) - is_tuple = False - hooked_args = cell_backward_hook(*args) - if is_tuple and len(args) == 1: - hooked_args = (hooked_args, ) + if is_backward_hook_output_a_view(): + hooked_args = cell_backward_hook(args) + else: + is_tuple = True + if not isinstance(args, tuple): + args = (args,) + is_tuple = False + hooked_args = cell_backward_hook(*args) + if is_tuple and len(args) == 1: + hooked_args = (hooked_args, ) return hooked_args diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py b/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py new file mode 100644 index 0000000000000000000000000000000000000000..5880d2284f33f8a9a406cabb298e921e85e6b6b5 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from mindspore import nn +from mindspore import communication +from msprobe.core.common.log import logger +from msprobe.mindspore.common.utils import is_mindtorch +if is_mindtorch(): + import torch + + +def is_valid_instance(model): + return isinstance(model, torch.nn.Module) if is_mindtorch() else isinstance(model, nn.Cell) + + +def get_submodules(model): + if not is_valid_instance(model): + logger.info("Counter invalid model, nothing to hook") + return {} + return model.named_modules() if is_mindtorch() else model.cells_and_names() + + +def get_parameters(model): + if not is_valid_instance(model): + return {} + if is_mindtorch(): + return model.named_parameters() + else: + return model.parameters_and_names() + + +def get_rank(): + if comm_is_initialized(): + return communication.get_rank() + return 0 + + +def comm_is_initialized(): + return communication.GlobalComm.INITED diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/anomaly_detect.py b/debug/accuracy_tools/msprobe/mindspore/monitor/data_writers.py similarity index 42% rename from debug/accuracy_tools/msprobe/mindspore/monitor/anomaly_detect.py rename to debug/accuracy_tools/msprobe/mindspore/monitor/data_writers.py index 3544ebbd025614349585bc799b15e00a5c2c7956..85c1096123c337a123b16f18236655bfe6e49c5e 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/anomaly_detect.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/data_writers.py @@ -15,91 +15,20 @@ import itertools import os -import sys -import statistics as st -from abc import ABC -from dataclasses import dataclass, field -from typing import List +from dataclasses import dataclass from collections import defaultdict import pandas as pd - from mindspore import ops +from mindspore import Tensor from mindspore import _no_grad + from msprobe.core.common.log import logger from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv +from msprobe.core.monitor.anomaly_processor import AnomalyDataFactory, AnomalyTurbulence, AnomalyScanner from msprobe.core.common.const import FileCheckConst, MonitorConst -class ScanRule(ABC): - name = "ScanRule" - - def apply(self, history, cur): - raise NotImplementedError("abstract method apply is not implemented") - - -class AnomalyTurbulence(ScanRule): - name = "AnomalyTurbulence" - - def __init__(self, threshold) -> None: - self.threshold = threshold - - def apply(self, history, cur): - baseline = st.mean(history) if isinstance(history, list) else history - - up_bound = baseline + baseline * self.threshold - if baseline > 0: - return cur > up_bound - else: - return cur < up_bound - - -class AnomalyScanner: - - @staticmethod - def load_rules(specs: List[dict]): - """ - specs: [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}] - """ - if specs is None: - return [] - alert_rules = [] - for spec in specs: - # 使用get方法获取键值,如果键不存在则返回None - rule_cls_name = spec.get("rule_name") - rule_args = spec.get("args") - - # 检查必要的键是否存在 - if rule_cls_name is None or rule_args is None: - logger.warning(f"Spec is missing required keys: {spec}") - continue - - cur_module = sys.modules.get(__name__) - try: - rule_cls = getattr(cur_module, rule_cls_name) - except AttributeError: - logger.error(f"Rule class '{rule_cls_name}' not found in the current module.") - continue - - try: - rule_instance = rule_cls(**rule_args) - alert_rules.append(rule_instance) - except Exception as e: - logger.error(f"Error creating instance of rule '{rule_cls_name}': {e}") - continue - - return alert_rules - - @staticmethod - def scan(scan_rules: List[ScanRule], history, cur): - anomaly = False - for rule in scan_rules: - anomaly = rule.apply(history, cur) - if anomaly: - return anomaly, rule.name - return anomaly, None - - class BCOLORS: HEADER = '\033[95m' OKBLUE = '\033[94m' @@ -112,130 +41,6 @@ class BCOLORS: UNDERLINE = '\033[4m' -class AnomalyDataFactory(ABC): - def __init__(self, rank, pp_stage, group_mates): - super().__init__() - self.rank = rank - self.pp_stage = pp_stage - self.group_mates = group_mates - self.micro_step = 0 - self.name2callid = {} - - def set_call_id(self, name2callid): - """根据当前GradContext信息更新call_id vpp_stage等信息 - """ - self.name2callid = name2callid - - def create(self, tag, message, step): - """如果检查出异常, 调用当前接口生成GradAnomalyData实例 - tag (tuple): metric tag ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min') - message (str): anomaly detect message - step (int): training step - """ - if not isinstance(tag, tuple) or len(tag) != 2: - raise ValueError("tag must be a tuple with length 2") - tag_name = tag[0] - param_name = tag_name.split('/')[0] - call_id = self.name2callid.get(tag_name, -1) - if MonitorConst.NAME_SEP in param_name: - vpp_stage = int(param_name.split(MonitorConst.NAME_SEP)[0]) - else: - vpp_stage = 0 - - return GradAnomalyData( - self.rank, - step, - self.micro_step, - self.pp_stage, - vpp_stage, - call_id, - tag_name, - message, - self.group_mates - ) - - -class TrainStage: - DEFAULT_STAGE = -1 - FORWARD_STAGE = 0 - BACKWARD_STAGE = 1 - OPTIMIZER_STAGE = 2 - - -FORWARD_KEY = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT] -BACKWARD_KEY = [MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT, - MonitorConst.PRE_GRAD, MonitorConst.POST_GRAD, MonitorConst.ACC_GRAD] -OPTIMIZER_KEY = [MonitorConst.EXP_AVG, MonitorConst.EXP_AVG_SQ] -TRAIN_STAGE = { - **{key_: TrainStage.FORWARD_STAGE for key_ in FORWARD_KEY}, - **{key_: TrainStage.BACKWARD_STAGE for key_ in BACKWARD_KEY}, - **{key_: TrainStage.OPTIMIZER_STAGE for key_ in OPTIMIZER_KEY} -} - - -@dataclass(eq=True) -class GradAnomalyData: - rank: int = 0 - step: int = 0 - micro_step: int = 0 - pp_stage: int = 0 - vpp_stage: int = 0 - call_id: int = 0 - tag_name: str = field(default=None, compare=False) - message: str = field(default="", compare=False) - group_mates: list = field(default=None, compare=False) - - def __lt__(self, other): - """ - 自定义比较函数,用于确定 GradAnomalyData 实例之间的顺序。 - 比较规则为: - step 和 micro_step 值越小优先级越高; - vpp 和 pp 在前向阶段值越小优先级越高,在非前向阶段值越大优先级越高; - call_id 值越小优先级越高。 - """ - if not isinstance(other, GradAnomalyData): - return NotImplemented - - self_train_stage = self.get_train_stage(self.tag_name) - other_train_stage = self.get_train_stage(other.tag_name) - - def vpp_pp_comparator(anomaly): - """ - Determine the priority rule for vpp and pp based on train stage - Forward stage prefers smaller vpp and pp - Other stages prefer larger vpp and pp - """ - if self_train_stage == TrainStage.FORWARD_STAGE: - return anomaly.vpp_stage, anomaly.pp_stage - else: - return -anomaly.vpp_stage, -anomaly.pp_stage - - self_cmp = [self.step, self.micro_step, self_train_stage, *vpp_pp_comparator(self), self.call_id] - other_cmp = [other.step, other.micro_step, other_train_stage, *vpp_pp_comparator(other), other.call_id] - return self_cmp < other_cmp - - def __le__(self, other): - if not isinstance(other, GradAnomalyData): - return NotImplemented - return self == other or self < other - - @staticmethod - def get_train_stage(tag_name): - """ - :param tag_name: "0:fc2_0/rank0/input", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq" - :return: int, if forward return 0; if backward return 1; if optimizer return 2 - """ - key_ = tag_name.split("/")[-1] - return TRAIN_STAGE.get(key_, TrainStage.DEFAULT_STAGE) - - def to_dict(self): - return self.__dict__ - - def get_key(self): - # 0:1.self_attention.core_attention_flash_0/rank0/input_grad - return ''.join([str(self.tag_name), "_step_", str(self.step), "_call_", str(self.call_id)]) - - @dataclass class WriterInput: path: str @@ -254,6 +59,41 @@ class BaseWriterWithAD: self.anomaly_factory = writer_input.anomaly_factory self.anomalies = [] self.ndigits = writer_input.ndigits + self.beta = 0.99 + + @staticmethod + def stack_tensors(tensor_list): + """ + Torch not support stack cpu and xpu tensors. Group the tensors into cpu_group and xpu_group, + stack them separately, migrate xpu_group to cpu, and then restore in the order of input. + + :param tensor_list: [tensor(-1.6165), tensor(-1.0985), tensor(-1.7777), tensor(-1.8408, device='npu:0')] + :return: result: list of float + """ + cpu_tensors = [] + xpu_tensors = [] + + for tensor in tensor_list: + if isinstance(tensor, Tensor): + # 将device上的tensor先stack后to cpu + xpu_tensors.append(tensor) + else: + cpu_tensors.append(tensor) + + xpu_stack = ops.stack(xpu_tensors).tolist() if xpu_tensors else ops.tensor([]) + + # 按照输入的顺序恢复 + result = [] + cpu_tensors_idx, xpu_tensors_idx = 0, 0 + for tensor in tensor_list: + if isinstance(tensor, Tensor): + result.append(xpu_stack[xpu_tensors_idx]) + xpu_tensors_idx += 1 + else: + result.append(cpu_tensors[cpu_tensors_idx]) + cpu_tensors_idx += 1 + + return result def get_anomalies(self): """返回已检测到的异常列表 @@ -272,12 +112,17 @@ class BaseWriterWithAD: Returns: None """ - detected = False - if self.ad_rules: - avg = self._update_tag2scalars(tag, scalar_value) - detected, rule_name = self._ad(scalar_value, history=avg) + if not self.ad_rules or tag[-1] in ["shape", "dtype"]: + return + if isinstance(scalar_value, Tensor): + scalar_value = scalar_value.item() + avg = self._update_tag2scalars(tag, scalar_value) + detected, rule_name = self._ad(scalar_value, history=avg) if detected: - exception_message = f"Rule {rule_name} reports anomaly signal in {tag} at step {global_step}." + if rule_name == AnomalyTurbulence.name and tag[-1] not in ["norm", "mean"]: + return + exception_message = (f"Rule {rule_name} reports anomaly signal in {tag} at step {global_step}, " + f"current value {scalar_value}, history mean {avg}.") logger.info(f"{BCOLORS.WARNING}> {exception_message}{BCOLORS.ENDC}") # append to self.anomalies for dump if self.anomaly_factory: @@ -290,8 +135,12 @@ class BaseWriterWithAD: tags = list(itertools.product(metric_value.keys(), op_list)) for op2tensor in metric_value.values(): tensors.extend(op2tensor.values()) + + if not tensors: + return + with _no_grad(): - metric_list = ops.stack(tensors).tolist() if tensors else [] + metric_list = self.stack_tensors(tensors) for tag, metric in zip(tags, metric_list): self.add_scalar(tag, metric, step, need_explain) @@ -311,11 +160,11 @@ class BaseWriterWithAD: Returns: float: The average value before update. """ + abs_scalar_value = abs(scalar_value) if tag not in self.tag2scalars: - self.tag2scalars[tag] = {'avg': scalar_value, 'count': 0} + self.tag2scalars[tag] = {'avg': abs_scalar_value, 'count': 0} avg = self.tag2scalars[tag]['avg'] - new_avg = (avg * self.tag2scalars[tag]['count'] + scalar_value) / (self.tag2scalars[tag]['count'] + 1) - self.tag2scalars[tag]['avg'] = new_avg + self.tag2scalars[tag]['avg'] = self.beta * avg + (1 - self.beta) * abs_scalar_value self.tag2scalars[tag]['count'] += 1 return avg @@ -353,11 +202,10 @@ class CSVWriterWithAD(BaseWriterWithAD): new_data = [] for name, metric_value in self.context_dict.items(): - if MonitorConst.NAME_SEP not in name: - new_data.append([name] + [step] + metric_value) - else: - new_data.append(name.split(MonitorConst.NAME_SEP) + [step] + metric_value) - new_data = pd.DataFrame(new_data).round(self.ndigits) + new_line = name.split(MonitorConst.NAME_SEP) + metric_value + new_line.insert(2, step) + new_data.append(new_line) + new_data = pd.DataFrame(new_data).round(self.ndigits).fillna("nan") write_df_to_csv(new_data, filepath, mode='a+', header=False) self.context_dict = defaultdict(list) @@ -375,30 +223,15 @@ class CSVWriterWithAD(BaseWriterWithAD): name += '.output' self.context_dict[name].append(scalar_value) - def write_metrics(self, op_list, metric_value, step, prefix='', need_explain=False): + def write_metrics(self, op_list, metric_value, step, prefix='', need_explain=False, **kwargs): need_explain = prefix == 'other' super().write_metrics(op_list, metric_value, step, prefix='', need_explain=need_explain) - # generate csv headers - # set hashmap to reduce the number of headers generated. - # 前向的norm用input.ops_和output.ops_,反向的用input_grad.ops_和output_grad.ops_ - if prefix in {"actv", "actv_grad"}: - if prefix == "actv": - input_and_output = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT] - else: - input_and_output = [MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT] - ops_ = [MonitorConst.DOT.join(i) for i in itertools.product(input_and_output, op_list)] - csv_header = ["module_name", "step", *ops_] + if prefix in [MonitorConst.ACTV, MonitorConst.ACTVGRAD] or kwargs.get("use_micro_step"): + self.header = MonitorConst.CSV_HEADER_MICRO_STEP + op_list else: - csv_header = ["param_name", "step", *op_list] - - keys = list(metric_value.keys()) - if keys and MonitorConst.NAME_SEP in keys[0]: - csv_header.insert(0, "vpp_stage") - - self.header = csv_header + self.header = MonitorConst.CSV_HEADER + op_list self.write_csv(prefix, step) - self.header = [] def close(self): pass diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/features.py b/debug/accuracy_tools/msprobe/mindspore/monitor/features.py index be958dadfe8fcc50f26f16c93b3a090269235d1e..3c31762b7d366982c58b5e81982a548f1e8df769 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/features.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/features.py @@ -17,6 +17,8 @@ from mindspore import mint, ops, _no_grad from mindspore import Tensor from mindspore import dtype as mstype +from msprobe.core.common.log import logger + @_no_grad() def square_sum(x: Tensor): @@ -46,6 +48,8 @@ def get_max(x: Tensor): @_no_grad() def get_zeros(x: Tensor, eps: float): + if x.numel() == 0: + return Tensor(float('nan')) return mint.sum(mint.abs(x) < eps) / x.numel() @@ -54,10 +58,101 @@ def get_nans(t): return ops.isnan(t.astype(mstype.float32)).sum() -FUNC_MAP = {"min" : get_min, - "max" : get_max, - "mean" : get_mean, - "norm" : get_norm, - "nans" : get_nans, - "zeros": get_zeros - } \ No newline at end of file +def get_shape(t): + return t.shape + + +def get_dtype(t): + return t.dtype + + +FUNC_MAP = { + "min": get_min, + "max": get_max, + "mean": get_mean, + "norm": get_norm, + "nans": get_nans, + "zeros": get_zeros, + "shape": get_shape, + "dtype": get_dtype +} + + +def max_eigenvalue(input_tensor: Tensor, num_iterations=3): + input_tensor = input_tensor.float() + try: + check_tensor_dim(input_tensor, 2) + except (TypeError, ValueError) as e: + logger.warning(f"calcute max eigenvalue failed, {e}") + return Tensor(0) + in_features = input_tensor.shape[1] + u_tensor = ops.randn(in_features) + u_norm = u_tensor.norm() + if u_norm == 0: + return Tensor(0) + u_tensor /= u_tensor.norm() + input_seq = ops.matmul(input_tensor.T, input_tensor) + for _ in range(num_iterations): + v_tensor = ops.matmul(input_seq, u_tensor) + spectral_norm = ops.matmul(v_tensor.T, u_tensor) + v_norm = v_tensor.norm() + if v_norm > 0: + u_tensor = v_tensor / v_norm + else: + spectral_norm = Tensor(0) + break + return spectral_norm.sqrt() + + +def check_tensor_dim(tensor, n): + if not isinstance(tensor, Tensor): + raise TypeError( + f"Input must be a mindspore Tensor, but got {type(tensor)} instead." + ) + if len(tensor.shape) < n: + raise ValueError( + f"tensor dim must be at least {n} dimensions." + f"Got shape: {tuple(tensor.shape)} with {tensor.dim()} dims" + ) + + +def cal_entropy(qk_tensor: Tensor, mask=None): + try: + check_tensor_dim(qk_tensor, 2) + except (TypeError, ValueError) as e: + logger.warning(f"calculate entropy failed, {e}") + return Tensor(0), Tensor(0) + if mask is None: + mask = ops.tril(ops.ones((qk_tensor.shape[1], qk_tensor.shape[1]))) + qk_tensor = qk_tensor - ops.amax(qk_tensor, axis=1, keepdims=True) + qk_tensor = qk_tensor.masked_fill(mask == 0, float('-inf')) + softmax_qkt = ops.softmax(qk_tensor.float(), axis=1) + softmax_max = ops.mean(ops.amax(softmax_qkt, axis=1)) + entropy = ops.mean(-ops.nansum(softmax_qkt * ops.log(softmax_qkt), axis=1)) + return entropy, softmax_max + + +def cal_stable_rank(weight: Tensor): + eig = max_eigenvalue(weight) + if eig == Tensor(0): + return Tensor(0), Tensor(0) + f_norm = ops.norm(weight, ord='fro') + return f_norm / eig, eig + + +def cal_qkt(q_h: Tensor, k_h: Tensor, order="s,b,h,d"): + # q_h shape is (s, b, h, d) + try: + check_tensor_dim(q_h, 4) + check_tensor_dim(k_h, 4) + except (TypeError, ValueError) as e: + logger.warning(f"calculatee qkt failed, {e}") + return Tensor(0) + if order == "s,b,h,d": + qkt = ops.matmul(q_h[:, 0, 0, :], k_h[:, 0, 0, :].t()) / q_h.shape[-1] ** 0.5 + elif order == "b,s,h,d": + qkt = ops.matmul(q_h[0, :, 0, :], k_h[0, :, 0, :].t()) / q_h.shape[-1] ** 0.5 + else: + logger.warning(f"Calculate qk tensor failed: Order unsupported.") + qkt = Tensor(0) + return qkt diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py index 068be9ff6c782bec2bf637999ef5f0eabe0c2675..610778fa7c8ba7e1b981ec64dad81853d6721900 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py @@ -13,28 +13,36 @@ # See the License for the specific language governing permissions and # limitations under the License. +from gzip import FEXTRA import os import re import uuid from collections import defaultdict from datetime import datetime +from functools import partial import pytz -import mindspore as ms +import pandas as pd +import mindspore from mindspore import Tensor, mint from mindspore import nn, _no_grad -from mindspore.communication import get_rank from msprobe.core.common.log import logger -from msprobe.core.common.const import MonitorConst -from msprobe.core.common.file_utils import load_json, save_json -from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, validate_config, step_accumulates_one, \ - is_skip_step, get_metrics, get_single_metrics, get_target_output_dir -from msprobe.mindspore.monitor.module_spec_verifier import validate_config_spec -from msprobe.mindspore.monitor.anomaly_detect import AnomalyScanner, AnomalyDataFactory, \ - CSVWriterWithAD, BaseWriterWithAD, WriterInput -from msprobe.mindspore.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \ - get_process_group +from msprobe.core.common.const import MonitorConst, Const +from msprobe.core.common.file_utils import load_json, save_json, make_dir +from msprobe.core.monitor.utils import validate_config, get_output_base_dir, get_target_output_dir +from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter +from msprobe.mindspore.common.utils import is_mindtorch +from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank, \ + comm_is_initialized +from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, step_accumulates_one, is_skip_step, \ + get_metrics, get_entropy_metric, get_sr_metric +from msprobe.mindspore.monitor.optimizer_collect import OptimizerMonFactory +from msprobe.mindspore.monitor.data_writers import CSVWriterWithAD, BaseWriterWithAD, WriterInput +from msprobe.mindspore.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate +from msprobe.mindspore.monitor.features import cal_qkt +from msprobe.core.common.file_utils import write_df_to_csv +from msprobe.core.common.utils import analyze_api_call_stack FORMAT_MAPPING = { MonitorConst.CSV: CSVWriterWithAD, @@ -72,13 +80,24 @@ def param_is_data_parallel_duplicate(dp_group): def squash_param_name(param_name): - for pattern in ['layers?\.(.*)', 'embeddings?\.(.*)', 'final.*', 'output.*', 'norm.*']: + for pattern in ['^.*\.(layers?\..*)', '^.*\.(embeddings?\..*)', '^.*\.(final.*)', '^.*\.(output.*)', + '^.*\.(norm.*)']: match = re.findall(pattern, param_name) if match: return match[0] return param_name +def is_recording_module(module_name, l2_targets, vpp_stage): + if len(l2_targets) > 0: + for pattern in [vpp_stage + squash_param_name(module_name), vpp_stage + module_name]: + if pattern in l2_targets: + return pattern + return "" + else: + raise NotImplementedError("If monitering l2_features, the targets should be set specifically.") + + # Used For Module Forward & Backward Collect class ModuleHookContext: def __init__(self, module_name) -> None: @@ -88,30 +107,26 @@ class ModuleHookContext: self.actvgrad = [] self.module_name = module_name self.struct = {} - self.format_by_arg = {} - self.verified = False - self.focused_in_col = 0 - self.focused_out_col = 0 - self.ignore_in = False # no need to care when no key 'input' or 'input_grad' found - - def set_format_by_arg(self, key_name: str, target_config: dict): - cared = target_config.get(self.module_name, self.struct) - if key_name in cared: - if isinstance(cared[key_name], dict): - # current cared is self.struct - config = cared[key_name].get('config') - self.format_by_arg[key_name] = config - else: - # current cared is target_config[self.module_name] - self.format_by_arg[key_name] = cared[key_name] - elif key_name in ['input', 'input_grad']: - self.ignore_in = True + self.stack = "" def reset(self): self.actv.clear() self.actvgrad.clear() +class FeatureHookContext: + def __init__(self, module_name): + self.step = 0 + self.micro_step = 0 + self.attention_feature = {} + self.linear_feature = {} + self.module_name = module_name + + def reset(self): + self.attention_feature.clear() + self.linear_feature.clear() + + start_step = 0 @@ -186,6 +201,7 @@ class TrainerMon: self.config_file_path = config_file_path self.process_group = process_group self.params_have_main_grad = params_have_main_grad + self.is_mindtorch = is_mindtorch() self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开 self.config = load_json(config_file_path) validate_config(self.config) @@ -218,10 +234,12 @@ class TrainerMon: self.dp_group = None self.tp_group = None self.micro_batch_number = 1 + self.optimizer_mon = None # TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量 self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext) self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext) + self.feature_hook_context_by_module = defaultdict(FeatureHookContext) self.optimizer_context = defaultdict(OptimizerContext) self.cc_context = defaultdict(CommunicationContext) self.grad_context = GradContext() @@ -240,6 +258,8 @@ class TrainerMon: self.optimizer_hooked = False self.param_registered = False self.struct_printed = False + self.pre_step_hooks = [] + self.post_step_hooks = [] # 动静态区分 self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true' @@ -253,6 +273,18 @@ class TrainerMon: if self.collect_times > 0: self.monitoring = True + @staticmethod + def get_linear_hook_target(module): + if isinstance(module, nn.Embedding): + return '' + if hasattr(module, "num_embeddings") or hasattr(module, "vocab_start_index"): + return '' + for weight_name in ["weight", "wg"]: + if hasattr(module, weight_name) and isinstance(getattr(module, weight_name), Tensor): + if getattr(module, weight_name).dim() == 2: + return weight_name + return '' + def set_config(self): self.start_step = self.config.get("start_step", 0) self.collect_times = self.config.get("collect_times", 100000000) # 默认大值, 目的是一直采集 @@ -260,7 +292,6 @@ class TrainerMon: self.has_collect_times = 0 # 重设采集计数器 self.print_struct = self.config.get("print_struct", False) self.targets = self.config.get("targets", None) - self.is_select = self.config.get("is_select", False) self.module_rank_list = self.config.get("module_ranks", []) self.format = self.config.get('format', MonitorConst.CSV) # only csv supported in mindspore self.eps = self.config.get('eps', 1e-8) @@ -276,6 +307,12 @@ class TrainerMon: self.param_distribution = self.config.get("param_distribution", False) self.mg_direction = self.config.get('mg_direction', False) # main grad direction self.cc_distribution = self.config.get("cc_distribution", {}) # communication ops + self.stack_info = self.config.get('stack_info', False) + self.monitor_mbs_grad = self.config.get('monitor_mbs_grad', False) + self.recording_l2_features = self.config.get('recording_l2_features', False) + self.sa_order = self.config.get('sa_order', "s,b,h,d") + + if not self.cc_distribution.get('enable', False): self.cc_log_only = False else: @@ -283,8 +320,6 @@ class TrainerMon: self.cc_log_only = self.cc_distribution.get('cc_log_only', False) self.cc_logged_stack = defaultdict(set) self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False) - self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self)) - api_register.redirect_api() self.common_info() # 初始化AnomalyData工厂 @@ -298,18 +333,25 @@ class TrainerMon: if self.format not in FORMAT_MAPPING: logger.error(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}") self.format = MonitorConst.CSV - writer = FORMAT_MAPPING[self.format] self.step_count_per_record = self.config.get('step_count_per_record', 1) - self.summary_writer = writer( - WriterInput( - self.tensorboard_dir, - self.alert_rules, - self.unique_id, - self.anomaly_data_factory, - self.ndigits, - self.step_count_per_record + if not self.module_rank_list or (self.rank in self.module_rank_list): + writer = FORMAT_MAPPING[self.format] + self.summary_writer = writer( + WriterInput( + self.tensorboard_dir, + self.alert_rules, + self.unique_id, + self.anomaly_data_factory, + self.ndigits, + self.step_count_per_record + ) ) - ) + + # 初始化anomaly detected文件目录 + if self.anomaly_data_factory: + self.anomaly_data_writer = AnomalyDataWriter(os.path.join(self.output_base_dir, "anomaly_detected"), + self.rank) + self.anomaly_data_writer.init_detected_json() def common_info(self): if not self.xy_distribution: @@ -322,6 +364,8 @@ class TrainerMon: logger.info("> momentum and variance of adam is not monitored. ") if not self.wg_distribution: logger.info("> weight grad of specified module is not monitored. ") + if not self.recording_l2_features: + logger.info("> l2 features of specified module is not monitored. ") if not self.mg_direction: logger.info('> grad and momentum direction will not be compared.') if not self.cc_distribution.get('enable', False): @@ -341,6 +385,7 @@ class TrainerMon: self.micro_batch_number = grad_acc_steps self.dp_group = dp_group self.tp_group = tp_group + self.optimizer_mon = OptimizerMonFactory.create_optimizer_mon(optimizer) self.hook_step_final(optimizer) if not isinstance(model, list): model = [model] @@ -361,16 +406,29 @@ class TrainerMon: context.step - self.start_step) % self.step_interval == 0) if module_rank_valid and step_condition: self.has_collect_times += 1 + + if self.anomaly_data_factory: + self.anomaly_data_factory.set_call_id(self.param_name_call_id) self.write_xy_tb(context.step) self.write_grad_tb(context.step) self.write_mv_tb(context) self.write_param_tb(context) + self.write_features_tb(context.step) + if self.stack_info: + self.write_stack_info() + self.stack_info = False + for handle in self.handles["stack"]: + handle.remove() + self.handles["stack"].clear() if context.metric_dict: self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other') context.metric_dict.clear() + if self.anomaly_data_factory: + self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies()) self.summary_writer.clear_anomalies() + self.call_id = 0 self.param_name_call_id.clear() @@ -380,7 +438,22 @@ class TrainerMon: context.step += 1 self.dynamic_monitor(optimizer) - optimizer.register_forward_hook(step_final_hook) + def patch_step(func, optimizer): + def wrapper(*args, **kwargs): + for hook in self.pre_step_hooks: + hook(optimizer, args, kwargs) + out = func(*args, **kwargs) + for hook in self.post_step_hooks: + hook(optimizer, args, kwargs) + step_final_hook(optimizer, args, kwargs) + return out + return wrapper + + if self.is_mindtorch: + optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer) + else: + optimizer.__class__.construct = patch_step(optimizer.__class__.construct, optimizer) + return def dynamic_monitor(self, optimizer): @@ -408,13 +481,14 @@ class TrainerMon: validate_config(config) self.config = config self.set_config() + self.start_step = context.step # 动态启停时不受原start_step影响,永远从下一步开始 logger.warning(f"config is updated at step{context.step - 1}, " f"will start new hook at step{context.step}.") except Exception as e: logger.error(f"set config wrong because {e}, not updated, please check!!!") return - self._remove_all_hooks() + self._remove_all_hooks(optimizer) self.register_hooks(optimizer) def register_hooks(self, optimizer): @@ -422,6 +496,9 @@ class TrainerMon: self.hook_modules() self.hook_optimizer(optimizer) self._patch_grad_sync() + if self.cc_distribution.get('enable', False): + self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self)) + api_register.redirect_api() self.monitoring = True def hook_modules(self): @@ -436,45 +513,42 @@ class TrainerMon: hooked_count = 0 for vpp_stage, model_chunk in enumerate(self.model): - if not isinstance(model_chunk, nn.Cell): + if not is_valid_instance(model_chunk): logger.info("Target Model is not Cell") continue vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}' - targets = [x for x, _ in model_chunk.cells_and_names()] if self.print_struct else self.targets.keys() - hooked_count += self._hook_module(targets, model_chunk, vpp_stage) + targets = [x for x, _ in get_submodules(model_chunk)] if self.print_struct else self.targets.keys() + l2_target_names = self.config.get('l2_targets', {}) + hooked_count += self._hook_module(targets, l2_target_names, model_chunk, vpp_stage) logger.info(f"> {hooked_count} modules are monitored.") def hook_optimizer(self, optimizer): - def optimizer_pre_hook_function(opt, grad_names, gradients): + def optimizer_pre_step_hook(opt, *args, **kwargs): context = self.optimizer_context[opt] + if (self.print_struct and not all(value == {} for value in self.module_struct.values()) + and not self.struct_printed): + self._save_module_struct() + if not self.cc_log_only: + raise Exception("exit after first monitor step when print model struct") if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times, self.collect_times): return - gradient_list = gradients[0] if isinstance(gradients, tuple) else gradients - is_select = self.is_select - for idx, grad in enumerate(gradient_list): - grad_name = grad_names[idx] - if is_select and grad_name not in self.targets: - continue - get_single_metrics(self.ops, grad_name, grad, context.param_weight_grad) - - if self.mv_distribution: - # fetch mean - for param in m_list: - name = param.name - if is_select and name not in self.targets: - continue - get_single_metrics(self.ops, name, param, context.exp_avg_metric) - # fetch variance - for param in v_list: - name = param.name - if is_select and name not in self.targets: - continue - get_single_metrics(self.ops, name, param, context.exp_avg_sq_metric) - if self.param_distribution: - for param in param_list: - get_single_metrics(self.ops, param.name, param, context.param_metric) - self.generate_wgrad_metrics() + + grad_dict = {} + if self.wg_distribution: + grad_dict = self.optimizer_mon.fetch_grad(self, self.param2name) + + if self.mv_distribution or self.ur_distribution or self.mg_direction: + if self.is_mindtorch: + context.param_exp_avg, context.param_exp_avg_sq, context.param_adam_update, \ + context.param_adam_ratio = self.optimizer_mon.fetch_mv(self, self.param2name) + else: + context.param_exp_avg, context.param_exp_avg_sq = self.get_mv_for_ms(optimizer) + + self.generate_wgrad_metrics(grad_dict) + self.generate_mv_metrics(context) + self.generate_param_metrics(context, MonitorConst.PRE_PARAM) + metric_dict = {} for cc in self.cc_context.values(): cc.aggregate() @@ -486,63 +560,86 @@ class TrainerMon: context.metric_dict = metric_dict return - def optimizer_pre_hook_wrapper(func, grad_names): - def wrapper(opt, gradients): - return func(opt, grad_names, gradients) - return wrapper + def optimizer_post_step_hook(optimizer, args, kwargs): + context = self.optimizer_context[optimizer] + self.generate_param_metrics(context, MonitorConst.POST_PARAM) + if self.optimizer_hooked or not self.is_target_rank(): return - m_list = [] - v_list = [] - param_list = [] - grad_names = [] - for param in optimizer.get_parameters(): - if MonitorConst.EXP_AVG_SQ in param.name: - v_list.append(param) - elif MonitorConst.EXP_AVG in param.name: - m_list.append(param) - elif param.name in ['global_step', 'learning_rate']: - pass - else: - param_list.append(param) - grad_names.append(param.name) - - handle = optimizer.register_forward_pre_hook( - optimizer_pre_hook_wrapper(optimizer_pre_hook_function, grad_names)) - self.handles['optimizer'].append(handle) + self.pre_step_hooks.append(optimizer_pre_step_hook) + self.post_step_hooks.append(optimizer_post_step_hook) self.optimizer_hooked = True return - def generate_wgrad_metrics(self): + def generate_wgrad_metrics(self, grad_dict): if not self.wg_distribution: - return {}, {} + return - if self.weight_hooked: - try: - get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric) - except Exception as e: - logger.warning(f"An error occurred while generating wgrad pre metrics") - return {}, {} + get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric) + get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post) - grad_dict = {} - for param, name in self.param2name.items(): - if self.duplicate_param.get(name, False): - continue - grad = param.main_grad if self.params_have_main_grad else param.grad - if grad is None: - logger.warning(f"grad is None: {name}, maybe something wrong happened.") + def generate_param_map(self, tag, param_tensor): + metrics = {} + if not self.is_mindtorch: + return param_tensor + for name in self.param2name.values(): + key = get_summary_writer_tag_name(name, tag, self.rank) + self.register_param_call_id("optimizer_pre_step_hook", key) + if name not in param_tensor or param_tensor[name] is None: continue - tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD) - self._register_param_call_id("hook_optimizer", tag) - grad_dict[tag] = grad - try: - get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post) - except Exception as e: - logger.warning(f"An error occurred while generating wgrad post metrics") + metrics[key] = param_tensor[name] + return metrics + + def generate_param_metrics(self, opt_context, stage=MonitorConst.PRE_PARAM): + if not self.param_distribution: + return + tag2param = { + self.name2tag.get(name, {}).get(stage): param + for name, param in self.name2param.items() + if param.numel() != 0 + } + get_metrics(self.ops, tag2param, self.eps, opt_context.param_metric) + + def get_mv_for_ms(self, opt): + if not self.mv_distribution: return {}, {} - return self.grad_context.post, self.grad_context.pre + common_opt = opt + if not is_valid_instance(opt): + common_opt = getattr(opt, 'optimizer') + if not is_valid_instance(common_opt): + logger.warning("Optimizer is not valid, please check usage") + return {}, {} + m_dict = {} + v_dict = {} + for name, param in get_parameters(common_opt): + if MonitorConst.EXP_AVG_SQ in name: + v_dict[name] = param + elif MonitorConst.EXP_AVG in name: + m_dict[name] = param + return m_dict, v_dict + + def generate_mv_metrics(self, opt_context): + if not self.mv_distribution: + return + opt_context.exp_avg_metric = {} + opt_context.exp_avg_sq_metric = {} + m_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG, opt_context.param_exp_avg) + v_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG_SQ, opt_context.param_exp_avg_sq) + get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric) + get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric) + + def write_stack_info(self): + stack_data = [] + header = ["module_name", "stack_info"] + stack_data.append(header) + for _, fwd_context in self.module_fwd_hook_context_by_module.items(): + stack_data.append([fwd_context.module_name, fwd_context.stack]) + filepath = os.path.join(self.tensorboard_dir, f'stack_info.csv') + if not os.path.exists(filepath): + data_frame = pd.DataFrame(columns=stack_data) + write_df_to_csv(data_frame, filepath) def write_xy_tb(self, step): if not self.xy_distribution: @@ -550,41 +647,90 @@ class TrainerMon: for _, fwd_context in self.module_fwd_hook_context_by_module.items(): if len(fwd_context.actv) == 0: continue - self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, 'actv') + self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, MonitorConst.ACTV) fwd_context.actv.clear() if self.grad_context.actv: - self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, 'actv_grad') + self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, MonitorConst.ACTVGRAD) def write_param_tb(self, opt_context): if not self.param_distribution: return - self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, 'param') + param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.PRE_PARAM in k} + updated_param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.POST_PARAM in k} + self.summary_writer.write_metrics(self.ops, param_metrics, opt_context.step, MonitorConst.PRE_PARAM) + self.summary_writer.write_metrics(self.ops, updated_param_metrics, opt_context.step, MonitorConst.POST_PARAM) def write_mv_tb(self, opt_context): if not self.mv_distribution: return - self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step, 'exp_avg') - self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step, 'exp_avg_sq') + self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step, MonitorConst.EXP_AVG) + self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step, + MonitorConst.EXP_AVG_SQ) def write_grad_tb(self, step): if not self.wg_distribution: return - self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced') + self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced', + use_micro_step=self.monitor_mbs_grad) self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced') + def write_metrics_if_not_empty(self, features, metrics, step, hook_name): + if not features or len(features) == 0: + return + use_micro_step = hook_name not in ["linear_hook"] + self.summary_writer.write_metrics(metrics, features, step, hook_name, use_micro_step=use_micro_step) + features.clear() + + def write_features_tb(self, step): + if not self.recording_l2_features: + return + for context in self.feature_hook_context_by_module.values(): + num_features = len(context.attention_feature) + len(context.linear_feature) + if num_features == 0: + continue + self.write_metrics_if_not_empty(context.attention_feature, ["entropy", "softmax"], step, + "attention_hook") + self.write_metrics_if_not_empty(context.linear_feature, ["sr", "kernel_norm"], step, + "linear_hook") + def is_target_rank(self): if self.module_rank_list and (self.rank not in self.module_rank_list): return False return True - def build_tbtag_tensor_map(self, module_name, tag, tensor): - metrics = {} - key = get_summary_writer_tag_name(module_name, tag, str(self.rank)) + def build_tbtag_tensor_map(self, module_name, suffix, tag, tensor): + """ + :param module_name: str of module name + :param suffix: + :param tag: + :param tensor: torch.tensor or tuple/list of torch.tensor + :return: tensor_map + """ + tensor_map = {} if isinstance(tensor, Tensor): - self._register_param_call_id("_hook_module", key) - metrics[key] = tensor - return metrics + tensor = [tensor] + if isinstance(tensor, tuple) or isinstance(tensor, list): + if len(tensor) == 1: + key = get_summary_writer_tag_name(module_name + suffix, tag, self.rank) + self.register_param_call_id("_hook_module", key) + tensor_map[key] = tensor[0] + else: + for i, tensor_i in enumerate(tensor): + key = get_summary_writer_tag_name(module_name + f"_{i}" + suffix, tag, self.rank) + self.register_param_call_id("_hook_module", key) + tensor_map[key] = tensor_i + return tensor_map + + def register_param_call_id(self, hook_name: str, key: str): + """ + :param hook_name: + :param key: str, '0:relu_0/output_grad' + :return: + """ + logger.debug(f"{hook_name} {key}: {self.call_id}") + self.param_name_call_id[key] = self.call_id + self.call_id += 1 def _register_param_name(self): for vpp_stage, model_chunk in enumerate(self.model): @@ -593,8 +739,7 @@ class TrainerMon: def _register_chunk(self, model_chunk, prefix): index = 0 - for param in model_chunk.get_parameters(): - param_name = param.name + for param_name, param in get_parameters(model_chunk): if not param.requires_grad: continue if self._is_target_param(param_name, param, prefix): @@ -609,25 +754,45 @@ class TrainerMon: self.duplicate_param[name] = True if self.dp_group and param_is_data_parallel_duplicate(self.dp_group): self.duplicate_param[name] = True + keywords = [ + MonitorConst.PRE_GRAD, + MonitorConst.POST_GRAD, + MonitorConst.PRE_PARAM, + MonitorConst.POST_PARAM + ] self.name2tag[name] = { - MonitorConst.PRE_GRAD: get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD, self.rank), - MonitorConst.POST_GRAD: get_summary_writer_tag_name(name, MonitorConst.POST_GRAD, self.rank) + k: get_summary_writer_tag_name(name, k, self.rank) + for k in keywords } index += 1 - def _hook_module(self, target_names, module, vpp_stage=''): - if not isinstance(module, nn.Cell): + def _save_module_struct(self): + output_dir = os.path.join(get_output_base_dir(), 'module_struct', f'rank{self.rank}') + make_dir(output_dir) + module_struct_file = os.path.realpath(os.path.join(output_dir, 'module_struct.json')) + save_json(module_struct_file, self.module_struct, indent=2) + logger.info(f"> save module struct to {module_struct_file}") + self.struct_printed = True + + def _hook_module(self, target_names, l2_target_names, module, vpp_stage=''): + if not is_valid_instance(module): # nothing to hook return 0 - def fwd_hook_fun(module, module_input, module_output, name): + def fwd_hook_fun(module, args, kwargs, module_output, name): + + module_input = [tensor for tensor in args if isinstance(tensor, Tensor)] + if kwargs: + kwargs_tensors = [tensor for tensor in kwargs.values() if isinstance(tensor, Tensor)] + module_input.extend(kwargs_tensors) + if module not in self.module_fwd_hook_context_by_module: self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name) context: ModuleHookContext = self.module_fwd_hook_context_by_module[module] if not context.struct: context.struct = { - MonitorConst.ACTV_IN: get_param_struct(module_input), - MonitorConst.ACTV_OUT: get_param_struct(module_output) + Const.INPUT: get_param_struct(module_input), + Const.OUTPUT: get_param_struct(module_output) } if self.print_struct: self.module_struct[context.module_name].update(context.struct) @@ -638,31 +803,18 @@ class TrainerMon: self.collect_times): step_accumulates_one(context, self.micro_batch_number) return - if not context.format_by_arg: - context.set_format_by_arg(MonitorConst.ACTV_IN, self.targets) - context.set_format_by_arg(MonitorConst.ACTV_OUT, self.targets) - if not context.format_by_arg: - return - if not context.verified: - if not context.ignore_in: - context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_IN], - module_input, context.module_name, - MonitorConst.ACTV_IN) - context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_OUT], - module_output, context.module_name, - MonitorConst.ACTV_OUT) - context.verified = True tbtag_tensor_map = {} - if not context.ignore_in: - cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col] - tbtag_tensor_map.update( - self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN, - cared_input)) - cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col] tbtag_tensor_map.update( - self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_OUT, - cared_output)) + self.build_tbtag_tensor_map( + f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}', + MonitorConst.ACTV, module_input)) + module_output = [tensor for tensor in module_output if isinstance(tensor, Tensor)] \ + if isinstance(module_output, tuple) else module_output + tbtag_tensor_map.update( + self.build_tbtag_tensor_map( + f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}', + MonitorConst.ACTV, module_output)) try: get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv) except Exception as e: @@ -687,31 +839,17 @@ class TrainerMon: step_accumulates_one(context, self.micro_batch_number) return - if not context.format_by_arg: - context.set_format_by_arg(MonitorConst.ACTVGRAD_IN, self.targets) - context.set_format_by_arg(MonitorConst.ACTVGRAD_OUT, self.targets) - if not context.format_by_arg: - return - if not context.verified: - if not context.ignore_in: - context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_IN], - input_grad, context.module_name, - MonitorConst.ACTVGRAD_IN) - context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_OUT], - output_grad, context.module_name, - MonitorConst.ACTVGRAD_OUT) - context.verified = True - + valid_input_grad = [tensor for tensor in input_grad if isinstance(tensor, Tensor)] tbtag_tensor_map = {} - if not context.ignore_in: - cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col] - tbtag_tensor_map.update( - self.build_tbtag_tensor_map( - f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN, cared_input_grad)) - cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col] tbtag_tensor_map.update( - self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_OUT, - cared_output_grad)) + self.build_tbtag_tensor_map( + f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}', + MonitorConst.ACTVGRAD, valid_input_grad)) + + tbtag_tensor_map.update( + self.build_tbtag_tensor_map( + f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}', + MonitorConst.ACTVGRAD, output_grad)) if context.micro_step == 0 and context.actvgrad: logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, " @@ -725,21 +863,95 @@ class TrainerMon: step_accumulates_one(context, self.micro_batch_number) return - def fwd_hook_fun_wrapper(fwd_hook_fun, name): - def wrapper(module, module_input, module_output): - return fwd_hook_fun(module, module_input, module_output, name) - return wrapper + def fwd_hook_register(module, fwd_hook_fun, name): + from packaging import version + if version.parse(mindspore.__version__) >= version.parse('2.6.0'): + def wrapper(module, args, kwargs, module_output): + return fwd_hook_fun(module, args, kwargs, module_output, name) + return module.register_forward_hook(wrapper, with_kwargs=True) + + else: + def wrapper(module, args, module_output): + return fwd_hook_fun(module, args, None, module_output, name) + return module.register_forward_hook(wrapper) + + def extract_attention_feature_hook(module, args, kwargs, module_output, name): + module_input = [tensor for tensor in args if isinstance(tensor, Tensor)] + if kwargs: + kwargs_tensors = [tensor for tensor in kwargs.values() if isinstance(tensor, Tensor)] + module_input.extend(kwargs_tensors) + + if module not in self.feature_hook_context_by_module: + self.feature_hook_context_by_module[module] = FeatureHookContext(name) + context: FeatureHookContext = self.feature_hook_context_by_module[module] + + tbtag_tensor_map = {} + if len(module_input) < 2: + logger.warning( + "Calculate attention feature failed, the length of module_input in attention hook's module should " + "be greater than or equal to 2.") + + q_h = module_input[0] + k_h = module_input[1] + qkt = cal_qkt(q_h, k_h, order=self.sa_order) + tbtag_tensor_map.update( + self.build_tbtag_tensor_map( + f'{context.module_name}.attention', f'{MonitorConst.NAME_SEP}{context.micro_step}', + 'qkt', qkt)) + get_entropy_metric(tbtag_tensor_map, context.attention_feature) + + context.micro_step += 1 + if context.micro_step == self.micro_batch_number: + context.micro_step = 0 + context.step += 1 + return + + def extract_linear_sr_hook(module, args, kwargs, module_output, name): + weight_name = self.get_linear_hook_target(module) + if weight_name == "": + return + if module not in self.feature_hook_context_by_module: + self.feature_hook_context_by_module[module] = FeatureHookContext(name) + context: FeatureHookContext = self.feature_hook_context_by_module[module] + + if context.micro_step == self.micro_batch_number - 1: + tbtag_tensor_map = {} + value = module.weight.data + tbtag_tensor_map.update( + self.build_tbtag_tensor_map( + f'{context.module_name}.linear', f'{MonitorConst.NAME_SEP}{context.micro_step}', + 'sr', value)) + + get_sr_metric(tbtag_tensor_map, context.linear_feature) + + context.micro_step += 1 + if context.micro_step == self.micro_batch_number: + context.micro_step = 0 + context.step += 1 + return + + def stack_hook(module, args, kwargs, module_output, name): + if module not in self.module_fwd_hook_context_by_module: + self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name) + context: ModuleHookContext = self.module_fwd_hook_context_by_module[module] + context.stack = analyze_api_call_stack(name) + return if self.backward_only and self.forward_only: logger.warning('not enable backward_only and forward_only simultaneously') hooked_count = 0 - if self.xy_distribution or self.print_struct: - for module_name, submodule in module.cells_and_names(): - name = self._is_target_module(module_name, target_names, vpp_stage) - if not name: - continue + + for module_name, submodule in get_submodules(module): + if self.stack_info: + name = vpp_stage + squash_param_name(module_name) + handle = fwd_hook_register(submodule, stack_hook, name=name) + self.handles["stack"].append(handle) + name = self._is_target_module(module_name, target_names, vpp_stage) + if not name: + continue + if self.xy_distribution or self.print_struct: if not self.backward_only: - handle = submodule.register_forward_hook(fwd_hook_fun_wrapper(fwd_hook_fun, name=name)) + handle = fwd_hook_register(submodule, fwd_hook_fun, name=name) self.handles['xy'].append(handle) if not self.forward_only: handle = submodule.register_backward_hook(bwd_hook_fun) @@ -747,6 +959,24 @@ class TrainerMon: self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name) logger.info(f"> {name} is monitored successfully") hooked_count += 1 + + if not self.print_struct and self.recording_l2_features: + for module_name, submodule in get_submodules(module): + func_map = { + "attention_hook": extract_attention_feature_hook, + "linear_hook": extract_linear_sr_hook + } + for hook in func_map.keys(): + if hook in l2_target_names: + temp_names = l2_target_names[hook] + name = is_recording_module(module_name, temp_names, vpp_stage) + if name: + handle = fwd_hook_register(submodule, func_map[hook], name=name) + print_feature_name = hook.split('_')[0] + logger.info_on_rank_0( + f'> {print_feature_name} features of {name} is monitored successfully') + self.handles["L2_features"].append(handle) + hooked_count += 1 return hooked_count def _patch_grad_sync(self): @@ -758,22 +988,30 @@ class TrainerMon: context = self.grad_context @_no_grad() - def param_hook(grad, context_dict, param, key): + def param_hook(grad, context_dict, param, name): + key = name + if self.monitor_mbs_grad: + key += f'{MonitorConst.NAME_SEP}{param.micro_step}' + key = get_summary_writer_tag_name(key, 'acc_grad', self.rank) + self.register_param_call_id("param_hook", key) param.micro_step += 1 - self._register_param_call_id("param_hook", key) + + if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number): + context_dict[key] = grad if param.micro_step == self.micro_batch_number: param.micro_step = 0 - context_dict[key] = grad - def param_hook_wrapper(param_hook, context_dict, param, key): + def param_hook_wrapper(param_hook, context_dict, param, name): def wrapper(grad): - return param_hook(grad, context_dict, param, key) + return param_hook(grad, context_dict, param, name) + return wrapper + logger.info("hooking weights.") for param, name in self.param2name.items(): - key = get_summary_writer_tag_name(name, 'acc_grad', self.rank) setattr(param, 'micro_step', 0) - handle = param.register_hook(param_hook_wrapper(param_hook, context_dict=context.acc, param=param, key=key)) + handle = param.register_hook( + param_hook_wrapper(param_hook, context_dict=context.acc, param=param, name=name)) self.handles['wgrads'].append(handle) self.weight_hooked = True @@ -799,26 +1037,21 @@ class TrainerMon: return pattern return "" - def _register_param_call_id(self, hook_name: str, key: str): - """ - :param hook_name: - :param key: str, '0:relu_0/output_grad' - :return: - """ - logger.debug(f"{hook_name} {key}: {self.call_id}") - self.param_name_call_id[key] = self.call_id - self.call_id += 1 - - def _remove_all_hooks(self): + def _remove_all_hooks(self, optimizer): # 清空hook handle for handle in self.handles['xy']: handle.remove() self.handles['xy'].clear() + for handle in self.handles['L2_features']: + handle.remove() + self.handles['L2_features'].clear() # 清空对应context缓存 - for _, fwd_context in self.module_fwd_hook_context_by_module.items(): + for fwd_context in self.module_fwd_hook_context_by_module.values(): fwd_context.reset() - for _, bwd_context in self.module_bwd_hook_context_by_module.items(): + for bwd_context in self.module_bwd_hook_context_by_module.values(): bwd_context.reset() + for feature_context in self.feature_hook_context_by_module.values(): + feature_context.reset() self.grad_context.reset() # 权重梯度和激活值梯度都在这 for handle in self.handles['wgrads']: @@ -827,9 +1060,8 @@ class TrainerMon: self.weight_hooked = False if self.optimizer_hooked: - for handle in self.handles['optimizer']: - handle.remove() - self.handles['optimizer'].clear() + self.pre_step_hooks.clear() + self.post_step_hooks.clear() for _, context in self.optimizer_context.items(): context.reset() self.optimizer_hooked = False @@ -837,6 +1069,7 @@ class TrainerMon: for handle in self.handles['cc']: handle.remove() self.handles['cc'].clear() + api_register.restore_api() for _, context in self.cc_context.items(): context.reset() @@ -867,4 +1100,4 @@ class TrainerMon: except Exception as e: logger.warning(f"Finish monitor, set config'dynamic_on=False fail because {e}, please check!!!") logger.info("Finish monitor") - self._remove_all_hooks() + self._remove_all_hooks(optimizer) diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/module_spec_verifier.py b/debug/accuracy_tools/msprobe/mindspore/monitor/module_spec_verifier.py deleted file mode 100644 index c06e8ea10f6a2178c3670e596ad64e333db44cab..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/module_spec_verifier.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -import abc -from mindspore import Tensor - -from msprobe.core.common.log import logger - - -# 用于存储所有validator实现类的注册表 -config_validator_registry = {} - - -def register_config_validator(cls): - """装饰器 用于注册ConfigValidator的实现类""" - config_validator_registry[cls.__name__] = cls - return cls - - -class ConfigValidator(metaclass=abc.ABCMeta): - @abc.abstractmethod - def check_pattern_match(self, config_spec: str): - pass - - @abc.abstractmethod - def validate(self, actual_data, module_name: str, data_type: str, pattern_match): - pass - - -@register_config_validator -class TensorValidator(ConfigValidator): - def check_pattern_match(self, config_spec: str): - pattern = re.compile(r"tensor") - return pattern.match(config_spec) - - def validate(self, actual_data, module_name: str, data_type: str, pattern_match): - if not isinstance(actual_data, Tensor): - raise ValueError( - f"Format of {module_name} {data_type} does not match the required format 'tensor' in config.") - - -@register_config_validator -class TupleValidator(ConfigValidator): - def check_pattern_match(self, config_spec: str): - pattern = re.compile(r"tuple\[(\d+)\]:?(\d+)?") - return pattern.match(config_spec) - - def validate(self, actual_data, module_name: str, data_type: str, pattern_match): - length, index = pattern_match.groups() - if index is None: - index = 0 - length, index = int(length), int(index) - - if not (0 <= index < length): - raise ValueError( - f"Format of {module_name} {data_type} in config.json does not match the required format 'tuple[x]:y'." - f"y must be greater than or equal to 0 and less than x.") - if not isinstance(actual_data, tuple): - raise ValueError( - f"Type of {module_name} {data_type} does not match spec of config.json, should be tuple, please check.") - if len(actual_data) != length: - raise ValueError( - f"Length of {module_name} {data_type} does not match spec of config.json, should be {length}, " - f"actual is {len(actual_data)} please check.") - return index - - -def validate_config_spec(config_spec: str, actual_data, module_name: str, data_type: str): - focused_col = None - for _, validator_cls in config_validator_registry.items(): - config_validator = validator_cls() - pattern_match = config_validator.check_pattern_match(config_spec) - if pattern_match: - try: - focused_col = config_validator.validate(actual_data, module_name, data_type, pattern_match) - except ValueError as e: - logger.warning(f"config spec validate failed: {str(e)}") - return focused_col - logger.warning(f"config spec in {module_name} {data_type} not supported, " - f"expected spec:'tuple\[(\d+)\]:(\d+)' or 'tensor', actual spec: {config_spec}.") - return focused_col \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/optimizer_collect.py b/debug/accuracy_tools/msprobe/mindspore/monitor/optimizer_collect.py new file mode 100644 index 0000000000000000000000000000000000000000..b6eaad2d79ee6a7587bf4d7633f875858dc5eb0b --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/optimizer_collect.py @@ -0,0 +1,334 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import abstractmethod + +from mindspore import mint, ops + +from msprobe.mindspore.common.log import logger +from msprobe.core.common.const import MonitorConst + + +class OptimizerMon(object): + def __init__(self, optim) -> None: + self.fp16_to_fp32_param = {} + self.optim = optim + self.state = {} + + def narrow_from_flatten(self, param, flatten_state): + return flatten_state + + def get_state(self, optim): + if hasattr(optim, 'chained_optimizers'): + for opt in optim.chained_optimizers: + self._get_single_state(opt) + else: + self._get_single_state(optim) + + def fetch_grad(self, monitor, params2name): + if not self.fp16_to_fp32_param: + self.map_fp16_to_fp32_param(self.optim) + + grad_dict = {} + first_param = True + for param, name in params2name.items(): + if monitor.duplicate_param.get(name, False): + continue + if self.fp16_to_fp32_param and param not in self.fp16_to_fp32_param: + continue + grad = param.main_grad if monitor.params_have_main_grad else param.grad + element_in_cur_partition = self.fp16_to_fp32_param.get(param, param).numel() + if param.numel() != element_in_cur_partition: + if first_param: + grad = grad.flatten()[-element_in_cur_partition:] + else: # supposed to be the last one + grad = grad.flatten()[:element_in_cur_partition] + first_param = False + if grad is None: + continue + tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD) + monitor.register_param_call_id("hook_optimizer", tag) + grad_dict[tag] = grad + return grad_dict + + def map_fp16_to_fp32_param(self, optim): + pass + + def fetch_mv(self, monitor, params2name): + if not self.fp16_to_fp32_param: + self.map_fp16_to_fp32_param(self.optim) + if not self.state: + self.get_state(self.optim) + + exp_avg_dict = {} + exp_avg_sq_dict = {} + update_dict = {} + ratio_dict = {} + + if not self.state: + logger.warning('optimizer state can not accessed') + return exp_avg_dict, exp_avg_sq_dict, update_dict, ratio_dict + + for lp_param, name in params2name.items(): + if lp_param in self.fp16_to_fp32_param: + hp_param = self.fp16_to_fp32_param[lp_param] + else: + hp_param = lp_param + + if hp_param in self.state: + state_param = self.state.get(hp_param, {}) + exp_avg = self.narrow_from_flatten(lp_param, state_param.get("exp_avg", None)) + exp_avg_sq = self.narrow_from_flatten(lp_param, state_param.get("exp_avg_sq", None)) + if monitor.mv_distribution: + exp_avg_dict[name] = exp_avg + exp_avg_sq_dict[name] = exp_avg_sq + if monitor.mg_direction: + exp_avg_dict[name] = exp_avg + if monitor.ur_distribution: + if len(self.optim.param_groups) > 1: + logger.info(f"the length of optim.param_groups is {len(self.optim.param_groups)}.") + if 'step' in state_param: + step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron) + elif 'step' in self.optim.param_groups[0]: + step = self.optim.param_groups[0]['step'] # AdamW from mindspeed + else: + logger.warning(f"step of {name} is None, maybe something wrong happened.") + continue + if exp_avg is None or exp_avg_sq is None: + logger.warning(f"exp_avg or exp_avg_sq of {name} is None, skip calculation.") + continue + exp_avg_hat = exp_avg / (1 - self.optim.defaults['betas'][0] ** step) + exp_avg_sq_hat = exp_avg_sq / (1 - self.optim.defaults['betas'][1] ** step) + update_dict[name] = exp_avg_hat / (mint.sqrt(exp_avg_sq_hat) + self.optim.defaults['eps']) + ratio_dict[name] = exp_avg_hat / mint.sqrt(exp_avg_sq_hat) + monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name]) + monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name]) + return exp_avg_dict, exp_avg_sq_dict, update_dict, ratio_dict + + def _get_single_state(self, optim): + state = {} + if hasattr(optim, 'param_to_cpu_states_map'): + state = optim.param_to_cpu_states_map + elif hasattr(optim, 'state'): + state = optim.state + elif hasattr(optim, 'optimizer') and hasattr(optim.optimizer, 'state'): + state = optim.optimizer.state + self.state.update(state) + + +class MixPrecisionOptimizerMon(OptimizerMon): + """ + 混合精度优化器监控类。在混合精度训练中监控和管理优化器。 + 混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。 + """ + def map_fp16_to_fp32_param(self, optim): + for fp16_group, fp32_group in zip(optim.float16_groups, optim.fp32_from_float16_groups): + for fp16_param, fp32_param in zip(fp16_group, fp32_group): + self.fp16_to_fp32_param[fp16_param] = fp32_param + + +class MegatronDistributedOptimizerMon(OptimizerMon): + def map_fp16_to_fp32_param(self, optim): + if not (hasattr(optim, "model_float16_groups") and + hasattr(optim, "shard_fp32_from_float16_groups")): + raise Exception( + "megatron distributed optimizer should have model_float16_groups and shard_fp32_from_float16_groups, " + "if not, please check megatron-lm version") + for fp16_group, shard_fp32_group in zip(optim.model_float16_groups, + optim.shard_fp32_from_float16_groups): + for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group): + self.fp16_to_fp32_param[fp16_param] = shard_fp32_param + + +class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon): + def map_fp16_to_fp32_param(self, optim): + for opt in optim.chained_optimizers: + super().map_fp16_to_fp32_param(opt) + + +class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon): + def map_fp16_to_fp32_param(self, optim): + for opt in optim.chained_optimizers: + super().map_fp16_to_fp32_param(opt) + + +class DeepSpeedZeroOptimizerMon(OptimizerMon): + """ + Base monitor class for DeepSpeed ZeRO optimizer. + ZeRO stage 0 no partition + ZeRO stage 1 partitions optimizer states across data parallel processes. + ZeRO stage 2 additionally partitions gradients. + ZeRO stage 3 additionally partitions parameters. + + This class provides monitoring capabilities for ZeRO optimizers by: + - Handling gradient collection for different ZeRO stages + - Managing optimizer state access for monitoring + """ + def __init__(self, optim): + super().__init__(optim) + self.stage = '' + self.bit16_groups = [] + self.fp32_flat_groups = [] + self.param2group = () + self.param2index = [] + self.group_offset = {} + + @abstractmethod + def get_grad_for_param(self, lp_param, group_idx, param_id): + raise NotImplementedError + + def param_not_in_partition(self, lp_param, group_idx): + param_slice_mapping = self.optim.state_dict()['param_slice_mappings'][group_idx] + hp_address = param_slice_mapping.get(self.optim.param_names.get(lp_param)) + return hp_address is None + + def get_position(self, lp_param, group_idx): + param_slice_mapping = self.optim.state_dict()['param_slice_mappings'][group_idx] + hp_address = param_slice_mapping.get(self.optim.param_names.get(lp_param)) + return hp_address.start, hp_address.numel + + def get_group_index(self): + param2group = {} + for group_idx, bit16_group in enumerate(self.bit16_groups): + for param in bit16_group: + param2group[param] = group_idx + return param2group + + def get_param_index(self, lp_param, group_idx): + if not self.param2index: + for group in self.bit16_groups: + param2index = {} + for index, param in enumerate(group): + param2index[param] = index + self.param2index.append(param2index) + + return self.param2index[group_idx][lp_param] + + def narrow_from_flatten(self, param, flatten_state): + if flatten_state is None: + return flatten_state + group_idx = self.param2group[param] + if self.param_not_in_partition(param, group_idx): + return None + start, numel = self.get_position(param, group_idx) + return flatten_state.narrow(0, start, numel) + + def map_fp16_to_fp32_param(self, optim): + for group_idx, group in enumerate(self.bit16_groups): + for param in group: + self.fp16_to_fp32_param[param] = self.fp32_flat_groups[group_idx] + + def fetch_grad(self, monitor, params2name): + grad_dict = {} + for lp_param, name in params2name.items(): + group_idx = self.param2group[lp_param] + param_id = self.get_param_index(lp_param, group_idx) + if self.param_not_in_partition(lp_param, group_idx): + continue + if self.stage == '1or2': + param_id = param_id - self.group_offset[group_idx] - 1 + grad = self.get_grad_for_param(lp_param, group_idx, param_id) + tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD) + monitor.register_param_call_id("hook_optimizer", tag) + grad_dict[tag] = grad + + return grad_dict + + +class DeepSpeedZeroOptimizerStage0Mon(DeepSpeedZeroOptimizerMon): + def __init__(self, optim): + super().__init__(optim) + self.stage = '0' + self.bit16_groups = optim.bf16_groups + self.fp32_flat_groups = optim.fp32_groups_flat_partition + self.param2group = self.get_group_index() + + def get_grad_for_param(self, lp_param, group_idx, param_id): + return self.optim.fp32_groups_gradient_dict[group_idx][param_id] + + +class DeepSpeedZeroOptimizerStage1or2Mon(DeepSpeedZeroOptimizerMon): + def __init__(self, optim): + super().__init__(optim) + self.stage = '1or2' + self.bit16_groups = optim.bit16_groups + self.fp32_flat_groups = optim.single_partition_of_fp32_groups + self.param2group = self.get_group_index() + self.group_offset = {} + self.get_group_offset() + + def get_grad_for_param(self, lp_param, group_idx, param_id): + if getattr(self.optim, "cpu_offload", False): + grads = self.optim.single_partition_of_fp32_groups[group_idx].grad + start, numel = self.get_position(lp_param, group_idx) + grad = grads.narrow(0, start, numel) + else: + grad = self.optim.averaged_gradients[group_idx][param_id] + return grad + + def get_group_offset(self): + for group_idx, group in enumerate(self.bit16_groups): + self.group_offset[group_idx] = -1 + for lp_param in group: + if self.param_not_in_partition(lp_param, group_idx): + self.group_offset[group_idx] = self.get_param_index(lp_param, group_idx) + else: + break + + +class DeepSpeedZeroOptimizerStage3Mon(DeepSpeedZeroOptimizerMon): + def __init__(self, optim): + super().__init__(optim) + self.stage = '3' + self.bit16_groups = optim.fp16_groups + self.fp32_flat_groups = optim.fp32_partitioned_groups_flat + self.param2group = self.get_group_index() + + def param_not_in_partition(self, lp_param, group_idx): + """Each param partioned across all zero ranks""" + return False + + def get_position(self, lp_param, group_idx): + param_id = self.optim.get_param_id(lp_param) + return self.optim.grad_position[param_id][1:] + + def get_grad_for_param(self, lp_param, group_idx, param_id): + return self.optim.averaged_gradients[group_idx][param_id] + + +class OptimizerMonFactory: + _optimizer_mon_map = { + "FP32Optimizer": OptimizerMon, + "Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon, + "DistributedOptimizer": MegatronDistributedOptimizerMon, + "SwapDistributedOptimizer": MegatronDistributedOptimizerMon, + "ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon, + "ChainedSwapDistributedOptimizer": MegatronChainedDistributedOptimizerMon, + "ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon, + "BF16_Optimizer": DeepSpeedZeroOptimizerStage0Mon, + "DeepSpeedZeroOptimizer": DeepSpeedZeroOptimizerStage1or2Mon, + "DeepSpeedZeroOptimizer_Stage3": DeepSpeedZeroOptimizerStage3Mon, + "Adam": OptimizerMon + } + + @staticmethod + def create_optimizer_mon(optimizer): + # auto replace opt_ty + optimizer_class = optimizer.__class__.__name__ + if optimizer_class == "ChainedOptimizer": + optimizer_class = "Chained" + optimizer.chained_optimizers[0].__class__.__name__ + logger.info(f'The optimizer type is {optimizer_class}') + + optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class, OptimizerMon) + return optimizer_mon_class(optimizer) diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py b/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py index a27172f19ead537f276c5ce0820b405d7abb6e25..d862bb2ef6de2f4b3ff7a5a5a264ed7abbe50f37 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py @@ -12,30 +12,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import re -from datetime import datetime from mindspore import dtype as mstype, Tensor -from msprobe.mindspore.monitor.features import FUNC_MAP -from msprobe.core.common.const import MonitorConst -from msprobe.core.common.utils import is_int -from msprobe.core.common.log import logger -from msprobe.core.common.file_utils import check_file_or_directory_path +from msprobe.mindspore.monitor.features import FUNC_MAP, cal_entropy, cal_stable_rank -def get_single_metrics(op_list, tag, tensor, output=None): +def get_single_metrics(op_list, tag, tensor, eps=1e-8, output=None): if output is None: output = {} if tag not in output: output[tag] = {} for op in op_list: func = FUNC_MAP.get(op) - statistic = func(tensor) + if op == "zeros": + statistic = func(tensor, eps) + else: + statistic = func(tensor) if hasattr(statistic, "dtype") and statistic.dtype == mstype.bfloat16: statistic = float(statistic) statistic = Tensor(statistic) - output[tag][op] = statistic.astype(mstype.float32) + if isinstance(statistic, Tensor): + output[tag][op] = statistic.astype(mstype.float32) + else: + output[tag][op] = statistic def get_metrics(op_list, tag2tensor, eps, output=None): @@ -44,7 +43,7 @@ def get_metrics(op_list, tag2tensor, eps, output=None): for tag, tensor in tag2tensor.items(): if tag not in output: output[tag] = {} - get_single_metrics(op_list, tag, tensor, output) + get_single_metrics(op_list, tag, tensor, eps, output) return output @@ -78,224 +77,27 @@ def is_skip_step(step, start_step, step_interval, has_collect_times=0, collect_t return step < start_step or (step - start_step) % step_interval != 0 or has_collect_times >= collect_times -def validate_ops(ops): - if not isinstance(ops, list): - raise TypeError("ops should be a list") - valid_ops = [] - for op in ops: - if op not in MonitorConst.OP_LIST: - logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}") - continue - valid_ops.append(op) - if not valid_ops: - default_op = MonitorConst.OP_LIST[0] - valid_ops.append(default_op) - logger.info(f"There is no valid ops, default op {default_op} is used") - return valid_ops - - -def validate_ranks(ranks): - if not isinstance(ranks, list): - raise TypeError("module_ranks should be a list") - for rank in ranks: - if not isinstance(rank, int): - raise TypeError(f"element in module_ranks should be a int, get {type(rank)}") - - -def validate_targets(targets): - if not isinstance(targets, dict): - raise TypeError('targets in config.json should be a dict') - for module_name, field in targets.items(): - if not isinstance(module_name, str): - raise TypeError('key of targets should be module_name[str] in config.json') - if not isinstance(field, dict): - raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json') - - -def validate_print_struct(print_struct): - if not isinstance(print_struct, bool): - raise TypeError("print_struct should be a bool") - - -def validate_ur_distribution(ur_distribution): - if not isinstance(ur_distribution, bool): - raise TypeError('ur_distribution should be a bool') - - -def validate_xy_distribution(xy_distribution): - if not isinstance(xy_distribution, bool): - raise TypeError('xy_distribution should be a bool') - - -def validate_wg_distribution(wg_distribution): - if not isinstance(wg_distribution, bool): - raise TypeError('wg_distribution should be a bool') - - -def validate_mg_distribution(mg_distribution): - if not isinstance(mg_distribution, bool): - raise TypeError('mg_distribution should be a bool') - - -def validate_param_distribution(param_distribution): - if not isinstance(param_distribution, bool): - raise TypeError('param_distribution should be a bool') - - -def validate_cc_distribution(cc_distribution): - if not isinstance(cc_distribution, dict): - raise TypeError('cc_distribution should be a dictionary') - expected_keys = { - 'enable': bool, - 'cc_codeline': list, - 'cc_pre_hook': bool, - 'cc_log_only': bool - } - for key, value in cc_distribution.items(): - if key in expected_keys: - if not isinstance(value, expected_keys[key]): - raise TypeError(f'cc_distribution {key} should be a {expected_keys[key].__name__}') - else: - raise TypeError(f'{key} of cc_distribution is not supported.') - - -def validate_alert(alert): - if not isinstance(alert, dict): - raise TypeError('alert should be a dictionary') - rules = alert.get('rules') - if rules and isinstance(rules, list): - for rule in rules: - rule_name = rule.get("rule_name") - if rule_name and rule_name not in MonitorConst.RULE_NAME: - raise TypeError(f"{rule_name} is not supported") - args = rule.get("args") - if args and isinstance(args, dict): - threshold = args.get("threshold") - if not isinstance(threshold, float) or threshold < 0: - raise TypeError('threshold must be float and not less than 0') - dump = alert.get('dump') - if dump and not isinstance(dump, bool): - raise TypeError('dump must be bool.') - - -def validate_step_count_per_record(step_count_per_record): - if not is_int(step_count_per_record): - raise TypeError('step_count_per_record must be int.') - if step_count_per_record < 1: - raise ValueError("step_count_per_record must greater than 0") - if step_count_per_record > 1e6: - raise ValueError("step_count_per_record must smaller than 1e6") - - -def validate_start_step(start_step): - if not is_int(start_step): - raise TypeError('start_step must be int.') - if start_step < 0: - raise ValueError("start_step must greater than 0") - if start_step > 1e8: - raise ValueError("start_step must smaller than 1e8") - - -def validate_step_interval(step_interval): - if not is_int(step_interval): - raise TypeError('step_interval must be int.') - if step_interval < 1: - raise ValueError("step_interval must greater than 1") - if step_interval > 1e8: - raise ValueError("step_interval must smaller than 1e8") - - -def validate_collect_times(collect_times): - if not is_int(collect_times): - raise TypeError('collect_times must be int.') - if collect_times < 1: - raise ValueError("collect_times must greater than 1") - - -def validate_config(config): - config['ops'] = validate_ops(config.get('ops', [])) - - eps = config.get('eps', 1e-8) - if not isinstance(eps, float): - raise TypeError("eps should be a float") - - ranks = config.get("module_ranks", []) - validate_ranks(ranks) - - targets = config.get("targets", {}) - validate_targets(targets) - - print_struct = config.get('print_struct', False) - validate_print_struct(print_struct) - - ur_distribution = config.get('ur_distribution', False) - validate_ur_distribution(ur_distribution) - - xy_distribution = config.get('xy_distribution', False) - validate_xy_distribution(xy_distribution) - - wg_distribution = config.get('wg_distribution', False) - validate_wg_distribution(wg_distribution) - - mg_distribution = config.get('mg_distribution', False) - validate_mg_distribution(mg_distribution) - - param_distribution = config.get('param_distribution', False) - validate_param_distribution(param_distribution) - - cc_distribution = config.get('cc_distribution', {}) - validate_cc_distribution(cc_distribution) - - alert = config.get('alert', {}) - validate_alert(alert) - - step_count_per_record = config.get('step_count_per_record', 1) - validate_step_count_per_record(step_count_per_record) - - start_step = config.get('start_step', 0) - validate_start_step(start_step) - - step_interval = config.get('step_interval', 1) - validate_step_interval(step_interval) - - collect_times = config.get('collect_times', 1e8) - validate_collect_times(collect_times) - - if not targets: - if xy_distribution: - config["all_xy"] = True - config["targets"] = {"": {}} - config["is_select"] = False - else: - config["is_select"] = True - - -def time_str2time_digit(time_str): - time_format = '%b%d_%H-%M-%S' - try: - time_digit = datetime.strptime(time_str, time_format) - except Exception as e: - raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \ - of existing output dirpath, like 'Dec03_21-34-40'.") from e - return time_digit +def get_entropy_metric(tag2tensor, out_dict=None): + if out_dict is None: + out_dict = {} + for tag, tensor in tag2tensor.items(): + if tag not in out_dict: + out_dict[tag] = {} + entropy, softmax = cal_entropy(tensor) + out_dict[tag]["entropy"] = entropy + out_dict[tag]["softmax"] = softmax + return out_dict -def get_target_output_dir(monitor_path, time_start, time_end): - check_file_or_directory_path(monitor_path, isdir=True) - time_start = time_str2time_digit(time_start) if time_start is not None else time_start - time_end = time_str2time_digit(time_end) if time_end is not None else time_end - if time_start and time_end and time_start > time_end: - raise ValueError(f"time_start({time_start}) greater than time_end({time_end})") - result = {} - for dirname in os.listdir(monitor_path): - match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname) - if not match: +def get_sr_metric(tag2tensor, out_dict=None): + if out_dict is None: + out_dict = {} + for tag, tensor in tag2tensor.items(): + if "sr" not in tag: continue - time_tag = match.group(1) - rank = match.group(2) - target_time = time_str2time_digit(time_tag) - start_ok = time_start is None or target_time >= time_start - end_ok = time_end is None or target_time <= time_end - if start_ok and end_ok: - result[rank] = os.path.join(monitor_path, dirname) - return result + if tag not in out_dict: + out_dict[tag] = {} + sr, eig = cal_stable_rank(tensor) + out_dict[tag]["sr"] = sr + out_dict[tag]["eig"] = eig + return out_dict diff --git a/debug/accuracy_tools/msprobe/mindspore/ms_config.py b/debug/accuracy_tools/msprobe/mindspore/ms_config.py index f20ed804c5bb8d8fbe4dba3e208060e8f52a3120..5194b2f5472e26afc658d648a11f34e0180dda36 100644 --- a/debug/accuracy_tools/msprobe/mindspore/ms_config.py +++ b/debug/accuracy_tools/msprobe/mindspore/ms_config.py @@ -14,7 +14,6 @@ # limitations under the License. from msprobe.core.common.const import Const -from msprobe.core.common.file_utils import load_json from msprobe.core.common.utils import is_int from msprobe.core.common_config import BaseConfig, CommonConfig from msprobe.core.grad_probe.constant import level_adp @@ -29,6 +28,7 @@ class TensorConfig(BaseConfig): self.check_mode = None self.file_format = json_config.get("file_format") self.check_config() + self._check_summary_mode() self._check_config() def _check_config(self): @@ -42,15 +42,27 @@ class StatisticsConfig(BaseConfig): self.file_format = None self.check_mode = None self.check_config() - self._check_config() - - def _check_config(self): - single_opt = ["statistics", "md5"] - muti_opt = ["md5", "max", "min", "mean", "l2norm"] - if isinstance(self.summary_mode, str) and self.summary_mode not in single_opt: - raise Exception("summary_mode is invalid") + self._check_summary_mode() + + self.tensor_list = json_config.get("tensor_list", []) + self._check_str_list_config(self.tensor_list, "tensor_list") + self.stat_cal_mode = json_config.get("device", "host") + self.device_stat_precision_mode = json_config.get("precision", "high") + self._check_stat_params() + + def _check_stat_params(self): + if self.stat_cal_mode not in ["device", "host"]: + raise Exception("Config param [device] is invalid, expected from [\"device\", \"host\"]") + if self.device_stat_precision_mode not in ["high", "low"]: + raise Exception("Config param [precision] is invalid, expected from [\"high\", \"low\"]") + + def _check_summary_mode(self): + muti_opt = ["max", "min", "mean", "count", "negative zero count", "positive zero count", "nan count", + "negative inf count", "positive inf count", "zero count", "l2norm", "hash", "md5"] + if isinstance(self.summary_mode, str) and self.summary_mode not in Const.SUMMARY_MODE: + raise Exception("summary_mode is an invalid string") if isinstance(self.summary_mode, list) and not all(opt in muti_opt for opt in self.summary_mode): - raise Exception("summary_mode is invalid") + raise Exception("summary_mode contains invalid option(s)") class OverflowCheckConfig(BaseConfig): @@ -68,6 +80,12 @@ class OverflowCheckConfig(BaseConfig): raise Exception("check_mode is invalid") +class ExceptionDumpConfig(BaseConfig): + def __init__(self, json_config): + super().__init__(json_config) + self.data_mode = ["all"] + + class FreeBenchmarkConfig(BaseConfig): def __init__(self, task_config): super().__init__(task_config) @@ -117,7 +135,8 @@ TaskDict = { Const.OVERFLOW_CHECK: OverflowCheckConfig, Const.FREE_BENCHMARK: FreeBenchmarkConfig, Const.GRAD_PROBE: GradProbeConfig, - Const.STRUCTURE: StructureConfig + Const.STRUCTURE: StructureConfig, + Const.EXCEPTION_DUMP: ExceptionDumpConfig } @@ -132,14 +151,3 @@ def parse_task_config(task, json_config): if task not in TaskDict: raise Exception("task is invalid.") return TaskDict.get(task)(task_map) - - -def parse_json_config(json_file_path): - if not json_file_path: - raise Exception("json file path is None") - json_config = load_json(json_file_path) - common_config = parse_common_config(json_config) - if not common_config.task: - common_config.task = Const.STATISTICS - task_config = parse_task_config(common_config.task, json_config) - return common_config, task_config diff --git a/debug/accuracy_tools/msprobe/mindspore/overflow_check/overflow_check_tool_factory.py b/debug/accuracy_tools/msprobe/mindspore/overflow_check/overflow_check_tool_factory.py index a2d3e290bd6b16b3deeb7f22a5e7d327ebaa2bc4..bae3104c249654dee5962b0d509cb53db5666b74 100644 --- a/debug/accuracy_tools/msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +++ b/debug/accuracy_tools/msprobe/mindspore/overflow_check/overflow_check_tool_factory.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from msprobe.core.common.log import logger from msprobe.mindspore.common.const import Const from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck @@ -31,7 +32,7 @@ class OverflowCheckToolFactory: Const.PYNATIVE_MODE: None }, Const.KERNEL: { - Const.GRAPH_KBYK_MODE: None, + Const.GRAPH_KBYK_MODE: KernelGraphOverflowCheck, Const.GRAPH_GE_MODE: KernelGraphOverflowCheck, Const.PYNATIVE_MODE: None } @@ -44,6 +45,7 @@ class OverflowCheckToolFactory: raise Exception("Valid level is needed.") tool = tool.get(config.execution_mode) if not tool: - raise Exception(f"Overflow check is not supported in {config.execution_mode} mode " - f"when level is {config.level}.") - return tool(config) + logger.error(f"Overflow check is not supported in {config.execution_mode} mode " + f"when level is {config.level}.") + raise ValueError + return (tool(config),) diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py deleted file mode 100644 index 5afbd046be4caf29c4b247a0f8fdd655c5208fd0..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/mindspore/service.py +++ /dev/null @@ -1,543 +0,0 @@ -# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import functools -import os -from collections import defaultdict - -import mindspore as ms -from mindspore import nn -from mindspore.common.api import _no_grad -from mindspore.ops.primitive import Primitive - -try: - from mindspore.common._pijit_context import PIJitCaptureContext -except ImportError: - pijit_label = False -else: - pijit_label = True - -from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException -from msprobe.core.common.file_utils import create_directory -from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation -from msprobe.core.data_dump.data_collector import build_data_collector -from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, - ModuleBackwardInputs) -from msprobe.core.data_dump.scope import BaseScope -from msprobe.mindspore.cell_processor import CellProcessor -from msprobe.mindspore.common.log import logger -from msprobe.mindspore.common.utils import (get_rank_if_initialized, clean_input_kwargs, - is_mindtorch, register_backward_hook_functions) -from msprobe.mindspore.dump.hook_cell.api_registry import api_register -from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService -from msprobe.mindspore.dump.jit_dump import JitDump -from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell -from msprobe.mindspore.dump.kernel_dump.kernel_config import create_kernel_config_json - -if is_mindtorch(): - import torch - - -class Service: - def __init__(self, config): - self.model = None - self.config = copy.deepcopy(config) - self.config.level = self.config.level_ori - self.data_collector = build_data_collector(self.config) - self.cell_processor = CellProcessor(self.data_collector.scope) - self.primitive_hook_service = PrimitiveHookService(self) - self.switch = False - self.inner_switch = False - self.primitive_switch = False - self.current_iter = 0 - self.first_start = True - self.current_rank = None - self.dump_iter_dir = None - self.start_call = False - self.should_stop_service = False - self.params_grad_info = {} - self.hook_handle_dict = {} - # 提前注册,确保注册尽可能多的API hook - self.register_api_hook() - self.init_for_debug_level() - - @staticmethod - def check_model_valid(models): - target_module_type = (torch.nn.Module, "torch.nn.Module") if is_mindtorch() else (nn.Cell, "mindspore.nn.Cell") - if models is None or isinstance(models, target_module_type[0]): - return models - error_model = None - if isinstance(models, (list, tuple)): - for model in models: - if not isinstance(model, target_module_type[0]): - error_model = model - break - else: - error_model = models - - if error_model is not None: - error_info = (f"The 'model' parameter must be a {target_module_type[1]} or list[{target_module_type[1]}] " - f"type, currently there is a {type(error_model)} type.") - raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, error_info) - return models - - @staticmethod - def prepare_module_input_output(target_type, cell, input_data, output): - if target_type == BaseScope.Module_Type_Module: - module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs={}, output=output) - else: - module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs, output=output) - return module_input_output - - def build_hook(self, target_type, name): - def pre_hook(api_or_cell_name, cell, input_data): - if not self.should_execute_hook(target_type, cell, True): - clean_input_kwargs(cell) - return None - - with _no_grad(): - self.inner_switch = True - if target_type == BaseScope.Module_Type_Module: - api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name) - else: - cell.forward_data_collected = True - HOOKCell.add_cell_count(name) - module_input_output = self.prepare_module_input_output(target_type, cell, input_data, None) - self.data_collector.update_api_or_module_name(api_or_cell_name) - self.data_collector.forward_input_data_collect(api_or_cell_name, cell, pid, module_input_output) - self.inner_switch = False - return input_data - - def grad_hook(cell, ori_name, param_name): - def hook_fn(grad): - if not self.should_execute_hook(target_type, cell, False): - return None - self.inner_switch = True - self.data_collector.params_data_collect(ori_name, param_name, pid, grad) - self.inner_switch = False - return None - - return hook_fn - - def register_param_hook(ori_name, cell, params_dict): - ''' - 注册参数hook - ''' - # data_mode为forward时,不注册参数hook - if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode): - for param_name, param in params_dict.items(): - if param.requires_grad: - name = ori_name + Const.SEP + param_name - old_handle = self.hook_handle_dict.get(name) - if old_handle and hasattr(old_handle, "remove"): - old_handle.remove() - handle = param.register_hook(grad_hook(cell, ori_name, param_name)) - self.hook_handle_dict[name] = handle - - def init_params_grad_info(cell, params_dict): - ''' - 初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位 - ''' - if not params_dict: - return - if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode): - grad_name = cell.params_grad_name if hasattr(cell, 'params_grad_name') else None - # 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中 - if not self.params_grad_info.get(grad_name): - data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}} - # 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位 - if data_info.get(grad_name): - # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新 - self.data_collector.handle_data(grad_name, data_info, - flush=self.data_collector.data_processor.is_terminated) - # 记录当前模块的参数梯度信息已占位 - self.params_grad_info[grad_name] = True - - def forward_hook(api_or_cell_name, cell, input_data, output): - if not self.should_execute_hook(target_type, cell, True): - clean_input_kwargs(cell) - return None - with _no_grad(): - self.inner_switch = True - module_input_output = self.prepare_module_input_output(target_type, cell, input_data, output) - if target_type == BaseScope.Module_Type_Module: - api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name) - params_dict = {} - if self.config.task != Const.STRUCTURE: - params_dict = { - key.split(Const.SEP)[-1]: value - for key, value in cell.parameters_dict(recurse=False).items() - } - setattr(module_input_output, Const.PARAMS, params_dict) - # 判断是否需要注册参数hook - if params_dict: - ori_name = api_or_cell_name.rsplit(Const.SEP, 2)[0] - grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD - # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook - setattr(cell, 'params_grad_name', grad_name) - register_param_hook(ori_name, cell, params_dict) - self.data_collector.update_api_or_module_name(api_or_cell_name) - self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output) - init_params_grad_info(cell, params_dict) - else: - self.data_collector.update_api_or_module_name(api_or_cell_name) - self.data_collector.forward_output_data_collect(api_or_cell_name, cell, pid, module_input_output) - - if self.data_collector.if_return_forward_new_output(): - forward_new_output = self.data_collector.get_forward_new_output() - self.inner_switch = False - return forward_new_output - clean_input_kwargs(cell) - self.inner_switch = False - return output - - def backward_hook(api_or_cell_name, cell, grad_input, grad_output): - if not self.should_execute_hook(target_type, cell, False): - return - self.inner_switch = True - - need_exchange = True - if target_type == BaseScope.Module_Type_Module: - if not hasattr(cell, 'has_pre_hook_called') or not cell.has_pre_hook_called: - need_exchange = False - api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name) - - self.data_collector.update_api_or_module_name(api_or_cell_name) - if self.data_collector: - # 框架最新接口变更,grad_input和grad_output的含义发生了变化,与torch含义保持一致,因此此处调换顺序传入 - if need_exchange: - module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input) - else: - module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output) - self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output) - self.inner_switch = False - - def pre_backward_hook(api_or_cell_name, cell, grad_input): - if not self.should_execute_hook(target_type, cell, False): - return - self.inner_switch = True - module_input = ModuleBackwardInputs(grad_input=grad_input) - self.data_collector.update_api_or_module_name(api_or_cell_name) - self.data_collector.backward_input_data_collect(api_or_cell_name, cell, pid, module_input) - - self.inner_switch = False - - pid = os.getpid() - if target_type == BaseScope.Module_Type_Module: - full_forward_name = name + Const.FORWARD - full_backward_name = name + Const.BACKWARD - else: - full_forward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.FORWARD - full_backward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.BACKWARD - pre_forward_hook = functools.partial(pre_hook, full_forward_name) - forward_hook = functools.partial(forward_hook, full_forward_name) - backward_hook = functools.partial(backward_hook, full_backward_name) - pre_backward_hook = functools.partial(pre_backward_hook, full_backward_name) - - def wrap_pre_forward_hook(cell, input_data): - return pre_forward_hook(cell, input_data) - - def wrap_forward_hook(cell, input_data, output_data): - return forward_hook(cell, input_data, output_data) - - def wrap_backward_hook(cell, grad_input, grad_output): - return backward_hook(cell, grad_input, grad_output) - - def wrap_pre_backward_hook(cell, grad_input): - return pre_backward_hook(cell, grad_input) - - return wrap_pre_forward_hook, wrap_forward_hook, wrap_backward_hook, wrap_pre_backward_hook - - def update_primitive_counters(self, primitive_name): - if primitive_name not in self.primitive_counters: - self.primitive_counters[primitive_name] = 0 - else: - self.primitive_counters[primitive_name] += 1 - - def step(self): - if self.config.level == Const.LEVEL_DEBUG: - return - if self.config.async_dump: - self.data_collector.fill_stack_tensor_data() - if self.config.task == Const.TENSOR: - self.data_collector.data_processor.dump_async_data() - self.data_collector.write_json() - self.current_iter += 1 - self.data_collector.update_iter(self.current_iter) - self.reset_status() - - def start(self, model=None): - if self.config.level == Const.LEVEL_DEBUG: - return - self.start_call = True - if self.should_stop_service: - return - if self.need_end_service(): - self.should_stop_service = True - self.switch = False - self.primitive_switch = False - print_tools_ends_info() - return - if self.config.step and self.current_iter not in self.config.step: - return - self.model = self.check_model_valid(model) - - logger.info(f"{Const.TOOL_NAME}: debugger.start() is set successfully") - - if self.first_start: - try: - self.current_rank = get_rank_if_initialized() - except DistributedNotInitializedError: - self.current_rank = None - - if self.config.rank and self.current_rank not in self.config.rank: - return - self.register_primitive_hook() - self.register_cell_hook() - if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]: - JitDump.set_config(self.config) - JitDump.set_data_collector(self.data_collector) - if hasattr(ms.common.api, "_MindsporeFunctionExecutor"): - ms.common.api._MindsporeFunctionExecutor = JitDump - else: - ms.common.api._JitExecutor = JitDump - ms.common.api._PyNativeExecutor.grad = JitDump.grad - if pijit_label: - PIJitCaptureContext.__enter__ = self.empty - PIJitCaptureContext.__exit__ = self.empty - self.first_start = False - - api_register.api_set_hook_func() - self.switch = True - self.primitive_switch = True - logger.info(f"Dump switch is turned on at step {self.current_iter}. ") - self.create_dirs() - logger.info(f"Dump data will be saved in {self.dump_iter_dir}.") - JitDump.jit_dump_switch = True - - def stop(self): - if self.config.level == Const.LEVEL_DEBUG: - return - if self.should_stop_service: - return - logger.info(f"{Const.TOOL_NAME}: debugger.stop() is set successfully. " - "Please set debugger.start() to turn on the dump switch again. ") - if not self.start_call: - logger.error(f"{Const.TOOL_NAME}: debugger.start() is not set in the current scope.") - raise Exception("debugger.start() is not set in the current scope.") - if self.config.step and self.current_iter not in self.config.step: - return - if self.config.rank and self.current_rank not in self.config.rank: - return - self.switch = False - self.primitive_switch = False - self.start_call = False - if self.config.async_dump: - self.data_collector.fill_stack_tensor_data() - if self.config.task == Const.TENSOR: - self.data_collector.data_processor.dump_async_data() - self.data_collector.write_json() - JitDump.jit_dump_switch = False - - def need_end_service(self): - if self.config.step and self.current_iter > max(self.config.step): - return True - if self.data_collector and self.data_collector.data_processor.is_terminated: - return True - return False - - def should_execute_hook(self, hook_type, cell, is_forward): - is_cell_hook = hook_type == BaseScope.Module_Type_Module - if is_cell_hook and not self.switch: - return False - elif not is_cell_hook and is_forward and not self.switch: - return False - elif not is_cell_hook and not is_forward and not cell.forward_data_collected: - return False - - if self.inner_switch: - return False - if not self.data_collector or self.data_collector.data_processor.is_terminated: - return False - return True - - def create_dirs(self): - create_directory(self.config.dump_path) - self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}") - cur_rank = self.current_rank if self.current_rank is not None else '' - if self.config.level == Const.LEVEL_L2: - create_directory(self.dump_iter_dir) - kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank) - self.config.kernel_config_path = kernel_config_path - return - - dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}") - create_directory(dump_dir) - if self.config.task in self.data_collector.tasks_need_tensor_data: - dump_data_dir = os.path.join(dump_dir, "dump_tensor_data") - create_directory(dump_data_dir) - else: - dump_data_dir = None - - dump_path_aggregation = DumpPathAggregation() - dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json") - dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json") - dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json") - dump_path_aggregation.dump_tensor_data_dir = dump_data_dir - self.data_collector.update_dump_paths(dump_path_aggregation) - - self.data_collector.initialize_json_file( - framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK - ) - - def empty(self, *args, **kwargs): - pass - - def register_api_hook(self): - if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]: - logger.info(f"The api {self.config.task} hook function is successfully mounted to the model.") - api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API)) - api_register.api_set_hook_func() - - def get_cells_and_names(self): - cells_and_names_with_index = {} - - def get_cell_or_module(model): - return model.named_modules() if is_mindtorch() else model.cells_and_names() - - if isinstance(self.model, (list, tuple)): - for index, model in enumerate(self.model): - cells_and_names_with_index[str(index)] = get_cell_or_module(model) - else: - cells_and_names_with_index["-1"] = get_cell_or_module(self.model) - return cells_and_names_with_index - - def register_primitive_hook(self): - if self.config.level not in [Const.LEVEL_MIX, Const.LEVEL_L1]: - return - if not self.model or self.config.task not in Const.DUMP_DATA_COLLECTION_LIST: - return - - primitive_set = set() - cells_and_names_with_index = self.get_cells_and_names() - for cells_and_names in cells_and_names_with_index.values(): - for _, cell in cells_and_names: - for attribute, value in vars(cell).items(): - if isinstance(value, Primitive): - primitive_set.add((attribute, value)) - - for pname, primitive in primitive_set: - primitive_class_name = primitive.__class__.__name__ - primitive_combined_name = pname + Const.SEP + primitive_class_name - new_primitive = type('NewPrimitive', (primitive.__class__,), - {'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__, - primitive_combined_name)}) - primitive.__class__ = new_primitive - - def register_cell_hook(self): - if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L0]: - logger.info(f"The cell {self.config.task} hook function is successfully mounted to the model.") - if not self.model: - raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, - f"The current level is {self.config.level}, the model cannot be None") - model_type = Const.MODULE if is_mindtorch() else Const.CELL - cells_and_names_with_index = self.get_cells_and_names() - - for index, cells_and_names in cells_and_names_with_index.items(): - model = self.model if index == "-1" else self.model[int(index)] - for name, cell in cells_and_names: - if cell == model: - continue - cell_index = (index + Const.SEP) if index != "-1" else "" - prefix = (model_type + Const.SEP + cell_index + name + - Const.SEP + cell.__class__.__name__ + Const.SEP) - _, forward_hook, backward_hook, _ = self.build_hook(BaseScope.Module_Type_Module, prefix) - cell.register_forward_hook(forward_hook) - cell.register_forward_pre_hook( - self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START)) - cell.register_forward_hook( - self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP)) - - register_backward_hook_functions["full"](cell, backward_hook) - register_backward_hook_functions["pre"]( - cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START)) - register_backward_hook_functions["full"]( - cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP)) - - def reset_status(self): - self.primitive_hook_service.primitive_counters.clear() - self.data_collector.reset_status() - JitDump.jit_count = defaultdict(int) - self.params_grad_info.clear() - if self.config.level == Const.LEVEL_L2: - self.data_collector.data_processor.reset_status() - return - if self.config.step and self.current_iter not in self.config.step: - return - if self.config.rank and self.current_rank not in self.config.rank: - return - - def init_for_debug_level(self): - if not (self.config.level == Const.LEVEL_DEBUG and self.config.task in [Const.TENSOR, Const.STATISTICS]): - return - try: - self.current_rank = get_rank_if_initialized() - except DistributedNotInitializedError: - self.current_rank = None - # dir: dump_path -- rank{} -- debug.json - self.dump_iter_dir = self.config.dump_path - cur_rank = self.current_rank if self.current_rank is not None else '' - dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}") - create_directory(dump_dir) - if self.config.task in self.data_collector.tasks_need_tensor_data: - dump_data_dir = os.path.join(dump_dir, "dump_tensor_data") - create_directory(dump_data_dir) - else: - dump_data_dir = None - - dump_path_aggregation = DumpPathAggregation() - dump_path_aggregation.dump_tensor_data_dir = dump_data_dir - dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json") - self.data_collector.update_dump_paths(dump_path_aggregation) - self.data_collector.initialize_json_file( - framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK - ) - self.debug_variable_counter = defaultdict(int) - - def save(self, variable, name, save_backward): - ''' - Args: - variable: Union[List[variable], dict{str: variable}, mindspore.tensor, str, float, int] - name: str - save_backward: boolean - Return: - void - ''' - if self.config.level != Const.LEVEL_DEBUG: - return - count = self.debug_variable_counter[name] - self.debug_variable_counter[name] += 1 - - name_with_count = f"{name}.{count}" - grad_name_with_count = f"{name}_grad.{count}" - - # forward save - self.data_collector.debug_data_collect_forward(variable, name_with_count) - - # backward save - if save_backward: - self.data_collector.debug_data_collect_backward(variable, grad_name_with_count) diff --git a/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py b/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py index a9cb5e6dd4037dcdeffe3c4d9584ad93c42022d6..cad37cebe8b5de39e3954da0eea2edd26b79223e 100644 --- a/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py +++ b/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py @@ -18,6 +18,7 @@ from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.dump_tool_factory import DumpToolFactory from msprobe.mindspore.overflow_check.overflow_check_tool_factory import OverflowCheckToolFactory from msprobe.mindspore.free_benchmark.self_check_tool_factory import SelfCheckToolFactory +from msprobe.mindspore.exception_dump.exception_dump_tool_factory import ExceptionDumpToolFactory class TaskHandlerFactory: @@ -25,15 +26,19 @@ class TaskHandlerFactory: Const.TENSOR: DumpToolFactory, Const.STATISTICS: DumpToolFactory, Const.OVERFLOW_CHECK: OverflowCheckToolFactory, - Const.FREE_BENCHMARK: SelfCheckToolFactory + Const.FREE_BENCHMARK: SelfCheckToolFactory, + Const.EXCEPTION_DUMP: ExceptionDumpToolFactory } @staticmethod - def create(config: DebuggerConfig): + def create(config: DebuggerConfig, model=None): task = TaskHandlerFactory.tasks.get(config.task) if not task: raise Exception("Valid task is needed.") - handler = task.create(config) + if task == DumpToolFactory: + handler = task.create(config, model) + else: + handler = task.create(config) if not handler: raise Exception("Can not find task handler") return handler diff --git a/debug/accuracy_tools/msprobe/msprobe.py b/debug/accuracy_tools/msprobe/msprobe.py index 8e0386fde6dccc071c3d9d8e1a86729a2c483c7c..221fb1da81a8562c6852fadb6927496779917fba 100644 --- a/debug/accuracy_tools/msprobe/msprobe.py +++ b/debug/accuracy_tools/msprobe/msprobe.py @@ -22,6 +22,8 @@ from msprobe.core.common.log import logger from msprobe.core.compare.utils import _compare_parser from msprobe.core.compare.compare_cli import compare_cli from msprobe.core.compare.merge_result.merge_result_cli import _merge_result_parser, merge_result_cli +from msprobe.core.config_check.config_check_cli import _config_checking_parser, \ + _run_config_checking_command def is_module_available(module_name): @@ -51,6 +53,9 @@ def main(): graph_service_cmd_parser = subparsers.add_parser('graph') op_generate_cmd_parser = subparsers.add_parser('op_generate') merge_result_parser = subparsers.add_parser('merge_result') + config_checking_parser = subparsers.add_parser('config_check') + nan_analyze_parser = subparsers.add_parser('nan_analyze') + _config_checking_parser(config_checking_parser) _compare_parser(compare_cmd_parser) _merge_result_parser(merge_result_parser) @@ -71,6 +76,7 @@ def main(): from msprobe.visualization.graph_service import _pt_graph_service_parser, _pt_graph_service_command from msprobe.pytorch.api_accuracy_checker.generate_op_script.op_generator import _op_generator_parser, \ _run_operator_generate_commond + from msprobe.nan_analyze.analyzer import _nan_analyze_parser, _run_nan_analyze _run_ut_parser(run_ut_cmd_parser) _run_ut_parser(multi_run_ut_cmd_parser) @@ -80,6 +86,7 @@ def main(): _run_overflow_check_parser(run_overflow_check_cmd_parser) _pt_graph_service_parser(graph_service_cmd_parser) _op_generator_parser(op_generate_cmd_parser) + _nan_analyze_parser(nan_analyze_parser) elif framework_args.framework == Const.MS_FRAMEWORK: from msprobe.mindspore.api_accuracy_checker.cmd_parser import add_api_accuracy_checker_argument from msprobe.visualization.graph_service import _ms_graph_service_parser, _ms_graph_service_command @@ -91,6 +98,10 @@ def main(): _ms_graph_service_parser(graph_service_cmd_parser) + from msprobe.mindspore.api_accuracy_checker.generate_op_script.op_generator import _op_generator_parser, \ + _run_operator_generate_commond + _op_generator_parser(op_generate_cmd_parser) + args = parser.parse_args(sys.argv[1:]) if sys.argv[2] == Const.PT_FRAMEWORK: if not is_torch_available: @@ -118,6 +129,10 @@ def main(): compare_cli(args) elif sys.argv[3] == "merge_result": merge_result_cli(args) + elif sys.argv[3] == "config_check": + _run_config_checking_command(args) + elif sys.argv[3] == "nan_analyze": + _run_nan_analyze(args) else: if not is_module_available(Const.MS_FRAMEWORK): logger.error("MindSpore does not exist, please install MindSpore library") @@ -134,9 +149,13 @@ def main(): mul_api_checker_main(args) elif sys.argv[3] == "graph": _ms_graph_service_command(args) + elif sys.argv[3] == 'op_generate': + _run_operator_generate_commond(args) elif sys.argv[3] == "code_mapping": from msprobe.mindspore.code_mapping.main import code_mapping_main code_mapping_main(args) + elif sys.argv[3] == "config_check": + _run_config_checking_command(args) if __name__ == "__main__": diff --git a/debug/accuracy_tools/msprobe/nan_analyze/__init__.py b/debug/accuracy_tools/msprobe/nan_analyze/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b14094e3f9a77a0970342980ed8de1017f58ce19 --- /dev/null +++ b/debug/accuracy_tools/msprobe/nan_analyze/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/nan_analyze/analyzer.py b/debug/accuracy_tools/msprobe/nan_analyze/analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..e147f23b7c7bd514a13251830e0365928876bc75 --- /dev/null +++ b/debug/accuracy_tools/msprobe/nan_analyze/analyzer.py @@ -0,0 +1,255 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from collections import defaultdict +import os +from itertools import dropwhile, chain + +from msprobe.core.common.file_utils import check_file_or_directory_path, save_json, make_dir +from msprobe.core.common.log import logger +from msprobe.core.common.const import Const +from msprobe.nan_analyze.utils import (RankPath, FileCache, is_communication_op, is_ignore_op, NanAnalyseConst, + analyze_anomaly_in_group) +from msprobe.nan_analyze.graph import DataNode, CommunicationNode + + +class NanAnalyzer: + def __init__(self, input_path, output_path): + self._input_path = input_path + self._output_path = output_path + self._paths = {} + self._resolve_input_path() + self._anomaly_nodes = [] # 记录所有异常节点 + self._cache = FileCache() + self._first_comm_nodes = {} # 记录各rank下首个通信节点的node_id + self._after_comm_anomalies = {} # 记录各rank下发生在通信节点之后的异常计算节点 + self._rank_comm_nodes_dict = {} # 记录各rank的通信节点 + + def analyze(self): + for analyze_func in [self._pre_analyze, self._analyze, self._post_analyze]: + analyze_func() + if self._anomaly_nodes: + self._gen_analyze_info() + return + logger.info('Cannot find any anomaly node, no need to generate analyze file.') + + def _resolve_input_path(self): + contents = os.listdir(self._input_path) + for path in contents: + if not path.startswith('rank'): + continue + rank_str = path.strip('rank') + if not rank_str: + rank = 0 + elif not rank_str.isdigit(): + continue + else: + rank = int(rank_str) + dump_path = os.path.join(self._input_path, path, NanAnalyseConst.DUMP_FILE) + construct_path = os.path.join(self._input_path, path, NanAnalyseConst.CONSTRUCT_FILE) + stack_path = os.path.join(self._input_path, path, NanAnalyseConst.STACK_FILE) + self._paths[rank] = RankPath(rank, dump_path, construct_path, stack_path) + + def _pre_analyze(self): + logger.info('Start searching anomaly node before communication.') + for path in self._paths.values(): + dump_data = self._cache.load_json(path.dump_path).get('data') + if not dump_data: + logger.warning(f'Rank {path.rank} has no dump data!') + continue + for op_name, op_data in dump_data.items(): + if is_communication_op(op_name): + self._first_comm_nodes[path.rank] = op_name + break + data_node = DataNode(op_name, path.rank, op_data) + if data_node.is_anomaly(): + self._anomaly_nodes.append(data_node) + break + + def _analyze(self): + logger.info('Start searching anomaly node during communication.') + self._rank_comm_nodes_dict = {rank: self._analyze_comm_nodes(rank) for rank in self._paths} + self._connect_comm_nodes() + self._pruning() + self._search_first_anomaly() + + def _post_analyze(self): + logger.info('Start searching anomaly node after communication.') + for nodes in self._after_comm_anomalies.values(): + if nodes: + self._anomaly_nodes.append(nodes[0]) + + def _gen_analyze_info(self): + if not os.path.exists(self._output_path): + make_dir(self._output_path) + file_name = f'anomaly_analyze_{time.time_ns()}.json' + result_file = os.path.join(self._output_path, file_name) + result_content = defaultdict(list) + for node in self._anomaly_nodes: + result_content[f'rank_{node.rank}'].append(node.gen_node_info(self._paths[node.rank])) + save_json(result_file, result_content, 2) + logger.info(f"The analyze result is saved in: {result_file}") + + def _analyze_comm_nodes(self, rank): + path = self._paths[rank] + data = self._cache.load_json(path.dump_path).get('data') + communication_nodes = {} + if rank not in self._first_comm_nodes: # 此rank没有通信节点 + return communication_nodes + last_node_id = None # 记录上一个通信节点的node_id + compute_ops = [] # 记录两个通信节点之间的计算节点 + sub_layer = 0 # 记录两个通信算子之间异常计算节点的调用序数 + for op_name in dropwhile(lambda k: k != self._first_comm_nodes[rank], data): + node_id = f'{rank}.{op_name}' + op_data = data[op_name] + if is_communication_op(op_name): + comm_node = CommunicationNode(node_id, rank, DataNode(op_name, rank, op_data, sub_layer=sub_layer), + compute_ops=compute_ops) + if last_node_id: + communication_nodes.get(last_node_id).add_next(comm_node) + communication_nodes[node_id] = comm_node + last_node_id = node_id + compute_ops = [] + sub_layer = 0 + elif not is_ignore_op(op_name): + data_node = DataNode(op_name, rank, op_data, sub_layer=sub_layer) + if data_node.is_anomaly(): + compute_ops.append(data_node) + sub_layer += 1 + if compute_ops: + self._after_comm_anomalies[rank] = compute_ops + return communication_nodes + + def _connect_comm_nodes(self): + searched_ranks = set() + for rank, nodes in list(self._rank_comm_nodes_dict.items())[:-1]: + searched_ranks.add(rank) + seen_nodes = set() + for cur_node in nodes.values(): + conn_info = cur_node.find_connected_nodes() + if not conn_info.get('ranks'): + conn_info['ranks'] = self._rank_comm_nodes_dict.keys() + if not self._find_connection(conn_info, cur_node, searched_ranks, seen_nodes): + logger.info(f'Cannot find connected communication node for "{cur_node.node_id}".') + + def _find_connection(self, conn_info, cur_node, searched_ranks, seen_nodes): + def connect(): + seen_nodes.add(search_node.node_id) + if search_node.type == NanAnalyseConst.DST: + cur_node.add_dst(search_node) + elif search_node.type == NanAnalyseConst.SRC: + search_node.layer = cur_node.layer + search_node.add_dst(cur_node) + else: + cur_node.add_link(search_node) + + found = cur_node.connected + for connected_rank in conn_info['ranks']: + if connected_rank in searched_ranks: + continue + tar_id_prefix = f'{connected_rank}.{conn_info["api"]}' + for search_id, search_node in self._rank_comm_nodes_dict[connected_rank].items(): + if search_id in seen_nodes: + continue + if not (search_id.startswith(tar_id_prefix) and search_node.type == conn_info.get('type')): + continue + search_conn_ranks = search_node.find_connected_nodes().get('ranks') + if ((not search_conn_ranks and search_node.api not in NanAnalyseConst.DIRECTED_API) or + cur_node.rank in search_conn_ranks): # 有些无向通信算子没有填ProcessGroup,默认连接所有rank + connect() + found = True + break + return found + + def _pruning(self): + deleted_node_id = [] + for nodes in self._rank_comm_nodes_dict.values(): + for node_id in list(nodes.keys()): + node = nodes[node_id] + if node.has_nan_inf() or node.compute_ops: + continue + deleted_node_id.append(node_id) + node.delete() + del nodes[node_id] + logger.debug(f'After pruning, following nodes are removed: [{", ".join(deleted_node_id)}]') + + def _search_first_anomaly(self): + nodes_queues = [] + for comm_nodes in self._rank_comm_nodes_dict.values(): + nodes_queues.append(sorted(list(comm_nodes.values()), key=lambda x: x.layer)) + seen_nodes = set() + + def get_next_node(node_list): + while node_list: + next_node = node_list.pop(0) + if next_node.node_id not in seen_nodes: + return next_node + return None + + def find_all_members(ori_node): + ids = get_relative_ids(ori_node) + id_queue = list(chain(*[get_relative_ids(self._get_node_by_id(n_id)).difference(ids) for n_id in ids])) + while id_queue: + new_id = id_queue.pop(0) + ids.add(new_id) + id_queue.extend(get_relative_ids(self._get_node_by_id(new_id)).difference(ids)) + return ids + + def get_relative_ids(ori_node): + if not ori_node: + return set() + return ({ori_node.node_id} | ori_node.link_nodes.keys() | ori_node.src_nodes.keys() | + ori_node.dst_nodes.keys()) + + while any(nodes_queues): + groups = [] + all_ids_in_groups = set() + for nodes in nodes_queues: + node = get_next_node(nodes) + if not node: + continue + if not groups or node.node_id in all_ids_in_groups: + new_group = find_all_members(node) + groups.append(new_group) + all_ids_in_groups.update(new_group) + for group in groups: + seen_nodes.update(group) + self._anomaly_nodes.extend(analyze_anomaly_in_group([self._get_node_by_id(n_id) for n_id in group])) + if self._anomaly_nodes: + self._anomaly_nodes = [min(self._anomaly_nodes, key=lambda x: (x.layer, x.sub_layer))] + return + + def _get_node_by_id(self, node_id): + splits = node_id.split(Const.SEP, 1) + if len(splits) < 2 or not splits[0].isdigit(): + logger.error(f'invalid node_id {node_id}') + raise RuntimeError(f'invalid node_id {node_id}') + rank = int(splits[0]) + return self._rank_comm_nodes_dict.get(rank, {}).get(node_id) + + +def _nan_analyze_parser(parser): + parser.add_argument("-i", "--input_path", dest="input_path", default="", type=str, + help=" The dump file path, over step level. eg: \"xxx/step_0/\".", + required=True) + parser.add_argument("-o", "--output_path", dest="output_path", default="./output", type=str, + help=" The nan inf analyze result output file path.", + required=False) + + +def _run_nan_analyze(args): + check_file_or_directory_path(args.input_path, True) + NanAnalyzer(args.input_path, args.output_path).analyze() \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/nan_analyze/graph.py b/debug/accuracy_tools/msprobe/nan_analyze/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..13ec61444734b0ebff9c8174bc3aaf36d0f94c4e --- /dev/null +++ b/debug/accuracy_tools/msprobe/nan_analyze/graph.py @@ -0,0 +1,193 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from msprobe.core.common.const import Const +from msprobe.core.common.log import logger +from msprobe.nan_analyze.utils import FileCache, RankPath, is_ignore_op, check_item_anomaly, NanAnalyseConst +from msprobe.core.common.exceptions import MsprobeException + + +@dataclass +class DataNode: + op_name: str + rank: int + inputs: list + input_args: list + input_kwargs: dict + outputs: dict + layer: int = 0 # 和communication_node的layer保持一致 + sub_layer: int = 0 # 调用顺序,越小表示越先调用 + + def __init__(self, op_name, rank, op_data, **kwargs): + self.op_name = op_name + self.rank = rank + self.inputs = op_data.get(Const.INPUT, []) + self.input_args = op_data.get(Const.INPUT_ARGS, []) + self.input_kwargs = op_data.get(Const.INPUT_KWARGS, {}) + self.outputs = op_data.get(Const.OUTPUT, {}) + self.sub_layer = kwargs.get('sub_layer', 0) + + @staticmethod + def find_complete_construct(construct_info, op_name): + construct = [op_name] + seen = set(op_name) + while True: + op_name = construct_info.get(op_name) + if not op_name or op_name in seen: + return construct + construct.insert(0, op_name) + seen.add(op_name) + + def find_stack(self, stack_info): + for item in stack_info.values(): + if not isinstance(item, list): + raise MsprobeException(MsprobeException.UNSUPPORTED_TYPE_ERROR, + f'The value\'s type in stack.json should be a list, not {type(item)}!') + if len(item) >= 2 and self.op_name in item[0]: + return item[1] + return {} + + def is_anomaly(self) -> bool: + if is_ignore_op(self.op_name): + return False + is_input_anomaly = (check_item_anomaly(self.inputs) or check_item_anomaly(self.input_args) or + check_item_anomaly(self.input_kwargs)) + is_output_anomaly = check_item_anomaly(self.outputs) + return (not is_input_anomaly) and is_output_anomaly + + def gen_node_info(self, path: RankPath): + cache = FileCache() + construct = cache.load_json(path.construct_path) + stack = cache.load_json(path.stack_path) + if Const.FORWARD in self.op_name: + data_info_list = {Const.INPUT_ARGS: self.input_args, Const.INPUT_KWARGS: self.input_kwargs, + Const.OUTPUT: self.outputs} + else: + data_info_list = {Const.INPUT: self.inputs, Const.OUTPUT: self.outputs} + return {'op_name': self.op_name, + 'data_info': data_info_list, + 'construct_info': self.find_complete_construct(construct, self.op_name), + 'stack_info': self.find_stack(stack)} + + +class CommunicationNode: + def __init__(self, node_id, rank, data: DataNode, layer=0, **kwargs): + self.node_id = node_id + self.rank = rank + self.data = data + self.layer = layer + op_name_split = self.data.op_name.split(Const.SEP) + if len(op_name_split) < 4: + logger.error(f'invalid op_name: {self.data.op_name}') + raise RuntimeError(f'invalid op_name: {self.data.op_name}') + self.api = op_name_split[1] + self.call_cnt = op_name_split[2] + self.pre_node = kwargs.get('pre_node') + self.link_nodes = kwargs.get('link_nodes', {}) + self.dst_nodes = kwargs.get('dst_nodes', {}) + self.src_nodes = kwargs.get('src_nodes', {}) + self.next_nodes = kwargs.get('next_nodes', {}) + self.compute_ops = kwargs.get('compute_ops', []) + self.type = self._resolve_type() + self.connected = False + + def add_next(self, node): + self.next_nodes[node.node_id] = node + node.pre_node = self + node.layer = self.layer + 1 + node.data.layer = node.layer + + def add_link(self, node): + self.link_nodes[node.node_id] = node + node.link_nodes[self.node_id] = self + node.layer = self.layer + node.data.layer = node.layer + self.connected = True + node.connected = True + + def add_dst(self, node): + self.dst_nodes[node.node_id] = node + node.src_nodes[self.node_id] = self + node.layer = self.layer + node.data.layer = node.layer + self.connected = True + node.connected = True + + def delete(self): + for node in self.next_nodes.values(): + node.pre_node = None + for node in self.dst_nodes.values(): + node.src_nodes.pop(self.node_id) + for node in self.src_nodes.values(): + node.dst_nodes.pop(self.node_id) + for node in self.link_nodes.values(): + node.link_nodes.pop(self.node_id) + if self.pre_node: + self.pre_node.next_nodes.pop(self.node_id) + + def has_nan_inf(self): + return self.input_has_nan_inf() or check_item_anomaly(self.data.outputs) + + def input_has_nan_inf(self): + return check_item_anomaly(self.data.input_args) or check_item_anomaly(self.data.input_kwargs) + + def find_connected_nodes(self): + """ + 根据 api/类型/入参/调用次数 确定相连接的node的op_name + """ + tar_api = NanAnalyseConst.P2P_API_MAPPING.get(self.api, self.api) + ranks = set() + for dst in [NanAnalyseConst.DST, NanAnalyseConst.DST_GROUP]: + if dst in self.data.input_kwargs: + dst_value = self.data.input_kwargs.get(dst) + if dst_value: + ranks.add(dst_value.get('value')) + break + for src in [NanAnalyseConst.SRC, NanAnalyseConst.SRC_GROUP]: + if src in self.data.input_kwargs: + src_value = self.data.input_kwargs.get(src) + if src_value: + ranks.add(src_value.get('value')) + break + if not ranks: + for item in self.data.input_args: + if isinstance(item, dict) and item.get(Const.TYPE) == 'int': + ranks.add(item.get('value')) + group = self.data.input_kwargs.get('group') + if group: + ranks.update(group.get('group_ranks')) + return {'ranks': ranks, 'api': f'Distributed.{tar_api}', + 'type': NanAnalyseConst.OPPOSITE_DIR.get(self.type, NanAnalyseConst.LINK)} + + def _resolve_type(self): + for src in [NanAnalyseConst.SRC, NanAnalyseConst.SRC_GROUP]: + if src in self.data.input_kwargs and self.data.input_kwargs[src]: + if self.data.input_kwargs[src].get('value') == self.rank: + return NanAnalyseConst.SRC + else: + return NanAnalyseConst.DST + for dst in [NanAnalyseConst.DST, NanAnalyseConst.DST_GROUP]: + if dst in self.data.input_kwargs and self.data.input_kwargs[dst]: + if self.data.input_kwargs[dst].get('value') == self.rank: + return NanAnalyseConst.DST + else: + return NanAnalyseConst.SRC + if self.api in NanAnalyseConst.DIRECTED_API: + for item in self.data.input_args: + if item.get(Const.TYPE) == 'int': + node_type = NanAnalyseConst.DIRECTED_API[self.api] + return node_type if item.get('value') == self.rank else NanAnalyseConst.OPPOSITE_DIR[node_type] + return NanAnalyseConst.LINK diff --git a/debug/accuracy_tools/msprobe/nan_analyze/utils.py b/debug/accuracy_tools/msprobe/nan_analyze/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aaed65106f32f6c6cfc11911596bc11accd6a0df --- /dev/null +++ b/debug/accuracy_tools/msprobe/nan_analyze/utils.py @@ -0,0 +1,211 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from dataclasses import dataclass +import sys +import time +import psutil + +from msprobe.core.common.const import CompareConst +from msprobe.core.common.file_utils import check_file_or_directory_path, load_json + + +@dataclass +class RankPath: + rank: int + dump_path: str + construct_path: str + stack_path: str + + def __init__(self, rank, dump_path, construct_path, stack_path): + self.rank = rank + check_file_or_directory_path(dump_path) + self.dump_path = dump_path + check_file_or_directory_path(construct_path) + self.construct_path = construct_path + check_file_or_directory_path(stack_path) + self.stack_path = stack_path + + +class FileCache: + """ + lazy load file + """ + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super().__new__(cls, *args, **kwargs) + return cls._instance + + def __init__(self): + self._max_memory_usage = psutil.virtual_memory().available / 4 # 最大占用当前可用内存空间的1/4 + self._cache = OrderedDict() + self._access_cnt = {} + self._access_time = {} + self._size = {} + + @staticmethod + def _sizeof(obj): + seen = set() + objs = [obj] + size = 0 + while objs: + obj = objs.pop() + obj_id = id(obj) + if obj_id in seen: + continue + seen.add(obj_id) + size += sys.getsizeof(obj) + if isinstance(obj, dict): + objs.extend(obj.keys()) + objs.extend(obj.values()) + elif isinstance(obj, (list, tuple, set, frozenset)): + objs.extend(obj) + return size + + def load_json(self, json_path): + if json_path in self._cache: + self._access_cnt[json_path] += 1 + self._access_time[json_path] = time.monotonic() + self._cache.move_to_end(json_path) + return self._cache[json_path] + self._cleanup() + return self._load(json_path) + + def _load(self, json_path): + data = load_json(json_path) + self._add_to_cache(json_path, data) + return data + + def _add_to_cache(self, key, data): + if key in self._cache: + self._cache.move_to_end(key) + else: + self._cache[key] = data + self._access_cnt[key] = 0 + self._access_time[key] = time.monotonic() + self._size[key] = self._sizeof(data) + + def _calc_cache_size(self): + return sys.getsizeof(self._cache) + sum(self._size.values()) + + def _cleanup(self): + while self._calc_cache_size() > self._max_memory_usage and self._cache: + least_frequent_key = min(self._access_cnt.keys(), key=lambda k: self._access_cnt[k]) + least_recent_key = min(self._access_time.keys(), key=lambda k: self._access_time[k]) + largest_key = max(self._cache.keys(), key=lambda k: self._size[k]) + key_to_rm = min([least_frequent_key, least_recent_key, largest_key], + key=lambda k: (self._access_cnt[k], self._access_time[k], -self._size[k])) + del self._cache[key_to_rm] + del self._access_cnt[key_to_rm] + del self._access_time[key_to_rm] + del self._size[key_to_rm] + + +def is_communication_op(op_name): + # 定义通信算子的关键字,覆盖各种通信操作,如all_reduce, send, broadcast等 + # 从wrap文件中读取,先硬编码在文件中 + return (op_name.startswith('Distributed.') and + any(keyword in op_name for keyword in NanAnalyseConst.COMMUNICATION_KEYWORDS)) + + +def is_ignore_op(op_name): + ignore_keywords = [ + 'Torch.empty', + 'Torch.fill' + ] + return any(keyword in op_name for keyword in ignore_keywords) + + +def check_item_anomaly(param): + def has_nan_inf(dict_obj, key): + return str(dict_obj.get(key)).lower() in CompareConst.OVERFLOW_LIST + + items = [] + if isinstance(param, list): + items = param + elif isinstance(param, dict): + items = param.values() + for item in items: + if not isinstance(item, dict): + continue + if has_nan_inf(item, 'Max') or has_nan_inf(item, 'Min'): + return True + return False + + +def analyze_anomaly_in_group(nodes_group): + anomaly_nodes = [] + + def get_compute_ops_from_comm_nodes(comm_nodes): + for comm_node in comm_nodes: + for op_node in comm_node.compute_ops: + op_node.layer = comm_node.layer + anomaly_nodes.append(op_node) + + def get_comm_ops(comm_nodes): + for node in comm_nodes: + node.data.layer = node.layer + anomaly_nodes.append(node.data) + + # 先看src或link中input是否有异常 + src_list = list(filter(lambda node: node.type in [NanAnalyseConst.SRC, NanAnalyseConst.LINK], nodes_group)) + input_anomaly_nodes = list(filter(lambda node: node.input_has_nan_inf(), src_list)) + # 如果有异常回溯计算节点找到异常来源 + # 使用cpu模拟节点进行计算,查看结果是否有问题。需要对所有计算节点录入/映射,暂不实现。 + get_compute_ops_from_comm_nodes(input_anomaly_nodes) + # 筛选入参没问题但出参有问题的通信节点 + output_anomaly_nodes = list(filter(lambda node: node.data.is_anomaly(), nodes_group)) + get_comm_ops(output_anomaly_nodes) + return anomaly_nodes + + +class NanAnalyseConst: + COMMUNICATION_KEYWORDS = { + 'send', # send 算子 + 'recv', # recv 算子 + 'broadcast', # broadcast 算子 + 'all_reduce', # all_reduce 算子 + 'reduce', # reduce 算子 + 'all_gather', # all_gather 算子 + 'gather', # gather 算子 + 'isend', # isend 算子 + 'irecv', # irecv 算子 + 'scatter', # scatter 算子 + 'reduce_scatter', # reduce_scatter 算子 + '_reduce_scatter_base', # _reduce_scatter_base 算子 + '_all_gather_base', # _all_gather_base 算子 + 'all_to_all_single', # all_to_all_single 算子 + 'all_to_all', # all_to_all 算子 + 'all_gather_into_tensor', # all_gather_into_tensor 算子 + 'reduce_scatter_tensor', # reduce_scatter_tensor 算子 + 'send_object_list', # send_object_list 算子 + 'recv_object_list' # recv_object_list 算子 + } + P2P_API_MAPPING = {'send': 'recv', 'recv': 'send', 'isend': 'irecv', 'irecv': 'isend', + 'send_object_list': 'recv_object_list', 'recv_object_list': 'send_object_list'} + SRC = 'src' + DST = 'dst' + SRC_GROUP = 'src_group' + DST_GROUP = 'dst_group' + LINK = 'link' + DIRECTED_API = {'send': DST, 'recv': SRC, 'isend': DST, 'irecv': SRC, 'broadcast': SRC, 'scatter': SRC, + 'gather': DST, 'send_object_list': DST, 'recv_object_list': SRC} + OPPOSITE_DIR = {SRC: DST, DST: SRC} + DUMP_FILE = "dump.json" + CONSTRUCT_FILE = "construct.json" + STACK_FILE = "stack.json" diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py index f2b2d6a30463c62846bcc02e147c9c319f55d1b8..7c9beab9b281a567f80c0049c444c659bbd84d6c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py @@ -24,8 +24,7 @@ from msprobe.pytorch.pt_config import RunUTConfig RunUtConfig = namedtuple('RunUtConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path', 'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list', - 'black_list', 'error_data_path', 'online_config']) -OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path']) + 'black_list', 'error_data_path']) class Config: @@ -46,13 +45,7 @@ class Config: 'white_list': list, 'black_list': list, 'error_data_path': str, - 'precision': int, - 'is_online': bool, - 'nfs_path': str, - 'host': str, - 'port': int, - 'rank_list': list, - 'tls_path': str + 'precision': int } if key not in validators: raise ValueError(f"{key} must be one of {validators.keys()}") @@ -68,10 +61,6 @@ class Config: RunUTConfig.check_filter_list_config(key, value) if key == 'error_data_path': RunUTConfig.check_error_data_path_config(value) - if key == 'nfs_path': - RunUTConfig.check_nfs_path_config(value) - if key == 'tls_path': - RunUTConfig.check_tls_path_config(value) return value @@ -85,12 +74,6 @@ class CheckerConfig: self.white_list = msCheckerConfig.white_list self.black_list = msCheckerConfig.black_list self.error_data_path = msCheckerConfig.error_data_path - self.is_online = msCheckerConfig.is_online - self.nfs_path = msCheckerConfig.nfs_path - self.host = msCheckerConfig.host - self.port = msCheckerConfig.port - self.rank_list = msCheckerConfig.rank_list - self.tls_path = msCheckerConfig.tls_path if task_config: self.load_config(task_config) @@ -99,22 +82,7 @@ class CheckerConfig: self.white_list = task_config.white_list self.black_list = task_config.black_list self.error_data_path = task_config.error_data_path - self.is_online = task_config.is_online - self.nfs_path = task_config.nfs_path - self.host = task_config.host - self.port = task_config.port - self.rank_list = task_config.rank_list - self.tls_path = task_config.tls_path - def get_online_config(self): - return OnlineConfig( - is_online=self.is_online, - nfs_path=self.nfs_path, - host=self.host, - port=self.port, - rank_list=self.rank_list, - tls_path=self.tls_path - ) def get_run_ut_config(self, **config_params): return RunUtConfig( @@ -125,8 +93,7 @@ class CheckerConfig: save_error_data=config_params.get('save_error_data'), is_continue_run_ut=config_params.get('is_continue_run_ut'), real_data_path=config_params.get('real_data_path'), - white_list=self.white_list, - black_list=self.black_list, - error_data_path=config_params.get('error_data_path'), - online_config=self.get_online_config() + white_list=self.white_list.copy() if self.white_list else [], + black_list=self.black_list.copy() if self.black_list else [], + error_data_path=config_params.get('error_data_path') ) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py index 8f7db73b58f42a4a64728bb0f12d25cf6f9f9ebe..24ac8b17ced04ea186898644925c77912dd294be 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py @@ -40,7 +40,7 @@ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validat from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments, extract_basic_api_segments from msprobe.core.common.file_utils import FileChecker, change_mode, create_directory from msprobe.pytorch.common.log import logger -from msprobe.core.common.utils import CompareException +from msprobe.core.common.utils import CompareException, check_op_str_pattern_valid from msprobe.core.common.const import Const, CompareConst, FileCheckConst CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path']) @@ -117,30 +117,6 @@ def api_precision_compare(config): change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY) -def online_api_precision_compare(online_config): - rank = online_config.rank - result_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.result_csv_path).replace( - "_rank*.csv", f"_rank{rank}.csv") - details_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.details_csv_path).replace( - "_rank*.csv", f"_rank{rank}.csv") - detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()] - result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()] - if not os.path.exists(result_csv_path): - write_csv(result_csv_title, result_csv_path) - if not os.path.exists(details_csv_path): - write_csv(detail_csv_title, details_csv_path) - config = CompareConfig("", "", result_csv_path, details_csv_path) - try: - npu_data, gpu_data = online_config.npu_data, online_config.gpu_data - check_csv_columns(npu_data.columns, "npu_csv") - check_csv_columns(gpu_data.columns, "gpu_csv") - analyse_csv(npu_data, gpu_data, config) - except Exception as err: - logger.error(f"Online api precision compare Error: {str(err)}") - change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY) - change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY) - - def analyse_csv(npu_data, gpu_data, config): forward_status, backward_status = [], [] last_api_name, last_api_dtype, last_api_full_name = None, None, None @@ -151,6 +127,7 @@ def analyse_csv(npu_data, gpu_data, config): message = '' compare_column = ApiPrecisionOutputColumn() full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME] + check_op_str_pattern_valid(full_api_name_with_direction_status) row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status] api_name, api_full_name, direction_status = extract_detailed_api_segments(full_api_name_with_direction_status) if not api_full_name: @@ -430,6 +407,7 @@ def _api_precision_compare(parser=None): _api_precision_compare_parser(parser) args = parser.parse_args(sys.argv[1:]) _api_precision_compare_command(args) + logger.info("Compare task completed.") def _api_precision_compare_command(args): @@ -457,8 +435,3 @@ def _api_precision_compare_parser(parser): parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, help=" The api precision compare task result out path.", required=False) - - -if __name__ == '__main__': - _api_precision_compare() - logger.info("Compare task completed.") diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py index cf5928e509e3138ea762cd9d7af6fc26a5d2c5c9..3387faaf96ec9eabb17d26ae98860c0d3b468ba0 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py @@ -40,6 +40,7 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dty DETAIL_TEST_ROWS, BENCHMARK_COMPARE_SUPPORT_LIST from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments from msprobe.pytorch.common.log import logger +from msprobe.core.common.decorator import recursion_depth_decorator ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'bwd_success_status', @@ -65,13 +66,6 @@ class Comparator: self.save_path_list = [result_csv_path] self.detail_save_path_list = [details_csv_path] - if config and config.online_config.is_online: - self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv") - self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv") - self.save_path_list = [self.save_path_str.format(rank) for rank in config.online_config.rank_list] - self.detail_save_path_list = \ - [self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list] - self.registry = self._register_compare_func() if not is_continue_run_ut: @@ -178,6 +172,41 @@ class Comparator: if not os.path.exists(detail_save_path): write_csv(DETAIL_TEST_ROWS, detail_save_path) + @recursion_depth_decorator("compare_core") + def _compare_core(self, api_name, bench_output, device_output): + compare_column = CompareColumn() + if not isinstance(bench_output, type(device_output)): + status = CompareConst.ERROR + message = "bench and npu output type is different." + elif isinstance(bench_output, dict): + b_keys, n_keys = set(bench_output.keys()), set(device_output.keys()) + if b_keys != n_keys: + status = CompareConst.ERROR + message = "bench and npu output dict keys are different." + else: + status, compare_column, message = self._compare_core(api_name, list(bench_output.values()), + list(device_output.values())) + elif isinstance(bench_output, torch.Tensor): + copy_bench_out = bench_output.detach().clone() + copy_device_output = device_output.detach().clone() + compare_column.bench_type = str(copy_bench_out.dtype) + compare_column.npu_type = str(copy_device_output.dtype) + compare_column.shape = tuple(device_output.shape) + status, compare_column, message = self._compare_torch_tensor(api_name, copy_bench_out, copy_device_output, + compare_column) + elif isinstance(bench_output, (bool, int, float, str)): + compare_column.bench_type = str(type(bench_output)) + compare_column.npu_type = str(type(device_output)) + status, compare_column, message = self._compare_builtin_type(bench_output, device_output, compare_column) + elif bench_output is None: + status = CompareConst.SKIP + message = "Bench output is None, skip this test." + else: + status = CompareConst.ERROR + message = "Unexpected output type in compare_core: {}".format(type(bench_output)) + + return status, compare_column, message + def write_summary_csv(self, test_result): test_rows = [] try: @@ -209,9 +238,8 @@ class Comparator: self.write_detail_csv(args) - def compare_output(self, full_api_name, data_info, is_online=False): + def compare_output(self, full_api_name, data_info): """Get compare result and write to result and detail csv. - is_online: bool, default False. True: called by online api precision compare, only compare without write to csv. """ _, api_name = extract_basic_api_segments(full_api_name) if not api_name: @@ -244,9 +272,7 @@ class Comparator: fwd_compare_alg_results, bwd_compare_alg_results, data_info.rank) - if is_online: - # get run_ut compare detail - return self._get_run_ut_detail(result_info) + self.record_results(result_info) return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \ or bwd_success_status == CompareConst.SPACE @@ -293,40 +319,6 @@ class Comparator: test_final_success = CompareConst.WARNING return test_final_success, detailed_result_total - def _compare_core(self, api_name, bench_output, device_output): - compare_column = CompareColumn() - if not isinstance(bench_output, type(device_output)): - status = CompareConst.ERROR - message = "bench and npu output type is different." - elif isinstance(bench_output, dict): - b_keys, n_keys = set(bench_output.keys()), set(device_output.keys()) - if b_keys != n_keys: - status = CompareConst.ERROR - message = "bench and npu output dict keys are different." - else: - status, compare_column, message = self._compare_core(api_name, list(bench_output.values()), - list(device_output.values())) - elif isinstance(bench_output, torch.Tensor): - copy_bench_out = bench_output.detach().clone() - copy_device_output = device_output.detach().clone() - compare_column.bench_type = str(copy_bench_out.dtype) - compare_column.npu_type = str(copy_device_output.dtype) - compare_column.shape = tuple(device_output.shape) - status, compare_column, message = self._compare_torch_tensor(api_name, copy_bench_out, copy_device_output, - compare_column) - elif isinstance(bench_output, (bool, int, float, str)): - compare_column.bench_type = str(type(bench_output)) - compare_column.npu_type = str(type(device_output)) - status, compare_column, message = self._compare_builtin_type(bench_output, device_output, compare_column) - elif bench_output is None: - status = CompareConst.SKIP - message = "Bench output is None, skip this test." - else: - status = CompareConst.ERROR - message = "Unexpected output type in compare_core: {}".format(type(bench_output)) - - return status, compare_column, message - def _compare_torch_tensor(self, api_name, bench_output, device_output, compare_column): cpu_shape = bench_output.shape npu_shape = device_output.shape diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py index 549230d0a9e200283f545eed608a8da5df6a53a8..89c4401b2cac863bc609cce14a9f4c3ca03951b7 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py @@ -73,27 +73,27 @@ DETAIL_TEST_ROWS = [ precision_configs = { - torch.float16 : { - 'small_value' : [ + torch.float16: { + 'small_value': [ 1e-3 ], - 'small_value_atol' : [ + 'small_value_atol': [ 1e-5 ] }, torch.bfloat16: { - 'small_value' : [ + 'small_value': [ 1e-3 ], - 'small_value_atol' : [ + 'small_value_atol': [ 1e-5 ] }, - torch.float32:{ - 'small_value' : [ + torch.float32: { + 'small_value': [ 1e-6 ], - 'small_value_atol' : [ + 'small_value_atol': [ 1e-9 ] } @@ -101,33 +101,33 @@ precision_configs = { ULP_PARAMETERS = { - torch.float16 : { - 'min_eb' : [ + torch.float16: { + 'min_eb': [ -14 ], - 'exponent_num' : [ + 'exponent_num': [ 10 ] }, - torch.bfloat16 : { - 'min_eb' : [ + torch.bfloat16: { + 'min_eb': [ -126 ], - 'exponent_num' : [ + 'exponent_num': [ 7 ] }, - torch.float32 : { - 'min_eb' : [ + torch.float32: { + 'min_eb': [ -126 ], - 'exponent_num' : [ + 'exponent_num': [ 23 ] } } - - + + class ApiPrecisionCompareColumn: API_NAME = 'API Name' DEVICE_DTYPE = 'DEVICE Dtype' @@ -202,7 +202,7 @@ class ApiPrecisionCompareColumn: CompareMessage = { - "topk" : "在npu上,topk的入参sorted=False时不生效,会返回有序tensor,而cpu上会返回无序tensor。 如果topk精度不达标,请检查是否是该原因导致的。" + "topk": "在npu上,topk的入参sorted=False时不生效,会返回有序tensor,而cpu上会返回无序tensor。 如果topk精度不达标,请检查是否是该原因导致的。" } diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml index 2ec9251009e61ef68dbfed987abe457d47b91e9a..2797d0c64cccad149727c4c6a1b86c5cb4290350 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml @@ -2,9 +2,4 @@ white_list: [] black_list: [] error_data_path: './' precision: 14 -is_online: False -nfs_path: "" -host: "" -port: -1 -rank_list: [0] -tls_path: "./" + diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py index 797210f09c3b55a64002a4aa84a3d39770ae803c..c58c058674f31d8acb24a008104cdd32b1969726 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py @@ -28,10 +28,10 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import binary_st ulp_standard_api, thousandth_standard_api from msprobe.core.common.file_utils import FileOpen, load_json, save_json from msprobe.core.common.utils import check_file_or_directory_path, check_op_str_pattern_valid, is_int -from msprobe.core.common.const import Const, MonitorConst, MsgConst +from msprobe.core.common.const import Const, MonitorConst, MsgConst, FileCheckConst from msprobe.core.common.log import logger -from msprobe.core.common.file_utils import make_dir -from msprobe.core.common.utils import recursion_depth_decorator +from msprobe.core.common.file_utils import make_dir, change_mode +from msprobe.core.common.decorator import recursion_depth_decorator TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] TORCH_BOOL_TYPE = ["torch.bool"] @@ -50,6 +50,7 @@ DATA_NAME = "data_name" API_MAX_LENGTH = 30 PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD] DATAMODE_LIST = ["random_data", "real_data"] +ITER_MAX_TIMES = 1000 class APIInfo: @@ -97,6 +98,8 @@ class CommonConfig: iter_t = self.iter_times if iter_t <= 0: raise ValueError("iter_times should be an integer bigger than zero!") + if iter_t > ITER_MAX_TIMES: + raise ValueError("iter_times should not be greater than 1000!") json_file = self.extract_api_path propagation = self.propagation @@ -117,7 +120,7 @@ class CommonConfig: # Retrieve the first API name and dictionary forward_item = next(iter(json_content.items()), None) - if not forward_item or not isinstance(forward_item[1], dict): + if not forward_item or not isinstance(forward_item[1], dict) or not forward_item[1]: raise ValueError(f'Invalid forward API data in json_content!') # if propagation is backward, ensure json file contains forward and backward info @@ -127,7 +130,7 @@ class CommonConfig: # if propagation is backward, ensure it has valid data if propagation == Const.BACKWARD: backward_item = list(json_content.items())[1] - if not isinstance(backward_item[1], dict): + if not isinstance(backward_item[1], dict) or not backward_item[1]: raise ValueError(f'Invalid backward API data in json_content!') return json_content @@ -169,7 +172,7 @@ class APIExtractor: value = self.load_real_data_path(value, real_data_path) new_data[key] = value if not new_data: - logger.error(f"Error: The api '{self.api_name}' does not exist in the file.") + logger.warning(f"Warning: The api '{self.api_name}' does not exist in the file.") else: save_json(self.output_file, new_data, indent=4) logger.info( @@ -183,6 +186,7 @@ class APIExtractor: self.update_data_name(v, dump_data_dir) return value + @recursion_depth_decorator("OpGenerator: APIExtractor.update_data_name") def update_data_name(self, data, dump_data_dir): if isinstance(data, list): for item in data: @@ -407,19 +411,16 @@ class OperatorScriptGenerator: return kwargs_dict_generator - def _op_generator_parser(parser): - parser.add_argument("-i", "--config_input", dest="config_input", default='', type=str, - help=" Path of config json file", required=True) + parser.add_argument("-i", "--config_input", dest="config_input", type=str, + help=" Path of config json file", required=True) parser.add_argument("-o", "--api_output_path", dest="api_output_path", type=str, - help=" Path of extract api_name.json.", - required=True) + help=" Path of extract api_name.json.", required=True) def parse_json_config(json_file_path): if not json_file_path: - config_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) - json_file_path = os.path.join(config_dir, "config.json") + raise Exception("config_input path can not be empty, please check.") json_config = load_json(json_file_path) common_config = CommonConfig(json_config) return common_config @@ -467,6 +468,7 @@ def _run_operator_generate_commond(cmd_args): fout.write(code_template.format(**internal_settings)) except OSError: logger.error(f"Failed to open file. Please check file {template_path} or {operator_script_path}.") + change_mode(operator_script_path, FileCheckConst.DATA_FILE_AUTHORITY) logger.info(f"Generate operator script successfully and the name is {operator_script_path}.") diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template index 131fd211ad82dad8256c48e59195fc335efa936b..c60d84994745e94bef6d05a78d83fae81df7ed1e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template @@ -1,6 +1,6 @@ -import json import os -import math +import re +import stat from enum import Enum, auto import torch try: @@ -25,6 +25,31 @@ RAISE_PRECISION = {{ }} THOUSANDTH_THRESHOLDING = 0.001 BACKWARD = 'backward' +DIR = "dir" +FILE = "file" +READ_ABLE = "read" +WRITE_ABLE = "write" +READ_WRITE_ABLE = "read and write" +DIRECTORY_LENGTH = 4096 +FILE_NAME_LENGTH = 255 +SOFT_LINK_ERROR = "检测到软链接" +FILE_PERMISSION_ERROR = "文件权限错误" +INVALID_FILE_ERROR = "无效文件" +ILLEGAL_PATH_ERROR = "非法文件路径" +ILLEGAL_PARAM_ERROR = "非法打开方式" +FILE_TOO_LARGE_ERROR = "文件过大" +FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" +FILE_SIZE_DICT = {{ + ".pkl": 1073741824, # 1 * 1024 * 1024 * 1024 + ".npy": 10737418240, # 10 * 1024 * 1024 * 1024 + ".json": 1073741824, # 1 * 1024 * 1024 * 1024 + ".pt": 10737418240, # 10 * 1024 * 1024 * 1024 + ".csv": 1073741824, # 1 * 1024 * 1024 * 1024 + ".xlsx": 1073741824, # 1 * 1024 * 1024 * 1024 + ".yaml": 1073741824, # 1 * 1024 * 1024 * 1024 + ".ir": 1073741824 # 1 * 1024 * 1024 * 1024 +}} +COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024 class CompareStandard(Enum): BINARY_EQUALITY_STANDARD = auto() @@ -33,13 +58,189 @@ class CompareStandard(Enum): BENCHMARK_STANDARD = auto() THOUSANDTH_STANDARD = auto() +class FileChecker: + """ + The class for check file. + + Attributes: + file_path: The file or dictionary path to be verified. + path_type: file or dictionary + ability(str): FileCheckConst.WRITE_ABLE or FileCheckConst.READ_ABLE to set file has writability or readability + file_type(str): The correct file type for file + """ + + def __init__(self, file_path, path_type, ability=None, file_type=None, is_script=True): + self.file_path = file_path + self.path_type = self._check_path_type(path_type) + self.ability = ability + self.file_type = file_type + self.is_script = is_script + + @staticmethod + def _check_path_type(path_type): + if path_type not in [DIR, FILE]: + print(f'ERROR: The path_type must be {{DIR}} or {{FILE}}.') + raise Exception(ILLEGAL_PARAM_ERROR) + return path_type + + def common_check(self): + """ + 功能:用户校验基本文件权限:软连接、文件长度、是否存在、读写权限、文件属组、文件特殊字符 + 注意:文件后缀的合法性,非通用操作,可使用其他独立接口实现 + """ + FileChecker.check_path_exists(self.file_path) + FileChecker.check_link(self.file_path) + self.file_path = os.path.realpath(self.file_path) + FileChecker.check_path_length(self.file_path) + FileChecker.check_path_type(self.file_path, self.path_type) + self.check_path_ability() + if self.is_script: + FileChecker.check_path_owner_consistent(self.file_path) + FileChecker.check_path_pattern_valid(self.file_path) + FileChecker.check_common_file_size(self.file_path) + FileChecker.check_file_suffix(self.file_path, self.file_type) + if self.path_type == FILE: + FileChecker.check_dirpath_before_read(self.file_path) + return self.file_path + + def check_path_ability(self): + if self.ability == WRITE_ABLE: + FileChecker.check_path_writability(self.file_path) + if self.ability == READ_ABLE: + FileChecker.check_path_readability(self.file_path) + if self.ability == READ_WRITE_ABLE: + FileChecker.check_path_readability(self.file_path) + FileChecker.check_path_writability(self.file_path) + + @staticmethod + def check_path_exists(path): + if not os.path.exists(path): + print(f'ERROR: The file path %s does not exist.' % path) + raise Exception() + + @staticmethod + def check_link(path): + abs_path = os.path.abspath(path) + if os.path.islink(abs_path): + print('ERROR: The file path {{}} is a soft link.'.format(path)) + raise Exception(SOFT_LINK_ERROR) + + @staticmethod + def check_path_length(path, name_length=None): + file_max_name_length = name_length if name_length else FILE_NAME_LENGTH + if len(path) > DIRECTORY_LENGTH or \ + len(os.path.basename(path)) > file_max_name_length: + print(f'ERROR: The file path length exceeds limit.') + raise Exception(ILLEGAL_PATH_ERROR) + + @staticmethod + def check_path_type(file_path, file_type): + if file_type == FILE: + if not os.path.isfile(file_path): + print(f"ERROR: The {{file_path}} should be a file!") + raise Exception(INVALID_FILE_ERROR) + if file_type == DIR: + if not os.path.isdir(file_path): + print(f"ERROR: The {{file_path}} should be a dictionary!") + raise Exception(INVALID_FILE_ERROR) + + @staticmethod + def check_path_owner_consistent(path): + file_owner = os.stat(path).st_uid + if file_owner != os.getuid() and os.getuid() != 0: + print('ERROR: The file path %s may be insecure because is does not belong to you.' % path) + raise Exception(FILE_PERMISSION_ERROR) + + @staticmethod + def check_path_pattern_valid(path): + if not re.match(FILE_VALID_PATTERN, path): + print('ERROR: The file path %s contains special characters.' % (path)) + raise Exception(ILLEGAL_PATH_ERROR) + + @staticmethod + def check_common_file_size(file_path): + if os.path.isfile(file_path): + for suffix, max_size in FILE_SIZE_DICT.items(): + if file_path.endswith(suffix): + FileChecker.check_file_size(file_path, max_size) + return + FileChecker.check_file_size(file_path, COMMOM_FILE_SIZE) + + @staticmethod + def check_file_size(file_path, max_size): + try: + file_size = os.path.getsize(file_path) + except OSError as os_error: + print(f'ERROR: Failed to open "{{file_path}}". {{str(os_error)}}') + raise Exception(INVALID_FILE_ERROR) from os_error + if file_size >= max_size: + print(f'ERROR: The size ({{file_size}}) of {{file_path}} exceeds ({{max_size}}) bytes, tools not support.') + raise Exception(FILE_TOO_LARGE_ERROR) + + @staticmethod + def check_file_suffix(file_path, file_suffix): + if file_suffix: + if not file_path.endswith(file_suffix): + print(f"The {{file_path}} should be a {{file_suffix}} file!") + raise Exception(INVALID_FILE_ERROR) + + @staticmethod + def check_dirpath_before_read(path): + path = os.path.realpath(path) + dirpath = os.path.dirname(path) + if FileChecker.check_others_writable(dirpath): + print(f"WARNING: The directory is writable by others: {{dirpath}}.") + try: + FileChecker.check_path_owner_consistent(dirpath) + except Exception: + print(f"WARNING: The directory {{dirpath}} is not yours.") + + @staticmethod + def check_others_writable(directory): + dir_stat = os.stat(directory) + is_writable = ( + bool(dir_stat.st_mode & stat.S_IWGRP) or # 组可写 + bool(dir_stat.st_mode & stat.S_IWOTH) # 其他用户可写 + ) + return is_writable + + @staticmethod + def check_path_readability(path): + if not os.access(path, os.R_OK): + print('ERROR: The file path %s is not readable.' % path) + raise Exception(FILE_PERMISSION_ERROR) + + @staticmethod + def check_path_writability(path): + if not os.access(path, os.W_OK): + print('ERROR: The file path %s is not writable.' % path) + raise Exception(FILE_PERMISSION_ERROR) + + +def check_file_or_directory_path(path, isdir=False): + """ + Function Description: + check whether the path is valid + Parameter: + path: the path to check + isdir: the path is dir or file + Exception Description: + when invalid data throw exception + """ + if isdir: + path_checker = FileChecker(path, DIR, WRITE_ABLE) + else: + path_checker = FileChecker(path, FILE, READ_ABLE) + path_checker.common_check() + def load_pt(pt_path, to_cpu=False): pt_path = os.path.realpath(pt_path) + check_file_or_directory_path(pt_path) try: if to_cpu: - pt = torch.load(pt_path, map_location=torch.device("cpu")) + pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=True) else: - pt = torch.load(pt_path) + pt = torch.load(pt_path, weights_only=True) except Exception as e: raise RuntimeError(f"load pt file {{pt_path}} failed") from e return pt @@ -202,6 +403,7 @@ def compare_tensor(out_device, out_bench, api_name): else: abs_err = torch.abs(out_device - out_bench) abs_bench = torch.abs(out_bench) + eps = 2 ** -23 if dtype_bench == torch.float32: eps = 2 ** -23 if dtype_bench == torch.float64: diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py index 498102b475f564564d6039a81e305fba3bceec17..8362b551a1ad5c7e18b593c73c25aef1a1dc9def 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py @@ -33,7 +33,7 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator from msprobe.pytorch.common import parse_json_info_forward_backward from msprobe.pytorch.common.log import logger from msprobe.core.common.file_utils import FileChecker, check_file_suffix, check_link, FileOpen, \ - create_directory, load_json, save_json + create_directory, load_json, save_json, read_csv from msprobe.core.common.file_utils import remove_path from msprobe.core.common.const import FileCheckConst, Const from msprobe.core.common.utils import CompareException @@ -50,6 +50,9 @@ def split_json_file(input_file, num_splits, filter_api): backward_data[f"{data_name}.backward"] = backward_data.pop(data_name) input_data = load_json(input_file) + if "dump_data_dir" not in input_data.keys(): + logger.error("Invalid input file, 'dump_data_dir' field is missing") + raise CompareException("Invalid input file, 'dump_data_dir' field is missing") if input_data.get("data") is None: logger.error("Invalid input file, 'data' field is missing") raise CompareException("Invalid input file, 'data' field is missing") @@ -67,15 +70,24 @@ def split_json_file(input_file, num_splits, filter_api): split_forward_data = dict(items[start:end]) temp_data = { **input_data, - "data":{ + "data": { **split_forward_data, **backward_data } } split_filename = os.path.join(input_dir, f"temp_part{i}.json") - save_json(split_filename, temp_data) split_files.append(split_filename) - + try: + save_json(split_filename, temp_data) + except Exception as e: + logger.error(f"An error occurred while saving split file: {e}") + for file in split_files: + try: + remove_path(file) + except Exception: + logger.error(f"File not found or could not be deleted: {file}") + msg = 'ERROR: Split json file failed, please check the input file and try again.' + raise CompareException(CompareException.PARSE_FILE_ERROR, msg) from e return split_files, total_items @@ -84,10 +96,6 @@ def signal_handler(signum, frame): raise KeyboardInterrupt() -signal.signal(signal.SIGINT, signal_handler) -signal.signal(signal.SIGTERM, signal_handler) - - ParallelUTConfig = namedtuple('ParallelUTConfig', ['api_files', 'out_path', 'num_splits', 'save_error_data_flag', 'jit_compile_flag', 'device_id', 'result_csv_path', 'total_items', 'config_path']) @@ -97,7 +105,7 @@ def run_parallel_ut(config): processes = [] device_id_cycle = cycle(config.device_id) if config.save_error_data_flag: - logger.info("UT task error datas will be saved") + logger.info("UT task error data will be saved") logger.info(f"Starting parallel UT with {config.num_splits} processes") progress_bar = tqdm(total=config.total_items, desc="Total items", unit="items") @@ -129,17 +137,20 @@ def run_parallel_ut(config): sys.stdout.flush() except ValueError as e: logger.warning(f"An error occurred while reading subprocess output: {e}") + finally: + if process.poll() is None: + process.stdout.close() def update_progress_bar(progress_bar, result_csv_path): while any(process.poll() is None for process in processes): - with FileOpen(result_csv_path, 'r') as result_file: - completed_items = len(result_file.readlines()) - 1 - progress_bar.update(completed_items - progress_bar.n) + result_file = read_csv(result_csv_path) + completed_items = len(result_file) + progress_bar.update(completed_items - progress_bar.n) time.sleep(1) for api_info in config.api_files: cmd = create_cmd(api_info, next(device_id_cycle)) - process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, bufsize=1, shell=False) processes.append(process) threading.Thread(target=read_process_output, args=(process,), daemon=True).start() @@ -185,8 +196,8 @@ def run_parallel_ut(config): def prepare_config(args): - api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE, - ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX) + api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE, + ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX) api_info = api_info_file_checker.common_check() out_path = args.out_path if args.out_path else Const.DEFAULT_PATH create_directory(out_path) @@ -195,11 +206,11 @@ def prepare_config(args): split_files, total_items = split_json_file(api_info, args.num_splits, args.filter_api) config_path = args.config_path if args.config_path else None if config_path: - config_path_checker = FileChecker(config_path, FileCheckConst.FILE, + config_path_checker = FileChecker(config_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX) config_path = config_path_checker.common_check() result_csv_path = args.result_csv_path or os.path.join( - out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv") + out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv") if not args.result_csv_path: details_csv_path = os.path.join(out_path, f"accuracy_checking_details_{time.strftime('%Y%m%d%H%M%S')}.csv") comparator = Comparator(result_csv_path, details_csv_path, False) @@ -214,14 +225,12 @@ def prepare_config(args): def main(): + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) parser = argparse.ArgumentParser(description='Run UT in parallel') _run_ut_parser(parser) - parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8, + parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8, help='Number of splits for parallel processing. Range: 1-64') args = parser.parse_args() config = prepare_config(args) run_parallel_ut(config) - - -if __name__ == '__main__': - main() diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py index 6214d892906bef44d94474c6415674f39099357b..0f184d14b66d84607a6767ba9ef5210ff4fc5b69 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py @@ -34,8 +34,10 @@ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api, i from msprobe.core.common.file_utils import check_link, FileChecker from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments from msprobe.core.common.const import FileCheckConst, Const +from msprobe.core.common.utils import check_op_str_pattern_valid from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward +from msprobe.core.common.decorator import recursion_depth_decorator def check_tensor_overflow(x): @@ -63,6 +65,7 @@ def check_tensor_overflow(x): return False +@recursion_depth_decorator("check_data_overflow") def check_data_overflow(x, device): if isinstance(x, (tuple, list)): if not x: @@ -75,6 +78,7 @@ def check_data_overflow(x, device): return torch_npu.npu.utils.npu_check_overflow(x) +@recursion_depth_decorator("is_bool_output") def is_bool_output(x): if isinstance(x, (tuple, list)): if not x: @@ -91,6 +95,7 @@ def run_overflow_check(forward_file): dump_path = os.path.dirname(forward_file) real_data_path = os.path.join(dump_path, Const.DUMP_TENSOR_DATA) for api_full_name, api_info_dict in tqdm(forward_content.items()): + check_op_str_pattern_valid(api_full_name) if is_unsupported_api(api_full_name, is_overflow_check=True): continue try: @@ -161,6 +166,7 @@ def _run_overflow_check(parser=None): _run_overflow_check_parser(parser) args = parser.parse_args(sys.argv[1:]) _run_overflow_check_command(args) + logger.info("UT task completed.") def _run_overflow_check_command(args): @@ -175,8 +181,3 @@ def _run_overflow_check_command(args): logger.error(f"Set NPU device id failed. device id is: {args.device_id}") raise NotImplementedError from error run_overflow_check(api_info) - - -if __name__ == '__main__': - _run_overflow_check() - logger.info("UT task completed.") diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py index 905687c1bfc932883396481410c333a7566fd342..4bf8ead7e6a7f686f2cb2f457884f054ab6e5237 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -45,14 +45,12 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareC from msprobe.pytorch.api_accuracy_checker.common.config import CheckerConfig from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward from msprobe.core.common.file_utils import FileChecker, change_mode, \ - create_directory, get_json_contents, read_csv, check_file_or_directory_path, check_crt_valid + create_directory, get_json_contents, read_csv, check_file_or_directory_path from msprobe.pytorch.common.log import logger from msprobe.pytorch.pt_config import parse_json_config from msprobe.core.common.const import Const, FileCheckConst, CompareConst -from msprobe.core.common.utils import safe_get_value, CompareException +from msprobe.core.common.utils import safe_get_value, CompareException, is_int, check_op_str_pattern_valid from msprobe.pytorch.common.utils import seed_all -from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec -from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params, \ ExecParams @@ -65,6 +63,8 @@ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv" not_backward_list = ['repeat_interleave'] unsupported_backward_list = ['masked_select'] +unsupported_api_list = ["to", "empty", "empty_like", "empty_strided", "new_empty", "new_empty_strided", + "empty_with_format"] tqdm_params = { @@ -83,29 +83,27 @@ tqdm_params = { } +seed_all() + + def run_ut(config): logger.info("start UT test") - if config.online_config.is_online: - logger.info(f"UT task result will be saved in {config.result_csv_path}".replace(".csv", "_rank*.csv")) - logger.info(f"UT task details will be saved in {config.details_csv_path}".replace(".csv", "_rank*.csv")) - else: - logger.info(f"UT task result will be saved in {config.result_csv_path}") - logger.info(f"UT task details will be saved in {config.details_csv_path}") + + logger.info(f"UT task result will be saved in {config.result_csv_path}") + logger.info(f"UT task details will be saved in {config.details_csv_path}") if config.save_error_data: - logger.info(f"UT task error_datas will be saved in {config.error_data_path}") + logger.info(f"UT task error_data will be saved in {config.error_data_path}") compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config) - if config.online_config.is_online: - run_api_online(config, compare) - else: - csv_df = read_csv(config.result_csv_path) - try: - api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)} - except IndexError: - logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.") - api_name_set = set() - run_api_offline(config, compare, api_name_set) + + csv_df = read_csv(config.result_csv_path) + try: + api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)} + except IndexError: + logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.") + api_name_set = set() + run_api_offline(config, compare, api_name_set) for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list): change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY) change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY) @@ -117,6 +115,7 @@ def run_ut(config): def run_api_offline(config, compare, api_name_set): err_column = CompareColumn() for _, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)): + check_op_str_pattern_valid(api_full_name) if api_full_name in api_name_set: continue if is_unsupported_api(api_full_name): @@ -158,66 +157,13 @@ def run_api_offline(config, compare, api_name_set): gc.collect() -def run_api_online(config, compare): - attl = init_attl(config.online_config) - dispatcher = ConsumerDispatcher(compare=compare) - dispatcher.start(handle_func=run_torch_api_online, config=config) - - def tcp_communication_flow(): - while True: - api_data = attl.recv() - if api_data == 'STOP_': - continue - if api_data == 'KILL_': - time.sleep(1) - logger.info("==========接收到STOP信号==========") - dispatcher.stop() - attl.stop_serve() - time.sleep(1) - break - if not isinstance(api_data, ApiData): - continue - api_full_name = api_data.name - _, api_name = extract_basic_api_segments(api_full_name) - if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list): - continue - if api_data.rank in config.online_config.rank_list: - dispatcher.update_consume_queue(api_data) - - def shared_storage_communication_flow(): - flag_num = -1 - while True: - api_data = attl.download() - if api_data == "start": - if flag_num == -1: - flag_num += 1 - flag_num += 1 - if api_data == "end": - flag_num -= 1 - if flag_num == 0: - dispatcher.stop() - break - if not isinstance(api_data, ApiData): - continue - api_full_name = api_data.name - _, api_name = extract_basic_api_segments(api_full_name) - if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list): - continue - if api_data.rank in config.online_config.rank_list: - dispatcher.update_consume_queue(api_data) - - if config.online_config.nfs_path: - shared_storage_communication_flow() - else: - tcp_communication_flow() - - def blacklist_and_whitelist_filter(api_name, black_list, white_list): """ run api(api_name) if api_name not in black_list and in white_list. If api is both in black_list and black_list, black_list first. return: False for exec api, True for not exec """ + black_list.extend(unsupported_api_list) if black_list and api_name in black_list: return True if white_list and api_name not in white_list: @@ -286,7 +232,7 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict if grad_input_index is not None: grad_index = grad_input_index.get('grad_index') - if need_backward: + if need_backward and out is not None: if need_to_backward(grad_index, out): backward_args = backward_content[api_full_name].get("input") func_options = { @@ -308,20 +254,6 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message) -def run_torch_api_online(api_full_name, api_data, backward_content): - in_fwd_data_list = [] - api_type, api_name = extract_basic_api_segments(api_full_name) - args, kwargs, out = api_data.args, api_data.kwargs, api_data.result - in_fwd_data_list.append(args) - in_fwd_data_list.append(kwargs) - if kwargs.get("device"): - del kwargs["device"] - - device_out = exec_api(api_type, api_name, Const.CUDA_LOWERCASE, args, kwargs) - device_out = move2device_exec(device_out, "cpu") - return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank) - - def check_need_grad(api_info_dict): need_grad = True if api_info_dict.get(Const.INPUT_KWARGS) and "out" in api_info_dict.get(Const.INPUT_KWARGS): @@ -344,6 +276,9 @@ def need_to_backward(grad_index, out): def run_backward(args, grad, grad_index, out): if grad_index is not None: + if not is_int(grad_index): + logger.error(f"{grad_index} dtype is not int") + raise TypeError(f"{grad_index} dtype is not int") if grad_index >= len(out): logger.error(f"Run backward error when grad_index is {grad_index}") raise IndexError(f"Run backward error when grad_index is {grad_index}") @@ -378,16 +313,6 @@ def initialize_save_error_data(error_data_path): return error_data_path -def init_attl(config): - """config: OnlineConfig""" - attl = ATTL('gpu', ATTLConfig(is_benchmark_device=True, - connect_ip=config.host, - connect_port=config.port, - nfs_path=config.nfs_path, - tls_path=config.tls_path)) - return attl - - def _run_ut_parser(parser): parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="", type=str, help=" The api param tool result file: generate from api param tool, " @@ -430,6 +355,7 @@ def preprocess_forward_content(forward_content): arg_cache = {} for key, value in forward_content.items(): + check_op_str_pattern_valid(key) base_key = key.rsplit(Const.SEP, 1)[0] if key not in arg_cache: @@ -471,35 +397,6 @@ def _run_ut(parser=None): run_ut_command(args) -def checked_online_config(online_config): - if not online_config.is_online: - return - if not isinstance(online_config.is_online, bool): - raise ValueError("is_online must be bool type") - # rank_list - if not isinstance(online_config.rank_list, list): - raise ValueError("rank_list must be a list") - if online_config.rank_list and not all(isinstance(rank, int) for rank in online_config.rank_list): - raise ValueError("All elements in rank_list must be integers") - - # nfs_path - if online_config.nfs_path: - check_file_or_directory_path(online_config.nfs_path, isdir=True) - return - # tls_path - if online_config.tls_path: - check_file_or_directory_path(online_config.tls_path, isdir=True) - check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key")) - check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt")) - check_crt_valid(os.path.join(online_config.tls_path, "server.crt")) - - # host and port - if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host): - raise Exception(f"host: {online_config.host} is invalid.") - if not isinstance(online_config.port, int) or not (0 < online_config.port <= 65535): - raise Exception(f"port: {online_config.port} is invalid, port range 0-65535.") - - def run_ut_command(args): if args.config_path: config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE, @@ -510,7 +407,7 @@ def run_ut_command(args): else: checker_config = CheckerConfig() - if not checker_config.is_online and not args.api_info_file: + if not args.api_info_file: logger.error("Please provide api_info_file for offline run ut.") raise Exception("Please provide api_info_file for offline run ut.") @@ -561,12 +458,18 @@ def run_ut_command(args): error_data_path = checker_config.error_data_path if save_error_data: if args.result_csv_path: - time_info = result_csv_path.split('.')[0].split('_')[-1] + parts_by_dot = result_csv_path.split(Const.SEP) + if len(parts_by_dot) < 2 or not parts_by_dot[0]: + raise ValueError("result_csv_path does not contain a valid file name with an extension.") + file_name_part = parts_by_dot[0] + parts_by_underscore = file_name_part.split(Const.REPLACEMENT_CHARACTER) + if len(parts_by_underscore) < 2: + raise ValueError("File name part does not contain enough '_' separated segments.") + time_info = parts_by_underscore[-1] + global UT_ERROR_DATA_DIR UT_ERROR_DATA_DIR = 'ut_error_data' + time_info error_data_path = initialize_save_error_data(error_data_path) - online_config = checker_config.get_online_config() - checked_online_config(online_config) config_params = { 'forward_content': forward_content, 'backward_content': backward_content, @@ -579,9 +482,8 @@ def run_ut_command(args): } run_ut_config = checker_config.get_run_ut_config(**config_params) run_ut(run_ut_config) + logger.info("UT task completed.") if __name__ == '__main__': - seed_all() _run_ut() - logger.info("UT task completed.") diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py index dc0174212e3f8f8cf70fa1701aadc664138dbcdf..60557c77d79c685dbcf8a312910816dcb06b2702 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py @@ -1,9 +1,7 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -18,8 +16,8 @@ import os from collections import namedtuple import re -import torch +import torch try: import torch_npu except ImportError: @@ -33,11 +31,9 @@ from msprobe.core.common.const import FileCheckConst, Const, CompareConst from msprobe.core.common.file_utils import FileChecker from msprobe.core.common.log import logger from msprobe.core.common.utils import CompareException +from msprobe.pytorch.hook_module.api_register import ApiTemplate, get_api_register from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate -from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate -from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate -from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate -from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate + hf_32_standard_api = ["conv1d", "conv2d"] not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'} @@ -108,17 +104,28 @@ def exec_api(exec_params): kwargs = exec_params.kwargs is_autocast = exec_params.is_autocast autocast_dtype = exec_params.autocast_dtype - - if api_type == "Functional": - torch_api = FunctionalOPTemplate(api_name, str, False) - if api_type == "Tensor": - torch_api = TensorOPTemplate(api_name, str, False) - if api_type == "Torch": - torch_api = TorchOPTemplate(api_name, str, False) - if api_type == "Aten": + out = None + + prefix_map = Const.API_DATA_PREFIX.get(Const.PT_FRAMEWORK, {}) + if not prefix_map or api_type not in prefix_map.values() or \ + api_type not in ( + Const.FUNCTIONAL_API_TYPE_PREFIX, + Const.TENSOR_API_TYPE_PREFIX, + Const.TORCH_API_TYPE_PREFIX, + Const.ATEN_API_TYPE_PREFIX, + Const.NPU_API_TYPE_PREFIX + ): + return out + + if api_type == Const.ATEN_API_TYPE_PREFIX: torch_api = AtenOPTemplate(api_name, None, False) - if api_type == "NPU": - torch_api = NpuOPTemplate(api_name, None, False, device) + else: + api_register = get_api_register() + api_register.initialize_hook(None) + api_func_type = list(prefix_map.keys())[list(prefix_map.values()).index(api_type)] + api_func = api_register.ori_api_attr.get(Const.PT_FRAMEWORK + Const.SEP + api_func_type, {}).get(api_name) + + torch_api = ApiTemplate(api_name, api_func, api_type, None, need_hook=False, device=device) if is_autocast: with autocast(dtype=autocast_dtype): out = torch_api.forward(*args, **kwargs) @@ -248,7 +255,8 @@ def record_skip_info(api_full_name, compare, compare_alg_results): def is_unsupported_api(api_name, is_overflow_check=False): split_name = api_name.split(Const.SEP)[0] - flag = (split_name == Const.DISTRIBUTED) or (is_overflow_check and split_name == Const.NPU) + unsupport_type_list = [Const.DISTRIBUTED, Const.MINDSPEED_API_TYPE_PREFIX] + flag = (split_name in unsupport_type_list) or (is_overflow_check and split_name == Const.NPU) if flag: logger.info(f"{split_name} api is not supported for run ut. SKIP.") return flag diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py deleted file mode 100644 index f31c29c6bb6fa8a863b83bf09d15aba09645436f..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +++ /dev/null @@ -1,203 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import glob -import os.path -import time -from multiprocessing import Queue -from typing import Optional, Union, Dict, Any -from dataclasses import dataclass - -import torch - -from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData -from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import TCPClient -from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer -from msprobe.core.common.file_utils import remove_path -from msprobe.pytorch.common.utils import logger, save_api_data, load_api_data, save_pkl, load_pkl - -BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]] - - -@dataclass -class ATTLConfig: - is_benchmark_device: bool - connect_ip: str - connect_port: int - # storage_config - nfs_path: str = None - tls_path: str = None - check_sum: bool = True - queue_size: int = 50 - - -class ATTL: - def __init__(self, session_id: str, session_config: ATTLConfig, need_dump=True) -> None: - self.session_id = session_id - self.session_config = session_config - self.logger = logger - self.socket_manager = None - self.data_queue = Queue(maxsize=50) - self.dequeue_list = [] - self.message_end = False - self.kill_progress = False - self.nfs_path = None - if self.session_config.nfs_path: - self.nfs_path = self.session_config.nfs_path - elif self.session_config.is_benchmark_device: - - self.socket_manager = TCPServer(self.session_config.connect_port, - self.data_queue, - self.session_config.check_sum, - self.session_config.tls_path) - self.socket_manager.start() - elif need_dump: - self.socket_manager = TCPClient(self.session_config.connect_ip, - self.session_config.connect_port, - self.session_config.check_sum, - self.session_config.tls_path) - self.socket_manager.start() - - def stop_serve(self): - if isinstance(self.socket_manager, TCPServer): - self.socket_manager.stop() - - def send(self, buffer: BufferType) -> None: - """ - npu major in 'send' (client) - """ - - # if tcp connection lost, - if self.socket_manager.signal_exit: - raise ConnectionError(f"Failed to connect to {self.session_config.connect_ip}.") - - # know receiver receive and go next - if isinstance(buffer, ApiData): - buffer = move2target_device(buffer, torch.device('cpu')) - - if 'device' in buffer.kwargs: - buffer.kwargs.pop('device') - rank = buffer.rank if hasattr(buffer, "rank") and buffer.rank is not None else 0 - step = buffer.step if hasattr(buffer, "step") else 0 - try: - io_buff = save_api_data(buffer) - except Exception as e: - self.logger.info(f"{buffer.name} can not be saved, skip: {e}") - return - data = io_buff.getvalue() - self.socket_manager.add_to_sending_queue(data, rank=rank, step=step) - - def recv(self, timeout_ms=0) -> Optional[BufferType]: - buffer = '' - while not buffer: - if timeout_ms > 0: - time.sleep(timeout_ms / 1000.0) - if not buffer and not self.data_queue.empty(): - buffer = self.data_queue.get() - break - if not buffer and timeout_ms > 0: # timeout is the only case we give up and return None - break - if self.message_end and self.data_queue.empty(): - buffer = b"KILL_CONFIRM" - self.kill_progress = True - break - time.sleep(0.1) # waiting outside the lock before next attempt - if not buffer: - # this is a result of a timeout - self.logger.info(f"RECEIVE API DATA TIMED OUT") - else: - if buffer == b"STOP_": - return "STOP_" - if buffer == b"KILL_": - self.message_end = True - return "STOP_" - if buffer == b"KILL_CONFIRM": - self.kill_progress = True - return "KILL_" - try: - buffer = load_api_data(buffer) - except Exception as e: - self.logger.warning("there is something error. please check it. %s", e) - if isinstance(buffer, bytes): - return '' - if isinstance(buffer, str): - return buffer - - return buffer - - def upload(self, buffer: BufferType): - if isinstance(buffer, ApiData): - buffer = move2target_device(buffer, torch.device('cpu')) - file_path = os.path.join(self.session_config.nfs_path, buffer.name + ".pt") - else: - file_path = os.path.join(self.session_config.nfs_path, buffer + f"_{int(time.time())}") - - try: - save_pkl(buffer, file_path) - except Exception as e: - self.logger.warning("there is something error in save_pt. please check it. %s", e) - - def download(self): - buffer = None - cur_file = None - for file_type in ("start*", "*.pt", "end*"): - pattern = os.path.join(self.nfs_path, file_type) - files = glob.glob(pattern) - if len(files) > 0: - cur_file = files[0] - break - - if cur_file is not None: - try: - buffer = load_pkl(cur_file) - except Exception as e: - self.logger.warning("there is something error. please check it. %s", e) - remove_path(cur_file) - return buffer - - -def move2device_exec(obj, device): - if isinstance(obj, (tuple, list)): - data_list = [move2device_exec(val, device) for val in obj] - return data_list if isinstance(obj, list) else tuple(data_list) - if isinstance(obj, dict): - return {key: move2device_exec(val, device) for key, val in obj.items()} - elif isinstance(obj, torch.Tensor): - obj = obj.detach() - if obj.device.type != device: - obj = obj.to(device) - return obj - elif "return_types" in str(type(obj)): - return move2device_exec(tuple(obj), device) - elif isinstance(obj, torch._C.device): - return torch.device(device) - else: - return obj - - -def move2target_device(buffer: ApiData, target_device): - # handle args - new_args = move2device_exec(buffer.args, target_device) - - # handle kwargs - new_kwargs = move2device_exec(buffer.kwargs, target_device) - - # handle result - new_results = move2device_exec(buffer.result, target_device) - - if target_device == torch.device('cpu') or target_device == "cpu": - return ApiData(buffer.name, tuple(new_args), new_kwargs, new_results, buffer.step, buffer.rank) - else: - return ApiData(buffer.name, tuple(new_args), new_kwargs, buffer.result, buffer.step, buffer.rank) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py deleted file mode 100644 index fbb087deec73bb6e77c0d7581128c74e2d9be9fa..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +++ /dev/null @@ -1,362 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import hashlib -import io -import struct -import time -import os -import signal -from queue import Queue -from threading import Thread -from typing import Union - -from twisted.internet import reactor, protocol, endpoints -from twisted.protocols.basic import FileSender - -from msprobe.pytorch.common.utils import logger -from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import STRUCT_UNPACK_MODE as unpack_mode, \ - STR_TO_BYTES_ORDER as bytes_order - -MAX_SENDING_QUEUE_SIZE = 20 - - -class TCPDataItem: - def __init__(self, data, - sequence_number: int, - rank: int = 0, - step: int = 0): - self.raw_data = data - self.sequence_number = sequence_number - self.rank = rank - self.step = step - self.retry_times = 0 - self.pending_time = 0 - self.busy_time = 0 - - -class TCPClient: - ACK_SUCCESS = b"OK___" - ACK_ERROR = b"ERROR" - ACK_BUSY = b"BUSY_" - ACK_STOP = b"STOP_" - ACK_STOP_CONFIRM = b"OVER_" - ACK_KILL_PROCESS = b"KILL_" - - QUEUE_PENDING_TIME = 60 - RESEND_RETRY_TIMES = 2 # 最大重传数 - RESEND_TIMER_TIME = 5 # 接收ACK超时定时器 - RESEND_PENDING_TIME = 60 # 连续pending时间超过1分钟则放弃该数据 - - def __init__(self, host="localhost", port=8000, check_sum=False, tls_path=None): - self.send_queue = Queue(MAX_SENDING_QUEUE_SIZE) - self.resend_dict = dict() - self.host = host - self.port = port - self.tls_path = tls_path - self.factory = None - self.sequence_number = 0 - self.signal_exit = False - self.tcp_manager = ClientProtocol(ack_queue_size=100, - chunk_size=655360, - check_sum=check_sum, - tls=self.tls_path) - self.send_thread = Thread(target=self._sending_queue_data) - self.send_thread.setDaemon(True) - self.send_thread.start() - self.destroy_thread = Thread(target=self._destroy_queue_data) - self.destroy_thread.setDaemon(True) - self.destroy_thread.start() - - @staticmethod - def run_reactor(): - reactor.run(installSignalHandlers=False) - - def start(self): - def conn_callback(cur_protocol): - if cur_protocol.transport and cur_protocol.transport.getPeer().host == self.host: - logger.debug(f"Process: {os.getpid()} connects to server successfully.") - else: - logger.warning(f"Process: {os.getpid()} fails to connect to server. ") - raise ConnectionError(f"Failed to connect to {self.host}.") - - def conn_err_callback(failure): - self.signal_exit = True - time.sleep(1) - reactor.stop() - logger.error(f"Failed to connected {self.host} {self.port}. Reason is {failure.getErrorMessage()}") - - def cur_protocol(): - return self.tcp_manager - - self.factory = MessageClientFactory() - self.factory.protocol = cur_protocol - if self.tls_path: - from twisted.internet import ssl - client_key = os.path.join(self.tls_path, "client.key") - client_crt = os.path.join(self.tls_path, "client.crt") - client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt) - endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, client_context_factory) - else: - endpoint = endpoints.TCP4ClientEndpoint(reactor, self.host, self.port) - d = endpoint.connect(self.factory) - d.addCallback(conn_callback) - d.addErrback(conn_err_callback) - - reactor_thread = Thread(target=self.run_reactor, daemon=True) - reactor_thread.start() - - def send_after_queue_empty(self, data): - while not self._ready_to_exit(): - if not self.tls_path: - self.add_to_sending_queue(data) - else: - for _ in range(MAX_SENDING_QUEUE_SIZE): - self.add_to_sending_queue(data) - time.sleep(2) - - def check_client_alive(self): - return self.factory.num_connections > 0 - - def stop(self): - self.tcp_manager.connection_timeout() - - def send_stop_signal(self): - self.send_after_queue_empty(self.ACK_STOP) - while not self._ready_to_exit(): - if not self.check_client_alive(): - break - time.sleep(1) - - def add_to_sending_queue(self, data: Union[bytes, TCPDataItem], rank: int = 0, step: int = 0): - if self._ready_to_exit(): - return - - send_data = data - if not isinstance(data, TCPDataItem): - send_data = TCPDataItem(data=data, - sequence_number=self.sequence_number, - rank=rank, - step=step) - self.sequence_number += 1 - try: - self.send_queue.put(send_data, block=True, timeout=self.QUEUE_PENDING_TIME) - except Exception as e: - logger.error(f"send_queue put send_data timeout, rank: {send_data.rank}, step: {send_data.step}," - f"sequence_number: {send_data.sequence_number}, send_queue size: {self.send_queue.qsize()}," - f"{str(e)}") - - def _send_data(self, data: TCPDataItem): - self.tcp_manager.send_wrapped_data(data.raw_data, - sequence_number=data.sequence_number, - rank=data.rank, - step=data.step - ) - - def _sending_queue_data(self): - while True: - if not self.tcp_manager.is_connected: - continue - - while self.send_queue.qsize() > 0: - if self._ready_to_exit(): - break - if len(self.resend_dict) < MAX_SENDING_QUEUE_SIZE: - data_obj = self.send_queue.get() - resend_key = str(data_obj.sequence_number) + "_" + str(data_obj.rank) + "_" + str(data_obj.step) - logger.debug(f"get {resend_key} from send_queue, and send to server.") - self._send_data(data_obj) - if resend_key not in self.resend_dict.keys(): - # Send data for the first time - self.resend_dict[resend_key] = data_obj - else: - time.sleep(0.1) - - if self._ready_to_exit(): - logger.debug("Successfully close sending process.") - break - time.sleep(0.1) - - def _destroy_queue_data(self): - while True: - if self._ready_to_exit(): - break - - while len(self.resend_dict) > 0 and self.tcp_manager.ack_queue.qsize() > 0: - ack_info, seq_number, rank, step = self.tcp_manager.ack_queue.get() - obj_key = str(seq_number) + "_" + str(rank) + "_" + str(step) - current_item = self.resend_dict.get(obj_key) - - if current_item is None: - continue - - if ack_info == self.ACK_SUCCESS: - self.resend_dict.pop(obj_key) - elif ack_info == self.ACK_BUSY: - logger.debug("RECV BUSY ACK") - if current_item.busy_time > 5: - self._resend_data(current_item) - else: - current_item.busy_time += 1 - elif ack_info == self.ACK_ERROR: - logger.debug("RECV ERROR ACK") - self._resend_data(current_item) - elif ack_info == self.ACK_STOP_CONFIRM: - logger.debug("RECV STOP ACK") - self.factory.num_connections -= 1 - - break - - time.sleep(0.1) - - def _resend_data(self, data: TCPDataItem): - if data.retry_times < self.RESEND_RETRY_TIMES: - data.retry_times += 1 - logger.debug(f"Resend data seq number: {data.sequence_number}") - self.add_to_sending_queue(data) - else: - self.resend_dict.pop(data.sequence_number) - logger.debug(f"SKIP send sequence number {data.sequence_number} after retry {data.retry_times} times!") - - def _pending_data(self, data: TCPDataItem): - if data.pending_time >= self.RESEND_PENDING_TIME: - self.resend_dict.pop(data.sequence_number) - logger.debug(f"SKIP send sequence number {data.sequence_number} after pending {data.pending_time} times!") - return - - # wait time is 100MB per second - pending_time = max(1, len(data.raw_data) // (2 ** 20 * 50)) - data.pending_time += pending_time - time.sleep(pending_time) - - def _ready_to_exit(self): - return self.signal_exit or self.tcp_manager.signal_exit - - -class ClientProtocol(protocol.Protocol): - TIMEOUT = 60 * 10 - - def __init__(self, ack_queue_size=100, chunk_size=65536, check_sum=False, tls=None): - self.buffer = io.BytesIO() - self.is_connected = False - self.check_sum = check_sum - self.tell = 0 - self.ack_queue = Queue(maxsize=ack_queue_size) - self.file_sender = FileSender() - self.file_sender.CHUNK_SIZE = chunk_size - self.signal_exit = False - self.defer = None - self.kill_process = False - self.ack = None - - self.timeout_call = None - - self.tls = tls - self.send_buffer = b"" - self.buffer_cnt = 0 - - def dataReceived(self, data): - if self.timeout_call.active(): - self.timeout_call.reset(self.TIMEOUT) - - self.buffer.seek(0, 2) - self.buffer.write(data) - self.buffer.seek(self.tell) - while True: - if len(self.buffer.getvalue()) >= 29: # 5 + 8 * 3 - ack = self.buffer.read(5) - self.ack = ack - seq_number = struct.unpack(unpack_mode, self.buffer.read(8))[0] - rank = struct.unpack(unpack_mode, self.buffer.read(8))[0] - step = struct.unpack(unpack_mode, self.buffer.read(8))[0] - logger.debug(f"receive 流水号: {seq_number}; RANK: {rank}; STEP: {step}; ACK: {ack}") - if ack == b"KILL_": - self.kill_process = True - logger.debug(f"接收到KILL信号, PID {os.getpid()}") - if ack == b"OVER_": - self.factory.num_connections -= 1 - self.tell += 29 - if not self.ack_queue.full(): - self.ack_queue.put((ack, seq_number, rank, step)) - self.buffer = io.BytesIO(self.buffer.getvalue()[self.tell:]) - self.tell = 0 - else: - time.sleep(0.1) - else: - break - - def send_wrapped_data(self, data, sequence_number: int = 0, rank: int = 0, step: int = 0): - length = len(data) - md5_hash = hashlib.md5(data).hexdigest() if self.check_sum else "" - data_meaasge = length.to_bytes(8, byteorder=bytes_order) + \ - sequence_number.to_bytes(8, byteorder=bytes_order) + \ - rank.to_bytes(8, byteorder=bytes_order) + \ - step.to_bytes(8, byteorder=bytes_order) + \ - md5_hash.encode() + \ - data - logger.debug(f"send 流水号: {sequence_number}; RANK: {rank}; STEP: {step}; LENGTH: {length}") - - while True: - if self.defer is None or self.defer.called: - self.defer = self.send_large_data(data_meaasge) - break - time.sleep(0.01) - - def send_large_data(self, data): - - if self.tls: - self.send_buffer += data - self.buffer_cnt += 1 - if self.buffer_cnt >= MAX_SENDING_QUEUE_SIZE: - d = self.file_sender.beginFileTransfer(io.BytesIO(self.send_buffer), self.transport) - self.send_buffer = b"" - self.buffer_cnt = 0 - else: - d = None - else: - d = self.file_sender.beginFileTransfer(io.BytesIO(data), self.transport) - return d - - def connection_timeout(self): - if self.factory.num_connections <= 0: - return - - self.factory.num_connections -= 1 - logger.debug(f"超时退出{self.transport.addr}, PID {os.getpid()}") - self.transport.loseConnection() - - def connectionMade(self): - self.timeout_call = reactor.callLater(self.TIMEOUT, self.connection_timeout) - self.is_connected = True - self.factory.num_connections += 1 - logger.info("successfully connect server") - - def connectionLost(self, reason): - self.signal_exit = True - self.factory.num_connections -= 1 - logger.info(f"Lost connection with server, reason is : {reason}") - - -class MessageClientFactory(protocol.ClientFactory): - def __init__(self): - self.num_connections = 0 - - def clientConnectionFailed(self, connector, reason): - logger.info(f"Fail to connection with server: {reason.getErrorMessage()}") - reactor.stop() - - def clientConnectionLost(self, connector, reason): - logger.info(f"Client lost connection with server: {reason.getErrorMessage()}") - reactor.stop() diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py deleted file mode 100644 index 8777af9cc37ad03dacfa82bf29854fb1c1babe95..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +++ /dev/null @@ -1,229 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -from collections import namedtuple - -import pandas as pd -import torch -import torch.multiprocessing as mp - -from msprobe.core.common.const import Const, CompareConst -from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import online_api_precision_compare -from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import DETAIL_TEST_ROWS, thousandth_standard_api, \ - binary_standard_api, absolute_standard_api -from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import UtDataInfo, exec_api, ExecParams -from msprobe.pytorch.common.log import logger -from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import move2target_device -from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params - - -# NPU vs GPU api list -CompareApi = set(absolute_standard_api) | set(binary_standard_api) | set(thousandth_standard_api) - -current_time = time.strftime("%Y%m%d%H%M%S") -ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME = "api_precision_compare_result_" + current_time + "_rank*.csv" -ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME = "api_precision_compare_details_" + current_time + "_rank*.csv" - -OnlineApiPrecisionCompareConfig = namedtuple('OnlineApiPrecisionCompareConfig', - ['npu_data', 'gpu_data', 'rank', 'result_csv_path', 'details_csv_path']) -# namedtuple of [instance of Comparator, func of run_touch_api_online, config of run_ut_config] -CommonCompareConfig = namedtuple('CommonCompareConfig', ['compare', 'handle_func', 'config']) - - -def run_ut_process(xpu_id, consumer_queue, common_config, api_precision_csv_file): - """ When consumer_queue(shared with ConsumerDispatcher) is not empty, consume api data from consumer_queue. - :param xpu_id: int - :param consumer_queue: shared queues of ConsumerDispatcher - :param common_config: namedtuple of CommonCompareConfig - :param api_precision_csv_file: list, length is 2, result file name and details file name - :return: - """ - gpu_device = torch.device(f'cuda:{xpu_id}') - - while True: - if consumer_queue.empty(): - time.sleep(0.1) - continue - - api_data = consumer_queue.get() - if api_data == "KILL_": - # current consumer finish - return - - _, api_name, _ = api_data.name.split(Const.SEP) - if api_name in CompareApi: - # NPU vs GPU - online_compare(api_data, gpu_device, common_config) - else: - # NPUvsCPU vs GPUvsCPU - online_precision_compare(api_data, gpu_device, common_config, api_precision_csv_file) - - -def online_precision_compare(api_data, device, common_config, api_precision_csv_file): - """online run_ut for precision_compare: NPUvsCPU vs GPUvsCPU - 1. get NPUvsCPU compare result - 2. get GPUvsCPU compare result - 3. call online_api_precision_compare - :param api_data - :param device - :param common_config: namedtuple of CommonCompareConfig - :param api_precision_csv_file: [result_file_name, details_file_name] - """ - compare, func, config = common_config.compare, common_config.handle_func, common_config.config - api_full_name = api_data.name - [api_type, api_name, _] = api_full_name.split(Const.SEP) - npu_args, npu_kwargs, npu_out = api_data.args, api_data.kwargs, api_data.result - - if npu_kwargs.get("device"): - del npu_kwargs["device"] - - try: - # NPU vs CPU - cpu_params = generate_cpu_params(npu_args, npu_kwargs, False, api_name) - cpu_args, cpu_kwargs = cpu_params.cpu_args, cpu_params.cpu_kwargs - cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs, False, None) - cpu_out = exec_api(cpu_exec_params) - npu_data_info = UtDataInfo(None, None, npu_out, cpu_out, None, [], None, rank=api_data.rank) - npu_detail = compare.compare_output(api_full_name, npu_data_info, True) - npu_data = pd.DataFrame(npu_detail, columns=DETAIL_TEST_ROWS[-1]) - - # GPU vs CPU - api_data_gpu = move2target_device(api_data, device) # args, kwargs -> gpu, result -> npu - data_info = func(api_full_name, api_data_gpu, config.backward_content) - gpu_out = data_info.bench_output - gpu_data_info = UtDataInfo(None, None, gpu_out, cpu_out, None, [], None, rank=api_data.rank) - gpu_detail = compare.compare_output(api_full_name, gpu_data_info, True) - gpu_data = pd.DataFrame(gpu_detail, columns=DETAIL_TEST_ROWS[-1]) - - # NPUvsCPU vs GPUvsCPU - result_file_name, details_file_name = api_precision_csv_file - precision_compare_config = OnlineApiPrecisionCompareConfig(npu_data, gpu_data, api_data.rank, - result_file_name, details_file_name) - online_api_precision_compare(precision_compare_config) - - except Exception as err: - if "expected scalar type Long" in str(err): - logger.warning( - f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " - f"'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.") - elif api_type in [Const.DISTRIBUTED]: - logger.info(f"{api_full_name} is not supported for run ut. SKIP.") - else: - logger.error(f"Run {api_full_name} UT Error: {str(err)}") - - compare.write_summary_csv((api_full_name, CompareConst.SKIP, CompareConst.SKIP, [[str(err)]], api_data.rank)) - - finally: - torch.cuda.empty_cache() - - -def online_compare(api_data, device, common_config): - """online run_ut for compare:NPU vs GPU - """ - compare, func, config = common_config.compare, common_config.handle_func, common_config.config - api_full_name = api_data.name - api_data = move2target_device(api_data, device) - try: - data_info = func(api_full_name, api_data, config.backward_content) - is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info) - logger.info(f"running api_full_name {api_full_name} ut, " - f"is_fwd_success: {is_fwd_success}, " - f"is_bwd_success: {is_bwd_success}") - except Exception as err: - [api_type, api_name, _] = api_full_name.split(Const.SEP) - if "expected scalar type Long" in str(err): - logger.warning( - f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " - f"'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.") - elif api_type in [Const.DISTRIBUTED]: - logger.info(f"{api_full_name} is not supported for run ut. SKIP.") - else: - logger.error(f"Run {api_full_name} UT Error: {str(err)}") - - compare.write_summary_csv((api_full_name, CompareConst.SKIP, CompareConst.SKIP, [[str(err)]], api_data.rank)) - - finally: - torch.cuda.empty_cache() - - -class ConsumerDispatcher: - def __init__(self, compare, capacity=10, num_workers=8, device: str = "gpu") -> None: - self.num_workers = num_workers - self.capacity = capacity - self.compare = compare - self.queues = [] - self.processes = [] - self.reverse_sort = False - self.pool = None - self.device = device - self.data_id = 0 - self.lock = mp.Lock() - self.result_queue = mp.Queue() - mp.set_start_method("spawn", force=True) - - def start(self, handle_func, config): - self.queues = [mp.Queue(maxsize=self.capacity) for _ in range(self.num_workers)] - api_precision_csv_file = [ - ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME, - ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME - ] - common_config = CommonCompareConfig(self.compare, handle_func, config) - for xpu_id, q in enumerate(self.queues): - p = mp.Process(name="run_ut_process", target=run_ut_process, - args=(xpu_id, q, common_config, api_precision_csv_file)) - - p.start() - self.processes.append(p) - logger.info( - f'Api_precision_compare task result will be saved in {ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME}') - logger.info( - f"Api_precision_compare task details will be saved in {ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME}") - logger.info("Successfully start unittest process.") - - def stop(self): - for q in self.queues: - while q.full(): - time.sleep(0.1) - q.put("KILL_") - - for p in self.processes: - p.join() - logger.info("Successfully stop unittest process.") - logger.info(f"Api_precision_compare task result is saved in {ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME}") - logger.info(f"Api_precision_compare task details is saved in {ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME}") - - def update_consume_queue(self, api_data): - while True: - index = self._choose_max_empty_site_strategy() - if index != -1: - q = self.queues[index] - q.put(api_data) - break - time.sleep(0.1) - - def _choose_max_empty_site_strategy(self): - maximum = 0 - index = -1 - # 充分利用多卡资源,防止任务过多分配给前面的卡 - _reverse = 1 if not self.reverse_sort else -1 - for i, q in enumerate(self.queues[::_reverse]): - empty_site = self.capacity - q.qsize() - if empty_site > maximum: - maximum = empty_site - index = i - index = len(self.queues) - index - 1 if index != -1 and self.reverse_sort else index - self.reverse_sort = not self.reverse_sort - return index diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py deleted file mode 100644 index 61650705e48056d5964c7ba48ff442247a08e4f9..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +++ /dev/null @@ -1,115 +0,0 @@ - -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from collections import defaultdict -from functools import wraps - -import torch -from torch.utils._python_dispatch import TorchDispatchMode -from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData -from msprobe.pytorch.common.utils import get_tensor_rank -from msprobe.core.common.const import Const -from msprobe.pytorch.common.log import logger -from msprobe.core.common.file_utils import load_yaml - - -def singleton(cls): - _instance = {} - - @wraps(cls) - def inner(): - if cls not in _instance: - _instance[cls] = cls() - return _instance[cls] - return inner - - -@singleton -class Counter: - def __init__(self) -> None: - self.index_dict = defaultdict(int) - - -counter = Counter() -yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml") -yaml_file = load_yaml(yaml_path) - - -class AccuracyCheckerDispatch(TorchDispatchMode): - def __init__(self, attl): - super(AccuracyCheckerDispatch, self).__init__() - self.attl = attl - self.counter = counter - self.aten_ops_blacklist = [] - self.npu_adjust_autogard = [] - self.aten_ops_blacklist = yaml_file.get('aten_ops_blacklist', []) - self.npu_adjust_autogard = yaml_file.get('npu_adjust_autogard', []) - - def __torch_dispatch__(self, func, types, args=None, kwargs=None): - func_name_split_list = func.__name__.split(Const.SEP) - aten_api = func_name_split_list[0] - self.enable_autogard(aten_api) - if aten_api in self.aten_ops_blacklist: - npu_out = func(*args, **kwargs) - return npu_out - - res = func(*args, **kwargs) - cur_rank = get_tensor_rank(args, res) - cur_api_number = self.counter.index_dict[aten_api] - api_name = f'{Const.ATEN}{Const.SEP}{aten_api}{Const.SEP}{cur_api_number}' - logger.info(f"tools is dumping api: {api_name}, rank: {cur_rank}") - api_data = ApiData(api_name, args, kwargs, res, 0, cur_rank) - if "device" in api_data.kwargs: - api_data.kwargs.pop("device") - if self.attl.nfs_path: - self.attl.upload(api_data) - else: - self.attl.send(api_data) - self.counter.index_dict[aten_api] += 1 - - return res - - def enable_autogard(self, aten_api): - if aten_api in self.npu_adjust_autogard: - torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.AutogradFunctionality, False) - - -def dispatch4data(func, attl, status): - @wraps(func) - def wrapper(*args, **kwargs): - if not status: - return func(*args, **kwargs) - with AccuracyCheckerDispatch(attl): - res = func(*args, **kwargs) - return res - - return wrapper - - -def run_ut_dispatch(attl, status, is_recompute=False): - """ - This function called by online_run_ut. - It is used to enable or disable dispatch for torch.autograd.backward function. - - Args: - attl (ATTL): online_run_ut class ATTL, which is used to upload or send api data to server. - status (bool): True means enable dispatch, False means disable dispatch. - is_recompute (bool): Flag of recompute, which is conflicted with aten api, then skip dispatch4data. - """ - if is_recompute: - return - torch.autograd.backward = dispatch4data(torch.autograd.backward, attl, status) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py deleted file mode 100644 index 411e36d4cb3014b75a46d58ebec99b7e8b7c7c44..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os.path -import struct -import hashlib -import time -import io -from threading import Thread - -from twisted.internet import reactor, protocol, endpoints - -from msprobe.pytorch.common.utils import logger -from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import cipher_list, \ - STRUCT_UNPACK_MODE as unpack_mode, STR_TO_BYTES_ORDER as bytes_order - - -class TCPServer: - def __init__(self, port, shared_queue, check_sum=False, tls_path=None) -> None: - self.port = port - self.shared_queue = shared_queue - self.check_sum = check_sum - self.tls_path = tls_path - self.factory = MessageServerFactory() - self.reactor_thread = None - - @staticmethod - def run_reactor(): - reactor.run(installSignalHandlers=False) - - def start(self): - self.factory.protocol = self.build_protocol - - if self.tls_path: - from OpenSSL import SSL - from twisted.internet import ssl - server_key = os.path.join(self.tls_path, "server.key") - server_crt = os.path.join(self.tls_path, "server.crt") - server_context_factory = ssl.DefaultOpenSSLContextFactory(server_key, server_crt, SSL.TLSv1_2_METHOD) - server_context_ = server_context_factory.getContext() - server_context_.set_cipher_list(cipher_list) - server_context_.set_options(SSL.OP_NO_RENEGOTIATION) - endpoint = endpoints.SSL4ServerEndpoint(reactor, self.port, server_context_factory) - else: - endpoint = endpoints.TCP4ServerEndpoint(reactor, self.port) - endpoint.listen(self.factory) - self.reactor_thread = Thread(target=self.run_reactor, daemon=True) - self.reactor_thread.start() - - def is_running(self): - return not self.factory.is_all_connection_closed() - - def stop(self): - self.factory.doStop() - reactor.callFromThread(reactor.sigInt, 2) - self.reactor_thread.join() - - def build_protocol(self): - return ServerProtocol(self.shared_queue, self.check_sum) - - -class ServerProtocol(protocol.Protocol): - ACK_SUCCESS = b"OK___" - ACK_ERROR = b"ERROR" - ACK_BUSY = b"BUSY_" - ACK_STOP = b"STOP_" - ACK_STOP_CONFIRM = b"OVER_" - ACK_KILL_PROCESS = b"KILL_" - - def __init__(self, shared_queue, check_sum=False): - self.start_time = None - self.buffer = io.BytesIO() - self.consumer_queue = shared_queue - self.check_sum = check_sum - self.length_width = 8 - self.md5_width = 32 - self.obj_length = None - self.tell = 0 - self.obj_md5 = None - self.obj_body = None - self.sequence_number = -1 - self.rank = -1 - self.step = -1 - self.sequence_number_dict = dict() - - def connectionMade(self): - self.buffer = io.BytesIO() - self.obj_length = None - self.tell = 0 - self.obj_md5 = None - self.obj_body = None - self.factory.transport_dict[self.transport] = 1 - self.factory.transport_list.append(self.transport) - logger.info(f"Connected to {self.transport.getPeer()} successfully.") - - def connectionLost(self, reason): - self.factory.transport_dict.pop(self.transport, None) - if len(self.factory.transport_dict) == 0: - self.consumer_queue.put(self.ACK_KILL_PROCESS) - - logger.info(f"Lost connection with {self.transport.getPeer()}. Reason is: {reason} 与客户端 断开连接, " - f"current connection number is: {len(self.factory.transport_dict)}") - - def send_ack(self, ack_info): - ack_message = b"".join([ - ack_info, - self.sequence_number.to_bytes(8, byteorder=bytes_order), - self.rank.to_bytes(8, byteorder=bytes_order), - self.step.to_bytes(8, byteorder=bytes_order) - ]) - self.transport.write(ack_message) - - def post_process(self): - send_busy_ack = False - while self.consumer_queue.full(): - if not send_busy_ack: - self.send_ack(self.ACK_BUSY) - logger.debug("sending BUSY ACK") - send_busy_ack = True - time.sleep(0.1) - - obj_key = str(self.sequence_number) + "_" + str(self.rank) + "_" + str(self.step) - - recv_md5 = hashlib.md5(self.obj_body).hexdigest() - if self.check_sum and recv_md5 != self.obj_md5: - # when needs check md5 and check no pass, indicates received data error, send b"ERROR" to client. - logger.debug(f"Error:接收数据有问题,流水号{self.sequence_number}, expected {self.obj_md5}, but get {recv_md5}") - self.send_ack(self.ACK_ERROR) - else: - if self.obj_body == self.ACK_STOP: - self.handle_with_stop() - else: - self.send_ack(self.ACK_SUCCESS) - if obj_key in self.sequence_number_dict: - logger.debug(f"这是一次异常的重传,可以忽略。 {obj_key}, {self.sequence_number_dict}") - else: - self.sequence_number_dict[obj_key] = self.obj_md5 - self.consumer_queue.put(self.obj_body, block=True) - - self.reset_env() - finish_time = time.time() - logger.debug(f"finish_time: {finish_time - self.start_time}") - - def handle_with_stop(self): - logger.debug(f"接收到停止传输信号 TCP{self.transport.getPeer()}") - self.send_ack(self.ACK_STOP_CONFIRM) - if len(self.factory.transport_dict) == 0: - _rank, _step, _sequence_number = 0, 0, 100000000 - ack_kill = self.ACK_KILL_PROCESS + \ - _sequence_number.to_bytes(8, byteorder='big') + \ - _rank.to_bytes(8, byteorder='big') + \ - _step.to_bytes(8, byteorder='big') - for trans in self.factory.transport_list: - trans.write(ack_kill) - logger.debug(f"发送KILL信息给{self.transport.getPeer()}") - self.consumer_queue.put(self.ACK_KILL_PROCESS) - time.sleep(2) - - def reset_env(self): - self.obj_length = None - self.sequence_number = -1 - self.rank = -1 - self.step = -1 - self.obj_md5 = None - self.obj_body = None - - def dataReceived(self, data): - self.buffer.seek(0, 2) - self.buffer.write(data) - self.buffer.seek(self.tell) - - # The first data packet is packet header, it contains obj_length, sequence_number, rank, step - if self.obj_length is None and len(self.buffer.getvalue()) >= self.length_width * 4: - self.start_time = time.time() - self.obj_length = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0] - self.sequence_number = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0] - self.rank = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0] - self.step = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0] - self.tell += self.length_width * 4 - logger.debug( - f"流水号: {self.sequence_number}; RANK: {self.rank}; STEP: {self.step}; Length: {self.obj_length}") - - # If needs check md5 but not parse md5 yet, read 32b md5 values - check_sum_and_md5 = (self.check_sum - and self.obj_length is not None - and self.obj_md5 is None - and len(self.buffer.getvalue()) - self.tell >= self.md5_width) - if check_sum_and_md5: - self.obj_md5 = self.buffer.read(self.md5_width).decode() - self.tell += self.md5_width - logger.debug(f"MD5: {self.obj_md5}") - - current_length = len(self.buffer.getvalue()) - self.tell - if self.obj_length is not None and 0 < self.obj_length <= current_length: - # Current api data receive finished - self.obj_body = self.buffer.read(self.obj_length) - - self.tell += self.obj_length - self.buffer = io.BytesIO(self.buffer.getvalue()[self.tell:]) - self.buffer.seek(0) - self.tell = 0 - recv_data_time = time.time() - logger.debug(f"self.sequence_number {self.sequence_number} " - f"recv_data_time {recv_data_time - self.start_time}") - - if self.obj_body == self.ACK_STOP: - # Indicates the current TCP link receives a STOP signal and remove from the transport_dict - _transport = self.factory.transport_dict.pop(self.transport, None) - logger.debug(f"接收到b'STOP_' self.sequence_number {self.sequence_number} ") - self.post_process() - - -class MessageServerFactory(protocol.ServerFactory): - def __init__(self) -> None: - """ - transport_dict: links that have not completed data transmission. - transport_list: Records all TCP links. Appends TCP link to the transport list - when a new TCP link is established. - """ - self.transport_dict = {} - self.transport_list = [] - - def is_all_connection_closed(self): - return len(self.transport_dict) == 0 diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml deleted file mode 100644 index 373e6ed0fc33a97537f38eedeb36a3e90122525a..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +++ /dev/null @@ -1,63 +0,0 @@ -aten_ops_blacklist: - - npu_binary_cross_entropy_with_logits_backward - - npu_ciou_backward - - _cudnn_rnn - - _local_scalar_dense - - _pin_memory - - _to_copy - - _unsafe_view - - clone - - contiguous - - copy_ - - cudnn_batch_norm - - cudnn_batch_norm_backward - - detach - - empty - - index_put_ - - lift_fresh - - max_pool2d_with_indices_backward # shape unmatch - - native_batch_norm_backward - - new_empty - - new_empty_strided - - new_full - - new_ones - - new_zeros - - ones - - ones_like - - permute - - rand - - rand_like - - randint - - randint_like - - randn - - randn_like - - randperm - - scalar_tensor - - select - - to - - transpose - - unbind - - view - - zero - - zero_ - - zeros - - zeros_like - - _record_function_enter_new - - _record_function_exit - - broadcast_ - - allreduce_ - - npu_clear_float_status - - npu_format_cast - - npu_dtype_cast - - npu_dtype_cast_backward - - _allgather_base_ - - _reduce_scatter_base_ - - is_same_size - -npu_adjust_autogard: - - adaptive_avg_pool2d - - batch_norm - - log_softmax - - nll_loss - - to - \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py index be15935ce9c9f77bc0a8447902f7f4a7b536a7fb..07655ba841120a80f64a9975a74abd7556569a41 100644 --- a/debug/accuracy_tools/msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py @@ -29,6 +29,8 @@ def softmax_func(x, axis=None): def npu_moe_gating_top_k_softmax(x, finished_optional, k): input_dtype = x.dtype + if x.dim() < 1: + raise ValueError("Input x must have at least 1 dimensions.") num_expert = x.shape[-1] softmax = softmax_func(x, -1) softmax = softmax.to(input_dtype) @@ -36,9 +38,13 @@ def npu_moe_gating_top_k_softmax(x, finished_optional, k): expert_idx = expert_idx[:, :k] y = torch.gather(softmax, index=expert_idx, dim=-1) if finished_optional is not None: + if finished_optional.dim() < 1: + raise ValueError("Finished_optional must have at least 1 dimensions.") finished_optional = finished_optional.view(finished_optional.shape[0], 1) finished_optional = finished_optional.expand(-1, k) expert_idx = torch.where(finished_optional, num_expert, expert_idx) + if y.dim() < 2: + raise ValueError("Variable y must have at least 2 dimensions.") row_idx = torch.arange(y.shape[0] * y.shape[1]).reshape(y.shape[1], y.shape[0]).t() return y, expert_idx, row_idx diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py index 58a585f5a05f4b2d533d150db3a9fbfd907f5a07..3cdb4f6c0c7c5ac5f01c41e06905b731cd809029 100644 --- a/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py @@ -117,6 +117,12 @@ def fusion_attention_forward(forward_params): pse = forward_params.pse scale = forward_params.scale keep_prob = forward_params.keep_prob + + # 除零风险拦截:keep_prob 为 0 时会导致除零错误 + if keep_prob == 0: + raise ValueError("fusion_attention_forward: keep_prob cannot be zero to avoid division by zero.") + + qk = calculate_qk(q, k, atten_mask, pse, scale) softmax_res, softmax_max, softmax_sum = softmax_forward(qk) if drop_mask is None or len(drop_mask.shape) == 0: @@ -137,6 +143,11 @@ def fusion_attention_backward(backward_params): pse = backward_params.pse scale = backward_params.scale keep_prob = backward_params.keep_prob + + # 除零风险拦截:keep_prob 为 0 时会导致除零错误 + if keep_prob == 0: + raise ValueError("fusion_attention_backward: keep_prob cannot be zero to avoid division by zero.") + dp = torch.matmul(dx, v.permute(0, 1, 3, 2)) if drop_mask is None or len(drop_mask.shape) == 0: drop_res = softmax_res.permute(0, 1, 3, 2) @@ -164,23 +175,35 @@ def parse_bsnd_args(query, key, head_num, input_layout): if input_layout == "BSH": b, s1, h1 = query.shape _, s2, h2 = key.shape + if n1 == 0: + raise ValueError("parse_bsnd_args: head_num (n1) cannot be zero to avoid division by zero.") d = h1 // n1 + if d == 0: + raise ValueError("parse_bsnd_args: computed head dimension (d) is zero, division by zero risk.") n2 = h2 // d elif input_layout == "SBH": s1, b, h1 = query.shape s2, _, h2 = key.shape + if n1 == 0: + raise ValueError("parse_bsnd_args: head_num (n1) cannot be zero to avoid division by zero.") d = h1 // n1 + if d == 0: + raise ValueError("parse_bsnd_args: computed head dimension (d) is zero, division by zero risk.") n2 = h2 // d elif input_layout == "BSND": b, s1, n1, d = query.shape _, s2, n2, _ = key.shape h1 = n1 * d h2 = n2 * d + if d == 0: + raise ValueError("parse_bsnd_args: head dimension (d) is zero, division by zero risk.") elif input_layout == "BNSD": b, n1, s1, d = query.shape _, n2, s2, _ = key.shape h1 = n1 * d h2 = n2 * d + if d == 0: + raise ValueError("parse_bsnd_args: head dimension (d) is zero, division by zero risk.") except Exception as e: raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e @@ -446,6 +469,8 @@ def npu_fusion_attention_forward_patch(*args, **kwargs): input_layout = get_input_layout(*args, **kwargs) b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], head_num, input_layout) + if d == 0: + raise ValueError("npu_fusion_attention_forward_patch: head dimension (d) is zero, division by zero risk.") if n1 == n2 and s1 == s2: logger.debug(f"running case : BNSD = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}") else: @@ -478,6 +503,8 @@ def npu_fusion_attention_backward_patch(*args, **kwargs): raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.") b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], args[4], args[5]) + if d == 0: + raise ValueError("npu_fusion_attention_backward_patch: head dimension (d) is zero, division by zero risk.") if n1 == n2 and s1 == s2: logger.info(f"running case : bnsd = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}") else: diff --git a/debug/accuracy_tools/msprobe/pytorch/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/common/utils.py index 16067f6d2bee70645bcc337d1809a14f41ae5b96..4b52d7475f5f9092594df9fd8c58b591b2adfd5c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/common/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,11 +24,12 @@ from functools import wraps import numpy as np import torch import torch.distributed as dist + from msprobe.core.common.exceptions import DistributedNotInitializedError from msprobe.core.common.file_utils import (FileCheckConst, change_mode, check_file_or_directory_path, check_path_before_create, FileOpen) from msprobe.core.common.log import logger -from msprobe.core.common.utils import check_seed_all +from msprobe.core.common.utils import check_seed_all, is_save_variable_valid from packaging import version try: @@ -38,7 +39,9 @@ except ImportError: else: is_gpu = False + torch_without_guard_version = torch.__version__ >= '2.1' +torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' if not is_gpu and not torch_without_guard_version: from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard @@ -57,7 +60,7 @@ def parameter_adapter(func): @wraps(func) def inner(self, *args, **kwargs): - if self.op_name_ == "__getitem__" and len(args) > 1 and isinstance(args[1], torch.Tensor): + if self.api_name == "__getitem__" and len(args) > 1 and isinstance(args[1], torch.Tensor): input_tensor = args[0] indices = args[1] if indices.dtype == torch.uint8: @@ -77,7 +80,7 @@ def parameter_adapter(func): else: res = [input_tensor[tensor_index] for tensor_index in indices] return getattr(torch._C._VariableFunctionsClass, "stack")(res, 0) - if self.op_name_ == "__eq__" and len(args) > 1 and args[1] is None: + if self.api_name == "__eq__" and len(args) > 1 and args[1] is None: return False return func(self, *args, **kwargs) @@ -147,7 +150,7 @@ def remove_dropout(): F.dropout3d = function_dropout3d -def seed_all(seed=1234, mode=False, rm_dropout=True): +def seed_all(seed=1234, mode=False, rm_dropout=False): check_seed_all(seed, mode, rm_dropout) try: random.seed(seed) @@ -309,14 +312,14 @@ def print_rank_0(message): logger.info(message) -def load_pt(pt_path, to_cpu=False): +def load_pt(pt_path, to_cpu=False, weights_only=True): pt_path = os.path.realpath(pt_path) check_file_or_directory_path(pt_path) try: if to_cpu: - pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=True) + pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=weights_only) else: - pt = torch.load(pt_path, weights_only=True) + pt = torch.load(pt_path, weights_only=weights_only) except Exception as e: raise RuntimeError(f"load pt file {pt_path} failed") from e return pt @@ -385,26 +388,6 @@ def load_pkl(pt_path): return pt -def save_api_data(api_data): - """Save data to io stream""" - try: - io_buff = io.BytesIO() - torch.save(api_data, io_buff) - except Exception as e: - raise RuntimeError(f"save api_data to io_buff failed") from e - return io_buff - - -def load_api_data(api_data_bytes): - """Load data from bytes stream""" - try: - buffer = io.BytesIO(api_data_bytes) - buffer = torch.load(buffer, map_location="cpu") - except Exception as e: - raise RuntimeError(f"load api_data from bytes failed") from e - return buffer - - def is_recomputation(): """Check if the current operation is in the re-computation phase. @@ -419,7 +402,11 @@ def is_recomputation(): bool: True if in the re-computation phase, False otherwise. """ backward_function_indices = [] - call_stack = inspect.stack() + try: + call_stack = inspect.stack() + except Exception as e: + logger.warning(f"Failed to capture stack trace, recomputation validation may be incorrect, error info: {e}.") + return False # Identify the function 'backward' is being executed within the 'torch/_tensor.py' file. for frame_info in call_stack: @@ -449,9 +436,11 @@ def is_recomputation(): def check_save_param(variable, name, save_backward): # try catch this api to skip invalid call - if not isinstance(variable, (list, dict, torch.Tensor, int, float, str)): + valid_data_types = (torch.Tensor, int, float, str) + if not is_save_variable_valid(variable, valid_data_types): + valid_data_types_with_nested_types = valid_data_types + (dict, tuple, list) logger.warning("PrecisionDebugger.save variable type not valid, " - "should be one of list, dict, torch.Tensor, int, float or string. " + f"should be one of {valid_data_types_with_nested_types}" "Skip current save process.") raise ValueError if not isinstance(name, str): @@ -466,10 +455,19 @@ def check_save_param(variable, name, save_backward): raise ValueError -def replace_last_occurrence(text, old, new): - if text is None: - return text - index = text.rfind(old) - if index != -1: - return text[:index] + text[index:].replace(old, new, 1) - return text +def is_torch_nn_module(variable): + return isinstance(variable, torch.nn.Module) and not isinstance(variable, torch.jit.ScriptModule) + + +def register_forward_pre_hook(module, forward_pre_hook): + if torch_version_above_or_equal_2: + module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True) + else: + module.register_forward_pre_hook(forward_pre_hook) + + +def register_forward_hook(module, forward_hook): + if torch_version_above_or_equal_2: + module.register_forward_hook(forward_hook, with_kwargs=True) + else: + module.register_forward_hook(forward_hook) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py index de62af421b5a37e39140a9836fb16853443740d7..6f8ad5cf60924581f9c112e1cb236f51f255a1dd 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,41 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - -from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.file_utils import create_directory -from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \ - set_dump_path -from msprobe.core.compare.acc_compare import ModeConfig -from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json, set_stack_json_path -from msprobe.pytorch.common.log import logger -from msprobe.pytorch.compare.pt_compare import PTComparator, compare +from msprobe.core.compare.utils import compare_distributed_inner +from msprobe.pytorch.compare.pt_compare import compare def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): - if kwargs.get("suffix"): - logger.error("Argument 'suffix' is not supported for compare_distributed.") - raise CompareException(CompareException.INVALID_PARAM_ERROR) - is_print_compare_log = kwargs.get("is_print_compare_log", True) - # get the ranks and match by order - npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank')) - bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank')) - if len(npu_ranks) != len(bench_ranks): - logger.error( - "The number of ranks in the two runs are different. " - "Unable to match the ranks. " - "Please use another folder to compare or use compare() api and manually match the ranks.") - raise CompareException(CompareException.INVALID_PATH_ERROR) - for nr, br in zip(npu_ranks, bench_ranks): - npu_data_dir = os.path.join(npu_dump_dir, nr) - bench_data_dir = os.path.join(bench_dump_dir, br) - npu_path = extract_json(npu_data_dir, stack_json=False) - bench_path = extract_json(bench_data_dir, stack_json=False) - - dump_result_param = { - "npu_json_path": npu_path, - "bench_json_path": bench_path, - "is_print_compare_log": is_print_compare_log - } - compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}-{br}', **kwargs) + compare_distributed_inner(npu_dump_dir, bench_dump_dir, output_path, compare, **kwargs) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py index 308a82b3d6e9beb67a669ea05b83d7b8a6eddc90..b3be3e793df30ada65a98374eae4f358da433fd3 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,92 +13,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os.path +from msprobe.core.common.utils import CompareException +from msprobe.core.common.log import logger +from msprobe.core.compare.acc_compare import Comparator, ModeConfig, MappingConfig, setup_comparison +from msprobe.pytorch.compare.utils import read_pt_data -import torch -from msprobe.core.common.const import FileCheckConst -from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.file_utils import FileChecker, create_directory, load_yaml -from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \ - set_dump_path -from msprobe.core.compare.acc_compare import Comparator, ModeConfig -from msprobe.core.compare.utils import set_stack_json_path -from msprobe.pytorch.common.log import logger -from msprobe.pytorch.common.utils import load_pt - - -class PTComparator(Comparator): - def __init__(self, mode_config, data_mapping=None): - super().__init__(mode_config) - - self.stack_mode = mode_config.stack_mode - self.auto_analyze = mode_config.auto_analyze - self.fuzzy_match = mode_config.fuzzy_match - self.dump_mode = mode_config.dump_mode - - self.frame_name = PTComparator.__name__ - self.data_mapping = data_mapping - if isinstance(self.data_mapping, str) or self.data_mapping is None: - self.data_mapping_dict = self.load_mapping_file(self.data_mapping) - elif isinstance(self.data_mapping, dict): - self.data_mapping_dict = self.data_mapping - else: - raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got " - f"{type(self.data_mapping)}") - - @staticmethod - def load_mapping_file(mapping_file): - if isinstance(mapping_file, str): - mapping_dict = load_yaml(mapping_file) - else: - mapping_dict = {} - return mapping_dict - - def read_npy_data(self, dir_path, file_name): - if not file_name: - return None - data_path = os.path.join(dir_path, file_name) - path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, - FileCheckConst.PT_SUFFIX, False) - data_path = path_checker.common_check() - try: - # detach because numpy can not process gradient information - data_value = load_pt(data_path, to_cpu=True).detach() - except RuntimeError as e: - # 这里捕获 load_pt 中抛出的异常 - logger.error(f"Failed to load the .pt file at {data_path}.") - raise CompareException(CompareException.INVALID_FILE_ERROR) from e - except AttributeError as e: - # 这里捕获 detach 方法抛出的异常 - logger.error(f"Failed to detach the loaded tensor.") - raise CompareException(CompareException.DETACH_ERROR) from e - if data_value.dtype == torch.bfloat16: - data_value = data_value.to(torch.float32) - data_value = data_value.numpy() - return data_value +def read_real_data(npu_dir, npu_data_name, bench_dir, bench_data_name, _) -> tuple: + n_value = read_pt_data(npu_dir, npu_data_name) + b_value = read_pt_data(bench_dir, bench_data_name) + return n_value, b_value def compare(input_param, output_path, **kwargs): - try: - auto_analyze = kwargs.get('auto_analyze', True) - fuzzy_match = kwargs.get('fuzzy_match', False) - data_mapping = kwargs.get('data_mapping', None) - suffix = kwargs.get('suffix', '') - - set_dump_path(input_param) - dump_mode = get_dump_mode(input_param) - if "stack_json_path" in input_param: - stack_mode = kwargs.get('stack_mode', False) - else: - stack_mode = set_stack_json_path(input_param) # set stack_mode and set "stack_json_path" in input_param - check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True)) - create_directory(output_path) - check_compare_param(input_param, output_path, dump_mode, stack_mode) - except (CompareException, FileCheckException) as error: - logger.error('Compare failed. Please check the arguments and do it again!') - raise CompareException(error.code) from error - - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - pt_comparator = PTComparator(mode_config, data_mapping) - pt_comparator.compare_core(input_param, output_path, suffix=suffix) + if not isinstance(input_param, dict): + logger.error("input_param should be dict, please check!") + raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR) + config = setup_comparison(input_param, output_path, **kwargs) + + config_dict = { + 'stack_mode': config.stack_mode, + 'auto_analyze': config.auto_analyze, + 'fuzzy_match': config.fuzzy_match, + 'highlight': config.highlight, + 'dump_mode': config.dump_mode, + 'first_diff_analyze': config.first_diff_analyze, + 'compared_file_type': config.compared_file_type + } + mode_config = ModeConfig(**config_dict) + mapping_config = MappingConfig(data_mapping=config.data_mapping) + pt_comparator = Comparator(read_real_data, mode_config, mapping_config) + pt_comparator.compare_core(input_param, output_path, suffix=config.suffix) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/pt_diff_analyze.py b/debug/accuracy_tools/msprobe/pytorch/compare/pt_diff_analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..b558a20b6f592ac9ebd758a0041155beee413caa --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/compare/pt_diff_analyze.py @@ -0,0 +1,21 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from msprobe.pytorch.compare.distributed_compare import compare_distributed + + +def pt_diff_analyze(npu_dump_dir, bench_dump_dir, output_path, first_diff_analyze): + compare_distributed(npu_dump_dir, bench_dump_dir, output_path, first_diff_analyze=first_diff_analyze) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/utils.py b/debug/accuracy_tools/msprobe/pytorch/compare/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..16473ff386d89de5f3bbb269e69837c07a950ea5 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/compare/utils.py @@ -0,0 +1,47 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch + +from msprobe.core.common.utils import logger, CompareException +from msprobe.core.common.file_utils import FileChecker, FileCheckConst +from msprobe.pytorch.common.utils import load_pt + + +def read_pt_data(dir_path, file_name): + if not file_name: + return None + + data_path = os.path.join(dir_path, file_name) + path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, + FileCheckConst.PT_SUFFIX, False) + data_path = path_checker.common_check() + try: + # detach because numpy can not process gradient information + data_value = load_pt(data_path, to_cpu=True).detach() + except RuntimeError as e: + # 这里捕获 load_pt 中抛出的异常 + logger.error(f"Failed to load the .pt file at {data_path}.") + raise CompareException(CompareException.INVALID_FILE_ERROR) from e + except AttributeError as e: + # 这里捕获 detach 方法抛出的异常 + logger.error(f"Failed to detach the loaded tensor.") + raise CompareException(CompareException.DETACH_ERROR) from e + if data_value.dtype == torch.bfloat16: + data_value = data_value.to(torch.float32) + data_value = data_value.numpy() + return data_value diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py index 77e78bc38063602e64b533291d60b9b12fd2ae00..5fc1dc784dbb216460f67559375e8049104f8cd6 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch - from msprobe.core.common.const import Const from msprobe.core.common.exceptions import MsprobeException from msprobe.pytorch.common.log import logger +from msprobe.pytorch.common.utils import is_torch_nn_module class DebuggerConfig: @@ -35,6 +34,7 @@ class DebuggerConfig: self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1 self.framework = Const.PT_FRAMEWORK self.async_dump = common_config.async_dump if common_config.async_dump else False + self.precision = common_config.precision if common_config.precision else Const.DUMP_PRECISION_LOW if self.task == Const.FREE_BENCHMARK: self.fuzz_device = task_config.fuzz_device @@ -48,24 +48,15 @@ class DebuggerConfig: "max_sample": task_config.max_sample } - self.online_run_ut = False - if self.task == Const.TENSOR: - # dump api tensor and collaborate with online run_ut - self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False - self.nfs_path = task_config.nfs_path if task_config.nfs_path else "" - self.tls_path = task_config.tls_path if task_config.tls_path else "" - self.host = task_config.host if task_config.host else "" - self.port = task_config.port if task_config.port else -1 - self.online_run_ut_recompute = task_config.online_run_ut_recompute \ - if isinstance(task_config.online_run_ut_recompute, bool) else False self.check() + self._check_statistics_config(task_config) if self.level == Const.LEVEL_L2: self.is_backward_kernel_dump = False self._check_and_adjust_config_with_l2() - def check_kwargs(self): + def check(self): if self.task and self.task not in Const.TASK_LIST: raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"The task <{self.task}> is not in the {Const.TASK_LIST}.") @@ -78,48 +69,61 @@ class DebuggerConfig: if not isinstance(self.async_dump, bool): raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"The parameters async_dump should be bool.") - if self.async_dump and self.task == Const.TENSOR and not self.list: - raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, - f"The parameters async_dump is true in tensor task, the parameters list cannot be " - f"empty.") if self.task == Const.STRUCTURE and self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]: logger.warning_on_rank_0( f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. " f"If not, the default level is {Const.LEVEL_MIX}." ) self.level = Const.LEVEL_MIX - - def check(self): - self.check_kwargs() + if self.async_dump: + if self.task == Const.TENSOR: + if self.level == Const.LEVEL_DEBUG: + self.list = [] # async_dump + debug level case ignore list + if not self.list and self.level != Const.LEVEL_DEBUG: + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, + f"The parameters async_dump is true in tensor task, the parameters list cannot be empty." + ) + if self.summary_mode == Const.MD5: + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, + f"The parameters async_dump is true, the parameters summary_mode cannot be md5." + ) return True - def check_model(self, instance, start_model): - if self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]: - if instance.model is not None or start_model is not None: - logger.info_on_rank_0( - f"The current level is not L0 or mix level, so the model parameters will not be used.") + def check_model(self, instance, start_model, token_range=None): + instance.model = start_model if start_model is not None else instance.model + + if token_range and not instance.model: + error_info = "The 'model' parameter must be provided when token_range is not None" + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, error_info) + + if self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX] and token_range is None: return - if start_model is None and instance.model is None: + + if instance.model is None: logger.error_on_rank_0( - f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' parameter.") + f"For level {self.level} or non-empty token_range, " + f"PrecisionDebugger or start interface must receive a 'model' parameter.") raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'") - instance.model = start_model if start_model is not None else instance.model - if isinstance(instance.model, torch.nn.Module): + if is_torch_nn_module(instance.model): return - error_model = None if isinstance(instance.model, (list, tuple)): + error_model = None for model in instance.model: - if not isinstance(model, torch.nn.Module): + if not is_torch_nn_module(model): error_model = model break + if error_model is not None: + error_info = (f"The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] " + f"type, currently there is an unsupported {type(error_model)} type.") + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, error_info) else: - error_model = instance.model - - if error_model is not None: error_info = (f"The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] " - f"type, currently there is a {type(error_model)} type.") + f"type, currently there is an unsupported {type(instance.model)} type.") raise MsprobeException( MsprobeException.INVALID_PARAM_ERROR, error_info) @@ -130,8 +134,23 @@ class DebuggerConfig: if not self.list or len(self.list) != 1: raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"When level is set to L2, the list must be configured as a list with one api name.") + if self.task != Const.TENSOR: + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, + f"When level is set to L2, the task must be set to tensor.") + api_name = self.list[0] if api_name.endswith(Const.BACKWARD): self.is_backward_kernel_dump = True api_forward_name = api_name[:-len(Const.BACKWARD)] + Const.FORWARD self.list.append(api_forward_name) + + def _check_statistics_config(self, task_config): + if self.task != Const.STATISTICS: + return + self.tensor_list = [] + if not hasattr(task_config, "tensor_list"): + return + if self.level == Const.LEVEL_DEBUG and task_config.tensor_list: + logger.warning_on_rank_0("When level is set to debug, the tensor_list will be invalid.") + return + self.tensor_list = task_config.tensor_list diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py index 5bb1d3a14e82d7b4bce9d7da8921a1d701e82222..0e784dafbd62688534547a7c0af79a067f97fac2 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py @@ -13,36 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple +from torch.utils.data import dataloader -import torch -from msprobe.core.common.const import Const, FileCheckConst, MsgConst +from msprobe.core.common.const import Const, MsgConst from msprobe.core.common.exceptions import MsprobeException -from msprobe.core.common.file_utils import FileChecker -from msprobe.core.common.utils import get_real_step_or_rank +from msprobe.core.common.utils import check_token_range, ThreadSafe +from msprobe.core.debugger.precision_debugger import BasePrecisionDebugger from msprobe.pytorch.common.log import logger -from msprobe.pytorch.common.utils import check_save_param +from msprobe.pytorch.common.utils import check_save_param, is_torch_nn_module from msprobe.pytorch.debugger.debugger_config import DebuggerConfig from msprobe.pytorch.dump.module_dump.module_dump import ModuleDumper from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor -from msprobe.pytorch.pt_config import parse_json_config -from msprobe.pytorch.service import Service -from torch.utils.data import dataloader - -ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task", - "dump_path", "level", "model"]) +from msprobe.pytorch.pytorch_service import PytorchService +from msprobe.pytorch.pt_config import parse_task_config -class PrecisionDebugger: - _instance = None - tasks_not_need_debugger = [Const.GRAD_PROBE] - - def __new__(cls, *args, **kwargs): - if cls._instance is None: - cls._instance = super(PrecisionDebugger, cls).__new__(cls) - cls._instance.config = None - cls._instance.enable_dataloader = False - return cls._instance +class PrecisionDebugger(BasePrecisionDebugger): def __init__( self, @@ -53,90 +39,67 @@ class PrecisionDebugger: model=None, step=None ): - if not hasattr(self, "initialized"): - config_params = ConfigParameters(config_path, - task, - dump_path, - level, - model) - self.check_input_params(config_params) - - self.initialized = True - self.model = model - common_config, task_config = parse_json_config(config_path, task) - self.task = task if task else common_config.task - if self.task == Const.GRAD_PROBE: - self.gm = GradientMonitor(common_config, task_config) - return - if step is not None: - common_config.step = get_real_step_or_rank(step, Const.STEP) - self.config = DebuggerConfig( - common_config, task_config, task, dump_path, level - ) - self.service = Service(self.config) - self.module_dumper = ModuleDumper(self.service) - self.enable_dataloader = self.config.enable_dataloader - if self.enable_dataloader: - logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.") - dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__) - - @property - def instance(self): - return self._instance + if self.initialized: + return + super().__init__(config_path, task, dump_path, level, step) + self.model = model + if self.task == Const.GRAD_PROBE: + self.gm = GradientMonitor(self.common_config, self.task_config) + return + self.config = DebuggerConfig( + self.common_config, self.task_config, task, dump_path, level + ) + self.service = PytorchService(self.config) + self.module_dumper = ModuleDumper(self.service) + self.ori_customer_func = {} + self.enable_dataloader = self.config.enable_dataloader + self._param_warning() @staticmethod - def check_input_params(args): - if args.config_path is not None: - if not isinstance(args.config_path, str): - raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string") - file_checker = FileChecker( - file_path=args.config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX) - file_checker.common_check() - - if args.task is not None and args.task not in Const.TASK_LIST: - raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}") + def _get_task_config(task, json_config): + return parse_task_config(task, json_config) - if args.dump_path is not None: - if not isinstance(args.dump_path, str): + @staticmethod + def _iter_tracer(func): + def func_wrapper(*args, **kwargs): + debugger_instance = PrecisionDebugger._instance + if not debugger_instance: raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string") + MsprobeException.INTERFACE_USAGE_ERROR, + f"PrecisionDebugger must be instantiated before executing the dataloader iteration" + ) - if args.level is not None and args.level not in Const.LEVEL_LIST: - raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}") + debugger_instance.enable_dataloader = False + if not debugger_instance.service.first_start: + debugger_instance.stop() + debugger_instance.step() + result = func(*args, **kwargs) + debugger_instance.start() + debugger_instance.enable_dataloader = True + return result - if args.model is not None: - logger.warning_on_rank_0( - "The 'model' parameter in the PrecisionDebugger will be deprecated in the future." - "It is recommended to pass the 'model' parameter in the start interface instead." - ) + return func_wrapper @classmethod - def start(cls, model=None): - instance = cls._instance - if not instance: - raise Exception(MsgConst.NOT_CREATED_INSTANCE) - if instance.task in PrecisionDebugger.tasks_not_need_debugger: + @ThreadSafe.synchronized + def start(cls, model=None, token_range=None): + instance = cls._get_instance() + if instance is None: return - instance.config.check_model(instance, model) + + check_token_range(token_range) + instance.config.check_model(instance, model, token_range) + if instance.enable_dataloader: logger.warning_on_rank_0("DataLoader is enabled, start() skipped.") else: - instance.service.start(instance.model) - - @classmethod - def forward_backward_dump_end(cls): - instance = cls._instance - instance.stop() + instance.service.start(instance.model, token_range) @classmethod + @ThreadSafe.synchronized def stop(cls): - instance = cls._instance - if not instance: - raise Exception(MsgConst.NOT_CREATED_INSTANCE) - if instance.task in PrecisionDebugger.tasks_not_need_debugger: + instance = cls._get_instance() + if instance is None: return if instance.enable_dataloader: logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.") @@ -144,14 +107,15 @@ class PrecisionDebugger: instance.service.stop() @classmethod + @ThreadSafe.synchronized def step(cls): - if not cls._instance: - raise Exception(MsgConst.NOT_CREATED_INSTANCE) - if cls._instance.task in PrecisionDebugger.tasks_not_need_debugger: + instance = cls._get_instance() + if instance is None: return cls._instance.service.step() @classmethod + @ThreadSafe.synchronized def monitor(cls, model): if not cls._instance: raise Exception(MsgConst.NOT_CREATED_INSTANCE) @@ -160,6 +124,7 @@ class PrecisionDebugger: cls._instance.gm.monitor(model) @classmethod + @ThreadSafe.synchronized def save(cls, variable, name, save_backward=True): instance = cls._instance if not instance: @@ -172,12 +137,24 @@ class PrecisionDebugger: return instance.service.save(variable, name, save_backward) + def _param_warning(self): + if self.model is not None: + logger.warning_on_rank_0( + "The 'model' parameter in the PrecisionDebugger will be deprecated in the future." + "It is recommended to pass the 'model' parameter in the start interface instead." + ) + if self.enable_dataloader: + logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.") + dataloader._BaseDataLoaderIter.__next__ = self._iter_tracer(dataloader._BaseDataLoaderIter.__next__) + +@ThreadSafe.synchronized def module_dump(module, dump_name): - if not isinstance(module, torch.nn.Module): + if not is_torch_nn_module(module): raise MsprobeException( MsprobeException.INVALID_PARAM_ERROR, - f"the module argument in module_dump must be a torch.nn.Module subclass" + f"the module argument in module_dump must be a torch.nn.Module type, " + f"but currently there is an unsupported {type(module)} type." ) if not isinstance(dump_name, str): raise MsprobeException( @@ -193,6 +170,7 @@ def module_dump(module, dump_name): instance.module_dumper.start_module_dump(module, dump_name) +@ThreadSafe.synchronized def module_dump_end(): instance = PrecisionDebugger._instance if not instance: @@ -201,17 +179,3 @@ def module_dump_end(): f"PrecisionDebugger must be instantiated before using module_dump_end interface" ) instance.module_dumper.stop_module_dump() - - -def iter_tracer(func): - def func_wrapper(*args, **kwargs): - debugger_instance = PrecisionDebugger.instance - debugger_instance.enable_dataloader = False - if not debugger_instance.service.first_start: - debugger_instance.stop() - debugger_instance.step() - result = func(*args, **kwargs) - debugger_instance.start() - debugger_instance.enable_dataloader = True - return result - return func_wrapper diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/hook_wrapper.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/hook_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..f41bf674602d5dbd8fc12d599fb388091ee56f04 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/hook_wrapper.py @@ -0,0 +1,91 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import wraps + +import torch +from torch.utils.hooks import BackwardHook + +from msprobe.core.common.const import Const +from msprobe.core.common.decorator import recursion_depth_decorator +from msprobe.pytorch.common.log import logger + + +def wrap_setup_backward_hook(func): + def requires_clone(tensor): + return isinstance(tensor, torch.Tensor) and tensor.requires_grad and torch.is_grad_enabled() + + @recursion_depth_decorator("Dump: wrap_setup_backward_hook.parse_tensor", max_depth=Const.DUMP_MAX_DEPTH) + def parse_tensor(item, tensor_list): + if requires_clone(item): + tensor_list.append(item) + elif isinstance(item, (list, tuple)): + for value in item: + parse_tensor(value, tensor_list) + elif isinstance(item, dict): + for value in item.values(): + parse_tensor(value, tensor_list) + + @recursion_depth_decorator("Dump: wrap_setup_backward_hook.rebuild_args", max_depth=Const.DUMP_MAX_DEPTH) + def rebuild_args(item, tensor_iter): + if requires_clone(item): + result = next(tensor_iter) + if hasattr(result, "_base") and result._base is not None: + if torch._C._autograd._get_creation_meta(result) != torch._C._autograd.CreationMeta(0): + torch._C._autograd._set_creation_meta(result, torch._C._autograd.CreationMeta(0)) + return result + if isinstance(item, list): + for index, value in enumerate(item): + item[index] = rebuild_args(value, tensor_iter) + return item + if isinstance(item, dict): + for key, value in item.items(): + item[key] = rebuild_args(value, tensor_iter) + return item + if isinstance(item, tuple): + if hasattr(item, '_fields'): + return type(item)(*[rebuild_args(i, tensor_iter) for i in item]) + return type(item)([rebuild_args(i, tensor_iter) for i in item]) + return item + + @wraps(func) + def wrap_setup_hook_func(*args, **kwargs): + if len(args) < 2: + return func(*args, **kwargs) + + actual_args = args[1] + + tensor_list = [] + + parse_tensor(actual_args, tensor_list) + + new_args = args[0], tuple(tensor_list) + hooked_tensors = func(*new_args, **kwargs) + + tensor_iter = iter(hooked_tensors) + try: + new_data = rebuild_args(actual_args, tensor_iter) + except Exception as e: + logger.debug(f"Unsupported data in setup input/output hook. The detail info: {e}") + new_data = actual_args + + return new_data + + return wrap_setup_hook_func + + +def wrap_setup_input_output_hook(): + BackwardHook.setup_input_hook = wrap_setup_backward_hook(BackwardHook.setup_input_hook) + BackwardHook.setup_output_hook = wrap_setup_backward_hook(BackwardHook.setup_output_hook) diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_dump.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_dump.py index 4700de6f1f9f3b5ddfb9507decb6f8739b5eda9b..5bf26f7ac0d91cce630a3b9c8e648453ae4ab65c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_dump.py +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_dump.py @@ -13,74 +13,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch -from msprobe.core.common.const import Const -from msprobe.core.data_dump.scope import BaseScope from msprobe.pytorch.common.log import logger -from msprobe.pytorch.hook_module.api_registry import api_register - -torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' +from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser +from msprobe.pytorch.hook_module.api_register import get_api_register class ModuleDumper: def __init__(self, service): self.service = service - self.hook_handle_list = [] + self.api_register = get_api_register() def start_module_dump(self, module, dump_name): - api_register.api_originality() - self.register_hook(module, dump_name) - - def stop_module_dump(self): - api_register.api_modularity() - for hook_handle in self.hook_handle_list: - if isinstance(hook_handle, torch.utils.hooks.RemovableHandle): - hook_handle.remove() - self.hook_handle_list.clear() + if hasattr(module, 'msprobe_hook') and not hasattr(module, 'msprobe_module_dump'): + logger.info_on_rank_0("The init dump is enabled, and the module dump function will not be available.") + return - def register_hook(self, module, dump_name): - prefix_name = ( - BaseScope.Module_Type_Module + Const.SEP + - dump_name + Const.SEP + - module.__class__.__name__ + Const.SEP - ) - module_processor = self.service.module_processor - _, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.service.build_hook( - BaseScope.Module_Type_Module, - prefix_name - ) + ModuleProcesser.enable_module_dump = True + self.api_register.restore_all_api() + if not hasattr(module, 'msprobe_module_dump'): + self.service.module_processor.register_module_hook(module, self.service.build_hook, + recursive=False, module_names=[dump_name]) + setattr(module, 'msprobe_module_dump', True) - if module_processor.has_register_backward_hook(module): - logger.warning( - f"The {dump_name} module has registered deprecated register_backward_hook," - f"which may cause abnormal data dump. The backward data dump for this module will be skipped." - ) - if torch_version_above_or_equal_2: - forward_hook_handle = module.register_forward_hook(forward_hook, with_kwargs=True) - else: - if not module_processor.has_register_backward_hook(module): - backward_hook_handle = module.register_full_backward_hook( - module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP) - ) - self.hook_handle_list.append(backward_hook_handle) - forward_hook_handle = module.register_forward_hook(forward_hook_torch_version_below_2) - self.hook_handle_list.append(forward_hook_handle) - if not module_processor.has_register_backward_hook(module): - backward_hook_handle = module.register_full_backward_hook(backward_hook) - self.hook_handle_list.append(backward_hook_handle) - - forward_pre_hook_handle = module.register_forward_pre_hook( - module_processor.node_hook(prefix_name + Const.FORWARD, Const.START) - ) - forward_hook_handle = module.register_forward_hook( - module_processor.node_hook(prefix_name + Const.FORWARD, Const.STOP) - ) - self.hook_handle_list.extend([forward_pre_hook_handle, forward_hook_handle]) - if torch_version_above_or_equal_2 and not module_processor.has_register_backward_hook(module): - backward_pre_hook_handle = module.register_full_backward_pre_hook( - module_processor.node_hook(prefix_name + Const.BACKWARD, Const.START) - ) - backward_hook_handle = module.register_full_backward_hook( - module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP) - ) - self.hook_handle_list.extend([backward_pre_hook_handle, backward_hook_handle]) + def stop_module_dump(self): + ModuleProcesser.enable_module_dump = False + self.api_register.register_all_api() diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py index b5ca1da461fd4235a09172de4b9dcea34a624e58..947fc5c0974702b68d9006bf4708ce82f2ead50f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py @@ -13,65 +13,90 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import wraps +import threading +import sys +from collections import OrderedDict import torch +from torch.utils.hooks import BackwardHook, RemovableHandle + from msprobe.core.common.const import Const +from msprobe.core.common.runtime import Runtime +from msprobe.core.common.utils import ModuleQueue, ThreadSafe from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope from msprobe.pytorch.common.log import logger -from msprobe.pytorch.common.utils import replace_last_occurrence -from torch.utils.checkpoint import checkpoint as origin_checkpoint -from torch.utils.checkpoint import set_checkpoint_early_stop -from torch.utils.hooks import BackwardHook +from msprobe.pytorch.common.utils import is_torch_nn_module, register_forward_pre_hook +from msprobe.pytorch.dump.module_dump.hook_wrapper import wrap_setup_input_output_hook torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' +if torch_version_above_or_equal_2: + from torch.utils.checkpoint import _StopRecomputationError + +def wrap_megatron_deallocate(func): + def wrapper_func(out, deallocate_pipeline_outputs=False): + if deallocate_pipeline_outputs and isinstance(out, torch.Tensor) and getattr(out, "_base") is not None: + out_clone = out.clone() + out.data = torch.empty((1,), device=out.device, dtype=out.dtype, ) + return func(out_clone, deallocate_pipeline_outputs) + return func(out, deallocate_pipeline_outputs) -def checkpoint_without_early_stop(*args, **kwargs): - with set_checkpoint_early_stop(False): - return origin_checkpoint(*args, **kwargs) + return wrapper_func -def replace_checkpoint(): - torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop +def wrap_forward_with_hook_safety(module): + """ + 包装模块的forward方法,确保异常时也执行forward_hook。 + """ + original_forward = module.forward + + def wrapped_forward(*args, **kwargs): + try: + output = original_forward(*args, **kwargs) + return output + except _StopRecomputationError as e: + exception_output = None + if len(module._forward_hooks.values()) > 0: + # msprobe的forward_hook会出现在第一个,仅执行msprobe的forward_hook + hook_fn = list(module._forward_hooks.values())[0] + hook_fn(module, args, kwargs, exception_output) + raise e + if torch_version_above_or_equal_2: + module.forward = wrapped_forward class ModuleProcesser: + module_queue = ModuleQueue() module_count = {} - module_stack = [] - api_parent_node = "" + module_stack = {} + api_parent_node = {} module_node = {} + module_bw_hook_kernels = {} + module_with_backward_hook = {} + enable_module_dump = False def __init__(self, scope): self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None - BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook) - BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook) - replace_checkpoint() - - @staticmethod - def clone_return_value(func): - @wraps(func) - def clone_return_value_func(*args, **kwargs): - result = func(*args, **kwargs) - return ModuleProcesser.clone_if_tensor(result) - - return clone_return_value_func - - @staticmethod - def clone_if_tensor(result): - if isinstance(result, torch.Tensor): - return result.clone() - elif type(result) is tuple: - return tuple(ModuleProcesser.clone_if_tensor(x) for x in result) - elif type(result) is list: - return list(ModuleProcesser.clone_if_tensor(x) for x in result) - elif type(result) is dict: - return {k: ModuleProcesser.clone_if_tensor(v) for k, v in result.items()} - else: - return result + wrap_setup_input_output_hook() + try: + from megatron.core.pipeline_parallel import schedules + origin_func_id = id(schedules.deallocate_output_tensor) + schedules.deallocate_output_tensor = wrap_megatron_deallocate(schedules.deallocate_output_tensor) + for module in list(sys.modules.values()): + if module.__name__ == 'schedules': + continue + for func in module.__dict__: + if id(module.__dict__[func]) == origin_func_id: + module.__setattr__(func, schedules.deallocate_output_tensor) + logger.debug(f'patch {module.__name__}.{func}.') + logger.info_on_rank_0("Patch megatron method success.") + except ImportError: + logger.info_on_rank_0("No megatron find.") + except Exception as e: + logger.info_on_rank_0(f"Patch megatron method failed, detail:{str(e)}") @staticmethod - def module_count_func(module_name): + def set_and_get_calls_number(module_name): if module_name not in ModuleProcesser.module_count: ModuleProcesser.module_count[module_name] = 0 else: @@ -85,120 +110,177 @@ class ModuleProcesser: module._is_full_backward_hook is False @staticmethod - def get_modules_and_names(models): + def get_modules_and_names(models, recursive, module_names): modules_and_names_with_index = {} if isinstance(models, (list, tuple)): + if not recursive and len(module_names) != len(models): + return modules_and_names_with_index for index, model in enumerate(models): - modules_and_names_with_index[str(index)] = model.named_modules() + modules_and_names_with_index[str(index)] = model.named_modules() if recursive else \ + [(module_names[index], model)] else: - modules_and_names_with_index["-1"] = models.named_modules() + if not recursive and len(module_names) != 1: + return modules_and_names_with_index + modules_and_names_with_index["-1"] = models.named_modules() if recursive else \ + [(module_names[0], models)] return modules_and_names_with_index @classmethod def reset_module_stats(cls): + cls.module_queue = ModuleQueue() cls.module_count = {} - cls.module_stack = [] - cls.api_parent_node = "" + cls.module_stack = {} + cls.api_parent_node = {} cls.module_node = {} + cls.module_bw_hook_kernels = {} + cls.enable_module_dump = False + + def register_module_hook(self, models, build_hook, recursive=True, module_names=None): + if module_names is None: + module_names = [] - def register_module_hook(self, models, build_hook): - logger.info_on_rank_0("The init dump is enabled, and the module dump function will not be available.") - modules_and_names_with_index = self.get_modules_and_names(models) + modules_and_names_with_index = self.get_modules_and_names(models, recursive, module_names) for index, modules_and_names in modules_and_names_with_index.items(): model = models if index == "-1" else models[int(index)] for name, module in modules_and_names: - if module == model: + if recursive and module == model: continue + if not is_torch_nn_module(module): + logger.warning( + f"The module dump does not support {type(module)} type. " + f"The data dump for this module will be skipped." + ) + continue + if module.__class__.__name__ == "FullyShardedDataParallel": + continue + setattr(module, 'msprobe_hook', True) module_index = (index + Const.SEP) if index != "-1" else "" - prefix_name = (BaseScope.Module_Type_Module + Const.SEP + module_index + - name + Const.SEP + module.__class__.__name__ + Const.SEP) - pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = build_hook( - BaseScope.Module_Type_Module, - prefix_name - ) + prefix_name = f'{BaseScope.Module_Type_Module}{Const.SEP}{module_index}{name}{Const.SEP}' + \ + f'{module.__class__.__name__}{Const.SEP}' + + forward_pre_hook = self.build_module_hook(prefix_name, build_hook) if self.has_register_backward_hook(module): logger.warning( f"The {prefix_name[:-1]} has registered deprecated register_backward_hook," f"which may cause abnormal data dump. The backward data dump for this module will be skipped." ) + ModuleProcesser.module_with_backward_hook[prefix_name] = True + wrap_forward_with_hook_safety(module) + register_forward_pre_hook(module, forward_pre_hook) + + def build_module_hook(self, module_name, build_data_hook): + @ThreadSafe.synchronized + def forward_pre_hook(module, args, kwargs=None): + if kwargs is None: + kwargs = {} + + if not Runtime.is_running: + return (args, kwargs) if torch_version_above_or_equal_2 else args + + if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump: + return (args, kwargs) if torch_version_above_or_equal_2 else args + + index = ModuleProcesser.set_and_get_calls_number(module_name) + full_forward_name = f'{module_name}{Const.FORWARD}{Const.SEP}{index}' + full_backward_name = f'{module_name}{Const.BACKWARD}{Const.SEP}{index}' + + self.set_construct_info_in_pre_hook(full_forward_name) + + if not hasattr(module, 'msprobe_forward_hook'): + forward_hooks_dict = getattr(module, '_forward_hooks', OrderedDict()) + handle = RemovableHandle(forward_hooks_dict) + forward_hooks_dict[handle.id] = forward_hook + forward_hooks_dict.move_to_end(handle.id, last=False) + if torch_version_above_or_equal_2: + forward_hooks_with_kwargs_dict = getattr(module, '_forward_hooks_with_kwargs', OrderedDict()) + forward_hooks_with_kwargs_dict[handle.id] = True + + setattr(module, 'msprobe_forward_hook', True) + + hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name) + + def get_backward_pre_hook(full_backward_name): + @ThreadSafe.synchronized + def backward_pre_hook_fn(module, grad_output): + self.set_construct_info_in_pre_hook(full_backward_name) + + return backward_pre_hook_fn + + def get_backward_hook(backward_data_hook, full_backward_name): + @ThreadSafe.synchronized + def backward_hook_fn(module, grad_input, grad_output): + new_output = backward_data_hook(module, grad_input, grad_output) + self.set_construct_info_in_hook(full_backward_name, is_forward=False) + return new_output + + return backward_hook_fn + + if not ModuleProcesser.module_with_backward_hook.get(module_name): + backward_pre_hook = get_backward_pre_hook(full_backward_name) + backward_hook = get_backward_hook(hook_set.backward_hook, full_backward_name) if torch_version_above_or_equal_2: - module.register_forward_hook(forward_hook, with_kwargs=True) + bw_hook = BackwardHook(module, [backward_hook], [backward_pre_hook]) else: - if not self.has_register_backward_hook(module): - module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP)) - module.register_forward_hook(forward_hook_torch_version_below_2) - if not self.has_register_backward_hook(module): - module.register_full_backward_hook(backward_hook) - - module.register_forward_pre_hook(self.node_hook(prefix_name + Const.FORWARD, Const.START)) - module.register_forward_hook(self.node_hook(prefix_name + Const.FORWARD, Const.STOP)) - if torch_version_above_or_equal_2 and not self.has_register_backward_hook(module): - module.register_full_backward_pre_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.START)) - module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP)) - - def node_hook(self, name_prefix, start_or_stop, **kwargs): - - def pre_hook(module, input, output=None): - try: - index = ModuleProcesser.module_count_func(name_prefix) - except IndexError as e: - index = None - pass - full_name = name_prefix + Const.SEP + str(index) - if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name: - module.mindstudio_reserved_name = [] - module.mindstudio_reserved_name.append(full_name) - if self.module_stack: - ModuleProcesser.module_node[full_name] = self.module_stack[-1] - else: - ModuleProcesser.module_node[full_name] = None + bw_hook = BackwardHook(module, [backward_hook]) + ModuleProcesser.module_bw_hook_kernels[full_forward_name] = bw_hook + args = bw_hook.setup_input_hook(args) + return (args, kwargs) if torch_version_above_or_equal_2 else args - ModuleProcesser.module_stack.append(full_name) - if self.module_stack: - ModuleProcesser.api_parent_node = self.module_stack[-1] - if self.scope: - self.scope.begin_module(full_name) + @ThreadSafe.synchronized + def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None): + if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump: + return output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output + + index = ModuleProcesser.module_count.get(module_name) + full_name = f'{module_name}{Const.FORWARD}{Const.SEP}{index}' - def end_hook(module, input, output=None): - if self.module_stack: - ModuleProcesser.module_stack.pop() - if self.module_stack: - ModuleProcesser.api_parent_node = self.module_stack[-1] + hook_set = build_data_hook(BaseScope.Module_Type_Module, full_name) + hook_result = hook_set.forward_hook(module, args, kwargs_or_output, output_or_kwargs) + self.set_construct_info_in_hook(full_name) + + if hook_result is not None: + result = hook_result else: - ModuleProcesser.api_parent_node = None - if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name: - raise RuntimeError(f"module reserve name is None when pop") - current_name = module.mindstudio_reserved_name.pop() + result = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output + + bw_hook = ModuleProcesser.module_bw_hook_kernels.get(full_name) + if bw_hook: + result = bw_hook.setup_output_hook(result) + + return result + + return forward_pre_hook + + def set_construct_info_in_pre_hook(self, full_name): + tid = threading.get_ident() + if tid not in self.module_stack: + ModuleProcesser.module_stack[tid] = [] + + if self.module_stack[tid]: + ModuleProcesser.module_node[full_name] = self.module_stack[tid][-1] + else: + parent_name = ModuleProcesser.module_queue.find_last(full_name) + ModuleProcesser.module_node[full_name] = parent_name + + ModuleProcesser.module_queue.add_name(full_name) + ModuleProcesser.module_stack[tid].append(full_name) + ModuleProcesser.api_parent_node[tid] = full_name + if self.scope: + self.scope.begin_module(full_name) + + def set_construct_info_in_hook(self, full_name, is_forward=True): + tid = threading.get_ident() + if torch_version_above_or_equal_2 or is_forward: + ModuleProcesser.module_queue.remove_name(full_name) + ModuleProcesser.api_parent_node[tid] = None + if self.module_stack.get(tid): + ModuleProcesser.module_stack[tid].pop() + if self.module_stack.get(tid): + ModuleProcesser.api_parent_node[tid] = ModuleProcesser.module_stack[tid][-1] if self.scope: - self.scope.end_module(current_name) - - def backward_hook(module, input, output=None): - try: - index = ModuleProcesser.module_count_func(name_prefix) - except IndexError as e: - index = None - pass - full_name = name_prefix + Const.SEP + str(index) - if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name: - module.mindstudio_reserved_name = [] - module.mindstudio_reserved_name.append(full_name) - forward_full_name = replace_last_occurrence(full_name, Const.BACKWARD, Const.FORWARD) - ModuleProcesser.module_node[full_name] = replace_last_occurrence( - ModuleProcesser.module_node.get(forward_full_name), Const.FORWARD, Const.BACKWARD) - ModuleProcesser.api_parent_node = None + self.scope.end_module(full_name) + else: if self.scope: self.scope.begin_module(full_name) - - if torch_version_above_or_equal_2: - if Const.START in start_or_stop: - return pre_hook - else: - return end_hook - else: - if Const.FORWARD in name_prefix and Const.START in start_or_stop: - return pre_hook - elif Const.BACKWARD in name_prefix: - return backward_hook - else: - return end_hook + ModuleProcesser.api_parent_node[tid] = full_name diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py index e3fd2b69fef2772354401a22344376258e77a008..6baa684cbff27001ac489eddf11269ba2c71dfae 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py @@ -16,7 +16,7 @@ import torch from msprobe.core.common.exceptions import FreeBenchmarkException -from msprobe.core.common.utils import recursion_depth_decorator +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.pytorch.free_benchmark.common.enums import DeviceType diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/single_benchmark.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/single_benchmark.py index 49e845da4011565f1b6ccf0c0e1193fb3fcffcbf..a5f18946c44c09bf1670173d45cc99ace3b0e79d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/single_benchmark.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/single_benchmark.py @@ -16,7 +16,7 @@ import math import torch -from msprobe.core.common.utils import recursion_depth_decorator +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.pytorch.free_benchmark import logger from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig from msprobe.pytorch.free_benchmark.common.utils import TorchC diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py index 66d7b7e10429dbfb939cdfa005422ce4f8e48f99..46a207c124e28004fa8877c99d4de53abdfe8617 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py @@ -17,8 +17,8 @@ from abc import ABC import torch from msprobe.core.common.const import Const +from msprobe.core.common.utils import replace_last_occurrence from msprobe.pytorch.free_benchmark import logger -from msprobe.pytorch.free_benchmark.common.constant import CommonField from msprobe.pytorch.free_benchmark.common.enums import ( DeviceType, FuzzLevel, @@ -37,6 +37,7 @@ from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import ( class FreeBenchmarkCheck(ABC): + grad_saver_dict = {} def __init__(self, config) -> None: super().__init__() @@ -68,7 +69,9 @@ class FreeBenchmarkCheck(ABC): grad_saver.kwargs = kwargs grad_saver.register_compare_func_for_inputs(args, data_processor) grad_saver.cache_backward_input(args) - setattr(module, CommonField.GRADSAVER, grad_saver) + + backward_name = replace_last_occurrence(name, Const.FORWARD, Const.BACKWARD) + FreeBenchmarkCheck.grad_saver_dict[backward_name] = grad_saver def forward(self, name, module, args, kwargs, output): if not self.config.fuzz_stage == Const.FORWARD: @@ -92,16 +95,16 @@ class FreeBenchmarkCheck(ABC): return perturbed_output, handler.get_unequal_rows() def backward(self, name, module, grad_output): - if not self.config.fuzz_stage == Const.BACKWARD: return try: - grad_saver = getattr(module, CommonField.GRADSAVER) + grad_saver = FreeBenchmarkCheck.grad_saver_dict[name] except AttributeError: logger.warning_on_rank_0( f"[msprobe] Free benchmark: get grad saver failed. api_name:{name}" ) return + del FreeBenchmarkCheck.grad_saver_dict[name] _new_grad_output = grad_output try: diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py index 41ec39e3a3b6233720c047d5d2b736d91bba989e..754e3b06e9670a04fcf7c20d5af3d7e1733b7af1 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py @@ -14,7 +14,7 @@ # limitations under the License. import torch -from msprobe.core.common.utils import recursion_depth_decorator +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.pytorch.free_benchmark import logger from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode @@ -95,13 +95,13 @@ class AddNoiseLayer(NpuBaseLayer): except Exception: logger.warning_on_rank_0( f"[msprobe] Free Benchmark: For {self.api_name}, " - f"when calculate maximun value, tensor is changed to float32." + f"when calculating the maximum value, the tensor is changed to float32." ) max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item() if max_val < abs_tol: logger.warning_on_rank_0( f"[msprobe] Free Benchmark: For {self.api_name}, " - f"Maximun value is less than the minimun threshold. Cancel add noise." + f"maximum value is less than the minimum threshold. Cancel adding noise." ) return False return True diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py index df1a73127aa0b69e42254cce1d3334810319f7cf..aec0c3ca96e39958316f6835261618c148c7ad4e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py @@ -14,7 +14,7 @@ # limitations under the License. import torch -from msprobe.core.common.utils import recursion_depth_decorator +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.pytorch.free_benchmark import logger from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode @@ -100,13 +100,13 @@ class BitNoiseLayer(NpuBaseLayer): except Exception: logger.warning_on_rank_0( f"[msprobe] Free Benchmark: For {self.api_name}, " - f"when calculate maximun value, tensor is changed to float32." + f"when calculate the maximum value, the tensor is changed to float32." ) max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item() if max_val < abs_tol: logger.warning_on_rank_0( f"[msprobe] Free Benchmark: For {self.api_name}, " - f"Maximun value is less than the minimun threshold. Cancel add noise." + f"maximum value is less than the minimum threshold. Cancel adding noise." ) return False return True diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py index c4fbeaf82f8fcafba235a7faa6dd9073d4d556d8..521637a1d8b3bca226a6eacfc5f6f5a0d4bc1921 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py @@ -14,7 +14,7 @@ # limitations under the License. import torch -from msprobe.core.common.utils import recursion_depth_decorator +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.pytorch.free_benchmark import logger from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode from msprobe.pytorch.free_benchmark.common.params import DataParams diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py index 095e77ffaff39a795cb1418c1695608d91d7427b..daa271976f3b05f81b9997bd1775ee2809b776c9 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py @@ -15,7 +15,7 @@ import torch from msprobe.core.common.const import Const -from msprobe.core.common.utils import recursion_depth_decorator +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.pytorch.free_benchmark import logger from msprobe.pytorch.free_benchmark.common.constant import CommonField from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py index 47f93ab7b89f44bdd4f92ceafc6e9dbe503d0374..e0d583dd012364f3bb75eb4d030dca21cfea2bc6 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py @@ -186,6 +186,8 @@ class FuzzHandler(ABC): ratio = self.ratio_calculate( origin_output, perturbed_output, norm_type=NormType.ENDLESS_NORM ) + if threshold == 0: + raise ValueError("Threshold cannot be zero. Check `get_threshold` implementation.") if ratio == ThresholdConfig.SYMBOL_FLIPPING: is_consistent = False else: diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/check_handler.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/check_handler.py index 9feec1531b16ff8ba63910f3f7c40aa275d0104e..d088cd1d1647a59c167f705702d9ad6afcf6e21b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/check_handler.py @@ -49,6 +49,6 @@ class CheckerHandler(FuzzHandler): except Exception as e: logger.warning_on_rank_0( f"[msprobe] Free Benchmark: For {self.params.api_name}, " - f"when campare the result exception raise {e}" + f"when comparing the results, an exception is raised: {e}" ) return data_params.original_result diff --git a/debug/accuracy_tools/msprobe/pytorch/function_factory.py b/debug/accuracy_tools/msprobe/pytorch/function_factory.py index 247e2cd0ed5ea11047cc0d75954dbc1e92b889f4..f515b5d4783c0e20a2303579f6954d42a7b9deac 100644 --- a/debug/accuracy_tools/msprobe/pytorch/function_factory.py +++ b/debug/accuracy_tools/msprobe/pytorch/function_factory.py @@ -70,7 +70,7 @@ class Register(dict): def add_register_item(key, value): if key in self._dict: - logger.warning(f"{value.__name__} has been registered before, so we will overriden it.") + logger.warning(f"{value.__name__} has been registered before, so we will override it.") self[key] = value return value diff --git a/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_monitor.py b/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_monitor.py index 926476b8fb353531e54a485ccb47c4c59860c5d0..81d7575fc251c0b90703b13c537f61f778cf5136 100644 --- a/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_monitor.py +++ b/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_monitor.py @@ -46,7 +46,7 @@ class GradientMonitor: if not os.path.exists(self._output_path): create_directory(self._output_path) else: - logger.warning(f"the file in {self._output_path} will be recoverd") + logger.warning(f"the file in {self._output_path} will be deleted") self._step = -1 self._param2name = defaultdict(str) @@ -97,7 +97,7 @@ class GradientMonitor: create_directory(output_dirpath) output_path = os.path.join(output_dirpath, f"grad_summary_{self._step}.csv") if os.path.exists(output_path): - logger.warning(f"{output_path} will be recoverd") + logger.warning(f"{output_path} will be deleted") remove_path(output_path) header_result = GradStatCsv.generate_csv_header(self._level_adp, self._bounds) output_lines.insert(0, header_result) diff --git a/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_stat_csv.py b/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_stat_csv.py index 6391f8f5e1d00f62240c002c35757bba623c3929..bf72a7fb0ff31fddfd7fe6582582b733428eee2b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_stat_csv.py +++ b/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_stat_csv.py @@ -17,6 +17,7 @@ from abc import ABC, abstractmethod from collections import namedtuple import hashlib from functools import wraps +import zlib import torch from msprobe.core.grad_probe.constant import GradConst @@ -74,8 +75,8 @@ class CsvMd5(CsvItem): def generate_csv_content(csv_content_input): grad = csv_content_input.grad tensor_bytes = grad.cpu().detach().float().numpy().tobytes() - md5_hash = hashlib.md5(tensor_bytes) - return [md5_hash.hexdigest()] + md5_hash = f"{zlib.crc32(tensor_bytes):08x}" + return [md5_hash] @register_csv_item(GradConst.DISTRIBUTION) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py new file mode 100644 index 0000000000000000000000000000000000000000..1fbe04b783248f14aeeef13b10d447c34dd7c7ec --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py @@ -0,0 +1,193 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import inspect +import os + +import torch +import torch.distributed as dist + +from msprobe.core.common.const import Const +from msprobe.core.common.file_utils import load_yaml +from msprobe.core.data_dump.api_registry import ApiRegistry +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.common.utils import ( + torch_without_guard_version, + is_gpu, + torch_device_guard, + parameter_adapter +) +from msprobe.pytorch.function_factory import npu_custom_functions +from msprobe.pytorch.hook_module.hook_module import HOOKModule +from msprobe.pytorch.hook_module.utils import dynamic_import_op + +try: + import mindspeed.ops +except ImportError: + mindspeed_enable = False +else: + mindspeed_enable = True + +torch_version_above_2 = torch.__version__.split('+')[0] > '2.0' + +_inner_used_api = {} +_supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),) +_cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"} +dist_data_collect_func = {} +dist_batch_data_collect_func = [] + +_api_types = { + Const.PT_FRAMEWORK: { + Const.PT_API_TYPE_FUNCTIONAL: ((torch.nn.functional,), (torch.nn.functional,)), + Const.PT_API_TYPE_TENSOR: ((torch.Tensor,), (torch.Tensor,)), + Const.PT_API_TYPE_TORCH: ((torch,), (torch,)), + Const.PT_API_TYPE_VF: ((torch._C._VariableFunctionsClass,), (torch._VF,)), + Const.PT_API_TYPE_DIST: ((dist,), (dist, dist.distributed_c10d)) + } +} +if not is_gpu: + import torch_npu + + if torch_without_guard_version: + _api_types.get(Const.PT_FRAMEWORK).update( + { + Const.PT_API_TYPE_NPU: ((torch.ops.npu, torch_npu), (torch_npu, torch.ops.npu)), + } + ) + else: + _api_types.get(Const.PT_FRAMEWORK).update( + {Const.PT_API_TYPE_NPU: ((torch_npu._C._VariableFunctionsClass,), (torch_npu,))} + ) + _api_types.get(Const.PT_FRAMEWORK).update( + { + Const.PT_API_TYPE_NPU_DIST: ( + (torch_npu.distributed,), + (torch_npu.distributed, torch_npu.distributed.distributed_c10d) + ) + } + ) + if mindspeed_enable: + _api_types.get(Const.PT_FRAMEWORK).update({Const.PT_API_TYPE_MINDSPEED: ((mindspeed.ops,), (mindspeed.ops,))}) + mindspeed_op_list = load_yaml(_supported_api_list_path[0]).get(Const.PT_API_TYPE_MINDSPEED) + mindspeed_op_file_list = [op.split(Const.SEP)[0] + Const.PY_SUFFIX for op in mindspeed_op_list] + dynamic_import_op(mindspeed.ops, mindspeed_op_file_list) + + +@parameter_adapter +def tensor_module_forward(module, *args, **kwargs): + return module.api_func(*args, **kwargs) + + +def dist_module_forward(module, *args, **kwargs): + handle = module.api_func(*args, **kwargs) + try: + bound = inspect.signature(module.api_func).bind(*args, **kwargs) + bound.apply_defaults() + use_async_op_flag = bound.arguments.get("async_op", False) + except Exception as e: + use_async_op_flag = False + logger.warning(f"fail to get dist api's func signature because {e}, no wait") + + def create_async_callback_func(catch_func): + full_name = module.full_forward_name if hasattr(module, "full_forward_name") else None + + def store_data(): + catch_func(module, full_name, args, kwargs, handle) + + return store_data + + if use_async_op_flag or module.api_name in ['isend', 'irecv']: + dist_data_collect_func[handle] = create_async_callback_func(module.distributed_forward_hook) + if module.api_name == 'batch_isend_irecv': + dist_batch_data_collect_func.append([handle, create_async_callback_func(module.distributed_forward_hook)]) + return handle + + +def redirect_wait(): + if hasattr(dist, "Work"): + from torch.distributed import Work + else: + from torch._C._distributed_c10d import Work + origin_wait = Work.wait + + def wrapped_wait(work): + def wrapped_wait(*args, **kwargs): + origin_wait(*args, **kwargs) + if args[0] in dist_data_collect_func: + store_func = dist_data_collect_func.pop(args[0]) + store_func() + return + for value in dist_batch_data_collect_func: + if args[0] in value[0]: + value[0].remove(args[0]) + if len(value[0]) == 0: + store_func = value[1] + store_func() + return + + return wrapped_wait + + Work.wait = wrapped_wait(Work) + + +def npu_module_forward(module, *args, **kwargs): + if not module.need_hook: + if module.api_name not in npu_custom_functions: + raise Exception(f'There is not bench function {module.api_name}') + if module.device == Const.CUDA_LOWERCASE: + module.api_name = _cuda_func_mapping.get(module.api_name, module.api_name) + if module.device in [Const.CUDA_LOWERCASE, Const.CPU_LOWERCASE]: + return npu_custom_functions[module.api_name](*args, **kwargs) + return module.api_func(*args, **kwargs) + + +forward_methods = { + "Tensor": tensor_module_forward, + "Distributed": dist_module_forward, + "NPU": npu_module_forward +} + + +class ApiTemplate(HOOKModule): + def __init__(self, api_name, api_func, prefix, hook_build_func, need_hook=True, device=Const.CPU_LOWERCASE): + self.api_name = api_name + self.prefix = prefix + self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP + self.need_hook = need_hook + self.device = device + self.op_is_distributed = prefix == Const.DIST_API_TYPE_PREFIX + if self.need_hook: + super().__init__(hook_build_func) + self.api_func = api_func + + @torch_device_guard + def forward(self, *args, **kwargs): + exec_func = forward_methods.get(self.prefix) + exec_func = functools.partial(exec_func, self) if exec_func else self.api_func + return exec_func(*args, **kwargs) + + +api_register = None + + +def get_api_register(return_new=False): + if return_new: + return ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate) + + global api_register + if api_register is None: + api_register = ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate) + return api_register diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/api_registry.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/api_registry.py deleted file mode 100644 index 1aad89bd6e89ae839513001b1d51572b50d8280b..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/api_registry.py +++ /dev/null @@ -1,166 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.distributed as dist - -from msprobe.pytorch.hook_module import wrap_torch, wrap_functional, wrap_tensor, wrap_vf, wrap_distributed, wrap_aten -from msprobe.pytorch.hook_module.wrap_aten import get_aten_ops -from msprobe.pytorch.hook_module.wrap_distributed import get_distributed_ops -from msprobe.pytorch.hook_module.wrap_functional import get_functional_ops -from msprobe.pytorch.hook_module.wrap_tensor import get_tensor_ops -from msprobe.pytorch.hook_module.wrap_torch import get_torch_ops -from msprobe.pytorch.hook_module.wrap_vf import get_vf_ops -from msprobe.pytorch.common.utils import torch_without_guard_version, npu_distributed_api, is_gpu -from msprobe.core.common.const import Const - -torch_version_above_2 = torch.__version__.split('+')[0] > '2.0' - -if not is_gpu: - import torch_npu - from . import wrap_npu_custom - from .wrap_npu_custom import get_npu_ops - - -class ApiRegistry: - def __init__(self): - self.tensor_ori_attr = {} - self.torch_ori_attr = {} - self.functional_ori_attr = {} - self.distributed_ori_attr = {} - self.npu_distributed_ori_attr = {} - self.vf_ori_attr = {} - self.aten_ori_attr = {} - self.torch_npu_ori_attr = {} - - self.tensor_hook_attr = {} - self.torch_hook_attr = {} - self.functional_hook_attr = {} - self.distributed_hook_attr = {} - self.npu_distributed_hook_attr = {} - self.vf_hook_attr = {} - self.aten_hook_attr = {} - self.torch_npu_hook_attr = {} - - @staticmethod - def store_ori_attr(ori_api_group, api_list, api_ori_attr): - for api in api_list: - if '.' in api: - sub_module_name, sub_op = api.rsplit('.', 1) - sub_module = getattr(ori_api_group, sub_module_name) - api_ori_attr[api] = getattr(sub_module, sub_op) - else: - api_ori_attr[api] = getattr(ori_api_group, api) - - @staticmethod - def set_api_attr(api_group, attr_dict): - for api, api_attr in attr_dict.items(): - if '.' in api: - sub_module_name, sub_op = api.rsplit('.', 1) - sub_module = getattr(api_group, sub_module_name, None) - if sub_module is not None: - setattr(sub_module, sub_op, api_attr) - else: - setattr(api_group, api, api_attr) - - def api_modularity(self): - self.set_api_attr(torch.Tensor, self.tensor_hook_attr) - self.set_api_attr(torch, self.torch_hook_attr) - self.set_api_attr(torch.nn.functional, self.functional_hook_attr) - self.set_api_attr(dist, self.distributed_hook_attr) - self.set_api_attr(dist.distributed_c10d, self.distributed_hook_attr) - if not is_gpu and not torch_without_guard_version: - self.set_api_attr(torch_npu.distributed, self.npu_distributed_hook_attr) - self.set_api_attr(torch_npu.distributed.distributed_c10d, self.npu_distributed_hook_attr) - if torch_version_above_2: - self.set_api_attr(torch.ops.aten, self.aten_hook_attr) - self.set_api_attr(torch._VF, self.vf_hook_attr) - if not is_gpu: - self.set_api_attr(torch_npu, self.torch_npu_hook_attr) - - def api_originality(self): - self.set_api_attr(torch.Tensor, self.tensor_ori_attr) - self.set_api_attr(torch, self.torch_ori_attr) - self.set_api_attr(torch.nn.functional, self.functional_ori_attr) - self.set_api_attr(dist, self.distributed_ori_attr) - self.set_api_attr(dist.distributed_c10d, self.distributed_ori_attr) - if not is_gpu and not torch_without_guard_version: - self.set_api_attr(torch_npu.distributed, self.npu_distributed_ori_attr) - self.set_api_attr(torch_npu.distributed.distributed_c10d, self.npu_distributed_ori_attr) - if torch_version_above_2: - self.set_api_attr(torch.ops.aten, self.aten_ori_attr) - self.set_api_attr(torch._VF, self.vf_ori_attr) - if not is_gpu: - self.set_api_attr(torch_npu, self.torch_npu_ori_attr) - - def initialize_hook(self, hook, online_run_ut=False): - """ - initialize_hook - Args: - hook (_type_): initialize_hook - online_run_ut (bool): default False, whether online run_ut or not. - If online_run_ut is True, the hook will not wrap the aten ops. - """ - self.store_ori_attr(torch.Tensor, get_tensor_ops(), self.tensor_ori_attr) - wrap_tensor.wrap_tensor_ops_and_bind(hook) - for attr_name in dir(wrap_tensor.HOOKTensor): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.tensor_hook_attr[attr_name[5:]] = getattr(wrap_tensor.HOOKTensor, attr_name) - - self.store_ori_attr(torch, get_torch_ops(), self.torch_ori_attr) - wrap_torch.wrap_torch_ops_and_bind(hook) - for attr_name in dir(wrap_torch.HOOKTorchOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.torch_hook_attr[attr_name[5:]] = getattr(wrap_torch.HOOKTorchOP, attr_name) - - self.store_ori_attr(torch.nn.functional, get_functional_ops(), self.functional_ori_attr) - wrap_functional.wrap_functional_ops_and_bind(hook) - for attr_name in dir(wrap_functional.HOOKFunctionalOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.functional_hook_attr[attr_name[5:]] = getattr(wrap_functional.HOOKFunctionalOP, attr_name) - - self.store_ori_attr(dist, get_distributed_ops(), self.distributed_ori_attr) - wrap_distributed.wrap_distributed_ops_and_bind(hook) - if not is_gpu and not torch_without_guard_version: - self.store_ori_attr(torch_npu.distributed, npu_distributed_api, self.npu_distributed_ori_attr) - for attr_name in dir(wrap_distributed.HOOKDistributedOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.distributed_hook_attr[attr_name[5:]] = getattr(wrap_distributed.HOOKDistributedOP, attr_name) - if not is_gpu and not torch_without_guard_version and attr_name[5:] in npu_distributed_api: - self.npu_distributed_hook_attr[attr_name[5:]] = getattr(wrap_distributed.HOOKDistributedOP, - attr_name) - - if torch_version_above_2 and not online_run_ut: - self.store_ori_attr(torch.ops.aten, get_aten_ops(), self.aten_ori_attr) - wrap_aten.wrap_aten_ops_and_bind(hook) - for attr_name in dir(wrap_aten.HOOKAtenOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.aten_hook_attr[attr_name[5:]] = getattr(wrap_aten.HOOKAtenOP, attr_name) - - self.store_ori_attr(torch._VF, get_vf_ops(), self.vf_ori_attr) - wrap_vf.wrap_vf_ops_and_bind(hook) - for attr_name in dir(wrap_vf.HOOKVfOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.vf_hook_attr[attr_name[5:]] = getattr(wrap_vf.HOOKVfOP, attr_name) - - if not is_gpu: - self.store_ori_attr(torch_npu, get_npu_ops(), self.torch_npu_ori_attr) - wrap_npu_custom.wrap_npu_ops_and_bind(hook) - for attr_name in dir(wrap_npu_custom.HOOKNpuOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.torch_npu_hook_attr[attr_name[5:]] = getattr(wrap_npu_custom.HOOKNpuOP, attr_name) - - -api_register = ApiRegistry() diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py index b59d4be82f2b55326c2a1d6a8a9e127a8470bff6..548d1a9bcebac3fc2cd6da63b4977e141dd06952 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,52 +14,30 @@ # limitations under the License. import functools -import threading from collections import defaultdict import torch import torch.nn as nn import torch.utils.hooks as full_hooks -torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' +from msprobe.pytorch.common.utils import register_forward_pre_hook class HOOKModule(nn.Module): module_count = defaultdict(int) - inner_stop_hook = {} - def __init__(self, build_hook) -> None: + def __init__(self, hook_build_func) -> None: super(HOOKModule, self).__init__() - self.has_overflow = False - self.prefix = "" - self.current_thread = threading.current_thread().ident - if self.current_thread not in HOOKModule.inner_stop_hook: - HOOKModule.inner_stop_hook[self.current_thread] = False - self.stop_hook = HOOKModule.inner_stop_hook.get(self.current_thread, False) - - if not self.stop_hook: - if hasattr(self, "prefix_op_name_"): - self.prefix = self.prefix_op_name_ - - self.forward_data_collected = False - forward_pre_hook, forward_hook, backward_hook, _ = build_hook(self.prefix) - if torch_version_above_or_equal_2: - self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True) - self.register_forward_hook(forward_hook, with_kwargs=True) - else: - self.register_forward_pre_hook(forward_pre_hook) - self.register_forward_hook(forward_hook) - self.register_backward_hook(backward_hook) + prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else "" + op_is_distributed = self.op_is_distributed if hasattr(self, "op_is_distributed") else False + if callable(hook_build_func): + hook_set = hook_build_func(prefix) + register_forward_pre_hook(self, hook_set.forward_pre_hook) + if op_is_distributed: + self.distributed_forward_hook = hook_set.distributed_forward_hook def __call__(self, *args, **kwargs): - changed = False - if not self.stop_hook: - HOOKModule.inner_stop_hook[self.current_thread] = True - changed = True - result = self._call_func(*args, **kwargs) - if changed: - HOOKModule.inner_stop_hook[self.current_thread] = False - return result + return self._call_func(*args, **kwargs) @staticmethod def reset_module_stats(): @@ -78,13 +56,7 @@ class HOOKModule(nn.Module): if len(self._backward_hooks) > 0: full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() for hook in self._forward_pre_hooks.values(): - result_args, result_kwargs = hook(self, args, kwargs) - if result_args is not None: - if not isinstance(result_args, tuple): - result_args = (result_args,) - args = result_args - if result_kwargs is not None: - kwargs = result_kwargs + hook(self, args, kwargs) bw_hook = None if len(full_backward_hooks) > 0: bw_hook = full_hooks.BackwardHook(self, full_backward_hooks) @@ -111,6 +83,10 @@ class HOOKModule(nn.Module): return result else: return result + + if not (var.requires_grad and torch.is_grad_enabled()): + return result + grad_fn = var.grad_fn if grad_fn is not None: for hook in non_full_backward_hooks: diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/pt_hook_manager.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/pt_hook_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..96c56a586b29606e212bd13a45a75174a371d50e --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/pt_hook_manager.py @@ -0,0 +1,137 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import threading +from contextlib import nullcontext + +import torch + +from msprobe.core.common.const import Const +from msprobe.core.common.runtime import Runtime +from msprobe.core.common.utils import replace_last_occurrence, ThreadSafe +from msprobe.core.data_dump.data_processor.base import (ModuleForwardInputsOutputs) +from msprobe.core.hook_manager import BaseHookManager, HookSet +from msprobe.pytorch.common.utils import is_recomputation, torch_version_above_or_equal_2, register_forward_hook +from msprobe.pytorch.hook_module.hook_module import HOOKModule + + +class PytorchHookManager(BaseHookManager): + @property + def _is_recompute(self): + return is_recomputation() + + @staticmethod + def _no_grad_context(): + return nullcontext() + + @staticmethod + def _add_count(name): + HOOKModule.add_module_count(name) + + @staticmethod + def _get_count(name): + return HOOKModule.get_module_count(name) + + @staticmethod + def _process_kwargs_and_output(module, tid, hook_type, kwargs_or_output, output_or_kwargs): + if hook_type == Const.API: + kwargs = kwargs_or_output + output = output_or_kwargs + else: + kwargs = kwargs_or_output if torch_version_above_or_equal_2 else {} + output = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output + return kwargs, output + + def build_hook(self, hook_type, name): + if hook_type == Const.API: + hook_set = HookSet( + forward_pre_hook=self._build_forward_pre_hook(hook_type, name), + distributed_forward_hook=self._build_distributed_forward_hook() + ) + else: + full_backward_name = replace_last_occurrence(name, Const.FORWARD, Const.BACKWARD) + hook_set = HookSet( + forward_hook=self._build_forward_hook(hook_type, name), + backward_hook=self._build_backward_hook(hook_type, full_backward_name) + ) + return hook_set + + def _register_forward_hook(self, module, api_name): + if not hasattr(module, 'msprobe_forward_hook'): + register_forward_hook(module, self._build_forward_hook(Const.API, api_name)) + setattr(module, 'msprobe_forward_hook', True) + + def _register_backward_hook(self, module, full_backward_name, args): + pass + + def _register_backward_pre_hook(self, module, full_backward_name, output): + var = output + while not isinstance(var, torch.Tensor): + if isinstance(var, dict): + var = next((v for v in var.values() if isinstance(v, torch.Tensor))) + elif isinstance(var, (list, tuple)): + if var: + var = var[0] + else: + return output + else: + return output + + if not (var.requires_grad and torch.is_grad_enabled()): + return output + + grad_fn = var.grad_fn + if grad_fn is not None: + backward_hook = self._build_backward_hook(Const.API, full_backward_name) + wrapper = functools.partial(backward_hook, module) + functools.update_wrapper(wrapper, backward_hook) + grad_fn.register_hook(wrapper) + + return output + + def _need_exchange(self, module): + return True + + def _get_params_dict(self, module): + params_dict = {} + if self.config.task != Const.STRUCTURE: + params_dict = { + key.split(Const.SEP)[-1]: value + for key, value in module.named_parameters(recurse=False) + } + return params_dict + + def _build_distributed_forward_hook(self): + def distributed_forward_hook(module, full_name, args, kwargs, output): + if not full_name or not Runtime.is_running: + return + + tid = threading.get_ident() + with ThreadSafe(): + BaseHookManager.inner_switch[tid] = True + self.data_collector.update_api_or_module_name(full_name) + module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output) + with self._no_grad_context(): + self.data_collector.forward_output_data_collect( + full_name, + module, + self._pid, + module_input_output, + self._is_recompute + ) + BaseHookManager.inner_switch[tid] = False + + return distributed_forward_hook diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/register_optimizer_hook.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/register_optimizer_hook.py index 75be9fc4532ea5863ed3daad569c062c4ccb91ba..b4f9a5f50639752e8094b38961ef600cc6d7b101 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/register_optimizer_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/register_optimizer_hook.py @@ -32,8 +32,9 @@ def register_optimizer_hook(data_collector): def patch_clip_grad(func): def wrapper(*args, **kwargs): data_collector.optimizer_status = Const.CLIP_GRAD - func(*args, **kwargs) + result = func(*args, **kwargs) data_collector.optimizer_status = Const.END_PREFIX + Const.CLIP_GRAD + return result return wrapper diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..7da0221c22e259f58c59eaa540bb824f427dc668 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py @@ -0,0 +1,140 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import importlib +import types + +import torch + +from msprobe.core.common.log import logger +from msprobe.pytorch.common.utils import torch_version_above_or_equal_2 +from msprobe.pytorch.hook_module.api_register import get_api_register + +if torch_version_above_or_equal_2: + from torch._dynamo.convert_frame import convert_frame as _orig_convert_frame, Hooks + + +def wrap_jit_script_func(): + def patched_script(*args, **kwargs): + all_api_registered = api_register.all_api_registered + if all_api_registered: + api_register.restore_all_api() + result = original_script(*args, **kwargs) + if all_api_registered: + api_register.register_all_api() + return result + + original_script = torch.jit.script + api_register = get_api_register() + torch.jit.script = patched_script + + +def wrap_compile_script_func(): + def _patched_convert_frame(compiler_fn, hooks): + """ + 在调用原 convert_frame 生成的 _convert_frame 之前恢复 API, + 调用完之后再重新注册所有 API。 + """ + # 拿到原来 inner 版的 _convert_frame + inner_convert = _orig_convert_frame(compiler_fn, hooks) + + def _wrapped(frame: types.FrameType, cache_size: int, hooks: Hooks, frame_state): + reg = get_api_register() + # 进入前 restore + reg.restore_all_api() + try: + result = inner_convert(frame, cache_size, hooks, frame_state) + except Exception: + # 异常时也要确保 register + reg.register_all_api() + raise + # 正常结束后 register + reg.register_all_api() + return result + + # 保留原属性以兼容 + _wrapped._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined] + _wrapped._clone_with_backend = lambda backend: _patched_convert_frame(backend, + hooks) # type: ignore[attr-defined] + return _wrapped + + import torch._dynamo.convert_frame as _cf_mod + _cf_mod.convert_frame = _patched_convert_frame + + +def patch_dynamo_compile(): + cf = importlib.import_module("torch._dynamo.convert_frame") + if not hasattr(cf, "_compile"): + logger.warning("No found torch._dynamo.convert_frame._compile") + + original = cf._compile + if getattr(original, "__msprobe_patched__", False): + return + + @functools.wraps(original) + def wrapped(*args, **kwargs): + result = None + try: + reg = get_api_register() + reg.restore_all_api() + except Exception as e: + logger.warning(f"[msprobe] Pre restore_all_api failed: {e}") + return result + + try: + result = original(*args, **kwargs) + except Exception: + logger.warning("[msprobe] _compile execution failed (returning None)") + result = None + finally: + try: + reg = get_api_register() + reg.register_all_api() # 改成注册hook + except Exception as e: + logger.warning(f"[msprobe] Post register_all_api failed: {e}") + return result + wrapped.__msprobe_patched__ = True + wrapped.__msprobe_original__ = original + cf._compile = wrapped + + +def unpatch_dynamo_compile() -> bool: + # 预留取消patch接口 + cf = importlib.import_module("torch._dynamo.convert_frame") + current = getattr(cf, "_compile", None) + if current is None: + return False + original = getattr(current, "__msprobe_original__", None) + if original is None: + return False + cf._compile = original + return True + + +def preprocess_func(): + try: + from torch.utils._device import _device_constructors + _device_constructors() + except ImportError: + pass + except Exception as e: + logger.warning(f"Failed to execute _device_constructors. Error Details: {str(e)}") + + +def wrap_script_func(): + wrap_jit_script_func() + if torch_version_above_or_equal_2: + patch_dynamo_compile() diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml index 4bc22f51ceb5497f307fb4ac3226c8c590ea459a..638fe8c620fabdb6214104fae7125173b347578a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml @@ -149,9 +149,9 @@ tensor: - __bool__ - __div__ - __eq__ + - __floordiv__ - __ge__ - __gt__ - - __getitem__ - __iadd__ - __iand__ - __idiv__ @@ -160,23 +160,33 @@ tensor: - __imod__ - __imul__ - __ior__ + - __ipow__ - __irshift__ - __isub__ - __ixor__ + - __le__ - __lshift__ + - __lt__ - __matmul__ - __mod__ - __mul__ + - __ne__ - __nonzero__ - __or__ + - __pow__ - __radd__ + - __rdiv__ + - __rmod__ - __rmul__ + - __ror__ + - __rpow__ - __rshift__ + - __rsub__ + - __rxor__ - __setitem__ - __sub__ - __truediv__ - __xor__ - - __pow__ - abs - abs_ - absolute @@ -199,12 +209,14 @@ tensor: - addmv_ - addr - addr_ + - adjoint - align_as - align_to - all - allclose - amax - amin + - aminmax - angle - any - arccos @@ -216,12 +228,15 @@ tensor: - arcsinh - arcsinh_ - arctan + - arctan2 + - arctan2_ - arctan_ - arctanh - arctanh_ - argmax - argmin - argsort + - argwhere - asin - asin_ - asinh @@ -236,39 +251,51 @@ tensor: - baddbmm_ - bernoulli - bernoulli_ + - bfloat16 - bincount - bitwise_and - bitwise_and_ + - bitwise_left_shift + - bitwise_left_shift_ - bitwise_not - bitwise_not_ - bitwise_or - bitwise_or_ + - bitwise_right_shift + - bitwise_right_shift_ - bitwise_xor - bitwise_xor_ - bmm + - bool - broadcast_to + - byte - cauchy_ - ceil - ceil_ + - cfloat + - char - cholesky + - cholesky_inverse + - cholesky_solve - chunk - clamp - - cholesky_solve - - cholesky_inverse - clamp_ - clamp_max - clamp_max_ - - clip - clamp_min - clamp_min_ + - clip - clip_ + - conj_physical - copysign - copysign_ + - corrcoef - cos - cos_ - cosh - cosh_ - count_nonzero + - cov - cummax - cummin - cumprod @@ -282,20 +309,23 @@ tensor: - diag_embed - diagflat - diagonal + - diagonal_scatter - diff - - dist - digamma - digamma_ + - dist - div - div_ - divide - divide_ - dot + - double + - dsplit - eig - eq - eq_ - - erf - equal + - erf - erf_ - erfc - erfc_ @@ -304,18 +334,21 @@ tensor: - exp - exp2 - exp2_ - - expm1 - exp_ + - expand + - expand_as + - expm1 - expm1_ - exponential_ - fill_ - - fix - fill_diagonal_ + - fix - fix_ + - flatten - flip - fliplr - - flatten - flipud + - float - float_power - float_power_ - floor @@ -328,6 +361,7 @@ tensor: - fmod_ - frac - frac_ + - frexp - gather - gcd - gcd_ @@ -338,31 +372,37 @@ tensor: - ger - greater - greater_ - - gt - - gt_ - greater_equal - greater_equal_ + - gt + - gt_ + - half - hardshrink - heaviside - heaviside_ - histc + - histogram + - hsplit - hypot - hypot_ + - i0 + - i0_ - igamma - igamma_ - igammac - igammac_ - index_add - index_add_ - - inverse - index_copy - index_copy_ - index_fill - index_fill_ - index_put - index_put_ - - inner - index_select + - inner + - int + - inverse - isclose - isfinite - isinf @@ -380,7 +420,6 @@ tensor: - le_ - lerp - lerp_ - - where - less - less_ - less_equal @@ -397,43 +436,47 @@ tensor: - log_ - log_normal_ - log_softmax - - logcumsumexp - - logdet - logaddexp - logaddexp2 + - logcumsumexp + - logdet - logical_and - logical_and_ - logical_not - - logit - logical_not_ - logical_or - logical_or_ - logical_xor - logical_xor_ + - logit - logit_ - logsumexp + - long - lstsq - lt - lt_ + - lu - lu_solve - map2_ - map_ - masked_fill - - matmul - masked_fill_ - masked_scatter - masked_scatter_ - masked_select + - matmul - matrix_exp + - matrix_power - max - maximum - mean - - matrix_power - median - min - minimum - mm - mode + - moveaxis + - movedim - msort - mul - mul_ @@ -443,6 +486,11 @@ tensor: - mv - mvlgamma - mvlgamma_ + - nan_to_num + - nan_to_num_ + - nanmean + - nanmedian + - nanquantile - nansum - narrow - narrow_copy @@ -452,20 +500,29 @@ tensor: - neg_ - negative - negative_ + - nextafter + - nextafter_ - nonzero - norm - normal_ - not_equal - not_equal_ + - numpy + - orgqr + - ormqr + - outer - permute - pinverse - polygamma + - polygamma_ - pow - pow_ - - polygamma_ - prelu - prod - put_ + - q_zero_point + - qr + - quantile - rad2deg - rad2deg_ - ravel @@ -474,15 +531,16 @@ tensor: - relu - relu_ - remainder - - repeat_interleave - - reshape - remainder_ - renorm - renorm_ - repeat + - repeat_interleave + - reshape - reshape_as - resize_ - resize_as_ + - resolve_neg - roll - rot90 - round @@ -496,6 +554,7 @@ tensor: - select - sgn - sgn_ + - short - sigmoid - sigmoid_ - sign @@ -507,11 +566,13 @@ tensor: - sinc_ - sinh - sinh_ + - slice_scatter - slogdet - smm - softmax - solve - sort + - split - split_with_sizes - sqrt - sqrt_ @@ -521,21 +582,29 @@ tensor: - squeeze_ - sspaddmm - std + - stft + - stride - sub - sub_ + - subtract - sum - sum_to_size - svd + - swapaxes + - swapdims + - swapdims_ - symeig - t - t_ - take + - take_along_dim - tan - tan_ - tanh - tanh_ - tensor_split - tile + - to - topk - transpose - transpose_ @@ -543,8 +612,8 @@ tensor: - tril - tril_ - triu - - true_divide - triu_ + - true_divide - true_divide_ - trunc - trunc_ @@ -552,37 +621,20 @@ tensor: - unbind - unflatten - unfold + - unique + - unique_consecutive - unsafe_chunk - - unsqueeze - unsafe_split - unsafe_split_with_sizes + - unsqueeze + - unsqueeze_ - var - vdot - - unsqueeze_ - view_as + - vsplit + - where - xlogy - xlogy_ - - split - - stft - - nan_to_num - - dsplit - - orgqr - - bitwise_left_shift_ - - arctan2 - - histogram - - q_zero_point - - adjoint - - ormqr - - bitwise_right_shift_ - - nanquantile - - lu - - quantile - - arctan2_ - - qr - - diagonal_scatter - - corrcoef - - vsplit - - aminmax torch: - linalg.norm @@ -624,6 +676,7 @@ torch: - _batch_norm_impl_index - _convolution - _foreach_norm + - _fused_adamw_ - _softmax_backward_data - abs - abs_ @@ -642,13 +695,14 @@ torch: - addmv - addmv_ - addr - - amax - affine_grid_generator - align_tensors - all - alpha_dropout - - amin - alpha_dropout_ + - amax + - amin + - aminmax - angle - any - arange @@ -661,12 +715,14 @@ torch: - arcsinh - arcsinh_ - arctan + - arctan2 - arctan_ - arctanh - arctanh_ - argmax - argmin - argsort + - argwhere - asin - asin_ - asinh @@ -687,13 +743,13 @@ torch: - batch_norm_elemt - batch_norm_gather_stats - batch_norm_gather_stats_with_counts - - bernoulli - batch_norm_stats - batch_norm_update_stats + - bernoulli - bilinear + - binary_cross_entropy_with_logits - bincount - binomial - - binary_cross_entropy_with_logits - bitwise_and - bitwise_not - bitwise_or @@ -739,9 +795,9 @@ torch: - conv_transpose1d - conv_transpose2d - conv_transpose3d - - cos - convolution - copysign + - cos - cos_ - cosh - cosh_ @@ -755,14 +811,16 @@ torch: - cummin - cumprod - cumsum + - cumulative_trapezoid - deg2rad - deg2rad_ - det - diag - diag_embed - - diff - diagflat - diagonal + - diagonal_scatter + - diff - digamma - dist - div @@ -771,12 +829,15 @@ torch: - dropout - dropout_ - dsmm + - dsplit - dstack - eig - einsum - embedding - embedding_bag - embedding_renorm_ + - empty + - empty_like - eq - equal - erf @@ -791,12 +852,12 @@ torch: - expm1 - expm1_ - eye - - feature_dropout - feature_alpha_dropout - feature_alpha_dropout_ + - feature_dropout - feature_dropout_ - - fix - fill_ + - fix - fix_ - flatten - flip @@ -811,8 +872,9 @@ torch: - fmod - frac - frac_ - - full + - frexp - frobenius_norm + - full - full_like - gather - gcd @@ -824,8 +886,8 @@ torch: - greater_equal - grid_sampler - grid_sampler_2d - - group_norm - grid_sampler_3d + - group_norm - gru - gru_cell - gt @@ -835,23 +897,29 @@ torch: - heaviside - hinge_embedding_loss - histc + - histogram + - histogramdd - hsmm + - hsplit - hspmm - hstack - hypot + - i0 + - i0_ - igamma - igammac - index_add - index_copy - - inner - index_fill - index_put - index_put_ - index_select + - inner - instance_norm - inverse - isclose - isfinite + - isin - isinf - isnan - isneginf @@ -879,8 +947,8 @@ torch: - log1p_ - log2 - log2_ - - log_softmax - log_ + - log_softmax - logaddexp - logaddexp2 - logcumsumexp @@ -899,18 +967,18 @@ torch: - lt - lu_solve - lu_unpack - - masked_fill - margin_ranking_loss + - masked_fill - masked_scatter - masked_select - - matrix_exp - matmul + - matrix_exp - matrix_power - matrix_rank - max - max_pool1d - - max_pool2d - max_pool1d_with_indices + - max_pool2d - max_pool3d - maximum - mean @@ -929,18 +997,20 @@ torch: - mvlgamma - nan_to_num - nan_to_num_ + - nanmean - nanmedian + - nanquantile - nansum - narrow + - narrow_copy - native_batch_norm - native_group_norm - - narrow_copy - native_layer_norm - native_norm - ne - neg - - negative - neg_ + - negative - negative_ - nextafter - nonzero @@ -972,30 +1042,31 @@ torch: - ravel - real - reciprocal - - relu - reciprocal_ + - relu - relu_ - remainder - renorm - repeat_interleave - reshape - resize_as_ + - resolve_neg - roll - rot90 - round - round_ + - row_stack - rrelu - rrelu_ - rsqrt - - row_stack - rsqrt_ - rsub - saddmm - scalar_tensor - scatter - - select - scatter_add - searchsorted + - select - selu - selu_ - sgn @@ -1015,12 +1086,12 @@ torch: - solve - sort - sparse_coo_tensor - - square - split - split_with_sizes - spmm - sqrt - sqrt_ + - square - square_ - squeeze - sspaddmm @@ -1042,8 +1113,8 @@ torch: - tan_ - tanh - tanh_ - - tensordot - tensor_split + - tensordot - threshold - threshold_ - tile @@ -1059,19 +1130,21 @@ torch: - true_divide - trunc - trunc_ - - unique_consecutive - - xlogy - unbind + - unflatten + - unique_consecutive - unsafe_chunk - unsafe_split - - vander - - var - - vdot - unsafe_split_with_sizes - unsqueeze + - vander + - var - var_mean + - vdot + - vsplit - vstack - where + - xlogy - xlogy_ _VF: @@ -1165,6 +1238,34 @@ torch_npu: - npu_moe_finalize_routing - npu_moe_gating_top_k_softmax - npu_trans_quant_param + - npu_gelu + - npu_ffn + - npu_quant_matmul + - npu_format_cast_ + - npu_dynamic_quant + - npu_moe_compute_expert_tokens + - npu_weight_quant_batchmatmul + - npu_dynamic_quant_asymmetric + - npu_grouped_matmul + - npu_quant_scatter_ + - npu_group_quant + - npu_fused_infer_attention_score + - npu_quantize + - npu_fast_gelu + - npu_weight_quant_batchmatmul + - scatter_update + - scatter_update_ + - npu_moe_init_routing + - npu_scatter_nd_update_ + - npu_scatter_nd_update + - npu_prefetch + - npu_dynamic_block_quant + - npu_add_rms_norm + - _npu_flash_attention + - _npu_rotary_embedding + - _npu_reshape_and_cache + - _npu_paged_attention + - npu_moe_gating_top_k aten: - signbit @@ -1912,4 +2013,27 @@ distributed: - all_to_all - all_gather_into_tensor - reduce_scatter_tensor - - batch_isend_irecv \ No newline at end of file + - batch_isend_irecv + +npu_distributed: + - isend + - irecv + +mindspeed: + - dropout_add_layer_norm.npu_dropout_add_layer_norm + - npu_rotary_position_embedding.npu_rotary_position_embedding + - fusion_attention_v2.npu_fusion_attention + - npu_mm_all_reduce_add_rms_norm.npu_mm_all_reduce_add_rms_norm + - npu_mm_all_reduce_add_rms_norm_.npu_mm_all_reduce_add_rms_norm_ + - gmm.npu_gmm + - gmm.npu_gmm_v2 + - npu_grouped_mat_mul_all_reduce.npu_grouped_mat_mul_all_reduce + - ffn.npu_ffn + - npu_moe_token_permute.npu_moe_token_permute + - npu_moe_token_unpermute.npu_moe_token_unpermute + - npu_ring_attention_update.npu_ring_attention_update + - npu_matmul_add.npu_matmul_add_fp32 + - npu_groupmatmul_add.npu_groupmatmul_add_fp32 + - quant_gmm.npu_quant_gmm + - quant_gmm.npu_quant_gmm_v2 + - npu_apply_fused_ema_adamw.npu_apply_fused_ema_adamw \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py index 41869403a547fc526ec422ecbb123af18ff81a39..68e434d0ad151fc70d2a7bbb333b195d4bbe0e2f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +14,11 @@ # limitations under the License. import os -from msprobe.core.common.file_utils import load_yaml +import importlib +import inspect + +from msprobe.core.common.file_utils import load_yaml, check_link +from msprobe.core.common.log import logger def get_ops(): @@ -26,3 +30,25 @@ def get_ops(): wrap_torch = ops.get('torch') wrap_npu_ops = ops.get('torch_npu') return set(wrap_functional) | set(wrap_tensor) | set(wrap_torch) | set(wrap_npu_ops) + + +def dynamic_import_op(package, white_list): + package_name = package.__name__ + ops = {} + ops_dir, _ = os.path.split(package.__file__) + check_link(ops_dir) + for file_name in os.listdir(ops_dir): + if file_name in white_list: + sub_module_name = file_name[:-3] + module_name = f"{package_name}.{sub_module_name}" + try: + module = importlib.import_module(module_name) + except Exception as e: + logger.warning(f"import {module_name} failed!") + continue + + func_members = inspect.getmembers(module, inspect.isfunction) + for func_member in func_members: + func_name, func = func_member[0], func_member[1] + ops[f"{sub_module_name}.{func_name}"] = func + return ops diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py deleted file mode 100644 index 1cd11842c31bacdad7c1bb90f98ac81c3415a40e..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from functools import wraps -import torch.distributed as dist - -from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.common.utils import torch_device_guard -from msprobe.core.common.const import Const -from msprobe.core.common.file_utils import load_yaml - - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") - - -distributed_func = {} -for f in dir(dist): - distributed_func[f] = getattr(dist, f) - - -def get_distributed_ops(): - _all_distributed_ops = dir(dist) - yaml_data = load_yaml(yaml_path) - wrap_distributed_ops = yaml_data.get('distributed') - return set(wrap_distributed_ops) & set(_all_distributed_ops) - - -class HOOKDistributedOP(object): - pass - - -class DistributedOPTemplate(HOOKModule): - def __init__(self, op_name, build_hook): - self.op_name_ = op_name - self.prefix_op_name_ = "Distributed" + Const.SEP + str(op_name) + Const.SEP - super().__init__(build_hook) - if not self.stop_hook: - self.op_is_distributed = True - - @torch_device_guard - def forward(self, *args, **kwargs): - handle = distributed_func.get(self.op_name_)(*args, **kwargs) - if kwargs.get("async_op") or self.op_name_ in ["isend", "irecv"]: - if handle and hasattr(handle, 'wait'): - handle.wait() - if self.op_name_ == "batch_isend_irecv": - if isinstance(handle, list): - for req in handle: - req.wait() - return handle - - -def wrap_distributed_op(op_name, hook): - @wraps(DistributedOPTemplate) - def distributed_op_template(*args, **kwargs): - return DistributedOPTemplate(op_name, hook)(*args, **kwargs) - - distributed_op_template.__name__ = op_name - return distributed_op_template - - -def wrap_distributed_ops_and_bind(hook): - _distributed_ops = get_distributed_ops() - for op_name in _distributed_ops: - setattr(HOOKDistributedOP, "wrap_" + str(op_name), wrap_distributed_op(op_name, hook)) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py deleted file mode 100644 index 6164169476dab66ac2bdb8d0cbc41a04ddce6713..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import torch - -from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.common.utils import torch_device_guard -from msprobe.core.common.const import Const -from msprobe.pytorch.common.log import logger -from msprobe.core.common.file_utils import load_yaml - - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") - - -def get_functional_ops(): - yaml_data = load_yaml(yaml_path) - wrap_functional_ops = yaml_data.get('functional') - _all_functional_ops = dir(torch.nn.functional) - return set(wrap_functional_ops) & set(_all_functional_ops) - - -TorchFunctions = {func: getattr(torch.nn.functional, func) for func in get_functional_ops()} - - -class HOOKFunctionalOP(object): - pass - - -class FunctionalOPTemplate(HOOKModule): - def __init__(self, op_name, hook, need_hook=True): - self.op_name_ = op_name - self.prefix_op_name_ = "Functional" + Const.SEP + str(op_name) + Const.SEP - if need_hook: - super().__init__(hook) - - @torch_device_guard - def forward(self, *args, **kwargs): - return TorchFunctions[str(self.op_name_)](*args, **kwargs) - - -def wrap_functional_op(op_name, hook): - def functional_op_template(*args, **kwargs): - return FunctionalOPTemplate(op_name, hook)(*args, **kwargs) - - return functional_op_template - - -def wrap_functional_ops_and_bind(hook): - _functional_ops = get_functional_ops() - for op_name in _functional_ops: - setattr(HOOKFunctionalOP, "wrap_" + op_name, wrap_functional_op(op_name, hook)) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py deleted file mode 100644 index 1c0afc59f50c069fbcd7e9a546c5b57c467400a9..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import torch - -from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.common.utils import torch_device_guard, torch_without_guard_version -from msprobe.core.common.const import Const -from msprobe.core.common.log import logger -from msprobe.core.common.file_utils import load_yaml -from msprobe.pytorch.function_factory import npu_custom_functions - -try: - import torch_npu -except ImportError: - logger.info("Failing to import torch_npu.") - - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") -cuda_func_mapping = {"npu_fusion_attention" : "gpu_fusion_attention"} - - -def get_npu_ops(): - if torch_without_guard_version: - _npu_ops = dir(torch.ops.npu) - else: - _npu_ops = dir(torch_npu._C._VariableFunctionsClass) - yaml_data = load_yaml(yaml_path) - wrap_npu_ops = yaml_data.get('torch_npu') - return set(wrap_npu_ops) & set(_npu_ops) - - -class HOOKNpuOP(object): - pass - - -class NpuOPTemplate(HOOKModule): - - def __init__(self, op_name, hook, need_hook=True, device=Const.CPU_LOWERCASE): - self.op_name_ = op_name - self.prefix_op_name_ = "NPU" + Const.SEP + str(op_name) + Const.SEP - self.need_hook = need_hook - self.device = device - if need_hook: - super().__init__(hook) - - @torch_device_guard - def forward(self, *args, **kwargs): - if not self.need_hook: - if self.op_name_ not in npu_custom_functions: - raise Exception(f'There is not bench function {self.op_name_}') - if self.device == Const.CUDA_LOWERCASE: - self.op_name_ = cuda_func_mapping.get(self.op_name_, self.op_name_) - if self.device in [Const.CUDA_LOWERCASE, Const.CPU_LOWERCASE]: - return npu_custom_functions[self.op_name_](*args, **kwargs) - if torch_without_guard_version: - return getattr(torch.ops.npu, str(self.op_name_))(*args, **kwargs) - else: - return getattr(torch_npu._C._VariableFunctionsClass, str(self.op_name_))(*args, **kwargs) - - -def wrap_npu_op(op_name, hook): - def npu_op_template(*args, **kwargs): - return NpuOPTemplate(op_name, hook)(*args, **kwargs) - return npu_op_template - - -def wrap_npu_ops_and_bind(hook): - _npu_ops = get_npu_ops() - for op_name in _npu_ops: - setattr(HOOKNpuOP, "wrap_" + str(op_name), wrap_npu_op(op_name, hook)) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py deleted file mode 100644 index f93c09a12415f22d96306ebc9de919520c025236..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import torch - -from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.common.utils import torch_device_guard, parameter_adapter -from msprobe.core.common.const import Const -from msprobe.core.common.file_utils import load_yaml - - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") - - -def get_tensor_ops(): - _tensor_ops = dir(torch.Tensor) - yaml_data = load_yaml(yaml_path) - wrap_tensor_ops = yaml_data.get('tensor') - return set(wrap_tensor_ops) & set(_tensor_ops) - - -TensorOps = {op: getattr(torch.Tensor, op) for op in get_tensor_ops()} - - -class HOOKTensor(object): - pass - - -class TensorOPTemplate(HOOKModule): - - def __init__(self, op_name, hook, need_hook=True): - self.op_name_ = op_name - self.prefix_op_name_ = "Tensor" + Const.SEP + str(op_name) + Const.SEP - if need_hook: - super().__init__(hook) - - @torch_device_guard - @parameter_adapter - def forward(self, *args, **kwargs): - return TensorOps[str(self.op_name_)](*args, **kwargs) - - -def wrap_tensor_op(op_name, hook): - - def tensor_op_template(*args, **kwargs): - return TensorOPTemplate(op_name, hook)(*args, **kwargs) - - return tensor_op_template - - -def wrap_tensor_ops_and_bind(hook): - _tensor_ops = get_tensor_ops() - for op_name in _tensor_ops: - setattr(HOOKTensor, "wrap_" + str(op_name), wrap_tensor_op(op_name, hook)) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py deleted file mode 100644 index fc9d61c206bcfaeda7fefb5cb8b90fda2d67cb16..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import torch - -from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.common.utils import torch_device_guard -from msprobe.core.common.const import Const -from msprobe.core.common.file_utils import load_yaml - - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") - - -def get_torch_ops(): - _torch_ops = [] - yaml_data = load_yaml(yaml_path) - wrap_torch_ops = yaml_data.get('torch') - for operation in wrap_torch_ops: - if '.' in operation: - operation_sub_module_name, operation_sub_op = operation.rsplit('.', 1) - operation_sub_module = getattr(torch, operation_sub_module_name) - if operation_sub_op in dir(operation_sub_module): - _torch_ops.append(operation) - else: - if hasattr(torch, operation): - _torch_ops.append(operation) - return set(_torch_ops) - - -TorchOps = {} -for op in get_torch_ops(): - if '.' in op: - sub_module_name, sub_op = op.rsplit('.', 1) - sub_module = getattr(torch, sub_module_name) - TorchOps[op] = getattr(sub_module, sub_op) - else: - TorchOps[op] = getattr(torch, op) - - - -class HOOKTorchOP(object): - pass - - -class TorchOPTemplate(HOOKModule): - - def __init__(self, op_name, hook, need_hook=True): - self.op_name_ = op_name - self.prefix_op_name_ = "Torch" + Const.SEP + str(op_name) + Const.SEP - if need_hook: - super().__init__(hook) - - @torch_device_guard - def forward(self, *args, **kwargs): - return TorchOps[str(self.op_name_)](*args, **kwargs) - - -def wrap_torch_op(op_name, hook): - - def torch_op_template(*args, **kwargs): - return TorchOPTemplate(op_name, hook)(*args, **kwargs) - - return torch_op_template - - -def wrap_torch_ops_and_bind(hook): - _torch_ops = get_torch_ops() - for op_name in _torch_ops: - setattr(HOOKTorchOP, "wrap_" + op_name, wrap_torch_op(op_name, hook)) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/csv2tb.py b/debug/accuracy_tools/msprobe/pytorch/monitor/csv2tb.py index 6ffd1ffabe7b113ff4e61786d4d9f0709b8b605b..0f015dceb4b18f210bf522ec4e1bba1f3d777c85 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/csv2tb.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/csv2tb.py @@ -22,12 +22,17 @@ from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from msprobe.core.common.const import MonitorConst -from msprobe.core.common.file_utils import read_csv, create_directory, remove_path -from msprobe.core.common.utils import is_int +from msprobe.core.common.file_utils import read_csv, create_directory, remove_path, recursive_chmod +from msprobe.core.common.utils import check_process_num +from msprobe.core.common.decorator import recursion_depth_decorator +from msprobe.core.monitor.utils import get_target_output_dir from msprobe.pytorch.common.log import logger -from msprobe.pytorch.monitor.utils import get_target_output_dir -all_data_type_list = ["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param"] + +all_data_type_list = [ + "actv", "actv_grad", "exp_avg", "exp_avg_sq", + "grad_unreduced", "grad_reduced", "param_origin", "param_updated" +] CSV_FILE_SUFFIX = r"_\d+-\d+\.csv" @@ -46,7 +51,7 @@ def parse_step_line(line, ops): def parse_step_fn(filepath): data = read_csv(filepath) - ops = [k for k in data.keys() if k in MonitorConst.OP_LIST] + ops = [k for k in data.keys() if k in MonitorConst.OP_LIST[:-2]] parse_step_result = {} for _, line in data.iterrows(): @@ -74,8 +79,10 @@ def write_step(output_dirpath, parse_step_result, rank, data_type): for op, value in ops.items(): tag = f"{vpp_name}/{op}" writer.add_scalar(tag, value, step) + writer.close() +@recursion_depth_decorator("update_dict", max_depth=50) def update_dict(dict1, dict2): for key, value in dict2.items(): if key in dict1: @@ -112,14 +119,9 @@ def csv2tb_by_step_work(target_output_dirs, output_dirpath, data_type_list): write_step(output_dirpath, all_step_result, rank, data_type) -def check_process_num(process_num): - if not is_int(process_num) or process_num <= 0: - raise ValueError(f"process_num({process_num}) is not a positive integer") - - def check_data_type_list(data_type_list): if data_type_list is None: - logger.info(f"data_type_list is None, use defualt all_data_type_list: {all_data_type_list}") + logger.info(f"data_type_list is None, use default all_data_type_list: {all_data_type_list}") return if not isinstance(data_type_list, list): raise ValueError(f"data_type_list({data_type_list}) is not a list") @@ -161,4 +163,5 @@ def csv2tensorboard_by_step( p.start() for p in processes: p.join() + recursive_chmod(output_dirpath) logger.info(f"output has been saved to: {output_dirpath}") diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_detect.py b/debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py similarity index 48% rename from debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_detect.py rename to debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py index 63f20b1928c80e1e29d7cb8224f267c246fcaa8b..bd6bde7e9f6ede789f520acc2138492e99bac509 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_detect.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py @@ -14,12 +14,8 @@ # limitations under the License. import itertools import os -import statistics as st -import sys -from abc import ABC from collections import defaultdict -from dataclasses import dataclass, field -from typing import List +from dataclasses import dataclass import pandas as pd import torch @@ -27,78 +23,10 @@ from torch.utils.tensorboard import SummaryWriter from msprobe.core.common.const import FileCheckConst, MonitorConst from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv +from msprobe.core.monitor.anomaly_processor import AnomalyDataFactory, AnomalyTurbulence, AnomalyScanner from msprobe.pytorch.common.log import logger -class ScanRule(ABC): - name = "ScanRule" - - def apply(self, history, cur): - raise NotImplementedError("abstract method apply is not implemented") - - -class AnomalyTurbulence(ScanRule): - name = "AnomalyTurbulence" - - def __init__(self, threshold) -> None: - self.threshold = threshold - - def apply(self, history, cur): - baseline = st.mean(history) if isinstance(history, list) else history - - up_bound = baseline + baseline * self.threshold - if baseline > 0: - return cur > up_bound - else: - return cur < up_bound - - -class AnomalyScanner: - - @staticmethod - def load_rules(specs: List[dict]): - """ - specs: [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}] - """ - if specs is None: - return [] - alert_rules = [] - for spec in specs: - # 使用get方法获取键值,如果键不存在则返回None - rule_cls_name = spec.get("rule_name") - rule_args = spec.get("args") - - # 检查必要的键是否存在 - if rule_cls_name is None or rule_args is None: - logger.warning(f"Spec is missing required keys: {spec}") - continue - - cur_module = sys.modules.get(__name__) - try: - rule_cls = getattr(cur_module, rule_cls_name) - except AttributeError: - logger.error(f"Rule class '{rule_cls_name}' not found in the current module.") - continue - - try: - rule_instance = rule_cls(**rule_args) - alert_rules.append(rule_instance) - except Exception as e: - logger.error(f"Error creating instance of rule '{rule_cls_name}': {e}") - continue - - return alert_rules - - @staticmethod - def scan(scan_rules: List[ScanRule], history, cur): - anomaly = False - for rule in scan_rules: - anomaly = rule.apply(history, cur) - if anomaly: - return anomaly, rule.name - return anomaly, None - - class BCOLORS: HEADER = '\033[95m' OKBLUE = '\033[94m' @@ -111,130 +39,6 @@ class BCOLORS: UNDERLINE = '\033[4m' -class AnomalyDataFactory(ABC): - def __init__(self, rank, pp_stage, group_mates): - super().__init__() - self.rank = rank - self.pp_stage = pp_stage - self.group_mates = group_mates - self.micro_step = 0 - self.name2callid = {} - - def set_call_id(self, name2callid): - """根据当前GradContext信息更新call_id vpp_stage等信息 - """ - self.name2callid = name2callid - - def create(self, tag, message, step): - """如果检查出异常, 调用当前接口生成GradAnomalyData实例 - tag (tuple): metric tag ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min') - message (str): anomaly detect message - step (int): training step - """ - if not isinstance(tag, tuple) or len(tag) != 2: - raise ValueError("tag must be a tuple with length 2") - tag_name = tag[0] - param_name = tag_name.split('/')[0] - call_id = self.name2callid.get(tag_name, -1) - if MonitorConst.NAME_SEP in param_name: - vpp_stage = int(param_name.split(MonitorConst.NAME_SEP)[0]) - else: - vpp_stage = 0 - - return GradAnomalyData( - self.rank, - step, - self.micro_step, - self.pp_stage, - vpp_stage, - call_id, - tag_name, - message, - self.group_mates - ) - - -class TrainStage: - DEFAULT_STAGE = -1 - FORWARD_STAGE = 0 - BACKWARD_STAGE = 1 - OPTIMIZER_STAGE = 2 - - -FORWARD_KEY = [MonitorConst.ACTV] -BACKWARD_KEY = [MonitorConst.ACTVGRAD, MonitorConst.PRE_GRAD, - MonitorConst.POST_GRAD, MonitorConst.ACC_GRAD] -OPTIMIZER_KEY = [MonitorConst.EXP_AVG, MonitorConst.EXP_AVG_SQ] -TRAIN_STAGE = { - **{key_: TrainStage.FORWARD_STAGE for key_ in FORWARD_KEY}, - **{key_: TrainStage.BACKWARD_STAGE for key_ in BACKWARD_KEY}, - **{key_: TrainStage.OPTIMIZER_STAGE for key_ in OPTIMIZER_KEY} -} - - -@dataclass(eq=True) -class GradAnomalyData: - rank: int = 0 - step: int = 0 - micro_step: int = 0 - pp_stage: int = 0 - vpp_stage: int = 0 - call_id: int = 0 - tag_name: str = field(default=None, compare=False) - message: str = field(default="", compare=False) - group_mates: list = field(default=None, compare=False) - - def __lt__(self, other): - """ - 自定义比较函数,用于确定 GradAnomalyData 实例之间的顺序。 - 比较规则为: - step 和 micro_step 值越小优先级越高; - vpp 和 pp 在前向阶段值越小优先级越高,在非前向阶段值越大优先级越高; - call_id 值越小优先级越高。 - """ - if not isinstance(other, GradAnomalyData): - return NotImplemented - - self_train_stage = self.get_train_stage(self.tag_name) - other_train_stage = self.get_train_stage(other.tag_name) - - def vpp_pp_comparator(anomaly): - """ - Determine the priority rule for vpp and pp based on train stage - Forward stage prefers smaller vpp and pp - Other stages prefer larger vpp and pp - """ - if self_train_stage == TrainStage.FORWARD_STAGE: - return anomaly.vpp_stage, anomaly.pp_stage - else: - return -anomaly.vpp_stage, -anomaly.pp_stage - - self_cmp = [self.step, self.micro_step, self_train_stage, *vpp_pp_comparator(self), self.call_id] - other_cmp = [other.step, other.micro_step, other_train_stage, *vpp_pp_comparator(other), other.call_id] - return self_cmp < other_cmp - - def __le__(self, other): - if not isinstance(other, GradAnomalyData): - return NotImplemented - return self == other or self < other - - @staticmethod - def get_train_stage(tag_name): - """ - :param tag_name: "0:fc2.input:0/rank0/actv", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq" - :return: int, if forward return 0; if backward return 1; if optimizer return 2 - """ - key_ = tag_name.split("/")[-1] - return TRAIN_STAGE.get(key_, TrainStage.DEFAULT_STAGE) - - def to_dict(self): - return self.__dict__ - - def get_key(self): - # 0:1.self_attention.core_attention_flash_0/rank0/input_grad - return ''.join([str(self.tag_name), "_step_", str(self.step), "_call_", str(self.call_id)]) - - @dataclass class WriterInput: path: str @@ -253,6 +57,41 @@ class BaseWriterWithAD: self.anomaly_factory = writer_input.anomaly_factory self.anomalies = [] self.ndigits = writer_input.ndigits + self.beta = 0.99 + + @staticmethod + def stack_tensors(tensor_list): + """ + Torch not support stack cpu and xpu tensors. Group the tensors into cpu_group and xpu_group, + stack them separately, migrate xpu_group to cpu, and then restore in the order of input. + + :param tensor_list: [tensor(-1.6165), tensor(-1.0985), tensor(-1.7777), tensor(-1.8408, device='npu:0')] + :return: result: list of float + """ + cpu_tensors = [] + xpu_tensors = [] + + for tensor in tensor_list: + if isinstance(tensor, torch.Tensor) and tensor.device.type != 'cpu': + # 将device上的tensor先stack后to cpu + xpu_tensors.append(tensor) + else: + cpu_tensors.append(tensor) + + xpu_stack = torch.stack(xpu_tensors).cpu() if xpu_tensors else torch.tensor([]) + + # 按照输入的顺序恢复 + result = [] + cpu_tensors_idx, xpu_tensors_idx = 0, 0 + for tensor in tensor_list: + if isinstance(tensor, torch.Tensor) and tensor.device.type != 'cpu': + result.append(xpu_stack[xpu_tensors_idx]) + xpu_tensors_idx += 1 + else: + result.append(cpu_tensors[cpu_tensors_idx]) + cpu_tensors_idx += 1 + + return result def get_anomalies(self): """返回已检测到的异常列表 @@ -271,12 +110,17 @@ class BaseWriterWithAD: Returns: None """ - detected = False - if self.ad_rules: - avg = self._update_tag2scalars(tag, scalar_value) - detected, rule_name = self._ad(scalar_value, history=avg) + if not self.ad_rules or tag[-1] in ["shape", "dtype"]: + return + if isinstance(scalar_value, torch.Tensor): + scalar_value = scalar_value.item() + avg = self._update_tag2scalars(tag, scalar_value) + detected, rule_name = self._ad(scalar_value, history=avg) if detected: - exception_message = f"Rule {rule_name} reports anomaly signal in {tag} at step {global_step}." + if rule_name == AnomalyTurbulence.name and tag[-1] not in ["norm", "mean"]: + return + exception_message = (f"Rule {rule_name} reports anomaly signal in {tag} at step {global_step}, " + f"current value {scalar_value}, history mean {avg}.") logger.info(f"{BCOLORS.WARNING}> {exception_message}{BCOLORS.ENDC}") # append to self.anomalies for dump if self.anomaly_factory: @@ -291,15 +135,15 @@ class BaseWriterWithAD: tensors.extend(op2tensor.values()) if not tensors: return - + n_slices = len(tensors) // MonitorConst.SLICE_SIZE with torch.no_grad(): for i in range(n_slices + 1): begin = i * MonitorConst.SLICE_SIZE - end = (i+1) * MonitorConst.SLICE_SIZE + end = (i + 1) * MonitorConst.SLICE_SIZE if begin == len(tensors): continue - metric_list = torch.stack(tensors[begin:end]).cpu() + metric_list = self.stack_tensors(tensors[begin:end]) for tag, metric in zip(tags[begin:end], metric_list): self.add_scalar(tag, metric, step) @@ -319,11 +163,11 @@ class BaseWriterWithAD: Returns: float: The average value before update. """ + abs_scalar_value = abs(scalar_value) if tag not in self.tag2scalars: - self.tag2scalars[tag] = {'avg': scalar_value, 'count': 0} + self.tag2scalars[tag] = {'avg': abs_scalar_value, 'count': 0} avg = self.tag2scalars[tag]['avg'] - new_avg = (avg * self.tag2scalars[tag]['count'] + scalar_value) / (self.tag2scalars[tag]['count'] + 1) - self.tag2scalars[tag]['avg'] = new_avg + self.tag2scalars[tag]['avg'] = self.beta * avg + (1 - self.beta) * abs_scalar_value self.tag2scalars[tag]['count'] += 1 return avg @@ -364,7 +208,6 @@ class CSVWriterWithAD(BaseWriterWithAD): new_line = name.split(MonitorConst.NAME_SEP) + metric_value new_line.insert(2, step) new_data.append(new_line) - new_data = pd.DataFrame(new_data).round(self.ndigits).fillna("nan") write_df_to_csv(new_data, filepath, mode='a+', header=False) self.context_dict = defaultdict(list) @@ -376,13 +219,19 @@ class CSVWriterWithAD(BaseWriterWithAD): super().add_scalar(tag, scalar_value, global_step) name = tag[0].split('/')[0] - self.context_dict[name].append(scalar_value.item()) + if isinstance(scalar_value, torch.Tensor): + value = scalar_value.item() + elif isinstance(scalar_value, torch.Size): + value = list(scalar_value) + else: + value = scalar_value + self.context_dict[name].append(value) - def write_metrics(self, ops, metric_value, step, prefix=''): + def write_metrics(self, ops, metric_value, step, prefix='', **kwargs): super().write_metrics(ops, metric_value, step, prefix='') - if prefix in [MonitorConst.ACTV, MonitorConst.ACTVGRAD]: - self.header = MonitorConst.CSV_HEADER_XY + ops + if prefix in [MonitorConst.ACTV, MonitorConst.ACTVGRAD] or kwargs.get("use_micro_step"): + self.header = MonitorConst.CSV_HEADER_MICRO_STEP + ops else: self.header = MonitorConst.CSV_HEADER + ops self.write_csv(prefix, step) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/distributed/wrap_distributed.py b/debug/accuracy_tools/msprobe/pytorch/monitor/distributed/wrap_distributed.py index b2fa26a58e702120fcabd5d82f8e1e0ed27f3bc4..c209fdba97fa9a4a153516d340892fbefbf0284f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/distributed/wrap_distributed.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/distributed/wrap_distributed.py @@ -24,6 +24,7 @@ import torch.nn as nn from msprobe.core.common.const import MonitorConst from msprobe.core.common.file_utils import load_yaml from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name +from msprobe.pytorch.common.log import logger try: import torch_npu @@ -37,6 +38,7 @@ WrapDistributedOps = load_yaml(OpsPath).get("distributed", []) StackBlackListPath = os.path.join(os.path.dirname(__file__), "stack_blacklist.yaml") StackBlackList = load_yaml(StackBlackListPath).get("stack", []) +MAX_STRING_LENGTH = 1000 distributed_func = {} for f in dir(dist): @@ -139,6 +141,8 @@ def get_process_group(process_group): def stack_filter(stack): + if len(stack) > MAX_STRING_LENGTH: + logger.warning(f'The character string contains more than {MAX_STRING_LENGTH}. re match is skipped.') for pattern in StackBlackList: if re.search(pattern, stack): return False @@ -188,10 +192,12 @@ def update_data(old, new): def is_target_line(codeline): - stack = get_callstack() - whole_stack = ';'.join(stack) if codeline == []: return True + stack = get_callstack() + whole_stack = ';'.join(stack) + if len(whole_stack) > MAX_STRING_LENGTH: + logger.warning(f'The character string contains more than {MAX_STRING_LENGTH}. re match is skipped.') for pattern in codeline: if re.search(pattern, whole_stack): return True diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/features.py b/debug/accuracy_tools/msprobe/pytorch/monitor/features.py index 81c029d401f9194688d332ac711d6065f126ce6a..960f3dabe32acb8c8ed3018f7421747131e02984 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/features.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/features.py @@ -45,13 +45,18 @@ def get_max(x: torch.tensor): @torch.no_grad() def get_zeros(x: torch.tensor, eps: float): + if x.numel() == 0: + return torch.tensor(float('nan')) return torch.sum(torch.abs(x) < eps) / x.numel() @torch.no_grad() def get_sign_matches(x: torch.tensor, y: torch.tensor): + if y.numel() == 0: + return torch.tensor(1.) xs = x.sign() ys = y.sign() + try: same_direction_ratio = ((xs * ys).sum() / ys.numel() + 1) / 2 except RuntimeError as e: @@ -106,3 +111,97 @@ def cal_histc(tensor_cal, bins_total, min_val, max_val): @torch.no_grad() def get_nans(t): return torch.isnan(t).sum() + + +def check_tensor_dim(tensor, n): + """检查张量维度是否大于n + """ + if not isinstance(tensor, torch.Tensor): + raise TypeError( + f"Input must be a PyTorch tensor. Got {type(tensor)} instead. " + f"Consider using torch.tensor() for conversion." + ) + + if tensor.dim() < n: + raise ValueError( + f"Tensor must have at least {n} dimensions. " + f"Got shape: {tuple(tensor.shape)} with {tensor.dim()} dims." + ) + + +@torch.no_grad() +def max_eigenvalue(input_tensor: torch.Tensor, num_iterations=3): + input_tensor = input_tensor.float() + try: + check_tensor_dim(input_tensor, 2) + except (TypeError, ValueError) as e: + logger.warning(f"Calculate max eigenvalue failed: {e}") + return torch.tensor(0) + in_features = input_tensor.shape[1] + u_tensor = torch.randn(in_features).to(input_tensor.device) + u_norm = u_tensor.norm() + if u_norm.item() == 0: + return torch.tensor(0) + u_tensor = u_tensor / u_tensor.norm() + input_seq = torch.matmul(input_tensor.T, input_tensor) + for _ in range(num_iterations): + v_tensor = torch.matmul(input_seq, u_tensor) + spectral_norm = torch.matmul(v_tensor.T, u_tensor) + v_norm = v_tensor.norm() + if v_norm > 0: + u_tensor = v_tensor / v_norm + else: + spectral_norm = torch.tensor(0) + break + return spectral_norm.sqrt() + + +@torch.no_grad() +def cal_entropy(qk_tensor, mask=None): + try: + check_tensor_dim(qk_tensor, 2) + except (TypeError, ValueError) as e: + logger.warning(f"Calculate max eigenvalue failed: {e}") + return torch.tensor(0), torch.tensor(0) + if mask is None: + mask = torch.tril(torch.ones(qk_tensor.shape[1], qk_tensor.shape[1])).to( + qk_tensor.device) + qk_tensor = qk_tensor - torch.amax(qk_tensor, dim=1, keepdim=True) + qk_tensor = qk_tensor.masked_fill(mask == 0, float('-inf')) + softmax_qkt = torch.nn.functional.softmax(qk_tensor.float(), dim=1) + # softmax取QK矩阵最大值 + softmax_max = torch.mean(torch.amax(softmax_qkt, dim=1)) + entropy = torch.mean(-torch.nansum(softmax_qkt * + torch.log(softmax_qkt), dim=1)) + return entropy, softmax_max + + +@torch.no_grad() +def cal_qkt(q_h, k_h, order="s,b,h,d"): + # q_h shape is [s, b, h, d] + try: + check_tensor_dim(q_h, 4) + check_tensor_dim(k_h, 4) + except (TypeError, ValueError) as e: + logger.warning(f"Calculate qk tensor failed: {e}") + return torch.tensor(0) + + if order == "s,b,h,d": + qkt = torch.matmul( + q_h[:, 0, 0, :], k_h[:, 0, 0, :].t()) / q_h.shape[-1] ** 0.5 + elif order == "b,s,h,d": + qkt = torch.matmul( + q_h[0, :, 0, :], k_h[0, :, 0, :].t()) / q_h.shape[-1] ** 0.5 + else: + logger.warning("Calculate qk tensor failed: Order unsupported.") + qkt = torch.tensor(0) + return qkt + + +@torch.no_grad() +def cal_stable_rank(weight: torch.Tensor): + eig = max_eigenvalue(weight) + if eig == torch.tensor(0): + return torch.tensor(0), torch.tensor(0) + f_norm = torch.norm(weight, p="fro") + return f_norm / eig, eig diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index d0285564d3cb5c00b69933db3259b7c3339c443d..0a2cd447c9fc32fb8e6ebcb2aa1baab05577c637 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -15,6 +15,7 @@ import json import os import uuid +import importlib from collections import defaultdict from datetime import datetime from functools import partial @@ -22,26 +23,30 @@ from functools import partial import pytz import torch import torch.distributed as dist +import pandas as pd from torch.utils.hooks import BackwardHook from msprobe.core.common.const import MonitorConst, Const -from msprobe.core.common.file_utils import load_json, save_json +from msprobe.core.common.file_utils import load_json, save_json, make_dir +from msprobe.core.common.decorator import recursion_depth_decorator +from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter +from msprobe.core.common.file_utils import write_df_to_csv +from msprobe.core.common.utils import analyze_api_call_stack +from msprobe.core.monitor.utils import validate_config, validate_ops, \ + get_output_base_dir, get_target_output_dir, chmod_tensorboard_dir, validate_set_monitor from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.utils import is_recomputation -from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter -from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD, AnomalyDataFactory, \ - CSVWriterWithAD, BaseWriterWithAD, WriterInput +from msprobe.pytorch.monitor.utils import get_param_struct +from msprobe.pytorch.monitor.data_writers import SummaryWriterWithAD, CSVWriterWithAD, BaseWriterWithAD, WriterInput from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \ get_process_group -from msprobe.pytorch.monitor.features import get_sign_matches +from msprobe.pytorch.monitor.features import get_sign_matches, cal_qkt from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \ - TensorMetrics, squash_param_name -from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec + TensorMetrics, squash_param_name, get_entropy_metric, get_sr_metric from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory -from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, \ - get_output_base_dir, get_target_output_dir from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer + torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' if not torch_version_above_or_equal_2: raise ValueError("monitor require torch>=2.0") @@ -52,6 +57,7 @@ FORMAT_MAPPING = { MonitorConst.CSV: CSVWriterWithAD, MonitorConst.API: BaseWriterWithAD } +start_step = 0 def param_is_not_tensor_parallel_duplicate(param, tp_group): @@ -71,43 +77,24 @@ class ModuleHookContext: self.actvgrad = [] self.module_name = module_name self.struct = {} - self.format_by_arg = {} - self.verified = False - self.focused_in_col = 0 - self.focused_out_col = 0 - - def set_format_by_arg(self, key_name: str, target_config: dict): - """ 按照监控对象配置format_by_arg - 1) module_name 在 target 中配置监控对象 - 2) module_name 未在 targets 中配置,且 all_xy 全量监控 - 3) module_name 未在 targets 中配置,且 all_xy 未全量监控 - - :param key_name: str, one of [input, output, input_grad, output_grad] - :param target_config: target obj in config json. - :return: - """ - cared = target_config.get(self.module_name, self.struct) - if key_name in cared: - target_module_config = cared[key_name] - if isinstance(target_module_config, dict): - # current cared is self.struct, monitor all data for module_name - self.format_by_arg[key_name] = target_module_config.get('config') - elif isinstance(target_module_config, str): - # current cared is target_config[self.module_name] - self.format_by_arg[key_name] = target_module_config - else: - logger.warning_on_rank_0(f"target module config error, result maybe empty." - f"module_name: {self.module_name}, key_name: {key_name}") - self.format_by_arg[key_name] = None - else: - self.format_by_arg[key_name] = self.struct.get(key_name).get('config') + self.stack = "" def reset(self): self.actv.clear() self.actvgrad.clear() -start_step = 0 +class FeatureHookContext: + def __init__(self, module_name): + self.step = 0 + self.micro_step = 0 + self.attention_feature = {} + self.linear_feature = {} + self.module_name = module_name + + def reset(self): + self.attention_feature.clear() + self.linear_feature.clear() class OptimizerContext: @@ -184,8 +171,8 @@ class TrainerMon: self.params_have_main_grad = params_have_main_grad self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer) self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer) - self.origin_step_func = None - self.origin_start_grad_sync = None + self.fsdp_post_backward_hook = None + self.fsdp2_foreach_reduce = None self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开 self.config = load_json(config_file_path) validate_config(self.config) @@ -220,22 +207,25 @@ class TrainerMon: self.dp_group = None self.tp_group = None self.enable_megatron = False + self.enable_deepspeed = False + self.fsdp_wrapped_module = False + self.fsdp2_wrapped_module = False self.micro_batch_number = 1 - self.optimizer_class = None self.optimizer_mon = None self.optimizer_trans = None # TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量 self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext) self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext) + self.feature_hook_context_by_module = defaultdict(FeatureHookContext) self.optimizer_context = defaultdict(OptimizerContext) self.cc_context = defaultdict(CommunicationContext) self.grad_context = GradContext() self.handles = defaultdict(list) self.param2name = defaultdict(str) - self.name2index = defaultdict() self.name2indices = defaultdict() self.name2param = {} + self.origin2squash = {} self.duplicate_param = {} self.name2tag = {} self.param_name_call_id = {} @@ -246,6 +236,8 @@ class TrainerMon: self.optimizer_hooked = False self.param_registered = False self.struct_printed = False + self.pre_step_hooks = [] + self.post_step_hooks = [] # 动静态区分 self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true' @@ -294,6 +286,18 @@ class TrainerMon: cc_tensor.reset() return metrics + @staticmethod + def get_linear_hook_target(module): + if isinstance(module, torch.nn.Embedding): + return '' + if hasattr(module, "num_embeddings") or hasattr(module, "vocab_start_index"): + return '' + for weight_name in ["weight", "wg"]: + if hasattr(module, weight_name) and isinstance(getattr(module, weight_name), torch.Tensor): + if getattr(module, weight_name).dim() == 2: + return weight_name + return '' + def set_config(self): logger.info(f"current config: {self.config}") self.start_step = self.config.get("start_step", 0) @@ -316,6 +320,10 @@ class TrainerMon: self.param_distribution = self.config.get("param_distribution", False) self.mg_direction = self.config.get('mg_direction', False) self.cc_distribution = self.config.get("cc_distribution", {}) + self.stack_info = self.config.get('stack_info', False) + self.monitor_mbs_grad = self.config.get('monitor_mbs_grad', False) + self.recording_l2_features = self.config.get("recording_l2_features", False) + self.sa_order = self.config.get("sa_order", "s,b,h,d") if not self.cc_distribution.get('enable', False): self.cc_log_only = False @@ -324,8 +332,6 @@ class TrainerMon: self.cc_log_only = self.cc_distribution.get('cc_log_only', False) self.cc_logged_stack = defaultdict(set) self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False) - self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self)) - api_register.redirect_api() self.common_info() @@ -338,11 +344,11 @@ class TrainerMon: # 初始化writer, 创建输出目录 if self.format not in FORMAT_MAPPING: - logger.error(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}") + logger.warning(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}") self.format = MonitorConst.CSV if self.ur_distribution and self.format != 'tensorboard': - logger.error("can only set ur_distribution when format is 'tensorboard', cancel ur_distribution") + logger.warning("can only set ur_distribution when format is 'tensorboard', cancel ur_distribution") self.ur_distribution = False writer = FORMAT_MAPPING[self.format] @@ -376,6 +382,8 @@ class TrainerMon: logger.info_on_rank_0("> momentum and variance of adam is not monitored. ") if not self.wg_distribution: logger.info_on_rank_0("> weight grad of specified module is not monitored. ") + if not self.recording_l2_features: + logger.info_on_rank_0("> l2 features of specified module is not monitored. ") if not self.mg_direction: logger.info_on_rank_0('> grad and momentum direction will not be compared.') if not self.cc_distribution.get('enable', False): @@ -405,13 +413,14 @@ class TrainerMon: start_iteration=0 ): """External interface""" + grad_acc_steps, start_iteration = validate_set_monitor(grad_acc_steps, start_iteration) global start_step start_step = start_iteration logger.info(f'grad acc steps {grad_acc_steps}') self.micro_batch_number = grad_acc_steps self.dp_group = dp_group self.tp_group = tp_group - self.optimizer_mon, self.optimizer_class = OptimizerMonFactory.create_optimizer_mon(optimizer) + self.optimizer_mon = OptimizerMonFactory.create_optimizer_mon(optimizer) self.hook_step_final(optimizer) if not isinstance(model, list): model = [model] @@ -427,6 +436,9 @@ class TrainerMon: self.hook_optimizer(optimizer) self._patch_grad_sync() self.hook_modules() + if self.cc_distribution.get('enable', False): + self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self)) + api_register.redirect_api() self.monitoring = True def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list): @@ -437,25 +449,48 @@ class TrainerMon: return self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank) - def build_tbtag_tensor_map(self, module_name, tag, tensor): - key = get_summary_writer_tag_name(module_name, tag, self.rank) - self._register_param_call_id("_hook_module", key) - return {key: tensor} + def build_tbtag_tensor_map(self, module_name, suffix, tag, tensor): + """ + :param module_name: str of module name + :param suffix: + :param tag: + :param tensor: torch.tensor or tuple/list of torch.tensor + :return: tensor_map + """ + tensor_map = {} + if isinstance(tensor, torch.Tensor): + tensor = [tensor] + if isinstance(tensor, tuple) or isinstance(tensor, list): + if len(tensor) == 1: + key = get_summary_writer_tag_name(module_name + suffix, tag, self.rank) + self.register_param_call_id("_hook_module", key) + tensor_map[key] = tensor[0] + else: + for i, tensor_i in enumerate(tensor): + key = get_summary_writer_tag_name(module_name + f"_{i}" + suffix, tag, self.rank) + self.register_param_call_id("_hook_module", key) + tensor_map[key] = tensor_i + return tensor_map def generate_param_map(self, tag, param_tensor): metrics = {} for name in self.param2name.values(): key = get_summary_writer_tag_name(name, tag, self.rank) - self._register_param_call_id("optimizer_pre_step_hook", key) + self.register_param_call_id("optimizer_pre_step_hook", key) if name not in param_tensor or param_tensor[name] is None: continue metrics[key] = param_tensor[name] return metrics - def generate_param_metrics(self, opt_context): + def generate_param_metrics(self, opt_context, stage=MonitorConst.PRE_PARAM): if not self.param_distribution: return - get_metrics(self.ops, self.name2param, self.eps, opt_context.param_metric) + tag2param = { + self.name2tag.get(name, {}).get(stage): param + for name, param in self.name2param.items() + if param.numel() != 0 + } + get_metrics(self.ops, tag2param, self.eps, opt_context.param_metric) def generate_mv_metrics(self, opt_context): if not self.mv_distribution: @@ -467,28 +502,22 @@ class TrainerMon: get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric) get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric) - def generate_wgrad_metrics(self): + def generate_wgrad_metrics(self, post_grad_dict): if not self.wg_distribution: return {}, {} if self.weight_hooked: get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric) - grad_dict = {} - for param, name in self.param2name.items(): - if self.duplicate_param.get(name, False): - continue - grad = param.main_grad if self.params_have_main_grad else param.grad - if grad is None: - logger.warning(f"grad is None: {name}, maybe something wrong happened.") - continue - tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD) - self._register_param_call_id("hook_optimizer", tag) - grad_dict[tag] = grad + get_metrics(self.ops, post_grad_dict, self.eps, self.grad_context.post) + reduced_grad = self.grad_context.post + + if self.weight_hooked: + unreduced_grad = self.grad_context.acc_metric + else: + unreduced_grad = self.grad_context.pre - get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post) - unreduced_grad = self.grad_context.acc_metric if self.weight_hooked else self.grad_context.pre - return self.grad_context.post, unreduced_grad + return reduced_grad, unreduced_grad def generate_xy_metrics(self): actv = {} @@ -514,6 +543,17 @@ class TrainerMon: def write_adhoc_check(self, step): self.tensor_metrics.flush(self.summary_writer) + def write_stack_info(self): + stack_data = [] + header = ["module_name", "stack_info"] + stack_data.append(header) + for _, fwd_context in self.module_fwd_hook_context_by_module.items(): + stack_data.append([fwd_context.module_name, fwd_context.stack]) + filepath = os.path.join(self.tensorboard_dir, f'stack_info.csv') + if not os.path.exists(filepath): + data_frame = pd.DataFrame(columns=stack_data) + write_df_to_csv(data_frame, filepath) + def write_xy_tb(self, step): if not self.xy_distribution: return @@ -525,10 +565,31 @@ class TrainerMon: if self.grad_context.actv: self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, MonitorConst.ACTVGRAD) + def write_metrics_if_not_empty(self, features, metrics, step, hook_name): + if not features or len(features) == 0: + return + use_micro_step = hook_name not in ["linear_hook"] + self.summary_writer.write_metrics(metrics, features, step, hook_name, use_micro_step=use_micro_step) + features.clear() + + def write_features_tb(self, step): + if not self.recording_l2_features: + return + for context in self.feature_hook_context_by_module.values(): + num_features = len(context.attention_feature) + len(context.linear_feature) + if num_features == 0: + continue + self.write_metrics_if_not_empty(context.attention_feature, ["entropy", "softmax_max"], + step, "attention_hook") + self.write_metrics_if_not_empty(context.linear_feature, ["sr", "kernel_norm"], step, "linear_hook") + def write_param_tb(self, opt_context): if not self.param_distribution: return - self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, MonitorConst.PARAM) + param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.PRE_PARAM in k} + updated_param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.POST_PARAM in k} + self.summary_writer.write_metrics(self.ops, param_metrics, opt_context.step, MonitorConst.PRE_PARAM) + self.summary_writer.write_metrics(self.ops, updated_param_metrics, opt_context.step, MonitorConst.POST_PARAM) def write_mv_tb(self, opt_context): if not self.mv_distribution: @@ -542,10 +603,11 @@ class TrainerMon: if not self.wg_distribution: return - if self.enable_megatron: - self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced') + if self.weight_hooked: + self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced', + use_micro_step=self.monitor_mbs_grad) else: - self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced') + self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced') self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced') def hook_optimizer(self, optimizer): @@ -567,21 +629,23 @@ class TrainerMon: # skip generate metrics if context.step < self.start_step or (context.step - self.start_step) % self.step_interval != 0: return - if MonitorConst.DEEPSPEED_ZERO_OPT_FILTER in self.optimizer_class: # use deepspeed with zero1/2/3 - if not self.name2indices: - self.name2indices = self.optimizer_mon.get_param_index(self.param2name, self.name2index, optimizer) - mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name, self.name2indices) - self.param2name = mv_result.grad - else: - mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name) - context.param_exp_avg = mv_result.exp_avg - context.param_exp_avg_sq = mv_result.exp_avg_sq - context.param_adam_update = mv_result.update - context.param_adam_ratio = mv_result.ratio - self.generate_wgrad_metrics() + grad_dict = {} + if self.wg_distribution: + grad_dict = self.optimizer_mon.fetch_grad(self, self.param2name) + + mv_result = None + if self.mv_distribution or self.ur_distribution or self.mg_direction: + mv_result = self.optimizer_mon.fetch_mv(self, self.param2name) + if mv_result: + context.param_exp_avg = mv_result.exp_avg + context.param_exp_avg_sq = mv_result.exp_avg_sq + context.param_adam_update = mv_result.update + context.param_adam_ratio = mv_result.ratio + + _, _ = self.generate_wgrad_metrics(grad_dict) self.generate_mv_metrics(context) - self.generate_param_metrics(context) + self.generate_param_metrics(context, MonitorConst.PRE_PARAM) tbtag_tensor_map = {} if self.mg_direction: @@ -609,17 +673,15 @@ class TrainerMon: context.metric_dict = metric_dict return - def patch_step(func, optimizer): - def wrapper(*args, **kwargs): - optimizer_pre_step_hook(optimizer, args, kwargs) - out = func(*args, **kwargs) - return out - return wrapper + def optimizer_post_step_hook(optimizer, args, kwargs): + context = self.optimizer_context[optimizer] + self.generate_param_metrics(context, MonitorConst.POST_PARAM) if self.optimizer_hooked: return - optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer) + self.pre_step_hooks.append(optimizer_pre_step_hook) + self.post_step_hooks.append(optimizer_post_step_hook) self.optimizer_hooked = True return @@ -649,6 +711,7 @@ class TrainerMon: validate_config(config) self.config = config self.set_config() + self.start_step = context.step # 动态启停时不受原start_step影响,永远从下一步开始 logger.warning(f"config is updated at step{context.step - 1}, " f"will start new hook at step{context.step}.") except Exception as e: @@ -674,10 +737,17 @@ class TrainerMon: if self.anomaly_data_factory: self.anomaly_data_factory.set_call_id(self.param_name_call_id) self.write_xy_tb(context.step) + self.write_features_tb(context.step) self.write_grad_tb(context.step) self.write_mv_tb(context) self.write_param_tb(context) self.write_adhoc_check(context.step) + if self.stack_info: + self.write_stack_info() + self.stack_info = False + for handle in self.handles["stack"]: + handle.remove() + self.handles["stack"].clear() if self.ur_distribution: for param_name, _ in context.param_adam_update.items(): @@ -696,6 +766,9 @@ class TrainerMon: if self.anomaly_data_factory: self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies()) self.summary_writer.clear_anomalies() + + if self.format == MonitorConst.TENSORBOARD: + chmod_tensorboard_dir(self.tensorboard_dir) self.call_id = 0 self.param_name_call_id.clear() @@ -707,13 +780,16 @@ class TrainerMon: def patch_step(func, optimizer): def wrapper(*args, **kwargs): + for hook in self.pre_step_hooks: + hook(optimizer, args, kwargs) out = func(*args, **kwargs) + for hook in self.post_step_hooks: + hook(optimizer, args, kwargs) step_final_hook(optimizer, args, kwargs) return out return wrapper optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer) - self.origin_step_func = optimizer.__class__.step return def hook_modules(self): @@ -731,10 +807,12 @@ class TrainerMon: vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}' targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[ 'targets'].keys() - hooked_count += self._hook_module(targets, model_chunk, vpp_stage) + l2_target_names = self.config.get('l2_targets', '') + hooked_count += self._hook_module(targets, l2_target_names, model_chunk, vpp_stage) logger.info_on_rank_0(f"> {hooked_count} modules are monitored.") + @recursion_depth_decorator('msprobe.pytorch.monitor.clone_if_tensor') def clone_if_tensor(args): if isinstance(args, tuple): return tuple([clone_if_tensor(arg) for arg in args]) @@ -756,11 +834,24 @@ class TrainerMon: BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook) return + def register_param_call_id(self, hook_name: str, key: str): + """ + :param hook_name: + :param key: str, '0:relu_0/output_grad' + :return: + """ + logger.debug(f"{hook_name} {key}: {self.call_id}") + self.param_name_call_id[key] = self.call_id + self.call_id += 1 + def _remove_all_hooks(self, optimizer): # 清空hook handle for handle in self.handles['xy']: handle.remove() self.handles['xy'].clear() + for handle in self.handles['L2_features']: + handle.remove() + self.handles['L2_features'].clear() # 清空对应context缓存 for _, fwd_context in self.module_fwd_hook_context_by_module.items(): fwd_context.reset() @@ -768,27 +859,23 @@ class TrainerMon: bwd_context.reset() self.grad_context.reset() # 权重梯度和激活值梯度都在这 - if self.origin_start_grad_sync: # megatron - try: - from megatron.core.distributed.param_and_grad_buffer import Bucket - Bucket.start_grad_sync = self.origin_start_grad_sync - logger.info("remove Bucket start_grad_sync") - except ImportError: - pass - try: - from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup - _ParamAndGradBucketGroup.start_grad_sync = self.origin_start_grad_sync - logger.info("remove _ParamAndGradBucketGroup start_grad_sync") - except ImportError: - pass - else: # not megatron + self.optimizer_mon.restore_grad_sync(self) + if self.fsdp_post_backward_hook: # fsdp + torch.distributed.fsdp._runtime_utils._post_backward_hook = self.fsdp_post_backward_hook + logger.info("remove patch_post_backward_hook in fsdp.") + if self.fsdp2_foreach_reduce: # fsdp2 + torch.distributed.fsdp._fully_shard._fsdp_collectives.foreach_reduce = self.fsdp2_foreach_reduce + importlib.reload(torch.distributed.fsdp._fully_shard._fsdp_param_group) + logger.info("remove patch_foreach_reduce_hook in fsdp2.") + else: # not megatron and not fsdp for handle in self.handles['wgrads']: handle.remove() self.handles['wgrads'].clear() self.weight_hooked = False if self.optimizer_hooked: - optimizer.__class__.step = self.origin_step_func + self.pre_step_hooks.clear() + self.post_step_hooks.clear() for _, context in self.optimizer_context.items(): context.reset() @@ -797,12 +884,12 @@ class TrainerMon: for handle in self.handles['cc']: handle.remove() self.handles['cc'].clear() + api_register.restore_api() for _, context in self.cc_context.items(): context.reset() # 清空节点缓存 self.param2name.clear() - self.name2index.clear() self.name2indices.clear() self.name2param.clear() self.duplicate_param.clear() @@ -841,14 +928,11 @@ class TrainerMon: logger.info(msg) def _save_module_struct(self): - save_module_struct = (not dist.is_initialized() - or (self.module_rank_list and dist.get_rank() == min(self.module_rank_list)) - or (not self.module_rank_list and dist.get_rank() == 0)) - - if save_module_struct: - module_struct_file = os.path.realpath(os.path.join(get_output_base_dir(), 'module_struct.json')) - save_json(module_struct_file, self.module_struct, indent=2) - logger.info(f"> save module struct to {module_struct_file}") + output_dir = os.path.join(get_output_base_dir(), 'module_struct', f'rank{self.rank}') + make_dir(output_dir) + module_struct_file = os.path.realpath(os.path.join(output_dir, 'module_struct.json')) + save_json(module_struct_file, self.module_struct, indent=2) + logger.info(f"> save module struct to {module_struct_file}") self.struct_printed = True def _is_target_param(self, param_name, param, prefix): @@ -856,33 +940,41 @@ class TrainerMon: squash_name = prefix + squash_param_name(param_name, self.squash_name) for target in self.config['targets'].keys(): if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target): - setattr(param, "zero_out_wgrad", True) return True return False def _register_chunk(self, model_chunk, prefix): - index = 0 for (param_name, param) in model_chunk.named_parameters(): if not param.requires_grad: continue + if not self.fsdp_wrapped_module and param_name.startswith("_fsdp_wrapped_module"): + self.fsdp_wrapped_module = True + if not self.fsdp2_wrapped_module and param.__class__.__name__ == "DTensor": + self.fsdp2_wrapped_module = True if self._is_target_param(param_name, param, prefix): name = prefix + squash_param_name(param_name, self.squash_name) if name in self.param2name.values(): name = prefix + param_name self.param2name[param] = name self.name2param[name] = param - self.name2index[name] = index + self.origin2squash[param_name] = name if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group): self.duplicate_param[name] = True if self.dp_group and param_is_data_parallel_duplicate(self.dp_group): self.duplicate_param[name] = True + + keywords = [ + MonitorConst.PRE_GRAD, + MonitorConst.POST_GRAD, + MonitorConst.PRE_PARAM, + MonitorConst.POST_PARAM + ] self.name2tag[name] = { - MonitorConst.PRE_GRAD: get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD, self.rank), - MonitorConst.POST_GRAD: get_summary_writer_tag_name(name, MonitorConst.POST_GRAD, self.rank) + k: get_summary_writer_tag_name(name, k, self.rank) + for k in keywords } - index += 1 def _register_param_name(self): for vpp_stage, model_chunk in enumerate(self.model): @@ -900,16 +992,35 @@ class TrainerMon: return pattern return "" - def _hook_module(self, target_names, module: torch.nn.Module, vpp_stage=''): + def _is_recording_module(self, module_name, l2_targets, vpp_stage, hook_name): + + if len(l2_targets) > 0: + for pattern in [ + vpp_stage + squash_param_name(module_name, self.squash_name), + vpp_stage + module_name, + ]: + if pattern in l2_targets: + return pattern + elif hook_name in ["linear_hook"]: + return vpp_stage + squash_param_name(module_name, self.squash_name) + return "" + + def _hook_module(self, target_names, l2_target_names, module: torch.nn.Module, vpp_stage=''): if '_modules' not in module.__dict__: # nothing to hook return 0 - def fwd_hook_fun(module, module_input, module_output, name): + def fwd_hook_fun(module, args, kwargs, module_output, name): if not module.training or is_recomputation(): # 1 only monitor training stage. # 2 when open recompute, skip recomputed forward stage. return + + module_input = [tensor for tensor in args if torch.is_tensor(tensor)] + if kwargs: + kwargs_tensors = [tensor for tensor in kwargs.values() if torch.is_tensor(tensor)] + module_input.extend(kwargs_tensors) + if module not in self.module_fwd_hook_context_by_module: self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name) context: ModuleHookContext = self.module_fwd_hook_context_by_module[module] @@ -918,34 +1029,20 @@ class TrainerMon: Const.INPUT: get_param_struct(module_input), Const.OUTPUT: get_param_struct(module_output) } + if self.print_struct: self.module_struct[context.module_name].update(context.struct) return - if not context.format_by_arg: - context.set_format_by_arg(Const.INPUT, self.config['targets']) - context.set_format_by_arg(Const.OUTPUT, self.config['targets']) - if not context.format_by_arg: - return - if not context.verified: - context.focused_in_col = validate_config_spec(context.format_by_arg[Const.INPUT], - module_input, context.module_name, - Const.INPUT) - context.focused_out_col = validate_config_spec(context.format_by_arg[Const.OUTPUT], - module_output, context.module_name, - Const.OUTPUT) - context.verified = True - # expect output be tensor type + tbtag_tensor_map = {} - cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col] tbtag_tensor_map.update( self.build_tbtag_tensor_map( - f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}', - MonitorConst.ACTV, cared_input)) - cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col] + f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}', + MonitorConst.ACTV, module_input)) tbtag_tensor_map.update( self.build_tbtag_tensor_map( - f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}', - MonitorConst.ACTV, cared_output)) + f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}', + MonitorConst.ACTV, module_output)) get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv) context.micro_step += 1 @@ -963,31 +1060,17 @@ class TrainerMon: if self.print_struct: self.module_struct[context.module_name].update(context.struct) return - if not context.format_by_arg: - context.set_format_by_arg(MonitorConst.INPUT_GRAD, self.config['targets']) - context.set_format_by_arg(MonitorConst.OUTPUT_GRAD, self.config['targets']) - if not context.format_by_arg: - return - if not context.verified: - context.focused_in_col = validate_config_spec( - context.format_by_arg[MonitorConst.INPUT_GRAD], - input_grad, context.module_name, MonitorConst.INPUT_GRAD) - context.focused_out_col = validate_config_spec( - context.format_by_arg[MonitorConst.OUTPUT_GRAD], - output_grad, context.module_name, MonitorConst.OUTPUT_GRAD) - context.verified = True tbtag_tensor_map = {} - cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col] tbtag_tensor_map.update( self.build_tbtag_tensor_map( - f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}', - MonitorConst.ACTV, cared_input_grad)) - cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col] + f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}', + MonitorConst.ACTVGRAD, input_grad)) + tbtag_tensor_map.update( self.build_tbtag_tensor_map( - f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}', - MonitorConst.ACTV, cared_output_grad)) + f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}', + MonitorConst.ACTVGRAD, output_grad)) if context.micro_step == 0 and context.actvgrad: logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, " @@ -1001,17 +1084,85 @@ class TrainerMon: context.micro_step = 0 return + def extract_attention_feature_hook(module, module_input, module_output, name): + if is_recomputation() or not module.training: + return + + if module not in self.feature_hook_context_by_module: + self.feature_hook_context_by_module[module] = FeatureHookContext(name) + context: FeatureHookContext = self.feature_hook_context_by_module[module] + tbtag_tensor_map = {} + if len(module_input) < 2: + logger.warning( + f"Length of module_input in attention hook ({name}) is {len(module_input)}, " + "expected >= 2. Skipping feature extraction for this module." + ) + return + q_h = module_input[0] + k_h = module_input[1] + qkt = cal_qkt(q_h, k_h, order=self.sa_order) + tbtag_tensor_map.update( + self.build_tbtag_tensor_map(f'{context.module_name}.attention', + f'{MonitorConst.NAME_SEP}{context.micro_step}', 'qkt', qkt) + ) + get_entropy_metric(tbtag_tensor_map, context.attention_feature) + + context.micro_step += 1 + if context.micro_step == self.micro_batch_number: + context.micro_step = 0 + context.step += 1 + return + + def extract_linear_sr_hook(module, module_input, module_output, name): + if is_recomputation() or not module.training: + return + weight_name = self.get_linear_hook_target(module) + if weight_name == '': + return + + if module not in self.feature_hook_context_by_module: + self.feature_hook_context_by_module[module] = FeatureHookContext(name) + context: FeatureHookContext = self.feature_hook_context_by_module[module] + + if context.micro_step == (self.micro_batch_number - 1): + tbtag_tensor_map = {} + value = getattr(module, weight_name).data + tbtag_tensor_map.update( + self.build_tbtag_tensor_map(f'{context.module_name}.linear', + '', 'sr', value) + ) + get_sr_metric(tbtag_tensor_map, context.linear_feature) + + context.micro_step += 1 + if context.micro_step == self.micro_batch_number: + context.micro_step = 0 + context.step += 1 + return + + def stack_hook(module, args, kwargs, module_output, name): + if module not in self.module_fwd_hook_context_by_module: + self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name) + context: ModuleHookContext = self.module_fwd_hook_context_by_module[module] + context.stack = analyze_api_call_stack(name) + return + if self.backward_only and self.forward_only: logger.warning('not enable backward_only and forward_only simultaneously') hooked_count = 0 - if self.xy_distribution or self.print_struct: - for module_name, submodule in module.named_modules(): - name = self._is_target_module(module_name, target_names, vpp_stage) - if not name: - continue + for module_name, submodule in module.named_modules(): + if self.stack_info: + name = vpp_stage + squash_param_name(module_name, self.squash_name) + handle = submodule.register_forward_hook(partial(stack_hook, name=name), with_kwargs=True) + self.handles['stack'].append(handle) + name = self._is_target_module(module_name, target_names, vpp_stage) + if not name: + continue + if submodule.__class__.__name__ == "FullyShardedDataParallel": + continue + if self.xy_distribution or self.print_struct: if not self.backward_only: - handle = submodule.register_forward_hook(partial(fwd_hook_fun, name=name)) + handle = submodule.register_forward_hook(partial(fwd_hook_fun, name=name), with_kwargs=True) self.handles['xy'].append(handle) if not self.forward_only and not self.has_register_backward_hook(name, submodule): handle = submodule.register_full_backward_hook(bwd_hook_fun) @@ -1019,91 +1170,141 @@ class TrainerMon: self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name) logger.info_on_rank_0(f"> {name} is monitored successfully") hooked_count += 1 + if not self.print_struct and self.recording_l2_features: + for module_name, submodule in module.named_modules(): + func_map = { + "attention_hook": extract_attention_feature_hook, + "linear_hook": extract_linear_sr_hook, + } + for hook_name in func_map.keys(): + if hook_name not in l2_target_names: + continue + temp_names = l2_target_names[hook_name] + name = self._is_recording_module(module_name, temp_names, vpp_stage, hook_name) + if name: + handle = submodule.register_forward_hook(partial(func_map[hook_name], name=name)) + print_feature_name = hook_name.split('_')[0] + logger.info_on_rank_0( + f'> {print_feature_name} features of {name} is monitored successfully') + self.handles["L2_features"].append(handle) + hooked_count += 1 + continue + return hooked_count def _patch_grad_sync(self): - def patch_sync(sync_grad_func): - def wrapper(bucket): + if not self.wg_distribution: + return + if self.fsdp_wrapped_module: + # patch fsdp _runtime_utils._post_backward_hook + self._patch_fsdp_post_backward_hook() + return + + if self.fsdp2_wrapped_module: + # patch fsdp2 _fully_shard._fsdp_collectives.foreach_reduce + self._patch_fsdp2_foreach_reduce() + return + + if self.monitor_mbs_grad: + self._hook_weights() + return + + self.optimizer_mon.patch_grad_sync(self) + + if self.enable_megatron or self.enable_deepspeed: + return + + # default hook weights + self._hook_weights() + + def _patch_fsdp_post_backward_hook(self): + """ + FSDP runtime 需要处理整个forward和backward计算和通信的流程,通过override nn.Module的forward,定义相应的逻辑。 + 对AccumulateGrad对象注册hook,可以在backward计算grad后立刻执行,在reduce_scatter操作前采集梯度累计后,通信聚合前的梯度。 + 每个forward阶段,fsdp对AccumulateGrad重复注册hook方法,monitor工具内注册hook无法生效, + 因此对_post_backward_hook进行patch,在backward后,reduce_scatter前采集梯度。 + """ + def patch_post_backward_hook(_post_backward_hook): + def wrapper(state, handle, *unused): grad_dict = {} - # Megatron between core_r0.6.0 and core_r0.8.0, this bucket is Bucket. - # When megatron is core_r0.9.0, this bucket is _ParamAndGradBucketGroup. - # In megatron version core_r0.9.0, func start_grad_sync from Bucket moved to _ParamAndGradBucketGroup. - bucket_params_id_list = [id(params) for params in bucket.params] + offset = 0 for param, name in self.param2name.items(): - if id(param) not in bucket_params_id_list: - continue - grad = param.main_grad if self.params_have_main_grad else param.grad - if grad is None: - logger.warning(f"grad is None: {name}, maybe something wrong happened.") + limit = param.numel() + if not limit: continue + grad = handle.flat_param.grad[offset:offset + limit] + offset += limit tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD) if tag is None: continue grad_dict[tag] = grad - self._register_param_call_id("sync_grad_func", tag) + self.register_param_call_id("_post_backward_hook", tag) get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre) - out = sync_grad_func(bucket) + out = _post_backward_hook(state, handle, *unused) return out - return wrapper - if not self.wg_distribution: - return + logger.info("Patch fsdp _post_backward_hook, collect pre_grad metrics.") + self.fsdp_post_backward_hook = torch.distributed.fsdp._runtime_utils._post_backward_hook + torch.distributed.fsdp._runtime_utils._post_backward_hook = \ + patch_post_backward_hook(torch.distributed.fsdp._runtime_utils._post_backward_hook) - try: - from megatron.core.distributed.param_and_grad_buffer import Bucket - self.origin_start_grad_sync = Bucket.start_grad_sync - Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) - self.enable_megatron = True - logger.info("megatron version is >= core_r0.6.0 <= core_r0.8.0") - except ImportError: - self.enable_megatron = False + def _patch_fsdp2_foreach_reduce(self): + def patch_foreach_reduce(foreach_reduce): + def wrapper(fsdp_params, unsharded_grads, *unused): + grad_dict = {} + for param, grad in zip(fsdp_params, unsharded_grads): + tag = self.name2tag.get(self.origin2squash[param._param_fqn], {}).get(MonitorConst.PRE_GRAD) + if tag is None: + continue + grad_dict[tag] = grad + self.register_param_call_id("foreach_reduce", tag) + get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre) + out = foreach_reduce(fsdp_params, unsharded_grads, *unused) + return out + return wrapper - try: - from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup - self.origin_start_grad_sync = _ParamAndGradBucketGroup.start_grad_sync - _ParamAndGradBucketGroup.start_grad_sync = patch_sync(_ParamAndGradBucketGroup.start_grad_sync) - self.enable_megatron = True - logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0") - except ImportError: - self.enable_megatron = False - - if not self.enable_megatron: - self._hook_weights() + logger.info("Patch fsdp2 foreach_reduce, collect pre_grad metrics.") + import torch.distributed.fsdp._fully_shard._fsdp_param_group as _fsdp_param_group + import torch.distributed.fsdp._fully_shard._fsdp_collectives as _fsdp_collectives + self.fsdp2_foreach_reduce = _fsdp_collectives.foreach_reduce + _fsdp_collectives.foreach_reduce = patch_foreach_reduce(_fsdp_collectives.foreach_reduce) + importlib.reload(_fsdp_param_group) # 关键操作,不然会因为torch一开始就import foreach_reduce导致patch失效 def _hook_weights(self): + """ + 遍历参数的梯度生成函数(grad_acc),并挂载hook,以便在该参数所有梯度计算后,采集通信聚合前梯度数据。 + """ context = self.grad_context @torch.no_grad - def param_hook(*args, context_dict, param, key, name): + def param_hook(*args, context_dict, param, name): + key = name + if self.monitor_mbs_grad: + key += f'{MonitorConst.NAME_SEP}{param.micro_step}' + + key = get_summary_writer_tag_name(key, 'acc_grad', self.rank) + self.register_param_call_id("param_hook", key) param.micro_step += 1 - self._register_param_call_id("param_hook", key) - if param.micro_step == self.micro_batch_number: - param.micro_step = 0 + + if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number): if self.params_have_main_grad: - context_dict[key] = param.main_grad.clone() + grad = param.main_grad else: - context_dict[key] = param.grad.clone() + grad = param.grad + context_dict[key] = grad.clone() + + if param.micro_step == self.micro_batch_number: + param.micro_step = 0 logger.info("hooking weights.") for param, name in self.param2name.items(): - key = get_summary_writer_tag_name(name, 'acc_grad', self.rank) setattr(param, 'micro_step', 0) param_tmp = param.expand_as(param) grad_acc = param_tmp.grad_fn.next_functions[0][0] handle = grad_acc.register_hook( - partial(param_hook, context_dict=context.acc, param=param, key=key, name=name)) + partial(param_hook, context_dict=context.acc, param=param, name=name)) self.grad_accs.append(grad_acc) self.handles['wgrads'].append(handle) self.weight_hooked = True - - def _register_param_call_id(self, hook_name: str, key: str): - """ - :param hook_name: - :param key: str, '0:relu_0/output_grad' - :return: - """ - logger.debug(f"{hook_name} {key}: {self.call_id}") - self.param_name_call_id[key] = self.call_id - self.call_id += 1 diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py index 87963812006413a90fd33bc70d6172a7c73c3f10..24e9b43d6378e37f09d3c654511c840e9168f6f6 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py @@ -17,6 +17,7 @@ import re import torch from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean +from msprobe.pytorch.monitor.features import cal_entropy, cal_stable_rank from msprobe.pytorch.monitor.utils import get_nan_tensor @@ -31,7 +32,8 @@ def squash_param_name(param_name, enable=True): if not enable: return param_name name = '' - for pattern in ['layers?\.(.*)', 'embeddings?\.(.*)', 'final.*', 'output.*', 'norm.*']: + for pattern in ['^.*\.(layers?\..*)', '^.*\.(embeddings?\..*)', '^.*\.(final.*)', '^.*\.(output.*)', + '^.*\.(norm.*)']: match = re.findall(pattern, param_name) if match: name += match[0] @@ -143,6 +145,20 @@ class IdentMetric(Metric): return tensor +@register_config_metric("shape") +class ShapeMetric(Metric): + @staticmethod + def get_metric_value(tensor, eps): + return tensor.shape + + +@register_config_metric("dtype") +class DtypeMetric(Metric): + @staticmethod + def get_metric_value(tensor, eps): + return tensor.dtype + + def get_metrics(ops, tag2tensor, eps, out_dict=None): """ :param ops: ["op1", "op2"] @@ -170,3 +186,27 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None): fun_metric = config_metric_registry.get(metric_name) out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps) return out_dict + + +def get_sr_metric(tag2tensor, out_dict=None): + if out_dict is None: + out_dict = {} + for tag, tensor in tag2tensor.items(): + if "sr" not in tag: + continue + if tag not in out_dict: + out_dict[tag] = {} + sr, eig = cal_stable_rank(tensor) + out_dict[tag]['sr'] = sr + out_dict[tag]['kernel_norm'] = eig + + +def get_entropy_metric(tag2tensor, out_dict=None): + if out_dict is None: + out_dict = {} + for tag, tensor in tag2tensor.items(): + if tag not in out_dict: + out_dict[tag] = {} + entropy, softmax_max = cal_entropy(tensor) + out_dict[tag]['entropy'] = entropy + out_dict[tag]['softmax_max'] = softmax_max diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_spec_verifier.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_spec_verifier.py deleted file mode 100644 index 72c35c90bf9540a31cfa1176274a3d2c66bc8946..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_spec_verifier.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -import abc -import torch - -from msprobe.pytorch.common.log import logger - -# 用于存储所有validator实现类的注册表 -config_validator_registry = {} - - -def register_config_validator(cls): - """装饰器 用于注册ConfigValidator的实现类""" - config_validator_registry[cls.__name__] = cls - return cls - - -class ConfigValidator(metaclass=abc.ABCMeta): - @abc.abstractmethod - def check_pattern_match(self, config_spec: str): - pass - - @abc.abstractmethod - def validate(self, actual_data, module_name: str, data_type: str, pattern_match): - pass - - -@register_config_validator -class TensorValidator(ConfigValidator): - def check_pattern_match(self, config_spec: str): - pattern = re.compile(r"tensor") - return pattern.match(config_spec) - - def validate(self, actual_data, module_name: str, data_type: str, pattern_match): - if not torch.is_tensor(actual_data): - raise ValueError( - f"Format of {module_name} {data_type} does not match the required format 'tensor' in config.") - - -@register_config_validator -class TupleValidator(ConfigValidator): - def check_pattern_match(self, config_spec: str): - pattern = re.compile(r"tuple\[(\d+)\]:?(\d+)?") - return pattern.match(config_spec) - - def validate(self, actual_data, module_name: str, data_type: str, pattern_match): - length, index = pattern_match.groups() - if index is None: - index = 0 - length, index = int(length), int(index) - - if not (0 <= index < length): - raise ValueError( - f"Format of {module_name} {data_type} in config.json does not match the required format 'tuple[x]:y'." - f"y must be greater than or equal to 0 and less than x.") - if not isinstance(actual_data, tuple): - raise ValueError( - f"Type of {module_name} {data_type} does not match spec of config.json, should be tuple, please check.") - if len(actual_data) != length: - raise ValueError( - f"Length of {module_name} {data_type} does not match spec of config.json, should be {length}, " - f"actual is {len(actual_data)} please check.") - return index - - -def validate_config_spec(config_spec: str, actual_data, module_name: str, data_type: str): - focused_col = None - if not config_spec or not isinstance(config_spec, str): - return focused_col - for _, validator_cls in config_validator_registry.items(): - config_validator = validator_cls() - pattern_match = config_validator.check_pattern_match(config_spec) - if pattern_match: - try: - focused_col = config_validator.validate(actual_data, module_name, data_type, pattern_match) - except ValueError as e: - logger.warning(f"config spec validate failed: {str(e)}") - return focused_col - logger.warning(f"config spec in {module_name} {data_type} not supported, " - f"expected spec:'tuple\[(\d+)\]:(\d+)' or 'tensor', actual spec: {config_spec}.") - return focused_col diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py index 602514836d2531ad4a6be3a23f56bc3b942ba199..2a5a95841579866c8b979111a7de4f4f751eceae 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py @@ -12,151 +12,196 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from collections import defaultdict +from abc import abstractmethod import torch -import torch.distributed as dist from msprobe.pytorch.common.log import logger -from msprobe.pytorch.monitor.utils import MVResult, MVGradResult +from msprobe.core.monitor.utils import MVResult +from msprobe.pytorch.monitor.module_metric import get_metrics +from msprobe.core.common.const import MonitorConst class OptimizerMon(object): - def __init__(self) -> None: + def __init__(self, torch_opt) -> None: self.fp16_to_fp32_param = {} - self.is_stage3 = False + self.torch_opt = torch_opt + self.state = {} + self.origin_funcs = [] + self.bucket_class = None + + def narrow_from_flatten(self, param, flatten_state): + return flatten_state + + def get_state(self, torch_opt): + if hasattr(torch_opt, 'chained_optimizers'): + for opt in torch_opt.chained_optimizers: + self._get_single_state(opt) + else: + self._get_single_state(torch_opt) - def fetch_mv(self, monitor, torch_opt, params2name): - pass + def fetch_grad(self, monitor, params2name): + if not self.fp16_to_fp32_param: + self.map_fp16_to_fp32_param(self.torch_opt) - def _fetch_mv_in_adam(self, monitor, torch_opt, params2name): - exp_avg_dict = defaultdict(float) - exp_avg_sq_dict = defaultdict(float) - update_dict = defaultdict() - ratio_dict = defaultdict() + grad_dict = {} + first_param = True for param, name in params2name.items(): - if param in self.fp16_to_fp32_param: - param = self.fp16_to_fp32_param[param] - - if param in torch_opt.state: - state_param = torch_opt.state.get(param, None) - exp_avg = state_param.get("exp_avg", None) - exp_avg_sq = state_param.get("exp_avg_sq", None) - if exp_avg is None or exp_avg_sq is None: - logger.warning(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.") - continue + if monitor.duplicate_param.get(name, False): + continue + if self.fp16_to_fp32_param and param not in self.fp16_to_fp32_param: + continue + grad = param.main_grad if monitor.params_have_main_grad else param.grad + if grad.__class__.__name__ == 'DTensor': + grad = grad.to_local() + element_in_cur_partition = self.fp16_to_fp32_param.get(param, param).numel() + if param.numel() != element_in_cur_partition: + if first_param: + grad = grad.flatten()[-element_in_cur_partition:] + else: # supposed to be the last one + grad = grad.flatten()[:element_in_cur_partition] + first_param = False + + if grad is None: + if not monitor.fsdp_wrapped_module: + logger.warning(f"grad is None: {name}, maybe something wrong happened.") + continue + tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD) + monitor.register_param_call_id("hook_optimizer", tag) + grad_dict[tag] = grad + return grad_dict + + def map_fp16_to_fp32_param(self, torch_opt): + pass + + def fetch_mv(self, monitor, params2name): + if not self.fp16_to_fp32_param: + self.map_fp16_to_fp32_param(self.torch_opt) + if not self.state: + self.get_state(self.torch_opt) + + exp_avg_dict = {} + exp_avg_sq_dict = {} + update_dict = {} + ratio_dict = {} + + if not self.state: + logger.warning('optimizer state can not accessed') + return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict) + + for lp_param, name in params2name.items(): + if lp_param in self.fp16_to_fp32_param: + hp_param = self.fp16_to_fp32_param[lp_param] + else: + hp_param = lp_param + + if hp_param in self.state: + state_param = self.state.get(hp_param, {}) + exp_avg = self.narrow_from_flatten(lp_param, state_param.get("exp_avg", None)) + exp_avg_sq = self.narrow_from_flatten(lp_param, state_param.get("exp_avg_sq", None)) if monitor.mv_distribution: exp_avg_dict[name] = exp_avg exp_avg_sq_dict[name] = exp_avg_sq if monitor.mg_direction: exp_avg_dict[name] = exp_avg if monitor.ur_distribution: - if len(torch_opt.param_groups) > 1: - logger.info(f"the length of torch_opt.param_groups is {len(torch_opt.param_groups)}.") + if len(self.torch_opt.param_groups) > 1: + logger.info(f"the length of torch_opt.param_groups is {len(self.torch_opt.param_groups)}.") if 'step' in state_param: step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron) - elif 'step' in torch_opt.param_groups[0]: - step = torch_opt.param_groups[0]['step'] # AdamW from mindspeed + elif 'step' in self.torch_opt.param_groups[0]: + step = self.torch_opt.param_groups[0]['step'] # AdamW from mindspeed else: logger.warning(f"step of {name} is None, maybe something wrong happened.") continue - exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step) - exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step) - update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps']) + if exp_avg is None or exp_avg_sq is None: + logger.warning(f"exp_avg or exp_avg_sq of {name} is None, skip calculation.") + continue + exp_avg_hat = exp_avg / (1 - self.torch_opt.defaults['betas'][0] ** step) + exp_avg_sq_hat = exp_avg_sq / (1 - self.torch_opt.defaults['betas'][1] ** step) + update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + self.torch_opt.defaults['eps']) ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat) monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name]) monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name]) return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict) - - def _fetch_mv_grad_in_adam(self, monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat): - exp_avg_dict = defaultdict(float) - exp_avg_sq_dict = defaultdict(float) - update_dict = defaultdict() - ratio_dict = defaultdict() - param2name = defaultdict() - fp32_partitioned_groups_flat_grad = defaultdict() - partition_id = dist.get_rank() - - def get_flatten_grad(self, optimizer, group_idx): - if fp32_partitioned_groups_flat[group_idx].grad is None: - if partition_id == dist.get_world_size() - 1 and not self.is_stage3: - fp32_partitioned_groups_flat_grad = optimizer.flatten_dense_tensors_aligned( - optimizer.averaged_gradients[group_idx], - int(optimizer.partition_size[group_idx]) - ).to(fp32_partitioned_groups_flat[group_idx].dtype) - else: - fp32_partitioned_groups_flat_grad = optimizer.flatten( - optimizer.averaged_gradients[group_idx] - ).to(fp32_partitioned_groups_flat[group_idx].dtype) - return fp32_partitioned_groups_flat_grad - else: - return fp32_partitioned_groups_flat[group_idx].grad - - for group_idx in range(len(fp32_partitioned_groups_flat)): - fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self, torch_opt, group_idx) - - for name in params2name.values(): - start_idx, end_idx, group_idx, group_with_rank = name2indices[name] - if group_with_rank != partition_id and isinstance(group_with_rank, int): - continue - fp32_param = fp32_partitioned_groups_flat[group_idx][start_idx: end_idx] - fp32_param.grad = fp32_partitioned_groups_flat_grad[group_idx][start_idx: end_idx] - param2name[fp32_param] = name - if not torch_opt.state: - continue - state_param = list(torch_opt.state.values())[group_idx] - exp_avg = state_param.get("exp_avg", None) - exp_avg_sq = state_param.get("exp_avg_sq", None) - if exp_avg is None or exp_avg_sq is None: - logger.warning(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.") - continue - exp_avg = exp_avg[start_idx: end_idx] - exp_avg_sq = exp_avg_sq[start_idx: end_idx] - if monitor.mv_distribution: - exp_avg_dict[name] = exp_avg - exp_avg_sq_dict[name] = exp_avg_sq - if monitor.mg_direction: - exp_avg_dict[name] = exp_avg - if monitor.ur_distribution: - if 'step' in state_param: - step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron) - elif 'step' in torch_opt.param_groups[group_idx]: - step = torch_opt.param_groups[group_idx]['step'] # AdamW from mindspeed - else: - logger.warning(f"step of {name} is None, maybe something wrong happened.") - continue - exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step) - exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step) - update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps']) - ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat) - monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name]) - monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name]) - del fp32_partitioned_groups_flat_grad - return MVGradResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict, - grad=param2name) - - -class MixPrecisionOptimizerMon(OptimizerMon): + + def patch_grad_sync(self, monitor): + def patch_sync(sync_grad_func): + def wrapper(bucket): + grad_dict = {} + # Megatron between core_r0.6.0 and core_r0.8.0, this bucket is Bucket. + # When megatron is core_r0.9.0, this bucket is _ParamAndGradBucketGroup. + # In megatron version core_r0.9.0, func start_grad_sync from Bucket moved to _ParamAndGradBucketGroup. + bucket_params_id_list = [id(params) for params in bucket.params] + for param, name in monitor.param2name.items(): + if id(param) not in bucket_params_id_list: + continue + grad = param.main_grad if monitor.params_have_main_grad else param.grad + if grad is None: + logger.warning(f"grad is None: {name}, maybe something wrong happened.") + continue + tag = monitor.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD) + if tag is None: + continue + grad_dict[tag] = grad + monitor.register_param_call_id("sync_grad_func", tag) + get_metrics(monitor.ops, grad_dict, monitor.eps, monitor.grad_context.pre) + out = sync_grad_func(bucket) + return out + + return wrapper + + try: + from megatron.core.distributed.param_and_grad_buffer import Bucket + self.origin_funcs.append(Bucket.start_grad_sync) + self.bucket_class = Bucket + Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) + monitor.enable_megatron = True + logger.info("megatron version is >= core_r0.6.0 <= core_r0.8.0") + except ImportError: + monitor.enable_megatron = False + + try: + from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup + self.origin_funcs.append(_ParamAndGradBucketGroup.start_grad_sync) + self.bucket_class = _ParamAndGradBucketGroup + _ParamAndGradBucketGroup.start_grad_sync = patch_sync(_ParamAndGradBucketGroup.start_grad_sync) + monitor.enable_megatron = True + logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0") + except ImportError: + monitor.enable_megatron = False | monitor.enable_megatron + + def restore_grad_sync(self, monitor): + if not monitor.enable_megatron: + return + + self.bucket_class.start_grad_sync = self.origin_funcs[0] + + + def _get_single_state(self, torch_opt): + state = {} + if hasattr(torch_opt, 'param_to_cpu_states_map'): + state = torch_opt.param_to_cpu_states_map + elif hasattr(torch_opt, 'state'): + state = torch_opt.state + elif hasattr(torch_opt, 'optimizer') and hasattr(torch_opt.optimizer, 'state'): + state = torch_opt.optimizer.state + self.state.update(state) + + +class MegatronMixPrecisionOptimizerMon(OptimizerMon): """ 混合精度优化器监控类。在混合精度训练中监控和管理优化器。 混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。 """ - - def map_fp16_tp_fp32_param(self, torch_opt): + def map_fp16_to_fp32_param(self, torch_opt): for fp16_group, fp32_group in zip(torch_opt.float16_groups, torch_opt.fp32_from_float16_groups): for fp16_param, fp32_param in zip(fp16_group, fp32_group): self.fp16_to_fp32_param[fp16_param] = fp32_param - def fetch_mv(self, monitor, torch_opt, params2name): - if not self.fp16_to_fp32_param and torch_opt is not None: - self.map_fp16_tp_fp32_param(torch_opt) - - return self._fetch_mv_in_adam(monitor, torch_opt, params2name) - class MegatronDistributedOptimizerMon(OptimizerMon): - def map_fp16_tp_fp32_param(self, torch_opt): + def map_fp16_to_fp32_param(self, torch_opt): if not (hasattr(torch_opt, "model_float16_groups") and hasattr(torch_opt, "shard_fp32_from_float16_groups")): raise Exception( @@ -167,141 +212,223 @@ class MegatronDistributedOptimizerMon(OptimizerMon): for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group): self.fp16_to_fp32_param[fp16_param] = shard_fp32_param - def fetch_mv(self, monitor, torch_opt, params2name): - if not self.fp16_to_fp32_param and torch_opt is not None: - self.map_fp16_tp_fp32_param(torch_opt) - - return self._fetch_mv_in_adam(monitor, torch_opt, params2name) +class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon): + def map_fp16_to_fp32_param(self, torch_opt): + for opt in torch_opt.chained_optimizers: + super().map_fp16_to_fp32_param(opt) -class MegatronFP32OptimizerMon(OptimizerMon): - def fetch_mv(self, monitor, torch_opt, params2name): - return self._fetch_mv_in_adam(monitor, torch_opt, params2name) +class MegatronChainedMixPrecisionOptimizerMon(MegatronMixPrecisionOptimizerMon): + def map_fp16_to_fp32_param(self, torch_opt): + for opt in torch_opt.chained_optimizers: + super().map_fp16_to_fp32_param(opt) -class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon): - def fetch_mv(self, monitor, torch_opt, params2name): - if not self.fp16_to_fp32_param and torch_opt is not None: - for opt in torch_opt.chained_optimizers: - self.map_fp16_tp_fp32_param(opt) - if not isinstance(torch_opt, torch.optim.Optimizer): - torch_opt.state = {} - for opt in torch_opt.chained_optimizers: - torch_opt.state.update(opt.optimizer.state) - return self._fetch_mv_in_adam(monitor, torch_opt, params2name) +class DeepSpeedZeroOptimizerMon(OptimizerMon): + """ + Base monitor class for DeepSpeed ZeRO optimizer. + ZeRO stage 0 no partition + ZeRO stage 1 partitions optimizer states across data parallel processes. + ZeRO stage 2 additionally partitions gradients. + ZeRO stage 3 additionally partitions parameters. + + This class provides monitoring capabilities for ZeRO optimizers by: + - Handling gradient collection for different ZeRO stages + - Managing optimizer state access for monitoring + """ + def __init__(self, torch_opt): + super().__init__(torch_opt) + self.stage = '' + self.bit16_groups = [] + self.fp32_flat_groups = [] + self.param2group = () + self.param2index = [] + self.group_offset = {} + + @abstractmethod + def get_grad_for_param(self, lp_param, group_idx, param_id): + raise NotImplementedError + + def param_not_in_partition(self, lp_param, group_idx): + param_slice_mapping = self.torch_opt.state_dict()['param_slice_mappings'][group_idx] + hp_address = param_slice_mapping.get(self.torch_opt.param_names.get(lp_param)) + return hp_address is None + + def get_position(self, lp_param, group_idx): + param_slice_mapping = self.torch_opt.state_dict()['param_slice_mappings'][group_idx] + hp_address = param_slice_mapping.get(self.torch_opt.param_names.get(lp_param)) + return hp_address.start, hp_address.numel + + def get_group_index(self): + param2group = {} + for group_idx, bit16_group in enumerate(self.bit16_groups): + for param in bit16_group: + param2group[param] = group_idx + return param2group + + def get_param_index(self, lp_param, group_idx): + if not self.param2index: + for group in self.bit16_groups: + param2index = {} + for index, param in enumerate(group): + param2index[param] = index + self.param2index.append(param2index) + + return self.param2index[group_idx][lp_param] + + def narrow_from_flatten(self, param, flatten_state): + if flatten_state is None: + return flatten_state + group_idx = self.param2group[param] + if self.param_not_in_partition(param, group_idx): + return None + start, numel = self.get_position(param, group_idx) + return flatten_state.narrow(0, start, numel) + + def map_fp16_to_fp32_param(self, torch_opt): + for group_idx, group in enumerate(self.bit16_groups): + for param in group: + self.fp16_to_fp32_param[param] = self.fp32_flat_groups[group_idx] + + def fetch_grad(self, monitor, params2name): + grad_dict = {} + for lp_param, name in params2name.items(): + group_idx = self.param2group[lp_param] + param_id = self.get_param_index(lp_param, group_idx) + if self.param_not_in_partition(lp_param, group_idx): + continue + if self.stage == '1or2': + param_id = param_id - self.group_offset[group_idx] - 1 + grad = self.get_grad_for_param(lp_param, group_idx, param_id) + tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD) + monitor.register_param_call_id("hook_optimizer", tag) + grad_dict[tag] = grad + + return grad_dict + + def patch_grad_sync(self, monitor): + pass + def restore_grad_sync(self, monitor): + pass -class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon): - def fetch_mv(self, monitor, torch_opt, params2name): - if not self.fp16_to_fp32_param and torch_opt is not None: - for opt in torch_opt.chained_optimizers: - self.map_fp16_tp_fp32_param(opt) - if not isinstance(torch_opt, torch.optim.Optimizer): - torch_opt.state = {} - for opt in torch_opt.chained_optimizers: - torch_opt.state.update(opt.optimizer.state) - return self._fetch_mv_in_adam(monitor, torch_opt, params2name) - - -class DeepSpeedZeroOptimizerStage0Mon(OptimizerMon): - def fetch_mv(self, monitor, torch_opt, params2name): - return self._fetch_mv_in_adam(monitor, torch_opt, params2name) - - -class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon): - def get_param_index(self, params2name, name2index, torch_opt): - fp16_groups = torch_opt.fp16_partitioned_groups - name2indices = defaultdict() - index_length = defaultdict() - index = 0 - idx = 0 - for group_idx, fp16_group in enumerate(fp16_groups): - for param in fp16_group: - param_length = len(param.flatten()) - index_length[idx] = (index, index + param_length, group_idx) - index += param_length - idx += 1 - for _, name in params2name.items(): - idx = name2index[name] - start_idx, end_idx, group_idx = index_length[idx] - name2indices[name] = (start_idx, end_idx, group_idx, None) - return name2indices - - def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None): - self.is_stage3 = True - fp32_partitioned_groups_flat = torch_opt.fp32_partitioned_groups_flat - return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat) - - -class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon): - @staticmethod - def get_group_index(fp32_length, world_size, index): - for i in range(len(fp32_length) - 1): - if fp32_length[i] <= index < fp32_length[i + 1]: - interval_start = fp32_length[i] - interval_length = fp32_length[i + 1] - fp32_length[i] - sub_interval_length = interval_length // world_size - sub_index = (index - interval_start) // sub_interval_length - sub_interval_start = interval_start + sub_index * sub_interval_length - return sub_interval_start, min(sub_index, world_size - 1) - return fp32_length[-1], 0 - - def get_param_index(self, params2name, name2index, torch_opt): - padding = torch_opt.groups_padding - world_size = dist.get_world_size() - fp32_length = [0] - for fp32_group_index, single_partition_of_fp32_group in enumerate(torch_opt.single_partition_of_fp32_groups): - fp32_length.append(len(single_partition_of_fp32_group) * world_size + fp32_length[fp32_group_index]) - - bf16_groups = [] - name2indices = defaultdict() - index_length = defaultdict() - index = 0 - idx = 0 - for group_idx, bf16_group in enumerate(torch_opt.bit16_groups): - bf16_groups.extend(bf16_group) - for param in bf16_group: - param_length = len(param.flatten()) - group_index, group_with_rank = self.get_group_index(fp32_length, world_size, index) - index_length[idx] = (index, index + param_length, group_idx, group_index, group_with_rank) - index += param_length - idx += 1 - group_length = len(bf16_groups) / len(torch_opt.bit16_groups) - for _, name in params2name.items(): - name_index = name2index[name] - start_idx, end_idx, group_idx, group_index, group_with_rank = index_length[name_index] - need_padding = True if group_with_rank == world_size - 1 else False - new_start_idx = start_idx - group_index - new_end_idx = end_idx - group_index - if need_padding and group_length - 1 <= name_index <= len(bf16_groups) - 1 and name_index % ( - group_length - 1) == 0: - new_end_idx -= padding[int(name_index // (group_length - 1) - 1)] - name2indices[name] = (new_start_idx, new_end_idx, group_idx, group_with_rank) - return name2indices - - def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None): - fp32_partitioned_groups_flat = torch_opt.single_partition_of_fp32_groups - return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat) - - -class DummyOptimizerMon(OptimizerMon): - def fetch_mv(self, monitor, torch_opt, params2name): - return self._fetch_mv_in_adam(monitor, torch_opt, params2name) +class DeepSpeedZeroOptimizerStage0Mon(DeepSpeedZeroOptimizerMon): + def __init__(self, torch_opt): + super().__init__(torch_opt) + self.stage = '0' + self.bit16_groups = torch_opt.bf16_groups + self.fp32_flat_groups = torch_opt.fp32_groups_flat_partition + self.param2group = self.get_group_index() + + def get_grad_for_param(self, lp_param, group_idx, param_id): + return self.torch_opt.fp32_groups_gradient_dict[group_idx][param_id] + + +class DeepSpeedZeroOptimizerStage1or2Mon(DeepSpeedZeroOptimizerMon): + def __init__(self, torch_opt): + super().__init__(torch_opt) + self.stage = '1or2' + self.bit16_groups = torch_opt.bit16_groups + self.fp32_flat_groups = torch_opt.single_partition_of_fp32_groups + self.param2group = self.get_group_index() + self.group_offset = {} + self.get_group_offset() + + def get_grad_for_param(self, lp_param, group_idx, param_id): + if getattr(self.torch_opt, "cpu_offload", False): + grads = self.torch_opt.single_partition_of_fp32_groups[group_idx].grad + start, numel = self.get_position(lp_param, group_idx) + grad = grads.narrow(0, start, numel) + else: + grad = self.torch_opt.averaged_gradients[group_idx][param_id] + return grad + + def get_group_offset(self): + for group_idx, group in enumerate(self.bit16_groups): + self.group_offset[group_idx] = -1 + for lp_param in group: + if self.param_not_in_partition(lp_param, group_idx): + self.group_offset[group_idx] = self.get_param_index(lp_param, group_idx) + else: + break + + + def patch_grad_sync(self, monitor): + def patch_sync(reduce_func): + def wrapper(zero_optimizer, *args, **kwargs): + grad_dict = {} + for i, param, _ in zero_optimizer.params_in_ipg_bucket: + if isinstance(param, int): # for ds >= 0.17.0 + param = zero_optimizer.bit16_groups[i][param] + name = monitor.param2name[param] + tag = monitor.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD) + grad_dict[tag] = zero_optimizer.get_gradient_for_reduction(param) + monitor.register_param_call_id("sync_grad_func", tag) + get_metrics(monitor.ops, grad_dict, monitor.eps, monitor.grad_context.pre) + out = reduce_func(zero_optimizer, *args, **kwargs) + return out + + return wrapper + try: + from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer + self.origin_funcs = [ + DeepSpeedZeroOptimizer.average_tensor, + DeepSpeedZeroOptimizer.buffered_reduce_fallback + ] + DeepSpeedZeroOptimizer.average_tensor = patch_sync(DeepSpeedZeroOptimizer.average_tensor) + DeepSpeedZeroOptimizer.buffered_reduce_fallback = \ + patch_sync(DeepSpeedZeroOptimizer.buffered_reduce_fallback) + monitor.enable_deepspeed = True + logger.info('deepspeed enabled') + except Exception as e: + monitor.enable_deepspeed = False | monitor.enable_deepspeed + logger.warning('Seems using deepspeed zero 1 or 2. But patch average tensor failed') + + def restore_grad_sync(self, monitor): + if not monitor.enable_deepspeed: + return + + from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer + DeepSpeedZeroOptimizer.average_tensor = self.origin_funcs[0] + DeepSpeedZeroOptimizer.buffered_reduce_fallback = self.origin_funcs[1] + + + +class DeepSpeedZeroOptimizerStage3Mon(DeepSpeedZeroOptimizerMon): + def __init__(self, torch_opt): + super().__init__(torch_opt) + self.stage = '3' + self.bit16_groups = torch_opt.fp16_groups + self.fp32_flat_groups = torch_opt.fp32_partitioned_groups_flat + self.param2group = self.get_group_index() + + def param_not_in_partition(self, lp_param, group_idx): + """Each param partioned across all zero ranks""" + return False + + def get_position(self, lp_param, group_idx): + param_id = self.torch_opt.get_param_id(lp_param) + return self.torch_opt.grad_position[param_id][1:] + + def get_grad_for_param(self, lp_param, group_idx, param_id): + return self.torch_opt.averaged_gradients[group_idx][param_id] class OptimizerMonFactory: _optimizer_mon_map = { - "FP32Optimizer": MegatronFP32OptimizerMon, - "Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon, + "FP32Optimizer": OptimizerMon, + "Float16OptimizerWithFloat16Params": MegatronMixPrecisionOptimizerMon, "DistributedOptimizer": MegatronDistributedOptimizerMon, + "SwapDistributedOptimizer": MegatronDistributedOptimizerMon, "ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon, + "ChainedSwapDistributedOptimizer": MegatronChainedDistributedOptimizerMon, "ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon, "BF16_Optimizer": DeepSpeedZeroOptimizerStage0Mon, "DeepSpeedZeroOptimizer": DeepSpeedZeroOptimizerStage1or2Mon, "DeepSpeedZeroOptimizer_Stage3": DeepSpeedZeroOptimizerStage3Mon, - "Adam": DummyOptimizerMon + "Adam": OptimizerMon } @staticmethod @@ -310,6 +437,7 @@ class OptimizerMonFactory: optimizer_class = optimizer.__class__.__name__ if optimizer_class == "ChainedOptimizer": optimizer_class = "Chained" + optimizer.chained_optimizers[0].__class__.__name__ + logger.info(f'The optimizer type is {optimizer_class}') - optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class, DummyOptimizerMon) - return optimizer_mon_class(), optimizer_class + optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class, OptimizerMon) + return optimizer_mon_class(optimizer) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/unittest/test_monitor.py b/debug/accuracy_tools/msprobe/pytorch/monitor/unittest/test_monitor.py deleted file mode 100644 index 4d5c1a717d80ee30414f25b44a93ddc7257ef2c7..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/unittest/test_monitor.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os -import re -from glob import glob - -import pandas as pd - -from msprobe.pytorch.common.log import logger - - -def parse_logfile(logfile): - grad_norm = [] - step = [] - with open(logfile) as f: - for line in f.readlines(): - if 'consumed samples' in line: - grad_norm.append(float(re.findall('(?<=grad norm\: )[\d\.]*', line)[0])) - return grad_norm - - -def parse_monitor_output(output_dir): - reduced = {} - unreduced = {} - for directory in glob(output_dir + '*'): - rank = int(re.findall('(?<=rank)[\d]*', directory)[0]) - unreduced[rank] = [] - reduced[rank] = [] - for file in os.listdir(directory): - df = pd.read_csv(os.path.join(directory, file)) - if '_unreduced_' in file: - unreduced[rank].append(df) - pass - elif '_reduced_' in file: - reduced[rank].append(df) - else: - logger.info(f'unexpected file {file} in {directory}') - return reduced, unreduced - - -def valid_reduce(reduced, unreduced, tp_size, dp_size, sequence_parallel): - steps = len(reduced[0]) - world_size = len(reduced) - errors = [] - for _, row in unreduced[0][0].iterrows(): - param = row['param_name'] - is_tp_duplicate = False - for step in range(2): - # sum reduced - reduced_mean = 0. - for rank in range(world_size): - if len(reduced[rank]) == 0: - continue - df = reduced[rank][step] - value = list(df[df['param_name'] == param]['mean']) - if not value: - if step == 0: - is_tp_duplicate = True - continue - reduced_mean += value[0] - - # sum unreduced - unreduced_mean = 0. - for rank in range(world_size): - df = unreduced[rank][step] - value = list(df[df['param_name'] == param]['mean']) - if not value: - continue - unreduced_mean += list(df[df['param_name'] == param]['mean'])[0] - - unreduced_mean /= dp_size - if is_tp_duplicate and (not sequence_parallel or 'embedding' in param): - unreduced_mean /= tp_size - try: - assert_equal(unreduced_mean, reduced_mean) - except AssertionError as e: - errors.append([param, step, e, is_tp_duplicate]) - if errors: - logger.info(errors) - else: - logger.info(f'grad mean is in consist between unreduced grad and reduced grad monitord.') - - -def assert_equal(a, b): - if b == 0 or a == 0: - return - if b == 0: - rel_diff = a - elif a == 0: - rel_diff = b - else: - rel_diff = abs(a / b - 1) - assert rel_diff < 0.01, f'{a}, {b}, {rel_diff}' - - -def valid_total_norm(total_norm, reduced, duplicate_embedding): - steps = len(total_norm) - world_size = len(reduced) - errors = [] - for step in range(steps): - calculated_norm = 0. - for rank in range(world_size): - if len(reduced[rank]) == 0: - if step == 0: - logger.info(f'rank {rank} is duplicated in dp group') - continue - for _, row in reduced[rank][step].iterrows(): - if duplicate_embedding and 'word_embedding' in row['param_name']: - continue - calculated_norm += row['norm'] ** 2 - try: - assert_equal(calculated_norm ** 0.5, total_norm[step]) - except AssertionError as e: - errors.append([step, e]) - if errors: - logger.info('total norm errors: ', errors) - else: - logger.info('grad norm in consist between training log and reduced gradients monitored') - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--monitor_output', '-m', type=str, required=True, - help='path prefix to the output of monitor e.g. monitor_output/Aug12_07-16') - parser.add_argument('--logfile', '-l', type=str, required=True, help='path to the training log file') - parser.add_argument('--tp_size', '-t', type=int, required=True, help='tp parallel size') - parser.add_argument('--dp_size', '-d', type=int, required=True, help='dp parallel size') - parser.add_argument('--pp_size', '-p', type=int, required=True, help='pp parallel size') - parser.add_argument('--untie_embeddings_and_output_weights', '-u', action="store_true", default=False, - help='whether untie_embeddings_and_output_weights in pp parallel') - parser.add_argument('--sequence_parallel', '-s', action="store_true", default=False, - help='whether sequence parallel is enabled. Add -s to store true') - - args = parser.parse_args() - - assert args.tp_size > 0, 'if tp not enabled, set tp_size = 1' - assert args.dp_size > 0, 'if tp not enabled, set dp_size = 1' - assert args.pp_size > 0, 'if tp not enabled, set pp_size = 1' - - total_norm = parse_logfile(args.logfile) - reduced, unreduced = parse_monitor_output(args.monitor_output) - - duplicate_embedding = not args.untie_embeddings_and_output_weights and args.pp_size > 1 - - valid_total_norm(total_norm, reduced, duplicate_embedding) - valid_reduce(reduced, unreduced, args.tp_size, args.dp_size, args.sequence_parallel) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py b/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py index 94afe56ffcfe7571a189c5f6959b2eb9a2779d81..ca339ad64823c20382910a339a1e4136008d675a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py @@ -12,20 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect -from collections import namedtuple -from datetime import timezone, timedelta -from functools import wraps -from datetime import datetime -import os -import re - import torch -from msprobe.core.common.const import MonitorConst, Const from msprobe.pytorch.common.log import logger -from msprobe.core.common.utils import is_int -from msprobe.core.common.file_utils import check_file_or_directory_path device = "cpu" @@ -37,24 +26,6 @@ except ImportError: device = "cuda" NAN_TENSOR_ON_DEVICE = None -FILE_MAX_SIZE = 10 * 1024 * 1024 * 1024 -FILE_NAME_MAX_LENGTH = 255 -DIRECTORY_MAX_LENGTH = 4096 - -beijing_tz = timezone(timedelta(hours=8)) -MVResult = namedtuple('MVResult', ("exp_avg", "exp_avg_sq", "update", "ratio")) -MVGradResult = namedtuple('MVGradResult', ("exp_avg", "exp_avg_sq", "update", "ratio", "grad")) - - -class MsgConst: - """ - Class for log messages const - """ - SPECIAL_CHAR = ["\n", "\r", "\u007F", "\b", "\f", "\t", "\u000B", "%08", "%0a", "%0b", "%0c", "%0d", "%7f"] - - -def get_output_base_dir(): - return os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR) def get_nan_tensor(): @@ -64,16 +35,6 @@ def get_nan_tensor(): return NAN_TENSOR_ON_DEVICE -def filter_special_chars(func): - @wraps(func) - def func_level(msg): - for char in MsgConst.SPECIAL_CHAR: - msg = msg.replace(char, '_') - return func(msg) - - return func_level - - def get_param_struct(param): res = {} if isinstance(param, (tuple, list)): @@ -86,201 +47,4 @@ def get_param_struct(param): else: res['config'] = f'{type(param)}' logger.warning(f'Not support type({type(param)}) now, please check the type of param {param}') - return res - - -def validate_ops(ops): - if not isinstance(ops, list): - raise TypeError("ops should be a list") - valid_ops = [] - for op in ops: - if op not in MonitorConst.OP_LIST: - logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}") - continue - valid_ops.append(op) - if not valid_ops: - default_op = MonitorConst.OP_LIST[0] - valid_ops.append(default_op) - logger.info_on_rank_0(f"There is no valid ops, default op {default_op} is used") - return valid_ops - - -def validate_ranks(ranks): - if not isinstance(ranks, list): - raise TypeError("module_ranks should be a list") - for rank in ranks: - if not isinstance(rank, int) or isinstance(rank, bool): - raise TypeError(f"element in module_ranks should be a int, get {type(rank)}") - - -def validate_targets(targets): - if not isinstance(targets, dict): - raise TypeError('targets in config.json should be a dict') - for module_name, field in targets.items(): - if not isinstance(module_name, str): - raise TypeError('key of targets should be module_name[str] in config.json') - if not isinstance(field, dict): - raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json') - - -def validate_print_struct(print_struct): - if not isinstance(print_struct, bool): - raise TypeError("print_struct should be a bool") - - -def validate_ur_distribution(ur_distribution): - if not isinstance(ur_distribution, bool): - raise TypeError('ur_distribution should be a bool') - - -def validate_xy_distribution(xy_distribution): - if not isinstance(xy_distribution, bool): - raise TypeError('xy_distribution should be a bool') - - -def validate_wg_distribution(wg_distribution): - if not isinstance(wg_distribution, bool): - raise TypeError('wg_distribution should be a bool') - - -def validate_mg_distribution(mg_distribution): - if not isinstance(mg_distribution, bool): - raise TypeError('mg_distribution should be a bool') - - -def validate_param_distribution(param_distribution): - if not isinstance(param_distribution, bool): - raise TypeError('param_distribution should be a bool') - - -def validate_cc_distribution(cc_distribution): - if not isinstance(cc_distribution, dict): - raise TypeError('cc_distribution should be a dictionary') - for key, value in cc_distribution.items(): - if key == 'enable': - if not isinstance(value, bool): - raise TypeError('cc_distribution enable should be a bool') - elif key == 'cc_codeline': - if not isinstance(value, list): - raise TypeError('cc_distribution cc_codeline should be a list') - elif key == 'cc_pre_hook': - if not isinstance(value, bool): - raise TypeError('cc_distribution cc_pre_hook should be a bool') - elif key == 'cc_log_only': - if not isinstance(value, bool): - raise TypeError('cc_distribution cc_log_only should be a bool') - else: - raise TypeError(f'{key} of cc_distribution is not supported.') - - -def validate_squash_name(squash_name): - if not isinstance(squash_name, bool): - raise TypeError('squash_name should be a bool') - - -def validate_alert(alert): - if not isinstance(alert, dict): - raise TypeError('alert should be a dictionary') - rules = alert.get('rules') - if rules and isinstance(rules, list): - for rule in rules: - rule_name = rule.get("rule_name") - if rule_name and rule_name not in MonitorConst.RULE_NAME: - raise TypeError(f"{rule_name} is not supported") - args = rule.get("args") - if args and isinstance(args, dict): - threshold = args.get("threshold") - if not isinstance(threshold, float) or threshold < 0: - raise TypeError('threshold must be float and not less than 0') - dump = alert.get('dump') - if dump and not isinstance(dump, bool): - raise TypeError('dump must be bool.') - - -def validate_step_count_per_record(step_count_per_record): - if not is_int(step_count_per_record): - raise TypeError('step_count_per_record must be int.') - if step_count_per_record < 1: - raise ValueError("step_count_per_record must greater than 0") - if step_count_per_record > 1e6: - raise ValueError("step_count_per_record must smaller than 1e6") - - -def validate_config(config): - config['ops'] = validate_ops(config.get('ops', [])) - - eps = config.get('eps', 1e-8) - if not isinstance(eps, float): - raise TypeError("eps should be a float") - - ranks = config.get("module_ranks", []) - validate_ranks(ranks) - - targets = config.get("targets", {}) - validate_targets(targets) - - print_struct = config.get('print_struct', False) - validate_print_struct(print_struct) - - ur_distribution = config.get('ur_distribution', False) - validate_ur_distribution(ur_distribution) - - xy_distribution = config.get('xy_distribution', False) - validate_xy_distribution(xy_distribution) - - wg_distribution = config.get('wg_distribution', False) - validate_wg_distribution(wg_distribution) - - mg_distribution = config.get('mg_distribution', False) - validate_mg_distribution(mg_distribution) - - param_distribution = config.get('param_distribution', False) - validate_param_distribution(param_distribution) - - cc_distribution = config.get('cc_distribution', {}) - validate_cc_distribution(cc_distribution) - - alert = config.get('alert', {}) - validate_alert(alert) - - step_count_per_record = config.get('step_count_per_record', 1) - validate_step_count_per_record(step_count_per_record) - - squash_name = config.get('squash_name', True) - validate_squash_name(squash_name) - - if not targets: - if xy_distribution: - config["all_xy"] = True - config["targets"] = {"": {}} - - -def time_str2time_digit(time_str): - time_format = '%b%d_%H-%M-%S' - try: - time_digit = datetime.strptime(time_str, time_format) - except Exception as e: - raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \ - of existing output dirpath, like 'Dec03_21-34-40'.") from e - return time_digit - - -def get_target_output_dir(monitor_path, time_start, time_end): - check_file_or_directory_path(monitor_path, isdir=True) - time_start = time_str2time_digit(time_start) if time_start is not None else time_start - time_end = time_str2time_digit(time_end) if time_end is not None else time_end - if time_start and time_end and time_start > time_end: - raise ValueError(f"time_start({time_start}) greater than time_end({time_end})") - result = {} - for dirname in os.listdir(monitor_path): - match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname) - if not match: - continue - time_tag = match.group(1) - rank = match.group(2) - target_time = time_str2time_digit(time_tag) - start_ok = time_start is None or target_time >= time_start - end_ok = time_end is None or target_time <= time_end - if start_ok and end_ok: - result[rank] = os.path.join(monitor_path, dirname) - return result + return res \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py index 7a265e70fa4cbe95c897c35d68e4afa8ebd77249..18d8e0f1d0ab00fb723eafa9d0dc17d92bd164a6 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py @@ -125,8 +125,6 @@ class Saver: def write_summary_csv(self, test_result): test_rows = [] - if self.stack_info: - test_rows[0].append(self.COLUMN_STACK_INFO) check_op_str_pattern_valid(test_result.api_name) df_row = [test_result.api_name, test_result.is_fwd_success, test_result.is_bwd_success] diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py index b9201cfaac74e38bbbaee468b6c452895f8b38f9..4615f45543f1423ecc73cfe8ae4a88c2c83bfedd 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py @@ -16,6 +16,7 @@ import json import os import time +import multiprocessing from multiprocessing import Pool import torch @@ -52,6 +53,7 @@ class PtdbgDispatch(TorchDispatchMode): return if dump_path is None: logger.error("Please set dump_path when dump_mode is config!") + raise DispatchException("Please set dump_path when dump_mode is config!") check_file_or_directory_path(dump_path, True) self.device_id = torch_npu._C._npu_getDevice() @@ -85,6 +87,11 @@ class PtdbgDispatch(TorchDispatchMode): self.get_ops(yaml_path) self.lock = None + max_process_num = max(int((multiprocessing.cpu_count() + 1) // Const.CPU_QUARTER), 1) + if process_num > max_process_num: + logger.error(f"process_num should be less than or equal to {max_process_num}, but got {process_num}!") + raise DispatchException(f'process_num should be less than or equal to {max_process_num}, ' + f'but got {process_num}!') if process_num > 0: self.pool = Pool(process_num) if debug: @@ -97,7 +104,7 @@ class PtdbgDispatch(TorchDispatchMode): if not is_npu: return - logger.info(f'start write compare csv: Rank[{self.device_id}], Pid[{os.getpid()}') + logger.info(f'start write compare csv: Rank[{self.device_id}], Pid[{os.getpid()}]') if self.process_num > 0: self.pool.close() @@ -115,6 +122,8 @@ class PtdbgDispatch(TorchDispatchMode): if len(json_line_data) == 0: break msg = json.loads(json_line_data) + if len(msg) < 2: + raise ValueError("JSON data does not contain enough elements. Expected at least 2 elements.") self.all_summary[msg[0]] = msg[1] fp_handle.close() @@ -199,8 +208,10 @@ class PtdbgDispatch(TorchDispatchMode): dispatch_workflow(run_param, data_info) else: self.lock.acquire() - self.all_summary.append([]) - self.lock.release() + try: + self.all_summary.append([]) + finally: + self.lock.release() run_param.process_flag = True if self.check_fun(func, run_param): data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, None, npu_out_cpu, cpu_out, diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py index b185bc1110d4062d8a31b9cc94dc946d8fb8456c..a254599f683acec08421d38921c8748fd319d165 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py @@ -19,7 +19,9 @@ import os from datetime import datetime, timezone import torch -from msprobe.core.common.file_utils import FileOpen, save_npy, save_json +from msprobe.core.common.const import Const +from msprobe.core.common.decorator import recursion_depth_decorator +from msprobe.core.common.file_utils import FileOpen, save_npy, save_json, remove_path, check_link from msprobe.pytorch.common.log import logger @@ -91,6 +93,7 @@ def support_basic_type(data): return False +@recursion_depth_decorator("dump_data") def dump_data(data, prefix, dump_path): if isinstance(data, (tuple, list)) and data: for i, item in enumerate(data): @@ -107,8 +110,17 @@ def dump_data(data, prefix, dump_path): def save_temp_summary(api_index, single_api_summary, path, lock): summary_path = os.path.join(path, f'summary.json') lock.acquire() - data = [api_index, single_api_summary] - save_json(summary_path, data, mode='a') + try: + data = [api_index, single_api_summary] + save_json(summary_path, data, mode='a') + except Exception as e: + logger.error(f'save temp summary error:{e}') + try: + remove_path(summary_path) + except FileNotFoundError: + logger.error(f'file not found:{summary_path}') + finally: + lock.release() def dispatch_workflow(run_param: DispatchRunParam, data_info: DisPatchDataInfo): diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py index ae8b9435a34ced607d4e70fab615b2b017083fe9..37105551a3bccca548fe2b6594f4848324746b49 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py @@ -27,8 +27,10 @@ else: pta_cpu_device = torch.device("cpu") from msprobe.core.common.const import CompareConst +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.pytorch.common.log import logger + cpu_device = torch._C.device("cpu") COLOR_RED = '\033[31m' COLOR_GREEN = '\033[32m' @@ -85,6 +87,7 @@ def get_callstack(): return callstack +@recursion_depth_decorator("data_to_cpu") def data_to_cpu(data, deep, data_cpu): global cpu_device list_cpu = [] diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/interactive_cli.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/interactive_cli.py index ac6f3d234e3a6681a580f16e56d94204223102f1..7f08b7929cd46961cb5850f16aa6ad7d7eace533 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/interactive_cli.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/interactive_cli.py @@ -45,12 +45,7 @@ class InteractiveCli(cmd.Cmd): @catch_exception def default(self, line=""): - self.util.execute_command(line) - return False - - @catch_exception - def do_run(self, line=""): - self.util.execute_command(line) + self.stdout.write("Command invalid, Only support command start with cad/vc/dc/pk/cn/pt\n") @catch_exception def do_vc(self, line=""): diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py index 66229d36b8d0b532eea48f1aa5d96e178ed80cdc..a3886641ed3af7180fab0c33b5c0cce53ee72974 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import hashlib import os import re import subprocess import sys import time +import zlib from collections import namedtuple import numpy as np @@ -37,8 +37,6 @@ try: from rich.table import Table from rich import print as rich_print from rich.columns import Columns - - install() except ImportError as err: install = None Panel = None @@ -114,11 +112,12 @@ class Util: @staticmethod def get_md5_for_numpy(obj): np_bytes = obj.tobytes() - md5_hash = hashlib.md5(np_bytes) - return md5_hash.hexdigest() + md5_crc = zlib.crc32(np_bytes) + return f"{md5_crc:08x}" @staticmethod def deal_with_dir_or_file_inconsistency(output_path): + logger.warning(f"Trying to delete {output_path}") remove_path(output_path) raise ParseException("Inconsistent directory structure or file.") @@ -227,7 +226,7 @@ class Util: def check_path_valid(self, path): path = self.path_strip(path) if not path or not os.path.exists(path): - self.log.error("The path %s does not exist." % path) + self.log.error("The path does not exist.") raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) isdir = check_file_type(path) == FileCheckConst.DIR check_file_or_directory_path(path, isdir=isdir) @@ -235,7 +234,7 @@ class Util: def check_files_in_path(self, path): if os.path.isdir(path) and len(os.listdir(path)) == 0: - self.log.error("No files in %s." % path) + self.log.error("No files found in path.") raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) def npy_info(self, source_data): @@ -264,7 +263,7 @@ class Util: match = re_pattern.match(name) if not match: continue - if extern_pattern != '' and re_pattern.match(extern_pattern) and not re.match(extern_pattern, name): + if extern_pattern != '' and re_pattern.match(extern_pattern) and not name.startswith(extern_pattern): continue file_list[name] = gen_info_func(name, match, file["root"]) return file_list diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py index 5b53831b1c6fb9280dbad5621ee222baa2712225..5e54a45f948d6bc081e12f88467ba6c876c1cbc9 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py @@ -83,4 +83,3 @@ class Visualization: self.util.log.info("\nStatistic Info:") title_printed = True self.util.log.info(summery_info) - pkl_handle.close() diff --git a/debug/accuracy_tools/msprobe/pytorch/pt_config.py b/debug/accuracy_tools/msprobe/pytorch/pt_config.py index 8293ac969490b103eef630081b6001234ca8bb07..4e2770f0258b52366dcc1019feae71aacfda29fc 100644 --- a/debug/accuracy_tools/msprobe/pytorch/pt_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/pt_config.py @@ -16,9 +16,9 @@ import os import re -from msprobe.core.common.const import Const +from msprobe.core.common.const import Const, FileCheckConst from msprobe.core.common.exceptions import MsprobeException -from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path, check_crt_valid +from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path, FileChecker from msprobe.core.common.log import logger from msprobe.core.common.utils import is_int from msprobe.core.common_config import BaseConfig, CommonConfig @@ -35,44 +35,15 @@ from msprobe.pytorch.hook_module.utils import get_ops class TensorConfig(BaseConfig): def __init__(self, json_config): super().__init__(json_config) - self.online_run_ut = json_config.get("online_run_ut", False) - self.nfs_path = json_config.get("nfs_path", "") - self.host = json_config.get("host", "") - self.port = json_config.get("port", -1) - self.tls_path = json_config.get("tls_path", "./") - self.online_run_ut_recompute = json_config.get("online_run_ut_recompute", False) self.check_config() + self._check_summary_mode() self._check_file_format() - if self.online_run_ut: - self._check_online_run_ut() + def _check_file_format(self): if self.file_format is not None and self.file_format not in ["npy", "bin"]: raise Exception("file_format is invalid") - def _check_online_run_ut(self): - if not isinstance(self.online_run_ut, bool): - raise Exception(f"online_run_ut: {self.online_run_ut} is invalid.") - - if not isinstance(self.online_run_ut_recompute, bool): - raise Exception(f"online_run_ut_recompute: {self.online_run_ut_recompute} is invalid.") - - if self.nfs_path: - check_file_or_directory_path(self.nfs_path, isdir=True) - return - - if self.tls_path: - check_file_or_directory_path(self.tls_path, isdir=True) - check_file_or_directory_path(os.path.join(self.tls_path, "client.key")) - check_file_or_directory_path(os.path.join(self.tls_path, "client.crt")) - check_crt_valid(os.path.join(self.tls_path, "client.crt")) - - if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host): - raise Exception(f"host: {self.host} is invalid.") - - if not isinstance(self.port, int) or not (0 < self.port <= 65535): - raise Exception(f"port: {self.port} is invalid, port range 0-65535.") - class StatisticsConfig(BaseConfig): def __init__(self, json_config): @@ -80,9 +51,8 @@ class StatisticsConfig(BaseConfig): self.check_config() self._check_summary_mode() - def _check_summary_mode(self): - if self.summary_mode and self.summary_mode not in ["statistics", "md5"]: - raise Exception("summary_mode is invalid") + self.tensor_list = json_config.get("tensor_list", []) + self._check_str_list_config(self.tensor_list, "tensor_list") class OverflowCheckConfig(BaseConfig): @@ -95,6 +65,8 @@ class OverflowCheckConfig(BaseConfig): def check_overflow_config(self): if self.overflow_nums is not None and not is_int(self.overflow_nums): raise Exception("overflow_num is invalid") + if self.overflow_nums is not None and self.overflow_nums != -1 and self.overflow_nums <= 0: + raise Exception("overflow_nums should be -1 or positive integer") if self.check_mode is not None and self.check_mode not in ["all", "aicore", "atomic"]: raise Exception("check_mode is invalid") @@ -148,7 +120,7 @@ class FreeBenchmarkCheckConfig(BaseConfig): self.pert_mode in PytorchFreeBenchmarkConst.CPU_MODE_LIST ): msg = ( - f"You neet to and can only set fuzz_device as {DeviceType.CPU} " + f"You need to and can only set fuzz_device as {DeviceType.CPU} " f"when pert_mode in {PytorchFreeBenchmarkConst.CPU_MODE_LIST}" ) logger.error_log_with_exp( @@ -246,12 +218,7 @@ class RunUTConfig(BaseConfig): self.white_list = json_config.get("white_list", Const.DEFAULT_LIST) self.black_list = json_config.get("black_list", Const.DEFAULT_LIST) self.error_data_path = json_config.get("error_data_path", Const.DEFAULT_PATH) - self.is_online = json_config.get("is_online", False) - self.nfs_path = json_config.get("nfs_path", "") - self.host = json_config.get("host", "") - self.port = json_config.get("port", -1) - self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST) - self.tls_path = json_config.get("tls_path", "./") + self.check_run_ut_config() @classmethod @@ -269,22 +236,11 @@ class RunUTConfig(BaseConfig): if not os.path.exists(error_data_path): raise Exception("error_data_path: %s does not exist" % error_data_path) - @classmethod - def check_nfs_path_config(cls, nfs_path): - if nfs_path and not os.path.exists(nfs_path): - raise Exception("nfs_path: %s does not exist" % nfs_path) - - @classmethod - def check_tls_path_config(cls, tls_path): - if tls_path and not os.path.exists(tls_path): - raise Exception("tls_path: %s does not exist" % tls_path) def check_run_ut_config(self): RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list) RunUTConfig.check_filter_list_config(Const.BLACK_LIST, self.black_list) RunUTConfig.check_error_data_path_config(self.error_data_path) - RunUTConfig.check_nfs_path_config(self.nfs_path) - RunUTConfig.check_tls_path_config(self.tls_path) class GradToolConfig(BaseConfig): diff --git a/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py b/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py new file mode 100644 index 0000000000000000000000000000000000000000..d9041ffc5579e6ebfe8897336c8b0985bd0f5e50 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py @@ -0,0 +1,66 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core.common.utils import Const +from msprobe.core.service import BaseService +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.common.utils import get_rank_if_initialized +from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser +from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate, redirect_wait +from msprobe.pytorch.hook_module.hook_module import HOOKModule +from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager +from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook +from msprobe.pytorch.hook_module.script_wrapper import wrap_script_func, preprocess_func + + +class PytorchService(BaseService): + @property + def _get_framework_type(self): + return Const.PT_FRAMEWORK + + @staticmethod + def _get_current_rank(): + return get_rank_if_initialized() + + def reset_status(self): + self._reset_status() + + def _init_specific_components(self): + self.logger = logger + self.api_register = get_api_register() + self.module_processor = ModuleProcesser(self.data_collector.scope) + self.hook_manager = PytorchHookManager(self.data_collector, self.config) + self.api_template = ApiTemplate + + def _register_hook(self): + if self._is_mix_level: + register_optimizer_hook(self.data_collector) + + def _register_api_hook(self): + preprocess_func() + super()._register_api_hook() + wrap_script_func() + redirect_wait() + + def _register_module_hook(self): + ModuleProcesser.enable_module_dump = True + self.module_processor.register_module_hook(self.model, self.build_hook) + self.logger.info(f"The module {self.config.task} hook function is successfully mounted to the model.") + + + def _reset_status(self): + super()._reset_status() + ModuleProcesser.reset_module_stats() + HOOKModule.reset_module_stats() diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py deleted file mode 100644 index fd81a7f1cf064506a4fb91481429828c97113509..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ /dev/null @@ -1,470 +0,0 @@ -# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -import os -from collections import namedtuple, defaultdict - -import torch -from msprobe.core.common.const import Const -from msprobe.core.common.exceptions import DistributedNotInitializedError -from msprobe.core.common.file_utils import create_directory -from msprobe.core.common.utils import print_tools_ends_info, DumpPathAggregation -from msprobe.core.data_dump.data_collector import build_data_collector -from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs -from msprobe.core.data_dump.scope import BaseScope -from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData -from msprobe.pytorch.common.log import logger -from msprobe.pytorch.common.utils import get_rank_if_initialized, is_recomputation -from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json -from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser -from msprobe.pytorch.hook_module.api_registry import api_register -from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook - -torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' -if torch_version_above_or_equal_2: - from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch - -HookFn = namedtuple('hookFn', ['pre_hook', 'forward_hook', 'backward_hook', 'forward_hook_torch_version_below_2']) - - -class Service: - def __init__(self, config): - self.model = None - self.config = config - self.data_collector = build_data_collector(config) - self.module_processor = ModuleProcesser(self.data_collector.scope) - self.switch = False - self.inner_switch = False - self.current_iter = 0 - self.first_start = True - self.current_rank = None - self.dump_iter_dir = None - self.should_stop_service = False - self.attl = None - self.params_grad_info = {} - self.hook_handle_dict = {} - # 提前注册,确保注册尽可能多的API hook - self.register_api_hook() - self.init_for_debug_level() - - def build_hook(self, module_type, name): - def pre_hook(api_or_module_name, module, args, kwargs): - if not self.should_execute_hook(module_type, module, True): - return args, kwargs - is_recompute = is_recomputation() - - self.inner_switch = True - if module_type == BaseScope.Module_Type_Module: - api_or_module_name = module.mindstudio_reserved_name[-1] - else: - module.forward_data_collected = True - HOOKModule.add_module_count(name) - self.data_collector.update_api_or_module_name(api_or_module_name) - - if self.config.online_run_ut: - self.inner_switch = False - return None, None - if self.data_collector: - module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None) - self.data_collector.forward_input_data_collect( - api_or_module_name, - module, - pid, - module_input_output, - is_recompute - ) - - self.inner_switch = False - return args, kwargs - - def grad_hook(module, ori_name, param_name): - def hook_fn(grad): - if not self.should_execute_hook(module_type, module, False): - return grad - self.inner_switch = True - self.data_collector.params_data_collect(ori_name, param_name, pid, grad) - self.inner_switch = False - return grad - - return hook_fn - - def register_param_hook(ori_name, module, params_dict): - ''' - 注册参数hook - ''' - # data_mode为forward时,不注册参数hook - if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode): - for param_name, param in params_dict.items(): - if param.requires_grad: - name = ori_name + Const.SEP + param_name - old_handle = self.hook_handle_dict.get(name) - if old_handle and hasattr(old_handle, "remove"): - old_handle.remove() - handle = param.register_hook(grad_hook(module, ori_name, param_name)) - self.hook_handle_dict[name] = handle - - def init_params_grad_info(module, params_dict): - ''' - 初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位 - ''' - if not params_dict: - return - if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode): - grad_name = module.params_grad_name if hasattr(module, 'params_grad_name') else None - # 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中 - if not self.params_grad_info.get(grad_name): - data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}} - # 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位 - if data_info.get(grad_name): - # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新 - self.data_collector.handle_data(grad_name, data_info, - flush=self.data_collector.data_processor.is_terminated) - # 记录当前模块的参数梯度信息已占位 - self.params_grad_info[grad_name] = True - - def forward_hook(api_or_module_name, module, args, kwargs, output): - if not self.should_execute_hook(module_type, module, True): - return None - is_recompute = is_recomputation() - - self.inner_switch = True - if self.config.online_run_ut: - self.data_collector.update_api_or_module_name(api_or_module_name) - if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name): - return None - api_data = ApiData( - api_or_module_name[:-len(Const.FORWARD_NAME_SUFFIX)], - args, - kwargs, - output, - self.current_iter, - self.current_rank - ) - self.attl_send(api_data) - self.inner_switch = False - return None - - module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output) - if module_type == BaseScope.Module_Type_Module: - api_or_module_name = module.mindstudio_reserved_name[-1] - self.data_collector.update_api_or_module_name(api_or_module_name) - params_dict = {} - if self.config.task != Const.STRUCTURE: - params_dict = { - key.split(Const.SEP)[-1]: value - for key, value in module.named_parameters(recurse=False) - } - setattr(module_input_output, Const.PARAMS, params_dict) - # 判断是否需要注册参数hook - if params_dict: - ori_name = api_or_module_name.rsplit(Const.SEP, 2)[0] - grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD - # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook - setattr(module, 'params_grad_name', grad_name) - register_param_hook(ori_name, module, params_dict) - self.data_collector.forward_data_collect( - api_or_module_name, - module, - pid, - module_input_output, - is_recompute - ) - init_params_grad_info(module, params_dict) - else: - self.data_collector.update_api_or_module_name(api_or_module_name) - self.data_collector.forward_output_data_collect( - api_or_module_name, - module, - pid, - module_input_output, - is_recompute - ) - - if self.data_collector.if_return_forward_new_output(): - forward_new_output = self.data_collector.get_forward_new_output() - self.inner_switch = False - return forward_new_output - self.inner_switch = False - return output - - def forward_hook_torch_version_below_2(api_or_module_name, module, args, output): - return forward_hook(api_or_module_name, module, args, {}, output) - - def backward_hook(api_or_module_name, module, grad_input, grad_output): - if not self.should_execute_hook(module_type, module, False): - return - is_recompute = is_recomputation() - - self.inner_switch = True - if module_type == BaseScope.Module_Type_Module: - api_or_module_name = module.mindstudio_reserved_name[-1] - self.data_collector.update_api_or_module_name(api_or_module_name) - - if self.config.online_run_ut: - self.inner_switch = False - return - - if self.data_collector: - # 此处获取到的grad_input实际为反向过程的输出数据,grad_output为反向过程的输入数据,因此传入时调换顺序 - module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input) - self.data_collector.backward_data_collect( - api_or_module_name, - module, - pid, - module_input_output, - is_recompute - ) - self.inner_switch = False - - pid = os.getpid() - full_forward_name = None - full_backward_name = None - if module_type == BaseScope.Module_Type_API: - full_forward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD - full_backward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.BACKWARD - pre_forward_hook_fn = functools.partial(pre_hook, full_forward_name) - forward_hook_fn = functools.partial(forward_hook, full_forward_name) - backward_hook_fn = functools.partial(backward_hook, full_backward_name) - forward_hook_torch_version_below_2_fn = functools.partial( - forward_hook_torch_version_below_2, - full_forward_name - ) - return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn) - - def start(self, model): - if self.config.level == Const.LEVEL_DEBUG: - return - if self.need_stop_service(): - return - - self.model = model - if self.first_start: - try: - self.current_rank = get_rank_if_initialized() - except DistributedNotInitializedError: - self.current_rank = None - self.attl_init() - - if self.config.rank and self.current_rank not in self.config.rank: - return - self.register_module_hook() - if self.config.level == Const.LEVEL_MIX: - register_optimizer_hook(self.data_collector) - self.first_start = False - if self.config.online_run_ut and torch_version_above_or_equal_2: - run_ut_dispatch(self.attl, True, self.config.online_run_ut_recompute) - self.switch = True - logger.info_on_rank_0(f"Dump switch is turned on at step {self.current_iter}. ") - if not self.config.online_run_ut: - self.create_dirs() - logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.") - - def stop(self): - if self.config.level == Const.LEVEL_DEBUG: - return - if self.should_stop_service: - return - if self.config.step and self.current_iter not in self.config.step: - return - if self.config.rank and self.current_rank not in self.config.rank: - return - self.switch = False - if self.config.level == Const.LEVEL_L2: - return - if self.config.online_run_ut and torch_version_above_or_equal_2: - run_ut_dispatch(self.attl, False, self.config.online_run_ut_recompute) - return - if self.config.async_dump: - self.data_collector.fill_stack_tensor_data() - if self.config.task == Const.TENSOR: - self.data_collector.data_processor.dump_async_data() - self.data_collector.write_json() - - def step(self): - if self.config.level == Const.LEVEL_DEBUG: - return - if self.should_stop_service: - return - if self.config.async_dump: - self.data_collector.fill_stack_tensor_data() - if self.config.task == Const.TENSOR: - self.data_collector.data_processor.dump_async_data() - self.data_collector.write_json() - self.current_iter += 1 - self.data_collector.update_iter(self.current_iter) - self.reset_status() - - def need_stop_service(self): - if self.should_stop_service: - return True - end_service = self.config.step and self.current_iter > max(self.config.step) or \ - self.data_collector and self.data_collector.data_processor.is_terminated - if end_service: - if self.config.online_run_ut: - # send stop signal if online_run_ut - self.attl_stop() - self.switch = False - self.should_stop_service = True - print_tools_ends_info() - return True - if self.config.step and self.current_iter not in self.config.step: - return True - return False - - def should_execute_hook(self, hook_type, module, is_forward): - is_module_hook = hook_type == BaseScope.Module_Type_Module - if is_module_hook and not self.switch: - return False - elif not is_module_hook and is_forward and not self.switch: - return False - elif not is_module_hook and not is_forward and not module.forward_data_collected: - return False - - if self.inner_switch: - return False - if not self.data_collector or self.data_collector.data_processor.is_terminated: - return False - return True - - def create_dirs(self): - create_directory(self.config.dump_path) - self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}") - cur_rank = self.current_rank if self.current_rank is not None else '' - if self.config.level == Const.LEVEL_L2: - create_directory(self.dump_iter_dir) - kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank) - self.config.kernel_config_path = kernel_config_path - return - - dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}") - create_directory(dump_dir) - if self.config.task in self.data_collector.tasks_need_tensor_data: - dump_data_dir = os.path.join(dump_dir, "dump_tensor_data") - create_directory(dump_data_dir) - else: - dump_data_dir = None - - dump_path_aggregation = DumpPathAggregation() - dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json") - dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json") - dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json") - dump_path_aggregation.dump_tensor_data_dir = dump_data_dir - dump_path_aggregation.free_benchmark_file_path = os.path.join(dump_dir, "free_benchmark.csv") - self.data_collector.update_dump_paths(dump_path_aggregation) - self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK) - - def register_api_hook(self): - if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]: - logger.info_on_rank_0(f"The api {self.config.task} hook function is successfully mounted to the model.") - api_register.initialize_hook( - functools.partial(self.build_hook, BaseScope.Module_Type_API), - self.config.online_run_ut - ) - api_register.api_modularity() - - def register_module_hook(self): - if self.config.level in [Const.LEVEL_L0, Const.LEVEL_MIX]: - logger.info_on_rank_0(f"The module {self.config.task} hook function is successfully mounted to the model.") - self.module_processor.register_module_hook(self.model, self.build_hook) - - def attl_init(self): - if self.config.online_run_ut: - from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTLConfig, ATTL - attl_config = ATTLConfig(is_benchmark_device=False, - connect_ip=self.config.host, - connect_port=self.config.port, - nfs_path=self.config.nfs_path, - tls_path=self.config.tls_path) - need_dump = len(self.config.rank) == 0 or self.current_rank in self.config.rank - self.attl = ATTL('npu', attl_config, need_dump=need_dump) - if self.config.nfs_path: - self.attl.upload("start") - - def attl_send(self, api_data): - logger.info(f"tools is dumping api: {api_data.name}, rank: {self.current_rank}") - api_type, _, _ = api_data.name.split(Const.SEP) - if api_type in [Const.DISTRIBUTED]: - logger.info(f"api {api_data.name} is not supported, skip") - return - if self.config.nfs_path: - self.attl.upload(api_data) - else: - self.attl.send(api_data) - - def attl_stop(self): - if self.config.nfs_path: - self.attl.upload("end") - elif self.attl.socket_manager is not None: - logger.info(f"pid: {os.getpid()} finished, start send STOP signal.") - self.attl.socket_manager.send_stop_signal() - - def reset_status(self): - ModuleProcesser.reset_module_stats() - HOOKModule.reset_module_stats() - self.data_collector.reset_status() - self.params_grad_info.clear() - - if self.config.level == Const.LEVEL_L2: - self.data_collector.data_processor.reset_status() - return - if self.config.step and self.current_iter not in self.config.step: - return - if self.config.rank and self.current_rank not in self.config.rank: - return - - def init_for_debug_level(self): - if not (self.config.level == Const.LEVEL_DEBUG and self.config.task in [Const.TENSOR, Const.STATISTICS]): - return - try: - self.current_rank = get_rank_if_initialized() - except DistributedNotInitializedError: - self.current_rank = None - - # dir: dump_path -- rank{} -- debug.json - self.dump_iter_dir = self.config.dump_path - cur_rank = self.current_rank if self.current_rank is not None else '' - dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}") - create_directory(dump_dir) - if self.config.task in self.data_collector.tasks_need_tensor_data: - dump_data_dir = os.path.join(dump_dir, "dump_tensor_data") - create_directory(dump_data_dir) - else: - dump_data_dir = None - - dump_path_aggregation = DumpPathAggregation() - dump_path_aggregation.dump_tensor_data_dir = dump_data_dir - dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json") - self.data_collector.update_dump_paths(dump_path_aggregation) - self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK) - - self.debug_variable_counter = defaultdict(int) - - def save(self, variable, name, save_backward): - if self.config.level != Const.LEVEL_DEBUG: - return - count = self.debug_variable_counter[name] - self.debug_variable_counter[name] += 1 - - name_with_count = f"{name}.{count}" - grad_name_with_count = f"{name}_grad.{count}" - - # forward save - self.data_collector.debug_data_collect_forward(variable, name_with_count) - - # backward save - if save_backward: - self.data_collector.debug_data_collect_backward(variable, grad_name_with_count) diff --git a/debug/accuracy_tools/msprobe/test/common_set_up/__init__.py b/debug/accuracy_tools/msprobe/test/common_set_up/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/test/common_set_up/mindtorch.py b/debug/accuracy_tools/msprobe/test/common_set_up/mindtorch.py new file mode 100644 index 0000000000000000000000000000000000000000..665d17c21e743fb5ffe6a0d9e014fe0a2da4af99 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/common_set_up/mindtorch.py @@ -0,0 +1,29 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mindspore import Tensor +import torch + + +def create_msa_tensor(data, dtype=None): + return Tensor(data, dtype) + + +tensor_tensor = torch.tensor +setattr(torch, 'tensor', create_msa_tensor) + + +def reset_torch_tensor(): + setattr(torch, 'tensor', tensor_tensor) diff --git a/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py b/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py new file mode 100644 index 0000000000000000000000000000000000000000..696f85a7196d9b5ec8fd8256857ea2db10ca8acf --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +from unittest import TestCase +from unittest.mock import MagicMock + +import mindspore as ms +from mindspore import mint + +try: + from mint import distributed +except ImportError: + distributed = MagicMock() + setattr(mint, 'distributed', distributed) + +from mindspore.common.api import _pynative_executor +mock_requires_grad = MagicMock(return_value=True) +setattr(_pynative_executor, "requires_grad", mock_requires_grad) + +from mindspore import ops +if not hasattr(ops, 'DumpGradient'): + DumpGradient = MagicMock() + setattr(ops, 'DumpGradient', DumpGradient) + +# ensure not to import torch_npu +from msprobe.mindspore import mindspore_service +from msprobe.mindspore.monitor import common_func + +from .mindtorch import reset_torch_tensor +from msprobe.mindspore.common import utils +from msprobe.mindspore.common.utils import is_mindtorch, register_backward_hook_functions + +utils.mindtorch_check_result = None +importlib.reload(mindspore_service) +importlib.reload(common_func) +reset_torch_tensor() + + +def register_backward_pre_hook(*args, **kwargs): + pass + + +register_backward_hook_functions['full'] = ms.nn.Cell.register_backward_hook +register_backward_hook_functions["pre"] = register_backward_pre_hook + + +class SetUp(TestCase): + def test_case(self): + self.assertTrue(hasattr(mint, 'distributed')) + self.assertTrue(hasattr(_pynative_executor, 'requires_grad')) + self.assertTrue(is_mindtorch()) + utils.mindtorch_check_result = None diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_db_manager.py b/debug/accuracy_tools/msprobe/test/core_ut/common/test_db_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..1537f689e051024545907e3f26aebb8dbd5553d2 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_db_manager.py @@ -0,0 +1,246 @@ +import unittest +import sqlite3 +import os +import tempfile +from typing import Dict, List +from unittest.mock import patch, MagicMock + +from msprobe.pytorch.common.log import logger +from msprobe.core.common.db_manager import DBManager + + +class TestDBManager(unittest.TestCase): + def setUp(self): + # 创建临时数据库文件 + self.temp_db = tempfile.NamedTemporaryFile(delete=False, suffix='.db') + self.db_path = self.temp_db.name + self.db_manager = DBManager(self.db_path) + + # 创建测试表 + self.test_table = "test_table" + self.create_test_table() + + def tearDown(self): + # 关闭并删除临时数据库文件 + if hasattr(self, 'temp_db'): + self.temp_db.close() + os.unlink(self.db_path) + + def create_test_table(self): + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(f""" + CREATE TABLE IF NOT EXISTS {self.test_table} ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + value INTEGER, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP + ) + """) + conn.commit() + + def test_get_connection_success(self): + """测试成功获取数据库连接""" + conn, curs = self.db_manager._get_connection() + self.assertIsInstance(conn, sqlite3.Connection) + self.assertIsInstance(curs, sqlite3.Cursor) + self.db_manager._release_connection(conn, curs) + + @patch.object(logger, 'error') + def test_get_connection_success_failed(self, mock_logger): + """测试错误日志记录""" + with patch('sqlite3.connect', side_effect=sqlite3.Error("Test error")): + with self.assertRaises(sqlite3.Error): + self.db_manager._get_connection() + mock_logger.assert_called_with( + "Database connection failed: Test error") + + def test_insert_data_basic(self): + """测试基本数据插入""" + test_data = [ + (1, "item1", 100), + (2, "item2", 200) + ] + columns = ["id", "name", "value"] + + inserted = self.db_manager.insert_data( + table_name=self.test_table, + data=test_data, + key_list=columns + ) + self.assertEqual(inserted, 2) + + # 验证数据是否实际插入 + results = self.db_manager.select_data( + table_name=self.test_table, + columns=["id", "name", "value"] + ) + self.assertEqual(len(results), 2) + self.assertEqual(results[0]["name"], "item1") + + def test_insert_data_without_keys(self): + """测试无列名的数据插入""" + test_data = [ + (3, "item3", 300, 333), + (4, "item4", 400, 333) + ] + + inserted = self.db_manager.insert_data( + table_name=self.test_table, + data=test_data + ) + self.assertEqual(inserted, 2) + + def test_insert_data_empty(self): + """测试空数据插入""" + inserted = self.db_manager.insert_data( + table_name=self.test_table, + data=[] + ) + self.assertEqual(inserted, 0) + + def test_insert_data_mismatch_keys(self): + """测试列名与数据不匹配的情况""" + test_data = [(5, "item5")] + with self.assertRaises(ValueError): + self.db_manager.insert_data( + table_name=self.test_table, + data=test_data, + key_list=["id", "name", "value"] # 多了一个列 + ) + + def test_select_data_basic(self): + """测试基本数据查询""" + # 先插入测试数据 + self.db_manager.insert_data( + table_name=self.test_table, + data=[(10, "test10", 1000)], + key_list=["id", "name", "value"] + ) + + results = self.db_manager.select_data( + table_name=self.test_table, + columns=["name", "value"], + where={"id": 10} + ) + + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["name"], "test10") + self.assertEqual(results[0]["value"], 1000) + + def test_select_data_no_where(self): + """测试无条件查询""" + # 插入多条数据 + test_data = [ + (20, "item20", 2000), + (21, "item21", 2100) + ] + self.db_manager.insert_data( + table_name=self.test_table, + data=test_data, + key_list=["id", "name", "value"] + ) + + results = self.db_manager.select_data( + table_name=self.test_table, + columns=["id", "name", "value"] + ) + self.assertGreaterEqual(len(results), 2) + + def test_update_data_basic(self): + """测试基本数据更新""" + # 先插入测试数据 + self.db_manager.insert_data( + table_name=self.test_table, + data=[(30, "old_name", 3000)], + key_list=["id", "name", "value"] + ) + + updated = self.db_manager.update_data( + table_name=self.test_table, + updates={"name": "new_name", "value": 3500}, + where={"id": 30} + ) + self.assertEqual(updated, 1) + + # 验证更新结果 + results = self.db_manager.select_data( + table_name=self.test_table, + columns=["id", "name", "value"], + where={"id": 30} + ) + self.assertEqual(results[0]["name"], "new_name") + self.assertEqual(results[0]["value"], 3500) + + def test_execute_sql_select(self): + """测试执行SELECT SQL语句""" + self.db_manager.insert_data( + table_name=self.test_table, + data=[(50, "sql_item", 5000)], + key_list=["id", "name", "value"] + ) + + results = self.db_manager.execute_sql( + sql=f"SELECT name, value FROM {self.test_table} WHERE id = ?", + params=(50,) + ) + + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["name"], "sql_item") + + def test_execute_sql_non_select(self): + """测试执行非SELECT SQL语句""" + # 先插入数据 + self.db_manager.insert_data( + table_name=self.test_table, + data=[(60, "to_delete", 6000)], + key_list=["id", "name", "value"] + ) + + # 执行DELETE语句 + self.db_manager.execute_sql( + sql=f"DELETE FROM {self.test_table} WHERE id = 60" + ) + + # 验证数据已被删除 + results = self.db_manager.select_data( + table_name=self.test_table, + columns=["id", "name", "value"], + where={"id": 60} + ) + self.assertEqual(len(results), 0) + + def test_table_exists_true(self): + """测试表存在检查(存在的情况)""" + exists = self.db_manager.table_exists(self.test_table) + self.assertTrue(exists) + + def test_table_exists_false(self): + """测试表存在检查(不存在的情况)""" + exists = self.db_manager.table_exists("non_existent_table") + self.assertFalse(exists) + + def test_execute_multi_sql(self): + """测试批量执行多个SQL语句""" + sql_commands = [ + f"INSERT INTO {self.test_table} (id, name, value) VALUES (70, 'multi1', 7000)", + f"INSERT INTO {self.test_table} (id, name, value) VALUES (71, 'multi2', 7100)", + f"SELECT * FROM {self.test_table} WHERE id IN (70, 71)" + ] + + results = self.db_manager.execute_multi_sql(sql_commands) + + # 应该只有最后一个SELECT语句有结果 + self.assertEqual(len(results), 1) + self.assertEqual(len(results[0]), 2) + + @patch.object(logger, 'error') + def test_db_operation_decorator(self, mock_logger): + """测试数据库操作装饰器""" + # 模拟一个会失败的操作 + with patch.object(self.db_manager, '_get_connection', + side_effect=sqlite3.Error("Test error")): + result = self.db_manager.select_data(table_name=self.test_table) + self.assertIsNone(result) # 装饰器会捕获异常并返回None + mock_logger.assert_called_with( + "Database operation failed: Test error") diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/dump_no_pt_no_ms.json b/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/dump_no_pt_no_ms.json new file mode 100644 index 0000000000000000000000000000000000000000..63a062d8ffa264a0254fc2bab0208dcf951ae094 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/dump_no_pt_no_ms.json @@ -0,0 +1,3 @@ +{ + "task": "tensor" +} \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/ms_dump_no_framework.json b/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/ms_dump_no_framework.json new file mode 100644 index 0000000000000000000000000000000000000000..b223c74b2315af1b9454e5f1e70c29502d449c56 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/ms_dump_no_framework.json @@ -0,0 +1,4 @@ +{ + "task": "tensor", + "type": "mindspore.float16" +} \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/pt_dump_no_framework.json b/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/pt_dump_no_framework.json new file mode 100644 index 0000000000000000000000000000000000000000..2444ae1fd4096b083a9e8a0e51c9166bb990f51f --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/pt_dump_no_framework.json @@ -0,0 +1,4 @@ +{ + "task": "tensor", + "type": "torch.float16" +} \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_file_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/common/test_file_utils.py index 9ed13f78aed57fd4d8153e2f005ea14d4fb33643..ecdc0e7311f1a5258730f9d6d7a27181b8a45dd7 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/common/test_file_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_file_utils.py @@ -1,7 +1,8 @@ +import unittest from unittest.mock import patch, mock_open, MagicMock +from zipfile import ZipFile, ZipInfo +import tempfile -import numpy as np -import pandas as pd import pytest from msprobe.core.common.file_utils import * @@ -246,14 +247,23 @@ class TestFileOperations: save_yaml(str(self.yaml_file), test_data) mock_file.assert_called_once_with(str(self.yaml_file), 'w', encoding='utf-8') assert mock_flock.call_count == 2 - mock_dump.assert_called_once_with(test_data, mock_file(), sort_keys=False) + mock_dump.assert_called_once_with(test_data, mock_file(), sort_keys=False)\ - def test_save_excel(self): + def test_save_excel_tiny(self): df = pd.DataFrame({'col1': [1, 2], 'col2': [3, 4]}) with patch('pandas.DataFrame.to_excel') as mock_to_excel, \ + patch('pandas.ExcelWriter') as mock_writer, \ patch('os.chmod') as mock_chmod: save_excel(self.excel_file, df) - mock_to_excel.assert_called_once_with(str(self.excel_file), index=False) + mock_to_excel.assert_called_once_with(mock_writer().__enter__(), sheet_name='Sheet1', index=False) + + def test_save_excel_large(self): + df = pd.DataFrame({'col1': list(range(1500000)), 'col2': list(range(1500000, 0, -1))}) + with patch('pandas.DataFrame.to_excel') as mock_to_excel, \ + patch('pandas.ExcelWriter') as mock_writer, \ + patch('os.chmod') as mock_chmod: + save_excel(self.excel_file, df) + mock_to_excel.assert_called_with(mock_writer().__enter__(), sheet_name='part_1', index=False) def test_move_file(self): dst_file = self.test_dir / "moved_file" @@ -439,18 +449,19 @@ class TestUtilityOperations: def test_remove_path(self): # Test remove file with patch('os.path.exists', return_value=True), \ - patch('os.path.islink', return_value=True), \ + patch('os.path.islink', return_value=False), \ + patch('os.path.isfile', return_value=True), \ patch('os.remove') as mock_remove: - remove_path(str(self.test_file)) - mock_remove.assert_called_once_with(str(self.test_file)) + remove_path("/test_remove_path/test/test.txt") + mock_remove.assert_called_once_with("/test_remove_path/test/test.txt") # Test remove directory with patch('os.path.exists', return_value=True), \ patch('os.path.islink', return_value=False), \ patch('os.path.isfile', return_value=False), \ patch('shutil.rmtree') as mock_rmtree: - remove_path(str(self.test_dir)) - mock_rmtree.assert_called_once_with(str(self.test_dir)) + remove_path("/test_remove_path/test") + mock_rmtree.assert_called_once_with("/test_remove_path/test") def test_get_json_contents(self): json_content = '{"key": "value"}' @@ -495,24 +506,6 @@ class TestUtilityOperations: assert result[0]['file'] == 'file1.txt' -class TestCertificateOperations: - @pytest.fixture(autouse=True) - def setup(self, tmp_path): - self.cert_file = tmp_path / "test.pem" - self.mock_cert = MagicMock() - self.mock_cert.get_notBefore.return_value = b'20230101000000Z' - self.mock_cert.get_notAfter.return_value = b'20250101000000Z' - self.mock_cert.has_expired.return_value = False - - def test_check_crt_valid(self): - # Test expired certificate - self.mock_cert.has_expired.return_value = True - with patch('OpenSSL.crypto.load_certificate', return_value=self.mock_cert), \ - patch('builtins.open', mock_open(read_data='cert data')), \ - pytest.raises(RuntimeError): - check_crt_valid(self.cert_file) - - class TestDirectoryChecks: @pytest.fixture(autouse=True) def setup(self, tmp_path): @@ -533,4 +526,59 @@ class TestDirectoryChecks: # Test file path check_file_or_directory_path(self.test_file, isdir=False) # Test directory path - check_file_or_directory_path(self.test_dir, isdir=True) \ No newline at end of file + check_file_or_directory_path(self.test_dir, isdir=True) + + +cur_dir = os.path.dirname(os.path.realpath(__file__)) +zip_dir = os.path.join(cur_dir, 'test_temp_zip_file') + + +class TestCheckZipFile(unittest.TestCase): + def setUp(self): + os.makedirs(zip_dir, mode=0o750, exist_ok=True) + + def tearDown(self): + if os.path.exists(zip_dir): + shutil.rmtree(zip_dir) + + @staticmethod + def create_fake_zip_with_sizes(file_sizes): + """创建临时 zip 文件,file_sizes 为每个文件的大小列表,伪造一个具有 file_size=size 的 ZIP 条目""" + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip", dir=zip_dir) + os.close(tmp_fd) + with ZipFile(tmp_path, 'w', allowZip64=True) as zipf: + for i, size in enumerate(file_sizes): + info = ZipInfo(f"file_{i}.bin") + zipf.writestr(info, b'') # 实际内容为空,但声明文件大小为 size + info.file_size = size + return tmp_path + + def test_valid_zip(self): + file_sizes = [100, 200, 300] + zip_path = self.create_fake_zip_with_sizes(file_sizes) + try: + check_zip_file(zip_path) + finally: + os.remove(zip_path) + + def test_single_file_too_large(self): + file_sizes = [FileCheckConst.MAX_FILE_IN_ZIP_SIZE + 1] + zip_path = self.create_fake_zip_with_sizes(file_sizes) + try: + with self.assertRaises(ValueError) as cm: + check_zip_file(zip_path) + self.assertIn("is too large to extract", str(cm.exception)) + finally: + os.remove(zip_path) + + def test_total_size_too_large(self): + count = 20 + size_each = (FileCheckConst.MAX_ZIP_SIZE // count) + 1 + file_sizes = [size_each] * count + zip_path = self.create_fake_zip_with_sizes(file_sizes) + try: + with self.assertRaises(ValueError) as cm: + check_zip_file(zip_path) + self.assertIn("Total extracted size exceeds the limit", str(cm.exception)) + finally: + os.remove(zip_path) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py index 3472ca9018e189ffb48e4d26cfeb79e1ba1ff16d..e7b1ae648a31fe691970f92da43769fc1ec9eda5 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (C) 2024-2025. Huawei Technologies Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,12 +17,12 @@ import json import os import tempfile -from datetime import datetime, timezone +import unittest from unittest import TestCase from unittest.mock import MagicMock, mock_open, patch -import OpenSSL import numpy as np +from pathlib import Path from msprobe.core.common.const import Const from msprobe.core.common.file_utils import ( @@ -30,7 +30,6 @@ from msprobe.core.common.file_utils import ( FileCheckException, check_file_or_directory_path, check_file_size, - check_crt_valid, get_file_content_bytes, get_json_contents, save_json, @@ -39,21 +38,24 @@ from msprobe.core.common.log import logger from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common.utils import (CompareException, check_compare_param, - check_configuration_param, _check_json, check_json_file, check_regex_prefix_format_valid, set_dump_path, get_dump_mode, - get_real_step_or_rank, - get_step_or_rank_from_string, + get_real_step_or_rank, + get_step_or_rank_from_string, get_stack_construct_by_dump_json_path, check_seed_all, safe_get_value, - recursion_depth_decorator, MsprobeBaseException, check_str_param, - is_json_file) + is_json_file, + detect_framework_by_dump_json, + is_save_variable_valid, + get_file_type, + check_dump_json_key) +from msprobe.core.common.decorator import recursion_depth_decorator class TestUtils(TestCase): @@ -131,14 +133,6 @@ class TestUtils(TestCase): self.assertEqual(len(mock__check_json.call_args[0]), 2) self.assertEqual(mock__check_json.call_args[0][1], "stack_path.json") - @patch.object(logger, "error") - def test_check_configuration_param(self, mock_error): - with self.assertRaises(CompareException) as context: - check_configuration_param(stack_mode="False", auto_analyze=True, fuzzy_match=False, - is_print_compare_log=True) - self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) - mock_error.assert_called_with("Invalid input parameter, False which should be only bool type.") - @patch.object(logger, "error") def test__check_json(self, mock_error): class TestOpen: @@ -203,7 +197,7 @@ class TestUtils(TestCase): with self.assertRaises(CompareException) as context: set_dump_path(input_param) self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR) - mock_error.assert_called_with("Please check the json path is valid. npu_path: None, bench_path: bench_path") + mock_error.assert_called_with("Please check the json path is valid and ensure that neither npu_path nor bench_path is None.") @patch.object(logger, "error") def test_get_dump_mode(self, mock_error): @@ -214,27 +208,53 @@ class TestUtils(TestCase): npu_json = { "task": Const.TENSOR, "dump_data_dir": "dump_data_dir", - "data": "data" + "data": {"api": "value"} } input_param["npu_json_path"] = "npu_path" - with patch("msprobe.core.common.utils.load_json", return_value=npu_json): + with patch("msprobe.core.common.utils.load_json", return_value=npu_json), \ + patch("msprobe.core.common.utils.get_file_type", return_value=Const.DUMP_JSON_FILE): dump_mode = get_dump_mode(input_param) self.assertEqual(dump_mode, Const.ALL) npu_json["task"] = Const.STATISTICS with patch("msprobe.core.common.utils.load_json", return_value=npu_json), \ - patch("msprobe.core.common.utils.md5_find", return_value=True): + patch("msprobe.core.common.utils.md5_find", return_value=True), \ + patch("msprobe.core.common.utils.get_file_type", return_value=Const.DUMP_JSON_FILE): dump_mode = get_dump_mode(input_param) self.assertEqual(dump_mode, Const.MD5) npu_json["task"] = Const.OVERFLOW_CHECK - with patch("msprobe.core.common.utils.load_json", return_value=npu_json): + with patch("msprobe.core.common.utils.load_json", return_value=npu_json), \ + patch("msprobe.core.common.utils.get_file_type", return_value=Const.DUMP_JSON_FILE): with self.assertRaises(CompareException) as context: dump_mode = get_dump_mode(input_param) self.assertEqual(context.exception.code, CompareException.INVALID_TASK_ERROR) mock_error.assert_called_with("Compare applies only to task is tensor or statistics") + def test_get_file_type(self): + # 测试有效的 file_path (dump.json) + file_path = 'path/to/dump.json' + expected_file_type = Const.DUMP_JSON_FILE + self.assertEqual(get_file_type(file_path), expected_file_type) + + # 测试有效的 file_path (debug.json) + file_path = 'path/to/debug.json' + expected_file_type = Const.DEBUG_JSON_FILE + self.assertEqual(get_file_type(file_path), expected_file_type) + + # 测试无效的 file_path + file_path = 'path/to/unknown.json' + with self.assertRaises(CompareException) as context: + get_file_type(file_path) + self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR) + + # 测试非字符串类型的 file_path + file_path = 12345 # 非字符串类型 + with self.assertRaises(CompareException) as context: + get_file_type(file_path) + self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR) + @patch('msprobe.core.common.file_utils.get_file_content_bytes') def test_get_json_contents_should_raise_exception(self, mock_get_file_content_bytes): mock_get_file_content_bytes.return_value = 'not a dict' @@ -334,7 +354,7 @@ class TestUtils(TestCase): def test_recursion_depth_decorator(self, mock_error): # 测试递归深度限制函数 recursion_list = [[]] - temp_list = recursion_list[0] + temp_list = recursion_list[0] for _ in range(Const.MAX_DEPTH): temp_list.append([]) temp_list = temp_list[0] @@ -436,55 +456,127 @@ class TestUtils(TestCase): self.assertFalse(is_json_file(file_path_false)) -class TestCheckCrtValid(TestCase): - """ - Test the check_crt_valid function. - """ +class TestDetectFrameworkByDumpJson(unittest.TestCase): + + @patch('msprobe.core.common.utils.load_json') + def test_valid_pytorch_framework(self, mock_load_json): + mock_load_json.return_value = {"framework": Const.PT_FRAMEWORK} + + result = detect_framework_by_dump_json("dummy_path") + + self.assertEqual(result, Const.PT_FRAMEWORK) + + @patch('msprobe.core.common.utils.load_json') + def test_valid_mindspore_framework(self, mock_load_json): + mock_load_json.return_value = {"framework": Const.MS_FRAMEWORK} + result = detect_framework_by_dump_json("dummy_path") + + self.assertEqual(result, Const.MS_FRAMEWORK) + + def test_detect_framework_in_file(self): + self.current_dir = Path(__file__).parent + file_path = self.current_dir / "test_dump_file/pt_dump_no_framework.json" + result = detect_framework_by_dump_json(file_path) + self.assertEqual(result, Const.PT_FRAMEWORK) + + self.current_dir = Path(__file__).parent + file_path = self.current_dir / "test_dump_file/ms_dump_no_framework.json" + result = detect_framework_by_dump_json(file_path) + self.assertEqual(result, Const.MS_FRAMEWORK) + + @patch("msprobe.core.common.utils.logger") + def test_detect_framework_exception(self, mock_logger): + self.current_dir = Path(__file__).parent + file_path = self.current_dir / "test_dump_file/dump_no_pt_no_ms.json" + with self.assertRaises(CompareException) as context: + result = detect_framework_by_dump_json(file_path) + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_logger.error.assert_called_once_with(f"{file_path} must be based on the MindSpore or PyTorch framework.") + + +class TestIsSaveVariableValid(unittest.TestCase): def setUp(self): - self.cert_file_path = "cert_file_path.pem" - if not os.path.exists(self.cert_file_path): - with open(self.cert_file_path, 'w') as f: - f.write("This is a test certificate.") - - def tearDown(self): - if os.path.exists(self.cert_file_path): - os.remove(self.cert_file_path) - - @patch('msprobe.core.common.file_utils.datetime') - @patch('OpenSSL.crypto.load_certificate') - @patch('builtins.open', new_callable=mock_open, read_data="cert_data") - def test_check_crt_valid_success(self, mock_open_, mock_load_certificate, mock_datetime): - mock_cert = MagicMock() - mock_cert.get_notBefore.return_value = b'20220101' - mock_cert.get_notAfter.return_value = b'20230101' - mock_cert.has_expired.return_value = False - mock_load_certificate.return_value = mock_cert - mock_datetime.now.return_value = datetime(2022, 10, 1) - - check_crt_valid(self.cert_file_path) - mock_load_certificate.assert_called_once_with(OpenSSL.crypto.FILETYPE_PEM, 'cert_data') - - @patch('datetime.datetime') - @patch('OpenSSL.crypto.load_certificate') - @patch('builtins.open', new_callable=mock_open, read_data="cert_data") - def test_check_crt_valid_expired(self, mock_open_, mock_load_certificate, mock_datetime): - mock_cert = MagicMock() - mock_cert.get_notBefore.return_value = b'20220101' - mock_cert.get_notAfter.return_value = b'20230101' - mock_cert.has_expired.return_value = True - mock_load_certificate.return_value = mock_cert - mock_datetime.now.return_value = datetime(2022, 10, 1, tzinfo=timezone.utc) - - with self.assertRaises(RuntimeError) as context: - check_crt_valid(self.cert_file_path) - self.assertIn('The SSL certificate has expired and needs to be replaced', str(context.exception)) - - @patch('OpenSSL.crypto.load_certificate') - @patch('builtins.open', new_callable=mock_open, read_data="cert_data") - def test_check_crt_valid_exception(self, mock_open_, mock_load_certificate): - mock_load_certificate.side_effect = Exception('Test Exception') - - with self.assertRaises(RuntimeError) as context: - check_crt_valid(self.cert_file_path) - self.assertIn('The SSL certificate is invalid', str(context.exception)) + self.valid_special_types = (int, float, str, bool) + + @patch.object(Const, "DUMP_MAX_DEPTH", 5) + def test_is_save_variable_valid_DepthExceeded_ReturnsFalse(self): + # 构造深度 = 阈值 + 1 + nested = [0] * 3 + for _ in range(Const.DUMP_MAX_DEPTH + 1): # 注意 +1,确保“超过”阈值 + nested = [nested] + self.assertFalse(is_save_variable_valid(nested, self.valid_special_types)) + + def test_is_save_variable_valid_ValidSpecialTypes_ReturnsTrue(self): + for valid_type in self.valid_special_types: + self.assertTrue(is_save_variable_valid(valid_type(0), self.valid_special_types)) + + def test_is_save_variable_valid_ListWithValidElements_ReturnsTrue(self): + self.assertTrue(is_save_variable_valid([1, 2, 3], self.valid_special_types)) + + def test_is_save_variable_valid_ListWithInvalidElement_ReturnsFalse(self): + self.assertFalse(is_save_variable_valid([1, "test", [1, slice(1)]], self.valid_special_types)) + + def test_is_save_variable_valid_TupleWithValidElements_ReturnsTrue(self): + self.assertTrue(is_save_variable_valid((1, 2, 3), self.valid_special_types)) + + def test_is_save_variable_valid_TupleWithInvalidElement_ReturnsFalse(self): + self.assertFalse(is_save_variable_valid((1, "test", [1, slice(1)]), self.valid_special_types)) + + def test_is_save_variable_valid_DictWithValidElements_ReturnsTrue(self): + self.assertTrue(is_save_variable_valid({"a": 1, "b": "test"}, self.valid_special_types)) + + def test_is_save_variable_valid_DictWithInvalidKey_ReturnsFalse(self): + self.assertFalse(is_save_variable_valid({1: "test"}, self.valid_special_types)) + + def test_is_save_variable_valid_DictWithInvalidValue_ReturnsFalse(self): + self.assertFalse(is_save_variable_valid({"a": [1, slice(1)]}, self.valid_special_types)) + + +class TestCheckDumpJsonKey(unittest.TestCase): + def test_valid_input(self): + json_data = { + "task": "tensor", + "data": {"api1": "value1"} + } + task, api_data = check_dump_json_key(json_data, "NPU") + self.assertEqual(task, "tensor") + self.assertEqual(api_data, {"api1": "value1"}) + + @patch("msprobe.core.common.utils.logger") + def test_missing_task(self, mock_logger): + json_data = { + "data": {"api1": "value1"} + } + with self.assertRaises(CompareException) as context: + check_dump_json_key(json_data, "bench") + self.assertEqual(context.exception.code, CompareException.INVALID_TASK_ERROR) + mock_logger.error.assert_called_once_with( + "Task for bench is empty, please check." + ) + + @patch("msprobe.core.common.utils.logger") + def test_missing_data(self, mock_logger): + json_data = { + "task": "tensor" + } + with self.assertRaises(CompareException) as context: + check_dump_json_key(json_data, "npu") + self.assertEqual(context.exception.code, CompareException.INVALID_DATA_ERROR) + mock_logger.error.assert_called_once_with( + "Missing 'data' in dump.json, please check dump.json of npu." + ) + + @patch("msprobe.core.common.utils.logger") + def test_wrong_data_type(self, mock_logger): + json_data = { + "task": "tensor", + "data": [1] + } + with self.assertRaises(CompareException) as context: + check_dump_json_key(json_data, "npu") + self.assertEqual(context.exception.code, CompareException.INVALID_DATA_ERROR) + mock_logger.error.assert_called_once_with( + "Invalid type for 'data': expected a dict. Please check dump.json of npu." + ) + diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/find_first/test_analyzer.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/find_first/test_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..717f160be0c018dc21f7ac2fb0a9a4e863ee00e5 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/find_first/test_analyzer.py @@ -0,0 +1,232 @@ +import unittest +import os +import sys +import json +import tempfile +import shutil +from unittest.mock import patch, MagicMock, mock_open + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))) + +from msprobe.core.compare.find_first.analyzer import DiffAnalyzer +from msprobe.core.compare.find_first.utils import RankPath, FileCache, DiffAnalyseConst +from msprobe.core.compare.find_first.graph import DataNode, CommunicationNode +from msprobe.core.common.const import Const + + +class TestDiffAnalyzer(unittest.TestCase): + def setUp(self): + # 创建临时目录用于测试 + self.temp_dir = tempfile.mkdtemp() + self.npu_path = os.path.join(self.temp_dir, "npu") + self.bench_path = os.path.join(self.temp_dir, "bench") + self.output_path = os.path.join(self.temp_dir, "output") + + # 创建目录结构 + os.makedirs(self.npu_path, exist_ok=True) + os.makedirs(self.bench_path, exist_ok=True) + os.makedirs(self.output_path, exist_ok=True) + + # 创建测试文件 + self.create_test_files() + + # 初始化分析器 + self.analyzer = DiffAnalyzer(self.npu_path, self.bench_path, self.output_path) + + def tearDown(self): + # 清理临时目录 + shutil.rmtree(self.temp_dir) + # 重置FileCache单例 + FileCache._instance = None + + def create_test_files(self): + # 创建比较结果文件 + compare_result_rank0 = os.path.join(self.output_path, "compare_result_rank0_123456.json") + compare_result_rank1 = os.path.join(self.output_path, "compare_result_rank1_123456.json") + + # 创建测试数据 + rank0_data = { + "Torch.add.1": { + "is_same": True, + "op_items": [ + {"NPU_Name": "input.0", "NPU_Max": 1.0, "NPU_Min": 0.0, "NPU_Mean": 0.5, "NPU_Norm": 0.7, "Stack": [["Torch.add.1", {"file": "test.py", "line": 10}]]} + ] + }, + "Distributed.all_reduce.2": { + "is_same": False, + "op_items": [ + {"NPU_Name": "input.0.dst", "NPU_Max": 1, "Stack": [["Distributed.all_reduce.2", {"file": "test.py", "line": 20}]]}, + {"NPU_Name": "output.0", "NPU_Max": 2.0, "Stack": "N/A"} + ] + }, + "Torch.mul.3": { + "is_same": False, + "op_items": [ + {"NPU_Name": "input.0", "NPU_Max": 2.0, "Stack": [["Torch.mul.3", {"file": "test.py", "line": 30}]]}, + {"NPU_Name": "output.0", "NPU_Max": 4.0, "Stack": "N/A"} + ] + } + } + + rank1_data = { + "Torch.add.1": { + "is_same": True, + "op_items": [ + {"NPU_Name": "input.0", "NPU_Max": 1.0, "Stack": [["Torch.add.1", {"file": "test.py", "line": 10}]]} + ] + }, + "Distributed.all_reduce.2": { + "is_same": True, + "op_items": [ + {"NPU_Name": "input.0.src", "NPU_Max": 0, "Stack": [["Distributed.all_reduce.2", {"file": "test.py", "line": 20}]]}, + {"NPU_Name": "output.0", "NPU_Max": 2.0, "Stack": "N/A"} + ] + } + } + + # 写入测试数据 + with open(compare_result_rank0, "w") as f: + json.dump(rank0_data, f) + + with open(compare_result_rank1, "w") as f: + json.dump(rank1_data, f) + + @patch('msprobe.core.compare.find_first.analyzer.DataProcessor') + def test_pre_process(self, mock_processor): + # 模拟预处理 + mock_processor_instance = mock_processor.return_value + self.analyzer.pre_processor = mock_processor_instance + + self.analyzer._pre_process() + + # 验证预处理调用 + mock_processor_instance.process.assert_called_once_with( + self.npu_path, self.bench_path, self.output_path + ) + + # 验证路径解析 + self.assertEqual(len(self.analyzer._paths), 2) # 应该有两个rank路径 + self.assertIn(0, self.analyzer._paths) + self.assertIn(1, self.analyzer._paths) + + def test_resolve_input_path(self): + # 测试解析输入路径 + self.analyzer._resolve_input_path(self.output_path) + + # 验证路径解析 + self.assertEqual(len(self.analyzer._paths), 2) # 应该有两个rank路径 + self.assertIn(0, self.analyzer._paths) + self.assertIn(1, self.analyzer._paths) + self.assertEqual(self.analyzer._paths[0].rank, 0) + self.assertEqual(self.analyzer._paths[1].rank, 1) + + @patch.object(FileCache, 'load_json') + def test_pre_analyze(self, mock_load_json): + # 模拟加载JSON数据 + mock_load_json.side_effect = lambda path: { + "Torch.add.1.forward": {"is_same": False, "op_items": []}, + "Distributed.all_reduce.2.forward": {"is_same": True, "op_items": []} + } if "rank0" in path else { + "Torch.add.1.forward": {"is_same": True, "op_items": []}, + "Distributed.all_reduce.2.forward": {"is_same": True, "op_items": []} + } + + # 设置路径 + self.analyzer._paths = { + 0: RankPath(0, os.path.join(self.output_path, "compare_result_rank0_123456.json")), + 1: RankPath(1, os.path.join(self.output_path, "compare_result_rank1_123456.json")) + } + + # 执行预分析 + self.analyzer._pre_analyze() + + # 验证结果 + self.assertEqual(len(self.analyzer._diff_nodes), 1) # 应该找到一个异常节点 + self.assertEqual(self.analyzer._diff_nodes[0].op_name, "Torch.add.1.forward") + self.assertEqual(self.analyzer._first_comm_nodes[1], "Distributed.all_reduce.2.forward") + + @patch.object(DiffAnalyzer, '_analyze_comm_nodes') + @patch.object(DiffAnalyzer, '_connect_comm_nodes') + @patch.object(DiffAnalyzer, '_pruning') + @patch.object(DiffAnalyzer, '_search_first_diff') + def test_analyze(self, mock_search, mock_pruning, mock_connect, mock_analyze_comm): + # 模拟分析过程 + self.analyzer._paths = { + 0: RankPath(0, os.path.join(self.output_path, "compare_result_rank0_123456.json")), + 1: RankPath(1, os.path.join(self.output_path, "compare_result_rank1_123456.json")) + } + + # 执行分析 + self.analyzer._analyze() + + # 验证调用 + mock_analyze_comm.assert_called() + mock_connect.assert_called_once() + mock_pruning.assert_called_once() + mock_search.assert_called_once() + + @patch.object(FileCache, 'load_json') + def test_analyze_comm_nodes(self, mock_load_json): + # 模拟加载JSON数据 + mock_load_json.return_value = { + "Distributed.all_reduce.1.forward": {"is_same": False, "op_items": []}, + "Torch.add.2": {"is_same": True, "op_items": []}, + "Distributed.all_reduce.3.forward": {"is_same": True, "op_items": []} + } + + # 设置首个通信节点 + self.analyzer._first_comm_nodes = {0: "Distributed.all_reduce.1.forward"} + + # 设置路径 + self.analyzer._paths = { + 0: RankPath(0, os.path.join(self.output_path, "compare_result_rank0_123456.json")) + } + + # 执行通信节点分析 + result = self.analyzer._analyze_comm_nodes(0) + + # 验证结果 + self.assertEqual(len(result), 2) # 应该有两个通信节点 + self.assertIn("0.Distributed.all_reduce.1.forward", result) + self.assertIn("0.Distributed.all_reduce.3.forward", result) + + def test_get_node_by_id(self): + # 设置通信节点字典 + node = MagicMock(spec=CommunicationNode) + self.analyzer._rank_comm_nodes_dict = {0: {"0.Distributed.all_reduce.1.forward": node}} + + # 测试获取节点 + result = self.analyzer._get_node_by_id("0.Distributed.all_reduce.1.forward") + + # 验证结果 + self.assertEqual(result, node) + + # 测试无效节点ID + with self.assertRaises(RuntimeError): + self.analyzer._get_node_by_id("invalid_id") + + @patch('msprobe.core.compare.find_first.analyzer.save_json') + @patch('msprobe.core.compare.find_first.analyzer.make_dir') + @patch('msprobe.core.compare.find_first.analyzer.time') + def test_gen_analyze_info(self, mock_time, mock_make_dir, mock_save_json): + # 模拟时间戳 + mock_time.time_ns.return_value = 123456789 + + # 设置异常节点 + node = MagicMock(spec=DataNode) + node.rank = 0 + node.gen_node_info.return_value = {"op_name": "test_op"} + self.analyzer._diff_nodes = [node] + + # 设置路径 + self.analyzer._paths = {0: MagicMock(spec=RankPath)} + + # 生成分析信息 + self.analyzer._gen_analyze_info() + + # 验证调用 + mock_save_json.assert_called_once() + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/find_first/test_data_processor.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/find_first/test_data_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..c37a707f1d882f6a2669ad916b295bac7491ce29 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/find_first/test_data_processor.py @@ -0,0 +1,44 @@ +import unittest +import os +import sys +from unittest.mock import patch, MagicMock + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))) + +from msprobe.core.compare.find_first.data_processor import DataProcessor +from msprobe.core.common.const import Const + + +class TestDataProcessor(unittest.TestCase): + def setUp(self): + # 创建测试路径 + self.npu_path = "/path/to/npu" + self.bench_path = "/path/to/bench" + self.output_path = "/path/to/output" + + def test_init_pytorch(self): + # 测试PyTorch框架初始化 + processor = DataProcessor(Const.PT_FRAMEWORK) + from msprobe.pytorch.compare.distributed_compare import compare_distributed + + # 验证初始化 + self.assertEqual(processor.data_frame, Const.PT_FRAMEWORK) + self.assertEqual(processor.process_func, compare_distributed) + + def test_init_mindspore(self): + # 测试MindSpore框架初始化 + processor = DataProcessor(Const.MS_FRAMEWORK) + from msprobe.mindspore.compare.distributed_compare import ms_compare_distributed + + # 验证初始化 + self.assertEqual(processor.data_frame, Const.MS_FRAMEWORK) + self.assertEqual(processor.process_func, ms_compare_distributed) + + + def test_init_unsupported(self): + # 测试不支持的框架 + with self.assertRaises(ValueError): + DataProcessor("unsupported_framework") + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/find_first/test_find_first_graph.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/find_first/test_find_first_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..e9dd810d780e40b29a5dc0a76e3ca07870f0ee31 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/find_first/test_find_first_graph.py @@ -0,0 +1,224 @@ +import unittest +import os +import sys +from unittest.mock import patch, MagicMock + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))) + +from msprobe.core.compare.find_first.graph import DataNode, CommunicationNode +from msprobe.core.compare.find_first.utils import RankPath, DiffAnalyseConst +from msprobe.core.common.const import Const, CompareConst + + +class TestDataNode(unittest.TestCase): + def setUp(self): + # 创建测试数据 + self.op_name = "Torch.add.1.forward" + self.rank = 0 + self.op_data = { + "is_same": False, + "op_items": [ + { + CompareConst.NPU_NAME: "input.0", + CompareConst.NPU_MAX: 1.0, + CompareConst.NPU_MIN: 0.0, + CompareConst.NPU_MEAN: 0.5, + CompareConst.NPU_NORM: 0.7, + CompareConst.STACK: [["Torch.add.1.forward", {"file": "test.py", "line": 10}]] + }, + { + CompareConst.NPU_NAME: "output.0", + CompareConst.NPU_MD5: "abc123", + CompareConst.STACK: CompareConst.N_A + } + ] + } + + def test_init(self): + # 测试初始化 + node = DataNode(self.op_name, self.rank, self.op_data) + + # 验证基本属性 + self.assertEqual(node.op_name, self.op_name) + self.assertEqual(node.rank, self.rank) + self.assertTrue(node.is_diff) # is_same为False,所以is_diff应为True + self.assertEqual(node.layer, 0) + self.assertEqual(node.sub_layer, 0) + + # 验证输入输出解析 + self.assertIn("input.0", node.inputs) + self.assertIn("output.0", node.outputs) + self.assertEqual(node.inputs["input.0"][CompareConst.NPU_MAX], 1.0) + self.assertEqual(node.outputs["output.0"][CompareConst.NPU_MD5], "abc123") + + # 验证堆栈信息 + self.assertIsNotNone(node.stack) + + def test_find_stack(self): + # 测试查找堆栈信息 + node = DataNode(self.op_name, self.rank, self.op_data) + stack_info = node.find_stack() + + # 验证堆栈信息 + self.assertEqual(stack_info, {"file": "test.py", "line": 10}) + + def test_gen_node_info(self): + # 测试生成节点信息 + node = DataNode(self.op_name, self.rank, self.op_data) + mock_path = MagicMock(spec=RankPath) + + info = node.gen_node_info(mock_path) + + # 验证节点信息 + self.assertEqual(info["op_name"], self.op_name) + self.assertIn(Const.INPUT, info["data_info"]) + self.assertIn(Const.OUTPUT, info["data_info"]) + self.assertIsNotNone(info["stack_info"]) + + +class TestCommunicationNode(unittest.TestCase): + def setUp(self): + # 创建测试数据 + self.op_name = "Distributed.all_reduce.1.forward" + self.rank = 0 + self.node_id = f"{self.rank}.{self.op_name}" + + # 创建模拟的DataNode + self.data_node = MagicMock(spec=DataNode) + self.data_node.op_name = self.op_name + self.data_node.is_diff = False + self.data_node.inputs = { + "input.0.dst": {CompareConst.NPU_MAX: 1}, + "input.1.group": {CompareConst.NPU_MD5: "[0,1,2]"} + } + + def test_init(self): + # 测试初始化 + node = CommunicationNode(self.node_id, self.rank, self.data_node) + + # 验证基本属性 + self.assertEqual(node.node_id, self.node_id) + self.assertEqual(node.rank, self.rank) + self.assertEqual(node.data, self.data_node) + self.assertFalse(node.is_diff) + self.assertEqual(node.layer, 0) + self.assertEqual(node.api, "all_reduce") + self.assertEqual(node.call_cnt, "1") + self.assertFalse(node.connected) + + # 验证节点关系初始化 + self.assertIsNone(node.pre_node) + self.assertEqual(node.link_nodes, {}) + self.assertEqual(node.dst_nodes, {}) + self.assertEqual(node.src_nodes, {}) + self.assertEqual(node.next_nodes, {}) + self.assertEqual(node.compute_ops, []) + + def test_add_next(self): + # 测试添加下一个节点 + node = CommunicationNode(self.node_id, self.rank, self.data_node) + next_node = CommunicationNode("1.Distributed.all_reduce.2.forward", 1, self.data_node) + + node.add_next(next_node) + + # 验证节点关系 + self.assertIn(next_node.node_id, node.next_nodes) + self.assertEqual(next_node.pre_node, node) + self.assertEqual(next_node.layer, node.layer + 1) + self.assertEqual(next_node.data.layer, next_node.layer) + + def test_add_link(self): + # 测试添加链接节点 + node = CommunicationNode(self.node_id, self.rank, self.data_node) + link_node = CommunicationNode("1.Distributed.all_reduce.2.forward", 1, self.data_node) + + node.add_link(link_node) + + # 验证节点关系 + self.assertIn(link_node.node_id, node.link_nodes) + self.assertIn(node.node_id, link_node.link_nodes) + self.assertEqual(link_node.layer, node.layer) + self.assertEqual(link_node.data.layer, link_node.layer) + self.assertTrue(node.connected) + self.assertTrue(link_node.connected) + + def test_add_dst(self): + # 测试添加目标节点 + node = CommunicationNode(self.node_id, self.rank, self.data_node) + dst_node = CommunicationNode("1.Distributed.all_reduce.2.forward", 1, self.data_node) + + node.add_dst(dst_node) + + # 验证节点关系 + self.assertIn(dst_node.node_id, node.dst_nodes) + self.assertIn(node.node_id, dst_node.src_nodes) + self.assertEqual(dst_node.layer, node.layer) + self.assertEqual(dst_node.data.layer, dst_node.layer) + self.assertTrue(node.connected) + self.assertTrue(dst_node.connected) + + def test_delete(self): + # 测试删除节点 + node = CommunicationNode(self.node_id, self.rank, self.data_node) + next_node = CommunicationNode("1.Distributed.all_reduce.2.forward", 1, self.data_node) + dst_node = CommunicationNode("2.Distributed.all_reduce.3.forward", 2, self.data_node) + src_node = CommunicationNode("3.Distributed.all_reduce.4.forward", 3, self.data_node) + link_node = CommunicationNode("4.Distributed.all_reduce.5.forward", 4, self.data_node) + pre_node = CommunicationNode("5.Distributed.all_reduce.6.forward", 5, self.data_node) + + # 建立节点关系 + node.add_next(next_node) + node.add_dst(dst_node) + src_node.add_dst(node) + node.add_link(link_node) + pre_node.add_next(node) + + # 删除节点 + node.delete() + + # 验证节点关系已清除 + self.assertIsNone(next_node.pre_node) + self.assertNotIn(node.node_id, dst_node.src_nodes) + self.assertNotIn(node.node_id, src_node.dst_nodes) + self.assertNotIn(node.node_id, link_node.link_nodes) + self.assertNotIn(node.node_id, pre_node.next_nodes) + + def test_find_connected_nodes(self): + # 测试查找连接节点 + node = CommunicationNode(self.node_id, self.rank, self.data_node) + + # 模拟输入数据 + node.data.inputs = { + "input.0.dst": {CompareConst.NPU_MAX: 1}, + "input.1.group": {CompareConst.NPU_MD5: "[0,1,2]"} + } + + result = node.find_connected_nodes() + + # 验证结果 + self.assertIn(1, result["ranks"]) + self.assertIn(0, result["ranks"]) + self.assertIn(2, result["ranks"]) + self.assertEqual(result["api"], "Distributed.all_reduce") + self.assertEqual(result["type"], DiffAnalyseConst.DST) + + def test_resolve_type(self): + # 测试解析节点类型 + # 测试SRC类型 + self.data_node.inputs = {"input.0.src": {CompareConst.NPU_MAX: 0}} + node = CommunicationNode(self.node_id, 0, self.data_node) + self.assertEqual(node.type, DiffAnalyseConst.SRC) + + # 测试DST类型 + self.data_node.inputs = {"input.0.dst": {CompareConst.NPU_MAX: 0}} + node = CommunicationNode(self.node_id, 0, self.data_node) + self.assertEqual(node.type, DiffAnalyseConst.DST) + + # 测试LINK类型(默认) + self.data_node.inputs = {"input.0": {CompareConst.NPU_MAX: 0}} + node = CommunicationNode(self.node_id, 0, self.data_node) + self.assertEqual(node.type, DiffAnalyseConst.LINK) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/find_first/test_find_first_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/find_first/test_find_first_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1f9ace22b56e58cba7836717d85cbb2c6b51dcf9 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/find_first/test_find_first_utils.py @@ -0,0 +1,161 @@ +import unittest +import os +import sys +import json +import tempfile +from unittest.mock import patch, MagicMock, mock_open + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))) + +from msprobe.core.compare.find_first.utils import ( + RankPath, FileCache, is_communication_op, is_ignore_op, + DiffAnalyseConst, analyze_diff_in_group +) +from msprobe.core.common.const import Const + + +class TestRankPath(unittest.TestCase): + def setUp(self): + # 创建临时文件用于测试 + self.temp_dir = tempfile.mkdtemp() + self.dump_path = os.path.join(self.temp_dir, "dump.json") + # 创建一个空文件 + with open(self.dump_path, "w") as f: + f.write("{}") + + def tearDown(self): + # 清理临时文件 + if os.path.exists(self.dump_path): + os.remove(self.dump_path) + os.rmdir(self.temp_dir) + + def test_init(self): + # 测试正常初始化 + rank_path = RankPath(1, self.dump_path) + self.assertEqual(rank_path.rank, 1) + self.assertEqual(rank_path.dump_path, self.dump_path) + + @patch('msprobe.core.compare.find_first.utils.check_file_or_directory_path') + def test_init_with_invalid_path(self, mock_check): + # 测试无效路径 + mock_check.side_effect = ValueError("Invalid path") + with self.assertRaises(ValueError): + RankPath(1, "/invalid/path") + + +class TestFileCache(unittest.TestCase): + def setUp(self): + # 重置单例 + FileCache._instance = None + self.cache = FileCache() + # 创建临时文件用于测试 + self.temp_dir = tempfile.mkdtemp() + self.json_path = os.path.join(self.temp_dir, "test.json") + self.test_data = {"key": "value"} + with open(self.json_path, "w") as f: + json.dump(self.test_data, f) + + def tearDown(self): + # 清理临时文件 + if os.path.exists(self.json_path): + os.remove(self.json_path) + os.rmdir(self.temp_dir) + + def test_singleton(self): + # 测试单例模式 + cache2 = FileCache() + self.assertIs(self.cache, cache2) + + def test_load_json(self): + # 测试加载JSON文件 + result = self.cache.load_json(self.json_path) + self.assertEqual(result, self.test_data) + + # 测试缓存功能 + result2 = self.cache.load_json(self.json_path) + self.assertEqual(result2, self.test_data) + self.assertEqual(self.cache._access_cnt[self.json_path], 1) # 访问计数应该增加 + + @patch('msprobe.core.compare.find_first.utils.load_json') + def test_cleanup(self, mock_load_json): + # 模拟大文件加载 + mock_load_json.return_value = {"large": "x" * 1000000} # 创建一个大对象 + + # 修改最大内存使用量为很小的值,强制清理 + original_max = self.cache._max_memory_usage + self.cache._max_memory_usage = 100 # 设置为很小的值 + + # 加载多个文件触发清理 + for i in range(5): + self.cache.load_json(f"{self.json_path}_{i}") + + # 恢复原始值 + self.cache._max_memory_usage = original_max + + +class TestCommunicationFunctions(unittest.TestCase): + def test_is_communication_op(self): + # 测试通信算子识别 + self.assertTrue(is_communication_op("Distributed.all_reduce.1")) + self.assertTrue(is_communication_op("Distributed.send.2")) + self.assertTrue(is_communication_op("Distributed.broadcast.3")) + self.assertTrue(is_communication_op(f"{Const.MINT_DIST_API_TYPE_PREFIX}.all_gather.4")) + self.assertTrue(is_communication_op(f"{Const.MS_API_TYPE_COM}.reduce.5")) + + # 测试非通信算子 + self.assertFalse(is_communication_op("Torch.add.1")) + self.assertFalse(is_communication_op("Torch.matmul.2")) + + def test_is_ignore_op(self): + # 测试忽略算子识别 + self.assertTrue(is_ignore_op("Torch.empty.1")) + self.assertTrue(is_ignore_op("Torch.fill.2")) + + # 测试非忽略算子 + self.assertFalse(is_ignore_op("Torch.add.1")) + self.assertFalse(is_ignore_op("Torch.matmul.2")) + + +class TestDiffAnalyseConst(unittest.TestCase): + def test_constants(self): + # 测试常量定义 + self.assertIn('send', DiffAnalyseConst.COMMUNICATION_KEYWORDS) + self.assertIn('recv', DiffAnalyseConst.COMMUNICATION_KEYWORDS) + self.assertIn('all_reduce', DiffAnalyseConst.COMMUNICATION_KEYWORDS) + + # 测试P2P API映射 + self.assertEqual(DiffAnalyseConst.P2P_API_MAPPING['send'], 'recv') + self.assertEqual(DiffAnalyseConst.P2P_API_MAPPING['recv'], 'send') + + # 测试方向常量 + self.assertEqual(DiffAnalyseConst.OPPOSITE_DIR[DiffAnalyseConst.SRC], DiffAnalyseConst.DST) + self.assertEqual(DiffAnalyseConst.OPPOSITE_DIR[DiffAnalyseConst.DST], DiffAnalyseConst.SRC) + + +class TestAnalyzeDiffInGroup(unittest.TestCase): + def test_analyze_diff_in_group_empty(self): + # 测试空组 + result = analyze_diff_in_group([]) + self.assertEqual(result, []) + + def test_analyze_diff_in_group(self): + # 创建模拟的通信节点 + mock_node1 = MagicMock() + mock_node1.type = DiffAnalyseConst.SRC + mock_node1.is_diff = True + mock_node1.compute_ops = [MagicMock(), MagicMock()] + + mock_node2 = MagicMock() + mock_node2.type = DiffAnalyseConst.DST + mock_node2.is_diff = False + mock_node2.data.is_diff = True + + # 测试分析函数 + result = analyze_diff_in_group([mock_node1, mock_node2]) + + # 验证结果包含所有异常节点 + self.assertEqual(len(result), 4) # 2个计算节点 + 2个通信节点 + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py index b4566fcfe6f48d9040feb4dc22f3a96cd08719a7..3cbef9af6c1968d6aa4f553ec675154aaeba2a24 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py @@ -6,15 +6,49 @@ import threading import unittest from unittest.mock import patch +import numpy as np import pandas as pd import torch +from msprobe.core.common.file_utils import load_json from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.utils import CompareException -from msprobe.core.compare.acc_compare import Comparator, ModeConfig, get_bench_data_name -from msprobe.core.compare.highlight import find_error_rows, find_compare_result_error_rows, ApiBatch -from msprobe.core.compare.utils import get_accuracy -from msprobe.pytorch.compare.pt_compare import PTComparator +from msprobe.core.compare.acc_compare import ModeConfig, MappingConfig, MappingDict, Comparator, ParseData, ProcessDf, \ + Match, CreateTable, CalcStatsDiff + +npu_op_item_data_fuzzy = { + 'op_name': 'Functional.conv2d.0.forward.input.0', + 'dtype': 'torch.float32', + 'shape': [1, 1, 28, 28], + 'summary': [3.029174327850342, -2.926689624786377, -0.06619918346405029], + 'stack_info': [], + 'data_name': 'Functional.conv2d.0.forward.input.0.pt', + 'compare_key': 'Functional.conv2d.0.forward.input.0', + 'compare_shape': [1, 1, 28, 28], +} +npu_op_item_fuzzy = pd.Series(npu_op_item_data_fuzzy) +npu_op_item_data_fuzzy_2 = { + 'op_name': 'Functional.conv2d.0.forward.input.1', + 'dtype': 'torch.float32', + 'shape': [1, 1, 28, 28], + 'summary': [3.029174327850342, -2.926689624786377, -0.06619918346405029], + 'stack_info': [], + 'data_name': 'Functional.conv2d.0.forward.input.1.pt', + 'compare_key': 'Functional.conv2d.0.forward.input.1', + 'compare_shape': [1, 1, 28, 28], +} +npu_op_item_fuzzy_2 = pd.Series(npu_op_item_data_fuzzy_2) +bench_op_item_data_fuzzy = { + 'op_name': 'Functional.conv2d.1.forward.input.0', + 'dtype': 'torch.float32', + 'shape': [1, 1, 28, 28], + 'summary': [3.029174327850342, -2.926689624786377, -0.06619918346405029], + 'stack_info': [], + 'data_name': 'Functional.conv2d.1.forward.input.0.pt', + 'compare_key': 'Functional.conv2d.1.forward.input.0', + 'compare_shape': [1, 1, 28, 28], +} +bench_op_item_fuzzy = pd.Series(bench_op_item_data_fuzzy) npu_dict = {'op_name': ['Functional.conv2d.0.forward.input.0', 'Functional.conv2d.0.forward.input.1', 'Functional.conv2d.0.forward.input.2', 'Functional.conv2d.0.forward.output'], @@ -159,50 +193,21 @@ aten_result = [ -10.640625, -0.008758544921875, 5.397906303405762, -5.796811580657959, 2.5283952709287405e-10, 'Warning', 'Need double check api accuracy.', 'None'], ['Aten__native_batch_norm_legit_functional.default_0_forward.output.1', 'Nan', 'torch.float32', 'Nan', [256], 'Nan', - ' ', ' ', ' ', ' ', ' ', 0.30550330877304077, -0.24485322833061218, -0.010361209511756897, 'Nan', 'Nan', 'Nan', + ' ', ' ', ' ', ' ', ' ', ' ', 0.30550330877304077, -0.24485322833061218, -0.010361209511756897, 'Nan', 'Nan', + 'Nan', 'Yes', '', 'None'], ['Aten__native_batch_norm_legit_functional.default_0_forward.output.2', 'Nan', 'torch.float32', 'Nan', [256], 'Nan', - ' ', ' ', ' ', ' ', ' ', 623.9192504882812, 432.96826171875, 520.2276611328125, 'Nan', 'Nan', 'Nan', + ' ', ' ', ' ', ' ', ' ', ' ', 623.9192504882812, 432.96826171875, 520.2276611328125, 'Nan', 'Nan', 'Nan', 'Yes', '', 'None'], ['Aten__native_batch_norm_legit_functional.default_0_forward.output.3', 'Nan', 'torch.float32', 'Nan', [256], 'Nan', - ' ', ' ', ' ', ' ', ' ', 2.4797861576080322, -3.055997371673584, -0.04795549064874649, 'Nan', 'Nan', 'Nan', + ' ', ' ', ' ', ' ', ' ', ' ', 2.4797861576080322, -3.055997371673584, -0.04795549064874649, 'Nan', 'Nan', 'Nan', 'Yes', '', 'None'], ['Aten__native_batch_norm_legit_functional.default_0_forward.output.4', 'Nan', 'torch.float32', 'Nan', [256], 'Nan', - ' ', ' ', ' ', ' ', ' ', 61.7945556640625, 42.59713363647461, 52.03831481933594, 'Nan', 'Nan', 'Nan', + ' ', ' ', ' ', ' ', ' ', ' ', 61.7945556640625, 42.59713363647461, 52.03831481933594, 'Nan', 'Nan', 'Nan', 'Yes', '', 'None']] highlight_dict = {'red_rows': [], 'yellow_rows': []} -num_0, num_1, num_2, num_3 = 0, 1, 2, 3 -summary_line_input = ['Functional_batch_norm_0_forward.input.0', 'Functional_batch_norm_0_forward.input.0', - 'torch.float16', - 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.01, 0, 0, 0, 1, 1, 1, 1, 1.01, 1, 1, 1, - 'Yes', ''] -summary_line_1 = ['Functional_batch_norm_0_forward.output.0', 'Functional_batch_norm_0_forward.output.0', - 'torch.float16', - 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 10, 0, 0, 0, 2, 0, 1, 1, 1, 1, 1, 1, - 'Warning', ''] -summary_line_2 = ['Functional_batch_norm_0_forward.output.1', 'Functional_batch_norm_0_forward.output.1', - 'torch.float16', - 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.02, 0, 0, 0, 0.12, 0, 1, 1, 0.1, 1, 1, 1, - 'Warning', ''] -summary_line_3 = ['Functional_batch_norm_0_forward.output.2', 'Functional_batch_norm_0_forward.output.2', - 'torch.float16', - 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0, 0, 0, 0, 2, 0, 1, 1, 1, 1, 1, 1, - 'Warning', ''] -line_input = ['Functional.batch.norm.0.forward.input.0', 'Functional.batch.norm.0.forward.input.0', 'torch.float16', - 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 1, 1, 1, 0.95, 1, 1, 1, 1, 1, 1.01, 1, 1, 1, - 'Yes', ''] -line_1 = ['Functional.batch.norm.0.forward.output.0', 'Functional.batch.norm.0.forward.output.0', 'torch.float16', - 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.8, 1, 1, 0.59, 1, 'nan', 0, 1, 1, 19, 1, 1, 1, - 'Warning', ''] -line_2 = ['Functional.batch.norm.0.forward.output.1', 'Functional.batch.norm.0.forward.output.1', 'torch.float16', - 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.9, 1, 1, 0.8, 1, 0, 0.12, 0, 1, 1, 0.1, 1, 1, 1, - 'Warning', ''] -line_3 = ['Functional.batch.norm.0.forward.output.2', 'Functional.batch.norm.0.forward.output.2', 'torch.float16', - 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.8, 1.1e+10, 1, 0.85, 1, 9, 0.12, 0, 1, 1, 0.1, 1, - 1, 1, 'Warning', ''] - op_data = { 'input_args': [{'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], 'Max': 0.33033010363578796, 'Min': -0.331031858921051, 'Mean': -0.030964046716690063, @@ -263,6 +268,33 @@ def generate_dump_json(base_dir): json.dump(data, json_file) +def generate_dump_json_md5(base_dir): + data_path = os.path.join(base_dir, 'dump_md5.json') + data = { + 'task': 'statistics', + 'level': 'L1', + 'dump_data_dir': '', + 'data': { + 'Functional.linear.0.forward': { + 'input_args': [ + {'type': 'torch.Tensor', + 'dtype': 'torch.float32', + 'shape': [2, 2], + 'Max': 2, + 'Min': 0, + 'Mean': 1, + 'Norm': 1, + 'requires_grad': False, + 'md5': 123456 + } + ] + } + } + } + with open(data_path, 'w') as json_file: + json.dump(data, json_file) + + def generate_stack_json(base_dir): data_path = os.path.join(base_dir, 'stack.json') data = {'Functional.linear.0.forward': ['File']} @@ -296,145 +328,6 @@ class TestUtilsMethods(unittest.TestCase): if os.path.exists(base_dir3): shutil.rmtree(base_dir3) - def test_get_accuracy_graph_mode(self): - result = [] - get_accuracy(result, npu_dict_aten, bench_dict_functional, dump_mode=Const.SUMMARY) - self.assertEqual(result, aten_result) - - def test_find_error_rows(self): - api_batch = ApiBatch("Functional_batch_norm_0_forward", 0) - api_batch.input_len = 1 - api_batch.output_end_index = 4 - api_batch.params_end_index = 4 - summary_result = [summary_line_input, summary_line_1, summary_line_2, summary_line_3] - highlight_dict_test = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []} - find_error_rows(summary_result, api_batch, highlight_dict_test, dump_mode=Const.SUMMARY) - self.assertEqual(highlight_dict_test, - {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}) - - def test_find_compare_result_error_rows(self): - result = [line_input, line_1, line_2, line_3] - result_df = pd.DataFrame(result) - highlight_dict_test = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []} - find_compare_result_error_rows(result_df, highlight_dict_test, dump_mode=Const.ALL) - self.assertEqual(highlight_dict_test, { - "red_rows": {1, 3}, - "yellow_rows": {2}, - "red_lines": [ - (1, ["maximum or minimum is nan, -inf, or inf"]), - (3, ["maximum absolute error exceeds 1e+10"]) - ], - "yellow_lines": [ - (2, ["The output's one thousandth err ratio decreases by more than 0.1 compared to the input/parameters's"]), - (3, [ - "maximum absolute error of both input/parameters and output exceed 1, " - "with the output larger by an order of magnitude", - "The output's cosine decreases by more than 0.1 compared to the input/parameters's"]) - ] - }) - - def test_calculate_summary_data(self): - npu_summary_data = [1, 1, 1, 1] - bench_summary_data = [2, 2, 2, 2] - result_item = ['', '', '', '', '', '', '', '', '', '', '', '', '', ''] - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - comparator = Comparator(mode_config) - comparator.calculate_summary_data(npu_summary_data, bench_summary_data, result_item) - self.assertEqual(result_item, - ['', '', '', '', '', '', -1, -1, -1, -1, '50.0%', '50.0%', '50.0%', '50.0%', '', '']) - - bench_summary_data = [0, 0, 0, 0] - result_item = ['', '', '', '', '', '', '', '', '', '', '', '', '', ''] - - comparator.calculate_summary_data(npu_summary_data, bench_summary_data, result_item) - self.assertEqual(result_item, ['', '', '', '', '', '', 1, 1, 1, 1, 'N/A', 'N/A', 'N/A', 'N/A', 'Warning', - 'Need double check api accuracy.']) - - def test_make_result_table_stack_mode_True(self): - result_md5 = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', 'File']] - result_summary = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '', '', '', '', '', - 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', 'File']] - result_all = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '', '', - 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', 'File', '-1']] - columns_md5_stack_mode_true = CompareConst.MD5_COMPARE_RESULT_HEADER + ['NPU_Stack_Info'] - result_table_md5_true = pd.DataFrame(result_md5, columns=columns_md5_stack_mode_true, dtype=object) - columns_summary_stack_mode_true = CompareConst.SUMMARY_COMPARE_RESULT_HEADER + ['NPU_Stack_Info'] - result_table_summary_true = pd.DataFrame(result_summary, columns=columns_summary_stack_mode_true, dtype=object) - columns_all_stack_mode_true = CompareConst.COMPARE_RESULT_HEADER + ['NPU_Stack_Info'] + ['Data_name'] - result_table_all_true = pd.DataFrame(result_all, columns=columns_all_stack_mode_true, dtype=object) - - stack_mode = True - auto_analyze = True - fuzzy_match = False - - dump_mode = Const.MD5 - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - result_df = Comparator(mode_config).make_result_table(result_md5) - self.assertTrue(result_df.equals(result_table_md5_true)) - - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - result_df = Comparator(mode_config).make_result_table(result_summary) - self.assertTrue(result_df.equals(result_table_summary_true)) - - dump_mode = Const.ALL - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - result_df = Comparator(mode_config).make_result_table(result_all) - self.assertTrue(result_df.equals(result_table_all_true)) - - def test_make_result_table_stack_mode_False(self): - result_md5_test = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '']] - result_md5 = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '']] - result_summary_test = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '', '', '', '', '', - 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '']] - result_summary = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '', '', '', '', '', - 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '']] - result_all_test = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '', '', - 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '', '-1']] - result_all = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '', '', - 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1']] - columns_md5_stack_mode_true = CompareConst.MD5_COMPARE_RESULT_HEADER - result_table_md5_true = pd.DataFrame(result_md5, columns=columns_md5_stack_mode_true, dtype='object') - columns_summary_stack_mode_true = CompareConst.SUMMARY_COMPARE_RESULT_HEADER - result_table_summary_true = pd.DataFrame(result_summary, columns=columns_summary_stack_mode_true, - dtype='object') - columns_all_stack_mode_true = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] - result_table_all_true = pd.DataFrame(result_all, columns=columns_all_stack_mode_true, dtype='object') - - stack_mode = False - auto_analyze = True - fuzzy_match = False - - dump_mode = Const.MD5 - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - result_df = Comparator(mode_config).make_result_table(result_md5_test) - self.assertTrue(result_df.equals(result_table_md5_true)) - - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - result_df = Comparator(mode_config).make_result_table(result_summary_test) - self.assertTrue(result_df.equals(result_table_summary_true)) - - dump_mode = Const.ALL - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - result_df = Comparator(mode_config).make_result_table(result_all_test) - self.assertTrue(result_df.equals(result_table_all_true)) - def test_gen_merge_list(self): op_data = { 'input_args': [ @@ -449,310 +342,617 @@ class TestUtilsMethods(unittest.TestCase): json_data = {'data': {'Functional.linear.0.forward': op_data}} op_name = 'Functional.linear.0.forward' stack_json_data = {'Functional.linear.0.forward': ['File']} - merge_list = { - 'input_struct': [('torch.float32', [2, 2])], - 'op_name': ['Functional.linear.0.forward.input.0'], - 'output_struct': [], - 'params_struct': [], - 'params_grad_struct': [], - 'stack_info': [['File']], - 'summary': [[1, 1, 1, 1]] + target_merge_list = [ + { + 'full_op_name': 'Functional.linear.0.forward.input.0', + 'type': 'torch.Tensor', + 'dtype': 'torch.float32', + 'shape': [2, 2], + 'requires_grad': 'False', + 'Max': 1, + 'Min': 1, + 'Mean': 1, + 'Norm': 1, + 'md5': '00000000', + 'data_name': 'Functional.linear.0.forward.input.0.pt', + 'state': 'input' + }, + { + 'full_op_name': 'Functional.linear.0.forward', + 'full_info': ['File'] + } + ] + + config_dict = { + 'stack_mode': True, + 'auto_analyze': True, + 'fuzzy_match': False, + 'dump_mode': Const.SUMMARY, } + mode_config = ModeConfig(**config_dict) + + result = ParseData(mode_config, 'rank0').gen_merge_list(json_data, op_name, stack_json_data) + self.assertEqual(result, target_merge_list) + + def test_check_op_item_fuzzy(self): + config_dict = { + 'stack_mode': False, + 'auto_analyze': True, + 'fuzzy_match': True, + 'dump_mode': Const.SUMMARY, + } + mode_config = ModeConfig(**config_dict) + mapping_config = MappingConfig() + + match = Match(mode_config, mapping_config, cross_frame=False) + result = match.check_op_item(npu_op_item_fuzzy, bench_op_item_fuzzy) + self.assertEqual(result, True) + + def test_compare_statistics(self): + generate_dump_json(base_dir) + generate_stack_json(base_dir) + file_list = [os.path.join(base_dir, 'dump.json'), os.path.join(base_dir, 'dump.json'), + os.path.join(base_dir, 'stack.json')] + + config_dict = { + 'stack_mode': True, + 'auto_analyze': True, + 'fuzzy_match': False, + 'dump_mode': Const.SUMMARY, + } + mode_config = ModeConfig(**config_dict) + mapping_config = MappingConfig() + + from msprobe.pytorch.compare.pt_compare import read_real_data + comparator = Comparator(read_real_data, mode_config, mapping_config) + parse_data = ParseData(mode_config, '') + npu_df, bench_df = parse_data.parse(file_list) + result = comparator.compare_statistics(npu_df, bench_df) + o_data = [ + ['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', + 'torch.float32', 'torch.float32', '[2, 2]', '[2, 2]', 'False', 'False', + 0, 0, 0, 0, '0.0%', 'N/A', '0.0%', '0.0%', 2, 0, 1, 1, 2, 0, 1, 1, + True, '', '', ['File'], 'input', 'Functional.linear.0.forward' + ] + ] + columns = CompareConst.SUMMARY_COMPARE_RESULT_HEADER + ['NPU_Stack_Info'] + ['state', 'api_origin_name'] + o_result = pd.DataFrame(o_data, columns=columns, dtype=object) + self.assertTrue(np.array_equal(result.to_numpy(), o_result.to_numpy())) + + +class TestParseData(unittest.TestCase): + + def setUp(self): + os.makedirs(base_dir, mode=0o750, exist_ok=True) + generate_dump_json(base_dir) + generate_dump_json_md5(base_dir) + generate_stack_json(base_dir) + + self.lock = threading.Lock() + + def tearDown(self): + if os.path.exists(base_dir): + shutil.rmtree(base_dir) + + def test_parse(self): + file_list = [os.path.join(base_dir, 'dump.json'), os.path.join(base_dir, 'dump.json'), + os.path.join(base_dir, 'stack.json')] stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + mode_config = ModeConfig(stack_mode=stack_mode) + parse_data = ParseData(mode_config, 'rank0') + npu_df, bench_df = parse_data.parse(file_list) + + target_df = pd.DataFrame( + [['Functional.linear.0.forward.input.0', 'torch.float32', [2, 2], + [2, 0, 1, 1], ['File'], 'input', 'Functional.linear.0.forward', 'False']], + columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'state', 'api_origin_name', 'requires_grad'] + ) + self.assertTrue(npu_df.equals(target_df)) + self.assertTrue(bench_df.equals(target_df)) + + def test_gen_data_df_summary(self): + npu_json_path = os.path.join(base_dir, 'dump.json') + stack_json_path = os.path.join(base_dir, 'stack.json') + npu_json_data = load_json(npu_json_path) + stack_json_data = load_json(stack_json_path) - result = Comparator(mode_config).gen_merge_list(json_data, op_name, stack_json_data) - self.assertEqual(result, merge_list) + stack_mode = True + mode_config = ModeConfig(stack_mode=stack_mode) + parse_data = ParseData(mode_config, 'rank0') + npu_df = parse_data.gen_data_df(npu_json_data, stack_json_data, 'NPU') + + target_df = pd.DataFrame( + [['Functional.linear.0.forward.input.0', 'torch.float32', + [2, 2], [2, 0, 1, 1], ['File'], 'input', 'Functional.linear.0.forward', 'False']], + columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'state', 'api_origin_name', 'requires_grad'] + ) + self.assertTrue(npu_df.equals(target_df)) + + def test_gen_data_df_all(self): + npu_json_path = os.path.join(base_dir, 'dump.json') + stack_json_path = os.path.join(base_dir, 'stack.json') + npu_json_data = load_json(npu_json_path) + stack_json_data = load_json(stack_json_path) - def test_check_op_fuzzy_false(self): - stack_mode = False - auto_analyze = True - dump_mode = Const.SUMMARY + stack_mode = True + mode_config = ModeConfig(stack_mode=stack_mode, dump_mode=Const.ALL) + parse_data = ParseData(mode_config, 'rank0') + npu_df = parse_data.gen_data_df(npu_json_data, stack_json_data, 'NPU') + + target_df = pd.DataFrame( + [['Functional.linear.0.forward.input.0', 'torch.float32', [2, 2], + [2, 0, 1, 1], ['File'], 'input', 'Functional.linear.0.forward', 'False', 'Functional.linear.0.forward.input.0.pt']], + columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'state', 'api_origin_name', 'requires_grad', 'data_name'] + ) + self.assertTrue(npu_df.equals(target_df)) + + def test_gen_data_df_md5(self): + npu_json_path = os.path.join(base_dir, 'dump_md5.json') + stack_json_path = os.path.join(base_dir, 'stack.json') + npu_json_data = load_json(npu_json_path) + stack_json_data = load_json(stack_json_path) - fuzzy_match = False - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + stack_mode = True + mode_config = ModeConfig(stack_mode=stack_mode, dump_mode=Const.MD5) + parse_data = ParseData(mode_config, 'rank0') + npu_df = parse_data.gen_data_df(npu_json_data, stack_json_data, 'NPU') - pt_comparator = PTComparator(mode_config) - result = pt_comparator.check_op(npu_dict, bench_dict) - self.assertEqual(result, True) + target_df = pd.DataFrame( + [['Functional.linear.0.forward.input.0', 'torch.float32', [2, 2], + [2, 0, 1, 1], ['File'], 'input', 'Functional.linear.0.forward', 'False', 123456]], + columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'state', 'api_origin_name', 'requires_grad', 'md5'] + ) + self.assertTrue(npu_df.equals(target_df)) - def test_check_op_fuzzy_true(self): - stack_mode = False - auto_analyze = True - dump_mode = Const.SUMMARY + def test_gen_merge_list(self): + npu_json_path = os.path.join(base_dir, 'dump.json') + stack_json_path = os.path.join(base_dir, 'stack.json') + npu_json_data = load_json(npu_json_path) + stack_json_data = load_json(stack_json_path) - fuzzy_match = True - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + stack_mode = True + mode_config = ModeConfig(stack_mode=stack_mode) + parse_data = ParseData(mode_config, 'rank0') + merge_list = parse_data.gen_merge_list(npu_json_data, 'Functional.linear.0.forward', stack_json_data) + + target_merge_list = [ + { + 'full_op_name': 'Functional.linear.0.forward.input.0', + 'type': 'torch.Tensor', + 'dtype': 'torch.float32', + 'shape': [2, 2], + 'requires_grad': 'False', + 'Max': 2, + 'Min': 0, + 'Mean': 1, + 'Norm': 1, + 'md5': '00000000', + 'data_name': 'Functional.linear.0.forward.input.0.pt', + 'state': 'input' + }, + { + 'full_op_name': 'Functional.linear.0.forward', + 'full_info': ['File'] + } + ] + self.assertEqual(merge_list, target_merge_list) - pt_comparator = PTComparator(mode_config) - result = pt_comparator.check_op(npu_dict2, bench_dict) - self.assertEqual(result, True) + +class TestProcessDf(unittest.TestCase): + + def test_get_api_name_success(self): + api_list = ['Functional', 'linear', '0', 'forward', 'input', '0'] + + mode_config = ModeConfig() + mapping_config = MappingConfig() + mapping_dict = MappingDict(mapping_config) + process_df = ProcessDf(mode_config, mapping_config, mapping_dict) + api_name = process_df.get_api_name(api_list) + + target_api_name = 'Functional.linear' + self.assertEqual(api_name, target_api_name) + + @patch('msprobe.core.compare.acc_compare.logger') + def test_get_api_name_index_error(self, mock_logger): + api_list = ['Functional'] + with self.assertRaises(CompareException) as context: + mode_config = ModeConfig() + mapping_config = MappingConfig() + mapping_dict = MappingDict(mapping_config) + process_df = ProcessDf(mode_config, mapping_config, mapping_dict) + api_name = process_df.get_api_name(api_list) + self.assertEqual(context.exception.code, CompareException.INDEX_OUT_OF_BOUNDS_ERROR) + mock_logger.error.assert_called_once_with('Failed to retrieve API name, please check if the dump data is reasonable') + + def test_process_compare_key_and_shape(self): + npu_df_o = bench_df_o = pd.DataFrame( + [['Functional.linear.0.forward.input.0', 'torch.float32', [2, 2], [2, 0, 1, 1], ['File']]], + columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info'] + ) + + mode_config = ModeConfig() + mapping_config = MappingConfig() + mapping_dict = MappingDict(mapping_config) + process_df = ProcessDf(mode_config, mapping_config, mapping_dict) + npu_df, bench_df = process_df.process_compare_key_and_shape(npu_df_o, bench_df_o) + + target_df = pd.DataFrame( + [['Functional.linear.0.forward.input.0', 'torch.float32', [2, 2], [2, 0, 1, 1], ['File'], 'Functional.linear.0.forward.input.0', [2, 2]]], + columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'compare_key', 'compare_shape'] + ) + self.assertTrue(npu_df.equals(target_df)) + self.assertTrue(bench_df.equals(target_df)) + + def test_process_internal_api_mapping(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + mapping_dict = MappingDict(mapping_config) + process_df = ProcessDf(mode_config, mapping_config, mapping_dict) + + # mint to torch + npu_op_name = 'Mint.mean.0.input.0' + target_name = 'Torch.mean.0.input.0' + name = process_df.process_internal_api_mapping(npu_op_name) + self.assertEqual(name, target_name) + + # mintfunctional to functional + npu_op_name = 'MintFunctional.mean.0.input.0' + target_name = 'Functional.mean.0.input.0' + name = process_df.process_internal_api_mapping(npu_op_name) + self.assertEqual(name, target_name) + + # inner mapping exists + npu_op_name = 'Functional.abs.0.input.0' + mapping_dict.ms_to_pt_mapping = {'Functional.abs': 'Torch.abs'} + target_name = 'Torch.abs.0.input.0' + name = process_df.process_internal_api_mapping(npu_op_name) + self.assertEqual(name, target_name) + + # inner mapping not found + npu_op_name = 'Functional.abs.0.input.0' + mapping_dict.ms_to_pt_mapping = {} + target_name = 'Functional.abs.0.input.0' + name = process_df.process_internal_api_mapping(npu_op_name) + self.assertEqual(name, target_name) + + def test_modify_compare_data_with_user_mapping(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + mapping_dict = MappingDict(mapping_config) + process_df = ProcessDf(mode_config, mapping_config, mapping_dict) + mapping_dict.api_mapping_dict = [{ + 'ms_api': 'Functional.conv2d', + 'pt_api': 'Torch.conv2d', + 'ms_args': [0], + 'pt_args': [0] + }] + + npu_df = pd.DataFrame([ + ['Functional.conv2d.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.conv2d.0.forward.input.0', 'input', 'Functional.conv2d.0.forward'], + ['Functional.amax.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.amax.0.forward.input.0', 'input', 'Functional.amax.0.forward'] + ], columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'compare_key', 'state', 'api_origin_name']) + bench_df = pd.DataFrame([ + ['Torch.conv2d.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Torch.conv2d.0.forward.input.0', 'input', 'Functional.conv2d.0.forward'], + ['Torch.amax.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Torch.amax.0.forward.input.0', 'input', 'Functional.amax.0.forward'] + ], columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'compare_key', 'state', 'api_origin_name']) + + process_df.modify_compare_data_with_user_mapping(npu_df, bench_df) + + def test_get_api_indices_dict(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + mapping_dict = MappingDict(mapping_config) + process_df = ProcessDf(mode_config, mapping_config, mapping_dict) + + op_name_df = pd.DataFrame([ + ['Functional.conv2d.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.conv2d.0.forward.input.0'], + ['Functional.amax.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.amax.0.forward.input.0'] + ], columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'compare_key']) + + api_indices_dict = process_df.get_api_indices_dict(op_name_df) + expected = { + 'Functional.conv2d': [0], + 'Functional.amax': [1] + } + self.assertEqual(api_indices_dict, expected) + + def test_process_cell_mapping(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + mapping_dict = MappingDict(mapping_config) + process_df = ProcessDf(mode_config, mapping_config, mapping_dict) + + # not name + npu_op_name = None + name = process_df.process_cell_mapping(npu_op_name) + self.assertEqual(name, CompareConst.N_A) + + # not params_grad + npu_op_name = 'MintFunctional.embedding.0.input.0' + name = process_df.process_cell_mapping(npu_op_name) + self.assertEqual(name, CompareConst.N_A) + + # default replace + npu_op_name = 'Cell.network_with_loss.module.GPTModel.forward.1.input.0' + name = process_df.process_cell_mapping(npu_op_name) + self.assertEqual(name, 'Module.network_with_loss.module.GPTModel.forward.1.input.0') + + # mapping_dict + npu_op_name = 'Cell.fc1.Dense.forward.0.input.0' + mapping_dict.cell_mapping_dict = {'fc1.Dense': 'module.name'} + name = process_df.process_cell_mapping(npu_op_name) + self.assertEqual(name, 'Module.module.name.forward.0.input.0') + + def test_process_data_mapping(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + mapping_dict = MappingDict(mapping_config) + process_df = ProcessDf(mode_config, mapping_config, mapping_dict) + + npu_op_name = 'Functional.flash_attention_score.4.forward.input.0' + mapping_dict.data_mapping_dict = {'Functional.flash_attention_score.4.forward.input.0': 'NPU.npu_fusion_attention.4.forward.input.0'} + name = process_df.process_data_mapping(npu_op_name) + self.assertEqual(name, 'NPU.npu_fusion_attention.4.forward.input.0') + + +class TestMatch(unittest.TestCase): + + def test_put_unmatched_in_table(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + match = Match(mode_config, mapping_config, cross_frame=False) + + match_result = pd.DataFrame(columns=CompareConst.MATCH_RESULT_COLUMNS) + npu_op_item = pd.Series(['op', 'float32', [1, 2], 'summary', 'stack_info', + 'input', 'op_origin', 'False', 'data_name', 'op', [1, 2]], + index=['op_name_x', 'dtype_x', 'shape_x', 'summary_x', 'stack_info_x', + 'state_x', 'api_origin_name_x', 'data_name_x', 'requires_grad_x', + 'compare_key', 'compare_shape'] + ) + match_result = match.put_unmatched_in_table(match_result, npu_op_item) + target_match_result = pd.DataFrame([['op', 'float32', [1, 2], 'summary', 'stack_info', + 'input', 'op_origin', 'False', 'data_name', 'op', [1, 2], + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A']], + columns=CompareConst.MATCH_RESULT_COLUMNS) + self.assertTrue(match_result.equals(target_match_result)) + + def test_put_matched_in_table(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + match = Match(mode_config, mapping_config, cross_frame=False) + + match_result = pd.DataFrame(columns=CompareConst.MATCH_RESULT_COLUMNS) + npu_op_item = pd.Series(['op', 'float32', [1, 2], 'summary', 'stack_info', + 'input', 'op_origin', 'False', 'data_name', 'op', [1, 2]], + index=['op_name_x', 'dtype_x', 'shape_x', 'summary_x', 'stack_info_x', + 'state_x', 'api_origin_name_x', 'requires_grad_x', 'data_name_x', + 'compare_key', 'compare_shape'] + ) + bench_op_item = pd.Series(['op', 'float32', [1, 2], 'summary', 'stack_info', + 'input', 'op_origin', 'False', 'data_name', 'op', [1, 2]], + index=['op_name_y', 'dtype_y', 'shape_y', 'summary_y', 'stack_info_y', + 'state_y', 'api_origin_name_y', 'requires_grad_y', 'data_name_y', + 'compare_key', 'compare_shape'] + ) + match_result = match.put_matched_in_table(match_result, npu_op_item, bench_op_item) + target_match_result = pd.DataFrame([['op', 'float32', [1, 2], 'summary', 'stack_info', + 'input', 'op_origin', 'False', 'data_name', 'op', [1, 2], + 'op', 'float32', [1, 2], 'summary', 'stack_info', + 'input', 'op_origin', 'False', 'data_name']], + columns=CompareConst.MATCH_RESULT_COLUMNS) + self.assertTrue(match_result.equals(target_match_result)) + + def test_rename_api(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + match = Match(mode_config, mapping_config, cross_frame=False) + + op_name_1 = 'Functional.linear.0.forward.input.0' + result_1 = match.rename_api(op_name_1) + self.assertTrue(result_1, 'Functional.linear.input.0') + + op_name_2 = 'Functional.linear.0.backward.input.0' + result_2 = match.rename_api(op_name_2) + self.assertTrue(result_2, 'Functional.linear.input.0') + + op_name_3 = 'Functional.linear.0.x.input.0' + result_3 = match.rename_api(op_name_3) + self.assertTrue(result_3, 'Functional.linear.0.x.input.0') + + def test_check_op_item(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + match = Match(mode_config, mapping_config, cross_frame=False) + + npu_op_item = pd.Series(['op', 'float32', [1, 2], 'summary', 'stack_info', 'data_name', 'Functional.linear.0.forward.input.0', [1, 2]], + index=['op_name_x', 'dtype_x', 'shape_x', 'summary_x', 'stack_info_x', 'data_name_x', + 'compare_key', 'compare_shape'] + ) + bench_op_item = pd.Series(['op', 'float32', [1, 2], 'summary', 'stack_info', 'data_name', 'Functional.linear.1.forward.input.0', [1, 2]], + index=['op_name_y', 'dtype_y', 'shape_y', 'summary_y', 'stack_info_y', 'data_name_y', + 'compare_key', 'compare_shape'] + ) + result = match.check_op_item(npu_op_item, bench_op_item) + self.assertTrue(result) + + def test_process_fuzzy_match(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + match = Match(mode_config, mapping_config, cross_frame=False) + + npu_df = pd.DataFrame([ + ['Functional.conv2d.3.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', + 'input', 'Functional.conv2d.3.forward', 'True', 'Functional.conv2d.3.forward.input.0.pt', + 'Functional.conv2d.3.forward.input.0', [1, 2]], + ['Functional.amax.1.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', + 'input', 'Functional.amax.1.forward', 'True', 'Functional.amax.0.forward.input.0.pt', + 'Functional.amax.1.forward.input.0', [1, 2]] + ], columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', + 'state', 'api_origin_name', 'requires_grad', 'data_name', + 'compare_key', 'compare_shape']) + bench_df = pd.DataFrame([ + ['Functional.conv2d.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', + 'input', 'Functional.conv2d.0.forward', 'True', 'Functional.conv2d.0.forward.input.0.pt', + 'Functional.conv2d.0.forward.input.0', [1, 2]], + ['Functional.amax.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', + 'input', 'Functional.amax.0.forward', 'True', 'Functional.amax.0.forward.input.0.pt', + 'Functional.amax.0.forward.input.0', [1, 2]] + ], columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', + 'state', 'api_origin_name', 'requires_grad', 'data_name', 'compare_key', 'compare_shape']) + + match_result = match.process_fuzzy_match(npu_df, bench_df) + expected = pd.DataFrame( + [ + ['Functional.conv2d.3.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', + 'input', 'Functional.conv2d.3.forward', 'True', 'Functional.conv2d.3.forward.input.0.pt', + 'Functional.conv2d.3.forward.input.0', [1, 2], + 'Functional.conv2d.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', + 'input', 'Functional.conv2d.0.forward', 'True', 'Functional.conv2d.0.forward.input.0.pt'], + ['Functional.amax.1.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', + 'input', 'Functional.amax.1.forward', 'True', 'Functional.amax.0.forward.input.0.pt', + 'Functional.amax.1.forward.input.0', [1, 2], + 'Functional.amax.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', + 'input', 'Functional.amax.0.forward', 'True', 'Functional.amax.0.forward.input.0.pt'] + ] + , columns=CompareConst.MATCH_RESULT_COLUMNS) + + self.assertTrue(match_result.equals(expected)) def test_match_op_both_last_element(self): - stack_mode = False - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - pt_comparator = PTComparator(mode_config) - a, b = pt_comparator.match_op([npu_dict], [bench_dict]) + config_dict = { + 'stack_mode': False, + 'auto_analyze': True, + 'fuzzy_match': False, + 'dump_mode': Const.SUMMARY, + } + mode_config = ModeConfig(**config_dict) + mapping_config = MappingConfig() + + match = Match(mode_config, mapping_config, cross_frame=False) + a, b = match.match_op([npu_op_item_fuzzy], [bench_op_item_fuzzy]) self.assertEqual(a, 0) self.assertEqual(b, 0) def test_match_op_only_npu_last_element(self): - stack_mode = False - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - pt_comparator = PTComparator(mode_config) - a, b = pt_comparator.match_op([npu_dict], [bench_dict, 1]) + config_dict = { + 'stack_mode': False, + 'auto_analyze': True, + 'fuzzy_match': False, + 'dump_mode': Const.SUMMARY, + } + mode_config = ModeConfig(**config_dict) + mapping_config = MappingConfig() + + match = Match(mode_config, mapping_config, cross_frame=False) + a, b = match.match_op([npu_op_item_fuzzy], [bench_op_item_fuzzy, 1]) self.assertEqual(a, 0) self.assertEqual(b, 0) def test_match_op_only_bench_last_element(self): - stack_mode = False - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - pt_comparator = PTComparator(mode_config) - a, b = pt_comparator.match_op([npu_dict, npu_dict2], [bench_dict]) + config_dict = { + 'stack_mode': False, + 'auto_analyze': True, + 'fuzzy_match': False, + 'dump_mode': Const.SUMMARY, + } + mode_config = ModeConfig(**config_dict) + mapping_config = MappingConfig() + + match = Match(mode_config, mapping_config, cross_frame=False) + a, b = match.match_op([npu_op_item_fuzzy, npu_op_item_data_fuzzy_2], [bench_op_item_fuzzy]) self.assertEqual(a, 0) self.assertEqual(b, 0) - def test_compare_process(self): - generate_dump_json(base_dir) - generate_stack_json(base_dir) - file_lists = [os.path.join(base_dir, 'dump.json'), os.path.join(base_dir, 'dump.json'), - os.path.join(base_dir, 'stack.json')] + def test_gen_dtype_condition(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + match = Match(mode_config, mapping_config, cross_frame=True) - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + # data mapping + mapping_config.data_mapping = True + match_result = pd.DataFrame([1, 2, 3]) + result = match.gen_dtype_condition(match_result) + expected = pd.Series([True, True, True]) + self.assertTrue(result.equals(expected)) - result = PTComparator(mode_config).compare_process(file_lists) - o_data = [ - ['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], 0, 0, 0, 0, '0.0%', 'N/A', '0.0%', '0.0%', - 2, 0, 1, 1, 2, 0, 1, 1, '', '', ['File'] - ] - ] - columns = CompareConst.SUMMARY_COMPARE_RESULT_HEADER + ['NPU_Stack_Info'] - o_result = pd.DataFrame(o_data, columns=columns, dtype=object) - self.assertTrue(result.equals(o_result)) + # normal + mapping_config.data_mapping = None + match_result = pd.DataFrame([['Float16', 'Float32'], ['torch.float32', 'torch.bfloat16']], columns=['dtype_x', 'dtype_y']) + result = match.gen_dtype_condition(match_result) + expected = pd.Series([True, True]) + self.assertTrue(result.equals(expected)) - def test_merge_data(self): - op_data = { - 'input_args': [ - { - 'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [2, 2], - 'Max': 1, 'Min': 1, 'Mean': 1, 'Norm': 1, 'requires_grad': False, - 'data_name': 'Functional.linear.0.forward.input.0.pt', - 'full_op_name': 'Functional.linear.0.forward.input.0' - } - ] - } - json_data = {'data': {'Functional.linear.0.forward': op_data}} - stack_json_data = {'Functional.linear.0.forward': ['File']} + def test_process_cross_frame_dtype(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + match = Match(mode_config, mapping_config, cross_frame=True) - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - result = Comparator(mode_config).merge_data(json_data, stack_json_data) - ops_all = { - 'Functional.linear.0.forward.input.0': { - 'data_name': None, 'stack_info': [['File']], - 'struct': ('torch.float32', [2, 2]), 'summary': [1, 1, 1, 1] - } + dtype_o = pd.Series(['Int8', 'Float16', 'torch.bool', 'Complex64', 'unknown']) + dtype = match.process_cross_frame_dtype(dtype_o) + self.assertTrue(dtype.equals(pd.Series(['int', 'float', 'bool', 'complex', 'unknown']))) + + +class TestCreateTable(unittest.TestCase): + + def test_process_data_name(self): + mode_config = ModeConfig() + create_table = CreateTable(mode_config) + + data = { + 'data_name_x': ['A', 'B', 'C'], + 'data_name_y': ['X', 'Y', 'Z'] } - self.assertEqual(result, ops_all) - - def test_compare_core_basic(self): - generate_dump_json(base_dir2) - generate_stack_json(base_dir2) - input_params = { - "npu_json_path": os.path.join(base_dir2, "dump.json"), - "bench_json_path": os.path.join(base_dir2, "dump.json"), - "stack_json_path": os.path.join(base_dir2, "stack.json"), + result_o = pd.DataFrame(data) + result = create_table.process_data_name(result_o) + target_data = { + 'data_name_x': [['A', 'X'], ['B', 'Y'], ['C', 'Z']], + 'data_name_y': ['X', 'Y', 'Z'] } - output_path = base_dir2 - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - PTComparator(mode_config).compare_core(input_params, output_path) - - output_files = os.listdir(output_path) - self.assertTrue(any(f.endswith(".xlsx") for f in output_files)) - - def test_compare_ops(self): - generate_dump_json(base_dir3) - generate_stack_json(base_dir3) - generate_pt(pt_dir) - dump_path = os.path.join(base_dir3, 'dump.json') - stack_path = os.path.join(base_dir3, 'stack.json') - input_param = {'npu_json_path': dump_path, 'bench_json_path': dump_path, 'stack_json_path': stack_path, - 'is_print_compare_log': True, 'npu_dump_data_dir': pt_dir, 'bench_dump_data_dir': pt_dir} - dump_path_dict = {'Functional.linear.0.forward.input.0': ['Functional.linear.0.forward.input.0.pt', - 'Functional.linear.0.forward.input.0.pt']} - result_df = pd.DataFrame({ - 'NPU Name': ['Functional.linear.0.forward.input.0'], - 'Bench Name': ['Functional.linear.0.forward.input.0'] - }) + target_result = pd.DataFrame(target_data) + self.assertTrue(result.equals(target_result)) - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - pt_comparator = PTComparator(mode_config) - updated_df = pt_comparator.compare_ops(idx=0, dump_path_dict=dump_path_dict, result_df=result_df, - lock=self.lock, input_param=input_param) - - self.assertEqual(updated_df.loc[0, CompareConst.COSINE], 1.0) - self.assertEqual(updated_df.loc[0, CompareConst.MAX_ABS_ERR], 0) - - def test_do_multi_process(self): - data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1']] - o_data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], 'unsupported', 'unsupported', 'unsupported', - 'unsupported', 'unsupported', - 1, 1, 1, 1, 1, 1, 1, 1, 'None', 'No bench data matched.', '-1']] - columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] - result_df = pd.DataFrame(data, columns=columns) - o_result = pd.DataFrame(o_data, columns=columns) - generate_dump_json(base_dir) - input_param = {'bench_json_path': os.path.join(base_dir, 'dump.json')} + def test_set_summary(self): + mode_config = ModeConfig() + create_table = CreateTable(mode_config) - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + # all nan + result = create_table.set_summary(['nan', 'NaN', 'nAn']) + expected = [CompareConst.NAN, CompareConst.NAN, CompareConst.NAN] + self.assertEqual(result, expected) - comparator = Comparator(mode_config) - result = comparator.do_multi_process(input_param, result_df) - self.assertTrue(result.equals(o_result)) + # mixed values + result = create_table.set_summary([1, 'nan', 2.0, 'NaN']) + expected = [1, CompareConst.NAN, 2.0, CompareConst.NAN] + self.assertEqual(result, expected) - def test_compare_by_op_1(self): - npu_op_name = 'Functional.linear.0.forward.input.0' - bench_op_name = 'N/A' - op_name_mapping_dict = {'Functional.linear.0.forward.input.0': [-1, -1]} - input_param = {} + # NA case + result = create_table.set_summary(CompareConst.N_A) + expected = [CompareConst.N_A, CompareConst.N_A, CompareConst.N_A, CompareConst.N_A] + self.assertEqual(result, expected) - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + # empty input + result = create_table.set_summary([]) + expected = [] + self.assertEqual(result, expected) - pt_comparator = PTComparator(mode_config) - result = pt_comparator.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param, {}) - self.assertEqual(result, ['unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', - 'No bench data matched.']) +class TestCalcStatsDiff(unittest.TestCase): - def test_compare_by_op_2(self): - npu_op_name = 'Functional.linear.0.forward.input.0' - bench_op_name = 'Functional.linear.0.forward.input.0' + def test_type_check(self): + mode_config = ModeConfig() + calc_stats_diff = CalcStatsDiff(mode_config) - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - pt_comparator = PTComparator(mode_config) - - pt_name = '-1' - pt_path = os.path.join(base_dir, pt_name) - op_name_mapping_dict = {'Functional.linear.0.forward.input.0': [pt_path, pt_path]} - input_param = {'npu_dump_data_dir': base_dir, 'bench_dump_data_dir': base_dir} - result = pt_comparator.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param, - {'Functional.linear.0.forward': {'input_args': [ - {'data_name': 'Functional.linear.0.forward.input.0.pt'}]}}) - self.assertEqual(result, ['unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', - f'Dump file: {pt_path} not found.']) - - pt_name = 'Functional.linear.0.forward.input.0.pt' - pt_path = os.path.join(base_dir, pt_name) - op_name_mapping_dict = {'Functional.linear.0.forward.input.0': [pt_path, pt_path]} - input_param = {'npu_dump_data_dir': base_dir, 'bench_dump_data_dir': base_dir} - result = pt_comparator.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param, {}) - self.assertEqual(result, ['unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', - 'Bench does not have data file.']) - - generate_pt(base_dir) - result = pt_comparator.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param, - {'Functional.linear.0.forward': {'input_args': [ - {'data_name': 'Functional.linear.0.forward.input.0.pt'}]}}) - self.assertEqual(result, [1.0, 0.0, 0.0, 1.0, 1.0, '']) - - def test_get_bench_data_name_input(self): - bench_op_name = "Functional.linear.0.forward.input.0" - bench_data = {"Functional.linear.0.forward": {"input_args": [{"data_name": "Functional.linear.0.forward.input.0.pt"}], "input_kwargs": {}, "output": []}} - result = get_bench_data_name(bench_op_name, bench_data) - - self.assertEqual(result, "Functional.linear.0.forward.input.0.pt") - - def test_get_bench_data_name_output(self): - bench_op_name = "Functional.linear.0.forward.output.0" - bench_data = {"Functional.linear.0.forward": {"input_args": [], "input_kwargs": {}, "output": [{"data_name": "Functional.linear.0.forward.output.0.pt"}]}} - result = get_bench_data_name(bench_op_name, bench_data) - - self.assertEqual(result, "Functional.linear.0.forward.output.0.pt") - - -class TestComparator(unittest.TestCase): - def setUp(self): - mode_config = ModeConfig(dump_mode=Const.MD5) - self.comparator = Comparator(mode_config=mode_config) - self.npu_ops_all = { - 'op1': {'struct': ['float32', [1, 96, 2], '83dcefb7']}, - } - self.bench_ops_all = { - 'op1': {'struct': ['float32', [1, 96, 2], '83dcefb7']}, - } + series = pd.Series([float('nan'), 5, 'nan', 10, 'abc', None]) + result = calc_stats_diff.type_check(series) + expected = pd.Series([True, True, True, True, False, False]) + self.assertTrue(result.equals(expected)) - def test_normal(self): - expected_result = ['op1', 'op1', 'float32', 'float32', [1, 96, 2], [1, 96, 2], '83dcefb7', '83dcefb7', - CompareConst.PASS, CompareConst.NONE] - result = self.comparator.get_result_md5_compare('op1', 'op1', - self.npu_ops_all, self.bench_ops_all) - self.assertEqual(result, expected_result) + def test_get_number(self): + mode_config = ModeConfig() + calc_stats_diff = CalcStatsDiff(mode_config) - @patch('msprobe.core.compare.acc_compare.logger') - def test_length_exception(self, mock_logger): - self.npu_ops_all['op1']['struct'] = ['npu_val1', 'npu_val2'] - with self.assertRaises(CompareException) as context: - self.comparator.get_result_md5_compare('op1', 'op1', - self.npu_ops_all, self.bench_ops_all) - self.assertEqual(context.exception.code, CompareException.INDEX_OUT_OF_BOUNDS_ERROR) - mock_logger.error.assert_called_once_with("The length of npu_struct and bench_struct must be >= 3, " - "but got npu_struct=2 and bench_struct=3. Please check!") - - def test_with_extra_args(self): - expected_result = ['op1', 'op1', 'float32', 'float32', [1, 96, 2], [1, 96, 2], '83dcefb7', '83dcefb7', - CompareConst.PASS, 'extra_data'] - result = self.comparator.get_result_md5_compare('op1', 'op1', - self.npu_ops_all, self.bench_ops_all, True, ['extra_data']) - self.assertEqual(result, expected_result) + series = pd.Series([1, '2', 3.5, 'text', None]) + result = calc_stats_diff.get_number(series) + expected = pd.Series([1, 2, 3.5, float('nan'), float('nan')]) + self.assertTrue(result.equals(expected)) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_check.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_check.py index a1e5f8eee1bce9b170e6f4f7fdfeda65d47252c9..212cb49eaa40d3a6110e844e43406a31b307ca98 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_check.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_check.py @@ -1,8 +1,12 @@ # coding=utf-8 import unittest -from msprobe.core.compare.check import check_struct_match, check_type_shape_match, check_graph_mode, fuzzy_check_op, \ - fuzzy_check_name, check_dump_json_str, check_json_key_value, valid_key_value, check_stack_json_str +from unittest.mock import patch + +from msprobe.core.compare.check import check_dump_json_str, check_json_key_value, valid_key_value, \ + check_stack_json_str, check_configuration_param from msprobe.core.common.utils import CompareException +from msprobe.core.common.log import logger +from msprobe.core.compare.acc_compare import ComparisonConfig # test_check_struct_match @@ -65,87 +69,6 @@ op_name = 'Functional.conv2d.0.backward.input.0' class TestUtilsMethods(unittest.TestCase): - - def test_check_struct_match_success(self): - result = check_struct_match(npu_dict, bench_dict) - self.assertTrue(result) - - def test_check_struct_match_fail(self): - npu_dict2 = {'input_struct': [('torch.float32', [1, 1, 28, 28]), ('torch.float32', [16, 1, 5, 5]), - ('torch.float32', [16])], - 'output_struct': [('torch.float32', [1, 16, 28, 28])] - } - - bench_dict2 = {'input_struct': [('torch.float32', [2, 1, 28, 28]), ('torch.float32', [16, 1, 5, 5]), - ('torch.float32', [16])], - 'output_struct': [('torch.float32', [1, 16, 28, 28])] - } - result = check_struct_match(npu_dict2, bench_dict2) - self.assertFalse(result) - - def test_check_struct_index_error(self): - npu_dict3 = {'input_struct': [('a'), ('torch.float32'), - ('torch.float32')], - 'output_struct': [('torch.float32')] - } - - bench_dict3 = {'input_struct': [('torch.float32'), ('torch.float32'), - ('torch.float32')], - 'output_struct': [('torch.float32')] - } - with self.assertRaises(CompareException) as context: - result = check_struct_match(npu_dict3, bench_dict3) - self.assertEqual(context.exception.code, CompareException.INDEX_OUT_OF_BOUNDS_ERROR) - - def test_check_type_shape_match_success(self): - result = check_type_shape_match(npu_struct, bench_struct) - self.assertTrue(result) - - def test_check_type_shape_match_index_error(self): - npu_struct2 = [('a'), ('torch.float32'), ('torch.float32')] - bench_struct2 = [('torch.float32'), ('torch.float32'), ('torch.float32')] - with self.assertRaises(CompareException) as context: - result = check_type_shape_match(npu_struct2, bench_struct2) - self.assertEqual(context.exception.code, CompareException.INDEX_OUT_OF_BOUNDS_ERROR) - - def test_check_graph_mode(self): - op1 = "Aten" - op2 = "torch" - self.assertTrue(check_graph_mode(op1, op2)) - self.assertTrue(check_graph_mode(op2, op1)) - self.assertFalse(check_graph_mode(op1, op1)) - self.assertFalse(check_graph_mode(op2, op2)) - - def test_fuzzy_check_op_1(self): - npu_name_list = [] - bench_name_list = [] - result = fuzzy_check_op(npu_name_list, bench_name_list) - self.assertFalse(result) - - def test_fuzzy_check_op_2(self): - npu_name_list = [] - bench_name_list = ['Functional.conv2d.0.forward.input.0'] - result = fuzzy_check_op(npu_name_list, bench_name_list) - self.assertFalse(result) - - def test_fuzzy_check_op_3(self): - npu_name_list = ['Functional.conv2d.0.forward.input.0'] - bench_name_list = ['Functional.conv2d.1.forward.input.0'] - result = fuzzy_check_op(npu_name_list, bench_name_list) - self.assertTrue(result) - - def test_fuzzy_check_name_1(self): - npu_name = 'Functional.conv2d.0.backward.input.0' - bench_name = 'Functional.conv2d.1.backward.input.0' - result = fuzzy_check_name(npu_name, bench_name) - self.assertTrue(result) - - def test_fuzzy_check_name_2(self): - npu_name = 'Functional.conv2d.0.backward.input.0' - bench_name = 'Functional.conv2d.1.backward.input.1' - result = fuzzy_check_name(npu_name, bench_name) - self.assertFalse(result) - def test_check_dump_json_str(self): with self.assertRaises(CompareException) as context: check_dump_json_str(op_data, op_name) @@ -157,7 +80,7 @@ class TestUtilsMethods(unittest.TestCase): self.assertEqual(context.exception.code, CompareException.INVALID_CHAR_ERROR) def test_check_json_key_value_max_depth(self): - result = check_json_key_value(input_output, op_name, depth=11) + result = check_json_key_value(input_output, op_name, depth=401) self.assertEqual(result, None) def test_valid_key_value_type_shape(self): @@ -203,3 +126,25 @@ class TestUtilsMethods(unittest.TestCase): with self.assertRaises(CompareException) as context: check_stack_json_str(stack_info, op_name) self.assertEqual(context.exception.code, CompareException.INVALID_CHAR_ERROR) + + @patch.object(logger, "error") + def test_check_configuration_param(self, mock_error): + config = ComparisonConfig( + dump_mode='', + stack_mode='False', + auto_analyze=True, + fuzzy_match=False, + highlight=False, + data_mapping={}, + suffix='', + cell_mapping={}, + api_mapping={}, + layer_mapping={}, + first_diff_analyze=False, + compared_file_type='', + is_print_compare_log=True + ) + with self.assertRaises(CompareException) as context: + check_configuration_param(config) + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_error.assert_called_with("Invalid input parameter, stack_mode which should be only bool type.") \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_npy_compare.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_npy_compare.py index aec6cdc51173ae817f32dd76455bec645659b45c..417ea2b9f043e67f5478af4fede9b5dbc19ba80b 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_npy_compare.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_npy_compare.py @@ -20,7 +20,7 @@ from unittest.mock import patch from msprobe.core.common.const import CompareConst from msprobe.core.compare.npy_compare import handle_inf_nan, reshape_value, get_error_flag_and_msg, \ npy_data_check, statistics_data_check, get_relative_err, GetCosineSimilarity, GetMaxAbsErr, GetMaxRelativeErr, \ - GetErrRatio, error_value_process, compare_ops_apply + GetErrRatio, error_value_process, compare_ops_apply, GetEuclideanDistance op_name = 'Functional.conv2d.0.backward.input.0' @@ -85,13 +85,14 @@ class TestUtilsMethods(unittest.TestCase): n_value = np.array([1, 2, np.inf, 4]) b_value = np.array([1, 2, 3, 4]) error_flag = True + error_file = 'fake file' - n_value, b_value, error_flag, err_msg = get_error_flag_and_msg(n_value, b_value, error_flag=error_flag) + n_value, b_value, error_flag, err_msg = get_error_flag_and_msg(n_value, b_value, error_flag=error_flag, error_file=error_file) self.assertEqual(n_value, CompareConst.READ_NONE) self.assertEqual(b_value, CompareConst.READ_NONE) self.assertTrue(error_flag) - self.assertEqual(err_msg, CompareConst.NO_BENCH) + self.assertEqual(err_msg, "Dump file: fake file not found or read failed.") def test_get_error_flag_and_msg_none(self): n_value = np.array([]) @@ -113,7 +114,7 @@ class TestUtilsMethods(unittest.TestCase): n_value, b_value, error_flag, err_msg = get_error_flag_and_msg(n_value, b_value, error_flag=error_flag) self.assertFalse(error_flag) - self.assertEqual(err_msg, "This is type of 0-d tensor, can not calculate 'Cosine', " + self.assertEqual(err_msg, "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', " "'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'. ") def test_get_error_flag_and_msg_shape_unmatch(self): @@ -239,15 +240,17 @@ class TestUtilsMethods(unittest.TestCase): b_value_1 = np.array(1) relative_err = get_relative_err(n_value_1, b_value_1) n_value_1, b_value_1 = reshape_value(n_value_1, b_value_1) - result, err_msg = op.apply(n_value_1, b_value_1, relative_err) + err_msg = "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'. " + result, err_msg = op.apply(n_value_1, b_value_1, relative_err, err_msg) self.assertEqual(result, CompareConst.UNSUPPORTED) - self.assertEqual(err_msg, "") + self.assertEqual(err_msg, "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'. ") n_value_2 = np.array([1, 2]) b_value_2 = np.array([1, 2]) relative_err = get_relative_err(n_value_2, b_value_2) n_value_2, b_value_2 = reshape_value(n_value_2, b_value_2) - result, err_msg = op.apply(n_value_2, b_value_2, relative_err) + err_msg = "" + result, err_msg = op.apply(n_value_2, b_value_2, relative_err, err_msg) self.assertEqual(result, 1.0) self.assertEqual(err_msg, "") @@ -255,7 +258,8 @@ class TestUtilsMethods(unittest.TestCase): b_value_3 = np.array([0, 0]) relative_err = get_relative_err(n_value_3, b_value_3) n_value_3, b_value_3 = reshape_value(n_value_3, b_value_3) - result, err_msg = op.apply(n_value_3, b_value_3, relative_err) + err_msg = "" + result, err_msg = op.apply(n_value_3, b_value_3, relative_err, err_msg) self.assertEqual(result, 1.0) self.assertEqual(err_msg, "") @@ -263,7 +267,8 @@ class TestUtilsMethods(unittest.TestCase): b_value_4 = np.array([1, 2]) relative_err = get_relative_err(n_value_4, b_value_4) n_value_4, b_value_4 = reshape_value(n_value_4, b_value_4) - result, err_msg = op.apply(n_value_4, b_value_4, relative_err) + err_msg = "" + result, err_msg = op.apply(n_value_4, b_value_4, relative_err, err_msg) self.assertEqual(result, CompareConst.NAN) self.assertEqual(err_msg, 'Cannot compare by Cosine Similarity, All the data is Zero in npu dump data.') @@ -271,7 +276,8 @@ class TestUtilsMethods(unittest.TestCase): b_value_5 = np.array([0, 0]) relative_err = get_relative_err(n_value_5, b_value_5) n_value_5, b_value_5 = reshape_value(n_value_5, b_value_5) - result, err_msg = op.apply(n_value_5, b_value_5, relative_err) + err_msg = "" + result, err_msg = op.apply(n_value_5, b_value_5, relative_err, err_msg) self.assertEqual(result, CompareConst.NAN) self.assertEqual(err_msg, 'Cannot compare by Cosine Similarity, All the data is Zero in Bench dump data.') @@ -282,7 +288,9 @@ class TestUtilsMethods(unittest.TestCase): b_value_1 = np.array([1]) relative_err = get_relative_err(n_value_1, b_value_1) n_value_1, b_value_1 = reshape_value(n_value_1, b_value_1) - result, err_msg = op.apply(n_value_1, b_value_1, relative_err) + err_msg = "" + + result, err_msg = op.apply(n_value_1, b_value_1, relative_err, err_msg) self.assertEqual(result, CompareConst.UNSUPPORTED) self.assertEqual(err_msg, "This is a 1-d tensor of length 1.") @@ -294,8 +302,9 @@ class TestUtilsMethods(unittest.TestCase): b_value = np.array([1, 1]) relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, CompareConst.NAN) self.assertEqual(err_msg, "Cannot compare by Cosine Similarity, the dump data has NaN.") @@ -319,8 +328,9 @@ class TestUtilsMethods(unittest.TestCase): b_value = np.array([0, 0]) relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, 2.0) self.assertEqual(err_msg, "") @@ -333,8 +343,9 @@ class TestUtilsMethods(unittest.TestCase): b_value = np.array([1, 1]) relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, CompareConst.NAN) self.assertEqual(err_msg, "Cannot compare by MaxAbsError, the data contains nan/inf/-inf in dump data.") @@ -347,8 +358,9 @@ class TestUtilsMethods(unittest.TestCase): b_value = np.array([1, 1]) relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, 1.0) self.assertEqual(err_msg, "") @@ -361,8 +373,9 @@ class TestUtilsMethods(unittest.TestCase): b_value = np.array([1, 1]) relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, CompareConst.NAN) self.assertEqual(err_msg, "Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data.") @@ -375,8 +388,9 @@ class TestUtilsMethods(unittest.TestCase): b_value = np.array([1, 1]) relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, 0.5) self.assertEqual(err_msg, "") @@ -387,11 +401,12 @@ class TestUtilsMethods(unittest.TestCase): n_value = np.array(1) # 标量 b_value = np.array(1) relative_err = np.array(0) + err_msg = "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'. " - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, CompareConst.UNSUPPORTED) - self.assertEqual(err_msg, "") + self.assertEqual(err_msg, "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'. ") def test_GetThousandErrRatio_not_size(self): op = GetErrRatio(CompareConst.THOUSAND_RATIO_THRESHOLD) @@ -399,8 +414,9 @@ class TestUtilsMethods(unittest.TestCase): n_value = np.array([1, 2]) b_value = np.array([1, 2]) relative_err = np.array([]) # 空数组 + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, CompareConst.NAN) self.assertEqual(err_msg, "") @@ -412,8 +428,9 @@ class TestUtilsMethods(unittest.TestCase): b_value = np.array([1, 1]) relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, 0.5) self.assertEqual(err_msg, "") @@ -438,7 +455,7 @@ class TestUtilsMethods(unittest.TestCase): result, err_msg = error_value_process(n_value) - self.assertEqual(result, 0) + self.assertEqual(result, CompareConst.UNSUPPORTED) self.assertEqual(err_msg, "") def test_error_value_process_shape_unmatch(self): @@ -471,5 +488,34 @@ class TestUtilsMethods(unittest.TestCase): error_flag = False err_msg = '' a, b = compare_ops_apply(n_value, b_value, error_flag, err_msg) - self.assertEqual(a, [1.0, 0.0, 0.0, 1.0, 1.0]) + self.assertEqual(a, [1.0, 0.0, 0.0, 0.0, 1.0, 1.0]) self.assertEqual(b, '') + + +class TestGetEuclideanDistance(unittest.TestCase): + + def setUp(self): + self.euc_distance = GetEuclideanDistance() + + def test_euclidean_distance_normal(self): + # 测试计算两个张量之间的欧式距离 + n_value = np.array([1, 2, 3]) + b_value = np.array([4, 5, 6]) + relative_err = None + err_msg = "" + + result, msg = self.euc_distance.apply(n_value, b_value, relative_err, err_msg) + expected_distance = np.linalg.norm(n_value - b_value) + self.assertEqual(result, expected_distance) + self.assertEqual(msg, '') + + def test_euclidean_distance_0d_tensor(self): + # 测试计算两个张量之间的欧式距离 + n_value = np.array(1) + b_value = np.array(1) + relative_err = None + err_msg = "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'. " + + result, msg = self.euc_distance.apply(n_value, b_value, relative_err, err_msg) + self.assertEqual(result, CompareConst.UNSUPPORTED) + self.assertEqual(msg, "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'. ") diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py index ab8703dcd353ff32dc0722fc314ade6042d6f567..286fecf84d50e34c0110318b4296a2330ec071fe 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py @@ -6,15 +6,17 @@ import shutil import unittest from unittest.mock import patch import zlib +import tempfile import numpy as np +import pandas as pd from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.utils import CompareException from msprobe.core.compare.utils import ApiItemInfo, _compare_parser, check_and_return_dir_contents, extract_json, \ - count_struct, get_accuracy, append_stack_info, get_rela_diff_summary_mode, get_un_match_accuracy, merge_tensor, \ - op_item_parse, read_op, rename_api, resolve_api_special_parameters, result_item_init, stack_column_process, \ - table_value_is_valid, get_name_and_state, reorder_op_name_list, reorder_op_x_list, gen_op_item + count_struct, get_accuracy, get_rela_diff_summary_mode, merge_tensor, op_item_parse, read_op, result_item_init, \ + stack_column_process, table_value_is_valid, reorder_op_name_list, gen_op_item, ApiBatch, get_paired_dirs, \ + reorder_index, gen_api_batches # test_read_op_1 op_data = { @@ -32,15 +34,19 @@ op_name = "Tensor.add_0.0.forward" op_result = [ {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], 'md5': '00000000', 'Max': 0.33033010363578796, 'Min': -0.331031858921051, 'Mean': -0.030964046716690063, 'data_name': '-1', - 'Norm': 2.2533628940582275, 'requires_grad': True, 'full_op_name': 'Tensor.add_0.0.forward.input.0'}, + 'Norm': 2.2533628940582275, 'requires_grad': 'True', 'full_op_name': 'Tensor.add_0.0.forward.input.0', + 'state': 'input'}, {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], 'md5': '00000000', 'Max': 0.003992878366261721, 'Min': -0.008102823048830032, 'Mean': -0.0002002553956117481, 'data_name': '-1', - 'Norm': 0.02844562754034996, 'requires_grad': False, 'full_op_name': 'Tensor.add_0.0.forward.input.1'}, + 'Norm': 0.02844562754034996, 'requires_grad': 'False', 'full_op_name': 'Tensor.add_0.0.forward.input.1', + 'state': 'input'}, {'full_op_name': 'Tensor.add_0.0.forward.input.alpha', 'dtype': "", 'shape': '[]', 'md5': '0dae4479', - 'Max': -0.1, 'Min': -0.1, 'Mean': -0.1, 'Norm': -0.1, 'data_name': '-1', 'type': 'float', 'value': -0.1}, + 'Max': -0.1, 'Min': -0.1, 'Mean': -0.1, 'Norm': -0.1, 'requires_grad': None, 'data_name': '-1', 'type': 'float', + 'value': -0.1, 'state': 'input'}, {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], 'md5': '00000000', 'Max': 0.33033010363578796, 'Min': -0.331031858921051, 'Mean': -0.030964046716690063, 'data_name': '-1', - 'Norm': 2.2533628940582275, 'requires_grad': True, 'full_op_name': 'Tensor.add_0.0.forward.output.0'}] + 'Norm': 2.2533628940582275, 'requires_grad': 'True', 'full_op_name': 'Tensor.add_0.0.forward.output.0', + 'state': 'output'}] # test_read_op_1 op_data_b = { @@ -57,13 +63,16 @@ op_name_b = "Tensor.add_0.0.backward" op_result_b = [ {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], 'data_name': '-1', 'md5': '00000000', 'Max': 0.33033010363578796, 'Min': -0.331031858921051, 'Mean': -0.030964046716690063, - 'Norm': 2.2533628940582275, 'requires_grad': True, 'full_op_name': 'Tensor.add_0.0.backward.input.0'}, + 'Norm': 2.2533628940582275, 'requires_grad': 'True', 'full_op_name': 'Tensor.add_0.0.backward.input.0', + 'state': 'input'}, {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], 'data_name': '-1', 'md5': '00000000', 'Max': 0.003992878366261721, 'Min': -0.008102823048830032, 'Mean': -0.0002002553956117481, - 'Norm': 0.02844562754034996, 'requires_grad': False, 'full_op_name': 'Tensor.add_0.0.backward.input.1'}, + 'Norm': 0.02844562754034996, 'requires_grad': 'False', 'full_op_name': 'Tensor.add_0.0.backward.input.1', + 'state': 'input'}, {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], 'data_name': '-1', 'md5': '00000000', 'Max': 0.33033010363578796, 'Min': -0.331031858921051, 'Mean': -0.030964046716690063, - 'Norm': 2.2533628940582275, 'requires_grad': True, 'full_op_name': 'Tensor.add_0.0.backward.output.0'}] + 'Norm': 2.2533628940582275, 'requires_grad': 'True', 'full_op_name': 'Tensor.add_0.0.backward.output.0', + 'state': 'output'}] # test_op_item_parse parse_item = [ @@ -77,14 +86,15 @@ parse_index = None parse_item_list = None parse_top_bool = True o_result_parse = [ - {'Max': 4097.0, 'Mean': 820.2, 'Min': 0.0, 'Norm': 4097.0, 'dtype': 'torch.int64', 'requires_grad': False, + {'Max': 4097.0, 'Mean': 820.2, 'Min': 0.0, 'Norm': 4097.0, 'dtype': 'torch.int64', 'requires_grad': 'False', 'shape': [5], 'type': 'torch.Tensor', 'full_op_name': 'Distributed.broadcast.0.forward.input.0', - 'data_name': '-1', 'md5': '00000000'}, + 'data_name': '-1', 'md5': '00000000', 'state': 'input'}, {'full_op_name': 'Distributed.broadcast.0.forward.input.1', 'dtype': "", 'shape': '[]', - 'md5': 'f4dbdf21', 'Max': 0, 'Min': 0, 'Mean': 0, 'Norm': 0, 'data_name': '-1', 'type': 'int', 'value': 0}, + 'md5': 'f4dbdf21', 'Max': 0, 'Min': 0, 'Mean': 0, 'Norm': 0, 'data_name': '-1', 'type': 'int', 'value': 0, + 'state': 'input', 'requires_grad': None}, {'Max': None, 'Mean': None, 'Min': None, 'Norm': None, 'data_name': '-1', 'dtype': 'slice', 'type': 'slice', 'full_op_name': 'Distributed.broadcast.0.forward.input.2', 'md5': '5fbbe87f', 'shape': '(3,)', - 'value': [None, None, None]} + 'value': [None, None, None], 'state': 'input', 'requires_grad': None} ] # test_resolve_api_special_parameters @@ -119,7 +129,8 @@ npu_dict = {'op_name': ['Functional.conv2d.0.forward.input.0', 'Functional.conv2 [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]], - 'stack_info': []} + 'stack_info': [], + 'requires_grad': [True, False, True, True, True, True, True, True]} bench_dict = {'op_name': ['Functional.conv2d.0.forward.input.0', 'Functional.conv2d.0.forward.input.1', 'Functional.conv2d.0.forward.input.2', 'Functional.conv2d.0.forward.output.0', 'Functional.conv2d.0.forward.parameters.weight', 'Functional.conv2d.0.forward.parameters.bias', @@ -137,39 +148,42 @@ bench_dict = {'op_name': ['Functional.conv2d.0.forward.input.0', 'Functional.con [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]], - 'stack_info': []} + 'stack_info': [], + 'requires_grad': [True, False, True, True, True, True, True, True]} highlight_dict = {'red_rows': [], 'yellow_rows': []} o_result = [ ['Functional.conv2d.0.forward.input.0', 'Functional.conv2d.0.forward.input.0', 'torch.float32', 'torch.float32', - [1, 1, 28, 28], [1, 1, 28, 28], 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', + [1, 1, 28, 28], [1, 1, 28, 28], True, True, 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', 3.029174327850342, -2.926689624786377, -0.06619918346405029, 1.0, - 3.029174327850342, -2.926689624786377, -0.06619918346405029, 1.0,'', '', 'None'], + 3.029174327850342, -2.926689624786377, -0.06619918346405029, 1.0, True, '', '', 'None'], ['Functional.conv2d.0.forward.input.1', 'Functional.conv2d.0.forward.input.1', 'torch.float32', 'torch.float32', - [16, 1, 5, 5], [16, 1, 5, 5], 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', + [16, 1, 5, 5], [16, 1, 5, 5], False, False, 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', 0.19919930398464203, -0.19974489510059357, 0.006269412115216255, 1.0, - 0.19919930398464203, -0.19974489510059357, 0.006269412115216255, 1.0, '', '', 'None'], + 0.19919930398464203, -0.19974489510059357, 0.006269412115216255, 1.0, True, '', '', 'None'], ['Functional.conv2d.0.forward.input.2', 'Functional.conv2d.0.forward.input.2', 'torch.float32', 'torch.float32', - [16], [16], 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', + [16], [16], True, True, 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', 0.19734230637550354, -0.18177609145641327, 0.007903944700956345, 1.0, - 0.19734230637550354, -0.18177609145641327, 0.007903944700956345, 1.0, '', '', 'None'], + 0.19734230637550354, -0.18177609145641327, 0.007903944700956345, 1.0, True, '', '', 'None'], ['Functional.conv2d.0.forward.parameters.weight', 'Functional.conv2d.0.forward.parameters.weight', 'torch.float32', 'torch.float32', - [1, 16, 28, 28], [1, 16, 28, 28], 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, '', '', 'None'], + [1, 16, 28, 28], [1, 16, 28, 28], True, True, 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, True, '', '', 'None'], ['Functional.conv2d.0.forward.parameters.bias', 'Functional.conv2d.0.forward.parameters.bias', 'torch.float32', 'torch.float32', - [1, 16, 28, 28], [1, 16, 28, 28], 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, '', '', 'None'], + [1, 16, 28, 28], [1, 16, 28, 28], True, True, 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, True, '', '', 'None'], ['Functional.conv2d.0.forward.output.0', 'Functional.conv2d.0.forward.output.0', 'torch.float32', 'torch.float32', - [1, 16, 28, 28], [1, 16, 28, 28], 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', + [1, 16, 28, 28], [1, 16, 28, 28], True, True, 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', 2.1166646480560303, -2.190781354904175, -0.003579073818400502, 1.0, - 2.1166646480560303, -2.190781354904175, -0.003579073818400502, 1.0, '', '', 'None'], - ['Functional.conv2d.0.parameters_grad.weight', 'Functional.conv2d.0.parameters_grad.weight', 'torch.float32', 'torch.float32', - [1, 16, 28, 28], [1, 16, 28, 28], 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, '', '', 'None'], - ['Functional.conv2d.0.parameters_grad.bias', 'Functional.conv2d.0.parameters_grad.bias', 'torch.float32', 'torch.float32', - [1, 16, 28, 28], [1, 16, 28, 28], 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, '', '', 'None'], + 2.1166646480560303, -2.190781354904175, -0.003579073818400502, 1.0, True, '', '', 'None'], + ['Functional.conv2d.0.parameters_grad.weight', 'Functional.conv2d.0.parameters_grad.weight', 'torch.float32', + 'torch.float32', + [1, 16, 28, 28], [1, 16, 28, 28], True, True, 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, True, '', '', 'None'], + ['Functional.conv2d.0.parameters_grad.bias', 'Functional.conv2d.0.parameters_grad.bias', 'torch.float32', + 'torch.float32', + [1, 16, 28, 28], [1, 16, 28, 28], True, True, 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, True, '', '', 'None'], ] # test_get_un_match_accuracy @@ -187,9 +201,11 @@ o_result_unmatch_1 = [ 'None'], ['Functional.conv2d.0.forward.output.0', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', 'None'], - ['Functional.conv2d.0.parameters_grad.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', + ['Functional.conv2d.0.parameters_grad.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', + 'N/A', 'None'], - ['Functional.conv2d.0.parameters_grad.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', + ['Functional.conv2d.0.parameters_grad.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', + 'N/A', 'None'] ] o_result_unmatch_2 = [ @@ -197,10 +213,12 @@ o_result_unmatch_2 = [ 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 3.029174327850342, -2.926689624786377, -0.06619918346405029, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None'], ['Functional.conv2d.0.forward.input.1', 'N/A', 'torch.float32', 'N/A', [16, 1, 5, 5], 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 0.19919930398464203, -0.19974489510059357, 0.006269412115216255, 1.0, 'N/A', 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 0.19919930398464203, -0.19974489510059357, 0.006269412115216255, 1.0, 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None'], ['Functional.conv2d.0.forward.input.2', 'N/A', 'torch.float32', 'N/A', [16], 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 'N/A', 'N/A', 0.19734230637550354, -0.18177609145641327, 0.007903944700956345, 1.0, 'N/A', 'N/A', 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 0.19734230637550354, -0.18177609145641327, 0.007903944700956345, 1.0, 'N/A', 'N/A', + 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None'], ['Functional.conv2d.0.forward.parameters.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', @@ -211,53 +229,64 @@ o_result_unmatch_2 = [ 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None'], ['Functional.conv2d.0.forward.output.0', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 2.1166646480560303, -2.190781354904175, -0.003579073818400502, 1.0, 'N/A', 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 2.1166646480560303, -2.190781354904175, -0.003579073818400502, 1.0, 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None'], - ['Functional.conv2d.0.parameters_grad.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', + ['Functional.conv2d.0.parameters_grad.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None'], - ['Functional.conv2d.0.parameters_grad.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', + ['Functional.conv2d.0.parameters_grad.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None'] ] o_result_unmatch_3 = [ - ['Functional.conv2d.0.forward.input.0', 'N/A', 'torch.float32', 'N/A', [1, 1, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 3.029174327850342, -2.926689624786377, -0.06619918346405029, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', - 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.forward.input.1', 'N/A', 'torch.float32', 'N/A', [16, 1, 5, 5], 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 0.19919930398464203, -0.19974489510059357, 0.006269412115216255, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', - 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.forward.input.2', 'N/A', 'torch.float32', 'N/A', [16], 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 0.19734230637550354, -0.18177609145641327, 0.007903944700956345, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', - 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.forward.parameters.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', - 'N/A', 'N/A', - 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.forward.parameters.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', - 'N/A', - 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.forward.output.0', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 2.1166646480560303, -2.190781354904175, -0.003579073818400502, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', - 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.parameters_grad.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.parameters_grad.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', '-1'] + ['Functional.conv2d.0.forward.input.0', 'N/A', 'torch.float32', 'N/A', [1, 1, 28, 28], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 3.029174327850342, -2.926689624786377, -0.06619918346405029, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 'No bench data matched.', 'None', ['-1', '-1']], + ['Functional.conv2d.0.forward.input.1', 'N/A', 'torch.float32', 'N/A', [16, 1, 5, 5], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 0.19919930398464203, -0.19974489510059357, 0.006269412115216255, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 'No bench data matched.', 'None', ['-1', '-1']], + ['Functional.conv2d.0.forward.input.2', 'N/A', 'torch.float32', 'N/A', [16], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 0.19734230637550354, -0.18177609145641327, 0.007903944700956345, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 'No bench data matched.', 'None', ['-1', '-1']], + ['Functional.conv2d.0.forward.parameters.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', ['-1', '-1']], + ['Functional.conv2d.0.forward.parameters.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', ['-1', '-1']], + ['Functional.conv2d.0.forward.output.0', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 2.1166646480560303, -2.190781354904175, -0.003579073818400502, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 'No bench data matched.', 'None', ['-1', '-1']], + ['Functional.conv2d.0.parameters_grad.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', ['-1', '-1']], + ['Functional.conv2d.0.parameters_grad.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', ['-1', '-1']] ] # test_merge_tensor tensor_list = [ {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], 'Max': 0.33033010363578796, 'Min': -0.331031858921051, 'Mean': -0.030964046716690063, 'Norm': 2.2533628940582275, 'requires_grad': True, - 'full_op_name': 'Tensor.add_.0.forward.input.0'}, + 'full_op_name': 'Tensor.add_.0.forward.input.0', 'state': 'input'}, {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], 'Max': 0.003992878366261721, 'Min': -0.008102823048830032, 'Mean': -0.0002002553956117481, - 'Norm': 0.02844562754034996, 'requires_grad': False, 'full_op_name': 'Tensor.add_.0.forward.input.1'}, + 'Norm': 0.02844562754034996, 'requires_grad': False, 'full_op_name': 'Tensor.add_.0.forward.input.1', + 'state': 'input'}, {'full_op_name': 'Tensor.add_.0.forward.input.alpha.0', 'dtype': "", "shape": '[]', 'md5': None, - 'Max': -0.1, 'Min': -0.1, 'Mean': -0.1, 'Norm': -0.1, 'data_name': '-1'}, + 'Max': -0.1, 'Min': -0.1, 'Mean': -0.1, 'Norm': -0.1, 'data_name': '-1', 'state': 'input'}, {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], 'Max': 0.33033010363578796, 'Min': -0.331031858921051, 'Mean': -0.030964046716690063, - 'Norm': 2.2533628940582275, 'requires_grad': True, 'full_op_name': 'Tensor.add_.0.forward.output.0'} + 'Norm': 2.2533628940582275, 'requires_grad': True, 'full_op_name': 'Tensor.add_.0.forward.output.0', + 'state': 'output'} ] result_op_dict = {'op_name': ['Tensor.add_.0.forward.input.0', 'Tensor.add_.0.forward.input.1', 'Tensor.add_.0.forward.input.alpha.0', 'Tensor.add_.0.forward.output.0'], @@ -266,22 +295,27 @@ result_op_dict = {'op_name': ['Tensor.add_.0.forward.input.0', 'Tensor.add_.0.fo 'output_struct': [('torch.float32', [16, 1, 3, 3])], 'params_struct': [], 'params_grad_struct': [], + 'debug_struct': [], 'summary': [[0.33033010363578796, -0.331031858921051, -0.030964046716690063, 2.2533628940582275], [0.003992878366261721, -0.008102823048830032, -0.0002002553956117481, 0.02844562754034996], [-0.1, -0.1, -0.1, -0.1], [0.33033010363578796, -0.331031858921051, -0.030964046716690063, 2.2533628940582275]], - 'stack_info': []} + 'stack_info': [], + 'state': ['input', 'input', 'input', 'output'], + 'requires_grad': [True, False, None, True]} tensor_list_md5 = [ {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], 'Max': 0.003992878366261721, 'Min': -0.008102823048830032, 'Mean': -0.0002002553956117481, - 'Norm': 0.02844562754034996, 'requires_grad': False, 'full_op_name': 'Tensor.add_.0.forward.input.0', 'md5': 1}, + 'Norm': 0.02844562754034996, 'requires_grad': False, 'full_op_name': 'Tensor.add_.0.forward.input.0', 'md5': 1, + 'state': 'input'}, {'full_op_name': 'Tensor.add_.0.forward.kwargs.alpha.0', 'dtype': "", "shape": '[]', 'md5': None, - 'Max': -0.1, 'Min': -0.1, 'Mean': -0.1, 'Norm': -0.1, 'data_name': '-1'}, + 'Max': -0.1, 'Min': -0.1, 'Mean': -0.1, 'Norm': -0.1, 'data_name': '-1', 'state': 'input'}, {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], 'Max': 0.33033010363578796, 'Min': -0.331031858921051, 'Mean': -0.030964046716690063, - 'Norm': 2.2533628940582275, 'requires_grad': True, 'full_op_name': 'Tensor.add_.0.forward.output.0', 'md5': 2} + 'Norm': 2.2533628940582275, 'requires_grad': True, 'full_op_name': 'Tensor.add_.0.forward.output.0', 'md5': 2, + 'state': 'output'} ] result_op_dict_md5 = {'op_name': ['Tensor.add_.0.forward.input.0', 'Tensor.add_.0.forward.kwargs.alpha.0', 'Tensor.add_.0.forward.output.0'], @@ -289,18 +323,22 @@ result_op_dict_md5 = {'op_name': ['Tensor.add_.0.forward.input.0', 'Tensor.add_. 'output_struct': [('torch.float32', [16, 1, 3, 3], 2)], 'params_struct': [], 'params_grad_struct': [], + 'debug_struct': [], 'summary': [ [0.003992878366261721, -0.008102823048830032, -0.0002002553956117481, 0.02844562754034996], [-0.1, -0.1, -0.1, -0.1], [0.33033010363578796, -0.331031858921051, -0.030964046716690063, 2.2533628940582275]], - 'stack_info': []} + 'stack_info': [], + 'state': ['input', 'input', 'output'], + 'requires_grad': [False, None, True] + } base_dir1 = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'test_acc_compare_utils1') base_dir2 = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'test_acc_compare_utils2') def create_json_files(base_dir): - file_names = ['dump.json', 'stack.json', 'construct.json'] + file_names = ['dump.json', 'stack.json', 'construct.json', 'debug.json'] for file_name in file_names: file_path = os.path.join(base_dir, file_name) @@ -333,29 +371,20 @@ class TestUtilsMethods(unittest.TestCase): def test_extract_json_1(self): create_json_files(base_dir1) - result = extract_json(base_dir1, stack_json=False) + result = extract_json(base_dir1, Const.DUMP_JSON_FILE) self.assertEqual(result, os.path.join(base_dir1, 'dump.json')) - result = extract_json(base_dir1, stack_json=True) + result = extract_json(base_dir1, Const.STACK_JSON_FILE) self.assertEqual(result, os.path.join(base_dir1, 'stack.json')) + result = extract_json(base_dir1, Const.DEBUG_JSON_FILE) + self.assertEqual(result, os.path.join(base_dir1, 'debug.json')) + def test_check_and_return_dir_contents(self): create_rank_dirs(base_dir2) result = check_and_return_dir_contents(base_dir2, 'rank') self.assertEqual(set(result), set(['rank0', 'rank1'])) - def test_rename_api_1(self): - test_name_1 = "Distributed.broadcast.0.forward.input.0" - expect_name_1 = "Distributed.broadcast.input.0" - actual_name_1 = rename_api(test_name_1, "forward") - self.assertEqual(actual_name_1, expect_name_1) - - def test_rename_api_2(self): - test_name_2 = "Torch.sum.0.backward.output.0" - expect_name_2 = "Torch.sum.output.0" - actual_name_2 = rename_api(test_name_2, "backward") - self.assertEqual(actual_name_2, expect_name_2) - def test_read_op(self): result = read_op(op_data, op_name) self.assertEqual(result, op_result) @@ -365,49 +394,44 @@ class TestUtilsMethods(unittest.TestCase): self.assertEqual(result, op_result_b) def test_op_item_parse(self): - result = op_item_parse(parse_item, parse_op_name) + result = op_item_parse(parse_item, parse_op_name, 'input') self.assertEqual(result, o_result_parse) def test_op_item_parse_max_depth(self): with self.assertRaises(CompareException) as context: - op_item_parse(parse_item, parse_op_name, depth=11) + op_item_parse(parse_item, parse_op_name, 'input', depth=401) self.assertEqual(context.exception.code, CompareException.RECURSION_LIMIT_ERROR) - def test_resolve_api_special_parameters(self): - item_list = [] - resolve_api_special_parameters(data_dict, full_op_name, item_list) - self.assertEqual(item_list, o_result_api_special) - def test_get_rela_diff_summary_mode_float_or_int(self): - result_item = [0] * 14 + result_item = [0] * 16 err_msg = '' npu_summary_data = [1, 1, 1, 1] - bench_summary_data = [1, 1, 1, 1] + bench_summary_data = [2, 2, 2, 2] result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data, bench_summary_data, err_msg) - self.assertEqual(result_item, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, '0.0%', '0.0%', '0.0%', '0.0%']) + self.assertEqual(result_item, [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, '50.0%', '50.0%', '50.0%', '50.0%']) self.assertEqual(accuracy_check, '') self.assertEqual(err_msg, '') def test_get_rela_diff_summary_mode_bool(self): - result_item = [0] * 14 + result_item = [0] * 16 err_msg = '' npu_summary_data = [True, True, True, True] bench_summary_data = [True, True, True, True] result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data, bench_summary_data, err_msg) - self.assertEqual(result_item, [0, 0, 0, 0, 0, 0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A']) + self.assertEqual(result_item, [0, 0, 0, 0, 0, 0, 0, 0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A']) self.assertEqual(accuracy_check, '') self.assertEqual(err_msg, '') def test_get_rela_diff_summary_mode_nan(self): - result_item = [0] * 14 + result_item = [0] * 16 err_msg = '' npu_summary_data = [float('nan')] bench_summary_data = [float('nan')] result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data, bench_summary_data, err_msg) - self.assertEqual(result_item, [0, 0, 0, 0, 0, 0, 'Nan', 0, 0, 0, 'Nan', 0, 0, 0]) + self.assertEqual(result_item, [0, 0, 0, 0, 0, 0, 0, 0, 'Nan', 0, 0, 0, 'Nan', 0, 0, 0]) self.assertEqual(accuracy_check, '') self.assertEqual(err_msg, '') @@ -443,57 +467,6 @@ class TestUtilsMethods(unittest.TestCase): get_accuracy(result, npu_dict, bench_dict, dump_mode=Const.SUMMARY) self.assertEqual(result, o_result) - def test_append_stack_info_stack_exist_index_0(self): - result_item = ['item1'] - npu_stack_info = ['stack_info1'] - index = 0 - - append_stack_info(result_item, npu_stack_info, index) - - self.assertEqual(result_item, ['item1', 'stack_info1']) - - def test_append_stack_info_stack_exist_index_not_0(self): - result_item = ['item1'] - npu_stack_info = ['stack_info1'] - index = 1 - - append_stack_info(result_item, npu_stack_info, index) - - self.assertEqual(result_item, ['item1', CompareConst.NONE]) - - def test_append_stack_info_stack_empty_index_0(self): - result_item = ['item1'] - npu_stack_info = [] - index = 0 - - append_stack_info(result_item, npu_stack_info, index) - - self.assertEqual(result_item, ['item1', CompareConst.NONE]) - - def test_append_stack_info_stack_empty_index_not_0(self): - result_item = ['item1'] - npu_stack_info = [] - index = 1 - - append_stack_info(result_item, npu_stack_info, index) - - self.assertEqual(result_item, ['item1', CompareConst.NONE]) - - def test_get_un_match_accuracy_md5(self): - result = [] - get_un_match_accuracy(result, npu_dict, dump_mode=Const.MD5) - self.assertEqual(result, o_result_unmatch_1) - - def test_get_un_match_accuracy_summary(self): - result = [] - get_un_match_accuracy(result, npu_dict, dump_mode=Const.SUMMARY) - self.assertEqual(result, o_result_unmatch_2) - - def test_get_un_match_accuracy_all(self): - result = [] - get_un_match_accuracy(result, npu_dict, dump_mode=Const.ALL) - self.assertEqual(result, o_result_unmatch_3) - def test_merge_tensor_summary(self): op_dict = merge_tensor(tensor_list, dump_mode=Const.SUMMARY) self.assertEqual(op_dict, result_op_dict) @@ -552,19 +525,21 @@ class TestUtilsMethods(unittest.TestCase): b_name = 'Tensor.add.0.forward.input.0' b_struct = ('torch.float32', [96]) bench_stack_info = ['abc'] + requires_grad_pair = [True, True] n_info = ApiItemInfo(n_name, n_struct, npu_stack_info) b_info = ApiItemInfo(b_name, b_struct, bench_stack_info) dump_mode = Const.ALL - result_item = result_item_init(n_info, b_info, dump_mode) + result_item = result_item_init(n_info, b_info, requires_grad_pair, dump_mode) self.assertEqual(result_item, ['Tensor.add.0.forward.input.0', 'Tensor.add.0.forward.input.0', - 'torch.float32', 'torch.float32', [96], [96], ' ', ' ', ' ', ' ', ' ']) + 'torch.float32', 'torch.float32', [96], [96], True, True, + ' ', ' ', ' ', ' ', ' ', ' ']) dump_mode = Const.SUMMARY - result_item = result_item_init(n_info, b_info, dump_mode) + result_item = result_item_init(n_info, b_info, requires_grad_pair, dump_mode) self.assertEqual(result_item, ['Tensor.add.0.forward.input.0', 'Tensor.add.0.forward.input.0', - 'torch.float32', 'torch.float32', [96], [96], ' ', ' ', ' ', ' ', ' ', ' ', ' ', - ' ']) + 'torch.float32', 'torch.float32', [96], [96], True, True, + ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ']) def test_result_item_init_md5(self): n_name = 'Tensor.add.0.forward.input.0' @@ -573,13 +548,15 @@ class TestUtilsMethods(unittest.TestCase): b_name = 'Tensor.add.0.forward.input.0' b_struct = ('torch.float32', [96], 'e87000dc') bench_stack_info = ['abc'] + requires_grad_pair = [True, True] n_info = ApiItemInfo(n_name, n_struct, npu_stack_info) b_info = ApiItemInfo(b_name, b_struct, bench_stack_info) dump_mode = Const.MD5 - result_item = result_item_init(n_info, b_info, dump_mode) + result_item = result_item_init(n_info, b_info, requires_grad_pair, dump_mode) self.assertEqual(result_item, ['Tensor.add.0.forward.input.0', 'Tensor.add.0.forward.input.0', - 'torch.float32', 'torch.float32', [96], [96], 'e87000dc', 'e87000dc', 'pass']) + 'torch.float32', 'torch.float32', [96], [96], True, True, + 'e87000dc', 'e87000dc', True, 'pass']) def test_result_item_init_md5_index_error(self): n_name = 'Tensor.add.0.forward.input.0' @@ -588,12 +565,13 @@ class TestUtilsMethods(unittest.TestCase): b_name = 'Tensor.add.0.forward.input.0' b_struct = ('torch.float32', [96]) bench_stack_info = ['abc'] + requires_grad_pair = [True, True] n_info = ApiItemInfo(n_name, n_struct, npu_stack_info) b_info = ApiItemInfo(b_name, b_struct, bench_stack_info) dump_mode = Const.MD5 with self.assertRaises(CompareException) as context: - result_item = result_item_init(n_info, b_info, dump_mode) + result_item = result_item_init(n_info, b_info, requires_grad_pair, dump_mode) self.assertEqual(context.exception.code, CompareException.INDEX_OUT_OF_BOUNDS_ERROR) def test_table_value_is_valid_int(self): @@ -612,96 +590,70 @@ class TestUtilsMethods(unittest.TestCase): self.assertFalse(result) -class TestGetNameAndState(unittest.TestCase): - def test_valid_forward_input(self): - name = 'conv2d.forward.1.input.0' - expected_api = 'conv2d.forward.1.' - expected_state = 'input' - self.assertEqual(get_name_and_state(name), (expected_api, expected_state)) - - def test_valid_backward_output(self): - name = 'Functional.pad.0.backward.output.0' - expected_api = 'Functional.pad.0.backward.' - expected_state = 'output' - self.assertEqual(get_name_and_state(name), (expected_api, expected_state)) - - def test_valid_with_kwargs(self): - name = 'layer.norm.2.forward.kwargs.attr' - expected_api = 'layer.norm.2.forward.' - expected_state = 'kwargs' - self.assertEqual(get_name_and_state(name), (expected_api, expected_state)) - - def test_no_numeric_index(self): - name = 'conv2d.forward.input.0' - expected_api = 'conv2d.forward.' - expected_state = 'input' - self.assertEqual(get_name_and_state(name), (expected_api, expected_state)) - - def test_invalid__state(self): - name = 'conv2d.forward.1.invalidstate.0' - with self.assertRaises(CompareException) as context: - get_name_and_state(name) - self.assertIn('Invalid name string', str(context.exception.code)) +class TestReorderIndex(unittest.TestCase): + def test_reorder_index_mixed_states(self): + op_parsed_list = [ + {Const.STATE: "OTHER"}, + {Const.STATE: Const.OUTPUT}, + {Const.STATE: Const.PARAMS}, + {Const.STATE: Const.PARAMS_GRAD}, + {Const.STATE: Const.INPUT}, + {"not_state": 123}, # 没有 STATE,算作 other + ] + + reordered = reorder_index(op_parsed_list) + self.assertTrue(reordered == [0, 4, 2, 1, 3]) + + def test_reorder_index_all_params(self): + op_parsed_list = [ + {Const.STATE: Const.PARAMS}, + {Const.STATE: Const.PARAMS}, + {Const.STATE: Const.PARAMS}, + ] + reordered = reorder_index(op_parsed_list) + self.assertTrue(reordered == [0, 1]) + + def test_reorder_index_empty(self): + op_parsed_list = [] + reordered = reorder_index(op_parsed_list) + self.assertTrue(reordered == []) + + def test_reorder_index_single_element(self): + op_parsed_list = [{Const.STATE: Const.PARAMS}] + reordered = reorder_index(op_parsed_list) + self.assertTrue(reordered == []) class TestReorderOpNameList(unittest.TestCase): def test_reorder_op_name_list(self): # 标准顺序 - op_name_list = ["op.forward.input.0.0", "op.forward.output.0", "op.forward.output.1", "op.forward.parameters.1", "op.forward.parameters.2", "op.parameters_grad.0"] - result = reorder_op_name_list(op_name_list) - expected = ["op.forward.input.0.0", "op.forward.parameters.1", "op.forward.parameters.2", "op.forward.output.0", "op.forward.output.1", "op.parameters_grad.0"] - self.assertEqual(result, expected) + op_name_list = ["op.forward.input.0.0", "op.forward.output.0", "op.forward.output.1", "op.forward.parameters.1", + "op.forward.parameters.2", "op.parameters_grad.0"] + state_list = ["input", "output", "output", "parameters", "parameters", "parameters_grad"] + op_name_reorder, state_reorder = reorder_op_name_list(op_name_list, state_list) + expected_result = ["op.forward.input.0.0", "op.forward.parameters.1", "op.forward.parameters.2", + "op.forward.output.0", "op.forward.output.1", "op.parameters_grad.0"] + expected_state = ["input", "parameters", "parameters", "output", "output", "parameters_grad"] + self.assertEqual(op_name_reorder, expected_result) + self.assertEqual(state_reorder, expected_state) # 只有输入元素 op_name_list = ["op.forward.input.0", "op.forward.input.1"] - result = reorder_op_name_list(op_name_list) - expected = ["op.forward.input.0", "op.forward.input.1"] - self.assertEqual(result, expected) + state_list = ["input", "input"] + op_name_reorder, state_reorder = reorder_op_name_list(op_name_list, state_list) + expected_result = ["op.forward.input.0", "op.forward.input.1"] + expected_state = ["input", "input"] + self.assertEqual(op_name_reorder, expected_result) + self.assertEqual(state_reorder, expected_state) # 输入为空 op_name_list = [] - result = reorder_op_name_list(op_name_list) - expected = [] - self.assertEqual(result, expected) - - -class TestReorderOpXList(unittest.TestCase): - def test_reorder_op_x_list(self): - # 标准顺序 - op_name_list = ["op.forward.input.0", "op.forward.output.0", "op.forward.parameters.weight"] - summary_list = ["summary1", "summary2", "summary3"] - data_name_list = ["data1", "data2", "data3"] - result_op_name, result_summary, result_data_name = reorder_op_x_list(op_name_list, summary_list, data_name_list) - self.assertEqual(result_op_name, ["op.forward.input.0", "op.forward.parameters.weight", "op.forward.output.0"]) - self.assertEqual(result_summary, ["summary1", "summary3", "summary2"]) - self.assertEqual(result_data_name, ["data1", "data3", "data2"]) - - # 空 op_name_list 或 summary_list - op_name_list = [] - summary_list = [] - data_name_list = ["data1", "data2", "data3"] - result_op_name, result_summary, result_data_name = reorder_op_x_list(op_name_list, summary_list, data_name_list) - self.assertEqual(result_op_name, []) - self.assertEqual(result_summary, []) - self.assertEqual(result_data_name, ["data1", "data2", "data3"]) - - # 空 data_name_list - op_name_list = ["op.forward.input.0", "op.forward.output.0", "op.forward.parameters.weight"] - summary_list = ["summary1", "summary2", "summary3"] - data_name_list = [] - result_op_name, result_summary, result_data_name = reorder_op_x_list(op_name_list, summary_list, data_name_list) - self.assertEqual(result_op_name, ["op.forward.input.0", "op.forward.parameters.weight", "op.forward.output.0"]) - self.assertEqual(result_summary, ["summary1", "summary3", "summary2"]) - self.assertEqual(result_data_name, []) - - # data_name_list 为 None - op_name_list = ["op.forward.input.0", "op.forward.output.0", "op.forward.parameters.weight"] - summary_list = ["summary1", "summary2", "summary3"] - data_name_list = None - result_op_name, result_summary, result_data_name = reorder_op_x_list(op_name_list, summary_list, data_name_list) - self.assertEqual(result_op_name, ["op.forward.input.0", "op.forward.parameters.weight", "op.forward.output.0"]) - self.assertEqual(result_summary, ["summary1", "summary3", "summary2"]) - self.assertEqual(result_data_name, None) + state_list = [] + op_name_reorder, state_reorder = reorder_op_name_list(op_name_list, state_list) + expected_result = [] + expected_state = [] + self.assertEqual(op_name_reorder, expected_result) + self.assertEqual(state_reorder, expected_state) class TestGenOpItem(unittest.TestCase): @@ -719,7 +671,7 @@ class TestGenOpItem(unittest.TestCase): } op_name = 'op_test' - result = gen_op_item(op_data, op_name) + result = gen_op_item(op_data, op_name, 'input') self.assertEqual(result['data_name'], 'test_data') self.assertEqual(result['full_op_name'], 'test_data') @@ -730,6 +682,7 @@ class TestGenOpItem(unittest.TestCase): self.assertEqual(result['Mean'], 2) self.assertEqual(result['Norm'], 2) self.assertEqual(result['md5'], f"{zlib.crc32(str(op_data['value']).encode()):08x}") + self.assertEqual(result['state'], 'input') def test_gen_op_item_with_empty_data_name(self): op_data = { @@ -739,11 +692,12 @@ class TestGenOpItem(unittest.TestCase): } op_name = 'op_test' - result = gen_op_item(op_data, op_name) + result = gen_op_item(op_data, op_name, 'input') # data_name为空时,应该被设置为'-1' self.assertEqual(result['data_name'], '-1') self.assertEqual(result['full_op_name'], op_name) + self.assertEqual(result['state'], 'input') def test_gen_op_item_with_none_data_name(self): op_data = { @@ -753,11 +707,12 @@ class TestGenOpItem(unittest.TestCase): } op_name = 'op_test' - result = gen_op_item(op_data, op_name) + result = gen_op_item(op_data, op_name, 'input') # data_name为None时,应该被设置为'-1' self.assertEqual(result['data_name'], '-1') self.assertEqual(result['full_op_name'], op_name) + self.assertEqual(result['state'], 'input') def test_gen_op_item_with_type_torch_size(self): op_data = { @@ -767,7 +722,7 @@ class TestGenOpItem(unittest.TestCase): } op_name = 'op_test' - result = gen_op_item(op_data, op_name) + result = gen_op_item(op_data, op_name, 'input') self.assertEqual(result['dtype'], 'torch.Size') self.assertEqual(result['shape'], '[2, 3, 4]') @@ -775,6 +730,7 @@ class TestGenOpItem(unittest.TestCase): self.assertEqual(result['Min'], None) self.assertEqual(result['Mean'], None) self.assertEqual(result['Norm'], None) + self.assertEqual(result['state'], 'input') def test_gen_op_item_with_type_slice(self): op_data = { @@ -784,10 +740,11 @@ class TestGenOpItem(unittest.TestCase): } op_name = 'op_test' - result = gen_op_item(op_data, op_name) + result = gen_op_item(op_data, op_name, 'input') self.assertEqual(result['dtype'], 'slice') self.assertEqual(result['shape'], str(np.shape(np.array(op_data['value'])))) + self.assertEqual(result['state'], 'input') def test_gen_op_item_with_type_ellipsis(self): op_data = { @@ -797,7 +754,7 @@ class TestGenOpItem(unittest.TestCase): } op_name = 'op_test' - result = gen_op_item(op_data, op_name) + result = gen_op_item(op_data, op_name, 'input') self.assertEqual(result['dtype'], 'ellipsis') self.assertEqual(result['shape'], '[]') @@ -805,6 +762,7 @@ class TestGenOpItem(unittest.TestCase): self.assertEqual(result['Min'], '...') self.assertEqual(result['Mean'], '...') self.assertEqual(result['Norm'], '...') + self.assertEqual(result['state'], 'input') def test_gen_op_item_with_type_torch_process_group(self): op_data = { @@ -814,7 +772,7 @@ class TestGenOpItem(unittest.TestCase): } op_name = 'op_test' - result = gen_op_item(op_data, op_name) + result = gen_op_item(op_data, op_name, 'input') self.assertEqual(result['dtype'], 'torch.ProcessGroup') self.assertEqual(result['shape'], '[]') @@ -822,6 +780,7 @@ class TestGenOpItem(unittest.TestCase): self.assertEqual(result['Min'], '[0, 1]') self.assertEqual(result['Mean'], '[0, 1]') self.assertEqual(result['Norm'], '[0, 1]') + self.assertEqual(result['state'], 'input') def test_gen_op_item_with_default_dtype(self): op_data = { @@ -831,10 +790,11 @@ class TestGenOpItem(unittest.TestCase): } op_name = 'op_test' - result = gen_op_item(op_data, op_name) + result = gen_op_item(op_data, op_name, 'input') self.assertEqual(result['dtype'], str(type(op_data['value']))) self.assertEqual(result['shape'], '[]') + self.assertEqual(result['state'], 'input') def test_gen_op_item_with_md5(self): op_data = { @@ -844,7 +804,149 @@ class TestGenOpItem(unittest.TestCase): } op_name = 'op_test' - result = gen_op_item(op_data, op_name) + result = gen_op_item(op_data, op_name, 'input') expected_md5 = f"{zlib.crc32(str(op_data['value']).encode()):08x}" self.assertEqual(result['md5'], expected_md5) + self.assertEqual(result['state'], 'input') + + +class TestApiBatch(unittest.TestCase): + def test_ApiBatch_increment_input(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.INPUT) + + self.assertEqual(api_batch._state, Const.INPUT) + self.assertEqual(api_batch.input_len, 2) + self.assertEqual(api_batch.params_end_index, 4) + self.assertEqual(api_batch.output_end_index, 4) + self.assertEqual(api_batch.params_grad_end_index, 4) + + def test_ApiBatch_increment_output(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.OUTPUT) + + self.assertEqual(api_batch._state, Const.OUTPUT) + self.assertEqual(api_batch.input_len, 1) + self.assertEqual(api_batch.params_end_index, 3) + self.assertEqual(api_batch.output_end_index, 4) + self.assertEqual(api_batch.params_grad_end_index, 4) + + def test_ApiBatch_increment_kwargs(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.KWARGS) + + self.assertEqual(api_batch._state, Const.KWARGS) + self.assertEqual(api_batch.input_len, 2) + self.assertEqual(api_batch.params_end_index, 4) + self.assertEqual(api_batch.output_end_index, 4) + self.assertEqual(api_batch.params_grad_end_index, 4) + + def test_ApiBatch_increment_params(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.PARAMS) + + self.assertEqual(api_batch._state, Const.PARAMS) + self.assertEqual(api_batch.input_len, 1) + self.assertEqual(api_batch.params_end_index, 4) + self.assertEqual(api_batch.output_end_index, 4) + self.assertEqual(api_batch.params_grad_end_index, 4) + + def test_ApiBatch_increment_multiple_input(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.INPUT) + api_batch.increment(Const.INPUT) + + self.assertEqual(api_batch._state, Const.INPUT) + self.assertEqual(api_batch.input_len, 3) + self.assertEqual(api_batch.params_end_index, 5) + self.assertEqual(api_batch.output_end_index, 5) + self.assertEqual(api_batch.params_grad_end_index, 5) + + def test_ApiBatch_increment_multiple_output(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.OUTPUT) + api_batch.increment(Const.OUTPUT) + + self.assertEqual(api_batch._state, Const.OUTPUT) + self.assertEqual(api_batch.input_len, 1) + self.assertEqual(api_batch.params_end_index, 3) + self.assertEqual(api_batch.output_end_index, 5) + self.assertEqual(api_batch.params_grad_end_index, 5) + + +class TestGenApiBatches(unittest.TestCase): + def test_gen_api_batches_normal(self): + result_df_part1 = pd.DataFrame(o_result) + result_df_part1.columns = CompareConst.SUMMARY_COMPARE_RESULT_HEADER_STACK + new_columns = [ + ['input', 'Functional.conv2d.0.forward'], + ['input', 'Functional.conv2d.0.forward'], + ['input', 'Functional.conv2d.0.forward'], + ['parameters', 'Functional.conv2d.0.forward'], + ['parameters', 'Functional.conv2d.0.forward'], + ['output', 'Functional.conv2d.0.forward'], + ['parameters_grad', 'Functional.conv2d.0.forward'], + ['parameters_grad', 'Functional.conv2d.0.forward'] + ] + result_df_part2 = pd.DataFrame(new_columns) + result_df_part2.columns = [Const.STATE, Const.API_ORIGIN_NAME] + result_df = pd.concat([result_df_part1, result_df_part2], axis=1) + result = result_df.values + header = result_df.columns.tolist() + result_api_batches = gen_api_batches(result, header) + + api_batch = ApiBatch('Functional.conv2d.0.forward', 0) + api_batch.input_len = 3 + api_batch.output_end_index = 6 + api_batch.params_end_index = 5 + api_batch.params_grad_end_index = 8 + api_batch._state = 'parameters_grad' + + result_api_batch = result_api_batches[0] + self.assertEqual(result_api_batch.api_name, api_batch.api_name) + self.assertEqual(result_api_batch.start, api_batch.start) + self.assertEqual(result_api_batch.input_len, api_batch.input_len) + self.assertEqual(result_api_batch.params_end_index, api_batch.params_end_index) + self.assertEqual(result_api_batch.params_grad_end_index, api_batch.params_grad_end_index) + self.assertEqual(result_api_batch._state, api_batch._state) + + +class TestGetPairedSteps(unittest.TestCase): + def setUp(self): + self.npu_dir = tempfile.TemporaryDirectory() + self.bench_dir = tempfile.TemporaryDirectory() + + self.npu_files = ['step1', 'step2'] + for name in self.npu_files: + open(os.path.join(self.npu_dir.name, name), 'w').close() + + self.bench_files = ['step2', 'step3'] + for name in self.bench_files: + open(os.path.join(self.bench_dir.name, name), 'w').close() + + def tearDown(self): + self.npu_dir.cleanup() + self.bench_dir.cleanup() + + def test_get_paired_steps(self): + paired = get_paired_dirs(self.npu_dir.name, self.bench_dir.name) + self.assertEqual(set(paired), {'step2'}) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_cli.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..62d5467137cdb9d8b6fffaff6e82a10c992df3ea --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_cli.py @@ -0,0 +1,34 @@ +import unittest +from unittest.mock import patch, MagicMock + +from msprobe.core.compare.compare_cli import mix_compare + + +class TestMixCompare(unittest.TestCase): + @patch('msprobe.core.compare.compare_cli.get_paired_dirs') + @patch('msprobe.core.compare.compare_cli.compare_cli') + def test_mix_compare_with_matching_dirs(self, mock_compare_cli, mock_get_paired_dirs): + mock_args = MagicMock() + mock_args.output_path = "/output" + mock_input_param = {"npu_path": "/npu_dump", "bench_path": "/bench_dump", "is_print_compare_log": True} + mock_get_paired_dirs.side_effect = [ + ["graph", "pynative"], # 第一次调用的返回值 + ["step1", "step2"], # 第二次调用的返回值 + ["step1", "step2"] # 第三次调用的返回值 + ] + + result = mix_compare(mock_args, mock_input_param, 1) + + self.assertTrue(result) + + @patch('msprobe.core.compare.compare_cli.get_paired_dirs') + @patch('msprobe.core.compare.compare_cli.compare_cli') + def test_mix_compare_no_matching_dirs(self, mock_compare_cli, mock_get_paired_dirs): + mock_args = MagicMock() + mock_args.output_path = "/output" + mock_input_param = {"npu_path": "/npu_dump", "bench_path": "/bench_dump", "is_print_compare_log": True} + mock_get_paired_dirs.return_value = set() + + result = mix_compare(mock_args, mock_input_param, 1) + + self.assertFalse(result) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_first_diff_analyze.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_first_diff_analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..cb2fde7eb6f22fc6edfb745ad488bb0bedf3ae6a --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_first_diff_analyze.py @@ -0,0 +1,176 @@ +import unittest +from unittest.mock import patch + +import pandas as pd + +from msprobe.core.common.const import Const, CompareConst +from msprobe.core.common.utils import CompareException +from msprobe.core.compare.diff_analyze.first_diff_analyze import FirstDiffAnalyze +from msprobe.core.compare.config import ModeConfig + + +class TestFirstDiffAnalyze(unittest.TestCase): + def setUp(self): + self.header = ['NPU name', 'L2norm diff', + 'MaxRelativeErr', 'MinRelativeErr', 'MeanRelativeErr', 'NormRelativeErr', + 'state', 'api_origin_name'] + self.data = [ + ['Functional.conv2d.0.forward.input.0', 1, '0.0%', '0.0%', '0.0%', '0.0%', 'input', 'Functional.conv2d.0.forward'], + ['Functional.conv2d.0.forward.input.1', 1, '99.0%', '99.0%', '99.0%', '99.0%', 'input', 'Functional.conv2d.0.forward'] + ] + self.result_df = pd.DataFrame(self.data, columns=self.header) + + @patch('msprobe.core.compare.diff_analyze.first_diff_analyze.thresholds', + {'compare_metrics': ['MaxRelativeErr', 'NormRelativeErr'], 'MaxRelativeErr': [0.5]}) + def test_single_metric_diff_check_true(self): + mode_config = ModeConfig(first_diff_analyze=True) + first_diff_analyze = FirstDiffAnalyze(mode_config, '') + result = first_diff_analyze.single_metric_diff_check('MaxRelativeErr', '60.0%') + self.assertTrue(result) + + @patch('msprobe.core.compare.diff_analyze.first_diff_analyze.thresholds', + {'compare_metrics': ['MaxRelativeErr', 'NormRelativeErr'], 'MaxRelativeErr': [0.5]}) + def test_single_metric_diff_check_false(self): + mode_config = ModeConfig(first_diff_analyze=True) + first_diff_analyze = FirstDiffAnalyze(mode_config, '') + result = first_diff_analyze.single_metric_diff_check('MaxRelativeErr', '30.0%') + self.assertFalse(result) + + @patch('msprobe.core.compare.diff_analyze.first_diff_analyze.thresholds', + {'compare_metrics': ['MaxRelativeErr', 'NormRelativeErr'], 'NormRelativeErr': [0.5]}) + def test_single_metric_diff_check_miss_threshold(self): + mode_config = ModeConfig(first_diff_analyze=True) + first_diff_analyze = FirstDiffAnalyze(mode_config, '') + with self.assertRaises(CompareException) as context: + result = first_diff_analyze.single_metric_diff_check('MaxRelativeErr', '30.0%') + self.assertEqual(context.exception.code, CompareException.MISSING_THRESHOLD_ERROR) + + @patch('msprobe.core.compare.diff_analyze.first_diff_analyze.thresholds', + {'compare_metrics': ['MaxRelativeErr', 'NormRelativeErr'], 'MaxRelativeErr': [0.5, 1.0]}) + def test_single_metric_diff_check_wrong_threshold(self): + mode_config = ModeConfig(first_diff_analyze=True) + first_diff_analyze = FirstDiffAnalyze(mode_config, '') + with self.assertRaises(CompareException) as context: + result = first_diff_analyze.single_metric_diff_check('MaxRelativeErr', '30.0%') + self.assertEqual(context.exception.code, CompareException.WRONG_THRESHOLD_ERROR) + + def test_single_api_check_within_threshold(self): + result_slice = [ + ['Functional.conv2d.0.forward.input.0', 1, '0.0%', '0.0%', '0.0%', '0.0%', 'input', 'Functional.conv2d.0.forward'], + ['Functional.conv2d.0.forward.input.1', 1, '0.1%', '0.1%', '0.1%', '0.1%', 'input', 'Functional.conv2d.0.forward'] + ] + expected_result = { + 'is_same': True, + 'op_items': [ + {'NPU name': 'Functional.conv2d.0.forward.input.0', 'L2norm diff': 1, + 'MaxRelativeErr': '0.0%', 'MinRelativeErr': '0.0%', + 'MeanRelativeErr': '0.0%', 'NormRelativeErr': '0.0%', + 'state': 'input', 'api_origin_name': 'Functional.conv2d.0.forward'}, + {'NPU name': 'Functional.conv2d.0.forward.input.1', 'L2norm diff': 1, + 'MaxRelativeErr': '0.1%', 'MinRelativeErr': '0.1%', + 'MeanRelativeErr': '0.1%', 'NormRelativeErr': '0.1%', + 'state': 'input', 'api_origin_name': 'Functional.conv2d.0.forward'} + ] + } + mode_config = ModeConfig(first_diff_analyze=True) + first_diff_analyze = FirstDiffAnalyze(mode_config, 'rank0') + result = first_diff_analyze.single_api_check(result_slice, self.header) + self.assertEqual(result, expected_result) + + def test_single_api_check_exceed_threshold(self): + result_slice = [ + ['Functional.conv2d.0.forward.input.0', 1, '88.0%', '88.0%', '88.0%', '88.0%', 'input', 'Functional.conv2d.0.forward'], + ['Functional.conv2d.0.forward.input.1', 1, '99.0%', '99.0%', '99.0%', '99.0%', 'input', 'Functional.conv2d.0.forward'] + ] + expected_result = { + 'is_same': False, + 'op_items': [ + {'NPU name': 'Functional.conv2d.0.forward.input.0', 'L2norm diff': 1, + 'MaxRelativeErr': '88.0%', 'MinRelativeErr': '88.0%', + 'MeanRelativeErr': '88.0%', 'NormRelativeErr': '88.0%', + 'state': 'input', 'api_origin_name': 'Functional.conv2d.0.forward'}, + {'NPU name': 'Functional.conv2d.0.forward.input.1', 'L2norm diff': 1, + 'MaxRelativeErr': '99.0%', 'MinRelativeErr': '99.0%', + 'MeanRelativeErr': '99.0%', 'NormRelativeErr': '99.0%', + 'state': 'input', 'api_origin_name': 'Functional.conv2d.0.forward'}, + ] + } + mode_config = ModeConfig(first_diff_analyze=True) + first_diff_analyze = FirstDiffAnalyze(mode_config, 'rank0') + result = first_diff_analyze.single_api_check(result_slice, self.header) + self.assertEqual(result, expected_result) + + def test_single_api_check_md5_same_true(self): + md5_header = CompareConst.MD5_COMPARE_RESULT_HEADER + [CompareConst.STACK, Const.STATE, Const.API_ORIGIN_NAME] + result_slice = [ + ['Functional.conv2d.0.forward.input.0', 'Functional.conv2d.0.forward.input.0', 'torch.int32', 'torch.int32', + '[]', '[]', 'True', 'True', '2144df1c', '2144df1c', True, 'pass', + '', 'input', 'Functional.conv2d.0.forward'] + ] + expected_result = { + 'is_same': True, + 'op_items': [ + {CompareConst.NPU_NAME: 'Functional.conv2d.0.forward.input.0', + CompareConst.BENCH_NAME: 'Functional.conv2d.0.forward.input.0', + CompareConst.NPU_DTYPE: 'torch.int32', CompareConst.BENCH_DTYPE: 'torch.int32', + CompareConst.NPU_SHAPE: '[]', CompareConst.BENCH_SHAPE: '[]', + CompareConst.NPU_REQ_GRAD: 'True', CompareConst.BENCH_REQ_GRAD: 'True', + CompareConst.NPU_MD5: '2144df1c', CompareConst.BENCH_MD5: '2144df1c', + CompareConst.REQ_GRAD_CONSIST: True, + CompareConst.RESULT: 'pass', CompareConst.STACK: '', + Const.STATE: 'input', Const.API_ORIGIN_NAME: 'Functional.conv2d.0.forward' + } + ] + } + mode_config = ModeConfig(dump_mode=Const.MD5, first_diff_analyze=True) + first_diff_analyze = FirstDiffAnalyze(mode_config, 'rank0') + result = first_diff_analyze.single_api_check(result_slice, md5_header) + self.assertEqual(result, expected_result) + + def test_single_api_check_md5_same_false(self): + md5_header = CompareConst.MD5_COMPARE_RESULT_HEADER + [CompareConst.STACK, Const.STATE, Const.API_ORIGIN_NAME] + result_slice = [ + ['Functional.conv2d.0.forward.input.0', 'Functional.conv2d.0.forward.input.0', 'torch.int32', 'torch.int32', + '[]', '[]', 'True', 'True', '2144df1c', '2100df1c', True, 'Different', + '', 'input', 'Functional.conv2d.0.forward'] + ] + expected_result = { + 'is_same': False, + 'op_items': [ + {CompareConst.NPU_NAME: 'Functional.conv2d.0.forward.input.0', + CompareConst.BENCH_NAME: 'Functional.conv2d.0.forward.input.0', + CompareConst.NPU_DTYPE: 'torch.int32', CompareConst.BENCH_DTYPE: 'torch.int32', + CompareConst.NPU_SHAPE: '[]', CompareConst.BENCH_SHAPE: '[]', + CompareConst.NPU_REQ_GRAD: 'True', CompareConst.BENCH_REQ_GRAD: 'True', + CompareConst.NPU_MD5: '2144df1c', CompareConst.BENCH_MD5: '2100df1c', + CompareConst.REQ_GRAD_CONSIST: True, + CompareConst.RESULT: 'Different', CompareConst.STACK: '', + Const.STATE: 'input', Const.API_ORIGIN_NAME: 'Functional.conv2d.0.forward' + } + ] + } + mode_config = ModeConfig(dump_mode=Const.MD5, first_diff_analyze=True) + first_diff_analyze = FirstDiffAnalyze(mode_config, 'rank0') + result = first_diff_analyze.single_api_check(result_slice, md5_header) + self.assertEqual(result, expected_result) + + def test_check_summary(self): + expected_result = { + 'Functional.conv2d.0.forward': { + 'is_same': False, + 'op_items': [ + {'NPU name': 'Functional.conv2d.0.forward.input.0', 'L2norm diff': 1, + 'MaxRelativeErr': '0.0%', 'MinRelativeErr': '0.0%', + 'MeanRelativeErr': '0.0%', 'NormRelativeErr': '0.0%', + 'state': 'input', 'api_origin_name': 'Functional.conv2d.0.forward'}, + {'NPU name': 'Functional.conv2d.0.forward.input.1', 'L2norm diff': 1, + 'MaxRelativeErr': '99.0%', 'MinRelativeErr': '99.0%', + 'MeanRelativeErr': '99.0%', 'NormRelativeErr': '99.0%', + 'state': 'input', 'api_origin_name': 'Functional.conv2d.0.forward'}, + ] + } + } + mode_config = ModeConfig(first_diff_analyze=True) + first_diff_analyze = FirstDiffAnalyze(mode_config, 'rank1') + result = first_diff_analyze.check(self.result_df) + self.assertEqual(result, expected_result) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py index f561a3e05ec84c3ee75dac50ed5aec2a2af7f7b5..c5190d3a577b8a5a58395836caac9715a8af268a 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py @@ -12,12 +12,53 @@ import openpyxl from openpyxl import load_workbook from openpyxl.styles import PatternFill - from msprobe.core.common.const import CompareConst, Const -from msprobe.core.compare.highlight import ApiBatch, CheckMaxRelativeDiff, CheckOrderMagnitude, \ - CheckOneThousandErrorRatio, CheckCosineSimilarity, add_highlight_row_info, compare_result_df_convert, \ - df_malicious_value_check, find_error_rows, highlight_rows_xlsx, update_highlight_err_msg, value_check - +from msprobe.core.compare.highlight import CheckMaxRelativeDiff, CheckOrderMagnitude, \ + CheckOneThousandErrorRatio, CheckCosineSimilarity, add_highlight_row_info, HighLight +from msprobe.core.compare.config import ModeConfig +from msprobe.core.compare.utils import ApiBatch + + +summary_line_input = ['Functional_batch_norm_0_forward.input.0', 'Functional_batch_norm_0_forward.input.0', + 'torch.float16', 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], True, True, + 0.01, 0, 0, 0, '0.0%', '0.0%', '0.0%', '0.0%', 1, 1, 1, 1, 1.01, 1, 1, 1, + True, 'Yes', ''] +summary_line_1 = ['Functional_batch_norm_0_forward.output.0', 'Functional_batch_norm_0_forward.output.0', + 'torch.float16', 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], True, True, + 10, 0, 0, 0, '0.0%', '0.0%', '0.0%', '0.0%', 2, 0, 1, 1, 1, 1, 1, 1, + True, 'Warning', ''] +summary_line_2 = ['Functional_batch_norm_0_forward.output.1', 'Functional_batch_norm_0_forward.output.1', + 'torch.float16', 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], True, True, + 0.02, 0, 0, 0, '0.0%', '0.0%', '0.0%', '0.0%', 0.12, 0, 1, 1, 0.1, 1, 1, 1, + True, 'Warning', ''] +summary_line_3 = ['Functional_batch_norm_0_forward.output.2', 'Functional_batch_norm_0_forward.output.2', + 'torch.float16', 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], True, True, + 0, 0, 0, 0, '0.0%', '0.0%', '0.0%', '0.0%', 2, 0, 1, 1, 1, 1, 1, 1, + True, 'Warning', ''] +line_input = ['Functional.batch.norm.0.forward.input.0', 'Functional.batch.norm.0.forward.input.0', 'torch.float16', + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], True, True, + 1, 0.5, 1, 1, 0.95, 1, + 1, 1, 1, 1, + 1.01, 1, 1, 1, + True, 'Yes', '', 'input', 'Functional.batch.norm.0.forward'] +line_1 = ['Functional.batch.norm.0.forward.output.0', 'Functional.batch.norm.0.forward.output.0', 'torch.float16', + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], True, True, + 0.8, 0.5, 1, 1, 0.59, 1, + 'nan', 0, 1, 1, + 19, 1, 1, 1, + True, 'Yes', '', 'output', 'Functional.batch.norm.0.forward'] +line_2 = ['Functional.batch.norm.0.forward.output.1', 'Functional.batch.norm.0.forward.output.1', 'torch.float16', + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], True, True, + 0.9, 0.5, 1, 1, 0.8, 1, + 0, 0.12, 0, 1, + 1, 0.1, 1, 1, + True, 'Yes', '', 'output', 'Functional.batch.norm.0.forward'] +line_3 = ['Functional.batch.norm.0.forward.output.2', 'Functional.batch.norm.0.forward.output.2', 'torch.float16', + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], True, True, + 0.8, 0.5, 1.1e+10, 1, 0.85, 1, + 9, 0.12, 0, 1, + 1, 0.1, 1, 1, + True, 'Yes', '', 'output', 'Functional.batch.norm.0.forward'] base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'test_highlight') @@ -25,8 +66,8 @@ base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'test_highl def generate_result_xlsx(base_dir): data_path = os.path.join(base_dir, 'target_result.xlsx') data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] + 'torch.float32', 'torch.float32', [2, 2], [2, 2], 'True', 'True', + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'True', 'Yes', '', '-1'] ] columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] result_df = pd.DataFrame(data, columns=columns) @@ -90,8 +131,8 @@ class TestUtilsMethods(unittest.TestCase): shutil.rmtree(base_dir) def test_CheckOrderMagnitude_normal(self): - api_in = [1, 1, 1, 1, 1, 1, 5, 1, 1] - api_out = [1, 1, 1, 1, 1, 1, 1, 1, 1] + api_in = [1, 1, 1, 1, 1, 1, True, True, 5, 1, 1] + api_out = [1, 1, 1, 1, 1, 1, True, True, 1, 1, 1] info = (api_in, api_out, 1) color_columns = () dump_mode = Const.SUMMARY @@ -101,8 +142,8 @@ class TestUtilsMethods(unittest.TestCase): self.assertEqual(result, None) def test_CheckOneThousandErrorRatio_str(self): - api_in = [1, 1, 1, 1, 1, 1, 1, 1, 1, "unsupported"] - api_out = [1, 1, 1, 1, 1, 1, 1, 1, 1, "unsupported"] + api_in = [1, 1, 1, 1, 1, 1, True, True, 0.9, 0.5, 1, 1, "unsupported"] + api_out = [1, 1, 1, 1, 1, 1, True, True, 0.9, 0.5, 1, 1, "unsupported"] info = (api_in, api_out, 1) color_columns = () dump_mode = Const.ALL @@ -113,8 +154,8 @@ class TestUtilsMethods(unittest.TestCase): @patch("msprobe.core.compare.highlight.add_highlight_row_info") def test_CheckOneThousandErrorRatio_red(self, mock_add_highlight_row_info): - api_in = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] - api_out = [1, 1, 1, 1, 1, 1, 1, 1, 1, 0.5] + api_in = [1, 1, 1, 1, 1, 1, True, True, 0.9, 0.5, 1, 1, 1] + api_out = [1, 1, 1, 1, 1, 1, True, True, 0.9, 0.5, 1, 1, 0.5] info = (api_in, api_out, 1) ColorColumns = namedtuple('ColorColumns', ['red', 'yellow']) color_columns = ColorColumns(red=[], yellow=[]) @@ -125,8 +166,8 @@ class TestUtilsMethods(unittest.TestCase): mock_add_highlight_row_info.assert_called_once() def test_CheckCosineSimilarity_str(self): - api_in = [1, 1, 1, 1, 1, 1, "unsupported", 1, 1, "unsupported"] - api_out = [1, 1, 1, 1, 1, 1, "unsupported", 1, 1, "unsupported"] + api_in = [1, 1, 1, 1, 1, 1, True, True, "unsupported", 1, 1, "unsupported"] + api_out = [1, 1, 1, 1, 1, 1, True, True, "unsupported", 1, 1, "unsupported"] info = (api_in, api_out, 1) color_columns = () dump_mode = Const.ALL @@ -141,8 +182,8 @@ class TestUtilsMethods(unittest.TestCase): red_lines, yellow_lines = [], [] color_columns = ColorColumns(red=red_lines, yellow=yellow_lines) - api_in = {6: 0, 18: 1} - api_out = {6: 0.6, 18: 1} + api_in = {12: '1%'} + api_out = {12: '60%'} num = 1 info = (api_in, api_out, num) CheckMaxRelativeDiff().apply(info, color_columns, dump_mode=Const.SUMMARY) @@ -156,12 +197,12 @@ class TestUtilsMethods(unittest.TestCase): red_lines, yellow_lines = [], [] color_columns = ColorColumns(red=red_lines, yellow=yellow_lines) - api_in = {6: 0.001, 18: 1} - api_out = {6: 0.2, 18: 1} + api_in = {12: '0.1%'} + api_out = {12: '20%'} num = 1 info = (api_in, api_out, num) CheckMaxRelativeDiff().apply(info, color_columns, dump_mode=Const.SUMMARY) - red_lines, yellow_lines = [], [(1, ["The output's maximum relative error exceeds 0.1, while the input/parameters's is below 0.01"])] + red_lines, yellow_lines = [], [(1, ["The output's maximum relative error exceeds 0.1, while the input/parameter's is below 0.01"])] target_color_columns = ColorColumns(red=red_lines, yellow=yellow_lines) self.assertEqual(color_columns, target_color_columns) @@ -171,8 +212,8 @@ class TestUtilsMethods(unittest.TestCase): red_lines, yellow_lines = [], [] color_columns = ColorColumns(red=red_lines, yellow=yellow_lines) - api_in = {6: 0.001, 18: np.nan} - api_out = {6: 0.2, 18: 1} + api_in = {12: '15000%'} + api_out = {12: '20%'} num = 1 info = (api_in, api_out, num) result = CheckMaxRelativeDiff().apply(info, color_columns, dump_mode=Const.SUMMARY) @@ -181,26 +222,31 @@ class TestUtilsMethods(unittest.TestCase): def test_find_error_rows_normal(self): compare_result = np.array([ ["Functional.linear.0.forward.input.0", "Functional.linear.0.forward.input.0", - "torch.float32", "torch.float32", [2, 2], [2, 2], 0.0, 0.0, 0.0, 0.0, "0.0%", "0.0%", "0.0%", "0.0%", - 1, 1, 1, 1, 1, 1, 1, 1, "", ""], + "torch.float32", "torch.float32", [2, 2], [2, 2], 'True', 'True', + 0.0, 0.0, 0.0, 0.0, "0.0%", "0.0%", "0.0%", "0.0%", + 1, 1, 1, 1, 1, 1, 1, 1, True, "", ""], ["Functional.linear.0.forward.input.1", "Functional.linear.0.forward.input.1", - "torch.float32", "torch.float32", [2, 2], [2, 2], 0.0, 0.0, 0.0, 0.0, "0.0%", "0.0%", "0.0%", "0.0%", - 1, 1, 1, 1, 1, 1, 1, 1, "", ""], + "torch.float32", "torch.float32", [2, 2], [2, 2], 'True', 'True', + 0.0, 0.0, 0.0, 0.0, "0.0%", "0.0%", "0.0%", "0.0%", + 1, 1, 1, 1, 1, 1, 1, 1, True, "", ""], ["Functional.linear.0.forward.input.2", "Functional.linear.0.forward.input.2", - "torch.float32", "torch.float32", [2], [2], 0.0, 0.0, 0.0, 0.0, "0.0%", "0.0%", "0.0%", "0.0%", - 1, 1, 1, 1, 1, 1, 1, 1, "", ""], + "torch.float32", "torch.float32", [2], [2], 'True', 'True', + 0.0, 0.0, 0.0, 0.0, "0.0%", "0.0%", "0.0%", "0.0%", + 1, 1, 1, 1, 1, 1, 1, 1, True, "", ""], ["Functional.linear.0.forward.output.0", "Functional.linear.0.forward.output.0", - "torch.float32", "torch.float32", [2, 2], [2, 2], 0.0, 0.0, 0.0, 0.0, "0.0%", "0.0%", "0.0%", "0.0%", - 1, 1, 1, 1, 1, 1, 1, 1, "", ""], + "torch.float32", "torch.float32", [2, 2], [2, 2], 'True', 'True', + 0.0, 0.0, 0.0, 0.0, "0.0%", "0.0%", "0.0%", "0.0%", + 1, 1, 1, 1, 1, 1, 1, 1, True, "", ""], ], dtype=object) api_batch = ApiBatch("Functional.linear.0.forward", 0) api_batch.input_len = 3 api_batch.output_end_index = 4 api_batch.params_end_index = 4 highlight_dict = {"red_lines": [], "red_rows": set(), "yellow_lines": [], "yellow_rows": set()} - dump_mode = Const.ALL - find_error_rows(compare_result, api_batch, highlight_dict, dump_mode) + mode_config = ModeConfig(dump_mode=Const.ALL) + highlight = HighLight(mode_config, '') + highlight.find_error_rows(compare_result, api_batch, highlight_dict) self.assertEqual(highlight_dict, {"red_lines": [], "red_rows": set(), "yellow_lines": [], "yellow_rows": set()}) @@ -211,92 +257,13 @@ class TestUtilsMethods(unittest.TestCase): api_batch.output_end_index = 1 api_batch.params_end_index = 1 highlight_dict = {} - dump_mode = Const.MD5 - result = find_error_rows(compare_result, api_batch, highlight_dict, dump_mode) + mode_config = ModeConfig(dump_mode=Const.MD5) + highlight = HighLight(mode_config, '') + result = highlight.find_error_rows(compare_result, api_batch, highlight_dict) self.assertEqual(result, None) - def test_ApiBatch_increment_input(self): - api_name = "functional.conv2d" - start = 2 - api_batch = ApiBatch(api_name, start) - - api_batch.increment(Const.INPUT) - - self.assertEqual(api_batch._state, Const.INPUT) - self.assertEqual(api_batch.input_len, 2) - self.assertEqual(api_batch.params_end_index, 4) - self.assertEqual(api_batch.output_end_index, 4) - self.assertEqual(api_batch.params_grad_end_index, 4) - - def test_ApiBatch_increment_output(self): - api_name = "functional.conv2d" - start = 2 - api_batch = ApiBatch(api_name, start) - - api_batch.increment(Const.OUTPUT) - - self.assertEqual(api_batch._state, Const.OUTPUT) - self.assertEqual(api_batch.input_len, 1) - self.assertEqual(api_batch.params_end_index, 3) - self.assertEqual(api_batch.output_end_index, 4) - self.assertEqual(api_batch.params_grad_end_index, 4) - - def test_ApiBatch_increment_kwargs(self): - api_name = "functional.conv2d" - start = 2 - api_batch = ApiBatch(api_name, start) - - api_batch.increment(Const.KWARGS) - - self.assertEqual(api_batch._state, Const.KWARGS) - self.assertEqual(api_batch.input_len, 2) - self.assertEqual(api_batch.params_end_index, 4) - self.assertEqual(api_batch.output_end_index, 4) - self.assertEqual(api_batch.params_grad_end_index, 4) - - def test_ApiBatch_increment_params(self): - api_name = "functional.conv2d" - start = 2 - api_batch = ApiBatch(api_name, start) - - api_batch.increment(Const.PARAMS) - - self.assertEqual(api_batch._state, Const.PARAMS) - self.assertEqual(api_batch.input_len, 1) - self.assertEqual(api_batch.params_end_index, 4) - self.assertEqual(api_batch.output_end_index, 4) - self.assertEqual(api_batch.params_grad_end_index, 4) - - def test_ApiBatch_increment_multiple_input(self): - api_name = "functional.conv2d" - start = 2 - api_batch = ApiBatch(api_name, start) - - api_batch.increment(Const.INPUT) - api_batch.increment(Const.INPUT) - - self.assertEqual(api_batch._state, Const.INPUT) - self.assertEqual(api_batch.input_len, 3) - self.assertEqual(api_batch.params_end_index, 5) - self.assertEqual(api_batch.output_end_index, 5) - self.assertEqual(api_batch.params_grad_end_index, 5) - - def test_ApiBatch_increment_multiple_output(self): - api_name = "functional.conv2d" - start = 2 - api_batch = ApiBatch(api_name, start) - - api_batch.increment(Const.OUTPUT) - api_batch.increment(Const.OUTPUT) - - self.assertEqual(api_batch._state, Const.OUTPUT) - self.assertEqual(api_batch.input_len, 1) - self.assertEqual(api_batch.params_end_index, 3) - self.assertEqual(api_batch.output_end_index, 5) - self.assertEqual(api_batch.params_grad_end_index, 5) - @patch("msprobe.core.compare.highlight.logger") def test_value_check(self, mock_logger): value = "=functional.conv2d" @@ -304,7 +271,9 @@ class TestUtilsMethods(unittest.TestCase): i = 1 result_df_columns = CompareConst.COMPARE_RESULT_HEADER - value_check(value, api_name, i, result_df_columns) + mode_config = ModeConfig() + highlight = HighLight(mode_config, '') + highlight.value_check(value, api_name, i, result_df_columns) mock_logger.error.assert_called_once_with( "Malicious value [=functional.conv2d] at api_name [=functional.conv2d], column [Bench Name], " @@ -314,49 +283,61 @@ class TestUtilsMethods(unittest.TestCase): def test_df_malicious_value_check(self): columns = CompareConst.COMPARE_RESULT_HEADER data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', ''] + 'torch.float32', 'torch.float32', [2, 2], [2, 2], True, True, + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, True, 'Yes', ''] ] result_df = pd.DataFrame(data, columns=columns) - df_malicious_value_check(result_df, columns) + mode_config = ModeConfig(dump_mode=Const.ALL) + highlight = HighLight(mode_config, '') + highlight.df_malicious_value_check(result_df) def test_compare_result_df_convert(self): value = float("nan") - result = compare_result_df_convert(value) + mode_config = ModeConfig() + highlight = HighLight(mode_config, '') + result = highlight.compare_result_df_convert(value) self.assertEqual(result, "nan\t") def test_highlight_rows_xlsx_red(self): data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] + 'torch.float32', 'torch.float32', [2, 2], [2, 2], True, True, + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, True, 'Yes', '', '-1'] ] columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] result_df = pd.DataFrame(data, columns=columns) highlight_dict = {'red_rows': [0]} file_path = os.path.join(base_dir, 'result.xlsx') - highlight_rows_xlsx(result_df, highlight_dict, file_path) + + mode_config = ModeConfig(dump_mode=Const.ALL) + highlight = HighLight(mode_config, '') + highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path) + generate_result_xlsx(base_dir) self.assertTrue(compare_excel_files_with_highlight(file_path, os.path.join(base_dir, 'target_result.xlsx'))) def test_highlight_rows_xlsx_yellow(self): data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] + 'torch.float32', 'torch.float32', [2, 2], [2, 2], True, True, + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, True, 'Yes', '', '-1'] ] columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] result_df = pd.DataFrame(data, columns=columns) highlight_dict = {'yellow_rows': [0]} file_path = os.path.join(base_dir, 'result.xlsx') - highlight_rows_xlsx(result_df, highlight_dict, file_path) + + mode_config = ModeConfig(dump_mode=Const.ALL) + highlight = HighLight(mode_config, '') + highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path) + generate_result_xlsx(base_dir) self.assertTrue(compare_excel_files_with_highlight(file_path, os.path.join(base_dir, 'target_result_yellow.xlsx'))) @patch("msprobe.core.compare.highlight.save_workbook") def test_highlight_rows_xlsx_malicious_columns(self, mock_save_book): data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] + 'torch.float32', 'torch.float32', [2, 2], [2, 2], True, True, + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, True, 'Yes', '', '-1'] ] columns = CompareConst.COMPARE_RESULT_HEADER + ['=Data_name'] result_df = pd.DataFrame(data, columns=columns) @@ -366,7 +347,9 @@ class TestUtilsMethods(unittest.TestCase): temp_output_file = 'temp_output.txt' sys.stdout = open(temp_output_file, 'w') - highlight_rows_xlsx(result_df, highlight_dict, file_path) + mode_config = ModeConfig(dump_mode=Const.ALL) + highlight = HighLight(mode_config, '') + highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path) with open(temp_output_file, 'r') as f: output = f.read() @@ -377,11 +360,11 @@ class TestUtilsMethods(unittest.TestCase): @patch("msprobe.core.compare.highlight.save_workbook") def test_highlight_rows_xlsx_malicious_type(self, mock_save_book): data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - '=torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'], + '=torch.float32', 'torch.float32', [2, 2], [2, 2], True, True, + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, True, 'Yes', '', '-1'], ['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - '=torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] + '=torch.float32', 'torch.float32', [2, 2], [2, 2], True, True, + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, True, 'Yes', '', '-1'] ] columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] result_df = pd.DataFrame(data, columns=columns) @@ -391,7 +374,9 @@ class TestUtilsMethods(unittest.TestCase): temp_output_file = 'temp_output.txt' sys.stdout = open(temp_output_file, 'w') - highlight_rows_xlsx(result_df, highlight_dict, file_path) + mode_config = ModeConfig(dump_mode=Const.ALL) + highlight = HighLight(mode_config, '') + highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path) with open(temp_output_file, 'r') as f: output = f.read() @@ -406,20 +391,13 @@ class TestUtilsMethods(unittest.TestCase): add_highlight_row_info(color_list, num, highlight_err_msg) self.assertEqual(color_list, [(1, ["a", "b"]), (5, ["c", "highlight"])]) - def test_add_highlight_row_info_new(self): - color_list = [(1, ["a", "b"]), (5, ["c"])] - num = 6 - highlight_err_msg = "highlight" - add_highlight_row_info(color_list, num, highlight_err_msg) - self.assertEqual(color_list, [(1, ["a", "b"]), (5, ["c"]), (6, ["highlight"])]) - def test_update_highlight_err_msg(self): data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'], + 'torch.float32', 'torch.float32', [2, 2], [2, 2], True, True, + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, True, 'Yes', '', '-1'], ['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] + 'torch.float32', 'torch.float32', [2, 2], [2, 2], True, True, + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, True, 'Yes', '', '-1'] ] columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] result_df = pd.DataFrame(data, columns=columns) @@ -429,27 +407,32 @@ class TestUtilsMethods(unittest.TestCase): 'red_lines': [(0, ['a', 'b'])], 'yellow_lines': [(0, ['c']), (1, ['d'])] } - update_highlight_err_msg(result_df, highlight_dict) + + mode_config = ModeConfig(dump_mode=Const.ALL) + highlight = HighLight(mode_config, '') + highlight.update_highlight_err_msg(result_df, highlight_dict) t_data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', 'a\nb', '-1'], + 'torch.float32', 'torch.float32', [2, 2], [2, 2], True, True, + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, True, 'Yes', 'a\nb', '-1'], ['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', 'd', '-1'] + 'torch.float32', 'torch.float32', [2, 2], [2, 2], True, True, + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, True, 'Yes', 'd', '-1'] ] target_result_df = pd.DataFrame(t_data, columns=columns) self.assertTrue(result_df.equals(target_result_df)) def test_update_highlight_err_msg_md5(self): data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], 'abc', 'abc', 'pass'] + 'torch.float32', 'torch.float32', True, True, [2, 2], [2, 2], 'abc', 'abc', True, 'pass'] ] columns = CompareConst.MD5_COMPARE_RESULT_HEADER result_df = pd.DataFrame(data, columns=columns) highlight_dict = {} - result = update_highlight_err_msg(result_df, highlight_dict) + mode_config = ModeConfig(dump_mode=Const.MD5) + highlight = HighLight(mode_config, '') + result = highlight.update_highlight_err_msg(result_df, highlight_dict) self.assertEqual(result, None) @@ -466,5 +449,41 @@ class TestUtilsMethods(unittest.TestCase): 'red_lines': [(0, ['a', 'b'])], 'yellow_lines': [(0, ['c']), (1, ['d'])] } - result = update_highlight_err_msg(result_df, highlight_dict) + mode_config = ModeConfig() + highlight = HighLight(mode_config, '') + result = highlight.update_highlight_err_msg(result_df, highlight_dict) self.assertEqual(result, None) + + def test_find_error_rows(self): + api_batch = ApiBatch("Functional_batch_norm_0_forward", 0) + api_batch.input_len = 1 + api_batch.output_end_index = 4 + api_batch.params_end_index = 4 + summary_result = [summary_line_input, summary_line_1, summary_line_2, summary_line_3] + highlight_dict_test = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []} + mode_config = ModeConfig() + highlight = HighLight(mode_config, '') + highlight.find_error_rows(summary_result, api_batch, highlight_dict_test) + self.assertEqual(highlight_dict_test, + {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}) + + def test_find_compare_result_error_rows(self): + result = [line_input, line_1, line_2, line_3] + result_df = pd.DataFrame(result) + result_df.columns = CompareConst.COMPARE_RESULT_HEADER + [Const.STATE, Const.API_ORIGIN_NAME] + highlight_dict_test = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []} + mode_config = ModeConfig(dump_mode=Const.ALL) + highlight = HighLight(mode_config, '') + highlight.find_compare_result_error_rows(result_df, highlight_dict_test) + self.assertEqual(highlight_dict_test, { + "red_rows": {1, 3}, + "yellow_rows": {2}, + "red_lines": [ + (1, ["maximum or minimum is nan, -inf, or inf"]), + (3, ["maximum absolute error exceeds 1e+10"]) + ], + "yellow_lines": [ + (2, ["The output's one thousandth err ratio decreases by more than 0.1 compared to the input/parameter's"]), + (3, ["The output's cosine decreases by more than 0.1 compared to the input/parameter's"]) + ] + }) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_multiprocessing_compute.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_multiprocessing_compute.py index 9c2dea835fea13af7902bf796d9ab06c9eb6a61b..736bc6207e5d69a48c71bd4f771538952a4b8f5c 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_multiprocessing_compute.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_multiprocessing_compute.py @@ -7,125 +7,245 @@ import unittest import pandas as pd -from msprobe.core.common.const import CompareConst, Const +from msprobe.core.common.const import Const, CompareConst from msprobe.core.common.utils import CompareException -from msprobe.core.compare.acc_compare import Comparator, ModeConfig -from msprobe.core.compare.multiprocessing_compute import ComparisonResult, _handle_multi_process, _save_cmp_result, \ - check_accuracy, read_dump_data -from test_acc_compare import generate_dump_json +from msprobe.core.compare.acc_compare import ModeConfig +from msprobe.core.compare.multiprocessing_compute import check_accuracy, CompareRealData, ComparisonResult +from msprobe.pytorch.compare.pt_compare import read_real_data +from test_acc_compare import generate_dump_json, generate_pt, generate_stack_json data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', + 'torch.float32', 'torch.float32', [2, 2], [2, 2], True, True, + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, - 'Yes', '', '-1']] + True, 'Yes', '', ['-1', '-1']]] o_data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], - 'unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', + 'torch.float32', 'torch.float32', [2, 2], [2, 2], True, True, + 'unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', 1, 1, 1, 1, 1, 1, 1, 1, - 'None', 'No bench data matched.', '-1']] + True, 'None', 'NPU does not have data file.', ['-1', '-1']]] columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] result_df = pd.DataFrame(data, columns=columns) o_result = pd.DataFrame(o_data, columns=columns) base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'test_cmp_multiprocessing_compute') +base_dir3 = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'test_acc_compare_data3') +pt_dir = os.path.join(base_dir3, f'dump_data_dir') class TestUtilsMethods(unittest.TestCase): + def test_check_accuracy(self): + max_abs_err = '' + + cos_1 = CompareConst.SHAPE_UNMATCH + result_1 = check_accuracy(cos_1, max_abs_err) + self.assertEqual(result_1, CompareConst.ACCURACY_CHECK_UNMATCH) + + cos_2 = CompareConst.NONE + result_2 = check_accuracy(cos_2, max_abs_err) + self.assertEqual(result_2, CompareConst.NONE) + + cos_3 = 'N/A' + result_3 = check_accuracy(cos_3, max_abs_err) + self.assertEqual(result_3, CompareConst.ACCURACY_CHECK_NO) + + cos_4 = '' + result_4 = check_accuracy(cos_4, max_abs_err) + self.assertEqual(result_4, CompareConst.NONE) + + cos_5 = 0.95 + max_abs_err = 0.002 + result_5 = check_accuracy(cos_5, max_abs_err) + self.assertEqual(result_5, CompareConst.ACCURACY_CHECK_NO) + + cos_6 = 0.85 + max_abs_err = 2 + result_6 = check_accuracy(cos_6, max_abs_err) + self.assertEqual(result_6, CompareConst.ACCURACY_CHECK_NO) + + cos_7 = 0.95 + max_abs_err = 0.001 + result_7 = check_accuracy(cos_7, max_abs_err) + self.assertEqual(result_7, CompareConst.ACCURACY_CHECK_YES) + + +class TestCompareRealData(unittest.TestCase): + def setUp(self): - self.result_df = pd.DataFrame(columns=[ - CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR, - CompareConst.ERROR_MESSAGE, CompareConst.ACCURACY, - CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO + self.result_df = pd.DataFrame([['']*8]*2, columns=[ + CompareConst.COSINE, CompareConst.EUC_DIST, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR, + CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO, + CompareConst.ACCURACY, CompareConst.ERROR_MESSAGE ]) os.makedirs(base_dir, mode=0o750, exist_ok=True) + os.makedirs(base_dir3, mode=0o750, exist_ok=True) + os.makedirs(pt_dir, mode=0o750, exist_ok=True) self.lock = threading.Lock() def tearDown(self): if os.path.exists(base_dir): shutil.rmtree(base_dir) - - def test_handle_multi_process(self): - stack_mode = False - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - func = Comparator(mode_config).compare_ops - generate_dump_json(base_dir) - input_parma = {'bench_json_path': os.path.join(base_dir, 'dump.json')} - lock = multiprocessing.Manager().RLock() - result = _handle_multi_process(func, input_parma, result_df, lock) - self.assertTrue(result.equals(o_result)) + if os.path.exists(pt_dir): + shutil.rmtree(pt_dir) + if os.path.exists(base_dir3): + shutil.rmtree(base_dir3) def test_read_dump_data(self): - result = read_dump_data(result_df) + file_reader = read_real_data + mode_config = ModeConfig(dump_mode=Const.ALL) + cross_frame = False + compare_real_data = CompareRealData(file_reader, mode_config, cross_frame) + + # normal + result = compare_real_data.read_dump_data(result_df) self.assertEqual(result, {'Functional.linear.0.forward.input.0': ['-1', '-1']}) + # index error with self.assertRaises(CompareException) as context: - result = read_dump_data(pd.DataFrame()) - self.assertEqual(context.exception.code, CompareException.INDEX_OUT_OF_BOUNDS_ERROR) + result = compare_real_data.read_dump_data(pd.DataFrame()) + self.assertEqual(context.exception.code, CompareException.INVALID_KEY_ERROR) def test_save_cmp_result_success(self): + file_reader = read_real_data + mode_config = ModeConfig(dump_mode=Const.ALL) + cross_frame = False + compare_real_data = CompareRealData(file_reader, mode_config, cross_frame) + comparison_result = ComparisonResult( cos_result=[0.99, 0.98], max_err_result=[0.01, 0.02], max_relative_err_result=[0.001, 0.002], - err_msgs=['', 'Error in comparison'], + euc_dist_result=[0.5, 0.49], one_thousand_err_ratio_result=[0.1, 0.2], - five_thousand_err_ratio_result=[0.05, 0.1] + five_thousand_err_ratio_result=[0.05, 0.1], + err_msgs=['', 'Error in comparison'] ) offset = 0 - updated_df = _save_cmp_result(offset, comparison_result, self.result_df, self.lock) + updated_df = compare_real_data._save_cmp_result(offset, comparison_result, self.result_df, self.lock) self.assertEqual(updated_df.loc[0, CompareConst.COSINE], 0.99) self.assertEqual(updated_df.loc[1, CompareConst.COSINE], 0.98) self.assertEqual(updated_df.loc[1, CompareConst.ERROR_MESSAGE], 'Error in comparison') def test_save_cmp_result_index_error(self): + file_reader = read_real_data + mode_config = ModeConfig(dump_mode=Const.ALL) + cross_frame = False + compare_real_data = CompareRealData(file_reader, mode_config, cross_frame) + comparison_result = ComparisonResult( cos_result=[0.99], max_err_result=[], max_relative_err_result=[0.001], - err_msgs=[''], + euc_dist_result=[0.5], one_thousand_err_ratio_result=[0.1], - five_thousand_err_ratio_result=[0.05] + five_thousand_err_ratio_result=[0.05], + err_msgs=[''] ) with self.assertRaises(CompareException) as context: - _save_cmp_result(0, comparison_result, self.result_df, self.lock) + compare_real_data._save_cmp_result(0, comparison_result, self.result_df, self.lock) self.assertEqual(context.exception.code, CompareException.INDEX_OUT_OF_BOUNDS_ERROR) - def test_check_accuracy(self): - max_abs_err = '' - - cos_1 = CompareConst.SHAPE_UNMATCH - result_1 = check_accuracy(cos_1, max_abs_err) - self.assertEqual(result_1, CompareConst.ACCURACY_CHECK_UNMATCH) - - cos_2 = CompareConst.NONE - result_2 = check_accuracy(cos_2, max_abs_err) - self.assertEqual(result_2, CompareConst.NONE) - - cos_3 = 'N/A' - result_3 = check_accuracy(cos_3, max_abs_err) - self.assertEqual(result_3, CompareConst.ACCURACY_CHECK_NO) + def test_compare_by_op_bench_normal(self): + npu_op_name = 'Functional.linear.0.forward.input.0' + bench_op_name = 'Functional.linear.0.forward.input.0' + + file_reader = read_real_data + mode_config = ModeConfig(dump_mode=Const.ALL) + cross_frame = False + compare_real_data = CompareRealData(file_reader, mode_config, cross_frame) + + pt_name = '-1' + op_name_mapping_dict = {'Functional.linear.0.forward.input.0': [pt_name, pt_name]} + input_param = {'npu_dump_data_dir': base_dir, 'bench_dump_data_dir': base_dir} + result = compare_real_data.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param) + self.assertEqual(result, ['unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', + 'unsupported', 'NPU does not have data file.']) + + pt_name = 'Functional.linear.0.forward.input.0.pt' + op_name_mapping_dict = {'Functional.linear.0.forward.input.0': [pt_name, pt_name]} + input_param = {'npu_dump_data_dir': base_dir, 'bench_dump_data_dir': base_dir} + result = compare_real_data.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param) + self.assertEqual(result, ['unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', + 'unsupported', "Dump file: ['Functional.linear.0.forward.input.0.pt', 'Functional.linear.0.forward.input.0.pt'] not found or read failed."]) + + generate_pt(base_dir) + result = compare_real_data.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param) + self.assertEqual(result, [1.0, 0.0, 0.0, 0.0, 1.0, 1.0, '']) + + def test_compare_by_op_bench_no_npu_real_data(self): + npu_op_name = 'Functional.linear.0.forward.input.0' + bench_op_name = 'N/A' + op_name_mapping_dict = {'Functional.linear.0.forward.input.0': [-1, -1]} + input_param = {} + + file_reader = read_real_data + mode_config = ModeConfig(dump_mode=Const.ALL) + cross_frame = False + compare_real_data = CompareRealData(file_reader, mode_config, cross_frame) + + result = compare_real_data.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param) + self.assertEqual(result, ['unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', + 'unsupported', 'NPU does not have data file.']) + + def test_compare_ops(self): + generate_dump_json(base_dir3) + generate_stack_json(base_dir3) + generate_pt(pt_dir) + dump_path = os.path.join(base_dir3, 'dump.json') + stack_path = os.path.join(base_dir3, 'stack.json') + input_param = {'npu_json_path': dump_path, 'bench_json_path': dump_path, 'stack_json_path': stack_path, + 'is_print_compare_log': True, 'npu_dump_data_dir': pt_dir, 'bench_dump_data_dir': pt_dir} + dump_path_dict = {'Functional.linear.0.forward.input.0': ['Functional.linear.0.forward.input.0.pt', + 'Functional.linear.0.forward.input.0.pt']} + result_df = pd.DataFrame({ + 'NPU Name': ['Functional.linear.0.forward.input.0'], + 'Bench Name': ['Functional.linear.0.forward.input.0'], + 'Err_message': '' + }) + + file_reader = read_real_data + mode_config = ModeConfig(dump_mode=Const.ALL) + cross_frame = False + compare_real_data = CompareRealData(file_reader, mode_config, cross_frame) + + updated_df = compare_real_data.compare_ops(idx=0, dump_path_dict=dump_path_dict, result_df=result_df, + lock=self.lock, input_param=input_param) + + self.assertEqual(updated_df.loc[0, CompareConst.COSINE], 1.0) + self.assertEqual(updated_df.loc[0, CompareConst.MAX_ABS_ERR], 0) + + def test_do_multi_process(self): + data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', + 'torch.float32', 'torch.float32', [2, 2], [2, 2], True, True, + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, True, 'Yes', '', ['-1', '-1']]] + o_data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', + 'torch.float32', 'torch.float32', [2, 2], [2, 2], True, True, + 'unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', + 1, 1, 1, 1, 1, 1, 1, 1, True, 'None', 'NPU does not have data file.', ['-1', '-1']]] + columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] + result_df = pd.DataFrame(data, columns=columns) + o_result = pd.DataFrame(o_data, columns=columns) + generate_dump_json(base_dir) + input_param = {'bench_json_path': os.path.join(base_dir, 'dump.json')} - cos_4 = '' - result_4 = check_accuracy(cos_4, max_abs_err) - self.assertEqual(result_4, CompareConst.NONE) + file_reader = read_real_data + mode_config = ModeConfig(dump_mode=Const.ALL) + cross_frame = False + compare_real_data = CompareRealData(file_reader, mode_config, cross_frame) - cos_5 = 0.95 - max_abs_err = 0.002 - result_5 = check_accuracy(cos_5, max_abs_err) - self.assertEqual(result_5, CompareConst.ACCURACY_CHECK_NO) + result = compare_real_data.do_multi_process(input_param, result_df) + self.assertTrue(result.equals(o_result)) - cos_6 = 0.85 - max_abs_err = 2 - result_6 = check_accuracy(cos_6, max_abs_err) - self.assertEqual(result_6, CompareConst.ACCURACY_CHECK_NO) + def test_handle_multi_process(self): + file_reader = read_real_data + mode_config = ModeConfig(dump_mode=Const.ALL) + cross_frame = False + compare_real_data = CompareRealData(file_reader, mode_config, cross_frame) - cos_7 = 0.95 - max_abs_err = 0.001 - result_7 = check_accuracy(cos_7, max_abs_err) - self.assertEqual(result_7, CompareConst.ACCURACY_CHECK_YES) + func = compare_real_data.compare_ops + generate_dump_json(base_dir) + input_param = {'bench_json_path': os.path.join(base_dir, 'dump.json')} + lock = multiprocessing.Manager().RLock() + result = compare_real_data._handle_multi_process(func, input_param, result_df, lock) + self.assertTrue(result.equals(o_result)) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_postprocess_pass.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_postprocess_pass.py index 9cb33eb277848fa96bdf5b7456867d8579359723..f3623da772d1cbde684aa53119639faa93e4f068 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_postprocess_pass.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_postprocess_pass.py @@ -14,9 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. """ +from dataclasses import dataclass from unittest import TestCase -from msprobe.core.compare.layer_mapping.postprocess_pass import extract_next_item_last_number -from msprobe.core.compare.layer_mapping.postprocess_pass import replace_next_item_index +from msprobe.core.compare.layer_mapping.postprocess_pass import extract_next_item_last_number, \ + replace_next_item_index, renumber_index_pass + + +@dataclass +class DataItem: + """Class for keeping track of an item in inventory""" + type_name: str + full_scope: str + layer_scope: str class TestPostProcessPass(TestCase): @@ -46,3 +55,12 @@ class TestPostProcessPass(TestCase): replace_result = replace_next_item_index(input_data, prefix, inf_value) self.assertEqual(replace_result, input_data) + def test_renumber_index_pass(self): + a = DataItem("ParallelTransformer", "fake_data.layers.10", "fake_data.layers") + b = DataItem("ParallelTransformer", "fake_data.layers.12", "fake_data.layers") + c = DataItem("FakeLayer", "fake_data.layers.10.a.b.c", "fake_data.layers.a.b") + data_items = [a, b, c] + renumber_index_pass(data_items, "ParallelTransformer") + self.assertEqual(a.full_scope, "fake_data.layers.0") + self.assertEqual(b.full_scope, "fake_data.layers.2") + self.assertEqual(c.full_scope, "fake_data.layers.0.a.b.c") diff --git a/debug/accuracy_tools/msprobe/test/core_ut/config_check/bench.sh b/debug/accuracy_tools/msprobe/test/core_ut/config_check/bench.sh new file mode 100644 index 0000000000000000000000000000000000000000..217676ef0f451b6b8f2d2cecb14545d9a7f8dd8b --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/config_check/bench.sh @@ -0,0 +1,25 @@ +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +CKPT_SAVE_DIR="your model save ckpt path" +DATA_PATH="your data path" +TOKENIZER_MODEL="your tokenizer path" +CKPT_LOAD_DIR="your model ckpt path" +TP=1 + +DISTRIBUTED_ARGS=" + --master_port $MASTER_PORT +" + +GPT_ARGS=" + --tensor-model-parallel-size ${TP} \ + --sequence-parallel \ + --tokenizer-model ${TOKENIZER_MODEL} \ +" + +torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \ + $GPT_ARGS \ + --distributed-backend nccl \ + --load $CKPT_LOAD_DIR \ + --save $CKPT_SAVE_DIR \ + | tee logs/train_llama2_7b.log \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/core_ut/config_check/cmp.sh b/debug/accuracy_tools/msprobe/test/core_ut/config_check/cmp.sh new file mode 100644 index 0000000000000000000000000000000000000000..8df9e6507975c7edbcfee105d838563171c720e4 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/config_check/cmp.sh @@ -0,0 +1,25 @@ +MASTER_PORT=6001 +NNODES=1 +NODE_RANK=0 +CKPT_SAVE_DIR="./aaa" +DATA_PATH="./aaa" +TOKENIZER_MODEL="./aaa" +CKPT_LOAD_DIR="./aaa" +TP=2 + +DISTRIBUTED_ARGS=" + --master_port $MASTER_PORT +" + +GPT_ARGS=" + --tensor-model-parallel-size ${TP} \ + --sequence-parallel \ + --tokenizer-model ${TOKENIZER_MODEL} \ +" + +torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \ + $GPT_ARGS \ + --distributed-backend nccl \ + --load $CKPT_LOAD_DIR \ + --save $CKPT_SAVE_DIR \ + | tee logs/train_llama2_7b.log \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_ckpt_compare.py b/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_ckpt_compare.py new file mode 100644 index 0000000000000000000000000000000000000000..f51aa76f3aed9c8ab2553808563303f5309eb449 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_ckpt_compare.py @@ -0,0 +1,90 @@ +import unittest +from unittest.mock import patch, mock_open +import numpy as np +from msprobe.core.config_check.ckpt_compare import metrics +from msprobe.core.config_check.ckpt_compare import megatron_loader + + + +class TestMetrics(unittest.TestCase): + + def test_in_different_shape(self): + a = np.zeros((2, 3)) + b = np.zeros((2, 3)) + c = np.zeros((3, 2)) + self.assertFalse(metrics.in_different_shape(a, b)) + self.assertTrue(metrics.in_different_shape(a, c)) + + def test_l2_distance(self): + a = np.array([1.0, 2.0, 3.0]) + b = np.array([1.0, 2.0, 3.0]) + c = np.array([4.0, 5.0, 6.0]) + self.assertAlmostEqual(metrics.l2_distance(a, b), 0.0) + self.assertAlmostEqual(metrics.l2_distance(a, c), np.linalg.norm(a - c)) + self.assertIsNone(metrics.l2_distance(None, b)) + self.assertIsNone(metrics.l2_distance(a, None)) + self.assertIsNone(metrics.l2_distance(a, np.zeros((2, 2)))) + + def test_cos_sim(self): + a = np.array([1.0, 0.0, 0.0], dtype=np.float32) + b = np.array([1.0, 0.0, 0.0], dtype=np.float32) + c = np.array([0.0, 1.0, 0.0], dtype=np.float32) + + self.assertAlmostEqual(metrics.cos_sim(a, b), 1.0, places=6) + self.assertAlmostEqual(metrics.cos_sim(a, c), 0.0, places=6) + self.assertIsNone(metrics.cos_sim(a, np.zeros((2, 2), dtype=np.float32))) + + def test_numel(self): + a = np.zeros((2, 3)) + b = np.zeros((2, 3)) + c = np.zeros((3, 2)) + self.assertEqual(metrics.numel(a, b), 6) + self.assertEqual(metrics.numel(a, c), 6) + d = np.zeros((2, 2)) + self.assertEqual(metrics.numel(a, d), (6, 4)) + + def test_shape(self): + a = np.zeros((2, 3)) + b = np.zeros((2, 3)) + c = np.zeros((3, 2)) + self.assertEqual(metrics.shape(a, b), [2, 3]) + self.assertEqual(metrics.shape(a, c), [[2, 3], [3, 2]]) + + +class TestMegatronLoader(unittest.TestCase): + + def test__parse_real_layer_idx(self): + name = 'layers.2.attn/1' # vpp_stage = 1 + result = megatron_loader._parse_real_layer_idx(name, num_layers_per_stage=4, pp_size=2, pp_rank=1) + self.assertEqual(result, 'layers.14.attn') + + def test__parse_real_expert_idx(self): + name = 'layers.0.experts.3.weight' + result = megatron_loader._parse_real_expert_idx(name, num_experts_per_rank=4, exp_rank=2) + self.assertIn('experts.11', result) # 3 + 2*4 = 11 + + # No expert pattern + name2 = 'layers.0.weight' + self.assertEqual(megatron_loader._parse_real_expert_idx(name2, 4, 2), name2) + + def test__consolidate_tp_weights(self): + arr1 = np.ones((2,2)) + arr2 = np.zeros((2,2)) + weights = { + 'linear_fc1.weight': [arr1, arr2], + 'linear_fc2.weight': [arr1, arr2], + 'linear_fc2.bias': [arr1, arr1] + } + result = megatron_loader._consolidate_tp_weights(weights) + self.assertTrue(np.allclose(result['linear_fc1.weight'], np.concatenate([arr1, arr2], axis=0))) + self.assertTrue(np.allclose(result['linear_fc2.weight'], np.concatenate([arr1, arr2], axis=1))) + self.assertTrue(np.allclose(result['linear_fc2.bias'], arr1)) + + def test__parse_num_layers_per_stage(self): + keys = {'layers.0.weight': None, 'layers.1.weight': None, 'layers.2.weight': None} + self.assertEqual(megatron_loader._parse_num_layers_per_stage(keys), 3) + + +if __name__ == '__main__': + unittest.main() + \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_config_check.py b/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_config_check.py new file mode 100644 index 0000000000000000000000000000000000000000..d2522cc1467aefc4391d92435a658390efbc7cb4 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_config_check.py @@ -0,0 +1,173 @@ +import os +import random +import shutil +import unittest +import torch +import json +import numpy as np +import torch.nn as nn +import mindspore as ms +import mindspore.nn as ms_nn +from mindspore import Tensor +from msprobe.core.config_check.config_checker import ConfigChecker +from msprobe.core.config_check.checkers.pip_checker import PipPackageChecker +from msprobe.core.config_check.checkers.random_checker import RandomChecker +from msprobe.core.config_check.checkers.dataset_checker import DatasetChecker +from msprobe.core.config_check.checkers.weights_checker import WeightsChecker +from msprobe.core.common.file_utils import read_xlsx +from msprobe.core.common.framework_adapter import FmkAdp + + +testdir = os.path.dirname(__file__) +config_checking_dir = os.path.dirname(testdir) +temp_dir = os.path.join(config_checking_dir, "temp") +os.makedirs(temp_dir, exist_ok=True) +ms.set_context(device_target="CPU") + + +def seed_all(seed=1234, mode=False): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.use_deterministic_algorithms(mode) + ms.set_seed(seed) + + +class MockPyTorchModule(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 5) + self.relu = nn.ReLU() + + def forward(self, x, y): + x1 = self.linear(x) + x2 = self.relu(x1) + return x2 + + +class MockMindSporeModule(ms_nn.Cell): + def __init__(self): + super().__init__() + self.linear = ms_nn.Dense(10, 5) + self.relu = ms_nn.ReLU() + + def construct(self, x): + x1 = self.linear(x) + x2 = self.relu(x1) + return x2 + + +def get_test_dataset(): + inputs = [torch.rand(10, 10) for _ in range(10)] + labels = [torch.randint(0, 5, (10,)) for _ in range(10)] + ms_inputs = [Tensor(input.numpy()) for input in inputs] + ms_labels = [Tensor(label.numpy()) for label in labels] + return zip(inputs, labels), zip(ms_inputs, ms_labels) + + +def get_test_model(use_pytorch=True): + if use_pytorch: + test_module = MockPyTorchModule() + nn.init.constant_(test_module.linear.weight, 1.0) + nn.init.constant_(test_module.linear.bias, 1.0) + return test_module + else: + test_module = MockMindSporeModule() + for param in test_module.get_parameters(): + param.set_data(ms.Tensor(np.ones(param.data.shape), dtype=param.data.dtype)) + return test_module + + +@unittest.mock.patch("msprobe.core.config_check.checkers.pip_checker.collect_pip_data") +@unittest.mock.patch("msprobe.core.config_check.checkers.env_args_checker.collect_env_data") +def train_test(seed, output_zip_path, shell_path, mock_env, mock_pip): + if seed == 1234: + mock_pip.return_value = "transformers=0.0.1" + mock_env.return_value = {"NCCL_DETERMINISTIC": True} + else: + mock_pip.return_value = "transformers=0.0.2" + mock_env.return_value = {"HCCL_DETERMINISTIC": False, "ASCEND_LAUNCH_BLOCKING": 1} + seed_all(seed) + + use_pytorch = seed == 1234 + test_dataset, ms_test_dataset = get_test_dataset() + test_module = get_test_model(use_pytorch) + + if use_pytorch: + loss_fun = nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(test_module.parameters(), lr=1e-2) + ConfigChecker(test_module, shell_path, output_zip_path) + + for input_data, label in test_dataset: + output = test_module(input_data, y=input_data) + loss = loss_fun(output, label) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + else: + loss_fun = ms_nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + optimizer = ms_nn.SGD(test_module.trainable_params(), learning_rate=1e-2) + train_network = ms_nn.TrainOneStepCell(ms_nn.WithLossCell(test_module, loss_fun), optimizer) + ConfigChecker(test_module, shell_path, output_zip_path, fmk="mindspore") + + for input_data, label in ms_test_dataset: + loss = train_network(input_data, label) + + + +class TestConfigChecker(unittest.TestCase): + def tearDown(self): + FmkAdp.set_fmk("pytorch") + shutil.rmtree(temp_dir) + + + def test_all(self): + train_test(1234, os.path.join(temp_dir, "config_check_pack1.zip"), [os.path.join(testdir, "cmp.sh")]) + + ConfigChecker.pre_forward_fun_list = [] + ConfigChecker.step = 0 + RandomChecker.write_once = False + ConfigChecker.apply_patches("pytorch") + ConfigChecker.apply_patches("mindspore") + + train_test(1233, os.path.join(temp_dir, "config_check_pack2.zip"), [os.path.join(testdir, "bench.sh")]) + + ConfigChecker.compare(os.path.join(temp_dir, "config_check_pack1.zip"), + os.path.join(temp_dir, "config_check_pack2.zip"), + os.path.join(temp_dir, "compare_output")) + + compare_output_dir = os.path.join(temp_dir, "compare_output") + + total_check_result = read_xlsx(os.path.join(compare_output_dir, ConfigChecker.result_filename)) + self.assertEqual(total_check_result.columns.tolist(), ConfigChecker.result_header) + target_total_check_result = [ + ['env', "error"], + ['pip', "error"], + ['dataset', "error"], + ['weights', "error"], + ['hyperparameters', "error"], + ['random', "error"] + ] + self.assertEqual(total_check_result.values.tolist(), target_total_check_result) + + pip_data_check_result = read_xlsx(os.path.join(compare_output_dir, ConfigChecker.result_filename), + sheet_name=PipPackageChecker.target_name_in_zip) + self.assertEqual(pip_data_check_result.columns.tolist(), PipPackageChecker.result_header) + self.assertEqual(pip_data_check_result.iloc[0].tolist(), ['transformers', '0.0.1', '0.0.2', 'error']) + + random_check_result = read_xlsx(os.path.join(compare_output_dir, ConfigChecker.result_filename), + sheet_name=RandomChecker.target_name_in_zip) + self.assertEqual(random_check_result.columns.tolist(), RandomChecker.result_header) + self.assertEqual(len(random_check_result), 7) + + dataset_check_result = read_xlsx(os.path.join(compare_output_dir, ConfigChecker.result_filename), + sheet_name=DatasetChecker.target_name_in_zip) + self.assertEqual(dataset_check_result.columns.tolist(), DatasetChecker.result_header) + self.assertEqual(len(dataset_check_result), 20) + + weight_check_result = read_xlsx(os.path.join(compare_output_dir, ConfigChecker.result_filename), + sheet_name=WeightsChecker.target_name_in_zip) + self.assertEqual(weight_check_result.columns.tolist(), WeightsChecker.result_header) + self.assertEqual(len(weight_check_result), 20) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_dataset_checker.py b/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_dataset_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..27db8e04d0890d579fae4ec02a7260102f08979b --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_dataset_checker.py @@ -0,0 +1,76 @@ +import unittest +import torch +import pandas as pd +from unittest.mock import patch, MagicMock + +from msprobe.core.config_check.checkers.dataset_checker import compare_dataset, \ + compare_dataset_dicts, parse_args_and_kargs, process_obj + + +class TestTensorProcessing(unittest.TestCase): + + def test_process_obj_tensor(self): + tensor = torch.tensor([1.0, 2.0, 3.0]) + result = process_obj(tensor) + self.assertEqual(isinstance(result, dict), True) + self.assertEqual(set(result.keys()), {'max', 'min', 'mean', 'norm'}) + + def test_process_obj_list(self): + obj = [torch.tensor([1.0]), torch.tensor([2.0])] + result = process_obj(obj) + self.assertEqual(isinstance(result, dict), True) + self.assertEqual(set(result.keys()), {0, 1}) + + def test_process_obj_dict(self): + obj = {'a': torch.tensor([1.0]), 'b': torch.tensor([2.0])} + result = process_obj(obj) + self.assertEqual(isinstance(result, dict), True) + self.assertEqual(set(result.keys()), {'a', 'b'}) + + def test_process_obj_other(self): + obj = "test" + result = process_obj(obj) + self.assertEqual(result, "") + + def test_parse_args_and_kargs(self): + args = (torch.tensor([1.0]),) + kwargs = {'a': torch.tensor([2.0])} + result = parse_args_and_kargs(args, kwargs) + self.assertEqual(isinstance(result, dict), True) + self.assertEqual(set(result.keys()), {'args', 'kwargs'}) + + def test_compare_dataset_dicts_equal(self): + dict1 = {'a': {'max': 1.0, 'min': 0.0, 'mean': 0.5, 'norm': 0.7}} + dict2 = {'a': {'max': 1.0, 'min': 0.0, 'mean': 0.5, 'norm': 0.7}} + results = compare_dataset_dicts(dict1, dict2) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]['equal'], True) + + def test_compare_dataset_dicts_not_equal(self): + dict1 = {'a': {'max': 1.0, 'min': 0.0, 'mean': 0.5, 'norm': 0.7}} + dict2 = {'a': {'max': 2.0, 'min': 0.0, 'mean': 0.5, 'norm': 0.7}} + results = compare_dataset_dicts(dict1, dict2) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]['equal'], False) + + def test_compare_dataset_dicts_nested(self): + dict1 = {'a': {'b': {'max': 1.0, 'min': 0.0, 'mean': 0.5, 'norm': 0.7}}} + dict2 = {'a': {'b': {'max': 1.0, 'min': 0.0, 'mean': 0.5, 'norm': 0.7}}} + results = compare_dataset_dicts(dict1, dict2) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]['tag'], 'a.b') + + @patch('os.listdir', side_effect=[["step1"], ["rank1"]]) + @patch('os.path.isdir', return_value=True) + @patch('os.path.isfile', return_value=True) + @patch('msprobe.core.config_check.checkers.dataset_checker.load_json') + def test_compare_dataset(self, mock_load_json, mock_isfile, mock_isdir, mock_listdir): + mock_load_json.return_value = {'a': {'max': 1.0, 'min': 0.0, 'mean': 0.5, 'norm': 0.7}} + bench_dir = 'bench' + cmp_dir = 'cmp' + result = compare_dataset(bench_dir, cmp_dir) + self.assertEqual(isinstance(result, pd.DataFrame), True) + + + + \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_megatron_load.py b/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_megatron_load.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a56d432b0b86e4d701ce37015044cacfde35c8 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_megatron_load.py @@ -0,0 +1,104 @@ +import unittest +from unittest.mock import patch, MagicMock +from msprobe.core.config_check.ckpt_compare import megatron_loader +import numpy as np + + +class TestMegatronLoader(unittest.TestCase): + def setUp(self) -> None: + self.mock_fmk_adp = MagicMock() + self.mock_logger = MagicMock() + self.patcher1 = patch('msprobe.core.config_check.ckpt_compare.megatron_loader.FmkAdp', self.mock_fmk_adp) + self.patcher2 = patch('msprobe.core.common.log.logger', self.mock_logger) + self.patcher1.start() + self.patcher2.start() + + def tearDown(self) -> None: + self.patcher1.stop() + self.patcher2.stop() + + def test__get_parameter_given_nested_dict_when_recursive_then_yield_all(self): + weights = {'a': {'b': 1, 'c': {'d': 2}}, 'e': 3} + self.mock_fmk_adp.is_tensor.side_effect = lambda x: isinstance(x, int) + self.mock_fmk_adp.asnumpy.side_effect = lambda x: x + + result = list(megatron_loader._get_parameter(weights)) + expected = [ + ('a.b', 1), + ('a.c.d', 2), + ('e', 3) + ] + self.assertEqual(result, expected) + + def test__get_parameter_given_flat_dict_when_no_nesting_then_yield_all(self): + weights = {'a': 1, 'b': 2} + self.mock_fmk_adp.is_tensor.return_value = True + self.mock_fmk_adp.asnumpy.side_effect = lambda x: x + + result = list(megatron_loader._get_parameter(weights)) + self.assertEqual(result, [('a', 1), ('b', 2)]) + + def test__parse_real_layer_idx_given_no_layer_index_when_no_match_then_return_original(self): + param_name = 'embedding.weight/0' + result = megatron_loader._parse_real_layer_idx(param_name, 1, 1, 0) + self.assertEqual(result, 'embedding.weight') + + # _parse_real_expert_idx tests + def test__parse_real_expert_idx_given_valid_name_when_exp_parallel_then_calculate_index(self): + param_name = 'experts.0.mlp.dense_h_to_4h.weight' + result = megatron_loader._parse_real_expert_idx(param_name, num_experts_per_rank=2, exp_rank=1) + self.assertEqual(result, 'experts.2.mlp.dense_h_to_4h.weight') + + def test__parse_real_expert_idx_given_no_expert_index_when_no_match_then_return_original(self): + param_name = 'non_expert.weight' + result = megatron_loader._parse_real_expert_idx(param_name, 1, 0) + self.assertEqual(result, 'non_expert.weight') + + # _consolidate_tp_weights tests + def test__consolidate_tp_weights_given_column_parallel_then_concat_axis0(self): + weights = { + 'linear_qkv': [np.array([[1, 2], [3, 4]]), np.array([[5, 6], [7, 8]])] + } + result = megatron_loader._consolidate_tp_weights(weights) + expected = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + np.testing.assert_array_equal(result['linear_qkv'], expected) + + def test__consolidate_tp_weights_given_row_parallel_then_concat_axis1(self): + weights = { + 'linear_proj.weight': [np.array([[1, 2], [3, 4]]), np.array([[5, 6], [7, 8]])] + } + result = megatron_loader._consolidate_tp_weights(weights) + expected = np.array([[1, 2, 5, 6], [3, 4, 7, 8]]) + np.testing.assert_array_equal(result['linear_proj.weight'], expected) + + def test__consolidate_tp_weights_given_other_params_then_use_first(self): + weights = { + 'embedding': [np.array([1, 2, 3]), np.array([1, 2, 3])] + } + result = megatron_loader._consolidate_tp_weights(weights) + np.testing.assert_array_equal(result['embedding'], np.array([1, 2, 3])) + + # _parse_num_layers_per_stage tests + def test__parse_num_layers_per_stage_given_valid_keys_then_calculate_max(self): + tp_partition = { + 'layers.0.linear.weight': [], + 'layers.1.linear.weight': [], + 'layers.5.linear.weight': [] + } + result = megatron_loader._parse_num_layers_per_stage(tp_partition) + self.assertEqual(result, 6) + + @patch('os.listdir', return_value=[]) + def test_parse_parallel_size_given_empty_dir_then_raise_error(self, mock_listdir): + with self.assertRaises(ValueError): + megatron_loader.parse_parallel_size('empty_dir') + + @patch('os.path.exists', return_value=False) + @patch('re.findall', return_value=[]) + def test_parse_iteration_given_invalid_path_then_raise_error(self, mock_find, mock_exists): + with self.assertRaises(ValueError): + megatron_loader.parse_iteration('invalid_path') + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_random_checker.py b/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_random_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..e19ba0ba0aed12820664a5d79a8c8e155ae1ccf8 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_random_checker.py @@ -0,0 +1,72 @@ +import unittest + +from msprobe.core.config_check.checkers.random_checker import stack_match + +class TestStackMatch(unittest.TestCase): + def test_identical_stacks(self): + stack1 = [ + "File /project/utils/funcs.py, line 42, in calculate_sum, return a + b", + "File /project/main.py, line 15, in main, result = calculate_sum(1, 2)" + ] + stack2 = [ + "File /project/utils/funcs.py, line 42, in calculate_sum, return a + b", + "File /project/main.py, line 15, in main, result = calculate_sum(1, 2)" + ] + self.assertTrue(stack_match(stack1, stack2)) + + def test_different_paths_same_file(self): + stack1 = [ + "File /user1/project/utils/funcs.py, line 42, in calculate_sum, return a + b" + ] + stack2 = [ + "File /user2/another_project/utils/funcs.py, line 42, in calculate_sum, return a + b" + ] + # 文件名相同,函数名和代码行相同 + self.assertTrue(stack_match(stack1, stack2)) + + def test_different_filenames(self): + stack1 = [ + "File /project/utils/funcs.py, line 42, in calculate_sum, return a + b" + ] + stack2 = [ + "File /project/utils/other_funcs.py, line 42, in calculate_sum, return a + b" + ] + # 文件名不同 + self.assertFalse(stack_match(stack1, stack2)) + + def test_different_line_numbers(self): + stack1 = [ + "File /project/utils/funcs.py, line 42, in calculate_sum, return a + b" + ] + stack2 = [ + "File /project/utils/funcs.py, line 45, in calculate_sum, return a + b" + ] + self.assertTrue(stack_match(stack1, stack2)) + + def test_different_functions(self): + stack1 = [ + "File /project/utils/funcs.py, line 42, in calculate_sum, return a + b" + ] + stack2 = [ + "File /project/utils/funcs.py, line 42, in multiply, return a * b" + ] + self.assertFalse(stack_match(stack1, stack2)) + + def test_similar_code_different_variables(self): + stack1 = [ + "File /project/main.py, line 15, in main, result = calculate_sum(a, b)" + ] + stack2 = [ + "File /project/main.py, line 15, in main, result = calculate_sum(x, y)" + ] + # 代码行前缀和结构相似 + self.assertTrue(stack_match(stack1, stack2)) + + def test_different_code_structure(self): + stack1 = [ + "File /project/main.py, line 15, in main, result = calculate_sum(a, b)" + ] + stack2 = [ + "File /project/main.py, line 15, in main, print('Hello, world!')" + ] + self.assertFalse(stack_match(stack1, stack2)) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_weight_checker.py b/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_weight_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..4920c034455920075c617f7db90a26fbe826b10c --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_weight_checker.py @@ -0,0 +1,79 @@ +import unittest +from unittest.mock import patch +import pandas as pd +import os +import torch + +from msprobe.core.config_check.checkers.weights_checker import collect_weights_data, compare_weight, compare_weight_file + + +class TestWeightComparison(unittest.TestCase): + @patch('msprobe.core.config_check.utils.utils.get_tensor_features') + @patch('torch.nn.Module.named_parameters') + def test_collect_weights_data(self, mock_named_parameters, mock_get_tensor_features): + mock_model = unittest.mock.create_autospec(torch.nn.Module) + mock_named_parameters.return_value = [('param1', object())] + mock_get_tensor_features.return_value = {'max': 1, 'min': 0, 'mean': 0.5, 'norm': 1} + result = collect_weights_data(mock_model) + self.assertEqual(isinstance(result, dict), True) + + @patch('msprobe.core.config_check.checkers.weights_checker.load_json') + def test_compare_weight_file(self, mock_load_json): + mock_load_json.side_effect = [ + {'weight1': {'max': 1, 'min': 0, 'mean': 0.5, 'norm': 1}}, + {'weight1': {'max': 1, 'min': 0, 'mean': 0.5, 'norm': 1}} + ] + result = compare_weight_file('bench.json', 'cmp.json') + self.assertEqual(isinstance(result, list), True) + + @patch('msprobe.core.config_check.checkers.weights_checker.os_walk_for_files') + @patch('msprobe.core.config_check.checkers.weights_checker.load_json') + @patch('os.path.exists') + def test_compare_weight(self, mock_exists, mock_load_json, mock_os_walk_for_files): + mock_os_walk_for_files.return_value = [ + {"root": "bench/step1/rank0", "file": "weights.json"} + ] + mock_load_json.return_value = {'weight1': {'max': 1, 'min': 0, 'mean': 0.5, 'norm': 1}} + mock_exists.return_value = True + result = compare_weight('bench', 'cmp') + self.assertEqual(isinstance(result, pd.DataFrame), True) + + @patch('msprobe.core.config_check.checkers.weights_checker.load_json') + def test_compare_weight_file_different_weights(self, mock_load_json): + mock_load_json.side_effect = [ + {'weight1': {'max': 1, 'min': 0, 'mean': 0.5, 'norm': 1}}, + {'weight1': {'max': 2, 'min': 1, 'mean': 1.5, 'norm': 2}} + ] + result = compare_weight_file('bench.json', 'cmp.json') + self.assertEqual(isinstance(result, list), True) + for res in result: + if res["weight_name"] == "weight1": + self.assertEqual(res["equal"], False) + + @patch('msprobe.core.config_check.checkers.weights_checker.os_walk_for_files') + @patch('msprobe.core.config_check.checkers.weights_checker.load_json') + @patch('os.path.exists') + def test_compare_weight_cmp_file_missing(self, mock_exists, mock_load_json, mock_os_walk_for_files): + mock_os_walk_for_files.return_value = [ + {"root": "bench/step1/rank0", "file": "weights.json"} + ] + mock_load_json.return_value = {'weight1': {'max': 1, 'min': 0, 'mean': 0.5, 'norm': 1}} + mock_exists.return_value = False + result = compare_weight('bench', 'cmp') + self.assertEqual(isinstance(result, pd.DataFrame), True) + self.assertEqual(len(result[result["equal"] == "only bench have"]), 1) + + @patch('msprobe.core.config_check.checkers.weights_checker.os_walk_for_files') + @patch('msprobe.core.config_check.checkers.weights_checker.load_json') + @patch('os.path.exists') + def test_compare_weight_multiple_files(self, mock_exists, mock_load_json, mock_os_walk_for_files): + mock_os_walk_for_files.return_value = [ + {"root": "bench/step1/rank0", "file": "weights1.json"}, + {"root": "bench/step1/rank0", "file": "weights2.json"} + ] + mock_load_json.return_value = {'weight1': {'max': 1, 'min': 0, 'mean': 0.5, 'norm': 1}} + mock_exists.return_value = True + result = compare_weight('bench', 'cmp') + self.assertEqual(isinstance(result, pd.DataFrame), True) + self.assertEqual(len(result), 2) + diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py index 8ff89437646ee203aaa4a3fac5bbfea1538e9409..f9b6bd4d8a2266e0f449239a7df87d5caf9d1b10 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py @@ -70,31 +70,23 @@ class TestBaseDataProcessor(unittest.TestCase): @patch('inspect.stack') def test_analyze_api_call_stack(self, mock_stack): mock_stack.return_value = [ - (None, 'file0.py', 0, 'function0', ['code line 0'], None), - (None, 'file1.py', 10, 'function1', ['code line 1'], None), - (None, 'file2.py', 20, 'function2', ['code line 2'], None), (None, 'file3.py', 30, 'function3', ['code line 3'], None), - (None, 'file4.py', 40, 'function4', ['code line 4'], None), - (None, 'file5.py', 50, 'function5', ['code line 5'], None), - (None, 'file6.py', 60, 'function6', ['code line 6'], None), - (None, 'file7.py', 70, 'function7', ['code line 7'], None), + (None, 'file1.py', 40, 'function1', ['code line 1'], None), + (None, 'file2.py', 50, 'function2', ['code line 2'], None), + (None, 'file3.py', 60, 'function3', ['code line 3'], None), + (None, 'file1.py', 70, 'function1', ['code line 1'], None), + (None, 'file1.py', 80, 'function1', ['code line 1'], None), + (None, 'file2.py', 90, 'function2', ['code line 2'], None), + (None, 'file3.py', 100, 'function3', ['code line 3'], None) ] result = BaseDataProcessor.analyze_api_call_stack('test_stack') - expected_output = { - 'test_stack': [ - 'File file5.py, line 50, in function5, \n code line 5', - 'File file6.py, line 60, in function6, \n code line 6', - 'File file7.py, line 70, in function7, \n code line 7', - ] - } - self.assertEqual(result, expected_output) + expected_output = ( + 'File file1.py, line 80, in function1, \n code line 1', + 'File file2.py, line 90, in function2, \n code line 2', + 'File file3.py, line 100, in function3, \n code line 3', + ) - def test_convert_numpy_to_builtin(self): - self.assertEqual(BaseDataProcessor._convert_numpy_to_builtin(np.int32(5)), (5, 'int32')) - self.assertEqual(BaseDataProcessor._convert_numpy_to_builtin(np.float64(3.14)), (3.14, 'float64')) - self.assertEqual(BaseDataProcessor._convert_numpy_to_builtin(np.bool_(True)), (True, 'bool_')) - self.assertEqual(BaseDataProcessor._convert_numpy_to_builtin(np.str_('test')), ('test', 'str_')) - self.assertEqual(BaseDataProcessor._convert_numpy_to_builtin(5), (5, '')) + self.assertEqual(result, expected_output) def test_analyze_builtin(self): result = self.processor._analyze_builtin(slice(1, 10, 2)) @@ -113,12 +105,37 @@ class TestBaseDataProcessor(unittest.TestCase): expected = {'type': 'int', 'value': 1} self.assertEqual(result, expected) + def test_analyze_numpy(self): + result = BaseDataProcessor._analyze_numpy(np.int32(5)) + expected = {"type": 'int32', "value": 5} + self.assertEqual(result, expected) + + result = BaseDataProcessor._analyze_numpy(np.float32(3.14)) + expected = {"type": 'float32', "value": 3.140000104904175} + self.assertEqual(result, expected) + + result = BaseDataProcessor._analyze_numpy(np.bool_(True)) + expected = {"type": 'bool_', "value": True} + self.assertEqual(result, expected) + + result = BaseDataProcessor._analyze_numpy(np.str_("abc")) + expected = {"type": 'str_', "value": "abc"} + self.assertEqual(result, expected) + + result = BaseDataProcessor._analyze_numpy(np.byte(1)) + expected = {"type": 'int8', "value": 1} + self.assertEqual(result, expected) + + result = BaseDataProcessor._analyze_numpy(np.complex128(1 + 2j)) + expected = {"type": 'complex128', "value": (1 + 2j)} + self.assertEqual(result, expected) + def test_get_special_types(self): self.assertIn(int, BaseDataProcessor.get_special_types()) - def test_analyze_numpy(self): + def test_analyze_ndarray(self): ndarray = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) - result = BaseDataProcessor._analyze_numpy(ndarray, 'numpy.ndarray') + result = BaseDataProcessor._analyze_ndarray(ndarray, 'numpy.ndarray') expected_result = { 'type': 'numpy.ndarray', 'dtype': 'int32', @@ -126,7 +143,20 @@ class TestBaseDataProcessor(unittest.TestCase): 'Max': 6, 'Min': 1, 'Mean': 3.5, - 'Norm':9.539392014169456 + 'Norm': 9.539392014169456 + } + self.assertEqual(result, expected_result) + + ndarray = np.array([], dtype=np.int32) + result = BaseDataProcessor._analyze_ndarray(ndarray, 'numpy.ndarray') + expected_result = { + 'type': 'numpy.ndarray', + 'dtype': 'int32', + 'shape': (0,), + 'Max': None, + 'Min': None, + 'Mean': None, + 'Norm': None } self.assertEqual(result, expected_result) @@ -134,6 +164,7 @@ class TestBaseDataProcessor(unittest.TestCase): transform = lambda x, _: x * 2 Test = namedtuple("Test", ['a']) myNamedTuple = Test(1) + @dataclass class MyDataClass: last_hidden_state: int = None @@ -145,7 +176,7 @@ class TestBaseDataProcessor(unittest.TestCase): hidden_states=(2, 3), attentions=(4, 5) ) - expected_dataclass_res = {'last_hidden_state': 2, 'hidden_states': [4, 6], 'attentions': [8,10]} + expected_dataclass_res = {'last_hidden_state': 2, 'hidden_states': [4, 6], 'attentions': [8, 10]} self.assertEqual(BaseDataProcessor.recursive_apply_transform(2, transform), 4) self.assertEqual(BaseDataProcessor.recursive_apply_transform(myData, transform), expected_dataclass_res) self.assertEqual(BaseDataProcessor.recursive_apply_transform(myNamedTuple, transform), {'a': 2}) @@ -280,9 +311,9 @@ class TestBaseDataProcessor(unittest.TestCase): self.assertEqual(dst_data_structure, excepted_result) def test_analyze_element_to_all_none(self): - element = {"key1": [12, 3, {"key2": 10, "key3":["12"]}]} + element = {"key1": [12, 3, {"key2": 10, "key3": ["12"]}]} result = self.processor.analyze_element_to_all_none(element) - excepted_result = {"key1": [None, None, {"key2": None, "key3":[None]}]} + excepted_result = {"key1": [None, None, {"key2": None, "key3": [None]}]} self.assertEqual(result, excepted_result) @patch.object(MindsporeDataProcessor, "is_hookable_element", return_value=True) @@ -327,4 +358,4 @@ class TestBaseDataProcessor(unittest.TestCase): nested_data_structure, ["grad_name_1", "layer1", "layer2"], "grad_data_info" ) self.assertIsNone(self.processor.save_name) - self.assertEqual(result, grad) \ No newline at end of file + self.assertEqual(result, grad) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_mindspore_processor.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_mindspore_processor.py index b593d34c5d86c7fb3b4a0e8a3ff548c55555e09d..99dc1d5eee2d2ca5f9a3176f1a674d7ee8a6d260 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_mindspore_processor.py @@ -19,9 +19,10 @@ from unittest.mock import patch, MagicMock import zlib import mindspore as ms -from mindspore import Tensor +from mindspore import Tensor, ops, mint import numpy as np +from msprobe.core.common.const import Const from msprobe.core.data_dump.data_processor.base import BaseDataProcessor from msprobe.core.data_dump.data_processor.mindspore_processor import ( MindsporeDataProcessor, @@ -32,6 +33,13 @@ from msprobe.core.data_dump.data_processor.mindspore_processor import ( from msprobe.mindspore.common.log import logger +def patch_norm(value): + return ops.norm(value) + + +setattr(mint, "norm", patch_norm) + + class TestMindsporeDataProcessor(unittest.TestCase): def setUp(self): self.config = MagicMock() @@ -69,11 +77,16 @@ class TestMindsporeDataProcessor(unittest.TestCase): def test_get_stat_info_float_async(self): self.config.async_dump = True tensor = ms.tensor([1.0, 2.0, 3.0]) - result = self.processor.get_stat_info(tensor).stack_tensor_stat[1] - self.assertEqual(result[0].item(), 3.0) - self.assertEqual(result[1].item(), 1.0) - self.assertEqual(result[2].item(), 2.0) - self.assertEqual(result[3].item(), ms.ops.norm(tensor).item()) + result = self.processor.get_stat_info(tensor) + result_max = result.max + result_min = result.min + result_mean = result.mean + result_norm = result.norm + + self.assertEqual(result_max.item(), 3.0) + self.assertEqual(result_min.item(), 1.0) + self.assertEqual(result_mean.item(), 2.0) + self.assertEqual(result_norm.item(), ms.ops.norm(tensor).item()) def test_get_stat_info_int(self): self.config.async_dump = False @@ -87,9 +100,13 @@ class TestMindsporeDataProcessor(unittest.TestCase): def test_get_stat_info_int_async(self): self.config.async_dump = True tensor = ms.tensor([1, 2, 3]) - result = self.processor.get_stat_info(tensor).stack_tensor_stat[1] - self.assertEqual(result[0].item(), 3.0) - self.assertEqual(result[1].item(), 1.0) + result = self.processor.get_stat_info(tensor) + + result_max = result.max + result_min = result.min + + self.assertEqual(result_max.item(), 3.0) + self.assertEqual(result_min.item(), 1.0) def test_get_stat_info_bool(self): self.config.async_dump = False @@ -103,9 +120,13 @@ class TestMindsporeDataProcessor(unittest.TestCase): def test_get_stat_info_bool_async(self): self.config.async_dump = True tensor = ms.Tensor([True, False, True]) - result = self.processor.get_stat_info(tensor).stack_tensor_stat[1] - self.assertEqual(result[0].item(), True) - self.assertEqual(result[1].item(), False) + result = self.processor.get_stat_info(tensor) + + result_max = result.max + result_min = result.min + + self.assertEqual(result_max.item(), True) + self.assertEqual(result_min.item(), False) @patch.object(MindsporeDataProcessor, 'get_md5_for_tensor') def test__analyze_tensor(self, get_md5_for_tensor): @@ -117,14 +138,13 @@ class TestMindsporeDataProcessor(unittest.TestCase): expected_result = { 'type': 'mindspore.Tensor', 'dtype': 'Int32', - 'shape': (3,), - 'Max': 3, - 'Min': 1, - 'Mean': 2, - 'Norm': ms.ops.norm(tensor).item(), - 'md5': 'test_md5', + 'shape': (3,) } result = self.processor._analyze_tensor(tensor, suffix) + # 删除不必要的字段 + result.pop('tensor_stat_index', None) + result.pop('md5_index', None) + self.assertEqual(result, expected_result) @@ -150,12 +170,9 @@ class TestTensorDataProcessor(unittest.TestCase): 'type': 'mindspore.Tensor', 'dtype': str(tensor.dtype), 'shape': tensor.shape, - 'Max': 3.0, - 'Min': 1.0, - 'Mean': 2.0, - 'Norm': ms.ops.norm(tensor).item(), 'data_name': 'test_api.input.suffix.npy' } + result.pop('tensor_stat_index', None) self.assertEqual(expected, result) @@ -164,6 +181,7 @@ class TestOverflowCheckDataProcessor(unittest.TestCase): class Config: def __init__(self): self.overflow_nums = 1 + self.data_processor = OverflowCheckDataProcessor(Config(), None) def test___init__(self): @@ -174,6 +192,7 @@ class TestOverflowCheckDataProcessor(unittest.TestCase): def test_analyze_forward(self): def func(_): self.data_processor.has_overflow = True + with patch.object(BaseDataProcessor, "analyze_forward", return_value={"min", 0}): with patch.object(OverflowCheckDataProcessor, "maybe_save_overflow_data"): api_info = self.data_processor.analyze_forward("name", "module", "module_input_output") @@ -187,6 +206,7 @@ class TestOverflowCheckDataProcessor(unittest.TestCase): def test_analyze_backward(self): def func(_): self.data_processor.has_overflow = True + with patch.object(BaseDataProcessor, "analyze_backward", return_value={"min", 0}): with patch.object(OverflowCheckDataProcessor, "maybe_save_overflow_data"): api_info = self.data_processor.analyze_backward("name", "module", "module_input_output") @@ -218,33 +238,62 @@ class TestOverflowCheckDataProcessor(unittest.TestCase): self.data_processor.overflow_nums = 3 self.assertFalse(self.data_processor.is_terminated) + # from unittest.mock import MagicMock + def test__analyze_maybe_overflow_tensor(self): + # Mock DataWriter 和相关方法 + self.data_processor.data_writer = MagicMock() + + tensor_json = {Const.TENSOR_STAT_INDEX: 1} # 修正:添加正确的 tensor_stat_index + + # 模拟返回值 + self.data_processor.data_writer.get_buffer_values_max.return_value = 10 + self.data_processor.data_writer.get_buffer_values_min.return_value = -10 + self.data_processor.has_overflow = False - tensor_json = {"Max": None, "Min": 0} + # 调用函数并检查没有溢出 self.data_processor._analyze_maybe_overflow_tensor(tensor_json) self.assertFalse(self.data_processor.has_overflow) - tensor_json.update({"Max": -np.inf}) + + self.data_processor.has_overflow = False + # max 值为 -np.inf,应该触发溢出 + self.data_processor.data_writer.get_buffer_values_max.return_value = -np.inf self.data_processor._analyze_maybe_overflow_tensor(tensor_json) self.assertTrue(self.data_processor.has_overflow) + self.data_processor.has_overflow = False - tensor_json.update({"Max": np.inf}) + # max 值为 np.inf,应该触发溢出 + self.data_processor.data_writer.get_buffer_values_max.return_value = np.inf self.data_processor._analyze_maybe_overflow_tensor(tensor_json) self.assertTrue(self.data_processor.has_overflow) + self.data_processor.has_overflow = False - tensor_json.update({"Max": np.nan}) + # max 值为 np.nan,应该触发溢出 + self.data_processor.data_writer.get_buffer_values_max.return_value = np.nan self.data_processor._analyze_maybe_overflow_tensor(tensor_json) self.assertTrue(self.data_processor.has_overflow) - tensor_json.update({"Max": 0}) + + self.data_processor.has_overflow = False + # max 值为 0,不会触发溢出 + self.data_processor.data_writer.get_buffer_values_max.return_value = 0 + self.data_processor._analyze_maybe_overflow_tensor(tensor_json) + self.assertFalse(self.data_processor.has_overflow) + self.data_processor.has_overflow = False - tensor_json.update({"Min": -np.inf}) + # min 值为 -np.inf,应该触发溢出 + self.data_processor.data_writer.get_buffer_values_min.return_value = -np.inf self.data_processor._analyze_maybe_overflow_tensor(tensor_json) self.assertTrue(self.data_processor.has_overflow) + self.data_processor.has_overflow = False - tensor_json.update({"Min": np.inf}) + # min 值为 np.inf,应该触发溢出 + self.data_processor.data_writer.get_buffer_values_min.return_value = np.inf self.data_processor._analyze_maybe_overflow_tensor(tensor_json) self.assertTrue(self.data_processor.has_overflow) + self.data_processor.has_overflow = False - tensor_json.update({"Min": np.nan}) + # min 值为 np.nan,应该触发溢出 + self.data_processor.data_writer.get_buffer_values_min.return_value = np.nan self.data_processor._analyze_maybe_overflow_tensor(tensor_json) self.assertTrue(self.data_processor.has_overflow) @@ -260,7 +309,7 @@ class TestOverflowCheckDataProcessor(unittest.TestCase): return_value=False): ret = self.data_processor._analyze_tensor("tensor", "suffix") self.assertEqual(self.data_processor.cached_tensors_and_file_paths, {"file_path": "tensor"}) - mock_warning.assert_not_called() + mock_warning.assert_called_with("tensor_stat_index does not exist in tensor_json.") mock_super.assert_called_with("tensor", "suffix") self.assertEqual(ret.get("Max"), None) self.assertEqual(ret.get("data_name"), "dump_data_name") @@ -268,7 +317,8 @@ class TestOverflowCheckDataProcessor(unittest.TestCase): with patch("msprobe.core.data_dump.data_processor.mindspore_processor.path_len_exceeds_limit", return_value=True): self.data_processor._analyze_tensor("tensor", "suffix") - mock_warning.assert_called_with("The file path file_path length exceeds limit.") + mock_warning.assert_called_with("tensor_stat_index does not exist in tensor_json.") + class TestKernelDumpDataProcessor(unittest.TestCase): def setUp(self): @@ -293,7 +343,8 @@ class TestKernelDumpDataProcessor(unittest.TestCase): def test_analyze_pre_forward_without_adump(self, mock_logger_warning): self.processor.enable_kernel_dump = True self.processor.analyze_forward_input("test_api_name", None, None) - mock_logger_warning.assert_called_with("The current msprobe package does not compile adump, and kernel dump cannot be used.") + mock_logger_warning.assert_called_with( + "The current msprobe package does not compile adump, and kernel dump cannot be used.") self.assertFalse(self.processor.enable_kernel_dump) @patch('msprobe.core.data_dump.data_processor.mindspore_processor.KernelDumpDataProcessor.stop_kernel_dump') @@ -319,7 +370,8 @@ class TestKernelDumpDataProcessor(unittest.TestCase): self.processor.enable_kernel_dump = True self.processor.analyze_backward_input("test_api_name", None, None) self.assertFalse(self.processor.enable_kernel_dump) - mock_logger_warning.assert_called_with("The current msprobe package does not compile adump, and kernel dump cannot be used.") + mock_logger_warning.assert_called_with( + "The current msprobe package does not compile adump, and kernel dump cannot be used.") @patch('msprobe.core.data_dump.data_processor.mindspore_processor.KernelDumpDataProcessor.stop_kernel_dump') @patch.object(logger, 'info') diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py index 34064e7cc2b9d0aa5c0c2e98806b8993137a589c..45099b78d9044427ea04d518db858ce657918b25 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py @@ -19,6 +19,7 @@ from msprobe.core.data_dump.data_processor.pytorch_processor import ( KernelDumpDataProcessor ) from torch import distributed as dist +from torch._subclasses import FakeTensorMode class TestPytorchDataProcessor(unittest.TestCase): @@ -59,9 +60,18 @@ class TestPytorchDataProcessor(unittest.TestCase): def test_get_stat_info_with_meta_tensor(self): mock_data = self.mock_tensor(is_meta=True) - result = PytorchDataProcessor.get_stat_info(mock_data) + result = self.processor.get_stat_info(mock_data) self.assertIsInstance(result, TensorStatInfo) + def test_get_stat_info_with_fake_tensor(self): + with FakeTensorMode() as fake_tensor_mode: + fake_tensor = fake_tensor_mode.from_tensor(torch.randn(1, 2, 3)) + result = self.processor.get_stat_info(fake_tensor) + self.assertIsNone(result.max) + self.assertIsNone(result.min) + self.assertIsNone(result.mean) + self.assertIsNone(result.norm) + def test_get_stat_info_float(self): tensor = torch.tensor([1.0, 2.0, 3.0]) result = self.processor.get_stat_info(tensor) @@ -70,30 +80,15 @@ class TestPytorchDataProcessor(unittest.TestCase): self.assertEqual(result.mean, 2.0) self.assertEqual(result.norm, torch.norm(tensor).item()) - def test_get_stat_info_float_async(self): - tensor = torch.tensor([1.0, 2.0, 3.0]) - result = self.processor.get_stat_info_async(tensor).stack_tensor_stat[1] - self.assertEqual(result[0].item(), 3.0) - self.assertEqual(result[1].item(), 1.0) - self.assertEqual(result[2].item(), 2.0) - self.assertEqual(result[3].item(), torch.norm(tensor).item()) - def test_get_stat_info_int(self): tensor = torch.tensor([1, 2, 3], dtype=torch.int32) result = self.processor.get_stat_info(tensor) + self.assertEqual(result.max, 3) self.assertEqual(result.min, 1) self.assertEqual(result.mean, 2) self.assertEqual(result.norm, torch.norm(tensor.float()).item()) - def test_get_stat_info_int_async(self): - tensor = torch.tensor([1, 2, 3]) - result = self.processor.get_stat_info_async(tensor).stack_tensor_stat[1] - self.assertEqual(result[0].item(), 3.0) - self.assertEqual(result[1].item(), 1.0) - self.assertEqual(result[2].item(), 2.0) - self.assertEqual(result[3].item(), torch.norm(tensor.float()).item()) - def test_get_stat_info_empty(self): tensor = torch.tensor([]) result = self.processor.get_stat_info(tensor) @@ -110,15 +105,9 @@ class TestPytorchDataProcessor(unittest.TestCase): self.assertIsNone(result.mean) self.assertIsNone(result.norm) - def test_get_stat_info_bool_async(self): - tensor = torch.tensor([True, False, True]) - result = self.processor.get_stat_info_async(tensor).stack_tensor_stat[1] - self.assertEqual(result[0].item(), True) - self.assertEqual(result[1].item(), False) - def test_get_stat_info_with_scalar_tensor(self): scalar_tensor = torch.tensor(42.0) - result = PytorchDataProcessor.get_stat_info(scalar_tensor) + result = self.processor.get_stat_info(scalar_tensor) self.assertIsInstance(result, TensorStatInfo) self.assertEqual(result.max, 42.0) self.assertEqual(result.min, 42.0) @@ -127,7 +116,7 @@ class TestPytorchDataProcessor(unittest.TestCase): def test_get_stat_info_with_complex_tensor(self): complex_tensor = torch.tensor([1 + 2j, 3 + 4j], dtype=torch.complex64) - result = PytorchDataProcessor.get_stat_info(complex_tensor) + result = self.processor.get_stat_info(complex_tensor) expected_max = np.abs(np.array([1 + 2j, 3 + 4j])).max().item() expected_min = np.abs(np.array([1 + 2j, 3 + 4j])).min().item() expected_mean = np.abs(np.array([1 + 2j, 3 + 4j])).mean().item() @@ -136,49 +125,6 @@ class TestPytorchDataProcessor(unittest.TestCase): self.assertAlmostEqual(result.min, expected_min, places=6) self.assertAlmostEqual(result.mean, expected_mean, places=6) - def test_handle_tensor_extremum_nan_inf_all_nan(self): - tensor = torch.tensor([float('nan'), float('nan')]) - result = self.processor.handle_tensor_extremum_nan_inf(tensor, 'max') - self.assertTrue(np.isnan(result)) - - def test_handle_tensor_extremum_nan_inf_all_inf(self): - tensor = torch.tensor([float('inf'), float('inf')]) - result = self.processor.handle_tensor_extremum_nan_inf(tensor, 'max') - self.assertTrue(np.isinf(result)) - - def test_handle_tensor_extremum_nan_inf_all_negative_inf(self): - tensor = torch.tensor([float('-inf'), float('-inf')]) - result = self.processor.handle_tensor_extremum_nan_inf(tensor, 'min') - self.assertTrue(np.isinf(result) and result < 0) - - def test_handle_tensor_extremum_nan_inf_mixed(self): - tensor = torch.tensor([1.0, float('nan'), 3.0, float('-inf'), 2.0]) - result_max = self.processor.handle_tensor_extremum_nan_inf(tensor, 'max') - result_min = self.processor.handle_tensor_extremum_nan_inf(tensor, 'min') - self.assertEqual(result_max, 3.0) - self.assertEqual(result_min, 1.0) - - def test_handle_tensor_extremum_nan_inf_mixed_with_inf(self): - tensor = torch.tensor([1.0, float('nan'), 3.0, float('inf'), 2.0]) - result_max = self.processor.handle_tensor_extremum_nan_inf(tensor, 'max') - result_min = self.processor.handle_tensor_extremum_nan_inf(tensor, 'min') - self.assertEqual(result_max, 3.0) - self.assertEqual(result_min, 1.0) - - def test_handle_tensor_extremum_nan_inf_no_inf_nan(self): - tensor = torch.tensor([1.0, 2.0, 3.0]) - result_max = self.processor.handle_tensor_extremum_nan_inf(tensor, 'max') - result_min = self.processor.handle_tensor_extremum_nan_inf(tensor, 'min') - self.assertEqual(result_max, 3.0) - self.assertEqual(result_min, 1.0) - - def test_handle_tensor_extremum_nan_inf_all_inf_nan(self): - tensor = torch.tensor([float('nan'), float('inf'), float('-inf')]) - result_max = self.processor.handle_tensor_extremum_nan_inf(tensor, 'max') - result_min = self.processor.handle_tensor_extremum_nan_inf(tensor, 'min') - self.assertTrue(np.isinf(result_max)) - self.assertTrue(np.isinf(result_min)) - def test_analyze_builtin(self): result = self.processor._analyze_builtin(slice(1, torch.tensor(10, dtype=torch.int32), np.int64(2))) expected = {'type': 'slice', 'value': [1, 10, 2]} @@ -196,7 +142,7 @@ class TestPytorchDataProcessor(unittest.TestCase): dist.init_process_group(backend='gloo', world_size=1, rank=0) process_group_element = dist.group.WORLD result = self.processor.process_group_hash(process_group_element) - expected = hashlib.md5('[0]'.encode('utf-8')).hexdigest() + expected = f"{zlib.crc32(str([0]).encode('utf-8')):08x}" self.assertEqual(result, expected) def test_analyze_torch_size(self): @@ -222,7 +168,7 @@ class TestPytorchDataProcessor(unittest.TestCase): expected = { 'type': 'torch.ProcessGroup', 'group_ranks': [0], - 'group_id': hashlib.md5('[0]'.encode('utf-8')).hexdigest() + 'group_id': f"{zlib.crc32(str([0]).encode('utf-8')):08x}" } self.assertEqual(result, expected) @@ -237,6 +183,7 @@ class TestPytorchDataProcessor(unittest.TestCase): class TestReduceOp: def __str__(self): raise Exception("failed to convert str type") + arg = TestReduceOp() self.processor._analyze_reduce_op(arg) mock_logger_warning.assert_called_with( @@ -268,11 +215,35 @@ class TestPytorchDataProcessor(unittest.TestCase): self.assertEqual(result, self.processor._analyze_process_group(process_group_element)) def test_analyze_single_element_numpy_conversion(self): - numpy_element = np.int64(1) - converted_numpy, numpy_type = self.processor._convert_numpy_to_builtin(numpy_element) + numpy_element = np.int32(5) result = self.processor.analyze_single_element(numpy_element, []) - expected_result = {"type": numpy_type, "value": converted_numpy} - self.assertEqual(result, expected_result) + expected = {"type": 'int32', "value": 5} + self.assertEqual(result, expected) + + numpy_element = np.float32(3.14) + result = self.processor.analyze_single_element(numpy_element, []) + expected = {"type": 'float32', "value": 3.140000104904175} + self.assertEqual(result, expected) + + numpy_element = np.bool_(True) + result = self.processor.analyze_single_element(numpy_element, []) + expected = {"type": 'bool_', "value": True} + self.assertEqual(result, expected) + + numpy_element = np.str_("abc") + result = self.processor.analyze_single_element(numpy_element, []) + expected = {"type": 'str_', "value": "abc"} + self.assertEqual(result, expected) + + numpy_element = np.byte(1) + result = self.processor.analyze_single_element(numpy_element, []) + expected = {"type": 'int8', "value": 1} + self.assertEqual(result, expected) + + numpy_element = np.complex128(1 + 2j) + result = self.processor.analyze_single_element(numpy_element, []) + expected = {"type": 'complex128', "value": (1 + 2j)} + self.assertEqual(result, expected) def test_analyze_single_element_tensor(self): tensor_element = torch.tensor([1, 2, 3]) @@ -302,28 +273,20 @@ class TestPytorchDataProcessor(unittest.TestCase): 'type': 'torch.Tensor', 'dtype': str(tensor.dtype), 'shape': tensor.shape, - 'Max': 3.0, - 'Min': 1.0, - 'Mean': 2.0, - 'Norm': torch.norm(tensor).item(), - 'requires_grad': tensor.requires_grad, - 'md5': 'mocked_md5' + 'requires_grad': tensor.requires_grad } + result.pop('tensor_stat_index', None) + result.pop('md5_index', None) self.assertDictEqual(expected, result) def test_analyze_tensor_with_empty_tensor(self): tensor = torch.tensor([]) result = self.processor._analyze_tensor(tensor, 'suffix') - self.assertEqual(result['Max'], None) - self.assertEqual(result['Min'], None) - self.assertEqual(result['Mean'], None) - self.assertEqual(result['Norm'], None) - def test_analyze_tensor_with_inf_and_nan(self): - tensor = torch.tensor([1.0, float('inf'), float('nan'), -float('inf')]) - result = self.processor._analyze_tensor(tensor, 'suffix') - self.assertEqual(result['Max_except_inf_nan'], 1.0) - self.assertEqual(result['Min_except_inf_nan'], 1.0) + self.assertEqual(result['type'], "torch.Tensor") + self.assertEqual(result['dtype'], 'torch.float32') + self.assertEqual(result['shape'], torch.Size([0])) + self.assertEqual(result['requires_grad'], False) class TestTensorDataProcessor(unittest.TestCase): @@ -348,13 +311,10 @@ class TestTensorDataProcessor(unittest.TestCase): 'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': tensor.shape, - 'Max': 3.0, - 'Min': 1.0, - 'Mean': 2.0, - 'Norm': torch.norm(tensor).item(), 'requires_grad': False, 'data_name': 'test_api.input.suffix.pt' } + result.pop('tensor_stat_index', None) self.assertEqual(expected, result) @@ -372,6 +332,9 @@ class TestOverflowCheckDataProcessor(unittest.TestCase): sys.modules['torch_npu'] = Mock() sys.modules['torch_npu.npu'] = Mock() sys.modules['torch_npu.npu.utils'] = Mock() + self.tensor_json = { + 'tensor_stat_index': 123 # 默认情况下 tensor_stat_index 存在 + } def test_is_terminated(self): self.processor.overflow_nums = -1 @@ -386,7 +349,7 @@ class TestOverflowCheckDataProcessor(unittest.TestCase): def test_analyze_forward_input(self): with patch.object(BaseDataProcessor, "analyze_forward_input", return_value={"name": 1}): - api_info = self.processor.analyze_forward_input("name", "module","module_input_output") + api_info = self.processor.analyze_forward_input("name", "module", "module_input_output") self.assertEqual(self.processor.cached_api_info, {"name": 1}) self.assertIsNone(api_info) @@ -448,19 +411,43 @@ class TestOverflowCheckDataProcessor(unittest.TestCase): self.processor._is_support_inf_nan() self.assertTrue(self.processor.support_inf_nan) - def test_analyze_maybe_overflow_tensor(self): - tensor_json = {'Max': None, 'Min': None} - self.processor._analyze_maybe_overflow_tensor(tensor_json) + def test_max_tensor_or_min_tensor_is_none(self): + # 让 get_buffer_values_max 和 get_buffer_values_min 返回 None + self.processor.data_writer.get_buffer_values_max.return_value = None + self.processor.data_writer.get_buffer_values_min.return_value = None + + # 在该情况下应该直接返回,不做任何改变 + self.processor._analyze_maybe_overflow_tensor(self.tensor_json) + + # 确保 has_overflow 没有被设置 self.assertFalse(self.processor.has_overflow) - tensor_json = {'Max': float('inf'), 'Min': 1.0} - self.processor._analyze_maybe_overflow_tensor(tensor_json) + def test_tensor_is_inf_or_nan(self): + # 模拟 max_tensor 为 Inf + self.processor.data_writer.get_buffer_values_max.return_value = torch.tensor(float('inf')) + self.processor.data_writer.get_buffer_values_min.return_value = torch.tensor(1.0) + + # 测试应该设置 has_overflow 为 True + self.processor._analyze_maybe_overflow_tensor(self.tensor_json) self.assertTrue(self.processor.has_overflow) - tensor_json = {'Max': 1.0, 'Min': float('inf')} - self.processor._analyze_maybe_overflow_tensor(tensor_json) + # 模拟 min_tensor 为 NaN + self.processor.data_writer.get_buffer_values_max.return_value = torch.tensor(1.0) + self.processor.data_writer.get_buffer_values_min.return_value = torch.tensor(float('nan')) + + # 测试应该设置 has_overflow 为 True + self.processor._analyze_maybe_overflow_tensor(self.tensor_json) self.assertTrue(self.processor.has_overflow) + def test_normal_tensor(self): + # 模拟正常的 max_tensor 和 min_tensor + self.processor.data_writer.get_buffer_values_max.return_value = torch.tensor(1.0) + self.processor.data_writer.get_buffer_values_min.return_value = torch.tensor(-1.0) + + # 在正常情况下不应该改变 has_overflow + self.processor._analyze_maybe_overflow_tensor(self.tensor_json) + self.assertFalse(self.processor.has_overflow) + @patch('msprobe.core.common.file_utils.path_len_exceeds_limit', return_value=False) @patch.object(BaseDataProcessor, 'get_save_file_path', return_value=['test_api_name', 'test_api_name.0.forward.input.pt']) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_api_registry.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_api_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..1758755b54f89a43cb9a739bf512833b09a72781 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_api_registry.py @@ -0,0 +1,73 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest import TestCase +from unittest.mock import patch + +import torch + +from msprobe.core.common.const import Const +from msprobe.core.data_dump.api_registry import _get_attr, ApiWrapper + + +class TestFunctions(TestCase): + def test__get_attr(self): + module = torch + + attr_name = 'linalg.norm' + target_value = torch.linalg.norm + actual_value = _get_attr(module, attr_name) + self.assertEqual(target_value, actual_value) + + attr_name = 'norm' + target_value = torch.norm + actual_value = _get_attr(module, attr_name) + self.assertEqual(target_value, actual_value) + + +class TestApiWrapper(TestCase): + api_types = { + Const.PT_FRAMEWORK: { + Const.PT_API_TYPE_TORCH: ((torch,), torch), + } + } + supported_api_list_path = (Const.SUPPORT_API_FILE_NAME,) + yaml_value = {'torch': ['linalg.norm', 'norm']} + api_names = {Const.PT_FRAMEWORK: {'torch': {'linalg.norm', 'norm'}}} + + def test___init__(self): + with patch('msprobe.core.data_dump.api_registry.load_yaml', return_value=self.yaml_value): + api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path) + self.assertEqual(api_wrapper.api_types, self.api_types) + self.assertEqual(api_wrapper.api_list_paths, self.supported_api_list_path) + self.assertEqual(api_wrapper.api_names, self.api_names) + self.assertEqual(api_wrapper.wrapped_api_functions, {}) + + api_wrapper = ApiWrapper(self.api_types, Const.SUPPORT_API_FILE_NAME) + self.assertEqual(api_wrapper.api_list_paths, list(self.supported_api_list_path)) + + with self.assertRaises(Exception) as context: + api_wrapper = ApiWrapper(self.api_types, (Const.SUPPORT_API_FILE_NAME, Const.SUPPORT_API_FILE_NAME)) + self.assertEqual(str(context.exception), + "The number of api_list_paths must be equal to the number of frameworks in 'api_types', " + "when api_list_paths is a list or tuple.") + + def test__get_api_names(self): + target_value = self.api_names + with patch('msprobe.core.data_dump.api_registry.load_yaml', return_value=self.yaml_value): + api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path) + actual_value = api_wrapper._get_api_names() + self.assertEqual(target_value, actual_value) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py index b9d2e7abef7244fc12dc71e3113c26af52529ce9..67ea481a9b6c983c29f80913820deb07e0b2219c 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py @@ -1,8 +1,7 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -13,7 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" + +import os import unittest from unittest.mock import patch, mock_open, MagicMock @@ -22,9 +22,6 @@ from msprobe.core.common.utils import Const from msprobe.core.data_dump.data_collector import DataCollector from msprobe.pytorch.debugger.debugger_config import DebuggerConfig from msprobe.pytorch.pt_config import parse_json_config -from msprobe.core.data_dump.json_writer import DataWriter -from msprobe.core.data_dump.data_processor.base import BaseDataProcessor -from msprobe.core.data_dump.data_processor.pytorch_processor import StatisticsDataProcessor class TestDataCollector(unittest.TestCase): @@ -38,6 +35,143 @@ class TestDataCollector(unittest.TestCase): config = DebuggerConfig(common_config, task_config, Const.STATISTICS, "./ut_dump", "L1") self.data_collector = DataCollector(config) + def test_dump_data_dir(self): + self.assertEqual(self.data_collector.dump_data_dir, None) + + self.data_collector.data_writer.dump_tensor_data_dir = "./test_dump" + self.assertEqual(self.data_collector.dump_data_dir, "./test_dump") + + def test_dump_file_path(self): + self.assertEqual(self.data_collector.dump_file_path, None) + + self.data_collector.data_writer.dump_file_path = "./test_dump/dump.json" + self.assertEqual(self.data_collector.dump_file_path, "./test_dump/dump.json") + + def test_scope_none_and_pid_match(self): + mock_name = "test_module" + current_pid = os.getpid() + result = self.data_collector.check_scope_and_pid(None, mock_name, current_pid) + self.assertTrue(result) + + def test_scope_valid_and_pid_match(self): + mock_scope = MagicMock() + mock_scope.check.return_value = True + mock_name = "valid_module" + current_pid = os.getpid() + result = self.data_collector.check_scope_and_pid(mock_scope, mock_name, current_pid) + self.assertTrue(result) + mock_scope.check.assert_called_once_with(mock_name) + + def test_scope_invalid_and_pid_match(self): + mock_scope = MagicMock() + mock_scope.check.return_value = False + mock_name = "invalid_module" + current_pid = os.getpid() + result = self.data_collector.check_scope_and_pid(mock_scope, mock_name, current_pid) + self.assertFalse(result) + + def test_scope_valid_but_pid_mismatch(self): + mock_scope = MagicMock() + mock_scope.check.return_value = True + mock_name = "valid_module" + fake_pid = os.getpid() + 1 + result = self.data_collector.check_scope_and_pid(mock_scope, mock_name, fake_pid) + self.assertFalse(result) + + def test_scope_none_but_pid_mismatch(self): + mock_name = "test_module" + fake_pid = os.getpid() + 1 + result = self.data_collector.check_scope_and_pid(None, mock_name, fake_pid) + self.assertFalse(result) + + def test_normal_case(self): + data_info = {"key1": {"other_field": "value"}} + self.data_collector.set_is_recomputable(data_info, True) + self.assertTrue(data_info["key1"]["is_recompute"]) + + self.data_collector.set_is_recomputable(data_info, False) + self.assertFalse(data_info["key1"]["is_recompute"]) + + def test_empty_data_info(self): + data_info = {} + original_data = data_info.copy() + self.data_collector.set_is_recomputable(data_info, True) + self.assertEqual(data_info, original_data) + + def test_data_info_length_not_one(self): + data_info = {"key1": {}, "key2": {}} + original_data = data_info.copy() + self.data_collector.set_is_recomputable(data_info, True) + self.assertEqual(data_info, original_data) + + def test_is_recompute_none(self): + data_info = {"key1": {}} + original_data = data_info.copy() + self.data_collector.set_is_recomputable(data_info, None) + self.assertEqual(data_info, original_data) + + def test_nested_structure(self): + data_info = {"layer1": {"sub_layer": {"value": 1}}} + self.data_collector.set_is_recomputable(data_info, True) + self.assertTrue(data_info["layer1"]["is_recompute"]) + self.assertEqual(data_info["layer1"]["sub_layer"]["value"], 1) + + def test_reset_status(self): + self.data_collector.optimizer_status = "test_optimizer_status" + self.data_collector.reset_status() + + self.assertEqual(self.data_collector.optimizer_status, "") + self.assertEqual( + self.data_collector.optimizer_status_first_start, + {Const.OPTIMIZER: True, Const.CLIP_GRAD: True} + ) + self.assertEqual(self.data_collector.backward_module_names, {}) + + def test_update_api_or_module_name(self): + self.assertEqual(self.data_collector.data_processor.current_api_or_module_name, None) + + self.data_collector.update_api_or_module_name("test_api_name") + self.assertEqual(self.data_collector.data_processor.current_api_or_module_name, "test_api_name") + + def test_write_json(self): + self.data_collector.data_writer = MagicMock() + + self.data_collector.write_json() + self.data_collector.data_writer.write_json.assert_called_once() + + def test_write_json_at_exit_with_async_dump_tensor(self): + self.data_collector.data_processor = MagicMock() + self.data_collector.data_writer = MagicMock() + self.data_collector.config.async_dump = True + self.data_collector.config.task = "tensor" + + self.data_collector.write_json_at_exit() + + self.data_collector.data_processor.dump_async_data.assert_called_once() + self.data_collector.data_writer.write_json.assert_called_once() + + def test_write_json_at_exit_with_no_async_dump(self): + self.data_collector.data_processor = MagicMock() + self.data_collector.data_writer = MagicMock() + self.data_collector.config.async_dump = False + self.data_collector.config.task = "tensor" + + self.data_collector.write_json_at_exit() + + self.data_collector.data_processor.dump_async_data.assert_not_called() + self.data_collector.data_writer.write_json.assert_called_once() + + def test_write_json_at_exit_with_statistics(self): + self.data_collector.data_processor = MagicMock() + self.data_collector.data_writer = MagicMock() + self.data_collector.config.async_dump = True + self.data_collector.config.task = "statistics" + + self.data_collector.write_json_at_exit() + + self.data_collector.data_processor.dump_async_data.assert_not_called() + self.data_collector.data_writer.write_json.assert_called_once() + def test_update_data(self): self.data_collector.config.task = Const.OVERFLOW_CHECK self.data_collector.data_processor.has_overflow = True @@ -59,6 +193,67 @@ class TestDataCollector(unittest.TestCase): mock_warning.assert_not_called() mock_debug.assert_called_once_with("msprobe is collecting data on Tensor.add.") + def test_call_stack_collect(self): + self.data_collector.data_processor = MagicMock() + self.data_collector.data_writer = MagicMock() + + test_name = "test_api" + mock_stack = ["func1", "func2", "func3"] + self.data_collector.data_processor.analyze_api_call_stack.return_value = mock_stack + + self.data_collector.call_stack_collect(test_name) + + self.data_collector.data_processor.analyze_api_call_stack.assert_called_once_with(test_name) + self.data_collector.data_writer.update_stack.assert_called_once_with(test_name, mock_stack) + + def test_update_construct_without_construct(self): + self.data_collector.data_writer = MagicMock() + + self.data_collector.config.level = "L1" + self.data_collector.update_construct("test") + self.data_collector.data_writer.update_construct.assert_not_called() + + def test_update_construct_with_first_start(self): + self.data_collector.module_processor = MagicMock() + self.data_collector.data_writer = MagicMock() + self.data_collector.config.level = "L0" + self.data_collector.optimizer_status = "optimizer" + self.data_collector.optimizer_status_first_start = {"optimizer": True} + + self.data_collector.update_construct("test_name") + calls = [ + unittest.mock.call({"optimizer": None}), + unittest.mock.call({"test_name": "optimizer"}), + unittest.mock.call(self.data_collector.module_processor.module_node) + ] + self.data_collector.data_writer.update_construct.assert_has_calls(calls) + + def test_update_construct_with_not_first_start(self): + self.data_collector.module_processor = MagicMock() + self.data_collector.data_writer = MagicMock() + self.data_collector.config.level = "L0" + self.data_collector.optimizer_status = "clip_grad" + self.data_collector.optimizer_status_first_start = {"clip_grad": False} + + self.data_collector.update_construct("test_name") + calls = [ + unittest.mock.call({"test_name": "clip_grad"}), + unittest.mock.call(self.data_collector.module_processor.module_node) + ] + self.data_collector.data_writer.update_construct.assert_has_calls(calls) + + def test_update_construct_with_module_prefix(self): + self.data_collector.module_processor = MagicMock() + self.data_collector.data_writer = MagicMock() + self.data_collector.config.level = "mix" + self.data_collector.optimizer_status = "other_status" + test_name = "Module_test_name" + + self.data_collector.update_construct(test_name) + self.data_collector.data_writer.update_construct.assert_called_with( + self.data_collector.module_processor.module_node + ) + def test_handle_data(self): with patch.object(DataCollector, "update_data") as mock_update_data, \ patch.object(DataCollector, "write_json") as mock_write_json, \ @@ -76,44 +271,212 @@ class TestDataCollector(unittest.TestCase): mock_flush.assert_not_called() mock_write_json.assert_called() - @patch.object(DataCollector, "update_construct") - @patch.object(DataWriter, "update_stack") - @patch.object(BaseDataProcessor, "analyze_api_call_stack") - @patch.object(DataCollector, "handle_data") - def test_forward_data_collect(self, mock_handle_data, _, __, ___): - with patch.object(DataCollector, "check_scope_and_pid", return_value=True), \ - patch.object(StatisticsDataProcessor, "analyze_forward", return_value={}): - with patch.object(StatisticsDataProcessor, "is_terminated", new=True): - self.data_collector.forward_data_collect("name", "module", "pid", "module_input_output") - mock_handle_data.assert_called_with("name", {}, flush=True) - - self.data_collector.forward_data_collect("name", "module", "pid", "module_input_output") - mock_handle_data.assert_called_with("name", {}, flush=False) - - @patch.object(DataCollector, "update_construct") - @patch.object(DataCollector, "handle_data") - def test_backward_data_collect(self, mock_handle_data, _): - with patch.object(DataCollector, "check_scope_and_pid", return_value=True), \ - patch.object(StatisticsDataProcessor, "analyze_backward", return_value={}): - with patch.object(StatisticsDataProcessor, "is_terminated", new=True): - self.data_collector.backward_data_collect("name", "module", "pid", "module_input_output") - mock_handle_data.assert_called_with("name", {}, flush=True) - - self.data_collector.backward_data_collect("name", "module", "pid", "module_input_output") - mock_handle_data.assert_called_with("name", {}, flush=False) - - @patch.object(DataWriter, "update_debug") - @patch.object(BaseDataProcessor, "analyze_debug_forward", return_value="data_info") - def test_debug_data_collect_forward(self, _, mock_update_debug): - self.data_collector.debug_data_collect_forward("variable", "name_with_count") - mock_update_debug.assert_called_with({"name_with_count": "data_info"}) - - @patch.object(DataWriter, "update_debug") - @patch.object(BaseDataProcessor, "analyze_debug_backward") - @patch.object(BaseDataProcessor, "analyze_element_to_all_none", return_value = "all_none_data_info") - def test_debug_data_collect_backward(self, _, mock_analyze_debug_backward, mock_update_debug): - self.data_collector.data_writer.cache_debug = {"data": None} - self.data_collector.debug_data_collect_backward("variable", "name_with_count") - mock_update_debug.assert_called_with({"name_with_count": "all_none_data_info"}) - mock_analyze_debug_backward.assert_called_with("variable", "name_with_count", self.data_collector.data_writer.cache_debug['data']) - self.data_collector.data_writer.cache_debug = None + +class TestForwardDataCollect(unittest.TestCase): + def setUp(self): + mock_json_data = { + "dump_path": "./test_fwd_dump", + } + with patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \ + patch("msprobe.pytorch.pt_config.load_json", return_value=mock_json_data): + common_config, task_config = parse_json_config("./config.json", Const.STATISTICS) + config = DebuggerConfig(common_config, task_config, Const.STATISTICS, "./test_fwd_dump", "L1") + self.data_collector = DataCollector(config) + + self.data_collector.update_construct = MagicMock() + self.data_collector.config = MagicMock() + self.data_collector.data_processor = MagicMock() + self.data_collector.scope = "test_scope" + self.data_collector.check_scope_and_pid = MagicMock() + self.data_collector.set_is_recomputable = MagicMock() + self.data_collector.handle_data = MagicMock() + self.data_collector.call_stack_collect = MagicMock() + + self.Const = MagicMock() + self.Const.FREE_BENCHMARK = "free_benchmark" + self.Const.TENSOR = "tensor" + self.Const.FORWARD = "forward" + self.Const.BACKWARD = "backward" + self.Const.STRUCTURE = "structure" + self.Const.LEVEL_L2 = "L2" + + def test_forward_input_with_free_benchmark_task(self): + self.data_collector.config.task = self.Const.FREE_BENCHMARK + self.data_collector.check_scope_and_pid.return_value = True + + self.data_collector.forward_input_data_collect( + "forward_test", + "module1", + 123, + "input_output" + ) + + self.data_collector.data_processor.analyze_forward_input.assert_called_once_with( + "backward_test", + "module1", + "input_output" + ) + + def test_forward_input_with_scope_pid_check_fail(self): + self.data_collector.config.task = self.Const.TENSOR + self.data_collector.check_scope_and_pid.return_value = False + + self.data_collector.forward_input_data_collect( + "test", "module1", 123, "input_output" + ) + + self.data_collector.data_processor.analyze_forward_input.assert_not_called() + + def test_forward_input_with_structure_task(self): + self.data_collector.config.task = self.Const.STRUCTURE + self.data_collector.check_scope_and_pid.return_value = True + + self.data_collector.forward_input_data_collect( + "test", "module1", 123, "input_output" + ) + + self.data_collector.data_processor.analyze_forward_input.assert_not_called() + self.data_collector.set_is_recomputable.assert_called_once_with({}, None) + + def test_forward_input_with_level_l2(self): + self.data_collector.config.task = self.Const.TENSOR + self.data_collector.config.level = self.Const.LEVEL_L2 + self.data_collector.check_scope_and_pid.return_value = True + + self.data_collector.forward_input_data_collect( + "test", "module1", 123, "input_output" + ) + + self.data_collector.handle_data.assert_not_called() + + def test_forward_input_with_recompute(self): + self.data_collector.config.task = self.Const.TENSOR + self.data_collector.config.level = "L1" + self.data_collector.check_scope_and_pid.return_value = True + mock_data = {"key": "value"} + self.data_collector.data_processor.analyze_forward_input.return_value = mock_data + + self.data_collector.forward_input_data_collect( + "test", "module1", 123, "input_output", is_recompute=True + ) + + self.data_collector.set_is_recomputable.assert_called_once_with(mock_data, True) + self.data_collector.handle_data.assert_called_once_with( + "test", mock_data, flush=self.data_collector.data_processor.is_terminated + ) + + def test_forward_output_with_scope_check_fail(self): + self.data_collector.check_scope_and_pid.return_value = False + self.data_collector.forward_output_data_collect("test", "module", 123, "data") + self.data_collector.data_processor.analyze_forward_output.assert_not_called() + + def test_forward_output_with_structure_task(self): + self.data_collector.config.task = self.Const.STRUCTURE + self.data_collector.forward_output_data_collect("test", "module", 123, "data") + self.data_collector.data_processor.analyze_forward_output.assert_not_called() + + def test_forward_output_with_level_l2(self): + self.data_collector.config.level = self.Const.LEVEL_L2 + self.data_collector.forward_output_data_collect("test", "module", 123, "data") + self.data_collector.handle_data.assert_not_called() + + def test_forward_output_normal(self): + mock_data = {"key": "value"} + self.data_collector.data_processor.analyze_forward_output.return_value = mock_data + self.data_collector.forward_output_data_collect("test", "module", 123, "data", True) + self.data_collector.handle_data.assert_called_once_with( + "test", + mock_data, + flush=self.data_collector.data_processor.is_terminated + ) + + def test_forward_with_scope_check_fail(self): + self.data_collector.check_scope_and_pid.return_value = False + self.data_collector.forward_data_collect("test", "module", 123, "data") + self.data_collector.data_processor.analyze_forward.assert_not_called() + + def test_forward_with_structure_task(self): + self.data_collector.config.task = self.Const.STRUCTURE + self.data_collector.forward_data_collect("test", "module", 123, "data") + self.data_collector.data_processor.analyze_forward.assert_not_called() + + def test_forward_normal(self): + mock_data = {"key": "value"} + self.data_collector.data_processor.analyze_forward.return_value = mock_data + self.data_collector.forward_data_collect("test", "module", 123, "data", False) + self.data_collector.call_stack_collect.assert_called_once_with("test") + self.data_collector.handle_data.assert_called_once_with( + "test", + mock_data, + flush=self.data_collector.data_processor.is_terminated + ) + + +class TestBackwardDataCollector(unittest.TestCase): + def setUp(self): + mock_json_data = { + "dump_path": "./test_bwd_dump", + } + with patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \ + patch("msprobe.pytorch.pt_config.load_json", return_value=mock_json_data): + common_config, task_config = parse_json_config("./config.json", Const.STATISTICS) + config = DebuggerConfig(common_config, task_config, Const.STATISTICS, "./test_bwd_dump", "L1") + self.data_collector = DataCollector(config) + + self.data_collector.config = MagicMock() + self.data_collector.data_processor = MagicMock() + self.data_collector.scope = "test_scope" + self.data_collector.check_scope_and_pid = MagicMock(return_value=True) + self.data_collector.set_is_recomputable = MagicMock() + self.data_collector.handle_data = MagicMock() + self.data_collector.update_construct = MagicMock() + self.data_collector.backward_module_names = {} + + self.Const = MagicMock() + self.Const.STRUCTURE = "structure" + self.Const.TENSOR = "tensor" + self.Const.LEVEL_L2 = "L2" + self.Const.SEP = "." + self.Const.MODULE_PREFIX = ["module"] + + def test_backward_with_scope_check_fail(self): + self.data_collector.check_scope_and_pid.return_value = False + self.data_collector.backward_data_collect("test", "module", 123, "data") + self.data_collector.data_processor.analyze_backward.assert_not_called() + + def test_backward_with_level_l2(self): + self.data_collector.config.level = self.Const.LEVEL_L2 + self.data_collector.backward_data_collect("test", "module", 123, "data") + self.data_collector.handle_data.assert_not_called() + + def test_backward_data_module_prefix_match(self): + self.data_collector.check_scope_and_pid.return_value = True + self.data_collector.config.task = self.Const.TENSOR + self.data_collector.config.level = "L1" + mock_data = {"key": "value"} + self.data_collector.data_processor.analyze_backward.return_value = mock_data + test_name = "Module.layer1.backward" + self.data_collector.backward_data_collect(test_name, "module", 123, "data") + self.assertEqual(self.data_collector.backward_module_names, {"Module": True}) + + def test_backward_input_with_structure_task(self): + self.data_collector.config.task = self.Const.STRUCTURE + self.data_collector.backward_input_data_collect("test", "module", 123, "data") + self.data_collector.data_processor.analyze_backward_input.assert_not_called() + + def test_backward_input_with_normal(self): + mock_data = {"key": "value"} + self.data_collector.data_processor.analyze_backward_input.return_value = mock_data + self.data_collector.backward_input_data_collect("test", "module", 123, "data", True) + self.data_collector.set_is_recomputable.assert_called_once_with(mock_data, True) + + def test_backward_output_with_scope_check_fail(self): + self.data_collector.check_scope_and_pid.return_value = False + self.data_collector.backward_output_data_collect("test", "module", 123, "data") + self.data_collector.data_processor.analyze_backward_output.assert_not_called() + + def test_backward_output_with_recompute(self): + mock_data = {"key": "value"} + self.data_collector.data_processor.analyze_backward_output.return_value = mock_data + self.data_collector.backward_output_data_collect("test", "module", 123, "data", False) + self.data_collector.set_is_recomputable.assert_called_once_with(mock_data, False) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py index 9b20ffb2197882e16c1550cf013d1ba132096063..31f34544e86423b10f2028e0fde1987990d09fa5 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py @@ -3,6 +3,7 @@ import os import unittest from unittest.mock import patch +from msprobe.core.common.const import Const from msprobe.core.common.utils import DumpPathAggregation from msprobe.core.common.file_utils import FileOpen, remove_path, load_json from msprobe.core.data_dump.json_writer import DataWriter @@ -13,6 +14,49 @@ class TestDataWriter(unittest.TestCase): self.data_writer = DataWriter() self.data_content = {"task": "tensor", "level": "L1", "data": {"Tensor.add": 1}} self.cur_path = os.path.dirname(os.path.realpath(__file__)) + self.stat_vector = [1.0, 2.0, 3.0, 4.0] # Example stat_vector for tests + self.data_writer.stat_stack_list = [self.stat_vector] # Mock the stat_stack_list + + def test_replace_stat_placeholders(self): + stat_result = [[1.0, 2.0, 3.0, 4.0]] # Mocking stat_result with a dummy value + data = {"type": "Tensor", "dtype": "float32", "shape": [1, 2, 3], Const.TENSOR_STAT_INDEX: 0} + + # Call _replace_stat_placeholders directly + self.data_writer._replace_stat_placeholders(data, stat_result) + + # Check that the function processed the placeholders correctly + self.assertEqual(data["Max"], 1.0) + self.assertEqual(data["Min"], 2.0) + self.assertEqual(data["Mean"], 3.0) + self.assertEqual(data["Norm"], 4.0) + + def test_append_stat_to_buffer(self): + index = self.data_writer.append_stat_to_buffer(self.stat_vector) + self.assertEqual(index, 1) # The first append will return index 0 + self.assertEqual(self.data_writer.stat_stack_list[0], + self.stat_vector) # Check if the stat is appended correctly + + def test_get_buffer_values_max(self): + max_value = self.data_writer.get_buffer_values_max(0) + self.assertEqual(max_value, 1.0) # The max value of stat_vector is 1.0 + + # Test when index is out of range + max_value_invalid = self.data_writer.get_buffer_values_max(1) + self.assertIsNone(max_value_invalid) # Should return None for invalid index + + def test_get_buffer_values_min(self): + min_value = self.data_writer.get_buffer_values_min(0) + self.assertEqual(min_value, 2.0) # The min value of stat_vector is 2.0 + + # Test when index is out of range + min_value_invalid = self.data_writer.get_buffer_values_min(1) + self.assertIsNone(min_value_invalid) # Should return None for invalid index + + def test_flush_stat_stack(self): + # Ensure that flush_stat_stack works and clears the stat_stack_list + result = self.data_writer.flush_stat_stack() + self.assertEqual(result, [[1.0, 2.0, 3.0, 4.0]]) # Returns the flushed stats + self.assertEqual(self.data_writer.stat_stack_list, []) # Ensure the list is cleared after flush def test_write_data_to_csv(self): cur_path = os.path.dirname(os.path.realpath(__file__)) @@ -42,9 +86,9 @@ class TestDataWriter(unittest.TestCase): remove_path(file_path) def test_reset_cache(self): - self.data_writer.cache_data={"data": 1} - self.data_writer.cache_stack={"stack": 2} - self.data_writer.cache_construct={"construct": 3} + self.data_writer.cache_data = {"data": 1} + self.data_writer.cache_stack = {"stack": 2} + self.data_writer.cache_construct = {"construct": 3} self.data_writer.reset_cache() self.assertEqual(self.data_writer.cache_data, {}) self.assertEqual(self.data_writer.cache_stack, {}) @@ -83,6 +127,7 @@ class TestDataWriter(unittest.TestCase): dump_path_aggregation.dump_tensor_data_dir = test_path dump_path_aggregation.free_benchmark_file_path = test_path dump_path_aggregation.debug_file_path = test_path + dump_path_aggregation.dump_error_info_path = test_path self.data_writer.update_dump_paths(dump_path_aggregation) self.assertTrue(self.data_writer.dump_file_path == test_path) @@ -117,8 +162,9 @@ class TestDataWriter(unittest.TestCase): self.assertEqual(self.data_writer.cache_data, expected) def test_update_stack(self): - self.data_writer.update_stack(self.data_content) - self.assertEqual(self.data_writer.cache_stack, self.data_content) + self.data_writer.cache_stack = {"stack1": ["test1"]} + self.data_writer.update_stack("test2", "stack1") + self.assertEqual(self.data_writer.cache_stack, {"stack1": ["test1", "test2"]}) def test_update_construct(self): self.data_writer.update_construct(self.data_content) @@ -136,13 +182,13 @@ class TestDataWriter(unittest.TestCase): os.remove(file_path) def test_write_stack_info_json(self): - self.data_writer.cache_stack = self.data_content + self.data_writer.cache_stack = {("api1", "api2"): ["stack1"]} file_path = os.path.join(self.cur_path, "stack.json") self.data_writer.write_stack_info_json(file_path) load_result = load_json(file_path) try: - self.assertEqual(load_result, self.data_content) + self.assertEqual(load_result, {"0": [["stack1"], ["api1", "api2"]]}) finally: os.remove(file_path) @@ -156,3 +202,48 @@ class TestDataWriter(unittest.TestCase): self.assertEqual(load_result, self.data_content) finally: os.remove(file_path) + + def test_replace_stat_placeholders_invalid_index(self): + data = { + "type": "Tensor", + "dtype": "float32", + "shape": [1, 2], + Const.TENSOR_STAT_INDEX: 10 # 超出索引 + } + stat_result = [[1.0, 2.0, 3.0, 4.0]] + self.data_writer._replace_stat_placeholders(data, stat_result) + self.assertIsNone(data.get(Const.TENSOR_STAT_INDEX)) + self.assertIn(Const.MAX, data) + self.assertIsNone(data[Const.MAX]) # 越界填 None + + def test_append_stat_to_buffer_multiple(self): + for i in range(5): + idx = self.data_writer.append_stat_to_buffer([i, i+1, i+2, i+3]) + self.assertEqual(idx, i + 1) + self.assertEqual(len(self.data_writer.stat_stack_list), 6) # 包含 setUp 中那一条 + + def test_get_buffer_values_max_invalid_data(self): + self.data_writer.stat_stack_list = [["not-a-number"]] # 非预期格式 + max_val = self.data_writer.get_buffer_values_max(0) + self.assertEqual(max_val, "not-a-number") # 仍然返回第一位 + + max_val = self.data_writer.get_buffer_values_max(-1) + self.assertIsNone(max_val) + + def test_flush_stat_stack_empty(self): + self.data_writer.stat_stack_list = [] + result = self.data_writer.flush_stat_stack() + self.assertEqual(result, []) + + def test_flush_stat_stack_with_tensor_like_items(self): + class DummyTensor: + def __init__(self, v): self.v = v + def item(self): return self.v + + self.data_writer.stat_stack_list = [ + [DummyTensor(1), DummyTensor(2), DummyTensor(3), DummyTensor(4)], + [5.5, 6.6, 7.7, 8.8] # 混合类型 + ] + result = self.data_writer.flush_stat_stack() + self.assertEqual(result, [[1, 2, 3, 4], [5.5, 6.6, 7.7, 8.8]]) + self.assertEqual(self.data_writer.stat_stack_list, []) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/monitor/__init__.py b/debug/accuracy_tools/msprobe/test/core_ut/monitor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_analyse.py b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_anomaly_processor.py similarity index 46% rename from debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_analyse.py rename to debug/accuracy_tools/msprobe/test/core_ut/monitor/test_anomaly_processor.py index 904be210a3771f1757e4410b5e0fa0f2ad6152f2..2511d60caa823366c778761ccb8fb9bca747d2f5 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_analyse.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_anomaly_processor.py @@ -1,14 +1,282 @@ import os import unittest +from unittest import TestCase from unittest.mock import patch, MagicMock -from msprobe.pytorch.monitor.anomaly_detect import GradAnomalyData +from msprobe.core.monitor.anomaly_processor import ScanRule, AnomalyTurbulence, AnomalyNan, AnomalyScanner, \ + AnomalyDataFactory, GradAnomalyData, AnomalyDataWriter, AnomalyDataLoader, AnomalyAnalyse, \ + _get_step_and_stop, _anomaly_analyse, _get_parse_args -from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter, AnomalyDataLoader, AnomalyAnalyse, \ - _get_parse_args, _get_step_and_stop, _anomaly_analyse +class TestScanRule(TestCase): + def test_apply_not_implemented(self): + scan_rule = ScanRule() + with self.assertRaises(Exception) as context: + scan_rule.apply(None, None) + + self.assertEqual(str(context.exception), "abstract method apply is not implemented") + + +class TestAnomalyTurbulence(TestCase): + + def setUp(self) -> None: + self.threshold = 0.2 + self.rule = AnomalyTurbulence(self.threshold) + + def test_apply_with_positive_baseline(self): + history = 12 + cur = 16 + result = self.rule.apply(cur, history=history) + self.assertTrue(result) + + def test_apply_with_non_positive_baseline(self): + history = 0 + cur = -1 + result = self.rule.apply(cur, history=history) + self.assertTrue(result) + + def test_apply_with_valid_value(self): + history = 0 + cur = 0 + result = self.rule.apply(cur, history=history) + self.assertFalse(result) + + +class TestAnomalyNan(TestCase): + + def setUp(self) -> None: + self.threshold = 1e10 + self.rule = AnomalyNan(self.threshold) + + def test_apply_with_nan(self): + cur = float("nan") + result = self.rule.apply(cur) + self.assertTrue(result) + + def test_apply_with_big_value(self): + cur = float("1e30") + result = self.rule.apply(cur) + self.assertTrue(result) + + def test_apply_with_valid_value(self): + cur = 0.5 + result = self.rule.apply(cur) + self.assertFalse(result) + + +class TestAnomalyScanner(TestCase): + + def test_load_rules_with_valied_spec(self): + specs = [ + {"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.2}} + ] + rules = AnomalyScanner.load_rules(specs) + + self.assertEqual(len(rules), 1) + self.assertIsInstance(rules[0], AnomalyTurbulence) + self.assertEqual(rules[0].threshold, 0.2) + + rules = AnomalyScanner.load_rules(None) + self.assertEqual(len(rules), 0) + + @patch("msprobe.core.monitor.anomaly_processor.logger") + def test_load_rules_with_missing_keys(self, mock_logger): + specs = [ + {"rule_name": "AnomalyTurbulence"} + ] + rules = AnomalyScanner.load_rules(specs) -class TestAnomalyDataWriter(unittest.TestCase): + self.assertEqual(len(rules), 0) + mock_logger.warning.assert_called_once_with(f"Spec is missing required keys: {specs[0]}") + + def test_load_rules_with_invalid_rule(self): + # test invalid rule_name + specs = [{"rule_name": "InvalidRule", "args": {"threshold": 0.2}}] + rules = AnomalyScanner.load_rules(specs) + self.assertEqual(len(rules), 0) + + # test invalid args + specs = [{"rule_name": "AnomalyTurbulence", "args": "invalid args"}] + rules = AnomalyScanner.load_rules(specs) + self.assertEqual(len(rules), 0) + + def test_scan(self): + ad_rules = [AnomalyTurbulence(0.2)] + # test scan with anomaly + expected = True, "AnomalyTurbulence" + self.assertEqual(AnomalyScanner.scan(ad_rules, 1.0, 2.0), expected) + # test scan with no anomaly + expected = False, None + self.assertEqual(AnomalyScanner.scan(ad_rules, 1.0, 1.0), expected) + + +class TestAnomalyDataFactory(TestCase): + + def setUp(self) -> None: + rank = 0 + pp_stage = 0 + group_mates = [0] + self.AnomalyDataFactory = AnomalyDataFactory(rank, pp_stage, group_mates) + + def test_set_call_id(self): + name2callid = {'param_name': 0} + self.AnomalyDataFactory.set_call_id(name2callid) + + self.assertEqual(self.AnomalyDataFactory.name2callid, {'param_name': 0}) + + def test_create_success(self): + tag = ('0:1.self_attention.core_attention_flash_0/rank0/output', 'min') + message = "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash_0/rank0/output', 'min') at step 2." + step = 2 + result = self.AnomalyDataFactory.create(tag, message, step) + + self.assertEqual(result.step, step) + self.assertEqual(result.tag_name, tag[0]) + self.assertEqual(result.message, message) + self.assertEqual(result.vpp_stage, 0) + + # test no vpp_stage + tag = ('1.self_attention.core_attention_flash_0/rank0/output', 'min') + result = self.AnomalyDataFactory.create(tag, message, step) + self.assertEqual(result.vpp_stage, 0) + + def test_create_failed(self): + error_tag = '0:1.self_attention.core_attention_flash_0/rank0/output' + message = "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash_0/rank0/output', 'min') at step 2." + step = 2 + with self.assertRaises(Exception) as context: + self.AnomalyDataFactory.create(error_tag, message, step) + self.assertEqual(str(context.exception), "tag must be a tuple with length 2") + + +class TestGradAnomalyData(TestCase): + + def setUp(self) -> None: + tag_name = "0:1.self_attention.core_attention_flash.output:0/rank0/actv" + message = "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash.output:0/rank0/actv', 'min') at step 2." + group_mates = [0] + self.GradAnomalyData = GradAnomalyData(tag_name=tag_name, message=message, group_mates=group_mates) + + def test_get_train_stage(self): + tag_name_list = ["0:fc2.input:0/rank0/actv", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq", ""] + expected_train_stage_list = [0, 1, 2, -1] + for tag_name, expected_train_stage in zip(tag_name_list, expected_train_stage_list): + train_stage = GradAnomalyData.get_train_stage(tag_name) + self.assertEqual(train_stage, expected_train_stage) + + def test_to_dict(self): + expected = { + 'rank': 0, + 'step': 0, + 'micro_step': 0, + 'pp_stage': 0, + 'vpp_stage': 0, + 'call_id': 0, + 'tag_name': "0:1.self_attention.core_attention_flash.output:0/rank0/actv", + 'message': "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash.output:0/rank0/actv', 'min') at step 2.", + 'group_mates': [0] + } + + self.assertEqual(self.GradAnomalyData.to_dict(), expected) + + def test_get_key(self): + expected = "0:1.self_attention.core_attention_flash.output:0/rank0/actv_step_0_call_0" + + self.assertEqual(self.GradAnomalyData.get_key(), expected) + + def test_lt_different_step(self): + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") + data2 = GradAnomalyData(step=2, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") + self.assertLess(data1, data2) + self.assertGreater(data2, data1) + + def test_lt_same_step_different_micro_step(self): + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") + data2 = GradAnomalyData(step=1, micro_step=1, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") + self.assertLess(data1, data2) + self.assertGreater(data2, data1) + + def test_lt_same_step_same_micro_step_different_vpp_stage(self): + # same forward + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/actv") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=1, pp_stage=0, call_id=0, tag_name="xxx/actv") + self.assertLess(data1, data2) + self.assertGreater(data2, data1) + + # same backward + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/post_grad") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=1, pp_stage=0, call_id=0, tag_name="xxx/post_grad") + self.assertLess(data2, data1) + self.assertGreater(data1, data2) + + # diff train stage + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/actv") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=1, pp_stage=0, call_id=0, tag_name="xxx/post_grad") + self.assertLess(data1, data2) + self.assertGreater(data2, data1) + + def test_lt_same_step_same_micro_step_same_vpp_stage_different_pp_stage(self): + # same forward + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/actv") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=1, call_id=0, tag_name="xxx/actv") + self.assertLess(data1, data2) + self.assertGreater(data2, data1) + + # same backward + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/post_grad") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=1, call_id=0, tag_name="xxx/post_grad") + self.assertLess(data2, data1) + self.assertGreater(data1, data2) + + # diff train stage + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/input") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=1, call_id=0, tag_name="xxx/post_grad") + self.assertLess(data1, data2) + self.assertGreater(data2, data1) + + def test_lt_same_step_same_micro_step_same_vpp_stage_same_pp_stage_different_call_id(self): + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=1, tag_name="") + self.assertLess(data1, data2) + self.assertGreater(data2, data1) + + def test_lt_same_data(self): + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") + self.assertGreaterEqual(data1, data2) + self.assertLessEqual(data1, data2) + + def test_lt_not_instance(self): + data = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0) + not_instance = "not an instance of GradAnomalyData" + self.assertEqual(data.__lt__(not_instance), NotImplemented) + + def test_le_same_instance(self): + # 测试相同实例的情况 + data1 = GradAnomalyData() + self.assertTrue(data1 <= data1) + + def test_le_different_instance(self): + # 测试不同实例的情况 + data1 = GradAnomalyData() + data2 = GradAnomalyData() + self.assertTrue(data1 <= data2) + + def test_le_not_instance(self): + # 测试非GradAnomalyData实例的情况 + data = GradAnomalyData() + not_instance = "Not an instance of GradAnomalyData" + self.assertEqual(data.__le__(not_instance), NotImplemented) + + def test_le_different_instance_not_equal(self): + # 测试不同实例且不相等的情况 + data1 = GradAnomalyData() + data2 = GradAnomalyData() + data2.some_attribute = "some value" + self.assertTrue(data1 <= data2) + + +class TestAnomalyDataWriter(TestCase): def test_get_anomaly_dict(self): # 测试 get_anomaly_dict 方法 @@ -29,9 +297,9 @@ class TestAnomalyDataWriter(unittest.TestCase): } self.assertEqual(result, expected) - @patch('msprobe.pytorch.monitor.anomaly_analyse.os.path.exists') - @patch('msprobe.pytorch.monitor.anomaly_analyse.create_directory') - @patch('msprobe.pytorch.monitor.anomaly_analyse.save_json') + @patch('msprobe.core.monitor.anomaly_processor.os.path.exists') + @patch('msprobe.core.monitor.anomaly_processor.create_directory') + @patch('msprobe.core.monitor.anomaly_processor.save_json') def test_init_detected_json(self, mock_save_json, mock_create_directory, mock_exists): # 模拟路径检查 mock_exists.side_effect = [False, False, False] # dump_path, dump_rank_dir, json_path @@ -42,16 +310,15 @@ class TestAnomalyDataWriter(unittest.TestCase): writer.init_detected_json() # 检查是否创建了目录 - mock_create_directory.assert_any_call('/tmp/dump') mock_create_directory.assert_any_call('/tmp/dump/rank0') # 检查是否初始化了 JSON 文件 mock_save_json.assert_called_once_with(writer.json_path, {}, indent=1) - @patch('msprobe.pytorch.monitor.anomaly_analyse.check_file_or_directory_path') - @patch('msprobe.pytorch.monitor.anomaly_analyse.remove_path') - @patch('msprobe.pytorch.monitor.anomaly_analyse.save_json') - @patch('msprobe.pytorch.monitor.anomaly_analyse.logger') + @patch('msprobe.core.monitor.anomaly_processor.check_file_or_directory_path') + @patch('msprobe.core.monitor.anomaly_processor.remove_path') + @patch('msprobe.core.monitor.anomaly_processor.save_json') + @patch('msprobe.core.monitor.anomaly_processor.logger') def test_init_detected_json_existing_file(self, mock_logger, mock_save_json, mock_remove_path, mock_check_path): # 设置测试参数 dump_path = 'test/dump_path' @@ -72,9 +339,9 @@ class TestAnomalyDataWriter(unittest.TestCase): mock_logger.warning.assert_called_once_with(f"The existing file will be deleted: {writer.json_path}.") mock_save_json.assert_called_once_with(writer.json_path, {}, indent=1) - @patch('msprobe.pytorch.monitor.anomaly_analyse.os.path.exists') - @patch('msprobe.pytorch.monitor.anomaly_analyse.load_json') - @patch('msprobe.pytorch.monitor.anomaly_analyse.save_json') + @patch('msprobe.core.monitor.anomaly_processor.os.path.exists') + @patch('msprobe.core.monitor.anomaly_processor.load_json') + @patch('msprobe.core.monitor.anomaly_processor.save_json') def test_write_detected_json(self, mock_save_json, mock_load_json, mock_exists): mock_exists.side_effect = [True, True] # json_path 存在 @@ -101,9 +368,9 @@ class TestAnomalyDataWriter(unittest.TestCase): mock_save_json.assert_called_once_with(writer.json_path, expected_data, indent=1) -class TestAnomalyDataLoader(unittest.TestCase): +class TestAnomalyDataLoader(TestCase): - @patch('msprobe.pytorch.monitor.anomaly_analyse.GradAnomalyData') # 替换为 GradAnomalyData 的实际导入路径 + @patch('msprobe.core.monitor.anomaly_processor.GradAnomalyData') # 替换为 GradAnomalyData 的实际导入路径 def test_create_instances_from_dict(self, mock_GradAnomalyData): # 模拟 GradAnomalyData 的构造函数 def mock_constructor(**kwargs): @@ -122,11 +389,11 @@ class TestAnomalyDataLoader(unittest.TestCase): # 确保创建了两个实例,第三个因缺少 key2 被捕获 self.assertEqual(len(instances), 2) - @patch('msprobe.pytorch.monitor.anomaly_analyse.os.listdir') - @patch('msprobe.pytorch.monitor.anomaly_analyse.os.path.exists') - @patch('msprobe.pytorch.monitor.anomaly_analyse.load_json') - @patch('msprobe.pytorch.monitor.anomaly_analyse.check_file_or_directory_path') - @patch('msprobe.pytorch.monitor.anomaly_analyse.GradAnomalyData') + @patch('msprobe.core.monitor.anomaly_processor.os.listdir') + @patch('msprobe.core.monitor.anomaly_processor.os.path.exists') + @patch('msprobe.core.monitor.anomaly_processor.load_json') + @patch('msprobe.core.monitor.anomaly_processor.check_file_or_directory_path') + @patch('msprobe.core.monitor.anomaly_processor.GradAnomalyData') def test_get_anomalies_from_jsons(self, mock_GradAnomalyData, mock_check_path, mock_load_json, mock_exists, mock_listdir): mock_check_path.return_value = None @@ -146,7 +413,7 @@ class TestAnomalyDataLoader(unittest.TestCase): mock_GradAnomalyData.side_effect = mock_constructor # 假设构造成功 loader = AnomalyDataLoader('/tmp/data') - with patch('msprobe.pytorch.monitor.anomaly_analyse.os.path.isdir', return_value=True): + with patch('msprobe.core.monitor.anomaly_processor.os.path.isdir', return_value=True): anomalies = loader.get_anomalies_from_jsons() # 确保从 rank0 读取了异常数据 @@ -155,7 +422,7 @@ class TestAnomalyDataLoader(unittest.TestCase): mock_load_json.assert_called_once_with('/tmp/data/rank0/anomaly.json') -class TestAnomalyAnalyse(unittest.TestCase): +class TestAnomalyAnalyse(TestCase): def setUp(self): self.anomaly_analyse = AnomalyAnalyse() @@ -189,10 +456,10 @@ class TestAnomalyAnalyse(unittest.TestCase): self.assertEqual(len(result), 3) self.assertEqual(result, [anomalies[1], anomalies[0], anomalies[2]]) - @patch('msprobe.pytorch.monitor.anomaly_analyse.os.path.exists') - @patch('msprobe.pytorch.monitor.anomaly_analyse.AnomalyDataWriter.get_anomaly_dict') - @patch('msprobe.pytorch.monitor.anomaly_analyse.save_json') - @patch('msprobe.pytorch.monitor.anomaly_analyse.logger') + @patch('msprobe.core.monitor.anomaly_processor.os.path.exists') + @patch('msprobe.core.monitor.anomaly_processor.AnomalyDataWriter.get_anomaly_dict') + @patch('msprobe.core.monitor.anomaly_processor.save_json') + @patch('msprobe.core.monitor.anomaly_processor.logger') def test_rewrite_sorted_anomalies(self, mock_logger, mock_save_json, mock_get_anomaly_dict, mock_exists): # 设置 mock mock_exists.return_value = False @@ -202,7 +469,7 @@ class TestAnomalyAnalyse(unittest.TestCase): # 调用方法 self.anomaly_analyse.sorted_anomalies = self.anomalies - with patch("msprobe.pytorch.monitor.anomaly_analyse.check_file_or_directory_path", return_value=None): + with patch("msprobe.core.monitor.anomaly_processor.check_file_or_directory_path", return_value=None): self.anomaly_analyse.rewrite_sorted_anomalies(output_path) # 验证调用 @@ -214,17 +481,17 @@ class TestAnomalyAnalyse(unittest.TestCase): ) mock_logger.info.assert_called_once_with("anomaly_analyse.json is at output_path.") - @patch('msprobe.pytorch.monitor.anomaly_analyse.os.path.exists') - @patch('msprobe.pytorch.monitor.anomaly_analyse.logger') + @patch('msprobe.core.monitor.anomaly_processor.os.path.exists') + @patch('msprobe.core.monitor.anomaly_processor.logger') def test_rewrite_sorted_anomalies_file_exists(self, mock_logger, mock_exists): # 模拟文件已经存在的情况 mock_exists.return_value = True output_path = 'output_path' # 调用方法 - with patch("msprobe.pytorch.monitor.anomaly_analyse.check_file_or_directory_path", return_value=None), \ - patch("msprobe.pytorch.monitor.anomaly_analyse.remove_path", return_value=None), \ - patch("msprobe.pytorch.monitor.anomaly_analyse.save_json", return_value=None): + with patch("msprobe.core.monitor.anomaly_processor.check_file_or_directory_path", return_value=None), \ + patch("msprobe.core.monitor.anomaly_processor.remove_path", return_value=None), \ + patch("msprobe.core.monitor.anomaly_processor.save_json", return_value=None): self.anomaly_analyse.rewrite_sorted_anomalies(output_path) # 验证日志警告 @@ -232,35 +499,7 @@ class TestAnomalyAnalyse(unittest.TestCase): f"The existing file will be deleted: output_path/anomaly_analyse.json.") -class TestParseArgs(unittest.TestCase): - - @patch('msprobe.pytorch.monitor.anomaly_analyse.sys.argv', - new=['script_name', '-d', 'path/to/data', '-o', 'path/to/output', '-k', '5', '-s', '[1,2,3]']) - def test_parse_args_with_all_arguments(self): - args = _get_parse_args() - self.assertEqual(args.data_path_dir, 'path/to/data') - self.assertEqual(args.out_path, 'path/to/output') - self.assertEqual(args.top_k_number, 5) - self.assertEqual(args.step_list, '[1,2,3]') - - @patch('msprobe.pytorch.monitor.anomaly_analyse.sys.argv', new=['script_name', '-d', 'path/to/data']) - def test_parse_args_with_required_argument_only(self): - args = _get_parse_args() - self.assertEqual(args.data_path_dir, 'path/to/data') - self.assertEqual(args.out_path, '') - self.assertEqual(args.top_k_number, 8) # 默认值 - self.assertEqual(args.step_list, '[]') # 默认值 - - @patch('msprobe.pytorch.monitor.anomaly_analyse.sys.argv', new=['script_name', '-d', 'path/to/data', '-k', '10']) - def test_parse_args_with_topk_only(self): - args = _get_parse_args() - self.assertEqual(args.data_path_dir, 'path/to/data') - self.assertEqual(args.out_path, '') - self.assertEqual(args.top_k_number, 10) # 提供的值 - self.assertEqual(args.step_list, '[]') # 默认值 - - -class TestGetStepAndStop(unittest.TestCase): +class TestGetStepAndStop(TestCase): def test_valid_step_list_and_top_k(self): # 构造有效的 args 对象 @@ -318,13 +557,13 @@ class TestGetStepAndStop(unittest.TestCase): self.assertEqual(str(context.exception), "The top k number must be greater than 0.") -class TestAnomalyAnalyseFunction(unittest.TestCase): +class TestAnomalyAnalyseFunction(TestCase): - @patch('msprobe.pytorch.monitor.anomaly_analyse._get_parse_args') # 模拟命令行参数解析 - @patch('msprobe.pytorch.monitor.anomaly_analyse._get_step_and_stop') # 模拟步骤和顶级数字解析 - @patch('msprobe.pytorch.monitor.anomaly_analyse.AnomalyDataLoader') # 模拟数据加载器 - @patch('msprobe.pytorch.monitor.anomaly_analyse.AnomalyAnalyse') # 模拟异常分析器 - @patch('msprobe.pytorch.monitor.anomaly_analyse.logger') # 模拟日志记录 + @patch('msprobe.core.monitor.anomaly_processor._get_parse_args') # 模拟命令行参数解析 + @patch('msprobe.core.monitor.anomaly_processor._get_step_and_stop') # 模拟步骤和顶级数字解析 + @patch('msprobe.core.monitor.anomaly_processor.AnomalyDataLoader') # 模拟数据加载器 + @patch('msprobe.core.monitor.anomaly_processor.AnomalyAnalyse') # 模拟异常分析器 + @patch('msprobe.core.monitor.anomaly_processor.logger') # 模拟日志记录 def test_anomaly_analyse(self, mock_logger, mock_anomaly_analyse, mock_anomaly_data_loader, mock_get_step_and_stop, mock_get_parse_args): # 模拟命令行参数 @@ -376,5 +615,33 @@ class TestAnomalyAnalyseFunction(unittest.TestCase): mock_logger.info.assert_any_call("1: Top Anomaly 2") +class TestParseArgs(TestCase): + + @patch('msprobe.core.monitor.anomaly_processor.sys.argv', + new=['script_name', '-d', 'path/to/data', '-o', 'path/to/output', '-k', '5', '-s', '[1,2,3]']) + def test_parse_args_with_all_arguments(self): + args = _get_parse_args() + self.assertEqual(args.data_path_dir, 'path/to/data') + self.assertEqual(args.out_path, 'path/to/output') + self.assertEqual(args.top_k_number, 5) + self.assertEqual(args.step_list, '[1,2,3]') + + @patch('msprobe.core.monitor.anomaly_processor.sys.argv', new=['script_name', '-d', 'path/to/data']) + def test_parse_args_with_required_argument_only(self): + args = _get_parse_args() + self.assertEqual(args.data_path_dir, 'path/to/data') + self.assertEqual(args.out_path, '') + self.assertEqual(args.top_k_number, 8) # 默认值 + self.assertEqual(args.step_list, '[]') # 默认值 + + @patch('msprobe.core.monitor.anomaly_processor.sys.argv', new=['script_name', '-d', 'path/to/data', '-k', '10']) + def test_parse_args_with_topk_only(self): + args = _get_parse_args() + self.assertEqual(args.data_path_dir, 'path/to/data') + self.assertEqual(args.out_path, '') + self.assertEqual(args.top_k_number, 10) # 提供的值 + self.assertEqual(args.step_list, '[]') # 默认值 + + if __name__ == '__main__': unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py new file mode 100644 index 0000000000000000000000000000000000000000..6b53ade72b2755d7a815955656b5321a1c27d101 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py @@ -0,0 +1,234 @@ +import unittest +import os +import tempfile +import shutil +from unittest.mock import patch, MagicMock +import pandas as pd + +from msprobe.core.monitor.csv2db import ( + CSV2DBConfig, + validate_process_num, + validate_step_partition, + validate_data_type_list, + _pre_scan_single_rank, + _pre_scan, + process_single_rank, + import_data, + csv2db, + all_data_type_list, +) +from msprobe.core.common.const import MonitorConst + + +class TestCSV2DBValidations(unittest.TestCase): + def test_validate_process_num_valid(self): + """测试有效的进程数""" + validate_process_num(1) + validate_process_num(MonitorConst.MAX_PROCESS_NUM) + + def test_validate_process_num_invalid(self): + """测试无效的进程数""" + with self.assertRaises(ValueError): + validate_process_num(0) + with self.assertRaises(ValueError): + validate_process_num(-1) + with self.assertRaises(ValueError): + validate_process_num(MonitorConst.MAX_PROCESS_NUM + 1) + + def test_validate_step_partition_valid(self): + """测试有效的step分区""" + validate_step_partition(MonitorConst.MIN_PARTITION) + validate_step_partition(MonitorConst.MAX_PARTITION) + + def test_validate_step_partition_invalid(self): + """测试无效的step分区""" + with self.assertRaises(ValueError): + validate_step_partition(MonitorConst.MAX_PARTITION + 1) + with self.assertRaises(ValueError): + validate_step_partition(MonitorConst.MIN_PARTITION - 1) + with self.assertRaises(TypeError): + validate_step_partition(500.0) + + def test_validate_data_type_list_valid(self): + """测试有效的数据类型列表""" + validate_data_type_list(["actv", "grad_reduced"]) + validate_data_type_list(all_data_type_list[:2]) + + def test_validate_data_type_list_invalid(self): + """测试无效的数据类型列表""" + with self.assertRaises(ValueError): + validate_data_type_list(["invalid_type"]) + with self.assertRaises(ValueError): + validate_data_type_list(["actv", "invalid_type"]) + + +class TestPreScanFunctions(unittest.TestCase): + def setUp(self): + # 创建临时目录和测试CSV文件 + self.temp_dir = tempfile.mkdtemp() + self.temp_dir_rank2 = tempfile.mkdtemp() + self.test_csv_path_actv = os.path.join(self.temp_dir, "actv_0-100.csv") + self.test_csv_path_rank2_grad = os.path.join( + self.temp_dir_rank2, "grad_reduced_100-200.csv") + self.test_csv_path_rank_inv = os.path.join( + self.temp_dir_rank2, "invalid_metric_100-200.csv") + + # 创建测试CSV数据 + test_data_actv = { + "name": ["layer1", "layer2"], + "vpp_stage": [0, 0], + "micro_step": [0, 1], + "step": [10, 20], + "min": [0.1, 0.2], + "max": [1.0, 2.0] + } + test_data_grad = { + "name": ["layer1_weight", "layer2_weight"], + "vpp_stage": [0, 0], + "micro_step": [0, 1], + "step": [10, 20], + "min": [0.1, 0.2], + "max": [1.0, 2.0] + } + df = pd.DataFrame(test_data_actv) + df.to_csv(self.test_csv_path_actv, index=False) + df = pd.DataFrame(test_data_grad) + df.to_csv(self.test_csv_path_rank2_grad, index=False) + df = pd.DataFrame(test_data_grad) + df.to_csv(self.test_csv_path_rank_inv, index=False) + + def tearDown(self): + # 清理临时目录 + shutil.rmtree(self.temp_dir) + + def test_pre_scan_single_rank(self): + """测试单个rank的预扫描""" + rank = 0 + files = [self.test_csv_path_actv] + result = _pre_scan_single_rank(rank, files) + self.assertEqual(result["max_rank"], rank) + self.assertEqual(result["metrics"], {"actv"}) + self.assertEqual(result["min_step"], 0) + self.assertEqual(result["max_step"], 100) + self.assertEqual(result["metric_stats"], {"actv": {"min", "max"}}) + self.assertEqual(len(result["targets"]), 2) + + def test_pre_scan(self): + """测试完整预扫描流程""" + # 模拟MonitorDB + mock_db = MagicMock() + + # 测试数据 + data_dirs = {0: self.temp_dir, 2: self.temp_dir_rank2} + data_type_list = ["actv", "grad_reduced"] + + result = _pre_scan(mock_db, data_dirs, data_type_list) + + self.assertEqual(sorted(list(result.keys())), [0, 2]) + + mock_db.insert_dimensions.assert_called_once() + mock_db.update_global_stats.assert_called_with( + max_rank=2, min_step=0, max_step=200 + ) + + +class TestProcessSingleRank(unittest.TestCase): + @patch("msprobe.core.monitor.csv2db.MonitorDB") + @patch("msprobe.core.monitor.csv2db.read_csv") + def test_process_single_rank(self, mock_read_csv, mock_db_class): + """测试处理单个rank的数据""" + # 模拟数据库和映射 + mock_db = MagicMock() + mock_db_class.return_value = mock_db + mock_db.get_metric_table_name.return_value = ( + "metric_1_step_0_99", 0, 99) + mock_db.insert_rows.return_value = 2 + + # 模拟CSV数据 + mock_result = pd.DataFrame({ + "name": ["layer1", "layer2"], + "vpp_stage": [0, 0], + "micro_step": [0, 1], + "step": [10, 20], + "norm": [0.1, 0.2], + "max": [1.0, 2.0] + }) + mock_read_csv.return_value = mock_result + + # 测试数据 + task = (0, ["actv_10-20.csv"]) + metric_id_dict = {"actv": (1, ["norm", "max"])} + target_dict = {("layer1", 0, 0): 1, ("layer2", 0, 1): 2} + step_partition_size = 100 + db_path = "dummy.db" + + result = process_single_rank( + task, metric_id_dict, target_dict, step_partition_size, db_path) + + self.assertEqual(result, 2) + mock_db.insert_rows.assert_called_with( + "metric_1_step_0_99", [(0, 10, 1, 0.1, 1.0), (0, 20, 2, 0.2, 2.0)] + ) + + +class TestImportData(unittest.TestCase): + @patch("msprobe.core.monitor.csv2db._pre_scan") + def test_import_data_success(self, mock_pre_scan): + """测试数据导入成功场景""" + # 模拟预扫描结果 + mock_pre_scan.return_value = { + 0: ["actv_10-20.csv"], 1: ["actv_10-20.csv"]} + + # 模拟数据库 + mock_db = MagicMock() + mock_db.get_metric_mapping.return_value = {"actv": (1, ["min", "max"])} + mock_db.get_target_mapping.return_value = {("layer1", 0, 0): 1} + + # 测试数据 + data_dirs = {0: "dir0", 1: "dir1"} + data_type_list = ["actv"] + workers = 2 + + import_data(mock_db, data_dirs, data_type_list, workers) + + mock_db.init_schema.assert_called_once() + mock_pre_scan.assert_called_once() + + @patch("msprobe.core.monitor.csv2db._pre_scan") + def test_import_data_no_files(self, mock_pre_scan): + """测试没有找到数据文件的情况""" + mock_pre_scan.return_value = {} + + mock_db = MagicMock() + data_dirs = {0: "dir0"} + data_type_list = ["actv"] + + result = import_data(mock_db, data_dirs, data_type_list) + + self.assertFalse(result) + mock_pre_scan.assert_called_once() + + +class TestCSV2DBMain(unittest.TestCase): + @patch("msprobe.core.monitor.csv2db.import_data") + @patch("msprobe.core.monitor.csv2db.get_target_output_dir") + @patch("msprobe.core.monitor.csv2db.create_directory") + def test_csv2db(self, mock_create_dir, mock_get_dirs, mock_import): + """测试主函数csv2db""" + # 模拟配置 + config = CSV2DBConfig( + monitor_path="test_path", + data_type_list=["actv"], + process_num=4, + step_partition=500 + ) + + # 模拟依赖函数 + mock_get_dirs.return_value = {0: "dir0", 1: "dir1"} + mock_import.return_value = True + + csv2db(config) + + mock_get_dirs.assert_called_once() + mock_create_dir.assert_called_once() + mock_import.assert_called_once() diff --git a/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_db_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_db_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2c6ee1fda879434cf213d27f26314ecb506f3958 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_db_utils.py @@ -0,0 +1,269 @@ +import unittest +import os +import re +import tempfile +from collections import OrderedDict +from unittest.mock import patch + +from msprobe.core.common.const import MonitorConst +from msprobe.core.monitor.db_utils import MonitorDB, MonitorSql, update_ordered_dict, get_ordered_stats + + +def normalize_spaces(text): + return re.sub(r'\s+', ' ', text) + + +class TestDBUtils(unittest.TestCase): + def test_update_ordered_dict(self): + """测试update_ordered_dict函数""" + main_dict = OrderedDict([('a', 1), ('b', 2)]) + new_list = ['b', 'c', 'd'] + + result = update_ordered_dict(main_dict, new_list) + + self.assertEqual(list(result.keys()), ['a', 'b', 'c', 'd']) + self.assertEqual(result['a'], 1) + self.assertIsNone(result['c']) + + def test_get_ordered_stats(self): + """测试get_ordered_stats函数""" + test_stats = ['stat2', 'stat1', 'stat3'] + supported_stats = ['stat1', 'stat2', 'stat3', 'stat4'] + + with patch.object(MonitorConst, 'OP_MONVIS_SUPPORTED', supported_stats): + result = get_ordered_stats(test_stats) + + self.assertEqual(result, ['stat1', 'stat2', 'stat3']) + + def test_get_ordered_stats_with_non_iterable(self): + """测试get_ordered_stats处理非可迭代对象""" + result = get_ordered_stats(123) + self.assertEqual(result, []) + + +class TestMonitorSql(unittest.TestCase): + def test_get_table_definition_all_tables(self): + """测试获取所有表定义""" + result = MonitorSql.get_table_definition() + self.assertIsInstance(result, list) + self.assertEqual(len(result), 4) + self.assertTrue(all("CREATE TABLE" in sql for sql in result)) + + def test_get_table_definition_single_table(self): + """测试获取单个表定义""" + for table in ["monitoring_targets", "monitoring_metrics", "metric_stats", "global_stats"]: + result = MonitorSql.get_table_definition(table) + result = normalize_spaces(result) + self.assertIn(f"CREATE TABLE IF NOT EXISTS {table}", result) + + def test_get_table_definition_invalid_table(self): + """测试获取不存在的表定义""" + with self.assertRaises(ValueError): + MonitorSql.get_table_definition("invalid_table") + + def test_get_metric_table_definition_with_partition(self): + """测试带分区的指标表定义""" + stats = ["norm", "max"] + result = MonitorSql.get_metric_table_definition( + "test_metric", stats, [100, 200]) + result = normalize_spaces(result) + self.assertIn("norm REAL DEFAULT NULL", result) + self.assertIn("max REAL DEFAULT NULL", result) + self.assertIn( + "step INTEGER NOT NULL CHECK(step BETWEEN 100 AND 200)", result) + + def test_get_metric_mapping_sql(self): + """测试获取指标映射SQL""" + result = MonitorSql.get_metric_mapping_sql() + result = normalize_spaces(result) + self.assertIn("SELECT m.metric_id, m.metric_name", result) + self.assertIn("GROUP_CONCAT(ms.stat_name)", result) + + +class TestMonitorDB(unittest.TestCase): + def setUp(self): + # 创建临时数据库文件 + self.temp_db = tempfile.NamedTemporaryFile(delete=False, suffix='.db') + self.db_path = self.temp_db.name + self.monitor_db = MonitorDB(self.db_path, step_partition_size=100) + + # 初始化数据库schema + self.monitor_db.init_schema() + + def tearDown(self): + # 关闭并删除临时数据库文件 + if hasattr(self, 'temp_db'): + self.temp_db.close() + os.unlink(self.db_path) + + def test_init_schema(self): + """测试初始化数据库schema""" + # 验证表是否创建成功 + for table in ["monitoring_targets", "monitoring_metrics", "metric_stats", "global_stats"]: + self.assertTrue(self.monitor_db.db_manager.table_exists(table)) + + # 验证全局统计初始值 + results = self.monitor_db.db_manager.select_data( + "global_stats", columns=["stat_name", "stat_value"]) + self.assertEqual(len(results), 4) + self.assertEqual(results[0]['stat_value'], 0) # max_rank + + def test_get_metric_table_name(self): + """测试生成指标表名""" + # 测试分区边界 + self.assertEqual( + self.monitor_db.get_metric_table_name(1, 50), + ("metric_1_step_0_99", 0, 99) + ) + self.assertEqual( + self.monitor_db.get_metric_table_name(1, 100), + ("metric_1_step_100_199", 100, 199) + ) + self.assertEqual( + self.monitor_db.get_metric_table_name(1, 199), + ("metric_1_step_100_199", 100, 199) + ) + + def test_insert_dimensions(self): + """测试插入维度数据""" + targets = OrderedDict() + targets[("layer1", 0, 0)] = None + targets[("layer2", 0, 1)] = None + + metrics = {"metric1", "metric2"} + metric_stats = { + "metric1": {"norm", "max"}, + "metric2": {"min", "max"} + } + + self.monitor_db.insert_dimensions( + targets=targets, + metrics=metrics, + metric_stats=metric_stats, + min_step=0, + max_step=200 + ) + + # 验证目标插入 + target_results = self.monitor_db.db_manager.select_data( + "monitoring_targets", columns=["target_id"]) + self.assertEqual(len(target_results), 2) + + # 验证指标插入 + metric_results = self.monitor_db.db_manager.select_data( + "monitoring_metrics", columns=["metric_id"]) + self.assertEqual(len(metric_results), 2) + + # 验证指标统计关系插入 + stat_results = self.monitor_db.db_manager.select_data( + "metric_stats", columns=["metric_id"]) + self.assertEqual(len(stat_results), 4) # 2 metrics * 2 stats each + + # 验证指标表创建 + self.assertTrue( + self.monitor_db.db_manager.table_exists("metric_1_step_0_99")) + self.assertTrue(self.monitor_db.db_manager.table_exists( + "metric_1_step_100_199")) + self.assertTrue( + self.monitor_db.db_manager.table_exists("metric_2_step_0_99")) + self.assertTrue(self.monitor_db.db_manager.table_exists( + "metric_2_step_100_199")) + + def test_create_metric_table(self): + """测试创建指标表""" + table_name = self.monitor_db.create_metric_table( + metric_id=1, + step=50, + stats=["norm", "max"] + ) + + self.assertEqual(table_name, "metric_1_step_0_99") + self.assertTrue(self.monitor_db.db_manager.table_exists(table_name)) + + def test_update_global_stats(self): + """测试更新全局统计""" + self.monitor_db.update_global_stats( + max_rank=8, + min_step=10, + max_step=1000 + ) + + # 验证更新结果 + results = self.monitor_db.db_manager.select_data( + "global_stats", columns=["stat_name", "stat_value"]) + stats = {row['stat_name']: row['stat_value'] for row in results} + self.assertEqual(stats['max_rank'], 8) + self.assertEqual(stats['min_step'], 10) + self.assertEqual(stats['max_step'], 1000) + + def test_get_metric_mapping(self): + """测试获取指标映射""" + # 先插入测试数据 + self.monitor_db.db_manager.insert_data( + "monitoring_metrics", + [("metric1",), ("metric2",)], + ["metric_name"] + ) + + # 获取metric_id + metric1_id = self.monitor_db._get_metric_id("metric1") + metric2_id = self.monitor_db._get_metric_id("metric2") + + # 插入统计关系 + self.monitor_db.db_manager.insert_data( + "metric_stats", + [(metric1_id, "norm"), (metric1_id, "max"), (metric2_id, "min")], + ["metric_id", "stat_name"] + ) + + # 测试获取映射 + mapping = self.monitor_db.get_metric_mapping() + + self.assertEqual(len(mapping), 2) + self.assertEqual(mapping["metric1"][0], metric1_id) + self.assertEqual(sorted(mapping["metric1"][1]), ["max", "norm"]) + self.assertEqual(mapping["metric2"][1], ["min"]) + + def test_get_target_mapping(self): + """测试获取目标映射""" + # 先插入测试数据 + self.monitor_db.db_manager.insert_data( + "monitoring_targets", + [("target1", 0, 0), ("target2", 0, 1)], + ["target_name", "vpp_stage", "micro_step"] + ) + + # 测试获取映射 + mapping = self.monitor_db.get_target_mapping() + + self.assertEqual(len(mapping), 2) + self.assertIn(("target1", 0, 0), mapping) + self.assertIn(("target2", 0, 1), mapping) + + def test_insert_rows(self): + """测试插入行数据""" + # 先创建测试表 + self.monitor_db.db_manager.execute_sql( + "CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)" + ) + + # 测试插入 + inserted = self.monitor_db.insert_rows( + "test_table", + [(1, "item1"), (2, "item2")] + ) + + self.assertEqual(inserted, 2) + + # 验证数据 + results = self.monitor_db.db_manager.select_data( + "test_table", columns=["id", "name"]) + self.assertEqual(len(results), 2) + + def test_insert_rows_table_not_exists(self): + """测试插入行数据到不存在的表""" + with self.assertRaises(RuntimeError): + self.monitor_db.insert_rows( + "non_existent_table", + [(1, "item1")] + ) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/single_save/test_single_save.py b/debug/accuracy_tools/msprobe/test/core_ut/single_save/test_single_save.py new file mode 100644 index 0000000000000000000000000000000000000000..95110f7878adea55d6eec6a559a1e5ead1ce9b11 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/single_save/test_single_save.py @@ -0,0 +1,106 @@ +import unittest +import os +import shutil +import torch +import torch.nn as nn +import mindspore +import mindspore.nn as mnn +from mindspore import Tensor +from msprobe.core import SingleSave +from msprobe.core import SingleComparator +from msprobe.core.common.file_utils import read_xlsx + + +# 固定随机性 +torch.manual_seed(42) +mindspore.set_seed(42) + + +# 定义 PyTorch 简单网络 +class SimpleTorchNet(nn.Module): + def __init__(self): + super(SimpleTorchNet, self).__init__() + self.fc1 = nn.Linear(10, 5) + self.fc2 = nn.Linear(5, 1) + + def forward(self, x): + x = self.fc1(x) + x = torch.relu(x) + output = self.fc2(x) + return x, output + + +# 定义 MindSpore 简单网络 +class SimpleMindSporeNet(mnn.Cell): + def __init__(self): + super(SimpleMindSporeNet, self).__init__() + self.fc1 = mnn.Dense(10, 5) + self.fc2 = mnn.Dense(5, 1) + + def construct(self, x): + x = self.fc1(x) + x = mindspore.ops.relu(x) + output = self.fc2(x) + return x, output + + +class TestNetworkComparison(unittest.TestCase): + def setUp(self): + self.torch_dump_path = "./torch_dump" + self.mindspore_dump_path = "./mindspore_dump" + self.output_path = "./compare_output" + self.num_test_cases = 5 # 随机测试用例数量 + + def tearDown(self): + if os.path.exists(self.torch_dump_path): + shutil.rmtree(self.torch_dump_path) + if os.path.exists(self.mindspore_dump_path): + shutil.rmtree(self.mindspore_dump_path) + if os.path.exists(self.output_path): + shutil.rmtree(self.output_path) + + def run_torch_network(self): + net = SimpleTorchNet() + saver = SingleSave(self.torch_dump_path, fmk="pytorch") + + for i in range(self.num_test_cases): + # 为每个测试用例生成不同的随机输入 + input_tensor = torch.randn(1, 10) + x, output = net(input_tensor) + saver.save({"output1": x, "output2": output}) + saver.save({"output1": x}) + saver.step() # 每个输入对应一个step + + def run_mindspore_network(self): + net = SimpleMindSporeNet() + SingleSave._instance = None # 重置单例 + saver = SingleSave(self.mindspore_dump_path, fmk="mindspore") + + for i in range(self.num_test_cases): + # 为每个测试用例生成不同的随机输入 + input_tensor = Tensor(mindspore.numpy.randn(1, 10)) + x, output = net(input_tensor) + saver.save({"output1": x, "output2": output}) + saver.save({"output1": x}) + saver.step() # 每个输入对应一个step + + def test_network_comparison(self): + # 运行 PyTorch 网络并保存多组数据 + self.run_torch_network() + + # 运行 MindSpore 网络并保存多组数据 + self.run_mindspore_network() + + # 使用 SingleComparator 进行对比 + SingleComparator.compare(self.torch_dump_path, self.mindspore_dump_path, self.output_path) + + # 验证输出目录是否存在 + self.assertTrue(os.path.exists(self.output_path)) + + output1_xlsx = read_xlsx(os.path.join(self.output_path, "output1.xlsx")) + self.assertEqual(output1_xlsx.columns.tolist(), SingleComparator.result_header) + self.assertEqual(len(output1_xlsx), 10) + + output2_xlsx = read_xlsx(os.path.join(self.output_path, "output2.xlsx")) + self.assertEqual(output2_xlsx.columns.tolist(), SingleComparator.result_header) + self.assertEqual(len(output2_xlsx), 5) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/core_ut/test_hook_manager.py b/debug/accuracy_tools/msprobe/test/core_ut/test_hook_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..0ad390123ca153d933bd559baa61e667c054a013 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/test_hook_manager.py @@ -0,0 +1,194 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading + +import unittest +from unittest.mock import MagicMock, patch + +from msprobe.core.common.const import Const +from msprobe.core.common.runtime import Runtime +from msprobe.core.hook_manager import BaseHookManager + + +class TestBaseHookManager(unittest.TestCase): + class MockBaseHookManager(BaseHookManager): + @property + def _is_recompute(self): + return False + + @staticmethod + def _no_grad_context(): + return MagicMock() + + @staticmethod + def _add_count(name): + pass + + @staticmethod + def _get_count(name): + pass + + @staticmethod + def _process_kwargs_and_output(module, tid, hook_type, kwargs_or_output, output_or_kwargs): + return {"kwargs": kwargs_or_output}, output_or_kwargs + + def build_hook(self): + pass + + def _register_forward_hook(self, module, api_name): + pass + + def _register_backward_hook(self, module, full_backward_name, args): + pass + + def _register_backward_pre_hook(self, module, full_backward_name, output): + pass + + def _get_params_dict(self, module): + return {} + + def _need_exchange(self, module): + return False + + def setUp(self): + self.mock_data_collector = MagicMock() + self.mock_config = MagicMock() + self.mock_config.data_mode = ["all"] + self.manager = self.MockBaseHookManager( + self.mock_data_collector, + self.mock_config + ) + BaseHookManager.inner_switch[threading.get_ident()] = False + BaseHookManager.hook_handle_dict = {} + BaseHookManager.params_grad_info = {} + + def test_init(self): + self.assertEqual(self.manager.data_collector, self.mock_data_collector) + self.assertEqual(self.manager.config, self.mock_config) + + def test_should_execute_hook_conditions(self): + tid = threading.get_ident() + Runtime.is_running = True + BaseHookManager.inner_switch[tid] = False + self.mock_data_collector.data_processor.is_terminated = False + self.assertTrue(self.manager._should_execute_hook(Const.MODULE, tid)) + self.assertTrue(self.manager._should_execute_hook(Const.API, tid)) + self.assertTrue(self.manager._should_execute_hook(Const.API, tid, is_forward=False)) + + Runtime.is_running = False + self.assertFalse(self.manager._should_execute_hook(Const.MODULE, tid)) + self.assertFalse(self.manager._should_execute_hook(Const.API, tid)) + self.assertTrue(self.manager._should_execute_hook(Const.API, tid, is_forward=False)) + + Runtime.is_running = True + BaseHookManager.inner_switch[tid] = True + self.assertFalse(self.manager._should_execute_hook(Const.MODULE, tid)) + self.assertFalse(self.manager._should_execute_hook(Const.API, tid)) + self.assertFalse(self.manager._should_execute_hook(Const.API, tid, is_forward=False)) + + self.mock_data_collector.data_processor.is_terminated = True + BaseHookManager.inner_switch[tid] = False + self.assertFalse(self.manager._should_execute_hook(Const.MODULE, tid)) + self.assertFalse(self.manager._should_execute_hook(Const.API, tid)) + self.assertFalse(self.manager._should_execute_hook(Const.API, tid, is_forward=False)) + + def test_clear_input_kwargs(self): + module = MagicMock() + tid = threading.get_ident() + module.msprobe_input_kwargs[tid] = {"key": "value"} + self.manager._clear_input_kwargs(module, tid) + self.assertFalse(tid in module.msprobe_input_kwargs) + + def test_register_param_hook(self): + module = MagicMock() + params = {"param1": MagicMock(requires_grad=True)} + full_name = "module.forward" + + with patch.object(self.manager, '_build_grad_hook') as mock_build: + self.manager._register_param_hook(full_name, module, params) + + self.assertEqual(len(BaseHookManager.hook_handle_dict), 1) + self.assertTrue("module.param1" in BaseHookManager.hook_handle_dict) + + self.assertEqual(module.params_grad_name, "module.parameters_grad") + + def test_init_params_grad_info(self): + module = MagicMock() + module.params_grad_name = "grad_name" + params = {"param1": MagicMock(requires_grad=True)} + + self.manager._init_params_grad_info(module, params) + self.mock_data_collector.handle_data.assert_called() + self.assertTrue(BaseHookManager.params_grad_info.get("grad_name")) + + self.manager._init_params_grad_info(module, params) + self.mock_data_collector.handle_data.assert_called_once() + + @patch.object(BaseHookManager, "_should_execute_hook") + def test_forward_pre_hook_behavior(self, mock_should_execute_hook): + mock_should_execute_hook.return_value = True + hook = self.manager._build_forward_pre_hook(Const.API, "api_name") + module = MagicMock() + module.msprobe_input_kwargs = {"kwarg": "value"} + args = (1, 2) + + Runtime.is_running = True + self.mock_data_collector.data_processor.is_terminated = False + with patch.object(self.manager, '_no_grad_context') as mock_ctx: + hook(module, args) + self.mock_data_collector.forward_input_data_collect.assert_called_once() + + @patch.object(BaseHookManager, "_should_execute_hook") + def test_forward_hook_behavior(self, mock_should_execute_hook): + mock_should_execute_hook.return_value = True + hook = self.manager._build_forward_hook(Const.MODULE, "module_name") + module = MagicMock() + args = (1, 2) + kwargs = {"kwargs": []} + output = MagicMock() + + self.mock_data_collector.if_return_forward_new_output.return_value = False + with patch.object(self.manager, '_get_params_dict', return_value={}): + result = hook(module, args, kwargs, output) + self.assertEqual(result, output) + self.mock_data_collector.forward_data_collect.assert_called_once() + self.mock_data_collector.get_forward_new_output.assert_not_called() + + self.mock_data_collector.if_return_forward_new_output.return_value = True + self.mock_data_collector.get_forward_new_output.return_value = "new_output" + with patch.object(self.manager, '_get_params_dict', return_value={}): + result = hook(module, args, output) + self.assertEqual(result, "new_output") + + @patch.object(BaseHookManager, "_should_execute_hook") + def test_backward_hook_behavior(self, mock_should_execute_hook): + mock_should_execute_hook.return_value = True + hook = self.manager._build_backward_hook(Const.API, "api_name") + module = MagicMock() + grad_input = (MagicMock(),) + grad_output = (MagicMock(),) + + module.forward_data_collected = True + Runtime.is_running = True + hook(module, grad_input, grad_output) + + self.mock_data_collector.backward_data_collect.assert_called_once() + + with patch.object(self.manager, '_need_exchange', return_value=True): + hook(module, grad_input, grad_output) + args, _ = self.mock_data_collector.backward_data_collect.call_args_list[1] + self.assertEqual(args[3].grad_input, grad_output) + self.assertEqual(args[3].grad_output, grad_input) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/test_service.py b/debug/accuracy_tools/msprobe/test/core_ut/test_service.py new file mode 100644 index 0000000000000000000000000000000000000000..5a241790055a8a68b1a0c8c7c7eb21c6bacd4e1d --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/test_service.py @@ -0,0 +1,393 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch +import os +import tempfile + +from msprobe.core.service import BaseService +from msprobe.core.common.utils import Const +from msprobe.core.common.runtime import Runtime +from msprobe.core.data_dump.api_registry import ApiRegistry +from msprobe.core.hook_manager import BaseHookManager + + +class ConcreteBaseService(BaseService): + def _init_specific_components(self): + self.logger = MagicMock() + self.api_register = MagicMock() + self.hook_manager = MagicMock() + self.api_template = MagicMock() + + def _register_hook(self): + pass + + def _register_module_hook(self): + pass + + def _get_framework_type(self): + return "TestFramework" + + @staticmethod + def _get_current_rank(): + return 0 + + def _change_jit_switch(self, status): + pass + +class TestBaseService(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.config = MagicMock() + self.config.level = Const.LEVEL_DEBUG + self.config.level_ori = self.config.level + self.config.step = [1, 3] + self.config.rank = [0, 2] + self.config.dump_path = self.temp_dir.name + self.config.task = Const.STATISTICS + self.config.async_dump = True + self.config.tensor_list = [] + self.config.framework = "test_framwork" + + with patch('msprobe.core.service.build_data_collector'): + self.service = ConcreteBaseService(self.config) + + def tearDown(self): + self.temp_dir.cleanup() + + def test_initialization(self): + self.assertEqual(self.service.config.level, Const.LEVEL_DEBUG) + self.assertIsNone(self.service.model) + self.assertIsNotNone(self.service.data_collector) + self.assertEqual(self.service.current_iter, 0) + self.assertEqual(self.service.loop, 0) + self.assertTrue(self.service.first_start) + self.assertFalse(self.service.primitive_switch) + self.assertIsNone(self.service.current_rank) + self.assertIsNone(self.service.dump_iter_dir) + self.assertFalse(self.service.should_stop_service) + self.assertTrue(self.service.currrent_step_first_debug_save) + self.assertEqual(self.service.ori_customer_func, {}) + + def test_properties(self): + self.service.config.level = Const.LEVEL_DEBUG + self.assertTrue(self.service._is_debug_level) + + self.service.config.level = Const.LEVEL_L2 + self.assertTrue(self.service._is_l2_level) + + self.service.config.level = Const.LEVEL_MIX + self.assertTrue(self.service._is_mix_level) + + self.service.config.level = Const.LEVEL_MIX + self.assertTrue(self.service._is_need_module_hook) + + self.service.config.level = Const.LEVEL_MIX + self.assertTrue(self.service._is_need_api_hook) + + self.assertFalse(self.service._need_tensor_data) + + self.service.current_iter = 2 + self.assertTrue(self.service._is_no_dump_step) + + self.service.current_rank = 1 + self.assertTrue(self.service._is_no_dump_rank) + + @patch.object(BaseService, '_get_current_rank') + @patch.object(BaseService, '_process_iteration') + def test_start_debug_level(self, mock_process_iter, mock_get_rank): + self.service.config.level = Const.LEVEL_DEBUG + model_mock = MagicMock() + + self.service.start(model=model_mock) + + mock_get_rank.assert_not_called() + mock_process_iter.assert_called_once() + self.service.logger.info.assert_not_called() + self.assertFalse(Runtime.is_running) + + + @patch.object(ConcreteBaseService, '_register_hook') + @patch.object(ConcreteBaseService, '_register_module_hook') + def test_start_normal_level_first_time(self, mock_register_module_hook, mock_register_hook): + self.service.config.level = Const.LEVEL_MIX + self.service.config.step = [] + self.service.config.rank = [] + model_mock = MagicMock() + self.service.data_collector.data_processor.is_terminated = False + self.service.start(model=model_mock) + + self.assertEqual(self.service.current_rank, 0) + self.assertEqual(Runtime.current_rank, 0) + + mock_register_hook.assert_called_once() + mock_register_module_hook.assert_called_once() + + self.service.logger.info.assert_called_with(f"Dump data will be saved in {self.service.dump_iter_dir}.") + self.assertTrue(Runtime.is_running) + self.assertTrue(self.service.primitive_switch) + self.assertFalse(self.service.first_start) + + @patch.object(ConcreteBaseService, '_register_hook') + @patch.object(ConcreteBaseService, '_register_module_hook') + @patch.object(ConcreteBaseService, 'create_dirs') + def test_start_not_first_calls(self, mock_dirs, mock_register_module_hook, mock_register_hook): + self.service.config.level = Const.LEVEL_L1 + self.service.config.step = [] + self.service.config.rank = [] + self.service.data_collector.data_processor.is_terminated = False + self.service.first_start = False + model_mock = MagicMock() + + self.service.start(model=model_mock) + mock_register_hook.assert_not_called() + mock_register_module_hook.assert_not_called() + self.assertTrue(Runtime.is_running) + self.assertTrue(self.service.primitive_switch) + mock_dirs.assert_called_once() + + def test_start_with_infer_hook(self): + self.service.config.level = Const.LEVEL_L1 + self.service.config.step = [] + self.service.config.rank = [] + self.service.data_collector.data_processor.is_terminated = False + model_mock = MagicMock() + token_range = [10, 20] + + self.service.start(model=model_mock, token_range=token_range) + model_mock.register_forward_pre_hook.assert_called_once() + self.assertEqual(self.service.cur_token_id, 0) + + def test_stop_debug_level(self): + self.config.level = Const.LEVEL_DEBUG + self.service.stop() + self.service.logger.info.assert_not_called() + + @patch.object(BaseService, '_process_async_dump') + def test_stop_normal_level(self, mock_process_async_dump): + self.service.config.level = Const.LEVEL_L1 + self.service.current_iter = 1 + self.service.current_rank = 0 + + self.service.stop() + self.assertFalse(Runtime.is_running) + self.assertFalse(self.service.primitive_switch) + + self.service.logger.info.assert_called_with( + f"{Const.TOOL_NAME}: debugger.stop() is set successfully. " + "Please set debugger.start() to turn on the dump switch again. " + ) + mock_process_async_dump.assert_called_once() + self.service.data_collector.write_json.assert_called_once() + + def test_stop_no_dump_step(self): + self.config.level = Const.LEVEL_L1 + self.service.current_iter = 2 + self.service.stop() + self.service.logger.info.assert_not_called() + + def test_stop_no_dump_rank(self): + self.config.level = Const.LEVEL_L1 + self.service.current_iter = 1 + self.service.current_rank = 1 + self.service.stop() + self.service.logger.info.assert_not_called() + + @patch.object(BaseService, '_process_async_dump') + def test_step(self, mock_process_async_dump): + self.service.step() + self.assertEqual(self.service.loop, 1) + self.assertTrue(self.service.currrent_step_first_debug_save) + mock_process_async_dump.assert_called_once() + self.service.data_collector.write_json.assert_called_once() + self.service.data_collector.reset_status.assert_called_once() + + @patch.object(BaseService, '_process_async_dump') + def test_step_should_stop_service(self, mock_process_async_dump): + self.service.should_stop_service = True + self.service.step() + self.assertEqual(self.service.loop, 0) + mock_process_async_dump.assert_not_called() + + def test_save_debug_level(self): + self.service.loop = 1 + self.service.init_step = 0 + self.service.save("test_var", "test_name", True) + self.service.data_collector.debug_data_collect_forward.assert_called_with("test_var", "test_name.0") + self.service.data_collector.debug_data_collect_backward.assert_called_with("test_var", "test_name_grad.0") + + def test_save_not_debug_level(self): + self.service.config.level = Const.LEVEL_L0 + self.service.loop = 1 + self.service.init_step = 0 + self.service.save("test_var", "test_name", True) + self.service.data_collector.debug_data_collect_forward.assert_not_called() + + def test_save_no_dump_step(self): + self.config.level = Const.LEVEL_DEBUG + self.service.current_iter = 2 + self.service.save("test_var", "test_name", True) + self.service.data_collector.debug_data_collect_forward.assert_not_called() + + def test_save_first_time_in_step(self): + self.service.config.level = Const.LEVEL_DEBUG + self.service.loop = 1 + self.service.init_step = 0 + + self.service.save("test_var", "test_name", True) + + self.assertEqual(self.service.current_rank, 0) + self.assertFalse(self.service.currrent_step_first_debug_save) + self.assertEqual(self.service.debug_variable_counter, {"test_name": 1}) + + self.assertIsNotNone(self.service.dump_iter_dir) + self.assertTrue(os.path.exists(self.service.dump_iter_dir)) + + @patch.object(ApiRegistry, 'register_custom_api') + def test_register_and_restore_custom_api(self, mock_register_custom_api): + module_mock = MagicMock() + api_name = "test_api" + api_prefix = "test_prefix" + self.service.register_custom_api(module_mock, api_name, api_prefix) + key = f"{str(module_mock)}{Const.SEP}{api_name}" + self.assertIn(key, self.service.ori_customer_func) + mock_register_custom_api.assert_called_once() + self.service.restore_custom_api(module_mock, api_name) + self.assertEqual(module_mock.test_api, self.service.ori_customer_func.get(key)) + + def test_build_hook(self): + hook = self.service.build_hook("test_type", "test_name") + self.service.hook_manager.build_hook.assert_called_with("test_type", "test_name") + + def test_create_dirs_pynative_graph(self): + Runtime.run_mode = Const.PYNATIVE_GRAPH_MODE + self.service.current_iter = 1 + self.service.current_rank = 0 + + self.service.create_dirs() + + expected_dir = os.path.join(self.config.dump_path, Const.PYNATIVE_MODE, "step1", "rank0") + self.assertEqual( + self.service.dump_iter_dir, os.path.join(self.config.dump_path, Const.PYNATIVE_MODE, "step1")) + self.assertTrue(os.path.exists(expected_dir)) + + self.service.data_collector.update_dump_paths.assert_called() + self.service.data_collector.initialize_json_file.assert_called() + + def test_create_dirs_pynative_mode(self): + Runtime.run_mode = Const.PYNATIVE_MODE + self.service.current_iter = 1 + self.service.current_rank = 0 + self.service.create_dirs() + expected_dir = os.path.join(self.config.dump_path, "step1", "rank0") + self.assertEqual(self.service.dump_iter_dir, os.path.join(self.config.dump_path, "step1")) + self.assertTrue(os.path.exists(expected_dir)) + + def test_create_dirs_l2_level(self): + self.service.config.level = Const.LEVEL_L2 + self.service.current_iter = 1 + self.service.current_rank = 0 + self.service.create_dirs() + expected_dir = os.path.join(self.config.dump_path, "step1") + self.assertEqual(self.service.dump_iter_dir, expected_dir) + self.assertTrue(os.path.exists(expected_dir)) + + kernel_config_path = os.path.join(expected_dir, "kernel_config_0.json") + self.assertTrue(os.path.exists(kernel_config_path)) + self.assertEqual(self.service.config.kernel_config_path, kernel_config_path) + + def test_need_stop_service_conditions(self): + self.service.current_iter = 4 + self.service.config.step = [1, 2, 3] + self.assertTrue(self.service._need_stop_service()) + self.assertFalse(Runtime.is_running) + self.assertFalse(self.service.primitive_switch) + + self.service.current_iter = 1 + self.service.data_collector.data_processor.is_terminated = True + self.assertTrue(self.service._need_stop_service()) + + self.service.data_collector.data_processor.is_terminated = False + self.service.should_stop_service = False + self.service.current_iter = 1 + self.service.config.step = [1, 2, 3] + self.assertFalse(self.service._need_stop_service()) + + def test_register_api_hook(self): + self.service.config.level = Const.LEVEL_MIX + self.service._register_api_hook() + self.service.api_register.initialize_hook.assert_called() + self.service.api_register.register_all_api.assert_called() + self.service.logger.info.assert_called_with( + f"The api {self.config.task} hook function is successfully mounted to the model." + ) + + def test_register_infer_count_hook(self): + model_mock = MagicMock() + token_range = [5, 10] + + self.service._register_infer_count_hook(model_mock, token_range) + + model_mock.register_forward_pre_hook.assert_called_once() + + hook = model_mock.register_forward_pre_hook.call_args[0][0] + + self.service.cur_token_id = 4 + hook(model_mock, None) + self.assertFalse(Runtime.is_running) + + self.service.cur_token_id = 5 + hook(model_mock, None) + self.assertTrue(Runtime.is_running) + + self.service.cur_token_id = 7 + hook(model_mock, None) + self.assertTrue(Runtime.is_running) + + self.service.cur_token_id = 11 + hook(model_mock, None) + self.assertFalse(Runtime.is_running) + + def test_process_iteration(self): + self.service.loop = 5 + self.service.init_step = 10 + self.service._process_iteration() + + self.assertEqual(self.service.current_iter, 15) + self.assertEqual(Runtime.current_iter, 15) + self.service.data_collector.update_iter.assert_called_with(15) + + def test_process_async_dump(self): + self.service.config.async_dump = True + self.service.config.task = Const.STATISTICS + self.service._process_async_dump() + + self.service.data_collector.data_processor.dump_async_data.assert_called_once() + + def test_process_async_dump_not_needed(self): + self.service.config.async_dump = False + self.service._process_async_dump() + self.service.data_collector.data_processor.dump_async_data.assert_not_called() + + self.service.config.task = Const.OVERFLOW_CHECK + self.service._process_async_dump() + self.service.data_collector.data_processor.dump_async_data.assert_not_called() + + def test_reset_status(self): + self.service._reset_status() + self.service.data_collector.reset_status.assert_called_once() + self.assertEqual(BaseHookManager.params_grad_info, {}) diff --git a/debug/accuracy_tools/msprobe/test/cpp/include/test_utils.cpp b/debug/accuracy_tools/msprobe/test/cpp/include/test_utils.cpp index e744233b3199c15f5ce77b4690bbaa523b0bad45..08dddbed6b7b0c83691826998a2291bb99a40990 100644 --- a/debug/accuracy_tools/msprobe/test/cpp/include/test_utils.cpp +++ b/debug/accuracy_tools/msprobe/test/cpp/include/test_utils.cpp @@ -2,7 +2,6 @@ #include #include #include -#include std::string TEST_ExecShellCommand(const std::string& cmd) { @@ -18,11 +17,10 @@ std::string TEST_ExecShellCommand(const std::string& cmd) return result; } -std::string trim(const std::string& str) +std::string Trim(const std::string& str) { std::string::size_type first = str.find_first_not_of(" \t\n\r\f\v"); std::string::size_type last = str.find_last_not_of(" \t\n\r\f\v"); - if (first == std::string::npos || last == std::string::npos) { return ""; } diff --git a/debug/accuracy_tools/msprobe/test/cpp/include/test_utils.hpp b/debug/accuracy_tools/msprobe/test/cpp/include/test_utils.hpp index ed842b87db77e75e618acd7a25949145a1578c37..08326522a9b06b671e62b5ecacbcc722f485f439 100644 --- a/debug/accuracy_tools/msprobe/test/cpp/include/test_utils.hpp +++ b/debug/accuracy_tools/msprobe/test/cpp/include/test_utils.hpp @@ -5,4 +5,4 @@ #define CONFIG_EXAMPLE __RESOURCES_PATH__"/config.json" std::string TEST_ExecShellCommand(const std::string& cmd); -std::string trim(const std::string& str); +std::string Trim(const std::string& str); diff --git a/debug/accuracy_tools/msprobe/test/cpp/test_config.cpp b/debug/accuracy_tools/msprobe/test/cpp/test_config.cpp index e8b9b73fb66c3fcae40819545c84b7fafb5d2c4d..36033c06a7336ec673eacf24c0c964b6a7719b59 100644 --- a/debug/accuracy_tools/msprobe/test/cpp/test_config.cpp +++ b/debug/accuracy_tools/msprobe/test/cpp/test_config.cpp @@ -2,14 +2,14 @@ #include "gtest/gtest.h" #include "nlohmann/json.hpp" #include "test_utils.hpp" -#include "base/ErrorInfos.hpp" -#include "base/DebuggerConfig.hpp" +#include "base/ErrorInfosManager.h" +#include "base/DebuggerConfig.h" using namespace MindStudioDebugger; namespace MsProbeTest { -static const std::string cfgContent = R"({ +static const std::string CFG_CONTENT = R"({ "task": "statistics", "dump_path": "./dump_path", "rank": [], @@ -104,7 +104,7 @@ void TestConfigMindSpore::SetUp() DebuggerConfig::GetInstance().Reset(); CleanErrorInfoCache(); ErrorInfosManager::SetLogPath(logpath); - cfgJson = nlohmann::json::parse(cfgContent); + cfgJson = nlohmann::json::parse(CFG_CONTENT); } void TestConfigMindSpore::TearDown() @@ -173,7 +173,7 @@ TEST_F(TestConfigMindSpore, TestCommonCfg) ASSERT_EQ(DumpCfgFile(), 0); EXPECT_EQ(cfg.LoadConfig(framework, cfgPath), 0); EXPECT_EQ(cfg.GetTaskList(), std::vector({DebuggerTaskType::TASK_DUMP_STATISTICS})); - EXPECT_EQ(cfg.GetOutputPath(), trim(TEST_ExecShellCommand("realpath ./output1"))); + EXPECT_EQ(cfg.GetOutputPath(), Trim(TEST_ExecShellCommand("realpath ./output1"))); EXPECT_EQ(cfg.GetRankRange(), std::vector({0, 1, 8})); EXPECT_EQ(cfg.GetStepRange(), std::vector({2, 4, 6, 7, 8})); EXPECT_EQ(cfg.GetDebugLevel(), DebuggerLevel::L2); diff --git a/debug/accuracy_tools/msprobe/test/cpp/test_cpython_utils.cpp b/debug/accuracy_tools/msprobe/test/cpp/test_cpython_utils.cpp index 0d9188878c0864d66d76cc3a823b0a0a5cf644d5..8bb5af7123f41fb42091c2cb21d394bce2b1af8d 100644 --- a/debug/accuracy_tools/msprobe/test/cpp/test_cpython_utils.cpp +++ b/debug/accuracy_tools/msprobe/test/cpp/test_cpython_utils.cpp @@ -2,7 +2,7 @@ #include #include "test_utils.hpp" -#include "utils/CPythonUtils.hpp" +#include "utils/CPythonUtils.h" using namespace MindStudioDebugger; using namespace MindStudioDebugger::CPythonUtils; @@ -56,79 +56,79 @@ TEST_F(CPythonUtilsTest, CPythonAgent) { TEST_F(CPythonUtilsTest, PythonObjectFromTo) { // 测试PythonObject的From和To函数 - int32_t input_int = -42; - PythonObject obj_int = PythonObject::From(input_int); - EXPECT_TRUE(obj_int.IsNumber()); + int32_t inputInt = -42; + PythonObject objInt = PythonObject::From(inputInt); + EXPECT_TRUE(objInt.IsNumber()); - int32_t output_int; - EXPECT_EQ(obj_int.To(output_int), 0); - EXPECT_EQ(output_int, input_int); + int32_t outputInt; + EXPECT_EQ(objInt.To(outputInt), 0); + EXPECT_EQ(outputInt, inputInt); - uint32_t input_uint = 56; - PythonObject obj_uint = PythonObject::From(input_uint); - EXPECT_TRUE(obj_uint.IsNumber()); + uint32_t inputUint = 56; + PythonObject objUint = PythonObject::From(inputUint); + EXPECT_TRUE(objUint.IsNumber()); - uint32_t output_uint; - EXPECT_EQ(obj_uint.To(output_uint), 0); - EXPECT_EQ(output_uint, input_uint); + uint32_t outputUint; + EXPECT_EQ(objUint.To(outputUint), 0); + EXPECT_EQ(outputUint, inputUint); - double input_double = 3.14; - PythonObject obj_double = PythonObject::From(input_double); - EXPECT_TRUE(obj_double.IsNumber()); + double inputDouble = 3.14; + PythonObject objDouble = PythonObject::From(inputDouble); + EXPECT_TRUE(objDouble.IsNumber()); - double output_double; - EXPECT_EQ(obj_double.To(output_double), 0); - EXPECT_DOUBLE_EQ(output_double, input_double); + double outputDouble; + EXPECT_EQ(objDouble.To(outputDouble), 0); + EXPECT_DOUBLE_EQ(outputDouble, inputDouble); - std::string input_str = "hello"; - PythonObject obj_str = PythonObject::From(input_str); - EXPECT_TRUE(obj_str.IsString()); + std::string inputStr = "hello"; + PythonObject objStr = PythonObject::From(inputStr); + EXPECT_TRUE(objStr.IsString()); - std::string output_str; - EXPECT_EQ(obj_str.To(output_str), 0); - EXPECT_EQ(output_str, input_str); + std::string outputStr; + EXPECT_EQ(objStr.To(outputStr), 0); + EXPECT_EQ(outputStr, inputStr); - const char* input_char = "world"; - PythonObject obj_str1 = PythonObject::From(input_char); - EXPECT_TRUE(obj_str1.IsString()); + const char* inputChar = "world"; + PythonObject objStr1 = PythonObject::From(inputChar); + EXPECT_TRUE(objStr1.IsString()); - EXPECT_EQ(obj_str1.To(output_str), 0); - EXPECT_EQ(output_str, std::string(input_char)); + EXPECT_EQ(objStr1.To(outputStr), 0); + EXPECT_EQ(outputStr, std::string(inputChar)); - bool input_bool = true; - PythonObject obj_bool = PythonObject::From(input_bool); - EXPECT_TRUE(obj_bool.IsBool()); + bool inputBool = true; + PythonObject objBool = PythonObject::From(inputBool); + EXPECT_TRUE(objBool.IsBool()); - bool output_bool; - EXPECT_EQ(obj_bool.To(output_bool), 0); - EXPECT_EQ(output_bool, input_bool); + bool outputBool; + EXPECT_EQ(objBool.To(outputBool), 0); + EXPECT_EQ(outputBool, inputBool); - std::vector input_vector_int = {1, 2, 3, 100}; - PythonObject list_int_obj = PythonObject::From(input_vector_int); - EXPECT_TRUE(list_int_obj.IsList()); + std::vector inputVectorInt = {1, 2, 3, 100}; + PythonObject listIntObj = PythonObject::From(inputVectorInt); + EXPECT_TRUE(listIntObj.IsList()); - std::vector output_vector_int; - EXPECT_EQ(list_int_obj.To(output_vector_int), 0); + std::vector outputVectorInt; + EXPECT_EQ(listIntObj.To(outputVectorInt), 0); - size_t size = input_vector_int.size(); - EXPECT_EQ(size, output_vector_int.size()); + size_t size = inputVectorInt.size(); + EXPECT_EQ(size, outputVectorInt.size()); for (size_t i = 0; i < size; ++i) { - EXPECT_EQ(input_vector_int[i], output_vector_int[i]); + EXPECT_EQ(inputVectorInt[i], outputVectorInt[i]); } - std::vector input_vector_str = {"a", "bb", "ccc", "dddd"}; - PythonObject list_str_obj = PythonObject::From(input_vector_str); - EXPECT_TRUE(list_str_obj.IsList()); + std::vector inputVectorStr = {"a", "bb", "ccc", "dddd"}; + PythonObject listStrObj = PythonObject::From(inputVectorStr); + EXPECT_TRUE(listStrObj.IsList()); - std::vector output_vector_str; - EXPECT_EQ(list_str_obj.To(output_vector_str), 0); + std::vector outputVectorStr; + EXPECT_EQ(listStrObj.To(outputVectorStr), 0); - size = input_vector_str.size(); - EXPECT_EQ(size, output_vector_str.size()); + size = inputVectorStr.size(); + EXPECT_EQ(size, outputVectorStr.size()); for (size_t i = 0; i < size; ++i) { - EXPECT_EQ(input_vector_str[i], output_vector_str[i]); + EXPECT_EQ(inputVectorStr[i], outputVectorStr[i]); } } @@ -199,18 +199,18 @@ TEST_F(CPythonUtilsTest, PythonNumberObject) { PythonNumberObject o5(PythonObject::From(4.44)); PythonNumberObject o6(PythonObject::From("1111")); - int int_v; - EXPECT_EQ(o1.To(int_v), 0); - EXPECT_EQ(int_v, 123); - double double_v; - EXPECT_EQ(o2.To(double_v), 0); - EXPECT_TRUE(std::fabs(double_v - 3.14) < 1e-5); - EXPECT_EQ(o3.To(int_v), 0); - EXPECT_EQ(int_v, 321); - EXPECT_EQ(o4.To(double_v), 0); - EXPECT_TRUE(std::fabs(double_v - 2.33) < 1e-5); - EXPECT_EQ(o5.To(double_v), 0); - EXPECT_TRUE(std::fabs(double_v - 4.44) < 1e-5); + int intV; + EXPECT_EQ(o1.To(intV), 0); + EXPECT_EQ(intV, 123); + double doubleV; + EXPECT_EQ(o2.To(doubleV), 0); + EXPECT_TRUE(std::fabs(doubleV - 3.14) < 1e-5); + EXPECT_EQ(o3.To(intV), 0); + EXPECT_EQ(intV, 321); + EXPECT_EQ(o4.To(doubleV), 0); + EXPECT_TRUE(std::fabs(doubleV - 2.33) < 1e-5); + EXPECT_EQ(o5.To(doubleV), 0); + EXPECT_TRUE(std::fabs(doubleV - 4.44) < 1e-5); EXPECT_TRUE(o6.IsNone()); } diff --git a/debug/accuracy_tools/msprobe/test/cpp/test_data_utils.cpp b/debug/accuracy_tools/msprobe/test/cpp/test_data_utils.cpp index 11442f12bfea9179ecd4e2e357bcf70b4212ab84..dd6325c2183332bfa0cc2acc592301f7cf58bda8 100644 --- a/debug/accuracy_tools/msprobe/test/cpp/test_data_utils.cpp +++ b/debug/accuracy_tools/msprobe/test/cpp/test_data_utils.cpp @@ -2,7 +2,7 @@ #include #include #include -#include "utils/DataUtils.hpp" +#include "utils/DataUtils.h" using namespace MindStudioDebugger; using namespace MindStudioDebugger::DataUtils; @@ -10,15 +10,15 @@ using namespace MindStudioDebugger::DataUtils; namespace MsProbeTest { TEST(DataUtilsTest, TestUnpackUint64Value) { - uint64_t data_le = 0x0102030405060708; - uint64_t result = UnpackUint64Value_Le(&data_le); + uint64_t dataLe = 0x0102030405060708; + uint64_t result = UnpackUint64ValueLe(&dataLe); #if __BYTE_ORDER == __LITTLE_ENDIAN EXPECT_EQ(result, 0x0102030405060708); #else EXPECT_EQ(result, 0x0807060504030201); #endif - uint64_t data_be = 0x0102030405060708; - result = UnpackUint64Value_Be(&data_be); + uint64_t dataBe = 0x0102030405060708; + result = UnpackUint64ValueBe(&dataBe); #if __BYTE_ORDER == __LITTLE_ENDIAN EXPECT_EQ(result, 0x0807060504030201); #else @@ -74,7 +74,7 @@ TEST(DataUtilsTest, TestGetFormatString) { EXPECT_EQ(GetFormatString(TensorFormat::FORMAT_FRACTAL_Z), "FRACTAL_Z"); EXPECT_EQ(GetFormatString(TensorFormat::FORMAT_C1HWNC0), "C1HWNC0"); EXPECT_EQ(GetFormatString(TensorFormat::FORMAT_HWCN), "HWCN"); - EXPECT_EQ(GetFormatString(TensorFormat::FORMAT_C1HWNCoC0), "C1HWNCoC0"); + EXPECT_EQ(GetFormatString(TensorFormat::FORMAT_C1HWNCOC0), "C1HWNCoC0"); EXPECT_EQ(GetFormatString(TensorFormat::FORMAT_DHWNC), "DHWNC"); EXPECT_EQ(GetFormatString(TensorFormat::FORMAT_NCL), "NCL"); EXPECT_EQ(GetFormatString(TensorFormat::FORMAT_MAX), "UNKNOWN"); diff --git a/debug/accuracy_tools/msprobe/test/cpp/test_environ.cpp b/debug/accuracy_tools/msprobe/test/cpp/test_environ.cpp index 94c830227ae58637642a189f36ade78de9a2a75c..be30c5c219ce1b34b92a1453eeb2050479bc7b97 100644 --- a/debug/accuracy_tools/msprobe/test/cpp/test_environ.cpp +++ b/debug/accuracy_tools/msprobe/test/cpp/test_environ.cpp @@ -2,8 +2,8 @@ #include #include "include/test_utils.hpp" -#include "base/DebuggerConfig.hpp" -#include "base/Environment.hpp" +#include "base/DebuggerConfig.h" +#include "base/Environment.h" using namespace MindStudioDebugger; using namespace MindStudioDebugger::Environment; diff --git a/debug/accuracy_tools/msprobe/test/cpp/test_file_operation.cpp b/debug/accuracy_tools/msprobe/test/cpp/test_file_operation.cpp index 2886126e9f568fba6b8ce3eabd752653d4493108..99dbe8124d17fe7cdcfd1f812d2132f416cea1ea 100644 --- a/debug/accuracy_tools/msprobe/test/cpp/test_file_operation.cpp +++ b/debug/accuracy_tools/msprobe/test/cpp/test_file_operation.cpp @@ -4,8 +4,8 @@ #include #include "test_utils.hpp" -#include "utils/DataUtils.hpp" -#include "utils/FileOperation.hpp" +#include "utils/DataUtils.h" +#include "utils/FileOperation.h" using namespace MindStudioDebugger; using namespace MindStudioDebugger::FileOperation; diff --git a/debug/accuracy_tools/msprobe/test/cpp/test_file_utils.cpp b/debug/accuracy_tools/msprobe/test/cpp/test_file_utils.cpp index 03449f761be0c8548021218581f4cbff12d4e07d..022ae396ba3d7a8343792a438162a74ff526fc75 100644 --- a/debug/accuracy_tools/msprobe/test/cpp/test_file_utils.cpp +++ b/debug/accuracy_tools/msprobe/test/cpp/test_file_utils.cpp @@ -8,7 +8,7 @@ #include #include "test_utils.hpp" -#include "utils/FileUtils.hpp" +#include "utils/FileUtils.h" using namespace MindStudioDebugger; using namespace MindStudioDebugger::FileUtils; @@ -52,7 +52,7 @@ TEST_F(FileUtilsTest, TestIsPathExist) TEST_F(FileUtilsTest, TestGetAbsPath) { - std::string pwd = trim(TEST_ExecShellCommand("pwd")); + std::string pwd = Trim(TEST_ExecShellCommand("pwd")); EXPECT_EQ(pwd, GetAbsPath(".")); EXPECT_EQ(pwd + "/testpath", GetAbsPath("./testpath")); EXPECT_EQ(pwd + "/testpath", GetAbsPath("./testpath/")); @@ -210,8 +210,8 @@ TEST_F(FileUtilsTest, TestIsPathLengthLegal) TEST_F(FileUtilsTest, TestIsPathDepthValid) { EXPECT_TRUE(IsPathDepthValid("")); - EXPECT_TRUE(IsPathDepthValid(std::string(PATH_DEPTH_MAX, pathSeparator))); - EXPECT_FALSE(IsPathDepthValid(std::string(PATH_DEPTH_MAX + 1, pathSeparator))); + EXPECT_TRUE(IsPathDepthValid(std::string(PATH_DEPTH_MAX, PATH_SEPARATOR))); + EXPECT_FALSE(IsPathDepthValid(std::string(PATH_DEPTH_MAX + 1, PATH_SEPARATOR))); } TEST_F(FileUtilsTest, TestIsFileOwner) diff --git a/debug/accuracy_tools/msprobe/test/cpp/test_log.cpp b/debug/accuracy_tools/msprobe/test/cpp/test_log.cpp index 254b54359a50166e1d893c5b936eb220ee0b2a73..ddf6950fd5b6ff0c7f191e5aa0f8e79897db6c7c 100644 --- a/debug/accuracy_tools/msprobe/test/cpp/test_log.cpp +++ b/debug/accuracy_tools/msprobe/test/cpp/test_log.cpp @@ -2,7 +2,7 @@ #include "gtest/gtest.h" #include "test_utils.hpp" -#include "base/ErrorInfos.hpp" +#include "base/ErrorInfosManager.h" using namespace MindStudioDebugger; diff --git a/debug/accuracy_tools/msprobe/test/cpp/test_math_utils.cpp b/debug/accuracy_tools/msprobe/test/cpp/test_math_utils.cpp index 3b23e9c879c431ef7457990ba774aa0dc1321b45..8e57d2cd53b5c52ca62084e8ecbd49e3e8138682 100644 --- a/debug/accuracy_tools/msprobe/test/cpp/test_math_utils.cpp +++ b/debug/accuracy_tools/msprobe/test/cpp/test_math_utils.cpp @@ -3,7 +3,7 @@ #include #include #include -#include "utils/MathUtils.hpp" +#include "utils/MathUtils.h" using namespace MindStudioDebugger; using namespace MindStudioDebugger::MathUtils; diff --git a/debug/accuracy_tools/msprobe/test/cpp/test_precision_debugger.cpp b/debug/accuracy_tools/msprobe/test/cpp/test_precision_debugger.cpp index 69df0c18fcc27cd0ac359262649fcc588f2e9b9f..2832f2345d7d72efe2d1bb305c044f56d9b83e8d 100644 --- a/debug/accuracy_tools/msprobe/test/cpp/test_precision_debugger.cpp +++ b/debug/accuracy_tools/msprobe/test/cpp/test_precision_debugger.cpp @@ -2,9 +2,9 @@ #include #include "include/test_utils.hpp" -#include "third_party/ACL/AclApi.hpp" -#include "base/ErrorInfos.hpp" -#include "core/PrecisionDebugger.hpp" +#include "third_party/ACL/AclApi.h" +#include "base/ErrorInfosManager.h" +#include "core/PrecisionDebugger.h" using namespace MindStudioDebugger; @@ -17,15 +17,15 @@ public: std::string Name() const override {return "PrecisionDbgTaskStub";} bool Condition(const DebuggerConfig& cfg) const override {return true;} - void Initialize(const DebuggerConfig& cfg) {initialize_called = true;} - void OnStart() {start_called = true;} - void OnStop() {stop_called = true;} - void OnStep() {step_called = true;} + void Initialize(const DebuggerConfig& cfg) {initializeCalled = true;} + void OnStart() {startCalled = true;} + void OnStop() {stopCalled = true;} + void OnStep() {stepCalled = true;} - bool initialize_called{false}; - bool start_called{false}; - bool stop_called{false}; - bool step_called{false}; + bool initializeCalled{false}; + bool startCalled{false}; + bool stopCalled{false}; + bool stepCalled{false}; }; class PrecisionDbgTaskUselessStub : public PrecisionDbgTaskStub { @@ -35,11 +35,11 @@ public: TEST(PrecisionDebuggerTest, TestRegisterBeforeInit) { PrecisionDebugger& debugger = PrecisionDebugger::GetInstance(); - PrecisionDbgTaskStub stub_task; + PrecisionDbgTaskStub stubTask; DebuggerConfig::GetInstance().Reset(); - debugger.RegisterDebuggerTask(&stub_task); - stub_task.Register(); + debugger.RegisterDebuggerTask(&stubTask); + stubTask.Register(); EXPECT_FALSE(debugger.IsEnable()); EXPECT_EQ(debugger.GetCurStep(), 0); @@ -49,12 +49,12 @@ TEST(PrecisionDebuggerTest, TestRegisterBeforeInit) { debugger.Step(); EXPECT_EQ(debugger.GetCurStep(), 0); - EXPECT_FALSE(stub_task.initialize_called); - EXPECT_FALSE(stub_task.start_called); - EXPECT_FALSE(stub_task.stop_called); - EXPECT_FALSE(stub_task.step_called); + EXPECT_FALSE(stubTask.initializeCalled); + EXPECT_FALSE(stubTask.startCalled); + EXPECT_FALSE(stubTask.stopCalled); + EXPECT_FALSE(stubTask.stepCalled); - debugger.UnRegisterDebuggerTask(&stub_task); + debugger.UnRegisterDebuggerTask(&stubTask); debugger.UnRegisterDebuggerTask(nullptr); } @@ -81,39 +81,39 @@ TEST(PrecisionDebuggerTest, TestInit) { TEST(PrecisionDebuggerTest, TestSubTaskDispatch) { PrecisionDebugger& debugger = PrecisionDebugger::GetInstance(); - PrecisionDbgTaskStub stub_task1; - PrecisionDbgTaskStub stub_task2; - PrecisionDbgTaskUselessStub stub_task3; + PrecisionDbgTaskStub stubTask1; + PrecisionDbgTaskStub stubTask2; + PrecisionDbgTaskUselessStub stubTask3; MOCKER(MindStudioDebugger::AscendCLApi::LoadAclApi) .stubs() .then(returnValue(0)); - MOCKER(MindStudioDebugger::AscendCLApi::ACLAPI_aclrtSynchronizeDevice) + MOCKER(MindStudioDebugger::AscendCLApi::AclApiAclrtSynchronizeDevice) .stubs() .then(returnValue(0)) .expects(atLeast(1)); - stub_task1.Register(); + stubTask1.Register(); EXPECT_EQ(debugger.Initialize("MindSpore", CONFIG_EXAMPLE), 0); - stub_task2.Register(); - stub_task3.Register(); + stubTask2.Register(); + stubTask3.Register(); - EXPECT_TRUE(stub_task1.initialize_called); - EXPECT_TRUE(stub_task2.initialize_called); - EXPECT_FALSE(stub_task3.initialize_called); - EXPECT_FALSE(stub_task1.start_called); - EXPECT_FALSE(stub_task2.stop_called); - EXPECT_FALSE(stub_task3.step_called); + EXPECT_TRUE(stubTask1.initializeCalled); + EXPECT_TRUE(stubTask2.initializeCalled); + EXPECT_FALSE(stubTask3.initializeCalled); + EXPECT_FALSE(stubTask1.startCalled); + EXPECT_FALSE(stubTask2.stopCalled); + EXPECT_FALSE(stubTask3.stepCalled); debugger.Start(); - EXPECT_TRUE(stub_task1.start_called); - EXPECT_FALSE(stub_task3.start_called); + EXPECT_TRUE(stubTask1.startCalled); + EXPECT_FALSE(stubTask3.startCalled); debugger.Stop(); - EXPECT_TRUE(stub_task1.stop_called); - EXPECT_TRUE(stub_task2.stop_called); + EXPECT_TRUE(stubTask1.stopCalled); + EXPECT_TRUE(stubTask2.stopCalled); debugger.Step(); - EXPECT_TRUE(stub_task1.step_called); + EXPECT_TRUE(stubTask1.stepCalled); GlobalMockObject::verify(); GlobalMockObject::reset(); diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_api_accuracy_checker.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_api_accuracy_checker.py index 2cf47a2064626db792c55efb449b47ca8ab9b04e..7214dfab3885689b4c5b0c903e430ec9dcecb989 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_api_accuracy_checker.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_api_accuracy_checker.py @@ -4,6 +4,9 @@ import logging import os import json import csv +import tempfile +import shutil + from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker @@ -40,13 +43,42 @@ def find_with_prefix(directory, prefix): class Args: - def __init__(self, api_info_file=None, out_path=None, result_csv_path=None): + def __init__(self, api_info_file=None, out_path=None, result_csv_path=None, save_error_data=False): self.api_info_file = api_info_file if api_info_file is not None else os.path.join(directory, "files", "api_info_statistics.json") self.out_path = out_path if out_path is not None else os.path.join(directory, "files") self.result_csv_path = result_csv_path if result_csv_path is not None else "" + self.save_error_data = save_error_data class TestApiAccuracyChecker(unittest.TestCase): + def test_init_save_error_data(self): + # 使用临时目录,不污染项目文件 + temp_dir = tempfile.mkdtemp() + try: + # 构造 args,只关注 out_path 和 save_error_data + args = Args(out_path=temp_dir, save_error_data=True) + config, dump_path_agg = ApiAccuracyChecker.init_save_error_data(args) + + # 1. config 字段检查 + self.assertEqual(config.execution_mode, "pynative") + self.assertEqual(config.task, "tensor") + self.assertEqual(config.dump_path, temp_dir) + self.assertEqual(config.dump_tensor_data_dir, temp_dir) + self.assertFalse(config.async_dump) + self.assertEqual(config.file_format, "npy") + + # 2. error_data 目录已创建 + error_dir = os.path.join(temp_dir, "error_data") + self.assertTrue(os.path.isdir(error_dir), f"{error_dir} should exist") + + # 3. dump_path_agg 路径检查 + self.assertEqual(dump_path_agg.dump_file_path, os.path.join(temp_dir, "dump.json")) + self.assertEqual(dump_path_agg.stack_file_path, os.path.join(temp_dir, "stack.json")) + self.assertEqual(dump_path_agg.dump_tensor_data_dir, error_dir) + + finally: + # 清理临时目录 + shutil.rmtree(temp_dir) def test_statistics_mode(self): api_info_statistics_path = os.path.join(directory, "files", "api_info_statistics.json") diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_api_runner.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_api_runner.py index dac2b9b364f6e4f271ff5ae8a43052a3fd2496d2..8d6387bb73bcf455b4066ee6796ac3227b15f10d 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_api_runner.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_api_runner.py @@ -121,10 +121,20 @@ class TestApiRunner(unittest.TestCase): ] for test_case in test_cases: api_instance, api_input_aggregation, forward_or_backward, api_platform, results_target = test_case - results_real = api_runner.run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform) + output = api_runner.run_api(api_instance, api_input_aggregation, forward_or_backward, + api_platform) + + # 如果返回的是 tuple,就拿第一个元素;否则直接当 list 用 + if isinstance(output, tuple): + results_real = output[0] + else: + results_real = output + # 下面跟原来测试逻辑一模一样 for res_real, res_target in zip(results_real, results_target): - assert (abs(res_real.get_parameter() - res_target.get_parameter(tensor_platform=api_platform)) < 1e-5).all() - + assert (abs( + res_real.get_parameter() + - res_target.get_parameter(tensor_platform=api_platform) + ) < 1e-5).all() def test_get_api_instance(self): #api_type_str, api_sub_name, api_platform, result_api diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_bench_functions.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_bench_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..2a2d3d2d7536f25a4171e9278f2d467711b34130 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_bench_functions.py @@ -0,0 +1,397 @@ +import unittest +import torch +import numpy as np + +from msprobe.mindspore.api_accuracy_checker.bench_functions.flash_attention_score import ( + softmax_forward, softmax_grad, broadcast_kv, calculate_qk, + fusion_attention_forward, fusion_attention_backward, + parse_bsnd_args, convert_from_bnsd, convert_to_bnsd, + generate_attn_mask, generate_kv, rebuid_softmax_by_qkv, + rebuild_softmax_by_max_sum, FlashAttentionScore, + npu_fusion_attention_forward_patch, npu_fusion_attention_backward_patch, + FaForwardParams, FaBackwardParams, RebuildSoftmaxParams, GTYPE, + get_head_num, get_input_layout +) + +class TestBenchFunctions(unittest.TestCase): + + def setUp(self): + torch.manual_seed(0) + np.random.seed(0) + + def test_softmax_forward(self): + x = torch.tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 4.0]], dtype=torch.float64) + res, x_max, x_sum = softmax_forward(x) + expected = torch.softmax(x, dim=-1) + self.assertTrue(torch.allclose(res, expected, atol=1e-6)) + self.assertTrue(torch.allclose(x_max, x.max(dim=-1, keepdim=True)[0], atol=1e-6)) + manual_sum = torch.exp(x - x_max).sum(dim=-1, keepdim=True) + self.assertTrue(torch.allclose(x_sum, manual_sum, atol=1e-6)) + + def test_softmax_grad(self): + x = torch.randn(4, 5, dtype=torch.float64) + y, _, _ = softmax_forward(x) + dp = torch.randn_like(y) + grad = softmax_grad(dp, y) + sum_grad = grad.sum(dim=-1) + self.assertTrue(torch.allclose(sum_grad, torch.zeros_like(sum_grad), atol=1e-6)) + + def test_broadcast_kv(self): + B, N_kv, S, D = 1, 2, 3, 4 + num_heads = 4 + kv = torch.arange(B*N_kv*S*D, dtype=torch.float32).reshape(B, N_kv, S, D) + out = broadcast_kv(num_heads, N_kv, kv, kv.dtype) + self.assertEqual(out.shape, (B, num_heads, S, D)) + self.assertTrue(torch.equal(out[:, :2, :, :], kv[:, 0:1, :, :].expand(B,2,S,D))) + self.assertTrue(torch.equal(out[:, 2:, :, :], kv[:, 1:2, :, :].expand(B,2,S,D))) + + def test_broadcast_kv_invalid(self): + kv = torch.randn(1,3,4) # 3D tensor + with self.assertRaises(ValueError): + broadcast_kv(4, 2, kv, kv.dtype) + kv4d = torch.randn(1,2,3,4) + with self.assertRaises(ValueError): + broadcast_kv(0, 2, kv4d, kv4d.dtype) + with self.assertRaises(ValueError): + broadcast_kv(4, 3, kv4d, kv4d.dtype) + with self.assertRaises(ValueError): + broadcast_kv(3, 2, kv4d, kv4d.dtype) + + def test_calculate_qk_basic(self): + q = torch.randn(2,2,3,4) + k = torch.randn(2,2,3,4) + scalar = 2.0 + qk = calculate_qk(q, k, None, None, scalar) + expected = torch.matmul(q, k.permute(0,1,3,2)) * scalar + self.assertTrue(torch.allclose(qk, expected, atol=1e-6)) + + def test_calculate_qk_errors(self): + q = torch.randn(2,2,3,4) + # head_dim mismatch + k = torch.randn(2,2,3,5) + with self.assertRaises(ValueError): + calculate_qk(q, k, None, None, 1.0) + # too few dims + q3 = torch.randn(2,3,4) + k3 = torch.randn(2,3,4) + with self.assertRaises(ValueError): + calculate_qk(q3, k3, None, None, 1.0) + + def test_calculate_qk_with_pse_and_mask(self): + q = torch.ones(1,1,2,2) + k = torch.ones(1,1,2,2) + pse = torch.ones(1,1,2,2) + mask = torch.zeros(1,1,2,2) + scalar = 1.0 + qk = calculate_qk(q, k, mask, pse, scalar) + expected = (torch.matmul(q, k.permute(0,1,3,2)) + pse) * scalar + mask.bool() * (-40000.0) + self.assertTrue(torch.allclose(qk, expected, atol=1e-6)) + + def test_parse_bsnd_args_bsh(self): + q = torch.randn(2,3,4) + k = torch.randn(2,5,4) + head_num = 2 + args = parse_bsnd_args(q, k, head_num, "BSH") + b, s1, s2, n1, n2, d, h1, h2, dtype = args + self.assertEqual((b, s1, s2, n1, n2, d, h1, h2), (2,3,5,2,2,2,4,4)) + + def test_parse_bsnd_args_errors(self): + q = torch.randn(2,3,0) # leads to d=0 + k = torch.randn(2,5,0) + with self.assertRaises(ValueError): + parse_bsnd_args(q, k, 2, "BSH") + with self.assertRaises(ValueError): + parse_bsnd_args(q, k, 2, "TND") + with self.assertRaises(ValueError): + parse_bsnd_args(q, k, 0, "BSH") + + def test_convert_from_and_to_bnsd(self): + B, N, S, D = 1, 2, 3, 4 + x = torch.arange(B*N*S*D).reshape(B, N, S, D) + for layout in ["BSH","SBH","BSND","BNSD"]: + out = convert_from_bnsd(x, layout) + back = convert_to_bnsd(out, N, layout) + self.assertTrue(torch.equal(back, x.to(GTYPE))) + with self.assertRaises(ValueError): + convert_to_bnsd(torch.randn(2,2), N, "TND") + + def test_generate_attn_mask_shapes_and_reverse(self): + b,n1,s1,s2 = 1,1,3,3 + dtype = torch.float32 + for mode in range(5): + mask = generate_attn_mask(mode, None, b, n1, s1, s2, 0, 0, dtype) + self.assertEqual(mask.shape, (s1, s2)) + # reverse from full 2048 mask + orig = torch.from_numpy(np.triu(np.ones([2048,2048]), k=1)).to(dtype) + for mode in [2,3,4]: + rev = generate_attn_mask(mode, orig, b, n1, s1, s2, 0, 0, dtype) + self.assertEqual(rev.shape, (s1, s2)) + + def test_generate_kv(self): + key = torch.randn(1,2,3,4) + value = torch.randn(1,2,3,4) + k_new, v_new = generate_kv(key, value, 2, 2) + self.assertTrue(torch.equal(k_new, key)) + k2, _ = generate_kv(key, value, 4, 2) + self.assertEqual(k2.shape[1], 4) + + + def test_rebuild_softmax_by_max_sum_and_errors(self): + # 正常路径 + B, N, S, D = 1, 1, 3, 4 + q = torch.randn(B, N, S, D) + k = torch.randn(B, N, S, D) + attn_mask = None + pse = None + scalar = 1.0 + # 手动构造 softmax_max、softmax_sum + qk, softmax_max, softmax_sum = softmax_forward(torch.matmul(q, k.permute(0,1,3,2))) + params = RebuildSoftmaxParams(q=q, k=k, attn_mask=attn_mask, pse=pse, + scalar_value=scalar, + softmax_max=softmax_max, softmax_sum=softmax_sum) + res = rebuild_softmax_by_max_sum(params) + self.assertTrue(torch.allclose(res, torch.softmax(torch.matmul(q,k.permute(0,1,3,2)), dim=-1))) + + # softmax_max 最后一维为 0 时抛错 + bad_max = torch.empty(B, N, S, 0) + bad_params = params._replace(softmax_max=bad_max) + with self.assertRaises(ValueError): + rebuild_softmax_by_max_sum(bad_params) + + def test_npu_patch_forward_and_backward_patch(self): + # forward_patch 长度不足报错 + with self.assertRaises(RuntimeError): + npu_fusion_attention_forward_patch(1) + # backward_patch 长度不等于6 报错 + with self.assertRaises(ValueError): + npu_fusion_attention_backward_patch(1,2,3) + + # 正常调用,检查返回结构 + B, S1, S2, N1, D = 1, 2, 2, 2, 4 + head_num = 1 + layout = "BSH" + q = torch.randn(B, S1, N1 * D) + k = torch.randn(B, S2, N1 * D) + # forward_patch 返回 args, dims_kwargs, new_kwargs + args, dims, new_kwargs = npu_fusion_attention_forward_patch(q, k, None, head_num, layout) + self.assertIn("b", dims) + # backward_patch + dx = torch.randn_like(q) + args2, dims2, new_kwargs2 = npu_fusion_attention_backward_patch(q, k, None, dx, head_num, layout) + self.assertIn("s1", dims2) + + + def test_fusion_attention_forward_with_drop_mask(self): + B, N, S, D = 1, 2, 3, 4 + q = torch.randn(B, N, S, D, dtype=torch.float64) + k = torch.randn(B, N, S, D, dtype=torch.float64) + v = torch.randn(B, N, S, D, dtype=torch.float64) + # 制造一个 drop_mask + drop_mask = torch.randint(0, 2, (B, N, S, S), dtype=torch.float64) + # 注意 drop_mask 需要能广播到 softmax_res 形状 (B,N,S,S) + params = FaForwardParams( + q=q, k=k, v=v, + drop_mask=drop_mask, + attn_mask=None, + pse=None, + scalar_value=1.0, + keep_prob=0.5 + ) + y1, _, _ = fusion_attention_forward(params) + # 手动计算:先 softmax,再 mask,再 matmul + qk = calculate_qk(q, k, None, None, 1.0) + sm, _, _ = softmax_forward(qk) + masked = sm * drop_mask * (1.0 / 0.5) + y2 = torch.matmul(masked, v) + self.assertTrue(torch.allclose(y1, y2)) + + def test_fusion_attention_backward_with_drop_mask(self): + B, N, S, D = 1, 2, 3, 4 + # 构造前向结果 + dx = torch.randn(B, N, S, D, dtype=torch.float64) + q = torch.randn(B, N, S, D, dtype=torch.float64) + k = torch.randn(B, N, S, D, dtype=torch.float64) + v = torch.randn(B, N, S, D, dtype=torch.float64) + # 构造 softmax_res (B,N,S,S) 和 drop_mask + sm = torch.softmax(torch.randn(B, N, S, S, dtype=torch.float64), dim=-1) + drop_mask = torch.randint(0, 2, (B, N, S, S), dtype=torch.float64) + params = FaBackwardParams( + dx=dx, q=q, k=k, v=v, + softmax_res=sm, + drop_mask=drop_mask, + pse=None, + scalar_value=1.0, + keep_prob=0.8 + ) + dq1, dk1, dv1 = fusion_attention_backward(params) + # 直接对比形状和 dtype + self.assertEqual(dq1.shape, q.shape) + self.assertEqual(dk1.shape, k.shape) + self.assertEqual(dv1.shape, v.shape) + self.assertEqual(dq1.dtype, torch.float64) + + def test_get_head_num_and_input_layout_errors(self): + # 既无 kwargs, 也无足够 args + with self.assertRaises(ValueError): + get_head_num(1, 2, 3) + with self.assertRaises(ValueError): + get_input_layout(1, 2, 3, 4) + + def test_npu_forward_patch_sanity(self): + # 测试 sparse_mode 非零路径 + B, S1, S2, N1, D = 1, 3, 5, 2, 4 + head_num = 2 + layout = "BSH" + q = torch.randn(B, S1, N1 * D) + k = torch.randn(B, S2, N1 * D) + # 传入 sparse_mode, pre/next token,pse + args, dims, new_kwargs = npu_fusion_attention_forward_patch( + q, k, None, + head_num, layout, + sparse_mode=3, + pre_tockens=1, + next_tockens=2, + pse=torch.ones(1), + ) + # dims 检查 + self.assertEqual(dims["b"], B) + self.assertEqual(dims["s1"], S1) + self.assertEqual(dims["s2"], S2) + self.assertEqual(new_kwargs["sparse_mode"], 3) + self.assertTrue("pse" in new_kwargs) + + def test_npu_backward_patch_sanity(self): + B, S1, S2, N1, D = 1, 4, 4, 2, 4 + head_num = 2 + layout = "BSH" + q = torch.randn(B, S1, N1 * D) + k = torch.randn(B, S2, N1 * D) + dx = torch.randn(B, S1, N1 * D) + # 正确长度 + args, dims, new_kwargs = npu_fusion_attention_backward_patch( + q, k, None, dx, head_num, layout + ) + self.assertEqual(dims["n1"], N1) + self.assertEqual(dims["n2"], N1) + # 传入 n2 不整除 n1 抛错 + with self.assertRaises(ValueError): + npu_fusion_attention_backward_patch( + q, k, None, dx, 3, layout + ) + + def test_gtype_constant(self): + # Ensure GTYPE matches expected torch dtype + self.assertEqual(GTYPE, torch.float64) + + def test_softmax_forward_and_sum(self): + x = torch.tensor([[0.5, -0.5], [2.0, 3.0]], dtype=GTYPE) + res, x_max, x_sum = softmax_forward(x) + expected = torch.softmax(x, dim=-1) + self.assertTrue(torch.allclose(res, expected, atol=1e-6)) + self.assertTrue(torch.allclose(x_sum, torch.exp(x - x_max).sum(dim=-1, keepdim=True), atol=1e-6)) + + def test_softmax_grad_zero_sum(self): + x = torch.randn(3, 4, dtype=GTYPE) + y, _, _ = softmax_forward(x) + dp = torch.randn_like(y) + grad = softmax_grad(dp, y) + self.assertTrue(torch.allclose(grad.sum(dim=-1), torch.zeros_like(grad.sum(dim=-1)), atol=1e-6)) + + def test_broadcast_kv_and_errors(self): + B, N_kv, S, D = 2, 1, 3, 4 + num_heads = 2 + kv = torch.arange(B*N_kv*S*D, dtype=torch.float32).reshape(B, N_kv, S, D) + out = broadcast_kv(num_heads, N_kv, kv, kv.dtype) + self.assertEqual(out.shape, (B, num_heads, S, D)) + # invalid dims + with self.assertRaises(ValueError): + broadcast_kv(2, 0, kv, kv.dtype) + with self.assertRaises(ValueError): + broadcast_kv(3, 2, kv, kv.dtype) + + def test_calculate_qk_basic_and_errors(self): + q = torch.randn(1,1,2,3) + k = torch.randn(1,1,2,3) + scalar = 0.5 + out = calculate_qk(q, k, None, None, scalar) + expected = torch.matmul(q, k.permute(0,1,3,2)) * scalar + self.assertTrue(torch.allclose(out, expected)) + # shape mismatch + k_bad = torch.randn(1,1,2,4) + with self.assertRaises(ValueError): + calculate_qk(q, k_bad, None, None, scalar) + # low dims + q3 = torch.randn(2,3,4) + with self.assertRaises(ValueError): + calculate_qk(q3, q3, None, None, scalar) + + def test_fusion_attention_forward_backward_no_mask(self): + B, N, S, D = 1, 2, 3, 4 + q = torch.randn(B,N,S,D,dtype=GTYPE) + k = torch.randn(B,N,S,D,dtype=GTYPE) + v = torch.randn(B,N,S,D,dtype=GTYPE) + params = FaForwardParams(q=q, k=k, v=v, drop_mask=None, attn_mask=None, + pse=None, scalar_value=1.0, keep_prob=1.0) + y, m, s = fusion_attention_forward(params) + # gradient + dx = torch.randn_like(y) + backward = FaBackwardParams(dx=dx, q=q, k=k, v=v, + softmax_res=torch.softmax(calculate_qk(q, k, None, None, 1.0), dim=-1), + drop_mask=None, pse=None, scalar_value=1.0, keep_prob=1.0) + dq, dk, dv = fusion_attention_backward(backward) + self.assertEqual(dq.shape, q.shape) + self.assertEqual(dk.shape, k.shape) + self.assertEqual(dv.shape, v.shape) + + def test_parse_and_convert_layouts(self): + q = torch.randn(2,3,4) + k = torch.randn(2,5,4) + head = 2 + args = parse_bsnd_args(q, k, head, "BSH") + self.assertEqual(args[0], 2) + B,N,S,D = 1,2,3,4 + x = torch.arange(B*N*S*D).reshape(B,N,S,D) + for layout in ["BSH","SBH","BSND","BNSD"]: + out = convert_from_bnsd(x, layout) + back = convert_to_bnsd(out, N, layout) + self.assertTrue(torch.equal(back, x.to(GTYPE))) + + def test_generate_attn_mask(self): + for mode in range(5): + mask = generate_attn_mask(mode, None, 1,1,3,3,0,0,torch.float32) + self.assertEqual(mask.shape, (3,3)) + # reverse large mask + orig = torch.from_numpy(np.triu(np.ones([2048,2048]),1)).to(torch.float32) + rev = generate_attn_mask(2, orig, 1,1,3,3,0,0,torch.float32) + self.assertEqual(rev.shape, (3,3)) + + def test_generate_kv(self): + k = torch.randn(1,2,3,4) + v = torch.randn(1,2,3,4) + k2, v2 = generate_kv(k, v, 4, 2) + self.assertEqual(k2.shape[1], 4) + + def test_get_head_and_layout(self): + with self.assertRaises(ValueError): + get_head_num(1) + with self.assertRaises(ValueError): + get_input_layout(1,2,3) + + def test_npu_patches(self): + # forward patch + q = torch.randn(1,2,2*4) + k = torch.randn(1,3,2*4) + with self.assertRaises(RuntimeError): + npu_fusion_attention_forward_patch(1) + args, dims, new_kwargs = npu_fusion_attention_forward_patch(q, k, None, 2, "BSH") + self.assertIn("b", dims) + # backward patch + dx = torch.randn_like(q) + with self.assertRaises(ValueError): + npu_fusion_attention_backward_patch(q,k,None,dx,3) + args2, dims2, new_kwargs2 = npu_fusion_attention_backward_patch(q,k,None,dx,2, "BSH") + self.assertIn("s1", dims2) + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_data_manager.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_data_manager.py index bb4c8b197ef8362921858839ca3790224715a39a..9cfad00d8ff13e91eb84fff5f46ab434f9ed1d4d 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_data_manager.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_data_manager.py @@ -2,7 +2,8 @@ import unittest from unittest.mock import patch, mock_open, MagicMock import os from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import DataManager -from msprobe.core.common.const import MsCompareConst, CompareConst +from msprobe.core.common.const import CompareConst +from msprobe.mindspore.common.const import MsCompareConst class TestDataManager(unittest.TestCase): diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_op_generate.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_op_generate.py new file mode 100644 index 0000000000000000000000000000000000000000..9166074ce6dfd0f6a794d5fd1b196495450541f7 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_op_generate.py @@ -0,0 +1,309 @@ +import unittest +import tempfile +import os +import json + +from msprobe.core.common.const import Const +from msprobe.mindspore.api_accuracy_checker.generate_op_script.op_generator import ( + APIInfo, + CommonConfig, + parse_json_config, + OperatorScriptGenerator, + APIExtractor, +) +from msprobe.core.common.file_utils import ( + FileOpen, + load_json, + save_json, + make_dir, + change_mode, +) +from msprobe.core.common.const import FileCheckConst + +class TestCommonConfigCheckConfig(unittest.TestCase): + def setUp(self): + # 基本有效配置 + self.tmpdir = tempfile.TemporaryDirectory() + self.valid = { + "dump_json_path": None, + "api_name": "Functional.add", + "extract_api_path": os.path.join(self.tmpdir.name, "out.json"), + "propagation": Const.FORWARD, + "data_mode": "random_data", + "random_seed": 0, + "iter_times": 1, + } + + def tearDown(self): + self.tmpdir.cleanup() + + def make_cfg(self, overrides): + cfg_dict = {**self.valid, **overrides} + return CommonConfig(cfg_dict) + + def test_invalid_api_name_too_long(self): + long_name = "A" * 31 + with self.assertRaises(ValueError) as cm: + self.make_cfg({"api_name": long_name}) + self.assertIn("too long", str(cm.exception)) + + def test_invalid_propagation(self): + with self.assertRaises(ValueError): + self.make_cfg({"propagation": "INVALID"}) + + def test_invalid_data_mode(self): + with self.assertRaises(ValueError): + self.make_cfg({"data_mode": "not_a_mode"}) + + def test_random_seed_not_int(self): + with self.assertRaises(ValueError): + self.make_cfg({"random_seed": "zero"}) + + def test_iter_times_not_int(self): + with self.assertRaises(ValueError): + self.make_cfg({"iter_times": "ten"}) + + +class TestParseJsonConfig(unittest.TestCase): + def test_empty_path_raises(self): + with self.assertRaises(Exception) as cm: + parse_json_config("") # 空路径 + self.assertIn("config_input path can not be empty", str(cm.exception)) + +class TestAPIExtractorExtractOp(unittest.TestCase): + def setUp(self): + # 准备一个 dump_json_path 文件 + self.tmpdir = tempfile.TemporaryDirectory() + self.dump = { + "framework": "mindspore", + "dump_data_dir": "/data", + "data": { + "Functional.add.0": { + Const.INPUT_ARGS: [{"data_name": "a.bin"}], + Const.OUTPUT: [{"data_name": "b.bin"}] + }, + "Other.mul.1": { + Const.INPUT_ARGS: [{"data_name": "c.bin"}] + } + } + } + self.dump_path = os.path.join(self.tmpdir.name, "dump.json") + with open(self.dump_path, "w") as f: + json.dump(self.dump, f) + # 输出路径 + self.out_path = os.path.join(self.tmpdir.name, "extract.json") + + def tearDown(self): + self.tmpdir.cleanup() + + def test_extract_op_creates_file_with_expected_keys(self): + extractor = APIExtractor("Functional.add", self.dump_path, self.out_path) + extractor.extract_op() + # 文件存在 + self.assertTrue(os.path.isfile(self.out_path)) + data = load_json(self.out_path) + # 应包含匹配 key 以及 FRAMEWORK、REAL_DATA_PATH + self.assertIn("Functional.add.0", data) + self.assertEqual(data.get("framework"), "mindspore") + self.assertEqual(data.get("real_data_path"), "/data") + # data_name 已被拼接 + arg = data["Functional.add.0"][Const.INPUT_ARGS][0] + self.assertEqual(arg["data_name"], os.path.join("/data", "a.bin")) + +class TestAPIExtractorUpdateDataNameNested(unittest.TestCase): + def test_update_data_name_nested_list(self): + ex = APIExtractor("Any", None, None) + data = {"data_name": "root"} + nested = [ [data], [{"data_name": "leaf"}] ] + ex.update_data_name(nested, "/base") + # 所有层级的 data_name 都被更新 + self.assertEqual(nested[0][0]["data_name"], "/base/root") + self.assertEqual(nested[1][0]["data_name"], "/base/leaf") + +class TestOperatorScriptGeneratorSegments(unittest.TestCase): + def test_extract_segments_invalid_length(self): + # 既不是 4 段也不是 5 段 + t, name, order = OperatorScriptGenerator.extract_detailed_api_segments("a.b.c") + self.assertIsNone(t) + self.assertIsNone(name) + self.assertIsNone(order) + +class TestOperatorScriptGeneratorNestedInputs(unittest.TestCase): + def test_generate_forward_inputs_code_nested(self): + args = [ + {"parameter_name": "x"}, + [ {"parameter_name": "y1"}, {"parameter_name": "y2"} ], + ] + code = OperatorScriptGenerator.generate_forward_inputs_code(args) + self.assertIn("x", code) + self.assertIn("y1", code) + self.assertIn("y2", code) + + def test_generate_gradient_inputs_code_nested(self): + args = [ + {"parameter_name": "g1"}, + [ {"parameter_name": "g2"} ] + ] + code = OperatorScriptGenerator.generate_gradient_inputs_code(args) + self.assertIn("g1", code) + self.assertIn("g2", code) + + +class TestAPIInfo(unittest.TestCase): + def test_api_type_and_supported(self): + api = APIInfo("Functional.add.0.forward", {}) + self.assertEqual(api.api_type, "Functional") + self.assertTrue(api.is_supported_type()) + + def test_from_json_forward(self): + data = {"Functional.add.0": {"input_args": [], "input_kwargs": {}}} + info = APIInfo.from_json(data, Const.FORWARD) + self.assertEqual(info.api_full_name, "Functional.add.0") + self.assertIsNone(info.backward_info) + + def test_from_json_backward(self): + data = { + "Functional.add.0": {"input_args": [], "input_kwargs": {}}, + "Functional.add_grad.0": {"grad_input": []}, + } + info = APIInfo.from_json(data, Const.BACKWARD) + self.assertEqual(info.api_full_name, "Functional.add.0") + self.assertIsNotNone(info.backward_info) + self.assertEqual(info.backward_info.api_full_name, "Functional.add_grad.0") + + def test_from_json_unsupported_type(self): + data = {"Unknown.add.0": {}} + with self.assertRaises(ValueError): + APIInfo.from_json(data, Const.FORWARD) + + +class TestCommonConfig(unittest.TestCase): + def setUp(self): + # create a temp directory to satisfy make_dir and path checks + self.tmpdir = tempfile.TemporaryDirectory() + self.extract_path = os.path.join(self.tmpdir.name, "sub", "api.json") + # build a valid config dict + self.config = { + "dump_json_path": None, + "api_name": "Functional.add", + "extract_api_path": self.extract_path, + "propagation": Const.FORWARD, + "data_mode": "random_data", + "random_seed": 1, + "iter_times": 1, + } + # ensure parent dir of extract_api_path exists + os.makedirs(os.path.dirname(self.extract_path), exist_ok=True) + # write a dummy JSON file for parse_json_config + self.config_file = os.path.join(self.tmpdir.name, "config.json") + with open(self.config_file, "w") as f: + json.dump(self.config, f) + + def tearDown(self): + self.tmpdir.cleanup() + + def test_parse_json_config(self): + cfg = parse_json_config(self.config_file) + self.assertIsInstance(cfg, CommonConfig) + self.assertEqual(cfg.api_name, "Functional.add") + self.assertEqual(cfg.propagation, Const.FORWARD) + + def test_check_user_settings_invalid_iter(self): + cfg = CommonConfig(self.config.copy()) + cfg.iter_times = 0 + with self.assertRaises(ValueError) as ctx: + cfg.check_user_settings() + self.assertIn("iter_times should be range from 1", str(ctx.exception)) + + def test_check_user_settings_empty_json(self): + # create an empty JSON file to simulate empty extract_api_path + empty = os.path.join(self.tmpdir.name, "empty.json") + with open(empty, "w") as f: + json.dump({}, f) + cfg = CommonConfig({**self.config, "extract_api_path": empty}) + with self.assertRaises(ValueError) as ctx: + cfg.check_user_settings() + self.assertIn("json file is empty", str(ctx.exception)) + + +class TestOperatorScriptGenerator(unittest.TestCase): + def test_extract_detailed_api_segments_four(self): + t, name, order = OperatorScriptGenerator.extract_detailed_api_segments( + "Functional.mul.1.out" + ) + self.assertEqual((t, name, order), ("Functional", "mul", "1")) + + def test_extract_detailed_api_segments_five(self): + t, name, order = OperatorScriptGenerator.extract_detailed_api_segments( + "Functional.prefix.mul.2.out" + ) + self.assertEqual((t, name, order), ("Functional", "prefix.mul", "2")) + + def test_generate_forward_inputs_code(self): + args_info = [{"parameter_name": "x"}, {"parameter_name": "y"}] + code = OperatorScriptGenerator.generate_forward_inputs_code(args_info) + self.assertIn("x", code) + self.assertIn("y", code) + self.assertIn("ComputeElement", code) + + def test_generate_kwargs_compute_element_dict_code(self): + code = OperatorScriptGenerator.generate_kwargs_compute_element_dict_code() + self.assertIn("kwargs_compute_element_dict", code) + self.assertTrue(code.strip().startswith("# ---- 构造 kwargs")) + + def test_generate_gradient_inputs_code(self): + args_back = [{"parameter_name": "grad"}] + code = OperatorScriptGenerator.generate_gradient_inputs_code(args_back) + self.assertIn("grad", code) + self.assertIn("gradient_inputs", code) + + def test_get_settings_real_data(self): + # simulate CommonConfig-like object + common = type("C", (), { + "propagation": Const.FORWARD, + "random_seed": 42, + "data_mode": "real_data", + "iter_times": 100 + }) + gen = OperatorScriptGenerator(common, ["a"], {"k": "v"}, None) + settings = gen.get_settings("Functional.add.0") + self.assertEqual(settings["iter_times"], 1) + self.assertEqual(settings["random_seed"], 42) + + def test_get_settings_random_data(self): + common = type("C", (), { + "propagation": Const.FORWARD, + "random_seed": 7, + "data_mode": "random_data", + "iter_times": 5 + }) + gen = OperatorScriptGenerator(common, ["a"], {"k": "v"}, None) + settings = gen.get_settings("Tensor.sub.3") + self.assertEqual(settings["iter_times"], 5) + + +class TestAPIExtractor(unittest.TestCase): + def test_update_data_name_simple(self): + ex = APIExtractor("Functional.add", None, None) + data = {"data_name": "foo.bin"} + ex.update_data_name(data, "/dumpdir") + self.assertEqual(data["data_name"], os.path.join("/dumpdir", "foo.bin")) + + def test_load_real_data_path(self): + ex = APIExtractor("Functional.add", None, None) + # build a minimal value dict + val = { + Const.INPUT_ARGS: [{"data_name": "a.txt"}], + Const.GRAD_INPUT: [], + Const.INPUT: [], + Const.OUTPUT: [], + Const.GRAD_OUTPUT: [] + } + out = ex.load_real_data_path(val, "/mydump") + # ensure in-place mutation happened + self.assertEqual(val[Const.INPUT_ARGS][0]["data_name"], "/mydump/a.txt") + self.assertIs(out, val) + + +if __name__ == "__main__": + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/common/test_ms_utils.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/common/test_ms_utils.py index 1ed3ca016108519fb3f643c9d4bb768f63a52d40..aad88bb9e68dec59bf98d134354447dc973bc546 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/common/test_ms_utils.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/common/test_ms_utils.py @@ -113,7 +113,7 @@ class TestMsprobeFunctions(unittest.TestCase): seed_all(42, True) # 验证 check_seed_all 的调用 - mock_check_seed_all.assert_called_once_with(42, True, True) + mock_check_seed_all.assert_called_once_with(42, True, False) # 验证环境变量是否设置正确 self.assertEqual(mock_environ.get('PYTHONHASHSEED'), '42') # 验证其他函数是否正确调用 diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/mindspore_data/dump.json b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/mindspore_data/dump.json index 5b954f6d6443c92e6321e5f55e373e99f428653d..48800c0455c6651b146600e61e636d4dc25fac31 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/mindspore_data/dump.json +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/mindspore_data/dump.json @@ -1,6 +1,7 @@ { "task": "statistics", "level": "mix", + "framework": "mindspore", "dump_data_dir": null, "data": { "Tensor.__add__.0.forward": { diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/pytorch_data/dump.json b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/pytorch_data/dump.json index 150cbd43b169573e48542aa0c46c26e7df69843e..b2704185ff19b961b43453f81247236d77677d83 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/pytorch_data/dump.json +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/pytorch_data/dump.json @@ -1,6 +1,7 @@ { "task": "statistics", "level": "mix", + "framework": "pytorch", "dump_data_dir": null, "data": { "Tensor.__add__.0.forward": { diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_common_dir_compare.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_common_dir_compare.py new file mode 100644 index 0000000000000000000000000000000000000000..6560b7af5366042ae0d5e33eb339341c2032d104 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_common_dir_compare.py @@ -0,0 +1,121 @@ +import unittest +import tempfile +import os +import numpy as np +import pandas as pd +from pathlib import Path +from msprobe.mindspore.compare.common_dir_compare import common_dir_compare + +class TestCommonDirCompare(unittest.TestCase): + def setUp(self): + # 创建临时目录 + self.npu_dir = tempfile.mkdtemp() + self.bench_dir = tempfile.mkdtemp() + self.output_dir = tempfile.mkdtemp() + + def tearDown(self): + # 清理临时目录 + for dir_path in [self.npu_dir, self.bench_dir, self.output_dir]: + for root, dirs, files in os.walk(dir_path, topdown=False): + for name in files: + os.remove(os.path.join(root, name)) + for name in dirs: + os.rmdir(os.path.join(root, name)) + os.rmdir(dir_path) + + def test_simple_directory_comparison(self): + """测试简单目录结构比对""" + # 创建测试npy文件 + np.save(os.path.join(self.npu_dir, "x_float32_0.npy"), np.random.rand(10, 10).astype(np.float32)) + np.save(os.path.join(self.npu_dir, "x_float32_1.npy"), np.random.rand(10, 10).astype(np.float32)) + np.save(os.path.join(self.bench_dir, "x_float32_0.npy"), np.random.rand(10, 10).astype(np.float32)) + np.save(os.path.join(self.bench_dir, "x_float32_1.npy"), np.random.rand(10, 10).astype(np.float32)) + + # 执行比对 + input_params = {'npu_path': self.npu_dir, 'bench_path': self.bench_dir} + result = common_dir_compare(input_params, self.output_dir) + + # 验证输出目录结构 + self.assertTrue(os.path.exists(os.path.join(self.output_dir, 'result.csv'))) + + # 验证结果文件内容 + result_df = pd.read_csv(os.path.join(self.output_dir, 'result.csv')) + self.assertEqual(len(result_df), 2) + self.assertIn('x_0', result_df['Name'].values) + self.assertIn('x_1', result_df['Name'].values) + + def test_nested_directory_comparison(self): + """测试嵌套目录结构比对""" + # 创建嵌套目录 + os.makedirs(os.path.join(self.npu_dir, "rank0")) + os.makedirs(os.path.join(self.bench_dir, "rank0")) + + # 创建测试npy文件 + np.save(os.path.join(self.npu_dir, "rank0", "y_float32_0.npy"), np.random.rand(5, 5).astype(np.float32)) + np.save(os.path.join(self.bench_dir, "rank0", "y_float32_0.npy"), np.random.rand(5, 5).astype(np.float32)) + + input_params = {'npu_path': self.npu_dir, 'bench_path': self.bench_dir} + result = common_dir_compare(input_params, self.output_dir) + + # 验证输出目录结构 + self.assertTrue(os.path.exists(os.path.join(self.output_dir, 'rank0', 'result.csv'))) + + # 验证结果文件内容 + result_df = pd.read_csv(os.path.join(self.output_dir, 'rank0', 'result.csv')) + self.assertEqual(len(result_df), 1) + self.assertIn('y_0', result_df['Name'].values) + + def test_filename_mapping(self): + """测试文件名映射功能""" + # 创建不同名称但通过映射关联的文件 + np.save(os.path.join(self.npu_dir, "a_float32_0.npy"), np.random.rand(4, 4).astype(np.float32)) + np.save(os.path.join(self.bench_dir, "b_float32_0.npy"), np.random.rand(4, 4).astype(np.float32)) + + input_params = {'npu_path': self.npu_dir, 'bench_path': self.bench_dir, 'map_dict': {'a': 'b'}} + result = common_dir_compare(input_params, self.output_dir) + + # 验证结果文件生成 + self.assertTrue(os.path.exists(os.path.join(self.output_dir, 'result.csv'))) + + # 验证结果文件内容 + result_df = pd.read_csv(os.path.join(self.output_dir, 'result.csv')) + self.assertEqual(len(result_df), 1) + self.assertIn('a_0', result_df['Name'].values) + + def test_large_number_of_files(self): + """测试大量文件比对""" + # 创建100对npy文件 + for i in range(100): + np.save(os.path.join(self.npu_dir, f"data_float32_{i}.npy"), np.random.rand(20, 20).astype(np.float32)) + np.save(os.path.join(self.bench_dir, f"data_float32_{i}.npy"), np.random.rand(20, 20).astype(np.float32)) + + input_params = {'npu_path': self.npu_dir, 'bench_path': self.bench_dir} + result = common_dir_compare(input_params, self.output_dir) + + # 验证所有结果都被处理 + result_df = pd.read_csv(os.path.join(self.output_dir, 'result.csv')) + self.assertEqual(len(result_df), 100) + + def test_empty_directory(self): + """测试空目录""" + input_params = {'npu_path': self.npu_dir, 'bench_path': self.bench_dir} + result = common_dir_compare(input_params, self.output_dir) + + # 应该没有结果文件生成 + self.assertEqual(len(os.listdir(self.output_dir)), 0) + + def test_different_data_types(self): + """测试不同数据类型的npy文件""" + np.save(os.path.join(self.npu_dir, "type_float32_0.npy"), np.random.rand(2, 2).astype(np.float32)) + np.save(os.path.join(self.bench_dir, "type_float64_0.npy"), np.random.rand(2, 2).astype(np.float64)) + + input_params = {'npu_path': self.npu_dir, 'bench_path': self.bench_dir, 'map_dict': {'type_float32': 'type_float64'}} + result = common_dir_compare(input_params, self.output_dir) + + # 验证数据类型被正确记录 + result_df = pd.read_csv(os.path.join(self.output_dir, 'result.csv')) + self.assertEqual(result_df.iloc[0]['NPU Dtype'], 'float32') + self.assertEqual(result_df.iloc[0]['Bench Dtype'], 'float64') + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare.py index b5cbff9784a837ea4d64ac9eccdf30175564f712..eafe9384618502390b41adedd7d32db172ca8188 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare.py @@ -1,19 +1,18 @@ # coding=utf-8 -import json import os -import random import shutil -import tempfile +import random import unittest +from unittest.mock import patch -import numpy as np import torch -import yaml +import numpy as np -from msprobe.core.common.utils import CompareException -from msprobe.core.compare.acc_compare import ModeConfig -from msprobe.mindspore.compare.ms_compare import MappingConfig, MSComparator, check_cross_framework +from msprobe.mindspore.compare.ms_compare import check_cross_framework, read_real_data, ms_compare from msprobe.core.common.const import Const +from msprobe.test.core_ut.compare.test_acc_compare import generate_dump_json, generate_stack_json +from msprobe.core.common.utils import CompareException + npu_dict = {'op_name': ['Functional.conv2d.0.forward.input.0', 'Functional.conv2d.0.forward.input.1', 'Functional.conv2d.0.forward.input.2', 'Functional.conv2d.0.forward.output'], @@ -173,6 +172,8 @@ json_data_template = { 'data': {} } +base_dir1 = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'test_ms_compare1') + def gen_data(is_ms=True): type_value = 'mindspore.Tensor' if is_ms else 'torch.Tensor' @@ -188,349 +189,65 @@ def gen_data(is_ms=True): } -def gen_api_mapping_test_data(need_user_mapping=False): - result_npu = json_data_template.copy() - result_bench = json_data_template.copy() - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig() - ms_comparator = MSComparator(mode_config, mapping_config) - - api_mapping = ms_comparator.load_internal_api() - ms_api_list = np.random.choice(list(api_mapping.keys()), size=5, replace=False).astype(str).tolist() - ms_api_data = {} - pt_api_data = {} - user_mapping = [] - for api in ms_api_list: - call_num = random.randint(1, 10) - direction = random.choice(['forward', 'backward']) - data_name_ms = api + '.' + str(call_num) + '.' + direction - data_name_pt = api_mapping.get(api) + '.' + str(call_num) + '.' + direction - input_num = random.randint(1, 5) - output_num = random.randint(1, 5) - ms_data = {'input_args': [gen_data(True) for _ in range(input_num)], - 'output': [gen_data(True) for _ in range(output_num)]} - pt_data = {'input_args': [gen_data(False) for _ in range(input_num)], - 'output': [gen_data(False) for _ in range(output_num)]} - ms_api_data[data_name_ms] = ms_data - pt_api_data[data_name_pt] = pt_data - if need_user_mapping: - compare_num_input = random.randint(1, input_num) - compare_num_output = random.randint(1, output_num) - user_mapping_item = {'ms_api': api, - 'pt_api': api_mapping.get(api), - 'ms_args': sorted(np.random.choice(list(range(input_num)), size=compare_num_input, - replace=False).astype(int).tolist()), - 'pt_args': sorted(np.random.choice(list(range(input_num)), size=compare_num_input, - replace=False).astype(int).tolist()), - 'ms_output': sorted(np.random.choice(list(range(output_num)), size=compare_num_output, - replace=False).astype(int).tolist()), - 'pt_output': sorted(np.random.choice(list(range(output_num)), size=compare_num_output, - replace=False).astype(int).tolist())} - user_mapping.append(user_mapping_item) - ms_api_key_list = list(ms_api_data.keys()) - random.shuffle(ms_api_key_list) - result_npu['data'] = {k: ms_api_data.get(k) for k in ms_api_key_list} - pt_api_key_list = list(pt_api_data.keys()) - random.shuffle(pt_api_key_list) - result_bench['data'] = {k: pt_api_data.get(k) for k in pt_api_key_list} - return result_npu, result_bench, user_mapping +class TestUtilsMethods(unittest.TestCase): + def setUp(self): + os.makedirs(base_dir1, mode=0o750, exist_ok=True) + np.save(os.path.join(base_dir1, 'numpy_data.npy'), np.array([1, 2, 3])) + torch.save(torch.tensor([2, 3, 4]), os.path.join(base_dir1, 'torch_data.pt')) -class TestUtilsMethods(unittest.TestCase): + def tearDown(self): + if os.path.exists(base_dir1): + shutil.rmtree(base_dir1) - def test_check_op_ms(self): - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL + @patch('msprobe.mindspore.compare.utils.detect_framework_by_dump_json') + def test_check_cross_framework_valid_pytorch(self, mock_detect_framework): + mock_detect_framework.return_value = Const.PT_FRAMEWORK - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig() + result = check_cross_framework("dummy_path") - ms_comparator = MSComparator(mode_config, mapping_config) - result = ms_comparator.check_op(npu_dict, bench_dict) self.assertTrue(result) - def test_data_mapping(self): - stack_json_data = {} - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig(data_mapping=data_mapping) - ms_comparator = MSComparator(mode_config, mapping_config) - - npu_ops_all = ms_comparator.merge_data(npu_json_data, stack_json_data) - npu_ops_all_correct = { - 'Functional.flash_attention_score.4.forward.input.0': { - 'struct': ('BFloat16', [4096, 1, 2048]), - 'summary': [4.1875, -4.4375, -4.550282028503716e-05, 2316.379150390625], - 'data_name': None, - 'stack_info': [None] - }, - 'Functional.flash_attention_score.4.forward.output.0': { - 'struct': ('BFloat16', [4096, 1, 2048]), - 'summary': [4.1875, -4.4375, -4.550282028503716e-05, 2316.379150390625], - 'data_name': None, - 'stack_info': [None] - } - } - self.assertDictEqual(npu_ops_all, npu_ops_all_correct) - - bench_ops_all = ms_comparator.merge_data(bench_json_data, stack_json_data) - bench_ops_all_correct = { - 'NPU.npu_fusion_attention.4.forward.input.0': { - 'struct': ('torch.bfloat16', [4096, 1, 2048]), - 'summary': [4.1875, -4.4375, -4.553794860839844e-05, 2320.0], - 'data_name': None, - 'stack_info': [None] - }, - 'NPU.npu_fusion_attention.4.forward.output.0': { - 'struct': ('torch.bfloat16', [4096, 1, 2048]), - 'summary': [4.1875, -4.4375, -4.553794860839844e-05, 2320.0], - 'data_name': None, - 'stack_info': [None] - } - } - self.assertDictEqual(bench_ops_all, bench_ops_all_correct) - - result = ms_comparator.get_accuracy(npu_ops_all, bench_ops_all) - result_correct = [['Functional.flash_attention_score.4.forward.input.0', - 'NPU.npu_fusion_attention.4.forward.input.0', - 'BFloat16', 'torch.bfloat16', [4096, 1, 2048], [4096, 1, 2048], 0.0, 0.0, - 3.512832336127758e-08, -3.620849609375, '0.0%', '0.0%', '0.07714076816099476%', - '0.1560711038523707%', 4.1875, -4.4375, -4.550282028503716e-05, 2316.379150390625, - 4.1875, -4.4375, -4.553794860839844e-05, 2320.0, '', '', None], - ['Functional.flash_attention_score.4.forward.output.0', - 'NPU.npu_fusion_attention.4.forward.output.0', - 'BFloat16', 'torch.bfloat16', [4096, 1, 2048], [4096, 1, 2048], 0.0, 0.0, - 3.512832336127758e-08, -3.620849609375, '0.0%', '0.0%', '0.07714076816099476%', - '0.1560711038523707%', 4.1875, -4.4375, -4.550282028503716e-05, 2316.379150390625, - 4.1875, -4.4375, -4.553794860839844e-05, 2320.0, '', '', None] - ] - self.assertListEqual(result, result_correct) - - def test_dm_tensor_task(self): - self.compare_process_custom(dump_mode=Const.ALL) - - def compare_process_custom(self, dump_mode): - data_path = tempfile.mkdtemp(prefix='dump_data', dir='/tmp') - try: - npu_dump_path = os.path.join(data_path, 'npu_dump.json') - bench_dump_path = os.path.join(data_path, 'bench_dump.json') - npu_stack_path = os.path.join(data_path, 'npu_stack.json') - - with open(npu_dump_path, 'w') as n_d_f: - json.dump(npu_json_data, n_d_f) - with open(bench_dump_path, 'w') as b_d_f: - json.dump(bench_json_data, b_d_f) - with open(npu_stack_path, 'w') as n_s_f: - json.dump({}, n_s_f) - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig() - - ms_comparator = MSComparator(mode_config, mapping_config) - result_df = ms_comparator.compare_process_custom((npu_dump_path, bench_dump_path, npu_stack_path)) - self.assertListEqual(result_df.values.tolist(), []) - finally: - shutil.rmtree(data_path) - - def test_check_cross_framework(self): - ms_data = { - "data_name": "Cell.model.language_model.encoder.layers.5.input_norm.FusedRMSNorm.forward.0.input.0.npy", - } - pt_data = { - "data_name": "Module.module.module.language_model.encoder.layers.0.input_norm.RMSNorm.forward.0.input.0.pt", + @patch('msprobe.mindspore.compare.utils.detect_framework_by_dump_json') + def test_check_cross_framework_invalid_framework(self, mock_detect_framework): + mock_detect_framework.return_value = Const.MS_FRAMEWORK + + result = check_cross_framework("dummy_path") + + self.assertFalse(result) + + def test_read_real_data_ms(self): + n_value, b_value = read_real_data(base_dir1, 'numpy_data.npy', base_dir1, 'numpy_data.npy', False) + self.assertTrue(np.array_equal(n_value, np.array([1, 2, 3]))) + self.assertTrue(np.array_equal(b_value, np.array([1, 2, 3]))) + + def test_read_real_data_cross_frame(self): + n_value, b_value = read_real_data(base_dir1, 'numpy_data.npy', base_dir1, 'torch_data.pt', True) + self.assertTrue(np.array_equal(n_value, np.array([1, 2, 3]))) + self.assertTrue(np.array_equal(b_value, np.array([2, 3, 4]))) + + def test_ms_compare(self): + generate_dump_json(base_dir1) + generate_stack_json(base_dir1) + + dump_path = os.path.join(base_dir1, 'dump.json') + + input_param = { + 'npu_json_path': dump_path, + 'bench_json_path': dump_path, + 'is_print_compare_log': True } + output_path = base_dir1 - def check_data(data): - with tempfile.NamedTemporaryFile(mode='w+', suffix='.json', encoding='utf-8', delete=True) as temp_file: - json.dump(data, temp_file, ensure_ascii=False, indent=4) - temp_file.flush() - return check_cross_framework(temp_file.name) - self.assertFalse(check_data(ms_data)) - self.assertTrue(check_data(pt_data)) - - def test_comapre_process(self): - data_path = tempfile.mkdtemp(prefix='dump_data', dir='/tmp') - try: - npu_dump_path = os.path.join(data_path, 'npu_dump.json') - bench_dump_path = os.path.join(data_path, 'bench_dump.json') - npu_stack_path = os.path.join(data_path, 'npu_stack.json') - - npu_data, bench_data, _ = gen_api_mapping_test_data() - with open(npu_dump_path, 'w', encoding='utf8') as n_d_f: - json.dump(npu_data, n_d_f) - with open(bench_dump_path, 'w', encoding='utf8') as b_d_f: - json.dump(bench_data, b_d_f) - with open(npu_stack_path, 'w', encoding='utf8') as n_s_f: - json.dump({}, n_s_f) - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig(api_mapping=True) - - ms_comparator = MSComparator(mode_config, mapping_config) - result_df = ms_comparator.compare_process((npu_dump_path, bench_dump_path, npu_stack_path)) - self.assertTrue((result_df['Bench Name'] != 'N/A').all()) - finally: - shutil.rmtree(data_path) - - def test_compare_process_with_customize_api_mapping(self): - data_path = tempfile.mkdtemp(prefix='dump_data', dir='/tmp') - try: - npu_dump_path = os.path.join(data_path, 'npu_dump.json') - bench_dump_path = os.path.join(data_path, 'bench_dump.json') - npu_stack_path = os.path.join(data_path, 'npu_stack.json') - user_mapping_path = os.path.join(data_path, 'user_mapping.yaml') - - npu_data, bench_data, user_mapping = gen_api_mapping_test_data(True) - with open(npu_dump_path, 'w', encoding='utf8') as n_d_f: - json.dump(npu_data, n_d_f) - with open(bench_dump_path, 'w', encoding='utf8') as b_d_f: - json.dump(bench_data, b_d_f) - with open(npu_stack_path, 'w', encoding='utf8') as n_s_f: - json.dump({}, n_s_f) - with open(user_mapping_path, 'w', encoding='utf8') as u_m_f: - yaml.safe_dump(user_mapping, u_m_f) - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig(api_mapping=user_mapping_path) - - ms_comparator = MSComparator(mode_config, mapping_config) - result_df = ms_comparator.compare_process((npu_dump_path, bench_dump_path, npu_stack_path)) - - user_mapping_dict = {} - for i in user_mapping: - user_mapping_dict[i.get('ms_api')] = {'input': i.get('ms_args'), 'output': i.get('ms_output')} - match_set = set() - for key in npu_data.get('data').keys(): - matched_dict = user_mapping_dict.get(key.rsplit('.', 2)[0]) - match_set.update({key + '.input.' + str(i) for i in matched_dict.get('input')}) - match_set.update({key + '.output.' + str(i) for i in matched_dict.get('output')}) - - self.assertTrue((result_df.loc[result_df['NPU Name'].isin(match_set), 'Bench Name'] != 'N/A').all()) - self.assertTrue((result_df.loc[~result_df['NPU Name'].isin(match_set), 'Bench Name'] == 'N/A').all()) - finally: - shutil.rmtree(data_path) - - def test_load_internal_api(self): - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig() - - ms_comparator = MSComparator(mode_config, mapping_config) - api_dict = ms_comparator.load_internal_api() - self.assertEqual(api_dict['Functional.abs'], 'Torch.abs') - - def test_process_cell_mapping(self): - self.base_test_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) - self.input_dir = os.path.join(self.base_test_dir, 'resources') - cell_mapping_path = os.path.join(self.input_dir, 'common', 'cell_mapping.yaml') - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig(cell_mapping=cell_mapping_path) - - ms_comparator = MSComparator(mode_config, mapping_config) - npu_op_name = ms_comparator.process_cell_mapping(npu_cell_dict.get('op_name')[0]) - self.assertEqual(npu_op_name, 'Module.fc1.Linear.forward.0.input.0') - - def test_read_npy_data(self): - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig() - - ms_comparator = MSComparator(mode_config, mapping_config) - - self.temp_file = tempfile.NamedTemporaryFile(suffix='.pt') - tensor = torch.Tensor([1, 2, 3]) - filename = self.temp_file.name.split('/')[-1] - torch.save(tensor, self.temp_file.name) - result = ms_comparator.read_npy_data('/tmp', filename, load_pt_file=True) - self.assertTrue(np.array_equal(result, np.array([1, 2, 3]))) - self.temp_file.close() - - self.temp_file = tempfile.NamedTemporaryFile(suffix='.npy') - tensor = np.array([1, 2, 3]) - filename = self.temp_file.name.split('/')[-1] - np.save(self.temp_file.name, tensor) - result = ms_comparator.read_npy_data('/tmp', filename, load_pt_file=False) - self.assertTrue(np.array_equal(result, np.array([1, 2, 3]))) - self.temp_file.close() - - def test_process_internal_api_mapping(self): - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig(api_mapping=1) - - ms_comparator = MSComparator(mode_config, mapping_config) - - npu_op_name = "Mint.addcmul.0.forward.input.0" - result = ms_comparator.process_internal_api_mapping(npu_op_name) - self.assertEqual(result, "Torch.addcmul.0.forward.input.0") - - npu_op_name = "MintFunctional.addcmul.0.forward.input.0" - result = ms_comparator.process_internal_api_mapping(npu_op_name) - self.assertEqual(result, "Functional.addcmul.0.forward.input.0") - - npu_op_name = "Functional.abs" - result = ms_comparator.process_internal_api_mapping(npu_op_name) - self.assertEqual(result, "Torch.abs") - - def test_get_api_name(self): - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig() - - ms_comparator = MSComparator(mode_config, mapping_config) - - api_list = ["Functional", "absolute", "0", "forward", "input", "0"] - result = ms_comparator.get_api_name(api_list) - self.assertEqual(result, "Functional.absolute") - - api_list = ["Mint"] - with self.assertRaises(CompareException): - ms_comparator.get_api_name(api_list) \ No newline at end of file + ms_compare(input_param, output_path) + output_files = os.listdir(output_path) + self.assertTrue(any(f.endswith(".xlsx") for f in output_files)) + + input_param2 = { + 'npu_json_path': '', + 'bench_json_path': dump_path, + 'is_print_compare_log': True + } + with self.assertRaises(CompareException) as context: + ms_compare(input_param2, output_path) + self.assertEqual(context.exception.code, 1) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare_utils.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d7fb5e38fb82b309caf3ab2a1b621655d7babc86 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare_utils.py @@ -0,0 +1,24 @@ +import unittest +from unittest.mock import patch + +import numpy as np + +from msprobe.core.common.file_utils import FileCheckConst +from msprobe.mindspore.compare.utils import read_npy_data + + +class TestReadNpyData(unittest.TestCase): + + @patch('msprobe.mindspore.compare.utils.load_npy') + @patch('msprobe.mindspore.compare.utils.FileChecker') + @patch('os.path.join', return_value='/fake/path/to/file.npy') + def test_read_real_data_ms(self, mock_os, mock_file_checker, mock_load_npy): + mock_file_checker.return_value.common_check.return_value = '/fake/path/to/file.npy' + + mock_load_npy.return_value = np.array([1.0, 2.0, 3.0]) + + result = read_npy_data('/fake/dir', 'file_name.npy') + + mock_file_checker.assert_called_once_with('/fake/path/to/file.npy', FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.NUMPY_SUFFIX, False) + mock_load_npy.assert_called_once_with('/fake/path/to/file.npy') + self.assertTrue(np.array_equal(result, np.array([1.0, 2.0, 3.0]))) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_graph_compare.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_graph_compare.py index e3fd9348efe7dd4df0a6db2cd52a45f4757dae01..56f2044eb14062c789c34d66a5e593f4e5fc9fb1 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_graph_compare.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_graph_compare.py @@ -78,7 +78,7 @@ class TestMsGraphCompare(unittest.TestCase): result_correct = ( f"[['{npu_file_path}', '{bench_file_path}', dtype('float16'), dtype('float16'), (10, 10), (10, 10), " - f"44.0, 44.0, 44.0, inf, 44.0, 44.0, 44.0, inf, 'Yes', '', 1.0, 0.0, 0.0, 1.0, 1.0]]") + f"44.0, 44.0, 44.0, inf, 44.0, 44.0, 44.0, inf, 'Yes', '', 1.0, 0.0, 0.0, 0.0, 1.0, 1.0]]") self.assertNotEqual(len(files), 0) self.assertEqual(result, result_correct) @@ -93,15 +93,20 @@ class TestMsGraphCompare(unittest.TestCase): compare_result_db = ms_graph_comparator.compare_ops(compare_result_db, mode) result = compare_result_db.values.tolist() - op_name = 'Default_Switch-op1_kernel_graph1_Data_86.185.41.output' + op_name = 'Default_Switch-op1_kernel_graph1_Data_86.185.output.0' npu_file_path = os.path.join(self.npu_data_path, 'rank_0/mnist/0/0/statistic.csv') bench_file_path = os.path.join(self.bench_data_path, 'rank_0/mnist/0/0/statistic.csv') - npu_name = f'{op_name} {npu_file_path}' - bench_name = f'{op_name} {bench_file_path}' + npu_name = f'{op_name}' + bench_name = f'{op_name}' + npu_csv_file = f'{npu_file_path}' + bench_csv_file = f'{bench_file_path}' result_correct = [ [npu_name, bench_name, 'float32', 'float32', '-4096', '-4096', 1.0000799894332886, 0.9999160170555115, 1.0, 63.9995002746582, 1.0000799894332886, 0.9999160170555115, 1.0, 63.9995002746582, 'Yes', '', 0.0, 0.0, 0.0, - 0.0, '0.0%', '0.0%', '0.0%', '0.0%']] + 0.0, '0.0%', '0.0%', '0.0%', '0.0%', + npu_csv_file, bench_csv_file + ] + ] self.assertListEqual(result, result_correct) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_graph_cell_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_graph_cell_dump.py new file mode 100644 index 0000000000000000000000000000000000000000..60a54d9e1523da8b9fa05bf3aef3a02b667e26c3 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_graph_cell_dump.py @@ -0,0 +1,379 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import re +import unittest +from unittest.mock import MagicMock, patch +import tempfile +import sys +from types import SimpleNamespace + +import mindspore as ms +from mindspore import ops +import pandas as pd + +from msprobe.core.common.const import Const as CoreConst +from msprobe.mindspore.dump import cell_dump_process +from msprobe.mindspore.dump.cell_dump_process import cell_construct_wrapper +from msprobe.mindspore.dump.cell_dump_process import convert_special_values, sort_filenames +from msprobe.mindspore.dump.cell_dump_process import check_relation +from msprobe.mindspore.dump.cell_dump_process import process_csv, np_ms_dtype_dict +from msprobe.mindspore.dump.cell_dump_process import create_kbyk_json + +class TestCellWrapperProcess(unittest.TestCase): + + @patch('msprobe.mindspore.dump.cell_dump_process.ops.is_tensor') + @patch('msprobe.mindspore.dump.cell_dump_process.td') + @patch('msprobe.mindspore.dump.cell_dump_process.td_in') + def test_cell_construct_wrapper(self, mock_td_in, mock_td, mock_istensor): + + # Mock the TensorDump operations + mock_td.return_value = MagicMock() + mock_td_in.return_value = MagicMock() + mock_istensor.return_value = False + + # Create a mock cell with necessary attributes + mock_cell = MagicMock() + mock_cell.data_mode = "all" + mock_cell.dump_path = "mock_dump_path" + mock_cell.cell_prefix = "mock_cell_prefix" + + # Define a mock function to wrap + def mock_func(*args, **kwargs): + return args + + # Wrap the mock function using cell_construct_wrapper + wrapped_func = cell_construct_wrapper(mock_func, mock_cell) + + # Create mock inputs + mock_input = ms.Tensor([1, 2, 3]) + mock_args = (mock_input,) + + # Call the wrapped function + wrapped_func(mock_cell, *mock_args) + + # Verify that the TensorDump operations were not called + mock_td_in.assert_not_called() + mock_td.assert_not_called() + + +class TestSortFilenames(unittest.TestCase): + + @patch('os.listdir') + def test_sort_filenames(self, mock_listdir): + # Mock the list of filenames returned by os.listdir + mock_listdir.return_value = [ + 'Cell.network._backbone.model.LlamaModel.backward.0.input.0_float16_177.npy', + 'Cell.network._backbone.model.LlamaModel.forward.0.input.0_in_int32_1.npy', + 'Cell.network._backbone.model.LlamaModel.forward.0.output.10_float16_165.npy', + 'Cell.network._backbone.model.norm_out.LlamaRMSNorm.backward.0.input.0_float16_178.npy' + ] + + # Mock the CoreConst values + CoreConst.REPLACEMENT_CHARACTER = '_' + CoreConst.NUMPY_SUFFIX = '.npy' + + # Expected sorted filenames + expected_sorted_filenames = [ + 'Cell.network._backbone.model.LlamaModel.forward.0.input.0_in_int32_1.npy', + 'Cell.network._backbone.model.LlamaModel.forward.0.output.10_float16_165.npy', + 'Cell.network._backbone.model.LlamaModel.backward.0.input.0_float16_177.npy', + 'Cell.network._backbone.model.norm_out.LlamaRMSNorm.backward.0.input.0_float16_178.npy' + ] + + # Call the function + sorted_filenames = sort_filenames('/mock/path') + + # Assert the filenames are sorted correctly + self.assertEqual(sorted_filenames, expected_sorted_filenames) + + +class TestCheckRelation(unittest.TestCase): + + def setUp(self): + CoreConst.SEP = '.' + global KEY_LAYERS + KEY_LAYERS = "layers" + + def test_direct_parent_child_relation(self): + self.assertTrue(check_relation("network._backbone", "network")) + self.assertTrue(check_relation("network._backbone.model", "network._backbone")) + + def test_no_relation(self): + self.assertFalse(check_relation("network._backbone", "network.loss")) + self.assertFalse(check_relation("network._backbone.model", "network.loss")) + + def test_layer_pattern_relation(self): + self.assertTrue(check_relation("network.model.layers.0", "network.model")) + self.assertTrue(check_relation("network._backbone.model.layers.1", "network._backbone.model")) + + def test_edge_cases(self): + self.assertFalse(check_relation("", "network")) + self.assertFalse(check_relation("network.layer1", "")) + self.assertFalse(check_relation("", "")) + + +class TestRenameFilename(unittest.TestCase): + def setUp(self): + self.logger_patcher = patch.object(cell_dump_process, "logger", MagicMock()) + self.logger_patcher.start() + + def tearDown(self): + self.logger_patcher.stop() + + @patch.object(cell_dump_process, "sort_filenames") + @patch("msprobe.mindspore.dump.cell_dump_process.move_file") + def test_rename_filename_tensor(self, mock_move_file, mock_sort_filenames): + cell_dump_process.dump_task = CoreConst.TENSOR + + with tempfile.TemporaryDirectory() as tmpdir: + filenames = [ + "Cell.a.b.c.X.forward.input.0_float32_1.npy", + "Cell.a.b.c.X.forward.input.1_float32_2.npy", + "Cell.a.b.c.X.forward.output.0_float32_3.npy", + "Cell.a.b.c.X.forward.input.0_float32_11.npy", + "Cell.a.b.c.X.forward.input.1_float32_12.npy", + "Cell.a.b.c.X.forward.output.0_float32_13.npy", + "Cell.a.b.c.X.backward.input.0_float32_30.npy" + ] + for fname in filenames: + with open(os.path.join(tmpdir, fname), "wb") as f: + f.write(b"dummy") + + mock_sort_filenames.return_value = filenames + + rename_calls = [] + def fake_rename(src, dst): + rename_calls.append((os.path.basename(src), os.path.basename(dst))) + mock_move_file.side_effect = fake_rename + + cell_dump_process.rename_filename(path=tmpdir) + + expected = [ + ("Cell.a.b.c.X.forward.input.0_float32_1.npy", "Cell.a.b.c.X.forward.0.input.0_float32_1.npy"), + ("Cell.a.b.c.X.forward.input.1_float32_2.npy", "Cell.a.b.c.X.forward.0.input.1_float32_2.npy"), + ("Cell.a.b.c.X.forward.output.0_float32_3.npy", "Cell.a.b.c.X.forward.0.output.0_float32_3.npy"), + ("Cell.a.b.c.X.forward.input.0_float32_11.npy", "Cell.a.b.c.X.forward.1.input.0_float32_11.npy"), + ("Cell.a.b.c.X.forward.input.1_float32_12.npy", "Cell.a.b.c.X.forward.1.input.1_float32_12.npy"), + ("Cell.a.b.c.X.forward.output.0_float32_13.npy", "Cell.a.b.c.X.forward.1.output.0_float32_13.npy"), + ("Cell.a.b.c.X.backward.input.0_float32_30.npy", "Cell.a.b.c.X.backward.0.input.0_float32_30.npy") + ] + self.assertEqual(rename_calls, expected) + + @patch("msprobe.mindspore.dump.cell_dump_process.move_file") + def test_rename_filename_statistics(self, mock_move_file): + cell_dump_process.dump_task = CoreConst.STATISTICS + + data = { + 'Op Name': [ + "Cell.a.b.c.X.forward.input.0", + "Cell.a.b.c.X.forward.input.1", + "Cell.a.b.c.X.forward.output.0", + "Cell.a.b.c.X.forward.input.0", + "Cell.a.b.c.X.forward.input.1", + "Cell.a.b.c.X.forward.output.0", + "Cell.a.b.c.X.backward.input.0" + ] + } + df = pd.DataFrame(data) + + cell_dump_process.rename_filename(data_df=df) + + self.assertEqual(df['Op Name'].iloc[0], "Cell.a.b.c.X.forward.0.input.0") + self.assertEqual(df['Op Name'].iloc[1], "Cell.a.b.c.X.forward.0.input.1") + self.assertEqual(df['Op Name'].iloc[2], "Cell.a.b.c.X.forward.0.output.0") + self.assertEqual(df['Op Name'].iloc[3], "Cell.a.b.c.X.forward.1.input.0") + self.assertEqual(df['Op Name'].iloc[4], "Cell.a.b.c.X.forward.1.input.1") + self.assertEqual(df['Op Name'].iloc[5], "Cell.a.b.c.X.forward.1.output.0") + self.assertEqual(df['Op Name'].iloc[6], "Cell.a.b.c.X.backward.0.input.0") + + +class TestConvertSpecialValues(unittest.TestCase): + TEST_CASES = [ + ("true", True), + ("True", True), + ("false", False), + ("False", False), + ("1.23", 1.23), + ("0", 0.0), + ("-5.6", -5.6), + (42, 42), + (3.14, 3.14), + (pd.NA, None) + ] + + def test_convert_special_values(self): + for input_value, expected in self.TEST_CASES: + result = convert_special_values(input_value) + self.assertEqual(result, expected) + + +class TestProcessCsv(unittest.TestCase): + + @staticmethod + def make_df(rows): + import pandas as pd + return pd.DataFrame(rows) + + @patch("msprobe.mindspore.dump.cell_dump_process.read_csv") + def test_process_csv_input_and_output(self, mock_read_csv): + rows = [ + { + 'Op Name': 'Cell.net.layer.forward.0.input.0', + 'Shape': '(2,3)', + 'Data Type': 'float32', + 'Max Value': 1.0, + 'Min Value': 0.0, + 'Avg Value': 0.5, + 'L2Norm Value': 2.0 + }, + { + 'Op Name': 'Cell.net.layer.forward.0.output.0', + 'Shape': '(2,3)', + 'Data Type': 'float32', + 'Max Value': 2.0, + 'Min Value': -1.0, + 'Avg Value': 0.0, + 'L2Norm Value': 3.0 + } + ] + df = self.make_df(rows) + mock_read_csv.return_value = df + + result = process_csv("dummy_path") + self.assertEqual(len(result), 2) + + op_name, key, tensor_json = result[0] + self.assertEqual(op_name, 'Cell.net.layer.forward.0') + self.assertEqual(key, CoreConst.INPUT_ARGS) + self.assertEqual(tensor_json[CoreConst.TYPE], 'mindspore.Tensor') + self.assertEqual(tensor_json[CoreConst.DTYPE], str(np_ms_dtype_dict['float32'])) + self.assertEqual(tensor_json[CoreConst.SHAPE], [2, 3]) + self.assertEqual(tensor_json[CoreConst.MAX], 1.0) + self.assertEqual(tensor_json[CoreConst.MIN], 0.0) + self.assertEqual(tensor_json[CoreConst.MEAN], 0.5) + self.assertEqual(tensor_json[CoreConst.NORM], 2.0) + + op_name, key, tensor_json = result[1] + self.assertEqual(op_name, 'Cell.net.layer.forward.0') + self.assertEqual(key, CoreConst.OUTPUT) + self.assertEqual(tensor_json[CoreConst.MAX], 2.0) + self.assertEqual(tensor_json[CoreConst.MIN], -1.0) + self.assertEqual(tensor_json[CoreConst.MEAN], 0.0) + self.assertEqual(tensor_json[CoreConst.NORM], 3.0) + + @patch("msprobe.mindspore.dump.cell_dump_process.read_csv") + def test_process_csv_handles_missing_columns(self, mock_read_csv): + rows = [ + { + 'Op Name': 'Cell.net.layer.forward.0.input.0', + 'Shape': '(1,)', + 'Data Type': 'int32' + } + ] + df = self.make_df(rows) + mock_read_csv.return_value = df + + result = process_csv("dummy_path") + self.assertEqual(len(result), 1) + op_name, key, tensor_json = result[0] + self.assertEqual(tensor_json[CoreConst.DTYPE], str(np_ms_dtype_dict['int32'])) + self.assertEqual(tensor_json[CoreConst.SHAPE], [1]) + + @patch("msprobe.mindspore.dump.cell_dump_process.read_csv") + def test_process_csv_handles_unknown_io_key(self, mock_read_csv): + rows = [ + { + 'Op Name': 'Cell.net.layer.forward.0.unknown.0', + 'Shape': '(1,2)', + 'Data Type': 'float16' + } + ] + df = self.make_df(rows) + mock_read_csv.return_value = df + + result = process_csv("dummy_path") + self.assertEqual(len(result), 1) + op_name, key, tensor_json = result[0] + self.assertIsNone(op_name) + self.assertIsNone(key) + self.assertIsNone(tensor_json) + + @patch("msprobe.mindspore.dump.cell_dump_process.read_csv") + def test_process_csv_shape_parsing(self, mock_read_csv): + rows = [ + { + 'Op Name': 'Cell.net.layer.forward.0.input.0', + 'Shape': '(4, 5, 6)', + 'Data Type': 'float64' + } + ] + df = self.make_df(rows) + mock_read_csv.return_value = df + + result = process_csv("dummy_path") + self.assertEqual(result[0][2][CoreConst.SHAPE], [4, 5, 6]) + + @patch("msprobe.mindspore.dump.cell_dump_process.read_csv") + def test_process_csv_convert_special_values_bool_and_nan(self, mock_read_csv): + rows = [ + { + 'Op Name': 'Cell.net.layer.forward.0.input.0', + 'Shape': '(1,)', + 'Data Type': 'float32', + 'Max Value': 'True', + 'Min Value': 'False', + 'Avg Value': float('nan'), + 'L2Norm Value': 1.23 + } + ] + df = self.make_df(rows) + mock_read_csv.return_value = df + + result = process_csv("dummy_path") + tensor_json = result[0][2] + self.assertIs(tensor_json[CoreConst.MAX], True) + self.assertIs(tensor_json[CoreConst.MIN], False) + self.assertIsNone(tensor_json[CoreConst.MEAN]) + self.assertEqual(tensor_json[CoreConst.NORM], 1.23) + + +class TestCreateKbykJsonMultiRank(unittest.TestCase): + @patch("msprobe.mindspore.dump.cell_dump_process.create_directory", lambda path: None) + @patch( + "msprobe.mindspore.dump.cell_dump_process.save_json", + lambda path, data, indent=4: open(path, "w").write("test") + ) + def test_create_kbyk_json_multi_rank(self): + + test_cases = [ + (None, "0kernel_kbyk_dump.json"), + ("1", "1kernel_kbyk_dump.json"), + ("3", "3kernel_kbyk_dump.json"), + ] + + for rank_id_env, expected_prefix in test_cases: + with tempfile.TemporaryDirectory() as dump_path: + summary_mode = ["max"] + step = 0 + # Patch environment variable + if rank_id_env is not None: + with patch.dict(os.environ, {"RANK_ID": rank_id_env}): + config_json_path = create_kbyk_json(dump_path, summary_mode, step) + else: + with patch.dict(os.environ, {}, clear=True): + config_json_path = create_kbyk_json(dump_path, summary_mode, step) + self.assertEqual(os.path.basename(config_json_path), expected_prefix) + self.assertTrue(config_json_path.startswith(dump_path)) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_debugger_config.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_debugger_config.py index 033b0c1ea5769c3f1f8e19dd8b45c48918e15814..8a7195eac824485e75d8c1ba0752715c7c6a5600 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_debugger_config.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_debugger_config.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,14 +17,17 @@ import unittest from unittest.mock import patch from msprobe.core.common.const import Const -from msprobe.core.common_config import CommonConfig, BaseConfig +from msprobe.core.common.log import logger +from msprobe.core.common_config import CommonConfig from msprobe.mindspore.common.const import FreeBenchmarkConst from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.ms_config import StatisticsConfig class TestDebuggerConfig(unittest.TestCase): + @patch.object(logger, "error") @patch("msprobe.mindspore.debugger.debugger_config.create_directory") - def test_init(self, _): + def test_init(self, _, mock_logger_error): json_config = { "dump_path": "/absolute_path", "rank": [], @@ -32,12 +35,13 @@ class TestDebuggerConfig(unittest.TestCase): "level": "L2" } common_config = CommonConfig(json_config) - task_config = BaseConfig(json_config) + task_config = StatisticsConfig(json_config) debugger_config = DebuggerConfig(common_config, task_config) self.assertEqual(debugger_config.task, Const.STATISTICS) self.assertEqual(debugger_config.file_format, "npy") self.assertEqual(debugger_config.check_mode, "all") self.assertEqual(debugger_config.overflow_nums, 1) + self.assertEqual(debugger_config.tensor_list, []) common_config.level = "L1" common_config.task = Const.FREE_BENCHMARK @@ -49,17 +53,16 @@ class TestDebuggerConfig(unittest.TestCase): task_config.handler_type = FreeBenchmarkConst.FIX task_config.pert_mode = FreeBenchmarkConst.ADD_NOISE - with self.assertRaises(Exception) as context: + with self.assertRaises(ValueError): DebuggerConfig(common_config, task_config) - self.assertEqual(str(context.exception), - "pert_mode must be improve_precision or empty when handler_type is fix, " - f"but got {FreeBenchmarkConst.ADD_NOISE}.") + mock_logger_error.assert_called_with("pert_mode must be improve_precision or empty when handler_type is fix, " + f"but got {FreeBenchmarkConst.ADD_NOISE}.") + mock_logger_error.reset_mock() task_config.handler_type = FreeBenchmarkConst.FIX task_config.pert_mode = FreeBenchmarkConst.DEFAULT_PERT_TYPE task_config.fuzz_stage = Const.BACKWARD - with self.assertRaises(Exception) as context: + with self.assertRaises(ValueError): DebuggerConfig(common_config, task_config) - self.assertEqual(str(context.exception), - "handler_type must be check or empty when fuzz_stage is backward, " - f"but got {task_config.handler_type}.") + mock_logger_error.assert_called_with("handler_type must be check or empty when fuzz_stage is backward, " + f"but got {task_config.handler_type}.") diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_precision_debugger.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_precision_debugger.py index 066ff537ce6fba12f712ae3d4681115499be35a6..6d26fc29b9357d4d6096f71d9bb61c3bdbb780ca 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_precision_debugger.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_precision_debugger.py @@ -16,14 +16,16 @@ import unittest from unittest.mock import patch, MagicMock -from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.core.common.const import Const, MsgConst +from msprobe.core.common_config import CommonConfig +from msprobe.core.debugger.precision_debugger import BasePrecisionDebugger from msprobe.mindspore.cell_processor import CellProcessor from msprobe.mindspore.common.const import Const as MsConst from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell -from msprobe.mindspore.runtime import Runtime +from msprobe.mindspore.ms_config import StatisticsConfig +from msprobe.core.common.runtime import Runtime class TestPrecisionDebugger(unittest.TestCase): @@ -48,12 +50,12 @@ class TestPrecisionDebugger(unittest.TestCase): } common_config = CommonConfig(json_config) - task_config = BaseConfig(json_config) + task_config = StatisticsConfig(json_config) handler = Handler() mock_get_mode = MagicMock() mock_parse_json_config = MagicMock() - with patch("msprobe.mindspore.debugger.precision_debugger.parse_json_config", new=mock_parse_json_config), \ + with patch.object(BasePrecisionDebugger, "_parse_config_path", new=mock_parse_json_config), \ patch.object(PrecisionDebugger, "_get_execution_mode", new=mock_get_mode), \ patch("msprobe.mindspore.debugger.precision_debugger.TaskHandlerFactory.create", return_value=handler), \ patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions"): @@ -68,20 +70,20 @@ class TestPrecisionDebugger(unittest.TestCase): self.assertTrue(Handler.called) mock_get_mode.return_value = MsConst.PYNATIVE_MODE - with patch("msprobe.mindspore.debugger.precision_debugger.Service") as mock_Service, \ + with patch("msprobe.mindspore.debugger.precision_debugger.MindsporeService") as mock_Service, \ patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions"): debugger = PrecisionDebugger() debugger.start() service = mock_Service.return_value mock_Service.assert_called_with(debugger.config) - service.start.assert_called_with(None) + service.start.assert_called_with(None, None) PrecisionDebugger._instance = None with self.assertRaises(Exception) as context: debugger.start() self.assertEqual(str(context.exception), MsgConst.NOT_CREATED_INSTANCE) - with patch("msprobe.mindspore.debugger.precision_debugger.parse_json_config", new=mock_parse_json_config), \ + with patch.object(BasePrecisionDebugger, "_parse_config_path", new=mock_parse_json_config), \ patch.object(PrecisionDebugger, "_get_execution_mode", new=mock_get_mode), \ patch("msprobe.mindspore.debugger.precision_debugger.TaskHandlerFactory.create", return_value=handler), \ patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions"): @@ -98,6 +100,9 @@ class TestPrecisionDebugger(unittest.TestCase): def __init__(self): self.task = Const.TENSOR self.service = None + self.config = MagicMock() + self.config.level_ori = MagicMock() + self.config.level_ori.return_value = Const.LEVEL_L1 PrecisionDebugger._instance = None with self.assertRaises(Exception) as context: PrecisionDebugger.stop() @@ -123,20 +128,25 @@ class TestPrecisionDebugger(unittest.TestCase): mock_reset_cell.assert_called_once() def test_forward_backward_dump_end(self): - with patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions"): + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [], + "level": "L1", + "async_dump": False + } + + common_config = CommonConfig(json_config) + task_config = StatisticsConfig(json_config) + with patch.object(BasePrecisionDebugger, "_parse_config_path", return_value=(common_config, task_config)), \ + patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions"): debugger = PrecisionDebugger() debugger.task = "statistics" debugger.service = MagicMock() debugger.forward_backward_dump_end() debugger.service.stop.assert_called_once() - def test_is_graph_dump_level_not_kernel(self): - config = MagicMock() - config.level = "NOT_KERNEL" - config.list = ["some_value"] - result = PrecisionDebugger._is_graph_dump(config) - self.assertFalse(result) - def test_is_graph_dump_empty_list(self): config = MagicMock() config.level = MsConst.KERNEL diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/dump/test_ms_kernel_config.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/dump/test_ms_kernel_config.py index 54c59b6409cb546384dcb50f47c7c27975fa1cb7..e760faefd31b2e2b60c24091c6eebed087f4268f 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/dump/test_ms_kernel_config.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/dump/test_ms_kernel_config.py @@ -16,11 +16,11 @@ import unittest from unittest.mock import patch -from msprobe.mindspore.dump.kernel_dump.kernel_config import create_kernel_config_json +from msprobe.core.kernel_dump.kernel_config import create_kernel_config_json class TestPtKernelConfig(unittest.TestCase): - @patch("msprobe.mindspore.dump.kernel_dump.kernel_config.save_json") + @patch("msprobe.core.kernel_dump.kernel_config.save_json") def test_create_kernel_config_json_with_rank(self, mock_save_json): dump_path = "./step0" cur_rank = 0 @@ -36,7 +36,7 @@ class TestPtKernelConfig(unittest.TestCase): } mock_save_json.assert_called_once_with(kernel_config_path, config_info, indent=4) - @patch("msprobe.mindspore.dump.kernel_dump.kernel_config.save_json") + @patch("msprobe.core.kernel_dump.kernel_config.save_json") def test_create_kernel_config_json_without_rank(self, mock_save_json): dump_path = "./step0" cur_rank = '' diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/common/test_ms_free_benchmark_utils.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/common/test_ms_free_benchmark_utils.py index 1f37e18c6ef8d3facb526f3c54169a17f4616189..d1f1b48dfb9a622b10fcf39c32394339f0a8df9c 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/common/test_ms_free_benchmark_utils.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/common/test_ms_free_benchmark_utils.py @@ -22,7 +22,7 @@ from msprobe.mindspore.common.const import FreeBenchmarkConst from msprobe.mindspore.free_benchmark.common.config import Config from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams from msprobe.mindspore.free_benchmark.common.utils import Tools, UnequalRow, make_unequal_row -from msprobe.mindspore.runtime import Runtime +from msprobe.core.common.runtime import Runtime class TestUtils(unittest.TestCase): diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/handler/test_ms_base_handler.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/handler/test_ms_base_handler.py index d7f5b0745cff481d0bc2e5771df36beb492d4015..91230456911fcace6877869a270264b6b95b6793 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/handler/test_ms_base_handler.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/handler/test_ms_base_handler.py @@ -24,6 +24,7 @@ from msprobe.mindspore.common.log import logger from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams from msprobe.mindspore.free_benchmark.common.utils import Tools from msprobe.mindspore.free_benchmark.handler.base_handler import BaseHandler +from msprobe.mindspore.dump.hook_cell.api_register import get_api_register class Handler(BaseHandler): @@ -45,6 +46,7 @@ class TestBaseHandler(unittest.TestCase): @classmethod def setUpClass(cls): cls.base_handler = Handler("api_name_with_id") + get_api_register(True).restore_all_api() def test___init__(self): base_handler = Handler("api_name_with_id") @@ -93,7 +95,7 @@ class TestBaseHandler(unittest.TestCase): first_tensor = Tensor([1.0, 1.2], dtype=ms.bfloat16) second_tensor = Tensor([1.5, 2.0], dtype=ms.bfloat16) - target = ops.max(ops.div(second_tensor.to(ms.float32), first_tensor.to(ms.float32)))[0].item() + target = ops.max(ops.div(ops.cast(second_tensor, ms.float32), ops.cast(first_tensor, ms.float32)))[0].item() ret = self.base_handler.get_endless_norm(first_tensor, second_tensor, abs_tol) self.assertEqual(ret, target) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/handler/test_ms_check_handler.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/handler/test_ms_check_handler.py index 58c0a7b46ad7ca5c05a157733d57dd8828ced24d..d983073794b31800c0deb56123635a8e3fb785c7 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/handler/test_ms_check_handler.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/handler/test_ms_check_handler.py @@ -24,7 +24,7 @@ from msprobe.mindspore.common.log import logger from msprobe.mindspore.free_benchmark.common.config import Config from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams from msprobe.mindspore.free_benchmark.handler.check_handler import CheckHandler -from msprobe.mindspore.runtime import Runtime +from msprobe.core.common.runtime import Runtime def where(*args): diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_base_perturbation.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_base_perturbation.py index 3469e809d3fb27f9e366d128cfa10d68c776e391..41f7ea6db9e55b4886fbf2d0b21b5c2abb2e1551 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_base_perturbation.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_base_perturbation.py @@ -33,7 +33,7 @@ class TestBasePerturbation(unittest.TestCase): self.assertFalse(TestBasePerturbation.base_pert.is_fuzzed) self.assertIsNone(TestBasePerturbation.base_pert.perturbation_value) - @patch("msprobe.mindspore.service.Service.should_execute_hook", return_value=False) + @patch("msprobe.core.hook_manager.BaseHookManager._should_execute_hook", return_value=False) def test_get_fuzzed_result(self, _): params = HandlerParams() params.args = [Tensor([1.0], dtype=ms.float32), Tensor([5.0], dtype=ms.float32)] diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_improve_precision.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_improve_precision.py index e200bb40868fab8a9618047244830aa8a74cec27..84f9766a053c9f0b1bed69f49a1a044300d9e215 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_improve_precision.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_improve_precision.py @@ -68,7 +68,7 @@ class TestImprovePrecisionPerturbation(unittest.TestCase): self.assertEqual(ret.dtype, target.dtype) self.assertFalse(self.improve_precision_pert.is_fuzzed) - @patch("msprobe.mindspore.service.Service.should_execute_hook", return_value=False) + @patch("msprobe.core.hook_manager.BaseHookManager._should_execute_hook", return_value=False) @patch.object(logger, "warning") def test_handle(self, mock_warning, _): self.improve_precision_pert.is_fuzzed = False diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_perturbation_factory.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_perturbation_factory.py index 858e664bbaddb3506bf53ea067eeca1c9706b43b..a4458912149fc8600d32a542c398a335be5d636d 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_perturbation_factory.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_perturbation_factory.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +14,9 @@ # limitations under the License. import unittest +from unittest.mock import patch +from msprobe.mindspore.common.log import logger from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import PerturbationFactory from msprobe.mindspore.free_benchmark.common.config import Config from msprobe.mindspore.common.const import FreeBenchmarkConst @@ -27,14 +29,14 @@ from msprobe.mindspore.free_benchmark.perturbation.exchange_value import Exchang class TestPerturbationFactory(unittest.TestCase): - def test_create(self): + @patch.object(logger, "error") + def test_create(self, mock_logger_error): api_name = "Functional.add.0" Config.pert_type = "UNKNOWN" - with self.assertRaises(Exception) as context: + with self.assertRaises(ValueError): PerturbationFactory.create(api_name) - self.assertEqual(str(context.exception), - "UNKNOWN is a invalid perturbation type") + mock_logger_error.assert_called_with("UNKNOWN is a invalid perturbation type") Config.pert_type = FreeBenchmarkConst.EXCHANGE_VALUE pert = PerturbationFactory.create(api_name) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py index e589dd4d58715d74644047f8c7e7a6ce79ccf225..d67e9073e45cd0312706ef616f55118265b15d20 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import threading import os from unittest import TestCase from unittest.mock import patch @@ -23,18 +24,24 @@ from mindspore import Tensor, mint, ops from msprobe.core.common.const import Const from msprobe.mindspore.common.const import FreeBenchmarkConst from msprobe.mindspore.common.log import logger -from msprobe.mindspore.dump.hook_cell.api_registry import api_register -from msprobe.mindspore.free_benchmark.api_pynative_self_check import (ApiPyNativeSelfCheck, check_all_tensor, - check_self, data_pre_deal, - deal_fuzzed_and_original_result, - get_module, get_supported_ops, - get_target_arg_index, need_wrapper_func) +from msprobe.mindspore.free_benchmark.api_pynative_self_check import ( + ApiPyNativeSelfCheck, + check_all_tensor, + check_self, + data_pre_deal, + deal_fuzzed_and_original_result, + get_module, + get_supported_ops, + get_target_arg_index, + need_wrapper_func, + _api_register +) from msprobe.mindspore.free_benchmark.common.config import Config from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams from msprobe.mindspore.free_benchmark.common.utils import Tools from msprobe.mindspore.free_benchmark.handler.check_handler import CheckHandler from msprobe.mindspore.free_benchmark.handler.fix_handler import FixHandler -from msprobe.mindspore.runtime import Runtime +from msprobe.core.common.runtime import Runtime class DebuggerConfig: @@ -83,31 +90,34 @@ class TestApiPyNativeSelfCheck(TestCase): self.assertEqual(self_checker.ori_func, target_ori_func) def test_handle(self): - with patch.object(api_register, "initialize_hook") as mock_init_hook, \ - patch.object(api_register, "api_set_hook_func") as mock_set_hook: + with patch.object(_api_register, "initialize_hook") as mock_init_hook, \ + patch.object(_api_register, "register_all_api") as mock_set_hook: self.checker.handle() mock_init_hook.assert_called_with(self.checker.build_hook) mock_set_hook.assert_called_once() def test_build_hook(self): - _, forward_hook, backward_hook, _ = self.checker.build_hook("Functional.add.") + hook_set = self.checker.build_hook("Functional.add.") cell = Cell() + tid = threading.get_ident() + cell.msprobe_input_kwargs = {tid: {}} with patch("msprobe.mindspore.free_benchmark.api_pynative_self_check.need_wrapper_func", return_value=False): - self.assertIsNone(forward_hook(cell, "input", "output")) + self.assertIsNone(hook_set.forward_hook(cell, "input", "output")) cell = Cell() + cell.msprobe_input_kwargs = {tid: {}} self.checker.api_list = ["mindspore.ops.add"] self.checker.ori_func["mindspore.ops.add"] = "add" with patch("msprobe.mindspore.free_benchmark.api_pynative_self_check.need_wrapper_func", return_value=True), \ patch("msprobe.mindspore.free_benchmark.api_pynative_self_check.check_self", return_value="ret") as mock_check: - ret = forward_hook(cell, ("input",), ("output",)) + ret = hook_set.forward_hook(cell, ("input",), ("output",)) self.assertEqual(ret, "ret") mock_check.assert_called_with("Functional.add.0", ("output",), "add", "input") - self.assertIsNone(backward_hook("cell", "grad_input", "grad_output")) + self.assertIsNone(hook_set.backward_hook("cell", "grad_input", "grad_output")) def test_store_original_func(self): self.checker.api_list = ["mindspore.ops.add"] @@ -156,8 +166,8 @@ class TestApiPyNativeSelfCheck(TestCase): mock_warning.reset_mock() Config.stage = Const.FORWARD with patch.object(logger, "info") as mock_info, \ - patch.object(api_register, "api_set_ori_func") as mock_set_ori, \ - patch.object(api_register, "api_set_hook_func") as mock_set_hook, \ + patch.object(_api_register, "restore_all_api") as mock_set_ori, \ + patch.object(_api_register, "register_all_api") as mock_set_hook, \ patch("msprobe.mindspore.free_benchmark.api_pynative_self_check.deal_fuzzed_and_original_result", return_value="ret"): args = (1.0, 1.0) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_self_check_tool_factory.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_self_check_tool_factory.py index fa68b8896c26d4156833c54d2b2bf5b443164e8f..4f3ddd45b5a05162c60abe831967dd449f3f5ae6 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_self_check_tool_factory.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_self_check_tool_factory.py @@ -1,7 +1,6 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,11 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" import os import unittest +from unittest.mock import patch +from msprobe.core.common.log import logger from msprobe.mindspore.free_benchmark.self_check_tool_factory import SelfCheckToolFactory from msprobe.mindspore.free_benchmark.api_pynative_self_check import ApiPyNativeSelfCheck from msprobe.mindspore.debugger.debugger_config import DebuggerConfig @@ -28,7 +28,8 @@ from msprobe.core.common.const import Const class TestSelfCheckToolFactory(unittest.TestCase): - def test_create(self): + @patch.object(logger, "error") + def test_create(self, mock_logger_error): common_config = CommonConfig({}) common_config.task = Const.FREE_BENCHMARK common_config.dump_path = os.path.dirname(os.path.realpath(__file__)) @@ -36,16 +37,16 @@ class TestSelfCheckToolFactory(unittest.TestCase): config = DebuggerConfig(common_config, task_config) config.level = "UNKNOWN" - with self.assertRaises(Exception) as context: + with self.assertRaises(ValueError): SelfCheckToolFactory.create(config) - self.assertEqual(str(context.exception), "UNKNOWN is not supported.") + mock_logger_error.assert_called_with("UNKNOWN is not supported.") + mock_logger_error.reset_mock() config.level = MsConst.API config.execution_mode = MsConst.GRAPH_KBYK_MODE - with self.assertRaises(Exception) as context: + with self.assertRaises(ValueError): SelfCheckToolFactory.create(config) - self.assertEqual(str(context.exception), - f"Task free_benchmark is not supported in this mode: {MsConst.GRAPH_KBYK_MODE}.") + mock_logger_error.assert_called_with(f"Task free_benchmark is not supported in this mode: {MsConst.GRAPH_KBYK_MODE}.") config.execution_mode = MsConst.PYNATIVE_MODE tool = SelfCheckToolFactory.create(config) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/grad_probe/test_grad_analyzer.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/grad_probe/test_grad_analyzer.py index 802769d9005916c8723d436349d13ca7f557a00a..af8f6b0477f766507db55dbe345f2d802415dc14 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/grad_probe/test_grad_analyzer.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/grad_probe/test_grad_analyzer.py @@ -1,6 +1,7 @@ import os import shutil import json +import time import numpy as np import mindspore as ms from unittest import TestCase, mock @@ -15,7 +16,8 @@ class TestGradAnalyzer(TestCase): @classmethod def setUpClass(cls): cls.output_path = "./test_output" - cls.dump_dir = f"{cls.output_path}/rank0/Dump" + cls.time_stamp = str(int(time.time())) + cls.dump_dir = f"{cls.output_path}/rank0/Dump{cls.time_stamp}" cls.save_dir = f"{cls.output_path}/rank0" os.makedirs(cls.dump_dir, exist_ok=True) @@ -31,7 +33,8 @@ class TestGradAnalyzer(TestCase): 'get_context.side_effect': lambda x: { GradConst.OUTPUT_PATH: self.output_path, GradConst.LEVEL: GradConst.LEVEL2, - GradConst.BOUNDS: [-0.1, 0.0, 0.1] + GradConst.BOUNDS: [-0.1, 0.0, 0.1], + GradConst.TIME_STAMP: self.time_stamp, }[x] })) # Clear dump directory before each test diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/grad_probe/test_ms_grad_monitor.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/grad_probe/test_ms_grad_monitor.py deleted file mode 100644 index ae24457a444bfdddc796802126150577280d7e62..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/grad_probe/test_ms_grad_monitor.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import hashlib -import json -import os -import shutil -from unittest import TestCase -from unittest.mock import patch - -import numpy as np -import mindspore -from mindspore import nn, Tensor -from mindspore.nn import SGD - -from msprobe.core.common.file_utils import FileOpen -from msprobe.core.grad_probe.constant import GradConst -from msprobe.mindspore import PrecisionDebugger -from msprobe.mindspore.grad_probe.global_context import grad_context - - -file_path = os.path.abspath(__file__) -directory = os.path.dirname(file_path) -config_json_path = os.path.join(directory, "config.json") - - -def main(): - PrecisionDebugger._instance = None - PrecisionDebugger.initialized = False - grad_context._setting[GradConst.CURRENT_STEP] = 0 - with patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions"): - debugger = PrecisionDebugger(config_json_path) - - class SimpleNet(nn.Cell): - def __init__(self): - super().__init__() - self.my_dense = nn.Dense(16, 5) - - def construct(self, x): - x = self.flatten(x) - logits = self.my_dense(x) - return logits - model = SimpleNet() - optimizer = SGD(model.trainable_params(), learning_rate=0.001) - - debugger.monitor(optimizer) - - fix_gradient = tuple([Tensor(np.arange(5*16).reshape((5, 16)), dtype=mindspore.float32), - Tensor(np.arange(5).reshape(5), dtype=mindspore.float32)]) - - steps = 10 - - for _ in range(steps): - optimizer(fix_gradient) - - -def save_dict_as_json(data, json_file_path): - with FileOpen(json_file_path, 'w') as f: - json.dump(data, f, ensure_ascii=False, indent=4) - print(f"字典已保存为json文件: {json_file_path}") - - -def get_hash(file_path): - with FileOpen(file_path, 'rb') as file: - hash_object = hashlib.md5() - for chunk in iter(lambda: file.read(4096), b""): - hash_object.update(chunk) - return hash_object.hexdigest() - - -class TestMsGradientMonitor(TestCase): - def test_gradient_monitor_L2(self): - gradient_output_path = os.path.join(directory, "gradient_output") - if os.path.isfile(config_json_path): - os.remove(config_json_path) - if os.path.isdir(gradient_output_path): - shutil.rmtree(gradient_output_path) - config_dict = { - "task": "grad_probe", - "dump_path": gradient_output_path, - "rank": [], - "step": [1], - "grad_probe": { - "grad_level": "L2", - "param_list": [] - } - } - save_dict_as_json(config_dict, config_json_path) - - main() - - my_dense_bias_path = os.path.join(gradient_output_path, "rank0", "step1", "my_dense.bias.npy") - self.assertTrue(os.path.isfile(my_dense_bias_path), "bias npy file not found") - my_dense_bias_real = np.load(my_dense_bias_path) - my_dense_bias_target = np.arange(5).reshape(5) > 0 - - self.assertTrue((my_dense_bias_real == my_dense_bias_target).all(), "bias ndarray not same as target") - - my_dense_weight_path = os.path.join(gradient_output_path, "rank0", "step1", "my_dense.weight.npy") - self.assertTrue(os.path.isfile(my_dense_weight_path), "weight npy file not found") - my_dense_weight_real = np.load(my_dense_weight_path) - my_dense_weight_target = np.arange(5*16).reshape((5, 16)) > 0 - - self.assertTrue((my_dense_weight_real == my_dense_weight_target).all(), "weight ndarray not same as target") - - real_md5_value = get_hash(os.path.join(gradient_output_path, "rank0", "grad_summary_1.csv")) - target_md5_value = "d5e71f1aa37d48ef0ca0a75932597a29" - self.assertEqual(real_md5_value, target_md5_value, "hash value of grad_summary_1.csv is not same as target") - - def test_gradient_monitor_L1(self): - gradient_output_path = os.path.join(directory, "gradient_output") - if os.path.isfile(config_json_path): - os.remove(config_json_path) - if os.path.isdir(gradient_output_path): - shutil.rmtree(gradient_output_path) - config_dict = { - "task": "grad_probe", - "dump_path": gradient_output_path, - "rank": [], - "step": [1], - "grad_probe": { - "grad_level": "L1", - "param_list": [] - } - } - save_dict_as_json(config_dict, config_json_path) - - main() - - my_dense_bias_path = os.path.join(gradient_output_path, "rank0", "step1", "my_dense.bias.npy") - self.assertTrue(os.path.isfile(my_dense_bias_path), "bias npy file not found") - my_dense_bias_real = np.load(my_dense_bias_path) - my_dense_bias_target = np.arange(5).reshape(5) > 0 - - self.assertTrue((my_dense_bias_real == my_dense_bias_target).all(), "bias ndarray not same as target") - - my_dense_weight_path = os.path.join(gradient_output_path, "rank0", "step1", "my_dense.weight.npy") - self.assertTrue(os.path.isfile(my_dense_weight_path), "weight npy file not found") - my_dense_weight_real = np.load(my_dense_weight_path) - my_dense_weight_target = np.arange(5*16).reshape((5, 16)) > 0 - - self.assertTrue((my_dense_weight_real == my_dense_weight_target).all(), "weight ndarray not same as target") - - real_md5_value = get_hash(os.path.join(gradient_output_path, "rank0", "grad_summary_1.csv")) - target_md5_value = "a4ad300992cb10965fbc12c2ee19dd37" - self.assertEqual(real_md5_value, target_md5_value, "hash value of grad_summary_1.csv is not same as target") - - def test_gradient_monitor_L0(self): - gradient_output_path = os.path.join(directory, "gradient_output") - if os.path.isfile(config_json_path): - os.remove(config_json_path) - if os.path.isdir(gradient_output_path): - shutil.rmtree(gradient_output_path) - config_dict = { - "task": "grad_probe", - "dump_path": gradient_output_path, - "rank": [], - "step": [1], - "grad_probe": { - "grad_level": "L0", - "param_list": [] - } - } - save_dict_as_json(config_dict, config_json_path) - - main() - - real_md5_value = get_hash(os.path.join(gradient_output_path, "rank0", "grad_summary_1.csv")) - target_md5_value = "62e137a119c0d1a44623f10049c3f80d" - self.assertEqual(real_md5_value, target_md5_value, "hash value of grad_summary_1.csv is not same as target") diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/hook_module/test_ms_hook_manager.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/hook_module/test_ms_hook_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..2426f5261960ea03cdd22bae3761b5051448fe89 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/hook_module/test_ms_hook_manager.py @@ -0,0 +1,141 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading + +import unittest +from unittest.mock import MagicMock, patch + +from msprobe.core.common.const import Const +from msprobe.core.hook_manager import HookSet, BaseHookManager +from msprobe.mindspore.dump.hook_cell.ms_hook_manager import MindsporeHookManager + + +class TestMindsporeHookManager(unittest.TestCase): + def setUp(self): + self.mock_data_collector = MagicMock() + self.mock_config = MagicMock() + self.mock_config.data_mode = ["all"] + self.mock_config.task = "statistics" + self.mock_config.level = Const.LEVEL_L1 + self.manager = MindsporeHookManager( + self.mock_data_collector, + self.mock_config + ) + BaseHookManager.inner_switch[threading.get_ident()] = False + + def test_properties(self): + self.assertIsNone(self.manager._is_recompute) + + with patch('msprobe.mindspore.dump.hook_cell.ms_hook_manager._no_grad') as mock_no_grad: + ctx = self.manager._no_grad_context() + mock_no_grad.assert_called_once() + + def test_add_count(self): + with patch('msprobe.mindspore.dump.hook_cell.ms_hook_manager.HOOKCell.add_cell_count') as mock_add: + self.manager._add_count("test_module") + mock_add.assert_called_once_with("test_module") + + def test_process_kwargs_and_output(self): + tid = threading.get_ident() + mock_module = MagicMock() + mock_module.msprobe_input_kwargs = {tid: {"kw1": "v1"}} + + kwargs, output = self.manager._process_kwargs_and_output( + mock_module, + tid, + Const.API, + "output_value", + "ignored" + ) + self.assertEqual(kwargs, {"kw1": "v1"}) + self.assertEqual(output, "output_value") + + with patch('msprobe.mindspore.dump.hook_cell.ms_hook_manager.has_kwargs_in_forward_hook', return_value=True): + kwargs, output = self.manager._process_kwargs_and_output( + mock_module, + tid, + Const.MODULE, + "kwargs_value", + "output_value" + ) + self.assertEqual(kwargs, "kwargs_value") + self.assertEqual(output, "output_value") + + def test_build_hook(self): + hook_set = self.manager.build_hook(Const.API, "test_api") + self.assertIsInstance(hook_set, HookSet) + self.assertEqual(hook_set.forward_pre_hook.__name__, "forward_pre_hook") + + hook_set = self.manager.build_hook(Const.MODULE, "test_module") + self.assertEqual(hook_set.forward_hook.__name__, "forward_hook") + self.assertEqual(hook_set.backward_pre_hook.__name__, "backward_pre_hook") + self.assertEqual(hook_set.backward_hook.__name__, "backward_hook") + + def test_need_exchange(self): + mock_module = MagicMock() + del mock_module.has_pre_hook_called + self.assertFalse(self.manager._need_exchange(mock_module)) + + mock_module.has_pre_hook_called = False + self.assertFalse(self.manager._need_exchange(mock_module)) + + mock_module.has_pre_hook_called = True + self.assertTrue(self.manager._need_exchange(mock_module)) + + def test_get_params_dict(self): + mock_module = MagicMock() + + self.mock_config.task = Const.STRUCTURE + params_dict = self.manager._get_params_dict(mock_module) + self.assertEqual(params_dict, {}) + + self.mock_config.task = "statistics" + mock_params = { + "test_module.weight": "w1", + "test_module.bias": "b1" + } + mock_module.parameters_dict.return_value = mock_params + params_dict = self.manager._get_params_dict(mock_module) + mock_module.parameters_dict.assert_called_once_with(recurse=False) + self.assertEqual(params_dict, {"weight": "w1", "bias": "b1"}) + + def test_build_backward_pre_hook(self): + hook_fn = self.manager._build_backward_pre_hook(Const.MODULE, "test_module_backward") + + mock_module = MagicMock() + mock_grad_input = ("grad1", "grad2") + + with patch.object(self.manager, '_should_execute_hook', return_value=False): + hook_fn(mock_module, mock_grad_input) + self.mock_data_collector.backward_input_data_collect.assert_not_called() + + self.mock_config.level = Const.LEVEL_L2 + with patch.object(self.manager, '_should_execute_hook', return_value=True): + hook_fn(mock_module, mock_grad_input) + + self.mock_data_collector.update_api_or_module_name.assert_called_with("test_module_backward") + self.mock_data_collector.backward_input_data_collect.assert_called_once() + + call_args = self.mock_data_collector.backward_input_data_collect.call_args[0] + module_input = call_args[3] + self.assertEqual(module_input.grad_input, mock_grad_input) + + self.assertFalse(BaseHookManager.inner_switch[threading.get_ident()]) + + self.mock_config.level = Const.LEVEL_L1 + with patch.object(self.manager, '_should_execute_hook', return_value=True): + hook_fn(mock_module, mock_grad_input) + self.mock_data_collector.backward_input_data_collect.assert_called_once() diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/config/test_config.json b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/config/test_config.json new file mode 100644 index 0000000000000000000000000000000000000000..d8bcd303f5da0bf7d68fb4bcc2048dbd500efd83 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/config/test_config.json @@ -0,0 +1,12 @@ +{ + "start_step": 0, + "collect_times": 2, + "step_interval": 1, + "targets": {"layer1": {}, "layer2": {}}, + "format": "csv", + "ops": ["max", "min", "mean"], + "xy_distribution": true, + "mv_distribution": true, + "wg_distribution": true, + "param_distribution": true +} \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_common_func.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_common_func.py new file mode 100644 index 0000000000000000000000000000000000000000..8ecc7d99b781973682e359f5eb02671a4903c04a --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_common_func.py @@ -0,0 +1,118 @@ +import pytest +from unittest.mock import patch, MagicMock +from mindspore import nn, context +from mindspore.common.initializer import Normal +import mindspore as ms + +from msprobe.mindspore.monitor.common_func import ( + is_valid_instance, + get_submodules, + get_parameters, + get_rank, + comm_is_initialized, +) + +TORCH_AVAILABLE = False +try: + import torch + import torch.nn as torch_nn + TORCH_AVAILABLE = True +except ImportError: + TORCH_AVAILABLE = False + + +class TestModelUtils: + @classmethod + def setup_class(cls): + """Setup once for all tests in this class""" + cls.ms_model = MSModel() + if TORCH_AVAILABLE: + cls.torch_model = TorchModel() + + @classmethod + def teardown_class(cls): + """Cleanup after all tests in this class""" + pass + + + def test_is_valid_instance_if_model_is_cell_or_module_then_return_true(self): + with patch('msprobe.mindspore.monitor.common_func.is_mindtorch') as mock_is_mindtorch: + if TORCH_AVAILABLE: + mock_is_mindtorch.return_value = True + assert is_valid_instance(self.torch_model) + mock_is_mindtorch.return_value = False + assert is_valid_instance(self.ms_model) + + def test_is_valid_instance_if_input_is_string_then_return_false(self): + assert not is_valid_instance("not a model") + + def test_is_valid_instance_if_input_is_number_then_return_false(self): + assert not is_valid_instance(123) + + def test_get_submodules_if_model_is_valid_then_return_non_empty_dict(self): + with patch('msprobe.mindspore.monitor.common_func.is_mindtorch') as mock_is_mindtorch: + mock_is_mindtorch.return_value = True + if TORCH_AVAILABLE: + submodules = dict(get_submodules(self.torch_model)) + assert len(submodules) > 0 + assert any(name == 'conv1' for name in submodules) + + mock_is_mindtorch.return_value = False + submodules = dict(get_submodules(self.ms_model)) + assert len(submodules) > 0 + assert any(name.endswith('conv1') for name in submodules) + + + def test_get_submodules_if_model_is_invalid_then_return_empty_dict(self): + assert get_submodules("invalid") == {} + + def test_get_parameters_if_model_is_valid_then_return_non_empty_dict(self): + with patch('msprobe.mindspore.monitor.common_func.is_mindtorch') as mock_is_mindtorch: + mock_is_mindtorch.return_value = True + if TORCH_AVAILABLE: + params = dict(get_parameters(self.torch_model)) + assert any(name == 'conv1.weight' for name in params) + mock_is_mindtorch.return_value = False + params = dict(get_parameters(self.ms_model)) + assert any('conv1.weight' in name for name in params) + + + def test_get_parameters_if_model_is_invalid_then_return_empty_dict(self): + assert get_parameters(123) == {} + + def test_get_rank_if_comm_initialized_then_return_integer(self): + rank = get_rank() + assert isinstance(rank, int) + assert rank >= 0 + + def test_comm_is_initialized_when_called_then_return_boolean(self): + assert isinstance(comm_is_initialized(), bool) + + +# Test models +class MSModel(nn.Cell): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 64, 3, has_bias=True, weight_init=Normal(0.02)) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU() + + def construct(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + return x + +if TORCH_AVAILABLE: + class TorchModel(torch_nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch_nn.Conv2d(3, 64, 3) + self.bn1 = torch_nn.BatchNorm2d(64) + self.relu = torch_nn.ReLU() + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + return x \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_mon_module_hook.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_mon_module_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..87d630c3620ad242f43d8ef629f44698ce7a5bbd --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_mon_module_hook.py @@ -0,0 +1,399 @@ +import pytest +import os +import json +import numpy as np +import mock +from datetime import datetime +import unittest +import inspect +from unittest.mock import MagicMock, patch, mock_open +from collections import defaultdict + +import mindspore as ms +from mindspore import nn, ops, Tensor, Parameter +from msprobe.core.common.const import MonitorConst, Const +from msprobe.mindspore.monitor.module_hook import ( + TrainerMon, + ModuleHookContext, + OptimizerContext, + GradContext, + CommunicationContext +) + +class MyMomentum(nn.Optimizer): + def __init__(self, params, learning_rate, momentum=0.9): + super(MyMomentum, self).__init__(learning_rate, params) + self.moments = self.parameters.clone(prefix="exp_avg", init="zeros") + self.momentum = momentum + self.opt = ops.ApplyMomentum() + + def construct(self, gradients): + params = self.parameters + lr = self.get_lr() + gradients = self.flatten_gradients(gradients) + gradients = self.decay_weight(gradients) + gradients = self.gradients_centralization(gradients) + gradients = self.scale_grad(gradients) + + success = None + for param, mom, grad in zip(params, self.moments, gradients): + success = self.opt(param, mom, lr, grad, self.momentum) + return success + + +class TestContext(unittest.TestCase): + def test_communication_context(self): + cc_ctx = CommunicationContext() + cc_ctx.reset() + cc_ctx.data = {'tag1': {'min': [1, 2, 3], 'max': [10, 11, 12]}, + 'tag2': {'min': [16, 17, 18], 'max': [22, 23, 24]}} + cc_ctx.aggregate() + expected_aggregated_data = {'tag1': {'max': 12, 'min': 1}, 'tag2': {'max': 24, 'min': 16}} + self.assertEqual(cc_ctx.data, expected_aggregated_data) + + def test_grad_context(self): + grad_ctx = GradContext() + grad_ctx.reset() + self.assertEqual(grad_ctx.pre, {}) + self.assertEqual(grad_ctx.post, {}) + + def test_module_hook_context_initialization(self): + """测试 ModuleHookContext 初始化状态""" + ctx = ModuleHookContext(module_name="test_module") + + # 验证基本属性 + self.assertEqual(ctx.step, 0) + self.assertEqual(ctx.micro_step, 0) + self.assertEqual(ctx.module_name, "test_module") + self.assertEqual(ctx.stack, "") + + # 验证数据结构类型 + self.assertIsInstance(ctx.actv, defaultdict) + self.assertEqual(len(ctx.actv), 0) # 应为空字典 + + self.assertIsInstance(ctx.actvgrad, list) + self.assertEqual(len(ctx.actvgrad), 0) # 应为空列表 + + self.assertIsInstance(ctx.struct, dict) + self.assertEqual(len(ctx.struct), 0) # 应为空字典 + + def test_module_hook_context_reset(self): + """测试 ModuleHookContext 重置功能""" + ctx = ModuleHookContext(module_name="test") + + # 填充测试数据 + ctx.step = 5 + ctx.micro_step = 3 + ctx.actv['layer1']['weight'] = [1.2, 3.4] + ctx.actvgrad.append('grad_data') + ctx.stack = "test_stack" + ctx.struct['meta'] = {'size': 10} + + # 执行重置 + ctx.reset() + + # 验证重置后状态 + self.assertEqual(ctx.step, 5) # 不应重置 + self.assertEqual(ctx.micro_step, 3) # 不应重置 + self.assertEqual(len(ctx.actv), 0) # 字典应清空 + self.assertEqual(len(ctx.actvgrad), 0) # 列表应清空 + self.assertEqual(ctx.stack, "test_stack") # 不应重置 + self.assertEqual(len(ctx.struct), 1) # 不应重置 + + def test_optimizer_context_initialization(self): + """测试 OptimizerContext 初始化状态""" + ctx = OptimizerContext() + + # 验证基本属性 + self.assertEqual(ctx.step, 0) + + # 验证所有字典结构均为空 + self.assertIsInstance(ctx.param_mg_direction, defaultdict) + self.assertEqual(len(ctx.param_mg_direction), 0) + + self.assertIsInstance(ctx.param_adam_update, defaultdict) + self.assertEqual(len(ctx.param_adam_update), 0) + + self.assertIsInstance(ctx.param_adam_ratio, defaultdict) + self.assertEqual(len(ctx.param_adam_ratio), 0) + + self.assertIsInstance(ctx.param_weight_grad, defaultdict) + self.assertEqual(len(ctx.param_weight_grad), 0) + + self.assertIsInstance(ctx.param_exp_avg, defaultdict) + self.assertEqual(len(ctx.param_exp_avg), 0) + + self.assertIsInstance(ctx.param_exp_avg_sq, defaultdict) + self.assertEqual(len(ctx.param_exp_avg_sq), 0) + + self.assertIsInstance(ctx.exp_avg_metric, dict) + self.assertEqual(len(ctx.exp_avg_metric), 0) + + self.assertIsInstance(ctx.exp_avg_sq_metric, dict) + self.assertEqual(len(ctx.exp_avg_sq_metric), 0) + + self.assertIsInstance(ctx.metric_dict, dict) + self.assertEqual(len(ctx.metric_dict), 0) + + self.assertIsInstance(ctx.param_metric, dict) + self.assertEqual(len(ctx.param_metric), 0) + + def test_optimizer_context_reset(self): + """测试 OptimizerContext 重置功能""" + ctx = OptimizerContext() + + # 填充测试数据 + ctx.step = 100 + ctx.param_mg_direction['weight'] = 0.5 + ctx.param_adam_update['bias'] = (0.1, 0.2) + ctx.param_adam_ratio['embed'] = 0.8 + ctx.param_weight_grad['linear'] = [-0.4, 0.6] + ctx.param_exp_avg['conv'] = [0.9] + ctx.param_exp_avg_sq['norm'] = [0.99] + ctx.exp_avg_metric['acc'] = 0.75 + ctx.exp_avg_sq_metric['loss'] = 0.25 + ctx.metric_dict['f1'] = 0.9 + ctx.param_metric['weight_metric'] = 1.0 + + # 执行重置 + ctx.reset() + + # 验证重置后状态 + self.assertEqual(ctx.step, 100) # 不应重置 + + # 所有字典/默认字典应为空 + self.assertEqual(len(ctx.param_mg_direction), 0) + self.assertEqual(len(ctx.param_adam_update), 0) + self.assertEqual(len(ctx.param_adam_ratio), 0) + self.assertEqual(len(ctx.param_weight_grad), 0) + self.assertEqual(len(ctx.param_exp_avg), 0) + self.assertEqual(len(ctx.param_exp_avg_sq), 0) + self.assertEqual(len(ctx.exp_avg_metric), 0) + self.assertEqual(len(ctx.exp_avg_sq_metric), 0) + self.assertEqual(len(ctx.metric_dict), 0) + self.assertEqual(len(ctx.param_metric), 0) + + +class TestTrainerMonWithRealNetwork: + @classmethod + def setup_class(cls): + """Setup once for all tests in this class""" + cls.mock_config = { + "start_step": 0, + "collect_times": 10, + "step_interval": 1, + "format": "csv", + "ops": ["norm"], + "alert": {"rules": [], "dump": False}, + "xy_distribution": True, + "mv_distribution": True, + "forward_only": True + } + cls.config_file = "test_config.json" + with open(cls.config_file, 'w') as f: + json.dump(cls.mock_config, f) + + # Setup real network components + cls.net = nn.Dense(2, 3) + cls.loss_fn = nn.MAELoss() + cls.opt = MyMomentum(cls.net.trainable_params(), 0.01) + + @classmethod + def teardown_class(cls): + """Clean up after all tests""" + if os.path.exists(cls.config_file): + os.remove(cls.config_file) + + def setup_method(self): + """Setup before each test""" + self.trainer = TrainerMon(self.config_file) + self.trainer.set_monitor(self.net, self.opt) + + def test_monitor_with_real_training_step_when_valid_then_pass(self): + + # Create test data + data = Tensor(np.random.rand(1, 10, 2), ms.float32) + label = Tensor(np.random.rand(1, 10, 3), ms.float32) + + # Define forward function + def forward_fn(data, label): + logits = self.net(data) + loss = self.loss_fn(logits, label) + return loss, logits + + # Define grad function + grad_fn = ms.value_and_grad(forward_fn, None, self.opt.parameters, has_aux=True) + + # Define training step + def train_step(data, label): + (loss, _), grads = grad_fn(data, label) + self.opt(grads) + return loss + + # Execute training step + loss = train_step(data, label) + + # Verify monitoring results + assert isinstance(loss, Tensor) + assert len(self.trainer.module_fwd_hook_context_by_module) > 0 + assert len(self.trainer.optimizer_context) > 0 + + def test_monitor_with_multiple_training_steps_when_valid_then_pass(self): + + # Create test data + data = Tensor(np.random.rand(1, 10, 2), ms.float32) + label = Tensor(np.random.rand(1, 10, 3), ms.float32) + + # Define forward function + def forward_fn(data, label): + logits = self.net(data) + loss = self.loss_fn(logits, label) + return loss, logits + + # Define grad function + grad_fn = ms.value_and_grad(forward_fn, None, self.opt.parameters, has_aux=True) + + # Define training step + def train_step(data, label): + (loss, _), grads = grad_fn(data, label) + self.opt(grads) + return loss + + # Execute multiple training steps + for step in range(3): + loss = train_step(data, label) + + # Verify monitoring results + assert isinstance(loss, Tensor) + assert len(self.trainer.module_fwd_hook_context_by_module) > 0 + assert len(self.trainer.optimizer_context) > 0 + assert self.trainer.optimizer_context[self.opt].step == step + 1 + + def test_monitor_with_parameter_updates_when_valid_then_pass(self): + # Get initial parameters + initial_params = [param.value() for param in self.net.get_parameters()] + + # Create test data + data = Tensor(np.random.rand(1, 10, 2), ms.float32) + label = Tensor(np.random.rand(1, 10, 3), ms.float32) + + # Define forward function + def forward_fn(data, label): + logits = self.net(data) + loss = self.loss_fn(logits, label) + return loss, logits + + # Define grad function + grad_fn = ms.value_and_grad(forward_fn, None, self.opt.parameters, has_aux=True) + + # Define training step + def train_step(data, label): + (loss, _), grads = grad_fn(data, label) + self.opt(grads) + return loss + + # Execute training step + loss = train_step(data, label) + + # Get updated parameters + updated_params = [param.value() for param in self.net.get_parameters()] + + # Verify parameters have changed + for init_param, updated_param in zip(initial_params, updated_params): + assert not np.array_equal(init_param.asnumpy(), updated_param.asnumpy()) + + def test_monitor_with_gradient_collection_when_valid_then_pass(self): + # Enable gradient monitoring + self.trainer.wg_distribution = True + self.monitor_mbs_grad = True + self.trainer._hook_weights() + + # Create test data + data = Tensor(np.random.rand(1, 10, 2), ms.float32) + label = Tensor(np.random.rand(1, 10, 3), ms.float32) + + # Define forward function + def forward_fn(data, label): + logits = self.net(data) + loss = self.loss_fn(logits, label) + return loss, logits + + # Define grad function + grad_fn = ms.value_and_grad(forward_fn, None, self.opt.parameters, has_aux=True) + + # Define training step + def train_step(data, label): + (loss, _), grads = grad_fn(data, label) + # Assign to main_grad + for param, grad in zip(self.opt.parameters, grads): + param.main_grad = grad + self.opt(grads) + return loss + + # Execute training step + loss = train_step(data, label) + + # Verify gradients were collected + assert len(self.trainer.grad_context.post) > 0 + + def test_monitor_with_momentum_collection_when_valid_then_pass(self): + # Enable momentum monitoring + self.trainer.mv_distribution = True + + # Create test data + data = Tensor(np.random.rand(1, 10, 2), ms.float32) + label = Tensor(np.random.rand(1, 10, 3), ms.float32) + + # Define forward function + def forward_fn(data, label): + logits = self.net(data) + loss = self.loss_fn(logits, label) + return loss, logits + + # Define grad function + grad_fn = ms.value_and_grad(forward_fn, None, self.opt.parameters, has_aux=True) + + # Define training step + def train_step(data, label): + (loss, _), grads = grad_fn(data, label) + self.opt(grads) + return loss + + # Execute training step + loss = train_step(data, label) + + # Verify momentum was collected + opt_context = self.trainer.optimizer_context[self.opt] + assert len(opt_context.exp_avg_metric) > 0 + + def test_dynamic_monitor_when_change_then_pass(self): + self.trainer.dynamic_enable = True + + # Create test data + data = Tensor(np.random.rand(1, 10, 2), ms.float32) + label = Tensor(np.random.rand(1, 10, 3), ms.float32) + + # Define forward function + def forward_fn(data, label): + logits = self.net(data) + loss = self.loss_fn(logits, label) + return loss, logits + + # Define grad function + grad_fn = ms.value_and_grad(forward_fn, None, self.opt.parameters, has_aux=True) + + # Define training step + def train_step(data, label): + (loss, _), grads = grad_fn(data, label) + self.opt(grads) + return loss + + for step in range(3): + loss = train_step(data, label) + if step == 0: + self.mock_config['start_step'] = 2 # 修改为step2 + self.mock_config["collect_times"] = 1 + self.mock_config['dynamic_on'] = True + with open(self.config_file, 'w') as f: + json.dump(self.mock_config, f) + assert len(self.trainer.module_fwd_hook_context_by_module) > 0 diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_mon_utils.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_mon_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..01005fadf563a92e9fac68dbde4363fff70410fb --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_mon_utils.py @@ -0,0 +1,212 @@ +import unittest +import os +import tempfile +import re +from datetime import datetime +from mindspore import dtype as mstype, Tensor +from msprobe.mindspore.monitor.features import FUNC_MAP +from msprobe.core.common.const import MonitorConst +from msprobe.core.common.utils import is_int +from msprobe.core.common.log import logger +from msprobe.core.common.file_utils import check_file_or_directory_path + +class TestMonitorUtils(unittest.TestCase): + def setUp(self): + # 创建临时目录用于测试 + self.temp_dir = tempfile.mkdtemp() + + # 创建符合 MonitorConst.OUTPUT_DIR_PATTERN 的测试目录 + self.valid_dir1 = os.path.join(self.temp_dir, "Dec03_21-34-40_rank0") + os.makedirs(self.valid_dir1) + self.valid_dir2 = os.path.join(self.temp_dir, "Dec04_22-35-41_rank1") + os.makedirs(self.valid_dir2) + self.invalid_dir = os.path.join(self.temp_dir, "invalid_directory") + os.makedirs(self.invalid_dir) + + def tearDown(self): + # 清理临时目录 + for root, dirs, files in os.walk(self.temp_dir, topdown=False): + for name in files: + os.remove(os.path.join(root, name)) + for name in dirs: + os.rmdir(os.path.join(root, name)) + os.rmdir(self.temp_dir) + + def test_get_summary_writer_tag_name(self): + from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name + + # 测试不带rank的情况 + result = get_summary_writer_tag_name("module1", "tag1", None) + self.assertEqual(result, "module1/tag1") + + # 测试带rank的情况 + result = get_summary_writer_tag_name("module2", "tag2", 1) + self.assertEqual(result, "module2/rank1/tag2") + + def test_step_accumulates_one(self): + from msprobe.mindspore.monitor.utils import step_accumulates_one + + class MockContext: + def __init__(self): + self.micro_step = 0 + self.step = 0 + + # 测试micro_step未达到micro_batch_number的情况 + context = MockContext() + step_accumulates_one(context, 3) + self.assertEqual(context.micro_step, 1) + self.assertEqual(context.step, 0) + + # 测试micro_step达到micro_batch_number的情况 + context.micro_step = 2 + step_accumulates_one(context, 3) + self.assertEqual(context.micro_step, 0) + self.assertEqual(context.step, 1) + + def test_is_skip_step(self): + from msprobe.mindspore.monitor.utils import is_skip_step + + # 测试step小于start_step的情况 + self.assertTrue(is_skip_step(5, 10, 1)) + + # 测试step等于start_step的情况 + self.assertFalse(is_skip_step(10, 10, 1)) + + # 测试step大于start_step但不满足interval的情况 + self.assertTrue(is_skip_step(11, 10, 2)) + + # 测试step大于start_step且满足interval的情况 + self.assertFalse(is_skip_step(12, 10, 2)) + + # 测试has_collect_times大于等于collect_times的情况 + self.assertTrue(is_skip_step(12, 10, 2, has_collect_times=5, collect_times=5)) + + def test_validate_ops(self): + from msprobe.core.monitor.utils import validate_ops + + # 测试输入不是list的情况 + with self.assertRaises(TypeError): + validate_ops("not_a_list") + + # 测试包含不支持op的情况 + ops = ["mean", "unsupported_op"] + result = validate_ops(ops) + self.assertIn("mean", result) + self.assertNotIn("unsupported_op", result) + + # 测试空列表情况,应该返回默认op + result = validate_ops([]) + self.assertEqual(len(result), 3) # 默认op + shape + dtype + + # 测试shape和dtype自动添加 + result = validate_ops(["mean"]) + self.assertIn("mean", result) + self.assertIn("shape", result) + self.assertIn("dtype", result) + + def test_validate_ranks(self): + from msprobe.core.monitor.utils import validate_ranks + + # 测试输入不是list的情况 + with self.assertRaises(TypeError): + validate_ranks("not_a_list") + + # 测试包含非int元素的情况 + with self.assertRaises(TypeError): + validate_ranks([1, "not_an_int", 2]) + + # 测试正常情况 + try: + validate_ranks([1, 2, 3]) + except Exception as e: + self.fail(f"validate_ranks raised unexpected exception: {e}") + + def test_validate_targets(self): + from msprobe.core.monitor.utils import validate_targets + + # 测试输入不是dict的情况 + with self.assertRaises(TypeError): + validate_targets("not_a_dict") + + # 测试key不是str的情况 + with self.assertRaises(TypeError): + validate_targets({1: {"input": "tensor"}}) + + # 测试value不是dict的情况 + with self.assertRaises(TypeError): + validate_targets({"module1": "not_a_dict"}) + + # 测试正常情况 + try: + validate_targets({"module1": {"input": "tensor"}, "module2": {"output": "tensor"}}) + except Exception as e: + self.fail(f"validate_targets raised unexpected exception: {e}") + + def test_validate_config(self): + from msprobe.core.monitor.utils import validate_config + + # 测试基本配置验证 + config = { + "ops": ["mean", "max"], + "eps": 1e-6, + "module_ranks": [0, 1], + "targets": {"module1": {"input": "tensor"}}, + "print_struct": True, + "ur_distribution": False, + "xy_distribution": True, + "wg_distribution": False, + "mg_distribution": True, + "param_distribution": False, + "cc_distribution": { + "enable": True, + "cc_codeline": ["line1", "line2"], + "cc_pre_hook": False, + "cc_log_only": True + }, + "alert": { + }, + "step_count_per_record": 10, + "start_step": 0, + "step_interval": 1, + "collect_times": 100, + "monitor_mbs_grad": True, + "dynamic_on": False + } + + try: + validate_config(config) + except Exception as e: + self.fail(f"validate_config raised unexpected exception: {e}") + + # 测试无效eps类型 + invalid_config = config.copy() + invalid_config["eps"] = "not_a_float" + with self.assertRaises(TypeError): + validate_config(invalid_config) + + def test_time_str2time_digit(self): + from msprobe.core.monitor.utils import time_str2time_digit + + # 测试有效时间字符串 + time_str = "Dec03_21-34-40" + result = time_str2time_digit(time_str) + self.assertIsInstance(result, datetime) + self.assertEqual(result.month, 12) + self.assertEqual(result.day, 3) + self.assertEqual(result.hour, 21) + + # 测试无效时间字符串 + invalid_time_str = "InvalidTimeString" + with self.assertRaises(RuntimeError): + time_str2time_digit(invalid_time_str) + + def test_get_target_output_dir(self): + from msprobe.core.monitor.utils import get_target_output_dir + + # 测试不带时间范围的情况 + result = get_target_output_dir(self.temp_dir, None, None) + self.assertEqual(len(result), 0) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_ms_features.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_ms_features.py new file mode 100644 index 0000000000000000000000000000000000000000..fa3d2705a773c86ad12fa8f843df5f4cf4dde855 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_ms_features.py @@ -0,0 +1,83 @@ +import unittest +from unittest.mock import patch + +from mindspore import mint, ops +from mindspore import Tensor +from mindspore import dtype as mstype + +from msprobe.mindspore.monitor.features import max_eigenvalue, cal_entropy, cal_qkt, cal_stable_rank + + +class TestMathFunctions(unittest.TestCase): + def test_max_eigenvalue(self): + """测试最大特征值计算""" + # 创建已知特征值的矩阵 + A = ops.diag(Tensor([3.0, 2.0, 1.0])) + + # 测试不同迭代次数 + eigval = max_eigenvalue(A, num_iterations=5) + self.assertAlmostEqual(eigval.item(), 3.0, delta=0.1) + + # 测试全零矩阵 + zero_matrix = ops.zeros((3, 3)) + eigval = max_eigenvalue(zero_matrix) + self.assertAlmostEqual(eigval.item(), 0.0) + + def test_cal_entropy(self): + """测试注意力熵计算""" + # 创建简单的注意力分数 + qk = Tensor([[1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0]]) + + # 无mask + entropy, softmax_max = cal_entropy(qk) + self.assertAlmostEqual(entropy, 0.4715, delta=0.1) + self.assertAlmostEqual(softmax_max, 0.7988, delta=0.1) + + # 带mask 和默认生成相同 + mask = Tensor([[1, 0, 0], + [1, 1, 0], + [1, 1, 1]], dtype=mstype.float32) + entropy, softmax_max = cal_entropy(qk, mask) + self.assertAlmostEqual(entropy, 0.4715, delta=0.1) + self.assertAlmostEqual(softmax_max, 0.7988, delta=0.1) + + @patch("msprobe.mindspore.monitor.features.logger") + def test_cal_qkt(self, mock_logger): + """测试QK^T计算""" + # 测试s,b,h,d顺序 + q = ops.randn((10, 2, 4, 8)) # [s, b, h, d] + k = ops.randn((10, 2, 4, 8)) # [s, b, h, d] + q_batch = ops.randn((2, 10, 4, 8)) # [b, s, h, d] + qkt = cal_qkt(q, k, order="s,b,h,d") + self.assertEqual(qkt.shape, (10, 10)) # [s, s] + + # 测试b,s,h,d顺序 + qkt = cal_qkt(q_batch, q_batch, order="b,s,h,d") + self.assertEqual(qkt.shape, (10, 10)) # [s, s] + + # 测试无效顺序 + cal_qkt(q, k, order="invalid_order") + mock_logger.warning.assert_called_with( + "Calculate qk tensor failed: Order unsupported.") + + def test_cal_stable_rank(self): + """测试谱半径计算""" + # 创建已知谱半径的矩阵 + A = ops.diag(Tensor([3.0, 2.0, 1.0])) + sr, eig = cal_stable_rank(A) + + # 验证Frobenius范数 + fro_norm = ops.norm(A, ord='fro') + self.assertAlmostEqual(sr, fro_norm / 3.0, delta=.5) # 最大特征值为3 + + # 测试正交矩阵 + ortho = ops.eye(5) + sr, eig = cal_stable_rank(ortho) + self.assertAlmostEqual(sr, Tensor(2.23/1), delta=.5) # F范数应为2.23 + self.assertAlmostEqual(eig, Tensor(1.0), delta=.1) # 特征值应为1 + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_ms_module_hook.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_ms_module_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..ad1beb97fcac842f6b16f89aa08a7982b27f28c0 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_ms_module_hook.py @@ -0,0 +1,126 @@ +import os +import unittest +from unittest.mock import MagicMock, patch +import mindspore as ms + +from msprobe.mindspore.monitor.module_hook import TrainerMon, ModuleHookContext, OptimizerContext, GradContext + +class TestTrainerMon(unittest.TestCase): + def setUp(self): + base_dir = os.path.dirname(os.path.realpath(__file__)) + self.config_path = os.path.join(base_dir, "config/test_config.json") + self.mock_config = { + "start_step": 0, + "collect_times": 2, + "step_interval": 1, + "targets": {"layer1": {}, "layer2": {}}, + "format": "csv", + "ops": ["max", "min", "mean"], + "xy_distribution": True, + "mv_distribution": True, + "wg_distribution": True, + "param_distribution": True + } + self.trainer = TrainerMon(self.config_path) + + def test_init_given_valid_config_when_initialized_then_sets_correct_attributes(self): + self.assertEqual(self.trainer.config_file_path, self.config_path) + self.assertEqual(self.trainer.start_step, 0) + self.assertEqual(self.trainer.collect_times, 2) + self.assertTrue(self.trainer.monitoring) + + @patch('os.getenv', return_value='/custom/output') + def test_get_output_base_dir_given_env_set_when_called_then_returns_custom_dir(self, mock_getenv): + from msprobe.mindspore.monitor.module_hook import get_output_base_dir + self.assertEqual(get_output_base_dir(), '/custom/output') + + @patch('os.path.getmtime', return_value=123456) + @patch('json.load', return_value={}) + def test_dynamic_monitor_given_updated_config_when_called_then_updates_config(self, mock_load, mock_mtime): + self.trainer.dynamic_enable = True + self.trainer.config_timestamp = 0 + self.trainer.monitoring = False + optimizer = MagicMock() + self.trainer.optimizer_context[optimizer] = OptimizerContext() + self.trainer.dynamic_monitor(optimizer) + self.assertEqual(self.trainer.config_timestamp, 123456) + + def test_is_target_rank_given_rank_in_list_when_called_then_returns_true(self): + self.trainer.module_rank_list = [0, 1] + self.trainer.rank = 0 + self.assertTrue(self.trainer.is_target_rank()) + + def test_is_target_rank_given_rank_not_in_list_when_called_then_returns_false(self): + self.trainer.module_rank_list = [1, 2] + self.trainer.rank = 0 + self.assertFalse(self.trainer.is_target_rank()) + + def test_hook_optimizer_given_valid_optimizer_when_called_then_adds_hooks(self): + optimizer = MagicMock() + self.trainer.hook_optimizer(optimizer) + self.assertEqual(len(self.trainer.pre_step_hooks), 1) + self.assertEqual(len(self.trainer.post_step_hooks), 1) + + def test_write_xy_tb_given_activation_data_when_called_then_writes_metrics(self): + context = ModuleHookContext("test_module") + context.actv = {"key": ms.Tensor(1.0)} + self.trainer.module_fwd_hook_context_by_module[MagicMock()] = context + self.trainer.summary_writer.write_metrics = MagicMock() + self.trainer.write_xy_tb(1) + self.trainer.summary_writer.write_metrics.assert_called() + + def test_write_grad_tb_given_grad_data_when_called_then_writes_metrics(self): + self.trainer.grad_context.acc_metric = {"grad1": ms.Tensor(0.5)} + self.trainer.grad_context.post = {"grad2": ms.Tensor(0.8)} + self.trainer.summary_writer.write_metrics = MagicMock() + self.trainer.write_grad_tb(1) + self.trainer.summary_writer.write_metrics.assert_called() + + def test_write_mv_tb_given_mv_data_when_called_then_writes_metrics(self): + context = OptimizerContext() + context.exp_avg_metric = {"m1": ms.Tensor(0.1)} + context.exp_avg_sq_metric = {"v1": ms.Tensor(0.2)} + self.trainer.summary_writer.write_metrics = MagicMock() + self.trainer.write_mv_tb(context) + self.trainer.summary_writer.write_metrics.assert_called() + + def test_write_param_tb_given_param_data_when_called_then_writes_metrics(self): + context = OptimizerContext() + context.param_metric = {"param_pre": ms.Tensor(1.0), "param_post": ms.Tensor(2.0)} + self.trainer.summary_writer.write_metrics = MagicMock() + self.trainer.write_param_tb(context) + self.trainer.summary_writer.write_metrics.assert_called() + + +class TestModuleHookContext(unittest.TestCase): + def test_reset_clears_activation_data(self): + context = ModuleHookContext("test") + context.actv = {"data": ms.Tensor(1.0)} + context.actvgrad = [ms.Tensor(2.0)] + context.reset() + self.assertEqual(len(context.actv), 0) + self.assertEqual(len(context.actvgrad), 0) + + +class TestOptimizerContext(unittest.TestCase): + def test_reset_clears_all_metrics(self): + context = OptimizerContext() + context.param_mg_direction = {"p1": 0.5} + context.param_adam_update = {"p1": ms.Tensor(0.1)} + context.reset() + self.assertEqual(len(context.param_mg_direction), 0) + self.assertEqual(len(context.param_adam_update), 0) + + +class TestGradContext(unittest.TestCase): + def test_reset_clears_grad_data(self): + context = GradContext() + context.pre = {"g1": ms.Tensor(0.1)} + context.post = {"g2": ms.Tensor(0.2)} + context.reset() + self.assertEqual(len(context.pre), 0) + self.assertEqual(len(context.post), 0) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_opt_collect.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_opt_collect.py new file mode 100644 index 0000000000000000000000000000000000000000..df3e54f1a173a943ec04d73957f80619547ce977 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_opt_collect.py @@ -0,0 +1,225 @@ +import pytest +import numpy as np +from mindspore import Tensor, nn, ops +from unittest.mock import MagicMock, patch + +from msprobe.core.common.const import MonitorConst +# Import the classes to test +from msprobe.core.common.log import logger +from msprobe.mindspore.monitor.optimizer_collect import ( + OptimizerMon, + MixPrecisionOptimizerMon, + MegatronDistributedOptimizerMon, + MegatronChainedDistributedOptimizerMon, + MegatronChainedMixPrecisionOptimizerMon, + DeepSpeedZeroOptimizerMon, + DeepSpeedZeroOptimizerStage0Mon, + DeepSpeedZeroOptimizerStage1or2Mon, + DeepSpeedZeroOptimizerStage3Mon, + OptimizerMonFactory +) + +class TestOptimizerMon: + @classmethod + def setup_class(cls): + """Setup once for all tests in this class""" + cls.mock_monitor = MagicMock() + cls.mock_monitor.name2tag = {"test_param": {MonitorConst.POST_GRAD: "test_tag"}} + cls.mock_monitor.duplicate_param = {} + cls.mock_monitor.params_have_main_grad = False + cls.mock_monitor.fsdp_wrapped_module = False + cls.mock_monitor.mv_distribution = True + cls.mock_monitor.mg_direction = True + cls.mock_monitor.ur_distribution = True + cls.mock_monitor.update_heatmap_visualizer = {"test_param": MagicMock()} + cls.mock_monitor.ratio_heatmap_visualizer = {"test_param": MagicMock()} + + def test_fetch_grad_if_param_has_valid_grad_then_return_correct_grad_values(self): + # Setup + param = MagicMock() + expected_grad = Tensor([1.0, 2.0, 3.0]) + param.grad = expected_grad + params2name = {param: "test_param"} + optimizer = MagicMock() + mon = OptimizerMon(optimizer) + + # Execute + result = mon.fetch_grad(self.mock_monitor, params2name) + + # Verify + assert len(result) == 1 + assert (result["test_tag"] == expected_grad).all() + self.mock_monitor.register_param_call_id.assert_called_once_with("hook_optimizer", "test_tag") + + def test_fetch_grad_if_param_has_main_grad_then_return_main_grad_values(self): + # Setup + param = MagicMock() + expected_grad = Tensor(np.array([1.5, 2.5])) + param.main_grad = expected_grad + param.grad = None + params2name = {param: "test_param"} + optimizer = MagicMock() + self.mock_monitor.params_have_main_grad = True + mon = OptimizerMon(optimizer) + + # Execute + result = mon.fetch_grad(self.mock_monitor, params2name) + + # Verify + assert len(result) == 1 + assert (result["test_tag"] == expected_grad).all() + + def test_fetch_mv_if_state_complete_then_return_correct_momentum_values(self): + # Setup + param = MagicMock() + params2name = {param: "test_param"} + optimizer = MagicMock() + optimizer.state = { + param: { + "exp_avg": Tensor([0.1]), + "exp_avg_sq": Tensor([0.2]), + "step": 10 + } + } + del optimizer.chained_optimizers + del optimizer.param_to_cpu_states_map + optimizer.defaults = {'betas': (0.9, 0.999), 'eps': 1e-8} + optimizer.param_groups = [{}] + + mon = OptimizerMon(optimizer) + mon.fp16_to_fp32_param = {} + + # Execute + exp_avg, exp_avg_sq, update, ratio = mon.fetch_mv(self.mock_monitor, params2name) + + # Verify + beta1, beta2 = optimizer.defaults['betas'] + step = optimizer.state[param]['step'] + + expected_exp_avg_hat = 0.1 / (1 - beta1**step) + expected_exp_avg_sq_hat = 0.2 / (1 - beta2**step) + expected_update = expected_exp_avg_hat / (np.sqrt(expected_exp_avg_sq_hat) + optimizer.defaults['eps']) + expected_ratio = expected_exp_avg_hat / np.sqrt(expected_exp_avg_sq_hat) + + assert exp_avg["test_param"] == Tensor([0.1]) + assert exp_avg_sq["test_param"] == Tensor([0.2]) + assert update["test_param"] == Tensor([expected_update]) + assert ratio["test_param"] == Tensor([expected_ratio]) + + def test_narrow_from_flatten_if_state_not_partitioned_then_return_original_state(self): + # Setup + param = MagicMock() + flatten_state = Tensor([1.0, 2.0, 3.0]) + mon = OptimizerMon(MagicMock()) + + # Execute + result = mon.narrow_from_flatten(param, flatten_state) + + # Verify + assert (result == flatten_state).all() + +class TestMixPrecisionOptimizerMon: + @classmethod + def setup_class(cls): + cls.mock_monitor = MagicMock() + cls.mock_monitor.mv_distribution = True + cls.mock_monitor.mg_direction = True + cls.mock_monitor.ur_distribution = True + cls.mock_monitor.update_heatmap_visualizer = {'param1': MagicMock(), 'param2': MagicMock()} + cls.mock_monitor.ratio_heatmap_visualizer = {'param1': MagicMock(), 'param2': MagicMock()} + + def test_map_fp16_to_fp32_param_if_multiple_groups_then_create_correct_mappings(self): + # Setup + optimizer = MagicMock() + fp16_params = [MagicMock(), MagicMock(), MagicMock()] + fp32_params = [MagicMock(), MagicMock(), MagicMock()] + optimizer.float16_groups = [fp16_params[:2], [fp16_params[2]]] + optimizer.fp32_from_float16_groups = [fp32_params[:2], [fp32_params[2]]] + + mon = MixPrecisionOptimizerMon(optimizer) + + # Execute + mon.map_fp16_to_fp32_param(optimizer) + + # Verify + assert len(mon.fp16_to_fp32_param) == 3 + for fp16, fp32 in zip(fp16_params, fp32_params): + assert mon.fp16_to_fp32_param[fp16] == fp32 + +class TestDeepSpeedZeroOptimizerStage1or2Mon: + @classmethod + def setup_class(cls): + """Setup once for all tests in this class""" + cls.mock_monitor = MagicMock() + cls.mock_monitor.name2tag = {"test_param": {MonitorConst.POST_GRAD: "test_tag"}} + cls.mock_monitor.duplicate_param = {} + cls.mock_monitor.params_have_main_grad = False + cls.mock_monitor.mg_direction = True + cls.mock_monitor.ur_distribution = True + + def test_fetch_grad_if_param_in_partition_then_return_correct_grad_slice(self): + # Setup + optimizer = MagicMock() + param = MagicMock() + params2name = {param: "test_param"} + expected_grad = Tensor(np.array([1.0, 2.0, 3.0])) + param.main_grad = expected_grad + param.grad = None + optimizer.bit16_groups = [[param]] + optimizer.cpu_offload = False + mon = DeepSpeedZeroOptimizerStage1or2Mon(optimizer) + mon.param2group = {param: 0} + mon.get_param_index = MagicMock(return_value=1) + mon.param_not_in_partition = MagicMock(return_value=False) + mon.get_position = MagicMock(return_value=(3, 3)) # start at index 3, length 3 + + # MagicMock the averaged_gradients structure + optimizer.averaged_gradients = { + 0: [ + None, # index 0 + Tensor(np.array([1.0, 2.0, 3.0])) # index 1 + ] + } + + # Execute + result = mon.fetch_grad(self.mock_monitor, params2name) + + # Verify + assert len(result) == 1 + assert (result["test_tag"] == expected_grad).all() + +class TestOptimizerMonFactory: + @classmethod + def setup_class(cls): + cls.mock_monitor = MagicMock() + cls.mock_monitor.mv_distribution = True + cls.mock_monitor.mg_direction = True + cls.mock_monitor.ur_distribution = True + cls.mock_monitor.update_heatmap_visualizer = {'param1': MagicMock(), 'param2': MagicMock()} + cls.mock_monitor.ratio_heatmap_visualizer = {'param1': MagicMock(), 'param2': MagicMock()} + + def test_create_optimizer_mon_if_chained_optimizer_then_return_correct_monitor_type(self): + # Setup + base_optimizer = MagicMock() + base_optimizer.__class__.__name__ = "DistributedOptimizer" + optimizer = MagicMock() + optimizer.__class__.__name__ = "ChainedOptimizer" + optimizer.chained_optimizers = [base_optimizer] + + # Execute + result = OptimizerMonFactory.create_optimizer_mon(optimizer) + + # Verify + assert isinstance(result, MegatronChainedDistributedOptimizerMon) + + def test_create_optimizer_mon_if_deepspeed_stage3_then_return_stage3_monitor(self): + # Setup + optimizer = MagicMock() + optimizer.__class__.__name__ = "DeepSpeedZeroOptimizer_Stage3" + + # Execute + result = OptimizerMonFactory.create_optimizer_mon(optimizer) + + # Verify + assert isinstance(result, DeepSpeedZeroOptimizerStage3Mon) + assert result.stage == '3' diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/save/test_debugger_save_mindspore.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/save/test_debugger_save_mindspore.py new file mode 100644 index 0000000000000000000000000000000000000000..d79fe48a75c65e6156efdcdfce627d85e98c77ff --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/save/test_debugger_save_mindspore.py @@ -0,0 +1,300 @@ +import unittest +import os +import json +import mindspore +import numpy as np +import shutil +from unittest.mock import patch + +from msprobe.mindspore import PrecisionDebugger +from msprobe.core.data_dump.data_processor.mindspore_processor import MindsporeDataProcessor +from msprobe.mindspore.dump.hook_cell.api_register import get_api_register + + +current_file = __file__ +parent_dir = os.path.abspath(os.path.dirname(current_file)) +test_dir = os.path.join(parent_dir, "test_dir") + +def deep_compare(obj1, obj2, float_tolerance=1e-5): + """ + Recursively compare two objects to check if they are the same. + Supports nested dictionaries and lists. + """ + if type(obj1) != type(obj2): + return False + if isinstance(obj1, dict): + if obj1.keys() != obj2.keys(): + return False + return all(deep_compare(obj1[key], obj2[key]) for key in obj1) + if isinstance(obj1, (tuple, list)): + if len(obj1) != len(obj2): + return False + return all(deep_compare(item1, item2) for item1, item2 in zip(obj1, obj2)) + if isinstance(obj1, (int, float)): + return abs(obj1 - obj2) < float_tolerance + return obj1 == obj2 + +class TestDebuggerSave(unittest.TestCase): + @staticmethod + def write_config_json(step, async_dump, mode, dump_path, config_file_path): + task = "tensor" if mode == "tensor" else "statistics" + statistics_summary_mode = "statistics" if mode == "statistics" else "md5" + config = { + "task": task, + "dump_path": dump_path, + "rank": [], + "step": step, + "level": "debug", + "enable_dataloader": False, + "async_dump": async_dump, + "statistics": { + "summary_mode": statistics_summary_mode, + } + } + with open(config_file_path, "w", encoding="utf-8") as f: + json.dump(config, f, indent=4, ensure_ascii=False) + + @staticmethod + def read_debug_json_into_dict(debug_json_path): + with open(debug_json_path, "r", encoding="utf-8") as f: + debug_json = json.load(f) + return debug_json + + @staticmethod + def check_real_npy(npy_path, target_ms_tensor, check_values=True, rtol=1e-5, atol=1e-8): + """ + Enhanced version with optional value comparison. + + Args: + npy_path (str): Path to the .npy file + target_ms_tensor: Target mindspore tensor to compare + check_values (bool): If True, also compare array values + rtol, atol: Relative and absolute tolerances for value comparison + + Returns: + bool: True if all checks pass + """ + # Convert mindspore tensor to numpy if needed + if hasattr(target_ms_tensor, 'numpy'): + target_ms_tensor = target_ms_tensor.numpy() + # Load the npy file + try: + npy_data = np.load(npy_path) + except FileNotFoundError: + print(f"Error: The file {npy_path} does not exist.") + return False + except Exception as e: + print(f"Error loading npy file: {e}") + return False + # Check shapes + if npy_data.shape != target_ms_tensor.shape: + print(f"Shape mismatch: npy data shape is {npy_data.shape}, target tensor shape is {target_ms_tensor.shape}") + return False + # Check dtypes + if npy_data.dtype != target_ms_tensor.dtype: + print(f"Shape mismatch: npy data dtype is {npy_data.dtype}, target tensor dtype is {target_ms_tensor.dtype}") + return False + # Optionally check values + if check_values: + if not np.allclose(npy_data, target_ms_tensor, rtol=rtol, atol=atol): + print("Value mismatch: npy data and target tensor values do not match within the specified tolerances.") + return False + + return True + + def setUp(self): + if not os.path.exists(test_dir): + os.makedirs(test_dir) + PrecisionDebugger._instance = None + self.original_mindspore_special_type = MindsporeDataProcessor.mindspore_special_type + MindsporeDataProcessor.mindspore_special_type = tuple([mindspore.Tensor]) + + def tearDown(self): + if os.path.exists(test_dir): + shutil.rmtree(test_dir) + PrecisionDebugger._instance = None + MindsporeDataProcessor.mindspore_special_type = self.original_mindspore_special_type + get_api_register(True).restore_all_api() + + @patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions") + def test_save_real_tensor(self, _): + data = {"a": mindspore.Tensor([1., 2.])} + step = [] + async_dump = False + mode = "tensor" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + PrecisionDebugger.save(data, "data_dict", save_backward=False) + PrecisionDebugger.step() + + # check npy file + npy_path = os.path.join(dump_path, "step0", "rank", "dump_tensor_data", "data_dict.0.debug.a.npy") + assert self.check_real_npy(npy_path, data["a"]) + + # check debug json + target_debug_info = { + "a": { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "data_name": "data_dict.0.debug.a.npy" + } + } + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + assert deep_compare(debug_json_dict["data"]["data_dict.0.debug"], target_debug_info) + + @patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions") + def test_save_md5(self, _): + data = {"a": mindspore.Tensor([1., 2.])} + step = [] + async_dump = False + mode = "md5" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + PrecisionDebugger.save(data, "data_dict", save_backward=False) + PrecisionDebugger.step() + # check debug json + target_debug_info = { + "a": { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "md5": "2e3fa576" + } + } + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + assert deep_compare(debug_json_dict["data"]["data_dict.0.debug"], target_debug_info) + + @patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions") + def test_save_multiple_steps(self, _): + data = {"a": mindspore.Tensor([1., 2.])} + step = [0, 1, 2] + async_dump = False + mode = "tensor" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + for _ in step: + PrecisionDebugger.save(data, "data_dict", save_backward=False) + PrecisionDebugger.step() + # check npy file + for i in step: + npy_path = os.path.join(dump_path, f"step{i}", "rank", "dump_tensor_data", "data_dict.0.debug.a.npy") + assert self.check_real_npy(npy_path, data["a"]) + # check debug json + target_debug_info = { + "a": { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "data_name": "data_dict.0.debug.a.npy" + } + } + for i in step: + debug_json_path = os.path.join(dump_path, f"step{i}", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + assert deep_compare(debug_json_dict["data"]["data_dict.0.debug"], target_debug_info) + + @patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions") + def test_save_multiple_times(self, _): + data = {"a": mindspore.Tensor([1., 2.])} + step = [] + call_times = 3 + async_dump = False + mode = "tensor" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + for _ in range(call_times): + PrecisionDebugger.save(data, "data_dict", save_backward=False) + PrecisionDebugger.step() + # check npy file + for i in range(call_times): + npy_path = os.path.join(dump_path, "step0", "rank", "dump_tensor_data", f"data_dict.{i}.debug.a.npy") + assert self.check_real_npy(npy_path, data["a"]) + # check debug json + for i in range(call_times): + target_debug_info = { + "a": { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "data_name": f"data_dict.{i}.debug.a.npy" + } + } + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + assert deep_compare(debug_json_dict["data"][f"data_dict.{i}.debug"], target_debug_info) + + @patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions") + def test_save_compilcated_data_structure(self, _): + x = mindspore.Tensor([1., 2.]) + complicated_structure = [{"a_key": x}] + step = [] + async_dump = False + mode = "tensor" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + PrecisionDebugger.save(complicated_structure, "complicated_structure") + PrecisionDebugger.step() + complicated_structure_info_list = [ + x, + os.path.join(dump_path, "step0", "rank", "dump_tensor_data", "complicated_structure.0.debug.0.a_key.npy"), + "complicated_structure.0.debug", + [ + { + "a_key": { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "data_name": "complicated_structure.0.debug.0.a_key.npy" + } + } + ], + ] + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + target_tensor, target_tensor_path, target_tensor_key, target_tensor_info = complicated_structure_info_list + assert self.check_real_npy(target_tensor_path, target_tensor) + assert deep_compare(debug_json_dict["data"][target_tensor_key], target_tensor_info) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_cell_processor.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_cell_processor.py index 40f5c0164115e18cdd49c046ce29967e7a3f63eb..183351fcf12b43df0d74511af14f77be23592a85 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_cell_processor.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_cell_processor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,135 +13,354 @@ # See the License for the specific language governing permissions and # limitations under the License. +import threading + import unittest from unittest.mock import MagicMock, patch +import mindspore as ms +from mindspore import Tensor +from mindspore.ops.operations import _inner_ops + from msprobe.core.common.const import Const +from msprobe.core.common.exceptions import MsprobeException +from msprobe.core.common.runtime import Runtime from msprobe.core.data_dump.scope import ModuleRangeScope -from msprobe.mindspore.cell_processor import CellProcessor - - -class MockCell: - def __init__(self): - self.mindstudio_reserved_name = None +from msprobe.core.hook_manager import HookSet +from msprobe.mindspore.cell_processor import CellProcessor, get_cell_construct +from msprobe.mindspore.common.log import logger class TestCellProcessor(unittest.TestCase): + @classmethod + def setUpClass(cls): + CellProcessor.reset_cell_stats() + cls.scope = MagicMock(spec=ModuleRangeScope) + cls.processor = CellProcessor(cls.scope) - def setUp(self): - # 重置静态变量 + @classmethod + def tearDownClass(cls): CellProcessor.reset_cell_stats() - self.scope = MagicMock(spec=ModuleRangeScope) - self.processor = CellProcessor(self.scope) - def test_init_with_module_range_scope(self): - self.assertIsInstance(self.processor.scope, ModuleRangeScope) + def test_class_attribute(self): + self.assertTrue(hasattr(CellProcessor, 'cell_count')) + self.assertTrue(hasattr(CellProcessor, 'cell_stack')) + self.assertTrue(hasattr(CellProcessor, 'api_parent_node')) + self.assertTrue(hasattr(CellProcessor, 'module_node')) + self.assertTrue(hasattr(CellProcessor, 'cell_bw_hook_kernels')) + self.assertTrue(hasattr(CellProcessor, 'cell_backward_pre_hook')) + self.assertTrue(hasattr(CellProcessor, 'cell_backward_hook')) - def test_init_with_none_scope(self): + def test__init(self): + self.assertIsInstance(self.processor.scope, ModuleRangeScope) processor = CellProcessor(None) self.assertIsNone(processor.scope) - def test_set_cell_count_new_cell(self): - count = self.processor.set_cell_count("cell1") + def test_get_cell_construct(self): + def construct(self, *args, **kwargs): + return len(args) + + _constrct = get_cell_construct(construct) + ret = _constrct(self, 'argument') + self.assertFalse(hasattr(self, 'msprobe_input_kwargs')) + self.assertEqual(ret, 1) + + setattr(self, 'msprobe_hook', True) + _constrct = get_cell_construct(construct) + ret = _constrct(self, 'argument') + self.assertEqual(self.msprobe_input_kwargs, {}) + self.assertEqual(ret, 1) + + del self.msprobe_hook + del self.msprobe_input_kwargs + + def test_set_and_get_calls_number(self): + CellProcessor.cell_count = {} + count = self.processor.set_and_get_calls_number("cell") self.assertEqual(count, 0) - self.assertEqual(CellProcessor.cell_count["cell1"], 0) + self.assertEqual(CellProcessor.cell_count["cell"], 0) - def test_set_cell_count_existing_cell(self): - self.processor.set_cell_count("cell1") - count = self.processor.set_cell_count("cell1") + count = self.processor.set_and_get_calls_number("cell") self.assertEqual(count, 1) - self.assertEqual(CellProcessor.cell_count["cell1"], 1) + self.assertEqual(CellProcessor.cell_count["cell"], 1) + + CellProcessor.cell_count = {} def test_reset_cell_stats(self): - self.processor.set_cell_count("cell1") + CellProcessor.cell_count['cell'] = 0 + CellProcessor.cell_stack['tid'] = 'cell' + CellProcessor.api_parent_node['tid'] = 'cell' + CellProcessor.module_node['cell'] = 'null' + CellProcessor.cell_bw_hook_kernels['cell'] = 'bw' + CellProcessor.cell_backward_pre_hook.append('backward_pre_hook') + CellProcessor.cell_backward_hook.append('backward_hook') + CellProcessor.reset_cell_stats() self.assertEqual(CellProcessor.cell_count, {}) - self.assertEqual(CellProcessor.cell_stack, []) - self.assertEqual(CellProcessor.api_parent_node, "") + self.assertEqual(CellProcessor.cell_stack, {}) + self.assertEqual(CellProcessor.api_parent_node, {}) self.assertEqual(CellProcessor.module_node, {}) + self.assertEqual(CellProcessor.cell_bw_hook_kernels, {}) + self.assertEqual(CellProcessor.cell_backward_pre_hook, []) + self.assertEqual(CellProcessor.cell_backward_hook, []) + + def test_register_cell_hook(self): + with self.assertRaises(MsprobeException) as context: + self.processor.register_cell_hook([], None, 'config') + self.assertEqual(str(context.exception), '[msprobe] 无效参数:The model cannot be None, when level is "L0" or "mix"') + + with patch('msprobe.mindspore.cell_processor.is_mindtorch') as mock_is_mindtorch, \ + patch('msprobe.mindspore.cell_processor.get_cells_and_names_with_index') as mock_get_cells_and_names, \ + patch('msprobe.mindspore.cell_processor.CellProcessor.build_cell_hook') as mock_build_cell_hook, \ + patch('msprobe.mindspore.cell_processor.get_cell_construct') as mock_get_cell_construct, \ + patch('msprobe.mindspore.cell_processor.is_graph_mode_cell_dump_allowed') \ + as mock_is_graph_mode_cell_dump_allowed, \ + patch.object(logger, 'info') as mock_logger_info: + mock_cell = MagicMock() + mock_sub_cell = MagicMock() + mock_get_cells_and_names.return_value = ({'-1': [('cell', mock_cell), ('sub_cell', mock_sub_cell)]}, {}) + mock_build_cell_hook.return_value = 'forward_pre_hook' + mock_get_cell_construct.return_value = '_construct' + mock_is_graph_mode_cell_dump_allowed.return_value = False + + mock_is_mindtorch.return_value = False + setattr(MagicMock, '_run_construct', '_run_construct') + self.processor.register_cell_hook(mock_cell, None, 'config') + self.assertTrue(mock_sub_cell.__class__.msprobe_construct) + mock_get_cell_construct.assert_called_with('_run_construct') + self.assertEqual(mock_sub_cell.__class__._run_construct, '_construct') + self.assertTrue(mock_sub_cell.msprobe_hook) + mock_build_cell_hook.assert_called_with('Cell.sub_cell.MagicMock.', None) + mock_cell.assert_not_called() + mock_sub_cell.register_forward_pre_hook.assert_called_with('forward_pre_hook') + mock_sub_cell.register_forward_hook.assert_not_called() + mock_logger_info.assert_called_with('The cell hook function is successfully mounted to the model.') + + del MagicMock._run_construct + del mock_sub_cell.__class__._run_construct + del mock_sub_cell.__class__.msprobe_construct + + mock_get_cell_construct.reset_mock() + mock_another_sub_cell = MagicMock() + setattr(mock_another_sub_cell.__class__, 'msprobe_construct', True) + mock_get_cells_and_names.return_value = ( + {'-1': [('cell', mock_cell), ('another_sub_cell', mock_another_sub_cell)]}, + {} + ) + self.processor.register_cell_hook(mock_cell, None, 'config') + mock_get_cell_construct.assert_not_called() + mock_another_sub_cell.register_forward_pre_hook.assert_called_with('forward_pre_hook') + mock_another_sub_cell.register_forward_hook.assert_not_called() + + del mock_another_sub_cell.__class__.msprobe_construct + + mock_build_cell_hook.reset_mock() + mock_get_cell_construct.reset_mock() + mock_another_sub_cell.reset_mock() + setattr(MagicMock, '_call_impl', '_call_impl') + mock_is_mindtorch.return_value = True + self.processor.register_cell_hook(mock_cell, None, 'config') + self.assertTrue(mock_another_sub_cell.__class__.msprobe_construct) + mock_get_cell_construct.assert_called_with('_call_impl') + mock_build_cell_hook.assert_called_with('Module.another_sub_cell.MagicMock.', None) + mock_cell.assert_not_called() + mock_another_sub_cell.register_forward_pre_hook.assert_called_with('forward_pre_hook') + mock_another_sub_cell.register_forward_hook.assert_not_called() + + del MagicMock._call_impl + del mock_another_sub_cell.__class__._call_impl + del mock_another_sub_cell.__class__.msprobe_construct + + def test_build_cell_hook(self): + CellProcessor.reset_cell_stats() + Runtime.is_running = True + + cell_name = 'Cell.cell.Cell.' + mock_build_data_hook = MagicMock() + mock_backward_data_hook = MagicMock() + target_grad_output = (Tensor([0.5]),) + mock_backward_data_hook.return_value = target_grad_output + mock_hook_set = HookSet(backward_hook=mock_backward_data_hook) + mock_build_data_hook.return_value = mock_hook_set + mock_cell = MagicMock() + + with patch.object(_inner_ops, 'CellBackwardHook') as mock_CellBackwardHook: + forward_pre_hook = self.processor.build_cell_hook(cell_name, mock_build_data_hook) + forward_hook = forward_pre_hook.__closure__[1].cell_contents.__closure__[2].cell_contents + + mock_bw = mock_CellBackwardHook.return_value + mock_bw.return_value = (Tensor([0.0]),) + args = (Tensor([1.0]),) + target_args = (Tensor([0.0]),) + full_forward_name = f'{cell_name}{Const.FORWARD}.0' + full_backward_name = f'{cell_name}{Const.BACKWARD}.0' + # call testing function - forward_pre_hook + ret = forward_pre_hook(mock_cell, args) + self.assertIsNone(CellProcessor.module_node[full_forward_name]) + self.assertEqual(CellProcessor.cell_stack[threading.get_ident()], [full_forward_name]) + self.assertEqual(CellProcessor.api_parent_node[threading.get_ident()], full_forward_name) + self.scope.begin_module.assert_called_with(full_forward_name) + mock_build_data_hook.assert_called_with('Module', full_forward_name) + self.assertEqual(len(CellProcessor.cell_backward_hook), 1) + mock_CellBackwardHook.assert_called_with(full_backward_name, mock_cell, + CellProcessor.cell_backward_hook[-1]) + mock_bw.register_backward_hook.assert_called_once() + mock_bw.assert_called_with(*args) + self.assertTrue((ret[0] == target_args[0]).all()) + + backward_hook = CellProcessor.cell_backward_hook[-1][full_backward_name] + grad_input = (Tensor([1.0]),) + grad_output = (Tensor([2.0]),) + # call testing function - backward_hook + ret = backward_hook(mock_cell, grad_input, grad_output) + mock_backward_data_hook.assert_called_with(mock_cell, grad_input, grad_output) + self.assertFalse(mock_cell.has_pre_hook_called) + self.assertEqual(CellProcessor.cell_stack[threading.get_ident()], []) + self.assertIsNone(CellProcessor.api_parent_node[threading.get_ident()]) + self.scope.end_module.assert_called_with(full_backward_name) + self.assertTrue((ret[0] == target_grad_output[0]).all()) + + mock_build_data_hook.reset_mock() + args = (Tensor([1], dtype=ms.int32),) + full_forward_name = f'{cell_name}{Const.FORWARD}.1' + # call testing function - forward_pre_hook + ret = forward_pre_hook(mock_cell, args) + self.assertIsNone(CellProcessor.module_node[full_forward_name]) + self.assertEqual(CellProcessor.cell_stack[threading.get_ident()], [full_forward_name]) + self.assertEqual(CellProcessor.api_parent_node[threading.get_ident()], full_forward_name) + self.scope.begin_module.assert_called_with(full_forward_name) + self.assertEqual(len(CellProcessor.cell_backward_hook), 1) + mock_build_data_hook.assert_not_called() + + full_forward_name = f'{cell_name}{Const.FORWARD}.0' + CellProcessor.cell_count = {cell_name: 0} + CellProcessor.cell_stack[threading.get_ident()] = [full_forward_name] + CellProcessor.api_parent_node[threading.get_ident()] = full_forward_name + CellProcessor.module_node = {full_forward_name: None} + self.scope.reset_mock() + mock_CellBackwardHook.reset_mock() + mock_bw.reset_mock() + target_output = Tensor([0.5]) + args = (Tensor([1.0]),) + output = Tensor([2.0]) + mock_bw.return_value = target_output + mock_backward_data_hook.reset_mock() + mock_forward_data_hook = MagicMock() + mock_forward_data_hook.return_value = output + mock_build_data_hook.return_value = HookSet( + forward_hook=mock_forward_data_hook, backward_hook=mock_backward_data_hook + ) + # call testing function - forward_hook + ret = forward_hook(mock_cell, args, output) + self.assertEqual(CellProcessor.cell_count.get(cell_name), 0) + self.assertEqual(CellProcessor.cell_stack[threading.get_ident()], []) + self.assertIsNone(CellProcessor.api_parent_node[threading.get_ident()]) + self.scope.end_module.assert_called_with(full_forward_name) + self.assertEqual(mock_bw.call_count, 2) + self.assertEqual(mock_bw.call_args_list[0][0][0], output) + self.assertEqual(mock_bw.call_args_list[1][0][0], target_output) + self.assertEqual(mock_CellBackwardHook.call_count, 1) + self.assertEqual(len(CellProcessor.cell_backward_pre_hook), 1) + self.assertTrue((ret == target_output).all()) + + backward_pre_hook = CellProcessor.cell_backward_pre_hook[-1][full_backward_name] + mock_backward_data_hook.reset_mock() + grad_output = (Tensor([2.0]),) + # call testing function - backward_pre_hook + ret = backward_pre_hook(mock_cell, grad_output) + self.assertTrue(mock_cell.has_pre_hook_called) + self.scope.begin_module.assert_called_with(full_backward_name) + self.assertEqual(CellProcessor.cell_stack[threading.get_ident()], [full_backward_name]) + self.assertEqual(CellProcessor.api_parent_node[threading.get_ident()], full_backward_name) + self.assertEqual(CellProcessor.module_node, {full_forward_name: None, full_backward_name: None}) + self.scope.begin_module.assert_called_with(full_backward_name) + mock_backward_data_hook.assert_not_called() + self.assertIsNone(ret) + + CellProcessor.cell_count = {cell_name: 0} + CellProcessor.cell_stack[threading.get_ident()] = [full_forward_name] + CellProcessor.api_parent_node[threading.get_ident()] = full_forward_name + CellProcessor.module_node = {full_forward_name: None} + mock_bw.reset_mock() + args = (Tensor([1.0]),) + output = (Tensor([2.0]),) + mock_forward_data_hook.return_value = output + target_output = (Tensor([0.5]),) + # call testing function - forward_hook + ret = forward_hook(mock_cell, args, output) + self.assertEqual(mock_bw.call_count, 2) + self.assertEqual(mock_bw.call_args_list[0][0][0], *output) + self.assertEqual(mock_bw.call_args_list[1][0][0], mock_bw.return_value) + self.assertTrue((ret[0] == target_output[0]).all()) + + CellProcessor.cell_count = {cell_name: 0} + CellProcessor.cell_stack[threading.get_ident()] = [full_forward_name] + CellProcessor.api_parent_node[threading.get_ident()] = full_forward_name + CellProcessor.module_node = {full_forward_name: None} + CellProcessor.cell_bw_hook_kernels.clear() + CellProcessor.cell_backward_pre_hook.clear() + mock_bw.reset_mock() + mock_bw.return_value = (Tensor([0.5]),) + output = (Tensor([1.0]), Tensor([2.0])) + mock_forward_data_hook.return_value = output + with self.assertRaises(TypeError) as context: + # call testing function - forward_hook + forward_hook(mock_cell, args, output) + self.assertEqual(str(context.exception), + 'The backward pre hook return value size is 1 not equal to output size 2') + mock_bw.assert_called_with(*output) + + self.scope.reset_mock() + backward_pre_hook = CellProcessor.cell_backward_pre_hook[-1][full_backward_name] + # call testing function - backward_pre_hook + ret = backward_pre_hook(mock_cell, grad_output) + self.assertFalse(mock_cell.has_pre_hook_called) + self.scope.begin_module.assert_called_with(full_backward_name) + mock_backward_data_hook.assert_called_with(mock_cell, (), grad_output) + self.assertEqual(CellProcessor.cell_stack[threading.get_ident()], []) + self.assertIsNone(CellProcessor.api_parent_node[threading.get_ident()]) + self.assertEqual(CellProcessor.module_node, {full_forward_name: None, full_backward_name: None}) + self.scope.end_module.assert_called_with(full_backward_name) + self.assertIsNone(ret) + + CellProcessor.reset_cell_stats() + Runtime.is_running = False + + def test_set_construct_info_in_pre_hook(self): + CellProcessor.reset_cell_stats() + self.processor.set_construct_info_in_pre_hook('full_name') + self.assertEqual(CellProcessor.module_node['full_name'], None) + self.assertEqual(CellProcessor.cell_stack[threading.get_ident()], ['full_name']) + self.assertEqual(CellProcessor.api_parent_node[threading.get_ident()], 'full_name') + self.scope.begin_module.assert_called_with('full_name') + + self.scope.begin_module.reset_mock() + self.processor.set_construct_info_in_pre_hook('sub_cell_name') + self.assertEqual(CellProcessor.module_node, {'full_name': None, 'sub_cell_name': 'full_name'}) + self.assertEqual(CellProcessor.cell_stack[threading.get_ident()], ['full_name', 'sub_cell_name']) + self.assertEqual(CellProcessor.api_parent_node[threading.get_ident()], 'sub_cell_name') + self.scope.begin_module.assert_called_with('sub_cell_name') - @patch('msprobe.core.common.const.Const') - def test_node_hook_begin(self, mock_const): - mock_const.SEP = "." # 确保 SEPARATOR 设置为字符串 - mock_const.START = "start" - cell = MockCell() - self.processor.node_hook("prefix", "start")(cell, "input") - - expected_name = "prefix" + mock_const.SEP + "0" - self.assertEqual(cell.mindstudio_reserved_name, expected_name) - self.assertIn(expected_name, CellProcessor.cell_stack) - self.assertEqual(CellProcessor.api_parent_node, expected_name) - self.scope.begin_module.assert_called_once_with(expected_name) - - @patch('msprobe.core.common.const.Const') - def test_node_hook_end(self, mock_const): - mock_const.START = "start" - cell = MockCell() - self.processor.node_hook("prefix", "start")(cell, "input") - self.processor.node_hook("prefix", "stop")(cell, "input", "output") - - self.assertEqual(len(CellProcessor.cell_stack), 0) - self.assertIsNone(CellProcessor.api_parent_node) - self.scope.end_module.assert_called_once_with(cell.mindstudio_reserved_name) - - @patch('msprobe.core.common.const.Const') - def test_multiple_node_hook_calls(self, mock_const): - mock_const.SEP = "." # 确保 SEPARATOR 设置为字符串 - mock_const.START = "start" - cell = MockCell() - - # First call - self.processor.node_hook("prefix", "start")(cell, "input") - expected_name1 = "prefix" + mock_const.SEP + "0" - - # Second call - self.processor.node_hook("prefix", "start")(cell, "input") - expected_name2 = "prefix" + mock_const.SEP + "1" - - self.assertEqual(cell.mindstudio_reserved_name, expected_name2) - self.assertEqual(CellProcessor.api_parent_node, expected_name2) - - # End first call - self.processor.node_hook("prefix", "stop")(cell, "input", "output") - self.assertEqual(len(CellProcessor.cell_stack), 1) # Still one item in stack - self.assertEqual(CellProcessor.api_parent_node, expected_name1) - - # End second call - self.processor.node_hook("prefix", "stop")(cell, "input", "output") - self.assertEqual(len(CellProcessor.cell_stack), 0) # Stack should be empty now - self.assertIsNone(CellProcessor.api_parent_node) - - def test_set_and_get_reserved_name(self): - cell = MockCell() - cell.mindstudio_reserved_name = "mindstudio_reserved_name" CellProcessor.reset_cell_stats() - cell_name = "Cell.net.Net.forward" - ret = self.processor.set_and_get_reserved_name(cell, cell_name) - self.assertEqual(ret, cell_name + Const.SEP + "0") - self.assertEqual(cell.mindstudio_reserved_name, ret) - self.assertEqual(CellProcessor.cell_count[cell_name], 0) - self.assertFalse(hasattr(cell, "has_pre_hook_called")) - - cell.has_pre_hook_called = False - ret = self.processor.set_and_get_reserved_name(cell, cell_name) - self.assertEqual(ret, cell_name + Const.SEP + "1") - self.assertEqual(cell.mindstudio_reserved_name, ret) - self.assertEqual(CellProcessor.cell_count[cell_name], 1) - self.assertFalse(cell.has_pre_hook_called) - - cell.has_pre_hook_called = True - cell.mindstudio_reserved_name = "mindstudio_reserved_name" + def test_set_construct_info_in_hook(self): CellProcessor.reset_cell_stats() - ret = self.processor.set_and_get_reserved_name(cell, cell_name) - self.assertEqual(ret, "mindstudio_reserved_name") - self.assertEqual(cell.mindstudio_reserved_name, ret) - self.assertEqual(CellProcessor.cell_count, {}) - self.assertFalse(cell.has_pre_hook_called) + self.processor.set_construct_info_in_hook('full_name') + self.assertIsNone(CellProcessor.api_parent_node[threading.get_ident()]) + self.scope.end_module.assert_called_with('full_name') + + self.scope.end_module.reset_mock() + CellProcessor.cell_stack[threading.get_ident()] = ['full_name'] + self.processor.set_construct_info_in_hook('full_name') + self.assertEqual(CellProcessor.cell_stack, {threading.get_ident(): []}) + self.assertIsNone(CellProcessor.api_parent_node[threading.get_ident()]) + self.scope.end_module.assert_called_with('full_name') + + self.scope.end_module.reset_mock() + CellProcessor.cell_stack[threading.get_ident()] = ['Cell.0', 'Cell.1'] + self.processor.set_construct_info_in_hook('full_name') + self.assertEqual(CellProcessor.cell_stack, {threading.get_ident():['Cell.0']}) + self.assertEqual(CellProcessor.api_parent_node[threading.get_ident()], 'Cell.0') + self.scope.end_module.assert_called_with('full_name') - ret = self.processor.set_and_get_reserved_name(cell, cell_name, is_called_by_pre_hook=True) - self.assertEqual(ret, cell_name + Const.SEP + "0") - self.assertEqual(cell.mindstudio_reserved_name, ret) - self.assertEqual(CellProcessor.cell_count[cell_name], 0) - self.assertTrue(cell.has_pre_hook_called) CellProcessor.reset_cell_stats() diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_dump_tool_factory.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_dump_tool_factory.py index 8f5d207c41923175b6efe4f9dc313896f879fd89..009c70a9ebdc112174301ae0f64836980a0ac841 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_dump_tool_factory.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_dump_tool_factory.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,16 +16,19 @@ from unittest import TestCase from unittest.mock import patch +from msprobe.core.common.log import logger from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.core.common.const import Const as CoreConst from msprobe.mindspore.common.const import Const from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.dump_tool_factory import DumpToolFactory +from msprobe.mindspore.ms_config import StatisticsConfig class TestDumpToolFactory(TestCase): + @patch.object(logger, "error") @patch("msprobe.mindspore.debugger.debugger_config.create_directory") - def test_create(self, _): + def test_create(self, _, mock_logger_error): json_config = { "task": "statistics", "dump_path": "/absolute_path", @@ -35,7 +38,7 @@ class TestDumpToolFactory(TestCase): } common_config = CommonConfig(json_config) - task_config = BaseConfig(json_config) + task_config = StatisticsConfig(json_config) config = DebuggerConfig(common_config, task_config) config.data_mode = [CoreConst.INPUT, CoreConst.OUTPUT] @@ -54,18 +57,21 @@ class TestDumpToolFactory(TestCase): DumpToolFactory.create(config) self.assertEqual(str(context.exception), "Valid level is needed.") - config.level = Const.KERNEL - with self.assertRaises(Exception) as context: - DumpToolFactory.create(config) - self.assertEqual(str(context.exception), "Data dump is not supported in None mode when dump level is kernel.") - config.execution_mode = Const.GRAPH_GE_MODE config.level = Const.CELL - with self.assertRaises(Exception) as context: - DumpToolFactory.create(config) - self.assertEqual(str(context.exception), "Data dump is not supported in graph_ge mode when dump level is cell.") + with patch('msprobe.mindspore.dump.dump_tool_factory.is_graph_mode_cell_dump_allowed') as \ + mock_is_cell_dump_allowed: + mock_is_cell_dump_allowed.return_value = True + with self.assertRaises(ValueError): + DumpToolFactory.create(config) + mock_logger_error.assert_called_with("Data dump is not supported in graph_ge mode when dump level is cell.") + mock_logger_error.reset_mock() + + mock_is_cell_dump_allowed.return_value = False + with self.assertRaises(Exception) as context: + DumpToolFactory.create(config) + self.assertEqual(str(context.exception), "Cell dump is not supported in graph mode.") - config.execution_mode = Const.GRAPH_KBYK_MODE config.level = Const.KERNEL dumper = DumpToolFactory.create(config) - self.assertEqual(dumper.dump_json["common_dump_settings"]["net_name"], "Net") + self.assertIsInstance(dumper, tuple) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_exception_dump_tool_factory.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_exception_dump_tool_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..42c42d281d4f115bfc16f49ff45a62e4aa99c9a1 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_exception_dump_tool_factory.py @@ -0,0 +1,61 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +from unittest import TestCase +from unittest.mock import patch + +from msprobe.core.common_config import CommonConfig, BaseConfig +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.exception_dump.kernel_graph_exception_dump import KernelGraphExceptionDump +from msprobe.core.common.file_utils import move_file + + +class TestKernelGraphExceptionDump(TestCase): + @patch("msprobe.mindspore.debugger.debugger_config.create_directory") + def test_handle(self, _): + json_config = { + "task": "exception_dump", + "dump_path": "/absolute_path", + "rank": [], + "step": [], + "level": "L2" + } + + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + task_config.check_mode = "atomic" + config = DebuggerConfig(common_config, task_config) + checker = KernelGraphExceptionDump(config) + self.assertEqual(checker.dump_json["common_dump_settings"]["op_debug_mode"], 4) + + _msprobe_c_existed = True + try: + from msprobe.lib import _msprobe_c + except ImportError: + _msprobe_c_existed = False + + with patch("msprobe.mindspore.exception_dump.kernel_graph_exception_dump.create_directory"), \ + patch("msprobe.mindspore.exception_dump.kernel_graph_exception_dump.logger.info"), \ + patch("msprobe.mindspore.exception_dump.kernel_graph_exception_dump.save_json") as mock_save_json: + + checker.handle() + self.assertIn("kernel_graph_exception_check.json", mock_save_json.call_args_list[0][0][0]) + self.assertIn("kernel_graph_exception_check.json", os.environ.get("MINDSPORE_DUMP_CONFIG")) + + if "MINDSPORE_DUMP_CONFIG" in os.environ: + del os.environ["MINDSPORE_DUMP_CONFIG"] \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_dump.py index 329274b19d862c8c0e50af0fdbd051909e6a60d6..e607f2c2a8a701417ae28a6c353349d1430d5e98 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_dump.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,6 +14,7 @@ # limitations under the License. import os +import sys from unittest import TestCase from unittest.mock import patch @@ -21,6 +22,7 @@ from unittest.mock import patch from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump +from msprobe.core.common.file_utils import move_file class TestKernelGraphDump(TestCase): @@ -44,10 +46,26 @@ class TestKernelGraphDump(TestCase): self.assertEqual(dumper.dump_json["common_dump_settings"]["file_format"], "bin") self.assertEqual(dumper.dump_json["common_dump_settings"]["input_output"], 2) + _msprobe_c_existed = True + try: + from msprobe.lib import _msprobe_c + except ImportError: + _msprobe_c_existed = False + with patch("msprobe.mindspore.dump.kernel_graph_dump.create_directory"), \ patch("msprobe.mindspore.dump.kernel_graph_dump.logger.info"), \ patch("msprobe.mindspore.dump.kernel_graph_dump.save_json") as mock_save_json: + if _msprobe_c_existed: + dumper.handle() + mock_save_json.assert_not_called() + + _msprobe_c_path = _msprobe_c.__file__ + _msprobe_c_test_path = _msprobe_c_path.replace('_msprobe_c.so', '_msprobe_c_test.so') + move_file(_msprobe_c_path, _msprobe_c_test_path) + sys.modules.pop('msprobe.lib') + sys.modules.pop('msprobe.lib._msprobe_c') + os.environ["GRAPH_OP_RUN"] = "1" with self.assertRaises(Exception) as context: dumper.handle() @@ -63,3 +81,5 @@ class TestKernelGraphDump(TestCase): del os.environ["MINDSPORE_DUMP_CONFIG"] if "MS_ACL_DUMP_CFG_PATH" in os.environ: del os.environ["MS_ACL_DUMP_CFG_PATH"] + if _msprobe_c_existed: + move_file(_msprobe_c_test_path, _msprobe_c_path) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py index b484bc9b7cdceec3b8906600b16b2d4fdc6b1b5e..b3b2466e8104de3735cd1f1d1de17a79e52bd38e 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,6 +14,7 @@ # limitations under the License. import os +import sys from unittest import TestCase from unittest.mock import patch @@ -21,6 +22,7 @@ from unittest.mock import patch from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck +from msprobe.core.common.file_utils import move_file class TestKernelGraphOverflowCheck(TestCase): @@ -41,11 +43,27 @@ class TestKernelGraphOverflowCheck(TestCase): checker = KernelGraphOverflowCheck(config) self.assertEqual(checker.dump_json["common_dump_settings"]["op_debug_mode"], 2) + _msprobe_c_existed = True + try: + from msprobe.lib import _msprobe_c + except ImportError: + _msprobe_c_existed = False + os.environ["MS_ACL_DUMP_CFG_PATH"] = "path" with patch("msprobe.mindspore.overflow_check.kernel_graph_overflow_check.create_directory"), \ patch("msprobe.mindspore.overflow_check.kernel_graph_overflow_check.logger.info"), \ patch("msprobe.mindspore.overflow_check.kernel_graph_overflow_check.save_json") as mock_save_json: + if _msprobe_c_existed: + checker.handle() + mock_save_json.assert_not_called() + + _msprobe_c_path = _msprobe_c.__file__ + _msprobe_c_test_path = _msprobe_c_path.replace('_msprobe_c.so', '_msprobe_c_test.so') + move_file(_msprobe_c_path, _msprobe_c_test_path) + sys.modules.pop('msprobe.lib') + sys.modules.pop('msprobe.lib._msprobe_c') + os.environ["GRAPH_OP_RUN"] = "1" with self.assertRaises(Exception) as context: checker.handle() @@ -60,3 +78,5 @@ class TestKernelGraphOverflowCheck(TestCase): if "MINDSPORE_DUMP_CONFIG" in os.environ: del os.environ["MINDSPORE_DUMP_CONFIG"] + if _msprobe_c_existed: + move_file(_msprobe_c_test_path, _msprobe_c_path) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_kbyk_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_kbyk_dump.py index c52ea4de2adef5a3c579c3deceece9d84b89309c..c3e2076483094264293ec4ac454e669ff842dd8a 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_kbyk_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_kbyk_dump.py @@ -21,8 +21,14 @@ from unittest.mock import patch from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.ms_config import StatisticsConfig from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump +from collections import Counter + +import mindspore as ms +ms_version = ms.__version__ + class TestKernelKbykDump(TestCase): @patch("msprobe.mindspore.debugger.debugger_config.create_directory") @@ -36,7 +42,7 @@ class TestKernelKbykDump(TestCase): } common_config = CommonConfig(json_config) - task_config = BaseConfig(json_config) + task_config = StatisticsConfig(json_config) config = DebuggerConfig(common_config, task_config) dumper = KernelKbykDump(config) self.assertEqual(dumper.dump_json["common_dump_settings"]["iteration"], "0|2") @@ -53,6 +59,138 @@ class TestKernelKbykDump(TestCase): if "MINDSPORE_DUMP_CONFIG" in os.environ: del os.environ["MINDSPORE_DUMP_CONFIG"] + @patch("msprobe.mindspore.debugger.debugger_config.create_directory") + def test_handle_when_async_dump_then_pass(self, _): + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [0, 2], + "level": "L2", + "async_dump": True + } + + common_config = CommonConfig(json_config) + task_config = StatisticsConfig(json_config) + config = DebuggerConfig(common_config, task_config) + dumper = KernelKbykDump(config) + self.assertEqual(dumper.dump_json["e2e_dump_settings"]["enable"], False) + + os.environ["MS_ACL_DUMP_CFG_PATH"] = "path" + with patch("msprobe.mindspore.dump.kernel_kbyk_dump.create_directory"), \ + patch("msprobe.mindspore.dump.kernel_kbyk_dump.logger.info") as mock_info, \ + patch("msprobe.mindspore.dump.kernel_kbyk_dump.save_json") as mock_save_json: + dumper.handle() + self.assertIn("kernel_kbyk_dump.json", mock_save_json.call_args_list[0][0][0]) + mock_info.assert_called_with("/absolute_path/kernel_kbyk_dump.json has been created.") + + self.assertEqual(os.environ.get("MS_ACL_DUMP_CFG_PATH"), None) + if "MINDSPORE_DUMP_CONFIG" in os.environ: + del os.environ["MINDSPORE_DUMP_CONFIG"] + + @patch("msprobe.mindspore.debugger.debugger_config.create_directory") + def test_handle_when_device_then_pass(self, _): + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [0, 2], + "level": "L2", + "statistics": { + "list": [], + "data_mode": ["all"], + "device": "device", + "summary_mode": "statistics" + } + } + + common_config = CommonConfig(json_config) + task_config = StatisticsConfig(json_config["statistics"]) + config = DebuggerConfig(common_config, task_config) + dumper = KernelKbykDump(config) + self.assertEqual(dumper.dump_json["e2e_dump_settings"]["stat_calc_mode"], "device") + + os.environ["MS_ACL_DUMP_CFG_PATH"] = "path" + with patch("msprobe.mindspore.dump.kernel_kbyk_dump.create_directory"), \ + patch("msprobe.mindspore.dump.kernel_kbyk_dump.logger.info") as mock_info, \ + patch("msprobe.mindspore.dump.kernel_kbyk_dump.save_json") as mock_save_json: + dumper.handle() + self.assertIn("kernel_kbyk_dump.json", mock_save_json.call_args_list[0][0][0]) + mock_info.assert_called_with("/absolute_path/kernel_kbyk_dump.json has been created.") + + self.assertEqual(os.environ.get("MS_ACL_DUMP_CFG_PATH"), None) + if "MINDSPORE_DUMP_CONFIG" in os.environ: + del os.environ["MINDSPORE_DUMP_CONFIG"] + + @patch("msprobe.mindspore.debugger.debugger_config.create_directory") + def test_handle_when_precision_then_pass(self, _): + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [0, 2], + "level": "L2", + "statistics": { + "list": [], + "data_mode": ["all"], + "precision": "low", + "summary_mode": "statistics" + } + } + + common_config = CommonConfig(json_config) + task_config = StatisticsConfig(json_config["statistics"]) + config = DebuggerConfig(common_config, task_config) + dumper = KernelKbykDump(config) + self.assertEqual(dumper.dump_json["e2e_dump_settings"]["device_stat_precision_mode"], "low") + + os.environ["MS_ACL_DUMP_CFG_PATH"] = "path" + with patch("msprobe.mindspore.dump.kernel_kbyk_dump.create_directory"), \ + patch("msprobe.mindspore.dump.kernel_kbyk_dump.logger.info") as mock_info, \ + patch("msprobe.mindspore.dump.kernel_kbyk_dump.save_json") as mock_save_json: + dumper.handle() + self.assertIn("kernel_kbyk_dump.json", mock_save_json.call_args_list[0][0][0]) + mock_info.assert_called_with("/absolute_path/kernel_kbyk_dump.json has been created.") + + self.assertEqual(os.environ.get("MS_ACL_DUMP_CFG_PATH"), None) + if "MINDSPORE_DUMP_CONFIG" in os.environ: + del os.environ["MINDSPORE_DUMP_CONFIG"] + + @patch("msprobe.mindspore.debugger.debugger_config.create_directory") + def test_handle_when_default_then_pass(self, _): + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [0, 2], + "level": "L2", + "statistics": { + "list": [], + "data_mode": ["all"], + "summary_mode": "statistics" + } + } + + common_config = CommonConfig(json_config) + task_config = StatisticsConfig(json_config) + config = DebuggerConfig(common_config, task_config) + dumper = KernelKbykDump(config) + self.assertEqual(dumper.dump_json["e2e_dump_settings"]["device_stat_precision_mode"], "high") + self.assertEqual(dumper.dump_json["e2e_dump_settings"]["stat_calc_mode"], "host") + self.assertEqual(dumper.dump_json["e2e_dump_settings"]["enable"], True) + + os.environ["MS_ACL_DUMP_CFG_PATH"] = "path" + with patch("msprobe.mindspore.dump.kernel_kbyk_dump.create_directory"), \ + patch("msprobe.mindspore.dump.kernel_kbyk_dump.logger.info") as mock_info, \ + patch("msprobe.mindspore.dump.kernel_kbyk_dump.save_json") as mock_save_json: + dumper.handle() + self.assertIn("kernel_kbyk_dump.json", mock_save_json.call_args_list[0][0][0]) + mock_info.assert_called_with("/absolute_path/kernel_kbyk_dump.json has been created.") + + self.assertEqual(os.environ.get("MS_ACL_DUMP_CFG_PATH"), None) + if "MINDSPORE_DUMP_CONFIG" in os.environ: + del os.environ["MINDSPORE_DUMP_CONFIG"] + @patch("msprobe.mindspore.debugger.debugger_config.create_directory") def test_handle_tensor(self, _): json_config = { @@ -246,4 +384,42 @@ class TestKernelKbykDump(TestCase): patch("msprobe.mindspore.dump.kernel_kbyk_dump.save_json") as mock_save_json: dumper.handle() mock_info.assert_called_with("/absolute_path/kernel_kbyk_dump.json has been created.") - self.assertEqual(os.environ.get("MS_ACL_DUMP_CFG_PATH"), None) \ No newline at end of file + self.assertEqual(os.environ.get("MS_ACL_DUMP_CFG_PATH"), None) + + @patch("msprobe.mindspore.debugger.debugger_config.create_directory") + def test_handle_statistics(self, _): + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [0, 2], + "level": "L2", + "statistics": { + "list": [], + "data_mode": ["all"], + "device": "host", + "summary_mode": ["hash", "md5", "max", "mean"] + } + } + + common_config = CommonConfig(json_config) + task_config = StatisticsConfig(json_config["statistics"]) + config = DebuggerConfig(common_config, task_config) + dumper = KernelKbykDump(config) + self.assertEqual(dumper.dump_json["e2e_dump_settings"]["stat_calc_mode"], "host") + self.assertEqual(dumper.dump_json["common_dump_settings"]["saved_data"], "statistic") + if ms_version > "2.7.0": + self.assertEqual(Counter(dumper.dump_json["common_dump_settings"]["statistic_category"]), Counter(["max", "hash", "hash:md5", "avg"])) + else: + self.assertEqual(Counter(dumper.dump_json["common_dump_settings"]["statistic_category"]), Counter(["max", "md5", "avg"])) + os.environ["MS_ACL_DUMP_CFG_PATH"] = "path" + with patch("msprobe.mindspore.dump.kernel_kbyk_dump.create_directory"), \ + patch("msprobe.mindspore.dump.kernel_kbyk_dump.logger.info") as mock_info, \ + patch("msprobe.mindspore.dump.kernel_kbyk_dump.save_json") as mock_save_json: + dumper.handle() + self.assertIn("kernel_kbyk_dump.json", mock_save_json.call_args_list[0][0][0]) + mock_info.assert_called_with("/absolute_path/kernel_kbyk_dump.json has been created.") + + self.assertEqual(os.environ.get("MS_ACL_DUMP_CFG_PATH"), None) + if "MINDSPORE_DUMP_CONFIG" in os.environ: + del os.environ["MINDSPORE_DUMP_CONFIG"] \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py index 7717f9c336202d67ee524f59c3c5f328e70a045f..9320f49ad7e1719288f675d699c3086a3ac53671 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py @@ -17,32 +17,11 @@ import unittest from unittest.mock import patch from msprobe.core.common.const import Const -from msprobe.mindspore.ms_config import (parse_json_config, parse_task_config, +from msprobe.mindspore.ms_config import (parse_task_config, TensorConfig, StatisticsConfig, OverflowCheckConfig, FreeBenchmarkConfig) class TestMsConfig(unittest.TestCase): - def test_parse_json_config(self): - mock_json_data = { - "dump_path": "./dump/", - "rank": [], - "step": [], - "level": "L1", - "statistics": { - "scope": [], - "list": [], - "data_mode": ["all"], - "summary_mode": "statistics" - } - } - with patch("msprobe.mindspore.ms_config.load_json", return_value=mock_json_data): - common_config, task_config = parse_json_config("./config.json") - self.assertEqual(common_config.task, Const.STATISTICS) - self.assertEqual(task_config.data_mode, ["all"]) - - with self.assertRaises(Exception) as context: - parse_json_config(None) - self.assertEqual(str(context.exception), "json file path is None") def test_parse_task_config(self): mock_json_config = { diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py index 495eedbf41384f820c2ca054fd73192d1966a8bd..2ae1d92502f7a2c7023e61786b8d41b2d36581f6 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py @@ -16,8 +16,11 @@ from unittest import TestCase from unittest.mock import patch import mindspore +from msprobe.core.debugger.precision_debugger import BasePrecisionDebugger from msprobe.mindspore import PrecisionDebugger -from msprobe.core.common_config import CommonConfig, BaseConfig +from msprobe.core.common_config import CommonConfig +from msprobe.mindspore.ms_config import StatisticsConfig + class TestMindsporeDebuggerSave(TestCase): def setUp(self): @@ -35,8 +38,8 @@ class TestMindsporeDebuggerSave(TestCase): } } common_config = CommonConfig(statistics_task_json) - task_config = BaseConfig(statistics_task_json) - with patch("msprobe.mindspore.debugger.precision_debugger.parse_json_config", return_value=(common_config, task_config)), \ + task_config = StatisticsConfig(statistics_task_json) + with patch.object(BasePrecisionDebugger, "_parse_config_path", return_value=(common_config, task_config)), \ patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions"): self.debugger = PrecisionDebugger() @@ -52,26 +55,27 @@ class TestMindsporeDebuggerSave(TestCase): "framework": "mindspore", "dump_data_dir": None, "data": { - "x_tensor.0": { + "x_tensor.0.debug": { "type": "mindspore.Tensor", "dtype": "Float32", - "shape": (1,), - "Max": 1.0, - "Min": 1.0, - "Mean": 1.0, - "Norm": 1.0 + "shape": (1,) }, - "x_tensor_grad.0": { + "x_tensor_grad.0.debug": { "type": "mindspore.Tensor", "dtype": "Float32", - "shape": (1,), - "Max": 2.0, - "Min": 2.0, - "Mean": 2.0, - "Norm": 2.0 + "shape": (1,) } } } + + grad_fn = mindspore.value_and_grad(forward_func, (0, 1)) grad_fn(x, y) - self.assertEqual(self.debugger.service.data_collector.data_writer.cache_debug, result_json) \ No newline at end of file + + result = self.debugger.service.data_collector.data_writer.cache_debug + # Remove 'tensor_stat_index' from all entries in the data dictionary + for key in result["data"]: + if 'tensor_stat_index' in result["data"][key]: + del result["data"][key]['tensor_stat_index'] + + self.assertEqual(result, result_json) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py index 912830ea1ab705aae63c69f5c240887d4b4ce5b7..7777ab41878829a290b396565a90b8001ab86977 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py @@ -1,7 +1,7 @@ -# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# Copyright (c) 2025, Huawei Technologies Co., Ltd. # All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -11,291 +11,146 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. language governing permissions and # limitations under the License. -import unittest from collections import defaultdict +import unittest from unittest.mock import MagicMock, patch - -from mindspore import nn, ops - -from msprobe.core.common.exceptions import MsprobeException -from msprobe.core.common.utils import Const, DumpPathAggregation -from msprobe.core.data_dump.scope import BaseScope -from msprobe.mindspore.cell_processor import CellProcessor -from msprobe.mindspore.common.log import logger -from msprobe.mindspore.common.utils import register_backward_hook_functions -from msprobe.mindspore.dump.hook_cell.api_registry import ApiRegistry, api_register -from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell from msprobe.mindspore.dump.jit_dump import JitDump -from msprobe.mindspore.service import Service +from msprobe.mindspore.mindspore_service import MindsporeService +from msprobe.core.common.utils import Const +from mindspore import ops +try: + from mindspore.common._pijit_context import PIJitCaptureContext +except ImportError: + pijit_label = False +else: + pijit_label = True -class TestService(unittest.TestCase): + +class TestMindsporeService(unittest.TestCase): def setUp(self): - self.config_mock = MagicMock() - self.config_mock.level_ori = Const.LEVEL_L0 - self.config_mock.dump_path = "/tmp/dump" - self.config_mock.step = [] - self.config_mock.rank = [] - self.config_mock.task = Const.TENSOR - self.config_mock.framework = Const.MS_FRAMEWORK - self.config_mock.list = [] - self.config_mock.scope = [] - self.service = Service(self.config_mock) - self.service.model = MagicMock(spec=nn.Cell) + + self.config = MagicMock() + self.config.step = [] + self.config.rank = [] + self.config.level_ori = Const.LEVEL_MIX + self.config.task = Const.STATISTICS + + with patch('msprobe.core.service.build_data_collector'): + self.service = MindsporeService(self.config) + + self.service.logger = MagicMock() self.service.data_collector = MagicMock() self.service.primitive_hook_service = MagicMock() - - def tearDown(self) -> None: - api_register.api_set_ori_func() - - def test_init(self): - self.assertEqual(self.service.config.level, "L0") - self.assertFalse(self.service.switch) - self.assertFalse(self.service.should_stop_service) - self.assertFalse(self.service.start_call) - self.assertTrue(self.service.first_start) - - def test_check_model_valid_with_valid_cell(self): - model = nn.Cell() - model_list = [model] - self.assertEqual(self.service.check_model_valid(model), model) - self.assertEqual(self.service.check_model_valid(model_list), model_list) - - def test_check_model_valid_with_invalid_type(self): - model = nn.Cell() - with self.assertRaises(MsprobeException): - self.service.check_model_valid("not a cell") - with self.assertRaises(MsprobeException): - self.service.check_model_valid(["not a cell", model]) - - def test_update_primitive_counters(self): - self.service.primitive_counters = {} - self.service.update_primitive_counters("conv2d") - self.assertEqual(self.service.primitive_counters["conv2d"], 0) - self.service.update_primitive_counters("conv2d") - self.assertEqual(self.service.primitive_counters["conv2d"], 1) - - @patch('msprobe.mindspore.service.create_directory') - def test_create_dirs(self, mock_create_directory): - self.service.current_iter = 1 - self.service.current_rank = 0 - self.service.data_collector.tasks_need_tensor_data = [Const.TENSOR] - self.service.data_collector.update_dump_paths = MagicMock() - self.service.create_dirs() - expected_calls = [ - ("/tmp/dump"), - ("/tmp/dump/step1/rank0"), - "/tmp/dump/step1/rank0/dump_tensor_data" - ] - mock_create_directory.assert_has_calls( - [unittest.mock.call(path) for path in expected_calls], any_order=True) - - args, _ = self.service.data_collector.update_dump_paths.call_args - self.assertEqual(args[0].dump_file_path, "/tmp/dump/step1/rank0/dump.json") - self.assertEqual(args[0].stack_file_path, "/tmp/dump/step1/rank0/stack.json") - self.assertEqual(args[0].construct_file_path, "/tmp/dump/step1/rank0/construct.json") - self.assertEqual(args[0].dump_tensor_data_dir, "/tmp/dump/step1/rank0/dump_tensor_data") - self.service.data_collector.initialize_json_file.assert_called_once_with( - framework=Const.MS_FRAMEWORK + self.service.cell_processor = MagicMock() + self.service.api_register = MagicMock() + + @patch('msprobe.mindspore.mindspore_service.is_mindtorch') + def test_framework_type(self, mock_is_mindtorch): + mock_is_mindtorch.return_value = True + self.assertEqual(self.service._get_framework_type, Const.MT_FRAMEWORK) + mock_is_mindtorch.return_value = False + self.assertEqual(self.service._get_framework_type, Const.MS_FRAMEWORK) + + @patch('msprobe.mindspore.mindspore_service.get_rank_if_initialized') + def test_get_current_rank(self, mock_get_rank): + mock_get_rank.return_value = 3 + self.assertEqual(MindsporeService._get_current_rank(), 3) + + def test_init_specific_components(self): + with patch('msprobe.core.service.build_data_collector'): + service = MindsporeService(self.config) + + self.assertIsNotNone(service.logger) + self.assertIsNotNone(service.api_register) + self.assertIsNotNone(service.primitive_hook_service) + self.assertIsNotNone(service.cell_processor) + self.assertIsNotNone(service.hook_manager) + + @patch.object(JitDump, "set_data_collector") + @patch.object(JitDump, "set_config") + @patch('msprobe.mindspore.mindspore_service.ms.common.api') + def test_setup_jit_context_with_pijit(self, mock_ms_api, mock_jit_set_config, mock_set_data_collector): + mock_ms_api.__dict__['_MindsporeFunctionExecutor'] = MagicMock() + self.service._setup_jit_context() + + mock_jit_set_config.assert_called_once_with(self.config) + mock_set_data_collector.assert_called_once_with(self.service.data_collector) + self.assertEqual(mock_ms_api._MindsporeFunctionExecutor, JitDump) + self.assertEqual(mock_ms_api._PyNativeExecutor.grad, JitDump.grad) + if pijit_label: + self.assertEqual(PIJitCaptureContext.__enter__, self.service.empty) + self.assertEqual(PIJitCaptureContext.__exit__, self.service.empty) + + @patch('msprobe.mindspore.mindspore_service.JitDump') + def test_change_jit_switch(self, mock_jit_dump): + self.service._change_jit_switch(True) + self.assertTrue(mock_jit_dump.jit_dump_switch) + + self.service._change_jit_switch(False) + self.assertFalse(mock_jit_dump.jit_dump_switch) + + def test_register_module_hook(self): + model_mock = MagicMock() + self.service.model = model_mock + self.service._register_module_hook() + + self.service.cell_processor.register_cell_hook.assert_called_once_with( + model_mock, self.service.build_hook, self.config ) - - @patch.object(Service, 'need_end_service', return_value=False) - def test_start_stop_cycle(self, mock_need_end_service): - self.service.model = nn.Cell() - with patch.object(self.service, 'register_cell_hook') as mock_register_hook: - self.should_stop_service = False - self.service.start(self.service.model) - self.assertTrue(self.service.switch) - self.service.stop() - self.assertFalse(self.service.switch) - mock_register_hook.assert_called_once() - mock_need_end_service.assert_called_once() - - def test_should_execute_hook_return_false(self): - cell = MagicMock() - self.service.switch = False - self.assertFalse(self.service.should_execute_hook("Module", cell, True)) - self.assertFalse(self.service.should_execute_hook("api", cell, True)) - - self.service.switch = True - cell.forward_data_collected = False - self.assertFalse(self.service.should_execute_hook("api", cell, False)) - - self.service.inner_switch = True - self.assertFalse(self.service.should_execute_hook("Module", cell, True)) - - self.service.inner_switch = False - self.service.data_collector = None - self.assertFalse(self.service.should_execute_hook("Module", cell, True)) - - def test_should_execute_hook_return_true(self): - cell = MagicMock() - self.service.switch = True - self.service.inner_switch = False - self.service.data_collector = MagicMock() - self.service.data_collector.data_processor = MagicMock() - self.service.data_collector.data_processor.is_terminated = False - self.assertTrue(self.service.should_execute_hook("Module", cell, True)) - - cell.forward_data_collected = True - self.assertTrue(self.service.should_execute_hook("api", cell, False)) - - def test_need_end_service_with_high_step(self): - self.service.config.step = [1, 2, 3] - self.service.current_iter = 4 - self.assertTrue(self.service.need_end_service()) - - def test_need_end_service_with_low_step(self): - self.service.config.step = [1, 2, 3] - self.service.current_iter = 2 - self.service.data_collector.data_processor.is_terminated = False - self.assertFalse(self.service.need_end_service()) - - def test_start_with_termination_condition(self): - self.service.config.step = [1, 2, 3] - self.service.current_iter = 4 - self.service.start() - self.assertFalse(self.service.switch) - self.assertTrue(self.service.should_stop_service) - self.assertFalse(self.service.primitive_switch) - - @patch('msprobe.mindspore.service.print_tools_ends_info') - @patch.object(Service, 'need_end_service', return_value=True) - def test_start_with_end_service(self, mock_need_end_service, mock_print_tools_ends_info): - self.service.start(self.service.model) - mock_need_end_service.assert_called_once() - mock_print_tools_ends_info.assert_called_once() - self.assertFalse(self.service.switch) - self.assertTrue(self.service.should_stop_service) - - @patch.object(Service, 'need_end_service', return_value=False) - @patch.object(logger, 'info') - @patch.object(Service, 'register_cell_hook') - @patch.object(Service, 'register_primitive_hook') - @patch.object(Service, 'create_dirs') - @patch('msprobe.mindspore.service.get_rank_if_initialized', return_value=0) - def test_start_first_time(self, mock_get_rank, mock_create_dirs, mock_register_primitive_hook, - mock_register_cell_hook, mock_logger, mock_need_end_service): - self.service.first_start = True - self.service.should_stop_service = False - self.service.start(self.service.model) - mock_get_rank.assert_called_once() - mock_register_cell_hook.assert_called_once() - mock_register_primitive_hook.assert_called_once() - mock_need_end_service.assert_called_once() - mock_create_dirs.assert_called_once() - self.assertFalse(self.service.first_start) - self.assertTrue(self.service.switch) - self.assertTrue(self.service.primitive_switch) - mock_logger.assert_called_with(f"Dump data will be saved in {self.service.dump_iter_dir}.") - - @patch.object(Service, 'register_primitive_hook') - @patch.object(Service, 'register_cell_hook') - @patch.object(Service, 'need_end_service', return_value=False) - @patch.object(JitDump, 'set_config') - @patch.object(JitDump, 'set_data_collector') - @patch.object(ApiRegistry, 'api_set_hook_func') - def test_start_with_jit_dump_enabled(self, mock_api_set_hook_func, mock_set_data_collector, - mock_set_config, mock_need_end_service, mock_register_cell_hook, - mock_register_primitive_hook): - self.service.config.level = Const.LEVEL_MIX - self.service.first_start = True - self.service.should_stop_service = False - self.service.start(self.service.model) - mock_set_config.assert_called_with(self.service.config) - mock_set_data_collector.assert_called_with(self.service.data_collector) - mock_api_set_hook_func.assert_called_once() - mock_need_end_service.assert_called_once() - mock_register_cell_hook.assert_called_once() - mock_register_primitive_hook.assert_called_once() - self.assertTrue(JitDump.jit_dump_switch) - - def test_step_updates(self): - CellProcessor.cell_count = {"test_api": 1} - HOOKCell.cell_count = {"test_api": 1} - JitDump.jit_count = {"test_api": 1} - self.service.primitive_hook_service.primitive_counters = {"test_api": 1} - self.service.current_iter = 0 - self.service.step() - self.assertEqual(self.service.current_iter, 1) - self.service.data_collector.update_iter.assert_called_once_with(1) - self.service.data_collector.reset_status.assert_called_once() - self.assertEqual(JitDump.jit_count, defaultdict(int)) - self.assertEqual((self.service.primitive_hook_service.primitive_counters), {}) - - @patch.object(Service, 'should_execute_hook') - def test_build_forward_and_backward_hooks(self, mock_should_execute_hook): - mock_should_execute_hook.return_value = True - self.service.data_collector = MagicMock() - self.service.data_collector.update_api_or_module_name = MagicMock() - self.service.data_collector.forward_data_collect = MagicMock() - self.service.data_collector.if_return_forward_new_output = MagicMock(return_value=False) - self.service.data_collector.backward_data_collect = MagicMock() - - mock_cell = MagicMock() - mock_cell.mindstudio_reserved_name = "TestCell" - mock_input = (MagicMock(),) - mock_output = MagicMock() - - _, forward_hook, backward_hook, _ = self.service.build_hook(BaseScope.Module_Type_Module, "TestHook") - - forward_hook(mock_cell, mock_input, mock_output) - self.service.data_collector.update_api_or_module_name.assert_called_with('TestCell') - self.service.data_collector.forward_data_collect.assert_called() - - self.service.data_collector.reset_mock() - - mock_grad_input = (MagicMock(),) - mock_grad_output = MagicMock() - - backward_hook(mock_cell, mock_grad_input, mock_grad_output) - self.service.data_collector.update_api_or_module_name.assert_called_with('TestHookbackward.0') - self.service.data_collector.backward_data_collect.assert_called() - + def test_register_primitive_hook(self): self.service.config.level = Const.LEVEL_MIX primitive_attr = ops.Add() primitive_name = "primitive_api" + mock_model = MagicMock() cell_mock = MagicMock() cell_mock.primitive_api = primitive_attr primitive_combined_name = primitive_name + Const.SEP + primitive_attr.__class__.__name__ - self.service.model.cells_and_names.return_value = [("cell_name", cell_mock)] - self.service.register_primitive_hook() + self.service.model = mock_model + with patch('msprobe.mindspore.mindspore_service.get_cells_and_names_with_index') as mock_get_cells_and_names: + mock_get_cells_and_names.return_value = ({'-1': [("cell_name", cell_mock)]}, {}) + self.service._register_primitive_hook() self.assertTrue(hasattr(primitive_attr.__class__, '__call__')) self.assertEqual(self.service.primitive_hook_service.wrap_primitive.call_args[0][1], primitive_combined_name) - - @patch.object(ApiRegistry, 'initialize_hook') - @patch.object(ApiRegistry, 'api_set_hook_func') - @patch("msprobe.mindspore.service.logger.info") - def test_register_hook_new_with_level_mix(self, mock_logger, mock_api_set_hook_func, mock_initialize_hook): - self.service.config.level = Const.LEVEL_MIX - self.service.register_api_hook() - self.service.register_cell_hook() - mock_logger.assert_called_with(f"The cell {self.service.config.task} hook function " - "is successfully mounted to the model.") - mock_api_set_hook_func.assert_called() - mock_initialize_hook.assert_called() - - @patch.object(CellProcessor, 'node_hook') - def test_register_hook_new_with_level_l0(self, mock_node_hook): - global register_backward_hook_functions - self.service.config.level = Const.LEVEL_L0 - cell_mock = MagicMock() - self.service.model.cells_and_names.return_value = [("cell_name", cell_mock)] - register_backward_hook_functions["pre"] = cell_mock.register_backward_pre_hook - register_backward_hook_functions["full"] = cell_mock.register_backward_hook - self.service.register_cell_hook() - cell_mock.register_forward_hook.assert_called() - cell_mock.register_backward_hook.assert_called() - mock_node_hook.assert_called() - register_backward_hook_functions = {} - - def test_register_hook_new_without_model_raises_exception(self): - self.service.config.level = Const.LEVEL_L0 - self.service.model = None - with self.assertRaises(MsprobeException): - self.service.register_cell_hook() + + def test_reset_status(self): + self.service.primitive_hook_service.primitive_counters = defaultdict(int) + self.service.primitive_hook_service.primitive_counters['test_prim'] = 5 + self.service._reset_status() + self.assertEqual(self.service.primitive_hook_service.primitive_counters, {}) + with patch('msprobe.mindspore.mindspore_service.JitDump') as mock_jit_dump: + mock_jit_dump.jit_count = defaultdict(int) + mock_jit_dump.jit_count['test_jit'] = 3 + self.service._reset_status() + self.assertEqual(mock_jit_dump.jit_count, {}) + + @patch('msprobe.mindspore.mindspore_service.JitDump') + def test_start_jit_enabled(self, mock_jit_dump): + self.service.data_collector.data_processor.is_terminated = False + model_mock = MagicMock() + self.service.start(model=model_mock) + self.assertTrue(mock_jit_dump.jit_dump_switch) + + @patch('msprobe.mindspore.mindspore_service.JitDump') + def test_stop_jit_disabled(self, mock_jit_dump): + self.config.level = Const.LEVEL_MIX + self.service.current_iter = 1 + self.service.current_rank = 0 + + self.service.stop() + + self.assertFalse(mock_jit_dump.jit_dump_switch) + + @patch('msprobe.mindspore.mindspore_service.JitDump') + @patch('msprobe.mindspore.mindspore_service.ms.common.api') + def test_setup_jit_context_level_not_supported(self, mock_ms_api, mock_jit_dump): + self.service.config.level = Const.LEVEL_DEBUG + + self.service._setup_jit_context() + + mock_jit_dump.set_config.assert_not_called() + mock_jit_dump.set_data_collector.assert_not_called() diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py index f46f171aa38585ea801f1fd3a9716bd3876a63a5..520a688dcf475fd3e6831ab0b25c7dda9faeb31b 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py @@ -1,7 +1,6 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,19 +12,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" + from unittest import TestCase from unittest.mock import patch -from msprobe.mindspore.common.const import Const +from msprobe.core.common.log import logger from msprobe.core.common_config import CommonConfig, BaseConfig +from msprobe.mindspore.common.const import Const from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.overflow_check.overflow_check_tool_factory import OverflowCheckToolFactory class TestOverflowCheckToolFactory(TestCase): + @patch.object(logger, "error") @patch("msprobe.mindspore.debugger.debugger_config.create_directory") - def test_create(self, _): + def test_create(self, _, mock_logger_error): json_config = { "task": "overflow_check", "dump_path": "/absolute_path", @@ -45,12 +46,11 @@ class TestOverflowCheckToolFactory(TestCase): config.execution_mode = Const.GRAPH_GE_MODE config.level = "cell" - with self.assertRaises(Exception) as context: + with self.assertRaises(ValueError): OverflowCheckToolFactory.create(config) - self.assertEqual(str(context.exception), - f"Overflow check is not supported in {config.execution_mode} mode " - f"when level is {config.level}.") + mock_logger_error.assert_called_with(f"Overflow check is not supported in {config.execution_mode} mode " + f"when level is {config.level}.") config.level = "kernel" - dumper = OverflowCheckToolFactory.create(config) + dumper = OverflowCheckToolFactory.create(config)[0] self.assertEqual(dumper.dump_json["common_dump_settings"]["file_format"], "npy") diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py index 3cafd49f2c101c45dbb65a08803dd77c6bca485d..a69caed2569cd875224cf7b87ea16ced69ce3ae5 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py @@ -1,8 +1,7 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -13,95 +12,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" + +from collections import defaultdict +import tempfile import unittest -import mindspore as ms -import numpy as np -import os from unittest.mock import Mock, patch -from mindspore import nn +import numpy as np +import mindspore as ms +from mindspore import Tensor, ops -import tempfile from msprobe.core.common.utils import Const -from msprobe.mindspore.service import Service -from msprobe.core.common.exceptions import MsprobeException +from msprobe.mindspore.mindspore_service import MindsporeService +from msprobe.core.common.runtime import Runtime from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell -from collections import defaultdict from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService -from mindspore.common.tensor import Tensor - - -class DummyModel(nn.Cell): - def __init__(self): - super(DummyModel, self).__init__() - self.dense = nn.Dense(2, 2) - - def construct(self, x): - return self.dense(x) - - -class TestService(unittest.TestCase): - @patch("msprobe.mindspore.debugger.debugger_config.create_directory") - def setUp(self, _): - json_config = { - "task": "statistics", - "dump_path": "/absolute_path", - "rank": [], - "step": [0, 2], - "level": "L1" - } - - common_config = CommonConfig(json_config) - task_config = BaseConfig(json_config) - config = DebuggerConfig(common_config, task_config) - self.service = Service(config) - self.service.model = Mock() - self.service.data_collector = Mock() - self.service.switch = True # Make sure the switch is on for testing - self.service.primitive_switch = True # Make sure the switch is on for testing - - def test_check_model_valid_none(self): - model = None - self.assertIsNone(self.service.check_model_valid(model)) - - def test_check_model_valid_valid_model(self): - model = DummyModel() - self.assertEqual(self.service.check_model_valid(model), model) - - def test_check_model_valid_invalid_model(self): - model = "invalid_model" - with self.assertRaises(MsprobeException) as context: - self.service.check_model_valid(model) - - def test_update_primitive_counters(self): - primitive_name = "test_primitive" - self.service.primitive_hook_service.update_primitive_counters(primitive_name) - self.assertEqual(self.service.primitive_hook_service.primitive_counters[primitive_name], 0) - self.service.primitive_hook_service.update_primitive_counters(primitive_name) - self.assertEqual(self.service.primitive_hook_service.primitive_counters[primitive_name], 1) - - def test_step_updates_iteration(self): - initial_iter = self.service.current_iter - self.service.step() - self.assertEqual(self.service.current_iter, initial_iter + 1) - - @patch.object(HOOKCell, 'cell_count', new_callable=lambda: defaultdict(int)) - def test_step_resets_counters(self, _): - # 假设在 step 调用之前已经有一些 primitive_counters - self.service.primitive_hook_service.primitive_counters["test_primitive"] = 5 - self.service.step() - self.assertEqual(self.service.primitive_hook_service.primitive_counters, {}) - self.assertEqual(HOOKCell.cell_count, defaultdict(int)) - - def test_step_calls_update_iter(self): - # 检查是否在调用 step 时调用了 update_iter - with patch.object(self.service.data_collector, 'update_iter') as mock_update_iter: - initial_iter = self.service.current_iter - self.service.step() - mock_update_iter.assert_called_once_with(initial_iter + 1) +from msprobe.mindspore.ms_config import StatisticsConfig class TestPrimitiveHookService(unittest.TestCase): @@ -118,21 +46,16 @@ class TestPrimitiveHookService(unittest.TestCase): } common_config = CommonConfig(json_config) - task_config = BaseConfig(json_config) + task_config = StatisticsConfig(json_config) config = DebuggerConfig(common_config, task_config) - self.service = Service(config) - self.service.model = Mock() - self.service.data_collector = Mock() - self.service.switch = True # Make sure the switch is on for testing - - # 模拟一个 service_instance 和 data_collector - self.mock_service_instance = Service(config) - self.mock_service_instance.switch = True - self.mock_service_instance.data_collector = Mock() - self.mock_service_instance.data_collector.dump_file_path = json_config["dump_path"] - # 初始化 PrimitiveHookService - self.primitive_hook_service = PrimitiveHookService(self.mock_service_instance) + with patch('msprobe.core.service.build_data_collector'), \ + patch('msprobe.mindspore.mindspore_service.CellProcessor'), \ + patch('msprobe.mindspore.mindspore_service.PrimitiveHookService'), \ + patch('msprobe.mindspore.mindspore_service.get_api_register'): + self.mock_service_instance = MindsporeService(config) + Runtime.is_running = True + self.primitive_hook_service = PrimitiveHookService(self.mock_service_instance) def tearDown(self): # 测试结束时删除临时目录 @@ -147,7 +70,6 @@ class TestPrimitiveHookService(unittest.TestCase): # 调用 wrap_primitive 获取包装函数通过闭包显式调用backward_hook hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents - wrapped_primitive_call = self.primitive_hook_service.wrap_primitive(None, "example") create_backward_hook = hook_primitive_inputs.__closure__[0].cell_contents @@ -162,7 +84,6 @@ class TestPrimitiveHookService(unittest.TestCase): backward_hook(grad_2) self.assertEqual(len(captured_grads), 6) # 捕获到两个梯度 - print(f"1After first backward_hook call, len(captured_grads): {len(captured_grads)}") # 调用到达阈值,验证数据收集 self.assertTrue(self.mock_service_instance.data_collector.backward_output_data_collect.called) @@ -176,7 +97,6 @@ class TestPrimitiveHookService(unittest.TestCase): # 调用 wrap_primitive 获取包装函数通过闭包显式调用backward_hook hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents - wrapped_primitive_call = self.primitive_hook_service.wrap_primitive(None, "example") create_backward_hook = hook_primitive_inputs.__closure__[0].cell_contents @@ -213,14 +133,7 @@ class TestPrimitiveHookService(unittest.TestCase): # 调用 wrap_primitive 获取包装函数通过闭包显式调用backward_hook hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents - wrapped_primitive_call = self.primitive_hook_service.wrap_primitive(None, "example") - if wrapped_primitive_call.__closure__: - for i, closure in enumerate(wrapped_primitive_call.__closure__): - print(f"Closure[{i}]:", closure.cell_contents) - - if hook_primitive_inputs.__closure__: - for i, closure in enumerate(hook_primitive_inputs.__closure__): - print(f"2Closure[{i}]:", closure.cell_contents) + create_backward_hook = hook_primitive_inputs.__closure__[0].cell_contents backward_hook = create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type) @@ -234,7 +147,6 @@ class TestPrimitiveHookService(unittest.TestCase): backward_hook(grad_2) self.assertEqual(len(captured_grads), 6) # 捕获到两个梯度 - print(f"After first backward_hook call, len(captured_grads): {len(captured_grads)}") # 调用到达阈值,验证数据收集 self.assertTrue(self.mock_service_instance.data_collector.backward_input_data_collect.called) @@ -281,18 +193,15 @@ class TestPrimitiveHookService(unittest.TestCase): updated_primitive_name = "test_primitive_input" # 调用 hook_primitive_inputs - hooked_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents(args, - captured_grads_input, - updated_primitive_name) - - # 验证 hooked_inputs 是否正确添加了 hook - for arg, hooked_arg in zip(args, hooked_inputs): - if isinstance(arg, Tensor): - print(f"Captured hooked_arg after hook: {hooked_arg}") - self.assertTrue(hasattr(hooked_arg, 'grad_fn')) - - # 打印调试信息 - print(f"Captured gradients after hook: {captured_grads_input}") + hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents + with patch.object(ops, 'HookBackward') as mock_HookBackward: + target_value = Tensor([1.0]) + mock_hbw = mock_HookBackward.return_value + mock_hbw.return_value = target_value + hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name) + self.assertEqual(mock_HookBackward.call_count, len(args)) + for hooked_input in hooked_inputs: + self.assertTrue((hooked_input == target_value).all()) def test_hook_primitive_outputs(self): # 模拟前向输出 @@ -301,17 +210,16 @@ class TestPrimitiveHookService(unittest.TestCase): updated_primitive_name = "test_primitive_output" # 调用 hook_primitive_outputs - hook_primitive_outputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[ - 1].cell_contents - hooked_outputs = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name) - - # 验证 hooked_outputs 是否正确添加了 hook - for tensor, hooked_tensor in zip(out, hooked_outputs): - if isinstance(tensor, Tensor): - self.assertTrue(hasattr(hooked_tensor, 'grad_fn')) - - # 打印调试信息 - print(f"Captured gradients after output hook: {captured_grads_output}") + hook_primitive_outputs = self.primitive_hook_service.wrap_primitive(None, + "example").__closure__[1].cell_contents + with patch.object(ops, 'HookBackward') as mock_HookBackward: + target_value = Tensor([1.0]) + mock_hbw = mock_HookBackward.return_value + mock_hbw.return_value = target_value + hooked_outputs = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name) + self.assertEqual(mock_HookBackward.call_count, len(out)) + for hooked_output in hooked_outputs: + self.assertTrue((hooked_output == target_value).all()) def test_wrapped_primitive_call_args(self): # 模拟前向输入 @@ -324,19 +232,18 @@ class TestPrimitiveHookService(unittest.TestCase): # 调用 wrapped_primitive_call 并检查 hooked_inputs 是否与原始 args 相同 try: - hooked_inputs = wrapped_primitive_call.__closure__[0].cell_contents(args, captured_grads_input, - updated_primitive_name) - for arg, hooked_arg in zip(args, hooked_inputs): - if isinstance(arg, Tensor): - self.assertTrue(hasattr(hooked_arg, 'grad_fn')) - self.assertTrue(np.array_equal(arg.asnumpy(), hooked_arg.asnumpy())) - print(f"Arg type: {type(arg)}, Hooked input type: {type(hooked_arg)}") - else: - self.assertEqual(arg, hooked_arg) + with patch.object(ops, 'HookBackward') as mock_HookBackward: + target_value = Tensor([1.0]) + mock_hbw = mock_HookBackward.return_value + mock_hbw.return_value = target_value + hooked_inputs = wrapped_primitive_call.__closure__[0].cell_contents(args, captured_grads_input, + updated_primitive_name) + self.assertEqual(mock_HookBackward.call_count, len(args)) + for hooked_input in hooked_inputs: + self.assertTrue((hooked_input == target_value).all()) except Exception as e: self.fail(f"wrapped_primitive_call raised an exception: {e}") - def test_update_primitive_counters_multiple(self): # 测试更新 primitive 计数器的功能,增加多个不同名称的测试 primitive_names = ["MatMul", "Conv2D", "ReLU", "Softmax"] @@ -366,7 +273,7 @@ class TestPrimitiveHookService(unittest.TestCase): def test_wrap_primitive_no_hook_with_invalid_input(self): # 测试在 switch 关闭时传入无效输入时的行为 - self.mock_service_instance.switch = False + Runtime.is_running = False invalid_inputs = [None, "invalid_tensor", 123] @@ -415,13 +322,11 @@ class TestPrimitiveHookService(unittest.TestCase): for captured_grads in captured_grads_sets: updated_primitive_name = "MatMul.Backward" - num_tensors = len(captured_grads) hook = self.primitive_hook_service.wrap_primitive(Mock(), "MatMul") backward_hook = hook(Mock(), captured_grads, updated_primitive_name, Const.INPUT) self.assertIsNotNone(backward_hook) - @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') def test_wrap_primitive_forward_and_backward_hooks(self, mock_hook_backward): # 模拟前向和后向钩子在同一个 primitive 中的行为 @@ -446,9 +351,6 @@ class TestPrimitiveHookService(unittest.TestCase): self.primitive_hook_service.update_primitive_counters(name) self.assertEqual(self.primitive_hook_service.primitive_counters[name], i) - - - def test_update_primitive_counters(self): primitive_name = "MatMul" self.primitive_hook_service.update_primitive_counters(primitive_name) @@ -495,7 +397,7 @@ class TestPrimitiveHookService(unittest.TestCase): wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") # 模拟反向传播过程,调用包装的 primitive - with patch.object(self.mock_service_instance.data_collector, 'backward_data_collect') as mock_backward_collect: + with patch.object(self.mock_service_instance.data_collector, 'backward_data_collect'): result = wrapped_func(Mock(), input_tensor) # 验证结果是 Tensor 实例 @@ -503,7 +405,7 @@ class TestPrimitiveHookService(unittest.TestCase): def test_wrap_primitive_no_hook_when_switch_off(self): # 模拟 switch 关闭的情况 - self.mock_service_instance.switch = False + Runtime.is_running = False # 模拟 Tensor 输入 input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) @@ -543,7 +445,6 @@ class TestPrimitiveHookService(unittest.TestCase): # 测试 create_backward_hook 的功能 captured_grads = [] updated_primitive_name = "MatMul.Backward" - num_tensors = 2 # 创建 backward hook backward_hook = self.primitive_hook_service.wrap_primitive(Mock(), "MatMul") diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py index 752b5f916d50083dba842707feeeda2edcbe14f0..bce627650093f0bd19e3d6d831bd9e5b065dd286 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py @@ -20,6 +20,7 @@ from unittest.mock import patch from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump +from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump from msprobe.mindspore.task_handler_factory import TaskHandlerFactory from msprobe.mindspore.common.const import Const @@ -47,7 +48,9 @@ class TestTaskHandlerFactory(TestCase): config.execution_mode = Const.GRAPH_GE_MODE handler = TaskHandlerFactory.create(config) - self.assertTrue(isinstance(handler, KernelGraphDump)) + self.assertTrue(isinstance(handler, tuple)) + self.assertTrue(isinstance(handler[1], KernelKbykDump)) + self.assertTrue(isinstance(handler[0], KernelGraphDump)) with patch("msprobe.mindspore.task_handler_factory.TaskHandlerFactory.tasks", new=tasks): with self.assertRaises(Exception) as context: diff --git a/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer.py b/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..833844e5beb3c0b762bf34369ea977b5ce79dee8 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer.py @@ -0,0 +1,358 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os.path +import unittest +from unittest.mock import patch +import argparse + +from msprobe.nan_analyze.analyzer import _nan_analyze_parser, NanAnalyzer + + +class DumpDataBuilder: + def __init__(self): + self.nodes = {} + self.layer = {} + + @staticmethod + def gen_data(is_normal, **kwargs): + def gen_single_data(normal): + return { + 'type': 'torch.Tensor', + 'dtype': 'torch.float32', + 'shape': [ + 2, + 1024 + ], + 'Max': 2.0 if normal else 'inf', + 'Min': 1.0 if normal else '-inf', + 'Mean': 1.5 if normal else 'nan', + 'Norm': 2.236 if normal else 'nan', + 'requires_grad': False + } + + def gen_int(value): + return { + 'type': 'int', + 'value': value + } + + def gen_process_group(ranks): + return { + 'type': 'ProcessGroup', + 'group_ranks': ranks + } + + data_type = kwargs.get('type') + if data_type == 'compute': + return { + 'input_args': [gen_single_data(True)], + 'input_kwargs': {}, + 'output': [gen_single_data(is_normal)] + } + if data_type == 'p2p_src': + return { + 'input_args': [gen_single_data(kwargs.get('is_input_normal'))], + 'input_kwargs': {'dst': gen_int(kwargs.get('dst'))}, + 'output': [gen_single_data(kwargs.get('is_output_normal'))] + } + if data_type == 'p2p_dst': + return { + 'input_args': [gen_single_data(kwargs.get('is_input_normal'))], + 'input_kwargs': {'src': gen_int(kwargs.get('src'))}, + 'output': [gen_single_data(kwargs.get('is_output_normal'))] + } + if data_type == 'p2g_src': + return { + 'input_args': [gen_single_data(kwargs.get('is_input_normal')), gen_int(kwargs.get('src'))], + 'input_kwargs': {'group': gen_process_group(kwargs.get('ranks'))}, + 'output': [gen_single_data(kwargs.get('is_output_normal'))] + } + if data_type == 'p2g_dst': + return { + 'input_args': [gen_single_data(kwargs.get('is_input_normal')), gen_int(kwargs.get('dst'))], + 'input_kwargs': {'group': gen_process_group(kwargs.get('ranks'))}, + 'output': [gen_single_data(kwargs.get('is_output_normal'))] + } + if data_type == 'link': + return { + 'input_args': [gen_single_data(kwargs.get('is_input_normal'))], + 'input_kwargs': {}, + 'output': [gen_single_data(kwargs.get('is_output_normal'))] + } + + def add_node(self, is_normal, **kwargs): + name = kwargs.get("name", 'operator') + layer = self.layer.get(name, 0) + if kwargs.get('type') == 'compute': + node_name = f'Torch.operator.{layer}.forward' + else: + node_name = f'Distributed.{name}.{layer}.forward' + self.nodes[node_name] = self.gen_data(is_normal, **kwargs) + self.layer[name] = layer + 1 + return self + + def build(self): + return self.nodes + + +rank_order_dict = { + # (name, type, src, dst, ranks) + 0: [(0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + ('send', 'p2p_src', 0, 1, []), + ('recv', 'p2p_dst', 1, 0, []), + (0, 'compute', 0, 0, []), + ('broadcast', 'p2g_src', 0, 0, [0, 1, 2, 3]), + (0, 'compute', 0, 0, []), + ('all_gather', 'link', 0, 0, []), + (0, 'compute', 0, 0, []), + ('gather', 'p2g_dst', 0, 0, [0, 1, 2, 3]), + (0, 'compute', 0, 0, [])], + 1: [('recv', 'p2p_dst', 0, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + ('send', 'p2p_src', 0, 2, []), + ('recv', 'p2p_dst', 2, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + ('send', 'p2p_src', 0, 0, []), + ('broadcast', 'p2g_src', 0, 0, [0, 1, 2, 3]), + (0, 'compute', 0, 0, []), + ('all_gather', 'link', 0, 0, []), + (0, 'compute', 0, 0, []), + ('gather', 'p2g_dst', 0, 0, [0, 1, 2, 3]), + (0, 'compute', 0, 0, [])], + 2: [('recv', 'p2p_dst', 1, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + ('send', 'p2p_src', 0, 3, []), + ('recv', 'p2p_dst', 3, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + ('send', 'p2p_src', 0, 1, []), + ('broadcast', 'p2g_src', 0, 0, [0, 1, 2, 3]), + (0, 'compute', 0, 0, []), + ('all_gather', 'link', 0, 0, []), + (0, 'compute', 0, 0, []), + ('gather', 'p2g_dst', 0, 0, [0, 1, 2, 3]), + (0, 'compute', 0, 0, [])], + 3: [('recv', 'p2p_dst', 2, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + ('send', 'p2p_src', 0, 2, []), + ('broadcast', 'p2g_src', 0, 0, [0, 1, 2, 3]), + (0, 'compute', 0, 0, []), + ('all_gather', 'link', 0, 0, []), + (0, 'compute', 0, 0, []), + ('gather', 'p2g_dst', 0, 0, [0, 1, 2, 3]), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, [])] +} + + +def do_nothing(*args, **kwargs): + return + + +def gen_normal_dump_json(rank): + builder = DumpDataBuilder() + for name, data_type, src, dst, ranks in rank_order_dict[rank]: + builder = builder.add_node(True, name=name, type=data_type, src=src, dst=dst, ranks=ranks, + is_input_normal=True, is_output_normal=True) + return {'task': 'statistics', + 'level': 'mix', + 'dump_data_dir': None, + 'data': builder.build() + } + + +def gen_pre_anomaly_dump_json(rank): + builder = DumpDataBuilder() + for i, (name, data_type, src, dst, ranks) in enumerate(rank_order_dict[rank]): + is_normal = True + if i == rank and i in [0, 1]: + is_normal = False + builder = builder.add_node(is_normal, name=name, type=data_type, src=src, dst=dst, ranks=ranks, + is_input_normal=True, is_output_normal=True) + return {'task': 'statistics', + 'level': 'mix', + 'dump_data_dir': None, + 'data': builder.build() + } + + +def gen_anomaly_dump_json(rank): + builder = DumpDataBuilder() + start = 999 + for i, (name, data_type, src, dst, ranks) in enumerate(rank_order_dict[rank]): + is_normal = True + is_input_normal = True + is_output_normal = True + if rank == 0: + if i == 7: + is_normal = False + elif i == 8: + is_input_normal = False + is_output_normal = False + else: + if name == 'broadcast': + start = i + is_output_normal = False + elif i > start: + is_normal = False + is_input_normal = False + is_output_normal = False + builder = builder.add_node(is_normal, name=name, type=data_type, src=src, dst=dst, ranks=ranks, + is_input_normal=is_input_normal, is_output_normal=is_output_normal) + return {'task': 'statistics', + 'level': 'mix', + 'dump_data_dir': None, + 'data': builder.build() + } + + +def gen_after_anomaly_dump_json(rank): + builder = DumpDataBuilder() + for i, (name, data_type, src, dst, ranks) in enumerate(rank_order_dict[rank]): + is_normal = (rank != 2 or i != 13) and (rank != 3 or i != 11) + builder = builder.add_node(is_normal, name=name, type=data_type, src=src, dst=dst, ranks=ranks, + is_input_normal=True, is_output_normal=True) + return {'task': 'statistics', + 'level': 'mix', + 'dump_data_dir': None, + 'data': builder.build() + } + + +json_dict = {os.path.join('./step0', f'rank{i if i > 0 else ""}', 'construct.json'): {} for i in range(4)} + + +def gen_stack_json(rank): + return {f'0': [list(json_dict[os.path.join('./step0', f'rank{rank if rank > 0 else ""}', 'dump.json')]['data'].keys()), + ['File /root/example.py, line 10, in test_fcn, \\n test(tensor)']]} + + +class mock_time: + _uni_value = 1 + + @staticmethod + def set_uni_value(var): + mock_time._uni_value = var + + @staticmethod + def time_ns(): + return mock_time._uni_value + + + +class MockedFileCache: + def load_json(self, file_path): + return json_dict[file_path] + + +class TestAnalyzer(unittest.TestCase): + def setUp(self): + self.output = {} + self.input_path = './step0' + self.output_path = './output' + with patch('os.listdir', return_value=['rank', 'rank1', 'rank2', 'rank3', 'rank_others']), \ + patch('msprobe.nan_analyze.utils.check_file_or_directory_path', do_nothing), \ + patch('msprobe.nan_analyze.analyzer.FileCache', MockedFileCache): + self.analyzer = NanAnalyzer(self.input_path, self.output_path) + + def mocked_save_json(self, file, content, indent): + self.output[file] = content + + def test_nan_analyze_parser(self): + args = [ + '-i', '/path/to/input', + '-o', '/path/to/output', + ] + + parser = argparse.ArgumentParser() + _nan_analyze_parser(parser) + parsed_args = parser.parse_args(args) + self.assertEqual(parsed_args.input_path, '/path/to/input') + self.assertEqual(parsed_args.output_path, '/path/to/output') + + def test_normal(self): + json_dict.update({os.path.join('./step0', f'rank{i if i > 0 else ""}', 'dump.json'): gen_normal_dump_json(i) for i in range(4)}) + json_dict.update({os.path.join('./step0', f'rank{i if i > 0 else ""}', 'stack.json'): gen_stack_json(i) for i in range(4)}) + with patch('os.path.exists', return_value=True), \ + patch('msprobe.nan_analyze.analyzer.logger.info', print), \ + patch('msprobe.nan_analyze.analyzer.logger.warning', print), \ + patch('msprobe.nan_analyze.analyzer.FileCache', MockedFileCache): + self.analyzer.analyze() + self.assertFalse(bool(self.output)) + + def test_pre_anomaly(self): + json_dict.update({os.path.join('./step0', f'rank{i if i > 0 else ""}', 'dump.json'): gen_pre_anomaly_dump_json(i) for i in range(4)}) + json_dict.update({os.path.join('./step0', f'rank{i if i > 0 else ""}', 'stack.json'): gen_stack_json(i) for i in range(4)}) + with patch('os.path.exists', return_value=True), \ + patch('msprobe.nan_analyze.analyzer.save_json', self.mocked_save_json), \ + patch('msprobe.nan_analyze.analyzer.logger.info', print), \ + patch('msprobe.nan_analyze.analyzer.logger.warning', print), \ + patch('msprobe.nan_analyze.analyzer.FileCache', MockedFileCache), \ + patch('msprobe.nan_analyze.graph.FileCache', MockedFileCache), \ + patch('msprobe.nan_analyze.analyzer.time', mock_time): + mock_time.set_uni_value(1) + self.analyzer.analyze() + res_json = self.output.get(os.path.join('./output', 'anomaly_analyze_1.json')) + self.assertTrue(bool(res_json)) + self.assertEqual('Torch.operator.0.forward', res_json['rank_0'][0]['op_name']) + + def test_anomaly(self): + json_dict.update({os.path.join('./step0', f'rank{i if i > 0 else ""}', 'dump.json'): gen_anomaly_dump_json(i) for i in range(4)}) + json_dict.update({os.path.join('./step0', f'rank{i if i > 0 else ""}', 'stack.json'): gen_stack_json(i) for i in range(4)}) + with patch('os.path.exists', return_value=True), \ + patch('msprobe.nan_analyze.analyzer.save_json', self.mocked_save_json), \ + patch('msprobe.nan_analyze.analyzer.logger.info', print), \ + patch('msprobe.nan_analyze.analyzer.logger.warning', print), \ + patch('msprobe.nan_analyze.analyzer.FileCache', MockedFileCache), \ + patch('msprobe.nan_analyze.graph.FileCache', MockedFileCache), \ + patch('msprobe.nan_analyze.analyzer.time', mock_time): + mock_time.set_uni_value(2) + self.analyzer.analyze() + res_json = self.output.get(os.path.join('./output', 'anomaly_analyze_2.json')) + self.assertTrue(bool(res_json)) + self.assertEqual('Torch.operator.5.forward', res_json['rank_0'][0]['op_name']) + + def test_after_anomaly(self): + json_dict.update({os.path.join('./step0', f'rank{i if i > 0 else ""}', 'dump.json'): gen_after_anomaly_dump_json(i) for i in range(4)}) + json_dict.update({os.path.join('./step0', f'rank{i if i > 0 else ""}', 'stack.json'): gen_stack_json(i) for i in range(4)}) + with patch('os.path.exists', return_value=True), \ + patch('msprobe.nan_analyze.analyzer.save_json', self.mocked_save_json), \ + patch('msprobe.nan_analyze.analyzer.logger.info', print), \ + patch('msprobe.nan_analyze.analyzer.logger.warning', print), \ + patch('msprobe.nan_analyze.analyzer.FileCache', MockedFileCache), \ + patch('msprobe.nan_analyze.graph.FileCache', MockedFileCache), \ + patch('msprobe.nan_analyze.analyzer.time', mock_time): + mock_time.set_uni_value(3) + self.analyzer.analyze() + res_json = self.output.get(os.path.join('./output', 'anomaly_analyze_3.json')) + self.assertTrue(bool(res_json)) + self.assertEqual(res_json['rank_2'][0]['op_name'], 'Torch.operator.6.forward') + self.assertEqual(res_json['rank_3'][0]['op_name'], 'Torch.operator.6.forward') + + diff --git a/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer_graph.py b/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..9bd3777ab6eb51de763a7c42d7cea98a0a4139d0 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer_graph.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest +import os +from unittest.mock import patch + +from msprobe.nan_analyze.graph import CommunicationNode, DataNode +from msprobe.nan_analyze.utils import RankPath + +from msprobe.core.common.exceptions import MsprobeException +from test_nan_analyzer import DumpDataBuilder, gen_normal_dump_json, MockedFileCache, json_dict, gen_stack_json, do_nothing + + +dump_json = {i: gen_normal_dump_json(i) for i in range(4)} + + +class TestCommunicationNode(unittest.TestCase): + def test_add_next(self): + op_name_0 = 'Distributed.send.0.forward' + op_name_1 = 'Distributed.recv.0.forward' + comm_node_0 = CommunicationNode(f'0.{op_name_0}', 0, DataNode(op_name_0, 0, dump_json[0]['data'][op_name_0])) + comm_node_1 = CommunicationNode(f'0.{op_name_1}', 0, DataNode(op_name_1, 0, dump_json[0]['data'][op_name_1])) + comm_node_0.add_next(comm_node_1) + self.assertEqual(comm_node_0.layer + 1, comm_node_1.layer) + self.assertTrue(comm_node_0 is comm_node_1.pre_node) + self.assertTrue(comm_node_1.node_id in comm_node_0.next_nodes) + + def test_add_link(self): + op_name = 'Distributed.all_gather.0.forward' + comm_node_0 = CommunicationNode(f'0.{op_name}', 0, DataNode(op_name, 0, dump_json[0]['data'][op_name])) + comm_node_1 = CommunicationNode(f'1.{op_name}', 1, DataNode(op_name, 1, dump_json[1]['data'][op_name])) + comm_node_0.add_link(comm_node_1) + self.assertEqual(comm_node_0.layer, comm_node_1.layer) + self.assertTrue(comm_node_0.node_id in comm_node_1.link_nodes) + self.assertTrue(comm_node_1.node_id in comm_node_0.link_nodes) + + def test_add_dst(self): + op_name = 'Distributed.broadcast.0.forward' + comm_node_0 = CommunicationNode(f'0.{op_name}', 0, DataNode(op_name, 0, dump_json[0]['data'][op_name])) + comm_node_1 = CommunicationNode(f'2.{op_name}', 2, DataNode(op_name, 2, dump_json[2]['data'][op_name])) + comm_node_0.add_dst(comm_node_1) + self.assertEqual(comm_node_0.layer, comm_node_1.layer) + self.assertTrue(comm_node_0.node_id in comm_node_1.src_nodes) + self.assertTrue(comm_node_1.node_id in comm_node_0.dst_nodes) + + def test_delete(self): + op_name = 'Distributed.broadcast.0.forward' + comm_node_0 = CommunicationNode(f'0.{op_name}', 0, DataNode(op_name, 0, dump_json[0]['data'][op_name])) + comm_node_1 = CommunicationNode(f'2.{op_name}', 2, DataNode(op_name, 2, dump_json[2]['data'][op_name])) + op_name = 'Distributed.recv.0.forward' + comm_node_2 = CommunicationNode(f'0.{op_name}', 0, DataNode(op_name, 0, dump_json[0]['data'][op_name])) + comm_node_2.add_next(comm_node_0) + comm_node_0.add_dst(comm_node_1) + comm_node_0.delete() + self.assertFalse(comm_node_1.src_nodes) + self.assertFalse(comm_node_2.next_nodes) + + def test_has_nan_inf(self): + op_name = 'Distributed.broadcast.0.forward' + comm_node_0 = CommunicationNode(f'0.{op_name}', 0, DataNode(op_name, 0, dump_json[0]['data'][op_name])) + self.assertFalse(comm_node_0.has_nan_inf()) + + def test_input_has_nan_inf(self): + op_name = 'Distributed.broadcast.0.forward' + comm_node_0 = CommunicationNode(f'0.{op_name}', 0, DataNode(op_name, 0, dump_json[0]['data'][op_name])) + self.assertFalse(comm_node_0.input_has_nan_inf()) + + def test_find_connected_nodes(self): + op_name = 'Distributed.broadcast.0.forward' + comm_node_0 = CommunicationNode(f'0.{op_name}', 0, DataNode(op_name, 0, dump_json[0]['data'][op_name])) + comm_node_1 = CommunicationNode(f'1.{op_name}', 1, DataNode(op_name, 1, dump_json[1]['data'][op_name])) + comm_node_2 = CommunicationNode(f'2.{op_name}', 2, DataNode(op_name, 2, dump_json[2]['data'][op_name])) + comm_node_3 = CommunicationNode(f'3.{op_name}', 3, DataNode(op_name, 3, dump_json[3]['data'][op_name])) + comm_node_0.add_dst(comm_node_1) + comm_node_0.add_dst(comm_node_2) + comm_node_0.add_dst(comm_node_3) + conn_info = comm_node_0.find_connected_nodes() + self.assertEqual(conn_info['ranks'], {0, 1, 2, 3}) + self.assertEqual(conn_info['api'], 'Distributed.broadcast') + self.assertEqual(conn_info['type'], 'dst') + + def test_resolve_type(self): + op_name = 'Distributed.broadcast.0.forward' + comm_node_0 = CommunicationNode(f'0.{op_name}', 0, DataNode(op_name, 0, dump_json[0]['data'][op_name])) + comm_node_1 = CommunicationNode(f'1.{op_name}', 1, DataNode(op_name, 1, dump_json[1]['data'][op_name])) + self.assertEqual(comm_node_0.type, 'src') + self.assertEqual(comm_node_1.type, 'dst') + + op_name = 'Distributed.all_gather.0.forward' + comm_node_2 = CommunicationNode(f'0.{op_name}', 0, DataNode(op_name, 0, dump_json[0]['data'][op_name])) + self.assertEqual(comm_node_2.type, 'link') + + +class TestDataNode(unittest.TestCase): + def setUp(self): + json_dict.update({os.path.join('./step0', f'rank{i if i > 0 else ""}', 'dump.json'): gen_normal_dump_json(i) for i in range(4)}) + json_dict.update({os.path.join('./step0', f'rank{i if i > 0 else ""}', 'stack.json'): gen_stack_json(i) for i in range(4)}) + json_dict[os.path.join('./step0', 'rank', 'construct.json')] = { + 'Torch.operator.1.forward': 'Module.module.test_model.forward.0', + 'Module.module.test_model.forward.0': 'Module.module.parent_model.forward.0', + 'Module.module.parent_model.forward.0': 'Module.module.root_model.forward.0', + 'Module.module.root_model.forward.0': None + } + + def test_find_stack(self): + with patch('msprobe.nan_analyze.graph.FileCache', MockedFileCache): + op_name = 'Torch.operator.1.forward' + data_node = DataNode(op_name, 0, dump_json[0]['data'][op_name]) + stack_info = data_node.find_stack(json_dict[os.path.join('./step0', 'rank', 'stack.json')]) + self.assertEqual(stack_info[0], 'File /root/example.py, line 10, in test_fcn, \\n test(tensor)') + with self.assertRaises(MsprobeException) as context: + data_node.find_stack({op_name: 'blabla'}) + self.assertEqual(context.exception.code, 4) + + def test_find_complete_construct(self): + with patch('msprobe.nan_analyze.graph.FileCache', MockedFileCache): + op_name = 'Torch.operator.1.forward' + construct = DataNode.find_complete_construct(json_dict[os.path.join('./step0', 'rank', 'construct.json')], + op_name) + self.assertEqual(len(construct), 4) + self.assertEqual(construct[0], 'Module.module.root_model.forward.0') + + def test_is_anomaly(self): + data_node_0 = DataNode('Torch.operator.1.forward', 0, DumpDataBuilder.gen_data(False, type='compute')) + data_node_1 = DataNode('Torch.operator.1.forward', 0, DumpDataBuilder.gen_data(True, type='compute')) + self.assertTrue(data_node_0.is_anomaly()) + self.assertFalse(data_node_1.is_anomaly()) + + def test_gen_node_info(self): + with patch('msprobe.nan_analyze.graph.FileCache', MockedFileCache), \ + patch('msprobe.nan_analyze.utils.check_file_or_directory_path', do_nothing): + op_name = 'Torch.operator.1.forward' + data_node = DataNode(op_name, 0, dump_json[0]['data'][op_name]) + node_info = data_node.gen_node_info(RankPath(0, os.path.join('./step0', 'rank', 'dump.json'), + os.path.join('./step0', 'rank', 'construct.json'), + os.path.join('./step0', 'rank', 'stack.json'))) + data_info = node_info['data_info'] + self.assertEqual(data_info['input_args'][0]['Max'], 2.0) + stack_info = node_info['stack_info'] + self.assertEqual(stack_info[0], 'File /root/example.py, line 10, in test_fcn, \\n test(tensor)') \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer_utils.py b/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a1f37e2e38b0f68045826c062d7bb7bcda261fa5 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer_utils.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest +from unittest.mock import patch + +from msprobe.nan_analyze.utils import (FileCache, is_communication_op, is_ignore_op, check_item_anomaly, + analyze_anomaly_in_group) +from msprobe.nan_analyze.graph import CommunicationNode, DataNode +from test_nan_analyzer import DumpDataBuilder + + +json_dict = {chr(no): {f'test_{chr(no)}_{i}': [f'content_{j}' for j in range(10)] for i in range(10)} for no in range(ord('a'), ord('z') + 1)} + + +def mocked_load_json(json_path): + return json_dict.get(json_path) + + +class MockedMemory: + def __init__(self): + self.available = 100000 + + +def mocked_virtual_memory(): + return MockedMemory() + + +class TestFileCache(unittest.TestCase): + def test_load_json(self): + with patch('msprobe.nan_analyze.utils.load_json', mocked_load_json), \ + patch('psutil.virtual_memory', mocked_virtual_memory): + cache = FileCache() + self.assertFalse('a' in cache._cache) + a = cache.load_json('a') + self.assertTrue('a' in cache._cache) + self.assertTrue('test_a_5' in a) + + def test_clean_up(self): + with patch('msprobe.nan_analyze.utils.load_json', mocked_load_json), \ + patch('psutil.virtual_memory', mocked_virtual_memory): + cache = FileCache() + for _ in range(100): + cache.load_json('a') + for i, no in enumerate(range(ord('a'), ord('g'))): + cache.load_json(chr(no)) + self.assertEqual('b' in cache._cache, 0 < i < 3) + self.assertTrue('a' in cache._cache) + +class TestUtils(unittest.TestCase): + def test_is_communication_op(self): + self.assertTrue(is_communication_op('Distributed.broadcast.0.forward')) + self.assertFalse(is_communication_op('Torch.operator.1.forward')) + + def test_is_ignore_op(self): + self.assertTrue(is_ignore_op('Torch.empty.1.forward')) + self.assertFalse(is_ignore_op('Torch.operator.1.forward')) + + def test_check_item_anomaly(self): + self.assertTrue(check_item_anomaly(DumpDataBuilder.gen_data(False, type='compute')['output'])) + self.assertFalse(check_item_anomaly(DumpDataBuilder.gen_data(True, type='compute')['output'])) + + def test_analyze_anomaly_in_group(self): + name = 'broadcast' + data_type = 'p2g_src' + src = 0 + dst = 0 + ranks = [0, 1, 2, 3] + op_name = f'Distributed.{name}.0.forward' + data = DumpDataBuilder.gen_data(False, name=name, type=data_type, src=src, dst=dst, ranks=ranks, + is_input_normal=True, is_output_normal=False) + node_id = f'0.{op_name}' + node = CommunicationNode(node_id, 0, DataNode(op_name, 0, data)) + anomalies = analyze_anomaly_in_group([node]) + self.assertEqual(anomalies[0].op_name, op_name) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py index a5d83ba3830c558d68884862b9870342c61701fa..8d3a922421e30a09a07057b94fc81e4d8cfe2e9c 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py @@ -29,12 +29,17 @@ class TestUtils(unittest.TestCase): self.processor.save_tensors_in_element(api_name, tensor) file_path = os.path.join(self.save_path, f'{api_name}.0.pt') self.assertTrue(os.path.exists(file_path)) - + + @patch.object(Const, "MAX_DEPTH", 50) def test_recursion_limit_error(self): tensor = torch.randn(10, 10) with self.assertRaises(DumpException) as context: - self.processor._save_recursive("test_api", [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, - [tensor, [tensor, [tensor, [tensor]]]]]]]]]]], 0) + self.processor._save_recursive("test_api", [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, + [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, + [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, + [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, + [tensor, [tensor, [tensor, [tensor + ]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]], 0) self.assertTrue(isinstance(context.exception, DumpException)) self.assertEqual(context.exception.code, DumpException.RECURSION_LIMIT_ERROR) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py index df03485dc6c77371750fd0b67ca2c37ff7e2ed7b..7c82585324effcac9a08dd1d2d5827c894311775 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py @@ -2,7 +2,7 @@ import unittest import os from unittest.mock import patch -from msprobe.pytorch.api_accuracy_checker.common.config import Config, CheckerConfig, OnlineConfig, msCheckerConfig +from msprobe.pytorch.api_accuracy_checker.common.config import Config, CheckerConfig, msCheckerConfig class TestUtConfig(): @@ -10,12 +10,6 @@ class TestUtConfig(): self.white_list = ['api1', 'api2'] self.black_list = ['api3'] self.error_data_path = '/path/to/error_data' - self.is_online = True - self.nfs_path = '/path/to/nfs' - self.host = 'localhost' - self.port = 8080 - self.rank_list = [0, 1, 2] - self.tls_path = '/path/to/tls' class TestConfig(unittest.TestCase): @@ -60,46 +54,19 @@ class TestConfig(unittest.TestCase): self.assertEqual(checker_config.white_list, msCheckerConfig.white_list) self.assertEqual(checker_config.black_list, msCheckerConfig.black_list) self.assertEqual(checker_config.error_data_path, msCheckerConfig.error_data_path) - self.assertEqual(checker_config.is_online, msCheckerConfig.is_online) - self.assertEqual(checker_config.nfs_path, msCheckerConfig.nfs_path) - self.assertEqual(checker_config.host, msCheckerConfig.host) - self.assertEqual(checker_config.port, msCheckerConfig.port) - self.assertEqual(checker_config.rank_list, msCheckerConfig.rank_list) - self.assertEqual(checker_config.tls_path, msCheckerConfig.tls_path) + def test_init_with_task_config(self): checker_config = CheckerConfig(self.task_config) self.assertEqual(checker_config.white_list, self.task_config.white_list) self.assertEqual(checker_config.black_list, self.task_config.black_list) self.assertEqual(checker_config.error_data_path, self.task_config.error_data_path) - self.assertEqual(checker_config.is_online, self.task_config.is_online) - self.assertEqual(checker_config.nfs_path, self.task_config.nfs_path) - self.assertEqual(checker_config.host, self.task_config.host) - self.assertEqual(checker_config.port, self.task_config.port) - self.assertEqual(checker_config.rank_list, self.task_config.rank_list) - self.assertEqual(checker_config.tls_path, self.task_config.tls_path) + def test_load_config(self): checker_config = CheckerConfig() checker_config.load_config(self.task_config) - self.assertEqual(checker_config.is_online, self.task_config.is_online) - self.assertEqual(checker_config.nfs_path, self.task_config.nfs_path) - self.assertEqual(checker_config.host, self.task_config.host) - self.assertEqual(checker_config.port, self.task_config.port) - self.assertEqual(checker_config.rank_list, self.task_config.rank_list) - self.assertEqual(checker_config.tls_path, self.task_config.tls_path) - - def test_get_online_config(self): - checker_config = CheckerConfig() - checker_config.load_config(self.task_config) - online_config = checker_config.get_online_config() - self.assertIsInstance(online_config, OnlineConfig) - self.assertEqual(online_config.is_online, self.task_config.is_online) - self.assertEqual(online_config.nfs_path, self.task_config.nfs_path) - self.assertEqual(online_config.host, self.task_config.host) - self.assertEqual(online_config.port, self.task_config.port) - self.assertEqual(online_config.rank_list, self.task_config.rank_list) - self.assertEqual(online_config.tls_path, self.task_config.tls_path) + def test_get_run_ut_config(self): forward_content = {'api1': 'data1', 'api2': 'data2'} diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py index 15a7908ad8de6d4883e0574ceaf451a03dbfbfe3..9e14e035ab6550250c3b992c653fefcb1f9dc8d1 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py @@ -350,59 +350,7 @@ class TestApiPrecisionCompare(unittest.TestCase): if os.path.exists(base_path): os.rmdir(base_path) - def test_online_api_precision_compare(self): - # 准备测试目录和文件 - base_path = 'test_online_compare_tmp' - os.makedirs(base_path, exist_ok=True) - - # 创建测试用的CSV文件 - npu_csv = os.path.join(base_path, 'npu.csv') - gpu_csv = os.path.join(base_path, 'gpu.csv') - result_csv = os.path.join(base_path, 'results_rank1.csv') - details_csv = os.path.join(base_path, 'details_rank1.csv') - - # 准备在线比较的配置 - online_config = MagicMock() - online_config.rank = 1 - online_config.result_csv_path = os.path.join(base_path, "results_rank*.csv") - online_config.details_csv_path = os.path.join(base_path, "details_rank*.csv") - - # 将测试数据写入CSV文件 - df = pd.DataFrame(self.test_data) - df.to_csv(npu_csv, index=False) - df.to_csv(gpu_csv, index=False) - - # 设置online_config的数据 - online_config.npu_data = pd.read_csv(npu_csv) - online_config.gpu_data = pd.read_csv(gpu_csv) - - try: - # 执行在线比较 - online_api_precision_compare(online_config) - - # 验证结果文件是否生成 - self.assertTrue(os.path.exists(result_csv)) - self.assertTrue(os.path.exists(details_csv)) - - # 读取并验证结果 - result_df = pd.read_csv(result_csv) - self.assertFalse(result_df.empty) - - details_df = pd.read_csv(details_csv) - self.assertFalse(details_df.empty) - - # 验证文件权限 - self.assertEqual(os.stat(result_csv).st_mode & 0o777, FileCheckConst.DATA_FILE_AUTHORITY) - self.assertEqual(os.stat(details_csv).st_mode & 0o777, FileCheckConst.DATA_FILE_AUTHORITY) - - finally: - # 清理测试文件 - for file_path in [npu_csv, gpu_csv, result_csv, details_csv]: - if os.path.exists(file_path): - os.remove(file_path) - if os.path.exists(base_path): - os.rmdir(base_path) - + def test_skip_due_to_empty_output(self): self.row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] = ' ' api_name = "abs" diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py index 0a88476d600958b26eaf6ca20a9a70d35b4221cc..f4c7057e976a421c216bc56d91ab938bd6702e04 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py @@ -306,7 +306,7 @@ class TestDataGenerateMethods(unittest.TestCase): tensor = gen_common_tensor(low_info, high_info, shape, data_dtype, None) self.assertTrue(tensor.max() == float('inf')) self.assertTrue(tensor.min() == float('-inf')) - + low_info = [float('inf'), float('inf')] high_info = [float('inf'), float('inf')] tensor = gen_common_tensor(low_info, high_info, shape, data_dtype, None) @@ -317,14 +317,12 @@ class TestDataGenerateMethods(unittest.TestCase): high_info = [2, float('inf')] tensor = gen_common_tensor(low_info, high_info, shape, data_dtype, None) self.assertTrue(tensor.max() == float('inf')) - self.assertTrue(torch.allclose(tensor.min(), torch.tensor(1.0), atol = 0.5)) - + low_info = [1, float('-inf')] high_info = [2, float('-inf')] tensor = gen_common_tensor(low_info, high_info, shape, data_dtype, None) - self.assertTrue(torch.allclose(tensor.max(), torch.tensor(2.0), atol = 0.3)) self.assertTrue(tensor.min() == float('-inf')) - + low_info = [1, float('nan')] high_info = [2, float('nan')] tensor = gen_common_tensor(low_info, high_info, shape, data_dtype, None) @@ -338,7 +336,7 @@ class TestDataGenerateMethods(unittest.TestCase): shape = (0, 0) tensor = gen_common_tensor(low_info, high_info, shape, data_dtype, None) self.assertEqual(tensor.numel(), 0) - + shape = (1, 2) low_info = [2, float('nan')] high_info = [2, float('nan')] diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py index 1ad191a0d4e85715e6199367d1d305c10a728630..8eb8fde4fdca88c97a4165f541f6dd6e7133303f 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py @@ -136,7 +136,7 @@ class TestMultiRunUT(unittest.TestCase): def setUp(self): self.test_json_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "dump.json") - self.test_data = {'data': {'key1': 'TRUE', 'key2': 'TRUE', 'key3': 'TRUE'}} + self.test_data = {'dump_data_dir': '/test', 'data': {'key1': 'TRUE', 'key2': 'TRUE', 'key3': 'TRUE'}} self.test_json_content = json.dumps(self.test_data) self.forward_split_files_content = [ {'key1': 'TRUE', 'key2': 'TRUE'}, diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py index cb54b4ccfef5c1aa19c4a3527b6b5cfdac7dcc77..bd07015d37ab9c1bac067db73fffe8287062526a 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py @@ -2,6 +2,7 @@ import os import copy import shutil +import tempfile import unittest from unittest.mock import patch, DEFAULT from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import * @@ -58,80 +59,80 @@ class TestFileCheck(unittest.TestCase): def test_config_path_soft_link_check(self): args = Args(config_path=self.soft_json_path, api_info_path=self.hard_json_path, out_path=self.hard_path) - + with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.SOFT_LINK_ERROR) def test_api_info_path_soft_link_check(self): args = Args(config_path=self.hard_json_path, api_info_path=self.soft_json_path, out_path=self.hard_path) - + with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.SOFT_LINK_ERROR) def test_out_path_soft_link_check(self): args = Args(config_path=self.hard_json_path, api_info_path=self.hard_json_path, out_path=self.soft_path) - + with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.SOFT_LINK_ERROR) - + def test_result_csv_path_soft_link_check(self): - args = Args(config_path=self.hard_json_path, api_info_path=self.hard_json_path, out_path=self.hard_path, + args = Args(config_path=self.hard_json_path, api_info_path=self.hard_json_path, out_path=self.hard_path, result_csv_path=self.csv_path) - + with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.SOFT_LINK_ERROR) def test_config_path_empty_check(self): args = Args(config_path=self.empty_path, api_info_path=self.hard_json_path, out_path=self.hard_path) - + with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.ILLEGAL_PATH_ERROR) - + def test_api_info_path_empty_check(self): args = Args(config_path=self.hard_json_path, api_info_path=self.empty_path, out_path=self.hard_path) - + with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.ILLEGAL_PATH_ERROR) - + def test_out_path_empty_check(self): args = Args(config_path=self.hard_json_path, api_info_path=self.hard_json_path, out_path=self.empty_path) with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.ILLEGAL_PATH_ERROR) - + def test_result_csv_path_empty_check(self): - args = Args(config_path=self.hard_json_path, api_info_path=self.hard_json_path, out_path=self.hard_path, + args = Args(config_path=self.hard_json_path, api_info_path=self.hard_json_path, out_path=self.hard_path, result_csv_path=self.empty_path) with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.ILLEGAL_PATH_ERROR) - + def test_config_path_invalid_check(self): args = Args(config_path=123, api_info_path=self.hard_json_path, out_path=self.hard_path) with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.ILLEGAL_PATH_ERROR) - + def test_api_info_path_invalid_check(self): args = Args(config_path=self.hard_json_path, api_info_path="123", out_path=self.hard_path) with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.ILLEGAL_PATH_ERROR) - + def test_out_path_invalid_check(self): args = Args(config_path=self.hard_json_path, api_info_path=self.hard_json_path, out_path=123) with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.ILLEGAL_PATH_ERROR) - + def test_result_csv_path_invalid_check(self): - args = Args(config_path=self.hard_json_path, api_info_path=self.hard_json_path, out_path=self.hard_path, + args = Args(config_path=self.hard_json_path, api_info_path=self.hard_json_path, out_path=self.hard_path, result_csv_path=123) with self.assertRaises(Exception) as context: run_ut_command(args) @@ -196,26 +197,26 @@ class TestRunUtMethods(unittest.TestCase): self.assertIsNone(data_info.bench_output) self.assertIsNone(data_info.grad_in) self.assertIsNone(data_info.in_fwd_data_list) - + def test_blacklist_and_whitelist_filter(self): api_name = "test_api" black_list = ["test_api"] white_list = [] result = blacklist_and_whitelist_filter(api_name, black_list, white_list) self.assertTrue(result) - + api_name = "test_api" black_list = [] white_list = ["another_api"] result = blacklist_and_whitelist_filter(api_name, black_list, white_list) self.assertTrue(result) - + api_name = "test_api" black_list = ["test_api"] white_list = ["test_api"] result = blacklist_and_whitelist_filter(api_name, black_list, white_list) self.assertTrue(result) - + api_name = "test_api" black_list = [] white_list = ["test_api"] @@ -230,87 +231,58 @@ class TestRunUtMethods(unittest.TestCase): api_name = "Distributed.all_reduce" result = is_unsupported_api(api_name) self.assertTrue(result) - + def test_no_backward(self): grad_index = None - out = (1, 2, 3) + out = (1, 2, 3) result = need_to_backward(grad_index, out) self.assertFalse(result) grad_index = 0 - out = 42 + out = 42 result = need_to_backward(grad_index, out) self.assertTrue(result) + def test_check_need_grad_given_out_kwarg_then_return_false(self): + from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import check_need_grad -class TestRunUtOnlineConfig(unittest.TestCase): + api_info_dict = {"input_kwargs": {"out": True}} + result = check_need_grad(api_info_dict) + self.assertFalse(result) - @patch('msprobe.pytorch.api_accuracy_checker.run_ut.run_ut.check_crt_valid') - def test_checked_online_config(self, mock_check_crt_valid): - class OnlineConfigClass: - is_online = True - rank_list = [0, 1] - nfs_path = "" - tls_path = "" - host = "127.0.0.1" - port = 12345 + def test_check_need_grad_given_no_out_kwarg_then_return_true(self): + from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import check_need_grad - mock_check_crt_valid.return_value = None + api_info_dict = {"input_kwargs": {}} + result = check_need_grad(api_info_dict) + self.assertTrue(result) - online_config = OnlineConfigClass() - res = checked_online_config(online_config) - self.assertIsNone(res) + def test_preprocess_forward_content_given_duplicate_apis_then_filter(self): + from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import preprocess_forward_content - # test is_online - online_config.is_online = "True" - with self.assertRaises(Exception) as context: - checked_online_config(online_config) - self.assertIn(str(context.exception), f"is_online must be bool type") - online_config.is_online = True + forward_content = { + "torch.add_1": {"input_args": [{"value": 1}], "input_kwargs": {}}, + "torch.add_2": {"input_args": [{"value": 1}], "input_kwargs": {}}, + "torch.sub": {"input_args": [{"value": 2}], "input_kwargs": {}} + } - # test rank_list - online_config.rank_list = "1234" - with self.assertRaises(Exception) as context: - checked_online_config(online_config) - self.assertIn(str(context.exception), f"rank_list must be a list") - online_config.rank_list = ["1", "2"] - with self.assertRaises(Exception) as context: - checked_online_config(online_config) - self.assertIn(str(context.exception), f"All elements in rank_list must be integers") - online_config.rank_list = [1, 2] + result = preprocess_forward_content(forward_content) - # test nfs_path - online_config.nfs_path = "./nfs_path" - with self.assertRaises(Exception) as context: - checked_online_config(online_config) - self.assertIn(str(context.exception), "[msprobe] 非法文件路径: ") - online_config.nfs_path = "" + self.assertEqual(len(result), 2) # One duplicate should be removed - # test tls_path - online_config.tls_path = "./tls_path" - with self.assertRaises(Exception) as context: - checked_online_config(online_config) - self.assertIn(str(context.exception), "[msprobe] 非法文件路径: ") - - os.makedirs(online_config.tls_path) - with open(os.path.join(online_config.tls_path, "server.key"), 'w') as file: - file.write("1") - with open(os.path.join(online_config.tls_path, "server.crt"), 'w') as file: - file.write("1") - checked_online_config(online_config) - shutil.rmtree(online_config.tls_path) - online_config.tls_path = "" - - # test host - online_config.host = "invalid_host" - with self.assertRaises(Exception) as context: - checked_online_config(online_config) - self.assertIn(str(context.exception), f"host: {online_config.host} is invalid.") - online_config.host = "127.0.0.1" + def test_initialize_save_error_data_given_valid_path_then_return_path(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.test_dir = self.temp_dir.name - # test port - online_config.port = -1 - with self.assertRaises(Exception) as context: - checked_online_config(online_config) - self.assertIn(str(context.exception), f"port: {online_config.port} is invalid, port range 0-65535.") - online_config.port = 6123 + from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import initialize_save_error_data + + error_data_path = os.path.join(self.test_dir, "error_data") + result = initialize_save_error_data(error_data_path) + + self.assertTrue(os.path.exists(result)) + self.assertIn("ut_error_data", result) + self.temp_dir.cleanup() + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut_utils.py index 0cf30461aec70b85577c38ebed011bf9f818874d..8cead7b0093ce68dca8a12b0ea6dbcde78a70c0b 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut_utils.py @@ -1,13 +1,27 @@ -# coding=utf-8 +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest -from unittest.mock import patch, MagicMock + import torch + from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import * from msprobe.core.common.file_utils import create_directory, write_csv class TestRunUtUtils(unittest.TestCase): - def setUp(self): save_path = "temp_save_path" create_directory(save_path) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_attl.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_attl.py deleted file mode 100644 index 7d4e6e950dc1d3e51ef69ca46895fcf5078c5f67..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_attl.py +++ /dev/null @@ -1,108 +0,0 @@ -# coding=utf-8 -import unittest -from unittest.mock import patch -from multiprocessing import Queue - -from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import * -from msprobe.core.common.file_utils import create_directory - -class TestATTL(unittest.TestCase): - - def setUp(self): - nfs_path = "temp_nfs_path" - create_directory(nfs_path) - self.nfs_path = os.path.realpath(nfs_path) - self.session_id = "test_session" - self.session_config = ATTLConfig(is_benchmark_device=False, connect_ip='127.0.0.1', - connect_port=8080, nfs_path=self.nfs_path , check_sum=False, queue_size=100) - self.attls = ATTL(self.session_id, self.session_config, need_dump=False) - self.buffer = ApiData('test_api', args=(torch.randn(2, 2),), kwargs={'device': 'cpu'}, - result=torch.randn(2, 2), step=1, rank=1) - - def tearDown(self): - for filename in os.listdir(self.nfs_path): - os.remove(os.path.join(self.nfs_path, filename)) - os.rmdir(self.nfs_path) - - def test_attl_config(self): - config = ATTLConfig(is_benchmark_device=True, connect_ip='192.168.1.1', connect_port=9090, - nfs_path=self.nfs_path, tls_path='/path/to/tls', check_sum=False, queue_size=100) - self.assertEqual(config.is_benchmark_device, True) - self.assertEqual(config.connect_ip, '192.168.1.1') - self.assertEqual(config.connect_port, 9090) - self.assertEqual(config.nfs_path, self.nfs_path) - self.assertEqual(config.tls_path, '/path/to/tls') - self.assertFalse(config.check_sum) - self.assertEqual(config.queue_size, 100) - - @patch('msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl.move2target_device') - def test_upload_api_data(self, mock_move2target_device): - mock_move2target_device.return_value = self.buffer - self.attls.upload(self.buffer) - mock_move2target_device.assert_called_once_with(self.buffer, torch.device('cpu')) - - @patch('glob.glob') - def test_download_no_files(self, mock_glob): - mock_glob.return_value = [] - result = self.attls.download() - self.assertIsNone(result) - - @patch('glob.glob') - @patch('msprobe.pytorch.common.utils.load_pt') - def test_download_with_exception(self, mock_load_pt, mock_glob): - mock_glob.return_value = ['/tmp/start_file.pt'] - mock_load_pt.side_effect = Exception('Load error') - with patch.object(self.attls.logger, 'warning') as mock_logger: - result = self.attls.download() - self.assertIsNone(result) - mock_logger.assert_called_once() - - def test_move2device_exec_tensor(self): - tensor = torch.randn(2, 2) - device = torch.device("cpu") - moved_tensor = move2device_exec(tensor, device) - self.assertEqual(moved_tensor.device, device) - - def test_move2device_exec_list(self): - tensor_list = [torch.randn(2, 2), torch.randn(2, 2)] - device = torch.device("cpu") - moved_list = move2device_exec(tensor_list, device) - for tensor in moved_list: - self.assertEqual(tensor.device, device) - - def test_move2device_exec_tuple(self): - tensor_tuple = (torch.randn(2, 2), torch.randn(2, 2)) - device = torch.device("cpu") - moved_tuple = move2device_exec(tensor_tuple, device) - for tensor in moved_tuple: - self.assertEqual(tensor.device, device) - - def test_move2device_exec_dict(self): - tensor_dict = {"a": torch.randn(2, 2), "b": torch.randn(2, 2)} - device = torch.device("cpu") - moved_dict = move2device_exec(tensor_dict, device) - for tensor in moved_dict.values(): - self.assertEqual(tensor.device, device) - - def test_move2device_exec_device(self): - device = torch.device("cpu") - moved_device = move2device_exec(torch.device("cpu"), device) - self.assertEqual(moved_device, device) - - def test_move2device_exec_non_tensor(self): - obj = "This is a string" - device = torch.device("cpu") - self.assertEqual(move2device_exec(obj, device), obj) - - def test_move2target_device_to_cpu(self): - tensor_args = (torch.randn(2, 2), torch.randn(3, 3)) - tensor_kwargs = {'key1': torch.randn(2, 2), 'key2': torch.randn(3, 3)} - tensor_result = torch.randn(2, 2) - buffer = ApiData('test_api', tensor_args, tensor_kwargs, tensor_result, 1, 1) - target_device = torch.device('cpu') - moved_buffer = move2target_device(buffer, target_device) - self.assertEqual(moved_buffer.result.device, target_device) - for tensor in moved_buffer.args: - self.assertEqual(tensor.device, target_device) - for tensor in moved_buffer.kwargs.values(): - self.assertEqual(tensor.device, target_device) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_client.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_client.py deleted file mode 100644 index d35cfc3387559064298a451fb9d868838bb25aac..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_client.py +++ /dev/null @@ -1,33 +0,0 @@ -# coding=utf-8 -import unittest -from unittest.mock import patch, MagicMock -from multiprocessing import Queue - -from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import * -from msprobe.core.common.file_utils import create_directory - - -class TestClient(unittest.TestCase): - - def setUp(self) -> None: - self.host = "localhost" - self.port = 8000 - self.check_sum = False - tls_path = "temp_tls_path" - create_directory(tls_path) - self.tls_path = os.path.realpath(tls_path) - - def tearDown(self) -> None: - for filename in os.listdir(self.tls_path): - os.remove(os.path.join(self.tls_path, filename)) - os.rmdir(self.tls_path) - - def test_TCPDataItem(self): - data_item = TCPDataItem(data="example_data", sequence_number=10, rank=1, step=2) - self.assertEqual(data_item.raw_data, "example_data") - self.assertEqual(data_item.sequence_number, 10) - self.assertEqual(data_item.rank, 1) - self.assertEqual(data_item.step, 2) - self.assertEqual(data_item.retry_times, 0) - self.assertEqual(data_item.pending_time, 0) - self.assertEqual(data_item.busy_time, 0) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_pt_accuracy_server.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_pt_accuracy_server.py deleted file mode 100644 index b60cfdc323bed57e1cda1fc2d9db3197638cee4c..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_pt_accuracy_server.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import io -import queue -import struct -import time -import unittest -from unittest.mock import MagicMock, patch - -from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import ( - TCPServer, - ServerProtocol, - MessageServerFactory -) - - -class TestTCPServer(unittest.TestCase): - def setUp(self): - self.shared_queue = queue.Queue() - self.tcp_server = TCPServer("6000", self.shared_queue) - self.tcp_server.tls_path = "/test/path" - self.tcp_server.factory = MagicMock() - - @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server.reactor") - def test_run_reactor(self, mock_reactor): - self.tcp_server.run_reactor() - mock_reactor.run.assert_called_once_with(installSignalHandlers=False) - - def test_is_running(self): - self.tcp_server.is_running() - self.tcp_server.factory.is_all_connection_closed.assert_called_once_with() - - @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server.reactor") - def test_stop(self, mock_reactor): - self.tcp_server.reactor_thread = MagicMock() - self.tcp_server.stop() - mock_reactor.callFromThread.assert_called_once() - self.tcp_server.reactor_thread.join.assert_called_once() - - -class TestServerProtocol(unittest.TestCase): - def setUp(self): - self.shared_queue = queue.Queue() - self.server_protocol = ServerProtocol(self.shared_queue) - self.server_protocol.start_time = time.time() - self.server_protocol.factory = MagicMock() - self.server_protocol.factory.transport_dict = {} - self.server_protocol.factory.transport_list = [] - self.server_protocol.transport = MagicMock() - - def test_connectionMade(self): - self.server_protocol.connectionMade() - self.assertEqual(self.server_protocol.tell, 0) - self.assertEqual(self.server_protocol.factory.transport_dict[self.server_protocol.transport], 1) - self.assertTrue(self.server_protocol.transport in self.server_protocol.factory.transport_list) - - def test_connectionLost(self): - self.server_protocol.factory.transport_dict[self.server_protocol.transport] = 1 - self.server_protocol.connectionLost("test") - self.assertEqual(len(self.server_protocol.factory.transport_dict), 0) - self.assertEqual(self.server_protocol.consumer_queue.get(), self.server_protocol.ACK_KILL_PROCESS) - - def test_send_ack(self): - self.server_protocol.sequence_number = 1 - self.server_protocol.rank = 0 - self.server_protocol.step = 0 - self.server_protocol.send_ack(b'test message') - expected_value = b''.join([ - b'test message', - b'\x00\x00\x00\x00\x00\x00\x00\x01', - b'\x00\x00\x00\x00\x00\x00\x00\x00', - b'\x00\x00\x00\x00\x00\x00\x00\x00', - ]) - self.server_protocol.transport.write.called_once_with(expected_value) - - @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server.hashlib.md5") - def test_post_process_error(self, mock_hashlib_md5): - self.shared_queue.maxsize = 1 - self.server_protocol.send_ack = MagicMock() - - def mock_send_ack_method1(): - self.server_protocol.consumer_queue.put(1) - - def mock_send_ack_method2(): - pass - - self.server_protocol.send_ack.side_effect = [mock_send_ack_method1, mock_send_ack_method2] - self.server_protocol.check_sum = True - mock_hashlib_md5.hexdiges.return_value = "123" - self.server_protocol.rank = 0 - self.server_protocol.step = 0 - self.server_protocol.post_process() - mock_hashlib_md5.assert_called() - self.server_protocol.send_ack.assert_any_call(self.server_protocol.ACK_ERROR) - self.assertEqual(self.server_protocol.rank, -1) - self.assertEqual(self.server_protocol.step, -1) - - @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server.hashlib.md5") - def test_post_process_success(self, _): - self.shared_queue.maxsize = 1 - self.server_protocol.send_ack = MagicMock() - - def mock_send_ack_method1(): - self.server_protocol.consumer_queue.put(1) - - def mock_send_ack_method2(): - pass - - self.server_protocol.send_ack.side_effect = [mock_send_ack_method1, mock_send_ack_method2] - self.server_protocol.check_sum = False - self.server_protocol.obj_body = self.server_protocol.ACK_SUCCESS - self.server_protocol.post_process() - self.server_protocol.send_ack.assert_any_call(self.server_protocol.ACK_SUCCESS) - - def test_handle_with_stop(self): - self.server_protocol.send_ack = MagicMock() - self.server_protocol.handle_with_stop() - self.server_protocol.send_ack.assert_called_once_with(self.server_protocol.ACK_STOP_CONFIRM) - self.assertEqual(self.server_protocol.consumer_queue.get(), self.server_protocol.ACK_KILL_PROCESS) - - def test_reset_env(self): - self.server_protocol.obj_length = 10 - self.server_protocol.sequence_number = 1 - self.server_protocol.rank = 2 - self.server_protocol.step = 3 - self.server_protocol.reset_env() - self.assertEqual(self.server_protocol.obj_length, None) - self.assertEqual(self.server_protocol.sequence_number, -1) - self.assertEqual(self.server_protocol.rank, -1) - self.assertEqual(self.server_protocol.step, -1) - - def test_dataReceived(self): - self.server_protocol.buffer = io.BytesIO() - self.server_protocol.post_process = MagicMock() - unpack_mode = '!Q' - header = struct.pack(unpack_mode, 10) - header += struct.pack(unpack_mode, 1) - header += struct.pack(unpack_mode, 2) - header += struct.pack(unpack_mode, 3) - - self.server_protocol.dataReceived(header) - - self.assertEqual(self.server_protocol.obj_length, 10) - self.assertEqual(self.server_protocol.sequence_number, 1) - self.assertEqual(self.server_protocol.rank, 2) - self.assertEqual(self.server_protocol.step, 3) - - -class TestMessageServerFactory(unittest.TestCase): - def setUp(self): - self.message_server_factory = MessageServerFactory() - - def test_is_all_connection_closed(self): - all_conn_closed = self.message_server_factory.is_all_connection_closed() - self.assertTrue(all_conn_closed) - - self.message_server_factory.transport_dict = {"test1": 1} - all_conn_closed = self.message_server_factory.is_all_connection_closed() - self.assertFalse(all_conn_closed) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_pt_device_dispatch.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_pt_device_dispatch.py deleted file mode 100644 index 79f569cdcaeaa662f403b73fd4047caf7c2f0311..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_pt_device_dispatch.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from unittest.mock import MagicMock, patch - -from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import run_ut_process, \ - online_precision_compare, online_compare, ConsumerDispatcher -from msprobe.pytorch.common.log import logger - - -class TestDeviceDispatchFunc(unittest.TestCase): - @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.online_compare") - @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.torch") - def test_run_ut_process(self, mock_torch, mock_online_compare): - xpu_id = 1 - mock_consumer_queue = MagicMock() - mock_consumer_queue.empty.side_effect = [True, False, False] - mock_api_data = MagicMock() - mock_api_data.name.split.return_value = ("test", "conv2d", 1) - mock_consumer_queue.get.side_effect = [mock_api_data, "KILL_"] - - run_ut_process(xpu_id, mock_consumer_queue, None, None) - mock_torch.device.assert_called_once_with('cuda:1') - mock_online_compare.assert_called_with(mock_api_data, mock_torch.device(), None) - - @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.UtDataInfo") - @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.exec_api") - @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.generate_cpu_params") - def test_online_precision_compare(self, mock_gen_cpu_params, mock_exec_api, mock_ut_data_info): - with patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.move2target_device"), \ - patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.pd"), \ - patch( - "msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.online_api_precision_compare"): - mock_gen_cpu_params.return_value = (MagicMock()) - mock_api_data = MagicMock() - mock_api_data.name.split.return_value = ("tensor", "conv2d", 1) - mock_com_config = MagicMock() - mock_api_precision_csv_file = [MagicMock(), MagicMock()] - online_precision_compare(mock_api_data, None, mock_com_config, mock_api_precision_csv_file) - mock_gen_cpu_params.assert_called() - mock_exec_api.assert_called() - mock_ut_data_info.assert_called() - - @patch.object(logger, "info") - @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.move2target_device") - def test_online_compare_success(self, mock_move2target_device, mock_logger_info): - api_data = MagicMock() - api_data.name = "test_api_name" - common_config = MagicMock() - common_config.compare.compare_output.return_value = ("test_fwd_success", "test_bwd_success") - online_compare(api_data, None, common_config) - mock_move2target_device.assert_called() - mock_logger_info.assert_called_once_with("running api_full_name test_api_name ut, " - "is_fwd_success: test_fwd_success, " - "is_bwd_success: test_bwd_success") - - @patch.object(logger, "error") - @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.move2target_device") - def test_online_compare_failed(self, mock_move2target_device, mock_logger_error): - api_data = MagicMock() - api_data.name.split.return_value = ["tensor", "conv2d", 1] - common_config = MagicMock() - online_compare(api_data, None, common_config) - mock_move2target_device.assert_called() - mock_logger_error.assert_called() - - -class TestConsumerDispatcher(unittest.TestCase): - @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.mp") - def setUp(self, mock_mq): - self.mock_mq = mock_mq - self.consumer_dispatcher = ConsumerDispatcher(None) - - @patch.object(logger, "info") - @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.mp") - @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.CommonCompareConfig") - def test_start(self, mock_com_compare_config, mock_mq, mock_log_info): - self.consumer_dispatcher.start(None, None) - mock_com_compare_config.assert_called_once_with(None, None, None) - mock_mq.Process.assert_called() - mock_log_info.assert_any_call("Successfully start unittest process.") - - @patch.object(logger, "info") - def test_stop(self, mock_log_info): - mock_queue = MagicMock() - mock_queue.full.side_effect = [True, False] - self.consumer_dispatcher.queues = [mock_queue] - - mock_process = MagicMock() - self.consumer_dispatcher.processes = [mock_process] - self.consumer_dispatcher.stop() - mock_log_info.assert_any_call("Successfully stop unittest process.") - mock_process.join.assert_called() - - def test_update_consume_queue(self): - self.consumer_dispatcher._choose_max_empty_site_strategy = MagicMock() - self.consumer_dispatcher._choose_max_empty_site_strategy.return_value = 0 - mock_queue = MagicMock() - self.consumer_dispatcher.queues = [mock_queue] - self.consumer_dispatcher.update_consume_queue("test_data") - mock_queue.put.assert_called_once_with("test_data") - - def test_choose_max_empty_site_strategy(self): - mock_queue = MagicMock() - mock_queue.qsize.return_value = 1 - self.consumer_dispatcher.queues = [mock_queue] - self.consumer_dispatcher.capacity = 5 - self.consumer_dispatcher.reverse_sort = False - result = self.consumer_dispatcher._choose_max_empty_site_strategy() - self.assertEqual(result, 0) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py index cdc922cc98d59b59ec0be85833d2000cd38913c8..709d252f39640578be2dab9edf390862efed8e9a 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py @@ -1,17 +1,39 @@ -import os +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import io +import os +import tempfile import unittest from unittest.mock import MagicMock, patch -import tempfile import torch import torch.distributed as dist - -from msprobe.core.common.file_utils import FileCheckConst from msprobe.core.common.exceptions import DistributedNotInitializedError from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData -from msprobe.pytorch.common.utils import parameter_adapter, get_rank_if_initialized, \ - get_tensor_rank, get_rank_id, print_rank_0, load_pt, save_pt, save_api_data, load_api_data, save_pkl, load_pkl +from msprobe.pytorch.common.utils import ( + parameter_adapter, + get_rank_if_initialized, + get_tensor_rank, + get_rank_id, + print_rank_0, + load_pt, + save_pt, + save_pkl, + load_pkl +) class TestParameterAdapter(unittest.TestCase): @@ -19,7 +41,7 @@ class TestParameterAdapter(unittest.TestCase): def setUp(self): self.func_mock = MagicMock() self.decorated_func = parameter_adapter(self.func_mock) - self.op_name_ = "__getitem__" + self.api_name = "__getitem__" def test_handle_masked_select_bfloat16(self): input_tensor = torch.tensor([1.0, 2.0], dtype=torch.bfloat16) @@ -45,7 +67,7 @@ class TestParameterAdapter(unittest.TestCase): self.assertTrue(torch.equal(result, torch.tensor([20.0, 30.0]))) def test_op_name_eq_with_none(self): - self.op_name_ = "__eq__" + self.api_name = "__eq__" args = (torch.tensor([1]), None) result = self.decorated_func(self, *args) self.assertFalse(result) @@ -186,6 +208,12 @@ class TestSavePT(unittest.TestCase): self.tensor = torch.tensor([1, 2, 3]) self.filepath = 'temp_tensor.pt' + def tearDown(self): + try: + os.remove(self.filepath) + except FileNotFoundError: + pass + @patch('msprobe.pytorch.common.utils.save_pt') @patch('os.path.realpath', return_value='temp_tensor.pt') @patch('msprobe.core.common.file_utils.check_path_before_create') @@ -193,21 +221,6 @@ class TestSavePT(unittest.TestCase): def test_save_pt_success(self, mock_change_mode, mock_check_path, mock_realpath, mock_torch_save): mock_torch_save(self.tensor, self.filepath) mock_torch_save.assert_called_once_with(self.tensor, self.filepath) - mock_change_mode.assert_called_once_with(self.filepath, FileCheckConst.DATA_FILE_AUTHORITY) - -class TestSavePT(unittest.TestCase): - - def setUp(self): - self.tensor = torch.tensor([1, 2, 3]) - self.filepath = 'temp_tensor.pt' - - @patch('torch.save') - @patch('os.path.realpath', return_value='temp_tensor.pt') - @patch('msprobe.core.common.file_utils.check_path_before_create') - @patch('msprobe.core.common.file_utils.change_mode') - def test_save_pt_success(self, mock_change_mode, mock_check_path, mock_realpath, mock_torch_save): - save_pt(self.tensor, self.filepath) - mock_torch_save.assert_called_once_with(self.tensor, self.filepath) @patch('torch.save', side_effect=Exception("Save failed")) @patch('os.path.realpath', return_value='temp_tensor.pt') @@ -218,47 +231,6 @@ class TestSavePT(unittest.TestCase): save_pt(self.tensor, self.filepath) self.assertIn("save pt file temp_tensor.pt failed", str(context.exception)) - def tearDown(self): - try: - os.remove(self.filepath) - except FileNotFoundError: - pass - - -class TestSaveApiData(unittest.TestCase): - - def test_save_api_data_success(self): - api_data = {"key": "value"} - io_buff = save_api_data(api_data) - self.assertIsInstance(io_buff, io.BytesIO) - io_buff.seek(0) - loaded_data = torch.load(io_buff) - self.assertEqual(loaded_data, api_data) - - def test_save_api_data_failure(self): - api_data = MagicMock() - with patch('torch.save', side_effect=Exception("save error")): - with self.assertRaises(RuntimeError) as context: - save_api_data(api_data) - self.assertIn("save api_data to io_buff failed", str(context.exception)) - - -class TestLoadApiData(unittest.TestCase): - - def test_load_api_data_success(self): - mock_tensor = torch.tensor([1, 2, 3]) - buffer = io.BytesIO() - torch.save(mock_tensor, buffer) - buffer.seek(0) - result = load_api_data(buffer.read()) - self.assertTrue(torch.equal(result, mock_tensor)) - - def test_load_api_data_failure(self): - invalid_bytes = b'not a valid tensor' - with self.assertRaises(RuntimeError) as context: - load_api_data(invalid_bytes) - self.assertIn("load api_data from bytes failed", str(context.exception)) - class TestSavePkl(unittest.TestCase): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_match.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_match.py deleted file mode 100644 index ac28e994e9c8e77f8ae675fec3322eaf64a64321..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_match.py +++ /dev/null @@ -1,20 +0,0 @@ -# coding=utf-8 -import unittest -from msprobe.pytorch.compare import match - - -class TestMatch(unittest.TestCase): - def test_graph_mapping(self): - op1 = "Aten_convolution_1_forward_0.input.0" - op2 = "Torch_conv2d_0_forward_0.input.0" - op3 = "Torch_batch_norm_0_forward_0.input.0" - op4 = "Aten_convolution.default_1_forward_0.input.0" - op5 = "Aten_foo_1_forward_0.input.0" - self.assertTrue(match.graph_mapping.match(op1, op2)) - self.assertTrue(match.graph_mapping.match(op2, op1)) - self.assertTrue(match.graph_mapping.match(op4, op2)) - self.assertTrue(match.graph_mapping.match(op2, op4)) - self.assertFalse(match.graph_mapping.match(op1, op3)) - self.assertFalse(match.graph_mapping.match(op3, op1)) - self.assertFalse(match.graph_mapping.match(op5, op2)) - self.assertFalse(match.graph_mapping.match(op2, op5)) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare.py index b079e646c4a8f4098bb233e3e6259ef3ebea9c94..e4c8b722b182b8c0a4e82ba1b0eeb1a6ed847ee2 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare.py @@ -3,16 +3,12 @@ import os import shutil import unittest -import numpy as np import torch -from msprobe.core.common.const import Const from msprobe.core.common.utils import CompareException -from msprobe.core.compare.acc_compare import ModeConfig -from msprobe.pytorch.compare.pt_compare import PTComparator, compare +from msprobe.pytorch.compare.pt_compare import compare from msprobe.test.core_ut.compare.test_acc_compare import generate_dump_json, generate_stack_json - base_dir1 = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'test_pt_compare1') base_dir2 = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'test_pt_compare2') @@ -40,36 +36,6 @@ class TestUtilsMethods(unittest.TestCase): if os.path.exists(base_dir2): shutil.rmtree(base_dir2) - def test_read_npy_data_bf16(self): - generate_bf16_pt(base_dir1) - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - pt_comparator = PTComparator(mode_config) - result = pt_comparator.read_npy_data(base_dir1, 'bf16.pt') - - target_result = torch.tensor([1, 2, 3, 4], dtype=torch.float32).numpy() - self.assertTrue(np.array_equal(result, target_result)) - - def test_read_npy_data_dict(self): - generate_dict_pt(base_dir1) - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - pt_comparator = PTComparator(mode_config) - - with self.assertRaises(CompareException) as context: - result = pt_comparator.read_npy_data(base_dir1, 'dict.pt') - self.assertEqual(context.exception.code, CompareException.DETACH_ERROR) - def test_compare(self): generate_dump_json(base_dir2) generate_stack_json(base_dir2) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..558df47a108f27858cc571f6854ca3f403fc6fee --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare_utils.py @@ -0,0 +1,72 @@ +import os +import shutil +import threading +import unittest +from unittest import mock +from unittest.mock import patch + +import numpy as np + +from msprobe.pytorch.compare import utils +from msprobe.pytorch.compare.utils import read_pt_data +from msprobe.test.core_ut.compare.test_acc_compare import generate_pt +from msprobe.core.common.utils import CompareException + + +base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'test_pt_compare_utils_data') +pt_dir = os.path.join(base_dir, f'dump_data_dir') + + +class TestReadPtData(unittest.TestCase): + + def setUp(self): + os.makedirs(base_dir, mode=0o750, exist_ok=True) + os.makedirs(pt_dir, mode=0o750, exist_ok=True) + + self.lock = threading.Lock() + + def tearDown(self): + if os.path.exists(pt_dir): + shutil.rmtree(pt_dir) + if os.path.exists(base_dir): + shutil.rmtree(base_dir) + + def test_read_pt_data_normal(self): + generate_pt(pt_dir) + result = read_pt_data(pt_dir, 'Functional.linear.0.forward.input.0.pt') + expected = np.array([1.0, 2.0, 3.0, 4.0]) + self.assertTrue(np.array_equal(result, expected)) + + def test_read_pt_data_no_file_name(self): + result = read_pt_data(pt_dir, None) + self.assertEqual(result, None) + + @patch.object(utils, 'load_pt') + @patch.object(utils, 'FileChecker') + def test_read_pt_data_runtime_error(self, mock_file_checker_class, mock_load_pt): + mock_file_checker = mock.Mock() + mock_file_checker.common_check.return_value = 'fake/path/file.pt' + mock_file_checker_class.return_value = mock_file_checker + + mock_load_pt.side_effect = RuntimeError('failed to load') + + with self.assertRaises(CompareException) as context: + read_pt_data('fake/path', 'file.pt') + self.assertEqual(context.exception.code, CompareException.INVALID_FILE_ERROR) + + @patch.object(utils, 'load_pt') + @patch.object(utils, 'FileChecker') + def test_read_pt_data_attribute_error(self, mock_file_checker_class, mock_load_pt): + mock_file_checker = mock.Mock() + mock_file_checker.common_check.return_value = 'fake/path/file.pt' + mock_file_checker_class.return_value = mock_file_checker + + class FakeTensor: + def detach(self): + raise AttributeError('no detach') + + mock_load_pt.return_value = FakeTensor() + + with self.assertRaises(CompareException) as context: + read_pt_data('fake/path', 'file.pt') + self.assertEqual(context.exception.code, CompareException.DETACH_ERROR) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_config.py index 4fc27c267ebe65ea46ecf0f17bc47ff702eb241d..56af1d2198d2270733cf533fe0ea556598bcf30b 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_config.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_config.py @@ -1,6 +1,7 @@ import unittest -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch +import torch from msprobe.core.common.const import Const from msprobe.core.common.exceptions import MsprobeException from msprobe.pytorch.debugger.debugger_config import DebuggerConfig @@ -34,40 +35,96 @@ class TestDebuggerConfig(unittest.TestCase): self.assertEqual(debugger.handler_type, "check") self.assertTrue(debugger.preheat_config["if_preheat"]) - def test_online_run_ut_initialization(self): - self.task_config.online_run_ut = True - self.task_config.nfs_path = "./nfs_path" - self.task_config.tls_path = "./tls_path" - self.task_config.host = "localhost" - self.task_config.port = 8080 - debugger = DebuggerConfig(self.common_config, self.task_config, Const.TENSOR, None, None) - self.assertTrue(debugger.online_run_ut) - self.assertEqual(debugger.nfs_path, "./nfs_path") - self.assertEqual(debugger.port, 8080) + def test_check_kwargs_with_invalid_task(self): + self.common_config.task = "invalid_task" + with self.assertRaises(MsprobeException) as context: + DebuggerConfig(self.common_config, self.task_config, None, None, None) + self.assertIn(f"The task is not in the {Const.TASK_LIST}", str(context.exception)) + + def test_check_kwargs_with_invalid_level(self): + self.common_config.level = "invalid_level" + with self.assertRaises(MsprobeException) as context: + DebuggerConfig(self.common_config, self.task_config, None, None, None) + self.assertIn(f"The level is not in the {Const.LEVEL_LIST}.", str(context.exception)) - def test_valid_task_and_level(self): - config = DebuggerConfig(self.common_config, self.task_config, "tensor", None, "L1") - config.check_kwargs() + def test_check_kwargs_with_invalid_dump_path(self): + self.common_config.dump_path = None + with self.assertRaises(MsprobeException) as context: + DebuggerConfig(self.common_config, self.task_config, None, None, None) + self.assertIn(f"The dump_path not found.", str(context.exception)) - def test_invalid_task(self): + def test_check_kwargs_with_invalid_async_dump(self): + self.common_config.async_dump = 1 with self.assertRaises(MsprobeException) as context: - config = DebuggerConfig(self.common_config, self.task_config, "invalid_task", None, "L1") - config.check_kwargs() - self.assertIn("not in the", str(context.exception)) + DebuggerConfig(self.common_config, self.task_config, None, None, None) + self.assertIn(f"The parameters async_dump should be bool.", str(context.exception)) + + def test_check_kwargs_with_async_dump_and_debug(self): + self.common_config.async_dump = True + self.common_config.task = Const.TENSOR + self.common_config.level = Const.LEVEL_DEBUG + self.task_config.list = ["linear"] + config = DebuggerConfig(self.common_config, self.task_config, None, None, None) + self.assertEqual(config.list, []) - def test_invalid_level(self): + def test_check_kwargs_with_async_dump_and_not_debug(self): + self.common_config.async_dump = True + self.common_config.task = Const.TENSOR + self.common_config.level = Const.LEVEL_MIX + self.task_config.list = [] + self.task_config.summary_mode = Const.SUMMARY_MODE with self.assertRaises(MsprobeException) as context: - config = DebuggerConfig(self.common_config, self.task_config, "tensor", None, "invalid_level") - config.check_kwargs() - self.assertIn("not in the", str(context.exception)) + DebuggerConfig(self.common_config, self.task_config, None, None, None) + self.assertIn(f"the parameters list cannot be empty.", str(context.exception)) + + def test_check_kwargs_with_structure_task(self): + self.common_config.task = Const.STRUCTURE + self.common_config.level = Const.LEVEL_L1 + config = DebuggerConfig(self.common_config, self.task_config, None, None, None) + self.assertEqual(config.level, Const.LEVEL_MIX) - def test_missing_dump_path(self): + def test_check_async_dump_and_md5(self): + self.common_config.async_dump = True + self.common_config.task = Const.STATISTICS + self.common_config.level = Const.LEVEL_L1 + self.task_config.summary_mode = Const.MD5 + with self.assertRaises(MsprobeException) as context: + DebuggerConfig(self.common_config, self.task_config, None, None, None) + self.assertIn(f"the parameters summary_mode cannot be md5.", str(context.exception)) + + def test_check_model_with_model_is_none(self): + self.common_config.level = Const.LEVEL_L0 + instance = MagicMock() + instance.model = None + config = DebuggerConfig(self.common_config, self.task_config, None, None, None) with self.assertRaises(MsprobeException) as context: - self.common_config.dump_path = None - config = DebuggerConfig(self.common_config, self.task_config, "tensor", None, "L1") - config.check_kwargs() - self.assertIn("dump_path not found", str(context.exception)) + config.check_model(instance, None, None) + self.assertIn("missing the parameter 'model'", str(context.exception)) + + def test_check_model_with_single_model(self): + self.common_config.level = Const.LEVEL_MIX + model1 = torch.nn.ReLU() + model2 = torch.nn.Linear(2, 2) + + instance = MagicMock() + instance.model = model1 + config = DebuggerConfig(self.common_config, self.task_config, None, None, None) + config.check_model(instance, model2, None) + + self.assertEqual(instance.model, model2) + + def test_check_model_with_incorrect_model(self): + self.common_config.level = Const.LEVEL_L0 + model1 = torch.nn.ReLU() + model2 = [torch.nn.Linear(2, 2), torch.nn.ReLU(), "test_model"] + + instance = MagicMock() + instance.model = model1 + config = DebuggerConfig(self.common_config, self.task_config, None, None, None) + with self.assertRaises(MsprobeException) as context: + config.check_model(instance, model2, None) + self.assertIn("must be a torch.nn.Module or list[torch.nn.Module]", str(context.exception)) def test_check_and_adjust_config_with_l2_scope_not_empty(self): self.common_config.dump_path = "./dump_path" @@ -100,3 +157,50 @@ class TestDebuggerConfig(unittest.TestCase): debugger = DebuggerConfig(self.common_config, self.task_config, None, None, None) debugger._check_and_adjust_config_with_l2() self.assertIn("Functional.conv2d.0.forward", self.task_config.list) + + def test_check_and_adjust_config_with_l2_task_not_tensor(self): + self.common_config.dump_path = "./dump_path" + self.common_config.task = Const.STATISTICS + + self.task_config.scope = [] + self.task_config.list = ["Functional.conv2d.0.forward"] + debugger = DebuggerConfig(self.common_config, self.task_config, None, None, None) + with self.assertRaises(MsprobeException) as context: + debugger._check_and_adjust_config_with_l2() + self.assertIn("the task must be set to tensor", str(context.exception)) + + def test_check_statistics_config_task_not_statistics(self): + self.common_config.dump_path = "./dump_path" + self.common_config.task = Const.TENSOR + + debugger = DebuggerConfig(self.common_config, self.task_config, None, None, None) + debugger._check_statistics_config(self.task_config) + self.assertFalse(hasattr(debugger, "tensor_list")) + + def test_check_statistics_config_not_tensor_list(self): + self.common_config.dump_path = "./dump_path" + self.common_config.task = Const.STATISTICS + delattr(self.task_config, "tensor_list") + + debugger = DebuggerConfig(self.common_config, self.task_config, None, None, None) + debugger._check_statistics_config(self.task_config) + self.assertEqual(debugger.tensor_list, []) + + def test_check_statistics_config_debug_level(self): + self.common_config.dump_path = "./dump_path" + self.common_config.task = Const.STATISTICS + self.common_config.level = Const.DEBUG + + debugger = DebuggerConfig(self.common_config, self.task_config, None, None, None) + self.task_config.tensor_list = ["Functional.conv2d"] + debugger._check_statistics_config(self.task_config) + self.assertEqual(debugger.tensor_list, []) + + def test_check_statistics_config_success(self): + self.common_config.dump_path = "./dump_path" + self.common_config.task = Const.STATISTICS + + self.task_config.tensor_list = ["Functional.conv2d"] + debugger = DebuggerConfig(self.common_config, self.task_config, None, None, None) + debugger._check_statistics_config(self.task_config) + self.assertEqual(debugger.tensor_list, self.task_config.tensor_list) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_start.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_start.py new file mode 100644 index 0000000000000000000000000000000000000000..217565e48192523e1e44dbc0a8e41f633ab71bef --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_start.py @@ -0,0 +1,105 @@ +import os +import torch +import torch.nn as nn +from torch.utils.data import TensorDataset, DataLoader +import unittest +from unittest.mock import patch +from msprobe.core.common_config import CommonConfig +from msprobe.core.debugger.precision_debugger import BasePrecisionDebugger +from msprobe.pytorch.pt_config import StatisticsConfig +from msprobe.pytorch.debugger.precision_debugger import PrecisionDebugger +from msprobe.core.common.file_utils import load_json +import shutil + +# 生成随机分类数据 +X = torch.randn(100, 2) +y = ((X[:, 0] + X[:, 1]) > 0).float().reshape(-1, 1) + +# 创建数据加载器 +dataset = TensorDataset(X, y) +dataloader = DataLoader(dataset, batch_size=10) + + +# 定义单层神经网络 +class SingleLayerNet(nn.Module): + def __init__(self): + super().__init__() + self.layer = nn.Linear(2, 1) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + return self.sigmoid(x) + + +class MultiStartDebugger: + debugger = None + dump_path = None + hooked_model = [] + + @classmethod + def init(cls, dump_path): + cls.dump_path = dump_path + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [], + "level": "L1", + "async_dump": False + } + + common_config = CommonConfig(json_config) + task_config = StatisticsConfig(json_config) + with patch.object(BasePrecisionDebugger, "_parse_config_path", return_value=(common_config, task_config)): + cls.debugger = PrecisionDebugger(task="statistics", level="L0", dump_path=dump_path, step=["2-3"]) + + @classmethod + def debugger_start(cls, model, tag): + cls.debugger.service.first_start = True if model not in cls.hooked_model else False + cls.debugger.service.config.dump_path = os.path.join(cls.dump_path, tag) + cls.debugger.start(model=model) + if not cls.debugger.service.first_start and model not in cls.hooked_model: + cls.hooked_model.append(model) + + @classmethod + def debugger_stop(cls): + cls.debugger.stop() + cls.debugger.service.reset_status() + + @classmethod + def debugger_step(cls): + cls.debugger.step() + + +class TestPTDebuggerStart(unittest.TestCase): + def test_debugger_multiple_start(self): + dump_path = "./test_debugger_multiple_start_dump" + + model1 = SingleLayerNet() + model2 = SingleLayerNet() + MultiStartDebugger.init(dump_path) + + for batch_X, batch_y in dataloader: + MultiStartDebugger.debugger_start(model=model1, tag="model1") + output1 = model1(batch_X) + MultiStartDebugger.debugger_stop() + + MultiStartDebugger.debugger_start(model=model2, tag="model2") + output2 = model2(batch_X) + MultiStartDebugger.debugger_stop() + MultiStartDebugger.debugger_step() + + model1_dump_path = os.path.join(dump_path, "model1") + self.assertTrue(os.path.exists(model1_dump_path)) + self.assertEqual(len(os.listdir(model1_dump_path)), 2) + model1_construct_json = load_json(os.path.join(model1_dump_path, "step2", "rank", "construct.json")) + self.assertEqual(len(model1_construct_json), 1) + + model2_dump_path = os.path.join(dump_path, "model2") + self.assertTrue(os.path.exists(model2_dump_path)) + self.assertEqual(len(os.listdir(model2_dump_path)), 2) + model2_construct_json = load_json(os.path.join(model2_dump_path, "step2", "rank", "construct.json")) + self.assertEqual(len(model2_construct_json), 1) + + shutil.rmtree(dump_path) + diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_precision_debugger.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_precision_debugger.py index a2f3e8a816e356b68e598138b30a9e14b42107d9..69df2c8923cbd9692f27ea2ebbaac8c01409a19a 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_precision_debugger.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_precision_debugger.py @@ -8,9 +8,12 @@ import torch from msprobe.core.common.const import Const, MsgConst from msprobe.core.common.utils import get_real_step_or_rank from msprobe.core.common.exceptions import MsprobeException, FileCheckException -from msprobe.pytorch.debugger.precision_debugger import PrecisionDebugger, iter_tracer +from msprobe.pytorch.debugger.precision_debugger import PrecisionDebugger from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor from msprobe.test.pytorch_ut.grad_probe.test_grad_monitor import common_config, task_config +from msprobe.core.common_config import CommonConfig +from msprobe.core.debugger.precision_debugger import BasePrecisionDebugger +from msprobe.pytorch.pt_config import StatisticsConfig, GradToolConfig class Args: @@ -23,6 +26,29 @@ class Args: class TestPrecisionDebugger(unittest.TestCase): + grad_json_config = { + "task": Const.GRAD_PROBE, + "dump_path": "/absolute_path", + "rank": [], + "step": [], + "level": "L1", + "async_dump": False + } + + grad_common_config = CommonConfig(grad_json_config) + grad_task_config = GradToolConfig(grad_json_config) + + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [], + "level": "L1", + "async_dump": False + } + + statistics_common_config = CommonConfig(json_config) + statistics_task_config = StatisticsConfig(json_config) def test_init(self): gm = GradientMonitor(common_config, task_config) @@ -30,43 +56,43 @@ class TestPrecisionDebugger(unittest.TestCase): step = get_real_step_or_rank([0, 1, "3-5"], Const.STEP) self.assertListEqual(step, [0, 1, 3, 4, 5]) - def test_instance(self): - debugger1 = PrecisionDebugger(dump_path="./dump_path") - debugger2 = PrecisionDebugger(dump_path="./dump_path") - self.assertIs(debugger1.instance, debugger2.instance) - def test_check_input_params(self): - args = Args(config_path = 1) + args = Args(config_path=1) with self.assertRaises(MsprobeException) as context: - PrecisionDebugger.check_input_params(args) + PrecisionDebugger._check_input_params(args.config_path, args.task, args.dump_path, args.level) self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) - args = Args(config_path = "./") + args = Args(config_path="./") with self.assertRaises(FileCheckException) as context: - PrecisionDebugger.check_input_params(args) + PrecisionDebugger._check_input_params(args.config_path, args.task, args.dump_path, args.level) self.assertEqual(context.exception.code, FileCheckException.INVALID_FILE_ERROR) - args = Args(task = 1) + args = Args(task=1) with self.assertRaises(MsprobeException) as context: - PrecisionDebugger.check_input_params(args) + PrecisionDebugger._check_input_params(args.config_path, args.task, args.dump_path, args.level) self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) - args = Args(dump_path = 1) + args = Args(dump_path=1) with self.assertRaises(MsprobeException) as context: - PrecisionDebugger.check_input_params(args) + PrecisionDebugger._check_input_params(args.config_path, args.task, args.dump_path, args.level) self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) - args = Args(level = 1) + args = Args(level=1) with self.assertRaises(MsprobeException) as context: - PrecisionDebugger.check_input_params(args) + PrecisionDebugger._check_input_params(args.config_path, args.task, args.dump_path, args.level) self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) - args = Args(config_path = os.path.join(os.path.dirname(__file__), "../../../config.json"), - task = Const.TASK_LIST[0], - dump_path="./dump_path", - level = Const.LEVEL_LIST[0], - model = torch.nn.Module()) - checked_input_params = PrecisionDebugger.check_input_params(args) + args = Args(config_path=os.path.join(os.path.dirname(__file__), "../../../config.json"), + task=Const.TASK_LIST[0], + dump_path="./dump_path", + level=Const.LEVEL_LIST[0], + model=torch.nn.Module()) + checked_input_params = PrecisionDebugger._check_input_params( + args.config_path, + args.task, + args.dump_path, + args.level + ) self.assertIsNone(checked_input_params) def test_start_grad_probe(self): @@ -75,12 +101,17 @@ class TestPrecisionDebugger(unittest.TestCase): PrecisionDebugger.start() self.assertEqual(str(context.exception), MsgConst.NOT_CREATED_INSTANCE) - PrecisionDebugger._instance = PrecisionDebugger(task=Const.GRAD_PROBE, dump_path="./dump_path") + with patch.object(BasePrecisionDebugger, "_parse_config_path", + return_value=(self.grad_common_config, self.grad_task_config)): + PrecisionDebugger._instance = PrecisionDebugger(task=Const.GRAD_PROBE, dump_path="./dump_path") checked_start = PrecisionDebugger.start() self.assertIsNone(checked_start) def test_start_statistics(self): - debugger = PrecisionDebugger(dump_path="./dump_path") + PrecisionDebugger._instance = None + with patch.object(BasePrecisionDebugger, "_parse_config_path", + return_value=(self.statistics_common_config, self.statistics_task_config)): + debugger = PrecisionDebugger(dump_path="./dump_path") debugger.service = MagicMock() debugger.config = MagicMock() debugger.task = 'statistics' @@ -88,7 +119,12 @@ class TestPrecisionDebugger(unittest.TestCase): debugger.service.start.assert_called_once() def test_forward_backward_dump_end(self): - debugger = PrecisionDebugger(dump_path="./dump_path") + with patch.object( + BasePrecisionDebugger, + "_parse_config_path", + return_value=(self.statistics_common_config,self.statistics_task_config) + ): + debugger = PrecisionDebugger(dump_path="./dump_path", task='statistics') debugger.service = MagicMock() debugger.config = MagicMock() debugger.task = 'statistics' @@ -101,11 +137,14 @@ class TestPrecisionDebugger(unittest.TestCase): PrecisionDebugger.stop() self.assertEqual(str(context.exception), MsgConst.NOT_CREATED_INSTANCE) - PrecisionDebugger._instance = PrecisionDebugger(task=Const.GRAD_PROBE, dump_path="./dump_path") + with patch.object(BasePrecisionDebugger, "_parse_config_path", + return_value=(self.grad_common_config, self.grad_task_config)): + PrecisionDebugger._instance = PrecisionDebugger(task=Const.GRAD_PROBE, dump_path="./dump_path") checked_stop = PrecisionDebugger.stop() self.assertIsNone(checked_stop) def test_stop_statistics(self): + PrecisionDebugger._instance = None debugger = PrecisionDebugger(dump_path="./dump_path") debugger.service = MagicMock() debugger.task = '' @@ -117,8 +156,9 @@ class TestPrecisionDebugger(unittest.TestCase): PrecisionDebugger._instance = None PrecisionDebugger.step() self.assertEqual(str(context.exception), MsgConst.NOT_CREATED_INSTANCE) - - PrecisionDebugger._instance = PrecisionDebugger(task=Const.GRAD_PROBE, dump_path="./dump_path") + with patch.object(BasePrecisionDebugger, "_parse_config_path", + return_value=(self.grad_common_config, self.grad_task_config)): + PrecisionDebugger._instance = PrecisionDebugger(task=Const.GRAD_PROBE, dump_path="./dump_path") checked_step = PrecisionDebugger.step() self.assertIsNone(checked_step) @@ -135,7 +175,12 @@ class TestPrecisionDebugger(unittest.TestCase): PrecisionDebugger.monitor(torch.nn.Module()) self.assertEqual(str(context.exception), MsgConst.NOT_CREATED_INSTANCE) - debugger = PrecisionDebugger(task=Const.STATISTICS, dump_path="./dump_path") + with patch.object( + BasePrecisionDebugger, + "_parse_config_path", + return_value=(self.statistics_common_config, self.statistics_task_config) + ): + debugger = PrecisionDebugger(task=Const.STATISTICS, dump_path="./dump_path") checked_monitor = debugger.monitor(torch.nn.Module()) self.assertIsNone(checked_monitor) @@ -146,40 +191,57 @@ class TestPrecisionDebugger(unittest.TestCase): debugger.gm.monitor(torch.nn.Module()) debugger.gm.monitor.assert_called_once() - @patch('msprobe.pytorch.debugger.precision_debugger.PrecisionDebugger') - def test_iter_tracer(self, mock_debugger): - mock_debugger_instance = mock_debugger.instance = MagicMock() - mock_debugger_instance.service.first_start = False - - @iter_tracer - def dataloader_func(): - return "test_iter_tracer" - result = dataloader_func() - self.assertEqual(result, "test_iter_tracer") - - mock_debugger_instance.stop.assert_called_once() - mock_debugger_instance.step.assert_called_once() - mock_debugger_instance.start.assert_called_once() - self.assertTrue(mock_debugger_instance.enable_dataloader) - - @patch('msprobe.pytorch.debugger.precision_debugger.PrecisionDebugger') - def test_iter_tracer_first_start(self, mock_debugger): - mock_debugger_instance = mock_debugger.instance = MagicMock() - mock_debugger_instance.service.first_start = True - - @iter_tracer - def dataloader_func(): - return "test_iter_tracer" - result = dataloader_func() - self.assertEqual(result, "test_iter_tracer") - - mock_debugger_instance.stop.assert_not_called() - mock_debugger_instance.step.assert_not_called() - mock_debugger_instance.start.assert_called_once() - self.assertTrue(mock_debugger_instance.enable_dataloader) - def tearDown(self): if os.path.exists("./dump_path/"): shutil.rmtree("./dump_path/") if os.path.exists("./grad_output/"): shutil.rmtree("./grad_output/") + + +class TestIterTracer(unittest.TestCase): + def setUp(self): + self.debugger = MagicMock() + self.debugger.service.first_start = False + self.debugger.enable_dataloader = True + self.ori_instance = PrecisionDebugger._instance + PrecisionDebugger._instance = self.debugger + + def tearDown(self): + PrecisionDebugger._instance = self.ori_instance + + def test_debugger_with_not_first_start(self): + @PrecisionDebugger._iter_tracer + def test_func(): + return "test case 1" + + result = test_func() + + self.assertEqual(result, "test case 1") + self.debugger.stop.assert_called_once() + self.debugger.step.assert_called_once() + self.debugger.start.assert_called_once() + + def test_debugger_with_first_start(self): + self.debugger.service.first_start = True + + @PrecisionDebugger._iter_tracer + def test_func(): + return "test case 2" + + result = test_func() + self.assertEqual(result, "test case 2") + self.debugger.stop.assert_not_called() + self.debugger.step.assert_not_called() + self.debugger.start.assert_called_once() + + def test_no_debugger_instance(self): + PrecisionDebugger._instance = None + + @PrecisionDebugger._iter_tracer + def test_func(): + return "test case 3" + + with self.assertRaises(MsprobeException) as context: + result = test_func() + self.assertEqual(result, "test case 3") + self.assertEqual(context.exception.code, MsprobeException.INTERFACE_USAGE_ERROR) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger_save/test_debugger_save_pytorch.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger_save/test_debugger_save_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..ea2350295206fb475e106f03e13afdeeba25289c --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger_save/test_debugger_save_pytorch.py @@ -0,0 +1,418 @@ +import unittest +import os +import json +import torch +import numpy as np +import shutil + +from msprobe.pytorch import PrecisionDebugger + +current_file = __file__ +parent_dir = os.path.abspath(os.path.dirname(current_file)) +test_dir = os.path.join(parent_dir, "test_dir") + +def deep_compare(obj1, obj2, float_tolerance=1e-5): + """ + Recursively compare two objects to check if they are the same. + Supports nested dictionaries and lists. + """ + if type(obj1) != type(obj2): + return False + if isinstance(obj1, dict): + if obj1.keys() != obj2.keys(): + return False + return all(deep_compare(obj1[key], obj2[key]) for key in obj1) + if isinstance(obj1, (tuple, list)): + if len(obj1) != len(obj2): + return False + return all(deep_compare(item1, item2) for item1, item2 in zip(obj1, obj2)) + if isinstance(obj1, (int, float)): + return abs(obj1 - obj2) < float_tolerance + return obj1 == obj2 + +class TestDebuggerSave(unittest.TestCase): + @staticmethod + def write_config_json(step, async_dump, mode, dump_path, config_file_path): + task = "tensor" if mode == "tensor" else "statistics" + statistics_summary_mode = "statistics" if mode == "statistics" else "md5" + config = { + "task": task, + "dump_path": dump_path, + "rank": [], + "step": step, + "level": "debug", + "enable_dataloader": False, + "async_dump": async_dump, + "statistics": { + "summary_mode": statistics_summary_mode, + } + } + with open(config_file_path, "w", encoding="utf-8") as f: + json.dump(config, f, indent=4, ensure_ascii=False) + + @staticmethod + def read_debug_json_into_dict(debug_json_path): + with open(debug_json_path, "r", encoding="utf-8") as f: + debug_json = json.load(f) + return debug_json + + + @staticmethod + def check_real_pt(pt_path, target_pt_tensor, check_values=True, rtol=1e-5, atol=1e-8): + """ + Enhanced version with optional value comparison. + + Args: + pt_path (str): Path to the .pt file + target_pt_tensor: Target torch tensor to compare + check_values (bool): If True, also compare array values + rtol, atol: Relative and absolute tolerances for value comparison + + Returns: + bool: True if all checks pass + """ + # Load the pt file + try: + pt_data = torch.load(pt_path) + except FileNotFoundError: + print(f"Error: The file {pt_path} does not exist.") + return False + except Exception as e: + print(f"Error loading pt file: {e}") + return False + # Check shapes + if pt_data.shape != target_pt_tensor.shape: + print(f"Shape mismatch: pt data shape is {pt_data.shape}, target tensor shape is {target_pt_tensor.shape}") + return False + # Check dtypes + if pt_data.dtype != target_pt_tensor.dtype: + print(f"Shape mismatch: pt data dtype is {pt_data.dtype}, target tensor dtype is {target_pt_tensor.dtype}") + return False + # Optionally check values + if check_values: + if not torch.allclose(pt_data, target_pt_tensor, rtol=rtol, atol=atol): + print("Value mismatch: pt data and target tensor values do not match within the specified tolerances.") + return False + return True + + def setUp(self): + if not os.path.exists(test_dir): + os.makedirs(test_dir) + PrecisionDebugger._instance = None + + def tearDown(self): + if os.path.exists(test_dir): + shutil.rmtree(test_dir) + PrecisionDebugger._instance = None + + def test_save_real_tensor(self): + data = {"a": torch.Tensor([1., 2.])} + step = [] + async_dump = False + mode = "tensor" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + PrecisionDebugger.save(data, "data_dict", save_backward=False) + PrecisionDebugger.step() + # check pt file + pt_path = os.path.join(dump_path, "step0", "rank", "dump_tensor_data", "data_dict.0.debug.a.pt") + assert self.check_real_pt(pt_path, data["a"]) + # check debug json + target_debug_info = { + "a": { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "requires_grad": False, + "data_name": "data_dict.0.debug.a.pt" + } + } + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + assert deep_compare(debug_json_dict["data"]["data_dict.0.debug"], target_debug_info) + + def test_save_md5(self): + data = {"a": torch.Tensor([1., 2.])} + step = [] + async_dump = False + mode = "md5" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + PrecisionDebugger.save(data, "data_dict", save_backward=False) + PrecisionDebugger.step() + # check debug json + target_debug_info = { + "a": { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "requires_grad": False, + "md5": "2e3fa576" + } + } + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + assert deep_compare(debug_json_dict["data"]["data_dict.0.debug"], target_debug_info) + + def test_save_multiple_steps(self): + data = {"a": torch.Tensor([1., 2.])} + step = [0, 1, 2] + async_dump = False + mode = "tensor" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + for _ in step: + PrecisionDebugger.save(data, "data_dict", save_backward=False) + PrecisionDebugger.step() + # check pt file + for i in step: + pt_path = os.path.join(dump_path, f"step{i}", "rank", "dump_tensor_data", "data_dict.0.debug.a.pt") + assert self.check_real_pt(pt_path, data["a"]) + # check debug json + target_debug_info = { + "a": { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "requires_grad": False, + "data_name": "data_dict.0.debug.a.pt" + } + } + for i in step: + debug_json_path = os.path.join(dump_path, f"step{i}", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + assert deep_compare(debug_json_dict["data"]["data_dict.0.debug"], target_debug_info) + + def test_async_save_tensor(self): + data = {"a": torch.Tensor([1., 2.])} + step = [] + async_dump = True + mode = "tensor" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + PrecisionDebugger.save(data, "data_dict", save_backward=False) + PrecisionDebugger.step() + + # check pt file + pt_path = os.path.join(dump_path, "step0", "rank", "dump_tensor_data", "data_dict.0.debug.a.pt") + assert self.check_real_pt(pt_path, data["a"]) + + # check debug json + target_debug_info = { + "a": { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": [ + 2 + ], + "data_name": "data_dict.0.debug.a.pt", + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "requires_grad": False, + } + } + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + assert deep_compare(debug_json_dict["data"]["data_dict.0.debug"], target_debug_info) + + def test_save_multiple_times(self): + data = {"a": torch.Tensor([1., 2.])} + step = [] + call_times = 3 + async_dump = False + mode = "tensor" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + for _ in range(call_times): + PrecisionDebugger.save(data, "data_dict", save_backward=False) + PrecisionDebugger.step() + + # check pt file + for i in range(call_times): + pt_path = os.path.join(dump_path, "step0", "rank", "dump_tensor_data", f"data_dict.{i}.debug.a.pt") + assert self.check_real_pt(pt_path, data["a"]) + + # check debug json + for i in range(call_times): + target_debug_info = { + "a": { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "requires_grad": False, + "data_name": f"data_dict.{i}.debug.a.pt" + } + } + + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + assert deep_compare(debug_json_dict["data"][f"data_dict.{i}.debug"], target_debug_info) + + def test_save_backward(self): + x = torch.Tensor([1., 2.]) + target_x_grad = torch.Tensor([1., 1.]) + def _forward_simple_func(x): + PrecisionDebugger.save(x, "x_tensor") + return x.sum() + step = [] + async_dump = False + mode = "tensor" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + x.requires_grad = True + loss = _forward_simple_func(x) + loss.backward() + PrecisionDebugger.step() + x_info_list = [ + x, + os.path.join(dump_path, "step0", "rank", "dump_tensor_data", "x_tensor.0.debug.pt"), + "x_tensor.0.debug", + { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "requires_grad": True, + "data_name": "x_tensor.0.debug.pt" + }, + ] + x_grad_info_list = [ + target_x_grad, + os.path.join(dump_path, "step0", "rank", "dump_tensor_data", "x_tensor_grad.0.debug.pt"), + "x_tensor_grad.0.debug", + { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": [ + 2 + ], + "Max": 1.0, + "Min": 1.0, + "Mean": 1.0, + "Norm": 1.4142135381698608, + "requires_grad": False, + "data_name": "x_tensor_grad.0.debug.pt" + }, + ] + check_list = [x_info_list, x_grad_info_list] + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + for check_info in check_list: + target_tensor, target_tensor_path, target_tensor_key, target_tensor_info = check_info + assert self.check_real_pt(target_tensor_path, target_tensor) + assert deep_compare(debug_json_dict["data"][target_tensor_key], target_tensor_info) + + def test_save_compilcated_data_structure_backward(self): + x = torch.Tensor([1., 2.]) + target_x_grad = torch.Tensor([1., 1.]) + def _forward_complicated_func(x): + complicated_structure = [{"a_key": x}] + PrecisionDebugger.save(complicated_structure, "complicated_structure") + return complicated_structure[0]["a_key"].sum() + step = [] + async_dump = False + mode = "tensor" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + x.requires_grad = True + loss = _forward_complicated_func(x) + loss.backward() + PrecisionDebugger.step() + complicated_structure_info_list = [ + x, + os.path.join(dump_path, "step0", "rank", "dump_tensor_data", "complicated_structure.0.debug.0.a_key.pt"), + "complicated_structure.0.debug", + [ + { + "a_key": { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "requires_grad": True, + "data_name": "complicated_structure.0.debug.0.a_key.pt" + } + } + ], + ] + complicated_structure_grad_info_list = [ + target_x_grad, + os.path.join(dump_path, "step0", "rank", "dump_tensor_data", "complicated_structure_grad.0.debug.0.a_key.pt"), + "complicated_structure_grad.0.debug", + [ + { + "a_key": { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": [ + 2 + ], + "Max": 1.0, + "Min": 1.0, + "Mean": 1.0, + "Norm": 1.4142135381698608, + "requires_grad": False, + "data_name": "complicated_structure_grad.0.debug.0.a_key.pt" + } + } + ], + ] + check_list = [complicated_structure_info_list, complicated_structure_grad_info_list] + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + for check_info in check_list: + target_tensor, target_tensor_path, target_tensor_key, target_tensor_info = check_info + assert self.check_real_pt(target_tensor_path, target_tensor) + assert deep_compare(debug_json_dict["data"][target_tensor_key], target_tensor_info) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py index 63d6abc3a2430bb6f092820c4b97a02cdf675612..4ba3556c277f3326520547a6124170f32a9cc8e8 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py @@ -16,45 +16,68 @@ import unittest from unittest.mock import patch, MagicMock -import torch -import torch.nn as nn -from msprobe.pytorch import PrecisionDebugger -from msprobe.pytorch.hook_module.api_registry import api_register -from msprobe.pytorch.service import torch_version_above_or_equal_2 +from torch import nn + +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.dump.module_dump.module_dump import ModuleDumper +from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser class TestModuleDumper(unittest.TestCase): - @classmethod - def setUpClass(cls): - PrecisionDebugger._instance = None - api_register.api_originality() + def setUp(self): + self.service = MagicMock() + with patch('msprobe.pytorch.dump.module_dump.module_dump.get_api_register'): + self.module_dumper = ModuleDumper(self.service) - @classmethod - def tearDownClass(cls): - PrecisionDebugger._instance = None - api_register.api_originality() + def test__init__(self): + self.service = MagicMock() + with patch('msprobe.pytorch.dump.module_dump.module_dump.get_api_register') as mock_get_api_register: + self.module_dumper = ModuleDumper(self.service) + self.assertEqual(self.module_dumper.service, self.service) + mock_get_api_register.assert_called_once() - def setUp(self): - self.module = nn.Linear(8, 4) - debugger = PrecisionDebugger(dump_path="./") - self.module_dumper = debugger.module_dumper + def test_start_module_dump(self): + module = nn.Module() + with patch.object(logger, 'info_on_rank_0') as mock_info: + module.msprobe_hook = True + ModuleProcesser.enable_module_dump = False + self.module_dumper.api_register.restore_all_api.reset_mock() + self.module_dumper.start_module_dump(module, 'dump_name') + mock_info.assert_called_with('The init dump is enabled, and the module dump function will not be available.') + self.assertFalse(ModuleProcesser.enable_module_dump) + self.module_dumper.api_register.restore_all_api.assert_not_called() + self.assertFalse(hasattr(module, 'msprobe_module_dump')) + + del module.msprobe_hook + mock_info.reset_mock() + self.module_dumper.start_module_dump(module, 'dump_name') + mock_info.assert_not_called() + self.assertTrue(ModuleProcesser.enable_module_dump) + self.module_dumper.api_register.restore_all_api.assert_called_once() + self.module_dumper.service.module_processor.register_module_hook.assert_called_with( + module, + self.module_dumper.service.build_hook, + recursive=False, + module_names=['dump_name'] + ) + self.assertTrue(module.msprobe_module_dump) + ModuleProcesser.enable_module_dump = False + + self.module_dumper.api_register.restore_all_api.reset_mock() + self.module_dumper.service.module_processor.register_module_hook.reset_mock() + self.module_dumper.start_module_dump(module, 'dump_name') + mock_info.assert_not_called() + self.assertTrue(ModuleProcesser.enable_module_dump) + self.module_dumper.api_register.restore_all_api.assert_called_once() + self.module_dumper.service.module_processor.register_module_hook.assert_not_called() + + ModuleProcesser.enable_module_dump = False def test_stop_module_dump(self): - self.module_dumper.hook_handle_list.extend([1, 2, 3]) - with patch('msprobe.pytorch.dump.module_dump.module_dump.api_register') as mock_api_register: - mock_handle1 = MagicMock(spec=torch.utils.hooks.RemovableHandle) - mock_handle2 = MagicMock(spec=torch.utils.hooks.RemovableHandle) - self.module_dumper.hook_handle_list.extend([mock_handle1, mock_handle2]) - - self.module_dumper.stop_module_dump() - mock_handle1.remove.assert_called_once() - mock_handle2.remove.assert_called_once() - self.assertEqual(self.module_dumper.hook_handle_list, []) - mock_api_register.api_modularity.assert_called_once() - - def test_register_hook(self): - self.module_dumper.register_hook(self.module, "TestModule") - if torch_version_above_or_equal_2: - self.assertEqual(len(self.module_dumper.hook_handle_list), 6) - else: - self.assertEqual(len(self.module_dumper.hook_handle_list), 5) + ModuleProcesser.enable_module_dump = True + self.module_dumper.api_register.register_all_api.reset_mock() + self.module_dumper.stop_module_dump() + self.assertFalse(ModuleProcesser.enable_module_dump) + self.module_dumper.api_register.register_all_api.assert_called_once() + + self.module_dumper.api_register.register_all_api.reset_mock() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py index f8a561b61b6a758a525675bdc59957e5c923b261..50c43288d2638b6a636085ef1acc9b568cdd6aaf 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py @@ -1,104 +1,306 @@ -import unittest -from unittest.mock import MagicMock +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from io import StringIO +import threading +import unittest +from unittest.mock import patch, MagicMock import torch +import msprobe.pytorch.dump.module_dump.module_processer as mp from msprobe.core.data_dump.scope import ModuleRangeScope -from msprobe.pytorch.common.utils import Const -from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser +from msprobe.pytorch.dump.module_dump.module_processer import ( + ModuleProcesser, + wrap_megatron_deallocate, + wrap_forward_with_hook_safety +) +from torch.utils.checkpoint import _StopRecomputationError +ori_checkpoint = torch.utils.checkpoint.checkpoint -class TestModuleProcesser(unittest.TestCase): +class TestModule(torch.nn.Module): + """测试用的模块类,可控制是否抛出异常""" + + def __init__(self, raise_exception=False): + super().__init__() + self.raise_exception = raise_exception + + def forward(self, x, *args, **kwargs): + if self.raise_exception: + raise _StopRecomputationError() + return x * 2 + + +def forward_hook_fn(module, args, kwargs_or_output, output_or_kwargs=None): + print(f"The forward_hook executed normally.") + + +class TestWrapper(unittest.TestCase): + def setUp(self): + torch.utils.checkpoint.checkpoint = ori_checkpoint + self.held_output = StringIO() + self.original_stdout = sys.stdout + sys.stdout = self.held_output + + def tearDown(self): + """恢复标准输出""" + sys.stdout = self.original_stdout + + def get_output(self): + """获取捕获的输出内容""" + return self.held_output.getvalue().strip() + + def test_wrap_megatron_deallocate(self): + mock_func = MagicMock(return_value="output_test") + wrapped = wrap_megatron_deallocate(mock_func) + + mock_tensor = MagicMock(spec=torch.Tensor) + mock_tensor._base = True + mock_tensor.device = "cpu" + mock_tensor.dtype = torch.float32 + mock_tensor.clone.return_value = "cloned" + + result = wrapped(mock_tensor, deallocate_pipeline_outputs=True) + mock_tensor.clone.assert_called_once() + self.assertEqual(mock_tensor.data.shape, (1,)) + self.assertEqual(result, "output_test") + mock_func.assert_called_once_with("cloned", True) + + result = wrapped("normal_input", False) + self.assertEqual(result, "output_test") + mock_func.assert_called_with("normal_input", False) + + def test_normal_forward_execution(self): + """测试正常执行forward时的情况""" + # 准备测试模块和hook + module = TestModule(raise_exception=False) + module.register_forward_hook(forward_hook_fn) + + # 应用包装函数 + wrap_forward_with_hook_safety(module) + + # 执行forward + input_tensor = torch.tensor(3.0) + output = module(input_tensor) + + # 验证结果和hook调用 + self.assertEqual(output.item(), 6.0) + self.assertIn("The forward_hook executed normally.", self.get_output()) + + def test_stop_recomputation_exception_triggers_hook(self): + """测试抛出_StopRecomputationError时hook被调用""" + # 准备测试模块和hook + module = TestModule(raise_exception=True) + module.register_forward_hook(forward_hook_fn) + + # 应用包装函数 + wrap_forward_with_hook_safety(module) + + # 执行forward并验证异常 + input_tensor = torch.tensor(3.0) + with self.assertRaises(_StopRecomputationError): + module(input_tensor) + + self.assertIn("The forward_hook executed normally.", self.get_output()) + + +class TestModuleProcesser(unittest.TestCase): def setUp(self): - self.mock_tensor = MagicMock(spec=torch.Tensor) + ModuleProcesser.module_count = {} + ModuleProcesser.module_stack = {} + ModuleProcesser.module_node = {} + ModuleProcesser.api_parent_node = {} + + self.scope = ModuleRangeScope([], []) self.mock_scope = MagicMock() - self.processor = ModuleProcesser(self.mock_scope) - - def test_scope_is_module_range_scope(self): - scope = ModuleRangeScope([], []) - processor = ModuleProcesser(scope) - self.assertEqual(processor.scope, scope) - - def test_scope_is_not_module_range_scope(self): - scope = "not a ModuleRangeScope" - processor = ModuleProcesser(scope) - self.assertIsNone(processor.scope) - - def test_clone_return_value_and_test_clone_if_tensor(self): - def func(x): - return x - - input = torch.tensor([1]) - input_tuple = (torch.tensor([1]), torch.tensor([2])) - input_list = [torch.tensor([1]), torch.tensor([2])] - input_dict = {"A": torch.tensor([1]), "B": torch.tensor([2])} - - result = ModuleProcesser.clone_return_value(func)(input) - result[0] = 2 - self.assertNotEqual(result, input) - result_tuple = ModuleProcesser.clone_return_value(func)(input_tuple) - result_tuple[0][0] = 2 - self.assertNotEqual(result_tuple, input_tuple) - result_list = ModuleProcesser.clone_return_value(func)(input_list) - result_list[0][0] = 2 - self.assertNotEqual(result_list, input_list) - result_dict = ModuleProcesser.clone_return_value(func)(input_dict) - result_dict["A"][0] = 2 - self.assertNotEqual(result_dict, input_dict) - - def test_module_count_func(self): - test = ModuleProcesser(None) - self.assertEqual(test.module_count, {}) - module_name = "nope" - test.module_count_func(module_name) - self.assertEqual(test.module_count["nope"], 0) - - def test_node_hook_forward_start(self): - name_prefix = "forward_layer" - hook = self.processor.node_hook(name_prefix, start_or_stop=Const.START) - module = MagicMock() - input = (self.mock_tensor,) - module.mindstudio_reserved_name = None - hook(module, input) - expected_name = f"forward_layer{Const.SEP}0" - self.assertEqual(module.mindstudio_reserved_name, [expected_name]) - self.assertIn(expected_name, ModuleProcesser.module_stack) - self.assertEqual(ModuleProcesser.api_parent_node, expected_name) - - def test_node_hook_forward_stop(self): - name_prefix = "forward_layer" - hook = self.processor.node_hook(name_prefix, start_or_stop=Const.STOP) - ModuleProcesser.module_stack.append(f"forward_layer{Const.SEP}0") - - module = MagicMock() - input = (self.mock_tensor,) - reserved_name = f"forward_layer{Const.SEP}0" - module.mindstudio_reserved_name = [reserved_name] - hook(module, input) - self.assertNotIn([f"forward_layer{Const.SEP}0"], ModuleProcesser.module_stack) - self.assertEqual(ModuleProcesser.api_parent_node, reserved_name) - - def test_node_hook_backward(self): - name_prefix = "backward_layer" - hook = self.processor.node_hook(name_prefix, start_or_stop=Const.START) - - module = MagicMock() - input = (self.mock_tensor,) - module.mindstudio_reserved_name = None - ModuleProcesser.module_node[f"forward_layer{Const.SEP}0"] = None - hook(module, input) - expected_name = f"backward_layer{Const.SEP}0" - self.assertEqual(module.mindstudio_reserved_name, [expected_name]) - self.assertIn(expected_name, ModuleProcesser.module_node) + + @patch('msprobe.pytorch.dump.module_dump.module_processer.wrap_setup_input_output_hook') + def test_init_with_valid_scope(self, mock_wrap): + processor = ModuleProcesser(self.scope) + self.assertEqual(processor.scope, self.scope) + mock_wrap.assert_called_once() + + @patch('msprobe.pytorch.dump.module_dump.module_processer.logger.info_on_rank_0') + def test_init_without_megatron(self, mock_log): + ModuleProcesser(self.scope) + mock_log.assert_called_with("No megatron find.") + + def test_set_and_get_calls_number(self): + count = ModuleProcesser.set_and_get_calls_number("test_module") + self.assertEqual(count, 0) + + count = ModuleProcesser.set_and_get_calls_number("test_module") + self.assertEqual(count, 1) def test_has_register_backward_hook(self): - module = MagicMock() - module._backward_hooks = {0: lambda: None} - module._is_full_backward_hook = False - result = self.processor.has_register_backward_hook(module) - self.assertTrue(result) - - module._is_full_backward_hook = True - result = self.processor.has_register_backward_hook(module) - self.assertFalse(result) + module1 = torch.nn.Linear(10, 10) + self.assertFalse(ModuleProcesser.has_register_backward_hook(module1)) + + module2 = MagicMock() + module2._backward_hooks = [1, 2, 3] + module2._is_full_backward_hook = False + self.assertTrue(ModuleProcesser.has_register_backward_hook(module2)) + + module2._is_full_backward_hook = True + self.assertFalse(ModuleProcesser.has_register_backward_hook(module2)) + + def test_get_modules_and_names_with_model_list(self): + mock_model1 = MagicMock() + mock_model2 = MagicMock() + mock_model1.named_modules.return_value = [("layer1", "obj1"), ("layer2", "obj2")] + mock_model2.named_modules.return_value = [("layer3", "obj3")] + + result = ModuleProcesser.get_modules_and_names( + [mock_model1, mock_model2], + recursive=True, + module_names=["model1", "model2"] + ) + self.assertEqual(result, { + "0": [("layer1", "obj1"), ("layer2", "obj2")], + "1": [("layer3", "obj3")] + }) + + def test_get_modules_and_names_with_model_tuple(self): + mock_model1 = MagicMock() + mock_model2 = MagicMock() + mock_model1.named_modules.return_value = [("layer1", "obj1")] + mock_model2.named_modules.return_value = [("layer2", "obj2")] + + result = ModuleProcesser.get_modules_and_names( + (mock_model1, mock_model2), + recursive=True, + module_names=["model1", "model2"] + ) + self.assertEqual(result, { + "0": [("layer1", "obj1")], + "1": [("layer2", "obj2")] + }) + + def test_get_modules_and_names_with_single_recursive(self): + mock_model = MagicMock() + mock_model.named_modules.return_value = [("layer1", "obj1")] + + result = ModuleProcesser.get_modules_and_names( + mock_model, + recursive=True, + module_names=["single_model"] + ) + self.assertEqual(result, { + "-1": [("layer1", "obj1")] + }) + + def test_get_modules_and_names_with_single_non_recursive(self): + mock_model = MagicMock() + result = ModuleProcesser.get_modules_and_names( + mock_model, + recursive=False, + module_names=["single_model"] + ) + self.assertEqual(result, { + "-1": [("single_model", mock_model)] + }) + + def test_get_modules_and_names_invalid_case(self): + result = ModuleProcesser.get_modules_and_names( + [MagicMock(), MagicMock()], + recursive=False, + module_names=["only_one_name"] + ) + self.assertEqual(result, {}) + + result = ModuleProcesser.get_modules_and_names( + MagicMock(), + recursive=False, + module_names=["name1", "name2"] + ) + self.assertEqual(result, {}) + + def test_reset_module_stats(self): + ModuleProcesser.module_count = {"test": 1} + ModuleProcesser.module_stack = ["layer1"] + ModuleProcesser.api_parent_node = "parent" + ModuleProcesser.module_node = {"key": "value"} + ModuleProcesser.module_bw_hook_kernels = {"hook": "data"} + ModuleProcesser.enable_module_dump = True + + ModuleProcesser.reset_module_stats() + + self.assertEqual(ModuleProcesser.module_count, {}) + self.assertEqual(ModuleProcesser.module_stack, {}) + self.assertEqual(ModuleProcesser.api_parent_node, {}) + self.assertEqual(ModuleProcesser.module_node, {}) + self.assertEqual(ModuleProcesser.module_bw_hook_kernels, {}) + self.assertFalse(ModuleProcesser.enable_module_dump) + + def test_set_construct_info_in_pre_hook_with_stack(self): + processor = ModuleProcesser(self.mock_scope) + ModuleProcesser.module_stack[threading.get_ident()] = ["parent_module"] + processor.scope = self.mock_scope + + processor.set_construct_info_in_pre_hook("current_module") + + self.assertEqual(ModuleProcesser.module_node["current_module"], "parent_module") + self.assertEqual( + ModuleProcesser.module_stack[threading.get_ident()], + ["parent_module", "current_module"] + ) + self.assertEqual(ModuleProcesser.api_parent_node[threading.get_ident()], "current_module") + self.mock_scope.begin_module.assert_called_once_with("current_module") + + def test_set_construct_info_in_pre_hook_empty_stack(self): + processor = ModuleProcesser(self.mock_scope) + processor.scope = self.mock_scope + processor.set_construct_info_in_pre_hook("root_module") + + self.assertIsNone(ModuleProcesser.module_node["root_module"]) + self.assertEqual(ModuleProcesser.module_stack[threading.get_ident()], ["root_module"]) + self.assertEqual(ModuleProcesser.api_parent_node[threading.get_ident()], "root_module") + + def test_set_construct_info_in_hook_with_forward(self): + mp.torch_version_above_or_equal_2 = True + processor = ModuleProcesser(self.mock_scope) + ModuleProcesser.module_stack = {threading.get_ident(): ["parent", "current"]} + processor.scope = self.mock_scope + + processor.set_construct_info_in_hook("current") + + self.assertEqual(ModuleProcesser.module_stack[threading.get_ident()], ["parent"]) + self.assertEqual(ModuleProcesser.api_parent_node[threading.get_ident()], "parent") + self.mock_scope.end_module.assert_called_once_with("current") + + def test_set_construct_info_in_hook_with_backward(self): + mp.torch_version_above_or_equal_2 = False + processor = ModuleProcesser(self.mock_scope) + processor.scope = self.mock_scope + + processor.set_construct_info_in_hook("backward_module", is_forward=False) + + self.assertEqual(ModuleProcesser.api_parent_node[threading.get_ident()], "backward_module") + self.mock_scope.begin_module.assert_called_once_with("backward_module") + + def test_set_construct_info_in_hook_empty_stack(self): + mp.torch_version_above_or_equal_2 = True + processor = ModuleProcesser(self.mock_scope) + + processor.set_construct_info_in_hook("module") + + self.assertEqual(ModuleProcesser.api_parent_node, {threading.get_ident(): None}) + + +if __name__ == "__main__": + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_pt_hook_wrapper.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_pt_hook_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..c070846a6042716ba146ef1b4e263638cca7a343 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_pt_hook_wrapper.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch + +import torch + +from msprobe.pytorch.dump.module_dump.hook_wrapper import wrap_setup_backward_hook + + +class TestWrapSetupBackwardHook(unittest.TestCase): + def setUp(self): + self.mock_func = MagicMock() + self.mock_func.return_value = ["clone_tensor1", "clone_tensor2"] + + self.decorated_func = wrap_setup_backward_hook(self.mock_func) + + self.tensor = torch.randn(3, requires_grad=True) + torch.set_grad_enabled(True) + + def test_insufficient_args(self): + result = self.decorated_func("test_case1") + self.mock_func.assert_called_once_with("test_case1") + self.assertListEqual(result, ["clone_tensor1", "clone_tensor2"]) + + def test_normal_processing_flow(self): + test_tensor = torch.randn(2, requires_grad=False) + test_data = { + "tensors": [self.tensor, torch.randn(2, requires_grad=True)], + "nested": { + "tuple": (self.tensor, test_tensor) + } + } + + mock_self = MagicMock() + mock_self.module.inplace = False + test_tensor1 = torch.randn(4, requires_grad=True) + test_tensor2 = torch.randn(4, requires_grad=True) + test_tensor3 = torch.randn(4, requires_grad=True) + self.mock_func.return_value = [test_tensor1, test_tensor2, test_tensor3] + result = self.decorated_func(mock_self, test_data) + + self.assertIsInstance(result, dict) + self.assertFalse(torch.equal(result["tensors"][0], self.tensor)) + self.assertTrue(torch.equal(result["tensors"][1], test_tensor2)) + self.assertIsInstance(result["nested"]["tuple"][0], torch.Tensor) + self.assertTrue(torch.equal(result["nested"]["tuple"][1], test_tensor)) + + def test_complex_data_structures(self): + test_case = [ + self.tensor, + {"dict": torch.randn(4, requires_grad=True)}, + (torch.randn(5, requires_grad=True),), + [torch.randn(6, requires_grad=True)] + ] + + mock_self = MagicMock() + mock_self.module.inplace = False + test_tensor1 = torch.randn(4, requires_grad=True) + test_tensor2 = torch.randn(5, requires_grad=True) + test_tensor3 = torch.randn(6, requires_grad=True) + self.mock_func.return_value = [self.tensor, test_tensor1, test_tensor2, test_tensor3] + result = self.decorated_func(mock_self, test_case) + + self.assertIsInstance(result, list) + self.assertTrue(torch.equal(result[1]["dict"], test_tensor1)) + self.assertTrue(torch.equal(result[2][0], test_tensor2)) + self.assertTrue(torch.equal(result[3][0], test_tensor3)) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_pt_kernel_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_pt_kernel_config.py index fbeeb07ffc9ac43eedc22ed95d1fa142bb2dd6e4..89176f5f51ce9f93b13bc14906be62e0425d957c 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_pt_kernel_config.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_pt_kernel_config.py @@ -16,11 +16,11 @@ import unittest from unittest.mock import patch -from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json +from msprobe.core.kernel_dump.kernel_config import create_kernel_config_json class TestPtKernelConfig(unittest.TestCase): - @patch("msprobe.pytorch.dump.kernel_dump.kernel_config.save_json") + @patch("msprobe.core.kernel_dump.kernel_config.save_json") def test_create_kernel_config_json_with_rank(self, mock_save_json): dump_path = "./step0" cur_rank = 0 @@ -36,7 +36,7 @@ class TestPtKernelConfig(unittest.TestCase): } mock_save_json.assert_called_once_with(kernel_config_path, config_info, indent=4) - @patch("msprobe.pytorch.dump.kernel_dump.kernel_config.save_json") + @patch("msprobe.core.kernel_dump.kernel_config.save_json") def test_create_kernel_config_json_without_rank(self, mock_save_json): dump_path = "./step0" cur_rank = '' diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py index be2215dcd9cb22577a84954b9283ed68825de86e..d4e568303a8b22058cba4ad879b160b3169a6cae 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py @@ -166,7 +166,7 @@ class TestPerturbedLayer(TestCase): layer.pre_check(y) mock_logger.assert_called_with( "[msprobe] Free Benchmark: For test_api_name, " - "Maximun value is less than the minimun threshold. Cancel add noise." + "maximum value is less than the minimum threshold. Cancel adding noise." ) # 对于输入张量,add_noise扰动因子对大于极小值的部分增加一个小值 @@ -212,7 +212,7 @@ class TestPerturbedLayer(TestCase): layer.pre_check(y) mock_logger.assert_called_with( "[msprobe] Free Benchmark: For test_api_name, " - "Maximun value is less than the minimun threshold. Cancel add noise." + "maximum value is less than the minimum threshold. Cancel adding noise." ) # 对于低精度输入、run cpu会升精度在cpu上计算,并会打印日志 diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/test_main.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/test_main.py index 4d777a2f0c86b7abaddebbe1ec52b6608d514c90..1d533d22411a26170def59b40a14cda3cdfe3ab8 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/test_main.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/test_main.py @@ -1,4 +1,3 @@ -import functools from abc import ABC from unittest import TestCase @@ -93,31 +92,3 @@ class TestInterface(TestCase): output=out, ) self.assertEqual(result.dtype, torch.float16) - - def testBackwardCheck(self): - # 对于反向接口,在pre forward时暂存input, 然后在backwrad后进行对比 - config = Config(Const.BACKWARD, HandlerType.CHECK) - checker = FreeBenchmarkCheck(config) - processor = UnequalDataProcessor() - # 初始化输入输出 - x = torch.tensor([2, 3], dtype=torch.float16, requires_grad=True) - y = torch.tensor([2, 3], dtype=torch.float16, requires_grad=True) - grad_output = torch.tensor([1, 1], dtype=torch.float16) - backward_name = Const.SEP.join([self.api_name, Const.BACKWARD]) - # 执行前向生成grad saver实例 - mul_module = WrapMul(self.api_name) - checker.pre_forward(backward_name, mul_module, processor, (x, y), {}) - # 执行算子前向和反向, 并反向获取扰动后grad_input - out = mul_module(x, y) - checker.backward(backward_name, mul_module, grad_output) - out.backward(torch.ones_like(out)) - # module是否添加暂存器, 其中反向钩子执行扰动后grad_input是否正确 - self.assertTrue(hasattr(mul_module, CommonField.GRADSAVER)) - grad_saver = getattr(mul_module, CommonField.GRADSAVER) - self.assertEqual(grad_saver.perturbed_grad_input[0][0], 2) - handler = FuzzHandlerFactory.create(grad_saver.handler_params) - # 模拟一个张量的梯度更新时触发反向检测 - grad_saver.compare_grad_results( - handler, torch.tensor(1.0), torch.tensor(2.0), 0 - ) - self.assertEqual(len(processor.unequal_rows), 0) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/grad_probe/test_grad_csv.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/grad_probe/test_grad_csv.py index f39d3f091faf8d57f80cccbadc15259ee54269f0..80e32ac2890dfa0247eb5cd76aefe45cbc735345 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/grad_probe/test_grad_csv.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/grad_probe/test_grad_csv.py @@ -24,7 +24,7 @@ class TestGradCSV(unittest.TestCase): def test_level_L0_content(self): generated_csv_line = GradStatCsv.generate_csv_line("model.conv2d", level_adp["L0"], grad_tensor, [-1, 0, 1]) - self.assertEqual(['model.conv2d', '678a6c7d9d9716682b56fda097d0936c', 2.0, -2.0, 2.851315498352051, [2, 2]], + self.assertEqual(['model.conv2d', 'e2863940', 2.0, -2.0, 2.851315498352051, [2, 2]], generated_csv_line) def test_level_L1_content(self): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_api_registry.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_api_registry.py deleted file mode 100644 index 837ad23df76be2a012a7408dab4879847937f229..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_api_registry.py +++ /dev/null @@ -1,130 +0,0 @@ -import unittest -from msprobe.pytorch.hook_module.api_registry import ApiRegistry, torch_version_above_2, is_gpu - - -class TestApiRegistry(unittest.TestCase): - - def test_store_ori_attr(self): - class A(): - a1 = 1 - class B(): - a = A() - b1 = 1 - b2 = 2 - - api_list = ["a.a1", "b1", "b2"] - expect_output = {"a.a1":1, "b1":1, "b2":2} - actual_output = dict() - ApiRegistry.store_ori_attr(B, api_list, actual_output) - self.assertEqual(actual_output, expect_output) - - - def test_set_api_attr(self): - class A(): - a1 = 1 - class B(): - a = A().__class__ - b1 = 1 - - attr_dict = {"a.a2":2, "b2":2, "b3":3} - ApiRegistry.set_api_attr(B, attr_dict) - - for k, v in attr_dict.items(): - if '.' in k: - sub_module_name, sub_op = k.rsplit('.', 1) - sub_module = getattr(B, sub_module_name, None) - - self.assertEqual(getattr(sub_module, sub_op), v) - else: - self.assertEqual(getattr(B, k), v) - - def test_api_modularity(self): - - import torch - import torch.distributed as dist - #import torch_npu #门禁没有安装torch_npu - from msprobe.pytorch.hook_module.api_registry import torch_without_guard_version, npu_distributed_api, is_gpu, torch_version_above_2 - - - - reg = ApiRegistry() - attr_dict = {"b2":2, "b3":3} - reg.tensor_hook_attr = attr_dict - reg.torch_hook_attr = attr_dict - reg.functional_hook_attr = attr_dict - reg.distributed_hook_attr = attr_dict - reg.npu_distributed_hook_attr = attr_dict - reg.aten_hook_attr = attr_dict - reg.vf_hook_attr = attr_dict - reg.torch_npu_hook_attr = attr_dict - - reg.api_modularity() - self.assertEqual(torch.Tensor.b2, 2) - - self.assertEqual(torch.b2, 2) - self.assertEqual(torch.nn.functional.b2, 2) - self.assertEqual(dist.b2, 2) - self.assertEqual(dist.distributed_c10d.b2, 2) - #if not is_gpu and not torch_without_guard_version: - #self.assertEqual(torch_npu.distributed.b2, 2) - #self.assertEqual(torch_npu.distributed.distributed_c10d.b2, 2) - if torch_version_above_2: - self.assertEqual(torch.ops.aten.b2, 2) - self.assertEqual(torch._VF.b2, 2) - #if not is_gpu: - #self.assertEqual(torch_npu.b2, 2) - - - def test_api_originality(self): - import torch - import torch.distributed as dist - #import torch_npu #门禁没有安装torch_npu - from msprobe.pytorch.hook_module.api_registry import torch_without_guard_version, npu_distributed_api, is_gpu, torch_version_above_2 - - - - reg = ApiRegistry() - attr_dict = {"b2":2, "b3":3} - reg.tensor_hook_attr = attr_dict - reg.torch_hook_attr = attr_dict - reg.functional_hook_attr = attr_dict - reg.distributed_hook_attr = attr_dict - reg.npu_distributed_hook_attr = attr_dict - reg.aten_hook_attr = attr_dict - reg.vf_hook_attr = attr_dict - reg.torch_npu_hook_attr = attr_dict - - reg.api_originality() - self.assertEqual(torch.Tensor.b2, 2) - - self.assertEqual(torch.b2, 2) - self.assertEqual(torch.nn.functional.b2, 2) - self.assertEqual(dist.b2, 2) - self.assertEqual(dist.distributed_c10d.b2, 2) - #if not is_gpu and not torch_without_guard_version: - #self.assertEqual(torch_npu.distributed.b2, 2) - #self.assertEqual(torch_npu.distributed.distributed_c10d.b2, 2) - if torch_version_above_2: - self.assertEqual(torch.ops.aten.b2, 2) - self.assertEqual(torch._VF.b2, 2) - #if not is_gpu: - #self.assertEqual(torch_npu.b2, 2) - - def test_initialize_hook(self): - def hook_test(): - pass - - reg = ApiRegistry() - reg.initialize_hook(hook_test) - empty_list = [] - self.assertFalse(empty_list==reg.tensor_hook_attr) - self.assertFalse(empty_list==reg.torch_hook_attr) - self.assertFalse(empty_list==reg.functional_hook_attr) - self.assertFalse(empty_list==reg.distributed_hook_attr) - self.assertFalse(empty_list==reg.npu_distributed_hook_attr) - if torch_version_above_2: - #print(True) - self.assertFalse(empty_list==reg.aten_hook_attr) - if not is_gpu: - #print(True) - self.assertFalse(empty_list==reg.torch_npu_hook_attr) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py deleted file mode 100644 index 1524a82ae1fc81eee245fa73bde4b4938cb89638..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py +++ /dev/null @@ -1,34 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch -import threading -from msprobe.pytorch.hook_module.hook_module import HOOKModule - -class TestHOOKModuleInit(unittest.TestCase): - - def setUp(self): - self.mock_build_hook = MagicMock(return_value=(MagicMock(), MagicMock(), MagicMock(), None)) - - def test_thread_handling(self): - module = HOOKModule(self.mock_build_hook) - current_thread_id = module.current_thread - self.assertEqual(current_thread_id, threading.current_thread().ident) - - -class TestHOOKModuleCall(unittest.TestCase): - def setUp(self): - self.mock_build_hook = MagicMock(return_value=(MagicMock(), MagicMock(), MagicMock(), None)) - self.module = HOOKModule(self.mock_build_hook) - - @patch.object(HOOKModule, '_call_func') - def test_call_function(self, mock_call_func): - mock_call_func.return_value = "test_result" - result = self.module("input_data") - mock_call_func.assert_called_once_with("input_data", **{}) - self.assertEqual(result, "test_result") - - @patch.object(HOOKModule, '_call_func') - def test_call_func_with_hooks(self, mock_call_func): - mock_call_func.return_value = "test_result_with_hooks" - result = self.module("input_data") - self.assertEqual(result, "test_result_with_hooks") - HOOKModule.inner_stop_hook[self.module.current_thread] = False diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_api_register.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_api_register.py new file mode 100644 index 0000000000000000000000000000000000000000..689cff0ee52a57847069e73a038df0738fdd7c73 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_api_register.py @@ -0,0 +1,211 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch + +import msprobe.pytorch.hook_module.api_register as api_register +from msprobe.pytorch.hook_module.api_register import ( + tensor_module_forward, + dist_module_forward, + npu_module_forward, + get_api_register, + ApiTemplate +) + + +class TestAPIRegister(unittest.TestCase): + def setUp(self): + api_register.api_register = None + + def test_tensor_module_forward(self): + mock_module = MagicMock() + mock_module.api_name = "test_name" + mock_module.api_func.return_value = "test_result" + + args = (1, 2, 3) + kwargs = {"key": "value"} + result = tensor_module_forward(mock_module, *args, **kwargs) + + mock_module.api_func.assert_called_once_with(*args, **kwargs) + self.assertEqual(result, "test_result") + + @patch('msprobe.pytorch.hook_module.api_register.logger.warning') + def test_basic_dist_module_forward(self, mock_logger): + mock_module = MagicMock() + mock_module.api_func.return_value = "test_handle" + mock_module.api_name = "test_api" + + result = dist_module_forward(mock_module, 1, 2, key="value") + mock_module.api_func.assert_called_once_with(1, 2, key="value") + self.assertEqual(result, "test_handle") + mock_logger.assert_not_called() + + @patch('msprobe.pytorch.hook_module.api_register.ApiRegistry') + def test_get_api_register_with_new_obj(self, mock_api_registry): + get_api_register(return_new=True) + mock_api_registry.assert_called_once() + self.assertIsNone(api_register.api_register) + + @patch('msprobe.pytorch.hook_module.api_register.ApiRegistry') + def test_get_api_register_with_not_new_obj(self, mock_api_registry): + get_api_register() + mock_api_registry.assert_called_once() + self.assertIsNotNone(api_register.api_register) + + +class TestNpuModuleForward(unittest.TestCase): + def setUp(self): + self.npu_custom_functions = { + "custom_func": MagicMock(return_value="custom_result"), + "npu_fusion_attention": MagicMock(return_value="nfa_result"), + "gpu_fusion_attention": MagicMock(return_value="gfa_result") + } + + self.module = MagicMock() + self.module.api_func.return_value = "test_result" + + def test_with_hook_enabled(self): + self.module.need_hook = True + result = npu_module_forward(self.module, 1, 2, key="value") + self.module.api_func.assert_called_once_with(1, 2, key="value") + self.assertEqual(result, "test_result") + + def test_with_unknown_api(self): + self.module.need_hook = False + self.module.api_name = "unknown_func" + with patch('msprobe.pytorch.hook_module.api_register.npu_custom_functions', new=self.npu_custom_functions): + with self.assertRaises(Exception) as context: + npu_module_forward(self.module, 1, 2, key="value") + self.assertIn("There is not bench function unknown_func", str(context.exception)) + + def test_cuda_device_with_mapping(self): + self.module.need_hook = False + self.module.api_name = "npu_fusion_attention" + self.module.device = 'cuda' + + with patch('msprobe.pytorch.hook_module.api_register.npu_custom_functions', new=self.npu_custom_functions): + result = npu_module_forward(self.module, 1, 2, key="value") + self.npu_custom_functions["gpu_fusion_attention"].assert_called_once_with(1, 2, key="value") + self.assertEqual(result, "gfa_result") + + def test_cpu_device(self): + self.module.need_hook = False + self.module.api_name = "custom_func" + self.module.device = "cpu" + + with patch('msprobe.pytorch.hook_module.api_register.npu_custom_functions', new=self.npu_custom_functions): + result = npu_module_forward(self.module, 1, 2, key="value") + self.npu_custom_functions["custom_func"].assert_called_once_with(1, 2, key="value") + self.assertEqual(result, "custom_result") + + def test_unsupported_device(self): + self.module.need_hook = False + self.module.api_name = "custom_func" + self.module.device = "unsupported_device" + + with patch('msprobe.pytorch.hook_module.api_register.npu_custom_functions', new=self.npu_custom_functions): + result = npu_module_forward(self.module, 1, 2, key="value") + self.module.api_func.assert_called_once_with(1, 2, key="value") + self.assertEqual(result, "test_result") + + +class TestApiTemplate(unittest.TestCase): + def setUp(self): + self.api_name = "Tensor.test_api" + self.api_func = MagicMock(return_value="test_result") + self.prefix = "test_prefix" + self.hook_build_func = MagicMock() + self.mock_hook_module = MagicMock() + + def test_init(self): + with patch('msprobe.pytorch.hook_module.api_register.HOOKModule') as mock_hook_module: + template = ApiTemplate( + self.api_name, + self.api_func, + self.prefix, + self.hook_build_func, + need_hook=False + ) + + self.assertEqual(template.api_name, self.api_name) + self.assertEqual(template.api_func, self.api_func) + self.assertEqual(template.prefix, self.prefix) + self.assertEqual(template.prefix_api_name, "test_prefix.test_api.") + self.assertEqual(template.device, "cpu") + self.assertFalse(template.op_is_distributed) + self.assertFalse(template.need_hook) + + def test_init_with_distributed_prefix(self): + with patch('msprobe.pytorch.hook_module.api_register.HOOKModule'): + self.prefix = "Distributed" + template = ApiTemplate( + self.api_name, + self.api_func, + self.prefix, + self.hook_build_func, + need_hook=False, + device="npu" + ) + + self.assertEqual(template.device, "npu") + self.assertEqual(template.prefix_api_name, "Distributed.test_api.") + self.assertTrue(template.op_is_distributed) + + def test_init_without_hook(self): + with patch('msprobe.pytorch.hook_module.api_register.HOOKModule') as mock_hook_module: + template = ApiTemplate( + self.api_name, + self.api_func, + self.prefix, + self.hook_build_func, + need_hook=False, + device="npu" + ) + + self.assertFalse(template.need_hook) + self.mock_hook_module.assert_not_called() + + def test_forward_with_prefix_match(self): + with patch('msprobe.pytorch.hook_module.api_register.HOOKModule'): + self.prefix = "Tensor" + template = ApiTemplate( + self.api_name, + self.api_func, + self.prefix, + self.hook_build_func, + need_hook=False, + device="npu" + ) + + result = template.forward("arg1", key="value") + + self.assertEqual(result, "test_result") + + def test_forward_without_prefix_match(self): + with patch('msprobe.pytorch.hook_module.api_register.HOOKModule'): + template = ApiTemplate( + self.api_name, + self.api_func, + self.prefix, + self.hook_build_func, + need_hook=False, + device="npu" + ) + + result = template.forward("arg1", key="value") + + self.api_func.assert_called_once_with("arg1", key="value") + self.assertEqual(result, "test_result") diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_hook_manager.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_hook_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..6d407f640f6243afbda610c17032937a58403b0f --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_hook_manager.py @@ -0,0 +1,122 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading + +import unittest +from unittest.mock import MagicMock, patch +from contextlib import nullcontext +from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager +from msprobe.core.common.const import Const +from msprobe.core.hook_manager import HookSet, BaseHookManager + + +class TestPytorchHookManager(unittest.TestCase): + def setUp(self): + self.mock_data_collector = MagicMock() + self.mock_config = MagicMock() + self.mock_config.data_mode = ["all"] + self.mock_config.task = "statistics" + self.manager = PytorchHookManager( + self.mock_data_collector, + self.mock_config + ) + BaseHookManager.inner_switch[threading.get_ident()] = False + + def test_properties(self): + with patch('msprobe.pytorch.hook_module.pt_hook_manager.is_recomputation', return_value=True): + self.assertTrue(self.manager._is_recompute) + + with patch('msprobe.pytorch.hook_module.pt_hook_manager.is_recomputation', return_value=False): + self.assertFalse(self.manager._is_recompute) + + def test_no_grad_context(self): + self.assertIsInstance(self.manager._no_grad_context(), nullcontext) + + def test_add_count(self): + with patch('msprobe.pytorch.hook_module.pt_hook_manager.HOOKModule.add_module_count') as mock_add: + self.manager._add_count("test_layer") + mock_add.assert_called_once_with("test_layer") + + def test_process_kwargs_and_output(self): + kwargs, output = self.manager._process_kwargs_and_output( + None, + None, + "API", + "kwargs_value", + "output_value" + ) + self.assertEqual(kwargs, "kwargs_value") + self.assertEqual(output, "output_value") + + with patch('msprobe.pytorch.hook_module.pt_hook_manager.torch_version_above_or_equal_2', new=True): + kwargs, output = self.manager._process_kwargs_and_output( + None, + None, + None, + "kwargs_value", + "output_value" + ) + self.assertEqual(kwargs, "kwargs_value") + self.assertEqual(output, "output_value") + + with patch('msprobe.pytorch.hook_module.pt_hook_manager.torch_version_above_or_equal_2', new=False): + kwargs, output = self.manager._process_kwargs_and_output( + None, + None, + None, + "kwargs_value", + "output_value" + ) + self.assertEqual(kwargs, {}) + self.assertEqual(output, "kwargs_value") + + def test_build_hook(self): + hook_set = self.manager.build_hook(Const.API, "test_api") + self.assertIsInstance(hook_set, HookSet) + self.assertTrue(callable(hook_set.forward_pre_hook)) + self.assertTrue(callable(hook_set.distributed_forward_hook)) + self.assertIsNone(hook_set.forward_hook) + self.assertIsNone(hook_set.backward_pre_hook) + self.assertIsNone(hook_set.backward_hook) + + hook_set = self.manager.build_hook(Const.MODULE, "test_module") + self.assertEqual(hook_set.forward_hook.__name__, "forward_hook") + self.assertEqual(hook_set.backward_hook.__name__, "backward_hook") + + def test_need_exchange(self): + self.assertTrue(self.manager._need_exchange(None)) + self.assertTrue(self.manager._need_exchange(MagicMock())) + + def test_get_params_dict(self): + mock_module = MagicMock() + + self.mock_config.task = Const.STRUCTURE + params_dict = self.manager._get_params_dict(mock_module) + self.assertEqual(params_dict, {}) + + self.mock_config.task = "statistics" + + mock_named_params = [ + ("conv.weight", MagicMock()), + ("bn.bias", MagicMock()) + ] + mock_module.named_parameters.return_value = mock_named_params + params_dict = self.manager._get_params_dict(mock_module) + mock_module.named_parameters.assert_called_once_with(recurse=False) + + self.assertEqual(set(params_dict.keys()), {"weight", "bias"}) + self.assertEqual(params_dict["weight"], mock_named_params[0][1]) + self.assertEqual(params_dict["bias"], mock_named_params[1][1]) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_hook_module.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_hook_module.py new file mode 100644 index 0000000000000000000000000000000000000000..500fe32c2630f83ece00112e5aee32068fc46b6c --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_hook_module.py @@ -0,0 +1,61 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import unittest +from collections import defaultdict +from unittest.mock import MagicMock, patch + +from msprobe.core.hook_manager import HookSet +from msprobe.pytorch.hook_module.hook_module import HOOKModule + + +class TestHOOKModule(unittest.TestCase): + def setUp(self): + self.mock_build_hook = MagicMock(return_value=HookSet(MagicMock(), MagicMock(), MagicMock())) + HOOKModule.module_count = defaultdict(int) + HOOKModule.inner_stop_hook = defaultdict(bool) + + @patch.object(HOOKModule, '_call_func') + def test_call_with_stop_hooks(self, mock_call_func): + mock_call_func.return_value = "test_result" + module1 = HOOKModule(self.mock_build_hook) + + result = module1("arg1", "arg2", key="value") + mock_call_func.assert_called_once_with("arg1", "arg2", key="value") + self.assertEqual(result, "test_result") + + @patch.object(HOOKModule, '_call_func') + def test_call_with_start_hooks(self, mock_call_func): + mock_call_func.return_value = "test_result" + module1 = HOOKModule(self.mock_build_hook) + + result = module1("arg1", "arg2", key="value") + mock_call_func.assert_called_once_with("arg1", "arg2", key="value") + self.assertEqual(result, "test_result") + + def test_reset_module_stats(self): + HOOKModule.module_count = {"Tensor.add.0.forward": 0} + HOOKModule.reset_module_stats() + self.assertDictEqual(HOOKModule.module_count, defaultdict(int)) + + def test_add_module_count(self): + HOOKModule.add_module_count("Tensor.add.0.forward") + self.assertEqual(HOOKModule.module_count["Tensor.add.0.forward"], 1) + + def test_get_module_count(self): + HOOKModule.module_count = {"Tensor.add.0.forward": 0} + result = HOOKModule.get_module_count("Tensor.add.0.forward") + self.assertEqual(result, 0) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_hook_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_hook_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..de70f8a11fb789189b56495c1ae9d1bad977fb71 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_hook_utils.py @@ -0,0 +1,80 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest +from unittest.mock import MagicMock, patch + +from msprobe.pytorch.hook_module.utils import get_ops, dynamic_import_op + + +class MockPackage: + __name__ = "mock_package" + __file__ = "/fake_path/__init__.py" + + +class TestUtils(unittest.TestCase): + def setUp(self): + self.yaml_content = { + 'functional': ['func1', 'func2'], + 'tensor': ['tensor_op1'], + 'torch': ['torch_op1', 'torch_op2'], + 'torch_npu': ['npu_op1'] + } + + self.mock_listdir = patch('os.listdir').start() + self.mock_check_link = patch('msprobe.pytorch.hook_module.utils.check_link').start() + + def tearDown(self): + patch.stopall() + + def test_get_ops(self): + with patch('msprobe.pytorch.hook_module.utils.load_yaml') as mock_load: + mock_load.return_value = self.yaml_content + result = get_ops() + self.assertEqual( + result, + { + 'func1', + 'func2', + 'tensor_op1', + 'torch_op1', + 'torch_op2', + 'npu_op1' + } + ) + + @patch('msprobe.pytorch.hook_module.utils.inspect') + def test_dynamic_import_op_success(self, mock_inspect): + mock_func = lambda x: x + mock_inspect.getmembers = MagicMock() + mock_inspect.getmembers.return_value = [['test_func', mock_func]] + + self.mock_listdir.return_value = ['valid.py', 'invalid.py'] + mock_module = MagicMock() + + with patch('importlib.import_module', return_value=mock_module) as mock_import: + ops = dynamic_import_op(MockPackage(), white_list=['valid.py']) + self.assertEqual(ops, {'valid.test_func': mock_func}) + mock_import.assert_called_once_with('mock_package.valid') + + def test_dynamic_import_op_failure(self): + self.mock_listdir.return_value = ['fail.py'] + with patch('importlib.import_module') as mock_import: + mock_import.side_effect = ImportError("Fake error") + with patch('msprobe.pytorch.hook_module.utils.logger.warning') as mock_logger: + ops = dynamic_import_op(MockPackage(), white_list=['fail.py']) + self.assertEqual(ops, {}) + mock_logger.assert_called_once_with("import mock_package.fail failed!") \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_jit_script_wrapper.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_jit_script_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..8e93c8de18f10f477a109a6134aa653604ef3f47 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_jit_script_wrapper.py @@ -0,0 +1,51 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch + +import torch +from msprobe.pytorch.hook_module.script_wrapper import wrap_jit_script_func + + +class TestWrapJitScriptFunc(unittest.TestCase): + def setUp(self): + self.original_script = torch.jit.script + + self.mock_api_register = MagicMock() + self.mock_api_register.all_api_registered = True + self.mock_api_register.register_all_api = MagicMock() + self.mock_api_register.restore_all_api = MagicMock() + + def tearDown(self): + torch.jit.script = self.original_script + + @patch('torch.jit.script', new_callable=MagicMock) + @patch('msprobe.pytorch.hook_module.script_wrapper.get_api_register', return_value=MagicMock()) + def test_patched_script(self, mock_get_api, mock_original_script): + mock_original_script.return_value = "mocked_result" + mock_get_api.return_value = self.mock_api_register + + wrap_jit_script_func() + + self.assertNotEqual(torch.jit.script, self.original_script) + + result = torch.jit.script("test_input") + + mock_original_script.assert_called_once_with("test_input") + self.assertEqual(result, "mocked_result") + + self.mock_api_register.restore_all_api.assert_called_once() + self.mock_api_register.register_all_api.assert_called_once() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py index af669cb5c73de85e51f36f62f9e7dc61bb599ca1..650fe6026fb71b9df67fc5f716bbd1531f5cc958 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py @@ -1,15 +1,34 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest from unittest.mock import MagicMock, patch import torch - +from msprobe.core.hook_manager import HookSet from msprobe.pytorch.function_factory import npu_custom_grad_functions -from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate, white_aten_ops, \ +from msprobe.pytorch.hook_module.wrap_aten import ( + AtenOPTemplate, + white_aten_ops, AtenOPPacketTemplate +) def mock_build_hook(prefix): - return (MagicMock(), MagicMock(), MagicMock(), MagicMock()) + return HookSet(MagicMock(), MagicMock(), MagicMock()) + class TestAtenOPTemplate(unittest.TestCase): @@ -79,8 +98,8 @@ class TestAtenOPPacketTemplate(unittest.TestCase): del self.mock_op_packet.nonexistent_attr with self.assertRaises(AttributeError) as context: _ = self.template.nonexistent_attr - self.assertIn("or OpOverloadPacket does not have attribute 'nonexistent_attr'.", \ - str(context.exception)) + self.assertIn("or OpOverloadPacket does not have attribute 'nonexistent_attr'.", + str(context.exception)) @patch('msprobe.pytorch.hook_module.wrap_aten.AtenOPTemplate', autospec=True) def test_getattr_op_overload(self, MockAtenOPTemplate): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py deleted file mode 100644 index 246feb56becf9942de9214f5b24b8471e9b4024a..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +++ /dev/null @@ -1,41 +0,0 @@ -import unittest -import torch.distributed as dist -from msprobe.pytorch.hook_module.wrap_distributed import * - -class TestWrapDistributed(unittest.TestCase): - def hook(name, prefix): - def forward_pre_hook(nope, input, kwargs): - return input, kwargs - - def forward_hook(nope, input, kwargs, result): - return 2 - - def backward_hook(): - pass - - def forward_hook_torch_version_below_2(): - pass - - return forward_pre_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 - - def test_get_distributed_ops(self): - ops = get_distributed_ops() - self.assertIsInstance(ops, set) - - def test_DistributedOPTemplate(self): - self.setUp() - op_name = 'all_reduce' - if op_name in get_distributed_ops(): - op = DistributedOPTemplate(op_name, self.hook) - self.assertEqual(op.op_name_, op_name) - - def test_wrap_distributed_op(self): - op_name = 'all_reduce' - if op_name in get_distributed_ops(): - wrapped_op = wrap_distributed_op(op_name, self.hook) - self.assertTrue(callable(wrapped_op)) - - def test_wrap_distributed_ops_and_bind(self): - wrap_distributed_ops_and_bind(self.hook) - for op_name in get_distributed_ops(): - self.assertTrue(hasattr(HOOKDistributedOP, "wrap_" + str(op_name))) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py deleted file mode 100644 index 282551e3cefdb2ae63efda284f5e7ae7482ae81c..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +++ /dev/null @@ -1,73 +0,0 @@ -import unittest -import torch -import torch.nn.functional as F -from msprobe.pytorch.hook_module.wrap_functional import get_functional_ops, \ - wrap_functional_ops_and_bind, HOOKFunctionalOP -from msprobe.pytorch.common.utils import remove_dropout - - -class TestDropoutFunctions(unittest.TestCase): - - def setUp(self): - self.input_tensor = torch.ones(10, 10) - remove_dropout() - - def test_function_dropout_no_dropout(self): - output = F.dropout(self.input_tensor, p = 0., training = True) - self.assertTrue(torch.equal(self.input_tensor, output)) - - def test_function_dropout_train_vs_eval(self): - output_train = F.dropout(self.input_tensor, p = 0., training = True) - output_eval = F.dropout(self.input_tensor, p = 0., training = False) - self.assertTrue(torch.equal(output_train, output_eval)) - - def test_function_dropout_invalid_probability(self): - with self.assertRaises(ValueError): - F.dropout(self.input_tensor, p = -0.1) - with self.assertRaises(ValueError): - F.dropout(self.input_tensor, p = 1.1) - - def test_function_dropout2d_no_dropout(self): - output = F.dropout2d(self.input_tensor, p = 0., training = True) - self.assertTrue(torch.equal(self.input_tensor, output)) - - def test_function_dropout2d_train_vs_eval(self): - output_train = F.dropout2d(self.input_tensor, p = 0., training = True) - output_eval = F.dropout2d(self.input_tensor, p = 0., training = False) - self.assertTrue(torch.equal(output_train, output_eval)) - - def test_function_dropout2d_invalid_probability(self): - with self.assertRaises(ValueError): - F.dropout2d(self.input_tensor, p = -0.1) - with self.assertRaises(ValueError): - F.dropout2d(self.input_tensor, p = 1.1) - - def test_function_dropout3d_no_dropout(self): - input_tensor_3d = self.input_tensor.unsqueeze(0) - output = F.dropout3d(input_tensor_3d, p = 0., training = True) - self.assertTrue(torch.equal(input_tensor_3d, output)) - - def test_function_dropout3d_train_vs_eval(self): - input_tensor_3d = self.input_tensor.unsqueeze(0) - output_train = F.dropout3d(input_tensor_3d, p = 0., training = True) - output_eval = F.dropout3d(input_tensor_3d, p = 0., training = False) - self.assertTrue(torch.equal(output_train, output_eval)) - - def test_function_dropout3d_invalid_probability(self): - input_tensor_3d = self.input_tensor.unsqueeze(0) - with self.assertRaises(ValueError): - F.dropout3d(input_tensor_3d, p = -0.1) - with self.assertRaises(ValueError): - F.dropout3d(input_tensor_3d, p = 1.1) - - -class TestWrapFunctional(unittest.TestCase): - - def test_get_functional_ops(self): - expected_ops = {'relu', 'sigmoid', 'softmax'} - actual_ops = get_functional_ops() - self.assertTrue(expected_ops.issubset(actual_ops)) - - def test_wrap_functional_ops_and_bind(self): - wrap_functional_ops_and_bind(None) - self.assertTrue(hasattr(HOOKFunctionalOP, 'wrap_relu')) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_npu_custom.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_npu_custom.py deleted file mode 100644 index 573d6d000f37f429619b89507cecd1258fbe4c8b..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_npu_custom.py +++ /dev/null @@ -1,43 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch - -from msprobe.core.common.const import Const -from msprobe.core.common.log import logger -from msprobe.pytorch.function_factory import npu_custom_functions -from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate - -try: - import torch_npu -except ImportError: - logger.info("Failing to import torch_npu.") - - -class TestNpuOPTemplate(unittest.TestCase): - - def setUp(self): - self.mock_hook = MagicMock(return_value=(MagicMock(), MagicMock(), MagicMock(), None)) - self.template = NpuOPTemplate("sum", self.mock_hook) - - def test_init(self): - self.assertEqual(self.template.op_name_, "sum") - self.assertEqual(self.template.prefix_op_name_, f"NPU{Const.SEP}sum{Const.SEP}") - self.assertTrue(self.template.need_hook) - self.assertEqual(self.template.device, Const.CPU_LOWERCASE) - - @patch('torch.ops.npu.sum') - def test_forward_without_hook(self, mock_npu_sum): - self.template.need_hook = False - npu_custom_functions["sum"] = MagicMock(return_value="output_from_custom") - - result = self.template.forward(1, 2, key='value') - self.assertEqual(result, "output_from_custom") - mock_npu_sum.assert_not_called() - - @patch('torch.ops.npu.sum') - def test_forward_with_hook(self, mock_npu_sum): - self.template.need_hook = True - mock_npu_sum.return_value = "output_from_npu" - - result = self.template.forward(1, 2, key='value') - self.assertEqual(result, "output_from_npu") - mock_npu_sum.assert_called_once_with(1, 2, key='value') diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py deleted file mode 100644 index 6868c5bda7a88c84702d15e995c7f60af2b4e4c5..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +++ /dev/null @@ -1,40 +0,0 @@ -import unittest -import torch -from msprobe.pytorch.hook_module.wrap_tensor import get_tensor_ops, HOOKTensor, TensorOPTemplate, wrap_tensor_op, wrap_tensor_ops_and_bind - -class TestWrapTensor(unittest.TestCase): - - def hook(name, prefix): - def forward_pre_hook(nope, input, kwargs): - return input, kwargs - - def forward_hook(nope, input, kwargs, result): - return 2 - - def backward_hook(): - pass - - def forward_hook_torch_version_below_2(): - pass - - return forward_pre_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 - - def test_get_tensor_ops(self): - result = get_tensor_ops() - self.assertIsInstance(result, set) - - def test_HOOKTensor(self): - hook_tensor = HOOKTensor() - self.assertIsInstance(hook_tensor, HOOKTensor) - - def test_TensorOPTemplate(self): - tensor_op_template = TensorOPTemplate('add', self.hook) - self.assertTrue(tensor_op_template.op_name_, 'add') - - def test_wrap_tensor_op(self): - wrapped_op = wrap_tensor_op('add', self.hook) - self.assertTrue(callable(wrapped_op)) - - def test_wrap_tensor_ops_and_bind(self): - wrap_tensor_ops_and_bind(self.hook) - self.assertTrue(hasattr(HOOKTensor, 'wrap_add')) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py deleted file mode 100644 index e0e4d000c0bd83be4facbbb406357427faf875ec..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +++ /dev/null @@ -1,48 +0,0 @@ -import unittest -import torch -from msprobe.pytorch.hook_module.wrap_torch import * - -class TestWrapTorch(unittest.TestCase): - - def hook(name, prefix): - def forward_pre_hook(nope, input, kwargs): - return input, kwargs - - def forward_hook(nope, input, kwargs, result): - return 2 - - def backward_hook(): - pass - - def forward_hook_torch_version_below_2(): - pass - - return forward_pre_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 - - def setUp(self): - - self.op_name = 'add' - self.torch_op = wrap_torch_op(self.op_name, self.hook) - - def test_get_torch_ops(self): - self.setUp() - ops = get_torch_ops() - self.assertIsInstance(ops, set) - self.assertIn(self.op_name, ops) - - def test_TorchOPTemplate(self): - self.setUp() - template = TorchOPTemplate(self.op_name, self.hook) - self.assertEqual(template.op_name_, self.op_name) - self.assertEqual(template.prefix_op_name_, "Torch." + str(self.op_name) + ".") - - def test_forward(self): - self.setUp() - template = TorchOPTemplate(self.op_name, self.hook) - result = template.forward(torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])) - torch.testing.assert_close(result, torch.tensor([5, 7, 9])) - - def test_wrap_torch_ops_and_bind(self): - self.setUp() - wrap_torch_ops_and_bind(self.hook) - self.assertTrue(hasattr(HOOKTorchOP, "wrap_" + self.op_name)) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py deleted file mode 100644 index 98efb4bc5b8a30284fe820124e48af7f487d1c54..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +++ /dev/null @@ -1,11 +0,0 @@ -import unittest -import torch -from msprobe.pytorch.hook_module import wrap_vf - -class TestWrapVF(unittest.TestCase): - def setUp(self): - self.hook = lambda x: x - - def test_get_vf_ops(self): - ops = wrap_vf.get_vf_ops() - self.assertIsInstance(ops, list) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/config/stack_config.json b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/config/stack_config.json new file mode 100644 index 0000000000000000000000000000000000000000..461b447ce0cd33fdcbab3476f7c1e3bcdee9dfad --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/config/stack_config.json @@ -0,0 +1,5 @@ +{ + "targets": {}, + "format": "csv", + "stack_info": true +} \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/demo_model.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/demo_model.py index f5de419440224cca261b62df2495e8ce28b8e2d4..820b1f7476d3d92288069bc00ac798c44bf14da6 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/demo_model.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/demo_model.py @@ -1,7 +1,25 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch import torch.nn.functional as F from msprobe.pytorch import TrainerMon from msprobe.pytorch.common import seed_all +from msprobe.pytorch.hook_module.api_register import get_api_register + +get_api_register().restore_all_api() device = torch.device('cpu') dtype_float32 = torch.float32 diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_detect.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_detect.py deleted file mode 100644 index fa0960e2cc1842a138b47fad3f86c1ed0d089db8..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_detect.py +++ /dev/null @@ -1,291 +0,0 @@ -import unittest -from unittest import TestCase -from unittest.mock import patch - -from msprobe.pytorch.monitor.anomaly_detect import AnomalyTurbulence, AnomalyScanner, \ - AnomalyDataFactory, GradAnomalyData, BaseWriterWithAD, ScanRule, WriterInput - - -class TestScanRule(TestCase): - def test_apply_not_implemented(self): - scan_rule = ScanRule() - with self.assertRaises(Exception) as context: - scan_rule.apply(None, None) - - self.assertEqual(str(context.exception), "abstract method apply is not implemented") - - -class TestAnomalyTurbulence(TestCase): - - def setUp(self) -> None: - self.threshold = 0.2 - self.rule = AnomalyTurbulence(self.threshold) - - def test_apply_with_positive_baseline(self): - history = [10, 12, 14] - cur = 16 - result = self.rule.apply(history, cur) - self.assertTrue(result) - - def test_apply_with_non_positive_baseline(self): - history = [0, 0, 0] - cur = -1 - result = self.rule.apply(history, cur) - self.assertTrue(result) - - -class TestAnomalyScanner(TestCase): - - def test_load_rules_with_valied_spec(self): - specs = [ - {"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.2}} - ] - rules = AnomalyScanner.load_rules(specs) - - self.assertEqual(len(rules), 1) - self.assertIsInstance(rules[0], AnomalyTurbulence) - self.assertEqual(rules[0].threshold, 0.2) - - rules = AnomalyScanner.load_rules(None) - self.assertEqual(len(rules), 0) - - @patch("msprobe.pytorch.monitor.anomaly_detect.logger") - def test_load_rules_with_missing_keys(self, mock_logger): - specs = [ - {"rule_name": "AnomalyTurbulence"} - ] - rules = AnomalyScanner.load_rules(specs) - - self.assertEqual(len(rules), 0) - mock_logger.warning.assert_called_once_with(f"Spec is missing required keys: {specs[0]}") - - def test_load_rules_with_invalid_rule(self): - # test invalid rule_name - specs = [{"rule_name": "InvalidRule", "args": {"threshold": 0.2}}] - rules = AnomalyScanner.load_rules(specs) - self.assertEqual(len(rules), 0) - - # test invalid args - specs = [{"rule_name": "AnomalyTurbulence", "args": "invalid args"}] - rules = AnomalyScanner.load_rules(specs) - self.assertEqual(len(rules), 0) - - def test_scan(self): - ad_rules = [AnomalyTurbulence(0.2)] - # test scan with anomaly - expected = True, "AnomalyTurbulence" - self.assertEqual(AnomalyScanner.scan(ad_rules, 1.0, 2.0), expected) - # test scan with no anomaly - expected = False, None - self.assertEqual(AnomalyScanner.scan(ad_rules, 1.0, 1.0), expected) - - -class TestAnomalyDataFactory(TestCase): - - def setUp(self) -> None: - rank = 0 - pp_stage = 0 - group_mates = [0] - self.AnomalyDataFactory = AnomalyDataFactory(rank, pp_stage, group_mates) - - def test_set_call_id(self): - name2callid = {'param_name': 0} - self.AnomalyDataFactory.set_call_id(name2callid) - - self.assertEqual(self.AnomalyDataFactory.name2callid, {'param_name': 0}) - - def test_create_success(self): - tag = ('0:1.self_attention.core_attention_flash_0/rank0/output', 'min') - message = "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash_0/rank0/output', 'min') at step 2." - step = 2 - result = self.AnomalyDataFactory.create(tag, message, step) - - self.assertEqual(result.step, step) - self.assertEqual(result.tag_name, tag[0]) - self.assertEqual(result.message, message) - self.assertEqual(result.vpp_stage, 0) - - # test no vpp_stage - tag = ('1.self_attention.core_attention_flash_0/rank0/output', 'min') - result = self.AnomalyDataFactory.create(tag, message, step) - self.assertEqual(result.vpp_stage, 0) - - def test_create_failed(self): - error_tag = '0:1.self_attention.core_attention_flash_0/rank0/output' - message = "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash_0/rank0/output', 'min') at step 2." - step = 2 - with self.assertRaises(Exception) as context: - self.AnomalyDataFactory.create(error_tag, message, step) - self.assertEqual(str(context.exception), "tag must be a tuple with length 2") - - -class TestGradAnomalyData(TestCase): - - def setUp(self) -> None: - tag_name = "0:1.self_attention.core_attention_flash.output:0/rank0/actv" - message = "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash.output:0/rank0/actv', 'min') at step 2." - group_mates = [0] - self.GradAnomalyData = GradAnomalyData(tag_name=tag_name, message=message, group_mates=group_mates) - - def test_get_train_stage(self): - tag_name_list = ["0:fc2.input:0/rank0/actv", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq", ""] - expected_train_stage_list = [0, 1, 2, -1] - for tag_name, expected_train_stage in zip(tag_name_list, expected_train_stage_list): - train_stage = GradAnomalyData.get_train_stage(tag_name) - self.assertEqual(train_stage, expected_train_stage) - - def test_to_dict(self): - expected = { - 'rank': 0, - 'step': 0, - 'micro_step': 0, - 'pp_stage': 0, - 'vpp_stage': 0, - 'call_id': 0, - 'tag_name': "0:1.self_attention.core_attention_flash.output:0/rank0/actv", - 'message': "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash.output:0/rank0/actv', 'min') at step 2.", - 'group_mates': [0] - } - - self.assertEqual(self.GradAnomalyData.to_dict(), expected) - - def test_get_key(self): - expected = "0:1.self_attention.core_attention_flash.output:0/rank0/actv_step_0_call_0" - - self.assertEqual(self.GradAnomalyData.get_key(), expected) - - def test_lt_different_step(self): - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") - data2 = GradAnomalyData(step=2, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") - self.assertLess(data1, data2) - self.assertGreater(data2, data1) - - def test_lt_same_step_different_micro_step(self): - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") - data2 = GradAnomalyData(step=1, micro_step=1, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") - self.assertLess(data1, data2) - self.assertGreater(data2, data1) - - def test_lt_same_step_same_micro_step_different_vpp_stage(self): - # same forward - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/actv") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=1, pp_stage=0, call_id=0, tag_name="xxx/actv") - self.assertLess(data1, data2) - self.assertGreater(data2, data1) - - # same backward - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/post_grad") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=1, pp_stage=0, call_id=0, tag_name="xxx/post_grad") - self.assertLess(data2, data1) - self.assertGreater(data1, data2) - - # diff train stage - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/actv") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=1, pp_stage=0, call_id=0, tag_name="xxx/post_grad") - self.assertLess(data1, data2) - self.assertGreater(data2, data1) - - def test_lt_same_step_same_micro_step_same_vpp_stage_different_pp_stage(self): - # same forward - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/actv") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=1, call_id=0, tag_name="xxx/actv") - self.assertLess(data1, data2) - self.assertGreater(data2, data1) - - # same backward - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/post_grad") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=1, call_id=0, tag_name="xxx/post_grad") - self.assertLess(data2, data1) - self.assertGreater(data1, data2) - - # diff train stage - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/input") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=1, call_id=0, tag_name="xxx/post_grad") - self.assertLess(data1, data2) - self.assertGreater(data2, data1) - - def test_lt_same_step_same_micro_step_same_vpp_stage_same_pp_stage_different_call_id(self): - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=1, tag_name="") - self.assertLess(data1, data2) - self.assertGreater(data2, data1) - - def test_lt_same_data(self): - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") - self.assertGreaterEqual(data1, data2) - self.assertLessEqual(data1, data2) - - def test_lt_not_instance(self): - data = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0) - not_instance = "not an instance of GradAnomalyData" - self.assertEqual(data.__lt__(not_instance), NotImplemented) - - def test_le_same_instance(self): - # 测试相同实例的情况 - data1 = GradAnomalyData() - self.assertTrue(data1 <= data1) - - def test_le_different_instance(self): - # 测试不同实例的情况 - data1 = GradAnomalyData() - data2 = GradAnomalyData() - self.assertTrue(data1 <= data2) - - def test_le_not_instance(self): - # 测试非GradAnomalyData实例的情况 - data = GradAnomalyData() - not_instance = "Not an instance of GradAnomalyData" - self.assertEqual(data.__le__(not_instance), NotImplemented) - - def test_le_different_instance_not_equal(self): - # 测试不同实例且不相等的情况 - data1 = GradAnomalyData() - data2 = GradAnomalyData() - data2.some_attribute = "some value" - self.assertTrue(data1 <= data2) - - -class TestBaseWriterWithAD(TestCase): - - def setUp(self) -> None: - self.BaseWriter = BaseWriterWithAD(WriterInput('', None, None)) - - def test_get_anomalies(self): - expected = [] - - self.assertEqual(self.BaseWriter.get_anomalies(), expected) - - def test_clear_anomalies(self): - self.BaseWriter.anomalies = ['anomaly1', 'anomaly2'] - self.BaseWriter.clear_anomalies() - - self.assertEqual(self.BaseWriter.anomalies, []) - - @patch("msprobe.pytorch.monitor.anomaly_detect.logger") - def test_add_scalar(self, mock_logger): - AnomalyTurbulence_obj = AnomalyTurbulence(0.2) - self.BaseWriter.ad_rules = [AnomalyTurbulence_obj] - self.BaseWriter.tag2scalars = {'tag': {'avg': 1.0, 'count': 1}} - self.BaseWriter.add_scalar('tag', 2.0) - - mock_logger.info.assert_called_once() - - def test_ad(self): - AnomalyTurbulence_obj = AnomalyTurbulence(0.2) - self.BaseWriter.ad_rules = [AnomalyTurbulence_obj] - expected = True, "AnomalyTurbulence" - - self.assertEqual(self.BaseWriter._ad(2.0, 1.0), expected) - - def test_update_tag2scalars(self): - self.BaseWriter._update_tag2scalars('tag1', 1.0) - self.assertEqual(self.BaseWriter.tag2scalars['tag1']['avg'], 1.0) - self.assertEqual(self.BaseWriter.tag2scalars['tag1']['count'], 1) - self.BaseWriter._update_tag2scalars('tag1', 2.0) - self.assertEqual(self.BaseWriter.tag2scalars['tag1']['avg'], 1.5) - self.assertEqual(self.BaseWriter.tag2scalars['tag1']['count'], 2) - - -if __name__ == '__main__': - unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_csv2tb.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_csv2tb.py index f2bc82ffafc2a1f10719d4a46669bc0050c12782..09e860e7ac5048bd059f888eabfd8ad1d7f45d37 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_csv2tb.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_csv2tb.py @@ -1,8 +1,22 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import shutil import random import unittest -import pytest import torch import numpy as np import torch.nn as nn @@ -11,14 +25,13 @@ from tensorboard.backend.event_processing.event_accumulator import EventAccumula from msprobe.pytorch import TrainerMon from msprobe.core.common.const import MonitorConst from msprobe.pytorch.monitor.csv2tb import parse_step_fn, csv2tensorboard_by_step +from msprobe.pytorch.hook_module.api_register import get_api_register +get_api_register().restore_all_api() base_dir = os.path.dirname(os.path.realpath(__file__)) config_json_path = os.path.join(base_dir, "config", "all_config.json") monitor_output = os.path.join(base_dir, "./monitor_output_csv2tb") -os.environ[MonitorConst.MONITOR_OUTPUT_DIR] = monitor_output -timestamp_dirpath = None -csv2tb_dirpath = None def seed_all(seed=1234, mode=False): @@ -28,8 +41,8 @@ def seed_all(seed=1234, mode=False): torch.manual_seed(seed) torch.use_deterministic_algorithms(mode) -seed_all() +seed_all() inputs = [torch.rand(10, 10) for _ in range(10)] labels = [torch.randint(0, 5, (10,)) for _ in range(10)] @@ -47,31 +60,6 @@ class MockModule(nn.Module): return x2 -def data_collect(): - loss_fun = nn.CrossEntropyLoss() - test_module = MockModule() - nn.init.constant_(test_module.linear.weight, 1.0) - nn.init.constant_(test_module.linear.bias, 1.0) - optimizer = torch.optim.Adam(test_module.parameters()) - - monitor = TrainerMon(config_json_path, params_have_main_grad=False) - monitor.set_monitor(test_module, grad_acc_steps=1, optimizer=optimizer) - - for input_data, label in zip(inputs, labels): - output = test_module(input_data) - loss = loss_fun(output, label) - optimizer.zero_grad() - loss.backward() - optimizer.step() - - global timestamp_dirpath, csv2tb_dirpath - timestamp_dirpath = os.path.join(monitor_output, os.listdir(monitor_output)[0]) - csv2tensorboard_by_step(monitor_output) - for dirname in os.listdir(monitor_output): - if "csv2tensorboard" in dirname: - csv2tb_dirpath = os.path.join(monitor_output, dirname, "rank0") - - def extract_scalars_from_tensorboard(log_dir): # 初始化 EventAccumulator event_acc = EventAccumulator(log_dir) @@ -126,97 +114,102 @@ def compare_scalar_dicts(dict1, dict2): return True -@pytest.fixture(scope="session") -def setup_all(): - data_collect() - yield - shutil.rmtree(monitor_output) - -@pytest.mark.usefixtures("setup_all") class TestGradMonitor(unittest.TestCase): + timestamp_dirpath = None + csv2tb_dirpath = None + + @classmethod + def setUpClass(cls): + + os.environ[MonitorConst.MONITOR_OUTPUT_DIR] = monitor_output + if os.path.exists(monitor_output): + shutil.rmtree(monitor_output) + + loss_fun = nn.CrossEntropyLoss() + test_module = MockModule() + nn.init.constant_(test_module.linear.weight, 1.0) + nn.init.constant_(test_module.linear.bias, 1.0) + optimizer = torch.optim.Adam(test_module.parameters()) + + monitor = TrainerMon(config_json_path, params_have_main_grad=False) + monitor.set_monitor(test_module, grad_acc_steps=1, optimizer=optimizer) + + for input_data, label in zip(inputs, labels): + output = test_module(input_data) + loss = loss_fun(output, label) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + cls.timestamp_dirpath = os.path.join(monitor_output, os.listdir(monitor_output)[0]) + csv2tensorboard_by_step(monitor_output) + for dirname in os.listdir(monitor_output): + if "csv2tensorboard" in dirname: + cls.csv2tb_dirpath = os.path.join(monitor_output, dirname, "rank0") + os.environ.pop(MonitorConst.MONITOR_OUTPUT_DIR) def setUp(self): self.maxDiff = None - + def test_actv(self): - data = parse_step_fn(os.path.join(timestamp_dirpath,"actv_0-2.csv")) + data = parse_step_fn(os.path.join(self.timestamp_dirpath, "actv_0-2.csv")) result = { 'vp0:.input:micro0': { - 0: {'nans': 0.0,'norm': 5.550016}, - 1: {'nans': 0.0,'norm': 5.975112}, - 2: {'nans': 0.0,'norm': 5.789881} - }, + 0: {'nans': 0.0, 'norm': 5.550016}, + 1: {'nans': 0.0, 'norm': 5.975112}, + 2: {'nans': 0.0, 'norm': 5.789881} + }, 'vp0:.output:micro0': { - 0: {'nans': 0.0,'norm': 41.842655}, - 1: {'nans': 0.0,'norm': 44.40981}, - 2: {'nans': 0.0,'norm': 43.578354} - }, + 0: {'nans': 0.0, 'norm': 41.842655}, + 1: {'nans': 0.0, 'norm': 44.40981}, + 2: {'nans': 0.0, 'norm': 43.578354} + }, 'vp0:linear.input:micro0': { - 0: {'nans': 0.0,'norm': 5.550016}, - 1: {'nans': 0.0,'norm': 5.975112}, - 2: {'nans': 0.0,'norm': 5.789881} - }, + 0: {'nans': 0.0, 'norm': 5.550016}, + 1: {'nans': 0.0, 'norm': 5.975112}, + 2: {'nans': 0.0, 'norm': 5.789881} + }, 'vp0:linear.output:micro0': { - 0: {'nans': 0.0,'norm': 41.842655}, - 1: {'nans': 0.0,'norm': 44.40981}, - 2: {'nans': 0.0,'norm': 43.578354} - }, + 0: {'nans': 0.0, 'norm': 41.842655}, + 1: {'nans': 0.0, 'norm': 44.40981}, + 2: {'nans': 0.0, 'norm': 43.578354} + }, 'vp0:relu.input:micro0': { - 0: {'nans': 0.0,'norm': 41.842655}, - 1: {'nans': 0.0,'norm': 44.40981}, - 2: {'nans': 0.0,'norm': 43.578354} - }, + 0: {'nans': 0.0, 'norm': 41.842655}, + 1: {'nans': 0.0, 'norm': 44.40981}, + 2: {'nans': 0.0, 'norm': 43.578354} + }, 'vp0:relu.output:micro0': { - 0: {'nans': 0.0,'norm': 41.842655}, - 1: {'nans': 0.0,'norm': 44.40981}, - 2: {'nans': 0.0,'norm': 43.578354} - } + 0: {'nans': 0.0, 'norm': 41.842655}, + 1: {'nans': 0.0, 'norm': 44.40981}, + 2: {'nans': 0.0, 'norm': 43.578354} } - self.assertEqual(dict_equal(data, result), True) - tb_data = extract_scalars_from_tensorboard(os.path.join(csv2tb_dirpath, "actv")) + } + self.assertDictEqual(data, result) + tb_data = extract_scalars_from_tensorboard(os.path.join(self.csv2tb_dirpath, "actv")) print(tb_data) tb_result = { 'vp0:.input:micro0/nans': [(0, 0.0), - (1, 0.0), - (2, 0.0), - (3, 0.0), - (4, 0.0), - (5, 0.0), - (6, 0.0), - (7, 0.0), - (8, 0.0), - (9, 0.0)], + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], 'vp0:.input:micro0/norm': [(0, 5.550015926361084), - (1, 5.975111961364746), - (2, 5.789881229400635), - (3, 6.052319049835205), - (4, 5.573315143585205), - (5, 5.864360809326172), - (6, 5.292460918426514), - (7, 5.477899074554443), - (8, 5.884613990783691), - (9, 5.456457138061523)], + (1, 5.975111961364746), + (2, 5.789881229400635), + (3, 6.052319049835205), + (4, 5.573315143585205), + (5, 5.864360809326172), + (6, 5.292460918426514), + (7, 5.477899074554443), + (8, 5.884613990783691), + (9, 5.456457138061523)], 'vp0:.output:micro0/nans': [(0, 0.0), - (1, 0.0), - (2, 0.0), - (3, 0.0), - (4, 0.0), - (5, 0.0), - (6, 0.0), - (7, 0.0), - (8, 0.0), - (9, 0.0)], - 'vp0:.output:micro0/norm': [(0, 41.842655181884766), - (1, 44.40980911254883), - (2, 43.57835388183594), - (3, 45.83631134033203), - (4, 42.0673828125), - (5, 43.46839141845703), - (6, 39.77947235107422), - (7, 40.200843811035156), - (8, 44.453147888183594), - (9, 40.841522216796875)], - 'vp0:linear.input:micro0/nans': [(0, 0.0), (1, 0.0), (2, 0.0), (3, 0.0), @@ -226,117 +219,136 @@ class TestGradMonitor(unittest.TestCase): (7, 0.0), (8, 0.0), (9, 0.0)], + 'vp0:.output:micro0/norm': [(0, 41.842655181884766), + (1, 44.40980911254883), + (2, 43.57835388183594), + (3, 45.83631134033203), + (4, 42.0673828125), + (5, 43.46839141845703), + (6, 39.77947235107422), + (7, 40.200843811035156), + (8, 44.453147888183594), + (9, 40.841522216796875)], + 'vp0:linear.input:micro0/nans': [(0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], 'vp0:linear.input:micro0/norm': [(0, 5.550015926361084), - (1, 5.975111961364746), - (2, 5.789881229400635), - (3, 6.052319049835205), - (4, 5.573315143585205), - (5, 5.864360809326172), - (6, 5.292460918426514), - (7, 5.477899074554443), - (8, 5.884613990783691), - (9, 5.456457138061523)], + (1, 5.975111961364746), + (2, 5.789881229400635), + (3, 6.052319049835205), + (4, 5.573315143585205), + (5, 5.864360809326172), + (6, 5.292460918426514), + (7, 5.477899074554443), + (8, 5.884613990783691), + (9, 5.456457138061523)], 'vp0:linear.output:micro0/nans': [(0, 0.0), - (1, 0.0), - (2, 0.0), - (3, 0.0), - (4, 0.0), - (5, 0.0), - (6, 0.0), - (7, 0.0), - (8, 0.0), - (9, 0.0)], + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], 'vp0:linear.output:micro0/norm': [(0, 41.842655181884766), - (1, 44.40980911254883), - (2, 43.57835388183594), - (3, 45.83631134033203), - (4, 42.0673828125), - (5, 43.46839141845703), - (6, 39.77947235107422), - (7, 40.200843811035156), - (8, 44.453147888183594), - (9, 40.841522216796875)], + (1, 44.40980911254883), + (2, 43.57835388183594), + (3, 45.83631134033203), + (4, 42.0673828125), + (5, 43.46839141845703), + (6, 39.77947235107422), + (7, 40.200843811035156), + (8, 44.453147888183594), + (9, 40.841522216796875)], 'vp0:relu.input:micro0/nans': [(0, 0.0), - (1, 0.0), - (2, 0.0), - (3, 0.0), - (4, 0.0), - (5, 0.0), - (6, 0.0), - (7, 0.0), - (8, 0.0), - (9, 0.0)], + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], 'vp0:relu.input:micro0/norm': [(0, 41.842655181884766), - (1, 44.40980911254883), - (2, 43.57835388183594), - (3, 45.83631134033203), - (4, 42.0673828125), - (5, 43.46839141845703), - (6, 39.77947235107422), - (7, 40.200843811035156), - (8, 44.453147888183594), - (9, 40.841522216796875)], + (1, 44.40980911254883), + (2, 43.57835388183594), + (3, 45.83631134033203), + (4, 42.0673828125), + (5, 43.46839141845703), + (6, 39.77947235107422), + (7, 40.200843811035156), + (8, 44.453147888183594), + (9, 40.841522216796875)], 'vp0:relu.output:micro0/nans': [(0, 0.0), - (1, 0.0), - (2, 0.0), - (3, 0.0), - (4, 0.0), - (5, 0.0), - (6, 0.0), - (7, 0.0), - (8, 0.0), - (9, 0.0)], + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], 'vp0:relu.output:micro0/norm': [(0, 41.842655181884766), - (1, 44.40980911254883), - (2, 43.57835388183594), - (3, 45.83631134033203), - (4, 42.0673828125), - (5, 43.46839141845703), - (6, 39.77947235107422), - (7, 40.200843811035156), - (8, 44.453147888183594), - (9, 40.841522216796875)]} - self.assertEqual(compare_scalar_dicts(tb_data, tb_result), True) - + (1, 44.40980911254883), + (2, 43.57835388183594), + (3, 45.83631134033203), + (4, 42.0673828125), + (5, 43.46839141845703), + (6, 39.77947235107422), + (7, 40.200843811035156), + (8, 44.453147888183594), + (9, 40.841522216796875)]} + self.assertDictEqual(tb_data, tb_result) def test_actv_grad(self): - data = parse_step_fn(os.path.join(timestamp_dirpath,"actv_grad_0-2.csv")) + data = parse_step_fn(os.path.join(self.timestamp_dirpath, "actv_grad_0-2.csv")) nan = np.nan result = { 'vp0:.input:micro0': { - 0: {'norm': nan, 'nans': nan}, - 1: {'norm': nan, 'nans': nan}, + 0: {'norm': nan, 'nans': nan}, + 1: {'norm': nan, 'nans': nan}, 2: {'norm': nan, 'nans': nan} - }, + }, 'vp0:.output:micro0': { - 0: {'norm': 0.282843, 'nans': 0.0}, - 1: {'norm': 0.282617, 'nans': 0.0}, + 0: {'norm': 0.282843, 'nans': 0.0}, + 1: {'norm': 0.282617, 'nans': 0.0}, 2: {'norm': 0.282655, 'nans': 0.0} - }, + }, 'vp0:relu.input:micro0': { - 0: {'norm': 0.282843, 'nans': 0.0}, - 1: {'norm': 0.282617, 'nans': 0.0}, + 0: {'norm': 0.282843, 'nans': 0.0}, + 1: {'norm': 0.282617, 'nans': 0.0}, 2: {'norm': 0.282655, 'nans': 0.0} - }, + }, 'vp0:relu.output:micro0': { - 0: {'norm': 0.282843, 'nans': 0.0}, - 1: {'norm': 0.282617, 'nans': 0.0}, + 0: {'norm': 0.282843, 'nans': 0.0}, + 1: {'norm': 0.282617, 'nans': 0.0}, 2: {'norm': 0.282655, 'nans': 0.0} - }, + }, 'vp0:linear.input:micro0': { - 0: {'norm': nan, 'nans': nan}, - 1: {'norm': nan, 'nans': nan}, + 0: {'norm': nan, 'nans': nan}, + 1: {'norm': nan, 'nans': nan}, 2: {'norm': nan, 'nans': nan} - }, + }, 'vp0:linear.output:micro0': { - 0: {'norm': 0.282843, 'nans': 0.0}, - 1: {'norm': 0.282617, 'nans': 0.0}, + 0: {'norm': 0.282843, 'nans': 0.0}, + 1: {'norm': 0.282617, 'nans': 0.0}, 2: {'norm': 0.282655, 'nans': 0.0} - } } - self.assertEqual(dict_equal(data, result), True) - - tb_data = extract_scalars_from_tensorboard(os.path.join(csv2tb_dirpath, "actv_grad")) + } + print(data) + + tb_data = extract_scalars_from_tensorboard(os.path.join(self.csv2tb_dirpath, "actv_grad")) tb_result = { 'vp0:.input:micro0/nans': [(0, nan), (1, nan), @@ -457,88 +469,90 @@ class TestGradMonitor(unittest.TestCase): (6, 0.28316599130630493), (7, 0.28274500370025635), (8, 0.2833530008792877), - (9, 0.2825529873371124)]} - self.assertEqual(compare_scalar_dicts(tb_data, tb_result), True) + (9, 0.2825529873371124)] + } + print(tb_data) - def test_param(self): - data = parse_step_fn(os.path.join(timestamp_dirpath,"param_0-2.csv")) + data = parse_step_fn(os.path.join(self.timestamp_dirpath, "param_origin_0-2.csv")) result = { 'vp0:linear.bias': { 0: {'nans': 0.0, 'norm': 2.236068}, 1: {'nans': 0.0, 'norm': 2.236198}, 2: {'nans': 0.0, 'norm': 2.235769} - }, + }, 'vp0:linear.weight': { 0: {'nans': 0.0, 'norm': 7.071068}, 1: {'nans': 0.0, 'norm': 7.068808}, 2: {'nans': 0.0, 'norm': 7.06771} - } } - self.assertEqual(dict_equal(data, result), True) - tb_data = extract_scalars_from_tensorboard(os.path.join(csv2tb_dirpath, "param")) + } + self.assertDictEqual(data, result) + tb_data = extract_scalars_from_tensorboard(os.path.join(self.csv2tb_dirpath, "param_origin")) tb_result = { 'vp0:linear.weight/norm': [ - (0, 7.071067810058594), - (1, 7.068808078765869), - (2, 7.067709922790527), - (3, 7.0673418045043945), - (4, 7.066926956176758), - (5, 7.066311836242676), - (6, 7.065629959106445), - (7, 7.065262794494629), - (8, 7.065001964569092), - (9, 7.064840793609619)], + (0, 7.071067810058594), + (1, 7.068808078765869), + (2, 7.067709922790527), + (3, 7.0673418045043945), + (4, 7.066926956176758), + (5, 7.066311836242676), + (6, 7.065629959106445), + (7, 7.065262794494629), + (8, 7.065001964569092), + (9, 7.064840793609619)], 'vp0:linear.weight/nans': [ - (0, 0.0), - (1, 0.0), - (2, 0.0), - (3, 0.0), - (4, 0.0), - (5, 0.0), - (6, 0.0), - (7, 0.0), - (8, 0.0), - (9, 0.0)], + (0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], 'vp0:linear.bias/norm': [ - (0, 2.2360680103302), - (1, 2.2361979484558105), - (2, 2.235769033432007), - (3, 2.235903024673462), - (4, 2.2360129356384277), - (5, 2.2359039783477783), - (6, 2.2357990741729736), - (7, 2.2357349395751953), - (8, 2.2356700897216797), - (9, 2.235619068145752)], + (0, 2.2360680103302), + (1, 2.2361979484558105), + (2, 2.235769033432007), + (3, 2.235903024673462), + (4, 2.2360129356384277), + (5, 2.2359039783477783), + (6, 2.2357990741729736), + (7, 2.2357349395751953), + (8, 2.2356700897216797), + (9, 2.235619068145752) + ], 'vp0:linear.bias/nans': [ - (0, 0.0), - (1, 0.0), - (2, 0.0), - (3, 0.0), - (4, 0.0), - (5, 0.0), - (6, 0.0), - (7, 0.0), - (8, 0.0), - (9, 0.0)] - } - self.assertEqual(compare_scalar_dicts(tb_data, tb_result), True) + (0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0) + ] + } + self.assertDictEqual(tb_data, tb_result) def test_exp_avg(self): - data = parse_step_fn(os.path.join(timestamp_dirpath,"exp_avg_0-2.csv")) + data = parse_step_fn(os.path.join(self.timestamp_dirpath, "exp_avg_0-2.csv")) result = { 'vp0:linear.bias': { 1: {'nans': 0.0, 'norm': 0.024495}, 2: {'nans': 0.0, 'norm': 0.052203} - }, + }, 'vp0:linear.weight': { 1: {'nans': 0.0, 'norm': 0.052394}, 2: {'nans': 0.0, 'norm': 0.099221} - } } - self.assertEqual(dict_equal(data, result), True) - tb_data = extract_scalars_from_tensorboard(os.path.join(csv2tb_dirpath, "exp_avg")) + } + self.assertDictEqual(data, result) + tb_data = extract_scalars_from_tensorboard(os.path.join(self.csv2tb_dirpath, "exp_avg")) tb_result = { 'vp0:linear.bias/nans': [(1, 0.0), (2, 0.0), @@ -576,22 +590,22 @@ class TestGradMonitor(unittest.TestCase): (7, 0.11372199654579163), (8, 0.12264800071716309), (9, 0.09017200022935867)]} - self.assertEqual(compare_scalar_dicts(tb_data, tb_result), True) + self.assertDictEqual(tb_data, tb_result) def test_exp_avg_sq(self): - data = parse_step_fn(os.path.join(timestamp_dirpath,"exp_avg_sq_0-2.csv")) + data = parse_step_fn(os.path.join(self.timestamp_dirpath, "exp_avg_sq_0-2.csv")) result = { 'vp0:linear.bias': { 1: {'nans': 0.0, 'norm': 4.2e-05}, 2: {'nans': 0.0, 'norm': 9.6e-05} - }, + }, 'vp0:linear.weight': { 1: {'nans': 0.0, 'norm': 6.7e-05}, 2: {'nans': 0.0, 'norm': 0.000126} - } } - self.assertEqual(dict_equal(data, result), True) - tb_data = extract_scalars_from_tensorboard(os.path.join(csv2tb_dirpath, "exp_avg_sq")) + } + self.assertDictEqual(data, result) + tb_data = extract_scalars_from_tensorboard(os.path.join(self.csv2tb_dirpath, "exp_avg_sq")) tb_result = { 'vp0:linear.bias/nans': [(1, 0.0), (2, 0.0), @@ -629,24 +643,24 @@ class TestGradMonitor(unittest.TestCase): (7, 0.00026000000070780516), (8, 0.00028700000257231295), (9, 0.0003060000017285347)]} - self.assertEqual(compare_scalar_dicts(tb_data, tb_result), True) - + self.assertDictEqual(tb_data, tb_result) + def test_grad_reduced(self): - data = parse_step_fn(os.path.join(timestamp_dirpath,"grad_reduced_0-2.csv")) + data = parse_step_fn(os.path.join(self.timestamp_dirpath, "grad_reduced_0-2.csv")) result = { 'vp0:linear.bias': { 0: {'nans': 0.0, 'norm': 0.244949}, 1: {'nans': 0.0, 'norm': 0.314345}, 2: {'nans': 0.0, 'norm': 0.281475} - }, + }, 'vp0:linear.weight': { 0: {'nans': 0.0, 'norm': 0.523935}, 1: {'nans': 0.0, 'norm': 0.595672}, 2: {'nans': 0.0, 'norm': 0.497603} - } } - self.assertEqual(dict_equal(data, result), True) - tb_data = extract_scalars_from_tensorboard(os.path.join(csv2tb_dirpath, "grad_reduced")) + } + self.assertDictEqual(data, result) + tb_data = extract_scalars_from_tensorboard(os.path.join(self.csv2tb_dirpath, "grad_reduced")) tb_result = { 'vp0:linear.bias/nans': [(0, 0.0), (1, 0.0), @@ -688,25 +702,25 @@ class TestGradMonitor(unittest.TestCase): (7, 0.4831080138683319), (8, 0.3234719932079315), (9, 0.32385098934173584)]} - self.assertEqual(compare_scalar_dicts(tb_data, tb_result), True) - + self.assertDictEqual(tb_data, tb_result) + def test_grad_unreduced(self): - data = parse_step_fn(os.path.join(timestamp_dirpath,"grad_unreduced_0-2.csv")) + data = parse_step_fn(os.path.join(self.timestamp_dirpath, "grad_unreduced_0-2.csv")) result = { 'vp0:linear.bias': { 0: {'nans': 0.0, 'norm': 0.244949}, 1: {'nans': 0.0, 'norm': 0.314345}, 2: {'nans': 0.0, 'norm': 0.281475} - }, + }, 'vp0:linear.weight': { 0: {'nans': 0.0, 'norm': 0.523935}, 1: {'nans': 0.0, 'norm': 0.595672}, 2: {'nans': 0.0, 'norm': 0.497603} - } } - self.assertEqual(dict_equal(data, result), True) + } + self.assertDictEqual(data, result) - tb_data = extract_scalars_from_tensorboard(os.path.join(csv2tb_dirpath, "grad_unreduced")) + tb_data = extract_scalars_from_tensorboard(os.path.join(self.csv2tb_dirpath, "grad_unreduced")) tb_result = { 'vp0:linear.bias/nans': [(0, 0.0), (1, 0.0), @@ -748,4 +762,8 @@ class TestGradMonitor(unittest.TestCase): (7, 0.4831080138683319), (8, 0.3234719932079315), (9, 0.32385098934173584)]} - self.assertEqual(compare_scalar_dicts(tb_data, tb_result), True) + self.assertDictEqual(tb_data, tb_result) + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_data_writers.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_data_writers.py new file mode 100644 index 0000000000000000000000000000000000000000..34204267935cd7691f5bcccce6c1af5451a2c34f --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_data_writers.py @@ -0,0 +1,52 @@ +import unittest +from unittest import TestCase +from unittest.mock import patch + +from msprobe.core.monitor.anomaly_processor import AnomalyTurbulence +from msprobe.pytorch.monitor.data_writers import BaseWriterWithAD, WriterInput + + +class TestBaseWriterWithAD(TestCase): + + def setUp(self) -> None: + self.BaseWriter = BaseWriterWithAD(WriterInput('', None, None)) + + def test_get_anomalies(self): + expected = [] + + self.assertEqual(self.BaseWriter.get_anomalies(), expected) + + def test_clear_anomalies(self): + self.BaseWriter.anomalies = ['anomaly1', 'anomaly2'] + self.BaseWriter.clear_anomalies() + + self.assertEqual(self.BaseWriter.anomalies, []) + + @patch("msprobe.pytorch.monitor.data_writers.logger") + def test_add_scalar(self, mock_logger): + AnomalyTurbulence_obj = AnomalyTurbulence(0.2) + self.BaseWriter.ad_rules = [AnomalyTurbulence_obj] + tag = ('0:1.post_attention_norm.weight/rank0/pre_grad', 'mean') + self.BaseWriter.tag2scalars = {tag: {'avg': 1.0, 'count': 1}} + self.BaseWriter.add_scalar(tag, 2.0) + + mock_logger.info.assert_called_once() + + def test_ad(self): + AnomalyTurbulence_obj = AnomalyTurbulence(0.2) + self.BaseWriter.ad_rules = [AnomalyTurbulence_obj] + expected = True, "AnomalyTurbulence" + + self.assertEqual(self.BaseWriter._ad(2.0, 1.0), expected) + + def test_update_tag2scalars(self): + self.BaseWriter._update_tag2scalars('tag1', 1.0) + self.assertEqual(self.BaseWriter.tag2scalars['tag1']['avg'], 1.0) + self.assertEqual(self.BaseWriter.tag2scalars['tag1']['count'], 1) + self.BaseWriter._update_tag2scalars('tag1', 2.0) + self.assertEqual(self.BaseWriter.tag2scalars['tag1']['avg'], 1.01) + self.assertEqual(self.BaseWriter.tag2scalars['tag1']['count'], 2) + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py index ff00cf7490d8110f2198df57ee5d91b6b75f5092..a2e7ecb512d195305e48badd03c7e1b4ae30b237 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py @@ -1,8 +1,10 @@ import unittest +from unittest.mock import patch + import torch from msprobe.pytorch.monitor.features import square_sum, get_min, get_mean, get_norm, get_max, get_zeros, \ get_sign_matches, eff_rank, mNTK, lambda_max_subsample, cal_histc, get_nans - +from msprobe.pytorch.monitor.features import max_eigenvalue, cal_entropy, cal_qkt, cal_stable_rank class TestMathFunctions(unittest.TestCase): def test_square_sum(self): @@ -87,6 +89,74 @@ class TestMathFunctions(unittest.TestCase): result = get_nans(tensor) self.assertEqual(result, 1) + def test_max_eigenvalue(self): + """测试最大特征值计算""" + # 创建已知特征值的矩阵 + A = torch.diag(torch.tensor([3.0, 2.0, 1.0])) + + # 测试不同迭代次数 + eigval = max_eigenvalue(A, num_iterations=5) + self.assertAlmostEqual(eigval.item(), 3.0, delta=0.1) + + # 测试全零矩阵 + zero_matrix = torch.zeros(3, 3) + eigval = max_eigenvalue(zero_matrix) + self.assertAlmostEqual(eigval.item(), 0.0) + + def test_cal_entropy(self): + """测试注意力熵计算""" + # 创建简单的注意力分数 + qk = torch.tensor([[1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0]]) + + # 无mask + entropy, softmax_max = cal_entropy(qk) + self.assertAlmostEqual(entropy, 0.4715, delta=0.1) + self.assertAlmostEqual(softmax_max, 0.7988, delta=0.1) + + # 带mask 和默认生成相同 + mask = torch.tensor([[1, 0, 0], + [1, 1, 0], + [1, 1, 1]], dtype=torch.float) + entropy, softmax_max = cal_entropy(qk, mask) + self.assertAlmostEqual(entropy, 0.4715, delta=0.1) + self.assertAlmostEqual(softmax_max, 0.7988, delta=0.1) + + @patch("msprobe.pytorch.monitor.features.logger") + def test_cal_qkt(self, mock_logger): + """测试QK^T计算""" + # 测试s,b,h,d顺序 + q = torch.randn(10, 2, 4, 8) # [s, b, h, d] + k = torch.randn(10, 2, 4, 8) # [s, b, h, d] + q_batch = torch.randn(2, 10, 4, 8) # [b, s, h, d] + qkt = cal_qkt(q, k, order="s,b,h,d") + self.assertEqual(qkt.shape, (10, 10)) # [s, s] + + # 测试b,s,h,d顺序 + qkt = cal_qkt(q_batch, q_batch, order="b,s,h,d") + self.assertEqual(qkt.shape, (10, 10)) # [s, s] + + # 测试无效顺序 + cal_qkt(q, k, order="invalid_order") + mock_logger.warning.assert_called_with( + "Calculate qk tensor failed: Order unsupported.") + + def test_cal_stable_rank(self): + """测试谱半径计算""" + # 创建已知谱半径的矩阵 + A = torch.diag(torch.tensor([3.0, 2.0, 1.0])) + sr, eig = cal_stable_rank(A) + + # 验证Frobenius范数 + fro_norm = torch.norm(A, p='fro') + self.assertAlmostEqual(sr, fro_norm / 3.0, delta=.5) # 最大特征值为3 + + # 测试正交矩阵 + ortho = torch.eye(5) + sr, eig = cal_stable_rank(ortho) + self.assertAlmostEqual(sr, torch.tensor(2.23/1), delta=.5) # F范数应为2.23 + self.assertAlmostEqual(eig, torch.tensor(1.0), delta=.1) # 特征值应为1 if __name__ == '__main__': unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_module_hook.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_module_hook.py index eefacb73c8e76636086554775b0e6f2e916ddf6e..6c3d2b925a4fbcd7ec7c81b85614b6be0e731b0c 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_module_hook.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_module_hook.py @@ -1,3 +1,18 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os.path import shutil import unittest @@ -8,10 +23,13 @@ import torch from msprobe.core.common.const import MonitorConst, Const from torch import distributed as dist +from msprobe.pytorch import TrainerMon +from msprobe.pytorch.hook_module.api_register import get_api_register from msprobe.pytorch.monitor.module_hook import CommunicationContext, GradContext, ModuleHookContext, \ param_is_not_tensor_parallel_duplicate, param_is_data_parallel_duplicate from msprobe.test.pytorch_ut.monitor.demo_model import monitor_demo -from msprobe.pytorch import TrainerMon + +get_api_register().restore_all_api() base_dir = os.path.dirname(os.path.realpath(__file__)) @@ -72,13 +90,13 @@ class TestModuleHook(unittest.TestCase): self.assertTrue(os.path.exists(actv_grad_0_csv)) # validate columns and lines actv_0 = pd.read_csv(actv_0_csv) - expect_columns = ['vpp_stage', 'name', 'step', 'micro_step', 'norm', 'nans'] + expect_columns = ['vpp_stage', 'name', 'step', 'micro_step', 'norm', 'nans', "shape", "dtype"] self.assertListEqual(list(actv_0.columns), expect_columns) - self.assertEqual(actv_0.shape, tuple([6, 6])) + self.assertEqual(actv_0.shape, tuple([6, 8])) actv_grad_0 = pd.read_csv(actv_grad_0_csv) - expect_columns = ['vpp_stage', 'name', 'step', 'micro_step', 'norm', 'nans'] + expect_columns = ['vpp_stage', 'name', 'step', 'micro_step', 'norm', 'nans', "shape", "dtype"] self.assertListEqual(list(actv_grad_0.columns), expect_columns) - self.assertEqual(actv_0.shape, tuple([6, 6])) + self.assertEqual(actv_0.shape, tuple([6, 8])) def test_wg_distribution(self): self.get_dist_mock(False) @@ -95,13 +113,13 @@ class TestModuleHook(unittest.TestCase): self.assertTrue(os.path.exists(grad_reduced_0_csv)) self.assertTrue(os.path.exists(grad_unreduced_0_csv)) # validate columns and lines - expect_columns = ["vpp_stage", "name", "step", "norm"] + expect_columns = ["vpp_stage", "name", "step", "norm", "shape", "dtype"] grad_reduced_0 = pd.read_csv(grad_reduced_0_csv) self.assertListEqual(list(grad_reduced_0.columns), expect_columns) - self.assertEqual(grad_reduced_0.shape, tuple([2, 4])) + self.assertEqual(grad_reduced_0.shape, tuple([2, 6])) grad_unreduced_0 = pd.read_csv(grad_unreduced_0_csv) self.assertListEqual(list(grad_unreduced_0.columns), expect_columns) - self.assertEqual(grad_unreduced_0.shape, tuple([2, 4])) + self.assertEqual(grad_unreduced_0.shape, tuple([2, 6])) def test_mv_distribution(self): self.get_dist_mock(False) @@ -118,13 +136,13 @@ class TestModuleHook(unittest.TestCase): self.assertTrue(os.path.exists(exp_avg_1_csv)) self.assertTrue(os.path.exists(exp_avg_sq_1_csv)) # validate columns and lines - expect_columns = ["vpp_stage", "name", "step", "norm"] + expect_columns = ["vpp_stage", "name", "step", "norm", "shape", "dtype"] exp_avg_1 = pd.read_csv(exp_avg_1_csv) self.assertListEqual(list(exp_avg_1.columns), expect_columns) - self.assertEqual(exp_avg_1.shape, tuple([2, 4])) + self.assertEqual(exp_avg_1.shape, tuple([2, 6])) exp_avg_sq_1 = pd.read_csv(exp_avg_sq_1_csv) self.assertListEqual(list(exp_avg_sq_1.columns), expect_columns) - self.assertEqual(exp_avg_sq_1.shape, tuple([2, 4])) + self.assertEqual(exp_avg_sq_1.shape, tuple([2, 6])) def test_ur_distribution(self): self.get_dist_mock(False) @@ -149,6 +167,18 @@ class TestModuleHook(unittest.TestCase): ) self.assertIsNotNone(hooker) + def test_stack_collect(self): + self.get_dist_mock(False) + stack_monitor_output = "./test_stack_info" + clean_output(stack_monitor_output) + os.environ[MonitorConst.MONITOR_OUTPUT_DIR] = stack_monitor_output + stack_config = os.path.join(base_dir, "config/stack_config.json") + monitor_demo(stack_config) + output_dir_list = os.listdir(stack_monitor_output) + self.assertEqual(len(output_dir_list), 1) + stack_csv_path = os.path.join(stack_monitor_output, output_dir_list[0], "stack_info.csv") + self.assertTrue(os.path.exists(stack_csv_path)) + def test_adhoc_check(self): # mock dist self.get_dist_mock(True) @@ -243,61 +273,6 @@ class TestParamIsDataParallelDuplicate(unittest.TestCase): self.assertFalse(result) -class TestModuleHookContext(unittest.TestCase): - def setUp(self): - self.module_name = "test_module" - self.context = ModuleHookContext(self.module_name) - self.context.struct = { - Const.INPUT: { - "config": "tuple[1]", - "0": "size=(2, 784), dtype=torch.float32", - }, - Const.OUTPUT: { - "config": "tensor", - "tensor": "size=(2, 10), dtype=torch.float32" - }, - MonitorConst.INPUT_GRAD: { - "config": "tuple[1]", - "0": "size=(2, 784), dtype=torch.float32" - }, - MonitorConst.OUTPUT_GRAD: { - "config": "tuple[1]", - "0": "size=(2, 10), dtype=torch.float32" - } - } - self.target_config = { - self.module_name: { - Const.INPUT: "tuple[1]:0", - Const.OUTPUT: "tensor", - MonitorConst.INPUT_GRAD: "tuple[1]:0" - } - } - - def test_set_format_by_arg_module_name_in_target_config(self): - self.context.set_format_by_arg(Const.INPUT, self.target_config) - self.assertEqual(self.context.format_by_arg[Const.INPUT], "tuple[1]:0") - self.context.set_format_by_arg(Const.OUTPUT, self.target_config) - self.assertEqual(self.context.format_by_arg[Const.OUTPUT], "tensor") - self.context.set_format_by_arg(MonitorConst.INPUT_GRAD, self.target_config) - self.assertEqual(self.context.format_by_arg[MonitorConst.INPUT_GRAD], "tuple[1]:0") - self.context.set_format_by_arg(MonitorConst.OUTPUT_GRAD, self.target_config) - self.assertEqual(self.context.format_by_arg[MonitorConst.OUTPUT_GRAD], "tuple[1]") - - def test_set_format_by_arg_module_name_not_in_target_config(self): - target_config = {} - self.context.set_format_by_arg(Const.INPUT, target_config) - self.assertEqual(self.context.format_by_arg[Const.INPUT], "tuple[1]") - self.context.set_format_by_arg(Const.OUTPUT, target_config) - self.assertEqual(self.context.format_by_arg[Const.OUTPUT], "tensor") - - @patch('msprobe.pytorch.monitor.module_hook.logger') - def test_set_format_by_arg_target_module_config_error(self, mock_logger): - target_config = {self.module_name: {Const.INPUT: 123}} - self.context.set_format_by_arg(Const.INPUT, target_config) - self.assertIsNone(self.context.format_by_arg.get(Const.INPUT)) - mock_logger.warning_on_rank_0.assert_called_once() - - class TestContext(unittest.TestCase): def test_communication_context(self): cc_ctx = CommunicationContext() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py index 0462ac3f39531119b40d3cc5051fad77f687b9b5..2f10d4f12906bfd91cfd304d157f5946dd68524b 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py @@ -5,10 +5,11 @@ from unittest.mock import patch, MagicMock import torch from msprobe.core.common.const import MonitorConst -from msprobe.pytorch.monitor.utils import filter_special_chars, MsgConst, get_param_struct, validate_ops, \ - validate_ranks, validate_targets, validate_print_struct, validate_ur_distribution, validate_xy_distribution, \ +from msprobe.core.monitor.utils import filter_special_chars, MsgConst, validate_ops, validate_ranks, \ + validate_targets, validate_print_struct, validate_ur_distribution, validate_xy_distribution, \ validate_mg_distribution, validate_wg_distribution, validate_cc_distribution, validate_alert, validate_config, \ - get_output_base_dir + get_output_base_dir, validate_l2_targets, validate_recording_l2_features, validate_sa_order +from msprobe.pytorch.monitor.utils import get_param_struct from msprobe.pytorch.common.utils import is_recomputation @@ -44,12 +45,12 @@ class TestValidationFunctions(unittest.TestCase): def test_validate_ops(self): ops = ['op1', 'op2', 'norm', 'max'] valid_ops = validate_ops(ops) - self.assertEqual(valid_ops, ['norm', 'max']) + self.assertEqual(valid_ops, ['norm', 'max', "shape", "dtype"]) def test_no_valid_ops(self): ops = ['op1', 'op2'] valid_ops = validate_ops(ops) - target_ops = [MonitorConst.OP_LIST[0]] + target_ops = [MonitorConst.OP_LIST[0], "shape", "dtype"] self.assertEqual(valid_ops, target_ops) def test_validate_ranks(self): @@ -104,13 +105,72 @@ class TestValidationFunctions(unittest.TestCase): 'alert': {'rules': [{'rule_name': 'AnomalyTurbulence', 'args': {'threshold': 10.0}}], 'dump': True} } validate_config(config) - target_ops = [MonitorConst.OP_LIST[0]] + target_ops = [MonitorConst.OP_LIST[0], "shape", "dtype"] self.assertEqual(config["ops"], target_ops) del config["targets"] validate_config(config) self.assertEqual(config["targets"], {"": {}}) self.assertEqual(config["all_xy"], True) + # ===== validate_l2_targets 测试 ===== + def test_validate_l2_targets_valid_input(self): + """测试合法输入""" + valid_targets = { + "attention_hook": ["0:0.self_attention.core_attention.flash_attention"], + "linear_hook": [] + } + validate_l2_targets(valid_targets) + + def test_validate_l2_targets_invalid_root_type(self): + """测试非 dict 输入""" + with self.assertRaises(TypeError) as cm: + validate_l2_targets("not_a_dict") + self.assertEqual(str(cm.exception), + 'l2_targets in config.json should be a dict') + + def test_validate_l2_targets_invalid_hook_name(self): + """测试非法 hook_name""" + with self.assertRaises(TypeError) as cm: + validate_l2_targets({"invalid_hook": ["module1"]}) + self.assertIn(f'key of l2_targtes must be in {MonitorConst.L2_HOOKS}', + str(cm.exception)) + + def test_validate_l2_targets_invalid_value_type(self): + """测试非法 value 类型""" + with self.assertRaises(TypeError) as cm: + validate_l2_targets({"linear_hook": "not_a_list"}) + self.assertEqual(str(cm.exception), + 'values of l2_targets should be a list in config.json') + + def test_validate_l2_targets_invalid_item_type(self): + """测试非法 list item 类型""" + with self.assertRaises(TypeError) as cm: + validate_l2_targets({"linear_hook": [123]}) + self.assertEqual(str(cm.exception), + 'item of "linear_hook" in l2_targets should be module_name[str] in config.json') + + # ===== validate_recording_l2_features 测试 ===== + def test_validate_recording_l2_features_valid(self): + """测试合法布尔值输入""" + validate_recording_l2_features(True) + validate_recording_l2_features(False) + + def test_validate_recording_l2_features_invalid_type(self): + """测试非法类型输入""" + with self.assertRaises(TypeError) as cm: + validate_recording_l2_features("xx") + self.assertEqual(str(cm.exception), + "recording_l2_features should be a bool") + + def test_valid_orders(self): + validate_sa_order("b,s,h,d") + validate_sa_order("s, b,h, d") + + def test_invalid_orders(self): + with self.assertRaises(TypeError) as cm: + validate_recording_l2_features("xx") + self.assertEqual(str(cm.exception), + f'sa_order must be in {MonitorConst.SA_ORDERS}, got xx') class TestIsRecomputation(unittest.TestCase): @patch('inspect.stack') diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py index 793b086b02db03f8a04b159f35f1df55fc1a9d2c..e8cbd00a0f31589104a50340f990558bd0277be9 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py @@ -3,18 +3,51 @@ from collections import defaultdict from unittest.mock import Mock, patch, MagicMock import torch +from msprobe.core.common.const import MonitorConst from msprobe.pytorch.monitor.optimizer_collect import OptimizerMon, \ - OptimizerMonFactory, DummyOptimizerMon, \ - MixPrecisionOptimizerMon, MegatronDistributedOptimizerMon, MegatronFP32OptimizerMon, \ + OptimizerMonFactory, MegatronMixPrecisionOptimizerMon, MegatronDistributedOptimizerMon, \ MegatronChainedDistributedOptimizerMon, MegatronChainedMixPrecisionOptimizerMon, \ - DeepSpeedZeroOptimizerStage0Mon, DeepSpeedZeroOptimizerStage1or2Mon, DeepSpeedZeroOptimizerStage3Mon - -from msprobe.pytorch.monitor.utils import MVResult, MVGradResult - + DeepSpeedZeroOptimizerMon, DeepSpeedZeroOptimizerStage0Mon, \ + DeepSpeedZeroOptimizerStage1or2Mon, DeepSpeedZeroOptimizerStage3Mon +from msprobe.core.monitor.utils import MVResult + + +def setup_param_groups(num_groups=2, params_per_group=5): + bit16_groups = [] + param_names = {} + grad_position = {} + param_slice_mappings = [] + count = 0 + for group_idx in range(num_groups): + group = [] + param_slice_mapping = {} + offset = 0 + for i in range(params_per_group): + name = f'param{group_idx}_{i}' + p = torch.nn.Parameter(torch.randn(2,3, dtype=torch.bfloat16)) + p.ds_tensor = torch.nn.Parameter(torch.randn(1,3, dtype=torch.bfloat16)) + p.ds_id = count + param_slice_mapping[name] = MagicMock(start=offset, numel=p.numel()) + group.append(p) + param_names[p] = name + grad_position[count] = [group_idx, offset, p.numel()] + offset += p.numel() + count += 1 + bit16_groups.append(group) + param_slice_mappings.append(param_slice_mapping) + + return bit16_groups, param_names, param_slice_mappings, grad_position + +def setup_mock_monitor(): + mock_monitor = MagicMock() + mock_monitor.mv_distribution = True + mock_monitor.mg_direction = False + mock_monitor.ur_distribution = False + + return mock_monitor class TestOptimizerMon(unittest.TestCase): def setUp(self) -> None: - # 初始化需要的monitor, torch_opt, params2name等对象 self.monitor = Mock() self.monitor.mv_distribution = True self.monitor.mg_direction = True @@ -23,11 +56,11 @@ class TestOptimizerMon(unittest.TestCase): self.monitor.ratio_heatmap_visualizer = {'param1': Mock(), 'param2': Mock()} def test_fetch_mv(self): - optimizer_mon = OptimizerMon() - res = optimizer_mon.fetch_mv(None, None, None) - self.assertEqual(res, None) + optimizer_mon = OptimizerMon(None) + res = optimizer_mon.fetch_mv(None, {}) + self.assertEqual(res.exp_avg, {}) - def test_fetch_mv_in_adam(self): + def test_fetch_mv(self): self.torch_opt = Mock() self.torch_opt.state = { 'param1': {'exp_avg': torch.tensor(0.1), 'exp_avg_sq': torch.tensor(0.2), 'step': torch.tensor(10)}, @@ -37,48 +70,10 @@ class TestOptimizerMon(unittest.TestCase): self.torch_opt.defaults = {'betas': (0.9, 0.999), 'eps': 1e-8} self.params2name = {'param1': 'param1', 'param2': 'param2'} - self.optimizer_mon = OptimizerMon() - result = self.optimizer_mon._fetch_mv_in_adam(self.monitor, self.torch_opt, self.params2name) + self.optimizer_mon = OptimizerMon(None) + result = self.optimizer_mon.fetch_mv(self.monitor, self.params2name) self.assertIsInstance(result, MVResult) - @patch('msprobe.pytorch.monitor.optimizer_collect.dist') - def test_fetch_mv_grad_in_adam(self, mock_dist): - self.optimizer_mon = OptimizerMon() - self.monitor = MagicMock() - self.torch_opt = MagicMock() - self.params2name = defaultdict(str) - self.name2indices = defaultdict(tuple) - self.fp32_partitioned_groups_flat = defaultdict(torch.Tensor) - - # Mocking the dist.get_rank() and dist.get_world_size() - mock_dist.get_rank.return_value = 0 - mock_dist.get_world_size.return_value = 1 - - # Mocking the wrapped_optimizer - self.torch_opt.state = defaultdict(dict) - self.torch_opt.averaged_gradients = defaultdict(torch.Tensor) - self.torch_opt.partition_size = defaultdict(int) - self.torch_opt.flatten_dense_tensors_aligned = MagicMock() - self.torch_opt.flatten = MagicMock() - - # Mocking the torch_opt.param_groups - self.torch_opt.param_groups = [{'step': 1, 'betas': (0.9, 0.999)}, - {'step': 2, 'betas': (0.9, 0.999)}, - {'step': 3, 'betas': (0.9, 0.999)}] - - # Mocking the monitor.mv_distribution, monitor.mg_direction, monitor.ur_distribution - self.monitor.mv_distribution = True - self.monitor.mg_direction = True - self.monitor.ur_distribution = True - - # Mocking the monitor.update_heatmap_visualizer and monitor.ratio_heatmap_visualizer - self.monitor.update_heatmap_visualizer = defaultdict(MagicMock) - self.monitor.ratio_heatmap_visualizer = defaultdict(MagicMock) - - result = self.optimizer_mon._fetch_mv_grad_in_adam(self.monitor, self.torch_opt, self.params2name, - self.name2indices, self.fp32_partitioned_groups_flat) - self.assertIsInstance(result, MVGradResult) - class TestMixPrecisionOptimizerMon(unittest.TestCase): def test_fetch_mv_with_fp16_to_fp32_param_and_mix_prec_opt(self): @@ -89,16 +84,16 @@ class TestMixPrecisionOptimizerMon(unittest.TestCase): self.mix_prec_opt = MagicMock() self.mix_prec_opt.float16_groups = [MagicMock()] self.mix_prec_opt.fp32_from_float16_groups = [MagicMock()] - self.optimizer = MixPrecisionOptimizerMon() + self.optimizer = MegatronMixPrecisionOptimizerMon(self.torch_opt) self.optimizer.fp16_to_fp32_param = {} - # Mock _fetch_mv_in_adam method and set a fixed return value + # Mock fetch_mv method and set a fixed return value mv_result = MVResult(exp_avg={}, exp_avg_sq={}, update={}, ratio={}) - self.mock_fetch_mv_in_adam = MagicMock(return_value=mv_result) - self.optimizer._fetch_mv_in_adam = self.mock_fetch_mv_in_adam + self.mock_fetch_mv = MagicMock(return_value=mv_result) + self.optimizer.fetch_mv = self.mock_fetch_mv - res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) - self.mock_fetch_mv_in_adam.assert_called_once_with(self.monitor, self.torch_opt, self.params2name) + res = self.optimizer.fetch_mv(self.monitor, self.params2name) + self.mock_fetch_mv.assert_called_once_with(self.monitor, self.params2name) self.assertIsInstance(res, MVResult) @@ -110,17 +105,17 @@ class TestChainedMixPrecisionOptimizerMon(unittest.TestCase): self.params2name = MagicMock() self.torch_opt.float16_groups = [MagicMock()] self.torch_opt.fp32_from_float16_groups = [MagicMock()] - self.optimizer = MegatronChainedMixPrecisionOptimizerMon() + self.optimizer = MegatronChainedMixPrecisionOptimizerMon(self.torch_opt) self.optimizer.optimizer = [MagicMock(), MagicMock()] self.optimizer.fp16_to_fp32_param = {} - # Mock _fetch_mv_in_adam method and set a fixed return value + # Mock fetch_mv method and set a fixed return value mv_result = MVResult(exp_avg={}, exp_avg_sq={}, update={}, ratio={}) - self.mock_fetch_mv_in_adam = MagicMock(return_value=mv_result) - self.optimizer._fetch_mv_in_adam = self.mock_fetch_mv_in_adam + self.mock_fetch_mv = MagicMock(return_value=mv_result) + self.optimizer.fetch_mv = self.mock_fetch_mv - res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) - self.mock_fetch_mv_in_adam.assert_called_once_with(self.monitor, self.torch_opt, self.params2name) + res = self.optimizer.fetch_mv(self.monitor, self.params2name) + self.mock_fetch_mv.assert_called_once_with(self.monitor, self.params2name) self.assertIsInstance(res, MVResult) @@ -129,26 +124,27 @@ class TestMegatronChainedDistributedOptimizerMon(unittest.TestCase): self.monitor = MagicMock() self.torch_opt = MagicMock() self.params2name = MagicMock() + self.torch_opt.chained_optimizers = [MagicMock(), MagicMock()] mv_result = MVResult(exp_avg={}, exp_avg_sq={}, update={}, ratio={}) - self.mock_fetch_mv_in_adam = MagicMock(return_value=mv_result) - self.optimizer = MegatronChainedDistributedOptimizerMon() + self.mock_fetch_mv = MagicMock(return_value=mv_result) + self.optimizer = MegatronChainedDistributedOptimizerMon(self.torch_opt) def test_fetch_mv_with_valid_optimizer(self): - self.torch_opt.model_float16_groups = [MagicMock()] - self.torch_opt.shard_fp32_from_float16_groups = [MagicMock()] - self.optimizer._fetch_mv_in_adam = self.mock_fetch_mv_in_adam + for opt in self.torch_opt.chained_optimizers: + opt.model_float16_groups = [MagicMock()] + opt.shard_fp32_from_float16_groups = [MagicMock()] + self.optimizer.fetch_mv = self.mock_fetch_mv - res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) + res = self.optimizer.fetch_mv(self.monitor, self.params2name) self.assertIsInstance(res, MVResult) def test_fetch_mv_with_invalid_optimizer(self): - self.torch_opt = Mock() - self.torch_opt.model_float16_groups = None - self.torch_opt.shard_fp32_from_float16_groups = None - self.optimizer._fetch_mv_in_adam = self.mock_fetch_mv_in_adam + for opt in self.torch_opt.chained_optimizers: + del opt.model_float16_groups + del opt.shard_fp32_from_float16_groups with self.assertRaises(Exception): - self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) + self.optimizer.fetch_mv(self.monitor, self.params2name) class TestMegatronDistributedOptimizerMon(unittest.TestCase): @@ -157,25 +153,23 @@ class TestMegatronDistributedOptimizerMon(unittest.TestCase): self.torch_opt = MagicMock() self.params2name = MagicMock() mv_result = MVResult(exp_avg={}, exp_avg_sq={}, update={}, ratio={}) - self.mock_fetch_mv_in_adam = MagicMock(return_value=mv_result) - self.optimizer = MegatronDistributedOptimizerMon() + self.mock_fetch_mv = MagicMock(return_value=mv_result) + self.optimizer = MegatronDistributedOptimizerMon(self.torch_opt) def test_fetch_mv_with_valid_optimizer(self): self.torch_opt.model_float16_groups = [MagicMock()] self.torch_opt.shard_fp32_from_float16_groups = [MagicMock()] - self.optimizer._fetch_mv_in_adam = self.mock_fetch_mv_in_adam + self.optimizer.fetch_mv = self.mock_fetch_mv - res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) + res = self.optimizer.fetch_mv(self.monitor, self.params2name) self.assertIsInstance(res, MVResult) def test_fetch_mv_with_invalid_optimizer(self): - self.torch_opt = Mock() self.torch_opt.model_float16_groups = None self.torch_opt.shard_fp32_from_float16_groups = None - self.optimizer._fetch_mv_in_adam = self.mock_fetch_mv_in_adam with self.assertRaises(Exception): - self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) + self.optimizer.fetch_mv(self.monitor, self.params2name) class TestCommonFetchMv(unittest.TestCase): @@ -184,103 +178,189 @@ class TestCommonFetchMv(unittest.TestCase): self.torch_opt = MagicMock() self.params2name = MagicMock() - def test_megatron_fp32_optimizer_mon(self): - self.optimizer = MegatronFP32OptimizerMon() - res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) + def test_optimizer_mon(self): + self.optimizer = OptimizerMon(None) + res = self.optimizer.fetch_mv(self.monitor, self.params2name) self.assertIsInstance(res, MVResult) - def test_deepspeed_zero_optimizer_stage0_mon(self): - self.optimizer = DeepSpeedZeroOptimizerStage0Mon() - res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) - self.assertIsInstance(res, MVResult) - def test_dummy_optimizer_mon(self): - self.optimizer = DummyOptimizerMon() - res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) - self.assertIsInstance(res, MVResult) +class TestDeepSpeedZeroOptimizer(unittest.TestCase): + def setUp(self): + bit16_groups, param_names, param_slice_mappings, _ = setup_param_groups() + mock_opt = MagicMock() + mock_opt.state_dict.return_value = { + 'param_slice_mappings': param_slice_mappings + } + mock_opt.param_names = param_names + mock_opt.bit16_groups = bit16_groups + self.torch_opt = mock_opt + self.mock_monitor = setup_mock_monitor() + self.optimizer_mon = DeepSpeedZeroOptimizerMon(mock_opt) + self.optimizer_mon.bit16_groups = mock_opt.bit16_groups + self.optimizer_mon.param2group = self.optimizer_mon.get_group_index() + + def test_param_not_in_partition(self): + param_in_partition = list(self.torch_opt.param_names.keys())[0] + param_not_in_partition = torch.randn(2,3) + + self.assertFalse( + self.optimizer_mon.param_not_in_partition(param_in_partition, 0) + ) + self.assertTrue( + self.optimizer_mon.param_not_in_partition(param_not_in_partition, 0) + ) + + def test_get_position(self): + param_in_partition = list(self.torch_opt.param_names.keys())[0] + start, numel = self.optimizer_mon.get_position(param_in_partition, 0) + self.assertEqual(start, 0) + self.assertEqual(numel, 6) -class TestDeepSpeedZeroOptimizerStage3Mon(unittest.TestCase): - def test_get_param_index(self): - self.torch_opt = Mock() - self.torch_opt.fp16_partitioned_groups = [ - [Mock(flatten=lambda: [1, 2, 3]), - Mock(flatten=lambda: [4, 5])], - [Mock(flatten=lambda: [6, 7, 8, 9])] - ] - self.params2name = {'param1': 'weight1', 'param2': 'weight2'} - self.name2index = {'weight1': 0, 'weight2': 2} + def test_get_group_index(self): + param = list(self.torch_opt.param_names.keys())[6] + self.assertEqual(self.optimizer_mon.param2group[param], 1) - optimizer_stage3_mon = DeepSpeedZeroOptimizerStage3Mon() - name2indices = optimizer_stage3_mon.get_param_index(self.params2name, self.name2index, self.torch_opt) +class TestDeepSpeedZeroOptimizerStage0Mon(unittest.TestCase): + def setUp(self): + bit16_groups, param_names, param_slice_mappings, _ = setup_param_groups() - expected_name2indices = {'weight1': (0, 3, 0, None), 'weight2': (5, 9, 1, None)} - self.assertDictEqual(dict(name2indices), expected_name2indices) + mock_opt = MagicMock() + mock_opt.state_dict.return_value = { + 'param_slice_mappings': param_slice_mappings + } + mock_opt.param_names = param_names + mock_opt.bf16_groups = bit16_groups + mock_opt.fp32_groups_flat_partition = [torch.stack(group,dim=0).flatten().float() \ + for group in bit16_groups]# mock name 2 index in subgroup + mock_opt.state = { + flat_group: { + 'exp_avg': torch.ones_like(flat_group), + 'exp_avg_sq': torch.ones_like(flat_group) + } for flat_group in mock_opt.fp32_groups_flat_partition + } + mock_opt.cpu_offload = False + + self.torch_opt = mock_opt + self.mock_monitor = setup_mock_monitor() + self.optimizer_mon = DeepSpeedZeroOptimizerStage0Mon(mock_opt) + + def test_get_grad_for_param(self): + param = list(self.torch_opt.param_names.keys())[0] + group_idx = 0 + param_id = 2 + grad_expected = torch.randn_like(param) + self.torch_opt.fp32_groups_gradient_dict = [[0, 0, grad_expected, 0]] + grad = self.optimizer_mon.get_grad_for_param(param, group_idx, param_id) + + self.assertTrue(torch.equal(grad_expected, grad)) + + def test_fetch_grad(self): + self.torch_opt.fp32_groups_gradient_dict = [[torch.randn_like(param) for param in group] for group in self.optimizer_mon.bit16_groups] + self.mock_monitor.name2tag = {name:{MonitorConst.POST_GRAD: name} for name in self.torch_opt.param_names.values()} + result = self.optimizer_mon.fetch_grad(self.mock_monitor, self.torch_opt.param_names) + for _, name in self.torch_opt.param_names.items(): + group_index, param_id = [int(i) for i in name.replace('param','').split('_')] + self.assertTrue(torch.equal(result[name], self.torch_opt.fp32_groups_gradient_dict[group_index][param_id])) def test_fetch_mv(self): - self.monitor = MagicMock() - self.torch_opt = MagicMock() - self.params2name = MagicMock() - self.torch_opt.fp16_partitioned_groups = MagicMock() - self.optimizer = DeepSpeedZeroOptimizerStage3Mon() - - # mock _fetch_mv_grad_in_adam - mv_result = MVGradResult(exp_avg={}, exp_avg_sq={}, update={}, ratio={}, grad={}) - self.mock_fetch_mv_grad_in_adam = MagicMock(return_value=mv_result) - self.optimizer._fetch_mv_grad_in_adam = self.mock_fetch_mv_grad_in_adam - - res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) - self.assertIsInstance(res, MVGradResult) + del self.torch_opt.chained_optimizers + del self.torch_opt.param_to_cpu_states_map + result = self.optimizer_mon.fetch_mv(self.mock_monitor, self.torch_opt.param_names) + for param, name in self.torch_opt.param_names.items(): + self.assertTrue(torch.equal(result.exp_avg[name], torch.ones_like(param).flatten())) + self.assertTrue(torch.equal(result.exp_avg_sq[name], torch.ones_like(param).flatten())) class TestDeepSpeedZeroOptimizerStage1or2Mon(unittest.TestCase): - def test_get_group_index(self): - self.fp32_length = [10, 20, 30, 40] - self.world_size = 4 - self.indexes = [5, 7, 12, 25, 35, 45] - self.expected_results = [(40, 0), (40, 0), (12, 1), (24, 2), (34, 2), (40, 0)] - - optimizer = DeepSpeedZeroOptimizerStage1or2Mon() - results = [optimizer.get_group_index(self.fp32_length, self.world_size, index) for index in self.indexes] - self.assertEqual(results, self.expected_results) + def setUp(self): + bit16_groups, param_names, param_slice_mappings, _ = setup_param_groups() - @patch('msprobe.pytorch.monitor.optimizer_collect.dist') - def test_get_param_index(self, mock_dist): - mock_dist.get_world_size.return_value = 4 + mock_opt = MagicMock() + mock_opt.state_dict.return_value = { + 'param_slice_mappings': param_slice_mappings + } + mock_opt.param_names = param_names + mock_opt.bit16_groups = bit16_groups + mock_opt.single_partition_of_fp32_groups = [torch.stack(group,dim=0).flatten().float() \ + for group in bit16_groups] + mock_opt.averaged_gradients = {group_idx: [torch.randn_like(param) for param in group] for group_idx, group in enumerate(bit16_groups)}# mock name 2 index in subgroup + mock_opt.state = { + flat_group: { + 'exp_avg': torch.ones_like(flat_group), + 'exp_avg_sq': torch.ones_like(flat_group) + } for flat_group in mock_opt.single_partition_of_fp32_groups + } + mock_opt.cpu_offload = False + + self.torch_opt = mock_opt + self.mock_monitor = setup_mock_monitor() + self.optimizer_mon = DeepSpeedZeroOptimizerStage1or2Mon(mock_opt) + + def test_get_grad_for_param(self): + param = list(self.torch_opt.param_names.keys())[0] + group_idx = 0 + param_id = 2 + grad_expected = torch.randn_like(param) + self.torch_opt.averaged_gradients = [[0, 0, grad_expected, 0]] + grad = self.optimizer_mon.get_grad_for_param(param, group_idx, param_id) + + self.assertTrue(torch.equal(grad_expected, grad)) + + def test_fetch_grad(self): + self.mock_monitor.name2tag = {name:{MonitorConst.POST_GRAD: name} for name in self.torch_opt.param_names.values()} + result = self.optimizer_mon.fetch_grad(self.mock_monitor, self.torch_opt.param_names) + for param, name in self.torch_opt.param_names.items(): + group_index, param_id = [int(i) for i in name.replace('param','').split('_')] + self.assertTrue(torch.equal(result[name], self.torch_opt.averaged_gradients[group_index][param_id])) - self.params2name = {'param1': 'weight', 'param2': 'bias'} - self.name2index = {'weight': 0, 'bias': 1} + def test_fetch_mv(self): + del self.torch_opt.chained_optimizers + del self.torch_opt.param_to_cpu_states_map + result = self.optimizer_mon.fetch_mv(self.mock_monitor, self.torch_opt.param_names) + for param, name in self.torch_opt.param_names.items(): + self.assertTrue(torch.equal(result.exp_avg[name], torch.ones_like(param).flatten())) + self.assertTrue(torch.equal(result.exp_avg_sq[name], torch.ones_like(param).flatten())) - self.optimizer_monitor = DeepSpeedZeroOptimizerStage1or2Mon() - self.torch_opt = MagicMock() - self.torch_opt.groups_padding = [1, 2, 3] - self.torch_opt.single_partition_of_fp32_groups = [torch.tensor([1, 2]), torch.tensor([3, 4, 5])] - self.torch_opt.bit16_groups = [ - [torch.tensor([6, 7]), torch.tensor([8])], - [torch.tensor([9, 10, 11])] - ] - - name2indices = self.optimizer_monitor.get_param_index(self.params2name, self.name2index, self.torch_opt) - for name, indices in name2indices.items(): - self.assertIn(name, self.params2name.values()) - self.assertIsInstance(indices, tuple) - self.assertEqual(len(indices), 4) +class TestDeepSpeedZeroOptimizerStage3Mon(unittest.TestCase): + def setUp(self): + bit16_groups, param_names, _, grad_position = setup_param_groups() + + mock_opt = MagicMock() + mock_opt.param_names = param_names + mock_opt.fp16_groups = bit16_groups + mock_opt.fp32_partitioned_groups_flat = [torch.stack(group,dim=0).flatten().float() + for group in bit16_groups] + mock_opt.averaged_gradients = {group_idx: [torch.randn_like(param) for param in group] + for group_idx, group in enumerate(bit16_groups)} + mock_opt.grad_position = grad_position + mock_opt.get_param_id = lambda x: int(param_names[x].split('_')[1]) + mock_opt.state = { + flat_group: { + 'exp_avg': torch.ones_like(flat_group), + 'exp_avg_sq': torch.ones_like(flat_group) + } for flat_group in mock_opt.fp32_partitioned_groups_flat + } + + self.torch_opt = mock_opt + self.optimizer_mon = DeepSpeedZeroOptimizerStage3Mon(mock_opt) + self.mock_monitor = setup_mock_monitor() + + def test_fetch_grad(self): + self.mock_monitor.name2tag = {name:{MonitorConst.POST_GRAD: name} for name in self.torch_opt.param_names.values()} + result = self.optimizer_mon.fetch_grad(self.mock_monitor, self.torch_opt.param_names) + for param, name in self.torch_opt.param_names.items(): + group_index, param_id = [int(i) for i in name.replace('param','').split('_')] + self.assertTrue(torch.equal(result[name], self.torch_opt.averaged_gradients[group_index][param_id])) def test_fetch_mv(self): - self.monitor = MagicMock() - self.torch_opt = MagicMock() - self.params2name = MagicMock() - self.torch_opt.fp16_partitioned_groups = MagicMock() - self.optimizer = DeepSpeedZeroOptimizerStage1or2Mon() - - # mock _fetch_mv_grad_in_adam - mv_result = MVGradResult(exp_avg={}, exp_avg_sq={}, update={}, ratio={}, grad={}) - self.mock_fetch_mv_grad_in_adam = MagicMock(return_value=mv_result) - self.optimizer._fetch_mv_grad_in_adam = self.mock_fetch_mv_grad_in_adam - - res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) - self.assertIsInstance(res, MVGradResult) + del self.torch_opt.chained_optimizers + del self.torch_opt.param_to_cpu_states_map + result = self.optimizer_mon.fetch_mv(self.mock_monitor, self.torch_opt.param_names) + for param, name in self.torch_opt.param_names.items(): + self.assertTrue(torch.equal(result.exp_avg[name], torch.ones_like(param).flatten())) + self.assertTrue(torch.equal(result.exp_avg_sq[name], torch.ones_like(param).flatten())) class TestOptimizerMonFactory(unittest.TestCase): @@ -291,48 +371,48 @@ class TestOptimizerMonFactory(unittest.TestCase): mix_optimizer_class = MagicMock() mix_optimizer_class.__name__ = "Float16OptimizerWithFloat16Params" mix_optimizer.__class__ = mix_optimizer_class - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(mix_optimizer)[0], - MixPrecisionOptimizerMon) + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(mix_optimizer), + MegatronMixPrecisionOptimizerMon) dis_optimizer = MagicMock() dis_optimizer_class = MagicMock() dis_optimizer_class.__name__ = "DistributedOptimizer" dis_optimizer.__class__ = dis_optimizer_class - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(dis_optimizer)[0], + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(dis_optimizer), MegatronDistributedOptimizerMon) fp32_optimizer = MagicMock() fp32_optimizer_class = MagicMock() fp32_optimizer_class.__name__ = "FP32Optimizer" fp32_optimizer.__class__ = fp32_optimizer_class - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(fp32_optimizer)[0], - MegatronFP32OptimizerMon) + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(fp32_optimizer), + OptimizerMon) chained_optimizer = MagicMock() chained_optimizer_class = MagicMock() chained_optimizer_class.__name__ = "ChainedOptimizer" chained_optimizer.__class__ = chained_optimizer_class chained_optimizer.chained_optimizers = [mix_optimizer, mix_optimizer] - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(chained_optimizer)[0], + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(chained_optimizer), MegatronChainedMixPrecisionOptimizerMon) chained_optimizer.chained_optimizers = [dis_optimizer, dis_optimizer] - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(chained_optimizer)[0], + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(chained_optimizer), MegatronChainedDistributedOptimizerMon) deepspeed_optimizer = MagicMock() deepspeed_optimizer_class = MagicMock() deepspeed_optimizer_class.__name__ = "BF16_Optimizer" deepspeed_optimizer.__class__ = deepspeed_optimizer_class - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(deepspeed_optimizer)[0], + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(deepspeed_optimizer), DeepSpeedZeroOptimizerStage0Mon) deepspeed_optimizer_class.__name__ = "DeepSpeedZeroOptimizer" - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(deepspeed_optimizer)[0], + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(deepspeed_optimizer), DeepSpeedZeroOptimizerStage1or2Mon) deepspeed_optimizer_class.__name__ = "DeepSpeedZeroOptimizer_Stage3" - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(deepspeed_optimizer)[0], + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(deepspeed_optimizer), DeepSpeedZeroOptimizerStage3Mon) - # 测试未知的优化器类型,应该返回DummyOptimizerMon + # 测试未知的优化器类型,应该返回OptimizerMon unknown_optimizer = MagicMock() unknown_optimizer_class = MagicMock() unknown_optimizer_class.__name__ = "unknown" unknown_optimizer.__class__ = unknown_optimizer_class - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(unknown_optimizer)[0], DummyOptimizerMon) + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(unknown_optimizer), OptimizerMon) if __name__ == '__main__': diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/parse_tool/test_interactive_cli.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/parse_tool/test_interactive_cli.py index b875bd7e8e17f6b869b2a1b1498982b2a17e1258..3a09d41588a94043f54161023ffbba573c60d76c 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/parse_tool/test_interactive_cli.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/parse_tool/test_interactive_cli.py @@ -26,63 +26,50 @@ class TestInteractiveCli(unittest.TestCase): @patch('msprobe.pytorch.parse_tool.lib.interactive_cli.ParseTool.prepare', return_value=None) def test_prepare(self, mock_prepare): self.interactive_cli.prepare() - mock_prepare.assert_called_once() - @patch('msprobe.pytorch.parse_tool.lib.interactive_cli.Util.execute_command', return_value=None) - def test_default(self, mock_execute_command): - res = self.interactive_cli.default() - - mock_execute_command.assert_called_once() - self.assertFalse(res) - - @patch('msprobe.pytorch.parse_tool.lib.interactive_cli.Util.execute_command', return_value=None) - def test_do_run(self, mock_execute_command): - self.interactive_cli.do_run() - - mock_execute_command.assert_called_once() + def test_default(self, command='rm'): + res = self.interactive_cli.default(command) + self.assertIsNone(res) @patch('msprobe.pytorch.parse_tool.lib.interactive_cli.ParseTool.do_compare_converted_dir') @patch('msprobe.pytorch.parse_tool.lib.interactive_cli.ParseTool.do_vector_compare') def test_do_vc(self, mock_do_vector_compare, mock_do_compare_converted_dir): - with patch('msprobe.pytorch.parse_tool.lib.interactive_cli.Util.check_path_valid'), \ - patch('msprobe.pytorch.parse_tool.lib.interactive_cli.Util.check_files_in_path'): - with patch('msprobe.pytorch.parse_tool.lib.interactive_cli.Util.dir_contains_only', return_value=False): - self.interactive_cli.do_vc('-m my_dump_path -g golden_dump_path -out output_path -cmp_path msaccucmp_path') - + with (patch('msprobe.pytorch.parse_tool.lib.interactive_cli.Util.check_path_valid'), + patch('msprobe.pytorch.parse_tool.lib.interactive_cli.Util.check_files_in_path')): + with patch('msprobe.pytorch.parse_tool.lib.interactive_cli.Util.dir_contains_only', + return_value=False): + self.interactive_cli.do_vc( + '-m my_dump_path -g golden_dump_path -out output_path -cmp_path msaccucmp_path') mock_do_vector_compare.assert_called_once() - with patch('msprobe.pytorch.parse_tool.lib.interactive_cli.Util.dir_contains_only', return_value=True): - self.interactive_cli.do_vc('-m my_dump_path -g golden_dump_path -out output_path -cmp_path msaccucmp_path') - + with patch('msprobe.pytorch.parse_tool.lib.interactive_cli.Util.dir_contains_only', + return_value=True): + self.interactive_cli.do_vc( + '-m my_dump_path -g golden_dump_path -out output_path -cmp_path msaccucmp_path') mock_do_compare_converted_dir.assert_called_once() @patch('msprobe.pytorch.parse_tool.lib.interactive_cli.ParseTool.do_convert_dump', return_value=None) def test_do_dc(self, mock_do_convert_dump): self.interactive_cli.do_dc('-n file_name/file_path -f format -out output_path') - mock_do_convert_dump.assert_called_once() @patch('msprobe.pytorch.parse_tool.lib.interactive_cli.ParseTool.do_print_data', return_value=None) def test_do_pt(self, mock_do_print_data): self.interactive_cli.do_pt('-n file_path') - mock_do_print_data.assert_called_once() @patch('msprobe.pytorch.parse_tool.lib.interactive_cli.ParseTool.do_parse_pkl', return_value=None) def test_do_pk(self, mock_do_parse_pkl): self.interactive_cli.do_pk('-f pkl_path -n api_name') - mock_do_parse_pkl.assert_called_once() @patch('msprobe.pytorch.parse_tool.lib.interactive_cli.ParseTool.do_compare_data', return_value=None) def test_do_cn(self, mock_do_comapre_data): self.interactive_cli.do_cn('-m my_data*.npy -g golden*.npu -p num -al atol -rl rtol') - mock_do_comapre_data.assert_called_once() @patch('msprobe.pytorch.parse_tool.lib.interactive_cli.ParseTool.do_convert_api_dir', return_value=None) def test_do_cad(self, mock_do_convert_api_dir): self.interactive_cli.do_cad('-m my_dump_path -out output_path -asc msaccucmp_path') - mock_do_convert_api_dir.assert_called_once() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/parse_tool/test_parse_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/parse_tool/test_parse_utils.py index dfec4d20366c6e834939130009dc6d33d1cbe9ed..c148f84d0d20213631e9be039521a14d970849e9 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/parse_tool/test_parse_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/parse_tool/test_parse_utils.py @@ -88,7 +88,7 @@ class TestUtils(unittest.TestCase): obj = np.array([1, 2, 3, 4, 5]) res = self.util.get_md5_for_numpy(obj) - self.assertEqual(res, '3cd8e13ca72251bfd8c08e209abcf46f') + self.assertEqual(res, 'baa24928') def test_deal_with_dir_or_file_inconsistency(self): with self.assertRaises(ParseException): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py index c1b8bac47fda100636b55fbc5ad452c2843e8aaa..2712281bef7cafe1e4e2ad82def5eb13c7716f9b 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py @@ -4,7 +4,7 @@ import unittest from unittest.mock import patch from msprobe.core.common.const import Const -from msprobe.pytorch.pt_config import parse_json_config, parse_task_config, TensorConfig, \ +from msprobe.pytorch.pt_config import parse_json_config, parse_task_config, \ StatisticsConfig, OverflowCheckConfig, FreeBenchmarkCheckConfig, RunUTConfig, GradToolConfig @@ -82,81 +82,6 @@ class TestPtConfig(unittest.TestCase): self.assertEqual(result.error_data_path, '/home/dump_path') -class TestTensorConfig(unittest.TestCase): - - def setUp(self): - self.json_config = { - "online_run_ut": False, - "host": "127.0.0.1", - "port": 8080 - } - self.config = TensorConfig(self.json_config) - - def test_check_file_format_valid(self): - self.config.file_format = "npy" - self.config._check_file_format() - - self.config.file_format = "bin" - self.config._check_file_format() - - def test_check_file_format_invalid(self): - self.config.file_format = "invalid_format" - with self.assertRaises(Exception) as context: - self.config._check_file_format() - self.assertIn(str(context.exception), "file_format is invalid") - - @patch('msprobe.pytorch.pt_config.check_crt_valid') - def test_check_online_run_ut(self, mock_check_crt_valid): - mock_check_crt_valid.return_value = True - - self.config.online_run_ut = "True" - with self.assertRaises(Exception) as context: - self.config._check_online_run_ut() - self.assertIn(str(context.exception), f"online_run_ut: {self.config.online_run_ut} is invalid.") - self.config.online_run_ut = True - - self.config.online_run_ut_recompute = "True" - with self.assertRaises(Exception) as context: - self.config._check_online_run_ut() - self.assertIn(str(context.exception), f"online_run_ut_recompute: {self.config.online_run_ut} is invalid.") - self.config.online_run_ut_recompute = False - - self.config.nfs_path = "./nfs_path" - with self.assertRaises(Exception) as context: - self.config._check_online_run_ut() - self.assertIn(str(context.exception), "[msprobe] 非法文件路径: ") - self.config.nfs_path = "" - - self.config.tls_path = "./tls_path" - with self.assertRaises(Exception) as context: - self.config._check_online_run_ut() - self.assertIn(str(context.exception), "[msprobe] 非法文件路径: ") - - os.makedirs(self.config.tls_path) - with open(os.path.join(self.config.tls_path, "client.key"), 'w') as file: - file.write("1") - with open(os.path.join(self.config.tls_path, "client.crt"), 'w') as file: - file.write("1") - self.config._check_online_run_ut() - shutil.rmtree(self.config.tls_path) - self.config.tls_path = "" - - self.config.host = "invalid_host" - with self.assertRaises(Exception) as context: - self.config._check_online_run_ut() - self.assertIn(str(context.exception), f"host: {self.config.host} is invalid.") - self.config.host = "127.0.0.1" - - self.config.port = -1 - with self.assertRaises(Exception) as context: - self.config._check_online_run_ut() - self.assertIn(str(context.exception), f"port: {self.config.port} is invalid, port range 0-65535.") - self.config.port = 6123 - - # all config right - self.config._check_online_run_ut() - - class TestStatisticsConfig(unittest.TestCase): def setUp(self): @@ -181,7 +106,7 @@ class TestStatisticsConfig(unittest.TestCase): self.config.summary_mode = "invalid_mode" with self.assertRaises(Exception) as context: self.config._check_summary_mode() - self.assertIn(str(context.exception), "summary_mode is invalid") + self.assertIn(str(context.exception), "[msprobe] 无效参数:") def test_check_summary_mode_none(self): self.config.summary_mode = None @@ -261,14 +186,14 @@ class TestFreeBenchmarkCheckConfig(unittest.TestCase): config = FreeBenchmarkCheckConfig(invalid_config) mock_error.assert_called_once() self.assertIn("fuzz_device is invalid", str(mock_error.call_args)) - + @patch('msprobe.core.common.log.logger.error_log_with_exp') def test_check_fuzz_device_cpu_mode_invalid(self, mock_error): invalid_config = self.valid_config.copy() invalid_config["fuzz_device"] = "cpu" invalid_config["pert_mode"] = "INVALID_CPU_MODE" config = FreeBenchmarkCheckConfig(invalid_config) - self.assertIn("You neet to and can only set fuzz_device as ", str(mock_error.call_args)) + self.assertIn("You need to and can only set fuzz_device as ", str(mock_error.call_args)) @patch('msprobe.core.common.log.logger.error_log_with_exp') def test_check_handler_type_invalid(self, mock_error): @@ -277,7 +202,7 @@ class TestFreeBenchmarkCheckConfig(unittest.TestCase): config = FreeBenchmarkCheckConfig(invalid_config) mock_error.assert_called_once() self.assertIn("handler_type is invalid", str(mock_error.call_args)) - + @patch('msprobe.core.common.log.logger.error_log_with_exp') def test_check_fuzz_stage_invalid(self, mock_error): invalid_config = self.valid_config.copy() @@ -319,7 +244,7 @@ class TestFreeBenchmarkCheckConfig(unittest.TestCase): config = FreeBenchmarkCheckConfig(invalid_config) mock_error.assert_called_once() self.assertIn("preheat_step must be greater than 0", str(mock_error.call_args)) - + @patch('msprobe.core.common.log.logger.error_log_with_exp') def test_check_preheat_max_sample_not_int(self, mock_error): invalid_config = self.valid_config.copy() @@ -328,7 +253,7 @@ class TestFreeBenchmarkCheckConfig(unittest.TestCase): config = FreeBenchmarkCheckConfig(invalid_config) mock_error.assert_called_once() self.assertIn("max_sample is invalid, it should be an integer", str(mock_error.call_args)) - + @patch('msprobe.core.common.log.logger.error_log_with_exp') def test_check_max_sample_invalid_not_great_than_zero(self, mock_error): invalid_config = self.valid_config.copy() @@ -363,60 +288,6 @@ class TestFreeBenchmarkCheckConfig(unittest.TestCase): self.assertIn("The pert_mode when opening fix handler must be one of", str(mock_error.call_args)) -class TestRunUTConfig(unittest.TestCase): - - @patch('msprobe.pytorch.hook_module.utils.get_ops', return_value=['relu', 'gelu', 'conv2d']) - def setUp(self, mock_get_ops): - self.config = RunUTConfig({ - "white_list": ["relu"], - "black_list": ["gelu"] - }) - - def test_check_filter_list_config_invalid_type(self): - with self.assertRaises(Exception) as context: - RunUTConfig.check_filter_list_config(Const.WHITE_LIST, "not_a_list") - self.assertIn("must be a list type", str(context.exception)) - - def test_check_filter_list_element_config_invalid_type(self): - with self.assertRaises(Exception) as context: - RunUTConfig.check_filter_list_config("white_list", [1, 1]) - self.assertIn("All elements in ", str(context.exception)) - - def test_check_filter_list_config_invalid_item(self): - with self.assertRaises(Exception) as context: - RunUTConfig.check_filter_list_config("white_list", ["api1"]) - self.assertIn("Invalid api in white_list:", str(context.exception)) - - @patch('os.path.exists', return_value=False) - def test_check_error_data_path_config_not_exist(self, mock_exists): - with self.assertRaises(Exception) as context: - RunUTConfig.check_error_data_path_config("./invalid_path") - self.assertIn("does not exist", str(context.exception)) - - @patch('os.path.exists', return_value=False) - def test_check_nfs_path_config_not_exist(self, mock_exists): - with self.assertRaises(Exception) as context: - RunUTConfig.check_nfs_path_config("./invalid_nfs") - self.assertIn("does not exist", str(context.exception)) - - @patch('os.path.exists', return_value=False) - def test_check_tls_path_config_not_exist(self, mock_exists): - with self.assertRaises(Exception) as context: - RunUTConfig.check_tls_path_config("./invalid_tls") - self.assertIn("does not exist", str(context.exception)) - - def test_check_run_ut_config(self): - with patch.object(RunUTConfig, 'check_filter_list_config') as mock_filter, \ - patch.object(RunUTConfig, 'check_error_data_path_config') as mock_error, \ - patch.object(RunUTConfig, 'check_nfs_path_config') as mock_nfs, \ - patch.object(RunUTConfig, 'check_tls_path_config') as mock_tls: - self.config.check_run_ut_config() - mock_filter.assert_called() - mock_error.assert_called() - mock_nfs.assert_called() - mock_tls.assert_called() - - class TestGradToolConfig(unittest.TestCase): def setUp(self): self.level_adp = {"L1": None, "L2": None} @@ -442,3 +313,7 @@ class TestGradToolConfig(unittest.TestCase): with self.assertRaises(Exception) as context: GradToolConfig(json_config) self.assertTrue("param_list must be a list" in str(context.exception)) + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_debug_save.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_debug_save.py index 534437260e66d9e586d69d557d30e308a9f4f3ee..e517e1cefe4987b62aa2040f1a4e9db0b8dfbe98 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_debug_save.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_debug_save.py @@ -18,6 +18,7 @@ import torch from msprobe.pytorch import PrecisionDebugger from msprobe.core.common_config import CommonConfig, BaseConfig +from msprobe.core.debugger.precision_debugger import BasePrecisionDebugger class TestPytorchDebuggerSave(TestCase): @@ -36,13 +37,14 @@ class TestPytorchDebuggerSave(TestCase): } common_config = CommonConfig(statistics_task_json) task_config = BaseConfig(statistics_task_json) - with patch("msprobe.pytorch.debugger.precision_debugger.parse_json_config", return_value=(common_config, task_config)): + with patch.object(BasePrecisionDebugger, "_parse_config_path", return_value=(common_config, task_config)): self.debugger = PrecisionDebugger() def test_forward_and_backward(self): def forward_func(x, y): PrecisionDebugger.save(x, "x_tensor") return x * y + x = torch.tensor([1.]) y = torch.tensor([2.]) x.requires_grad = True @@ -53,28 +55,28 @@ class TestPytorchDebuggerSave(TestCase): "framework": "pytorch", "dump_data_dir": None, "data": { - "x_tensor.0": { + "x_tensor.0.debug": { "type": "torch.Tensor", "dtype": "torch.float32", "shape": torch.Size([1]), - "Max": 1.0, - "Min": 1.0, - "Mean": 1.0, - "Norm": 1.0, "requires_grad": True }, - "x_tensor_grad.0": { + "x_tensor_grad.0.debug": { "type": "torch.Tensor", "dtype": "torch.float32", "shape": torch.Size([1]), - "Max": 2.0, - "Min": 2.0, - "Mean": 2.0, - "Norm": 2.0, "requires_grad": False } } } + loss = forward_func(x, y) loss.backward() - self.assertEqual(self.debugger.service.data_collector.data_writer.cache_debug, result_json) \ No newline at end of file + + result = self.debugger.service.data_collector.data_writer.cache_debug + # Remove 'tensor_stat_index' from all entries in the data dictionary + for key in result["data"]: + if 'tensor_stat_index' in result["data"][key]: + del result["data"][key]['tensor_stat_index'] + + self.assertEqual(result, result_json) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_service.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_service.py new file mode 100644 index 0000000000000000000000000000000000000000..d1419ab12caf47bd10c7aabb7991ffcc694b8f5d --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_service.py @@ -0,0 +1,103 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch +from msprobe.pytorch.pytorch_service import PytorchService +from msprobe.core.common.utils import Const +from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser +from msprobe.pytorch.hook_module.hook_module import HOOKModule + + +class TestPytorchService(unittest.TestCase): + def setUp(self): + self.config = MagicMock() + self.config.step = [] + self.config.rank = [] + self.config.level = Const.LEVEL_MIX + self.config.task = Const.STATISTICS + + with patch('msprobe.core.service.build_data_collector'): + self.service = PytorchService(self.config) + + self.service.logger = MagicMock() + self.service.data_collector = MagicMock() + self.service.module_processor = MagicMock() + self.service.api_register = MagicMock() + + def test_framework_type(self): + self.assertEqual(self.service._get_framework_type, Const.PT_FRAMEWORK) + + @patch('msprobe.pytorch.pytorch_service.get_rank_if_initialized') + def test_get_current_rank(self, mock_get_rank): + mock_get_rank.return_value = 5 + self.assertEqual(self.service._get_current_rank(), 5) + + def test_init_specific_components(self): + with patch('msprobe.core.service.build_data_collector'): + service = PytorchService(self.config) + + self.assertIsNotNone(service.logger) + self.assertIsNotNone(service.api_register) + self.assertIsNotNone(service.module_processor) + self.assertIsNotNone(service.hook_manager) + + def test_register_hook(self): + self.service._register_hook() + + @patch('msprobe.pytorch.pytorch_service.register_optimizer_hook') + def test_register_hook_mix_level(self, mock_register_opt): + self.service.config.level = Const.LEVEL_MIX + self.service._register_hook() + mock_register_opt.assert_called_once_with(self.service.data_collector) + + @patch('msprobe.pytorch.pytorch_service.register_optimizer_hook') + def test_register_hook_not_mix_level(self, mock_register_opt): + self.service.config.level = Const.LEVEL_L1 + self.service._register_hook() + mock_register_opt.assert_not_called() + + @patch('msprobe.pytorch.pytorch_service.wrap_script_func') + def test_register_api_hook(self, mock_wrap_jit): + self.service.config.level = Const.LEVEL_L1 + self.service._register_api_hook() + mock_wrap_jit.assert_called_once() + self.service.api_register.initialize_hook.assert_called_once() + + def test_register_module_hook(self): + model_mock = MagicMock() + self.service.model = model_mock + self.service._register_module_hook() + + self.service.module_processor.register_module_hook.assert_called_once_with( + model_mock, self.service.build_hook + ) + + self.assertTrue(self.service.module_processor.enable_module_dump) + + + @patch.object(HOOKModule, 'reset_module_stats') + @patch.object(ModuleProcesser, 'reset_module_stats') + def test_reset_status(self, mock_reset_module_processor, mock_reset_hook_module): + self.service._reset_status() + mock_reset_hook_module.assert_called_once() + mock_reset_module_processor.assert_called_once() + self.service.data_collector.reset_status.assert_called_once() + + + def test_register_module_hook(self): + self.service.model = MagicMock() + self.service._register_module_hook() + self.service.module_processor.register_module_hook.assert_called_once() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py deleted file mode 100644 index 6687f3111050ea53e14e62f3afd55ae1eff2b8c0..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py +++ /dev/null @@ -1,150 +0,0 @@ -import unittest -from unittest.mock import patch, mock_open, MagicMock - -from msprobe.core.common.utils import Const -from msprobe.pytorch.debugger.debugger_config import DebuggerConfig -from msprobe.pytorch.pt_config import parse_json_config -from msprobe.pytorch.service import Service - - -class TestService(unittest.TestCase): - def setUp(self): - mock_json_data = { - "dump_path": "./dump/", - } - with patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \ - patch("msprobe.pytorch.pt_config.load_json", return_value=mock_json_data): - common_config, task_config = parse_json_config("./config.json", Const.STATISTICS) - self.config = DebuggerConfig(common_config, task_config, Const.STATISTICS, "./ut_dump", "L1") - self.service = Service(self.config) - - def test_start_success(self): - with patch("msprobe.pytorch.service.get_rank_if_initialized", return_value=0), \ - patch("msprobe.pytorch.service.Service.create_dirs", return_value=None): - self.service.start(None) - self.assertEqual(self.service.current_rank, 0) - - def test_start_fail(self): - self.service.config.rank = [1, 2] - self.service.current_rank = 3 - self.assertIsNone(self.service.start(None)) - - self.service.config.step = [1, 2] - self.service.current_iter = 3 - self.assertIsNone(self.service.start(None)) - - @patch("msprobe.core.data_dump.data_collector.DataCollector.write_json") - def test_stop_success(self, mock_write_json): - mock_write_json.return_value = None - self.service.stop() - - self.assertFalse(self.service.switch) - - def test_stop_fail(self): - self.service.switch = True - - self.service.config.rank = [1, 2] - self.service.current_rank = 3 - res = self.service.stop() - self.assertIsNone(res) - self.assertTrue(self.service.switch) - - self.service.config.step = [1, 2] - self.service.current_iter = 3 - res = self.service.stop() - self.assertIsNone(res) - self.assertTrue(self.service.switch) - - self.service.config.level = "L2" - res = self.service.stop() - self.assertIsNone(res) - self.assertTrue(self.service.switch) - - self.service.should_stop_service = True - res = self.service.stop() - self.assertIsNone(res) - self.assertTrue(self.service.switch) - - def test_step_success(self): - self.service.step() - self.assertEqual(self.service.current_iter, 1) - - def test_step_fail(self): - self.service.should_stop_service = True - self.assertIsNone(self.service.step()) - - def test_register_module_hook_with_level0(self): - self.service.model = MagicMock() - self.service.build_hook = MagicMock() - self.config.level = "L0" - with patch("msprobe.pytorch.service.logger.info_on_rank_0") as mock_logger, \ - patch("msprobe.pytorch.service.ModuleProcesser.register_module_hook") as mock_register_module_hook: - self.service.register_module_hook() - self.assertEqual(mock_logger.call_count, 1) - mock_register_module_hook.assert_called_once() - - def test_register_api_hook_with_level1(self): - self.service.build_hook = MagicMock() - self.config.level = "L1" - with patch("msprobe.pytorch.service.logger.info_on_rank_0") as mock_logger, \ - patch("msprobe.pytorch.service.api_register.initialize_hook") as mock_init_hook, \ - patch("msprobe.pytorch.service.api_register.api_modularity") as mock_api_modularity: - self.service.register_api_hook() - self.assertEqual(mock_logger.call_count, 1) - mock_init_hook.assert_called_once() - mock_api_modularity.assert_called_once() - - def test_create_dirs(self): - with patch("msprobe.pytorch.service.create_directory"), \ - patch("msprobe.core.data_dump.data_collector.DataCollector.update_dump_paths"), \ - patch("msprobe.core.data_dump.data_collector.DataCollector.initialize_json_file"): - self.service.create_dirs() - self.assertEqual(self.service.dump_iter_dir, "./ut_dump/step0") - - def test_need_end_service(self): - self.service.should_stop_service = True - self.assertTrue(self.service.need_stop_service()) - - self.service.should_stop_service = False - self.service.config.step = [1, 3] - self.service.current_iter = 1 - self.assertFalse(self.service.need_stop_service()) - - self.service.current_iter = 2 - self.assertTrue(self.service.need_stop_service()) - - self.service.current_iter = 4 - self.service.config.level = "L0" - self.service.config.online_run_ut = False - self.assertTrue(self.service.need_stop_service()) - self.assertFalse(self.service.switch) - self.assertTrue(self.service.should_stop_service) - - def test_should_execute_hook_return_false(self): - module = MagicMock() - self.service.switch = False - self.assertFalse(self.service.should_execute_hook("Module", module, True)) - self.assertFalse(self.service.should_execute_hook("api", module, True)) - - self.service.switch = True - module.forward_data_collected = False - self.assertFalse(self.service.should_execute_hook("api", module, False)) - - self.service.inner_switch = True - self.assertFalse(self.service.should_execute_hook("Module", module, True)) - - self.service.inner_switch = False - self.service.data_collector = None - self.assertFalse(self.service.should_execute_hook("Module", module, True)) - - def test_should_execute_hook_return_true(self): - module = MagicMock() - self.service.switch = True - self.service.inner_switch = False - self.service.data_collector = MagicMock() - self.service.data_collector.data_processor = MagicMock() - self.service.data_collector.data_processor.is_terminated = False - self.assertTrue(self.service.should_execute_hook("Module", module, True)) - - module.forward_data_collected = True - self.assertTrue(self.service.should_execute_hook("api", module, False)) diff --git a/debug/accuracy_tools/msprobe/test/resources/layer_mapping/mindspore/dump.json b/debug/accuracy_tools/msprobe/test/resources/layer_mapping/mindspore/dump.json index b55f9e0699fe6329ceeb09a51fe20118c65545e7..153d84e7d117b5be89dfdb522edc39dc066929cb 100644 --- a/debug/accuracy_tools/msprobe/test/resources/layer_mapping/mindspore/dump.json +++ b/debug/accuracy_tools/msprobe/test/resources/layer_mapping/mindspore/dump.json @@ -1,6 +1,7 @@ { "task": "statistics", "level": "mix", + "framework": "mindspore", "dump_data_dir": null, "data": { "Cell.network_with_loss.module.language_model.embedding.word_embeddings.VocabParallelEmbedding.forward.0": { diff --git a/debug/accuracy_tools/msprobe/test/resources/layer_mapping/pytorch/dump.json b/debug/accuracy_tools/msprobe/test/resources/layer_mapping/pytorch/dump.json index d7dd1c0c38e2d24c8b0d19c346a50eb33437d232..02239176a9d690c4ce70c06cc6ab117a3c122811 100644 --- a/debug/accuracy_tools/msprobe/test/resources/layer_mapping/pytorch/dump.json +++ b/debug/accuracy_tools/msprobe/test/resources/layer_mapping/pytorch/dump.json @@ -1,6 +1,7 @@ { "task": "statistics", "level": "mix", + "framework": "pytorch", "dump_data_dir": null, "data": { "Module.module.module.language_model.embedding.word_embeddings.VocabParallelEmbedding.forward.0": { diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_builder.py b/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_builder.py index 706dc8bf82e59f413c3fd559a39af89c6a70be47..fef328ce41a2287be611131d3e21f86ce8f8ef0b 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_builder.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_builder.py @@ -32,11 +32,6 @@ class TestGraphBuilder(unittest.TestCase): self.assertIsInstance(graph, Graph) self.assertEqual(len(graph.node_map), 3) - @patch('msprobe.visualization.builder.graph_builder.save_json_file') - def test_to_json(self, mock_save_json_file): - GraphBuilder.to_json("step/rank/output.vis", self.config) - mock_save_json_file.assert_called_once() - @patch('msprobe.visualization.graph.node_op.NodeOp.get_node_op') @patch('msprobe.visualization.builder.msprobe_adapter.get_input_output', return_value=([], [])) def test__init_nodes(self, mock_get_input_output, mock_get_node_op): @@ -111,3 +106,23 @@ class TestGraphBuilder(unittest.TestCase): self.assertEqual(graph.root.subnodes[2].op, NodeOp.module) self.assertEqual(len(graph.root.subnodes[0].subnodes), 0) self.assertEqual(graph.root.subnodes[0].id, 'Module.a.0') + + def test_add_parameters_grad(self): + graph = Graph('TestNet') + graph.add_node(NodeOp.module, 'Module.a.backward.0', graph.root) + graph.add_node(NodeOp.module, 'Module.b.backward.0', graph.root) + graph.add_node(NodeOp.module, 'Module.a.backward.1', graph.root) + graph.add_node(NodeOp.module, 'Module.aa.backward.0', graph.get_node('Module.a.backward.0')) + graph.add_node(NodeOp.module, 'Module.aaa.backward.0', graph.get_node('Module.a.backward.0')) + graph.add_node(NodeOp.module, 'Module.aa.backward.1', graph.get_node('Module.a.backward.1')) + graph.add_node(NodeOp.module, 'Module.aaa.backward.1', graph.get_node('Module.a.backward.1')) + + data_dict = {'Module.a.parameters_grad': {}, 'Module.aaa.parameters_grad': {}} + GraphBuilder._add_parameters_grad(graph, data_dict) + root_nodes_id = [node.id for node in graph.get_node('TestNet').subnodes] + sub_nodes_id0 = [node.id for node in graph.get_node('Module.a.backward.0').subnodes] + sub_nodes_id1 = [node.id for node in graph.get_node('Module.a.backward.1').subnodes] + + self.assertEqual(root_nodes_id[-1], 'Module.a.backward.1') + self.assertEqual(sub_nodes_id0[-1], 'Module.aaa.backward.0') + self.assertEqual(sub_nodes_id1[-1], 'Module.a.parameters_grad') diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_merger.py b/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_merger.py new file mode 100644 index 0000000000000000000000000000000000000000..36f551edc8df75b5d1d6c8c0ddf40e4eff0084fc --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_merger.py @@ -0,0 +1,409 @@ +import unittest +from unittest.mock import patch, MagicMock, call +from msprobe.visualization.builder.graph_merger import ( + GraphMerger, BaseGraphMerger, PPMerger, TPMerger, + NoParallelMerger, TPPPMerger, FullMerger +) +from msprobe.core.common.const import Const +from msprobe.visualization.utils import GraphConst, ParallelParam +from msprobe.visualization.graph.node_op import NodeOp +from msprobe.visualization.graph.graph import Graph +from msprobe.core.common.exceptions import MsprobeException + + +class TestGraphMerger(unittest.TestCase): + def setUp(self): + self.build_graph_results = MagicMock() + self.parallel_param = ParallelParam(tp=1, pp=1, rank_size=1) + self.is_bench = False + + def test_select_strategy_no_parallel(self): + self.parallel_param.tp = self.parallel_param.pp = self.parallel_param.rank_size = 1 + merger = GraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + self.assertIsInstance(merger.strategy, NoParallelMerger) + + def test_select_strategy_tp(self): + self.parallel_param.tp = self.parallel_param.rank_size = 2 + self.parallel_param.pp = 1 + merger = GraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + self.assertIsInstance(merger.strategy, TPMerger) + + def test_select_strategy_pp(self): + self.parallel_param.pp = self.parallel_param.rank_size = 2 + self.parallel_param.tp = 1 + merger = GraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + self.assertIsInstance(merger.strategy, PPMerger) + + def test_select_strategy_tp_pp(self): + self.parallel_param.tp = self.parallel_param.pp = 2 + self.parallel_param.rank_size = 4 + merger = GraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + self.assertIsInstance(merger.strategy, TPPPMerger) + + def test_select_strategy_full(self): + self.parallel_param.tp = 2 + self.parallel_param.pp = 2 + self.parallel_param.rank_size = 8 + merger = GraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + self.assertIsInstance(merger.strategy, FullMerger) + + def test_merge_graph(self): + merger = GraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + merger.strategy.merge_graphs = MagicMock() + merger.merge_graph() + merger.strategy.merge_graphs.assert_called_once() + + +class TestBaseGraphMerger(unittest.TestCase): + def setUp(self): + self.build_graph_results = [MagicMock(rank=i) for i in range(2)] + self.parallel_param = ParallelParam(tp=1, pp=1, rank_size=2) + self.is_bench = False + self.merger = BaseGraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + + def test_sort_merged_api_collection(self): + graph = MagicMock() + root = MagicMock() + graph.root = root + subnode1 = MagicMock(id=f"{GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS}.0", op=NodeOp.api_collection) + subnode1.subnodes = [MagicMock(id="op_Rank1.0"), MagicMock(id="op_Rank0.0")] + root.subnodes = [subnode1] + self.merger.sort_merged_api_collection(graph) + self.assertEqual([n.id for n in subnode1.subnodes], ["op_Rank0.0", "op_Rank1.0"]) + + def test_update_node_data_key(self): + data_dict = { + "old_id.input.0": {"full_op_name": "old_id.op"}, + "other_key": {"value": "test"} + } + new_dict = self.merger._update_node_data_key("old_id", "new_id", data_dict) + self.assertEqual(new_dict, { + "new_id.input.0": {"full_op_name": "new_id.op"}, + "other_key": {"value": "test"} + }) + + def test_compare_value_same(self): + self.assertTrue(self.merger._compare_value_same(1, 1)) + self.assertFalse(self.merger._compare_value_same(1, 2)) + self.assertTrue(self.merger._compare_value_same("a", "a")) + self.assertTrue(self.merger._compare_value_same(1, 1.00000001, has_uncertainty=True)) + self.assertFalse(self.merger._compare_value_same(1, 1.1, has_uncertainty=True)) + + def test_merge_graph_api_collection(self): + results = [MagicMock() for _ in range(2)] + graph0, graph1 = Graph("name1"), Graph("name2") + results[0].graph, results[1].graph = graph0, graph1 + root0, root1 = MagicMock(), MagicMock() + graph0.root, graph1.root = root0, root1 + node0 = MagicMock(id=f"{GraphConst.APIS_BETWEEN_MODULES}.0") + node0_sub1 = MagicMock(id="sub_op.0") + node0.subnodes = [node0_sub1] + node1 = MagicMock(id=f"{GraphConst.APIS_BETWEEN_MODULES}.0") + node1_sub1 = MagicMock(id="sub_op.0") + graph0.node_map = {f"{GraphConst.APIS_BETWEEN_MODULES}.0": node0} + node1.subnodes = [node1_sub1] + root0.subnodes = [node0] + root1.subnodes = [node1] + + self.merger.merge_graph_api_collection(results) + + self.assertEqual(len(root0.subnodes), 1) + self.assertTrue(root0.subnodes[0].id.startswith(GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS)) + self.assertEqual(len(root0.subnodes[0].subnodes), 1) + + def test_split_graph_results_by_groups(self): + groups = [[0, 1], [2, 3]] + results = [MagicMock(rank=i) for i in range(4)] + self.merger.build_graph_results = results + split = self.merger.split_graph_results_by_groups(groups) + self.assertEqual(len(split), 2) + self.assertEqual([r.rank for r in split[0]], [0, 1]) + self.assertEqual([r.rank for r in split[1]], [2, 3]) + + def test_compare_node_param_data(self): + main_node = MagicMock() + other_nodes = [MagicMock()] + main_node.id = "id" + other_nodes[0].id = "id" + main_node.input_data = {"input.0": {Const.DTYPE: "torch.float16", Const.MAX: 1}} + other_nodes[0].input_data = {"input.0": {Const.DTYPE: "torch.float16", Const.MAX: 2}} + in_diff, out_diff = self.merger.compare_node_param_data(main_node, other_nodes) + self.assertEqual(list(in_diff.keys()), ["input.0"]) + + def test_compare_param_same(self): + param1 = {Const.MAX: 1, Const.MIN: 0, Const.MEAN: 0.5, Const.NORM: 1} + param2 = {Const.MAX: 1, Const.MIN: 0, Const.MEAN: 0.5, Const.NORM: 1} + self.assertTrue(self.merger.compare_param_same(param1, param2)) + + param2[Const.MAX] = 2 + self.assertFalse(self.merger.compare_param_same(param1, param2)) + + def test_add_all_nodes_rank(self): + graph0, graph1 = MagicMock(), MagicMock() + node0, node1 = MagicMock(), MagicMock() + graph0.node_map.values.return_value = [node0] + graph1.node_map.values.return_value = [node1] + self.build_graph_results[0].graph = graph0 + self.build_graph_results[1].graph = graph1 + + self.merger._add_all_nodes_rank() + + self.assertEqual(node0.rank, 0) + self.assertEqual(node1.rank, 1) + + def test_get_default_groups(self): + self.parallel_param.tp = 4 + self.parallel_param.pp = 2 + self.parallel_param.rank_size = 8 + merger = BaseGraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + tp_groups, pp_groups = merger.get_default_groups() + self.assertEqual(tp_groups, [[0, 1, 2, 3], [4, 5, 6, 7]]) + self.assertEqual(pp_groups, [[0, 4], [1, 5], [2, 6], [3, 7]]) + + self.parallel_param.tp = 2 + self.parallel_param.pp = 2 + self.parallel_param.rank_size = 8 + merger = BaseGraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + tp_groups, pp_groups = merger.get_default_groups() + self.assertEqual(tp_groups, [[0, 1], [2, 3], [4, 5], [6, 7]]) + self.assertEqual(pp_groups, [[0, 4], [1, 5], [2, 6], [3, 7]]) + + self.parallel_param.tp = 2 + self.parallel_param.pp = 3 + self.parallel_param.rank_size = 8 + merger = BaseGraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + with self.assertRaises(MsprobeException): + merger.get_default_groups() + + +class TestPPMerger(unittest.TestCase): + def setUp(self): + self.build_graph_results = [MagicMock(rank=i) for i in range(4)] + self.parallel_param = ParallelParam(tp=1, pp=4, rank_size=4) + self.is_bench = False + self.merger = PPMerger(self.build_graph_results, self.parallel_param, self.is_bench) + + def test_trace_p2p_mapping(self): + p2p_mapping = {0: 2, 1: 3, 2: 4, 3: 5, 4: 6, 5: 7, 6: 4, 7: 5} + chains = self.merger._trace_p2p_mapping(p2p_mapping) + self.assertEqual(len(chains), 2) + self.assertIn([0, 2, 4, 6], chains) + self.assertIn([1, 3, 5, 7], chains) + + @patch('msprobe.visualization.builder.graph_merger.PPMerger._merge_nodes') + def test_merge_nodes(self, mock_merge): + main_graph = MagicMock() + main_node = MagicMock(id="module.layers.0.forward") + other_graphs = [MagicMock() for _ in range(3)] + for i, g in enumerate(other_graphs): + g.get_node.return_value = MagicMock(id=f"module.layers.{i}.forward") + + self.merger._merge_nodes(main_graph, main_node, other_graphs) + mock_merge.assert_called() + + def test_merge_graphs(self): + self.merger.get_groups = MagicMock(return_value=[[0, 1, 2, 3]]) + self.merger.merge_pp_graphs = MagicMock(return_value=self.build_graph_results[:1]) + results = self.merger.merge_graphs() + self.assertEqual(len(results), 1) + + def test_get_groups(self): + for i, result in enumerate(self.build_graph_results): + graph = MagicMock() + node = MagicMock(id=f"Distributed.send.{i}.forward") + node.input_data = {f"Distributed.send.{i}.forward.input.dst": {"value": (i + 1) % 4}} + graph.node_map.values.return_value = [node] + result.graph = graph + + groups = self.merger.get_groups() + self.assertEqual(len(groups), 1) + self.assertEqual(groups[0], [0, 1, 2, 3]) + + def test_merge_other_unique_nodes(self): + main_graph = MagicMock() + main_node = MagicMock() + other_nodes = [MagicMock()] + main_node.subnodes = [MagicMock(id="main_sub.0")] + other_nodes[0].subnodes = [MagicMock(id="other_sub.0")] + + self.merger._merge_other_unique_nodes(main_graph, main_node, other_nodes) + self.assertEqual(len(main_node.subnodes), 2) + + def test_sort_nodes(self): + graph = MagicMock() + start_node = MagicMock(id="module.layers.0.forward%0%0") + start_node.op = NodeOp.module + api_node = MagicMock(id="Torch.mul.forward.0%0%0") + graph.node_map = {"module.layers.0.forward%0%0": start_node, "Torch.mul.forward.0%0%0": api_node} + parent_node = MagicMock() + parent_node.subnodes = [start_node, api_node] + start_node.upnode = parent_node + + self.merger._sort_nodes(graph, start_node) + self.assertEqual(parent_node.subnodes[0].id, "module.layers.0.forward") + self.assertEqual(parent_node.subnodes[1].id, "Torch.mul_rank0.forward.0") + + def test_add_node_to_main_graph(self): + graph = MagicMock() + node = MagicMock() + subnode = MagicMock() + node.subnodes = [subnode] + + self.merger._add_node_to_main_graph(graph, node) + graph.node_map.__setitem__.assert_has_calls([call(node.id, node), call(subnode.id, subnode)]) + + def test_get_node_sort_rule(self): + node = MagicMock(id="module.layers.0.forward%1%2") + self.assertEqual(self.merger._get_node_sort_rule(node), (2, 1)) + self.assertEqual(self.merger._get_node_sort_rule(node, rank_ascending=False), (-2, 1)) + + def test_mark_node_id_position_rank(self): + node = MagicMock() + parent_node = MagicMock() + parent_node.subnodes = [MagicMock(), node, MagicMock()] + node.upnode = parent_node + node.id = "module.layers.0.forward" + + self.merger._mark_node_id_position_rank(node, 2) + self.assertEqual(node.id, "module.layers.0.forward%1%2") + + def test_update_node_id(self): + graph = MagicMock() + start_node = MagicMock(id="module.layers.0.forward%1%2") + start_node.op = NodeOp.module + start_node.pp_index = 1 + graph.node_map = {start_node.id: start_node} + + self.merger._update_node_id(graph, start_node) + self.assertEqual(start_node.id, "module.layers.1.forward") + + +class TestTPMerger(unittest.TestCase): + def setUp(self): + self.build_graph_results = [MagicMock(rank=i) for i in range(4)] + self.parallel_param = ParallelParam(tp=4, pp=1, rank_size=4) + self.is_bench = False + self.merger = TPMerger(self.build_graph_results, self.parallel_param, self.is_bench) + + def test_merge_params(self): + params = { + "input.0": [ + {Const.MAX: 1, Const.MIN: 0, Const.MEAN: 0.5, Const.NORM: 1}, + {Const.MAX: 2, Const.MIN: 0, Const.MEAN: 0.7, Const.NORM: 1.2} + ] + } + merge_info = self.merger._merge_params(params) + self.assertIn("The Max value merging method for input.0 is: max(1, 2) = 2", merge_info) + self.assertIn("The Mean value merging method for input.0 is: (0.5 + 0.7) / 2 = 0.6", merge_info) + + def test_get_need_merge_node(self): + main_node = MagicMock(id="module.matmul_rank0.forward") + other_graphs = [MagicMock() for _ in range(3)] + tp_merge_mapping = {0: [1, 2, 3]} + + for i, g in enumerate(other_graphs): + g.node_map = {f"module.matmul_rank{i + 1}.forward": MagicMock()} + + nodes = self.merger._get_need_merge_node(main_node, other_graphs, tp_merge_mapping) + self.assertEqual(len(nodes), 0) + + def test_merge_graphs(self): + self.merger.get_groups = MagicMock(return_value=[[0, 1, 2, 3]]) + self.merger.merge_tp_graphs = MagicMock(return_value=self.build_graph_results[:1]) + results = self.merger.merge_graphs() + self.assertEqual(len(results), 1) + + def test_get_groups(self): + for i, result in enumerate(self.build_graph_results): + graph = MagicMock() + node = MagicMock(id=f"all_reduce.{i}") + node.input_data = {f"all_reduce.{i}.input.group": {"group_ranks": [0, 1, 2, 3]}} + graph.node_map.values.return_value = [node] + result.graph = graph + + groups = self.merger.get_groups() + self.assertEqual(len(groups), 1) + self.assertEqual(groups[0], [0, 1, 2, 3]) + + def test_handle_tp_matmul_reduce(self): + node = MagicMock(id=f"module.RowParallelLinear.forward.0") + node.op = NodeOp.module + matmul_node = MagicMock(id="matmul.0") + matmul_node.output_data = {"output.0": {Const.MAX: 1}} + reduce_node = MagicMock(id="all_reduce.0") + reduce_node.input_data = {"input.0": {Const.MAX: 1}} + reduce_node.output_data = {"output.0": {Const.MAX: 2}} + node.subnodes = [matmul_node, reduce_node] + other_graphs = [MagicMock()] + + self.merger._handle_tp_matmul_reduce(node, other_graphs, {}) + self.assertEqual(matmul_node.output_data["output.0"][Const.MAX], 2) + + +class TestNoParallelMerger(unittest.TestCase): + def setUp(self): + self.build_graph_results = [MagicMock()] + self.parallel_param = ParallelParam(tp=1, pp=1, rank_size=1) + self.is_bench = False + self.merger = NoParallelMerger(self.build_graph_results, self.parallel_param, self.is_bench) + + def test_merge_graphs(self): + self.merger.merge_graph_api_collection = MagicMock() + results = self.merger.merge_graphs() + self.assertEqual(results, self.build_graph_results) + self.merger.merge_graph_api_collection.assert_called_once_with(self.build_graph_results) + + +class TestTPPPMerger(unittest.TestCase): + def setUp(self): + self.build_graph_results = [MagicMock(rank=i) for i in range(4)] + self.parallel_param = ParallelParam(tp=2, pp=2, rank_size=4) + self.is_bench = False + self.merger = TPPPMerger(self.build_graph_results, self.parallel_param, self.is_bench) + + @patch('msprobe.visualization.builder.graph_merger.TPMerger') + @patch('msprobe.visualization.builder.graph_merger.PPMerger') + def test_merge_graphs(self, mock_pp, mock_tp): + tp_merger = MagicMock() + pp_merger = MagicMock() + mock_tp.return_value = tp_merger + mock_pp.return_value = pp_merger + + pp_merger.get_groups.return_value = [[0, 1], [2, 3]] + tp_merger.get_groups.return_value = [[0, 2], [1, 3]] + tp_merger.merge_tp_graphs.return_value = [MagicMock()] + + results = self.merger.merge_graphs() + self.assertEqual(len(results), 1) + + +class TestFullMerger(unittest.TestCase): + def setUp(self): + self.build_graph_results = [MagicMock(rank=i) for i in range(8)] + self.parallel_param = ParallelParam(tp=2, pp=4, rank_size=8, vpp=1) + self.is_bench = False + self.merger = FullMerger(self.build_graph_results, self.parallel_param, self.is_bench) + + @patch('msprobe.visualization.builder.graph_merger.TPMerger') + @patch('msprobe.visualization.builder.graph_merger.PPMerger') + def test_merge_graphs(self, mock_pp, mock_tp): + tp_merger = MagicMock() + pp_merger = MagicMock() + mock_tp.return_value = tp_merger + mock_pp.return_value = pp_merger + + pp_merger.get_groups.return_value = [[0, 1, 2, 3], [4, 5, 6, 7]] + tp_merger.get_groups.return_value = [[0, 4], [1, 5], [2, 6], [3, 7]] + + pp_result0 = MagicMock(rank=0) + pp_result1 = MagicMock(rank=4) + pp_merger.merge_pp_graphs.side_effect = [[pp_result0], [pp_result1]] + + tp_merger.merge_tp_graphs.side_effect = [[MagicMock()], [MagicMock()]] + + results = self.merger.merge_graphs() + self.assertEqual(len(results), 1) + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_msprobe_adapter.py b/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_msprobe_adapter.py index bee32a34a0509d5559b47d7a1625f618dc132d4e..e2ca516542a9840e0230a58eca5d0ad20c6f7579 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_msprobe_adapter.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_msprobe_adapter.py @@ -11,6 +11,7 @@ from msprobe.visualization.builder.msprobe_adapter import ( _format_data ) from msprobe.visualization.utils import GraphConst +from msprobe.visualization.graph.base_node import BaseNode import torch from msprobe.core.common.const import Const @@ -55,11 +56,9 @@ class TestMsprobeAdapter(unittest.TestCase): @patch('msprobe.visualization.builder.msprobe_adapter.get_accuracy') def test_compare_node(self, mock_get_accuracy): - node_ids = ["node1", "node2"] - data_dicts = [{'node1': {"input_args": [], "input_kwargs": {}, "output": {}}}, - {'node2': {"input_args": [], "input_kwargs": {}, "output": {}}}] - stack_json_data = {} - result = compare_node(node_ids, data_dicts, stack_json_data, GraphConst.REAL_DATA_COMPARE) + node_n = BaseNode('', 'node1') + node_b = BaseNode('', 'node2') + result = compare_node(node_n, node_b, GraphConst.REAL_DATA_COMPARE) mock_get_accuracy.assert_called_once() self.assertIsInstance(result, list) diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_graph_comparator.py b/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_graph_comparator.py index f4d68ccb530919dbdfedaa12bea716b2c70e278d..4accdacd76a434b6329a9fce378e38927092e9ae 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_graph_comparator.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_graph_comparator.py @@ -1,5 +1,6 @@ import os import unittest +from typing import Any from dataclasses import dataclass from unittest.mock import patch from unittest.mock import MagicMock @@ -12,7 +13,7 @@ from msprobe.visualization.utils import GraphConst class Args: input_path: str = None output_path: str = None - layer_mapping: str = None + layer_mapping: Any = None framework: str = None overflow_check: bool = False fuzzy_match: bool = False @@ -39,7 +40,7 @@ class TestGraphComparator(unittest.TestCase): mock_load_data_json_file.return_value = "data_dict" mock_load_json_file.return_value = "construct_dict" mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE - self.comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path)) + self.comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path), False) self.comparator._parse_param(self.dump_path_param, self.output_path) self.assertEqual(self.comparator.dump_path_param, { @@ -57,7 +58,7 @@ class TestGraphComparator(unittest.TestCase): mock_load_data_json_file.return_value = "data_dict" mock_load_json_file.return_value = "construct_dict" mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE - comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path)) + comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path), False) comparator._compare_nodes = MagicMock() comparator._postcompare = MagicMock() @@ -76,7 +77,7 @@ class TestGraphComparator(unittest.TestCase): node = MagicMock() compare_result_list = [("output1", "data1"), ("input1", "data2")] - comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path)) + comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path), False) comparator.ma = MagicMock() comparator.ma.prepare_real_data.return_value = True @@ -100,7 +101,7 @@ class TestGraphComparator(unittest.TestCase): mock_run_real_data.return_value = mock_df mock_get_csv_df.return_value = mock_df mock_get_node_error_status.return_value = True - comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path)) + comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path), False) comparator.ma = MagicMock() comparator.ma.compare_mode = GraphConst.REAL_DATA_COMPARE comparator._handle_api_collection_index = MagicMock() @@ -118,7 +119,7 @@ class TestGraphComparator(unittest.TestCase): mock_load_data_json_file.return_value = "data_dict" mock_load_json_file.return_value = "construct_dict" mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE - comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path)) + comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path), False) apis = BaseNode(NodeOp.api_collection, 'Apis_Between_Modules.0') api1 = BaseNode(NodeOp.function_api, 'Tensor.a.0') api1.data = {GraphConst.JSON_INDEX_KEY: 0.9} @@ -145,11 +146,12 @@ class TestGraphComparator(unittest.TestCase): mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE mock_mapping_match.return_value = (node_b, [], []) mock_compare_node.return_value = ['result'] - comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path)) + comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path, layer_mapping=True), True) comparator.mapping_dict = True comparator._compare_nodes(node_n) self.assertEqual(node_n.matched_node_link, ['Tensor.b.0']) self.assertEqual(node_b.matched_node_link, ['Tensor.a.0']) + comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path), False) comparator.mapping_dict = False node_n = BaseNode(NodeOp.function_api, 'Tensor.a.0') node_b = BaseNode(NodeOp.function_api, 'Tensor.a.0') @@ -185,6 +187,6 @@ class TestGraphComparator(unittest.TestCase): 'stack_json_path': os.path.join(dir_name, 'input', 'step0', 'rank0', 'stack.json'), 'is_print_compare_log': True } - comparator = GraphComparator(self.graphs, dump_path_param, Args(output_path=self.output_path)) + comparator = GraphComparator(self.graphs, dump_path_param, Args(output_path=self.output_path), False) comparator.add_compare_result_to_node(node, compare_result_list) self.assertEqual(node.data, {'precision_index': 0}) diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_mode_adapter.py b/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_mode_adapter.py index 87d1f9ee5f01c7c9b2f264f3e6ec16b5155c1f8e..25ad91605900b2026235aff693762dec0556d27c 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_mode_adapter.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_mode_adapter.py @@ -2,7 +2,8 @@ import json import unittest from unittest.mock import patch, MagicMock from msprobe.visualization.compare.mode_adapter import ModeAdapter -from msprobe.visualization.graph.base_node import BaseNode, NodeOp +from msprobe.visualization.graph.base_node import BaseNode +from msprobe.visualization.graph.node_op import NodeOp from msprobe.visualization.utils import GraphConst, ToolTip from msprobe.core.common.const import CompareConst @@ -25,22 +26,22 @@ class TestModeAdapter(unittest.TestCase): node_data = {'Tensor.__imul__.0.forward.input.0': {'type': 'torch.Tensor', 'dtype': 'torch.int64', 'shape': [], 'Max': 16388, 'Min': 16388, 'Mean': 16388, 'Norm': 16388, - 'requires_grad': False, 'md5': 'a563a4ea', + 'requires_grad': 'False', 'md5': 'a563a4ea', 'full_op_name': 'Tensor.__imul__.0.forward.input.0', - 'data_name': '-1'}, + 'data_name': '-1', 'state': 'input'}, 'Tensor.__imul__.0.forward.input.1': {'type': 'torch.Tensor', 'dtype': 'torch.int64', 'shape': [], 'Max': 4097, 'Min': 4097, 'Mean': 4097, 'Norm': 4097, - 'requires_grad': False, 'md5': 'ce564339', + 'requires_grad': 'False', 'md5': 'ce564339', 'full_op_name': 'Tensor.__imul__.0.forward.input.1', - 'data_name': '-1'}} + 'data_name': '-1', 'state': 'input'}} compare_dict = {'Tensor.__imul__.0.forward.input.0': ['Tensor.__imul__.0.forward.input.0', 'Tensor.__imul__.0.forward.input.0', 'torch.int64', - 'torch.int64', [], [], 'a563a4ea', 'a563a4ea', 'pass', - []], + 'torch.int64', [], [], 'False', 'False', + 'a563a4ea', 'a563a4ea', True, 'pass', []], 'Tensor.__imul__.0.forward.input.1': ['Tensor.__imul__.0.forward.input.1', 'Tensor.__imul__.0.forward.input.1', 'torch.int64', - 'torch.int64', [], [], 'ce564339', 'ce564559', 'diff', - 'None']} + 'torch.int64', [], [], 'False', 'False', + 'ce564339', 'ce564559', True, 'diff', 'None']} precision_index = ModeAdapter._add_md5_compare_data(node_data, compare_dict) self.assertEqual(precision_index, 0) @@ -48,68 +49,68 @@ class TestModeAdapter(unittest.TestCase): tensor_data = {'Module.module.Float16Module.forward.0.input.0': ['Module.module.Float16Module.forward.0.input.0', 'Module.module.Float16Module.forward.0.input.0', - 'torch.int64', 'torch.int64', [1, 1024], [1, 1024], - 1.0, 0.0, 0.0, 1.0, 1.0, 29992.0, 1.0, 9100.3125, - 474189.09375, 29992.0, 1.0, 9100.3125, 474189.09375, - 'Yes', '', None, + 'torch.int64', 'torch.int64', [1, 1024], [1, 1024], 'False', 'False', + 1.0, 0.0, 0.0, 1.0, 1.0, + 29992.0, 1.0, 9100.3125, 474189.09375, + 29992.0, 1.0, 9100.3125, 474189.09375, + True, 'Yes', '', None, 'Module.module.Float16Module.forward.0.input.0.pt'], 'Module.module.Float16Module.forward.0.input.1': [ 'Module.module.Float16Module.forward.0.input.1', 'Module.module.Float16Module.forward.0.input.1', - 'torch.int64', 'torch.int64', [1, 1024], [1, 1024], - 1.0, 0.0, 0.0, None, 1.0, 1023.0, 0.0, 511.5, - 18904.755859375, 1023.0, 0.0, 511.5, 18904.755859375, - 'Yes', '', 'None', + 'torch.int64', 'torch.int64', [1, 1024], [1, 1024], 'False', 'False', + 1.0, 0.0, 0.0, None, 1.0, + 1023.0, 0.0, 511.5, 18904.755859375, + 1023.0, 0.0, 511.5, 18904.755859375, + True, 'Yes', '', 'None', 'Module.module.Float16Module.forward.0.input.1.pt'], 'Module.module.Float16Module.forward.0.input.2': [ 'Module.module.Float16Module.forward.0.input.2', 'Module.module.Float16Module.forward.0.input.2', - 'torch.bool', 'torch.bool', [1, 1, 1024, 1024], - [1, 1, 1024, 1024], 1.0, 0.0, 0.0, 1.0, 1.0, True, - False, None, None, True, False, None, None, 'Yes', '', - 'None', + 'torch.bool', 'torch.bool', [1, 1, 1024, 1024], [1, 1, 1024, 1024], 'False', 'False', + 1.0, 0.0, 0.0, 1.0, 1.0, + True, False, None, None, True, False, None, None, + True, 'Yes', '', 'None', 'Module.module.Float16Module.forward.0.input.2.pt'], 'Module.module.Float16Module.forward.0.kwargs.labels': [ 'Module.module.Float16Module.forward.0.kwargs.labels', - 'Module.module.Float16Module.forward.0.kwargs.labels', 'torch.int64', - 'torch.int64', [1, 1024], - [1, 1024], 1.0, 0.0, 0.0, 1.0, 1.0, 29992.0, 1.0, 9108.99609375, 474332.28125, - 29992.0, 1.0, - 9108.99609375, 474332.28125, 'Yes', '', 'None', + 'Module.module.Float16Module.forward.0.kwargs.labels', + 'torch.int64', 'torch.int64', [1, 1024], [1, 1024], 'False', 'False', + 1.0, 0.0, 0.0, 1.0, 1.0, + 29992.0, 1.0, 9108.99609375, 474332.28125, + 29992.0, 1.0, 9108.99609375, 474332.28125, + True, 'Yes', '', 'None', 'Module.module.Float16Module.forward.0.kwargs.labels.pt'], 'Module.module.Float16Module.forward.0.output.0': [ 'Module.module.Float16Module.forward.0.output.0', 'Module.module.Float16Module.forward.0.output.0', - 'torch.float32', 'torch.float32', [1, 1024], - [1, 1024], 0.994182636336, 4.863566398621, - 0.461487948895, 0.0068359375, 0.0234375, - 15.402446746826172, 7.318280220031738, - 11.375151634216309, 366.3365173339844, - 10.538880348205566, 10.215872764587402, - 10.378824234008789, 332.1264953613281, 'No', '', - 'None', + 'torch.float32', 'torch.float32', [1, 1024], [1, 1024], 'False', 'False', + 0.994182636336, 4.863566398621, 0.461487948895, 0.0068359375, 0.0234375, + 15.402446746826172, 7.318280220031738, 11.375151634216309, 366.3365173339844, + 10.538880348205566, 10.215872764587402, 10.378824234008789, 332.1264953613281, + True, 'No', '', 'None', 'Module.module.Float16Module.forward.0.output.0.pt']} node_data = {'Module.module.Float16Module.forward.0.input.0': {'type': 'torch.Tensor', 'dtype': 'torch.int64', 'shape': [1, 1024], 'Max': 29992.0, 'Min': 1.0, 'Mean': 9100.3125, 'Norm': 474189.09375, - 'requires_grad': False, + 'requires_grad': 'False', 'md5': '00000000'}, 'Module.module.Float16Module.forward.0.input.1': {'type': 'torch.Tensor', 'dtype': 'torch.int64', 'shape': [1, 1024], 'Max': 1023.0, 'Min': 0.0, 'Mean': 511.5, 'Norm': 18904.755859375, - 'requires_grad': False, + 'requires_grad': 'False', 'md5': '00000000'}, 'Module.module.Float16Module.forward.0.input.2': {'type': 'torch.Tensor', 'dtype': 'torch.bool', 'shape': [1, 1, 1024, 1024], 'Max': True, 'Min': False, 'Mean': None, 'Norm': None, - 'requires_grad': False, + 'requires_grad': 'False', 'md5': '00000000'}, 'Module.module.Float16Module.forward.0.kwargs.labels': {'type': 'torch.Tensor', 'dtype': 'torch.int64', 'shape': None, 'Max': 29992.0, 'Min': 1.0, 'Mean': 9108.99609375, 'Norm': 474332.28125, - 'requires_grad': False, + 'requires_grad': 'False', 'md5': '00000000'}, 'Module.module.Float16Module.forward.0.kwargs.None': None} min_thousandth = ModeAdapter._add_real_compare_data(node_data, tensor_data) @@ -119,51 +120,60 @@ class TestModeAdapter(unittest.TestCase): compare_data_dict = { 'Module.module.Float16Module.forward.0.input.0': ['Module.module.Float16Module.forward.0.input.0', 'Module.module.Float16Module.forward.0.input.0', - 'torch.int64', 'torch.int64', [4, 4096], [4, 4096], 0.0, - 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', 30119.0, - 1.0, 8466.25, 1786889.625, 30119.0, 1.0, 8466.25, - 1786889.625, '', ''], + 'torch.int64', 'torch.int64', [4, 4096], [4, 4096], + 'False', 'False', + 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', + 30119.0, 1.0, 8466.25, 1786889.625, + 30119.0, 1.0, 8466.25, 1786889.625, + True, '', '', None], 'Module.module.Float16Module.forward.0.input.1': ['Module.module.Float16Module.forward.0.input.1', 'Module.module.Float16Module.forward.0.input.1', - 'torch.int64', 'torch.int64', [4, 4096], [4, 4096], 0.0, - 0.0, 0.0, 0.0, '0.0%', 'N/A', '0.0%', '0.0%', 4095.0, 0.0, - 2047.5, 302642.375, 4095.0, 0.0, 2047.5, 302642.375, '', - '', 'None'], + 'torch.int64', 'torch.int64', [4, 4096], [4, 4096], + 'False', 'False', + 0.0, 0.0, 0.0, 0.0, '0.0%', 'N/A', '0.0%', '0.0%', + 4095.0, 0.0, 2047.5, 302642.375, + 4095.0, 0.0, 2047.5, 302642.375, + True, '', '', 'None'], 'Module.module.Float16Module.forward.0.input.2': ['Module.module.Float16Module.forward.0.input.2', 'Module.module.Float16Module.forward.0.input.2', - 'torch.bool', 'torch.bool', [1, 1, 4096, 4096], - [1, 1, 4096, 4096], 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 'N/A', True, False, None, None, True, False, - None, None, '', '', 'None'], + 'torch.bool', 'torch.bool', + [1, 1, 4096, 4096], [1, 1, 4096, 4096], + 'False', 'False', + 'N/A', 'N/A', 'N/A', 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', + True, False, None, None, True, False, None, None, + True, '', '', 'None'], 'Module.module.Float16Module.forward.0.input.labels': ['Module.module.Float16Module.forward.0.input.labels', 'Module.module.Float16Module.forward.0.input.labels', - 'torch.float16', 'torch.float16', [4, 4096], - [4, 4096], + 'torch.float16', 'torch.float16', + [4, 4096], [4, 4096], + 'False', 'False', 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', 30119.0, 0.00001, 8460.7685546875, 1786117.625, - 30119.0, - 1.0, 8460.7685546875, 1786117.625, '', '', 'None']} + 30119.0, 1.0, 8460.7685546875, 1786117.625, + True, '', '', 'None']} node_data = {'Module.module.Float16Module.forward.0.input.0': {'type': 'torch.Tensor', 'dtype': 'torch.int64', 'shape': [4, 4096], 'Max': 30119.0, 'Min': 1.0, 'Mean': 8466.25, 'Norm': 1786889.625, - 'requires_grad': False, + 'requires_grad': 'False', 'data_name': '-1', 'md5': '00000000'}, 'Module.module.Float16Module.forward.0.input.1': {'type': 'torch.Tensor', 'dtype': 'torch.int64', 'shape': [4, 4096], 'Max': 4095.0, 'Min': 0.0, 'Mean': 2047.5, 'Norm': 302642.375, - 'requires_grad': False, + 'requires_grad': 'False', 'data_name': '-1', 'md5': '00000000'}, 'Module.module.Float16Module.forward.0.input.2': {'type': 'torch.Tensor', 'dtype': 'torch.bool', 'shape': [1, 1, 4096, 4096], 'Max': True, 'Min': False, 'Mean': None, 'Norm': None, - 'requires_grad': False, + 'requires_grad': 'False', 'data_name': '-1', 'md5': '00000000'}, 'Module.module.Float16Module.forward.0.input.labels': {'type': 'torch.Tensor', 'dtype': 'torch.float16', 'shape': [4, 4096], 'Max': 30119.0, 'Min': 0.00001, 'Mean': 8460.7685546875, - 'Norm': 1786117.625, 'requires_grad': False, + 'Norm': 1786117.625, + 'requires_grad': 'False', 'data_name': '-1', 'md5': '00000000'}, 'Module.module.Float16Module.forward.0.kwargs.None': None} precision_index = ModeAdapter._add_summary_compare_data(node_data, compare_data_dict) @@ -225,27 +235,6 @@ class TestModeAdapter(unittest.TestCase): self.adapter.add_csv_data(compare_result_list) self.assertEqual(self.adapter.csv_data, compare_result_list) - def test_add_error_key(self): - node_data = {'key': {}} - self.adapter.compare_mode = GraphConst.REAL_DATA_COMPARE - self.adapter.add_error_key(node_data) - self.assertEqual(node_data['key'][GraphConst.ERROR_KEY], - [CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO]) - node_data = {'key': {}} - self.adapter.compare_mode = GraphConst.SUMMARY_COMPARE - self.adapter.add_error_key(node_data) - self.assertEqual(node_data['key'][GraphConst.ERROR_KEY], - [CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR, - CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR]) - node_data = {'key': []} - self.adapter.add_error_key(node_data) - self.assertEqual(node_data['key'], []) - - node_data = {'key': {}} - self.adapter.compare_mode = '111' - self.adapter.add_error_key(node_data) - self.assertEqual(node_data['key'], {'error_key': []}) - def test_get_tool_tip(self): self.adapter.compare_mode = GraphConst.MD5_COMPARE tips = self.adapter.get_tool_tip() diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_base_node.py b/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_base_node.py index 480b95620e6a81577d825b7af55b45fc0a04c34c..4abd0f35377d89301763046c1c0db672880ccd45 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_base_node.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_base_node.py @@ -1,6 +1,6 @@ import unittest -from msprobe.visualization.graph.base_node import BaseNode, NodeOp -from msprobe.visualization.utils import GraphConst +from msprobe.visualization.graph.base_node import BaseNode +from msprobe.visualization.graph.node_op import NodeOp class TestBaseNode(unittest.TestCase): @@ -20,7 +20,6 @@ class TestBaseNode(unittest.TestCase): other_node = BaseNode(self.node_op, self.node_id, self.up_node) self.assertEqual(self.node, other_node) - def test_set_input_output(self): input_data = {'input1': 'value1'} output_data = {'output1': 'value2'} @@ -42,21 +41,6 @@ class TestBaseNode(unittest.TestCase): self.assertEqual(self.node.matched_node_link, ancestors) self.assertEqual(other_node.matched_node_link, ancestors) - def test_to_dict(self): - expected_result = { - 'id': self.node_id, - 'node_type': self.node_op.value, - 'data': {}, - 'output_data': {}, - 'input_data': {}, - 'upnode': self.up_node.id, - 'subnodes': [], - 'matched_node_link': [], - 'suggestions': {}, - 'stack_info': [] - } - self.assertEqual(self.node.to_dict(), expected_result) - def test_get_ancestors(self): expected_ancestors = ['up_node_1'] self.assertEqual(self.node.get_ancestors(), expected_ancestors) diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_graph.py b/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_graph.py index 81f9fdca5277de6e1670da409bcf93e56ece3206..0d88509055860a759c6267838d85f2d99587211f 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_graph.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_graph.py @@ -27,12 +27,6 @@ class TestGraph(unittest.TestCase): self.assertIn("api.8", graph.node_map) self.assertNotIn("api", graph.node_map) - def test_to_dict(self): - self.graph.add_node(self.node_op, self.node_id) - result = self.graph.to_dict() - self.assertEqual(result[GraphConst.JSON_ROOT_KEY], "model_name") - self.assertIn(self.node_id, result[GraphConst.JSON_NODE_KEY]) - def test_str(self): self.graph.add_node(self.node_op, self.node_id) expected_str = f'{self.node_id}' @@ -55,17 +49,6 @@ class TestGraph(unittest.TestCase): self.assertIsNotNone(matched_node) self.assertEqual(ancestors, ['node_id_a']) - def test_dfs(self): - graph = Graph("model_name") - graph.add_node(NodeOp.module, "node_a") - graph.add_node(NodeOp.module, "node_b") - node_a = BaseNode(self.node_op, self.node_id) - result = {} - graph.dfs(node_a, result) - self.assertEqual(result, {'node_id': {'id': 'node_id', 'node_type': 0, 'data': {}, - 'output_data': {}, 'input_data': {}, 'upnode': 'None', 'subnodes': [], - 'matched_node_link': [], 'suggestions': {}, 'stack_info': []}}) - def test_split_nodes_by_micro_step(self): nodes = [BaseNode(NodeOp.module, 'a.forward.0'), BaseNode(NodeOp.module, 'a.backward.0'), BaseNode(NodeOp.api_collection, 'apis.0'), BaseNode(NodeOp.module, 'a.forward.1'), diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/input/step0/rank0/construct.json b/debug/accuracy_tools/msprobe/test/visualization_ut/input/step0/rank0/construct.json index 0967ef424bce6791893e9a57bb952f80fd536e93..f38780de744675a62cee03c58fb4682448c210a6 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/input/step0/rank0/construct.json +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/input/step0/rank0/construct.json @@ -1 +1,4 @@ -{} +{ + "Tensor1": "Module1", + "Module1": null +} diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/input/step0/rank0/dump.json b/debug/accuracy_tools/msprobe/test/visualization_ut/input/step0/rank0/dump.json index 330122252bd65cb01bbf9f0cd6c912f407b32a28..18501445cf403fecabace4e817f79e3b29edace0 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/input/step0/rank0/dump.json +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/input/step0/rank0/dump.json @@ -2,5 +2,5 @@ "task": "statistics", "level": "mix", "dump_data_dir": null, - "data": {} + "data": {"api": {"input": [{}]}} } diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank0/construct.json b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank0/construct.json index 0967ef424bce6791893e9a57bb952f80fd536e93..f38780de744675a62cee03c58fb4682448c210a6 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank0/construct.json +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank0/construct.json @@ -1 +1,4 @@ -{} +{ + "Tensor1": "Module1", + "Module1": null +} diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank0/dump.json b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank0/dump.json index 330122252bd65cb01bbf9f0cd6c912f407b32a28..d40eabd5eeddb1c0eb723e49b5674a9bb0635fa7 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank0/dump.json +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank0/dump.json @@ -2,5 +2,6 @@ "task": "statistics", "level": "mix", "dump_data_dir": null, - "data": {} + "data": {"api": {"input": [{}]}}, + "framework": "pytorch" } diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank1/construct.json b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank1/construct.json index 0967ef424bce6791893e9a57bb952f80fd536e93..f38780de744675a62cee03c58fb4682448c210a6 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank1/construct.json +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank1/construct.json @@ -1 +1,4 @@ -{} +{ + "Tensor1": "Module1", + "Module1": null +} diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank1/dump.json b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank1/dump.json index 330122252bd65cb01bbf9f0cd6c912f407b32a28..d40eabd5eeddb1c0eb723e49b5674a9bb0635fa7 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank1/dump.json +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank1/dump.json @@ -2,5 +2,6 @@ "task": "statistics", "level": "mix", "dump_data_dir": null, - "data": {} + "data": {"api": {"input": [{}]}}, + "framework": "pytorch" } diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step1/rank0/construct.json b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step1/rank0/construct.json index 0967ef424bce6791893e9a57bb952f80fd536e93..f38780de744675a62cee03c58fb4682448c210a6 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step1/rank0/construct.json +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step1/rank0/construct.json @@ -1 +1,4 @@ -{} +{ + "Tensor1": "Module1", + "Module1": null +} diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step1/rank0/dump.json b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step1/rank0/dump.json index 330122252bd65cb01bbf9f0cd6c912f407b32a28..d40eabd5eeddb1c0eb723e49b5674a9bb0635fa7 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step1/rank0/dump.json +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step1/rank0/dump.json @@ -2,5 +2,6 @@ "task": "statistics", "level": "mix", "dump_data_dir": null, - "data": {} + "data": {"api": {"input": [{}]}}, + "framework": "pytorch" } diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step2/rank0/construct.json b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step2/rank0/construct.json index 0967ef424bce6791893e9a57bb952f80fd536e93..f38780de744675a62cee03c58fb4682448c210a6 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step2/rank0/construct.json +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step2/rank0/construct.json @@ -1 +1,4 @@ -{} +{ + "Tensor1": "Module1", + "Module1": null +} diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step2/rank0/dump.json b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step2/rank0/dump.json index 330122252bd65cb01bbf9f0cd6c912f407b32a28..d40eabd5eeddb1c0eb723e49b5674a9bb0635fa7 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step2/rank0/dump.json +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step2/rank0/dump.json @@ -2,5 +2,6 @@ "task": "statistics", "level": "mix", "dump_data_dir": null, - "data": {} + "data": {"api": {"input": [{}]}}, + "framework": "pytorch" } diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/test_db_utils.py b/debug/accuracy_tools/msprobe/test/visualization_ut/test_db_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0682dab09592a7a37ea416c3caff6b76fd212481 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/test_db_utils.py @@ -0,0 +1,200 @@ +import os +import sqlite3 +import unittest +from unittest.mock import Mock, patch +import tempfile +from msprobe.visualization.db_utils import ( # 请替换为实际模块名 + create_table_sql_from_dict, + create_insert_sql_from_dict, + to_db, + add_table_index, + post_process_db, + node_to_db, + config_to_db, + get_graph_unique_id, + get_node_unique_id, + node_columns, + indexes +) + + +class TestDatabaseFunctions(unittest.TestCase): + + def setUp(self): + # 创建临时文件作为测试数据库 + self.temp_db = tempfile.NamedTemporaryFile(delete=False, suffix='.db').name + self.addCleanup(os.remove, self.temp_db) + + def test_create_table_sql_from_dict(self): + """测试表创建SQL语句生成""" + test_table = "test_table" + test_columns = { + "id": "INTEGER PRIMARY KEY", + "name": "TEXT NOT NULL" + } + + expected_sql = """CREATE TABLE IF NOT EXISTS test_table ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL +);""" + + generated_sql = create_table_sql_from_dict(test_table, test_columns) + self.assertEqual(generated_sql, expected_sql) + + def test_create_insert_sql_from_dict(self): + """测试插入SQL语句生成""" + test_table = "test_table" + test_columns = {"id": "INTEGER", "name": "TEXT"} + + # 测试普通插入 + expected_sql = "INSERT INTO test_table (id, name) VALUES (?, ?)" + generated_sql = create_insert_sql_from_dict(test_table, test_columns) + self.assertEqual(generated_sql, expected_sql) + + # 测试忽略插入 + expected_sql = "INSERT OR IGNORE INTO test_table (id, name) VALUES (?, ?)" + generated_sql = create_insert_sql_from_dict(test_table, test_columns, ignore_insert=True) + self.assertEqual(generated_sql, expected_sql) + + def test_to_db(self): + """测试数据库写入功能""" + # 创建测试表和数据 + create_sql = create_table_sql_from_dict("test_table", {"id": "INTEGER PRIMARY KEY", "name": "TEXT"}) + insert_sql = create_insert_sql_from_dict("test_table", {"id": "INTEGER", "name": "TEXT"}) + test_data = [(1, "test1"), (2, "test2"), (3, "test3")] + + # 执行写入 + to_db(self.temp_db, create_sql, insert_sql, test_data) + + # 验证数据是否正确写入 + conn = sqlite3.connect(self.temp_db) + cursor = conn.cursor() + cursor.execute("SELECT * FROM test_table") + result = cursor.fetchall() + conn.close() + + self.assertEqual(result, test_data) + + def test_add_table_index(self): + """测试索引添加功能""" + # 先创建测试表 + conn = sqlite3.connect(self.temp_db) + cursor = conn.cursor() + cursor.execute(create_table_sql_from_dict('tb_nodes', node_columns)) + conn.commit() + conn.close() + + # 添加索引 + add_table_index(self.temp_db) + + # 验证索引是否创建 + conn = sqlite3.connect(self.temp_db) + cursor = conn.cursor() + cursor.execute("PRAGMA index_list(tb_nodes)") + indexes_result = cursor.fetchall() + conn.close() + + # 检查是否存在预期的索引 + index_names = [idx[1] for idx in indexes_result] + for index_name in indexes.keys(): + self.assertIn(index_name, index_names) + + def test_post_process_db(self): + """测试数据库后处理功能""" + # 创建测试表 + conn = sqlite3.connect(self.temp_db) + cursor = conn.cursor() + cursor.execute(create_table_sql_from_dict('tb_nodes', node_columns)) + conn.commit() + conn.close() + + with patch('msprobe.visualization.db_utils.add_table_index') as mock_add_index: + post_process_db(self.temp_db) + mock_add_index.assert_called_once_with(self.temp_db) + + def test_get_graph_unique_id(self): + """测试图形唯一ID生成""" + # 创建模拟graph对象 + mock_graph = Mock() + mock_graph.data_source = "test_source" + mock_graph.step = 5 + mock_graph.rank = 2 + + expected_id = "test_source_5_2" + self.assertEqual(get_graph_unique_id(mock_graph), expected_id) + + def test_get_node_unique_id(self): + """测试节点唯一ID生成""" + # 创建模拟对象 + mock_graph = Mock() + mock_graph.data_source = "test_source" + mock_graph.step = 5 + mock_graph.rank = 2 + + mock_node = Mock() + mock_node.id = "node_123" + + expected_id = "test_source_5_2_node_123" + self.assertEqual(get_node_unique_id(mock_graph, mock_node), expected_id) + + @patch('msprobe.visualization.db_utils.to_db') + @patch('msprobe.visualization.db_utils.get_node_unique_id') + @patch('msprobe.visualization.db_utils.get_graph_unique_id') + @patch('msprobe.visualization.builder.msprobe_adapter.format_node_data') + def test_node_to_db(self, mock_format, mock_graph_id, mock_node_id, mock_to_db): + """测试节点数据写入数据库""" + # 配置模拟 + mock_graph_id.return_value = "graph_123" + mock_node_id.return_value = "node_456" + mock_format.return_value = {} + + # 创建模拟graph和node + mock_node = Mock() + mock_node.id = "node1" + mock_node.op.value = "OPERATION" + mock_node.upnode = None + mock_node.subnodes = [] + mock_node.data = {} + mock_node.micro_step_id = 1 + mock_node.matched_node_link = {} + mock_node.stack_info = {} + mock_node.parallel_merge_info = None + mock_node.matched_distributed = {} + mock_node.input_data = {} + mock_node.output_data = {} + + mock_graph = Mock() + mock_graph.get_sorted_nodes.return_value = [mock_node] + mock_graph.data_source = "test_source" + mock_graph.data_path = "/test/path" + mock_graph.step = 1 + mock_graph.rank = 0 + + # 执行测试 + node_to_db(mock_graph, self.temp_db) + + # 验证to_db被正确调用 + self.assertTrue(mock_to_db.called) + + @patch('msprobe.visualization.db_utils.to_db') + def test_config_to_db(self, mock_to_db): + """测试配置数据写入数据库""" + mock_config = Mock() + mock_config.graph_b = False + mock_config.task = "test_task" + mock_config.tool_tip = "test tooltip" + mock_config.micro_steps = 10 + mock_config.overflow_check = 1 + mock_config.node_colors = {} + mock_config.rank_list = [0, 1, 2, 3] + mock_config.step_list = [0] + + # 执行测试 + config_to_db(mock_config, self.temp_db) + + # 验证to_db被正确调用 + self.assertTrue(mock_to_db.called) + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/test_graph_service.py b/debug/accuracy_tools/msprobe/test/visualization_ut/test_graph_service.py index 7dfd9564ebc21327f3e7e29be90da7f78c3b0393..78e3d1e7bb848b322560d4967ea6673547385c42 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/test_graph_service.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/test_graph_service.py @@ -7,7 +7,7 @@ import argparse from dataclasses import dataclass from unittest.mock import patch -from msprobe.visualization.graph_service import _compare_graph, _build_graph, _compare_graph_ranks, \ +from msprobe.visualization.graph_service import _compare_graph_result, _build_graph_result, _compare_graph_ranks, \ _compare_graph_steps, _build_graph_ranks, _build_graph_steps, _graph_service_command, _graph_service_parser from msprobe.core.common.utils import CompareException @@ -21,6 +21,8 @@ class Args: overflow_check: bool = False fuzzy_match: bool = False complete_stack: bool = False + parallel_merge: bool = False + parallel_params: tuple = None class TestGraphService(unittest.TestCase): @@ -34,8 +36,8 @@ class TestGraphService(unittest.TestCase): 'is_print_compare_log': True } self.layer_mapping = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'layer_mapping.yaml') - self.pattern = r'\b\w+\.vis\b' - self.pattern_rank = r'[\w_]+\.vis\b' + self.pattern = r'\b\w+\.vis.db\b' + self.pattern_rank = r'[\w_]+\.vis.db\b' self.output_json = [] for i in range(7): self.output_json.append(os.path.join(self.current_path, f"compare{i}.json")) @@ -45,30 +47,31 @@ class TestGraphService(unittest.TestCase): last_call_args = mock_log_info.call_args[0][0] self.assertIn(log_info, last_call_args) matches = re.findall(self.pattern, last_call_args) - self.assertTrue(os.path.exists(os.path.join(self.output, matches[0]))) + if matches: + self.assertTrue(os.path.exists(os.path.join(self.output, matches[0]))) @patch('msprobe.core.common.log.logger.info') - def test_compare_graph(self, mock_log_info): + def test_compare_graph_result(self, mock_log_info): args = Args(output_path=self.output, framework='pytorch') - result = _compare_graph(self.input_param, args) + result = _compare_graph_result(self.input_param, args) self.assertEqual(mock_log_info.call_count, 2) self.assertIsNotNone(result) args = Args(output_path=self.output, framework='mindspore') - result = _compare_graph(self.input_param, args) + result = _compare_graph_result(self.input_param, args) self.assertIsNotNone(result) args = Args(output_path=self.output, framework='pytorch', layer_mapping=self.layer_mapping) - result = _compare_graph(self.input_param, args) + result = _compare_graph_result(self.input_param, args) self.assertIsNotNone(result) args = Args(output_path=self.output, framework='pytorch', overflow_check=True) - result = _compare_graph(self.input_param, args) + result = _compare_graph_result(self.input_param, args) self.assertIsNotNone(result) @patch('msprobe.core.common.log.logger.info') - def test_build_graph(self, mock_log_info): - result = _build_graph(os.path.join(self.input, 'step0', 'rank0'), Args(overflow_check=True)) + def test_build_graph_result(self, mock_log_info): + result = _build_graph_result(os.path.join(self.input, 'step0', 'rank0'), Args(overflow_check=True)) self.assertEqual(mock_log_info.call_count, 1) self.assertIsNotNone(result) @@ -81,7 +84,7 @@ class TestGraphService(unittest.TestCase): } args = Args(output_path=self.output, framework='pytorch') _compare_graph_ranks(input_param, args) - self.assert_log_info(mock_log_info) + self.assert_log_info(mock_log_info, 'Successfully exported compare graph results.') input_param1 = { 'npu_path': os.path.join(self.input, 'step0'), @@ -101,7 +104,7 @@ class TestGraphService(unittest.TestCase): } args = Args(output_path=self.output, framework='pytorch') _compare_graph_steps(input_param, args) - self.assert_log_info(mock_log_info) + self.assert_log_info(mock_log_info, 'Successfully exported compare graph results.') input_param1 = { 'npu_path': self.input, @@ -115,12 +118,12 @@ class TestGraphService(unittest.TestCase): @patch('msprobe.core.common.log.logger.info') def test_build_graph_ranks(self, mock_log_info): _build_graph_ranks(os.path.join(self.input, 'step0'), Args(output_path=self.output)) - self.assert_log_info(mock_log_info, "Model graph built successfully, the result file is saved in") + self.assert_log_info(mock_log_info, "Successfully exported build graph results.") @patch('msprobe.core.common.log.logger.info') def test_build_graph_steps(self, mock_log_info): _build_graph_steps(self.input, Args(output_path=self.output)) - self.assert_log_info(mock_log_info, "Model graph built successfully, the result file is saved in") + self.assert_log_info(mock_log_info, "Successfully exported build graph results.") @patch('msprobe.core.common.log.logger.info') def test_graph_service_command(self, mock_log_info): @@ -129,8 +132,10 @@ class TestGraphService(unittest.TestCase): args = Args(input_path=self.output_json[0], output_path=self.output, framework='pytorch') _graph_service_command(args) - self.assert_log_info(mock_log_info) + self.assert_log_info(mock_log_info, 'Exporting compare graph result successfully, the result file is saved in') + @patch('msprobe.core.common.log.logger.info') + def test_graph_service_command1(self, mock_log_info): input_param1 = { 'npu_path': os.path.join(self.input, 'step0', 'rank0'), 'is_print_compare_log': True @@ -139,8 +144,10 @@ class TestGraphService(unittest.TestCase): json.dump(input_param1, f, indent=4) args = Args(input_path=self.output_json[1], output_path=self.output, framework='pytorch') _graph_service_command(args) - self.assert_log_info(mock_log_info, "Model graph built successfully, the result file is saved in") + self.assert_log_info(mock_log_info, "Model graph exported successfully, the result file is saved in") + @patch('msprobe.core.common.log.logger.info') + def test_graph_service_command2(self, mock_log_info): input_param2 = { 'npu_path': os.path.join(self.input, 'step0'), 'bench_path': os.path.join(self.input, 'step0'), @@ -150,8 +157,10 @@ class TestGraphService(unittest.TestCase): json.dump(input_param2, f, indent=4) args = Args(input_path=self.output_json[2], output_path=self.output, framework='pytorch') _graph_service_command(args) - self.assert_log_info(mock_log_info) + self.assert_log_info(mock_log_info, 'Successfully exported compare graph results.') + @patch('msprobe.core.common.log.logger.info') + def test_graph_service_command3(self, mock_log_info): input_param3 = { 'npu_path': self.input, 'bench_path': self.input, @@ -161,8 +170,10 @@ class TestGraphService(unittest.TestCase): json.dump(input_param3, f, indent=4) args = Args(input_path=self.output_json[3], output_path=self.output, framework='pytorch') _graph_service_command(args) - self.assert_log_info(mock_log_info) + self.assert_log_info(mock_log_info, 'Successfully exported compare graph results.') + @patch('msprobe.core.common.log.logger.info') + def test_graph_service_command4(self, mock_log_info): input_param4 = { 'npu_path': os.path.join(self.input, 'step0'), 'is_print_compare_log': True @@ -171,8 +182,10 @@ class TestGraphService(unittest.TestCase): json.dump(input_param4, f, indent=4) args = Args(input_path=self.output_json[4], output_path=self.output, framework='pytorch') _graph_service_command(args) - self.assert_log_info(mock_log_info, "Model graph built successfully, the result file is saved in") + self.assert_log_info(mock_log_info, "Successfully exported build graph results.") + @patch('msprobe.core.common.log.logger.info') + def test_graph_service_command5(self, mock_log_info): input_param5 = { 'npu_path': self.input, 'is_print_compare_log': True @@ -181,8 +194,10 @@ class TestGraphService(unittest.TestCase): json.dump(input_param5, f, indent=4) args = Args(input_path=self.output_json[5], output_path=self.output, framework='pytorch') _graph_service_command(args) - self.assert_log_info(mock_log_info, "Model graph built successfully, the result file is saved in") + self.assert_log_info(mock_log_info, "Successfully exported build graph results.") + @patch('msprobe.core.common.log.logger.info') + def test_graph_service_command6(self, mock_log_info): input_param6 = { 'npu_path': self.input, 'bench_path': os.path.join(self.input, 'step0'), diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/test_visualization_utils.py b/debug/accuracy_tools/msprobe/test/visualization_ut/test_visualization_utils.py index e5b0afaadf9def910c248b945ad15084300a65c0..41ea145208dc658a83bb5c791d6b05a0abb30616 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/test_visualization_utils.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/test_visualization_utils.py @@ -1,7 +1,7 @@ import os import unittest from msprobe.visualization.utils import (load_json_file, load_data_json_file, str2float, check_directory_content, - GraphConst) + GraphConst, SerializableArgs) class TestMappingConfig(unittest.TestCase): @@ -37,6 +37,21 @@ class TestMappingConfig(unittest.TestCase): input_type = check_directory_content(os.path.join(self.input, "step0", "rank0")) self.assertEqual(input_type, GraphConst.FILES) + def test_serializable_args(self): + class TmpArgs: + def __init__(self, a, b, c): + self.a = a + self.b = b + self.c = c + input_args1 = TmpArgs('a', 123, [1, 2, 3]) + serializable_args1 = SerializableArgs(input_args1) + self.assertEqual(serializable_args1.__dict__, input_args1.__dict__) + input_args2 = TmpArgs('a', 123, lambda x: print(x)) + serializable_args2 = SerializableArgs(input_args2) + self.assertNotEqual(serializable_args2.__dict__, input_args2.__dict__) + + + if __name__ == '__main__': unittest.main() diff --git a/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py b/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py index 814882e6b819e9e6b6b421aec5f8f0b89f03f7c6..1312d244f71a535a56a842b053fff9a525daea44 100644 --- a/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py +++ b/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py @@ -14,24 +14,28 @@ # limitations under the License. import re +from dataclasses import dataclass from msprobe.core.common.const import Const -from msprobe.core.common.file_utils import load_json +from msprobe.core.common.file_utils import load_json, save_json +from msprobe.core.common.utils import load_stack_json +from msprobe.core.common.log import logger from msprobe.visualization.builder.msprobe_adapter import get_input_output from msprobe.visualization.builder.msprobe_adapter import op_patterns from msprobe.visualization.graph.graph import Graph from msprobe.visualization.graph.node_op import NodeOp -from msprobe.visualization.utils import save_json_file, GraphConst +from msprobe.visualization.utils import GraphConst +from msprobe.visualization.db_utils import node_to_db, config_to_db class GraphBuilder: backward_pattern = re.compile(r"(\.backward\.)(\d+)$") forward_pattern = re.compile(r"(\.forward\.)(\d+)$") - # 匹配以大写字母开头,后接任意字母,并以Template(结尾 - template_pattern = re.compile(r'\b[A-Z][a-zA-Z]*Template\(') + # 匹配以大写字母开头,后接任意字母,并以Template(结尾,或包含api_template(的字符串 + template_pattern = re.compile(r'\b([A-Z][a-zA-Z]*Template|api_template|api_instance)\(') @staticmethod - def build(construct_path, data_path, stack_path, model_name='DefaultModel', complete_stack=False): + def build(construct_path, data_path, stack_path, model_name='DefaultModel'): """ GraphBuilder的对外提供的构图方法 Args: @@ -39,74 +43,35 @@ class GraphBuilder: data_path: dump.json路径 stack_path: stack.json路径 model_name: 模型名字,依赖外部输入 - complete_stack: 完整的堆栈信息 Returns: Graph,代表图的数据结构 """ construct_dict = load_json(construct_path) + if not construct_dict: + logger.error("The content of 'construct.json' is empty, failed to build graph. " + "When dumping data, it is necessary to select level L0 or mix in order to " + "collect model structure data, that is, the content of 'construct.json' is not empty.") + raise RuntimeError dump_dict = load_json(data_path) - stack_dict = load_json(stack_path) - if not complete_stack: - GraphBuilder._simplify_stack(stack_dict) + stack_dict = load_stack_json(stack_path) data_dict = dump_dict.get(GraphConst.DATA_KEY, {}) graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict) GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict) GraphBuilder._collect_apis_between_modules(graph) + GraphBuilder._add_parameters_grad(graph, data_dict) return graph @staticmethod - def to_json(filename, config): - """ - 将graph导出成.vis文件的接口 - """ - result = {} + def to_db(filename, config): + config.graph_n.step = config.step + config.graph_n.rank = config.rank + config.graph_n.compare_mode = config.compare_mode + node_to_db(config.graph_n, filename) if config.graph_b: - result[GraphConst.JSON_NPU_KEY] = config.graph_n.to_dict() - result[GraphConst.JSON_BENCH_KEY] = config.graph_b.to_dict() - else: - result = config.graph_n.to_dict() - if config.tool_tip: - result[GraphConst.JSON_TIP_KEY] = config.tool_tip - if config.node_colors: - result[GraphConst.COLORS] = config.node_colors - if config.micro_steps: - result[GraphConst.MICRO_STEPS] = config.micro_steps - if config.task: - result[GraphConst.JSON_TASK_KEY] = config.task - result[GraphConst.OVERFLOW_CHECK] = config.overflow_check - save_json_file(filename, result) - - @staticmethod - def _simplify_stack(stack_dict): - """ - 精简堆栈内容,模块级保留包含"模块名("的堆栈,api级保留"xxxTemplate("的下一行堆栈 - - 例如模块 Module.layer3.0.bn2.BatchNorm2d.forward.0,模块名为bn2,匹配"bn2(", - 保留堆栈"File /home/models/resnet.py, line 97, in forward, \n out = self.bn2(out)" - - 例如Api Tensor.__iadd__.4.forward,堆栈为: - "File /home/wrap_tensor.py, line 61, return TensorOPTemplate(op_name, hook)(*args, **kwargs)", - "File /home/torchvision/models/resnet.py, line 102, in forward, \n out += identity", - 匹配到第一行的"TensorOPTemplate(",保留下一行堆栈 - """ - module_pattern = re.compile(op_patterns[0]) - for dump_name, stack_list in stack_dict.items(): - if not isinstance(stack_list, list): - continue - if module_pattern.match(dump_name): - parts = dump_name.split(Const.SEP) - if len(parts) < abs(Const.LAYER_NAME_INDEX): - continue - module_name = parts[Const.LAYER_NAME_INDEX] - for stack in stack_list: - if re.search(module_name + r'\(', stack): - stack_list = [stack] - break - else: - for index, stack in enumerate(stack_list): - if GraphBuilder.template_pattern.search(stack) and index < len(stack_list) - 1: - stack_list = [stack_list[index + 1]] - break - stack_dict[dump_name] = stack_list + config.graph_b.data_source = GraphConst.JSON_BENCH_KEY + config.graph_b.step = config.step + config.graph_b.rank = config.rank + node_to_db(config.graph_b, filename) + config_to_db(config, filename) @staticmethod def _handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id): @@ -186,6 +151,8 @@ class GraphBuilder: # 数据格式:"output": [[{param1}, {param2}, ...]] if GraphBuilder._is_valid_batch_p2p_output(param_list): for param in param_list[0]: + if not isinstance(param, dict): + continue info = {GraphConst.OP: param.get(GraphConst.OP), GraphConst.PEER: param.get(GraphConst.PEER), GraphConst.GROUP_ID: param.get(GraphConst.GROUP_ID)} node.batch_p2p_info.append(info) @@ -235,10 +202,46 @@ class GraphBuilder: graph.root.subnodes = output + @staticmethod + def _add_parameters_grad(graph, data_dict): + """ + 将parameters_grad信息添加到graph中, + 对应模块的parameters_grad节点添加到对应模块的最后一次backward节点(backward计数最大)内作为子节点 + + 例如,graph有节点Module.a.backward.0, Module.a.backward.1, Module.a.backward.2 + 则Module.a.parameters_grad添加在Module.a.backward.2内作为子节点 + """ + prefixes = [] + suffix = Const.SEP + Const.PARAMS_GRAD + for node_id in data_dict.keys(): + if node_id not in graph.node_map and node_id.endswith(suffix): + prefixes.append(node_id.replace(suffix, '')) + + max_info = {prefix: 0 for prefix in prefixes} + + for key in graph.node_map.keys(): + parts = key.split(Const.SEP) + if len(parts) > 2 and parts[-2] == Const.BACKWARD: + num = int(parts[-1]) + prefix = Const.SEP.join(parts[:-2]) + if prefix in max_info and num > max_info[prefix]: + max_info[prefix] = num + + for prefix, num in max_info.items(): + node_id = prefix + Const.SEP + Const.BACKWARD + Const.SEP + str(num) + node = graph.get_node(node_id) + if node: + parameters_grad_node_id = graph.add_node(NodeOp.module, prefix + suffix, up_node=node) + # 添加输入输出数据 + node_data = data_dict.get(parameters_grad_node_id, {}) + input_data, output_data = get_input_output(node_data, parameters_grad_node_id) + # 更新数据 + graph.get_node(parameters_grad_node_id).set_input_output(input_data, output_data) + class GraphExportConfig: def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task='', - overflow_check=False): + overflow_check=False, compare_mode=None, step=0, rank=0, step_list=None, rank_list=None): self.graph_n = graph_n self.graph_b = graph_b self.tool_tip = tool_tip @@ -246,3 +249,25 @@ class GraphExportConfig: self.micro_steps = micro_steps self.task = task self.overflow_check = overflow_check + self.compare_mode = compare_mode + self.step = step + self.rank = rank + self.step_list = step_list + self.rank_list = rank_list + + +@dataclass +class GraphInfo: + graph: Graph + construct_path: str + data_path: str + stack_path: str + + +@dataclass +class BuildGraphTaskInfo: + graph_info_n: GraphInfo + graph_info_b: GraphInfo + npu_rank: str + bench_rank: str + time_str: str diff --git a/debug/accuracy_tools/msprobe/visualization/builder/graph_merger.py b/debug/accuracy_tools/msprobe/visualization/builder/graph_merger.py new file mode 100644 index 0000000000000000000000000000000000000000..f1a2a6ca32c94285f7584a72f8106ff3e81b3bdd --- /dev/null +++ b/debug/accuracy_tools/msprobe/visualization/builder/graph_merger.py @@ -0,0 +1,987 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import math + +from msprobe.core.common.const import Const +from msprobe.visualization.graph.graph import Graph, BaseNode +from msprobe.visualization.graph.node_op import NodeOp +from msprobe.core.common.log import logger +from msprobe.visualization.utils import GraphConst +from msprobe.core.common.decorator import recursion_depth_decorator +from msprobe.core.common.parallel_state import get_tp_pp_default_groups + +MAX_INFO = 'The Max value merging method for ' +MIN_INFO = 'The Min value merging method for ' +MEAN_INFO = 'The Mean value merging method for ' +NORM_INFO = 'The Norm value merging method for ' + + +class GraphMerger: + def __init__(self, build_graph_results, parallel_param, is_bench=False): + self.strategy = self._select_strategy(build_graph_results, parallel_param, is_bench) + + @staticmethod + def _select_strategy(results, param, is_bench): + if param.tp == param.pp == param.rank_size == 1: + return NoParallelMerger(results, param, is_bench) + elif param.tp == param.rank_size: + return TPMerger(results, param, is_bench) + elif param.pp == param.rank_size: + return PPMerger(results, param, is_bench) if param.vpp == 1 else VPPMerger(results, param, is_bench) + elif param.pp == 1: + return TPMerger(results, param, is_bench) + elif param.tp == 1: + return PPMerger(results, param, is_bench) if param.vpp == 1 else VPPMerger(results, param, is_bench) + elif param.tp * param.pp == param.rank_size: + return TPPPMerger(results, param, is_bench) + else: + return FullMerger(results, param, is_bench) + + def merge_graph(self): + return self.strategy.merge_graphs() + + +class BaseGraphMerger: + def __init__(self, build_graph_results, parallel_param, is_bench): + self.unmerged_module = [Const.CLIP_GRAD, Const.OPTIMIZER] + self.dtype_list = Const.TORCH_INT_DTYPE + Const.TORCH_FLOAT_DTYPE + [Const.FLOAT16, Const.FLOAT32, + Const.BFLOAT16] + self.build_graph_results = build_graph_results + self.parallel_param = parallel_param + self.is_bench = is_bench + self.log_prefix = '[Bench]' if self.is_bench else '[NPU]' + self._add_all_nodes_rank() + + @staticmethod + def sort_merged_api_collection(graph): + def extract_rank(node): + match = re.search(r'_Rank(\d+)', node.id) + return int(match.group(1)) if match else None + + for sub_node in graph.root.subnodes: + if sub_node.op == NodeOp.api_collection and sub_node.id.startswith( + GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS): + sub_node.subnodes = sorted(sub_node.subnodes, key=extract_rank) + + @staticmethod + def _update_node_data_key(old_id, new_id, data_dict): + new_dict = {} + for key, value in data_dict.items(): + new_key = key.replace(old_id, new_id) + if 'full_op_name' in value: + value['full_op_name'] = value.get('full_op_name').replace(old_id, new_id) + new_dict[new_key] = value + return new_dict + + @staticmethod + def _compare_value_same(main_value, other_value, has_uncertainty=False): + if not isinstance(main_value, (int, float)) or not isinstance(other_value, (int, float)): + return True + # 没开启确定性计算,各rank的mean和norm有细微差异,如果相对误差在阈值内则认为是相同的 + if has_uncertainty: + diff = abs(main_value - other_value) + if math.isnan(diff): + return math.isnan(main_value) and math.isnan(other_value) + elif math.isinf(diff): + return math.isinf(main_value) and math.isinf(other_value) + else: + return diff < GraphConst.UNCERTAINTY_THRESHOLD if main_value == 0 else \ + abs(diff / main_value) < GraphConst.UNCERTAINTY_THRESHOLD + else: + return main_value == other_value + + def merge_graphs(self): + raise NotImplementedError("This method should be implemented by subclasses.") + + def merge_graph_api_collection(self, results: list): + """ + graph合并时,将各rank的游离api集合合并为一个总的游离api集合 + example: + rank0: Apis_Between_Modules.0 rank1: Apis_Between_Modules.0 + Module.module.Float16Module.forward.0 Module.module.Float16Module.forward.0 + Apis_Between_Modules.1 Apis_Between_Modules.1 + + merged: Apis_Between_Modules_All_Ranks.0 + |_ Apis_Between_Modules_Rank0.0 + |_ Apis_Between_Modules_Rank1.0 + Module.module.Float16Module.forward.0 + Apis_Between_Modules_All_Ranks.1 + |_ Apis_Between_Modules_Rank0.1 + |_ Apis_Between_Modules_Rank1.1 + """ + main_graph_result = results[0] + main_root_sub_nodes = main_graph_result.graph.root.subnodes + new_main_root_sub_nodes = [] + for main_node in main_root_sub_nodes: + # 如果游离api集合已合并为一个总的游离api集合,总的游离api集合之间还要再合并 + if main_node.id.startswith(GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS): + new_main_root_sub_nodes.append(main_node) + for other_graph_result in results[1:]: + other_node = other_graph_result.graph.get_node(main_node.id) + if not other_node: + continue + for sub_node in other_node.subnodes: + sub_node.upnode = main_node + main_graph_result.graph.node_map[sub_node.id] = sub_node + for sub_sub_node in sub_node.subnodes: + main_graph_result.graph.node_map[sub_sub_node.id] = sub_sub_node + main_node.subnodes.extend(other_node.subnodes) + # 游离api集合合并为一个总的游离api集合 + elif main_node.id.startswith(GraphConst.APIS_BETWEEN_MODULES): + all_collection_node_id = main_graph_result.graph.add_node(NodeOp.api_collection, + GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS, + id_accumulation=True) + all_collection_node = main_graph_result.graph.get_node(all_collection_node_id) + new_main_root_sub_nodes.append(all_collection_node) + # Apis_Between_Modules.0 --> Apis_Between_Modules_Rank0.0 + origin_main_node_id = main_node.id + main_node.id = GraphConst.APIS_BETWEEN_MODULES + f'_Rank{main_graph_result.rank}.' + \ + main_node.id.split(Const.SEP)[-1] + all_collection_node.subnodes = [main_node] + main_node.upnode = all_collection_node + main_graph_result.graph.node_map[main_node.id] = main_node + del main_graph_result.graph.node_map[origin_main_node_id] + for other_graph_result in results[1:]: + other_node = other_graph_result.graph.get_node(origin_main_node_id) + if not other_node: + continue + # Apis_Between_Modules.0 --> Apis_Between_Modules_Rank1.0 + other_node.id = GraphConst.APIS_BETWEEN_MODULES + f'_Rank{other_graph_result.rank}.' + \ + other_node.id.split(Const.SEP)[-1] + main_graph_result.graph.node_map[other_node.id] = other_node + for sub_node in other_node.subnodes: + # api节点,在api名称上添加rank信息 + old_id = sub_node.id + parts = sub_node.id.split(Const.SEP) + parts[1] += f'_rank{other_graph_result.rank}' + sub_node.id = Const.SEP.join(parts) + sub_node.input_data = self._update_node_data_key(old_id, sub_node.id, sub_node.input_data) + sub_node.output_data = self._update_node_data_key(old_id, sub_node.id, sub_node.output_data) + main_graph_result.graph.node_map[sub_node.id] = sub_node + all_collection_node.subnodes.append(other_node) + other_node.upnode = all_collection_node + else: + new_main_root_sub_nodes.append(main_node) + main_graph_result.graph.root.subnodes = new_main_root_sub_nodes + + def split_graph_results_by_groups(self, groups): + """ + 基于pp或tp域,划分待合并的graph + """ + rank_results_mapping = {result.rank: result for result in self.build_graph_results} + return [[rank_results_mapping.get(rank) for rank in ranks] for ranks in groups] + + def compare_node_param_data(self, main_node, other_nodes, compare_data=True): + """ + 当前节点与若干其他节点比较输入输出参数的数据是否一致,如果发现有不一致的参数,将参数暂存于列表中 + :param main_node: 当前节点 + :param other_nodes: 其他节点列表 + :param compare_data: 是否进行数据比对,如果compare_data=False则直接认为数据不一致 + :return: 输入不一致的参数dict,输出不一致的参数dict,两个dict都为空列表代表两个节点的输入输出完全一致 + """ + if not other_nodes: + return {}, {} + data_types = {'input_data': {}, 'output_data': {}} + for data_type, data_dict in data_types.items(): + main_data_dict = getattr(main_node, data_type) + for key, main_param in main_data_dict.items(): + same_flag = compare_data + if main_param.get(Const.DTYPE) not in self.dtype_list: + continue + tp_need_merge_params = [main_param] + for other_node in other_nodes: + param_key = key.replace(main_node.id, other_node.id) if main_node.id != other_node.id else key + other_param = getattr(other_node, data_type).get(param_key, {}) + if other_param.get(Const.DTYPE) not in self.dtype_list: + break + tp_need_merge_params.append(other_param) + if compare_data and not self.compare_param_same(main_param, other_param, has_uncertainty=True): + same_flag = False + if not same_flag: + # {input.0: [{"Max": 0, "Min": 0, ...}, {"Max": 0.1, "Min": 0, ...}, ...]} + data_dict[key.replace(main_node.id + Const.SEP, '')] = tp_need_merge_params + return data_types.get('input_data'), data_types.get('output_data') + + def compare_param_same(self, main_param, other_param, has_uncertainty=False): + if not self._compare_value_same(main_param.get(Const.MAX), other_param.get(Const.MAX)): + return False + if not self._compare_value_same(main_param.get(Const.MIN), other_param.get(Const.MIN)): + return False + if not self._compare_value_same(main_param.get(Const.MEAN), other_param.get(Const.MEAN), has_uncertainty): + return False + if not self._compare_value_same(main_param.get(Const.NORM), other_param.get(Const.NORM), has_uncertainty): + return False + return True + + def get_default_groups(self): + """ + 根据GPU总数、TP数、PP数初始化并行组 + + return: + tp_groups: 张量并行组列表,每个元素是一个包含组内rank的列表 + pp_groups: 流水线并行组列表,每个元素是一个包含组内rank的列表 + """ + tp_groups, pp_groups = get_tp_pp_default_groups(self.parallel_param.rank_size, self.parallel_param.tp, + self.parallel_param.pp, order=self.parallel_param.order) + + return tp_groups, pp_groups + + def _add_all_nodes_rank(self): + for result in self.build_graph_results: + for node in result.graph.node_map.values(): + node.rank = result.rank + + +class PPMerger(BaseGraphMerger): + LAYERS_PATTERN = re.compile(r"(layers\.|layer\.)\d+(\.)") + MARK_PATTERN = re.compile(r"%(\d+)%(\d+)$") + MARK = '%' + + @staticmethod + def _trace_p2p_mapping(p2p_mapping: dict): + """ + 将字典分组为独立的链,每个链都从未访问过的键开始,按照字典中的映射关系进行追踪 + p2p_mapping内容为p2p通信的send映射,追踪映射关系建立pp域 + example: p2p_mapping={0: 2, 1: 3, 2: 4, 3: 5, 4: 6, 5: 7, 6: 4, 7: 5}, return=[[0, 2, 4, 6], [1, 3, 5, 7]] + """ + visited = set() + result = [] + + def collect_keys(start_key): + """ + 追踪从某一个键开始的所有“连续”键,直到无法再找到下一个键为止 + """ + current_key = start_key + chain = [] + while current_key in p2p_mapping and current_key not in visited: + chain.append(current_key) + visited.add(current_key) + current_key = p2p_mapping[current_key] + return chain + + for key in p2p_mapping: + if key not in visited: + chain_result = collect_keys(key) + if chain_result: + result.append(chain_result) + return result + + @recursion_depth_decorator("msprobe.visualization.builder.graph_merger.PPMerger._merge_nodes", 1000) + def _merge_nodes(self, main_graph, main_node, other_graphs): + """ + 其他rank graph中被pp切分的节点,需要合并到main graph + """ + other_nodes = [] + for other_graph in other_graphs: + other_node = other_graph.get_node(main_node.id) + # 表明此节点只有main graph有 + if not other_node: + other_nodes.clear() + return + other_nodes.append(other_node) + if other_nodes: + param_in, param_out = self.compare_node_param_data(main_node, other_nodes) + # 各个rank都有的模块,且输入输出都不一致,且节点id符合正则,判定为被pp切分的模块,需要合并结构 + pp_merged_condition = param_in and param_out and self.LAYERS_PATTERN.search(main_node.id) + # backward可能没有output,是否要pp合并从对应的forward节点判断 + if Const.SEP + Const.BACKWARD + Const.SEP in main_node.id: + f_node = main_graph.node_map.get( + main_node.id.replace(Const.SEP + Const.BACKWARD + Const.SEP, Const.SEP + Const.FORWARD + Const.SEP)) + if f_node and hasattr(f_node, 'is_pp_merged'): + pp_merged_condition = True + if pp_merged_condition: + main_node.is_pp_merged = True + main_up_node = main_node.upnode + for other_node in other_nodes: + # pp切分中被切分的层在各rank的名称是一样的,这里给其他rank的同名层增加位置和rank标记 + self._mark_node_id_position_rank(other_node, other_node.rank) + self._add_node_to_main_graph(main_graph, other_node) + # 其他rank被pp切分的模块节点添加到当前rank的graph + other_node.upnode = main_up_node + main_up_node.subnodes.append(other_node) + # 已找到被pp切分的模块节点,不再递归其内部 + return + # 各个rank都有的forward模块,且输入一致,输出不一致,判定为模块内部包含被pp切分的模块,此模块的输出要使用最后一个rank的输出 + elif not param_in and param_out and Const.SEP + Const.FORWARD + Const.SEP in main_node.id: + main_node.output_data = other_nodes[-1].output_data + # 各个rank都有的backward模块,且输出一致,输入不一致,判定为模块内部包含被pp切分的模块,此模块的输入要使用最后一个rank的输入 + elif param_in and not param_out and Const.SEP + Const.BACKWARD + Const.SEP in main_node.id: + main_node.input_data = other_nodes[-1].input_data + self._merge_other_unique_nodes(main_graph, main_node, other_nodes) + for sub_node in main_node.subnodes: + if sub_node.op == NodeOp.module: + self._merge_nodes(main_graph, sub_node, other_graphs) + + def merge_graphs(self): + results_groups = self.split_graph_results_by_groups(self.get_groups()) + results = [] + for result_groups in results_groups: + self.merge_graph_api_collection(result_groups) + results.extend(self.merge_pp_graphs(result_groups)) + return results + + def merge_pp_graphs(self, results): + if not results or len(results) < 2: + return results + graphs = [x.graph for x in results] + main_graph_result = results[0] + for main_node in main_graph_result.graph.root.subnodes: + if main_node.op == NodeOp.module and main_node.id not in self.unmerged_module: + self._merge_nodes(main_graph_result.graph, main_node, graphs[1:]) + self._sort_nodes(main_graph_result.graph, main_node) + return [main_graph_result] + + def get_groups(self): + """ + 在各rank寻找p2p通信节点,建立各rank之间p2p的映射关系 + """ + p2p_mapping = {} + for result in self.build_graph_results: + rank = result.rank + pp_rank = None + for node in result.graph.node_map.values(): + if not node.id.startswith(Const.DISTRIBUTED + Const.SEP): + continue + if '.batch_isend_irecv.' in node.id: + for p2p_info in node.batch_p2p_info: + target_rank = p2p_info.get(GraphConst.PEER) + if target_rank is not None and target_rank != rank and p2p_info.get(GraphConst.OP) == 'isend': + pp_rank = target_rank + break + elif '.send.' in node.id or '.isend.' in node.id: + # example: Distributed.isend.0.forward --> Distributed.isend.0.forward.input.dst + dst_kwarg = f'{node.id}{Const.SEP}{Const.INPUT}{Const.SEP}{GraphConst.DST}' + dst = node.input_data.get(dst_kwarg, {}).get('value') + if dst is not None: + pp_rank = dst + break + if pp_rank is not None: + break + if pp_rank is not None: + p2p_mapping[rank] = pp_rank + pp_groups = self._trace_p2p_mapping(p2p_mapping) + if not pp_groups: + logger.info('Unable to get pp groups based on Distributed Api (batch_isend_irecv, send, or isend), ' + 'generate pp groups using parallel param "rank_size", "tp" and "pp".') + _, pp_groups = self.get_default_groups() + logger.info(f'{self.log_prefix} All pp groups is {pp_groups}.') + return pp_groups + + def _merge_other_unique_nodes(self, main_graph, main_node, other_nodes): + """ + 其他rank graph中other_node的子节点列表如果包含独有的节点,需要合并到main graph + """ + lists = [main_node.subnodes] + for other_node in other_nodes: + lists.append(other_node.subnodes) + dicts = [{node.id: node for node in lst} for lst in lists] + unique_node_ids = {} + # 计算每个集合的独有元素 + for i, current_dict in enumerate(dicts): + other_ids = set() + for j, other_dict in enumerate(dicts): + if i != j: + # 更新并集,添加当前遍历到的集合的元素 + other_ids.update(other_dict.keys()) + result = set(current_dict.keys()) - other_ids + if i != 0 and result: + # 计算当前集合与其他集合并集的差集,即独有元素,保持原始顺序 + unique_node_ids[i] = [node_id for node_id in current_dict if node_id in result] + unique_nodes = [] + if unique_node_ids: + for i, items in unique_node_ids.items(): + for item in items: + unique_nodes.append(dicts[i].get(item)) + if unique_nodes: + for unique_node in unique_nodes: + self._mark_node_id_position_rank(unique_node, unique_node.rank) + self._add_node_to_main_graph(main_graph, unique_node) + main_node.subnodes.append(unique_node) + unique_node.upnode = main_node + + def _sort_nodes(self, main_graph, start_node): + stack = [start_node] + while stack: + node = stack.pop() + if self.MARK_PATTERN.search(node.id): + is_forward = (Const.SEP + Const.FORWARD + Const.SEP in node.id or + Const.SEP + Const.FORWARD + self.MARK in node.id) + new_sub_nodes1, new_sub_nodes2 = [], [] + for item in node.upnode.subnodes: + new_sub_nodes2.append(item) if self.MARK_PATTERN.search(item.id) else new_sub_nodes1.append(item) + + order = True if is_forward else False + new_sub_nodes2.sort(key=lambda n: self._get_node_sort_rule(n, rank_ascending=order)) + new_sub_nodes = new_sub_nodes1 + new_sub_nodes2 if is_forward else new_sub_nodes2 + new_sub_nodes1 + + index = -1 + node_iter = new_sub_nodes if is_forward else reversed(new_sub_nodes) + for item in node_iter: + if self.LAYERS_PATTERN.search(item.id): + index += 1 + if self.MARK_PATTERN.search(item.id): + item.pp_index = index + + for item in new_sub_nodes2: + self._update_node_id(main_graph, item) + + node.upnode.subnodes = new_sub_nodes + + stack.extend(node.subnodes) + + def _add_node_to_main_graph(self, main_graph: Graph, node: BaseNode): + if node.id in main_graph.node_map: + logger.warning(f'{node.id} is exist!') + else: + main_graph.node_map[node.id] = node + for sub_node in node.subnodes: + self._add_node_to_main_graph(main_graph, sub_node) + + def _get_node_sort_rule(self, node, rank_ascending=True): + match = self.MARK_PATTERN.search(node.id) + if match: + # position代表当前节点在父节点中的位置序号 + position, rank = int(match.group(1)), int(match.group(2)) + if rank_ascending: + return rank, position + else: + return -rank, position + return (float('inf'), float('inf')) if rank_ascending else (-float('inf'), -float('inf')) + + def _mark_node_id_position_rank(self, node: BaseNode, rank): + position = 0 + for index, item in enumerate(node.upnode.subnodes): + if item.id == node.id: + position = index + break + # 各rank重复节点添加所处层级位置排序信息position和rank号,用%分隔 + node.id = node.id + f'{self.MARK}{position}' + f'{self.MARK}{rank}' + for sub_node in node.subnodes: + self._mark_node_id_position_rank(sub_node, rank) + + def _update_node_id(self, graph, start_node: BaseNode, pp_index=""): + stack = [(start_node, pp_index)] + while stack: + node, pp_index = stack.pop() + # 修改节点id之前删除node_map的信息,修改完再添加回去 + if node.id not in graph.node_map: + logger.warning(f'Update node id {node.id} fail!') + else: + del graph.node_map[node.id] + old_id = self.MARK_PATTERN.sub("", node.id) + if node.op == NodeOp.module: + # 被pp切分的模块节点,基于位置和rank信息修改模块名称计数信息 + if self.LAYERS_PATTERN.search(node.id) and self.MARK_PATTERN.search(node.id): + if hasattr(node, 'pp_index'): + pp_index = str(node.pp_index) + node.id = self.LAYERS_PATTERN.sub(r"\g<1>" + pp_index + r"\g<2>", node.id) + else: + # api节点,在api名称上添加rank信息 + parts = node.id.split(Const.SEP) + parts[1] += f'_rank{node.id.split(PPMerger.MARK)[-1]}' + node.id = Const.SEP.join(parts) + # 把之前添加的位置和rank信息删掉 + node.id = self.MARK_PATTERN.sub("", node.id) + # node id更新了,那么data的key中包含node id也要更新 + node.input_data = self._update_node_data_key(old_id, node.id, node.input_data) + node.output_data = self._update_node_data_key(old_id, node.id, node.output_data) + graph.node_map[node.id] = node + # 将子节点加入栈中 + for sub_node in node.subnodes: + stack.append((sub_node, pp_index)) + + +class TPMerger(BaseGraphMerger): + RANK_PATTERN = re.compile(r"_rank(\d+)\.") + OPERATION_TABLE = { + Const.MAX: { + 'initial': lambda p: p.get(Const.MAX), + 'merge': lambda current, other: max(current, other.get(Const.MAX)), + 'finalize': lambda current, count: current, + 'formula': lambda key, values: f'{MAX_INFO}{key} is: max({", ".join(map(str, values))})' + }, + Const.MIN: { + 'initial': lambda p: p.get(Const.MIN), + 'merge': lambda current, other: min(current, other.get(Const.MIN)), + 'finalize': lambda current, count: current, + 'formula': lambda key, values: f'{MIN_INFO}{key} is: min({", ".join(map(str, values))})' + }, + Const.MEAN: { + 'initial': lambda p: p.get(Const.MEAN), + 'merge': lambda current, other: current + other.get(Const.MEAN), + 'finalize': lambda current, count: current / count, + 'formula': lambda key, values: f'{MEAN_INFO}{key} is: ({" + ".join(map(str, values))}) / {len(values)}' + }, + Const.NORM: { + 'initial': lambda p: pow(p.get(Const.NORM), 2.0), + 'merge': lambda current, other: current + pow(other.get(Const.NORM), 2.0), + 'finalize': lambda current, count: pow(current, 1 / 2.0), + 'formula': lambda key, values: f'{NORM_INFO}{key} is: ({" + ".join([f"{v} ** 2" for v in values])}) ** 0.5' + } + } + TP_MERGED_INFO = f'This data is the merged data after tensor parallelism(TP), and the data is merged from rank ' + + @staticmethod + def _merge_params(tp_need_merge_param: dict): + """ + 合并tp切分的各rank参数统计值 + tp_need_merge_param: {input.0: [{"Max": 0, "Min": 0, ...}, {"Max": 0.1, "Min": 0, ...}, ...]} + return: 计算详情 + """ + merge_info = [] + for key, param_list in tp_need_merge_param.items(): + if len(param_list) < 2: + continue + main_param = param_list[0] + + for stat, ops in TPMerger.OPERATION_TABLE.items(): + current_value = ops['initial'](main_param) + value_list = [current_value if stat != Const.NORM else main_param.get(Const.NORM)] + + for other_param in param_list[1:]: + current_value = ops['merge'](current_value, other_param) + value_list.append(other_param.get(stat) if stat != Const.NORM else other_param.get(Const.NORM)) + + final_value = ops['finalize'](current_value, len(param_list)) + main_param[stat] = final_value + formula_base = f'{ops["formula"](key, value_list)}' + f' = {final_value}' + + merge_info.append(formula_base) + + return merge_info + + @staticmethod + def _get_need_merge_node(main_node, other_graphs, tp_merge_mapping): + """ + 获取需要TP合并的节点列表 + 如果是TP+PP的混合并行,此时数据已经被PP合并过,一些node_id被标记上rank信息,此时需要基于rank映射才能获取到需要TP合并的节点列表,例如: + main_node = Torch.matmul_rank4.32.forward other_node = Torch.matmul_rank5.32.forward + 需要建立4->5的映射,才能基于Torch.matmul_rank4.32.forward找到Torch.matmul_rank5.32.forward + """ + other_nodes = [] + match = TPMerger.RANK_PATTERN.search(main_node.id) + # 节点名称被标记rank信息,且提供了映射 + if match and tp_merge_mapping: + rank = int(match.group(1)) + tp_mapping_ranks = tp_merge_mapping.get(rank) + if not tp_mapping_ranks: + return other_nodes + if len(tp_mapping_ranks) != len(other_graphs): + return other_nodes + for i, graph in enumerate(other_graphs): + # 基于映射得到目标rank,替换node_id当前rank信息后去目标graph取node + tp_mapping_id = TPMerger.RANK_PATTERN.sub(f"_rank{tp_mapping_ranks[i]}.", main_node.id) + other_node = graph.node_map.get(tp_mapping_id) + if not other_node or main_node.get_ancestors() != other_node.get_ancestors(): + other_nodes.clear() + break + other_nodes.append(other_node) + else: + for graph in other_graphs: + other_node = graph.node_map.get(main_node.id) + if not other_node or main_node.get_ancestors() != other_node.get_ancestors(): + other_nodes.clear() + break + other_nodes.append(other_node) + + return other_nodes + + @staticmethod + def _slice_list_at_id(node_list, target_id1, target_id2): + start_index, end_index = -1, -1 + for index, node in enumerate(node_list): + if target_id1 in node.id: + start_index = index + elif target_id2 in node.id: + end_index = index + return [] if start_index == -1 or end_index == -1 else node_list[start_index:end_index + 1] + + def merge_graphs(self): + results_groups = self.split_graph_results_by_groups(self.get_groups()) + results = [] + for result_groups in results_groups: + self.merge_graph_api_collection(result_groups) + results.extend(self.merge_tp_graphs(result_groups)) + return results + + def merge_tp_graphs(self, results, tp_merge_mapping=None): + if not results or len(results) < 2: + return results + graphs = [x.graph for x in results] + main_graph_result = results[0] + for main_node in main_graph_result.graph.node_map.values(): + should_continue = ( + not main_node.upnode or main_node.upnode.op != NodeOp.module or + main_node.upnode.id in self.unmerged_module or main_node.id.startswith(Const.DISTRIBUTED) or + main_node.parallel_merge_info != []) + if should_continue: + continue + self._handle_tp_matmul_reduce(main_node, graphs[1:], tp_merge_mapping) + other_nodes = self._get_need_merge_node(main_node, graphs[1:], tp_merge_mapping) + tp_need_merge_param_in, tp_need_merge_param_out = self.compare_node_param_data(main_node, other_nodes) + if tp_need_merge_param_in or tp_need_merge_param_out: + ranks = [main_node.rank] + for other_node in other_nodes: + ranks.append(other_node.rank) + main_node.parallel_merge_info.append(f'{self.TP_MERGED_INFO}{ranks}.') + merge_info_in = self._merge_params(tp_need_merge_param_in) + merge_info_out = self._merge_params(tp_need_merge_param_out) + main_node.parallel_merge_info.extend(merge_info_in + merge_info_out) + for main_node in main_graph_result.graph.node_map.values(): + self._merge_tp_megatron_column_row_parallel(main_node, graphs[1:], tp_merge_mapping) + return [main_graph_result] + + def get_groups(self): + tp_groups = [] + for result in self.build_graph_results: + for node in result.graph.node_map.values(): + if any(op in node.id for op in GraphConst.REDUCE_OPERATIONS): + group_ranks = node.input_data.get(f'{node.id}.input.group', {}).get('group_ranks') + if group_ranks and group_ranks not in tp_groups: + tp_groups.append(group_ranks) + break + if not tp_groups: + logger.info('Unable to get tp groups based on Distributed Api (reduce_scatter or all_reduce), ' + 'generate tp groups using parallel param "rank_size", "tp" and "pp".') + tp_groups, _ = self.get_default_groups() + logger.info(f'{self.log_prefix} All tp groups is {tp_groups}.') + return tp_groups + + def _handle_tp_matmul_reduce(self, node, other_graphs, tp_merge_mapping): + """ + 前向RowParallel和反向ColumnParallel层的matmul输出需要替换成matmul计算完成后all_reduce/reduce_scatter的输出 + """ + if node.op != NodeOp.module: + return + splits = node.id.split(Const.SEP) + if len(splits) < 4: + return + is_forward_with_row_parallel = splits[-2] == Const.FORWARD and 'RowParallelLinear' in splits[-3] + is_backward_with_column_parallel = splits[-2] == Const.BACKWARD and 'ColumnParallelLinear' in splits[-3] + if not is_forward_with_row_parallel and not is_backward_with_column_parallel: + return + matmul_list = [] + reduce_list = [] + for sub_node in node.subnodes: + if 'matmul' in sub_node.id: + matmul_list.append(sub_node) + if ('_reduce_scatter_base' in sub_node.id or 'reduce_scatter_tensor' in sub_node.id or + 'all_reduce' in sub_node.id): + reduce_list.append(sub_node) + if not matmul_list or not reduce_list: + return + for matmul_node in matmul_list: + if not matmul_node.output_data: + continue + # matmul的output0,将传递给all_reduce/reduce_scatter,作为all_reduce的input0,或作为reduce_scatter的input1 + matmul_node_output_param = list(matmul_node.output_data.values())[0] + for reduce_node in reduce_list: + if not reduce_node.output_data: + continue + if 'all_reduce' in reduce_node.id: + if not reduce_node.input_data: + continue + reduce_node_input_param = list(reduce_node.input_data.values())[0] + else: + if len(reduce_node.input_data) < 2: + continue + reduce_node_input_param = list(reduce_node.input_data.values())[1] + if not self.compare_param_same(matmul_node_output_param, reduce_node_input_param): + continue + # matmul的input统计值与其他rank的数据进行合并 + other_nodes = self._get_need_merge_node(matmul_node, other_graphs, tp_merge_mapping) + tp_need_merge_param_in, _ = self.compare_node_param_data(matmul_node, other_nodes) + if tp_need_merge_param_in: + ranks = [matmul_node.rank] + for other_node in other_nodes: + ranks.append(other_node.rank) + matmul_node.parallel_merge_info.append(f'{self.TP_MERGED_INFO}{ranks}.') + merge_info_in = self._merge_params(tp_need_merge_param_in) + matmul_node.parallel_merge_info.extend(merge_info_in) + # matmul的output0替换为all_reduce/reduce_scatter的output0 + reduce_node_output_param = list(reduce_node.output_data.values())[0] + keys = [Const.MAX, Const.MIN, Const.MEAN, Const.NORM] + matmul_node_output_param.update({k: reduce_node_output_param.get(k) for k in keys}) + full_op_name = reduce_node_output_param.get('full_op_name') + param_name = full_op_name if full_op_name else reduce_node.id + matmul_node.parallel_merge_info.append(f'The output of this data is merged from {param_name}') + reduce_list.remove(reduce_node) + break + + def _merge_tp_megatron_column_row_parallel(self, node, other_graphs, tp_merge_mapping): + if node.op != NodeOp.module or node.parallel_merge_info: + return + splits = node.id.split(Const.SEP) + if len(splits) < 4: + return + is_forward_with_column_parallel = splits[-2] == Const.FORWARD and 'ColumnParallelLinear' in splits[-3] + if not is_forward_with_column_parallel: + return + if not node.upnode: + return + # 获取[ColumnParallelLinear, RowParallelLinear]结构 + nodes = self._slice_list_at_id(node.upnode.subnodes, node.id, 'RowParallelLinear') + if len(nodes) < 2: + return + stack = nodes[:] + while stack: + current_node = stack.pop() + stack.extend(reversed(current_node.subnodes)) + + if current_node.parallel_merge_info or current_node.id.startswith(Const.DISTRIBUTED): + continue + + other_nodes = self._get_need_merge_node(current_node, other_graphs, tp_merge_mapping) + param_in, param_out = self.compare_node_param_data(current_node, other_nodes, False) + + if param_in or param_out: + ranks = [current_node.rank] + for other_node in other_nodes: + ranks.append(other_node.rank) + current_node.parallel_merge_info.append(f'{self.TP_MERGED_INFO}{ranks}.') + # ColumnParallelLinear层的输入、其中的matmul输入不需要合并 + if current_node == nodes[0] or ('matmul' in current_node.id and current_node.upnode == nodes[0]): + param_in.pop('input.0', None) + # RowParallelLinear层的输出、其中的matmul输出不需要合并, bias不需要合并 + elif current_node == nodes[-1] or ('matmul' in current_node.id and current_node.upnode == nodes[-1]): + param_out = {} + param_in.pop('parameters.bias', None) + + merge_info_in = self._merge_params(param_in) + merge_info_out = self._merge_params(param_out) + current_node.parallel_merge_info.extend(merge_info_in + merge_info_out) + + +class NoParallelMerger(BaseGraphMerger): + def merge_graphs(self): + self.merge_graph_api_collection(self.build_graph_results) + return self.build_graph_results + + +class TPPPMerger(BaseGraphMerger): + def merge_graphs(self): + tp_merger = TPMerger(self.build_graph_results, self.parallel_param, self.is_bench) + pp_merger = PPMerger(self.build_graph_results, self.parallel_param, self.is_bench) \ + if self.parallel_param.vpp == 1 else VPPMerger(self.build_graph_results, self.parallel_param, self.is_bench) + pp_groups = pp_merger.get_groups() + tp_groups = tp_merger.get_groups() + # 进入TP+PP混合处理器,PP和TP必然大于1 + tp_merge_mapping = {} + for tp_group in tp_groups[1:]: + tp_merge_mapping[tp_group[0]] = tp_group[1:] + self.merge_graph_api_collection(self.build_graph_results) + # 先合并pp,需要知道pp域,在各自pp域中合并 + results_groups_pp = self.split_graph_results_by_groups(pp_groups) + pp_results = [] + for results in results_groups_pp: + pp_results.extend(pp_merger.merge_pp_graphs(results)) + # pp合并完成后,直接进行tp合并,最终得到一个graph + tp_result = tp_merger.merge_tp_graphs(pp_results, tp_merge_mapping) + self.sort_merged_api_collection(tp_result[0].graph) + return tp_result + + +class FullMerger(BaseGraphMerger): + def merge_graphs(self): + tp_merger = TPMerger(self.build_graph_results, self.parallel_param, self.is_bench) + pp_merger = PPMerger(self.build_graph_results, self.parallel_param, self.is_bench) \ + if self.parallel_param.vpp == 1 else VPPMerger(self.build_graph_results, self.parallel_param, self.is_bench) + pp_groups = pp_merger.get_groups() + tp_groups = tp_merger.get_groups() + tp_merge_mapping = {} + if len(tp_groups) < 1: + raise RuntimeError(f'Graph merged error, and tp_groups is {tp_groups}.') + for tp_group in tp_groups[1:]: + if len(tp_group) < 1: + raise RuntimeError(f'Graph merged error, and tp_group is {tp_group}.') + tp_merge_mapping[tp_group[0]] = tp_group[1:] + # 先合并pp,需要知道pp域,在各自pp域中合并 + results_groups_pp = self.split_graph_results_by_groups(pp_groups) + pp_results = {} + for pp_result in results_groups_pp: + self.merge_graph_api_collection(pp_result) + pp_result = pp_merger.merge_pp_graphs(pp_result)[0] + pp_results[pp_result.rank] = pp_result + # pp合并完成后,基于tp域划分pp合并结果 + lists_to_be_tp_merged = [] + for tp_group in tp_groups: + list_to_be_tp_merged = [] + for rank in tp_group: + pp_result = pp_results.get(rank) + if pp_result: + list_to_be_tp_merged.append(pp_result) + if list_to_be_tp_merged: + lists_to_be_tp_merged.append(list_to_be_tp_merged) + tp_results = [] + for list_to_be_tp_merged in lists_to_be_tp_merged: + self.merge_graph_api_collection(list_to_be_tp_merged) + tp_merged_result = tp_merger.merge_tp_graphs(list_to_be_tp_merged, tp_merge_mapping) + self.sort_merged_api_collection(tp_merged_result[0].graph) + tp_results.extend(tp_merged_result) + return tp_results + + +class VPPMerger(PPMerger): + LAYERS_NUM_PATTERN = re.compile(r"(layers\.|layer\.)(\d+)(\.)") + FORWARD_PATTERN = re.compile(r'\.forward\.\d+$') + + @staticmethod + def _replace_vpp_id(s, vpp_id): + parts = s.split(Const.SEP) + if len(parts) < 2 or not parts[1].isdigit(): + return s + parts[1] = str(vpp_id) + return Const.SEP.join(parts) + + def merge_pp_graphs(self, results): + if not results or len(results) < 2: + return results + graphs = [x.graph for x in results] + main_graph_result = results[0] + for main_node in main_graph_result.graph.root.subnodes: + if main_node.op == NodeOp.module and main_node.id not in self.unmerged_module: + self._merge_nodes(main_graph_result.graph, main_node, graphs[1:]) + self._sort_nodes(main_graph_result.graph, main_node) + self._merge_vpp_data(main_graph_result.graph) + self._merge_vpp_chunks(main_graph_result.graph) + return [main_graph_result] + + def _merge_vpp_data(self, graph): + """ + 所有chunk的数据都合并到chunk0,前向chunk0的输出使用最后一个chunk的输出,反向chunk0的输入使用最后一个chunk的输入 + """ + module_list = [] + for node in reversed(graph.root.subnodes): + parts = node.id.split(Const.SEP) + if len(parts) < 2: + continue + if parts[1] in [GraphConst.VPP_CHUNK_0, str(self.parallel_param.vpp - 1)]: + module_list.append(node) + if not module_list: + return + stack = module_list[:] + while stack: + current_node = stack.pop() + if hasattr(current_node, 'is_pp_merged') or hasattr(current_node, + 'pp_index') or current_node.op != NodeOp.module: + continue + is_forward = self.FORWARD_PATTERN.search(current_node.id) + stack.extend(reversed(current_node.subnodes)) + target_id = self._replace_vpp_id(current_node.id, self.parallel_param.vpp - 1) + target_node = graph.node_map.get(target_id) + if not target_node: + continue + if is_forward: + current_node.output_data = self._update_node_data_key(target_node.id, current_node.id, + target_node.output_data) + else: + current_node.input_data = self._update_node_data_key(target_node.id, current_node.id, + target_node.input_data) + + def _merge_vpp_chunks(self, graph): + """ + 所有chunk都合并到chunk0,layers层搬到chunk0并重排序号 + """ + chunk_id_list = [i for i in range(1, self.parallel_param.vpp)] + chunk_0_list = [] + for node in reversed(graph.root.subnodes): + parts = node.id.split(Const.SEP) + if len(parts) < 2: + continue + if parts[1] == GraphConst.VPP_CHUNK_0: + chunk_0_list.append(node) + if not chunk_0_list: + return + stack = chunk_0_list[:] + layers_need_merge_dict = {} + while stack: + current_node = stack.pop() + if hasattr(current_node, 'is_pp_merged') or hasattr(current_node, 'pp_index') \ + and current_node.upnode.id not in layers_need_merge_dict: + layers_need_merge_dict[current_node.upnode.id] = current_node.upnode + continue + stack.extend(reversed(current_node.subnodes)) + for node in layers_need_merge_dict.values(): + is_forward = self.FORWARD_PATTERN.search(node.id) + for vpp_id in chunk_id_list: + target_node = graph.node_map.get(self._replace_vpp_id(node.id, vpp_id)) + if not target_node: + continue + # 其他chunk的layers都搬到chunk0,forward追加到后面,backward追加到前面 + if is_forward: + node.subnodes.extend(target_node.subnodes) + else: + node.subnodes = target_node.subnodes + node.subnodes + for sub_node in target_node.subnodes: + sub_node.upnode = node + # 获取其他chunk的层级链路,删除所有父节点,不在前端展示已合并的其他chunk节点 + ancestors = target_node.get_ancestors() + if len(ancestors) < 2: + continue + for module_id in ancestors[1:]: + graph.node_map.pop(module_id, None) + graph.root.subnodes = [node for node in graph.root.subnodes if node.id != ancestors[1]] + # layers层重排序号 + self._sort_layers(node.subnodes, graph, is_forward) + + def _sort_layers(self, node_list, graph, is_forward): + if not is_forward: + node_list = list(reversed(node_list)) + index = -1 + for node in node_list: + match = self.LAYERS_NUM_PATTERN.search(node.id) + if match: + index += 1 + parts = node.id.split(Const.SEP) + # Module.0.xxx代表第一个chunk,不必重排序 + if len(parts) < 2 or parts[1] == GraphConst.VPP_CHUNK_0: + continue + # layers层修改chunk号和layers序号,非layers层修改chunk号 + new_node_id_prefix = '' + if match: + prefix, number, dot = match.groups() + new_string = prefix + str(index) + dot + start, end = match.span() + new_node_id_prefix = node.id[:start] + new_string + new_node_id_prefix = self._replace_vpp_id(new_node_id_prefix, GraphConst.VPP_CHUNK_0) + new_node_id = new_node_id_prefix + node.id[end:] + else: + new_node_id = self._replace_vpp_id(node.id, GraphConst.VPP_CHUNK_0) + graph.node_map.pop(node.id, None) + node.input_data = self._update_node_data_key(node.id, new_node_id, node.input_data) + node.output_data = self._update_node_data_key(node.id, new_node_id, node.output_data) + node.id = new_node_id + graph.node_map[new_node_id] = node + stack = node.subnodes[:] + while stack: + current_node = stack.pop() + if current_node.op != NodeOp.module: + continue + stack.extend(reversed(current_node.subnodes)) + match = self.LAYERS_NUM_PATTERN.search(current_node.id) + if match: + _, e = match.span() + new_current_node_id = new_node_id_prefix + current_node.id[e:] + else: + new_current_node_id = self._replace_vpp_id(current_node.id, GraphConst.VPP_CHUNK_0) + current_node.input_data = self._update_node_data_key(current_node.id, new_current_node_id, + current_node.input_data) + current_node.output_data = self._update_node_data_key(current_node.id, new_current_node_id, + current_node.output_data) + graph.node_map.pop(current_node.id, None) + current_node.id = new_current_node_id + graph.node_map[new_current_node_id] = current_node diff --git a/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py b/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py index ee5e3f519ed126b2aaa493e0d3a3b7fce33313e4..bcac6d258dcfa94d2f55a786eb3f5941d1b1ee33 100644 --- a/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py +++ b/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,13 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import re -import math -from msprobe.core.compare.acc_compare import read_op, merge_tensor, get_accuracy + +from msprobe.core.compare.acc_compare import ModeConfig +from msprobe.core.compare.multiprocessing_compute import CompareRealData +from msprobe.core.compare.utils import read_op, merge_tensor, get_accuracy, make_result_table from msprobe.core.common.utils import set_dump_path, get_dump_mode from msprobe.visualization.utils import GraphConst from msprobe.core.common.const import Const -from msprobe.core.compare.acc_compare import ModeConfig + # 用于将节点名字解析成对应的NodeOp的规则 op_patterns = [ @@ -51,16 +54,20 @@ def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False): framework: 框架类型, pytorch或mindspore is_cross_frame: 是否进行跨框架比对,仅支持mindspore比pytorch, 其中pytorch为标杆 """ - mode_config = ModeConfig(stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=Const.ALL) + config_dict = { + 'stack_mode': False, + 'auto_analyze': True, + 'fuzzy_match': False, + 'dump_mode': Const.ALL + } + mode_config = ModeConfig(**config_dict) if framework == Const.PT_FRAMEWORK: - from msprobe.pytorch.compare.pt_compare import PTComparator - return PTComparator(mode_config).do_multi_process(dump_path_param, csv_path) + from msprobe.pytorch.compare.pt_compare import read_real_data + return CompareRealData(read_real_data, mode_config, is_cross_frame).do_multi_process(dump_path_param, csv_path) else: - from msprobe.mindspore.compare.ms_compare import MSComparator, MappingConfig - ms_comparator = MSComparator(mode_config, MappingConfig()) - ms_comparator.cross_frame = is_cross_frame - return ms_comparator.do_multi_process(dump_path_param, csv_path) + from msprobe.mindspore.compare.ms_compare import read_real_data + return CompareRealData(read_real_data, mode_config, is_cross_frame).do_multi_process(dump_path_param, csv_path) def get_input_output(node_data, node_id): @@ -120,11 +127,13 @@ def compare_data_fuzzy(data_dict_list1, data_dict_list2): return True -def format_node_data(data_dict, node_id=None): +def format_node_data(data_dict, node_id=None, compare_mode=None): """ 删除节点数据中不需要展示的字段 """ - del_list = ['requires_grad', 'full_op_name'] + del_list = ['state', 'full_op_name'] + if GraphConst.MD5_COMPARE != compare_mode: + del_list.append(Const.MD5) if node_id and GraphConst.BATCH_P2P in node_id: del_list.extend(['op', 'peer', 'tag', 'group_id']) for _, value in data_dict.items(): @@ -137,31 +146,27 @@ def format_node_data(data_dict, node_id=None): return data_dict -def compare_node(node_ids, data_dicts, stack_json_data, compare_mode): +def compare_node(node_n, node_b, compare_mode): """ 调用acc_compare.py中的get_accuracy获得精度对比指标 真实数据对比模式无法获得精度对比指标,需要调用多进程比对接口 Returns: 包含参数信息和对比指标(真实数据对比模式除外)的list """ - merge_n = _parse_node(node_ids[0], data_dicts[0], stack_json_data, compare_mode) - merge_b = _parse_node(node_ids[1], data_dicts[1], stack_json_data, compare_mode) - result = [] dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode) + merge_n = _parse_node(node_n, dump_mode) + merge_b = _parse_node(node_b, dump_mode) + result = [] get_accuracy(result, merge_n, merge_b, dump_mode) return result -def _parse_node(node_id, data_dict, stack_json_data, compare_mode): +def _parse_node(node, dump_mode): """ 转换节点,使其能够作为acc_compare.py中的get_accuracy的入参 """ - dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode) - op_parsed_list = read_op(data_dict.get(node_id, {}), node_id) - if node_id in stack_json_data: - op_parsed_list.append( - {'full_op_name': node_id, 'full_info': stack_json_data[node_id]}) - else: - op_parsed_list.append({'full_op_name': node_id, 'full_info': None}) + op_parsed_list = [] + op_parsed_list.extend(node.input_data.values()) + op_parsed_list.extend(node.output_data.values()) result = merge_tensor(op_parsed_list, dump_mode) if not result: result['op_name'] = [] @@ -172,7 +177,7 @@ def _format_decimal_string(s): """ 使用正则表达式匹配包含数字、小数点和可选的百分号的字符串 """ - pattern = re.compile(r'\d{1,20}\.\d{1,20}%?') + pattern = re.compile(r'^\d{1,20}\.\d{1,20}%?$') matches = pattern.findall(s) for match in matches: is_percent = match.endswith('%') @@ -227,3 +232,12 @@ def _format_data(data_dict): if all_null: data_dict.clear() data_dict[GraphConst.VALUE] = GraphConst.NULL + + +def get_csv_df(stack_mode, csv_data, compare_mode): + """ + 调用acc接口写入csv + """ + + dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode) + return make_result_table(csv_data, dump_mode, stack_mode) diff --git a/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py b/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py index 902d721a8d1047b687b878eb45a802a1df4154bd..0595a58107f613f29d6f1c1347257660b5cf34da 100644 --- a/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py +++ b/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,34 +14,40 @@ # limitations under the License. import re -from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data -from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file, get_csv_df +from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data, get_csv_df +from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file from msprobe.visualization.graph.graph import Graph, NodeOp -from msprobe.visualization.graph.node_colors import NodeColors from msprobe.visualization.compare.mode_adapter import ModeAdapter from msprobe.core.common.const import Const +from msprobe.core.common.decorator import recursion_depth_decorator class GraphComparator: - def __init__(self, graphs, dump_path_param, args, mapping_dict=None): + MAX_DEPTH = 1000 + + def __init__(self, graphs, dump_path_param, args, is_cross_framework, mapping_dict=None): self.graph_n = graphs[0] self.graph_b = graphs[1] self._parse_param(dump_path_param, args.output_path) self.framework = args.framework + self.layer_mapping = args.layer_mapping self.mapping_dict = mapping_dict self.fuzzy_match = args.fuzzy_match self.pattern = re.compile(r'\.\d+\.') + self.is_cross_framework = is_cross_framework + self.parallel_merge = args.parallel_merge if hasattr(args, 'parallel_merge') else False + self.rank_pattern = re.compile(r"_rank\d+") def compare(self): """ 比较函数,初始化结束后单独调用。比较结果写入graph_n """ if self.fuzzy_match: - self._compare_nodes_fuzzy(self.graph_n.root) + self._compare_nodes_fuzzy(self.graph_n.root, False if self.parallel_merge else True) else: self._compare_nodes(self.graph_n.root) self._postcompare() - + def add_compare_result_to_node(self, node, compare_result_list): """ 将比对结果添加到节点的输入输出数据中 @@ -66,7 +72,59 @@ class GraphComparator: self.ma.parse_result(node, [compare_in_dict, compare_out_dict])) node.data[GraphConst.JSON_INDEX_KEY] = precision_index node.data.update(other_dict) - + + def _compare_nodes(self, node_root): + """ + 遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比 + 这里采用先序遍历,好处在于当这个节点被比较时,他的先序已经被匹配,这可以为后续的模糊匹配提供重要信息 + """ + def compare_single_node(node_n): + if self.layer_mapping: + node_b, ancestors_n, ancestors_b = Graph.mapping_match(node_n, self.graph_b, self.mapping_dict) + if node_b: + ancestors_n.append(node_n.id) + ancestors_b.append(node_b.id) + node_n.matched_node_link = ancestors_b + node_b.matched_node_link = ancestors_n + else: + node_b, ancestors = Graph.match(self.graph_n, node_n, self.graph_b) + if node_b: + ancestors.append(node_b.id) + node_n.add_link(node_b, ancestors) + if node_b: + # 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口 + self._get_and_add_result(node_n, node_b) + node_list.extend(node_n.subnodes) + + node_list = [node_root] + while node_list: + compare_single_node(node_list.pop(0)) + + def _compare_nodes_fuzzy(self, node_root, check_shape=True): + def compare_single_nodes_fuzzy(node_n): + if node_n.op != NodeOp.function_api: + # 模块经过模糊匹配 + node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(node_n, self.graph_b.node_map.get(node_n.id), + check_shape) + if node_b: + self._process_matched_nodes(node_n, node_b, ancestors_n, ancestors_b) + # 匹配上的两个模块中的所有api, 忽略dump调用次数,按照名称一致+模块中的调用顺序进行匹配 + recount_result_n = self._recount_api_node(node_n) + recount_result_b = self._recount_api_node(node_b) + for recount_node_id, node_id_n in recount_result_n.items(): + api_node_n = self.graph_n.node_map.get(node_id_n) + if not api_node_n: + continue + api_node_b, ancestors_n, ancestors_b = Graph.fuzzy_match( + api_node_n, self.graph_b.node_map.get(recount_result_b.get(recount_node_id)), check_shape) + if api_node_b: + self._process_matched_nodes(api_node_n, api_node_b, ancestors_n, ancestors_b) + node_list.extend(node_n.subnodes) + + node_list = [node_root] + while node_list: + compare_single_nodes_fuzzy(node_list.pop(0)) + def _parse_param(self, dump_path_param, output_path): self.dump_path_param = dump_path_param self.output_path = output_path @@ -81,7 +139,7 @@ class GraphComparator: if not self.ma.compare_mode == GraphConst.REAL_DATA_COMPARE: return df = get_csv_df(True, self.ma.csv_data, self.ma.compare_mode) - df = run_real_data(self.dump_path_param, df, self.framework, True if self.mapping_dict else False) + df = run_real_data(self.dump_path_param, df, self.framework, self.is_cross_framework) compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()} for node in self.ma.compare_nodes: precision_index, _ = self.ma.parse_result(node, [compare_data_dict]) @@ -92,64 +150,26 @@ class GraphComparator: api集合的指标, md5模式使用集合中所有api最小的指标,statistics和tensor模式使用集合中所有api最大的指标 md5模式下指标为0代表最差,statistics和tensor模式下指标为1代表最差 """ + def handle_api_collection_index(api_collection_node): + precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \ + else GraphConst.MIN_INDEX_KEY + for api in api_collection_node.subnodes: + precision_index = min(precision_index, + api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \ + if self.ma.compare_mode == GraphConst.MD5_COMPARE \ + else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY)) + api_collection_node.data[GraphConst.JSON_INDEX_KEY] = precision_index + for node in self.graph_n.root.subnodes: - if node.op == NodeOp.api_collection: - precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \ - else GraphConst.MIN_INDEX_KEY - for api in node.subnodes: - precision_index = min(precision_index, - api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \ - if self.ma.compare_mode == GraphConst.MD5_COMPARE \ - else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY)) - node.data[GraphConst.JSON_INDEX_KEY] = precision_index - - def _compare_nodes(self, node_n): - """ - 递归遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比 - 这里采用先序遍历,好处在于当这个节点被比较时,他的先序已经被匹配,这可以为后续的模糊匹配提供重要信息 - """ - if self.mapping_dict: - node_b, ancestors_n, ancestors_b = Graph.mapping_match(node_n, self.graph_b, self.mapping_dict) - if node_b: - ancestors_n.append(node_n.id) - ancestors_b.append(node_b.id) - node_n.matched_node_link = ancestors_b - node_b.matched_node_link = ancestors_n - else: - node_b, ancestors = Graph.match(self.graph_n, node_n, self.graph_b) - if node_b: - ancestors.append(node_b.id) - node_n.add_link(node_b, ancestors) - if node_b: - # 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口 - self._get_and_add_result(node_n, node_b) - for subnode in node_n.subnodes: - self._compare_nodes(subnode) - - def _compare_nodes_fuzzy(self, node_n): - if node_n.op != NodeOp.function_api: - # 模块经过模糊匹配 - node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(node_n, self.graph_b.node_map.get(node_n.id)) - if node_b: - self._process_matched_nodes(node_n, node_b, ancestors_n, ancestors_b) - # 匹配上的两个模块中的所有api, 忽略dump调用次数,按照名称一致+模块中的调用顺序进行匹配 - recount_result_n = self._recount_api_node(node_n) - recount_result_b = self._recount_api_node(node_b) - for recount_node_id, node_id_n in recount_result_n.items(): - api_node_n = self.graph_n.node_map.get(node_id_n) - if not api_node_n: - continue - api_node_b, ancestors_n, ancestors_b = Graph.fuzzy_match( - api_node_n, self.graph_b.node_map.get(recount_result_b.get(recount_node_id))) - if api_node_b: - self._process_matched_nodes(api_node_n, api_node_b, ancestors_n, ancestors_b) - for sub_node in node_n.subnodes: - self._compare_nodes_fuzzy(sub_node) + if node.op == NodeOp.api_collection and node.id.startswith(GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS): + for sub_node in node.subnodes: + handle_api_collection_index(sub_node) + handle_api_collection_index(node) + elif node.op == NodeOp.api_collection: + handle_api_collection_index(node) def _get_and_add_result(self, node_n, node_b): - compare_result_list = compare_node([node_n.id, node_b.id], - [self.data_n_dict, self.data_b_dict], - self.stack_json_data, self.ma.compare_mode) + compare_result_list = compare_node(node_n, node_b, self.ma.compare_mode) if compare_result_list: self.ma.add_csv_data(compare_result_list) self.add_compare_result_to_node(node_n, compare_result_list) @@ -166,6 +186,8 @@ class GraphComparator: if sub_node.op == NodeOp.function_api: # 忽略dump调用次数 count_removed_id = self.pattern.sub(Const.SEP, sub_node.id) + if self.rank_pattern.search(count_removed_id): + count_removed_id = self.rank_pattern.sub('', count_removed_id) node_count[count_removed_id] = node_count.get(count_removed_id, 0) + 1 # 赋予模块中的调用顺序 recount_node_id = count_removed_id + str(node_count.get(count_removed_id)) diff --git a/debug/accuracy_tools/msprobe/visualization/compare/mode_adapter.py b/debug/accuracy_tools/msprobe/visualization/compare/mode_adapter.py index 535192d80c566c48cedde4ea5b4474b6dc82dec0..2f1c7d5721accb1165e57181573dc9cd9715746f 100644 --- a/debug/accuracy_tools/msprobe/visualization/compare/mode_adapter.py +++ b/debug/accuracy_tools/msprobe/visualization/compare/mode_adapter.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import math +import json from msprobe.core.common.const import CompareConst, Const from msprobe.visualization.utils import ToolTip, GraphConst, str2float @@ -25,6 +25,12 @@ class ModeAdapter: self.csv_data = [] self.compare_nodes = [] + @staticmethod + def _is_invalid(value): + if not isinstance(value, float): + return False + return math.isnan(value) or math.isinf(value) + @staticmethod def _add_md5_compare_data(node_data, compare_data_dict): precision_index = GraphConst.MAX_INDEX_KEY @@ -49,6 +55,8 @@ class ModeAdapter: for key, value in node_data.items(): if not isinstance(value, dict): continue + if value.get(Const.MAX) is None: + continue compare_data = compare_data_dict.get(key) if compare_data: headers = CompareConst.COMPARE_RESULT_HEADER @@ -67,9 +75,13 @@ class ModeAdapter: if thousandth is not None: numbers.append(thousandth) node_data[key] = value + if ModeAdapter._is_invalid(value.get(Const.MAX)) or ModeAdapter._is_invalid(value.get(Const.MIN)): + numbers.append(CompareConst.N_A) # 双千指标都是None的异常情况 if not numbers: min_thousandth = None + elif CompareConst.N_A in numbers: + min_thousandth = CompareConst.N_A else: min_thousandth = min(numbers + [min_thousandth]) return min_thousandth @@ -81,6 +93,8 @@ class ModeAdapter: for key, data_info in node_data.items(): if not isinstance(data_info, dict): continue + if data_info.get(Const.MAX) is None: + continue compare_data = compare_data_dict.get(key) if compare_data: # 对应比对结果csv的列 @@ -92,6 +106,8 @@ class ModeAdapter: relative_err = str2float(data_info.get(item)) max_relative_err = max(max_relative_err, relative_err) node_data[key] = data_info + if ModeAdapter._is_invalid(data_info.get(Const.MAX)) or ModeAdapter._is_invalid(data_info.get(Const.MIN)): + max_relative_err = GraphConst.MAX_INDEX_KEY max_relative_err = 1 if max_relative_err > 1 else max_relative_err return max_relative_err @@ -133,7 +149,11 @@ class ModeAdapter: ModeAdapter._check_list_len(compare_data_dict_list, 1) min_thousandth_in = ModeAdapter._add_real_compare_data(node.input_data, compare_data_dict_list[0]) min_thousandth_out = ModeAdapter._add_real_compare_data(node.output_data, compare_data_dict_list[0]) - if min_thousandth_in is not None and min_thousandth_out is not None: + if CompareConst.N_A == min_thousandth_out: + change_percentage = GraphConst.MAX_INDEX_KEY + elif CompareConst.N_A == min_thousandth_in: + change_percentage = GraphConst.MIN_INDEX_KEY + elif min_thousandth_in is not None and min_thousandth_out is not None: change_percentage = min_thousandth_in - min_thousandth_out else: change_percentage = GraphConst.MIN_INDEX_KEY @@ -141,6 +161,7 @@ class ModeAdapter: else change_percentage precision_index = GraphConst.MAX_INDEX_KEY \ if change_percentage > GraphConst.MAX_INDEX_KEY else change_percentage + precision_index = self._ignore_precision_index(node.id, precision_index) return precision_index, other_dict def prepare_real_data(self, node): @@ -157,24 +178,6 @@ class ModeAdapter: return self.csv_data.extend(compare_result_list) - def add_error_key(self, node_data): - """ - 根据不同的模式进行提供不同错误信息 - """ - for key, value in node_data.items(): - if not isinstance(value, dict): - continue - if self.compare_mode == GraphConst.SUMMARY_COMPARE: - message = [CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR, - CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR] - elif self.compare_mode == GraphConst.REAL_DATA_COMPARE: - message = [CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] - else: - # 输出件优化 - message = [] - value[GraphConst.ERROR_KEY] = message - node_data[key] = value - def get_tool_tip(self): """ 用于前端展示字段的具体含义 @@ -195,3 +198,11 @@ class ModeAdapter: CompareConst.MAX_ABS_ERR: ToolTip.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR: ToolTip.MAX_RELATIVE_ERR} return json.dumps(tips) + + def _ignore_precision_index(self, node_id, precision_index): + node_id_split = node_id.split(Const.SEP) + if len(node_id_split) < 2: + return precision_index + if node_id.split(Const.SEP)[1] in GraphConst.IGNORE_PRECISION_INDEX: + return GraphConst.MAX_INDEX_KEY if self.compare_mode == GraphConst.MD5_COMPARE else GraphConst.MIN_INDEX_KEY + return precision_index diff --git a/debug/accuracy_tools/msprobe/visualization/db_utils.py b/debug/accuracy_tools/msprobe/visualization/db_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..98907b88f0d27eeceafa74405c7aaf56eecba143 --- /dev/null +++ b/debug/accuracy_tools/msprobe/visualization/db_utils.py @@ -0,0 +1,252 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sqlite3 +import json +import re +from msprobe.core.common.log import logger +from msprobe.core.common.file_utils import change_mode, check_path_before_create, FileChecker +from msprobe.core.common.const import FileCheckConst +from msprobe.visualization.utils import GraphConst +from msprobe.visualization.builder.msprobe_adapter import format_node_data + +TEXT_PRIMARY_KEY = 'TEXT PRIMARY KEY' +TEXT_NOT_NULL = 'TEXT NOT NULL' +INTEGER_NOT_NULL = 'INTEGER NOT NULL' +TEXT = 'TEXT' +INTEGER = 'INTEGER' + +node_columns = { + 'id': TEXT_PRIMARY_KEY, + 'graph_id': TEXT_NOT_NULL, + 'node_order': INTEGER_NOT_NULL, + 'node_name': TEXT_NOT_NULL, + 'node_type': TEXT_NOT_NULL, + 'up_node': TEXT, + 'sub_nodes': TEXT, + 'precision_index': INTEGER, + 'overflow_level': TEXT, + 'micro_step_id': INTEGER_NOT_NULL, + 'matched_node_link': TEXT, + 'stack_id': TEXT, + 'parallel_merge_info': TEXT, + 'matched_distributed': TEXT, + 'modified': INTEGER_NOT_NULL, + 'input_data': TEXT, + 'output_data': TEXT, + 'data_source': TEXT, + 'dump_data_dir': TEXT, + 'step': INTEGER_NOT_NULL, + 'rank': INTEGER_NOT_NULL +} + +config_columns = { + 'id': TEXT_PRIMARY_KEY, + 'graph_type': TEXT_NOT_NULL, + 'task': TEXT, + 'tool_tip': TEXT, + 'micro_steps': INTEGER, + 'overflow_check': INTEGER, + 'node_colors': TEXT_NOT_NULL, + 'rank_list': TEXT_NOT_NULL, + 'step_list': TEXT_NOT_NULL +} + +stack_columns = { + 'id': TEXT_PRIMARY_KEY, + 'stack_info': TEXT +} + +indexes = { + "index1": ["step", "rank", "data_source", "up_node", "node_order"], + "index2": ["step", "rank", "data_source", "node_name"], + "index3": ["step", "rank", "data_source", "node_order"], + "index4": ["step", "rank", "node_order"], + "index5": ["step", "rank", "micro_step_id", "node_order"], + "index6": ["step", "rank", "modified", "matched_node_link"] +} + +SAFE_NAME_PATTERN = re.compile(r'^[a-zA-Z0-9_]+$') + + +def is_safe_identifier(name): + """验证标识符是否安全(防止SQL注入)""" + return isinstance(name, str) and SAFE_NAME_PATTERN.match(name) is not None + + +def create_table_sql_from_dict(table_name, columns_dict): + """ + 根据提供的表名和列定义字典生成CREATE TABLE SQL语句。 + """ + if not is_safe_identifier(table_name): + raise ValueError(f"Invalid table name: {table_name} - potential SQL injection risk!") + + sql = f"CREATE TABLE IF NOT EXISTS {table_name} (\n" + + column_definitions = [] + for column_name, column_type in columns_dict.items(): + if not is_safe_identifier(column_name): + raise ValueError(f"Invalid column name: {column_name} - potential SQL injection risk!") + + column_definitions.append(f" {column_name} {column_type}") + + sql += ",\n".join(column_definitions) + sql += "\n);" + + return sql + + +def create_insert_sql_from_dict(table_name, columns_dict, ignore_insert=False): + """ + 根据提供的表名和数据字典生成INSERT INTO SQL语句。 + """ + if not is_safe_identifier(table_name): + raise ValueError(f"Invalid table name: {table_name} - potential SQL injection risk!") + + columns = list(columns_dict.keys()) + + for column_name in columns: + if not is_safe_identifier(column_name): + raise ValueError(f"Invalid column name: {column_name} - potential SQL injection risk!") + + placeholders = ["?"] * len(columns) + + columns_string = ", ".join(columns) + placeholders_string = ", ".join(placeholders) + + sql_prefix = "INSERT OR IGNORE INTO" if ignore_insert else "INSERT INTO" + sql = f"{sql_prefix} {table_name} ({columns_string}) VALUES ({placeholders_string})" + return sql + + +def to_db(db_path, create_table_sql, insert_sql, data, db_insert_size=1000): + if not os.path.exists(db_path): + check_path_before_create(db_path) + else: + FileChecker(db_path, FileCheckConst.FILE, FileCheckConst.READ_WRITE_ABLE, + FileCheckConst.DB_SUFFIX).common_check() + try: + conn = sqlite3.connect(db_path) + except sqlite3.Error as e: + logger.error(f"Unable to create database connection: {e}") + raise RuntimeError("Unable to create database connection") from e + + try: + cursor = conn.cursor() + cursor.execute(create_table_sql) + if len(data) == 1: + cursor.execute(insert_sql, data[0]) + conn.commit() + else: + for i in range(0, len(data), db_insert_size): + batch = data[i:i + db_insert_size] + cursor.executemany(insert_sql, batch) + conn.commit() + except sqlite3.Error as e: + logger.error(f"An sqlite3 error occurred: {e}") + raise RuntimeError("An sqlite3 error occurred") from e + finally: + conn.close() + + +def add_table_index(db_path): + FileChecker(db_path, FileCheckConst.FILE, FileCheckConst.READ_WRITE_ABLE, FileCheckConst.DB_SUFFIX).common_check() + try: + conn = sqlite3.connect(db_path) + except sqlite3.Error as e: + logger.error(f"Unable to create database connection: {e}") + raise RuntimeError("Unable to create database connection") from e + + try: + cursor = conn.cursor() + for index_name, columns in indexes.items(): + if not is_safe_identifier(index_name): + raise ValueError(f"Invalid index name: {index_name} - potential SQL injection risk!") + + for column in columns: + if not is_safe_identifier(column): + raise ValueError(f"Invalid column name in index: {column} - potential SQL injection risk!") + + columns_str = ', '.join(columns) + index_sql = f''' + CREATE INDEX IF NOT EXISTS {index_name} ON tb_nodes ({columns_str}); + ''' + cursor.execute(index_sql) + conn.commit() + except sqlite3.Error as e: + logger.error(f"Failed to add table index: {e}") + raise RuntimeError("Failed to add table index") from e + finally: + conn.close() + + +def post_process_db(db_path): + add_table_index(db_path) + change_mode(db_path, FileCheckConst.DATA_FILE_AUTHORITY) + + +def node_to_db(graph, db_name): + create_table_sql = create_table_sql_from_dict('tb_nodes', node_columns) + insert_sql = create_insert_sql_from_dict('tb_nodes', node_columns) + data = [] + stack_dict = {} + for i, node in enumerate(graph.get_sorted_nodes()): + stack_info_text = json.dumps(node.stack_info) + if stack_info_text not in stack_dict: + stack_dict[stack_info_text] = get_stack_unique_id(graph, stack_dict) + data.append((get_node_unique_id(graph, node), get_graph_unique_id(graph), i, node.id, node.op.value, + node.upnode.id if node.upnode else '', + json.dumps([node.id for node in node.subnodes]) if node.subnodes else '', + node.data.get(GraphConst.JSON_INDEX_KEY), node.data.get(GraphConst.OVERFLOW_LEVEL), + node.micro_step_id if node.micro_step_id is not None else 0, json.dumps(node.matched_node_link), + stack_dict.get(stack_info_text), + json.dumps(node.parallel_merge_info) if node.parallel_merge_info else '', + json.dumps(node.matched_distributed), 0, + json.dumps(format_node_data(node.input_data, node.id, graph.compare_mode)), + json.dumps(format_node_data(node.output_data, node.id, graph.compare_mode)), + graph.data_source, graph.data_path, graph.step, graph.rank)) + to_db(db_name, create_table_sql, insert_sql, data) + stack_to_db(stack_dict, db_name) + + +def config_to_db(config, db_name): + create_table_sql = create_table_sql_from_dict('tb_config', config_columns) + insert_sql = create_insert_sql_from_dict('tb_config', config_columns, ignore_insert=True) + data = [("1", "compare" if config.graph_b else "build", config.task, config.tool_tip, config.micro_steps, + config.overflow_check, json.dumps(config.node_colors), json.dumps(config.rank_list), + json.dumps(config.step_list))] + to_db(db_name, create_table_sql, insert_sql, data) + + +def stack_to_db(stack_dict, db_name): + create_table_sql = create_table_sql_from_dict('tb_stack', stack_columns) + insert_sql = create_insert_sql_from_dict('tb_stack', stack_columns) + data = [] + for stack_info_text, unique_id in stack_dict.items(): + data.append((unique_id, stack_info_text)) + to_db(db_name, create_table_sql, insert_sql, data) + + +def get_graph_unique_id(graph): + return f'{graph.data_source}_{graph.step}_{graph.rank}' + + +def get_node_unique_id(graph, node): + return f'{get_graph_unique_id(graph)}_{node.id}' + + +def get_stack_unique_id(graph, stack_dict): + return f'{get_graph_unique_id(graph)}_{len(stack_dict)}' diff --git a/debug/accuracy_tools/msprobe/visualization/graph/base_node.py b/debug/accuracy_tools/msprobe/visualization/graph/base_node.py index 2642ff1e97ebcc055212d4d776eb7c8a08866dc8..febcbcac84aaa91b3674c75ddc6497abeadba7b9 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph/base_node.py +++ b/debug/accuracy_tools/msprobe/visualization/graph/base_node.py @@ -12,10 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from msprobe.core.overflow_check.level import OverflowLevel -from msprobe.visualization.graph.node_op import NodeOp from msprobe.visualization.utils import GraphConst from msprobe.visualization.builder.msprobe_adapter import format_node_data, compare_data, compare_data_fuzzy +from msprobe.core.common.log import logger class BaseNode: @@ -35,6 +36,8 @@ class BaseNode: self.overflow_level = None self.matched_distributed = {} self.batch_p2p_info = [] + self.rank = 0 + self.parallel_merge_info = [] def __str__(self): info = f'id:\t{self.id}' @@ -86,35 +89,19 @@ class BaseNode: self.matched_node_link = ancestors node.matched_node_link = ancestors - def to_dict(self): - """ - 输出数据 - """ - result = { - 'id': self.id, - 'node_type': self.op.value, - 'output_data': format_node_data(self.output_data, self.id), - 'input_data': format_node_data(self.input_data, self.id), - 'upnode': self.upnode.id if self.upnode else 'None', - 'subnodes': [node.id for node in self.subnodes], - 'matched_node_link': self.matched_node_link, - 'suggestions': self.suggestions, - 'stack_info': self.stack_info - } - if self.micro_step_id is not None: - result['micro_step_id'] = self.micro_step_id - result['data'] = self.data - if self.matched_distributed: - result[GraphConst.MATCHED_DISTRIBUTED] = self.matched_distributed - return result - def get_ancestors(self): """ 获取节点所有祖先的列表 """ ancestors = [] current_node = self.upnode + seen_nodes = set() while current_node: + if current_node.id in seen_nodes: + logger.warning(f'Detected a cycle in the node structure and cannot get node ancestors, ' + f'current node is {current_node.id}.') + return [] + seen_nodes.add(current_node.id) ancestors.append(current_node.id) current_node = current_node.upnode return list(reversed(ancestors)) diff --git a/debug/accuracy_tools/msprobe/visualization/graph/distributed_analyzer.py b/debug/accuracy_tools/msprobe/visualization/graph/distributed_analyzer.py index 5e68d6b2528aea4d6645da2885fa76a7b9bb97b2..90ac8dcfed106018789297248c5555554716cbc0 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph/distributed_analyzer.py +++ b/debug/accuracy_tools/msprobe/visualization/graph/distributed_analyzer.py @@ -82,7 +82,7 @@ class DistributedAnalyzer: """ target_rank = node.input_data.get(f'{node.id}{GraphConst.INPUT}{parameter}', {}).get('value') if target_rank is None: - logger.warning(f'The parameter {parameter} of node {node.id} does not exist, {CANNOT_MATCH}{rank}') + logger.debug(f'The parameter {parameter} of node {node.id} does not exist, {CANNOT_MATCH}{rank}') return target_rank @staticmethod @@ -95,27 +95,18 @@ class DistributedAnalyzer: """ group = node.input_data.get(f'{node.id}{GraphConst.INPUT}group', {}) if not group: - logger.warning(f'The kwarg group of node {node.id} does not exist, {CANNOT_MATCH}{rank}') + logger.debug(f'The kwarg group of node {node.id} does not exist, {CANNOT_MATCH}{rank}') return None, None group_ranks = group.get('group_ranks') if not group_ranks: - logger.warning(f'The group_ranks of node {node.id} does not exist, {CANNOT_MATCH}{rank}') + logger.debug(f'The group_ranks of node {node.id} does not exist, {CANNOT_MATCH}{rank}') return None, None group_id = group.get('group_id') if not group_id: - logger.warning(f'The group_id of node {node.id} does not exist, {CANNOT_MATCH}{rank}') + logger.debug(f'The group_id of node {node.id} does not exist, {CANNOT_MATCH}{rank}') return None, None return group_ranks, group_id - @staticmethod - def _get_batch_group_info(node, rank): - for data in node.input_data.values(): - group_id = data.get('group_id') - if group_id is not None: - return group_id - logger.warning(f'The group_id of node {node.id} does not exist, {CANNOT_MATCH}{rank}') - return None - def distributed_match(self): for rank, graph in self.graphs.items(): nodes = graph.node_map @@ -192,7 +183,7 @@ class DistributedAnalyzer: op = info_dict.get(GraphConst.OP) target_rank = info_dict.get(GraphConst.PEER) if op is None or target_rank is None: - logger.warning('Cannot get param op or peer.') + logger.debug('Cannot get param op or peer.') continue group_id = op + Const.REPLACEMENT_CHARACTER + Const.RANK + str(target_rank) + \ Const.REPLACEMENT_CHARACTER + info_dict.get(GraphConst.GROUP_ID, '') @@ -224,7 +215,7 @@ class DistributedAnalyzer: """ target_graph = self.graphs.get(target_rank) if not target_graph: - logger.warning(f'Graph data does not exist, {CANNOT_MATCH}{target_rank}') + logger.debug(f'Graph data does not exist, {CANNOT_MATCH}{target_rank}') return None target_group_mapping = self.group_node_mapping.get(target_rank) # p2p通信,想要获取目标节点,需要替换unique_group_id中的rank和api name, @@ -235,7 +226,7 @@ class DistributedAnalyzer: target_node_id = target_group_mapping.get(target_unique_group_id, '') target_node = target_graph.node_map.get(target_node_id) if not target_node: - logger.warning(f'Node {target_node_id} does not exist, {CANNOT_MATCH}{target_rank}') + logger.debug(f'Node {target_node_id} does not exist, {CANNOT_MATCH}{target_rank}') return None return target_node @@ -285,13 +276,13 @@ class DistributedAnalyzer: source_rank = (target_node.input_data.get(f'{target_node.id}{GraphConst.INPUT}{target_config_info[1]}', {}) .get('value')) if source_rank is None: - logger.warning( + logger.debug( f'The kwarg {target_config_info[1]} of node {target_node.id} does not exist, ' f'{CANNOT_MATCH}{source_rank}') return if source_rank != rank: # 点对点通信,待匹配目标节点包含的rank信息要与当前rank一致 - logger.warning( + logger.debug( f'{node.id} of rank{rank} is expected to communicate with {target_node.id} of rank{target_rank}, ' f'but the data shows that {target_node.id} communicates with rank{source_rank}.' f'The rank is inconsistent, cannot match distributed node') @@ -300,7 +291,7 @@ class DistributedAnalyzer: # 点对点通信,两个匹配节点的输出数据要一致 if not DistributedAnalyzer._node_output_all_equal(node.output_data.get(node.id + '.output.0'), target_node.output_data.get(target_node.id + '.output.0')): - logger.warning(f'{node.id} output of rank{rank} is different from the {target_node.id} ' + logger.debug(f'{node.id} output of rank{rank} is different from the {target_node.id} ' f'output of rank{target_rank}, cannot match distributed node') return @@ -341,7 +332,7 @@ class DistributedAnalyzer: if not target_group_id: continue if group_id != target_group_id: - logger.warning( + logger.debug( f'{node.id} of rank{rank} is expected to communicate with {target_node.id} of rank{target_rank}' f', but the data shows that the group id of the two nodes are different, ' f'cannot match distributed node') @@ -377,7 +368,7 @@ class DistributedAnalyzer: target_api_name = self.config.get(api_name)[0] target_rank = int(id_info[1].replace(Const.RANK, '')) except Exception as e: - logger.warning(f'Failed to parsing batch p2p parameter with error info: {e}.') + logger.debug(f'Failed to parse batch p2p parameter with error info: {e}.') continue target_node = self._get_target_node(rank, unique_group_id, api_name, target_rank, target_api_name) if not target_node: diff --git a/debug/accuracy_tools/msprobe/visualization/graph/graph.py b/debug/accuracy_tools/msprobe/visualization/graph/graph.py index 5ce12d1cadb9aec2cc7c65954bb861b85032212d..e03418d771fd3f9343329db7a563751e414e320a 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph/graph.py +++ b/debug/accuracy_tools/msprobe/visualization/graph/graph.py @@ -18,9 +18,7 @@ from msprobe.visualization.graph.node_op import NodeOp from msprobe.visualization.utils import GraphConst from msprobe.core.common.log import logger from msprobe.core.common.const import Const - - -MAX_RECUR_LEVEL = 100 +from msprobe.core.common.decorator import recursion_depth_decorator class Graph: @@ -31,6 +29,10 @@ class Graph: self.root = self.get_node(model_name) self.data_path = data_path self.dump_data = dump_data + self.data_source = GraphConst.JSON_NPU_KEY + self.step = 0 + self.rank = 0 + self.compare_mode = GraphConst.SUMMARY_COMPARE def __str__(self): infos = [f'{str(self.node_map.get(node_id))}' for node_id in self.node_map] @@ -67,22 +69,16 @@ class Graph: ancestors_b = node_b.get_ancestors() return node_b, ancestors_n, ancestors_b - @staticmethod - def fuzzy_match(node_n, node_b): - if not node_n or not node_b or not node_n.fuzzy_eq(node_b): + def fuzzy_match(node_n, node_b, check_shape=True): + if not node_n or not node_b: + return None, [], [] + if check_shape and not node_n.fuzzy_eq(node_b): return None, [], [] ancestors_n = node_n.get_ancestors() ancestors_b = node_b.get_ancestors() return node_b, ancestors_n, ancestors_b - @staticmethod - def dfs(node, result): - info = node.to_dict() - result[node.id] = info - for subnode in node.subnodes: - Graph.dfs(subnode, result) - @staticmethod def split_nodes_by_micro_step(nodes): """ @@ -127,6 +123,25 @@ class Graph: result[micro_step].append(node) return result + def get_sorted_nodes(self): + """ + 通过深度优先遍历graph,获得排过序的node列表 + """ + visited = set() + order = [] + + @recursion_depth_decorator('msprobe.visualization.graph.graph.Graph.get_nodes_order.visit', max_depth=500) + def visit(node): + if node.id in visited: + return + visited.add(node.id) + for sub_node in node.subnodes: + visit(sub_node) + order.append(node) + + visit(self.root) + return order + def add_node(self, node_op, node_id, up_node=None, id_accumulation=False): """ 在graph中进行节点的添加 @@ -157,19 +172,6 @@ class Graph: """ return self.node_map.get(node_id, None) - def to_dict(self): - """ - 用于数据输出 - """ - result = {} - result[GraphConst.JSON_ROOT_KEY] = self.root.id if self.root else 'None' - result[GraphConst.JSON_DATA_KEY] = self.data_path - result[GraphConst.JSON_NODE_KEY] = {} - for node_id in self.node_map: - info = self.node_map.get(node_id).to_dict() - result[GraphConst.JSON_NODE_KEY][node_id] = info - return result - def paging_by_micro_step(self, graph_other=None): """ 给graph首层节点增加micro step标记,供前端分页展示,有助于在处理大规模图数据时进行优化和管理 @@ -179,6 +181,15 @@ class Graph: graph_other: 可选参数,另一个graph Returns: 分批的数量 """ + + @recursion_depth_decorator( + 'msprobe.visualization.graph.graph.Graph.paging_by_micro_step.propagate_micro_step_id', max_depth=500) + def propagate_micro_step_id(node): + if node.upnode is not None and node.micro_step_id is None: + node.micro_step_id = node.upnode.micro_step_id + for sub_node in node.subnodes: + propagate_micro_step_id(sub_node) + batches_n = Graph.split_nodes_by_micro_step(self.root.subnodes) for batch_number, nodes in batches_n.items(): for node in nodes: @@ -188,6 +199,7 @@ class Graph: node_other = graph_other.get_node(node.matched_node_link[-1]) if node_other: node_other.micro_step_id = batch_number + propagate_micro_step_id(self.root) # 遍历graph_other根节点下的所有子节点,确保未匹配节点也有micro_step_id if graph_other: for node in graph_other.root.subnodes: @@ -197,6 +209,7 @@ class Graph: except ValueError: micro_step_id = 0 node.micro_step_id = micro_step_id + propagate_micro_step_id(graph_other.root) return len(batches_n) def overflow_check(self): diff --git a/debug/accuracy_tools/msprobe/visualization/graph/node_op.py b/debug/accuracy_tools/msprobe/visualization/graph/node_op.py index 33bfa9cc2e34a0960c3ff236a1bd183a5753a0ab..85d7e65bc528298d596398e10e4f9b9d2a35882f 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph/node_op.py +++ b/debug/accuracy_tools/msprobe/visualization/graph/node_op.py @@ -24,7 +24,6 @@ class NodeOp(Enum): function_api = 1 api_collection = 9 - @staticmethod def get_node_op(node_name: str): """ @@ -37,5 +36,5 @@ class NodeOp(Enum): pattern = op_patterns[index] if re.match(pattern, node_name): return op - logger.warning(f"Cannot parsing node_name {node_name} into NodeOp, default parsing as module.") + logger.warning(f"Cannot parse node_name {node_name} into NodeOp, default parsing as module.") return NodeOp.module diff --git a/debug/accuracy_tools/msprobe/visualization/graph_service.py b/debug/accuracy_tools/msprobe/visualization/graph_service.py index 75b0014c1c09abb8dfecf285fed5eed3063827a0..1f14aa94d8eb79c00e6cff46ab4cf19af94ce7ec 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph_service.py +++ b/debug/accuracy_tools/msprobe/visualization/graph_service.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,90 +15,112 @@ import os import time -import json +from copy import deepcopy +from multiprocessing import cpu_count, Pool from msprobe.core.common.file_utils import (check_file_type, create_directory, FileChecker, check_file_or_directory_path, load_json) from msprobe.core.common.const import FileCheckConst, Const -from msprobe.core.common.utils import CompareException -from msprobe.core.overflow_check.checker import AnomalyDetector +from msprobe.core.common.utils import CompareException, get_dump_mode from msprobe.visualization.compare.graph_comparator import GraphComparator -from msprobe.visualization.utils import GraphConst, check_directory_content -from msprobe.visualization.builder.graph_builder import GraphBuilder, GraphExportConfig +from msprobe.visualization.utils import GraphConst, check_directory_content, SerializableArgs, load_parallel_param, \ + sort_rank_number_strings, check_whether_parallel_merge, validate_parallel_param, get_step_or_rank_int +from msprobe.visualization.builder.graph_builder import GraphBuilder, GraphExportConfig, GraphInfo, BuildGraphTaskInfo from msprobe.core.common.log import logger from msprobe.visualization.graph.node_colors import NodeColors from msprobe.core.compare.layer_mapping import generate_api_mapping_by_layer_mapping from msprobe.core.compare.utils import check_and_return_dir_contents +from msprobe.core.common.utils import detect_framework_by_dump_json from msprobe.visualization.graph.distributed_analyzer import DistributedAnalyzer +from msprobe.visualization.builder.graph_merger import GraphMerger +from msprobe.visualization.db_utils import post_process_db current_time = time.strftime("%Y%m%d%H%M%S") +build_output_db_name = f'build_{current_time}.vis.db' +compare_output_db_name = f'compare_{current_time}.vis.db' -def _compare_graph(input_param, args): - logger.info('Start building model graphs...') - # 对两个数据进行构图 - dump_path_n = input_param.get('npu_path') - dump_path_b = input_param.get('bench_path') - construct_path_n = FileChecker(os.path.join(dump_path_n, GraphConst.CONSTRUCT_FILE), - FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check() - construct_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.CONSTRUCT_FILE), - FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check() - data_path_n = FileChecker(os.path.join(dump_path_n, GraphConst.DUMP_FILE), FileCheckConst.FILE, - FileCheckConst.READ_ABLE).common_check() - data_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.DUMP_FILE), FileCheckConst.FILE, - FileCheckConst.READ_ABLE).common_check() - stack_path_n = FileChecker(os.path.join(dump_path_n, GraphConst.STACK_FILE), FileCheckConst.FILE, - FileCheckConst.READ_ABLE).common_check() - stack_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.STACK_FILE), FileCheckConst.FILE, - FileCheckConst.READ_ABLE).common_check() - graph_n = GraphBuilder.build(construct_path_n, data_path_n, stack_path_n, complete_stack=args.complete_stack) - graph_b = GraphBuilder.build(construct_path_b, data_path_b, stack_path_b, complete_stack=args.complete_stack) - logger.info('Model graphs built successfully, start Comparing graphs...') - # 基于graph、stack和data进行比较 +def _compare_graph(graph_n: GraphInfo, graph_b: GraphInfo, input_param, args): dump_path_param = { - 'npu_json_path': data_path_n, - 'bench_json_path': data_path_b, - 'stack_json_path': stack_path_n, + 'npu_json_path': graph_n.data_path, + 'bench_json_path': graph_b.data_path, + 'stack_json_path': graph_n.stack_path, 'is_print_compare_log': input_param.get("is_print_compare_log", True) } - mapping_dict = None + mapping_dict = {} if args.layer_mapping: - yaml_path = FileChecker(args.layer_mapping, FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check() try: - mapping_dict = generate_api_mapping_by_layer_mapping(data_path_n, data_path_b, yaml_path) + mapping_dict = generate_api_mapping_by_layer_mapping(graph_n.data_path, graph_b.data_path, + args.layer_mapping) except Exception: logger.warning('The layer mapping file parsing failed, please check file format, mapping is not effective.') - graph_comparator = GraphComparator([graph_n, graph_b], dump_path_param, args, mapping_dict=mapping_dict) + is_cross_framework = detect_framework_by_dump_json(graph_n.data_path) != \ + detect_framework_by_dump_json(graph_b.data_path) + if is_cross_framework and not args.layer_mapping: + logger.error('The cross_frame graph comparison failed. ' + 'Please specify -lm or --layer_mapping when performing cross_frame graph comparison.') + raise CompareException(CompareException.CROSS_FRAME_ERROR) + + graph_comparator = GraphComparator([graph_n.graph, graph_b.graph], dump_path_param, args, is_cross_framework, + mapping_dict=mapping_dict) graph_comparator.compare() - micro_steps = graph_n.paging_by_micro_step(graph_b) + return graph_comparator + + +def _compare_graph_result(input_param, args): + logger.info('Start building model graphs...') + # 对两个数据进行构图 + graph_n = _build_graph_info(input_param.get('npu_path'), args) + graph_b = _build_graph_info(input_param.get('bench_path'), args) + logger.info('Model graphs built successfully, start comparing graphs...') + # 基于graph、stack和data进行比较 + graph_comparator = _compare_graph(graph_n, graph_b, input_param, args) + # 增加micro step标记 + micro_steps = graph_n.graph.paging_by_micro_step(graph_b.graph) # 开启溢出检测 if args.overflow_check: - graph_n.overflow_check() - graph_b.overflow_check() + graph_n.graph.overflow_check() + graph_b.graph.overflow_check() - return CompareGraphResult(graph_n, graph_b, graph_comparator, micro_steps) + return CompareGraphResult(graph_n.graph, graph_b.graph, graph_comparator, micro_steps) -def _export_compare_graph_result(args, graphs, graph_comparator, micro_steps, - output_file_name=f'compare_{current_time}.vis'): - create_directory(args.output_path) - output_path = os.path.join(args.output_path, output_file_name) +def _export_compare_graph_result(args, result): + graphs = [result.graph_n, result.graph_b] + graph_comparator = result.graph_comparator + micro_steps = result.micro_steps + logger.info(f'Start exporting compare graph result, file name: {compare_output_db_name}...') + output_db_path = os.path.join(args.output_path, compare_output_db_name) task = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(graph_comparator.ma.compare_mode) export_config = GraphExportConfig(graphs[0], graphs[1], graph_comparator.ma.get_tool_tip(), NodeColors.get_node_colors(graph_comparator.ma.compare_mode), micro_steps, task, - args.overflow_check) - GraphBuilder.to_json(output_path, export_config) - logger.info(f'Model graphs compared successfully, the result file is saved in {output_path}') - - -def _build_graph(dump_path, args): - logger.info('Start building model graph...') + args.overflow_check, graph_comparator.ma.compare_mode, result.step, result.rank, + args.step_list if hasattr(args, 'step_list') else [0], + args.rank_list if hasattr(args, 'rank_list') else [0]) + try: + GraphBuilder.to_db(output_db_path, export_config) + logger.info(f'Exporting compare graph result successfully, the result file is saved in {output_db_path}') + return '' + except RuntimeError as e: + logger.error(f'Failed to export compare graph result, file: {compare_output_db_name}, error: {e}') + return compare_output_db_name + + +def _build_graph_info(dump_path, args, graph=None): construct_path = FileChecker(os.path.join(dump_path, GraphConst.CONSTRUCT_FILE), FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check() data_path = FileChecker(os.path.join(dump_path, GraphConst.DUMP_FILE), FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check() stack_path = FileChecker(os.path.join(dump_path, GraphConst.STACK_FILE), FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check() - graph = GraphBuilder.build(construct_path, data_path, stack_path, complete_stack=args.complete_stack) + if not graph: + graph = GraphBuilder.build(construct_path, data_path, stack_path) + return GraphInfo(graph, construct_path, data_path, stack_path) + + +def _build_graph_result(dump_path, args): + logger.info('Start building model graphs...') + graph = _build_graph_info(dump_path, args).graph + # 增加micro step标记 micro_steps = graph.paging_by_micro_step() # 开启溢出检测 if args.overflow_check: @@ -106,15 +128,118 @@ def _build_graph(dump_path, args): return BuildGraphResult(graph, micro_steps) -def _export_build_graph_result(out_path, graph, micro_steps, overflow_check, - output_file_name=f'build_{current_time}.vis'): - create_directory(out_path) - output_path = os.path.join(out_path, output_file_name) - GraphBuilder.to_json(output_path, GraphExportConfig(graph, micro_steps=micro_steps, overflow_check=overflow_check)) - logger.info(f'Model graph built successfully, the result file is saved in {output_path}') +def _run_build_graph_compare(input_param, args, nr, br): + logger.info(f'Start building graph for {nr}...') + graph_n = _build_graph_info(input_param.get('npu_path'), args) + graph_b = _build_graph_info(input_param.get('bench_path'), args) + logger.info(f'Building graph for {nr} finished.') + return BuildGraphTaskInfo(graph_n, graph_b, nr, br, current_time) + + +def _run_build_graph_single(dump_ranks_path, rank, step, args): + logger.info(f'Start building graph for {rank}...') + dump_path = os.path.join(dump_ranks_path, rank) + result = _build_graph_result(dump_path, args) + if rank != Const.RANK: + result.rank = get_step_or_rank_int(rank, True) + logger.info(f'Building graph for step: {step}, rank: {rank} finished.') + return result + + +def _run_graph_compare(graph_task_info, input_param, args): + logger.info(f'Start comparing data for {graph_task_info.npu_rank}...') + graph_n = graph_task_info.graph_info_n + graph_b = graph_task_info.graph_info_b + nr = graph_task_info.npu_rank + graph_comparator = _compare_graph(graph_n, graph_b, input_param, args) + micro_steps = graph_n.graph.paging_by_micro_step(graph_b.graph) + # 开启溢出检测 + if args.overflow_check: + graph_n.graph.overflow_check() + graph_b.graph.overflow_check() + graph_result = CompareGraphResult(graph_n.graph, graph_b.graph, graph_comparator, micro_steps) + if nr != Const.RANK: + graph_result.rank = get_step_or_rank_int(nr, True) + logger.info(f'Comparing data for {graph_task_info.npu_rank} finished.') + return graph_result + + +def _export_build_graph_result(args, result): + out_path = args.output_path + graph = result.graph + micro_steps = result.micro_steps + overflow_check = args.overflow_check + logger.info(f'Start exporting graph for {build_output_db_name}...') + output_db_path = os.path.join(out_path, build_output_db_name) + config = GraphExportConfig(graph, micro_steps=micro_steps, overflow_check=overflow_check, rank=result.rank, + step=result.step, rank_list=args.rank_list if hasattr(args, 'rank_list') else [0], + step_list=args.step_list if hasattr(args, 'step_list') else [0]) + try: + GraphBuilder.to_db(output_db_path, config) + logger.info(f'Model graph exported successfully, the result file is saved in {output_db_path}') + return None + except RuntimeError as e: + logger.error(f'Failed to export model graph, file: {build_output_db_name}, error: {e}') + return build_output_db_name + + +def is_real_data_compare(input_param, npu_ranks, bench_ranks): + dump_rank_n = input_param.get('npu_path') + dump_rank_b = input_param.get('bench_path') + has_real_data = False + for nr, br in zip(npu_ranks, bench_ranks): + dump_path_param = { + 'npu_json_path': FileChecker(os.path.join(dump_rank_n, nr, GraphConst.DUMP_FILE), FileCheckConst.FILE, + FileCheckConst.READ_ABLE).common_check(), + 'bench_json_path': FileChecker(os.path.join(dump_rank_b, br, GraphConst.DUMP_FILE), FileCheckConst.FILE, + FileCheckConst.READ_ABLE).common_check() + } + has_real_data |= get_dump_mode(dump_path_param) == Const.ALL + return has_real_data + + +def _mp_compare(input_param, serializable_args, nr, br): + graph_task_info = _run_build_graph_compare(input_param, serializable_args, nr, br) + return _run_graph_compare(graph_task_info, input_param, serializable_args) def _compare_graph_ranks(input_param, args, step=None): + with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool: + def err_call(err): + logger.error(f'Error occurred while comparing graph ranks: {err}') + try: + pool.close() + except OSError as e: + logger.error(f'Error occurred while terminating the pool: {e}') + + serializable_args = SerializableArgs(args) + # 暂存所有rank的graph,用于匹配rank间的分布式节点 + compare_graph_results = _get_compare_graph_results(input_param, serializable_args, step, pool, err_call) + + serializable_args.rank_list = [result.rank for result in compare_graph_results] + + # 匹配rank间的分布式节点 + if len(compare_graph_results) > 1: + DistributedAnalyzer({obj.rank: obj.graph_n for obj in compare_graph_results}, + args.overflow_check).distributed_match() + DistributedAnalyzer({obj.rank: obj.graph_b for obj in compare_graph_results}, + args.overflow_check).distributed_match() + + export_res_task_list = [] + create_directory(args.output_path) + for result in compare_graph_results: + export_res_task_list.append(pool.apply_async(_export_compare_graph_result, + args=(serializable_args, result), + error_callback=err_call)) + export_res_list = [res.get() for res in export_res_task_list] + if any(export_res_list): + failed_names = list(filter(lambda x: x, export_res_list)) + logger.error(f'Unable to export compare graph results: {", ".join(failed_names)}.') + else: + logger.info('Successfully exported compare graph results.') + + +def _get_compare_graph_results(input_param, serializable_args, step, pool, err_call): dump_rank_n = input_param.get('npu_path') dump_rank_b = input_param.get('bench_path') npu_ranks = sorted(check_and_return_dir_contents(dump_rank_n, Const.RANK)) @@ -123,32 +248,34 @@ def _compare_graph_ranks(input_param, args, step=None): logger.error('The number of ranks in the two runs are different. Unable to match the ranks.') raise CompareException(CompareException.INVALID_PATH_ERROR) compare_graph_results = [] - for nr, br in zip(npu_ranks, bench_ranks): - logger.info(f'Start processing data for {nr}...') - input_param['npu_path'] = os.path.join(dump_rank_n, nr) - input_param['bench_path'] = os.path.join(dump_rank_b, br) - output_file_name = f'compare_{step}_{nr}_{current_time}.vis' if step else f'compare_{nr}_{current_time}.vis' - result = _compare_graph(input_param, args) - result.output_file_name = output_file_name - if nr != Const.RANK: - try: - result.rank = int(nr.replace(Const.RANK, "")) - except Exception as e: - logger.error('The folder name format is incorrect, expected rank+number.') - raise CompareException(CompareException.INVALID_PATH_ERROR) from e - # 暂存所有rank的graph,用于匹配rank间的分布式节点 - compare_graph_results.append(result) - - # 匹配rank间的分布式节点 - if len(compare_graph_results) > 1: - DistributedAnalyzer({obj.rank: obj.graph_n for obj in compare_graph_results}, - args.overflow_check).distributed_match() - DistributedAnalyzer({obj.rank: obj.graph_b for obj in compare_graph_results}, - args.overflow_check).distributed_match() - - for result in compare_graph_results: - _export_compare_graph_result(args, [result.graph_n, result.graph_b], result.graph_comparator, - result.micro_steps, output_file_name=result.output_file_name) + if is_real_data_compare(input_param, npu_ranks, bench_ranks): + mp_task_dict = {} + for nr, br in zip(npu_ranks, bench_ranks): + input_param['npu_path'] = os.path.join(dump_rank_n, nr) + input_param['bench_path'] = os.path.join(dump_rank_b, br) + build_key = f'{step}_{nr}' if step else f'{nr}' + input_param_copy = deepcopy(input_param) + mp_task_dict[build_key] = pool.apply_async(_run_build_graph_compare, + args=(input_param_copy, serializable_args, nr, br), + error_callback=err_call) + + mp_res_dict = {k: v.get() for k, v in mp_task_dict.items()} + for mp_res in mp_res_dict.values(): + compare_graph_results.append(_run_graph_compare(mp_res, input_param, serializable_args)) + else: + compare_graph_tasks = [] + for nr, br in zip(npu_ranks, bench_ranks): + input_param['npu_path'] = os.path.join(dump_rank_n, nr) + input_param['bench_path'] = os.path.join(dump_rank_b, br) + input_param_copy = deepcopy(input_param) + compare_graph_tasks.append(pool.apply_async(_mp_compare, + args=(input_param_copy, serializable_args, nr, br), + error_callback=err_call)) + compare_graph_results = [task.get() for task in compare_graph_tasks] + if step is not None: + for result in compare_graph_results: + result.step = get_step_or_rank_int(step) + return compare_graph_results def _compare_graph_steps(input_param, args): @@ -159,98 +286,202 @@ def _compare_graph_steps(input_param, args): bench_steps = sorted(check_and_return_dir_contents(dump_step_b, Const.STEP)) if npu_steps != bench_steps: - logger.error('The number of steps in the two runs are different. Unable to match the steps.') + logger.error('The number of steps in the two runs is different. Unable to match the steps.') raise CompareException(CompareException.INVALID_PATH_ERROR) + args.step_list = sorted([get_step_or_rank_int(step) for step in npu_steps]) + for folder_step in npu_steps: logger.info(f'Start processing data for {folder_step}...') input_param['npu_path'] = os.path.join(dump_step_n, folder_step) input_param['bench_path'] = os.path.join(dump_step_b, folder_step) - _compare_graph_ranks(input_param, args, step=folder_step) + _compare_graph_ranks(input_param, args, step=folder_step) if not args.parallel_merge \ + else _compare_graph_ranks_parallel(input_param, args, step=folder_step) def _build_graph_ranks(dump_ranks_path, args, step=None): - ranks = sorted(check_and_return_dir_contents(dump_ranks_path, Const.RANK)) - build_graph_results = [] - for rank in ranks: - logger.info(f'Start processing data for {rank}...') - dump_path = os.path.join(dump_ranks_path, rank) - output_file_name = f'build_{step}_{rank}_{current_time}.vis' if step else f'build_{rank}_{current_time}.vis' - result = _build_graph(dump_path, args) - result.output_file_name = output_file_name - if rank != Const.RANK: + ranks = sort_rank_number_strings(check_and_return_dir_contents(dump_ranks_path, Const.RANK)) + serializable_args = SerializableArgs(args) + with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool: + def err_call(err): + logger.error(f'Error occurred while comparing graph ranks: {err}') try: - result.rank = int(rank.replace(Const.RANK, "")) - except Exception as e: - logger.error('The folder name format is incorrect, expected rank+number.') - raise CompareException(CompareException.INVALID_PATH_ERROR) from e - build_graph_results.append(result) - - if len(build_graph_results) > 1: - DistributedAnalyzer({obj.rank: obj.graph for obj in build_graph_results}, - args.overflow_check).distributed_match() - - for result in build_graph_results: - _export_build_graph_result(args.output_path, result.graph, result.micro_steps, args.overflow_check, - result.output_file_name) + pool.close() + except OSError as e: + logger.error(f'Error occurred while terminating the pool: {e}') + + build_graph_tasks = [] + for rank in ranks: + build_graph_tasks.append(pool.apply_async(_run_build_graph_single, + args=(dump_ranks_path, rank, step, serializable_args), + error_callback=err_call)) + build_graph_results = [task.get() for task in build_graph_tasks] + + if step is not None: + for result in build_graph_results: + result.step = get_step_or_rank_int(step) + + if args.parallel_params: + validate_parallel_param(args.parallel_params[0], dump_ranks_path) + build_graph_results = GraphMerger(build_graph_results, args.parallel_params[0]).merge_graph() + + if len(build_graph_results) > 1 and not args.parallel_merge: + DistributedAnalyzer({obj.rank: obj.graph for obj in build_graph_results}, + args.overflow_check).distributed_match() + + create_directory(args.output_path) + export_build_graph_tasks = [] + serializable_args.rank_list = [result.rank for result in build_graph_results] + for result in build_graph_results: + export_build_graph_tasks.append(pool.apply_async(_export_build_graph_result, + args=(serializable_args, result), + error_callback=err_call)) + export_build_graph_result = [task.get() for task in export_build_graph_tasks] + if any(export_build_graph_result): + failed_names = list(filter(lambda x: x, export_build_graph_result)) + logger.error(f'Unable to export build graph results: {failed_names}.') + else: + logger.info(f'Successfully exported build graph results.') def _build_graph_steps(dump_steps_path, args): steps = sorted(check_and_return_dir_contents(dump_steps_path, Const.STEP)) + args.step_list = sorted([get_step_or_rank_int(step) for step in steps]) + for step in steps: logger.info(f'Start processing data for {step}...') dump_ranks_path = os.path.join(dump_steps_path, step) _build_graph_ranks(dump_ranks_path, args, step) +def _compare_and_export_graph(graph_task_info, input_param, args): + result = _run_graph_compare(graph_task_info, input_param, args) + return _export_compare_graph_result(args, result) + + +def _compare_graph_ranks_parallel(input_param, args, step=None): + args.fuzzy_match = True + npu_path = input_param.get('npu_path') + bench_path = input_param.get('bench_path') + ranks_n = sort_rank_number_strings(check_and_return_dir_contents(npu_path, Const.RANK)) + ranks_b = sort_rank_number_strings(check_and_return_dir_contents(bench_path, Const.RANK)) + parallel_params = load_parallel_param(input_param) + if len(parallel_params) != 2: + raise RuntimeError('Parallel params error in compare graph!') + validate_parallel_param(parallel_params[0], npu_path) + validate_parallel_param(parallel_params[1], bench_path, '[Bench]') + serializable_args = SerializableArgs(args) + + with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool: + def err_call(err): + logger.error(f'Error occurred while comparing graph ranks: {err}') + try: + pool.close() + except OSError as e: + logger.error(f'Error occurred while terminating the pool: {e}') + + # 1.并行构图 + build_graph_tasks_n = [] + build_graph_tasks_b = [] + for rank in ranks_n: + build_graph_tasks_n.append(pool.apply_async(_run_build_graph_single, + args=(npu_path, rank, step, serializable_args), + error_callback=err_call)) + for rank in ranks_b: + build_graph_tasks_b.append(pool.apply_async(_run_build_graph_single, + args=(bench_path, rank, step, serializable_args), + error_callback=err_call)) + graph_results_n = [task.get() for task in build_graph_tasks_n] + graph_results_b = [task.get() for task in build_graph_tasks_b] + + # 2.图合并 + build_graph_results_n = GraphMerger(graph_results_n, parallel_params[0]).merge_graph() + build_graph_results_b = GraphMerger(graph_results_b, parallel_params[1], True).merge_graph() + if len(build_graph_results_n) != len(build_graph_results_b): + raise RuntimeError(f'Parallel merge failed because the dp of npu: {len(build_graph_results_n)} ' + f'is inconsistent with that of bench: {len(build_graph_results_b)}!') + serializable_args.rank_list = [result.rank for result in build_graph_results_n] + # 3.并行图比对和输出 + export_res_task_list = [] + create_directory(args.output_path) + for i, result_n in enumerate(build_graph_results_n): + graph_n = result_n.graph + graph_b = build_graph_results_b[i].graph + graph_task_info = BuildGraphTaskInfo( + _build_graph_info(os.path.join(npu_path, f'rank{graph_n.root.rank}'), args, graph_n), + _build_graph_info(os.path.join(bench_path, f'rank{graph_b.root.rank}'), args, graph_b), + f'rank{graph_n.root.rank}', f'rank{graph_b.root.rank}', current_time) + export_res_task_list.append(pool.apply_async(_compare_and_export_graph, + args=(graph_task_info, input_param, serializable_args), + error_callback=err_call)) + export_res_list = [res.get() for res in export_res_task_list] + if any(export_res_list): + failed_names = list(filter(lambda x: x, export_res_list)) + logger.error(f'Unable to export compare graph results: {", ".join(failed_names)}.') + else: + logger.info('Successfully exported compare graph results.') + + def _graph_service_parser(parser): parser.add_argument("-i", "--input_path", dest="input_path", type=str, help=" The compare input path, a dict json.", required=True) parser.add_argument("-o", "--output_path", dest="output_path", type=str, help=" The compare task result out path.", required=True) - parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, + parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, nargs='?', const=True, help=" The layer mapping file path.", required=False) parser.add_argument("-oc", "--overflow_check", dest="overflow_check", action="store_true", help=" whether open overflow_check for graph.", required=False) parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true", help=" Whether to perform a fuzzy match on the api name.", required=False) - parser.add_argument("-cs", "--complete_stack", dest="complete_stack", action="store_true", - help=" Whether to use complete stack information.", required=False) def _graph_service_command(args): input_param = load_json(args.input_path) npu_path = input_param.get("npu_path") bench_path = input_param.get("bench_path") + args.parallel_merge = check_whether_parallel_merge(input_param) + args.parallel_params = load_parallel_param(input_param) if args.parallel_merge else None check_file_or_directory_path(npu_path, isdir=True) if bench_path: check_file_or_directory_path(bench_path, isdir=True) if check_file_type(npu_path) == FileCheckConst.DIR and not bench_path: content = check_directory_content(npu_path) + output_db_path = os.path.join(args.output_path, build_output_db_name) if content == GraphConst.RANKS: _build_graph_ranks(npu_path, args) elif content == GraphConst.STEPS: _build_graph_steps(npu_path, args) else: - result = _build_graph(npu_path, args) - _export_build_graph_result(args.output_path, result.graph, result.micro_steps, args.overflow_check) + result = _build_graph_result(npu_path, args) + create_directory(args.output_path) + file_name = _export_build_graph_result(args, result) + if file_name: + logger.error('Failed to export model build graph.') elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR: content_n = check_directory_content(npu_path) content_b = check_directory_content(bench_path) + output_db_path = os.path.join(args.output_path, compare_output_db_name) if content_n != content_b: raise ValueError('The directory structures of npu_path and bench_path are inconsistent.') if content_n == GraphConst.RANKS: - _compare_graph_ranks(input_param, args) + if args.parallel_merge: + _compare_graph_ranks_parallel(input_param, args) + else: + _compare_graph_ranks(input_param, args) elif content_n == GraphConst.STEPS: _compare_graph_steps(input_param, args) else: - result = _compare_graph(input_param, args) - _export_compare_graph_result(args, [result.graph_n, result.graph_b], - result.graph_comparator, result.micro_steps) + result = _compare_graph_result(input_param, args) + create_directory(args.output_path) + file_name = _export_compare_graph_result(args, result) + if file_name: + logger.error('Failed to export model compare graph.') else: logger.error("The npu_path or bench_path should be a folder.") raise CompareException(CompareException.INVALID_COMPARE_MODE) + # 所有数据输出db结束后,添加索引,修改权限 + post_process_db(output_db_path) def _pt_graph_service_parser(parser): @@ -270,18 +501,18 @@ def _ms_graph_service_command(args): class CompareGraphResult: - def __init__(self, graph_n, graph_b, graph_comparator, micro_steps, rank=0, output_file_name=''): + def __init__(self, graph_n, graph_b, graph_comparator, micro_steps, rank=0, step=0): self.graph_n = graph_n self.graph_b = graph_b self.graph_comparator = graph_comparator self.micro_steps = micro_steps self.rank = rank - self.output_file_name = output_file_name + self.step = step class BuildGraphResult: - def __init__(self, graph, micro_steps, rank=0, output_file_name=''): + def __init__(self, graph, micro_steps=0, rank=0, step=0): self.graph = graph self.micro_steps = micro_steps self.rank = rank - self.output_file_name = output_file_name + self.step = step diff --git a/debug/accuracy_tools/msprobe/visualization/utils.py b/debug/accuracy_tools/msprobe/visualization/utils.py index 623bcd11c45f1ff8e9c283d30a982af239706ce4..9092ead53212070fe248e5ae870cce1dc6bbd644 100644 --- a/debug/accuracy_tools/msprobe/visualization/utils.py +++ b/debug/accuracy_tools/msprobe/visualization/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,9 +16,12 @@ import os import re import json +import pickle from msprobe.core.common.file_utils import FileOpen from msprobe.core.common.const import CompareConst, Const -from msprobe.core.compare.acc_compare import Comparator, ModeConfig +from msprobe.core.common.log import logger +from msprobe.core.common.exceptions import MsprobeException +from msprobe.core.compare.utils import check_and_return_dir_contents def load_json_file(file_path): @@ -42,23 +45,6 @@ def load_data_json_file(file_path): return load_json_file(file_path).get(GraphConst.DATA_KEY, {}) -def save_json_file(file_path, data): - """ - 保存json文件 - """ - with FileOpen(file_path, 'w') as f: - f.write(json.dumps(data, indent=4)) - - -def get_csv_df(stack_mode, csv_data, compare_mode): - """ - 调用acc接口写入csv - """ - dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode) - mode_config = ModeConfig(stack_mode=stack_mode, dump_mode=dump_mode) - return Comparator(mode_config).make_result_table(csv_data) - - def str2float(percentage_str): """ 百分比字符串转换转换为浮点型 @@ -73,12 +59,19 @@ def str2float(percentage_str): return 0 -def is_integer(s): +def get_step_or_rank_int(x: str, is_rank=False): + """ + 获取字符串rank{int}或者step{int}中的int值,如果x=rank或step,返回0 + """ + if x in [Const.RANK, Const.STEP]: + return 0 + description = Const.RANK if is_rank else Const.STEP try: - int(s) - return True - except Exception: - return False + x_int = int(x.replace(Const.RANK, "")) if is_rank else int(x.replace(Const.STEP, "")) + except Exception as e: + logger.error(f'The folder name format is incorrect, expected {description}+number, such as rank0, step1, etc.') + raise RuntimeError from e + return x_int def check_directory_content(input_path): @@ -126,6 +119,83 @@ def check_directory_content(input_path): "all rank{number} named folders (such as rank0), or all files.") +def extract_rank_number(rank_str): + try: + return int(rank_str[4:]) + except ValueError: + return 0 + + +def sort_rank_number_strings(rank_number_strings): + sorted_list = sorted(rank_number_strings, key=extract_rank_number) + return sorted_list + + +def check_whether_parallel_merge(input_param): + parallel_merge = input_param.get("parallel_merge") + if not isinstance(parallel_merge, dict) or not parallel_merge: + return False + if not parallel_merge.get('npu'): + return False + return True + + +def load_parallel_param(input_param): + parallel_merge = input_param.get("parallel_merge", {}) + config_n = parallel_merge.get('npu', {}) + config_b = parallel_merge.get('bench', {}) + param_n = ParallelParam(config_n.get('rank_size'), config_n.get('tp'), config_n.get('pp'), config_n.get('vpp', 1), + config_n.get('order', 'tp-cp-ep-dp-pp')) + param_b = ParallelParam(config_b.get('rank_size'), config_b.get('tp'), config_b.get('pp'), config_b.get('vpp', 1), + config_b.get('order', 'tp-cp-ep-dp-pp')) + return (param_n,) if not config_b else (param_n, param_b) + + +def validate_parallel_param(parallel_param, dump_path, log_prefix='[NPU]'): + params = [parallel_param.tp, parallel_param.pp, parallel_param.rank_size] + ranks = check_and_return_dir_contents(dump_path, Const.RANK) + if len(ranks) != parallel_param.rank_size: + logger.error(f'{log_prefix} The parallel param "rank_size" error, ' + f'you set {parallel_param.rank_size} but expected {len(ranks)}.') + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if any(x is None for x in params): + logger.error(f'{log_prefix} The parallel params "tp/pp/rank_size" must not be null!') + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if any(x <= 0 for x in params): + logger.error(f'{log_prefix} The parallel params "tp/pp/vpp/rank_size" must be greater than 0!') + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if parallel_param.tp > parallel_param.rank_size: + logger.error(f'{log_prefix} The parallel param "tp" must be less than or equal to "rank_size"!') + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if parallel_param.pp > parallel_param.rank_size: + logger.error(f'{log_prefix} The parallel param "pp" must be less than or equal to "rank_size"!') + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if parallel_param.rank_size % parallel_param.tp != 0: + logger.error(f'{log_prefix} The parallel param "rank_size" must be divisible by "tp"!') + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if parallel_param.rank_size % parallel_param.pp != 0: + logger.error(f'{log_prefix} The parallel param "rank_size" must be divisible by "pp"!') + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if parallel_param.tp * parallel_param.pp > parallel_param.rank_size: + logger.error(f'{log_prefix} The parallel params "tp * pp" must be less than or equal to "rank_size"!') + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if parallel_param.vpp > 1 and parallel_param.pp < 2: + logger.error(f'{log_prefix} When configuring the parallel param "vpp", the "pp" param must be greater than 1!') + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if not isinstance(parallel_param.order, str): + logger.error(f'{log_prefix} The parallel params "order" must be of string type!') + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + + +class ParallelParam: + def __init__(self, rank_size, tp, pp, vpp=1, order='tp-cp-ep-dp-pp'): + self.rank_size = rank_size + self.tp = tp + self.pp = pp + self.vpp = vpp + self.order = order + + class ToolTip: MAX_DIFF = 'NPU与标杆API统计信息比对,最大值的差值' MIN_DIFF = 'NPU与标杆API统计信息比对,最小值的差值' @@ -143,14 +213,12 @@ class ToolTip: '当最大相对误差越接近0表示其计算的误差越小。' '当dump数据中存在0或Nan时,比对结果中最大相对误差则出现inf或Nan的情况,属于正常现象' ) - SMALL_VALUE_TIP = '{}, 由于{}小于{}, 建议不参考此相对误差,请参考绝对误差' class GraphConst: CONSTRUCT_FILE = 'construct.json' DUMP_FILE = 'dump.json' STACK_FILE = 'stack.json' - GRAPH_FILE = 'graph.vis' ERROR_KEY = 'error_key' SUMMARY_COMPARE = 0 MD5_COMPARE = 1 @@ -164,35 +232,23 @@ class GraphConst: JSON_DATA_KEY = 'dump_data_dir' JSON_TASK_KEY = 'task' DATA_KEY = 'data' - REAL_DATA_TH = 0.1 - MAX_RELATIVE_ERR_TH = 0.5 ROUND_TH = 6 JSON_INDEX_KEY = 'precision_index' MATCHED_DISTRIBUTED = 'matched_distributed' OVERFLOW_LEVEL = 'overflow_level' MAX_INDEX_KEY = 1 MIN_INDEX_KEY = 0 - SUGGEST_KEY = 'text' - TAG_NA = 'na' - OUTPUT_INDEX_TWO = -2 - OUTPUT_INDEX_THREE = -3 - OUTPUT_MIN_LEN = 3 INPUT = '.input.' OUTPUT = '.output.' STR_MAX_LEN = 50 - SMALL_VALUE = 1e-3 - MD5_INDEX_LIST = [CompareConst.RESULT] - REAL_DATA_INDEX_LIST = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR, - CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] - SUMMARY_INDEX_LIST = [CompareConst.MAX_DIFF, CompareConst.MIN_DIFF, CompareConst.MEAN_DIFF, - CompareConst.NORM_DIFF, CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR, - CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR] - VALUE_INDEX_LIST = [Const.MAX, Const.MIN, Const.MEAN, Const.NORM] + MD5_INDEX_LIST = CompareConst.MD5_COMPARE_INDEX + [CompareConst.REQ_GRAD_CONSIST] + REAL_DATA_INDEX_LIST = CompareConst.ALL_COMPARE_INDEX + [CompareConst.REQ_GRAD_CONSIST] + SUMMARY_INDEX_LIST = CompareConst.SUMMARY_COMPARE_INDEX + [CompareConst.REQ_GRAD_CONSIST] APIS_BETWEEN_MODULES = 'Apis_Between_Modules' + APIS_BETWEEN_MODULES_ALL_RANKS = 'Apis_Between_Modules_All_Ranks' NULL = 'null' NONE = 'None' VALUE = 'value' - BRACE = '{}' DESCRIPTION = 'description' COLORS = 'Colors' MICRO_STEPS = 'MicroSteps' @@ -223,3 +279,31 @@ class GraphConst: OP = 'op' PEER = 'peer' GROUP_ID = 'group_id' + + UNCERTAINTY_THRESHOLD = 1e-6 + REDUCE_OPERATIONS = ['reduce_scatter', 'all_reduce'] + + IGNORE_PRECISION_INDEX = {'empty', 'empty_like', 'empty_with_format', 'new_empty_strided', 'new_empty', + 'empty_strided'} + VPP_CHUNK_0 = '0' + + +def is_serializable(obj): + """ + Check if an object is serializable + """ + try: + pickle.dumps(obj) + return True + except (pickle.PicklingError, AttributeError, TypeError): + return False + except Exception as e: + logger.error('Unexpected error occurred while pickling obj.') + raise RuntimeError('Unexpected error occurred while pickling obj.') from e + + +class SerializableArgs: + def __init__(self, args): + for k, v in vars(args).items(): + if is_serializable(v): + setattr(self, k, v) diff --git a/debug/accuracy_tools/setup.py b/debug/accuracy_tools/setup.py index 2da7fcf667765a841b9db1bbf5628fad5b1cf8a9..30a20dd4cc91a58f6e96c4cc6f14bca2849a9fe9 100644 --- a/debug/accuracy_tools/setup.py +++ b/debug/accuracy_tools/setup.py @@ -14,7 +14,7 @@ # limitations under the License. -__version__ = '1.2.2' +__version__ = '8.2.0' import subprocess import platform @@ -24,17 +24,19 @@ import setuptools INSTALL_REQUIRED = [ "wheel", "einops", - "numpy < 2.0", + "numpy >=1.23.0, < 2.0", "pandas >= 1.3.5, < 2.1", "pyyaml", "rich", "tqdm", - "openpyxl", - "pyopenssl", + "openpyxl >= 3.0.6", + "pyopenssl==24.2.1", "twisted", "matplotlib", "tensorboard", - "tabulate" + "tabulate", + "pwinput", + "psutil" ] EXCLUDE_PKGS = [ diff --git a/debug/resources/training_process.png b/debug/resources/training_process.png new file mode 100644 index 0000000000000000000000000000000000000000..e1cf2f20471624cd86edbf45444bb431086d6065 Binary files /dev/null and b/debug/resources/training_process.png differ diff --git a/dynolog_npu/README.md b/dynolog_npu/README.md deleted file mode 100644 index 9cc015e66c656c65fa48ad73a8246487a2016bef..0000000000000000000000000000000000000000 --- a/dynolog_npu/README.md +++ /dev/null @@ -1,148 +0,0 @@ -# Ascend Extension for dynolog - -## 安装方式 - -### 1. clone 代码 - -```bash -git clone https://gitee.com/ascend/mstt.git -``` - -### 2. 安装依赖 -dynolog的编译依赖,确保安装了以下依赖: -
采集模式global_batch_size单卡8卡
L0113GB97GB
225B194GB
337G291GB
225GB194GB
337GB291GB
L11440GB3.4TB
2720GB5.4TB
3960GB7.3TB
- - - - - - - - - - - - -
Language - Toolchain -
C++ - gcc 8.5.0+ -
Rust - Rust 1.58.1 (1.56+ required for clap dependency) -
- -- 安装rust - -```bash -curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh - -source $HOME/.cargo/env -``` - -- 安装ninja - -```bash -# debian -sudo apt-get install -y cmake ninja-build - -# centos -sudo yum install -y cmake ninja -``` - -### 3. 编译 - -默认编译生成dyno和dynolog二进制文件, -t参数可以支持将二进制文件打包成deb包或rpm包. - -```bash -# 编译dyno和dynolog二进制文件 -bash scripts/build.sh - -# 编译deb包, 当前支持amd64和aarch64平台, 默认为amd64, 编译aarch64平台需要修改third_party/dynolog/scripts/debian/control文件中的Architecture改为aarch64 -bash scripts/build.sh -t deb - -# 编译rpm包, 当前只支持amd64平台 -bash scripts/build.sh -t rpm -``` - -## 使用方式 - -### Profiler trace dump功能 -Profiler trace dump功能基于dynolog开发,实现类似于动态profiling的动态触发Ascend Torch Profiler采集profiling的功能。用户基于dyno CLI命令行可以动态触发指定节点的训练进程trace dump。 - -- 查看nputrace支持的命令和帮助 - -```bash -dyno nputrace --help -``` - -- nputrace使用方式 - -```bash -dyno nputrace [SUBCOMMANDS] --log-file -``` - -nputrace子命令支持的参数选项 - -| 子命令 | 参数类型 | 说明 | -|-------|-------|-------| -| record_shapes | action | 是否采集算子的InputShapes和InputTypes,设置参数采集,默认不采集 | -| profile_memory | action | 是否采集算子内存信息,设置参数采集,默认不采集 | -| with_stack | action | 是否采集Python调用栈,设置参数采集,默认不采集 | -| with_flops | action | 是否采集算子flops,设置参数采集,默认不采集 | -| with_modules | action | 是否采集modules层级的Python调用栈,设置参数采集,默认不采集 | -| analyse | action | 采集后是否自动解析,设置参数解析,默认不解析 | -| l2_cache | action | 是否采集L2 Cache数据,设置参数采集,默认不采集 | -| op_attr | action | 是否采集算子属性信息,设置参数采集,默认不采集 | -| data_simplification | String | 解析完成后是否数据精简,可选值范围[`true`, `false`],默认值`true` | -| activities | String | 控制CPU、NPU事件采集范围,可选值范围[`CPU,NPU`, `NPU,CPU`, `CPU`, `NPU`],默认值`CPU,NPU` | -| profiler_level | String | 控制profiler的采集等级,可选值范围[`Level_none`, `Level0`, `Level1`, `Level2`],默认值`Level0`| -| aic_metrics | String | AI Core的性能指标采集项,可选值范围[`AiCoreNone`, `PipeUtilization`, `ArithmeticUtilization`, `Memory`, `MemoryL0`, `ResourceConflictRatio`, `MemoryUB`, `L2Cache`, `MemoryAccess`],默认值`AiCoreNone`| -| export_type | String | profiler解析导出数据的类型,可选值范围[`Text`, `Db`],默认值`Text`| -| gc_detect_threshold | Option | GC检测阈值,单位ms,只采集超过阈值的GC事件。该参数为可选参数,默认不设置时不开启GC检测 | - -- nputrace示例命令 - -```bash -# 示例1:采集框架、CANN和device数据,同时采集完后自动解析以及解析完成不做数据精简,落盘路径为/tmp/profile_data -dyno nputrace --activities CPU,NPU --analyse --data_simplification false --log-file /tmp/profile_data - -# 示例2:只采集CANN和device数据,同时采集完后自动解析以及解析完成后开启数据精简,落盘路径为/tmp/profile_data -dyno nputrace --activities NPU --analyse --data_simplification true --log-file /tmp/profile_data - -# 示例3:只采集CANN和device数据,只采集不解析,落盘路径为/tmp/profile_data -dyno nputrace --activities NPU --log-file /tmp/profile_data -``` - -### NPU Monitor功能 -NPU Monitor基于MSPTI/MSTX能力开发,实现了轻量级在线监控能力,能够用于性能问题的初步定位。 - -```bash -dyno npu-monitor --help -``` - -- npu-monitor使用方式 - -```bash -dyno npu-monitor [SUBCOMMANDS] -``` - -npu-monitor子命令支持的参数选项 -| 子命令 | 参数类型 | 说明 | -|-------|-------|-------| -| npu_monitor_start | action | 开启性能监控,设置参数开启,默认不采集 | -| npu_monitor_stop | action | 停止性能监控,设置参数开启,默认不采集 | -| report_interval_s | int | 性能监控数据上报周期,单位s,需要在启动时设置。默认值60 | -| mspti_activity_kind | String | 性能监控数据上报数据类型,可以设置单个或多个,多个类型以逗号分隔,需要在启动时设置。可选值范围[`Marker`, `Kernel`, `API`, `Hccl`, `Memory`, `MemSet`, `MemCpy`] , 默认值`Marker`| - -- npu-monitor示例命令 - -```bash -# 示例1:开启性能监控,使用默认配置 -dyno npu-monitor --npu_monitor_start - -# 示例2:暂停性能监控 -dyno npu-monitor --npu_monitor_stop - -# 示例3:开启性能监控,上报周期30s, 上报数据类型Marker和Kernel -dyno npu-monitor --npu_monitor_start 30 --mspti_activity_kind Marker,Kernel -``` \ No newline at end of file diff --git a/dynolog_npu/plugin/Readme.md b/dynolog_npu/plugin/Readme.md deleted file mode 100644 index c59bfffad5aaac5383b407e3ff3d23ed126131f5..0000000000000000000000000000000000000000 --- a/dynolog_npu/plugin/Readme.md +++ /dev/null @@ -1,17 +0,0 @@ - - -# Build and Install npu-dynolog-plugin -``` -# install pybind11 -pip install pybind11 - -# build dynolog_npu_plugin wheel -python3 setup.py bdist_wheel -# install -pip install dist/{dynolog-npu-plugin-xxx.wheel} - -# example -import IPCMonitor -dyno_worker = IPCMonitor.PyDynamicMonitorProxy() -dyno_worker.init_dyno(0) -``` diff --git a/dynolog_npu/plugin/bindings.cpp b/dynolog_npu/plugin/bindings.cpp deleted file mode 100644 index c0cdaa4d577b3a76ec2d6f3eae4b426556a56532..0000000000000000000000000000000000000000 --- a/dynolog_npu/plugin/bindings.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include -#include "ipc_monitor/PyDynamicMonitorProxy.h" - -namespace py = pybind11; - -PYBIND11_MODULE(IPCMonitor, m) { - py::class_(m, "PyDynamicMonitorProxy") - .def(py::init<>()) - .def("init_dyno", &dynolog_npu::ipc_monitor::PyDynamicMonitorProxy::InitDyno, py::arg("npuId")) - .def("poll_dyno", &dynolog_npu::ipc_monitor::PyDynamicMonitorProxy::PollDyno); -} \ No newline at end of file diff --git a/dynolog_npu/plugin/ipc_monitor/DynoLogNpuMonitor.cpp b/dynolog_npu/plugin/ipc_monitor/DynoLogNpuMonitor.cpp deleted file mode 100644 index 940f5aae167f088361057fe2a7a389a76f5bb2b4..0000000000000000000000000000000000000000 --- a/dynolog_npu/plugin/ipc_monitor/DynoLogNpuMonitor.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include "DynoLogNpuMonitor.h" - -#include - -#include "utils.h" - -namespace dynolog_npu { -namespace ipc_monitor { - -bool DynoLogNpuMonitor::Init() -{ - if (isInitialized_) { - std::cout << "[WRARNING] DynoLog npu monitor already initialized" << std::endl; - return true; - } - bool res = ipcClient_.RegisterInstance(npuId_); - if (res) { - isInitialized_ = true; - std::cout << "[INFO] DynoLog npu monitor initialized success !" << std::endl; - } - return res; -} - -std::string DynoLogNpuMonitor::Poll() -{ - std::string res = ipcClient_.IpcClientNpuConfig(); - if (res.empty()) { - std::cout << "[INFO] Request for dynolog server is empty !" << std::endl; - return ""; - } - std::cout << "[INFO] Received NPU configuration successfully" << std::endl; - return res; -} - -} // namespace ipc_monitor -} // namespace dynolog_npu \ No newline at end of file diff --git a/dynolog_npu/plugin/ipc_monitor/MonitorBase.h b/dynolog_npu/plugin/ipc_monitor/MonitorBase.h deleted file mode 100644 index 108023c7624b747e5987be9184d6c594decd360a..0000000000000000000000000000000000000000 --- a/dynolog_npu/plugin/ipc_monitor/MonitorBase.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef MONITOR_BASE_H -#define MONITOR_BASE_H -#include - -namespace dynolog_npu { -namespace ipc_monitor { - -class MonitorBase { -public: - virtual bool Init() = 0; - virtual std::string Poll() = 0; - virtual void SetNpuId(int id) = 0; -}; - -} // namespace ipc_monitor -} // namespace dynolog_npu - -#endif \ No newline at end of file diff --git a/dynolog_npu/plugin/ipc_monitor/PyDynamicMonitorProxy.h b/dynolog_npu/plugin/ipc_monitor/PyDynamicMonitorProxy.h deleted file mode 100644 index 8b5f88abf9d2cf589bec685cd3a520729afe8dd5..0000000000000000000000000000000000000000 --- a/dynolog_npu/plugin/ipc_monitor/PyDynamicMonitorProxy.h +++ /dev/null @@ -1,40 +0,0 @@ -#ifndef PYDYNAMIC_MONITOR_PROXY_H -#define PYDYNAMIC_MONITOR_PROXY_H - -#include -#include -#include "MonitorBase.h" -#include "DynoLogNpuMonitor.h" - -namespace dynolog_npu { -namespace ipc_monitor { - -class PyDynamicMonitorProxy { -public: - PyDynamicMonitorProxy() = default; - bool InitDyno(int npuId) - { - try { - monitor_ = DynoLogNpuMonitor::GetInstance(); - monitor_->SetNpuId(npuId); - bool res = monitor_->Init(); - return res; - } catch (const std::exception &e) { - std::cout << "[ERROR] Error when init dyno " << e.what() << std::endl; - return false; - } - } - - std::string PollDyno() - { - return monitor_->Poll(); - }; - -private: - MonitorBase *monitor_ = nullptr; -}; - -} // namespace ipc_monitor -} // namespace dynolog_npu - -#endif diff --git a/dynolog_npu/plugin/ipc_monitor/utils.cpp b/dynolog_npu/plugin/ipc_monitor/utils.cpp deleted file mode 100644 index 936821fd34bc34bc9db9e09515132e8af39ba57a..0000000000000000000000000000000000000000 --- a/dynolog_npu/plugin/ipc_monitor/utils.cpp +++ /dev/null @@ -1,135 +0,0 @@ -#include "utils.h" - -namespace dynolog_npu { -namespace ipc_monitor { -std::unordered_map submoduleMap = { - {SubModule::IPC, "IPC"}, -}; - -std::unordered_map errCodeMap = { - {ErrCode::SUC, "success"}, - {ErrCode::PARAM, "invalid parameter"}, - {ErrCode::TYPE, "invalid type"}, - {ErrCode::VALUE, "invalid value"}, - {ErrCode::PTR, "invalid pointer"}, - {ErrCode::INTERNAL, "internal error"}, - {ErrCode::MEMORY, "memory error"}, - {ErrCode::NOT_SUPPORT, "feature not supported"}, - {ErrCode::NOT_FOUND, "resource not found"}, - {ErrCode::UNAVAIL, "resource unavailable"}, - {ErrCode::SYSCALL, "system call failed"}, - {ErrCode::TIMEOUT, "timeout error"}, - {ErrCode::PERMISSION, "permission error"}, -}; - -std::string getCurrentTimestamp() -{ - auto now = std::chrono::system_clock::now(); - auto micros = std::chrono::duration_cast(now.time_since_epoch()); - - std::time_t currentTime = std::chrono::system_clock::to_time_t(now); - std::tm* timeInfo = std::localtime(¤tTime); - - auto milli_time = std::chrono::duration_cast(micros).count() % 1000; - auto micro_time = micros.count() % 1000; - - std::ostringstream oss; - oss << std::put_time(timeInfo, "%Y-%m-%d-%H:%M:%S"); - return oss.str(); -} - -std::string formatErrorCode(SubModule submodule, ErrCode errorCode) -{ - std::ostringstream oss; - oss << "\n[ERROR] " << getCurrentTimestamp() << " (PID:" << getpid() << ")"; - oss << "ERR" << std::setw(2) << std::setfill('0') << static_cast(submodule); // 2: 字段宽度 - oss << std::setw(3) << std::setfill('0') << static_cast(errorCode); // 3: 字段宽度 - oss << " " << submoduleMap[submodule] << " " << errCodeMap[errorCode]; - - return oss.str(); -}; - - -int32_t GetProcessId() -{ - return static_cast(getpid()); -} - -std::pair GetParentPidAndCommand(int32_t pid) -{ - std::string fileName = "/proc/" + std::to_string(pid) + "/stat"; - std::ifstream statFile(fileName); - if (!statFile) { - return std::make_pair(0, ""); - } - int32_t parentPid = 0; - std::string command; - std::string line; - if (std::getline(statFile, line)) { - int ret = sscanf(line.c_str(), "%*d (%[^)]) %*c %d", command.data(), &parentPid); - if (ret == 2) { // 2: 接收到2个字符 - std::cout << "[INFO] Success to get parent pid: " << parentPid << std::endl; - return std::make_pair(parentPid, command); - } - } - std::cout << "[WARNING] Failed to parse /proc/" << pid << "/stat" << std::endl; - return std::make_pair(0, ""); -} - -std::vector> GetPidCommandPairsofAncestors() -{ - std::vector> process_pids_and_cmds; - process_pids_and_cmds.reserve(MaxParentPids + 1); - int32_t current_pid = GetProcessId(); - for (int i = 0; i <= MaxParentPids && (i == 0 || current_pid > 1); i++) { - std::pair parent_pid_and_cmd = GetParentPidAndCommand(current_pid); - process_pids_and_cmds.push_back(std::make_pair(current_pid, parent_pid_and_cmd.second)); - current_pid = parent_pid_and_cmd.first; - } - return process_pids_and_cmds; -} - -std::vector GetPids() -{ - const auto &pids = GetPidCommandPairsofAncestors(); - std::vector res; - res.reserve(pids.size()); - for (const auto &pidPair : pids) { - res.push_back(pidPair.first); - } - return res; -} -std::string GenerateUuidV4() -{ - static std::random_device randomDevice; - static std::mt19937 gen(randomDevice()); - static std::uniform_int_distribution<> dis(0, 15); // range (0, 15) - static std::uniform_int_distribution<> dis2(8, 11); // range (8, 11) - - std::stringstream stringStream; - stringStream << std::hex; - for (int i = 0; i < 8; i++) { // 8 times - stringStream << dis(gen); - } - stringStream << "-"; - for (int j = 0; j < 4; j++) { // 4 times - stringStream << dis(gen); - } - stringStream << "-4"; // add -4 - for (int k = 0; k < 3; k++) { // 3 times - stringStream << dis(gen); - } - stringStream << "-"; - stringStream << dis2(gen); - for (int m = 0; m < 3; m++) { // 3 times - stringStream << dis(gen); - } - stringStream << "-"; - for (int n = 0; n < 12; n++) { // 12 times - stringStream << dis(gen); - } - return stringStream.str(); -} - -} // namespace ipc_monitor -} // namespace dynolog_npu diff --git a/dynolog_npu/plugin/ipc_monitor/utils.h b/dynolog_npu/plugin/ipc_monitor/utils.h deleted file mode 100644 index 0d8ceb8cfd0bf81b6d8b807c6ac1b505276ddf83..0000000000000000000000000000000000000000 --- a/dynolog_npu/plugin/ipc_monitor/utils.h +++ /dev/null @@ -1,63 +0,0 @@ -#ifndef IPC_MONITOR_UTILS_H -#define IPC_MONITOR_UTILS_H -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - - -namespace dynolog_npu { -namespace ipc_monitor { - -constexpr int MaxParentPids = 5; -int32_t GetProcessId(); -std::string GenerateUuidV4(); -std::vector GetPids(); -std::pair GetParentPidAndCommand(int32_t pid); -std::vector> GetPidCommandPairsofAncestors(); -std::string getCurrentTimestamp(); - -enum class SubModule { - IPC = 0 -}; - -enum class ErrCode { - SUC = 0, - PARAM = 1, - TYPE = 2, - VALUE = 3, - PTR = 4, - INTERNAL = 5, - MEMORY = 6, - NOT_SUPPORT = 7, - NOT_FOUND = 8, - UNAVAIL = 9, - SYSCALL = 10, - TIMEOUT = 11, - PERMISSION = 12, -}; - - -std::string formatErrorCode(SubModule submodule, ErrCode errorCode); - -#define IPC_ERROR(error) formatErrorCode(SubModule::IPC, error) - -template -inline T ReinterpretConvert(V ptr) { - return reinterpret_cast(ptr); -} - - -} // namespace ipc_monitor -} // namespace dynolog_npu - -#endif - diff --git a/dynolog_npu/third_party/dynolog b/dynolog_npu/third_party/dynolog deleted file mode 160000 index d5d37bc182bc2aa8fa60ba7d5ee897bacb5cbd4b..0000000000000000000000000000000000000000 --- a/dynolog_npu/third_party/dynolog +++ /dev/null @@ -1 +0,0 @@ -Subproject commit d5d37bc182bc2aa8fa60ba7d5ee897bacb5cbd4b diff --git a/flight_recoder/analysis_flight.py b/flight_recoder/analysis_flight.py deleted file mode 100644 index f81f771ab1c81ad79cb93401e200b600a4b17af3..0000000000000000000000000000000000000000 --- a/flight_recoder/analysis_flight.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright (c) 2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - -import os -import pickle -import sys -import logging -from collections import defaultdict - -from check_path import get_valid_read_path - - -logging.basicConfig( - level=logging.INFO, # 设置日志级别为 INFO - format="%(asctime)s - %(levelname)s - %(message)s", # 设置日志格式 - handlers=[logging.StreamHandler()], # 输出到控制台 -) - - -SAFE_CLASSES = { - # 内置安全类型 - "builtins": {"str", "int", "float", "list", "dict", "tuple"}, -} - - -class SafeUnpickler(pickle.Unpickler): - def find_class(self, module, name): - # 检查模块和类是否在白名单中 - if module in SAFE_CLASSES and name in SAFE_CLASSES[module]: - return super().find_class(module, name) - raise pickle.UnpicklingError(f"Forbidden class: {module}.{name}") - - -def load_recorder_data(path, world_size): - """加载所有 rank 的 recorder 数据""" - recorder_dict = {} - for rank in range(world_size): - file_path = os.path.join(path, str(rank)) if not path.endswith("/") else path + str(rank) - file_path = get_valid_read_path(file_path) - try: - with open(file_path, "rb") as f: - res = SafeUnpickler(f).load() - recorder_dict[str(rank)] = res - except Exception as e: - logging.error(f"Failed to load data from {file_path}: {e}") - return recorder_dict - - -def extract_hccl_info(recorder_dict): - """从 recorder 数据中提取 HCCL 相关信息""" - hccl_dict = {} - for rank, recorder in recorder_dict.items(): - entries = recorder.get("entries", []) - if not entries: - continue - last_entry = entries[-1] - hccl_dict[rank] = { - "state": last_entry.get("state", None), - "record_id": last_entry.get("record_id", None), - "pg_id": last_entry.get("pg_id", None), - "time_discovered_completed_ns": last_entry.get("time_discovered_completed_ns", None), - "name": last_entry.get("frames", [{}])[0].get("name", None), - } - return hccl_dict - - -def analyze_pg_groups(hccl_dict): - """分析 HCCL 数据,按 pg_id 分组并检查问题""" - pg_groups = defaultdict(list) - for _, op in hccl_dict.items(): - pg_groups[op["pg_id"]].append(op) - - for pg_id, group in pg_groups.items(): - scheduled_ops = [op for op in group if op["state"] == "scheduled"] - completed_ops = [op for op in group if op["state"] == "completed"] - - # 情况 1: 所有卡都是 scheduled,且 record_id 和 name 相同 - if len(scheduled_ops) == len(group): - record_id = scheduled_ops[0]["record_id"] - name = scheduled_ops[0]["name"] - all_same = all(op["record_id"] == record_id and op["name"] == name for op in scheduled_ops) - if all_same: - logging.info( - f"The pg_id {pg_id}'s Communication Operator {name}" - " executed too slowly, causing the HCCL to time out." - ) - - # 情况 2: 存在 completed 算子且 该算子的record_id 比其他 scheduled 算子少 1 - elif completed_ops and scheduled_ops: - completed_op = completed_ops[0] - scheduled_record_id = scheduled_ops[0]["record_id"] - if completed_op["record_id"] == scheduled_record_id - 1: - logging.info( - f"The pg_id {pg_id}'s rank {completed_op['pg_id']}'s " - "Computational task took too long, causing the other ranks' " - "HCCL task to time out." - ) - - # 情况 3: 所有算子均为 completed - elif not scheduled_ops and completed_ops: - latest_op = max(completed_ops, key=lambda x: x["time_discovered_completed_ns"] or 0) - logging.info( - f"The computational task of the pg_id {pg_id} " - f"after the communication operator {latest_op['name']} " - "took too long." - ) - - else: - logging.info(f"The situation cannot be recognized!") - - -def get_int_arg(args, idx, default): - if len(args) > idx: - try: - return int(args[idx]) - except ValueError: - logging.warning(f"Invalid input {args[idx]}, using default: {default}") - return default - - -def main(): - # 设置默认值 - default_path = os.getenv("TORCH_HCCL_DEBUG_INFO_TEMP_FILE") - default_world_size = 8 - - # 获取命令行参数,如果未提供则使用默认值 - path = sys.argv[1] if len(sys.argv) > 1 else default_path - world_size = get_int_arg(sys.argv, 2, default_world_size) - - if not path: - raise ValueError("Path is required and cannot be empty.") - - logging.info(f"Path: {path}") - logging.info(f"World Size: {world_size}") - - # 加载数据 - recorder_dict = load_recorder_data(path, world_size) - if not recorder_dict: - logging.error("No valid recorder data found.") - return - - # 提取 HCCL 信息 - hccl_dict = extract_hccl_info(recorder_dict) - - # 分析 HCCL 数据 - analyze_pg_groups(hccl_dict) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/flight_recoder/check_path.py b/flight_recoder/check_path.py deleted file mode 100644 index b34e4dcdb68b28b44f387cb14919ad127658ca8f..0000000000000000000000000000000000000000 --- a/flight_recoder/check_path.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) 2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -import os -import sys -import stat - - -PATH_WHITE_LIST_REGEX = re.compile(r"[^_A-Za-z0-9/.-]") -MAX_READ_FILE_SIZE_4G = 4294967296 # 4G, 4 * 1024 * 1024 * 1024 -MAX_READ_FILE_SIZE_32G = 34359738368 # 32G, 32 * 1024 * 1024 * 1024 -MAX_READ_FILE_SIZE_512G = 549755813888 # 512G, 512 * 1024 * 1024 * 1024 - -# group not writable, others no permission, max stat is 750 -WRITE_FILE_NOT_PERMITTED_STAT = stat.S_IWGRP | stat.S_IWOTH | stat.S_IROTH | stat.S_IXOTH -# group not writable, others not writable, max stat is 755 -READ_FILE_NOT_PERMITTED_STAT = stat.S_IWGRP | stat.S_IWOTH - - -def type_to_str(value_type): - return ' or '.join([ii.__name__ for ii in value_type]) if isinstance(value_type, tuple) else value_type.__name__ - - -def check_type(value, value_type, param_name="value"): - if not isinstance(value, value_type): - raise TypeError('{} must be {}, not {}.'.format(param_name, type_to_str(value_type), type(value).__name__)) - - -def get_valid_path(path): - check_type(path, str, "path") - if not path or len(path) == 0: - raise ValueError("The value of the path cannot be empty.") - if PATH_WHITE_LIST_REGEX.search(path): # Check special char - raise ValueError("Input path contains invalid characters.") # Not printing out the path value for invalid char - path = os.path.expanduser(path) # Consider paths starting with "~" - if os.path.islink(os.path.abspath(path)): # when checking link, get rid of the "/" at the path tail if any - raise ValueError("The value of the path cannot be soft link: {}.".format(path)) - - real_path = os.path.realpath(path) - - if len(real_path) > 4096: - raise ValueError("The length of file path should be less than 4096.") - - if real_path != path and PATH_WHITE_LIST_REGEX.search(real_path): # Check special char again - raise ValueError("Input path contains invalid characters.") # Not printing out the path value for invalid char - - return real_path - - -def is_belong_to_user_or_group(file_stat): - return file_stat.st_uid == os.getuid() or file_stat.st_gid in os.getgroups() - - -def get_valid_read_path(path, size_max=MAX_READ_FILE_SIZE_4G, check_user_stat=True, is_dir=False): - real_path = get_valid_path(path) - if not os.path.isfile(real_path): - raise ValueError("The path {} doesn't exists or not a file.".format(path)) - - file_stat = os.stat(real_path) - if check_user_stat and not sys.platform.startswith("win") and not is_belong_to_user_or_group(file_stat): - raise ValueError("The file {} doesn't belong to the current user or group.".format(path)) - if check_user_stat and os.stat(path).st_mode & READ_FILE_NOT_PERMITTED_STAT > 0: - raise ValueError("The file {} is group writable, or is others writable.".format(path)) - if not os.access(real_path, os.R_OK) or file_stat.st_mode & stat.S_IRUSR == 0: # At least been 400 - raise ValueError("Current user doesn't have read permission to the file {}.".format(path)) - if not is_dir and size_max > 0 and file_stat.st_size > size_max: - raise ValueError("The file {} exceeds size limitation of {}.".format(path, size_max)) - return real_path \ No newline at end of file diff --git a/flight_recoder/flight_recoder.md b/flight_recoder/flight_recoder.md deleted file mode 100644 index 8b398a6730bae0823b04c20a22258a81392922c9..0000000000000000000000000000000000000000 --- a/flight_recoder/flight_recoder.md +++ /dev/null @@ -1,49 +0,0 @@ -# 飞行记录器超时类问题分析 - -训练任务卡住是阻塞AI大规模分布式集群训练任务的主要和关键问题,当前需要等待集合通信超时才能感知,影响集群可用性。框架需要支持检测训练任务卡住问题,做到提前识别并保存必要的诊断信息,提高问题定位效率和集群设备可用性。当HeartbeatMonitor长时间未检测到心跳时,即可认为训练任务已经卡住,需要触发诊断信息保存。 - -本工具提供torch npu上飞行记录器flight recorder记录日志的读取解析能力,并根据解析后的日志提供超时类问题的初步分析能力,主要支持以下三种情况的超时类问题的识别和分析 - -|问题| 具体内容 | -| --- | --- | -|类型一 | 同通信域内的某张卡计算超时,导致其他卡等待触发飞行记录器和hccl time out | -|类型二 | 同通信域内的通信算子之后的非通信任务耗时过长| -|类型三 | 同通信域内的某个通信算子进行通信时执行超时 | - -## 使用方法 - -### 1 飞行记录器开启方法 - -按照如下方法设置环境变量开启飞行记录器 - -``` -export TORCH_HCCL_ENABLE_MONITORING=1 #用于检测是否开启卡住问题检测 -export TORCH_HCCL_DUMP_ON_TIMEOUT=1 # 用于控制是否保存诊断信息 -export TORCH_HCCL_TRACE_BUFFER_SIZE=1 # 用于控制保存的集合通信状态数量 -export TORCH_HCCL_HEARTBEAT_TIMEOUT_SEC=20 # 用于控制心跳超时时间,即训练业务多久未下发集合通信算子时需要判定为卡住,默认10分钟,单位s。(需要小于HCCL_EXEC_TIMEOUT,避免集合通信先报超时错误) -export TORCH_HCCL_DEBUG_INFO_TEMP_FILE=/tmp/ #保存诊断信息的文件路径 -``` - -### 2 工具使用方法 - -``` -python analysis_flight.py path world_size -``` - -脚本从命令行参数获取 `path` 和 `world_size` 的值,并记录日志。如果未提供命令行参数,则使用默认值。 - -* `path`:从命令行第一个参数获取,如果未提供则使用 `default_path`, default_path从TORCH_HCCL_DEBUG_INFO_TEMP_FILE获取。 -* `world_size`:从命令行第二个参数获取,如果未提供则使用 `default_world_size`,默认为8。 - -| 参数名| 含义 | 使用限制 | -| --- | --- | --- | -| path | 飞行记录器的日志 | 可选。数据类型:string 默认为环境变量中的TORCH_HCCL_DEBUG_INFO_TEMP_FILE,若设置日志格式指定有前缀,则需要在路径中加入前缀 | -| world_size | 同一个通信域中的卡数 | 可选。数据类型:int 默认为8 | - -### 3 输出示例 - -``` -2025-02-19 08:10:07,160 - INFO - Path: /tmp/ -2025-02-19 08:10:07,160 - INFO - World Size: 8 -2025-02-19 08:10:07,162 - INFO - The pg_id 0's rank 0's Computational task took too long, causing the other ranks' HCCL task to time out. -``` diff --git a/msmonitor/README.md b/msmonitor/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f38851159b9a55bc672fed12469d4d4f65583a2d --- /dev/null +++ b/msmonitor/README.md @@ -0,0 +1,93 @@ +# msMonitor + +## 📌 简介 +msMonitor是MindStudio推出的一站式在线监控工具,提供用户在集群场景性能监控定位端到端能力。msMonitor基于[dynolog](https://github.com/facebookincubator/dynolog)开发,结合AI框架([Ascend PyTorch Profiler](https://www.hiascend.com/document/detail/zh/mindstudio/81RC1/T&ITools/Profiling/atlasprofiling_16_0090.html#ZH-CN_TOPIC_0000002353635602__zh-cn_topic_0000002370275077_section17272160135118)、[MindSpore Profiler](https://www.hiascend.com/document/detail/zh/mindstudio/81RC1/T&ITools/Profiling/atlasprofiling_16_0087.html#ZH-CN_TOPIC_0000002353475766__zh-cn_topic_0000002370195177_section0157845102716))的动态采集能力和[MSPTI](https://www.hiascend.com/document/detail/zh/mindstudio/81RC1/T&ITools/Profiling/atlasprofiling_16_0021.html),为用户提供**nputrace**和**npumonitor**功能: +1. **npumonitor功能**:轻量常驻后台,监控关键算子耗时 +2. **nputrace功能**:获取到框架、CANN以及device的详细性能数据 + +![msMonitor](./docs/resources/msMonitor.png) + +如上图所示msMonitor分为三部分: +1. **Dynolog daemon**:dynolog守护进程,每个节点只有一个守护进程,负责接收dyno CLI的RPC请求、触发nputrace和npumonitor功能、上报数据的处理以及最终数据的展示。 +2. **Dyno CLI**:dyno客户端,为用户提供nputrace和npumonitor子命令,任意节点都可以安装。 +3. **MSPTI Monitor**:基于MSPTI实现的监控子模块,通过调用MSPTI的API获取性能数据,并上报给Dynolog daemon。 + + +## 💻 版本说明 +msMonitor由三个文件组成,其中dyno和dynolog可以被打包为deb包或者rpm包。最新的预编译安装包和版本依赖请查看[msMonitor release](./docs/release_notes.md)。目前msMonitor支持在[PyTorch](https://gitee.com/ascend/pytorch)框架和[MindSpore](https://www.mindspore.cn/)框架上运行。 + +| 文件名 | 用途 | +|------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------| +| dyno | dyno客户端二进制文件 | +| dynolog | dynolog服务端二进制文件 | +| msmonitor_plugin-{mindstudio_version}-cp{python_version}-cp{python_version}-linux_{system_architecture}.whl | MSPTI Monitor、IPC等公共能力工具包,{mindstudio_version}表示mindstudio版本号,{python_version}表示python版本号,{system_architecture}表示CPU架构系统 | + +## 🚀 快速上手 +### Step 1: 安装 +请参见[msMonitor安装手册](./docs/install.md),安装msMonitor工具,推荐使用下载软件包安装。 + +### Step 2: 运行 +npumonitor和nputrace功能详细说明请参考[特性介绍](#-特性介绍)章节,下面介绍msMonitor常见的使用场景: +1. 先使用npumonitor功能获取关键算子耗时 +2. 当发现监控到关键算子耗时劣化,使用nputrace功能采集详细性能数据做分析 + +**操作步骤** +1. 拉起dynolog daemon进程,详细介绍请参考[dynolog介绍](./docs/dynolog.md) + +- 示例 +```bash +# 命令行方式开启dynolog daemon +dynolog --enable-ipc-monitor --certs-dir /home/server_certs + +# 如需使用Tensorboard展示数据,传入参数--metric_log_dir用于指定Tensorboard文件落盘路径 +# 例如: +dynolog --enable-ipc-monitor --certs-dir /home/server_certs --metric_log_dir /tmp/metric_log_dir # dynolog daemon的日志路径为:/var/log/dynolog.log +``` + +2. 使能msMonitor环境变量 +```bash +export MSMONITOR_USE_DAEMON=1 +``` + +3. 设置LD_PRELOAD使能MSPTI(使能npumonitor功能设置) +```bash +# 默认路径示例:export LD_PRELOAD=/usr/local/Ascend/ascend-toolkit/latest/lib64/libmspti.so +export LD_PRELOAD=/ascend-toolkit/latest/lib64/libmspti.so +``` +4. 拉起训练/推理任务 +```bash +bash run_ai_task.sh +``` +5. 使用dyno命令行触发npumonitor监控关键算子耗时 +```bash +# 开启npu-monitor,上报周期30s, 上报数据类型为Kernel +dyno --certs-dir /home/client_certs npu-monitor --npu-monitor-start --report-interval-s 30 --mspti-activity-kind Kernel +``` +```bash +# 关闭npu-monitor +dyno --certs-dir /home/client_certs npu-monitor --npu-monitor-stop +``` +6. 使用dyno命令行触发nputrace采集详细trace数据(需要关闭npumonitor功能才能触发nputrace功能) +```bash +# 从第10个step开始采集,采集2个step,采集框架、CANN和device数据,同时采集完后自动解析以及解析完成不做数据精简,落盘路径为/tmp/profile_data +dyno --certs-dir /home/client_certs nputrace --start-step 10 --iterations 2 --activities CPU,NPU --analyse --data-simplification false --log-file /tmp/profile_data +``` + +## 📖 特性介绍 +⚠️ 由于底层资源限制,npumonitor功能和nputrace不能同时开启。 + +1. 执行 dyno 命令后,响应结果里有一个 ‘response’ 的json字符串。该字符串中的 ‘commandStatus’ 字段用于标识命令是否生效:‘effective’ 表示命令会生效,‘ineffective’ 表示命令无效。其他字段均为 dynolog 的原生字段。 + +### 📈 npumonitor特性 +npumonitor特性为用户提供轻量化监控关键指标的能力,npumonitor基于[MSPTI](https://www.hiascend.com/document/detail/zh/mindstudio/81RC1/T&ITools/Profiling/atlasprofiling_16_0021.html)开发,用户可以通过npumonitor查看模型运行时的计算、通信算子执行耗时。 +具体使用方式请参考[npumonitor使用方式](./docs/npumonitor.md),MindSpore框架下使用方式请参考[MindSpore框架下msMonitor的使用方法](./docs/mindspore_adapter.md)。 + +### 📊 nputrace特性 +nputrace特性为用户提供动态触发AI框架([Ascend PyTorch Profiler](https://www.hiascend.com/document/detail/zh/mindstudio/81RC1/T&ITools/Profiling/atlasprofiling_16_0090.html)、[MindSpore Profiler](https://www.hiascend.com/document/detail/zh/mindstudio/81RC1/T&ITools/Profiling/atlasprofiling_16_0087.html))采集解析的能力,即实现模型拉起后不需要中断模型运行,可多次触发不同配置Profiler采集解析。采集的性能数据可以使用[MindStudio Insight](https://www.hiascend.com/document/detail/zh/mindstudio/81RC1/GUI_baseddevelopmenttool/msascendinsightug/Insight_userguide_0002.html)进行可视化,效果图如下。 +具体使用方式请参考[nputrace使用方式](./docs/nputrace.md),MindSpore框架下使用方式请参考[MindSpore框架下msMonitor的使用方法](./docs/mindspore_adapter.md) +![MindStudio Insight TimeLine可视化效果图](./docs/resources/mindstudio_insight.png) + +## 🔒 安全声明 +[msMonitor安全声明](./docs/security_statement.md) +## ❓ FAQ +[msMonitor FAQ](./docs/faq.md) \ No newline at end of file diff --git a/msmonitor/docs/dynolog.md b/msmonitor/docs/dynolog.md new file mode 100644 index 0000000000000000000000000000000000000000..2d4486fb276ce9b604ab6c45075d5b320dcb3a22 --- /dev/null +++ b/msmonitor/docs/dynolog.md @@ -0,0 +1,28 @@ +# dynolog介绍 + +dynolog负责接收dyno CLI的RPC请求,触发nputrace和npumonitor功能。 + +- dynolog daemon可以通过systemd或者命令行任意一种方法开启 + +```bash +# 方法1:使用systemd拉起service +# 修改配置文件/etc/dynolog.gflags, 使能ipc_monitor +echo "--enable_ipc_monitor" | sudo tee -a /etc/dynolog.gflags +sudo systemctl start dynolog +``` + +```bash +# 方法2:命令行执行 +dynolog --enable-ipc-monitor --certs-dir /home/server_certs +``` + +## dynolog常用参数 + +| 命令 | 参数类型 | 说明 | 是否必选 | +|---------------------|--------|-----------------------------------------------------|:----:| +| --enable-ipc-monitor | action | 是否启用IPC监控功能,用于与dyno进行通信,设置参数开启,默认不开启 | N | +| --port | i32 | dynolog daemon进程监听的端口号,默认值1778 | N | +| --certs-dir | String | 用于指定dyno与dynolog RPC通信时TLS证书的路径,当值为`NO_CERTS`时不使用证书校验 | Y | +| --metric_log_dir | String | 用于指定Metric数据的落盘路径 | N | +| --use_JSON | action | 是否使用JSON格式记录metric数据到日志中,默认不启用 | N | + diff --git a/msmonitor/docs/faq.md b/msmonitor/docs/faq.md new file mode 100644 index 0000000000000000000000000000000000000000..25015d4f0b876642d8e22b58d8b377d8cd528637 --- /dev/null +++ b/msmonitor/docs/faq.md @@ -0,0 +1,5 @@ +# FAQ + +- **Q:dyno CLI发送npumonitor命令后没有数据上报?** +- A:npumonitor功能基于MSPTI接口开发,如果没有数据上报,请先检查LD_PRELOAD是否正常设置libmspti.so的路径,然后检查dynolog日志中有无正常接收到dyno CLI的RPC请求。 +---------------------------------------------------------------------------------------------------------------------------------------------------- \ No newline at end of file diff --git a/msmonitor/docs/install.md b/msmonitor/docs/install.md new file mode 100644 index 0000000000000000000000000000000000000000..497dc52df885a3cf62876eca6c576f4498ff926c --- /dev/null +++ b/msmonitor/docs/install.md @@ -0,0 +1,116 @@ +# msMonitor安装 +## 下载软件包安装(推荐) +最新的预编译安装包和版本依赖请查看[msMonitor release](./release_notes.md),并根据指导进行校验和安装。 + +## 源码编译安装 + +### 1. clone 代码 + +```bash +git clone https://gitee.com/ascend/mstt.git +``` + +### 2. 安装依赖 +dynolog的编译依赖,确保安装了以下依赖: + + + + + + + + + + + + + +
Language + Toolchain +
C++ + gcc 8.5.0+ +
Rust + Rust >= 1.81 +
+ +- 安装rust + +```bash +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh + +source $HOME/.cargo/env +``` + +- 安装ninja + +```bash +# debian +sudo apt-get install -y cmake ninja-build + +# centos +sudo yum install -y cmake ninja +``` + +- 安装protobuf (tensorboard_logger三方依赖,用于对接tensorboard展示) +- **说明**:要求protobuf版本为3.12或更高版本 +```bash +# debian +sudo apt install -y protobuf-compiler libprotobuf-dev + +# centos +sudo yum install -y protobuf protobuf-devel protobuf-compiler + +# Python +pip install protobuf +``` + +- (可选)安装openssl(RPC TLS认证)& 生成证书密钥 +- **说明**:如果不需要使用TLS证书密钥加密,该步骤可跳过。 + +```bash +# debian +sudo apt-get install -y openssl + +# centos +sudo yum install -y openssl +``` +dyno CLI与dynolog daemon之间的RPC通信使用TLS证书密钥加密,在启动dyno和dynolog二进制时可以指定证书密钥存放的路径,路径下需要满足如下结构和名称。 +**用户应使用与自己需求相符的密钥生成和存储机制,并保证密钥安全性与机密性。** + +服务端证书目录结构: +```bash +server_certs +├── ca.crt (根证书,用于验证其他证书的合法性,必选) +├── server.crt (服务器端的证书,用于向客户端证明服务器身份,必选) +├── server.key (服务器端的私钥文件,与server.crt配对使用,支持加密,必选) +└── ca.crl (证书吊销列表,包含已被吊销的证书信息,可选) +``` +客户端证书目录结构: +```bash +client_certs +├── ca.crt (根证书,用于验证其他证书的合法性,必选) +├── client.crt (客户端证书,用于向服务器证明客户端身份,必选) +├── client.key (客户端的私钥文件,与client.crt配对使用,支持加密,必选) +└── ca.crl (证书吊销列表,包含已被吊销的证书信息,可选) +``` + +### 3. 编译 + +- dynolog编译 + +默认编译生成dyno和dynolog二进制文件,-t参数可以支持将二进制文件打包成deb包或rpm包。 + +```bash +# 编译dyno和dynolog二进制文件 +bash scripts/build.sh + +# 编译deb包, 当前支持amd64和aarch64平台, 默认为amd64, 编译aarch64平台需要修改third_party/dynolog/scripts/debian/control文件中的Architecture改为arm64 +bash scripts/build.sh -t deb + +# 编译rpm包, 当前只支持amd64平台 +bash scripts/build.sh -t rpm +``` + +- msmonitor-plugin wheel包编译 + +msmonitor-plugin wheel包提供IPCMonitor,MsptiMonitor等公共能力,使用nputrace和npumonitor功能前必须安装该wheel包,具体编译安装指导可参考[msmonitor-plugin编包指导](../plugin/README.md)。 \ No newline at end of file diff --git a/msmonitor/docs/mindspore_adapter.md b/msmonitor/docs/mindspore_adapter.md new file mode 100644 index 0000000000000000000000000000000000000000..97b8795a3183e911750b3ab01f7f3ea238778562 --- /dev/null +++ b/msmonitor/docs/mindspore_adapter.md @@ -0,0 +1,86 @@ +## MindSpore框架下msMonitor的使用方法 + +### 1. 动态profiling自定义for循环方式 + +Step 1:拉起dynolog daemon进程 + +Step 2:使能dynolog环境变量 + +Step 3:配置msMonitor日志路径 + +- 前3步以及第5步操作可以参考[npumonitor](./npumonitor.md)或[nputrace](./nputrace.md) + +Step 4: 拉起训练任务 +在训练任务中实例化DynamicProfilerMonitor对象,且在每一次训练后,调用step()方法。 + +- 示例代码如下: +```python +import numpy as np +import mindspore +import mindspore.dataset as ds +from mindspore import nn +from mindspore.profiler import DynamicProfilerMonitor + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.fc = nn.Dense(2, 2) + + def construct(self, x): + return self.fc(x) + + +def generator_net(): + for _ in range(2): + yield np.ones([2, 2]).astype(np.float32), np.ones([2]).astype(np.int32) + + +def train(test_net): + optimizer = nn.Momentum(test_net.trainable_params(), 1, 0.9) + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) + data = ds.GeneratorDataset(generator_net(), ["data", "label"]) + model = mindspore.train.Model(test_net, loss, optimizer) + model.train(1, data) + +if __name__ == '__main__': + dp = DynamicProfilerMonitor() + step_num = 100 + # 定义模型 + net = Net() + for i in range(step_num): + # 模型训练 + train(net) + # 调用step方法实现npu trace dump或npu monitor功能 + dp.step() +``` + +Step 5:使用dyno CLI使能trace dump或npu-monitor + +### 2. 动态profiling callback方式 +该使能方式与动态profiling自定义for循环方式一致,唯一区别是将step()方法适配在step_begin、step_end回调函数中。 +- 示例代码如下: +```python +import mindspore +from mindspore.profiler import DynamicProfilerMonitor + +class StopAtStep(mindspore.Callback): + def __init__(self, start_step, stop_step): + super(StopAtStep, self).__init__() + self.start_step = start_step + self.stop_step = stop_step + self.dp = DynamicProfilerMonitor() + + def step_begin(self, run_context): + cb_params = run_context.original_args() + step_num = cb_params.cur_step_num + if step_num == self.start_step: + self.dp.start() + + def step_end(self, run_context): + cb_params = run_context.original_args() + step_num = cb_params.cur_step_num + if self.start_step <= step_num < self.stop_step: + self.dp.step() # 调用step方法实现npu trace dump或npu monitor功能 + if step_num == self.stop_step: + self.dp.stop() +``` diff --git a/msmonitor/docs/npumonitor.md b/msmonitor/docs/npumonitor.md new file mode 100644 index 0000000000000000000000000000000000000000..2b1c94fff5c7364f58745a174f2b059149a6f582 --- /dev/null +++ b/msmonitor/docs/npumonitor.md @@ -0,0 +1,102 @@ +# npumonitor特性 + +npumonitor通过dyno CLI中的npumonitor子命令开启: + +```bash +dyno --certs-dir npu-monitor [SUBCOMMANDS] +``` +**说明**: +- 1. dyno和dynolog中--certs-dir传入参数值须保持一致; +- 2. 可传入证书路径,如果不使用TLS证书密钥,设置为NO_CERTS。 + + +查看npumonitor支持的命令和帮助 + +```bash +dyno npu-monitor --help +``` + +npu-monitor的SUBCOMMANDS(子命令)选项如下: + +| 子命令 | 参数类型 | 说明 | PyTorch支持 | MindSpore支持 | 是否必选 | +|-----------------------|-------|------------------------------------------------------------------------------------------------------------------------------------------------------|:---------:|:-----------:|:-----------:| +| --npu-monitor-start | action | 开启性能监控,设置参数后生效,默认不生效 | Y | Y | N | +| --npu-monitor-stop | action | 停止性能监控,设置参数后生效,默认不生效 | Y | Y | N | +| --report-interval-s | int | 性能监控数据上报周期,单位s,需要在启动时设置。默认值60 | Y | Y | N | +| --mspti-activity-kind | String | 性能监控数据上报数据类型,可以设置单个或多个,多个类型以逗号分隔,每次设置时刷新全局上报类型。可选值范围[`Marker`, `Kernel`, `API`, `Hccl`, `Memory`, `MemSet`, `MemCpy`, `Communication`] , 默认值`Marker` | Y | Y | N | +| --log-file | String | 性能数据采集落盘的路径,当前仅支持`mspti-activity-kind`设置为`Marker`,`Kernel`,`API`,`Communication`4类数据的导出, 落盘数据格式为db,db中内容说明请参考[msprof导出db格式数据说明](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/devaids/Profiling/atlasprofiling_16_1144.html),默认值为空,表示不落盘 | Y | Y | N | + +## npu-monitor使用方法 + +Step 1: 拉起dynolog daemon进程,详细介绍请参考[dynolog介绍](./dynolog.md) + +- 示例 +```bash +# 命令行方式开启dynolog daemon +dynolog --enable-ipc-monitor --certs-dir /home/server_certs + +# 如需使用Tensorboard展示数据,传入参数--metric_log_dir用于指定Tensorboard文件落盘路径 +# 例如: +dynolog --enable-ipc-monitor --certs-dir /home/server_certs --metric_log_dir /tmp/metric_log_dir # dynolog daemon的日志路径为:/var/log/dynolog.log +``` + +Step 2:在训练/推理任务拉起窗口使能dynolog环境变量 +```bash +export MSMONITOR_USE_DAEMON=1 +``` + +Step 3:配置Msmonitor日志路径(可选,默认路径为当前目录下的msmonitor_log) +```bash +export MSMONITOR_LOG_PATH= +# 示例: +export MSMONITOR_LOG_PATH=/tmp/msmonitor_log +``` + +Step 4:设置LD_PRELOAD使能MSPTI +```bash +# 示例:export LD_PRELOAD=/usr/local/Ascend/ascend-toolkit/latest/lib64/libmspti.so +export LD_PRELOAD=/ascend-toolkit/latest/lib64/libmspti.so + ``` + +Step 5:拉起训练/推理任务 +```bash +# 训练任务中需要使用pytorch的优化器/继承原生优化器 +bash train.sh +``` + +Step 6:使用dyno CLI使能npu-monitor +```bash +# 示例1:开启性能监控,使用默认配置 +dyno --certs-dir /home/client_certs npu-monitor --npu-monitor-start + +# 示例2:暂停性能监控 +dyno --certs-dir /home/client_certs npu-monitor --npu-monitor-stop + +# 示例3:性能监控过程中修改配置 +# 上报周期30s, 上报数据类型Marker和Kernel +dyno --certs-dir /home/client_certs npu-monitor --report-interval-s 30 --mspti-activity-kind Marker,Kernel + +# 示例4:性能监控开启时修改配置 +# 上报周期30s, 上报数据类型Marker和Kernel +dyno --certs-dir /home/client_certs npu-monitor --npu-monitor-start --report-interval-s 30 --mspti-activity-kind Marker,Kernel + +# 示例5:性能监控开启时修改配置,开启数据采集落盘 +# 数据落盘路径为/tmp/msmonitor_db,落盘周期为30s,采集数据类型为Marker,Kernel,Communication +dyno --certs-dir /home/client_certs npu-monitor --npu-monitor-start --report-interval-s 30 --mspti-activity-kind Marker,Kernel,Communication --log-file /tmp/msmonitor_db + +# 示例6:多机场景下性能监控开启时修改配置 +# 多机场景下向特定机器x.x.x.x发送参数信息,参数表示上报周期30s, 上报数据类型Marker和Kernel +dyno --certs-dir /home/client_certs --hostname x.x.x.x npu-monitor --npu-monitor-start --report-interval-s 30 --mspti-activity-kind Marker,Kernel +``` + +Step 7:(可选)观测Tensorboard上报数据 +``` +# 请确保安装了Tensorboard: +pip install tensorboard + +# 然后运行: +tensorboard --logdir={metric_log_dir} # metric_log_dir为Step1中dynolog命令行中--metric_log_dir参数指定的路径 + +# 打开浏览器访问http://localhost:6006即可看到对应可视化图表, 其中localhost为服务器的ip地址,6006为tensorboard默认端口 +``` +> tensorboard 具体使用参数见https://github.com/tensorflow/tensorboard \ No newline at end of file diff --git a/msmonitor/docs/nputrace.md b/msmonitor/docs/nputrace.md new file mode 100644 index 0000000000000000000000000000000000000000..65b307956cdc0aea8c8acf6531930588db1fa8e2 --- /dev/null +++ b/msmonitor/docs/nputrace.md @@ -0,0 +1,90 @@ +# nputrace特性 + +nputrace通过dyno CLI中的nputrace子命令开启: + +```bash +dyno --certs-dir nputrace [SUBCOMMANDS] +``` +**说明**: +- 1. dyno和dynolog中--certs-dir传入参数值须保持一致; +- 2. 可传入证书路径,如果不使用TLS证书密钥,设置为NO_CERTS。 + +查看nputrace支持的命令和帮助 + +```bash +dyno nputrace --help +``` + +nputrace的SUBCOMMANDS(子命令)选项如下: + +| 子命令 | 参数类型 | 说明 | PyTorch支持 | MindSpore支持 | 是否必选 | +|-----------------------|-------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:---------:|:-----------:|:----------:| +| --job-id | u64 | 采集任务的job id,默认值0,dynolog原生参数 | N | N | N | +| --pids | String | 采集任务的pid列表,多个pid用逗号分隔,默认值0,dynolog原生参数 | N | N | N | +| --process-limit | u64 | 最大采集进程的数量,默认值3,dynolog原生参数 | N | N | N | +| --profile-start-time | u64 | 用于同步采集的Unix时间戳,单位毫秒,默认值0,dynolog原生参数 | N | N | N | +| --duration-ms | u64 | 采集的周期,单位毫秒,默认值500,dynolog原生参数 | N | N | N | +| --iterations | i64 | 采集总迭代数,必须传入正整数,dynolog原生参数,需与start-step参数同时指定 | Y | Y | Y | +| --log-file | String | 采集落盘的路径 | Y | Y | Y | +| --start-step | i64 | 开始采集的迭代数,必须传入正整数或-1,设置为-1时表示从下一个step开始采集 | Y | Y | Y | +| --record-shapes | action | 是否采集算子的InputShapes和InputTypes,设置参数采集,默认不采集 | Y | Y | N | +| --profile-memory | action | 是否采集算子内存信息,设置参数采集,默认不采集 | Y | Y | N | +| --with-stack | action | 是否采集Python调用栈,设置参数采集,默认不采集 | Y | Y | N | +| --with-flops | action | 是否采集算子flops,设置参数采集,默认不采集 | Y | N | N | +| --with-modules | action | 是否采集modules层级的Python调用栈,设置参数采集,默认不采集 | Y | N | N | +| --analyse | action | 采集后是否自动解析,设置参数解析,默认不解析 | Y | Y | N | +| --l2-cache | action | 是否采集L2 Cache数据,设置参数采集,默认不采集 | Y | Y | N | +| --op-attr | action | 是否采集算子属性信息,设置参数采集,默认不采集 | Y | N | N | +| --msprof-tx | action | 是否使能MSTX,设置参数采集,默认不使能 | Y | Y | N | +| --mstx-domain-include | Option | 使能--msprof-tx采集mstx打点数据的情况下,配置该开关,设置实际采集的domain范围,与--mstx-domain-exclude参数互斥,若同时设置,则只有--mstx-domain-include生效。该参数为可选参数,默认不使能。可配置一个或多个domain,例如:--mstx-domain-include domain1, domain2 | Y | Y | N | +| --mstx-domain-exclude | Option | 使能--msprof-tx采集mstx打点数据的情况下,配置该开关,设置实际不采集的domain范围,与--mstx-domain-include参数互斥,若同时设置,则只有--mstx-domain-include生效。该参数为可选参数,默认不使能。可配置一个或多个domain,例如:--mstx-domain-exclude domain1, domain2 | Y | Y | N | +| --data-simplification | String | 解析完成后是否数据精简,可选值范围[`true`, `false`],默认值`true` | Y | Y | N | +| --activities | String | 控制CPU、NPU事件采集范围,可以设置单个或多个,多个类型以逗号分隔,可选值范围[`CPU`, `NPU`],默认值`CPU,NPU` | Y | Y | N | +| --profiler-level | String | 控制profiler的采集等级,可选值范围[`Level_none`, `Level0`, `Level1`, `Level2`],默认值`Level0` | Y | Y | N | +| --aic-metrics | String | AI Core的性能指标采集项,可选值范围[`AiCoreNone`, `PipeUtilization`, `ArithmeticUtilization`, `Memory`, `MemoryL0`, `ResourceConflictRatio`, `MemoryUB`, `L2Cache`, `MemoryAccess`],当--profiler-level设置为Level_none或Level0,默认值`AiCoreNone`,当--profiler-level设置为Level1或Level2,默认值`PipeUtilization` | Y | Y | N | +| --export-type | String | profiler解析导出数据的类型,只能设置其中一个,不能同时设置db和text,可选值范围[`Text`, `Db`],默认值`Text` | Y | Y | N | +| --gc-detect-threshold | Option | GC检测阈值,单位ms,只采集超过阈值的GC事件。默认不设置时不开启GC检测 | Y | N | N | +| --host-sys | String | 采集[host侧系统数据](https://www.hiascend.com/document/detail/zh/mindstudio/80RC1/T&ITools/Profiling/atlasprofiling_16_0014.html)(CPU利用率、内存利用率、磁盘I/O利用率、网络I/O利用率等)。可以设置单个或多个,多个类型以逗号分隔,可选值范围[`cpu`, `mem`, `disk`, `network`, `osrt`] , 默认不设置时不开启host侧系统数据采集 | Y | Y | N | +| --sys-io | action | 采集NIC、ROCE数据。设置参数采集,默认不采集 | Y | Y | N | +| --sys-interconnection | action | 采集集合通信带宽数据(HCCS)、PCIe、片间传输带宽数据。设置参数采集,默认不采集 | Y | Y | N | + +## nputrace使用方法 + +Step 1:拉起dynolog daemon进程,详细介绍请参考[dynolog介绍](./dynolog.md) + +- 示例 +```bash +# 命令行方式开启dynolog daemon +dynolog --enable-ipc-monitor --certs-dir /home/server_certs +``` + +Step 2:在训练/推理任务拉起窗口使能dynolog环境变量 +```bash +export MSMONITOR_USE_DAEMON=1 +``` + +Step 3:拉起训练/推理任务 +```bash +# 训练任务中需要使用pytorch的优化器/继承原生优化器 +bash train.sh +``` + +Step 4:使用dyno CLI动态触发trace dump +```bash +# 示例1:从第10个step开始采集,采集2个step,采集框架、CANN和device数据,同时采集完后自动解析以及解析完成不做数据精简,落盘路径为/tmp/profile_data +dyno --certs-dir /home/client_certs nputrace --start-step 10 --iterations 2 --activities CPU,NPU --analyse --data-simplification false --log-file /tmp/profile_data + +# 示例2:从下一个step开始采集,采集2个step,采集框架、CANN和device数据,同时采集完后自动解析以及解析完成不做数据精简,落盘路径为/tmp/profile_data +dyno --certs-dir /home/client_certs nputrace --start-step -1 --iterations 2 --activities CPU,NPU --analyse --data-simplification false --log-file /tmp/profile_data + +# 示例3:从第10个step开始采集,采集2个step,只采集CANN和device数据,同时采集完后自动解析以及解析完成后开启数据精简,落盘路径为/tmp/profile_data +dyno --certs-dir /home/client_certs nputrace --start-step 10 --iterations 2 --activities NPU --analyse --data-simplification true --log-file /tmp/profile_data + +# 示例4:从第10个step开始采集,采集2个step,只采集CANN和device数据,只采集不解析,落盘路径为/tmp/profile_data + +dyno --certs-dir /home/client_certs nputrace --start-step 10 --iterations 2 --activities NPU --log-file /tmp/profile_data + +# 示例5:多机场景下向特定机器x.x.x.x发送参数信息,参数表示从第10个step开始采集,采集2个step,只采集CANN和device数据,只采集不解析,落盘路径为/tmp/profile_data +dyno --certs-dir /home/client_certs --hostname x.x.x.x nputrace --start-step 10 --iterations 2 --activities NPU --log-file /tmp/profile_data +``` +nputrace落盘的数据格式和交付件介绍请参考[MindSpore&PyTorch框架性能数据文件参考](https://www.hiascend.com/document/detail/zh/mindstudio/81RC1/T&ITools/Profiling/atlasprofiling_16_0177.html) \ No newline at end of file diff --git a/msmonitor/docs/release_notes.md b/msmonitor/docs/release_notes.md new file mode 100644 index 0000000000000000000000000000000000000000..1bb2e5ced16b7ed12c7b124204e7b2194b01a502 --- /dev/null +++ b/msmonitor/docs/release_notes.md @@ -0,0 +1,60 @@ +# 版本说明 + +| msmonitor版本 | 发布日期 | 下载链接 | 校验码 | 配套CANN版本 | 配套torch_npu版本 | 配套MindSpore版本 | +|---------------|----------------|--------------------------------------------------------------------------|------------------------------------------------------------------------|--------------|--------------|--------------| +| 8.1.0 | 2025-07-11 | [aarch64_8.1.0.zip](https://ptdbg.obs.cn-north-4.myhuaweicloud.com/profiler/msmonitor/8.1.0/aarch64_8.1.0.zip) | ce136120c0288291cc0a7803b1efc8c8416c6105e9d54c17ccf2e2510869fada | 8.1.RC1及以上 | v7.1.0及以上 | 2.7.0-rc1及以上 | +| | 2025-07-11 | [x86_8.1.0.zip](https://ptdbg.obs.cn-north-4.myhuaweicloud.com/profiler/msmonitor/8.1.0/x86_8.1.0.zip) | 097d11c7994793b6389b19259269ceb3b6b7ac5ed77da3949b3f09da2103b7f2 | 8.1.RC1及以上 | v7.1.0及以上 | 2.7.0-rc1及以上 | + +Step 1: 根据aarch64还是x86选择对应安装包链接下载。 + +Step 2: 校验包完整性 + + 1. 根据以上下载链接下载包到Linux安装环境。 + + 2. 进入zip包所在目录,执行如下命令。 + + ``` + sha256sum {name}.zip + ``` + + {name}为zip包名称。 + + 若回显呈现对应版本zip包一致的**校验码**,则表示下载了正确的性能工具zip安装包。示例如下: + + ```bash + sha256sum aarch64_8.1.0.zip + ``` + +Step 3: 包安装(以x86版本为例) + + 1. 解压压缩包 + ```bash + mkdir x86 + unzip x86_8.1.0.zip -d x86 + ``` + + 2. 进入目录 + ```bash + cd x86 + ``` + + 3. 安装whl包 + ```bash + pip install msmonitor_plugin-{mindstudio_version}-cp{python_version}-cp{python_version}-linux_{system_architecture}.whl + ``` + + 4. 安装dynolog + + 有以下三种安装方式可供选择,根据用户服务器系统自行选择: + + 方式一:使用deb软件包安装(适用于Debian/Ubuntu等系统); + ``` + dpkg -i --force-overwrite dynolog*.deb + ``` + + 方式二:使用rpm软件包安装(适用于RedHat/Fedora/openSUSE等系统); + ``` + rpm -ivh dynolog-*.rpm --nodeps + ``` + + 方式三:直接复制bin文件夹到系统中。 diff --git a/msmonitor/docs/resources/mindstudio_insight.png b/msmonitor/docs/resources/mindstudio_insight.png new file mode 100644 index 0000000000000000000000000000000000000000..1c4085db535bea6929874ef718f2b1a32dceb387 Binary files /dev/null and b/msmonitor/docs/resources/mindstudio_insight.png differ diff --git a/msmonitor/docs/resources/msMonitor.png b/msmonitor/docs/resources/msMonitor.png new file mode 100644 index 0000000000000000000000000000000000000000..50d6110784ab2c4fe43e4c1a440f268f490e966e Binary files /dev/null and b/msmonitor/docs/resources/msMonitor.png differ diff --git a/msmonitor/docs/security_statement.md b/msmonitor/docs/security_statement.md new file mode 100644 index 0000000000000000000000000000000000000000..8a79e3553649c39be9727662f1fa225d9c79e277 --- /dev/null +++ b/msmonitor/docs/security_statement.md @@ -0,0 +1,17 @@ +## 安全声明 + +### 安全风险 + +**dynolog 原生全零监听安全风险** + +msMonitor 引入了开源第三方库 dynolog,该库的 `dynolog/src/rpc/SimpleJsonServer.cpp` 文件包含全零监听功能(bind to `in6addr_any`),存在网络暴露安全风险。msMonitor 对dynolog进行了 NPU 适配,适配引入的文件 `msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.cpp` 中包含全零监听代码,为保证工具功能和易用性,未对原生dynolog全零监听代码进行修改,该安全风险来源于 dynolog 开源三方库。 + +**风险消减措施** + +建议用户配置 iptables 等防火墙机制限制对 RPC 端口的网络访问。 + +### 通信矩阵 + +| 序号 | 代码仓 | 功能 | 源设备 | 源IP | 源端口 | 目的设备 | 目的IP | 目的端口
(侦听) | 协议 | 端口说明 | 端口配置 | 侦听端口是否可更改 | 认证方式 | 加密方式 | 所属平面 | 版本 | 特殊场景 | 备注 | +|:----|:------------|:-----------|:------------------|:---------------------|:------|:-------------------|:---------------------|:--------------|:-----------|:-------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----------|:-----|:-----|:-------|:-----------------------|:-----|:---| +| 1 | msMonitor | dyno和dynolog RPC通信 | dyno客户端 | 运行dyno客户端进程的服务器的ip | | dynolog服务端所在服务器 | dynolog服务端所在服务器的ip | 1778 | TCP | RPC通信 | 不涉及 | 可修改 | 证书密钥 | TLS | 业务面 | 所有版本 | 无 | | diff --git a/msmonitor/dynolog_npu/CMakeLists.txt b/msmonitor/dynolog_npu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..2aeddf3d667645f8739ef3b7aa6b55c0780ad8f3 --- /dev/null +++ b/msmonitor/dynolog_npu/CMakeLists.txt @@ -0,0 +1,94 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +cmake_minimum_required(VERSION 3.16) + +project(Dynolog VERSION 1.0) +option(BUILD_TESTS "Build the unit tests" ON) +option(USE_ODS_GRAPH_API "Enable logger to Meta ODS using public Graph API." +OFF) +option(USE_JSON_GENERATED_PERF_EVENTS "Add performance events generated using +Intel json spec, see hbt/src/perf_event/json_events/intel" +OFF) +option(USE_PROMETHEUS "Enable logging to prometheus, this requires +prometheus-cpp to be installed on the system" +OFF) +option(USE_TENSORBOARD "Enable logging to tensorboard, this requires +protobuf to be installed on the system" +ON) + +if(USE_PROMETHEUS) + find_package(prometheus-cpp CONFIG REQUIRED) +endif() + +file(READ "version.txt" DYNOLOG_VERSION) +string(STRIP ${DYNOLOG_VERSION} DYNOLOG_VERSION) + +execute_process ( + COMMAND git rev-parse --short HEAD + OUTPUT_VARIABLE DYNOLOG_GIT_REV + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +set(DYNOLOG_VERSION "\"${DYNOLOG_VERSION}\"") +set(DYNOLOG_GIT_REV "\"${DYNOLOG_GIT_REV}\"") +message("Dynolog version = ${DYNOLOG_VERSION}") +message("Dynolog git rev = ${DYNOLOG_GIT_REV}") + +set(CMAKE_VERBOSE_MAKEFILE ON) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED True) +set(CMAKE_POSITION_INDEPENDENT_CODE True) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread") +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread") + +if(BUILD_TESTS) + enable_testing() + add_subdirectory("third_party/googletest" "third_party/googletest") +endif() + +include_directories(".") +add_subdirectory(dynolog) +add_subdirectory(cli) +# The following dummy depdendency ensures the cli is built +add_dependencies(dynolog_lib dyno) +add_subdirectory(hbt) + +# Third party deps +set(BUILD_SHARED_LIBS OFF CACHE INTERNAL "") +set(BUILD_SAMPLES OFF CACHE BOOL "") +set(BUILD_TEST OFF CACHE BOOL "") +set(BUILD_SHARED_LIBS OFF CACHE BOOL "") + +set(BUILD_TESTING OFF CACHE BOOL "") +set(WITH_GFLAGS OFF CACHE BOOL "") +add_subdirectory(third_party/glog) +target_link_libraries(dynolog_lib PUBLIC glog::glog) + +set(GFLAGS_BUILD_TESTING OFF CACHE BOOL "") +add_subdirectory(third_party/gflags) +target_link_libraries(dynolog_lib PUBLIC gflags::gflags) + +# https://github.com/nlohmann/json#cmake +set(JSON_BuildTests OFF CACHE INTERNAL "") +add_subdirectory(third_party/json) +target_link_libraries(dynolog_lib PUBLIC nlohmann_json::nlohmann_json) + +add_subdirectory(third_party/pfs) +target_include_directories(dynolog_lib PUBLIC third_party/pfs/include) +target_link_libraries(dynolog_lib PUBLIC pfs) + +add_subdirectory(third_party/fmt) +target_link_libraries(dynolog_lib PUBLIC fmt::fmt) + +if(USE_TENSORBOARD) + set(Protobuf_USE_STATIC_LIBS ON) + add_subdirectory(third_party/tensorboard_logger) + target_include_directories(dynolog_lib PUBLIC third_party/tensorboard_logger/include) + target_link_libraries(dynolog_lib PUBLIC tensorboard_logger) +endif() + +if(USE_ODS_GRAPH_API) + add_subdirectory(third_party/cpr) + target_link_libraries(dynolog_lib PUBLIC cpr::cpr) +endif() \ No newline at end of file diff --git a/msmonitor/dynolog_npu/cli/Cargo.toml b/msmonitor/dynolog_npu/cli/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..65b8d94af52eaa9188ad107732060c4e6e727f0f --- /dev/null +++ b/msmonitor/dynolog_npu/cli/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "dyno" +version = "0.1.0" +edition = "2021" + +[dependencies] +anyhow = "1.0.57" +clap = { version = "4.4", features = ["derive"]} +serde_json = "1.0" +rustls = "0.21.0" +rustls-pemfile = "1.0" +webpki = "0.22" +x509-parser = "0.15" +der-parser = "8" +pem = "1.1" +chrono = "0.4" +num-bigint = "0.4" +openssl = { version = "0.10", features = ["vendored"] } +rpassword = "7.2.0" diff --git a/msmonitor/dynolog_npu/cli/src/commands/dcgm.rs b/msmonitor/dynolog_npu/cli/src/commands/dcgm.rs new file mode 100644 index 0000000000000000000000000000000000000000..9f342c04dfd93e2ae0a411e49434950a0ce2bd16 --- /dev/null +++ b/msmonitor/dynolog_npu/cli/src/commands/dcgm.rs @@ -0,0 +1,32 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use anyhow::Result; +use crate::DynoClient; +use super::utils; + +// This module contains the handling logic for dcgm + +/// Pause dcgm module profiling +pub fn run_dcgm_pause( + mut client: DynoClient, + duration_s: i32, +) -> Result<()> { + let msg = format!(r#"{{"fn":"dcgmPause", "duration_s":{}}}"#, duration_s); + utils::send_msg(&mut client, &msg)?; + let resp_str = utils::get_resp(&mut client)?; + println!("{}", resp_str); + Ok(()) +} + +/// Resume dcgm module profiling +pub fn run_dcgm_resume( + mut client: DynoClient, +) -> Result<()> { + utils::send_msg(&mut client, r#"{"fn":"dcgmResume"}"#)?; + let resp_str = utils::get_resp(&mut client)?; + println!("{}", resp_str); + Ok(()) +} \ No newline at end of file diff --git a/msmonitor/dynolog_npu/cli/src/commands/gputrace.rs b/msmonitor/dynolog_npu/cli/src/commands/gputrace.rs new file mode 100644 index 0000000000000000000000000000000000000000..677ccf27d105aa438bb9768a2d45387f5c4cecf2 --- /dev/null +++ b/msmonitor/dynolog_npu/cli/src/commands/gputrace.rs @@ -0,0 +1,213 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use anyhow::Result; +use serde_json::Value; +use crate::DynoClient; +use super::utils; + +// This module contains the handling logic for dyno gputrace + +#[derive(Debug)] +pub enum GpuTraceTriggerConfig { + DurationBased { + profile_start_time: u64, + duration_ms: u64, + }, + IterationBased { + profile_start_iteration_roundup: u64, + iterations: i64, + }, +} + +impl GpuTraceTriggerConfig { + fn config(&self) -> String { + match *self { + GpuTraceTriggerConfig::DurationBased { + profile_start_time, + duration_ms, + } => format!( + "PROFILE_START_TIME={}\nACTIVITIES_DURATION_MSECS={}", + profile_start_time, duration_ms + ), + GpuTraceTriggerConfig::IterationBased { + profile_start_iteration_roundup, + iterations, + } => format!( + r#"PROFILE_START_ITERATION=0 +PROFILE_START_ITERATION_ROUNDUP={} +ACTIVITIES_ITERATIONS={}"#, + profile_start_iteration_roundup, iterations + ), + } + } +} + +#[derive(Debug)] +pub struct GpuTraceOptions { + pub record_shapes: bool, + pub profile_memory: bool, + pub with_stacks: bool, + pub with_flops: bool, + pub with_modules: bool, +} + +impl GpuTraceOptions { + fn config(&self) -> String { + format!( + r#" +PROFILE_REPORT_INPUT_SHAPES={} +PROFILE_PROFILE_MEMORY={} +PROFILE_WITH_STACK={} +PROFILE_WITH_FLOPS={} +PROFILE_WITH_MODULES={}"#, + self.record_shapes, + self.profile_memory, + self.with_stacks, + self.with_flops, + self.with_modules + ) + } +} + +#[derive(Debug)] +pub struct GpuTraceConfig { + pub log_file: String, + pub trigger_config: GpuTraceTriggerConfig, + pub trace_options: GpuTraceOptions, +} + +impl GpuTraceConfig { + fn config(&self) -> String { + format!( + "ACTIVITIES_LOG_FILE={}\n{}{}", + self.log_file, + self.trigger_config.config(), + self.trace_options.config() + ) + } +} + +/// Gputrace command triggers GPU profiling on pytorch apps +pub fn run_gputrace( + mut client: DynoClient, + job_id: u64, + pids: &str, + process_limit: u32, + config: GpuTraceConfig, +) -> Result<()> { + let kineto_config = config.config(); + println!("Kineto config = \n{}", kineto_config); + let kineto_config = kineto_config.replace('\n', "\\n"); + + let request_json = format!( + r#" +{{ + "fn": "setKinetOnDemandRequest", + "config": "{}", + "job_id": {}, + "pids": [{}], + "process_limit": {} +}}"#, + kineto_config, job_id, pids, process_limit + ); + + utils::send_msg(&mut client, &request_json)?; + + let resp_str = utils::get_resp(&mut client)?; + + println!("response = {}\n", resp_str); + + let resp_v: Value = serde_json::from_str(&resp_str)?; + let processes = resp_v["processesMatched"].as_array().unwrap(); + + if processes.is_empty() { + println!("No processes were matched, please check --job-id or --pids flags"); + } else { + println!("Matched {} processes", processes.len()); + println!("Trace output files will be written to:"); + + for pid in processes { + let pid = pid.as_i64().unwrap(); + println!( + " {}", + config.log_file.replace(".json", &format!("_{}.json", pid)) + ); + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use crate::*; + + #[test] + fn test_gputrace_trigger_config() { + let trigger_config = GpuTraceTriggerConfig::DurationBased { + profile_start_time: 1000, + duration_ms: 42, + }; + assert_eq!( + trigger_config.config(), + r#"PROFILE_START_TIME=1000 +ACTIVITIES_DURATION_MSECS=42"# + ); + + let trigger_config = GpuTraceTriggerConfig::IterationBased { + profile_start_iteration_roundup: 1000, + iterations: 42, + }; + assert_eq!( + trigger_config.config(), + r#"PROFILE_START_ITERATION=0 +PROFILE_START_ITERATION_ROUNDUP=1000 +ACTIVITIES_ITERATIONS=42"# + ); + } + + #[test] + fn test_gputrace_config() { + let mut test_trace_options = GpuTraceOptions { + record_shapes: true, + profile_memory: false, + with_stacks: true, + with_flops: false, + with_modules: true, + }; + assert_eq!( + test_trace_options.config(), + r#" +PROFILE_REPORT_INPUT_SHAPES=true +PROFILE_PROFILE_MEMORY=false +PROFILE_WITH_STACK=true +PROFILE_WITH_FLOPS=false +PROFILE_WITH_MODULES=true"# + ); + + test_trace_options.profile_memory = true; + + let test_trace_config = GpuTraceConfig { + log_file: String::from("/tmp/test_trace.json"), + trigger_config: GpuTraceTriggerConfig::DurationBased { + profile_start_time: 1000, + duration_ms: 42, + }, + trace_options: test_trace_options, + }; + assert_eq!( + test_trace_config.config(), + r#"ACTIVITIES_LOG_FILE=/tmp/test_trace.json +PROFILE_START_TIME=1000 +ACTIVITIES_DURATION_MSECS=42 +PROFILE_REPORT_INPUT_SHAPES=true +PROFILE_PROFILE_MEMORY=true +PROFILE_WITH_STACK=true +PROFILE_WITH_FLOPS=false +PROFILE_WITH_MODULES=true"# + ); + } +} \ No newline at end of file diff --git a/dynolog_npu/dynolog_npu/cli/src/commands/mod.rs b/msmonitor/dynolog_npu/cli/src/commands/mod.rs similarity index 97% rename from dynolog_npu/dynolog_npu/cli/src/commands/mod.rs rename to msmonitor/dynolog_npu/cli/src/commands/mod.rs index 18950d3c1a01d972db58a614a46f08176b02c725..1f3bf17ad791c73b531323eafd6f2601b7db6b03 100644 --- a/dynolog_npu/dynolog_npu/cli/src/commands/mod.rs +++ b/msmonitor/dynolog_npu/cli/src/commands/mod.rs @@ -9,10 +9,11 @@ // handling code. Additionally, explicitly "exporting" all the command modules here allows // us to avoid having to explicitly list all the command modules in main.rs. +pub mod status; +pub mod version; pub mod dcgm; pub mod gputrace; pub mod nputrace; pub mod npumonitor; -pub mod status; -pub mod version; +pub mod utils; // ... add new command modules here \ No newline at end of file diff --git a/dynolog_npu/dynolog_npu/cli/src/commands/npumonitor.rs b/msmonitor/dynolog_npu/cli/src/commands/npumonitor.rs similarity index 66% rename from dynolog_npu/dynolog_npu/cli/src/commands/npumonitor.rs rename to msmonitor/dynolog_npu/cli/src/commands/npumonitor.rs index 1edfaea5939f5cee5df8618720d1bfa16d0071b5..239145ec2e976e225305a44d484319ea625c0e57 100644 --- a/dynolog_npu/dynolog_npu/cli/src/commands/npumonitor.rs +++ b/msmonitor/dynolog_npu/cli/src/commands/npumonitor.rs @@ -1,9 +1,6 @@ -use std::net::TcpStream; - use anyhow::Result; - -#[path = "utils.rs"] -mod utils; +use crate::DynoClient; +use super::utils; #[derive(Debug)] pub struct NpuMonitorConfig { @@ -11,28 +8,27 @@ pub struct NpuMonitorConfig { pub npu_monitor_stop: bool, pub report_interval_s: u32, pub mspti_activity_kind: String, + pub log_file: String } impl NpuMonitorConfig { fn config(&self) -> String { format!( - r#" -NPU_MONITOR_START={} + r#"NPU_MONITOR_START={} NPU_MONITOR_STOP={} REPORT_INTERVAL_S={} -MSPTI_ACTIVITY_KIND={}"#, +MSPTI_ACTIVITY_KIND={} +NPU_MONITOR_LOG_FILE={}"#, self.npu_monitor_start, self.npu_monitor_stop, self.report_interval_s, - self.mspti_activity_kind + self.mspti_activity_kind, + self.log_file ) } } -pub fn run_npumonitor( - client: TcpStream, - config: NpuMonitorConfig, -) -> Result<()> { +pub fn run_npumonitor(mut client: DynoClient, config: NpuMonitorConfig) -> Result<()> { let config_str = config.config(); println!("Npu monitor config = \n{}", config_str); let config_str = config_str.replace('\n', "\\n"); @@ -49,10 +45,8 @@ pub fn run_npumonitor( config_str ); - utils::send_msg(&client, &request_json).expect("Error sending message to service"); - - let resp_str = utils::get_resp(&client).expect("Unable to decode output bytes"); - + utils::send_msg(&mut client, &request_json)?; + let resp_str = utils::get_resp(&mut client)?; println!("response = {}", resp_str); Ok(()) diff --git a/dynolog_npu/dynolog_npu/cli/src/commands/nputrace.rs b/msmonitor/dynolog_npu/cli/src/commands/nputrace.rs similarity index 78% rename from dynolog_npu/dynolog_npu/cli/src/commands/nputrace.rs rename to msmonitor/dynolog_npu/cli/src/commands/nputrace.rs index 4bf7132de338d8eee0de556449269712617772e2..66f5d576c8381e90e3debaf7b2d351670997cce9 100644 --- a/dynolog_npu/dynolog_npu/cli/src/commands/nputrace.rs +++ b/msmonitor/dynolog_npu/cli/src/commands/nputrace.rs @@ -1,10 +1,12 @@ -use std::net::TcpStream; - +// Copyright (c) Meta Platforms, Inc. and affiliates. +// Copyright (c) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. use anyhow::Result; use serde_json::Value; - -#[path = "utils.rs"] -mod utils; +use crate::DynoClient; +use super::utils; #[derive(Debug)] pub enum NpuTraceTriggerConfig { @@ -13,7 +15,7 @@ pub enum NpuTraceTriggerConfig { duration_ms: u64, }, IterationBased { - start_step: u64, + start_step: i64, iterations: i64, }, } @@ -55,9 +57,15 @@ pub struct NpuTraceOptions { pub aic_metrics: String, pub l2_cache: bool, pub op_attr: bool, + pub msprof_tx: bool, pub gc_detect_threshold: Option, pub data_simplification: String, pub export_type: String, + pub host_sys: String, + pub sys_io: bool, + pub sys_interconnection: bool, + pub mstx_domain_include: Option, + pub mstx_domain_exclude: Option, } impl NpuTraceOptions { @@ -75,9 +83,15 @@ PROFILE_PROFILER_LEVEL={} PROFILE_AIC_METRICS={} PROFILE_L2_CACHE={} PROFILE_OP_ATTR={} +PROFILE_MSPROF_TX={} PROFILE_GC_DETECT_THRESHOLD={} PROFILE_DATA_SIMPLIFICATION={} -PROFILE_EXPORT_TYPE={}"#, +PROFILE_EXPORT_TYPE={} +PROFILE_HOST_SYS={} +PROFILE_SYS_IO={} +PROFILE_SYS_INTERCONNECTION={} +PROFILE_MSTX_DOMAIN_INCLUDE={} +PROFILE_MSTX_DOMAIN_EXCLUDE={}"#, self.record_shapes, self.profile_memory, self.with_stack, @@ -89,9 +103,15 @@ PROFILE_EXPORT_TYPE={}"#, self.aic_metrics, self.l2_cache, self.op_attr, + self.msprof_tx, self.gc_detect_threshold.map_or("None".to_string(), |v| v.to_string()), self.data_simplification, - self.export_type + self.export_type, + self.host_sys, + self.sys_io, + self.sys_interconnection, + self.mstx_domain_include.clone().map_or("None".to_string(), |v| v.to_string()), + self.mstx_domain_exclude.clone().map_or("None".to_string(), |v| v.to_string()) ) } } @@ -115,7 +135,7 @@ impl NpuTraceConfig { } pub fn run_nputrace( - client: TcpStream, + mut client: DynoClient, job_id: u64, pids: &str, process_limit: u32, @@ -137,9 +157,9 @@ pub fn run_nputrace( config_str, job_id, pids, process_limit ); - utils::send_msg(&client, &request_json).expect("Error sending message to service"); + utils::send_msg(&mut client, &request_json)?; - let resp_str = utils::get_resp(&client).expect("Unable to decode output bytes"); + let resp_str = utils::get_resp(&mut client)?; println!("response = {}", resp_str); @@ -181,7 +201,7 @@ mod test { ACTIVITIES_DURATION_MSECS=1000"# ); - let trigger_config = NpuTraceTriggerConfig::IterationBased { + let trigger_config = NpuTraceTriggerConfig::IterationBased { profile_start_step: 1000, iterations: 1000, }; @@ -213,9 +233,15 @@ ACTIVITIES_ITERATIONS=1000"# aic_metrics: "AiCoreNone".to_string(), l2_cache: true, op_attr: true, + msprof_tx: true, gc_detect_threshold: 0.1, data_simplification: "true", export_type: "Text".to_string(), + host_sys: "cpu".to_string(), + sys_io: true, + sys_interconnection: true, + mstx_domain_include: "domain1".to_string(), + mstx_domain_exclude: "domain2".to_string(), }, }; assert_eq!( @@ -234,9 +260,15 @@ PROFILE_PROFILER_LEVEL=Level0 PROFILE_AIC_METRICS=AiCoreNone PROFILE_L2_CACHE=true PROFILE_OP_ATTR=true +PROFILE_MSPROF_TX=true PROFILE_GC_DETECT_THRESHOLD=0.1 PROFILE_DATA_SIMPLIFICATION=true -PROFILE_EXPORT_TYPE=Text"# +PROFILE_EXPORT_TYPE=Text +PROFILE_HOST_SYS=cpu +PROFILE_SYS_IO=true +PROFILE_SYS_INTERCONNECTION=true +PROFILE_MSTX_DOMAIN_INCLUDE=domain1 +PROFILE_MSTX_DOMAIN_EXCLUDE=domain2"# ); } } diff --git a/msmonitor/dynolog_npu/cli/src/commands/status.rs b/msmonitor/dynolog_npu/cli/src/commands/status.rs new file mode 100644 index 0000000000000000000000000000000000000000..1be17956c12f270871878d584ea1da904c7f0418 --- /dev/null +++ b/msmonitor/dynolog_npu/cli/src/commands/status.rs @@ -0,0 +1,15 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use anyhow::Result; +use crate::DynoClient; +use super::utils; + +pub fn run_status(mut client: DynoClient) -> Result<()> { + utils::send_msg(&mut client, r#"{"fn":"getStatus"}"#)?; + let resp_str = utils::get_resp(&mut client)?; + println!("{}", resp_str); + Ok(()) +} \ No newline at end of file diff --git a/msmonitor/dynolog_npu/cli/src/commands/utils.rs b/msmonitor/dynolog_npu/cli/src/commands/utils.rs new file mode 100644 index 0000000000000000000000000000000000000000..c2fdd3de6186b1c4a60041ba9dbb276811970fa8 --- /dev/null +++ b/msmonitor/dynolog_npu/cli/src/commands/utils.rs @@ -0,0 +1,50 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::io::{Read, Write}; + +use anyhow::Result; + +use crate::DynoClient; + +pub fn send_msg(client: &mut DynoClient, msg: &str) -> Result<()> { + match client { + DynoClient::Secure(secure_client) => { + let msg_len: [u8; 4] = i32::try_from(msg.len()).unwrap().to_ne_bytes(); + secure_client.write_all(&msg_len)?; + secure_client.write_all(msg.as_bytes())?; + secure_client.flush()?; + } + DynoClient::Insecure(insecure_client) => { + let msg_len: [u8; 4] = i32::try_from(msg.len()).unwrap().to_ne_bytes(); + insecure_client.write_all(&msg_len)?; + insecure_client.write_all(msg.as_bytes())?; + insecure_client.flush()?; + } + } + Ok(()) +} + +pub fn get_resp(client: &mut DynoClient) -> Result { + let mut len_buf = [0u8; 4]; + let mut resp_buf; + + match client { + DynoClient::Secure(secure_client) => { + secure_client.read_exact(&mut len_buf)?; + let len = u32::from_ne_bytes(len_buf) as usize; + resp_buf = vec![0u8; len]; + secure_client.read_exact(&mut resp_buf)?; + } + DynoClient::Insecure(insecure_client) => { + insecure_client.read_exact(&mut len_buf)?; + let len = u32::from_ne_bytes(len_buf) as usize; + resp_buf = vec![0u8; len]; + insecure_client.read_exact(&mut resp_buf)?; + } + } + + Ok(String::from_utf8(resp_buf)?) +} \ No newline at end of file diff --git a/msmonitor/dynolog_npu/cli/src/commands/version.rs b/msmonitor/dynolog_npu/cli/src/commands/version.rs new file mode 100644 index 0000000000000000000000000000000000000000..31139d5685289e7b4fadd28a97c41cf943e8c291 --- /dev/null +++ b/msmonitor/dynolog_npu/cli/src/commands/version.rs @@ -0,0 +1,18 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use anyhow::Result; +use crate::DynoClient; +use super::utils; + +// This module contains the handling logic for querying dyno version + +/// Get version info +pub fn run_version(mut client: DynoClient) -> Result<()> { + utils::send_msg(&mut client, r#"{"fn":"getVersion"}"#)?; + let resp_str = utils::get_resp(&mut client)?; + println!("{}", resp_str); + Ok(()) +} \ No newline at end of file diff --git a/dynolog_npu/dynolog_npu/cli/src/main.rs b/msmonitor/dynolog_npu/cli/src/main.rs similarity index 37% rename from dynolog_npu/dynolog_npu/cli/src/main.rs rename to msmonitor/dynolog_npu/cli/src/main.rs index 8bc4a2af0e2c19d6e783663924578e3c2ad7408a..525b787327aa4e7d90925d2e962eb250ec17480e 100644 --- a/dynolog_npu/dynolog_npu/cli/src/main.rs +++ b/msmonitor/dynolog_npu/cli/src/main.rs @@ -2,14 +2,28 @@ // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. - +use std::fs::File; +use std::io::{BufReader, Read}; +use rustls::{Certificate, RootCertStore, PrivateKey, ClientConnection, StreamOwned}; +use std::sync::Arc; use std::net::TcpStream; use std::net::ToSocketAddrs; +use std::path::PathBuf; +use std::io; +use rpassword::prompt_password; use anyhow::Result; use clap::Parser; use std::collections::HashSet; +use x509_parser::prelude::*; +use x509_parser::num_bigint::ToBigInt; +use std::fs::read_to_string; +use x509_parser::public_key::RSAPublicKey; +use x509_parser::der_parser::oid; +use num_bigint::BigUint; +use openssl::pkey::PKey; + // Make all the command modules accessible to this file. mod commands; use commands::gputrace::GpuTraceConfig; @@ -37,18 +51,22 @@ use commands::*; /// the command dispatching logic clear and concise, please keep the code in the match branch to a minimum. const DYNO_PORT: u16 = 1778; +const MIN_RSA_KEY_LENGTH: u64 = 3072; // 最小 RSA 密钥长度(位) #[derive(Debug, Parser)] +#[command(author, version, about, long_about = None)] struct Opts { - #[clap(long, default_value = "localhost")] + #[arg(long, default_value = "localhost")] hostname: String, - #[clap(long, default_value_t = DYNO_PORT)] + #[arg(long, default_value_t = DYNO_PORT)] port: u16, - #[clap(subcommand)] + #[arg(long, required = true)] + certs_dir: String, + #[command(subcommand)] cmd: Command, } -const ALLOWED_VALUES: &[&str] = &["Marker", "Kernel", "API", "Hccl", "Memory", "MemSet", "MemCpy"]; +const ALLOWED_VALUES: &[&str] = &["Marker", "Kernel", "API", "Hccl", "Memory", "MemSet", "MemCpy", "Communication"]; fn parse_mspti_activity_kinds(src: &str) -> Result{ let allowed_values: HashSet<&str> = ALLOWED_VALUES.iter().cloned().collect(); @@ -60,10 +78,49 @@ fn parse_mspti_activity_kinds(src: &str) -> Result{ return Err(format!("Invalid MSPTI activity kind: {}, Possible values: {:?}.]", kind, allowed_values)); } } - + Ok(src.to_string()) } +const ALLOWED_HOST_SYSTEM_VALUES: &[&str] = &["cpu", "mem", "disk", "network", "osrt"]; + +fn parse_host_sys(src: &str) -> Result{ + if src == "None" { + return Ok(src.to_string()); + } + + let allowed_host_sys_values: HashSet<&str> = ALLOWED_HOST_SYSTEM_VALUES.iter().cloned().collect(); + + let host_systems: Vec<&str> = src.split(',').map(|s| s.trim()).collect(); + + for host_system in &host_systems { + if !allowed_host_sys_values.contains(host_system) { + return Err(format!("Invalid NPU Trace host system: {}, Possible values: {:?}.]", host_system, + allowed_host_sys_values)); + } + } + let result = host_systems.join(","); + Ok(result) +} + +const INSTANT_START_STEP: i64 = -1; // nputrace子命令,表示从下个step开启采集任务 + +fn parse_start_step(src: &str) -> Result { + let start_step = src.trim().parse::().map_err(|e| format!("{}", e))?; + if start_step < INSTANT_START_STEP { + return Err(format!("Must be non-negative integer or {}", INSTANT_START_STEP)); + } + Ok(start_step) +} + +fn parse_iterations(src: &str) -> Result { + let iterations = src.trim().parse::().map_err(|e| format!("{}", e))?; + if iterations <= 0 { + return Err("Must be a positive integer".to_string()); + } + Ok(iterations) +} + #[derive(Debug, Parser)] enum Command { /// Check the status of a dynolog process @@ -73,44 +130,44 @@ enum Command { /// Capture gputrace Gputrace { /// Job id of the application to trace. - #[clap(long, default_value_t = 0)] + #[arg(long, default_value_t = 0)] job_id: u64, /// List of pids to capture trace for (comma separated). - #[clap(long, default_value = "0")] + #[arg(long, default_value = "0")] pids: String, /// Duration of trace to collect in ms. - #[clap(long, default_value_t = 500)] + #[arg(long, default_value_t = 500)] duration_ms: u64, /// Training iterations to collect, this takes precedence over duration. - #[clap(long, default_value_t = -1)] + #[arg(long, default_value_t = -1)] iterations: i64, /// Log file for trace. - #[clap(long)] + #[arg(long)] log_file: String, /// Unix timestamp used for synchronized collection (milliseconds since epoch). - #[clap(long, default_value_t = 0)] + #[arg(long, default_value_t = 0)] profile_start_time: u64, /// Start iteration roundup, starts an iteration based trace at a multiple /// of this value. - #[clap(long, default_value_t = 1)] + #[arg(long, default_value_t = 1)] profile_start_iteration_roundup: u64, /// Max number of processes to profile. - #[clap(long, default_value_t = 3)] + #[arg(long, default_value_t = 3)] process_limit: u32, /// Record PyTorch operator input shapes and types. - #[clap(long, action)] + #[arg(long)] record_shapes: bool, /// Profile PyTorch memory. - #[clap(long, action)] + #[arg(long)] profile_memory: bool, /// Capture Python stacks in traces. - #[clap(long, action)] + #[arg(long)] with_stacks: bool, /// Annotate operators with analytical flops. - #[clap(long, action)] + #[arg(long)] with_flops: bool, /// Capture PyTorch operator modules in traces. - #[clap(long, action)] + #[arg(long)] with_modules: bool, }, /// Capture nputrace. Subcommand functions aligned with Ascend Torch Profiler. @@ -125,7 +182,7 @@ enum Command { #[clap(long, default_value_t = 500)] duration_ms: u64, /// Training iterations to collect, this takes precedence over duration. - #[clap(long, default_value_t = -1)] + #[clap(long, value_parser = parse_iterations, allow_negative_numbers = true)] iterations: i64, /// Log file for trace. #[clap(long)] @@ -133,9 +190,9 @@ enum Command { /// Unix timestamp used for synchronized collection (milliseconds since epoch). #[clap(long, default_value_t = 0)] profile_start_time: u64, - /// Number of steps to start profile. - #[clap(long, default_value_t = 0)] - start_step: u64, + /// Number of steps to start profile, -1 means start from next step. + #[clap(long, value_parser = parse_start_step, allow_negative_numbers = true)] + start_step: i64, /// Max number of processes to profile. #[clap(long, default_value_t = 3)] process_limit: u32, @@ -172,6 +229,9 @@ enum Command { /// Whether to collect op attributes. #[clap(long, action)] op_attr: bool, + /// Whether to enable MSTX. + #[clap(long, action)] + msprof_tx: bool, /// GC detect threshold. #[clap(long)] gc_detect_threshold: Option, @@ -181,6 +241,21 @@ enum Command { /// Types of data exported by the profiler. #[clap(long, value_parser = ["Text", "Db"], default_value = "Text")] export_type: String, + /// Obtain the system data on the host side. + #[clap(long, value_parser = parse_host_sys, default_value = "None")] + host_sys: String, + /// Whether to enable sys io. + #[clap(long, action)] + sys_io: bool, + /// Whether to enable sys interconnection. + #[clap(long, action)] + sys_interconnection: bool, + /// The domain that needs to be enabled in mstx mode. + #[clap(long)] + mstx_domain_include: Option, + /// Domains that do not need to be enabled in mstx mode. + #[clap(long)] + mstx_domain_exclude: Option, }, /// Ascend MSPTI Monitor NpuMonitor { @@ -196,6 +271,9 @@ enum Command { /// MSPTI collect activity kind #[clap(long, value_parser = parse_mspti_activity_kinds, default_value = "Marker")] mspti_activity_kind: String, + /// Log file for NPU monitor. + #[clap(long, default_value = "")] + log_file: String, }, /// Pause dcgm profiling. This enables running tools like Nsight compute and avoids conflicts. DcgmPause { @@ -207,29 +285,411 @@ enum Command { DcgmResume, } -/// Create a socket connection to dynolog -fn create_dyno_client(host: &str, port: u16) -> Result { +struct ClientConfigPath { + cert_path: PathBuf, + key_path: PathBuf, + ca_cert_path: PathBuf, +} + +fn verify_certificate(cert_der: &[u8], is_root_cert: bool) -> Result<()> { + // 解析 X509 证书 + let (_, cert) = X509Certificate::from_der(cert_der) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to parse cert: {:?}", e)))?; + + // 检查证书版本是否为 X.509v3 + if cert.tbs_certificate.version != X509Version(2) { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Certificate is not X.509v3" + ).into()); + } + + // 检查证书签名算法 + let sig_alg = cert.signature_algorithm.algorithm; + + // 定义不安全的算法 OID + let md2_rsa = oid!(1.2.840.113549.1.1.2); // MD2 with RSA + let md5_rsa = oid!(1.2.840.113549.1.1.4); // MD5 with RSA + let sha1_rsa = oid!(1.2.840.113549.1.1.5); // SHA1 with RSA + + // 检查是否使用不安全的算法 + if sig_alg == md2_rsa || sig_alg == md5_rsa || sig_alg == sha1_rsa { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Certificate uses insecure signature algorithm" + ).into()); + } + + // 定义 RSA 签名算法 OID + let rsa_sha256 = oid!(1.2.840.113549.1.1.11); // RSA with SHA256 + let rsa_sha384 = oid!(1.2.840.113549.1.1.12); // RSA with SHA384 + let rsa_sha512 = oid!(1.2.840.113549.1.1.13); // RSA with SHA512 + + // 检查 RSA 密钥长度 + if sig_alg == rsa_sha256 || sig_alg == rsa_sha384 || sig_alg == rsa_sha512 { + // 获取公钥 + if let Ok((_, public_key)) = SubjectPublicKeyInfo::from_der(&cert.tbs_certificate.subject_pki.subject_public_key.data) { + if let Ok((_, rsa_key)) = RSAPublicKey::from_der(&public_key.subject_public_key.data) { + // 检查 RSA 密钥长度 + let modulus = BigUint::from_bytes_be(&rsa_key.modulus); + let key_length = modulus.bits(); + if key_length < MIN_RSA_KEY_LENGTH { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("RSA key length {} bits is less than required {} bits", key_length, MIN_RSA_KEY_LENGTH) + ).into()); + } + } + } + } + + // 检查证书的扩展域 + let mut has_ca_constraint = false; + let mut has_key_usage = false; + let mut has_crl_sign = false; + let mut has_cert_sign = false; + + for ext in cert.tbs_certificate.extensions() { + if ext.oid == oid_registry::OID_X509_EXT_BASIC_CONSTRAINTS { + if let Ok((_, constraints)) = BasicConstraints::from_der(ext.value) { + has_ca_constraint = constraints.ca; + } else { + println!("Failed to parse Basic Constraints"); + } + } else if ext.oid == oid_registry::OID_X509_EXT_KEY_USAGE { + println!("Found Key Usage extension"); + if let Ok((_, usage)) = KeyUsage::from_der(ext.value) { + has_key_usage = true; + has_cert_sign = usage.key_cert_sign(); + has_crl_sign = usage.crl_sign(); + } else { + println!("Failed to parse Key Usage"); + } + } + } + + // 根据证书类型进行不同的验证 + if is_root_cert { + // 根证书验证要求 + if !has_ca_constraint { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Root certificate must have CA constraint" + ).into()); + } + if !has_key_usage { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Root certificate must have key usage extension" + ).into()); + } + if !has_cert_sign { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Root certificate must have certificate signature permission" + ).into()); + } + if !has_crl_sign { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Root certificate must have CRL signature permission" + ).into()); + } + } else { + // 客户端证书验证要求 + if has_ca_constraint { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Client certificate should not have CA constraint" + ).into()); + } + if !has_key_usage { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Client certificate must have key usage extension" + ).into()); + } + } + + // 检查证书有效期 + let now = chrono::Utc::now(); + let not_before = chrono::DateTime::from_timestamp( + cert.tbs_certificate.validity.not_before.timestamp(), + 0 + ).ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid not_before date"))?; + + let not_after = chrono::DateTime::from_timestamp( + cert.tbs_certificate.validity.not_after.timestamp(), + 0 + ).ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid not_after date"))?; + + if now < not_before { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Certificate is not yet valid. Valid from: {}", not_before) + ).into()); + } + + if now > not_after { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Certificate has expired. Expired at: {}", not_after) + ).into()); + } + + Ok(()) +} + +fn is_cert_revoked(cert_der: &[u8], crl_path: &PathBuf) -> Result { + // 解析 X509 证书 + let (_, cert) = X509Certificate::from_der(cert_der) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to parse cert: {:?}", e)))?; + + // 读取 CRL 文件 + let crl_data = read_to_string(crl_path)?; + let (_, pem) = pem::parse_x509_pem(crl_data.as_bytes()) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to parse CRL PEM: {:?}", e)))?; + + // 解析 CRL + let (_, crl) = CertificateRevocationList::from_der(&pem.contents) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to parse CRL: {:?}", e)))?; + + // 检查 CRL 的有效期 + let now = chrono::Utc::now(); + let crl_not_before = chrono::DateTime::from_timestamp( + crl.tbs_cert_list.this_update.timestamp(), + 0 + ).ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid CRL this_update date"))?; + + let crl_not_after = if let Some(next_update) = crl.tbs_cert_list.next_update { + chrono::DateTime::from_timestamp( + next_update.timestamp(), + 0 + ).ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid CRL next_update date"))? + } else { + crl_not_before + chrono::Duration::days(365) + }; + + // 检查 CRL 是否在有效期内 + if now < crl_not_before { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("CRL is not yet valid. Valid from: {}", crl_not_before) + ).into()); + } + + if now > crl_not_after { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("CRL has expired. Expired at: {}", crl_not_after) + ).into()); + } + + // 获取证书序列号 + let cert_serial = cert.tbs_certificate.serial.to_bigint() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Failed to convert certificate serial to BigInt"))?; + + // 检查 CRL 吊销条目 + for revoked in crl.iter_revoked_certificates() { + let revoked_serial = revoked.user_certificate.to_bigint() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Failed to convert revoked certificate serial to BigInt"))?; + + if revoked_serial == cert_serial { + return Ok(true); + } + } + Ok(false) +} + +enum DynoClient { + Secure(StreamOwned), + Insecure(TcpStream), +} + +fn create_dyno_client( + host: &str, + port: u16, + certs_dir: &str, +) -> Result { + if certs_dir == "NO_CERTS" { + println!("Running in no-certificate mode"); + create_dyno_client_with_no_certs(host, port) + } else { + println!("Running in certificate mode"); + let certs_dir = PathBuf::from(certs_dir); + let config = ClientConfigPath { + cert_path: certs_dir.join("client.crt"), + key_path: certs_dir.join("client.key"), + ca_cert_path: certs_dir.join("ca.crt"), + }; + let client = create_dyno_client_with_certs(host, port, &config)?; + Ok(DynoClient::Secure(client)) + } +} + +fn create_dyno_client_with_no_certs( + host: &str, + port: u16, +) -> Result { let addr = (host, port) .to_socket_addrs()? .next() .expect("Failed to connect to the server"); + let stream = TcpStream::connect(addr)?; + Ok(DynoClient::Insecure(stream)) +} - TcpStream::connect(addr).map_err(|err| err.into()) +// 安全清除密码的函数 +fn secure_clear_password(password: &mut Vec) { + if !password.is_empty() { + // 使用零覆盖密码数据 + for byte in password.iter_mut() { + *byte = 0; + } + // 清空向量 + password.clear(); + // 收缩向量容量,释放内存 + password.shrink_to_fit(); + } } +fn create_dyno_client_with_certs( + host: &str, + port: u16, + config: &ClientConfigPath, +) -> Result> { + let addr = (host, port) + .to_socket_addrs()? + .next() + .ok_or_else(|| io::Error::new( + io::ErrorKind::NotFound, + "Could not resolve the host address" + ))?; + + let stream = TcpStream::connect(addr)?; + + println!("Loading CA cert from: {}", config.ca_cert_path.display()); + let mut root_store = RootCertStore::empty(); + let ca_file = File::open(&config.ca_cert_path)?; + let mut ca_reader = BufReader::new(ca_file); + let ca_certs = rustls_pemfile::certs(&mut ca_reader)?; + for ca_cert in &ca_certs { + verify_certificate(ca_cert, true)?; // 验证根证书 + } + for ca_cert in ca_certs { + root_store.add(&Certificate(ca_cert))?; + } + + println!("Loading client cert from: {}", config.cert_path.display()); + let cert_file = File::open(&config.cert_path)?; + let mut cert_reader = BufReader::new(cert_file); + let certs = rustls_pemfile::certs(&mut cert_reader)?; + + // 检查客户端证书的基本要求 + for cert in &certs { + verify_certificate(cert, false)?; // 验证客户端证书 + } + + // 检查证书吊销状态 + let crl_path = config.cert_path.parent().unwrap().join("ca.crl"); + if crl_path.exists() { + println!("Checking CRL file: {}", crl_path.display()); + for cert in &certs { + match is_cert_revoked(cert, &crl_path) { + Ok(true) => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Certificate is revoked" + ).into()); + } + Ok(false) => { + continue; + } + Err(e) => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("CRL verification failed: {}", e) + ).into()); + } + } + } + } else { + println!("CRL file does not exist: {}", crl_path.display()); + } + + let certs = certs.into_iter().map(Certificate).collect(); + + println!("Loading client key from: {}", config.key_path.display()); + let key_file = File::open(&config.key_path)?; + let mut key_reader = BufReader::new(key_file); + + // 检查私钥是否加密 + let mut key_data = Vec::new(); + key_reader.read_to_end(&mut key_data)?; + let key_str = String::from_utf8_lossy(&key_data); + let is_encrypted = key_str.contains("ENCRYPTED"); + + // 根据是否加密来加载私钥 + let keys = if is_encrypted { + // 如果私钥是加密的,请求用户输入密码 + let mut password = prompt_password("Please enter the certificate password: ")?.into_bytes(); + let pkey = PKey::private_key_from_pem_passphrase(&key_data, &password) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to decrypt private key: {}", e)))?; + + // 手动清除密码 + secure_clear_password(&mut password); + + // 返回私钥 + vec![pkey.private_key_to_der() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to convert private key to DER: {}", e)))?] + } else { + // 如果私钥未加密,直接加载 + let mut key_reader = BufReader::new(File::open(&config.key_path)?); + rustls_pemfile::pkcs8_private_keys(&mut key_reader)? + }; + + if keys.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "No private key found in the key file" + ).into()); + } + let key = PrivateKey(keys[0].clone()); + + let config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_client_auth_cert(certs, key)?; + + let server_name = rustls::ServerName::try_from(host) + .map_err(|e| io::Error::new( + io::ErrorKind::InvalidInput, + format!("Invalid hostname: {}", e) + ))?; + + let conn = rustls::ClientConnection::new( + Arc::new(config), + server_name + )?; + + Ok(StreamOwned::new(conn, stream)) +} + + fn main() -> Result<()> { let Opts { hostname, port, + certs_dir, cmd, } = Opts::parse(); - let dyno_client = - create_dyno_client(&hostname, port).expect("Couldn't connect to the server..."); + let client = create_dyno_client(&hostname, port, &certs_dir) + .expect("Couldn't connect to the server..."); match cmd { - Command::Status => status::run_status(dyno_client), - Command::Version => version::run_version(dyno_client), + Command::Status => status::run_status(client), + Command::Version => version::run_version(client), Command::Gputrace { job_id, pids, @@ -268,7 +728,7 @@ fn main() -> Result<()> { trigger_config, trace_options, }; - gputrace::run_gputrace(dyno_client, job_id, &pids, process_limit, trace_config) + gputrace::run_gputrace(client, job_id, &pids, process_limit, trace_config) } Command::Nputrace { job_id, @@ -290,9 +750,15 @@ fn main() -> Result<()> { aic_metrics, l2_cache, op_attr, + msprof_tx, gc_detect_threshold, data_simplification, export_type, + host_sys, + sys_io, + sys_interconnection, + mstx_domain_include, + mstx_domain_exclude, } => { let trigger_config = if iterations > 0 { NpuTraceTriggerConfig::IterationBased { @@ -318,33 +784,41 @@ fn main() -> Result<()> { aic_metrics, l2_cache, op_attr, + msprof_tx, gc_detect_threshold, data_simplification, export_type, + host_sys, + sys_io, + sys_interconnection, + mstx_domain_include, + mstx_domain_exclude, }; let trace_config = NpuTraceConfig { log_file, trigger_config, trace_options, }; - nputrace::run_nputrace(dyno_client, job_id, &pids, process_limit, trace_config) + nputrace::run_nputrace(client, job_id, &pids, process_limit, trace_config) } Command::NpuMonitor { npu_monitor_start, npu_monitor_stop, report_interval_s, mspti_activity_kind, + log_file, } => { let npu_mon_config = NpuMonitorConfig { npu_monitor_start, npu_monitor_stop, report_interval_s, - mspti_activity_kind + mspti_activity_kind, + log_file }; - npumonitor::run_npumonitor(dyno_client, npu_mon_config) + npumonitor::run_npumonitor(client, npu_mon_config) } - Command::DcgmPause { duration_s } => dcgm::run_dcgm_pause(dyno_client, duration_s), - Command::DcgmResume => dcgm::run_dcgm_resume(dyno_client), + Command::DcgmPause { duration_s } => dcgm::run_dcgm_pause(client, duration_s), + Command::DcgmResume => dcgm::run_dcgm_resume(client), // ... add new commands here } } \ No newline at end of file diff --git a/msmonitor/dynolog_npu/cmake/Findopenssl.cmake b/msmonitor/dynolog_npu/cmake/Findopenssl.cmake new file mode 100644 index 0000000000000000000000000000000000000000..746134a7ac94fb96822018d755a894005cde4928 --- /dev/null +++ b/msmonitor/dynolog_npu/cmake/Findopenssl.cmake @@ -0,0 +1,81 @@ +set(PACKAGE_VERSION 3.0.16) + +set(PKG_NAME openssl) +set(SHA256_VALUE "47ad8d3b2745717edf612fd75366faa3da4ef36b87343632de0df2433f425721") +set(GIT_TAG "openssl-3.0.16") +set(DOWNLOAD_PATH "${CMAKE_SOURCE_DIR}/third_party") +set(DIR_NAME "${DOWNLOAD_PATH}/openssl") +set(LIBDIR "lib64") + +function(download_opensource_pkg pkg_name) + message("start to download ${pkg_name}...") + set(options) + set(oneValueArgs SHA256 GIT_TAG DOWNLOAD_PATH DIR_NAME BUILD_CMD) + set(multiValueArgs PATCHES) + cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + if (NOT PKG_DOWNLOAD_PATH) + set(PKG_DOWNLOAD_PATH "${CMAKE_SOURCE_DIR}/third_party") + endif() + file(MAKE_DIRECTORY ${PKG_DOWNLOAD_PATH}) + + execute_process( + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/cmake + COMMAND bash download_opensource.sh ${pkg_name} ${PKG_DOWNLOAD_PATH} ${PKG_SHA256} ${PKG_GIT_TAG} + RESULT_VARIABLE RESULT + ) + if (NOT RESULT EQUAL 0) + message(FATAL_ERROR "Failed to download ${pkg_name}(${RESULT}).") + endif() + if (PKG_BUILD_CMD) + execute_process(COMMAND bash -c "cd ${PKG_DOWNLOAD_PATH}/${DIR_NAME};${PKG_BUILD_CMD}") + endif() +endfunction() + +download_opensource_pkg(${PKG_NAME} + SHA256 ${SHA256_VALUE} + GIT_TAG ${GIT_TAG} + DOWNLOAD_PATH ${DOWNLOAD_PATH} +) + +include_directories(${DIR_NAME}/include) +set(BUILD_DEPENDENCY_PATH "${DOWNLOAD_PATH}/openssl_build_dependency") +file(GLOB OPENSSL_LIB "${BUILD_DEPENDENCY_PATH}/${LIBDIR}/libssl.a") +file(GLOB CRYPTO_LIB "${BUILD_DEPENDENCY_PATH}/${LIBDIR}/libcrypto.a") +if (OPENSSL_LIB AND CRYPTO_LIB) + set(${PKG_NAME}_FOUND TRUE) + set(${PKG_NAME}_LIBRARIES "${OPENSSL_LIB};${CRYPTO_LIB}") + return() +endif() + +execute_process( + WORKING_DIRECTORY ${DIR_NAME} + COMMAND ./config -fPIC no-shared --prefix=${BUILD_DEPENDENCY_PATH} --libdir=${LIBDIR} + RESULT_VARIABLE RESULT +) +if (NOT RESULT EQUAL 0) + message(FATAL_ERROR "Failed to build openssl. ${RESULT}") +endif() + +execute_process( + WORKING_DIRECTORY ${DIR_NAME} + COMMAND make -j16 + RESULT_VARIABLE RESULT +) +if (NOT RESULT EQUAL 0) + message(FATAL_ERROR "Failed to build openssl. ${RESULT}") +endif() + +execute_process( + WORKING_DIRECTORY ${DIR_NAME} + COMMAND make install +) + +file(GLOB OPENSSL_LIB "${BUILD_DEPENDENCY_PATH}/${LIBDIR}/libssl.a") +file(GLOB CRYPTO_LIB "${BUILD_DEPENDENCY_PATH}/${LIBDIR}/libcrypto.a") +if (NOT OPENSSL_LIB OR NOT CRYPTO_LIB) + message(FATAL_ERROR "Failed to build openssl.") +endif() + +set(${PKG_NAME}_LIBRARIES "${OPENSSL_LIB};${CRYPTO_LIB}") +set(${PKG_NAME}_FOUND TRUE) diff --git a/msmonitor/dynolog_npu/cmake/config.ini b/msmonitor/dynolog_npu/cmake/config.ini new file mode 100644 index 0000000000000000000000000000000000000000..ef6cbf729d1082ac22f55511f1f1a2a93bfd3420 --- /dev/null +++ b/msmonitor/dynolog_npu/cmake/config.ini @@ -0,0 +1,2 @@ +[openssl] +url = https://gitee.com/mirrors/openssl.git \ No newline at end of file diff --git a/msmonitor/dynolog_npu/cmake/download_opensource.sh b/msmonitor/dynolog_npu/cmake/download_opensource.sh new file mode 100644 index 0000000000000000000000000000000000000000..f1ba1cd859263f00cb424567810789c9e562e6a6 --- /dev/null +++ b/msmonitor/dynolog_npu/cmake/download_opensource.sh @@ -0,0 +1,84 @@ +#!/bin/bash + +if [ "$#" -lt 2 ]; then + echo "Usage: $0 [ ] [ ]" + exit 1 +fi + +pkg_name=$1 +path=$2 + +if [ "$#" -ge 3 ]; then + sha256_value=$3 +fi +if [ "$#" -ge 4 ]; then + tag=$4 +fi + +url=$(awk -F " = " '/\['${pkg_name}'\]/{a=1}a==1&&$1~/url/{print $2;exit}' config.ini) +lib_path=$MSTT_LIB_PATH +if [ -n "$lib_path" ]; then + url=${lib_path}$(echo $url | awk -F '/' -v OFS='/' '{print $5,$8}') +fi +if [[ ! $url = https* ]]; then + echo "The URL of $pkg_name is illegal." + exit 1 +fi + +echo "Start to download ${url}..." + +if [ ! -d "$path" ]; then + echo "The specified path does not exist: $path" + exit 1 +fi +cd ${path} + +extension=$(echo "${url}" | awk -F'[./]' '{print $NF}') +if [[ "${extension}" == "gz" || "${extension}" == "zip" ]]; then + fullname="${path}/$(basename "${url}")" + if [[ -e ${fullname} ]]; then + echo "Source ${fullname} is exists, will not download again." + else + curl -L "${url}" -o ${fullname} -k + if [ $? -eq 0 ]; then + echo "Download successful: ${url}" + else + echo "Download failed: ${url}" + exit 1 + fi + fi + + if [[ ! -z "${sha256_value}" ]]; then + sha256data=$(sha256sum "${fullname}" | cut -d' ' -f1) + if [[ "${sha256data}" != "${sha256_value}" ]]; then + echo "Failed to verify sha256: ${url}" + exit 1 + fi + fi + + if [[ "${extension}" == "gz" ]]; then + tar -zxvf ${fullname} -C ./ -n > /dev/null + elif [[ "${extension}" == "zip" ]]; then + unzip -n ${fullname} -d ./ > /dev/null + fi +elif [[ "${extension}" == "git" ]]; then + repository="$(basename ${url} .git)" + if [[ -e ${repository} ]]; then + echo "Source ${repository} is exists, will not clone again." + else + if [[ -z "${tag}" ]]; then + git clone ${url} + else + git clone ${url} -b "${tag}" + fi + if [ $? -eq 0 ]; then + echo "Download successful: ${url}" + else + echo "Download failed: ${url}" + exit 1 + fi + fi +else + echo "Unknow url ${url}" + exit 1 +fi diff --git a/msmonitor/dynolog_npu/dynolog/src/CMakeLists.txt b/msmonitor/dynolog_npu/dynolog/src/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d9b17c187e3edf4f424bd5d1db2ae1f82c411096 --- /dev/null +++ b/msmonitor/dynolog_npu/dynolog/src/CMakeLists.txt @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +set(CMAKE_SKIP_RPATH TRUE) + +cmake_minimum_required(VERSION 3.16) +add_definitions(-DDYNOLOG_VERSION=${DYNOLOG_VERSION} -DDYNOLOG_GIT_REV=${DYNOLOG_GIT_REV}) + +message("Use Prometheus = ${USE_PROMETHEUS}") +message("Use ODS Graph API = ${USE_ODS_GRAPH_API}") +message("Use Tensorboard = ${USE_TENSORBOARD}") + +# our build script will first create a src/ dir where all source code will exist +file (GLOB dynolog_src "*.h" "*.cpp") + +# Remove main from library, only needed for exec. +list(REMOVE_ITEM dynolog_src "${CMAKE_CURRENT_SOURCE_DIR}/Main.cpp") +add_library(dynolog_lib ${dynolog_src}) + +if(USE_ODS_GRAPH_API) + target_compile_options(dynolog_lib PUBLIC "-DUSE_GRAPH_ENDPOINT") +endif() + +if(USE_TENSORBOARD) + target_compile_options(dynolog_lib PUBLIC "-DUSE_TENSORBOARD") +endif() + +if(USE_PROMETHEUS) + find_package(prometheus-cpp CONFIG REQUIRED) + add_definitions(-DUSE_PROMETHEUS) + target_link_libraries(dynolog_lib PRIVATE prometheus-cpp::pull) +endif() + +target_compile_options(dynolog_lib PRIVATE + -fPIC + -fstack-protector-all + -ftrapv +) + +target_link_options(dynolog_lib PRIVATE + -Wl,-z,relro,-z,now,-z,noexecstack + -s +) + +target_link_libraries(dynolog_lib PUBLIC Monitor) +target_link_libraries(dynolog_lib PUBLIC BuiltinMetrics) + +add_subdirectory(rpc) + +add_subdirectory(ipcfabric) +target_link_libraries(dynolog_lib PUBLIC dynolog_ipcfabric_lib) + +# depends on ipcfabric +add_subdirectory(tracing) +target_link_libraries(dynolog_lib PUBLIC dynolog_ipcmonitor_lib) + +add_subdirectory(gpumon) +target_link_libraries(dynolog_lib PUBLIC dynolog_dcgm_lib "-ldl") + +add_subdirectory(rdmamon) +target_link_libraries(dynolog_lib PUBLIC dynolog_rdmamon_lib) + +add_subdirectory(metric_frame) + +add_executable(dynolog Main.cpp) +target_link_libraries(dynolog PRIVATE dynolog_lib dynolog_rpc_lib) + +target_compile_options(dynolog PRIVATE + -fPIC + -fstack-protector-all + -ftrapv +) + +target_link_options(dynolog PRIVATE + -Wl,-z,relro,-z,now,-z,noexecstack + -s +) \ No newline at end of file diff --git a/msmonitor/dynolog_npu/dynolog/src/DynologTensorBoardLogger.cpp b/msmonitor/dynolog_npu/dynolog/src/DynologTensorBoardLogger.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5d6ae0d03edf1f74ae4f6537e2206c0527d7a581 --- /dev/null +++ b/msmonitor/dynolog_npu/dynolog/src/DynologTensorBoardLogger.cpp @@ -0,0 +1,173 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "DynologTensorBoardLogger.h" + +#include + +#include "hbt/src/common/System.h" + +#include +#include +#include +#include +#include +#include + +#ifdef USE_TENSORBOARD +DEFINE_string(metric_log_dir, "", "The Path to store tensorboard logs"); + +namespace dynolog { + +const std::string TensorBoardLoggerImpl::log_file_name_ = "tfevents.pb"; +std::filesystem::path TensorBoardLoggerManager::log_path_ = ""; + +DynologTensorBoardLogger::DynologTensorBoardLogger(const std::string& metric_log_dir) + : logPath_(metric_log_dir) +{ + if (!validateLogDir(logPath_)) { + throw std::runtime_error("Unable to record logs in the target folder"); + } + + LOG(INFO) << "Initialized tensorboard logger on = " << logPath_; +} + +void DynologTensorBoardLogger::finalize() +{ + TensorBoardLoggerManager::logPath(logPath_); + auto logging_guard = TensorBoardLoggerManager::singleton(); + auto prom = logging_guard.manager; + auto deviceId = dynamic_metrics_["deviceId"] == "-1" ? "host": dynamic_metrics_["deviceId"]; + auto kind = dynamic_metrics_["kind"]; + std::string real_tag = kind == "Marker" + ? kind + "/" + dynamic_metrics_["domain"] + : kind; + std::string metric_name = "duration"; + MsptiMetricDesc desc{deviceId, kind, real_tag, metric_name, kvs_["duration"]}; + prom->log(desc); +} + +bool DynologTensorBoardLogger::validateLogDir(const std::string& path) +{ + std::filesystem::path log_path(path); + + static const std::unordered_map INVALID_CHAR = { + {"\n", "\\n"}, {"\f", "\\f"}, {"\r", "\\r"}, {"\b", "\\b"}, {"\t", "\\t"}, + {"\v", "\\v"}, {"\u007F", "\\u007F"}, {"\"", "\\\""}, {"'", "\'"}, + {"\\", "\\\\"}, {"%", "\\%"}, {">", "\\>"}, {"<", "\\<"}, {"|", "\\|"}, + {"&", "\\&"}, {"$", "\\$"}, {";", "\\;"}, {"`", "\\`"} + }; + for (auto &item: INVALID_CHAR) { + if (path.find(item.first) != std::string::npos) { + LOG(ERROR) << "The path contains invalid character: " << item.second; + return false; + } + } + + if (!std::filesystem::exists(log_path)) { + LOG(ERROR) << "Error: Path does not exist: " << path; + return false; + } + + if (!std::filesystem::is_directory(log_path)) { + LOG(ERROR) << "Error: Path is not a directory: " << path; + return false; + } + + if (std::filesystem::is_symlink(log_path)) { + LOG(ERROR) << "Error: Path is a symbolic link: " << path; + return false; + } + + struct stat info; + if (stat(path.c_str(), &info) != 0) { + LOG(ERROR) << "Error: Cannot stat path: " << path; + return false; + } + + uid_t current_uid = getuid(); + if (info.st_uid != current_uid && current_uid != 0) { + LOG(ERROR) << "Error: Path is not owned by current user"; + return false; + } + return true; +} + +// static +std::shared_ptr TensorBoardLoggerManager::singleton_() +{ + static std::shared_ptr manager_ = + std::make_shared(); + return manager_; +} + +// static +TensorBoardLoggerManager::LoggingGuard TensorBoardLoggerManager::singleton() +{ + auto s = singleton_(); + return LoggingGuard{.manager = s, .lock_guard = s->lock()}; +} + +bool TensorBoardLoggerManager::isValidMetric(const MsptiMetricDesc &desc) +{ + auto it = validMetrics_.find(desc.kind_); + if (it == validMetrics_.end() || !it->second.count(desc.metric_name_)) { + return false; + } + return true; +} + +uint64_t TensorBoardLoggerManager::getCurStep(const std::string& device, const std::string& kind) +{ + auto key = std::make_pair(device, kind); + return device_kind2_step_[key]++; +} + +void TensorBoardLoggerManager::log(const MsptiMetricDesc& desc) +{ + if (!isValidMetric(desc)) { + return; + } + + auto device = desc.device_id_; + // 读取tensorboardImpl,调用Log方法写入 + auto it = device_loggers_.find(device); + std::shared_ptr logger; + if (it == device_loggers_.end()) { + std::string device_log_path = log_path_ / ("device_" + device); + device_loggers_[device] = std::make_shared(device_log_path, ""); + } + logger = device_loggers_[device]; + logger->log(desc.tag_, desc.val_, getCurStep(device, desc.kind_)); +} + +void TensorBoardLoggerImpl::log(const std::string &key, double val, uint64_t step) +{ + if (!std::filesystem::exists(log_path_)) { + std::error_code ec; + std::filesystem::create_directories(log_path_, ec); + if (ec) { + LOG(ERROR) << "failed to create log dir: " << log_path_ << "errorcode: " << ec.message(); + return; + } + } + + if (logger_ == nullptr) { + logger_ = std::make_shared(log_path_ / log_file_name_); + } + logger_->add_scalar(key, step, val); +} +} +#endif \ No newline at end of file diff --git a/msmonitor/dynolog_npu/dynolog/src/DynologTensorBoardLogger.h b/msmonitor/dynolog_npu/dynolog/src/DynologTensorBoardLogger.h new file mode 100644 index 0000000000000000000000000000000000000000..26241aaf46be01248c5140c8fdf5faabc4fadf90 --- /dev/null +++ b/msmonitor/dynolog_npu/dynolog/src/DynologTensorBoardLogger.h @@ -0,0 +1,122 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +#include "dynolog/src/Logger.h" + +#include "MsMonitorMetrics.h" + +#ifdef USE_TENSORBOARD + +#include "tensorboard_logger.h" + +DECLARE_string(metric_log_dir); + +namespace dynolog { + +class TensorBoardLoggerImpl { +public: + explicit TensorBoardLoggerImpl(std::string log_path, std::string tag = "") + : log_path_(log_path), tag_(tag) {}; + void log(const std::string& key, double val, uint64_t step); +private: + std::filesystem::path log_path_; + std::string tag_; + static const std::string log_file_name_; + std::shared_ptr logger_; +}; + +class TensorBoardLoggerManager { +public: + struct LoggingGuard { + std::shared_ptr manager; + std::lock_guard lock_guard; + }; + + void log(const MsptiMetricDesc& desc); + + static void logPath(const std::string& cfg_log_path) + { + log_path_ = cfg_log_path; + } + + static LoggingGuard singleton(); + + bool isValidMetric(const MsptiMetricDesc& desc); + + uint64_t getCurStep(const std::string& device, const std::string& kind); + +private: + std::lock_guard lock() + { + return std::lock_guard{mutex_}; + } + static std::shared_ptr singleton_(); + + std::mutex mutex_; + static std::filesystem::path log_path_; + + std::unordered_map> device_loggers_; + std::map, std::uint64_t> device_kind2_step_; +}; + +class DynologTensorBoardLogger final : public Logger { +public: + explicit DynologTensorBoardLogger(const std::string& metric_log_dir); + void setTimestamp(Timestamp ts) override {} + + void logInt(const std::string& key, int64_t val) override + { + kvs_[key] = static_cast(val); + } + + void logFloat(const std::string& key, float val) override + { + kvs_[key] = static_cast(val); + } + + void logUint(const std::string& key, uint64_t val) override + { + kvs_[key] = static_cast(val); + } + + // logStr for dynamic metris + void logStr(const std::string& key, const std::string& val) override + { + if (validDynamicMetrics_.count(key)) { + dynamic_metrics_[key] = val; + } + } + + void finalize() override; + +private: + bool validateLogDir(const std::string& path); + +private: + std::unordered_map kvs_; + std::unordered_map dynamic_metrics_; + std::string logPath_; + std::string hostName_; +}; + +} // namespace dynolog +#endif \ No newline at end of file diff --git a/msmonitor/dynolog_npu/dynolog/src/LibkinetoConfigManager.cpp b/msmonitor/dynolog_npu/dynolog/src/LibkinetoConfigManager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9fcbb141a90846bd7c5f4c22779807f13e70208d --- /dev/null +++ b/msmonitor/dynolog_npu/dynolog/src/LibkinetoConfigManager.cpp @@ -0,0 +1,361 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "dynolog/src/LibkinetoConfigManager.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "hbt/src/common/System.h" +#ifdef __linux__ +#include +#endif + +namespace dynolog { + +namespace { + +const int VerboseLevel = 2; +constexpr std::chrono::seconds kKeepAliveTimeSecs(60); +constexpr char kConfigFile[] = "/etc/libkineto.conf"; + +inline void setThreadName(const std::string& name) { +#ifdef __linux__ + constexpr size_t kMaxBuff = 16; + std::array buff; + std::size_t len = name.copy(buff.begin(), 0, kMaxBuff - 1); + buff[len] = '\0'; + ::prctl(PR_SET_NAME, buff.begin(), 0, 0, 0); +#endif +} + +} // namespace + +static std::string addTraceIdToConfigString( + const std::string& trace_id, + const std::string& config) { + const std::string kTraceIdIdentifier = "REQUEST_TRACE_ID"; + return fmt::format( + R"( + {} + {}={})", + config, + kTraceIdIdentifier, + trace_id); +} + +static std::string generateTraceId(int32_t pid) { + // Hostname + PID + timestamp should be a unique trace id in the context + // of this code's execution + std::string str_trace_id = fmt::format( + "{}:{}:{}", facebook::hbt::getHostName(), pid, std::time(nullptr)); + std::size_t hashed_trace_id = std::hash{}(str_trace_id); + return std::to_string(hashed_trace_id); +} + +LibkinetoConfigManager::LibkinetoConfigManager() { + managerThread_ = new std::thread(&LibkinetoConfigManager::start, this); +} + +LibkinetoConfigManager::~LibkinetoConfigManager() { + stopFlag_ = true; + managerCondVar_.notify_one(); + managerThread_->join(); + delete managerThread_; + managerThread_ = nullptr; +} + +std::shared_ptr LibkinetoConfigManager::getInstance() { + static auto instance = std::make_shared(); + return instance; +} + +void LibkinetoConfigManager::start() { + setThreadName("kinetoConfigMgr"); + // Periodically clean the job table and check base config changes. + // If a libkineto instance hasn't contacted us for a while, remove it. + LOG(INFO) << "Starting LibkinetoConfigManager runloop"; + while (true) { + refreshBaseConfig(); + std::unique_lock lock(mutex_); + managerCondVar_.wait_for(lock, kKeepAliveTimeSecs); + if (stopFlag_) { + break; + } + runGc(); + } +} + +// return "" on errors. Otherwise a config string. +static std::string readConfigFromConfigFile(const char* filename) { + // Read whole file into a string. + std::ifstream file(filename); + if (!file) { + return ""; + } + std::string conf; + try { + conf.assign( + std::istreambuf_iterator(file), std::istreambuf_iterator()); + } catch (std::exception& e) { + LOG(ERROR) << "Error in reading libkineto config from config file: " + << e.what(); + } + return conf; +} + +void LibkinetoConfigManager::refreshBaseConfig() { + auto cfg = readConfigFromConfigFile(kConfigFile); + if (!cfg.empty() && cfg != baseConfig_) { + std::lock_guard guard(mutex_); + baseConfig_ = cfg; + } +} + +void LibkinetoConfigManager::runGc() { + auto t = std::chrono::system_clock::now(); + int job_count = jobs_.size(); + for (auto job_it = jobs_.begin(); job_it != jobs_.end();) { + auto& procs = job_it->second; + for (auto proc_it = procs.begin(); proc_it != procs.end();) { + struct LibkinetoProcess& proc = proc_it->second; + if ((t - proc.lastRequestTime) > kKeepAliveTimeSecs) { + LOG(INFO) << fmt::format( + "Stopped tracking process ({}) from job {}", + fmt::join(proc_it->first, ","), + job_it->first); + onProcessCleanup(proc_it->first); + proc_it = procs.erase(proc_it); + } else { + proc_it++; + } + } + if (procs.empty()) { + LOG(INFO) << "Stopped tracking job " << job_it->first; + jobInstancesPerGpu_.erase(job_it->first); + job_it = jobs_.erase(job_it); + } else { + job_it++; + } + } + if (job_count != jobs_.size()) { + LOG(INFO) << "Tracked jobs: " << jobs_.size(); + } +} + +int32_t LibkinetoConfigManager::registerLibkinetoContext( + const std::string& jobId, + int32_t pid, + int32_t gpu) { + std::lock_guard guard(mutex_); + auto& instances = jobInstancesPerGpu_[jobId][gpu]; + instances.insert(pid); + LOG(INFO) << fmt::format("Registered process ({}) for job {}.", pid, jobId); + return instances.size(); +} + +// Called by libkineto instances periodically. +// In addition to returning a configuration string if one is found, +// register the jobId and set of pids with the config manager. +// This is how we keep track of running instances of libkineto. +// LibkinetoConfigManager::run() periodically scans the table +// for processes no longer calling this function and removes them. +std::string LibkinetoConfigManager::obtainOnDemandConfig( + const std::string& jobId, + const std::vector& pids, + int32_t configType) { + VLOG(VerboseLevel) << fmt::format( + "obtainOnDemandConfig({}, ({}), {})", + jobId, + fmt::join(pids, ","), + configType); + std::string ret; + std::set pids_set(pids.begin(), pids.end()); + std::lock_guard guard(mutex_); + + auto _emplace_result = jobs_[jobId].emplace(pids_set, LibkinetoProcess{}); + const auto& it = _emplace_result.first; + bool newProcess = _emplace_result.second; + struct LibkinetoProcess& process = it->second; + + if (newProcess) { + // First time - intialize! + // 'pids' is an ordered ancestor list starting with the + // child (leaf) process, i.e. the one making this request. + // Keep a copy of this pid so that clients can know which + // pids are being profiled. + process.pid = pids[0]; // Remember child (leaf) process + LOG(INFO) << fmt::format( + "Registered process ({}) for job {}.", fmt::join(pids, ", "), jobId); + + onRegisterProcess(pids_set); + } + if ((configType & int(LibkinetoConfigType::EVENTS)) && + !process.eventProfilerConfig.empty()) { + ret += process.eventProfilerConfig + "\n"; + process.eventProfilerConfig.clear(); + } + + if ((configType & int(LibkinetoConfigType::ACTIVITIES)) && + !process.activityProfilerConfig.empty()) { + ret += process.activityProfilerConfig + "\n"; + process.activityProfilerConfig.clear(); + } + // Track last request time so we know which libkineto instances + // are currently active. + process.lastRequestTime = std::chrono::system_clock::now(); + return ret; +} + +void LibkinetoConfigManager::setOnDemandConfigForProcess( + GpuProfilerResult& res, + LibkinetoProcess& process, + const std::string& config, + int32_t configType /* LibkinetoConfigType */, + int32_t limit) { + res.processesMatched.push_back(process.pid); + + if (res.eventProfilersTriggered.size() < limit && + (configType & int(LibkinetoConfigType::EVENTS))) { + if (process.eventProfilerConfig.empty()) { + process.eventProfilerConfig = config; + res.eventProfilersTriggered.push_back(process.pid); + } else { + res.eventProfilersBusy++; + } + } + if (res.activityProfilersTriggered.size() < limit && + (configType & int(LibkinetoConfigType::ACTIVITIES))) { + if (process.activityProfilerConfig.empty()) { + preCheckOnDemandConfig(process); + + std::string trace_id = generateTraceId(process.pid); + std::string updatedConfig = addTraceIdToConfigString(trace_id, config); + + res.activityProfilersTriggered.push_back(process.pid); + process.activityProfilerConfig = updatedConfig; + res.traceIds.push_back(trace_id); + + LOG(INFO) << " PID: " << process.pid << ", Trace Id: " << trace_id; + } else { + res.activityProfilersBusy++; + } + } +} + +// Called by clients to control one or more libkineto instances. +// The config is any legal libkineto on-demand config (see wiki). +// Set config type to indicate whether this request is for +// event profiling, activity profiling or both. +// The limit argument is used when the job uses multiple processes or +// the pid is a parent pid of multiple processes with libkineto. +// For example, when specifying a pid with 8 child processes, +// the limit argument can be used to profile 2 of those. +GpuProfilerResult LibkinetoConfigManager::setOnDemandConfig( + const std::string& jobId, + const std::set& pids, + const std::string& config, + int32_t configType /* LibkinetoConfigType */, + int32_t limit) { + LOG(INFO) << fmt::format( + + "Initiating on-demand GPU profiling for job ID {}, pids [{}]", + jobId, + fmt::join(pids, ",")); + + GpuProfilerResult res; + res.activityProfilersBusy = 0; + res.eventProfilersBusy = 0; + + size_t nPids = pids.size(); + // For backwards compatibility with older versions of the dyno CLI, + // there are two conditions under which all processes should be traced: + // 1. target PIDs are empty + // 2. target PIDs contain a single PID, 0. + // As older versions of the CLI are phased out, 2) will no longer need to be + // accounted for. + bool traceAllPids = nPids == 0 || (nPids == 1 && *pids.begin() == 0); + { + std::lock_guard guard(mutex_); + if (auto it = jobs_.find(jobId); it != jobs_.end()) { + auto& processes = it->second; + for (auto& pair : processes) { + for (const auto& pid : pair.first) { + // Trace the process if we find a match or target pids is empty. + if (traceAllPids || pids.find(pid) != pids.end()) { + auto& process = pair.second; + setOnDemandConfigForProcess( + res, process, config, configType, limit); + // the user could provide multiple pids that belong to the same the + // LibkientoProcess object, so we break after the first match for + // the LibkinetoProcess. + break; + } + } + } + if (res.activityProfilersTriggered.size() > 0) { + onSetOnDemandConfig(pids); + } + } + } + + LOG(INFO) << "On-demand request: " << res.processesMatched.size() + << " matching processes"; + if (configType & int(LibkinetoConfigType::EVENTS)) { + LOG(INFO) << "Installed event profiler config for " + << res.eventProfilersTriggered.size() << " process(es) " << "(" + << res.eventProfilersBusy << " busy)"; + } + if (configType & int(LibkinetoConfigType::ACTIVITIES)) { + LOG(INFO) << "Installed activity profiler config for " + << res.activityProfilersTriggered.size() << " process(es) " << "(" + << res.activityProfilersBusy << " busy)"; + } + return res; +} + +int LibkinetoConfigManager::processCount(const std::string& jobId) const { + int count = 0; + std::lock_guard guard(mutex_); + auto it = jobs_.find(jobId); + if (it != jobs_.end()) { + count = it->second.size(); + } + LOG(INFO) << "Process count for job ID " << jobId << ": " << count; + return count; +} + +void LibkinetoConfigManager::updateNpuStatus( + const std::string& jobId, + int32_t pid, + int32_t status, + const std::string& msgType) { + // jobId, pid为预留参数,目前无用 + std::lock_guard guard(mutex_); + if (msgType == kLibkinetoTraceStatus) { + npuTraceStatus_ = status; + } else if (msgType == kLibkinetoMonitorStatus) { + npuMonitorStatus_ = status; + } +} + +int32_t LibkinetoConfigManager::getNpuTraceStatus() +{ + std::lock_guard guard(mutex_); + return npuTraceStatus_; +} + +int32_t LibkinetoConfigManager::getNpuMonitorStatus() +{ + std::lock_guard guard(mutex_); + return npuMonitorStatus_; +} + +} // namespace dynolog diff --git a/msmonitor/dynolog_npu/dynolog/src/LibkinetoConfigManager.h b/msmonitor/dynolog_npu/dynolog/src/LibkinetoConfigManager.h new file mode 100644 index 0000000000000000000000000000000000000000..437609ea0492a32624c2202777e48fe821248840 --- /dev/null +++ b/msmonitor/dynolog_npu/dynolog/src/LibkinetoConfigManager.h @@ -0,0 +1,110 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "dynolog/src/LibkinetoTypes.h" + +namespace dynolog { + +const std::string kLibkinetoTraceStatus = "npuTraceStatus"; +const std::string kLibkinetoMonitorStatus = "npuMonitorStatus"; + +class LibkinetoConfigManager { + public: + LibkinetoConfigManager(); + virtual ~LibkinetoConfigManager(); + + int32_t + registerLibkinetoContext(const std::string& jobId, int32_t pid, int32_t gpu); + static std::shared_ptr getInstance(); + + std::string getBaseConfig() { + std::lock_guard guard(mutex_); + return baseConfig_; + } + + std::string obtainOnDemandConfig( + const std::string& jobId, + const std::vector& pids, + int32_t configType); + + GpuProfilerResult setOnDemandConfig( + const std::string& jobId, + const std::set& pids, + const std::string& config, + int32_t configType, + int32_t limit); + + void updateNpuStatus(const std::string& jobId, int32_t pid, int32_t status, const std::string& msgType); + int32_t getNpuTraceStatus(); + int32_t getNpuMonitorStatus(); + + // Return the number of active libkineto processes + // with the given Chronos / Tangram Job Id + int processCount(const std::string& jobId) const; + + protected: + struct LibkinetoProcess { + int32_t pid; + std::chrono::system_clock::time_point lastRequestTime; + std::string eventProfilerConfig; + std::string activityProfilerConfig; + }; + + // A few callbacks for additional instrumentation. + virtual void onRegisterProcess(const std::set& /*pids*/) {} + + virtual void preCheckOnDemandConfig(const LibkinetoProcess& /*process*/) {} + + virtual void onSetOnDemandConfig(const std::set& /*pids*/) {} + + virtual void onProcessCleanup(const std::set& /*pids*/) {} + + // Map of pid ancestry -> LibkinetoProcess + using ProcessMap = std::map, LibkinetoProcess>; + std::map jobs_; + + // Map of gpu id -> pids + using InstancesPerGpuMap = std::map>; + // Job id -> InstancesPerGpu + std::map jobInstancesPerGpu_; + mutable std::mutex mutex_; + + void setOnDemandConfigForProcess( + GpuProfilerResult& res, + LibkinetoProcess& process, + const std::string& config, + int32_t configType, + int32_t limit); + + private: + // Garbage collection and config refresh - periodically clean up + // data from terminated processes. + void start(); + void runGc(); + void refreshBaseConfig(); + + std::string baseConfig_; + std::thread* managerThread_{nullptr}; + std::atomic_bool stopFlag_{false}; + std::condition_variable managerCondVar_; + int32_t npuTraceStatus_ = 0; + int32_t npuMonitorStatus_ = 0; + // mutable std::mutex mutex_; TODO make private again +}; + +} // namespace dynolog diff --git a/dynolog_npu/dynolog_npu/dynolog/src/Main.cpp b/msmonitor/dynolog_npu/dynolog/src/Main.cpp similarity index 57% rename from dynolog_npu/dynolog_npu/dynolog/src/Main.cpp rename to msmonitor/dynolog_npu/dynolog/src/Main.cpp index 8e5177768327e37173d4e7661e334a9400bd6172..89c263947971d38a0c2b9d8d1c5dd4e1cd833104 100644 --- a/dynolog_npu/dynolog_npu/dynolog/src/Main.cpp +++ b/msmonitor/dynolog_npu/dynolog/src/Main.cpp @@ -15,6 +15,7 @@ #include "dynolog/src/KernelCollector.h" #include "dynolog/src/Logger.h" #include "dynolog/src/ODSJsonLogger.h" + #include "dynolog/src/PerfMonitor.h" #include "dynolog/src/ScubaLogger.h" #include "dynolog/src/ServiceHandler.h" @@ -28,6 +29,10 @@ #include "dynolog/src/PrometheusLogger.h" #endif +#ifdef USE_TENSORBOARD +#include "dynolog/src/DynologTensorBoardLogger.h" +#endif + using namespace dynolog; using json = nlohmann::json; namespace hbt = facebook::hbt; @@ -62,39 +67,47 @@ DEFINE_bool( "Enabled GPU monitorng, currently supports NVIDIA GPUs."); DEFINE_bool(enable_perf_monitor, false, "Enable heartbeat perf monitoring."); -std::unique_ptr getLogger(const std::string& scribe_category = "") { - std::vector> loggers; +std::unique_ptr getLogger(const std::string& scribe_category = "") +{ + std::vector> loggers; #ifdef USE_PROMETHEUS - if (FLAGS_use_prometheus) { + if (FLAGS_use_prometheus) { loggers.push_back(std::make_unique()); - } + } #endif - if (FLAGS_use_fbrelay) { +#ifdef USE_TENSORBOARD + if (!FLAGS_metric_log_dir.empty()) { + loggers.push_back(std::make_unique(FLAGS_metric_log_dir)); + } +#endif + if (FLAGS_use_fbrelay) { loggers.push_back(std::make_unique()); - } - if (FLAGS_use_ODS) { + } + if (FLAGS_use_ODS) { loggers.push_back(std::make_unique()); - } - if (FLAGS_use_JSON) { + } + if (FLAGS_use_JSON) { loggers.push_back(std::make_unique()); - } - if (FLAGS_use_scuba && !scribe_category.empty()) { + } + if (FLAGS_use_scuba && !scribe_category.empty()) { loggers.push_back(std::make_unique(scribe_category)); - } - return std::make_unique(std::move(loggers)); + } + return std::make_unique(std::move(loggers)); } -auto next_wakeup(int sec) { - return std::chrono::steady_clock::now() + std::chrono::seconds(sec); +auto next_wakeup(int sec) +{ + return std::chrono::steady_clock::now() + std::chrono::seconds(sec); } -void kernel_monitor_loop() { - KernelCollector kc; +void kernel_monitor_loop() +{ + KernelCollector kc; - LOG(INFO) << "Running kernel monitor loop : interval = " + LOG(INFO) << "Running kernel monitor loop : interval = " << FLAGS_kernel_monitor_reporting_interval_s << " s."; - while (1) { + while (1) { auto logger = getLogger(); auto wakeup_timepoint = next_wakeup(FLAGS_kernel_monitor_reporting_interval_s); @@ -105,20 +118,21 @@ void kernel_monitor_loop() { /* sleep override */ std::this_thread::sleep_until(wakeup_timepoint); - } + } } -void perf_monitor_loop() { - PerfMonitor pm( +void perf_monitor_loop() +{ + PerfMonitor pm( hbt::CpuSet::makeAllOnline(), std::vector{"instructions", "cycles"}, getDefaultPmuDeviceManager(), getDefaultMetrics()); - LOG(INFO) << "Running perf monitor loop : interval = " - << FLAGS_perf_monitor_reporting_interval_s << " s."; + LOG(INFO) << "Running perf monitor loop : interval = " + << FLAGS_perf_monitor_reporting_interval_s << " s."; - while (1) { + while (1) { auto logger = getLogger(); auto wakeup_timepoint = next_wakeup(FLAGS_perf_monitor_reporting_interval_s); @@ -129,22 +143,24 @@ void perf_monitor_loop() { logger->finalize(); /* sleep override */ std::this_thread::sleep_until(wakeup_timepoint); - } + } } -auto setup_server(std::shared_ptr handler) { - return std::make_unique>( - handler, FLAGS_port); +auto setup_server(std::shared_ptr handler) +{ + return std::make_unique>( + handler, FLAGS_port); } -void gpu_monitor_loop(std::shared_ptr dcgm) { - auto logger = getLogger(FLAGS_scribe_category); +void gpu_monitor_loop(std::shared_ptr dcgm) +{ + auto logger = getLogger(FLAGS_scribe_category); - LOG(INFO) << "Running DCGM loop : interval = " - << FLAGS_dcgm_reporting_interval_s << " s."; - LOG(INFO) << "DCGM fields: " << gpumon::FLAGS_dcgm_fields; + LOG(INFO) << "Running DCGM loop : interval = " + << FLAGS_dcgm_reporting_interval_s << " s."; + LOG(INFO) << "DCGM fields: " << gpumon::FLAGS_dcgm_fields; - while (1) { + while (1) { auto wakeup_timepoint = next_wakeup(FLAGS_dcgm_reporting_interval_s); dcgm->update(); @@ -152,55 +168,66 @@ void gpu_monitor_loop(std::shared_ptr dcgm) { /* sleep override */ std::this_thread::sleep_until(wakeup_timepoint); - } + } } -int main(int argc, char** argv) { - gflags::ParseCommandLineFlags(&argc, &argv, true); - FLAGS_logtostderr = 1; - google::InitGoogleLogging(argv[0]); +int main(int argc, char** argv) +{ + gflags::ParseCommandLineFlags(&argc, &argv, true); + FLAGS_logtostderr = 1; + google::InitGoogleLogging(argv[0]); - LOG(INFO) << "Starting Ascend Extension for dynolog, version = " DYNOLOG_VERSION - << ", build git-hash = " DYNOLOG_GIT_REV; + LOG(INFO) << "Starting Ascend Extension for dynolog, version = " DYNOLOG_VERSION + << ", build git-hash = " DYNOLOG_GIT_REV; - std::shared_ptr dcgm; + std::shared_ptr dcgm; - std::unique_ptr ipcmon; - std::unique_ptr ipcmon_thread, gpumon_thread, pm_thread; + std::shared_ptr ipcmon; + std::unique_ptr ipcmon_thread; + std::unique_ptr data_ipcmon_thread; + std::unique_ptr gpumon_thread; + std::unique_ptr pm_thread; - if (FLAGS_enable_ipc_monitor) { + if (FLAGS_enable_ipc_monitor) { LOG(INFO) << "Starting IPC Monitor"; - ipcmon = std::make_unique(); + ipcmon = std::make_shared(); + ipcmon->setLogger(std::move(getLogger())); ipcmon_thread = - std::make_unique([&ipcmon]() { ipcmon->loop(); }); - } + std::make_unique([ipcmon]() { ipcmon->loop(); }); + data_ipcmon_thread = + std::make_unique([ipcmon]() { ipcmon->dataLoop(); }); + } - if (FLAGS_enable_gpu_monitor) { + if (FLAGS_enable_gpu_monitor) { dcgm = gpumon::DcgmGroupInfo::factory( gpumon::FLAGS_dcgm_fields, FLAGS_dcgm_reporting_interval_s * 1000); gpumon_thread = std::make_unique(gpu_monitor_loop, dcgm); - } - std::thread km_thread{kernel_monitor_loop}; - if (FLAGS_enable_perf_monitor) { + } + std::thread km_thread{kernel_monitor_loop}; + if (FLAGS_enable_perf_monitor) { pm_thread = std::make_unique(perf_monitor_loop); - } + } + + // setup service + auto handler = std::make_shared(dcgm); - // setup service - auto handler = std::make_shared(dcgm); + // use simple json RPC server for now + // In the current scenario, the process can only be terminated and all threads closed using Ctrl+C. + auto server = setup_server(handler); + server->run(); - // use simple json RPC server for now - auto server = setup_server(handler); - server->run(); + if (km_thread.joinable()) { + km_thread.join(); + } - km_thread.join(); - if (pm_thread) { + if (pm_thread && pm_thread->joinable()) { pm_thread->join(); - } - if (gpumon_thread) { + } + if (gpumon_thread && gpumon_thread->joinable()) { gpumon_thread->join(); - } + } - server->stop(); + server->stop(); - return 0; -} \ No newline at end of file + return 0; +} diff --git a/msmonitor/dynolog_npu/dynolog/src/Metrics.cpp b/msmonitor/dynolog_npu/dynolog/src/Metrics.cpp new file mode 100644 index 0000000000000000000000000000000000000000..959bb1778ced1cd55ba1070c72b7f39a13cbe85f --- /dev/null +++ b/msmonitor/dynolog_npu/dynolog/src/Metrics.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "dynolog/src/Metrics.h" + +#include +#include + +namespace dynolog { + +const std::vector getAllMetrics() +{ + static std::vector metrics_ = { + {.name = "kindName", + .type = MetricType::Instant, + .desc = "Report data kind name"}, + {.name = "duration", + .type = MetricType::Delta, + .desc = "Total execution time for corresponding kind"}, + {.name = "timestamp", + .type = MetricType::Instant, + .desc = "The timestamp of the reported data"}, + {.name = "deviceId", + .type = MetricType::Instant, + .desc = "The ID of the device for reporting data"}, + }; + return metrics_; +} + +// These metrics are dynamic per network drive +const std::vector getNetworkMetrics() +{ + static std::vector metrics_ = {}; + return metrics_; +} + +} // namespace dynolog \ No newline at end of file diff --git a/msmonitor/dynolog_npu/dynolog/src/MsMonitorMetrics.h b/msmonitor/dynolog_npu/dynolog/src/MsMonitorMetrics.h new file mode 100644 index 0000000000000000000000000000000000000000..e5c4b9c677f3f22e90aa8c11feb32f6402abd4aa --- /dev/null +++ b/msmonitor/dynolog_npu/dynolog/src/MsMonitorMetrics.h @@ -0,0 +1,48 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DYNOLOG_NPU_MSMONITOR_METRICS_H +#define DYNOLOG_NPU_MSMONITOR_METRICS_H + +#include +#include +#include + +namespace dynolog { + +const std::unordered_set validDynamicMetrics_ { + {"deviceId", "kind", "domain"} +}; + +const std::unordered_map> validMetrics_ { + {"Marker", {"duration"}}, + {"Kernel", {"duration"}}, + {"API", {"duration"}}, + {"Hccl", {"duration"}}, + {"Memory", {"duration"}}, + {"MemSet", {"duration"}}, + {"MemCpy", {"duration"}} +}; + +struct MsptiMetricDesc { + std::string device_id_; + std::string kind_; + std::string tag_; + std::string metric_name_; + double val_; +}; +} // namespace dynolog + +#endif \ No newline at end of file diff --git a/msmonitor/dynolog_npu/dynolog/src/rpc/CMakeLists.txt b/msmonitor/dynolog_npu/dynolog/src/rpc/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..754f55fd4ab3197718baad5d94a63db7ea3b2b0d --- /dev/null +++ b/msmonitor/dynolog_npu/dynolog/src/rpc/CMakeLists.txt @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake") +find_package(openssl REQUIRED) + +add_library(dynolog_rpc_lib STATIC + SimpleJsonServer.cpp SimpleJsonServer.h + ${CMAKE_CURRENT_SOURCE_DIR}/../ServiceHandler.h +) + +target_include_directories(dynolog_rpc_lib + INTERFACE ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_include_directories(dynolog_rpc_lib + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/.. +) +target_link_libraries(dynolog_rpc_lib PRIVATE dynolog_lib) +target_link_libraries(dynolog_rpc_lib PUBLIC gflags::gflags) +target_link_libraries(dynolog_rpc_lib PUBLIC glog::glog) +target_link_libraries(dynolog_rpc_lib PUBLIC nlohmann_json::nlohmann_json) +target_link_libraries(dynolog_rpc_lib PUBLIC fmt::fmt) +target_link_libraries(dynolog_rpc_lib PRIVATE ${openssl_LIBRARIES}) diff --git a/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.cpp b/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..78117789e7373357434f95ae4d6927f1f3ac6b1e --- /dev/null +++ b/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.cpp @@ -0,0 +1,792 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "dynolog/src/rpc/SimpleJsonServer.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +DEFINE_string(certs_dir, "", "TLS crets dir"); + +constexpr int CLIENT_QUEUE_LEN = 50; +const std::string NO_CERTS_MODE = "NO_CERTS"; +const size_t MIN_RSA_KEY_LENGTH = 3072; +constexpr char BACKSPACE_ASCII = 8; +constexpr char DEL_ASCII = 127; + +namespace dynolog { + +SimpleJsonServerBase::SimpleJsonServerBase(int port) : port_(port) +{ + try { + initSocket(); + if (FLAGS_certs_dir != NO_CERTS_MODE) { + init_openssl(); + ctx_ = create_context(); + configure_context(ctx_); + } + } catch (const std::exception& e) { + LOG(ERROR) << "Failed to initialize server: " << e.what(); + if (sock_fd_ != -1) { + close(sock_fd_); + sock_fd_ = -1; + } + throw; + } +} + +SimpleJsonServerBase::~SimpleJsonServerBase() +{ + if (thread_) { + stop(); + } + close(sock_fd_); + if (FLAGS_certs_dir != NO_CERTS_MODE && ctx_) { + SSL_CTX_free(ctx_); + } +} + +void SimpleJsonServerBase::initSocket() +{ + struct sockaddr_in6 server_addr; + + /* Create socket for listening (client requests). */ + sock_fd_ = ::socket(AF_INET6, SOCK_STREAM, 0); + if (sock_fd_ == -1) { + int error_code = errno; + std::perror("socket()"); + throw std::runtime_error("socket() failed, error is " + std::string(std::strerror(error_code))); + } + + /* Set socket to reuse address in case server is restarted. */ + int flag = 1; + int ret = + ::setsockopt(sock_fd_, SOL_SOCKET, SO_REUSEADDR, &flag, sizeof(flag)); + if (ret == -1) { + int error_code = errno; + std::perror("setsockopt()"); + throw std::runtime_error("setsockopt() failed, error is " + std::string(std::strerror(error_code))); + } + + // in6addr_any allows us to bind to both IPv4 and IPv6 clients. + server_addr.sin6_addr = in6addr_any; + server_addr.sin6_family = AF_INET6; + server_addr.sin6_port = htons(port_); + + /* Bind address and socket together */ + ret = ::bind(sock_fd_, (struct sockaddr*)&server_addr, sizeof(server_addr)); + if (ret == -1) { + int error_code = errno; + std::perror("bind()"); + close(sock_fd_); + throw std::runtime_error("bind() failed, error is " + std::string(std::strerror(error_code))); + } + + /* Create listening queue (client requests) */ + ret = ::listen(sock_fd_, CLIENT_QUEUE_LEN); + if (ret == -1) { + int error_code = errno; + std::perror("listen()"); + close(sock_fd_); + throw std::runtime_error("listen() failed, error is " + std::string(std::strerror(error_code))); + } + + /* Get port if assigned 0 */ + if (port_ == 0) { + socklen_t len_out = sizeof(server_addr); + ret = ::getsockname(sock_fd_, (struct sockaddr*)&server_addr, &len_out); + if (ret < 0 || len_out != sizeof(server_addr)) { + std::perror("getsockname()"); + } else { + port_ = ntohs(server_addr.sin6_port); + LOG(INFO) << "System assigned port = " << ntohs(server_addr.sin6_port); + } + } + + LOG(INFO) << "Listening to connections on port " << port_; + initSuccess_ = true; +} + +/* A simple wrapper to accept connections and read data + * + * Messages are prefixed using the length so we know how long a message + * to actually read. + * : int32_t len + * : char json[] + */ +class ClientSocketWrapper { +public: + ~ClientSocketWrapper() + { + if (FLAGS_certs_dir != NO_CERTS_MODE && ssl_) { + int shutdown_ret = SSL_shutdown(ssl_); + if (shutdown_ret <= 0) { + LOG(ERROR) << "SSL_shutdown failed, error code: " << shutdown_ret; + shutdown_ret = SSL_shutdown(ssl_); + } + SSL_free(ssl_); + } + if (client_sock_fd_ != -1) { + ::close(client_sock_fd_); + } + } + + bool accept(int server_socket, SSL_CTX* ctx) + { + struct sockaddr_in6 client_addr; + socklen_t client_addr_len = sizeof(client_addr); + std::array client_addr_str; + + client_sock_fd_ = ::accept( + server_socket, (struct sockaddr*)&client_addr, &client_addr_len); + if (client_sock_fd_ == -1) { + std::perror("accept()"); + return false; + } + + inet_ntop( + AF_INET6, + &(client_addr.sin6_addr), + client_addr_str.data(), + client_addr_str.size()); + LOG(INFO) << "Received connection from " << client_addr_str.data(); + + if (FLAGS_certs_dir == NO_CERTS_MODE) { + LOG(INFO) << "No certs mode"; + return true; + } + + ssl_ = SSL_new(ctx); + SSL_set_fd(ssl_, client_sock_fd_); + if (SSL_accept(ssl_) <= 0) { + ERR_print_errors_fp(stderr); + return false; + } + LOG(INFO) << "SSL handshake success"; + return true; + } + + std::string get_message() + { + int32_t msg_size = -1; + if (!read_helper((uint8_t*)&msg_size, sizeof(msg_size)) || msg_size <= 0) { + LOG(ERROR) << "Invalid message size = " << msg_size; + return ""; + } + + std::string message; + message.resize(msg_size); + int recv = 0; + int ret = 1; + while (recv < msg_size && ret > 0) { + ret = read_helper((uint8_t*)&message[recv], msg_size - recv); + recv += ret > 0 ? ret : 0; + } + + if (recv != msg_size) { + LOG(ERROR) << "Received partial message, expected size " << msg_size + << " found : " << recv; + LOG(ERROR) << "Message received = " << message; + return ""; + } + + return message; + } + + bool send_response(const std::string& response) + { + int32_t size = response.size(); + int ret; + if (FLAGS_certs_dir == NO_CERTS_MODE) { + ret = ::write(client_sock_fd_, (void*)&size, sizeof(size)); + if (ret == -1) { + std::perror("write()"); + return false; + } + } else { + ret = SSL_write(ssl_, (void*)&size, sizeof(size)); + if (ret <= 0) { + ERR_print_errors_fp(stderr); + return false; + } + } + int sent = 0; + while (sent < size && ret > 0) { + if (FLAGS_certs_dir == NO_CERTS_MODE) { + ret = ::write(client_sock_fd_, (void*)&response[sent], size - sent); + if (ret == -1) { + std::perror("write()"); + } else { + sent += ret; + } + } else { + ret = SSL_write(ssl_, (void*)&response[sent], size - sent); + if (ret <= 0) { + ERR_print_errors_fp(stderr); + } else { + sent += ret; + } + } + } + + if (sent < response.size()) { + LOG(ERROR) << "Unable to write full response"; + return false; + } + return ret > 0; + } + +private: + int read_helper(uint8_t* buf, int size) + { + if (FLAGS_certs_dir == NO_CERTS_MODE) { + int ret = ::read(client_sock_fd_, (void*)buf, size); + if (ret == -1) { + std::perror("read()"); + } + return ret; + } + int ret = SSL_read(ssl_, (void*)buf, size); + if (ret <= 0) { + ERR_print_errors_fp(stderr); + } + return ret; + } + int client_sock_fd_ = -1; + SSL* ssl_ = nullptr; +}; + +/* Accepts socket connections and processes the payloads. + * This will inturn call the Handler functions */ +void SimpleJsonServerBase::loop() noexcept +{ + if (sock_fd_ == -1 || !initSuccess_) { + return; + } + + while (run_) { + processOne(); + } +} + +void SimpleJsonServerBase::processOne() noexcept +{ + LOG(INFO) << "Waiting for connection."; + ClientSocketWrapper client; + if (!client.accept(sock_fd_, ctx_)) { + return; + } + std::string request_str = client.get_message(); + LOG(INFO) << "RPC message received = " << request_str; + auto response_str = processOneImpl(request_str); + if (response_str.empty()) { + return; + } + if (!client.send_response(response_str)) { + LOG(ERROR) << "Failed to send response"; + } +} + +void SimpleJsonServerBase::run() +{ + LOG(INFO) << "Launching RPC thread"; + thread_ = std::make_unique([this]() { this->loop(); }); +} + +void SimpleJsonServerBase::init_openssl() +{ + #if OPENSSL_VERSION_NUMBER >= 0x10100000L + // OpenSSL 1.1.0+ (包括3.0+) + OPENSSL_init_ssl(OPENSSL_INIT_LOAD_SSL_STRINGS | + OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); + #else + // OpenSSL 1.0.x 及更早版本 + SSL_load_error_strings(); + OpenSSL_add_ssl_algorithms(); + #endif +} + +SSL_CTX* SimpleJsonServerBase::create_context() +{ + const SSL_METHOD* method = TLS_server_method(); + SSL_CTX* ctx = SSL_CTX_new(method); + if (!ctx) { + ERR_print_errors_fp(stderr); + throw std::runtime_error("Unable to create SSL context"); + } + return ctx; +} + +static bool is_cert_revoked(X509* cert, X509_STORE* store) +{ + if (!cert || !store) { + LOG(ERROR) << "Invalid certificate or store pointer"; + return false; + } + // 获取证书的颁发者名称 + X509_NAME* issuer = X509_get_issuer_name(cert); + if (!issuer) { + LOG(ERROR) << "Failed to get certificate issuer"; + return false; + } + // 获取证书的序列号 + const ASN1_INTEGER* serial = X509_get_serialNumber(cert); + if (!serial) { + LOG(ERROR) << "Failed to get certificate serial number"; + return false; + } + // 创建证书验证上下文 + X509_STORE_CTX* ctx = X509_STORE_CTX_new(); + if (!ctx) { + LOG(ERROR) << "Failed to create certificate store context"; + return false; + } + bool is_revoked = false; + STACK_OF(X509_CRL)* crls = nullptr; + try { + // 初始化证书验证上下文 + if (!X509_STORE_CTX_init(ctx, store, cert, nullptr)) { + LOG(ERROR) << "Failed to initialize certificate store context"; + X509_STORE_CTX_free(ctx); + return false; + } + // 获取CRL列表 + crls = X509_STORE_CTX_get1_crls(ctx, issuer); + if (!crls) { + LOG(INFO) << "No CRLs found for issuer"; + X509_STORE_CTX_free(ctx); + return false; + } + time_t current_time = time(nullptr); + for (int i = 0; i < sk_X509_CRL_num(crls); i++) { + X509_CRL* crl = sk_X509_CRL_value(crls, i); + if (!crl) { + LOG(ERROR) << "Invalid CRL at index " << i; + continue; + } + // 检查 CRL 的有效期 + const ASN1_TIME* crl_this_update = X509_CRL_get0_lastUpdate(crl); + const ASN1_TIME* crl_next_update = X509_CRL_get0_nextUpdate(crl); + if (!crl_this_update) { + LOG(ERROR) << "Failed to get CRL this_update time"; + continue; + } + // 检查 CRL 是否已生效 + if (X509_cmp_time(crl_this_update, ¤t_time) > 0) { + LOG(INFO) << "CRL is not yet valid"; + continue; + } + // 检查 CRL 是否过期 + if (crl_next_update) { + if (X509_cmp_time(crl_next_update, ¤t_time) < 0) { + LOG(INFO) << "CRL has expired"; + continue; + } + } + // 检查证书是否在 CRL 中 + STACK_OF(X509_REVOKED)* revoked = X509_CRL_get_REVOKED(crl); + if (revoked) { + for (int j = 0; j < sk_X509_REVOKED_num(revoked); j++) { + X509_REVOKED* rev = sk_X509_REVOKED_value(revoked, j); + if (rev) { + const ASN1_INTEGER* rev_serial = X509_REVOKED_get0_serialNumber(rev); + if (rev_serial && ASN1_INTEGER_cmp(serial, rev_serial) == 0) { + LOG(INFO) << "Certificate is found in CRL"; + is_revoked = true; + break; + } + } + } + } + if (is_revoked) { + break; + } + } + } catch (const std::exception& e) { + LOG(ERROR) << "Exception while checking CRL: " << e.what(); + is_revoked = false; + } + if (crls) { + sk_X509_CRL_pop_free(crls, X509_CRL_free); + } + X509_STORE_CTX_free(ctx); + return is_revoked; +} + +// 禁用终端回显的函数,但显示星号 +int get_password_with_stars(char* buf, size_t bufsize) +{ + char* secure_buf = static_cast(CRYPTO_secure_malloc(bufsize, __FILE__, __LINE__)); + if (!secure_buf) { + return -1; /* 内存申请失败 */ + } + struct termios old_flags; + struct termios new_flags; + size_t idx = 0; + if (tcgetattr(fileno(stdin), &old_flags) != 0) { + CRYPTO_secure_free(secure_buf, __FILE__, __LINE__); + return -1; + } + new_flags = old_flags; + new_flags.c_lflag &= ~ECHO; + tcsetattr(fileno(stdin), TCSANOW, &new_flags); + + int ch; + while ((ch = getchar()) != '\n' && ch != EOF && idx + 1 < bufsize) { + if (ch == DEL_ASCII || ch == BACKSPACE_ASCII) { + if (idx > 0) { + idx--; + printf("\b \b"); + fflush(stdout); + } + } else { + secure_buf[idx++] = static_cast (ch); + printf("*"); + fflush(stdout); + } + } + secure_buf[idx] = '\0'; + std::copy_n(secure_buf, idx + 1, buf); + tcsetattr(fileno(stdin), TCSANOW, &old_flags); + OPENSSL_cleanse(secure_buf, bufsize); + CRYPTO_secure_free(secure_buf, __FILE__, __LINE__); + return idx; +} + +// 验证证书版本和签名算法 +void SimpleJsonServerBase::verify_certificate_version_and_algorithm(X509* cert) +{ + // 1. 检查证书版本是否为 X.509v3 + if (X509_get_version(cert) != 2) { // 2 表示 X.509v3 + throw std::runtime_error("Certificate is not X.509v3"); + } + + // 2. 检查证书签名算法 + const X509_ALGOR* sig_alg = X509_get0_tbs_sigalg(cert); + if (!sig_alg) { + throw std::runtime_error("Failed to get signature algorithm"); + } + + int sig_nid = OBJ_obj2nid(sig_alg->algorithm); + // 检查是否使用不安全的算法 + if (sig_nid == NID_md2WithRSAEncryption || + sig_nid == NID_md5WithRSAEncryption || + sig_nid == NID_sha1WithRSAEncryption) { + throw std::runtime_error("Certificate uses insecure signature algorithm: " + std::to_string(sig_nid)); + } +} + +// 验证 RSA 密钥长度 +void SimpleJsonServerBase::verify_rsa_key_length(EVP_PKEY* pkey) +{ + if (EVP_PKEY_base_id(pkey) == EVP_PKEY_RSA) { + size_t key_length = 0; +#if OPENSSL_VERSION_NUMBER >= 0x30000000L + // OpenSSL 3.0 及以上版本 + key_length = EVP_PKEY_get_size(pkey) * 8; // 转换为位数 +#else + // OpenSSL 1.1.1 及以下版本 + RSA* rsa = EVP_PKEY_get0_RSA(pkey); + if (!rsa) { + throw std::runtime_error("Failed to get RSA key"); + } + + const BIGNUM* n = nullptr; + RSA_get0_key(rsa, &n, nullptr, nullptr); + if (!n) { + throw std::runtime_error("Failed to get RSA modulus"); + } + + key_length = BN_num_bits(n); +#endif + if (key_length < MIN_RSA_KEY_LENGTH) { + throw std::runtime_error("RSA key length " + std::to_string(key_length) + " bits is less than required " + std::to_string(MIN_RSA_KEY_LENGTH) + " bits"); + } + } +} + +// 验证证书有效期 +void SimpleJsonServerBase::verify_certificate_validity(X509* cert) +{ + ASN1_TIME* not_before = X509_get_notBefore(cert); + ASN1_TIME* not_after = X509_get_notAfter(cert); + if (!not_before || !not_after) { + throw std::runtime_error("Failed to get certificate validity period"); + } + + time_t current_time = time(nullptr); + struct tm tm_before = {}; + struct tm tm_after = {}; + if (!ASN1_TIME_to_tm(not_before, &tm_before) || + !ASN1_TIME_to_tm(not_after, &tm_after)) { + throw std::runtime_error("Failed to convert certificate dates"); + } + + time_t not_before_time = mktime(&tm_before); + time_t not_after_time = mktime(&tm_after); + + // 检查证书是否已生效 + if (current_time < not_before_time) { + BIO* bio = BIO_new(BIO_s_mem()); + if (bio) { + ASN1_TIME_print(bio, not_before); + char* not_before_str = nullptr; + long len = BIO_get_mem_data(bio, ¬_before_str); + if (len > 0) { + std::string time_str(not_before_str, len); + BIO_free(bio); + throw std::runtime_error("Server certificate is not yet valid. Valid from: " + time_str); + } + BIO_free(bio); + } + throw std::runtime_error("Server certificate is not yet valid"); + } + + // 检查证书是否已过期 + if (current_time > not_after_time) { + BIO* bio = BIO_new(BIO_s_mem()); + if (bio) { + ASN1_TIME_print(bio, not_after); + char* not_after_str = nullptr; + long len = BIO_get_mem_data(bio, ¬_after_str); + if (len > 0) { + std::string time_str(not_after_str, len); + BIO_free(bio); + throw std::runtime_error("Server certificate has expired. Expired at: " + time_str); + } + BIO_free(bio); + } + throw std::runtime_error("Server certificate has expired"); + } +} + +// 验证证书扩展域 +void SimpleJsonServerBase::verify_certificate_extensions(X509* cert) +{ + bool has_ca_constraint = false; + bool has_key_usage = false; + bool has_cert_sign = false; + bool has_crl_sign = false; + + const STACK_OF(X509_EXTENSION)* exts = X509_get0_extensions(cert); + if (exts) { + for (int i = 0; i < sk_X509_EXTENSION_num(exts); i++) { + X509_EXTENSION* ext = sk_X509_EXTENSION_value(exts, i); + ASN1_OBJECT* obj = X509_EXTENSION_get_object(ext); + + if (OBJ_obj2nid(obj) == NID_basic_constraints) { + BASIC_CONSTRAINTS* constraints = (BASIC_CONSTRAINTS*)X509V3_EXT_d2i(ext); + if (constraints) { + has_ca_constraint = constraints->ca; + BASIC_CONSTRAINTS_free(constraints); + } + } else if (OBJ_obj2nid(obj) == NID_key_usage) { + ASN1_BIT_STRING* usage = (ASN1_BIT_STRING*)X509V3_EXT_d2i(ext); + if (usage) { + has_key_usage = true; + if (usage->data) { + has_cert_sign = (usage->data[0] & KU_KEY_CERT_SIGN) != 0; + has_crl_sign = (usage->data[0] & KU_CRL_SIGN) != 0; + } else { + /* 位串为空,视为两项均未置位 */ + has_cert_sign = false; + has_crl_sign = false; + } + ASN1_BIT_STRING_free(usage); + } + } + } + } + + if (has_ca_constraint) { + throw std::runtime_error("Client certificate should not have CA constraint"); + } + if (!has_key_usage) { + throw std::runtime_error("Client certificate must have key usage extension"); + } +} + +// 加载私钥 +void SimpleJsonServerBase::load_private_key(SSL_CTX* ctx, const std::string& server_key) +{ + FILE* key_file = fopen(server_key.c_str(), "r"); + if (!key_file) { + throw std::runtime_error("Failed to open server key file"); + } + + bool is_encrypted = false; + char line[256]; + while (fgets(line, sizeof(line), key_file)) { + if (strstr(line, "ENCRYPTED")) { + is_encrypted = true; + break; + } + } + rewind(key_file); + + if (is_encrypted) { + char password[256] = {0}; + std::cout << "Please enter the certificate password: "; + get_password_with_stars(password, sizeof(password)); + std::cout << std::endl; + + EVP_PKEY* pkey = PEM_read_PrivateKey( + key_file, + nullptr, + [](char* buf, int size, int rwflag, void* userdata) -> int { + const char* password = static_cast(userdata); + int pwlen = strlen(password); + if (pwlen > size) return 0; + std::copy(password, password + pwlen, buf); + return pwlen; + }, + password); + + fclose(key_file); + // 直接清空 char[] 密码 + std::fill(std::begin(password), std::end(password), 0); + + if (!pkey) { + ERR_print_errors_fp(stderr); + throw std::runtime_error("Failed to load encrypted server private key"); + } + + if (SSL_CTX_use_PrivateKey(ctx, pkey) <= 0) { + EVP_PKEY_free(pkey); + ERR_print_errors_fp(stderr); + throw std::runtime_error("Failed to use server private key"); + } + + EVP_PKEY_free(pkey); + } else { + fclose(key_file); + if (SSL_CTX_use_PrivateKey_file(ctx, server_key.c_str(), SSL_FILETYPE_PEM) <= 0) { + ERR_print_errors_fp(stderr); + throw std::runtime_error("Failed to load server private key"); + } + } +} + +// 加载和验证 CRL +void SimpleJsonServerBase::load_and_verify_crl(SSL_CTX* ctx, const std::string& crl_file) +{ + X509_STORE* store = SSL_CTX_get_cert_store(ctx); + if (!store) { + throw std::runtime_error("Failed to get certificate store"); + } + + if (access(crl_file.c_str(), F_OK) != -1) { + FILE* crl_file_ptr = fopen(crl_file.c_str(), "r"); + if (!crl_file_ptr) { + LOG(WARNING) << "Failed to open CRL file: " << crl_file; + return; + } + + X509_CRL* crl = PEM_read_X509_CRL(crl_file_ptr, nullptr, nullptr, nullptr); + fclose(crl_file_ptr); + + if (!crl) { + LOG(WARNING) << "Failed to read CRL from file: " << crl_file; + return; + } + + if (X509_STORE_add_crl(store, crl) != 1) { + LOG(WARNING) << "Failed to add CRL to certificate store"; + X509_CRL_free(crl); + return; + } + + X509* cert = SSL_CTX_get0_certificate(ctx); + if (!cert) { + X509_CRL_free(crl); + throw std::runtime_error("Failed to get server certificate"); + } + + if (is_cert_revoked(cert, store)) { + X509_CRL_free(crl); + throw std::runtime_error("Server certificate is revoked"); + } + + X509_CRL_free(crl); + } +} + +void SimpleJsonServerBase::configure_context(SSL_CTX* ctx) +{ + if (FLAGS_certs_dir.empty()) { + throw std::runtime_error("--certs-dir must be specified!"); + } + + std::string certs_dir = FLAGS_certs_dir; + if (!certs_dir.empty() && certs_dir.back() != '/') + certs_dir += '/'; + + std::string server_cert = certs_dir + "server.crt"; + std::string server_key = certs_dir + "server.key"; + std::string ca_cert = certs_dir + "ca.crt"; + std::string crl_file = certs_dir + "ca.crl"; + + LOG(INFO) << "Loading server cert: " << server_cert; + LOG(INFO) << "Loading server key: " << server_key; + LOG(INFO) << "Loading CA cert: " << ca_cert; + + // 1. 加载并验证服务器证书 + if (SSL_CTX_use_certificate_file(ctx, server_cert.c_str(), SSL_FILETYPE_PEM) <= 0) { + ERR_print_errors_fp(stderr); + throw std::runtime_error("Failed to load server certificate"); + } + + X509* cert = SSL_CTX_get0_certificate(ctx); + if (!cert) { + throw std::runtime_error("Failed to get server certificate"); + } + + // 2. 验证证书版本和签名算法 + verify_certificate_version_and_algorithm(cert); + + // 3. 验证 RSA 密钥长度 + EVP_PKEY* pkey = X509_get_pubkey(cert); + if (!pkey) { + throw std::runtime_error("Failed to get public key"); + } + verify_rsa_key_length(pkey); + EVP_PKEY_free(pkey); + + // 4. 验证证书有效期 + verify_certificate_validity(cert); + + // 5. 验证证书扩展域 + verify_certificate_extensions(cert); + + // 6. 加载私钥 + load_private_key(ctx, server_key); + + // 7. 加载 CA 证书 + if (SSL_CTX_load_verify_locations(ctx, ca_cert.c_str(), NULL) <= 0) { + ERR_print_errors_fp(stderr); + throw std::runtime_error("Failed to load CA certificate"); + } + + // 8. 加载和验证 CRL + load_and_verify_crl(ctx, crl_file); + + // 9. 设置证书验证选项 + SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, NULL); +} + +} // namespace dynolog \ No newline at end of file diff --git a/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.h b/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.h new file mode 100644 index 0000000000000000000000000000000000000000..9efe0d086ca8ff606db972a1da94ddcf7eb4734a --- /dev/null +++ b/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.h @@ -0,0 +1,80 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "dynolog/src/ServiceHandler.h" + +DECLARE_string(certs_dir); + +namespace dynolog { +// This is a simple service built using UNIX Sockets +// with remote procedure calls implemented via JSON string. +class SimpleJsonServerBase { +public: + explicit SimpleJsonServerBase(int port); + virtual ~SimpleJsonServerBase(); + + int getPort() const + { + return port_; + } + + bool initSuccessful() const + { + return initSuccess_; + } + // spin up a new thread to process requets + void run(); + + void stop() + { + run_ = 0; + thread_->join(); + } + + // synchronously processes a request + void processOne() noexcept; + +protected: + void initSocket(); + void init_openssl(); + SSL_CTX *create_context(); + void configure_context(SSL_CTX *ctx); + + // process requests in a loop + void loop() noexcept; + + // implement processing of request using the handler + virtual std::string processOneImpl(const std::string &request_str) + { + return ""; + } + + void verify_certificate_version_and_algorithm(X509 *cert); + void verify_rsa_key_length(EVP_PKEY *pkey); + void verify_certificate_validity(X509 *cert); + void verify_certificate_extensions(X509 *cert); + void load_private_key(SSL_CTX *ctx, const std::string &server_key); + void load_and_verify_crl(SSL_CTX *ctx, const std::string &crl_file); + + int port_; + int sock_fd_{-1}; + bool initSuccess_{false}; + + std::atomic run_{true}; + std::unique_ptr thread_; + + SSL_CTX *ctx_{nullptr}; +}; + +} // namespace dynolog \ No newline at end of file diff --git a/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServerInl.h b/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServerInl.h new file mode 100644 index 0000000000000000000000000000000000000000..b51d973e4ff3f7d3f4434f79bac186640f8f0e79 --- /dev/null +++ b/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServerInl.h @@ -0,0 +1,150 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include +#include "dynolog/src/rpc/SimpleJsonServer.h" + +namespace dynolog { + +template +class SimpleJsonServer : public SimpleJsonServerBase { + public: + explicit SimpleJsonServer(std::shared_ptr handler, int port) + : SimpleJsonServerBase(port), handler_(std::move(handler)) {} + + ~SimpleJsonServer() {} + + std::string processOneImpl(const std::string& request) override; + + private: + std::shared_ptr handler_; +}; + +// convert to json and validate the request message +// the request should contain : +// { "fn" : "" +// .. +// } + +nlohmann::json toJson(const std::string& message) { + using json = nlohmann::json; + json result; + if (message.empty()) { + return result; + } + try { + result = json::parse(message); + } catch (json::parse_error&) { + LOG(ERROR) << "Error parsing message = " << message; + return result; + } + + if (result.empty() || !result.is_object()) { + LOG(ERROR) + << "Request message should not be empty and should be json object."; + return json(); + } + + if (!result.contains("fn")) { + LOG(ERROR) << "Request must contain a 'fn' field for the RPC call " + << " request json = " << result.dump(); + return json(); + } + + return result; +} + +std::string GetCommandStatus(const std::string& configStr) +{ + auto npuTraceStatus = LibkinetoConfigManager::getInstance()->getNpuTraceStatus(); + auto npuMonitorStatus = LibkinetoConfigManager::getInstance()->getNpuMonitorStatus(); + std::string prefix = "NPU_MONITOR_START"; + if (configStr.compare(0, prefix.size(), prefix) == 0) { + if (npuTraceStatus == 1) { + return "ineffective"; + } + else if (npuTraceStatus == 0) { + return "effective"; + } + else { + return "unknown"; + } + } else { + if (npuMonitorStatus == 1) { + return "ineffective"; + } + else if (npuMonitorStatus == 0) { + return "effective"; + } + else { + return "unknown"; + } + } +} + +template +std::string SimpleJsonServer::processOneImpl( + const std::string& request_str) { + using json = nlohmann::json; + json request = toJson(request_str); + json response; + + if (request.empty()) { + LOG(ERROR) << "Failed parsing request, continuing ..."; + return ""; + } + + if (request["fn"] == "getStatus") { + response["status"] = handler_->getStatus(); + } else if (request["fn"] == "getVersion") { + response["version"] = handler_->getVersion(); + } else if (request["fn"] == "setKinetOnDemandRequest") { + if (!request.contains("config") || !request.contains("pids")) { + response["status"] = "failed"; + } else { + try { + std::string config = request.value("config", ""); + std::vector pids = request.at("pids").get>(); + std::set pids_set{pids.begin(), pids.end()}; // TODO directly convert? + + int job_id = request.value("job_id", 0); + int process_limit = request.value("process_limit", 1000); + auto result = handler_->setKinetOnDemandRequest(job_id, pids_set, config, process_limit); + auto commandStatus = GetCommandStatus(config); + response["commandStatus"] = commandStatus; + response["processesMatched"] = result.processesMatched; + response["eventProfilersTriggered"] = result.eventProfilersTriggered; + response["activityProfilersTriggered"] = result.activityProfilersTriggered; + response["eventProfilersBusy"] = result.eventProfilersBusy; + response["activityProfilersBusy"] = result.activityProfilersBusy; + } catch (const std::exception& ex) { + LOG(ERROR) << "setKinetOnDemandRequest: parsing exception = " << ex.what(); + response["status"] = fmt::format("failed with exception = {}", ex.what()); + } + } + } else if (request["fn"] == "dcgmProfPause") { + if (!request.contains("duration_s")) { + response["status"] = "failed"; + } else { + int duration_s = request.value("duration_s", 300); + bool result = handler_->dcgmProfPause(duration_s); + response["status"] = result; + } + } else if (request["fn"] == "dcgmProfResume") { + bool result = handler_->dcgmProfResume(); + response["status"] = result; + } else { + LOG(ERROR) << "Unknown RPC call = " << request["fn"]; + return ""; + } + + return response.dump(); +} + +} // namespace dynolog diff --git a/msmonitor/dynolog_npu/dynolog/src/tracing/CMakeLists.txt b/msmonitor/dynolog_npu/dynolog/src/tracing/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..4afd436bcc378db13f6b925fbd319c7b381a5f2b --- /dev/null +++ b/msmonitor/dynolog_npu/dynolog/src/tracing/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +add_library (dynolog_ipcmonitor_lib IPCMonitor.cpp IPCMonitor.h + ${CMAKE_CURRENT_SOURCE_DIR}/../LibkinetoConfigManager.h +) + +target_include_directories(dynolog_ipcmonitor_lib + INTERFACE ${CMAKE_CURRENT_SOURCE_DIR} +) +target_include_directories(dynolog_ipcmonitor_lib + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/.. +) + +target_link_libraries(dynolog_ipcmonitor_lib PUBLIC glog::glog) +target_link_libraries(dynolog_ipcmonitor_lib PUBLIC dynolog_ipcfabric_lib) +target_link_libraries(dynolog_ipcmonitor_lib PUBLIC nlohmann_json::nlohmann_json) diff --git a/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.cpp b/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..096c03716d96ae84218d3c1293099aa22b3e3484 --- /dev/null +++ b/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.cpp @@ -0,0 +1,221 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "dynolog/src/tracing/IPCMonitor.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "dynolog/src/LibkinetoConfigManager.h" +#include "dynolog/src/ipcfabric/Utils.h" + +namespace dynolog { +namespace tracing { + +constexpr int kSleepUs = 10000; +constexpr int kDataMsgSleepUs = 1000; +const std::string kLibkinetoRequest = "req"; +const std::string kLibkinetoContext = "ctxt"; +const std::string kLibkinetoData = "data"; + +IPCMonitor::IPCMonitor(const std::string& ipc_fabric_name) +{ + ipc_manager_ = FabricManager::factory(ipc_fabric_name); + data_ipc_manager_ = FabricManager::factory(ipc_fabric_name + "_data"); + // below ensures singleton exists + LOG(INFO) << "Kineto config manager : active processes = " + << LibkinetoConfigManager::getInstance()->processCount("0"); +} + +void IPCMonitor::loop() +{ + while (ipc_manager_) { + if (ipc_manager_->recv()) { + std::unique_ptr msg = ipc_manager_->retrieve_msg(); + processMsg(std::move(msg)); + } + /* sleep override */ + usleep(kSleepUs); + } +} + +void IPCMonitor::dataLoop() +{ + while (data_ipc_manager_) { + if (data_ipc_manager_->recv()) { + std::unique_ptr msg = data_ipc_manager_->retrieve_msg(); + processDataMsg(std::move(msg)); + } + /* sleep override */ + usleep(kDataMsgSleepUs); + } +} + +void IPCMonitor::processMsg(std::unique_ptr msg) +{ + if (!ipc_manager_) { + LOG(ERROR) << "Fabric Manager not initialized"; + return; + } + // sizeof(msg->metadata.type) = 32, well above the size of the constant + // strings we are comparing against. memcmp is safe + if (memcmp( // NOLINT(facebook-security-vulnerable-memcmp) + msg->metadata.type, + kLibkinetoContext.data(), + kLibkinetoContext.size()) == 0) { + registerLibkinetoContext(std::move(msg)); + } else if ( + memcmp( // NOLINT(facebook-security-vulnerable-memcmp) + msg->metadata.type, + kLibkinetoRequest.data(), + kLibkinetoRequest.size()) == 0) { + getLibkinetoOnDemandRequest(std::move(msg)); + } else if ( + memcmp( // NOLINT(facebook-security-vulnerable-memcmp) + msg->metadata.type, + kLibkinetoTraceStatus.data(), + kLibkinetoTraceStatus.size()) == 0) { + updateLibkinetoStatus(std::move(msg), kLibkinetoTraceStatus); + } else if ( + memcmp( // NOLINT(facebook-security-vulnerable-memcmp) + msg->metadata.type, + kLibkinetoMonitorStatus.data(), + kLibkinetoMonitorStatus.size()) == 0) { + updateLibkinetoStatus(std::move(msg), kLibkinetoMonitorStatus); + } else { + LOG(ERROR) << "TYPE UNKOWN: " << msg->metadata.type; + } +} + +void tracing::IPCMonitor::setLogger(std::unique_ptr logger) +{ + logger_ = std::move(logger); +} + +void IPCMonitor::LogData(const nlohmann::json& result) +{ + auto timestamp = result["timestamp"].get(); + logger_->logUint("timestamp", timestamp); + auto duration = result["duration"].get(); + logger_->logUint("duration", duration); + auto deviceId = result["deviceId"].get(); + logger_->logStr("deviceId", std::to_string(deviceId)); + auto kind = result["kind"].get(); + logger_->logStr("kind", kind); + if (result.contains("domain") && result["domain"].is_string()) { + auto domain = result["domain"].get(); + logger_->logStr("domain", domain); + } + if (result.contains("name") && result["name"].is_string()) { + auto name = result["name"].get(); + logger_->logStr("name", name); + } + logger_->finalize(); +} + +void IPCMonitor::processDataMsg(std::unique_ptr msg) +{ + if (!data_ipc_manager_) { + LOG(ERROR) << "Fabric Manager not initialized"; + return; + } + if (memcmp( // NOLINT(facebook-security-vulnerable-memcmp) + msg->metadata.type, + kLibkinetoData.data(), + kLibkinetoData.size()) == 0) { + std::string message = std::string((char*)msg->buf.get(), msg->metadata.size); + try { + nlohmann::json result = nlohmann::json::parse(message); + LOG(INFO) << "Received data message : " << result; + LogData(result); + } catch (nlohmann::json::parse_error&) { + LOG(ERROR) << "Error parsing message = " << message; + return; + } + } else { + LOG(ERROR) << "TYPE UNKOWN: " << msg->metadata.type; + } +} + +void IPCMonitor::getLibkinetoOnDemandRequest( + std::unique_ptr msg) +{ + if (!ipc_manager_) { + LOG(ERROR) << "Fabric Manager not initialized"; + return; + } + std::string ret_config = ""; + ipcfabric::LibkinetoRequest* req = + (ipcfabric::LibkinetoRequest*)msg->buf.get(); + if (req->n == 0) { + LOG(ERROR) << "Missing pids parameter for type " << req->type; + return; + } + std::vector pids(req->pids, req->pids + req->n); + try { + ret_config = LibkinetoConfigManager::getInstance()->obtainOnDemandConfig( + std::to_string(req->jobid), pids, req->type); + } catch (const std::runtime_error& ex) { + LOG(ERROR) << "Kineto config manager exception : " << ex.what(); + } + std::unique_ptr ret = + ipcfabric::Message::constructMessage( + ret_config, kLibkinetoRequest); + if (!ipc_manager_->sync_send(*ret, msg->src)) { + LOG(ERROR) << "Failed to return config to libkineto: IPC sync_send fail"; + } + return; +} + +void IPCMonitor::registerLibkinetoContext( + std::unique_ptr msg) +{ + if (!ipc_manager_) { + LOG(ERROR) << "Fabric Manager not initialized"; + return; + } + ipcfabric::LibkinetoContext* ctxt = + (ipcfabric::LibkinetoContext*)msg->buf.get(); + int32_t size = -1; + try { + size = LibkinetoConfigManager::getInstance()->registerLibkinetoContext( + std::to_string(ctxt->jobid), ctxt->pid, ctxt->gpu); + } catch (const std::runtime_error& ex) { + LOG(ERROR) << "Kineto config manager exception : " << ex.what(); + } + std::unique_ptr ret = + ipcfabric::Message::constructMessage( + size, kLibkinetoContext); + if (!ipc_manager_->sync_send(*ret, msg->src)) { + LOG(ERROR) << "Failed to send ctxt from dyno: IPC sync_send fail"; + } + return; +} + +void IPCMonitor::updateLibkinetoStatus( + std::unique_ptr msg, const std::string& msgType) +{ + struct NpuStatus { + int32_t status; + pid_t pid; + int64_t jobId; + }; + NpuStatus* status = (NpuStatus*)msg->buf.get(); + try { + LibkinetoConfigManager::getInstance()->updateNpuStatus( + std::to_string(status->jobId), status->pid, status->status, msgType); + } catch (const std::runtime_error& ex) { + LOG(ERROR) << "Kineto config manager exception when updateNpuStatus: " << ex.what(); + } +} + +} // namespace tracing +} // namespace dynolog diff --git a/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.h b/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.h new file mode 100644 index 0000000000000000000000000000000000000000..b9c2c2ebc205c5af56a83c11fe822931ab044674 --- /dev/null +++ b/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.h @@ -0,0 +1,46 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +// Use glog for FabricManager.h +#define USE_GOOGLE_LOG + +#include "dynolog/src/ipcfabric/FabricManager.h" +#include "dynolog/src/Logger.h" + +namespace dynolog { +namespace tracing { + +class IPCMonitor { +public: + using FabricManager = dynolog::ipcfabric::FabricManager; + IPCMonitor(const std::string& ipc_fabric_name = "dynolog"); + virtual ~IPCMonitor() {} + + void loop(); + void dataLoop(); + +public: + virtual void processMsg(std::unique_ptr msg); + virtual void processDataMsg(std::unique_ptr msg); + void getLibkinetoOnDemandRequest(std::unique_ptr msg); + void registerLibkinetoContext(std::unique_ptr msg); + void updateLibkinetoStatus(std::unique_ptr msg, const std::string& msgType); + void setLogger(std::unique_ptr logger); + void LogData(const nlohmann::json& result); + + std::unique_ptr ipc_manager_; + std::unique_ptr data_ipc_manager_; + std::unique_ptr logger_; + + // friend class test_case_name##_##test_name##_Test + friend class IPCMonitorTest_LibkinetoRegisterAndOndemandTest_Test; +}; + +} // namespace tracing +} // namespace dynolog diff --git a/msmonitor/plugin/CMakeLists.txt b/msmonitor/plugin/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..6795fa89d091f478f36760a0028b2a191826f8a2 --- /dev/null +++ b/msmonitor/plugin/CMakeLists.txt @@ -0,0 +1,74 @@ +cmake_minimum_required(VERSION 3.16) +project(IPCMonitor) + +set(CMAKE_SKIP_RPATH TRUE) + +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +find_package(pybind11 REQUIRED) +find_package(Python REQUIRED COMPONENTS Interpreter Development) + +set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake") +set(ENV{PROJECT_ROOT_PATH} "${CMAKE_SOURCE_DIR}") +include(utils) +find_package(glog MODULE REQUIRED) +find_package(nlohmannjson MODULE REQUIRED) +find_package(sqlite3 MODULE REQUIRED) + +include_directories( + ${CMAKE_CURRENT_SOURCE_DIR}/ipc_monitor + ${CMAKE_CURRENT_SOURCE_DIR}/ipc_monitor/metric + ${CMAKE_CURRENT_SOURCE_DIR}/ipc_monitor/mspti_monitor + ${CMAKE_CURRENT_SOURCE_DIR}/third_party/securec/include +) + +file(GLOB_RECURSE IPC_SOURCES + ${CMAKE_CURRENT_SOURCE_DIR}/ipc_monitor/*.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ipc_monitor/metric/*.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ipc_monitor/mspti_monitor/*.cpp +) + +file(GLOB_RECURSE SECUREC_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/third_party/securec/src/*.c) + +set(SOURCES + bindings.cpp + ${IPC_SOURCES} + ${SECUREC_SOURCES} + ${sqlite3_SOURCES} +) + +add_library(IPCMonitor MODULE ${SOURCES}) + +set_target_properties(IPCMonitor + PROPERTIES + OUTPUT_NAME IPCMonitor_C + LIBRARY_OUTPUT_DIRECTORY ${CMAKE_INSTALL_PREFIX}/IPCMonitor/lib64 + PREFIX "" +) + +target_link_libraries(IPCMonitor PRIVATE + pybind11::module + pthread + ${glog_LIBRARIES} + ${CMAKE_CURRENT_SOURCE_DIR}/stub/libmspti.so +) + +target_compile_options(IPCMonitor PRIVATE + -fPIC + -fstack-protector-all + -ftrapv +) + +target_link_options(IPCMonitor PRIVATE + -Wl,-z,relro,-z,now,-z,noexecstack + -s +) + +set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -O0 -g") +set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -g") +set(CMAKE_EXE_LINKER_FLAGS_DEBUG "${CMAKE_EXE_LINKER_FLAGS_DEBUG} -O0 -g") + +set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -D_FORTIFY_SOURCE=2 -O2") +set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} -D_FORTIFY_SOURCE=2 -O2") diff --git a/msmonitor/plugin/IPCMonitor/__init__.py b/msmonitor/plugin/IPCMonitor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bd4cb27a8858f009c75b3b220c7001114e1adb2e --- /dev/null +++ b/msmonitor/plugin/IPCMonitor/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .dynamic_monitor_proxy import PyDynamicMonitorProxy diff --git a/msmonitor/plugin/IPCMonitor/dynamic_monitor_proxy.py b/msmonitor/plugin/IPCMonitor/dynamic_monitor_proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..b2ceea962c0cd6ed47c1fe5cf1f10aca493f4dfb --- /dev/null +++ b/msmonitor/plugin/IPCMonitor/dynamic_monitor_proxy.py @@ -0,0 +1,50 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import importlib +from .singleton import Singleton +from .utils import get_parallel_group_info + +so_path = os.path.join(os.path.dirname(__file__), "lib64") +sys.path.append(os.path.realpath(so_path)) +ipcMonitor_C_module = importlib.import_module("IPCMonitor_C") + + +@Singleton +class PyDynamicMonitorProxy: + + @classmethod + def init_dyno(cls, npu_id: int): + return ipcMonitor_C_module.init_dyno(npu_id) + + @classmethod + def poll_dyno(cls): + return ipcMonitor_C_module.poll_dyno() + + @classmethod + def enable_dyno_npu_monitor(cls, config_map: dict): + if str(config_map.get("NPU_MONITOR_STOP")).lower() in ("true", "1"): + ipcMonitor_C_module.set_cluster_config_data({"parallel_group_info": get_parallel_group_info()}) + ipcMonitor_C_module.enable_dyno_npu_monitor(config_map) + + @classmethod + def finalize_dyno(cls): + ipcMonitor_C_module.finalize_dyno() + + @classmethod + def update_profiler_status(cls, status: dict): + ipcMonitor_C_module.update_profiler_status(status) \ No newline at end of file diff --git a/dynolog_npu/plugin/setup.py b/msmonitor/plugin/IPCMonitor/singleton.py similarity index 42% rename from dynolog_npu/plugin/setup.py rename to msmonitor/plugin/IPCMonitor/singleton.py index 151b9b3fb3fa1a42e147685f632163c8b3f5a564..6386c2ab66c0d403421f47576b6dce9694b26569 100644 --- a/dynolog_npu/plugin/setup.py +++ b/msmonitor/plugin/IPCMonitor/singleton.py @@ -1,42 +1,25 @@ -# Copyright (c) 2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from setuptools import setup -from pybind11.setup_helpers import Pybind11Extension - -BASE_DIR = os.path.dirname(os.path.realpath(__file__)) - -# Define the extension module -ext_modules = [ - Pybind11Extension( - "IPCMonitor", # Name of the Python module - sources=["bindings.cpp", - "ipc_monitor/utils.cpp", - "ipc_monitor/DynoLogNpuMonitor.cpp", - "ipc_monitor/NpuIpcClient.cpp", - ], # Source files - include_dirs=[os.path.join(BASE_DIR, "ipc_monitor")], # Include Pybind11 headers - language="c++", # Specify the language - ), -] - -# Set up the package -setup( - name="dynolog_npu_plugin", - version="0.1", - description="dynolog npu plugins", - ext_modules=ext_modules, - install_requires=["pybind11"], -) \ No newline at end of file +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Singleton(object): + def __init__(self, cls): + self._cls = cls + self._instance = {} + + def __call__(self): + if self._cls not in self._instance: + self._instance[self._cls] = self._cls() + return self._instance[self._cls] diff --git a/msmonitor/plugin/IPCMonitor/utils.py b/msmonitor/plugin/IPCMonitor/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9656b1ef73934df2e38b3b1ea577e5b9d2c84922 --- /dev/null +++ b/msmonitor/plugin/IPCMonitor/utils.py @@ -0,0 +1,142 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import warnings +from typing import Optional + + +def get_pytorch_rank_id() -> Optional[int]: + """Get pytorch rank id.""" + try: + import torch + rank_id = os.environ.get("RANK") + if rank_id is None and torch.distributed.is_available() and torch.distributed.is_initialized(): + rank_id = torch.distributed.get_rank() + if rank_id is not None and not isinstance(rank_id, int): + rank_id = int(rank_id) + except Exception as ex: + raise RuntimeError(f"Get rank id failed in pytorch: {str(ex)}") from ex + return rank_id + + +def get_pytorch_parallel_group_info() -> str: + """Get pytorch parallel group info.""" + try: + import torch + from torch.distributed.distributed_c10d import _world as distributed_world + if torch.distributed.is_available() and torch.distributed.is_initialized(): + group_info = {} + global_rank = torch.distributed.get_rank() + for group in distributed_world.pg_map.keys(): + if torch.distributed.get_backend(group) != "hccl": + continue + hccl_group = group._get_backend(torch.device("npu")) + comm_name = hccl_group.get_hccl_comm_name(global_rank, init_comm=False) + if comm_name: + group_info[comm_name] = { + "group_name": hccl_group.options.hccl_config.get("group_name", ""), + "group_rank": torch.distributed.get_group_rank(group, global_rank), + "global_ranks": torch.distributed.get_process_group_ranks(group) + } + default_group = torch.distributed.distributed_c10d._get_default_group() + comm_name = default_group._get_backend(torch.device("npu")).get_hccl_comm_name(global_rank, init_comm=False) + if comm_name: + group_info[comm_name] = { + "group_name": "default_group", + "group_rank": torch.distributed.get_group_rank(default_group, global_rank), + "global_ranks": torch.distributed.get_process_group_ranks(default_group) + } + if group_info: + return json.dumps(group_info) + except Exception as ex: + raise RuntimeError(f"Get parallel group info in pytorch failed: {str(ex)}.") from ex + return "" + + +def get_mindspore_rank_id() -> Optional[int]: + """Get mindspore rank id.""" + try: + import mindspore.communication as comm + rank_id = os.environ.get("RANK_ID") + if rank_id is None and comm.GlobalComm.INITED: + rank_id = comm.get_rank() + if rank_id is not None and not isinstance(rank_id, int): + rank_id = int(rank_id) + except Exception as ex: + raise RuntimeError(f"Get rank id failed in mindspore: {str(ex)}") from ex + return rank_id + + +def get_mindspore_parallel_group_info() -> str: + """Get mindspore parallel group info.""" + try: + import mindspore.communication as comm + import mindspore.communication._comm_helper as comm_helper + if comm.GlobalComm.INITED and comm.GlobalComm.BACKEND == comm_helper.Backend.HCCL: + group_info = {} + for group_name in comm_helper._get_group_map().keys(): + comm_name = comm.get_comm_name(group_name) + if not comm_name: + continue + group_info[comm_name] = { + "group_name": group_name, + "group_rank": comm.get_local_rank(group_name), + "global_ranks": comm.get_process_group_ranks(group_name) + } + if group_info: + return json.dumps(group_info) + except Exception as ex: + raise RuntimeError(f"Get parallel group info in mindspore failed: {str(ex)}.") from ex + return "" + + +def get_rank_id() -> int: + """Get rank id.""" + rank_id = None + try: + rank_id = get_pytorch_rank_id() + except Exception as ex: + warnings.warn(f"{str(ex)}") + + if rank_id is None: + try: + rank_id = get_mindspore_rank_id() + except Exception as ex: + warnings.warn(f"{str(ex)}") + + if rank_id is None: + warnings.warn("Failed to get rank id from pytorch and mindspore, set rank id to -1.") + rank_id = -1 + + return rank_id + + +def get_parallel_group_info() -> str: + """Get parallel group info.""" + parallel_group_info = "" + try: + parallel_group_info = get_pytorch_parallel_group_info() + except Exception as ex: + warnings.warn(f"{str(ex)}") + + if not parallel_group_info: + try: + parallel_group_info = get_mindspore_parallel_group_info() + except Exception as ex: + warnings.warn(f"{str(ex)}") + + return parallel_group_info diff --git a/msmonitor/plugin/README.md b/msmonitor/plugin/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cac84ea8e45644d2a0dc45edc18be156a620b978 --- /dev/null +++ b/msmonitor/plugin/README.md @@ -0,0 +1,56 @@ +# msmonitor-plugin编包指导 +## 模块说明 +### IPCMonitor +提供IPC(Inter-Process Communication)通信接口,用于实现 +1. IPC控制通道: profiler backend向dynolog daemon获取profiler配置 +2. IPC数据通道: mspti monitor向dynolog daemon发送性能数据 + +__PyDynamicMonitorProxy接口说明__: +* `init_dyno` 向dynolog daemon发送注册请求 + * input: npu_id(int) + * return: None +* `poll_dyno` 向dynolog daemon获取Profiler控制参数 + * input: None + * return: str,返回控制参数 +* `enable_dyno_npu_monitor` 开启mspti监控 + * input: cfg_map(Dict[str,str]) 参数配置 + * return: None +* `finalize_dyno` 释放msmonitor中相关资源、线程 + * input: None + * return: None +* `update_profiler_status` 上报profiler_status + * input: status(Dict[str,str]) + * return: None + +## 安装方式 +### 1. 通过shell脚本一键安装 +``` +chmod +x build.sh +./build.sh +``` +### 2. 手动安装 +* 安装依赖 +``` +pip install wheel +pip install pybind11 +``` +* 编译whl包 +``` +bash ./stub/build_stub.sh +python3 setup.py bdist_wheel +``` +以上命令执行完成后在dist目录下生成msMonitor插件whl安装包msmonitor_plugin-{mindstudio_version}-cp{python_version}-cp{python_version}-linux_{system_architecture}.whl +* 安装 +``` +pip install dist/msmonitor_plugin-{mindstudio_version}-cp{python_version}-cp{python_version}-linux_{system_architecture}.whl +``` +* 卸载 +``` +pip uninstall msmonitor-plugin +``` + +## 日志 +用户可以通过配置MSMONITOR_LOG_PATH环境变量,指定到自定义的日志文件路径,默认路径为当前目录下的msmonitor_log +``` +export MSMONITOR_LOG_PATH=/tmp/msmonitor_log # /tmp/msmonitor_log为自定义日志文件路径 +``` diff --git a/msmonitor/plugin/bindings.cpp b/msmonitor/plugin/bindings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..794b49767b9748b8011746fca712273c8eab3515 --- /dev/null +++ b/msmonitor/plugin/bindings.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "ipc_monitor/PyDynamicMonitorProxy.h" +#include "ipc_monitor/mspti_monitor/MsptiMonitor.h" + +namespace py = pybind11; + + +PYBIND11_MODULE(IPCMonitor_C, m) { + m.def("init_dyno", [](int npu_id) -> bool { + return dynolog_npu::ipc_monitor::PyDynamicMonitorProxy::GetInstance()->InitDyno(npu_id); + }, py::arg("npu_id")); + m.def("poll_dyno", []() -> std::string { + return dynolog_npu::ipc_monitor::PyDynamicMonitorProxy::GetInstance()->PollDyno(); + }); + m.def("enable_dyno_npu_monitor", [](std::unordered_map& config_map) -> void { + dynolog_npu::ipc_monitor::PyDynamicMonitorProxy::GetInstance()->EnableMsptiMonitor(config_map); + }, py::arg("config_map")); + m.def("finalize_dyno", []() -> void { + dynolog_npu::ipc_monitor::PyDynamicMonitorProxy::GetInstance()->FinalizeDyno(); + }); + m.def("set_cluster_config_data", [](const std::unordered_map& cluster_config) -> void { + dynolog_npu::ipc_monitor::MsptiMonitor::GetInstance()->SetClusterConfigData(cluster_config); + }, py::arg("cluster_config")); + m.def("update_profiler_status", [](std::unordered_map& status) -> void { + dynolog_npu::ipc_monitor::PyDynamicMonitorProxy::GetInstance()->UpdateProfilerStatus(status); + }, py::arg("status")); +} diff --git a/dynolog_npu/plugin/build.sh b/msmonitor/plugin/build.sh old mode 100755 new mode 100644 similarity index 80% rename from dynolog_npu/plugin/build.sh rename to msmonitor/plugin/build.sh index ce20d9d2be546afbc63e3aace524f74858eff6ff..939aaa2baf6f92d8fc4600340d363997942db33b --- a/dynolog_npu/plugin/build.sh +++ b/msmonitor/plugin/build.sh @@ -3,7 +3,10 @@ # install pybind11 pip install pybind11 -# build dynolog_npu_plugin wheel +# build stub +sh ./stub/build_stub.sh + +# build msmonitor_plugin wheel python3 setup.py bdist_wheel # find .whl files in dist @@ -18,4 +21,4 @@ fi # pip install whl echo "pip install ${files}" -pip install ${files} \ No newline at end of file +pip install ${files} diff --git a/msmonitor/plugin/cmake/Findglog.cmake b/msmonitor/plugin/cmake/Findglog.cmake new file mode 100644 index 0000000000000000000000000000000000000000..bbebee6ee217f23e65244c9072550ed612ac7c97 --- /dev/null +++ b/msmonitor/plugin/cmake/Findglog.cmake @@ -0,0 +1,42 @@ +set(PACKAGE_VERSION 0.6.0) + +set(PKG_NAME glog) +set(DOWNLOAD_PATH "$ENV{PROJECT_ROOT_PATH}/third_party") +set(GIT_TAG "v0.6.0") +set(DIR_NAME "${DOWNLOAD_PATH}/glog") + +if (NOT ${PKG_NAME}_FOUND) + +download_opensource_pkg(${PKG_NAME} + GIT_TAG ${GIT_TAG} + DOWNLOAD_PATH ${DOWNLOAD_PATH} +) + +execute_process( + WORKING_DIRECTORY ${DIR_NAME} + COMMAND cmake -S . -B build -G "Unix Makefiles" -DBUILD_SHARED_LIBS=OFF -DCMAKE_INSTALL_PREFIX=${DIR_NAME}/install -DCMAKE_INSTALL_LIBDIR=${DIR_NAME}/install/lib64 -DWITH_GFLAGS=OFF -DWITH_GTEST=OFF -DWITH_SYMBOLIZE=OFF -DCMAKE_POLICY_VERSION_MINIMUM=3.5 + RESULT_VARIABLE RESULT +) +if (NOT RESULT EQUAL 0) + message(FATAL_ERROR "Failed to build glog. ${RESULT}") +endif() + +execute_process( + WORKING_DIRECTORY ${DIR_NAME} + COMMAND cmake --build build --target install + RESULT_VARIABLE RESULT +) +if (NOT RESULT EQUAL 0) + message(FATAL_ERROR "Failed to build glog. ${RESULT}") +endif() + +file(GLOB GLOG_LIB "${DIR_NAME}/install/lib64/libglog.a") +if (NOT GLOG_LIB) + message(FATAL_ERROR "Failed to build glog.") +endif() + +set(${PKG_NAME}_LIBRARIES ${GLOG_LIB}) +include_directories(${DIR_NAME}/install/include) +set(${PKG_NAME}_FOUND TRUE) + +endif() diff --git a/msmonitor/plugin/cmake/Findnlohmannjson.cmake b/msmonitor/plugin/cmake/Findnlohmannjson.cmake new file mode 100644 index 0000000000000000000000000000000000000000..a657cc3accbff93f0599c8d30fc9228339563c7c --- /dev/null +++ b/msmonitor/plugin/cmake/Findnlohmannjson.cmake @@ -0,0 +1,18 @@ +set(PACKAGE_VERSION 3.12.0) + +set(PKG_NAME nlohmannjson) +set(DOWNLOAD_PATH "$ENV{PROJECT_ROOT_PATH}/third_party") +set(GIT_TAG "v3.12.0") +set(DIR_NAME "${DOWNLOAD_PATH}/nlohmann-json") + +if (NOT ${PKG_NAME}_FOUND) + +download_opensource_pkg(${PKG_NAME} + GIT_TAG ${GIT_TAG} + DOWNLOAD_PATH ${DOWNLOAD_PATH} +) + +include_directories(${DIR_NAME}/include) +set(${PKG_NAME}_FOUND TRUE) + +endif() diff --git a/msmonitor/plugin/cmake/Findsqlite3.cmake b/msmonitor/plugin/cmake/Findsqlite3.cmake new file mode 100644 index 0000000000000000000000000000000000000000..afc8c73adc7b8ace75c5b193b9c9832683b6c644 --- /dev/null +++ b/msmonitor/plugin/cmake/Findsqlite3.cmake @@ -0,0 +1,22 @@ +set(PACKAGE_VERSION 3.50.3) + +set(PKG_NAME sqlite3) +set(DOWNLOAD_PATH "$ENV{PROJECT_ROOT_PATH}/third_party") +set(DIR_NAME "${DOWNLOAD_PATH}/sqlite-amalgamation-3500300") + +if (NOT ${PKG_NAME}_FOUND) + +download_opensource_pkg(${PKG_NAME} + DOWNLOAD_PATH ${DOWNLOAD_PATH} +) + +file(GLOB SQLITE3_SRC "${DIR_NAME}/sqlite3.c") +if (NOT SQLITE3_SRC) + message(FATAL_ERROR "Failed to get sqlite3 source code.") +endif() + +set(${PKG_NAME}_SOURCES ${SQLITE3_SRC}) +include_directories(${DIR_NAME}) +set(${PKG_NAME}_FOUND TRUE) + +endif() diff --git a/msmonitor/plugin/cmake/config.ini b/msmonitor/plugin/cmake/config.ini new file mode 100644 index 0000000000000000000000000000000000000000..f7b2eaa4e96bfec6a3b39b75b0b2667962b04e5a --- /dev/null +++ b/msmonitor/plugin/cmake/config.ini @@ -0,0 +1,8 @@ +[glog] +url = https://gitee.com/mirrors/glog.git + +[nlohmannjson] +url = https://gitee.com/mirrors/nlohmann-json.git + +[sqlite3] +url = https://sqlite.org/2025/sqlite-amalgamation-3500300.zip \ No newline at end of file diff --git a/msmonitor/plugin/cmake/download_opensource.sh b/msmonitor/plugin/cmake/download_opensource.sh new file mode 100644 index 0000000000000000000000000000000000000000..d04e59b445edafd17061a227763eb09cd3b439b7 --- /dev/null +++ b/msmonitor/plugin/cmake/download_opensource.sh @@ -0,0 +1,73 @@ +#!/bin/bash + +if [ "$#" -lt 2 ]; then + echo "Usage: $0 [ ]" + exit 1 +fi + +pkg_name=$1 +path=$2 + +if [ "$#" -ge 3 ]; then + tag=$3 +fi + +url=$(awk -F " = " '/\['${pkg_name}'\]/{a=1}a==1&&$1~/url/{print $2;exit}' config.ini) +lib_path=$MSTT_LIB_PATH +if [ -n "$lib_path" ]; then + url=${lib_path}$(echo $url | awk -F '/' -v OFS='/' '{print $5,$8}') +fi +if [[ ! $url = https* ]]; then + echo "The URL of $pkg_name is illegal." + exit 1 +fi + +echo "Start to download ${url}..." + +if [ ! -d "$path" ]; then + echo "The specified path does not exist: $path" + exit 1 +fi +cd ${path} + +extension=$(echo "${url}" | awk -F'[./]' '{print $NF}') +if [[ "${extension}" == "gz" || "${extension}" == "zip" ]]; then + fullname="${path}/$(basename "${url}")" + if [[ -e ${fullname} ]]; then + echo "Source ${fullname} is exists, will not download again." + else + curl -L "${url}" -o ${fullname} -k + if [ $? -eq 0 ]; then + echo "Download successful: ${url}" + else + echo "Download failed: ${url}" + exit 1 + fi + fi + + if [[ "${extension}" == "gz" ]]; then + tar -zxvf ${fullname} -C ./ -n > /dev/null + elif [[ "${extension}" == "zip" ]]; then + unzip -n ${fullname} -d ./ > /dev/null + fi +elif [[ "${extension}" == "git" ]]; then + repository="$(basename ${url} .git)" + if [[ -e ${repository} ]]; then + echo "Source ${repository} is exists, will not clone again." + else + if [[ -z "${tag}" ]]; then + git clone ${url} + else + git clone ${url} -b "${tag}" + fi + if [ $? -eq 0 ]; then + echo "Download successful: ${url}" + else + echo "Download failed: ${url}" + exit 1 + fi + fi +else + echo "Unknow url ${url}" + exit 1 +fi diff --git a/msmonitor/plugin/cmake/utils.cmake b/msmonitor/plugin/cmake/utils.cmake new file mode 100644 index 0000000000000000000000000000000000000000..3d815d268558aabb27dd3503a2b79a0368da914d --- /dev/null +++ b/msmonitor/plugin/cmake/utils.cmake @@ -0,0 +1,25 @@ + +function(download_opensource_pkg pkg_name) + message("start to download ${pkg_name}...") + set(options) + set(oneValueArgs GIT_TAG DOWNLOAD_PATH DIR_NAME BUILD_CMD) + set(multiValueArgs PATCHES) + cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + if (NOT PKG_DOWNLOAD_PATH) + set(PKG_DOWNLOAD_PATH "${CMAKE_SOURCE_DIR}/../third_party") + endif() + file(MAKE_DIRECTORY ${PKG_DOWNLOAD_PATH}) + + execute_process( + WORKING_DIRECTORY $ENV{PROJECT_ROOT_PATH}/cmake + COMMAND bash download_opensource.sh ${pkg_name} ${PKG_DOWNLOAD_PATH} ${PKG_GIT_TAG} + RESULT_VARIABLE RESULT + ) + if (NOT RESULT EQUAL 0) + message(FATAL_ERROR "Failed to download ${pkg_name}(${RESULT}).") + endif() + if (PKG_BUILD_CMD) + execute_process(COMMAND bash -c "cd ${PKG_DOWNLOAD_PATH}/${DIR_NAME};${PKG_BUILD_CMD}") + endif() +endfunction() diff --git a/msmonitor/plugin/ipc_monitor/DynoLogNpuMonitor.cpp b/msmonitor/plugin/ipc_monitor/DynoLogNpuMonitor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a8b8220a9a362ae3f1e67e99a425c466f4d1e63e --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/DynoLogNpuMonitor.cpp @@ -0,0 +1,146 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "DynoLogNpuMonitor.h" +#include +#include +#include +#include "utils.h" +#include "MsptiMonitor.h" + +namespace dynolog_npu { +namespace ipc_monitor { +DynoLogNpuMonitor::DynoLogNpuMonitor() +{ + // init glog + if (!google::IsGoogleLoggingInitialized()) { + std::string logPath; + if (CreateMsmonitorLogPath(logPath)) { + fprintf(stderr, "[INFO] [%d] Msmonitor log will record to %s\n", GetProcessId(), logPath.c_str()); + logPath = logPath + "/msmonitor_"; + google::InitGoogleLogging("MsMonitor"); + google::SetStderrLogging(google::GLOG_ERROR); + google::SetLogDestination(google::GLOG_INFO, logPath.c_str()); + google::SetLogFilenameExtension(".log"); + } else { + fprintf(stderr, "Failed to create log path, log will not record\n"); + } + } + msptiActivityDisableMarkerDomain("communication"); // filter inner communication marker data for now +} + +bool DynoLogNpuMonitor::Init() +{ + if (isInitialized_) { + LOG(WARNING) << "DynoLog npu monitor already initialized"; + return true; + } + if (!ipcClient_.Init()) { + LOG(ERROR) << "DynoLog npu monitor ipcClient init failed"; + return false; + } + bool res = ipcClient_.RegisterInstance(npuId_); + if (res) { + isInitialized_ = true; + LOG(INFO) << "DynoLog npu monitor initialized successfully"; + } + return res; +} + +ErrCode DynoLogNpuMonitor::DealMonitorReq(MsptiMonitorCfg& cmd) +{ + auto msptiMonitor = MsptiMonitor::GetInstance(); + if (cmd.monitorStop) { + if (msptiMonitor->IsStarted()) { + LOG(INFO) << "Stop mspti monitor thread successfully"; + msptiMonitor->Stop(); + } + return ErrCode::SUC; + } + + if (cmd.reportIntervals != 0) { + msptiMonitor->SetFlushInterval(cmd.reportIntervals); + } + + if (cmd.monitorStart && !msptiMonitor->IsStarted()) { + if (!cmd.savePath.empty() && !msptiMonitor->CheckAndSetSavePath(cmd.savePath)) { + LOG(ERROR) << "Invalid log path, mspti monitor start failed"; + return ErrCode::PERMISSION; + } + + LOG(INFO) << "Start mspti monitor thread successfully"; + msptiMonitor->Start(); + } + + if (msptiMonitor->IsStarted() && !cmd.enableActivities.empty()) { + auto curActivities = msptiMonitor->GetEnabledActivities(); + std::vector enableKinds; + std::vector disableKinds; + std::set_difference(cmd.enableActivities.begin(), cmd.enableActivities.end(), curActivities.begin(), curActivities.end(), + std::back_inserter(enableKinds)); + std::set_difference(curActivities.begin(), curActivities.end(), cmd.enableActivities.begin(), cmd.enableActivities.end(), + std::back_inserter(disableKinds)); + for (auto activity : enableKinds) { + msptiMonitor->EnableActivity(activity); + } + for (auto activity : disableKinds) { + msptiMonitor->DisableActivity(activity); + } + } + return ErrCode::SUC; +} + +std::string DynoLogNpuMonitor::Poll() +{ + std::string res = ipcClient_.IpcClientNpuConfig(); + if (res.size() == 4) { // res为4,表示dynolog注册进程成功 + LOG(INFO) << "Regist to dynolog daemon successfully"; + return ""; + } + if (res.empty()) { + return ""; + } + LOG(INFO) << "Received NPU configuration successfully"; + return res; +} + +void DynoLogNpuMonitor::EnableMsptiMonitor(std::unordered_map& cfg_map) +{ + auto cmd = InputParser::GetInstance()->DynoLogGetOpts(cfg_map); + if (cmd.isMonitor) { + auto ans = DealMonitorReq(cmd); + if (ans != ErrCode::SUC) { + LOG(ERROR) << "Deal monitor request failed, because" << IPC_ERROR(ans); + } + UpdateNpuStatus(static_cast(MsptiMonitor::GetInstance()->IsStarted()), MSG_TYPE_MONITOR_STATUS); + } +} + +void DynoLogNpuMonitor::Finalize() +{ + MsptiMonitor::GetInstance()->Uninit(); +} + +void DynoLogNpuMonitor::UpdateNpuStatus(int32_t status, const std::string& msgType) +{ + bool res = ipcClient_.SendNpuStatus(status, msgType); + if (res) { + LOG(INFO) << "Send npu status successfully"; + } else { + LOG(WARNING) << "Send npu status failed"; + } +} +} // namespace ipc_monitor +} // namespace dynolog_npu diff --git a/dynolog_npu/plugin/ipc_monitor/DynoLogNpuMonitor.h b/msmonitor/plugin/ipc_monitor/DynoLogNpuMonitor.h similarity index 37% rename from dynolog_npu/plugin/ipc_monitor/DynoLogNpuMonitor.h rename to msmonitor/plugin/ipc_monitor/DynoLogNpuMonitor.h index 40ee21072710312a86cd75befdcefa67e24efb8f..3641925db49fbbbe89f3931bff2f3c4831b66c38 100644 --- a/dynolog_npu/plugin/ipc_monitor/DynoLogNpuMonitor.h +++ b/msmonitor/plugin/ipc_monitor/DynoLogNpuMonitor.h @@ -1,9 +1,25 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef DYNOLOG_NPU_MONITOR_H #define DYNOLOG_NPU_MONITOR_H #include "MonitorBase.h" #include "NpuIpcClient.h" #include "singleton.h" +#include "InputParser.h" namespace dynolog_npu { namespace ipc_monitor { @@ -12,14 +28,23 @@ class DynoLogNpuMonitor : public MonitorBase, public Singleton; public: - DynoLogNpuMonitor() = default; + DynoLogNpuMonitor(); bool Init() override; + ErrCode DealMonitorReq(MsptiMonitorCfg& cmd); std::string Poll() override; + void EnableMsptiMonitor(std::unordered_map& cfg_map); + void Finalize(); + void UpdateNpuStatus(int32_t status, const std::string& msgType); void SetNpuId(int id) override { npuId_ = id; } + IpcClient *GetIpcClient() + { + return &ipcClient_; + } + private: bool isInitialized_ = false; int32_t npuId_ = 0; @@ -29,5 +54,4 @@ private: } // namespace ipc_monitor } // namespace dynolog_npu -#endif - +#endif // DYNOLOG_NPU_MONITOR_H diff --git a/msmonitor/plugin/ipc_monitor/InputParser.cpp b/msmonitor/plugin/ipc_monitor/InputParser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0558f00530106ae4df340ad2608863056ca8e06f --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/InputParser.cpp @@ -0,0 +1,70 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "InputParser.h" +#include +#include +#include "utils.h" + +namespace dynolog_npu { +namespace ipc_monitor { + +const std::string MSPTI_ACTIVITY_KIND_KEY = "MSPTI_ACTIVITY_KIND"; +const std::string REPORT_INTERVAL_S_KEY = "REPORT_INTERVAL_S"; +const std::string NPU_MONITOR_START_KEY = "NPU_MONITOR_START"; +const std::string NPU_MONITOR_STOP_KEY = "NPU_MONITOR_STOP"; +const std::string NPU_MONITOR_SAVE_PATH = "NPU_MONITOR_LOG_FILE"; + +const std::unordered_map kindStrMap { + {"Marker", MSPTI_ACTIVITY_KIND_MARKER}, + {"Kernel", MSPTI_ACTIVITY_KIND_KERNEL}, + {"API", MSPTI_ACTIVITY_KIND_API}, + {"Hccl", MSPTI_ACTIVITY_KIND_HCCL}, + {"Memory", MSPTI_ACTIVITY_KIND_MEMORY}, + {"MemSet", MSPTI_ACTIVITY_KIND_MEMSET}, + {"MemCpy", MSPTI_ACTIVITY_KIND_MEMCPY}, + {"Communication", MSPTI_ACTIVITY_KIND_COMMUNICATION}, +}; + +std::set str2Kinds(const std::string& kindStrs) +{ + std::set res; + auto kindStrList = split(kindStrs, ','); + for (auto& kindStr : kindStrList) { + auto kind = kindStrMap.find(kindStr); + if (kind == kindStrMap.end()) { + return {MSPTI_ACTIVITY_KIND_INVALID}; + } + res.insert(kind->second); + } + return res; +} + +MsptiMonitorCfg InputParser::DynoLogGetOpts(std::unordered_map& cmd) +{ + if (!cmd.count(NPU_MONITOR_START_KEY)) { + return {{MSPTI_ACTIVITY_KIND_INVALID}, 0, false, false, false, ""}; + } + auto activityKinds = str2Kinds(cmd[MSPTI_ACTIVITY_KIND_KEY]); + uint32_t reportTimes = 0; + Str2Uint32(reportTimes, cmd[REPORT_INTERVAL_S_KEY]); + bool startSwitch = false; + Str2Bool(startSwitch, cmd[NPU_MONITOR_START_KEY]); + bool endSwitch = false; + Str2Bool(endSwitch, cmd[NPU_MONITOR_STOP_KEY]); + return {activityKinds, reportTimes, startSwitch, endSwitch, true, cmd[NPU_MONITOR_SAVE_PATH]}; +} +} // namespace ipc_monitor +} // namespace dynolog_npu diff --git a/msmonitor/plugin/ipc_monitor/InputParser.h b/msmonitor/plugin/ipc_monitor/InputParser.h new file mode 100644 index 0000000000000000000000000000000000000000..fb03655259fc3e67eb79b40a327fe99d8bf79617 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/InputParser.h @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef INPUT_PARSER_H +#define INPUT_PARSER_H + +#include +#include +#include +#include "mspti.h" +#include "singleton.h" + +namespace dynolog_npu { +namespace ipc_monitor { + +struct MsptiMonitorCfg { + std::set enableActivities; + uint32_t reportIntervals; + bool monitorStart; + bool monitorStop; + bool isMonitor; + std::string savePath; +}; + + +class InputParser : public Singleton { +public: + MsptiMonitorCfg DynoLogGetOpts(std::unordered_map& cmd); +}; + +} // namespace ipc_monitor +} // namespace dynolog_npu +#endif // INPUT_PARSER_H diff --git a/msmonitor/plugin/ipc_monitor/MonitorBase.h b/msmonitor/plugin/ipc_monitor/MonitorBase.h new file mode 100644 index 0000000000000000000000000000000000000000..a46fc7fe31e9464c00cfecd01a29b85f3977705b --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/MonitorBase.h @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MONITOR_BASE_H +#define MONITOR_BASE_H + +#include + +namespace dynolog_npu { +namespace ipc_monitor { + +class MonitorBase { +public: + virtual bool Init() = 0; + virtual std::string Poll() = 0; + virtual void SetNpuId(int id) = 0; +}; + +} // namespace ipc_monitor +} // namespace dynolog_npu +#endif // MONITOR_BASE_H diff --git a/dynolog_npu/plugin/ipc_monitor/NpuIpcClient.cpp b/msmonitor/plugin/ipc_monitor/NpuIpcClient.cpp similarity index 55% rename from dynolog_npu/plugin/ipc_monitor/NpuIpcClient.cpp rename to msmonitor/plugin/ipc_monitor/NpuIpcClient.cpp index 97966e8eeacc7276426feb237aa122eb8dee046f..4bd425ed12ff538527c9b53a4c28f25f380add3b 100644 --- a/dynolog_npu/plugin/ipc_monitor/NpuIpcClient.cpp +++ b/msmonitor/plugin/ipc_monitor/NpuIpcClient.cpp @@ -1,57 +1,112 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "NpuIpcClient.h" - -#include +#include namespace dynolog_npu { namespace ipc_monitor { +bool IpcClient::Init() +{ + pids_ = GetPids(); + return true; +} -bool IpcClient::RegisterInstance(int32_t id) +bool IpcClient::RegisterInstance(int32_t npu) { NpuContext context{ - .npu = id, + .npu = npu, .pid = getpid(), .jobId = JOB_ID, }; - std::unique_ptr message = Message::ConstructMessage(context, "ctxt"); + std::unique_ptr message = Message::ConstructMessage(context, MSG_TYPE_CONTEXT); try { - if (!SyncSendMessage(*message, std::string(DYNO_IPC_NAME))) { - std::cout << "[WARNING]Failed to send register ctxt for pid " << context.pid << " with dyno" << std::endl; + if (!SyncSendMessage(*message, DYNO_IPC_NAME)) { + LOG(WARNING) << "Failed to send register ctxt for pid " << context.pid << " with dyno"; return false; } } catch (const std::exception &e) { - std::cout << "[WARNING] Error when SyncSendMessage: " << e.what() << std::endl; + LOG(WARNING) << "Error when SyncSendMessage: " << e.what(); return false; } - std::cout << "[INFO] Resigter pid " << context.pid << " for dynolog success !" << std::endl; + LOG(INFO) << "Resigter pid " << context.pid << " for dynolog success!"; return true; } + +bool IpcClient::SendNpuStatus(int32_t status, const std::string& msgType) +{ + NpuStatus npuStatus{ + .status = status, + .pid = GetProcessId(), + .jobId = JOB_ID, + }; + std::unique_ptr message = Message::ConstructMessage(npuStatus, msgType); + try { + if (!SyncSendMessage(*message, DYNO_IPC_NAME)) { + LOG(WARNING) << "Failed to send msmonitor status for pid " << npuStatus.pid << " with dyno"; + return false; + } + } catch (const std::exception &e) { + LOG(WARNING) << "Error when SyncSendMessage: " << e.what(); + return false; + } + LOG(INFO) << "Send msmonitor status for pid " << npuStatus.pid << " for dynolog success!"; + return true; +} + std::string IpcClient::IpcClientNpuConfig() { auto size = pids_.size(); - auto *req = (NpuRequest *)malloc(sizeof(NpuRequest) + sizeof(int32_t) * size); + auto *req = ReinterpretConvert(malloc(sizeof(NpuRequest) + sizeof(int32_t) * size)); + if (req == nullptr) { + LOG(ERROR) << " Malloc for NpuRequest failed !"; + return ""; + } req->type = DYNO_IPC_TYPE; req->pidSize = size; req->jobId = JOB_ID; - for (int i = 0; i < size; i++) { + for (size_t i = 0; i < size; i++) { req->pids[i] = pids_[i]; } - std::unique_ptr message = Message::ConstructMessage(*req, "req", size); - if (!SyncSendMessage(*message, std::string(DYNO_IPC_NAME))) { - std::cout << "[WARNING] Failed to send config to dyno server fail !" << std::endl; + std::unique_ptr message; + try{ + message = Message::ConstructMessage(*req, MSG_TYPE_REQUEST, size); + } + catch (const std::exception &e) { + LOG(ERROR) << "ConstructMessage failed: " << e.what(); + free(req); + req = nullptr; + throw; + } + if (!message || !SyncSendMessage(*message, DYNO_IPC_NAME)) { + LOG(WARNING) << "Failed to send config to dyno server"; free(req); req = nullptr; return ""; } free(req); + req = nullptr; message = PollRecvMessage(MAX_IPC_RETRIES, MAX_SLEEP_US); if (!message) { - std::cout << "[WARNING] Failed to receive on-demand config !" << std::endl; + LOG(WARNING) << "Failed to receive on-demand config"; return ""; } std::string res = std::string(ReinterpretConvert(message->buf.get()), message->metadata.size); - return res; } + std::unique_ptr IpcClient::ReceiveMessage() { std::lock_guard wguard(dequeLock_); @@ -62,10 +117,11 @@ std::unique_ptr IpcClient::ReceiveMessage() msgDynoDeque_.pop_front(); return message; } + bool IpcClient::SyncSendMessage(const Message &message, const std::string &destName, int numRetry, int seepTimeUs) { if (destName.empty()) { - std::cout << "[WARNING] Can not send to empty socket name !" << std::endl; + LOG(WARNING) << "Can not send to empty socket name!"; return false; } int i = 0; @@ -79,11 +135,12 @@ bool IpcClient::SyncSendMessage(const Message &message, const std::string &destN seepTimeUs *= 2; // 2: double sleep time } } catch (const std::exception &e) { - std::cout << "[ERROR] Error when SyncSendMessage: " << e.what() << std::endl; + LOG(ERROR) << "Error when SyncSendMessage: " << e.what(); return false; } return i < numRetry; } + bool IpcClient::Recv() { try { @@ -94,7 +151,7 @@ bool IpcClient::Recv() try { successFlag = ep_.TryPeekMessage(*peekCtxt); } catch (std::exception &e) { - std::cout << "[ERROR] Error when TryPeekMessage: " << e.what() << std::endl; + LOG(ERROR) << "Error when TryPeekMessage: " << e.what(); return false; } if (successFlag) { @@ -108,7 +165,7 @@ bool IpcClient::Recv() try { successFlag = ep_.TryRcvMessage(*recvCtxt); } catch (std::exception &e) { - std::cout << "[ERROR] Error when TryRecvMsg: " << e.what() << std::endl; + LOG(ERROR) << "Error when TryRecvMsg: " << e.what(); return false; } if (successFlag) { @@ -118,11 +175,12 @@ bool IpcClient::Recv() } } } catch (std::exception &e) { - std::cout << "[ERROR] Error in Recv(): " << e.what() << std::endl; + LOG(ERROR) << "Error in Recv(): " << e.what(); return false; } return false; } + std::unique_ptr IpcClient::PollRecvMessage(int maxRetry, int sleeTimeUs) { for (int i = 0; i < maxRetry; i++) { @@ -133,6 +191,5 @@ std::unique_ptr IpcClient::PollRecvMessage(int maxRetry, int sleeTimeUs } return nullptr; } - } // namespace ipc_monitor -} // namespace dynolog_npu \ No newline at end of file +} // namespace dynolog_npu diff --git a/dynolog_npu/plugin/ipc_monitor/NpuIpcClient.h b/msmonitor/plugin/ipc_monitor/NpuIpcClient.h similarity index 48% rename from dynolog_npu/plugin/ipc_monitor/NpuIpcClient.h rename to msmonitor/plugin/ipc_monitor/NpuIpcClient.h index ae7b00eb51b935db4e799fab470c3343e78bcb6f..42cfcccf6da2a23355ffb4afc4e115c82c696173 100644 --- a/dynolog_npu/plugin/ipc_monitor/NpuIpcClient.h +++ b/msmonitor/plugin/ipc_monitor/NpuIpcClient.h @@ -1,40 +1,67 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef NPU_IPC_CLIENT_H #define NPU_IPC_CLIENT_H -#include -#include + +#include #include #include -#include -#include -#include -#include #include "NpuIpcEndPoint.h" #include "utils.h" +#include "securec.h" namespace dynolog_npu { namespace ipc_monitor { constexpr int TYPE_SIZE = 32; constexpr int JOB_ID = 0; -constexpr const char *DYNO_IPC_NAME = "dynolog"; constexpr const int DYNO_IPC_TYPE = 3; constexpr const int MAX_IPC_RETRIES = 5; constexpr const int MAX_SLEEP_US = 10000; +const std::string DYNO_IPC_NAME = "dynolog"; +const std::string MSG_TYPE_REQUEST = "req"; +const std::string MSG_TYPE_CONTEXT = "ctxt"; +const std::string MSG_TYPE_TRACE_STATUS = "npuTraceStatus"; +const std::string MSG_TYPE_MONITOR_STATUS = "npuMonitorStatus"; +const std::string MSG_TYPE_DATA = "data"; + struct NpuRequest { int type; int pidSize; int64_t jobId; int32_t pids[0]; }; + struct NpuContext { int32_t npu; pid_t pid; int64_t jobId; }; + +struct NpuStatus { + int32_t status; + pid_t pid; + int64_t jobId; +}; + struct Metadata { size_t size = 0; char type[TYPE_SIZE] = ""; }; + struct Message { Metadata metadata; std::unique_ptr buf; @@ -45,19 +72,26 @@ struct Message { if (type.size() + 1 > sizeof(ipcNpuMessage->metadata.type)) { throw std::runtime_error("Type string is too long to fit in metadata.type" + IPC_ERROR(ErrCode::PARAM)); } - memcpy(ipcNpuMessage->metadata.type, type.c_str(), type.size() + 1); + if (memcpy_s(ipcNpuMessage->metadata.type, sizeof(ipcNpuMessage->metadata.type), + type.c_str(), type.size() + 1) != EOK) { + throw std::runtime_error("memcpy_s failed" + IPC_ERROR(ErrCode::MEMORY)); + } #if __cplusplus >= 201703L if constexpr (std::is_same::value == true) { ipcNpuMessage->metadata.size = data.size(); ipcNpuMessage->buf = std::make_unique(ipcNpuMessage->metadata.size); - memcpy(ipcNpuMessage->buf.get(), data.c_str(), sizeof(data)); + if (memcpy_s(ipcNpuMessage->buf.get(), ipcNpuMessage->metadata.size, data.c_str(), data.size()) != EOK) { + throw std::runtime_error("memcpy_s failed" + IPC_ERROR(ErrCode::MEMORY)); + } return ipcNpuMessage; } #endif static_assert(std::is_trivially_copyable::value); ipcNpuMessage->metadata.size = sizeof(data); ipcNpuMessage->buf = std::make_unique(ipcNpuMessage->metadata.size); - memcpy(ipcNpuMessage->buf.get(), &data, sizeof(data)); + if (memcpy_s(ipcNpuMessage->buf.get(), ipcNpuMessage->metadata.size, &data, sizeof(data)) != EOK) { + throw std::runtime_error("memcpy_s failed" + IPC_ERROR(ErrCode::MEMORY)); + } return ipcNpuMessage; } @@ -68,36 +102,62 @@ struct Message { if (type.size() + 1 > sizeof(ipcNpuMessage->metadata.type)) { throw std::runtime_error("Type string is too long to fit in metadata.type" + IPC_ERROR(ErrCode::PARAM)); } - memcpy(ipcNpuMessage->metadata.type, type.c_str(), type.size() + 1); + if (memcpy_s(ipcNpuMessage->metadata.type, sizeof(ipcNpuMessage->metadata.type), + type.c_str(), type.size() + 1) != EOK) { + throw std::runtime_error("memcpy_s failed" + IPC_ERROR(ErrCode::MEMORY)); + } static_assert(std::is_trivially_copyable::value); static_assert(std::is_trivially_copyable::value); ipcNpuMessage->metadata.size = sizeof(data) + sizeof(U) * n; ipcNpuMessage->buf = std::make_unique(ipcNpuMessage->metadata.size); - memcpy(ipcNpuMessage->buf.get(), &data, ipcNpuMessage->metadata.size); + if (memcpy_s(ipcNpuMessage->buf.get(), ipcNpuMessage->metadata.size, + &data, ipcNpuMessage->metadata.size) != EOK) { + throw std::runtime_error("memcpy_s failed" + IPC_ERROR(ErrCode::MEMORY)); + } + return ipcNpuMessage; + } + + static std::unique_ptr ConstructStrMessage(const std::string &data, const std::string &type) + { + std::unique_ptr ipcNpuMessage = std::make_unique(Message()); + if (type.size() + 1 > sizeof(ipcNpuMessage->metadata.type)) { + throw std::runtime_error("Type string is too long to fit in metadata.type" + IPC_ERROR(ErrCode::PARAM)); + } + if (memcpy_s(ipcNpuMessage->metadata.type, sizeof(ipcNpuMessage->metadata.type), + type.c_str(), type.size() + 1) != EOK) { + throw std::runtime_error("memcpy_s failed" + IPC_ERROR(ErrCode::MEMORY)); + } + ipcNpuMessage->metadata.size = data.size(); + ipcNpuMessage->buf = std::make_unique(ipcNpuMessage->metadata.size); + if (memcpy_s(ipcNpuMessage->buf.get(), ipcNpuMessage->metadata.size, data.c_str(), data.size()) != EOK) { + throw std::runtime_error("memcpy_s failed" + IPC_ERROR(ErrCode::MEMORY)); + } return ipcNpuMessage; } }; + class IpcClient { public: IpcClient(const IpcClient &) = delete; IpcClient &operator = (const IpcClient &) = delete; IpcClient() = default; + bool Init(); bool RegisterInstance(int32_t npu); + bool SendNpuStatus(int32_t npuTraceStatus, const std::string& msgType); std::string IpcClientNpuConfig(); + bool SyncSendMessage(const Message &message, const std::string &destName, int numRetry = 10, + int seepTimeUs = 10000); private: - std::vector pids_ = GetPids(); + std::vector pids_; NpuIpcEndPoint<0> ep_{ "dynoconfigclient" + GenerateUuidV4() }; std::mutex dequeLock_; std::deque> msgDynoDeque_; std::unique_ptr ReceiveMessage(); - bool SyncSendMessage(const Message &message, const std::string &destName, int numRetry = 10, - int seepTimeUs = 10000); bool Recv(); std::unique_ptr PollRecvMessage(int maxRetry, int sleeTimeUs); }; - } // namespace ipc_monitor } // namespace dynolog_npu -#endif \ No newline at end of file +#endif // NPU_IPC_CLIENT_H diff --git a/dynolog_npu/plugin/ipc_monitor/NpuIpcEndPoint.h b/msmonitor/plugin/ipc_monitor/NpuIpcEndPoint.h similarity index 77% rename from dynolog_npu/plugin/ipc_monitor/NpuIpcEndPoint.h rename to msmonitor/plugin/ipc_monitor/NpuIpcEndPoint.h index 6560fa515646226ddbffbca49c4f818eb0d0ebcf..38d837ad94abe739bbd29bfaecd9a621c8e99e39 100644 --- a/dynolog_npu/plugin/ipc_monitor/NpuIpcEndPoint.h +++ b/msmonitor/plugin/ipc_monitor/NpuIpcEndPoint.h @@ -1,23 +1,37 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef NPU_IPC_ENDPOINT_H #define NPU_IPC_ENDPOINT_H -#include + #include #include #include #include #include +#include #include -#include -#include -#include #include "utils.h" +#include "securec.h" namespace dynolog_npu { namespace ipc_monitor { using fileDesT = int; constexpr const char STR_END_CHAR = '\0'; -constexpr int SOCKET_FD_CHMOD = 0666; +constexpr int SOCKET_FD_CHMOD = 0640; struct NpuPayLoad { size_t size; @@ -46,23 +60,34 @@ public: if (socketFd == -1) { throw std::runtime_error(std::strerror(errno) + IPC_ERROR(ErrCode::PARAM)); } + int ret = 0; struct sockaddr_un address; size_t addressLen = SetSocketAdress(addressName, address); if (address.sun_path[0] != STR_END_CHAR) { - unlink(address.sun_path); + ret = unlink(address.sun_path); + } + if (ret == -1) { + throw std::runtime_error("Unlink failed, error is " + std::string(strerror(errno)) + IPC_ERROR(ErrCode::PARAM)); } - int res = bind(socketFd, ReinterpretConvert(&address), addressLen); - if (res == -1) { + + ret = bind(socketFd, ReinterpretConvert(&address), addressLen); + if (ret == -1) { throw std::runtime_error("Bind socket failed." + IPC_ERROR(ErrCode::PARAM)); } + if (address.sun_path[0] != STR_END_CHAR) { - chmod(address.sun_path, SOCKET_FD_CHMOD); + ret = chmod(address.sun_path, SOCKET_FD_CHMOD); + } + if (ret == -1) { + throw std::runtime_error("Chmod failed, error is " + std::string(strerror(errno)) + IPC_ERROR(ErrCode::PARAM)); } } + ~NpuIpcEndPoint() { close(socketFd); } + [[nodiscard]] auto BuildSendNpuCtxt(const std::string &desAddrName, const std::vector &npuPayLoad, const std::vector &fileDes) { @@ -80,7 +105,11 @@ public: throw std::runtime_error("Memcpy failed when fileDes size large than ctxt fileDesPtr " + IPC_ERROR(ErrCode::PARAM)); } - memcpy(ctxt->fileDesPtr, fileDes.data(), fileDes.size() * sizeof(fileDesT)); + if (memcpy_s(ctxt->fileDesPtr, sizeof(ctxt->fileDesPtr), + fileDes.data(), fileDes.size() * sizeof(fileDesT)) != EOK) { + throw std::runtime_error("Memcpy failed when fileDes size large than ctxt fileDesPtr " + + IPC_ERROR(ErrCode::MEMORY)); + } } return ctxt; } @@ -137,7 +166,7 @@ public: throw std::runtime_error("TryPeekMessage occur " + std::string(std::strerror(errno))); } - const char *GetName(Ctxt const & ctxt) const noexcept + const char *GetName(Ctxt const & ctxt) const { if (ctxt.messageName.sun_path[0] != STR_END_CHAR) { throw std::runtime_error("GetName() want to got abstract socket, but got " + @@ -173,8 +202,10 @@ protected: auto BuildNpuCtxt_(const std::vector &npuPayLoad, unsigned numFileDes) { auto ctxt = std::make_unique(npuPayLoad.size()); - std::memset(&ctxt->msghdr, 0, sizeof(ctxt->msghdr)); - for (auto i = 0; i < npuPayLoad.size(); i++) { + if (memset_s(&ctxt->msghdr, sizeof(ctxt->msghdr), 0, sizeof(ctxt->msghdr)) != EOK) { + throw std::runtime_error("Memset failed when build ctxt " + IPC_ERROR(ErrCode::MEMORY)); + } + for (size_t i = 0; i < npuPayLoad.size(); i++) { ctxt->iov[i] = {npuPayLoad[i].data, npuPayLoad[i].size}; } ctxt->msghdr.msg_name = &ctxt->messageName; @@ -197,8 +228,7 @@ protected: return ctxt; } }; - } // namespace ipc_monitor } // namespace dynolog_npu -#endif +#endif // NPU_IPC_ENDPOINT_H diff --git a/msmonitor/plugin/ipc_monitor/PyDynamicMonitorProxy.h b/msmonitor/plugin/ipc_monitor/PyDynamicMonitorProxy.h new file mode 100644 index 0000000000000000000000000000000000000000..a3e8105312baf9bac52e1eae8f4cf706a5787547 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/PyDynamicMonitorProxy.h @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef PYDYNAMIC_MONITOR_PROXY_H +#define PYDYNAMIC_MONITOR_PROXY_H + +#include +#include "MonitorBase.h" +#include "DynoLogNpuMonitor.h" + +namespace dynolog_npu { +namespace ipc_monitor { + +class PyDynamicMonitorProxy : public Singleton { + friend class Singleton; + +public: + PyDynamicMonitorProxy() = default; + bool InitDyno(int npuId) + { + try { + monitor_ = DynoLogNpuMonitor::GetInstance(); + monitor_->SetNpuId(npuId); + bool res = monitor_->Init(); + return res; + } catch (const std::exception &e) { + LOG(ERROR) << "Error when init dyno " << e.what(); + return false; + } + } + + std::string PollDyno() + { + return monitor_->Poll(); + } + + void EnableMsptiMonitor(std::unordered_map& config_map) + { + DynoLogNpuMonitor::GetInstance()->EnableMsptiMonitor(config_map); + } + + void FinalizeDyno() + { + DynoLogNpuMonitor::GetInstance()->Finalize(); + } + + void UpdateProfilerStatus(std::unordered_map& status) + { + int32_t npuTraceStatus = 0; + auto it = status.find("profiler_status"); + if (it != status.end() && !it->second.empty()) { + Str2Int32(npuTraceStatus, it->second); + } else { + LOG(ERROR) << "Missing key 'profiler_status'."; + return; + } + DynoLogNpuMonitor::GetInstance()->UpdateNpuStatus(npuTraceStatus, MSG_TYPE_TRACE_STATUS); + } +private: + MonitorBase *monitor_ = nullptr; +}; + +} // namespace ipc_monitor +} // namespace dynolog_npu +#endif // PYDYNAMIC_MONITOR_PROXY_H diff --git a/msmonitor/plugin/ipc_monitor/TimerTask.h b/msmonitor/plugin/ipc_monitor/TimerTask.h new file mode 100644 index 0000000000000000000000000000000000000000..beacab87587b972386e38e2691c9cb9dcca8969d --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/TimerTask.h @@ -0,0 +1,118 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef TIMER_TASK_H +#define TIMER_TASK_H + +#include +#include +#include +#include +#include +#include + +namespace dynolog_npu { +namespace ipc_monitor { +class TimerTask { +public: + TimerTask(const std::string& name, int interval) + : interval(interval), name(name), manual_trigger(false), running(false) {} + + virtual ~TimerTask() + { + Stop(); + } + + void Run() + { + if (running) { + LOG(ERROR) << name << " Timer task is already running."; + return; + } + running = true; + taskThread = std::thread(&TimerTask::TaskRun, this); + } + + void Trigger() + { + std::unique_lock lock(cv_mutex); + manual_trigger = true; + if (running.load()) { + cv.notify_one(); + } + } + + // 停止定时任务 + void Stop() + { + if (!running) { + LOG(WARNING) << name << "Timer task is not running."; + return; + } + + running = false; + cv.notify_one(); + if (taskThread.joinable()) { + taskThread.join(); + } + } + + void SetInterval(int intervalTimes) + { + interval.store(intervalTimes); + } + + virtual void RunPreTask() {}; + virtual void RunPostTask() {}; + virtual void ExecuteTask() = 0; + bool IsRunning() { return running.load(); } +private: + // 定时任务线程函数 + void TaskRun() + { + LOG(INFO) << name << " Timer task started."; + RunPreTask(); + while (running) { + std::unique_lock lock(cv_mutex); + if (interval.load()) { + cv.wait_for(lock, std::chrono::seconds(interval.load()), [&] {return manual_trigger || !running;}); + } else { + cv.wait(lock, [&] {return manual_trigger || !running;}); + } + if (!running) { + break; + } + if (manual_trigger) { + manual_trigger = false; + } + if (running) { + ExecuteTask(); + } + } + RunPostTask(); + LOG(INFO) << name << " Timer task stopped."; + } + + std::atomic interval; + std::string name; + std::condition_variable cv; + std::mutex cv_mutex; + std::atomic manual_trigger; + std::atomic running; + std::thread taskThread; +}; +} // namespace ipc_monitor +} // namespace dynolog_npu +#endif // TIMER_TASK_H diff --git a/msmonitor/plugin/ipc_monitor/db/Connection.cpp b/msmonitor/plugin/ipc_monitor/db/Connection.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6934c14f909a715be18d34c9a883c0bc48a895a4 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/db/Connection.cpp @@ -0,0 +1,232 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "db/Connection.h" +#include "utils.h" + +namespace { +constexpr int32_t TIMEOUT = INT32_MAX; +const std::string CREATE_TABLE = "CREATE TABLE"; +const std::string CREATE_INDEX = "CREATE INDEX"; +const std::string DROP_TABLE = "DROP TABLE"; +const std::string UPDATE = "UPDATE"; +const std::string DELETE = "DELETE"; +const std::string CHECK = "CHECK"; +} + +namespace dynolog_npu { +namespace ipc_monitor { +namespace db { +Connection::Connection(const std::string &path) +{ + auto rc = sqlite3_open(path.c_str(), &db_); + if (rc != SQLITE_OK) { + LOG(ERROR) << "Open database failed: " << rc << ", msg: " << sqlite3_errmsg(db_); + sqlite3_close_v2(db_); + db_ = nullptr; + } else { + sqlite3_exec(db_, "PRAGMA synchronous=OFF;", nullptr, nullptr, nullptr); + } +} + +Connection::~Connection() +{ + if (stmt_) { + sqlite3_finalize(stmt_); + } + if (db_) { + auto rc = sqlite3_close(db_); + if (rc != SQLITE_OK) { + LOG(ERROR) << "Close database failed: " << rc << ", msg: " << sqlite3_errmsg(db_); + sqlite3_close_v2(db_); + } + db_ = nullptr; + } +} + +bool Connection::ExecuteSql(const std::string &sql, const std::string &sqlType) +{ + CHAR_PTR errMsg{nullptr}; + sqlite3_busy_timeout(db_, TIMEOUT); + auto rc = sqlite3_exec(db_, sql.c_str(), nullptr, nullptr, &errMsg); + if (rc != SQLITE_OK) { + if (sqlType == CHECK) { + LOG(WARNING) << "Execute sql failed: " << rc << ", type: " << sqlType << ", msg: " << errMsg; + } else { + LOG(ERROR) << "Execute sql failed: " << rc << ", type: " << sqlType << ", msg: " << errMsg; + } + sqlite3_free(errMsg); + return false; + } + return true; +} + +bool Connection::CheckTableExists(const std::string &tableName) +{ + std::string sql{"SELECT 1 FROM sqlite_master WHERE type='table' AND name='" + tableName + "' LIMIT 1"}; + std::vector> result; + if(ExecuteQuery(sql, result)) { + return !result.empty(); + } + return false; +} + +bool Connection::ExecuteCreateTable(const std::string &sql) +{ + return ExecuteSql(sql, CREATE_TABLE); +} + +bool Connection::ExecuteCreateIndex(const std::string &sql) +{ + return ExecuteSql(sql, CREATE_INDEX); +} + +bool Connection::ExecuteDropTable(const std::string &sql) +{ + return ExecuteSql(sql, DROP_TABLE); +} + +bool Connection::ExecuteUpdate(const std::string &sql) +{ + return ExecuteSql(sql, UPDATE); +} + +bool Connection::ExecuteDelete(const std::string &sql) +{ + return ExecuteSql(sql, DELETE); +} + +std::vector Connection::ExecuteGetTableColumns(const std::string &tableName) +{ + std::vector columns; + std::string sql = "PRAGMA table_info(" + tableName + ")"; + sqlite3_busy_timeout(db_, TIMEOUT); + auto rc = sqlite3_prepare_v2(db_, sql.c_str(), -1, &stmt_, nullptr); + if (rc != SQLITE_OK) { + LOG(ERROR) << "Execute sql failed: " << rc << ", msg: " << sqlite3_errmsg(db_); + return columns; + } + while (sqlite3_step(stmt_) == SQLITE_ROW) { + std::string name, type; + GetColumn(name); + GetColumn(type); + columns.emplace_back(name, type); + index_ = 0; + } + return columns; +} + +bool Connection::InsertCmd(const std::string &tableName, uint32_t colNum) +{ + std::string sql = "INSERT INTO " + tableName + " VALUES ("; + for (uint32_t i = 0; i < colNum; ++i) { + sql += "?"; + if (i < colNum - 1) { + sql += ", "; + } + } + sql += ")"; + sqlite3_busy_timeout(db_, TIMEOUT); + auto rc = sqlite3_prepare_v2(db_, sql.c_str(), -1, &stmt_, nullptr); + if (rc != SQLITE_OK) { + LOG(ERROR) << "Execute sql failed: " << rc << ", msg: " << sqlite3_errmsg(db_); + return false; + } + return true; +} + +bool Connection::QueryCmd(const std::string &sql) +{ + auto rc = sqlite3_prepare_v2(db_, sql.c_str(), -1, &stmt_, nullptr); + if (rc != SQLITE_OK) { + LOG(ERROR) << "Execute sql failed: " << rc << ", msg: " << sqlite3_errmsg(db_); + return false; + } + return true; +} + +bool Connection::BindParameters(int32_t value) +{ + return sqlite3_bind_int(stmt_, ++index_, value) == SQLITE_OK; +} + +bool Connection::BindParameters(uint32_t value) +{ + return sqlite3_bind_int64(stmt_, ++index_, value) == SQLITE_OK; +} + +bool Connection::BindParameters(int64_t value) +{ + return sqlite3_bind_int64(stmt_, ++index_, value) == SQLITE_OK; +} + +bool Connection::BindParameters(uint64_t value) +{ + return sqlite3_bind_int64(stmt_, ++index_, value) == SQLITE_OK; +} + +bool Connection::BindParameters(double value) +{ + return sqlite3_bind_double(stmt_, ++index_, value) == SQLITE_OK; +} + +bool Connection::BindParameters(std::string value) +{ + return sqlite3_bind_text(stmt_, ++index_, value.c_str(), -1, SQLITE_TRANSIENT) == SQLITE_OK; +} + +void Connection::GetColumn(uint16_t &value) +{ + value = static_cast(sqlite3_column_int(stmt_, ++index_)); +} + +void Connection::GetColumn(int32_t &value) +{ + value = sqlite3_column_int(stmt_, ++index_); +} + +void Connection::GetColumn(uint32_t &value) +{ + value = static_cast(sqlite3_column_int64(stmt_, ++index_)); +} + +void Connection::GetColumn(int64_t &value) +{ + value = sqlite3_column_int64(stmt_, ++index_); +} + +void Connection::GetColumn(uint64_t &value) +{ + value = static_cast(sqlite3_column_int64(stmt_, ++index_)); +} + +void Connection::GetColumn(double &value) +{ + value = sqlite3_column_double(stmt_, ++index_); +} + +void Connection::GetColumn(std::string &value) +{ + const unsigned char *text = sqlite3_column_text(stmt_, ++index_); + if (text == nullptr) { + value.clear(); + } else { + value = std::string(ReinterpretConvert(text)); + } +} +} // namespace db +} // namespace ipc_monitor +} // namespace dynolog_npu diff --git a/msmonitor/plugin/ipc_monitor/db/Connection.h b/msmonitor/plugin/ipc_monitor/db/Connection.h new file mode 100644 index 0000000000000000000000000000000000000000..93f939863759d038e977192aed67763e58ec7fdf --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/db/Connection.h @@ -0,0 +1,208 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef IPC_MONITOR_DB_CONNECTION_H +#define IPC_MONITOR_DB_CONNECTION_H + +#include +#include +#include +#include + +#include +#include + +namespace dynolog_npu { +namespace ipc_monitor { +namespace db { +using CHAR_PTR = char*; +struct TableColumn { + std::string name; + std::string type; + bool isPrimaryKey = false; + + TableColumn(const std::string &name, const std::string &type, bool isPrimaryKey = false) + : name(name), type(type), isPrimaryKey(isPrimaryKey) {} + + std::string ToString() const + { + return name + " " + type + (isPrimaryKey ? " PRIMARY KEY" : ""); + } + + bool operator==(const TableColumn &other) const + { + return (name == other.name) && (type == other.type); + } +}; + +template +struct IndexSequence {}; + +template +struct IndexSequenceMaker : IndexSequenceMaker {}; + +template +struct IndexSequenceMaker<0, S...> { + using type = IndexSequence; +}; + +template +using MakeIndexSequence = typename IndexSequenceMaker::type; + +class Connection { +public: + explicit Connection(const std::string &path); + ~Connection(); + bool CheckTableExists(const std::string &tableName); + bool ExecuteSql(const std::string &sql, const std::string &sqlType); + bool ExecuteCreateTable(const std::string &sql); + bool ExecuteCreateIndex(const std::string &sql); + bool ExecuteDropTable(const std::string &sql); + template + bool ExecuteInsert(const std::string &tableName, const std::vector> &data); + template + bool ExecuteQuery(const std::string &sql, std::vector> &result); + bool ExecuteUpdate(const std::string &sql); + bool ExecuteDelete(const std::string &sql); + std::vector ExecuteGetTableColumns(const std::string &tableName); + +private: + bool InsertCmd(const std::string &tableName, uint32_t colNum); + bool BindParameters(int32_t value); + bool BindParameters(uint32_t value); + bool BindParameters(int64_t value); + bool BindParameters(uint64_t value); + bool BindParameters(double value); + bool BindParameters(std::string value); + template + void ExecuteInsertHelper(T &row, IndexSequence); + template + int ExecuteInsertHelperHerlper(T t); + template + void InsertRow(T &row); + + bool QueryCmd(const std::string &sql); + void GetColumn(uint16_t &value); + void GetColumn(int32_t &value); + void GetColumn(uint32_t &value); + void GetColumn(int64_t &value); + void GetColumn(uint64_t &value); + void GetColumn(double &value); + void GetColumn(std::string &value); + template + void ExecuteQueryHelper(T &row, IndexSequence); + template + int ExecuteQueryHelperHelper(T &t); + template + void GetRow(T &row); + +private: + int index_{0}; + sqlite3 *db_{nullptr}; + sqlite3_stmt *stmt_{nullptr}; +}; + +template +int Connection::ExecuteInsertHelperHerlper(T t) +{ + return BindParameters(t) ? 0 : -1; +} + +template +void Connection::ExecuteInsertHelper(T &row, IndexSequence) +{ + std::initializer_list {(ExecuteInsertHelperHerlper(std::get(row)), 0)...}; +} + +template +void Connection::InsertRow(T &row) +{ + using TupleType = typename std::decay::type; + ExecuteInsertHelper(row, MakeIndexSequence::value>{}); +} + +template +int Connection::ExecuteQueryHelperHelper(T &t) +{ + GetColumn(t); + return 0; +} + +template +void Connection::ExecuteQueryHelper(T &row, IndexSequence) +{ + std::initializer_list {(ExecuteQueryHelperHelper(std::get(row)), 0)...}; +} + +template +void Connection::GetRow(T &row) +{ + using TupleType = typename std::decay::type; + ExecuteQueryHelper(row, MakeIndexSequence::value>{}); +} + +template +bool Connection::ExecuteInsert(const std::string &tableName, const std::vector> &data) +{ + uint32_t colNum = sizeof...(Args); + sqlite3_exec(db_, "BEGIN", nullptr, nullptr, nullptr); + if (!InsertCmd(tableName, colNum)) { + return false; + } + for (const auto &row : data) { + index_ = 0; + sqlite3_reset(stmt_); + InsertRow(row); + auto rc = sqlite3_step(stmt_); + if (rc != SQLITE_DONE) { + LOG(ERROR) << "ExecuteInsert failed: " << rc << ", msg: " << sqlite3_errmsg(db_) << ", insert failed"; + if (sqlite3_exec(db_, "ROLLBACK", nullptr, nullptr, nullptr) != SQLITE_OK) { + LOG(ERROR) << "ExecuteInsert failed: " << rc << ", rollback failed"; + } + return false; + } + } + sqlite3_exec(db_, "COMMIT", nullptr, nullptr, nullptr); + return true; +} + +template +bool Connection::ExecuteQuery(const std::string &sql, std::vector> &result) +{ + if (!QueryCmd(sql)) { + return false; + } + while(true) { + auto rc = sqlite3_step(stmt_); + if (rc != SQLITE_ROW) { + if (rc != SQLITE_DONE) { + LOG(ERROR) << "ExecuteQuery failed: " << rc << ", msg: " << sqlite3_errmsg(db_) << ", query failed"; + return false; + } + break; + } + index_ = -1; + std::tuple row; + GetRow(row); + result.emplace_back(row); + } + return true; +} +} // namespace db +} // namespace ipc_monitor +} // namespace dynolog_npu + +#endif // IPC_MONITOR_DB_CONNECTION_H diff --git a/msmonitor/plugin/ipc_monitor/db/DBConstant.h b/msmonitor/plugin/ipc_monitor/db/DBConstant.h new file mode 100644 index 0000000000000000000000000000000000000000..1fb35b31829061fac6aff9f2b33f3754d2ecbc3a --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/db/DBConstant.h @@ -0,0 +1,48 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef IPC_MONITOR_DB_CONSTANT_H +#define IPC_MONITOR_DB_CONSTANT_H + +#include + +namespace dynolog_npu { +namespace ipc_monitor { +namespace db { +const std::string SQL_TEXT_TYPE = "TEXT"; +const std::string SQL_INT_TYPE = "INTEGER"; +const std::string SQL_REAL_TYPE = "REAL"; +const std::string SQL_NUMERIC_TYPE = "NUMERIC"; + +const std::string TABLE_STRING_IDS = "STRING_IDS"; +const std::string TABLE_SESSION_TIME_INFO = "SESSION_TIME_INFO"; +const std::string TABLE_CANN_API = "CANN_API"; +const std::string TABLE_TASK = "TASK"; +const std::string TABLE_COMPUTE_TASK_INFO = "COMPUTE_TASK_INFO"; +const std::string TABLE_COMMUNICATION_OP = "COMMUNICATION_OP"; +const std::string TABLE_MSTX = "MSTX_EVENTS"; +const std::string TABLE_MSTX_EVENT_TYPE = "ENUM_MSTX_EVENT_TYPE"; +const std::string TABLE_HCCL_DATA_TYPE = "ENUM_HCCL_DATA_TYPE"; +const std::string TABLE_API_TYPE = "ENUM_API_TYPE"; +const std::string TABLE_HOST_INFO = "HOST_INFO"; +const std::string TABLE_RANK_DEVICE_MAP = "RANK_DEVICE_MAP"; +const std::string TABLE_META_DATA = "META_DATA"; + +} // namespace db +} // namespace ipc_monitor +} // namespace dynolog_npu + +#endif // IPC_MONITOR_DB_CONSTANT_H diff --git a/msmonitor/plugin/ipc_monitor/db/DBInfo.h b/msmonitor/plugin/ipc_monitor/db/DBInfo.h new file mode 100644 index 0000000000000000000000000000000000000000..b729c105f553d19d5d275d9bdd80c7248ca885c5 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/db/DBInfo.h @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef IPC_MONITOR_DB_INFO_H +#define IPC_MONITOR_DB_INFO_H + +#include "db/DBRunner.h" +#include "db/DataBase.h" + +namespace dynolog_npu { +namespace ipc_monitor { +namespace db { +struct DBInfo { + DBInfo() = default; + + bool ConstructDBRunner(const std::string& dbPath) + { + dbRunner = std::make_shared(dbPath); + return dbRunner != nullptr; + } + + std::shared_ptr database{nullptr}; + std::shared_ptr dbRunner{nullptr}; +}; +} // namespace db +} // namespace ipc_monitor +} // namespace dynolog_npu + +#endif // IPC_MONITOR_DB_INFO_H diff --git a/msmonitor/plugin/ipc_monitor/db/DBProcessManager.cpp b/msmonitor/plugin/ipc_monitor/db/DBProcessManager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7e51a17d3e035bd11d97108d389fd934cc22845b --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/db/DBProcessManager.cpp @@ -0,0 +1,427 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "db/DBProcessManager.h" +#include "db/DBConstant.h" +#include "singleton.h" + +namespace dynolog_npu { +namespace ipc_monitor { +namespace db { +namespace { +constexpr uint64_t MSTX_CONNECTION_ID_OFFSET = 4000000000ULL; +const std::string MSTX_TASK_TYPE = "MsTx"; +const std::string NA = "N/A"; +const std::vector> HCCL_DATA_TYPE = { + {msptiCommunicationDataType::MSPTI_ACTIVITY_COMMUNICATION_INT8, "INT8"}, + {msptiCommunicationDataType::MSPTI_ACTIVITY_COMMUNICATION_INT16, "INT16"}, + {msptiCommunicationDataType::MSPTI_ACTIVITY_COMMUNICATION_INT32, "INT32"}, + {msptiCommunicationDataType::MSPTI_ACTIVITY_COMMUNICATION_INT64, "INT64"}, + {msptiCommunicationDataType::MSPTI_ACTIVITY_COMMUNICATION_UINT8, "UINT8"}, + {msptiCommunicationDataType::MSPTI_ACTIVITY_COMMUNICATION_UINT16, "UINT16"}, + {msptiCommunicationDataType::MSPTI_ACTIVITY_COMMUNICATION_UINT32, "UINT32"}, + {msptiCommunicationDataType::MSPTI_ACTIVITY_COMMUNICATION_UINT64, "UINT64"}, + {msptiCommunicationDataType::MSPTI_ACTIVITY_COMMUNICATION_FP16, "FP16"}, + {msptiCommunicationDataType::MSPTI_ACTIVITY_COMMUNICATION_FP32, "FP32"}, + {msptiCommunicationDataType::MSPTI_ACTIVITY_COMMUNICATION_FP64, "FP64"}, + {msptiCommunicationDataType::MSPTI_ACTIVITY_COMMUNICATION_BFP16, "BFP16"}, + {msptiCommunicationDataType::MSPTI_ACTIVITY_COMMUNICATION_INT128, "INT128"}, + {msptiCommunicationDataType::MSPTI_ACTIVITY_COMMUNICATION_INVALID_TYPE, "INVALID_TYPE"} +}; + +constexpr uint16_t MSTX_MARKER_TYPE = 0; +constexpr uint16_t MSTX_RANGE_TYPE = 2; +const std::vector> MSTX_EVENT_TYPE = { + {0, "marker"}, + {1, "push/pop"}, + {2, "start/end"}, + {3, "marker_ex"} +}; + +const std::vector> META_DATA = { + {"SCHEMA_VERSION_MICRO", "1"}, + {"SCHEMA_VERSION_MINOR", "1"}, + {"SCHEMA_VERSION_MAJOR", "1"}, + {"SCHEMA_VERSION", "1.1.1"} +}; + +constexpr uint16_t API_NODE_TYPE = 10000; +const std::vector> API_TYPE = { + {5000, "runtime"}, + {5500, "hccl"}, + {10000, "node"}, + {15000, "model"}, + {20000, "acl"}, + {50001, "op"}, + {50002, "queue"}, + {50003, "trace"}, + {50004, "mstx"} +}; + +uint64_t ConcatGlobalTid(uint32_t pid, uint32_t tid) +{ + constexpr uint32_t INT32_BIT_COUNT = 32; + return (static_cast(pid) << INT32_BIT_COUNT) | tid; +} + +std::string GetMsmonitorDbPath(const std::string &outputPath) +{ + auto identity = join({std::to_string(GetProcessId()), getCurrentTimestamp(), std::to_string(GetRankId())}, "_"); + return outputPath + "/msmonitor_" + identity + ".db"; +} +} // namecpace + +class IdPool : public Singleton { +public: + IdPool() = default; + ~IdPool() = default; + uint64_t GetUint64Id(const std::string &key); + StringIdFormat GetStringIdData(); + +private: + uint64_t uint64Index_{0}; + std::mutex uint64IdMapMutex_; + std::unordered_map uint64IdMap_; +}; + +uint64_t IdPool::GetUint64Id(const std::string &key) +{ + std::lock_guard lock(uint64IdMapMutex_); + auto it = uint64IdMap_.find(key); + if (it != uint64IdMap_.end()) { + return it->second; + } + uint64IdMap_.emplace(key, uint64Index_); + return uint64Index_++; +} + +StringIdFormat IdPool::GetStringIdData() +{ + std::lock_guard lock(uint64IdMapMutex_); + StringIdFormat stringIdData; + stringIdData.reserve(uint64IdMap_.size()); + for (auto it : uint64IdMap_) { + stringIdData.emplace_back(it.second, it.first); + } + return stringIdData; +} + +void DBProcessManager::SetReportInterval(uint32_t interval) +{ + if (reportInterval_.load() != interval) { + LOG(INFO) << "DBProcessManager SetReportInterval interval: " << interval; + if (IsRunning()) { + SaveData(); + } + SetInterval(interval); + reportInterval_.store(interval); + } +} + +void DBProcessManager::RunPreTask() +{ + sessionStartTime_ = getCurrentTimestamp64(); +} + +void DBProcessManager::ExecuteTask() +{ + if (!SaveData()) { + LOG(ERROR) << "DBProcessManager SaveData failed"; + } +} + +bool DBProcessManager::CheckAndInitDB() +{ + std::lock_guard lock(dbMutex_); + if(msMonitorDB_.database == nullptr || msMonitorDB_.dbRunner == nullptr) { + std::shared_ptr msMonitorDB{nullptr}; + MakeSharedPtr(msMonitorDB); + msMonitorDB_.database = msMonitorDB; + auto dbPath = GetMsmonitorDbPath(savePath_); + LOG(INFO) << "msMonitor db will be save to " << dbPath; + return msMonitorDB_.database != nullptr && msMonitorDB_.ConstructDBRunner(dbPath); + } + return true; +} + +bool DBProcessManager::SaveData() +{ + if(!CheckAndInitDB()) { + LOG(ERROR) << "DBProcessManager init msmonitor db failed"; + return false; + } + + bool flag = true; + APIFormat apiData; + CommunicationOpFormat communicationOpData; + TaskFormat taskData; + ComputeTaskInfoFormat computeTaskInfoData; + MstxFormat mstxData; + + { + std::lock_guard lock(dataMutex_); + apiData = std::move(apiData_); + communicationOpData = std::move(communicationOpData_); + taskData = std::move(taskData_); + computeTaskInfoData = std::move(computeTaskInfoData_); + mstxData = std::move(mstxData_); + } + + flag = (apiData.empty() || SaveIncDataToDB(apiData, TABLE_CANN_API)) && flag; + flag = (communicationOpData.empty() || SaveIncDataToDB(communicationOpData, TABLE_COMMUNICATION_OP)) && flag; + flag = (taskData.empty() || SaveIncDataToDB(taskData, TABLE_TASK)) && flag; + flag = (computeTaskInfoData.empty() || SaveIncDataToDB(computeTaskInfoData, TABLE_COMPUTE_TASK_INFO)) && flag; + flag = (mstxData.empty() || SaveIncDataToDB(mstxData, TABLE_MSTX)) && flag; + + return flag; +} + +bool DBProcessManager::SaveConstantData() +{ + bool flag = true; + flag = InsertDataToDB(HCCL_DATA_TYPE, TABLE_HCCL_DATA_TYPE, msMonitorDB_) && flag; + flag = InsertDataToDB(MSTX_EVENT_TYPE, TABLE_MSTX_EVENT_TYPE, msMonitorDB_) && flag; + flag = InsertDataToDB(API_TYPE, TABLE_API_TYPE, msMonitorDB_) && flag; + flag = InsertDataToDB(META_DATA, TABLE_META_DATA, msMonitorDB_) && flag; + + std::vector> hostInfoData {{GetHostUid(), GetHostName()}}; + flag = InsertDataToDB(hostInfoData, TABLE_HOST_INFO, msMonitorDB_) && flag; + + std::vector> sessionTimeInfoData {{sessionStartTime_, getCurrentTimestamp64()}}; + flag = InsertDataToDB(sessionTimeInfoData, TABLE_SESSION_TIME_INFO, msMonitorDB_) && flag; + + auto stringIdData = IdPool::GetInstance()->GetStringIdData(); + flag = (stringIdData.empty() || InsertDataToDB(stringIdData, TABLE_STRING_IDS, msMonitorDB_)) && flag; + return flag; +} + +bool DBProcessManager::SaveParallelGroupData() +{ + const std::string parallel_group_info_key = "parallel_group_info"; + auto iter = clusterConfigData.find(parallel_group_info_key); + if (iter == clusterConfigData.end()) { + LOG(WARNING) << "DBProcessManager SaveParallelGroupData parallel_group_info is not found"; + return true; + } + const std::string& parallel_group_info = iter->second; + if (!parallel_group_info.empty()) { + std::vector> data {{parallel_group_info_key, parallel_group_info}}; + return InsertDataToDB(data, TABLE_META_DATA, msMonitorDB_); + } + return true; +} + +bool DBProcessManager::SaveRankDeviceData() +{ + if (msMonitorDB_.dbRunner->CheckTableExists(TABLE_RANK_DEVICE_MAP)) { + return true; + } + if (deviceSet_.empty()) { + return false; + } + auto rankId = GetRankId(); + std::vector> rankDeviceData; + rankDeviceData.reserve(deviceSet_.size()); + for (auto deviceId : deviceSet_) { + rankDeviceData.emplace_back(rankId, deviceId); + } + if (!InsertDataToDB(rankDeviceData, TABLE_RANK_DEVICE_MAP, msMonitorDB_)) { + LOG(ERROR) << "DBProcessManager insert rank device map data failed"; + return false; + } + return true; +} + +void DBProcessManager::RunPostTask() +{ + SaveData(); + + std::lock_guard lock(dataMutex_); + if (hasSavedData_) { + if (CheckAndInitDB()) { + SaveConstantData(); + SaveParallelGroupData(); + SaveRankDeviceData(); + } else { + LOG(ERROR) << "DBProcessManager init msmonitor db failed"; + } + } + sessionStartTime_ = 0; + hasSavedData_ = false; + reportInterval_.store(DEFAULT_FLUSH_INTERVAL); + deviceSet_.clear(); + apiData_.clear(); + computeTaskInfoData_.clear(); + communicationOpData_.clear(); + taskData_.clear(); + mstxData_.clear(); + mstxRangeHostDataMap_.clear(); + mstxRangeDeviceDataMap_.clear(); + savePath_.clear(); + msMonitorDB_.database = nullptr; + msMonitorDB_.dbRunner = nullptr; +} + +void DBProcessManager::ProcessApiData(msptiActivityApi *record) +{ + std::lock_guard lock(dataMutex_); + uint64_t name = IdPool::GetInstance()->GetUint64Id(record->name); + uint64_t globalTid = ConcatGlobalTid(record->pt.processId, record->pt.threadId); + uint64_t connectionId = record->correlationId; + apiData_.emplace_back(static_cast(record->start), static_cast(record->end), + API_NODE_TYPE, globalTid, connectionId, name); +} + +std::string DBProcessManager::ConstructCommOpName(const std::string &opName, const std::string &groupName) +{ + uint64_t opCount = communicationGroupOpCount_[groupName]++; + std::string groupId; + auto it = communicationGroupNameMap_.find(groupName); + if (it == communicationGroupNameMap_.end()) { + static const size_t GROUP_ID_LEN = 3; + auto groupHashId = std::to_string(CalcHashId(groupName)); + if (groupHashId.size() >= GROUP_ID_LEN) { + groupHashId = groupHashId.substr(groupHashId.size()-GROUP_ID_LEN); + } + communicationGroupNameMap_.emplace(groupName, groupHashId); + groupId = groupHashId; + } else { + groupId = it->second; + } + return opName + "_" + groupId + "_" + std::to_string(opCount) + "_1"; +} + +void DBProcessManager::ProcessCommunicationData(msptiActivityCommunication *record) +{ + std::lock_guard lock(dataMutex_); + uint64_t groupName = IdPool::GetInstance()->GetUint64Id(record->commName); + auto commOpName = ConstructCommOpName(record->name, record->commName); + uint64_t opName = IdPool::GetInstance()->GetUint64Id(commOpName); + uint32_t opId = communicationOpId_.fetch_add(1); + uint64_t algType = IdPool::GetInstance()->GetUint64Id(record->algType); + uint64_t opType = IdPool::GetInstance()->GetUint64Id(record->name); + uint64_t connectionId = record->correlationId; + communicationOpData_.emplace_back(opName, static_cast(record->start), static_cast(record->end), + connectionId, groupName, opId, 0, 0, static_cast(record->dataType), + algType, static_cast(record->count), opType); +} + +void DBProcessManager::ProcessKernelData(msptiActivityKernel *record) +{ + std::lock_guard lock(dataMutex_); + uint64_t opName = IdPool::GetInstance()->GetUint64Id(record->name); + uint64_t taskType = IdPool::GetInstance()->GetUint64Id(record->type); + uint64_t globalTaskId = globalTaskId_.fetch_add(1); + uint64_t NAId = IdPool::GetInstance()->GetUint64Id(NA); + computeTaskInfoData_.emplace_back(opName, globalTaskId, UINT32_MAX, UINT32_MAX, taskType, + NAId, NAId, NAId, NAId, NAId, NAId, NAId, NAId, NAId, NAId); + uint64_t connectionId = record->correlationId; + uint32_t deviceId = record->ds.deviceId; + taskData_.emplace_back(static_cast(record->start), static_cast(record->end), + deviceId, connectionId, globalTaskId, GetProcessId(), taskType, UINT32_MAX, + static_cast(record->ds.streamId), UINT32_MAX, UINT32_MAX); + deviceSet_.insert(deviceId); +} + +void DBProcessManager::ProcessMstxData(msptiActivityMarker *record) +{ + std::lock_guard lock(dataMutex_); + if (record->sourceKind == msptiActivitySourceKind::MSPTI_ACTIVITY_SOURCE_KIND_HOST) { + ProcessMstxHostData(record); + } else if (record->sourceKind == msptiActivitySourceKind::MSPTI_ACTIVITY_SOURCE_KIND_DEVICE) { + ProcessMstxDeviceData(record); + } +} + +void DBProcessManager::ProcessMstxHostData(msptiActivityMarker *record) +{ + uint64_t connectionId = record->id + MSTX_CONNECTION_ID_OFFSET; + uint64_t timestamp = static_cast(record->timestamp); + uint64_t message = IdPool::GetInstance()->GetUint64Id(record->name); + uint64_t domain = IdPool::GetInstance()->GetUint64Id(record->domain); + uint64_t globalTid = ConcatGlobalTid(record->objectId.pt.processId, record->objectId.pt.threadId); + if (record->flag == msptiActivityFlag::MSPTI_ACTIVITY_FLAG_MARKER_INSTANTANEOUS || + record->flag == msptiActivityFlag::MSPTI_ACTIVITY_FLAG_MARKER_INSTANTANEOUS_WITH_DEVICE) { + mstxData_.emplace_back(timestamp, timestamp, MSTX_MARKER_TYPE, UINT32_MAX, UINT32_MAX, + message, globalTid, globalTid, domain, connectionId); + } else if (record->flag == msptiActivityFlag::MSPTI_ACTIVITY_FLAG_MARKER_START || + record->flag == msptiActivityFlag::MSPTI_ACTIVITY_FLAG_MARKER_START_WITH_DEVICE) { + mstxRangeHostDataMap_.emplace(connectionId, MstxHostData{connectionId, timestamp, globalTid, domain, message}); + } else if (record->flag == msptiActivityFlag::MSPTI_ACTIVITY_FLAG_MARKER_END || + record->flag == msptiActivityFlag::MSPTI_ACTIVITY_FLAG_MARKER_END_WITH_DEVICE) { + auto it = mstxRangeHostDataMap_.find(connectionId); + if (it != mstxRangeHostDataMap_.end()) { + mstxData_.emplace_back(it->second.timestamp, timestamp, MSTX_RANGE_TYPE, UINT32_MAX, UINT32_MAX, + it->second.message, it->second.globalTid, globalTid, it->second.domain, connectionId); + mstxRangeHostDataMap_.erase(it); + } + } +} + +void DBProcessManager::ProcessMstxDeviceData(msptiActivityMarker *record) +{ + uint64_t connectionId = record->id + MSTX_CONNECTION_ID_OFFSET; + uint64_t timestamp = static_cast(record->timestamp); + uint64_t taskType = IdPool::GetInstance()->GetUint64Id(MSTX_TASK_TYPE); + if (record->flag == msptiActivityFlag::MSPTI_ACTIVITY_FLAG_MARKER_INSTANTANEOUS_WITH_DEVICE) { + taskData_.emplace_back(timestamp, timestamp, + static_cast(record->objectId.ds.deviceId), connectionId, + globalTaskId_.fetch_add(1), GetProcessId(), taskType, UINT32_MAX, + static_cast(record->objectId.ds.streamId), UINT32_MAX, UINT32_MAX); + } else if (record->flag == msptiActivityFlag::MSPTI_ACTIVITY_FLAG_MARKER_START_WITH_DEVICE) { + mstxRangeDeviceDataMap_.emplace(connectionId, + MstxDeviceData{connectionId, timestamp, globalTaskId_.fetch_add(1)}); + } else if (record->flag == msptiActivityFlag::MSPTI_ACTIVITY_FLAG_MARKER_END_WITH_DEVICE) { + auto it = mstxRangeDeviceDataMap_.find(connectionId); + if (it != mstxRangeDeviceDataMap_.end()) { + uint32_t deviceId = static_cast(record->objectId.ds.deviceId); + taskData_.emplace_back(it->second.timestamp, timestamp, + deviceId, connectionId, it->second.globalTaskId, GetProcessId(), taskType, + UINT32_MAX, static_cast(record->objectId.ds.streamId), UINT32_MAX, UINT32_MAX); + mstxRangeDeviceDataMap_.erase(it); + deviceSet_.insert(deviceId); + } + } +} + +ErrCode DBProcessManager::ConsumeMsptiData(msptiActivity *record) +{ + if (record == nullptr) { + LOG(ERROR) << "DBProcessManager::ConsumeMsptiData record is null"; + return ErrCode::VALUE; + } + switch (record->kind) { + case msptiActivityKind::MSPTI_ACTIVITY_KIND_API: + ProcessApiData(ReinterpretConvert(record)); + break; + case msptiActivityKind::MSPTI_ACTIVITY_KIND_COMMUNICATION: + ProcessCommunicationData(ReinterpretConvert(record)); + break; + case msptiActivityKind::MSPTI_ACTIVITY_KIND_KERNEL: + ProcessKernelData(ReinterpretConvert(record)); + break; + case msptiActivityKind::MSPTI_ACTIVITY_KIND_MARKER: + ProcessMstxData(ReinterpretConvert(record)); + break; + default: + LOG(WARNING) << record->kind << " is not supported for DBProcessManager"; + break; + } + return ErrCode::SUC; +} +} // namespace db +} // namespace ipc_monitor +} // namespace dynolog_npu diff --git a/msmonitor/plugin/ipc_monitor/db/DBProcessManager.h b/msmonitor/plugin/ipc_monitor/db/DBProcessManager.h new file mode 100644 index 0000000000000000000000000000000000000000..100e56fdabab522bb637f779f3575e63ef420acc --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/db/DBProcessManager.h @@ -0,0 +1,159 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef IPC_MONITOR_DB_PROCESS_MANAGER_H +#define IPC_MONITOR_DB_PROCESS_MANAGER_H + +#include +#include +#include +#include "MsptiDataProcessBase.h" +#include "db/DBInfo.h" + +namespace dynolog_npu { +namespace ipc_monitor { +namespace db { +// STRING_IDS: id, value +using StringIdFormat = std::vector>; +// CANN_API: startNs, endNs, type, globalTid, connectionId, name +using APIFormat = std::vector>; +// COMMUNICATION_OP: opName, startNs, endNs, connectionId, groupName, +// opId, relay, retry, dataType, algType, count, opType +using CommunicationOpFormat = std::vector>; +// COMPUTE_TASK_INFO: name, globalTaskId, blockDim, mixBlockDim, taskType, opType, inputFormats, inputDataTypes, +// inputShapes, outputFormats, outputDataTypes, outputShapes, attrInfo, opState, hf32Eligible +using ComputeTaskInfoFormat = std::vector>; +// TASK: startNs, endNs, deviceId, connectionId, globalTaskId, +// globalPid, taskType, contextId, streamId, taskId, modelId +using TaskFormat = std::vector>; +// MSTX: startNs, endNs, eventType, rangeId, category, +// message, globalTid, endGlobalTid, domainId, connectionId +using MstxFormat = std::vector>; + +struct MstxHostData { + uint64_t connectionId; + uint64_t timestamp; + uint64_t globalTid; + uint64_t domain; + uint64_t message; +}; + +struct MstxDeviceData { + uint64_t connectionId; + uint64_t timestamp; + uint64_t globalTaskId; +}; + +class DBProcessManager : public MsptiDataProcessBase { +public: + DBProcessManager(std::string savePath) + : MsptiDataProcessBase("DBProcessManager"), savePath_(std::move(savePath)) {} + ~DBProcessManager() = default; + ErrCode ConsumeMsptiData(msptiActivity *record) override; + void SetReportInterval(uint32_t interval) override; + void RunPreTask() override; + void ExecuteTask() override; + void RunPostTask() override; + +private: + void ProcessApiData(msptiActivityApi *record); + void ProcessCommunicationData(msptiActivityCommunication *record); + void ProcessKernelData(msptiActivityKernel *record); + void ProcessMstxData(msptiActivityMarker *record); + void ProcessMstxHostData(msptiActivityMarker *record); + void ProcessMstxDeviceData(msptiActivityMarker *record); + bool CheckAndInitDB(); + bool SaveData(); + bool SaveConstantData(); + bool SaveParallelGroupData(); + bool SaveRankDeviceData(); + std::string ConstructCommOpName(const std::string &opName, const std::string &groupName); + template + bool SaveIncDataToDB(const std::vector> &data, const std::string &tableName); + +private: + uint64_t sessionStartTime_{0}; + std::string savePath_; + std::mutex dbMutex_; + DBInfo msMonitorDB_; + std::atomic reportInterval_{0}; + + std::mutex dataMutex_; + bool hasSavedData_{false}; + std::unordered_set deviceSet_; + // api data + APIFormat apiData_; + // communication data + std::atomic communicationOpId_{0}; + std::unordered_map communicationGroupOpCount_; + std::unordered_map communicationGroupNameMap_; + CommunicationOpFormat communicationOpData_; + // compute task info data + std::atomic globalTaskId_{0}; + ComputeTaskInfoFormat computeTaskInfoData_; + // task data + TaskFormat taskData_; + // mstx data + std::unordered_map mstxRangeHostDataMap_; + std::unordered_map mstxRangeDeviceDataMap_; + MstxFormat mstxData_; +}; + +template +bool InsertDataToDB(const std::vector> &data, const std::string &tableName, DBInfo &msMonitorDB) +{ + LOG(INFO) << "InsertDataToDB tableName: " << tableName; + if (data.empty()) { + LOG(WARNING) << tableName << " is empty"; + return true; + } + if (msMonitorDB.dbRunner == nullptr) { + LOG(ERROR) << "msMonitorDB dbRunner is null"; + return false; + } + if (msMonitorDB.database == nullptr) { + LOG(ERROR) << "msMonitorDB database is null"; + return false; + } + if (!msMonitorDB.dbRunner->CreateTable(tableName, msMonitorDB.database->GetTableCols(tableName))) { + LOG(ERROR) << "msMonitorDB " << tableName << " CreateTable failed"; + return false; + } + if (!msMonitorDB.dbRunner->InsertData(tableName, data)) { + LOG(ERROR) << "msMonitorDB " << tableName << " InsertData failed"; + return false; + } + return true; +} + +template +bool DBProcessManager::SaveIncDataToDB(const std::vector> &data, const std::string &tableName) +{ + if (data.empty()) { + LOG(WARNING) << tableName << " is empty"; + return true; + } + bool ret = InsertDataToDB(data, tableName, msMonitorDB_); + hasSavedData_ = hasSavedData_ || ret; + return ret; +} +} // namespace db +} // namespace ipc_monitor +} // namespace dynolog_npu +#endif // IPC_MONITOR_DB_PROCESS_MANAGER_H diff --git a/msmonitor/plugin/ipc_monitor/db/DBRunner.cpp b/msmonitor/plugin/ipc_monitor/db/DBRunner.cpp new file mode 100644 index 0000000000000000000000000000000000000000..171597615c2f3cc049defa64d8588c722c4c3561 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/db/DBRunner.cpp @@ -0,0 +1,160 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "db/DBRunner.h" +#include + +namespace dynolog_npu { +namespace ipc_monitor { +namespace db { +namespace { +std::string GetColumnsString(const std::vector &columns) +{ + std::vector columnStrings(columns.size()); + std::transform(columns.begin(), columns.end(), columnStrings.begin(), [](const TableColumn &column) { + return column.ToString(); + }); + return join(columnStrings, ","); +} +} + +bool DBRunner::CheckTableExists(const std::string &tableName) const +{ + std::shared_ptr conn{nullptr}; + MakeSharedPtr(conn, path_); + if (conn == nullptr) { + return false; + } + return conn->CheckTableExists(tableName); +} + +bool DBRunner::CreateTable(const std::string &tableName, const std::vector &columns) const +{ + if (tableName.empty()) { + LOG(ERROR) << "Create table failed, table name is empty"; + return false; + } + std::shared_ptr conn{nullptr}; + MakeSharedPtr(conn, path_); + if (conn == nullptr) { + return false; + } + LOG(INFO) << "Create table " << tableName; + std::string columnsString = GetColumnsString(columns); + std::string sql = "CREATE TABLE IF NOT EXISTS " + tableName + " (" + columnsString + ")"; + if (!conn->ExecuteCreateTable(sql)) { + LOG(ERROR) << "Create table " << tableName << " failed"; + return false; + } + LOG(INFO) << "Create table " << tableName << " success"; + return true; +} + +bool DBRunner::CreateIndex(const std::string &tableName, const std::string &indexName, + const std::vector &colNames) const +{ + if (tableName.empty() || indexName.empty() || colNames.empty()) { + LOG(ERROR) << "Create index failed, table name or index name or column name is empty"; + return false; + } + std::shared_ptr conn{nullptr}; + MakeSharedPtr(conn, path_); + if (conn == nullptr) { + return false; + } + LOG(INFO) << "Create index " << indexName << " on table " << tableName; + std::string valueStr = join(colNames, ","); + std::string sql = "CREATE INDEX IF NOT EXISTS " + indexName + " ON " + tableName + " (" + valueStr + ")"; + if (!conn->ExecuteCreateIndex(sql)) { + LOG(ERROR) << "Create index " << indexName << " on table " << tableName << " failed, sql: " << sql; + return false; + } + LOG(INFO) << "Create index " << indexName << " on table " << tableName << " success"; + return true; +} + +bool DBRunner::DropTable(const std::string &tableName) const +{ + if (tableName.empty()) { + LOG(ERROR) << "Drop table failed, table name is empty"; + return false; + } + std::shared_ptr conn{nullptr}; + MakeSharedPtr(conn, path_); + if (conn == nullptr) { + return false; + } + LOG(INFO) << "Drop table " << tableName; + std::string sql = "DROP TABLE " + tableName; + if (!conn->ExecuteDropTable(sql)) { + LOG(ERROR) << "Drop table " << tableName << " failed"; + return false; + } + LOG(INFO) << "Drop table " << tableName << " success"; + return true; +} + +bool DBRunner::DeleteData(const std::string &sql) const +{ + std::shared_ptr conn{nullptr}; + MakeSharedPtr(conn, path_); + if (conn == nullptr) { + return false; + } + LOG(INFO) << "Delete data, sql: " << sql; + if (!conn->ExecuteDelete(sql)) { + LOG(ERROR) << "Delete data failed, sql: " << sql; + return false; + } + LOG(INFO) << "Delete data success, sql: " << sql; + return true; +} + +bool DBRunner::UpdateData(const std::string &sql) const +{ + std::shared_ptr conn{nullptr}; + MakeSharedPtr(conn, path_); + if (conn == nullptr) { + return false; + } + LOG(INFO) << "Update data, sql: " << sql; + if (!conn->ExecuteUpdate(sql)) { + LOG(ERROR) << "Update data failed, sql: " << sql; + return false; + } + LOG(INFO) << "Update data success, sql: " << sql; + return true; +} + +std::vector DBRunner::GetTableColumns(const std::string &tableName) const +{ + std::shared_ptr conn{nullptr}; + MakeSharedPtr(conn, path_); + if (conn == nullptr) { + return {}; + } + LOG(INFO) << "Get table columns, table name: " << tableName; + auto cols = conn->ExecuteGetTableColumns(tableName); + if (cols.empty()) { + LOG(ERROR) << "Get table columns failed, table name: " << tableName; + return cols; + } + LOG(INFO) << "Get table columns success, table name: " << tableName; + return cols; +} +} // namespace db +} // namespace ipc_monitor +} // namespace dynolog_npu diff --git a/msmonitor/plugin/ipc_monitor/db/DBRunner.h b/msmonitor/plugin/ipc_monitor/db/DBRunner.h new file mode 100644 index 0000000000000000000000000000000000000000..f9619ce48c587faa3b869e3fcf5fcd88de076ba1 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/db/DBRunner.h @@ -0,0 +1,88 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef IPC_MONITOR_DB_RUNNER_H +#define IPC_MONITOR_DB_RUNNER_H +#include "db/Connection.h" +#include "utils.h" + +namespace dynolog_npu { +namespace ipc_monitor { +namespace db { +class DBRunner { +public: + explicit DBRunner(const std::string &dbPath): path_(dbPath) {}; + ~DBRunner() = default; + bool CheckTableExists(const std::string &tableName) const; + bool CreateTable(const std::string &tableName, const std::vector &cols) const; + bool CreateIndex(const std::string &tableName, const std::string &indexName, + const std::vector &colNames) const; + bool DropTable(const std::string &tableName) const; + template + bool InsertData(const std::string &tableName, const std::vector> &data) const; + bool DeleteData(const std::string &sql) const; + template + bool QueryData(const std::string &sql, std::vector> &result) const; + bool UpdateData(const std::string &sql) const; + std::vector GetTableColumns(const std::string &tableName) const; +private: + std::string path_; +}; + +template +bool DBRunner::InsertData(const std::string &tableName, const std::vector> &data) const +{ + if (tableName.empty()) { + LOG(ERROR) << "Table name is empty"; + return false; + } + LOG(INFO) << "Start insert data to " << tableName; + std::shared_ptr conn{nullptr}; + MakeSharedPtr(conn, path_); + if (conn == nullptr) { + LOG(ERROR) << "Create connection for " << tableName << " failed"; + return false; + } + if (!conn->ExecuteInsert(tableName, data)) { + LOG(ERROR) << "Insert data to " << tableName << " failed"; + return false; + } + LOG(INFO) << "Insert data to " << tableName << " success"; + return true; +} + +template +bool DBRunner::QueryData(const std::string &sql, std::vector> &result) const +{ + LOG(INFO) << "Start query data"; + std::shared_ptr conn{nullptr}; + MakeSharedPtr(conn, path_); + if (conn == nullptr) { + LOG(ERROR) << "Create connection failed: " << sql; + return false; + } + if (!conn->ExecuteQuery(sql, result)) { + LOG(ERROR) << "Query data failed: " << sql; + return false; + } + LOG(INFO) << "Query data success: " << sql; + return true; +} +} // namespace db +} // namespace ipc_monitor +} // namespace dynolog_npu + +#endif // IPC_MONITOR_DB_RUNNER_H diff --git a/msmonitor/plugin/ipc_monitor/db/DataBase.cpp b/msmonitor/plugin/ipc_monitor/db/DataBase.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a3197a5d9f974039196adcf6ec59f32dd2eb0b4e --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/db/DataBase.cpp @@ -0,0 +1,155 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "db/DataBase.h" +#include "db/DBConstant.h" + +namespace dynolog_npu { +namespace ipc_monitor { +namespace db { +namespace { +const TableColumns STRING_IDS = { + {"id", SQL_INT_TYPE}, + {"value", SQL_TEXT_TYPE} +}; + +const TableColumns SESSION_TIME_INFO = { + {"startTimeNs", SQL_INT_TYPE}, + {"endTimeNs", SQL_INT_TYPE} +}; + +const TableColumns ENUM_TABLE = { + {"id", SQL_INT_TYPE, true}, + {"name", SQL_TEXT_TYPE} +}; + +const TableColumns META_DATA = { + {"name", SQL_TEXT_TYPE}, + {"value", SQL_TEXT_TYPE} +}; + +const TableColumns HOST_INFO = { + {"hostUid", SQL_TEXT_TYPE}, + {"hostName", SQL_TEXT_TYPE} +}; + +const TableColumns RANK_DEVICE_MAP = { + {"rankId", SQL_INT_TYPE}, + {"deviceId", SQL_INT_TYPE} +}; + +const TableColumns CANN_API = { + {"startNs", SQL_INT_TYPE}, + {"endNs", SQL_INT_TYPE}, + {"type", SQL_INT_TYPE}, + {"globalTid", SQL_INT_TYPE}, + {"connectionId", SQL_INT_TYPE}, + {"name", SQL_INT_TYPE} +}; + +const TableColumns TASK = { + {"startNs", SQL_INT_TYPE}, + {"endNs", SQL_INT_TYPE}, + {"deviceId", SQL_INT_TYPE}, + {"connectionId", SQL_INT_TYPE}, + {"globalTaskId", SQL_INT_TYPE}, + {"globalPid", SQL_INT_TYPE}, + {"taskType", SQL_INT_TYPE}, + {"contextId", SQL_INT_TYPE}, + {"streamId", SQL_INT_TYPE}, + {"taskId", SQL_INT_TYPE}, + {"modelId", SQL_INT_TYPE} +}; + +const TableColumns COMPUTE_TASK_INFO = { + {"name", SQL_INT_TYPE}, + {"globalTaskId", SQL_INT_TYPE}, + {"blockDim", SQL_INT_TYPE}, + {"mixBlockDim", SQL_INT_TYPE}, + {"taskType", SQL_INT_TYPE}, + {"opType", SQL_INT_TYPE}, + {"inputFormats", SQL_INT_TYPE}, + {"inputDataTypes", SQL_INT_TYPE}, + {"inputShapes", SQL_INT_TYPE}, + {"outputFormats", SQL_INT_TYPE}, + {"outputDataTypes", SQL_INT_TYPE}, + {"outputShapes", SQL_INT_TYPE}, + {"attrInfo", SQL_INT_TYPE}, + {"opState", SQL_INT_TYPE}, + {"hf32Eligible", SQL_INT_TYPE} +}; + +const TableColumns COMMUNICATION_OP = { + {"opName", SQL_INT_TYPE}, + {"startNs", SQL_INT_TYPE}, + {"endNs", SQL_INT_TYPE}, + {"connectionId", SQL_INT_TYPE}, + {"groupName", SQL_INT_TYPE}, + {"opId", SQL_INT_TYPE}, + {"relay", SQL_INT_TYPE}, + {"retry", SQL_INT_TYPE}, + {"dataType", SQL_INT_TYPE}, + {"algType", SQL_INT_TYPE}, + {"count", SQL_NUMERIC_TYPE}, + {"opType", SQL_INT_TYPE} +}; + +const TableColumns MSTX = { + {"startNs", SQL_INT_TYPE}, + {"endNs", SQL_INT_TYPE}, + {"eventType", SQL_INT_TYPE}, + {"rangeId", SQL_INT_TYPE}, + {"category", SQL_INT_TYPE}, + {"message", SQL_INT_TYPE}, + {"globalTid", SQL_INT_TYPE}, + {"endGlobalTid", SQL_INT_TYPE}, + {"domainId", SQL_INT_TYPE}, + {"connectionId", SQL_INT_TYPE} +}; +} // namespace + +const TableColumns& Database::GetTableCols(const std::string &tableName) +{ + auto iter = tableColumns_.find(tableName); + if (iter == tableColumns_.end()) { + LOG(ERROR) << "Table " << tableName << " is not found"; + return {}; + } + return iter->second; +} + +MsMonitorDB::MsMonitorDB() +{ + dbName_ = "msmonitor.db"; + tableColumns_ = { + {TABLE_STRING_IDS, STRING_IDS}, + {TABLE_SESSION_TIME_INFO, SESSION_TIME_INFO}, + {TABLE_COMMUNICATION_OP, COMMUNICATION_OP}, + {TABLE_HCCL_DATA_TYPE, ENUM_TABLE}, + {TABLE_MSTX, MSTX}, + {TABLE_MSTX_EVENT_TYPE, ENUM_TABLE}, + {TABLE_API_TYPE, ENUM_TABLE}, + {TABLE_CANN_API, CANN_API}, + {TABLE_TASK, TASK}, + {TABLE_COMPUTE_TASK_INFO, COMPUTE_TASK_INFO}, + {TABLE_META_DATA, META_DATA}, + {TABLE_HOST_INFO, HOST_INFO}, + {TABLE_RANK_DEVICE_MAP, RANK_DEVICE_MAP} + }; +} +} // namespace db +} // namespace ipc_monitor +} // namespace dynolog_npu diff --git a/msmonitor/plugin/ipc_monitor/db/DataBase.h b/msmonitor/plugin/ipc_monitor/db/DataBase.h new file mode 100644 index 0000000000000000000000000000000000000000..2aa847016a10608e3c65ab2328efd5f54c314a70 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/db/DataBase.h @@ -0,0 +1,48 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef IPC_MONITOR_DB_BASE_H +#define IPC_MONITOR_DB_BASE_H + +#include +#include "db/Connection.h" + +namespace dynolog_npu { +namespace ipc_monitor { +namespace db { +using TableColumns = std::vector; + +class Database { +public: + Database() = default; + virtual ~Database() = default; + void SetDBName(std::string dbName) { dbName_ = std::move(dbName); } + std::string GetDBName() const { return dbName_; } + const TableColumns& GetTableCols(const std::string &tableName); +protected: + std::string dbName_; + std::unordered_map tableColumns_; +}; + +class MsMonitorDB : public Database { +public: + MsMonitorDB(); +}; +} // namespace db +} // namespace ipc_monitor +} // namespace dynolog_npu + +#endif // IPC_MONITOR_DB_BASE_H diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricApiProcess.cpp b/msmonitor/plugin/ipc_monitor/metric/MetricApiProcess.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3998d2d5c22fb798d2bf605c38e6909d83b8e3f1 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricApiProcess.cpp @@ -0,0 +1,88 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MetricApiProcess.h" + +#include +#include + +#include "utils.h" + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +std::string ApiMetric::seriesToJson() +{ + nlohmann::json jsonMsg; + jsonMsg["kind"] = "API"; + jsonMsg["deviceId"] = -1; + jsonMsg["duration"] = duration; + jsonMsg["timestamp"] = timestamp; + return jsonMsg.dump(); +} + +void MetricApiProcess::ConsumeMsptiData(msptiActivity *record) +{ + msptiActivityApi* apiData = ReinterpretConvert(record); + std::shared_ptr tmp; + MakeSharedPtr(tmp); + if (tmp == nullptr || memcpy_s(tmp.get(), sizeof(msptiActivityApi), apiData, sizeof(msptiActivityApi)) != EOK) { + LOG(ERROR) << "memcpy_s failed " << IPC_ERROR(ErrCode::MEMORY); + return; + } + { + std::unique_lock lock(dataMutex); + records.emplace_back(std::move(tmp)); + } +} + +std::vector MetricApiProcess::AggregatedData() +{ + std::vector> copyRecords; + { + std::unique_lock lock(dataMutex); + copyRecords = std::move(records); + records.clear(); + } + if (copyRecords.empty()) { + return {}; + } + ApiMetric apiMetric{}; + auto ans = std::accumulate(copyRecords.begin(), copyRecords.end(), 0ULL, + [](uint64_t acc, std::shared_ptr api) { + return acc + api->end - api->start; + }); + apiMetric.duration = ans; + apiMetric.deviceId = -1; + apiMetric.timestamp = getCurrentTimestamp64(); + return {apiMetric}; +} + +void MetricApiProcess::SendProcessMessage() +{ + auto afterAggregated = AggregatedData(); + for (auto& metric: afterAggregated) { + SendMessage(metric.seriesToJson()); + } +} + +void MetricApiProcess::Clear() +{ + records.clear(); +} +} +} +} diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricApiProcess.h b/msmonitor/plugin/ipc_monitor/metric/MetricApiProcess.h new file mode 100644 index 0000000000000000000000000000000000000000..c9357e58eec78ebf4b67941c14c16c3747daa46f --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricApiProcess.h @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef METRIC_API_PROCESS_H +#define METRIC_API_PROCESS_H + +#include +#include +#include "MetricProcessBase.h" + + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +struct ApiMetric { + uint64_t duration; + uint64_t timestamp; + uint32_t deviceId; +public: + std::string seriesToJson(); +}; + +class MetricApiProcess : public MetricProcessBase { +public: + MetricApiProcess() = default; + void ConsumeMsptiData(msptiActivity *record) override; + std::vector AggregatedData(); + void SendProcessMessage() override; + void Clear() override; +private: + std::mutex dataMutex; + std::vector> records; +}; +} +} +} + +#endif \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricCommunicationProcess.cpp b/msmonitor/plugin/ipc_monitor/metric/MetricCommunicationProcess.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9126b1dfd13721e4f679c6cebe887dcfa47e125f --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricCommunicationProcess.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MetricCommunicationProcess.h" +#include +#include +#include "utils.h" + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +std::string CommunicationMetric::seriesToJson() +{ + nlohmann::json jsonMsg; + jsonMsg["kind"] = "Communication"; + jsonMsg["deviceId"] = deviceId; + jsonMsg["duration"] = duration; + jsonMsg["timestamp"] = timestamp; + return jsonMsg.dump(); +} + +void MetricCommunicationProcess::ConsumeMsptiData(msptiActivity *record) +{ + msptiActivityCommunication* communicationData = ReinterpretConvert(record); + std::shared_ptr tmp; + MakeSharedPtr(tmp); + if (tmp == nullptr || memcpy_s(tmp.get(), sizeof(msptiActivityCommunication), communicationData, sizeof(msptiActivityCommunication)) != EOK) { + LOG(ERROR) << "memcpy_s failed " << IPC_ERROR(ErrCode::MEMORY); + return; + } + { + std::unique_lock lock(dataMutex); + records.emplace_back(std::move(tmp)); + } +} + +std::vector MetricCommunicationProcess::AggregatedData() +{ + std::vector> copyRecords; + { + std::unique_lock lock(dataMutex); + copyRecords = std::move(records); + records.clear(); + } + if (copyRecords.empty()) { + return {}; + } + std::unordered_map>> deviceId2CommunicationData = + groupby(copyRecords, [](const std::shared_ptr& data) -> std::uint32_t { + return data->ds.deviceId; + }); + std::vector ans; + auto curTimestamp = getCurrentTimestamp64(); + for (auto& pair: deviceId2CommunicationData) { + CommunicationMetric communicationMetric{}; + auto& communicationDatas = pair.second; + communicationMetric.duration = std::accumulate(communicationDatas.begin(), communicationDatas.end(), 0ULL, + [](uint64_t acc, std::shared_ptr communication) { + return acc + communication->end - communication->start; + }); + communicationMetric.deviceId = pair.first; + communicationMetric.timestamp = curTimestamp; + ans.emplace_back(communicationMetric); + } + return ans; +} + +void MetricCommunicationProcess::SendProcessMessage() +{ + auto afterAggregated = AggregatedData(); + for (auto& metric: afterAggregated) { + SendMessage(metric.seriesToJson()); + } +} + +void MetricCommunicationProcess::Clear() +{ + records.clear(); +} +} +} +} diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricCommunicationProcess.h b/msmonitor/plugin/ipc_monitor/metric/MetricCommunicationProcess.h new file mode 100644 index 0000000000000000000000000000000000000000..745652f8ef57082c16de2980658253f9a2ad5364 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricCommunicationProcess.h @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef METRIC_COMMUNICATION_PROCESS_H +#define METRIC_COMMUNICATION_PROCESS_H + +#include +#include +#include "MetricProcessBase.h" + + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +struct CommunicationMetric { + std::string kindName; + uint64_t duration; + uint64_t timestamp; + uint32_t deviceId; +public: + std::string seriesToJson(); +}; + +class MetricCommunicationProcess : public MetricProcessBase { +public: + MetricCommunicationProcess() = default; + void ConsumeMsptiData(msptiActivity *record) override; + std::vector AggregatedData(); + void SendProcessMessage() override; + void Clear() override; +private: + std::mutex dataMutex; + std::vector> records; +}; +} +} +} + +#endif \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricHcclProcess.cpp b/msmonitor/plugin/ipc_monitor/metric/MetricHcclProcess.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4ebf7d7b6faea5a457466bc2aab369e089c58a27 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricHcclProcess.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MetricHcclProcess.h" +#include +#include +#include "utils.h" + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +std::string HcclMetric::seriesToJson() +{ + nlohmann::json jsonMsg; + jsonMsg["kind"] = "Hccl"; + jsonMsg["deviceId"] = deviceId; + jsonMsg["duration"] = duration; + jsonMsg["timestamp"] = timestamp; + return jsonMsg.dump(); +} + +void MetricHcclProcess::ConsumeMsptiData(msptiActivity *record) +{ + msptiActivityHccl* hcclData = ReinterpretConvert(record); + std::shared_ptr tmp; + MakeSharedPtr(tmp); + if (tmp == nullptr || memcpy_s(tmp.get(), sizeof(msptiActivityHccl), hcclData, sizeof(msptiActivityHccl)) != EOK) { + LOG(ERROR) << "memcpy_s failed " << IPC_ERROR(ErrCode::MEMORY); + return; + } + { + std::unique_lock lock(dataMutex); + records.emplace_back(std::move(tmp)); + } +} + +std::vector MetricHcclProcess::AggregatedData() +{ + std::vector> copyRecords; + { + std::unique_lock lock(dataMutex); + copyRecords = std::move(records); + records.clear(); + } + if (copyRecords.empty()) { + return {}; + } + std::unordered_map>> deviceId2HcclData = + groupby(copyRecords, [](const std::shared_ptr& data) -> std::uint32_t { + return data->ds.deviceId; + }); + std::vector ans; + auto curTimestamp = getCurrentTimestamp64(); + for (auto& pair: deviceId2HcclData) { + HcclMetric hcclMetric{}; + auto& hcclDatas = pair.second; + hcclMetric.duration = std::accumulate(hcclDatas.begin(), hcclDatas.end(), 0ULL, + [](uint64_t acc, std::shared_ptr hccl) { + return acc + hccl->end - hccl->start; + }); + hcclMetric.deviceId = pair.first; + hcclMetric.timestamp = curTimestamp; + ans.emplace_back(hcclMetric); + } + return ans; +} + +void MetricHcclProcess::SendProcessMessage() +{ + auto afterAggregated = AggregatedData(); + for (auto& metric: afterAggregated) { + SendMessage(metric.seriesToJson()); + } +} + +void MetricHcclProcess::Clear() +{ + records.clear(); +} +} +} +} diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricHcclProcess.h b/msmonitor/plugin/ipc_monitor/metric/MetricHcclProcess.h new file mode 100644 index 0000000000000000000000000000000000000000..2c846949d35f1dc3b0c5d359e15dc8d2818db6b5 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricHcclProcess.h @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef METRIC_HCCL_PROCESS_H +#define METRIC_HCCL_PROCESS_H + +#include +#include +#include "MetricProcessBase.h" + + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +struct HcclMetric { + std::string kindName; + uint64_t duration; + uint64_t timestamp; + uint32_t deviceId; +public: + std::string seriesToJson(); +}; + +class MetricHcclProcess : public MetricProcessBase { +public: + MetricHcclProcess() = default; + void ConsumeMsptiData(msptiActivity *record) override; + std::vector AggregatedData(); + void SendProcessMessage() override; + void Clear() override; +private: + std::mutex dataMutex; + std::vector> records; +}; +} +} +} + +#endif \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricKernelProcess.cpp b/msmonitor/plugin/ipc_monitor/metric/MetricKernelProcess.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4013d841b4cff860d4934a618b3c64a9b32ee718 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricKernelProcess.cpp @@ -0,0 +1,96 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MetricKernelProcess.h" + +#include + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +std::string KernelMetric::seriesToJson() +{ + nlohmann::json jsonMsg; + jsonMsg["kind"] = "Kernel"; + jsonMsg["deviceId"] = deviceId; + jsonMsg["duration"] = duration; + jsonMsg["timestamp"] = timestamp; + return jsonMsg.dump(); +} + +void MetricKernelProcess::ConsumeMsptiData(msptiActivity *record) +{ + msptiActivityKernel* kernel = ReinterpretConvert(record); + std::shared_ptr tmp; + MakeSharedPtr(tmp); + if (tmp == nullptr || memcpy_s(tmp.get(), sizeof(msptiActivityKernel), kernel, sizeof(msptiActivityKernel)) != EOK) { + LOG(ERROR) << "memcpy_s failed " << IPC_ERROR(ErrCode::MEMORY); + return; + } + { + std::unique_lock lock(dataMutex); + records.emplace_back(std::move(tmp)); + } +} + +std::vector MetricKernelProcess::AggregatedData() +{ + std::vector> copyRecords; + { + std::unique_lock lock(dataMutex); + copyRecords = std::move(records); + records.clear(); + } + if (copyRecords.empty()) { + return {}; + } + std::unordered_map>> deviceId2KernelData = + groupby(copyRecords, [](const std::shared_ptr& data) -> std::uint32_t { + return data->ds.deviceId; + }); + std::vector ans; + auto curTimestamp = getCurrentTimestamp64(); + for (auto& pair: deviceId2KernelData) { + auto deviceId = pair.first; + auto& kernelDatas = pair.second; + KernelMetric kernelMetric{}; + kernelMetric.duration = std::accumulate(kernelDatas.begin(), kernelDatas.end(), 0ULL, + [](uint64_t acc, std::shared_ptr kernel) { + return acc + kernel->end - kernel->start; + }); + kernelMetric.deviceId = deviceId; + kernelMetric.timestamp = curTimestamp; + ans.emplace_back(kernelMetric); + } + + return ans; +} + +void MetricKernelProcess::SendProcessMessage() +{ + auto afterAggregated = AggregatedData(); + for (auto& metric: afterAggregated) { + SendMessage(metric.seriesToJson()); + } +} + +void MetricKernelProcess::Clear() +{ + records.clear(); +} +} +} +} diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricKernelProcess.h b/msmonitor/plugin/ipc_monitor/metric/MetricKernelProcess.h new file mode 100644 index 0000000000000000000000000000000000000000..9bd034283ece0ba3cd5cfc5f5215b104ef37334c --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricKernelProcess.h @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef METRIC_KERNEL_PROCESS_H +#define METRIC_KERNEL_PROCESS_H + +#include +#include "MetricProcessBase.h" + + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +struct KernelMetric { + uint64_t duration; + uint64_t timestamp; + uint32_t deviceId; +public: + std::string seriesToJson(); +}; + +class MetricKernelProcess : public MetricProcessBase { +public: + MetricKernelProcess() = default; + void ConsumeMsptiData(msptiActivity *record) override; + std::vector AggregatedData(); + void SendProcessMessage() override; + void Clear() override; +private: + std::mutex dataMutex; + std::vector> records; +}; +} +} +} + +#endif \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricManager.cpp b/msmonitor/plugin/ipc_monitor/metric/MetricManager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e2306f0ef1287d54d61b4cdc1a5f18d8f50528ea --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricManager.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MetricManager.h" +#include "MetricKernelProcess.h" +#include "MetricApiProcess.h" +#include "MetricMemCpyProcess.h" +#include "MetricHcclProcess.h" +#include "MetricMarkProcess.h" +#include "MetricMemSetProcess.h" +#include "MetricMemProcess.h" +#include "MetricCommunicationProcess.h" +#include "utils.h" + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +MetricManager::MetricManager(): MsptiDataProcessBase("MetricManager"), + kindSwitchs_(MSPTI_ACTIVITY_KIND_COUNT), consumeStatus_(MSPTI_ACTIVITY_KIND_COUNT) { + metrics.resize(MSPTI_ACTIVITY_KIND_COUNT); + metrics[MSPTI_ACTIVITY_KIND_KERNEL] = std::make_shared(); + metrics[MSPTI_ACTIVITY_KIND_API] = std::make_shared(); + metrics[MSPTI_ACTIVITY_KIND_MEMCPY] = std::make_shared(); + metrics[MSPTI_ACTIVITY_KIND_MARKER] = std::make_shared(); + metrics[MSPTI_ACTIVITY_KIND_MEMSET] = std::make_shared(); + metrics[MSPTI_ACTIVITY_KIND_HCCL] = std::make_shared(); + metrics[MSPTI_ACTIVITY_KIND_MEMORY] = std::make_shared(); + metrics[MSPTI_ACTIVITY_KIND_COMMUNICATION] = std::make_shared(); +} + +void MetricManager::RunPostTask() +{ + for (int i = 0; i < MSPTI_ACTIVITY_KIND_COUNT; i++) { + if (kindSwitchs_[i].load()) { + kindSwitchs_[i] = false; + metrics[i]->Clear(); + } + } +} + +ErrCode MetricManager::ConsumeMsptiData(msptiActivity *record) +{ + if (!kindSwitchs_[record->kind]) { + return ErrCode::PERMISSION; + } + auto metricProcess = metrics[record->kind]; + consumeStatus_[record->kind] = true; + metricProcess->ConsumeMsptiData(record); + consumeStatus_[record->kind] = false; + return ErrCode::SUC; +} + +void MetricManager::SetReportInterval(uint32_t intervalTimes) +{ + if (reportInterval_.load() != intervalTimes) { + SendMetricMsg(); + SetInterval(intervalTimes); + reportInterval_.store(intervalTimes); + } +} + +void MetricManager::ExecuteTask() +{ + SendMetricMsg(); +} + +void MetricManager::SendMetricMsg() +{ + for (int i = 0; i < MSPTI_ACTIVITY_KIND_COUNT; i++) { + if (kindSwitchs_[i].load()) { + metrics[i]->SendProcessMessage(); + } + } +} + +void MetricManager::EnableKindSwitch(msptiActivityKind kind, bool flag) +{ + kindSwitchs_[kind] = flag; +} +} +} +} \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricManager.h b/msmonitor/plugin/ipc_monitor/metric/MetricManager.h new file mode 100644 index 0000000000000000000000000000000000000000..55d6aa7687fc071f2c9aa7c049f4567f6f3d7696 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricManager.h @@ -0,0 +1,48 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef METRIC_MANAGER_H +#define METRIC_MANAGER_H + +#include +#include +#include "MsptiDataProcessBase.h" +#include "MetricProcessBase.h" + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { +class MetricManager : public MsptiDataProcessBase { +public: + MetricManager(); + ~MetricManager() = default; + ErrCode ConsumeMsptiData(msptiActivity *record) override; + void SetReportInterval(uint32_t intervalTimes) override; + void ExecuteTask() override; + void EnableKindSwitch(msptiActivityKind kind, bool flag) override; + void RunPostTask() override; + +private: + void SendMetricMsg(); +private: + std::vector> kindSwitchs_; + std::vector> consumeStatus_; + std::atomic reportInterval_; + std::vector> metrics; +}; +} // namespace metric +} // namespace ipc_monitor +} // namespace dynolog_npu +#endif // METRIC_MANAGER_H diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricMarkProcess.cpp b/msmonitor/plugin/ipc_monitor/metric/MetricMarkProcess.cpp new file mode 100644 index 0000000000000000000000000000000000000000..458be03f3cdd9b6f959fb2fe1c4e579f1f46b045 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricMarkProcess.cpp @@ -0,0 +1,159 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MetricMarkProcess.h" + +#include +#include +#include + +#include "utils.h" + + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +constexpr size_t COMPLETE_RANGE_DATA_SIZE = 4; + +std::string MarkMetric::seriesToJson() +{ + nlohmann::json jsonMsg; + jsonMsg["kind"] = "Marker"; + jsonMsg["deviceId"] = deviceId; + jsonMsg["domain"] = domain; + jsonMsg["name"] = name; + jsonMsg["duration"] = duration; + jsonMsg["timestamp"] = timestamp; + return jsonMsg.dump(); +} + +bool MetricMarkProcess::TransMarkData2Range(const std::vector>& markDatas, + RangeMarkData& rangemarkData) +{ + if (markDatas.size() != COMPLETE_RANGE_DATA_SIZE) { + return false; + } + + for (auto& activityMarker: markDatas) { + if (activityMarker->flag == MSPTI_ACTIVITY_FLAG_MARKER_START_WITH_DEVICE) { + if (activityMarker->sourceKind == MSPTI_ACTIVITY_SOURCE_KIND_DEVICE) { + rangemarkData.deviceId = activityMarker->objectId.ds.deviceId; + rangemarkData.deviceStart = activityMarker->timestamp; + } else { + rangemarkData.start = activityMarker->timestamp; + rangemarkData.name = activityMarker->name; + } + } + if (activityMarker->flag == MSPTI_ACTIVITY_FLAG_MARKER_END_WITH_DEVICE) { + if (activityMarker->sourceKind == MSPTI_ACTIVITY_SOURCE_KIND_DEVICE) { + rangemarkData.deviceEnd = activityMarker->timestamp; + } else { + rangemarkData.end = activityMarker->timestamp; + } + } + } + auto markId = markDatas[0]->id; + std::string domainName = "default"; + auto it = domainMsg.find(markId); + if (it != domainMsg.end()) { + domainName = *it->second; + } + rangemarkData.domain = domainName; + id2Marker.erase(markId); + domainMsg.erase(markId); + return true; +} + +void MetricMarkProcess::ConsumeMsptiData(msptiActivity *record) +{ + msptiActivityMarker* markerData = ReinterpretConvert(record); + std::shared_ptr tmp; + MakeSharedPtr(tmp); + if (tmp == nullptr || memcpy_s(tmp.get(), sizeof(msptiActivityMarker), markerData, sizeof(msptiActivityMarker)) != EOK) { + LOG(ERROR) << "memcpy_s failed " << IPC_ERROR(ErrCode::MEMORY); + return; + } + { + std::unique_lock lock(dataMutex); + records.emplace_back(std::move(tmp)); + if (markerData->flag == MSPTI_ACTIVITY_FLAG_MARKER_START_WITH_DEVICE && + markerData->sourceKind == MSPTI_ACTIVITY_SOURCE_KIND_HOST) { + std::string domainStr = markerData->domain; + auto markId = markerData->id; + domainMsg.emplace(markId, std::make_shared(domainStr)); + } + } +} + +std::vector MetricMarkProcess::AggregatedData() +{ + std::vector> copyRecords; + { + std::unique_lock lock(dataMutex); + copyRecords = std::move(records); + records.clear(); + } + for (auto& record: copyRecords) { + id2Marker[record->id].emplace_back(std::move(record)); + } + std::vector rangeDatas; + for (auto pair = id2Marker.rbegin(); pair != id2Marker.rend(); ++pair) { + auto markId = pair->first; + auto markDatas = pair->second; + RangeMarkData rangeMark{}; + if (TransMarkData2Range(markDatas, rangeMark)) { + rangeDatas.emplace_back(rangeMark); + } + } + + std::vector ans; + MarkMetric hostMarkMetric{}; + MarkMetric devMarkMetric{}; + for (const auto& data : rangeDatas) { + hostMarkMetric.name = data.name; + hostMarkMetric.domain = data.domain; + hostMarkMetric.deviceId = -1; + hostMarkMetric.duration = data.end - data.start; + hostMarkMetric.timestamp = data.start; + ans.emplace_back(hostMarkMetric); + + devMarkMetric.name = data.name; + devMarkMetric.domain = data.domain; + devMarkMetric.deviceId = data.deviceId; + devMarkMetric.duration = data.deviceEnd - data.deviceStart; + devMarkMetric.timestamp = data.deviceStart; + ans.emplace_back(devMarkMetric); + } + return ans; +} + +void MetricMarkProcess::SendProcessMessage() +{ + auto afterAggregated = AggregatedData(); + for (auto& metric: afterAggregated) { + SendMessage(metric.seriesToJson()); + } +} + +void MetricMarkProcess::Clear() +{ + records.clear(); + domainMsg.clear(); + id2Marker.clear(); +} +} +} +} diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricMarkProcess.h b/msmonitor/plugin/ipc_monitor/metric/MetricMarkProcess.h new file mode 100644 index 0000000000000000000000000000000000000000..35c9ddf700e349283e6b1b1880dffe2fed5a4090 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricMarkProcess.h @@ -0,0 +1,70 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef METRIC_MARK_PROCESS_H +#define METRIC_MARK_PROCESS_H + +#include +#include +#include "MetricProcessBase.h" + + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +struct MarkMetric { + std::string name; + std::string domain; + uint64_t duration; + uint64_t timestamp; + uint32_t deviceId; +public: + std::string seriesToJson(); +}; + +struct RangeMarkData { + std::string name; + std::string domain; + uint64_t duration; + uint64_t start{0}; + uint64_t end{0}; + uint64_t deviceStart{0}; + uint64_t deviceEnd{0}; + uint32_t deviceId; +}; + + +class MetricMarkProcess : public MetricProcessBase { +public: + MetricMarkProcess() = default; + void ConsumeMsptiData(msptiActivity *record) override; + std::vector AggregatedData(); + void SendProcessMessage() override; + void Clear() override; +private: + bool TransMarkData2Range(const std::vector>& markDatas, + RangeMarkData& rangemarkData); +private: + std::mutex dataMutex; + std::unordered_map> domainMsg; + std::vector> records; + std::map>> id2Marker; +}; +} +} +} + +#endif \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricMemCpyProcess.cpp b/msmonitor/plugin/ipc_monitor/metric/MetricMemCpyProcess.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aaa4fd62682a07442b3bd2dcb21462affa5244e8 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricMemCpyProcess.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MetricMemCpyProcess.h" + +#include + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +std::string MemCpyMetric::seriesToJson() +{ + nlohmann::json jsonMsg; + jsonMsg["kind"] = "MemCpy"; + jsonMsg["deviceId"] = deviceId; + jsonMsg["duration"] = duration; + jsonMsg["timestamp"] = timestamp; + return jsonMsg.dump(); +} + +void MetricMemCpyProcess::ConsumeMsptiData(msptiActivity *record) +{ + msptiActivityMemcpy* kernel = ReinterpretConvert(record); + std::shared_ptr tmp; + MakeSharedPtr(tmp); + if (tmp == nullptr || memcpy_s(tmp.get(), sizeof(msptiActivityMemcpy), kernel, sizeof(msptiActivityMemcpy)) != EOK) { + LOG(ERROR) << "memcpy_s failed " << IPC_ERROR(ErrCode::MEMORY); + return; + } + { + std::unique_lock lock(dataMutex); + records.emplace_back(std::move(tmp)); + } +} + +std::vector MetricMemCpyProcess::AggregatedData() +{ + std::vector> copyRecords; + { + std::unique_lock lock(dataMutex); + copyRecords = std::move(records); + records.clear(); + } + if (copyRecords.empty()) { + return {}; + } + std::unordered_map>> deviceId2Memcpy = + groupby(copyRecords, [](const std::shared_ptr& data) -> std::uint32_t { + return data->deviceId; + }); + std::vector ans; + auto curTimestamp = getCurrentTimestamp64(); + for (auto& pair: deviceId2Memcpy) { + auto deviceId = pair.first; + MemCpyMetric memCpyMetric{}; + auto& memCpyDatas = pair.second; + memCpyMetric.duration = std::accumulate(memCpyDatas.begin(), memCpyDatas.end(), 0ULL, + [](uint64_t acc, std::shared_ptr memcpy) { + return acc + memcpy->end - memcpy->start; + }); + memCpyMetric.deviceId = deviceId; + memCpyMetric.timestamp = curTimestamp; + ans.emplace_back(memCpyMetric); + } + return ans; +} + +void MetricMemCpyProcess::SendProcessMessage() +{ + auto afterAggregated = AggregatedData(); + for (auto& metric: afterAggregated) { + SendMessage(metric.seriesToJson()); + } +} + +void MetricMemCpyProcess::Clear() +{ + records.clear(); +} +} +} +} diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricMemCpyProcess.h b/msmonitor/plugin/ipc_monitor/metric/MetricMemCpyProcess.h new file mode 100644 index 0000000000000000000000000000000000000000..9b3b845f31a669190ac6b480d40090c7dc2785ad --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricMemCpyProcess.h @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef METRIC_MEMCPY_PROCESS_H +#define METRIC_MEMCPY_PROCESS_H + +#include +#include "MetricProcessBase.h" + + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +struct MemCpyMetric { + uint64_t duration; + uint64_t timestamp; + uint32_t deviceId; +public: + std::string seriesToJson(); +}; + +class MetricMemCpyProcess : public MetricProcessBase { +public: + MetricMemCpyProcess() = default; + void ConsumeMsptiData(msptiActivity *record) override; + std::vector AggregatedData(); + void SendProcessMessage() override; + void Clear() override; +private: + std::mutex dataMutex; + std::vector> records; +}; +} +} +} + +#endif \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricMemProcess.cpp b/msmonitor/plugin/ipc_monitor/metric/MetricMemProcess.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3f51476595a4fcfe29e4938c1e64d94a35815913 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricMemProcess.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MetricMemProcess.h" + +#include + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +std::string MemMetric::seriesToJson() +{ + nlohmann::json jsonMsg; + jsonMsg["kind"] = "Memory"; + jsonMsg["deviceId"] = deviceId; + jsonMsg["duration"] = duration; + jsonMsg["timestamp"] = timestamp; + return jsonMsg.dump(); +} + +void MetricMemProcess::ConsumeMsptiData(msptiActivity *record) +{ + msptiActivityMemory* mem = ReinterpretConvert(record); + std::shared_ptr tmp; + MakeSharedPtr(tmp); + if (tmp == nullptr || memcpy_s(tmp.get(), sizeof(msptiActivityMemory), mem, sizeof(msptiActivityMemory)) != EOK) { + LOG(ERROR) << "memcpy_s failed " << IPC_ERROR(ErrCode::MEMORY); + return; + } + { + std::unique_lock lock(dataMutex); + records.emplace_back(std::move(tmp)); + } +} + +std::vector MetricMemProcess::AggregatedData() +{ + std::vector> copyRecords; + { + std::unique_lock lock(dataMutex); + copyRecords = std::move(records); + records.clear(); + } + if (copyRecords.empty()) { + return {}; + } + std::unordered_map>> deviceId2MemData = + groupby(copyRecords, [](const std::shared_ptr& data) -> std::uint32_t { + return data->deviceId; + }); + std::vector ans; + auto curTimestamp = getCurrentTimestamp64(); + for (auto& pair: deviceId2MemData) { + auto deviceId = pair.first; + auto& memDatas = pair.second; + MemMetric memMetric{}; + memMetric.duration = std::accumulate(memDatas.begin(), memDatas.end(), 0ULL, + [](uint64_t acc, std::shared_ptr mem) { + return acc + mem->end - mem->start; + }); + memMetric.deviceId = deviceId; + memMetric.timestamp = curTimestamp; + ans.emplace_back(memMetric); + } + return ans; +} + +void MetricMemProcess::SendProcessMessage() +{ + auto afterAggregated = AggregatedData(); + for (auto& metric: afterAggregated) { + SendMessage(metric.seriesToJson()); + } +} + +void MetricMemProcess::Clear() +{ + records.clear(); +} +} +} +} \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricMemProcess.h b/msmonitor/plugin/ipc_monitor/metric/MetricMemProcess.h new file mode 100644 index 0000000000000000000000000000000000000000..c6548c18de1f0cf68ca0490b2640b15d8025ea29 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricMemProcess.h @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef METRIC_MEM_PROCESS_H +#define METRIC_MEM_PROCESS_H + +#include +#include "MetricProcessBase.h" + + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +struct MemMetric { + std::string name; + uint64_t duration; + uint64_t timestamp; + uint32_t deviceId; +public: + std::string seriesToJson(); +}; + +class MetricMemProcess : public MetricProcessBase { +public: + MetricMemProcess() = default; + void ConsumeMsptiData(msptiActivity *record) override; + std::vector AggregatedData(); + void SendProcessMessage() override; + void Clear() override; +private: + std::mutex dataMutex; + std::vector> records; +}; +} +} +} + +#endif \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricMemSetProcess.cpp b/msmonitor/plugin/ipc_monitor/metric/MetricMemSetProcess.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f165ae2d7bdf8b82d0028a604c80fe617168c55c --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricMemSetProcess.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MetricMemSetProcess.h" + +#include + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +std::string MemSetMetric::seriesToJson() +{ + nlohmann::json jsonMsg; + jsonMsg["kind"] = "MemSet"; + jsonMsg["deviceId"] = deviceId; + jsonMsg["duration"] = duration; + jsonMsg["timestamp"] = timestamp; + return jsonMsg.dump(); +} + +void MetricMemSetProcess::ConsumeMsptiData(msptiActivity *record) +{ + msptiActivityMemset* memSet = ReinterpretConvert(record); + std::shared_ptr tmp; + MakeSharedPtr(tmp); + if (tmp == nullptr || memcpy_s(tmp.get(), sizeof(msptiActivityMemset), memSet, sizeof(msptiActivityMemset)) != EOK) { + LOG(ERROR) << "memcpy_s failed " << IPC_ERROR(ErrCode::MEMORY); + return; + } + { + std::unique_lock lock(dataMutex); + records.emplace_back(std::move(tmp)); + } +} + +std::vector MetricMemSetProcess::AggregatedData() +{ + std::vector> copyRecords; + { + std::unique_lock lock(dataMutex); + copyRecords = std::move(records); + records.clear(); + } + if (copyRecords.empty()) { + return {}; + } + std::unordered_map>> deviceId2MemsetData = + groupby(copyRecords, [](const std::shared_ptr& data) -> std::uint32_t { + return data->deviceId; + }); + std::vector ans; + auto curTimestamp = getCurrentTimestamp64(); + for (auto& pair: deviceId2MemsetData) { + MemSetMetric memSetMetric{}; + auto deviceId = pair.first; + auto& memSetDatas = pair.second; + memSetMetric.duration = std::accumulate(memSetDatas.begin(), memSetDatas.end(), 0ULL, + [](uint64_t acc, std::shared_ptr memSet) { + return acc + memSet->end - memSet->start; + }); + memSetMetric.deviceId = deviceId; + memSetMetric.timestamp = curTimestamp; + ans.emplace_back(memSetMetric); + } + return ans; +} + +void MetricMemSetProcess::SendProcessMessage() +{ + auto afterAggregated = AggregatedData(); + for (auto& metric: afterAggregated) { + SendMessage(metric.seriesToJson()); + } +} + +void MetricMemSetProcess::Clear() +{ + records.clear(); +} +} +} +} \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricMemSetProcess.h b/msmonitor/plugin/ipc_monitor/metric/MetricMemSetProcess.h new file mode 100644 index 0000000000000000000000000000000000000000..5d725e6edf5c4bd074d9cc1751a7dde263b52f67 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricMemSetProcess.h @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef METRIC_MEM_SET_PROCESS_H +#define METRIC_MEM_SET_PROCESS_H + +#include +#include "metric/MetricProcessBase.h" + + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +struct MemSetMetric { + std::string name; + uint64_t duration; + uint64_t timestamp; + uint32_t deviceId; +public: + std::string seriesToJson(); +}; + +class MetricMemSetProcess : public MetricProcessBase { +public: + MetricMemSetProcess() = default; + void ConsumeMsptiData(msptiActivity *record) override; + std::vector AggregatedData(); + void SendProcessMessage() override; + void Clear() override; +private: + std::mutex dataMutex; + std::vector> records; +}; +} +} +} + +#endif \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricProcessBase.h b/msmonitor/plugin/ipc_monitor/metric/MetricProcessBase.h new file mode 100644 index 0000000000000000000000000000000000000000..2d066a9b27080116e6d87c908a0d08e837c5add9 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricProcessBase.h @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef METRIC_PROCESS_BASE_H +#define METRIC_PROCESS_BASE_H + +#include +#include + +#include "DynoLogNpuMonitor.h" +#include "NpuIpcClient.h" +#include "mspti.h" + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { +class MetricProcessBase { +public: + void SendMessage(std::string message) + { + if (message.empty()) { + LOG(ERROR) << "SendMessage message is empty"; + return; + } + static const std::string destName = DYNO_IPC_NAME + "_data"; + static const int maxRetry = 5; + static const int retryWaitTimeUs = 1000; + auto msg = Message::ConstructStrMessage(message, MSG_TYPE_DATA); + if (!msg) { + LOG(ERROR) << "ConstructStrMessage failed, message: " << message; + return; + } + auto ipcClient = DynoLogNpuMonitor::GetInstance()->GetIpcClient(); + if (!ipcClient) { + LOG(ERROR) << "DynoLogNpuMonitor ipcClient is nullptr"; + return; + } + if (!ipcClient->SyncSendMessage(*msg, destName, maxRetry, retryWaitTimeUs)) { + LOG(ERROR) << "send mspti message failed: " << message; + } + } + virtual void ConsumeMsptiData(msptiActivity *record) = 0; + virtual void Clear() = 0; + virtual void SendProcessMessage() = 0; +}; +} +} +} +#endif \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/mspti_monitor/MsptiDataProcessBase.h b/msmonitor/plugin/ipc_monitor/mspti_monitor/MsptiDataProcessBase.h new file mode 100644 index 0000000000000000000000000000000000000000..ce91db8472399956dff083fdacf33413c4b9bf20 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/mspti_monitor/MsptiDataProcessBase.h @@ -0,0 +1,43 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MSPTI_DATA_PROCESS_BASE_H +#define MSPTI_DATA_PROCESS_BASE_H + +#include "mspti.h" +#include "utils.h" +#include "TimerTask.h" + +namespace dynolog_npu { +namespace ipc_monitor { +class MsptiDataProcessBase : public TimerTask { +public: + explicit MsptiDataProcessBase(const std::string& name) : TimerTask(name, DEFAULT_FLUSH_INTERVAL) {} + ~MsptiDataProcessBase() = default; + void ExecuteTask() override {} + virtual void EnableKindSwitch(msptiActivityKind kind, bool flag) {} + virtual ErrCode ConsumeMsptiData(msptiActivity *record) { return ErrCode::SUC; } + virtual void SetReportInterval(uint32_t interval) {} + void SetClusterConfigData(const std::unordered_map& configData) + { + clusterConfigData = configData; + } + +protected: + std::unordered_map clusterConfigData; +}; +} // namespace ipc_monitor +} // namespace dynolog_npu +#endif // MSPTI_DATA_PROCESS_BASE_H diff --git a/msmonitor/plugin/ipc_monitor/mspti_monitor/MsptiMonitor.cpp b/msmonitor/plugin/ipc_monitor/mspti_monitor/MsptiMonitor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c3d36ed2cd0313dddc7fb938258ea78918d02f9a --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/mspti_monitor/MsptiMonitor.cpp @@ -0,0 +1,287 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MsptiMonitor.h" + +#include +#include +#include +#include + +#include "DynoLogNpuMonitor.h" +#include "MetricManager.h" +#include "db/DBProcessManager.h" +#include "utils.h" + +namespace { +constexpr size_t DEFAULT_BUFFER_SIZE = 8 * 1024 * 1024; +constexpr size_t MAX_BUFFER_SIZE = 256 * 1024 * 1024; +constexpr uint32_t MAX_ALLOC_CNT = MAX_BUFFER_SIZE / DEFAULT_BUFFER_SIZE; +} + +namespace dynolog_npu { +namespace ipc_monitor { +MsptiMonitor::~MsptiMonitor() +{ + Uninit(); +} + +void MsptiMonitor::Start() +{ + if (start_.load()) { + return; + } + if (savePath_.empty()) { + std::shared_ptr metricManager{nullptr}; + MakeSharedPtr(metricManager); + dataProcessor_ = metricManager; + } else { + std::shared_ptr dbProcessManager{nullptr}; + MakeSharedPtr(dbProcessManager, savePath_); + dataProcessor_ = dbProcessManager; + } + if (dataProcessor_ == nullptr) { + LOG(ERROR) << "MsptiMonitor Start failed, dataProcessor init failed"; + return; + } + SetThreadName("MsptiMonitor"); + if (Thread::Start() != 0) { + LOG(ERROR) << "MsptiMonitor start failed"; + return; + } + start_.store(true); + dataProcessor_->SetReportInterval(flushInterval_); + dataProcessor_->Run(); + LOG(INFO) << "MsptiMonitor start successfully"; +} + +void MsptiMonitor::Stop() +{ + if (!start_.load()) { + LOG(WARNING) << "MsptiMonitor is not running"; + return; + } + + if (msptiActivityFlushAll(1) != MSPTI_SUCCESS) { + LOG(WARNING) << "MsptiMonitor stop msptiActivityFlushAll failed"; + } + Uninit(); + LOG(INFO) << "MsptiMonitor stop successfully"; +} + +void MsptiMonitor::Uninit() +{ + if (!start_.load()) { + return; + } + start_.store(false); + cv_.notify_one(); + Thread::Stop(); + if (dataProcessor_ != nullptr) { + dataProcessor_->Stop(); + dataProcessor_ = nullptr; + } + savePath_.clear(); +} + +bool MsptiMonitor::CheckAndSetSavePath(const std::string &path) +{ + if (path.empty()) { + LOG(ERROR) << "MsptiMonitor CheckAndSetSavePath failed, path is empty"; + return false; + } + std::string absPath = PathUtils::RelativeToAbsPath(path); + if (PathUtils::DirPathCheck(absPath)) { + std::string realPath = PathUtils::RealPath(absPath); + if (PathUtils::CreateDir(realPath)) { + savePath_ = realPath; + return true; + } + LOG(ERROR) << "MsptiMonitor CheckAndSetSavePath failed, Create save path: " << realPath << " failed."; + } else { + LOG(ERROR) << "MsptiMonitor CheckAndSetSavePath failed, save path: " << absPath << " is invalid."; + } + return false; +} + +void MsptiMonitor::EnableActivity(msptiActivityKind kind) +{ + if (MSPTI_ACTIVITY_KIND_INVALID < kind && kind < MSPTI_ACTIVITY_KIND_COUNT) { + std::lock_guard lock(activityMtx_); + if (msptiActivityEnable(kind) == MSPTI_SUCCESS) { + enabledActivities_.insert(kind); + } else { + LOG(ERROR) << "MsptiMonitor enableActivity failed, kind: " << static_cast(kind); + } + if (dataProcessor_ != nullptr) { + dataProcessor_->EnableKindSwitch(kind, true); + } + } +} + +void MsptiMonitor::DisableActivity(msptiActivityKind kind) +{ + if (MSPTI_ACTIVITY_KIND_INVALID < kind && kind < MSPTI_ACTIVITY_KIND_COUNT) { + std::lock_guard lock(activityMtx_); + if (msptiActivityDisable(kind) == MSPTI_SUCCESS) { + enabledActivities_.erase(kind); + } else { + LOG(ERROR) << "MsptiMonitor disableActivity failed, kind: " << static_cast(kind); + } + if (dataProcessor_ != nullptr) { + dataProcessor_->EnableKindSwitch(kind, false); + } + } +} + +void MsptiMonitor::SetFlushInterval(uint32_t interval) +{ + flushInterval_.store(interval); + checkFlush_.store(true); + if (start_.load()) { + cv_.notify_one(); + } + if (dataProcessor_ != nullptr) { + dataProcessor_->SetReportInterval(interval); + } +} + +bool MsptiMonitor::IsStarted() +{ + return start_.load(); +} + +std::set MsptiMonitor::GetEnabledActivities() +{ + std::lock_guard lock(activityMtx_); + return enabledActivities_; +} + +void MsptiMonitor::Run() +{ + if (msptiSubscribe(&subscriber_, nullptr, nullptr) != MSPTI_SUCCESS) { + LOG(ERROR) << "MsptiMonitor run failed, msptiSubscribe failed"; + return; + } + if (msptiActivityRegisterCallbacks(BufferRequest, BufferComplete) != MSPTI_SUCCESS) { + LOG(ERROR) << "MsptiMonitor run failed, msptiActivityRegisterCallbacks failed"; + return; + } + while (true) { + std::unique_lock lock(cvMtx_); + if (flushInterval_.load() > 0) { + cv_.wait_for(lock, std::chrono::seconds(flushInterval_.load()), + [&]() { return checkFlush_.load() || !start_.load();}); + } else { + cv_.wait(lock, [&]() { return checkFlush_.load () || !start_.load();}); + } + if (!start_.load()) { + break; + } + if (checkFlush_.load()) { + checkFlush_.store(false); + } + if (flushInterval_.load() > 0) { + if (msptiActivityFlushAll(1) != MSPTI_SUCCESS) { + LOG(ERROR) << "MsptiMonitor run msptiActivityFlushAll failed"; + } + } + } + if (msptiUnsubscribe(subscriber_) != MSPTI_SUCCESS) { + LOG(ERROR) << "MsptiMonitor run failed, msptiUnsubscribe failed"; + } + { + std::lock_guard lock(activityMtx_); + for (auto kind : enabledActivities_) { + msptiActivityDisable(kind); + } + enabledActivities_.clear(); + } + checkFlush_.store(false); + flushInterval_.store(0); +} + +std::atomic MsptiMonitor::allocCnt{0}; + +void MsptiMonitor::BufferRequest(uint8_t **buffer, size_t *size, size_t *maxNumRecords) +{ + if (buffer == nullptr || size == nullptr || maxNumRecords == nullptr) { + return; + } + *maxNumRecords = 0; + if (allocCnt.load() >= MAX_ALLOC_CNT) { + *buffer = nullptr; + *size = 0; + LOG(ERROR) << "MsptiMonitor BufferRequest failed, allocCnt: " << allocCnt.load(); + return; + } + uint8_t *pBuffer = ReinterpretConvert(MsptiMalloc(DEFAULT_BUFFER_SIZE, ALIGN_SIZE)); + if (pBuffer == nullptr) { + *buffer = nullptr; + *size = 0; + } else { + *buffer = pBuffer; + *size = DEFAULT_BUFFER_SIZE; + allocCnt++; + LOG(INFO) << "MsptiMonitor BufferRequest, size: " << *size; + } +} + +void MsptiMonitor::BufferComplete(uint8_t *buffer, size_t size, size_t validSize) +{ + if (validSize > 0 && buffer != nullptr) { + LOG(INFO) << "MsptiMonitor BufferComplete, size: " << size << ", validSize: " << validSize; + msptiActivity *record = nullptr; + msptiResult status = MSPTI_SUCCESS; + do { + status = msptiActivityGetNextRecord(buffer, validSize, &record); + if (status == MSPTI_SUCCESS) { + BufferConsume(record); + } else if (status == MSPTI_ERROR_MAX_LIMIT_REACHED) { + break; + } else { + LOG(ERROR) << "MsptiMonitor BufferComplete failed, status: " << static_cast(status); + break; + } + } while (true); + allocCnt--; + } + MsptiFree(buffer); +} + +void MsptiMonitor::BufferConsume(msptiActivity *record) +{ + if (record == nullptr) { + return; + } + auto dataProcessor = GetDataProcessor(); + if (dataProcessor != nullptr) { + dataProcessor->ConsumeMsptiData(record); + } +} + +std::shared_ptr MsptiMonitor::GetDataProcessor() +{ + return GetInstance()->dataProcessor_; +} + +void MsptiMonitor::SetClusterConfigData(const std::unordered_map& configData) +{ + if (dataProcessor_ != nullptr) { + dataProcessor_->SetClusterConfigData(configData); + } +} +} // namespace ipc_monitor +} // namespace dynolog_npu diff --git a/msmonitor/plugin/ipc_monitor/mspti_monitor/MsptiMonitor.h b/msmonitor/plugin/ipc_monitor/mspti_monitor/MsptiMonitor.h new file mode 100644 index 0000000000000000000000000000000000000000..5e988948f7a077b640388c65a82b850b70917924 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/mspti_monitor/MsptiMonitor.h @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MSPTI_MONITOR_H +#define MSPTI_MONITOR_H + +#include +#include +#include +#include +#include "mspti.h" +#include "thread.h" +#include "singleton.h" +#include "MsptiDataProcessBase.h" + + +namespace dynolog_npu { +namespace ipc_monitor { +class MsptiMonitor : public Singleton, public Thread { +public: + virtual ~MsptiMonitor(); + void Start(); + void Stop(); + void EnableActivity(msptiActivityKind kind); + void DisableActivity(msptiActivityKind kind); + void SetFlushInterval(uint32_t interval); + bool IsStarted(); + std::set GetEnabledActivities(); + void Uninit(); + bool CheckAndSetSavePath(const std::string& path); + void SetClusterConfigData(const std::unordered_map& configData); + +private: + static void BufferRequest(uint8_t **buffer, size_t *size, size_t *maxNumRecords); + static void BufferComplete(uint8_t *buffer, size_t size, size_t validSize); + static void BufferConsume(msptiActivity *record); + static std::shared_ptr GetDataProcessor(); + static std::atomic allocCnt; + +private: + void Run() override; + +private: + std::atomic start_{false}; + std::mutex cvMtx_; + std::condition_variable cv_; + msptiSubscriberHandle subscriber_{nullptr}; + std::mutex activityMtx_; + std::set enabledActivities_; + std::atomic checkFlush_{false}; + std::atomic flushInterval_{0}; + std::string savePath_; + std::shared_ptr dataProcessor_{nullptr}; +}; +} // namespace ipc_monitor +} // namespace dynolog_npu +#endif // MSPTI_MONITOR_H diff --git a/msmonitor/plugin/ipc_monitor/mspti_monitor/mspti.h b/msmonitor/plugin/ipc_monitor/mspti_monitor/mspti.h new file mode 100644 index 0000000000000000000000000000000000000000..baa8a201dd2840852bfce5c407fa43412077aeab --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/mspti_monitor/mspti.h @@ -0,0 +1,295 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MSPTI_STUB_H +#define MSPTI_STUB_H + +constexpr int ACTIVITY_STRUCT_ALIGNMENT = 8; +#if defined(_WIN32) +#define START_PACKED_ALIGNMENT __pragma(pack(push, 1)) +#define PACKED_ALIGNMENT __declspec(align(ACTIVITY_STRUCT_ALIGNMENT)) +#define END_PACKED_ALIGNMENT __pragma(pack(pop)) +#elif defined(__GNUC__) +#define START_PACKED_ALIGNMENT +#define PACKED_ALIGNMENT __attribute__((__packed__)) __attribute__((aligned(ACTIVITY_STRUCT_ALIGNMENT))) +#define END_PACKED_ALIGNMENT +#else +#define START_PACKED_ALIGNMENT +#define PACKED_ALIGNMENT +#define END_PACKED_ALIGNMENT +#endif + +#include +#include + +#define MSPTI_INVALID_DEVICE_ID ((uint32_t) 0xFFFFFFFFU) +#define MSPTI_INVALID_STREAM_ID ((uint32_t) 0xFFFFFFFFU) +#define MSPTI_INVALID_CORRELATION_ID ((uint64_t) 0) +using msptiCallbackId = uint32_t; + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef enum { + MSPTI_SUCCESS = 0, + MSPTI_ERROR_INVALID_PARAMETER = 1, + MSPTI_ERROR_MULTIPLE_SUBSCRIBERS_NOT_SUPPORTED = 2, + MSPTI_ERROR_MAX_LIMIT_REACHED = 3, + MSPTI_ERROR_DEVICE_OFFLINE = 4, + MSPTI_ERROR_QUERY_EMPTY = 5, + MSPTI_ERROR_INNER = 999, + MSPTI_ERROR_FOECE_INT = 0x7fffffff +} msptiResult; + +typedef enum { + MSPTI_CB_DOMAIN_INVALID = 0, + MSPTI_CB_DOMAIN_RUNTIME = 1, + MSPTI_CB_DOMAIN_HCCL = 2, + MSPTI_CB_DOMAIN_SIZE, + MSPTI_CB_DOMAIN_FORCE_INT = 0x7fffffff +} msptiCallbackDomain; + +typedef enum { + MSPTI_API_ENTER = 0, + MSPTI_API_EXIT = 1, + MSPTI_API_CBSITE_FORCE_INT = 0x7fffffff +} msptiApiCallbackSite; + +typedef struct { + msptiApiCallbackSite callbackSite; + const char *functionName; + const void *functionParams; + const void *functionReturnValue; + const char *symbolName; + uint64_t correlationId; + uint64_t reserved1; + uint64_t reserved2; + uint64_t *correlationData; +} msptiCallbackData; + +typedef enum { + MSPTI_ACTIVITY_KIND_INVALID = 0, + MSPTI_ACTIVITY_KIND_MARKER = 1, + MSPTI_ACTIVITY_KIND_KERNEL = 2, + MSPTI_ACTIVITY_KIND_API = 3, + MSPTI_ACTIVITY_KIND_HCCL = 4, + MSPTI_ACTIVITY_KIND_MEMORY = 5, + MSPTI_ACTIVITY_KIND_MEMSET = 6, + MSPTI_ACTIVITY_KIND_MEMCPY = 7, + MSPTI_ACTIVITY_KIND_EXTERNAL_CORRELATION = 8, + MSPTI_ACTIVITY_KIND_COMMUNICATION = 9, + MSPTI_ACTIVITY_KIND_COUNT, + MSPTI_ACTIVITY_KIND_FORCE_INT = 0x7fffffff +} msptiActivityKind; + +typedef enum { + MSPTI_ACTIVITY_FLAG_NONE = 0, + MSPTI_ACTIVITY_FLAG_MARKER_INSTANTANEOUS = 1 << 0, + MSPTI_ACTIVITY_FLAG_MARKER_START = 1 << 1, + MSPTI_ACTIVITY_FLAG_MARKER_END = 1 << 2, + MSPTI_ACTIVITY_FLAG_MARKER_INSTANTANEOUS_WITH_DEVICE = 1 << 3, + MSPTI_ACTIVITY_FLAG_MARKER_START_WITH_DEVICE = 1 << 4, + MSPTI_ACTIVITY_FLAG_MARKER_END_WITH_DEVICE = 1 << 5 +} msptiActivityFlag; + +typedef enum { + MSPTI_ACTIVITY_SOURCE_KIND_HOST = 0, + MSPTI_ACTIVITY_SOURCE_KIND_DEVICE = 1 +} msptiActivitySourceKind; + +typedef enum { + MSPTI_ACTIVITY_MEMORY_OPERATION_TYPE_ALLOCATATION = 0, + MSPTI_ACTIVITY_MEMORY_OPERATION_TYPE_RELEASE = 1 +} msptiActivityMemoryOperationType; + +typedef enum { + MSPTI_ACTIVITY_MEMORY_KIND_UNKNOWN = 0, + MSPTI_ACTIVITY_MEMORY_KIND_DEVICE = 1 +} msptiActivityMemoryKind; + +typedef enum { + MSPTI_ACTIVITY_MEMCPY_KIND_UNKNOWN = 0, + MSPTI_ACTIVITY_MEMCPY_KIND_HTOH = 1, + MSPTI_ACTIVITY_MEMCPY_KIND_HTOD = 2, + MSPTI_ACTIVITY_MEMCPY_KIND_DTOH = 3, + MSPTI_ACTIVITY_MEMCPY_KIND_DTOD = 4, + MSPTI_ACTIVITY_MEMCPY_KIND_DEFAULT = 5 +} msptiActivityMemcpyKind; + +typedef enum { + MSPTI_ACTIVITY_COMMUNICATION_INT8 = 0, + MSPTI_ACTIVITY_COMMUNICATION_INT16 = 1, + MSPTI_ACTIVITY_COMMUNICATION_INT32 = 2, + MSPTI_ACTIVITY_COMMUNICATION_FP16 = 3, + MSPTI_ACTIVITY_COMMUNICATION_FP32 = 4, + MSPTI_ACTIVITY_COMMUNICATION_INT64 = 5, + MSPTI_ACTIVITY_COMMUNICATION_UINT64 = 6, + MSPTI_ACTIVITY_COMMUNICATION_UINT8 = 7, + MSPTI_ACTIVITY_COMMUNICATION_UINT16 = 8, + MSPTI_ACTIVITY_COMMUNICATION_UINT32 = 9, + MSPTI_ACTIVITY_COMMUNICATION_FP64 = 10, + MSPTI_ACTIVITY_COMMUNICATION_BFP16 = 11, + MSPTI_ACTIVITY_COMMUNICATION_INT128 = 12, + MSPTI_ACTIVITY_COMMUNICATION_INVALID_TYPE = 0x0000FFFF +} msptiCommunicationDataType; + +START_PACKED_ALIGNMENT + +typedef union PACKED_ALIGNMENT { + struct { + uint32_t processId; + uint32_t threadId; + } pt; + struct { + uint32_t deviceId; + uint32_t streamId; + } ds; +} msptiObjectId; + +typedef struct PACKED_ALIGNMENT { + msptiActivityKind kind; +} msptiActivity; + +typedef struct PACKED_ALIGNMENT { + msptiActivityKind kind; + uint64_t start; + uint64_t end; + struct { + uint32_t processId; + uint32_t threadId; + } pt; + uint64_t correlationId; + const char* name; +} msptiActivityApi; + +typedef struct PACKED_ALIGNMENT { + msptiActivityKind kind; + uint64_t start; + uint64_t end; + struct { + uint32_t deviceId; + uint32_t streamId; + } ds; + uint64_t correlationId; + const char *type; + const char *name; +} msptiActivityKernel; + +typedef struct PACKED_ALIGNMENT { + msptiActivityKind kind; + msptiActivityFlag flag; + msptiActivitySourceKind sourceKind; + uint64_t timestamp; + uint64_t id; + msptiObjectId objectId; + const char *name; + const char *domain; +} msptiActivityMarker; + +typedef struct PACKED_ALIGNMENT { + msptiActivityKind kind; + uint64_t start; + uint64_t end; + struct { + uint32_t deviceId; + uint32_t streamId; + } ds; + double bandWidth; + const char *name; + const char *commName; +} msptiActivityHccl; + +typedef struct PACKED_ALIGNMENT { + msptiActivityKind kind; + msptiActivityMemoryOperationType memoryOperationType; + msptiActivityMemoryKind memoryKind; + uint64_t correlationId; + uint64_t start; + uint64_t end; + uint64_t address; + uint64_t bytes; + uint32_t processId; + uint32_t deviceId; + uint32_t streamId; +} msptiActivityMemory; + +typedef struct PACKED_ALIGNMENT { + msptiActivityKind kind; + uint32_t value; + uint64_t bytes; + uint64_t start; + uint64_t end; + uint32_t deviceId; + uint32_t streamId; + uint64_t correlationId; + uint8_t isAsync; +} msptiActivityMemset; + +typedef struct PACKED_ALIGNMENT { + msptiActivityKind kind; + msptiActivityMemcpyKind copyKind; + uint64_t bytes; + uint64_t start; + uint64_t end; + uint32_t deviceId; + uint32_t streamId; + uint64_t correlationId; + uint8_t isAsync; +} msptiActivityMemcpy; + +typedef struct PACKED_ALIGNMENT { + msptiActivityKind kind; + msptiCommunicationDataType dataType; + uint64_t count; + struct { + uint32_t deviceId; + uint32_t streamId; + } ds; + uint64_t start; + uint64_t end; + const char* algType; + const char* name; + const char* commName; + uint64_t correlationId; +} msptiActivityCommunication; + +END_PACKED_ALIGNMENT + +typedef void(*msptiCallbackFunc)(void* userdata, msptiCallbackDomain domain, msptiCallbackId cbid, const msptiCallbackData *cbdata); +typedef void(*msptiBuffersCallbackRequestFunc)(uint8_t **buffer, size_t *size, size_t *maxNumRecords); +typedef void(*msptiBuffersCallbackCompleteFunc)(uint8_t *buffer, size_t size, size_t validSize); + +struct msptiSubscriber_st { + msptiCallbackFunc callback; + void *userdata; +}; + +typedef struct msptiSubscriber_st *msptiSubscriberHandle; + +msptiResult msptiSubscribe(msptiSubscriberHandle *subscriber, msptiCallbackFunc callback, void *userdata); +msptiResult msptiUnsubscribe(msptiSubscriberHandle subscriber); +msptiResult msptiActivityRegisterCallbacks(msptiBuffersCallbackRequestFunc funcBufferRequested, msptiBuffersCallbackCompleteFunc funcBufferCompleted); +msptiResult msptiActivityEnable(msptiActivityKind kind); +msptiResult msptiActivityDisable(msptiActivityKind kind); +msptiResult msptiActivityGetNextRecord(uint8_t *buffer, size_t validBufferSizeBytes, msptiActivity **record); +msptiResult msptiActivityFlushAll(uint32_t flag); +msptiResult msptiActivityEnableMarkerDomain(const char* name); +msptiResult msptiActivityDisableMarkerDomain(const char* name); + +#ifdef __cplusplus +} +#endif // __cplusplus +#endif // MSPTI_STUB_H diff --git a/dynolog_npu/plugin/ipc_monitor/singleton.h b/msmonitor/plugin/ipc_monitor/singleton.h similarity index 48% rename from dynolog_npu/plugin/ipc_monitor/singleton.h rename to msmonitor/plugin/ipc_monitor/singleton.h index 8bb106f3adc8b365ef81feb603c6aaac917a00e2..5143f404f19a2c6d2da96e171bb39e1cd2b549b6 100644 --- a/dynolog_npu/plugin/ipc_monitor/singleton.h +++ b/msmonitor/plugin/ipc_monitor/singleton.h @@ -1,31 +1,47 @@ -#ifndef SINGLETON_H -#define SINGLETON_H -#include - -namespace dynolog_npu { -namespace ipc_monitor { - -template -class Singleton { -public: - static T *GetInstance() noexcept(std::is_nothrow_constructible::value) { - static T instance; - return &instance; - } - - virtual ~Singleton() = default; - -protected: - explicit Singleton() = default; - -private: - explicit Singleton(const Singleton &obj) = delete; - Singleton& operator=(const Singleton &obj) = delete; - explicit Singleton(Singleton &&obj) = delete; - Singleton& operator=(Singleton &&obj) = delete; -}; - -} // ipc_monitor -} // dynolog_npu - +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SINGLETON_H +#define SINGLETON_H +#include + +namespace dynolog_npu { +namespace ipc_monitor { + +template +class Singleton { +public: + static T *GetInstance() noexcept(std::is_nothrow_constructible::value) + { + static T instance; + return &instance; + } + + virtual ~Singleton() = default; + +protected: + explicit Singleton() = default; + +private: + explicit Singleton(const Singleton &obj) = delete; + Singleton& operator=(const Singleton &obj) = delete; + explicit Singleton(Singleton &&obj) = delete; + Singleton& operator=(Singleton &&obj) = delete; +}; + +} // ipc_monitor +} // dynolog_npu + #endif \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/thread.h b/msmonitor/plugin/ipc_monitor/thread.h new file mode 100644 index 0000000000000000000000000000000000000000..b674cbb6cb86c3bf6177da81684549211ef92ee8 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/thread.h @@ -0,0 +1,90 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef IPC_MONITOR_THREAD_H +#define IPC_MONITOR_THREAD_H + +#include +#include +#include +#include +#include "utils.h" + +namespace dynolog_npu { +namespace ipc_monitor { +class Thread { +public: + Thread() + : is_alive_(false), + pid_(0), + thread_name_("IPCMonitor") {} + + ~Thread() + { + if (is_alive_) { + (void)pthread_cancel(pid_); + (void)pthread_join(pid_, nullptr); + } + } + + void SetThreadName(const std::string &name) + { + if (!name.empty()) { + thread_name_ = name; + } + } + + std::string GetThreadName() + { + return thread_name_; + } + + int Start() + { + int ret = pthread_create(&pid_, nullptr, Execute, ReinterpretConvert(this)); + is_alive_ = (ret == 0) ? true : false; + return ret; + } + + int Stop() + { + return Join(); + } + + int Join() + { + int ret = pthread_join(pid_, nullptr); + is_alive_ = (ret == 0) ? false : true; + return ret; + } + +private: + static void* Execute(void *args) + { + Thread *thr = ReinterpretConvert(args); + prctl(PR_SET_NAME, ReinterpretConvert(thr->GetThreadName().data())); + thr->Run(); + return nullptr; + } + virtual void Run() = 0; + +private: + bool is_alive_; + pthread_t pid_; + std::string thread_name_; +}; +} // ipc_monitor +} // dynolog_npu +#endif // IPC_MONITOR_THREAD_H diff --git a/msmonitor/plugin/ipc_monitor/utils.cpp b/msmonitor/plugin/ipc_monitor/utils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f78889791229acaf17211ea1e13ca3010cb01ade --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/utils.cpp @@ -0,0 +1,573 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace dynolog_npu { +namespace ipc_monitor { +namespace { +template +std::string IntToHexStr(T number) +{ + std::stringstream strStream; + strStream << std::hex << number; + return strStream.str(); +} +} // namespace + +std::unordered_map submoduleMap = { + {SubModule::IPC, "IPC"}, +}; + +std::unordered_map errCodeMap = { + {ErrCode::SUC, "success"}, + {ErrCode::PARAM, "invalid parameter"}, + {ErrCode::TYPE, "invalid type"}, + {ErrCode::VALUE, "invalid value"}, + {ErrCode::PTR, "invalid pointer"}, + {ErrCode::INTERNAL, "internal error"}, + {ErrCode::MEMORY, "memory error"}, + {ErrCode::NOT_SUPPORT, "feature not supported"}, + {ErrCode::NOT_FOUND, "resource not found"}, + {ErrCode::UNAVAIL, "resource unavailable"}, + {ErrCode::SYSCALL, "system call failed"}, + {ErrCode::TIMEOUT, "timeout error"}, + {ErrCode::PERMISSION, "permission error"}, +}; + +std::string getCurrentTimestamp() +{ + auto now = std::chrono::system_clock::now(); + auto micros = std::chrono::duration_cast(now.time_since_epoch()); + + std::ostringstream oss; + std::time_t currentTime = std::chrono::system_clock::to_time_t(now); + std::tm timeInfo; + if (localtime_r(¤tTime, &timeInfo) != nullptr) { + auto milli_time = std::chrono::duration_cast(micros).count() % 1000; + auto micro_time = micros.count() % 1000; + + oss << std::put_time(&timeInfo, "%Y%m%d%H%M%S"); + constexpr int kMilliTimeWidth = 3; + oss << std::setw(kMilliTimeWidth) << std::setfill('0') << milli_time; + } + + return oss.str(); +} + +uint64_t getCurrentTimestamp64() +{ + auto now = std::chrono::system_clock::now(); + auto ns = std::chrono::duration_cast(now.time_since_epoch()); + return ns.count(); +} + +std::string formatErrorCode(SubModule submodule, ErrCode errorCode) +{ + std::ostringstream oss; + oss << "\n[ERROR] " << getCurrentTimestamp() << " (PID:" << getpid() << ")"; + oss << "ERR" << std::setw(2) << std::setfill('0') << static_cast(submodule); // 2: 字段宽度 + oss << std::setw(3) << std::setfill('0') << static_cast(errorCode); // 3: 字段宽度 + oss << " " << submoduleMap[submodule] << " " << errCodeMap[errorCode]; + return oss.str(); +}; + +int32_t GetProcessId() +{ + static int32_t pid = []() -> int32_t { + return static_cast(getpid()); + }(); + return pid; +} + +bool ParseProcStat(const std::string& line, std::string& command, int& parentPid) +{ + size_t lparen = line.find('('); + size_t rparen = line.rfind(')'); + if (lparen == std::string::npos || rparen == std::string::npos || rparen <= lparen + 1) { + LOG(WARNING) << "cannot find command name: " << line; + return false; + } + command = line.substr(lparen + 1, rparen - lparen - 1); + + std::string afterCmd = line.substr(rparen + 1); + std::istringstream iss(afterCmd); + std::string state; + int ppid; + if (!(iss >> state >> ppid)) { + LOG(WARNING) << "Failed to parse state/ppid from: " << afterCmd; + return false; + } + parentPid = ppid; + return true; +} + +std::pair GetParentPidAndCommand(int32_t pid) +{ + std::string fileName = "/proc/" + std::to_string(pid) + "/stat"; + std::ifstream statFile(fileName); + if (!statFile) { + return std::make_pair(0, ""); + } + int32_t parentPid = 0; + std::string command; + std::string line; + if (std::getline(statFile, line)) { + bool ret = ParseProcStat(line, command, parentPid); + if (ret) { + return std::make_pair(parentPid, command); + } + } + LOG(WARNING) << "Failed to parse /proc/" << pid << "/stat"; + return std::make_pair(0, ""); +} + +std::vector> GetPidCommandPairsofAncestors() +{ + std::vector> process_pids_and_cmds; + process_pids_and_cmds.reserve(MaxParentPids + 1); + int32_t current_pid = GetProcessId(); + for (int i = 0; i <= MaxParentPids && (i == 0 || current_pid > 1); i++) { + std::pair parent_pid_and_cmd = GetParentPidAndCommand(current_pid); + process_pids_and_cmds.push_back(std::make_pair(current_pid, parent_pid_and_cmd.second)); + current_pid = parent_pid_and_cmd.first; + } + return process_pids_and_cmds; +} + +std::vector GetPids() +{ + const auto &pids = GetPidCommandPairsofAncestors(); + std::vector res; + res.reserve(pids.size()); + for (const auto &pidPair : pids) { + res.push_back(pidPair.first); + } + LOG(INFO) << "Success to get parent pid: " << res; + return res; +} + +std::string GenerateUuidV4() +{ + static std::random_device randomDevice; + static std::mt19937 gen(randomDevice()); + static std::uniform_int_distribution<> dis(0, 15); // range (0, 15) + static std::uniform_int_distribution<> dis2(8, 11); // range (8, 11) + + std::stringstream stringStream; + stringStream << std::hex; + for (int i = 0; i < 8; i++) { // 8 times + stringStream << dis(gen); + } + stringStream << "-"; + for (int j = 0; j < 4; j++) { // 4 times + stringStream << dis(gen); + } + stringStream << "-4"; // add -4 + for (int k = 0; k < 3; k++) { // 3 times + stringStream << dis(gen); + } + stringStream << "-"; + stringStream << dis2(gen); + for (int m = 0; m < 3; m++) { // 3 times + stringStream << dis(gen); + } + stringStream << "-"; + for (int n = 0; n < 12; n++) { // 12 times + stringStream << dis(gen); + } + return stringStream.str(); +} + +bool Str2Uint32(uint32_t& dest, const std::string& str) +{ + if (str.empty()) { + LOG(ERROR) << "Str to uint32 failed, input string is null"; + return false; + } + size_t pos = 0; + try { + dest = static_cast(std::stoul(str, &pos)); + } catch(...) { + LOG(ERROR) << "Str to uint32 failed, input string is " << str; + return false; + } + if (pos != str.size()) { + LOG(ERROR) << "Str to uint32 failed, input string is " << str; + return false; + } + return true; +} + +bool Str2Int32(int32_t& dest, const std::string& str) +{ + if (str.empty()) { + LOG(ERROR) << "Str to int32 failed, input string is null"; + return false; + } + size_t pos = 0; + try { + dest = static_cast(std::stol(str, &pos)); + } catch(...) { + LOG(ERROR) << "Str to int32 failed, input string is " << str; + return false; + } + if (pos != str.size()) { + LOG(ERROR) << "Str to int32 failed, input string is " << str; + return false; + } + return true; +} + +bool Str2Bool(bool& dest, const std::string& str) +{ + std::string lower_str = str; + std::transform(lower_str.begin(), lower_str.end(), lower_str.begin(), ::tolower); + + if (lower_str == "true" || lower_str == "1") { + dest = true; + return true; + } + + if (lower_str == "false" || lower_str == "0") { + dest = false; + return true; + } + LOG(ERROR) << "Str to bool failed, input string is " << str; + return false; +} + +std::string& trim(std::string& str) +{ + if (str.empty()) { + return str; + } + str.erase(0, str.find_first_not_of(" ")); + str.erase(str.find_last_not_of(" ") + 1); + return str; +} + +// split函数 +std::vector split(const std::string& str, char delimiter) +{ + std::vector tokens; + std::string token; + std::istringstream tokenStream(str); + + while (std::getline(tokenStream, token, delimiter)) { + tokens.push_back(token); + } + + return tokens; +} + +std::string join(const std::vector &strs, const std::string &delimiter) +{ + std::stringstream ss; + for (size_t i = 0, len = strs.size(); i < len; ++i) { + ss << strs[i] << (i == len - 1 ? "" : delimiter); + } + return ss.str(); +} + +void *MsptiMalloc(size_t size, size_t alignment) +{ + if (alignment > 0) { + size = (size + alignment - 1) / alignment * alignment; + } +#if defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 200112L + void *ptr = nullptr; + if (posix_memalign(&ptr, alignment, size) != 0) { + ptr = nullptr; + } + return ptr; +#else + return malloc(size); +#endif +} + +void MsptiFree(uint8_t *ptr) +{ + if (ptr != nullptr) { + free(ptr); + } +} + +bool PathUtils::IsFileExist(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX) { + return false; + } + return access(path.c_str(), F_OK) == 0; +} + +bool PathUtils::IsFileWritable(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX) { + return false; + } + return access(path.c_str(), W_OK) == 0; +} + +bool PathUtils::IsDir(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX) { + return false; + } + struct stat st{}; + int ret = lstat(path.c_str(), &st); + if (ret != 0) { + return false; + } + return S_ISDIR(st.st_mode); +} + +bool PathUtils::CreateDir(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX) { + return false; + } + if (IsFileExist(path)) { + return IsDir(path); + } + size_t pos = 0; + while ((pos = path.find_first_of('/', pos)) != std::string::npos) { + std::string baseDir = path.substr(0, ++pos); + if (IsFileExist(baseDir)) { + if (IsDir(baseDir)) { + continue; + } else { + return false; + } + } + if (mkdir(baseDir.c_str(), DATA_DIR_AUTHORITY) != 0) { + if (errno != EEXIST) { + return false; + } + } + } + auto ret = mkdir(path.c_str(), DATA_DIR_AUTHORITY); + return (ret == 0 || errno == EEXIST) ? true : false; +} + +std::string PathUtils::RealPath(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX) { + return ""; + } + char realPath[PATH_MAX] = {0}; + if (realpath(path.c_str(), realPath) == nullptr) { + return ""; + } + return std::string(realPath); +} + +std::string PathUtils::RelativeToAbsPath(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX) { + return ""; + } + if (path[0] != '/') { + char pwdPath[PATH_MAX] = {0}; + if (getcwd(pwdPath, PATH_MAX) != nullptr) { + return std::string(pwdPath) + "/" + path; + } + return ""; + } + return std::string(path); +} + +std::string PathUtils::DirName(const std::string &path) +{ + if (path.empty()) { + return ""; + } + std::string tempPath = std::string(path.begin(), path.end()); + char* cPath = dirname(const_cast(tempPath.data())); + return cPath ? std::string(cPath) : ""; +} + +bool PathUtils::CreateFile(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX || !CreateDir(DirName(path))) { + return false; + } + int fd = creat(path.c_str(), DATA_FILE_AUTHORITY); + return (fd < 0 || close(fd) != 0) ? false : true; +} + +bool PathUtils::IsSoftLink(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX || !IsFileExist(path)) { + return false; + } + struct stat st{}; + if (lstat(path.c_str(), &st) != 0) { + return false; + } + return S_ISLNK(st.st_mode); +} + +bool PathUtils::DirPathCheck(const std::string& path) +{ + if (path.empty() || path.size() > PATH_MAX) { + fprintf(stderr, "[ERROR] The length of Path %s is invalid.\n", path.c_str()); + return false; + } + if (IsSoftLink(path)) { + fprintf(stderr, "[ERROR] Path %s is soft link.\n", path.c_str()); + return false; + } + if (!IsFileExist(path) && !CreateDir(path)) { + fprintf(stderr, "[ERROR] Path %s not exist and create failed.\n", path.c_str()); + return false; + } + if (!IsDir(path) || !IsFileWritable(path)) { + fprintf(stderr, "[ERROR] %s is not a directory or is not writable.\n", path.c_str()); + return false; + } + return true; +} + +int GetRankId() +{ + static int rankId = []() -> int { + pybind11::gil_scoped_acquire gil; + return pybind11::module::import("IPCMonitor.utils").attr("get_rank_id")().cast(); + }(); + return rankId; +} + +uint64_t CalcHashId(const std::string &data) +{ + static const uint32_t UINT32_BITS = 32; + uint32_t prime[2] = {29, 131}; + uint32_t hash[2] = {0}; + for (char d : data) { + hash[0] = hash[0] * prime[0] + static_cast(d); + hash[1] = hash[1] * prime[1] + static_cast(d); + } + return (static_cast(hash[0]) << UINT32_BITS) | hash[1]; +} + +std::string GetHostName() +{ + char hostName[PATH_MAX] = {0}; + if (gethostname(hostName, PATH_MAX) != 0) { + return ""; + } + return std::string(hostName); +} + +std::string GetHostUid() +{ + static const uint8_t SECOND_LEAST_BIT = 1 << 1; + struct ifaddrs *ifaddr = nullptr; + if (getifaddrs(&ifaddr) == -1) { + if (ifaddr != nullptr) { + freeifaddrs(ifaddr); + } + return 0; + } + std::vector universalMacAddrs; + std::vector localMacAddrs; + for (struct ifaddrs *ifa = ifaddr; ifa != nullptr; ifa = ifa->ifa_next) { + if (ifa->ifa_addr == nullptr || ifa->ifa_addr->sa_family != AF_PACKET) { + continue; + } + if ((ifa->ifa_flags & IFF_LOOPBACK) != 0) { + continue; + } + struct sockaddr_ll *lladdr = ReinterpretConvert(ifa->ifa_addr); + uint32_t len = static_cast(lladdr->sll_halen); + if (len > 0) { + std::string addr; + for (uint32_t i = 0; i < len; ++i) { + std::string hexAddr = IntToHexStr(static_cast(lladdr->sll_addr[i])); + addr += (hexAddr.length() > 1) ? hexAddr : ("0" + hexAddr); + } + if ((lladdr->sll_addr[0] & SECOND_LEAST_BIT) == 0) { + universalMacAddrs.emplace_back(addr); + } else { + localMacAddrs.emplace_back(addr); + } + } + } + if (ifaddr != nullptr) { + freeifaddrs(ifaddr); + } + if (universalMacAddrs.empty() && localMacAddrs.empty()) { + return 0; + } + auto &macAddrs = universalMacAddrs.empty() ? localMacAddrs : universalMacAddrs; + std::sort(macAddrs.begin(), macAddrs.end()); + return std::to_string(CalcHashId(join(macAddrs, "-"))); +} + +bool CreateMsmonitorLogPath(std::string &path) +{ + const char* logPathEnvVal = getenv("MSMONITOR_LOG_PATH"); + std::string logPath; + if (logPathEnvVal != nullptr) { + logPath = logPathEnvVal; + } + if (logPath.empty()) { + char cwdPath[PATH_MAX] = {0}; + if (getcwd(cwdPath, PATH_MAX) != nullptr) { + logPath = cwdPath; + } + } + if (logPath.empty()) { + fprintf(stderr, "[ERROR] Failed to get msmonitor log path.\n"); + return false; + } + logPath = logPath + "/msmonitor_log"; + std::string absPath = PathUtils::RelativeToAbsPath(logPath); + if (PathUtils::DirPathCheck(absPath)) { + std::string realPath = PathUtils::RealPath(absPath); + if (PathUtils::CreateDir(realPath)) { + path = realPath; + return true; + } + fprintf(stderr, "[ERROR] Create LOG_PATH: %s failed.\n", realPath.c_str()); + } else { + fprintf(stderr, "[ERROR] LOG_PATH: %s of Msmonitor is invalid.\n", absPath.c_str()); + } + return false; +} +} // namespace ipc_monitor +} // namespace dynolog_npu diff --git a/msmonitor/plugin/ipc_monitor/utils.h b/msmonitor/plugin/ipc_monitor/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..8a0d93ce42536acc9c214ab2231803c45a298bf8 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/utils.h @@ -0,0 +1,127 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef IPC_MONITOR_UTILS_H +#define IPC_MONITOR_UTILS_H + +#include +#include +#include +#include +#include +#include +#include + +namespace dynolog_npu { +namespace ipc_monitor { +constexpr int MaxParentPids = 5; +int32_t GetProcessId(); +std::string GenerateUuidV4(); +std::vector GetPids(); +std::pair GetParentPidAndCommand(int32_t pid); +std::vector> GetPidCommandPairsofAncestors(); +std::string getCurrentTimestamp(); +uint64_t getCurrentTimestamp64(); +bool Str2Uint32(uint32_t& dest, const std::string& str); +bool Str2Int32(int32_t& dest, const std::string& str); +bool Str2Bool(bool& dest, const std::string& str); +std::string& trim(std::string& str); +std::vector split(const std::string& str, char delimiter); +std::string join(const std::vector& strs, const std::string& delimiter); + +constexpr size_t ALIGN_SIZE = 8; +void *MsptiMalloc(size_t size, size_t alignment); +void MsptiFree(uint8_t *ptr); +const mode_t DATA_FILE_AUTHORITY = 0640; +const mode_t DATA_DIR_AUTHORITY = 0750; +const uint32_t DEFAULT_FLUSH_INTERVAL = 60; + +enum class SubModule { + IPC = 0 +}; + +enum class ErrCode { + SUC = 0, + PARAM = 1, + TYPE = 2, + VALUE = 3, + PTR = 4, + INTERNAL = 5, + MEMORY = 6, + NOT_SUPPORT = 7, + NOT_FOUND = 8, + UNAVAIL = 9, + SYSCALL = 10, + TIMEOUT = 11, + PERMISSION = 12, +}; + +std::string formatErrorCode(SubModule submodule, ErrCode errorCode); + +#define IPC_ERROR(error) formatErrorCode(SubModule::IPC, error) + +template +inline T ReinterpretConvert(V ptr) +{ + return reinterpret_cast(ptr); +} + +template +inline void MakeSharedPtr(std::shared_ptr& ptr, Args&&... args) +{ + try { + ptr = std::make_shared(std::forward(args)...); + } catch(std::bad_alloc& e) { + throw; + } catch (...) { + ptr = nullptr; + return; + } +} + +template +auto groupby(const Container& vec, KeyFunc keyFunc) +{ + using KeyType = decltype(keyFunc(*vec.begin())); + using ValueType = typename Container::value_type; + std::unordered_map> grouped; + for (const auto& item : vec) { + grouped[keyFunc(item)].push_back(item); + } + return grouped; +} + +int GetRankId(); +uint64_t CalcHashId(const std::string &data); +std::string GetHostName(); +std::string GetHostUid(); +bool CreateMsmonitorLogPath(std::string& path); + +struct PathUtils { + static bool IsFileExist(const std::string &path); + static bool IsFileWritable(const std::string &path); + static bool IsDir(const std::string &path); + static bool CreateDir(const std::string &path); + static std::string RealPath(const std::string &path); + static std::string RelativeToAbsPath(const std::string &path); + static std::string DirName(const std::string &path); + static bool CreateFile(const std::string &path); + static bool IsSoftLink(const std::string &path); + static bool DirPathCheck(const std::string &path); +}; +} // namespace ipc_monitor +} // namespace dynolog_npu +#endif // IPC_MONITOR_UTILS_H diff --git a/msmonitor/plugin/setup.py b/msmonitor/plugin/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..2cd1881672488a000aa71dbb58f6c34525c98e2f --- /dev/null +++ b/msmonitor/plugin/setup.py @@ -0,0 +1,69 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +import subprocess +import pybind11 + +from setuptools import setup, Extension, find_namespace_packages +from setuptools.command.build_ext import build_ext + + +class CMakeExtension(Extension): + def __init__(self, name, sourcedir=""): + super().__init__(name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) + + +class CMakeBuild(build_ext): + def run(self): + for ext in self.extensions: + self.build_extension(ext) + + def build_extension(self, ext): + cfg = 'Debug' if self.debug else 'Release' + build_args = ['--config', cfg] + + ext_dir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) + cmake_args = [ + '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + ext_dir, + '-DPYTHON_EXECUTABLE=' + sys.executable, + '-DCMAKE_PREFIX_PATH=' + pybind11.get_cmake_dir(), + '-DCMAKE_INSTALL_PREFIX=' + ext_dir, + '-DCMAKE_BUILD_TYPE=' + cfg + ] + + env = os.environ.copy() + env['CXXFLAGS'] = '{} -DVERSION_INFO=\\"{}\\"'.format(env.get('CXXFLAGS', ''), + self.distribution.get_version()) + + if not os.path.exists(self.build_temp): + os.makedirs(self.build_temp) + subprocess.check_call(['cmake', ext.sourcedir] + cmake_args, cwd=self.build_temp, env=env) + subprocess.check_call(['cmake', '--build', '.', '-j', '8'] + build_args, + cwd=self.build_temp) + + +setup( + name="msmonitor_plugin", + version="8.1.0", + description="msMonitor plugin", + packages=find_namespace_packages(include=["IPCMonitor*"]), + include_package_data=True, + ext_modules=[CMakeExtension('IPCMonitor')], + cmdclass=dict(build_ext=CMakeBuild), + install_requires=["pybind11"], +) diff --git a/msmonitor/plugin/stub/build_stub.sh b/msmonitor/plugin/stub/build_stub.sh new file mode 100644 index 0000000000000000000000000000000000000000..97ec0699aec5923497ee32a7252b0337db059f7f --- /dev/null +++ b/msmonitor/plugin/stub/build_stub.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +CDIR="$(cd "$(dirname "$0")" ; pwd -P)" + +cd ${CDIR} + +gcc -fPIC -shared -o libmspti.so -I../ipc_monitor/mspti_monitor mspti.cpp diff --git a/msmonitor/plugin/stub/mspti.cpp b/msmonitor/plugin/stub/mspti.cpp new file mode 100644 index 0000000000000000000000000000000000000000..566b4b83f9e9a9f79645a9c6e21cbaa0584af2a6 --- /dev/null +++ b/msmonitor/plugin/stub/mspti.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mspti.h" + +msptiResult msptiSubscribe(msptiSubscriberHandle *subscriber, msptiCallbackFunc callback, void *userdata) +{ + return MSPTI_SUCCESS; +} + +msptiResult msptiUnsubscribe(msptiSubscriberHandle subscriber) +{ + return MSPTI_SUCCESS; +} + +msptiResult msptiActivityRegisterCallbacks(msptiBuffersCallbackRequestFunc funcBufferRequested, msptiBuffersCallbackCompleteFunc funcBufferCompleted) +{ + return MSPTI_SUCCESS; +} + +msptiResult msptiActivityEnable(msptiActivityKind kind) +{ + return MSPTI_SUCCESS; +} + +msptiResult msptiActivityDisable(msptiActivityKind kind) +{ + return MSPTI_SUCCESS; +} + +msptiResult msptiActivityGetNextRecord(uint8_t *buffer, size_t validBufferSizeBytes, msptiActivity **record) +{ + return MSPTI_SUCCESS; +} + +msptiResult msptiActivityFlushAll(uint32_t flag) +{ + return MSPTI_SUCCESS; +} + +msptiResult msptiActivityEnableMarkerDomain(const char* name) +{ + return MSPTI_SUCCESS; +} + +msptiResult msptiActivityDisableMarkerDomain(const char* name) +{ + return MSPTI_SUCCESS; +} diff --git a/msmonitor/plugin/third_party/securec/include/securec.h b/msmonitor/plugin/third_party/securec/include/securec.h new file mode 100644 index 0000000000000000000000000000000000000000..fa575ffe359104deabd5d32154c9afbc81065ddf --- /dev/null +++ b/msmonitor/plugin/third_party/securec/include/securec.h @@ -0,0 +1,161 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2014-2021. All rights reserved. + * Licensed under Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, + * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, + * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + * Description: The user of this secure c library should include this header file in you source code. + * This header file declare all supported API prototype of the library, + * such as memcpy_s, strcpy_s, wcscpy_s,strcat_s, strncat_s, sprintf_s, scanf_s, and so on. + * Create: 2014-02-25 + * Notes: Do not modify this file by yourself. + */ + +#ifndef SECUREC_H_5D13A042_DC3F_4ED9_A8D1_882811274C27 +#define SECUREC_H_5D13A042_DC3F_4ED9_A8D1_882811274C27 + +#include "securectype.h" +#ifndef SECUREC_HAVE_STDARG_H +#define SECUREC_HAVE_STDARG_H 1 +#endif + +#if SECUREC_HAVE_STDARG_H +#include +#endif + +#ifndef SECUREC_HAVE_ERRNO_H +#define SECUREC_HAVE_ERRNO_H 1 +#endif + +/* EINVAL ERANGE may defined in errno.h */ +#if SECUREC_HAVE_ERRNO_H +#if SECUREC_IN_KERNEL +#include +#else +#include +#endif +#endif + +/* Define error code */ +#if defined(SECUREC_NEED_ERRNO_TYPE) || !defined(__STDC_WANT_LIB_EXT1__) || \ + (defined(__STDC_WANT_LIB_EXT1__) && (!__STDC_WANT_LIB_EXT1__)) +#ifndef SECUREC_DEFINED_ERRNO_TYPE +#define SECUREC_DEFINED_ERRNO_TYPE +/* Just check whether macrodefinition exists. */ +#ifndef errno_t +typedef int errno_t; +#endif +#endif +#endif + +/* Success */ +#ifndef EOK +#define EOK 0 +#endif + +#ifndef EINVAL +/* The src buffer is not correct and destination buffer can not be reset */ +#define EINVAL 22 +#endif + +#ifndef EINVAL_AND_RESET +/* Once the error is detected, the dest buffer must be reset! Value is 22 or 128 */ +#define EINVAL_AND_RESET 150 +#endif + +#ifndef ERANGE +/* The destination buffer is not long enough and destination buffer can not be reset */ +#define ERANGE 34 +#endif + +#ifndef ERANGE_AND_RESET +/* Once the error is detected, the dest buffer must be reset! Value is 34 or 128 */ +#define ERANGE_AND_RESET 162 +#endif + +#ifndef EOVERLAP_AND_RESET +/* Once the buffer overlap is detected, the dest buffer must be reset! Value is 54 or 128 */ +#define EOVERLAP_AND_RESET 182 +#endif + +/* If you need export the function of this library in Win32 dll, use __declspec(dllexport) */ +#ifndef SECUREC_API +#if defined(SECUREC_DLL_EXPORT) +#if defined(_MSC_VER) +#define SECUREC_API __declspec(dllexport) +#else /* build for linux */ +#define SECUREC_API __attribute__((visibility("default"))) +#endif /* end of _MSC_VER and SECUREC_DLL_EXPORT */ +#elif defined(SECUREC_DLL_IMPORT) +#if defined(_MSC_VER) +#define SECUREC_API __declspec(dllimport) +#else +#define SECUREC_API +#endif /* end of _MSC_VER and SECUREC_DLL_IMPORT */ +#else +/* + * Standardized function declaration. If a security function is declared in the your code, + * it may cause a compilation alarm,Please delete the security function you declared. + * Adding extern under windows will cause the system to have inline functions to expand, + * so do not add the extern in default + */ +#if defined(_MSC_VER) +#define SECUREC_API +#else +#define SECUREC_API extern +#endif +#endif +#endif + +#ifdef __cplusplus +extern "C" { +#endif +/* + * Description: The GetHwSecureCVersion function get SecureC Version string and version number. + * Parameter: verNumber - to store version number (for example value is 0x500 | 0xa) + * Return: version string + */ +SECUREC_API const char *GetHwSecureCVersion(unsigned short *verNumber); + +#if SECUREC_ENABLE_MEMSET +/* + * Description: The memset_s function copies the value of c (converted to an unsigned char) into each of + * the first count characters of the object pointed to by dest. + * Parameter: dest - destination address + * Parameter: destMax - The maximum length of destination buffer + * Parameter: c - the value to be copied + * Parameter: count - copies count bytes of value to dest + * Return: EOK if there was no runtime-constraint violation + */ +SECUREC_API errno_t memset_s(void *dest, size_t destMax, int c, size_t count); +#endif + +#ifndef SECUREC_ONLY_DECLARE_MEMSET +#define SECUREC_ONLY_DECLARE_MEMSET 0 +#endif + +#if !SECUREC_ONLY_DECLARE_MEMSET + +#if SECUREC_ENABLE_MEMCPY +/* + * Description: The memcpy_s function copies n characters from the object pointed to + * by src into the object pointed to by dest. + * Parameter: dest - destination address + * Parameter: destMax - The maximum length of destination buffer + * Parameter: src - source address + * Parameter: count - copies count bytes from the src + * Return: EOK if there was no runtime-constraint violation + */ +SECUREC_API errno_t memcpy_s(void *dest, size_t destMax, const void *src, size_t count); +#endif + +#endif + +#ifdef __cplusplus +} +#endif +#endif diff --git a/msmonitor/plugin/third_party/securec/include/securectype.h b/msmonitor/plugin/third_party/securec/include/securectype.h new file mode 100644 index 0000000000000000000000000000000000000000..c406d198971a926ead3f9564072a6d9e828e6894 --- /dev/null +++ b/msmonitor/plugin/third_party/securec/include/securectype.h @@ -0,0 +1,501 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2014-2021. All rights reserved. + * Licensed under Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, + * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, + * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + * Description: Define internal used macro and data type. The marco of SECUREC_ON_64BITS + * will be determined in this header file, which is a switch for part + * of code. Some macro are used to suppress warning by MS compiler. + * Create: 2014-02-25 + * Notes: User can change the value of SECUREC_STRING_MAX_LEN and SECUREC_MEM_MAX_LEN + * macro to meet their special need, but The maximum value should not exceed 2G. + */ +/* + * [Standardize-exceptions]: Performance-sensitive + * [reason]: Strict parameter verification has been done before use + */ + +#ifndef SECURECTYPE_H_A7BBB686_AADA_451B_B9F9_44DACDAE18A7 +#define SECURECTYPE_H_A7BBB686_AADA_451B_B9F9_44DACDAE18A7 + +#ifndef SECUREC_USING_STD_SECURE_LIB +#if defined(_MSC_VER) && _MSC_VER >= 1400 +#if defined(__STDC_WANT_SECURE_LIB__) && (!__STDC_WANT_SECURE_LIB__) +/* Security functions have been provided since vs2005, default use of system library functions */ +#define SECUREC_USING_STD_SECURE_LIB 0 +#else +#define SECUREC_USING_STD_SECURE_LIB 1 +#endif +#else +#define SECUREC_USING_STD_SECURE_LIB 0 +#endif +#endif + +/* Compatibility with older Secure C versions, shielding VC symbol redefinition warning */ +#if defined(_MSC_VER) && (_MSC_VER >= 1400) && (!SECUREC_USING_STD_SECURE_LIB) +#ifndef SECUREC_DISABLE_CRT_FUNC +#define SECUREC_DISABLE_CRT_FUNC 1 +#endif +#ifndef SECUREC_DISABLE_CRT_IMP +#define SECUREC_DISABLE_CRT_IMP 1 +#endif +#else /* MSC VER */ +#ifndef SECUREC_DISABLE_CRT_FUNC +#define SECUREC_DISABLE_CRT_FUNC 0 +#endif +#ifndef SECUREC_DISABLE_CRT_IMP +#define SECUREC_DISABLE_CRT_IMP 0 +#endif +#endif + +#if SECUREC_DISABLE_CRT_FUNC +#ifdef __STDC_WANT_SECURE_LIB__ +#undef __STDC_WANT_SECURE_LIB__ +#endif +#define __STDC_WANT_SECURE_LIB__ 0 +#endif + +#if SECUREC_DISABLE_CRT_IMP +#ifdef _CRTIMP_ALTERNATIVE +#undef _CRTIMP_ALTERNATIVE +#endif +#define _CRTIMP_ALTERNATIVE /* Comment Microsoft *_s function */ +#endif + +/* Compile in kernel under macro control */ +#ifndef SECUREC_IN_KERNEL +#ifdef __KERNEL__ +#define SECUREC_IN_KERNEL 1 +#else +#define SECUREC_IN_KERNEL 0 +#endif +#endif + +/* make kernel symbols of functions available to loadable modules */ +#ifndef SECUREC_EXPORT_KERNEL_SYMBOL +#if SECUREC_IN_KERNEL +#define SECUREC_EXPORT_KERNEL_SYMBOL 1 +#else +#define SECUREC_EXPORT_KERNEL_SYMBOL 0 +#endif +#endif + +#if SECUREC_IN_KERNEL +#ifndef SECUREC_ENABLE_SCANF_FILE +#define SECUREC_ENABLE_SCANF_FILE 0 +#endif +#ifndef SECUREC_ENABLE_WCHAR_FUNC +#define SECUREC_ENABLE_WCHAR_FUNC 0 +#endif +#else /* SECUREC_IN_KERNEL */ +#ifndef SECUREC_ENABLE_SCANF_FILE +#define SECUREC_ENABLE_SCANF_FILE 1 +#endif +#ifndef SECUREC_ENABLE_WCHAR_FUNC +#define SECUREC_ENABLE_WCHAR_FUNC 1 +#endif +#endif + +/* Default secure function declaration, default declarations for non-standard functions */ +#ifndef SECUREC_SNPRINTF_TRUNCATED +#define SECUREC_SNPRINTF_TRUNCATED 1 +#endif + +#if SECUREC_USING_STD_SECURE_LIB +#if defined(_MSC_VER) && _MSC_VER >= 1400 +/* Declare secure functions that are not available in the VS compiler */ +#ifndef SECUREC_ENABLE_MEMSET +#define SECUREC_ENABLE_MEMSET 1 +#endif +/* VS 2005 have vsnprintf_s function */ +#ifndef SECUREC_ENABLE_VSNPRINTF +#define SECUREC_ENABLE_VSNPRINTF 0 +#endif +#ifndef SECUREC_ENABLE_SNPRINTF +/* VS 2005 have vsnprintf_s function Adapt the snprintf_s of the security function */ +#define snprintf_s _snprintf_s +#define SECUREC_ENABLE_SNPRINTF 0 +#endif +/* Before VS 2010 do not have v functions */ +#if _MSC_VER <= 1600 || defined(SECUREC_FOR_V_SCANFS) +#ifndef SECUREC_ENABLE_VFSCANF +#define SECUREC_ENABLE_VFSCANF 1 +#endif +#ifndef SECUREC_ENABLE_VSCANF +#define SECUREC_ENABLE_VSCANF 1 +#endif +#ifndef SECUREC_ENABLE_VSSCANF +#define SECUREC_ENABLE_VSSCANF 1 +#endif +#endif + +#else /* MSC VER */ +#ifndef SECUREC_ENABLE_MEMSET +#define SECUREC_ENABLE_MEMSET 0 +#endif +#ifndef SECUREC_ENABLE_SNPRINTF +#define SECUREC_ENABLE_SNPRINTF 0 +#endif +#ifndef SECUREC_ENABLE_VSNPRINTF +#define SECUREC_ENABLE_VSNPRINTF 0 +#endif +#endif + +#ifndef SECUREC_ENABLE_MEMMOVE +#define SECUREC_ENABLE_MEMMOVE 0 +#endif +#ifndef SECUREC_ENABLE_MEMCPY +#define SECUREC_ENABLE_MEMCPY 0 +#endif +#ifndef SECUREC_ENABLE_STRCPY +#define SECUREC_ENABLE_STRCPY 0 +#endif +#ifndef SECUREC_ENABLE_STRNCPY +#define SECUREC_ENABLE_STRNCPY 0 +#endif +#ifndef SECUREC_ENABLE_STRCAT +#define SECUREC_ENABLE_STRCAT 0 +#endif +#ifndef SECUREC_ENABLE_STRNCAT +#define SECUREC_ENABLE_STRNCAT 0 +#endif +#ifndef SECUREC_ENABLE_SPRINTF +#define SECUREC_ENABLE_SPRINTF 0 +#endif +#ifndef SECUREC_ENABLE_VSPRINTF +#define SECUREC_ENABLE_VSPRINTF 0 +#endif +#ifndef SECUREC_ENABLE_SSCANF +#define SECUREC_ENABLE_SSCANF 0 +#endif +#ifndef SECUREC_ENABLE_VSSCANF +#define SECUREC_ENABLE_VSSCANF 0 +#endif +#ifndef SECUREC_ENABLE_SCANF +#define SECUREC_ENABLE_SCANF 0 +#endif +#ifndef SECUREC_ENABLE_VSCANF +#define SECUREC_ENABLE_VSCANF 0 +#endif + +#ifndef SECUREC_ENABLE_FSCANF +#define SECUREC_ENABLE_FSCANF 0 +#endif +#ifndef SECUREC_ENABLE_VFSCANF +#define SECUREC_ENABLE_VFSCANF 0 +#endif +#ifndef SECUREC_ENABLE_STRTOK +#define SECUREC_ENABLE_STRTOK 0 +#endif +#ifndef SECUREC_ENABLE_GETS +#define SECUREC_ENABLE_GETS 0 +#endif + +#else /* SECUREC USE STD SECURE LIB */ + +#ifndef SECUREC_ENABLE_MEMSET +#define SECUREC_ENABLE_MEMSET 1 +#endif +#ifndef SECUREC_ENABLE_MEMMOVE +#define SECUREC_ENABLE_MEMMOVE 1 +#endif +#ifndef SECUREC_ENABLE_MEMCPY +#define SECUREC_ENABLE_MEMCPY 1 +#endif +#ifndef SECUREC_ENABLE_STRCPY +#define SECUREC_ENABLE_STRCPY 1 +#endif +#ifndef SECUREC_ENABLE_STRNCPY +#define SECUREC_ENABLE_STRNCPY 1 +#endif +#ifndef SECUREC_ENABLE_STRCAT +#define SECUREC_ENABLE_STRCAT 1 +#endif +#ifndef SECUREC_ENABLE_STRNCAT +#define SECUREC_ENABLE_STRNCAT 1 +#endif +#ifndef SECUREC_ENABLE_SPRINTF +#define SECUREC_ENABLE_SPRINTF 1 +#endif +#ifndef SECUREC_ENABLE_VSPRINTF +#define SECUREC_ENABLE_VSPRINTF 1 +#endif +#ifndef SECUREC_ENABLE_SNPRINTF +#define SECUREC_ENABLE_SNPRINTF 1 +#endif +#ifndef SECUREC_ENABLE_VSNPRINTF +#define SECUREC_ENABLE_VSNPRINTF 1 +#endif +#ifndef SECUREC_ENABLE_SSCANF +#define SECUREC_ENABLE_SSCANF 1 +#endif +#ifndef SECUREC_ENABLE_VSSCANF +#define SECUREC_ENABLE_VSSCANF 1 +#endif +#ifndef SECUREC_ENABLE_SCANF +#if SECUREC_ENABLE_SCANF_FILE +#define SECUREC_ENABLE_SCANF 1 +#else +#define SECUREC_ENABLE_SCANF 0 +#endif +#endif +#ifndef SECUREC_ENABLE_VSCANF +#if SECUREC_ENABLE_SCANF_FILE +#define SECUREC_ENABLE_VSCANF 1 +#else +#define SECUREC_ENABLE_VSCANF 0 +#endif +#endif + +#ifndef SECUREC_ENABLE_FSCANF +#if SECUREC_ENABLE_SCANF_FILE +#define SECUREC_ENABLE_FSCANF 1 +#else +#define SECUREC_ENABLE_FSCANF 0 +#endif +#endif +#ifndef SECUREC_ENABLE_VFSCANF +#if SECUREC_ENABLE_SCANF_FILE +#define SECUREC_ENABLE_VFSCANF 1 +#else +#define SECUREC_ENABLE_VFSCANF 0 +#endif +#endif + +#ifndef SECUREC_ENABLE_STRTOK +#define SECUREC_ENABLE_STRTOK 1 +#endif +#ifndef SECUREC_ENABLE_GETS +#define SECUREC_ENABLE_GETS 1 +#endif +#endif /* SECUREC_USE_STD_SECURE_LIB */ + +#if !SECUREC_ENABLE_SCANF_FILE +#if SECUREC_ENABLE_FSCANF +#undef SECUREC_ENABLE_FSCANF +#define SECUREC_ENABLE_FSCANF 0 +#endif +#if SECUREC_ENABLE_VFSCANF +#undef SECUREC_ENABLE_VFSCANF +#define SECUREC_ENABLE_VFSCANF 0 +#endif +#if SECUREC_ENABLE_SCANF +#undef SECUREC_ENABLE_SCANF +#define SECUREC_ENABLE_SCANF 0 +#endif +#if SECUREC_ENABLE_FSCANF +#undef SECUREC_ENABLE_FSCANF +#define SECUREC_ENABLE_FSCANF 0 +#endif + +#endif + +#if SECUREC_IN_KERNEL +#include +#include +#else +#ifndef SECUREC_HAVE_STDIO_H +#define SECUREC_HAVE_STDIO_H 1 +#endif +#ifndef SECUREC_HAVE_STRING_H +#define SECUREC_HAVE_STRING_H 1 +#endif +#ifndef SECUREC_HAVE_STDLIB_H +#define SECUREC_HAVE_STDLIB_H 1 +#endif +#if SECUREC_HAVE_STDIO_H +#include +#endif +#if SECUREC_HAVE_STRING_H +#include +#endif +#if SECUREC_HAVE_STDLIB_H +#include +#endif +#endif + +/* + * If you need high performance, enable the SECUREC_WITH_PERFORMANCE_ADDONS macro, default is enable. + * The macro is automatically closed on the windows platform and linux kernel + */ +#ifndef SECUREC_WITH_PERFORMANCE_ADDONS +#if SECUREC_IN_KERNEL +#define SECUREC_WITH_PERFORMANCE_ADDONS 0 +#else +#define SECUREC_WITH_PERFORMANCE_ADDONS 1 +#endif +#endif + +/* If enable SECUREC_COMPATIBLE_WIN_FORMAT, the output format will be compatible to Windows. */ +#if (defined(_WIN32) || defined(_WIN64) || defined(_MSC_VER)) && !defined(SECUREC_COMPATIBLE_LINUX_FORMAT) +#ifndef SECUREC_COMPATIBLE_WIN_FORMAT +#define SECUREC_COMPATIBLE_WIN_FORMAT +#endif +#endif + +#if defined(SECUREC_COMPATIBLE_WIN_FORMAT) +/* On windows platform, can't use optimized function for there is no __builtin_constant_p like function */ +/* If need optimized macro, can define this: define __builtin_constant_p(x) 0 */ +#ifdef SECUREC_WITH_PERFORMANCE_ADDONS +#undef SECUREC_WITH_PERFORMANCE_ADDONS +#define SECUREC_WITH_PERFORMANCE_ADDONS 0 +#endif +#endif + +#if defined(__VXWORKS__) || defined(__vxworks) || defined(__VXWORKS) || defined(_VXWORKS_PLATFORM_) || \ + defined(SECUREC_VXWORKS_VERSION_5_4) +#ifndef SECUREC_VXWORKS_PLATFORM +#define SECUREC_VXWORKS_PLATFORM +#endif +#endif + +/* If enable SECUREC_COMPATIBLE_LINUX_FORMAT, the output format will be compatible to Linux. */ +#if !defined(SECUREC_COMPATIBLE_WIN_FORMAT) && !defined(SECUREC_VXWORKS_PLATFORM) +#ifndef SECUREC_COMPATIBLE_LINUX_FORMAT +#define SECUREC_COMPATIBLE_LINUX_FORMAT +#endif +#endif + +#ifdef SECUREC_COMPATIBLE_LINUX_FORMAT +#ifndef SECUREC_HAVE_STDDEF_H +#define SECUREC_HAVE_STDDEF_H 1 +#endif +/* Some system may no stddef.h */ +#if SECUREC_HAVE_STDDEF_H +#if !SECUREC_IN_KERNEL +#include +#endif +#endif +#endif + +/* + * Add the -DSECUREC_SUPPORT_FORMAT_WARNING=1 compiler option to supoort -Wformat=2. + * Default does not check the format is that the same data type in the actual code. + * In the product is different in the original data type definition of VxWorks and Linux. + */ +#ifndef SECUREC_SUPPORT_FORMAT_WARNING +#define SECUREC_SUPPORT_FORMAT_WARNING 0 +#endif + +#if SECUREC_SUPPORT_FORMAT_WARNING +#define SECUREC_ATTRIBUTE(x, y) __attribute__((format(printf, (x), (y)))) +#else +#define SECUREC_ATTRIBUTE(x, y) +#endif + +/* + * Add the -DSECUREC_SUPPORT_BUILTIN_EXPECT=0 compiler option, if compiler can not support __builtin_expect. + */ +#ifndef SECUREC_SUPPORT_BUILTIN_EXPECT +#define SECUREC_SUPPORT_BUILTIN_EXPECT 1 +#endif + +#if SECUREC_SUPPORT_BUILTIN_EXPECT && defined(__GNUC__) && ((__GNUC__ > 3) || \ + (defined(__GNUC_MINOR__) && (__GNUC__ == 3 && __GNUC_MINOR__ > 3))) +/* + * This is a built-in function that can be used without a declaration, if warning for declaration not found occurred, + * you can add -DSECUREC_NEED_BUILTIN_EXPECT_DECLARE to compiler options + */ +#ifdef SECUREC_NEED_BUILTIN_EXPECT_DECLARE +long __builtin_expect(long exp, long c); +#endif + +#define SECUREC_LIKELY(x) __builtin_expect(!!(x), 1) +#define SECUREC_UNLIKELY(x) __builtin_expect(!!(x), 0) +#else +#define SECUREC_LIKELY(x) (x) +#define SECUREC_UNLIKELY(x) (x) +#endif + +/* Define the max length of the string */ +#ifndef SECUREC_STRING_MAX_LEN +#define SECUREC_STRING_MAX_LEN 0x7fffffffUL +#endif +#define SECUREC_WCHAR_STRING_MAX_LEN (SECUREC_STRING_MAX_LEN / sizeof(wchar_t)) + +/* Add SECUREC_MEM_MAX_LEN for memcpy and memmove */ +#ifndef SECUREC_MEM_MAX_LEN +#define SECUREC_MEM_MAX_LEN 0x7fffffffUL +#endif +#define SECUREC_WCHAR_MEM_MAX_LEN (SECUREC_MEM_MAX_LEN / sizeof(wchar_t)) + +#if SECUREC_STRING_MAX_LEN > 0x7fffffffUL +#error "max string is 2G" +#endif + +#if (defined(__GNUC__) && defined(__SIZEOF_POINTER__)) +#if (__SIZEOF_POINTER__ != 4) && (__SIZEOF_POINTER__ != 8) +#error "unsupported system" +#endif +#endif + +#if defined(_WIN64) || defined(WIN64) || defined(__LP64__) || defined(_LP64) +#define SECUREC_ON_64BITS +#endif + +#if (!defined(SECUREC_ON_64BITS) && defined(__GNUC__) && defined(__SIZEOF_POINTER__)) +#if __SIZEOF_POINTER__ == 8 +#define SECUREC_ON_64BITS +#endif +#endif + +#if defined(__SVR4) || defined(__svr4__) +#define SECUREC_ON_SOLARIS +#endif + +#if (defined(__hpux) || defined(_AIX) || defined(SECUREC_ON_SOLARIS)) +#define SECUREC_ON_UNIX +#endif + +/* + * Codes should run under the macro SECUREC_COMPATIBLE_LINUX_FORMAT in unknown system on default, + * and strtold. + * The function strtold is referenced first at ISO9899:1999(C99), and some old compilers can + * not support these functions. Here provides a macro to open these functions: + * SECUREC_SUPPORT_STRTOLD -- If defined, strtold will be used + */ +#ifndef SECUREC_SUPPORT_STRTOLD +#define SECUREC_SUPPORT_STRTOLD 0 +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT)) +#if defined(__USE_ISOC99) || \ + (defined(_AIX) && defined(_ISOC99_SOURCE)) || \ + (defined(__hpux) && defined(__ia64)) || \ + (defined(SECUREC_ON_SOLARIS) && (!defined(_STRICT_STDC) && !defined(__XOPEN_OR_POSIX)) || \ + defined(_STDC_C99) || defined(__EXTENSIONS__)) +#undef SECUREC_SUPPORT_STRTOLD +#define SECUREC_SUPPORT_STRTOLD 1 +#endif +#endif +#if ((defined(SECUREC_WRLINUX_BELOW4) || defined(_WRLINUX_BELOW4_))) +#undef SECUREC_SUPPORT_STRTOLD +#define SECUREC_SUPPORT_STRTOLD 0 +#endif +#endif + +#if SECUREC_WITH_PERFORMANCE_ADDONS + +#ifndef SECUREC_TWO_MIN +#define SECUREC_TWO_MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +/* This macro do not check buffer overlap by default */ +#define SECUREC_MEMCPY_SM(dest, destMax, src, count) \ + (!(((size_t)(destMax) == 0) || \ + (((unsigned long long)(destMax) & (unsigned long long)(-2)) > SECUREC_MEM_MAX_LEN) || \ + ((size_t)(count) > (size_t)(destMax)) || ((void *)(dest)) == NULL || ((const void *)(src) == NULL)) ? \ + (memcpy((dest), (src), (count)), EOK) : \ + (memcpy_s((dest), (destMax), (src), (count)))) + +#define SECUREC_MEMSET_SM(dest, destMax, c, count) \ + (!((((unsigned long long)(destMax) & (unsigned long long)(-2)) > SECUREC_MEM_MAX_LEN) || \ + ((void *)(dest) == NULL) || ((size_t)(count) > (size_t)(destMax))) ? \ + (memset((dest), (c), (count)), EOK) : \ + (memset_s((dest), (destMax), (c), (count)))) + +#endif +#endif diff --git a/msmonitor/plugin/third_party/securec/src/memcpy_s.c b/msmonitor/plugin/third_party/securec/src/memcpy_s.c new file mode 100644 index 0000000000000000000000000000000000000000..a7fd48748e50a7180c2afd8a1def9b05180eb8bc --- /dev/null +++ b/msmonitor/plugin/third_party/securec/src/memcpy_s.c @@ -0,0 +1,555 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2014-2021. All rights reserved. + * Licensed under Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, + * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, + * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + * Description: memcpy_s function + * Create: 2014-02-25 + */ +/* + * [Standardize-exceptions] Use unsafe function: Portability + * [reason] Use unsafe function to implement security function to maintain platform compatibility. + * And sufficient input validation is performed before calling + */ + +#include "securecutil.h" + +#if SECUREC_WITH_PERFORMANCE_ADDONS +#ifndef SECUREC_MEMCOPY_THRESHOLD_SIZE +#define SECUREC_MEMCOPY_THRESHOLD_SIZE 64UL +#endif + +#define SECUREC_SMALL_MEM_COPY(dest, src, count) do { \ + if (SECUREC_ADDR_ALIGNED_8(dest) && SECUREC_ADDR_ALIGNED_8(src)) { \ + /* Use struct assignment */ \ + switch (count) { \ + case 1: \ + *(unsigned char *)(dest) = *(const unsigned char *)(src); \ + break; \ + case 2: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 2); \ + break; \ + case 3: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 3); \ + break; \ + case 4: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 4); \ + break; \ + case 5: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 5); \ + break; \ + case 6: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 6); \ + break; \ + case 7: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 7); \ + break; \ + case 8: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 8); \ + break; \ + case 9: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 9); \ + break; \ + case 10: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 10); \ + break; \ + case 11: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 11); \ + break; \ + case 12: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 12); \ + break; \ + case 13: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 13); \ + break; \ + case 14: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 14); \ + break; \ + case 15: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 15); \ + break; \ + case 16: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 16); \ + break; \ + case 17: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 17); \ + break; \ + case 18: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 18); \ + break; \ + case 19: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 19); \ + break; \ + case 20: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 20); \ + break; \ + case 21: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 21); \ + break; \ + case 22: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 22); \ + break; \ + case 23: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 23); \ + break; \ + case 24: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 24); \ + break; \ + case 25: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 25); \ + break; \ + case 26: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 26); \ + break; \ + case 27: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 27); \ + break; \ + case 28: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 28); \ + break; \ + case 29: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 29); \ + break; \ + case 30: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 30); \ + break; \ + case 31: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 31); \ + break; \ + case 32: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 32); \ + break; \ + case 33: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 33); \ + break; \ + case 34: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 34); \ + break; \ + case 35: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 35); \ + break; \ + case 36: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 36); \ + break; \ + case 37: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 37); \ + break; \ + case 38: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 38); \ + break; \ + case 39: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 39); \ + break; \ + case 40: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 40); \ + break; \ + case 41: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 41); \ + break; \ + case 42: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 42); \ + break; \ + case 43: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 43); \ + break; \ + case 44: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 44); \ + break; \ + case 45: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 45); \ + break; \ + case 46: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 46); \ + break; \ + case 47: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 47); \ + break; \ + case 48: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 48); \ + break; \ + case 49: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 49); \ + break; \ + case 50: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 50); \ + break; \ + case 51: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 51); \ + break; \ + case 52: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 52); \ + break; \ + case 53: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 53); \ + break; \ + case 54: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 54); \ + break; \ + case 55: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 55); \ + break; \ + case 56: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 56); \ + break; \ + case 57: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 57); \ + break; \ + case 58: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 58); \ + break; \ + case 59: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 59); \ + break; \ + case 60: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 60); \ + break; \ + case 61: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 61); \ + break; \ + case 62: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 62); \ + break; \ + case 63: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 63); \ + break; \ + case 64: \ + SECUREC_COPY_VALUE_BY_STRUCT((dest), (src), 64); \ + break; \ + default: \ + /* Do nothing */ \ + break; \ + } /* END switch */ \ + } else { \ + unsigned char *tmpDest_ = (unsigned char *)(dest); \ + const unsigned char *tmpSrc_ = (const unsigned char *)(src); \ + switch (count) { \ + case 64: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 63: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 62: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 61: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 60: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 59: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 58: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 57: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 56: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 55: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 54: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 53: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 52: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 51: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 50: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 49: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 48: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 47: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 46: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 45: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 44: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 43: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 42: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 41: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 40: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 39: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 38: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 37: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 36: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 35: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 34: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 33: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 32: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 31: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 30: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 29: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 28: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 27: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 26: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 25: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 24: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 23: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 22: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 21: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 20: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 19: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 18: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 17: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 16: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 15: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 14: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 13: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 12: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 11: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 10: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 9: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 8: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 7: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 6: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 5: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 4: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 3: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 2: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + case 1: \ + *(tmpDest_++) = *(tmpSrc_++); \ + /* fall-through */ /* FALLTHRU */ \ + default: \ + /* Do nothing */ \ + break; \ + } \ + } \ +} SECUREC_WHILE_ZERO + +/* + * Performance optimization + */ +#define SECUREC_MEMCPY_OPT(dest, src, count) do { \ + if ((count) > SECUREC_MEMCOPY_THRESHOLD_SIZE) { \ + SECUREC_MEMCPY_WARP_OPT((dest), (src), (count)); \ + } else { \ + SECUREC_SMALL_MEM_COPY((dest), (src), (count)); \ + } \ +} SECUREC_WHILE_ZERO +#endif + +/* + * Handling errors + */ +SECUREC_INLINE errno_t SecMemcpyError(void *dest, size_t destMax, const void *src, size_t count) +{ + if (destMax == 0 || destMax > SECUREC_MEM_MAX_LEN) { + SECUREC_ERROR_INVALID_RANGE("memcpy_s"); + return ERANGE; + } + if (dest == NULL || src == NULL) { + SECUREC_ERROR_INVALID_PARAMTER("memcpy_s"); + if (dest != NULL) { + (void)SECUREC_MEMSET_FUNC_OPT(dest, 0, destMax); + return EINVAL_AND_RESET; + } + return EINVAL; + } + if (count > destMax) { + (void)SECUREC_MEMSET_FUNC_OPT(dest, 0, destMax); + SECUREC_ERROR_INVALID_RANGE("memcpy_s"); + return ERANGE_AND_RESET; + } + if (SECUREC_MEMORY_IS_OVERLAP(dest, src, count)) { + (void)SECUREC_MEMSET_FUNC_OPT(dest, 0, destMax); + SECUREC_ERROR_BUFFER_OVERLAP("memcpy_s"); + return EOVERLAP_AND_RESET; + } + /* Count is 0 or dest equal src also ret EOK */ + return EOK; +} + +#if defined(SECUREC_COMPATIBLE_WIN_FORMAT) + /* + * The fread API in windows will call memcpy_s and pass 0xffffffff to destMax. + * To avoid the failure of fread, we don't check desMax limit. + */ +#define SECUREC_MEMCPY_PARAM_OK(dest, destMax, src, count) (SECUREC_LIKELY((count) <= (destMax) && \ + (dest) != NULL && (src) != NULL && \ + (count) > 0 && SECUREC_MEMORY_NO_OVERLAP((dest), (src), (count)))) +#else +#define SECUREC_MEMCPY_PARAM_OK(dest, destMax, src, count) (SECUREC_LIKELY((count) <= (destMax) && \ + (dest) != NULL && (src) != NULL && (destMax) <= SECUREC_MEM_MAX_LEN && \ + (count) > 0 && SECUREC_MEMORY_NO_OVERLAP((dest), (src), (count)))) +#endif + +/* + * + * The memcpy_s function copies n characters from the object pointed to by src into the object pointed to by dest + * + * + * dest Destination buffer. + * destMax Size of the destination buffer. + * src Buffer to copy from. + * count Number of characters to copy + * + * + * dest buffer is updated. + * + * + * EOK Success + * EINVAL dest is NULL and destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN + * EINVAL_AND_RESET dest != NULL and src is NULL and destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN + * ERANGE destMax > SECUREC_MEM_MAX_LEN or destMax is 0 + * ERANGE_AND_RESET count > destMax and destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN + * and dest != NULL and src != NULL + * EOVERLAP_AND_RESET dest buffer and source buffer are overlapped and + * count <= destMax destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN and dest != NULL + * and src != NULL and dest != src + * + * if an error occurred, dest will be filled with 0. + * If the source and destination overlap, the behavior of memcpy_s is undefined. + * Use memmove_s to handle overlapping regions. + */ +errno_t memcpy_s(void *dest, size_t destMax, const void *src, size_t count) +{ + if (SECUREC_MEMCPY_PARAM_OK(dest, destMax, src, count)) { + SECUREC_MEMCPY_WARP_OPT(dest, src, count); + return EOK; + } + /* Meet some runtime violation, return error code */ + return SecMemcpyError(dest, destMax, src, count); +} + +#if SECUREC_EXPORT_KERNEL_SYMBOL +EXPORT_SYMBOL(memcpy_s); +#endif + +#if SECUREC_WITH_PERFORMANCE_ADDONS +/* + * Performance optimization + */ +errno_t memcpy_sOptAsm(void *dest, size_t destMax, const void *src, size_t count) +{ + if (SECUREC_MEMCPY_PARAM_OK(dest, destMax, src, count)) { + SECUREC_MEMCPY_OPT(dest, src, count); + return EOK; + } + /* Meet some runtime violation, return error code */ + return SecMemcpyError(dest, destMax, src, count); +} + +/* Trim judgement on "destMax <= SECUREC_MEM_MAX_LEN" */ +errno_t memcpy_sOptTc(void *dest, size_t destMax, const void *src, size_t count) +{ + if (SECUREC_LIKELY(count <= destMax && dest != NULL && src != NULL && \ + count > 0 && SECUREC_MEMORY_NO_OVERLAP((dest), (src), (count)))) { + SECUREC_MEMCPY_OPT(dest, src, count); + return EOK; + } + /* Meet some runtime violation, return error code */ + return SecMemcpyError(dest, destMax, src, count); +} +#endif + diff --git a/msmonitor/plugin/third_party/securec/src/memset_s.c b/msmonitor/plugin/third_party/securec/src/memset_s.c new file mode 100644 index 0000000000000000000000000000000000000000..d9a657fd326af60ec1195b226aa762855042299b --- /dev/null +++ b/msmonitor/plugin/third_party/securec/src/memset_s.c @@ -0,0 +1,510 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2014-2021. All rights reserved. + * Licensed under Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, + * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, + * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + * Description: memset_s function + * Create: 2014-02-25 + */ +/* + * [Standardize-exceptions] Use unsafe function: Portability + * [reason] Use unsafe function to implement security function to maintain platform compatibility. + * And sufficient input validation is performed before calling + */ + +#include "securecutil.h" + +#define SECUREC_MEMSET_PARAM_OK(dest, destMax, count) (SECUREC_LIKELY((destMax) <= SECUREC_MEM_MAX_LEN && \ + (dest) != NULL && (count) <= (destMax))) + +#if SECUREC_WITH_PERFORMANCE_ADDONS + +/* Use union to clear strict-aliasing warning */ +typedef union { + SecStrBuf32 buf32; + SecStrBuf31 buf31; + SecStrBuf30 buf30; + SecStrBuf29 buf29; + SecStrBuf28 buf28; + SecStrBuf27 buf27; + SecStrBuf26 buf26; + SecStrBuf25 buf25; + SecStrBuf24 buf24; + SecStrBuf23 buf23; + SecStrBuf22 buf22; + SecStrBuf21 buf21; + SecStrBuf20 buf20; + SecStrBuf19 buf19; + SecStrBuf18 buf18; + SecStrBuf17 buf17; + SecStrBuf16 buf16; + SecStrBuf15 buf15; + SecStrBuf14 buf14; + SecStrBuf13 buf13; + SecStrBuf12 buf12; + SecStrBuf11 buf11; + SecStrBuf10 buf10; + SecStrBuf9 buf9; + SecStrBuf8 buf8; + SecStrBuf7 buf7; + SecStrBuf6 buf6; + SecStrBuf5 buf5; + SecStrBuf4 buf4; + SecStrBuf3 buf3; + SecStrBuf2 buf2; +} SecStrBuf32Union; +/* C standard initializes the first member of the consortium. */ +static const SecStrBuf32 g_allZero = {{ + 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, + 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, + 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, + 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U +}}; +static const SecStrBuf32 g_allFF = {{ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF +}}; + +/* Clear conversion warning strict aliasing" */ +SECUREC_INLINE const SecStrBuf32Union *SecStrictAliasingCast(const SecStrBuf32 *buf) +{ + return (const SecStrBuf32Union *)buf; +} + +#ifndef SECUREC_MEMSET_THRESHOLD_SIZE +#define SECUREC_MEMSET_THRESHOLD_SIZE 32UL +#endif + +#define SECUREC_UNALIGNED_SET(dest, c, count) do { \ + unsigned char *pDest_ = (unsigned char *)(dest); \ + switch (count) { \ + case 32: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 31: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 30: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 29: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 28: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 27: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 26: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 25: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 24: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 23: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 22: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 21: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 20: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 19: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 18: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 17: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 16: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 15: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 14: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 13: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 12: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 11: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 10: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 9: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 8: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 7: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 6: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 5: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 4: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 3: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 2: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + case 1: \ + *(pDest_++) = (unsigned char)(c); \ + /* fall-through */ /* FALLTHRU */ \ + default: \ + /* Do nothing */ \ + break; \ + } \ +} SECUREC_WHILE_ZERO + +#define SECUREC_SET_VALUE_BY_STRUCT(dest, dataName, n) do { \ + *(SecStrBuf##n *)(dest) = *(const SecStrBuf##n *)(&((SecStrictAliasingCast(&(dataName)))->buf##n)); \ +} SECUREC_WHILE_ZERO + +#define SECUREC_ALIGNED_SET_OPT_ZERO_FF(dest, c, count) do { \ + switch (c) { \ + case 0: \ + switch (count) { \ + case 1: \ + *(unsigned char *)(dest) = (unsigned char)0; \ + break; \ + case 2: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 2); \ + break; \ + case 3: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 3); \ + break; \ + case 4: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 4); \ + break; \ + case 5: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 5); \ + break; \ + case 6: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 6); \ + break; \ + case 7: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 7); \ + break; \ + case 8: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 8); \ + break; \ + case 9: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 9); \ + break; \ + case 10: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 10); \ + break; \ + case 11: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 11); \ + break; \ + case 12: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 12); \ + break; \ + case 13: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 13); \ + break; \ + case 14: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 14); \ + break; \ + case 15: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 15); \ + break; \ + case 16: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 16); \ + break; \ + case 17: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 17); \ + break; \ + case 18: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 18); \ + break; \ + case 19: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 19); \ + break; \ + case 20: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 20); \ + break; \ + case 21: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 21); \ + break; \ + case 22: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 22); \ + break; \ + case 23: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 23); \ + break; \ + case 24: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 24); \ + break; \ + case 25: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 25); \ + break; \ + case 26: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 26); \ + break; \ + case 27: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 27); \ + break; \ + case 28: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 28); \ + break; \ + case 29: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 29); \ + break; \ + case 30: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 30); \ + break; \ + case 31: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 31); \ + break; \ + case 32: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allZero, 32); \ + break; \ + default: \ + /* Do nothing */ \ + break; \ + } \ + break; \ + case 0xFF: \ + switch (count) { \ + case 1: \ + *(unsigned char *)(dest) = (unsigned char)0xffU; \ + break; \ + case 2: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 2); \ + break; \ + case 3: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 3); \ + break; \ + case 4: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 4); \ + break; \ + case 5: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 5); \ + break; \ + case 6: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 6); \ + break; \ + case 7: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 7); \ + break; \ + case 8: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 8); \ + break; \ + case 9: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 9); \ + break; \ + case 10: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 10); \ + break; \ + case 11: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 11); \ + break; \ + case 12: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 12); \ + break; \ + case 13: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 13); \ + break; \ + case 14: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 14); \ + break; \ + case 15: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 15); \ + break; \ + case 16: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 16); \ + break; \ + case 17: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 17); \ + break; \ + case 18: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 18); \ + break; \ + case 19: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 19); \ + break; \ + case 20: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 20); \ + break; \ + case 21: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 21); \ + break; \ + case 22: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 22); \ + break; \ + case 23: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 23); \ + break; \ + case 24: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 24); \ + break; \ + case 25: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 25); \ + break; \ + case 26: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 26); \ + break; \ + case 27: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 27); \ + break; \ + case 28: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 28); \ + break; \ + case 29: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 29); \ + break; \ + case 30: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 30); \ + break; \ + case 31: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 31); \ + break; \ + case 32: \ + SECUREC_SET_VALUE_BY_STRUCT((dest), g_allFF, 32); \ + break; \ + default: \ + /* Do nothing */ \ + break; \ + } \ + break; \ + default: \ + SECUREC_UNALIGNED_SET((dest), (c), (count)); \ + break; \ + } /* END switch */ \ +} SECUREC_WHILE_ZERO + +#define SECUREC_SMALL_MEM_SET(dest, c, count) do { \ + if (SECUREC_ADDR_ALIGNED_8((dest))) { \ + SECUREC_ALIGNED_SET_OPT_ZERO_FF((dest), (c), (count)); \ + } else { \ + SECUREC_UNALIGNED_SET((dest), (c), (count)); \ + } \ +} SECUREC_WHILE_ZERO + +/* + * Performance optimization + */ +#define SECUREC_MEMSET_OPT(dest, c, count) do { \ + if ((count) > SECUREC_MEMSET_THRESHOLD_SIZE) { \ + SECUREC_MEMSET_PREVENT_DSE((dest), (c), (count)); \ + } else { \ + SECUREC_SMALL_MEM_SET((dest), (c), (count)); \ + } \ +} SECUREC_WHILE_ZERO +#endif + +/* + * Handling errors + */ +SECUREC_INLINE errno_t SecMemsetError(void *dest, size_t destMax, int c) +{ + /* Check destMax is 0 compatible with _sp macro */ + if (destMax == 0 || destMax > SECUREC_MEM_MAX_LEN) { + SECUREC_ERROR_INVALID_RANGE("memset_s"); + return ERANGE; + } + if (dest == NULL) { + SECUREC_ERROR_INVALID_PARAMTER("memset_s"); + return EINVAL; + } + SECUREC_MEMSET_PREVENT_DSE(dest, c, destMax); /* Set entire buffer to value c */ + SECUREC_ERROR_INVALID_RANGE("memset_s"); + return ERANGE_AND_RESET; +} + +/* + * + * The memset_s function copies the value of c (converted to an unsigned char) + * into each of the first count characters of the object pointed to by dest. + * + * + * dest Pointer to destination. + * destMax The size of the buffer. + * c Character to set. + * count Number of characters. + * + * + * dest buffer is updated. + * + * + * EOK Success + * EINVAL dest == NULL and destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN + * ERANGE destMax > SECUREC_MEM_MAX_LEN or (destMax is 0 and count > destMax) + * ERANGE_AND_RESET count > destMax and destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN and dest != NULL + * + * if return ERANGE_AND_RESET then fill dest to c ,fill length is destMax + */ +errno_t memset_s(void *dest, size_t destMax, int c, size_t count) +{ + if (SECUREC_MEMSET_PARAM_OK(dest, destMax, count)) { + SECUREC_MEMSET_PREVENT_DSE(dest, c, count); + return EOK; + } + /* Meet some runtime violation, return error code */ + return SecMemsetError(dest, destMax, c); +} + +#if SECUREC_EXPORT_KERNEL_SYMBOL +EXPORT_SYMBOL(memset_s); +#endif + +#if SECUREC_WITH_PERFORMANCE_ADDONS +/* + * Performance optimization + */ +errno_t memset_sOptAsm(void *dest, size_t destMax, int c, size_t count) +{ + if (SECUREC_MEMSET_PARAM_OK(dest, destMax, count)) { + SECUREC_MEMSET_OPT(dest, c, count); + return EOK; + } + /* Meet some runtime violation, return error code */ + return SecMemsetError(dest, destMax, c); +} + +/* + * Performance optimization, trim judgement on "destMax <= SECUREC_MEM_MAX_LEN" + */ +errno_t memset_sOptTc(void *dest, size_t destMax, int c, size_t count) +{ + if (SECUREC_LIKELY(count <= destMax && dest != NULL)) { + SECUREC_MEMSET_OPT(dest, c, count); + return EOK; + } + /* Meet some runtime violation, return error code */ + return SecMemsetError(dest, destMax, c); +} +#endif + diff --git a/msmonitor/plugin/third_party/securec/src/securecutil.c b/msmonitor/plugin/third_party/securec/src/securecutil.c new file mode 100644 index 0000000000000000000000000000000000000000..0053a72cfab51526702fecc78d1cbe4616e68abb --- /dev/null +++ b/msmonitor/plugin/third_party/securec/src/securecutil.c @@ -0,0 +1,81 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2014-2021. All rights reserved. + * Licensed under Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, + * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, + * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + * Description: Provides internal functions used by this library, such as memory + * copy and memory move. Besides, include some helper function for + * printf family API, such as SecVsnprintfImpl + * Create: 2014-02-25 + */ + +/* Avoid duplicate header files,not include securecutil.h */ +#include "securecutil.h" + +#if defined(ANDROID) && !defined(SECUREC_CLOSE_ANDROID_HANDLE) && (SECUREC_HAVE_WCTOMB || SECUREC_HAVE_MBTOWC) +#include +#if SECUREC_HAVE_WCTOMB +/* + * Convert wide characters to narrow multi-bytes + */ +int wctomb(char *s, wchar_t wc) +{ + return (int)wcrtomb(s, wc, NULL); +} +#endif + +#if SECUREC_HAVE_MBTOWC +/* + * Converting narrow multi-byte characters to wide characters + * mbrtowc returns -1 or -2 upon failure, unlike mbtowc, which only returns -1 + * When the return value is less than zero, we treat it as a failure + */ +int mbtowc(wchar_t *pwc, const char *s, size_t n) +{ + return (int)mbrtowc(pwc, s, n, NULL); +} +#endif +#endif + +/* The V100R001C01 version num is 0x5 (High 8 bits) */ +#define SECUREC_C_VERSION 0x500U +#define SECUREC_SPC_VERSION 0x10U +#define SECUREC_VERSION_STR "1.1.16" + +/* + * Get version string and version number. + * The rules for version number are as follows: + * 1) SPC verNumber<->verStr like: + * 0x201<->C01 + * 0x202<->C01SPC001 Redefine numbers after this version + * 0x502<->C01SPC002 + * 0x503<->C01SPC003 + * ... + * 0X50a<->SPC010 + * 0X50b<->SPC011 + * ... + * 0x700<->C02 + * 0x701<->C01SPC001 + * 0x702<->C02SPC002 + * ... + * 2) CP verNumber<->verStr like: + * 0X601<->CP0001 + * 0X602<->CP0002 + * ... + */ +const char *GetHwSecureCVersion(unsigned short *verNumber) +{ + if (verNumber != NULL) { + *verNumber = (unsigned short)(SECUREC_C_VERSION | SECUREC_SPC_VERSION); + } + return SECUREC_VERSION_STR; +} +#if SECUREC_EXPORT_KERNEL_SYMBOL +EXPORT_SYMBOL(GetHwSecureCVersion); +#endif + diff --git a/msmonitor/plugin/third_party/securec/src/securecutil.h b/msmonitor/plugin/third_party/securec/src/securecutil.h new file mode 100644 index 0000000000000000000000000000000000000000..7e3bd691f9ece9decd2fcb3c239697c806597246 --- /dev/null +++ b/msmonitor/plugin/third_party/securec/src/securecutil.h @@ -0,0 +1,574 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2014-2021. All rights reserved. + * Licensed under Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, + * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, + * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + * Description: Define macro, data struct, and declare internal used function prototype, + * which is used by secure functions. + * Create: 2014-02-25 + */ + +#ifndef SECURECUTIL_H_46C86578_F8FF_4E49_8E64_9B175241761F +#define SECURECUTIL_H_46C86578_F8FF_4E49_8E64_9B175241761F +#include "securec.h" + +#if (defined(_MSC_VER)) && (_MSC_VER >= 1400) +/* Shield compilation alerts using discarded functions and Constant expression to maximize code compatibility */ +#define SECUREC_MASK_MSVC_CRT_WARNING __pragma(warning(push)) \ + __pragma(warning(disable : 4996 4127)) +#define SECUREC_END_MASK_MSVC_CRT_WARNING __pragma(warning(pop)) +#else +#define SECUREC_MASK_MSVC_CRT_WARNING +#define SECUREC_END_MASK_MSVC_CRT_WARNING +#endif +#define SECUREC_WHILE_ZERO SECUREC_MASK_MSVC_CRT_WARNING while (0) SECUREC_END_MASK_MSVC_CRT_WARNING + +/* Automatically identify the platform that supports strnlen function, and use this function to improve performance */ +#ifndef SECUREC_HAVE_STRNLEN +#if (defined(_XOPEN_SOURCE) && _XOPEN_SOURCE >= 700) || (defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 200809L) +#if SECUREC_IN_KERNEL +#define SECUREC_HAVE_STRNLEN 0 +#else +#if defined(__GLIBC__) && __GLIBC__ >= 2 && defined(__GLIBC_MINOR__) && __GLIBC_MINOR__ >= 10 +#define SECUREC_HAVE_STRNLEN 1 +#else +#define SECUREC_HAVE_STRNLEN 0 +#endif +#endif +#else +#define SECUREC_HAVE_STRNLEN 0 +#endif +#endif + +#if SECUREC_IN_KERNEL +/* In kernel disable functions */ +#ifndef SECUREC_ENABLE_SCANF_FILE +#define SECUREC_ENABLE_SCANF_FILE 0 +#endif +#ifndef SECUREC_ENABLE_SCANF_FLOAT +#define SECUREC_ENABLE_SCANF_FLOAT 0 +#endif +#ifndef SECUREC_ENABLE_SPRINTF_FLOAT +#define SECUREC_ENABLE_SPRINTF_FLOAT 0 +#endif +#ifndef SECUREC_HAVE_MBTOWC +#define SECUREC_HAVE_MBTOWC 0 +#endif +#ifndef SECUREC_HAVE_WCTOMB +#define SECUREC_HAVE_WCTOMB 0 +#endif +#ifndef SECUREC_HAVE_WCHART +#define SECUREC_HAVE_WCHART 0 +#endif +#else /* Not in kernel */ +/* Systems that do not support file, can define this macro to 0. */ +#ifndef SECUREC_ENABLE_SCANF_FILE +#define SECUREC_ENABLE_SCANF_FILE 1 +#endif +#ifndef SECUREC_ENABLE_SCANF_FLOAT +#define SECUREC_ENABLE_SCANF_FLOAT 1 +#endif +/* Systems that do not support float, can define this macro to 0. */ +#ifndef SECUREC_ENABLE_SPRINTF_FLOAT +#define SECUREC_ENABLE_SPRINTF_FLOAT 1 +#endif +#ifndef SECUREC_HAVE_MBTOWC +#define SECUREC_HAVE_MBTOWC 1 +#endif +#ifndef SECUREC_HAVE_WCTOMB +#define SECUREC_HAVE_WCTOMB 1 +#endif +#ifndef SECUREC_HAVE_WCHART +#define SECUREC_HAVE_WCHART 1 +#endif +#endif + +#ifndef SECUREC_ENABLE_INLINE +#define SECUREC_ENABLE_INLINE 0 +#endif + +#ifndef SECUREC_INLINE +#if SECUREC_ENABLE_INLINE +#define SECUREC_INLINE static inline +#else +#define SECUREC_INLINE static +#endif +#endif + +#ifndef SECUREC_WARP_OUTPUT +#if SECUREC_IN_KERNEL +#define SECUREC_WARP_OUTPUT 1 +#else +#define SECUREC_WARP_OUTPUT 0 +#endif +#endif + +#ifndef SECUREC_STREAM_STDIN +#define SECUREC_STREAM_STDIN stdin +#endif + +#define SECUREC_MUL_SIXTEEN(x) ((x) << 4U) +#define SECUREC_MUL_EIGHT(x) ((x) << 3U) +#define SECUREC_MUL_TEN(x) ((((x) << 2U) + (x)) << 1U) +/* Limited format input and output width, use signed integer */ +#define SECUREC_MAX_WIDTH_LEN_DIV_TEN 21474836 +#define SECUREC_MAX_WIDTH_LEN (SECUREC_MAX_WIDTH_LEN_DIV_TEN * 10) +/* Is the x multiplied by 10 greater than */ +#define SECUREC_MUL_TEN_ADD_BEYOND_MAX(x) (((x) > SECUREC_MAX_WIDTH_LEN_DIV_TEN)) + +#define SECUREC_FLOAT_BUFSIZE (309 + 40) /* Max length of double value */ +#define SECUREC_FLOAT_BUFSIZE_LB (4932 + 40) /* Max length of long double value */ +#define SECUREC_FLOAT_DEFAULT_PRECISION 6 + +/* This macro does not handle pointer equality or integer overflow */ +#define SECUREC_MEMORY_NO_OVERLAP(dest, src, count) \ + (((src) < (dest) && ((const char *)(src) + (count)) <= (char *)(dest)) || \ + ((dest) < (src) && ((char *)(dest) + (count)) <= (const char *)(src))) + +#define SECUREC_MEMORY_IS_OVERLAP(dest, src, count) \ + (((src) < (dest) && ((const char *)(src) + (count)) > (char *)(dest)) || \ + ((dest) < (src) && ((char *)(dest) + (count)) > (const char *)(src))) + +/* + * Check whether the strings overlap, len is the length of the string not include terminator + * Length is related to data type char or wchar , do not force conversion of types + */ +#define SECUREC_STRING_NO_OVERLAP(dest, src, len) \ + (((src) < (dest) && ((src) + (len)) < (dest)) || \ + ((dest) < (src) && ((dest) + (len)) < (src))) + +/* + * Check whether the strings overlap for strcpy wcscpy function, dest len and src Len are not include terminator + * Length is related to data type char or wchar , do not force conversion of types + */ +#define SECUREC_STRING_IS_OVERLAP(dest, src, len) \ + (((src) < (dest) && ((src) + (len)) >= (dest)) || \ + ((dest) < (src) && ((dest) + (len)) >= (src))) + +/* + * Check whether the strings overlap for strcat wcscat function, dest len and src Len are not include terminator + * Length is related to data type char or wchar , do not force conversion of types + */ +#define SECUREC_CAT_STRING_IS_OVERLAP(dest, destLen, src, srcLen) \ + (((dest) < (src) && ((dest) + (destLen) + (srcLen)) >= (src)) || \ + ((src) < (dest) && ((src) + (srcLen)) >= (dest))) + +#if SECUREC_HAVE_STRNLEN +#define SECUREC_CALC_STR_LEN(str, maxLen, outLen) do { \ + *(outLen) = strnlen((str), (maxLen)); \ +} SECUREC_WHILE_ZERO +#define SECUREC_CALC_STR_LEN_OPT(str, maxLen, outLen) do { \ + if ((maxLen) > 8) { \ + /* Optimization or len less then 8 */ \ + if (*((str) + 0) == '\0') { \ + *(outLen) = 0; \ + } else if (*((str) + 1) == '\0') { \ + *(outLen) = 1; \ + } else if (*((str) + 2) == '\0') { \ + *(outLen) = 2; \ + } else if (*((str) + 3) == '\0') { \ + *(outLen) = 3; \ + } else if (*((str) + 4) == '\0') { \ + *(outLen) = 4; \ + } else if (*((str) + 5) == '\0') { \ + *(outLen) = 5; \ + } else if (*((str) + 6) == '\0') { \ + *(outLen) = 6; \ + } else if (*((str) + 7) == '\0') { \ + *(outLen) = 7; \ + } else if (*((str) + 8) == '\0') { \ + /* Optimization with a length of 8 */ \ + *(outLen) = 8; \ + } else { \ + /* The offset is 8 because the performance of 8 byte alignment is high */ \ + *(outLen) = 8 + strnlen((str) + 8, (maxLen) - 8); \ + } \ + } else { \ + SECUREC_CALC_STR_LEN((str), (maxLen), (outLen)); \ + } \ +} SECUREC_WHILE_ZERO +#else +#define SECUREC_CALC_STR_LEN(str, maxLen, outLen) do { \ + const char *strEnd_ = (const char *)(str); \ + size_t availableSize_ = (size_t)(maxLen); \ + while (availableSize_ > 0 && *strEnd_ != '\0') { \ + --availableSize_; \ + ++strEnd_; \ + } \ + *(outLen) = (size_t)(strEnd_ - (str)); \ +} SECUREC_WHILE_ZERO +#define SECUREC_CALC_STR_LEN_OPT SECUREC_CALC_STR_LEN +#endif + +#define SECUREC_CALC_WSTR_LEN(str, maxLen, outLen) do { \ + const wchar_t *strEnd_ = (const wchar_t *)(str); \ + size_t len_ = 0; \ + while (len_ < (maxLen) && *strEnd_ != L'\0') { \ + ++len_; \ + ++strEnd_; \ + } \ + *(outLen) = len_; \ +} SECUREC_WHILE_ZERO + +/* + * Performance optimization, product may disable inline function. + * Using function pointer for MEMSET to prevent compiler optimization when cleaning up memory. + */ +#ifdef SECUREC_USE_ASM +#define SECUREC_MEMSET_FUNC_OPT memset_opt +#define SECUREC_MEMCPY_FUNC_OPT memcpy_opt +#else +#define SECUREC_MEMSET_FUNC_OPT memset +#define SECUREC_MEMCPY_FUNC_OPT memcpy +#endif + +#define SECUREC_MEMCPY_WARP_OPT(dest, src, count) (void)SECUREC_MEMCPY_FUNC_OPT((dest), (src), (count)) + +#ifndef SECUREC_MEMSET_BARRIER +#if defined(__GNUC__) +/* Can be turned off for scenarios that do not use memory barrier */ +#define SECUREC_MEMSET_BARRIER 1 +#else +#define SECUREC_MEMSET_BARRIER 0 +#endif +#endif + +#ifndef SECUREC_MEMSET_INDIRECT_USE +/* Can be turned off for scenarios that do not allow pointer calls */ +#define SECUREC_MEMSET_INDIRECT_USE 1 +#endif + +#if SECUREC_MEMSET_BARRIER +#define SECUREC_MEMORY_BARRIER(dest) __asm__ __volatile__("": : "r"(dest) : "memory") +#else +#define SECUREC_MEMORY_BARRIER(dest) +#endif + +#if SECUREC_MEMSET_BARRIER +#define SECUREC_MEMSET_PREVENT_DSE(dest, value, count) do { \ + (void)SECUREC_MEMSET_FUNC_OPT(dest, value, count); \ + SECUREC_MEMORY_BARRIER(dest); \ +} SECUREC_WHILE_ZERO +#elif SECUREC_MEMSET_INDIRECT_USE +#define SECUREC_MEMSET_PREVENT_DSE(dest, value, count) do { \ + void *(* const volatile fn_)(void *s_, int c_, size_t n_) = SECUREC_MEMSET_FUNC_OPT; \ + (void)(*fn_)((dest), (value), (count)); \ +} SECUREC_WHILE_ZERO +#else +#define SECUREC_MEMSET_PREVENT_DSE(dest, value, count) (void)SECUREC_MEMSET_FUNC_OPT((dest), (value), (count)) +#endif + +#ifdef SECUREC_FORMAT_OUTPUT_INPUT +#if defined(SECUREC_COMPATIBLE_WIN_FORMAT) || defined(__ARMCC_VERSION) +typedef __int64 SecInt64; +typedef unsigned __int64 SecUnsignedInt64; +#if defined(__ARMCC_VERSION) +typedef unsigned int SecUnsignedInt32; +#else +typedef unsigned __int32 SecUnsignedInt32; +#endif +#else +typedef unsigned int SecUnsignedInt32; +typedef long long SecInt64; +typedef unsigned long long SecUnsignedInt64; +#endif + +#ifdef SECUREC_FOR_WCHAR +#if defined(SECUREC_VXWORKS_PLATFORM) && !defined(__WINT_TYPE__) +typedef wchar_t wint_t; +#endif +#ifndef WEOF +#define WEOF ((wchar_t)(-1)) +#endif +#define SECUREC_CHAR(x) L ## x +typedef wchar_t SecChar; +typedef wchar_t SecUnsignedChar; +typedef wint_t SecInt; +typedef wint_t SecUnsignedInt; +#else /* no SECUREC_FOR_WCHAR */ +#define SECUREC_CHAR(x) (x) +typedef char SecChar; +typedef unsigned char SecUnsignedChar; +typedef int SecInt; +typedef unsigned int SecUnsignedInt; +#endif +#endif + +/* + * Determine whether the address is 8-byte aligned + * Some systems do not have uintptr_t type, so use NULL to clear tool alarm 507 + */ +#define SECUREC_ADDR_ALIGNED_8(addr) ((((size_t)(addr)) & 7U) == 0) /* Use 7 to check aligned 8 */ + +/* + * If you define the memory allocation function, you need to define the function prototype. + * You can define this macro as a header file. + */ +#if defined(SECUREC_MALLOC_PROTOTYPE) +SECUREC_MALLOC_PROTOTYPE +#endif + +#ifndef SECUREC_MALLOC +#define SECUREC_MALLOC(x) malloc((size_t)(x)) +#endif + +#ifndef SECUREC_FREE +#define SECUREC_FREE(x) free((void *)(x)) +#endif + +/* Improve performance with struct assignment, buf1 is not defined to avoid tool false positive */ +#define SECUREC_COPY_VALUE_BY_STRUCT(dest, src, n) do { \ + *(SecStrBuf##n *)(void *)(dest) = *(const SecStrBuf##n *)(const void *)(src); \ +} SECUREC_WHILE_ZERO + +typedef struct { + unsigned char buf[2]; /* Performance optimization code structure assignment length 2 bytes */ +} SecStrBuf2; +typedef struct { + unsigned char buf[3]; /* Performance optimization code structure assignment length 3 bytes */ +} SecStrBuf3; +typedef struct { + unsigned char buf[4]; /* Performance optimization code structure assignment length 4 bytes */ +} SecStrBuf4; +typedef struct { + unsigned char buf[5]; /* Performance optimization code structure assignment length 5 bytes */ +} SecStrBuf5; +typedef struct { + unsigned char buf[6]; /* Performance optimization code structure assignment length 6 bytes */ +} SecStrBuf6; +typedef struct { + unsigned char buf[7]; /* Performance optimization code structure assignment length 7 bytes */ +} SecStrBuf7; +typedef struct { + unsigned char buf[8]; /* Performance optimization code structure assignment length 8 bytes */ +} SecStrBuf8; +typedef struct { + unsigned char buf[9]; /* Performance optimization code structure assignment length 9 bytes */ +} SecStrBuf9; +typedef struct { + unsigned char buf[10]; /* Performance optimization code structure assignment length 10 bytes */ +} SecStrBuf10; +typedef struct { + unsigned char buf[11]; /* Performance optimization code structure assignment length 11 bytes */ +} SecStrBuf11; +typedef struct { + unsigned char buf[12]; /* Performance optimization code structure assignment length 12 bytes */ +} SecStrBuf12; +typedef struct { + unsigned char buf[13]; /* Performance optimization code structure assignment length 13 bytes */ +} SecStrBuf13; +typedef struct { + unsigned char buf[14]; /* Performance optimization code structure assignment length 14 bytes */ +} SecStrBuf14; +typedef struct { + unsigned char buf[15]; /* Performance optimization code structure assignment length 15 bytes */ +} SecStrBuf15; +typedef struct { + unsigned char buf[16]; /* Performance optimization code structure assignment length 16 bytes */ +} SecStrBuf16; +typedef struct { + unsigned char buf[17]; /* Performance optimization code structure assignment length 17 bytes */ +} SecStrBuf17; +typedef struct { + unsigned char buf[18]; /* Performance optimization code structure assignment length 18 bytes */ +} SecStrBuf18; +typedef struct { + unsigned char buf[19]; /* Performance optimization code structure assignment length 19 bytes */ +} SecStrBuf19; +typedef struct { + unsigned char buf[20]; /* Performance optimization code structure assignment length 20 bytes */ +} SecStrBuf20; +typedef struct { + unsigned char buf[21]; /* Performance optimization code structure assignment length 21 bytes */ +} SecStrBuf21; +typedef struct { + unsigned char buf[22]; /* Performance optimization code structure assignment length 22 bytes */ +} SecStrBuf22; +typedef struct { + unsigned char buf[23]; /* Performance optimization code structure assignment length 23 bytes */ +} SecStrBuf23; +typedef struct { + unsigned char buf[24]; /* Performance optimization code structure assignment length 24 bytes */ +} SecStrBuf24; +typedef struct { + unsigned char buf[25]; /* Performance optimization code structure assignment length 25 bytes */ +} SecStrBuf25; +typedef struct { + unsigned char buf[26]; /* Performance optimization code structure assignment length 26 bytes */ +} SecStrBuf26; +typedef struct { + unsigned char buf[27]; /* Performance optimization code structure assignment length 27 bytes */ +} SecStrBuf27; +typedef struct { + unsigned char buf[28]; /* Performance optimization code structure assignment length 28 bytes */ +} SecStrBuf28; +typedef struct { + unsigned char buf[29]; /* Performance optimization code structure assignment length 29 bytes */ +} SecStrBuf29; +typedef struct { + unsigned char buf[30]; /* Performance optimization code structure assignment length 30 bytes */ +} SecStrBuf30; +typedef struct { + unsigned char buf[31]; /* Performance optimization code structure assignment length 31 bytes */ +} SecStrBuf31; +typedef struct { + unsigned char buf[32]; /* Performance optimization code structure assignment length 32 bytes */ +} SecStrBuf32; +typedef struct { + unsigned char buf[33]; /* Performance optimization code structure assignment length 33 bytes */ +} SecStrBuf33; +typedef struct { + unsigned char buf[34]; /* Performance optimization code structure assignment length 34 bytes */ +} SecStrBuf34; +typedef struct { + unsigned char buf[35]; /* Performance optimization code structure assignment length 35 bytes */ +} SecStrBuf35; +typedef struct { + unsigned char buf[36]; /* Performance optimization code structure assignment length 36 bytes */ +} SecStrBuf36; +typedef struct { + unsigned char buf[37]; /* Performance optimization code structure assignment length 37 bytes */ +} SecStrBuf37; +typedef struct { + unsigned char buf[38]; /* Performance optimization code structure assignment length 38 bytes */ +} SecStrBuf38; +typedef struct { + unsigned char buf[39]; /* Performance optimization code structure assignment length 39 bytes */ +} SecStrBuf39; +typedef struct { + unsigned char buf[40]; /* Performance optimization code structure assignment length 40 bytes */ +} SecStrBuf40; +typedef struct { + unsigned char buf[41]; /* Performance optimization code structure assignment length 41 bytes */ +} SecStrBuf41; +typedef struct { + unsigned char buf[42]; /* Performance optimization code structure assignment length 42 bytes */ +} SecStrBuf42; +typedef struct { + unsigned char buf[43]; /* Performance optimization code structure assignment length 43 bytes */ +} SecStrBuf43; +typedef struct { + unsigned char buf[44]; /* Performance optimization code structure assignment length 44 bytes */ +} SecStrBuf44; +typedef struct { + unsigned char buf[45]; /* Performance optimization code structure assignment length 45 bytes */ +} SecStrBuf45; +typedef struct { + unsigned char buf[46]; /* Performance optimization code structure assignment length 46 bytes */ +} SecStrBuf46; +typedef struct { + unsigned char buf[47]; /* Performance optimization code structure assignment length 47 bytes */ +} SecStrBuf47; +typedef struct { + unsigned char buf[48]; /* Performance optimization code structure assignment length 48 bytes */ +} SecStrBuf48; +typedef struct { + unsigned char buf[49]; /* Performance optimization code structure assignment length 49 bytes */ +} SecStrBuf49; +typedef struct { + unsigned char buf[50]; /* Performance optimization code structure assignment length 50 bytes */ +} SecStrBuf50; +typedef struct { + unsigned char buf[51]; /* Performance optimization code structure assignment length 51 bytes */ +} SecStrBuf51; +typedef struct { + unsigned char buf[52]; /* Performance optimization code structure assignment length 52 bytes */ +} SecStrBuf52; +typedef struct { + unsigned char buf[53]; /* Performance optimization code structure assignment length 53 bytes */ +} SecStrBuf53; +typedef struct { + unsigned char buf[54]; /* Performance optimization code structure assignment length 54 bytes */ +} SecStrBuf54; +typedef struct { + unsigned char buf[55]; /* Performance optimization code structure assignment length 55 bytes */ +} SecStrBuf55; +typedef struct { + unsigned char buf[56]; /* Performance optimization code structure assignment length 56 bytes */ +} SecStrBuf56; +typedef struct { + unsigned char buf[57]; /* Performance optimization code structure assignment length 57 bytes */ +} SecStrBuf57; +typedef struct { + unsigned char buf[58]; /* Performance optimization code structure assignment length 58 bytes */ +} SecStrBuf58; +typedef struct { + unsigned char buf[59]; /* Performance optimization code structure assignment length 59 bytes */ +} SecStrBuf59; +typedef struct { + unsigned char buf[60]; /* Performance optimization code structure assignment length 60 bytes */ +} SecStrBuf60; +typedef struct { + unsigned char buf[61]; /* Performance optimization code structure assignment length 61 bytes */ +} SecStrBuf61; +typedef struct { + unsigned char buf[62]; /* Performance optimization code structure assignment length 62 bytes */ +} SecStrBuf62; +typedef struct { + unsigned char buf[63]; /* Performance optimization code structure assignment length 63 bytes */ +} SecStrBuf63; +typedef struct { + unsigned char buf[64]; /* Performance optimization code structure assignment length 64 bytes */ +} SecStrBuf64; + +/* + * User can change the error handler by modify the following definition, + * such as logging the detail error in file. + */ +#if defined(_DEBUG) || defined(DEBUG) +#if defined(SECUREC_ERROR_HANDLER_BY_ASSERT) +#define SECUREC_ERROR_INVALID_PARAMTER(msg) assert(msg "invalid argument" == NULL) +#define SECUREC_ERROR_INVALID_RANGE(msg) assert(msg "invalid dest buffer size" == NULL) +#define SECUREC_ERROR_BUFFER_OVERLAP(msg) assert(msg "buffer overlap" == NULL) +#elif defined(SECUREC_ERROR_HANDLER_BY_PRINTF) +#if SECUREC_IN_KERNEL +#define SECUREC_ERROR_INVALID_PARAMTER(msg) printk("%s invalid argument\n", msg) +#define SECUREC_ERROR_INVALID_RANGE(msg) printk("%s invalid dest buffer size\n", msg) +#define SECUREC_ERROR_BUFFER_OVERLAP(msg) printk("%s buffer overlap\n", msg) +#else +#define SECUREC_ERROR_INVALID_PARAMTER(msg) printf("%s invalid argument\n", msg) +#define SECUREC_ERROR_INVALID_RANGE(msg) printf("%s invalid dest buffer size\n", msg) +#define SECUREC_ERROR_BUFFER_OVERLAP(msg) printf("%s buffer overlap\n", msg) +#endif +#elif defined(SECUREC_ERROR_HANDLER_BY_FILE_LOG) +#define SECUREC_ERROR_INVALID_PARAMTER(msg) LogSecureCRuntimeError(msg " EINVAL\n") +#define SECUREC_ERROR_INVALID_RANGE(msg) LogSecureCRuntimeError(msg " ERANGE\n") +#define SECUREC_ERROR_BUFFER_OVERLAP(msg) LogSecureCRuntimeError(msg " EOVERLAP\n") +#endif +#endif + +/* Default handler is none */ +#ifndef SECUREC_ERROR_INVALID_PARAMTER +#define SECUREC_ERROR_INVALID_PARAMTER(msg) +#endif +#ifndef SECUREC_ERROR_INVALID_RANGE +#define SECUREC_ERROR_INVALID_RANGE(msg) +#endif +#ifndef SECUREC_ERROR_BUFFER_OVERLAP +#define SECUREC_ERROR_BUFFER_OVERLAP(msg) +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +/* Assembly language memory copy and memory set for X86 or MIPS ... */ +#ifdef SECUREC_USE_ASM +void *memcpy_opt(void *dest, const void *src, size_t n); +void *memset_opt(void *s, int c, size_t n); +#endif + +#if defined(SECUREC_ERROR_HANDLER_BY_FILE_LOG) +void LogSecureCRuntimeError(const char *errDetail); +#endif + +#ifdef __cplusplus +} +#endif /* __cplusplus */ +#endif + diff --git a/dynolog_npu/scripts/apply_dyno_patches.sh b/msmonitor/scripts/apply_dyno_patches.sh similarity index 100% rename from dynolog_npu/scripts/apply_dyno_patches.sh rename to msmonitor/scripts/apply_dyno_patches.sh diff --git a/dynolog_npu/scripts/build.sh b/msmonitor/scripts/build.sh similarity index 71% rename from dynolog_npu/scripts/build.sh rename to msmonitor/scripts/build.sh index aa3508e14faa6bfea06afe0cd3083ad1a5317037..d51d2191334e49f2752e08f94182ab55f163ee51 100644 --- a/dynolog_npu/scripts/build.sh +++ b/msmonitor/scripts/build.sh @@ -1,5 +1,31 @@ #!/bin/bash set -e +export BUILD_PROMETHEUS=0 +export BUILD_TENSORBOARD=1 +export USE_TENSORBOARD="OFF" + +# 设置 CARGO_HOME +export CARGO_HOME="/root/.cargo" + +# 创建 Cargo 配置目录 +mkdir -p ${CARGO_HOME} + +# 创建 config.toml(安全编译选项) +cat > ${CARGO_HOME}/config.toml << EOF +[net] +git-fetch-with-cli = true + +[build] +rustflags = [ + "-C", "relocation_model=pie", + "-C", "link-args=-Wl,-z,now", + "-C", "link-args=-Wl,-z,relro", + "-C", "strip=symbols", + "-C", "overflow_checks", + "-C", "link-args=-static-libgcc", + "-C", "link-args=-static-libstdc++" +] +EOF check_gcc_version() { if ! command -v gcc >/dev/null 2>&1; then @@ -30,8 +56,8 @@ check_rust_version() { local RUST_MAJOR=$(echo $RUST_VERSION | cut -d. -f1) local RUST_MINOR=$(echo $RUST_VERSION | cut -d. -f2) - if [ "$RUST_MAJOR" -lt 1 ] || ([ "$RUST_MAJOR" -eq 1 ] && [ "$RUST_MINOR" -lt 56 ]); then - echo "ERROR: Rust version must be greater than or equal to 1.56.0" + if [ "$RUST_MAJOR" -lt 1 ] || ([ "$RUST_MAJOR" -eq 1 ] && [ "$RUST_MINOR" -lt 81 ]); then + echo "ERROR: Rust version must be greater than or equal to 1.81" echo "Current Rust version: $RUST_VERSION" return 1 fi @@ -42,7 +68,7 @@ check_rust_version() { update_and_checkout_submodule() { DYNLOG_COMMIT_ID="a9b6aeddcd6363252f5388cb0dd942981a09a24b" - git submodule update --init --recursive + git submodule update --init if [ $? -ne 0 ]; then echo "ERROR: update git submodule failed" return 1 @@ -50,6 +76,15 @@ update_and_checkout_submodule() { cd ./third_party/dynolog git checkout ${DYNLOG_COMMIT_ID} + + if [ "${BUILD_TENSORBOARD}" -ne 0 ]; then + if [ ! -d "./third_party/tensorboard_logger" ]; then + git submodule add https://github.com/RustingSword/tensorboard_logger.git ./third_party/tensorboard_logger + fi + USE_TENSORBOARD="ON" + fi + git submodule update --init --recursive + if [ $? -ne 0 ]; then echo "ERROR: switch to dynolog specified commit failed" cd .. @@ -85,6 +120,8 @@ check_rust_version echo "------------------ Update and checkout submodule -------------------" update_and_checkout_submodule +cp -r dynolog_npu/cmake third_party/dynolog + echo "------------------ Generate patch for Ascend -----------------------" bash scripts/gen_dyno_patches.sh @@ -98,11 +135,21 @@ if [ -z "$PACKAGE_TYPE" ]; then bash scripts/build.sh echo "Build dynolog success without packaging" elif [ "$PACKAGE_TYPE" = "deb" ]; then + ARCHITECTURE=$(uname -m) + CONTROL_FILE="scripts/debian/control" + ARCH="amd64" + if [[ "$ARCHITECTURE" == "aarch64" ]]; then + sed -i 's/^Architecture: .*/Architecture: arm64/' "$CONTROL_FILE" + ARCH="arm64" + echo "dpkg Architecture set to arm64" + fi + export ARCH=$ARCH bash scripts/debian/make_deb.sh + unset ARCH mv dynolog_*.deb ../../ echo "Build dynolog deb package success" elif [ "$PACKAGE_TYPE" = "rpm" ]; then bash scripts/rpm/make_rpm.sh - mv dynolog_*.rpm ../../ + mv dynolog-*.rpm ../../ echo "Build dynolog rpm package success" fi diff --git a/dynolog_npu/scripts/gen_dyno_patches.sh b/msmonitor/scripts/gen_dyno_patches.sh similarity index 100% rename from dynolog_npu/scripts/gen_dyno_patches.sh rename to msmonitor/scripts/gen_dyno_patches.sh diff --git a/msmonitor/scripts/run_tests.sh b/msmonitor/scripts/run_tests.sh new file mode 100644 index 0000000000000000000000000000000000000000..78329da25d0345bec509165e088301637f4046b6 --- /dev/null +++ b/msmonitor/scripts/run_tests.sh @@ -0,0 +1,64 @@ +#!/bin/bash +# 该脚本用于CI环境,执行系统测试用例 +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. + +# 严格模式,任何错误都会导致脚本退出 +set -e + +# 测试目录 +ST_DIR="test/st" + +# 当前目录检查 +if [[ $(basename $(pwd)) != "msmonitor" ]]; then + if [[ -d "msmonitor" ]]; then + echo "进入msmonitor目录" + cd msmonitor + else + echo "错误: 请在msmonitor目录或其父目录下运行此脚本" + exit 1 + fi +fi + +# 设置必要的环境变量 +export LD_LIBRARY_PATH=third_party/dynolog/third_party/prometheus-cpp/_build/lib:$LD_LIBRARY_PATH + +echo "执行系统测试 (test/st 目录)" + +# 检查系统测试目录是否存在 +if [[ ! -d "$ST_DIR" ]]; then + echo "错误: 系统测试目录 $ST_DIR 不存在" + exit 1 +fi + +# 查找所有以test开头的.py文件 +st_files=$(find $ST_DIR -name "test*.py") + +if [[ -z "$st_files" ]]; then + echo "错误: 没有找到测试文件" + exit 1 +fi + +# 执行每个测试文件,遇到失败立即停止 +for test_file in $st_files; do + echo "===============================================" + echo "执行测试: $test_file" + + # 直接执行Python文件 + python "$test_file" + result=$? + + if [ $result -eq 0 ]; then + echo "[通过] 测试成功: $test_file" + else + echo "[失败] 测试失败: $test_file" + echo "===============================================" + echo "测试执行中止: 发现失败的测试" + exit 1 + fi +done + +echo "===============================================" +echo "系统测试执行完毕" +echo "[成功] 所有测试通过" +exit 0 \ No newline at end of file diff --git a/msmonitor/test/st/gen_tls_certs.sh b/msmonitor/test/st/gen_tls_certs.sh new file mode 100644 index 0000000000000000000000000000000000000000..5ae6536beddec7a964ddd3c6b8d01ad1de89fb8c --- /dev/null +++ b/msmonitor/test/st/gen_tls_certs.sh @@ -0,0 +1,92 @@ +#!/bin/bash + +set -e +set -o pipefail + +# Configuration +CERTS_DIR="certs" +DAYS_VALID=3650 +SERVER_CN="localhost" +SERVER_SAN="DNS:localhost,IP:127.0.0.1" +CLIENT_CN="client" + +# Check if openssl is installed +if ! command -v openssl &> /dev/null; then + echo "Error: openssl is not installed. Please install it first." + exit 1 +fi + +# Create certs directory +mkdir -p "$CERTS_DIR" + +# Generate CA root certificate and private key +echo "Generating CA root certificate and private key..." +openssl genrsa -out "$CERTS_DIR/ca.key" 2048 +openssl req -x509 -new -key "$CERTS_DIR/ca.key" -days "$DAYS_VALID" -out "$CERTS_DIR/ca.crt" -subj "/CN=Test-CA" + +# Generate client private key and CSR +echo "Generating client private key and CSR..." +openssl genrsa -out "$CERTS_DIR/client.key" 2048 +openssl req -new -key "$CERTS_DIR/client.key" -out "$CERTS_DIR/client.csr" -subj "/CN=$CLIENT_CN" + +# Create v3.ext for client certificate +echo "Creating v3.ext for client certificate..." +cat > "$CERTS_DIR/v3.ext" < "$CERTS_DIR/server.ext" <= 3.7 ,tensorboard >= 2.11.2 + +### 2. 安装方式 + +#### 2.1 pip 安装(推荐) + +- 现本插件已经上传到 pypi 社区,用户可在 python 环境下直接通过以下 pip 指令进行安装: + ``` + pip install tb-graph-ascend + ``` +- 也可在 pypi 社区上下载离线 whl 包,传输到无法访问公网的环境上离线安装使用。访问[下载链接](https://pypi.org/project/tb-graph-ascend/#files)选择 whl 包进行下载,之后便可使用指令安装(此处{version}为 whl 包实际版本) + ``` + pip install tb-graph_ascend_{version}-py3-none-any.whl + ``` + +#### 2.2 从源代码安装 + +1. 从仓库下载源码并切换到 master 分支: + + ``` + git clone https://gitee.com/ascend/mstt.git -b master + ``` + +2. 进入目录 `plugins/tensorboard-plugins/tb_graph_ascend` 下 +3. 编译前端代码,根据操作系统选取不同指令 + + ``` + cd fe + // 安装前端依赖 + npm install --force + // Windows系统 + npm run buildWin + // 其他可使用cp指令的系统,如Linux或Mac + npm run buildLinux + ``` + + **注意**: 此步骤需要安装 [Node.js](https://nodejs.org/zh-cn/download) 环境 + +4. 回到上级目录直接安装: + ``` + cd ../ + python setup.py develop + ``` + +- 或: 构建 whl 包安装 + ``` + python setup.py bdist_wheel + ``` + 在 `plugins/tensorboard-plugins/tb_graph_ascend/dist` 目录下取出 whl 包,使用以下指令安装(此处{version}为 whl 包实际版本) + ``` + pip install tb-graph_ascend_{version}-py3-none-any.whl + ``` + +### 3. 解析数据说明 + +将通过[msprobe](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/msprobe#10-%E5%88%86%E7%BA%A7%E5%8F%AF%E8%A7%86%E5%8C%96%E6%9E%84%E5%9B%BE%E6%AF%94%E5%AF%B9)工具构图功能采集得到的文件后缀为.vis 的模型结构文件(文件本身为 json 格式)放置于某个文件夹中,路径名称下文称之为 `output_path` + +- E.g. \ + `---output_path` \ + `-----output.vis` \ + `-----output2.vis` + +### 4. 启动方式 + +1. 启动 TensorBoard + + ``` + tensorboard --logdir output_path + ``` + + 注意:确保默认端口 6006 可连通。 + + 如果需要切换端口号需要在尾部加上指定的端口号,如`--port=6007` + + ``` + tensorboard --logdir output_path --port=6007 + ``` + +2. 在浏览器上打开 tensorboard + + 在浏览器中打开 URL: `http://localhost:6006`。 + + 注意:如果`--logdir` 指定目录下的文件太大或太多,请等候,刷新浏览器查看加载结果。 + +3. 建议在本地启动 tensorboard,如果网络浏览器与启动 TensorBoard 的机器不在同一台机器上,需要远程启动,可参照[远程启动方式](#413-远程查看数据),但需用户自行评估**安全风险**。 + +## 三、浏览器查看 + +**注意:本工具不支持同时通过多个浏览器窗口同时访问同一个 TensorBoard 服务,否则会出现页面无法正常显示的情况。** + +### 3.1 主界面 + +![输入图片说明](./doc/images/main-page.png) + +### 3.2 操作方式: + +- **节点双击打开,单击选中。** +- **选中的节点边框呈现蓝色,比对场景下若其存在对应节点,则对应节点边框为浅蓝色。** +- **键盘 WS 根据鼠标位置放大缩小,AD 左右移动。** +- **鼠标滚轮上下移动,鼠标可拖动页面。** +- **比对场景鼠标右键可选中节点,并可展开至对应侧的节点并选中。** + +![输入图片说明](./doc/images/operator-image.png) + +### 3.3 名称搜索 + +![输入图片说明](./doc/images/vis_search_info.png) + +### 3.4 精度筛选/溢出筛选 + +注意:单图场景不存在精度筛选和溢出筛选,下图为双图比对场景。
+ +![输入图片说明](./doc/images/vis_precision_info.png) + +### 3.5 未匹配节点筛选 + +参考匹配说明 ,不符合匹配规则的节点为无匹配节点,颜色标灰。适用于排查两个模型结构差异的场景。
+ +![输入图片说明](./doc/images/vis_unmatch_info.png) + +### 3.6 手动选择节点匹配 + +可通过浏览器界面,通过鼠标选择两个待匹配的灰色节点进行匹配。当前暂不支持真实数据模式。
+如果选中"操作选中节点及其子节点":
+点击匹配后会将两个节点及其子节点按照 Module 名称依次匹配,取消匹配后会将子节点的匹配关系清除。
+否则:
+点击匹配后只会将两个节点进行匹配,取消匹配后会将节点的匹配关系清除 +注意:匹配结束之后,需要点击保存才能持久化到源文件里面 + +![输入图片说明](./doc/images/vis_match_info.png) + +### 3.7 生成匹配配置文件 + +可保存已经已匹配节点的匹配关系到配置文件中,并支持读取配置文件中的数据,进行匹配操作。
+默认保存在当前目录下,文件名为`[当前文件名].vis.config`,每次切换文件都会扫描当前录下的后缀名为.vis.config 配置文件,并更新配置文件列表。 +注意:匹配结束之后,需要点击保存才能持久化到源文件里面 +![输入图片说明](./doc/images/vis_save_match_info.png) + +### 3.8 支持用户自定义精度指标配置 + +![输入图片说明](./doc/images/vis_update_precision.png) + +## 四、附录 + +### 4.1 安全加固建议 + +#### 4.1.1 免责声明 + +本工具为基于 TensorBoard 底座开发的插件,使用本插件需要基于 TensorBoard 运行,请自行关注 TensorBoard 相关安全配置和安全风险。 + +打开本工具时,本工具会对 logdir 目录下的 vis 文件以及其父目录进行安全检查,如果存在安全风险,本工具会展示如下提示信息,询问用户是否继续执行,用户选择继续执行后,可以操作未通过安全检查的文件和目录,用户需要自行承担操作风险。如果用户选择不继续执行,则用户只能操作通过安全检查的文件。 + +![输入图片说明](./doc/images/safe_warning.png) + +#### 4.1.2 TensorBoard 版本说明 + +满足[相关依赖](#1-相关依赖)中要求的 TensorBoard 版本皆可正常使用本插件功能,但为 TensorBoard 本身安全风险考虑,建议使用最新版本 TensorBoard 。 + +#### 4.1.3 远程查看数据 + +如果网络浏览器与启动 TensorBoard 的机器不在同一台机器上, TensorBoard 提供了远程查看数据的指令启动方式,但此种方式会将服务器对应端口在局域网内公开(全零监听),请用户自行关注安全风险。 + +- 在启动指令尾部加上`--bind_all`或`--host={服务器IP}`参数启用远程查看方式,如: + + ``` + tensorboard --logdir output_path --port=6006 --host=xxx.xxx.xxx.xxx + 或 + tensorboard --logdir output_path --port=6006 --bind_all + ``` + +- 在打开浏览器访问界面时,需将 URL 内主机名由`localhost`替换为主机的 ip 地址,如`http://xxx.xxx.xxx.xxx:6006` + +### 4.2 通信矩阵 + +| 序号 | 代码仓 | 功能 | 源设备 | 源 IP | 源端口 | 目的设备 | 目的 IP | 目的端口
(侦听) | 协议 | 端口说明 | 端口配置 | 侦听端口是否可更改 | 所属平面 | 版本 | 特殊场景 | 备注 | +| :--- | :------------------ | :------------------------- | :------------------------------ | :--------------------------------- | :----- | :----------------------- | :------------------------------ | :-------------------- | :--- | :------------------- | :------- | :----------------- | :------- | :------- | :------- | :--- | +| 1 | tensorboard-plugins | TensorBoard 底座前后端通信 | 访问 TensorBoard 浏览器所在机器 | 访问 TensorBoard 浏览器所在机器 ip | | TensorBoard 服务所在机器 | TensorBoard 服务所在服务器的 ip | 6006 | HTTP | tensorboard 服务通信 | `--port` | 可修改 | 业务面 | 所有版本 | 无 | | + +### 4.3 公网地址说明 + +[公网地址说明](./doc/公网地址说明.csv) diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/main-page.png b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/main-page.png new file mode 100644 index 0000000000000000000000000000000000000000..b8e2a6dbcc5f55f3369406148dfc378890ccdc73 Binary files /dev/null and b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/main-page.png differ diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/operator-image.png b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/operator-image.png new file mode 100644 index 0000000000000000000000000000000000000000..b4463c05dc0e6a379d68592ec4129bd397ae0dd6 Binary files /dev/null and b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/operator-image.png differ diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/safe_warning.png b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/safe_warning.png new file mode 100644 index 0000000000000000000000000000000000000000..15b14eab1fd7a5710cc540e1126c346f5291fb8a Binary files /dev/null and b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/safe_warning.png differ diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/vis_match_info.png b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/vis_match_info.png new file mode 100644 index 0000000000000000000000000000000000000000..5785c81a2f308a4fb87db1c8528262e3b1821932 Binary files /dev/null and b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/vis_match_info.png differ diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/vis_precision_info.png b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/vis_precision_info.png new file mode 100644 index 0000000000000000000000000000000000000000..79c6ff77f4fffedfcbaee47767d3f8a4f1b0d5b3 Binary files /dev/null and b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/vis_precision_info.png differ diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/vis_save_match_info.png b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/vis_save_match_info.png new file mode 100644 index 0000000000000000000000000000000000000000..3670076eb8dc4cc315cfc89acec3d1d8d739ed6e Binary files /dev/null and b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/vis_save_match_info.png differ diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/vis_search_info.png b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/vis_search_info.png new file mode 100644 index 0000000000000000000000000000000000000000..7c51a804862591005725e1c2e1da0ff0ac152df1 Binary files /dev/null and b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/vis_search_info.png differ diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/vis_unmatch_info.png b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/vis_unmatch_info.png new file mode 100644 index 0000000000000000000000000000000000000000..5f698d109543e6171e2df28bafa83a09d3dd351d Binary files /dev/null and b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/vis_unmatch_info.png differ diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/vis_update_precision.png b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/vis_update_precision.png new file mode 100644 index 0000000000000000000000000000000000000000..b764fc983c0178e6f2f1d77807a6a4635a7dbd9e Binary files /dev/null and b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/vis_update_precision.png differ diff --git "a/plugins/tensorboard-plugins/tb_graph_ascend/doc/\345\205\254\347\275\221\345\234\260\345\235\200\350\257\264\346\230\216.csv" "b/plugins/tensorboard-plugins/tb_graph_ascend/doc/\345\205\254\347\275\221\345\234\260\345\235\200\350\257\264\346\230\216.csv" new file mode 100644 index 0000000000000000000000000000000000000000..ace10f82da1fc85c01850220d9ec812e9b7ecdce --- /dev/null +++ "b/plugins/tensorboard-plugins/tb_graph_ascend/doc/\345\205\254\347\275\221\345\234\260\345\235\200\350\257\264\346\230\216.csv" @@ -0,0 +1,11 @@ +IPַ/URLַ//ַ,;˵ +http://www.apache.org/licenses/LICENSE-2.0,License +pmail_mindstudio@huawei.com,MindStudioٷ +https://gitee.com/ascend/mstt/tree/master/plugins/tensorboard-plugins/tb_graph_ascend,ֵַ +https://npms.io,npm߹ַ +http://codepen.io/shyndman/pen/,룬ע +https://github.com/webcomponents/shadycss/issues/193,룬ע +http://jsbin.com/temexa/4,룬ע +https://fonts.googleapis.com/,룬ʽļ +https://developer.mozilla.org/,룬ע +https://github.com/vaadin/vaadin-time-picker/issues/145,룬ע diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/prettier.json b/plugins/tensorboard-plugins/tb_graph_ascend/fe/.prettierrc similarity index 83% rename from plugins/tensorboard-plugins/tb_plugin/fe/prettier.json rename to plugins/tensorboard-plugins/tb_graph_ascend/fe/.prettierrc index ef5789da9458a66e7dacc1dfdeeb764642331734..e3d2acb00457084b2f6cccafb8c95740e0344485 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/prettier.json +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/.prettierrc @@ -8,5 +8,6 @@ "useTabs": false, "trailingComma": "all", "proseWrap": "always", - "endOfLine": "lf" + "endOfLine": "lf", + "printWidth": 120 } diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Accuracy/entity.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/index.html similarity index 50% rename from plugins/tensorboard-plugins/tb_plugin/fe/src/components/Accuracy/entity.ts rename to plugins/tensorboard-plugins/tb_graph_ascend/fe/index.html index 270c4cb6535633f9a03e5b9fe02dca6121cd3ba7..fbb070b3715607b970b860e2f3f13e6e210581e5 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Accuracy/entity.ts +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/index.html @@ -1,30 +1,30 @@ -/*--------------------------------------------------------------------------------------------- - * Copyright (c) Microsoft Corporation. All rights reserved. + + + -export interface FileInfo { - id: number; - fileName: string; - fileContent: string; - checked: boolean; - lossTag: string; - iterTag: string; - iters: number[]; - losses: number[][]; - iterLosses: { [iter: number]: number }; -} + + + + + + + + + \ No newline at end of file diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/package-lock.json b/plugins/tensorboard-plugins/tb_graph_ascend/fe/package-lock.json new file mode 100644 index 0000000000000000000000000000000000000000..79dcfed452d4fc0d50e094760b37ce091c9cc030 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/package-lock.json @@ -0,0 +1,6479 @@ +{ + "name": "tb-graph-ascend", + "version": "0.1.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "tb-graph-ascend", + "version": "0.1.0", + "dependencies": { + "@jridgewell/gen-mapping": "^0.3.12", + "@polymer/decorators": "^3.0.0", + "@polymer/iron-collapse": "^3.0.1", + "@polymer/iron-icon": "^3.0.1", + "@polymer/paper-button": "^3.0.1", + "@polymer/paper-checkbox": "^3.1.0", + "@polymer/paper-dialog": "^3.0.1", + "@polymer/paper-tooltip": "^3.0.1", + "@polymer/polymer": "^3.5.1", + "@vaadin/button": "24.6.5", + "@vaadin/checkbox": "24.6.5", + "@vaadin/checkbox-group": "^24.6.5", + "@vaadin/combo-box": "24.6.5", + "@vaadin/confirm-dialog": "24.6.5", + "@vaadin/context-menu": "24.6.5", + "@vaadin/details": "24.6.5", + "@vaadin/grid": "24.6.5", + "@vaadin/icon": "24.6.5", + "@vaadin/icons": "24.6.5", + "@vaadin/notification": "24.6.5", + "@vaadin/progress-bar": "24.6.5", + "@vaadin/select": "24.6.5", + "@vaadin/tabs": "24.6.5", + "@vaadin/tabsheet": "24.6.5", + "@vaadin/text-field": "24.6.5", + "@vaadin/tooltip": "24.6.5", + "axios": "^1.8.4", + "brace-expansion": "^1.1.12", + "clean-webpack-plugin": "^4.0.0", + "cross-env": "^7.0.3", + "css-loader": "^7.1.2", + "d3": "^7.9.0", + "dagre": "^0.8.5", + "form-data": "^4.0.4", + "i18next": "^23.16.8", + "i18next-browser-languagedetector": "^7.2.2", + "lodash": "^4.17.21", + "prettier": "^3.4.2", + "style-loader": "^4.0.0" + }, + "devDependencies": { + "@types/d3": "^7.4.3", + "@types/lodash": "^4.17.20", + "@types/node": "^16.4.13", + "@types/offscreencanvas": "^2019.6.3", + "@types/requirejs": "^2.1.33", + "@types/resize-observer-browser": "^0.1.6", + "@types/three": "^0.131.0", + "html-loader": "^5.1.0", + "html-webpack-plugin": "^5.6.3", + "inline-chunk-html-plugin": "^1.1.1", + "ts-loader": "^9.5.1", + "tslib": "^2.6.2", + "typescript": "^5.4.5", + "webpack": "^5.96.1", + "webpack-cli": "^5.1.4", + "webpack-dev-server": "4.15.1" + } + }, + "node_modules/@babel/runtime": { + "version": "7.28.2", + "resolved": "https://registry.npmmirror.com/@babel/runtime/-/runtime-7.28.2.tgz", + "integrity": "sha512-KHp2IflsnGywDjBWDkR9iEqiWSpc8GIi0lgTT3mOElT0PP1tG26P4tmFI2YvAdzgq9RGyoHZQEIEdZy6Ec5xCA==", + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@discoveryjs/json-ext": { + "version": "0.5.7", + "resolved": "https://registry.npmmirror.com/@discoveryjs/json-ext/-/json-ext-0.5.7.tgz", + "integrity": "sha512-dBVuXR082gk3jsFp7Rd/JI4kytwGHecnCoTtXFb7DB6CNHp4rg5k1bhg0nWdLGLnOV71lmDzGQaLMy8iPLY0pw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/@jridgewell/gen-mapping": { + "version": "0.3.12", + "resolved": "https://registry.npmmirror.com/@jridgewell/gen-mapping/-/gen-mapping-0.3.12.tgz", + "integrity": "sha512-OuLGC46TjB5BbN1dH8JULVVZY4WTdkF7tV9Ys6wLL1rubZnCMstOhNHueU5bLCrnRuDhKPDM4g6sw4Bel5Gzqg==", + "license": "MIT", + "dependencies": { + "@jridgewell/sourcemap-codec": "^1.5.0", + "@jridgewell/trace-mapping": "^0.3.24" + } + }, + "node_modules/@jridgewell/resolve-uri": { + "version": "3.1.2", + "resolved": "https://registry.npmmirror.com/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz", + "integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==", + "license": "MIT", + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/source-map": { + "version": "0.3.6", + "resolved": "https://registry.npmmirror.com/@jridgewell/source-map/-/source-map-0.3.6.tgz", + "integrity": "sha512-1ZJTZebgqllO79ue2bm3rIGud/bOe0pP5BjSRCRxxYkEZS8STV7zN84UBbiYu7jy+eCKSnVIUgoWWE/tt+shMQ==", + "license": "MIT", + "dependencies": { + "@jridgewell/gen-mapping": "^0.3.5", + "@jridgewell/trace-mapping": "^0.3.25" + } + }, + "node_modules/@jridgewell/sourcemap-codec": { + "version": "1.5.0", + "resolved": "https://registry.npmmirror.com/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz", + "integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==", + "license": "MIT" + }, + "node_modules/@jridgewell/trace-mapping": { + "version": "0.3.25", + "resolved": "https://registry.npmmirror.com/@jridgewell/trace-mapping/-/trace-mapping-0.3.25.tgz", + "integrity": "sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==", + "license": "MIT", + "dependencies": { + "@jridgewell/resolve-uri": "^3.1.0", + "@jridgewell/sourcemap-codec": "^1.4.14" + } + }, + "node_modules/@leichtgewicht/ip-codec": { + "version": "2.0.5", + "resolved": "https://registry.npmmirror.com/@leichtgewicht/ip-codec/-/ip-codec-2.0.5.tgz", + "integrity": "sha512-Vo+PSpZG2/fmgmiNzYK9qWRh8h/CHrwD0mo1h1DzL4yzHNSfWYujGTYsWGreD000gcgmZ7K4Ys6Tx9TxtsKdDw==", + "dev": true, + "license": "MIT" + }, + "node_modules/@lit-labs/ssr-dom-shim": { + "version": "1.3.0", + "resolved": "https://registry.npmmirror.com/@lit-labs/ssr-dom-shim/-/ssr-dom-shim-1.3.0.tgz", + "integrity": "sha512-nQIWonJ6eFAvUUrSlwyHDm/aE8PBDu5kRpL0vHMg6K8fK3Diq1xdPjTnsJSwxABhaZ+5eBi1btQB5ShUTKo4nQ==", + "license": "BSD-3-Clause" + }, + "node_modules/@lit/reactive-element": { + "version": "2.0.4", + "resolved": "https://registry.npmmirror.com/@lit/reactive-element/-/reactive-element-2.0.4.tgz", + "integrity": "sha512-GFn91inaUa2oHLak8awSIigYz0cU0Payr1rcFsrkf5OJ5eSPxElyZfKh0f2p9FsTiZWXQdWGJeXZICEfXXYSXQ==", + "license": "BSD-3-Clause", + "dependencies": { + "@lit-labs/ssr-dom-shim": "^1.2.0" + } + }, + "node_modules/@open-wc/dedupe-mixin": { + "version": "1.4.0", + "resolved": "https://registry.npmmirror.com/@open-wc/dedupe-mixin/-/dedupe-mixin-1.4.0.tgz", + "integrity": "sha512-Sj7gKl1TLcDbF7B6KUhtvr+1UCxdhMbNY5KxdU5IfMFWqL8oy1ZeAcCANjoB1TL0AJTcPmcCFsCbHf8X2jGDUA==", + "license": "MIT" + }, + "node_modules/@polymer/decorators": { + "version": "3.0.0", + "resolved": "https://registry.npmmirror.com/@polymer/decorators/-/decorators-3.0.0.tgz", + "integrity": "sha512-qh+VID9nDV9q3ABvIfWgm7/+udl7v2HKsMLPXFm8tj1fI7qr7yWJMFwS3xWBkMmuNPtmkS8MDP0vqLAQIEOWzg==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/polymer": "^3.0.5" + } + }, + "node_modules/@polymer/font-roboto": { + "version": "3.0.2", + "resolved": "https://registry.npmmirror.com/@polymer/font-roboto/-/font-roboto-3.0.2.tgz", + "integrity": "sha512-tx5TauYSmzsIvmSqepUPDYbs4/Ejz2XbZ1IkD7JEGqkdNUJlh+9KU85G56Tfdk/xjEZ8zorFfN09OSwiMrIQWA==", + "license": "BSD-3-Clause" + }, + "node_modules/@polymer/iron-a11y-keys-behavior": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/@polymer/iron-a11y-keys-behavior/-/iron-a11y-keys-behavior-3.0.1.tgz", + "integrity": "sha512-lnrjKq3ysbBPT/74l0Fj0U9H9C35Tpw2C/tpJ8a+5g8Y3YJs1WSZYnEl1yOkw6sEyaxOq/1DkzH0+60gGu5/PQ==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/iron-behaviors": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/@polymer/iron-behaviors/-/iron-behaviors-3.0.1.tgz", + "integrity": "sha512-IMEwcv1lhf1HSQxuyWOUIL0lOBwmeaoSTpgCJeP9IBYnuB1SPQngmfRuHKgK6/m9LQ9F9miC7p3HeQQUdKAE0w==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/iron-a11y-keys-behavior": "^3.0.0-pre.26", + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/iron-checked-element-behavior": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/@polymer/iron-checked-element-behavior/-/iron-checked-element-behavior-3.0.1.tgz", + "integrity": "sha512-aDr0cbCNVq49q+pOqa6CZutFh+wWpwPMLpEth9swx+GkAj+gCURhuQkaUYhIo5f2egDbEioR1aeHMnPlU9dQZA==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/iron-form-element-behavior": "^3.0.0-pre.26", + "@polymer/iron-validatable-behavior": "^3.0.0-pre.26", + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/iron-collapse": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/@polymer/iron-collapse/-/iron-collapse-3.0.1.tgz", + "integrity": "sha512-yg6q5ZyckQR9VL9VmLrSTkSFXWy9AcJC8KtnD5cg0EHRPbakE8I9S/gVAgeP4nMWV2a/BjLLC4IBygcCMDhAGw==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/iron-resizable-behavior": "^3.0.0-pre.26", + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/iron-fit-behavior": { + "version": "3.1.0", + "resolved": "https://registry.npmmirror.com/@polymer/iron-fit-behavior/-/iron-fit-behavior-3.1.0.tgz", + "integrity": "sha512-ABcgIYqrjhmUT8tiuolqeGttF/8pd3sEymUDrO1vXbZu4FWIvoLNndrMDFvs++AGd12Mjf5pYy84NJc6dB8Vig==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/iron-flex-layout": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/@polymer/iron-flex-layout/-/iron-flex-layout-3.0.1.tgz", + "integrity": "sha512-7gB869czArF+HZcPTVSgvA7tXYFze9EKckvM95NB7SqYF+NnsQyhoXgKnpFwGyo95lUjUW9TFDLUwDXnCYFtkw==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/iron-form-element-behavior": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/@polymer/iron-form-element-behavior/-/iron-form-element-behavior-3.0.1.tgz", + "integrity": "sha512-G/e2KXyL5AY7mMjmomHkGpgS0uAf4ovNpKhkuUTRnMuMJuf589bKqE85KN4ovE1Tzhv2hJoh/igyD6ekHiYU1A==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/iron-icon": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/@polymer/iron-icon/-/iron-icon-3.0.1.tgz", + "integrity": "sha512-QLPwirk+UPZNaLnMew9VludXA4CWUCenRewgEcGYwdzVgDPCDbXxy6vRJjmweZobMQv/oVLppT2JZtJFnPxX6g==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/iron-flex-layout": "^3.0.0-pre.26", + "@polymer/iron-meta": "^3.0.0-pre.26", + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/iron-meta": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/@polymer/iron-meta/-/iron-meta-3.0.1.tgz", + "integrity": "sha512-pWguPugiLYmWFV9UWxLWzZ6gm4wBwQdDy4VULKwdHCqR7OP7u98h+XDdGZsSlDPv6qoryV/e3tGHlTIT0mbzJA==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/iron-overlay-behavior": { + "version": "3.0.3", + "resolved": "https://registry.npmmirror.com/@polymer/iron-overlay-behavior/-/iron-overlay-behavior-3.0.3.tgz", + "integrity": "sha512-Q/Fp0+uOQQ145ebZ7T8Cxl4m1tUKYjyymkjcL2rXUm+aDQGb1wA1M1LYxUF5YBqd+9lipE0PTIiYwA2ZL/sznA==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/iron-a11y-keys-behavior": "^3.0.0-pre.26", + "@polymer/iron-fit-behavior": "^3.0.0-pre.26", + "@polymer/iron-resizable-behavior": "^3.0.0-pre.26", + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/iron-resizable-behavior": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/@polymer/iron-resizable-behavior/-/iron-resizable-behavior-3.0.1.tgz", + "integrity": "sha512-FyHxRxFspVoRaeZSWpT3y0C9awomb4tXXolIJcZ7RvXhMP632V5lez+ch5G5SwK0LpnAPkg35eB0LPMFv+YMMQ==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/iron-selector": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/@polymer/iron-selector/-/iron-selector-3.0.1.tgz", + "integrity": "sha512-sBVk2uas6prW0glUe2xEJJYlvxmYzM40Au9OKbfDK2Qekou/fLKcBRyIYI39kuI8zWRaip8f3CI8qXcUHnKb1A==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/iron-validatable-behavior": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/@polymer/iron-validatable-behavior/-/iron-validatable-behavior-3.0.1.tgz", + "integrity": "sha512-wwpYh6wOa4fNI+jH5EYKC7TVPYQ2OfgQqocWat7GsNWcsblKYhLYbwsvEY5nO0n2xKqNfZzDLrUom5INJN7msQ==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/iron-meta": "^3.0.0-pre.26", + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/neon-animation": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/@polymer/neon-animation/-/neon-animation-3.0.1.tgz", + "integrity": "sha512-cDDc0llpVCe0ATbDS3clDthI54Bc8YwZIeTGGmBJleKOvbRTUC5+ssJmRL+VwVh+VM5FlnQlx760ppftY3uprg==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/iron-resizable-behavior": "^3.0.0-pre.26", + "@polymer/iron-selector": "^3.0.0-pre.26", + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/paper-behaviors": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/@polymer/paper-behaviors/-/paper-behaviors-3.0.1.tgz", + "integrity": "sha512-6knhj69fPJejv8qR0kCSUY+Q0XjaUf0OSnkjRjmTJPAwSrRYtgqE+l6P1FfA+py1X/cUjgne9EF5rMZAKJIg1g==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/iron-behaviors": "^3.0.0-pre.26", + "@polymer/iron-checked-element-behavior": "^3.0.0-pre.26", + "@polymer/paper-ripple": "^3.0.0-pre.26", + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/paper-button": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/@polymer/paper-button/-/paper-button-3.0.1.tgz", + "integrity": "sha512-JRNBc+Oj9EWnmyLr7FcCr8T1KAnEHPh6mosln9BUdkM+qYaYsudSICh3cjTIbnj6AuF5OJidoLkM1dlyj0j6Zg==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/iron-flex-layout": "^3.0.0-pre.26", + "@polymer/paper-behaviors": "^3.0.0-pre.27", + "@polymer/paper-styles": "^3.0.0-pre.26", + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/paper-checkbox": { + "version": "3.1.0", + "resolved": "https://registry.npmmirror.com/@polymer/paper-checkbox/-/paper-checkbox-3.1.0.tgz", + "integrity": "sha512-kXm6yDG1tT8if0XuJ2cc9NF+g8Ev4wG+rnf0a+Sx+O7J6fn1jcnBlYn72FlrfjVjDQZDBFmT6nynhD5PvFw8iQ==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/iron-a11y-keys-behavior": "^3.0.0-pre.26", + "@polymer/iron-checked-element-behavior": "^3.0.0-pre.26", + "@polymer/paper-behaviors": "^3.0.0-pre.27", + "@polymer/paper-ripple": "^3.0.0-pre.26", + "@polymer/paper-styles": "^3.0.0-pre.26", + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/paper-dialog": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/@polymer/paper-dialog/-/paper-dialog-3.0.1.tgz", + "integrity": "sha512-KvglYbEq7AWJvui2j6WKLnOvgVMeGjovAydGrPRj7kVzCiD49Eq/hpYFJTRV5iDcalWH+mORUpw+jrFnG9+Kgw==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/iron-overlay-behavior": "^3.0.0-pre.27", + "@polymer/neon-animation": "^3.0.0-pre.26", + "@polymer/paper-dialog-behavior": "^3.0.0-pre.26", + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/paper-dialog-behavior": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/@polymer/paper-dialog-behavior/-/paper-dialog-behavior-3.0.1.tgz", + "integrity": "sha512-wbI4kCK8le/9MHT+IXzvHjoatxf3kd3Yn0tgozAiAwfSZ7N4Ubpi5MHrK0m9S9PeIxKokAgBYdTUrezSE5378A==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/iron-overlay-behavior": "^3.0.0-pre.27", + "@polymer/paper-styles": "^3.0.0-pre.26", + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/paper-ripple": { + "version": "3.0.2", + "resolved": "https://registry.npmmirror.com/@polymer/paper-ripple/-/paper-ripple-3.0.2.tgz", + "integrity": "sha512-DnLNvYIMsiayeICroYxx6Q6Hg1cUU8HN2sbutXazlemAlGqdq80qz3TIaVdbpbt/pvjcFGX2HtntMlPstCge8Q==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/iron-a11y-keys-behavior": "^3.0.0-pre.26", + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/paper-styles": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/@polymer/paper-styles/-/paper-styles-3.0.1.tgz", + "integrity": "sha512-y6hmObLqlCx602TQiSBKHqjwkE7xmDiFkoxdYGaNjtv4xcysOTdVJsDR/R9UHwIaxJ7gHlthMSykir1nv78++g==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/font-roboto": "^3.0.1", + "@polymer/iron-flex-layout": "^3.0.0-pre.26", + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/paper-tooltip": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/@polymer/paper-tooltip/-/paper-tooltip-3.0.1.tgz", + "integrity": "sha512-yiUk09opTEnE1lK+tb501ENb+yQBi4p++Ep0eGJAHesVYKVMPNgPphVKkIizkDaU+n0SE+zXfTsRbYyOMDYXSg==", + "license": "BSD-3-Clause", + "dependencies": { + "@polymer/paper-styles": "^3.0.0-pre.26", + "@polymer/polymer": "^3.0.0" + } + }, + "node_modules/@polymer/polymer": { + "version": "3.5.2", + "resolved": "https://registry.npmmirror.com/@polymer/polymer/-/polymer-3.5.2.tgz", + "integrity": "sha512-fWwImY/UH4bb2534DVSaX+Azs2yKg8slkMBHOyGeU2kKx7Xmxp6Lee0jP8p6B3d7c1gFUPB2Z976dTUtX81pQA==", + "license": "BSD-3-Clause", + "dependencies": { + "@webcomponents/shadycss": "^1.9.1" + } + }, + "node_modules/@types/body-parser": { + "version": "1.19.5", + "resolved": "https://registry.npmmirror.com/@types/body-parser/-/body-parser-1.19.5.tgz", + "integrity": "sha512-fB3Zu92ucau0iQ0JMCFQE7b/dv8Ot07NI3KaZIkIUNXq82k4eBAqUaneXfleGY9JWskeS9y+u0nXMyspcuQrCg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/connect": "*", + "@types/node": "*" + } + }, + "node_modules/@types/bonjour": { + "version": "3.5.13", + "resolved": "https://registry.npmmirror.com/@types/bonjour/-/bonjour-3.5.13.tgz", + "integrity": "sha512-z9fJ5Im06zvUL548KvYNecEVlA7cVDkGUi6kZusb04mpyEFKCIZJvloCcmpmLaIahDpOQGHaHmG6imtPMmPXGQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*" + } + }, + "node_modules/@types/connect": { + "version": "3.4.38", + "resolved": "https://registry.npmmirror.com/@types/connect/-/connect-3.4.38.tgz", + "integrity": "sha512-K6uROf1LD88uDQqJCktA4yzL1YYAK6NgfsI0v/mTgyPKWsX1CnJ0XPSDhViejru1GcRkLWb8RlzFYJRqGUbaug==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*" + } + }, + "node_modules/@types/connect-history-api-fallback": { + "version": "1.5.4", + "resolved": "https://registry.npmmirror.com/@types/connect-history-api-fallback/-/connect-history-api-fallback-1.5.4.tgz", + "integrity": "sha512-n6Cr2xS1h4uAulPRdlw6Jl6s1oG8KrVilPN2yUITEs+K48EzMJJ3W1xy8K5eWuFvjp3R74AOIGSmp2UfBJ8HFw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/express-serve-static-core": "*", + "@types/node": "*" + } + }, + "node_modules/@types/d3": { + "version": "7.4.3", + "resolved": "https://registry.npmmirror.com/@types/d3/-/d3-7.4.3.tgz", + "integrity": "sha512-lZXZ9ckh5R8uiFVt8ogUNf+pIrK4EsWrx2Np75WvF/eTpJ0FMHNhjXk8CKEx/+gpHbNQyJWehbFaTvqmHWB3ww==", + "dev": true, + "dependencies": { + "@types/d3-array": "*", + "@types/d3-axis": "*", + "@types/d3-brush": "*", + "@types/d3-chord": "*", + "@types/d3-color": "*", + "@types/d3-contour": "*", + "@types/d3-delaunay": "*", + "@types/d3-dispatch": "*", + "@types/d3-drag": "*", + "@types/d3-dsv": "*", + "@types/d3-ease": "*", + "@types/d3-fetch": "*", + "@types/d3-force": "*", + "@types/d3-format": "*", + "@types/d3-geo": "*", + "@types/d3-hierarchy": "*", + "@types/d3-interpolate": "*", + "@types/d3-path": "*", + "@types/d3-polygon": "*", + "@types/d3-quadtree": "*", + "@types/d3-random": "*", + "@types/d3-scale": "*", + "@types/d3-scale-chromatic": "*", + "@types/d3-selection": "*", + "@types/d3-shape": "*", + "@types/d3-time": "*", + "@types/d3-time-format": "*", + "@types/d3-timer": "*", + "@types/d3-transition": "*", + "@types/d3-zoom": "*" + } + }, + "node_modules/@types/d3-array": { + "version": "3.2.1", + "resolved": "https://registry.npmmirror.com/@types/d3-array/-/d3-array-3.2.1.tgz", + "integrity": "sha512-Y2Jn2idRrLzUfAKV2LyRImR+y4oa2AntrgID95SHJxuMUrkNXmanDSed71sRNZysveJVt1hLLemQZIady0FpEg==", + "dev": true + }, + "node_modules/@types/d3-axis": { + "version": "3.0.6", + "resolved": "https://registry.npmmirror.com/@types/d3-axis/-/d3-axis-3.0.6.tgz", + "integrity": "sha512-pYeijfZuBd87T0hGn0FO1vQ/cgLk6E1ALJjfkC0oJ8cbwkZl3TpgS8bVBLZN+2jjGgg38epgxb2zmoGtSfvgMw==", + "dev": true, + "dependencies": { + "@types/d3-selection": "*" + } + }, + "node_modules/@types/d3-brush": { + "version": "3.0.6", + "resolved": "https://registry.npmmirror.com/@types/d3-brush/-/d3-brush-3.0.6.tgz", + "integrity": "sha512-nH60IZNNxEcrh6L1ZSMNA28rj27ut/2ZmI3r96Zd+1jrZD++zD3LsMIjWlvg4AYrHn/Pqz4CF3veCxGjtbqt7A==", + "dev": true, + "dependencies": { + "@types/d3-selection": "*" + } + }, + "node_modules/@types/d3-chord": { + "version": "3.0.6", + "resolved": "https://registry.npmmirror.com/@types/d3-chord/-/d3-chord-3.0.6.tgz", + "integrity": "sha512-LFYWWd8nwfwEmTZG9PfQxd17HbNPksHBiJHaKuY1XeqscXacsS2tyoo6OdRsjf+NQYeB6XrNL3a25E3gH69lcg==", + "dev": true + }, + "node_modules/@types/d3-color": { + "version": "3.1.3", + "resolved": "https://registry.npmmirror.com/@types/d3-color/-/d3-color-3.1.3.tgz", + "integrity": "sha512-iO90scth9WAbmgv7ogoq57O9YpKmFBbmoEoCHDB2xMBY0+/KVrqAaCDyCE16dUspeOvIxFFRI+0sEtqDqy2b4A==", + "dev": true + }, + "node_modules/@types/d3-contour": { + "version": "3.0.6", + "resolved": "https://registry.npmmirror.com/@types/d3-contour/-/d3-contour-3.0.6.tgz", + "integrity": "sha512-BjzLgXGnCWjUSYGfH1cpdo41/hgdWETu4YxpezoztawmqsvCeep+8QGfiY6YbDvfgHz/DkjeIkkZVJavB4a3rg==", + "dev": true, + "dependencies": { + "@types/d3-array": "*", + "@types/geojson": "*" + } + }, + "node_modules/@types/d3-delaunay": { + "version": "6.0.4", + "resolved": "https://registry.npmmirror.com/@types/d3-delaunay/-/d3-delaunay-6.0.4.tgz", + "integrity": "sha512-ZMaSKu4THYCU6sV64Lhg6qjf1orxBthaC161plr5KuPHo3CNm8DTHiLw/5Eq2b6TsNP0W0iJrUOFscY6Q450Hw==", + "dev": true + }, + "node_modules/@types/d3-dispatch": { + "version": "3.0.6", + "resolved": "https://registry.npmmirror.com/@types/d3-dispatch/-/d3-dispatch-3.0.6.tgz", + "integrity": "sha512-4fvZhzMeeuBJYZXRXrRIQnvUYfyXwYmLsdiN7XXmVNQKKw1cM8a5WdID0g1hVFZDqT9ZqZEY5pD44p24VS7iZQ==", + "dev": true + }, + "node_modules/@types/d3-drag": { + "version": "3.0.7", + "resolved": "https://registry.npmmirror.com/@types/d3-drag/-/d3-drag-3.0.7.tgz", + "integrity": "sha512-HE3jVKlzU9AaMazNufooRJ5ZpWmLIoc90A37WU2JMmeq28w1FQqCZswHZ3xR+SuxYftzHq6WU6KJHvqxKzTxxQ==", + "dev": true, + "dependencies": { + "@types/d3-selection": "*" + } + }, + "node_modules/@types/d3-dsv": { + "version": "3.0.7", + "resolved": "https://registry.npmmirror.com/@types/d3-dsv/-/d3-dsv-3.0.7.tgz", + "integrity": "sha512-n6QBF9/+XASqcKK6waudgL0pf/S5XHPPI8APyMLLUHd8NqouBGLsU8MgtO7NINGtPBtk9Kko/W4ea0oAspwh9g==", + "dev": true + }, + "node_modules/@types/d3-ease": { + "version": "3.0.2", + "resolved": "https://registry.npmmirror.com/@types/d3-ease/-/d3-ease-3.0.2.tgz", + "integrity": "sha512-NcV1JjO5oDzoK26oMzbILE6HW7uVXOHLQvHshBUW4UMdZGfiY6v5BeQwh9a9tCzv+CeefZQHJt5SRgK154RtiA==", + "dev": true + }, + "node_modules/@types/d3-fetch": { + "version": "3.0.7", + "resolved": "https://registry.npmmirror.com/@types/d3-fetch/-/d3-fetch-3.0.7.tgz", + "integrity": "sha512-fTAfNmxSb9SOWNB9IoG5c8Hg6R+AzUHDRlsXsDZsNp6sxAEOP0tkP3gKkNSO/qmHPoBFTxNrjDprVHDQDvo5aA==", + "dev": true, + "dependencies": { + "@types/d3-dsv": "*" + } + }, + "node_modules/@types/d3-force": { + "version": "3.0.10", + "resolved": "https://registry.npmmirror.com/@types/d3-force/-/d3-force-3.0.10.tgz", + "integrity": "sha512-ZYeSaCF3p73RdOKcjj+swRlZfnYpK1EbaDiYICEEp5Q6sUiqFaFQ9qgoshp5CzIyyb/yD09kD9o2zEltCexlgw==", + "dev": true + }, + "node_modules/@types/d3-format": { + "version": "3.0.4", + "resolved": "https://registry.npmmirror.com/@types/d3-format/-/d3-format-3.0.4.tgz", + "integrity": "sha512-fALi2aI6shfg7vM5KiR1wNJnZ7r6UuggVqtDA+xiEdPZQwy/trcQaHnwShLuLdta2rTymCNpxYTiMZX/e09F4g==", + "dev": true + }, + "node_modules/@types/d3-geo": { + "version": "3.1.0", + "resolved": "https://registry.npmmirror.com/@types/d3-geo/-/d3-geo-3.1.0.tgz", + "integrity": "sha512-856sckF0oP/diXtS4jNsiQw/UuK5fQG8l/a9VVLeSouf1/PPbBE1i1W852zVwKwYCBkFJJB7nCFTbk6UMEXBOQ==", + "dev": true, + "dependencies": { + "@types/geojson": "*" + } + }, + "node_modules/@types/d3-hierarchy": { + "version": "3.1.7", + "resolved": "https://registry.npmmirror.com/@types/d3-hierarchy/-/d3-hierarchy-3.1.7.tgz", + "integrity": "sha512-tJFtNoYBtRtkNysX1Xq4sxtjK8YgoWUNpIiUee0/jHGRwqvzYxkq0hGVbbOGSz+JgFxxRu4K8nb3YpG3CMARtg==", + "dev": true + }, + "node_modules/@types/d3-interpolate": { + "version": "3.0.4", + "resolved": "https://registry.npmmirror.com/@types/d3-interpolate/-/d3-interpolate-3.0.4.tgz", + "integrity": "sha512-mgLPETlrpVV1YRJIglr4Ez47g7Yxjl1lj7YKsiMCb27VJH9W8NVM6Bb9d8kkpG/uAQS5AmbA48q2IAolKKo1MA==", + "dev": true, + "dependencies": { + "@types/d3-color": "*" + } + }, + "node_modules/@types/d3-path": { + "version": "3.1.1", + "resolved": "https://registry.npmmirror.com/@types/d3-path/-/d3-path-3.1.1.tgz", + "integrity": "sha512-VMZBYyQvbGmWyWVea0EHs/BwLgxc+MKi1zLDCONksozI4YJMcTt8ZEuIR4Sb1MMTE8MMW49v0IwI5+b7RmfWlg==", + "dev": true + }, + "node_modules/@types/d3-polygon": { + "version": "3.0.2", + "resolved": "https://registry.npmmirror.com/@types/d3-polygon/-/d3-polygon-3.0.2.tgz", + "integrity": "sha512-ZuWOtMaHCkN9xoeEMr1ubW2nGWsp4nIql+OPQRstu4ypeZ+zk3YKqQT0CXVe/PYqrKpZAi+J9mTs05TKwjXSRA==", + "dev": true + }, + "node_modules/@types/d3-quadtree": { + "version": "3.0.6", + "resolved": "https://registry.npmmirror.com/@types/d3-quadtree/-/d3-quadtree-3.0.6.tgz", + "integrity": "sha512-oUzyO1/Zm6rsxKRHA1vH0NEDG58HrT5icx/azi9MF1TWdtttWl0UIUsjEQBBh+SIkrpd21ZjEv7ptxWys1ncsg==", + "dev": true + }, + "node_modules/@types/d3-random": { + "version": "3.0.3", + "resolved": "https://registry.npmmirror.com/@types/d3-random/-/d3-random-3.0.3.tgz", + "integrity": "sha512-Imagg1vJ3y76Y2ea0871wpabqp613+8/r0mCLEBfdtqC7xMSfj9idOnmBYyMoULfHePJyxMAw3nWhJxzc+LFwQ==", + "dev": true + }, + "node_modules/@types/d3-scale": { + "version": "4.0.9", + "resolved": "https://registry.npmmirror.com/@types/d3-scale/-/d3-scale-4.0.9.tgz", + "integrity": "sha512-dLmtwB8zkAeO/juAMfnV+sItKjlsw2lKdZVVy6LRr0cBmegxSABiLEpGVmSJJ8O08i4+sGR6qQtb6WtuwJdvVw==", + "dev": true, + "dependencies": { + "@types/d3-time": "*" + } + }, + "node_modules/@types/d3-scale-chromatic": { + "version": "3.1.0", + "resolved": "https://registry.npmmirror.com/@types/d3-scale-chromatic/-/d3-scale-chromatic-3.1.0.tgz", + "integrity": "sha512-iWMJgwkK7yTRmWqRB5plb1kadXyQ5Sj8V/zYlFGMUBbIPKQScw+Dku9cAAMgJG+z5GYDoMjWGLVOvjghDEFnKQ==", + "dev": true + }, + "node_modules/@types/d3-selection": { + "version": "3.0.11", + "resolved": "https://registry.npmmirror.com/@types/d3-selection/-/d3-selection-3.0.11.tgz", + "integrity": "sha512-bhAXu23DJWsrI45xafYpkQ4NtcKMwWnAC/vKrd2l+nxMFuvOT3XMYTIj2opv8vq8AO5Yh7Qac/nSeP/3zjTK0w==", + "dev": true + }, + "node_modules/@types/d3-shape": { + "version": "3.1.7", + "resolved": "https://registry.npmmirror.com/@types/d3-shape/-/d3-shape-3.1.7.tgz", + "integrity": "sha512-VLvUQ33C+3J+8p+Daf+nYSOsjB4GXp19/S/aGo60m9h1v6XaxjiT82lKVWJCfzhtuZ3yD7i/TPeC/fuKLLOSmg==", + "dev": true, + "dependencies": { + "@types/d3-path": "*" + } + }, + "node_modules/@types/d3-time": { + "version": "3.0.4", + "resolved": "https://registry.npmmirror.com/@types/d3-time/-/d3-time-3.0.4.tgz", + "integrity": "sha512-yuzZug1nkAAaBlBBikKZTgzCeA+k1uy4ZFwWANOfKw5z5LRhV0gNA7gNkKm7HoK+HRN0wX3EkxGk0fpbWhmB7g==", + "dev": true + }, + "node_modules/@types/d3-time-format": { + "version": "4.0.3", + "resolved": "https://registry.npmmirror.com/@types/d3-time-format/-/d3-time-format-4.0.3.tgz", + "integrity": "sha512-5xg9rC+wWL8kdDj153qZcsJ0FWiFt0J5RB6LYUNZjwSnesfblqrI/bJ1wBdJ8OQfncgbJG5+2F+qfqnqyzYxyg==", + "dev": true + }, + "node_modules/@types/d3-timer": { + "version": "3.0.2", + "resolved": "https://registry.npmmirror.com/@types/d3-timer/-/d3-timer-3.0.2.tgz", + "integrity": "sha512-Ps3T8E8dZDam6fUyNiMkekK3XUsaUEik+idO9/YjPtfj2qruF8tFBXS7XhtE4iIXBLxhmLjP3SXpLhVf21I9Lw==", + "dev": true + }, + "node_modules/@types/d3-transition": { + "version": "3.0.9", + "resolved": "https://registry.npmmirror.com/@types/d3-transition/-/d3-transition-3.0.9.tgz", + "integrity": "sha512-uZS5shfxzO3rGlu0cC3bjmMFKsXv+SmZZcgp0KD22ts4uGXp5EVYGzu/0YdwZeKmddhcAccYtREJKkPfXkZuCg==", + "dev": true, + "dependencies": { + "@types/d3-selection": "*" + } + }, + "node_modules/@types/d3-zoom": { + "version": "3.0.8", + "resolved": "https://registry.npmmirror.com/@types/d3-zoom/-/d3-zoom-3.0.8.tgz", + "integrity": "sha512-iqMC4/YlFCSlO8+2Ii1GGGliCAY4XdeG748w5vQUbevlbDu0zSjH/+jojorQVBK/se0j6DUFNPBGSqD3YWYnDw==", + "dev": true, + "dependencies": { + "@types/d3-interpolate": "*", + "@types/d3-selection": "*" + } + }, + "node_modules/@types/eslint": { + "version": "9.6.1", + "resolved": "https://registry.npmmirror.com/@types/eslint/-/eslint-9.6.1.tgz", + "integrity": "sha512-FXx2pKgId/WyYo2jXw63kk7/+TY7u7AziEJxJAnSFzHlqTAS3Ync6SvgYAN/k4/PQpnnVuzoMuVnByKK2qp0ag==", + "license": "MIT", + "dependencies": { + "@types/estree": "*", + "@types/json-schema": "*" + } + }, + "node_modules/@types/eslint-scope": { + "version": "3.7.7", + "resolved": "https://registry.npmmirror.com/@types/eslint-scope/-/eslint-scope-3.7.7.tgz", + "integrity": "sha512-MzMFlSLBqNF2gcHWO0G1vP/YQyfvrxZ0bF+u7mzUdZ1/xK4A4sru+nraZz5i3iEIk1l1uyicaDVTB4QbbEkAYg==", + "license": "MIT", + "dependencies": { + "@types/eslint": "*", + "@types/estree": "*" + } + }, + "node_modules/@types/estree": { + "version": "1.0.7", + "resolved": "https://registry.npmmirror.com/@types/estree/-/estree-1.0.7.tgz", + "integrity": "sha512-w28IoSUCJpidD/TGviZwwMJckNESJZXFu7NBZ5YJ4mEUnNraUn9Pm8HSZm/jDF1pDWYKspWE7oVphigUPRakIQ==", + "license": "MIT" + }, + "node_modules/@types/express": { + "version": "4.17.21", + "resolved": "https://registry.npmmirror.com/@types/express/-/express-4.17.21.tgz", + "integrity": "sha512-ejlPM315qwLpaQlQDTjPdsUFSc6ZsP4AN6AlWnogPjQ7CVi7PYF3YVz+CY3jE2pwYf7E/7HlDAN0rV2GxTG0HQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/body-parser": "*", + "@types/express-serve-static-core": "^4.17.33", + "@types/qs": "*", + "@types/serve-static": "*" + } + }, + "node_modules/@types/express-serve-static-core": { + "version": "5.0.6", + "resolved": "https://registry.npmmirror.com/@types/express-serve-static-core/-/express-serve-static-core-5.0.6.tgz", + "integrity": "sha512-3xhRnjJPkULekpSzgtoNYYcTWgEZkp4myc+Saevii5JPnHNvHMRlBSHDbs7Bh1iPPoVTERHEZXyhyLbMEsExsA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*", + "@types/qs": "*", + "@types/range-parser": "*", + "@types/send": "*" + } + }, + "node_modules/@types/express/node_modules/@types/express-serve-static-core": { + "version": "4.19.6", + "resolved": "https://registry.npmmirror.com/@types/express-serve-static-core/-/express-serve-static-core-4.19.6.tgz", + "integrity": "sha512-N4LZ2xG7DatVqhCZzOGb1Yi5lMbXSZcmdLDe9EzSndPV2HpWYWzRbaerl2n27irrm94EPpprqa8KpskPT085+A==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*", + "@types/qs": "*", + "@types/range-parser": "*", + "@types/send": "*" + } + }, + "node_modules/@types/geojson": { + "version": "7946.0.16", + "resolved": "https://registry.npmmirror.com/@types/geojson/-/geojson-7946.0.16.tgz", + "integrity": "sha512-6C8nqWur3j98U6+lXDfTUWIfgvZU+EumvpHKcYjujKH7woYyLj2sUmff0tRhrqM7BohUw7Pz3ZB1jj2gW9Fvmg==", + "dev": true + }, + "node_modules/@types/glob": { + "version": "7.2.0", + "resolved": "https://registry.npmmirror.com/@types/glob/-/glob-7.2.0.tgz", + "integrity": "sha512-ZUxbzKl0IfJILTS6t7ip5fQQM/J3TJYubDm3nMbgubNNYS62eXeUpoLUC8/7fJNiFYHTrGPQn7hspDUzIHX3UA==", + "license": "MIT", + "dependencies": { + "@types/minimatch": "*", + "@types/node": "*" + } + }, + "node_modules/@types/html-minifier-terser": { + "version": "6.1.0", + "resolved": "https://registry.npmmirror.com/@types/html-minifier-terser/-/html-minifier-terser-6.1.0.tgz", + "integrity": "sha512-oh/6byDPnL1zeNXFrDXFLyZjkr1MsBG667IM792caf1L2UPOOMf65NFzjUH/ltyfwjAGfs1rsX1eftK0jC/KIg==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/http-errors": { + "version": "2.0.4", + "resolved": "https://registry.npmmirror.com/@types/http-errors/-/http-errors-2.0.4.tgz", + "integrity": "sha512-D0CFMMtydbJAegzOyHjtiKPLlvnm3iTZyZRSZoLq2mRhDdmLfIWOCYPfQJ4cu2erKghU++QvjcUjp/5h7hESpA==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/http-proxy": { + "version": "1.17.16", + "resolved": "https://registry.npmmirror.com/@types/http-proxy/-/http-proxy-1.17.16.tgz", + "integrity": "sha512-sdWoUajOB1cd0A8cRRQ1cfyWNbmFKLAqBB89Y8x5iYyG/mkJHc0YUH8pdWBy2omi9qtCpiIgGjuwO0dQST2l5w==", + "dev": true, + "dependencies": { + "@types/node": "*" + } + }, + "node_modules/@types/json-schema": { + "version": "7.0.15", + "resolved": "https://registry.npmmirror.com/@types/json-schema/-/json-schema-7.0.15.tgz", + "integrity": "sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==", + "license": "MIT" + }, + "node_modules/@types/lodash": { + "version": "4.17.20", + "resolved": "https://registry.npmmirror.com/@types/lodash/-/lodash-4.17.20.tgz", + "integrity": "sha512-H3MHACvFUEiujabxhaI/ImO6gUrd8oOurg7LQtS7mbwIXA/cUqWrvBsaeJ23aZEPk1TAYkurjfMbSELfoCXlGA==", + "dev": true + }, + "node_modules/@types/mime": { + "version": "1.3.5", + "resolved": "https://registry.npmmirror.com/@types/mime/-/mime-1.3.5.tgz", + "integrity": "sha512-/pyBZWSLD2n0dcHE3hq8s8ZvcETHtEuF+3E7XVt0Ig2nvsVQXdghHVcEkIWjy9A0wKfTn97a/PSDYohKIlnP/w==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/minimatch": { + "version": "5.1.2", + "resolved": "https://registry.npmmirror.com/@types/minimatch/-/minimatch-5.1.2.tgz", + "integrity": "sha512-K0VQKziLUWkVKiRVrx4a40iPaxTUefQmjtkQofBkYRcoaaL/8rhwDWww9qWbrgicNOgnpIsMxyNIUM4+n6dUIA==", + "license": "MIT" + }, + "node_modules/@types/node": { + "version": "16.18.126", + "resolved": "https://registry.npmmirror.com/@types/node/-/node-16.18.126.tgz", + "integrity": "sha512-OTcgaiwfGFBKacvfwuHzzn1KLxH/er8mluiy8/uM3sGXHaRe73RrSIj01jow9t4kJEW633Ov+cOexXeiApTyAw==", + "license": "MIT" + }, + "node_modules/@types/node-forge": { + "version": "1.3.11", + "resolved": "https://registry.npmmirror.com/@types/node-forge/-/node-forge-1.3.11.tgz", + "integrity": "sha512-FQx220y22OKNTqaByeBGqHWYz4cl94tpcxeFdvBo3wjG6XPBuZ0BNgNZRV5J5TFmmcsJ4IzsLkmGRiQbnYsBEQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*" + } + }, + "node_modules/@types/offscreencanvas": { + "version": "2019.7.3", + "resolved": "https://registry.npmmirror.com/@types/offscreencanvas/-/offscreencanvas-2019.7.3.tgz", + "integrity": "sha512-ieXiYmgSRXUDeOntE1InxjWyvEelZGP63M+cGuquuRLuIKKT1osnkXjxev9B7d1nXSug5vpunx+gNlbVxMlC9A==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/qs": { + "version": "6.9.18", + "resolved": "https://registry.npmmirror.com/@types/qs/-/qs-6.9.18.tgz", + "integrity": "sha512-kK7dgTYDyGqS+e2Q4aK9X3D7q234CIZ1Bv0q/7Z5IwRDoADNU81xXJK/YVyLbLTZCoIwUoDoffFeF+p/eIklAA==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/range-parser": { + "version": "1.2.7", + "resolved": "https://registry.npmmirror.com/@types/range-parser/-/range-parser-1.2.7.tgz", + "integrity": "sha512-hKormJbkJqzQGhziax5PItDUTMAM9uE2XXQmM37dyd4hVM+5aVl7oVxMVUiVQn2oCQFN/LKCZdvSM0pFRqbSmQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/requirejs": { + "version": "2.1.37", + "resolved": "https://registry.npmmirror.com/@types/requirejs/-/requirejs-2.1.37.tgz", + "integrity": "sha512-jmFgr3mwN2NSmtRP6IpZ2nfRS7ufSXuDYQ6YyPFArN8x5dARQcD/DXzT0J6NYbvquVT4pg9K9HWdi6e6DZR9iQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/resize-observer-browser": { + "version": "0.1.11", + "resolved": "https://registry.npmmirror.com/@types/resize-observer-browser/-/resize-observer-browser-0.1.11.tgz", + "integrity": "sha512-cNw5iH8JkMkb3QkCoe7DaZiawbDQEUX8t7iuQaRTyLOyQCR2h+ibBD4GJt7p5yhUHrlOeL7ZtbxNHeipqNsBzQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/retry": { + "version": "0.12.0", + "resolved": "https://registry.npmmirror.com/@types/retry/-/retry-0.12.0.tgz", + "integrity": "sha512-wWKOClTTiizcZhXnPY4wikVAwmdYHp8q6DmC+EJUzAMsycb7HB32Kh9RN4+0gExjmPmZSAQjgURXIGATPegAvA==", + "dev": true + }, + "node_modules/@types/send": { + "version": "0.17.4", + "resolved": "https://registry.npmmirror.com/@types/send/-/send-0.17.4.tgz", + "integrity": "sha512-x2EM6TJOybec7c52BX0ZspPodMsQUd5L6PRwOunVyVUhXiBSKf3AezDL8Dgvgt5o0UfKNfuA0eMLr2wLT4AiBA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/mime": "^1", + "@types/node": "*" + } + }, + "node_modules/@types/serve-index": { + "version": "1.9.4", + "resolved": "https://registry.npmmirror.com/@types/serve-index/-/serve-index-1.9.4.tgz", + "integrity": "sha512-qLpGZ/c2fhSs5gnYsQxtDEq3Oy8SXPClIXkW5ghvAvsNuVSA8k+gCONcUCS/UjLEYvYps+e8uBtfgXgvhwfNug==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/express": "*" + } + }, + "node_modules/@types/serve-static": { + "version": "1.15.7", + "resolved": "https://registry.npmmirror.com/@types/serve-static/-/serve-static-1.15.7.tgz", + "integrity": "sha512-W8Ym+h8nhuRwaKPaDw34QUkwsGi6Rc4yYqvKFo5rm2FUEhCFbzVWrxXUxuKK8TASjWsysJY0nsmNCGhCOIsrOw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/http-errors": "*", + "@types/node": "*", + "@types/send": "*" + } + }, + "node_modules/@types/sockjs": { + "version": "0.3.36", + "resolved": "https://registry.npmmirror.com/@types/sockjs/-/sockjs-0.3.36.tgz", + "integrity": "sha512-MK9V6NzAS1+Ud7JV9lJLFqW85VbC9dq3LmwZCuBe4wBDgKC0Kj/jd8Xl+nSviU+Qc3+m7umHHyHg//2KSa0a0Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*" + } + }, + "node_modules/@types/three": { + "version": "0.131.1", + "resolved": "https://registry.npmmirror.com/@types/three/-/three-0.131.1.tgz", + "integrity": "sha512-unnjsolcm7R90e4XK9qMq4JYEzly0XQNa0pG8RAOMZeVzj3FLIFPymAYUx4Osz0gY9jFZz8omIQplqiieEE7gw==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/trusted-types": { + "version": "2.0.7", + "resolved": "https://registry.npmmirror.com/@types/trusted-types/-/trusted-types-2.0.7.tgz", + "integrity": "sha512-ScaPdn1dQczgbl0QFTeTOmVHFULt394XJgOQNoyVhZ6r2vLnMLJfBPd53SB52T/3G36VI1/g2MZaX0cwDuXsfw==", + "license": "MIT" + }, + "node_modules/@types/ws": { + "version": "8.18.0", + "resolved": "https://registry.npmmirror.com/@types/ws/-/ws-8.18.0.tgz", + "integrity": "sha512-8svvI3hMyvN0kKCJMvTJP/x6Y/EoQbepff882wL+Sn5QsXb3etnamgrJq4isrBxSJj5L2AuXcI0+bgkoAXGUJw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*" + } + }, + "node_modules/@vaadin/a11y-base": { + "version": "24.6.11", + "resolved": "https://registry.npmmirror.com/@vaadin/a11y-base/-/a11y-base-24.6.11.tgz", + "integrity": "sha512-yBZ0QGPngbItIJQx3FRIa9IXDW2Ftf6SFFPGhbdAZafJPBlFi6FElP9cVtL3qjJlI5KKBp/UXEcC8ehPK207gw==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/component-base": "~24.6.11", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/button": { + "version": "24.6.5", + "resolved": "https://registry.npmmirror.com/@vaadin/button/-/button-24.6.5.tgz", + "integrity": "sha512-i+pgR0Gn6EWxLgWEQOi7yXXQSQklsr7a+yotlet1GOB+DymE+w9RVp4WOZ6T8yaqTICKcDQldFkreTzFVxsHAQ==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/a11y-base": "~24.6.5", + "@vaadin/component-base": "~24.6.5", + "@vaadin/vaadin-lumo-styles": "~24.6.5", + "@vaadin/vaadin-material-styles": "~24.6.5", + "@vaadin/vaadin-themable-mixin": "~24.6.5", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/checkbox": { + "version": "24.6.5", + "resolved": "https://registry.npmmirror.com/@vaadin/checkbox/-/checkbox-24.6.5.tgz", + "integrity": "sha512-XYhiEr9PMnonyMot7y8PIK/9GlNNCxupFg7ZFsrGvsDpfD5fCfpsnv+c3kUwVN6yFze3NVQEXOyp/UGoZNoNmQ==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/a11y-base": "~24.6.5", + "@vaadin/component-base": "~24.6.5", + "@vaadin/field-base": "~24.6.5", + "@vaadin/vaadin-lumo-styles": "~24.6.5", + "@vaadin/vaadin-material-styles": "~24.6.5", + "@vaadin/vaadin-themable-mixin": "~24.6.5", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/checkbox-group": { + "version": "24.6.5", + "resolved": "https://registry.npmmirror.com/@vaadin/checkbox-group/-/checkbox-group-24.6.5.tgz", + "integrity": "sha512-1K34LnXxINlMSrwAynLW46nyAGqz6kZW4ogZeKESXa+JogjOiHCaVy127xIKYmfJD2yR4ti31VPQKPNQXlZpxA==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/a11y-base": "~24.6.5", + "@vaadin/checkbox": "~24.6.5", + "@vaadin/component-base": "~24.6.5", + "@vaadin/field-base": "~24.6.5", + "@vaadin/vaadin-lumo-styles": "~24.6.5", + "@vaadin/vaadin-material-styles": "~24.6.5", + "@vaadin/vaadin-themable-mixin": "~24.6.5", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/combo-box": { + "version": "24.6.5", + "resolved": "https://registry.npmmirror.com/@vaadin/combo-box/-/combo-box-24.6.5.tgz", + "integrity": "sha512-u/xC9QegwWgmw9TutPRoIzeBpUgG6Kt9CmJbNZNeWBrP9Nicz/QAawApynvjWQtmm7zIKXp7SPzW1Gqwpe09mQ==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/a11y-base": "~24.6.5", + "@vaadin/component-base": "~24.6.5", + "@vaadin/field-base": "~24.6.5", + "@vaadin/input-container": "~24.6.5", + "@vaadin/item": "~24.6.5", + "@vaadin/lit-renderer": "~24.6.5", + "@vaadin/overlay": "~24.6.5", + "@vaadin/vaadin-lumo-styles": "~24.6.5", + "@vaadin/vaadin-material-styles": "~24.6.5", + "@vaadin/vaadin-themable-mixin": "~24.6.5", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/component-base": { + "version": "24.6.11", + "resolved": "https://registry.npmmirror.com/@vaadin/component-base/-/component-base-24.6.11.tgz", + "integrity": "sha512-7jR6vcJeCBgY2CNbAPLOcUTsxYspqdkA0slUGk3GwfgsRDD5FLkzqQDSM5+yE6O2+4Wah2Tk+kG/GsKGtlUlwg==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/vaadin-development-mode-detector": "^2.0.0", + "@vaadin/vaadin-usage-statistics": "^2.1.0", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/confirm-dialog": { + "version": "24.6.5", + "resolved": "https://registry.npmmirror.com/@vaadin/confirm-dialog/-/confirm-dialog-24.6.5.tgz", + "integrity": "sha512-YkEFCj4psD/+K50fpqoJIUYRvdyzoRfWdC3ZDaDcliWND7m50cNZBxj1psZ+iOP32MVK+4Ke16fFbEY0tKoyWg==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/button": "~24.6.5", + "@vaadin/component-base": "~24.6.5", + "@vaadin/dialog": "~24.6.5", + "@vaadin/overlay": "~24.6.5", + "@vaadin/vaadin-lumo-styles": "~24.6.5", + "@vaadin/vaadin-material-styles": "~24.6.5", + "@vaadin/vaadin-themable-mixin": "~24.6.5", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/context-menu": { + "version": "24.6.5", + "resolved": "https://registry.npmmirror.com/@vaadin/context-menu/-/context-menu-24.6.5.tgz", + "integrity": "sha512-WLFKmoyIG+GI/UQH4EohhBsLhsYPGV1wdE80Gu+0Gl3/aGLm1ofl6ls+iVzC+/AOBIjNFG1TGmrxMIti1zk0PA==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/a11y-base": "~24.6.5", + "@vaadin/component-base": "~24.6.5", + "@vaadin/item": "~24.6.5", + "@vaadin/list-box": "~24.6.5", + "@vaadin/lit-renderer": "~24.6.5", + "@vaadin/overlay": "~24.6.5", + "@vaadin/vaadin-lumo-styles": "~24.6.5", + "@vaadin/vaadin-material-styles": "~24.6.5", + "@vaadin/vaadin-themable-mixin": "~24.6.5", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/details": { + "version": "24.6.5", + "resolved": "https://registry.npmmirror.com/@vaadin/details/-/details-24.6.5.tgz", + "integrity": "sha512-V22OCdRnT7qOVsVpedGfrwDPE9dFWdhFDv66RfkiWGHpPoq0+dYUpP2Y5Iy7YRCxqVnogVBiE8qHPgZAO4U18A==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/a11y-base": "~24.6.5", + "@vaadin/button": "~24.6.5", + "@vaadin/component-base": "~24.6.5", + "@vaadin/vaadin-lumo-styles": "~24.6.5", + "@vaadin/vaadin-material-styles": "~24.6.5", + "@vaadin/vaadin-themable-mixin": "~24.6.5", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/dialog": { + "version": "24.6.11", + "resolved": "https://registry.npmmirror.com/@vaadin/dialog/-/dialog-24.6.11.tgz", + "integrity": "sha512-EWMkf8yH6sFEG4e9zTpPW20sL8vTz5TpBvJmD4rUUqvxEphHCUI+uO0N05/XFa2k7xZS6q8G38hc1BcCYr5Syw==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/component-base": "~24.6.11", + "@vaadin/lit-renderer": "~24.6.11", + "@vaadin/overlay": "~24.6.11", + "@vaadin/vaadin-lumo-styles": "~24.6.11", + "@vaadin/vaadin-material-styles": "~24.6.11", + "@vaadin/vaadin-themable-mixin": "~24.6.11", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/field-base": { + "version": "24.6.11", + "resolved": "https://registry.npmmirror.com/@vaadin/field-base/-/field-base-24.6.11.tgz", + "integrity": "sha512-dRjxKzbW3xQAau1xuO8uZepVWaImS2wEyKDK9Oh+y8iiu4smYEmo9e4aqMqQN/sOHU6OSa4YtbyJZlvD1sBXrA==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/a11y-base": "~24.6.11", + "@vaadin/component-base": "~24.6.11", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/grid": { + "version": "24.6.5", + "resolved": "https://registry.npmmirror.com/@vaadin/grid/-/grid-24.6.5.tgz", + "integrity": "sha512-BlZO8+oWTmrnCbZESa73IbMuXfxQu7Viotd88NXY/ixq/8LiQqj2yNHtKTPz2l2QL1ke57ckFsjzN6w52nYc5g==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/a11y-base": "~24.6.5", + "@vaadin/checkbox": "~24.6.5", + "@vaadin/component-base": "~24.6.5", + "@vaadin/lit-renderer": "~24.6.5", + "@vaadin/text-field": "~24.6.5", + "@vaadin/vaadin-lumo-styles": "~24.6.5", + "@vaadin/vaadin-material-styles": "~24.6.5", + "@vaadin/vaadin-themable-mixin": "~24.6.5", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/grid/node_modules/@vaadin/checkbox": { + "version": "24.6.11", + "resolved": "https://registry.npmmirror.com/@vaadin/checkbox/-/checkbox-24.6.11.tgz", + "integrity": "sha512-Uvd6gZ3xQQrZTtCJL6f4uLbg6mXsAKjiZto7Je39yJwUHz8r5MIQr+4mLF4zc6mYVSH/Ihj/a4n9FOuTwSEuQw==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/a11y-base": "~24.6.11", + "@vaadin/component-base": "~24.6.11", + "@vaadin/field-base": "~24.6.11", + "@vaadin/vaadin-lumo-styles": "~24.6.11", + "@vaadin/vaadin-material-styles": "~24.6.11", + "@vaadin/vaadin-themable-mixin": "~24.6.11", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/icon": { + "version": "24.6.5", + "resolved": "https://registry.npmmirror.com/@vaadin/icon/-/icon-24.6.5.tgz", + "integrity": "sha512-y6Jy69nySb3tZqEIYAYpyGTiNkKS//ro+w6tuD0a0gu+GrfTv90XDNEY9FvGvnUHsM44OoiQRH3kD15kmISkxQ==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/component-base": "~24.6.5", + "@vaadin/vaadin-lumo-styles": "~24.6.5", + "@vaadin/vaadin-themable-mixin": "~24.6.5", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/icons": { + "version": "24.6.5", + "resolved": "https://registry.npmmirror.com/@vaadin/icons/-/icons-24.6.5.tgz", + "integrity": "sha512-zd8KKkJ18EI70IQGoCz3hcQed+VFPnqECKci8vt+OJi1n5j7qzPW4sbEOLZxr6cWrnN1eNdSHfJCQWXrFfL0bQ==", + "license": "Apache-2.0", + "dependencies": { + "@polymer/polymer": "^3.0.0", + "@vaadin/icon": "~24.6.5" + } + }, + "node_modules/@vaadin/input-container": { + "version": "24.6.7", + "resolved": "https://registry.npmmirror.com/@vaadin/input-container/-/input-container-24.6.7.tgz", + "integrity": "sha512-376ZyD74jrKvjiM+gE0xNScyZPU7REMBbGXpmM4DpoLYgw60m01D3fliZaOTVDyXc3gvxWIai3L1vCY0KYpD6w==", + "license": "Apache-2.0", + "dependencies": { + "@polymer/polymer": "^3.0.0", + "@vaadin/component-base": "~24.6.7", + "@vaadin/vaadin-lumo-styles": "~24.6.7", + "@vaadin/vaadin-material-styles": "~24.6.7", + "@vaadin/vaadin-themable-mixin": "~24.6.7", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/item": { + "version": "24.6.7", + "resolved": "https://registry.npmmirror.com/@vaadin/item/-/item-24.6.7.tgz", + "integrity": "sha512-9xpJEVhgHF3YQGVeet2uakMTH7SyEbQx+uT5Kld/r1CiCYOKUxbERXrFuJ/5/lgakXjDvN1d7rYDcjPb3CUfsQ==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/a11y-base": "~24.6.7", + "@vaadin/component-base": "~24.6.7", + "@vaadin/vaadin-lumo-styles": "~24.6.7", + "@vaadin/vaadin-material-styles": "~24.6.7", + "@vaadin/vaadin-themable-mixin": "~24.6.7", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/list-box": { + "version": "24.6.7", + "resolved": "https://registry.npmmirror.com/@vaadin/list-box/-/list-box-24.6.7.tgz", + "integrity": "sha512-yUBHonI6uD28l2h+CUh2KPzXe+Ptv6UWtNJIIevX/xkQhptquXzE01bVXlh1NcLVppnu21gaxFs/l+/rHlAKpw==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/a11y-base": "~24.6.7", + "@vaadin/component-base": "~24.6.7", + "@vaadin/item": "~24.6.7", + "@vaadin/vaadin-lumo-styles": "~24.6.7", + "@vaadin/vaadin-material-styles": "~24.6.7", + "@vaadin/vaadin-themable-mixin": "~24.6.7", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/lit-renderer": { + "version": "24.6.11", + "resolved": "https://registry.npmmirror.com/@vaadin/lit-renderer/-/lit-renderer-24.6.11.tgz", + "integrity": "sha512-JugFumbBQP4r28+HcbdDUVVGs5VRsqanLsifjkVrz/xb4saWv460lEYco5ES+StH+xZ2IuJZmEjEFUBSrVR/tA==", + "license": "Apache-2.0", + "dependencies": { + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/notification": { + "version": "24.6.5", + "resolved": "https://registry.npmmirror.com/@vaadin/notification/-/notification-24.6.5.tgz", + "integrity": "sha512-9OgYmZn3qU3pVMaoIRITNs6gymrnswYO7bk9+8e97o3W4A9TIcAO6F2HTgLO5ieMuuOI1DSlVCpXbrM3xBe8pw==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/component-base": "~24.6.5", + "@vaadin/lit-renderer": "~24.6.5", + "@vaadin/overlay": "~24.6.5", + "@vaadin/vaadin-lumo-styles": "~24.6.5", + "@vaadin/vaadin-material-styles": "~24.6.5", + "@vaadin/vaadin-themable-mixin": "~24.6.5", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/overlay": { + "version": "24.6.11", + "resolved": "https://registry.npmmirror.com/@vaadin/overlay/-/overlay-24.6.11.tgz", + "integrity": "sha512-RrN16YWKpg2pxAu6RxlUxBbovf+1+RYRmVvUk0WKBuaA4EBli0mN1vkZbK0XEsPvC2bGHPMWWqwmUzaIXv+1bw==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/a11y-base": "~24.6.11", + "@vaadin/component-base": "~24.6.11", + "@vaadin/vaadin-lumo-styles": "~24.6.11", + "@vaadin/vaadin-material-styles": "~24.6.11", + "@vaadin/vaadin-themable-mixin": "~24.6.11", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/popover": { + "version": "24.6.7", + "resolved": "https://registry.npmmirror.com/@vaadin/popover/-/popover-24.6.7.tgz", + "integrity": "sha512-GqdDsi+x6+6YNBNPC+BvrshrwXlcmL+nR8v5sY+l1TMPVKNWFb2579Qzc9vvu7jMOr2rQd3F+ZjPoMAqgwuZHw==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@vaadin/a11y-base": "~24.6.7", + "@vaadin/component-base": "~24.6.7", + "@vaadin/lit-renderer": "~24.6.7", + "@vaadin/overlay": "~24.6.7", + "@vaadin/vaadin-lumo-styles": "~24.6.7", + "@vaadin/vaadin-material-styles": "~24.6.7", + "@vaadin/vaadin-themable-mixin": "~24.6.7", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/progress-bar": { + "version": "24.6.5", + "resolved": "https://registry.npmmirror.com/@vaadin/progress-bar/-/progress-bar-24.6.5.tgz", + "integrity": "sha512-lJPRV1SAP0Z46pcgQ9RiV8ZVqytDpIDZ7oMJW7WsjS70CAlrqJZF0JoJ3WoqUrHasNhxU7jjx+iXVXw7CzRrDg==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/component-base": "~24.6.5", + "@vaadin/vaadin-lumo-styles": "~24.6.5", + "@vaadin/vaadin-material-styles": "~24.6.5", + "@vaadin/vaadin-themable-mixin": "~24.6.5", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/scroller": { + "version": "24.6.7", + "resolved": "https://registry.npmmirror.com/@vaadin/scroller/-/scroller-24.6.7.tgz", + "integrity": "sha512-JLqrJCVcfo3GELWd8xNLGif+xz4WpiodPn4uW5/kI3lqLKYg7RKhEu9dg1zRpSEUou5SVFQCMB9m+D1AwyoQGQ==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/a11y-base": "~24.6.7", + "@vaadin/component-base": "~24.6.7", + "@vaadin/vaadin-lumo-styles": "~24.6.7", + "@vaadin/vaadin-material-styles": "~24.6.7", + "@vaadin/vaadin-themable-mixin": "~24.6.7", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/select": { + "version": "24.6.5", + "resolved": "https://registry.npmmirror.com/@vaadin/select/-/select-24.6.5.tgz", + "integrity": "sha512-dDVv4d4QLs7EZEJuOkBI/wjmR7mZ5TyUacCKmscq+Ke7DQrq46DuCUjj82+OSFC7z2m3+v5wflfVMciQehR1+Q==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.2.0", + "@vaadin/a11y-base": "~24.6.5", + "@vaadin/button": "~24.6.5", + "@vaadin/component-base": "~24.6.5", + "@vaadin/field-base": "~24.6.5", + "@vaadin/input-container": "~24.6.5", + "@vaadin/item": "~24.6.5", + "@vaadin/list-box": "~24.6.5", + "@vaadin/lit-renderer": "~24.6.5", + "@vaadin/overlay": "~24.6.5", + "@vaadin/vaadin-lumo-styles": "~24.6.5", + "@vaadin/vaadin-material-styles": "~24.6.5", + "@vaadin/vaadin-themable-mixin": "~24.6.5", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/tabs": { + "version": "24.6.5", + "resolved": "https://registry.npmmirror.com/@vaadin/tabs/-/tabs-24.6.5.tgz", + "integrity": "sha512-svUqDjwzlnKsAOYB0szST4Tjhspnb007bMf16fhmkM12u3KK053hEZ2TYX7lNVFLC3RiDvGa8i6nCAK2SVXCDQ==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/a11y-base": "~24.6.5", + "@vaadin/component-base": "~24.6.5", + "@vaadin/item": "~24.6.5", + "@vaadin/vaadin-lumo-styles": "~24.6.5", + "@vaadin/vaadin-material-styles": "~24.6.5", + "@vaadin/vaadin-themable-mixin": "~24.6.5", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/tabsheet": { + "version": "24.6.5", + "resolved": "https://registry.npmmirror.com/@vaadin/tabsheet/-/tabsheet-24.6.5.tgz", + "integrity": "sha512-dn4RFFdK+7Hu6Hhq/V0jb1pwwcLxipgMjAmYGsot4vapqFKSdqea1WpVo6TvVkGXCg3TIrYq5SRbzrIzh9FEzg==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/component-base": "~24.6.5", + "@vaadin/scroller": "~24.6.5", + "@vaadin/tabs": "~24.6.5", + "@vaadin/vaadin-lumo-styles": "~24.6.5", + "@vaadin/vaadin-material-styles": "~24.6.5", + "@vaadin/vaadin-themable-mixin": "~24.6.5", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/text-field": { + "version": "24.6.5", + "resolved": "https://registry.npmmirror.com/@vaadin/text-field/-/text-field-24.6.5.tgz", + "integrity": "sha512-zujt5k6i6pkVbfUiQlYWBGa/MUAmWeq0xhDLgHIapzUlEIq6gf67KFwEfhfmwdVzGQImFTTKUBWhO4DERRF0Nw==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/a11y-base": "~24.6.5", + "@vaadin/component-base": "~24.6.5", + "@vaadin/field-base": "~24.6.5", + "@vaadin/input-container": "~24.6.5", + "@vaadin/vaadin-lumo-styles": "~24.6.5", + "@vaadin/vaadin-material-styles": "~24.6.5", + "@vaadin/vaadin-themable-mixin": "~24.6.5", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/tooltip": { + "version": "24.6.5", + "resolved": "https://registry.npmmirror.com/@vaadin/tooltip/-/tooltip-24.6.5.tgz", + "integrity": "sha512-IPcMN61PO+u9IgHyM3GCqrzSUQUo13Tysvp58Z7OvtZg/IgQpcEtWkC2m+Qg9rwJAZu/x37Qfd/8on0TQWzlMg==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/a11y-base": "~24.6.5", + "@vaadin/component-base": "~24.6.5", + "@vaadin/overlay": "~24.6.5", + "@vaadin/popover": "~24.6.5", + "@vaadin/vaadin-lumo-styles": "~24.6.5", + "@vaadin/vaadin-material-styles": "~24.6.5", + "@vaadin/vaadin-themable-mixin": "~24.6.5", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/vaadin-development-mode-detector": { + "version": "2.0.7", + "resolved": "https://registry.npmmirror.com/@vaadin/vaadin-development-mode-detector/-/vaadin-development-mode-detector-2.0.7.tgz", + "integrity": "sha512-9FhVhr0ynSR3X2ao+vaIEttcNU5XfzCbxtmYOV8uIRnUCtNgbvMOIcyGBvntsX9I5kvIP2dV3cFAOG9SILJzEA==", + "license": "Apache-2.0" + }, + "node_modules/@vaadin/vaadin-lumo-styles": { + "version": "24.6.11", + "resolved": "https://registry.npmmirror.com/@vaadin/vaadin-lumo-styles/-/vaadin-lumo-styles-24.6.11.tgz", + "integrity": "sha512-WRluczao8lZgImdtl66v09YjFULb1iLAhcU48aiR9igAT7h6aLeHYBvRH3AA/gBlUNwHd4xlBSl89p4HP2GGog==", + "license": "Apache-2.0", + "dependencies": { + "@polymer/polymer": "^3.0.0", + "@vaadin/component-base": "~24.6.11", + "@vaadin/icon": "~24.6.11", + "@vaadin/vaadin-themable-mixin": "~24.6.11" + } + }, + "node_modules/@vaadin/vaadin-lumo-styles/node_modules/@vaadin/icon": { + "version": "24.6.11", + "resolved": "https://registry.npmmirror.com/@vaadin/icon/-/icon-24.6.11.tgz", + "integrity": "sha512-CKOh+I84+GZRfMHrhtATtrw3bSW5eUArgGT4cKsOY3asoCZXUdTObPD/PqKfP4e2uAA1bgLl27kOc+W8dmibJA==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "@polymer/polymer": "^3.0.0", + "@vaadin/component-base": "~24.6.11", + "@vaadin/vaadin-lumo-styles": "~24.6.11", + "@vaadin/vaadin-themable-mixin": "~24.6.11", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/vaadin-material-styles": { + "version": "24.6.11", + "resolved": "https://registry.npmmirror.com/@vaadin/vaadin-material-styles/-/vaadin-material-styles-24.6.11.tgz", + "integrity": "sha512-tDumwlaDp/s9u++MPi64I1o2ls/drWOZf4xVPhztUjt3NwYJUeVXtwu39q0wBRIeRM7UBrs06kug2CVT72U4qQ==", + "license": "Apache-2.0", + "dependencies": { + "@polymer/polymer": "^3.0.0", + "@vaadin/component-base": "~24.6.11", + "@vaadin/vaadin-themable-mixin": "~24.6.11" + } + }, + "node_modules/@vaadin/vaadin-themable-mixin": { + "version": "24.6.11", + "resolved": "https://registry.npmmirror.com/@vaadin/vaadin-themable-mixin/-/vaadin-themable-mixin-24.6.11.tgz", + "integrity": "sha512-xCmn3X+2C7nI9LQn2OqLLkLw7VeJOCo99DlHwnxeLZpJJ/s8bjDXcIWflS+IOChzHixgEFkDSoLcNYoCR1RvYg==", + "license": "Apache-2.0", + "dependencies": { + "@open-wc/dedupe-mixin": "^1.3.0", + "lit": "^3.0.0" + } + }, + "node_modules/@vaadin/vaadin-usage-statistics": { + "version": "2.1.3", + "resolved": "https://registry.npmmirror.com/@vaadin/vaadin-usage-statistics/-/vaadin-usage-statistics-2.1.3.tgz", + "integrity": "sha512-8r4TNknD7OJQADe3VygeofFR7UNAXZ2/jjBFP5dgI8+2uMfnuGYgbuHivasKr9WSQ64sPej6m8rDoM1uSllXjQ==", + "hasInstallScript": true, + "license": "Apache-2.0", + "dependencies": { + "@vaadin/vaadin-development-mode-detector": "^2.0.0" + }, + "engines": { + "node": "^12.20.0 || ^14.13.1 || >=16.0.0" + } + }, + "node_modules/@webassemblyjs/ast": { + "version": "1.14.1", + "resolved": "https://registry.npmmirror.com/@webassemblyjs/ast/-/ast-1.14.1.tgz", + "integrity": "sha512-nuBEDgQfm1ccRp/8bCQrx1frohyufl4JlbMMZ4P1wpeOfDhF6FQkxZJ1b/e+PLwr6X1Nhw6OLme5usuBWYBvuQ==", + "license": "MIT", + "dependencies": { + "@webassemblyjs/helper-numbers": "1.13.2", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2" + } + }, + "node_modules/@webassemblyjs/floating-point-hex-parser": { + "version": "1.13.2", + "resolved": "https://registry.npmmirror.com/@webassemblyjs/floating-point-hex-parser/-/floating-point-hex-parser-1.13.2.tgz", + "integrity": "sha512-6oXyTOzbKxGH4steLbLNOu71Oj+C8Lg34n6CqRvqfS2O71BxY6ByfMDRhBytzknj9yGUPVJ1qIKhRlAwO1AovA==", + "license": "MIT" + }, + "node_modules/@webassemblyjs/helper-api-error": { + "version": "1.13.2", + "resolved": "https://registry.npmmirror.com/@webassemblyjs/helper-api-error/-/helper-api-error-1.13.2.tgz", + "integrity": "sha512-U56GMYxy4ZQCbDZd6JuvvNV/WFildOjsaWD3Tzzvmw/mas3cXzRJPMjP83JqEsgSbyrmaGjBfDtV7KDXV9UzFQ==", + "license": "MIT" + }, + "node_modules/@webassemblyjs/helper-buffer": { + "version": "1.14.1", + "resolved": "https://registry.npmmirror.com/@webassemblyjs/helper-buffer/-/helper-buffer-1.14.1.tgz", + "integrity": "sha512-jyH7wtcHiKssDtFPRB+iQdxlDf96m0E39yb0k5uJVhFGleZFoNw1c4aeIcVUPPbXUVJ94wwnMOAqUHyzoEPVMA==", + "license": "MIT" + }, + "node_modules/@webassemblyjs/helper-numbers": { + "version": "1.13.2", + "resolved": "https://registry.npmmirror.com/@webassemblyjs/helper-numbers/-/helper-numbers-1.13.2.tgz", + "integrity": "sha512-FE8aCmS5Q6eQYcV3gI35O4J789wlQA+7JrqTTpJqn5emA4U2hvwJmvFRC0HODS+3Ye6WioDklgd6scJ3+PLnEA==", + "license": "MIT", + "dependencies": { + "@webassemblyjs/floating-point-hex-parser": "1.13.2", + "@webassemblyjs/helper-api-error": "1.13.2", + "@xtuc/long": "4.2.2" + } + }, + "node_modules/@webassemblyjs/helper-wasm-bytecode": { + "version": "1.13.2", + "resolved": "https://registry.npmmirror.com/@webassemblyjs/helper-wasm-bytecode/-/helper-wasm-bytecode-1.13.2.tgz", + "integrity": "sha512-3QbLKy93F0EAIXLh0ogEVR6rOubA9AoZ+WRYhNbFyuB70j3dRdwH9g+qXhLAO0kiYGlg3TxDV+I4rQTr/YNXkA==", + "license": "MIT" + }, + "node_modules/@webassemblyjs/helper-wasm-section": { + "version": "1.14.1", + "resolved": "https://registry.npmmirror.com/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.14.1.tgz", + "integrity": "sha512-ds5mXEqTJ6oxRoqjhWDU83OgzAYjwsCV8Lo/N+oRsNDmx/ZDpqalmrtgOMkHwxsG0iI//3BwWAErYRHtgn0dZw==", + "license": "MIT", + "dependencies": { + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-buffer": "1.14.1", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/wasm-gen": "1.14.1" + } + }, + "node_modules/@webassemblyjs/ieee754": { + "version": "1.13.2", + "resolved": "https://registry.npmmirror.com/@webassemblyjs/ieee754/-/ieee754-1.13.2.tgz", + "integrity": "sha512-4LtOzh58S/5lX4ITKxnAK2USuNEvpdVV9AlgGQb8rJDHaLeHciwG4zlGr0j/SNWlr7x3vO1lDEsuePvtcDNCkw==", + "license": "MIT", + "dependencies": { + "@xtuc/ieee754": "^1.2.0" + } + }, + "node_modules/@webassemblyjs/leb128": { + "version": "1.13.2", + "resolved": "https://registry.npmmirror.com/@webassemblyjs/leb128/-/leb128-1.13.2.tgz", + "integrity": "sha512-Lde1oNoIdzVzdkNEAWZ1dZ5orIbff80YPdHx20mrHwHrVNNTjNr8E3xz9BdpcGqRQbAEa+fkrCb+fRFTl/6sQw==", + "license": "Apache-2.0", + "dependencies": { + "@xtuc/long": "4.2.2" + } + }, + "node_modules/@webassemblyjs/utf8": { + "version": "1.13.2", + "resolved": "https://registry.npmmirror.com/@webassemblyjs/utf8/-/utf8-1.13.2.tgz", + "integrity": "sha512-3NQWGjKTASY1xV5m7Hr0iPeXD9+RDobLll3T9d2AO+g3my8xy5peVyjSag4I50mR1bBSN/Ct12lo+R9tJk0NZQ==", + "license": "MIT" + }, + "node_modules/@webassemblyjs/wasm-edit": { + "version": "1.14.1", + "resolved": "https://registry.npmmirror.com/@webassemblyjs/wasm-edit/-/wasm-edit-1.14.1.tgz", + "integrity": "sha512-RNJUIQH/J8iA/1NzlE4N7KtyZNHi3w7at7hDjvRNm5rcUXa00z1vRz3glZoULfJ5mpvYhLybmVcwcjGrC1pRrQ==", + "license": "MIT", + "dependencies": { + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-buffer": "1.14.1", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/helper-wasm-section": "1.14.1", + "@webassemblyjs/wasm-gen": "1.14.1", + "@webassemblyjs/wasm-opt": "1.14.1", + "@webassemblyjs/wasm-parser": "1.14.1", + "@webassemblyjs/wast-printer": "1.14.1" + } + }, + "node_modules/@webassemblyjs/wasm-gen": { + "version": "1.14.1", + "resolved": "https://registry.npmmirror.com/@webassemblyjs/wasm-gen/-/wasm-gen-1.14.1.tgz", + "integrity": "sha512-AmomSIjP8ZbfGQhumkNvgC33AY7qtMCXnN6bL2u2Js4gVCg8fp735aEiMSBbDR7UQIj90n4wKAFUSEd0QN2Ukg==", + "license": "MIT", + "dependencies": { + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/ieee754": "1.13.2", + "@webassemblyjs/leb128": "1.13.2", + "@webassemblyjs/utf8": "1.13.2" + } + }, + "node_modules/@webassemblyjs/wasm-opt": { + "version": "1.14.1", + "resolved": "https://registry.npmmirror.com/@webassemblyjs/wasm-opt/-/wasm-opt-1.14.1.tgz", + "integrity": "sha512-PTcKLUNvBqnY2U6E5bdOQcSM+oVP/PmrDY9NzowJjislEjwP/C4an2303MCVS2Mg9d3AJpIGdUFIQQWbPds0Sw==", + "license": "MIT", + "dependencies": { + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-buffer": "1.14.1", + "@webassemblyjs/wasm-gen": "1.14.1", + "@webassemblyjs/wasm-parser": "1.14.1" + } + }, + "node_modules/@webassemblyjs/wasm-parser": { + "version": "1.14.1", + "resolved": "https://registry.npmmirror.com/@webassemblyjs/wasm-parser/-/wasm-parser-1.14.1.tgz", + "integrity": "sha512-JLBl+KZ0R5qB7mCnud/yyX08jWFw5MsoalJ1pQ4EdFlgj9VdXKGuENGsiCIjegI1W7p91rUlcB/LB5yRJKNTcQ==", + "license": "MIT", + "dependencies": { + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-api-error": "1.13.2", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/ieee754": "1.13.2", + "@webassemblyjs/leb128": "1.13.2", + "@webassemblyjs/utf8": "1.13.2" + } + }, + "node_modules/@webassemblyjs/wast-printer": { + "version": "1.14.1", + "resolved": "https://registry.npmmirror.com/@webassemblyjs/wast-printer/-/wast-printer-1.14.1.tgz", + "integrity": "sha512-kPSSXE6De1XOR820C90RIo2ogvZG+c3KiHzqUoO/F34Y2shGzesfqv7o57xrxovZJH/MetF5UjroJ/R/3isoiw==", + "license": "MIT", + "dependencies": { + "@webassemblyjs/ast": "1.14.1", + "@xtuc/long": "4.2.2" + } + }, + "node_modules/@webcomponents/shadycss": { + "version": "1.11.2", + "resolved": "https://registry.npmmirror.com/@webcomponents/shadycss/-/shadycss-1.11.2.tgz", + "integrity": "sha512-vRq+GniJAYSBmTRnhCYPAPq6THYqovJ/gzGThWbgEZUQaBccndGTi1hdiUP15HzEco0I6t4RCtXyX0rsSmwgPw==", + "license": "BSD-3-Clause" + }, + "node_modules/@webpack-cli/configtest": { + "version": "2.1.1", + "resolved": "https://registry.npmmirror.com/@webpack-cli/configtest/-/configtest-2.1.1.tgz", + "integrity": "sha512-wy0mglZpDSiSS0XHrVR+BAdId2+yxPSoJW8fsna3ZpYSlufjvxnP4YbKTCBZnNIcGN4r6ZPXV55X4mYExOfLmw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14.15.0" + }, + "peerDependencies": { + "webpack": "5.x.x", + "webpack-cli": "5.x.x" + } + }, + "node_modules/@webpack-cli/info": { + "version": "2.0.2", + "resolved": "https://registry.npmmirror.com/@webpack-cli/info/-/info-2.0.2.tgz", + "integrity": "sha512-zLHQdI/Qs1UyT5UBdWNqsARasIA+AaF8t+4u2aS2nEpBQh2mWIVb8qAklq0eUENnC5mOItrIB4LiS9xMtph18A==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14.15.0" + }, + "peerDependencies": { + "webpack": "5.x.x", + "webpack-cli": "5.x.x" + } + }, + "node_modules/@webpack-cli/serve": { + "version": "2.0.5", + "resolved": "https://registry.npmmirror.com/@webpack-cli/serve/-/serve-2.0.5.tgz", + "integrity": "sha512-lqaoKnRYBdo1UgDX8uF24AfGMifWK19TxPmM5FHc2vAGxrJ/qtyUyFBWoY1tISZdelsQ5fBcOusifo5o5wSJxQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14.15.0" + }, + "peerDependencies": { + "webpack": "5.x.x", + "webpack-cli": "5.x.x" + }, + "peerDependenciesMeta": { + "webpack-dev-server": { + "optional": true + } + } + }, + "node_modules/@xtuc/ieee754": { + "version": "1.2.0", + "resolved": "https://registry.npmmirror.com/@xtuc/ieee754/-/ieee754-1.2.0.tgz", + "integrity": "sha512-DX8nKgqcGwsc0eJSqYt5lwP4DH5FlHnmuWWBRy7X0NcaGR0ZtuyeESgMwTYVEtxmsNGY+qit4QYT/MIYTOTPeA==", + "license": "BSD-3-Clause" + }, + "node_modules/@xtuc/long": { + "version": "4.2.2", + "resolved": "https://registry.npmmirror.com/@xtuc/long/-/long-4.2.2.tgz", + "integrity": "sha512-NuHqBY1PB/D8xU6s/thBgOAiAP7HOYDQ32+BFZILJ8ivkUkAHQnWfn6WhL79Owj1qmUnoN/YPhktdIoucipkAQ==", + "license": "Apache-2.0" + }, + "node_modules/accepts": { + "version": "1.3.8", + "resolved": "https://registry.npmmirror.com/accepts/-/accepts-1.3.8.tgz", + "integrity": "sha512-PYAthTa2m2VKxuvSD3DPC/Gy+U+sOA1LAuT8mkmRuvw+NACSaeXEQ+NHcVF7rONl6qcaxV3Uuemwawk+7+SJLw==", + "dev": true, + "license": "MIT", + "dependencies": { + "mime-types": "~2.1.34", + "negotiator": "0.6.3" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/accepts/node_modules/negotiator": { + "version": "0.6.3", + "resolved": "https://registry.npmmirror.com/negotiator/-/negotiator-0.6.3.tgz", + "integrity": "sha512-+EUsqGPLsM+j/zdChZjsnX51g4XrHFOIXwfnCVPGlQk/k5giakcKsuxCObBRu6DSm9opw/O6slWbJdghQM4bBg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/acorn": { + "version": "8.14.1", + "resolved": "https://registry.npmmirror.com/acorn/-/acorn-8.14.1.tgz", + "integrity": "sha512-OvQ/2pUDKmgfCg++xsTX1wGxfTaszcHVcTctW4UJB4hibJx2HXxxO5UmVgyjMa+ZDsiaf5wWLXYpRWMmBI0QHg==", + "license": "MIT", + "bin": { + "acorn": "bin/acorn" + }, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/ajv": { + "version": "8.17.1", + "resolved": "https://registry.npmmirror.com/ajv/-/ajv-8.17.1.tgz", + "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.3", + "fast-uri": "^3.0.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/ajv-formats": { + "version": "2.1.1", + "resolved": "https://registry.npmmirror.com/ajv-formats/-/ajv-formats-2.1.1.tgz", + "integrity": "sha512-Wx0Kx52hxE7C18hkMEggYlEifqWZtYaRgouJor+WMdPnQyEK13vgEWyVNup7SoeeoLMsr4kf5h6dOW11I15MUA==", + "license": "MIT", + "dependencies": { + "ajv": "^8.0.0" + }, + "peerDependencies": { + "ajv": "^8.0.0" + }, + "peerDependenciesMeta": { + "ajv": { + "optional": true + } + } + }, + "node_modules/ajv-keywords": { + "version": "5.1.0", + "resolved": "https://registry.npmmirror.com/ajv-keywords/-/ajv-keywords-5.1.0.tgz", + "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.3" + }, + "peerDependencies": { + "ajv": "^8.8.2" + } + }, + "node_modules/ansi-html-community": { + "version": "0.0.8", + "resolved": "https://registry.npmmirror.com/ansi-html-community/-/ansi-html-community-0.0.8.tgz", + "integrity": "sha512-1APHAyr3+PCamwNw3bXCPp4HFLONZt/yIH0sZp0/469KWNTEy+qN5jQ3GVX6DMZ1UXAi34yVwtTeaG/HpBuuzw==", + "dev": true, + "engines": [ + "node >= 0.8.0" + ], + "license": "Apache-2.0", + "bin": { + "ansi-html": "bin/ansi-html" + } + }, + "node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmmirror.com/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmmirror.com/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dev": true, + "license": "MIT", + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/anymatch": { + "version": "3.1.3", + "resolved": "https://registry.npmmirror.com/anymatch/-/anymatch-3.1.3.tgz", + "integrity": "sha512-KMReFUr0B4t+D+OBkjR3KYqvocp2XaSzO55UcB6mgQMd3KbcE+mWTyvVV7D/zsdEbNnV6acZUutkiHQXvTr1Rw==", + "dev": true, + "license": "ISC", + "dependencies": { + "normalize-path": "^3.0.0", + "picomatch": "^2.0.4" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/array-flatten": { + "version": "1.1.1", + "resolved": "https://registry.npmmirror.com/array-flatten/-/array-flatten-1.1.1.tgz", + "integrity": "sha512-PCVAQswWemu6UdxsDFFX/+gVeYqKAod3D3UVm91jHwynguOwAvYPhx8nNlM++NqRcK6CxxpUafjmhIdKiHibqg==", + "dev": true, + "license": "MIT" + }, + "node_modules/array-union": { + "version": "1.0.2", + "resolved": "https://registry.npmmirror.com/array-union/-/array-union-1.0.2.tgz", + "integrity": "sha512-Dxr6QJj/RdU/hCaBjOfxW+q6lyuVE6JFWIrAUpuOOhoJJoQ99cUn3igRaHVB5P9WrgFVN0FfArM3x0cueOU8ng==", + "license": "MIT", + "dependencies": { + "array-uniq": "^1.0.1" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/array-uniq": { + "version": "1.0.3", + "resolved": "https://registry.npmmirror.com/array-uniq/-/array-uniq-1.0.3.tgz", + "integrity": "sha512-MNha4BWQ6JbwhFhj03YK552f7cb3AzoE8SzeljgChvL1dl3IcvggXVz1DilzySZkCja+CXuZbdW7yATchWn8/Q==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/asynckit": { + "version": "0.4.0", + "resolved": "https://registry.npmmirror.com/asynckit/-/asynckit-0.4.0.tgz", + "integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==", + "license": "MIT" + }, + "node_modules/axios": { + "version": "1.8.4", + "resolved": "https://registry.npmmirror.com/axios/-/axios-1.8.4.tgz", + "integrity": "sha512-eBSYY4Y68NNlHbHBMdeDmKNtDgXWhQsJcGqzO3iLUM0GraQFSS9cVgPX5I9b3lbdFKyYoAEGAZF1DwhTaljNAw==", + "license": "MIT", + "dependencies": { + "follow-redirects": "^1.15.6", + "form-data": "^4.0.0", + "proxy-from-env": "^1.1.0" + } + }, + "node_modules/balanced-match": { + "version": "1.0.2", + "resolved": "https://registry.npmmirror.com/balanced-match/-/balanced-match-1.0.2.tgz", + "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", + "license": "MIT" + }, + "node_modules/batch": { + "version": "0.6.1", + "resolved": "https://registry.npmmirror.com/batch/-/batch-0.6.1.tgz", + "integrity": "sha512-x+VAiMRL6UPkx+kudNvxTl6hB2XNNCG2r+7wixVfIYwu/2HKRXimwQyaumLjMveWvT2Hkd/cAJw+QBMfJ/EKVw==", + "dev": true, + "license": "MIT" + }, + "node_modules/binary-extensions": { + "version": "2.3.0", + "resolved": "https://registry.npmmirror.com/binary-extensions/-/binary-extensions-2.3.0.tgz", + "integrity": "sha512-Ceh+7ox5qe7LJuLHoY0feh3pHuUDHAcRUeyL2VYghZwfpkNIy/+8Ocg0a3UuSoYzavmylwuLWQOf3hl0jjMMIw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/body-parser": { + "version": "1.20.3", + "resolved": "https://registry.npmmirror.com/body-parser/-/body-parser-1.20.3.tgz", + "integrity": "sha512-7rAxByjUMqQ3/bHJy7D6OGXvx/MMc4IqBn/X0fcM1QUcAItpZrBEYhWGem+tzXH90c+G01ypMcYJBO9Y30203g==", + "dev": true, + "license": "MIT", + "dependencies": { + "bytes": "3.1.2", + "content-type": "~1.0.5", + "debug": "2.6.9", + "depd": "2.0.0", + "destroy": "1.2.0", + "http-errors": "2.0.0", + "iconv-lite": "0.4.24", + "on-finished": "2.4.1", + "qs": "6.13.0", + "raw-body": "2.5.2", + "type-is": "~1.6.18", + "unpipe": "1.0.0" + }, + "engines": { + "node": ">= 0.8", + "npm": "1.2.8000 || >= 1.4.16" + } + }, + "node_modules/bonjour-service": { + "version": "1.3.0", + "resolved": "https://registry.npmmirror.com/bonjour-service/-/bonjour-service-1.3.0.tgz", + "integrity": "sha512-3YuAUiSkWykd+2Azjgyxei8OWf8thdn8AITIog2M4UICzoqfjlqr64WIjEXZllf/W6vK1goqleSR6brGomxQqA==", + "dev": true, + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.3", + "multicast-dns": "^7.2.5" + } + }, + "node_modules/boolbase": { + "version": "1.0.0", + "resolved": "https://registry.npmmirror.com/boolbase/-/boolbase-1.0.0.tgz", + "integrity": "sha512-JZOSA7Mo9sNGB8+UjSgzdLtokWAky1zbztM3WRLCbZ70/3cTANmQmOdR7y2g+J0e2WXywy1yS468tY+IruqEww==", + "dev": true, + "license": "ISC" + }, + "node_modules/brace-expansion": { + "version": "1.1.12", + "resolved": "https://registry.npmmirror.com/brace-expansion/-/brace-expansion-1.1.12.tgz", + "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "node_modules/braces": { + "version": "3.0.3", + "resolved": "https://registry.npmmirror.com/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", + "dev": true, + "license": "MIT", + "dependencies": { + "fill-range": "^7.1.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/browserslist": { + "version": "4.24.4", + "resolved": "https://registry.npmmirror.com/browserslist/-/browserslist-4.24.4.tgz", + "integrity": "sha512-KDi1Ny1gSePi1vm0q4oxSF8b4DR44GF4BbmS2YdhPLOEqd8pDviZOGH/GsmRwoWJ2+5Lr085X7naowMwKHDG1A==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/browserslist" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/browserslist" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "dependencies": { + "caniuse-lite": "^1.0.30001688", + "electron-to-chromium": "^1.5.73", + "node-releases": "^2.0.19", + "update-browserslist-db": "^1.1.1" + }, + "bin": { + "browserslist": "cli.js" + }, + "engines": { + "node": "^6 || ^7 || ^8 || ^9 || ^10 || ^11 || ^12 || >=13.7" + } + }, + "node_modules/buffer-from": { + "version": "1.1.2", + "resolved": "https://registry.npmmirror.com/buffer-from/-/buffer-from-1.1.2.tgz", + "integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==", + "license": "MIT" + }, + "node_modules/bytes": { + "version": "3.1.2", + "resolved": "https://registry.npmmirror.com/bytes/-/bytes-3.1.2.tgz", + "integrity": "sha512-/Nf7TyzTx6S3yRJObOAV7956r8cr2+Oj8AC5dt8wSP3BQAoeX58NoHyCU8P8zGkNXStjTSi6fzO6F0pBdcYbEg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/call-bind-apply-helpers": { + "version": "1.0.2", + "resolved": "https://registry.npmmirror.com/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz", + "integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/call-bound": { + "version": "1.0.4", + "resolved": "https://registry.npmmirror.com/call-bound/-/call-bound-1.0.4.tgz", + "integrity": "sha512-+ys997U96po4Kx/ABpBCqhA9EuxJaQWDQg7295H4hBphv3IZg0boBKuwYpt4YXp6MZ5AmZQnU/tyMTlRpaSejg==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.2", + "get-intrinsic": "^1.3.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/camel-case": { + "version": "4.1.2", + "resolved": "https://registry.npmmirror.com/camel-case/-/camel-case-4.1.2.tgz", + "integrity": "sha512-gxGWBrTT1JuMx6R+o5PTXMmUnhnVzLQ9SNutD4YqKtI6ap897t3tKECYla6gCWEkplXnlNybEkZg9GEGxKFCgw==", + "dev": true, + "license": "MIT", + "dependencies": { + "pascal-case": "^3.1.2", + "tslib": "^2.0.3" + } + }, + "node_modules/caniuse-lite": { + "version": "1.0.30001707", + "resolved": "https://registry.npmmirror.com/caniuse-lite/-/caniuse-lite-1.0.30001707.tgz", + "integrity": "sha512-3qtRjw/HQSMlDWf+X79N206fepf4SOOU6SQLMaq/0KkZLmSjPxAkBOQQ+FxbHKfHmYLZFfdWsO3KA90ceHPSnw==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/browserslist" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/caniuse-lite" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "CC-BY-4.0" + }, + "node_modules/chalk": { + "version": "4.1.2", + "resolved": "https://registry.npmmirror.com/chalk/-/chalk-4.1.2.tgz", + "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-styles": "^4.1.0", + "supports-color": "^7.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/chalk?sponsor=1" + } + }, + "node_modules/chokidar": { + "version": "3.6.0", + "resolved": "https://registry.npmmirror.com/chokidar/-/chokidar-3.6.0.tgz", + "integrity": "sha512-7VT13fmjotKpGipCW9JEQAusEPE+Ei8nl6/g4FBAmIm0GOOLMua9NDDo/DWp0ZAxCr3cPq5ZpBqmPAQgDda2Pw==", + "dev": true, + "license": "MIT", + "dependencies": { + "anymatch": "~3.1.2", + "braces": "~3.0.2", + "glob-parent": "~5.1.2", + "is-binary-path": "~2.1.0", + "is-glob": "~4.0.1", + "normalize-path": "~3.0.0", + "readdirp": "~3.6.0" + }, + "engines": { + "node": ">= 8.10.0" + }, + "funding": { + "url": "https://paulmillr.com/funding/" + }, + "optionalDependencies": { + "fsevents": "~2.3.2" + } + }, + "node_modules/chrome-trace-event": { + "version": "1.0.4", + "resolved": "https://registry.npmmirror.com/chrome-trace-event/-/chrome-trace-event-1.0.4.tgz", + "integrity": "sha512-rNjApaLzuwaOTjCiT8lSDdGN1APCiqkChLMJxJPWLunPAt5fy8xgU9/jNOchV84wfIxrA0lRQB7oCT8jrn/wrQ==", + "license": "MIT", + "engines": { + "node": ">=6.0" + } + }, + "node_modules/clean-css": { + "version": "5.3.3", + "resolved": "https://registry.npmmirror.com/clean-css/-/clean-css-5.3.3.tgz", + "integrity": "sha512-D5J+kHaVb/wKSFcyyV75uCn8fiY4sV38XJoe4CUyGQ+mOU/fMVYUdH1hJC+CJQ5uY3EnW27SbJYS4X8BiLrAFg==", + "dev": true, + "license": "MIT", + "dependencies": { + "source-map": "~0.6.0" + }, + "engines": { + "node": ">= 10.0" + } + }, + "node_modules/clean-webpack-plugin": { + "version": "4.0.0", + "resolved": "https://registry.npmmirror.com/clean-webpack-plugin/-/clean-webpack-plugin-4.0.0.tgz", + "integrity": "sha512-WuWE1nyTNAyW5T7oNyys2EN0cfP2fdRxhxnIQWiAp0bMabPdHhoGxM8A6YL2GhqwgrPnnaemVE7nv5XJ2Fhh2w==", + "license": "MIT", + "dependencies": { + "del": "^4.1.1" + }, + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "webpack": ">=4.0.0 <6.0.0" + } + }, + "node_modules/clone-deep": { + "version": "4.0.1", + "resolved": "https://registry.npmmirror.com/clone-deep/-/clone-deep-4.0.1.tgz", + "integrity": "sha512-neHB9xuzh/wk0dIHweyAXv2aPGZIVk3pLMe+/RNzINf17fe0OG96QroktYAUm7SM1PBnzTabaLboqqxDyMU+SQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-plain-object": "^2.0.4", + "kind-of": "^6.0.2", + "shallow-clone": "^3.0.0" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmmirror.com/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "color-name": "~1.1.4" + }, + "engines": { + "node": ">=7.0.0" + } + }, + "node_modules/color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmmirror.com/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "dev": true, + "license": "MIT" + }, + "node_modules/colorette": { + "version": "2.0.20", + "resolved": "https://registry.npmmirror.com/colorette/-/colorette-2.0.20.tgz", + "integrity": "sha512-IfEDxwoWIjkeXL1eXcDiow4UbKjhLdq6/EuSVR9GMN7KVH3r9gQ83e73hsz1Nd1T3ijd5xv1wcWRYO+D6kCI2w==", + "dev": true, + "license": "MIT" + }, + "node_modules/combined-stream": { + "version": "1.0.8", + "resolved": "https://registry.npmmirror.com/combined-stream/-/combined-stream-1.0.8.tgz", + "integrity": "sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==", + "license": "MIT", + "dependencies": { + "delayed-stream": "~1.0.0" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/commander": { + "version": "2.20.3", + "resolved": "https://registry.npmmirror.com/commander/-/commander-2.20.3.tgz", + "integrity": "sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ==", + "license": "MIT" + }, + "node_modules/compressible": { + "version": "2.0.18", + "resolved": "https://registry.npmmirror.com/compressible/-/compressible-2.0.18.tgz", + "integrity": "sha512-AF3r7P5dWxL8MxyITRMlORQNaOA2IkAFaTr4k7BUumjPtRpGDTZpl0Pb1XCO6JeDCBdp126Cgs9sMxqSjgYyRg==", + "dev": true, + "license": "MIT", + "dependencies": { + "mime-db": ">= 1.43.0 < 2" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/compression": { + "version": "1.8.0", + "resolved": "https://registry.npmmirror.com/compression/-/compression-1.8.0.tgz", + "integrity": "sha512-k6WLKfunuqCYD3t6AsuPGvQWaKwuLLh2/xHNcX4qE+vIfDNXpSqnrhwA7O53R7WVQUnt8dVAIW+YHr7xTgOgGA==", + "dev": true, + "license": "MIT", + "dependencies": { + "bytes": "3.1.2", + "compressible": "~2.0.18", + "debug": "2.6.9", + "negotiator": "~0.6.4", + "on-headers": "~1.0.2", + "safe-buffer": "5.2.1", + "vary": "~1.1.2" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/concat-map": { + "version": "0.0.1", + "resolved": "https://registry.npmmirror.com/concat-map/-/concat-map-0.0.1.tgz", + "integrity": "sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==", + "license": "MIT" + }, + "node_modules/connect-history-api-fallback": { + "version": "2.0.0", + "resolved": "https://registry.npmmirror.com/connect-history-api-fallback/-/connect-history-api-fallback-2.0.0.tgz", + "integrity": "sha512-U73+6lQFmfiNPrYbXqr6kZ1i1wiRqXnp2nhMsINseWXO8lDau0LGEffJ8kQi4EjLZympVgRdvqjAgiZ1tgzDDA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.8" + } + }, + "node_modules/content-disposition": { + "version": "0.5.4", + "resolved": "https://registry.npmmirror.com/content-disposition/-/content-disposition-0.5.4.tgz", + "integrity": "sha512-FveZTNuGw04cxlAiWbzi6zTAL/lhehaWbTtgluJh4/E95DqMwTmha3KZN1aAWA8cFIhHzMZUvLevkw5Rqk+tSQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "safe-buffer": "5.2.1" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/content-type": { + "version": "1.0.5", + "resolved": "https://registry.npmmirror.com/content-type/-/content-type-1.0.5.tgz", + "integrity": "sha512-nTjqfcBFEipKdXCv4YDQWCfmcLZKm81ldF0pAopTvyrFGVbcR6P/VAAd5G7N+0tTr8QqiU0tFadD6FK4NtJwOA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/cookie": { + "version": "0.7.1", + "resolved": "https://registry.npmmirror.com/cookie/-/cookie-0.7.1.tgz", + "integrity": "sha512-6DnInpx7SJ2AK3+CTUE/ZM0vWTUboZCegxhC2xiIydHR9jNuTAASBrfEpHhiGOZw/nX51bHt6YQl8jsGo4y/0w==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/cookie-signature": { + "version": "1.0.6", + "resolved": "https://registry.npmmirror.com/cookie-signature/-/cookie-signature-1.0.6.tgz", + "integrity": "sha512-QADzlaHc8icV8I7vbaJXJwod9HWYp8uCqf1xa4OfNu1T7JVxQIrUgOWtHdNDtPiywmFbiS12VjotIXLrKM3orQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/core-util-is": { + "version": "1.0.3", + "resolved": "https://registry.npmmirror.com/core-util-is/-/core-util-is-1.0.3.tgz", + "integrity": "sha512-ZQBvi1DcpJ4GDqanjucZ2Hj3wEO5pZDS89BWbkcrvdxksJorwUDDZamX9ldFkp9aw2lmBDLgkObEA4DWNJ9FYQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/cross-env": { + "version": "7.0.3", + "resolved": "https://registry.npmmirror.com/cross-env/-/cross-env-7.0.3.tgz", + "integrity": "sha512-+/HKd6EgcQCJGh2PSjZuUitQBQynKor4wrFbRg4DtAgS1aWO+gU52xpH7M9ScGgXSYmAVS9bIJ8EzuaGw0oNAw==", + "license": "MIT", + "dependencies": { + "cross-spawn": "^7.0.1" + }, + "bin": { + "cross-env": "src/bin/cross-env.js", + "cross-env-shell": "src/bin/cross-env-shell.js" + }, + "engines": { + "node": ">=10.14", + "npm": ">=6", + "yarn": ">=1" + } + }, + "node_modules/cross-spawn": { + "version": "7.0.6", + "resolved": "https://registry.npmmirror.com/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", + "license": "MIT", + "dependencies": { + "path-key": "^3.1.0", + "shebang-command": "^2.0.0", + "which": "^2.0.1" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/css-loader": { + "version": "7.1.2", + "resolved": "https://registry.npmmirror.com/css-loader/-/css-loader-7.1.2.tgz", + "integrity": "sha512-6WvYYn7l/XEGN8Xu2vWFt9nVzrCn39vKyTEFf/ExEyoksJjjSZV/0/35XPlMbpnr6VGhZIUg5yJrL8tGfes/FA==", + "license": "MIT", + "dependencies": { + "icss-utils": "^5.1.0", + "postcss": "^8.4.33", + "postcss-modules-extract-imports": "^3.1.0", + "postcss-modules-local-by-default": "^4.0.5", + "postcss-modules-scope": "^3.2.0", + "postcss-modules-values": "^4.0.0", + "postcss-value-parser": "^4.2.0", + "semver": "^7.5.4" + }, + "engines": { + "node": ">= 18.12.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + }, + "peerDependencies": { + "@rspack/core": "0.x || 1.x", + "webpack": "^5.27.0" + }, + "peerDependenciesMeta": { + "@rspack/core": { + "optional": true + }, + "webpack": { + "optional": true + } + } + }, + "node_modules/css-select": { + "version": "4.3.0", + "resolved": "https://registry.npmmirror.com/css-select/-/css-select-4.3.0.tgz", + "integrity": "sha512-wPpOYtnsVontu2mODhA19JrqWxNsfdatRKd64kmpRbQgh1KtItko5sTnEpPdpSaJszTOhEMlF/RPz28qj4HqhQ==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "boolbase": "^1.0.0", + "css-what": "^6.0.1", + "domhandler": "^4.3.1", + "domutils": "^2.8.0", + "nth-check": "^2.0.1" + }, + "funding": { + "url": "https://github.com/sponsors/fb55" + } + }, + "node_modules/css-what": { + "version": "6.1.0", + "resolved": "https://registry.npmmirror.com/css-what/-/css-what-6.1.0.tgz", + "integrity": "sha512-HTUrgRJ7r4dsZKU6GjmpfRK1O76h97Z8MfS1G0FozR+oF2kG6Vfe8JE6zwrkbxigziPHinCJ+gCPjA9EaBDtRw==", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">= 6" + }, + "funding": { + "url": "https://github.com/sponsors/fb55" + } + }, + "node_modules/cssesc": { + "version": "3.0.0", + "resolved": "https://registry.npmmirror.com/cssesc/-/cssesc-3.0.0.tgz", + "integrity": "sha512-/Tb/JcjK111nNScGob5MNtsntNM1aCNUDipB/TkwZFhyDrrE47SOx/18wF2bbjgc3ZzCSKW1T5nt5EbFoAz/Vg==", + "license": "MIT", + "bin": { + "cssesc": "bin/cssesc" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/d3": { + "version": "7.9.0", + "resolved": "https://registry.npmmirror.com/d3/-/d3-7.9.0.tgz", + "integrity": "sha512-e1U46jVP+w7Iut8Jt8ri1YsPOvFpg46k+K8TpCb0P+zjCkjkPnV7WzfDJzMHy1LnA+wj5pLT1wjO901gLXeEhA==", + "dependencies": { + "d3-array": "3", + "d3-axis": "3", + "d3-brush": "3", + "d3-chord": "3", + "d3-color": "3", + "d3-contour": "4", + "d3-delaunay": "6", + "d3-dispatch": "3", + "d3-drag": "3", + "d3-dsv": "3", + "d3-ease": "3", + "d3-fetch": "3", + "d3-force": "3", + "d3-format": "3", + "d3-geo": "3", + "d3-hierarchy": "3", + "d3-interpolate": "3", + "d3-path": "3", + "d3-polygon": "3", + "d3-quadtree": "3", + "d3-random": "3", + "d3-scale": "4", + "d3-scale-chromatic": "3", + "d3-selection": "3", + "d3-shape": "3", + "d3-time": "3", + "d3-time-format": "4", + "d3-timer": "3", + "d3-transition": "3", + "d3-zoom": "3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-array": { + "version": "3.2.4", + "resolved": "https://registry.npmmirror.com/d3-array/-/d3-array-3.2.4.tgz", + "integrity": "sha512-tdQAmyA18i4J7wprpYq8ClcxZy3SC31QMeByyCFyRt7BVHdREQZ5lpzoe5mFEYZUWe+oq8HBvk9JjpibyEV4Jg==", + "dependencies": { + "internmap": "1 - 2" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-axis": { + "version": "3.0.0", + "resolved": "https://registry.npmmirror.com/d3-axis/-/d3-axis-3.0.0.tgz", + "integrity": "sha512-IH5tgjV4jE/GhHkRV0HiVYPDtvfjHQlQfJHs0usq7M30XcSBvOotpmH1IgkcXsO/5gEQZD43B//fc7SRT5S+xw==", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-brush": { + "version": "3.0.0", + "resolved": "https://registry.npmmirror.com/d3-brush/-/d3-brush-3.0.0.tgz", + "integrity": "sha512-ALnjWlVYkXsVIGlOsuWH1+3udkYFI48Ljihfnh8FZPF2QS9o+PzGLBslO0PjzVoHLZ2KCVgAM8NVkXPJB2aNnQ==", + "dependencies": { + "d3-dispatch": "1 - 3", + "d3-drag": "2 - 3", + "d3-interpolate": "1 - 3", + "d3-selection": "3", + "d3-transition": "3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-chord": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/d3-chord/-/d3-chord-3.0.1.tgz", + "integrity": "sha512-VE5S6TNa+j8msksl7HwjxMHDM2yNK3XCkusIlpX5kwauBfXuyLAtNg9jCp/iHH61tgI4sb6R/EIMWCqEIdjT/g==", + "dependencies": { + "d3-path": "1 - 3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-color": { + "version": "3.1.0", + "resolved": "https://registry.npmmirror.com/d3-color/-/d3-color-3.1.0.tgz", + "integrity": "sha512-zg/chbXyeBtMQ1LbD/WSoW2DpC3I0mpmPdW+ynRTj/x2DAWYrIY7qeZIHidozwV24m4iavr15lNwIwLxRmOxhA==", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-contour": { + "version": "4.0.2", + "resolved": "https://registry.npmmirror.com/d3-contour/-/d3-contour-4.0.2.tgz", + "integrity": "sha512-4EzFTRIikzs47RGmdxbeUvLWtGedDUNkTcmzoeyg4sP/dvCexO47AaQL7VKy/gul85TOxw+IBgA8US2xwbToNA==", + "dependencies": { + "d3-array": "^3.2.0" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-delaunay": { + "version": "6.0.4", + "resolved": "https://registry.npmmirror.com/d3-delaunay/-/d3-delaunay-6.0.4.tgz", + "integrity": "sha512-mdjtIZ1XLAM8bm/hx3WwjfHt6Sggek7qH043O8KEjDXN40xi3vx/6pYSVTwLjEgiXQTbvaouWKynLBiUZ6SK6A==", + "dependencies": { + "delaunator": "5" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-dispatch": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/d3-dispatch/-/d3-dispatch-3.0.1.tgz", + "integrity": "sha512-rzUyPU/S7rwUflMyLc1ETDeBj0NRuHKKAcvukozwhshr6g6c5d8zh4c2gQjY2bZ0dXeGLWc1PF174P2tVvKhfg==", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-drag": { + "version": "3.0.0", + "resolved": "https://registry.npmmirror.com/d3-drag/-/d3-drag-3.0.0.tgz", + "integrity": "sha512-pWbUJLdETVA8lQNJecMxoXfH6x+mO2UQo8rSmZ+QqxcbyA3hfeprFgIT//HW2nlHChWeIIMwS2Fq+gEARkhTkg==", + "dependencies": { + "d3-dispatch": "1 - 3", + "d3-selection": "3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-dsv": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/d3-dsv/-/d3-dsv-3.0.1.tgz", + "integrity": "sha512-UG6OvdI5afDIFP9w4G0mNq50dSOsXHJaRE8arAS5o9ApWnIElp8GZw1Dun8vP8OyHOZ/QJUKUJwxiiCCnUwm+Q==", + "dependencies": { + "commander": "7", + "iconv-lite": "0.6", + "rw": "1" + }, + "bin": { + "csv2json": "bin/dsv2json.js", + "csv2tsv": "bin/dsv2dsv.js", + "dsv2dsv": "bin/dsv2dsv.js", + "dsv2json": "bin/dsv2json.js", + "json2csv": "bin/json2dsv.js", + "json2dsv": "bin/json2dsv.js", + "json2tsv": "bin/json2dsv.js", + "tsv2csv": "bin/dsv2dsv.js", + "tsv2json": "bin/dsv2json.js" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-dsv/node_modules/commander": { + "version": "7.2.0", + "resolved": "https://registry.npmmirror.com/commander/-/commander-7.2.0.tgz", + "integrity": "sha512-QrWXB+ZQSVPmIWIhtEO9H+gwHaMGYiF5ChvoJ+K9ZGHG/sVsa6yiesAD1GC/x46sET00Xlwo1u49RVVVzvcSkw==", + "engines": { + "node": ">= 10" + } + }, + "node_modules/d3-dsv/node_modules/iconv-lite": { + "version": "0.6.3", + "resolved": "https://registry.npmmirror.com/iconv-lite/-/iconv-lite-0.6.3.tgz", + "integrity": "sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw==", + "dependencies": { + "safer-buffer": ">= 2.1.2 < 3.0.0" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/d3-ease": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/d3-ease/-/d3-ease-3.0.1.tgz", + "integrity": "sha512-wR/XK3D3XcLIZwpbvQwQ5fK+8Ykds1ip7A2Txe0yxncXSdq1L9skcG7blcedkOX+ZcgxGAmLX1FrRGbADwzi0w==", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-fetch": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/d3-fetch/-/d3-fetch-3.0.1.tgz", + "integrity": "sha512-kpkQIM20n3oLVBKGg6oHrUchHM3xODkTzjMoj7aWQFq5QEM+R6E4WkzT5+tojDY7yjez8KgCBRoj4aEr99Fdqw==", + "dependencies": { + "d3-dsv": "1 - 3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-force": { + "version": "3.0.0", + "resolved": "https://registry.npmmirror.com/d3-force/-/d3-force-3.0.0.tgz", + "integrity": "sha512-zxV/SsA+U4yte8051P4ECydjD/S+qeYtnaIyAs9tgHCqfguma/aAQDjo85A9Z6EKhBirHRJHXIgJUlffT4wdLg==", + "dependencies": { + "d3-dispatch": "1 - 3", + "d3-quadtree": "1 - 3", + "d3-timer": "1 - 3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-format": { + "version": "3.1.0", + "resolved": "https://registry.npmmirror.com/d3-format/-/d3-format-3.1.0.tgz", + "integrity": "sha512-YyUI6AEuY/Wpt8KWLgZHsIU86atmikuoOmCfommt0LYHiQSPjvX2AcFc38PX0CBpr2RCyZhjex+NS/LPOv6YqA==", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-geo": { + "version": "3.1.1", + "resolved": "https://registry.npmmirror.com/d3-geo/-/d3-geo-3.1.1.tgz", + "integrity": "sha512-637ln3gXKXOwhalDzinUgY83KzNWZRKbYubaG+fGVuc/dxO64RRljtCTnf5ecMyE1RIdtqpkVcq0IbtU2S8j2Q==", + "dependencies": { + "d3-array": "2.5.0 - 3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-hierarchy": { + "version": "3.1.2", + "resolved": "https://registry.npmmirror.com/d3-hierarchy/-/d3-hierarchy-3.1.2.tgz", + "integrity": "sha512-FX/9frcub54beBdugHjDCdikxThEqjnR93Qt7PvQTOHxyiNCAlvMrHhclk3cD5VeAaq9fxmfRp+CnWw9rEMBuA==", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-interpolate": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/d3-interpolate/-/d3-interpolate-3.0.1.tgz", + "integrity": "sha512-3bYs1rOD33uo8aqJfKP3JWPAibgw8Zm2+L9vBKEHJ2Rg+viTR7o5Mmv5mZcieN+FRYaAOWX5SJATX6k1PWz72g==", + "dependencies": { + "d3-color": "1 - 3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-path": { + "version": "3.1.0", + "resolved": "https://registry.npmmirror.com/d3-path/-/d3-path-3.1.0.tgz", + "integrity": "sha512-p3KP5HCf/bvjBSSKuXid6Zqijx7wIfNW+J/maPs+iwR35at5JCbLUT0LzF1cnjbCHWhqzQTIN2Jpe8pRebIEFQ==", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-polygon": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/d3-polygon/-/d3-polygon-3.0.1.tgz", + "integrity": "sha512-3vbA7vXYwfe1SYhED++fPUQlWSYTTGmFmQiany/gdbiWgU/iEyQzyymwL9SkJjFFuCS4902BSzewVGsHHmHtXg==", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-quadtree": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/d3-quadtree/-/d3-quadtree-3.0.1.tgz", + "integrity": "sha512-04xDrxQTDTCFwP5H6hRhsRcb9xxv2RzkcsygFzmkSIOJy3PeRJP7sNk3VRIbKXcog561P9oU0/rVH6vDROAgUw==", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-random": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/d3-random/-/d3-random-3.0.1.tgz", + "integrity": "sha512-FXMe9GfxTxqd5D6jFsQ+DJ8BJS4E/fT5mqqdjovykEB2oFbTMDVdg1MGFxfQW+FBOGoB++k8swBrgwSHT1cUXQ==", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-scale": { + "version": "4.0.2", + "resolved": "https://registry.npmmirror.com/d3-scale/-/d3-scale-4.0.2.tgz", + "integrity": "sha512-GZW464g1SH7ag3Y7hXjf8RoUuAFIqklOAq3MRl4OaWabTFJY9PN/E1YklhXLh+OQ3fM9yS2nOkCoS+WLZ6kvxQ==", + "dependencies": { + "d3-array": "2.10.0 - 3", + "d3-format": "1 - 3", + "d3-interpolate": "1.2.0 - 3", + "d3-time": "2.1.1 - 3", + "d3-time-format": "2 - 4" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-scale-chromatic": { + "version": "3.1.0", + "resolved": "https://registry.npmmirror.com/d3-scale-chromatic/-/d3-scale-chromatic-3.1.0.tgz", + "integrity": "sha512-A3s5PWiZ9YCXFye1o246KoscMWqf8BsD9eRiJ3He7C9OBaxKhAd5TFCdEx/7VbKtxxTsu//1mMJFrEt572cEyQ==", + "dependencies": { + "d3-color": "1 - 3", + "d3-interpolate": "1 - 3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-selection": { + "version": "3.0.0", + "resolved": "https://registry.npmmirror.com/d3-selection/-/d3-selection-3.0.0.tgz", + "integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-shape": { + "version": "3.2.0", + "resolved": "https://registry.npmmirror.com/d3-shape/-/d3-shape-3.2.0.tgz", + "integrity": "sha512-SaLBuwGm3MOViRq2ABk3eLoxwZELpH6zhl3FbAoJ7Vm1gofKx6El1Ib5z23NUEhF9AsGl7y+dzLe5Cw2AArGTA==", + "dependencies": { + "d3-path": "^3.1.0" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-time": { + "version": "3.1.0", + "resolved": "https://registry.npmmirror.com/d3-time/-/d3-time-3.1.0.tgz", + "integrity": "sha512-VqKjzBLejbSMT4IgbmVgDjpkYrNWUYJnbCGo874u7MMKIWsILRX+OpX/gTk8MqjpT1A/c6HY2dCA77ZN0lkQ2Q==", + "dependencies": { + "d3-array": "2 - 3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-time-format": { + "version": "4.1.0", + "resolved": "https://registry.npmmirror.com/d3-time-format/-/d3-time-format-4.1.0.tgz", + "integrity": "sha512-dJxPBlzC7NugB2PDLwo9Q8JiTR3M3e4/XANkreKSUxF8vvXKqm1Yfq4Q5dl8budlunRVlUUaDUgFt7eA8D6NLg==", + "dependencies": { + "d3-time": "1 - 3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-timer": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/d3-timer/-/d3-timer-3.0.1.tgz", + "integrity": "sha512-ndfJ/JxxMd3nw31uyKoY2naivF+r29V+Lc0svZxe1JvvIRmi8hUsrMvdOwgS1o6uBHmiz91geQ0ylPP0aj1VUA==", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-transition": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/d3-transition/-/d3-transition-3.0.1.tgz", + "integrity": "sha512-ApKvfjsSR6tg06xrL434C0WydLr7JewBB3V+/39RMHsaXTOG0zmt/OAXeng5M5LBm0ojmxJrpomQVZ1aPvBL4w==", + "dependencies": { + "d3-color": "1 - 3", + "d3-dispatch": "1 - 3", + "d3-ease": "1 - 3", + "d3-interpolate": "1 - 3", + "d3-timer": "1 - 3" + }, + "engines": { + "node": ">=12" + }, + "peerDependencies": { + "d3-selection": "2 - 3" + } + }, + "node_modules/d3-zoom": { + "version": "3.0.0", + "resolved": "https://registry.npmmirror.com/d3-zoom/-/d3-zoom-3.0.0.tgz", + "integrity": "sha512-b8AmV3kfQaqWAuacbPuNbL6vahnOJflOhexLzMMNLga62+/nh0JzvJ0aO/5a5MVgUFGS7Hu1P9P03o3fJkDCyw==", + "dependencies": { + "d3-dispatch": "1 - 3", + "d3-drag": "2 - 3", + "d3-interpolate": "1 - 3", + "d3-selection": "2 - 3", + "d3-transition": "2 - 3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/dagre": { + "version": "0.8.5", + "resolved": "https://registry.npmmirror.com/dagre/-/dagre-0.8.5.tgz", + "integrity": "sha512-/aTqmnRta7x7MCCpExk7HQL2O4owCT2h8NT//9I1OQ9vt29Pa0BzSAkR5lwFUcQ7491yVi/3CXU9jQ5o0Mn2Sw==", + "license": "MIT", + "dependencies": { + "graphlib": "^2.1.8", + "lodash": "^4.17.15" + } + }, + "node_modules/debug": { + "version": "2.6.9", + "resolved": "https://registry.npmmirror.com/debug/-/debug-2.6.9.tgz", + "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", + "dev": true, + "license": "MIT", + "dependencies": { + "ms": "2.0.0" + } + }, + "node_modules/default-gateway": { + "version": "6.0.3", + "resolved": "https://registry.npmmirror.com/default-gateway/-/default-gateway-6.0.3.tgz", + "integrity": "sha512-fwSOJsbbNzZ/CUFpqFBqYfYNLj1NbMPm8MMCIzHjC83iSJRBEGmDUxU+WP661BaBQImeC2yHwXtz+P/O9o+XEg==", + "dev": true, + "dependencies": { + "execa": "^5.0.0" + }, + "engines": { + "node": ">= 10" + } + }, + "node_modules/define-lazy-prop": { + "version": "2.0.0", + "resolved": "https://registry.npmmirror.com/define-lazy-prop/-/define-lazy-prop-2.0.0.tgz", + "integrity": "sha512-Ds09qNh8yw3khSjiJjiUInaGX9xlqZDY7JVryGxdxV7NPeuqQfplOpQ66yJFZut3jLa5zOwkXw1g9EI2uKh4Og==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/del": { + "version": "4.1.1", + "resolved": "https://registry.npmmirror.com/del/-/del-4.1.1.tgz", + "integrity": "sha512-QwGuEUouP2kVwQenAsOof5Fv8K9t3D8Ca8NxcXKrIpEHjTXK5J2nXLdP+ALI1cgv8wj7KuwBhTwBkOZSJKM5XQ==", + "license": "MIT", + "dependencies": { + "@types/glob": "^7.1.1", + "globby": "^6.1.0", + "is-path-cwd": "^2.0.0", + "is-path-in-cwd": "^2.0.0", + "p-map": "^2.0.0", + "pify": "^4.0.1", + "rimraf": "^2.6.3" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/delaunator": { + "version": "5.0.1", + "resolved": "https://registry.npmmirror.com/delaunator/-/delaunator-5.0.1.tgz", + "integrity": "sha512-8nvh+XBe96aCESrGOqMp/84b13H9cdKbG5P2ejQCh4d4sK9RL4371qou9drQjMhvnPmhWl5hnmqbEE0fXr9Xnw==", + "dependencies": { + "robust-predicates": "^3.0.2" + } + }, + "node_modules/delayed-stream": { + "version": "1.0.0", + "resolved": "https://registry.npmmirror.com/delayed-stream/-/delayed-stream-1.0.0.tgz", + "integrity": "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==", + "license": "MIT", + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/depd": { + "version": "2.0.0", + "resolved": "https://registry.npmmirror.com/depd/-/depd-2.0.0.tgz", + "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/destroy": { + "version": "1.2.0", + "resolved": "https://registry.npmmirror.com/destroy/-/destroy-1.2.0.tgz", + "integrity": "sha512-2sJGJTaXIIaR1w4iJSNoN0hnMY7Gpc/n8D4qSCJw8QqFWXf7cuAgnEHxBpweaVcPevC2l3KpjYCx3NypQQgaJg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.8", + "npm": "1.2.8000 || >= 1.4.16" + } + }, + "node_modules/detect-node": { + "version": "2.1.0", + "resolved": "https://registry.npmmirror.com/detect-node/-/detect-node-2.1.0.tgz", + "integrity": "sha512-T0NIuQpnTvFDATNuHN5roPwSBG83rFsuO+MXXH9/3N1eFbn4wcPjttvjMLEPWJ0RGUYgQE7cGgS3tNxbqCGM7g==", + "dev": true, + "license": "MIT" + }, + "node_modules/dns-packet": { + "version": "5.6.1", + "resolved": "https://registry.npmmirror.com/dns-packet/-/dns-packet-5.6.1.tgz", + "integrity": "sha512-l4gcSouhcgIKRvyy99RNVOgxXiicE+2jZoNmaNmZ6JXiGajBOJAesk1OBlJuM5k2c+eudGdLxDqXuPCKIj6kpw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@leichtgewicht/ip-codec": "^2.0.1" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/dom-converter": { + "version": "0.2.0", + "resolved": "https://registry.npmmirror.com/dom-converter/-/dom-converter-0.2.0.tgz", + "integrity": "sha512-gd3ypIPfOMr9h5jIKq8E3sHOTCjeirnl0WK5ZdS1AW0Odt0b1PaWaHdJ4Qk4klv+YB9aJBS7mESXjFoDQPu6DA==", + "dev": true, + "license": "MIT", + "dependencies": { + "utila": "~0.4" + } + }, + "node_modules/dom-serializer": { + "version": "1.4.1", + "resolved": "https://registry.npmmirror.com/dom-serializer/-/dom-serializer-1.4.1.tgz", + "integrity": "sha512-VHwB3KfrcOOkelEG2ZOfxqLZdfkil8PtJi4P8N2MMXucZq2yLp75ClViUlOVwyoHEDjYU433Aq+5zWP61+RGag==", + "dev": true, + "license": "MIT", + "dependencies": { + "domelementtype": "^2.0.1", + "domhandler": "^4.2.0", + "entities": "^2.0.0" + }, + "funding": { + "url": "https://github.com/cheeriojs/dom-serializer?sponsor=1" + } + }, + "node_modules/dom-serializer/node_modules/entities": { + "version": "2.2.0", + "resolved": "https://registry.npmmirror.com/entities/-/entities-2.2.0.tgz", + "integrity": "sha512-p92if5Nz619I0w+akJrLZH0MX0Pb5DX39XOwQTtXSdQQOaYH03S1uIQp4mhOZtAXrxq4ViO67YTiLBo2638o9A==", + "dev": true, + "license": "BSD-2-Clause", + "funding": { + "url": "https://github.com/fb55/entities?sponsor=1" + } + }, + "node_modules/domelementtype": { + "version": "2.3.0", + "resolved": "https://registry.npmmirror.com/domelementtype/-/domelementtype-2.3.0.tgz", + "integrity": "sha512-OLETBj6w0OsagBwdXnPdN0cnMfF9opN69co+7ZrbfPGrdpPVNBUj02spi6B1N7wChLQiPn4CSH/zJvXw56gmHw==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/fb55" + } + ], + "license": "BSD-2-Clause" + }, + "node_modules/domhandler": { + "version": "4.3.1", + "resolved": "https://registry.npmmirror.com/domhandler/-/domhandler-4.3.1.tgz", + "integrity": "sha512-GrwoxYN+uWlzO8uhUXRl0P+kHE4GtVPfYzVLcUxPL7KNdHKj66vvlhiweIHqYYXWlw+T8iLMp42Lm67ghw4WMQ==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "domelementtype": "^2.2.0" + }, + "engines": { + "node": ">= 4" + }, + "funding": { + "url": "https://github.com/fb55/domhandler?sponsor=1" + } + }, + "node_modules/domutils": { + "version": "2.8.0", + "resolved": "https://registry.npmmirror.com/domutils/-/domutils-2.8.0.tgz", + "integrity": "sha512-w96Cjofp72M5IIhpjgobBimYEfoPjx1Vx0BSX9P30WBdZW2WIKU0T1Bd0kz2eNZ9ikjKgHbEyKx8BB6H1L3h3A==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "dom-serializer": "^1.0.1", + "domelementtype": "^2.2.0", + "domhandler": "^4.2.0" + }, + "funding": { + "url": "https://github.com/fb55/domutils?sponsor=1" + } + }, + "node_modules/dot-case": { + "version": "3.0.4", + "resolved": "https://registry.npmmirror.com/dot-case/-/dot-case-3.0.4.tgz", + "integrity": "sha512-Kv5nKlh6yRrdrGvxeJ2e5y2eRUpkUosIW4A2AS38zwSz27zu7ufDwQPi5Jhs3XAlGNetl3bmnGhQsMtkKJnj3w==", + "dev": true, + "license": "MIT", + "dependencies": { + "no-case": "^3.0.4", + "tslib": "^2.0.3" + } + }, + "node_modules/dunder-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmmirror.com/dunder-proto/-/dunder-proto-1.0.1.tgz", + "integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==", + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.1", + "es-errors": "^1.3.0", + "gopd": "^1.2.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/ee-first": { + "version": "1.1.1", + "resolved": "https://registry.npmmirror.com/ee-first/-/ee-first-1.1.1.tgz", + "integrity": "sha512-WMwm9LhRUo+WUaRN+vRuETqG89IgZphVSNkdFgeb6sS/E4OrDIN7t48CAewSHXc6C8lefD8KKfr5vY61brQlow==", + "dev": true, + "license": "MIT" + }, + "node_modules/electron-to-chromium": { + "version": "1.5.126", + "resolved": "https://registry.npmmirror.com/electron-to-chromium/-/electron-to-chromium-1.5.126.tgz", + "integrity": "sha512-AtH1uLcTC72LA4vfYcEJJkrMk/MY/X0ub8Hv7QGAePW2JkeUFHEL/QfS4J77R6M87Sss8O0OcqReSaN1bpyA+Q==", + "license": "ISC" + }, + "node_modules/encodeurl": { + "version": "2.0.0", + "resolved": "https://registry.npmmirror.com/encodeurl/-/encodeurl-2.0.0.tgz", + "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/enhanced-resolve": { + "version": "5.18.1", + "resolved": "https://registry.npmmirror.com/enhanced-resolve/-/enhanced-resolve-5.18.1.tgz", + "integrity": "sha512-ZSW3ma5GkcQBIpwZTSRAI8N71Uuwgs93IezB7mf7R60tC8ZbJideoDNKjHn2O9KIlx6rkGTTEk1xUCK2E1Y2Yg==", + "license": "MIT", + "dependencies": { + "graceful-fs": "^4.2.4", + "tapable": "^2.2.0" + }, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/entities": { + "version": "4.5.0", + "resolved": "https://registry.npmmirror.com/entities/-/entities-4.5.0.tgz", + "integrity": "sha512-V0hjH4dGPh9Ao5p0MoRY6BVqtwCjhz6vI5LT8AJ55H+4g9/4vbHx1I54fS0XuclLhDHArPQCiMjDxjaL8fPxhw==", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">=0.12" + }, + "funding": { + "url": "https://github.com/fb55/entities?sponsor=1" + } + }, + "node_modules/envinfo": { + "version": "7.14.0", + "resolved": "https://registry.npmmirror.com/envinfo/-/envinfo-7.14.0.tgz", + "integrity": "sha512-CO40UI41xDQzhLB1hWyqUKgFhs250pNcGbyGKe1l/e4FSaI/+YE4IMG76GDt0In67WLPACIITC+sOi08x4wIvg==", + "dev": true, + "license": "MIT", + "bin": { + "envinfo": "dist/cli.js" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/es-define-property": { + "version": "1.0.1", + "resolved": "https://registry.npmmirror.com/es-define-property/-/es-define-property-1.0.1.tgz", + "integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-errors": { + "version": "1.3.0", + "resolved": "https://registry.npmmirror.com/es-errors/-/es-errors-1.3.0.tgz", + "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-module-lexer": { + "version": "1.6.0", + "resolved": "https://registry.npmmirror.com/es-module-lexer/-/es-module-lexer-1.6.0.tgz", + "integrity": "sha512-qqnD1yMU6tk/jnaMosogGySTZP8YtUgAffA9nMN+E/rjxcfRQ6IEk7IiozUjgxKoFHBGjTLnrHB/YC45r/59EQ==", + "license": "MIT" + }, + "node_modules/es-object-atoms": { + "version": "1.1.1", + "resolved": "https://registry.npmmirror.com/es-object-atoms/-/es-object-atoms-1.1.1.tgz", + "integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-set-tostringtag": { + "version": "2.1.0", + "resolved": "https://registry.npmmirror.com/es-set-tostringtag/-/es-set-tostringtag-2.1.0.tgz", + "integrity": "sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.6", + "has-tostringtag": "^1.0.2", + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/escalade": { + "version": "3.2.0", + "resolved": "https://registry.npmmirror.com/escalade/-/escalade-3.2.0.tgz", + "integrity": "sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA==", + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/escape-html": { + "version": "1.0.3", + "resolved": "https://registry.npmmirror.com/escape-html/-/escape-html-1.0.3.tgz", + "integrity": "sha512-NiSupZ4OeuGwr68lGIeym/ksIZMJodUGOSCZ/FSnTxcrekbvqrgdUxlJOMpijaKZVjAJrWrGs/6Jy8OMuyj9ow==", + "dev": true, + "license": "MIT" + }, + "node_modules/eslint-scope": { + "version": "5.1.1", + "resolved": "https://registry.npmmirror.com/eslint-scope/-/eslint-scope-5.1.1.tgz", + "integrity": "sha512-2NxwbF/hZ0KpepYN0cNbo+FN6XoK7GaHlQhgx/hIZl6Va0bF45RQOOwhLIy8lQDbuCiadSLCBnH2CFYquit5bw==", + "license": "BSD-2-Clause", + "dependencies": { + "esrecurse": "^4.3.0", + "estraverse": "^4.1.1" + }, + "engines": { + "node": ">=8.0.0" + } + }, + "node_modules/esrecurse": { + "version": "4.3.0", + "resolved": "https://registry.npmmirror.com/esrecurse/-/esrecurse-4.3.0.tgz", + "integrity": "sha512-KmfKL3b6G+RXvP8N1vr3Tq1kL/oCFgn2NYXEtqP8/L3pKapUA4G8cFVaoF3SU323CD4XypR/ffioHmkti6/Tag==", + "license": "BSD-2-Clause", + "dependencies": { + "estraverse": "^5.2.0" + }, + "engines": { + "node": ">=4.0" + } + }, + "node_modules/esrecurse/node_modules/estraverse": { + "version": "5.3.0", + "resolved": "https://registry.npmmirror.com/estraverse/-/estraverse-5.3.0.tgz", + "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", + "license": "BSD-2-Clause", + "engines": { + "node": ">=4.0" + } + }, + "node_modules/estraverse": { + "version": "4.3.0", + "resolved": "https://registry.npmmirror.com/estraverse/-/estraverse-4.3.0.tgz", + "integrity": "sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw==", + "license": "BSD-2-Clause", + "engines": { + "node": ">=4.0" + } + }, + "node_modules/etag": { + "version": "1.8.1", + "resolved": "https://registry.npmmirror.com/etag/-/etag-1.8.1.tgz", + "integrity": "sha512-aIL5Fx7mawVa300al2BnEE4iNvo1qETxLrPI/o05L7z6go7fCw1J6EQmbK4FmJ2AS7kgVF/KEZWufBfdClMcPg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/eventemitter3": { + "version": "4.0.7", + "resolved": "https://registry.npmmirror.com/eventemitter3/-/eventemitter3-4.0.7.tgz", + "integrity": "sha512-8guHBZCwKnFhYdHr2ysuRWErTwhoN2X8XELRlrRwpmfeY2jjuUN4taQMsULKUVo1K4DvZl+0pgfyoysHxvmvEw==", + "dev": true + }, + "node_modules/events": { + "version": "3.3.0", + "resolved": "https://registry.npmmirror.com/events/-/events-3.3.0.tgz", + "integrity": "sha512-mQw+2fkQbALzQ7V0MY0IqdnXNOeTtP4r0lN9z7AAawCXgqea7bDii20AYrIBrFd/Hx0M2Ocz6S111CaFkUcb0Q==", + "license": "MIT", + "engines": { + "node": ">=0.8.x" + } + }, + "node_modules/execa": { + "version": "5.1.1", + "resolved": "https://registry.npmmirror.com/execa/-/execa-5.1.1.tgz", + "integrity": "sha512-8uSpZZocAZRBAPIEINJj3Lo9HyGitllczc27Eh5YYojjMFMn8yHMDMaUHE2Jqfq05D/wucwI4JGURyXt1vchyg==", + "dev": true, + "dependencies": { + "cross-spawn": "^7.0.3", + "get-stream": "^6.0.0", + "human-signals": "^2.1.0", + "is-stream": "^2.0.0", + "merge-stream": "^2.0.0", + "npm-run-path": "^4.0.1", + "onetime": "^5.1.2", + "signal-exit": "^3.0.3", + "strip-final-newline": "^2.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sindresorhus/execa?sponsor=1" + } + }, + "node_modules/express": { + "version": "4.21.2", + "resolved": "https://registry.npmmirror.com/express/-/express-4.21.2.tgz", + "integrity": "sha512-28HqgMZAmih1Czt9ny7qr6ek2qddF4FclbMzwhCREB6OFfH+rXAnuNCwo1/wFvrtbgsQDb4kSbX9de9lFbrXnA==", + "dev": true, + "license": "MIT", + "dependencies": { + "accepts": "~1.3.8", + "array-flatten": "1.1.1", + "body-parser": "1.20.3", + "content-disposition": "0.5.4", + "content-type": "~1.0.4", + "cookie": "0.7.1", + "cookie-signature": "1.0.6", + "debug": "2.6.9", + "depd": "2.0.0", + "encodeurl": "~2.0.0", + "escape-html": "~1.0.3", + "etag": "~1.8.1", + "finalhandler": "1.3.1", + "fresh": "0.5.2", + "http-errors": "2.0.0", + "merge-descriptors": "1.0.3", + "methods": "~1.1.2", + "on-finished": "2.4.1", + "parseurl": "~1.3.3", + "path-to-regexp": "0.1.12", + "proxy-addr": "~2.0.7", + "qs": "6.13.0", + "range-parser": "~1.2.1", + "safe-buffer": "5.2.1", + "send": "0.19.0", + "serve-static": "1.16.2", + "setprototypeof": "1.2.0", + "statuses": "2.0.1", + "type-is": "~1.6.18", + "utils-merge": "1.0.1", + "vary": "~1.1.2" + }, + "engines": { + "node": ">= 0.10.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/fast-deep-equal": { + "version": "3.1.3", + "resolved": "https://registry.npmmirror.com/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", + "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", + "license": "MIT" + }, + "node_modules/fast-uri": { + "version": "3.0.6", + "resolved": "https://registry.npmmirror.com/fast-uri/-/fast-uri-3.0.6.tgz", + "integrity": "sha512-Atfo14OibSv5wAp4VWNsFYE1AchQRTv9cBGWET4pZWHzYshFSS9NQI6I57rdKn9croWVMbYFbLhJ+yJvmZIIHw==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/fastify" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/fastify" + } + ], + "license": "BSD-3-Clause" + }, + "node_modules/fastest-levenshtein": { + "version": "1.0.16", + "resolved": "https://registry.npmmirror.com/fastest-levenshtein/-/fastest-levenshtein-1.0.16.tgz", + "integrity": "sha512-eRnCtTTtGZFpQCwhJiUOuxPQWRXVKYDn0b2PeHfXL6/Zi53SLAzAHfVhVWK2AryC/WH05kGfxhFIPvTF0SXQzg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 4.9.1" + } + }, + "node_modules/faye-websocket": { + "version": "0.11.4", + "resolved": "https://registry.npmmirror.com/faye-websocket/-/faye-websocket-0.11.4.tgz", + "integrity": "sha512-CzbClwlXAuiRQAlUyfqPgvPoNKTckTPGfwZV4ZdAhVcP2lh9KUxJg2b5GkE7XbjKQ3YJnQ9z6D9ntLAlB+tP8g==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "websocket-driver": ">=0.5.1" + }, + "engines": { + "node": ">=0.8.0" + } + }, + "node_modules/fill-range": { + "version": "7.1.1", + "resolved": "https://registry.npmmirror.com/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", + "dev": true, + "license": "MIT", + "dependencies": { + "to-regex-range": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/finalhandler": { + "version": "1.3.1", + "resolved": "https://registry.npmmirror.com/finalhandler/-/finalhandler-1.3.1.tgz", + "integrity": "sha512-6BN9trH7bp3qvnrRyzsBz+g3lZxTNZTbVO2EV1CS0WIcDbawYVdYvGflME/9QP0h0pYlCDBCTjYa9nZzMDpyxQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "debug": "2.6.9", + "encodeurl": "~2.0.0", + "escape-html": "~1.0.3", + "on-finished": "2.4.1", + "parseurl": "~1.3.3", + "statuses": "2.0.1", + "unpipe": "~1.0.0" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/find-up": { + "version": "4.1.0", + "resolved": "https://registry.npmmirror.com/find-up/-/find-up-4.1.0.tgz", + "integrity": "sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw==", + "dev": true, + "license": "MIT", + "dependencies": { + "locate-path": "^5.0.0", + "path-exists": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/flat": { + "version": "5.0.2", + "resolved": "https://registry.npmmirror.com/flat/-/flat-5.0.2.tgz", + "integrity": "sha512-b6suED+5/3rTpUBdG1gupIl8MPFCAMA0QXwmljLhvCUKcUvdE4gWky9zpuGCcXHOsz4J9wPGNWq6OKpmIzz3hQ==", + "dev": true, + "license": "BSD-3-Clause", + "bin": { + "flat": "cli.js" + } + }, + "node_modules/follow-redirects": { + "version": "1.15.9", + "resolved": "https://registry.npmmirror.com/follow-redirects/-/follow-redirects-1.15.9.tgz", + "integrity": "sha512-gew4GsXizNgdoRyqmyfMHyAmXsZDk6mHkSxZFCzW9gwlbtOW44CDtYavM+y+72qD/Vq2l550kMF52DT8fOLJqQ==", + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/RubenVerborgh" + } + ], + "license": "MIT", + "engines": { + "node": ">=4.0" + }, + "peerDependenciesMeta": { + "debug": { + "optional": true + } + } + }, + "node_modules/form-data": { + "version": "4.0.4", + "resolved": "https://registry.npmmirror.com/form-data/-/form-data-4.0.4.tgz", + "integrity": "sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==", + "license": "MIT", + "dependencies": { + "asynckit": "^0.4.0", + "combined-stream": "^1.0.8", + "es-set-tostringtag": "^2.1.0", + "hasown": "^2.0.2", + "mime-types": "^2.1.12" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/forwarded": { + "version": "0.2.0", + "resolved": "https://registry.npmmirror.com/forwarded/-/forwarded-0.2.0.tgz", + "integrity": "sha512-buRG0fpBtRHSTCOASe6hD258tEubFoRLb4ZNA6NxMVHNw2gOcwHo9wyablzMzOA5z9xA9L1KNjk/Nt6MT9aYow==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/fresh": { + "version": "0.5.2", + "resolved": "https://registry.npmmirror.com/fresh/-/fresh-0.5.2.tgz", + "integrity": "sha512-zJ2mQYM18rEFOudeV4GShTGIQ7RbzA7ozbU9I/XBpm7kqgMywgmylMwXHxZJmkVoYkna9d2pVXVXPdYTP9ej8Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/fs-monkey": { + "version": "1.0.6", + "resolved": "https://registry.npmmirror.com/fs-monkey/-/fs-monkey-1.0.6.tgz", + "integrity": "sha512-b1FMfwetIKymC0eioW7mTywihSQE4oLzQn1dB6rZB5fx/3NpNEdAWeCSMB+60/AeT0TCXsxzAlcYVEFCTAksWg==", + "dev": true + }, + "node_modules/fs.realpath": { + "version": "1.0.0", + "resolved": "https://registry.npmmirror.com/fs.realpath/-/fs.realpath-1.0.0.tgz", + "integrity": "sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==", + "license": "ISC" + }, + "node_modules/function-bind": { + "version": "1.1.2", + "resolved": "https://registry.npmmirror.com/function-bind/-/function-bind-1.1.2.tgz", + "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/get-intrinsic": { + "version": "1.3.0", + "resolved": "https://registry.npmmirror.com/get-intrinsic/-/get-intrinsic-1.3.0.tgz", + "integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==", + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.2", + "es-define-property": "^1.0.1", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.1.1", + "function-bind": "^1.1.2", + "get-proto": "^1.0.1", + "gopd": "^1.2.0", + "has-symbols": "^1.1.0", + "hasown": "^2.0.2", + "math-intrinsics": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/get-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmmirror.com/get-proto/-/get-proto-1.0.1.tgz", + "integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==", + "license": "MIT", + "dependencies": { + "dunder-proto": "^1.0.1", + "es-object-atoms": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/get-stream": { + "version": "6.0.1", + "resolved": "https://registry.npmmirror.com/get-stream/-/get-stream-6.0.1.tgz", + "integrity": "sha512-ts6Wi+2j3jQjqi70w5AlN8DFnkSwC+MqmxEzdEALB2qXZYV3X/b1CTfgPLGJNMeAWxdPfU8FO1ms3NUfaHCPYg==", + "dev": true, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/glob": { + "version": "7.2.3", + "resolved": "https://registry.npmmirror.com/glob/-/glob-7.2.3.tgz", + "integrity": "sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==", + "deprecated": "Glob versions prior to v9 are no longer supported", + "license": "ISC", + "dependencies": { + "fs.realpath": "^1.0.0", + "inflight": "^1.0.4", + "inherits": "2", + "minimatch": "^3.1.1", + "once": "^1.3.0", + "path-is-absolute": "^1.0.0" + }, + "engines": { + "node": "*" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/glob-parent": { + "version": "5.1.2", + "resolved": "https://registry.npmmirror.com/glob-parent/-/glob-parent-5.1.2.tgz", + "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", + "dev": true, + "license": "ISC", + "dependencies": { + "is-glob": "^4.0.1" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/glob-to-regexp": { + "version": "0.4.1", + "resolved": "https://registry.npmmirror.com/glob-to-regexp/-/glob-to-regexp-0.4.1.tgz", + "integrity": "sha512-lkX1HJXwyMcprw/5YUZc2s7DrpAiHB21/V+E1rHUrVNokkvB6bqMzT0VfV6/86ZNabt1k14YOIaT7nDvOX3Iiw==", + "license": "BSD-2-Clause" + }, + "node_modules/globby": { + "version": "6.1.0", + "resolved": "https://registry.npmmirror.com/globby/-/globby-6.1.0.tgz", + "integrity": "sha512-KVbFv2TQtbzCoxAnfD6JcHZTYCzyliEaaeM/gH8qQdkKr5s0OP9scEgvdcngyk7AVdY6YVW/TJHd+lQ/Df3Daw==", + "license": "MIT", + "dependencies": { + "array-union": "^1.0.1", + "glob": "^7.0.3", + "object-assign": "^4.0.1", + "pify": "^2.0.0", + "pinkie-promise": "^2.0.0" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/globby/node_modules/pify": { + "version": "2.3.0", + "resolved": "https://registry.npmmirror.com/pify/-/pify-2.3.0.tgz", + "integrity": "sha512-udgsAY+fTnvv7kI7aaxbqwWNb0AHiB0qBO89PZKPkoTmGOgdbrHDKD+0B2X4uTfJ/FT1R09r9gTsjUjNJotuog==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/gopd": { + "version": "1.2.0", + "resolved": "https://registry.npmmirror.com/gopd/-/gopd-1.2.0.tgz", + "integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/graceful-fs": { + "version": "4.2.11", + "resolved": "https://registry.npmmirror.com/graceful-fs/-/graceful-fs-4.2.11.tgz", + "integrity": "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==", + "license": "ISC" + }, + "node_modules/graphlib": { + "version": "2.1.8", + "resolved": "https://registry.npmmirror.com/graphlib/-/graphlib-2.1.8.tgz", + "integrity": "sha512-jcLLfkpoVGmH7/InMC/1hIvOPSUh38oJtGhvrOFGzioE1DZ+0YW16RgmOJhHiuWTvGiJQ9Z1Ik43JvkRPRvE+A==", + "license": "MIT", + "dependencies": { + "lodash": "^4.17.15" + } + }, + "node_modules/handle-thing": { + "version": "2.0.1", + "resolved": "https://registry.npmmirror.com/handle-thing/-/handle-thing-2.0.1.tgz", + "integrity": "sha512-9Qn4yBxelxoh2Ow62nP+Ka/kMnOXRi8BXnRaUwezLNhqelnN49xKz4F/dPP8OYLxLxq6JDtZb2i9XznUQbNPTg==", + "dev": true, + "license": "MIT" + }, + "node_modules/has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmmirror.com/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/has-symbols": { + "version": "1.1.0", + "resolved": "https://registry.npmmirror.com/has-symbols/-/has-symbols-1.1.0.tgz", + "integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-tostringtag": { + "version": "1.0.2", + "resolved": "https://registry.npmmirror.com/has-tostringtag/-/has-tostringtag-1.0.2.tgz", + "integrity": "sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==", + "license": "MIT", + "dependencies": { + "has-symbols": "^1.0.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/hasown": { + "version": "2.0.2", + "resolved": "https://registry.npmmirror.com/hasown/-/hasown-2.0.2.tgz", + "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", + "license": "MIT", + "dependencies": { + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/he": { + "version": "1.2.0", + "resolved": "https://registry.npmmirror.com/he/-/he-1.2.0.tgz", + "integrity": "sha512-F/1DnUGPopORZi0ni+CvrCgHQ5FyEAHRLSApuYWMmrbSwoN2Mn/7k+Gl38gJnR7yyDZk6WLXwiGod1JOWNDKGw==", + "dev": true, + "license": "MIT", + "bin": { + "he": "bin/he" + } + }, + "node_modules/hpack.js": { + "version": "2.1.6", + "resolved": "https://registry.npmmirror.com/hpack.js/-/hpack.js-2.1.6.tgz", + "integrity": "sha512-zJxVehUdMGIKsRaNt7apO2Gqp0BdqW5yaiGHXXmbpvxgBYVZnAql+BJb4RO5ad2MgpbZKn5G6nMnegrH1FcNYQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "inherits": "^2.0.1", + "obuf": "^1.0.0", + "readable-stream": "^2.0.1", + "wbuf": "^1.1.0" + } + }, + "node_modules/hpack.js/node_modules/readable-stream": { + "version": "2.3.8", + "resolved": "https://registry.npmmirror.com/readable-stream/-/readable-stream-2.3.8.tgz", + "integrity": "sha512-8p0AUk4XODgIewSi0l8Epjs+EVnWiK7NoDIEGU0HhE7+ZyY8D1IMY7odu5lRrFXGg71L15KG8QrPmum45RTtdA==", + "dev": true, + "license": "MIT", + "dependencies": { + "core-util-is": "~1.0.0", + "inherits": "~2.0.3", + "isarray": "~1.0.0", + "process-nextick-args": "~2.0.0", + "safe-buffer": "~5.1.1", + "string_decoder": "~1.1.1", + "util-deprecate": "~1.0.1" + } + }, + "node_modules/hpack.js/node_modules/safe-buffer": { + "version": "5.1.2", + "resolved": "https://registry.npmmirror.com/safe-buffer/-/safe-buffer-5.1.2.tgz", + "integrity": "sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==", + "dev": true, + "license": "MIT" + }, + "node_modules/hpack.js/node_modules/string_decoder": { + "version": "1.1.1", + "resolved": "https://registry.npmmirror.com/string_decoder/-/string_decoder-1.1.1.tgz", + "integrity": "sha512-n/ShnvDi6FHbbVfviro+WojiFzv+s8MPMHBczVePfUpDJLwoLT0ht1l4YwBCbi8pJAveEEdnkHyPyTP/mzRfwg==", + "dev": true, + "license": "MIT", + "dependencies": { + "safe-buffer": "~5.1.0" + } + }, + "node_modules/html-entities": { + "version": "2.6.0", + "resolved": "https://registry.npmmirror.com/html-entities/-/html-entities-2.6.0.tgz", + "integrity": "sha512-kig+rMn/QOVRvr7c86gQ8lWXq+Hkv6CbAH1hLu+RG338StTpE8Z0b44SDVaqVu7HGKf27frdmUYEs9hTUX/cLQ==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/mdevils" + }, + { + "type": "patreon", + "url": "https://patreon.com/mdevils" + } + ] + }, + "node_modules/html-loader": { + "version": "5.1.0", + "resolved": "https://registry.npmmirror.com/html-loader/-/html-loader-5.1.0.tgz", + "integrity": "sha512-Jb3xwDbsm0W3qlXrCZwcYqYGnYz55hb6aoKQTlzyZPXsPpi6tHXzAfqalecglMQgNvtEfxrCQPaKT90Irt5XDA==", + "dev": true, + "license": "MIT", + "dependencies": { + "html-minifier-terser": "^7.2.0", + "parse5": "^7.1.2" + }, + "engines": { + "node": ">= 18.12.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + }, + "peerDependencies": { + "webpack": "^5.0.0" + } + }, + "node_modules/html-minifier-terser": { + "version": "7.2.0", + "resolved": "https://registry.npmmirror.com/html-minifier-terser/-/html-minifier-terser-7.2.0.tgz", + "integrity": "sha512-tXgn3QfqPIpGl9o+K5tpcj3/MN4SfLtsx2GWwBC3SSd0tXQGyF3gsSqad8loJgKZGM3ZxbYDd5yhiBIdWpmvLA==", + "dev": true, + "license": "MIT", + "dependencies": { + "camel-case": "^4.1.2", + "clean-css": "~5.3.2", + "commander": "^10.0.0", + "entities": "^4.4.0", + "param-case": "^3.0.4", + "relateurl": "^0.2.7", + "terser": "^5.15.1" + }, + "bin": { + "html-minifier-terser": "cli.js" + }, + "engines": { + "node": "^14.13.1 || >=16.0.0" + } + }, + "node_modules/html-minifier-terser/node_modules/commander": { + "version": "10.0.1", + "resolved": "https://registry.npmmirror.com/commander/-/commander-10.0.1.tgz", + "integrity": "sha512-y4Mg2tXshplEbSGzx7amzPwKKOCGuoSRP/CjEdwwk0FOGlUbq6lKuoyDZTNZkmxHdJtp54hdfY/JUrdL7Xfdug==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14" + } + }, + "node_modules/html-webpack-plugin": { + "version": "5.6.3", + "resolved": "https://registry.npmmirror.com/html-webpack-plugin/-/html-webpack-plugin-5.6.3.tgz", + "integrity": "sha512-QSf1yjtSAsmf7rYBV7XX86uua4W/vkhIt0xNXKbsi2foEeW7vjJQz4bhnpL3xH+l1ryl1680uNv968Z+X6jSYg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/html-minifier-terser": "^6.0.0", + "html-minifier-terser": "^6.0.2", + "lodash": "^4.17.21", + "pretty-error": "^4.0.0", + "tapable": "^2.0.0" + }, + "engines": { + "node": ">=10.13.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/html-webpack-plugin" + }, + "peerDependencies": { + "@rspack/core": "0.x || 1.x", + "webpack": "^5.20.0" + }, + "peerDependenciesMeta": { + "@rspack/core": { + "optional": true + }, + "webpack": { + "optional": true + } + } + }, + "node_modules/html-webpack-plugin/node_modules/commander": { + "version": "8.3.0", + "resolved": "https://registry.npmmirror.com/commander/-/commander-8.3.0.tgz", + "integrity": "sha512-OkTL9umf+He2DZkUq8f8J9of7yL6RJKI24dVITBmNfZBmri9zYZQrKkuXiKhyfPSu8tUhnVBB1iKXevvnlR4Ww==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 12" + } + }, + "node_modules/html-webpack-plugin/node_modules/html-minifier-terser": { + "version": "6.1.0", + "resolved": "https://registry.npmmirror.com/html-minifier-terser/-/html-minifier-terser-6.1.0.tgz", + "integrity": "sha512-YXxSlJBZTP7RS3tWnQw74ooKa6L9b9i9QYXY21eUEvhZ3u9XLfv6OnFsQq6RxkhHygsaUMvYsZRV5rU/OVNZxw==", + "dev": true, + "license": "MIT", + "dependencies": { + "camel-case": "^4.1.2", + "clean-css": "^5.2.2", + "commander": "^8.3.0", + "he": "^1.2.0", + "param-case": "^3.0.4", + "relateurl": "^0.2.7", + "terser": "^5.10.0" + }, + "bin": { + "html-minifier-terser": "cli.js" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/htmlparser2": { + "version": "6.1.0", + "resolved": "https://registry.npmmirror.com/htmlparser2/-/htmlparser2-6.1.0.tgz", + "integrity": "sha512-gyyPk6rgonLFEDGoeRgQNaEUvdJ4ktTmmUh/h2t7s+M8oPpIPxgNACWa+6ESR57kXstwqPiCut0V8NRpcwgU7A==", + "dev": true, + "funding": [ + "https://github.com/fb55/htmlparser2?sponsor=1", + { + "type": "github", + "url": "https://github.com/sponsors/fb55" + } + ], + "license": "MIT", + "dependencies": { + "domelementtype": "^2.0.1", + "domhandler": "^4.0.0", + "domutils": "^2.5.2", + "entities": "^2.0.0" + } + }, + "node_modules/htmlparser2/node_modules/entities": { + "version": "2.2.0", + "resolved": "https://registry.npmmirror.com/entities/-/entities-2.2.0.tgz", + "integrity": "sha512-p92if5Nz619I0w+akJrLZH0MX0Pb5DX39XOwQTtXSdQQOaYH03S1uIQp4mhOZtAXrxq4ViO67YTiLBo2638o9A==", + "dev": true, + "license": "BSD-2-Clause", + "funding": { + "url": "https://github.com/fb55/entities?sponsor=1" + } + }, + "node_modules/http-deceiver": { + "version": "1.2.7", + "resolved": "https://registry.npmmirror.com/http-deceiver/-/http-deceiver-1.2.7.tgz", + "integrity": "sha512-LmpOGxTfbpgtGVxJrj5k7asXHCgNZp5nLfp+hWc8QQRqtb7fUy6kRY3BO1h9ddF6yIPYUARgxGOwB42DnxIaNw==", + "dev": true, + "license": "MIT" + }, + "node_modules/http-errors": { + "version": "2.0.0", + "resolved": "https://registry.npmmirror.com/http-errors/-/http-errors-2.0.0.tgz", + "integrity": "sha512-FtwrG/euBzaEjYeRqOgly7G0qviiXoJWnvEH2Z1plBdXgbyjv34pHTSb9zoeHMyDy33+DWy5Wt9Wo+TURtOYSQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "depd": "2.0.0", + "inherits": "2.0.4", + "setprototypeof": "1.2.0", + "statuses": "2.0.1", + "toidentifier": "1.0.1" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/http-parser-js": { + "version": "0.5.9", + "resolved": "https://registry.npmmirror.com/http-parser-js/-/http-parser-js-0.5.9.tgz", + "integrity": "sha512-n1XsPy3rXVxlqxVioEWdC+0+M+SQw0DpJynwtOPo1X+ZlvdzTLtDBIJJlDQTnwZIFJrZSzSGmIOUdP8tu+SgLw==", + "dev": true, + "license": "MIT" + }, + "node_modules/http-proxy": { + "version": "1.18.1", + "resolved": "https://registry.npmmirror.com/http-proxy/-/http-proxy-1.18.1.tgz", + "integrity": "sha512-7mz/721AbnJwIVbnaSv1Cz3Am0ZLT/UBwkC92VlxhXv/k/BBQfM2fXElQNC27BVGr0uwUpplYPQM9LnaBMR5NQ==", + "dev": true, + "dependencies": { + "eventemitter3": "^4.0.0", + "follow-redirects": "^1.0.0", + "requires-port": "^1.0.0" + }, + "engines": { + "node": ">=8.0.0" + } + }, + "node_modules/http-proxy-middleware": { + "version": "2.0.9", + "resolved": "https://registry.npmmirror.com/http-proxy-middleware/-/http-proxy-middleware-2.0.9.tgz", + "integrity": "sha512-c1IyJYLYppU574+YI7R4QyX2ystMtVXZwIdzazUIPIJsHuWNd+mho2j+bKoHftndicGj9yh+xjd+l0yj7VeT1Q==", + "dev": true, + "dependencies": { + "@types/http-proxy": "^1.17.8", + "http-proxy": "^1.18.1", + "is-glob": "^4.0.1", + "is-plain-obj": "^3.0.0", + "micromatch": "^4.0.2" + }, + "engines": { + "node": ">=12.0.0" + }, + "peerDependencies": { + "@types/express": "^4.17.13" + }, + "peerDependenciesMeta": { + "@types/express": { + "optional": true + } + } + }, + "node_modules/human-signals": { + "version": "2.1.0", + "resolved": "https://registry.npmmirror.com/human-signals/-/human-signals-2.1.0.tgz", + "integrity": "sha512-B4FFZ6q/T2jhhksgkbEW3HBvWIfDW85snkQgawt07S7J5QXTk6BkNV+0yAeZrM5QpMAdYlocGoljn0sJ/WQkFw==", + "dev": true, + "engines": { + "node": ">=10.17.0" + } + }, + "node_modules/i18next": { + "version": "23.16.8", + "resolved": "https://registry.npmmirror.com/i18next/-/i18next-23.16.8.tgz", + "integrity": "sha512-06r/TitrM88Mg5FdUXAKL96dJMzgqLE5dv3ryBAra4KCwD9mJ4ndOTS95ZuymIGoE+2hzfdaMak2X11/es7ZWg==", + "funding": [ + { + "type": "individual", + "url": "https://locize.com" + }, + { + "type": "individual", + "url": "https://locize.com/i18next.html" + }, + { + "type": "individual", + "url": "https://www.i18next.com/how-to/faq#i18next-is-awesome.-how-can-i-support-the-project" + } + ], + "license": "MIT", + "dependencies": { + "@babel/runtime": "^7.23.2" + } + }, + "node_modules/i18next-browser-languagedetector": { + "version": "7.2.2", + "resolved": "https://registry.npmmirror.com/i18next-browser-languagedetector/-/i18next-browser-languagedetector-7.2.2.tgz", + "integrity": "sha512-6b7r75uIJDWCcCflmbof+sJ94k9UQO4X0YR62oUfqGI/GjCLVzlCwu8TFdRZIqVLzWbzNcmkmhfqKEr4TLz4HQ==", + "license": "MIT", + "dependencies": { + "@babel/runtime": "^7.23.2" + } + }, + "node_modules/iconv-lite": { + "version": "0.4.24", + "resolved": "https://registry.npmmirror.com/iconv-lite/-/iconv-lite-0.4.24.tgz", + "integrity": "sha512-v3MXnZAcvnywkTUEZomIActle7RXXeedOR31wwl7VlyoXO4Qi9arvSenNQWne1TcRwhCL1HwLI21bEqdpj8/rA==", + "dev": true, + "license": "MIT", + "dependencies": { + "safer-buffer": ">= 2.1.2 < 3" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/icss-utils": { + "version": "5.1.0", + "resolved": "https://registry.npmmirror.com/icss-utils/-/icss-utils-5.1.0.tgz", + "integrity": "sha512-soFhflCVWLfRNOPU3iv5Z9VUdT44xFRbzjLsEzSr5AQmgqPMTHdU3PMT1Cf1ssx8fLNJDA1juftYl+PUcv3MqA==", + "license": "ISC", + "engines": { + "node": "^10 || ^12 || >= 14" + }, + "peerDependencies": { + "postcss": "^8.1.0" + } + }, + "node_modules/import-local": { + "version": "3.2.0", + "resolved": "https://registry.npmmirror.com/import-local/-/import-local-3.2.0.tgz", + "integrity": "sha512-2SPlun1JUPWoM6t3F0dw0FkCF/jWY8kttcY4f599GLTSjh2OCuuhdTkJQsEcZzBqbXZGKMK2OqW1oZsjtf/gQA==", + "dev": true, + "license": "MIT", + "dependencies": { + "pkg-dir": "^4.2.0", + "resolve-cwd": "^3.0.0" + }, + "bin": { + "import-local-fixture": "fixtures/cli.js" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/inflight": { + "version": "1.0.6", + "resolved": "https://registry.npmmirror.com/inflight/-/inflight-1.0.6.tgz", + "integrity": "sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==", + "deprecated": "This module is not supported, and leaks memory. Do not use it. Check out lru-cache if you want a good and tested way to coalesce async requests by a key value, which is much more comprehensive and powerful.", + "license": "ISC", + "dependencies": { + "once": "^1.3.0", + "wrappy": "1" + } + }, + "node_modules/inherits": { + "version": "2.0.4", + "resolved": "https://registry.npmmirror.com/inherits/-/inherits-2.0.4.tgz", + "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", + "license": "ISC" + }, + "node_modules/inline-chunk-html-plugin": { + "version": "1.1.1", + "resolved": "https://registry.npmmirror.com/inline-chunk-html-plugin/-/inline-chunk-html-plugin-1.1.1.tgz", + "integrity": "sha512-6W1eGIj8z/Yla6xJx5il6jJfCxMZS3kVkbiLQThbbjdsDLRIWkUVmpnhfW2l6WAwCW+qfy0zoXVGBZM1E5XF3g==", + "deprecated": "Package no longer supported. Contact Support at https://www.npmjs.com/support for more info.", + "dev": true + }, + "node_modules/internmap": { + "version": "2.0.3", + "resolved": "https://registry.npmmirror.com/internmap/-/internmap-2.0.3.tgz", + "integrity": "sha512-5Hh7Y1wQbvY5ooGgPbDaL5iYLAPzMTUrjMulskHLH6wnv/A+1q5rgEaiuqEjB+oxGXIVZs1FF+R/KPN3ZSQYYg==", + "engines": { + "node": ">=12" + } + }, + "node_modules/interpret": { + "version": "3.1.1", + "resolved": "https://registry.npmmirror.com/interpret/-/interpret-3.1.1.tgz", + "integrity": "sha512-6xwYfHbajpoF0xLW+iwLkhwgvLoZDfjYfoFNu8ftMoXINzwuymNLd9u/KmwtdT2GbR+/Cz66otEGEVVUHX9QLQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/ipaddr.js": { + "version": "2.2.0", + "resolved": "https://registry.npmmirror.com/ipaddr.js/-/ipaddr.js-2.2.0.tgz", + "integrity": "sha512-Ag3wB2o37wslZS19hZqorUnrnzSkpOVy+IiiDEiTqNubEYpYuHWIf6K4psgN2ZWKExS4xhVCrRVfb/wfW8fWJA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 10" + } + }, + "node_modules/is-binary-path": { + "version": "2.1.0", + "resolved": "https://registry.npmmirror.com/is-binary-path/-/is-binary-path-2.1.0.tgz", + "integrity": "sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw==", + "dev": true, + "license": "MIT", + "dependencies": { + "binary-extensions": "^2.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/is-core-module": { + "version": "2.16.1", + "resolved": "https://registry.npmmirror.com/is-core-module/-/is-core-module-2.16.1.tgz", + "integrity": "sha512-UfoeMA6fIJ8wTYFEUjelnaGI67v6+N7qXJEvQuIGa99l4xsCruSYOVSQ0uPANn4dAzm8lkYPaKLrrijLq7x23w==", + "dev": true, + "license": "MIT", + "dependencies": { + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-docker": { + "version": "2.2.1", + "resolved": "https://registry.npmmirror.com/is-docker/-/is-docker-2.2.1.tgz", + "integrity": "sha512-F+i2BKsFrH66iaUFc0woD8sLy8getkwTwtOBjvs56Cx4CgJDeKQeqfz8wAYiSb8JOprWhHH5p77PbmYCvvUuXQ==", + "dev": true, + "bin": { + "is-docker": "cli.js" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/is-extglob": { + "version": "2.1.1", + "resolved": "https://registry.npmmirror.com/is-extglob/-/is-extglob-2.1.1.tgz", + "integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-glob": { + "version": "4.0.3", + "resolved": "https://registry.npmmirror.com/is-glob/-/is-glob-4.0.3.tgz", + "integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-extglob": "^2.1.1" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-number": { + "version": "7.0.0", + "resolved": "https://registry.npmmirror.com/is-number/-/is-number-7.0.0.tgz", + "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.12.0" + } + }, + "node_modules/is-path-cwd": { + "version": "2.2.0", + "resolved": "https://registry.npmmirror.com/is-path-cwd/-/is-path-cwd-2.2.0.tgz", + "integrity": "sha512-w942bTcih8fdJPJmQHFzkS76NEP8Kzzvmw92cXsazb8intwLqPibPPdXf4ANdKV3rYMuuQYGIWtvz9JilB3NFQ==", + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/is-path-in-cwd": { + "version": "2.1.0", + "resolved": "https://registry.npmmirror.com/is-path-in-cwd/-/is-path-in-cwd-2.1.0.tgz", + "integrity": "sha512-rNocXHgipO+rvnP6dk3zI20RpOtrAM/kzbB258Uw5BWr3TpXi861yzjo16Dn4hUox07iw5AyeMLHWsujkjzvRQ==", + "license": "MIT", + "dependencies": { + "is-path-inside": "^2.1.0" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/is-path-inside": { + "version": "2.1.0", + "resolved": "https://registry.npmmirror.com/is-path-inside/-/is-path-inside-2.1.0.tgz", + "integrity": "sha512-wiyhTzfDWsvwAW53OBWF5zuvaOGlZ6PwYxAbPVDhpm+gM09xKQGjBq/8uYN12aDvMxnAnq3dxTyoSoRNmg5YFg==", + "license": "MIT", + "dependencies": { + "path-is-inside": "^1.0.2" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/is-plain-obj": { + "version": "3.0.0", + "resolved": "https://registry.npmmirror.com/is-plain-obj/-/is-plain-obj-3.0.0.tgz", + "integrity": "sha512-gwsOE28k+23GP1B6vFl1oVh/WOzmawBrKwo5Ev6wMKzPkaXaCDIQKzLnvsA42DRlbVTWorkgTKIviAKCWkfUwA==", + "dev": true, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/is-plain-object": { + "version": "2.0.4", + "resolved": "https://registry.npmmirror.com/is-plain-object/-/is-plain-object-2.0.4.tgz", + "integrity": "sha512-h5PpgXkWitc38BBMYawTYMWJHFZJVnBquFE57xFpjB8pJFiF6gZ+bU+WyI/yqXiFR5mdLsgYNaPe8uao6Uv9Og==", + "dev": true, + "license": "MIT", + "dependencies": { + "isobject": "^3.0.1" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-stream": { + "version": "2.0.1", + "resolved": "https://registry.npmmirror.com/is-stream/-/is-stream-2.0.1.tgz", + "integrity": "sha512-hFoiJiTl63nn+kstHGBtewWSKnQLpyb155KHheA1l39uvtO9nWIop1p3udqPcUd/xbF1VLMO4n7OI6p7RbngDg==", + "dev": true, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/is-wsl": { + "version": "2.2.0", + "resolved": "https://registry.npmmirror.com/is-wsl/-/is-wsl-2.2.0.tgz", + "integrity": "sha512-fKzAra0rGJUUBwGBgNkHZuToZcn+TtXHpeCgmkMJMMYx1sQDYaCSyjJBSCa2nH1DGm7s3n1oBnohoVTBaN7Lww==", + "dev": true, + "dependencies": { + "is-docker": "^2.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/isarray": { + "version": "1.0.0", + "resolved": "https://registry.npmmirror.com/isarray/-/isarray-1.0.0.tgz", + "integrity": "sha512-VLghIWNM6ELQzo7zwmcg0NmTVyWKYjvIeM83yjp0wRDTmUnrM678fQbcKBo6n2CJEF0szoG//ytg+TKla89ALQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/isexe": { + "version": "2.0.0", + "resolved": "https://registry.npmmirror.com/isexe/-/isexe-2.0.0.tgz", + "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", + "license": "ISC" + }, + "node_modules/isobject": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/isobject/-/isobject-3.0.1.tgz", + "integrity": "sha512-WhB9zCku7EGTj/HQQRz5aUQEUeoQZH2bWcltRErOpymJ4boYE6wL9Tbr23krRPSZ+C5zqNSrSw+Cc7sZZ4b7vg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/jest-worker": { + "version": "27.5.1", + "resolved": "https://registry.npmmirror.com/jest-worker/-/jest-worker-27.5.1.tgz", + "integrity": "sha512-7vuh85V5cdDofPyxn58nrPjBktZo0u9x1g8WtjQol+jZDaE+fhN+cIvTj11GndBnMnyfrUOG1sZQxCdjKh+DKg==", + "license": "MIT", + "dependencies": { + "@types/node": "*", + "merge-stream": "^2.0.0", + "supports-color": "^8.0.0" + }, + "engines": { + "node": ">= 10.13.0" + } + }, + "node_modules/jest-worker/node_modules/supports-color": { + "version": "8.1.1", + "resolved": "https://registry.npmmirror.com/supports-color/-/supports-color-8.1.1.tgz", + "integrity": "sha512-MpUEN2OodtUzxvKQl72cUF7RQ5EiHsGvSsVG0ia9c5RbWGL2CI4C7EpPS8UTBIplnlzZiNuV56w+FuNxy3ty2Q==", + "license": "MIT", + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/supports-color?sponsor=1" + } + }, + "node_modules/json-parse-even-better-errors": { + "version": "2.3.1", + "resolved": "https://registry.npmmirror.com/json-parse-even-better-errors/-/json-parse-even-better-errors-2.3.1.tgz", + "integrity": "sha512-xyFwyhro/JEof6Ghe2iz2NcXoj2sloNsWr/XsERDK/oiPCfaNhl5ONfp+jQdAZRQQ0IJWNzH9zIZF7li91kh2w==", + "license": "MIT" + }, + "node_modules/json-schema-traverse": { + "version": "1.0.0", + "resolved": "https://registry.npmmirror.com/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", + "license": "MIT" + }, + "node_modules/kind-of": { + "version": "6.0.3", + "resolved": "https://registry.npmmirror.com/kind-of/-/kind-of-6.0.3.tgz", + "integrity": "sha512-dcS1ul+9tmeD95T+x28/ehLgd9mENa3LsvDTtzm3vyBEO7RPptvAD+t44WVXaUjTBRcrpFeFlC8WCruUR456hw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/launch-editor": { + "version": "2.10.0", + "resolved": "https://registry.npmmirror.com/launch-editor/-/launch-editor-2.10.0.tgz", + "integrity": "sha512-D7dBRJo/qcGX9xlvt/6wUYzQxjh5G1RvZPgPv8vi4KRU99DVQL/oW7tnVOCCTm2HGeo3C5HvGE5Yrh6UBoZ0vA==", + "dev": true, + "license": "MIT", + "dependencies": { + "picocolors": "^1.0.0", + "shell-quote": "^1.8.1" + } + }, + "node_modules/lit": { + "version": "3.2.1", + "resolved": "https://registry.npmmirror.com/lit/-/lit-3.2.1.tgz", + "integrity": "sha512-1BBa1E/z0O9ye5fZprPtdqnc0BFzxIxTTOO/tQFmyC/hj1O3jL4TfmLBw0WEwjAokdLwpclkvGgDJwTIh0/22w==", + "license": "BSD-3-Clause", + "dependencies": { + "@lit/reactive-element": "^2.0.4", + "lit-element": "^4.1.0", + "lit-html": "^3.2.0" + } + }, + "node_modules/lit-element": { + "version": "4.1.1", + "resolved": "https://registry.npmmirror.com/lit-element/-/lit-element-4.1.1.tgz", + "integrity": "sha512-HO9Tkkh34QkTeUmEdNYhMT8hzLid7YlMlATSi1q4q17HE5d9mrrEHJ/o8O2D0cMi182zK1F3v7x0PWFjrhXFew==", + "license": "BSD-3-Clause", + "dependencies": { + "@lit-labs/ssr-dom-shim": "^1.2.0", + "@lit/reactive-element": "^2.0.4", + "lit-html": "^3.2.0" + } + }, + "node_modules/lit-html": { + "version": "3.2.1", + "resolved": "https://registry.npmmirror.com/lit-html/-/lit-html-3.2.1.tgz", + "integrity": "sha512-qI/3lziaPMSKsrwlxH/xMgikhQ0EGOX2ICU73Bi/YHFvz2j/yMCIrw4+puF2IpQ4+upd3EWbvnHM9+PnJn48YA==", + "license": "BSD-3-Clause", + "dependencies": { + "@types/trusted-types": "^2.0.2" + } + }, + "node_modules/loader-runner": { + "version": "4.3.0", + "resolved": "https://registry.npmmirror.com/loader-runner/-/loader-runner-4.3.0.tgz", + "integrity": "sha512-3R/1M+yS3j5ou80Me59j7F9IMs4PXs3VqRrm0TU3AbKPxlmpoY1TNscJV/oGJXo8qCatFGTfDbY6W6ipGOYXfg==", + "license": "MIT", + "engines": { + "node": ">=6.11.5" + } + }, + "node_modules/locate-path": { + "version": "5.0.0", + "resolved": "https://registry.npmmirror.com/locate-path/-/locate-path-5.0.0.tgz", + "integrity": "sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g==", + "dev": true, + "license": "MIT", + "dependencies": { + "p-locate": "^4.1.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/lodash": { + "version": "4.17.21", + "resolved": "https://registry.npmmirror.com/lodash/-/lodash-4.17.21.tgz", + "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==", + "license": "MIT" + }, + "node_modules/lower-case": { + "version": "2.0.2", + "resolved": "https://registry.npmmirror.com/lower-case/-/lower-case-2.0.2.tgz", + "integrity": "sha512-7fm3l3NAF9WfN6W3JOmf5drwpVqX78JtoGJ3A6W0a6ZnldM41w2fV5D490psKFTpMds8TJse/eHLFFsNHHjHgg==", + "dev": true, + "license": "MIT", + "dependencies": { + "tslib": "^2.0.3" + } + }, + "node_modules/math-intrinsics": { + "version": "1.1.0", + "resolved": "https://registry.npmmirror.com/math-intrinsics/-/math-intrinsics-1.1.0.tgz", + "integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/media-typer": { + "version": "0.3.0", + "resolved": "https://registry.npmmirror.com/media-typer/-/media-typer-0.3.0.tgz", + "integrity": "sha512-dq+qelQ9akHpcOl/gUVRTxVIOkAJ1wR3QAvb4RsVjS8oVoFjDGTc679wJYmUmknUF5HwMLOgb5O+a3KxfWapPQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/memfs": { + "version": "3.5.3", + "resolved": "https://registry.npmmirror.com/memfs/-/memfs-3.5.3.tgz", + "integrity": "sha512-UERzLsxzllchadvbPs5aolHh65ISpKpM+ccLbOJ8/vvpBKmAWf+la7dXFy7Mr0ySHbdHrFv5kGFCUHHe6GFEmw==", + "dev": true, + "dependencies": { + "fs-monkey": "^1.0.4" + }, + "engines": { + "node": ">= 4.0.0" + } + }, + "node_modules/merge-descriptors": { + "version": "1.0.3", + "resolved": "https://registry.npmmirror.com/merge-descriptors/-/merge-descriptors-1.0.3.tgz", + "integrity": "sha512-gaNvAS7TZ897/rVaZ0nMtAyxNyi/pdbjbAwUpFQpN70GqnVfOiXpeUUMKRBmzXaSQ8DdTX4/0ms62r2K+hE6mQ==", + "dev": true, + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/merge-stream": { + "version": "2.0.0", + "resolved": "https://registry.npmmirror.com/merge-stream/-/merge-stream-2.0.0.tgz", + "integrity": "sha512-abv/qOcuPfk3URPfDzmZU1LKmuw8kT+0nIHvKrKgFrwifol/doWcdA4ZqsWQ8ENrFKkd67Mfpo/LovbIUsbt3w==", + "license": "MIT" + }, + "node_modules/methods": { + "version": "1.1.2", + "resolved": "https://registry.npmmirror.com/methods/-/methods-1.1.2.tgz", + "integrity": "sha512-iclAHeNqNm68zFtnZ0e+1L2yUIdvzNoauKU4WBA3VvH/vPFieF7qfRlwUZU+DA9P9bPXIS90ulxoUoCH23sV2w==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/micromatch": { + "version": "4.0.8", + "resolved": "https://registry.npmmirror.com/micromatch/-/micromatch-4.0.8.tgz", + "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==", + "dev": true, + "license": "MIT", + "dependencies": { + "braces": "^3.0.3", + "picomatch": "^2.3.1" + }, + "engines": { + "node": ">=8.6" + } + }, + "node_modules/mime": { + "version": "1.6.0", + "resolved": "https://registry.npmmirror.com/mime/-/mime-1.6.0.tgz", + "integrity": "sha512-x0Vn8spI+wuJ1O6S7gnbaQg8Pxh4NNHb7KSINmEWKiPE4RKOplvijn+NkmYmmRgP68mc70j2EbeTFRsrswaQeg==", + "dev": true, + "license": "MIT", + "bin": { + "mime": "cli.js" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/mime-db": { + "version": "1.52.0", + "resolved": "https://registry.npmmirror.com/mime-db/-/mime-db-1.52.0.tgz", + "integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/mime-types": { + "version": "2.1.35", + "resolved": "https://registry.npmmirror.com/mime-types/-/mime-types-2.1.35.tgz", + "integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==", + "license": "MIT", + "dependencies": { + "mime-db": "1.52.0" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/mimic-fn": { + "version": "2.1.0", + "resolved": "https://registry.npmmirror.com/mimic-fn/-/mimic-fn-2.1.0.tgz", + "integrity": "sha512-OqbOk5oEQeAZ8WXWydlu9HJjz9WVdEIvamMCcXmuqUYjTknH/sqsWvhQ3vgwKFRR1HpjvNBKQ37nbJgYzGqGcg==", + "dev": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/minimalistic-assert": { + "version": "1.0.1", + "resolved": "https://registry.npmmirror.com/minimalistic-assert/-/minimalistic-assert-1.0.1.tgz", + "integrity": "sha512-UtJcAD4yEaGtjPezWuO9wC4nwUnVH/8/Im3yEHQP4b67cXlD/Qr9hdITCU1xDbSEXg2XKNaP8jsReV7vQd00/A==", + "dev": true, + "license": "ISC" + }, + "node_modules/minimatch": { + "version": "3.1.2", + "resolved": "https://registry.npmmirror.com/minimatch/-/minimatch-3.1.2.tgz", + "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "license": "ISC", + "dependencies": { + "brace-expansion": "^1.1.7" + }, + "engines": { + "node": "*" + } + }, + "node_modules/ms": { + "version": "2.0.0", + "resolved": "https://registry.npmmirror.com/ms/-/ms-2.0.0.tgz", + "integrity": "sha512-Tpp60P6IUJDTuOq/5Z8cdskzJujfwqfOTkrwIwj7IRISpnkJnT6SyJ4PCPnGMoFjC9ddhal5KVIYtAt97ix05A==", + "dev": true, + "license": "MIT" + }, + "node_modules/multicast-dns": { + "version": "7.2.5", + "resolved": "https://registry.npmmirror.com/multicast-dns/-/multicast-dns-7.2.5.tgz", + "integrity": "sha512-2eznPJP8z2BFLX50tf0LuODrpINqP1RVIm/CObbTcBRITQgmC/TjcREF1NeTBzIcR5XO/ukWo+YHOjBbFwIupg==", + "dev": true, + "license": "MIT", + "dependencies": { + "dns-packet": "^5.2.2", + "thunky": "^1.0.2" + }, + "bin": { + "multicast-dns": "cli.js" + } + }, + "node_modules/nanoid": { + "version": "3.3.11", + "resolved": "https://registry.npmmirror.com/nanoid/-/nanoid-3.3.11.tgz", + "integrity": "sha512-N8SpfPUnUp1bK+PMYW8qSWdl9U+wwNWI4QKxOYDy9JAro3WMX7p2OeVRF9v+347pnakNevPmiHhNmZ2HbFA76w==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "bin": { + "nanoid": "bin/nanoid.cjs" + }, + "engines": { + "node": "^10 || ^12 || ^13.7 || ^14 || >=15.0.1" + } + }, + "node_modules/negotiator": { + "version": "0.6.4", + "resolved": "https://registry.npmmirror.com/negotiator/-/negotiator-0.6.4.tgz", + "integrity": "sha512-myRT3DiWPHqho5PrJaIRyaMv2kgYf0mUVgBNOYMuCH5Ki1yEiQaf/ZJuQ62nvpc44wL5WDbTX7yGJi1Neevw8w==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/neo-async": { + "version": "2.6.2", + "resolved": "https://registry.npmmirror.com/neo-async/-/neo-async-2.6.2.tgz", + "integrity": "sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw==", + "license": "MIT" + }, + "node_modules/no-case": { + "version": "3.0.4", + "resolved": "https://registry.npmmirror.com/no-case/-/no-case-3.0.4.tgz", + "integrity": "sha512-fgAN3jGAh+RoxUGZHTSOLJIqUc2wmoBwGR4tbpNAKmmovFoWq0OdRkb0VkldReO2a2iBT/OEulG9XSUc10r3zg==", + "dev": true, + "license": "MIT", + "dependencies": { + "lower-case": "^2.0.2", + "tslib": "^2.0.3" + } + }, + "node_modules/node-forge": { + "version": "1.3.1", + "resolved": "https://registry.npmmirror.com/node-forge/-/node-forge-1.3.1.tgz", + "integrity": "sha512-dPEtOeMvF9VMcYV/1Wb8CPoVAXtp6MKMlcbAt4ddqmGqUJ6fQZFXkNZNkNlfevtNkGtaSoXf/vNNNSvgrdXwtA==", + "dev": true, + "license": "(BSD-3-Clause OR GPL-2.0)", + "engines": { + "node": ">= 6.13.0" + } + }, + "node_modules/node-releases": { + "version": "2.0.19", + "resolved": "https://registry.npmmirror.com/node-releases/-/node-releases-2.0.19.tgz", + "integrity": "sha512-xxOWJsBKtzAq7DY0J+DTzuz58K8e7sJbdgwkbMWQe8UYB6ekmsQ45q0M/tJDsGaZmbC+l7n57UV8Hl5tHxO9uw==", + "license": "MIT" + }, + "node_modules/normalize-path": { + "version": "3.0.0", + "resolved": "https://registry.npmmirror.com/normalize-path/-/normalize-path-3.0.0.tgz", + "integrity": "sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/npm-run-path": { + "version": "4.0.1", + "resolved": "https://registry.npmmirror.com/npm-run-path/-/npm-run-path-4.0.1.tgz", + "integrity": "sha512-S48WzZW777zhNIrn7gxOlISNAqi9ZC/uQFnRdbeIHhZhCA6UqpkOT8T1G7BvfdgP4Er8gF4sUbaS0i7QvIfCWw==", + "dev": true, + "dependencies": { + "path-key": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/nth-check": { + "version": "2.1.1", + "resolved": "https://registry.npmmirror.com/nth-check/-/nth-check-2.1.1.tgz", + "integrity": "sha512-lqjrjmaOoAnWfMmBPL+XNnynZh2+swxiX3WUE0s4yEHI6m+AwrK2UZOimIRl3X/4QctVqS8AiZjFqyOGrMXb/w==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "boolbase": "^1.0.0" + }, + "funding": { + "url": "https://github.com/fb55/nth-check?sponsor=1" + } + }, + "node_modules/object-assign": { + "version": "4.1.1", + "resolved": "https://registry.npmmirror.com/object-assign/-/object-assign-4.1.1.tgz", + "integrity": "sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/object-inspect": { + "version": "1.13.4", + "resolved": "https://registry.npmmirror.com/object-inspect/-/object-inspect-1.13.4.tgz", + "integrity": "sha512-W67iLl4J2EXEGTbfeHCffrjDfitvLANg0UlX3wFUUSTx92KXRFegMHUVgSqE+wvhAbi4WqjGg9czysTV2Epbew==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/obuf": { + "version": "1.1.2", + "resolved": "https://registry.npmmirror.com/obuf/-/obuf-1.1.2.tgz", + "integrity": "sha512-PX1wu0AmAdPqOL1mWhqmlOd8kOIZQwGZw6rh7uby9fTc5lhaOWFLX3I6R1hrF9k3zUY40e6igsLGkDXK92LJNg==", + "dev": true, + "license": "MIT" + }, + "node_modules/on-finished": { + "version": "2.4.1", + "resolved": "https://registry.npmmirror.com/on-finished/-/on-finished-2.4.1.tgz", + "integrity": "sha512-oVlzkg3ENAhCk2zdv7IJwd/QUD4z2RxRwpkcGY8psCVcCYZNq4wYnVWALHM+brtuJjePWiYF/ClmuDr8Ch5+kg==", + "dev": true, + "license": "MIT", + "dependencies": { + "ee-first": "1.1.1" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/on-headers": { + "version": "1.0.2", + "resolved": "https://registry.npmmirror.com/on-headers/-/on-headers-1.0.2.tgz", + "integrity": "sha512-pZAE+FJLoyITytdqK0U5s+FIpjN0JP3OzFi/u8Rx+EV5/W+JTWGXG8xFzevE7AjBfDqHv/8vL8qQsIhHnqRkrA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/once": { + "version": "1.4.0", + "resolved": "https://registry.npmmirror.com/once/-/once-1.4.0.tgz", + "integrity": "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==", + "license": "ISC", + "dependencies": { + "wrappy": "1" + } + }, + "node_modules/onetime": { + "version": "5.1.2", + "resolved": "https://registry.npmmirror.com/onetime/-/onetime-5.1.2.tgz", + "integrity": "sha512-kbpaSSGJTWdAY5KPVeMOKXSrPtr8C8C7wodJbcsd51jRnmD+GZu8Y0VoU6Dm5Z4vWr0Ig/1NKuWRKf7j5aaYSg==", + "dev": true, + "dependencies": { + "mimic-fn": "^2.1.0" + }, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/open": { + "version": "8.4.2", + "resolved": "https://registry.npmmirror.com/open/-/open-8.4.2.tgz", + "integrity": "sha512-7x81NCL719oNbsq/3mh+hVrAWmFuEYUqrq/Iw3kUzH8ReypT9QQ0BLoJS7/G9k6N81XjW4qHWtjWwe/9eLy1EQ==", + "dev": true, + "dependencies": { + "define-lazy-prop": "^2.0.0", + "is-docker": "^2.1.1", + "is-wsl": "^2.2.0" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/p-limit": { + "version": "2.3.0", + "resolved": "https://registry.npmmirror.com/p-limit/-/p-limit-2.3.0.tgz", + "integrity": "sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w==", + "dev": true, + "license": "MIT", + "dependencies": { + "p-try": "^2.0.0" + }, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/p-locate": { + "version": "4.1.0", + "resolved": "https://registry.npmmirror.com/p-locate/-/p-locate-4.1.0.tgz", + "integrity": "sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==", + "dev": true, + "license": "MIT", + "dependencies": { + "p-limit": "^2.2.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/p-map": { + "version": "2.1.0", + "resolved": "https://registry.npmmirror.com/p-map/-/p-map-2.1.0.tgz", + "integrity": "sha512-y3b8Kpd8OAN444hxfBbFfj1FY/RjtTd8tzYwhUqNYXx0fXx2iX4maP4Qr6qhIKbQXI02wTLAda4fYUbDagTUFw==", + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/p-retry": { + "version": "4.6.2", + "resolved": "https://registry.npmmirror.com/p-retry/-/p-retry-4.6.2.tgz", + "integrity": "sha512-312Id396EbJdvRONlngUx0NydfrIQ5lsYu0znKVUzVvArzEIt08V1qhtyESbGVd1FGX7UKtiFp5uwKZdM8wIuQ==", + "dev": true, + "dependencies": { + "@types/retry": "0.12.0", + "retry": "^0.13.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/p-try": { + "version": "2.2.0", + "resolved": "https://registry.npmmirror.com/p-try/-/p-try-2.2.0.tgz", + "integrity": "sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/param-case": { + "version": "3.0.4", + "resolved": "https://registry.npmmirror.com/param-case/-/param-case-3.0.4.tgz", + "integrity": "sha512-RXlj7zCYokReqWpOPH9oYivUzLYZ5vAPIfEmCTNViosC78F8F0H9y7T7gG2M39ymgutxF5gcFEsyZQSph9Bp3A==", + "dev": true, + "license": "MIT", + "dependencies": { + "dot-case": "^3.0.4", + "tslib": "^2.0.3" + } + }, + "node_modules/parse5": { + "version": "7.2.1", + "resolved": "https://registry.npmmirror.com/parse5/-/parse5-7.2.1.tgz", + "integrity": "sha512-BuBYQYlv1ckiPdQi/ohiivi9Sagc9JG+Ozs0r7b/0iK3sKmrb0b9FdWdBbOdx6hBCM/F9Ir82ofnBhtZOjCRPQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "entities": "^4.5.0" + }, + "funding": { + "url": "https://github.com/inikulin/parse5?sponsor=1" + } + }, + "node_modules/parseurl": { + "version": "1.3.3", + "resolved": "https://registry.npmmirror.com/parseurl/-/parseurl-1.3.3.tgz", + "integrity": "sha512-CiyeOxFT/JZyN5m0z9PfXw4SCBJ6Sygz1Dpl0wqjlhDEGGBP1GnsUVEL0p63hoG1fcj3fHynXi9NYO4nWOL+qQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/pascal-case": { + "version": "3.1.2", + "resolved": "https://registry.npmmirror.com/pascal-case/-/pascal-case-3.1.2.tgz", + "integrity": "sha512-uWlGT3YSnK9x3BQJaOdcZwrnV6hPpd8jFH1/ucpiLRPh/2zCVJKS19E4GvYHvaCcACn3foXZ0cLB9Wrx1KGe5g==", + "dev": true, + "license": "MIT", + "dependencies": { + "no-case": "^3.0.4", + "tslib": "^2.0.3" + } + }, + "node_modules/path-exists": { + "version": "4.0.0", + "resolved": "https://registry.npmmirror.com/path-exists/-/path-exists-4.0.0.tgz", + "integrity": "sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/path-is-absolute": { + "version": "1.0.1", + "resolved": "https://registry.npmmirror.com/path-is-absolute/-/path-is-absolute-1.0.1.tgz", + "integrity": "sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/path-is-inside": { + "version": "1.0.2", + "resolved": "https://registry.npmmirror.com/path-is-inside/-/path-is-inside-1.0.2.tgz", + "integrity": "sha512-DUWJr3+ULp4zXmol/SZkFf3JGsS9/SIv+Y3Rt93/UjPpDpklB5f1er4O3POIbUuUJ3FXgqte2Q7SrU6zAqwk8w==", + "license": "(WTFPL OR MIT)" + }, + "node_modules/path-key": { + "version": "3.1.1", + "resolved": "https://registry.npmmirror.com/path-key/-/path-key-3.1.1.tgz", + "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/path-parse": { + "version": "1.0.7", + "resolved": "https://registry.npmmirror.com/path-parse/-/path-parse-1.0.7.tgz", + "integrity": "sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw==", + "dev": true, + "license": "MIT" + }, + "node_modules/path-to-regexp": { + "version": "0.1.12", + "resolved": "https://registry.npmmirror.com/path-to-regexp/-/path-to-regexp-0.1.12.tgz", + "integrity": "sha512-RA1GjUVMnvYFxuqovrEqZoxxW5NUZqbwKtYz/Tt7nXerk0LbLblQmrsgdeOxV5SFHf0UDggjS/bSeOZwt1pmEQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/picocolors": { + "version": "1.1.1", + "resolved": "https://registry.npmmirror.com/picocolors/-/picocolors-1.1.1.tgz", + "integrity": "sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==", + "license": "ISC" + }, + "node_modules/picomatch": { + "version": "2.3.1", + "resolved": "https://registry.npmmirror.com/picomatch/-/picomatch-2.3.1.tgz", + "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8.6" + }, + "funding": { + "url": "https://github.com/sponsors/jonschlinkert" + } + }, + "node_modules/pify": { + "version": "4.0.1", + "resolved": "https://registry.npmmirror.com/pify/-/pify-4.0.1.tgz", + "integrity": "sha512-uB80kBFb/tfd68bVleG9T5GGsGPjJrLAUpR5PZIrhBnIaRTQRjqdJSsIKkOP6OAIFbj7GOrcudc5pNjZ+geV2g==", + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/pinkie": { + "version": "2.0.4", + "resolved": "https://registry.npmmirror.com/pinkie/-/pinkie-2.0.4.tgz", + "integrity": "sha512-MnUuEycAemtSaeFSjXKW/aroV7akBbY+Sv+RkyqFjgAe73F+MR0TBWKBRDkmfWq/HiFmdavfZ1G7h4SPZXaCSg==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/pinkie-promise": { + "version": "2.0.1", + "resolved": "https://registry.npmmirror.com/pinkie-promise/-/pinkie-promise-2.0.1.tgz", + "integrity": "sha512-0Gni6D4UcLTbv9c57DfxDGdr41XfgUjqWZu492f0cIGr16zDU06BWP/RAEvOuo7CQ0CNjHaLlM59YJJFm3NWlw==", + "license": "MIT", + "dependencies": { + "pinkie": "^2.0.0" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/pkg-dir": { + "version": "4.2.0", + "resolved": "https://registry.npmmirror.com/pkg-dir/-/pkg-dir-4.2.0.tgz", + "integrity": "sha512-HRDzbaKjC+AOWVXxAU/x54COGeIv9eb+6CkDSQoNTt4XyWoIJvuPsXizxu/Fr23EiekbtZwmh1IcIG/l/a10GQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "find-up": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/postcss": { + "version": "8.5.3", + "resolved": "https://registry.npmmirror.com/postcss/-/postcss-8.5.3.tgz", + "integrity": "sha512-dle9A3yYxlBSrt8Fu+IpjGT8SY8hN0mlaA6GY8t0P5PjIOZemULz/E2Bnm/2dcUOena75OTNkHI76uZBNUUq3A==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/postcss" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "dependencies": { + "nanoid": "^3.3.8", + "picocolors": "^1.1.1", + "source-map-js": "^1.2.1" + }, + "engines": { + "node": "^10 || ^12 || >=14" + } + }, + "node_modules/postcss-modules-extract-imports": { + "version": "3.1.0", + "resolved": "https://registry.npmmirror.com/postcss-modules-extract-imports/-/postcss-modules-extract-imports-3.1.0.tgz", + "integrity": "sha512-k3kNe0aNFQDAZGbin48pL2VNidTF0w4/eASDsxlyspobzU3wZQLOGj7L9gfRe0Jo9/4uud09DsjFNH7winGv8Q==", + "license": "ISC", + "engines": { + "node": "^10 || ^12 || >= 14" + }, + "peerDependencies": { + "postcss": "^8.1.0" + } + }, + "node_modules/postcss-modules-local-by-default": { + "version": "4.2.0", + "resolved": "https://registry.npmmirror.com/postcss-modules-local-by-default/-/postcss-modules-local-by-default-4.2.0.tgz", + "integrity": "sha512-5kcJm/zk+GJDSfw+V/42fJ5fhjL5YbFDl8nVdXkJPLLW+Vf9mTD5Xe0wqIaDnLuL2U6cDNpTr+UQ+v2HWIBhzw==", + "license": "MIT", + "dependencies": { + "icss-utils": "^5.0.0", + "postcss-selector-parser": "^7.0.0", + "postcss-value-parser": "^4.1.0" + }, + "engines": { + "node": "^10 || ^12 || >= 14" + }, + "peerDependencies": { + "postcss": "^8.1.0" + } + }, + "node_modules/postcss-modules-scope": { + "version": "3.2.1", + "resolved": "https://registry.npmmirror.com/postcss-modules-scope/-/postcss-modules-scope-3.2.1.tgz", + "integrity": "sha512-m9jZstCVaqGjTAuny8MdgE88scJnCiQSlSrOWcTQgM2t32UBe+MUmFSO5t7VMSfAf/FJKImAxBav8ooCHJXCJA==", + "license": "ISC", + "dependencies": { + "postcss-selector-parser": "^7.0.0" + }, + "engines": { + "node": "^10 || ^12 || >= 14" + }, + "peerDependencies": { + "postcss": "^8.1.0" + } + }, + "node_modules/postcss-modules-values": { + "version": "4.0.0", + "resolved": "https://registry.npmmirror.com/postcss-modules-values/-/postcss-modules-values-4.0.0.tgz", + "integrity": "sha512-RDxHkAiEGI78gS2ofyvCsu7iycRv7oqw5xMWn9iMoR0N/7mf9D50ecQqUo5BZ9Zh2vH4bCUR/ktCqbB9m8vJjQ==", + "license": "ISC", + "dependencies": { + "icss-utils": "^5.0.0" + }, + "engines": { + "node": "^10 || ^12 || >= 14" + }, + "peerDependencies": { + "postcss": "^8.1.0" + } + }, + "node_modules/postcss-selector-parser": { + "version": "7.1.0", + "resolved": "https://registry.npmmirror.com/postcss-selector-parser/-/postcss-selector-parser-7.1.0.tgz", + "integrity": "sha512-8sLjZwK0R+JlxlYcTuVnyT2v+htpdrjDOKuMcOVdYjt52Lh8hWRYpxBPoKx/Zg+bcjc3wx6fmQevMmUztS/ccA==", + "license": "MIT", + "dependencies": { + "cssesc": "^3.0.0", + "util-deprecate": "^1.0.2" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/postcss-value-parser": { + "version": "4.2.0", + "resolved": "https://registry.npmmirror.com/postcss-value-parser/-/postcss-value-parser-4.2.0.tgz", + "integrity": "sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ==", + "license": "MIT" + }, + "node_modules/prettier": { + "version": "3.5.3", + "resolved": "https://registry.npmmirror.com/prettier/-/prettier-3.5.3.tgz", + "integrity": "sha512-QQtaxnoDJeAkDvDKWCLiwIXkTgRhwYDEQCghU9Z6q03iyek/rxRh/2lC3HB7P8sWT2xC/y5JDctPLBIGzHKbhw==", + "license": "MIT", + "bin": { + "prettier": "bin/prettier.cjs" + }, + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/prettier/prettier?sponsor=1" + } + }, + "node_modules/pretty-error": { + "version": "4.0.0", + "resolved": "https://registry.npmmirror.com/pretty-error/-/pretty-error-4.0.0.tgz", + "integrity": "sha512-AoJ5YMAcXKYxKhuJGdcvse+Voc6v1RgnsR3nWcYU7q4t6z0Q6T86sv5Zq8VIRbOWWFpvdGE83LtdSMNd+6Y0xw==", + "dev": true, + "license": "MIT", + "dependencies": { + "lodash": "^4.17.20", + "renderkid": "^3.0.0" + } + }, + "node_modules/process-nextick-args": { + "version": "2.0.1", + "resolved": "https://registry.npmmirror.com/process-nextick-args/-/process-nextick-args-2.0.1.tgz", + "integrity": "sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==", + "dev": true, + "license": "MIT" + }, + "node_modules/proxy-addr": { + "version": "2.0.7", + "resolved": "https://registry.npmmirror.com/proxy-addr/-/proxy-addr-2.0.7.tgz", + "integrity": "sha512-llQsMLSUDUPT44jdrU/O37qlnifitDP+ZwrmmZcoSKyLKvtZxpyV0n2/bD/N4tBAAZ/gJEdZU7KMraoK1+XYAg==", + "dev": true, + "license": "MIT", + "dependencies": { + "forwarded": "0.2.0", + "ipaddr.js": "1.9.1" + }, + "engines": { + "node": ">= 0.10" + } + }, + "node_modules/proxy-addr/node_modules/ipaddr.js": { + "version": "1.9.1", + "resolved": "https://registry.npmmirror.com/ipaddr.js/-/ipaddr.js-1.9.1.tgz", + "integrity": "sha512-0KI/607xoxSToH7GjN1FfSbLoU0+btTicjsQSWQlh/hZykN8KpmMf7uYwPW3R+akZ6R/w18ZlXSHBYXiYUPO3g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.10" + } + }, + "node_modules/proxy-from-env": { + "version": "1.1.0", + "resolved": "https://registry.npmmirror.com/proxy-from-env/-/proxy-from-env-1.1.0.tgz", + "integrity": "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==", + "license": "MIT" + }, + "node_modules/qs": { + "version": "6.13.0", + "resolved": "https://registry.npmmirror.com/qs/-/qs-6.13.0.tgz", + "integrity": "sha512-+38qI9SOr8tfZ4QmJNplMUxqjbe7LKvvZgWdExBOmd+egZTtjLB67Gu0HRX3u/XOq7UU2Nx6nsjvS16Z9uwfpg==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "side-channel": "^1.0.6" + }, + "engines": { + "node": ">=0.6" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/randombytes": { + "version": "2.1.0", + "resolved": "https://registry.npmmirror.com/randombytes/-/randombytes-2.1.0.tgz", + "integrity": "sha512-vYl3iOX+4CKUWuxGi9Ukhie6fsqXqS9FE2Zaic4tNFD2N2QQaXOMFbuKK4QmDHC0JO6B1Zp41J0LpT0oR68amQ==", + "license": "MIT", + "dependencies": { + "safe-buffer": "^5.1.0" + } + }, + "node_modules/range-parser": { + "version": "1.2.1", + "resolved": "https://registry.npmmirror.com/range-parser/-/range-parser-1.2.1.tgz", + "integrity": "sha512-Hrgsx+orqoygnmhFbKaHE6c296J+HTAQXoxEF6gNupROmmGJRoyzfG3ccAveqCBrwr/2yxQ5BVd/GTl5agOwSg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/raw-body": { + "version": "2.5.2", + "resolved": "https://registry.npmmirror.com/raw-body/-/raw-body-2.5.2.tgz", + "integrity": "sha512-8zGqypfENjCIqGhgXToC8aB2r7YrBX+AQAfIPs/Mlk+BtPTztOvTS01NRW/3Eh60J+a48lt8qsCzirQ6loCVfA==", + "dev": true, + "license": "MIT", + "dependencies": { + "bytes": "3.1.2", + "http-errors": "2.0.0", + "iconv-lite": "0.4.24", + "unpipe": "1.0.0" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/readable-stream": { + "version": "3.6.2", + "resolved": "https://registry.npmmirror.com/readable-stream/-/readable-stream-3.6.2.tgz", + "integrity": "sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==", + "dev": true, + "license": "MIT", + "dependencies": { + "inherits": "^2.0.3", + "string_decoder": "^1.1.1", + "util-deprecate": "^1.0.1" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/readdirp": { + "version": "3.6.0", + "resolved": "https://registry.npmmirror.com/readdirp/-/readdirp-3.6.0.tgz", + "integrity": "sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA==", + "dev": true, + "license": "MIT", + "dependencies": { + "picomatch": "^2.2.1" + }, + "engines": { + "node": ">=8.10.0" + } + }, + "node_modules/rechoir": { + "version": "0.8.0", + "resolved": "https://registry.npmmirror.com/rechoir/-/rechoir-0.8.0.tgz", + "integrity": "sha512-/vxpCXddiX8NGfGO/mTafwjq4aFa/71pvamip0++IQk3zG8cbCj0fifNPrjjF1XMXUne91jL9OoxmdykoEtifQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "resolve": "^1.20.0" + }, + "engines": { + "node": ">= 10.13.0" + } + }, + "node_modules/relateurl": { + "version": "0.2.7", + "resolved": "https://registry.npmmirror.com/relateurl/-/relateurl-0.2.7.tgz", + "integrity": "sha512-G08Dxvm4iDN3MLM0EsP62EDV9IuhXPR6blNz6Utcp7zyV3tr4HVNINt6MpaRWbxoOHT3Q7YN2P+jaHX8vUbgog==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.10" + } + }, + "node_modules/renderkid": { + "version": "3.0.0", + "resolved": "https://registry.npmmirror.com/renderkid/-/renderkid-3.0.0.tgz", + "integrity": "sha512-q/7VIQA8lmM1hF+jn+sFSPWGlMkSAeNYcPLmDQx2zzuiDfaLrOmumR8iaUKlenFgh0XRPIUeSPlH3A+AW3Z5pg==", + "dev": true, + "license": "MIT", + "dependencies": { + "css-select": "^4.1.3", + "dom-converter": "^0.2.0", + "htmlparser2": "^6.1.0", + "lodash": "^4.17.21", + "strip-ansi": "^6.0.1" + } + }, + "node_modules/require-from-string": { + "version": "2.0.2", + "resolved": "https://registry.npmmirror.com/require-from-string/-/require-from-string-2.0.2.tgz", + "integrity": "sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/requires-port": { + "version": "1.0.0", + "resolved": "https://registry.npmmirror.com/requires-port/-/requires-port-1.0.0.tgz", + "integrity": "sha512-KigOCHcocU3XODJxsu8i/j8T9tzT4adHiecwORRQ0ZZFcp7ahwXuRU1m+yuO90C5ZUyGeGfocHDI14M3L3yDAQ==", + "dev": true + }, + "node_modules/resolve": { + "version": "1.22.10", + "resolved": "https://registry.npmmirror.com/resolve/-/resolve-1.22.10.tgz", + "integrity": "sha512-NPRy+/ncIMeDlTAsuqwKIiferiawhefFJtkNSW0qZJEqMEb+qBt/77B/jGeeek+F0uOeN05CDa6HXbbIgtVX4w==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-core-module": "^2.16.0", + "path-parse": "^1.0.7", + "supports-preserve-symlinks-flag": "^1.0.0" + }, + "bin": { + "resolve": "bin/resolve" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/resolve-cwd": { + "version": "3.0.0", + "resolved": "https://registry.npmmirror.com/resolve-cwd/-/resolve-cwd-3.0.0.tgz", + "integrity": "sha512-OrZaX2Mb+rJCpH/6CpSqt9xFVpN++x01XnN2ie9g6P5/3xelLAkXWVADpdz1IHD/KFfEXyE6V0U01OQ3UO2rEg==", + "dev": true, + "license": "MIT", + "dependencies": { + "resolve-from": "^5.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/resolve-from": { + "version": "5.0.0", + "resolved": "https://registry.npmmirror.com/resolve-from/-/resolve-from-5.0.0.tgz", + "integrity": "sha512-qYg9KP24dD5qka9J47d0aVky0N+b4fTU89LN9iDnjB5waksiC49rvMB0PrUJQGoTmH50XPiqOvAjDfaijGxYZw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/retry": { + "version": "0.13.1", + "resolved": "https://registry.npmmirror.com/retry/-/retry-0.13.1.tgz", + "integrity": "sha512-XQBQ3I8W1Cge0Seh+6gjj03LbmRFWuoszgK9ooCpwYIrhhoO80pfq4cUkU5DkknwfOfFteRwlZ56PYOGYyFWdg==", + "dev": true, + "engines": { + "node": ">= 4" + } + }, + "node_modules/rimraf": { + "version": "2.7.1", + "resolved": "https://registry.npmmirror.com/rimraf/-/rimraf-2.7.1.tgz", + "integrity": "sha512-uWjbaKIK3T1OSVptzX7Nl6PvQ3qAGtKEtVRjRuazjfL3Bx5eI409VZSqgND+4UNnmzLVdPj9FqFJNPqBZFve4w==", + "deprecated": "Rimraf versions prior to v4 are no longer supported", + "license": "ISC", + "dependencies": { + "glob": "^7.1.3" + }, + "bin": { + "rimraf": "bin.js" + } + }, + "node_modules/robust-predicates": { + "version": "3.0.2", + "resolved": "https://registry.npmmirror.com/robust-predicates/-/robust-predicates-3.0.2.tgz", + "integrity": "sha512-IXgzBWvWQwE6PrDI05OvmXUIruQTcoMDzRsOd5CDvHCVLcLHMTSYvOK5Cm46kWqlV3yAbuSpBZdJ5oP5OUoStg==" + }, + "node_modules/rw": { + "version": "1.3.3", + "resolved": "https://registry.npmmirror.com/rw/-/rw-1.3.3.tgz", + "integrity": "sha512-PdhdWy89SiZogBLaw42zdeqtRJ//zFd2PgQavcICDUgJT5oW10QCRKbJ6bg4r0/UY2M6BWd5tkxuGFRvCkgfHQ==" + }, + "node_modules/safe-buffer": { + "version": "5.2.1", + "resolved": "https://registry.npmmirror.com/safe-buffer/-/safe-buffer-5.2.1.tgz", + "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT" + }, + "node_modules/safer-buffer": { + "version": "2.1.2", + "resolved": "https://registry.npmmirror.com/safer-buffer/-/safer-buffer-2.1.2.tgz", + "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==", + "license": "MIT" + }, + "node_modules/schema-utils": { + "version": "4.3.0", + "resolved": "https://registry.npmmirror.com/schema-utils/-/schema-utils-4.3.0.tgz", + "integrity": "sha512-Gf9qqc58SpCA/xdziiHz35F4GNIWYWZrEshUc/G/r5BnLph6xpKuLeoJoQuj5WfBIx/eQLf+hmVPYHaxJu7V2g==", + "license": "MIT", + "dependencies": { + "@types/json-schema": "^7.0.9", + "ajv": "^8.9.0", + "ajv-formats": "^2.1.1", + "ajv-keywords": "^5.1.0" + }, + "engines": { + "node": ">= 10.13.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + } + }, + "node_modules/select-hose": { + "version": "2.0.0", + "resolved": "https://registry.npmmirror.com/select-hose/-/select-hose-2.0.0.tgz", + "integrity": "sha512-mEugaLK+YfkijB4fx0e6kImuJdCIt2LxCRcbEYPqRGCs4F2ogyfZU5IAZRdjCP8JPq2AtdNoC/Dux63d9Kiryg==", + "dev": true, + "license": "MIT" + }, + "node_modules/selfsigned": { + "version": "2.4.1", + "resolved": "https://registry.npmmirror.com/selfsigned/-/selfsigned-2.4.1.tgz", + "integrity": "sha512-th5B4L2U+eGLq1TVh7zNRGBapioSORUeymIydxgFpwww9d2qyKvtuPU2jJuHvYAwwqi2Y596QBL3eEqcPEYL8Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node-forge": "^1.3.0", + "node-forge": "^1" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/semver": { + "version": "7.7.1", + "resolved": "https://registry.npmmirror.com/semver/-/semver-7.7.1.tgz", + "integrity": "sha512-hlq8tAfn0m/61p4BVRcPzIGr6LKiMwo4VM6dGi6pt4qcRkmNzTcWq6eCEjEh+qXjkMDvPlOFFSGwQjoEa6gyMA==", + "license": "ISC", + "bin": { + "semver": "bin/semver.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/send": { + "version": "0.19.0", + "resolved": "https://registry.npmmirror.com/send/-/send-0.19.0.tgz", + "integrity": "sha512-dW41u5VfLXu8SJh5bwRmyYUbAoSB3c9uQh6L8h/KtsFREPWpbX1lrljJo186Jc4nmci/sGUZ9a0a0J2zgfq2hw==", + "dev": true, + "license": "MIT", + "dependencies": { + "debug": "2.6.9", + "depd": "2.0.0", + "destroy": "1.2.0", + "encodeurl": "~1.0.2", + "escape-html": "~1.0.3", + "etag": "~1.8.1", + "fresh": "0.5.2", + "http-errors": "2.0.0", + "mime": "1.6.0", + "ms": "2.1.3", + "on-finished": "2.4.1", + "range-parser": "~1.2.1", + "statuses": "2.0.1" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/send/node_modules/encodeurl": { + "version": "1.0.2", + "resolved": "https://registry.npmmirror.com/encodeurl/-/encodeurl-1.0.2.tgz", + "integrity": "sha512-TPJXq8JqFaVYm2CWmPvnP2Iyo4ZSM7/QKcSmuMLDObfpH5fi7RUGmd/rTDf+rut/saiDiQEeVTNgAmJEdAOx0w==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/send/node_modules/ms": { + "version": "2.1.3", + "resolved": "https://registry.npmmirror.com/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", + "dev": true, + "license": "MIT" + }, + "node_modules/serialize-javascript": { + "version": "6.0.2", + "resolved": "https://registry.npmmirror.com/serialize-javascript/-/serialize-javascript-6.0.2.tgz", + "integrity": "sha512-Saa1xPByTTq2gdeFZYLLo+RFE35NHZkAbqZeWNd3BpzppeVisAqpDjcp8dyf6uIvEqJRd46jemmyA4iFIeVk8g==", + "license": "BSD-3-Clause", + "dependencies": { + "randombytes": "^2.1.0" + } + }, + "node_modules/serve-index": { + "version": "1.9.1", + "resolved": "https://registry.npmmirror.com/serve-index/-/serve-index-1.9.1.tgz", + "integrity": "sha512-pXHfKNP4qujrtteMrSBb0rc8HJ9Ms/GrXwcUtUtD5s4ewDJI8bT3Cz2zTVRMKtri49pLx2e0Ya8ziP5Ya2pZZw==", + "dev": true, + "license": "MIT", + "dependencies": { + "accepts": "~1.3.4", + "batch": "0.6.1", + "debug": "2.6.9", + "escape-html": "~1.0.3", + "http-errors": "~1.6.2", + "mime-types": "~2.1.17", + "parseurl": "~1.3.2" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/serve-index/node_modules/depd": { + "version": "1.1.2", + "resolved": "https://registry.npmmirror.com/depd/-/depd-1.1.2.tgz", + "integrity": "sha512-7emPTl6Dpo6JRXOXjLRxck+FlLRX5847cLKEn00PLAgc3g2hTZZgr+e4c2v6QpSmLeFP3n5yUo7ft6avBK/5jQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/serve-index/node_modules/http-errors": { + "version": "1.6.3", + "resolved": "https://registry.npmmirror.com/http-errors/-/http-errors-1.6.3.tgz", + "integrity": "sha512-lks+lVC8dgGyh97jxvxeYTWQFvh4uw4yC12gVl63Cg30sjPX4wuGcdkICVXDAESr6OJGjqGA8Iz5mkeN6zlD7A==", + "dev": true, + "license": "MIT", + "dependencies": { + "depd": "~1.1.2", + "inherits": "2.0.3", + "setprototypeof": "1.1.0", + "statuses": ">= 1.4.0 < 2" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/serve-index/node_modules/inherits": { + "version": "2.0.3", + "resolved": "https://registry.npmmirror.com/inherits/-/inherits-2.0.3.tgz", + "integrity": "sha512-x00IRNXNy63jwGkJmzPigoySHbaqpNuzKbBOmzK+g2OdZpQ9w+sxCN+VSB3ja7IAge2OP2qpfxTjeNcyjmW1uw==", + "dev": true, + "license": "ISC" + }, + "node_modules/serve-index/node_modules/setprototypeof": { + "version": "1.1.0", + "resolved": "https://registry.npmmirror.com/setprototypeof/-/setprototypeof-1.1.0.tgz", + "integrity": "sha512-BvE/TwpZX4FXExxOxZyRGQQv651MSwmWKZGqvmPcRIjDqWub67kTKuIMx43cZZrS/cBBzwBcNDWoFxt2XEFIpQ==", + "dev": true, + "license": "ISC" + }, + "node_modules/serve-index/node_modules/statuses": { + "version": "1.5.0", + "resolved": "https://registry.npmmirror.com/statuses/-/statuses-1.5.0.tgz", + "integrity": "sha512-OpZ3zP+jT1PI7I8nemJX4AKmAX070ZkYPVWV/AaKTJl+tXCTGyVdC1a4SL8RUQYEwk/f34ZX8UTykN68FwrqAA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/serve-static": { + "version": "1.16.2", + "resolved": "https://registry.npmmirror.com/serve-static/-/serve-static-1.16.2.tgz", + "integrity": "sha512-VqpjJZKadQB/PEbEwvFdO43Ax5dFBZ2UECszz8bQ7pi7wt//PWe1P6MN7eCnjsatYtBT6EuiClbjSWP2WrIoTw==", + "dev": true, + "license": "MIT", + "dependencies": { + "encodeurl": "~2.0.0", + "escape-html": "~1.0.3", + "parseurl": "~1.3.3", + "send": "0.19.0" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/setprototypeof": { + "version": "1.2.0", + "resolved": "https://registry.npmmirror.com/setprototypeof/-/setprototypeof-1.2.0.tgz", + "integrity": "sha512-E5LDX7Wrp85Kil5bhZv46j8jOeboKq5JMmYM3gVGdGH8xFpPWXUMsNrlODCrkoxMEeNi/XZIwuRvY4XNwYMJpw==", + "dev": true, + "license": "ISC" + }, + "node_modules/shallow-clone": { + "version": "3.0.1", + "resolved": "https://registry.npmmirror.com/shallow-clone/-/shallow-clone-3.0.1.tgz", + "integrity": "sha512-/6KqX+GVUdqPuPPd2LxDDxzX6CAbjJehAAOKlNpqqUpAqPM6HeL8f+o3a+JsyGjn2lv0WY8UsTgUJjU9Ok55NA==", + "dev": true, + "license": "MIT", + "dependencies": { + "kind-of": "^6.0.2" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/shebang-command": { + "version": "2.0.0", + "resolved": "https://registry.npmmirror.com/shebang-command/-/shebang-command-2.0.0.tgz", + "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", + "license": "MIT", + "dependencies": { + "shebang-regex": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/shebang-regex": { + "version": "3.0.0", + "resolved": "https://registry.npmmirror.com/shebang-regex/-/shebang-regex-3.0.0.tgz", + "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/shell-quote": { + "version": "1.8.2", + "resolved": "https://registry.npmmirror.com/shell-quote/-/shell-quote-1.8.2.tgz", + "integrity": "sha512-AzqKpGKjrj7EM6rKVQEPpB288oCfnrEIuyoT9cyF4nmGa7V8Zk6f7RRqYisX8X9m+Q7bd632aZW4ky7EhbQztA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel": { + "version": "1.1.0", + "resolved": "https://registry.npmmirror.com/side-channel/-/side-channel-1.1.0.tgz", + "integrity": "sha512-ZX99e6tRweoUXqR+VBrslhda51Nh5MTQwou5tnUDgbtyM0dBgmhEDtWGP/xbKn6hqfPRHujUNwz5fy/wbbhnpw==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "object-inspect": "^1.13.3", + "side-channel-list": "^1.0.0", + "side-channel-map": "^1.0.1", + "side-channel-weakmap": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-list": { + "version": "1.0.0", + "resolved": "https://registry.npmmirror.com/side-channel-list/-/side-channel-list-1.0.0.tgz", + "integrity": "sha512-FCLHtRD/gnpCiCHEiJLOwdmFP+wzCmDEkc9y7NsYxeF4u7Btsn1ZuwgwJGxImImHicJArLP4R0yX4c2KCrMrTA==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "object-inspect": "^1.13.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-map": { + "version": "1.0.1", + "resolved": "https://registry.npmmirror.com/side-channel-map/-/side-channel-map-1.0.1.tgz", + "integrity": "sha512-VCjCNfgMsby3tTdo02nbjtM/ewra6jPHmpThenkTYh8pG9ucZ/1P8So4u4FGBek/BjpOVsDCMoLA/iuBKIFXRA==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.5", + "object-inspect": "^1.13.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-weakmap": { + "version": "1.0.2", + "resolved": "https://registry.npmmirror.com/side-channel-weakmap/-/side-channel-weakmap-1.0.2.tgz", + "integrity": "sha512-WPS/HvHQTYnHisLo9McqBHOJk2FkHO/tlpvldyrnem4aeQp4hai3gythswg6p01oSoTl58rcpiFAjF2br2Ak2A==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.5", + "object-inspect": "^1.13.3", + "side-channel-map": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/signal-exit": { + "version": "3.0.7", + "resolved": "https://registry.npmmirror.com/signal-exit/-/signal-exit-3.0.7.tgz", + "integrity": "sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==", + "dev": true + }, + "node_modules/sockjs": { + "version": "0.3.24", + "resolved": "https://registry.npmmirror.com/sockjs/-/sockjs-0.3.24.tgz", + "integrity": "sha512-GJgLTZ7vYb/JtPSSZ10hsOYIvEYsjbNU+zPdIHcUaWVNUEPivzxku31865sSSud0Da0W4lEeOPlmw93zLQchuQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "faye-websocket": "^0.11.3", + "uuid": "^8.3.2", + "websocket-driver": "^0.7.4" + } + }, + "node_modules/source-map": { + "version": "0.6.1", + "resolved": "https://registry.npmmirror.com/source-map/-/source-map-0.6.1.tgz", + "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==", + "license": "BSD-3-Clause", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/source-map-js": { + "version": "1.2.1", + "resolved": "https://registry.npmmirror.com/source-map-js/-/source-map-js-1.2.1.tgz", + "integrity": "sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==", + "license": "BSD-3-Clause", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/source-map-support": { + "version": "0.5.21", + "resolved": "https://registry.npmmirror.com/source-map-support/-/source-map-support-0.5.21.tgz", + "integrity": "sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w==", + "license": "MIT", + "dependencies": { + "buffer-from": "^1.0.0", + "source-map": "^0.6.0" + } + }, + "node_modules/spdy": { + "version": "4.0.2", + "resolved": "https://registry.npmmirror.com/spdy/-/spdy-4.0.2.tgz", + "integrity": "sha512-r46gZQZQV+Kl9oItvl1JZZqJKGr+oEkB08A6BzkiR7593/7IbtuncXHd2YoYeTsG4157ZssMu9KYvUHLcjcDoA==", + "dev": true, + "license": "MIT", + "dependencies": { + "debug": "^4.1.0", + "handle-thing": "^2.0.0", + "http-deceiver": "^1.2.7", + "select-hose": "^2.0.0", + "spdy-transport": "^3.0.0" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/spdy-transport": { + "version": "3.0.0", + "resolved": "https://registry.npmmirror.com/spdy-transport/-/spdy-transport-3.0.0.tgz", + "integrity": "sha512-hsLVFE5SjA6TCisWeJXFKniGGOpBgMLmerfO2aCyCU5s7nJ/rpAepqmFifv/GCbSbueEeAJJnmSQ2rKC/g8Fcw==", + "dev": true, + "license": "MIT", + "dependencies": { + "debug": "^4.1.0", + "detect-node": "^2.0.4", + "hpack.js": "^2.1.6", + "obuf": "^1.1.2", + "readable-stream": "^3.0.6", + "wbuf": "^1.7.3" + } + }, + "node_modules/spdy-transport/node_modules/debug": { + "version": "4.4.0", + "resolved": "https://registry.npmmirror.com/debug/-/debug-4.4.0.tgz", + "integrity": "sha512-6WTZ/IxCY/T6BALoZHaE4ctp9xm+Z5kY/pzYaCHRFeyVhojxlrm+46y68HA6hr0TcwEssoxNiDEUJQjfPZ/RYA==", + "dev": true, + "license": "MIT", + "dependencies": { + "ms": "^2.1.3" + }, + "engines": { + "node": ">=6.0" + }, + "peerDependenciesMeta": { + "supports-color": { + "optional": true + } + } + }, + "node_modules/spdy-transport/node_modules/ms": { + "version": "2.1.3", + "resolved": "https://registry.npmmirror.com/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", + "dev": true, + "license": "MIT" + }, + "node_modules/spdy/node_modules/debug": { + "version": "4.4.0", + "resolved": "https://registry.npmmirror.com/debug/-/debug-4.4.0.tgz", + "integrity": "sha512-6WTZ/IxCY/T6BALoZHaE4ctp9xm+Z5kY/pzYaCHRFeyVhojxlrm+46y68HA6hr0TcwEssoxNiDEUJQjfPZ/RYA==", + "dev": true, + "license": "MIT", + "dependencies": { + "ms": "^2.1.3" + }, + "engines": { + "node": ">=6.0" + }, + "peerDependenciesMeta": { + "supports-color": { + "optional": true + } + } + }, + "node_modules/spdy/node_modules/ms": { + "version": "2.1.3", + "resolved": "https://registry.npmmirror.com/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", + "dev": true, + "license": "MIT" + }, + "node_modules/statuses": { + "version": "2.0.1", + "resolved": "https://registry.npmmirror.com/statuses/-/statuses-2.0.1.tgz", + "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/string_decoder": { + "version": "1.3.0", + "resolved": "https://registry.npmmirror.com/string_decoder/-/string_decoder-1.3.0.tgz", + "integrity": "sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==", + "dev": true, + "license": "MIT", + "dependencies": { + "safe-buffer": "~5.2.0" + } + }, + "node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmmirror.com/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/strip-final-newline": { + "version": "2.0.0", + "resolved": "https://registry.npmmirror.com/strip-final-newline/-/strip-final-newline-2.0.0.tgz", + "integrity": "sha512-BrpvfNAE3dcvq7ll3xVumzjKjZQ5tI1sEUIKr3Uoks0XUl45St3FlatVqef9prk4jRDzhW6WZg+3bk93y6pLjA==", + "dev": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/style-loader": { + "version": "4.0.0", + "resolved": "https://registry.npmmirror.com/style-loader/-/style-loader-4.0.0.tgz", + "integrity": "sha512-1V4WqhhZZgjVAVJyt7TdDPZoPBPNHbekX4fWnCJL1yQukhCeZhJySUL+gL9y6sNdN95uEOS83Y55SqHcP7MzLA==", + "license": "MIT", + "engines": { + "node": ">= 18.12.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + }, + "peerDependencies": { + "webpack": "^5.27.0" + } + }, + "node_modules/supports-color": { + "version": "7.2.0", + "resolved": "https://registry.npmmirror.com/supports-color/-/supports-color-7.2.0.tgz", + "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "dev": true, + "license": "MIT", + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/supports-preserve-symlinks-flag": { + "version": "1.0.0", + "resolved": "https://registry.npmmirror.com/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz", + "integrity": "sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/tapable": { + "version": "2.2.1", + "resolved": "https://registry.npmmirror.com/tapable/-/tapable-2.2.1.tgz", + "integrity": "sha512-GNzQvQTOIP6RyTfE2Qxb8ZVlNmw0n88vp1szwWRimP02mnTsx3Wtn5qRdqY9w2XduFNUgvOwhNnQsjwCp+kqaQ==", + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/terser": { + "version": "5.39.0", + "resolved": "https://registry.npmmirror.com/terser/-/terser-5.39.0.tgz", + "integrity": "sha512-LBAhFyLho16harJoWMg/nZsQYgTrg5jXOn2nCYjRUcZZEdE3qa2zb8QEDRUGVZBW4rlazf2fxkg8tztybTaqWw==", + "license": "BSD-2-Clause", + "dependencies": { + "@jridgewell/source-map": "^0.3.3", + "acorn": "^8.8.2", + "commander": "^2.20.0", + "source-map-support": "~0.5.20" + }, + "bin": { + "terser": "bin/terser" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/terser-webpack-plugin": { + "version": "5.3.14", + "resolved": "https://registry.npmmirror.com/terser-webpack-plugin/-/terser-webpack-plugin-5.3.14.tgz", + "integrity": "sha512-vkZjpUjb6OMS7dhV+tILUW6BhpDR7P2L/aQSAv+Uwk+m8KATX9EccViHTJR2qDtACKPIYndLGCyl3FMo+r2LMw==", + "license": "MIT", + "dependencies": { + "@jridgewell/trace-mapping": "^0.3.25", + "jest-worker": "^27.4.5", + "schema-utils": "^4.3.0", + "serialize-javascript": "^6.0.2", + "terser": "^5.31.1" + }, + "engines": { + "node": ">= 10.13.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + }, + "peerDependencies": { + "webpack": "^5.1.0" + }, + "peerDependenciesMeta": { + "@swc/core": { + "optional": true + }, + "esbuild": { + "optional": true + }, + "uglify-js": { + "optional": true + } + } + }, + "node_modules/thunky": { + "version": "1.1.0", + "resolved": "https://registry.npmmirror.com/thunky/-/thunky-1.1.0.tgz", + "integrity": "sha512-eHY7nBftgThBqOyHGVN+l8gF0BucP09fMo0oO/Lb0w1OF80dJv+lDVpXG60WMQvkcxAkNybKsrEIE3ZtKGmPrA==", + "dev": true, + "license": "MIT" + }, + "node_modules/to-regex-range": { + "version": "5.0.1", + "resolved": "https://registry.npmmirror.com/to-regex-range/-/to-regex-range-5.0.1.tgz", + "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-number": "^7.0.0" + }, + "engines": { + "node": ">=8.0" + } + }, + "node_modules/toidentifier": { + "version": "1.0.1", + "resolved": "https://registry.npmmirror.com/toidentifier/-/toidentifier-1.0.1.tgz", + "integrity": "sha512-o5sSPKEkg/DIQNmH43V0/uerLrpzVedkUh8tGNvaeXpfpuwjKenlSox/2O/BTlZUtEe+JG7s5YhEz608PlAHRA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.6" + } + }, + "node_modules/ts-loader": { + "version": "9.5.2", + "resolved": "https://registry.npmmirror.com/ts-loader/-/ts-loader-9.5.2.tgz", + "integrity": "sha512-Qo4piXvOTWcMGIgRiuFa6nHNm+54HbYaZCKqc9eeZCLRy3XqafQgwX2F7mofrbJG3g7EEb+lkiR+z2Lic2s3Zw==", + "dev": true, + "license": "MIT", + "dependencies": { + "chalk": "^4.1.0", + "enhanced-resolve": "^5.0.0", + "micromatch": "^4.0.0", + "semver": "^7.3.4", + "source-map": "^0.7.4" + }, + "engines": { + "node": ">=12.0.0" + }, + "peerDependencies": { + "typescript": "*", + "webpack": "^5.0.0" + } + }, + "node_modules/ts-loader/node_modules/source-map": { + "version": "0.7.4", + "resolved": "https://registry.npmmirror.com/source-map/-/source-map-0.7.4.tgz", + "integrity": "sha512-l3BikUxvPOcn5E74dZiq5BGsTb5yEwhaTSzccU6t4sDOH8NWJCstKO5QT2CvtFoK6F0saL7p9xHAqHOlCPJygA==", + "dev": true, + "license": "BSD-3-Clause", + "engines": { + "node": ">= 8" + } + }, + "node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmmirror.com/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "dev": true, + "license": "0BSD" + }, + "node_modules/type-is": { + "version": "1.6.18", + "resolved": "https://registry.npmmirror.com/type-is/-/type-is-1.6.18.tgz", + "integrity": "sha512-TkRKr9sUTxEH8MdfuCSP7VizJyzRNMjj2J2do2Jr3Kym598JVdEksuzPQCnlFPW4ky9Q+iA+ma9BGm06XQBy8g==", + "dev": true, + "license": "MIT", + "dependencies": { + "media-typer": "0.3.0", + "mime-types": "~2.1.24" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/typescript": { + "version": "5.8.2", + "resolved": "https://registry.npmmirror.com/typescript/-/typescript-5.8.2.tgz", + "integrity": "sha512-aJn6wq13/afZp/jT9QZmwEjDqqvSGp1VT5GVg+f/t6/oVyrgXM6BY1h9BRh/O5p3PlUPAe+WuiEZOmb/49RqoQ==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "tsc": "bin/tsc", + "tsserver": "bin/tsserver" + }, + "engines": { + "node": ">=14.17" + } + }, + "node_modules/unpipe": { + "version": "1.0.0", + "resolved": "https://registry.npmmirror.com/unpipe/-/unpipe-1.0.0.tgz", + "integrity": "sha512-pjy2bYhSsufwWlKwPc+l3cN7+wuJlK6uz0YdJEOlQDbl6jo/YlPi4mb8agUkVC8BF7V8NuzeyPNqRksA3hztKQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/update-browserslist-db": { + "version": "1.1.3", + "resolved": "https://registry.npmmirror.com/update-browserslist-db/-/update-browserslist-db-1.1.3.tgz", + "integrity": "sha512-UxhIZQ+QInVdunkDAaiazvvT/+fXL5Osr0JZlJulepYu6Jd7qJtDZjlur0emRlT71EN3ScPoE7gvsuIKKNavKw==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/browserslist" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/browserslist" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "dependencies": { + "escalade": "^3.2.0", + "picocolors": "^1.1.1" + }, + "bin": { + "update-browserslist-db": "cli.js" + }, + "peerDependencies": { + "browserslist": ">= 4.21.0" + } + }, + "node_modules/util-deprecate": { + "version": "1.0.2", + "resolved": "https://registry.npmmirror.com/util-deprecate/-/util-deprecate-1.0.2.tgz", + "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==", + "license": "MIT" + }, + "node_modules/utila": { + "version": "0.4.0", + "resolved": "https://registry.npmmirror.com/utila/-/utila-0.4.0.tgz", + "integrity": "sha512-Z0DbgELS9/L/75wZbro8xAnT50pBVFQZ+hUEueGDU5FN51YSCYM+jdxsfCiHjwNP/4LCDD0i/graKpeBnOXKRA==", + "dev": true, + "license": "MIT" + }, + "node_modules/utils-merge": { + "version": "1.0.1", + "resolved": "https://registry.npmmirror.com/utils-merge/-/utils-merge-1.0.1.tgz", + "integrity": "sha512-pMZTvIkT1d+TFGvDOqodOclx0QWkkgi6Tdoa8gC8ffGAAqz9pzPTZWAybbsHHoED/ztMtkv/VoYTYyShUn81hA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4.0" + } + }, + "node_modules/uuid": { + "version": "8.3.2", + "resolved": "https://registry.npmmirror.com/uuid/-/uuid-8.3.2.tgz", + "integrity": "sha512-+NYs2QeMWy+GWFOEm9xnn6HCDp0l7QBD7ml8zLUmJ+93Q5NF0NocErnwkTkXVFNiX3/fpC6afS8Dhb/gz7R7eg==", + "dev": true, + "license": "MIT", + "bin": { + "uuid": "dist/bin/uuid" + } + }, + "node_modules/vary": { + "version": "1.1.2", + "resolved": "https://registry.npmmirror.com/vary/-/vary-1.1.2.tgz", + "integrity": "sha512-BNGbWLfd0eUPabhkXUVm0j8uuvREyTh5ovRa/dyow/BqAbZJyC+5fU+IzQOzmAKzYqYRAISoRhdQr3eIZ/PXqg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/watchpack": { + "version": "2.4.2", + "resolved": "https://registry.npmmirror.com/watchpack/-/watchpack-2.4.2.tgz", + "integrity": "sha512-TnbFSbcOCcDgjZ4piURLCbJ3nJhznVh9kw6F6iokjiFPl8ONxe9A6nMDVXDiNbrSfLILs6vB07F7wLBrwPYzJw==", + "license": "MIT", + "dependencies": { + "glob-to-regexp": "^0.4.1", + "graceful-fs": "^4.1.2" + }, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/wbuf": { + "version": "1.7.3", + "resolved": "https://registry.npmmirror.com/wbuf/-/wbuf-1.7.3.tgz", + "integrity": "sha512-O84QOnr0icsbFGLS0O3bI5FswxzRr8/gHwWkDlQFskhSPryQXvrTMxjxGP4+iWYoauLoBvfDpkrOauZ+0iZpDA==", + "dev": true, + "license": "MIT", + "dependencies": { + "minimalistic-assert": "^1.0.0" + } + }, + "node_modules/webpack": { + "version": "5.98.0", + "resolved": "https://registry.npmmirror.com/webpack/-/webpack-5.98.0.tgz", + "integrity": "sha512-UFynvx+gM44Gv9qFgj0acCQK2VE1CtdfwFdimkapco3hlPCJ/zeq73n2yVKimVbtm+TnApIugGhLJnkU6gjYXA==", + "license": "MIT", + "dependencies": { + "@types/eslint-scope": "^3.7.7", + "@types/estree": "^1.0.6", + "@webassemblyjs/ast": "^1.14.1", + "@webassemblyjs/wasm-edit": "^1.14.1", + "@webassemblyjs/wasm-parser": "^1.14.1", + "acorn": "^8.14.0", + "browserslist": "^4.24.0", + "chrome-trace-event": "^1.0.2", + "enhanced-resolve": "^5.17.1", + "es-module-lexer": "^1.2.1", + "eslint-scope": "5.1.1", + "events": "^3.2.0", + "glob-to-regexp": "^0.4.1", + "graceful-fs": "^4.2.11", + "json-parse-even-better-errors": "^2.3.1", + "loader-runner": "^4.2.0", + "mime-types": "^2.1.27", + "neo-async": "^2.6.2", + "schema-utils": "^4.3.0", + "tapable": "^2.1.1", + "terser-webpack-plugin": "^5.3.11", + "watchpack": "^2.4.1", + "webpack-sources": "^3.2.3" + }, + "bin": { + "webpack": "bin/webpack.js" + }, + "engines": { + "node": ">=10.13.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + }, + "peerDependenciesMeta": { + "webpack-cli": { + "optional": true + } + } + }, + "node_modules/webpack-cli": { + "version": "5.1.4", + "resolved": "https://registry.npmmirror.com/webpack-cli/-/webpack-cli-5.1.4.tgz", + "integrity": "sha512-pIDJHIEI9LR0yxHXQ+Qh95k2EvXpWzZ5l+d+jIo+RdSm9MiHfzazIxwwni/p7+x4eJZuvG1AJwgC4TNQ7NRgsg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@discoveryjs/json-ext": "^0.5.0", + "@webpack-cli/configtest": "^2.1.1", + "@webpack-cli/info": "^2.0.2", + "@webpack-cli/serve": "^2.0.5", + "colorette": "^2.0.14", + "commander": "^10.0.1", + "cross-spawn": "^7.0.3", + "envinfo": "^7.7.3", + "fastest-levenshtein": "^1.0.12", + "import-local": "^3.0.2", + "interpret": "^3.1.1", + "rechoir": "^0.8.0", + "webpack-merge": "^5.7.3" + }, + "bin": { + "webpack-cli": "bin/cli.js" + }, + "engines": { + "node": ">=14.15.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + }, + "peerDependencies": { + "webpack": "5.x.x" + }, + "peerDependenciesMeta": { + "@webpack-cli/generators": { + "optional": true + }, + "webpack-bundle-analyzer": { + "optional": true + }, + "webpack-dev-server": { + "optional": true + } + } + }, + "node_modules/webpack-cli/node_modules/commander": { + "version": "10.0.1", + "resolved": "https://registry.npmmirror.com/commander/-/commander-10.0.1.tgz", + "integrity": "sha512-y4Mg2tXshplEbSGzx7amzPwKKOCGuoSRP/CjEdwwk0FOGlUbq6lKuoyDZTNZkmxHdJtp54hdfY/JUrdL7Xfdug==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14" + } + }, + "node_modules/webpack-dev-middleware": { + "version": "5.3.4", + "resolved": "https://registry.npmmirror.com/webpack-dev-middleware/-/webpack-dev-middleware-5.3.4.tgz", + "integrity": "sha512-BVdTqhhs+0IfoeAf7EoH5WE+exCmqGerHfDM0IL096Px60Tq2Mn9MAbnaGUe6HiMa41KMCYF19gyzZmBcq/o4Q==", + "dev": true, + "dependencies": { + "colorette": "^2.0.10", + "memfs": "^3.4.3", + "mime-types": "^2.1.31", + "range-parser": "^1.2.1", + "schema-utils": "^4.0.0" + }, + "engines": { + "node": ">= 12.13.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + }, + "peerDependencies": { + "webpack": "^4.0.0 || ^5.0.0" + } + }, + "node_modules/webpack-dev-server": { + "version": "4.15.1", + "resolved": "https://registry.npmmirror.com/webpack-dev-server/-/webpack-dev-server-4.15.1.tgz", + "integrity": "sha512-5hbAst3h3C3L8w6W4P96L5vaV0PxSmJhxZvWKYIdgxOQm8pNZ5dEOmmSLBVpP85ReeyRt6AS1QJNyo/oFFPeVA==", + "dev": true, + "dependencies": { + "@types/bonjour": "^3.5.9", + "@types/connect-history-api-fallback": "^1.3.5", + "@types/express": "^4.17.13", + "@types/serve-index": "^1.9.1", + "@types/serve-static": "^1.13.10", + "@types/sockjs": "^0.3.33", + "@types/ws": "^8.5.5", + "ansi-html-community": "^0.0.8", + "bonjour-service": "^1.0.11", + "chokidar": "^3.5.3", + "colorette": "^2.0.10", + "compression": "^1.7.4", + "connect-history-api-fallback": "^2.0.0", + "default-gateway": "^6.0.3", + "express": "^4.17.3", + "graceful-fs": "^4.2.6", + "html-entities": "^2.3.2", + "http-proxy-middleware": "^2.0.3", + "ipaddr.js": "^2.0.1", + "launch-editor": "^2.6.0", + "open": "^8.0.9", + "p-retry": "^4.5.0", + "rimraf": "^3.0.2", + "schema-utils": "^4.0.0", + "selfsigned": "^2.1.1", + "serve-index": "^1.9.1", + "sockjs": "^0.3.24", + "spdy": "^4.0.2", + "webpack-dev-middleware": "^5.3.1", + "ws": "^8.13.0" + }, + "bin": { + "webpack-dev-server": "bin/webpack-dev-server.js" + }, + "engines": { + "node": ">= 12.13.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + }, + "peerDependencies": { + "webpack": "^4.37.0 || ^5.0.0" + }, + "peerDependenciesMeta": { + "webpack": { + "optional": true + }, + "webpack-cli": { + "optional": true + } + } + }, + "node_modules/webpack-dev-server/node_modules/rimraf": { + "version": "3.0.2", + "resolved": "https://registry.npmmirror.com/rimraf/-/rimraf-3.0.2.tgz", + "integrity": "sha512-JZkJMZkAGFFPP2YqXZXPbMlMBgsxzE8ILs4lMIX/2o0L9UBw9O/Y3o6wFw/i9YLapcUJWwqbi3kdxIPdC62TIA==", + "deprecated": "Rimraf versions prior to v4 are no longer supported", + "dev": true, + "dependencies": { + "glob": "^7.1.3" + }, + "bin": { + "rimraf": "bin.js" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/webpack-merge": { + "version": "5.10.0", + "resolved": "https://registry.npmmirror.com/webpack-merge/-/webpack-merge-5.10.0.tgz", + "integrity": "sha512-+4zXKdx7UnO+1jaN4l2lHVD+mFvnlZQP/6ljaJVb4SZiwIKeUnrT5l0gkT8z+n4hKpC+jpOv6O9R+gLtag7pSA==", + "dev": true, + "license": "MIT", + "dependencies": { + "clone-deep": "^4.0.1", + "flat": "^5.0.2", + "wildcard": "^2.0.0" + }, + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/webpack-sources": { + "version": "3.2.3", + "resolved": "https://registry.npmmirror.com/webpack-sources/-/webpack-sources-3.2.3.tgz", + "integrity": "sha512-/DyMEOrDgLKKIG0fmvtz+4dUX/3Ghozwgm6iPp8KRhvn+eQf9+Q7GWxVNMk3+uCPWfdXYC4ExGBckIXdFEfH1w==", + "license": "MIT", + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/websocket-driver": { + "version": "0.7.4", + "resolved": "https://registry.npmmirror.com/websocket-driver/-/websocket-driver-0.7.4.tgz", + "integrity": "sha512-b17KeDIQVjvb0ssuSDF2cYXSg2iztliJ4B9WdsuB6J952qCPKmnVq4DyW5motImXHDC1cBT/1UezrJVsKw5zjg==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "http-parser-js": ">=0.5.1", + "safe-buffer": ">=5.1.0", + "websocket-extensions": ">=0.1.1" + }, + "engines": { + "node": ">=0.8.0" + } + }, + "node_modules/websocket-extensions": { + "version": "0.1.4", + "resolved": "https://registry.npmmirror.com/websocket-extensions/-/websocket-extensions-0.1.4.tgz", + "integrity": "sha512-OqedPIGOfsDlo31UNwYbCFMSaO9m9G/0faIHj5/dZFDMFqPTcx6UwqyOy3COEaEOg/9VsGIpdqn62W5KhoKSpg==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=0.8.0" + } + }, + "node_modules/which": { + "version": "2.0.2", + "resolved": "https://registry.npmmirror.com/which/-/which-2.0.2.tgz", + "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", + "license": "ISC", + "dependencies": { + "isexe": "^2.0.0" + }, + "bin": { + "node-which": "bin/node-which" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/wildcard": { + "version": "2.0.1", + "resolved": "https://registry.npmmirror.com/wildcard/-/wildcard-2.0.1.tgz", + "integrity": "sha512-CC1bOL87PIWSBhDcTrdeLo6eGT7mCFtrg0uIJtqJUFyK+eJnzl8A1niH56uu7KMa5XFrtiV+AQuHO3n7DsHnLQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/wrappy": { + "version": "1.0.2", + "resolved": "https://registry.npmmirror.com/wrappy/-/wrappy-1.0.2.tgz", + "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==", + "license": "ISC" + }, + "node_modules/ws": { + "version": "8.18.3", + "resolved": "https://registry.npmmirror.com/ws/-/ws-8.18.3.tgz", + "integrity": "sha512-PEIGCY5tSlUt50cqyMXfCzX+oOPqN0vuGqWzbcJ2xvnkzkq46oOpz7dQaTDBdfICb4N14+GARUDw2XV2N4tvzg==", + "dev": true, + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": ">=5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + } + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/package.json b/plugins/tensorboard-plugins/tb_graph_ascend/fe/package.json new file mode 100644 index 0000000000000000000000000000000000000000..3185e4db9ccbed0919d9385ac721e9f4c0df2dbb --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/package.json @@ -0,0 +1,71 @@ +{ + "name": "tb-graph-ascend", + "version": "0.1.0", + "private": "true", + "main": "index.js", + "scripts": { + "dev": "webpack serve --config webpack.dev.js", + "buildLinux": "cross-env NODE_ENV=production webpack && cp dist/index.html ../server/static/", + "buildWin": "cross-env NODE_ENV=production webpack && copy dist\\index.html ..\\server\\static\\", + "prettier": "prettier --config ./.prettierrc --write ./src/**/*.ts" + }, + "devDependencies": { + "@types/d3": "^7.4.3", + "@types/lodash": "^4.17.20", + "@types/node": "^16.4.13", + "@types/offscreencanvas": "^2019.6.3", + "@types/requirejs": "^2.1.33", + "@types/resize-observer-browser": "^0.1.6", + "@types/three": "^0.131.0", + "html-loader": "^5.1.0", + "html-webpack-plugin": "^5.6.3", + "inline-chunk-html-plugin": "^1.1.1", + "ts-loader": "^9.5.1", + "tslib": "^2.6.2", + "typescript": "^5.4.5", + "webpack": "^5.96.1", + "webpack-cli": "^5.1.4", + "webpack-dev-server": "4.15.1" + }, + "dependencies": { + "@jridgewell/gen-mapping": "^0.3.12", + "@polymer/decorators": "^3.0.0", + "@polymer/iron-collapse": "^3.0.1", + "@polymer/iron-icon": "^3.0.1", + "@polymer/paper-button": "^3.0.1", + "@polymer/paper-checkbox": "^3.1.0", + "@polymer/paper-dialog": "^3.0.1", + "@polymer/paper-tooltip": "^3.0.1", + "@polymer/polymer": "^3.5.1", + "@vaadin/button": "24.6.5", + "@vaadin/checkbox": "24.6.5", + "@vaadin/checkbox-group": "^24.6.5", + "@vaadin/combo-box": "24.6.5", + "@vaadin/confirm-dialog": "24.6.5", + "@vaadin/context-menu": "24.6.5", + "@vaadin/details": "24.6.5", + "@vaadin/grid": "24.6.5", + "@vaadin/icon": "24.6.5", + "@vaadin/icons": "24.6.5", + "@vaadin/notification": "24.6.5", + "@vaadin/progress-bar": "24.6.5", + "@vaadin/select": "24.6.5", + "@vaadin/tabs": "24.6.5", + "@vaadin/tabsheet": "24.6.5", + "@vaadin/text-field": "24.6.5", + "@vaadin/tooltip": "24.6.5", + "axios": "^1.8.4", + "brace-expansion": "^1.1.12", + "clean-webpack-plugin": "^4.0.0", + "cross-env": "^7.0.3", + "css-loader": "^7.1.2", + "d3": "^7.9.0", + "dagre": "^0.8.5", + "form-data": "^4.0.4", + "i18next": "^23.16.8", + "i18next-browser-languagedetector": "^7.2.2", + "lodash": "^4.17.21", + "prettier": "^3.4.2", + "style-loader": "^4.0.0" + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/common/constant.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/common/constant.ts new file mode 100644 index 0000000000000000000000000000000000000000..986d12c65fb4265a0f06bb3534862e977d656b02 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/common/constant.ts @@ -0,0 +1,83 @@ +/* Copyright (c) 2025, Huawei Technologies. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// NPU侧模型的节点前缀 +export const NPU_PREFIX = 'N___'; +// 标杆侧模型的节点前缀 +export const BENCH_PREFIX = 'B___'; +// 未匹配节点颜色 +export const UNMATCHED_COLOR = '#C7C7C7'; + +export const JSON_TYPE = 'json' +export const DB_TYPE = 'db' + +// 双图下单个图形的最小宽度 +export const MIN_GRAPG_WIDTH = 200; + +// 预设颜色 +export const defaultColorSetting = [ + { key: '#FFFCF3', values: [0, 0.2] }, + { key: '#FFEDBE', values: [0.2, 0.4] }, + { key: '#FFDC7F', values: [0.4, 0.6] }, + { key: '#FFC62E', values: [0.6, 0.8] }, + { key: '#ff704d', values: [0.8, 1] }, +]; +// 预设颜色设置项 +export const defaultColorSelects = [{ key: 'NaN', values: [NaN, NaN] }]; + +export enum NODE_TYPE { + MODULE = 0, // 圆角矩形,有可展开,不可展开两种情况,可展开的宽度较宽,不可展开,宽度较窄 + UNEXPAND_NODE = 1, // 椭圆形,不可展开,API + API_LIST = 9, // API列表 + MULTI_COLLECTION = 8, // 融合算子 +} + +// 渲染信息 +export const DURATION_TIME = 160; // 动画时间 +export const SELECTED_STROKE_COLOR = 'rgb(31, 63, 207)'; // 选中节点颜色 +export const BENCH_NODE_COLOR = 'rgba(255, 255, 255, 1)'; // 基准模型节点颜色 +export const BENCH_STROKE_COLOR = 'rgb(161, 161, 161)'; // 基准模型边框颜色 +export const NO_MATCHED_NODE_COLOR = 'rgb(199, 199, 199)'; // 未匹配节点颜色 +export const BASE_NODE_COLOR = 'rgb(255, 255, 255)'; // 基准节点颜色,没有精度信息、API、FUSION的填充色 +export const STROKE_WIDTH = 1.5; // 边框宽度 +export const SELECTED_STROKE_WIDTH = 2; // 边框颜色 + +export const MOVE_STEP = 40; // 移动步长 +export const SCALE_STEP = 0.2; // 缩放步长 + +export const MAX_SCALE = 3; // 最大缩放 +export const MIN_SCALE = 1; // 最小缩放 + +// 溢出检测颜色 +export enum OVERFLOW_COLOR { + medium = ' #B6C7FC', + high = ' #7E96F0', + critical = ' #4668B8', + default = 'rgb(199, 199, 199)', +} + +export const NODE_TYPE_STYLES = { + // 节点样式 + [NODE_TYPE.MODULE]: { strokeDasharray: '20,0', rx: '5', ry: '5' }, + [NODE_TYPE.UNEXPAND_NODE]: { strokeDasharray: '20,0', rx: '50%', ry: '50%', fontSize: 6 }, + [NODE_TYPE.API_LIST]: { strokeDasharray: '15,1', rx: '5', ry: '5' }, + [NODE_TYPE.MULTI_COLLECTION]: { strokeDasharray: '2,1', rx: '5', ry: '5' }, +}; + +export const PREFIX_MAP = { + Single: '', + NPU: NPU_PREFIX, + Bench: BENCH_PREFIX, +}; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/common/graph-board-layout/index.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/common/graph-board-layout/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..000bb979e86acd56d486d4b1512b73ac137ad25e --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/common/graph-board-layout/index.ts @@ -0,0 +1,168 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import { customElement } from '@polymer/decorators'; +import { html, PolymerElement } from '@polymer/polymer'; +import { DarkModeMixin } from '../../polymer/dark_mode_mixin'; +import './tensorboardColor'; + +@customElement('graph-board-layout') +class TfDashboardLayout extends DarkModeMixin(PolymerElement) { + static readonly template = html` + + +
+ +
+ + + `; + + _toggleSidebar(): void { + // 通过 ID 获取元素并隐藏 + const sidebar = this.shadowRoot?.querySelector('#sidebar'); + const sidebarToggle = this.shadowRoot?.querySelector('#sidebar-toggle'); + // 检查并切换 display 样式 + if (sidebar) { + sidebar?.classList.toggle('sider-hidden'); // 改为显示 + sidebarToggle?.classList.toggle('sidebar-toggle-fold'); // 改变箭头方向 + } + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/common/graph-board-layout/tensorboardColor.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/common/graph-board-layout/tensorboardColor.ts new file mode 100644 index 0000000000000000000000000000000000000000..e76ed139f9d2956cad1cb1c782358b9ee1975c00 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/common/graph-board-layout/tensorboardColor.ts @@ -0,0 +1,57 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +const style = document.createElement('style'); +style.setAttribute('is', 'custom-style'); +style.textContent = ` + :root { + --tb-orange-weak: #ffa726; + --tb-orange-strong: #f57c00; + --tb-orange-dark: #dc7320; + --tb-grey-darker: #e2e2e2; + --tb-grey-lighter: #f3f3f3; + --tb-ui-dark-accent: #757575; + --tb-ui-light-accent: #e0e0e0; + --tb-ui-border: var(--paper-grey-300); + --tb-graph-faded: #e0d4b3; + --tb-secondary-text-color: var(--paper-grey-800); + --tb-raised-button-shadow-color: rgba(0, 0, 0, 0.2); + --primary-background-color: #fff; + --secondary-background-color:rgb(247, 247, 247); + --tb-layout-background-color: #f5f5f5; + --tb-link: #1976d2; /* material blue 700. */ + --tb-link-visited: #7b1fa2; /* material purple 700. */ + } + + :root .dark-mode { + --tb-ui-border: var(--paper-grey-700); + --tb-ui-dark-accent: var(--paper-grey-400); + --tb-ui-light-accent: var(--paper-grey-600); + --tb-secondary-text-color: var(--paper-grey-400); + --tb-raised-button-shadow-color: rgba(255, 255, 255, 0.5); + --primary-text-color: #fff; + --secondary-text-color: var(--paper-grey-400); + --primary-background-color: #303030; /* material grey A400. */ + --secondary-background-color: #3a3a3a; + --tb-layout-background-color: #3a3a3a; + --tb-link: #42a5f5; /* material blue 400. */ + --tb-link-visited: #ba68c8; /* material purple 300. */ + /* Overrides paper-material */ + --shadow-elevation-2dp_-_box-shadow: 0 2px 2px 0 rgba(255, 255, 255, 0.14), + 0 1px 5px 0 rgba(255, 255, 255, 0.12), + 0 3px 1px -2px rgba(255, 255, 255, 0.2); + } +`; +document.head.appendChild(style); diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/common/i18n.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/common/i18n.ts new file mode 100644 index 0000000000000000000000000000000000000000..87fabb2ff4a75a547f5ae17448455348ca1e8110 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/common/i18n.ts @@ -0,0 +1,96 @@ +/* Copyright (c) 2025, Huawei Technologies. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import i18next from 'i18next'; +import LanguageDetector from 'i18next-browser-languagedetector'; + +i18next + .use(LanguageDetector) + .init({ + fallbackLng: 'zh-CN', + resources: { + 'en': { + translation: { + fit: "Fit Screen", + settings: "Settings", + function: 'function', + show_debug_minimap: "show debug minimap", + show_bench_minimap: "show bench minimap", + run: "Run", + tag: "Tag", + invalid_rank_id: "Tip: The target file does not exist", + data_side: "Data Side", + search_node: "Search Node", + node_list: "Node List", + debug: "Debug", + bench: "Bench", + accuracy_error: "Accuracy Error", + overflow: "Overflow", + match_accuracy_error: "Match Accuracy Error Node", + overflow_filter_node: "Overflow Filter Node", + no_matching_nodes: "No matching nodes", + precision_desc: { + summary: "The relative error between the statistical output of the debug side and the benchmark side of the node, the larger the value, the greater the precision gap, the darker the color mark, the relative error indicator (RelativeErr): | (debug value - benchmark value) / benchmark value |", + all: "The difference between the minimum double thousand indicator of all inputs and the minimum double thousandth indicator of all outputs of the node, reflecting the decline of the double thousand indicator, the larger the value, the greater the precision gap, the darker the color mark, the double thousandth precision indicator (One Thousandth Err Ratio): The relative error of each element in the tensor is compared with the corresponding benchmark data, the proportion of relative error less than one thousandth of the total number of elements, the closer the proportion is to 1, the better", + md5: "If the md5 value of any input or output of the node is different, it will be marked red" + }, + node_match: "Node Match", + select_match_config_file: "Select Match Config File", + select_match_config_file_desc: "Select the corresponding configuration file, read the matching node information, and match the corresponding node.", + node_search: "Node Search" + } + }, + 'zh-CN': { + translation: { + fit: "自适应屏幕", + settings: "设置", + function: "功能", + show_debug_minimap: "调试侧缩略图", + show_bench_minimap: "标杆侧缩略图", + run: "目录", + tag: "文件", + invalid_rank_id: "提示:目标文件不存在", + data_side: "数据侧", + search_node: "节点搜索", + node_list: "节点列表", + debug: "调试侧", + bench: "标杆侧", + accuracy_error: "精度误差", + overflow: "精度溢出", + match_accuracy_error: "符合精度误差节点", + overflow_filter_node: "符合溢出筛选节点", + no_matching_nodes: "无匹配节点", + precision_desc: { + "summary": "节点中调试侧和标杆侧输出的统计量相对误差,值越大精度差距越大,颜色标记越深,相对误差指标(RelativeErr):| (调试值 - 标杆值) / 标杆值 |", + "all": "节点中所有输入的最小双千指标和所有输出的最小双千分之一指标的差值,反映了双千指标的下降情况,值越大精度差距越大,颜色标记越深,双千分之一精度指标(One Thousandth Err Ratio):Tensor中的元素逐个与对应的标杆数据对比,相对误差小于千分之一的比例占总元素个数的比例,比例越接近1越好", + "md5": "节点中任意输入输出的md5值不同则标记为红色" + }, + node_match: "节点匹配", + select_match_config_file: "选择匹配配置文件", + select_match_config_file_desc: "选择对应配置文件,会读取匹配节点信息,并将对应节点进行匹配。", + node_search: "节点搜索" + } + } + }, + detection: { + order: ['navigator'] // 只使用浏览器语言检测 + }, + debug: false, + interpolation: { + escapeValue: false + } + }); + +export default i18next; \ No newline at end of file diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_ascend/index.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_ascend/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..b085bee4647da0d1445d88f06077e6c9e2466973 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_ascend/index.ts @@ -0,0 +1,462 @@ +/* Copyright (c) 2025, Huawei Technologies. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { customElement, observe, property } from '@polymer/decorators'; +import { html, PolymerElement } from '@polymer/polymer'; +import { LegacyElementMixin } from '../polymer/legacy_element_mixin'; +import { LoadGraphFileInfoListType } from './type'; +import useGraphAscend from './useGraphAscend'; +import { formatBytes, safeJSONParse } from '../utils'; +import { isEmpty } from 'lodash'; +import '../graph_board/index'; +import '../graph_info_board/index'; +import '../graph_controls_board/index'; +import '../common/graph-board-layout'; +import '@vaadin/confirm-dialog' +import { Notification } from '@vaadin/notification'; +import request from '../utils/request'; +import type { SelectedItemType, SelectionType, ProgressType, GraphConfigType, GraphAllNodeType, NodeListType, UnmatchedNodeType } from './type'; + +@customElement('graph-ascend') +class TfGraphDashboard extends LegacyElementMixin(PolymerElement) { + static readonly template = html` + + + +
+
+ +
+ + +
+
+ +
+

如果您仍坚持继续,请知悉以下风险:

+
非授权路径访问可能存在信息泄露和文件内容篡改。 文件过大或格式异常,可能导致性能问题或服务中断。路径中存在软链接或权限不当,可能存在越权访问和数据篡改风险。
+

继续操作将由您自行承担相关后果。如非明确知晓风险,请取消操作并联系管理员处理。

+
+ +
+
+
+ + `; + + @property({ type: Object }) + metaDir: Record = {}; + + @property({ type: Object, notify: true }) + selection: SelectionType | null = null; + + @property({ type: Object, notify: true }) + nodelist: NodeListType = {} as NodeListType; + + @property({ type: Object, notify: true }) + unmatched: UnmatchedNodeType = {} as UnmatchedNodeType; + + @property({ type: Object, notify: true }) + matchedlist: any; + + @property({ type: String, notify: true }) + selectedNode: string = ''; + + @property({ type: String, notify: true }) + jumpToNode: string = ''; + + @property({ type: Object, notify: true }) + colors: any; + + @property({ type: Boolean, notify: true }) + isOverflowFilter: boolean = false; + + @property({ type: Object }) + progressData: ProgressType = { progress: 0, progressValue: 0, done: false }; + + @property({ type: Boolean }) + isSingleGraph: boolean = false; + + @property({ type: Array }) + microsteps: number[] = []; + + @property({ type: Array }) + steps: Array = [{ value: 0, label: '0' }]; + + @property({ type: Array }) + ranks: Array = [{ value: 0, label: '0' }]; + + + @property({ type: Array }) + overflowcheck: boolean = false; + + @property({ type: Object }) + tooltips: object = {}; + + @property({ type: Object }) + colorset: object = {}; + + @property({ type: Object }) + npuMatchNodes: object = {}; + + @property({ type: Object }) + benchMatchNodes: object = {}; + + @property({ type: Object }) + matchedConfigFiles: string[] = []; + @property({ type: Object }) + task: string = ''; + + @property({ type: Boolean }) + safeDialogOpened: boolean = false; + + @property({ type: Array }) + fileListError: Array = []; + + @property({ type: Object }) + needLoadAllNodeList: boolean = true; + + private currentSelection: SelectionType | null = null; + private useGraphAscend = useGraphAscend(); + private eventSource: EventSource | null = null; + + @observe('selection') + updateGraphData = () => { + if (!this.selection?.run || !this.selection?.tag) { + return; + } + const isFileChange = this.currentSelection?.run !== this.selection?.run || this.currentSelection?.tag !== this.selection?.tag; + const isDBChange = this.currentSelection?.rank !== this.selection?.rank || this.currentSelection?.step !== this.selection?.step; + if (isFileChange) { + switch (this.selection?.type) { + case 'json': + this.loadJSONGraphData(this.selection); + break; + case 'db': + this.loadDBGraphData(this.selection, true); + break; + default: + break; + } + } + else if (isDBChange) { + this.loadDBGraphData(this.selection, false); + } + else if (this.currentSelection?.microStep !== this.selection?.microStep) { + this.initGraphBoard(); // 只改变microsteps时,不重新加载图数据 + } + else { + return + } + this.set('needLoadAllNodeList', true); + this.currentSelection = this.selection; + }; + + override async ready(): Promise { + super.ready(); + const { data, error } = await this.useGraphAscend.loadGraphFileInfoList(true); + const safeDialog = this.shadowRoot?.querySelector('#safe-dialog') as HTMLElement; + safeDialog.addEventListener('cancel', this.onSafeDialogCancel as any); + if (!isEmpty(error)) { + this.set('safeDialogOpened', true); + this.set('fileListError', error); + } + this.set('metaDir', data); + document.addEventListener( + 'contextMenuTag-changed', + (event: any) => this.set('jumpToNode', event.detail?.nodeName), + { passive: true }, + ); + } + // 关闭默认安全模式继续 + onSafeDialogCancel = async () => { + const { data, error } = await this.useGraphAscend.loadGraphFileInfoList(false); + if (!isEmpty(error)) { + Notification.show('文件列表加载失败', { + position: 'middle', + duration: 2000, + theme: 'error', + }); + return; + } + this.set('metaDir', data); + } + loadDBGraphData = async (metaData: SelectionType, isInitDB: boolean = false) => { + if (isInitDB) { + this.progreesLoading('正在初始化数据库', '请稍后', { progress: 10, progressValue: 10, done: false }); + await request({ url: 'loadGraphData', method: 'GET', params: metaData }); + await this.loadGraphConfig(metaData) + } + this.progreesLoading('正在初始化图', '请稍后', { progress: 90, progressValue: 90, done: false }); + this.initGraphBoard(); // 先读取配置,再加载图,顺序很重要 + this.progreesLoading('初始化完成', '请稍后', { progress: 100, progressValue: 100, done: true }); + } + + + loadJSONGraphData = async (metaData: SelectionType) => { + if (this.eventSource) { + this.eventSource.close(); + this.eventSource = null; + } + + this.eventSource = new EventSource(`loadGraphData?run=${metaData.run}&tag=${metaData.tag}&type=${metaData.type}`); + this.eventSource.onmessage = async (e) => { + const data = safeJSONParse(e.data); + if (data?.error) { + this.progreesError('初始化图失败', data.error); + } + if (data?.status === 'reading') { + this.progressReading('正在读取文件', data); + } + if (data?.status === 'loading') { + if (data.done) { + this.eventSource?.close(); + this.eventSource = null; + try { + await this.loadGraphConfig(metaData) + this.initGraphBoard(); // 先读取配置,再加载图,顺序很重要 + this.progreesLoading('初始化完成', '请稍后', data); + } catch (error) { + this.progreesError('初始化图失败', error); + } + } else { + this.progreesLoading('正在解析文件', '正在初始化模型,请稍后.', data); + } + } + }; + + this.eventSource.onerror = (e) => { + if (!this.progressData || !this.progressData.done) { + this.progreesError('加载失败', '请检查文件格式是否正确'); + } + this.eventSource?.close(); + }; + } + + loadGraphConfig = async (metaData) => { + const { success, data, error } = await this.useGraphAscend.loadGraphConfig(metaData); + const config = data as GraphConfigType; + if (success) { + this.set('colors', config.colors); + this.set('tooltips', safeJSONParse(config.tooltips)); + this.set('overflowcheck', config.overflowCheck); + this.set('colorset', Object.entries(config.colors || {})); + this.set('isSingleGraph', config.isSingleGraph); + this.set('task', config.task); + this.set('matchedConfigFiles', ['未选择', ...config.matchedConfigFiles]); + const microstepsCount = Number(config.microSteps); + const ranks = config.ranks || [0]; + const steps = config.steps || [0]; + if (microstepsCount) { + const microstepsArray = Array.from({ length: microstepsCount + 1 }, (_, index) => ({ + label: index === 0 ? 'ALL' : String(index - 1), + value: index - 1, + })); + this.set('microsteps', microstepsArray); + } + if (ranks.length > 0) { + const ranksArray = ranks.map((rank) => ({ + label: rank, + value: rank, + })) + this.set('ranks', ranksArray); + } + if (steps.length > 0) { + const stepsArray = steps.map((step) => ({ + label: step, + value: step, + })) + this.set('steps', stepsArray); + } + } else { + Notification.show(`图配置加载失败:${error}`, { + position: 'middle', + duration: 2000, + theme: 'error', + }); + } + }; + + loadGraphAllNodeList = async (metaData: SelectionType) => { + const { success, data, error } = await this.useGraphAscend.loadGraphAllNodeList(metaData); + const allNodeList = data as GraphAllNodeType; + if (success) { + const nodelist = {} as NodeListType; + const unmatched = {} as UnmatchedNodeType; + if (this.isSingleGraph) { + nodelist.npu = allNodeList?.npuNodeList; + } else { + nodelist.npu = allNodeList?.npuNodeList; + nodelist.bench = allNodeList?.benchNodeList; + unmatched.npuNodeList = allNodeList?.npuUnMatchNodes; + unmatched.benchNodeList = allNodeList?.benchUnMatchNodes; + } + this.set('npuMatchNodes', allNodeList?.npuMatchNodes); + this.set('benchMatchNodes', allNodeList?.benchMatchNodes); + this.set('nodelist', nodelist); + this.set('unmatched', unmatched); + } else { + Notification.show(`图节点列表加载失败:${error}`, { + position: 'middle', + duration: 2000, + theme: 'error', + }); + } + }; + + initGraphBoard = () => { + (this.shadowRoot?.querySelector('#graph-board') as any)?.initGraphHierarchy(this.jumpToNode); + if (this.jumpToNode) { + this.set('selectedNode', this.jumpToNode); + this.set('jumpToNode', ''); + } + }; + + onFitTap(): void { + (this.shadowRoot?.querySelector('#graph-board') as any).fitScreen(); + } + + progressReading = (title, data) => { + data.progressValue = data.done ? 1 : data.progress / 100.0; + data.size = formatBytes(data.size); + data.read = formatBytes(data.read); + data.title = title; + data.info = `文件大小: ${data.size}, 已读取: ${data.read}`; + this.set('progressData', data); + }; + + progreesLoading = (title, info, progressData) => { + const data = { + ...progressData, + title, + info, + }; + data.progressValue = progressData.done ? 1 : progressData.progress / 100.0; + this.set('progressData', data); + }; + + progreesError = (title, info) => { + const data = { + ...this.progressData, + title, + info, + }; + this.updateStyles({ + '--progress-background-color': 'red', + '--progress-color': 'red', + }); + this.set('progressData', data); + }; +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_ascend/type/index.d.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_ascend/type/index.d.ts new file mode 100644 index 0000000000000000000000000000000000000000..dc9b7455e37f692a940c49f2fc79a3c4e7998982 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_ascend/type/index.d.ts @@ -0,0 +1,83 @@ +/* Copyright (c) 2025, Huawei Technologies. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +export interface ProgressType { + progress?: number; + progressValue?: number; + size?: number; + read?: number; + done?: boolean; +} + +export interface SelectedItemType { + value: number; + label: string; +} + +export interface SelectionType { + run: string; + tag: string; + type: 'json' | 'db'; + microStep?: number; + step?: number; + rank?: number; +} + +export interface GraphConfigType { + tooltips: string; + colors: Record; + overflowCheck: boolean; + microSteps: number; + isSingleGraph: boolean; + matchedConfigFiles: string[]; + task: string; + ranks: number[]; + steps: number[] +} + +export interface GraphAllNodeType { + npuNodeList: string[]; + benchNodeList: string[]; + npuUnMatchNodes: string[]; + benchUnMatchNodes: string[]; + npuMatchNodes: string[]; + benchMatchNodes: string[]; +} + +export interface NodeListType { + npu: string[]; + bench: string[]; +} + +export interface UnmatchedNodeType { + npuNodeList: string[]; + benchNodeList: string[]; +} + +export interface LoadGraphFileInfoListType { + data: { + [string]: string[]; + }; + error: [ + { + run: string; + tag: string; + info: string; + } + ]; +} \ No newline at end of file diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_ascend/useGraphAscend.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_ascend/useGraphAscend.ts new file mode 100644 index 0000000000000000000000000000000000000000..d340d8a6fb9562de017ff6301f7497291f13c38f --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_ascend/useGraphAscend.ts @@ -0,0 +1,56 @@ +/* Copyright (c) 2025, Huawei Technologies. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import request from '../utils/request'; +import { LoadGraphFileInfoListType, SelectionType } from './type'; +const useGraphAscend = () => { + + const loadGraphFileInfoList = async (isSafeCheck: boolean): Promise => { + try { + const params = { + isSafeCheck + }; + const result = await request({ url: 'load_meta_dir', method: 'GET', params: params }); + return result as unknown as LoadGraphFileInfoListType; + } catch (err) { + return { + data: {}, + error: [ + { + run: '', + tag: '', + info: '加载文件列表失败', + } + ], + }; + } + }; + const loadGraphConfig = async (metaData: SelectionType): Promise => { + const result = await request({ url: 'loadGraphConfigInfo', method: 'POST', data: { metaData } }); // 获取异步的 ArrayBuffer + return result; + }; + + const loadGraphAllNodeList = async (metaData: SelectionType): Promise => { + const result = await request({ url: 'loadGraphAllNodeList', method: 'POST', data: { metaData } }); // 获取异步的 ArrayBuffer + return result; + }; + + return { + loadGraphConfig, + loadGraphAllNodeList, + loadGraphFileInfoList, + }; +}; +export default useGraphAscend; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_board/components/hierarchy/index.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_board/components/hierarchy/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..111ea4f6b16b368d0b110db346ff4358169cf905 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_board/components/hierarchy/index.ts @@ -0,0 +1,774 @@ +/* Copyright (c) 2025, Huawei Technologies. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { PolymerElement, html } from '@polymer/polymer'; +import { customElement, property, observe } from '@polymer/decorators'; +import * as d3 from 'd3'; +import useGraph from './useGraph'; +import { changeGraphPosition } from '../../../utils/index'; +import { parseTransform } from '../../../utils/index'; +import { isEmpty, throttle } from 'lodash'; +import * as minimap from '../minimap/minimap'; +import { + NPU_PREFIX, + BENCH_PREFIX, + MOVE_STEP, + SCALE_STEP, + NODE_TYPE, + MAX_SCALE, + MIN_SCALE, + PREFIX_MAP, +} from '../../../common/constant'; +import '../minimap/index'; +import '@vaadin/context-menu'; +import { Notification } from '@vaadin/notification'; +import type { UseGraphType } from '../../type'; +import type { HierarchyNodeType, ContextMenuItem, PreProcessDataConfigType, GraphType } from '../../type'; +import type { ContextMenuItemSelectedEvent } from '@vaadin/context-menu'; +import type { SelectionType } from '../../../graph_ascend/type'; + +const EXPAND_MATCHED_NODE = 1; +const DATA_COMMUNICATION = 2; +const DATA_COMMUNICATION_TYEPE = { + send: '数据发送', + receive: '数据接收', + send_receive: '数据发送接收', +}; +@customElement('graph-hierarchy') +class Hierarchy extends PolymerElement { + static readonly template = html` + +
+ + + + + + +