224 Star 1.3K Fork 1.1K

Ascend/samples

加入 Gitee
与超过 1400万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
data_utils.h 6.15 KB
一键复制 编辑 原始数据 按行查看 历史
诸葛文洵 提交于 2024-08-14 10:22 +08:00 . change names from MatMul to Matmul
/**
* @file data_utils.h
*
* Copyright (C) 2023-2024. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
*/
#ifndef DATA_UTILS_H
#define DATA_UTILS_H
#include <fcntl.h>
#include <sys/stat.h>
#include <unistd.h>
#include <cassert>
#include <cstdio>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <string>
#include <vector>
#include "acl/acl.h"
typedef enum {
DT_UNDEFINED = -1,
FLOAT = 0,
HALF = 1,
INT8_T = 2,
INT32_T = 3,
UINT8_T = 4,
INT16_T = 6,
UINT16_T = 7,
UINT32_T = 8,
INT64_T = 9,
UINT64_T = 10,
DOUBLE = 11,
BOOL = 12,
STRING = 13,
COMPLEX64 = 16,
COMPLEX128 = 17,
BF16 = 27
} printDataType;
#define INFO_LOG(fmt, args...) fprintf(stdout, "[INFO] " fmt "\n", ##args)
#define WARN_LOG(fmt, args...) fprintf(stdout, "[WARN] " fmt "\n", ##args)
#define ERROR_LOG(fmt, args...) fprintf(stdout, "[ERROR] " fmt "\n", ##args)
#define CHECK_ACL(x) \
do { \
aclError __ret = x; \
if (__ret != ACL_ERROR_NONE) { \
std::cerr << __FILE__ << ":" << __LINE__ << " aclError:" << __ret << std::endl; \
} \
} while (0);
/**
* @brief Read data from file
* @param [in] filePath: file path
* @param [out] fileSize: file size
* @return read result
*/
bool ReadFile(const std::string &filePath, size_t &fileSize, void *buffer, size_t bufferSize)
{
struct stat sBuf;
int fileStatus = stat(filePath.data(), &sBuf);
if (fileStatus == -1) {
ERROR_LOG("failed to get file");
return false;
}
if (S_ISREG(sBuf.st_mode) == 0) {
ERROR_LOG("%s is not a file, please enter a file", filePath.c_str());
return false;
}
std::ifstream file;
file.open(filePath, std::ios::binary);
if (!file.is_open()) {
ERROR_LOG("Open file failed. path = %s", filePath.c_str());
return false;
}
std::filebuf *buf = file.rdbuf();
size_t size = buf->pubseekoff(0, std::ios::end, std::ios::in);
if (size == 0) {
ERROR_LOG("file size is 0");
file.close();
return false;
}
if (size > bufferSize) {
ERROR_LOG("file size is larger than buffer size");
file.close();
return false;
}
buf->pubseekpos(0, std::ios::in);
buf->sgetn(static_cast<char *>(buffer), size);
fileSize = size;
file.close();
return true;
}
/**
* @brief Write data to file
* @param [in] filePath: file path
* @param [in] buffer: data to write to file
* @param [in] size: size to write
* @return write result
*/
bool WriteFile(const std::string &filePath, const void *buffer, size_t size)
{
if (buffer == nullptr) {
ERROR_LOG("Write file failed. buffer is nullptr");
return false;
}
int fd = open(filePath.c_str(), O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWRITE);
if (fd < 0) {
ERROR_LOG("Open file failed. path = %s", filePath.c_str());
return false;
}
size_t writeSize = write(fd, buffer, size);
(void)close(fd);
if (writeSize != size) {
ERROR_LOG("Write file Failed.");
return false;
}
return true;
}
template <typename T> void DoPrintData(const T *data, size_t count, size_t elementsPerRow)
{
assert(elementsPerRow != 0);
for (size_t i = 0; i < count; ++i) {
std::cout << std::setw(10) << data[i];
if (i % elementsPerRow == elementsPerRow - 1) {
std::cout << std::endl;
}
}
}
void DoPrintHalfData(const aclFloat16 *data, size_t count, size_t elementsPerRow)
{
assert(elementsPerRow != 0);
for (size_t i = 0; i < count; ++i) {
std::cout << std::setw(10) << std::setprecision(6) << aclFloat16ToFloat(data[i]);
if (i % elementsPerRow == elementsPerRow - 1) {
std::cout << std::endl;
}
}
}
void PrintData(const void *data, size_t count, printDataType dataType, size_t elementsPerRow = 16)
{
if (data == nullptr) {
ERROR_LOG("Print data failed. data is nullptr");
return;
}
switch (dataType) {
case BOOL:
DoPrintData(reinterpret_cast<const bool *>(data), count, elementsPerRow);
break;
case INT8_T:
DoPrintData(reinterpret_cast<const int8_t *>(data), count, elementsPerRow);
break;
case UINT8_T:
DoPrintData(reinterpret_cast<const uint8_t *>(data), count, elementsPerRow);
break;
case INT16_T:
DoPrintData(reinterpret_cast<const int16_t *>(data), count, elementsPerRow);
break;
case UINT16_T:
DoPrintData(reinterpret_cast<const uint16_t *>(data), count, elementsPerRow);
break;
case INT32_T:
DoPrintData(reinterpret_cast<const int32_t *>(data), count, elementsPerRow);
break;
case UINT32_T:
DoPrintData(reinterpret_cast<const uint32_t *>(data), count, elementsPerRow);
break;
case INT64_T:
DoPrintData(reinterpret_cast<const int64_t *>(data), count, elementsPerRow);
break;
case UINT64_T:
DoPrintData(reinterpret_cast<const uint64_t *>(data), count, elementsPerRow);
break;
case HALF:
DoPrintHalfData(reinterpret_cast<const aclFloat16 *>(data), count, elementsPerRow);
break;
case FLOAT:
DoPrintData(reinterpret_cast<const float *>(data), count, elementsPerRow);
break;
case DOUBLE:
DoPrintData(reinterpret_cast<const double *>(data), count, elementsPerRow);
break;
default:
ERROR_LOG("Unsupported type: %d", dataType);
}
std::cout << std::endl;
}
#endif // DATA_UTILS_H
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ascend/samples.git
git@gitee.com:ascend/samples.git
ascend
samples
samples
8.0.RC3

搜索帮助