代码拉取完成,页面将自动刷新
/*
Copyright (c) 2025 Huawei Technologies Co., Ltd.
This file is a part of the CANN Open Software.
Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
Please refer to the License for details. You may not use this file except in compliance with the License.
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
See LICENSE in the root of the software repository for the full text of the License.
*/
#include "acl/acl.h"
#include "../../include/custom_type.h"
#include "../data_utils.h"
#include "../run_main.h"
#include "./batchgemm_host.h"
#include <cstdio>
#include <string>
int main ( int argc, char ** argv ) {
// 获取输入 deviceId layoutA layoutB zeroPaddingM zeroPaddingN zeroPaddingK batchCount
uint32_t deviceId = std::stoi( argv[1] );
layoutType layoutA = ( std::stoi( argv[2] ) == 0 ? RowMajor : ColumnMajor );
layoutType layoutB = ( std::stoi( argv[3] ) == 0 ? RowMajor : ColumnMajor );
uint32_t zeroPaddingM = std::stoi( argv[4] );
uint32_t zeroPaddingN = std::stoi( argv[5] );
uint32_t zeroPaddingK = std::stoi( argv[6] );
uint32_t batchCount = std::stoi( argv[7] );
// 打印输出 layoutA layoutB zeroPaddingM zeroPaddingN zeroPaddingK batchCount
printf("\nTesting kernel on device %d. \n"
"Getting test input: \n"
"layoutA: %s, layoutB: %s, "
"zeroPaddingM: %d, zeroPaddingN: %d, zeroPaddingK: %d, "
"batchCount: %d. \n",
deviceId,
layoutA == 0 ? "RowMajor" : "ColumnMajor", layoutB == 0 ? "RowMajor" : "ColumnMajor",
zeroPaddingM, zeroPaddingN, zeroPaddingK,
batchCount);
if(deviceId < 0 || deviceId > 7
|| !(layoutA == RowMajor || layoutA == ColumnMajor)
|| !(layoutB == RowMajor || layoutB == ColumnMajor)
|| zeroPaddingM <= 0
|| zeroPaddingN <= 0
|| zeroPaddingK <= 0
|| batchCount <= 0){
printf("Wrong input! \n");
return 0;
}
// acl初始化
const char *aclConfigPath = "../acl.json";
ACL_CHECK(aclInit(/*aclConfigPath*/nullptr));
ACL_CHECK(aclrtSetDevice(deviceId));
aclrtStream stream;
ACL_CHECK(aclrtCreateStream(&stream));
std::string src="./data/";
std::string pathA = src + "A.bin";
std::string pathB = src + "B.bin";
std::string pathC = src + "C.bin";
std::string pathAlpha = src + "alpha.bin";
std::string pathBeta = src + "beta.bin";
std::string pathMaskA = src + "maskA.bin";
std::string pathExpectResult = src + "expect_result.bin";
half* preOutput = nullptr;
half* curOutput = nullptr;
run_main(
batchgemm_host,
stream,
layoutA,
layoutB,
zeroPaddingM,
zeroPaddingN,
zeroPaddingK,
batchCount,
pathA,
pathB,
pathC,
pathAlpha,
pathBeta,
pathMaskA,
pathExpectResult,
preOutput,
curOutput,
0,
0
);
ACL_CHECK(aclrtFree(curOutput));
// 反初始化
ACL_CHECK(aclrtDestroyStream(stream));
ACL_CHECK(aclrtResetDevice(deviceId));
ACL_CHECK(aclFinalize());
return 0;
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。