8 Star 123 Fork 5

Ascend/cann-var-sequence-gemm

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
batchgemm_main.cpp 3.40 KB
一键复制 编辑 原始数据 按行查看 历史
/*
Copyright (c) 2025 Huawei Technologies Co., Ltd.
This file is a part of the CANN Open Software.
Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
Please refer to the License for details. You may not use this file except in compliance with the License.
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
See LICENSE in the root of the software repository for the full text of the License.
*/
#include "acl/acl.h"
#include "../../include/custom_type.h"
#include "../data_utils.h"
#include "../run_main.h"
#include "./batchgemm_host.h"
#include <cstdio>
#include <string>
int main ( int argc, char ** argv ) {
// 获取输入 deviceId layoutA layoutB zeroPaddingM zeroPaddingN zeroPaddingK batchCount
uint32_t deviceId = std::stoi( argv[1] );
layoutType layoutA = ( std::stoi( argv[2] ) == 0 ? RowMajor : ColumnMajor );
layoutType layoutB = ( std::stoi( argv[3] ) == 0 ? RowMajor : ColumnMajor );
uint32_t zeroPaddingM = std::stoi( argv[4] );
uint32_t zeroPaddingN = std::stoi( argv[5] );
uint32_t zeroPaddingK = std::stoi( argv[6] );
uint32_t batchCount = std::stoi( argv[7] );
// 打印输出 layoutA layoutB zeroPaddingM zeroPaddingN zeroPaddingK batchCount
printf("\nTesting kernel on device %d. \n"
"Getting test input: \n"
"layoutA: %s, layoutB: %s, "
"zeroPaddingM: %d, zeroPaddingN: %d, zeroPaddingK: %d, "
"batchCount: %d. \n",
deviceId,
layoutA == 0 ? "RowMajor" : "ColumnMajor", layoutB == 0 ? "RowMajor" : "ColumnMajor",
zeroPaddingM, zeroPaddingN, zeroPaddingK,
batchCount);
if(deviceId < 0 || deviceId > 7
|| !(layoutA == RowMajor || layoutA == ColumnMajor)
|| !(layoutB == RowMajor || layoutB == ColumnMajor)
|| zeroPaddingM <= 0
|| zeroPaddingN <= 0
|| zeroPaddingK <= 0
|| batchCount <= 0){
printf("Wrong input! \n");
return 0;
}
// acl初始化
const char *aclConfigPath = "../acl.json";
ACL_CHECK(aclInit(/*aclConfigPath*/nullptr));
ACL_CHECK(aclrtSetDevice(deviceId));
aclrtStream stream;
ACL_CHECK(aclrtCreateStream(&stream));
std::string src="./data/";
std::string pathA = src + "A.bin";
std::string pathB = src + "B.bin";
std::string pathC = src + "C.bin";
std::string pathAlpha = src + "alpha.bin";
std::string pathBeta = src + "beta.bin";
std::string pathMaskA = src + "maskA.bin";
std::string pathExpectResult = src + "expect_result.bin";
half* preOutput = nullptr;
half* curOutput = nullptr;
run_main(
batchgemm_host,
stream,
layoutA,
layoutB,
zeroPaddingM,
zeroPaddingN,
zeroPaddingK,
batchCount,
pathA,
pathB,
pathC,
pathAlpha,
pathBeta,
pathMaskA,
pathExpectResult,
preOutput,
curOutput,
0,
0
);
ACL_CHECK(aclrtFree(curOutput));
// 反初始化
ACL_CHECK(aclrtDestroyStream(stream));
ACL_CHECK(aclrtResetDevice(deviceId));
ACL_CHECK(aclFinalize());
return 0;
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ascend/cann-var-sequence-gemm.git
git@gitee.com:ascend/cann-var-sequence-gemm.git
ascend
cann-var-sequence-gemm
cann-var-sequence-gemm
master

搜索帮助