Ai
223 Star 1.3K Fork 1.1K

Ascend/samples
暂停

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
add_custom.cpp 2.37 KB
一键复制 编辑 原始数据 按行查看 历史
诸葛文洵 提交于 2024-07-21 10:36 +08:00 . simplify samples
/**
* @file add_custom.cpp
*
* Copyright (C) 2023-2024. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
*/
#include "add_custom_tiling.h"
#include "register/op_def_registry.h"
namespace optiling {
const uint32_t BLOCK_DIM = 8;
const uint32_t TILE_NUM = 8;
static ge::graphStatus TilingFunc(gert::TilingContext *context)
{
TilingData tiling;
uint32_t totalLength = context->GetInputShape(0)->GetOriginShape().GetShapeSize();
context->SetBlockDim(BLOCK_DIM);
tiling.set_totalLength(totalLength);
tiling.set_tileNum(TILE_NUM);
tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
size_t *currentWorkspace = context->GetWorkspaceSizes(1);
currentWorkspace[0] = 0;
return ge::GRAPH_SUCCESS;
}
} // namespace optiling
namespace ge {
static graphStatus InferShape(gert::InferShapeContext *context)
{
const gert::Shape *x1_shape = context->GetInputShape(0);
gert::Shape *y_shape = context->GetOutputShape(0);
*y_shape = *x1_shape;
return GRAPH_SUCCESS;
}
static graphStatus InferDataType(gert::InferDataTypeContext *context)
{
const auto inputDataType = context->GetInputDataType(0);
context->SetOutputDataType(0, inputDataType);
return ge::GRAPH_SUCCESS;
}
} // namespace ge
namespace ops {
class AddCustom : public OpDef {
public:
explicit AddCustom(const char *name) : OpDef(name)
{
this->Input("x")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16})
.Format({ge::FORMAT_ND});
this->Input("y")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16})
.Format({ge::FORMAT_ND});
this->Output("z")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16})
.Format({ge::FORMAT_ND});
this->SetInferShape(ge::InferShape).SetInferDataType(ge::InferDataType);
this->AICore()
.SetTiling(optiling::TilingFunc)
.AddConfig("ascend910")
.AddConfig("ascend310p")
.AddConfig("ascend310b")
.AddConfig("ascend910b");
}
};
OP_ADD(AddCustom);
} // namespace ops
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ascend/samples.git
git@gitee.com:ascend/samples.git
ascend
samples
samples
8.0.RC3

搜索帮助