From a1476d0549718eae99e70859ae821741914e5aee Mon Sep 17 00:00:00 2001 From: yangyidiao Date: Tue, 15 Apr 2025 20:39:29 +0800 Subject: [PATCH] 1 --- .gitee/ISSUE_TEMPLATE/bug-report.yml | 1 + .gitee/ISSUE_TEMPLATE/documentation.yml | 1 + .gitee/ISSUE_TEMPLATE/feature-request.yml | 1 + .gitee/ISSUE_TEMPLATE/question.yml | 1 + README.md | 9 + blacklist.txt | 8 + classify_rule.yaml | 14 +- cmake/scripts/utest/gen_tiling_data_stub.py | 3 +- docs/GroupedMatmul.md | 340 ++++ docs/GroupedMatmulV2.md | 342 ++++ docs/GroupedMatmulV3.md | 409 ++++ docs/GroupedMatmulV4.md | 453 +++++ ...76\350\256\241\344\273\213\347\273\215.md" | 115 ++ ...6\346\240\270\346\226\271\346\241\210.png" | Bin 0 -> 10765 bytes ...6\347\273\204\346\226\271\346\241\210.png" | Bin 0 -> 34477 bytes ...6\346\240\270\346\226\271\346\241\210.png" | Bin 0 -> 15095 bytes ...\257UB_Buffer\345\210\206\351\205\215.png" | Bin 0 -> 17929 bytes ...7\346\265\201\347\250\213\345\233\276.png" | Bin 0 -> 27396 bytes .../transformer/grouped_matmul/CMakeLists.txt | 57 + .../grouped_matmul_generate_data.py | 38 + .../grouped_matmul_print_result.py | 27 + .../grouped_matmul/grouped_matmul_utils.cpp | 199 ++ .../grouped_matmul/grouped_matmul_utils.h | 150 ++ .../grouped_matmul/run_grouped_matmul_case.sh | 32 + .../grouped_matmul/test_grouped_matmul_v2.cpp | 115 ++ .../grouped_matmul/test_grouped_matmul_v3.cpp | 128 ++ .../grouped_matmul/test_grouped_matmul_v4.cpp | 135 ++ .../grouped_matmul/grouped_matmul.cpp | 202 ++ .../grouped_matmul/grouped_matmul.h | 453 +++++ .../grouped_matmul/grouped_matmul_antiquant.h | 544 ++++++ .../grouped_matmul_antiquant_a16w8_msd.h | 950 +++++++++ .../grouped_matmul_quant_mixcore.h | 419 ++++ .../grouped_matmul/grouped_matmul_utils.h | 210 ++ .../grouped_matmul/grouped_matmul_vector.h | 76 + .../grouped_matmul/ophost/CMakeLists.txt | 55 + .../ophost/aclnn_grouped_matmul.cpp | 1724 +++++++++++++++++ .../ophost/aclnn_grouped_matmul.h | 63 + .../ophost/aclnn_grouped_matmul_v2.h | 65 + .../ophost/aclnn_grouped_matmul_v3.h | 65 + .../ophost/aclnn_grouped_matmul_v4.h | 90 + .../ophost/fallback_grouped_matmul.cpp | 235 +++ .../grouped_matmul/ophost/grouped_matmul.cpp | 76 + .../grouped_matmul/ophost/grouped_matmul.h | 36 + .../ophost/grouped_matmul_def.cpp | 142 ++ .../ophost/grouped_matmul_proto.cpp | 1492 ++++++++++++++ .../ophost/grouped_matmul_tiling.cpp | 999 ++++++++++ .../ophost/grouped_matmul_tiling.h | 59 + .../utils/inc/tests/utils/aclnn_tensor_list.h | 55 + .../utils/inc/tests/utils/tensor_intf.h | 3 + .../framework/utils/src/aclnn_tensor_list.cpp | 175 ++ .../framework/utils/src/tensor_intf.cpp | 26 +- .../ops_test/src/transformer/CMakeLists.txt | 1 + .../transformer/grouped_matmul/CMakeLists.txt | 73 + .../comm/inc/aclnn_grouped_matmul_case.h | 52 + .../comm/inc/aclnn_grouped_matmul_param.h | 70 + .../comm/inc/grouped_matmul_case.h | 54 + .../comm/inc/grouped_matmul_param.h | 56 + .../comm/src/aclnn_grouped_matmul_case.cpp | 141 ++ .../comm/src/aclnn_grouped_matmul_param.cpp | 140 ++ .../comm/src/grouped_matmul_case.cpp | 213 ++ .../comm/src/grouped_matmul_param.cpp | 43 + .../grouped_matmul/utest/ts_grouped_matmul.h | 29 + .../utest/ts_grouped_matmul_kernel.cpp | 638 ++++++ .../utest/ts_grouped_matmul_tiling.cpp | 233 +++ .../utest_aclnn/ts_aclnn_grouped_matmul.cpp | 385 ++++ .../utest_aclnn/ts_aclnn_grouped_matmul.h | 32 + 66 files changed, 12948 insertions(+), 4 deletions(-) create mode 100644 docs/GroupedMatmul.md create mode 100644 docs/GroupedMatmulV2.md create mode 100644 docs/GroupedMatmulV3.md create mode 100644 docs/GroupedMatmulV4.md create mode 100644 "docs/common/GroupedMatmul\347\256\227\345\255\220\350\256\276\350\256\241\344\273\213\347\273\215.md" create mode 100644 "docs/fig/GMM\345\257\271\350\247\222\347\272\277\345\210\206\346\240\270\346\226\271\346\241\210.png" create mode 100644 "docs/fig/GMM\345\257\271\350\247\222\347\272\277\345\210\206\347\273\204\346\226\271\346\241\210.png" create mode 100644 "docs/fig/GMM\346\250\252\345\220\221\345\210\206\346\240\270\346\226\271\346\241\210.png" create mode 100644 "docs/fig/GMM\351\207\217\345\214\226\345\234\272\346\231\257UB_Buffer\345\210\206\351\205\215.png" create mode 100644 "docs/fig/GMM\351\207\217\345\214\226\345\234\272\346\231\257\346\265\201\347\250\213\345\233\276.png" create mode 100644 examples/transformer/grouped_matmul/CMakeLists.txt create mode 100644 examples/transformer/grouped_matmul/grouped_matmul_generate_data.py create mode 100644 examples/transformer/grouped_matmul/grouped_matmul_print_result.py create mode 100644 examples/transformer/grouped_matmul/grouped_matmul_utils.cpp create mode 100644 examples/transformer/grouped_matmul/grouped_matmul_utils.h create mode 100644 examples/transformer/grouped_matmul/run_grouped_matmul_case.sh create mode 100644 examples/transformer/grouped_matmul/test_grouped_matmul_v2.cpp create mode 100644 examples/transformer/grouped_matmul/test_grouped_matmul_v3.cpp create mode 100644 examples/transformer/grouped_matmul/test_grouped_matmul_v4.cpp create mode 100644 src/transformer/grouped_matmul/grouped_matmul.cpp create mode 100644 src/transformer/grouped_matmul/grouped_matmul.h create mode 100644 src/transformer/grouped_matmul/grouped_matmul_antiquant.h create mode 100644 src/transformer/grouped_matmul/grouped_matmul_antiquant_a16w8_msd.h create mode 100644 src/transformer/grouped_matmul/grouped_matmul_quant_mixcore.h create mode 100644 src/transformer/grouped_matmul/grouped_matmul_utils.h create mode 100644 src/transformer/grouped_matmul/grouped_matmul_vector.h create mode 100644 src/transformer/grouped_matmul/ophost/CMakeLists.txt create mode 100644 src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul.cpp create mode 100644 src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul.h create mode 100644 src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul_v2.h create mode 100644 src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul_v3.h create mode 100644 src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul_v4.h create mode 100644 src/transformer/grouped_matmul/ophost/fallback_grouped_matmul.cpp create mode 100644 src/transformer/grouped_matmul/ophost/grouped_matmul.cpp create mode 100644 src/transformer/grouped_matmul/ophost/grouped_matmul.h create mode 100644 src/transformer/grouped_matmul/ophost/grouped_matmul_def.cpp create mode 100644 src/transformer/grouped_matmul/ophost/grouped_matmul_proto.cpp create mode 100644 src/transformer/grouped_matmul/ophost/grouped_matmul_tiling.cpp create mode 100644 src/transformer/grouped_matmul/ophost/grouped_matmul_tiling.h create mode 100644 tests/ut/ops_test/framework/utils/inc/tests/utils/aclnn_tensor_list.h create mode 100644 tests/ut/ops_test/framework/utils/src/aclnn_tensor_list.cpp create mode 100644 tests/ut/ops_test/src/transformer/grouped_matmul/CMakeLists.txt create mode 100644 tests/ut/ops_test/src/transformer/grouped_matmul/comm/inc/aclnn_grouped_matmul_case.h create mode 100644 tests/ut/ops_test/src/transformer/grouped_matmul/comm/inc/aclnn_grouped_matmul_param.h create mode 100644 tests/ut/ops_test/src/transformer/grouped_matmul/comm/inc/grouped_matmul_case.h create mode 100644 tests/ut/ops_test/src/transformer/grouped_matmul/comm/inc/grouped_matmul_param.h create mode 100644 tests/ut/ops_test/src/transformer/grouped_matmul/comm/src/aclnn_grouped_matmul_case.cpp create mode 100644 tests/ut/ops_test/src/transformer/grouped_matmul/comm/src/aclnn_grouped_matmul_param.cpp create mode 100644 tests/ut/ops_test/src/transformer/grouped_matmul/comm/src/grouped_matmul_case.cpp create mode 100644 tests/ut/ops_test/src/transformer/grouped_matmul/comm/src/grouped_matmul_param.cpp create mode 100644 tests/ut/ops_test/src/transformer/grouped_matmul/utest/ts_grouped_matmul.h create mode 100644 tests/ut/ops_test/src/transformer/grouped_matmul/utest/ts_grouped_matmul_kernel.cpp create mode 100644 tests/ut/ops_test/src/transformer/grouped_matmul/utest/ts_grouped_matmul_tiling.cpp create mode 100644 tests/ut/ops_test/src/transformer/grouped_matmul/utest_aclnn/ts_aclnn_grouped_matmul.cpp create mode 100644 tests/ut/ops_test/src/transformer/grouped_matmul/utest_aclnn/ts_aclnn_grouped_matmul.h diff --git a/.gitee/ISSUE_TEMPLATE/bug-report.yml b/.gitee/ISSUE_TEMPLATE/bug-report.yml index b0492446..a56d876f 100644 --- a/.gitee/ISSUE_TEMPLATE/bug-report.yml +++ b/.gitee/ISSUE_TEMPLATE/bug-report.yml @@ -56,6 +56,7 @@ body: - PFA - FIA - FFN + - GMM - 其他 validations: required: true \ No newline at end of file diff --git a/.gitee/ISSUE_TEMPLATE/documentation.yml b/.gitee/ISSUE_TEMPLATE/documentation.yml index 4c314310..b74bfcb7 100644 --- a/.gitee/ISSUE_TEMPLATE/documentation.yml +++ b/.gitee/ISSUE_TEMPLATE/documentation.yml @@ -37,6 +37,7 @@ body: - PFA - FIA - FFN + - GMM - 其他 validations: required: true \ No newline at end of file diff --git a/.gitee/ISSUE_TEMPLATE/feature-request.yml b/.gitee/ISSUE_TEMPLATE/feature-request.yml index 9c53ffa5..139caafb 100644 --- a/.gitee/ISSUE_TEMPLATE/feature-request.yml +++ b/.gitee/ISSUE_TEMPLATE/feature-request.yml @@ -43,6 +43,7 @@ body: - PFA - FIA - FFN + - GMM - 其他 validations: required: true \ No newline at end of file diff --git a/.gitee/ISSUE_TEMPLATE/question.yml b/.gitee/ISSUE_TEMPLATE/question.yml index 66855d89..ad5ebf46 100644 --- a/.gitee/ISSUE_TEMPLATE/question.yml +++ b/.gitee/ISSUE_TEMPLATE/question.yml @@ -25,6 +25,7 @@ body: - PFA - FIA - FFN + - GMM - 其他 validations: required: true \ No newline at end of file diff --git a/README.md b/README.md index 35c22132..3bd0efc4 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ cann-ops-adv,是基于昇腾硬件的融合算子库(adv表示advanced)。 | ├── flash_attention_score # 训练FA算子示例代码 | ├── flash_attention_score_grad # 训练FAG算子示例代码 | ├── fused_infer_attention_score # 推理FIA算子示例代码 + | ├── grouped_matmul # 推理GroupedMatmul算子示例代码 | ├── incre_flash_attention # 推理IFA算子示例代码 | ├── prompt_flash_attention # 推理PFA算子示例代码 | @@ -54,6 +55,9 @@ cann-ops-adv,是基于昇腾硬件的融合算子库(adv表示advanced)。 | | ├── fused_infer_attention_score # 推理FIA算子源代码 | | | ├── ophost # ophost目录,包含tiling策略、aclnn接口、算子原型、信息库配置 | | | ├── fused_infer_attention_score*.* # FIA算子Kernel源文件 + | | ├── grouped_matmul # 推理GroupedMatmul算子源代码 + | | | ├── ophost # ophost目录,包含tiling策略、aclnn接口、算子原型、信息库配置 + | | | ├── grouped_matmul*.* # GroupedMatmul算子kernel源文件 | | ├── incre_flash_attention # 推理IFA算子源代码 | | | ├── ophost # ophost目录,包含tiling策略、aclnn接口、算子原型、信息库配置 | | | ├── incre_flash_attention*.* # IFA算子kernel源文件 @@ -85,6 +89,10 @@ cann-ops-adv,是基于昇腾硬件的融合算子库(adv表示advanced)。 | FlashAttentionScoreGradV2 | FlashAttentionScoreV2的反向计算,相较于FlashAttentionScoreGard,新增psetype、q_start_idx、kv_start_idx参数。 |
  • [FlashAttentionScoreGradV2](./docs/FlashAttentionScoreGradV2.md)
  • [FlashAttentionUnpaddingScoreGradV2](./docs/FlashAttentionUnpaddingScoreGradV2.md) | | FusedInferAttentionScore | 融合PromptFlashAttentionV3,IncreFlashAttentionV4的功能。
    IFA新增: lse输出、per-token伪量化特性。
    PFA新增: lse输出、伪量化、左Padding、Paged Attention特性。 | [FusedInferAttentionScore](./docs/FusedInferAttentionScore.md) | | FusedInferAttentionScoreV2 | 在FusedInferAttentionScore基础上, IFA 新增kv伪量化参数分离。
    PFA新增:prefix特性。 | [FusedInferAttentionScoreV2](./docs/FusedInferAttentionScoreV2.md) | +| GroupedMatmul | 实现分组矩阵乘计算,每组矩阵乘的维度大小可以不同。 | [GroupedMatmul](./docs/GroupedMatmul.md) | +| GroupedMatmulV2 | 在GroupedMatmul基础上新增:支持不同分组轴;非量化场景支持x、weight转置;非量化场景支持x、weight输入都为float32类型;量化和伪量化场景支持weight转置。 | [GroupedMatmulV2](./docs/GroupedMatmulV2.md) | +| GroupedMatmulV3 | 在GroupedMatmulV2基础上groupList从aclIntArray指针输入改为了aclTensor指针输入。 | [GroupedMatmulV3](./docs/GroupedMatmulV3.md) | +| GroupedMatmulV4 | 在GroupedMatmulV3基础上新增:支持groupListOptional中数值为分组轴上每组大小;量化场景支持静态量化和动态量化bfloat16和float16输出,包括带激活及不带激活场景;支持伪量化weight是INT4的输入。 | [GroupedMatmulV4](./docs/GroupedMatmulV4.md) | | IncreFlashAttention | 使用FlashAttention算法实现self-attention(自注意力)的计算。 | [IncreFlashAttention](./docs/IncreFlashAttention.md) | | IncreFlashAttentionV2 | 在IncreFlashAttention基础上新增量化特性。 | [IncreFlashAttentionV2](./docs/IncreFlashAttentionV2.md) | | IncreFlashAttentionV3 | 在IncreFlashAttentionV2基础上新增位置编码、page attention、kv cache反量化特性。 | [IncreFlashAttentionV3](./docs/IncreFlashAttentionV3.md) | @@ -276,6 +284,7 @@ cann-ops-adv仓提供了如下融合算子的代码实现设计,方便开发 - [FA/FAG算子设计介绍](./docs/common/FA-FAG算子设计介绍.md) - [FFN算子设计介绍](./docs/common/FFN算子设计介绍.md) - [IFA算子设计介绍](./docs/common/IFA算子设计介绍.md) +- [GroupedMatmul算子设计介绍](./docs/common/GroupedMatmul算子设计介绍.md) - [PFA算子设计介绍](./docs/common/PFA算子设计介绍.md) ## 贡献指南 diff --git a/blacklist.txt b/blacklist.txt index b5b50d4a..cdc8c24e 100644 --- a/blacklist.txt +++ b/blacklist.txt @@ -17,6 +17,14 @@ src/transformer/ffn/ophost/ffn.cpp src/transformer/ffn/ophost/fallback_ffn.cpp src/transformer/ffn/ophost/ffn_def.cpp src/transformer/ffn/ffn_nonquant_nz.h +src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul.h +src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul_v2.h +src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul_v3.h +src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul_v4.h +src/transformer/grouped_matmul/ophost/grouped_matmul.h +src/transformer/grouped_matmul/ophost/grouped_matmul_proto.h +src/transformer/grouped_matmul/ophost/grouped_matmul_def.cpp +src/transformer/grouped_matmul/ophost/fallback_grouped_matmul.cpp src/transformer/incre_flash_attention/ifa_public_define.h src/transformer/incre_flash_attention/incre_flash_attention.cpp src/transformer/incre_flash_attention/incre_flash_attention_allvec_new.h diff --git a/classify_rule.yaml b/classify_rule.yaml index 06b86ab5..2f5bfa18 100644 --- a/classify_rule.yaml +++ b/classify_rule.yaml @@ -57,7 +57,19 @@ src: - tests/ut/ops_test/src/transformer/ffn options: - ffn - + + grouped_matmul: + module: True + src: + - src/transformer/grouped_matmul + tests: + ut: + ops_test: + src: + - tests/ut/ops_test/src/transformer/grouped_matmul + options: + - grouped_matmul + incre_flash_attention: module: True src: diff --git a/cmake/scripts/utest/gen_tiling_data_stub.py b/cmake/scripts/utest/gen_tiling_data_stub.py index c02f1e58..0586c058 100644 --- a/cmake/scripts/utest/gen_tiling_data_stub.py +++ b/cmake/scripts/utest/gen_tiling_data_stub.py @@ -231,8 +231,7 @@ class Process: "#undef GET_TILING_DATA_MEMBER\n" "#define GET_TILING_DATA_MEMBER(tiling_type, member, var, tiling) \\\n" "decltype(tiling_type::member) var; \\\n" - "size_t offset = (size_t)(&((tiling_type *)0)->member); \\\n" - "(void)memcpy_s(&var, sizeof(decltype(var)), tiling + offset, sizeof(decltype(var))); \n" + "(void)memcpy_s(&var, sizeof(decltype(var)), tiling + (size_t)(&((tiling_type *)0)->member), sizeof(decltype(var))); \n" ) source = bgn_src + def_src cls._write_file(file=stub_file, src=source) diff --git a/docs/GroupedMatmul.md b/docs/GroupedMatmul.md new file mode 100644 index 00000000..8d9b1791 --- /dev/null +++ b/docs/GroupedMatmul.md @@ -0,0 +1,340 @@ +声明:本文使用[Creative Commons License version 4.0](https://creativecommons.org/licenses/by/4.0/legalcode)许可协议,转载、引用或修改等操作请遵循此许可协议。 + +# GroupedMatmul + +## 支持的产品型号 +- Atlas A2 训练系列产品/Atlas 800I A2 推理产品 + +产品形态详细说明请参见[昇腾产品形态说明](https://www.hiascend.com/document/redirect/CannCommunityProductForm) + +## 功能描述 + +- 算子功能:实现分组矩阵乘计算,每组矩阵乘的维度大小可以不同。基本功能为矩阵乘,如$y_i[m_i,n_i]=x_i[m_i,k_i] \times weight_i[k_i,n_i], i=1...g$,其中g为分组个数,$m_i/k_i/n_i$为应对shape。根据x、weight、y的Tensor数量支持如下4种场景: + + - x、weight、y都为多tensor,即每组的数据对应的Tensor是独立的。 + - x为单tensor,weight/y为多tensor,此时需要通过可选参数group_list说明x在行上的分组情况,如group_list[0]=10说明x的前10行参与第一组矩阵乘计算。 + - x、weight为多tensor,y为单tensor,此时每组矩阵乘的结果放在同一个Tensor中连续存放。 + - x、y为单tensor,weight为多tensor,属于前两种情况的组合。 + + **说明:** 单tensor指一个tensor list中所有分组的tensor在M轴上合并为1个;否则为多tensor。 +- 计算公式: + - **非量化场景:** + + $$ + y_i=x_i\times weight_i + bias_i + $$ + + - **量化场景:** + + $$ + y_i=(x_i\times weight_i + bias_i) * scale_i + offset_i + $$ + + - **反量化场景:** + + $$ + y_i=(x_i\times weight_i + bias_i) * scale_i + $$ + + - **伪量化场景:** + + $$ + y_i=x_i\times (weight_i + antiquant\_offset_i) * antiquant\_scale_i + bias_i + $$ + +## 实现原理 + +详细实现原理参考[GroupedMatmul设计](./common/GroupedMatmul算子设计介绍.md)。 + +## 算子执行接口 + +每个算子分为[两段式接口](common/两段式接口.md),必须先调用“aclnnGroupedMatmulGetWorkspaceSize”接口获取入参并根据计算流程计算所需workspace大小,再调用“aclnnGroupedMatmul”接口执行计算。 + +* `aclnnStatus aclnnGroupedMatmulGetWorkspaceSize(const aclTensorList* x, const aclTensorList* weight, const aclTensorList* biasOptional, const aclTensorList* scaleOptional, const aclTensorList* offsetOptional, const aclTensorList* antiquantScaleOptional, const aclTensorList* antiquantOffsetOptional, const aclIntArray* groupListOptional, int64_t splitItem, const aclTensorList* y, uint64_t* workspaceSize, aclOpExecutor** executor)` +* `aclnnStatus aclnnGroupedMatmul(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream)` + +**说明**: + +- 算子执行接口对外屏蔽了算子内部实现逻辑以及不同代际NPU的差异,且开发者无需编译算子,实现了算子的精简调用。 +- 若开发者不使用算子执行接口的调用算子,也可以定义基于Ascend IR的算子描述文件,通过ATC工具编译获得算子om文件,然后加载模型文件执行算子,详细调用方法可参见《应用开发指南》的[单算子调用 > 单算子模型执行](https://hiascend.com/document/redirect/CannCommunityCppOpcall)章节。 + +### aclnnGroupedMatmulGetWorkspaceSize + +- **参数说明:** + - x(aclTensorList\*,计算输入):必选参数,Device侧的aclTensorList,公式中的输入x,数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32,[数据格式](common/数据格式.md)支持ND,支持的最大长度为128个。 + - weight(aclTensorList\*,计算输入):必选参数,Device侧的aclTensorList,公式中的weight,数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32,[数据格式](common/数据格式.md)支持ND,支持的最大长度为128个。 + - biasOptional(aclTensorList\*,计算输入)可选参数,Device侧的aclTensorList,公式中的bias,数据类型支持FLOAT16、FLOAT32、INT32,[数据格式](common/数据格式.md)支持ND,长度与weight相同。 + - scaleOptional(aclTensorList\*,计算输入)可选参数,Device侧的aclTensorList,代表量化参数中的缩放因子,数据类型支持UINT64,[数据格式](common/数据格式.md)支持ND,长度与weight相同。 + - offsetOptional(aclTensorList\*,计算输入)可选参数,Device侧的aclTensorList,代表量化参数中的偏移量,数据类型支持FLOAT32,[数据格式](common/数据格式.md)支持ND,长度与weight相同。 + - antiquantScaleOptional(aclTensorList\*,计算输入)可选参数,Device侧的aclTensorList,代表伪量化参数中的缩放因子,数据类型支持FLOAT16、BFLOAT16,[数据格式](common/数据格式.md)支持ND,长度与weight相同。 + - antiquantOffsetOptional(aclTensorList\*,计算输入)可选参数,Device侧的aclTensorList,代表伪量化参数中的偏移量,数据类型支持FLOAT16、BFLOAT16,[数据格式](common/数据格式.md)支持ND,长度与weight相同。 + - groupListOptional(aclIntArray\*,计算输入):可选参数,Host侧的aclIntArray类型,代表输入和输出M方向的matmul索引情况,数据类型支持INT64,[数据格式](common/数据格式.md)支持ND,长度与weight相同。 + - splitItem(int64\_t,计算输入):整数型参数,代表输出是否要做tensor切分,0/1代表输出为多tensor;2/3代表输出为单tensor,默认值为0。 + - y(aclTensorList\*,计算输出):Device侧的aclTensorList,公式中的输出y,数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32,[数据格式](common/数据格式.md)支持ND,支持的最大长度为128个。 + - workspaceSize(uint64\_t\*,出参):返回需要在Device侧申请的workspace大小。 + - executor(aclOpExecutor\*\*,出参):返回op执行器,包含了算子计算流程。 + +- **返回值:** + + 返回aclnnStatus状态码,具体参见[aclnn返回码](common/aclnn返回码.md)。 + + ``` + 第一段接口完成入参校验,若出现以下错误码,则对应原因为: + - 返回161001(ACLNN_ERR_PARAM_NULLPTR): + 1.如果传入参数是必选输入、输出或者必选属性,且是空指针。 + 2.传入参数weight的元素存在空指针。 + 3.传入参数x的元素为空指针,且传出参数y的元素不为空指针。 + 4.传入参数x的元素不为空指针,且传出参数y的元素为空指针。 + - 返回161002(ACLNN_ERR_PARAM_INVALID): + 1.x、weight、biasOptional、scaleOptional、offsetOptional、antiquantScaleOptional、antiquantOffsetOptional、groupListOptional、splitItem、y的数据类型和数据格式不在支持的范围内。 + 2.weight的长度大于128。 + 3.若bias不为空,bias的长度不等于weight的长度。 + 4.splitItem为2、3的场景,y长度不等于1。 + 5.splitItem为0、1的场景,y长度不等于weight的长度,groupListOptional长度不等于weight的长度。 + ``` + +### aclnnGroupedMatmul + +- **参数说明:** + - workspace(void\*,入参):在Device侧申请的workspace内存地址。 + - workspaceSize(uint64\_t,入参):在Device侧申请的workspace大小,由第一段接口aclnnGroupedMatmulGetWorkspaceSize获取。 + - executor(aclOpExecutor\*,入参):op执行器,包含了算子计算流程。 + - stream(aclrtStream,入参):指定执行任务的AscendCL stream流。 + +- **返回值:** + + 返回aclnnStatus状态码,具体参见[aclnn返回码](common/aclnn返回码.md)。 + +## 约束与限制 + - 非量化场景支持的输入类型为: + - x为FLOAT16、weight为FLOAT16、biasOptional为FLOAT16、scaleOptional为空、offsetOptional为空、antiquantScaleOptional为空、antiquantOffsetOptional为空、y为FLOAT16; + - x为BFLOAT16、weight为BFLOAT16、biasOptional为FLOAT32、scaleOptional为空、offsetOptional为空、antiquantScaleOptional为空、antiquantOffsetOptional为空、y为BFLOAT16; + - x为FLOAT32、weight为FLOAT32、biasOptional为FLOAT32、scaleOptional为空、offsetOptional为空、antiquantScaleOptional为空、antiquantOffsetOptional为空、y为FLOAT32; + - 量化场景支持的输入类型为: + + - x为INT8、weight为INT8、biasOptional为INT32、scaleOptional为UINT64、offsetOptional为空、antiquantScaleOptional为空、antiquantOffsetOptional为空、y为INT8; + - 伪量化场景支持的输入类型为: + - x为FLOAT16、weight为INT8、biasOptional为FLOAT16、scaleOptional为空,offsetOptional为空,antiquantScaleOptional为FLOAT16、antiquantOffsetOptional为FLOAT16、y为FLOAT16; + - x为BFLOAT16、weight为INT8、biasOptional为FLOAT32、scaleOptional为空,offsetOptional为空,antiquantScaleOptional为BFLOAT16、antiquantOffsetOptional为BFLOAT16、y为BFLOAT16; + - 如果传入groupListOptional,groupListOptional必须为非负递增数列,groupListOptional长度不能为1。 + - 当前支持的场景: + 支持场景中单表示单tensor,多表示多tensor,表示顺序为x,weight,y,例,单多单表示支持x为单tensor,weight多tensor,y单tensor的场景。 + + | 支持场景 | 场景限制 | + |:-------:| :-------| + | 多多多 |1)仅支持splitItem为0/1
    2)x中tensor支持2-6维,weight中tensor需为2维,y中tensor维度和x保持一致
    3)若x中存在tensor大于2维,groupListOptional必须传空
    4)若x中tensor为2维且传入groupListOptional,groupListOptional的差值需与x中tensor的第一维一一对应 | + | 单多单 |1)仅支持splitItem为2/3
    2)必须传groupListOptional,且最后一个值与x中tensor的第一维相等
    3)x,weight,y中tensor需为2维
    4)weight中每个tensor的N轴必须相等 | + | 单多多 |1)仅支持splitItem为0/1
    2)必须传groupListOptional,groupListOptional的差值需与y中tensor的第一维一一对应
    3)x,weight,y中tensor需为2维 | + | 多多单 |1)仅支持splitItem为2/3
    2)x,weight,y中tensor需为2维
    3)weight中每个tensor的N轴必须相等
    4)若传入groupListOptional,groupListOptional的差值需与x中tensor的第一维一一对应 | + - x和weight中每一组tensor的最后一维大小都应小于65536。$x_i$的最后一维指当属性transpose_x为false时$x_i$的K轴或当transpose_x为true时$x_i$的M轴。$weight_i$的最后一维指当属性transpose_weight为false时$weight_i$的N轴或当transpose_weight为true时$weight_i$的K轴。 + - x和weight中每一组tensor的每一维大小在32字节对齐后都应小于int32的最大值2147483647。 + +## 算子原型 + +```c++ +REG_OP(GroupedMatmul) + .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_BF16, DT_INT8, DT_FLOAT})) + .DYNAMIC_INPUT(weight, TensorType({DT_FLOAT16, DT_BF16, DT_INT8, DT_FLOAT, DT_INT4})) + .DYNAMIC_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .DYNAMIC_INPUT(scale, TensorType({DT_UINT64, DT_BF16, DT_FLOAT32})) + .DYNAMIC_INPUT(offset, TensorType({DT_FLOAT32})) + .DYNAMIC_INPUT(antiquant_scale, TensorType({DT_FLOAT16, DT_BF16})) + .DYNAMIC_INPUT(antiquant_offset, TensorType({DT_FLOAT16, DT_BF16})) + .OPTIONAL_INPUT(group_list, TensorType({DT_INT64})) + .OPTIONAL_INPUT(per_token_scale, TensorType({DT_FLOAT})) + .DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT16, DT_BF16, DT_INT8, DT_FLOAT})) + .ATTR(split_item, Int, 0) + .ATTR(dtype, Int, 0) + .ATTR(transpose_weight, Bool, false) + .ATTR(transpose_x, Bool, false) + .ATTR(group_type, Int, -1) + .ATTR(group_list_type, Int, 0) + .ATTR(act_type, Int, 0) + .OP_END_FACTORY_REG(GroupedMatmul) +``` + +参数解释请参见**算子执行接口**。 + +## 调用示例 +- PyTorch框架调用 + + 如果通过PyTorch单算子方式调用该融合算子,则需要参考PyTorch融合算子[torch_npu.npu_grouped_matmul](https://hiascend.com/document/redirect/PyTorchAPI);如果用户定制了该融合算子,则需要参考《Ascend C算子开发》手册[适配PyTorch框架](https://hiascend.com/document/redirect/CannCommunityAscendCInvorkOnNetwork)。 + +- aclnn单算子调用方式 + + 通过aclnn单算子调用示例代码如下,仅供参考,具体编译和执行过程请参考[编译与运行样例](common/编译与运行样例.md)。 + +```c++ +#include +#include +#include "acl/acl.h" +#include "aclnnop/aclnn_grouped_matmul.h" + +#define CHECK_RET(cond, return_expr) \ + do { \ + if (!(cond)) { \ + return_expr; \ + } \ + } while (0) + +#define LOG_PRINT(message, ...) \ + do { \ + printf(message, ##__VA_ARGS__); \ + } while (0) + +int64_t GetShapeSize(const std::vector& shape) { + int64_t shapeSize = 1; + for (auto i : shape) { + shapeSize *= i; + } + return shapeSize; +} + +int Init(int32_t deviceId, aclrtStream* stream) { + // 固定写法,AscendCL初始化 + auto ret = aclInit(nullptr); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret); + ret = aclrtSetDevice(deviceId); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret); + ret = aclrtCreateStream(stream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret); + return 0; +} + +template +int CreateAclTensor(const std::vector& shape, void** deviceAddr, + aclDataType dataType, aclTensor** tensor) { + auto size = GetShapeSize(shape) * sizeof(T); + // 调用aclrtMalloc申请Device侧内存 + auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret); + + // 调用aclrtMemcpy将Host侧数据拷贝到Device侧内存上 + std::vector hostData(size, 0); + ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret); + + // 计算连续tensor的strides + std::vector strides(shape.size(), 1); + for (int64_t i = shape.size() - 2; i >= 0; i--) { + strides[i] = shape[i + 1] * strides[i + 1]; + } + + // 调用aclCreateTensor接口创建aclTensor + *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, + shape.data(), shape.size(), *deviceAddr); + return 0; +} + + +int CreateAclTensorList(const std::vector>& shapes, void** deviceAddr, + aclDataType dataType, aclTensorList** tensor) { + int size = shapes.size(); + aclTensor* tensors[size]; + for (int i = 0; i < size; i++) { + int ret = CreateAclTensor(shapes[i], deviceAddr + i, dataType, tensors + i); + CHECK_RET(ret == ACL_SUCCESS, return ret); + } + *tensor = aclCreateTensorList(tensors, size); + return ACL_SUCCESS; +} + + +int main() { + // 1. (固定写法)device/stream初始化,参考AscendCL对外接口列表 + // 根据自己的实际device填写deviceId + int32_t deviceId = 0; + aclrtStream stream; + auto ret = Init(deviceId, &stream); + // check根据自己的需要处理 + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Init acl failed. ERROR: %d\n", ret); return ret); + + // 2. 构造输入与输出,需要根据API的接口自定义构造 + std::vector> xShape = {{1, 16}, {4, 32}}; + std::vector> weightShape= {{16, 24}, {32, 16}}; + std::vector> biasShape = {{24}, {16}}; + std::vector> yShape = {{1, 24}, {4, 16}}; + void* xDeviceAddr[2]; + void* weightDeviceAddr[2]; + void* biasDeviceAddr[2]; + void* yDeviceAddr[2]; + aclTensorList* x = nullptr; + aclTensorList* weight = nullptr; + aclTensorList* bias = nullptr; + aclIntArray* groupedList = nullptr; + aclTensorList* scale = nullptr; + aclTensorList* offset = nullptr; + aclTensorList* antiquantScale = nullptr; + aclTensorList* antiquantOffset = nullptr; + aclTensorList* y = nullptr; + int64_t splitItem = 0; + + // 创建x aclTensorList + ret = CreateAclTensorList(xShape, xDeviceAddr, aclDataType::ACL_FLOAT16, &x); + CHECK_RET(ret == ACL_SUCCESS, return ret); + // 创建weight aclTensorList + ret = CreateAclTensorList(weightShape, weightDeviceAddr, aclDataType::ACL_FLOAT16, &weight); + CHECK_RET(ret == ACL_SUCCESS, return ret); + // 创建bias aclTensorList + ret = CreateAclTensorList(biasShape, biasDeviceAddr, aclDataType::ACL_FLOAT16, &bias); + CHECK_RET(ret == ACL_SUCCESS, return ret); + // 创建y aclTensorList + ret = CreateAclTensorList(yShape, yDeviceAddr, aclDataType::ACL_FLOAT16, &y); + CHECK_RET(ret == ACL_SUCCESS, return ret); + + uint64_t workspaceSize = 0; + aclOpExecutor* executor; + + // 3. 调用CANN算子库API + // 调用aclnnGroupedMatmul第一段接口 + ret = aclnnGroupedMatmulGetWorkspaceSize(x, weight, bias, scale, offset, antiquantScale, antiquantOffset, groupedList, splitItem, y, &workspaceSize, &executor); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnGroupedMatmulGetWorkspaceSize failed. ERROR: %d\n", ret); return ret); + // 根据第一段接口计算出的workspaceSize申请device内存 + void* workspaceAddr = nullptr; + if (workspaceSize > 0) { + ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); return ret); + } + // 调用aclnnGroupedMatmul第二段接口 + ret = aclnnGroupedMatmul(workspaceAddr, workspaceSize, executor, stream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnGroupedMatmul failed. ERROR: %d\n", ret); return ret); + + // 4. (固定写法)同步等待任务执行结束 + ret = aclrtSynchronizeStream(stream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret); return ret); + + // 5. 获取输出的值,将Device侧内存上的结果拷贝至Host侧,需要根据具体API的接口定义修改 + for (int i = 0; i < 2; i++) { + auto size = GetShapeSize(yShape[i]); + std::vector resultData(size, 0); + ret = aclrtMemcpy(resultData.data(), size * sizeof(resultData[0]), yDeviceAddr[i], + size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return ret); + for (int64_t j = 0; j < size; j++) { + LOG_PRINT("result[%ld] is: %f\n", j, resultData[j]); + } + } + + // 6. 释放aclTensor和aclScalar,需要根据具体API的接口定义修改 + aclDestroyTensorList(x); + aclDestroyTensorList(weight); + aclDestroyTensorList(bias); + aclDestroyTensorList(y); + + // 7. 释放device资源,需要根据具体API的接口定义修改 + for (int i = 0; i < 2; i++) { + aclrtFree(xDeviceAddr[i]); + aclrtFree(weightDeviceAddr[i]); + aclrtFree(biasDeviceAddr[i]); + aclrtFree(yDeviceAddr[i]); + } + if (workspaceSize > 0) { + aclrtFree(workspaceAddr); + } + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + return 0; +} +``` + diff --git a/docs/GroupedMatmulV2.md b/docs/GroupedMatmulV2.md new file mode 100644 index 00000000..539c4f10 --- /dev/null +++ b/docs/GroupedMatmulV2.md @@ -0,0 +1,342 @@ +声明:本文使用[Creative Commons License version 4.0](https://creativecommons.org/licenses/by/4.0/legalcode)许可协议,转载、引用或修改等操作请遵循此许可协议。 + +# GroupedMatmulV2 + +## 支持的产品型号 +- Atlas A2 训练系列产品/Atlas 800I A2 推理产品 + +产品形态详细说明请参见[昇腾产品形态说明](https://www.hiascend.com/document/redirect/CannCommunityProductForm) + +## 功能描述 + +- 算子功能:实现分组矩阵乘计算,每组矩阵乘的维度大小可以不同。基本功能为矩阵乘,如$y_i[m_i,n_i]=x_i[m_i,k_i] \times weight_i[k_i,n_i], i=1...g$,其中g为分组个数,$m_i/k_i/n_i$为应对shape。相较于[GroupedMatmul](GroupedMatmul.md)接口,**此接口新增**: + - 支持不同分组轴,由groupType表示。 + - 非量化场景,支持x,weight转置(转置指若shape为[M,K]时,则stride为[1,M],数据排布为[K,M]的场景)。 + - 非量化场景支持x,weight输入都为float32类型。 + - 量化、伪量化场景,支持weight转置,支持weight为单tensor。 + +- 计算公式: + - **非量化场景:** + + $$ + y_i=x_i\times weight_i + bias_i + $$ + + - **量化场景:** + + $$ + y_i=(x_i\times weight_i + bias_i) * scale_i + offset_i + $$ + + - **反量化场景:** + + $$ + y_i=(x_i\times weight_i + bias_i) * scale_i + $$ + + - **伪量化场景:** + + $$ + y_i=x_i\times (weight_i + antiquant\_offset_i) * antiquant\_scale_i + bias_i + $$ + +## 实现原理 + +详细实现原理参考[GroupedMatmul设计](./common/GroupedMatmul算子设计介绍.md)。 + +## 算子执行接口 + +每个算子分为[两段式接口](common/两段式接口.md),必须先调用“aclnnGroupedMatmulV2GetWorkspaceSize”接口获取入参并根据计算流程计算所需workspace大小,再调用“aclnnGroupedMatmulV2”接口执行计算。 + +* `aclnnStatus aclnnGroupedMatmulV2GetWorkspaceSize(const aclTensorList* x, const aclTensorList* weight, const aclTensorList* biasOptional, const aclTensorList* scaleOptional, const aclTensorList* offsetOptional, const aclTensorList* antiquantScaleOptional, const aclTensorList* antiquantOffsetOptional, const aclIntArray* groupListOptional, int64_t splitItem, int64_t groupType, const aclTensorList* y, uint64_t* workspaceSize, aclOpExecutor** executor)` +* `aclnnStatus aclnnGroupedMatmulV2(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream)` + +**说明**: + +- 算子执行接口对外屏蔽了算子内部实现逻辑以及不同代际NPU的差异,且开发者无需编译算子,实现了算子的精简调用。 +- 若开发者不使用算子执行接口的调用算子,也可以定义基于Ascend IR的算子描述文件,通过ATC工具编译获得算子om文件,然后加载模型文件执行算子,详细调用方法可参见《应用开发指南》的[单算子调用 > 单算子模型执行](https://hiascend.com/document/redirect/CannCommunityCppOpcall)章节。 + +### aclnnGroupedMatmulV2GetWorkspaceSize + +- **参数说明:** + - x(aclTensorList\*,计算输入):必选参数,Device侧的aclTensorList,公式中的输入x,数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32,[数据格式](common/数据格式.md)支持ND,支持的最大长度为128个。 + - weight(aclTensorList\*,计算输入):必选参数,Device侧的aclTensorList,公式中的weight,数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32,[数据格式](common/数据格式.md)支持ND,支持的最大长度为128个。 + - biasOptional(aclTensorList\*,计算输入)可选参数,Device侧的aclTensorList,公式中的bias,数据类型支持FLOAT16、FLOAT32、INT32,[数据格式](common/数据格式.md)支持ND,长度与weight相同。 + - scaleOptional(aclTensorList\*,计算输入)可选参数,Device侧的aclTensorList,代表量化参数中的缩放因子,数据类型支持UINT64,[数据格式](common/数据格式.md)支持ND,长度与weight相同。 + - offsetOptional(aclTensorList\*,计算输入)可选参数,Device侧的aclTensorList,代表量化参数中的偏移量,数据类型支持FLOAT32,[数据格式](common/数据格式.md)支持ND,长度与weight相同。 + - antiquantScaleOptional(aclTensorList\*,计算输入)可选参数,Device侧的aclTensorList,代表伪量化参数中的缩放因子,数据类型支持FLOAT16、BFLOAT16,[数据格式](common/数据格式.md)支持ND,长度与weight相同。 + - antiquantOffsetOptional(aclTensorList\*,计算输入)可选参数,Device侧的aclTensorList,代表伪量化参数中的偏移量,数据类型支持FLOAT16、BFLOAT16,[数据格式](common/数据格式.md)支持ND,长度与weight相同。 + - groupListOptional(aclIntArray\*,计算输入):可选参数,Host侧的aclIntArray类型,分组轴方向的matmul索引情况,分组轴由参数groupType表示,数据类型支持INT64,[数据格式](common/数据格式.md)支持ND,长度与weight相同。 + - splitItem(int64\_t,计算输入):整数型参数,代表输出是否要做tensor切分,0/1代表输出为多tensor;2/3代表输出为单tensor,默认值为0。 + - groupType(int64\_t,计算输入):整数型参数,代表需要分组的轴,如矩阵乘为C[m,n]=A[m,k]xB[k,n],则groupType取值-1:不分组,0:m轴分组,1:n轴分组,2:k轴分组,默认值为-1,当前不支持n轴分组。 + - y(aclTensorList\*,计算输出):Device侧的aclTensorList,公式中的输出y,数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32,[数据格式](common/数据格式.md)支持ND,支持的最大长度为128个。 + - workspaceSize(uint64\_t\*,出参):返回需要在Device侧申请的workspace大小。 + - executor(aclOpExecutor\*\*,出参):返回op执行器,包含了算子计算流程。 + +- **返回值:** + + 返回aclnnStatus状态码,具体参见[aclnn返回码](common/aclnn返回码.md)。 + + ``` + 第一段接口完成入参校验,若出现以下错误码,则对应原因为: + - 返回161001(ACLNN_ERR_PARAM_NULLPTR): + 1.如果传入参数是必选输入、输出或者必选属性,且是空指针。 + 2.传入参数weight的元素存在空指针。 + 3.传入参数x的元素为空指针,且传出参数y的元素不为空指针。 + 4.传入参数x的元素不为空指针,且传出参数y的元素为空指针。 + - 返回161002(ACLNN_ERR_PARAM_INVALID): + 1.x、weight、biasOptional、scaleOptional、offsetOptional、antiquantScaleOptional、antiquantOffsetOptional、groupListOptional、splitItem、groupType、y的数据类型和数据格式不在支持的范围内。 + 2.weight的长度大于128;若bias不为空,bias的长度不等于weight的长度。 + 3.splitItem为2、3的场景,y长度不等于1; + 4.splitItem为0、1的场景,y长度不等于weight的长度,groupListOptional长度不等于weight的长度。 + ``` + +### aclnnGroupedMatmulV2 + +- **参数说明:** + - workspace(void\*,入参):在Device侧申请的workspace内存地址。 + - workspaceSize(uint64\_t,入参):在Device侧申请的workspace大小,由第一段接口aclnnGroupedMatmulV2GetWorkspaceSize获取。 + - executor(aclOpExecutor\*,入参):op执行器,包含了算子计算流程。 + - stream(aclrtStream,入参):指定执行任务的AscendCL stream流。 + +- **返回值:** + + 返回aclnnStatus状态码,具体参见[aclnn返回码](common/aclnn返回码.md)。 + +## 约束与限制 + - 非量化场景支持的输入类型为: + - x为FLOAT16、weight为FLOAT16、biasOptional为FLOAT16、scaleOptional为空、offsetOptional为空、antiquantScaleOptional为空、antiquantOffsetOptional为空、y为FLOAT16; + - x为BFLOAT16、weight为BFLOAT16、biasOptional为FLOAT32、scaleOptional为空、offsetOptional为空、antiquantScaleOptional为空、antiquantOffsetOptional为空、y为BFLOAT16; + - x为FLOAT32、weight为FLOAT32、biasOptional为FLOAT32、scaleOptional为空、offsetOptional为空、antiquantScaleOptional为空、antiquantOffsetOptional为空、y为FLOAT32; + - 量化场景支持的输入类型为: + + - x为INT8、weight为INT8、biasOptional为INT32、scaleOptional为UINT64、offsetOptional为空、antiquantScaleOptional为空、antiquantOffsetOptional为空、y为INT8; + - 伪量化场景支持的输入类型为: + - x为FLOAT16、weight为INT8、biasOptional为FLOAT16、scaleOptional为空,offsetOptional为空,antiquantScaleOptional为FLOAT16、antiquantOffsetOptional为FLOAT16、y为FLOAT16; + - x为BFLOAT16、weight为INT8、biasOptional为FLOAT32、scaleOptional为空,offsetOptional为空,antiquantScaleOptional为BFLOAT16、antiquantOffsetOptional为BFLOAT16、y为BFLOAT16; + - 如果传入groupListOptional,groupListOptional必须为非负递增数列,groupListOptional长度不能为1。 + - 不同groupType支持场景: + - 量化、伪量化仅支持groupType为-1和0场景。 + - 支持场景中单表示单tensor,多表示多tensor,表示顺序为x、weight、y。例如单多单表示支持x为单tensor、weight多tensor、y单tensor的场景。 + + | groupType | 支持场景 | 场景限制 | + |:---------:|:-------:| :-------| + | -1 | 多多多 |1)仅支持splitItem为0/1
    2)x中tensor支持2-6维,weight中tensor需为2维,y中tensor维度和x保持一致
    3)groupListOptional必须传空 | + | 0 | 单单单 |1)仅支持splitItem为2/3
    2)weight中tensor需为3维,x,y中tensor需为2维
    3)必须传groupListOptional,且最后一个值与x中tensor的第一维相等 | + | 0 | 单多单 |1)仅支持splitItem为2/3
    2)必须传groupListOptional,且最后一个值与x中tensor的第一维相等
    3)x,weight,y中tensor需为2维
    4)weight中每个tensor的N轴必须相等 | + | 0 | 单多多 |1)仅支持splitItem为0/1
    2)必须传groupListOptional,groupListOptional的差值需与y中tensor的第一维一一对应
    3)x,weight,y中tensor需为2维 | + | 0 | 多多单 |1)仅支持splitItem为2/3
    2)x,weight,y中tensor需为2维
    3)weight中每个tensor的N轴必须相等
    4)若传入groupListOptional,groupListOptional的差值需与x中tensor的第一维一一对应 | + | 2 | 单单单 |1)仅支持splitItem为2/3
    2)x,weight中tensor需为2维,y中tensor需为3维
    3)必须传groupListOptional,且最后一个值与x中tensor的第二维相等| + - x和weight中每一组tensor的最后一维大小都应小于65536。$x_i$的最后一维指当属性transpose_x为false时$x_i$的K轴或当transpose_x为true时$x_i$的M轴。$weight_i$的最后一维指当属性transpose_weight为false时$weight_i$的N轴或当transpose_weight为true时$weight_i$的K轴。 + - x和weight中每一组tensor的每一维大小在32字节对齐后都应小于int32的最大值2147483647。 + +## 算子原型 + +```c++ +REG_OP(GroupedMatmul) + .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_BF16, DT_INT8, DT_FLOAT})) + .DYNAMIC_INPUT(weight, TensorType({DT_FLOAT16, DT_BF16, DT_INT8, DT_FLOAT, DT_INT4})) + .DYNAMIC_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .DYNAMIC_INPUT(scale, TensorType({DT_UINT64, DT_BF16, DT_FLOAT32})) + .DYNAMIC_INPUT(offset, TensorType({DT_FLOAT32})) + .DYNAMIC_INPUT(antiquant_scale, TensorType({DT_FLOAT16, DT_BF16})) + .DYNAMIC_INPUT(antiquant_offset, TensorType({DT_FLOAT16, DT_BF16})) + .OPTIONAL_INPUT(group_list, TensorType({DT_INT64})) + .OPTIONAL_INPUT(per_token_scale, TensorType({DT_FLOAT})) + .DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT16, DT_BF16, DT_INT8, DT_FLOAT})) + .ATTR(split_item, Int, 0) + .ATTR(dtype, Int, 0) + .ATTR(transpose_weight, Bool, false) + .ATTR(transpose_x, Bool, false) + .ATTR(group_type, Int, -1) + .ATTR(group_list_type, Int, 0) + .ATTR(act_type, Int, 0) + .OP_END_FACTORY_REG(GroupedMatmul) +``` + +参数解释请参见**算子执行接口**。 + +## 调用示例 +- PyTorch框架调用 + + 如果通过PyTorch单算子方式调用该融合算子,则需要参考PyTorch融合算子[torch_npu.npu_grouped_matmul](https://hiascend.com/document/redirect/PyTorchAPI);如果用户定制了该融合算子,则需要参考《Ascend C算子开发》手册[适配PyTorch框架](https://hiascend.com/document/redirect/CannCommunityAscendCInvorkOnNetwork)。 + +- aclnn单算子调用方式 + + 通过aclnn单算子调用示例代码如下,仅供参考,具体编译和执行过程请参考[编译与运行样例](common/编译与运行样例.md)。 + +```c++ +#include +#include +#include "acl/acl.h" +#include "aclnnop/aclnn_grouped_matmul_v2.h" + +#define CHECK_RET(cond, return_expr) \ + do { \ + if (!(cond)) { \ + return_expr; \ + } \ + } while (0) + +#define LOG_PRINT(message, ...) \ + do { \ + printf(message, ##__VA_ARGS__); \ + } while (0) + +int64_t GetShapeSize(const std::vector& shape) { + int64_t shapeSize = 1; + for (auto i : shape) { + shapeSize *= i; + } + return shapeSize; +} + +int Init(int32_t deviceId, aclrtStream* stream) { + // 固定写法,AscendCL初始化 + auto ret = aclInit(nullptr); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret); + ret = aclrtSetDevice(deviceId); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret); + ret = aclrtCreateStream(stream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret); + return 0; +} + +template +int CreateAclTensor(const std::vector& shape, void** deviceAddr, + aclDataType dataType, aclTensor** tensor) { + auto size = GetShapeSize(shape) * sizeof(T); + // 调用aclrtMalloc申请Device侧内存 + auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret); + + // 调用aclrtMemcpy将Host侧数据拷贝到Device侧内存上 + std::vector hostData(size, 0); + ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret); + + // 计算连续tensor的strides + std::vector strides(shape.size(), 1); + for (int64_t i = shape.size() - 2; i >= 0; i--) { + strides[i] = shape[i + 1] * strides[i + 1]; + } + + // 调用aclCreateTensor接口创建aclTensor + *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, + shape.data(), shape.size(), *deviceAddr); + return 0; +} + + +int CreateAclTensorList(const std::vector>& shapes, void** deviceAddr, + aclDataType dataType, aclTensorList** tensor) { + int size = shapes.size(); + aclTensor* tensors[size]; + for (int i = 0; i < size; i++) { + int ret = CreateAclTensor(shapes[i], deviceAddr + i, dataType, tensors + i); + CHECK_RET(ret == ACL_SUCCESS, return ret); + } + *tensor = aclCreateTensorList(tensors, size); + return ACL_SUCCESS; +} + + +int main() { + // 1. (固定写法)device/stream初始化,参考AscendCL对外接口列表 + // 根据自己的实际device填写deviceId + int32_t deviceId = 0; + aclrtStream stream; + auto ret = Init(deviceId, &stream); + // check根据自己的需要处理 + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Init acl failed. ERROR: %d\n", ret); return ret); + + // 2. 构造输入与输出,需要根据API的接口自定义构造 + std::vector> xShape = {{1, 16}, {4, 32}}; + std::vector> weightShape= {{16, 24}, {32, 16}}; + std::vector> biasShape = {{24}, {16}}; + std::vector> yShape = {{1, 24}, {4, 16}}; + void* xDeviceAddr[2]; + void* weightDeviceAddr[2]; + void* biasDeviceAddr[2]; + void* yDeviceAddr[2]; + aclTensorList* x = nullptr; + aclTensorList* weight = nullptr; + aclTensorList* bias = nullptr; + aclIntArray* groupedList = nullptr; + aclTensorList* scale = nullptr; + aclTensorList* offset = nullptr; + aclTensorList* antiquantScale = nullptr; + aclTensorList* antiquantOffset = nullptr; + aclTensorList* y = nullptr; + int64_t splitItem = 0; + int64_t groupType = -1; + + // 创建x aclTensorList + ret = CreateAclTensorList(xShape, xDeviceAddr, aclDataType::ACL_FLOAT16, &x); + CHECK_RET(ret == ACL_SUCCESS, return ret); + // 创建weight aclTensorList + ret = CreateAclTensorList(weightShape, weightDeviceAddr, aclDataType::ACL_FLOAT16, &weight); + CHECK_RET(ret == ACL_SUCCESS, return ret); + // 创建bias aclTensorList + ret = CreateAclTensorList(biasShape, biasDeviceAddr, aclDataType::ACL_FLOAT16, &bias); + CHECK_RET(ret == ACL_SUCCESS, return ret); + // 创建y aclTensorList + ret = CreateAclTensorList(yShape, yDeviceAddr, aclDataType::ACL_FLOAT16, &y); + CHECK_RET(ret == ACL_SUCCESS, return ret); + + uint64_t workspaceSize = 0; + aclOpExecutor* executor; + + // 3. 调用CANN算子库API + // 调用aclnnGroupedMatmulV2第一段接口 + ret = aclnnGroupedMatmulV2GetWorkspaceSize(x, weight, bias, scale, offset, antiquantScale, antiquantOffset, groupedList, splitItem, groupType, y, &workspaceSize, &executor); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnGroupedMatmulGetWorkspaceSize failed. ERROR: %d\n", ret); return ret); + // 根据第一段接口计算出的workspaceSize申请device内存 + void* workspaceAddr = nullptr; + if (workspaceSize > 0) { + ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); return ret); + } + // 调用aclnnGroupedMatmulV2第二段接口 + ret = aclnnGroupedMatmulV2(workspaceAddr, workspaceSize, executor, stream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnGroupedMatmul failed. ERROR: %d\n", ret); return ret); + + // 4. (固定写法)同步等待任务执行结束 + ret = aclrtSynchronizeStream(stream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret); return ret); + + // 5. 获取输出的值,将Device侧内存上的结果拷贝至Host侧,需要根据具体API的接口定义修改 + for (int i = 0; i < 2; i++) { + auto size = GetShapeSize(yShape[i]); + std::vector resultData(size, 0); + ret = aclrtMemcpy(resultData.data(), size * sizeof(resultData[0]), yDeviceAddr[i], + size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return ret); + for (int64_t j = 0; j < size; j++) { + LOG_PRINT("result[%ld] is: %d\n", j, resultData[j]); + } + } + + // 6. 释放aclTensor和aclScalar,需要根据具体API的接口定义修改 + aclDestroyTensorList(x); + aclDestroyTensorList(weight); + aclDestroyTensorList(bias); + aclDestroyTensorList(y); + + // 7. 释放device资源,需要根据具体API的接口定义修改 + for (int i = 0; i < 2; i++) { + aclrtFree(xDeviceAddr[i]); + aclrtFree(weightDeviceAddr[i]); + aclrtFree(biasDeviceAddr[i]); + aclrtFree(yDeviceAddr[i]); + } + if (workspaceSize > 0) { + aclrtFree(workspaceAddr); + } + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + return 0; +} +``` + diff --git a/docs/GroupedMatmulV3.md b/docs/GroupedMatmulV3.md new file mode 100644 index 00000000..fd282f58 --- /dev/null +++ b/docs/GroupedMatmulV3.md @@ -0,0 +1,409 @@ +声明:本文使用[Creative Commons License version 4.0](https://creativecommons.org/licenses/by/4.0/legalcode)许可协议,转载、引用或修改等操作请遵循此许可协议。 + +# GroupedMatmulV3 + +## 支持的产品型号 + +- Atlas A2 训练系列产品/Atlas 800I A2 推理产品 +- Atlas 推理系列产品 + +产品形态详细说明请参见[昇腾产品形态说明](https://www.hiascend.com/document/redirect/CannCommunityProductForm) + +## 功能描述 + +- 算子功能:实现分组矩阵乘计算,每组矩阵乘的维度大小可以不同。基本功能为矩阵乘,如$y_i[m_i,n_i]=x_i[m_i,k_i] \times weight_i[k_i,n_i], i=1...g$,其中g为分组个数,$m_i/k_i/n_i$为应对shape。输入输出数据类型均为aclTensorList,支持aclTensorList长度为1,对应的功能为: + + - k轴分组:$k_i$各不相同,但$m_i/n_i$每组相同,此时$x_i/weight_i$可以在$k_i$上拼接,对应aclTensorList长度为1。 + - m轴分组:$k_i$各组相同,$weight_i/y_i$可以在$n_i$上拼接。 + - n轴分组:$k_i$各组相同,$x_i/weight_i$分别可以在$m_i/n_i$上拼接。 + + - Atlas A2 训练系列产品/Atlas 800I A2 推理产品: + + 相较于[GroupedMatmul](GroupedMatmul.md)接口,**此接口新增:** + - 非量化场景,支持weight转置(转置指若shape为[M,K]时,则stride为[1,M],数据排布为[K,M]的场景)。 + - 支持m轴和k轴分组,由groupType表示。 + - x、weight、y都为单tensor非量化场景支持x,weight输入都为float32类型。 + - 量化、伪量化场景,支持weight转置,支持weight为单tensor。 + - 对于[aclnnGroupedMatmulGetWorkspaceSize](GroupedMatmul.md)接口支持的特性,该接口不支持x为单tensor,weight/y为多tensor场景。 + + **说明:** + + - 单tensor指一个tensor list中所有分组的tensor在groupType指定的分组轴上合并为1个;否则为多tensor。 + - tensor转置:指若tensor shape为[M,K]时,则stride为[1,M],数据排布为[K,M]的场景,即非连续tensor。 + +- 计算公式: + - **非量化场景:** + + $$ + y_i=x_i\times weight_i + bias_i + $$ + + - **量化场景:** + + $$ + y_i=(x_i\times weight_i + bias_i) * scale_i + offset_i + $$ + + - **反量化场景:** + + $$ + y_i=(x_i\times weight_i + bias_i) * scale_i + $$ + + - **伪量化场景:** + + $$ + y_i=x_i\times (weight_i + antiquant\_offset_i) * antiquant\_scale_i + bias_i + $$ + +## 实现原理 + +详细实现原理参考[GroupedMatmul设计](./common/GroupedMatmul算子设计介绍.md)。 + +## 算子执行接口 + +每个算子分为[两段式接口](common/两段式接口.md),必须先调用“aclnnGroupedMatmulV3GetWorkspaceSize”接口获取入参并根据计算流程计算所需workspace大小,再调用“aclnnGroupedMatmulV3”接口执行计算。 + +* `aclnnStatus aclnnGroupedMatmulV3GetWorkspaceSize(const aclTensorList* x, const aclTensorList* weight, const aclTensorList* biasOptional, const aclTensorList* scaleOptional, const aclTensorList* offsetOptional, const aclTensorList* antiquantScaleOptional, const aclTensorList* antiquantOffsetOptional, const aclTensor* groupListOptional, int64_t splitItem, int64_t groupType, const aclTensorList* y, uint64_t* workspaceSize, aclOpExecutor** executor)` +* `aclnnStatus aclnnGroupedMatmulV3(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream)` + +**说明**: + +- 算子执行接口对外屏蔽了算子内部实现逻辑以及不同代际NPU的差异,且开发者无需编译算子,实现了算子的精简调用。 +- 若开发者不使用算子执行接口的调用算子,也可以定义基于Ascend IR的算子描述文件,通过ATC工具编译获得算子om文件,然后加载模型文件执行算子,详细调用方法可参见《应用开发指南》的[单算子调用 > 单算子模型执行](https://hiascend.com/document/redirect/CannCommunityCppOpcall)章节。 + +### aclnnGroupedMatmulV3GetWorkspaceSize + +- **参数说明:** + - x(aclTensorList\*,计算输入):必选参数,Device侧的aclTensorList,公式中的输入x,[数据格式](common/数据格式.md)支持ND,支持的最大长度为128个。 + - Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32。 + - Atlas 推理系列产品:数据类型支持FLOAT16。 + - weight(aclTensorList\*,计算输入):必选参数,Device侧的aclTensorList,公式中的weight,[数据格式](common/数据格式.md)支持ND,支持的最大长度为128个。 + - Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32。 + - Atlas 推理系列产品:数据类型支持FLOAT16。 + - biasOptional(aclTensorList\*,计算输入)可选参数,Device侧的aclTensorList,公式中的bias,[数据格式](common/数据格式.md)支持ND,长度与weight相同。 + - Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持FLOAT16、FLOAT32、INT32。 + - Atlas 推理系列产品:数据类型支持FLOAT16。 + - scaleOptional(aclTensorList\*,计算输入)可选参数,Device侧的aclTensorList,代表量化参数中的缩放因子,数据类型支持UINT64,[数据格式](common/数据格式.md)支持ND,长度与weight相同。 + - Atlas 推理系列产品:功能暂不支持,需传空指针。 + - offsetOptional(aclTensorList\*,计算输入)可选参数,Device侧的aclTensorList,代表量化参数中的偏移量,数据类型支持FLOAT32,[数据格式](common/数据格式.md)支持ND,长度与weight相同。 + - Atlas 推理系列产品:功能暂不支持,需传空指针。 + - antiquantScaleOptional(aclTensorList\*,计算输入)可选参数,Device侧的aclTensorList,代表伪量化参数中的缩放因子,[数据格式](common/数据格式.md)支持ND,长度与weight相同。 + - Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持FLOAT16、BFLOAT16。 + - Atlas 推理系列产品:数据类型支持FLOAT16,功能暂不支持,需传空指针。 + - antiquantOffsetOptional(aclTensorList\*,计算输入)可选参数,Device侧的aclTensorList,代表伪量化参数中的偏移量,[数据格式](common/数据格式.md)支持ND,长度与weight相同。 + - Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持FLOAT16、BFLOAT16。 + - Atlas 推理系列产品:数据类型支持FLOAT16,功能暂不支持,需传空指针。 + - groupListOptional(aclTensor\*,计算输入):可选参数,Host侧的aclTensor类型,代表输入和输出分组轴方向的matmul大小分布,数据类型支持INT64,[数据格式](common/数据格式.md)支持ND。 + - splitItem(int64\_t,计算输入):整数型参数,代表输出是否要做tensor切分,0/1代表输出为多tensor;2/3代表输出为单tensor。 + - groupType(int64\_t,计算输入):整数型参数,代表需要分组的轴,如矩阵乘为C[m,n]=A[m,k]xB[k,n],则groupType取值-1:不分组,0:m轴分组,1:n轴分组,2:k轴分组。 + - Atlas A2 训练系列产品/Atlas 800I A2 推理产品:当前不支持n轴分组。 + - Atlas 推理系列产品:当前只支持m轴分组。 + - y(aclTensorList\*,计算输出):Device侧的aclTensorList,公式中的输出y,[数据格式](common/数据格式.md)支持ND,支持的最大长度为128个。 + - Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32。 + - Atlas 推理系列产品:数据类型支持FLOAT16。 + - workspaceSize(uint64\_t\*,出参):返回需要在Device侧申请的workspace大小。 + - executor(aclOpExecutor\*\*,出参):返回op执行器,包含了算子计算流程。 + +- **返回值:** + + 返回aclnnStatus状态码,具体参见[aclnn返回码](common/aclnn返回码.md)。 + + ``` + 第一段接口完成入参校验,若出现以下错误码,则对应原因为: + - 返回161001(ACLNN_ERR_PARAM_NULLPTR): + 1.如果传入参数是必选输入、输出或者必选属性,且是空指针。 + 2.传入参数weight的元素存在空指针。 + 3.传入参数x的元素为空指针,且传出参数y的元素不为空指针。 + 4.传入参数x的元素不为空指针,且传出参数y的元素为空指针。 + - 返回161002(ACLNN_ERR_PARAM_INVALID): + 1.x、weight、biasOptional、scaleOptional、offsetOptional、antiquantScaleOptional、antiquantOffsetOptional、groupListOptional、splitItem、groupType、y的数据类型和数据格式不在支持的范围内。 + 2.weight的长度大于128;若bias不为空,bias的长度不等于weight的长度。 + 3.groupListOptional维度为1。 + 4.splitItem为2、3的场景,y长度不等于1。 + 5.splitItem为0、1的场景,y长度不等于weight的长度,groupListOptional长度不等于weight的长度。 + ``` + +### aclnnGroupedMatmulV3 + +- **参数说明:** + - workspace(void\*,入参):在Device侧申请的workspace内存地址。 + - workspaceSize(uint64\_t,入参):在Device侧申请的workspace大小,由第一段接口aclnnGroupedMatmulV3GetWorkspaceSize获取。 + - executor(aclOpExecutor\*,入参):op执行器,包含了算子计算流程。 + - stream(aclrtStream,入参):指定执行任务的AscendCL stream流。 + +- **返回值:** + + 返回aclnnStatus状态码,具体参见[aclnn返回码](common/aclnn返回码.md)。 + +## 约束与限制 + - 如果传入groupListOptional,groupListOptional必须为非负递增数列,groupListOptional长度不能为1。 + - x和weight中每一组tensor的每一维大小在32字节对齐后都应小于int32的最大值2147483647。 + - Atlas A2 训练系列产品/Atlas 800I A2 推理产品: + - 非量化场景支持的输入类型为: + - x为FLOAT16、weight为FLOAT16、biasOptional为FLOAT16、scaleOptional为空、offsetOptional为空、antiquantScaleOptional为空、antiquantOffsetOptional为空、y为FLOAT16; + - x为BFLOAT16、weight为BFLOAT16、biasOptional为FLOAT32、scaleOptional为空、offsetOptional为空、antiquantScaleOptional为空、antiquantOffsetOptional为空、y为BFLOAT16; + - x为FLOAT32、weight为FLOAT32、biasOptional为FLOAT32、scaleOptional为空、offsetOptional为空、antiquantScaleOptional为空、antiquantOffsetOptional为空、y为FLOAT32(仅x、weight、y都为单tensor场景支持); + - 量化场景支持的输入类型为: + + - x为INT8、weight为INT8、biasOptional为INT32、scaleOptional为UINT64、offsetOptional为空、antiquantScaleOptional为空、antiquantOffsetOptional为空、y为INT8; + - 伪量化场景支持的输入类型为: + - x为FLOAT16、weight为INT8、biasOptional为FLOAT16、scaleOptional为空,offsetOptional为空,antiquantScaleOptional为FLOAT16、antiquantOffsetOptional为FLOAT16、y为FLOAT16; + - x为BFLOAT16、weight为INT8、biasOptional为FLOAT32、scaleOptional为空,offsetOptional为空,antiquantScaleOptional为BFLOAT16、antiquantOffsetOptional为BFLOAT16、y为BFLOAT16; + - 不同groupType支持场景: + - 量化、伪量化仅支持groupType为-1和0场景。 + - 支持场景中单表示单tensor,多表示多tensor,表示顺序为x,weight,y,例如单多单表示支持x为单tensor,weight多tensor,y单tensor的场景。 + | groupType | 支持场景 | 场景限制 | + |:---------:|:-------:| :-------| + | -1 | 多多多 |1)仅支持splitItem为0/1
    2)x中tensor支持2-6维,weight中tensor需为2维,y中tensor维度和x保持一致
    3)groupListOptional必须传空
    4)支持weight转置,但weight的tensorList中每个tensor是否转置需保持统一
    5)x不支持转置 | + | 0 | 单单单 |1)仅支持splitItem为2/3
    2)weight中tensor需为3维,x,y中tensor需为2维
    3)必须传groupListOptional,且当groupListType为0时,最后一个值与x中tensor的第一维相等,当groupListType为1时,数值的总和与x中tensor的第一维相等
    4)groupListOptional第1维最大支持1024,即最多支持1024个group
    5)支持weight转置
    6)x不支持转置 | + | 0 | 单多单 |1)仅支持splitItem为2/3
    2)必须传groupListOptional,且当groupListType为0时,最后一个值与x中tensor的第一维相等,当groupListType为1时,数值的总和与x中tensor的第一维相等,长度最大为128
    3)x,weight,y中tensor需为2维
    4)weight中每个tensor的N轴必须相等
    5)支持weight转置,但weight的tensorList中每个tensor是否转置需保持统一
    6)x不支持转置 | + | 0 | 多多单 |1)仅支持splitItem为2/3
    2)x,weight,y中tensor需为2维
    3)weight中每个tensor的N轴必须相等
    4)若传入groupListOptional,当groupListType为0时,groupListOptional的差值需与x中tensor的第一维一一对应,当groupListType为1时,groupListOptional的数值需与x中tensor的第一维一一对应,且长度最大为128
    5)支持weight转置,但weight的tensorList中每个tensor是否转置需保持统一
    6)x不支持转置 | + | 2 | 单单单 |1)仅支持splitItem为2/3
    2)x,weight中tensor需为2维,y中tensor需为3维
    3)必须传groupListOptional,且当groupListType为0时,最后一个值与x中tensor的第二维相等,当groupListType为1时,数值的总和与x中tensor的第二维相等
    4)groupListOptional第1维最大支持1024, 即最多支持1024个group
    5)x必须转置,weight不能转置 | + + - x和weight中每一组tensor的最后一维大小都应小于65536。$x_i$的最后一维指当x不转置时$x_i$的K轴或当x转置时$x_i$的M轴。$weight_i$的最后一维指当weight不转置时$weight_i$的N轴或当weight转置时$weight_i$的K轴。 + - Atlas 推理系列产品: + - 输入输出只支持float16的数据类型,输出y的n轴大小需要是16的倍数。 + | groupType | 支持场景 | 场景限制 | + |:---------:|:-------:| :------ | + | 0 | 单单单 |1)仅支持splitItem为2/3
    2)weight中tensor需为3维,x,y中tensor需为2维
    3)必须传groupListOptional,且当groupListType为0时,最后一个值与x中tensor的第一维相等,当groupListType为1时,数值的总和与x中tensor的第一维相等
    4)groupListOptional第1维最大支持1024, 即最多支持1024个group
    5)支持weight转置,不支持x转置 | + +## 算子原型 + +```c++ +REG_OP(GroupedMatmul) + .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_BF16, DT_INT8, DT_FLOAT})) + .DYNAMIC_INPUT(weight, TensorType({DT_FLOAT16, DT_BF16, DT_INT8, DT_FLOAT, DT_INT4})) + .DYNAMIC_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .DYNAMIC_INPUT(scale, TensorType({DT_UINT64, DT_BF16, DT_FLOAT32})) + .DYNAMIC_INPUT(offset, TensorType({DT_FLOAT32})) + .DYNAMIC_INPUT(antiquant_scale, TensorType({DT_FLOAT16, DT_BF16})) + .DYNAMIC_INPUT(antiquant_offset, TensorType({DT_FLOAT16, DT_BF16})) + .OPTIONAL_INPUT(group_list, TensorType({DT_INT64})) + .OPTIONAL_INPUT(per_token_scale, TensorType({DT_FLOAT})) + .DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT16, DT_BF16, DT_INT8, DT_FLOAT})) + .ATTR(split_item, Int, 0) + .ATTR(dtype, Int, 0) + .ATTR(transpose_weight, Bool, false) + .ATTR(transpose_x, Bool, false) + .ATTR(group_type, Int, -1) + .ATTR(group_list_type, Int, 0) + .ATTR(act_type, Int, 0) + .OP_END_FACTORY_REG(GroupedMatmul) +``` + +参数解释请参见**算子执行接口**。 + +## 调用示例 +- PyTorch框架调用 + + 如果通过PyTorch单算子方式调用该融合算子,则需要参考PyTorch融合算子[torch_npu.npu_grouped_matmul](https://hiascend.com/document/redirect/PyTorchAPI);如果用户定制了该融合算子,则需要参考《Ascend C算子开发》手册[适配PyTorch框架](https://hiascend.com/document/redirect/CannCommunityAscendCInvorkOnNetwork)。 + +- aclnn单算子调用方式 + + 通过aclnn单算子调用示例代码如下,仅供参考,具体编译和执行过程请参考[编译与运行样例](common/编译与运行样例.md)。 + +```c++ +#include +#include +#include "acl/acl.h" +#include "aclnnop/aclnn_grouped_matmul_v3.h" + +#define CHECK_RET(cond, return_expr) \ + do { \ + if (!(cond)) { \ + return_expr; \ + } \ + } while (0) + +#define LOG_PRINT(message, ...) \ + do { \ + printf(message, ##__VA_ARGS__); \ + } while (0) + +int64_t GetShapeSize(const std::vector& shape) { + int64_t shapeSize = 1; + for (auto i : shape) { + shapeSize *= i; + } + return shapeSize; +} + +int Init(int32_t deviceId, aclrtStream* stream) { + // 固定写法,AscendCL初始化 + auto ret = aclInit(nullptr); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret); + ret = aclrtSetDevice(deviceId); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret); + ret = aclrtCreateStream(stream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret); + return 0; +} + +template +int CreateAclTensor_New(const std::vector& hostData, const std::vector& shape, void** deviceAddr, + aclDataType dataType, aclTensor** tensor) { + auto size = GetShapeSize(shape) * sizeof(T); + // 调用aclrtMalloc申请Device侧内存 + auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret); + + // 调用aclrtMemcpy将Host侧数据拷贝到Device侧内存上 + ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret); + + // 计算连续tensor的strides + std::vector strides(shape.size(), 1); + for (int64_t i = shape.size() - 2; i >= 0; i--) { + strides[i] = shape[i + 1] * strides[i + 1]; + } + + // 调用aclCreateTensor接口创建aclTensor + *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, + shape.data(), shape.size(), *deviceAddr); + return 0; +} + +template +int CreateAclTensor(const std::vector& shape, void** deviceAddr, + aclDataType dataType, aclTensor** tensor) { + auto size = GetShapeSize(shape) * sizeof(T); + // 调用aclrtMalloc申请Device侧内存 + auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret); + + // 调用aclrtMemcpy将Host侧数据拷贝到Device侧内存上 + std::vector hostData(size, 0); + ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret); + + // 计算连续tensor的strides + std::vector strides(shape.size(), 1); + for (int64_t i = shape.size() - 2; i >= 0; i--) { + strides[i] = shape[i + 1] * strides[i + 1]; + } + + // 调用aclCreateTensor接口创建aclTensor + *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, + shape.data(), shape.size(), *deviceAddr); + return 0; +} + + +int CreateAclTensorList(const std::vector>& shapes, void** deviceAddr, + aclDataType dataType, aclTensorList** tensor) { + int size = shapes.size(); + aclTensor* tensors[size]; + for (int i = 0; i < size; i++) { + int ret = CreateAclTensor(shapes[i], deviceAddr + i, dataType, tensors + i); + CHECK_RET(ret == ACL_SUCCESS, return ret); + } + *tensor = aclCreateTensorList(tensors, size); + return ACL_SUCCESS; +} + + +int main() { + // 1. (固定写法)device/stream初始化,参考AscendCL对外接口列表 + // 根据自己的实际device填写deviceId + int32_t deviceId = 0; + aclrtStream stream; + auto ret = Init(deviceId, &stream); + // check根据自己的需要处理 + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Init acl failed. ERROR: %d\n", ret); return ret); + + // 2. 构造输入与输出,需要根据API的接口自定义构造 + std::vector> xShape = {{512, 256}}; + std::vector> weightShape= {{2, 256, 256}}; + std::vector> biasShape = {{2, 256}}; + std::vector> yShape = {{512, 256}}; + std::vector groupListShape = {{2}}; + std::vector groupListData = {256, 512}; + void* xDeviceAddr[1]; + void* weightDeviceAddr[1]; + void* biasDeviceAddr[1]; + void* yDeviceAddr[1]; + void* groupListDeviceAddr; + aclTensorList* x = nullptr; + aclTensorList* weight = nullptr; + aclTensorList* bias = nullptr; + aclTensor* groupedList = nullptr; + aclTensorList* scale = nullptr; + aclTensorList* offset = nullptr; + aclTensorList* antiquantScale = nullptr; + aclTensorList* antiquantOffset = nullptr; + aclTensorList* y = nullptr; + int64_t splitItem = 2; + int64_t groupType = 0; + + // 创建x aclTensorList + ret = CreateAclTensorList(xShape, xDeviceAddr, aclDataType::ACL_FLOAT16, &x); + CHECK_RET(ret == ACL_SUCCESS, return ret); + // 创建weight aclTensorList + ret = CreateAclTensorList(weightShape, weightDeviceAddr, aclDataType::ACL_FLOAT16, &weight); + CHECK_RET(ret == ACL_SUCCESS, return ret); + // 创建bias aclTensorList + ret = CreateAclTensorList(biasShape, biasDeviceAddr, aclDataType::ACL_FLOAT16, &bias); + CHECK_RET(ret == ACL_SUCCESS, return ret); + // 创建y aclTensorList + ret = CreateAclTensorList(yShape, yDeviceAddr, aclDataType::ACL_FLOAT16, &y); + CHECK_RET(ret == ACL_SUCCESS, return ret); + // 创建group_list aclTensor + ret = CreateAclTensor_New(groupListData, groupListShape, &groupListDeviceAddr, aclDataType::ACL_INT64, &groupedList); + CHECK_RET(ret == ACL_SUCCESS, return ret); + uint64_t workspaceSize = 0; + aclOpExecutor* executor; + + // 3. 调用CANN算子库API + // 调用aclnnGroupedMatmulV3第一段接口 + ret = aclnnGroupedMatmulV3GetWorkspaceSize(x, weight, bias, scale, offset, antiquantScale, antiquantOffset, groupedList, splitItem, groupType, y, &workspaceSize, &executor); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnGroupedMatmulGetWorkspaceSize failed. ERROR: %d\n", ret); return ret); + // 根据第一段接口计算出的workspaceSize申请device内存 + void* workspaceAddr = nullptr; + if (workspaceSize > 0) { + ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); return ret); + } + // 调用aclnnGroupedMatmulV3第二段接口 + ret = aclnnGroupedMatmulV3(workspaceAddr, workspaceSize, executor, stream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnGroupedMatmul failed. ERROR: %d\n", ret); return ret); + + // 4. (固定写法)同步等待任务执行结束 + ret = aclrtSynchronizeStream(stream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret); return ret); + + // 5. 获取输出的值,将Device侧内存上的结果拷贝至Host侧,需要根据具体API的接口定义修改 + for (int i = 0; i < 1; i++) { + auto size = GetShapeSize(yShape[i]); + std::vector resultData(size, 0); + ret = aclrtMemcpy(resultData.data(), size * sizeof(resultData[0]), yDeviceAddr[i], + size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return ret); + for (int64_t j = 0; j < size; j++) { + LOG_PRINT("result[%ld] is: %d\n", j, resultData[j]); + } + } + + // 6. 释放aclTensor和aclScalar,需要根据具体API的接口定义修改 + aclDestroyTensorList(x); + aclDestroyTensorList(weight); + aclDestroyTensorList(bias); + aclDestroyTensorList(y); + + // 7. 释放device资源,需要根据具体API的接口定义修改 + for (int i = 0; i < 1; i++) { + aclrtFree(xDeviceAddr[i]); + aclrtFree(weightDeviceAddr[i]); + aclrtFree(biasDeviceAddr[i]); + aclrtFree(yDeviceAddr[i]); + } + if (workspaceSize > 0) { + aclrtFree(workspaceAddr); + } + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + return 0; +} +``` + diff --git a/docs/GroupedMatmulV4.md b/docs/GroupedMatmulV4.md new file mode 100644 index 00000000..82082dc8 --- /dev/null +++ b/docs/GroupedMatmulV4.md @@ -0,0 +1,453 @@ +声明:本文使用[Creative Commons License version 4.0](https://creativecommons.org/licenses/by/4.0/legalcode)许可协议,转载、引用或修改等操作请遵循此许可协议。 + +# GroupedMatmulV4 + +## 支持的产品型号 + +- Atlas A2 训练系列产品/Atlas 800I A2 推理产品 +- Atlas 推理系列产品 + +产品形态详细说明请参见[昇腾产品形态说明](https://www.hiascend.com/document/redirect/CannCommunityProductForm) + +## 功能描述 + +- 算子功能:实现分组矩阵乘计算,每组矩阵乘的维度大小可以不同。基本功能为矩阵乘,如$y_i[m_i,n_i]=x_i[m_i,k_i] \times weight_i[k_i,n_i], i=1...g$,其中g为分组个数,$m_i/k_i/n_i$为应对shape。输入输出数据类型均为aclTensorList,支持aclTensorList长度为1,对应的功能为: + + - k轴分组:$k_i$各不相同,但$m_i/n_i$每组相同,此时$x_i/weight_i$可以在$k_i$上拼接,对应aclTensorList长度为1。 + - m轴分组:$k_i$各组相同,$weight_i/y_i$可以在$n_i$上拼接。 + - n轴分组:$k_i$各组相同,$x_i/weight_i$分别可以在$m_i/n_i$上拼接。 + + 相较于[GroupedMatmulV3](GroupedMatmulV3.md)接口,**此接口新增:** + - 支持groupListOptional中数值为分组轴上每组大小。 + - Atlas A2 训练系列产品/Atlas 800I A2 推理产品: + - 支持静态量化(pertensor+perchannel)bfloat16和float16输出,带激活及不带激活场景 + - 支持动态量化(pertoken+perchannel)bfloat16和float16输出,带激活及不带激活场景。 + - 支持伪量化weight是INT4的输入,不带激活场景,支持perchannel和pergroup两种模式。 + + **说明:** + - 单tensor指一个tensor list中所有分组的tensor在groupType指定的分组轴上合并为1个;否则为多tensor。 + - tensor转置:指若tensor shape为[M,K]时,则stride为[1,M],数据排布为[K,M]的场景,即非连续tensor。 + +- 计算公式: + - **非量化场景:** + + $$ + y_i=x_i\times weight_i + bias_i + $$ + + - **量化场景 (per-channel):** + + $$ + y_i=(x_i\times weight_i + bias_i) * scale_i + offset_i + $$ + + - **量化场景 (per-token):** + + $$ + y_i=(x_i\times weight_i + bias_i) * scale_i * per_token_scale_i + $$ + + - **反量化场景:** + + $$ + y_i=(x_i\times weight_i + bias_i) * scale_i + $$ + + - **伪量化场景:** + + $$ + y_i=x_i\times (weight_i + antiquant\_offset_i) * antiquant\_scale_i + bias_i + $$ + +## 实现原理 + +详细实现原理参考[GroupedMatmul设计](./common/GroupedMatmul算子设计介绍.md)。 + +## 算子执行接口 + +每个算子分为[两段式接口](common/两段式接口.md),必须先调用“aclnnGroupedMatmulV4GetWorkspaceSize”接口获取入参并根据计算流程计算所需workspace大小,再调用“aclnnGroupedMatmulV4”接口执行计算。 + +* `aclnnStatus aclnnGroupedMatmulV4GetWorkspaceSize(const aclTensorList *x, const aclTensorList *weight, const aclTensorList *biasOptional, const aclTensorList *scaleOptional, const aclTensorList *offsetOptional, const aclTensorList *antiquantScaleOptional, const aclTensorList *antiquantOffsetOptional, const aclTensorList *perTokenScaleOptional, const aclTensor *groupListOptional, const aclTensorList *activationInputOptional, const aclTensorList *activationQuantScaleOptional, const aclTensorList *activationQuantOffsetOptional, int64_t splitItem, int64_t groupType, int64_t groupListType, int64_t actType, aclTensorList *out, aclTensorList *activationFeatureOutOptional, aclTensorList *dynQuantScaleOutOptional, uint64_t *workspaceSize, aclOpExecutor **executor)` +* `aclnnStatus aclnnGroupedMatmulV4(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream)` + +**说明**: + +- 算子执行接口对外屏蔽了算子内部实现逻辑以及不同代际NPU的差异,且开发者无需编译算子,实现了算子的精简调用。 +- 若开发者不使用算子执行接口的调用算子,也可以定义基于Ascend IR的算子描述文件,通过ATC工具编译获得算子om文件,然后加载模型文件执行算子,详细调用方法可参见《应用开发指南》的[单算子调用 > 单算子模型执行](https://hiascend.com/document/redirect/CannCommunityCppOpcall)章节。 + +### aclnnGroupedMatmulV4GetWorkspaceSize + +- **参数说明:** + - x(aclTensorList\*,计算输入):必选参数,Device侧的aclTensorList,公式中的输入x,[数据格式](common/数据格式.md)支持ND,支持的最大长度为128个。 + - Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32。 + - Atlas 推理系列产品:数据类型支持FLOAT16。 + - weight(aclTensorList\*,计算输入):必选参数,Device侧的aclTensorList,公式中的weight,[数据格式](common/数据格式.md)支持ND,支持的最大长度为128个。 + - Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32、INT4。 + - Atlas 推理系列产品:数据类型支持FLOAT16。 + - biasOptional(aclTensorList\*,计算输入)可选参数,Device侧的aclTensorList,公式中的bias,[数据格式](common/数据格式.md)支持ND,长度与weight相同。 + - Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持FLOAT16、FLOAT32、INT32。 + - Atlas 推理系列产品:数据类型支持FLOAT16。 + - scaleOptional(aclTensorList\*,计算输入)可选参数,Device侧的aclTensorList,代表量化参数中的缩放因子,[数据格式](common/数据格式.md)支持ND,长度与weight相同 + - Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持UINT64、BFLOAT16、FLOAT32。 + - Atlas 推理系列产品:数据类型支持UINT64,功能暂不支持,需传空指针。 + - offsetOptional(aclTensorList\*,计算输入)可选参数,Device侧的aclTensorList,代表量化参数中的偏移量,数据类型支持FLOAT32,[数据格式](common/数据格式.md)支持ND,长度与weight相同。 + - Atlas 推理系列产品:功能暂不支持,需传空指针。 + - antiquantScaleOptional(aclTensorList\*,计算输入)可选参数,Device侧的aclTensorList,代表伪量化参数中的缩放因子,[数据格式](common/数据格式.md)支持ND,长度与weight相同。综合约束请参见[约束与限制](# 约束与限制)。 + - Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持FLOAT16、BFLOAT16。 + - Atlas 推理系列产品:数据类型支持FLOAT16,功能暂不支持,需传空指针。 + - antiquantOffsetOptional(aclTensorList\*,计算输入)可选参数,Device侧的aclTensorList,代表伪量化参数中的偏移量,[数据格式](common/数据格式.md)支持ND,长度与weight相同。综合约束请参见[约束与限制](# 约束与限制)。 + - Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持FLOAT16、BFLOAT16。 + - Atlas 推理系列产品:数据类型支持FLOAT16,功能暂不支持,需传空指针。 + - perTokenScaleOptional(aclTensorList\*,计算输入)可选参数,Device侧的aclTensorList,代表量化参数中的由x量化引入的缩放因子,数据类型支持FLOAT32,[数据格式](common/数据格式.md)支持ND,只支持1维且长度与x相同。 + - groupListOptional(aclTensor\*,计算输入):可选参数,Host侧的aclTensor类型,代表输入和输出分组轴方向的matmul大小分布,数据类型支持INT64,[数据格式](common/数据格式.md)支持ND。 + - activationInputOptional(aclTensorList\*,计算输入):可选参数,Host侧的aclTensorList类型,代表激活函数的反向输入,当前只支持传入nullptr。 + - activationQuantScaleOptional\*,计算输入):可选参数,Host侧的aclTensorList类型,当前只支持传入nullptr。 + - activationQuantOffsetOptional\*,计算输入):可选参数,Host侧的aclTensorList类型,当前只支持传入nullptr。 + - splitItem(int64\_t,计算输入):整数型参数,代表输出是否要做tensor切分,0/1代表输出为多tensor;2/3代表输出为单tensor。 + - groupType(int64\_t,计算输入):整数型参数,代表需要分组的轴,如矩阵乘为C[m,n]=A[m,k]xB[k,n],则groupType取值-1:不分组,0:m轴分组,1:n轴分组,2:k轴分组。 + - Atlas A2 训练系列产品/Atlas 800I A2 推理产品:当前不支持n轴分组。 + - Atlas 推理系列产品:当前只支持m轴分组。 + - groupListType(int64\_t,计算输入):整数型参数,可取值0或1,0代表groupListOptional中数值为分组轴大小的cumsum结果(累积和),1代表groupListOptional中数值为分组轴上每组大小。 + - actType(int64\_t,计算输入):整数型参数,代表激活函数类型。 + - Atlas A2 训练系列产品/Atlas 800I A2 推理产品:取值范围为0-5,支持的枚举值如下: + * 0:GMMActType::GMM_ACT_TYPE_NONE; + * 1:GMMActType::GMM_ACT_TYPE_RELU; + * 2:GMMActType::GMM_ACT_TYPE_GELU_TANH; + * 3:GMMActType::GMM_ACT_TYPE_GELU_ERR_FUNC(不支持); + * 4:GMMActType::GMM_ACT_TYPE_FAST_GELU; + * 5:GMMActType::GMM_ACT_TYPE_SILU; + - Atlas 推理系列产品:当前只支持传入0,0:GMMActType::GMM_ACT_TYPE_NONE。 + - out(aclTensorList\*,计算输出):Device侧的aclTensorList,公式中的输出y,[数据格式](common/数据格式.md)支持ND,支持的最大长度为128个。 + - Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32。 + - Atlas 推理系列产品:数据类型支持FLOAT16。 + - activationFeatureOutOptional(aclTensorList\*,计算输出):Device侧的aclTensorList,激活函数的输入数据,当前只支持传入nullptr。 + - dynQuantScaleOutOptional\*,计算输出):Device侧的aclTensorList,当前只支持传入nullptr。 + - workspaceSize(uint64\_t\*,出参):返回需要在Device侧申请的workspace大小。 + - executor(aclOpExecutor\*\*,出参):返回op执行器,包含了算子计算流程。 + +- **返回值:** + + 返回aclnnStatus状态码,具体参见[aclnn返回码](common/aclnn返回码.md)。 + + ``` + 第一段接口完成入参校验,若出现以下错误码,则对应原因为: + - 返回161001(ACLNN_ERR_PARAM_NULLPTR): + 1.如果传入参数是必选输入、输出或者必选属性,且是空指针。 + 2.传入参数weight的元素存在空指针。 + 3.传入参数x的元素为空指针,且传出参数y的元素不为空指针。 + 4.传入参数x的元素不为空指针,且传出参数y的元素为空指针。 + - 返回161002(ACLNN_ERR_PARAM_INVALID): + 1.x、weight、biasOptional、scaleOptional、offsetOptional、antiquantScaleOptional、antiquantOffsetOptional、groupListOptional、splitItem、groupType、actType、y的数据类型和数据格式不在支持的范围内。 + 2.weight的长度大于128;若bias不为空,bias的长度不等于weight的长度。 + 3.groupListOptional维度为1。 + 4.splitItem为2、3的场景,y长度不等于1。 + 5.splitItem为0、1的场景,y长度不等于weight的长度,groupListOptional长度不等于weight的长度。 + ``` + +### aclnnGroupedMatmulV4 + +- **参数说明:** + - workspace(void\*,入参):在Device侧申请的workspace内存地址。 + - workspaceSize(uint64\_t,入参):在Device侧申请的workspace大小,由第一段接口aclnnGroupedMatmulV4GetWorkspaceSize获取。 + - executor(aclOpExecutor\*,入参):op执行器,包含了算子计算流程。 + - stream(aclrtStream,入参):指定执行任务的AscendCL stream流。 + +- **返回值:** + + 返回aclnnStatus状态码,具体参见[aclnn返回码](common/aclnn返回码.md)。 + +## 约束与限制 + - 如果传入groupListOptional,当groupListType为0时,groupListOptional必须为非负单调非递减数列,当groupListType为1时,groupListOptional必须为非负数列,且长度不能为1。 + - x和weight中每一组tensor的每一维大小在32字节对齐后都应小于int32的最大值2147483647。 + - Atlas A2 训练系列产品/Atlas 800I A2 推理产品: + - 非量化场景支持的输入类型为: + - x为FLOAT16、weight为FLOAT16、biasOptional为FLOAT16、scaleOptional为空、offsetOptional为空、antiquantScaleOptional为空、antiquantOffsetOptional为空、perTokenScaleOptional为空、activationInputOptional为空、y为FLOAT16。 + - x为BFLOAT16、weight为BFLOAT16、biasOptional为FLOAT32、scaleOptional为空、offsetOptional为空、antiquantScaleOptional为空、antiquantOffsetOptional为空、perTokenScaleOptional为空、activationInputOptional为空、y为BFLOAT16。 + - x为FLOAT32、weight为FLOAT32、biasOptional为FLOAT32、scaleOptional为空、offsetOptional为空、antiquantScaleOptional为空、antiquantOffsetOptional为空、perTokenScaleOptional为空、activationInputOptional为空、y为FLOAT32(仅x、weight、y都为单tensor场景支持)。 + - per-channel量化场景支持的输入类型为: + - x为INT8、weight为INT8、biasOptional为INT32、scaleOptional为UINT64、offsetOptional为空、antiquantScaleOptional为空、antiquantOffsetOptional为空、perTokenScaleOptional为空、activationInputOptional为空、y为INT8。 + - x为INT8、weight为INT8、biasOptional为INT32、scaleOptional为BFLOAT16、offsetOptional为空、antiquantScaleOptional为空、antiquantOffsetOptional为空、perTokenScaleOptional为空、activationInputOptional为空、y为BFLOAT16。 + - x为INT8、weight为INT8、biasOptional为INT32、scaleOptional为FLOAT32、offsetOptional为空、antiquantScaleOptional为空、antiquantOffsetOptional为空、perTokenScaleOptional为空、activationInputOptional为空、y为FLOAT16。 + - per-token量化场景支持的输入类型为: + - x为INT8、weight为INT8、biasOptional为INT32、scaleOptional为BFLOAT16、offsetOptional为空、antiquantScaleOptional为空、antiquantOffsetOptional为空、perTokenScaleOptional为FLOAT32、activationInputOptional为空、y为BFLOAT16。 + - x为INT8、weight为INT8、biasOptional为INT32、scaleOptional为FLOAT32、offsetOptional为空、antiquantScaleOptional为空、antiquantOffsetOptional为空、perTokenScaleOptional为FLOAT32、activationInputOptional为空、y为FLOAT16。 + - 伪量化场景支持的输入类型为: + - x为FLOAT16、weight为INT8或INT4、biasOptional为FLOAT16、scaleOptional为空,offsetOptional为空,antiquantScaleOptional为FLOAT16、antiquantOffsetOptional为FLOAT16、perTokenScaleOptional为空、activationInputOptional为空、y为FLOAT16。 + - 伪量化参数antiquantScaleOptional和antiquantOffsetOptional的shape要满足下表(其中g为matmul组数,G为pergroup数,$G_i$为第i个tensor的pergroup数): + | 使用场景 | 子场景 | shape限制 | + |:---------:|:-------:| :-------| + | 伪量化perchannel | weight单 | $[g, n]$| + | 伪量化perchannel | weight多 | $[n_i]$| + | 伪量化pergroup | weight单 | $[g, G, n]$| + | 伪量化pergroup | weight多 | $[G_i, n_i]$| + - x为BFLOAT16、weight为INT8或INT4、biasOptional为FLOAT32、scaleOptional为空,offsetOptional为空,antiquantScaleOptional为BFLOAT16、antiquantOffsetOptional为BFLOAT16、perTokenScaleOptional为空、activationInputOptional为空、y为BFLOAT16。 + - 伪量化场景下,若weight的类型为INT8,仅支持perchannel模式;若weight的类型为INT4,支持perchannel和pergroup两种模式。若为pergroup,pergroup数G或$G_i$必须要能整除对应的$k_i$。若weight为多tensor,定义pergroup长度$s_i = k_i / G_i$,要求所有$s_i(i=1,2,...g)$都相等。 + - 伪量化场景下若weight的类型为INT4,则weight中每一组tensor的最后一维大小都应是偶数。$weight_i$的最后一维指weight不转置时$weight_i$的N轴或当weight转置时$weight_i$的K轴。并且在pergroup场景下,当weight转置时,要求pergroup长度$s_i$是偶数。 + + - 不同groupType支持场景: + - 量化、伪量化仅支持groupType为-1和0场景。 + - 支持场景中单表示单tensor,多表示多tensor,表示顺序为x,weight,y,例如单多单表示支持x为单tensor,weight多tensor,y单tensor的场景。 + | groupType | 支持场景 | 场景限制 | + |:---------:|:-------:| :-------| + | -1 | 多多多 |1)仅支持splitItem为0/1
    2)x中tensor支持2-6维,weight中tensor需为2维,y中tensor维度和x保持一致
    3)groupListOptional必须传空
    4)支持weight转置,但weight的tensorList中每个tensor是否转置需保持统一
    5)x不支持转置 | + | 0 | 单单单 |1)仅支持splitItem为2/3
    2)weight中tensor需为3维,x,y中tensor需为2维
    3)必须传groupListOptional,且当groupListType为0时,最后一个值与x中tensor的第一维相等,当groupListType为1时,数值的总和与x中tensor的第一维相等
    4)groupListOptional第1维最大支持1024,即最多支持1024个group
    5)支持weight转置
    6)x不支持转置 | + | 0 | 单多单 |1)仅支持splitItem为2/3
    2)必须传groupListOptional,且当groupListType为0时,最后一个值与x中tensor的第一维相等,当groupListType为1时,数值的总和与x中tensor的第一维相等,长度最大为128
    3)x,weight,y中tensor需为2维
    4)weight中每个tensor的N轴必须相等
    5)支持weight转置,但weight的tensorList中每个tensor是否转置需保持统一
    6)x不支持转置 | + | 0 | 多多单 |1)仅支持splitItem为2/3
    2)x,weight,y中tensor需为2维
    3)weight中每个tensor的N轴必须相等
    4)若传入groupListOptional,当groupListType为0时,groupListOptional的差值需与x中tensor的第一维一一对应,当groupListType为1时,groupListOptional的数值需与x中tensor的第一维一一对应,且长度最大为128
    5)支持weight转置,但weight的tensorList中每个tensor是否转置需保持统一
    6)x不支持转置 | + | 2 | 单单单 |1)仅支持splitItem为2/3
    2)x,weight中tensor需为2维,y中tensor需为3维
    3)必须传groupListOptional,且当groupListType为0时,最后一个值与x中tensor的第二维相等,当groupListType为1时,数值的总和与x中tensor的第二维相等
    4)groupListOptional第1维最大支持1024, 即最多支持1024个group
    5)x必须转置,weight不能转置 | + - x和weight中每一组tensor的最后一维大小都应小于65536。$x_i$的最后一维指当x不转置时$x_i$的K轴或当x转置时$x_i$的M轴。$weight_i$的最后一维指当weight不转置时$weight_i$的N轴或当weight转置时$weight_i$的K轴。 + - 仅量化场景 (per-token)、反量化场景支持激活函数计算。 + + - Atlas 推理系列产品: + - 输入输出只支持float16的数据类型,输出y的n轴大小需要是16的倍数。 + | groupType | 支持场景 | 场景限制 | + |:---------:|:-------:| :------ | + | 0 | 单单单 |1)仅支持splitItem为2/3
    2)weight中tensor需为3维,x,y中tensor需为2维
    3)必须传groupListOptional,且当groupListType为0时,最后一个值与x中tensor的第一维相等,当groupListType为1时,数值的总和与x中tensor的第一维相等
    4)groupListOptional第1维最大支持1024,即最多支持1024个group
    5)支持weight转置,不支持x转置 | + +## 算子原型 + +```c++ +REG_OP(GroupedMatmul) + .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_BF16, DT_INT8, DT_FLOAT})) + .DYNAMIC_INPUT(weight, TensorType({DT_FLOAT16, DT_BF16, DT_INT8, DT_FLOAT, DT_INT4})) + .DYNAMIC_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .DYNAMIC_INPUT(scale, TensorType({DT_UINT64, DT_BF16, DT_FLOAT32})) + .DYNAMIC_INPUT(offset, TensorType({DT_FLOAT32})) + .DYNAMIC_INPUT(antiquant_scale, TensorType({DT_FLOAT16, DT_BF16})) + .DYNAMIC_INPUT(antiquant_offset, TensorType({DT_FLOAT16, DT_BF16})) + .OPTIONAL_INPUT(group_list, TensorType({DT_INT64})) + .OPTIONAL_INPUT(per_token_scale, TensorType({DT_FLOAT})) + .DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT16, DT_BF16, DT_INT8, DT_FLOAT})) + .ATTR(split_item, Int, 0) + .ATTR(dtype, Int, 0) + .ATTR(transpose_weight, Bool, false) + .ATTR(transpose_x, Bool, false) + .ATTR(group_type, Int, -1) + .ATTR(group_list_type, Int, 0) + .ATTR(act_type, Int, 0) + .OP_END_FACTORY_REG(GroupedMatmul) +``` + +参数解释请参见**算子执行接口**。 + +## 调用示例 +- PyTorch框架调用 + + 如果通过PyTorch单算子方式调用该融合算子,则需要参考PyTorch融合算子[torch_npu.npu_grouped_matmul](https://hiascend.com/document/redirect/PyTorchAPI);如果用户定制了该融合算子,则需要参考《Ascend C算子开发》手册[适配PyTorch框架](https://hiascend.com/document/redirect/CannCommunityAscendCInvorkOnNetwork)。 + +- aclnn单算子调用方式 + + 通过aclnn单算子调用示例代码如下,仅供参考,具体编译和执行过程请参考[编译与运行样例](common/编译与运行样例.md)。 + +```c++ +#include +#include +#include "acl/acl.h" +#include "aclnnop/aclnn_grouped_matmul_v4.h" + +#define CHECK_RET(cond, return_expr) \ + do { \ + if (!(cond)) { \ + return_expr; \ + } \ + } while (0) + +#define LOG_PRINT(message, ...) \ + do { \ + printf(message, ##__VA_ARGS__); \ + } while (0) + +int64_t GetShapeSize(const std::vector& shape) { + int64_t shapeSize = 1; + for (auto i : shape) { + shapeSize *= i; + } + return shapeSize; +} + +int Init(int32_t deviceId, aclrtStream* stream) { + // 固定写法,AscendCL初始化 + auto ret = aclInit(nullptr); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret); + ret = aclrtSetDevice(deviceId); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret); + ret = aclrtCreateStream(stream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret); + return 0; +} + +template +int CreateAclTensor_New(const std::vector& hostData, const std::vector& shape, void** deviceAddr, + aclDataType dataType, aclTensor** tensor) { + auto size = GetShapeSize(shape) * sizeof(T); + // 调用aclrtMalloc申请Device侧内存 + auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret); + + // 调用aclrtMemcpy将Host侧数据拷贝到Device侧内存上 + ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret); + + // 计算连续tensor的strides + std::vector strides(shape.size(), 1); + for (int64_t i = shape.size() - 2; i >= 0; i--) { + strides[i] = shape[i + 1] * strides[i + 1]; + } + + // 调用aclCreateTensor接口创建aclTensor + *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, + shape.data(), shape.size(), *deviceAddr); + return 0; +} + +template +int CreateAclTensor(const std::vector& shape, void** deviceAddr, + aclDataType dataType, aclTensor** tensor) { + auto size = GetShapeSize(shape) * sizeof(T); + // 调用aclrtMalloc申请Device侧内存 + auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret); + + // 调用aclrtMemcpy将Host侧数据拷贝到Device侧内存上 + std::vector hostData(size, 0); + ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret); + + // 计算连续tensor的strides + std::vector strides(shape.size(), 1); + for (int64_t i = shape.size() - 2; i >= 0; i--) { + strides[i] = shape[i + 1] * strides[i + 1]; + } + + // 调用aclCreateTensor接口创建aclTensor + *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, + shape.data(), shape.size(), *deviceAddr); + return 0; +} + + +int CreateAclTensorList(const std::vector>& shapes, void** deviceAddr, + aclDataType dataType, aclTensorList** tensor) { + int size = shapes.size(); + aclTensor* tensors[size]; + for (int i = 0; i < size; i++) { + int ret = CreateAclTensor(shapes[i], deviceAddr + i, dataType, tensors + i); + CHECK_RET(ret == ACL_SUCCESS, return ret); + } + *tensor = aclCreateTensorList(tensors, size); + return ACL_SUCCESS; +} + + +int main() { + // 1. (固定写法)device/stream初始化,参考AscendCL对外接口列表 + // 根据自己的实际device填写deviceId + int32_t deviceId = 0; + aclrtStream stream; + auto ret = Init(deviceId, &stream); + // check根据自己的需要处理 + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Init acl failed. ERROR: %d\n", ret); return ret); + + // 2. 构造输入与输出,需要根据API的接口自定义构造 + std::vector> xShape = {{512, 256}}; + std::vector> weightShape= {{2, 256, 256}}; + std::vector> biasShape = {{2, 256}}; + std::vector> yShape = {{512, 256}}; + std::vector groupListShape = {{2}}; + std::vector groupListData = {256, 512}; + void* xDeviceAddr[1]; + void* weightDeviceAddr[1]; + void* biasDeviceAddr[1]; + void* yDeviceAddr[1]; + void* groupListDeviceAddr; + aclTensorList* x = nullptr; + aclTensorList* weight = nullptr; + aclTensorList* bias = nullptr; + aclTensor* groupedList = nullptr; + aclTensorList* scale = nullptr; + aclTensorList* offset = nullptr; + aclTensorList* antiquantScale = nullptr; + aclTensorList* antiquantOffset = nullptr; + aclTensorList* perTokenScale = nullptr; + aclTensorList* activationInput = nullptr; + aclTensorList* activationQuantScale = nullptr; + aclTensorList* activationQuantOffset = nullptr; + aclTensorList* y = nullptr; + aclTensorList* activationFeatureOut = nullptr; + aclTensorList* dynQuantScaleOut = nullptr; + int64_t splitItem = 3; + int64_t groupType = 0; + int64_t groupListType = 0; + int64_t actType = 0; + + // 创建x aclTensorList + ret = CreateAclTensorList(xShape, xDeviceAddr, aclDataType::ACL_FLOAT16, &x); + CHECK_RET(ret == ACL_SUCCESS, return ret); + // 创建weight aclTensorList + ret = CreateAclTensorList(weightShape, weightDeviceAddr, aclDataType::ACL_FLOAT16, &weight); + CHECK_RET(ret == ACL_SUCCESS, return ret); + // 创建bias aclTensorList + ret = CreateAclTensorList(biasShape, biasDeviceAddr, aclDataType::ACL_FLOAT16, &bias); + CHECK_RET(ret == ACL_SUCCESS, return ret); + // 创建y aclTensorList + ret = CreateAclTensorList(yShape, yDeviceAddr, aclDataType::ACL_FLOAT16, &y); + CHECK_RET(ret == ACL_SUCCESS, return ret); + // 创建group_list aclTensor + ret = CreateAclTensor_New(groupListData, groupListShape, &groupListDeviceAddr, aclDataType::ACL_INT64, &groupedList); + CHECK_RET(ret == ACL_SUCCESS, return ret); + + uint64_t workspaceSize = 0; + aclOpExecutor* executor; + + // 3. 调用CANN算子库API + // 调用aclnnGroupedMatmulV4第一段接口 + ret = aclnnGroupedMatmulV4GetWorkspaceSize(x, weight, bias, scale, offset, antiquantScale, antiquantOffset, perTokenScale, groupedList, activationInput, activationQuantScale, activationQuantOffset, splitItem, groupType, groupListType, actType, y, activationFeatureOut, dynQuantScaleOut, &workspaceSize, &executor); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnGroupedMatmulGetWorkspaceSize failed. ERROR: %d\n", ret); return ret); + // 根据第一段接口计算出的workspaceSize申请device内存 + void* workspaceAddr = nullptr; + if (workspaceSize > 0) { + ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); return ret); + } + // 调用aclnnGroupedMatmulV4第二段接口 + ret = aclnnGroupedMatmulV4(workspaceAddr, workspaceSize, executor, stream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnGroupedMatmul failed. ERROR: %d\n", ret); return ret); + + // 4. (固定写法)同步等待任务执行结束 + ret = aclrtSynchronizeStream(stream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret); return ret); + + // 5. 获取输出的值,将Device侧内存上的结果拷贝至Host侧,需要根据具体API的接口定义修改 + for (int i = 0; i < 1; i++) { + auto size = GetShapeSize(yShape[i]); + std::vector resultData(size, 0); + ret = aclrtMemcpy(resultData.data(), size * sizeof(resultData[0]), yDeviceAddr[i], + size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return ret); + for (int64_t j = 0; j < size; j++) { + LOG_PRINT("result[%ld] is: %d\n", j, resultData[j]); + } + } + + // 6. 释放aclTensor和aclScalar,需要根据具体API的接口定义修改 + aclDestroyTensorList(x); + aclDestroyTensorList(weight); + aclDestroyTensorList(bias); + aclDestroyTensorList(y); + + // 7. 释放device资源,需要根据具体API的接口定义修改 + for (int i = 0; i < 1; i++) { + aclrtFree(xDeviceAddr[i]); + aclrtFree(weightDeviceAddr[i]); + aclrtFree(biasDeviceAddr[i]); + aclrtFree(yDeviceAddr[i]); + } + if (workspaceSize > 0) { + aclrtFree(workspaceAddr); + } + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + return 0; +} +``` + diff --git "a/docs/common/GroupedMatmul\347\256\227\345\255\220\350\256\276\350\256\241\344\273\213\347\273\215.md" "b/docs/common/GroupedMatmul\347\256\227\345\255\220\350\256\276\350\256\241\344\273\213\347\273\215.md" new file mode 100644 index 00000000..2457c604 --- /dev/null +++ "b/docs/common/GroupedMatmul\347\256\227\345\255\220\350\256\276\350\256\241\344\273\213\347\273\215.md" @@ -0,0 +1,115 @@ +声明:本文使用[Creative Commons License version 4.0](https://creativecommons.org/licenses/by/4.0/legalcode)许可协议,转载、引用或修改等操作请遵循此许可协议。 + +# 1 GroupedMatmul融合算子设计介绍 + +GroupedMatmul算子的功能是进行分组矩阵乘计算,每组矩阵乘的维度大小可以不同。矩阵乘公式如下: +$$ +y_i[m_i,n_i]=x_i[m_i,k_i] \times weight_i[k_i,n_i], i=1...g +$$ +其中g为分组个数,$m_i、k_i、n_i$为对应shape。 +GroupedMatmul算子实现时还需要考虑如下两个方面: +1. 支持不同的参数、数据类型,如有无bias、不同激活函数类型,非量化、量化、伪量化等不同场景;不同场景对应的计算流程不同,性能优化方法不同,因此实现上划分成了不同的模板,有各自的模板参数; +2. 硬件上AiCore内存大小有限,一般完成一个算子的计算需要对数据进行切分,并对数据搬运和计算过程进行流水并行排布,该过程对算子的影响非常大,也是性能优化阶段主要调整对象,而host上的tiling函数即是为完成该切分和流水的参数计算。 +# 2 场景划分 +从功能角度可分为非量化场景、量化场景和伪量化场景,代码层面通过三种方式选择具体模板: +1.编译宏:通过x和weight的数据类型编译的宏ORIG_DTYPE_X,ORIG_DTYE_WEIGHT; +2.tilingkey:如x/weight是否转置; +3.tilingData:如tiling中isPerTokenQuant表示是否为per token量化。 +说明: +- 非量化指x、weight、y均为浮点数类型,如float16/bfloat16/float32,非量化为纯cube场景,其计算过程由matmul高阶api实现; +- 量化指x和weight为低精度整数类型,GroupMatmul支持A8W8场景,包括重量化、per tensor + per channel量化和per token + per channel量化(简称per token量化); +- 伪量化指x为浮点数类型,weight为低精度整数类型,GroupedMatmul支持A16W8和A16W4场景。 + +## 2.1 per token量化 +本章以per token量化场景为例介绍一下GroupMatmul的算法流程,计算过程如下: +matmul(int32) -> 反量化(fp32) -> mul(fp32) -> 激活函数(fp32)(可选) -> cast(fp16/bf16),其中mul(fp32)的输入perTokenScale还需要从shape(m) broadcast成(m,n)。 + +![GroupedMatmul量化场景流程图](../fig/GMM量化场景流程图.png) + +## 2.2 分组方式 +针对不同场景,GroupMatmul可分为m轴分组和k轴分组,又称切M,切K。在正向训练过程对m轴进行分组,在反向计算梯度时就需要对k轴进行分组。 +- m轴分组:$k_i$各组相同,$weight_i/y_i$可以在$n_i$上拼接,此时group type = 0; +m轴分组可用于非量化正向训练场景,量化场景和伪量化场景。 +```mermaid +graph LR + A[(x:GM,K)] --> B([GroupedMatmul:group_type=M]) + C[(w:G,K,N)] --> B([GroupedMatmul:group_type=M]) + D[(group_list:G)] --> B([GroupedMatmul:group_type=M]) + B([GroupedMatmul:group_type=M]) --> E[(y:GM,N)] +``` +- k轴分组:$k_i$各不相同,但$m_i/n_i$每组相同,此时$x_i/weight_i$可以在$k_i$上拼接。k轴分组仅用于非量化训练场景,用于求损失关于weight的梯度,由于求weight梯度时需要对x进行转置,因此转置后就从x就从m轴分组变为k轴分组。 +```mermaid +graph LR + A[(x:GM,K)] --> B([GroupedMatmul:group_type=K,transpsoe_x:True]) + E[(dy:GM,N)] --> B([GroupedMatmul:group_type=K,transpsoe_x:True]) + D[(group_list:G)] --> B([GroupedMatmul:group_type=K,transpsoe_x:True]) + B([GroupedMatmul:group_type=K,transpsoe_x:True]) --> C[(dw:G,K,N)] +``` +## 2.3 多tensor/单tensor支持 +GroupedMatmul算子支持输入输出为多tensor、单tensor。 +单tensor指一个tensor list中所有分组的tensor在groupType指定的分组轴上合并为1个;否则为多tensor。 +下表介绍了不同方案tensor支持的shape,其中单表示单tensor,多表示多tensor,表示顺序为x,weight,y,例,单多单表示支持x为单tensor,weight多tensor,y单tensor的场景。 +| group_type | supported scenario | x shape | weight shape | y shape | optional-dynamic inputs shape if needed | group_list shape if passed | per_token_scale shape if passed | +| :--------: | :----------------: | :-----: | :----------: | :-----: | :-------------------------------------: | :------------------------: | :-----------------------------: | +| -1 | 多多多 | [(M1,K1),(M2,K2),...] | [(K1,N1),(K2,N2),...] | [(M1,N1),(M2,N2),...] | [(N1),(N2),...] | not support | not support | +| 0 | 单单单 | [(M,K)] | [(G,K,N)] | [(M,N)] | [(G,N)] | (G) | (M) | +| 0 | 单多单 | [(M,K)] | [(K,N),(K,N),...] | [(M,N)] | [(N),(N),...] | (G) | not support | +| 0 | 多多单 | [(M1,K1),(M2,K2),...] | [(K1,N),(K2,N),...] | [(M,N)] | [(N),(N),...] | (G) | not support | + +例如在多多多场景,x shape={{4,16}, {12,16}, {16,16}},weight shape={{16,8},{20,8},{24,8}},y shape={{4,8},{12,8},{16,8}},如果想在单单单场景进行相同的计算,则x shape={32,16},weight shape= {3,16,8},y shape={32,8},groupList={4,12,16}。 + +# 3 tiling设计 +## 3.1 tilingData设计 +```c++ +BEGIN_TILING_DATA_DEF(GMMTilingData) + TILING_DATA_FIELD_DEF_STRUCT(GMMBaseParams, gmmBaseParams); + TILING_DATA_FIELD_DEF_STRUCT(GMMArray, gmmArray); + TILING_DATA_FIELD_DEF_STRUCT(TCubeTiling, mmTilingData); +END_TILING_DATA_DEF; +``` + +tilingData主要包含上述结构里的三个部分: +1. GMMBaseParams: GroupedMatmul的分组数量、Aicore核数、ub tiling参数、matmul分核tiling参数等基础tiling参数,是否有激活函数、量化类型(per token或per tensor)、激活函数类型等功能参数; +2. GMMArray: 当输入是多tensor时,通过三个数组记录每组Matmul的shape, kernel通过GlobalTensor::GetValue()的方式获取。当输入为全单tensor时,m/k/n中的一个值在group_list中,另外两个值在所有的group中都相同,此时kernel只需要访问数组中的第一个值。 +3. TCubeTiling: matmul高阶api对应的tilingData。 + +kernel中为了避免在栈空间中申请GMMArray中3个数组以及避免拷贝这些数组,采用GET_TILING_DATA_MEMBER接口拷贝除GMMArray之外的结构体。 + +## 3.2. UB buffer分配 +在初始化阶段,GroupedMatmul需要确定UB buffer的分配复用情况。非量化、伪量化的计算过程简单,没有UB buffer的复用,而量化场景多、计算复杂,UB buffer需要复用,以提高单次计算的数据量。 +定义每份buffer分配的字节大小比上处理的数据个数(baseM\*baseN,记为ubCalcSize,此处baseM/baseN为vector计算的参数)为该buffer的份数,以per token量化为例,假设baseM = 24, baseN = 256,则ubCalcSize = baseM * baseN = 6kb。 + +UB buffer分配如下: + +![GroupedMatmul量化场景UB_buffer分配](../fig/GMM量化场景UB_Buffer分配.png) + +1. vector计算的输入即matmul的输出的类型为int32,在开启doubleBuffer之后输入需要sizeof(int32) * 2,即8份buffer; +2. vector计算的输出为fp16/bf16,其需要sizeof(fp16/bf16) * 2,即4份buffer; +3. vector计算的中间计算过程需要申请tmpBuffer,其中broadcast的输出需要sizeof(fp32)块buffer,反量化的输出sizeof(fp32)块buffer,还需申请sharedBuffer进行buffer复用,包含broadcast临时空间、反量化临时空间和Mul输出,大小为8块buffer,共需要申请tmpBuffer的大小为16块buffer。 + +总共需要分配的UB buffer为28 * 6kb = 168kb。 + +## 3.3 基本块分核方案 +GroupedMatmul实现时需要考虑输入为多个tensor的情况,即每组matmul的shape可能各不相同,而kernel侧不能为每组matmul单独配置对应的matmul高阶api接口实例(tiling结构体和core栈空间大小均不允许)。为了适配不同shape的matmul计算,Groupedmatmul采用基本块方式(横向分核),以baseM、baseN为基本块进行分核计算,此处baseM/baseN为matmul的参数。 + +![基本块分核](../fig/GMM横向分核方案.png) + +## 3.4 对角线分核方案 +按基本块方案进行分核,容易存在同地址访问的问题,例如当基本块方案中nDim=coreNum时,则同一时间所有的核都在访问左矩阵的相同地址,对性能影响较大。因此当基本块数量超过coreNum时(没超过coreNum时,对角线方案无法解决同地址访问问题),可以采用如下对角线方案,同一时间不同核尽量错开对数据的访问,每个方块代表一个输出的基本块,数字代表基本块遍历顺序(横向分核为原始基本块分核方案)。 + +![对角线分核](../fig/GMM对角线分核方案.png) + +对于适合对角线优化的场景,在基本块方案上输入数据会存在多次访问,当nDim/mDim很大时,"对角线"不能直接任意往下延伸,否则在k值比较大的场景下,对角线上对应的输入数据均为不同的数据,已经加载过一次的数据被新数据从L2 cache中替换掉,后续需要加载时还是从DDR内存中加载,导致性能劣化。因此需要限制对角线范围,以充分利用L2 cache中缓存的数据。 + +![对角线分组](../fig/GMM对角线分组方案.png) + +将对角线遍历按阈值进行分组,为尽量避免同地址访问,两个方向的阈值最好都不小于实际的物理核数,因此有: +$$ +\min(T_m, T_n) \geq numCore +$$ + +为充分利用cache,一个分组块对应的$X$、$Weight$、$Y$的总数据量最好不超过设备的L2 cache大小,因此有: +$$ +sizeof(dtype) \cdot (T_m \cdot singleM \cdot K + K \cdot T_n \cdot singleN + T_m \cdot singleM \cdot T_n \cdot singleN) \leq L2_{size} +$$ +若两个条件不能同时满足,须根据实际情况做取舍。 \ No newline at end of file diff --git "a/docs/fig/GMM\345\257\271\350\247\222\347\272\277\345\210\206\346\240\270\346\226\271\346\241\210.png" "b/docs/fig/GMM\345\257\271\350\247\222\347\272\277\345\210\206\346\240\270\346\226\271\346\241\210.png" new file mode 100644 index 0000000000000000000000000000000000000000..a1ab0b01255221d269162eef9ede5a7b51203c81 GIT binary patch literal 10765 zcmbVy2RK~a+BPDM7(qlQh!VX-ugPf9Nt7rdYD5i!s1afGlITS5M2#|Pq%cZIbTJ4< z?_~&M#=j=-cg}aN_nhzlukUwld-m+T*P2zH`+n}{St~+UTa}E2o`isafb5={l0E?e z5kCO|q1`1Sa7W-&js|>Oc2j%kK|lccg8wH>;)O5}5ZoZRr*!9mPv+)ykQ>8t!18Vt z_v8m&5f`KlH`IlC6G3FZMq7hoaUer~A-27c`{bvb#0_eiT{I&_O_N-K-`r#obZ0aDnZT}7>DKU+o1MiLl^^Y& zm=GtjH|j%I(ML$ouLSx3KRqaA9I7uV^KE-nML&8#(spqzvQpW=Am(~(*w=(ly20R! z;5gY-sV2msY>Gu)Dx}jglXaV6BX(c1&F^x*afe;_p!nr} zaN}LM?_6E;l$tlSH@#qcagJpQoTTTJzbxz$_m`53-C?-I5dcklsj8It@we{JIhJ8I z=1K-}-QeSRc?zB8=ATKSDe8VH8v4bl6q6~G;nAwH>T{2s`8eAxEhnE%yq)K3)oseW z0`V(Rf)u{hc*L7w!jZ-at4<8$;Sk{RCazYzOdQbj-9@ZywYkU_drQ7xLM8O%nu!E8 zlj7MfOVSIO5mvc(i8hE1KTYXZot8$9JScAu=*KbH0BE}n-!-omY)wIVeaN-iD0L(Y z&)ghraYgUN5}kwd`o?e_uv@qDot&5`{L6L`@IMq1VZh2mIfV)2_sM{5IKU$GRa4T#qqo6?58Ik4#c~dB z1@Gp57~l2^M+^lMVf2U=F6()>Fp&XpOq3xkz>v4z{K=;}DRRrh#Fy{bUqqO7 z{1I<{{?HhwsAv-3x%${B$llev4c0S(Q&SB+q=Ja9n7zHeJ%$d}(nNJBM-v_bpUDF<3A^3R3?(7W8&mVhPC-EmKLMR%JD7VYFX^3t0N^McwolA82^2cjG*DZ zwbV^~ZEYiuAom_Cc=z>um^@ENz*ETSp_Y8$adw!(oSV%+P*G;+p>`buH|p4jV37ot z2IaMrPNg2~>_Ku0XI^6Ex7I!{3?fgBK7;!Q-Lity>x1H{Z{LxTk=YMsi7mdr6(YhN zBEqMuVr^~heI%!>ZKEOMj_RQvF#HCX!%kJD*2p}Gy2s1QZh1+QH~)gJihM6ZmF}3E z-6AGM@AJC*@yhZAf~c~$jFcyE!jI`|T(H+zaPex5=dX*WYy5v*=&uR=bs@bb3cIoa z`iQQ|@=ycCUp5;p|LLq^K>699dU9ePvcX;lx)78Y8TPQ^`?>Ze)cw!9Cv{yLL_C(; zkAUTDwdgPhnLP@^j+(QaWd~u>>jWuv`vtCuDURozvtIqWo1;1!Ykb7?`!!?NW$`S4|$Eo3XjZ ze1^WV1^P9{>P-g|DMiiLXU{H#bP7aZrf0avVm^SeI?SS>3LbVv?#$&A~5*c3-gRMxvDb(5`Q zgx2>AW6|^$Jb?Yheh-6dgbUmHIBJBQ9}^Wgp4%z=q}6r;&U9FkirE(|V4EWE{TClaB=_O>Tl zRr=;F0WKD8o<41877YUE6nS)PA3*5N_yjtY@@DH%)#6vxPr|y+F)QoI2Jf-g`k~R$z^&?`83V}&Ys^v)sNsQI9Rapy9roKeAuN(Rhxk>y&67s@QBU*x zZMZ~1+8eMNY`)TPRDEr!!JDxYAq|K@UxV}BC&b2Rxx(H+zw3{|n8m-Vhr8PAoEqd8 z-AjBwbPhHJEp!YH55}zvZ^{S?_*I1;H0w+=c-$aqaNwYtyRJ*MVTgHCrzmWE)BxM+^S~I!r)RtQVXuAUc zL&`T%-mvTBCt)HKaU0XT4C07@qn?GhzV)iUo$3m6N0WC#DU`H2;muQ<#)NSJwQ3i~M_2s*l3XQ#<#6!hi0{fq!m3qLCyfLT+teAln z_VfZUUZ)7gG7jK?t6AJj4mYM-niP(vz5CKc@i?bef{Yd*aYz+~qbmSM)qL(HhCc?$ zJ~AV6=~eJ`+<-l00CiNvGM2s2z`$Momzsd{R+LJBqHk1X`>k8V00ezhN93N=cvs1y zNBoS;?tpT<50k#xX}n{gJ3cu*^#{(oEH`&vYz{9fAY72C)IH#Y z?IizKKmN4!%Gl!CRnSIa(BnYL8wFabsj%KVS28N*9 zh+r+pvjjSVvGyvUfDu-Agb6^tt}BAI4|r=ub5>YII3Npj2xF}P{)Kbyh#|acS=x_7 zmp!T)UpL54ckyy4Ip2KUz(A}_P^L?$rp?OQx^N){6lo3(+_FeYv6ws!O_)+Zms(cy zQt!32Ei$QRk$Da>c9Q5sOCoLmkI6{r?m-C@joJsPh$GR7LBXd_T+-Y zDH=ff<^q-3V<*jQ+9(a7gcWaoDSCV`0b?hO>!u2iCO>|YJAE@@`Wg#EVVs4$JA&E7 z1T{7LXk$5Wt6=hcr_EpZ)*xS_+Aa9t(#wQCK)3w7=xJcH6=yfq_F&8B9VhG222g9)qNg z0?V{mWYEsVy+y&kfp>eKS$b#M>y(YyqfN)3g~ zed60(ZeQZU3vo+|fB5Qu=;(h!@eUmun^DJcS2x=nd288Vm2>)Ot90F@gCWNSoE~u3 zYbDWLU0qvSevz4oWG@urEpO&~`d);6Af1QD#2xm;Rra zJc%ddxs+K!z2o9)@i5B~;D;PO1glx=XnDuJ-w|M!KVNwOm=e!Pz3x*4o#<16%JTP8f_%=$(X7MygU!?E1)Ne3>Ix#mEvCnz=^#S{`_~ zPTU}uyNNmMajfbwH#)g5pUtZRnKbEH5BU!{0c)?$lIYrRutY^0+4+abf{|~6G@mu! zOj`$Or?!|5qBI$<5gCIp*4dr@lXve@P90 zwhkxf47_Y^eD-7<6LdCg)ZvT^dJs+jrbT{b74?Ryx5!Zbqsc=t4@9ct5e2kC{7izS z13l+0;gXnQar}-Z=%jr3%*yLd3B~zw*qBzu%#;e+UG4rf>#9Qkdsfro5x1wy34#?b zc9r4iq9uz~8u{P1PT;I&EF16Dr3!$jemd8{?xb-aE1YofvP^yt(LUAC<~eHc%{~yf z)h&-Z+8A_p4IRJu6@CyH9~c>SU8DRiWHo(lr?~kN>(ubq)&2Sgdcm3Y`N;2>*&x)o znd3~r@@G)U;%g;&=ut>hjngyR^(lcnalW8Uf8047HgYXgv6ka&UBSL#T#;2xWe0rw zTb1CU2l*Am+USx&I1lghfE~}@Usz(jxC!I(9DlfJ`a{3OKR1wseF)aoSGtQkg+=;9 z>6nQX2WHx8ABP~lvt@Bjt+?hZUihO>3&J9=nD|H6_Jd0H-!)XEgn}ajE9yO;A!F?4nBts z!8#&@pLWzwN17A3yfF&6m}uwDsE z-wE94-D1%7i)%AwPU5C(NCaYhd3m|~p_U##=yv?&4yCKV-vxVg5{oGhDW83@jLvBE z7p=xjcd69z2=Bz*rAvIqO3%R1(z1Q)ZwWj;AKdXq{W#G4jWg-*vcwy&?*}nVIwgxF z=SBOAJ-V>&UZV3SbIWwTG9^jSd z;1#@c&&f)FR+9jtSfi36CIx((0|7U#<8P2j5%=M5RK`QA+3+_mN)bopWiymep{`a$ zfx4K%1L+fZ@J|&;=LH}Go=(u$k(T_4wF}<`B)qIWFqTcC{43_}C7C5(7e0Y29NzQ6Zjuznzrz=HuM{%je!ph^H5{`E+?1RgS$=*ERd?lsGkt&gI>%R&*9u*^b7}V< z$931;W5l_A7^KowH`md%GwfppE|-P^J5oz!15vWc&Bmblw_9*YXIwJgz2rgMen98C2NB)_29f-`r z!^8hDVZaw%v4_s)F6TiCjOyx-=gyXmQxsb|(Nrp4_0>o=w!Mf!#Jict$_y< zn>&K7?#Cxo2@UmZItu&&0eim}4m_r!K*ScQhR8mPk9Xz%ES6`&mkuSByZoYx{V&-; zqC+{EHGetUkEqs3w&D@L1+LbA1-kbgTI-tKdYFFY506ntmhFH@0-mzq3ZPMCY z0Cx3md)`7Y79PA-`w&7%^9o$?u)$thMPuO-F1HVti-74fx?kdEbq#U>R^Q7CdpQK2 ztu9~S0cm@~4_@zOu^;@mJod&+Ydwl(4^x5kH-th6mqBo=a8eSEXb;%?cn&z{L<%Wl zGg^EVH()PGfh^G_L|t{mhv>av5kPKU$>y9Fm>HicV+@VX< zdons%*Qg+cI+fvXc{&;~i^hKD@=i5~2oAc|Te2(ipGgH(9wEXUaIjHk@GxOxSw1D; z$efC_G-CFGEuS*ov3gQxGo#y!!>s`MYG7o)-hc)z`feV^4Gof+iGWeK2->fq3?Z}y zDNnb12-X~bz{E{HsH?h$po<_UaEV(2kOq()#1Ae2KBfs{B;pAJV@j$BAxu&Nh2+7z zj{z0&fTwnTZa(>Ul*XmY7dGrFsq&Ih9zI801<>xi%6>K zzw4LaH{IX(8#7bM4PO6QS>Wi3C_1^1g<&#B`7hVoPKm#m`_%!Ul^_w=6pRNkw>nM? ziq|S*Mv^W}DBXIxTQ-G*B27w014jYB6tLy{oG;L$lNl~2`Eh}fQ_%ln3bUPEUloV} z&8^pra#*v^?Jn|-sLhunjF@PO;0K=eWkz?u~}VF|A{UTY%?c zW=cxFUKgb_x}sMla}x`NO(`oJRhQ*)bHtef7nZ7WIhwq%EwB-)qy$zZlq?h}wn_JT zbObgAvHha@P7pHEA&~c^xO}|NB7Hz+E8P{z5@&2LmCaaolS3%$7TRWBec6&F3*p1b z)+`W}?HYrH%g8J)rYy$FpcrQ3ou~TwC`RWZY(x_BxYQXNAzS2xOT1<_z?Dv1mLeXx zdAJt+8k>c# z&_EYZIuw+CCv*aLZ8LZjV6fE$ZAbHi*rhjwk!bBA7Fo^aeyN1_z~i$!K;;Y&0FHpE zP5|O6UV#_JK-#i=*8lUytCt!zJQ?4Tr6ol6#`pJKjq1|#=;`SPvg91Ss1xtdiJ5eX zDfIY20!mkBan4&IEF;XP`4W{L|HUiZh5sq*pa+%-jD_UhwV}sA!1L3wZcxG1g@5Tx zympe$md{|pwp!rgNJsMZw9M08)rRkbF}`3eGP&Ao6UB1Miba;FqOMv}OooLeoxrpG zwB=79*F8OF2nb_L|3kJ+(wj`P<9NZ>+0Nhawa*5NK?rP4ZRm) z+!tc@U_sn$dqWyv{c84iSb@G>%6YyUunHagl81x6#myW(soLMv1F`KdRd1bqt4^JA zO#ugt<&kV2SPk)yI$2_ez{-LDn98}Frq1ZrMO>!_&F3Tg1ClrNHN?` zJPiu`y86c0gho#(Q9VGK73QR-kVN@RTWQqOB$Mg5ucKN@jy@KCLiOe9z>nd3yXQ$& zUV}wSwTCQuV1GazelI|3xq@D$gXj^=sdRKSU0VY-d~3HF_md^1w&{=7Zs(JDTGV)ADMDU&B<@W0lg{|5@JD`G**gaL=Ax71Ygc|; zxQnEJ6~3DpuQ00M*<^m;@ukT{Xm7~6GvvHBCYHWsRJCk1Gs$@M4-C-HI`&7pTiie3 z+6Hg<`1oKBH!ngMCA~I`U0thotzF@Z$w*V$v;<5SXd1tQq?nlPHMgB#oybn>?pW#v zj+tVyG^fNL=@G2PohTHl#=1LNQdr>RnC55$zV0L7?oQ{%8p|Y;j7k2MArTi7Gwr|s zO=ZVvwtf1y)KsG@3p2BJp*~zq6yFxK(@g7&r%%grHsn@-a@q362KDdCE{pi7#t4A&)1o&eXF$) z?QIwN(9d(AHpfc%l6%lIMduqZ-J|sMbTF{rkw`#Y-WwCba&ps1q^zjuf~WL5?PXve zk%8l5?@gH?-Ju;CLbUzUa!*}cx%4uW|*~0VfB)ep$M@x|*71JX_{ugcJ|2 zf__QA8vq8n$eM_$>5gOZmVUSD`Lika`5x*Xdpv*{gh_g2cvw_SjDwAxogIF?<=jQc zhYugjP}`SF^kfx?c_(UZ`zp;_KF#d_b62+6U-~vZJuTRAy0@~j0&EWS+35ES_=VFJ zY}bkDkE;H92I7!T^!m8BmlyqIR<^pjI#97`gA=ewVZ*gd?c#7xFR!rAb9vk>!rn6^h``*?G+sz zv#IAck?{TBy9@Ab5l($BovhovqL1jMWn@SrrJl_z0gF(}!D5cU3@+a-DK1`~@!M`Y z!(oP356C2tikCWxQpr2nz+?ezx@hV~U)Ivl&~SC#1`{+^Wnpn(YUs1%3IFj!&sxg~ z*ed8e9Ddw+k>z%Jbqp5_q@A;T<|FcSXd)14qW10{_0%&QK_WIt56T2fp(mQF(W~x0T0CZX(+a_!z?KP`c>nY#c-`bXSgYsmTw5w1q8|9q}i`#GceUPz72uclUEPJ zchTLkOjHSA951r??;NZScQu<0fH@n?lp01NTV=29&u-OaV>f{5+S%DfFGqr@_dHn9 zvmVTKlWcvlzho%{@Vn?K4bA=jB8W@T+|nFk81Kj<>Pt0%@h=HSNv~hHcu~;Cyt1^k z=l7hJmKH!5yx|qFZsh%PIXO8n?l1_`Q%6VMiMsC8720>R@N`mUZ|Ve77kFV^o#=ST z_a8q11;hfX1m+KnZ(u;1T_htteXP>lue5_dr$LD26*X-galWm#Qnbq|AR&~?M?Q;J zm@V%P8302t(O>{J@6K3xdLDo|1kwqiE$I8p=a6o2THsIt#jku(wB%-ULngC?n2Ag^%BqUg5eZD`56Hi6} zTmul{(oJ2;vyct$*|>vZQu>_brKQ7-$+z0YqG|mb8%_cx(dZ8i4Gq9lEWD?m17wlI zK6g=00xPZR!1-2<%ZQ0R-~T2NrMA@)eZ6IF?s5KT%lr4=CZN8*mHY`^IaTz%`0C zpR<_^U}7i#`SWM8j_zJZ`7d#u&hgt463fhY!E**i#@n2a+S=N>x+VZ#0p?dd={-`+ z5OlLvvd14})c|9WGfAU7rO6p3cP6akzJ2?)IJoJKjzaV@%_*7#MNMFrr$b0!39_^Q e%Tv)A`RSCyogB2@V7TAxcV!DuF=Y79bEt9VIM2LO2dXDL_(H zug7+^{o$DeMOeZ=XjCR{|M?P`C3Vz;c|Xp4Je7q-nSj?F1Va>r@+n=s7zqS?NRwc6 zAPzzi7VGQ{M-hg>0Yj${h9NS6l*s_T7Ki%(RLHlq{Ju$2Qc}Ecp+NyC!h<@D4svkR zi4ga4lg%1H~IObiBhH(AHqe^aB!&5!lU3wsNOe8uW$)cW>*pzBGHNg;E8Z# ziQ`$>*+YVYCM{S`cYpu>4cge=X8B9P`Uxock(~Ve^mOF?l`~Olc5RIS50AGy&6=Hu zlhb`Hjg8Rf1~srtqt;B;&hGqT*4x`#E`u{?Ek>!Nyt2}FJDP(NFGe_LQZIU9-|96x zApso+r^G{3Pw#AFqtmQbJVsO8%d16{7#0bM{YV*xh!soug8UpEDRyA{(tU|bQg$_! zh4pG^|GL3$-FK@O7Ae%&B0UdlY}c_#r*h9!PA(=pJKHQU?eBmABPSkG_Jn};}Ni@_oa8U^>tKv%V~vC-5_ zb={WhOsw>c0WF-hNp{uk z4?pbD_IBPG&sUu`&qTQ*E4$X>;=zdtyCBFCrIPOn39}_rV`F1oT|``5nll^)Qx>=! zw!q6)qfUhMb_tr?2@ZG4O=Gvw^7Qod@OZr7&t?x7%~hb>8%y73j9NpE5G8g&o3?jy zVqv`mDr==S!@oMX>-gLJA}KLZk`xUpFh^5KTYJQ;_P2B@hh5fP6w$3r?k2do`s^Fs01DQ{Y&>F9KN(*;Km{3 zH1UEwiNryZ51Ttb!)5t=oPj~&IiNcWk;RR!)0OhNQvBfE-h41PD8Zf%bi0>jyN<(r zNoZ*3OfRb9(60mfPa1J*D8d%sJA1h{EupNlJ@#-i!4@V2-ta(s6WMlupn%a>Cr>Yz z&OSDowlw1gwh_tDUR+TD3<)0}pXGWRufu|51Rwi{FU3*?%)Goh*GeqI+*y&5K=bMA zXAcpI!w|JGI`AXRRqCc^XY1?gR%G^z5>IXRgmT!faC#dZ!SKl;h0ZK2aP~U_VP>|0 zhMIvvof3nMjm_)s`t z!9hV9nwq;iJ3t`f;c4j!MZR`#MMb+P=9BvrjfIDok)7S)da9LBC5V3;fE9$o+y`+p z$>nud-qz;x_3KwVyDGCyU|5_&0|OX{XkYzy@?!ay8XdMTWN+{8$jQh)KttCqT6=jt z*P2h1eiRcGEq<#~7x}jV=ZlI4aYNmT0}%`r71dCHSQh)J%tDOP)C35L*L!1S6%}l! z+d!a(x#J`hKJwtUwY6P>7jF-hGcn0uYpI*Ds6)hMFD7K#1BUj-#!o%|?q1>or6!+e z*F>p;tj7F9wbJKC5`C9a+0|%vXSaMspr*b3{aY`0D8go;93J9WV>?B8f9B?p{QMDw zVWJ7i<$zgY5XKz>^55SSPE=r?ATihy1?Gt;|8Vrwflxx5;9WQswKnRIAL?^nj?-@~?bYBHx+D2DkI=ju& zX~Q}Y#acbC_fQG$fJq5`6;dukRaI3uM}a>1ZC~ImhdW6!qR%K?lvqHZL$yqei<=u< z(8}sKS>IPvS9fgFK3|Dy7EDe_xk%2%*%<)=;pX}}j~gq9Kp36PL9fw%^Xc}us_kwG zU7q*ydHpTx$ez4}O1V%&Tl?6jtfj>(fW0EL^X2*8@9D58fkGY#9dvrFDM$7wo(nRA zy;K-*eMSyaQgE~1KATL{)YJe&;jJ5;5|7&#HvgKJpe^)j}xkrQZ$!RfxfK49D{ z04;!=4x7rEo|Z=7_t)WH(aR>XANGDvC#rCd>|igK%aIB&=|;nugFcG}0bgxtX(>Sh z%Lxa2(?DTv0S>%3cfNRt-(AXXNHh>gx;;Q0(LSald;=bb z(*;36kpKEhP{^;9-o-zEZhF!GAoss7MwVzf9Qa=!w%wB(#dKW&yD4!IXEjlYPR9GW{L{U8TcJUsU@0)C_C3I(#s&c zWC}5Qq@M{m-|Mzk2AWh61z*<~A1X>JK7&9gMdaQe4AYs74^5Ivt8Dx|#WQJatU%)T zi~<2}VJhAy2AhviAh)-jv1@S6p7xH<9OUpxf9aTFCcnl<#NwB8txrO2G@59O>QL%V zXrIMMcZarr(ZbFkWjntqvsGnAX+Fkem^gfYIl<)q=) z?WMCYsF0Dp37szg0=u7|k4cAaZk|uKg@u8kDq#``ul(X>lGJ_Gn^l#afVvhW?%nkzemca6CNq@_1 zWM+#8E714R5z;`Y2+%qI*?ql{-D=Z^RxAbU$s^mq6*e=cx>9UKLJJ83#bFIVVrD^$ zUDaa_jVv?Ix8WzSTy#C>b2DxR83>(4oceN^C;Kg-)Wy;tB^r!PbS?JlItPS*6pZP^2BJJ)AVxeuw~%6`iVMfPFs&dyZu zpiUlE1FCOAoYllgI{DBw!1`XKEt80dtjyiVqH(4ZPWbG#a29MLioT5kx!PiPi(M-4V8(}vD zm_MfP->?-g-RQV|d>GusYZ zy{7~zvW7JIp+9NC^0DIdMy%S~&R;mNd-}Z|Act>$pc8q<>&>Zk>1bz3j|xJB>R>Ct z1m4t^itTbGU%WM1BxInCh<7DSE_(8JHjS~PH&YcSBOZ~=R@G~*xb$h1HaC{QAi<0q zUt^&$92f~XV!In1KWviof88`+H6q7u%!n7>hHfcOI13>+mMFyUcy78BH|tdh0+l=X z@mS2%EP>8pxC+rjHv1HOKHub(K30sp0}4!9^2lH+0Ue?epYul=L((wN;WQCZX(SI1 zk@6nrRMg`>lavveqO|Eu^*+E)1FUy z;^=->ZDxXTRF>l|8?mvmg^-7m=o@HkF-|&-*~Ee$`qVd%BywLtYx!Y6M720HUB5H<o5_8FC9YI@Kf9OfgrSU6F}g=tn`BcCxx3RJ}i~)dmxkrS&4km9J3t zIqZeL0a0}iW@D@8%cYt&XTA;^I|7<5M2lh2S8; z9@$2yzbaJ5hm{38h@gYXRot7!{1Z5brJ5JnG%n5~Dy}TBYzHSW(9ct-y3Vc@vwAyE zv`xgm@ZHSUhJ(VNrIFJYpX*sN+E(G#UP>ZO&!wC%pSrOn61e|$o>v6-*zZzPFLT|2 zA2`(P_|`h)55BclC{o;`yRsVpqN<#?#GPK21&Ol9O0uhrINK94--;16!A=qF?7~7| z1VlYNnE5HylHcC=SPfH4fR+1Od>q)cQ_v)7_6mela#&({PEWGhg`jFIN3wSzY^P2Y zf1w*wvUhYKHU2biVF)@{{?`xA7nKyK@_noV!LY3(nr%d=JHWsgP~CfZcDn+hdu@Dp zjf{+92n8zdD_UFsej%7KTjVqBHv52cc{w)a;guA^j2ZA(O~kNA7nLK6>d`!} z?!|G;r9N_|Z-f|KN|e%qNPK$bsrzr!i6+uW57zj->r6-1$;!0qhcm6b-RZw|U!d`> ztKZ)vkDU^*^Dk01Urk(_ z+i`H{Rkfe6rqew^cNGy%i9}or8ywRO>AZbB(L4#3CSKheAHLHkftmE+=aTYx2?8!jAKga!c*3k*vHIq4o?d`at{H#5j zZ=e5}{r)!&i}iu4JfeQl%bGkA3n2%6z**s6y53kNqnqSqkoA}9)g41zvEzcU;ZYiMF>=33l0!V zp47NLMGU?gZ;PpPD)=bCt7Bnv`~0@rU_*;**jKnOIkTnS2h2E5&co1Lb-XdA&m z=clorcJ_vU;`0vk2%FB=lOAeI_ZOS8%GNxW&HAH?PMo#RUw9gd_)P`sIl0q^BaU$s zAs9|mHaGDov~S_REM5-3Hk_ou_i*Xe8!s0=c3aR<(N@mu&eMJU5{7Dw-XFxonm;4H zTZbl@#x|0 z)a(AVj--eWLb(4IQwg|CD+KW!=Rmi?9Bl>4`I_I~{>1CKnWFelsEMO-Y-3}+ava3q6&13~3m@_J6p1-ksVM|1YYdWdW z7{6j;q27v(S|R(&Uc-HIxd_aEhLT_EEm?1PYpv%wc0WutR)G#6aGWWck($7RkBsoT ziKj9?chmaO98t4PSU^Zaox_cC6PG8O?OR?ZayZ`zTdOttxo)xK3o=DlzjHuSYT-R12=%MQWpH&E2RBSyGr_x?2i z^4Hv~tgIfreF2hcXqjT&uPyPn+x^`Tt8gTx8K2LRzMI`!RYOfNQm@5x1DUP-y7V9n zo$!-PZT(tSyW3m6Xek63w*1!4^`oSuqz=7$fzn}XfeJN;^+Mc_A1d0~Jna{QPG43Q zR`0VF%^`?P^njOra`yZf(HMUu7&YmBJM5>zM=}8=Jb0qk|M${dRX*uuk zD{Jz@jwt(+IagO!@bK^)9UbdFma3EjK)ZI4R3O>1^;5y=rlF9E_&4&ri(j9I?$HnP z#k`!1C%5bY*pQ%ju|?KU@HBz^-BKq!K`f7qj0|8KBSixn{s5E|KrWH?67TUxukNND zv)Bw}?8GQ^q_SAQDmrBzQAXLfJ=Kgw;k;j)BprvPO8=2!Rx!$}cy4}vet5VPps;F5 zDJTXeCSrRwqnt)ajaxVm+c;2!MV)>MoN^!Il;$x8%vj%DctH?}1{9s*XBmzaaz{8& z$Z=;9;Nk*QMwJRRm>?+{@I>OGQMa$~Nf75Ek(8{}X@qN`Sc-#aY1uegF=%LL)ARFg zyTh0uN*WsD+C_BR41fnj`9#7(3eXi)P8 zKnpB4*r~54_`G_H$bVx>A0Y5L(VM8kZ)={9jhQ)(dBc8=Z+}LXdHAtth>a|S{ky?4 z9A?{hzN!sTeXL`#1uk#yZFl_*ibOyd5@CB2CNPZrMS1k+zhugIxKAR+4`NmI50FC-lg>B7XZlDbL}3x^Q?a)u%$57 zJJoBZLu36$-Y;2zbFNS!O6*{7-wACGfYvn9{@HC6>crXE z*~7y_?vY)9Q)R@Cy0l(7$;aI1r%$Ne<{)JO0|~m)v&mPsb$`f#vBuZSl3sIi=4m_Vs@G_*hSkw``c1mNcaZfXCY1w#(}%>cd>>WcvuI1prKh3= zB+u;qa*}C1izqy+MH9bt1uztFvQEk=t*cQ3-W)c2rpUD{2vs zVXIU^ADe`ZB2C4)8O|HY^TlW&v zVZ3S!RoR>kN0!l@Mur{6Gc{ELD(fRs8sZWKg<|@p51*tVAJ^kuaHDPef_1uf&Fpb0 zpTZP!JR)o0zE)6MkT>)i_F5iO+y41&sn-A(F_YAmd8OBvnsB8SFmm=iq9&Go1w7;& z{C@E>13xuo0#fTN|x>35-)bKb|c}`OO5Gb80&EkgGzVvoTMaHN;N=D`2I=qD2<+ED4=2Ubv zMCDMIYuQee(>`}pC2vl5HbxuI;nwnfYUaqcnYJ2Mr{HXMN^>qwQ%Si8fjrN`YUyuZ zPW|XNI~&AQGDJ2#XMHPb8VPEfOSt79IMS==A;BvzvSkoid?k_@5s`hWzfNCR#ugRN zQT3Xm>3Dd*?U}%XHhOcH1q~fcXLys!5R18gutVVEZ!I>r|Je9CqmRY+SI^L4`;?wg z9hWfXP9#)Y8V)yGN6RCdL?h!7Bzb?gbwp%@;7$cQ za|wwdYCXbFNL<0@ym=EB-I)dkkLi6v`k_3im|@}JKvPv{){+BE_zQwuqLiw#^33uw zz|X5vtD2dS2X_6csbRJ%E;R15<%9;++$qi(^Te@#HseV~hX<{;nP75;Oxk>Hl%7#X#;EQR% z1!)Z(!YohM`JI=WFA4hCae$_%3^~)M621v*A=@hU5^Q6`86#|_rmLuU&u`z_KbSE( ze7CY#wm0JGk@TRU)Eck%e})^={+6DXjab@Di5N?A^l|lu2O6l8n3F3oi)C^;1WfFg z=|4R@0AVbg)mXw62?X9&%+udMxf%ga#?14ka5d4YRNkTvI~qJXrB3>NON|92{Ueho z8Dwn9OYuYUYb{1mg_U$$$^m4bNfLL!nV0Uy<{2a?vkf1_k}qp|?`c@r!;6(!H#$Xz zi8;SaD>HdI;WEq4y5!L+H+Naz#fU)5!SQ7=H7ghv57*+8j+kx%NU$vbgIW~~qy*46Uawg#Fm2a;+}nBIQmCkFBpFn}{6|bi(0iPH)qV3z;Xx-0@JZjoG96o&mNaOlO!}QL_Skw3@EMF{klR>42x}Z-{U#1tI8xH(p>HgQ z1mQMNt{djXm(KN97Oi-PpTpt88Sh+-mjBMLAz@_gRpa9`M>$Mn7Szh{ndZ99nFuuoOqVgfW+;5fRDv6#j54xD8lIo+ zVQKlFxYkh+S_IaP8HeckuI6}WWbTOtA7sc~5v2BfT*@m)hPTfCY?7aYSnpCl(FrdW zXtLEv%^l3yg}M?8zJ`WRFLvr^R+20^ar6S4d7Kg(#<=%lRGPZ9>G*Pb%TPZlKuQt) zKj3{fuv$)XsO0;}kDmbbV;`e+`aoOVrogfS0s{eQTd>Zqt}a?_l2q2WGH@6W=NFgH zSUqX|a5so2pHD?e7RGmb$9J7w%d9IbK@spx%sd~J&{I@2^cFU&ZG+4D@OhcIM)ol> zY*Yn7n};$P^_Aef5d7PF02E(FwD=a-G-kbiLt4Y#DndCNqwWeB(tecJLACG;TOM{- zH*-VOyI^r)TBMtgJ7!p$hwa)yWuDq}quT!gfyoh=L^ue)EaKoSFjXamfj~d^k(`Bo z7aJFfgiXY`N~#rv%_N=LExi+KvWGXEYiFEUxvElbMqn_GZSvdq@?xYZYkW>a#-l`OevqL{mi2Bvk^;c>r{1M)rq9mBJd+X`38B zJ+xmR~cEG?IO*;Fnaoq5;wj__h?AZaB_IISd zjTg>{gBcvJF3_)mz!QErZ)mCFl02u018 z2v6q#>$g$XX44!Q0^5XGb(G7!r4EL2UceNkuyRgEbd%&GHOOzPLPRQGCvuGnFS=Q_Ud+8?L9jha`CjpaA}x;)~*N0P^G<930SeK$4L- zevr27eO;c^kH9BKLsOQUn+r$_fGszdc% zWf9%%EcW|D%XWg3CV6=$-6W%X#My~1A@Z&qEfc6%{3jG`G~p0lUry)Bw=#LcGEcmX zqDGg@vOTcQ?OE|Gy=^|XMT$QYCL&_p**vUrH3P3ZT(?f|i8Ldhw;>juG7!ihv%Iw7 zFb`s^1-|`R@oyd>2&8iEzQ=gSLxIggI^N#i9u@|33p*Yb76yF2v_l7(O=Q_D6)=bK zuB-pG^?)>!4H7f(u@7yQU-a1$yr&&A`nWql_!o$tvDr=jnd9N6t!_wL_MFkt(Ppq; z(xnQNC^3`^Wtg6m;!_6+o^KA$r(a)TK>qg(?H3x+DMxWAHG*UJi^slKEnpYJu?&bl z=5r^{pjZe&mmU@WY|3z?)Lvc=4~~|#wY4UveIQTDe6pOUJoAbK>Zo{`nMj{|y?V(S zm^0cEwct$7W+$*%l5oKoEFaC_RHdH1J6*=b#RbyhUS3|ZzQX%&e<$9`^*h#DtBIV_ zZns`0^n*0PI*LQH+KV{xg2*!lko^|?CSn`IY`=;Ki7_+1qn*3Eqo#eVXtCvAhJjC1 z#{B&J{rnauK;kc!_iC_V-S_$Jb^Q%MriFbWgFhuxVTjPZ6SkHFj0fKa#v-YZpm2tI zd%1+wt}lU&>2fUvIO7yJ8z0DzXJ*PS0|FDxMhg1VD+zi)4E07>1}Vi0m1#kHjJKa7 z5(*Ne#d^D~@>^~5YRH)d;A{0Vz5o@sAG_V5joYi~iTKm2@o>>-%oAg?p98?8Wk0{l zJ|gA#+NEIh@8Y_>quFymPv179*RfYhj;}BJuP=6tS0tjp{28^(^E<}4T`)nNUM|o2 z_jTGso8n)EtNq&yjj;LV$JX=xUJvI1sB5(KLxNO?o`a7svu87tRo4o#ge7-wP zkRaWLwHp3G)nL2&X#xw705G#I(|M`-7?nH+3EoXpT9J&H8^tQqCqoD-DSa-zzUt9~3jVHjEGSVB2ZEg!vWJ|?rV(4I!KSzy^{vA*R$unmZ8#=dWIVG;$A zCwK9*J_qjjr7VMie4p<-2t0(nkTRcR3ERFE@lwP&kjdRzR6M=K$HxQI#}TR)eKIZ% zj%COJz#j-Fu={8G`6;z}i1>WH?wWW@y=qyvstUd@&D7}k!Z$ZDF_Du46j=<^)JBZh zfB;H3=T~#{lf2*4Od`!*#~&!o;*aa6R(|usl@LOM0jV5prE1eZ*OfDes$@{*ZsS|K;yVWeM7_i+#I-|hewlePH{=eA|M5!$o9@v zp(ga$L8YOgVfpYb%F~3KK{waeU(e$6^7QhuWYv<82un;<1*#M+25>Yy-Vcni1Ip`i z?_7mvN=g=IMR_@(`wC@;G+kk_8@U;3MURi;Dne0b>cgJRHE`v z5|$~Jm--3U0VBZswQ0H@%HA3IdKv9_jJWzlPID$PX)RR`PyaXv{>v3W4Kz1SWJ@=Y zG2kWM@C@zl5%?e4;?=_^h#5_RpziuXeZk#O_Ln{$^qaDS1_gD4=dN)oQ~`hG|CiJ{ za-C%G=jK`M4n&QkC77K-WcQ9JIxaIbXVLw|e@GqG%#(?9K=kC%lk0h2QI&t7hq( zmZa?V#V%$T)MA z7&3&^0}06&UG4i|JSbN6e;fug6ljn(csHCzxo3mcZ6Ftc3FI%6anvC);g~D&o?V%M zAn{m^*x9v7Jzk3$IUy@aR_|$D; zF?xS&oPhGUR20IA_oF1iq`=JN*Ybcbe5AL(&xCEDCadV}n_6^!vJmmpwB<{lfRURo zb!$Ibau9OZW?i>mPsWmIaS@Bt72eUW>SAOKQNgvg8mHfPkASV|w;N8_Ce z61>)}-q5o(O^E6-LL-c0vw7W8&@O!Hj32KF8T}seiJ`fDJo^!60T;p~w;QQ&Ab~j- zz33aG3L9~1rk}-$PMYV$3ZJUqm%S%up+qBD2!AJW4iE^^J(L?P@LAQh)r~o!J@_7s z#-HDsi^&i~<9*&|PjP$cReTR7DLlMbRw215H@$U|FkbS}ZJ-(?+?KiV=kM5ewGB>G zjDG*Q@54yQd&;*j9ewQeCl`=u33MQT^hjHX;J!rJYmG|w-XFG&obQO`9p=ouC2yO& z4rG30S9T%tLJ&vPSiEf7ow|D0ql{x z{(CRoIkVHkzc`VU-2!lYnSbUvp?od`D995UEj_T`g-Z&EvT`v%R)9F(!kY+H>c2^R zF7`4*faSpD^zMyizm{xVuJE}rV@de&=Yt^bU>dMC;u}M1zHxmGGcd3#rlS9G0~zwm zTnSU`{4@dU1pN$$f+h z0p!^XbMg(CjsI~RWF3})($>#2DQUMdxC(b@3p$9H6gGB{g;mHNML(@?UK1DOyrhb- zP6|<!a?qCkd;BBZ*Kt8XM>GfHT5cKn45@;e=H{h;oV!Jq&e;C- zY=h{eo>9O3fb`rNXxv%>=1M>}Ev}w*bLIpN4VvOM@6j=3xHJuZor=2!^CKM1dClt3 z;q~FsnJV~9Vo=tZMSk~i%WM0e6Ug6Zdb@G+bvP`7i<09Pr2aC#h~P<3Q65~VS|6II zC^D7bu+AVSsV#A?2;f1@++ngA5j8A)@IB&3vaajL;K)t3X}i_l#`QlftX?}hW2yS*TAQF3RZulF&^gO-c+!) zKYe+nR2ZzyL51Ygpr6X*+)Bj$eKcX*nP&bf?*N#4YriT+%t7c6I(9dahhx<)ipj~K z|ABOMGfvl}V{P7mj}I%-PgUyIld*=j+T00md^zWMfSyfjAz+{*$f2Vij+qlNpp>lT z2>rNU{n?|snfySme$t^%k9zk2f0rFSO??!O!M{?jZ`Mf2wP4_wsgnI2KM$KzluOiq z0I6Dq|9qfC$>_KZey@0916}4%Ann0T}Z&AQT1c3%gyIV$sA#g^LhBUiEp0Jdl;|jX3 zx{+dCXmqIL)s+lQ`8wOIa^Z!iKIG(I(o?3BeGHN%m1UXxm1PFahab2_ev7jDcMd>VLmaee5mVEkCaR8{3(i-3 zRn)97vJc))#1C|G3MW<*W-FcL>iO=0_%wRI6X50#~}Xtn`CUS$lSSUGJi%L4mBv+NK1LWJ4>NWf3M zJ_91yQm<^R2`}l0YvQvaX?Z~U5Ky-@^FSQ7r^qQ+cDC|I|21Yz$&I@uJ&kmuIwivx z3CK?f!Vu61pX6xIzBO3L-tnQ-O=>uCX>|;grXk@p{*r72>t=LTXUB+0)2YzaoTEN; zL%i+O5!%>jf&*=(=Uv~~3l<)nvoRjhDwt3JeR*>vd@3#B8Rwx={>l{`P|+Y`t@!)z zD1a`a16~%n+T|qI8oCVKZWcgL#kpZT=p*oxKY%U%^>_)=h=wLfF54LW<(HhwD}c@(MQu@)e@pR z#LCy%rKq;7J~M!k_fReiiaHV|(1wI56Z$iqEjO$*dzqa-O8n8IE8P4&!o)Hx2B3~( zL9zGmd;}Gejsq;g!otGQ(bqAw|E9{a&w)hP=4O629u9<>wPzM9ov$h{2#8BX-h!8- zpRMK>RPA$Q;xr3lxs*Sos+e?Y{d_6PQ5b5I)3f`g>lnZ)xgfMOn8d`0=Yxw>yg@;p@3vOE;hbZ zXIt9Y`Q1#3qobn((spbd9Kc9|Ovqo(;Q;WIxXoRY#s8ad$Ek3s6BW(1-$>(rswD?uYw=5_CRNIqFzTkJu#Mq|6ns8elN3qWPx3;yDE^n`M!qH; z4`0}R&(zBJ^ZiF(zAcpjSO)-*_ip|9*)Wg_h!a#)Oq$#_0Q4FZW#YMI!iWv%{IT5Z z-qpZi66EA^x3%SCxNz?o+Q3wK@v{3L)@!nnz4^5$^HgICYSh7L$ilP!S70*dCE z>-L)CHFtJqo#jSxehG>X*Y4%1!SYPok;?%HuLuzEY2S7n}`?Q8wcPrSWU zL5ukY!C?3IphBU_KKZ_<+j%{=R0Qg||Dj7cCp))W<1(NHexGBpQh$n0%F4=0Nx=rv z@tgMjuR{18f3G4lPu;U)m@~E$7oM;#a_0PYPH>Hz|>_ZLAVCi$QB95=Dij{&dY-7COl5iSatY0?uClvPwPaBvpq=hcDT z)Km|^Om)h0&4i6`E8b7ohVy}r)#m^^$01E)EcROX^DNQNp%JvSb;@w%%(XuywkCkb zXXnx#fUTsYL1Sj(qN48(B4C%8xO-qg+)J#|lw3~IsZ7`-^oVlO8_2DyNaZEQ% zMFXkZ54d;n$W&~@vR~X;V&d%I$Lq1ivCH)HQ!|^}7Gr&HAZugc$aUAJ z&aS5X(z*i3e_(7g9y{`6-vHf|B0fx-PCNIjpK_6AiwC+rYEiscQw)%11iUBv`}=@N zYi>cB9TDgD39Iz)dAm(LK~=fNCw8A^R5usFK@&c{Qh}n=)+> zvF>oDB<6!GTxnL8&-v@4bHbMl9l{s9qNh1%^-7&7J-|Wrg*cx3ams>~Kf9_)!)Yg( zIKH^KnVX7gC~CbcO3p9s?uc+5s5Yx+>(c-SurKCxcD=1ht~=){ z;!}1vy_V+Y$-pp3N-|5*#sU}>z@`$j@9P!hdFfHLx@^>)yVdOF2|N+ssHdljp`t@e zLtpDd>&MyZltGryq-I`PI=x$KXSSurLdYzt^76o5%Bf`*7?H}rw4F@tSAeJihCs`} z%l;H#DFnsY|97!lDW5|qwWBTwy{g?jIL-z_=`gBB!^2lxaiet^YAL@ ztaDMwcZC|=9xrZ-ajd!mZ8E@HoyEp{d}!cS!wbl;A4d6-E0W^+@$kHj4XN!|Nm6ZuD&^ zFOYVvBW{&mNvp!m(Cg^2`s*WtRs8{>yDBU5hyCDA!u*E4rw)S&BXOJ8gk|1Lt~y}F z63DiXy9AgntRutV3jkaxoT7Xatw!SlJKy>xNHUJobnB0*CZZE~5=CXks#fOEYCLMS zZ2JhVV0`AUZCXfLd7UPVC6*@5NU1yqeBIJao_(*DF^&t|A~m|hE8}Hs!s(_zyr*T^p3Kn#P>)43M6S42nc66K1c5xraFr1Ql9n+VsRYsoPn{M`@%+9&f| zx3=I>hAv5WNopQTKWCrVU^C<6_%gE}Asvw zVuVvfP)S`MjkPlQ??fYauIixwN=p_CeGh;9Na5dE4PrUdj{fVw3iV$?=1$~)wNLc* zZBZ}Wy{7J+6yL+p_+V$Y8FKOp|DWv`!76Y&7+IH5{5|k)(?R!U4wLhur_rcOVosU0 zp8i1?KUN)MMQ>iqOZrXRg(5`@Uj-xuB1zDHs0s!VT2i!YR^yJJPM`7uDh6GMfyFc( zOq3=oxsgR_lhwiW=7`A_4q6HX>TdL%LRQi9RqKN;q%Au+VHM8mgtUwsKVwfkJ2&oc z+jdU}%eYY8Bv54nR(BuATtm1$%Fd)zR3o;w7P0Ib=JSU-$EGIcRCRxR$Afw#7m8I5 zeP<$+Ae^xxl$$X1s0icSktCJtn#(w<&DGNogE5(c{&#uDALf{9iO=ue2T!Ic*qimp zdezmQ&&P~as&VO#fSQ5WkS|3ss-8mvxk_~!yiENdN>N#zN`;@;d~iH@%@*9AicT(J z@Y`t%|7iNgUv3=Fj3tc4_dCa@k&}A1fkDn=V1u-rtBsP78CJ~Ww5zI&V=W}4^T`OpG zuistVy$+i||6dIcCPD6DrtWqM7+L@ExtL|gE1{*5<4?F z?sn^Jozy994V6I`1t%qej`FEz=g0w$+0{hwhOA=URL#1X7fTQ@Z{Y*RrrAp6os8KT z6hH5uWNFEAF9#4#hXYhsbSwBrL{wE_pnER+=-#W6N*Mf`9bE|XIPfI$fo$z&Y8o~Z z`=dz_skro2l(uB^hhJY5erSHQ->^c9u{vUlwzVO&?n^jn^^!>Xs*&>w?EC{rL_j~D5(<0|Hf|~o&{EHnld2rRlBK~jsdotz1 z*yt(J_s5^!N=KBLH<=c3vmurW$P4)&Ex>I#0SF{<-9hc|UyyX^GzvI-Dyh$>OOSU| zJ@#E6yJRnn;@7V8iiyR#c4)I95G#@}(7@iwSkMk&ZtZ zQ2S7RM|LtGflQI)_OuQoU{g3z>=n3k(Q7$!$8nh2F6~h;x$WN4C(!D*x21TA+%nU% z3B}nGH;syIxX%;(C;-@eAk~IRB{SfJ#Js(0--*LA8cK&%R-wb6!aga1k>u&$iSQB){3z}9FFvU4C%mnA;Y z7PJ!Igzo$*@Z&BMekZ2EJd&6g1+kOd@Z_zEq$AViy1$PFUOxvz!`r%xos+$CR0C5R z%Y|f2Q_*h|qE8X8ug42Fh(BF(!P$O{c=G zYCXUy5*@*O%bN@7wip|F870g0>;J9tu8~o*Bdhty4~u3FJSjN{qt{9w_9HNzMfQY< zXu)UZc@V~Vph+u;899E++)}eu^L}Ho?xZnVF3>e8peOIc%!!3)c`W<#s`*m22YJH% zyDJw8YLMbJWMf?vv4|0>waM>7N_N5@;$t|ErYY<99`ygZ4(8jPReg{lBa1N(PfuK) zm7kKi#qm7^G6w}PL=5Yi8PjUk?)2t69u{Y(%-w@tePsdA3KqZ2p(`UZdbvGt+YTt1 zJ1oZkXx`23QArlYF&`G<*QR<_;jyq7V=LkOk{R3GMl{C5G#qEFll@!op!Jo{CkA;| zF~xZ#ORc1PXtZciw#6yVcN=CR!!Rlkhvv%^oU`(E1J`mtCw0O9z)D`pafg^eJb^i6 z#Dv;H91Vjn8U_*xp-r`gshaNYkW3d9vg2uch=Fxd|7UpjzyG(2E=TnJlf_!PJ0%b1G|Tp8qp`X$cD>0z!{t&x7VvExTtQwtD>m zK~Wlz(HsBX+>G0i-v4v#kFH$%x(QpojDBfeCvbtlSC{jx?JRl~aWLegG$!Em?>+uk zWybpRjuZdCZL$Av{`H3;I4^!*Ib`VCV9-V^M}=95{(9voJ+G@8R40O*ZZ-FxtP4>7 zZwi5SrM1rxZe|xsvf&LhweSp?i9YgT*#ACCLh19+%~C7+xd$<|atxMY^mvJAd0lOH zBy)fo1qd7Z+`uAB>;JFTzACJaW{a|sAi*U-a0rs%!3plcU4y&3yAvR2aCdiicY?bI z2=4AMo&VlD-}uAK^wT+Q)m^o#y7t~!u=f1$$eJZLI2P}1E$JdOsg1uFm+fp_3y&N~_u$UKMU?QM&hTWvn_cz@9U$`g zU~JyBIoBlKTa65sqE-<$jwxSuY%k=9<^y9Fm zmR*AxB+aZ7f|Kn-w(S6z&jL8x?_1aiqFTu4mhoZwR#rVwNu;MNKxu(J$)V=;Ra0h&3Gu{k=->hvSY8}hgT4xpn;+X+7 z1NirWaYHF^2sv%nxDwg}eS9uLpJo5ng10q$wQ{TTkQW33pu=B?f|2cPbTJ@rT# zD5|dA6+5*8?~|9x@R`~>g!-rbZS(D5rUr`M1c}zfFt`zn%$M`FoBJ_Z6`g%^+oJJt zvpbL7Y{BQ<*>6$KJ-sc+1fj>#A4dza`#4VgW1VlB>=3Xswi# zCx1x!t}>vL4%W3Z-q756P3|T@r(mI}q-vE8HjG`rI`iB8Nj~DG>?KksTQ~UF`X(zt zIJ|qvJu)}5J#ht+KEO|=sK4~!-2RGVXpCb87B~0i5u^HQ1L2#!! ztJ{vJWcV>A`M#+6T&E+4Xd^IFKALCcQWc%sQ=!eyjR_D|-pRH{X+QK&|IBK@s@bkp zIbVb%bD0t&+8BM#Qd`VclgW?WTlZ;NC|keYO;KtIlObqUyWZ^x=ycR+M%?)tLuKe= zwG6IAmgHoO?!zm_6dZxSoS>Bc#4ZdwZb$SI1U2N#K4x@~rjy;Gu|(nfTe$c7(&6he zS*JlYVvs0mOc-8?ZyN+Szg-=;JA2#?FqUGdWy8c|(NxnsAJyIRs&00ELJ{l!`shoz zgWM=YW6*-=?SJHwnua5ozcFZ~9UQ0&?N|x3$N{b?8XlqZH!?%Nk7F#g+&c*$+RTzk zrOB90N?DtFTE5`IQ|4Yo;eb$CgcY>q8am1+&8k3kpwZh7E29q>IUN1yql1J zYk2Q)9H#%nw~IdQ?(WLC^h4*qXW>iwwObCIpn-wsz{=Dt-XDMuaH{hAqU)&{3y7M1 zzn7L0Sz6u-er*~Q@s|~$E69R4pun$Bet{))>(gB7NDAYDIsn9OzYlIFyI|;+&BH0B z$^ILmfw28X`Bh8nWx90NlqoIz)#TN6wC$j)v;IiXd-UTek?Hx{HIpk#wUZ$;BSueA zZEL>c_B4UM1iMb%VakAJM zsrx3@3MIqW)G}KxMqJ-mzmM)7ond7TVqN|GbG_v*pddXWnZ)t*jtFustHb1ZFV@&% zNinx&jaqOXB5EeI!b3B=k-SY#smH+)O+|y1N;+Vs6$)!OIhh-4tf1vratHIT*#Y;# zW!n3JwJ$dJ%@2L4f!0VX+`|(WpC5~!%p7JeK9_SXq8DgQEnCvKLp2?Utj2lArkR?X zKAjvGQew5#A7m%J75*6(=`CBLxVe7`=|AMP(YHN5mR8|hw`{3X_$j%6q1QB7roQ!B z!g?Klyva9A(9&qkWlFk%=q{}K@-yeR4-HGJ^M;m&`sFyIgd)yFs^6mDiD+XdI;$jN z^}|U<*tp<0$Gvv6xucjLh-1UNIQ3uFMFfxuASK}FPY`G4qY4Sz`n@BdujQ$A`4piV zAu4(8;Fh^$+p1hpV4r(8dOpdo9*NCUWLZT>mH7t$t5PcrO(rRto@5U>79xpmDZ#Rr zw*@++m`}JGr`g*2H17&EXQ^{VTg!TB-HoudNsD)VVQxJLLska0)v>LfNT|1o{N7tK zGTwg9iDT#9C|>g77)@@ee>v>&#u22HbmKmb#6LO17QMSx-|($6{g|@9DDRHbuYlHn zd#jFkyNqcp?U~e8qVv(cC*rm0HQv020;)Wt?JQxvLVC67dA2vN{AHPLEZKjX<_Evg zXqPKvayC8PkBNkV%kA_22owmP3OgPN%m6E{hoOKZ)`k}Unj`kXQZZ}_>49?1PpePC z^J~jbx=(uf6}+OLQk3Xj8+*457ab-b?GerGO?B(ftsNvhX6LNoucF}ea_9YA_j~Cb z*7EXdRi^QDNr;4YM9S;U)cWhi^f-q8ckKWczC?C{!d6u~GW*A*OE$GN3+Dm9bgA<& zmEN&$wMN_K9fRI)Gq4U%Y+$w?77+$sFT7FGABIo*)M*Q*mpXIRMJ5DhsK_4g%R(*9gYi(9E0v7j;Hm)mPwJycToWJ+FjG0=M10O|B z!8N+)+SVC!5W_b3QBftng)`I_{RdLDH!G?DBF@I?01D~DPCi6e@1kYwSBqB*;sBh^ zO9WK7foAZ;?%fGeHaB!+UdhzeGmhH9@Ut_)J=1eV9zb_6iUdV+{Z*%V_oCw|@dG+PdXwoX4z5#{4;h@CoieC!G6qk~~*lWi+_vQ+yL5^r@#@+4?8>9x z#Nk5RN`lkhlPJ%azl)3RKh9!T1SzqD)01vJV5IhM#1`fsX-QP1e=RS-{g==X$p(w z1~6xOWH^Q!!l+}BaCvL+IwtM5d2-Wc(Cg7LUq5H~-oA5~AH<3vz^1CUs#7ayb!-@Q z-_)pLTZAhNZBF3O9|#O>{SrHzT)$Rg*Dn$f`n7+vk_Bh=lTjKvz&NeY5-;jVkj*+k z3ujOiFW`3-nwsrN9uD|kDvwH~cO6pADHCyB9PbB-z=f)&2(kLzBo~Vdo(h!9qEi42 z2NgisctX{SPut|=_}jn%I;GRKw=MD|RU1KQ-zjJZ%DS$kKhJUepnWad z+tM63re2Etx#nu_%{27!3RcA`A>poKH9EC@_|;+QCMsqfko8a3*5*$TfRRK&dDd{siE@a6`*v818`A*tYh1`letBUlx8k0C z4z;b*M{CwI3fg4q8<^i=Pm4CKSFvUWq)o`$_O9tLMlV+U0oZ`gUlSycu+0|- zx(FG9jG8=eW!gm#`J!(a1C*G3_t{rVrb!9aic{VvoWkW6YiqBu z)$b9MLjr(KEb?EUkEyLDfBk{)vM@1>42ce?dVEAM=-F${G4-=i%RW~hGu)J`onI-x zMiNW?v;sfvL{`H(y&;*4>sZN2UIZo*;!Q}tQ`W`~#Q>T6 z5tAo}LVL4RaA;29IQEG{N|>zX(20okDf=U&Jow@te(35l8UG^XQI*-yW82?vAQivNljm%;qq%(On(){fmXe}oU@KV)>zYCCG?k5O zaP*P+-An9k-GbBH>`BHnfH8{;-0v67WPPj~PiSUJO5w#W;}I=wfd~8+>JVfG8;p92 zn`$ymJk|9*`rnf54Jj>Dlt)Bp&SzsGe;OHCieJcQb_-aHSkxo z4bZu;2-7)REP+Jv`ymj)#egZeh9R9x@(2?T*NnK=3Ts-eNy3#cQlUUzaCUO^u)OnW zUL{u~yxq~k-hqN;fJ2?MvBb~hGm$i*_IoLwnCy$2Bn3BQF1K*Ro$hcEO%?0@(v}NN zJe;KvN2@POh?y01+_P}{0}}b+e`=Kz2;~pH&dy;YN3O2H<4DUcy97J5!FW-@hv83ju zPiGWaRu6gpS_0O?6#8+%f>^h2d`moN3J^fN1$)gmP}+H_h_n&I^J?Y&{mi26p^T#^ zkE7H|^)u6Md)^E`bdmOznW55Q?p{;dfI!=a{IAP&34~ve0bnIlwjKqQ5Cbu-U5}3C zWq*8E;4$l<Y*e8;)KPgf`E{ z3A2NClla33Q#{T!{tupR%Q9^#NNo%3lpMXON>DQ^sXcQOC!rNyGCmqpx?T7X`yW8D!*iraF@(=3g@YY4K&d!* zUia!_cpr?-GE~q~2C+66{Re&5hIKEgaXElcP&CDI3on`3t(-aL?pouy4n~kOkF8Gn z4W#eZXW&{>FQkT-D+>VlPawjc#=lecGMEXd z!={j{6E^y!7K!7hvc$XcK*d%LSCzN!RWy|N&lOv()pe0Q{Me95x^kpJy941XU=L)W193XYzE`EbU_tPHuX`U(bG>cjmR#3cJC z83py36$n6Q?p++4j}2zeRxLmvZrEg4(lw!xKA+=V)FQvyFPUvOhx_X!7CEolJyb5F zf#(#>*_C9=X{#UAev@QI%zuB4JEGiRUh-B66`gLApT9l0NT|hW?z> zaVhwF0vsS=1Pcj#PoUv#1qYlb4iD7^^So=4Pw0yo<8l}oLU0&t=|M-=@D~nd90p4d zEAJSAiG$lb!{L@W^j_o6cDCa*L{9pUbAi#r*+=E2 zRH{ikOa2BmLswh6+a9DHB4AuM&P0FU_#h;w#Us$HZ87KD;TbbglyXn%?UR(%(OmSI z^DfyQ<%g8uB7{Sa&EN`zmcP_He}pYtGY2&s0+<@M`8pM#sgfme<sWj1)`0i3{LU!4zW38r;dzR z`Sue2jltZWPf5*ACWi%5e#+ot*5DP%0eIY}M}Q6I2;QwSHMfM44AF=q+;@M@hG|^< zuL%Bxe)m6I(*KoHqKC`ILcBsryt4bwyMcQG5-DADWgrA@3bW(x+`EdW?Wr~IcNX*F z)v5QLmdGs@TmV;p({WB@BPYxiPUS!d%;~*9eemF>CvV#mm!#{Rh~49%W%Vm481Szg z0TP6eT-vr1&Ur#t1)DD&kx{Vn-i)8g=U4eN2mH9DnN^cLeK`v$gUHyZZXDqLq7!?f z-frhCA=YoWUQIoI=HxM80pcJ+QDWZ_!2b&(04P{+S| z?x1uoll_q<9pHx3{EG6T0{qS+B3&hT>xKxjbN#Pu#=PqUR$qN}ewOaScK;3r#yOvm z1=1y-hxVkz0hhA@6qZt25|!8!WH=;r3TfuAD2C=82^&)b<3Od{WzZMmljBQIUQ!lUYLvoSORd-LD;ZaT}ji; zj^Kn0AMxm#%el|{QQleU>y{;9Uofh`IGkjvNfaT4_lo!`_ql27UR{H>A~PcQ>u8v_ z8%ORT1l&;gdly4c0qlK%V4cZ?U;AtI!A95X4K3H5M_<>j3gJ~=Xr96stk(HtF4!6h zfiW&s`)a%J7Jokfzz_`Xh|GvT+yq1Jg@LgB`B_K6EvPFmkh>xKR~77kgHCX5Zf5RC zJYJ3!H7AIUYS+3rjo&J#4tb26oD&M+BwK5k&eY;|CD7t?(kcF#m1#Omavc8VY`Her zzh6F9s<77l&Da9A-UfFPhX&uOF6+KOE+ttR05OX}fG5EO=ebryDGhC~K_pmSZ5jqh z+AmrS%-I>i9uF$dxy`*n2GIzofiekGyJ}v!TROy0wP-FVaC-P(!HH;uuI|DP} zb?6=_3JIa^liQ=e;KM9ZKC4bRJP`&3jtE`AU$mmdS<^N!hQW8Vx}K^5y}$Qs_h37G zA&5_S*L6YPd|O?WE&VJ|jkL*(nDe4Od;IVS{Jo)z0P7?U@B@fStE@)zBb}q@8-{kG zP92wK2QHn!B~!11hFm!VMchS=)APIU_Ql>Ry}M_h8SWO`#yyLd#G^ll814Jk+&(hI zFlV|%5O)~bPRDKj2Rr}|BegakC=Q|v_|}fW>Pen^`7k2l`Qg~z&UwM@bPb$wE%86+ z7}f{xn`_6uF?x$%FRd4^)VMij?i$(+GxsO4a=}K~vr2DWO=8Qu z#Euq;B?;j=?i#nh$7Ejlcm%{fAmY>VKF*>eLh@-RBLFYvDzu5P;;!AnuVlpn0`#eIR69_BN~?~XsFPNXRXK0g+*!o;L}Y+iy8{pSEq@S`ZHtE$2F%D0%fEX8qN(v?NakYfZf*!ID)hW!UD==|@ojfDZ%9APLfk|c&DbF2A>AhIz` zTOa_E5c@5{$V?7)2)oF&(mPHnT@f5Szy+USTcdnlG8Xl%Ka-r^BRa$AV)xPpT0OVK zsG4C)BCmijyRF$FeeDHyDLOoa53P8j{*uZ{cG3`rAJ`;TBMtGtzTi{WmS!#_6{)jZ3P9}l?!V_A zYfl|o@i4#7m_w65RX*s4K@}%G5_w9EV=BqnXki6s_3=*|1k6*gKSb=TgPrSVSISEl z-vyyyUl9Wb*Mcs%o?x}tc3E`B?7IC3w52+!TMsCaZ}Qlps0@&j zFgZVYxAffV-UP+$oPeT~+kzqG@Fz$(?d|zxD zElgNA@F(ov94zc<#gq5R6e>C0+KwBK;B8L-#Bm90qr76Dm2w6|-t88YI1oo@u!;_! zp13LJQktq{{JbQ|aKI$E4e!KTw1<@WwflvZvQNhh_BV6k@bZ^VN*leX4mNBcGiMZF zsCH+e3xf)Bf&!Ah!ZgkjSi-;&$ef9W|C|HLRngTo-NfOi`@;ji?Oy!g4=0)r=5G5&wsBT%V_9W@e5+N`Xz8aZp%<|7a=PQ_Ms)jYQ1Ayp|+yrKi?X-yw-m3E{s14wc<} z1%+9MNbxA3o?8yaam|z=0&Y{SWG+>yUJCw$2fr*^78z--l5>`Pv_fgi?g)AnL?Py^ z`Zv`%$nL72Jy0Ge^H@8|rZ^}$ux2!x-S%gO^O+GyhuK=TLaqe0Na}!D)XlL$3^w%Q zMF3&JN2}+ea&Ku(j0TI>{y4Fu$?%IJW5yb5+Sh3-n1^eutZ9{BdBvBz_F2!e*P4#K zA1Svz1ol(^Cn&_AK#AbcxFlr+MXBf4`gOXe11)7)7hB<56i-Tz%j7aWM@9B9ug(HV}l-4>5k-S zYDFsUOJyus){z*h*bJ2KMkEZay z!H8egrR;W_58ZF#*DtshK{>f1cr7jt8jGDDbtMKNAmUS={8<9-)@RfZDtIL5F>N(C zud1djvSWks->8Ta?SVv?BalB_qm*Y{B3_bw;#<{TTwEJ8W7wZuyyU)F6>q^T$<9Oy z?O;}>X<}LLE2w9*Y$NsLvQ1Ep2N(_s3~O_9N!M&aVvvMl+6n zt<^HQPor%+PNeh1w0$Z_R>etAVCOGW@Z|#-%t9=#*dHUsMsM6vj5%A!`qVfC5eY0i zjqH!XwX?OpHkeiXNA2}X6B_L;t=x(GrSKE0xK`PUiXqAsKh+cN^`T|+YQ$2%yc^6m z#)uI3_74(bioRj?!1AL998ydie&aDDBjK}_3Kcj!G& z&+)wegX%GOQ-hfVLfQH^pQBjHToeb?<($w2NR1$%X;UeX?J9Ar04|+EvcpmX>kZ%(0G3S63z>NlG@OPq>uZo|@{IjV9$CY>AG-+epNzkzp9u0Q)OOagO+5IC~R&)^16jh4Z_O327bo!&Q5ZA^V(;+gfqT*5Y z4MKtywHdol-1blJb<~|hN&p!qypGI9SFV_1fx=HHK4yIoA<0RQWk3X$>XZVh;F3OU z0zRGe(6K7!PwW`3_~0Zj54jf&91vxhmFjLi1og^#A}zEQBvYoV^2DFPm}v zN%z&6F)wAiP9KlQ=JOCKl?8^rS$OT5N+`Zzpby!DLgga1?Z$aYTRr_HeV~p`ujaMD zYREW1&1~!DF6~Nq{hjnL_QGPkQ=0l?W9=BbV{GhD{Q^F%c93zB{drVuyI}Zo?1qZ6 z_gwi9>`M@xi{%U-`NOm_q(7Q+m?E8iPr&xQSjpOYH}QdMYA3({x>UINW#!_^u|@d= zT4!&fj4n%BgjKk^H@twv-P4aNO&uA6n6(QY&^c;VFQqwB*9ra=JzD((eOw&w6RpCI z1`L2hf-IIQ<|ZvsQJ-(gwjr4s4W~rj6o_4DsO_Z)IeXXBjU|k85z7JV7RQ7e^cTtk zG)|U2#YS-;7)2+W9k;3@uG*)*U*=8Z&(HSax*ehKXw2Iju>JFO3{EaKAu?)h-r>Ng zy)e>%f$2(xjpZsmo~e;dHkq??Wb(RXY&axfQkQ}AWa#}bhC#w~^KbeS8Y{R28b(gT zE?M)7c)XnVva5)m1kIi%7LP`IWUx7^{@JtU&|D-mn{ZT1$tIqc(V&u?s5N414MD1g zp!rb%|BWdO8VtPu+n$BRF{cIMF}%=82AFUw9Qh2NQtLW8@0Wg*P;&__PeUJS_bGhi z47HzPW756lqeRa(UtYv+Ko%v=s+sYFNjN1*YnPFf3PzNhden!p@v+>Dpa%oHnG4h1q^DV&-9wm178J;r&r?gYN+;?dIIbjRJSMXtAj0m>e27Q;&1)p?3ZK1Q?~>XYCyiUc*lRF(0FYGcJN7sEyvj z6{`KdBd@A@qy$Wb7LZO?yH*+UKgYhg=|>ZHPEezE{~hzsRq#6e=Wi%7@9}Yva@*Qp z?t()_lkmLFz(4!~Ta~hhm-ItxTmC}rhj@yzo%)s>oWcpy=;N~6Rf0DcLP4kB1Pg?l zIp0dAvQmf@;GML*mQ(!kJ!$Zyj~=se#)R6&W`0Q7M0XbT<(h1YVe;Y^0Dp!|i)8YZ zvfB^;E$>Qdf8G~zemz-a@Xvi!gIReCNbm@;((K@H)hcbt>FhPtcl$-nW;OGd_z_~Y z)=EeoW)I$(JieVdfaRzpb$o>Y?t(^p+vA3Qt*1Xg0{lxMjkiT9%=3zO6GC6T`u8wp z`=HkbTWVCuPw+32^|93WLhI)$LRQExZ97!`7!_!L+a1>D6Q$A+gft?(l?LuI$pOXWmX7;u_aU6JS%rcV&f{B z{_<1iNe@m}HH3DXUP!Xp>k{EcQj?l%`_cvI+)$2ki&>6ZzNU75u1Y8Q6Bo(~t9 z1VGA521wIsq(ZHcc0DW~J*G>o+7b}f8XFN~WL|HaDl}#BqyNV4q)|9nB3M*f)$%Tw ziIr4Q=%ZhIL-NDSKfH#l2g}7ekw^g@gn}lI$kaR1M|w^+ugYBsDsy>Q-P?hj9eQ=H zWg>u2^Sp13I`o@54tAZM=8B?xOAq)3JsWH&z{4f^sAnl(D#;gj@Qy>6{NSX{Pj81} z{5J&93D%3~sn#n<-PwK7!eu03b{n&Zh#ZIW?&u1h&{9w`-P0U=C*BPWh&Ij9%}nXo2ZI=oqRs1OOF>JV;%LmAH#XEYA{)9k4 z&JN^x?|vc69gNaO)+w<=u7Nu7N{3qVZJp81DV7kI(aB5$^QyJ^bN5Dj6HW^KMzp_J4yrcj7rlyk8|~b`^ctTeQqL$?$SkL~%WAb4qMJydqsLY<=z%5uw#wqN zeB5tn*rkXWNgJPZOBlx39E?sy%P!<^kGurk|Ix~gUsj*^F0qT;PsY66)t=MiINx|% z#d5m6-yBV*#vZ;%sVy!Zw)3^I=-aj?q<8w01rA{)vnyU=Uq0{=N5u49HwV>|a-ZFO zY{@NL?y%h+D(op}?iZlTqg(VaZ*sKN!G(R{k(L?@z7!{wCW?3m3Haz0Xsn@nKiHI_ z7PCoxLaZj6bB8Jd-*-)kMw!y)#>(kE4Ob_g392Y z1@bC`TU9*_L#x7=ad0mGZU2qB11jOPu1PI%EpuOHL^u)Ve0=M}FraeNNNC*GG_HGs z_6Y#)=+~Y}fW3o9Rqvr`DxfaU%X+z7i)FI|nTu8;5T8pB$XSY_Umw`XcaBiFHriu>Z!*b{~WCx=69S$TU zYEf(XOw|55;$LOs&Gd=3W}fbfCEgPA=%a6KT+0!FK}-=~Vuf7JdcCx)i{1AXTDG{3 zxTfCedq?H1_@5qo!l02i2I2%03@{lSM=<{XtDC zqGlTwnOtW1o!+T!AFU0$eZLU(=a$(q#6p%a1FIK}+6VlyQnBfulep5R*Hk{4_T(uu zqr`6e>)Di_onofI1-Eb^ywA{@?>42m)%`cpNWW{goCSE8vz8`vB81Kru=E?3qTZmT zx5{V#{_Ld0H=Wzu)gGWpc>v_1hv-eAaTUoa-E^!5%I#`77(C^SE5i}YY#yFisH;9eVK(bpQ zy%wqno!!StXsT>8gE2+^H(HgreaRAd@&jR4Af2t;P^*HXkf?D3q(`+R~$64l-xkZjV zX8_qv!-DyPyLsSil3@^1oEH-ACbdSE7j3d-5*Ptq1!_K|0&HcDCMTHDm? zjP0+`&GMQgp!JC3SRwcl6I-=BWz&gaLHOgTyAwY14d{A$*)*CHsOcSbYpc>TlAvSmT!iL|ZUK7Y5=QSBq;&X})ksmTNz<00t&%I40v4$3FhZhfSUv$-{N zVUTWGL;LB!@m12 zf_C(H5z}A+`i0L9Rz}C$;Y7Wij;PR!&vw`%qLMx$Eazm|ye_{jr~RK^Fj-N3Iq%0+F{DQy?Jej7(=c1xtgw?B?B z3+isM)A3LVRBjV{T}@Br-eFJ@6#NZzk25>d7E(|BZ2sajHY}FZW>P0q#Ze~!3ru)? za#TSZ%%snC_ zna9@7hF-cIWF%iMUO;}=en$;G@yEcrI@Z75O#xTIGfXRFSj;6CZ19YuZ-2WcCVXb* zStj{AQ@eIJ5Lqlt^3T69LbJP*3jaC?h5b2BqVyl4y7rE0F|B8C2 z&D4=;RX;JX)PZgl=4)R8>q=UOxgm&@e>Er>S2M~k>6p;Ot4shR}o z<0VJTWyP4f;A(7It+~BQmB~qT$5bo-cZxDEZT0+CbAHr}s%NN!8Kz-vM;-UC>hrJ| zSp6CK&N*wuvTUqZl8R&99|ATG9ySYE5?4R{$Kz@}TqPr!p4(NYWG4t5u)&#psbsNV zBOGxfnP#*P8S8RigF@9XW4JxopBeRt_Ls0ujl*~gm8k~%LN^*+Zh;ik1A5Ddr$G&j*bDEydk#hEzSb0+VZ8Uvs?PI9)Cg z^9eV=Zvb(Yha)xZyJi~i!}d}s$JZa$bxaXp9iOEt4xFXixvAqHeHyCMNbgTCJTTFB zD`;dlM~?r=kZ6YV|H*~7)-+I(%o|h|8k$!{(Aj2 z_?#)oSC?y1_ep5WlTv1TSq{i-ygMBor9Fv&Ob1i(2BQ#xE&zDd-?gQPTUB@W=R0Er zd*+YnbN_ff5jSylS~LeCv97T!ymoS9LSlk3u}GGI)4E_fT*ZNLa>(1i8r^$QBZ zaRhBfH;`M>BEN}3)b6*kG#CE;cROzpd!;TZ?!RN%GTy4BqSswBV%DqY;3HL zj}HN-Q*K^fUT$uoB9-y_=-?oj9%QjfKlqu*R#`$_eaVLnEz1}QqzwEAsH* zAP5O>XM4NO?b-|&92%NDBxtSlEA9hxkoN>3xwY|dadCA?shvRm%iNGL{2a5)`(y;gcY~cIeV8knu%zpYc zKJE?HZIOWy0b`v}V`C%xqwW+K;jXW(?d|JxbaDa<0=Iy`K7?{3ptp5T~l-Q z&(KdWx7ufCQWP*e3>U4(ngX66FE0-!cfN9ao-I*<$UHsN@9F8`eZHDD9EpeUaZHU* zPKg!I12c&i740w8nS*aIR}mW*mjt%bP_5Z4Rk5fl_|(eEHyt5iWd{Hrs>-dJ5dwb76(Z_3Bl*=a5&AyG5t zQ>|KhyY8Ja9PR7tOM_KXRrMLvl_n<-#MQ_vPb@1bQ7BRb^MHWWb)z=@V2tqgt)IU? zX!eCCrWNXi@G7ZWLWFR^W@w}zxwt$)TLp`qp~+FD(Q0fYI}2(HMNwk$75+u)dd<7qC146&pJtI(p-5N3H}k zKCm6ZxI%??omt$F!SL|#TND&$mql8Ebif_#QNmKX@~`=H-&euj4ceiwaA41)U?~i{ zA_4&R)S0s)NeWQSO}owQx*)?VIU_@+tlsfZxj5mzAI};S6H~QA*paksRmdz_&s_?m z&w|x@wYe@YkDP)+ga!+g!J40)ZOV2#2~3|ZEZDRd_VgpxA_OC`Atrq2R<7E z!^DQlgNv#!_dzuQP*_+94j}qfszpUbGi~HE1gU|6fkfU5Xqo$Ze%3vVM|L( z%cv%)IX<{}>7mgdv}$lXJYH+#q@c*3-?{=#AGBmun)-aPW%{MJ$jJHCMrnPSDQ~h0 zaBn~({O;BbIO@^@vTVLI8jTZQLQHt6Ko#4#_;@TZDf{5yps!*y*jBQ#nes)YNs?4? zFfcIk4MnQvlZV%eUuus9`XT}Llk)bqep%&+7g_XIuMDOVH?YaK2~izJWkl$O#Io2SIY2&{@#JM2$2H#h(O z{Tti~j*_FpGb)9gvT3WEAyL6S|7MwC!-7j+1pe>8xKG}xj|!%GQ_Ox~WE2n+mKG`( H(DC~(dEkU? literal 0 HcmV?d00001 diff --git "a/docs/fig/GMM\346\250\252\345\220\221\345\210\206\346\240\270\346\226\271\346\241\210.png" "b/docs/fig/GMM\346\250\252\345\220\221\345\210\206\346\240\270\346\226\271\346\241\210.png" new file mode 100644 index 0000000000000000000000000000000000000000..bf5622f4a1e10f64c45e8d8abb864a4878375029 GIT binary patch literal 15095 zcmch;2RK}7zdkI9BuEfM8$s0QEi+2Ah|z`UHADo%=)K!PlrTh(7K~m(5WP%<5JvB! zccS;se~o0nd+&37?>XQ1Kj%9xW~`aH*0a`Ae)n^i34W#|gO3No!@|PCmy?C6Vqsn7 z#KO8_f^!vkB?qQl1w7%}%IY{^VG%T8{;tHZ6HsAc-N%xHN~pUgttCqYQRyTsqWHd1 zlqZR<8f2+@-72BhJ-%N(w#!o^p&U;n#5#BHxrG*2RABJt)6xxg1q8RlFTb~ ze0j@~4CSmh`#f;;p9tq-4fzf?f(z*guF*xk9T$zZr6w~S)InG46o^9Q@mLw6ptXzFoT{|WapP$3MKf-28e6vyxLgd3n zBKaPXZF-IF%Qb&{tE&L?y3*VZ) zS9d-O^r-ep{UEF=CTlZ4e=1|;9WO-^jX8dgRoAm<3$s-JIPN1)?{L0qbB&#^8&{&` zJ)Q}nNYIPIjiKvF*retML0o-j_CgOnl3`17KUmh}c#C#S?99ef5lbnOefNy5W5HLg zxzm{u+H3`>wKvmmC5cd7b_$@ZTjAgF-D!5)w+I+vN6Ap0egCl$94zysjW}5)ithm` zv~Aj(MbNHR-`wYiO`A8X%z3VVh)rhiDkQT=j8$G;Ld9XOwMr>h<5+7UHTm{%hD zYS!6`jjAMMn%;US%UzKUnNPB1ZhkTOxG6j_mUQqa1x5QSQT8QpSC5P z0)6oayi4Y-xnEMEy0K4mTU4%8WM+zQymMf*WZO{tO@`^@RR04cNeFK!*c6$@$sLi$-Hj_WtNYhs#VejW}yhh_yqpX;8$%-VT=XY!M=X|fWM+l*V zSw7*#D>A1qdspo|(xRhBT5<{BQ&rJu<%P%z6E$=*6s6HR<-#bWSyu!c6yz$j3OTK^@=Bj`6%_^PIJ|$yTsd{L&?>TMwHTHX` zZD7`e6;?=os!(QwYr!3rDZZyeISjRG(egE7BE4m|iHPRg<}Jf&70&T1x$dn8#>Tld zLf#Eg$B{Z3N-v;-kaLMv5&07j^|MByjbz55yo2^4+eu``v2$ikudIx?y6&2Av+L5y zzlbj>4+!cEnX+V3Y84PRDXAe22v9ZOZ=P*nKq!u9k|!tHe_7X6)FEn+I2i zk~au=y{JCG?#8J29II7(#ChVn=;`kVZg)r|Q!h_HT&S}5x^Fz~rC_UvA07m9-wi$~ z*)*{+G0aR4k}Mo^OZG5P-h0ZKo^rux6T}u{V5a@vM*4dq8!&fELT&Inp0U+kDQ4hw z5R)~LVi1Hld5M9`Uw*)8%R!xSUcYj-B(uA5P7O(PjFO?kRg63uC_3+ZTMFj^J>Qtj z&WIo=3TnnlCN_d2Bi^>K_p-9?wLu`gl$5M4Z40V!bCx!L*b}hZETYw={^3!dc=BnV zxYduT0dVE{hd$2Z1x7&`iEQX_0?aa`Me`GBT?sw``5h)meD}rqq@(>4pIsS8+s7BX z%C%H;M`Qaw=Uuyrllm7{-lvln56q^+aoIzd=>K=VyB!$aeI z#fSQIixxiQLwgAje@N5QdOMzb#P#qVC zsEdto^!y5n5V#ZoS{Ncw$N6t)8}W8qAMDbeUlT9Bb80{f02UFpX6l_ww1RRQ_UhU6 zpw5Tr*`MM3yaT*R-@G(=N_E_-+*A?z2pd4TCGPVlnrMAg3kJOPuU>vAK9%4K7S_8~ zs;E}R<*~*Qy+|dx4pRC3 zV1kBx;<|~2^_V$F5;*fKUrcNN0|%@Zs3`!vswK;Q6tLWC`4@<(Xr6j>f}XdCALPt2 znYY{OG-Rj>@87dQY<@o_wNQEdm>9{s5@%V?D9P|i8(w)zkAbr{7VQL`S#LknJ=!3m zBMUx~C}_z8m)X;Y^xR33(lX!tj$9rmyk7VXFJa=7l?O#TVM(azBXz#p>Cmu3^Y*77 zzwJSQT9~l>&VnL9Wt*y6^OG;B%~u8EFtp?7j%OhOJ6MT2(PI=2dxyqL@Sfmi6EY21 zS*k2h3kv#vU^V*SL!XEQ18{=%9O;^Ies!54zNBn<-vD%$s49Kn0aPb~RzGVQ-!QZV z4^h^9+YG-!2V5-lLC4YWf}K9L$8IfkkQNa}=XC7awBSQ3c(B!_Ds z$(}@FSmk6?LH`{Y-+Cag1zMq%pi)Yu# zO3v5Hy!|ZmWprILQR&!yTOEl5bU?oZy`PRZPQb)hToY~unggfeTOfP6-a7^4qxwH| z)L&6Y{--dvE=;rqTvBa-cYh|*h$6FKvb$=7#^5F9K6>j5I$O1bO}eb#CuEHpaOp)) zrLD%e4Ij~cFz>B*11woUZ=#)6U*XYkT25xSveo|dV zutO!L%ELZ~>d@}`*HpIeuB>IPvN_fdKZZsZspn}*^y%1*Mf2>mC7YLC?s=A~J8BDV zPnClVs`{e(g@H{|Qtt~9%^()r_$+uo(FCh}=5tz_W!-?FAp3cHewhiT3ufw!#Xy@k zz!ija*tH>j*8}x6yS*smFX-P!+LW&dbqNvG;mz@lGlR-0j6;iKXCoFrQ

    *TuU;% zc9Q*dsOhd#vCf}pSl%>1Q8BQ&z-}jBKf)R{9>W8UmM;~cA-v(YO<#B#!YKaCteFK?d93wUGH7ujfbAiMIwSO|FKN?dNrv^Stjo zBPw>?24*GDwcYLPd>EjN(68goIt^Of7L`Y^xB(Xhaq<1bc-E-10EQzEILSf0{@|atO5JXEKeRXsef(Aqv#5 zdibuGx`b&pw?5ZocPGVuzB=Ot%b%#Zedhw_NGe+1=JBd5I*CwKrw2#r>GKpq{LURl zA@R9HzekCkIui}Uk>HpNl~ud}h0q&a9MugjVc1bw&9Ki~70g8V2A~kb@Ti<(-dqdI z33gTsOEEsx@McYzWhZ-11D#E+DQZ#uYu!FsDWqqZEd1jOc6Sr1dqJ=e=$G3B47-qg zOvx_4{ckDHzh@o)cqdALIZCd#Dv;1;XXe|j2>#-+O1s&4bX*DNC4yZ84b3ccBpz&_ z$Q@q#O(NJRV9?jncYi;O*y9(VF))PtcTDRHg}s&Fl%PJ@w~>z#lrtquQuy9F;7ov~m{4{&=Debi z4!-xi?|scKaolqG;?eSF7bk>T%fjfsJ7^AIK*AZI0Hd6__aEm4gL{HzPRiFvDvXi~ zh_tUibx}_rx2DQVhsV2$Pk^`*_ho@Iv1?xu580GVT)dV#CqTberPxPRe?R;c;!Byf zL?2bT6=9Y1%Qwx{l5wTYZ&`2$-ZRjL6n0O=NeLXH=w~Pp$LUAPXFt~08|>!3WEhT$ z<_r+Th80S&fzEy~Ki%oB%#uk+<`XwqQRvpTa3u-93zlxVm2RI-Q~kqdo$oE88QHRQ z_=gQR63m=F`$OWR_z&6pqKLf!E*HeytS`h=2KuT8in;%fXLM4*6Dh!As{Q^g`WMY% zg_P$8@I->g3ZI(zU8uG&fo@ekPi&L&RdnwxJ=iTiU^6)B^JI+<5`WdXHD`@c)2OF)CMpR+TW^+0#Rf=ml?Qsy^8sKCCp5*g4;f0b;~7!vF9n14$a0;00Q&&(?3p^-lOQ5da+ z>jHJ|63Q8{Kt3|z1MR_YF0zqVzx1n@lQ8yv)_>k@fq+})X5~k?>iY+L=BWOumYpD_ zR!?gWpy`U=gxeLTcxw`Ty-?m}ZS>Z62i;Je+Knr*;8>zhhcx-cwRFf;6F2rIb46T^ zsDr_7HvGRs#k{wND=$bsbA&4VreA}r$T}osZfiU{A)V#O^u$xx!p+hje*a03g)F1e zY-BKpyzzkDIjE#$DYE&Enz?Du-IZVHLGQg({kPd-fZ5g$SnA&lYi88Do%NP$xu zK494*LQ)5S6G%Tl$aOL?%(HYDqr*maIz6dsYD0^X>5WdXcO^x= zDosWcFY1MIN|#wDB*8k8EOH><9l4j_CTVlDFw<1Llt9jNUpckepIRFFU@{a7vx>2? z>cDWVmTLd7=PS#sh{LpR1U>NfpdndF+X>e* z3&nKAz75G%eW3J1C-tHFeJ-K{I2%(4o6wY?M9X(e4zF> z9I?icg3nH)0oBZ2&0@y_q~bYQ?lm(M9t`@2bY^tVFQQ<_FuyBr?zO+euIuqfj_ zxEsJdW~{?t%cb{d8r3Hh^ zK|-uIg{x;N5eq&WL(VRQVKNFn31pAt@X!KdHy{$j&4r`2Q5v#R1pM<0e5)S~I@QV%|Mety}1U*=j zn!JGSeFz9sOq*wx)k|3d7w=w;TXrs-C}e^D122$dVo1=f7S6uTfqYN&Gv5GJ-Zq2T zARX&Fl?ACMUvFrsf}unNKi>@crB!San{cB=_nKkfr|bdrutRB9D@yji2YAPrj$Oy~ zY%!u_?NDCRs@FkdqylP;uX@nU(rr56jz@R3-zWln&Ji^xbb_bWK1i3@OBLQS+wkSS8oGU70G(`Cq|>^%b0lB-yendS zel!jjL}CaQ4{758j{1$_Hz9-sc7R3bW@5VFb~{4*7aOMsxRb?;Lbf8YWHAeMMFWK3q2I zxtlWj0SQ1t%{shO?%9Zz!(7r0@bhCp>s%Iat!iFnH9~*R9V04Zt|%xTvT&VOvH^v7 zma3KS6^^#=kn4tG)H$|}`#5){=EZ$Ep7ag@C1 zts5Cs{1zKes9=~e+>;1*A=PC7**%mEl@0*W3&Y2YF`$fCdPka3ETVR3+M|%(zz|}e zBdJzv^f+cY0acsdExUu+pZv6na5IxIR+erghh!zE`3v^~ikN=#ki5xbi>iArPgG)? zfz3dUrR~q8SNLLURvAQlVR9SkH8@b>#q zzqo)YjAC9E_%w+i#T>g#0O3*W$ z6yP`7RR8z<#?6>7o-x_uaC5_;_ApsCQQLZ~a%+2DVTA0rR0y}*S?D-9m<$mrZrI2m zxeZ1>X2ARXU|Drza2~jT@7PW_fC8uZEbxEiVD2fw#&U_TT&8m=(qH^$;Q$|X3X2|J zUQF%Rv{nfC>X3yK^(6UAqDswt;iQWKGP9>|ceT@ah4VGdsF@6mjbN*bq;fs)_kvHA z=#U`~Wiy&V?)2b=2h@|Nk z8e}W#fU1!Yu0E0s2`+ZU#e%d*B$k-b&W_;idU2#P2&}( zPL=@L?IFprdm0WFHd%L^Uycj_+5#0}u-HQ!P)CX?@_(KN0UDP$BDN!`&CAGI=QmmT z(V=xIqS!KglC{MML+AaHBix)fW$plbT{xcJ{imOwT>_7CQ-@AcP^hMHm7t%2@+B8K z@ZB1)t>KResg0e}q+<%&ci`Yh-i{Z?fkB-3v1_r>J;_@a2_3>cf_FAc1( z5D!qb(9v}qL;Z;WH`oLFrnR`l7Nr4|rNY7UW)~a&Ye@k)$<6n239Q*-^Zm9*m6vLn zJw(UMQX3M!*$!DYrs~MI~8C^;qk)@zG*TBlHxyA zdw4=s2W^&;TnM>O`W4`mNC+VM3F6lS5@%Vh;4wg6bd4OT?Tnt<^%t#eKOa%-eIFVI zzge?*uiv69-zL=ZoVK>B+=#1S#rjisqHptC232rjG3C3$=2zgKLjQw&TplB=QmhCL;B2NaE|_1VfI@KA)e)R{(zyk;*b^-7&Lnbmlh18&=j1PqDi(sNw)jkLr9~18tsduCKms!9q}>fY!Hq! zC(Dn%ix?J(0Y^-=ka+%fV^7R+@(jj4IxzD8%}2SxnS&cOHq&HA_?3DG217-m7K`oPK*P z&KfZVbfG|jc85b2fs6pT8&@k+=Upcul{xjWwBR8erF-Ee9;NAk;kzjGz}H%8p@AIJ zG*E4#e?GJQsiuT)Gm7`-$!i_!$I!dC>)s5pdgy$}1t5V3ZULAYD;$P047Tysi$ATx zj%+fewckKrO+x>snN5kjAU1nDoK`b9{7F$*t7w};G8iQPjl~1OU|J?+x{3-?iXOw^fiHTG=yXZz;t(bT9#ryJR z>Q*~4AM6@2vtvKsSxIR+;Auc-2zMMe<*lngg-L(#5Jlw$T$(LoYPQILf7&gs)lC+6 z9-g_^u-bZmlSv2C1}R-yY#RBg%LRg7$vlXc$PdpDZVtc>Hbz+VypAM44Va9q_|3*| zg09AyS4&e*C-)Y7UTI`_j=n>rQmEg#@Lg|@*_@8!a1eUsw`{5Tz-6ZOZ`k zd7L>0JCYt=fIv_v#&|g@bbkN?u=<~nWN?5555u1b*;whk+bMZavTzkWR|okl--v-K z`J`V^RW(k-EF+l<_V^oC)t{)-q5m2OdF^6-u}(Jy=+xKcEYNE-U$x-Aqd$Tl`^Otag&nKkRn zjlu4uMmoTr6wgq1Pc|~DkJhF2&;NL|klGpM__5gaJvp%32p;;CR#@!8VRg{oR|;zW z?Y=Y$$m@->HhlP*s4Fe*6H!aYCMeII;D7;?mo&6;-G9nw_hU$bdEdWx)IBAl9K=F* z=vxYlydlG%^E_X!@+)fljW{i$!|ghx&|9EcY8F%**)u#mOS&2j-<5w08p7>aFuea% zgHZ#~F?}XFD1>Vmard}G4tI^|(*OwXk;(K@EL8GSi>AxlO@m-7YlSDDrRdH5CM)p& zM|^T1%@&p6KuCMO@GXR#-s)wQ-E1>Dj@v{JlSu>o`@(M>i2@swKHLjUH&9d>JpUVs zG=V<1#J{p+!2ssOhh55MXSK+TzB%UqVBgp1OfyFoT=JDrnmLMh3V%BI#utu=&Y2>W z*1lf#2!80I=-1JiF{Wv?Tl?BXr?&Mu8evg;eE6?*cKPsWbfWvt-eG70-gPcC5ItjR z<1Fk?{qk2DZArHQj}g{xhaofEokl9N{l(wPsfz%{0Xud>0e%dpT{=US znkLO+;Xg4YLd~S`m##J~u0kwchRMQ(Y-0#b2r0PEXmtObPg~H?oexqZNPsn97>d3U z;M41^C}Z}<{j{7k;(|0xXAixBc@Gl?o67t@S@!wi4nu$QUm{Nw<&OJ9Xwqk|)%VEW zK3}SIezp?MPb{k}X-oE$=IvMsOWX-v1zfG_{R4W{=%tT|07D_M)v=TC3pzopx&yTa zh`PDg4Pd++nfz(Ipuldy01ZzNxm#tTLg2)52@TC1^DmZg6X)fCO5-T3Wpk^| zhaCKsjq%Y57YzF^x$6sQTaNT4GLeg2B2yttR8w(;h+`G}pY}@u^wxR*Ek@6ufb~*u z$&H5WwcR?u0U9e^0)foUdUjX3dovYxL%(l%s{pZz-x)QA0s;>B_uMDs;F79UTUjX$ zrk4JuH7Ez{KREc;*JzmWd*;&Ep$vv=(V{g*775lKRhBAW?`Zrjt46a(8rq~xUByZA zN5+)lb4gjUr-M^5QUcm&J-M0c{pQ3g%NB~#>E3IP0oUYB%%Go64Pd{}L&&D<4?m<9 zbCdck*0n1AwF=Z1^Q$!03j_eCT`G zKI29Bi4DZ~VoOxQ*o{RTLDT-S=`r5RmKuAEsf`HD?HjXOl|vkn zEejfkgr_-Z)Wo)rp;|f!B}#xdE1#@ETm9M zmqc*j(OfwQy|RK)t5Fl?8)`gE9n3?0I2>dn*9pfoCBJ+!82nqRt@WhFZmTB<+i#hE z{o2cBX<%ZG99O-)A5MXYX?bJLVi9KPu%`T9Q5O1w0S6V=p|YangG1rebne%XyWh2==64~^DPxJX5x?Gh*%-ojFL;~ zONi9flt){lcOXM#`j=|bUFeB35IRX}UKS1X zdPJvILsKsf5(Y5IWxX5!UFZB?q~-rr0!$$9elY4;$ zXA-wX%)zVYX~o@IzsOAshTH%O*A~DIU=|0G|1?t8XKr1(C)ooN8W14hqqbc$?tcOQcQvE>M2As``k^#fUJ$E8Wn5)ROsKTY~0X?nng z(UE>g&Gs}L?inbpdMSY|mX^&4%sI84UZ*dCAuo;H3sNGPOZ<3nCK$g9uen+FqhY;a z3$}KksFrRxnqf`@8rqdD1(0f?sQ3eo&1J%CEGMuqHDY*dCb(zo$`Fjv=jX5Zm=)u@ z56!Ut@ZHy+e~Vti^!G!U;hSirIUx61gq`-m_BwNS*e&Chz3ZN9s99L5;TtiUV@#ME z7nxxSV6sb{>R#L&ka}^^^mOn1@8Hv8g#Ie-rv_PREGpLl z=cZ+2!E%6Ns_tH8){Tfy9Bi)A{u!EqY|%1JW2b)5trYZrzG^?FIECy3ho{z@CBTg(vrCOcz4uupyuzC z^@0ev`qq+CMacZVa3o;2a3oCZ*dJMBua4lW^*N5BO+twCy4eta#hNFCiEu&x!FTCg znE=$%5=T-l-B~`-E1+K_#9iVJ#%j^V(&Q(v!kG`9QuY$XUmuds{~N>6-x+4Zj~Lo! z2UnW(eExt*EGA8jl1W91p&Ek}4^JV?>E1q%`1$XFw5o>{ExG>X)sG`Gki4SY3*Ibpv$Yk+={PeG zKoj_}1t9^lF~pDTF80xHFFpG;Le^i!-`rn(v2E{4#hdedIJQTgjRj?L(=gL7toZ7= z9FyPDreufV8}wWd5=c{6JceLnh{lpFO=WP$JP`5{?k#KfO}70dk^L!)9a+&Z!&m?l zM(Unxi+PRk8-UdVK8bvx_qaW!xrKfs*Eq~_bqmkA4YI4+U#S$tuj`3Arc47R?%-Jj zbx;)gK`kq5WIsZ#m%Us>i^-PQH_I@HAj^6;WOO!f@R z=Km%UlUFIJ@P#f1p=k{jJ_xght%w+|$hLK~yV_=JhGa zl>^Bs(#4*18KX>G9}!{Ulhd6p)Zt7>ChEb~RD)lQ%Z4{_0y8tS6|h;t!g!4f6{2}$ zWMq2Uh&hV)sl%IL;J@AEM!${Poz5Xc^hJy=RuIrvwOpP@^Kp536Tk_Mb{3UCh*Wu< zIPY#0=U%L?T8hw+^wezBCptEM?r-Au-U^wvJsi}t8LxI)@go$wk2JhE9$F!fp#?JW zHmI}xQnRk}Y1;GYw<&~w*s$kFPB~2{yG&3L1g7|$A9Z$iCJVcK5c2|^ox~Y>AKmoX zya|1kU~1a7Hd>J);#T5l+uV!j1g^P-*5~x}Xh+y(-M_-nWn-ewVX+f9_0PRyykfJi z4|O}8!qXn-Q#Zw3_g5*2ynx#-4WJQqZ{p>``hEkrr4p#t9B`bG_Y-KRY5^rGB69D6 zgR86SRtiPH)?8~;#iaL19O~4>b;ogK@Z(KFn$^`+5%*md-OtFEKo1;kh5MW?`kXIL zo01aMo$Zfa9EM-Gnmb_C@}&~I){6ZU>o0e@xw#2+L8Z%vL2vcgTSStA5HKd2jfAwK z?w?CbchH9Ce6>Y*Zi~q-G$O8#p;_jg6XJIt-u4UnTEpAOQ-6NE5-xuJJ_vYo6X?J0s{k6ypMCu6T)f5&(9_=&N{>d zu_oy^$M{`i+)fYeDjRl|deTE`IHr@I9UkzU$aHxF(_sU3alW(Im437FC5b5CJ%SW5 zue`jU*(u(SJSv!OH&%Ka6dF-9PV;y^T3_zVHj;4Po@;%O;<-IKdCtGfTE5)i zce6*29m_!DS>Tua_Qe#RddIo7ipljocSEnefs+lC&&mE;m)Obb)@(~@b~QI&4H_5S z`;Gh_4dy2y>u9}ZqmV~PM8xy-aBH;%SR;;_(;R43jlC#kv7?EDdce2$Ip0YER>>Go z=!W+JFo7pDH1=%H7i3mX0TF*$J^C9qNk>pTv; z4;yJg0bA37v;i(p2x&xcDnbYd2?a*Wu1@J;56*vKM^8^r1M2_<%!>iTNu3U$`!~jG z1WxRL75yy)b^J5mrvAv(%xtk{@@%~~7iMxi(&g<8td00#OwX?XD-l?}d3gz0Za}|o z64`2K00p3V0n1Ylh%yy%RMmMF0IpC4wq7K|M$r`DERk_|i!GHW^ zB9!)gGY|?UVJ8c@*yg_228I9_3o)dNy@?gE81xhy1`dBn-&ERug+a z-vxRV_$)WAxHm9TvzTE3{*FC(vdjZA=P_%fx8_ZS1aKc&=I1lyDA)7Tqm;9e6krx; z0gDYcWk#>A?dF#XYGBsu+B!S$0}~P$;a7t9)>l_OHydtZ+5{Hl?Bw}=1+CX^j|{Ds z<6>vB?NiHd!l$Q;;)Luvl{F-F%SVT^;o>K&rOC<3n+Hhmt#9Ph@apmPs0Yf<&dy)Q zbG7S_77~F)cev9fP8hc!cD$6n*>j_j2bhDUjg5`NMNC&V+x`IV%S&Bjd}QQkCWPC1 zvR)Lu4F#$MXCIgrSgOm}z*;y4pPDLH<~wfrxK#qf+vUAqIzm4RbmP|{&$GRugN;dA zg22+EJHUmAdLG)AXSo2E3sjjVGGYx4Ghw-n`jd5FP9O<~;S{|Ce9n{7*7x7lcG}{-ABt^&t#6$w?_e3!l93{XcYjhY-MV$}_f^P8nfJ&D_y`~n2w6fLA`b#RCjzc7;bDQl%a@G; zz{@LJaWw}J2&L=k7bcz-g#ZL11xY}J6kXHymwXfzTrvb75}kQ6m)@q2d~h)p&w?nX z`Hf&AY*df9K`p9W2&kVq;f3Y|02^t}HG#;lVsx%xu}%e`=p@JRS+jj49jk z@v(q_0HUgugF}@j11|h?;J^F(dsM-@qd~^jw1focNQDbfk?pFm@>9X7&24D731e*e zY03sX7!i*g^tJh^9XMhxE-q$f+h=EIeltUB$Li|p^YimPy}g(F3zyRxHs@z&bq=J1 zl@k*a({H=l+I-)U{yM!zBqne3LUDTmR6yrjQ1CWFjFq2%X?l8kagmRS=^X?=BJei1 zzMh+wc5wXgdt)OnEv?kIj~}F^?KIf%rCnSuCnqPn!*HuAD;pXc)6>Q2a5P$fr1^DL zC)-ZRNK0=!@_@&8&o#xLJ$tscwx*zgsxV-2;zG)bZv9zOi{aF+P6gevhl;Ikvn^ zjE6TnK0dyH` z+t;ho)i^rr92^L^ovg>xs(5;O0#X}E<362I(b1~2{Z>)}G`FOnAY@Rru&Aika<0;0 zb0i}pV|{(St?ivarq0KYUq4dQ9ot5uA|XLUMd>$qUzH2LDFgc9`=V`WP6wz1=u1T} zah2uc(xJDCj!vuBtuq+tnSDJUEp3>En5nh($jAuLg~8@j0wiSz%or zmlP{JXz0~GHc)N(9{h8^AY%8vlap`2%Wg?+oiSRe(x=f)5ATW5%HrI$1S7JU$W^^z z(B+%;Ki`PqyO=fj_U#*fM9*-6v6jcxY7o$1dU_g=?dadOCJ>b<^YZiWLw(J5ufIYA zc7u&p77#!OFegtvQN;rq35Y4-7~|4XSNGR?e{6z8Cf7^sJVWE6)I3W>jJM-V**(pPZZ=Q048#{@Up^N`A0YNn>Mb z>Y)-HPEbIAodSnlW4M@srlzLip8|}FlM_2USKW+@ns%rl%i?;E%hFndCq0=m@btWET)8jhJ#AC3`tUeVxtBn% zaj=?0%^BmuKNqeDf2KJI^v^{xo{rDGlaEhpZ8iu5_eBM`sV$V5 zk8ddA=%w$*%Re_4;{j9FVy5)IljPyo{yR_`&R-cZ1M}y~^I29|AfF&LX+{1*cJCdEM)C z*L&!>$vf>46d1^$)xgHWa&kVSvr_NSE5!_3Oe1(R7WMQPI1*1L4vvei5KL6QBj4Lz zFd$e}ka|hY*RQ{3S|84PA_yLL${q$5RsjhvHg0ckM@B|sQpvEfv%h){%NZrC4Fmf1 zwp!h5wIg8V;i7gKIMSw1^nM!oY%mB`0P)i!-qbPw@KzWAdqq{up|ODK%5$DL=EHLRPX*P z@EAbT{%v(OaM3b1R~9%QIyhQtY$Bp;Gsk%tD#eYwJ9!#bpgNCjM%4HF?&AySpERBjy<=8U>+e^u=3{%6{Y({h z+Q;$WDj}>Pj()hcwD8Kv=y(PGO3E1qbeTV0^+;UpGOBaG@p!xODanUl?bp6-n^vgL zqYt@9*5lFnTn~2cs0&af+LX%q0 zH@Il_73A^l?ztaBm-Ad~;YMu4uf3mSZPY2I5t4@~XA5+H(72x8O*sW70sK(ZK`R_f z7{0gG1a$wMt~X_bL9aGira-=sd`0C4{9eO?F*h4+*SlW5<;WMmTA0iTX+8D>A}{J6 zlO5PwelwuQAox;LncYS`o}F_pN5Vqs&w4+(v_8bvX8XFGZj^W&Ej+?Wbce1DYSdnQ zvUyVgQe|mra$2c#@%T+swq@Sh!n6wV^=J^-x^RZKWfa)C>05Zy)~p-&IP>(1M|5!M z3=g_YL>2s5r+{=S1oJd>zWGnKqRnrD+OS$I)v8!5*Qf~{39wBs0E7MuyP4fan}E3$ zZ@bC)Zuqp9@gnl)dNt?hOt>2>3mAWUk0`luvarn@D*}Jvxh_m05yutWk+GG&VlKAU zx%tc{3kT%}CkrAm3Q!>oP~l}!#c~R#6l_+co!~38%`*DcrMqfQ3w6i>M-sR`!A)2q zi|PUmU0S=jkU+CdyvIE%p9oRB(KwbW+%ZG8T=O=ysw%j`B6;(GPV@Ptj1F-c#^^oy z#-P2@9yrb4YQIjPsa$Di&|hJEBbt+7I<4}97G#=!==~O#?8jioa?B94 zXFtpJZ-MGkT0if}Ct1XIu-$z*N>fum6dgtZ^H#;w|EGX&5+>#?-Av2E%uN#+$=TxU zQBqJeUD8(QZyF^#rGD?ncB$=pcblNKa3@|7apO!no~?IEI_Z^VrF8MV=2(@49Q8S* za$T4?IED&Fo-rMTB3NtRg=5FsMUn6U`x;jSm|0ylg48As4haa%ZjCjQbUZirIJbl{ zW0{luyU&L&Q|?fE-Z`$RZyG>an9&s#zBg5D_8rk`3_D%-G{q>0^hRQ7-?Yg-T^fnXi}n&iJLH9CM^{jrw*Pi z9i1(wu&^Pm@2?c z?y07GVuQX%QzA9BBg{d+Ymv|HdXpyH`xaqs&pD85bwWaTe1@+j(y=JBmWxyZZZkO8 zmBSsii_FRJlknE{9d=o^=Tm!Go}Tw`JmsLGpY>swt|n0evRl(L%P7O-oNYwvDEoYD zcAXt=joQ6T8+pkfbJ;buVi%#x?hpkAUCzZTu(F6luvvJ9n@sfDkuvPx-YuCS<#i#U^Q;+uGCYD|Vhy@TuG&%(xKHiyV$cXh@@=uY}n+cY#^ ziopdXR!>zWrHG2fd`c_3D;#B;N5knQSb10;B%A8dq$#2&`oc&Qg(oA~o} zTse%W99WR9laR-YHpxkkyU*5x1ryzi`-caoFjON`eCZBd{jfR&V0~?Kl@|I?dT0Z& zlguJ%YvLYanTI9f>zBR3X$)A?8cNb-({@zz`0sj`Tkr@RY_$Wt%=TmiulJcV^+g~b zn5COMMi}qo@k-A-(+9)dHh$FC8`wL{7bKukMiHJ>H-9kcp|t+lwP4m8`Q?Z^%SCq5 zQi^xXkUQse&E=9$Zg1x*H+Q6u1UXBCGdbc_<({Hcm{6}1jWWkka4=hF2c@2NW}F5E z9hfth={@!`Df&loW1ZjvF0wBY4Gnk_S;4d-txekI<~~1B@7MIP=semx2m~}Xz zr=>uE0Y|v>Vp8vRnoeH550=$ef+6zgLCublmZk~m-QeI*-ch(T(`M%*HwL--=~55t z%hS_Bb30KLLTL>2F^Gzcn)~>t!j~bw!PG2sX$GbmywpO*5-&US$qGV@Lh6=3CZI?? zbS)YOLuxkY)2wRCgvuoxrMxvc$OjVw;UO1q2$21WO#FF+gHz;?J;(K%LS*W?^X%hl(LHn4I%ybSBBDEz^i!{XL z<)M+7YAa~aNR01zaj3C^kXPvVIXL4TR?IAESc_YM(Aw9o9%*k@Mu=a8ArliMToNZu zz7UTSZH$R`rzhEXN<_v2EeA)j!0^F8af0G)rTXo+^`@>s9wQwsZCFv2$@tup5FP}7 zh2N|^gy{-8896_BQ2VDW<$zIYA}s@{CcmUCri4o;B>YOhyX39gYQ50%u9E=Q-HJn} z@oN)$-F<%PX($Xr@GL&->DoyS3S!ym0@V7e3`14bdBmrmDLb1{gp2r?=v5{y{R71i zzrnmA9zk!HQA}o(DD{_j({S9&b{3}k&=;M{;BOxx>fN~-uPQ6S5?Ng2bnKHW5tu`T zL7G#fXaR;CSky{z;@)!+8M~Wa>65UYEms~bI!5?G1lH4C z{`EUctCsuAxeqHzMfNwrpO`DdtGvpKNy}8x0_($ESKjZ;;18kxz~%lm7ByMswSLj* zkklGHtMbgrCKZl3O0(&N!A^Fr1fQYD-2x$gnm^K-DPgqhvJdf_V>RV!meJ(NdUY7Z z`%(m*;kB3CTm@9WLpH@PYo_GL3{L5XB7XnW+CN)oDG5KBnZgE zrWX^~sxVZTnRgmAr~IxqJ3mmxNk-OHv(T9X$HAdyeGn0WO%#V_K9L?WwMF*~5754mCgDubYC|svz|^Q6;Y2W z#ck3UN?fxPDc#M^-?7ys&PJSOSIXDWRR~L0F z&5|=97dtFr7g3EU^h8fEZCarlmrR5}au^enG;m)l-Rru>xWdq)qL1gwM8u|tSF*kT zX7nNAyThZ`>h3-}p3Z806vG}@L6}GLT{YDvHP=n_O}dW)2dZ7!=P1G3Dup?XCAZ)@ zG;O3o=Q> z?|KPCR}$ngZR4q#9@L$N0)sb$cJCr_r@ItWi`HqFWt(?$GD^Cgju+(?V7Gffc4Qfw zN=c@>8T^hp&9z(|t3*w}{B>XJV&-vb;Ogi%666V=FWDBKJsl&?5uKFWtQn!M4EK0YUn0v)}I zSx8w4OBkXz{%nlf=(rsW4jq~I;riamwv?cnlK4fZFCqek$Cm5UqzER24@0V`5{vcR zuz!GXMt2fQ%xIYvA>_E5N8Rw-6c1faTt0!CpG;Kx2Qw`duP+J1(i=rbM`J9*(DM!Z zl-D|M5@@<1_)GdNCWw;R;$=(_8=TS=EE*(IFh)-xf~a?1zo+PQ)y+s)ix9(ygEj8i zY`HiE&pg)-zU6cem9Rm1w}g!som26Z^VvqMYgZDfsVlEFzPxpIq!Cjp{h9*~QTl#! z5?*>v^Xk&Cq=YUG5mn_g209=!Oi`?{X6;1wY>wK>mXLZ;2zlrTPl_4B?DydkI!pcF z1#LSvg&{KVrzF|-_=rL!zJ?xi6o1mOp5fKGDuK9w_8Mo+6|78Z?VpnDp;yx%i|jt; zC}}o0#g6+ks$pXgDPw;wI-S@EcazE)=0mIcU0lszHqGXN5JY+!%u?|q9E+t|fKEw% znyP|7RIcB?SZ8f_U$l>w?{dA$sO+)JXCx=<8{08bcXQH$>guj6Byf3VnJYF4eMJ?$ z`FVk;$96K*L}w3raI%EpTFH~0N?N+v^IKYB`a9)T#|%H+jeAD*2QKcQJwT$Gu4 zScT@id+mPilTagrp$|H5ZwG&rtmq*_Ad}a}ji0=?YivGq1{hvU8O__^NDSN< zTcHK#8hOdL8K<|Ws&h18RItiGosdrfB%j0r*e2s5A-36$#wp4qu>C=aAS;K^(494}h zwI&`bB9{9bW5wAh0-9nJMNE4%tLVopw}x?S8HV@;vc+O@>O%b#FZ94!lYV>icURrQvc)-JZeZ<4@`x zI8?=R?yIEP~IG+h!uI$EZwPHnGkT8Sw*YfI{EL4w$8T41R2RLG_306EE zkNufl(blu?$jKjb_D&7w9`-7vMN3Fz^2o0-qJ73ke~&a>W(JzFsnId$1VycLa;4f% zk2wQ-T+MT|>-oc!_p46K8)p2>xb9I+zm46iu)gaoGvn$~i6NgziW``OEzpgBRIGpGotiJ#G4hXm_KDTfRfgmo_VhZ`mAW2~Vp z#*joM^3HKubZC0{r#TnwEvt%A#5(R*@{Jkdn7IjgnZ$0D57_fsupwc{)hc&V$pxRj zYF|{fA2W7?M3Yb4ebyjFiO55l@YFAOVPrT*?8CKnck_oBsjbV?TzCL{8xI+A%Od*wgxX2$@_!B$(dJs9h_%lf;X zt<9&Zpo@O6fFjXaP3e%DwS8ux@r#P<$j>4!?l;ZueVM`ERqX8bC#o<-$h=cIJ|H#Q zd~8bc5a@R~>&$qZ!}1xZ3DE^6EcJ)Uvq(zOTY}}p^c)ri3bvb7kIJ!3CzD$Y>1iXZ z^t_)2K8KulHYSEdjLs%Dxo+5$Sy2x~ZW9+`-)>@TUHKSS^~-{K5usLBz|h|*NG--OVOLcROg zKbgMf8@8eev$3xnoRsaHO>(>|-i6uXWXy$g0jwQ)sCcSI9_85T+$}P`s8@hqq7vsT zV6Wl}NlhJpt&ymv*tUeBwvvl(5KbM!S*ubfovi>DppqQPVr~&X6Jm{Anch{JIKS*_ z=5UG#_tw4Y%S+@PBpyoySsh}}WQ@-cHlA5IRXumix(#$Q@%JKz8rZA?l37|Z!qxa9 zziS18D;9%F%oimS-w)GjRW?>4nUuk2Uj_z-xShH!garSne;eSRwbi`38#KEs4N=Q; z7@8Cb98;#D1qZR0S6^@218ZDshU0mOqVoFYPBxZArIMx~SX*^WZlHv2CaK|OD2!Ao zMFJ9D$d?~>LCo&5ir7#0?Idlja{5{ScrwA|Yav#Z;b~~%{7hI`ZQ`3URac(SA^yyL zDM__12N7KuWQYG-sTft@{bX1KRT(#(lT4gJ08eZrC{ETeHut zu0l>9#8F_E)6ZDXlP_UWh%3#ll8R})d|VB)lJ5q*Q=1nA%C3NKi+nR;t&crq)nDBn zb7Q0hT}Ev-4B*{kv_@xhKvm(wdrf3` zs#?EhDQZYO)7u}emA3I>PVbo&3cFy+MnJBl+4pR3g{A_Rx^wIyn|{%Z;getp#L#j` zX_=VqyF^kxjqMj$vp5qQF=z;VncJ#|s7XDHV9}(c=Jt4AY`T&L1>;cKvASyYyr>vj zMfEi@49fclyz2V+0Ds}sGX4-@7J?nsh%CjJ6*cC=@E1BZHX_;;N5a{en&k_z7bry0 zN9pntvkPxnki58(f9J#ot#wGWWM`M=``B~-G6DO5L(j5(cB7l^z7_xeSuBRM^p<3JXy3m zmLB;?kyTj0e&uUfTRg+Qxowg#9MO04FJFeQ5Wqq^-jedB-8IkZ=G2W$_T;3LEk#HZ z5(>^DLamkgJlHyhrKbJ6IgRItgUf1b0Y$-hcn}`vs_P^sO8qLmpta&G*rMm*n6up? z$a6wPe&)n1RqWq^*po;wR&m?EwKe`hZCmz3dU0^DyUrjN>zV$+wc7`#7oc($+rbwI z@|U+%Z>tOPG~jA#PHv$0Mp)VvI)|dZP5Kg7g4OrOtOG6rp&9(u7KgkhNKvgxzi)hyA4OUk zGt>C%rrfkwzcM9G)m2TGE*R%aG?y7QagvYl+i*s9s2wtbeB=FyYVJ@u$yr5kGjj^f zzreGFp3$x~{x_19)p6T-Q<~Qv zK@{;WsNAnYWN4^ds)EhNLB&kIrP%L8(vm6+3vjq@cZLMJgZv>uN=~dAgwhOeZTV&r z#+azp8P$t7Uc7)2HXrBKnw=d0<_bo(Hslgetz46*I5iD5$Y@^*MZq=6f zO}5_8gSWRd)%MB_nljeflAL7$o3>n-Ez8S+7ooY8yp2s`j~yYgqWQxb#6B0@VuKB? z9j4n={FOS!8IH!suxU zZPM?qcRs#5niwUqV+%FOIaOD>eIAk_F1W>eXS{p#rnehsM%8Bc#mzRYhRVt3@-XRa z8j@w^DshF9N{8U--X5w@$^q;(iLi34nKWRx`pJ2*GlzLO@6a+dygXb(_iV8aH*(_l`==?*?xc7p;Y%_M%N(hI z_P7};Z?J1?mp{~s#niIjY{Zg1^z7?0Pe`#EN>8dx}XxJ!zA-;J% zgxNp`m$EKVYlX=7 zn0yJtw%s)Fwiup=R9o($Xz%X$SAD1*Ww{AV8#Koth8jqu1V_@g~57x&;;d_u1e98M`MpJ8#=YD|>B|J2C) zn0oCFH#{W!%k`)>o99J>9AV8&E~Cm;u&HCYe=!eH$z9py$q!!FlCZdT*nMp6Uvvee z_6byI(N^-ztsM^%Z+qE9ouBbR4OOC2!xmpphF*#cmVygZm4Jxq^PWVKb$kSO{sJ=~ z*Fo8=D}f}&kkF+#;o$L^jipg@^BPE8luf5wtG|cD&hUgtaqFXWx2D^dLbnMSe_s$*j&jov@Du7)98~mVaOs}m5fZ&+qkQi*1Ep~;dkt);7JWt<6w?IlYMpd^ zY=1?@-<@k`ELxVe)$QN*%Tr*l7t+5__I+@mcojgRWA%8okC_jFZLG zQhLfc)bfQdnZAZ!!W3fTa4V^r$)g~9N8U8%Y|wS^^)^E+9|Zvw*N(AxI>l;gPL?cx zcvyC{esQjFxA^-9)Dg}ooykNv73Z2AAZKun-Mbd6NT(9tU=f!z_9orGFzHwhqXd^b zvd_Cj4V4icXMm+~7px+?2m^fu*aRT}!BCEAbdclN$M|MLa`;=xd$4|}+3zI>@9i?% z{a3NdxgJrOKdmZ$XzdA**mg`xyoJV3Fz2zvrcTIIdosDLe?1HgPD-4Qw+Yh^)2e-n zl79`UorolQv+4jj%=kW+z%lWzVs(3ep^z#tgvs3lrjV{)aq6|pmI zQROUy)QGYtz&2WKY(lAo6c&+_N@r0JG-ko>PH&)XF&%O=`!hKF$dm;OO1($vygr>$ zO+Y5VH8P*@9Dshn0BP}|Z;uz6*YAKJIos!!)Mxb~losQWFR*G*Rw%Q0{V&JO`Wf5C z1!;h&>KAc3y2ff_XeiLHX_zRWewxf*+G2RQ z6Lu7cUF$!gk#H&D;EGci(_Sax)d=SNr+74ff+a=>5RISq=%oZgFGOwVmKgOqSJOYc zh7bdrMTNih{Q%pjMU%@BVkF@2K;OM>r;Jt-`Mg_ge0}+z$^hhhiw58fY|Yny$E$gz z5rzhPv46a*aVxG_6X^kxSLxO55dv}=mwU`ja;^~;n|Rk>H^qKkWTs$-L@!`g5lHw3 zmrv6HPz1~#t7mPN1Y1j(z01g0)&xw8@JJ;E9s)f=Z&rpso%MtmIPN&weLjqCBJ&T( zhb=TC7#n(nUQu?y?LxjY%K@Ipe3VZ)(&}-v9r|`<^_Cq*B?Q(Ax1$Q@bRa}!g%1}q zmmB~*mGq}3`&|gmBQ0cUw|%wZRkZ-n3QaF}RX}uPgSz(!RTcfwPA#WRcA124Zbb;K z$-(%qvyrWx3@eMioqD#tDl~IQI&$qjQEP>@7#=}P5K{ei)DA<(T&x2FB2IFV7h31? zu~m=qkO=TXj%D9;98FgjljuGHi}SlDF|eC13h2I_Nzu}l0MKr9bZ7pszVY3c`EL19 zx?MF6PYH>HV`nT52yf->qO~v#$QQd`$wTytwK0JX1w@*46il>?qW8kK@teUm-|L5* zT9G7Vi{XY|!C&u}QD6=N9s}ut*Iy4OiiUB#-I^QNf(S^K{I`dQSg4?(RY!GbEy=r? zuK3ME*0AkQHXghPNt{OhzdcqwAah@joJpAk2VRP{^&s11aP0vQo&(&0m)%h{md^P( zMX7k#1g#HWA0F{{uFGD8B7o9v63>c>QoYHS&xckMrJ9KlhdNj3u6sxV_U7-02!t;tJ7{*mRH<(32hUv{dr05$b9a`VCY~DqDt3iM6KgncE?RJy=o3O zUu~hiDF5D(Id{(ZsNQ`)rlfK&*0lT9=KXVP^h5W;h8@mDI2Adyd&QjDe}O-6Gq2`M zGVLCRy1Ng3{x{?YhmWI!J^Lu2T~Ejb-k-?;dKD;q0dB?sm_aw-dOcRQGHUYz70G*c z47#rhQZC8Qz$yFI*m|Sn?71l5GC&5*AHacMKOGhHxiGO~zSR~VzEQF1g)JZWiv}?P z+f!AyUH!qLNnm(Cq^;|YtO=r=I+KDf10>c|=PVcQP$i^bOP;tkx04VU*x z$vV~b*BdF<>47D~iBtILdy|~d!XGlD7KX1)ZgthJ6KVlm5zv(S^6=%Gi+Lc6EDeEV z`@aOwE6IPM27~o2o&wNIW5z;oh4|#KyXF?OQro6#G*PBmZ$xm%sIa*NL%f*h<1S6Bhf3V+9@PbJV8~;8KAgwZ)`X?IJoQB4f~rPHwqgErBY0Nbr7DCS)g8GVd)1**VZ6K@8Gk#ALQh3 zckSIHX`0&G1N>h<;lV(@O6v)Cu1$;mEnlB0GXmSC`P|CJdJ{z!i5Z5VQRwSkqRGR}^Q6T=f@?#E&KkkLo_hDg1x9~~S3XloN0r`u^ zVd!9Th`BNq_50+ImMRHHRikvT>tr%c^Vb-`ZiYKDP~09TbNvRJ6>|CAY8Xh?7pA~w zY^;o7tvYIQ@TZ-i2y!TTG^vx%k2`x@GNIT2hrbx^gx0*uD@v-y4_bjwe$(i(qgwy4 z)ONd8kkX3~@HjM~K*G`W!rNa=PDjr+O%W|xJf;K;gO1KvyhXjCoL~{T-1+^PP$64v zFrJc|4DLR^`9VHbjGFKUcdt0 z1QNQ;GgCbf+=9vbzxe95lSEPl967ggx1%2cV(2Ran+NBZ1Unr|1hqgY4;MvD4DPk@ z;Q^Nhda{Z<8l}DZCtYJ)#^Ud7#?J$k6eJ9vqidwDpnd#h>gq3-s-eb3{ewRd*wTNC zoT8SbxK77D9Z{Wq3aL3Q%NF?ShYg0hLb|L6WLw9F^CUauFI@&0%MvAr-QseJ{S06C zu1P>3W7yNz*o>ZKTYEpBqn{v{{K0^Nijm$(s3l5SMCh?b4FM_Gz)lMJU14RS&z5kg zNxdp~fJ=Xh{~K8%e$hfX+x&Z7PF`Qr%Y6S5X$m4J{h!%y$lBs3{-q%^21~udD&4ZA zlJ(+yW#2J!!#g^9g>3@trDE3$-N4kuVz@Ayp~^6;K8A8NFd6iNBA%C zlaZd7fYQ)N*vnougmyGUs=tk+^B~~7eo~QEpBVX6^9BVI3)(W9WgW^CrEzD^0P>CE z4VQ^41Z{o)xIR--@Hi)Be0pwQ6UJ)^M{Cgz3e-HeEaAN%e~niX^A@r2=*c%i+k^g5 zMd^e%T5_AZUv7Y~4__Xzx#b*;D|iDT7pMQ8b$5oK6oY3_z4ZAH5K`C8Ac&qQL70Wr z@mQh|KvcZT7H`GSp87RSE;kbJIT4Er5d6)^94`Z@ar%#W>e}^Atqf?~dqjq8NTJhGCiGSNMh@rx7|-D)N+)mv>JA2sg6fBD z2RDy9E<;DOw~ZC*U8cTU?SCC+IpwuXb@CCU=hT~-oyuDwXg9EW{Zp8 z=@4rgFVy!nkYLSpi|_40UOp4a`!xCCazC{Le(df9A=PdHQDWUxjt{Iqt6|u3iW?-R zpq4MTy^rys@7v(E- zQ4=C$GH`|BTS{5~~ ze!kmB;GZQddOPN#Vkf*stXkNjjr|XJ@pM6z1mv(O^}ISv;f0#bWE1T%=wc)|PKcA&}Y=&nk=2?Dp%KZqMl3c-?P&VRzma(U(nJ1Z zKW_fzliZ3L#s9u*o$;Cdb`1EiwSWircw|G&?{YD8*>cJMj}Hg(&9%Lox@&nPx_7Fc zU-;zP_^{bpJvREVe|yL-c(ax5fFt%7POs`1uN1ocNaVFKsb&KQdwO&g)v15c{KboJ zCcPu+4}SuLqt#VP{X?<%W21#seR8J*$EoT!Zc9i;zf z^?gF?{|{Z%`2dMq_3U8|}4VZg>v6K4#-zOdi`VHvrMBGXAcD_B*@VgS@&=cR=R|3~pabQU)!zv8aQ;cB6l zR^FbrlJz|sW8pGOU6oMx{{wb*#cViyfI;bJqKRWk6g2r+5nbHG&Fs%P7}t^S!^4R$ z^1I-GpUqK81N+l8`i0^8&+mrv!w0y>zrFDP3p`I~O-;mS<4Ao9>k24+HoWN1@<&lT zpL%<{Ga&C)EcRwlFIEZxLzEA&vK&f1RBwxUAoYcwj<(ZMR@TWpdOlfwXE8+j9LC*L zp&}UFe2pTTM$WviCYB^_t8RDVyquJb-#a_ zz$@qIC%#meS^bnq4B=4MW)Umt;c|F3yj2~pg-dFjc5#6AB^rp(U+0$i?IF>`*+yL} zjF){@iT~^jrG`HG>a5I>^UxtPx3b~C3<+$g#&7P_BcRKkF+Z!yq5!j)v*+zA!?>)p zNYqZbD=2{_0Ef{xIQTgbk=3_mg#h_szsltO;~@W}U8uFzQ?{71u>tBOdihI3$`*i1 z10||B3V9{Tu3ipb-x`yNL)j5x(np;=r06GyWBd_-I5svLPbDG-UC@RZv*VVVC%}@= zFUF0`HVkF?F@eGIKz>=0sV6w*ifN3`e`0G=XjKW?_M>`W7gGCjPDX!5N~FU$aS96f zDe`B=Z;YIAR zn_~@R?qtlrGeGqL60`nU>OA4rAn*2OxX+N;X(S)TIT$P-qF%28C>#s> zrG?*zw?vF8hkZyNacTn9E76_ALg#N>B?i<64~5{20R;I^=*Lgn*o<^ak|IRsslQpJ zzo}1smBZqAb$XMd8Hg60DxZ{%VDkkY@Rmjw7u(?tV@essVyJ%y1Z>?tGlbB zzb0^quZ2T4$AA_f5aztY-_aV#lLZJ7WrFGmkw)RPvyvoPo?8Adsuwon0Qm~bIUQG()y z^5rG#?3^lIKAG3lRdCscZ$2uJ1dkkjAXdTTn*g?8p4QqCu&z zz#gJfQ9$m!o~@*(SWtGy%?WwHw3}9`ahSc?8f9xTF1THc1rHA_Mu8~vRwJP~c*cX4oZBfZ7{b)9&Ra;D{ zbMnqm8^ET6$B#!+OQVeFT3>Mtlf`L>U@AM6*8)tCXl1)?WxwKL@fd29S5nC5E@d?B zl<3Hjsc=5Kf6`d=3{?d0+fkT*B)9fOVE^^FK*~r$I}cl_i3l5p?AI~L8*_r_MRM^x zNq*CP=86&Zp*KSEChLRdP7DjQP)sG?FDZ@;_?&bdPbfkyj!{k{%i^uCpFgGwNEM55 zT6_$08>a&6vlogDTFQmD2#r`SHj2E1gafGEaF)nP6^TBnm+MMFmq5y>tnW}3BCbz)1qS{)}67JY9bCX@#x zIP;T_vGoh&Kje%pNqziRZlx>oBJ?@Iw!9wG#<@f%BxVnxMiHnG0vYOguBIdhj~*a_ z3fYY7h6@X(=COEO{TjId-IX4|$Ti9N9>M`gAFh0cmLYwZX@4P!bWmmkJz7~Vk1{KK zjrElT8fP{ZXPefKBEt$eE6lMJfDgm)p!Q@OeznvX52swT*d+Xdh$XmQczLS^J_k=6 z=D8Vn-#jJm-;*d2(f6shbJ$KP#QbA2+0!7$pV)Oy#mFG?|>LBS_7si6LxA4i4p4 zH&TDYOI}sRFmY|BYCH}0@|(;KDkHS{DZWU=y3gbT?A^ZHi`FMAqGcD_gbwYVk{8&L z{9(3%bjRq`JOViDSHjB`32T>!VpI3A@oE7+2H-2BC8)^@ji}q;!2SPSS5+Nl`RSjS zF;R1awyfJ#RFY;lk==c2-DQBUW3g5Gmq-Qd^s{&`M_!8-eE1tmLyhgJ@OgW) zsdG=WkltZ&uBK>s-b$uRAGXd6IGtrB-!t;E=`|Zq!?aamhCVT78DyLiyFx_nL zZ_WWCk?z#Zpw2_3*#xDFxN4ddwgHP1oplq4j>jW!bIwiRLGB*mNr6>Dly;;MxW* zPBS0V^pY-@Jw_Kkfr>fK&(9&`IAKd~o1h0vjbQ{&D_k4+7Mg0~4bambu6hLpocCtu zt6eP-eihiXTtg!`m#^+lN5jIxW{TC;M&A5=YW(-2qTA8OyHSl=Yk-OexR;bLFfclv z$KBac0{E|AzsAHgsH)A)70sXUI_am1;yZ!wWLkztGj{v+2@dE=k*wgo<6r`PMMcFZ z|7F@=>K_ocUmFFa{={1ZK0k}{xy!CKfB|R05xeTz+n>bB3XJWZXL?SKN3GF8tXCw0W%qlxWYRb~r#~==lCd7~96x02<{pRK8z; zRRJiP>G)a-*@3t8|=UDDE^If`tQ??e{psf1bUJEbS}VCi0QxQaE-oiKdo^2r?BGVi1Tqp5Tk84ct$qA6GKfU|4wl)Fqm5(}6TklQ=&CJXkh+Q{F z(h+eP42_( z4&@6qn%mws+r`q?)qO{bc65E$jf4aB1=2~tklc1mPGY%Wk^r=Sk4x*+#v>@yI6{mJ zCoDTF3mD$Ih^3_^{0L)D&qd=#vFz2;YfpjJk`gOjT~d1b;ojb$B7mekhn10%Vs2q! zVP+N*8fwq{T$m20_QwxZZTtLka#Wx(0GS@>JUKXB<~_u{IU~5na;{^4V`OCQ*!f^_ zn+`8Xk8vm!@$b&*01h}kU=FY!WKnz=08xH%(G%dI-rTq!l)9Un(=aeZ248u4J3H5? z(v_B$QiAI>OFvR~=`*7nnV6hj--Y2a3IP0Uh$5hA?Cj?NsU8LCiRq3m2N0j_?Ez*! zY33*+HT6eoEY17HizoW(LamLBwee}`8xg{n0L&>ZH{!Qi+RbCBs)`5XdVq!J=GAjQfbR?R29V0^9zdr5!Pf+w;eg72 z_T_$6Q_AnW!p_bvq1ZrI*Vh7O>+q>O0(R;JY`br=B27B020RiHAqh!YY3YThF^x(4 zY!EUc;vXhGsAvTsr~?e`bD(yb_yHnb=e~^-?tB@1JiNYvf!_^yc69=B7pe*OJ4I!{ zQ3g&AfTla_PSb-i;Gd7k9RCzSZ9bpY(8NTTDR0|6IJVdkn zuyO;%q`gi5pBUGTw6wGu)OF*aUdvJT_Xk;72a5XWpVLHHmX<3mT68$HSO`+jM1eyl z+S)*zG@UHp*M4`0Ap~?lRsGxBkU#$lRB-TT24>ZjXNydJ6=70(`aU$C2eD_kb2a0y zf2w*9XcaIi`b(C9gBr8}zB_LE*I-kiBk~1`42gh(19G)}bh2{+j_v@aD4>9K3U|M8 zTYu%b0*5pJD?%`EPQvLmE`uhQ;}%+NA(2bmm??0eL|Euk+1cJMC@rlS;#B(hQJ4E; z{sdsRViu%HQ2gyK)BQZEpH{Es$Uk;B0d~faxy-4cx3_mwsin15hZJr0EaP|2hm}sS z<0*e5Z+jc~pWQkTs`)Q0?0?^?{#S8^$wySmxAa8G-pa;LhEhUQ22v`l>-WC^H00XD literal 0 HcmV?d00001 diff --git "a/docs/fig/GMM\351\207\217\345\214\226\345\234\272\346\231\257\346\265\201\347\250\213\345\233\276.png" "b/docs/fig/GMM\351\207\217\345\214\226\345\234\272\346\231\257\346\265\201\347\250\213\345\233\276.png" new file mode 100644 index 0000000000000000000000000000000000000000..eea1fa6a4a1e0ea17c318bbcd9e0f8efae70f093 GIT binary patch literal 27396 zcmc$`by!thyDq*|6bq0LlvJcsx_$>-O{jV0g>)Scf06req;H5 z@7~|H_qTuN{Bh28=5^_0%{AAYbBytf@!a=wKhO04AT5UW5dR?r0zs1y7nXxS?sP#Q zx1Zm?4UX&$$SQ)}11oVATL=Wb4f*d@6fHUd1o8qRA^i5EW8&tFi+bc*(#=i(3&p7y ztj94Z&%G|2=Rfl{I_8Eyd*8xH3<-;;e+$>J4^EFBdAO98@KfGym>u(1ymDx#X;m@# zA%PMmuWV{q?tK59rXz7H_tszYv1Y^f#Jy>`n*=PL`7XTprndduk%R(yS)(@kZKCusq$gWg6G;lT}Myw3von)ytz;&M0Q!O5BTu^W|!NlY*U)rHo2Ik(rN>N zc(VJAd2)N*Hf2Ves-g{p>5pa3u)V#7AAC&ygH0hdgrVC4R?j~Ud(grZ;re?YmwD{prTPUu`(0p94uQNOXoJ4tTfJ1O zJ}v3=`^COVGXTQ{WAR{-Efp7WTsiTI8JFH)xxzuTvpyKFvKSQtG4^ys%gxYU`W@V2 z5{F4jItU#KM>oofG1dUXxo*v!(3@?Kwr*>tHSo}Icd^%vv32mb3*FszbWc0Yy;3aZSVBoQ*^cv|cd^HafGES&}>Qowx?4-`4^3 zlc%>

    nosvxTg2@WlV3t2`|h8$~Nbkr#w~6#AE4M)`!9>KG)(YG<~gGlJ4%zmrlG z?>2<*B{OqbZmu|jX8zc*@UiMrUa;}d__(sFYFCNc9Pv%=-DxlBoM9LWqyaln2wWC2 z`lJLr)xX`r{9Psn@&B8-ISXoNXh@f{81)%zNaxgF?ss!1S+x3~b#;S76-rH4GekU} z3Ol;(iuD)U{JE1DC@ICJ-a!skcopw`=qTZ^%wtxfC1KR4dME62@AO8uomnc-OR~mh zrK_8<$@OgCpu^j_;reuMDMK`PC_mrgS5MDJ<^s(UMq1i~g9BSL)h71_Ki-?m18hoZ zLRRAss7sO2(T4i^C@3gqv-M|~F^_p|c;7uP!YsLEA4GT^l=X9Rd_1>)b8gPo@z-l} zNiW*myu3JW`}}s@z2jqP85yO0I}?))kE<~czN^DwRi!a0+`yG#83C#H?>}n7M@BX$ zD+>;C5_(q?6A3RiAOqc?%k0V8Itck>^*x+UES3u6O7Y=$@Je0JYy#UCS!WQDbfH8% zifU>X)3*JiqoW3fhSCY_QX3zXl&W1$wuf_MTwGij)oXTZ*OKMW$3{n!xb5F@iFB?j z6|2{}o|RQspH!F(t71@*9_w3@ykucH*k9?sb?eq6TskY$GOfd|DB7YqS!3h%mF^g# zgT0&-chgekL>Rae5ufwNytX|WaG|qSlIs_w;8>UX>EXsODjL3iiHb1)>AZIWhb4{W zgr=#kZu-zcnk5N@)$ooPJD6Jo1FyB4Aw>mNmN!mRH=`9En)~HZ0+Kgtp%|+s8t4@{ zC{WC^^^DL-ggj{58YuUo;a2wfd>!0 zTjoz0u&}XtZjO0Q6KPK$+*hDGn|N0w>Dj#i7k$I_V2`P5@kw@ZT8-0zp@v3Mb+yBr zH_bvkM*8~iQ~0kkLiy#g&4!1DOrcP0Y-}hL8Z~lpeX%pO!@RjXDsaA@0fvCqnuV6O zpt!h?zYObBR00n2bW_tr22&_4X`r|yfg8JL)u zczAf2nD#buQWO<=){o8|W}S=%Qr|8tENpCSklJjHm%?B$X-z(`fzyCtNlHq}2jmqL zl!x+KFSd2ZFix4Vlhb;fZ{!@48LjWyZ$amj^}J$YvOZYt1>@XxKgo3o zl3Ij5ra5;OShUccoR6ZBLBOeekGI(V>CDMW!_AqkUMpXM6hiPLyK7h5KH}-CoJOmn z1==F0th_+dHfr)6)RT)4_}7{^Pfv(T;m4shVBN@}Q_M--ta&KdxdQHUOm=R5{LGq8 z=y`2b)u-PhdD1hH6b+g;11=0&jth4iG(FDtSHLnv$n&XZYRcvOaKo^CqO>2C#5J|& z>s08s&_UW#^ZTj@D$ObU#zro{w|6gF$nvT+1{F zfvaUofsM=hx+y)IE`{2Q9-QgBJie(pSG9Z6jD7HisbC&^A>O8oJN>s*@fO)PnON!R zr8-lScuRXtEHU>^PL!(ptTxq^l-4Tps;~{Ne3&7(z^t=j(@&? z|GuZQlgs%q*Do_Gi;JE@mf}Z>$MsoTTN`LzWvwbZm1L>m7 zrYaml#KUenk~=Uk0IuBxJPa`NEwqAIez)_3egO|KtZ)vt$s~R^=Y=0%z3-uQb#~?z z6;(2#%ewb|U|=;Kc;VeLkd?)lyx z(@hP9E7TOgVBF;sn-xSyLNoPF2N!cU?g@cZz#1mb+1tPH1-eKSE@5AoU4dHbn;c0uLb>Dg9OZCK6PcF)h(>g(%A7_o|d|3Jrg z1H0e%5u#xXKcRXnY+;<3POme98+|IrkPnR_C3U|$IYDt(d$eMZ>TRp&xYpBUgKGY6 z$4hM|IrR#gPnvHL=<9c48vsl+rZ?^k3)eBp!*sW8{qZBSzMhwEqxF@FwZwDVE8QPI z^o0cl!__Uf6Ztq+Rb$1SYRTLj$y9~CqRliCqJOW&VB%7chJQc0?{a?h+-*5+T@iCG z#06(_N`cE{Fx)daPYjJ0r(rC=vEuq9c<5lb#ocsNo)l$W9>fp3orAcsYjK zJ1u2G8zw&2O82F9^b2KLnaGc^F1asiQY8^N^dvE=%A@&dG)4zCCf#Ex8!f8&-L^8* zrWy|*xkcJ@#are#o8`w=Idd)ZY zf2}4$M%6rKEYO{ox1o-(&X8Dkee0j%XKB|aq!^o8Nn=}?_0t*9heK*v08Lb5hB0qW zsl(N5Os;JE`8hvj^E$-d{Cd|he|bh}#E*v=Mychx-I=X78WR*nn#1U}U*UIC&e*k< z^09khEjjYHq{1lawM^ssP3EsP1tb0vtmo&`2JB`=?g)7OXB?P`a-8>guzYjqqtFpTd)O z@tV(NEE1IT(Xg1d)pBHbziE+2iJ($^-JyAcSxP*Ybd-EmeGsm3-ew*r!WN}qq@_{k zoc!{n<-8_cqZC@bf|t}goy2D1-YDCqY;!$;M%|>-$;h41PlBi`C@qy0ei$Ilpi$M8 zWmM)(`ZI*rdPE27`~c(T%Is1=n&ddOJL_nepN@{2k&b%qY{!zFk>cz!Vt9=<3s#n2 z|2=b>jMTMo~`lbf-=$zG?^C;QMcE6NL15 zQ%SxwiPXi_8s8p&Q;~PAnEUy`D4=-Mr25cQj@MdrRJF<Bv5 zC_7IR){&92H&F4gQ1RP_boZ*Q;U)2?*9f)W+L@^v8y~N+o1l5xVVC)3Qvf%Gu1OQK zYCTKLF3PgQ!=@`zh)2eQpJPt~Gwioy2W>|VFXdMZXWqJN<%8)@N8dagS1wcXDf^E+ z_S@SwzSf1_>rG}E(=x#dE63XF+fnrMdmp`G(!Re3e<~y*{6l|o`v;4yrfV_kTueN_ zvt>VCHGlPlnZwR(CCPh@>qkk_Cz|k*O)>Yz#Tf3=8_AquiB-$l@72YZYiPZ`B6i$K z>+TI3_AUr0V%Chm_%g;!{p67g;#|s-Cy|nSu;^^f&IQqQxk{AOT0;ucDm;yAINB;K zf<;s2FskWVUX2yi5lZ7Et!rwW2Q=7E_3;fY(o1yF+HMu?>`aWM>n3fdTw7N)9JRWU z@SO?+cLU=pzv-Jjzev1;z9R~rW>J!5tLSD;07IZA(@#3*qO?`(v&}&$PrXu-G})(f zZ>dkANmN6Xu9ET;H@{|8mDPvTt>jBcq?~7GRc-P|+>7O!4;BO}s-u(m7}0x4VI(1F z<<**0O_{@7?p4bkd^}yV&Szb09f7~?m~~HCj^X_@!RRS^gBx~lZ%(`uTnllMbMj~X!)ow2OwYW#7cty)gOrb=KG~9d6n@lbUjyYCd;{ktZySX!I z4;ScCg9{8-3x&hh`~;@FM}kQ=h91+C5N4>tq8b;SuUx~y@3oWCva@YA~BmNu|xuSxHzlK!N8j!wwY!SJ9a)`Qc^P_*V*A8z4u zB6EI^byH#IFXEjOW?|nO>S!XbIWjMQIxQuZVu(_4CD3EKyqYBDr_Og`lJK1KtE#K{ zHlAy3E#R&K>~XsGVBtbd;f8)XdF9-gr!0Gca9^d+BO#J3E{xVo=^;G+v zI~bim#xx`QAE_PZ$wj;RJ&rKdtMv0ocom#7O_=J7euU>VL|%Ps!zV^7$>V4|r{hL> zdT#i@LT`BFs4mWud#ic(I00d)RXH;~N$H~(3s%T1{cwC|lt909gN7$6&fz>06)8B87%mM?{*acYVe(hl| z*zG^j((-j!;w=lUu3jQfiobAjYw zU4V!ml|4a9Q|-vs1##hW-MV30d3_Xh4EJ!e%n@qnn^SIbh6XN)v{P8#m?kwG#USQf z2eh#h#;@%Un#R?J73;TmqpJOX!8zpGg5Rsem%M-3JKRA*JkD=fF-)O42t8&OGlKrQ z9QBn}jp^BYTAt;qKh()q`j(VFrxa47R3AN9%2LKGqZ;!*zm}c(a@4PsnVF|aU-pB7 z$ytz!)Xf^nh5(DH>oCbd#?OubMu$DqKH@^7-k_&pP&6$Q4O!IWg#7~@OWfqRgLIFD zKJ07xIS;uuofc911%Y}S@$8;N{5}2oUbQIF)=84QllZ777fW$WH*4COI)k0=#Kb4& z>=+np28JQFE=62@kE)JOm*#d-N3HqG*gP7vP>VtiQj55J>fwZUbfF;&Wo#YVlkrQ5 z*3iKStxa%RF|3lr!LizViO>lea)1Dr=&fGmESZ}uki@kc_#*}|MXQQ^N6^zmxDs$ zdn7>-X(LuDMqNCgNu$9I!QS94h7(lHH)`8YOUSutbTz{&-@G&E8hWvp#7%CcH`uPs zL@vTM_t{#(MP9kt^4IrS!nrxGA~OjEeE{I0gLrUMgT(nS0(o!LoGI zG_#a#PRwC`&@}V-;>{Tmpn>c4yCtwN zZ5*8n2?G#l<~+hhmjuRNm)9h@khUJ5cdb4-n{ANiSecL)O0?+B|rkt_W$VbBe8TdE57mC_+)enrifYnN;t zZP071*I6v#gZ2D5m6DVW%h0SE%Ou~noihrCh>{n_c`h+_teVqPl#h3I3=Ah z*^EQ5C>n7e8Xag`5 zI}$B^2!vY{8C>|h<~e`MyQ28^3?{Gk00yr=;6HY4<=*Tw%`H{Lv?v;xAfSZ6g6&Lv zo;yL@Gzf68PnYk}uM&e>2<#Ga*33Bo>Ii`-jmg$W-wH+C zVNfF{6}ofjQK}o4Ez+W31$_l%jxl6pIcLqAtLyrT&)4_yFfziw$cFp|M21<}*(bGT z8>Tp+hJ`*y%~^lC3TAd5No8=m&cmZWZ!kAn|38HgB0x4~8m1iUQPhO*K9N3g$Kk1rmtg zUup?BY2u{8@HRonUnl6Lz!YD9MiWDtk!NwO^A(H1Z&hm$F+U1;hlcnpxzW{#ywH=g z_5gz}RI|Ys`z8Zl)kM*BH6Z@Cj z?^NvRNvD@?*W67CN%4dbKLC8+{(TD8Xp@IWcJVF9`Vw)K?XRS!as0;3uy+Fne+!|X zDFDxVnlht33WnTmTIib9@Zz`62);`pzFqL2=uA-ykVI2ubgb`tihXe_N&C#>uo9?R z>yScXww;#ZyZylO`G6LJ8`6BVtNa{jJiz`DV>XqOa7P&g_rTq%gd6%-SN3xr?-7Dv zB05b^`Vrf&F-%O-v&}{nj8^dgW(x%07_oZM(DQV~UvP{3y<}I}0(&rTYj_vgoN*9E zGx*54MT`us{H@6jNSP6MfldYoJWc=2Zl(gT5oDPCe^*Dg!c*{vKD;bSZbb^Wo9*0T=ZgcmVyFB(V`%mVChNCq@>LHf|(^(D1jSZ?PR1vho7tIz>(CF*g^w}L9ui8 zUe^SMNa_v|;X;G<>Nb;atJgzA24!u3=K}D4l#kd*JPQzC_*(WNteFY>Hp#-&J%}q0 zuwmUZw#1s7%}Oo-c58b>+2D_(%Ox`ksf2VBPZMhB%h%1^2Pb+e>mCX}WbbpS{PXd0 z2GH>msBJ@~^&Gd`;^#}6JUZ^z)b&nOt~|lP!K88@>}&u$*k+2LU%Zer{X)pAMnS53 z#c91gRAf|;wK@Vn^r*1YSzMKfavcdzzaV0Y1Ll)|R@${*zt0@AVf#L%RUDhZ?oHlobSa{V++lC0)kc=RPYa;V1`EZ{U z>MPBrMCP}`WPMtz7S6w{^rE#a;_u@{@Y5GpjNqKerS8Lovm2w1vP=AweTp=kmlF-% zPJg*uE^R<~UtGdu8_cU|zl#n8#BF?ZQ{sT$WzH`YefKFQ^M+dnm%*tXQ?R ziN{&dd8*q=-8tBy*i(z{?II6=^!by+#EBm+JC{8Ve?f!_(&3P`Fdq3 zcdv92(KRutG46c?6iZUMqpwnk0PYq!WoNyQrvYzD{ZIt3`K|L)Y+M&d_ORAzqcjNn zB4co4=pRT#WS^fA&?}G%DU>_|3(Uil&@oc1d4dK*p+kHGfF&#^!NJ5dHlw;YPESH4 zZ8$W8$JJm|0CZO17#LM^^(j2ZW#WeQQrLA5KyDUZwnl1ii83tf$9ka&j0P$TTlmd= z`Nb5{t17HZ%sMsGL1dV}j+^2nfXthYv!fK%UTs#E>W`5L z`3F6YEXe-;++;_@@7H=yV&-9G!I2u-`%FN|X8HD@hGq$P{EY8~4au??AN5!~9L&%& z`o-H?5AAl35D?>{g%069AfR4RA~&0`+}F5(i{@s(wzrV~P4twL>!?WR_-G)Edg?I` zM^5qNtIsXEu$P|?4rU#xnOBa!R37LSj7!4cW_*}C&V*6~uOPqS1MDJiyG zI7LZM{BwLN(g?Kj)leFzaZ3w5EVR#(@@u#OClpKcQFX`X-nV^zAEonO05 z^|qTDf7hTXrdEUa1k{#8BHk^+>|zi8N?bd}%Z6vkn0o%f^Ybs#0A|^*+hLqNxDS~b zR>SqJ2rCjvY_;R5rr}OOP)y+Sb#bW$RwTGham`km&9t0CAk9%n(SJ|`$jO=e90O%$)=$PrW$ZzvYSj*NN-0 zvvb3++pdpX&ma&-D&RFE`%i=gltIc9{dvnG^4jw{yp;?PH|cLB+c*1RC5_P+>WhnkK|stB5ot7*e$ycDx@qcAW#r*G}y;Q!G-x5*Ryo0N=y zpB9vw;Ar|v_0|8F5%!nw{#S?(95>|fUej!SR+oB_YvE;@m9ggyqaxjK1r>0-w*WkP z`VZu*r3gYqgv8A6|3EH9B;>+c)Watr*kat7HBbB-iy^<@23Smv;yXA$5uYf}o(tlc z-27KNKKQ!bPRIl3@SNGg9OWa6+q)5xK*@CXLofK(<_Wii%UVVJ?JhTf+*Eh1h`DE) zulMg*?hTk|ygQjj>dcZwfo}*vc2c)z2Z;1HkbvILy%KFahvD5z8tj+4Nig3XU!(l! zHdrtl?<>7_5bNlngy4oD8l@PXnS0b_*b}YaPVjTxX)<{Kt&a%13S2Rd%~a#=q&D^- zX35?KUlr!Etb`+Av#qE6Erz22k6j^8qeM8X@>|p8T1^;$(%`L16TXW2eN@t7l_haZ-3I zSE}o_=bB3FUs|MUFY&rL7I;|Jk4rm1SNy{REaNPC&o&Mc7+F0uoF!u1$I4epuHko$ zzvQr|v&j@af=qWuZ6+?(tb~R5BArGBwI~6&El9sv-q@G`xjpvu76$w9YOFgV%Zhzn zg;^zMM9#KB@TZlX=hlmsD_b0z4@Y3%a6^))`UEsn%R6U&{=7KZadCXjark8Lz}?m+ z+0V~<=eJY%C!6FFKXK2US9SC4hi#i+_12qMq9q0VJA)Sny;Tfy9qr!EFQZWNy6ZS) z3caS`SL394eHxIrk=3bch~<`9(zvnmmH9dI%kK`CiD&W)VZ$l7E+BvG{sfEw2|gDS zo?$CHX8t7AJM5jKN8(N)3gPnrzjnJ8U97N|7>$9#e*mViaD3ea+eystXuNWgF|m-r zY5^}>j$y8Yr(r)H8N4O^tVe;8zyGr-c=}SWLmKA30W2+==$E~5&1y|JbSGX;T z+Sz`)-0xIWjMH_mx9jbLfM-M_`L+jdWv|mQm!B2(!f%%=2z!C^?pD~V9Ds8zKYX$L zAf{wOEX-w5Tubqpg_2E!oGNwWcSuZx=iuNiQqrH{fl>-0oOaL*AK#-}>QDFZ*8@EB z7FzO9`K6x^^Hv*OIYK*^M_46#O@HffS5PYeAeN5uN^;?|R8%>7!^ES-qrK4T>!WUb zljnsRH%<>(bhCLcT8)bej9XNl+i(aE{kuDV5T> zY;?I&{`<$rxRjh8QY}ZyLr=2PK=Dx{@`Nf1^KO6n#+EM4gu0FL=1!)CTT#&ws3MQ0 z#vEUVd{iQeDdfoN#iS!C)QEox{N(NB&FiJ`qt$O^KZa zbKZDY=WI7d(_j3>Lh)#ddUM!AC;HQJe1)J*tM@Lxn=7Qt@32_8I>28(!1>qkzi`k6 zTxyb3qTBH6GmIXV#2F-ZT0Y$_44{@6tBMSX#U!56q@ka93eAmve*G%*d~$(KE#(<1P4`b?Z=P-7yX z9DQBMB4yT^*}_3j!U#queC|MP*y&83?oHJ*%!~Ssy@^5jNlvN1!&uQ+(X1rgX1Bc@ z<+m)(m1VCHIy(|07ksl5#*!28mpnkNQ`er}n$Kva8uNmb^lKipK(sTMHhca{#nVTd z{@Kf!_4ZU9O2V0rYuVn{C0f$5f?I84s}`zpc*)o~F-+Zz95*d(nKFU2^la@7IJt+R z-%n0r!mr+?L@!ZI61K>hJY=u0bp2fND>g=0SpSEYN)9tRN@ixw&&6*Z>fcm9F`vj# z%J{EGe%VSJE_&~AbbSTp^-Pn)%`6zdojLHw^_9y>AJB? z5{#tnzMIR?p-ETMgV+;7=UE5ihOudFlB%M2?$@C=7jcxM7j*AcCL=x9HuSJD%u+a< zC7Rn$s>E$y6Y;%YHP{XeO@2ih?j|T5ZDH&IJNXaVgNf-snZ;o?#ozgI{*Qi>z*FJz z^tp#ke>zZ>5qj8?p$U31C_#NaetvBFQj=J*k)?+U2CGvWHUVrd3FT#M`q#gOD2pT$ z_uZz(>}R+Fvbq&(+XWsBMu73blZwSJr9S#F;T8FS)PSeFeHV^WJBqqk=_h=>6YEs2 zhe6mSpgBOQpCM0Sd6ido46`$U4+&p9wiC9o9bw?nBz<6V{GL|_kj4sS;rMH;q51&5 zzi7tV^`|W&OFl1m<@w+s1(PTfTDhO=cScAJycSFa!4`mx9S_OYjwC!kgK+2%=-vH2eA* z)yb5*f)%Y*vu2L%d0e$T13{=l<0+Tzb*@5AfIJ1t>XyJWPU?Q8`_U@01v`{X_``*KU*8@KO~iV6B&%USIPl7d5-dG6Pu zl2N|~d^PggIZ1U*T~MyPII5%?zm@Md57iWE3#r0ykJV4@^UQ4s7#+)C8MbZqpY=Cg zt}v>`WRyY`6ra;D4|y1loa%I?A$ix-#37g5&cHB-ekNgl*}E@(;#CW8#>J@_VX@Dd zV70bNC5r?&rBqFlU8;{)*l1f(tIxDzuDs5SS1 zTQ+E1VEXJ*wJH@v-4OPgmi7a}rh*B?8{y%Je?088A>&VH&6?$=2XOQ zJsLrkw6By9Zj(!_uC?5n@Rcrg-JRl7BBc=JCQ11gj^n!9+(EGtb+lAO9{zgPB-j`5 zwX}wtk^7}592%VV-;0)tO157MN?u=o?n0%oMtE;C~1rNsyj+H?T(N1hjyiC~{ubX6fDha*;sN3sPo^ z!6!AW_7-X4-xrDs?-siVsicIN$Hwq6zuwHxbJcH)Pt10j`$Jt-BC5oJ@io!tAn9oi zi}x9W)2kxgEz>L92iVjk&r3OSMzaYq1Em?%IH;}1z7L6L)C8QlUT<$Y^{S11=@sX$ zHLUQ{UnR>Ds8Ro4^WU6ahYknQY5zV$+hCX9X_X)xa(a5o`&*!1(M^ZNV^6$fo~97j zdb*;VZrNoOqbM_Pk6|=Mrz22l-ZsnEM?|5ES9GOpC$MXFS zbMM};3;j)*>Rio^l+`p!tt<_{jE*0l>j6GHbv!$YK~S4?YieV9qu^ODJe8eTbIGpu zQ?W%=mFuol^_<29vb@5a;pE;(+v1b?Fk_zh8rv#S3{w`->I4#FTB?EFKG_rgR21=e1k!n;spk2wP<)7dNpFr_g!OV7(@zNVFd z93;sNUz^y+gbk67iD9+)cAJ-cs?P zhN2tUoqtEJ|4~~sAVqgy82YcA^`Ab0iV0Bh0IF2}>+k+;Y0`icZzbmc510D41nVCM zGaDNlYih=g&8!EE>8Px@QFtFC$=Vx`*mt}$m*PCl^qwFz$MEC350riE7BYbo(&)Z^ zeKeO$?4FXcDoLiN(P`y>92!n(4g-+h?2Qc)!8Z)X~zsmFG{NJ`K%w zt$Bzv1W!M4un8)E(zI=4_|=ULl^(lm!j!k881OMWo4r>KFSETZb5-G;!AhOU%|7D} zF&(2o?8$cBjs*ud>q^#nd*)MCwE39mXxxbS7$q&OlFmAzP-LYS*ka~aR@m5HZU)zT z6|S#m+i75tk;;8zKT#Fz3?qL_I%qSqGvh@rP|#Bo{r%`y_zO&bT+F58aOi${A^7^P z31)8+xwCS-*3aV<5yl5tVP^gXTv$ab1j7BGE{*0CME4a|O;?rxj`9BmB7n#e7!z5D zC&MFEUGe>H+Ft=7pheQmBPjlU&uB{|&MHS@z2BvL$85c^yVO3rFO(2i792IEFRX4W zZ>K&q&6^?q&5x`-tmp!G!c?!Y)|Ao_WdZ_%sFM2{{0}(~vR=GileurBa#4<& z$7$RNuMA&U!^;@Ah)g9{8`DgfvLiFEV=^}Ns>fWw!^SSxAhtU`=>CRLE)F!?KErYo zjXUZ;l6V|kU0w0<@hceGf~gCNMf9hlt1F67nuRv(E>Tc+#b3Li?8l})GBrQ|UTyAOe=gPJ+M z{AkzJJ`g{H?0x{z3o5u;9WqQY-_}S@t1b2~B#0adEwRr3Wf5W}X$Lr>E!Ut|LmC6dU>)z5l8n1HxH8a@nr($<583C^rfe z7Zv?ktWgI_P(Xs+)rh>zOrx9H#&^fP#p&tk5LExr(794Q-yv|_v%H*S=MjSw-9HN% zM|47h+Q*L{<>gsiPj?Rw4?%&+*GG8pu`xRDn%v#i`jS9%x_;wE&XW|+<}0EWTz$< z*4W4o?tgi9z^7@$2!;^Wwye7ZqH)|F9v+sKJD@3`twBLS-@bhV8TYe|ZZ0!0dunRx z5SC$0O-+xR6OSgV1;GnuEX=%;k`j;|n>OS^y~uV>(~Sx>^4X?C3~{yd82#*Z%Q;ok zQY*3sS}*1V7^KDaz>M^CFq#fqlM3qU>awy6<2qZDl}jC=mJJ+QuCAA$_aocMy23y$ zSx+oeE0a}5fwD_5|8GHQABi@*d`<&FebgBs3SKMDx-VtD0h}xG>>R##9&#@58{Ec9 zPD6v}IdaRZFDxk$aM)G{bwY0FJRBU=+1ZAG2AU@~g0TP_xWlu@_b|bLkZ@Q`>xJ^c zUsaw+0|S?8H+m#S-asr5Gl6RA>L<>owd=yPnf(3MUv;cYt@S7gKV*!&>yHu=_k@a6 z2*_%1=3#obifdq;O zpc|od)U)Z7p%_I*$1lcHM2U21s?}|<#9an8iQA9vHIZS^7GXtvOpbc@K>RsNi=Bl> zKKYqV0RJvLuZC{od|_W46O(DhT$(;9lEcsoNVNYYq%s^>S@=E4I+5XA#a<`Br(d3; zwtq{e^08EI@p+HOI?Jd6s?-IVr2+j5ojZ`tWar;`(zCI~S~ePBM76Y?WUgoTX!HCDfT1XgY>4 zr)iSP&viKuD+$xJ!WsB%`pQz{n@z^Y~nFsQ( zmgiMg?sb$c0aD>ED3-(j`XxDa&4w^^15@6rrv~+r$a}o50LI8Tq!O3y8?cC=q)B8`RVBVzVUZx*%tgd8RxJM(lg zhifC0NW-HH2(s%d25J)=;2#712IWR;!&KE-fu=b+n%$>(3o~ z{kZSi5Kb&4@7eLL6>D2$g#oR@+F1TlYU%iXKDo%rpDBw2%aD4w9|Qv|$uf!LqT_|T z@`A$q9*-=hQ*@*4$)M}&nE`DAr+jQ(4{0?7hSK*BM>%d>ek||K(`IE+7^MCB94aMk zR8iL~B*4Xc`fg+ORd%O*M}q z1%Zfr=pap~dZKL-q&%&~PDPo`{M<+)=ol>rXp5m8zMD0I&sGQXD^7CkMHV-ETG;gF zA>!llizL^F{~-SD-^W{12zXIJ$#I(Z<$xif}hVhTnv! zYfu5{Qamk&j`M@t zr0Wb?9!@TKd3ouBL1OXx*>3~_IN?n(;M*1)Er2rwn21sHwlorBSVT++%)R@)(P594N4zzHM9N@zdN zYQN?^=way+j0DH2oHV!`KmjBG7S$zqJ8pkQ*dh!UNP$Tx#;^9omD{fMl{8(eayd-bST}I0Jwj5DW~t|oeZcbq z)A;j@++<{#4O6%x%b>*oU`a|yfDu*&?M6H>69eM{Ho(YxdU~d7Z9yug7HU*mM~CJ3 zcO3wkT75D4dwcC+0?cG&ZMvnd2fZ9Kj-lIrIqX1|uW2|xKQGnm*ar2{-fk!GVvT}L zDpFFtju0X)o8`$$Xw=uiqOV_&o4&q2XCJG>lMbW)WImVUO{C3&c_S$!<6t_9IGVfK zcBng!p2h_$)W7Jo<|&iLub)LuJy9xD26qYxF)CqUX6{SoyZA|9$w{1-mse^@zqqop zvbJWF;s|OkL1TO|NVer9y3)hI1XQNOHSiY)*0s>;`|8yz5|_==?Rp6gBv0PIPKH07 zSDOcUi3|pyd>sJ6y}dn98x2P4xNM#$M{Pqq!RPvGWaA|m9Jl>Kom zl&q|FKw6+xEJXgKY)DnbcSdX-lwZTyn#sWkbaZs&q@^v$YW*^9GG(=62$F*ID=oG6Eyn)075M}P*bE-znhJ2deC_z32- zprD{?O7UhyQh0be9X@8*`MGn)-Ivmgb<0zU7ayuhEX0fUFww59X5VtB%(JVV`aX(3 zRIbMj`-?ufTVBpy97cHIX$FDZt|+H_NSSvIzEq$hoRSD-pQtn zI^{?f0ol%3-R6kbhhpZ{c~pZxw<^J_C4g@#1c3lR5m@;CTBdb8oUSj=mY0{o+z?Y# zjcA&&bRCIVpB+VJvqH9u8XK=cp8-gY;2Z&Xng_t|8_S(-?LRS)A-IsTpVF zZo5`dR<*!tH_+3>Z5t5Wy4l;pOgfacpWcZk^(jyWewwr=`mD$^!~WRTg(uF>s;ag# zJJoeINm~DRHcLXMm;W#YnhAMCT^B{<)_W}MQrj}UAA%>7ZS3pgAm4%{)cf^_q<$Q! z@W#biVDkaQQQM@jg!pRr&2Ys;B1@7q>g~^qec2`!Gq6$C-bIV}Yy;-!D$y#@1}f}? z=(GVgr@K1IYa7Xn%Nm+b3#uDu3o`z^oJ2c>&;y!bpfYI88uy-^g9BTzY@(1%BAiTW zOdzGdf4?W43QG_ao}v?RFYh^KZ|cPxDk&-P5>tK|_(tDdR*ZlZ?~Iq~fp=PXy-+`) zR}lEah&)ZnCBk){!HQWP%O6Jsri%+5+xhtVf&$%ZaWb3hOD92b=D5rY1qJdi#_`eb zDPpvLh*|^A3VEZzEOjvBFc;jeOFjY1ye}M1CL+73C^dZPF3?>TPBy5Er7^Za+%zZ9 zJ%srVr~S=~7x%r25kHMKV2vo=Y=2$>0{IXH?nU3d$Z26CZ;k0^pz=wo%0OkO@uE0C zzv7FKEY2{icZ)VBA^YPDz_$ttx>#%E7N z!P%47N@(*KDgJm1i0$&J8Ua^`ZQV@yPXFuoVq&PRuOGMCIXD2EDRH-!utxn{llMF*@vS2GuN~aHDNyxU;#qF zlEul^DZ0zXT<*Pna(;0k`BPg34%UQK-u(Q(Ue)sxqm!Al@uvIux`mIEHW%gflk zpUA>Zj~-XwR9tM)myvHug^o-x$bCl~nv;?73d?_7q3^g(o31;3WeiqK!eyhsUm9&d zj^$rfRds|I1Gjybm0kDH&CD96`1XlS?R1Mbs!-auEA@CX#BbQK%DjsGWq2Vx5e*f>v13J z&1d7gZiLL(+i4&WLh~9Yrrt~l5yD$aZQyN zILhd`Snd$C)3ro~b?7A}qT5xqlS@|Ak3;`7ML<>b;*Q-!N0hPvZSlZF<+HWGesk@X zwu1|7ZzM~DW-rx7@W)bnyRzQ&n@gXY(dG=ih4?x-B(M!`6wh62TBp5WSDt?rWJy1| z*G&AHxDlHu5SAOX;4eQcv?<;*S{r~R3j6Xm#ew0KEC4IQSfydR0D+T z=ug_UF>#v(n~81Sk1Z3G710Cw>d~q<#=zv~BxdeK=UGp7tGRttp$N@%|6Ehk#t9Zd zt>4%aK4Drrp}*&>RAUMjOwVe5PkmpeFY#-F27$Uc{QXGC?=kT|7cAUbeO1OoNg2=N zcGWUZokBET=vD6$(jSLK_b0?>_{#F&v0uevgJSTGSR4J*yZt|ir+n0i{D*4ijMV>2 zOqMeFwylcs%iZhzLd${ML}mVCB2Oyni&0^Ietv|cUW-U8{NgB?k5Wc42_wBpPO>la z$qvJA4i}bs#CwSLPq|z8P)&@Zm$-3eYhgOLDF{*F^AU=YWrp<5b&QQnE{a zxqBap($0{I>M+eW>w8TZ#b7We*UGGq|8wF|!}A0ZrPxOzO1Y$@q-JZSpQG{FK>6hu zJlt~YN755YLEJEaNANtAc{UnQ#=-VI-~z3vywL z(5qAw;8$QOK8#8cQ$9;gO)m)~_EV?cCWt(|WWCoKnWg=LySTXcU8h;d*pS=mpp{sn zS|&zMo%eV#jcP|_zAXpIzAE#&dh?YKNdf~B>AJblculxqd)wX5Fwj9axWcuJA( zo|T+tp$vDJBCl9Tf3aC4H{F_RGRa9NIw3BFAr>W^6aHEIs%nGOO86$PwveMGN(~vwc{8tF}gkSEb z{>>57pIq$+6;VuICv1YYr{(vO|Aol~$E(l!?7%9dyBi)PZn_0WNw(lpIMd={x)%oV z8QgR=H8tUiU}oZ*&ml{9h;|M%gj$UH^)_X5U4NYaMDa&x<2{=`4b+nlcVy}A zYKcZEM=kU!nA~bW-Z*jB<*1PP+MP4oE$rup-aaJHg>w#d_D;=&!Mw}3oM0hU(NE$$ z8E~7^dFt(`)bjE&pWA_h2UA=Zn}>#mB3Rm5Y!tib_p-z`#~MxOTtu@JL9#kCFPt%g z>-5^7xhjMH<7&a(X~w*gMId(qlD=Kad5Zey`fCU^C%(otk0{b>NU5T%+KN+7_r{~ z0SMoMG~%iJDJ>XL6y`DOaPH&g?DG=UUa`ehhizO!t@z^P^AO&Y>A{se7Rpi>9A@A8 zX2N`A=R8-y&J&-o!V#joL&U3ON%!&AVGeP<`g$5fU3L9;(K~*<=AJrDTwHwYTt6}_ zmu8(d`f?4LQL9be!&{biWN)(K_U|d@>1L^T%`aNG-ru-oupZxa=}IqIaq; zIPeV!Z*>9Bc>*P77$q{W?x&=2C-7@mVuv>80az~{4NR$W+62!>d zEM4{7?$!H*GlC=qB{Oy5=ySoHXjyC>BiP-&wL3jcBB3rO-v}lDA5u>GLVtZV*Gtfmh8rVpN znKAo>(H;a7QmC?Dm5}lwK)eR`9|5KeLFEwnG;eS3_zb5S{gt9>{YyvNMBc=3n6k<@ zC!be0N-z$2{C<~lBA9O^lb};=-o9%8&87((c~JO{g*M)*O7Aq2eK|Z|Mk7L?fS^H7 zyH7*{kGc~2#9NUiX@XVt4+BX%j-9C z{{APS7$I@*{ygu(S?PF2JWBa8^4>tjyDVJ(p|?zOh3Jb>$_++^gy5#wnq>ODv=_~| zOEXUSz3uF1P)$WGN~S3C?(;$rD5Ag&Y=4mJ_2RKq5 zKyt!kS^W-+=VE8Rsm9lal;mK#t zQ(e-IMaTP1lU${_=xAqJzkK-;v3;dQKOy5(usne2j_G0|Sxc+SJuEymYJhM5IgNb_ zY!`Ceqo`2M^AZ=LZr^8$E1EfNKRe%v=J<@GNMy^l<}QGMn-DSvcaH~%qf&%vZibng zo0|`#PPB^AwKWTC#bz&+qobmlnz##}2IUt$Y@9i)w#xezX#gG8?ZXMRjK4EKLF@Hp zZZxXE1Vw*L&iX0?KeGaVh6|b+%uDZhC%xXN1AyHzj39k~0*jWO8!yr~+pU_h)k}nJwk`E@nfX`4c;;I&>I7 z_Zp|*3D#cr*a%EDjeknTVjPV}85nFiSmKCY{8q@!Iim&I1W&pgI8Wk~;@JT1qKjmBWws^ZrP z#;=H-+3OQ&MI=J9t?=h-->;O3d{^!DDWhtV2$?8r@fz(JBn=ov!JnQNY$<^WiKk^$ zTNkI}_vYN2ukRmN-*eM5JMYv0{$B5OohWvBj>WKY5@8Uc;RtKvpJzWV`8gTHvh%2! zgnA@+AHJ-0@p2T#e(bz-f(GgH^B_q`pK9Y6SPlr`qX=19k9--}OPXqqlCvRLQos(mc~ zRARAC?BmwQiC!DD>kM^YN~7wJ(_Wfo4ntij&OfU=(-5A7)tnO-N5U;`^yRP2q^|zQLBKMGaE9at%N8KCJ%SQ%Hy=RXw1go9Thw+8BdH z6(Ky_2RC9RhQ9&4%;pV*V1}sW(IT2$|`I{Q%f{l@s3l(?K5w4E^JAH4O{SC zmzy_tpNMkHRm+9WLop4(bleLpE{^78Vo(kX*+_H_QZ*49=4&Pi%P)HyE13Y62&5V{NO zFqn4UB>V)(m2kTGWF$^n#QSL53KR?iq2^*0%m55^NF?pPjrj{0yuf8uQduU^1 z^E&a3fVWy%g>+{C1cTvi1f!%y<>oXqw<5t76>?d{G1GS4&m*pQ4 zrR2NiRVtse^e?ZfiHTuU>Dn<#-R@aAKUX*Q-<+DCkKxVMKq*EPn-c}AGI^C&L`u!# zlai7S(sTTBl)RGY>gX6dRt~lCNgMHKWAP9?iNeu%8F>~jOrZFLm6l%B&26Pf=8Z7R zwTH|AkGhJnv*SG;aDBcylJg?nr|PXpvH1g|%Y9Sh^F`_>J-ky-=BZ?00x=(0>ayCX z$9OW&*mCRhP%>9bM*rs9x8fZ1K)OT5Kyy;9aj&g?XTl#j;^h9dq2MS^%y#mQt*vdd zyB-mvSH|@+(TV5TU8W`}*@)H8SjS~FZcI#%2s#ud$%o|RXXq;Zc$aAUz#RQE zTl3+=g#Ov&{<@qaRQUoSQxe!J9kRbz*;`&5U!Id5z0Yr=RS50UE*VIX5?#kpZISrn?7!qE%B(43P{Gj&k!>gVl8)#5Za8JSm=olg4oL4g#~ z1oq~vsLUydaqHvRoeeJC`8r!;O_rHopyXLFCD7An4lL<)J^%cd_;_`dJZS2v-aK%y zad&G))4Nqm!B^Da_qyJ`(9P9#V5r*FX`t6(91;@RIXC*!Ze)DjSo3hxZeLYt=_LE( z=eEvbFmBKQ%l&dzf97d}znfU@-?{D*+-L5I?S+u*Ue`e_7@V~FAE^P=R{7Gr%Vx3R zWWI6d{R{&`X-{kg=x=3l*ESt?&Sf5i@d|j+w&7nA%)Vi;^i{M}LG5W(&-Q#dzos4> zE!ucb%0-~qnfyJkNLNZOj2PRTtR2(}Fcu`{m5fLC?O)2g1{cKl&;lDrSW$j*wWooe z9zPwA`TdG&ViNVj-b9HAP>rlX=%5#x*`ru zgB$0%mmzCkQEqOKT5wrOiBiOC`~=|3*;~1l_p+IePVkASlF8(Rvg`T#{*`nhcu)Bi zJ|C8rRTN?C!bmX(fyMJ{Z549mH;UwJv~v2~$Bj9Qt7OXBOs_dL7+pi2jlmy!!Hu^; zF{Pd$L>tfc6nZ~3*Y%QeV&=%$wJJwwPm5AfoHyJb5tkw5DuJ1J&$v&XSv-+kVlGR2Q%Kb|d9K(9%-DRR$`gV#8T9&!JZyYPwHyE zmr423ga7buesFYHu#_-@{P&XeT9g95WTbmi)C*Z*-6FMd%ur{g*Ca)+#<`?%jK@bY zf{BjSP7loc#>900lr{M{OAv8-7#-F6bT-nZW~h(3Wgv&kHswwe z9cJf}aCff$tMTzFNfJziPf3tD=>5$KF(BIAzfXh?4zn_dQ6V8Apz6zTzd;3>#lPTw z6&&2ItrHDRstmiuua0TEGvSK-k!pkO%gzgJ7qf4DXlaZyYubJ`iP@?-;o{;hv_Z4C z%rhq;!wiZdygow6vw>%LZr^=fSa`XrI9@PI;d-ZYMA@)=!laq3qP3pFg(!$jf3bdX zM`Ooj0kaf7Xc_t|K7&9?OM!t&tJ0a+5;uJJih!}?2;jh_ThE@g0;mGnex*wK+=gWxwii%$_amV_6k>-j*=~X_pVtE7<6g<90f;izS(je|&YG`ZQ<Q7eI(~}QF97==xrK=<|%5~Mw6Q{D>oNp&Q4_LIX1s~tc76Am{^GXlS-4XmVDO* z@Lx%k6)zQL2wH5T7_Hsj*4R_-cc^p&VEt|~MacC2H_rbb{}?F0O9^h7S3fx`7PYp; zo~9N|#nl`z4Xwty@S*3zwe=bO)Wii z{NupleKN;sDnk*Yf$9lI_wR?*m$GQmd>ARqG3rk9TX<15hC{G95neE%!A#y>BtMp3 z{=1JiZD;-G^~YNuh7(pW4`6#i7A%)~?N9nJrx)q6_E;TuUV7xQtus}&Q7b^>r+LH} zV1PT_uIknENoU(*)CX6WKZr@edo4VV`x*A60F%b(X?0nNN$PpvvAjwhBG3G6l_TEe z)Vgl%!>jDP!AA?p@YA)zE_lKiw= zM3t!lI?``4YHA(zm2KD(z#sON1(2MSlcOB*Bz_W_KA=EJS_{ebG&eH(G<~_Q!zR^g z`LVECqI#wEs4(8+yxFIvi9ZV@E;F&BHS#5*!3gN=3Q zSp#JUCFZ`HR0-gMW7QGG;7%IE1GJj)cJ`{T=9A-Jbv8Oj8UM*5Vp4he zWUZvUl$SSZnyajV_>|!w2w3)Hdu-0vyJgbp*beXfP6)xMg~Edr%?yH6X7C{%5h-h@9ojo#Zz5st~1vf0IZGsYcmE3DRpfM;FWskwc{aarJk`uAcc$Y~4A1I$znSU;_ z85~HJ(~-ETsi^|!1B$!zzh;;#Eg><)Ab)+O)6VIRgd+7DW*$Ig%iitGJUq{a^+<5i z^O)ZS`!%fCoKR ziJU$5!!rv1^c4(>s3y2oTuDl#0@@?q8pQ?@NopX~JU~YiNVnczlhN?DLp z=Xx+MCK%aVfbQR|U-5)Co((jD_dSp(Q2W5XhZL3b6IxQ9=yYS;tgPCSFeY7^wKP&7u;|>=Mx@8mt1j>NV=*@`Pr4* zFt5Wz%M*h?$pwNAT%;6hNylH@d0(P=&oc-0nw*d;guVv0;12401H8QczzG3Mg%i1b zx0u+UFAtULkKxp39s&|GgGt|p;Z z&=Mt)l-T-o`mi$}z%c@)GT%JD6uJc`Cnp0LY{O-!pt3rD&oJu>E>tP2*jO;8^Ok1? z6=os~p1ve%n3a8(klI)N2X%P_q6O9hU6}3GCH85>{*k8)(AhGtfe6x_0Y4}!Ca!i) zqd5Hdg^yb~pM_({-MThMu)xUbNP7y!E;xtN|Gy=d|KCq+dg!ezb@gaFhEu<(B&T*e JU)JEoe*umleFy*m literal 0 HcmV?d00001 diff --git a/examples/transformer/grouped_matmul/CMakeLists.txt b/examples/transformer/grouped_matmul/CMakeLists.txt new file mode 100644 index 00000000..5afd23c1 --- /dev/null +++ b/examples/transformer/grouped_matmul/CMakeLists.txt @@ -0,0 +1,57 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +add_executable(grouped_matmul_v2 + test_grouped_matmul_v2.cpp + grouped_matmul_utils.cpp +) +add_execute_example( + TARGET_NAME grouped_matmul_v2 + SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/run_grouped_matmul_case.sh + TEST_CASE grouped_matmul_v2 + ACLNN_FUNC "aclnnGroupedMatmulV2" +) + +add_executable(grouped_matmul_v3 + test_grouped_matmul_v3.cpp + grouped_matmul_utils.cpp +) +add_execute_example( + TARGET_NAME grouped_matmul_v3 + SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/run_grouped_matmul_case.sh + TEST_CASE grouped_matmul_v3 + ACLNN_FUNC "aclnnGroupedMatmulV3" +) + +add_executable(grouped_matmul_v4 + test_grouped_matmul_v4.cpp + grouped_matmul_utils.cpp +) +add_execute_example( + TARGET_NAME grouped_matmul_v4 + SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/run_grouped_matmul_case.sh + TEST_CASE grouped_matmul_v4 + ACLNN_FUNC "aclnnGroupedMatmulV4" +) + +install(TARGETS grouped_matmul_v2 grouped_matmul_v3 grouped_matmul_v4 + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin + OPTIONAL +) +target_link_libraries(grouped_matmul_v2 PRIVATE + -lc_sec +) +target_link_libraries(grouped_matmul_v3 PRIVATE + -lc_sec +) +target_link_libraries(grouped_matmul_v4 PRIVATE + -lc_sec +) \ No newline at end of file diff --git a/examples/transformer/grouped_matmul/grouped_matmul_generate_data.py b/examples/transformer/grouped_matmul/grouped_matmul_generate_data.py new file mode 100644 index 00000000..36647499 --- /dev/null +++ b/examples/transformer/grouped_matmul/grouped_matmul_generate_data.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# coding: utf-8 +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +import sys +import numpy as np + +def gen_tensor_list(low, high, shape_list, name, dtype=np.float16): + testCase = sys.argv[1] + for index, shape in enumerate(shape_list) : + tensor = np.random.uniform(low, high, shape).astype(dtype) + tensor.tofile(f"{testCase}_{name}_{index}.bin") + +if __name__ == '__main__': + testCase = sys.argv[1] + if testCase == 'grouped_matmul_v2': + gen_tensor_list(-0.1, 0.1, [(1, 16), (4, 32)], 'x') + gen_tensor_list(-0.1, 0.1, [(16, 24), (32, 16)], 'weight') + gen_tensor_list(-0.1, 0.1, [(24), (16)], 'bias') + if testCase == 'grouped_matmul_v3': + gen_tensor_list(-1, 1, [(16, 128)], 'x', np.float16) + gen_tensor_list(-128, 128, [(4, 128, 1024)], 'weight', np.int8) + gen_tensor_list(-0.5, 0.5, [(4, 1024)], 'bias', np.float16) + gen_tensor_list(-0.05, 0.05, [(4, 1024)], 'antiquant_scale', np.float16) + gen_tensor_list(-2, 2, [(4, 1024)], 'antiquant_offset', np.float16) + if testCase == 'grouped_matmul_v4': + gen_tensor_list(-128, 128, [(32, 5)], 'x', np.int8) + gen_tensor_list(-128, 128, [(2, 5, 10)], 'weight', np.int8) + gen_tensor_list(-256, 256, [(2, 10)], 'bias', np.int32) + gen_tensor_list(-0.05, 0.05, [(2, 10)], 'scale', np.float32) + gen_tensor_list(-0.1, 0.1, [(32)], 'pertoken_scale', np.float32) \ No newline at end of file diff --git a/examples/transformer/grouped_matmul/grouped_matmul_print_result.py b/examples/transformer/grouped_matmul/grouped_matmul_print_result.py new file mode 100644 index 00000000..6e34a23f --- /dev/null +++ b/examples/transformer/grouped_matmul/grouped_matmul_print_result.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# coding: utf-8 +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +import numpy as np +import sys +import os + +if __name__ == '__main__': + test_case = sys.argv[1] + print_head = "====================== Show sample " + test_case + " result start =================" + file = test_case + '_y_0.bin' + if not os.path.isfile(file): + raise RuntimeError(f"Invalid case name:", test_case) + y = np.fromfile(file, dtype=np.float16) + print(f"{test_case} output[0]: ", y) + file = test_case + '_y_1.bin' + if os.path.isfile(file): + y = np.fromfile(file, dtype=np.float16) + print(f"{test_case} output[1]: ", y) diff --git a/examples/transformer/grouped_matmul/grouped_matmul_utils.cpp b/examples/transformer/grouped_matmul/grouped_matmul_utils.cpp new file mode 100644 index 00000000..eeb9ae3b --- /dev/null +++ b/examples/transformer/grouped_matmul/grouped_matmul_utils.cpp @@ -0,0 +1,199 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_utils.cpp + * \brief + */ + +#include "grouped_matmul_utils.h" + +int64_t grouped_matmul_example::GetShapeSize(const std::vector &shape) +{ + int64_t shapeSize = 1; + for (auto i : shape) { + shapeSize *= i; + } + return shapeSize; +} + +int grouped_matmul_example::Init(int32_t deviceId, aclrtStream *stream) +{ + // (Fixed writing) Initialize AscendCL. + auto ret = aclInit(nullptr); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret); + ret = aclrtSetDevice(deviceId); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret); + ret = aclrtCreateStream(stream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret); + return 0; +} + +int grouped_matmul_example::ReadBinFileNNop(const std::string &filePath, void *buffer, size_t bufferSize) +{ + struct stat sBuf; + int fileStatus = stat(filePath.data(), &sBuf); + CHECK_RET(fileStatus == ACL_SUCCESS, LOG_PRINT("Failed to get file %s\n", filePath.c_str()); return -1); + + std::ifstream file; + file.open(filePath, std::ios::binary); + CHECK_RET(file.is_open(), LOG_PRINT("Open file failed.\n"); return -1); + + file.seekg(0, file.end); + uint64_t binFileBufferLen = file.tellg(); + CHECK_RET(binFileBufferLen == bufferSize, LOG_PRINT("Check file size failed.\n"); file.close(); return -1); + + file.seekg(0, file.beg); + file.read(static_cast(buffer), binFileBufferLen); + file.close(); + return ACL_SUCCESS; +} + +int grouped_matmul_example::CreateAclTensor(const std::string& filePath, const std::vector &shape, + void **deviceAddr, aclDataType dataType, aclTensor **tensor) +{ + auto size = GetShapeSize(shape) * aclDataTypeSize(dataType); + // Call aclrtMalloc to allocate memory on the device. + auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret); + + // Malloc host memory + void *binBufferHost = nullptr; + ret = aclrtMallocHost(&binBufferHost, size); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMallocHost failed. ERROR: %d\n", ret); return ret); + + // Read input data file + ret = ReadBinFileNNop(filePath, binBufferHost, size); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("ReadBinFileNNop failed. ERROR: %d\n", ret); + (void)aclrtFreeHost(binBufferHost); return ret); + + // Call aclrtMemcpy to copy the data on the host to the memory on the device. + ret = aclrtMemcpy(*deviceAddr, size, binBufferHost, size, ACL_MEMCPY_HOST_TO_DEVICE); + (void)aclrtFreeHost(binBufferHost); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret); + + // Compute the strides of the contiguous tensor. + std::vector strides(shape.size(), 1); + for (int64_t i = shape.size() - 2; i >= 0; i--) { + strides[i] = shape[i + 1] * strides[i + 1]; + } + + // Call aclCreateTensor to create an aclTensor. + *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, + shape.data(), shape.size(), *deviceAddr); + return 0; +} + +int grouped_matmul_example::CreateAclTensorList(const std::string& filePath, + const std::vector> &shapes, + void **deviceAddr, aclDataType dataType, aclTensorList **tensor) +{ + int size = shapes.size(); + aclTensor *tensors[size]; + for (int i = 0; i < size; i++) { + std::string tensorPath = filePath + "_" + std::to_string(i) + ".bin"; + int ret = CreateAclTensor(tensorPath, shapes[i], deviceAddr + i, dataType, tensors + i); + CHECK_RET(ret == ACL_SUCCESS, return ret); + } + *tensor = aclCreateTensorList(tensors, size); + return ACL_SUCCESS; +} + +void grouped_matmul_example::FreeParam(GroupedMatmulParams ¶ms) +{ + if (params.x != nullptr) { + aclDestroyTensorList(params.x); + } + if (params.weight != nullptr) { + aclDestroyTensorList(params.weight); + } + if (params.bias != nullptr) { + aclDestroyTensorList(params.bias); + } + if (params.scale != nullptr) { + aclDestroyTensorList(params.scale); + } + if (params.offset != nullptr) { + aclDestroyTensorList(params.offset); + } + if (params.antiquantScale != nullptr) { + aclDestroyTensorList(params.antiquantScale); + } + if (params.antiquantOffset != nullptr) { + aclDestroyTensorList(params.antiquantOffset); + } + if (params.perTokenScale != nullptr) { + aclDestroyTensorList(params.perTokenScale); + } + if (params.groupList != nullptr) { + aclDestroyIntArray(params.groupList); + } + if (params.groupListTensor != nullptr) { + aclDestroyTensor(params.groupListTensor); + } + if (params.y != nullptr) { + aclDestroyTensorList(params.y); + } +} + +void grouped_matmul_example::FreeAddr(GroupedMatmulDevAddr &addrs) +{ + int size = TENSOR_SIZE; + for (int i = 0; i < size; i++) { + if (addrs.x[i] != nullptr) { + aclrtFree(addrs.x[i]); + } + if (addrs.weight[i] != nullptr) { + aclrtFree(addrs.weight[i]); + } + if (addrs.bias[i] != nullptr) { + aclrtFree(addrs.bias[i]); + } + if (addrs.scale[i] != nullptr) { + aclrtFree(addrs.scale[i]); + } + if (addrs.offset[i] != nullptr) { + aclrtFree(addrs.offset[i]); + } + if (addrs.antiquantScale[i] != nullptr) { + aclrtFree(addrs.antiquantScale[i]); + } + if (addrs.antiquantOffset[i] != nullptr) { + aclrtFree(addrs.antiquantOffset[i]); + } + if (addrs.perTokenScale[i] != nullptr) { + aclrtFree(addrs.perTokenScale[i]); + } + if (addrs.groupList[i] != nullptr) { + aclrtFree(addrs.groupList[i]); + } + if (addrs.groupListTensor[i] != nullptr) { + aclrtFree(addrs.groupListTensor[i]); + } + if (addrs.y[i] != nullptr) { + aclrtFree(addrs.y[i]); + } + if (addrs.workspaceAddr != nullptr) { + aclrtFree(addrs.workspaceAddr); + } + } +} + +void grouped_matmul_example::FreeResource(GroupedMatmulParams ¶ms, GroupedMatmulDevAddr &addrs, + int32_t deviceId, aclrtStream *stream) +{ + FreeParam(params); + FreeAddr(addrs); + if (stream != nullptr) { + aclrtDestroyStream(*stream); + } + aclrtResetDevice(deviceId); + aclFinalize(); +} \ No newline at end of file diff --git a/examples/transformer/grouped_matmul/grouped_matmul_utils.h b/examples/transformer/grouped_matmul/grouped_matmul_utils.h new file mode 100644 index 00000000..fc1e21ff --- /dev/null +++ b/examples/transformer/grouped_matmul/grouped_matmul_utils.h @@ -0,0 +1,150 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_utils.h + * \brief + */ + +#ifndef EXAMPLE_GROUPED_MATMUL_UTILS_H +#define EXAMPLE_GROUPED_MATMUL_UTILS_H + +#include +#include +#include +#include +#include +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" + +namespace grouped_matmul_example{ +#define CHECK_RET(cond, return_expr) \ + do { \ + if (!(cond)) { \ + return_expr; \ + } \ + } while (0) + +#define LOG_PRINT(message, ...) \ + do { \ + printf(message, ##__VA_ARGS__); \ + } while (0) + +int64_t GetShapeSize(const std::vector &shape); + +template void SaveOutResult(std::string &fileName, std::vector &shape, + void **deviceAddr, aclDataType dataType) +{ + auto size = GetShapeSize(shape); + auto dtypeSize = aclDataTypeSize(dataType); + std::vector resultData(size, 0); + auto ret = aclrtMemcpy(resultData.data(), resultData.size() * dtypeSize, *deviceAddr, + size * dtypeSize, ACL_MEMCPY_DEVICE_TO_HOST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return); + std::ofstream file(fileName, std::ios::binary); + // Save data to file + file.write(static_cast((void *)resultData.data()), size * dtypeSize); + file.close(); +} + +int Init(int32_t deviceId, aclrtStream *stream); + +int ReadBinFileNNop(const std::string &filePath, void *buffer, size_t bufferSize); + +template +int CreateAclTensor(const std::vector &hostData, const std::vector &shape, void **deviceAddr, + aclDataType dataType, aclTensor **tensor) +{ + auto size = GetShapeSize(shape) * sizeof(T); + // Call aclrtMalloc to allocate memory on the device. + auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret); + + // Call aclrtMemcpy to copy the data on the host to the memory on the device. + ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret); + + // Compute the strides of the contiguous tensors. + std::vector strides(shape.size(), 1); + for (int64_t i = shape.size() - 2; i >= 0; i--) { + strides[i] = shape[i + 1] * strides[i + 1]; + } + + // Call aclCreateTensor to create an aclTensor. + *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, + shape.data(), shape.size(), *deviceAddr); + return 0; +} + +template +int CreateAclTensorList(const std::vector> &hostData, const std::vector> &shapes, + void **deviceAddr, aclDataType dataType, aclTensorList **tensor) +{ + int size = shapes.size(); + aclTensor *tensors[size]; + for (int i = 0; i < size; i++) { + int ret = CreateAclTensor(hostData[i], shapes[i], deviceAddr + i, dataType, tensors + i); + CHECK_RET(ret == ACL_SUCCESS, return ret); + } + *tensor = aclCreateTensorList(tensors, size); + return ACL_SUCCESS; +} + +int CreateAclTensor(const std::string& filePath, const std::vector &shape, void **deviceAddr, + aclDataType dataType, aclTensor **tensor); + +int CreateAclTensorList(const std::string& filePath, const std::vector> &shapes, + void **deviceAddr, aclDataType dataType, aclTensorList **tensor); + +struct GroupedMatmulParams { + aclTensorList *x = nullptr; + aclTensorList *weight = nullptr; + aclTensorList *bias = nullptr; + aclTensorList *scale = nullptr; + aclTensorList *offset = nullptr; + aclTensorList *antiquantScale = nullptr; + aclTensorList *antiquantOffset = nullptr; + aclTensorList *perTokenScale = nullptr; + aclIntArray *groupList = nullptr; + aclTensor *groupListTensor = nullptr; + aclTensorList *y = nullptr; + + // only support nullptr + aclTensorList *activationInput = nullptr; + aclTensorList *activationQuantScale = nullptr; + aclTensorList *activationQuantOffset = nullptr; + aclTensorList *activationFeatureOut = nullptr; + aclTensorList *dynQuantScaleOut = nullptr; +}; + +constexpr uint16_t TENSOR_SIZE = 2; +struct GroupedMatmulDevAddr { + void *x[TENSOR_SIZE] = {nullptr, nullptr}; + void *weight[TENSOR_SIZE] = {nullptr, nullptr}; + void *bias[TENSOR_SIZE] = {nullptr, nullptr}; + void *scale[TENSOR_SIZE] = {nullptr, nullptr}; + void *offset[TENSOR_SIZE] = {nullptr, nullptr}; + void *antiquantScale[TENSOR_SIZE] = {nullptr, nullptr}; + void *antiquantOffset[TENSOR_SIZE] = {nullptr, nullptr}; + void *perTokenScale[TENSOR_SIZE] = {nullptr, nullptr}; + void *groupList[TENSOR_SIZE] = {nullptr, nullptr}; + void *groupListTensor[TENSOR_SIZE] = {nullptr, nullptr}; + void *y[TENSOR_SIZE] = {nullptr, nullptr}; + void *workspaceAddr = nullptr; +}; + +void FreeParam(GroupedMatmulParams ¶ms); + +void FreeAddr(GroupedMatmulDevAddr &addrs); + +void FreeResource(GroupedMatmulParams ¶ms, GroupedMatmulDevAddr &addrs, int32_t deviceId, aclrtStream *stream); +} // namespace grouped_matmul_example + +#endif \ No newline at end of file diff --git a/examples/transformer/grouped_matmul/run_grouped_matmul_case.sh b/examples/transformer/grouped_matmul/run_grouped_matmul_case.sh new file mode 100644 index 00000000..5b6b12fc --- /dev/null +++ b/examples/transformer/grouped_matmul/run_grouped_matmul_case.sh @@ -0,0 +1,32 @@ +#!/bin/bash +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +set -e + +CURRENT_DIR=$(dirname $(readlink -f ${BASH_SOURCE[0]})) + +case_name=$1 +test_program=$2 +test_case=$3 + +echo "=========================================== Run $case_name ====================================" +python3 ${CURRENT_DIR}/grouped_matmul_generate_data.py $test_case +echo "=========================================== Execute $case_name sample start ====================" +# Execute test program +${test_program} $test_case +if [ $? -ne 0 ];then + echo "Error: Execute ${test_program} failed." + exit 1 +fi +echo "=========================================== Execute $case_name sample end ======================" +python3 ${CURRENT_DIR}/grouped_matmul_print_result.py $test_case +rm -rf *.bin +echo "=========================================== Run $case_name success ==============================" +exit 0 \ No newline at end of file diff --git a/examples/transformer/grouped_matmul/test_grouped_matmul_v2.cpp b/examples/transformer/grouped_matmul/test_grouped_matmul_v2.cpp new file mode 100644 index 00000000..7d4949b1 --- /dev/null +++ b/examples/transformer/grouped_matmul/test_grouped_matmul_v2.cpp @@ -0,0 +1,115 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file test_grouped_matmul_v2.cpp + * \brief + */ + +#include "grouped_matmul_utils.h" +#include "aclnnop/aclnn_grouped_matmul_v2.h" + +namespace grouped_matmul_example { +std::vector> yShape; + +int PrepareFloat16TensorList(const std::string &testCase, const std::string ¤tPath, GroupedMatmulParams ¶ms, + GroupedMatmulDevAddr &addrs) +{ + // The shape value must be the same with value in file ./grouped_matmul_generate_data.py. + std::vector> xShape = {{1, 16}, {4, 32}}; + std::vector> weightShape = {{16, 24}, {32, 16}}; + std::vector> biasShape = {{24}, {16}}; + yShape = {{1, 24}, {4, 16}}; + std::vector> yData; + for (auto &i : yShape) { + yData.push_back(std::vector(GetShapeSize(i), 0)); + } + + std::string xPath = currentPath + testCase + "_x"; + auto ret = CreateAclTensorList(xPath, xShape, addrs.x, aclDataType::ACL_FLOAT16, ¶ms.x); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Prepare Tensor x failed\n"); return ret); + std::string weightPath = currentPath + testCase + "_weight"; + ret = CreateAclTensorList(weightPath, weightShape, addrs.weight, aclDataType::ACL_FLOAT16, ¶ms.weight); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Prepare Tensor weight failed\n"); return ret); + std::string biasPath = currentPath + testCase + "_bias"; + ret = CreateAclTensorList(biasPath, biasShape, addrs.bias, aclDataType::ACL_FLOAT16, ¶ms.bias); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Prepare Tensor bias failed\n"); return ret); + ret = CreateAclTensorList(yData, yShape, addrs.y, aclDataType::ACL_FLOAT16, ¶ms.y); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Prepare Tensor y failed\n"); return ret); + + return ACL_SUCCESS; +} +} // namespace grouped_matmul_example +using namespace grouped_matmul_example; + +int main(int argc, char **argv) +{ + if (argv == nullptr || argc < 2) { // 2: Input num, exeFile and testCase. + LOG_PRINT("Number of input parameter error, except >= 2 but got %d inputs.\n", argc); + return 0; + } + std::string exeFile(argv[0]); + std::string currentPath = std::string(exeFile.substr(0, exeFile.rfind('/')) + "/"); + std::string testCase(argv[1]); + // 1. (Fixed writing) Initialize the device and stream. For details, see the list of external AscendCL APIs. + // Set the device ID in use. + int32_t deviceId = 0; + aclrtStream stream; + auto ret = Init(deviceId, &stream); + // Use CHECK as required. + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Init acl failed. ERROR: %d\n", ret); return ret); + + // 2. Construct the input and output based on the API. + int64_t splitItem = 0; + int64_t groupType = -1; + GroupedMatmulParams params; + GroupedMatmulDevAddr addrs; + ret = PrepareFloat16TensorList(testCase, currentPath, params, addrs); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Prepare failed. ERROR: %d\n", ret); + FreeResource(params, addrs, deviceId, &stream); return ret); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + // 3. Call the CANN operator library API. + // Call the first-phase API of aclnnGroupedMatmulV2. + ret = aclnnGroupedMatmulV2GetWorkspaceSize(params.x, params.weight, params.bias, params.scale, params.offset, + params.antiquantScale, params.antiquantOffset, params.groupList, + splitItem, groupType, params.y, &workspaceSize, &executor); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnGroupedMatmulGetWorkspaceSize failed. ERROR: %d\n", ret); + FreeResource(params, addrs, deviceId, &stream); return ret); + + // Malloc device memory for workspace based on the workspaceSize calculated from the first interface + if (workspaceSize > 0) { + ret = aclrtMalloc(&addrs.workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); + FreeResource(params, addrs, deviceId, &stream); return ret); + } + + // Call the second-phase API of aclnnGroupedMatmulV2. + ret = aclnnGroupedMatmulV2(addrs.workspaceAddr, workspaceSize, executor, stream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnGroupedMatmul failed. ERROR: %d\n", ret); + FreeResource(params, addrs, deviceId, &stream); return ret); + + // 4. (Fixed writing) Wait until the task execution is complete. + ret = aclrtSynchronizeStream(stream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret); + FreeResource(params, addrs, deviceId, &stream); return ret); + + // 5. Copy output from device to host and then save to file + for (int i = 0; i < yShape.size(); ++i) { + std::string outFile = testCase + "_y_" + std::to_string(i) + ".bin"; + SaveOutResult(outFile, yShape[i], &addrs.y[i], aclDataType::ACL_FLOAT16); + } + + // 6. Release aclTensor and device resource. + FreeResource(params, addrs, deviceId, &stream); + return 0; +} \ No newline at end of file diff --git a/examples/transformer/grouped_matmul/test_grouped_matmul_v3.cpp b/examples/transformer/grouped_matmul/test_grouped_matmul_v3.cpp new file mode 100644 index 00000000..75ffe4fa --- /dev/null +++ b/examples/transformer/grouped_matmul/test_grouped_matmul_v3.cpp @@ -0,0 +1,128 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file test_grouped_matmul_v3.cpp + * \brief + */ + +#include "grouped_matmul_utils.h" +#include "aclnnop/aclnn_grouped_matmul_v3.h" + +namespace grouped_matmul_example { +std::vector> yShape; + +int PrepareAntiquant(const std::string &testCase, const std::string ¤tPath, GroupedMatmulParams ¶ms, + GroupedMatmulDevAddr &addrs) +{ + // The shape value must be the same with value in file ./grouped_matmul_generate_data.py. + std::vector> xShape = {{16, 128}}; + std::vector> weightShape = {{4, 128, 1024}}; + std::vector> biasShape = {{4, 1024}}; + std::vector> scaleShape = {{4, 1024}}; + std::vector> offsetShape = {{4, 1024}}; + std::vector groupShape = {4}; + std::vector groupList = {4, 8, 12, 16}; + yShape = {{16, 1024}}; + std::vector> yData; + for (auto &i : yShape) { + yData.push_back(std::vector(GetShapeSize(i), 0)); + } + + std::string xPath = currentPath + testCase + "_x"; + auto ret = CreateAclTensorList(xPath, xShape, addrs.x, aclDataType::ACL_FLOAT16, ¶ms.x); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Prepare TensorList x failed\n"); return ret); + std::string weightPath = currentPath + testCase + "_weight"; + ret = CreateAclTensorList(weightPath, weightShape, addrs.weight, aclDataType::ACL_INT8, ¶ms.weight); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Prepare TensorList weight failed\n"); return ret); + std::string biasPath = currentPath + testCase + "_bias"; + ret = CreateAclTensorList(biasPath, biasShape, addrs.bias, aclDataType::ACL_FLOAT16, ¶ms.bias); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Prepare TensorList bias failed\n"); return ret); + std::string scalePath = currentPath + testCase + "_antiquant_scale"; + ret = CreateAclTensorList(scalePath, scaleShape, addrs.antiquantScale, aclDataType::ACL_FLOAT16, + ¶ms.antiquantScale); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Prepare TensorList antiquantScale failed\n"); return ret); + std::string offsetPath = currentPath + testCase + "_antiquant_offset"; + ret = CreateAclTensorList(offsetPath, offsetShape, addrs.antiquantOffset, aclDataType::ACL_FLOAT16, + ¶ms.antiquantOffset); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Prepare TensorList antiquantOffset failed\n"); return ret); + ret = CreateAclTensor(groupList, groupShape, addrs.groupListTensor, aclDataType::ACL_INT64, ¶ms.groupListTensor); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Prepare TensorList groupList failed\n"); return ret); + ret = CreateAclTensorList(yData, yShape, addrs.y, aclDataType::ACL_FLOAT16, ¶ms.y); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Prepare TensorList y failed\n"); return ret); + + return ACL_SUCCESS; +} +} // namespace grouped_matmul_example +using namespace grouped_matmul_example; + +int main(int argc, char **argv) +{ + if (argv == nullptr || argc < 2) { // 2: Input num, exeFile and testCase. + LOG_PRINT("Number of input parameter error, except >= 2 but got %d inputs.\n", argc); + return 0; + } + std::string exeFile(argv[0]); + std::string currentPath = std::string(exeFile.substr(0, exeFile.rfind('/')) + "/"); + std::string testCase(argv[1]); + // 1. (Fixed writing) Initialize the device and stream. For details, see the list of external AscendCL APIs. + // Set the device ID in use. + int32_t deviceId = 0; + aclrtStream stream; + auto ret = Init(deviceId, &stream); + // Use CHECK as required. + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Init acl failed. ERROR: %d\n", ret); return ret); + + // 2. Construct the input and output based on the API. + int64_t splitItem = 3; + int64_t groupType = 0; + GroupedMatmulParams params; + GroupedMatmulDevAddr addrs; + ret = PrepareAntiquant(testCase, currentPath, params, addrs); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Prepare failed. ERROR: %d\n", ret); + FreeResource(params, addrs, deviceId, &stream); return ret); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + // 3. Call the CANN operator library API. + // Call the first-phase API of aclnnGroupedMatmulV3. + ret = aclnnGroupedMatmulV3GetWorkspaceSize(params.x, params.weight, params.bias, params.scale, params.offset, + params.antiquantScale, params.antiquantOffset, params.groupListTensor, + splitItem, groupType, params.y, &workspaceSize, &executor); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnGroupedMatmulGetWorkspaceSize failed. ERROR: %d\n", ret); + FreeResource(params, addrs, deviceId, &stream); return ret); + + // Malloc device memory for workspace based on the workspaceSize calculated from the first interface + if (workspaceSize > 0) { + ret = aclrtMalloc(&addrs.workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); + FreeResource(params, addrs, deviceId, &stream); return ret); + } + // Call the second-phase API of aclnnGroupedMatmulV3. + ret = aclnnGroupedMatmulV3(addrs.workspaceAddr, workspaceSize, executor, stream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnGroupedMatmul failed. ERROR: %d\n", ret); + FreeResource(params, addrs, deviceId, &stream); return ret); + + // 4. (Fixed writing) Wait until the task execution is complete. + ret = aclrtSynchronizeStream(stream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret); + FreeResource(params, addrs, deviceId, &stream); return ret); + + // 5. Copy output from device to host and then save to file + for (int i = 0; i < yShape.size(); ++i) { + std::string outFile = testCase + "_y_" + std::to_string(i) + ".bin"; + SaveOutResult(outFile, yShape[i], &addrs.y[i], aclDataType::ACL_FLOAT16); + } + + // 6. Release aclTensor and device resource. + FreeResource(params, addrs, deviceId, &stream); + return 0; +} \ No newline at end of file diff --git a/examples/transformer/grouped_matmul/test_grouped_matmul_v4.cpp b/examples/transformer/grouped_matmul/test_grouped_matmul_v4.cpp new file mode 100644 index 00000000..08441b33 --- /dev/null +++ b/examples/transformer/grouped_matmul/test_grouped_matmul_v4.cpp @@ -0,0 +1,135 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file test_grouped_matmul_v4.cpp + * \brief + */ + +#include "grouped_matmul_utils.h" +#include "aclnnop/aclnn_grouped_matmul_v4.h" + +namespace grouped_matmul_example { +std::vector> yShape; + +int PrepareQuant(const std::string &testCase, const std::string ¤tPath, GroupedMatmulParams ¶ms, + GroupedMatmulDevAddr &addrs) +{ + // The shape value must be the same with value in file ./grouped_matmul_generate_data.py. + std::vector> xShape = {{32, 5}}; + std::vector> weightShape = {{2, 5, 10}}; + std::vector> biasShape = {{2, 10}}; + std::vector> scaleShape = {{2, 10}}; + std::vector> offsetShape = {{2, 10}}; + std::vector> pertokenShape = {{32}}; + std::vector groupShape = {2}; + std::vector groupList = {16, 16}; + yShape = {{32, 10}}; + std::vector> yData; + for (auto &i : yShape) { + yData.push_back(std::vector(GetShapeSize(i), 0)); + } + + std::string xPath = currentPath + testCase + "_x"; + auto ret = CreateAclTensorList(xPath, xShape, addrs.x, aclDataType::ACL_INT8, ¶ms.x); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Prepare TensorList x failed\n"); return ret); + std::string weightPath = currentPath + testCase + "_weight"; + ret = CreateAclTensorList(weightPath, weightShape, addrs.weight, aclDataType::ACL_INT8, ¶ms.weight); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Prepare TensorList weight failed\n"); return ret); + std::string biasPath = currentPath + testCase + "_bias"; + ret = CreateAclTensorList(biasPath, biasShape, addrs.bias, aclDataType::ACL_INT32, ¶ms.bias); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Prepare TensorList bias failed\n"); return ret); + std::string scalePath = currentPath + testCase + "_scale"; + ret = CreateAclTensorList(scalePath, scaleShape, addrs.scale, aclDataType::ACL_FLOAT, + ¶ms.scale); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Prepare TensorList scale failed\n"); return ret); + std::string pertokenPath = currentPath + testCase + "_pertoken_scale"; + ret = CreateAclTensorList(pertokenPath, pertokenShape, addrs.perTokenScale, aclDataType::ACL_FLOAT, + ¶ms.perTokenScale); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Prepare TensorList perTokenScale failed\n"); return ret); + ret = CreateAclTensor(groupList, groupShape, addrs.groupListTensor, aclDataType::ACL_INT64, ¶ms.groupListTensor); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Prepare TensorList groupList failed\n"); return ret); + ret = CreateAclTensorList(yData, yShape, addrs.y, aclDataType::ACL_FLOAT16, ¶ms.y); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Prepare TensorList y failed\n"); return ret); + + return ACL_SUCCESS; +} +} // namespace grouped_matmul_example +using namespace grouped_matmul_example; + +int main(int argc, char **argv) +{ + if (argv == nullptr || argc < 2) { // 2: Input num, exeFile and testCase. + LOG_PRINT("Number of input parameter error, except >= 2 but got %d inputs.\n", argc); + return 0; + } + std::string exeFile(argv[0]); + std::string currentPath = std::string(exeFile.substr(0, exeFile.rfind('/')) + "/"); + std::string testCase(argv[1]); + // 1. (Fixed writing) Initialize the device and stream. For details, see the list of external AscendCL APIs. + // Set the device ID in use. + int32_t deviceId = 0; + aclrtStream stream; + auto ret = Init(deviceId, &stream); + // Use CHECK as required. + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Init acl failed. ERROR: %d\n", ret); return ret); + + // 2. Construct the input and output based on the API. + int64_t splitItem = 3; + int64_t groupType = 0; + int64_t groupListType = 1; + int64_t actType = 0; + GroupedMatmulParams params; + GroupedMatmulDevAddr addrs; + ret = PrepareQuant(testCase, currentPath, params, addrs); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Prepare failed. ERROR: %d\n", ret); + FreeResource(params, addrs, deviceId, &stream); return ret); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + // 3. Call the CANN operator library API. + // Call the first-phase API of aclnnGroupedMatmulV4. + ret = aclnnGroupedMatmulV4GetWorkspaceSize(params.x, params.weight, params.bias, params.scale, params.offset, + params.antiquantScale, params.antiquantOffset, params.perTokenScale, + params.groupListTensor, params.activationInput, + params.activationQuantScale, params.activationQuantOffset, + splitItem, groupType, groupListType, actType, params.y, + params.activationFeatureOut, params.dynQuantScaleOut, + &workspaceSize, &executor); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnGroupedMatmulGetWorkspaceSize failed. ERROR: %d\n", ret); + FreeResource(params, addrs, deviceId, &stream); return ret); + + // Malloc device memory for workspace based on the workspaceSize calculated from the first interface + if (workspaceSize > 0) { + ret = aclrtMalloc(&addrs.workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); + FreeResource(params, addrs, deviceId, &stream); return ret); + } + // Call the second-phase API of aclnnGroupedMatmulV4. + ret = aclnnGroupedMatmulV4(addrs.workspaceAddr, workspaceSize, executor, stream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnGroupedMatmul failed. ERROR: %d\n", ret); + FreeResource(params, addrs, deviceId, &stream); return ret); + + // 4. (Fixed writing) Wait until the task execution is complete. + ret = aclrtSynchronizeStream(stream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret); + FreeResource(params, addrs, deviceId, &stream); return ret); + + // 5. Copy output from device to host and then save to file + for (int i = 0; i < yShape.size(); ++i) { + std::string outFile = testCase + "_y_" + std::to_string(i) + ".bin"; + SaveOutResult(outFile, yShape[i], &addrs.y[i], aclDataType::ACL_FLOAT16); + } + + // 6. Release aclTensor and device resource. + FreeResource(params, addrs, deviceId, &stream); + return 0; +} \ No newline at end of file diff --git a/src/transformer/grouped_matmul/grouped_matmul.cpp b/src/transformer/grouped_matmul/grouped_matmul.cpp new file mode 100644 index 00000000..e2c502b4 --- /dev/null +++ b/src/transformer/grouped_matmul/grouped_matmul.cpp @@ -0,0 +1,202 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul.cpp + * \brief + */ +#include "grouped_matmul_utils.h" +#include "grouped_matmul_antiquant.h" +#include "grouped_matmul_vector.h" +#include "grouped_matmul.h" + +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 + +#include "grouped_matmul_antiquant_a16w8_msd.h" +#include "grouped_matmul_quant_mixcore.h" +#endif + + +using namespace AscendC; +using namespace matmul; +using namespace GROUPED_MATMUL; + +#ifndef FORMAT_FRACTAL_NZ + #define FORMAT_FRACTAL_NZ +#endif + +#if defined(FORMAT_WEIGHT) && FORMAT_WEIGHT == FORMAT_FRACTAL_NZ +constexpr CubeFormat wFormat = CubeFormat::NZ; +constexpr MatmulConfig matmulCFG = NZ_CFG_MDL; +#else +constexpr CubeFormat wFormat = CubeFormat::ND; +constexpr MatmulConfig matmulCFG = CFG_MDL; +#endif + +template +using xType = MatmulType; + +template +using xTypeMSD = MatmulType; + +template +using weightType = MatmulType; + +template +using weightTypeMSD = MatmulType; + +using yType = MatmulType; + +using yTypeMSD = MatmulType; + +using biasType = MatmulType; + +#define GMM_IMP(computeClass, processClass, transA, transB, sync, cfg) \ + do { \ + using matmulType = MMType, weightType, yType, biasType, cfg>; \ + matmulType::MT mm; \ + GET_TILING_DATA_MEMBER(GMMTilingData, gmmBaseParams, gmmBaseParams_, tiling); \ + GET_TILING_DATA_MEMBER(GMMTilingData, mmTilingData, mmTilingData_, tiling); \ + REGIST_MATMUL_OBJ(&tPipe, GetSysWorkSpacePtr(), mm, &mmTilingData_); \ + computeClass computeOp(mm); \ + computeOp.Init(x, weight, bias, scale, offset, antiquantScale, antiquantOffset, groupList, perTokenScale, \ + y, user1, &gmmBaseParams_, &mmTilingData_, &tPipe); \ + processClass op(computeOp); \ + op.Init(&gmmBaseParams_, &mmTilingData_, groupList, tiling); \ + op.Process(); \ + } while (0) + +#define GMM_CUBE_IMP(transA, transB, sync, cfg) \ + do { \ + if ASCEND_IS_AIV { \ + return; \ + } \ + using matmulType = MMImplType, weightType, yType, biasType, cfg>; \ + matmulType::MT mm; \ + GET_TILING_DATA_MEMBER(GMMTilingData, gmmBaseParams, gmmBaseParams_, tiling); \ + GET_TILING_DATA_MEMBER(GMMTilingData, mmTilingData, mmTilingData_, tiling); \ + mm.SetSubBlockIdx(0); \ + mm.Init(&mmTilingData_, &tPipe); \ + GMMCompute computeOp(mm); \ + computeOp.Init(x, weight, bias, scale, offset, antiquantScale, antiquantOffset, groupList, perTokenScale, \ + y, user1, &gmmBaseParams_, &mmTilingData_, &tPipe); \ + GMMProcess op(computeOp); \ + op.Init(&gmmBaseParams_, &mmTilingData_, groupList, tiling); \ + op.Process(); \ + } while (0) + +#define GMM_CV_SPLIT_IMP(computeClass, processClass, transA, transB, sync, cfg, aType, bType, cType) \ + do { \ + using matmulType = MMImplType, bType, cType, biasType, cfg>; \ + matmulType::MT mm; \ + GET_TILING_DATA_MEMBER(GMMTilingData, gmmBaseParams, gmmBaseParams_, tiling); \ + GET_TILING_DATA_MEMBER(GMMTilingData, mmTilingData, mmTilingData_, tiling); \ + if ASCEND_IS_AIC { \ + mm.SetSubBlockIdx(0); \ + mm.Init(&mmTilingData_, &tPipe); \ + } \ + computeClass computeOp(mm); \ + computeOp.Init(x, weight, bias, scale, offset, antiquantScale, antiquantOffset, groupList, perTokenScale, \ + y, user1, &gmmBaseParams_, &mmTilingData_, &tPipe); \ + processClass op(computeOp); \ + op.Init(&gmmBaseParams_, &mmTilingData_, groupList, tiling); \ + op.Process(); \ + } while (0) + +extern "C" __global__ __aicore__ void grouped_matmul(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR scale, + GM_ADDR offset, GM_ADDR antiquantScale, GM_ADDR antiquantOffset, + GM_ADDR groupList, GM_ADDR perTokenScale, GM_ADDR y, + GM_ADDR workspace, GM_ADDR tiling) { + TPipe tPipe; + AscendCUtils::SetOverflow(1); + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIC_ONLY); + GM_ADDR user1 = GetUserWorkspace(workspace); + +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 +#if defined(GMM_ANTI_QUANT) + if (TILING_KEY_IS(0)) { + KERNEL_TASK_TYPE(0, KERNEL_TYPE_MIX_AIC_1_1); + GMM_IMP(GMMAntiquantComputeNorm, GMMAntiquantProcess, false, false, false, matmulCFG); + } else if (TILING_KEY_IS(2)) { // weight tansposed + KERNEL_TASK_TYPE(2, KERNEL_TYPE_MIX_AIC_1_1); + GMM_IMP(GMMAntiquantComputeNorm, GMMAntiquantProcess, false, true, false, matmulCFG); + } else if (TILING_KEY_IS(3)) { // antiquant performence + KERNEL_TASK_TYPE(3, KERNEL_TYPE_MIX_AIC_1_2); + GMM_IMP(GMMAntiquantComputePerformance, GMMAntiquantProcess, false, false, false, matmulCFG); + } + #if defined(ORIG_DTYPE_WEIGHT) && defined(DT_INT8) && ORIG_DTYPE_WEIGHT == DT_INT8 + if (TILING_KEY_IS(6)) { // antiquant msd + KERNEL_TASK_TYPE(6, KERNEL_TYPE_MIX_AIC_1_1); + GMM_CV_SPLIT_IMP(GMMA16W8MSDCompute, GMMA16W8MSDProcess, false, false, false, matmulCFG, + xTypeMSD, weightTypeMSD, yTypeMSD); + } else if (TILING_KEY_IS(7)) { // antiquant msd weight tansposed + KERNEL_TASK_TYPE(7, KERNEL_TYPE_MIX_AIC_1_1); + GMM_CV_SPLIT_IMP(GMMA16W8MSDCompute, GMMA16W8MSDProcess, false, true, false, matmulCFG, + xTypeMSD, weightTypeMSD, yTypeMSD); + } + #endif +#elif defined(GMM_QUANT_BF16) || defined(GMM_QUANT_FLOAT16) + if (TILING_KEY_IS(0)) { + KERNEL_TASK_TYPE(0, KERNEL_TYPE_MIX_AIC_1_1); + GMM_CV_SPLIT_IMP(GMMQuantMixCoreCompute, GMMProcess, false, false, false, matmulCFG, xType, weightType, yType); + } else if (TILING_KEY_IS(2)) { // weight tansposed + KERNEL_TASK_TYPE(2, KERNEL_TYPE_MIX_AIC_1_1); + GMM_CV_SPLIT_IMP(GMMQuantMixCoreCompute, GMMProcess, false, true, false, matmulCFG, xType, weightType, yType); + } else if (TILING_KEY_IS(4)) { + KERNEL_TASK_TYPE(4, KERNEL_TYPE_MIX_AIC_1_2); + GMM_CV_SPLIT_IMP(GMMQuantMixCoreCompute, GMMProcess, false, false, false, matmulCFG, xType, weightType, yType); + } else if (TILING_KEY_IS(5)) { // weight tansposed + KERNEL_TASK_TYPE(5, KERNEL_TYPE_MIX_AIC_1_2); + GMM_CV_SPLIT_IMP(GMMQuantMixCoreCompute, GMMProcess, false, true, false, matmulCFG, xType, weightType, yType); + } +#else + if (TILING_KEY_IS(0)) { + KERNEL_TASK_TYPE(0, KERNEL_TYPE_MIX_AIC_1_0); + GMM_CUBE_IMP(false, false, false, matmulCFGUnitFlag); + } else if (TILING_KEY_IS(2)) { // weight transposed + KERNEL_TASK_TYPE(2, KERNEL_TYPE_MIX_AIC_1_0); + GMM_CUBE_IMP(false, true, false, matmulCFGUnitFlag); + } +#endif + +#if defined(GMM_FLOAT) + if (TILING_KEY_IS(1)) { // x transposed + KERNEL_TASK_TYPE(1, KERNEL_TYPE_MIX_AIC_1_1); + if ASCEND_IS_AIV { + GET_TILING_DATA(tilingData, tiling); + EmptyTensorCompute(groupList, y, &tilingData); + } + if ASCEND_IS_AIC { + GMM_CUBE_IMP(true, false, false, matmulCFG); + } + } +#endif +#endif + +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 +#if defined(GMM_FLOAT) + if (TILING_KEY_IS(0)) { + GMM_CUBE_IMP(false, false, false, matmulCFG); + } else if (TILING_KEY_IS(1)) { // x transposed + KERNEL_TASK_TYPE(1, KERNEL_TYPE_MIX_AIC_1_1); + if ASCEND_IS_AIV { + GET_TILING_DATA(tilingData, tiling); + EmptyTensorCompute(groupList, y, &tilingData); + } + if ASCEND_IS_AIC { + GMM_CUBE_IMP(true, false, false, matmulCFG); + } + } else if (TILING_KEY_IS(2)) { // weight transposed + GMM_CUBE_IMP(false, true, false, matmulCFG); + } + +#endif +#endif +} diff --git a/src/transformer/grouped_matmul/grouped_matmul.h b/src/transformer/grouped_matmul/grouped_matmul.h new file mode 100644 index 00000000..5000b0b2 --- /dev/null +++ b/src/transformer/grouped_matmul/grouped_matmul.h @@ -0,0 +1,453 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul.h + * \brief + */ +#ifndef ASCENDC_GROUPED_MATMUL_H +#define ASCENDC_GROUPED_MATMUL_H + +#include "grouped_matmul_utils.h" + +namespace GROUPED_MATMUL { + +constexpr uint32_t thresholdBlockNum = 8; // 8 is obtained by tests, indicating the threshold of basic block numbers + // in both directions when assigning data blocks to cube cores when using + // diagnal strategy +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 +constexpr uint32_t thresholdDimM = 1; // not needs any special strategies +#else +constexpr uint32_t thresholdDimM = 5; // 5 is obtained by tests, indicating the threshold for distinguishing + // strategies for large/small shapes +#endif + +/*@brief store variables for core split configuration +*/ +struct MNConfig { + uint32_t m = 0; + uint32_t k = 0; + uint32_t n = 0; + uint32_t baseM = 0; + uint32_t baseN = 0; + uint32_t mIdx = 0; + uint32_t nIdx = 0; + uint32_t blockDimM = 0; + uint32_t blockDimN = 0; + uint32_t singleM = 0; + uint32_t singleN = 0; + uint64_t wBaseOffset = 0; + uint64_t nAxisBaseOffset = 0; + uint64_t mAxisBaseOffset = 0; + uint64_t xBaseOffset = 0; + uint64_t yBaseOffset = 0; + uint64_t wOutOffset = 0; + uint64_t workSpaceOffset = 0; +}; + +template +__aicore__ inline void DataCopyPad2D(const LocalTensor dst, const GlobalTensor src, uint32_t dim1, uint32_t dim0, + uint32_t fullDim0) { + DataCopyExtParams params; + params.blockCount = dim1; + params.blockLen = dim0 * sizeof(T); + params.srcStride = (fullDim0 - dim0) * sizeof(T); + params.dstStride = Ceil(dim0 * sizeof(T), UB_BLOCK_DOUBLE_UNIT_SIZE) * 2 - \ + Ceil(dim0 * sizeof(T), UB_BLOCK_UNIT_SIZE); + + DataCopyPadExtParams padParams; + padParams.isPad = true; + padParams.rightPadding = 0; + padParams.leftPadding = 0; + padParams.paddingValue = 0; + DataCopyPad(dst, src, params, padParams); +} + +template +__aicore__ inline void DataCopyPad2D(const GlobalTensor dst, const LocalTensor src, uint32_t dim1, uint32_t dim0, + uint32_t srcFullDim0, uint32_t dstFullDim0) { + DataCopyExtParams params; + params.blockCount = dim1; + params.blockLen = dim0 * sizeof(T); + params.srcStride = static_cast((srcFullDim0 - dim0) * sizeof(T) / UB_BLOCK_UNIT_SIZE); + params.dstStride = (dstFullDim0 - dim0) * sizeof(T); + DataCopyPad(dst, src, params); +} + +/** @brief GroupMatmul operator Class +*/ +template +class GMMProcess { + protected: + using B = typename ComputeType::B; + ComputeType& computeOp; // inernal computation operator + const GMMBaseParams* __restrict gmmBaseParams; + const TCubeTiling* __restrict mmTilingData; + + uint32_t blockIdx; + uint32_t coreIdx; + uint32_t groupNum; + int32_t preOffset; + GM_ADDR groupListPtr; + GlobalTensor groupListGm; + GlobalTensor mListGm; + GlobalTensor kListGm; + GlobalTensor nListGm; + + public: + /** @brief constructor */ + __aicore__ inline GMMProcess(ComputeType& computeOp_) : computeOp(computeOp_) {} + + __aicore__ inline void Init(const GMMBaseParams* __restrict gmmBaseParamsIn, + const TCubeTiling* __restrict mmTilingDataIn, GM_ADDR groupList, GM_ADDR tiling); + + __aicore__ inline void Process(); + + protected: + __aicore__ inline void SetMNConfig(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig); + + __aicore__ inline void SetMKN(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig); + + __aicore__ inline void UpdateMnConfig(MNConfig &mnConfig); + + __aicore__ inline void MNBlockIdxCompute(MNConfig &mnConfig, const uint32_t curBlock, const uint32_t count, + const uint32_t thresholdM_dimN); +}; + +template + __aicore__ inline void GMMProcess::Init(const GMMBaseParams* __restrict gmmBaseParamsIn, + const TCubeTiling* __restrict mmTilingDataIn, GM_ADDR groupList, GM_ADDR tiling) { + blockIdx = GetBlockIdx(); + coreIdx = blockIdx; + int64_t coreRation = GetTaskRation(); + if (coreRation > 1) { + coreIdx /= coreRation; + } + gmmBaseParams = gmmBaseParamsIn; + mmTilingData = mmTilingDataIn; + groupNum = gmmBaseParams->groupNum; + groupListPtr = groupList; + if (groupListPtr != nullptr) { + groupListGm.SetGlobalBuffer((__gm__ int64_t*)groupList); + } + GET_TILING_DATA_MEMBER_ADDR(GMMTilingData, gmmArray, gmmArrayAddr, tiling); // custom macro + mListGm.SetGlobalBuffer((__gm__ int32_t*)gmmArrayAddr); + kListGm.SetGlobalBuffer((__gm__ int32_t*)(gmmArrayAddr + sizeof(int32_t) * MKN_LIST_LEN)); + nListGm.SetGlobalBuffer((__gm__ int32_t*)(gmmArrayAddr + sizeof(int32_t) * MKN_LIST_LEN * 2)); +} + +template +__aicore__ inline void GMMProcess::SetMNConfig(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig) { + SetMKN(splitValue, groupIdx, mnConfig); + mnConfig.baseM = mmTilingData->baseM; + mnConfig.baseN = mmTilingData->baseN; + mnConfig.singleM = mnConfig.baseM; + mnConfig.singleN = mnConfig.baseN; +#if defined(GMM_QUANT_BF16) || defined(GMM_QUANT_FLOAT16) + if (gmmBaseParams->singleN > 0) { // not sequential write + mnConfig.singleN = gmmBaseParams->singleN; + } +#endif +} + +template +__aicore__ inline void GMMProcess::SetMKN(const int32_t splitValue, const uint32_t groupIdx, + MNConfig &mnConfig) { + uint32_t singleWeight = gmmBaseParams->singleWeight; + uint32_t singleX = gmmBaseParams->singleX; + uint32_t singleY = gmmBaseParams->singleY; + bool isAllSingleTensor = singleWeight == 1 && singleX == 1 && singleY == 1; + if (gmmBaseParams->groupType == 0) { + mnConfig.m = splitValue; + mnConfig.k = kListGm.GetValue(isAllSingleTensor ? 0 : groupIdx); + mnConfig.n = nListGm.GetValue(isAllSingleTensor ? 0 : groupIdx); + return; + } + + if (gmmBaseParams->groupType == 2) { + mnConfig.m = mListGm.GetValue(isAllSingleTensor ? 0 :groupIdx); + mnConfig.k = splitValue; + mnConfig.n = nListGm.GetValue(isAllSingleTensor ? 0 :groupIdx); + return; + } + + mnConfig.m = mListGm.GetValue(groupIdx); + mnConfig.k = kListGm.GetValue(groupIdx); + mnConfig.n = nListGm.GetValue(groupIdx); + return; +} + +template +__aicore__ inline void GMMProcess::UpdateMnConfig(MNConfig &mnConfig) { + if constexpr (B::format == CubeFormat::NZ) { + mnConfig.wBaseOffset += AlignUp<16>(mnConfig.k) * AlignUp<16>(mnConfig.n); // 16: nz format last two dim size + } else { + mnConfig.wBaseOffset += mnConfig.k * mnConfig.n; + } + mnConfig.nAxisBaseOffset += mnConfig.n; + mnConfig.mAxisBaseOffset += mnConfig.m; + mnConfig.xBaseOffset += mnConfig.m * mnConfig.k; + mnConfig.yBaseOffset += mnConfig.m * mnConfig.n; +} + +template +__aicore__ inline void GMMProcess::MNBlockIdxCompute(MNConfig &mnConfig, const uint32_t curBlock, + const uint32_t count, const uint32_t thresholdM_dimN) { + if (mnConfig.blockDimM <= thresholdDimM || thresholdDimM == 1) { + mnConfig.mIdx = (curBlock - count) / mnConfig.blockDimN; + mnConfig.nIdx = (curBlock - count) % mnConfig.blockDimN; + } else { + uint32_t relativeBlock = curBlock - count; + uint32_t curThresholdM = relativeBlock >= AlignDown(mnConfig.blockDimM * mnConfig.blockDimN, thresholdM_dimN) ? + mnConfig.blockDimM % thresholdBlockNum : thresholdBlockNum; + uint32_t curThresholdM_thresholdN = curThresholdM * thresholdBlockNum; + uint32_t curThresholdN = relativeBlock % thresholdM_dimN >=AlignDown(curThresholdM * mnConfig.blockDimN, + curThresholdM_thresholdN) ? mnConfig.blockDimN % thresholdBlockNum : thresholdBlockNum; + + uint32_t localRelativeBlock = relativeBlock % thresholdM_dimN % curThresholdM_thresholdN; + mnConfig.mIdx = localRelativeBlock % curThresholdM + relativeBlock / thresholdM_dimN * thresholdBlockNum; + mnConfig.nIdx = (localRelativeBlock + localRelativeBlock / + LeastCommonMultiple(curThresholdM, curThresholdN)) % curThresholdN + relativeBlock % + thresholdM_dimN / curThresholdM_thresholdN * thresholdBlockNum; + } +} + +template +__aicore__ inline void GMMProcess::Process() { + MNConfig mnConfig; + if (gmmBaseParams->groupType != -1) { // -1: no split + if (unlikely(groupListPtr == nullptr)) { + return; + } + preOffset = 0; + } + for (uint32_t groupIdx = 0, count = 0; groupIdx < groupNum; ++groupIdx) { + UpdateMnConfig(mnConfig); + int32_t splitValue = GetSplitValueFromGroupList(groupIdx, preOffset, gmmBaseParams, groupListGm); + SetMNConfig(splitValue, groupIdx, mnConfig); + if (mnConfig.m <= 0 || mnConfig.k <= 0 || mnConfig.n <= 0) { + continue; + } + mnConfig.blockDimM = Ceil(mnConfig.m, mnConfig.singleM); + mnConfig.blockDimN = Ceil(mnConfig.n, mnConfig.singleN); + + uint32_t curCount = count + mnConfig.blockDimM * mnConfig.blockDimN; + uint32_t curBlock = coreIdx >= count ? coreIdx : coreIdx + gmmBaseParams->coreNum; + uint32_t thresholdM_dimN = thresholdBlockNum * mnConfig.blockDimN; + + while (curBlock < curCount) { + MNBlockIdxCompute(mnConfig, curBlock, count, thresholdM_dimN); + computeOp.MMCompute(groupIdx, mnConfig, coreIdx); + computeOp.VectorCompute(mnConfig); + curBlock += gmmBaseParams->coreNum; + } + count = curCount % gmmBaseParams->coreNum; + } + computeOp.PostCompute(); +} + +/** @brief intenal computation class +*/ +template +class GMMCompute { + public: + using AT = typename mmType::AT::T; + using BT = typename mmType::BT::T; + using B = typename mmType::BT; + using CT = typename mmType::CT::T; + using BiasT = typename mmType::BiasT::T; + using WT = DTYPE_WEIGHT; + constexpr static bool transposeX = mmType::AT::isTrans; + constexpr static bool transposeW = mmType::BT::isTrans; + + /** @brief constructor */ + __aicore__ inline GMMCompute(typename mmType::MT& mm_) : mm(mm_) {} + + __aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR scale, GM_ADDR offset, + GM_ADDR antiquantScale, GM_ADDR antiquantOffset, GM_ADDR groupList, + GM_ADDR perTokenScale, GM_ADDR y, GM_ADDR workspace, + const GMMBaseParams* __restrict gmmBaseParams, + const TCubeTiling* __restrict mmTilingData, TPipe* tPipe); + + __aicore__ inline void MMCompute(uint32_t groupIdx, MNConfig& mnConfig, uint32_t coreIdx); + + __aicore__ inline void VectorCompute(MNConfig& mnConfig) {} + + __aicore__ inline void PostCompute() {} + + protected: + __aicore__ inline void SetGlobalBufferBias(uint32_t groupIdx, uint32_t tailN, const MNConfig mnConfig); + + __aicore__ inline GlobalTensor SetGlobalBufferW(uint32_t groupIdx, uint32_t tailN, MNConfig& mnConfig); + + __aicore__ inline uint64_t SetWOffset(uint32_t tailN, uint32_t k); + + protected: + TPipe* pipe; + typename mmType::MT& mm; // matmul operator + bool hasBias = false; + GM_ADDR xTensorPtr; + GM_ADDR weightTensorPtr; + GM_ADDR biasTensorPtr; + GM_ADDR yTensorPtr; + GlobalTensor xGm; + GlobalTensor weightGm; + GlobalTensor biasGm; + GlobalTensor yGm; +#if defined(GMM_QUANT_INT8) + GM_ADDR scaleTensorPtr; + GlobalTensor scaleGm; +#endif + uint32_t ubBaseN; + uint32_t ubBaseK; + uint32_t ubCalSize; + uint32_t singleWeight; + uint32_t singleX; + uint32_t singleY; + uint32_t coreNum; + uint32_t subBlockIdx; + bool mmWaitStatus; + uint32_t activeType; +}; + +template +__aicore__ inline void GMMCompute::Init(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR scale, + GM_ADDR offset, GM_ADDR antiquantScale, GM_ADDR antiquantOffset, + GM_ADDR groupList, GM_ADDR perTokenScale, GM_ADDR y, + GM_ADDR workspace, const GMMBaseParams* __restrict gmmBaseParams, + const TCubeTiling* __restrict mmTilingData, + TPipe* tPipe) { + xTensorPtr = x; + weightTensorPtr = weight; + biasTensorPtr = bias; + yTensorPtr = y; + pipe = tPipe; + ubBaseN = gmmBaseParams->ubBaseN; + ubBaseK = gmmBaseParams->ubBaseK; + ubCalSize = gmmBaseParams->ubCalSize; + singleWeight = gmmBaseParams->singleWeight; + singleX = gmmBaseParams->singleX; + singleY = gmmBaseParams->singleY; + coreNum = gmmBaseParams->coreNum; + subBlockIdx = GetSubBlockIdx(); + hasBias = mmTilingData->isBias != 0; + activeType = gmmBaseParams->activeType; + mmWaitStatus = false; +#if defined(GMM_QUANT_INT8) + scaleTensorPtr = scale; +#endif +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + TBuf<> ubBuf; + pipe->InitBuffer(ubBuf, TOTAL_UB_SIZE / 2); + LocalTensor buf = ubBuf.template Get(); + mm.SetLocalWorkspace(buf); +#endif +} + +template +__aicore__ inline void GMMCompute::SetGlobalBufferBias(uint32_t groupIdx, + uint32_t tailN, const MNConfig mnConfig) { + if (hasBias) { + if (singleWeight == 0) { + biasGm.SetGlobalBuffer(GetTensorAddr(groupIdx, biasTensorPtr)); + } else { + biasGm.SetGlobalBuffer(GetTensorAddr(0, biasTensorPtr) + mnConfig.nAxisBaseOffset); + } + mm.SetBias(biasGm[tailN]); + } +} + +template +__aicore__ inline uint64_t GMMCompute::SetWOffset(uint32_t tailN, uint32_t k) { + uint64_t wOffset = 0; + if constexpr (mmType::BT::format == CubeFormat::NZ && transposeW) { + wOffset = tailN * (UB_BLOCK_UNIT_SIZE / sizeof(BT)); // 32: quant is 32, float16 is 16 + } else if constexpr (mmType::BT::format == CubeFormat::NZ) { + wOffset = tailN * AlignUp<16>(k); // 16: nz format last two dim size + } else if constexpr (transposeW) { + wOffset = tailN * k; + } else { + wOffset = tailN; + } + return wOffset; +} + +template +__aicore__ inline GlobalTensor GMMCompute::SetGlobalBufferW( + uint32_t groupIdx, uint32_t tailN, MNConfig& mnConfig) { + uint64_t wOffset = SetWOffset(tailN, mnConfig.k); +#if defined(GMM_ANTI_QUANT) + return weightGm[transposeW ? mnConfig.workSpaceOffset - tailN + wOffset : mnConfig.workSpaceOffset]; +#else + GlobalTensor weightGmLocal; + if (singleWeight == 0) { + weightGmLocal.SetGlobalBuffer(GetTensorAddr(groupIdx, weightTensorPtr) + wOffset); + } else { + weightGmLocal.SetGlobalBuffer(GetTensorAddr(0, weightTensorPtr) + mnConfig.wBaseOffset + wOffset); + } + #if !(defined(ASCENDC_OOM) && ASCENDC_OOM == 1) + if (mnConfig.blockDimM == 1) { + weightGmLocal.SetL2CacheHint(CacheMode::CACHE_MODE_DISABLE); + } + #endif + return weightGmLocal; +#endif +} + +template +__aicore__ inline void GMMCompute::MMCompute(uint32_t groupIdx, MNConfig& mnConfig, uint32_t coreIdx) { + if (subBlockIdx != 0) { + return; + } + uint32_t tailN = mnConfig.nIdx * mnConfig.singleN; + uint32_t curSingleN = mnConfig.nIdx < mnConfig.blockDimN - 1 ? mnConfig.singleN : mnConfig.n - tailN; + uint32_t curSingleM = mnConfig.mIdx < mnConfig.blockDimM - 1 ? mnConfig.singleM + : mnConfig.m - mnConfig.mIdx * mnConfig.singleM; + uint64_t xOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.k; + if constexpr (transposeX) { + xOffset = mnConfig.mIdx * mnConfig.singleM; + } + uint64_t outOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.n + tailN; + // init global buffer + if (singleX == 0) { + xGm.SetGlobalBuffer(GetTensorAddr(groupIdx, xTensorPtr)); + } else { + xGm.SetGlobalBuffer(GetTensorAddr(0, xTensorPtr) + mnConfig.xBaseOffset); + } + GlobalTensor weightGmLocal = SetGlobalBufferW(groupIdx, tailN, mnConfig); + mm.SetOrgShape(mnConfig.m, mnConfig.n, mnConfig.k); + mm.SetSingleShape(curSingleM, curSingleN, mnConfig.k); + mm.SetTensorA(xGm[xOffset], transposeX); + mm.SetTensorB(weightGmLocal, transposeW); +#if defined(GMM_QUANT_INT8) + if (singleWeight == 0) { + scaleGm.SetGlobalBuffer(GetTensorAddr(groupIdx, scaleTensorPtr)); + } else { + scaleGm.SetGlobalBuffer(GetTensorAddr(0, scaleTensorPtr) + mnConfig.nAxisBaseOffset); + } + mm.SetQuantVector(scaleGm[tailN]); +#endif + SetGlobalBufferBias(groupIdx, tailN, mnConfig); + if (singleY == 0) { + yGm.SetGlobalBuffer(GetTensorAddr(groupIdx, yTensorPtr)); + } else { + yGm.SetGlobalBuffer(GetTensorAddr(0, yTensorPtr) + mnConfig.yBaseOffset); + } + #if defined(GMM_ANTI_QUANT) + mm.template IterateAll(yGm[outOffset], 0, false, true); + mmWaitStatus = true; + #else + mm.template IterateAll(yGm[outOffset], 0); + #endif +} + +} // namespace GROUPED_MATMUL + +#endif // ASCENDC_GROUPED_MATMUL_H diff --git a/src/transformer/grouped_matmul/grouped_matmul_antiquant.h b/src/transformer/grouped_matmul/grouped_matmul_antiquant.h new file mode 100644 index 00000000..1e077f9f --- /dev/null +++ b/src/transformer/grouped_matmul/grouped_matmul_antiquant.h @@ -0,0 +1,544 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_antiquant.h + * \brief + */ +#ifndef ASCENDC_GROUPED_MATMUL_ANTIQUANT_H +#define ASCENDC_GROUPED_MATMUL_ANTIQUANT_H + +#include "grouped_matmul.h" + +#ifdef GMM_ANTI_QUANT +namespace GROUPED_MATMUL { + +constexpr uint32_t CAST_THRESHOLD_CACHE_BIG = 16 * 1024 * 1024; // 16M is obtained by tests +constexpr uint32_t CAST_THRESHOLD_CACHE_SMALL = 10 * 1024 * 1024; // 10M is obtained by tests +constexpr uint32_t CAST_PERFORMANCE_MAX_N = 5120; +constexpr uint32_t CAST_MIN_SINGLE_K = 8; +constexpr int32_t BEST_UB_BASEN = 512; + +/*@brief store variables for core split configuration +*/ +struct CastWeightConfig { + uint32_t coreNum = 0; + uint32_t nUsedCore = 0; + uint32_t curDimN = 0; + uint32_t castRoundIdx = 0; + uint32_t workSpaceIdx = 0; + uint64_t wInNOffset = 0; + uint32_t wInKOffset = 0; + uint32_t curSingleN = 0; + uint32_t curSingleK = 0; + uint32_t tailN = 0; +}; + +/** @brief GroupMatmul Antiquant operator Class +*/ +template +class GMMAntiquantProcess : public GMMProcess{ + protected: + constexpr static bool antiquantPerformance = ComputeType::antiquantPerformanceFlag; + public: + /** @brief constructor */ + __aicore__ inline GMMAntiquantProcess(ComputeType& computeOp_) : GMMProcess(computeOp_) {} + + __aicore__ inline void Process(); + + private: + __aicore__ inline void SetAntiquantMNConfig(const uint64_t singleWorkSpaceSize, const uint32_t curBlock, bool& validCore, + CastWeightConfig& castConfig, MNConfig &mnConfig); + + __aicore__ inline void SetAntiquantCastConfig(uint32_t& curCount, MNConfig mnConfig, + CastWeightConfig& castConfig); + __aicore__ inline void AntiquantUpdateSingleM(MNConfig& mnConfig, uint32_t& dimM, uint32_t dimN); +}; + +template +__aicore__ inline void GMMAntiquantProcess::SetAntiquantMNConfig(const uint64_t singleWorkSpaceSize, + const uint32_t curBlock, bool& validCore, CastWeightConfig& castConfig, MNConfig &mnConfig) { + mnConfig.workSpaceOffset = castConfig.workSpaceIdx * singleWorkSpaceSize; + castConfig.workSpaceIdx = castConfig.workSpaceIdx == 0 ? 1 : 0; // next round use another workspace + castConfig.castRoundIdx = Ceil(curBlock + 1, castConfig.coreNum) - 1; // +1: let curBlock start from 1,-1: castRoundIdx start from 0 + castConfig.curDimN = castConfig.nUsedCore; + if (castConfig.castRoundIdx == Ceil(mnConfig.blockDimN, castConfig.nUsedCore) - 1) { // -1 last round + castConfig.curDimN = mnConfig.blockDimN - castConfig.castRoundIdx * castConfig.nUsedCore; + } + // compute dimM + uint32_t dimM = Max(castConfig.coreNum / castConfig.curDimN, 1); // 1: The minimum value of dimM is 1 + dimM = Min(Ceil(mnConfig.m, this->mmTilingData->baseM), dimM); + mnConfig.singleM = Ceil(mnConfig.m, dimM); + mnConfig.blockDimM = dimM; + mnConfig.mIdx = this->coreIdx / castConfig.curDimN; + mnConfig.nIdx = this->coreIdx % castConfig.curDimN; + validCore = this->coreIdx < dimM * castConfig.curDimN; +} + +template +__aicore__ inline void GMMAntiquantProcess::SetAntiquantCastConfig(uint32_t& curCount, + MNConfig mnConfig, + CastWeightConfig& castConfig) { + if (mnConfig.blockDimM > 0 && mnConfig.blockDimN > 0) { + // 16M and 10M is obtained by tests. When N is greater than 5120, the cache uses 10 MB for better performance + uint32_t cacheThreshold = mnConfig.n > CAST_PERFORMANCE_MAX_N ? CAST_THRESHOLD_CACHE_SMALL : CAST_THRESHOLD_CACHE_BIG; + // 16M/k is the length of N that needs to be calculated for single round. + // 16M/k/baseN is the coreNum required for single round calculation of the N-axis. + castConfig.nUsedCore = Min(Ceil(cacheThreshold, mnConfig.k * this->mmTilingData->baseN), castConfig.coreNum); + castConfig.nUsedCore = Min(castConfig.nUsedCore, mnConfig.blockDimN); + curCount = Ceil(mnConfig.blockDimN, castConfig.nUsedCore) * castConfig.coreNum; + } +} + +template +__aicore__ inline void GMMAntiquantProcess::AntiquantUpdateSingleM(MNConfig& mnConfig, + uint32_t& dimM, uint32_t dimN) { + if (dimM > 1 && dimN < this->gmmBaseParams->coreNum) { + uint32_t restCores = this->gmmBaseParams->coreNum / dimN; + if (dimM > restCores) { + mnConfig.singleM = Ceil(mnConfig.m, restCores); + dimM = Ceil(mnConfig.m, mnConfig.singleM); + } + } +} + +template +__aicore__ inline void GMMAntiquantProcess::Process() { + MNConfig mnConfig; + CastWeightConfig castConfig; + castConfig.coreNum = this->gmmBaseParams->coreNum; + bool validCore = true; + uint64_t singleWorkSpaceSize = this->gmmBaseParams->workspaceSize / 2; // 2: antiQuantNormal use 2 block workspace + if (this->gmmBaseParams->groupType != -1) { // -1: no need to split + this->preOffset = 0; + if (unlikely(this->groupListPtr == nullptr)) {this->groupNum = 0;} // not continue Process + } + for (uint32_t groupIdx = 0, count = 0; groupIdx < this->groupNum; ++groupIdx) { + int32_t splitValue = GetSplitValueFromGroupList(groupIdx, this->preOffset, this->gmmBaseParams, this->groupListGm); + this->SetMNConfig(splitValue, groupIdx, mnConfig); + uint32_t dimM = Ceil(mnConfig.m, mnConfig.singleM); + uint32_t dimN = Ceil(mnConfig.n, mnConfig.singleN); + if constexpr (!antiquantPerformance) { + AntiquantUpdateSingleM(mnConfig, dimM, dimN); + } + mnConfig.blockDimM = dimM; + mnConfig.blockDimN = dimN; + uint32_t curCount = count + dimM * dimN; + uint32_t curBlock = this->coreIdx >= count ? this->coreIdx : this->coreIdx + this->gmmBaseParams->coreNum; + uint32_t thresholdM_dimN = thresholdBlockNum * dimN; + + if constexpr (antiquantPerformance) { + SetAntiquantCastConfig(curCount, mnConfig, castConfig); + } + + while (curBlock < curCount) { + if constexpr (antiquantPerformance) { // performance verison, will split dimN + SetAntiquantMNConfig(singleWorkSpaceSize, curBlock, validCore, castConfig, mnConfig); + } else { + mnConfig.workSpaceOffset = mnConfig.wBaseOffset; + this->MNBlockIdxCompute(mnConfig, curBlock, count, thresholdM_dimN); + } + this->computeOp.PreCompute(groupIdx, this->coreIdx, mnConfig, castConfig); + this->computeOp.MMSync(); + if (validCore) { + mnConfig.workSpaceOffset += mnConfig.nIdx * mnConfig.singleN; + if constexpr (antiquantPerformance) { + mnConfig.nIdx += castConfig.castRoundIdx * castConfig.nUsedCore; + } + this->computeOp.MMCompute(groupIdx, mnConfig, this->coreIdx); + } + curBlock += this->gmmBaseParams->coreNum; + } + this->UpdateMnConfig(mnConfig); + count = curCount % this->gmmBaseParams->coreNum; + } +} + + +/** @brief intenal computation class +*/ +template +class GMMAntiquantCompute : public GMMCompute { + public: + using AT = typename mmType::AT::T; + using BT = typename mmType::BT::T; + using B = typename mmType::BT; + using CT = typename mmType::CT::T; + using BiasT = typename mmType::BiasT::T; + using WT = DTYPE_WEIGHT; + constexpr static bool transposeX = mmType::AT::isTrans; + constexpr static bool transposeW = mmType::BT::isTrans; + constexpr static bool antiquantPerformanceFlag = antiquantPerformance; + + __aicore__ inline GMMAntiquantCompute(typename mmType::MT& mm_) : GMMCompute(mm_) {} + + __aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR scale, + GM_ADDR offset, GM_ADDR antiquantScale, GM_ADDR antiquantOffset, GM_ADDR groupList, GM_ADDR perTokenScale, + GM_ADDR y, GM_ADDR workspace, const GMMBaseParams* __restrict gmmBaseParams, + const TCubeTiling* __restrict mmTilingData, TPipe* tPipe); + + __aicore__ inline void PreCompute(uint32_t groupIdx, + uint32_t coreIdx, MNConfig& mnConfig, CastWeightConfig& castConfig); + + __aicore__ inline void MMSync(); + + private: + + __aicore__ inline void CastWeightProcess(MNConfig& mnConfig, CastWeightConfig& castConfig); + __aicore__ inline void SetAntiQuantGlobalBuffer(uint32_t groupIdx, const MNConfig mnConfig); + __aicore__ inline void SetGmToUbDataCopyParams(const uint32_t curBaseN, const uint32_t curBaseK, + const MNConfig& mnConfig, DataCopyExtParams& intriParams); + __aicore__ inline void SetUbToGmDataCopyParams(const uint32_t curBaseN, const uint32_t alignRowLen, + const uint32_t curBaseK, const MNConfig& mnConfig, + DataCopyExtParams& intriParams); + __aicore__ inline void CastWeightCompute(uint32_t curCalcK, uint32_t curCalcAlignN); + __aicore__ inline void DataCopyScaleAndOffset(uint32_t curBaseN, uint32_t alignBaseN, + uint64_t realScaleOffset); + __aicore__ inline void DataCopyScale(uint32_t curBaseN, uint32_t alignBaseN, uint64_t scaleOffset); + __aicore__ inline void DataCopyPerTokenScale(uint32_t curBaseM, uint64_t perTokenScaleOffset); + __aicore__ inline void PerTokenDequant(uint32_t curBaseM, uint32_t alignBaseN); + __aicore__ inline void SetPerTokenQuantRefreshedBuffer(const MNConfig mnConfig); + __aicore__ inline void ComputeUbBaseK(uint32_t curSingleK, uint32_t offsetK, uint32_t newBaseK, + uint32_t& curUsedGroupSize, uint32_t& curBaseK); + __aicore__ inline void FreeScaleAndOffset(bool& firstLoop); + + GlobalTensor weightAntiQuantGm; + GM_ADDR antiScaleTensorPtr; + GM_ADDR antiOffsetTensorPtr; + LocalTensor scaleInUb; + LocalTensor offsetInUb; + GlobalTensor antiScaleGM; + GlobalTensor antiOffsetGM; + // define the que + TQue vecInQueue; + TQue vecOutQueue; + TQue scaleInQueue; + TQue offsetInQueue; + TBuf tmpBuff; + LocalTensor tmpUb; + bool isPerGroup = false; + uint32_t perGroupSize; +}; + +template +__aicore__ inline void +GMMAntiquantCompute::Init(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR scale, + GM_ADDR offset, GM_ADDR antiquantScale, GM_ADDR antiquantOffset, GM_ADDR groupList, GM_ADDR perTokenScale, + GM_ADDR y, GM_ADDR workspace, const GMMBaseParams* __restrict gmmBaseParams, + const TCubeTiling* __restrict mmTilingData, TPipe* tPipe) { + this->GMMCompute::Init(x, weight, bias, scale, offset, antiquantScale, antiquantOffset, groupList, + perTokenScale, y, workspace, gmmBaseParams, mmTilingData, tPipe); + antiScaleTensorPtr = antiquantScale; + antiOffsetTensorPtr = antiquantOffset; + perGroupSize = gmmBaseParams->quantParam; + isPerGroup = perGroupSize > 0; + this->weightGm.SetGlobalBuffer((__gm__ BT*)workspace); + uint32_t maxUbBaseN = BEST_UB_BASEN; + if constexpr (transposeW) { + maxUbBaseN = this->ubBaseN; + } + // scale should bigger than singleN, 32 alignment is required + this->pipe->InitBuffer(scaleInQueue, 2, maxUbBaseN * sizeof(BT)); + this->pipe->InitBuffer(offsetInQueue, 2, maxUbBaseN * sizeof(BT)); + this->pipe->InitBuffer(vecInQueue, 2, this->ubCalSize * GetTypeBits() / INT8_BITS); + this->pipe->InitBuffer(vecOutQueue, 2, this->ubCalSize * sizeof(BT)); + this->pipe->InitBuffer(tmpBuff, gmmBaseParams->ubRestBytes); + tmpUb = tmpBuff.Get(); +} + +template +__aicore__ inline void GMMAntiquantCompute::PreCompute(uint32_t groupIdx, + uint32_t coreIdx, MNConfig& mnConfig, CastWeightConfig& castConfig) { + if constexpr (!antiquantPerformance) { + if (this->subBlockIdx != 0) { + return; + } + } + castConfig.curSingleN = 0; + castConfig.curSingleK = 0; + castConfig.wInKOffset = 0; + castConfig.wInNOffset = 0; + mnConfig.wOutOffset = mnConfig.workSpaceOffset; + castConfig.tailN = 0; + if constexpr (antiquantPerformance) { // antiquant normal version + uint32_t blockDimK = Min(this->coreNum, Ceil(mnConfig.k, CAST_MIN_SINGLE_K)); + if (coreIdx >= blockDimK) { return; } + castConfig.curSingleK = Ceil(mnConfig.k, blockDimK); + castConfig.tailN = castConfig.castRoundIdx * castConfig.nUsedCore * mnConfig.singleN; + castConfig.wInNOffset = castConfig.tailN; + castConfig.wInKOffset = coreIdx * castConfig.curSingleK; + if (coreIdx == blockDimK - 1) { // -1: last dimK + castConfig.curSingleK = mnConfig.k - castConfig.curSingleK * coreIdx; + } + mnConfig.wOutOffset += castConfig.wInKOffset * mnConfig.n; + castConfig.curSingleN = castConfig.curDimN * mnConfig.singleN; + if (castConfig.castRoundIdx == Ceil(mnConfig.blockDimN, castConfig.nUsedCore) - 1) { // -1: last round + castConfig.curSingleN = mnConfig.n - castConfig.castRoundIdx * castConfig.nUsedCore * mnConfig.singleN; + } + } else { // antiquant generalized version + castConfig.curSingleN = mnConfig.singleN; + castConfig.curSingleK = mnConfig.k; + castConfig.tailN = mnConfig.nIdx * mnConfig.singleN; + castConfig.wInNOffset = this->transposeW ? castConfig.tailN * mnConfig.k : castConfig.tailN; + mnConfig.wOutOffset += castConfig.wInNOffset; + if (mnConfig.nIdx == mnConfig.blockDimN - 1) { + castConfig.curSingleN = mnConfig.n - mnConfig.nIdx * mnConfig.singleN; + } + } + SetAntiQuantGlobalBuffer(groupIdx, mnConfig); + CastWeightProcess(mnConfig, castConfig); +} + +template +__aicore__ inline void GMMAntiquantCompute::MMSync() { + if (this->mmWaitStatus) { + this->mm.WaitIterateAll(); + this->mmWaitStatus = false; + } + if constexpr (antiquantPerformance) { + SyncAll(); + } +} + +template +__aicore__ inline void +GMMAntiquantCompute::SetAntiQuantGlobalBuffer(uint32_t groupIdx, + const MNConfig mnConfig) { + if (this->singleWeight == 0) { + weightAntiQuantGm.SetGlobalBuffer(GetTensorAddr(groupIdx, this->weightTensorPtr)); + antiScaleGM.SetGlobalBuffer(GetTensorAddr(groupIdx, antiScaleTensorPtr)); + antiOffsetGM.SetGlobalBuffer(GetTensorAddr(groupIdx, antiOffsetTensorPtr)); + } else { + weightAntiQuantGm.SetGlobalBuffer(GetTensorAddr(0, this->weightTensorPtr) + mnConfig.wBaseOffset * GetTypeBits() / INT8_BITS); + uint64_t antiquantParamsOffset = mnConfig.nAxisBaseOffset; + if (isPerGroup) { + antiquantParamsOffset *= (mnConfig.k / perGroupSize); + } + antiScaleGM.SetGlobalBuffer(GetTensorAddr(0, antiScaleTensorPtr) + antiquantParamsOffset); + antiOffsetGM.SetGlobalBuffer(GetTensorAddr(0, antiOffsetTensorPtr) + antiquantParamsOffset); + } +} + + +template +__aicore__ inline void GMMAntiquantCompute::ComputeUbBaseK( + uint32_t curSingleK, uint32_t offsetK, uint32_t newBaseK, uint32_t& curUsedGroupSize, uint32_t& curBaseK) { + if (unlikely(offsetK + newBaseK >= curUsedGroupSize)) { + curBaseK = curUsedGroupSize - offsetK; + curUsedGroupSize += perGroupSize; + if (offsetK + curBaseK > curSingleK) { + curBaseK = curSingleK - offsetK; + } + } else if (unlikely(offsetK + newBaseK > curSingleK)) { + curBaseK = curSingleK - offsetK; + } else { + curBaseK = newBaseK; + } +} + + +template +__aicore__ inline void GMMAntiquantCompute::FreeScaleAndOffset(bool& firstLoop) { + if (firstLoop) { + firstLoop = false; + } else { + scaleInQueue.FreeTensor(scaleInUb); + offsetInQueue.FreeTensor(offsetInUb); + } +} + +template +__aicore__ inline void GMMAntiquantCompute::CastWeightProcess( + MNConfig& mnConfig, CastWeightConfig& castConfig) { + uint64_t wInOffset = castConfig.wInNOffset + static_cast(castConfig.wInKOffset) * mnConfig.n; + const uint32_t& curSingleK = castConfig.curSingleK; + const uint32_t& curSingleN = castConfig.curSingleN; + const uint32_t& scaleOffset = castConfig.tailN; + uint32_t newBaseK = this->ubBaseK; + uint32_t newBaseN = this->ubBaseN; + uint32_t usedGroupSize = mnConfig.k; + if (isPerGroup) { + newBaseK = Min(this->ubBaseK, perGroupSize); + if (!transposeW && newBaseK < perGroupSize && newBaseK > perGroupSize / 2 && mnConfig.n % newBaseN != 0) { + uint32_t tempUbBaseN = AlignDown(this->ubBaseK * this->ubBaseN / Ceil(perGroupSize, 2), 32); // 32:a factor + // ubBaseN cannot be larger than BEST_UB_BASEN, due to offset/scale queue size + if (tempUbBaseN <= BEST_UB_BASEN && mnConfig.n % tempUbBaseN == 0) { + newBaseK = Ceil(perGroupSize, 2); + newBaseN = tempUbBaseN; + } + } + usedGroupSize = perGroupSize + AlignDown(castConfig.wInKOffset, perGroupSize); + } + DataCopyPadExtParams padParams; + for (uint32_t offsetN(0), curBaseN(newBaseN), nCount(0); offsetN < curSingleN; offsetN += newBaseN) { + if (unlikely(offsetN + newBaseN > curSingleN)) { + curBaseN = curSingleN - offsetN; + } + uint32_t alignBaseN = AlignUp(curBaseN, UB_BLOCK_UNIT_SIZE * INT8_BITS / GetTypeBits()); + if (!isPerGroup) { + DataCopyScaleAndOffset(curBaseN, alignBaseN, scaleOffset + offsetN); + } + uint32_t curBaseK = newBaseK; + uint32_t curUsedGroupSize = usedGroupSize - castConfig.wInKOffset; + bool firstKLoop = true; + int32_t prePergroupIdx = -1; + int32_t curPergroupIdx = 0; + for (uint32_t offsetK(0), subCoreCount(nCount); offsetK < curSingleK; offsetK += curBaseK) { + ComputeUbBaseK(curSingleK, offsetK, newBaseK, curUsedGroupSize, curBaseK); + if constexpr (antiquantPerformance) { + if (this->subBlockIdx == (++subCoreCount) % 2) { // 2: two vectors + continue; + } + } + if (isPerGroup) { + curPergroupIdx = (offsetK + castConfig.wInKOffset) / perGroupSize; + if (firstKLoop || curPergroupIdx > prePergroupIdx) { // load new group + FreeScaleAndOffset(firstKLoop); + DataCopyScaleAndOffset(curBaseN, alignBaseN, scaleOffset + offsetN + curPergroupIdx * mnConfig.n); + prePergroupIdx = curPergroupIdx; + } + } + LocalTensor inLocal = vecInQueue.AllocTensor(); + DataCopyExtParams gmToUbIntriParams; + SetGmToUbDataCopyParams(curBaseN, curBaseK, mnConfig, gmToUbIntriParams); + uint64_t weightInOffset = transposeW ? offsetK + static_cast(offsetN) * mnConfig.k : + static_cast(offsetK) * mnConfig.n + offsetN; + DataCopyPad(inLocal, weightAntiQuantGm[(weightInOffset + wInOffset) * GetTypeBits() / INT8_BITS], gmToUbIntriParams, padParams); + vecInQueue.EnQue(inLocal); + + DataCopyExtParams ubToGmIntriParams; + if constexpr (transposeW) { + uint32_t alignBaseK = AlignUp(curBaseK, UB_BLOCK_UNIT_SIZE * INT8_BITS / GetTypeBits()); + CastWeightCompute(alignBaseK, alignBaseN); + SetUbToGmDataCopyParams(curBaseN, alignBaseK, curBaseK, mnConfig, ubToGmIntriParams); + } else { + CastWeightCompute(curBaseK, alignBaseN); + SetUbToGmDataCopyParams(curBaseN, alignBaseN, curBaseK, mnConfig, ubToGmIntriParams); + } + + // ResultCopy2GM + LocalTensor wResUb = vecOutQueue.DeQue(); + uint64_t weightOutOffset = transposeW ? mnConfig.wOutOffset + offsetK + offsetN * mnConfig.k : + mnConfig.wOutOffset + offsetK * mnConfig.n + offsetN; + DataCopyPad(this->weightGm[weightOutOffset], wResUb, ubToGmIntriParams); + vecOutQueue.FreeTensor(wResUb); + } + nCount = nCount == 0 ? 1: 0; + if (!(isPerGroup && firstKLoop)) { + scaleInQueue.FreeTensor(scaleInUb); + offsetInQueue.FreeTensor(offsetInUb); + } + } + + event_t eventIdMTE3ToS = static_cast(this->pipe->FetchEventID(HardEvent::MTE3_S)); + SetFlag(eventIdMTE3ToS); + WaitFlag(eventIdMTE3ToS); +} + +template +__aicore__ inline void +GMMAntiquantCompute::CastWeightCompute(uint32_t curCalcK, uint32_t curCalcAlignN) { + LocalTensor wInUb = vecInQueue.DeQue(); + wInUb.SetSize(curCalcK * curCalcAlignN); + LocalTensor wResUb = vecOutQueue.AllocTensor(); + LocalTensor tmpLocal = tmpUb.template ReinterpretCast(); + + AntiQuantShapeInfo shapeInfo; + if constexpr (transposeW) { + shapeInfo.offsetHeight = curCalcAlignN; + shapeInfo.offsetWidth = 1; + shapeInfo.scaleHeight = curCalcAlignN; + shapeInfo.scaleWidth = 1; + event_t eventId = static_cast(this->pipe->FetchEventID(HardEvent::MTE2_S)); + SetFlag(eventId); + WaitFlag(eventId); + } else { + shapeInfo.offsetHeight = 1; + shapeInfo.offsetWidth = curCalcAlignN; + shapeInfo.scaleHeight = 1; + shapeInfo.scaleWidth = curCalcAlignN; + } + // fp16 tempbuff is 0, bf16 tempbuff = offset.GetSize() * 2 * sizeof(float) + 64 * K * sizeof(float) + AscendAntiQuant(wResUb, wInUb, offsetInUb, scaleInUb, tmpLocal, curCalcK, shapeInfo); + + vecInQueue.FreeTensor(wInUb); + vecOutQueue.EnQue(wResUb); +} + +template +__aicore__ inline void +GMMAntiquantCompute::SetGmToUbDataCopyParams(const uint32_t curBaseN, + const uint32_t curBaseK, const MNConfig& mnConfig, DataCopyExtParams& intriParams) { + if constexpr (transposeW) { + intriParams.blockLen = Ceil(curBaseK * GetTypeBits(), INT8_BITS); + intriParams.blockCount = curBaseN; + intriParams.srcStride = Ceil((mnConfig.k - curBaseK) * GetTypeBits(), INT8_BITS); + intriParams.dstStride = 0; + } else { + intriParams.blockLen = Ceil(curBaseN * GetTypeBits(), INT8_BITS); + intriParams.blockCount = curBaseK; + intriParams.srcStride = Ceil((mnConfig.n - curBaseN) * GetTypeBits(), INT8_BITS); + intriParams.dstStride = 0; + } +} + +template +__aicore__ inline void +GMMAntiquantCompute::SetUbToGmDataCopyParams(const uint32_t curBaseN, + const uint32_t alignRowLen, const uint32_t curBaseK, const MNConfig& mnConfig, DataCopyExtParams& intriParams) { + if constexpr (transposeW) { + uint32_t alignBaseK = AlignUp(curBaseK, UB_BLOCK_UNIT_SIZE); + intriParams.blockLen = curBaseK * sizeof(BT); + intriParams.blockCount = curBaseN; + intriParams.srcStride = (alignRowLen - curBaseK) / (UB_BLOCK_UNIT_SIZE / sizeof(BT)); + intriParams.dstStride = (mnConfig.k - curBaseK) * sizeof(BT); + } else { + intriParams.blockLen = curBaseN * sizeof(BT); + intriParams.blockCount = curBaseK; + intriParams.srcStride = (alignRowLen - curBaseN) / (UB_BLOCK_UNIT_SIZE / sizeof(BT)); + intriParams.dstStride = (mnConfig.n - curBaseN) * sizeof(BT); + } +} + +template +__aicore__ inline void +GMMAntiquantCompute::DataCopyScaleAndOffset(uint32_t curBaseN, uint32_t alignBaseN, + uint64_t realScaleOffset) { + // copy scale and offset frome GM + DataCopyPadParams padParams; + DataCopyParams scaleParams; + scaleParams.blockLen = curBaseN * sizeof(BT); + scaleParams.blockCount = 1; + scaleParams.srcStride = 0; + scaleParams.dstStride = 0; + LocalTensor scaleLocal = scaleInQueue.AllocTensor(); + DataCopyPad(scaleLocal, antiScaleGM[realScaleOffset], scaleParams, padParams); + scaleInQueue.EnQue(scaleLocal); + + LocalTensor offsetLocal = offsetInQueue.AllocTensor(); + DataCopyPad(offsetLocal, antiOffsetGM[realScaleOffset], scaleParams, padParams); + offsetInQueue.EnQue(offsetLocal); + + scaleInUb = scaleInQueue.DeQue(); + scaleInUb.SetSize(alignBaseN); + offsetInUb = offsetInQueue.DeQue(); + offsetInUb.SetSize(alignBaseN); +} + +template +using GMMAntiquantComputePerformance = GMMAntiquantCompute; + +template +using GMMAntiquantComputeNorm = GMMAntiquantCompute; + +} // namespace GROUPED_MATMUL + +#endif // GMM_ANTI_QUANT +#endif // ASCENDC_GROUPED_MATMUL_ANTIQUANT_H diff --git a/src/transformer/grouped_matmul/grouped_matmul_antiquant_a16w8_msd.h b/src/transformer/grouped_matmul/grouped_matmul_antiquant_a16w8_msd.h new file mode 100644 index 00000000..aa445adb --- /dev/null +++ b/src/transformer/grouped_matmul/grouped_matmul_antiquant_a16w8_msd.h @@ -0,0 +1,950 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_antiquant_a16w8_msd.h + * \brief + */ +#ifndef ASCENDC_GROUPED_MATMUL_ANTIQUANT_A16W8_MSD_H +#define ASCENDC_GROUPED_MATMUL_ANTIQUANT_A16W8_MSD_H + +#include "grouped_matmul_utils.h" +#include "grouped_matmul.h" + + +#if defined(GMM_ANTI_QUANT) && defined(ORIG_DTYPE_WEIGHT) && defined(DT_INT8) && \ + ORIG_DTYPE_WEIGHT == DT_INT8 +namespace GROUPED_MATMUL { +static constexpr uint32_t A16W8_MSD_STEP = 2; +static constexpr uint32_t A16W8_MSD_PREPROCESS_MAX_GROUP = 12; +static constexpr uint32_t FACTOR_FOR_FLOAT_ALIGN_TO_32 = 8; +static constexpr uint32_t ALIGN_UB_BASE_K = 128; +static constexpr uint32_t POST_SKIP_ITER_NUM = 3; + +struct PreMNConfig { + uint32_t m = 0; + uint32_t k = 0; + uint32_t baseM = 0; + uint32_t baseK = 0; + uint32_t mIdx = 0; + uint32_t kIdx = 0; + uint32_t blockDimM = 0; + uint32_t blockDimK = 0; + uint32_t singleM = 0; + uint32_t singleMTail = 0; + uint64_t mAxisBaseOffset = 0; +}; + +struct PreBaseMNConfig { + uint32_t m = 0; + uint32_t k = 0; + uint64_t mAxisBaseOffset = 0; +}; + +/** @brief GroupMatmul operator Class +*/ +template +class GMMA16W8MSDProcess{ + protected: + using B = typename ComputeType::B; + ComputeType& computeOp; // internal computation operator + const GMMBaseParams* __restrict gmmBaseParams; + const TCubeTiling* __restrict mmTilingData; + + uint32_t blockIdx; + uint32_t coreIdx; + uint32_t groupNum; + uint32_t coreNum; + uint32_t ubCalSize; + int32_t preOffset; + int32_t preOffsetPre; + GM_ADDR groupListPtr; + GlobalTensor groupListGm; + GlobalTensor kListGm; + GlobalTensor nListGm; + + public: + /** @brief constructor */ + __aicore__ inline GMMA16W8MSDProcess(ComputeType& computeOp_) : computeOp(computeOp_) {} + + __aicore__ inline void Init(const GMMBaseParams* __restrict gmmBaseParamsIn, + const TCubeTiling* __restrict mmTilingDataIn, GM_ADDR groupList, GM_ADDR tiling); + + __aicore__ inline void Process(); + + private: + __aicore__ inline void PreProcess(PreBaseMNConfig &preBaseMNConfig, MNConfig &mnConfig, + uint32_t &preGroupIdx, uint32_t &preCoreCount, bool &isPreRequired); + + __aicore__ inline void TailProcess(MNConfig &mnConfig, uint32_t secondHalfIterCount); + + __aicore__ inline void SetMNConfigs(PreBaseMNConfig &preBaseMNConfig, MNConfig &mnConfig); + + __aicore__ inline void UpdateMnConfig(MNConfig &mnConfig); +}; + +template + __aicore__ inline void GMMA16W8MSDProcess::Init(const GMMBaseParams* __restrict gmmBaseParamsIn, + const TCubeTiling* __restrict mmTilingDataIn, GM_ADDR groupList, GM_ADDR tiling) { + blockIdx = GetBlockIdx(); + coreIdx = blockIdx; + int64_t coreRation = GetTaskRation(); + if (coreRation > 1) { + coreIdx /= coreRation; + } + gmmBaseParams = gmmBaseParamsIn; + mmTilingData = mmTilingDataIn; + ubCalSize = gmmBaseParams->ubCalSize; + groupNum = gmmBaseParams->groupNum; + coreNum = gmmBaseParams->coreNum; + groupListPtr = groupList; + preOffset = 0; + preOffsetPre = 0; + if (groupListPtr != nullptr) { + groupListGm.SetGlobalBuffer((__gm__ int64_t*)groupList); + } + GET_TILING_DATA_MEMBER_ADDR(GMMTilingData, gmmArray, gmmArrayAddr, tiling); // custom macro + kListGm.SetGlobalBuffer((__gm__ int32_t*)(gmmArrayAddr + sizeof(int32_t) * MKN_LIST_LEN)); + nListGm.SetGlobalBuffer((__gm__ int32_t*)(gmmArrayAddr + sizeof(int32_t) * MKN_LIST_LEN * 2)); +} + +template +__aicore__ inline void GMMA16W8MSDProcess::SetMNConfigs( + PreBaseMNConfig &preBaseMNConfig, MNConfig &mnConfig) { + preBaseMNConfig.k = kListGm.GetValue(0); + + mnConfig.k = preBaseMNConfig.k; + mnConfig.n = nListGm.GetValue(0); + mnConfig.baseM = mmTilingData->baseM; + mnConfig.baseN = mmTilingData->baseN; + mnConfig.singleM = mnConfig.baseM; + // 2048: least n for case enable singleN 1024; + // 2: according to experiments when n larger than or equal 2k, singleN 1024 has better performence + // 1024: larger singleN can reduce preprocess iter num than reduce syncall nums and has better performence + // 4: when satisfy reuqirements of different singleN, singleN should be quater of n align up to 1024 + mnConfig.singleN = mnConfig.n >= 2048 && mnConfig.n / mnConfig.k >= 2 ? + 1024 * Ceil(mnConfig.n / 4, 1024) : mnConfig.baseN; + mnConfig.singleN = mnConfig.singleN <= ubCalSize ? mnConfig.singleN : ubCalSize; +} + +template +__aicore__ inline void GMMA16W8MSDProcess::UpdateMnConfig(MNConfig &mnConfig) { + if constexpr (B::format == CubeFormat::NZ) { + mnConfig.wBaseOffset += AlignUp<16>(mnConfig.k) * AlignUp<16>(mnConfig.n); // 16: nz format last two dim size + } else { + mnConfig.wBaseOffset += mnConfig.k * mnConfig.n; + } + mnConfig.mAxisBaseOffset += mnConfig.m; + mnConfig.nAxisBaseOffset += mnConfig.n; + mnConfig.xBaseOffset += mnConfig.m * mnConfig.k; + mnConfig.yBaseOffset += mnConfig.m * mnConfig.n; +} + +template +__aicore__ inline void GMMA16W8MSDProcess::PreProcess( + PreBaseMNConfig &preBaseMNConfig, MNConfig &mnConfig, uint32_t &preGroupIdx, uint32_t &preCoreCount, + bool &isPreRequired) { + PreBaseMNConfig preBaseMNConfigs[A16W8_MSD_PREPROCESS_MAX_GROUP]; + uint32_t preValidGroupCount = 0; + while (preCoreCount < coreNum && preValidGroupCount < A16W8_MSD_PREPROCESS_MAX_GROUP && preGroupIdx < groupNum) { + preBaseMNConfig.mAxisBaseOffset += preBaseMNConfig.m; + preBaseMNConfig.m = GetSplitValueFromGroupList(preGroupIdx, preOffsetPre, gmmBaseParams, groupListGm); + preGroupIdx++; + if (preBaseMNConfig.m <= 0) { + continue; + } + preBaseMNConfigs[preValidGroupCount] = preBaseMNConfig; + preValidGroupCount++; + preCoreCount += Ceil(A16W8_MSD_STEP * preBaseMNConfig.m, mnConfig.singleM) * + Ceil(mnConfig.n, mnConfig.singleN); + } + if (preValidGroupCount == 0) { + isPreRequired = false; + return; + } + computeOp.PreProcess(preBaseMNConfigs, preValidGroupCount, mmTilingData->baseM / 2); + preCoreCount = preCoreCount % coreNum; +} + +template +__aicore__ inline void GMMA16W8MSDProcess::TailProcess(MNConfig &mnConfig, uint32_t secondHalfIterCount) { + uint32_t resPostLoop = POST_SKIP_ITER_NUM; + if (secondHalfIterCount < POST_SKIP_ITER_NUM) { + resPostLoop = secondHalfIterCount; + secondHalfIterCount = POST_SKIP_ITER_NUM; + } + for (uint32_t resIdx = 0; resIdx < resPostLoop; ++resIdx) { + computeOp.PostProcess(mnConfig, true, secondHalfIterCount); + secondHalfIterCount++; + } +} + +template +__aicore__ inline void GMMA16W8MSDProcess::Process() { + PreBaseMNConfig preBaseMNConfig; + MNConfig mnConfig; + uint32_t preValidGroupCount = 0; + uint32_t preGroupIdx = 0; + bool isPreRequired = false; + uint32_t secondHalfIterCount = 0; + SetMNConfigs(preBaseMNConfig, mnConfig); + if (mnConfig.k <= 0 || mnConfig.n <= 0) { + return; + } + for (uint32_t groupIdx(0), count(0), curBlock(0), curCount(0), preCoreCount(0); + groupIdx < groupNum; ++groupIdx) { + isPreRequired = preGroupIdx == groupIdx; + if (isPreRequired) { + PreProcess(preBaseMNConfig, mnConfig, preGroupIdx, preCoreCount, isPreRequired); + } + if (groupIdx > 0) { + UpdateMnConfig(mnConfig); + } + mnConfig.m = GetSplitValueFromGroupList(groupIdx, preOffset, gmmBaseParams, groupListGm); + if ASCEND_IS_AIC { + if (isPreRequired) { + CrossCoreWaitFlag(SYNC_AIV_AIC_FLAG); + } + } + if (mnConfig.m <= 0) { + continue; + } + mnConfig.blockDimM = Ceil(A16W8_MSD_STEP * mnConfig.m, mnConfig.singleM); + mnConfig.blockDimN = Ceil(mnConfig.n, mnConfig.singleN); + curCount = count + mnConfig.blockDimM * mnConfig.blockDimN; + curBlock = coreIdx >= count ? coreIdx : coreIdx + coreNum; + while (curBlock < curCount) { + mnConfig.mIdx = (curBlock - count) / mnConfig.blockDimN; + mnConfig.nIdx = (curBlock - count) % mnConfig.blockDimN; + computeOp.MMCompute(mnConfig); + computeOp.PostProcess(mnConfig, false, secondHalfIterCount); + secondHalfIterCount++; + curBlock += coreNum; + } + count = curCount % coreNum; + } + TailProcess(mnConfig, secondHalfIterCount); +} + +/** @brief intenal computation class +*/ +template +class GMMA16W8MSDCompute { + public: + using AT = typename mmType::AT::T; + using BT = typename mmType::BT::T; + using B = typename mmType::BT; + using CT = typename mmType::CT::T; + using WT = DTYPE_WEIGHT; + constexpr static bool transposeX = mmType::AT::isTrans; + constexpr static bool transposeW = mmType::BT::isTrans; + /** @brief constructor */ + __aicore__ inline GMMA16W8MSDCompute(typename mmType::MT& mm_) : mm(mm_) {} + + __aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR scale, GM_ADDR offset, + GM_ADDR antiquantScale, GM_ADDR antiquantOffset, GM_ADDR group_list, + GM_ADDR perTokenScale, GM_ADDR y, GM_ADDR workspace, + const GMMBaseParams* __restrict gmmBaseParams, + const TCubeTiling* __restrict mmTilingData, TPipe* tPipe); + + __aicore__ inline void MMCompute(MNConfig& mnConfig); + + __aicore__ inline void PreProcess(PreBaseMNConfig *preBaseMNConfigs, uint32_t preValidGroupCount, + uint32_t maxMPerGroup); + + __aicore__ inline void PostProcess(MNConfig& mnConfig, bool isLastGroup, uint32_t secondHalfIterCount); + + private: + __aicore__ inline void InitLocalTensor(); + + __aicore__ inline void InitWorkspace(GM_ADDR workspace); + + __aicore__ inline void PreProcessTiling(uint32_t m, uint32_t curCoreNum, uint32_t startCoreIdx, + PreBaseMNConfig &preBaseMNConfig, PreMNConfig &preMNConfig); + + __aicore__ inline void PreProcessCalc(uint32_t curCoreNum, uint32_t startCoreIdx, uint32_t &resSyncCount, + PreMNConfig &preMNConfig); + + __aicore__ inline void PreProcessSync(uint32_t preValidGroupCount, uint32_t &resSyncCount); + + __aicore__ inline void CopyOriginInput(uint32_t k, uint32_t curBaseM, uint32_t curBaseK, uint64_t xGmOffset); + + __aicore__ inline void CalcReduceSum(uint32_t curBaseM, uint32_t curBaseK, uint64_t gmReduceSumOffset); + + __aicore__ inline void CalcAMax(uint32_t curBaseM, uint32_t curBaseK, uint64_t gmReduceMaxOffset); + + __aicore__ inline void CopyInAmax(uint32_t curBaseM, uint64_t gmReduceMaxOffset); + + __aicore__ inline void CalcAMatrix(PreMNConfig &preMNConfig, uint32_t curBaseM, uint32_t curBaseK, + uint64_t gmReduceMaxOffset, uint64_t aOffsetGm); + + __aicore__ inline void CalcASum(MNConfig& postMNConfig, uint32_t curBaseM, uint32_t curBaseN, uint32_t offsetM, + uint64_t offsetAndScaleOffset); + + __aicore__ inline void ProcessScaleAndBias(uint32_t n, uint32_t curBaseN, uint64_t offsetAndScaleOffset); + + __aicore__ inline void ProcessC1C2(MNConfig& postMNConfig, uint32_t curBaseM, uint32_t curBaseN, uint32_t offsetM, + uint32_t curSingleM); + + __aicore__ inline void CalcCMatrix(MNConfig& postMNConfig, uint32_t curBaseM, uint32_t curBaseN, uint32_t offsetM); + + __aicore__ inline void CopyOutFinalResult(uint32_t n, uint32_t curBaseM, uint32_t curBaseN, uint64_t yOffset); + + __aicore__ inline GlobalTensor SetGlobalBufferW(uint32_t tailN, MNConfig& mnConfig); + + __aicore__ inline uint64_t SetWOffset(uint32_t tailN, uint32_t k); + + TPipe* pipe; + typename mmType::MT& mm; // matmul operator + bool hasBias = false; + GM_ADDR weightTensorPtr; + GlobalTensor xGm; + GlobalTensor biasGm; + GlobalTensor scaleGm; + GlobalTensor offsetGm; + GlobalTensor mmOutGm; + GlobalTensor aMatrixGm; + GlobalTensor globalMaxGm; + GlobalTensor localSumGm; + GlobalTensor yGm; + + // define the que + TQue vecInQueue; + TQue vecOutQueue; + TQue ReduceResultInQueue; + TBuf tmpBuff; + // LocalTensor used in stage1 (preprocess) + LocalTensor s1MiddleResult1; + LocalTensor s1MiddleResult2; + LocalTensor s1TmpBuf; + LocalTensor s1A1A2FP16; + // LocalTensor used in stage2 and stage3 (postprocess) + LocalTensor s23MiddleResult1; + LocalTensor s23MiddleResult2; + LocalTensor s23MiddleResult3; + LocalTensor cTmp; + LocalTensor processedScale; + LocalTensor processedBias; + LocalTensor globalReduceSum; + LocalTensor s23TmpBuf; + LocalTensor aMaxInUb; + + uint32_t cubeBaseM; + uint32_t ubCalSizeS1; + uint32_t ubCalSizeS2; + uint32_t coreNum; + uint32_t totalM; + uint32_t aicIdx; + uint32_t aivIdx; + uint32_t ubRestBytes; + uint32_t aMatrixSize; + MNConfig mnConfigs[POST_SKIP_ITER_NUM + 1]; +}; + +template +__aicore__ inline void GMMA16W8MSDCompute::Init(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, + GM_ADDR scale, GM_ADDR offset, GM_ADDR antiquantScale, + GM_ADDR antiquantOffset, GM_ADDR group_list, + GM_ADDR perTokenScale, GM_ADDR y, GM_ADDR workspace, + const GMMBaseParams* __restrict gmmBaseParams, + const TCubeTiling* __restrict mmTilingData, + TPipe* tPipe) { + weightTensorPtr = weight; + pipe = tPipe; + cubeBaseM = mmTilingData->baseM; + ubCalSizeS1 = 2 * gmmBaseParams->ubCalSize; + ubCalSizeS2 = gmmBaseParams->ubCalSize; + totalM = gmmBaseParams->m; + coreNum = gmmBaseParams->coreNum; + ubRestBytes = gmmBaseParams->ubRestBytes; + aMatrixSize = gmmBaseParams->workspaceSize; + hasBias = gmmBaseParams->hasBias == 1; + aicIdx = GetBlockIdx() / GetTaskRation(); + aivIdx = GetBlockIdx(); + + xGm.SetGlobalBuffer(GetTensorAddr(0, x)); + scaleGm.SetGlobalBuffer(GetTensorAddr(0, antiquantScale)); + offsetGm.SetGlobalBuffer(GetTensorAddr(0, antiquantOffset)); + yGm.SetGlobalBuffer(GetTensorAddr(0, y)); + if (hasBias) { + biasGm.SetGlobalBuffer(GetTensorAddr(0, bias)); + } + InitLocalTensor(); + InitWorkspace(workspace); +} + +template +__aicore__ inline void GMMA16W8MSDCompute::InitLocalTensor() { + if ASCEND_IS_AIC { + return; + } + uint32_t alignedCoreNum = AlignUp<8>(coreNum); + pipe->InitBuffer(vecInQueue, 1, ubCalSizeS1 * sizeof(half)); + pipe->InitBuffer(vecOutQueue, 1, ubCalSizeS1 * sizeof(int8_t)); + pipe->InitBuffer(ReduceResultInQueue, 1, cubeBaseM / 2 * alignedCoreNum * sizeof(float)); + pipe->InitBuffer(tmpBuff, ubRestBytes); + uint32_t s1TmpUbOffset = 0; + uint32_t s23TmpUbOffset = 0; + // local tensor for stage1 + s1MiddleResult1 = tmpBuff.GetWithOffset(ubCalSizeS1, 0); + s1TmpUbOffset += ubCalSizeS1 * sizeof(float); + s1MiddleResult2 = tmpBuff.GetWithOffset(ubCalSizeS1, s1TmpUbOffset); + s1TmpUbOffset += ubCalSizeS1 * sizeof(float); + s1TmpBuf = tmpBuff.GetWithOffset(ubCalSizeS1, s1TmpUbOffset); + s1A1A2FP16 = tmpBuff.GetWithOffset(ubCalSizeS1, s1TmpUbOffset); + // local tensor for stage2 and stage3 + s23MiddleResult1 = tmpBuff.GetWithOffset(ubCalSizeS2, 0); + s23TmpUbOffset += ubCalSizeS2 * sizeof(float); + cTmp = tmpBuff.GetWithOffset(ubCalSizeS2, s23TmpUbOffset); + s23TmpUbOffset += ubCalSizeS2 * sizeof(float); + processedScale = tmpBuff.GetWithOffset(ubCalSizeS2, s23TmpUbOffset); + s23TmpUbOffset += ubCalSizeS2 * sizeof(float); + processedBias = tmpBuff.GetWithOffset(ubCalSizeS2, s23TmpUbOffset); + s23TmpUbOffset += ubCalSizeS2 * sizeof(float); + globalReduceSum = tmpBuff.GetWithOffset(ubCalSizeS2, s23TmpUbOffset); + s23MiddleResult2 = tmpBuff.GetWithOffset(ubCalSizeS2, s23TmpUbOffset); + s23TmpUbOffset += ubCalSizeS2 * sizeof(float); + s23MiddleResult3 = tmpBuff.GetWithOffset(ubCalSizeS2, s23TmpUbOffset); +} + +template +__aicore__ inline void GMMA16W8MSDCompute::InitWorkspace(GM_ADDR workspace) { + uint32_t usedWorkspaceSize = 0; + globalMaxGm.SetGlobalBuffer((__gm__ float *)(workspace)); + if (aivIdx == 0) { // 0: use aiv 0 to init gm for amax + InitOutput(globalMaxGm, totalM * FACTOR_FOR_FLOAT_ALIGN_TO_32); + } + usedWorkspaceSize += totalM * sizeof(float) * FACTOR_FOR_FLOAT_ALIGN_TO_32; + localSumGm.SetGlobalBuffer((__gm__ float *)(workspace + usedWorkspaceSize)); + if (aivIdx == 1) { // 1: use aiv 1 to init gm for asum + InitOutput(localSumGm, totalM * coreNum); + } + usedWorkspaceSize += totalM * coreNum * sizeof(float); + aMatrixGm.SetGlobalBuffer((__gm__ int8_t *)(workspace + usedWorkspaceSize)); + usedWorkspaceSize += aMatrixSize * sizeof(int8_t); + mmOutGm.SetGlobalBuffer((__gm__ int32_t *)(workspace + usedWorkspaceSize)); + if ASCEND_IS_AIV { + SyncAll(); + } +} + +template +__aicore__ inline uint64_t GMMA16W8MSDCompute::SetWOffset(uint32_t tailN, uint32_t k) { + uint64_t wOffset = 0; + if constexpr (mmType::BT::format == CubeFormat::NZ && transposeW) { + wOffset = tailN * (UB_BLOCK_UNIT_SIZE / sizeof(BT)); // 32: quant is 32, float16 is 16 + } else if constexpr (mmType::BT::format == CubeFormat::NZ) { + wOffset = tailN * AlignUp<16>(k); // 16: nz format last two dim size + } else if constexpr (transposeW) { + wOffset = k * tailN; + } else { + wOffset = tailN; + } + return wOffset; +} + +template +__aicore__ inline GlobalTensor GMMA16W8MSDCompute::SetGlobalBufferW( + uint32_t tailN, MNConfig& mnConfig) { + uint64_t wOffset = SetWOffset(tailN, mnConfig.k); + GlobalTensor weightGmLocal; + weightGmLocal.SetGlobalBuffer(GetTensorAddr(0, weightTensorPtr) + mnConfig.wBaseOffset + wOffset); + if (mnConfig.blockDimM == 1) { + weightGmLocal.SetL2CacheHint(CacheMode::CACHE_MODE_DISABLE); + } + return weightGmLocal; +} + +template +__aicore__ inline void GMMA16W8MSDCompute::PreProcess( + PreBaseMNConfig *preBaseMNConfigs, uint32_t preValidGroupCount, uint32_t maxMPerGroup) { + if ASCEND_IS_AIC { + return; + } + PreMNConfig preMNConfig; + uint32_t usedCoreNum = 0; // num of core not used + uint32_t curCoreNum = 0; + uint32_t tokenNumEachPreIter = 0; + uint32_t splitGroupNum = 0; + // since limitation on matmul baseM, data block of preprocess cannot only split by groupIdx. If m of a group + // larger than half of matmul baseM, this group should splited in preprocess, and should get group num after + // split first. + for (uint32_t gIdx = 0; gIdx < preValidGroupCount; ++gIdx) { + // ensure each group after splited has at least one core when preprocessing. + splitGroupNum += Ceil(preBaseMNConfigs[gIdx].m, maxMPerGroup); + tokenNumEachPreIter += preBaseMNConfigs[gIdx].m; + } + // num of core need to allocate to different group + uint32_t unAllocatedCoreNum = splitGroupNum <= coreNum ? coreNum - splitGroupNum : 0; + uint32_t resSyncCount = Ceil(splitGroupNum, coreNum); + uint32_t resTokenNum = tokenNumEachPreIter; // num of tokens not have corresponding core + for (uint32_t gIdx = 0; gIdx < preValidGroupCount; ++gIdx) { + uint32_t curGroupResTokenNum = preBaseMNConfigs[gIdx].m; + // if m of the group larger than half of matmul baseM, preprocess data step by step, + // each step accepts half of matmul baseM * k size of data. + while (curGroupResTokenNum > maxMPerGroup) { + curCoreNum = Ceil(unAllocatedCoreNum * maxMPerGroup, resTokenNum) + 1; + PreProcessTiling(maxMPerGroup, curCoreNum, usedCoreNum, preBaseMNConfigs[gIdx], preMNConfig); + PreProcessCalc(curCoreNum, usedCoreNum, resSyncCount, preMNConfig); + unAllocatedCoreNum -= (curCoreNum - 1); + resTokenNum -= maxMPerGroup; + usedCoreNum = (usedCoreNum + curCoreNum) % coreNum; + curGroupResTokenNum -= maxMPerGroup; + preBaseMNConfigs[gIdx].mAxisBaseOffset += maxMPerGroup; + } + curCoreNum = (unAllocatedCoreNum * curGroupResTokenNum / resTokenNum) + 1; + PreProcessTiling(curGroupResTokenNum, curCoreNum, usedCoreNum, preBaseMNConfigs[gIdx], preMNConfig); + PreProcessCalc(curCoreNum, usedCoreNum, resSyncCount, preMNConfig); + unAllocatedCoreNum -= (curCoreNum - 1); + resTokenNum -= curGroupResTokenNum; + usedCoreNum = (usedCoreNum + curCoreNum) % coreNum; + } + PreProcessSync(preValidGroupCount, resSyncCount); +} + +template +__aicore__ inline void GMMA16W8MSDCompute::PreProcessSync( + uint32_t preValidGroupCount, uint32_t &resSyncCount) { + while (resSyncCount > 0) { + SyncAll(); + resSyncCount -= 1; + } + SyncAll(); + CrossCoreSetFlag(SYNC_AIV_AIC_FLAG); +} + +template +__aicore__ inline void GMMA16W8MSDCompute::PreProcessTiling( + uint32_t m, uint32_t curCoreNum, uint32_t startCoreIdx, + PreBaseMNConfig &preBaseMNConfig, PreMNConfig &preMNConfig) { + if (aivIdx < startCoreIdx || aivIdx >= curCoreNum + startCoreIdx) { + return; + } + preMNConfig.m = m; + preMNConfig.k = preBaseMNConfig.k; + preMNConfig.mAxisBaseOffset = preBaseMNConfig.mAxisBaseOffset; + if (m < curCoreNum) { + preMNConfig.baseK = preMNConfig.k / curCoreNum; + preMNConfig.baseK = AlignUp(preMNConfig.baseK, ALIGN_UB_BASE_K); + preMNConfig.blockDimK = Ceil(preMNConfig.k, preMNConfig.baseK); + preMNConfig.blockDimM = curCoreNum / preMNConfig.blockDimK; + } else { + preMNConfig.baseK = preMNConfig.k ; + preMNConfig.blockDimK = 1; + preMNConfig.blockDimM = curCoreNum; + } + preMNConfig.singleM = Ceil(m, preMNConfig.blockDimM); + preMNConfig.blockDimM = Ceil(m, preMNConfig.singleM); // prevent wrapping when calc singleMTail + preMNConfig.singleMTail = m - (preMNConfig.blockDimM - 1) * preMNConfig.singleM; + preMNConfig.baseM = ubCalSizeS1 / preMNConfig.baseK; + preMNConfig.baseM = preMNConfig.baseM < preMNConfig.singleM ? preMNConfig.baseM : preMNConfig.singleM; + preMNConfig.mIdx = (aivIdx - startCoreIdx) / preMNConfig.blockDimK; + preMNConfig.kIdx = (aivIdx - startCoreIdx) % preMNConfig.blockDimK; +} + +template +__aicore__ inline void GMMA16W8MSDCompute::PreProcessCalc( + uint32_t curCoreNum, uint32_t startCoreIdx, uint32_t &resSyncCount, PreMNConfig &preMNConfig) { + if (aivIdx < startCoreIdx || aivIdx >= startCoreIdx + curCoreNum) { + return; + } + if (aivIdx < startCoreIdx + curCoreNum && aivIdx >= startCoreIdx + preMNConfig.blockDimM * preMNConfig.blockDimK) { + return; + } + uint32_t curBaseK = preMNConfig.kIdx < preMNConfig.blockDimK - 1 ? + preMNConfig.baseK : preMNConfig.k - preMNConfig.kIdx * preMNConfig.baseK; + uint32_t curBaseM = preMNConfig.baseM; + uint32_t curSingleM = preMNConfig.mIdx < preMNConfig.blockDimM - 1 ? + preMNConfig.singleM : preMNConfig.m - preMNConfig.mIdx * preMNConfig.singleM; + for (uint32_t offsetM = 0; offsetM < curSingleM; offsetM += preMNConfig.baseM) { + if (offsetM + preMNConfig.baseM >= curSingleM) { + curBaseM = curSingleM - offsetM; + } + uint64_t offsetBase = preMNConfig.mAxisBaseOffset + preMNConfig.mIdx * preMNConfig.singleM + offsetM; + uint64_t xGmOffset = offsetBase * preMNConfig.k + preMNConfig.kIdx * preMNConfig.baseK; + CopyOriginInput(preMNConfig.k, curBaseM, curBaseK, xGmOffset); + uint64_t gmReduceMaxOffset = offsetBase * FACTOR_FOR_FLOAT_ALIGN_TO_32; + uint64_t gmReduceSumOffset = offsetBase * coreNum + (aivIdx - startCoreIdx); + CalcAMax(curBaseM, curBaseK, gmReduceMaxOffset); + CalcReduceSum(curBaseM, curBaseK, gmReduceSumOffset); + } + SyncAll(); + resSyncCount -= 1; + curBaseM = preMNConfig.baseM; + for (uint32_t offsetM = 0; offsetM < curSingleM; offsetM += preMNConfig.baseM) { + if (offsetM + preMNConfig.baseM >= curSingleM) { + curBaseM = curSingleM - offsetM; + } + uint64_t offsetBase = preMNConfig.mAxisBaseOffset + preMNConfig.mIdx * preMNConfig.singleM + offsetM; + uint64_t xGmOffset = offsetBase * preMNConfig.k + preMNConfig.kIdx * preMNConfig.baseK; + uint64_t gmReduceMaxOffset = offsetBase * FACTOR_FOR_FLOAT_ALIGN_TO_32; + uint64_t aOffsetGm = + (preMNConfig.mAxisBaseOffset * 2 + preMNConfig.mIdx * preMNConfig.singleM + offsetM) * preMNConfig.k + + preMNConfig.kIdx * preMNConfig.baseK; + CopyOriginInput(preMNConfig.k, curBaseM, curBaseK, xGmOffset); + CalcAMatrix(preMNConfig, curBaseM, curBaseK, gmReduceMaxOffset, aOffsetGm); + } +} + +template +__aicore__ inline void GMMA16W8MSDCompute::CopyOriginInput( + uint32_t k, uint32_t curBaseM, uint32_t curBaseK, uint64_t xGmOffset) { + uint32_t alignedBaseK = AlignUp<32>(curBaseK); + LocalTensor xLocal = vecInQueue.AllocTensor(); + DataCopyPad2D(xLocal, xGm[xGmOffset], curBaseM, curBaseK, k); + vecInQueue.EnQue(xLocal); + LocalTensor xFP16InUb = vecInQueue.DeQue(); + Cast(s1MiddleResult1, xFP16InUb, RoundMode::CAST_NONE, curBaseM * alignedBaseK); + PipeBarrier(); + vecInQueue.FreeTensor(xFP16InUb); + } + +template +__aicore__ inline void GMMA16W8MSDCompute::CalcReduceSum( + uint32_t curBaseM, uint32_t curBaseK, uint64_t gmReduceSumOffset) { + uint32_t alignedBaseK = AlignUp<32>(curBaseK); + LocalTensor blockReduceSumInUb = vecOutQueue.AllocTensor(); + for (uint32_t idxM = 0; idxM < curBaseM; ++idxM) { + ReduceSum(blockReduceSumInUb[idxM * FACTOR_FOR_FLOAT_ALIGN_TO_32], s1MiddleResult1[idxM * alignedBaseK], + s1TmpBuf[idxM * alignedBaseK], curBaseK); + } + PipeBarrier(); + vecOutQueue.EnQue(blockReduceSumInUb); + LocalTensor blockReduceSum = vecOutQueue.DeQue(); + DataCopyExtParams aSumOutParams; + aSumOutParams.blockLen = sizeof(float); + aSumOutParams.blockCount = curBaseM; + aSumOutParams.srcStride = 0; + aSumOutParams.dstStride = (coreNum - 1) * sizeof(float); + DataCopyPad(localSumGm[gmReduceSumOffset], blockReduceSum, aSumOutParams); + vecOutQueue.FreeTensor(blockReduceSum); +} + +template +__aicore__ inline void GMMA16W8MSDCompute::CalcAMax( + uint32_t curBaseM, uint32_t curBaseK, uint64_t gmReduceMaxOffset) { + uint32_t alignedBaseK = AlignUp<32>(curBaseK); + Abs(s1MiddleResult2, s1MiddleResult1, curBaseM * alignedBaseK); + PipeBarrier(); + + // 计算ReduceMax + LocalTensor blockReduceMaxInUb = vecOutQueue.AllocTensor(); + for (uint32_t idxM = 0; idxM < curBaseM; ++idxM) { + ReduceMax(blockReduceMaxInUb[idxM * FACTOR_FOR_FLOAT_ALIGN_TO_32], s1MiddleResult2[idxM * alignedBaseK], + s1TmpBuf[idxM * alignedBaseK], curBaseK, false); + } + PipeBarrier(); + vecOutQueue.EnQue(blockReduceMaxInUb); + LocalTensor blockReduceMax = vecOutQueue.DeQue(); + SetAtomicMax(); + DataCopyExtParams aMaxOutParams; + aMaxOutParams.blockLen = FACTOR_FOR_FLOAT_ALIGN_TO_32 * sizeof(float); + aMaxOutParams.blockCount = curBaseM; + aMaxOutParams.srcStride = 0; + aMaxOutParams.dstStride = 0; + DataCopyPad(globalMaxGm[gmReduceMaxOffset], blockReduceMax, aMaxOutParams); + SetAtomicNone(); + vecOutQueue.FreeTensor(blockReduceMax); +} + +template +__aicore__ inline void GMMA16W8MSDCompute::CopyInAmax(uint32_t curBaseM, uint64_t gmReduceMaxOffset) { + // copy amax from gm + LocalTensor aMaxLocal = ReduceResultInQueue.AllocTensor(); + DataCopyPadExtParams padParams; + DataCopyExtParams aMaxInParams; + aMaxInParams.blockLen = FACTOR_FOR_FLOAT_ALIGN_TO_32 * sizeof(float); + aMaxInParams.blockCount = curBaseM; + aMaxInParams.srcStride = 0; + aMaxInParams.dstStride = 0; + DataCopyPad(aMaxLocal, globalMaxGm[gmReduceMaxOffset], aMaxInParams, padParams); + ReduceResultInQueue.EnQue(aMaxLocal); + aMaxInUb = ReduceResultInQueue.DeQue(); +} + +template +__aicore__ inline void GMMA16W8MSDCompute::CalcAMatrix( + PreMNConfig &preMNConfig, uint32_t curBaseM, uint32_t curBaseK, uint64_t gmReduceMaxOffset, uint64_t aOffsetGm) { + uint32_t alignedBaseK = AlignUp<32>(curBaseK); + CopyInAmax(curBaseM, gmReduceMaxOffset); + event_t eventIdMTE2ToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_S)); + SetFlag(eventIdMTE2ToS); + WaitFlag(eventIdMTE2ToS); + // calc a_tmp = 127 * x / amax for each row + for (uint32_t idxM = 0; idxM < curBaseM; ++idxM) { + float invertAMaxPerRow = 127.0f / aMaxInUb(idxM * FACTOR_FOR_FLOAT_ALIGN_TO_32); + event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + SetFlag(eventIdSToV); + WaitFlag(eventIdSToV); + Muls(s1MiddleResult2[idxM * alignedBaseK], s1MiddleResult1[idxM * alignedBaseK], invertAMaxPerRow, + alignedBaseK); // a_tmp + } + PipeBarrier(); + ReduceResultInQueue.FreeTensor(aMaxInUb); + // calc a1 + LocalTensor a1Int8InUb = vecOutQueue.AllocTensor(); + + Cast(s1MiddleResult1, s1MiddleResult2, RoundMode::CAST_ROUND, curBaseM * alignedBaseK); // a1 + PipeBarrier(); + Cast(s1A1A2FP16, s1MiddleResult1, RoundMode::CAST_NONE, curBaseM * alignedBaseK); + PipeBarrier(); + Cast(a1Int8InUb, s1A1A2FP16, RoundMode::CAST_NONE, curBaseM * alignedBaseK); + vecOutQueue.EnQue(a1Int8InUb); + LocalTensor a1Int8 = vecOutQueue.DeQue(); + DataCopyPad2D(aMatrixGm[aOffsetGm], a1Int8, curBaseM, curBaseK, alignedBaseK, preMNConfig.k); + + // calc a2 + PipeBarrier(); + Sub(s1TmpBuf, s1MiddleResult2, s1MiddleResult1, curBaseM * alignedBaseK); // a_tmp - a1 + PipeBarrier(); + Muls(s1MiddleResult1, s1TmpBuf, static_cast(254), curBaseM * alignedBaseK); // 254 * (a_tmp - a1) + PipeBarrier(); + Cast(s1MiddleResult2, s1MiddleResult1, RoundMode::CAST_ROUND, curBaseM * alignedBaseK); // a2 + PipeBarrier(); + Cast(s1A1A2FP16, s1MiddleResult2, RoundMode::CAST_NONE, curBaseM * alignedBaseK); + vecOutQueue.FreeTensor(a1Int8); + LocalTensor a2Int8InUb = vecOutQueue.AllocTensor(); + PipeBarrier(); + Cast(a2Int8InUb, s1A1A2FP16, RoundMode::CAST_NONE, curBaseM * alignedBaseK); + vecOutQueue.EnQue(a2Int8InUb); + LocalTensor a2Int8 = vecOutQueue.DeQue(); + DataCopyPad2D(aMatrixGm[aOffsetGm + preMNConfig.m * preMNConfig.k], a2Int8, + curBaseM, curBaseK, alignedBaseK, preMNConfig.k); + vecOutQueue.FreeTensor(a2Int8); +} + +template +__aicore__ inline void GMMA16W8MSDCompute::MMCompute(MNConfig& mnConfig) { + uint32_t tailN = mnConfig.nIdx * mnConfig.singleN; + uint64_t outOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.n + tailN; + + mnConfig.workSpaceOffset = outOffset + A16W8_MSD_STEP * mnConfig.yBaseOffset; + if ASCEND_IS_AIC { + uint32_t curSingleN = mnConfig.nIdx < mnConfig.blockDimN - 1 ? mnConfig.singleN : mnConfig.n - tailN; + uint32_t curSingleM = mnConfig.mIdx < mnConfig.blockDimM - 1 ? mnConfig.singleM : + A16W8_MSD_STEP * mnConfig.m - mnConfig.mIdx * mnConfig.singleM; + uint64_t xOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.k + A16W8_MSD_STEP * mnConfig.xBaseOffset; + // init global buffer + GlobalTensor weightGm = SetGlobalBufferW(tailN, mnConfig); + mm.SetOrgShape(A16W8_MSD_STEP * mnConfig.m, mnConfig.n, mnConfig.k); + mm.SetSingleShape(curSingleM, curSingleN, mnConfig.k); + mm.SetTensorA(aMatrixGm[xOffset], transposeX); + mm.SetTensorB(weightGm, transposeW); + while (mm.Iterate()) { + mm.GetTensorC(mmOutGm[mnConfig.workSpaceOffset]); + } + CrossCoreSetFlag(SYNC_AIC_AIV_FLAG); + } +} + +template +__aicore__ inline void GMMA16W8MSDCompute::PostProcess( + MNConfig& mnConfig, bool isLastGroup, uint32_t secondHalfIterCount) { + if ASCEND_IS_AIC { + return; + } + if (!isLastGroup) { + mnConfigs[secondHalfIterCount % (POST_SKIP_ITER_NUM + 1)] = mnConfig; + if (secondHalfIterCount < POST_SKIP_ITER_NUM) { + return; + } + } + MNConfig postMNConfig = mnConfigs[(secondHalfIterCount - POST_SKIP_ITER_NUM) % (POST_SKIP_ITER_NUM + 1)]; + uint32_t tailN = postMNConfig.nIdx * postMNConfig.singleN; + uint32_t curCubeSingleM = postMNConfig.mIdx < postMNConfig.blockDimM - 1 ? + postMNConfig.singleM : A16W8_MSD_STEP * postMNConfig.m - postMNConfig.mIdx * postMNConfig.singleM; + uint32_t curBaseN = postMNConfig.nIdx < postMNConfig.blockDimN - 1 ? + postMNConfig.singleN : postMNConfig.n - tailN; + uint32_t alignedBaseN = AlignUp<32>(curBaseN); + uint32_t curBaseM = ubCalSizeS2 / alignedBaseN; + uint32_t curSingleM = curCubeSingleM / 2; + curBaseM = curBaseM < curSingleM ? curBaseM : curSingleM; + postMNConfig.singleM /= 2; + + uint64_t offsetAndScaleOffset = postMNConfig.nAxisBaseOffset + tailN; + ProcessScaleAndBias(postMNConfig.n, curBaseN, offsetAndScaleOffset); + for (uint32_t offsetM = 0; offsetM < curSingleM; offsetM += curBaseM) { + if (offsetM + curBaseM >= curSingleM) { + curBaseM = curSingleM - offsetM; + } + CalcASum(postMNConfig, curBaseM, curBaseN, offsetM, offsetAndScaleOffset); + if (offsetM == 0) { // only first iter need to wait for cube + CrossCoreWaitFlag(SYNC_AIC_AIV_FLAG); + } + ProcessC1C2(postMNConfig, curBaseM, curBaseN, offsetM, curSingleM); + CalcCMatrix(postMNConfig, curBaseM, curBaseN, offsetM); + uint64_t yOffset = (postMNConfig.mIdx * postMNConfig.singleM + offsetM) * postMNConfig.n + \ + postMNConfig.nIdx * postMNConfig.singleN + postMNConfig.yBaseOffset; + CopyOutFinalResult(postMNConfig.n, curBaseM, curBaseN, yOffset); + } +} + +template +__aicore__ inline void GMMA16W8MSDCompute::CalcASum( + MNConfig& postMNConfig, uint32_t curBaseM, uint32_t curBaseN, uint32_t offsetM, uint64_t offsetAndScaleOffset) { + uint32_t alignedBaseN = AlignUp<32>(curBaseN); + uint32_t alignedCoreNum = AlignUp<8>(coreNum); + // process offset + LocalTensor offsetF16 = vecInQueue.AllocTensor(); + DataCopyPad2D(offsetF16, offsetGm[offsetAndScaleOffset], 1, curBaseN, postMNConfig.n); + vecInQueue.EnQue(offsetF16); + LocalTensor offsetF16InUb = vecInQueue.DeQue(); + Cast(s23MiddleResult1, offsetF16InUb, RoundMode::CAST_NONE, alignedBaseN); + PipeBarrier(); + vecInQueue.FreeTensor(offsetF16InUb); + + // calc global sum and mul with offset + LocalTensor localSum = ReduceResultInQueue.AllocTensor(); + DataCopyExtParams params; + params.blockCount = curBaseM; + params.blockLen = coreNum * sizeof(float); + params.srcStride = 0; + params.dstStride = 0; + + DataCopyPadExtParams padParams; + padParams.isPad = true; + padParams.rightPadding = alignedCoreNum - coreNum; + padParams.leftPadding = 0; + padParams.paddingValue = 0; + uint32_t gmReduceSumOffset = (postMNConfig.mAxisBaseOffset + postMNConfig.mIdx * postMNConfig.singleM + offsetM) * + coreNum; + DataCopyPad(localSum, localSumGm[gmReduceSumOffset], params, padParams); + ReduceResultInQueue.EnQue(localSum); + LocalTensor localSumInUb = ReduceResultInQueue.DeQue(); + for (uint32_t idxM = 0; idxM < curBaseM; ++idxM) { + ReduceSum(globalReduceSum[idxM * FACTOR_FOR_FLOAT_ALIGN_TO_32], localSumInUb[idxM * alignedCoreNum], + s23MiddleResult3[idxM * alignedBaseN], alignedCoreNum); + event_t eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + SetFlag(eventIdVToS); + WaitFlag(eventIdVToS); + float aSumPerRow = globalReduceSum.GetValue(idxM * FACTOR_FOR_FLOAT_ALIGN_TO_32); + event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + SetFlag(eventIdSToV); + WaitFlag(eventIdSToV); + Muls(cTmp[idxM * alignedBaseN], s23MiddleResult1, aSumPerRow, alignedBaseN); // c_tmp = offset * asum + } + PipeBarrier(); + ReduceResultInQueue.FreeTensor(localSumInUb); +} + +template +__aicore__ inline void GMMA16W8MSDCompute::ProcessScaleAndBias( + uint32_t n, uint32_t curBaseN, uint64_t offsetAndScaleOffset) { + uint32_t alignedBaseN = AlignUp<32>(curBaseN); + LocalTensor scaleF16 = vecInQueue.AllocTensor(); + DataCopyPad2D(scaleF16, scaleGm[offsetAndScaleOffset], 1, curBaseN, n); + vecInQueue.EnQue(scaleF16); + LocalTensor scaleF16InUb = vecInQueue.DeQue(); + Cast(processedScale, scaleF16InUb, RoundMode::CAST_NONE, alignedBaseN); + PipeBarrier(); + vecInQueue.FreeTensor(scaleF16InUb); + if (hasBias) { + #if ORIG_DTYPE_X == DT_FLOAT16 + LocalTensor biasF16 = vecInQueue.AllocTensor(); + DataCopyPad2D(biasF16, biasGm[offsetAndScaleOffset], 1, curBaseN, n); + vecInQueue.EnQue(biasF16); + LocalTensor biasInUb = vecInQueue.DeQue(); + Cast(processedBias, biasInUb, RoundMode::CAST_NONE, alignedBaseN); + vecInQueue.FreeTensor(biasInUb); + #else + event_t eventIdVToMte2 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2)); + SetFlag(eventIdVToMte2); + WaitFlag(eventIdVToMte2); + DataCopyPad2D(processedBias, biasGm[offsetAndScaleOffset], 1, curBaseN, n); + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + #endif + } +} + +template +__aicore__ inline void GMMA16W8MSDCompute::ProcessC1C2( + MNConfig& postMNConfig, uint32_t curBaseM, uint32_t curBaseN, uint32_t offsetM, uint32_t curSingleM) { + uint32_t alignedBaseN = AlignUp<32>(curBaseN); + LocalTensor c2S32 = vecInQueue.AllocTensor(); + uint64_t c2Offset = postMNConfig.workSpaceOffset + (curSingleM + offsetM) * postMNConfig.n; + DataCopyPad2D(c2S32, mmOutGm[c2Offset], curBaseM, curBaseN, postMNConfig.n); + vecInQueue.EnQue(c2S32); + LocalTensor c2S32InUb = vecInQueue.DeQue(); + Cast(s23MiddleResult2, c2S32InUb, RoundMode::CAST_NONE, curBaseM * alignedBaseN); // c2 + PipeBarrier(); + vecInQueue.FreeTensor(c2S32InUb); + + LocalTensor c1S32 = vecInQueue.AllocTensor(); + uint64_t c1Offset = postMNConfig.workSpaceOffset + offsetM * postMNConfig.n; + DataCopyPad2D(c1S32, mmOutGm[c1Offset], curBaseM, curBaseN, postMNConfig.n); + vecInQueue.EnQue(c1S32); + Muls(s23MiddleResult1, s23MiddleResult2, static_cast(1.0 / 254), curBaseM * alignedBaseN); // c2 / 254 + PipeBarrier(); + LocalTensor c1S32InUb = vecInQueue.DeQue(); + Cast(s23MiddleResult2, c1S32InUb, RoundMode::CAST_NONE, curBaseM * alignedBaseN); // c1 + PipeBarrier(); + vecInQueue.FreeTensor(c1S32InUb); + } + +template +__aicore__ inline void GMMA16W8MSDCompute::CalcCMatrix( + MNConfig& postMNConfig, uint32_t curBaseM, uint32_t curBaseN, uint32_t offsetM) { + uint32_t alignedBaseN = AlignUp<32>(curBaseN); + + Add(s23MiddleResult3, s23MiddleResult2, s23MiddleResult1, curBaseM * alignedBaseN); // c1 + c2 / 254 + PipeBarrier(); + uint32_t gmReduceMaxOffset = (postMNConfig.mAxisBaseOffset + postMNConfig.mIdx * postMNConfig.singleM + offsetM) * + FACTOR_FOR_FLOAT_ALIGN_TO_32; + CopyInAmax(curBaseM, gmReduceMaxOffset); + event_t eventIdMTE2ToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_S)); + SetFlag(eventIdMTE2ToS); + WaitFlag(eventIdMTE2ToS); + for (uint32_t idxM = 0; idxM < curBaseM; ++idxM) { + float aMaxPerRow = + aMaxInUb(idxM * FACTOR_FOR_FLOAT_ALIGN_TO_32) / 127.0f; + event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + SetFlag(eventIdSToV); + WaitFlag(eventIdSToV); + // c = (c1 + c2 / 254) * amax / 127 + Muls(s23MiddleResult2[idxM * alignedBaseN], s23MiddleResult3[idxM * alignedBaseN], aMaxPerRow, alignedBaseN); + } + PipeBarrier(); + ReduceResultInQueue.FreeTensor(aMaxInUb); + Add(s23MiddleResult1, s23MiddleResult2, cTmp, curBaseM * alignedBaseN); // c + c_tmp + PipeBarrier(); + for (uint32_t idxM = 0; idxM < curBaseM; ++idxM) { + Mul(s23MiddleResult2[idxM * alignedBaseN], s23MiddleResult1[idxM * alignedBaseN], processedScale, + alignedBaseN); // (c + c_tmp) * scale + PipeBarrier(); + if (hasBias) { + Add(s23MiddleResult2[idxM * alignedBaseN], s23MiddleResult2[idxM * alignedBaseN], processedBias, + alignedBaseN); + } + } + PipeBarrier(); +} + +template +__aicore__ inline void GMMA16W8MSDCompute::CopyOutFinalResult( + uint32_t n, uint32_t curBaseM, uint32_t curBaseN, uint64_t yOffset) { + uint32_t alignedBaseN = AlignUp<32>(curBaseN); + LocalTensor outputInUb = vecOutQueue.AllocTensor(); + #if ORIG_DTYPE_X == DT_FLOAT16 + Cast(outputInUb, s23MiddleResult2, RoundMode::CAST_NONE, curBaseM * alignedBaseN); + #else + Cast(outputInUb, s23MiddleResult2, RoundMode::CAST_RINT, curBaseM * alignedBaseN); + #endif + vecOutQueue.EnQue(outputInUb); + LocalTensor output = vecOutQueue.DeQue(); + DataCopyPad2D(yGm[yOffset], output, curBaseM, curBaseN, alignedBaseN, n); + vecOutQueue.FreeTensor(output); +} + +} // namespace GROUPED_MATMUL + +#endif +#endif // ASCENDC_GROUPED_MATMUL_ANTIQUANT_A16W8_MSD_H diff --git a/src/transformer/grouped_matmul/grouped_matmul_quant_mixcore.h b/src/transformer/grouped_matmul/grouped_matmul_quant_mixcore.h new file mode 100644 index 00000000..f6487d5b --- /dev/null +++ b/src/transformer/grouped_matmul/grouped_matmul_quant_mixcore.h @@ -0,0 +1,419 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_quant_mixcore.h + * \brief + */ +#ifndef ASCENDC_GROUPED_MATMUL_QUANT_MIXCORE_H +#define ASCENDC_GROUPED_MATMUL_QUANT_MIXCORE_H + +#include "grouped_matmul_utils.h" +#include "grouped_matmul.h" + +#if defined(GMM_QUANT_BF16) || defined(GMM_QUANT_FLOAT16) +namespace GROUPED_MATMUL { +/*@brief store variables for core split configuration +*/ +constexpr int32_t PIPELINE_NUM = 4; +constexpr uint32_t BROADCAST_DIM = 2; + +/** @brief intenal computation class +*/ +template +class GMMQuantMixCoreCompute : public GMMCompute { + public: + using AT = typename mmType::AT::T; + using BT = typename mmType::BT::T; + using B = typename mmType::BT; + using CT = typename mmType::CT::T; + using BiasT = typename mmType::BiasT::T; + using WT = DTYPE_WEIGHT; + constexpr static bool transposeX = mmType::AT::isTrans; + constexpr static bool transposeW = mmType::BT::isTrans; + + /** @brief constructor */ + __aicore__ inline GMMQuantMixCoreCompute(typename mmType::MT& mm_) : GMMCompute(mm_) {} + + __aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR scale, GM_ADDR offset, + GM_ADDR antiquantScale, GM_ADDR antiquantOffset, GM_ADDR group_list, + GM_ADDR perTokenScale, GM_ADDR y, GM_ADDR workspace, + const GMMBaseParams* __restrict gmmBaseParams, + const TCubeTiling* __restrict mmTilingData, TPipe* tPipe); + + __aicore__ inline void MMCompute(uint32_t groupIdx, MNConfig& mnConfig, uint32_t coreIdx); + + __aicore__ inline void VectorCompute(MNConfig& mnConfig); + + __aicore__ inline void PostCompute(); + + private: + __aicore__ inline void Dequant(MNConfig& mnConfig); + + __aicore__ inline void SetPerTokenQuantStaticBuffer(const GMMBaseParams* __restrict gmmBaseParams, + const TCubeTiling* __restrict mmTilingData, GM_ADDR workspace); + + __aicore__ inline void DataCopyScale(uint32_t curBaseN, uint32_t alignBaseN, uint64_t scaleOffset); + + __aicore__ inline void DataCopyPerTokenScaleAndBrcb(MNConfig& mnConfig, uint32_t curBaseM, uint32_t alignBaseN, + uint32_t offsetM); + + __aicore__ inline void SetPerTokenQuantRefreshedBuffer(const MNConfig mnConfig); + + __aicore__ inline void ActivationCompute(uint32_t computeSize, LocalTensor preResUb, + LocalTensor actTmpLocal); + + __aicore__ inline void ComputeDequantAndActivate(MNConfig& mnConfig, uint32_t curVecBaseM, uint32_t alignBaseN, uint32_t curVecBaseN, + uint32_t offsetM); + + __aicore__ inline void DataCopyOut(MNConfig& mnConfig, uint32_t curVecBaseM, uint32_t curVecBaseN, + uint32_t alignBaseN, uint64_t outOffset); + + __aicore__ inline void VectorTilingCalc(MNConfig& mnConfig, uint32_t& curCubeSingleN, uint32_t& curCubeSingleM, + uint32_t& vecBaseN, uint32_t& vecBaseM); + + GM_ADDR scaleTensorPtr; + GM_ADDR perTokenScaleTensorPtr; + GlobalTensor scaleGm; + GlobalTensor perTokenScaleGm; + GlobalTensor mmOutGm; + // define the que + TQue vecInQueue; + TQue vecOutQueue; + TQue scaleInQueue; + TQue perTokenScaleInQueue; + TBuf tmpBuff; + LocalTensor mmOutInUb; + LocalTensor scaleInUb; + LocalTensor perTokenScaleInUb; + LocalTensor dequantMiddleResult; + LocalTensor sharedTmpLocal; + LocalTensor mulsResultLocal; + LocalTensor pertokenBrcbLocal; + LocalTensor actResultLocal; + bool sequentialWrite = true; + bool isPerTokenQuant; + uint32_t cubeNum; // Matmul completions on the kernel +}; + +template +__aicore__ inline void GMMQuantMixCoreCompute::Init(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, + GM_ADDR scale, GM_ADDR offset, GM_ADDR antiquantScale, + GM_ADDR antiquantOffset, GM_ADDR groupList, + GM_ADDR perTokenScale, GM_ADDR y, GM_ADDR workspace, + const GMMBaseParams* __restrict gmmBaseParams, + const TCubeTiling* __restrict mmTilingData, + TPipe* tPipe) { + this->GMMCompute::Init(x, weight, bias, scale, offset, antiquantScale, antiquantOffset, groupList, + perTokenScale, y, workspace, gmmBaseParams, mmTilingData, tPipe); + isPerTokenQuant = gmmBaseParams->quantParam == 1; + sequentialWrite = gmmBaseParams->singleN == 0; + scaleTensorPtr = scale; + perTokenScaleTensorPtr = perTokenScale; + cubeNum = 0; + SetPerTokenQuantStaticBuffer(gmmBaseParams, mmTilingData, workspace); +} + +template +__aicore__ inline void GMMQuantMixCoreCompute::PostCompute() { + if ASCEND_IS_AIC { + for (int32_t idx = 0; idx < Min(cubeNum, PIPELINE_NUM); ++idx) { + CrossCoreWaitFlag(SYNC_AIV_AIC_FLAG); + } + } +} + +template +__aicore__ inline void GMMQuantMixCoreCompute::MMCompute(uint32_t groupIdx, MNConfig& mnConfig, + uint32_t coreIdx) { + uint32_t tailN = mnConfig.nIdx * mnConfig.singleN; + uint32_t curSingleN = mnConfig.nIdx < mnConfig.blockDimN - 1 ? mnConfig.singleN : mnConfig.n - tailN; + uint32_t curSingleM = mnConfig.mIdx < mnConfig.blockDimM - 1 ? mnConfig.singleM + : mnConfig.m - mnConfig.mIdx * mnConfig.singleM; + uint64_t xOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.k; + if constexpr (transposeX) { + xOffset = mnConfig.mIdx * mnConfig.singleM; + } + uint64_t outOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.n + tailN; + // init global buffer + if (this->singleX == 0) { + this->xGm.SetGlobalBuffer(GetTensorAddr(groupIdx, this->xTensorPtr)); + } else { + this->xGm.SetGlobalBuffer(GetTensorAddr(0, this->xTensorPtr) + mnConfig.xBaseOffset); + } + GlobalTensor weightGm = this->SetGlobalBufferW(groupIdx, tailN, mnConfig); + if (sequentialWrite) { + mnConfig.workSpaceOffset = mnConfig.baseN * mnConfig.baseM * \ + (coreIdx + (cubeNum % PIPELINE_NUM) * this->coreNum); + } else { + mnConfig.workSpaceOffset = outOffset + mnConfig.yBaseOffset; + } + if ASCEND_IS_AIC { + if (this->cubeNum >= PIPELINE_NUM) { + CrossCoreWaitFlag(SYNC_AIV_AIC_FLAG); + } + this->mm.SetOrgShape(mnConfig.m, mnConfig.n, mnConfig.k); + this->mm.SetSingleShape(curSingleM, curSingleN, mnConfig.k); + this->mm.SetTensorA(this->xGm[xOffset], transposeX); + this->mm.SetTensorB(weightGm, transposeW); + this->SetGlobalBufferBias(groupIdx, tailN, mnConfig); + while (this->mm.Iterate()) { + this->mm.GetTensorC(mmOutGm[mnConfig.workSpaceOffset], 0, sequentialWrite); + } + CrossCoreSetFlag(SYNC_AIC_AIV_FLAG); + } + cubeNum++; +} + +template +__aicore__ inline void GMMQuantMixCoreCompute::VectorCompute(MNConfig& mnConfig) { + if ASCEND_IS_AIV { + CrossCoreWaitFlag(SYNC_AIC_AIV_FLAG); + SetPerTokenQuantRefreshedBuffer(mnConfig); + Dequant(mnConfig); + CrossCoreSetFlag(SYNC_AIV_AIC_FLAG); + } +} + +template +__aicore__ inline void GMMQuantMixCoreCompute::ComputeDequantAndActivate(MNConfig& mnConfig, + uint32_t curVecBaseM, uint32_t alignBaseN, uint32_t curVecBaseN, uint32_t offsetM) { + DataCopyPerTokenScaleAndBrcb(mnConfig, curVecBaseM, alignBaseN, offsetM); + mmOutInUb = vecInQueue.DeQue(); + LocalTensor yLocalInUb = vecOutQueue.AllocTensor(); + + #if defined(GMM_QUANT_BF16) + if (!isPerTokenQuant && this->activeType == 0) { // BF16 static quantization without activation. + AscendDequant(yLocalInUb, mmOutInUb, scaleInUb, sharedTmpLocal, {curVecBaseM, alignBaseN, curVecBaseN}); + vecInQueue.FreeTensor(mmOutInUb); + vecOutQueue.EnQue(yLocalInUb); + return; + } + #endif + AscendDequant(dequantMiddleResult, mmOutInUb, scaleInUb, sharedTmpLocal, {curVecBaseM, alignBaseN, curVecBaseN}); + PipeBarrier(); + LocalTensor preResUb = dequantMiddleResult; + LocalTensor yFP32LocalInUb = dequantMiddleResult; + LocalTensor actTmpLocal = sharedTmpLocal; + // pertoken antiquant + if (isPerTokenQuant) { + Mul(mulsResultLocal, dequantMiddleResult, pertokenBrcbLocal, curVecBaseM * alignBaseN); + PipeBarrier(); + preResUb = mulsResultLocal; + yFP32LocalInUb = mulsResultLocal; + actTmpLocal = tmpBuff.GetWithOffset(2 * this->ubCalSize * sizeof(float), 0); + } + // activation function + if (this->activeType != 0) { + uint32_t computeSize = curVecBaseM * alignBaseN; + ActivationCompute(computeSize, preResUb, actTmpLocal); + yFP32LocalInUb = actResultLocal; + } + // get final output after Cast + #if defined(GMM_QUANT_BF16) + Cast(yLocalInUb, yFP32LocalInUb, RoundMode::CAST_RINT, curVecBaseM * alignBaseN); + #elif defined(GMM_QUANT_FLOAT16) + Cast(yLocalInUb, yFP32LocalInUb, RoundMode::CAST_NONE, curVecBaseM * alignBaseN); + #endif + PipeBarrier(); + vecInQueue.FreeTensor(mmOutInUb); + vecOutQueue.EnQue(yLocalInUb); + return; +} + +template +__aicore__ inline void GMMQuantMixCoreCompute::VectorTilingCalc( + MNConfig& mnConfig, uint32_t& curCubeSingleN, uint32_t& curCubeSingleM, uint32_t& vecBaseN, + uint32_t& vecBaseM) { + curCubeSingleN = mnConfig.nIdx == mnConfig.blockDimN - 1 ? + mnConfig.n - mnConfig.nIdx * mnConfig.singleN : mnConfig.singleN; + curCubeSingleM = mnConfig.mIdx == mnConfig.blockDimM - 1 ? + mnConfig.m - mnConfig.mIdx * mnConfig.singleM : mnConfig.singleM; + vecBaseN = sequentialWrite ? curCubeSingleN : mnConfig.baseN; + vecBaseM = this->ubCalSize / AlignUp(vecBaseN, static_cast(UB_BLOCK_DOUBLE_UNIT_SIZE / sizeof(int32_t))); + vecBaseM = Min(vecBaseM, curCubeSingleM); +} + +template +__aicore__ inline void GMMQuantMixCoreCompute::Dequant(MNConfig& mnConfig) { + uint32_t curCubeSingleN; + uint32_t curCubeSingleM; + uint32_t vecBaseN; + uint32_t vecBaseM; + VectorTilingCalc(mnConfig, curCubeSingleN, curCubeSingleM, vecBaseN, vecBaseM); + uint32_t curVecBaseN = vecBaseN; + uint32_t curVecBaseM; + uint32_t vecCount = 0; + uint32_t rowLength = sequentialWrite ? curCubeSingleN : mnConfig.n; + uint32_t taskRation = GetTaskRation(); + for (uint32_t offsetN = 0; offsetN < curCubeSingleN; offsetN += vecBaseN) { + if (unlikely(offsetN + vecBaseN >= curCubeSingleN)) { curVecBaseN = curCubeSingleN - offsetN; } + uint32_t alignBaseN = AlignUp(curVecBaseN, static_cast(UB_BLOCK_DOUBLE_UNIT_SIZE / sizeof(int32_t))); + uint64_t scaleOffset = mnConfig.nIdx * mnConfig.singleN + offsetN; + DataCopyScale(curVecBaseN, alignBaseN, scaleOffset); + curVecBaseM = vecBaseM; + for (uint32_t offsetM = 0; offsetM < curCubeSingleM; offsetM += vecBaseM) { + vecCount++; + if (vecCount % taskRation != this->subBlockIdx) { + continue; + } + if (unlikely(offsetM + vecBaseM >= curCubeSingleM)) { + curVecBaseM = curCubeSingleM - offsetM; + } + // use AscendDequant interface to do perchannel dequant + uint64_t mmOutOffset = mnConfig.workSpaceOffset + offsetM * static_cast(rowLength) + offsetN; + LocalTensor mmOutLocal = vecInQueue.AllocTensor(); + DataCopyPad2D(mmOutLocal, mmOutGm[mmOutOffset], curVecBaseM, curVecBaseN, rowLength); + vecInQueue.EnQue(mmOutLocal); + ComputeDequantAndActivate(mnConfig, curVecBaseM, alignBaseN, curVecBaseN, offsetM); + uint64_t outOffset = (mnConfig.mIdx * mnConfig.singleM + offsetM) * mnConfig.n + \ + mnConfig.nIdx * mnConfig.singleN + offsetN; + DataCopyOut(mnConfig, curVecBaseM, curVecBaseN, alignBaseN, outOffset); + } + scaleInQueue.FreeTensor(scaleInUb); + } +} + +template +__aicore__ inline void GMMQuantMixCoreCompute::DataCopyOut(MNConfig& mnConfig, uint32_t curVecBaseM, + uint32_t curVecBaseN, uint32_t alignBaseN, + uint64_t outOffset) { + // Copy the result of vector to yGm. + LocalTensor yLocal = vecOutQueue.DeQue(); + DataCopyPad2D(this->yGm[outOffset], yLocal, curVecBaseM, curVecBaseN, alignBaseN, mnConfig.n); + vecOutQueue.FreeTensor(yLocal); +} + +template +__aicore__ inline void GMMQuantMixCoreCompute::ActivationCompute(uint32_t computeSize, + LocalTensor preResUb, + LocalTensor actTmpLocal) { + ActiveType active = ActiveType(this->activeType); + if (active == ActiveType::FASTGELU) { + FasterGelu(actResultLocal, preResUb, actTmpLocal, computeSize); + } else if (active == ActiveType::RELU) { + Relu(actResultLocal, preResUb, computeSize); + } else if (active == ActiveType::SILU) { + Silu(actResultLocal, preResUb, computeSize); + } else if (active == ActiveType::GELU_TANH) { + Gelu(actResultLocal, preResUb, actTmpLocal, computeSize); + } + PipeBarrier(); +} + +template +__aicore__ inline void GMMQuantMixCoreCompute::SetPerTokenQuantStaticBuffer( + const GMMBaseParams* __restrict gmmBaseParams, const TCubeTiling* __restrict mmTilingData, GM_ADDR workspace) { + // Initialize ub and gm memories that do not need to be reinitialized due to changes in groupidx. + if ASCEND_IS_AIV { + // 2: enabling double buffer, occupying two buffer. + this->pipe->InitBuffer(scaleInQueue, 2, mmTilingData->baseN * sizeof(DTYPE_SCALE)); + if (isPerTokenQuant) { + // 2: enabling double buffer, occupying two buffer. + this->pipe->InitBuffer(perTokenScaleInQueue, 2, mmTilingData->baseM * sizeof(float)); + } + // 2: enabling double buffer, occupying two buffer. + this->pipe->InitBuffer(vecInQueue, 2, this->ubCalSize * sizeof(CT)); + // 2: enabling double buffer, occupying two buffer. + this->pipe->InitBuffer(vecOutQueue, 2, this->ubCalSize * sizeof(DTYPE_Y)); + this->pipe->InitBuffer(tmpBuff, gmmBaseParams->ubRestBytes); + dequantMiddleResult = tmpBuff.GetWithOffset(this->ubCalSize, 0); + #if defined(GMM_QUANT_FLOAT16) + uint32_t factor = 1; + #else + uint32_t factor = 0; + #endif + // 2: Indicates the first two blocks of ub are already occupied. + factor = !isPerTokenQuant && this->activeType == 0 ? factor : 2; + uint32_t ubCalSizeFloat = this->ubCalSize * sizeof(float); + uint32_t offset = factor * ubCalSizeFloat; + // 2: Indicates a temporary space twice the size is needed. + sharedTmpLocal = tmpBuff.GetWithOffset(2 * ubCalSizeFloat, offset); + if (isPerTokenQuant) { + // 2: Indicates the first two blocks of ub are already occupied. + mulsResultLocal = tmpBuff.GetWithOffset(this->ubCalSize, 2 * ubCalSizeFloat); + pertokenBrcbLocal = tmpBuff.GetWithOffset(this->ubCalSize, ubCalSizeFloat); + } + if (this->activeType != 0) { + // 2: Indicates the first three blocks of ub are already occupied. + uint32_t offsetAct = !isPerTokenQuant ? ubCalSizeFloat : 3 * ubCalSizeFloat; + actResultLocal = tmpBuff.GetWithOffset(this->ubCalSize, offsetAct); + } + } + mmOutGm.SetGlobalBuffer((__gm__ MM_DTYPE_Y *)workspace); +} + +template +__aicore__ inline void GMMQuantMixCoreCompute::SetPerTokenQuantRefreshedBuffer(const MNConfig mnConfig) { + // Initialize gm memories that need to be reinitialized due to changes in groupidx. + // Currently, pertoken quant only supports single-tensor mode, + // hence set according to x and weight single-tensor mode. + // Add an if branch if multi-tensor mode for weght is required. + scaleGm.SetGlobalBuffer(GetTensorAddr(0, scaleTensorPtr) + mnConfig.nAxisBaseOffset); + if (isPerTokenQuant) { + perTokenScaleGm.SetGlobalBuffer((__gm__ float *)perTokenScaleTensorPtr + mnConfig.mAxisBaseOffset); + } + // Add an if branch if multi-tensor mode for y is required. + this->yGm.SetGlobalBuffer(GetTensorAddr(0, this->yTensorPtr) + mnConfig.yBaseOffset); +} + +template +__aicore__ inline void GMMQuantMixCoreCompute::DataCopyScale(uint32_t curBaseN, + uint32_t alignBaseN, + uint64_t scaleOffset) +{ + // GM copy scale + DataCopyPadExtParams padParams; + DataCopyExtParams scaleParams; + scaleParams.blockLen = curBaseN * sizeof(DTYPE_SCALE); + scaleParams.blockCount = 1; + scaleParams.srcStride = 0; + scaleParams.dstStride = 0; + LocalTensor scaleLocal = scaleInQueue.AllocTensor(); + DataCopyPad(scaleLocal, scaleGm[scaleOffset], scaleParams, padParams); + scaleInQueue.EnQue(scaleLocal); + + scaleInUb = scaleInQueue.DeQue(); + scaleInUb.SetSize(alignBaseN); +} + +template +__aicore__ inline void GMMQuantMixCoreCompute::DataCopyPerTokenScaleAndBrcb(MNConfig& mnConfig, + uint32_t curBaseM, + uint32_t alignBaseN, + uint32_t offsetM) +{ + if (!isPerTokenQuant) { + return; + } + uint64_t perTokenScaleOffset = mnConfig.mIdx * mnConfig.singleM + offsetM; + // GM copy per token scale + DataCopyPadExtParams padParams; + DataCopyExtParams perTokenScaleParams; + perTokenScaleParams.blockLen = curBaseM * sizeof(float); + perTokenScaleParams.blockCount = 1; + perTokenScaleParams.srcStride = 0; + perTokenScaleParams.dstStride = 0; + LocalTensor perTokenScaleLocal = perTokenScaleInQueue.AllocTensor(); + DataCopyPad(perTokenScaleLocal, perTokenScaleGm[perTokenScaleOffset], perTokenScaleParams, padParams); + perTokenScaleInQueue.EnQue(perTokenScaleLocal); + + perTokenScaleInUb = perTokenScaleInQueue.DeQue(); + const uint32_t broadCastDst[BROADCAST_DIM] = {curBaseM, alignBaseN}; + const uint32_t broadCastSrc[BROADCAST_DIM] = {curBaseM, 1}; + BroadCast(pertokenBrcbLocal, perTokenScaleInUb, broadCastDst, broadCastSrc, + sharedTmpLocal); + perTokenScaleInQueue.FreeTensor(perTokenScaleInUb); +} + +} // namespace GROUPED_MATMUL + +#endif +#endif // ASCENDC_GROUPED_MATMUL_QUANT_MIXCORE_H diff --git a/src/transformer/grouped_matmul/grouped_matmul_utils.h b/src/transformer/grouped_matmul/grouped_matmul_utils.h new file mode 100644 index 00000000..f66d0977 --- /dev/null +++ b/src/transformer/grouped_matmul/grouped_matmul_utils.h @@ -0,0 +1,210 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_utils.h + * \brief + */ +#ifndef ASCENDC_GROUPED_MATMUL_UTILS_H +#define ASCENDC_GROUPED_MATMUL_UTILS_H + +#include "kernel_tiling/kernel_tiling.h" +#include "kernel_operator.h" +#include "lib/matmul_intf.h" + +#if defined(ORIG_DTYPE_X) && defined(ORIG_DTYPE_WEIGHT) && defined(ORIG_DTYPE_Y) && defined(DT_INT8) && \ + defined(DT_BF16) + #if ORIG_DTYPE_X == ORIG_DTYPE_WEIGHT + #if ORIG_DTYPE_X == DT_INT8 + #if ORIG_DTYPE_Y == DT_BF16 + #define GMM_QUANT_BF16 + #define MM_DTYPE_Y int32_t + #elif ORIG_DTYPE_Y == DT_FLOAT16 + #define GMM_QUANT_FLOAT16 + #define MM_DTYPE_Y int32_t + #else + #define GMM_QUANT_INT8 + #endif + #else + #define GMM_FLOAT + #endif + #else + #define GMM_ANTI_QUANT + #endif +#endif + +#if defined(DTYPE_Y) && !defined(MM_DTYPE_Y) + #define MM_DTYPE_Y DTYPE_Y +#endif + +namespace GROUPED_MATMUL { +using namespace AscendC; + +constexpr uint32_t INT8_BITS = 8; // a int8 number has 8 bits +constexpr int32_t MKN_LIST_LEN = 128; // 128: predefined array legnth +constexpr uint32_t UB_BLOCK_UNIT_SIZE = 32; // 32: a block has 32 bytes data +constexpr uint32_t UB_BLOCK_DOUBLE_UNIT_SIZE = 64; // 64: a block has 64 bytes data +constexpr uint32_t HALF_UB_BLOCK_UNIT_SIZE = UB_BLOCK_UNIT_SIZE / 2; // 2: a float16 data has two bytes +constexpr MatmulConfig NZ_CFG_MDL = GetMDLConfig(false, false, 0, true, false, false, false); +constexpr MatmulConfig matmulCFGUnitFlag{false, false, true, 0, 0, 0, false, false, false, false, false, 0, 0, 0, + 0, 0, 0, 0, true}; + +constexpr uint64_t SYNC_AIV_AIC_FLAG = 3; +constexpr uint64_t SYNC_AIC_AIV_FLAG = 5; +constexpr uint64_t SYNC_MODE2 = 2; + +template +struct MMType { + using AT = AT_; + using BT = BT_; + using CT = CT_; + using BiasT = BiasT_; + using MT = matmul::Matmul; +}; + +template +struct MMImplType { + using AT = AT_; + using BT = BT_; + using CT = CT_; + using BiasT = BiasT_; + using MT = matmul::MatmulImpl; +}; + +enum class ActiveType { + INVALID_TYPE = 0, + RELU, + GELU_TANH, + GELU_ERR_FUNC, + FASTGELU, + SILU +}; + +template +__aicore__ inline T GreatestCommonDivisor(T a, T b) { + T c = a; + if (a < b) { + a = b; + b = c; + } + while (b != 0) { + c = a; + a = b; + b = c % b; + } + return a; +} + +template +__aicore__ inline T LeastCommonMultiple(T a, T b) { + return a * b / GreatestCommonDivisor(a, b); +} + +template +__aicore__ inline T Max(T a, T b) { + return a > b ? a : b; +} + +template +__aicore__ inline T Min(T a, T b) { + return a > b ? b : a; +} + +template +__aicore__ inline T AlignUp(T a) { + return (a + base - 1) / base * base; +} + +template +__aicore__ inline T AlignUp(T a, T base) { + return (a + base - 1) / base * base; +} + +template +__aicore__ inline T AlignDown(T a, T base) { + if (unlikely(base == 0)) { + return a; + } + return a / base * base; +} + +template <> +__aicore__ inline uint32_t AlignUp<4, uint32_t>(uint32_t a) { + // to be Multiple of 4, result should be in a format of b(xxxx,x100). + // This means last two bits should be zero, requiring that + // result = num & b(1111,1100) = num & (~3). + // &(~3) operator may reduces num into the range [num, num - 3]. + // As the result should be no less than a (result >= a), it means num - 3 >= a in the worst case. + // In this case, num >= a+3. On the other hand, num should also be less then a+4, otherwise, + // the result will not be least multiple of 4 for 3. In other cases like [num, num - 2], + // num = a + 3 also satisfies the goal condition. + return (a + 3) & ~3; // & ~3: set last two bits of (a+3) to be zero +} + +template <> +__aicore__ inline uint32_t AlignUp<8, uint32_t>(uint32_t a) { + // In general, if we want to get the least multiple of b (b is the power of 2) for a, + // it comes to a conclusion from the above comment: result = (a + (b - 1)) & (~b) + return (a + 7) & ~7; // & ~7: set last four bits of (a+7) to be zero +} + +template <> +__aicore__ inline uint32_t AlignUp<16, uint32_t>(uint32_t a) { + // In general, if we want to get the least multiple of b (b is the power of 2) for a, + // it comes to a conclusion from the above comment: result = (a + (b - 1)) & (~b) + return (a + 15) & ~15; // & ~15: set last four bits of (a+15) to be zero +} + +template <> +__aicore__ inline uint32_t AlignUp<32, uint32_t>(uint32_t a) { + // refer to the above comments. + return (a + 31) & ~31; // & ~31: set last five bits of (a+31) to be zero} +} + +template +__aicore__ inline __gm__ T* GetTensorAddr(uint16_t index, GM_ADDR tensorPtr) { + __gm__ uint64_t* dataAddr = reinterpret_cast<__gm__ uint64_t*>(tensorPtr); + uint64_t tensorPtrOffset = *dataAddr; // The offset of the data address from the first address. + // Moving 3 bits to the right means dividing by sizeof(uint64 t). + __gm__ uint64_t* retPtr = dataAddr + (tensorPtrOffset >> 3); + return reinterpret_cast<__gm__ T*>(*(retPtr + index)); +} + +#define GET_TILING_DATA_MEMBER_ADDR(tilingType, member, var, tiling) \ + size_t offset##var = (size_t)(&((tilingType*)0)->member); \ + __gm__ uint8_t* (var) = (tiling) + (offset##var) + +__aicore__ inline int32_t GetSplitValueFromGroupList(uint32_t groupIdx, int32_t &preOffset, + const GMMBaseParams* __restrict &gmmBaseParams, + const GlobalTensor &groupListGm) { + int32_t splitValue = 0; + if (likely(gmmBaseParams->groupType != -1)) { // -1: no need to split + if (gmmBaseParams->groupListType == 0) { + int32_t offset = static_cast(groupListGm.GetValue(groupIdx)); + splitValue = offset - preOffset; + preOffset = offset; + } else { + splitValue = static_cast(groupListGm.GetValue(groupIdx)); + } + } + return splitValue; +} + +template +__aicore__ inline constexpr uint32_t GetTypeBits() { + if constexpr (IsSameType::value) { + return 4; // 4: int4 bits number + } + return sizeof(T) * INT8_BITS; +} + +} // namespace GROUPED_MATMUL + +#endif // ASCENDC_GROUPED_MATMUL_UTILS_H diff --git a/src/transformer/grouped_matmul/grouped_matmul_vector.h b/src/transformer/grouped_matmul/grouped_matmul_vector.h new file mode 100644 index 00000000..bd1523fa --- /dev/null +++ b/src/transformer/grouped_matmul/grouped_matmul_vector.h @@ -0,0 +1,76 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_vector.h + * \brief + */ +#ifndef ASCENDC_GROUPED_MATMUL_VECTOR_H +#define ASCENDC_GROUPED_MATMUL_VECTOR_H + +#include "grouped_matmul_utils.h" + +namespace GROUPED_MATMUL { + +template +__aicore__ inline void EmptyTensorCompute(GM_ADDR groupListPtr, GM_ADDR y, const GMMTilingData* __restrict tiling) { + const GMMBaseParams* __restrict gmmBaseParams = &tiling->gmmBaseParams; + // In the V2 interface, grouptype is -1 after host is grouped. Thus, grouptype can be either -1 or 2. + if (groupListPtr == nullptr || gmmBaseParams->groupType == 0) { + return; + } + + GlobalTensor yGm; + GlobalTensor groupListGm; + yGm.SetGlobalBuffer(GetTensorAddr(0, y)); + if (groupListPtr != nullptr) { + groupListGm.SetGlobalBuffer((__gm__ int64_t*)groupListPtr); + } + uint64_t yBaseOffset = 0; + int32_t preOffset = 0; + uint32_t singleWeight = gmmBaseParams->singleWeight; + uint32_t singleX = gmmBaseParams->singleX; + uint32_t singleY = gmmBaseParams->singleY; + bool isAllSingleTensor = singleWeight == 1 && singleX == 1 && singleY == 1; + + const int32_t *ubM = tiling->gmmArray.mList; + const int32_t *ubK = tiling->gmmArray.kList; + const int32_t *ubN = tiling->gmmArray.nList; + int64_t coreIdx = GetBlockIdx(); + int64_t coreRation = GetTaskRation(); + if (coreRation > 1) { + coreIdx /= coreRation; + } + + for (uint32_t groupIdx = 0; groupIdx < gmmBaseParams->groupNum; ++groupIdx) { + int32_t splitValue = GetSplitValueFromGroupList(groupIdx, preOffset, gmmBaseParams, groupListGm); + uint32_t m = isAllSingleTensor && gmmBaseParams->groupType == 2 ? *ubM : *(ubM + groupIdx); + uint32_t k = *ubK < 0 && gmmBaseParams->groupType == 2 ? splitValue : *(ubK + groupIdx); + uint32_t n = isAllSingleTensor ? *ubN : *(ubN + groupIdx); + + if (k == 0) { + uint32_t singleM = Ceil(m, gmmBaseParams->coreNum); + singleM = AlignUp(singleM); + uint32_t cursingleM = singleM; + if (singleM * coreIdx >= m) { + yBaseOffset += m * n; + continue; + } else if (m - singleM * coreIdx < singleM) { + cursingleM = m - singleM * coreIdx; + } + InitOutput(yGm[yBaseOffset + coreIdx * singleM * n], cursingleM * n, 0); + } + yBaseOffset += m * n; + } +} + +} // namespace GROUPED_MATMUL + +#endif // ASCENDC_GROUPED_MATMUL_VECTOR_H diff --git a/src/transformer/grouped_matmul/ophost/CMakeLists.txt b/src/transformer/grouped_matmul/ophost/CMakeLists.txt new file mode 100644 index 00000000..bd368b00 --- /dev/null +++ b/src/transformer/grouped_matmul/ophost/CMakeLists.txt @@ -0,0 +1,55 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +add_ops_compile_options( + OP_NAME GroupedMatmul + OPTIONS --cce-auto-sync=off + -Wno-deprecated-declarations + -Werror +) + +target_sources(op_host_aclnnExc PRIVATE + grouped_matmul_def.cpp +) + +target_sources(opapi PRIVATE + grouped_matmul.cpp + aclnn_grouped_matmul.cpp +) + +if (NOT BUILD_OPEN_PROJECT) + target_sources(aclnn_ops_train PRIVATE + grouped_matmul.cpp + aclnn_grouped_matmul.cpp + ) + + target_sources(aclnn_ops_infer PRIVATE + grouped_matmul.cpp + aclnn_grouped_matmul.cpp + ) +endif () + +target_sources(optiling PRIVATE + grouped_matmul_tiling.cpp + fallback_grouped_matmul.cpp +) + +target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_sources(opsproto PRIVATE + grouped_matmul_proto.cpp +) + +file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_grouped_matmul*.h") + +install(FILES ${_GMM_Aclnn_header} + DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL +) diff --git a/src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul.cpp b/src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul.cpp new file mode 100644 index 00000000..994c5617 --- /dev/null +++ b/src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul.cpp @@ -0,0 +1,1724 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "aclnn_grouped_matmul.h" +#include "aclnn_grouped_matmul_v2.h" +#include "aclnn_grouped_matmul_v3.h" +#include "aclnn_grouped_matmul_v4.h" + +#include +#include + +#include "aclnn_kernels/transdata.h" +#include "grouped_matmul.h" +#include "aclnn_kernels/contiguous.h" + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_kernels/common/op_error_check.h" +#include "opdev/common_types.h" +#include "opdev/data_type_utils.h" +#include "opdev/format_utils.h" +#include "opdev/op_dfx.h" +#include "opdev/op_executor.h" +#include "opdev/op_log.h" +#include "opdev/platform.h" +#include "opdev/shape_utils.h" +#include "opdev/tensor_view_utils.h" +#include "opdev/make_op_executor.h" + +using namespace op; + +#ifdef __cplusplus +extern "C" { +#endif + +static constexpr int64_t X_Y_SEPARATED = 0; // x,y no split +static constexpr int64_t Y_SEPARATED = 1; // x split +static constexpr int64_t X_SEPARATED = 2; // y split +static constexpr int64_t NO_SEPARATED = 3; // x,y split +static constexpr int64_t MAX_GROUP_LIST_SIZE_ARRAY = 128; +static constexpr int64_t MAX_GROUP_LIST_SIZE_TENSOR = 1024; +static constexpr int64_t NO_SPLIT = -1; +static constexpr int64_t SPLIT_M = 0; +static constexpr int64_t SPLIT_K = 2; +static constexpr int64_t SPLIT_N = 1; +static constexpr int64_t MAX_INNER_AXIS = 65535; + +static constexpr size_t MAX_FM_DIM = 6; +static constexpr size_t MIN_FM_DIM = 2; +static constexpr size_t SEPARATED_WEIGHT_DIM = 2; +static constexpr size_t SPLIT_M_SINGLE_WEIGHT_DIM = 3; +static constexpr size_t SPLIT_K_SINGLE_WEIGHT_DIM = 2; +static constexpr int64_t END_ACT_TYPE_ENUM = 6; + +enum class GMMApiVersion { + V1, + V2, + V3, + V4 +}; + +const std::map BIAS_DTYPE { + {DataType::DT_FLOAT16, aclDataType::ACL_FLOAT16}, + {DataType::DT_BF16, aclDataType::ACL_FLOAT}, + {DataType::DT_INT8, aclDataType::ACL_INT32}, + {DataType::DT_FLOAT, aclDataType::ACL_FLOAT} +}; + +struct GroupedMatmulParams { + const aclTensorList *x = nullptr; + const aclTensorList *weight = nullptr; + const aclTensorList *biasOptional = nullptr; + const aclIntArray *groupListOptional = nullptr; + const aclTensor *groupTensorOptional = nullptr; + const aclTensorList *scaleOptional = nullptr; + const aclTensorList *offsetOptional = nullptr; + const aclTensorList *antiquantScaleOptional = nullptr; + const aclTensorList *antiquantOffsetOptional = nullptr; + const aclTensorList *perTokenScaleOptional = nullptr; + const aclTensorList *activationInputOptional = nullptr; + const aclTensorList *activationQuantScaleOptional = nullptr; + const aclTensorList *activationQuantOffsetOptional = nullptr; + int64_t splitItem = 0; + int64_t groupListType = 0; + int64_t activeType = 0; + bool transposeWeight = false; + bool transposeX = false; + bool isSingleWeight = false; + GMMApiVersion apiVersion = GMMApiVersion::V1; + int64_t groupType = -1; + const aclTensorList *y = nullptr; + const aclTensorList *activationFeatureOutOptional = nullptr; + const aclTensorList *dynQuantScaleOutOptional = nullptr; + DataType xDtype = DataType::DT_FLOAT16; +}; + +static bool IsTransposeLastTwoDims(const aclTensor *tensor) { + auto shape = tensor->GetViewShape(); + int64_t dim1 = shape.GetDimNum() - 1; + int64_t dim2 = shape.GetDimNum() - 2; + auto strides = tensor->GetViewStrides(); + if (strides[dim2] == 1 && strides[dim1] == shape.GetDim(dim2)) { + int64_t tmpNxD = shape.GetDim(dim1) * shape.GetDim(dim2); + for (int64_t batchDim = shape.GetDimNum() - 3; batchDim >= 0;batchDim--) { + if(strides[batchDim] != tmpNxD) { + return false; + } + tmpNxD *= shape.GetDim(batchDim); + } + return true; + } + return false; +} + +static aclnnStatus CheckShapeSameLengthTensorList(const aclTensorList *tensorList1, const aclTensorList *tensorList2, + const std::vector& dimIds, const int64_t innerAxisDimId, + const std::vector& tensorType) { + // Verify if the values of a specified dimension in each tensor of two tensor lists with equal lengths are consistent. + uint64_t groupNum = tensorList1->Size(); + for (uint64_t i = 0; i < groupNum; i++) { + int64_t dimValue1 = (*tensorList1)[i]->GetViewShape().GetDim(dimIds[0]); + // tensorType[2] indicates whether to verify innerAxisDimId of tensorList1;if so, check if it's less than or equal to 65535. + if (tensorType[2] == "true" && innerAxisDimId > -1) { + int64_t innerAxisValue = (*tensorList1)[i]->GetViewShape().GetDim(innerAxisDimId); + CHECK_COND(innerAxisValue <= MAX_INNER_AXIS, ACLNN_ERR_PARAM_INVALID, + "Dim %lu value of %s[%lu] should less or equal to 65535, but now is %lu.", + dimIds[0], tensorType[0].c_str(), i, innerAxisValue); + } + int64_t dimValue2 = (*tensorList2)[i]->GetViewShape().GetDim(dimIds[1]); + CHECK_COND(dimValue1 == dimValue2, ACLNN_ERR_PARAM_INVALID, + "Dim %lu value of %s[%lu] should be equal with dim %lu value of %s[%lu], but now is %ld and %ld respectively.", + dimIds[0], tensorType[0].c_str(), i, dimIds[1], tensorType[1].c_str(), i, dimValue1, dimValue2); + } + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckShapeDiffLengthTensorList(const aclTensorList *longTensorList, + const aclTensorList *singleTensorList, + const std::vector& dimIds, + const int64_t innerAxisdimId, + const std::vector& tensorType) { + // Check if the values of a specified axis in a tensor list of multiple tensor + // match those in a tensor list of a single tensor. + // Specified axis is not a split axis. + int64_t dimValueSingle = (*singleTensorList)[0]->GetViewShape().GetDim(dimIds[1]); + // tensorType[2] indicates whether to verify innerAxisdimId of tensorList1; if so, check if it's less than or equal to 65535. + if (tensorType[2] == "true" && innerAxisdimId > -1) { + int64_t dimValue = (*singleTensorList)[0]->GetViewShape().GetDim(innerAxisdimId); + CHECK_COND(dimValue <= MAX_INNER_AXIS, ACLNN_ERR_PARAM_INVALID, + "Dim %lu value of %s[0] should less or equal to 65535, but now is %lu.", + innerAxisdimId, tensorType[1].c_str(), dimValue); + } + uint64_t groupNum = longTensorList->Size(); + for (uint64_t i = 0; i < groupNum; i++) { + int64_t dimValueLong = (*longTensorList)[i]->GetViewShape().GetDim(dimIds[0]); + CHECK_COND(dimValueLong == dimValueSingle, ACLNN_ERR_PARAM_INVALID, + "Dim %lu value of %s[%lu] %ld should be equal with dim %lu value of %s[0] %ld.", + dimIds[0], tensorType[0].c_str(), i, dimValueLong, + dimIds[1], tensorType[1].c_str(), dimValueSingle); + } + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckFormat(const aclTensor *tensor, const std::string& tensorType, size_t idx) { + bool isWeightTensor = tensorType == "weight"; + op::Format tensorFormat = tensor->GetStorageFormat(); + CHECK_COND(tensorFormat < Format::FORMAT_END, ACLNN_ERR_PARAM_INVALID, "Format of %s[%lu] %s is invalid.", + tensorType.c_str(), idx, op::ToString(tensorFormat).GetString()); + if (isWeightTensor) { // 310P weight need to be NZ + CHECK_COND(!op::IsPrivateFormat(tensorFormat) || tensorFormat == Format::FORMAT_FRACTAL_NZ, + ACLNN_ERR_PARAM_INVALID, "Format of %s[%lu] %s is invalid.", tensorType.c_str(), idx, + op::ToString(tensorFormat).GetString()); + } else { + CHECK_COND(!op::IsPrivateFormat(tensorFormat), + ACLNN_ERR_PARAM_INVALID, "Format of %s[%lu] %s is invalid.", tensorType.c_str(), idx, + op::ToString(tensorFormat).GetString()); + } + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckShapeDiffLengthTensorListSplitAxis(const aclTensorList *longTensorList, + const aclTensorList *singleTensorList, + const size_t dimIdxLongTensorList, + const size_t dimIdxSingleTensorList, + const std::vector& tensorType) { + // Check if the sum of values along a specified axis in a multi-tensor list equals + //the corresponding axis value in a single-tensor list. + // The specified axis is the split axis. + int64_t dimValueSingle = (*singleTensorList)[0]->GetViewShape().GetDim(dimIdxSingleTensorList); + uint64_t groupNum = longTensorList->Size(); + int64_t preOffset = 0; + for (uint64_t i = 0; i < groupNum; i++) { + int64_t dimValueLong = (*longTensorList)[i]->GetViewShape().GetDim(dimIdxLongTensorList); + preOffset += dimValueLong; + } + CHECK_COND(preOffset == dimValueSingle, ACLNN_ERR_PARAM_INVALID, + "Sum of dim %lu value of %s %ld should be equal with dim %lu value of %s[0] %ld.", + dimIdxLongTensorList, tensorType[0].c_str(), preOffset, + dimIdxSingleTensorList, tensorType[1].c_str(), dimValueSingle); + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckDimNumAndFormat(const GroupedMatmulParams &gmmParams, const aclTensorList *tensorList, + const size_t expectedDimNum, const std::string& tensorType) { + uint64_t tensorListLength = tensorList->Size(); + for (size_t i = 0; i < tensorListLength; ++i) { + CHECK_COND((*tensorList)[i] != nullptr, ACLNN_ERR_PARAM_INVALID, + "%s[%lu] is null, which is not supported.", tensorType.c_str(), i); + CHECK_COND(CheckFormat((*tensorList)[i], tensorType, i) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Invalid format."); + size_t dimNum = (*tensorList)[i]->GetViewShape().GetDimNum(); + CHECK_COND(dimNum == expectedDimNum, ACLNN_ERR_PARAM_INVALID, + "%s[%lu] dim num should be %lu in this case, but now is %lu.", + tensorType.c_str(), i, expectedDimNum, dimNum); + if (tensorType == "weight") { + CHECK_COND(IsTransposeLastTwoDims((*gmmParams.weight)[i]) == gmmParams.transposeWeight, ACLNN_ERR_PARAM_INVALID, + "The transpose state must be the same for each tensor in weight."); + } + } + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckDimNumAndGroupListNoSplitAndFormat(const GroupedMatmulParams &gmmParams) { + // When groupType is -1 and not V1 interface, grouplist be empty. + if (gmmParams.apiVersion != GMMApiVersion::V1) { + CHECK_COND(gmmParams.groupListOptional == nullptr, ACLNN_ERR_PARAM_INVALID, + "groupListOptional should be nullptr when groupType is -1."); + } + size_t tensorListLength = gmmParams.x->Size(); + // Check that the length of grouplist is consistent with x when grouplist is not empty. + if (gmmParams.groupListOptional != nullptr) { + CHECK_COND(gmmParams.groupListOptional->Size() == tensorListLength, ACLNN_ERR_PARAM_INVALID, + "Size of groupListOptional %lu should be equal to size of x %lu.", + gmmParams.groupListOptional->Size(), tensorListLength); + } + if (gmmParams.groupTensorOptional != nullptr) { + CHECK_COND(gmmParams.groupTensorOptional->GetViewShape().GetDim(0) == static_cast(tensorListLength), + ACLNN_ERR_PARAM_INVALID, "Size of groupListOptional(tensor) %ld should be equal to size of x %zu.", + gmmParams.groupTensorOptional->GetViewShape().GetDim(0), tensorListLength); + } + int64_t preGoupList = 0; + for (size_t i = 0; i < tensorListLength; ++i) { + // Check dims + CHECK_COND((*gmmParams.x)[i] != nullptr, ACLNN_ERR_PARAM_INVALID, "x[%lu] is null, which is not supported.", i); + CHECK_COND((*gmmParams.weight)[i] != nullptr, ACLNN_ERR_PARAM_INVALID, "weight[%lu] is null, which is not supported.", i); + CHECK_COND(IsTransposeLastTwoDims((*gmmParams.weight)[i]) == gmmParams.transposeWeight, ACLNN_ERR_PARAM_INVALID, + "The transpose state must be the same for each tensor in weight."); + CHECK_COND((*gmmParams.y)[i] != nullptr, ACLNN_ERR_PARAM_INVALID, "y[%lu] is null, which is not supported.", i); + CHECK_COND(CheckFormat((*gmmParams.x)[i], "x", i) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "Invalid format."); + CHECK_COND(CheckFormat((*gmmParams.weight)[i], "weight", i) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "Invalid format."); + CHECK_COND(CheckFormat((*gmmParams.y)[i], "y", i) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "Invalid format."); + size_t xDimNum = (*gmmParams.x)[i]->GetViewShape().GetDimNum(); + size_t weightDimNum = (*gmmParams.weight)[i]->GetViewShape().GetDimNum(); + size_t yDimNum = (*gmmParams.y)[i]->GetViewShape().GetDimNum(); + CHECK_COND(xDimNum <= MAX_FM_DIM && xDimNum >= MIN_FM_DIM, ACLNN_ERR_PARAM_INVALID, + "x[%lu] dimNum is %lu , but only support 2-6.", i, xDimNum); + CHECK_COND(weightDimNum == SEPARATED_WEIGHT_DIM, ACLNN_ERR_PARAM_INVALID, + "weight[%lu] dimNum is %lu , but only support 2 when weight separated.", i, weightDimNum); + CHECK_COND(xDimNum == yDimNum, ACLNN_ERR_PARAM_INVALID, + "y[%lu] dimNum %lu should be equal with x[%lu] DimNum %lu.", i, yDimNum, i, xDimNum); + // If not V1 interface and x dim > 2, grouplist be empty. + if (xDimNum > MIN_FM_DIM) { + CHECK_COND(gmmParams.groupListOptional == nullptr, ACLNN_ERR_PARAM_INVALID, + "groupListOptional should be nullptr when x, y both separated and dim num larger than 2."); + } + if (xDimNum == MIN_FM_DIM && gmmParams.groupListOptional != nullptr) { + int64_t xMDimValue = (*gmmParams.x)[i]->GetViewShape().GetDim(0); + std::string errorMessage = i == 0 ? "groupListOptional[0]" : + "groupListOptional[" + std::to_string(i) + "] - groupListOptional[" + std::to_string(i - 1) + "]"; + CHECK_COND(xMDimValue == (*gmmParams.groupListOptional)[i] - preGoupList, ACLNN_ERR_PARAM_INVALID, + "x[%lu] dim 0 value %ld should be equal to %s %ld.", + i, xMDimValue, errorMessage.c_str(), (*gmmParams.groupListOptional)[i] - preGoupList); + preGoupList = (*gmmParams.groupListOptional)[i]; + } + } + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckNotNull(const aclTensorList *x, const aclTensorList *weight, const aclTensorList *y) { + CHECK_COND(x != nullptr, ACLNN_ERR_PARAM_NULLPTR, "x must not be nullptr."); + CHECK_COND(weight != nullptr, ACLNN_ERR_PARAM_NULLPTR, "weight must not be nullptr."); + CHECK_COND(y != nullptr, ACLNN_ERR_PARAM_NULLPTR, "y must not be nullptr."); + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckGroupListCommonIntArray(const GroupedMatmulParams &gmmParams, const bool isRequiredGroupList, + const size_t groupNum, int64_t &groupListLastValue) { + // Must pass groupList scenario, check groupList is not empty. + CHECK_COND(gmmParams.groupListOptional != nullptr || !isRequiredGroupList, ACLNN_ERR_PARAM_NULLPTR, + "groupListOptional required in this case, but get nullptr."); + if (gmmParams.groupListOptional != nullptr) { + // groupList must be an ascending sequence. + uint64_t groupListSize = gmmParams.groupListOptional->Size(); + CHECK_COND(groupListSize <= MAX_GROUP_LIST_SIZE_ARRAY, ACLNN_ERR_PARAM_INVALID, + "When groupList type is int array, size of groupList %lu should be less than or equal to %ld.", + groupListSize, MAX_GROUP_LIST_SIZE_ARRAY); + int64_t preGoupList = 0; + for (size_t i = 0; i < groupListSize; i++) { + CHECK_COND((*gmmParams.groupListOptional)[i] >= preGoupList, ACLNN_ERR_PARAM_INVALID, + "groupListOptional should be non-negative and incremental."); + preGoupList = (*gmmParams.groupListOptional)[i]; + } + // Check groupList length matches other tensor lists. + CHECK_COND((groupListSize == groupNum && groupNum > 1) || groupNum == 1, ACLNN_ERR_PARAM_INVALID, + "When groupList is not null, size of groupList %lu should be equal to groupNum %lu.", + groupListSize, groupNum); + groupListLastValue = preGoupList; + } + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckGroupListCommonTensor(const GroupedMatmulParams &gmmParams, const bool isRequiredGroupList, + const size_t groupNum) { + CHECK_COND(!(gmmParams.groupTensorOptional == nullptr && isRequiredGroupList), ACLNN_ERR_PARAM_INVALID, + "groupListOptional(tensor) is required in this case, but get nullptr."); + if (gmmParams.groupTensorOptional != nullptr) { + int64_t groupListSize = gmmParams.groupTensorOptional->GetViewShape().GetDim(0); + CHECK_COND(groupListSize <= MAX_GROUP_LIST_SIZE_TENSOR, ACLNN_ERR_PARAM_INVALID, + "When groupList type is tenosr, size of groupList %ld should be less than or equal to %ld.", + groupListSize, MAX_GROUP_LIST_SIZE_TENSOR); + CHECK_COND((groupListSize == static_cast(groupNum) && groupNum > 1) || groupNum == 1, ACLNN_ERR_PARAM_INVALID, + "When groupList is not null, size of groupList(tensor) %ld should be equal to groupNum %lu.", + groupListSize, groupNum); + CHECK_COND(gmmParams.groupTensorOptional->GetDataType() == DataType::DT_INT64, ACLNN_ERR_PARAM_INVALID, + "Invalid dtype: Only int64 is supported for groupList."); + } + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckGroupListSplitK(const GroupedMatmulParams &gmmParams, const bool isRequiredGroupList, + const bool xSeparated, const bool weightSeparated, const size_t groupNum) { + int64_t groupListLastValue = 0; + if (gmmParams.apiVersion == GMMApiVersion::V4 || gmmParams.apiVersion == GMMApiVersion::V3) { + CHECK_COND(CheckGroupListCommonTensor(gmmParams, isRequiredGroupList, groupNum) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "CheckGroupListCommonTensor failed."); + return ACLNN_SUCCESS; + } + CHECK_COND( + CheckGroupListCommonIntArray(gmmParams, isRequiredGroupList, groupNum, groupListLastValue) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "CheckGroupListCommonIntArray failed."); + if (gmmParams.groupListOptional != nullptr) { + if (xSeparated) { + int64_t preOffset = 0; + // Check the increment in groupList matches x's k. + for (size_t i = 0; i < groupNum; i++) { + int64_t xKDimValue = (*gmmParams.x)[i]->GetViewShape().GetDim(1); + std::string errorMessage = i == 0 ? "groupListOptional[0]" : + "groupListOptional[" + std::to_string(i) + "] - groupListOptional[" + std::to_string(i - 1) + "]"; + CHECK_COND(xKDimValue == (*gmmParams.groupListOptional)[i] - preOffset, ACLNN_ERR_PARAM_INVALID, + "x[%lu] dim 1 value %lu should be equal to %s %ld.", + i, xKDimValue, errorMessage.c_str(), (*gmmParams.groupListOptional)[i] - preOffset); + preOffset = (*gmmParams.groupListOptional)[i]; + } + } else if (weightSeparated) { + int64_t preOffset = 0; + // Check the increment in groupList matches weight's k. + for (size_t i = 0; i < groupNum; i++) { + int64_t weightKDimValue = (*gmmParams.weight)[i]->GetViewShape().GetDim(0); + std::string errorMessage = i == 0 ? "groupListOptional[0]" : + "groupListOptional[" + std::to_string(i) + "] - groupListOptional[" + std::to_string(i - 1) + "]"; + CHECK_COND(weightKDimValue == (*gmmParams.groupListOptional)[i] - preOffset, ACLNN_ERR_PARAM_INVALID, + "weight[%lu] dim 0 %lu value should be equal to %s %ld.", + i, weightKDimValue, errorMessage.c_str(), (*gmmParams.groupListOptional)[i] - preOffset); + preOffset = (*gmmParams.groupListOptional)[i]; + } + } else { + CHECK_COND((*gmmParams.x)[0]->GetViewShape().GetDim(1) == groupListLastValue, ACLNN_ERR_PARAM_INVALID, + "When splited axis is K, the last value of group list(%ld) must equal with x shape[1] (%lu).", + groupListLastValue, (*gmmParams.x)[0]->GetViewShape().GetDim(1)); + } + } + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckGroupListSplitM(const GroupedMatmulParams &gmmParams, const bool isRequiredGroupList, + const bool xSeparated, const bool ySeparated, const size_t groupNum) { + int64_t groupListLastValue = 0; + if (gmmParams.apiVersion == GMMApiVersion::V3 || gmmParams.apiVersion == GMMApiVersion::V4) { + CHECK_COND(CheckGroupListCommonTensor(gmmParams, isRequiredGroupList, groupNum) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "CheckGroupListCommonTensor failed."); + return ACLNN_SUCCESS; + } + CHECK_COND( + CheckGroupListCommonIntArray(gmmParams, isRequiredGroupList, groupNum, groupListLastValue) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "CheckGroupListCommonIntArray failed."); + if (gmmParams.groupListOptional != nullptr) { + if (xSeparated) { + int64_t preGoupList = 0; + // Check the increment in groupList matches x's m. + for (size_t i = 0; i < groupNum; i++) { + int64_t xMDimValue = (*gmmParams.x)[i]->GetViewShape().GetDim(0); + std::string errorMessage = i == 0 ? "groupListOptional[0]" : + "groupListOptional[" + std::to_string(i) + "] - groupListOptional[" + std::to_string(i - 1) + "]"; + CHECK_COND(xMDimValue == (*gmmParams.groupListOptional)[i] - preGoupList, ACLNN_ERR_PARAM_INVALID, + "x[%lu] dim 0 value %lu should be equal to %s %ld.", + i, xMDimValue, errorMessage.c_str(), (*gmmParams.groupListOptional)[i] - preGoupList); + preGoupList = (*gmmParams.groupListOptional)[i]; + } + } else if (ySeparated) { + int64_t preGoupList = 0; + // Check the increment in groupList matches y's m. + for (size_t i = 0; i < groupNum; i++) { + int64_t yMDimValue = (*gmmParams.y)[i]->GetViewShape().GetDim(0); + std::string errorMessage = i == 0 ? "groupListOptional[0]" : + "groupListOptional[" + std::to_string(i) + "] - groupListOptional[" + std::to_string(i - 1) + "]"; + CHECK_COND(yMDimValue == (*gmmParams.groupListOptional)[i] - preGoupList, ACLNN_ERR_PARAM_INVALID, + "y[%lu] dim 0 value %lu should be equal to %s %ld.", + i, yMDimValue, errorMessage.c_str(), (*gmmParams.groupListOptional)[i] - preGoupList); + preGoupList = (*gmmParams.groupListOptional)[i]; + } + } else { + CHECK_COND((*gmmParams.x)[0]->GetViewShape().GetDim(0) == groupListLastValue, ACLNN_ERR_PARAM_INVALID, + "When splited axis is M, the last value of group list(%ld) must equal with x shape[0] (%lu).", + groupListLastValue, (*gmmParams.x)[0]->GetViewShape().GetDim(0)); + } + } + return ACLNN_SUCCESS; +} + +static uint64_t GetGroupSize(const GroupedMatmulParams &gmmParams) { + // When X is already split, or in scenarios where splititem is 0 or 2, X input is pre-grouped, + // and group size can be obtained from X. + if (gmmParams.x->Size() > 1) { + return gmmParams.x->Size(); + } + if (gmmParams.weight->Size() > 1) { + return gmmParams.weight->Size(); + } + if (gmmParams.y->Size() > 1) { + return gmmParams.y->Size(); + } + if (gmmParams.groupListOptional != nullptr) { + return gmmParams.groupListOptional->Size(); + } + if (gmmParams.groupTensorOptional != nullptr) { + return gmmParams.groupTensorOptional->GetViewShape().GetDim(0); + } + // If groupList is null, weight must provide split info for x, and it must be grouped into k. + return 1; +} + +static aclnnStatus CheckDimNumAndPerGroupNum(bool isAntiquantInt4, std::tuple tensorDimNums, + const gert::Shape& tensorShape, const std::string& tensorType, int64_t weightKDimValue) { + size_t tensorDimNum = std::get<0>(tensorDimNums); + size_t expectedDimNum = std::get<1>(tensorDimNums); // 1: the sceond element + if (isAntiquantInt4) { + if (tensorDimNum == expectedDimNum) { + int64_t perGroupNum = tensorShape.GetDim(tensorDimNum - 2); + CHECK_COND(perGroupNum > 0 && weightKDimValue % perGroupNum == 0, ACLNN_ERR_PARAM_INVALID, + "perGroupNum must be larger than 0, and can evenly divided by K[%ld] in A16W4-pergroup case," + " but now perGroupNum is %ld.", weightKDimValue, perGroupNum); + } else { + CHECK_COND(tensorDimNum == expectedDimNum - 1, ACLNN_ERR_PARAM_INVALID, + "%s Dim must be %zu in A16W4-perchannel case, but now is %zu.", + tensorType.c_str(), expectedDimNum - 1, tensorDimNum); + } + } else { + CHECK_COND(tensorDimNum == expectedDimNum - 1, ACLNN_ERR_PARAM_INVALID, + "%s Dim must be %zu, but now is %zu.", + tensorType.c_str(), expectedDimNum - 1, tensorDimNum); + } + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckOptionalTensorList(const GroupedMatmulParams &gmmParams, const aclTensorList *tensorList, + const std::string& tensorType) { + // Check bias, scale, antiquant scale, antiquant offset length, tensor's dims and shape. + uint64_t numTotal = GetGroupSize(gmmParams); + uint64_t tensorSize = tensorList->Size(); + uint64_t weightGroupedSize = gmmParams.weight->Size(); + auto w0Shape = (*gmmParams.weight)[0]->GetViewShape(); + uint64_t weightNDimIdx = w0Shape.GetDimNum() - 1; + int64_t weightKDimValue = w0Shape.GetDim(w0Shape.GetDimNum() - 2); // -2: k axis offset + // Check tensorList length matches weight. + CHECK_COND(tensorSize == weightGroupedSize, ACLNN_ERR_PARAM_INVALID, "%s size[%lu] must be " + "equal with weight size[%lu].", tensorType.c_str(), tensorSize, weightGroupedSize); + DataType w0Dtype = (*gmmParams.weight)[0]->GetDataType(); + bool isAntiquantInt4 = (w0Dtype == DataType::DT_INT4 && tensorType.find("antiquant") != std::string::npos); + if (gmmParams.isSingleWeight) { + // If weight is a single tensor, tensor must also be a single tensor following weight. + // Check tensor dimensions must be 2. + CHECK_COND((*tensorList)[0] != nullptr, ACLNN_ERR_PARAM_INVALID, + "%s[0] must not be nullptr, but now is nullptr.", tensorType.c_str()); + auto tensor0Shape = (*tensorList)[0]->GetViewShape(); + size_t tensorDimNum = tensor0Shape.GetDimNum(); + // 3: shape is (E,G,N),G is the perGroupNum + CHECK_COND(CheckDimNumAndPerGroupNum(isAntiquantInt4, {tensorDimNum, 3}, tensor0Shape, tensorType, weightKDimValue) + == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "CheckDimNumAndPerGroupNum failed."); + // Check the first dimension, batch size must match the group size. + uint64_t batchSize = tensor0Shape.GetDim(0); + CHECK_COND(batchSize == numTotal, ACLNN_ERR_PARAM_INVALID, "%s batch size[%lu] should be euqal " + "with groupList length[%lu].", tensorType.c_str(), batchSize, numTotal); + // Check tensor’s Ndim must match weight’s Ndim. + int64_t weightNDimValue = w0Shape.GetDim(weightNDimIdx); + int64_t tensorNDimValue = tensor0Shape.GetDim(tensorDimNum - 1); + CHECK_COND(tensorNDimValue == weightNDimValue, ACLNN_ERR_PARAM_INVALID, + "NDim[%ld] of %s should be equal with NDim[%ld] of weight.", + tensorNDimValue, tensorType.c_str(), weightNDimValue); + } else { + for (uint64_t i = 0; i < numTotal; i++) { + CHECK_COND((*tensorList)[i] != nullptr, ACLNN_ERR_PARAM_INVALID, + "%s[%lu] must not be nullptr, but now is nullptr.", tensorType.c_str(), i); + // If weight is not a single tensor, each tensor dimension must be 1. + auto tensorShape = (*tensorList)[i]->GetViewShape(); + size_t tensorDimNum = tensorShape.GetDimNum(); + auto wShape = (*gmmParams.weight)[i]->GetViewShape(); + // 2: shape is (G,N), G is the perGroupNum + CHECK_COND(CheckDimNumAndPerGroupNum(isAntiquantInt4, {tensorDimNum, 2}, tensorShape, tensorType, + wShape.GetDim(0)) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "CheckDimNumAndPerGroupNum failed."); + // Check the NDIm of each group’s tensor must match the NDim of the same group’s weight. + int64_t weightNDimValue = wShape.GetDim(weightNDimIdx); + int64_t tensorNDimValue = tensorShape.GetDim(tensorDimNum - 1); + CHECK_COND(tensorNDimValue == weightNDimValue, ACLNN_ERR_PARAM_INVALID, + "NDim[%ld] of %s[%lu] should be equal with NDim[%ld] of weight[%lu].", + tensorNDimValue, tensorType.c_str(), i, weightNDimValue, i); + } + } + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckPerTokenScale(const GroupedMatmulParams &gmmParams) { + // check pertoken scale lengh, tensor's dim and shape. + uint64_t perTokenScaleSize = gmmParams.perTokenScaleOptional->Size(); + uint64_t xGroupedSize = gmmParams.x->Size(); + uint64_t weightGroupedSize = gmmParams.weight->Size(); + uint64_t yGroupedSize = gmmParams.y->Size(); + uint64_t xMDimIdx = 0; + // check the length of pertoken scale matches x. + if (xGroupedSize == 1 && weightGroupedSize == 1 && yGroupedSize == 1) { + CHECK_COND(perTokenScaleSize == xGroupedSize && perTokenScaleSize == 1, ACLNN_ERR_PARAM_INVALID, + "perTokenScaleOptional size[%zu] must be 1 and equal with x size[%zu].", + perTokenScaleSize, xGroupedSize); + CHECK_COND((*gmmParams.perTokenScaleOptional)[0] != nullptr, ACLNN_ERR_PARAM_INVALID, + "perTokenScaleOptional[0] must not be nullptr, but now is nullptr."); + // If x is a single tensor, pertoken scale must also be a single tensor following x. + // Check tensor dimensions must be 1. + size_t tensorDimNum = (*gmmParams.perTokenScaleOptional)[0]->GetViewShape().GetDimNum(); + CHECK_COND(tensorDimNum == 1, ACLNN_ERR_PARAM_INVALID, + "perTokenScaleOptional dim num must be 1 when x is single tensor, but now is %zu.", tensorDimNum); + // Check the shape size of pertoken scale must match x’s MDim. + int64_t xMDimValue = (*gmmParams.x)[0]->GetViewShape().GetDim(xMDimIdx); + int64_t tensorMDimValue = (*gmmParams.perTokenScaleOptional)[0]->GetViewShape().GetDim(tensorDimNum - 1); + CHECK_COND(tensorMDimValue == xMDimValue, ACLNN_ERR_PARAM_INVALID, + "MDim[%ld] of perTokenScaleOptional should be equal with MDim[%ld] of x.", + tensorMDimValue, xMDimValue); + } else { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "per-token quant case is only supported " + "when x, weight and y are all single tensor, but now x size is %zu, weight size is %zu, y size is %zu", + xGroupedSize, weightGroupedSize, yGroupedSize); + return ACLNN_ERR_PARAM_INVALID; + } + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckTensorListDataType(const aclTensorList *tensorList, const DataType dtype) { + for (size_t i = 0; i < tensorList->Size(); i++) { + const aclTensor* tensor = (*tensorList)[i]; + OP_CHECK_NULL(tensor, continue); + OP_CHECK_DTYPE_NOT_MATCH(tensor, dtype, return ACLNN_ERR_PARAM_INVALID); + } + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckMatmulDataType(const GroupedMatmulParams &gmmParams, const DataType xDtype, + const DataType weightDtype, const DataType yDtype, const DataType biasDtype) { + CHECK_COND(CheckTensorListDataType(gmmParams.x, xDtype) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "GMM: x dtype does not match with required dtype[%s].", op::ToString(xDtype).GetString()); + CHECK_COND(CheckTensorListDataType(gmmParams.weight, weightDtype) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "GMM: weight dtype does not match with required dtype[%s].", op::ToString(weightDtype).GetString()); + CHECK_COND(CheckTensorListDataType(gmmParams.y, yDtype) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "GMM: y dtype does not match with required dtype[%s].", op::ToString(yDtype).GetString()); + if (gmmParams.biasOptional != nullptr) { + CHECK_COND(CheckTensorListDataType(gmmParams.biasOptional, biasDtype) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "GMM: bias dtype does not match with required dtype[%s].", op::ToString(biasDtype).GetString()); + } + return ACLNN_SUCCESS; +} + +static aclnnStatus IsGmmQuantEmpty(const GroupedMatmulParams &gmmParams) { + CHECK_RET(gmmParams.scaleOptional == nullptr, ACLNN_ERR_PARAM_INVALID); + CHECK_RET(gmmParams.offsetOptional == nullptr, ACLNN_ERR_PARAM_INVALID); + CHECK_RET(gmmParams.perTokenScaleOptional == nullptr, ACLNN_ERR_PARAM_INVALID); + return ACLNN_SUCCESS; +} + +static aclnnStatus IsGmmAntiQuantEmpty(const GroupedMatmulParams &gmmParams) { + CHECK_RET(gmmParams.antiquantScaleOptional == nullptr, ACLNN_ERR_PARAM_INVALID); + CHECK_RET(gmmParams.antiquantOffsetOptional == nullptr, ACLNN_ERR_PARAM_INVALID); + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckNonQuant(const GroupedMatmulParams &gmmParams) { + CHECK_COND(IsGmmQuantEmpty(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Detected nonquant, but quant inputs is not empty!"); + CHECK_COND(IsGmmAntiQuantEmpty(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Detected nonquant, but antiquant inputs is not empty!"); + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckQuantParamsDtype(const GroupedMatmulParams &gmmParams, bool isPerTokenQuant) { + DataType yDtype = (*gmmParams.y)[0]->GetDataType(); + for (size_t i = 0; i < gmmParams.scaleOptional->Size(); i++) { + DataType scaleDtype = (*gmmParams.scaleOptional)[i]->GetDataType(); + if (isPerTokenQuant) { + bool isOutputBF16 = scaleDtype == DataType::DT_BF16 && yDtype == DataType::DT_BF16; + bool isOutputFloat16 = scaleDtype == DataType::DT_FLOAT && yDtype == DataType::DT_FLOAT16; + CHECK_COND(isOutputBF16 || isOutputFloat16, ACLNN_ERR_PARAM_INVALID, + "per-token quant case only supports scale data type bfloat16 with output data type bfloat16," + "or scale with data type float32 when output is float16," + " but now scale[%zu] has data type %s and output has data type %s!", + i, op::ToString(scaleDtype).GetString(), op::ToString(yDtype).GetString()); + } else { + bool isOutputInt8 = (scaleDtype == DataType::DT_INT64 || scaleDtype == DataType::DT_UINT64) && + yDtype == DataType::DT_INT8; + bool isOutputBF16 = scaleDtype == DataType::DT_BF16 && yDtype == DataType::DT_BF16; + bool isOutputFP16 = scaleDtype == DataType::DT_FLOAT && yDtype == DataType::DT_FLOAT16; + CHECK_COND(isOutputInt8 || isOutputBF16 || isOutputFP16, ACLNN_ERR_PARAM_INVALID, + "per-channel quant case only supports scale with data type int64/uint64 when output is int8, " + "or data type bfloat16 when output is bfloat16, " + "or data type float32 when output is float16, " + "but scale[%zu] has data type %s and output has data type %s!", + i, op::ToString(scaleDtype).GetString(), op::ToString(yDtype).GetString()); + } + } + if (isPerTokenQuant) { + for (size_t i = 0; i < gmmParams.perTokenScaleOptional->Size(); i++) { + DataType perTokenScaleDtype = (*gmmParams.perTokenScaleOptional)[i]->GetDataType(); + CHECK_COND(perTokenScaleDtype == DataType::DT_FLOAT, ACLNN_ERR_PARAM_INVALID, + "per-token quant case only support perTokenScale with data type float32, " + "but perTokenScale[%zu] has data type %s!", i, op::ToString(perTokenScaleDtype).GetString()); + } + } + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckGroupedMatmulQuant(const GroupedMatmulParams &gmmParams) { + bool is310P = GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND310P; + CHECK_COND(!is310P, ACLNN_ERR_PARAM_INVALID, + "GMM: quant cases do not support on Ascend310P."); + CHECK_COND(gmmParams.groupType != SPLIT_K, ACLNN_ERR_PARAM_INVALID, + "GMM: quant cases do not support splited axis is K."); + CHECK_COND(gmmParams.offsetOptional == nullptr, ACLNN_ERR_PARAM_INVALID, + "GMM: offset must be nullptr in quant, but now is not nullptr."); + CHECK_COND(gmmParams.scaleOptional != nullptr, ACLNN_ERR_PARAM_INVALID, + "GMM: scale must not be nullptr in quant, but now is nullptr."); + CHECK_COND(CheckOptionalTensorList(gmmParams, gmmParams.scaleOptional, "scale") == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "Invalid scale."); + bool isPerTokenQuant = gmmParams.perTokenScaleOptional != nullptr; + CHECK_COND(CheckQuantParamsDtype(gmmParams, isPerTokenQuant) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Check quant params data type failed!"); + if (isPerTokenQuant) { + CHECK_COND(CheckPerTokenScale(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Check perTokenScale failed!"); + } + CHECK_COND(IsGmmAntiQuantEmpty(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Detected quant, but antiquant inputs is not empty!"); + return ACLNN_SUCCESS; +} + +static int64_t GetPergroupSize(const GroupedMatmulParams &gmmParams, size_t w0DimNum, const gert::Shape& wShape, const gert::Shape& shape) { + int64_t pergroupSize = 0; + size_t shapeDimNum = shape.GetDimNum(); + if (gmmParams.isSingleWeight && w0DimNum > SEPARATED_WEIGHT_DIM) { // antiquant param shape (E, N), (E, G, N) + if (shapeDimNum > SEPARATED_WEIGHT_DIM) { + int64_t k = wShape.GetDim(1); + pergroupSize = k / shape.GetDim(shapeDimNum - 2); // 2: the last 2-th index + } + } else { // antiquant param shape (N), (G, N) + if (shapeDimNum > 1) { + int64_t k = wShape.GetDim(0); + pergroupSize = k / shape.GetDim(shapeDimNum - 2); // 2: the last 2-th index + } + } + return pergroupSize; +} + +static aclnnStatus CheckGroupedMatmulAntiQuant(const GroupedMatmulParams &gmmParams) { + CHECK_COND(GetCurrentPlatformInfo().GetSocVersion() != SocVersion::ASCEND310P, ACLNN_ERR_PARAM_INVALID, + "GMM: antiquant cases do not support on Ascend310P."); + CHECK_COND(gmmParams.groupType != SPLIT_K, ACLNN_ERR_PARAM_INVALID, + "GMM: antiquant cases do not support splited axis is K."); + CHECK_COND(gmmParams.antiquantScaleOptional != nullptr, ACLNN_ERR_PARAM_INVALID, + "GMM: antiquantScale must not be nullptr in antiquant, but now is nullptr."); + CHECK_COND(gmmParams.antiquantOffsetOptional != nullptr, ACLNN_ERR_PARAM_INVALID, + "GMM: antiquantOffset must not be nullptr in antiquant, but now is nullptr."); + // check the shape of antiquantScale and antiquantOffset + CHECK_COND(CheckOptionalTensorList(gmmParams, gmmParams.antiquantScaleOptional, "antiquantScale") == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "Invalid antiquantScale"); + CHECK_COND(CheckOptionalTensorList(gmmParams, gmmParams.antiquantOffsetOptional, "antiquantOffset") == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "Invalid antiquantOffset"); + DataType w0Dtype = (*gmmParams.weight)[0]->GetDataType(); + // check perGroupNum + bool isAntiquantInt4 = w0Dtype == DataType::DT_INT4; + if (isAntiquantInt4) { + auto antiquantScale0Shape = (*gmmParams.antiquantScaleOptional)[0]->GetViewShape(); + size_t antiquantScale0DimNum = antiquantScale0Shape.GetDimNum(); + auto w0Shape = (*gmmParams.weight)[0]->GetViewShape(); + size_t w0DimNum = w0Shape.GetDimNum(); + int64_t pergroupSize = GetPergroupSize(gmmParams, w0DimNum, w0Shape, antiquantScale0Shape); + CHECK_COND(!gmmParams.transposeWeight || pergroupSize % 2 == 0, ACLNN_ERR_PARAM_INVALID, // 2: a factor + "pergroupSize should be even when weight is transposed in A16W4-pergroup case, but now is %ld", pergroupSize); + for (size_t i = 0; i < gmmParams.antiquantScaleOptional->Size(); ++i) { + auto antiquantScaleShape = (*gmmParams.antiquantScaleOptional)[i]->GetViewShape(); + auto antiquantOffsetShape = (*gmmParams.antiquantOffsetOptional)[i]->GetViewShape(); + size_t antiquantScaleDimNum = antiquantScaleShape.GetDimNum(); + size_t antiquantOffsetDimNum = antiquantOffsetShape.GetDimNum(); + CHECK_COND(antiquantScaleDimNum == antiquantScale0DimNum && antiquantScale0DimNum == antiquantOffsetDimNum, + ACLNN_ERR_PARAM_INVALID, "antiquantScale[%zu]'s dim num[%zu] is not equal with first tensor's dim" + " num[%zu] or antiquantOffset[%zu]'s dim num[%zu] is not equal with antiquantScale[0]'s dim num[%zu]", + i, antiquantScaleDimNum, antiquantScale0DimNum, i, antiquantOffsetDimNum, antiquantScale0DimNum); + auto wShape = (*gmmParams.weight)[i]->GetViewShape(); + int64_t pergroupSizeOfScale = GetPergroupSize(gmmParams, w0DimNum, wShape, antiquantScaleShape); + int64_t pergroupSizeOfOffset = GetPergroupSize(gmmParams, w0DimNum, wShape, antiquantOffsetShape); + CHECK_COND(pergroupSizeOfScale == pergroupSize && pergroupSizeOfOffset == pergroupSize, ACLNN_ERR_PARAM_INVALID, + "antiquantScale[%zu]'s pergroup size[%ld] or antiquantOffset[%zu]'s pergroup size[%ld]" + "is not the required value[%ld]", i, pergroupSizeOfScale, i, pergroupSizeOfOffset, pergroupSize); + } + } + CHECK_COND(CheckTensorListDataType(gmmParams.antiquantScaleOptional, gmmParams.xDtype) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "GMM: antiquantScale dtype does not match with x dtype[%s].", + op::ToString(gmmParams.xDtype).GetString()); + CHECK_COND(CheckTensorListDataType(gmmParams.antiquantOffsetOptional, gmmParams.xDtype) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "GMM: antiquantOffset dtype does not match with x dtype[%s].", + op::ToString(gmmParams.xDtype).GetString()); + CHECK_COND(IsGmmQuantEmpty(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Detected antiquant, but quant inputs is not empty!"); + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckFunctionQuantParams(const GroupedMatmulParams &gmmParams) { + DataType yDtypeOrg = (*gmmParams.y)[0]->GetDataType(); + for (size_t i = 0; i < gmmParams.y->Size(); i++) { + const aclTensor* yTensor = (*gmmParams.y)[i]; + OP_CHECK_NULL(yTensor, continue); + DataType yDtype = yTensor->GetDataType(); + CHECK_COND(yDtype == yDtypeOrg, ACLNN_ERR_PARAM_INVALID, + "output tensorlist has different data type, y[0] data type is %s, and y[%zu] data type id %s.", + op::ToString(yDtypeOrg).GetString(), i, op::ToString(yDtype).GetString()); + if (!(yDtype == DataType::DT_INT8 || yDtype == DataType::DT_BF16 || yDtype == DataType::DT_FLOAT16)) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "Expect yDtype is int8, float16 or bfloat16 in quant case, " + "but now y[%zu] dtype is %s", i, op::ToString(yDtype).GetString()); + return ACLNN_ERR_PARAM_INVALID; + } + } + if (gmmParams.biasOptional != nullptr) { + CHECK_COND(CheckTensorListDataType(gmmParams.biasOptional, DataType::DT_INT32) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "GMM: bias dtype does not match with required dtype int32."); + } + CHECK_COND(CheckGroupedMatmulQuant(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "CheckGroupedMatmulQuant failed."); + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckFunctionParams(const GroupedMatmulParams &gmmParams) { + DataType weightDtype = (*gmmParams.weight)[0]->GetDataType(); + bool isNoActivation = gmmParams.activeType == GMMActType::GMM_ACT_TYPE_NONE; + + if (GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND310P) { + bool isAllInputFP16 = gmmParams.xDtype == DataType::DT_FLOAT16 && weightDtype == DataType::DT_FLOAT16; + if (gmmParams.biasOptional != nullptr) { + isAllInputFP16 = isAllInputFP16 && (*gmmParams.biasOptional)[0]->GetDataType() == DataType::DT_FLOAT16; + } + CHECK_COND(isAllInputFP16, ACLNN_ERR_PARAM_INVALID, "Only float16 is supported on Ascend310P platforms."); + CHECK_COND(isNoActivation, ACLNN_ERR_PARAM_INVALID, "Activation is not supported on Ascend310P platforms."); + } + + if ((gmmParams.xDtype == DataType::DT_BF16 || gmmParams.xDtype == DataType::DT_FLOAT16 || + gmmParams.xDtype == DataType::DT_FLOAT) && gmmParams.xDtype == weightDtype) { + if (gmmParams.apiVersion == GMMApiVersion::V1) { + CHECK_COND(gmmParams.xDtype != DataType::DT_FLOAT, ACLNN_ERR_PARAM_INVALID, + "aclnnGroupedMatmul does not support x or weight dtype float32."); + } + DataType biasDtype = gmmParams.xDtype == DataType::DT_BF16 ? DataType::DT_FLOAT: gmmParams.xDtype; + CHECK_RET(CheckMatmulDataType(gmmParams, gmmParams.xDtype, weightDtype, gmmParams.xDtype, biasDtype) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID); + CHECK_COND(isNoActivation, ACLNN_ERR_PARAM_INVALID, "non quant case dose not support activation."); + return CheckNonQuant(gmmParams); + } + if (gmmParams.xDtype == DataType::DT_INT8 && weightDtype == DataType::DT_INT8) { + // quant + DataType yDtype = (*gmmParams.y)[0]->GetDataType(); + CHECK_COND(isNoActivation || yDtype != DataType::DT_INT8, + ACLNN_ERR_PARAM_INVALID, "quant case with output dtype int8 dose not support activation."); + CHECK_COND(CheckFunctionQuantParams(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "CheckFunctionQuantParams failed."); + return ACLNN_SUCCESS; + } + if ((gmmParams.xDtype == DataType::DT_BF16 || gmmParams.xDtype == DataType::DT_FLOAT16) + && (weightDtype == DataType::DT_INT8 || weightDtype == DataType::DT_INT4)) { + // antiquant + DataType biasDtype = gmmParams.xDtype == DataType::DT_BF16 ? DataType::DT_FLOAT: DataType::DT_FLOAT16; + CHECK_RET( + CheckMatmulDataType(gmmParams, gmmParams.xDtype, weightDtype, gmmParams.xDtype, biasDtype) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID); + CHECK_COND(isNoActivation, ACLNN_ERR_PARAM_INVALID, "antiquant case dose not support activation."); + return CheckGroupedMatmulAntiQuant(gmmParams); + } + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "GMM: there is no matching xDtype and weightDtype pattern. " + "case with x dtype %s and weight dtype %s is not supported.", + op::ToString(gmmParams.xDtype).GetString(), op::ToString(weightDtype).GetString()); + return ACLNN_ERR_PARAM_INVALID; +} + +static aclnnStatus CheckWeightShapeInnerAxisEven(const aclTensorList *tensorList, const size_t weightSize, + const int64_t innerAxisDimId) { + if ((*tensorList)[0]->GetDataType() == DataType::DT_INT4) { + for (size_t i = 0; i < weightSize; ++i) { + int64_t n = (*tensorList)[i]->GetViewShape().GetDim(innerAxisDimId); + // 2: a even factor + CHECK_COND(n % 2 == 0, ACLNN_ERR_PARAM_INVALID, "weight's inner axis size[%ld] is not even!", n); + } + } + return ACLNN_SUCCESS; +} + +static aclnnStatus SplitMSingleXSingleWeightSingleY(const GroupedMatmulParams &gmmParams) { + static const std::vector TENSOR_X_WEIGHT{"x", "weight", "true"}; + static const std::vector TENSOR_X_Y{"x", "y", "false"}; + static const std::vector TENSOR_WEIGHT_Y{"weight", "y", "true"}; + CHECK_COND(gmmParams.splitItem == X_SEPARATED || gmmParams.splitItem == NO_SEPARATED, ACLNN_ERR_PARAM_INVALID, + "When y is not separated, splitItem should be 2/3, but current splitItem is %ld.", gmmParams.splitItem); + // check dim + CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.x, MIN_FM_DIM, "x") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Dim num or format of tensor in tensor list x is invalid."); + CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.weight, SPLIT_M_SINGLE_WEIGHT_DIM, "weight") == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "Dim num or format of tensor in tensor list weight is invalid."); + CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.y, MIN_FM_DIM, "y") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Dim num or format of tensor in tensor list y is invalid."); + // check shape, x(m,k), weight(b,k,n), y(m,n) + int64_t innerAxisDimId = 1; // x always is not transposed, check K axis + CHECK_COND(CheckShapeSameLengthTensorList(gmmParams.x, gmmParams.weight, {1, 1}, innerAxisDimId, TENSOR_X_WEIGHT) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "k dim value of x and weight is not matched."); + CHECK_COND(CheckShapeSameLengthTensorList(gmmParams.x, gmmParams.y, {0, 0}, -1, TENSOR_X_Y) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "m dim value of x and y is not matched."); + innerAxisDimId = !gmmParams.transposeWeight ? 2 : -1; // 2:N axis index of weight. If w is not transposed, check N asix; otherwise, check k axis, which can be skiped + // 2:N axis index of weight. + CHECK_COND(CheckShapeSameLengthTensorList(gmmParams.weight, gmmParams.y, {2, 1}, innerAxisDimId, TENSOR_WEIGHT_Y) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "n dim value of weight and y is not matched."); + CHECK_COND(CheckWeightShapeInnerAxisEven(gmmParams.weight, gmmParams.weight->Size(), + gmmParams.transposeWeight ? 1 : 2) == ACLNN_SUCCESS, // 2: axis index + ACLNN_ERR_PARAM_INVALID, "w inner axis size should be even when weight is int4 dtype."); + // check groupList + size_t batchSizeWeight = (*gmmParams.weight)[0]->GetViewShape().GetDim(0); + CHECK_COND(CheckGroupListSplitM(gmmParams, true, false, false, batchSizeWeight) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "Invalid groupList."); + return ACLNN_SUCCESS; +} + +static aclnnStatus SplitMSingleXSeparatedWeightSingleY(const GroupedMatmulParams &gmmParams) { + size_t weightSize = gmmParams.weight->Size(); + static const std::vector TENSOR_WEIGHT_X{"Weight", "x", "true"}; + static const std::vector TENSOR_X_Y{"x", "y", "false"}; + static const std::vector TENSOR_WEIGHT_Y{"Weight", "y", "true"}; + std::string errorMessage = gmmParams.apiVersion != GMMApiVersion::V2 ? "When splited axis is M" : "When groupType is 0"; + CHECK_COND(gmmParams.splitItem == X_SEPARATED || gmmParams.splitItem == NO_SEPARATED, ACLNN_ERR_PARAM_INVALID, + "When y is not separated, splitItem should be 2/3, but current splitItem is %ld.", gmmParams.splitItem); + // check dim + CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.x, MIN_FM_DIM, "x") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Dim num or format of tensor in tensor list x is invalid."); + CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.weight, SEPARATED_WEIGHT_DIM, "weight") == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "Dim num or format of tensor in tensor list weight is invalid."); + CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.y, MIN_FM_DIM, "y") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Dim num or format of tensor in tensor list y is invalid."); + // check shape, x(m,k), weight(k,n), y(m,n) + int64_t innerAxisDimId = 1; // x always is not transposed, check K axis + CHECK_COND(CheckShapeDiffLengthTensorList(gmmParams.weight, gmmParams.x, {0, 1}, innerAxisDimId, TENSOR_WEIGHT_X) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "k dim value of x and weight is not matched."); + CHECK_COND(CheckShapeSameLengthTensorList(gmmParams.x, gmmParams.y, {0, 0}, -1, TENSOR_X_Y) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "m dim value of x and y is not matched."); + innerAxisDimId = !gmmParams.transposeWeight ? 1 : -1; // if w is not transposed, check N asix; otherwise, check k axis, which can be skiped + CHECK_COND(CheckShapeDiffLengthTensorList(gmmParams.weight, gmmParams.y, {1, 1}, innerAxisDimId, TENSOR_WEIGHT_Y) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "n dim value of weight and y is not matched."); + CHECK_COND(CheckWeightShapeInnerAxisEven(gmmParams.weight, weightSize, gmmParams.transposeWeight ? 0 : 1) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "w inner axis size should be even when weight is int4 dtype."); + // check groupList + CHECK_COND(CheckGroupListSplitM(gmmParams, true, false, false, weightSize) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "Invalid groupList."); + return ACLNN_SUCCESS; +} + +static aclnnStatus SplitMSingleXSeparatedWeightSeparatedY(const GroupedMatmulParams &gmmParams) { + size_t ySize = gmmParams.y->Size(); + size_t weightSize = gmmParams.weight->Size(); + static const std::vector TENSOR_WEIGHT_X{"Weight", "x", "true"}; + static const std::vector TENSOR_Y_X{"y", "x", "false"}; + static const std::vector TENSOR_WEIGHT_Y{"Weight", "y", "true"}; + std::string errorMessage = gmmParams.apiVersion == GMMApiVersion::V1 ? "When splited axis is M" : "When groupType is 0"; + CHECK_COND(gmmParams.splitItem == X_Y_SEPARATED || gmmParams.splitItem == Y_SEPARATED, ACLNN_ERR_PARAM_INVALID, + "When y is separated, splitItem should be 0/1, but current splitItem is %ld.", gmmParams.splitItem); + CHECK_COND(ySize == weightSize, ACLNN_ERR_PARAM_INVALID, + "When y and weight are separated, size of y %lu should equal to size of weight %lu.", + ySize, weightSize); + // check dim + CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.x, MIN_FM_DIM, "x") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Dim num or format of tensor in tensor list x is invalid."); + CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.weight, SEPARATED_WEIGHT_DIM, "weight") == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "Dim num or format of tensor in tensor list weight is invalid."); + CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.y, MIN_FM_DIM, "y") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Dim num or format of tensor in tensor list y is invalid."); + // check shape, x(m,k), weight(k,n), y(m,n) + int64_t innerAxisDimId = 1; // x always is not transposed, check K axis + CHECK_COND(CheckShapeDiffLengthTensorList(gmmParams.weight, gmmParams.x, {0, 1}, innerAxisDimId, TENSOR_WEIGHT_X) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "k dim value of x and weight is not matched."); + CHECK_COND(CheckShapeDiffLengthTensorListSplitAxis(gmmParams.y, gmmParams.x, 0, 0, TENSOR_Y_X) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "m dim value of x and y is not matched."); + innerAxisDimId = !gmmParams.transposeWeight ? 1 : -1; // if w is not transposed, check N asix; otherwise, check K axis, which can be skiped + CHECK_COND(CheckShapeSameLengthTensorList(gmmParams.weight, gmmParams.y, {1, 1}, innerAxisDimId, TENSOR_WEIGHT_Y) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "n dim value of weight and y is not matched."); + CHECK_COND(CheckWeightShapeInnerAxisEven(gmmParams.weight, weightSize, gmmParams.transposeWeight ? 0 : 1) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "w inner axis size should be even when weight is int4 dtype."); + // check groupList + CHECK_COND(CheckGroupListSplitM(gmmParams, true, false, true, ySize) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "Invalid groupList."); + return ACLNN_SUCCESS; +} + +static aclnnStatus SplitMSeparatedXSeparatedWeightSingleY(const GroupedMatmulParams &gmmParams) { + size_t xSize = gmmParams.x->Size(); + size_t weightSize = gmmParams.weight->Size(); + static const std::vector TENSOR_WEIGHT_X{"Weight", "x", "true"}; + static const std::vector TENSOR_X_Y{"x", "y", "false"}; + static const std::vector TENSOR_WEIGHT_Y{"Weight", "y", "true"}; + std::string errorMessage = gmmParams.apiVersion != GMMApiVersion::V2 ? "When splited axis is M" : "When groupType is 0"; + CHECK_COND(gmmParams.splitItem == X_SEPARATED || gmmParams.splitItem == NO_SEPARATED, ACLNN_ERR_PARAM_INVALID, + "When y is not separated, splitItem should be 2/3, but current splitItem is %ld.", gmmParams.splitItem); + CHECK_COND(xSize == weightSize, ACLNN_ERR_PARAM_INVALID, + "When x and weight are separated, size of x %lu should equal to size of weight %lu.", + xSize, weightSize); + // check dim + CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.x, MIN_FM_DIM, "x") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Dim num or format of tensor in tensor list x is invalid."); + CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.weight, SEPARATED_WEIGHT_DIM, "weight") == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "Dim num or format of tensor in tensor list weight is invalid."); + CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.y, MIN_FM_DIM, "y") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Dim num or format of tensor in tensor list y is invalid."); + // check shape, x(m,k), weight(k,n), y(m,n) + int64_t innerAxisDimId = 0; // 0: the index of weight's K axis. x always is not transposed, check K axis + CHECK_COND(CheckShapeSameLengthTensorList(gmmParams.weight, gmmParams.x, {0, 1}, innerAxisDimId, TENSOR_WEIGHT_X) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "k dim value of x and weight is not matched."); + CHECK_COND(CheckShapeDiffLengthTensorListSplitAxis(gmmParams.x, gmmParams.y, 0, 0, TENSOR_X_Y) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "m dim value of x and y is not matched."); + innerAxisDimId = !gmmParams.transposeWeight ? 1 : -1; // if w is not transposed, check N asix; otherwise, check k axis, which can be skiped + CHECK_COND(CheckShapeDiffLengthTensorList(gmmParams.weight, gmmParams.y, {1, 1}, innerAxisDimId, TENSOR_WEIGHT_Y) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "n dim value of weight and y is not matched."); + CHECK_COND(CheckWeightShapeInnerAxisEven(gmmParams.weight, weightSize, gmmParams.transposeWeight ? 0 : 1) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "w inner axis size should be even when weight is int4 dtype."); + // check groupList + CHECK_COND(CheckGroupListSplitM(gmmParams, false, true, false, xSize) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "Invalid groupList."); + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckCaseSplitM(const GroupedMatmulParams &gmmParams) { + size_t xSize = gmmParams.x->Size(); + size_t ySize = gmmParams.y->Size(); + size_t weightSize = gmmParams.weight->Size(); + if (xSize == 1 && weightSize == 1 && ySize == 1) { + CHECK_COND(SplitMSingleXSingleWeightSingleY(gmmParams) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "Split m, single x, single weight, single y case failed."); + return ACLNN_SUCCESS; + } + if (xSize == 1 && weightSize > 1 && ySize == 1) { + CHECK_COND(SplitMSingleXSeparatedWeightSingleY(gmmParams) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "Split m, single x, separated weight, single y case failed."); + return ACLNN_SUCCESS; + } + if (xSize == 1 && weightSize > 1 && ySize > 1) { + CHECK_COND(!(gmmParams.apiVersion == GMMApiVersion::V3 || gmmParams.apiVersion == GMMApiVersion::V4), + ACLNN_ERR_PARAM_INVALID, + "When grouplist is tensor, split m, single x, separated weight, separated y cases do not support."); + CHECK_COND(SplitMSingleXSeparatedWeightSeparatedY(gmmParams) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "Split m, single x, separated weight, separated y case failed."); + return ACLNN_SUCCESS; + } + if (xSize > 1 && weightSize > 1 && ySize == 1) { + CHECK_COND(SplitMSeparatedXSeparatedWeightSingleY(gmmParams) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "Split m, separated x, separated weight, single y case failed."); + return ACLNN_SUCCESS; + } + std::string errorMessage = gmmParams.apiVersion != GMMApiVersion::V2 ? "When splited axis is M" : "When groupType is 0"; + if ((gmmParams.apiVersion == GMMApiVersion::V3 || gmmParams.apiVersion == GMMApiVersion::V4) && gmmParams.isSingleWeight) { + errorMessage = "When groupType is 0"; + } + std::string xStatus = xSize > 1 ? "separated" : "not separated"; + std::string weightStatus = weightSize > 1 ? "separated" : "not separated"; + std::string yStatus = ySize > 1 ? "separated" : "not separated"; + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "%s, current case with x %s, weight %s, y %s is not supported.", + errorMessage.c_str(), xStatus.c_str(), weightStatus.c_str(), yStatus.c_str()); + return ACLNN_ERR_PARAM_INVALID; +} + +static aclnnStatus CheckCaseSplitK(const GroupedMatmulParams &gmmParams) { + static const std::vector TENSOR_X_WEIGHT{"x", "weight", "true"}; + static const std::vector TENSOR_X_Y{"x", "y", "false"}; + static const std::vector TENSOR_WEIGHT_Y{"Weight", "y", "true"}; + size_t xSize = gmmParams.x->Size(); + size_t ySize = gmmParams.y->Size(); + size_t weightSize = gmmParams.weight->Size(); + if (xSize == 1 && ySize == 1 && weightSize == 1) { + // The left matrix must be transposed. + CHECK_COND(gmmParams.transposeX, ACLNN_ERR_PARAM_INVALID, + "When groupType is 2 and x is not separated, tensor in x should be transposed."); + // When groupType is 2, splitItem mast be 2/3. + CHECK_COND(gmmParams.splitItem == X_SEPARATED || gmmParams.splitItem == NO_SEPARATED, ACLNN_ERR_PARAM_INVALID, + "When groupType is 2, splitItem should be 2/3, but current splitItem is %ld", gmmParams.splitItem); + // check dim + CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.x, MIN_FM_DIM, "x") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Dim num or format of tensor in tensor list x is invalid."); + CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.weight, SPLIT_K_SINGLE_WEIGHT_DIM, "weight") == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, + "Dim num or format of tensor in tensor list weight is invalid."); + // 3:y is 3 Dims in single-tensor case when split K. + CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.y, 3, "y") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Dim num or format of tensor in tensor list y is invalid."); + // check shape, x(m,k), weight(k,n), y(b,m,n) + int64_t innerAxisDimId = 0; // x always is transposed, check M axis + CHECK_COND(CheckShapeSameLengthTensorList(gmmParams.x, gmmParams.weight, {1, 0}, innerAxisDimId, TENSOR_X_WEIGHT) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "k dim value of x and weight is not matched."); + CHECK_COND(CheckShapeSameLengthTensorList(gmmParams.x, gmmParams.y, {0, 1}, -1, TENSOR_X_Y) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "m dim value of x and y is not matched."); + innerAxisDimId = 1; // w always is not transposed, check N axis + // 2:N axis index of y + CHECK_COND(CheckShapeSameLengthTensorList(gmmParams.weight, gmmParams.y, {1, 2}, innerAxisDimId, TENSOR_WEIGHT_Y) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "n dim value of weight and y is not matched."); + // check groupList + size_t batchSizeY = (*gmmParams.y)[0]->GetViewShape().GetDim(0); + CHECK_COND(CheckGroupListSplitK(gmmParams, true, false, false, batchSizeY) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "Invalid groupList."); + return ACLNN_SUCCESS; + } + OP_LOGE(ACLNN_ERR_PARAM_INVALID, + "When groupType is 2, only support case with unseparated x, weight and y, " + "but now x size is %lu, weight size is %lu, y size is %lu.", xSize, weightSize, ySize); + return ACLNN_ERR_PARAM_INVALID; +} + +static aclnnStatus CheckCaseNoSplit(const GroupedMatmulParams &gmmParams) { + // When groupType is -1, splitItem mast be 0/1. + CHECK_COND(gmmParams.splitItem == X_Y_SEPARATED || gmmParams.splitItem == Y_SEPARATED, ACLNN_ERR_PARAM_INVALID, + "When y is separated, splitItem should be 0/1, but current splitItem is %ld.", gmmParams.splitItem); + // 校验group num + size_t xSize = gmmParams.x->Size(); + size_t ySize = gmmParams.y->Size(); + size_t weightSize = gmmParams.weight->Size(); + CHECK_COND(xSize == ySize, ACLNN_ERR_PARAM_INVALID, + "When y is separated, size of x %lu should equal to size of y %lu.", xSize, ySize); + CHECK_COND(xSize == weightSize, ACLNN_ERR_PARAM_INVALID, + "When x and weight are separated, size of x %lu should equal to size of weight %lu.", + xSize, weightSize); + // check dim + CHECK_COND(CheckDimNumAndGroupListNoSplitAndFormat(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Dim num or format of tensor in tensor lists or grouplist is invalid."); + // check shape + for (size_t i = 0; i < xSize; i++) { + size_t xDimNum = (*gmmParams.x)[i]->GetViewShape().GetDimNum(); + // 2: Indicates validation up to the second last dimension, x and y must be equal in every dimension except the last one. + for (size_t dimIdx = 0; dimIdx < xDimNum - 2; dimIdx++) { + size_t xDimValue = (*gmmParams.x)[i]->GetViewShape().GetDim(dimIdx); + size_t yDimValue = (*gmmParams.y)[i]->GetViewShape().GetDim(dimIdx); + CHECK_COND(xDimValue == yDimValue, ACLNN_ERR_PARAM_INVALID, + "y[%lu] dim %lu value %lu should equal to x[%lu] dim %lu value %lu.", + i, dimIdx, xDimValue, i, dimIdx, yDimValue); + } + // check the inner dim of x is less than 65535 + size_t xKDimValue = (*gmmParams.x)[i]->GetViewShape().GetDim(xDimNum - 1); // x always is not transposed + CHECK_COND(xKDimValue <= MAX_INNER_AXIS, ACLNN_ERR_PARAM_INVALID, + "x[%lu] dim %lu value %lu should less or equal to 65535.", i, xDimNum - 1, xKDimValue); + size_t weightKDimValue = (*gmmParams.weight)[i]->GetViewShape().GetDim(0); + CHECK_COND(xKDimValue == weightKDimValue, ACLNN_ERR_PARAM_INVALID, + "x[%lu] dim %lu value %lu should equal to weight[%lu] dim 0 value %lu.", + i, xDimNum - 1, xKDimValue, i, weightKDimValue); + size_t weightNDimValue = (*gmmParams.weight)[i]->GetViewShape().GetDim(1); + if (!gmmParams.transposeWeight) { // if weight is not transposed, check N aisx; otherwise, check K axis, which can be skiped + CHECK_COND(weightNDimValue <= MAX_INNER_AXIS, ACLNN_ERR_PARAM_INVALID, + "w[%lu] dim %d value %lu should less or equal to 65535.", i, 1, weightNDimValue); + } + if ((*gmmParams.weight)[0]->GetDataType() == DataType::DT_INT4) { + CHECK_COND(weightNDimValue % 2 == 0, ACLNN_ERR_PARAM_INVALID, // 2: an even factor + "w[%lu] dim %d value %lu should be even when weight is int4 dtype.", i, 1, weightNDimValue); + } + // check y[n]=weight[n] + size_t yNDimValue = (*gmmParams.y)[i]->GetViewShape().GetDim(xDimNum - 1); + CHECK_COND(yNDimValue == weightNDimValue, ACLNN_ERR_PARAM_INVALID, + "y[%lu] dim %lu value %lu should equal to weight[%lu] dim 1 value %lu.", + i, xDimNum - 1, yNDimValue, i, weightNDimValue); + } + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckParamDifferentGroupType(const GroupedMatmulParams &gmmParams) { + CHECK_COND(!(gmmParams.transposeX && gmmParams.transposeWeight), ACLNN_ERR_PARAM_INVALID, + "x and weight can not be transposed at the same time."); + CHECK_COND((gmmParams.groupListOptional == nullptr || gmmParams.groupListOptional->Size() != 1) && + (gmmParams.groupTensorOptional == nullptr || gmmParams.groupTensorOptional->GetViewShape().GetDim(0) != 1), + ACLNN_ERR_PARAM_INVALID, "size of groupList can not be 1." + "If expected group num is 1, groupList should be nullptr."); + if (GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND310P && gmmParams.transposeWeight) { + CHECK_COND(gmmParams.groupType == SPLIT_M && gmmParams.x->Size() == 1 && gmmParams.weight->Size() == 1 + && gmmParams.y->Size() == 1, ACLNN_ERR_PARAM_INVALID, + "When transpose weight, ASCEND310P only support split m, single x, single weight, single y."); + } + if (gmmParams.groupType == NO_SPLIT) { + CHECK_COND(!gmmParams.transposeX, ACLNN_ERR_PARAM_INVALID, + "When x, weight and y are all separated, x can not be transposed."); + CHECK_COND(!(gmmParams.apiVersion == GMMApiVersion::V1 && gmmParams.transposeWeight), ACLNN_ERR_PARAM_INVALID, + "in this version, when x, weight and y are all separated, weight can not be transposed."); + CHECK_COND(CheckCaseNoSplit(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Invalid inputs!"); + } else if (gmmParams.groupType == SPLIT_M) { + std::string errorMessage = gmmParams.apiVersion != GMMApiVersion::V2 && !gmmParams.isSingleWeight + ? "When splited axis is M" : "When groupType is 0"; + CHECK_COND(!gmmParams.transposeX, ACLNN_ERR_PARAM_INVALID, + "%s, x can not be transposed.", errorMessage.c_str()); + CHECK_COND(CheckCaseSplitM(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Invalid inputs!"); + } else if (gmmParams.groupType == SPLIT_K) { + CHECK_COND(gmmParams.biasOptional == nullptr, ACLNN_ERR_PARAM_INVALID, + "When groupType is 2, bias must be empty."); + CHECK_COND(CheckCaseSplitK(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Invalid inputs!"); + } + if (gmmParams.biasOptional != nullptr) { + CHECK_COND(CheckOptionalTensorList(gmmParams, gmmParams.biasOptional, "bias") == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "Invalid bias!"); + } + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckTensorListLength(const aclTensorList *tensorList) { + size_t groupSize = 0; + if (tensorList != nullptr) { + groupSize = tensorList->Size(); + } + CHECK_COND(groupSize <= MAX_GROUP_LIST_SIZE_ARRAY, ACLNN_ERR_PARAM_INVALID, + "Length of tensorList should not exceed %ld, but actually got %ld.", + MAX_GROUP_LIST_SIZE_ARRAY, groupSize); + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckGroupSize(const GroupedMatmulParams &gmmParams) { + // Only groupSizes of necessary inputs will be checked here. + // The groupSizes of optional inputs and output will be checked in subsequent steps. + CHECK_COND(CheckTensorListLength(gmmParams.x) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Invalid length of tensorList x."); + CHECK_COND(CheckTensorListLength(gmmParams.weight) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Invalid length of tensorList weight."); + + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckUnusedParams(const GroupedMatmulParams &gmmParams) { + // Check currently disabled parameters, delete accordingly when parameter functionality is supported. + CHECK_COND(gmmParams.activationInputOptional == nullptr, ACLNN_ERR_PARAM_INVALID, + "activationInputOptional must be nullptr."); + CHECK_COND(gmmParams.activationQuantScaleOptional == nullptr, ACLNN_ERR_PARAM_INVALID, + "activationQuantScaleOptional must be nullptr."); + CHECK_COND(gmmParams.activationQuantOffsetOptional == nullptr, ACLNN_ERR_PARAM_INVALID, + "activationQuantOffsetOptional must be nullptr."); + CHECK_COND(gmmParams.activationFeatureOutOptional == nullptr, ACLNN_ERR_PARAM_INVALID, + "activationFeatureOutOptional must be nullptr."); + CHECK_COND(gmmParams.dynQuantScaleOutOptional == nullptr, ACLNN_ERR_PARAM_INVALID, + "dynQuantScaleOutOptional must be nullptr."); + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckParam(const GroupedMatmulParams &gmmParams) { + CHECK_COND(CheckUnusedParams(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "Invalid unused params."); + CHECK_RET(CheckFunctionParams(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID); + CHECK_RET(CheckParamDifferentGroupType(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID); + CHECK_RET(CheckGroupSize(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID); + return ACLNN_SUCCESS; +} + +static void CreateContiguousTensorList(const aclTensorList *tensorList, + std::vector &newTensorList, + aclOpExecutor *executor) { + op::Shape shape; + for (uint64_t idx = 0; idx < (*tensorList).Size(); idx++) { + const aclTensor *inputTensor = (*tensorList)[idx]; + op::Shape viewShape = inputTensor->GetViewShape(); + uint32_t viewShapeDimsNum = viewShape.GetDimNum(); + shape.SetScalar(); + // 2: the second last dimension; in for-loops, it indicates dimensions before the second last remain unchanged. + for (uint32_t i = 0; i < viewShapeDimsNum - 2; ++i) { + shape.AppendDim(viewShape.GetDim(i)); + } + // viewShapeDimsNum - 1, the dim value of the last dim. viewShapeDimsNum - 2, the dim value of the second last dim. + shape.AppendDim(viewShape.GetDim(viewShapeDimsNum - 1)); + shape.AppendDim(viewShape.GetDim(viewShapeDimsNum - 2)); // 2:the second last dim. + aclTensor *tensor = executor->CreateView(inputTensor, shape, inputTensor->GetViewOffset()); // use executor to create tensor + tensor->SetStorageFormat(inputTensor->GetStorageFormat()); + newTensorList.emplace_back(tensor); + } +} + +static void CheckOptionalTensorListEmpty(const aclTensorList *&tensorList) { + if (tensorList != nullptr) { + if (tensorList->Size() == 0) { + tensorList = nullptr; + } else if ((*tensorList)[0] == nullptr) { + tensorList = nullptr; + } else if (tensorList->Size() == 1) { + op::Shape shape = (*tensorList)[0]->GetViewShape(); + if (shape.GetDimNum() == 1 && shape.GetDim(0) == 0) { + tensorList = nullptr; + } + } + } +} + +static void ResetEmptyTensor(GroupedMatmulParams &gmmParams) { + // set the empty tensor list to nullptr + if (gmmParams.groupListOptional != nullptr && gmmParams.groupListOptional->Size() == 0) { + gmmParams.groupListOptional = nullptr; + } + CheckOptionalTensorListEmpty(gmmParams.biasOptional); + CheckOptionalTensorListEmpty(gmmParams.scaleOptional); + CheckOptionalTensorListEmpty(gmmParams.offsetOptional); + CheckOptionalTensorListEmpty(gmmParams.antiquantScaleOptional); + CheckOptionalTensorListEmpty(gmmParams.antiquantOffsetOptional); + CheckOptionalTensorListEmpty(gmmParams.perTokenScaleOptional); + CheckOptionalTensorListEmpty(gmmParams.activationInputOptional); + CheckOptionalTensorListEmpty(gmmParams.activationQuantScaleOptional); + CheckOptionalTensorListEmpty(gmmParams.activationQuantOffsetOptional); + CheckOptionalTensorListEmpty(gmmParams.activationFeatureOutOptional); + CheckOptionalTensorListEmpty(gmmParams.dynQuantScaleOutOptional); +} + +static void CreateEmptyTensor(const aclDataType dataType, const aclTensorList *&gmmTensorList, + aclTensorList *&tensorList, aclOpExecutor *executor) { + // if tensor list is nullptr, convert tensorlist to a tensorlist containing a tensor with shape 0. + if (gmmTensorList == nullptr) { + FVector emptyTensors; + aclTensor *emptyTensor = executor->AllocTensor({0}, static_cast(dataType)); + emptyTensors.emplace_back(emptyTensor); + tensorList = executor->AllocTensorList(emptyTensors.data(), emptyTensors.size()); + gmmTensorList = tensorList; + } +} + +static aclnnStatus DataContiguous(const aclTensorList *&tensors, aclOpExecutor *executor) { + std::vector tensorsVec; + const aclTensor *contiguousTensor = nullptr; + for (size_t i = 0; i < tensors->Size(); ++i) { + const aclTensor *tensor = (*tensors)[i]; + contiguousTensor = l0op::Contiguous(tensor, executor); + CHECK_RET(contiguousTensor != nullptr, ACLNN_ERR_INNER_NULLPTR); + tensorsVec.push_back(contiguousTensor); + } + tensors = executor->AllocTensorList(tensorsVec.data(), tensorsVec.size()); + return ACLNN_SUCCESS; +} + +static aclnnStatus DataContiguousAndTransFormat(const aclTensor *tensor, const aclTensor *&reformatedTensor, + const op::Format requiredFormat, aclOpExecutor *executor) { + if (!op::IsPrivateFormat(tensor->GetStorageFormat()) && tensor->GetStorageFormat() != op::Format::FORMAT_ND) { + tensor = l0op::ReFormat(tensor, op::Format::FORMAT_ND, executor); + } + if (tensor == nullptr || tensor->GetViewShape().GetDimNum() == 1) { + OP_LOGD("No need to do contiguous process"); + reformatedTensor = tensor; + } else { + reformatedTensor = l0op::Contiguous(tensor, executor); + } + CHECK_RET(reformatedTensor != nullptr, ACLNN_ERR_INNER_NULLPTR); + reformatedTensor = l0op::TransData(reformatedTensor, requiredFormat, 1, executor); + CHECK_RET(reformatedTensor != nullptr, ACLNN_ERR_INNER_NULLPTR); + return ACLNN_SUCCESS; +} + +static aclnnStatus TransWeightToNz(GroupedMatmulParams &gmmParams, aclOpExecutor *executor) { + bool is310p = GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND310P; + if (is310p) { + const aclTensorList *&weights = gmmParams.weight; + size_t wLength = weights->Size(); + std::vector reformatedWeightVec; + const aclTensor* reformatedWeight = nullptr; + // trans weight format + for (size_t i(0); i < wLength; ++i) { + const aclTensor* weight = (*weights)[i]; + op::Shape shape = weight->GetViewShape(); + // 2: When weight is transposed, n is the second last axis. + int64_t n = gmmParams.transposeWeight ? shape.GetDim(shape.GetDimNum() - 2) : shape.GetDim(shape.GetDimNum() - 1); + DataType dtype = weight->GetDataType(); + // 32: matmul api requires n axis aligning with 32 bytes + CHECK_COND(n % static_cast(32 / std::max(1, op::TypeSize(dtype))) == 0, ACLNN_ERR_PARAM_INVALID, + "output n axis should align with 32 Bytes, but now is %ld", n); + aclnnStatus ret = DataContiguousAndTransFormat(weight, reformatedWeight, + Format::FORMAT_FRACTAL_NZ, executor); + CHECK_RET(ret == ACLNN_SUCCESS, ret); + reformatedWeightVec.push_back(reformatedWeight); + } + weights = executor->AllocTensorList(reformatedWeightVec.data(), reformatedWeightVec.size()); + } else { // 910 + const aclTensorList *&weights = gmmParams.weight; + size_t wLength = weights->Size(); + for (size_t i(0); i < wLength; ++i) { + const aclTensor* weight = (*weights)[i]; + if (weight->GetStorageFormat() != op::Format::FORMAT_FRACTAL_NZ) { + break; + } + op::Shape shape = weight->GetViewShape(); + // 2: When weight is transposed, n is the second last axis. + int64_t n = gmmParams.transposeWeight ? shape.GetDim(shape.GetDimNum() - 2) : shape.GetDim(shape.GetDimNum() - 1); + DataType dtype = weight->GetDataType(); + // 32: matmul api requires n axis aligning with 32 bytes + CHECK_COND(n % static_cast(32 / std::max(1, op::TypeSize(dtype))) == 0, ACLNN_ERR_PARAM_INVALID, + "output n axis should align with 32 Bytes, but now is %ld", n); + } + } + return ACLNN_SUCCESS; +} + +static aclnnStatus CheckZeroShape(GroupedMatmulParams ¶ms, uint64_t *workspaceSize) { + bool isEmpty = true; + for (size_t i = 0; i < params.x->Size(); ++i) { + if (!((*params.x)[i]->IsEmpty())) { + isEmpty = false; + break; + } + } + if (isEmpty) { + *workspaceSize = 0; + return ACLNN_ERR_PARAM_INVALID; + } + return ACLNN_SUCCESS; +} + +static void SetParamsTensorEmpty(GroupedMatmulParams ¶ms, aclOpExecutor *executor) { + aclTensorList *emptyBiasList = nullptr; + CreateEmptyTensor(BIAS_DTYPE.at(params.xDtype), params.biasOptional, emptyBiasList, executor); + + aclTensorList *emptyScaleList = nullptr; + CreateEmptyTensor(aclDataType::ACL_UINT64, params.scaleOptional, emptyScaleList, executor); + + aclTensorList *emptyOffsetList = nullptr; + CreateEmptyTensor(aclDataType::ACL_FLOAT, params.offsetOptional, emptyOffsetList, executor); + + aclTensorList *emptyAntiquantScaleList = nullptr; + CreateEmptyTensor(aclDataType::ACL_FLOAT16, params.antiquantScaleOptional, emptyAntiquantScaleList, executor); + + aclTensorList *emptyAntiquantOffsetList = nullptr; + CreateEmptyTensor(aclDataType::ACL_FLOAT16, params.antiquantOffsetOptional, emptyAntiquantOffsetList, executor); + + aclTensorList *emptyPerTokenScaleList = nullptr; + CreateEmptyTensor(aclDataType::ACL_FLOAT, params.perTokenScaleOptional, emptyPerTokenScaleList, executor); + + aclTensorList *emptyActivationInputList = nullptr; + CreateEmptyTensor(aclDataType::ACL_FLOAT, params.activationInputOptional, emptyActivationInputList, executor); + + aclTensorList *emptyActivationQuantScaleList = nullptr; + CreateEmptyTensor(aclDataType::ACL_FLOAT, params.activationQuantScaleOptional, emptyActivationQuantScaleList, executor); + + aclTensorList *emptyActivationQuantOffsetList = nullptr; + CreateEmptyTensor(aclDataType::ACL_FLOAT, params.activationQuantOffsetOptional, emptyActivationQuantOffsetList, executor); + + aclTensorList *emptyActivationFeatureOutList = nullptr; + CreateEmptyTensor(aclDataType::ACL_FLOAT, params.activationFeatureOutOptional, emptyActivationFeatureOutList, executor); + + aclTensorList *emptyDynQuantScaleOutList = nullptr; + CreateEmptyTensor(aclDataType::ACL_FLOAT, params.dynQuantScaleOutOptional, emptyDynQuantScaleOutList, executor); +} + +static aclnnStatus CheckOutputShape(const aclTensorList* l0Res, const aclTensorList* y) { + CHECK_COND(l0Res->Size() == y->Size(), ACLNN_ERR_PARAM_INVALID, "Output tensor list length is not right."); + for (size_t i = 0; i < y->Size(); ++i) { + auto const &resShape = (*l0Res)[i]->GetViewShape(); + auto const &yShape = (*y)[i]->GetViewShape(); + if (resShape != yShape) { + if (!(resShape.GetShapeSize() == 1 && yShape.GetShapeSize() == 1)) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "Output tensor's shape[%s] is not equal with infered output's shape[%s].", + op::ToString(yShape).GetString(), op::ToString(resShape).GetString()); + return ACLNN_ERR_PARAM_INVALID; + } + } + } + return ACLNN_SUCCESS; +} + +static aclnnStatus ParamsDataContiguous(GroupedMatmulParams ¶ms, aclOpExecutor *executorPtr) { + CHECK_COND(DataContiguous(params.x, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Contiguous x failed."); // make x contiguous + CHECK_COND(DataContiguous(params.weight, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Contiguous weight failed."); // make w contiguous + CHECK_COND(DataContiguous(params.biasOptional, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Contiguous biasOptional failed."); + CHECK_COND(DataContiguous(params.scaleOptional, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Contiguous scaleOptional failed."); + CHECK_COND(DataContiguous(params.offsetOptional, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Contiguous offsetOptional failed."); + CHECK_COND(DataContiguous(params.antiquantScaleOptional, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Contiguous antiquantScaleOptional failed."); + CHECK_COND(DataContiguous(params.antiquantOffsetOptional, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Contiguous antiquantOffsetOptional failed."); + CHECK_COND(DataContiguous(params.perTokenScaleOptional, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Contiguous perTokenScaleOptional failed."); + if (params.groupTensorOptional != nullptr) { + params.groupTensorOptional = l0op::Contiguous(params.groupTensorOptional, executorPtr); + CHECK_COND(params.groupTensorOptional != nullptr, ACLNN_ERR_PARAM_INVALID, + "Contiguous groupTensorOptional failed."); + } + return ACLNN_SUCCESS; +} + +static aclnnStatus GetGMMResultByL0Api(GroupedMatmulParams ¶ms, uint64_t *workspaceSize, aclOpExecutor **executor) { + auto uniqueExecutor = CREATE_EXECUTOR(); // fixed writen style, create OpExecutor + aclOpExecutor *executorPtr = uniqueExecutor.get(); + CHECK_RET(executorPtr != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR); + CHECK_COND(BIAS_DTYPE.find(params.xDtype) != BIAS_DTYPE.cend(), ACLNN_ERR_PARAM_INVALID, + "GMM: Cannot find bias dtype match with xDtype[%s]", op::ToString(params.xDtype).GetString()); + + SetParamsTensorEmpty(params, executorPtr); // create empty tensorLists + + if (params.transposeX) { + std::vector xTensorList; + CreateContiguousTensorList(params.x, xTensorList, executorPtr); + params.x = executorPtr->AllocTensorList(xTensorList.data(), xTensorList.size()); + } + if (params.transposeWeight) { + std::vector weightTensorList; + CreateContiguousTensorList(params.weight, weightTensorList, executorPtr); + params.weight = executorPtr->AllocTensorList(weightTensorList.data(), weightTensorList.size()); + } + CHECK_COND(ParamsDataContiguous(params, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "ParamsDataContiguous failed."); + if (CheckZeroShape(params, workspaceSize) != ACLNN_SUCCESS) { + uniqueExecutor.ReleaseTo(executor); + return ACLNN_SUCCESS; + } + CHECK_COND(TransWeightToNz(params, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "TransWeightToNz failed."); + + if (params.groupListOptional != nullptr) { + params.groupTensorOptional = uniqueExecutor->ConvertToTensor(params.groupListOptional, op::ToOpDataType(ACL_INT64)); + } + auto perTokenScaleOptional = (*params.perTokenScaleOptional)[0]->IsEmpty() ? nullptr : (*params.perTokenScaleOptional)[0]; + // Invoke l0 operator GroupedMatmul for calculation. + auto result = l0op::GroupedMatmul(params.x, params.weight, params.biasOptional, params.scaleOptional, + params.offsetOptional, params.antiquantScaleOptional, params.antiquantOffsetOptional, + params.groupTensorOptional, perTokenScaleOptional, params.splitItem, + (*params.y)[0]->GetDataType(), params.transposeWeight, params.transposeX, params.groupType, + params.groupListType, params.activeType, params.y->Size(), executorPtr); + CHECK_RET(result != nullptr, ACLNN_ERR_INNER_NULLPTR); + + CHECK_COND(CheckOutputShape(result, params.y) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Check output shape failed."); + // If the output tensor is non-contiguous, convert the calculated contiguous tensor to non-contiguous. + for (size_t i(0); i < params.y->Size(); ++i) { + auto viewCopyResult = l0op::ViewCopy((*result)[i], (*params.y)[i], executorPtr); + CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR); + } + + // Standard syntax, get the size of workspace needed during computation. + *workspaceSize = uniqueExecutor->GetWorkspaceSize(); + uniqueExecutor.ReleaseTo(executor); + return ACLNN_SUCCESS; +} + +static aclnnStatus PreCheckGroupType(int64_t splitItem, int64_t groupType) { + // Intercept currently unsupported groupType + CHECK_COND(groupType != SPLIT_N, ACLNN_ERR_PARAM_INVALID, "Not support split n dim now, groupType can not be 1."); + CHECK_COND(groupType == SPLIT_M || groupType == SPLIT_K || groupType == NO_SPLIT, ACLNN_ERR_PARAM_INVALID, + "groupType only support -1/0/2 now, but given groupType is %ld", groupType); + if (splitItem == X_SEPARATED || splitItem == NO_SEPARATED) { + CHECK_COND(groupType != NO_SPLIT, ACLNN_ERR_PARAM_INVALID, "When splitItem is 2/3, groupType can not be -1."); + } + return ACLNN_SUCCESS; +} + +static int64_t CorrectSplitItem(const aclTensorList *x, const aclTensorList *y, int64_t splitItem) { + int64_t splitItemCorrected = splitItem; + // Adjust split item based on the range of split item and group type. + if (splitItem == X_Y_SEPARATED || splitItem == Y_SEPARATED) { + // If X and Y have the same size, the input X and Y must be grouped. + splitItemCorrected = x->Size() == y->Size() ? X_Y_SEPARATED : Y_SEPARATED; + } + if (splitItem == X_SEPARATED || splitItem == NO_SEPARATED) { + splitItemCorrected = x->Size() == 1 ? NO_SEPARATED : X_SEPARATED; + } + return splitItemCorrected; +} + +static aclnnStatus CheckTransposeStatus(const aclTensorList *x, const aclTensorList *weight, bool &transposeX, + bool &transposeWeight, int64_t groupType) { + CHECK_COND((*x)[0] != nullptr, ACLNN_ERR_PARAM_INVALID, "x[0] is nullptr!"); + CHECK_COND((*weight)[0] != nullptr, ACLNN_ERR_PARAM_INVALID, "weight[0] is nullptr!"); + transposeX = IsTransposeLastTwoDims((*x)[0]); // check is transpose x + // if last two axis shape is (1, 1), IsTransposeLastTwoDims() api will return x is not transposed, + // but when group type is 2, x is required to be transposed. To ensure this case can execute normally, + // transposeX is setted to true manually. + if (groupType == SPLIT_K) { + size_t x0DimNum = (*x)[0]->GetViewShape().GetDimNum(); + size_t checkedAxisNum = x0DimNum > 1 ? 2 : 1; // 2:need to check last two axis' shape + size_t lastAxisSize = 1; + for (size_t i = 1; i <= checkedAxisNum; ++i) { + lastAxisSize *= (*x)[0]->GetViewShape().GetDim(x0DimNum - i); + } + transposeX = transposeX || lastAxisSize == 1; + } + transposeWeight = IsTransposeLastTwoDims((*weight)[0]); // check is transpose w + return ACLNN_SUCCESS; +} + +static aclnnStatus aclnnGroupedMatmulGetWorkspaceSizeCommon(const aclTensorList *x, const aclTensorList *weight, + const aclTensorList *biasOptional, const aclTensorList *scaleOptional, const aclTensorList *offsetOptional, + const aclTensorList *antiquantScaleOptional, const aclTensorList *antiquantOffsetOptional, + const aclTensorList *perTokenScaleOptional, const aclIntArray *groupListOptional, + const aclTensor *groupTensorOptional, const aclTensorList *activationInputOptional, + const aclTensorList *activationQuantScaleOptional, const aclTensorList *activationQuantOffsetOptional, + int64_t splitItem, int64_t groupType, int64_t groupListType, int64_t actType, GMMApiVersion apiVersion, + const aclTensorList *y, const aclTensorList *activationFeatureOutOptional, + const aclTensorList *dynQuantScaleOutOptional, uint64_t *workspaceSize, aclOpExecutor **executor) { + DataType xDtype = DataType::DT_UNDEFINED; + for (size_t i = 0; i < x->Size(); ++i) { + if ((*x)[i] != nullptr) { + xDtype = (*x)[i]->GetDataType(); + break; + } + } + bool isSingleWeight = (weight->Size() == 1 && groupType != NO_SPLIT); + bool transposeX; + bool transposeWeight; + CHECK_COND(CheckTransposeStatus(x, weight, transposeX, transposeWeight, groupType) == ACLNN_SUCCESS, + ACLNN_ERR_PARAM_INVALID, "CheckTransposeStatus failed!"); + GroupedMatmulParams gmmParams{x, weight, biasOptional, groupListOptional, groupTensorOptional, scaleOptional, + offsetOptional, antiquantScaleOptional, antiquantOffsetOptional, perTokenScaleOptional, + activationInputOptional, activationQuantScaleOptional, activationQuantOffsetOptional, + splitItem, groupListType, actType, transposeWeight, transposeX, isSingleWeight, + apiVersion, groupType, y, activationFeatureOutOptional, dynQuantScaleOutOptional, + xDtype}; + if (gmmParams.scaleOptional != nullptr) { + for (size_t i = 0; i < gmmParams.scaleOptional->Size(); i++) { + if ((*gmmParams.scaleOptional)[i]->GetDataType() == DataType::DT_INT64) { + (void)const_cast((*gmmParams.scaleOptional)[i])->SetDataType(op::DataType::DT_UINT64); + } + } + } + ResetEmptyTensor(gmmParams); // make empty tensor/tensorList nullptr + CHECK_RET(CheckParam(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID); + gmmParams.splitItem = CorrectSplitItem(x, y, splitItem); + + aclnnStatus ret = GetGMMResultByL0Api(gmmParams, workspaceSize, executor); + + return ret; +} + +aclnnStatus aclnnGroupedMatmulV4GetWorkspaceSize(const aclTensorList *x, const aclTensorList *weight, + const aclTensorList *biasOptional, const aclTensorList *scaleOptional, const aclTensorList *offsetOptional, + const aclTensorList *antiquantScaleOptional, const aclTensorList *antiquantOffsetOptional, + const aclTensorList *perTokenScaleOptional, const aclTensor *groupListOptional, + const aclTensorList *activationInputOptional, const aclTensorList *activationQuantScaleOptional, + const aclTensorList *activationQuantOffsetOptional, int64_t splitItem, int64_t groupType, int64_t groupListType, + int64_t actType, aclTensorList *out, aclTensorList *activationFeatureOutOptional, + aclTensorList *dynQuantScaleOutOptional, uint64_t *workspaceSize, aclOpExecutor **executor) { + CHECK_COND(CheckNotNull(x, weight, out) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_NULLPTR, + "one of required inputs is nullptr."); + // Standard syntax, Check parameters. + L2_DFX_PHASE_1(aclnnGroupedMatmulV4, + DFX_IN(x, weight, biasOptional, scaleOptional, offsetOptional, + antiquantScaleOptional, antiquantOffsetOptional, perTokenScaleOptional, activationInputOptional, + activationQuantScaleOptional, activationQuantOffsetOptional, + groupListOptional, splitItem, groupType, groupListType, actType), + DFX_OUT(out, activationFeatureOutOptional, dynQuantScaleOutOptional)); + bool is310P = GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND310P; + bool supportedCaseOn310P = x->Size() == 1 && out->Size() == 1 && weight->Size() == 1 && groupType == 0; + CHECK_COND((is310P && supportedCaseOn310P) || !is310P, ACLNN_ERR_PARAM_INVALID, + "only surpport x, y, weight not separated case with groupType is 0 on ASCEND310P."); + CHECK_COND(PreCheckGroupType(splitItem, groupType) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "PreCheckGroupType failed, groupType is invalid."); + CHECK_COND(groupListOptional == nullptr || groupListOptional->GetViewShape().GetDimNum() == 1, + ACLNN_ERR_PARAM_INVALID, "When groupList type is tensor, groupList dim only support 1, but now is %ld.", + groupListOptional->GetViewShape().GetDimNum()); + CHECK_COND(actType >= 0, ACLNN_ERR_PARAM_INVALID, "actType must be larger or equal to 0"); + if (actType != GMMActType::GMM_ACT_TYPE_NONE) { + CHECK_COND(actType != GMMActType::GMM_ACT_TYPE_GELU_ERR_FUNC, ACLNN_ERR_PARAM_INVALID, + "Activation function not support GELU_ERR_FUNC now."); + CHECK_COND(actType < END_ACT_TYPE_ENUM, ACLNN_ERR_PARAM_INVALID, + "Activation function only support RELU/GELU_TANH/FASTGELU/SILU."); + } + CHECK_COND(groupListType == 0 || groupListType == 1, ACLNN_ERR_PARAM_INVALID, "groupListType shoule be 0 or 1."); + return aclnnGroupedMatmulGetWorkspaceSizeCommon(x, weight, biasOptional, scaleOptional, offsetOptional, + antiquantScaleOptional, antiquantOffsetOptional, + perTokenScaleOptional, nullptr, groupListOptional, + activationInputOptional, activationQuantScaleOptional, + activationQuantOffsetOptional, splitItem, groupType, groupListType, + actType, GMMApiVersion::V4, out, activationFeatureOutOptional, + dynQuantScaleOutOptional, workspaceSize, executor); +} + +aclnnStatus aclnnGroupedMatmulV3GetWorkspaceSize(const aclTensorList *x, const aclTensorList *weight, + const aclTensorList *biasOptional, const aclTensorList *scaleOptional, const aclTensorList *offsetOptional, + const aclTensorList *antiquantScaleOptional, const aclTensorList *antiquantOffsetOptional, + const aclTensor *groupListOptional, int64_t splitItem, int64_t groupType, const aclTensorList *y, + uint64_t *workspaceSize, aclOpExecutor **executor) { + CHECK_COND(CheckNotNull(x, weight, y) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_NULLPTR, + "one of required inputs is nullptr."); + // Standard syntax, Check parameters. + L2_DFX_PHASE_1(aclnnGroupedMatmulV3, + DFX_IN(x, weight, biasOptional, scaleOptional, offsetOptional, + antiquantScaleOptional, antiquantOffsetOptional, groupListOptional, + splitItem, groupType), + DFX_OUT(y)); + bool is310P = GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND310P; + bool supportedCaseOn310P = x->Size() == 1 && y->Size() == 1 && weight->Size() == 1 && groupType == 0; + CHECK_COND((is310P && supportedCaseOn310P) || !is310P, ACLNN_ERR_PARAM_INVALID, + "only surpport x, y, weight not separated case with groupType is 0 on ASCEND310P."); + CHECK_COND(PreCheckGroupType(splitItem, groupType) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "PreCheckGroupType failed, groupType is invalid."); + CHECK_COND(groupListOptional == nullptr || groupListOptional->GetViewShape().GetDimNum() == 1, + ACLNN_ERR_PARAM_INVALID, "When groupList type is tensor, groupList dim only support 1, but now is %ld.", + groupListOptional->GetViewShape().GetDimNum()); + return aclnnGroupedMatmulGetWorkspaceSizeCommon(x, weight, biasOptional, scaleOptional, offsetOptional, + antiquantScaleOptional, antiquantOffsetOptional, nullptr, nullptr, + groupListOptional, nullptr, nullptr, nullptr, splitItem, groupType, + 0, 0, GMMApiVersion::V3, y, nullptr, nullptr, workspaceSize, + executor); +} + +aclnnStatus aclnnGroupedMatmulV2GetWorkspaceSize(const aclTensorList *x, const aclTensorList *weight, + const aclTensorList *biasOptional, const aclTensorList *scaleOptional, const aclTensorList *offsetOptional, + const aclTensorList *antiquantScaleOptional, const aclTensorList *antiquantOffsetOptional, + const aclIntArray *groupListOptional, int64_t splitItem, int64_t groupType, const aclTensorList *y, + uint64_t *workspaceSize, aclOpExecutor **executor) { + CHECK_COND(CheckNotNull(x, weight, y) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_NULLPTR, + "one of required inputs is nullptr."); + bool is310P = GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND310P; + CHECK_COND(!is310P, ACLNN_ERR_PARAM_INVALID, + "Only aclnnGroupedMatmulV3GetWorkspaceSize is supported on ASCEND310P."); + // Standard syntax, Check parameters. + L2_DFX_PHASE_1(aclnnGroupedMatmulV2, + DFX_IN(x, weight, biasOptional, scaleOptional, offsetOptional, + antiquantScaleOptional, antiquantOffsetOptional, groupListOptional, + splitItem, groupType), + DFX_OUT(y)); + CHECK_COND(PreCheckGroupType(splitItem, groupType) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "PreCheckGroupType failed, groupType is invalid."); + return aclnnGroupedMatmulGetWorkspaceSizeCommon(x, weight, biasOptional, scaleOptional, offsetOptional, + antiquantScaleOptional, antiquantOffsetOptional, nullptr, + groupListOptional, nullptr, nullptr, nullptr, nullptr, splitItem, + groupType, 0, 0, GMMApiVersion::V2, y, nullptr, nullptr, + workspaceSize, executor); +} + +aclnnStatus aclnnGroupedMatmulGetWorkspaceSize(const aclTensorList *x, const aclTensorList *weight, + const aclTensorList *biasOptional, const aclTensorList *scaleOptional, const aclTensorList *offsetOptional, + const aclTensorList *antiquantScaleOptional, const aclTensorList *antiquantOffsetOptional, + const aclIntArray *groupListOptional, int64_t splitItem, const aclTensorList *y, uint64_t *workspaceSize, + aclOpExecutor **executor) { + CHECK_COND(CheckNotNull(x, weight, y) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_NULLPTR, + "one of required inputs is nullptr."); + bool is310P = GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND310P; + CHECK_COND(!is310P, ACLNN_ERR_PARAM_INVALID, + "Only aclnnGroupedMatmulV3GetWorkspaceSize is supported on ASCEND310P."); + // Standard syntax, Check parameters. + L2_DFX_PHASE_1(aclnnGroupedMatmul, + DFX_IN(x, weight, biasOptional, scaleOptional, offsetOptional, + antiquantScaleOptional, antiquantOffsetOptional, groupListOptional, + splitItem), + DFX_OUT(y)); + int64_t groupType = 0; + // Support weight group size of 1 only when the overall group size is 1. + if (weight->Size() == 1) { + CHECK_COND(x->Size() == 1 && y->Size() == 1, ACLNN_ERR_PARAM_INVALID, + "Only accept separated weight, but input weight is not separated."); + } + bool xYSeparated = (x->Size() > 1 && y->Size() > 1) || + (x->Size() == 1 && y->Size() == 1 && weight->Size() == 1); + // Group type is -1 only when both input X and Y are grouped case. + if (xYSeparated) { + groupType = -1; + } + if (GetCurrentPlatformInfo().GetSocVersion() != SocVersion::ASCEND310P) { + bool isSingleWeight = (weight->Size() == 1) && !(x->Size() == 1 && xYSeparated); + CHECK_COND(!isSingleWeight, ACLNN_ERR_PARAM_INVALID, + "Only accept separated weight, but input weight is not separated."); + } + return aclnnGroupedMatmulGetWorkspaceSizeCommon(x, weight, biasOptional, scaleOptional, offsetOptional, + antiquantScaleOptional, antiquantOffsetOptional, nullptr, + groupListOptional, nullptr, nullptr, nullptr, nullptr, splitItem, + groupType, 0, 0, GMMApiVersion::V1, y, nullptr, nullptr, + workspaceSize, executor); +} + +aclnnStatus aclnnGroupedMatmul(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, + aclrtStream stream) { + L2_DFX_PHASE_2(aclnnGroupedMatmul); + CHECK_COND(CommonOpExecutorRun(workspace, workspaceSize, executor, stream) == ACLNN_SUCCESS, ACLNN_ERR_INNER, + "This is an error in GMM launch aicore"); + return ACLNN_SUCCESS; +} + +aclnnStatus aclnnGroupedMatmulV2(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, + aclrtStream stream) { + L2_DFX_PHASE_2(aclnnGroupedMatmulV2); + CHECK_COND(CommonOpExecutorRun(workspace, workspaceSize, executor, stream) == ACLNN_SUCCESS, ACLNN_ERR_INNER, + "This is an error in GMM launch aicore"); + return ACLNN_SUCCESS; +} + +aclnnStatus aclnnGroupedMatmulV3(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, + aclrtStream stream) { + L2_DFX_PHASE_2(aclnnGroupedMatmulV3); + CHECK_COND(CommonOpExecutorRun(workspace, workspaceSize, executor, stream) == ACLNN_SUCCESS, ACLNN_ERR_INNER, + "This is an error in GMM launch aicore"); + return ACLNN_SUCCESS; +} + +aclnnStatus aclnnGroupedMatmulV4(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, + aclrtStream stream) { + L2_DFX_PHASE_2(aclnnGroupedMatmulV4); + CHECK_COND(CommonOpExecutorRun(workspace, workspaceSize, executor, stream) == ACLNN_SUCCESS, ACLNN_ERR_INNER, + "This is an error in GMM launch aicore"); + return ACLNN_SUCCESS; +} + +#ifdef __cplusplus +} +#endif diff --git a/src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul.h b/src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul.h new file mode 100644 index 00000000..8f508163 --- /dev/null +++ b/src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul.h @@ -0,0 +1,63 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef OP_API_INC_GROUPED_MATMUL_H +#define OP_API_INC_GROUPED_MATMUL_H +#include "aclnn/aclnn_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief aclnnGroupedMatmul的第一段接口,根据具体的计算流程,计算workspace大小。 + * @domain aclnn_ops_infer + * + * @param [in] x: 表示公式中的x,数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] weight: + * 表示公式中的weight,数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] biasOptional: + * 表示公式中的bias,数据类型支持FLOAT16、FLOAT32、INT32数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] scaleOptional: 表示量化参数,数据类型支持UINT64数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] offsetOptional: 表示量化参数,数据类型支持FLOAT32数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] antiquantScaleOptional: + * 表示伪量化参数,数据类型支持FLOAT16,BFLOAT16数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] antiquantOffsetOptional: + * 表示伪量化参数,数据类型支持FLOAT16,BFLOAT16数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] groupListOptional: 可选参数,代表输入和输出M轴上的索引情况,数据类型支持INT64,支持的最大长度为128个。 + * @param [in] splitItem: + * 整数型参数,代表输出是否要做tensor切分,0/1代表输出为多tensor;2/3代表输出为单tensor,默认值为0。 + * @param [out] y: 表示公式中的y,数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [out] workspaceSize: 返回用户需要在npu device侧申请的workspace大小。 + * @param [out] executor: 返回op执行器,包含算子计算流程。 + * @return aclnnStatus: 返回状态码。 + */ +__attribute__((visibility("default"))) aclnnStatus aclnnGroupedMatmulGetWorkspaceSize( + const aclTensorList* x, const aclTensorList* weight, const aclTensorList* biasOptional, + const aclTensorList* scaleOptional, const aclTensorList* offsetOptional, + const aclTensorList* antiquantScaleOptional, const aclTensorList* antiquantOffsetOptional, + const aclIntArray* groupListOptional, int64_t splitItem, const aclTensorList* y, uint64_t* workspaceSize, + aclOpExecutor** executor); + +/** + * @brief aclnnGroupedMatmul的第二段接口,用于执行计算。 + * @param [in] workspace: 在npu device侧申请的workspace内存起址。 + * @param [in] workspaceSize: 在npu device侧申请的workspace大小,由第一段接口aclnnGtTensorGetWorkspaceSize获取。 + * @param [in] stream: acl stream流。 + * @param [in] executor: op执行器,包含了算子计算流程。 + * @return aclnnStatus: 返回状态码。 + */ +__attribute__((visibility("default"))) aclnnStatus aclnnGroupedMatmul(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, + aclrtStream stream); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul_v2.h b/src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul_v2.h new file mode 100644 index 00000000..3de0d3c2 --- /dev/null +++ b/src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul_v2.h @@ -0,0 +1,65 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef OP_API_INC_GROUPED_MATMUL_V2_H +#define OP_API_INC_GROUPED_MATMUL_V2_H +#include "aclnn/aclnn_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief aclnnGroupedMatmulV2的第一段接口,根据具体的计算流程,计算workspace大小。 + * @domain aclnn_ops_infer + * + * @param [in] x: 表示公式中的x,数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] weight: + * 表示公式中的weight,数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] biasOptional: + * 表示公式中的bias,数据类型支持FLOAT16、FLOAT32、INT32数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] scaleOptional: 表示量化参数,数据类型支持UINT64数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] offsetOptional: 表示量化参数,数据类型支持FLOAT32数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] antiquantScaleOptional: + * 表示伪量化参数,数据类型支持FLOAT16,BFLOAT16数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] antiquantOffsetOptional: + * 表示伪量化参数,数据类型支持FLOAT16,BFLOAT16数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] groupListOptional: 可选参数,代表输入和输出分组轴上的索引情况,数据类型支持INT64,支持的最大长度为128个。 + * @param [in] splitItem: + * 整数型参数,代表输出是否要做tensor切分,0/1代表输出为多tensor;2/3代表输出为单tensor,默认值为0。 + * @param [in] groupType: + * 整数型参数,代表需要切分的轴,-1代表不需要切分;0代表需要切分M轴;1代表需要切分N轴;2代表需要切分K轴。 + * @param [out] y: 表示公式中的out,数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [out] workspaceSize: 返回用户需要在npu device侧申请的workspace大小。 + * @param [out] executor: 返回op执行器,包含算子计算流程。 + * @return aclnnStatus: 返回状态码。 + */ +__attribute__((visibility("default"))) aclnnStatus aclnnGroupedMatmulV2GetWorkspaceSize( + const aclTensorList* x, const aclTensorList* weight, const aclTensorList* biasOptional, + const aclTensorList* scaleOptional, const aclTensorList* offsetOptional, + const aclTensorList* antiquantScaleOptional, const aclTensorList* antiquantOffsetOptional, + const aclIntArray* groupListOptional, int64_t splitItem, int64_t groupType, const aclTensorList* y, + uint64_t* workspaceSize, aclOpExecutor** executor); + +/** + * @brief aclnnGroupedMatmulV2的第二段接口,用于执行计算。 + * @param [in] workspace: 在npu device侧申请的workspace内存起址。 + * @param [in] workspaceSize: 在npu device侧申请的workspace大小,由第一段接口aclnnGtTensorGetWorkspaceSize获取。 + * @param [in] stream: acl stream流。 + * @param [in] executor: op执行器,包含了算子计算流程。 + * @return aclnnStatus: 返回状态码。 + */ +__attribute__((visibility("default"))) aclnnStatus aclnnGroupedMatmulV2(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, + aclrtStream stream); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul_v3.h b/src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul_v3.h new file mode 100644 index 00000000..0812f1e3 --- /dev/null +++ b/src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul_v3.h @@ -0,0 +1,65 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef OP_API_INC_GROUPED_MATMUL_V3_H +#define OP_API_INC_GROUPED_MATMUL_V3_H +#include "aclnn/aclnn_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief aclnnGroupedMatmulV3的第一段接口,根据具体的计算流程,计算workspace大小。 + * @domain aclnn_ops_infer + * + * @param [in] x: 表示公式中的x,数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] weight: + * 表示公式中的weight,数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] biasOptional: + * 表示公式中的bias,数据类型支持FLOAT16、FLOAT32、INT32数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] scaleOptional: 表示量化参数,数据类型支持UINT64数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] offsetOptional: 表示量化参数,数据类型支持FLOAT32数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] antiquantScaleOptional: + * 表示伪量化参数,数据类型支持FLOAT16,BFLOAT16数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] antiquantOffsetOptional: + * 表示伪量化参数,数据类型支持FLOAT16,BFLOAT16数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] groupListOptional: 可选参数,代表输入和输出分组轴上的索引情况,数据类型支持INT64,支持的最大长度为128个。 + * @param [in] splitItem: + * 整数型参数,代表输出是否要做tensor切分,0/1代表输出为多tensor;2/3代表输出为单tensor,默认值为0。 + * @param [in] groupType: + * 整数型参数,代表需要切分的轴,-1代表不需要切分;0代表需要切分M轴;1代表需要切分N轴;2代表需要切分K轴。 + * @param [out] y: 表示公式中的out,数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [out] workspaceSize: 返回用户需要在npu device侧申请的workspace大小。 + * @param [out] executor: 返回op执行器,包含算子计算流程。 + * @return aclnnStatus: 返回状态码。 + */ +__attribute__((visibility("default"))) aclnnStatus aclnnGroupedMatmulV3GetWorkspaceSize( + const aclTensorList* x, const aclTensorList* weight, const aclTensorList* biasOptional, + const aclTensorList* scaleOptional, const aclTensorList* offsetOptional, + const aclTensorList* antiquantScaleOptional, const aclTensorList* antiquantOffsetOptional, + const aclTensor* groupListOptional, int64_t splitItem, int64_t groupType, const aclTensorList* y, + uint64_t* workspaceSize, aclOpExecutor** executor); + +/** + * @brief aclnnGroupedMatmulV3的第二段接口,用于执行计算。 + * @param [in] workspace: 在npu device侧申请的workspace内存起址。 + * @param [in] workspaceSize: 在npu device侧申请的workspace大小,由第一段接口aclnnGtTensorGetWorkspaceSize获取。 + * @param [in] stream: acl stream流。 + * @param [in] executor: op执行器,包含了算子计算流程。 + * @return aclnnStatus: 返回状态码。 + */ +__attribute__((visibility("default"))) aclnnStatus aclnnGroupedMatmulV3(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, + aclrtStream stream); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul_v4.h b/src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul_v4.h new file mode 100644 index 00000000..48b44708 --- /dev/null +++ b/src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul_v4.h @@ -0,0 +1,90 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef OP_API_INC_GROUPED_MATMUL_V4_H +#define OP_API_INC_GROUPED_MATMUL_V4_H +#include "aclnn/aclnn_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef enum { + GMM_ACT_TYPE_NONE = 0LL, + GMM_ACT_TYPE_RELU = 1LL, + GMM_ACT_TYPE_GELU_TANH = 2LL, + GMM_ACT_TYPE_GELU_ERR_FUNC = 3LL, + GMM_ACT_TYPE_FAST_GELU = 4LL, + GMM_ACT_TYPE_SILU = 5LL, +} GMMActType; + +/** + * @brief aclnnGroupedMatmulV4的第一段接口,根据具体的计算流程,计算workspace大小。 + * @domain aclnn_ops_infer + * + * @param [in] x: 表示公式中的x,数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] weight: + * 表示公式中的weight,数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32、INT4数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] biasOptional: + * 表示公式中的bias,数据类型支持FLOAT16、FLOAT32、INT32数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] scaleOptional: 表示量化参数,数据类型支持UINT64、BFLOAT16、FLOAT32数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] offsetOptional: 表示量化参数,数据类型支持FLOAT32数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] antiquantScaleOptional: + * 表示伪量化参数,数据类型支持FLOAT16,BFLOAT16数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] antiquantOffsetOptional: + * 表示伪量化参数,数据类型支持FLOAT16,BFLOAT16数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [in] perTokenScaleOptional: + * 表示per token量化参数,数据类型支持FLOAT32数据类型,数据格式支持ND,支持的最大长度为1个。 + * @param [in] groupListOptional: 可选参数,代表输入和输出分组轴上的索引情况,数据类型支持INT64, + * 部分场景支持的最大长度为1024个(详见接口文档约束说明),其余场景支持的最大长度为128个。 + * @param [in] activationInputOptional: 可选参数,代表激活函数的反向输入。 + * @param [in] activationQuantScaleOptional: 可选参数,预留参数。 + * @param [in] activationQuantOffsetOptional: 可选参数,预留参数。 + * @param [in] splitItem: + * 整数型参数,代表输出是否要做tensor切分,0/1代表输出为多tensor;2/3代表输出为单tensor,默认值为0。 + * @param [in] groupType: + * 整数型参数,代表需要切分的轴,-1代表不需要切分;0代表需要切分M轴;1代表需要切分N轴;2代表需要切分K轴。 + * @param [in] groupListType: + * 整数型参数,可取值0或1,0代表groupListOptional中数值为分组轴大小的cumsum结果(累积和), + * 1代表groupListOptional中数值为分组轴上每组大小。 + * @param [in] actType:整数型参数,代表激活函数类型,各激活函数枚举值参考枚举类GMMActType。 + * @param [out] out: 表示公式中的out,数据类型支持FLOAT16、BFLOAT16、INT8、FLOAT32数据类型,数据格式支持ND,支持的最大长度为128个。 + * @param [out] activationFeatureOutOptional: 激活函数的输入数据。 + * @param [out] dynQuantScaleOutOptional: 预留参数。 + * @param [out] workspaceSize: 返回用户需要在npu device侧申请的workspace大小。 + * @param [out] executor: 返回op执行器,包含算子计算流程。 + * @return aclnnStatus: 返回状态码。 + */ +__attribute__((visibility("default"))) aclnnStatus aclnnGroupedMatmulV4GetWorkspaceSize( + const aclTensorList *x, const aclTensorList *weight, const aclTensorList *biasOptional, + const aclTensorList *scaleOptional, const aclTensorList *offsetOptional, + const aclTensorList *antiquantScaleOptional, const aclTensorList *antiquantOffsetOptional, + const aclTensorList *perTokenScaleOptional, const aclTensor *groupListOptional, + const aclTensorList *activationInputOptional, const aclTensorList *activationQuantScaleOptional, + const aclTensorList *activationQuantOffsetOptional, int64_t splitItem, int64_t groupType, + int64_t groupListType, int64_t actType, aclTensorList *out, aclTensorList *activationFeatureOutOptional, + aclTensorList *dynQuantScaleOutOptional, uint64_t *workspaceSize, + aclOpExecutor **executor); + +/** + * @brief aclnnGroupedMatmulV4的第二段接口,用于执行计算。 + * @param [in] workspace: 在npu device侧申请的workspace内存起址。 + * @param [in] workspaceSize: 在npu device侧申请的workspace大小,由第一段接口aclnnGtTensorGetWorkspaceSize获取。 + * @param [in] stream: acl stream流。 + * @param [in] executor: op执行器,包含了算子计算流程。 + * @return aclnnStatus: 返回状态码。 + */ +__attribute__((visibility("default"))) aclnnStatus aclnnGroupedMatmulV4(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, + aclrtStream stream); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/src/transformer/grouped_matmul/ophost/fallback_grouped_matmul.cpp b/src/transformer/grouped_matmul/ophost/fallback_grouped_matmul.cpp new file mode 100644 index 00000000..c1e9a14b --- /dev/null +++ b/src/transformer/grouped_matmul/ophost/fallback_grouped_matmul.cpp @@ -0,0 +1,235 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "fallback_comm.h" +#include "fallback.h" + +#ifdef __cplusplus +extern "C" { +#endif + +namespace fallback { +using namespace ge; +using namespace gert; + +constexpr size_t INDEX_GMM_INPUT_X = 0; +constexpr size_t INDEX_GMM_INPUT_WEIGHT = 1; +constexpr size_t INDEX_GMM_INPUT_BIAS = 2; +constexpr size_t INDEX_GMM_INPUT_SCALE = 3; +constexpr size_t INDEX_GMM_INPUT_OFFSET = 4; +constexpr size_t INDEX_GMM_INPUT_ANTIQUANT_SCALE = 5; +constexpr size_t INDEX_GMM_INPUT_ANTIQUANT_OFFSET = 6; +constexpr size_t INDEX_GMM_INPUT_GROUP_LIST = 7; +constexpr size_t INDEX_GMM_INPUT_PER_TOKEN_SCALE = 8; +constexpr size_t INDEX_GMM_OUTPUT_Y = 0; +constexpr size_t INDEX_GMM_ATTR_SPLIT_ITEM = 0; +constexpr size_t INDEX_GMM_ATTR_TRANSPOSE_WEIGHT = 2; +constexpr size_t INDEX_GMM_ATTR_TRANSPOSE_X = 3; +constexpr size_t INDEX_GMM_ATTR_GROUP_TYPE = 4; +constexpr size_t INDEX_GMM_ATTR_GROUP_LIST_TYPE = 5; +constexpr size_t INDEX_GMM_ATTR_ACT_TYPE = 6; + +inline aclTensorList* ConvertType(aclTensorList* geTensorList) { + return geTensorList; +} + +inline aclTensor* GeTensor2AclTensor(const gert::Tensor* geTensor, bool enableTranspose, bool enableNZ=false) { + if (geTensor == nullptr) { + return nullptr; + } + auto storageShape = geTensor->GetStorageShape(); + if (storageShape.GetDimNum() <= 1) { + return ConvertType(geTensor); + } + std::vector storageShapeVec; + for (size_t i = 0; i < storageShape.GetDimNum(); ++i) { + storageShapeVec.push_back(storageShape.GetDim(i)); + } + + static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor); + OPS_CHECK(aclCreateTensor == nullptr, OPS_LOG_E("aclnnfallback", "aclCreateTensor nullptr"), return nullptr); + + void* deviceAddr = (void*)geTensor->GetAddr(); + // convert data type + auto dataType_ge = geTensor->GetDataType(); + auto dataType = ToAclDataType(dataType_ge); + + // convert view shape + auto origin_shape = geTensor->GetOriginShape(); + std::vector viewShape; + for (size_t i = 0; i < origin_shape.GetDimNum(); ++i) { + viewShape.push_back(origin_shape.GetDim(i)); + } + // Compute the strides of contiguous tensor + std::vector strides(viewShape.size(), 1); + for (int64_t i = viewShape.size() - 2; i >= 0; i--) { + strides[i] = viewShape[i + 1] * strides[i + 1]; + } + // when tensor is transposed, last two dims in strides and viewShape should swap + if (enableTranspose) { + // dimM the second-to-last dim, dimN the last dim + auto dimM = viewShape.size() - 2; + auto dimN = viewShape.size() - 1; + auto swap = strides[dimN]; + strides[dimN] = strides[dimM]; + strides[dimM] = swap; + // swap viewShape + swap = viewShape[dimN]; + viewShape[dimN] = viewShape[dimM]; + viewShape[dimM] = swap; + } + auto aclFormat = aclFormat::ACL_FORMAT_ND; + if (enableNZ && GetPrimaryFormat(geTensor->GetStorageFormat()) == ge::Format::FORMAT_FRACTAL_NZ) { + aclFormat = aclFormat::ACL_FORMAT_FRACTAL_NZ; + } + aclTensor* out = aclCreateTensor(viewShape.data(), viewShape.size(), dataType, strides.data(), + 0, aclFormat, storageShapeVec.data(), storageShapeVec.size(), deviceAddr); + OPS_CHECK(out == nullptr, OPS_LOG_E("aclnnfallback", "out nullptr"), return nullptr); + + return out; +} + +graphStatus PrepareGeTensorVector(OpExecuteContext* host_api_ctx, std::vector& tensorVector, size_t index) { + size_t cnt = 0; + while (true) { + auto inputGe = host_api_ctx->GetDynamicInputTensor(index, cnt); + if (inputGe == nullptr) { + break; + } + tensorVector.push_back(inputGe); + cnt++; + } + return GRAPH_SUCCESS; +} + +graphStatus PrepareAclTensorVector(OpExecuteContext* host_api_ctx, std::vector& tensorVector, size_t index, bool enableTranspose, bool enableNZ) { + size_t cnt = 0; + while (true) { + auto inputGe = host_api_ctx->GetDynamicInputTensor(index, cnt); + if (inputGe == nullptr) { + break; + } + auto inputAcl = GeTensor2AclTensor(inputGe, enableTranspose, enableNZ); + tensorVector.push_back(inputAcl); + cnt++; + } + return GRAPH_SUCCESS; +} + +graphStatus PrepareOutputTensorVector(OpExecuteContext* host_api_ctx, std::vector& tensorVector, size_t index, size_t numGeWeight, int32_t splitItem) { + size_t numGeY = 0; + if (0 == splitItem || 1 == splitItem) { // Length of tensorListY equals that of tensorListWeight when split_item = 0 / 1 + numGeY = numGeWeight; + } + else if (2 == splitItem || 3 == splitItem) { // Length of tensorListY equals 1 when split_item = 2 / 3 + numGeY = 1; + } + else { + OPS_LOG_E("aclnnfallback", "Invalid value of split_item: %d, which must be one of 0/1/2/3.", splitItem); + return GRAPH_FAILED; + } + + for (size_t k = 0; k < numGeY; k++) { + auto outputGe = host_api_ctx->GetOutputTensor(index + k); + if (outputGe == nullptr) {return GRAPH_FAILED;} + tensorVector.push_back(outputGe); + } + return GRAPH_SUCCESS; +} + +static graphStatus GroupedMatmulExecuteFunc(OpExecuteContext* host_api_ctx) +{ + OPS_CHECK(host_api_ctx == nullptr, OPS_LOG_E("aclnnfallback", "host_api_ctx is null"), return GRAPH_FAILED); + + auto attrs = host_api_ctx->GetAttrs(); + OPS_CHECK(attrs == nullptr, OPS_LOG_E("aclnnfallback", "attrs is null"), return GRAPH_FAILED); + const int64_t* splitItemGe = attrs->GetAttrPointer(INDEX_GMM_ATTR_SPLIT_ITEM); + OPS_CHECK(splitItemGe == nullptr, OPS_LOG_E("aclnnfallback", "splitItemGe is null"), return GRAPH_FAILED); + const bool* isWeightTransposed = attrs->GetAttrPointer(INDEX_GMM_ATTR_TRANSPOSE_WEIGHT); + OPS_CHECK(isWeightTransposed == nullptr, OPS_LOG_E("aclnnfallback", "isWeightTransposed is null"), return GRAPH_FAILED); + const bool* isXTransposed = attrs->GetAttrPointer(INDEX_GMM_ATTR_TRANSPOSE_X); + OPS_CHECK(isXTransposed == nullptr, OPS_LOG_E("aclnnfallback", "isXTransposed is null"), return GRAPH_FAILED); + const int64_t* groupTypeGe = attrs->GetAttrPointer(INDEX_GMM_ATTR_GROUP_TYPE); + OPS_CHECK(groupTypeGe == nullptr, OPS_LOG_E("aclnnfallback", "groupTypeGe is null"), return GRAPH_FAILED); + const int64_t* groupListTypeGe = attrs->GetAttrPointer(INDEX_GMM_ATTR_GROUP_LIST_TYPE); + OPS_CHECK(groupListTypeGe == nullptr, OPS_LOG_E("aclnnfallback", "groupListTypeGe is null"), return GRAPH_FAILED); + const int64_t* actTypeGe = attrs->GetAttrPointer(INDEX_GMM_ATTR_ACT_TYPE); + OPS_CHECK(actTypeGe == nullptr, OPS_LOG_E("aclnnfallback", "actTypeGe is null"), return GRAPH_FAILED); + + static const auto aclCreateTensorList = GET_OP_API_FUNC(aclCreateTensorList); + OPS_CHECK(aclCreateTensorList == nullptr, + OPS_LOG_E("aclnnfallback", "Get opapi func aclCreateTensorList failed"), return GRAPH_FAILED); + + std::vector aclTensorVectorX; + PrepareAclTensorVector(host_api_ctx, aclTensorVectorX, INDEX_GMM_INPUT_X, *isXTransposed, false); + auto aclTensorListX = aclCreateTensorList(aclTensorVectorX.data(), aclTensorVectorX.size()); + + std::vector aclTensorVectorWeight; + PrepareAclTensorVector(host_api_ctx, aclTensorVectorWeight, INDEX_GMM_INPUT_WEIGHT, *isWeightTransposed, true); + size_t numGeWeight = aclTensorVectorWeight.size(); + auto aclTensorListWeight = aclCreateTensorList(aclTensorVectorWeight.data(), aclTensorVectorWeight.size()); + + std::vector geTensorVectorBias; + PrepareGeTensorVector(host_api_ctx, geTensorVectorBias, INDEX_GMM_INPUT_BIAS); + + std::vector geTensorVectorScale; + PrepareGeTensorVector(host_api_ctx, geTensorVectorScale, INDEX_GMM_INPUT_SCALE); + + std::vector geTensorVectorOffset; + PrepareGeTensorVector(host_api_ctx, geTensorVectorOffset, INDEX_GMM_INPUT_OFFSET); + + std::vector geTensorVectorAntiquantScale; + PrepareGeTensorVector(host_api_ctx, geTensorVectorAntiquantScale, INDEX_GMM_INPUT_ANTIQUANT_SCALE); + + std::vector geTensorVectorAntiquantOffset; + PrepareGeTensorVector(host_api_ctx, geTensorVectorAntiquantOffset, INDEX_GMM_INPUT_ANTIQUANT_OFFSET); + + auto groupListTensor = host_api_ctx->GetOptionalInputTensor(INDEX_GMM_INPUT_GROUP_LIST); + + auto perTokenScale = ConvertType(host_api_ctx->GetOptionalInputTensor(INDEX_GMM_INPUT_PER_TOKEN_SCALE)); + if (perTokenScale == nullptr) { + std::vector shape{0}; + static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor); + OPS_CHECK(aclCreateTensor == nullptr, OPS_LOG_E("aclnnfallback", "aclCreateTensor nullptr"), return GRAPH_FAILED); + perTokenScale = aclCreateTensor(shape.data(), shape.size(), aclDataType::ACL_FLOAT, shape.data(), + 0, aclFormat::ACL_FORMAT_ND, shape.data(), shape.size(), nullptr); + OPS_CHECK(perTokenScale == nullptr, OPS_LOG_E("aclnnfallback", "perTokenScale nullptr"), return GRAPH_FAILED); + } + std::vector geTensorVectorPerTokenScale{perTokenScale}; + auto aclTensorListPerTokenScale = aclCreateTensorList(geTensorVectorPerTokenScale.data(), + geTensorVectorPerTokenScale.size()); + + std::vector geTensorVectorY; + PrepareOutputTensorVector(host_api_ctx, geTensorVectorY, INDEX_GMM_OUTPUT_Y, numGeWeight, *splitItemGe); + + aclTensorList* activationInputOptional = nullptr; + aclTensorList* activationQuantScaleOptional = nullptr; + aclTensorList* activationQuantOffsetOptional = nullptr; + aclTensorList* actFeatureOutOptional = nullptr; + aclTensorList* dynQuantScaleOutOptional = nullptr; + + // execute opapi + auto api_ret = EXEC_OPAPI_CMD(aclnnGroupedMatmulV4, aclTensorListX, aclTensorListWeight, geTensorVectorBias, + geTensorVectorScale, geTensorVectorOffset, geTensorVectorAntiquantScale, + geTensorVectorAntiquantOffset, aclTensorListPerTokenScale, groupListTensor, + activationInputOptional, activationQuantScaleOptional, activationQuantOffsetOptional, + *splitItemGe, *groupTypeGe, *groupListTypeGe, *actTypeGe, + geTensorVectorY, actFeatureOutOptional, dynQuantScaleOutOptional); + OPS_CHECK(api_ret != GRAPH_SUCCESS, OPS_LOG_E("aclnnfallback", "api_ret failed:%u", api_ret), return GRAPH_FAILED); + return GRAPH_SUCCESS; +} + +IMPL_OP(GroupedMatmul).OpExecuteFunc(GroupedMatmulExecuteFunc); + +} // namespace fallback + +#ifdef __cplusplus +} +#endif diff --git a/src/transformer/grouped_matmul/ophost/grouped_matmul.cpp b/src/transformer/grouped_matmul/ophost/grouped_matmul.cpp new file mode 100644 index 00000000..755d21b9 --- /dev/null +++ b/src/transformer/grouped_matmul/ophost/grouped_matmul.cpp @@ -0,0 +1,76 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "grouped_matmul.h" +#include "opdev/op_log.h" +#include "opdev/op_dfx.h" +#include "opdev/shape_utils.h" +#include "opdev/make_op_executor.h" + +using namespace op; + +namespace l0op { +OP_TYPE_REGISTER(GroupedMatmul); + +const aclTensorList *GroupedMatmul(const aclTensorList *x, + const aclTensorList *weight, + const aclTensorList *biasOptional, + const aclTensorList *scaleOptional, + const aclTensorList *offsetOptional, + const aclTensorList *antiquantScaleOptional, + const aclTensorList *antiquantOffsetOptional, + const aclTensor *groupListOptional, + const aclTensor *perTokenScaleOptional, + int64_t splitItem, + op::DataType yDtype, + bool transposeWeight, + bool transposeX, + int64_t groupType, + int64_t groupListType, + int64_t actType, + size_t outLength, + aclOpExecutor *executor) { + L0_DFX(GroupedMatmul, x, weight, biasOptional, scaleOptional, offsetOptional, antiquantScaleOptional, + antiquantOffsetOptional, groupListOptional, perTokenScaleOptional, splitItem, yDtype, + transposeWeight, transposeX, groupType, groupListType, actType, outLength); + std::vector tensorsVec; + const aclTensor *x0 = x->Size() > 0 ? (*x)[0] : nullptr; + if (x0 == nullptr) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "(*x)[0] is nullptr."); + return nullptr; + } + for (size_t i(0); i < outLength; ++i) { + tensorsVec.emplace_back(executor->AllocTensor(yDtype, x0->GetStorageFormat(), x0->GetOriginalFormat())); + } + auto out = executor->AllocTensorList(tensorsVec.data(), outLength); + auto ret = INFER_SHAPE(GroupedMatmul, + OP_INPUT(x, weight, biasOptional, scaleOptional, offsetOptional, antiquantScaleOptional, + antiquantOffsetOptional, groupListOptional, perTokenScaleOptional), + OP_OUTPUT(out), + OP_ATTR(splitItem, -1, transposeWeight, transposeX, groupType, groupListType, actType)); + if (ret != ACLNN_SUCCESS) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "InferShape failed."); + return nullptr; + } + ret = ADD_TO_LAUNCHER_LIST_AICORE(GroupedMatmul, + OP_INPUT(x, weight, biasOptional, scaleOptional, offsetOptional, + antiquantScaleOptional, antiquantOffsetOptional, groupListOptional, + perTokenScaleOptional), + OP_OUTPUT(out), + OP_ATTR(splitItem, -1, transposeWeight, transposeX, groupType, groupListType, + actType)); + if (ret != ACLNN_SUCCESS) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "ADD_TO_LAUNCHER_LIST_AICORE failed."); + return nullptr; + } + return out; +} + +} // namespace l0op diff --git a/src/transformer/grouped_matmul/ophost/grouped_matmul.h b/src/transformer/grouped_matmul/ophost/grouped_matmul.h new file mode 100644 index 00000000..fd042e4e --- /dev/null +++ b/src/transformer/grouped_matmul/ophost/grouped_matmul.h @@ -0,0 +1,36 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef OP_API_INC_LEVEL0_OP_GROUPED_MATMUL_OP_H +#define OP_API_INC_LEVEL0_OP_GROUPED_MATMUL_OP_H + +#include "opdev/op_executor.h" + +namespace l0op { +const aclTensorList *GroupedMatmul(const aclTensorList *x, + const aclTensorList *weight, + const aclTensorList *biasOptional, + const aclTensorList *scaleOptional, + const aclTensorList *offsetOptional, + const aclTensorList *antiquantScaleOptional, + const aclTensorList *antiquantOffsetOptional, + const aclTensor *groupListOptional, + const aclTensor *perTokenScaleOptional, + int64_t splitItem, + op::DataType yDtype, + bool transposeWeight, + bool transposeX, + int64_t groupType, + int64_t groupListType, + int64_t actType, + size_t outLength, + aclOpExecutor *executor); +} + +#endif \ No newline at end of file diff --git a/src/transformer/grouped_matmul/ophost/grouped_matmul_def.cpp b/src/transformer/grouped_matmul/ophost/grouped_matmul_def.cpp new file mode 100644 index 00000000..27cf2dcd --- /dev/null +++ b/src/transformer/grouped_matmul/ophost/grouped_matmul_def.cpp @@ -0,0 +1,142 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul.cpp + * \brief + */ + +#include +#include "register/op_def_registry.h" +namespace ops { +class GroupedMatmul : public OpDef { +public: + explicit GroupedMatmul(const char* name) : OpDef(name) + { + this->Input("x") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_INT8, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("weight") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_FLOAT, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT4, ge::DT_INT4}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("bias") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32, ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_FLOAT16, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("scale") + .ParamType(DYNAMIC) + .DataType({ge::DT_UINT64, ge::DT_UINT64, ge::DT_UINT64, ge::DT_UINT64, ge::DT_UINT64, ge::DT_UINT64, ge::DT_BF16, ge::DT_BF16, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_UINT64, ge::DT_UINT64}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("offset") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("antiquant_scale") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("antiquant_offset") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("group_list") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("per_token_scale") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("y") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_INT8, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT, ge::DT_BF16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("split_item").AttrType(OPTIONAL).Int(0); + this->Attr("dtype").AttrType(OPTIONAL).Int(0); // Reserved parameter, indicate output data type, not enabled. + this->Attr("transpose_weight").AttrType(OPTIONAL).Bool(false); // Reserved parameter, indicate wether input weight is transposed, not enabled. + this->Attr("transpose_x").AttrType(OPTIONAL).Bool(false); // Reserved parameter, indicate wether input x is transposed, not enabled. + this->Attr("group_type").AttrType(OPTIONAL).Int(-1); // Indicates the splited dimension. + this->Attr("group_list_type").AttrType(OPTIONAL).Int(0); // Indicates whether the value in group_dist is cumsum or count. + this->Attr("act_type").AttrType(OPTIONAL).Int(0); // Indicate activation function type. + + OpAICoreConfig aicore_config; + aicore_config.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(true) + .ExtendCfgInfo("prebuildPattern.value", "Opaque") + .ExtendCfgInfo("coreType.value", "AiCore") + .ExtendCfgInfo("aclnnSupport.value", "support_aclnn") + .ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false"); + + this->AICore().AddConfig("ascend910b", aicore_config); + + OpAICoreConfig config310P; + config310P.Input("x") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND}); + config310P.Input("weight") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT16}) + .Format({ge::FORMAT_FRACTAL_NZ}); + config310P.Input("bias") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND}); + config310P.Input("scale") + .ParamType(DYNAMIC) + .DataType({ge::DT_UINT64}) + .Format({ge::FORMAT_ND}); + config310P.Input("offset") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}); + config310P.Input("antiquant_scale") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND}); + config310P.Input("antiquant_offset") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND}); + config310P.Input("group_list") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}); + config310P.Input("per_token_scale") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND}); + config310P.Output("y") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND}); + config310P.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(true) + .ExtendCfgInfo("prebuildPattern.value", "Opaque") + .ExtendCfgInfo("coreType.value", "AiCore") + .ExtendCfgInfo("aclnnSupport.value", "support_aclnn") + .ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false"); + + this->AICore().AddConfig("ascend310p", config310P); + } +}; + +OP_ADD(GroupedMatmul); +} diff --git a/src/transformer/grouped_matmul/ophost/grouped_matmul_proto.cpp b/src/transformer/grouped_matmul/ophost/grouped_matmul_proto.cpp new file mode 100644 index 00000000..f53c42c6 --- /dev/null +++ b/src/transformer/grouped_matmul/ophost/grouped_matmul_proto.cpp @@ -0,0 +1,1492 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul.cc + * \brief + */ +#include "register/op_impl_registry.h" +#include "log/ops_log.h" +#include "platform/platform_info.h" + +using namespace ge; +namespace ops { + +static constexpr size_t INDEX_IN_X = 0; +static constexpr size_t INDEX_IN_WEIGHT = 1; +static constexpr size_t INDEX_IN_BIAS = 2; +static constexpr size_t INDEX_IN_SCALE = 3; +static constexpr size_t INDEX_IN_OFFSET = 4; +static constexpr size_t INDEX_IN_ANTIQUANT_SCALE = 5; +static constexpr size_t INDEX_IN_ANTIQUANT_OFFSET = 6; +static constexpr size_t INDEX_IN_GROUP_LIST = 7; +static constexpr size_t INDEX_IN_PERTOKEN_SCALE = 8; + +static constexpr size_t INDEX_OUT_Y = 0; +static constexpr size_t INDEX_ATTR_SPLIT_ITEM = 0; +static constexpr size_t INDEX_ATTR_OUTPUT_DTYPE = 1; +static constexpr size_t INDEX_ATTR_TRANSPOSE_W = 2; +static constexpr size_t INDEX_ATTR_TRANSPOSE_X = 3; +static constexpr size_t INDEX_ATTR_GROUP_TYPE = 4; +static constexpr size_t INDEX_ATTR_GROUP_LIST_TYPE = 5; +static constexpr size_t INDEX_ATTR_ACT_TYPE = 6; + +static constexpr int64_t X_Y_SEPARATED = 0; // x,y have been separated +static constexpr int64_t Y_SEPARATED = 1; // y has been separated +static constexpr int64_t X_SEPARATED = 2; // x has been separated +static constexpr int64_t NO_SEPARATED = 3; // x,y have not been separated + +static constexpr int64_t NO_SPLIT = -1; +static constexpr int64_t SPLIT_M = 0; +static constexpr int64_t SPLIT_N = 1; +static constexpr int64_t SPLIT_K = 2; + +static constexpr int64_t MAX_GROUP_LIST_SIZE_ARRAY = 128; +static constexpr int64_t MAX_GROUP_LIST_SIZE_TENSOR = 1024; +static constexpr int64_t MAX_INNER_AXIS = 65535; +static constexpr size_t MAX_FM_DIM = 6; +static constexpr size_t MIN_FM_DIM = 2; +static constexpr size_t SEPARATED_WEIGHT_DIM = 2; +static constexpr size_t SPLIT_M_SINGLE_WEIGHT_DIM = 3; +static constexpr size_t SPLIT_K_SINGLE_WEIGHT_DIM = 2; + +enum class PlatformID { + UNKNOWN, + ASCEND310P, + ASCEND910B +}; + +enum class GMMActType : int64_t { + GMM_ACT_TYPE_NONE, + GMM_ACT_TYPE_RELU, + GMM_ACT_TYPE_GELU_TANH, + GMM_ACT_TYPE_GELU_ERR_FUNC, + GMM_ACT_TYPE_FAST_GELU, + GMM_ACT_TYPE_SILU, + END_ACT_TYPE_ENUM +}; + +static const std::map OUTPUT_DTYPE_MAP = { + {0, DataType::DT_FLOAT16}, + {1, DataType::DT_BF16}, + {-1, DataType::DT_INT8} +}; + +struct GMMParamsInfo { + size_t numX; + size_t numWeight; + size_t numY; + int64_t lenGroupList; + size_t groupNum; + size_t numScale; + size_t numOffset; + size_t numAntiquantScale; + size_t numAntiquantOffset; + PlatformID platform; +}; + +struct GMMAttrs { + int64_t splitItem; + int64_t groupType; + bool transposeX; + bool transposeWeight; + int64_t activeType; +}; + +static inline std::string ToString(const std::int64_t value) { + return std::to_string(value); +} + +static ge::graphStatus CheckSplitItem(int64_t splitItem) { + if (splitItem == X_Y_SEPARATED || splitItem == NO_SEPARATED || + splitItem == X_SEPARATED || splitItem == Y_SEPARATED) { + return GRAPH_SUCCESS; + } else { + return GRAPH_FAILED; + } +} + +static bool IsTensorListNullOrEmpty(const gert::InferShapeContext* context, size_t index) { + auto shape = context->GetDynamicInputShape(index, 0); + if (shape == nullptr) { + return true; + } + if (shape->GetDimNum() == 0 || (shape->GetDimNum() == 1 && shape->GetDim(0) == 0)) { + if (context->GetDynamicInputShape(index, 1) == nullptr) { + return true; + } + } + return false; +} + +static ge::graphStatus CheckGroupType(const gert::InferShapeContext* context, int64_t groupType) { + if (groupType == NO_SPLIT || groupType == SPLIT_M || groupType == SPLIT_K) { + return GRAPH_SUCCESS; + } else if (groupType == SPLIT_N) { + OPS_LOG_E(context->GetNodeName(), "Splitting tensor along the N-axis is not supported yet."); + return GRAPH_FAILED; + } else { + OPS_LOG_E(context->GetNodeName(), "GroupType can only be -1/0/2 now, but actually %ld is given.", groupType); + return GRAPH_FAILED; + } +} + +static ge::graphStatus UpdateShapeYMultiDim(gert::InferShapeContext* context, size_t idxY, const gert::Shape* xShape, + const gert::Shape* weightShape) { + gert::Shape* yShape = context->GetOutputShape(idxY); + OPS_LOG_E_IF_NULL(context, yShape, return ge::GRAPH_FAILED); + *yShape = *xShape; + size_t dimY = yShape->GetDimNum(); + const gert::RuntimeAttrs* attrs = context->GetAttrs(); + OPS_LOG_E_IF_NULL(context, attrs, return ge::GRAPH_FAILED); + const bool* transposeWPtr = attrs->GetAttrPointer(INDEX_ATTR_TRANSPOSE_W); + const bool* transposeXPtr = attrs->GetAttrPointer(INDEX_ATTR_TRANSPOSE_X); + + OPS_LOG_E_IF_NULL(context, weightShape, return ge::GRAPH_FAILED); + if (transposeWPtr != nullptr && *transposeWPtr) { + yShape->SetDim(dimY - 1, weightShape->GetDim(weightShape->GetDimNum() - 2)); // -2: transpose weight + } else { + yShape->SetDim(dimY - 1, weightShape->GetDim(weightShape->GetDimNum() - 1)); + } + if (transposeXPtr != nullptr && *transposeXPtr) { + yShape->SetDim(dimY - 2, xShape->GetDim(xShape->GetDimNum() - 1)); // -2: last two dim of Y + } + return GRAPH_SUCCESS; +} + +static ge::graphStatus UpdateShapeY(gert::InferShapeContext* context, size_t idxY, std::vector yDims) { + gert::Shape* yShape = context->GetOutputShape(idxY); + OPS_LOG_E_IF_NULL(context, yShape, return ge::GRAPH_FAILED); + yShape->SetDimNum(yDims.size()); + for (size_t dim = 0; dim < yDims.size(); ++dim) { + yShape->SetDim(dim, yDims[dim]); + } + return GRAPH_SUCCESS; +} + +static ge::graphStatus UpdateMultipleShapeY(gert::InferShapeContext* context, const gert::Tensor* groupListTensor, + size_t weightDimN, bool isXTransposed, size_t xDimM) { + auto groupListData = groupListTensor->GetData(); + OPS_CHECK(groupListData == nullptr, + OPS_LOG_E(context->GetNodeName(), "Failed to obtain necessary data from groupListTensor."), + return GRAPH_FAILED); + const gert::RuntimeAttrs* attrs = context->GetAttrs(); + OPS_LOG_E_IF_NULL(context, attrs, return ge::GRAPH_FAILED); + const int64_t* groupListTypePtr = attrs->GetAttrPointer(INDEX_ATTR_GROUP_LIST_TYPE); + OPS_LOG_E_IF_NULL(context, groupListTypePtr, return ge::GRAPH_FAILED); + const gert::Shape* x0Shape = context->GetDynamicInputShape(INDEX_IN_X, 0); + OPS_LOG_E_IF_NULL(context, x0Shape, return ge::GRAPH_FAILED); + const gert::Shape* weight0Shape = context->GetDynamicInputShape(INDEX_IN_WEIGHT, 0); + OPS_LOG_E_IF_NULL(context, weight0Shape, return ge::GRAPH_FAILED); + int64_t preOffset = 0; + for (int idx = 0; idx < groupListTensor->GetShapeSize(); ++idx) { + const gert::Shape* weightShape = context->GetDynamicInputShape(INDEX_IN_WEIGHT, idx); + if (weightShape == nullptr) { + weightShape = weight0Shape; + } + if (isXTransposed) { + const gert::Shape* xShape = context->GetDynamicInputShape(INDEX_IN_X, idx); + if (xShape == nullptr) { + xShape = x0Shape; + } + std::vector yDims = {xShape->GetDim(xDimM), weightShape->GetDim(weightDimN)}; + OPS_CHECK(UpdateShapeY(context, INDEX_OUT_Y + idx, yDims) != GRAPH_SUCCESS, OPS_LOG_E(context->GetNodeName(), + "Failed to update shape of y."), return GRAPH_FAILED); + } else { + std::vector yDims; + if (*groupListTypePtr == 0) { + yDims = {groupListData[idx] - preOffset, weightShape->GetDim(weightDimN)}; + } else if (*groupListTypePtr == 1) { + yDims = {groupListData[idx], weightShape->GetDim(weightDimN)}; + } else { + OPS_LOG_E(context->GetNodeName(), "Invalid groupListType = %ld", *groupListTypePtr); + return GRAPH_FAILED; + } + OPS_CHECK(UpdateShapeY(context, INDEX_OUT_Y + idx, yDims) != GRAPH_SUCCESS, OPS_LOG_E(context->GetNodeName(), + "Failed to update shape of y."), return GRAPH_FAILED); + preOffset = groupListData[idx]; + } + } + + return GRAPH_SUCCESS; +} + +static ge::graphStatus MultiInMultiOutWithoutGroupList(gert::InferShapeContext* context) { + size_t idx = 0; + size_t idw = 0; + const gert::Shape* w0Shape = context->GetDynamicInputShape(INDEX_IN_WEIGHT, 0); + OPS_LOG_E_IF_NULL(context, w0Shape, return ge::GRAPH_FAILED); + while (true) { + const gert::Shape* xShape = context->GetDynamicInputShape(INDEX_IN_X, idx); + if (xShape == nullptr) { + break; + } + ++idx; + const gert::Shape* wShape = context->GetDynamicInputShape(INDEX_IN_WEIGHT, idw); + if (wShape) { + ++idw; + } else { + wShape = w0Shape; + } + OPS_CHECK(UpdateShapeYMultiDim(context, INDEX_OUT_Y + idx - 1, xShape, wShape) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Failed to update shape of y."), return GRAPH_FAILED); + } + const gert::RuntimeAttrs* attrs = context->GetAttrs(); + OPS_LOG_E_IF_NULL(context, attrs, return ge::GRAPH_FAILED); + const int64_t* groupTypePtr = attrs->GetAttrPointer(INDEX_ATTR_GROUP_TYPE); + bool success = true; + if (w0Shape->GetDimNum() == 2) { // 2 two-dim weight tensor + if (groupTypePtr != nullptr && *groupTypePtr == 2) { + success = true; + } else { + success = idx == idw; + } + } else { + success = static_cast(idx) == w0Shape->GetDim(0); + } + OPS_CHECK(!success, + OPS_LOG_E(context->GetNodeName(), + "x tensorList's length[%zu] != weight tensor's first dim[%ld] and length[%zu]", + idx, w0Shape->GetDim(0), idw), + return GRAPH_FAILED); + return GRAPH_SUCCESS; +} + +static ge::graphStatus GetAttrs(gert::InferShapeContext* context, GMMAttrs& gmmAttrs) { + const gert::RuntimeAttrs* attrs = context->GetAttrs(); + OPS_LOG_E_IF_NULL(context, attrs, return ge::GRAPH_FAILED); + + const int64_t* splitItemPtr = attrs->GetAttrPointer(INDEX_ATTR_SPLIT_ITEM); + OPS_LOG_E_IF_NULL(context, splitItemPtr, return ge::GRAPH_FAILED); + gmmAttrs.splitItem = *splitItemPtr; + OPS_CHECK(CheckSplitItem(gmmAttrs.splitItem) != GRAPH_SUCCESS, OPS_LOG_E(context->GetNodeName(), + "Invalid splitItem, which can only be one of 0/1/2/3."), return GRAPH_FAILED); + OPS_LOG_I(context->GetNodeName(), "splitItem = %ld", gmmAttrs.splitItem); + + const int64_t* groupTypePtr = attrs->GetAttrPointer(INDEX_ATTR_GROUP_TYPE); + OPS_LOG_E_IF_NULL(context, groupTypePtr, return ge::GRAPH_FAILED); + gmmAttrs.groupType = *groupTypePtr; + OPS_CHECK(CheckGroupType(context, gmmAttrs.groupType) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Invalid groupType."), return GRAPH_FAILED); + OPS_LOG_I(context->GetNodeName(), "groupType = %ld", gmmAttrs.groupType); + + const bool* transposeWPtr = attrs->GetAttrPointer(INDEX_ATTR_TRANSPOSE_W); + OPS_LOG_E_IF_NULL(context, transposeWPtr, return ge::GRAPH_FAILED); + gmmAttrs.transposeWeight = *transposeWPtr; + OPS_LOG_I(context->GetNodeName(), "isWeightTransposed = %d", gmmAttrs.transposeWeight); + + const bool* transposeXPtr = attrs->GetAttrPointer(INDEX_ATTR_TRANSPOSE_X); + OPS_LOG_E_IF_NULL(context, transposeXPtr, return ge::GRAPH_FAILED); + gmmAttrs.transposeX = *transposeXPtr; + OPS_LOG_I(context->GetNodeName(), "isXTransposed = %d", gmmAttrs.transposeX); + + const int64_t* activeType = attrs->GetInt(INDEX_ATTR_ACT_TYPE); + OPS_LOG_E_IF_NULL(context, activeType, return ge::GRAPH_FAILED); + OPS_CHECK(*activeType < 0 || *activeType >= static_cast(GMMActType::END_ACT_TYPE_ENUM), + OPS_LOG_E(context->GetNodeName(), "activeType must be no less than 0 and smaller than 6"), + return GRAPH_FAILED); + OPS_CHECK(*activeType == static_cast(GMMActType::GMM_ACT_TYPE_GELU_ERR_FUNC), + OPS_LOG_E(context->GetNodeName(), "Activation function not support GELU_ERR_FUNC now."), + return GRAPH_FAILED); + gmmAttrs.activeType = *activeType; + OPS_LOG_I(context->GetNodeName(), "activeType = %ld", gmmAttrs.activeType); + + return GRAPH_SUCCESS; +} + +static ge::graphStatus GetNumOfInputs(const gert::InferShapeContext* context, size_t& numX, + size_t& numWeight, int64_t& lenGroupList) { + ge::graphStatus res = GRAPH_SUCCESS; + const gert::Shape* shape = nullptr; + while (true) { + shape = context->GetDynamicInputShape(INDEX_IN_X, numX); + if (shape == nullptr) { // last shape + break; + } + for (size_t i = 0; i < shape->GetDimNum(); ++i) { + if (shape->GetDim(i) < 0) { // shape dim cannot be smaller than 0 + res = GRAPH_FAILED; + break; + } + } + ++numX; + } + OPS_LOG_I(context->GetNodeName(), "numX = %lu", numX); + + while (true) { + shape = context->GetDynamicInputShape(INDEX_IN_WEIGHT, numWeight); + if (shape == nullptr) { // last shape + break; + } + for (size_t i = 0; i < shape->GetDimNum(); ++i) { + if (shape->GetDim(i) < 0) { // shape dim cannot be smaller than 0 + res = GRAPH_FAILED; + break; + } + } + ++numWeight; + } + OPS_LOG_I(context->GetNodeName(), "numWeight = %lu", numWeight); + + const gert::Tensor* groupListTensor = context->GetOptionalInputTensor(INDEX_IN_GROUP_LIST); + if (groupListTensor != nullptr) { + lenGroupList = groupListTensor->GetShapeSize(); + if (lenGroupList < 0) { // lenGroupList cannot be smaller than 0 + res = GRAPH_FAILED; + } + } + OPS_LOG_I(context->GetNodeName(), "lenGroupList = %lu", lenGroupList); + + return res; +} + +static int64_t GetDim0(const gert::InferShapeContext* context, bool isXTransposed, size_t numX, size_t xDimM) { + int64_t dim0 = 0; + if (isXTransposed) { + const gert::Shape* x0Shape = context->GetDynamicInputShape(INDEX_IN_X, 0); + dim0 = (x0Shape == nullptr ? 0 : x0Shape->GetDim(xDimM)); + } else { + for (size_t idx = 0; idx < numX; ++idx) { + const gert::Shape* xShape = context->GetDynamicInputShape(INDEX_IN_X, idx); + dim0 += (xShape == nullptr ? 0 : xShape->GetDim(0)); + } + } + + return dim0; +} + +static bool inline IsNonEmpty(const gert::Shape* shape) { + return (shape != nullptr && !(shape->GetDimNum() == 1 && shape->GetDim(0) == 0)); +} + +static ge::graphStatus IsGmmAntiQuantEmpty(gert::InferShapeContext* context) { + OPS_CHECK(!IsTensorListNullOrEmpty(context, INDEX_IN_ANTIQUANT_SCALE), + OPS_LOG_E(context->GetNodeName(), "antiquantScale is not null or empty!"), + return GRAPH_FAILED); + OPS_CHECK(!IsTensorListNullOrEmpty(context, INDEX_IN_ANTIQUANT_OFFSET), + OPS_LOG_E(context->GetNodeName(), "antiquantOffset is not null or empty!"), + return GRAPH_FAILED); + return GRAPH_SUCCESS; +} + +static ge::graphStatus IsGmmQuantEmpty(gert::InferShapeContext* context) { + OPS_CHECK(!IsTensorListNullOrEmpty(context, INDEX_IN_SCALE), + OPS_LOG_E(context->GetNodeName(), "scale is not null or empty!"), + return GRAPH_FAILED); + OPS_CHECK(!IsTensorListNullOrEmpty(context, INDEX_IN_OFFSET), + OPS_LOG_E(context->GetNodeName(), "offset is not null or empty!"), + return GRAPH_FAILED); + const gert::Shape* pertokenQuantScale0Shape = context->GetOptionalInputShape(INDEX_IN_PERTOKEN_SCALE); + OPS_CHECK(IsNonEmpty(pertokenQuantScale0Shape), + OPS_LOG_E(context->GetNodeName(), "pertokenQuant scale is not null or empty!"), + return GRAPH_FAILED); + return GRAPH_SUCCESS; +} + +static ge::graphStatus CheckNonQuant(gert::InferShapeContext* context) { + OPS_CHECK(IsGmmQuantEmpty(context) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Detected nonquant, but quant inputs is not empty!"), + return GRAPH_FAILED); + OPS_CHECK(IsGmmAntiQuantEmpty(context) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Detected nonquant, but antiquant inputs is not empty!"), + return GRAPH_FAILED); + return GRAPH_SUCCESS; +} + +static ge::graphStatus GetGroupSize(const gert::InferShapeContext* context, GMMParamsInfo& paramsInfo) { + size_t groupNum = 1; + size_t maxGroupNum = MAX_GROUP_LIST_SIZE_ARRAY; // init max value + if (paramsInfo.numX > 1) { + groupNum = paramsInfo.numX; + } else if (paramsInfo.numWeight > 1) { + groupNum = paramsInfo.numWeight; + } else if (paramsInfo.numY > 1) { + groupNum = paramsInfo.numY; + } else if (paramsInfo.lenGroupList > 0) { + groupNum = paramsInfo.lenGroupList; + maxGroupNum = MAX_GROUP_LIST_SIZE_TENSOR; // only this case allows MAX_GROUP_LIST_SIZE_TENSOR size + } + OPS_CHECK(groupNum > maxGroupNum, + OPS_LOG_E(context->GetNodeName(), "groupNum[%zu] is larger than %zu.", + groupNum, maxGroupNum), + return GRAPH_FAILED); + paramsInfo.groupNum = groupNum; + return GRAPH_SUCCESS; +} + +static graphStatus CheckDimNumAndPerGroupNum(const gert::InferShapeContext* context, bool isAntiquantInt4, + const std::tuple& dimData, const gert::Shape* tensorShape, const std::string& tensorType) { + size_t tensorDimNum = std::get<0>(dimData); + size_t expectedDimNum = std::get<1>(dimData); // 1: the sceond element + int64_t weightKDimValue = std::get<2>(dimData); // 2: the third element + if (isAntiquantInt4) { + if (tensorDimNum == expectedDimNum) { + int64_t perGroupNum = tensorShape->GetDim(tensorDimNum - 2); // 2: the last 2-th index + OPS_CHECK(!(perGroupNum > 0 && weightKDimValue % perGroupNum == 0), + OPS_LOG_E(context->GetNodeName(), "perGroupNum must be larger than 0, and can evenly divided " + "by K[%ld] in A16W4-pergroup case, but now perGroupNum is %ld.", weightKDimValue, perGroupNum), + return GRAPH_FAILED); + } else { + OPS_CHECK(tensorDimNum != expectedDimNum - 1, + OPS_LOG_E(context->GetNodeName(), "%s Dim must be %zu for in perchannel case or " + "%zu for pergroup case in A16W4, but now is %zu.", + tensorType.c_str(), expectedDimNum - 1, expectedDimNum, tensorDimNum), + return GRAPH_FAILED); + } + } else { + OPS_CHECK(tensorDimNum != expectedDimNum - 1, + OPS_LOG_E(context->GetNodeName(), "%s Dim must be %zu, but now is %zu.", + tensorType.c_str(), expectedDimNum - 1, tensorDimNum), + return GRAPH_FAILED); + } + return GRAPH_SUCCESS; +} + +static ge::graphStatus CheckOptionalTensorList(gert::InferShapeContext* context, const std::string tensorType, + const GMMParamsInfo& paramsInfo, const GMMAttrs& gmmAttrs, size_t nodeIdx) { + // check bias,scale, antiquant scale or antiquant offset's size,tensor dimension and shape. + const size_t& groupNum = paramsInfo.groupNum; + size_t tensorSize = 0; + while (context->GetDynamicInputShape(nodeIdx, tensorSize) != nullptr) { + ++tensorSize; + } + uint64_t weightGroupedSize = static_cast(paramsInfo.numWeight); + const int64_t& groupType = gmmAttrs.groupType; + auto shape = context->GetDynamicInputShape(INDEX_IN_WEIGHT, 0); + OPS_LOG_E_IF_NULL(context, shape, return ge::GRAPH_FAILED); + uint64_t weightNDimIdx = shape->GetDimNum() - (gmmAttrs.transposeWeight ? 2 : 1); + auto tensor0Shape = context->GetDynamicInputShape(nodeIdx, 0); + // tensorList size should equals with weight's size + OPS_CHECK(tensorSize != weightGroupedSize, OPS_LOG_E(context->GetNodeName(), + "%s size[%lu] must be equal with weight size[%lu].", tensorType.c_str(), tensorSize, weightGroupedSize), return GRAPH_FAILED); + bool isSingleWeight = (weightGroupedSize == 1 && groupType != NO_SPLIT); + auto w0Desc = context->GetDynamicInputDesc(INDEX_IN_WEIGHT, 0); + OPS_LOG_E_IF_NULL(context, w0Desc, return ge::GRAPH_FAILED); + bool isAntiquantInt4 = (w0Desc->GetDataType() == DT_INT4 && tensorType.find("antiquant") != std::string::npos); + if (isSingleWeight) { // In this case, nodeIdx should have only single tensor, its dim should be 2. + OPS_CHECK(IsTensorListNullOrEmpty(context, nodeIdx), OPS_LOG_E(context->GetNodeName(), + "%s must not be nullptr or empty, but now is nullptr or empty.", tensorType.c_str()), return GRAPH_FAILED); + size_t tensorDimNum = tensor0Shape->GetDimNum(); + int64_t k = shape->GetDim(shape->GetDimNum() - (gmmAttrs.transposeWeight ? 1 : 2)); // 2: axis index + // 3: shape is (E,G,N),G is the perGroupNum + OPS_CHECK(CheckDimNumAndPerGroupNum(context, isAntiquantInt4, {tensorDimNum, 3, k}, tensor0Shape, tensorType) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "CheckDimNumAndPerGroupNum failed."), return GRAPH_FAILED); + OPS_CHECK(static_cast(tensor0Shape->GetDim(0)) != groupNum, OPS_LOG_E(context->GetNodeName(), "%s batch size[%ld] should be " + "euqal with groupList length[%lu].", tensorType.c_str(), tensor0Shape->GetDim(0), groupNum), return GRAPH_FAILED); + // tensor's N axis size should equal with weight's N axis. + int64_t weightNDimValue = context->GetDynamicInputShape(INDEX_IN_WEIGHT, 0)->GetDim(weightNDimIdx); + int64_t tensorNDimValue = tensor0Shape->GetDim(tensorDimNum - 1); + OPS_CHECK(tensorNDimValue != weightNDimValue, OPS_LOG_E(context->GetNodeName(), + "NDim[%ld] of %s should be equal with NDim[%ld] of weight.", tensorNDimValue, tensorType.c_str(), weightNDimValue), + return GRAPH_FAILED); + } else { + for (uint64_t i = 0; i < groupNum; i++) { + auto tensorShape = context->GetDynamicInputShape(nodeIdx, i); + OPS_CHECK(tensorShape == nullptr, OPS_LOG_E(context->GetNodeName(), + "%s[%lu] must not be nullptr, but now is nullptr.", tensorType.c_str(), i), return GRAPH_FAILED); + // check each of tensor's dim to be 1 + size_t tensorDimNum = tensorShape->GetDimNum(); + auto wShape = context->GetDynamicInputShape(INDEX_IN_WEIGHT, i); + OPS_LOG_E_IF_NULL(context, wShape, return ge::GRAPH_FAILED); + int64_t k = wShape->GetDim(wShape->GetDimNum() - (gmmAttrs.transposeWeight ? 1 : 2)); // 2: axis index + // 2: shape is (G,N), G is the perGroupNum + OPS_CHECK(CheckDimNumAndPerGroupNum(context, isAntiquantInt4, {tensorDimNum, 2, k}, tensorShape, tensorType) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "CheckDimNumAndPerGroupNum failed."), return GRAPH_FAILED); + int64_t weightNDimValue = wShape->GetDim(weightNDimIdx); + int64_t tensorNDimValue = tensorShape->GetDim(tensorDimNum - 1); + OPS_CHECK(tensorNDimValue != weightNDimValue, OPS_LOG_E(context->GetNodeName(), "NDim[%ld] of %s[%lu] should be equal with " + "NDim[%ld] of weight[%lu].", tensorNDimValue, tensorType.c_str(), i, weightNDimValue, i), return GRAPH_FAILED); + } + } + return GRAPH_SUCCESS; +} + +static ge::graphStatus CheckPerTokenScale(const gert::InferShapeContext* context, const GMMParamsInfo& paramsInfo) { + // check pertoken scale's size, tensor dimension and shape + const size_t& xGroupedSize = paramsInfo.numX; + const size_t& weightGroupedSize = paramsInfo.numWeight; + const size_t& yGroupedSize = paramsInfo.numY; + uint64_t xMDimIdx = 0; + // check pertoken scale's size to be equal with x's + if (xGroupedSize == 1 && weightGroupedSize == 1 && yGroupedSize == 1) { + auto perTokenScale0Shape = context->GetOptionalInputShape(INDEX_IN_PERTOKEN_SCALE); + OPS_CHECK(perTokenScale0Shape == nullptr, + OPS_LOG_E(context->GetNodeName(), "perTokenScaleOptional must not be nullptr, but now is nullptr."), + return GRAPH_FAILED); + // tensor dimension of pertoken_scale should be 1. + size_t tensorDimNum = perTokenScale0Shape->GetDimNum(); + OPS_CHECK(tensorDimNum != 1, + OPS_LOG_E(context->GetNodeName(), + "perTokenScaleOptional dim num must be 1 when x is single tensor, but now is %zu.", tensorDimNum), + return GRAPH_FAILED); + // check pertoken_scale's tensor shape size to be equal with M axis size of x. + auto xShape = context->GetDynamicInputShape(INDEX_IN_X, 0); + OPS_LOG_E_IF_NULL(context, xShape, return ge::GRAPH_FAILED); + int64_t xMDimValue = xShape->GetDim(xMDimIdx); + int64_t tensorMDimValue = perTokenScale0Shape->GetDim(tensorDimNum - 1); + OPS_CHECK(tensorMDimValue != xMDimValue, + OPS_LOG_E(context->GetNodeName(), + "MDim[%ld] of perTokenScaleOptional should be equal with MDim[%ld] of x.", + tensorMDimValue, xMDimValue), + return GRAPH_FAILED); + } else { + OPS_LOG_E(context->GetNodeName(), "per-token quant case is only supported " + "when x, weight and y are all single tensor, but now x size is %zu, weight size is %zu, y size is %zu", + xGroupedSize, weightGroupedSize, yGroupedSize); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +static ge::graphStatus CheckGroupedMatmulQuant(gert::InferShapeContext* context, const GMMAttrs& gmmAttrs, + const GMMParamsInfo& paramsInfo) { + OPS_CHECK(paramsInfo.platform == PlatformID::ASCEND310P, + OPS_LOG_E(context->GetNodeName(), "quant cases do not support on Ascend310P."), + return GRAPH_FAILED); + OPS_CHECK(gmmAttrs.groupType == SPLIT_K, + OPS_LOG_E(context->GetNodeName(), "quant cases do not support splited axis is K."), + return GRAPH_FAILED); + OPS_CHECK(IsTensorListNullOrEmpty(context, INDEX_IN_SCALE), + OPS_LOG_E(context->GetNodeName(), "scale must not be nullptr in quant, but now is nullptr."), + return GRAPH_FAILED); + OPS_CHECK(!IsTensorListNullOrEmpty(context, INDEX_IN_OFFSET), + OPS_LOG_E(context->GetNodeName(), "offset must be nullptr in quant, but now is not nullptr."), + return GRAPH_FAILED); + OPS_CHECK(CheckOptionalTensorList(context, "scale", paramsInfo, gmmAttrs, INDEX_IN_SCALE) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Invalid scale."), + return GRAPH_FAILED); + bool isPerTokenQuant = context->GetOptionalInputShape(INDEX_IN_PERTOKEN_SCALE) != nullptr; + if (isPerTokenQuant) { + OPS_CHECK(CheckPerTokenScale(context, paramsInfo) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Check perTokenScale failed!"), + return GRAPH_FAILED); + } + OPS_CHECK(IsGmmAntiQuantEmpty(context) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Detected quant, but antiquant inputs is not empty!"), + return GRAPH_FAILED); + return GRAPH_SUCCESS; +} + +static int64_t GetPergroupSize(const GMMAttrs& gmmAttrs, bool isSingleWeight, + const gert::Shape* wShape, const gert::Shape* shape) { + int64_t pergroupSize = 0; + size_t shapeDimNum = shape->GetDimNum(); + if (isSingleWeight) { // antiquant param shape (E, N), (E, G, N) + if (shapeDimNum > SEPARATED_WEIGHT_DIM) { + int64_t k = gmmAttrs.transposeWeight ? wShape->GetDim(2) : wShape->GetDim(1); // 2: the k axis index + pergroupSize = k / shape->GetDim(shapeDimNum - 2); // 2: the last 2-th index + } + } else { // antiquant param shape (N), (G, N) + if (shapeDimNum > 1) { + int64_t k = gmmAttrs.transposeWeight ? wShape->GetDim(1): wShape->GetDim(0); + pergroupSize = k / shape->GetDim(shapeDimNum - 2); // 2: the last 2-th index + } + } + return pergroupSize; +} + +static ge::graphStatus CheckGroupedMatmulAntiQuantForShape(gert::InferShapeContext* context, const GMMAttrs& gmmAttrs, const GMMParamsInfo& paramsInfo) { + OPS_CHECK(paramsInfo.platform == PlatformID::ASCEND310P, OPS_LOG_E(context->GetNodeName(), + "antiquant cases do not support on Ascend310P."), return GRAPH_FAILED); + OPS_CHECK(gmmAttrs.groupType == SPLIT_K, OPS_LOG_E(context->GetNodeName(), "antiquant cases do not support splited axis is K."), + return GRAPH_FAILED); + OPS_CHECK(IsTensorListNullOrEmpty(context, INDEX_IN_ANTIQUANT_SCALE), + OPS_LOG_E(context->GetNodeName(), "antiquantScale must not be nullptr in antiquant, but now is nullptr or empty."), + return GRAPH_FAILED); + OPS_CHECK(IsTensorListNullOrEmpty(context, INDEX_IN_ANTIQUANT_OFFSET), + OPS_LOG_E(context->GetNodeName(), "antiquantOffset must not be nullptr in antiquant, but now is nullptr or empty."), + return GRAPH_FAILED); + // check antiquantScale and antiquantOffset's tensor shape + OPS_CHECK(CheckOptionalTensorList(context, "antiquantScale", paramsInfo, gmmAttrs, INDEX_IN_ANTIQUANT_SCALE) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Invalid antiquantScale"), + return GRAPH_FAILED); + OPS_CHECK(CheckOptionalTensorList(context, "antiquantOffset", paramsInfo, gmmAttrs, INDEX_IN_ANTIQUANT_OFFSET) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Invalid antiquantOffset"), + return GRAPH_FAILED); + // check perGroupSize + auto w0Desc = context->GetDynamicInputDesc(INDEX_IN_WEIGHT, 0); + if (w0Desc->GetDataType() == DT_INT4) { + auto antiquantScale0Shape = context->GetDynamicInputShape(INDEX_IN_ANTIQUANT_SCALE, 0); + auto dimNum = antiquantScale0Shape->GetDimNum(); + bool isSingleWeight = (paramsInfo.numWeight == 1 && gmmAttrs.groupType != NO_SPLIT); + int64_t pergroupSize = GetPergroupSize(gmmAttrs, isSingleWeight, context->GetDynamicInputShape(INDEX_IN_WEIGHT, 0), antiquantScale0Shape); + OPS_CHECK(gmmAttrs.transposeWeight && pergroupSize % 2 != 0, // 2: a factor + OPS_LOG_E(context->GetNodeName(), "pergroupSize should be even when weight is transposed" + "in A16W4-pergroup case, but now is %ld", pergroupSize), return GRAPH_FAILED); + for (size_t i = 0; ; ++i) { + auto antiquantScaleShape = context->GetDynamicInputShape(INDEX_IN_ANTIQUANT_SCALE, i); + auto antiquantOffsetShape = context->GetDynamicInputShape(INDEX_IN_ANTIQUANT_OFFSET, i); + if (antiquantScaleShape == nullptr || antiquantOffsetShape == nullptr) { + break; + } + size_t antiquantScaleDimNum = antiquantScaleShape->GetDimNum(); + size_t antiquantOffsetDimNum = antiquantOffsetShape->GetDimNum(); + OPS_CHECK(antiquantScaleDimNum != dimNum || antiquantOffsetDimNum != dimNum, + OPS_LOG_E(context->GetNodeName(), "antiquantScale[%zu] dim num[%zu] or antiquantOffset[%zu] dim num[%zu] is not equal with %zu", + i, antiquantScaleDimNum, i, antiquantOffsetDimNum, dimNum), return GRAPH_FAILED); + auto wShape = context->GetDynamicInputShape(INDEX_IN_WEIGHT, i); + int64_t pergroupSizeOfScale = GetPergroupSize(gmmAttrs, isSingleWeight, wShape, antiquantScaleShape); + int64_t pergroupSizeOfOffset = GetPergroupSize(gmmAttrs, isSingleWeight, wShape, antiquantOffsetShape); + OPS_CHECK(pergroupSizeOfScale != pergroupSize || pergroupSizeOfOffset != pergroupSize, + OPS_LOG_E(context->GetNodeName(), "antiquantScale[%zu]'s pergroup size[%ld] or antiquantOffset[%zu]'s pergroup size[%ld]" + "is not the required value[%ld]", i, pergroupSizeOfScale, i, pergroupSizeOfOffset, pergroupSize), + return GRAPH_FAILED); + } + } + OPS_CHECK(IsGmmQuantEmpty(context) != GRAPH_SUCCESS, OPS_LOG_E(context->GetNodeName(), + "Detected antiquant, but quant inputs is not empty!"), return GRAPH_FAILED); + return GRAPH_SUCCESS; +} + +static ge::graphStatus CheckFunctionParamsForShape(gert::InferShapeContext* context, const GMMAttrs& gmmAttrs, + GMMParamsInfo& paramsInfo) { + if (context == nullptr) { + return GRAPH_FAILED; + } + fe::PlatformInfo platformInfo; + fe::OptionalInfo optionalInfo; + auto ret = fe::PlatformInfoManager::Instance().GetPlatformInfoWithOutSocVersion(platformInfo, optionalInfo); + if (ret != ge::GRAPH_SUCCESS) { + paramsInfo.platform = PlatformID::UNKNOWN; + OPS_LOG_W(context->GetNodeName(), "Cannot get platform info!"); + return GRAPH_SUCCESS; + } else { + paramsInfo.platform = optionalInfo.soc_version.find("310P") != std::string::npos ? + PlatformID::ASCEND310P : PlatformID::ASCEND910B; + } + auto x0Desc = context->GetDynamicInputDesc(INDEX_IN_X, 0); + OPS_LOG_E_IF_NULL(context, x0Desc, return ge::GRAPH_FAILED); + DataType xDtype = x0Desc->GetDataType(); + auto w0Desc = context->GetDynamicInputDesc(INDEX_IN_WEIGHT, 0); + OPS_LOG_E_IF_NULL(context, w0Desc, return ge::GRAPH_FAILED); + DataType weightDtype = w0Desc->GetDataType(); + if ((xDtype == DataType::DT_BF16 || xDtype == DataType::DT_FLOAT16 || + xDtype == DataType::DT_FLOAT) && xDtype == weightDtype) { + // nonquant + return CheckNonQuant(context); + } + if (xDtype == DataType::DT_INT8 && weightDtype == DataType::DT_INT8) { + // quant + return CheckGroupedMatmulQuant(context, gmmAttrs, paramsInfo); + } + if ((xDtype == DataType::DT_BF16 || xDtype == DataType::DT_FLOAT16) && + (weightDtype == DataType::DT_INT8 || weightDtype == DataType::DT_INT4)) { + // antiquant + return CheckGroupedMatmulAntiQuantForShape(context, gmmAttrs, paramsInfo); + } + return GRAPH_FAILED; +} + +static ge::graphStatus CheckDimNumAndGroupListNoSplitAndFormat(const gert::InferShapeContext* context, + uint64_t tensorListLength, const size_t numWeight) { + // when groupList is not empty, check its size equal with the length of x. + auto groupTensorOptionalShape = context->GetOptionalInputShape(INDEX_IN_GROUP_LIST); + if (groupTensorOptionalShape != nullptr) { + OPS_CHECK(groupTensorOptionalShape->GetDim(0) != static_cast(tensorListLength), + OPS_LOG_E(context->GetNodeName(), "Size of groupList(tensor) %ld should be equal to size of x %lu.", + groupTensorOptionalShape->GetDim(0), tensorListLength), + return GRAPH_FAILED); + } + auto wShape = context->GetDynamicInputShape(INDEX_IN_WEIGHT, 0); + OPS_LOG_E_IF_NULL(context, wShape, return ge::GRAPH_FAILED); + // check dimension + for (size_t i = 0; i < tensorListLength; ++i) { + auto xShape = context->GetDynamicInputShape(INDEX_IN_X, i); + OPS_CHECK(xShape == nullptr, + OPS_LOG_E(context->GetNodeName(), "x[%lu] is null, which is not supported.", i), + return GRAPH_FAILED); + if (numWeight > 1) { + wShape = context->GetDynamicInputShape(INDEX_IN_WEIGHT, i); + OPS_LOG_E_IF_NULL(context, wShape, return ge::GRAPH_FAILED); + size_t weightDimNum = wShape->GetDimNum(); + OPS_CHECK(weightDimNum != SEPARATED_WEIGHT_DIM, + OPS_LOG_E(context->GetNodeName(), + "weight[%lu] dimNum is %lu , but only support 2 when weight separated.", + i, weightDimNum), + return GRAPH_FAILED); + } + size_t xDimNum = xShape->GetDimNum(); + OPS_CHECK(xDimNum > MAX_FM_DIM || xDimNum < MIN_FM_DIM, + OPS_LOG_E(context->GetNodeName(), "x[%lu] dimNum is %lu , but only support 2-6.", i, xDimNum), + return GRAPH_FAILED); + } + return GRAPH_SUCCESS; +} + +static ge::graphStatus TensorType2NodeId(const std::vector& tensorType, std::vector& nodeIdx) { + if (nodeIdx.size() > tensorType.size()) { + return GRAPH_FAILED; + } + for (size_t i(0); i < nodeIdx.size(); ++i) { + if (tensorType[i] == "x") { + nodeIdx[i] = INDEX_IN_X; + } else if (tensorType[i] == "weight") { + nodeIdx[i] = INDEX_IN_WEIGHT; + } else if (tensorType[i] == "y") { + nodeIdx[i] = INDEX_OUT_Y; + } else { + return GRAPH_FAILED; + } + } + return GRAPH_SUCCESS; +} + +static ge::graphStatus CheckDimNum(gert::InferShapeContext* context, uint64_t tensorListLength, + const size_t expectedDimNum, const std::string tensorType) { + int64_t nodeIdx = 0; + if (tensorType == "x") { + nodeIdx = INDEX_IN_X; + } else if (tensorType == "weight") { + nodeIdx = INDEX_IN_WEIGHT; + } else if (tensorType == "y") { + nodeIdx = INDEX_OUT_Y; + } else { + return GRAPH_FAILED; + } + const gert::Shape* shape; + for (size_t i = 0; i < tensorListLength; ++i) { + if (tensorType == "y") { + shape = context->GetOutputShape(nodeIdx + i); + } else { + shape = context->GetDynamicInputShape(nodeIdx, i); + } + OPS_CHECK(shape == nullptr, + OPS_LOG_E(context->GetNodeName(), "%s[%lu] is null, which is not supported.", tensorType.c_str(), i), + return GRAPH_FAILED); + size_t dimNum = shape->GetDimNum(); + OPS_CHECK(dimNum != expectedDimNum, + OPS_LOG_E(context->GetNodeName(), "%s[%lu] dim num should be %lu in this case, but now is %lu.", + tensorType.c_str(), i, expectedDimNum, dimNum), + return GRAPH_FAILED); + } + return GRAPH_SUCCESS; +} + +static ge::graphStatus CheckWeightShapeInnerAxisEven(const gert::InferShapeContext* context, const size_t weightSize, + const int64_t innerAxisDimId) { + auto w0Desc = context->GetDynamicInputDesc(INDEX_IN_WEIGHT, 0); + OPS_LOG_E_IF_NULL(context, w0Desc, return ge::GRAPH_FAILED); + DataType wDtype = w0Desc->GetDataType(); + if (wDtype == DataType::DT_INT4) { + for (size_t i = 0; i < weightSize; ++i) { + auto wShape = context->GetDynamicInputShape(INDEX_IN_WEIGHT, i); + OPS_LOG_E_IF_NULL(context, wShape, return ge::GRAPH_FAILED); + int64_t n = wShape->GetDim(innerAxisDimId); + OPS_CHECK(n % 2 != 0, + OPS_LOG_E(context->GetNodeName(), "w[%zu] dim %ld value %ld should be even when weight is int4 dtype.", + i, innerAxisDimId, n), + return GRAPH_FAILED); + } + } + return GRAPH_SUCCESS; +} + +static ge::graphStatus IsxSizeEqualWithWeightKAxis(const gert::InferShapeContext* context, + const GMMParamsInfo& paramsInfo, const gert::Shape* wShape, size_t& wKDimIdx, size_t& wNDimIdx) { + if (paramsInfo.numWeight == 1 && wShape->GetDimNum() > 2) { // 2: separated tensor's dim + wKDimIdx += 1; + wNDimIdx += 1; + OPS_CHECK(paramsInfo.numX != static_cast(wShape->GetDim(0)), + OPS_LOG_E(context->GetNodeName(), "When x and y are separated, and weight is not separated, size of x " + "%zu should equal to the first dim of weight tensor %ld.", paramsInfo.numX, wShape->GetDim(0)), + return GRAPH_FAILED); + } + return GRAPH_SUCCESS; +} + +static ge::graphStatus CheckCaseNoSplit(gert::InferShapeContext* context, bool transposeWeight, + const GMMParamsInfo& paramsInfo) { + const size_t& xSize = paramsInfo.numX; + const size_t& weightSize = paramsInfo.numWeight; + // check group num + OPS_CHECK(xSize != paramsInfo.numY, OPS_LOG_E(context->GetNodeName(), + "When y is separated, size of x %lu should equal to size of y %lu.", xSize, paramsInfo.numY), return GRAPH_FAILED); + OPS_CHECK(weightSize != 1 && xSize != weightSize, OPS_LOG_E(context->GetNodeName(), "When x and weight are separated, " + "size of x %lu should equal to size of weight %lu.", xSize, weightSize), return GRAPH_FAILED); + // check dimension + OPS_CHECK(CheckDimNumAndGroupListNoSplitAndFormat(context, xSize, weightSize) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Dim num or format of tensor in tensor lists or grouplist is invalid."), + return GRAPH_FAILED); + // check shape + auto wShape = context->GetDynamicInputShape(INDEX_IN_WEIGHT, 0); + OPS_LOG_E_IF_NULL(context, wShape, return ge::GRAPH_FAILED); + size_t wKDimIdx = transposeWeight ? 1 : 0; + size_t wNDimIdx = transposeWeight ? 0 : 1; + OPS_CHECK(IsxSizeEqualWithWeightKAxis(context, paramsInfo, wShape, wKDimIdx, wNDimIdx) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "IsxSizeEqualWithWeightKAxis failed."), return GRAPH_FAILED); + int64_t weightKDimValue = wShape->GetDim(wKDimIdx); + int64_t weightNDimValue = wShape->GetDim(wNDimIdx); + auto w0Desc = context->GetDynamicInputDesc(INDEX_IN_WEIGHT, 0); + OPS_LOG_E_IF_NULL(context, w0Desc, return ge::GRAPH_FAILED); + DataType wDtype = w0Desc->GetDataType(); + // 2: an even factor + OPS_CHECK(wDtype == DataType::DT_INT4 && weightNDimValue % 2 != 0, OPS_LOG_E(context->GetNodeName(), + "w[0] dim %lu value %ld should be even when weight is int4 dtype.", wNDimIdx, weightNDimValue), + return GRAPH_FAILED); + for (size_t i = 0; i < xSize; i++) { + auto xShape = context->GetDynamicInputShape(INDEX_IN_X, i); + size_t xDimNum = xShape->GetDimNum(); + // check inner axis of x, which should not be larger than 65535 + int64_t xKDimValue = xShape->GetDim(xDimNum - 1); // x always is not transposed + OPS_CHECK(xKDimValue > MAX_INNER_AXIS, + OPS_LOG_E(context->GetNodeName(), "x[%lu] dim %lu value %ld should less or equal to %ld.", + i, xDimNum - 1, xKDimValue, MAX_INNER_AXIS), + return GRAPH_FAILED); + if (weightSize > 1) { + wShape = context->GetDynamicInputShape(INDEX_IN_WEIGHT, i); + weightKDimValue = wShape->GetDim(wKDimIdx); + weightNDimValue = wShape->GetDim(wNDimIdx); + // 2: an even factor + OPS_CHECK(i > 0 && wDtype == DataType::DT_INT4 && weightNDimValue % 2 != 0, OPS_LOG_E(context->GetNodeName(), + "w[%lu] dim %lu value %ld should be even when weight is int4 dtype.", i, wNDimIdx, weightNDimValue), + return GRAPH_FAILED); + } + OPS_CHECK(xKDimValue != weightKDimValue, + OPS_LOG_E(context->GetNodeName(), "x[%lu] dim %lu value %ld should equal to weight[%lu] dim 0 value %ld.", + i, xDimNum - 1, xKDimValue, i, weightKDimValue), + return GRAPH_FAILED); + // if weight is not transposed, check N aisx; otherwise, check K axis, which can be skiped + OPS_CHECK(!transposeWeight && weightNDimValue > MAX_INNER_AXIS, + OPS_LOG_E(context->GetNodeName(), "w[%zu] dim %zu value %ld should less or equal to %ld.", + i, wNDimIdx, weightNDimValue, MAX_INNER_AXIS), + return GRAPH_FAILED); + } + return GRAPH_SUCCESS; +} + +static ge::graphStatus CheckInnerAxisOfTensorList(const gert::InferShapeContext* context, size_t nodeId, + int64_t innerAxisDimId, size_t checkNum) { + for (size_t i = 0; i < checkNum; i++) { + auto shape = context->GetDynamicInputShape(nodeId, i); + OPS_LOG_E_IF_NULL(context, shape, return ge::GRAPH_FAILED); + int64_t innerAxisValue = shape->GetDim(innerAxisDimId); + OPS_CHECK(innerAxisValue > MAX_INNER_AXIS, + OPS_LOG_E(context->GetNodeName(), "Dim %ld value of %zu-th shape should less or equal to %ld, " + "but now is %ld.", innerAxisDimId, i, MAX_INNER_AXIS, innerAxisValue), + return GRAPH_FAILED); + } + return GRAPH_SUCCESS; +} + +static ge::graphStatus CheckShapeSameLengthTensorList(gert::InferShapeContext* context, + const std::vector& dimIds, const int64_t innerAxisDimId, + const std::vector tensorType, uint64_t groupNum) { + std::vector nodeIdx = {0, 0}; + OPS_CHECK(TensorType2NodeId(tensorType, nodeIdx) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "TensorType2NodeId failed."), + return GRAPH_FAILED); + // check two tensorlist's size to be the same, and tensors to have consistant dimension. + const gert::Shape* shape; + for (uint64_t i = 0; i < groupNum; i++) { + shape = context->GetDynamicInputShape(nodeIdx[0], i); + OPS_LOG_E_IF_NULL(context, shape, return ge::GRAPH_FAILED); + int64_t dimValue1 = shape->GetDim(dimIds[0]); + // tensorType[2] indicates whether check tensorList0's inner axis(innerAxisDimId) + if (tensorType[2] == "true" && innerAxisDimId > -1) { + auto shape0 = context->GetDynamicInputShape(nodeIdx[0], i); + OPS_LOG_E_IF_NULL(context, shape0, return ge::GRAPH_FAILED); + int64_t innerAxisValue = shape0->GetDim(innerAxisDimId); + OPS_CHECK(innerAxisValue > MAX_INNER_AXIS, + OPS_LOG_E(context->GetNodeName(), "Dim %lu value of %s[%lu] should less or equal to %ld, " + "but now is %ld.", + dimIds[0], tensorType[0].c_str(), i, MAX_INNER_AXIS, innerAxisValue), + return GRAPH_FAILED); + } + if (tensorType[1] == "y") { + shape = context->GetOutputShape(nodeIdx[1] + i); + } else { + shape = context->GetDynamicInputShape(nodeIdx[1], i); + } + OPS_LOG_E_IF_NULL(context, shape, return ge::GRAPH_FAILED); + int64_t dimValue2 = shape->GetDim(dimIds[1]); + OPS_CHECK(dimValue1 != dimValue2, + OPS_LOG_E(context->GetNodeName(), + "Dim %lu value of %s[%lu] should be equal with dim %lu value of %s[%lu]" + ", but now is %ld and %ld respectively.", dimIds[0], tensorType[0].c_str(), + i, dimIds[1], tensorType[1].c_str(), i, dimValue1, dimValue2), + return GRAPH_FAILED); + } + return GRAPH_SUCCESS; +} + +static ge::graphStatus CheckShapeDiffLengthTensorList(gert::InferShapeContext* context, + const std::vector& dimIds, + const int64_t innerAxisdimId, + const std::vector tensorType, + uint64_t groupNum) { + std::vector nodeIdx = {0, 0}; + OPS_CHECK(TensorType2NodeId(tensorType, nodeIdx) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "TensorType2NodeId failed."), + return GRAPH_FAILED); + // check each tensor's selected dimension size in a multi-tensor tensorlist's to equal with + // the tensor selected dimension in single-tensor tensorlist. + // the selected axis is not the split-axis. + const gert::Shape* singleTensor0; + if (tensorType[1] == "y") { + singleTensor0 = context->GetOutputShape(nodeIdx[1]); + } else { + singleTensor0 = context->GetDynamicInputShape(nodeIdx[1], 0); + } + OPS_LOG_E_IF_NULL(context, singleTensor0, return ge::GRAPH_FAILED); + int64_t dimValueSingle = singleTensor0->GetDim(dimIds[1]); + // tensorType[2] indicates whether check single tensorList's inner axis(innerAxisDimId) + if (tensorType[2] == "true" && innerAxisdimId > -1) { + int64_t dimValue = singleTensor0->GetDim(innerAxisdimId); + OPS_CHECK(dimValue > MAX_INNER_AXIS, + OPS_LOG_E(context->GetNodeName(), + "Dim %ld value of %s[0] should less or equal to %ld, but now is %ld.", + innerAxisdimId, tensorType[1].c_str(), MAX_INNER_AXIS, dimValue), + return GRAPH_FAILED); + } + const gert::Shape* longTensor; + for (uint64_t i = 0; i < groupNum; i++) { + if (tensorType[0] == "y") { + longTensor = context->GetOutputShape(nodeIdx[0] + i); + } else { + longTensor = context->GetDynamicInputShape(nodeIdx[0], i); + } + OPS_LOG_E_IF_NULL(context, longTensor, return ge::GRAPH_FAILED); + int64_t dimValueLong = longTensor->GetDim(dimIds[0]); + OPS_CHECK(dimValueLong != dimValueSingle, + OPS_LOG_E(context->GetNodeName(), + "Dim %lu value of %s[%lu] %ld should be equal with dim %lu value of %s[0] %ld.", + dimIds[0], tensorType[0].c_str(), i, dimValueLong, + dimIds[1], tensorType[1].c_str(), dimValueSingle), + return GRAPH_FAILED); + } + return GRAPH_SUCCESS; +} + +static ge::graphStatus CheckGroupListCommonTensor(const gert::InferShapeContext* context, + const bool isRequiredGroupList, const int64_t groupNum) { + auto groupTensorOptionalShape = context->GetOptionalInputShape(INDEX_IN_GROUP_LIST); + bool isNull = groupTensorOptionalShape == nullptr; + OPS_CHECK(isNull && isRequiredGroupList, + OPS_LOG_E(context->GetNodeName(), "groupListOptional(tensor) is required in this case, but get nullptr."), + return GRAPH_FAILED); + if (isNull) { + return GRAPH_SUCCESS; + } + int64_t groupListSize = groupTensorOptionalShape->GetDim(0); + OPS_CHECK(groupListSize > MAX_GROUP_LIST_SIZE_TENSOR, + OPS_LOG_E(context->GetNodeName(), + "When groupList type is tenosr, size of groupList %ld should be less than or equal to %ld.", + groupListSize, MAX_GROUP_LIST_SIZE_TENSOR), + return GRAPH_FAILED); + OPS_CHECK(!((groupListSize == groupNum && groupNum > 1) || groupNum == 1), + OPS_LOG_E(context->GetNodeName(), + "When groupList is not null, size of groupList(tensor) %ld should be equal to groupNum %ld.", + groupListSize, groupNum), + return GRAPH_FAILED); + auto groupListDesc = context->GetOptionalInputDesc(INDEX_IN_GROUP_LIST); + OPS_LOG_E_IF_NULL(context, groupListDesc, return ge::GRAPH_FAILED); + OPS_CHECK(groupListDesc->GetDataType() != DataType::DT_INT64, + OPS_LOG_E(context->GetNodeName(), "Invalid dtype: Only int64 is supported for groupList, but now is %s.", + ToString(groupListDesc->GetDataType()).data()), + return GRAPH_FAILED); + return GRAPH_SUCCESS; +} + +static ge::graphStatus SplitMSingleXSingleWeightSingleY(gert::InferShapeContext* context, bool transposeWeight, + const GMMParamsInfo& paramsInfo) { + std::vector tenorXAndWeight{"x", "weight", "true"}; + // check dimension + OPS_CHECK(CheckDimNum(context, paramsInfo.numX, MIN_FM_DIM, "x") != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Dim num or format of tensor in tensor list x is invalid."), + return GRAPH_FAILED); + OPS_CHECK(CheckDimNum(context, paramsInfo.numWeight, SPLIT_M_SINGLE_WEIGHT_DIM, "weight") != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Dim num or format of tensor in tensor list weight is invalid."), + return GRAPH_FAILED); + // check shape, x(m,k), weight(b,k,n), y(m,n) + int64_t innerAxisDimId = 1; // x always is not transposed, check K axis + size_t kAxisOfWeight = transposeWeight ? 2 : 1; // if weight is transposed, 2 is the k axis idx of the weight, otherwise is 1 + OPS_CHECK(CheckShapeSameLengthTensorList(context, {1, kAxisOfWeight}, innerAxisDimId, tenorXAndWeight, paramsInfo.numX) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "k dim value of x and weight is not matched."), + return GRAPH_FAILED); + innerAxisDimId = !transposeWeight ? 2 : -1; // If w is not transposed, check N(2) asix; otherwise, check k axis, which can be skiped + OPS_CHECK(CheckInnerAxisOfTensorList(context, INDEX_IN_WEIGHT, innerAxisDimId, paramsInfo.numWeight) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "inner axis size of weight is larger than %ld!", MAX_INNER_AXIS), + return GRAPH_FAILED); + OPS_CHECK(CheckWeightShapeInnerAxisEven(context, paramsInfo.numWeight, 2) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "weight's N axis size should be even when it is int4 dtype."), + return GRAPH_FAILED); + // check groupList + OPS_CHECK(CheckGroupListCommonTensor(context, true, context->GetDynamicInputShape(INDEX_IN_WEIGHT, 0)->GetDim(0)) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Invalid groupList."), + return GRAPH_FAILED); + return GRAPH_SUCCESS; +} + +static ge::graphStatus SplitMSingleXSeparatedWeightSingleY(gert::InferShapeContext* context, bool transposeWeight, + const GMMParamsInfo& paramsInfo) { + std::vector tenorWeightAndX{"weight", "x", "true"}; + // check dimension + OPS_CHECK(CheckDimNum(context, paramsInfo.numX, MIN_FM_DIM, "x") != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Dim num or format of tensor in tensor list x is invalid."), + return GRAPH_FAILED); + OPS_CHECK(CheckDimNum(context, paramsInfo.numWeight, SEPARATED_WEIGHT_DIM, "weight") != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Dim num or format of tensor in tensor list weight is invalid."), + return GRAPH_FAILED); + // check shape, x(m,k), weight(k,n), y(m,n) + int64_t innerAxisDimId = 1; // x always is not transposed, check K axis + size_t kAxisOfWeight = transposeWeight ? 1 : 0; + OPS_CHECK(CheckShapeDiffLengthTensorList(context, {kAxisOfWeight, 1}, innerAxisDimId, tenorWeightAndX, paramsInfo.numWeight) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "k dim value of x and weight is not matched."), + return GRAPH_FAILED); + innerAxisDimId = !transposeWeight ? 1 : -1; // if w is not transposed, check N asix; otherwise, check k axis, which can be skiped + OPS_CHECK(CheckInnerAxisOfTensorList(context, INDEX_IN_WEIGHT, innerAxisDimId, 1) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "inner axis size of weight is larger than %ld!", MAX_INNER_AXIS), + return GRAPH_FAILED); + OPS_CHECK(CheckWeightShapeInnerAxisEven(context, paramsInfo.numWeight, 1) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "weight's N axis size should be even when it is int4 dtype."), + return GRAPH_FAILED); + // check groupList + OPS_CHECK(CheckGroupListCommonTensor(context, true, paramsInfo.numWeight) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Invalid groupList."), + return GRAPH_FAILED); + return GRAPH_SUCCESS; +} + +static ge::graphStatus SplitMSeparatedXSeparatedWeightSingleY(gert::InferShapeContext* context, + bool transposeWeight, const GMMParamsInfo& paramsInfo) { + const size_t& xSize = paramsInfo.numX; + const size_t& weightSize = paramsInfo.numWeight; + std::vector tenorWeightAndX{"weight", "x", "true"}; + OPS_CHECK(xSize != weightSize, + OPS_LOG_E(context->GetNodeName(), + "When x and weight are separated, size of x %lu should equal to size of weight %lu.", + xSize, weightSize), + return GRAPH_FAILED); + // check dimension + OPS_CHECK(CheckDimNum(context, xSize, MIN_FM_DIM, "x") != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Dim num or format of tensor in tensor list x is invalid."), + return GRAPH_FAILED); + OPS_CHECK(CheckDimNum(context, weightSize, SEPARATED_WEIGHT_DIM, "weight") != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Dim num or format of tensor in tensor list weight is invalid."), + return GRAPH_FAILED); + // check shape, x(m,k), weight(k,n), y(m,n) + int64_t innerAxisDimId = 1; // originalShape's inner axis of weight + size_t kAxisOfWeight = transposeWeight ? 1 : 0; + OPS_CHECK(CheckShapeSameLengthTensorList(context, {kAxisOfWeight, 1}, innerAxisDimId, tenorWeightAndX, weightSize) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "k dim value of x and weight is not matched."), + return GRAPH_FAILED); + innerAxisDimId = !transposeWeight ? 1 : -1; // if w is not transposed, N asix has been checked, need to check x's inner axis(K, when x is always not transposed) + OPS_CHECK(CheckInnerAxisOfTensorList(context, INDEX_IN_X, innerAxisDimId, 1) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "inner axis size of x is larger than %ld!", MAX_INNER_AXIS), + return GRAPH_FAILED); + OPS_CHECK(CheckWeightShapeInnerAxisEven(context, weightSize, 1) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "weight's N axis size should be even when it is int4 dtype."), + return GRAPH_FAILED); + // check groupList + OPS_CHECK(CheckGroupListCommonTensor(context, false, xSize) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Invalid groupList."), + return GRAPH_FAILED); + return GRAPH_SUCCESS; +} + +static ge::graphStatus CheckCaseSplitM(gert::InferShapeContext* context, bool transposeWeight, + const GMMParamsInfo& paramsInfo) { + const size_t& xSize = paramsInfo.numX; + const size_t& weightSize = paramsInfo.numWeight; + const size_t& ySize = paramsInfo.numY; + if (xSize == 1 && weightSize == 1 && ySize == 1) { + OPS_CHECK(SplitMSingleXSingleWeightSingleY(context, transposeWeight, paramsInfo) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Split m, single x, single weight, single y case failed."), + return GRAPH_FAILED); + return GRAPH_SUCCESS; + } + if (xSize == 1 && weightSize > 1 && ySize == 1) { + OPS_CHECK(weightSize != paramsInfo.groupNum, OPS_LOG_E(context->GetNodeName(), + "weight Size [%zu] does not equal with groupNum %zu", weightSize, paramsInfo.groupNum), + return GRAPH_FAILED); + OPS_CHECK(SplitMSingleXSeparatedWeightSingleY(context, transposeWeight, paramsInfo) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Split m, single x, separated weight, single y case failed."), + return GRAPH_FAILED); + return GRAPH_SUCCESS; + } + if (xSize == 1 && weightSize > 1 && ySize > 1) { + const gert::Tensor* groupListTensor = context->GetOptionalInputTensor(INDEX_IN_GROUP_LIST); + OPS_CHECK(groupListTensor == nullptr || groupListTensor->GetData() == nullptr, + OPS_LOG_E(context->GetNodeName(), "Failed to obtain necessary data from groupListTensor. " + "When grouplist is an invalid tensor, split m, single x, separated weight, separated y cases do not support."), + return GRAPH_FAILED); + return GRAPH_SUCCESS; // skip the check + } + if (xSize > 1 && weightSize > 1 && ySize == 1) { + OPS_CHECK(weightSize != paramsInfo.groupNum, OPS_LOG_E(context->GetNodeName(), + "weight Size [%zu] does not equal with groupNum %zu", weightSize, paramsInfo.groupNum), + return GRAPH_FAILED); + OPS_CHECK(SplitMSeparatedXSeparatedWeightSingleY(context, transposeWeight, paramsInfo) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Split m, separated x, separated weight, single y case failed."), + return GRAPH_FAILED); + return GRAPH_SUCCESS; + } + OPS_LOG_E(context->GetNodeName(), "When groupType is 0, current case with x %zu, weight %zu, y %zu is not supported.", + xSize, weightSize, ySize); + return GRAPH_FAILED; +} + +static ge::graphStatus CheckCaseSplitK(gert::InferShapeContext* context, bool transposeX, bool transposeWeight, + const GMMParamsInfo& paramsInfo) { + std::vector tenorXAndWeight{"x", "weight", "true"}; + const size_t& xSize = paramsInfo.numX; + const size_t& weightSize = paramsInfo.numWeight; + const size_t& ySize = paramsInfo.numY; + if (xSize == 1 && ySize == 1 && weightSize == 1) { + OPS_CHECK(!transposeX, + OPS_LOG_E(context->GetNodeName(), + "When groupType is 2 and x is not separated, tensor in x should be transposed."), + return GRAPH_FAILED); + // check dimension + OPS_CHECK(CheckDimNum(context, xSize, MIN_FM_DIM, "x") != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Dim num or format of tensor in tensor list x is invalid."), + return GRAPH_FAILED); + OPS_CHECK(CheckDimNum(context, weightSize, SPLIT_K_SINGLE_WEIGHT_DIM, "weight") != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Dim num or format of tensor in tensor list weight is invalid."), + return GRAPH_FAILED); + // check shape, x(m,k), weight(k,n), y(b,m,n) + int64_t innerAxisDimId = 1; // x always is transposed, and the inner axis is always the last axis, M axis. + size_t kAxisOfWeight = transposeWeight ? 1 : 0; + OPS_CHECK(CheckShapeSameLengthTensorList(context, {0, kAxisOfWeight}, innerAxisDimId, tenorXAndWeight, xSize) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "k dim value of x and weight is not matched."), + return GRAPH_FAILED); + innerAxisDimId = 1; // w always is not transposed, and the inner axis is always the last axis, N axis. + OPS_CHECK(CheckInnerAxisOfTensorList(context, INDEX_IN_WEIGHT, innerAxisDimId, weightSize) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "inner axis size of weight is larger than %ld!", MAX_INNER_AXIS), + return GRAPH_FAILED); + // check groupList + OPS_CHECK(CheckGroupListCommonTensor(context, true, 1) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Invalid groupList."), + return GRAPH_FAILED); + return GRAPH_SUCCESS; + } + OPS_LOG_E(context->GetNodeName(), + "When groupType is 2, only support case with unseparated x, weight and y, " + "but now x size is %lu, weight size is %lu, y size is %lu.", xSize, weightSize, ySize); + return GRAPH_FAILED; +} + +static ge::graphStatus CheckParamDifferentGroupType(gert::InferShapeContext* context, const GMMAttrs& gmmAttrs, + const GMMParamsInfo& paramsInfo) { + OPS_CHECK(paramsInfo.platform == PlatformID::UNKNOWN, OPS_LOG_W(context->GetNodeName(), "Cannot get platform info!"), return GRAPH_SUCCESS); + const int64_t& groupType = gmmAttrs.groupType; + const bool& transposeX = gmmAttrs.transposeX; + const bool& transposeWeight = gmmAttrs.transposeWeight; + OPS_CHECK(transposeX && transposeWeight, OPS_LOG_E(context->GetNodeName(), + "x and weight can not be transposed at the same time."), return GRAPH_FAILED); + auto groupTensorOptionalShape = context->GetOptionalInputShape(INDEX_IN_GROUP_LIST); + OPS_CHECK(groupTensorOptionalShape != nullptr && (groupTensorOptionalShape->GetDimNum() > 1 || + groupTensorOptionalShape->GetDim(0) <= 1), + OPS_LOG_E(context->GetNodeName(), "When groupList is a tensor, its dim only supports 1 and size of " + "elements should be larger than 1, but now are %zu and %ld, respectively.", + groupTensorOptionalShape->GetDimNum(), groupTensorOptionalShape->GetDim(0)), + return GRAPH_FAILED); + OPS_CHECK(paramsInfo.platform == PlatformID::ASCEND310P && !(groupType == SPLIT_M && paramsInfo.numX == 1 && + paramsInfo.numWeight == 1 && paramsInfo.numY == 1), + OPS_LOG_E(context->GetNodeName(), + "When on ASCEND310P, it only supports split m, single x, single weight, single y."), + return GRAPH_FAILED); + + if (groupType == NO_SPLIT) { + OPS_CHECK(transposeX, OPS_LOG_E(context->GetNodeName(), + "When x, weight and y are all separated, x can not be transposed."), return GRAPH_FAILED); + OPS_CHECK(CheckCaseNoSplit(context, transposeWeight, paramsInfo) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Invalid inputs!"), return GRAPH_FAILED); + } else if (groupType == SPLIT_M) { + OPS_CHECK(transposeX, + OPS_LOG_E(context->GetNodeName(), "When groupType is 0, x can not be transposed."), + return GRAPH_FAILED); + OPS_CHECK(CheckCaseSplitM(context, transposeWeight, paramsInfo) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Invalid inputs!"), return GRAPH_FAILED); + } else if (groupType == SPLIT_K) { + OPS_CHECK(!IsTensorListNullOrEmpty(context, INDEX_IN_BIAS), + OPS_LOG_E(context->GetNodeName(), "When groupType is 2, bias must be empty."), return GRAPH_FAILED); + OPS_CHECK(CheckCaseSplitK(context, transposeX, transposeWeight, paramsInfo) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Invalid inputs!"), return GRAPH_FAILED); + } + if (!IsTensorListNullOrEmpty(context, INDEX_IN_BIAS)) { + OPS_CHECK(CheckOptionalTensorList(context, "bias", paramsInfo, gmmAttrs, INDEX_IN_BIAS) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Invalid bias!"), return GRAPH_FAILED); + } + return GRAPH_SUCCESS; +} + +static ge::graphStatus XNotSingleYSeparated(gert::InferShapeContext* context, + size_t weightDimN, bool isXTransposed, size_t xDimM) { + const gert::Tensor* groupListTensor = context->GetOptionalInputTensor(INDEX_IN_GROUP_LIST); + if (groupListTensor != nullptr) { + OPS_CHECK(UpdateMultipleShapeY(context, groupListTensor, weightDimN, isXTransposed, xDimM) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Failed to update shape of y."), return GRAPH_FAILED); + } else { + OPS_CHECK(MultiInMultiOutWithoutGroupList(context)!= GRAPH_SUCCESS, OPS_LOG_E(context->GetNodeName(), + "Failed to process multi-in-multi-out case without GroupList."), return GRAPH_FAILED); + } + return GRAPH_SUCCESS; +} + +static ge::graphStatus XSingleYSeparated(gert::InferShapeContext* context, + size_t weightDimN, bool isXTransposed, size_t xDimM) { + const gert::Tensor* groupListTensor = context->GetOptionalInputTensor(INDEX_IN_GROUP_LIST); + OPS_CHECK(groupListTensor == nullptr, + OPS_LOG_E(context->GetNodeName(), "GroupList is required when x is single tensor while y is not."), + return GRAPH_FAILED); + OPS_CHECK(UpdateMultipleShapeY(context, groupListTensor, weightDimN, isXTransposed, xDimM) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Failed to update shape of y."), + return GRAPH_FAILED); + return GRAPH_SUCCESS; +} + +static ge::graphStatus InferShape4GroupedMatmul(gert::InferShapeContext* context) { + GMMAttrs gmmAttrs{X_Y_SEPARATED, NO_SPLIT, false, false, 0}; + OPS_CHECK(GetAttrs(context, gmmAttrs) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Failed to get attrs."), return GRAPH_FAILED); + + size_t numX = 0; // init numX + size_t numWeight = 0; // init numWeight + int64_t lenGroupList = 0; // init lenGroupList + size_t numY = context->GetComputeNodeOutputNum(); + if (GetNumOfInputs(context, numX, numWeight, lenGroupList) == GRAPH_SUCCESS) { // check input shape value inside + GMMParamsInfo paramsInfo{numX, numWeight, numY, lenGroupList, 0, 0, 0, 0, 0, PlatformID::UNKNOWN}; + OPS_CHECK(GetGroupSize(context, paramsInfo) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "check groupNum failed"), return GRAPH_FAILED); + OPS_CHECK(CheckFunctionParamsForShape(context, gmmAttrs, paramsInfo) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "CheckFunctionParamsForShape failed."), return GRAPH_FAILED); + OPS_CHECK(CheckParamDifferentGroupType(context, gmmAttrs, paramsInfo) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "CheckParamDifferentGroupType failed."), return GRAPH_FAILED); + } else { + OPS_CHECK(CheckDimNum(context, numX, MIN_FM_DIM, "x") != GRAPH_SUCCESS, // check dim number of tensors + OPS_LOG_E(context->GetNodeName(), "Dim num of tensor in tensorList x is invalid."), + return GRAPH_FAILED); + } + + const gert::Shape* x0Shape = context->GetDynamicInputShape(INDEX_IN_X, 0); + OPS_LOG_E_IF_NULL(context, x0Shape, return ge::GRAPH_FAILED); + size_t xDimNum = x0Shape->GetDimNum(); + const gert::Shape* w0Shape = context->GetDynamicInputShape(INDEX_IN_WEIGHT, 0); + OPS_LOG_E_IF_NULL(context, w0Shape, return ge::GRAPH_FAILED); + size_t weightDimNum = w0Shape->GetDimNum(); + bool isSingleX = numX == 1 && gmmAttrs.groupType != NO_SPLIT; + bool isSingleY = numY == 1 && gmmAttrs.groupType != NO_SPLIT; + size_t xDimM = gmmAttrs.transposeX ? xDimNum - 1 : xDimNum - 2; + size_t weightDimN = gmmAttrs.transposeWeight ? weightDimNum - 2 : weightDimNum - 1; + // set y shape + if (isSingleX && !isSingleY) { + OPS_CHECK(XSingleYSeparated(context, weightDimN, gmmAttrs.transposeX, xDimM) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Failed to update shape of y."), return GRAPH_FAILED); + } else if (isSingleX && isSingleY) { + OPS_CHECK(gmmAttrs.groupType != SPLIT_M && gmmAttrs.groupType != SPLIT_K, OPS_LOG_E(context->GetNodeName(), + "When x is single tensor, input tensors can only be split along M or K axis."), return GRAPH_FAILED); + std::vector yDims = {x0Shape->GetDim(xDimM), w0Shape->GetDim(weightDimN)}; + if (gmmAttrs.groupType == SPLIT_K) { yDims.insert(yDims.begin(), numWeight == 1 ? lenGroupList : numWeight); } + OPS_CHECK(UpdateShapeY(context, INDEX_OUT_Y, yDims) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Failed to update y shape."), return GRAPH_FAILED); + } else if (!isSingleX && !isSingleY) { + OPS_CHECK(XNotSingleYSeparated(context, weightDimN, gmmAttrs.transposeX, xDimM) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Failed to update shape of y."), return GRAPH_FAILED); + } else if (!isSingleX && isSingleY) { + std::vector yDims = {GetDim0(context, gmmAttrs.transposeX, numX, xDimM), w0Shape->GetDim(weightDimN)}; + OPS_CHECK(UpdateShapeY(context, INDEX_OUT_Y, yDims) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Failed to update shape of y."), return GRAPH_FAILED); + } + return GRAPH_SUCCESS; +} + +// ========================================================================================= +// ========================================================================================= +static graphStatus CheckTensorListDataType(const gert::InferDataTypeContext* context, uint32_t index, + const DataType dtype) { + size_t inIdx = 0; + while (true) { + auto iDtype = context->GetDynamicInputDataType(index, inIdx); + if (iDtype == DT_UNDEFINED) { + break; + } + OPS_CHECK(iDtype != dtype, + OPS_LOG_E(context->GetNodeName(), "data type of tensors in a tensorList should all be the same!"), + return GRAPH_FAILED); + ++inIdx; + } + return GRAPH_SUCCESS; +} + +static graphStatus CheckMatmulDataType(gert::InferDataTypeContext* context, const DataType xDtype, + const DataType weightDtype, const DataType biasDtype) { + OPS_CHECK(CheckTensorListDataType(context, INDEX_IN_X, xDtype) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "x dtype does not match with required dtype[%s].", + ToString(xDtype).data()), + return GRAPH_FAILED); + OPS_CHECK(CheckTensorListDataType(context, INDEX_IN_WEIGHT, weightDtype) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "weight dtype does not match with required dtype[%s].", + ToString(weightDtype).data()), + return GRAPH_FAILED); + OPS_CHECK(CheckTensorListDataType(context, INDEX_IN_BIAS, biasDtype) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "bias dtype does not match with required dtype[%s].", + ToString(biasDtype).data()), + return GRAPH_FAILED); + return GRAPH_SUCCESS; +} + +static graphStatus CheckFunctionQuantParams(gert::InferDataTypeContext* context) { + OPS_CHECK(CheckTensorListDataType(context, INDEX_IN_X, DataType::DT_INT8) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "x dtype does not match with required dtype[INT8]."), + return GRAPH_FAILED); + OPS_CHECK(CheckTensorListDataType(context, INDEX_IN_WEIGHT, DataType::DT_INT8) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "weight dtype does not match with required dtype[INT8]."), + return GRAPH_FAILED); + OPS_CHECK(CheckTensorListDataType(context, INDEX_IN_BIAS, DataType::DT_INT32) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "bias dtype does not match with required dtype int32."), + return GRAPH_FAILED); + auto scale0Dtype = context->GetDynamicInputDataType(INDEX_IN_SCALE, 0); + // Now we cannot make sure if is pertoken quant case, so scale/offset dtype check is remained to the InferShape stage. + OPS_CHECK(CheckTensorListDataType(context, INDEX_IN_SCALE, scale0Dtype) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "dtypes of scales in the tensorList should all be the same."), + return GRAPH_FAILED); + auto offset0Dtype = context->GetDynamicInputDataType(INDEX_IN_OFFSET, 0); + OPS_CHECK(CheckTensorListDataType(context, INDEX_IN_OFFSET, offset0Dtype) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "dtypes of offsets in the tensorList should all be the same."), + return GRAPH_FAILED); + return GRAPH_SUCCESS; +} + +static graphStatus CheckGroupedMatmulAntiQuantForDtype(gert::InferDataTypeContext* context) { + auto xDtype = context->GetDynamicInputDataType(INDEX_IN_X, 0); + OPS_CHECK(CheckTensorListDataType(context, INDEX_IN_ANTIQUANT_SCALE, xDtype) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "antiquantScale dtype does not match with x dtype[%s].", ToString(xDtype).data()), + return GRAPH_FAILED); + OPS_CHECK(CheckTensorListDataType(context, INDEX_IN_ANTIQUANT_OFFSET, xDtype) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "antiquantOffset dtype does not match with x dtype[%s].", ToString(xDtype).data()), + return GRAPH_FAILED); + return GRAPH_SUCCESS; +} + +static graphStatus CheckFunctionParamsForDtype(gert::InferDataTypeContext* context) { + fe::PlatformInfo platformInfo; + fe::OptionalInfo optionalInfo; + graphStatus ret = fe::PlatformInfoManager::Instance().GetPlatformInfoWithOutSocVersion(platformInfo, optionalInfo); + PlatformID platform = PlatformID::UNKNOWN; + if (ret != ge::GRAPH_SUCCESS) { + OPS_LOG_W(context->GetNodeName(), "Cannot get platform info."); + return GRAPH_SUCCESS; + } else { + platform = optionalInfo.soc_version.find("310P") == std::string::npos ? PlatformID::ASCEND910B : PlatformID::ASCEND310P; + } + DataType xDtype = context->GetDynamicInputDataType(INDEX_IN_X, 0); + DataType weightDtype = context->GetDynamicInputDataType(INDEX_IN_WEIGHT, 0); + if (platform == PlatformID::ASCEND310P) { + bool isAllInputFP16 = xDtype == DataType::DT_FLOAT16 && weightDtype == DataType::DT_FLOAT16; + OPS_CHECK(!isAllInputFP16, OPS_LOG_E(context->GetNodeName(), + "Only float16 is supported on Ascend310P platforms."), return GRAPH_FAILED); + auto biasDtype = context->GetOptionalInputDataType(INDEX_IN_BIAS); + OPS_CHECK(biasDtype != ge::DT_UNDEFINED && biasDtype != DataType::DT_FLOAT16, OPS_LOG_E(context->GetNodeName(), + "only bias float16 is supported on Ascend310P platforms."), return GRAPH_FAILED); + } + if ((xDtype == DataType::DT_BF16 || xDtype == DataType::DT_FLOAT16 || xDtype == DataType::DT_FLOAT) && + xDtype == weightDtype) { // nonquant + DataType biasDtype = xDtype == DataType::DT_BF16 ? DataType::DT_FLOAT: xDtype; + OPS_CHECK(CheckMatmulDataType(context, xDtype, weightDtype, biasDtype) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "case with x dtype %s and weight dtype %s is not supported!", + ToString(xDtype).data(), ToString(weightDtype).data()), + return GRAPH_FAILED); + return GRAPH_SUCCESS; + } + if (xDtype == DataType::DT_INT8 && weightDtype == DataType::DT_INT8) { + // quant + OPS_CHECK(CheckFunctionQuantParams(context) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "CheckFunctionQuantParams failed."), + return GRAPH_FAILED); + return GRAPH_SUCCESS; + } + if ((xDtype == DataType::DT_BF16 || xDtype == DataType::DT_FLOAT16) && + (weightDtype == DataType::DT_INT8 || weightDtype == DataType::DT_INT4)) { + // antiquant + DataType biasDtype = xDtype == DataType::DT_BF16 ? DataType::DT_FLOAT: DataType::DT_FLOAT16; + OPS_CHECK(CheckMatmulDataType(context, xDtype, weightDtype, biasDtype) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "case with x dtype %s and weight dtype %s is not supported!", + ToString(xDtype).data(), ToString(weightDtype).data()), + return GRAPH_FAILED); + return CheckGroupedMatmulAntiQuantForDtype(context); + } + OPS_LOG_E(context->GetNodeName(), "GMM: there is no matching xDtype and weightDtype pattern. " + "case with x dtype %s and weight dtype %s is not supported.", + ToString(xDtype).data(), ToString(weightDtype).data()); + return GRAPH_FAILED; +} + +static graphStatus CheckQuantParamsDtype(const gert::InferDataTypeContext* context, const int64_t outputDtype, + const DataType yDtype) { + size_t i = 0; + auto scale0Dtype = context->GetDynamicInputDataType(INDEX_IN_SCALE, 0); + OPS_CHECK(scale0Dtype == ge::DT_UNDEFINED, OPS_LOG_E(context->GetNodeName(), "scale is undefined!"), + return GRAPH_FAILED); + auto perTokenScale0Dtype = context->GetDynamicInputDataType(INDEX_IN_PERTOKEN_SCALE, 0); + bool isPerTokenQuant = perTokenScale0Dtype != ge::DT_UNDEFINED; + if (isPerTokenQuant) { + bool isOutputBF16 = scale0Dtype == DataType::DT_BF16 && outputDtype == 1; + bool isOutputFloat16 = scale0Dtype == DataType::DT_FLOAT && outputDtype == 0; + OPS_CHECK(!isOutputBF16 && !isOutputFloat16, + OPS_LOG_E(context->GetNodeName(), "per-token quant case only supports scale data type bfloat16 with " + "output data type bfloat16, or scale with data type float32 when output is float16, but " + "now scale[%zu] has data type %s and output has data type %s!", + i, ToString(scale0Dtype).data(), ToString(yDtype).data()), + return GRAPH_FAILED); + } else { + bool isOutputInt8 = scale0Dtype == DataType::DT_UINT64 && outputDtype == -1; + bool isOutputBF16 = scale0Dtype == DataType::DT_BF16 && outputDtype == 1; + bool isOutputFP16 = scale0Dtype == DataType::DT_FLOAT && outputDtype == 0; + OPS_CHECK(!isOutputInt8 && !isOutputBF16 && !isOutputFP16, + OPS_LOG_E(context->GetNodeName(), "per-channel quant case only supports scale with data type uint64 " + "when output is int8, or data type bfloat16 when output is bfloat16, or data type float32 " + "when output is float16, but scale[%zu] has data type %s and output has data type %s!", + i, ToString(scale0Dtype).data(), ToString(yDtype).data()), + return GRAPH_FAILED); + } + if (isPerTokenQuant) { + OPS_CHECK(perTokenScale0Dtype != DataType::DT_FLOAT, + OPS_LOG_E(context->GetNodeName(), "pertoken quant case only support perTokenScale with dtype float32," + "but perTokenScale[%zu] has data type %s!", i, ToString(perTokenScale0Dtype).data()), + return GRAPH_FAILED); + } + return GRAPH_SUCCESS; +} + +static graphStatus InferDataType4GroupedMatmul(gert::InferDataTypeContext* context) { + OPS_CHECK(CheckFunctionParamsForDtype(context) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "CheckFunctionParamsForDtype failed!"), + return GRAPH_FAILED); + + auto x0Dtype = context->GetDynamicInputDataType(INDEX_IN_X, 0); + auto weight0Dtype = context->GetDynamicInputDataType(INDEX_IN_WEIGHT, 0); + size_t numY = context->GetComputeNodeOutputNum(); + auto attrs = context->GetAttrs(); + OPS_LOG_E_IF_NULL(context, attrs, return ge::GRAPH_FAILED); + bool isQuantCase = x0Dtype == ge::DT_INT8 && weight0Dtype == ge::DT_INT8; + const int64_t* outputDtype = attrs->GetInt(INDEX_ATTR_OUTPUT_DTYPE); + DataType yDtype = x0Dtype; + if (isQuantCase && outputDtype != nullptr) { + auto it = OUTPUT_DTYPE_MAP.find(*outputDtype); + OPS_CHECK(it == OUTPUT_DTYPE_MAP.end(), + OPS_LOG_E(context->GetNodeName(), + "value of attr dtype only supports -1/0/1, but now is %ld.", *outputDtype), + return GRAPH_FAILED); + yDtype = it->second; + OPS_CHECK(CheckQuantParamsDtype(context, *outputDtype, yDtype) != GRAPH_SUCCESS, + OPS_LOG_E(context->GetNodeName(), "Check quant params data type failed!"), + return GRAPH_FAILED); + } + for (size_t k = 0; k < numY; k++) { + context->SetOutputDataType(INDEX_OUT_Y + k, yDtype); + } + return GRAPH_SUCCESS; +} + +IMPL_OP_INFERSHAPE(GroupedMatmul) + .InferShape(InferShape4GroupedMatmul) + .InferDataType(InferDataType4GroupedMatmul); +} // namespace ops diff --git a/src/transformer/grouped_matmul/ophost/grouped_matmul_tiling.cpp b/src/transformer/grouped_matmul/ophost/grouped_matmul_tiling.cpp new file mode 100644 index 00000000..c9d22e93 --- /dev/null +++ b/src/transformer/grouped_matmul/ophost/grouped_matmul_tiling.cpp @@ -0,0 +1,999 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_tiling.cpp + * \brief + */ +#include "grouped_matmul_tiling.h" + +#include +#include +#include "register/op_impl_registry.h" +#include "log/ops_log.h" +#include "error/ops_error.h" +#include "tiling/tiling_base.h" +using namespace ge; +using namespace AscendC; + +template +static T1 CeilDiv(T1 a, T2 b) +{ + if (b == 0) { + return 0; + } + return (a + b - 1) / b; +} + +namespace optiling { +constexpr uint32_t X_INDEX = 0; +constexpr uint32_t WEIGHT_INDEX = 1; +constexpr uint32_t BIAS_INDEX = 2; +constexpr uint32_t ANTIQUANT_SCALE_INDEX = 5; +constexpr uint32_t GROUPLIST_INDEX = 7; +constexpr uint32_t PER_TOKEN_SCALE_INDEX = 8; +constexpr uint32_t SCALE_INDEX = 3; +constexpr uint32_t Y_INDEX = 0; +constexpr int64_t BEST_L1_PARTA = 256 * 1024; +constexpr int64_t BEST_L1_PARTB = 128 * 1024; +constexpr int32_t BEST_BASEN = 256; +constexpr int32_t BEST_BASEN_MSD = 512; +constexpr int32_t BEST_UB_BASEK = 256; +constexpr int32_t BEST_UB_BASEN = 512; +constexpr int32_t MAX_BASEM = 256; +constexpr uint32_t A16W8_MSD_STEP = 2; +constexpr uint32_t A16W8_MSD_KN_BASE_BLOCK = 128; +constexpr uint32_t A16W8_MSD_AVERAGE_TOKEN_NUM = 64; +constexpr uint32_t A16W8_MSD_MAX_K = 12 * 1024; +constexpr uint32_t A16W8_MSD_MIN_N = 1024; +constexpr uint32_t UB_BLOCK_UNIT_SIZE = 32; // 32: a block has 32 bytes data +constexpr uint32_t UB_ANTIQUANT_PER_BLOCK_ALIGN = 4 * 1024; +constexpr uint32_t UB_A16W8_BLOCK_NUM_FP16 = 6; // 2 * sizeof(int8) + 2 * sizeof(half) +constexpr uint32_t UB_A16W8_IO_USED_BLOCK_FP16 = 6; +constexpr uint32_t UB_A16W8_BLOCK_NUM_BF16 = 8; // tmpUb used 2 blks +constexpr uint32_t UB_A16W8_IO_USED_BLOCK_BF16 = 6; +constexpr uint32_t UB_A16W4_BLOCK_NUM_FP16 = 5; // 2 * sizeof(int4) + 2 * sizeof(half) +constexpr uint32_t UB_A16W4_IO_USED_BLOCK_FP16 = 5; +constexpr uint32_t UB_A16W4_BLOCK_NUM_BF16 = 7; // tmpUb used 2 blks +constexpr uint32_t UB_A16W4_IO_USED_BLOCK_BF16 = 5; +constexpr uint32_t UB_DYNAMIC_QUANT_BLOCK_NUM = 28; +constexpr uint32_t UB_DUNAMIC_QUANT_IO_USED_BLOCK = 12; +constexpr uint32_t UB_QUANT_BLOCK_ALIGN = 2 * 1024; +constexpr uint32_t UB_A16W8_MSD_BLOCK_NUM = 30; +constexpr uint32_t UB_A16W8_MSD_IO_USED_BLOCK = 6; +constexpr uint32_t UB_A16W8_MSD_BLOCK_ALIGN = 512; +constexpr uint32_t UB_STATIC_QUANT_BLOCK_NUM_BF16 = 20; +constexpr uint32_t UB_STATIC_QUANT_BLOCK_NUM_FP16 = 24; +constexpr uint32_t UB_STATIC_QUANT_IO_USED_BLOCK = 12; +constexpr uint32_t QUEUE_DOUBLE_BUFFER = 2; +constexpr uint32_t FP32_DATATYPE_SIZE = 4; +constexpr uint64_t TILING_KEY = 0; +constexpr uint64_t TILING_KEY_TRANS_X = 1; +constexpr uint64_t TILING_KEY_TRANS_W = 2; +constexpr uint64_t TILING_KEY_ANTIQUANT_PERFORMANCE = 3; +constexpr uint64_t TILING_KEY_QUANT_2VECTOR = 4; +constexpr uint64_t TILING_KEY_QUANT_2VECTOR_TRANS_W = 5; +constexpr uint64_t TILING_KEY_A16W8_MSD = 6; +constexpr uint64_t TILING_KEY_A16W8_MSD_TRANS_W = 7; +constexpr uint64_t ATTR_INDEX_SPLIT_ITEM = 0; +constexpr uint64_t ATTR_INDEX_TRANS_W = 2; +constexpr uint64_t ATTR_INDEX_TRANS_X = 3; +constexpr uint64_t ATTR_INDEX_GROUPTYPE = 4; +constexpr uint32_t ATTR_INDEX_GROUP_LIST_TYPE = 5; +constexpr uint64_t ATTR_INDEX_ACT_TYPE = 6; +constexpr uint64_t DOUBLE_BUFFER_L0A_L0B = 2; +constexpr uint64_t DOUBLE_BUFFER_STEPKA_STEPKB = 2; +constexpr uint32_t SYS_WORKSPACE_SIZE = 16 * 1024 * 1024; +constexpr int32_t NO_SPLIT = -1; +constexpr int32_t SPLIT_M = 0; +constexpr int32_t SPLIT_K = 2; +constexpr int64_t ANTIQUANT_PERFORMANCE_THRESHOLD = 5 * 1024 * 1024; // used for whether going into performance branch in antiquant case, by experiment +constexpr int64_t ACT_TYPE_GELU = 2; +constexpr uint16_t MAX_TENSOR_CONT = 128; + +static inline uint32_t SixteenAlign(uint32_t a, bool up = false) { + if (up) { + a += 15; // 15: 16 bytes up-align + } + return a & ~15; // ~15: 16 bytes down-align +} + +struct GMMCompileInfo { + uint32_t aicNum; + uint32_t aivNum; + uint64_t ubSize; + uint64_t l1Size; + uint64_t l2Size; + uint64_t l0CSize; + uint64_t l0ASize; + uint64_t l0BSize; + platform_ascendc::SocVersion socVersion; +}; + +class GMMTiling { + public: + GMMTilingData tilingData; + ge::graphStatus Init(const gert::TilingContext* context); + ge::graphStatus RunFusionKernelTiling(gert::TilingContext* context); + + protected: + ge::graphStatus CalMMTiling(const gert::TilingContext* context, const GMMCompileInfo* compileInfoPtr); + ge::graphStatus GMMSetMMTiling(const gert::TilingContext* context, const GMMCompileInfo* compileInfoPtr); + void GMMSetTilingKey(gert::TilingContext* context) const; + ge::graphStatus GMMGetAttrs(const gert::TilingContext* context); + ge::graphStatus GMMSetUbDivideBlk(); + ge::graphStatus GMMSetUbDivideBlkAntiquant(); + ge::graphStatus GMMSetUbDivideBlkQuant(); + ge::graphStatus GMMCalUbSize(const gert::TilingContext* context, uint32_t ubSize); + int64_t GMMGetBS(const gert::Shape xShape) const; + ge::graphStatus PrepareTilingData(const gert::TilingContext* context); + ge::graphStatus CheckWeightNZShape(const gert::TilingContext* context, int64_t numInOneBlk) const; + ge::graphStatus GMMGetTensorShapeSplitM(const gert::TilingContext* context, const gert::Shape xShape, + const gert::Shape wShape); + ge::graphStatus GMMGetTensorShapeSplitK(const gert::TilingContext* context, const gert::Shape xShape, + const gert::Shape wShape); + ge::graphStatus SplitMSingleXSingleWeightSingleY(const gert::Shape xShape, const gert::Shape wShape); + ge::graphStatus SplitMSingleXSeparatedWeight(const gert::TilingContext* context, const gert::Shape xShape); + ge::graphStatus SeparatedXSeparatedWeight(const gert::TilingContext* context); + ge::graphStatus SeparatedXSingleWeight(const gert::TilingContext* context, const gert::Shape wShape); + ge::graphStatus SplitKSingleXSingleWeightSingleY(const gert::TilingContext* context, const gert::Shape xShape, + const gert::Shape wShape); + ge::graphStatus DivideUbAndSetWorkspace(gert::TilingContext* context, const uint32_t& aicNum); + void DivideUbAndSetWorkspaceAntiquant(size_t* workspaces, const uint32_t& aicNum, uint32_t &ubSize); + ge::graphStatus SetBias(const gert::TilingContext* context, matmul_tiling::MultiCoreMatmulTiling& mm) const; + int32_t FindBestSingleNPertoken(const uint32_t aicNum) const; + ge::graphStatus SetWorkspscesPerTokenQuant(const uint32_t aicNum, size_t* workspaces); + void SetTilingDataIsSingleTensor(); + ge::graphStatus GetPerGroupNum(const gert::TilingContext* context); + ge::graphStatus CheckMKN(const gert::TilingContext* context); + + private: + int32_t mList_[MAX_TENSOR_CONT] = {0}; + int32_t kList_[MAX_TENSOR_CONT] = {0}; + int32_t nList_[MAX_TENSOR_CONT] = {0}; + int64_t maxM_ = 0; + int64_t maxN_ = 0; + int64_t maxK_ = 0; + int32_t minK_ = INT32_MAX; + int32_t baseM_; + int32_t baseN_; + int32_t baseK_; + uint64_t ubSize_; + uint32_t mmDataTypeSize_; + uint32_t ubDivideBlkNum_; + uint32_t ubIoBlkNum_; + uint32_t ubBlockAlign_; + uint64_t workspacesSize_ = 0; // for antiquant + uint32_t groupNum_ = 0; + bool transposeWeight_; + bool transposeX_; + bool isSingleWeight_; + bool isSingleX_; + bool isSingleY_; + bool isAllSingleTensor_; + bool hasBias_; + int32_t groupType_; + int64_t splitItem_; + uint32_t groupListType_; + uint32_t xKDim_; + uint32_t weightNDim_; + uint32_t xDimNum_; + bool antiquantPerformance_ = false; + uint32_t actType_; + + ge::DataType xDType_ = ge::DT_UNDEFINED; + ge::DataType mmDType_ = ge::DT_UNDEFINED; + ge::DataType weightDtype_ = ge::DT_UNDEFINED; + ge::DataType scaleDtype_ = ge::DT_UNDEFINED; + ge::DataType yDtype_ = ge::DT_UNDEFINED; + uint32_t perTokenOrPerGroupSize_ = 0; // in quant case, it indicates pertoken flag; in antiquant case, it represents pergroup size + bool isA16W8Msd_ = false; + uint32_t totalM_ = 0; + matmul_tiling::CubeFormat wFormat_; + int32_t nzFactor_; // for weight nz format +}; + +ge::graphStatus GMMTiling::CheckWeightNZShape(const gert::TilingContext* context, int64_t numInOneBlk) const { + OPS_ERR_IF(numInOneBlk <= 0, OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "numInOneBlk, the " + "input of CheckWeightNZShape has an invaild value %ld", numInOneBlk), return ge::GRAPH_FAILED); + size_t i = 0; + while (true) { + auto wTensor = context->GetDynamicInputTensor(WEIGHT_INDEX, i++); + if (wTensor == nullptr) { break; } + gert::Shape wOriginShape = wTensor->GetOriginShape(); + int64_t lastDimValue = wOriginShape.GetDim(wOriginShape.GetDimNum() - 1); // inner axis + OPS_ERR_IF(lastDimValue % numInOneBlk != 0, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), + "the inner axis size of nz weight is expected to be a multiple of 32B, " + "but now the inner axis size is %ld.", lastDimValue), + return ge::GRAPH_FAILED); + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus GMMTiling::CheckMKN(const gert::TilingContext* context) { + mmDataTypeSize_ = GetSizeByDataType(mmDType_); + OPS_ERR_IF(mmDataTypeSize_ == 0, OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), + "GMM get mm dtype[%s] size is 0.", TypeUtils::DataTypeToAscendString(mmDType_).GetString()), + return ge::GRAPH_FAILED); + uint32_t numInOneBlk = ONE_BLK_SIZE / mmDataTypeSize_; + OPS_ERR_IF(numInOneBlk == 0, OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), + "GMM numInOneBlk cannot be 0."), return ge::GRAPH_FAILED); + int64_t maxMKN = INT_MAX / numInOneBlk * numInOneBlk; + OPS_ERR_IF(maxM_ > maxMKN || maxN_ > maxMKN || maxK_ > maxMKN, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), + "32B-aligned m, n or k axis is out of range int32!"), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +void GMMTiling::SetTilingDataIsSingleTensor() { + tilingData.gmmBaseParams.set_singleWeight(static_cast(isSingleWeight_)); + tilingData.gmmBaseParams.set_singleX(static_cast(isSingleX_)); + tilingData.gmmBaseParams.set_singleY(static_cast(isSingleY_)); +} + +ge::graphStatus GMMTiling::PrepareTilingData(const gert::TilingContext* context) { + // get transpose and groupType + OPS_ERR_IF(GMMGetAttrs(context) != ge::GRAPH_SUCCESS, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "GMMGetAttrs failed"), + return ge::GRAPH_FAILED); + // get the first tensor's shape of weight and x + auto xTensor = context->GetDynamicInputTensor(X_INDEX, 0); // 0: get first tensor + OPS_LOG_E_IF_NULL(context, xTensor, return ge::GRAPH_FAILED); + gert::Shape xShape = xTensor->GetStorageShape(); + xDimNum_ = static_cast(xShape.GetDimNum()); + xKDim_ = transposeX_ ? 0 : xDimNum_ - 1; // 0: when x is transposed, the first dim is k; -1:otherwise, the last dim is k + + auto wTensor = context->GetDynamicInputTensor(WEIGHT_INDEX, 0); + OPS_LOG_E_IF_NULL(context, wTensor, return ge::GRAPH_FAILED); + gert::Shape wShape = wTensor->GetOriginShape(); + uint32_t wDimNum = static_cast(wShape.GetDimNum()); + weightNDim_ = transposeWeight_ ? wDimNum - 2 : wDimNum - 1; // -2: when w is transposed, the last 2 dim is n; -1: otherwise, the last dim is n + nzFactor_ = 1; // init + if (wFormat_ == matmul_tiling::CubeFormat::NZ) { + uint32_t numInOneBlk = UB_BLOCK_UNIT_SIZE / std::max(1, GetSizeByDataType(weightDtype_)); + if (wDimNum >= 4) { // 4: least dim num of nz format tensor + weightNDim_ = transposeWeight_ ? wDimNum - 3 : wDimNum - 4; // -3: when w is transposed, the last 3 dim is n/nzFactor; -4: when w has nz format, the last 4 dim is n/nzFactor + // nzFactor_ is a factor used to compute n axis size. If weight is transposed, nzFactor_ is 16; otherwise nzFactor_ is 16 for bf16, 32 for int8 + nzFactor_ = transposeWeight_ ? 16 : static_cast(numInOneBlk); + } else { + OPS_ERR_IF(CheckWeightNZShape(context, static_cast(numInOneBlk)) != ge::GRAPH_SUCCESS, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "the shape of nz weight is invaild."), + return ge::GRAPH_FAILED); + } + } + isSingleWeight_ = (context->GetDynamicInputTensor(WEIGHT_INDEX, 1) == nullptr); + isSingleX_ = (context->GetDynamicInputTensor(X_INDEX, 1) == nullptr); + isSingleY_ = (splitItem_ == 2 || splitItem_ == 3); // 2: when x is multi-tensor, y is single-tensor; 3: when x is single-tensor, y is single-tensor + SetTilingDataIsSingleTensor(); + + if (groupType_ == SPLIT_M) { + return GMMGetTensorShapeSplitM(context, xShape, wShape); + } + if (groupType_ == SPLIT_K) { + return GMMGetTensorShapeSplitK(context, xShape, wShape); + } + if (groupType_ == NO_SPLIT) { // not split any axis + if (isSingleWeight_ && wDimNum > 2) { // 2: dim of splited weight tensor + return SeparatedXSingleWeight(context, wShape); + } + return SeparatedXSeparatedWeight(context); + } + OPS_LOG_E(context->GetNodeName(), "GMM_tiling: not support groupType_=%d, isSingleWeight_=%d, isSingleX_=%d, isSingleY_=%d", + groupType_, isSingleWeight_, isSingleX_, isSingleY_); + return ge::GRAPH_FAILED; +} + +ge::graphStatus GMMTiling::GMMGetTensorShapeSplitM(const gert::TilingContext* context, const gert::Shape xShape, + const gert::Shape wShape) { + if (isSingleX_ && isSingleWeight_ && isSingleY_) { // split M, s-s-s + return SplitMSingleXSingleWeightSingleY(xShape, wShape); + } + if (isSingleX_ && !isSingleWeight_ && isSingleY_) { // split M, s-m-s + return SplitMSingleXSeparatedWeight(context, xShape); + } + if (isSingleX_ && !isSingleWeight_ && !isSingleY_) { // splitM, s-m-m + return SplitMSingleXSeparatedWeight(context, xShape); + } + if (!isSingleX_ && !isSingleWeight_ && isSingleY_) { // split M, m-m-s + return SeparatedXSeparatedWeight(context); + } + if (!isSingleX_ && isSingleWeight_) { // split M, m-s-m/m-s-s + return SeparatedXSingleWeight(context, wShape); + } + if (!isSingleX_ && !isSingleWeight_ && !isSingleY_) { // split M, m-m-m + return SeparatedXSeparatedWeight(context); + } + OPS_LOG_E(context->GetNodeName(), "GMM_tiling: not support groupType_=%d, isSingleWeight_=%d, isSingleX_=%d, isSingleY_=%d", + groupType_, isSingleWeight_, isSingleX_, isSingleY_); + return ge::GRAPH_FAILED; +} + +ge::graphStatus GMMTiling::GMMGetTensorShapeSplitK(const gert::TilingContext* context, const gert::Shape xShape, + const gert::Shape wShape) { + if (isSingleX_ && isSingleWeight_ && isSingleY_) { // splitK, s-s-s + return SplitKSingleXSingleWeightSingleY(context, xShape, wShape); + } + if (!isSingleX_ && isSingleWeight_) { // splitK, m-s-m/m-s-s + return SeparatedXSingleWeight(context, wShape); + } + OPS_LOG_E(context->GetNodeName(), "GMM_tiling: not support groupType_=%d, isSingleWeight_=%d, isSingleX_=%d, isSingleY_=%d", + groupType_, isSingleWeight_, isSingleX_, isSingleY_); + return ge::GRAPH_FAILED; +} + +/** @brief split M:single-single-single(s-s-s) +*/ +ge::graphStatus GMMTiling::SplitMSingleXSingleWeightSingleY(const gert::Shape xShape, const gert::Shape wShape) { + groupNum_ = static_cast(wShape.GetDim(0)); + int64_t m = GMMGetBS(xShape); + int64_t k = xShape.GetDim(xKDim_); + int64_t n = wShape.GetDim(weightNDim_) * static_cast(nzFactor_); + kList_[0] = static_cast(k); // if split M axis, the K axis values of x tensorList are all the same. + nList_[0] = static_cast(n); + mList_[0] = -1; + maxM_ = m; + maxK_ = k; + maxN_ = n; + totalM_ = m; + return ge::GRAPH_SUCCESS; +} + +/** @brief split M:single-multi-single(s-m-s)/single-multi-multi(s-m-m), share the same function. +*/ +ge::graphStatus GMMTiling::SplitMSingleXSeparatedWeight(const gert::TilingContext* context, const gert::Shape xShape) { + int64_t m = GMMGetBS(xShape); + int64_t k = xShape.GetDim(xKDim_); + for (uint32_t i = 0; i < MAX_TENSOR_CONT; i++) { + auto wTensor = context->GetDynamicInputTensor(WEIGHT_INDEX, i); + if (wTensor == nullptr) { break; } // when x has multi tensors, xTensor is allowed to be empty + auto wShape = wTensor->GetOriginShape(); + + groupNum_ += 1; + kList_[i] = static_cast(k); + int64_t n = wShape.GetDim(weightNDim_) * nzFactor_; + nList_[i] = static_cast(n); + maxN_ = std::max(maxN_, n); + } + mList_[0] = -1; // mList is unknown right now + maxM_ = m; + maxK_ = k; + totalM_ = m; + + return ge::GRAPH_SUCCESS; +} + +/** @brief split M:multi-multi-single(m-m-s); no split: multi-multi-multi(m-m-m), share the same function +*/ +ge::graphStatus GMMTiling::SeparatedXSeparatedWeight(const gert::TilingContext* context) { + for (uint32_t i = 0; i < MAX_TENSOR_CONT; i++) { + auto wTensor = context->GetDynamicInputTensor(WEIGHT_INDEX, i); + auto xTensor = context->GetDynamicInputTensor(X_INDEX, i); + if (wTensor == nullptr || xTensor == nullptr) { break; } + auto wShape = wTensor->GetOriginShape(); + auto xShape = xTensor->GetStorageShape(); + groupNum_ += 1; + int64_t m = GMMGetBS(xShape); + int64_t k = xShape.GetDim(xKDim_); + int64_t n = wShape.GetDim(weightNDim_) * nzFactor_; + mList_[i] = static_cast(m); + kList_[i] = static_cast(k); + nList_[i] = static_cast(n); + maxM_ = std::max(maxM_, m); + maxK_ = std::max(maxK_, k); + maxN_ = std::max(maxN_, n); + totalM_ += m; + } + groupType_ = NO_SPLIT; + return ge::GRAPH_SUCCESS; +} + +/** @brief split M : multi-single-multi(m-s-m), split K : multi-single-multi(m-s-m), share the same function +*/ +ge::graphStatus GMMTiling::SeparatedXSingleWeight(const gert::TilingContext* context, const gert::Shape wShape) { + int64_t n = wShape.GetDim(weightNDim_) * nzFactor_; + for (uint32_t i = 0; i < MAX_TENSOR_CONT; i++) { + auto xTensor = context->GetDynamicInputTensor(X_INDEX, i); + if (xTensor == nullptr) { break; } // when x has multi tensors, xTensor is allowed to be empty + auto xShape = xTensor->GetStorageShape(); + groupNum_ += 1; + int64_t m = GMMGetBS(xShape); + int64_t k = xShape.GetDim(xKDim_); + mList_[i] = static_cast(m); + kList_[i] = static_cast(k); + nList_[i] = static_cast(n); + maxM_ = std::max(maxM_, m); + maxK_ = std::max(maxK_, k); + totalM_ += m; + } + maxN_ = n; + groupType_ = NO_SPLIT; + return ge::GRAPH_SUCCESS; +} + +/** @brief split K single-single-single +*/ +ge::graphStatus GMMTiling::SplitKSingleXSingleWeightSingleY(const gert::TilingContext* context, + const gert::Shape xShape, const gert::Shape wShape) { + int64_t m = GMMGetBS(xShape); + int64_t k = xShape.GetDim(xKDim_); + int64_t n = wShape.GetDim(weightNDim_) * nzFactor_; + + auto groupListTensor = context->GetDynamicInputTensor(GROUPLIST_INDEX, 0); + if (groupListTensor == nullptr) { + OPS_LOG_E(context->GetNodeName(), "groupListTensor is nullptr"); + return ge::GRAPH_FAILED; + } + gert::Shape groupListShape = groupListTensor->GetStorageShape(); + groupNum_ = static_cast(groupListShape.GetDim(0)); // 0: the first dim of groupList is groupNum + mList_[0] = static_cast(m); + nList_[0] = static_cast(n); + kList_[0] = -1; + maxM_ = m; + maxN_ = n; + maxK_ = k; + totalM_ = m; + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus GMMTiling::Init(const gert::TilingContext* context) { + OPS_ERR_IF(PrepareTilingData(context) != ge::GRAPH_SUCCESS, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "GMM PrepareTilingData failed."), + return ge::GRAPH_FAILED); + auto compileInfoPtr = context->GetCompileInfo(); + OPS_LOG_E_IF_NULL(context, compileInfoPtr, return ge::GRAPH_FAILED); // check compileInfoPtr is not null + // check whether x, weight and y are all single tensor + isAllSingleTensor_ = isSingleX_ && isSingleWeight_ && isSingleY_; + bool isA16W8 = (xDType_ == ge::DT_FLOAT16 || xDType_ == ge::DT_BF16) && weightDtype_ == ge::DT_INT8; + // check whether k and n are supported in msd + bool isKNForA16W8MSD = maxN_ % A16W8_MSD_KN_BASE_BLOCK == 0 && maxK_ % A16W8_MSD_KN_BASE_BLOCK == 0 && + maxK_ <= A16W8_MSD_MAX_K && maxN_ >= A16W8_MSD_MIN_N; + // check whether total token num and average token num are supported in msd + bool isMForA16W8MSD = totalM_ <= A16W8_MSD_AVERAGE_TOKEN_NUM * groupNum_; + isA16W8Msd_ = isAllSingleTensor_ && groupType_ == SPLIT_M && isA16W8 && isKNForA16W8MSD && isMForA16W8MSD; + mmDType_ = isA16W8Msd_ ? ge::DT_INT8 : xDType_; + OPS_ERR_IF(CheckMKN(context) != ge::GRAPH_SUCCESS, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "GMM CheckMKN failed."), + return ge::GRAPH_FAILED); + auto biasPtr = context->GetDynamicInputTensor(BIAS_INDEX, 0); // 0: obtain the first tensor of the tensorList + hasBias_ = !(biasPtr == nullptr || biasPtr->GetStorageShape().GetShapeSize() == 0); + + tilingData.gmmArray.set_mList(mList_); + tilingData.gmmArray.set_kList(kList_); + tilingData.gmmArray.set_nList(nList_); + tilingData.gmmBaseParams.set_groupNum(groupNum_); + tilingData.gmmBaseParams.set_m(totalM_); + tilingData.gmmBaseParams.set_hasBias(static_cast(hasBias_)); + tilingData.gmmBaseParams.set_groupType(static_cast(groupType_)); + tilingData.gmmBaseParams.set_activeType(actType_); + tilingData.gmmBaseParams.set_quantParam(perTokenOrPerGroupSize_); + tilingData.gmmBaseParams.set_groupListType(groupListType_); + OPS_LOG_I(context->GetNodeName(), "GMM_tiling: groupNum_ is %u, maxM_ is %ld, maxK_ is %ld, maxN_ is %ld.", + groupNum_, maxM_, maxK_, maxN_); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus GMMTiling::GetPerGroupNum(const gert::TilingContext* context) { + auto antiquantScale = context->GetDynamicInputTensor(ANTIQUANT_SCALE_INDEX, 0); + OPS_LOG_E_IF_NULL(context, antiquantScale, return ge::GRAPH_FAILED); + auto antiquantScaleShape = antiquantScale->GetStorageShape(); + int64_t dimNum = antiquantScaleShape.GetDimNum(); + auto wTensor = context->GetDynamicInputTensor(WEIGHT_INDEX, 0); + OPS_LOG_E_IF_NULL(context, wTensor, return ge::GRAPH_FAILED); + gert::Shape wShape = wTensor->GetOriginShape(); + size_t wDimNum = wShape.GetDimNum(); + if ((isSingleWeight_ && wDimNum > 2 && dimNum == 3) || (!isSingleWeight_ && dimNum == 2)) { // 2 and 3: dim threshold + int64_t g = antiquantScaleShape.GetDim(dimNum - 2); + perTokenOrPerGroupSize_ = g > 1 ? kList_[0] / g : 0; + tilingData.gmmBaseParams.set_quantParam(perTokenOrPerGroupSize_); + } + return ge::GRAPH_SUCCESS; +} + +void GMMTiling::DivideUbAndSetWorkspaceAntiquant(size_t* workspaces, const uint32_t& aicNum, uint32_t &ubSize) { + if (isA16W8Msd_) { + // whole workspace for a16w8 msd scene is combined by workspace for global max of each row (m * 8), + // workspace for local reduce sum of each row (m * aicNum), workspace for data after prerpocessing x (2 * m * k) + // and workspace for matmul output (2 * m * n) + // 32: need 32 byte to store global max + workspaces[0] += (aicNum * sizeof(float) + 32 + + A16W8_MSD_STEP * (maxK_ * sizeof(int8_t) + maxN_ * sizeof(int32_t))) * totalM_; + // 7: make aicnum align up to 8 + uint32_t alignedAicNum = (aicNum + 7) & (~7); + ubSize = static_cast(ubSize_ - (baseM_ / A16W8_MSD_STEP) * alignedAicNum * sizeof(float)); + // workspacesSize in GMMBaseParams is size of matmul left input matrix of matmul. + workspacesSize_ += A16W8_MSD_STEP * static_cast(maxK_) * static_cast(maxM_); + } else { + for (uint32_t i = 0; i < groupNum_; i++) { + bool isAllSingleTensor = isSingleX_ && isSingleWeight_ && isSingleY_; + int32_t kInList = isAllSingleTensor ? kList_[0] : kList_[i]; // in s-s-s case,k only exits in the first of the list + int32_t nInList = isAllSingleTensor ? nList_[0] : nList_[i]; // in s-s-s case,n only exits in the first of the list + int32_t k = kList_[0] == -1 ? static_cast(maxK_) : kInList; + int32_t n = nList_[0] == -1 ? static_cast(maxN_) : nInList; + minK_ = std::min(minK_, k); + workspacesSize_ += static_cast(k) * static_cast(n); + } + // when minK * baseN * coreNum * sizeof(float16) > 12M, it goes into antiquantPerformance branch (12M is obtained by test). + int32_t dimMN = + CeilDiv(CeilDiv(maxM_, groupNum_), baseM_) * CeilDiv(maxN_, baseN_); + bool goodCubeUtility = dimMN * (xDType_ == ge::DT_BF16 ? 2 : 1) >= static_cast(aicNum * 0.4); // 0.4: a factor, in practice. + antiquantPerformance_ = + goodCubeUtility && static_cast(minK_) * baseN_ * aicNum >= ANTIQUANT_PERFORMANCE_THRESHOLD; + uint32_t maxUbBaseN = BEST_UB_BASEN; + if (transposeWeight_) { + maxUbBaseN = baseN_; + } else if (antiquantPerformance_) { + // 2: use 2 pieces of workspace in antiquantPerformance branch + workspacesSize_ = static_cast(maxN_) * maxK_ * 2; + } + // 2: 2 InQueue(antiquant_scale,antiquant_offset) + ubSize = static_cast(ubSize_ - maxUbBaseN * mmDataTypeSize_ * QUEUE_DOUBLE_BUFFER * 2); + workspaces[0] += workspacesSize_ * mmDataTypeSize_; + } +} + +ge::graphStatus GMMTiling::DivideUbAndSetWorkspace(gert::TilingContext* context, const uint32_t& aicNum) { + size_t* workspaces = context->GetWorkspaceSizes(1); // get second variable + OPS_LOG_E_IF_NULL(context, workspaces, return ge::GRAPH_FAILED); // check workspaces is not null + workspaces[0] = SYS_WORKSPACE_SIZE; // default size + if (weightDtype_ != ge::DT_INT8 && weightDtype_ != ge::DT_INT4) { + return ge::GRAPH_SUCCESS; + } + uint32_t ubSize = static_cast(ubSize_); + if ((xDType_ == ge::DT_BF16 || xDType_ == ge::DT_FLOAT16)) { + DivideUbAndSetWorkspaceAntiquant(workspaces, aicNum, ubSize); + OPS_ERR_IF(GetPerGroupNum(context) != ge::GRAPH_SUCCESS, OPS_REPORT_VECTOR_INNER_ERR( + context->GetNodeName(), "GetPerGroupNum failed."), return ge::GRAPH_FAILED); + } else if (xDType_ == ge::DT_INT8) { + uint32_t scaleDataTypeSize = GetSizeByDataType(scaleDtype_); + ubSize = perTokenOrPerGroupSize_ == 1 ? // is perToken + static_cast(ubSize_ - + (baseN_ * scaleDataTypeSize + baseM_ * sizeof(float)) * QUEUE_DOUBLE_BUFFER) : + static_cast(ubSize_ - baseN_ * scaleDataTypeSize * QUEUE_DOUBLE_BUFFER); + OPS_ERR_IF(SetWorkspscesPerTokenQuant(aicNum, workspaces) != ge::GRAPH_SUCCESS, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "SetWorkspscesPerTokenQuant failed."), + return ge::GRAPH_FAILED); + } + OPS_ERR_IF(GMMSetUbDivideBlk() != ge::GRAPH_SUCCESS, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), + "GMMSetUbDivideBlk failed."), + return ge::GRAPH_FAILED); + OPS_ERR_IF(GMMCalUbSize(context, ubSize) != ge::GRAPH_SUCCESS, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), + "GMMCalUbSize failed."), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +int32_t GMMTiling::FindBestSingleNPertoken(const uint32_t aicNum) const { + if (CeilDiv(maxN_, baseN_) * groupNum_ <= aicNum) { // if all matmuls only occupy a part of cores + return baseN_; + } + if (maxN_ >= 2048) { // 2048: a threshold + return 1024; // 1024: max singleN + } + int32_t bestSingleN = baseN_; // init bestSingleN + uint32_t bestLastCycleCoreNum = (groupNum_ * CeilDiv(maxN_, bestSingleN)) % aicNum; // init lastCycleCoreNum + // 1024: max singleN + for (int32_t tempSingleN = 1024 / baseN_ * baseN_; tempSingleN > baseN_; tempSingleN -= baseN_) { + uint32_t lastCycleCoreNum = (groupNum_ * CeilDiv(maxN_, tempSingleN)) % aicNum; + if (lastCycleCoreNum == 0) { + bestSingleN = tempSingleN; + break; + } + if (lastCycleCoreNum > bestLastCycleCoreNum || + (lastCycleCoreNum == bestLastCycleCoreNum && maxN_ % tempSingleN == 0)) { + bestSingleN = tempSingleN; + bestLastCycleCoreNum = lastCycleCoreNum; + } + } + return bestSingleN; +} + +ge::graphStatus GMMTiling::SetWorkspscesPerTokenQuant(const uint32_t aicNum, size_t* workspaces) { + if (aicNum == 0) { // invaild value + return ge::GRAPH_FAILED; + } + bool opt = (maxM_ <= 32 * groupNum_ && wFormat_ == matmul_tiling::CubeFormat::NZ) && + (!transposeWeight_ || maxN_ >= 2048); // 32: a factor, 2048: a threshold. + if (opt) { // non-basic strategy. matmul output in non-continugous mode with singleN >= baseN + workspaces[0] += maxM_ * maxN_ * sizeof(int32_t); + int32_t bestSingleN = FindBestSingleNPertoken(aicNum); + tilingData.gmmBaseParams.set_singleN(bestSingleN); + } else { + // 4: when do cv parallelism, four pieces of workspace are used for storing four cycles of matmul output + workspaces[0] += 4 * baseM_ * baseN_ * aicNum * sizeof(int32_t); + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus GMMTiling::RunFusionKernelTiling(gert::TilingContext* context) { + OPS_LOG_I(context->GetNodeName(), "Begin Run GMM Tiling"); + auto compileInfoPtr = context->GetCompileInfo(); + OPS_LOG_E_IF_NULL(context, compileInfoPtr, return ge::GRAPH_FAILED); // check compileInfoPtr is not null + + ubSize_ = compileInfoPtr->ubSize; // get ubSize from compileInfo + const uint32_t& aicNum = compileInfoPtr->aicNum; // get aicNum from compileInfo + context->SetBlockDim(aicNum); // block dim is the number of aicube + + OPS_ERR_IF(CalMMTiling(context, compileInfoPtr) != ge::GRAPH_SUCCESS, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "GMM CalMMTiling failed"), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(GMMSetMMTiling(context, compileInfoPtr) != ge::GRAPH_SUCCESS, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "GMM GMMSetMMTiling failed"), + return ge::GRAPH_FAILED); + tilingData.mmTilingData.set_usedCoreNum(aicNum); // usedCoreNum is ai_core num + tilingData.gmmBaseParams.set_coreNum(aicNum); // ai cube number + tilingData.gmmBaseParams.set_singleN(0); // 0 is the default value + OPS_ERR_IF(DivideUbAndSetWorkspace(context, aicNum) != ge::GRAPH_SUCCESS, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "GMM DivideUbAndSetWorkspace failed"), + return ge::GRAPH_FAILED); + tilingData.gmmBaseParams.set_workspaceSize(workspacesSize_); + GMMSetTilingKey(context); // set tilingkey + tilingData.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tilingData.GetDataSize()); + OPS_LOG_I(context->GetNodeName(), "End Run GMM Tiling"); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus GMMTiling::GMMCalUbSize(const gert::TilingContext* context, uint32_t ubSize) { + OPS_ERR_IF((ubDivideBlkNum_ == 0 || ubBlockAlign_ == 0), + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "ubDivideBlkNum and ubBlockAlign cannot be 0"), + return ge::GRAPH_FAILED); + uint32_t ubCalSize = ubSize / ubDivideBlkNum_; // divide the UB into ubDivideBlkNum_ pieces + ubCalSize = ubCalSize / ubBlockAlign_ * ubBlockAlign_; // 16k/8k/4k align. + uint32_t ubRestBytes = ubSize - ubCalSize * ubIoBlkNum_; // compute the rest memory in UB space + ubRestBytes = ubRestBytes / UB_BLOCK_UNIT_SIZE * UB_BLOCK_UNIT_SIZE; // 32B align. + OPS_ERR_IF((ubCalSize == 0 || ubRestBytes == 0), + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "ubCalSize and ubRestBytes cannot be 0"), + return ge::GRAPH_FAILED); + uint32_t ubBaseN = 0; // init + uint32_t ubBaseK = 0; // init + if (transposeWeight_) { + ubBaseK = BEST_UB_BASEK; + ubBaseN = ubCalSize / ubBaseK; + uint32_t alignFactor = UB_BLOCK_UNIT_SIZE; + if (weightDtype_ == ge::DT_INT4) { + alignFactor <<= 1; // int4 need 64 elements algin. + } + ubBaseN = ubBaseN / alignFactor * alignFactor; + } else { + if ((xDType_ == ge::DT_BF16 || xDType_ == ge::DT_FLOAT16) && + (weightDtype_ == ge::DT_INT8 || weightDtype_ == ge::DT_INT4)) { + if (perTokenOrPerGroupSize_ > 0) { + ubBaseK = perTokenOrPerGroupSize_; + static const uint32_t MIN_UB_BASEN = 128; // a threshold + ubBaseN = std::min(BEST_UB_BASEN, std::max(MIN_UB_BASEN, (ubCalSize / ubBaseK + MIN_UB_BASEN - 1) / MIN_UB_BASEN * MIN_UB_BASEN)); + } else if (antiquantPerformance_) { + ubBaseN = BEST_UB_BASEN; + } else { + ubBaseN = baseN_; + } + } else { + ubBaseN = baseN_; + } + ubBaseK = ubCalSize / ubBaseN; // ubCalSize is the number of elements, not in bytes unit. + } + if (xDType_ == ge::DT_BF16 && (weightDtype_ == ge::DT_INT8 || weightDtype_ == ge::DT_INT4) && !isA16W8Msd_) { + OPS_ERR_IF(ubBaseK == 0 || ubBaseN == 0, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "ubBaseK or ubBaseN cannot be 0"), + return ge::GRAPH_FAILED); + } + tilingData.gmmBaseParams.set_ubCalSize(ubCalSize); + tilingData.gmmBaseParams.set_ubRestBytes(ubRestBytes); // in byte unit + tilingData.gmmBaseParams.set_ubBaseK(ubBaseK); + tilingData.gmmBaseParams.set_ubBaseN(ubBaseN); + return ge::GRAPH_SUCCESS; +} + +int64_t GMMTiling::GMMGetBS(const gert::Shape xShape) const { + int64_t bs = 0; // init bs + if (transposeX_) { + bs = xShape.GetDim(1); // x shape is [k, m] if x is transpose_ + } else { + if (groupType_ == -1) { // -1: no group case, may exits a situation that multi dims product equals to bs. + bs = xShape.GetDim(0); // 0: x first dim + size_t bsDimNum = xDimNum_ >= 1 ? xDimNum_ - 1 : 0; // 1: x last dim k, the other dimensions are bs + for (size_t i = 1; i < bsDimNum; i++) { + bs *= xShape.GetDim(i); + } + } else { + bs = xShape.GetDim(0); // in group case,x's shapeis [m,k], 0 is the m axis. + } + } + return bs; +} + +void GMMTiling::GMMSetTilingKey(gert::TilingContext* context) const { + bool transposeXSupportDtype = (weightDtype_ == ge::DT_FLOAT16 || weightDtype_ == ge::DT_BF16 || + weightDtype_ == ge::DT_FLOAT); + if (isA16W8Msd_) { + context->SetScheduleMode(1); // set as batchmod for template using SyncAll + context->SetTilingKey(transposeWeight_ ? TILING_KEY_A16W8_MSD_TRANS_W : TILING_KEY_A16W8_MSD); + return; + } + if (xDType_ == ge::DT_INT8 && weightDtype_ == ge::DT_INT8 && actType_ == ACT_TYPE_GELU) { + if (transposeWeight_) { + context->SetTilingKey(TILING_KEY_QUANT_2VECTOR_TRANS_W); + } else { + context->SetTilingKey(TILING_KEY_QUANT_2VECTOR); + } + return; + } + if (transposeWeight_) { + context->SetTilingKey(TILING_KEY_TRANS_W); + } else if (transposeX_ && transposeXSupportDtype) { + context->SetTilingKey(TILING_KEY_TRANS_X); + } else if (antiquantPerformance_) { + context->SetTilingKey(TILING_KEY_ANTIQUANT_PERFORMANCE); + context->SetScheduleMode(1); // set as batchmod for template using SyncAll + } else { + context->SetTilingKey(TILING_KEY); + } +} + +ge::graphStatus GMMTiling::GMMGetAttrs(const gert::TilingContext* context) { + auto attr = context->GetAttrs(); + OPS_LOG_E_IF_NULL(context, attr, return ge::GRAPH_FAILED); // check attr is not null + const bool* transposeWeightPtr = attr->GetAttrPointer(ATTR_INDEX_TRANS_W); + const bool* transposeXPtr = attr->GetAttrPointer(ATTR_INDEX_TRANS_X); + const int32_t* groupTypePtr = attr->GetAttrPointer(ATTR_INDEX_GROUPTYPE); + const int64_t* splitItemPtr = attr->GetAttrPointer(ATTR_INDEX_SPLIT_ITEM); + const int64_t* actTypePtr = attr->GetAttrPointer(ATTR_INDEX_ACT_TYPE); + const uint32_t* groupListTypePtr = attr->GetAttrPointer(ATTR_INDEX_GROUP_LIST_TYPE); + transposeWeight_ = transposeWeightPtr != nullptr ? *transposeWeightPtr : false; + transposeX_ = transposeXPtr != nullptr ? *transposeXPtr : false; + groupType_ = groupTypePtr != nullptr ? *groupTypePtr : NO_SPLIT; + splitItem_ = splitItemPtr != nullptr ? *splitItemPtr : 0; // 0: 默认split_item + actType_ = actTypePtr != nullptr ? *actTypePtr : 0; + groupListType_ = groupListTypePtr != nullptr ? *groupListTypePtr : 0; + + auto xDesc = context->GetDynamicInputDesc(X_INDEX, 0); + OPS_LOG_E_IF_NULL(context, xDesc, return ge::GRAPH_FAILED); // check xDesc is not null + xDType_ = xDesc->GetDataType(); + auto w0Desc = context->GetDynamicInputDesc(WEIGHT_INDEX, 0); + OPS_LOG_E_IF_NULL(context, w0Desc, return ge::GRAPH_FAILED); + weightDtype_ = w0Desc->GetDataType(); + auto perTokenScalePtr = context->GetOptionalInputTensor(PER_TOKEN_SCALE_INDEX); + if (perTokenScalePtr != nullptr && perTokenScalePtr->GetStorageShape().GetShapeSize() != 0) { + perTokenOrPerGroupSize_ = 1; + } + tilingData.gmmBaseParams.set_quantParam(perTokenOrPerGroupSize_); + if (weightDtype_ == ge::DT_INT8 && xDType_ == ge::DT_INT8) { + auto scale0Desc = context->GetDynamicInputDesc(SCALE_INDEX, 0); + OPS_LOG_E_IF_NULL(context, scale0Desc, return ge::GRAPH_FAILED); + scaleDtype_ = scale0Desc->GetDataType(); + auto yDesc = context->GetOutputDesc(Y_INDEX); + OPS_LOG_E_IF_NULL(context, yDesc, return ge::GRAPH_FAILED); + yDtype_ = yDesc->GetDataType(); + } + auto wFormat0 = static_cast(ge::GetPrimaryFormat(w0Desc->GetStorageFormat())); + wFormat_ = wFormat0 == ge::FORMAT_FRACTAL_NZ ? matmul_tiling::CubeFormat::NZ : matmul_tiling::CubeFormat::ND; + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus GMMTiling::GMMSetUbDivideBlkAntiquant() { + if (isA16W8Msd_) { + ubDivideBlkNum_ = UB_A16W8_MSD_BLOCK_NUM; + ubIoBlkNum_ = UB_A16W8_MSD_IO_USED_BLOCK; + ubBlockAlign_ = UB_A16W8_MSD_BLOCK_ALIGN; + return ge::GRAPH_SUCCESS; + } + if (xDType_ == ge::DT_FLOAT16 && (weightDtype_ == ge::DT_INT8 || weightDtype_ == ge::DT_INT4)) { + if (weightDtype_ == ge::DT_INT8) { + ubDivideBlkNum_ = UB_A16W8_BLOCK_NUM_FP16; + ubIoBlkNum_ = UB_A16W8_IO_USED_BLOCK_FP16; + } else { // int4 + ubDivideBlkNum_ = UB_A16W4_BLOCK_NUM_FP16; + ubIoBlkNum_ = UB_A16W4_IO_USED_BLOCK_FP16; + } + ubBlockAlign_ = UB_ANTIQUANT_PER_BLOCK_ALIGN; + return ge::GRAPH_SUCCESS; + } + if (xDType_ == ge::DT_BF16 && (weightDtype_ == ge::DT_INT8 || weightDtype_ == ge::DT_INT4)) { + if (weightDtype_ == ge::DT_INT8) { + ubDivideBlkNum_ = UB_A16W8_BLOCK_NUM_BF16; + ubIoBlkNum_ = UB_A16W8_IO_USED_BLOCK_BF16; + } else { + ubDivideBlkNum_ = UB_A16W4_BLOCK_NUM_BF16; + ubIoBlkNum_ = UB_A16W4_IO_USED_BLOCK_BF16; + } + ubBlockAlign_ = UB_ANTIQUANT_PER_BLOCK_ALIGN; + return ge::GRAPH_SUCCESS; + } + return ge::GRAPH_FAILED; +} + +ge::graphStatus GMMTiling::GMMSetUbDivideBlkQuant() { + if (weightDtype_ == ge::DT_INT8 && (perTokenOrPerGroupSize_ == 1 || actType_ != 0)) { + // include case per-token without activation, per-token with activation and per-tensor with activation + ubDivideBlkNum_ = UB_DYNAMIC_QUANT_BLOCK_NUM; + ubIoBlkNum_ = UB_DUNAMIC_QUANT_IO_USED_BLOCK; + ubBlockAlign_ = UB_QUANT_BLOCK_ALIGN; + return ge::GRAPH_SUCCESS; + } + if (weightDtype_ == ge::DT_INT8 && perTokenOrPerGroupSize_ != 1) { + // include case per-tensor without activation + if (yDtype_ == ge::DT_FLOAT16) { + ubDivideBlkNum_ = UB_STATIC_QUANT_BLOCK_NUM_FP16; + } else { + ubDivideBlkNum_ = UB_STATIC_QUANT_BLOCK_NUM_BF16; + } + ubIoBlkNum_ = UB_STATIC_QUANT_IO_USED_BLOCK; + ubBlockAlign_ = UB_QUANT_BLOCK_ALIGN; + return ge::GRAPH_SUCCESS; + } + return ge::GRAPH_FAILED; +} + +ge::graphStatus GMMTiling::GMMSetUbDivideBlk() { + ubDivideBlkNum_ = 0; // init ubDivideBlkNum_ + ubIoBlkNum_ = 0; // init ubIoBlkNum_ + ubBlockAlign_ = 0; // init ubBlockAlign_ + if (xDType_ == ge::DT_INT8) { + return GMMSetUbDivideBlkQuant(); + } else { + return GMMSetUbDivideBlkAntiquant(); + } + return ge::GRAPH_FAILED; +} + +ge::graphStatus GMMTiling::SetBias(const gert::TilingContext* context, matmul_tiling::MultiCoreMatmulTiling& mm) const { + if (!hasBias_ || isA16W8Msd_) { + mm.SetBias(false); + } else { + mm.SetBias(true); + auto biasTensor = context->GetDynamicInputTensor(BIAS_INDEX, 0); + OPS_ERR_IF(biasTensor == nullptr, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "Get bias tensor failed."), + return ge::GRAPH_FAILED); + mm.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, + static_cast(biasTensor->GetDataType())); + } + return ge::GRAPH_SUCCESS; +} + +static void InitPlatformInfo(const GMMCompileInfo* compileInfoPtr, matmul_tiling::PlatformInfo& platformInfo) { + platformInfo.socVersion = compileInfoPtr->socVersion; + platformInfo.l1Size = compileInfoPtr->l1Size; + platformInfo.l0CSize = compileInfoPtr->l0CSize; + platformInfo.ubSize = compileInfoPtr->ubSize; + platformInfo.l0ASize = compileInfoPtr->l0ASize; + platformInfo.l0BSize = compileInfoPtr->l0BSize; +} + +ge::graphStatus GMMTiling::GMMSetMMTiling(const gert::TilingContext* context, const GMMCompileInfo* compileInfoPtr) { + matmul_tiling::DataType matmulDtype = static_cast(mmDType_); + matmul_tiling::PlatformInfo platformInfo; + InitPlatformInfo(compileInfoPtr, platformInfo); + matmul_tiling::MultiCoreMatmulTiling mm(platformInfo); + int64_t mInMM = isA16W8Msd_ ? A16W8_MSD_STEP * maxM_ : maxM_; + mm.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmulDtype, false); + mm.SetBType(matmul_tiling::TPosition::GM, wFormat_, matmulDtype, false); + mm.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND_ALIGN, matmul_tiling::DataType::DT_FLOAT16); + OPS_ERR_IF(SetBias(context, mm) != ge::GRAPH_SUCCESS, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "SetBias failed."), return ge::GRAPH_FAILED); + mm.SetOrgShape(mInMM, maxN_, maxK_); + mm.SetShape(mInMM, baseN_, maxK_); + mm.SetFixSplit(baseM_, baseN_, baseK_); + mm.SetBufferSpace(compileInfoPtr->l1Size, compileInfoPtr->l0CSize, ubSize_); + OPS_ERR_IF(mm.GetTiling(tilingData.mmTilingData) == -1, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "matmul getTiling failed."), + return ge::GRAPH_FAILED); + // according to double buffer, recompute the params used for data movement from GM to L1 + uint32_t mmStepKa = (BEST_L1_PARTB >> 1) / (baseM_ * baseK_ * mmDataTypeSize_); + if (compileInfoPtr->socVersion == platform_ascendc::SocVersion::ASCEND310P && + wFormat_ == matmul_tiling::CubeFormat::NZ && mInMM <= baseM_) { + mmStepKa = std::min(mmStepKa, std::max(1, 128 / baseK_)); // 128: nz inner block size. In practice, baseK_*mmStepKa=128 makes performance better. + } + uint32_t mmStepKb = (BEST_L1_PARTA >> 1) / (baseN_ * baseK_ * mmDataTypeSize_); + + if (mmStepKa > mmStepKb) { + mmStepKa = mmStepKa / mmStepKb * mmStepKb; + } else if (mmStepKa < mmStepKb) { + mmStepKb = mmStepKb / mmStepKa * mmStepKa; + } + constexpr uint32_t stepM = 1; // 1: stepM set fixed value 1 + constexpr uint32_t stepN = 1; // 1: stepN set fixed value 1 + uint32_t mmDepthA1 = mmStepKa * DOUBLE_BUFFER_STEPKA_STEPKB * stepM; + uint32_t mmDepthB1 = mmStepKb * DOUBLE_BUFFER_STEPKA_STEPKB * stepN; + tilingData.mmTilingData.set_shareMode(0); + if (compileInfoPtr->socVersion == platform_ascendc::SocVersion::ASCEND310P) { + tilingData.mmTilingData.set_shareUbSize(0); + tilingData.mmTilingData.set_transLength(131072); // 131072: 128KB size + } + tilingData.mmTilingData.set_dbL0C(1); // disable double buffer for LOC + tilingData.mmTilingData.set_baseM(baseM_); // set precomputed baseM + tilingData.mmTilingData.set_baseN(baseN_); // set precomputed baseN + tilingData.mmTilingData.set_baseK(baseK_); // set precomputed baseK + tilingData.mmTilingData.set_stepKa(mmStepKa); // set precomputed mmStepKa + tilingData.mmTilingData.set_depthA1(mmDepthA1); // set precomputed mmDepthA1 + tilingData.mmTilingData.set_stepKb(mmStepKb); // set precomputed mmStepKb + tilingData.mmTilingData.set_depthB1(mmDepthB1); // set precomputed mmDepthB1 + tilingData.mmTilingData.set_stepM(stepM); // set precomputed stepM + tilingData.mmTilingData.set_stepN(stepN); // set precomputed stepN + OPS_LOG_I(context->GetNodeName(), "GMM_tiling: baseM is %d, baseK is %d, baseN is %d.", baseM_, baseK_, baseN_); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus GMMTiling::CalMMTiling(const gert::TilingContext* context, const GMMCompileInfo* compileInfoPtr) { + // 2048: min n for a16w8 msd to set baseN 512 + baseN_ = isA16W8Msd_ && maxN_ >= 2048 && !transposeWeight_ ? BEST_BASEN_MSD : BEST_BASEN; + // according to the double buffer enabled L0B, compute baseK + baseK_ = (compileInfoPtr->l0BSize / DOUBLE_BUFFER_L0A_L0B) / (baseN_ * mmDataTypeSize_); + baseK_ = SixteenAlign(baseK_); + OPS_ERR_IF(baseK_ == 0, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "baseK_ cannot be 0."), + return ge::GRAPH_FAILED); + // according to the double buffer enabled L0A/L0C, compute baseM(cube) + uint32_t maxBaseM = compileInfoPtr->l0CSize / (baseN_ * FP32_DATATYPE_SIZE); + baseM_ = std::min((compileInfoPtr->l0ASize / DOUBLE_BUFFER_L0A_L0B) / (baseK_ * mmDataTypeSize_), + maxBaseM); + + if (!isA16W8Msd_) { + baseM_ = baseM_ > maxM_ ? SixteenAlign(maxM_, true) : SixteenAlign(baseM_); + } else { + baseM_ = baseM_ > A16W8_MSD_STEP * maxM_ ? SixteenAlign(A16W8_MSD_STEP * maxM_, true) : SixteenAlign(baseM_); + } + if (baseM_ > MAX_BASEM) { + baseM_ = MAX_BASEM; + } + OPS_ERR_IF(baseM_ == 0, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "baseM_ cannot be 0."), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ASCENDC_EXTERN_C ge::graphStatus TilingGMM(gert::TilingContext* context) { + GMMTiling tiling; + OPS_ERR_IF(tiling.Init(context) != ge::GRAPH_SUCCESS, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "GMM tiling init failed"), + return ge::GRAPH_FAILED); + return tiling.RunFusionKernelTiling(context); +} + +ASCENDC_EXTERN_C ge::graphStatus TilingPrepareForGMM(gert::TilingParseContext* context) { + fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo(); + OPS_LOG_E_IF_NULL(context, platformInfoPtr, return ge::GRAPH_FAILED); + auto compileInfoPtr = context->GetCompiledInfo(); + OPS_LOG_E_IF_NULL(context, compileInfoPtr, return ge::GRAPH_FAILED); + + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr); + compileInfoPtr->aicNum = ascendcPlatform.GetCoreNumAic(); + compileInfoPtr->aivNum = ascendcPlatform.GetCoreNumAiv(); + compileInfoPtr->socVersion = ascendcPlatform.GetSocVersion(); + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, compileInfoPtr->ubSize); + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L1, compileInfoPtr->l1Size); + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L0_A, compileInfoPtr->l0ASize); + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L0_B, compileInfoPtr->l0BSize); + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L0_C, compileInfoPtr->l0CSize); + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L2, compileInfoPtr->l2Size); + + OPS_ERR_IF((compileInfoPtr->aicNum == 0 || compileInfoPtr->aivNum == 0 || compileInfoPtr->ubSize == 0 || \ + compileInfoPtr->l1Size == 0 || compileInfoPtr->l0CSize == 0 || compileInfoPtr->l0ASize == 0 || \ + compileInfoPtr->l0BSize == 0), + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), + "platform info is invalid, aicNum=%u, aivNum=%u, ubSize=%lu, l1Size=%lu, l0CSize=%lu, l0ASize=%lu, l0BSize=%lu", + compileInfoPtr->aicNum, compileInfoPtr->aivNum, compileInfoPtr->ubSize, compileInfoPtr->l1Size, + compileInfoPtr->l0CSize, compileInfoPtr->l0ASize, compileInfoPtr->l0BSize), + return ge::GRAPH_FAILED); + + OPS_LOG_I(context->GetNodeName(), "Parse compile info success, soc: %d", + static_cast(compileInfoPtr->socVersion)); + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(GroupedMatmul) +.Tiling(TilingGMM) +.TilingParse(TilingPrepareForGMM); // regist into the framework +} // namespace optiling diff --git a/src/transformer/grouped_matmul/ophost/grouped_matmul_tiling.h b/src/transformer/grouped_matmul/ophost/grouped_matmul_tiling.h new file mode 100644 index 00000000..048377b1 --- /dev/null +++ b/src/transformer/grouped_matmul/ophost/grouped_matmul_tiling.h @@ -0,0 +1,59 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_tiling.h + * \brief + */ +#ifndef AIR_CXX_RUNTIME_V2_OP_IMPL_GROUPED_MATMUL_H +#define AIR_CXX_RUNTIME_V2_OP_IMPL_GROUPED_MATMUL_H + +#include "register/tilingdata_base.h" +#include "tiling/tiling_api.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(GMMBaseParams) + TILING_DATA_FIELD_DEF(uint32_t, groupNum); + TILING_DATA_FIELD_DEF(uint32_t, coreNum); + TILING_DATA_FIELD_DEF(uint32_t, activeType); + TILING_DATA_FIELD_DEF(uint32_t, ubBaseK); + TILING_DATA_FIELD_DEF(uint32_t, ubBaseN); + TILING_DATA_FIELD_DEF(uint32_t, ubCalSize); + TILING_DATA_FIELD_DEF(uint32_t, ubRestBytes); + TILING_DATA_FIELD_DEF(uint32_t, singleWeight); + TILING_DATA_FIELD_DEF(uint32_t, singleX); + TILING_DATA_FIELD_DEF(uint32_t, singleY); + TILING_DATA_FIELD_DEF(int32_t, groupType); + TILING_DATA_FIELD_DEF(uint32_t, singleN); // If sequential write, the value should be zero! + TILING_DATA_FIELD_DEF(uint32_t, quantParam); // in quant case, PerToken: 1; in antiquant case, the value represents PerGroupSize + TILING_DATA_FIELD_DEF(uint32_t, groupListType); + TILING_DATA_FIELD_DEF(uint32_t, m); + TILING_DATA_FIELD_DEF(uint32_t, hasBias); + TILING_DATA_FIELD_DEF(uint64_t, workspaceSize); +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(GMMBaseParamsOp, GMMBaseParams) + +BEGIN_TILING_DATA_DEF(GMMArray) + TILING_DATA_FIELD_DEF_ARR(int32_t, 128, mList); // 128 :MAX_TENSOR_CONT + TILING_DATA_FIELD_DEF_ARR(int32_t, 128, kList); + TILING_DATA_FIELD_DEF_ARR(int32_t, 128, nList); +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(GMMArrayOp, GMMArray) + +BEGIN_TILING_DATA_DEF(GMMTilingData) + TILING_DATA_FIELD_DEF_STRUCT(GMMBaseParams, gmmBaseParams); + TILING_DATA_FIELD_DEF_STRUCT(GMMArray, gmmArray); + TILING_DATA_FIELD_DEF_STRUCT(TCubeTiling, mmTilingData); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(GroupedMatmul, GMMTilingData) +} + +#endif // AIR_CXX_RUNTIME_V2_OP_IMPL_GROUPED_MATMUL_H \ No newline at end of file diff --git a/tests/ut/ops_test/framework/utils/inc/tests/utils/aclnn_tensor_list.h b/tests/ut/ops_test/framework/utils/inc/tests/utils/aclnn_tensor_list.h new file mode 100644 index 00000000..6eafe8d8 --- /dev/null +++ b/tests/ut/ops_test/framework/utils/inc/tests/utils/aclnn_tensor_list.h @@ -0,0 +1,55 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file aclnn_tensor_list.h + * \brief 封装 ACLNN TensorList, 简化 Tiling 及 Kernel 阶段对 TensorList 操作. + */ + +#pragma once + +#include +#include "tests/utils/tensor_list.h" + +namespace ops::adv::tests::utils { + +class AclnnTensorList : public ops::adv::tests::utils::TensorIntf { +public: + AclnnTensorList() = default; + AclnnTensorList(const char *name, const std::vector> &shape, const char *shapeType, + ge::DataType dType, ge::Format format, + TensorType type = TensorType::REQUIRED_INPUT, bool isTrans = false); + explicit AclnnTensorList(const TensorList &t, bool isTrans = false); + AclnnTensorList(const AclnnTensorList &o) = default; + AclnnTensorList &operator=(const AclnnTensorList &o) = default; + ~AclnnTensorList() override; + + aclDataType GetAclDataType() const; + aclTensorList *GetAclTensorList() const; + + uint8_t *AllocDevData(int32_t initVal, int64_t minSize) override; + void FreeDevData() override; + +protected: + aclDataType aclDataType_ = ACL_DT_UNDEFINED; + std::vector> aclTensorListDataStrides_; + aclTensorList *aclTensorList_ = nullptr; + + uint8_t *AllocDevDataImpl(int64_t size) override; + void FreeDevDataImpl(uint8_t *devPtr) override; + bool MemSetDevDataImpl(uint8_t *devPtr, int64_t devMax, int32_t val, int64_t cnt) override; + bool MemCpyHostToDevDataImpl(uint8_t *devPtr, int64_t devMax, const void *hostPtr, int64_t cnt) override; + bool MemCpyDevDataToHostImpl(void *hostPtr, int64_t hostMax, const uint8_t *devPtr, int64_t cnt) override; + +private: + void Destroy(); +}; + +} // namespace ops::adv::tests::utils diff --git a/tests/ut/ops_test/framework/utils/inc/tests/utils/tensor_intf.h b/tests/ut/ops_test/framework/utils/inc/tests/utils/tensor_intf.h index 5f7d3488..b8aea710 100644 --- a/tests/ut/ops_test/framework/utils/inc/tests/utils/tensor_intf.h +++ b/tests/ut/ops_test/framework/utils/inc/tests/utils/tensor_intf.h @@ -61,6 +61,8 @@ public: const std::string &Name() const; const gert::Shape &Shape() const; const std::vector &ShapeView() const; + const std::vector> &ShapesView() const; + const std::string &ShapeType() const; ge::DataType GetDataType() const; ge::Format GetFormat() const; @@ -70,6 +72,7 @@ public: uint8_t *GetDevData() const; int64_t GetDevDataSize() const; + virtual uint8_t *AllocDevDataNz(int32_t initVal, int64_t minSize); virtual uint8_t *AllocDevData(int32_t initVal, int64_t minSize); virtual void FreeDevData(); diff --git a/tests/ut/ops_test/framework/utils/src/aclnn_tensor_list.cpp b/tests/ut/ops_test/framework/utils/src/aclnn_tensor_list.cpp new file mode 100644 index 00000000..62df444f --- /dev/null +++ b/tests/ut/ops_test/framework/utils/src/aclnn_tensor_list.cpp @@ -0,0 +1,175 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file aclnn_tensor_list.cpp + * \brief 封装 ACLNN TensorList, 简化 Tiling 及 Kernel 阶段对 TensorList 操作. + */ + +#include "tests/utils/aclnn_tensor_list.h" +#include +#include +#include +#include "tests/utils/log.h" +#include "tests/utils/io.h" + +namespace { +std::map geDtype2AclDtypeMap = { + {ge::DataType::DT_FLOAT16, ACL_FLOAT16}, {ge::DataType::DT_BF16, ACL_BF16}, {ge::DataType::DT_FLOAT, ACL_FLOAT}, + {ge::DataType::DT_BOOL, ACL_BOOL}, {ge::DataType::DT_UINT8, ACL_UINT8}, {ge::DataType::DT_INT4, ACL_INT4}, + {ge::DataType::DT_INT8, ACL_INT8}, {ge::DataType::DT_INT32, ACL_INT32}, {ge::DataType::DT_INT64, ACL_INT64}}; +} + +using namespace ops::adv::tests::utils; + +AclnnTensorList::AclnnTensorList(const char *name, const std::vector> &shapes, + const char *shapeType, ge::DataType dType, ge::Format format, TensorType type, bool isTrans) + : TensorIntf(name, {}, shapeType, dType, format, type), aclDataType_(ACL_DT_UNDEFINED), + aclTensorListDataStrides_({}), aclTensorList_(nullptr) +{ + for (auto shape : shapes) { + std::vector shapeView{}; + gert::Shape myShape; + for (auto dim : shape) { + shapeView.push_back(dim); + myShape.AppendDim(dim); + } + std::vector aclTensorDataStrides{}; + aclTensorDataStrides.resize(myShape.GetDimNum(), 1); + auto dim1 = static_cast(myShape.GetDimNum() - 1); + auto dim2 = static_cast(myShape.GetDimNum() - 2); + for (auto i = dim2; i >= 0; i--) { + aclTensorDataStrides[i] = myShape[i + 1] * aclTensorDataStrides[i + 1]; + } + if (isTrans && dim2 >= 0) { + aclTensorDataStrides[dim2] = 1; + aclTensorDataStrides[dim1] = myShape[dim2]; + } + this->shapes_.push_back(myShape); + this->shapesView_.push_back(shapeView); + aclTensorListDataStrides_.push_back(aclTensorDataStrides); + } + this->isArray_ = true; + /* 获取 aclDataType */ + auto iter = geDtype2AclDtypeMap.find(dType); + if (iter == geDtype2AclDtypeMap.end()) { + LOG_ERR("TensorList(%s), Unknown dtype(%s)", name_.c_str(), ge::TypeUtils::DataTypeToSerialString(dType).c_str()); + } else { + aclDataType_ = iter->second; + } +} + +AclnnTensorList::AclnnTensorList(const TensorList &t, bool isTrans) + : AclnnTensorList(t.Name().c_str(), t.ShapesView(), t.ShapeType().c_str(), t.GetDataType(), t.GetFormat(), + t.GetTensorType(), isTrans) +{ +} + +AclnnTensorList::~AclnnTensorList() +{ + this->Destroy(); +} + +aclDataType AclnnTensorList::GetAclDataType() const +{ + return aclDataType_; +} + +aclTensorList *AclnnTensorList::GetAclTensorList() const +{ + return aclTensorList_; +} + +uint8_t *AclnnTensorList::AllocDevData(int32_t initVal, int64_t minSize) +{ + int size = this->shapesView_.size(); + if (size <= 0) { + return nullptr; + } + aclTensor** tensors = reinterpret_cast(malloc(size * sizeof(aclTensor*))); + if (tensors == nullptr) { + return nullptr; + } + for (int i = 0; i < size; i++) { + aclTensor** tmpTensor = tensors + i; + this->shape_ = this->shapes_[i]; + if (TensorIntf::AllocDevData(initVal, minSize) == nullptr) { + return nullptr; + } + /* 调用 aclCreateTensor 创建 aclTensor */ + *tmpTensor = aclCreateTensor(shapesView_[i].data(), shapesView_[i].size(), aclDataType_, aclTensorListDataStrides_[i].data(), 0, + aclFormat::ACL_FORMAT_ND, shapesView_[i].data(), shapesView_[i].size(), devData_); + if (tmpTensor == nullptr) { + LOG_ERR("aclCreateTensor failed, Tensor(%s))", name_.c_str()); + this->FreeDevData(); + } + } + aclTensorList_ = aclCreateTensorList(tensors, size); + free(tensors); + return devData_; +} + +void AclnnTensorList::FreeDevData() +{ + this->Destroy(); +} + +uint8_t *AclnnTensorList::AllocDevDataImpl(int64_t size) +{ + /* 调用 aclrtMalloc 申请 device 侧内存 */ + void *devPtr = nullptr; + auto ret = aclrtMalloc(&devPtr, size, ACL_MEM_MALLOC_HUGE_FIRST); + if (ret != ACL_SUCCESS) { + LOG_ERR("aclrtMalloc failed, ERROR: %d, Tensor(%s), Size(%ld)", ret, name_.c_str(), size); + } + // LOG_IF(ret != ACL_SUCCESS, LOG_ERR("aclrtMalloc failed, ERROR: %d, Tensor(%s), Size(%ld)", ret, name_.c_str(), + // size)); + return (uint8_t *)devPtr; +} + +void AclnnTensorList::FreeDevDataImpl(uint8_t *devPtr) +{ + auto ret = aclrtFree(devPtr); + LOG_IF(ret != ACL_SUCCESS, LOG_ERR("aclrtFree failed, ERROR: %d, Tensor(%s))", ret, name_.c_str())); +} + +bool AclnnTensorList::MemSetDevDataImpl(uint8_t *devPtr, int64_t devMax, int32_t val, int64_t cnt) +{ + /* 调用 aclrtMemset 设置值 */ + auto ret = aclrtMemset(devPtr, devMax, val, cnt); + LOG_IF(ret != ACL_SUCCESS, LOG_ERR("aclrtMemset failed, ERROR: %d, Tensor(%s)", ret, name_.c_str())); + return ret == ACL_SUCCESS; +} + +bool AclnnTensorList::MemCpyHostToDevDataImpl(uint8_t *devPtr, int64_t devMax, const void *hostPtr, int64_t cnt) +{ + /* 调用 aclrtMemcpy */ + auto ret = aclrtMemcpy(devPtr, devMax, hostPtr, cnt, ACL_MEMCPY_HOST_TO_DEVICE); + LOG_IF(ret != ACL_SUCCESS, LOG_ERR("aclrtMemcpy failed, ERROR: %d, Tensor(%s)", ret, name_.c_str())); + return ret == ACL_SUCCESS; +} + +bool AclnnTensorList::MemCpyDevDataToHostImpl(void *hostPtr, int64_t hostMax, const uint8_t *devPtr, int64_t cnt) +{ + /* 调用 aclrtMemcpy */ + auto ret = aclrtMemcpy(hostPtr, hostMax, devPtr, cnt, ACL_MEMCPY_DEVICE_TO_HOST); + LOG_IF(ret != ACL_SUCCESS, LOG_ERR("aclrtMemcpy failed, ERROR: %d, Tensor(%s)", ret, name_.c_str())); + return ret == ACL_SUCCESS; +} + +void AclnnTensorList::Destroy() +{ + if (aclTensorList_ != nullptr) { + auto ret = aclDestroyTensorList(aclTensorList_); + LOG_IF(ret != ACL_SUCCESS, LOG_ERR("aclDestroyTensorList failed, ERROR: %d, TensorList(%s))", ret, name_.c_str())); + aclTensorList_ = nullptr; + } + TensorIntf::FreeDevData(); +} diff --git a/tests/ut/ops_test/framework/utils/src/tensor_intf.cpp b/tests/ut/ops_test/framework/utils/src/tensor_intf.cpp index 389d9758..279b463f 100644 --- a/tests/ut/ops_test/framework/utils/src/tensor_intf.cpp +++ b/tests/ut/ops_test/framework/utils/src/tensor_intf.cpp @@ -83,7 +83,7 @@ std::string TensorIntf::GetTilingStr() const R"(", )" + R"("shape": )" + TensorIntf::ToString(this->shapes_[i]) + R"(, )" + R"("ori_shape": )" + TensorIntf::ToString(this->shapes_[i]) + R"( })"; finalStr += str; - finalStr += ";"; + finalStr += ","; } str = R"({ )" R"("name": ")" + @@ -112,6 +112,11 @@ const std::vector &TensorIntf::ShapeView() const return shapeView_; } +const std::vector> &TensorIntf::ShapesView() const +{ + return shapesView_; +} + const std::string &TensorIntf::ShapeType() const { return shapeType_; @@ -163,6 +168,25 @@ int64_t TensorIntf::GetDevDataSize() const return devDataSize_; } +uint8_t *TensorIntf::AllocDevDataNz(int32_t initVal, int64_t minSize) +{ + if (devData_ != nullptr) { + return devData_; + } + devDataSize_ = std::max(this->GetExpDataSize(), minSize); + devData_ = this->AllocDevDataImpl(devDataSize_); + if (devData_ == nullptr) { + goto ErrRet; + } + if (!this->MemSetDevDataImpl(devData_, devDataSize_, initVal, devDataSize_)) { + goto ErrRet; + } + return devData_; +ErrRet: + this->FreeDevData(); + return nullptr; +} + uint8_t *TensorIntf::AllocDevData(int32_t initVal, int64_t minSize) { if (devData_ != nullptr) { diff --git a/tests/ut/ops_test/src/transformer/CMakeLists.txt b/tests/ut/ops_test/src/transformer/CMakeLists.txt index 5e2aa44b..d072d7a0 100644 --- a/tests/ut/ops_test/src/transformer/CMakeLists.txt +++ b/tests/ut/ops_test/src/transformer/CMakeLists.txt @@ -10,6 +10,7 @@ set(flash_attention_score_alias flash_attention) set(flash_attention_score_grad_alias flash_attention) set(ffn_alias ffn) +set(grouped_matmul_alias grouped_matmul) set(incre_flash_attention incre_flash_attention) set(prompt_flash_attention prompt_flash_attention) set(fused_infer_attention_score fused_infer_attention_score) diff --git a/tests/ut/ops_test/src/transformer/grouped_matmul/CMakeLists.txt b/tests/ut/ops_test/src/transformer/grouped_matmul/CMakeLists.txt new file mode 100644 index 00000000..b48df4c5 --- /dev/null +++ b/tests/ut/ops_test/src/transformer/grouped_matmul/CMakeLists.txt @@ -0,0 +1,73 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +######################################################################################################################## +# 调用编译方法, 生成对应编译目标 +######################################################################################################################## +set(_GMM_TilingSourcesExt + ${CMAKE_SOURCE_DIR}/src/transformer/grouped_matmul/ophost/grouped_matmul_tiling.cpp +) + +set(_GMM_TilingPrivateIncludesExt + ${CMAKE_SOURCE_DIR}/src/transformer/grouped_matmul/ophost +) + +set(_GMM_KernelTilingDataDefH + ${CMAKE_SOURCE_DIR}/src/transformer/grouped_matmul/ophost/grouped_matmul_tiling.h +) + +set(_GMM_OpApiSourcesExt + ${CMAKE_SOURCE_DIR}/src/transformer/grouped_matmul/ophost/grouped_matmul.cpp + ${CMAKE_SOURCE_DIR}/src/transformer/grouped_matmul/ophost/aclnn_grouped_matmul.cpp +) + +set(_GMM_OpProtoSourcesExt + ${CMAKE_SOURCE_DIR}/src/transformer/grouped_matmul/ophost/grouped_matmul_proto.cpp +) + +aux_source_directory(${CMAKE_SOURCE_DIR}/src/transformer/grouped_matmul _GMM_KernelSourcesExt) +set(_GMM_KernelPrivateCompileDefinitionsExt + KernelCtrlParam grouped_matmul bf16 DTYPE_X=bfloat16_t DTYPE_WEIGHT=bfloat16_t DTYPE_BIAS=float DTYPE_Y=bfloat16_t ORIG_DTYPE_X=DT_BF16 ORIG_DTYPE_WEIGHT=DT_BF16 ORIG_DTYPE_Y=DT_BF16 + KernelCtrlParam grouped_matmul fp16 DTYPE_X=half DTYPE_WEIGHT=half DTYPE_BIAS=half DTYPE_Y=half ORIG_DTYPE_X=DT_FLOAT16 ORIG_DTYPE_WEIGHT=DT_FLOAT16 ORIG_DTYPE_Y=DT_FLOAT16 + KernelCtrlParam grouped_matmul fp32 DTYPE_X=float DTYPE_WEIGHT=float DTYPE_BIAS=float DTYPE_Y=float ORIG_DTYPE_X=DT_FLOAT ORIG_DTYPE_WEIGHT=DT_FLOAT ORIG_DTYPE_Y=DT_FLOAT + KernelCtrlParam grouped_matmul quant_int8 DTYPE_X=int8_t DTYPE_WEIGHT=int8_t DTYPE_BIAS=int32_t DTYPE_SCALE=uint64_t DTYPE_Y=int8_t ORIG_DTYPE_X=DT_INT8 ORIG_DTYPE_WEIGHT=DT_INT8 ORIG_DTYPE_Y=DT_INT8 + KernelCtrlParam grouped_matmul quant_bf16 DTYPE_X=int8_t DTYPE_WEIGHT=int8_t DTYPE_BIAS=int32_t DTYPE_SCALE=bfloat16_t DTYPE_Y=bfloat16_t ORIG_DTYPE_X=DT_INT8 ORIG_DTYPE_WEIGHT=DT_INT8 ORIG_DTYPE_Y=DT_BF16 + KernelCtrlParam grouped_matmul quant_fp16 DTYPE_X=int8_t DTYPE_WEIGHT=int8_t DTYPE_BIAS=int32_t DTYPE_SCALE=float DTYPE_Y=half ORIG_DTYPE_X=DT_INT8 ORIG_DTYPE_WEIGHT=DT_INT8 ORIG_DTYPE_Y=DT_FLOAT16 + KernelCtrlParam grouped_matmul a16w8_bf16 DTYPE_X=bfloat16_t DTYPE_WEIGHT=int8_t DTYPE_BIAS=float DTYPE_Y=bfloat16_t ORIG_DTYPE_X=DT_BF16 ORIG_DTYPE_WEIGHT=DT_INT8 ORIG_DTYPE_Y=DT_BF16 + KernelCtrlParam grouped_matmul a16w8_fp16 DTYPE_X=half DTYPE_WEIGHT=int8_t DTYPE_BIAS=half DTYPE_Y=half ORIG_DTYPE_X=DT_FLOAT16 ORIG_DTYPE_WEIGHT=DT_INT8 ORIG_DTYPE_Y=DT_FLOAT16 + KernelCtrlParam grouped_matmul a16w4_bf16 DTYPE_X=bfloat16_t DTYPE_WEIGHT=int4b_t DTYPE_BIAS=float DTYPE_Y=bfloat16_t ORIG_DTYPE_X=DT_BF16 ORIG_DTYPE_WEIGHT=DT_INT4 ORIG_DTYPE_Y=DT_BF16 + KernelCtrlParam grouped_matmul a16w4_fp16 DTYPE_X=half DTYPE_WEIGHT=int4b_t DTYPE_BIAS=half DTYPE_Y=half ORIG_DTYPE_X=DT_FLOAT16 ORIG_DTYPE_WEIGHT=DT_INT4 ORIG_DTYPE_Y=DT_FLOAT16 +) + +set(_GMM_TargetPrivateIncludeExt + ${CMAKE_SOURCE_DIR}/src/transformer/grouped_matmul/ophost +) + +set(_GMM_TillingPrivateLinkLibrariesExt +) + +set(_GMM_TargetPrivateLinkLibrariesExt + ops_utils_tiling_headers +) + +OpsTest_Level2_AddOp( + SUB_SYSTEM transformer + BRIEF GMM + SNAKE grouped_matmul + OPAPI_SOURCES_EXT ${_GMM_OpApiSourcesExt} + PROTO_SOURCES_EXT ${_GMM_OpProtoSourcesExt} + TILING_SOURCES_EXT ${_GMM_TilingSourcesExt} + TILING_PRIVATE_INCLUDES_EXT ${_GMM_TilingPrivateIncludesExt} + TILING_PRIVATE_LINK_LIBRARIES_EXT ${_GMM_TillingPrivateLinkLibrariesExt} + KERNEL_SOURCES_EXT ${_GMM_KernelSourcesExt} + KERNEL_TILING_DATA_DEF_H ${_GMM_KernelTilingDataDefH} + KERNEL_PRIVATE_COMPILE_DEFINITIONS_EXT ${_GMM_KernelPrivateCompileDefinitionsExt} + UTEST_COMMON_PRIVATE_INCLUDES_EXT ${_GMM_TargetPrivateIncludeExt} + UTEST_COMMON_PRIVATE_LINK_LIBRARIES_EXT ${_GMM_TargetPrivateLinkLibrariesExt} +) \ No newline at end of file diff --git a/tests/ut/ops_test/src/transformer/grouped_matmul/comm/inc/aclnn_grouped_matmul_case.h b/tests/ut/ops_test/src/transformer/grouped_matmul/comm/inc/aclnn_grouped_matmul_case.h new file mode 100644 index 00000000..489ded58 --- /dev/null +++ b/tests/ut/ops_test/src/transformer/grouped_matmul/comm/inc/aclnn_grouped_matmul_case.h @@ -0,0 +1,52 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file aclnn_grouped_matmul_case.h + * \brief GroupedMatmul Aclnn 测试用例. + */ + +#ifndef UTEST_ACLNN_GROUPED_MATMUL_CASE_H +#define UTEST_ACLNN_GROUPED_MATMUL_CASE_H + +#include "grouped_matmul_case.h" +#include "tests/utils/op_info.h" +#include "tests/utils/aclnn_context.h" +#include "aclnn_grouped_matmul_param.h" + +namespace ops::adv::tests::grouped_matmul { +using AclnnGroupedMatmulVersion = ops::adv::tests::grouped_matmul::AclnnGroupedMatmulParam::AclnnGroupedMatmulVersion; + +class AclnnGroupedMatmulCase : public ops::adv::tests::grouped_matmul::GroupedMatmulCase { +public: + using AclnnContext = ops::adv::tests::utils::AclnnContext; + +public: + /* 算子控制信息 */ + AclnnContext mAclnnCtx; + + /* 输入/输出 参数 */ + AclnnGroupedMatmulParam mAclnnParam; + +public: + AclnnGroupedMatmulCase(); + AclnnGroupedMatmulCase(const char *name, bool enable, const char *dbgInfo, OpInfo opInfo, + AclnnGroupedMatmulParam param, + int32_t tilingTemplatePriority = kTilingTemplatePriority_Invalid); + bool Run() override; + +protected: + bool InitParam() override; + bool InitOpInfo() override; + bool InitCurrentCasePtr() override; +}; + +} // namespace ops::adv::tests::grouped_matmul +#endif // UTEST_ACLNN_GROUPED_MATMUL_CASE_H diff --git a/tests/ut/ops_test/src/transformer/grouped_matmul/comm/inc/aclnn_grouped_matmul_param.h b/tests/ut/ops_test/src/transformer/grouped_matmul/comm/inc/aclnn_grouped_matmul_param.h new file mode 100644 index 00000000..8f330de3 --- /dev/null +++ b/tests/ut/ops_test/src/transformer/grouped_matmul/comm/inc/aclnn_grouped_matmul_param.h @@ -0,0 +1,70 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file aclnn_grouped_matmul_param.h + * \brief GroupedMatmul Aclnn 参数信息. + */ + +#ifndef UTEST_ACLNN_GROUPED_MATMUL_PARAM_H +#define UTEST_ACLNN_GROUPED_MATMUL_PARAM_H + +#include "grouped_matmul_case.h" +#include "tests/utils/aclnn_tensor.h" +#include "tests/utils/aclnn_tensor_list.h" + +namespace ops::adv::tests::grouped_matmul { + +class AclnnGroupedMatmulParam : public ops::adv::tests::grouped_matmul::Param { +public: + using AclnnTensor = ops::adv::tests::utils::AclnnTensor; + using AclnnTensorList = ops::adv::tests::utils::AclnnTensorList; + +public: + enum class FunctionType { + NO_QUANT, + QUANT, + ANTIQUANT, + QUANT_PERTOKEN + }; + + enum class AclnnGroupedMatmulVersion { + V1, + V2, + V3, + V4 + }; + +public: + FunctionType mFunctionType = FunctionType::NO_QUANT; + AclnnGroupedMatmulVersion mAclnnGroupedMatmulVersion = AclnnGroupedMatmulVersion::V1; + /* 输入输出 */ + AclnnTensorList aclnnX, aclnnWeight, aclnnBias, aclnnScale, aclnnOffset, aclnnAntiquantScale, aclnnAntiquantOffset, + aclnnPerTokenScale, aclnnY; + AclnnTensor aclnnGroupListTensor; + aclIntArray *aclnnGroupListIntAry = nullptr; + +public: + AclnnGroupedMatmulParam() = default; + AclnnGroupedMatmulParam(std::vector inputs, Tensor groupList, std::vector groupListData, + std::int32_t splitItem, int32_t dtype, bool transposeWeight, bool transposeX, + int32_t groupType, int32_t groupListType, int32_t actType, FunctionType functionType, + AclnnGroupedMatmulVersion aclnnGroupedMatmulVersion); + + ~AclnnGroupedMatmulParam(); + + bool Init(); + +private: + bool InitGroupList(); +}; + +} // namespace ops::adv::tests::grouped_matmul +#endif // UTEST_ACLNN_GROUPEDMATMUL_PARAM_H diff --git a/tests/ut/ops_test/src/transformer/grouped_matmul/comm/inc/grouped_matmul_case.h b/tests/ut/ops_test/src/transformer/grouped_matmul/comm/inc/grouped_matmul_case.h new file mode 100644 index 00000000..f376f096 --- /dev/null +++ b/tests/ut/ops_test/src/transformer/grouped_matmul/comm/inc/grouped_matmul_case.h @@ -0,0 +1,54 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + /*! + * \file grouped_matmul_case.h + * \brief GroupedMatmul 测试用例. + */ + +#ifndef UTEST_GROUPED_MATMUL_CASE_H +#define UTEST_GROUPED_MATMUL_CASE_H + +#include +#include + +#include "tests/utils/case.h" +#include "tests/utils/tensor.h" +#include "tests/utils/context.h" +#include "tests/utils/op_info.h" +#include "grouped_matmul_param.h" + +namespace ops::adv::tests::grouped_matmul { + +using ops::adv::tests::grouped_matmul::Param; +using ops::adv::tests::utils::Context; +using ops::adv::tests::utils::OpInfo; + +class GroupedMatmulCase : public ops::adv::tests::utils::Case { +public: + OpInfo mOpInfo; + Context mCtx; + Param mParam; + +public: + GroupedMatmulCase(); + GroupedMatmulCase(const char *name, bool enable, const char *dbgInfo, OpInfo opInfo, Param param, + int32_t tilingTemplatePriority); + + bool Run() override; + +protected: + bool InitParam() override; + bool InitOpInfo() override; + bool InitCurrentCasePtr() override; +}; + +} // namespace ops::adv::tests::grouped_matmul +#endif // UTEST_GROUPED_MATMUL_CASE_H \ No newline at end of file diff --git a/tests/ut/ops_test/src/transformer/grouped_matmul/comm/inc/grouped_matmul_param.h b/tests/ut/ops_test/src/transformer/grouped_matmul/comm/inc/grouped_matmul_param.h new file mode 100644 index 00000000..9a15b993 --- /dev/null +++ b/tests/ut/ops_test/src/transformer/grouped_matmul/comm/inc/grouped_matmul_param.h @@ -0,0 +1,56 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + /*! + * \file grouped_matmul_param.h + * \brief GroupedMatmul 参数信息. + */ + +#ifndef UTEST_GROUPED_MATMUL_PARAM_H +#define UTEST_GROUPED_MATMUL_PARAM_H + +#include +#include "tests/utils/tensor.h" +#include "tests/utils/tensor_list.h" + +namespace ops::adv::tests::grouped_matmul { + +using ops::adv::tests::utils::Tensor; +using ops::adv::tests::utils::TensorList; + +class Param { +public: + std::map mTensorLists; + Tensor mGroupList; + Tensor mPerTokenScale; + std::vector mGroupListData = {}; + int32_t mSplitItem = 0; + int32_t mDtype = 0; + bool mTransposeWeight = false; + bool mTransposeX = false; + int32_t mGroupType = 0; + int32_t mGroupListType = 0; + int32_t mActType = 0; + +public: + Param() = default; + Param(std::vector inputs, Tensor perTokenScale, Tensor groupList, + std::vector groupListData, int32_t splitItem, int32_t dType, + bool transposeWeight, bool transposeX, int32_t groupType, int32_t groupListType, int32_t actType); +}; + +Tensor GenTensor(const char *name, const std::initializer_list &shape, ge::DataType dType, + ge::Format format = ge::FORMAT_ND); + +TensorList GenTensorList(const char *name, const std::vector> &shapes, ge::DataType dType, + ge::Format format = ge::FORMAT_ND); + +} // namespace ops::adv::tests::grouped_matmul +#endif // UTEST_GROUPED_MATMUL_PARAM_H \ No newline at end of file diff --git a/tests/ut/ops_test/src/transformer/grouped_matmul/comm/src/aclnn_grouped_matmul_case.cpp b/tests/ut/ops_test/src/transformer/grouped_matmul/comm/src/aclnn_grouped_matmul_case.cpp new file mode 100644 index 00000000..0766c0cb --- /dev/null +++ b/tests/ut/ops_test/src/transformer/grouped_matmul/comm/src/aclnn_grouped_matmul_case.cpp @@ -0,0 +1,141 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file aclnn_grouped_matmul_case.cpp + * \brief GroupedMatmul Aclnn 测试用例. + */ + +#include +#include "tests/utils/log.h" +#include "aclnn_grouped_matmul.h" +#include "aclnn_grouped_matmul_v2.h" +#include "aclnn_grouped_matmul_v3.h" +#include "aclnn_grouped_matmul_v4.h" +#include "aclnn_grouped_matmul_case.h" + +using namespace ops::adv::tests::grouped_matmul; + +bool GroupedMatmulTilingRunCbf(void *curCase, uint64_t *workSpaceSize, aclOpExecutor **opExecutor) +{ + auto *cs = static_cast(curCase); + auto *aclnnParam = &cs->mAclnnParam; + + aclnnStatus ret = ACL_SUCCESS; + if (aclnnParam->mAclnnGroupedMatmulVersion == AclnnGroupedMatmulVersion::V1) { + ret = aclnnGroupedMatmulGetWorkspaceSize( + aclnnParam->aclnnX.GetAclTensorList(), aclnnParam->aclnnWeight.GetAclTensorList(), + aclnnParam->aclnnBias.GetAclTensorList(), aclnnParam->aclnnScale.GetAclTensorList(), + aclnnParam->aclnnOffset.GetAclTensorList(), aclnnParam->aclnnAntiquantScale.GetAclTensorList(), + aclnnParam->aclnnAntiquantOffset.GetAclTensorList(), aclnnParam->aclnnGroupListIntAry, + aclnnParam->mSplitItem, aclnnParam->aclnnY.GetAclTensorList(), workSpaceSize, opExecutor); + } else if (aclnnParam->mAclnnGroupedMatmulVersion == AclnnGroupedMatmulVersion::V2) { + ret = aclnnGroupedMatmulV2GetWorkspaceSize( + aclnnParam->aclnnX.GetAclTensorList(), aclnnParam->aclnnWeight.GetAclTensorList(), + aclnnParam->aclnnBias.GetAclTensorList(), aclnnParam->aclnnScale.GetAclTensorList(), + aclnnParam->aclnnOffset.GetAclTensorList(), aclnnParam->aclnnAntiquantScale.GetAclTensorList(), + aclnnParam->aclnnAntiquantOffset.GetAclTensorList(), aclnnParam->aclnnGroupListIntAry, + aclnnParam->mSplitItem, aclnnParam->mGroupType, aclnnParam->aclnnY.GetAclTensorList(), + workSpaceSize, opExecutor); + } else if (aclnnParam->mAclnnGroupedMatmulVersion == AclnnGroupedMatmulVersion::V3) { + ret = aclnnGroupedMatmulV3GetWorkspaceSize( + aclnnParam->aclnnX.GetAclTensorList(), aclnnParam->aclnnWeight.GetAclTensorList(), + aclnnParam->aclnnBias.GetAclTensorList(), aclnnParam->aclnnScale.GetAclTensorList(), + aclnnParam->aclnnOffset.GetAclTensorList(), aclnnParam->aclnnAntiquantScale.GetAclTensorList(), + aclnnParam->aclnnAntiquantOffset.GetAclTensorList(), aclnnParam->aclnnGroupListTensor.GetAclTensor(), + aclnnParam->mSplitItem, aclnnParam->mGroupType, aclnnParam->aclnnY.GetAclTensorList(), + workSpaceSize, opExecutor); + } else if (aclnnParam->mAclnnGroupedMatmulVersion == AclnnGroupedMatmulVersion::V4) { + ret = aclnnGroupedMatmulV4GetWorkspaceSize( + aclnnParam->aclnnX.GetAclTensorList(), aclnnParam->aclnnWeight.GetAclTensorList(), + aclnnParam->aclnnBias.GetAclTensorList(), aclnnParam->aclnnScale.GetAclTensorList(), + aclnnParam->aclnnOffset.GetAclTensorList(), aclnnParam->aclnnAntiquantScale.GetAclTensorList(), + aclnnParam->aclnnAntiquantOffset.GetAclTensorList(), aclnnParam->aclnnPerTokenScale.GetAclTensorList(), + aclnnParam->aclnnGroupListTensor.GetAclTensor(), nullptr, nullptr, nullptr, + aclnnParam->mSplitItem, aclnnParam->mGroupType, aclnnParam->mGroupListType, aclnnParam->mActType, + aclnnParam->aclnnY.GetAclTensorList(), nullptr, nullptr, workSpaceSize, opExecutor); + } + return ret == ACL_SUCCESS; +} + +bool GroupedMatmulKernelRunCbf(void *curCase) +{ + auto *cs = static_cast(curCase); + auto *aclnnParam = &cs->mAclnnParam; + auto *aclnnCtx = &cs->mAclnnCtx; + + aclnnStatus ret = ACL_SUCCESS; + if (aclnnParam->mAclnnGroupedMatmulVersion == AclnnGroupedMatmulVersion::V1) { + ret = aclnnGroupedMatmul(aclnnCtx->GetWorkspacePtr(), aclnnCtx->GetWorkspaceSize(), aclnnCtx->GetAclOpExecutor(), + aclnnCtx->GetAclRtStream()); + } else if (aclnnParam->mAclnnGroupedMatmulVersion == AclnnGroupedMatmulVersion::V2) { + ret = aclnnGroupedMatmulV2(aclnnCtx->GetWorkspacePtr(), aclnnCtx->GetWorkspaceSize(), aclnnCtx->GetAclOpExecutor(), + aclnnCtx->GetAclRtStream()); + } else if (aclnnParam->mAclnnGroupedMatmulVersion == AclnnGroupedMatmulVersion::V3) { + ret = aclnnGroupedMatmulV3(aclnnCtx->GetWorkspacePtr(), aclnnCtx->GetWorkspaceSize(), aclnnCtx->GetAclOpExecutor(), + aclnnCtx->GetAclRtStream()); + } else if (aclnnParam->mAclnnGroupedMatmulVersion == AclnnGroupedMatmulVersion::V4) { + ret = aclnnGroupedMatmulV4(aclnnCtx->GetWorkspacePtr(), aclnnCtx->GetWorkspaceSize(), aclnnCtx->GetAclOpExecutor(), + aclnnCtx->GetAclRtStream()); + } + LOG_IF(ret != ACL_SUCCESS, LOG_ERR("aclnnGroupedMatmul failed, ERROR: %d", ret)); + + return ret == ACL_SUCCESS; +} + +AclnnGroupedMatmulCase::AclnnGroupedMatmulCase() + : GroupedMatmulCase(), mAclnnCtx(AclnnContext()), mAclnnParam(AclnnGroupedMatmulParam()) +{ +} + +AclnnGroupedMatmulCase::AclnnGroupedMatmulCase(const char *name, bool enable, const char *dbgInfo, OpInfo opInfo, + AclnnGroupedMatmulParam aclnnParam, int32_t tilingTemplatePriority) + : GroupedMatmulCase(name, enable, dbgInfo, std::move(opInfo), Param(), tilingTemplatePriority), + mAclnnParam(std::move(aclnnParam)) +{ +} + +bool AclnnGroupedMatmulCase::InitParam() +{ + return mAclnnParam.Init(); +} + +bool AclnnGroupedMatmulCase::InitOpInfo() +{ + if (!GroupedMatmulCase::InitOpInfo()) { + return false; + } + + auto rst = mAclnnCtx.SetOpName(this->mOpInfo.mName.c_str()); + rst = rst && mAclnnCtx.SetTilingRunCbf(GroupedMatmulTilingRunCbf); + rst = rst && mAclnnCtx.SetKernelRunCbf(GroupedMatmulKernelRunCbf); + rst = rst && mOpInfo.SetContext(&mAclnnCtx); + return rst; +} + +bool AclnnGroupedMatmulCase::InitCurrentCasePtr() +{ + Case::mCurrentCasePtr = this; + return true; +} + +bool AclnnGroupedMatmulCase::Run() +{ + if (!mEnable) { + return true; + } + if (!mOpInfo.ProcessTiling(mName)) { + return false; + } + if (!mOpInfo.ProcessKernel(mName)) { + return false; + } + return true; +} diff --git a/tests/ut/ops_test/src/transformer/grouped_matmul/comm/src/aclnn_grouped_matmul_param.cpp b/tests/ut/ops_test/src/transformer/grouped_matmul/comm/src/aclnn_grouped_matmul_param.cpp new file mode 100644 index 00000000..86790c4c --- /dev/null +++ b/tests/ut/ops_test/src/transformer/grouped_matmul/comm/src/aclnn_grouped_matmul_param.cpp @@ -0,0 +1,140 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file aclnn_grouped_matmul_param.cpp + * \brief GroupedMatmul Aclnn 参数信息. + */ + +#include "aclnn_grouped_matmul_param.h" +#include +#include "tests/utils/case.h" +#include "tests/utils/io.h" +#include "tests/utils/log.h" + +using ops::adv::tests::utils::ReadFile; +using ops::adv::tests::utils::WriteFile; +using ops::adv::tests::utils::TensorIntf; +namespace { +template bool InitAclIntArray(aclIntArray **intArray, std::vector &hostData) +{ + if (intArray == nullptr) { + LOG_ERR("intArray nil."); + return false; + } + if (*intArray != nullptr) { + auto ret = aclDestroyIntArray(*intArray); + LOG_IF_EXPR(ret != ACL_SUCCESS, LOG_ERR("aclDestroyIntArray failed, ERROR: %d", ret), *intArray = nullptr); + } + if (hostData.empty()) { + return true; + } + *intArray = aclCreateIntArray(hostData.data(), hostData.size()); + if (*intArray == nullptr) { + LOG_ERR("aclCreateIntArray failed."); + return false; + } + return true; +} +} // namespace + +using namespace ops::adv::tests::grouped_matmul; + +AclnnGroupedMatmulParam::AclnnGroupedMatmulParam(std::vector inputs, Tensor groupList, + std::vector groupListData, std::int32_t splitItem, int32_t dtype, bool transposeWeight, + bool transposeX, int32_t groupType, int32_t groupListType, int32_t actType, FunctionType functionType, + AclnnGroupedMatmulVersion aclnnGMMVersion) + : Param(std::move(inputs), Tensor(), std::move(groupList), std::move(groupListData), splitItem, dtype, transposeWeight, + transposeX, groupType, groupListType, actType), mFunctionType(functionType), mAclnnGroupedMatmulVersion(aclnnGMMVersion) +{ +} + + +AclnnGroupedMatmulParam::~AclnnGroupedMatmulParam() +{ + if (aclnnGroupListIntAry != nullptr) { + auto ret = aclDestroyIntArray(aclnnGroupListIntAry); + LOG_IF_EXPR(ret != ACL_SUCCESS, LOG_ERR("aclnnGroupListIntAry failed, ERROR: %d", ret), + aclnnGroupListIntAry = nullptr); + } +} + +bool AclnnGroupedMatmulParam::Init() +{ + aclnnX = ops::adv::tests::utils::AclnnTensorList(mTensorLists["x"], mTransposeX); + aclnnWeight = ops::adv::tests::utils::AclnnTensorList(mTensorLists["weight"], mTransposeWeight); + auto iter = mTensorLists.find("bias"); + if (iter != mTensorLists.end()) { + aclnnBias = ops::adv::tests::utils::AclnnTensorList(mTensorLists["bias"]); + } + aclnnY = ops::adv::tests::utils::AclnnTensorList(mTensorLists["y"]); + if (mFunctionType == FunctionType::QUANT) { + aclnnScale = ops::adv::tests::utils::AclnnTensorList(mTensorLists["scale"]); + iter = mTensorLists.find("offset"); + if (iter != mTensorLists.end()) { + aclnnOffset = ops::adv::tests::utils::AclnnTensorList(mTensorLists["offset"]); + } + } else if (mFunctionType == FunctionType::ANTIQUANT) { + aclnnAntiquantScale = ops::adv::tests::utils::AclnnTensorList(mTensorLists["antiquant_scale"]); + aclnnAntiquantOffset = ops::adv::tests::utils::AclnnTensorList(mTensorLists["antiquant_offset"]); + } else if (mFunctionType == FunctionType::QUANT_PERTOKEN) { + aclnnScale = ops::adv::tests::utils::AclnnTensorList(mTensorLists["scale"]); + aclnnPerTokenScale = ops::adv::tests::utils::AclnnTensorList(mTensorLists["pertoken_scale"]); + iter = mTensorLists.find("offset"); + if (iter != mTensorLists.end()) { + aclnnOffset = ops::adv::tests::utils::AclnnTensorList(mTensorLists["offset"]); + } + } + auto ret = InitGroupList(); + LOG_IF_EXPR(ret == false, LOG_ERR("InitGroupList faild"), return false); + auto *cs = static_cast(ops::adv::tests::utils::Case::GetCurrentCase()); + LOG_IF_EXPR(cs == nullptr, LOG_ERR("Can't get current case"), return false); + for (auto *t : + std::vector{&aclnnX, &aclnnWeight, &aclnnBias, &aclnnScale, &aclnnOffset, &aclnnAntiquantScale, + &aclnnAntiquantOffset, &aclnnGroupListTensor, &aclnnPerTokenScale, &aclnnY}) { + t->FreeDevData(); + if (t->GetExpDataSize() <= 0) { + continue; + } + auto *devData = t->AllocDevData(0, 0); + if (devData == nullptr) { + return false; + } + std::string filePath = std::string(cs->GetRootPath()) + t->Name() + ".bin"; + if (ops::adv::tests::utils::FileExist(filePath)) { + if (!t->LoadFileToDevData(filePath)) { + return false; + } + } + } + return true; +} + +bool AclnnGroupedMatmulParam::InitGroupList() +{ + if (mGroupListData.size() == 0) { + return true; + } + + if (mAclnnGroupedMatmulVersion == AclnnGroupedMatmulVersion::V3 || + mAclnnGroupedMatmulVersion == AclnnGroupedMatmulVersion::V4) { + size_t dataSize = mGroupListData.size() * sizeof(int64_t); + std::string fileName = "groupList.bin"; + if (!WriteFile(fileName, mGroupListData.data(), dataSize)) { + LOG_ERR("Write groupList data to file[%s] failed", fileName.c_str()); + return false; + } + aclnnGroupListTensor = ops::adv::tests::utils::AclnnTensor(mGroupList); + } else if (!InitAclIntArray(&aclnnGroupListIntAry, mGroupListData)) { + return false; + } + + return true; +} \ No newline at end of file diff --git a/tests/ut/ops_test/src/transformer/grouped_matmul/comm/src/grouped_matmul_case.cpp b/tests/ut/ops_test/src/transformer/grouped_matmul/comm/src/grouped_matmul_case.cpp new file mode 100644 index 00000000..93638538 --- /dev/null +++ b/tests/ut/ops_test/src/transformer/grouped_matmul/comm/src/grouped_matmul_case.cpp @@ -0,0 +1,213 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + /*! + * \file grouped_matmul_case.cpp + * \brief GroupedMatmul 测试用例. + */ + +#include "grouped_matmul_case.h" +#include +#include +#include +#include +#include "tests/utils/log.h" +#include "tests/utils/io.h" +#include "tests/utils/platform.h" +#include "tiling/gmm/tiling_data.h" +#include "tiling/tiling_templates_registry.h" + +using Case = ops::adv::tests::utils::Case; +using GroupedMatmulCase = ops::adv::tests::grouped_matmul::GroupedMatmulCase; +using ops::adv::tests::utils::ReadFile; +using ops::adv::tests::utils::WriteFile; + +/** + * 以下函数声明需要保持与 CMakeList.txt 中调用 OpsTest_Level2_AddOp 函数时 KERNEL_PRIVATE_COMPILE_DEFINITIONS_EXT + * 参数所控制的 Kernel 入口一致. + */ + +#define GROUPEDMATMUL_KERNEL_PARAM \ + (GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR scale, \ + GM_ADDR offset, GM_ADDR antiquantScale, GM_ADDR antiquantOffset, \ + GM_ADDR groupList, GM_ADDR perTokenScale, GM_ADDR y, \ + GM_ADDR workspace, GM_ADDR tiling) + +using GroupedMatmulKernelFunc = void(*) GROUPEDMATMUL_KERNEL_PARAM; + +extern "C" __global__ __aicore__ void grouped_matmul_fp16 GROUPEDMATMUL_KERNEL_PARAM; +extern "C" __global__ __aicore__ void grouped_matmul_bf16 GROUPEDMATMUL_KERNEL_PARAM; +extern "C" __global__ __aicore__ void grouped_matmul_fp32 GROUPEDMATMUL_KERNEL_PARAM; +extern "C" __global__ __aicore__ void grouped_matmul_quant_int8 GROUPEDMATMUL_KERNEL_PARAM; +extern "C" __global__ __aicore__ void grouped_matmul_quant_bf16 GROUPEDMATMUL_KERNEL_PARAM; +extern "C" __global__ __aicore__ void grouped_matmul_quant_fp16 GROUPEDMATMUL_KERNEL_PARAM; +extern "C" __global__ __aicore__ void grouped_matmul_a16w8_bf16 GROUPEDMATMUL_KERNEL_PARAM; +extern "C" __global__ __aicore__ void grouped_matmul_a16w8_fp16 GROUPEDMATMUL_KERNEL_PARAM; +extern "C" __global__ __aicore__ void grouped_matmul_a16w4_bf16 GROUPEDMATMUL_KERNEL_PARAM; +extern "C" __global__ __aicore__ void grouped_matmul_a16w4_fp16 GROUPEDMATMUL_KERNEL_PARAM; + +using namespace ops::adv::tests::grouped_matmul; +using ops::adv::tests::utils::Platform; +using ops::adv::tests::utils::TensorList; +using ops::adv::tests::utils::TensorIntf; + +enum class KernelParams { + X = 0, + WEIGHT, + BIAS, + SCALE, + OFFSET, + ANTIQUANT_SCALE, + ANTIQUANT_OFFSET, + GROUP_LIST, + PER_TOKEN_SCALE +}; + +bool RunGroupedMatmul(void *func, uint64_t tilingKey, int64_t blockDim, std::vector &inputs, + std::vector &output, uint8_t *workspace, uint8_t *tilingData) +{ + (void)blockDim; + // Kernel 运行 + auto kernelFunc = (GroupedMatmulKernelFunc)func; + ICPU_SET_TILING_KEY(tilingKey); + ICPU_RUN_KF(kernelFunc, 1, + inputs[static_cast(KernelParams::X)]->GetDevData(), + inputs[static_cast(KernelParams::WEIGHT)]->GetDevData(), + inputs[static_cast(KernelParams::BIAS)]->GetDevData(), + inputs[static_cast(KernelParams::SCALE)]->GetDevData(), + inputs[static_cast(KernelParams::OFFSET)]->GetDevData(), + inputs[static_cast(KernelParams::ANTIQUANT_SCALE)]->GetDevData(), + inputs[static_cast(KernelParams::ANTIQUANT_OFFSET)]->GetDevData(), + inputs[static_cast(KernelParams::GROUP_LIST)]->GetDevData(), + inputs[static_cast(KernelParams::PER_TOKEN_SCALE)]->GetDevData(), + output[0]->GetDevData(), + workspace, tilingData); + return true; +} + +GroupedMatmulCase::GroupedMatmulCase() : GroupedMatmulCase("Undefined", true, "", OpInfo(), Param(), 0) +{ +} + +GroupedMatmulCase::GroupedMatmulCase(const char *name, bool enable, const char *dbgInfo, OpInfo opInfo, Param param, + int32_t tilingTemplatePriority) + : Case(name, enable, dbgInfo, tilingTemplatePriority), mOpInfo(std::move(opInfo)), mParam(std::move(param)) +{ + this->mOpInfo.mName = "GroupedMatmul"; +} + +bool GroupedMatmulCase::Run() +{ + if(!mEnable) { + return true; + } + if (!mOpInfo.ProcessTiling(mName)) { + return false; + } + auto *groupedMatmulTiling = const_cast((const GMMTilingData *)(mCtx.GetTilingData())); + if (groupedMatmulTiling == nullptr) { + LOG_ERR("Tiling failed!"); + return false; + } + if (!mOpInfo.ProcessKernel(mName)) { + return false; + } + return true; +} + +bool GroupedMatmulCase::InitParam() +{ + if(mParam.mGroupListData.size() > 0) { + size_t dataSize = mParam.mGroupListData.size() * sizeof(int64_t); + uint8_t *addr = mParam.mGroupList.AllocDevData(0, dataSize); + if (addr == nullptr) { + LOG_ERR("TensorList(%s, %zu) AllocDevData Failed.", mParam.mGroupList.Name().c_str(), dataSize); + return false; + } + std::string fileName = this->mName + "_groupList.bin"; + if (!WriteFile(fileName, mParam.mGroupListData.data(), dataSize)) { + LOG_ERR("Write groupList data to file[%s] failed", fileName.c_str()); + return false; + } + if (!ReadFile(fileName, dataSize, addr, dataSize)) { + LOG_ERR("Read groupList data[%s] to tensor failed", fileName.c_str()); + return false; + } + } + return true; +} + +void *GetGroupedMatmulKernelFunc(Param& mParam) { + auto *groupedMatmulKernelFunc = (void *)grouped_matmul_fp16; + if (mParam.mTensorLists["x"].GetDataType() == mParam.mTensorLists["weight"].GetDataType()) { + if (mParam.mTensorLists["weight"].GetDataType() == ge::DataType::DT_INT8) { //量化 + if (mParam.mTensorLists["y"].GetDataType() == ge::DataType::DT_INT8) { + groupedMatmulKernelFunc = (void *)grouped_matmul_quant_int8; + } else if (mParam.mTensorLists["y"].GetDataType() == ge::DataType::DT_FLOAT16) { + groupedMatmulKernelFunc = (void *)grouped_matmul_quant_fp16; + } else { + groupedMatmulKernelFunc = (void *)grouped_matmul_quant_bf16; + } + } else {//非量化 + if (mParam.mTensorLists["weight"].GetDataType() == ge::DataType::DT_FLOAT16) { + groupedMatmulKernelFunc = (void *)grouped_matmul_fp16; + } else if (mParam.mTensorLists["weight"].GetDataType() == ge::DataType::DT_BF16) { + groupedMatmulKernelFunc = (void *)grouped_matmul_bf16; + } else { + groupedMatmulKernelFunc = (void *)grouped_matmul_fp32; + } + } + } else {//伪量化 + if (mParam.mTensorLists["weight"].GetDataType() == ge::DataType::DT_INT8) { + if (mParam.mTensorLists["x"].GetDataType() == ge::DataType::DT_FLOAT16) { + groupedMatmulKernelFunc = (void *)grouped_matmul_a16w8_fp16; + } else { + groupedMatmulKernelFunc = (void *)grouped_matmul_a16w8_bf16; + } + } else { + if (mParam.mTensorLists["x"].GetDataType() == ge::DataType::DT_FLOAT16) { + groupedMatmulKernelFunc = (void *)grouped_matmul_a16w4_fp16; + } else { + groupedMatmulKernelFunc = (void *)grouped_matmul_a16w4_bf16; + } + } + } + return groupedMatmulKernelFunc; +} + +bool GroupedMatmulCase::InitOpInfo() +{ + auto *groupedMatmulKernelFunc = GetGroupedMatmulKernelFunc(mParam); + bool rst = mCtx.SetOpName(mOpInfo.mName.c_str()); + rst = rst && mCtx.SetDeterministic(mOpInfo.mCtr.mDeterministic); + rst = rst && mCtx.SetInputs({&mParam.mTensorLists["x"], &mParam.mTensorLists["weight"], &mParam.mTensorLists["bias"], + &mParam.mTensorLists["scale"], &mParam.mTensorLists["offset"], + &mParam.mTensorLists["antiquant_scale"], &mParam.mTensorLists["antiquant_offset"], + &mParam.mGroupList, &mParam.mPerTokenScale}); + rst = rst && mCtx.SetOutputs({&mParam.mTensorLists["y"]}); + rst = rst && mCtx.SetAttrs({{"split_item", mParam.mSplitItem}, + {"dtype", mParam.mDtype}, + {"transpose_weight", mParam.mTransposeWeight}, + {"transpose_x", mParam.mTransposeX}, + {"group_type", mParam.mGroupType}, + {"group_list_type", mParam.mGroupListType}, + {"act_type", mParam.mActType}}); + rst = rst && mCtx.SetKernelRunCbf(RunGroupedMatmul); + rst = rst && mCtx.SetKernelMainFunc((void *)groupedMatmulKernelFunc); + rst = rst && mOpInfo.SetContext(&mCtx); + return rst; +} + +bool GroupedMatmulCase::InitCurrentCasePtr() +{ + Case::mCurrentCasePtr = this; + return true; +} + diff --git a/tests/ut/ops_test/src/transformer/grouped_matmul/comm/src/grouped_matmul_param.cpp b/tests/ut/ops_test/src/transformer/grouped_matmul/comm/src/grouped_matmul_param.cpp new file mode 100644 index 00000000..c7f624c6 --- /dev/null +++ b/tests/ut/ops_test/src/transformer/grouped_matmul/comm/src/grouped_matmul_param.cpp @@ -0,0 +1,43 @@ +#/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_param.cpp + * \brief GroupedMatmul 参数信息. + */ + +#include "grouped_matmul_param.h" + +using namespace ops::adv::tests::grouped_matmul; + +Param::Param(std::vector inputs, Tensor perTokenScale, Tensor groupList, + std::vector groupListData, int32_t splitItem, int32_t dType, + bool transposeWeight, bool transposeX, int32_t groupType, int32_t groupListType, int32_t actType) + : mPerTokenScale(perTokenScale), mGroupListData(std::move(groupListData)), mSplitItem(splitItem), + mDtype(dType), mTransposeWeight(transposeWeight), mTransposeX(transposeX), + mGroupType(groupType), mGroupListType(groupListType), mActType(actType) +{ + for (auto &tensorList : inputs) { + mTensorLists[tensorList.Name()] = tensorList; + } + mGroupList = groupList; +} + +Tensor ops::adv::tests::grouped_matmul::GenTensor(const char *name, const std::initializer_list &shape, + ge::DataType dType, ge::Format format) +{ + return Tensor(name, shape, "", dType, format); +} + +TensorList ops::adv::tests::grouped_matmul::GenTensorList(const char *name, const std::vector> &shapes, + ge::DataType dType, ge::Format format) +{ + return TensorList(name, shapes, "", dType, format); +} \ No newline at end of file diff --git a/tests/ut/ops_test/src/transformer/grouped_matmul/utest/ts_grouped_matmul.h b/tests/ut/ops_test/src/transformer/grouped_matmul/utest/ts_grouped_matmul.h new file mode 100644 index 00000000..23415f1e --- /dev/null +++ b/tests/ut/ops_test/src/transformer/grouped_matmul/utest/ts_grouped_matmul.h @@ -0,0 +1,29 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file ts_grouped_matmul.h + * \brief GroupedMatmul UTest 相关基类定义. + */ + +#ifndef UTEST_TS_GROUPEDMATMUL_H +#define UTEST_TS_GROUPEDMATMUL_H + +#include "tests/utest/ts.h" +#include "grouped_matmul_case.h" + +using ops::adv::tests::grouped_matmul::GroupedMatmulCase; +using ops::adv::tests::grouped_matmul::GenTensor; +using ops::adv::tests::grouped_matmul::GenTensorList; +using ops::adv::tests::grouped_matmul::Param; + +class Ts_GroupedMatmul_WithParam_Ascend910B3 : public Ts_WithParam_Ascend910B3 {}; + +#endif // UTEST_TS_GROUPEDMATMUL_H \ No newline at end of file diff --git a/tests/ut/ops_test/src/transformer/grouped_matmul/utest/ts_grouped_matmul_kernel.cpp b/tests/ut/ops_test/src/transformer/grouped_matmul/utest/ts_grouped_matmul_kernel.cpp new file mode 100644 index 00000000..66b1da1c --- /dev/null +++ b/tests/ut/ops_test/src/transformer/grouped_matmul/utest/ts_grouped_matmul_kernel.cpp @@ -0,0 +1,638 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file ts_grouped_matmul_kernel.cpp + * \brief GroupedMatmul kernel用例. + */ + +#include "ts_grouped_matmul.h" + +namespace { +TEST_P(Ts_GroupedMatmul_WithParam_Ascend910B3, Tc_Kernel_GroupedMatmul) +{ + ASSERT_TRUE(case_->Init()); + ASSERT_EQ(case_->Run(), case_->mOpInfo.mExp.mSuccess); +} + +const auto Tc_GroupedMatmul_Kernel_Case = ::testing::Values( + GroupedMatmulCase( /* no split, m-m-m */ + "GroupedMatmul_Case0", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{256, 256}, {1024, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{256, 256}, {256, 1024}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{256}, {1024}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{256, 256}, {1024, 1024}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {}, ge::DataType::DT_INT64), + {}, 0, -1, false, false, -1, 1, 0), + 0), + GroupedMatmulCase( /* split M, m-m-m */ + "GroupedMatmul_Case1", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{128, 256}, {1024, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{256, 256}, {256, 1024}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{256}, {1024}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{128, 256}, {1024, 1024}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {}, ge::DataType::DT_INT64), + {}, 0, -1, false, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* split M, s-m-s */ + "GroupedMatmul_Case2", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{256, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{256, 256}, {256, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{256}, {256}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{256, 256}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {128, 128}, 3, -1, false, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* split M, s-m-m */ + "GroupedMatmul_Case3", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{256, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{256, 256}, {256, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{256}, {256}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{128, 256}, {128, 256}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {128, 128}, 0, -1, false, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* split K s-s-s + transpose x */ + "GroupedMatmul_Case4", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 1, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{256, 768}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{768, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{256, 256}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {256, 512}, 3, -1, false, true, 2, 1, 0), + 0), + GroupedMatmulCase( /* split K m-s-m + transpose x */ + "GroupedMatmul_Case5", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 1, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{256, 256}, {256, 512}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{768, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{256, 256}, {256, 256}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {}, ge::DataType::DT_INT64), + {}, 0, -1, false, true, 2, 1, 0), + 0), + GroupedMatmulCase( /* single tensor + transpose weight */ + "GroupedMatmul_Case6", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 2, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{256, 256}, {1024, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{2, 256, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{256, 256}, {1024, 256}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {}, ge::DataType::DT_INT64), + {}, 0, -1, true, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* single tensor + w nz */ + "GroupedMatmul_Case7", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{64, 64}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{2, 4, 4, 16, 16}}, ge::DataType::DT_FLOAT16, ge::FORMAT_FRACTAL_NZ), + GenTensorList("bias", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{32, 64}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {32, 32}, 3, -1, false, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* per token quant */ + "GroupedMatmul_Case8", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{16, 5}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{2, 5, 10}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{0}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{2, 10}}, ge::DataType::DT_BF16), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{16, 10}}, ge::DataType::DT_BF16)}, + GenTensor("per_token_scale", {16}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {8, 8}, 3, 1, false, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* per token quant */ + "GroupedMatmul_Case9", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, false), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{360, 1024}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{16, 1024, 8192}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{0}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{16, 8192}}, ge::DataType::DT_BF16), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{360, 8192}}, ge::DataType::DT_BF16)}, + GenTensor("per_token_scale", {360}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {16}, ge::DataType::DT_INT64), + {40, 40, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20}, 2, 1, false, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* per token quant */ + "GroupedMatmul_Case10", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, false), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{360, 8192}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{16, 8192, 1024}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{0}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{16, 1024}}, ge::DataType::DT_BF16), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{360, 1024}}, ge::DataType::DT_BF16)}, + GenTensor("per_token_scale", {360}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {16}, ge::DataType::DT_INT64), + {40, 40, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20}, 2, 1, false, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* per token quant act */ + "GroupedMatmul_Case11", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{16, 10}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{2, 10, 10}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{0}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{2, 10}}, ge::DataType::DT_BF16), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{16, 10}}, ge::DataType::DT_BF16)}, + GenTensor("per_token_scale", {16}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {8, 8}, 3, 1, false, false, 0, 1, 1), + 0), + GroupedMatmulCase( /* quant act */ + "GroupedMatmul_Case12", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{16, 20}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{2, 20, 10}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{0}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{2, 10}}, ge::DataType::DT_BF16), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{16, 10}}, ge::DataType::DT_BF16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {8, 8}, 3, 1, false, false, 0, 1, 1), + 0), + GroupedMatmulCase( /* pertoken quant fp16*/ + "GroupedMatmul_Case13", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{16, 5}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{2, 5, 10}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{0}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{2, 10}}, ge::DataType::DT_FLOAT), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{16, 10}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {16}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {8, 8}, 3, 0, false, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* per token quant weight NZ */ + "GroupedMatmul_Case14", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{16, 128}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{4, 8, 8, 16, 32}}, ge::DataType::DT_INT8, ge::FORMAT_FRACTAL_NZ), + GenTensorList("bias", {{0}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{4, 256}}, ge::DataType::DT_BF16), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{16, 256}}, ge::DataType::DT_BF16)}, + GenTensor("per_token_scale", {16}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {4, 4, 4, 4}, 3, 1, false, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* pertoken quant fp16*/ + "GroupedMatmul_Case15", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{32, 5}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{2, 5, 10}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{0}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{2, 10}}, ge::DataType::DT_FLOAT), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{32, 10}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {32}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {16, 16}, 3, 0, false, false, 0, 1, 1), + 0), + GroupedMatmulCase( /* pertoken quant fp16*/ + "GroupedMatmul_Case16", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 4, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{32, 5}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{2, 5, 20}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{0}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{2, 20}}, ge::DataType::DT_FLOAT), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{32, 20}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {32}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {16, 16}, 3, 0, false, false, 0, 1, 2), + 0), + GroupedMatmulCase( /* pertoken quant fp16*/ + "GroupedMatmul_Case17", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{32, 5}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{2, 5, 30}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{0}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{2, 30}}, ge::DataType::DT_FLOAT), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{32, 30}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {32}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {16, 16}, 3, 0, false, false, 0, 1, 4), + 0), + GroupedMatmulCase( /* pertoken quant fp16*/ + "GroupedMatmul_Case18", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{32, 5}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{2, 5, 40}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{0}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{2, 40}}, ge::DataType::DT_FLOAT), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{32, 40}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {32}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {16, 16}, 3, 0, false, false, 0, 1, 5), + 0), + GroupedMatmulCase( /* pertoken quant fp16*/ + "GroupedMatmul_Case19", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 5, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{16, 5}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{2, 10, 5}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{0}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{2, 10}}, ge::DataType::DT_FLOAT), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{16, 10}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {16}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {8, 8}, 3, 0, true, false, 0, 1, 2), + 0), + GroupedMatmulCase( /* quant + transpose weight*/ + "GroupedMatmul_Case20", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 2, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{16, 64}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{4, 64, 32}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{4, 32}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{4, 32}}, ge::DataType::DT_FLOAT), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{16, 32}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {4, 4, 4, 4}, 3, 1, true, false, -1, 1, 0), + 0), + GroupedMatmulCase( /* quant 有问题*/ + "GroupedMatmul_Case21", false, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{256, 256}, {1024, 256}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{256, 256}, {256, 1024}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{256}, {1024}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{256}, {1024}}, ge::DataType::DT_UINT64), + GenTensorList("offset", {{256}, {1024}}, ge::DataType::DT_FLOAT), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{256, 256}, {1024, 1024}}, ge::DataType::DT_INT8)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {}, ge::DataType::DT_INT64), + {}, 0, -1, false, false, -1, 1, 0), + 0), + GroupedMatmulCase( /* a16w8 fp16 */ + "GroupedMatmul_Case22", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{256, 256}, {128, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{256, 256}, {256, 128}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{256}, {128}}, ge::DataType::DT_FLOAT16), + GenTensorList("scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{256}, {128}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{256}, {128}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{256, 256}, {128, 128}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {}, ge::DataType::DT_INT64), + {}, 0, -1, false, false, -1, 1, 0), + 0), + GroupedMatmulCase( /* a16w8 bf16 */ + "GroupedMatmul_Case23", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{256, 256}, {64, 256}}, ge::DataType::DT_BF16), + GenTensorList("weight", {{256, 256}, {256, 64}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{256}, {64}}, ge::DataType::DT_FLOAT), + GenTensorList("scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{256}, {64}}, ge::DataType::DT_BF16), + GenTensorList("antiquant_offset", {{256}, {64}}, ge::DataType::DT_BF16), + GenTensorList("y", {{256, 256}, {64, 64}}, ge::DataType::DT_BF16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {}, ge::DataType::DT_INT64), + {}, 0, -1, false, false, -1, 1, 0), + 0), + GroupedMatmulCase( /* a16w8 bf16 */ + "GroupedMatmul_Case24", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{128, 256}}, ge::DataType::DT_BF16), + GenTensorList("weight", {{256, 256}, {256, 256}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{2, 256}}, ge::DataType::DT_FLOAT), + GenTensorList("scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{2, 256}}, ge::DataType::DT_BF16), + GenTensorList("antiquant_offset", {{2, 256}}, ge::DataType::DT_BF16), + GenTensorList("y", {{128, 256}}, ge::DataType::DT_BF16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {8, 8}, 3, -1, false, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* a16w4 fp16 */ + "GroupedMatmul_Case25", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{256, 256},}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{4, 256, 256}}, ge::DataType::DT_INT4), + GenTensorList("bias", {{4, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{4, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{4, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{256, 256}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {64, 64, 64, 64}, 3, -1, false, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* a16w4 fp16 transpose weight*/ + "GroupedMatmul_Case26", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 02, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{256, 256},}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{8, 256, 256}}, ge::DataType::DT_INT4), + GenTensorList("bias", {{8, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{8, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{8, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{256, 256}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {32, 32, 32, 32, 32, 32, 32, 32}, 3, -1, true, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* a16w4 bf16 */ + "GroupedMatmul_Case27", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{256, 256},}, ge::DataType::DT_BF16), + GenTensorList("weight", {{4, 256, 256}}, ge::DataType::DT_INT4), + GenTensorList("bias", {{4, 256}}, ge::DataType::DT_FLOAT), + GenTensorList("scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{4, 256}}, ge::DataType::DT_BF16), + GenTensorList("antiquant_offset", {{4, 256}}, ge::DataType::DT_BF16), + GenTensorList("y", {{256, 256}}, ge::DataType::DT_BF16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {64, 64, 64, 64}, 3, -1, false, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* a16w8 fp16 + transpose weight */ + "GroupedMatmul_Case28", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 2, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{256, 256}, {128, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{256, 256}, {128, 256}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{256}, {128}}, ge::DataType::DT_FLOAT16), + GenTensorList("scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{256}, {128}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{256}, {128}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{256, 256}, {128, 128}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {}, ge::DataType::DT_INT64), + {}, 0, -1, true, false, -1, 1, 0), + 0), + GroupedMatmulCase( /* antiquant performance */ + "GroupedMatmul_Case29", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 3, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{1024, 2048}, {1024, 2048}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{2048, 1000}, {2048, 1000}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{1000}, {1000}}, ge::DataType::DT_FLOAT16), + GenTensorList("scale", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{1000}, {1000}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{1000}, {1000}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{1024, 1000}, {1024, 1000}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {256, 256, 256, 256}, 0, -1, false, false, -1, 1, 0), + 0), + GroupedMatmulCase( /* antiquant msd */ + "GroupedMatmul_Case30", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 6, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{16, 128}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{4, 128, 1024}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{4, 1024}}, ge::DataType::DT_FLOAT16), + GenTensorList("scale", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{4, 1024}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{4, 1024}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{16, 1024}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {4, 4, 4, 4}, 3, -1, false, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* antiquant msd + transpsoe weight */ + "GroupedMatmul_Case31", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 7, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{16, 128}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{2, 1024, 128}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{2, 1024}}, ge::DataType::DT_FLOAT16), + GenTensorList("scale", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{2, 1024}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{2, 1024}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{16, 1024}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {8, 8}, 3, -1, true, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* antiquant msd bf16 */ + "GroupedMatmul_Case32", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 6, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{16, 128}}, ge::DataType::DT_BF16), + GenTensorList("weight", {{4, 128, 1024}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{4, 1024}}, ge::DataType::DT_FLOAT), + GenTensorList("scale", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{4, 1024}}, ge::DataType::DT_BF16), + GenTensorList("antiquant_offset", {{4, 1024}}, ge::DataType::DT_BF16), + GenTensorList("y", {{16, 1024}}, ge::DataType::DT_BF16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {4, 4, 4, 4}, 3, -1, false, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* antiquant msd bf16 + transpsoe weight */ + "GroupedMatmul_Case33", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 7, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{16, 128}}, ge::DataType::DT_BF16), + GenTensorList("weight", {{2, 1024, 128}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{2, 1024}}, ge::DataType::DT_FLOAT), + GenTensorList("scale", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{2, 1024}}, ge::DataType::DT_BF16), + GenTensorList("antiquant_offset", {{2, 1024}}, ge::DataType::DT_BF16), + GenTensorList("y", {{16, 1024}}, ge::DataType::DT_BF16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {8, 8}, 3, -1, true, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* antiquant pergroup */ + "GroupedMatmul_Case34", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{16, 16}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{2, 16, 10}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{2, 10}}, ge::DataType::DT_FLOAT16), + GenTensorList("scale", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{2, 2, 10}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{2, 2, 10}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{16, 10}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {8, 8}, 0, -1, false, false, -1, 1, 0), + 0), + GroupedMatmulCase( /* split K s-s-s + transpose x */ + "GroupedMatmul_Case35", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 1, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{256, 768}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{768, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{256, 256}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {256, 512, 512, 768}, 3, -1, false, true, 2, 0, 0), + 0), + GroupedMatmulCase( /* quant int8*/ + "GroupedMatmul_Case36", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{16, 5}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{5, 10}, {5, 10}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{0}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{2, 10}}, ge::DataType::DT_UINT64), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{16, 10}}, ge::DataType::DT_INT8)}, + GenTensor("per_token_scale", {16}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {8, 8}, 3, -1, false, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* quant int8*/ + "GroupedMatmul_Case37", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, true), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{16, 5}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{2, 5, 10}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{0}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{2, 10}}, ge::DataType::DT_UINT64), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{16, 10}}, ge::DataType::DT_INT8)}, + GenTensor("per_token_scale", {16}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {8, 8}, 3, -1, false, false, 0, 1, 0), + 0)); + +INSTANTIATE_TEST_SUITE_P(GroupedMatmul, Ts_GroupedMatmul_WithParam_Ascend910B3, Tc_GroupedMatmul_Kernel_Case); +} // namespace \ No newline at end of file diff --git a/tests/ut/ops_test/src/transformer/grouped_matmul/utest/ts_grouped_matmul_tiling.cpp b/tests/ut/ops_test/src/transformer/grouped_matmul/utest/ts_grouped_matmul_tiling.cpp new file mode 100644 index 00000000..fd627327 --- /dev/null +++ b/tests/ut/ops_test/src/transformer/grouped_matmul/utest/ts_grouped_matmul_tiling.cpp @@ -0,0 +1,233 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file ts_grouped_matmul_tiling.cpp + * \brief GroupedMatmul tiling用例. + */ + + #include "ts_grouped_matmul.h" + +namespace { +TEST_P(Ts_GroupedMatmul_WithParam_Ascend910B3, Tc_Tiling_GroupedMatmul) +{ + ASSERT_TRUE(case_->Init()); + ASSERT_EQ(case_->Run(), case_->mOpInfo.mExp.mSuccess); +} + +const auto Tc_GroupedMatmul_Tiling_Case = ::testing::Values( + GroupedMatmulCase( + "GroupedMatmul_Case0", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, false), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{256, 512}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{512, 128}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{128}}, ge::DataType::DT_FLOAT16), + GenTensorList("scale", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{256, 128}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {64, 64, 64, 64}, 0, -1, false, false, -1, 1, 0), + 0), + GroupedMatmulCase( + "GroupedMatmul_Case1", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, false), + ExpectInfo(false, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{256, 300}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{300, 128}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{128}}, ge::DataType::DT_FLOAT16), + GenTensorList("scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{256, 128}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {64, 64, 64, 64}, 0, -1, false, false, 1, 1, 0), + 0), + GroupedMatmulCase( + "GroupedMatmul_Case2", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, false), + ExpectInfo(false, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{256, 400}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{400, 128}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{128}}, ge::DataType::DT_FLOAT16), + GenTensorList("scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{256, 128}, {256, 128}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {128, 128}, 0, -1, false, false, 0, 1, 0), + 0), + GroupedMatmulCase( + "GroupedMatmul_Case3", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, false), + ExpectInfo(false, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{256, 300}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{300, 128}, {300, 128}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{128}}, ge::DataType::DT_FLOAT16), + GenTensorList("scale", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{256, 128}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {128, 128}, 3, -1, false, false, 2, 1, 0), + 0), + GroupedMatmulCase( /* single tensor split m */ + "GroupedMatmul_Case4", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, false), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{256, 512}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{4, 512, 128}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{4, 128}}, ge::DataType::DT_FLOAT16), + GenTensorList("scale", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{256, 128}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {64, 64, 64, 64}, 3, -1, false, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* single tensor + transpose w */ + "GroupedMatmul_Case5", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, false), + ExpectInfo(true, 2, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{256, 512}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{4, 128, 512}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{4, 128}}, ge::DataType::DT_FLOAT16), + GenTensorList("scale", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{256, 128}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {64, 64, 64, 64}, 3, -1, true, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* single tensor + transpose x */ + "GroupedMatmul_Case6", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, false), + ExpectInfo(true, 1, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{512, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{4, 512, 128}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{4, 128}}, ge::DataType::DT_FLOAT16), + GenTensorList("scale", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{256, 128}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {64, 64, 64, 64}, 3, -1, false, true, 0, 1, 0), + 0), + GroupedMatmulCase( /* antiquant */ + "GroupedMatmul_Case7", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, false), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{16, 5}, {16, 5}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{5, 10}, {5, 10}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{10}, {10}}, ge::DataType::DT_FLOAT16), + GenTensorList("scale", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("offset", {{}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{10}, {10}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{10}, {10}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{16, 10}, {16, 10}}, ge::DataType::DT_FLOAT16)}, + GenTensor("per_token_scale", {}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {}, ge::DataType::DT_INT64), + {}, 0, -1, false, false, -1, 0, 0), + 0), + GroupedMatmulCase( /* per token quant weight NZ */ + "GroupedMatmul_Case8", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, false), + ExpectInfo(false, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{360, 1024}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{16, 1024, 8191}}, ge::DataType::DT_INT8, ge::FORMAT_FRACTAL_NZ), + GenTensorList("bias", {{0}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{16, 8191}}, ge::DataType::DT_BF16), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{360, 8191}}, ge::DataType::DT_BF16)}, + GenTensor("per_token_scale", {360}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {16}, ge::DataType::DT_INT64), + {40, 40, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20}, 2, 1, false, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* per token quant weight NZ */ + "GroupedMatmul_Case9", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, false), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{360, 1024}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{16, 256, 64, 16, 32}}, ge::DataType::DT_INT8, ge::FORMAT_FRACTAL_NZ), + GenTensorList("bias", {{0}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{16, 8192}}, ge::DataType::DT_BF16), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{360, 8192}}, ge::DataType::DT_BF16)}, + GenTensor("per_token_scale", {360}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {16}, ge::DataType::DT_INT64), + {20, 20, 40, 40, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20}, 2, 1, false, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* per token quant weight NZ */ + "GroupedMatmul_Case10", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, false), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{360, 8192}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{32, 32, 512, 16, 32}}, ge::DataType::DT_INT8, ge::FORMAT_FRACTAL_NZ), + GenTensorList("bias", {{0}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{32, 1024}}, ge::DataType::DT_BF16), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{360, 1024}}, ge::DataType::DT_BF16)}, + GenTensor("per_token_scale", {360}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {32}, ge::DataType::DT_INT64), + {11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 11, 12, 11, 12, 11, 12, + 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 11, 12, 11, 12, 11, 12}, 2, 1, false, false, 0, 1, 0), + 0), + GroupedMatmulCase( /* per token quant weight NZ */ + "GroupedMatmul_Case11", true, "", /* CaseName, Enable, DebugInfo */ + OpInfo(ControlInfo(true, false), + ExpectInfo(true, 0, + ExpectInfo::kFullTilingBlockDim)), /* ExpectSuccess, ExpectTilingKey, ExpectTilingBlockDim */ + Param({GenTensorList("x", {{120, 8192}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{4, 32, 512, 16, 32}}, ge::DataType::DT_INT8, ge::FORMAT_FRACTAL_NZ), + GenTensorList("bias", {{0}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{4, 1024}}, ge::DataType::DT_BF16), + GenTensorList("offset", {{0}}, ge::DataType::DT_FLOAT), + GenTensorList("antiquant_scale", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{0}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{120, 1024}}, ge::DataType::DT_BF16)}, + GenTensor("per_token_scale", {120}, ge::DataType::DT_FLOAT), + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {30, 30, 30, 30}, 2, 1, false, false, 0, 1, 0), + 0)); + +INSTANTIATE_TEST_SUITE_P(GroupedMatmul, Ts_GroupedMatmul_WithParam_Ascend910B3, Tc_GroupedMatmul_Tiling_Case); +} // namespace \ No newline at end of file diff --git a/tests/ut/ops_test/src/transformer/grouped_matmul/utest_aclnn/ts_aclnn_grouped_matmul.cpp b/tests/ut/ops_test/src/transformer/grouped_matmul/utest_aclnn/ts_aclnn_grouped_matmul.cpp new file mode 100644 index 00000000..ecdc1dab --- /dev/null +++ b/tests/ut/ops_test/src/transformer/grouped_matmul/utest_aclnn/ts_aclnn_grouped_matmul.cpp @@ -0,0 +1,385 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file ts_aclnn_grouped_matmul.cpp + * \brief GroupedMatmul ACLNN 测试用例. + */ + +#include "ts_aclnn_grouped_matmul.h" + +namespace { +TEST_P(Ts_Aclnn_GroupedMatmul_WithParam_Ascend910B3, Tc_Aclnn_GroupedMatmul) +{ + ASSERT_TRUE(case_->Init()); + ASSERT_EQ(case_->Run(), case_->mOpInfo.mExp.mSuccess); +} + +const auto Tc_GroupedMatmul_Aclnn_Case = ::testing::Values( + + AclnnGroupedMatmulCase("Test_GMMV1_SPLIT_0", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(true, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{50, 5120}, {65, 5120}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{5120, 2560}, {5120, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{2560}, {2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{50, 2560}, {65, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {0}, ge::DataType::DT_INT64), + {}, 0, -1, false, false, -1, 1, 0, FunctionType::NO_QUANT, + AclnnGroupedMatmulVersion::V1)), + AclnnGroupedMatmulCase("Test_GMMV1_SPLIT_3", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(true, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{50, 5120}, {15, 5120}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{5120, 2560}, {5120, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{2560}, {2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{65, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {50, 65}, 2, -1, false, false, 0, 1, 0, FunctionType::NO_QUANT, + AclnnGroupedMatmulVersion::V2)), + AclnnGroupedMatmulCase("Test_GMMV2_01", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(true, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{84, 5120}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{4, 5120, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{4, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{84, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {50, 65, 69, 84}, 3, -1, false, false, 0, 1, 0, FunctionType::NO_QUANT, + AclnnGroupedMatmulVersion::V3)), + AclnnGroupedMatmulCase("Test_GMMV4_01", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(true, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{84, 5120}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{4, 5120, 2560}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{4, 2560}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{4, 2560}}, ge::DataType::DT_INT64), + GenTensorList("y", {{84, 2560}}, ge::DataType::DT_INT8)}, + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {50, 65, 69, 84}, 3, -1, false, false, 0, 1, 0, FunctionType::QUANT, + AclnnGroupedMatmulVersion::V4)), + AclnnGroupedMatmulCase("Test_GMMV4_02", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(true, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{84, 5120}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{4, 5120, 2560}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{4, 2560}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{4, 2560}}, ge::DataType::DT_FLOAT), + GenTensorList("y", {{84, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {50, 65, 69, 84}, 3, -1, true, false, 0, 1, 0, FunctionType::QUANT, + AclnnGroupedMatmulVersion::V4)), + AclnnGroupedMatmulCase("Test_GMMV4_03", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(true, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{84, 5120}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{4, 5120, 2560}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{4, 2560}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{4, 2560}}, ge::DataType::DT_BF16), + GenTensorList("y", {{84, 2560}}, ge::DataType::DT_BF16)}, + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {50, 65, 69, 84}, 3, -1, false, false, 0, 1, 0, FunctionType::QUANT, + AclnnGroupedMatmulVersion::V4)), + AclnnGroupedMatmulCase("Test_GMMV4_04", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(true, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{84, 5120}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{4, 5120, 2560}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{4, 2560}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{4, 2560}}, ge::DataType::DT_BF16), + GenTensorList("pertoken_scale", {{84}}, ge::DataType::DT_FLOAT), + GenTensorList("y", {{84, 2560}}, ge::DataType::DT_BF16)}, + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {50, 65, 69, 84}, 3, -1, false, false, 0, 1, 0, FunctionType::QUANT_PERTOKEN, + AclnnGroupedMatmulVersion::V4)), + AclnnGroupedMatmulCase("Test_GMMV4_05", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(true, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{84, 5120}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{4, 5120, 2560}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{4, 2560}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{4, 2560}}, ge::DataType::DT_FLOAT), + GenTensorList("pertoken_scale", {{84}}, ge::DataType::DT_FLOAT), + GenTensorList("y", {{84, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {50, 65, 69, 84}, 3, -1, false, false, 0, 1, 0, FunctionType::QUANT_PERTOKEN, + AclnnGroupedMatmulVersion::V4)), + AclnnGroupedMatmulCase("Test_GMMV4_06", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(true, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{84, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{4, 2560, 2560}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{4, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{4, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{4, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{84, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {50, 65, 69, 84}, 3, -1, false, false, 0, 1, 0, FunctionType::ANTIQUANT, + AclnnGroupedMatmulVersion::V4)), + AclnnGroupedMatmulCase("Test_GMMV2_07", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(true, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{50, 5120}, {15, 5120}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{5120, 2560}, {5120, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{2560}, {2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{50, 2560}, {15, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {0}, ge::DataType::DT_INT64), + {}, 0, -1, false, false, -1, 1, 0, FunctionType::NO_QUANT, + AclnnGroupedMatmulVersion::V2)), + AclnnGroupedMatmulCase("Test_GMMV1_08", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(true, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{50, 5120}, {15, 5120}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{5120, 2560}, {5120, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{2560}, {2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{65, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {50, 65}, 3, -1, false, false, 0, 0, 0, FunctionType::NO_QUANT, + AclnnGroupedMatmulVersion::V1)), + AclnnGroupedMatmulCase("Test_GMMV4_09", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(true, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{65, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{2560, 2560}, {2560, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{65, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {50, 65}, 3, -1, false, false, 0, 0, 0, FunctionType::NO_QUANT, + AclnnGroupedMatmulVersion::V4)), + AclnnGroupedMatmulCase("Test_GMMV4_10", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(true, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{65, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{2560, 2560}, {2560, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{50, 2560}, {15, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {50, 65}, 0, -1, false, false, 0, 0, 0, FunctionType::NO_QUANT, + AclnnGroupedMatmulVersion::V1)), + AclnnGroupedMatmulCase("Test_GMMV4_11", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(true, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{1280, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{2560, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{2, 1280, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {1280, 1280}, 3, -1, false, true, 2, 1, 0, FunctionType::NO_QUANT, + AclnnGroupedMatmulVersion::V4)), + AclnnGroupedMatmulCase("Test_GMMV4_12", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(true, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{1280, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{2560, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{2, 1280, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {1280, 2560}, 3, -1, false, true, 2, 1, 0, FunctionType::NO_QUANT, + AclnnGroupedMatmulVersion::V2)), + AclnnGroupedMatmulCase("Test_GMMV4_13", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(true, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{16, 16}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{2, 16, 10}}, ge::DataType::DT_INT4), + GenTensorList("bias", {{2, 10}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_scale", {{2, 2, 10}}, ge::DataType::DT_FLOAT16), + GenTensorList("antiquant_offset", {{2, 2, 10}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{16, 10}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {8, 8}, 3, -1, false, false, 0, 1, 0, FunctionType::ANTIQUANT, + AclnnGroupedMatmulVersion::V4)), + AclnnGroupedMatmulCase("Test_GMMV1_SPLIT_0", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(false, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{50, 5120}, {65, 5120}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{5120, 2560}, {5120, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{2560}, {2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{50, 2560}, {65, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {0}, ge::DataType::DT_INT64), + {}, 0, -1, false, false, -1, 1, 6, FunctionType::NO_QUANT, + AclnnGroupedMatmulVersion::V4)), + AclnnGroupedMatmulCase("Test_GMMV1__Error0", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(false, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{65, 5120}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{2, 5120, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{2, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{65, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {50, 65}, 1, -1, false, false, -1, 1, 0, FunctionType::NO_QUANT, + AclnnGroupedMatmulVersion::V1)), + AclnnGroupedMatmulCase("Test_GMMV1_Error_1", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(false, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{50, 5120}, {65, 5120}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{5120, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{50, 2560}, {65, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {0}, ge::DataType::DT_INT64), + {}, 0, -1, false, false, -1, 1, 0, FunctionType::NO_QUANT, + AclnnGroupedMatmulVersion::V1)), + AclnnGroupedMatmulCase("Test_GMMV2_Error_0", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(false, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{50, 512}, {15, 512}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{512, 256}, {512, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{256}, {256}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{65, 256}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {50, 65}, 2, -1, false, false, 1, 1, 0, FunctionType::NO_QUANT, + AclnnGroupedMatmulVersion::V2)), + AclnnGroupedMatmulCase("Test_GMMV2_Error_1", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(false, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{2, 128, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{256, 256}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{2, 128, 256}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {128, 256}, 3, -1, false, true, 2, 1, 0, FunctionType::NO_QUANT, + AclnnGroupedMatmulVersion::V2)), + AclnnGroupedMatmulCase("Test_GMMV2_Error_2", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(false, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{1280, 1280}, {1280, 1280}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{2560, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{2, 1280, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {1280, 2560}, 3, -1, false, true, 2, 1, 0, FunctionType::NO_QUANT, + AclnnGroupedMatmulVersion::V2)), + AclnnGroupedMatmulCase("Test_GMMV2_Error_3", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(false, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{65, 5120}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{2, 5120, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{2, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{65, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {40, 50, 65}, 2, -1, false, false, 0, 1, 0, FunctionType::NO_QUANT, + AclnnGroupedMatmulVersion::V2)), + AclnnGroupedMatmulCase("Test_GMMV2_Error_4", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(false, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{50, 5120}, {15, 5120}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{5120, 2560}, {5120, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{2560}, {2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{40, 2560}, {25, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {}, ge::DataType::DT_INT64), + {}, 0, -1, false, false, -1, 1, 0, FunctionType::NO_QUANT, + AclnnGroupedMatmulVersion::V2)), + AclnnGroupedMatmulCase("Test_GMMV2_Error_5", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(false, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{50, 5120}, {15, 5120}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{5120, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{2560}, {2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{50, 2560}, {15, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {50, 65}, 2, -1, false, false, 0, 1, 0, FunctionType::NO_QUANT, + AclnnGroupedMatmulVersion::V2)), + AclnnGroupedMatmulCase("Test_GMMV2_Error_6", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(false, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{20, 512}, {20, 512}, {25, 512}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{512, 2560}, {512, 2560}, {512, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("bias", {{2560}, {2560}, {2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{65, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {20, 50, 65}, 2, -1, false, false, 0, 1, 0, FunctionType::NO_QUANT, + AclnnGroupedMatmulVersion::V2)), + AclnnGroupedMatmulCase("Test_GMMV4_Error_0", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(false, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{2560, 1280}}, ge::DataType::DT_FLOAT16), + GenTensorList("weight", {{2, 2560, 2560}}, ge::DataType::DT_FLOAT16), + GenTensorList("y", {{1280, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {1280, 1280}, 3, -1, false, false, 2, 1, 0, FunctionType::NO_QUANT, + AclnnGroupedMatmulVersion::V4)), + AclnnGroupedMatmulCase("Test_GMMV2_Error_1", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(false, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{84, 512}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{512, 2560}, {512, 2560}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{2, 2560}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{2, 2560}}, ge::DataType::DT_FLOAT), + GenTensorList("pertoken_scale", {{84}}, ge::DataType::DT_FLOAT), + GenTensorList("y", {{84, 2560}}, ge::DataType::DT_FLOAT16)}, + GenTensor("grouped_list", {2}, ge::DataType::DT_INT64), + {50, 84}, 3, -1, false, false, 0, 1, 0, FunctionType::QUANT_PERTOKEN, + AclnnGroupedMatmulVersion::V4)), + AclnnGroupedMatmulCase("Test_GMMV2_Error_2", true, "", /* CaseName,Enable,DebugInfo */ + OpInfo(ControlInfo(true, true), /* RunTiling,RunKernel */ + ExpectInfo(false, /* ExpectSuccess */ + ExpectInfo::kInvalidTilingKey, /* ExpectTilingKey */ + ExpectInfo::kInvalidTilingBlockDim)), /* ExpectTilingBlockDim */ + AclnnGroupedMatmulParam({GenTensorList("x", {{84, 5120}}, ge::DataType::DT_INT8), + GenTensorList("weight", {{4, 5120, 2560}}, ge::DataType::DT_INT8), + GenTensorList("bias", {{4, 2560}}, ge::DataType::DT_INT32), + GenTensorList("scale", {{4, 2560}}, ge::DataType::DT_BF16), + GenTensorList("y", {{84, 2560}}, ge::DataType::DT_FLOAT)}, + GenTensor("grouped_list", {4}, ge::DataType::DT_INT64), + {50, 65, 69, 84}, 3, -1, false, false, 0, 1, 0, FunctionType::QUANT, + AclnnGroupedMatmulVersion::V4)) +); + +INSTANTIATE_TEST_SUITE_P(GroupedMatmul, Ts_Aclnn_GroupedMatmul_WithParam_Ascend910B3, Tc_GroupedMatmul_Aclnn_Case); +} // namespace \ No newline at end of file diff --git a/tests/ut/ops_test/src/transformer/grouped_matmul/utest_aclnn/ts_aclnn_grouped_matmul.h b/tests/ut/ops_test/src/transformer/grouped_matmul/utest_aclnn/ts_aclnn_grouped_matmul.h new file mode 100644 index 00000000..f3e50250 --- /dev/null +++ b/tests/ut/ops_test/src/transformer/grouped_matmul/utest_aclnn/ts_aclnn_grouped_matmul.h @@ -0,0 +1,32 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file ts_grouped_matmul.h + * \brief GroupedMatmul UTest 相关基类定义. + */ + +#ifndef UTEST_TS_GROUPEDMATMUL_H +#define UTEST_TS_GROUPEDMATMUL_H + +#include "tests/utest/ts.h" +#include "grouped_matmul_case.h" +#include "aclnn_grouped_matmul_case.h" + +using ops::adv::tests::grouped_matmul::AclnnGroupedMatmulCase; +using ops::adv::tests::grouped_matmul::GenTensor; +using ops::adv::tests::grouped_matmul::GenTensorList; +using AclnnGroupedMatmulParam = ops::adv::tests::grouped_matmul::AclnnGroupedMatmulParam; +using FunctionType = ops::adv::tests::grouped_matmul::AclnnGroupedMatmulParam::FunctionType; +using AclnnGroupedMatmulVersion = ops::adv::tests::grouped_matmul::AclnnGroupedMatmulParam::AclnnGroupedMatmulVersion; + +class Ts_Aclnn_GroupedMatmul_WithParam_Ascend910B3 : public Ts_WithParam_Ascend910B3 {}; + +#endif // UTEST_TS_GROUPEDMATMUL_H \ No newline at end of file -- Gitee