From 43747edcaa49b92c59ed747a033265fb8cd495a7 Mon Sep 17 00:00:00 2001 From: zhangyujia77 Date: Thu, 25 Dec 2025 19:01:06 +0800 Subject: [PATCH 1/3] Add examples for PyTorch custom operator invocation using Pybind and torch.library. --- .../CMakeLists.txt | 64 +++++++++ .../28_simple_add_torch_library/README.md | 111 ++++++++++++++++ .../add_custom.asc | 123 ++++++++++++++++++ .../add_custom_test.py | 68 ++++++++++ operator/ascendc/0_introduction/README.md | 1 + 5 files changed, 367 insertions(+) create mode 100644 operator/ascendc/0_introduction/28_simple_add_torch_library/CMakeLists.txt create mode 100644 operator/ascendc/0_introduction/28_simple_add_torch_library/README.md create mode 100644 operator/ascendc/0_introduction/28_simple_add_torch_library/add_custom.asc create mode 100644 operator/ascendc/0_introduction/28_simple_add_torch_library/add_custom_test.py diff --git a/operator/ascendc/0_introduction/28_simple_add_torch_library/CMakeLists.txt b/operator/ascendc/0_introduction/28_simple_add_torch_library/CMakeLists.txt new file mode 100644 index 000000000..6fe405761 --- /dev/null +++ b/operator/ascendc/0_introduction/28_simple_add_torch_library/CMakeLists.txt @@ -0,0 +1,64 @@ +# ---------------------------------------------------------------------------------------------------------- +# This program is free software, you can redistribute it and/or modify it. +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ---------------------------------------------------------------------------------------------------------- + +cmake_minimum_required(VERSION 3.16) +find_package(ASC REQUIRED) +project(kernel_samples LANGUAGES ASC CXX) + +find_package(Python3 COMPONENTS Interpreter Development REQUIRED) +message("Python3_INCLUDE_DIRS is ${Python3_INCLUDE_DIRS}") + +execute_process( + COMMAND ${Python3_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)" + OUTPUT_VARIABLE TORCH_CMAKE_PATH + OUTPUT_STRIP_TRAILING_WHITESPACE +) +message("TORCH_CMAKE_PATH is ${TORCH_CMAKE_PATH}") +find_package(Torch REQUIRED HINTS ${TORCH_CMAKE_PATH}/Torch) +message("TORCH_INCLUDE_DIRS is ${TORCH_INCLUDE_DIRS}") +message("TORCH_LIBRARIES is ${TORCH_LIBRARIES}") + +execute_process( + COMMAND python3 -c "import os, torch_npu; print(os.path.dirname(torch_npu.__file__))" + OUTPUT_VARIABLE TORCH_NPU_PATH + OUTPUT_STRIP_TRAILING_WHITESPACE +) +message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}") +set(TORCH_NPU_INCLUDE_DIRS ${TORCH_NPU_PATH}/include) +set(TORCH_NPU_LIBRARIES ${TORCH_NPU_PATH}/lib) +message("TORCH_NPU_INCLUDE_DIRS is ${TORCH_NPU_INCLUDE_DIRS}") +message("TORCH_NPU_LIBRARIES is ${TORCH_NPU_LIBRARIES}") + +add_library(custom_ops SHARED + add_custom.asc +) + +target_include_directories(custom_ops PRIVATE + ${Python3_INCLUDE_DIRS} + ${TORCH_INCLUDE_DIRS} + ${TORCH_NPU_INCLUDE_DIRS} +) + +target_link_libraries(custom_ops PRIVATE + torch_npu +) + +target_link_directories(custom_ops PRIVATE + ${TORCH_NPU_LIBRARIES} +) + +target_compile_definitions(custom_ops PRIVATE + __GLIBCXX_USE_CXX11_ABI=0 +) + +target_compile_options(custom_ops PRIVATE + $<$:--npu-arch=dav-2201> +) \ No newline at end of file diff --git a/operator/ascendc/0_introduction/28_simple_add_torch_library/README.md b/operator/ascendc/0_introduction/28_simple_add_torch_library/README.md new file mode 100644 index 000000000..203333bbd --- /dev/null +++ b/operator/ascendc/0_introduction/28_simple_add_torch_library/README.md @@ -0,0 +1,111 @@ +# torch_library注册自定义算子直调样例 +本样例展示了如何使用PyTorch的torch.library机制注册自定义算子,并通过<<<>>>内核调用符调用核函数,以简单的Add算子为例,实现两个向量的逐元素相加。 + +## 目录结构介绍 +``` +├── 28_simple_add_torch_library +│ ├── CMakeLists.txt // 编译工程文件 +│ ├── add_custom_test.py // PyTorch调用自定义算子的测试脚本 +│ └── add_custom.asc // Ascend C算子实现 & 自定义算子注册 +``` + +## 代码实现介绍 +- kernel实现 + Add算子的数学表达式为: + ``` + z = x + y + ``` + 计算逻辑是:Ascend C提供的矢量计算接口的操作元素都为LocalTensor,输入数据需要先搬运进片上存储,然后使用计算接口完成两个输入参数相加,得到最终结果,再搬出到外部存储上。 + + Add算子的实现流程分为3个基本任务:CopyIn,Compute,CopyOut。CopyIn任务负责将Global Memory上的输入Tensor xGm和yGm搬运到Local Memory,分别存储在xLocal、yLocal,Compute任务负责对xLocal、yLocal执行加法操作,计算结果存储在zLocal中,CopyOut任务负责将输出数据从zLocal搬运至Global Memory上的输出Tensor zGm中。具体请参考[add_custom.asc](./add_custom.asc)。 + +- 自定义算子注册: + + PyTorch提供`TORCH_LIBRARY`宏作为自定义算子注册的核心接口,用于创建并初始化自定义算子库,注册后在Python侧可以通过`torch.ops.namespace.op_name`方式进行调用,例如: + ```c++ + TORCH_LIBRARY(ascendc_ops, m) { + m.def(ascendc_add"(Tensor x, Tensor y) -> Tensor"); + } + ``` + 另外,若相同命名空间需要在多个文件中拆分注册,需要使用`TORCH_LIBRARY_FRAGMENT`扩展现有算子库,避免重复创建命名空间导致冲突。 + + `TORCH_LIBRARY_IMPL`用于将算子逻辑绑定到特定的DispatchKey(PyTorch设备调度标识)。针对NPU设备,需要将算子实现注册到PrivateUse1这一专属的DispatchKey上,例如: + ```c++ + TORCH_LIBRARY_IMPL(ascendc_ops, PrivateUse1, m) + { + m.impl("ascendc_add", TORCH_FN(ascendc_ops::ascendc_add)); + } + ``` + + 本样例在add_custom.asc中定义了一个名为ascendc_ops的命名空间,并在其中注册了ascendc_add函数,Python侧可以通过`torch.ops.ascendc_ops.ascendc_add`调用自定义的API。在ascendc_add函数中通过`c10_npu::getCurrentNPUStream()`函数获取当前NPU上的流,并通过内核调用符<<<>>>调用自定义的Kernel函数add_custom,在NPU上执行算子。 + +- Python测试脚本 + + 在add_custom_test.py中,首先通过`torch.ops.load_library`加载生成的自定义算子库,并定义一个仅包含单算子的PyTorch模型SingleOpModel,其前向计算直接调用自定义算子。在测试执行时,脚本通过torchair配置编译策略并利用`torch.compile`使能模型在NPU进行全图编译优化,同时启动`torch_npu.profiler`性能分析器来获取算子执行过程中NPU和CPU的性能数据,最终通过对比NPU输出与CPU标准加法结果来验证自定义算子的数值正确性。 + +## 支持的产品型号 +本样例支持如下产品型号: +- Atlas A2 训练系列产品/Atlas 800I A2 推理产品 + +## 运行样例算子 +- 安装PyTorch (这里以使用2.1.0版本为例) + + **aarch64:** + + ```bash + pip3 install torch==2.1.0 + ``` + + **x86:** + + ```bash + pip3 install torch==2.1.0+cpu --index-url https://download.pytorch.org/whl/cpu + ``` + +- 安装torch-npu (以Pytorch2.1.0、python3.9、CANN版本8.0.RC1.alpha002为例) + ```bash + git clone https://gitee.com/ascend/pytorch.git -b v6.0.rc1.alpha002-pytorch2.1.0 + cd pytorch/ + bash ci/build.sh --python=3.9 + pip3 install dist/*.whl + ``` + +- 打开样例目录 +以命令行方式下载样例代码,master分支为例。 + ```bash + cd ${git_clone_path}/samples/operator/ascendc/0_introduction/27_simple_add_cpp_extensions + ``` + +- 配置环境变量 + 请根据当前环境上CANN开发套件包的[安装方式](https://hiascend.com/document/redirect/CannCommunityInstSoftware),选择对应配置环境变量的命令。 + - 默认路径,root用户安装CANN软件包 + ```bash + export ASCEND_INSTALL_PATH=/usr/local/Ascend/ascend-toolkit/latest + ``` + - 默认路径,非root用户安装CANN软件包 + ```bash + export ASCEND_INSTALL_PATH=$HOME/Ascend/ascend-toolkit/latest + ``` + - 指定路径install_path,安装CANN软件包 + ```bash + export ASCEND_INSTALL_PATH=${install_path}/ascend-toolkit/latest + ``` + 配置安装路径后,执行以下命令统一配置环境变量。 + ```bash + # 配置CANN环境变量 + source ${ASCEND_INSTALL_PATH}/bin/setenv.bash + # 添加AscendC CMake Module搜索路径至环境变量 + export CMAKE_PREFIX_PATH=${ASCEND_INSTALL_PATH}/compiler/tikcpp/ascendc_kernel_cmake:$CMAKE_PREFIX_PATH + ``` + +- 样例执行 + ```bash + mkdir -p build && cd build; # 创建并进入build目录 + cmake ..;make -j; # 编译工程 + python3 ../add_custom_test.py # 执行测试脚本 + ``` + +## 更新说明 +| 时间 | 更新事项 | +| ---------- | ------------ | +| 2025/12/25 | 新增本readme | diff --git a/operator/ascendc/0_introduction/28_simple_add_torch_library/add_custom.asc b/operator/ascendc/0_introduction/28_simple_add_torch_library/add_custom.asc new file mode 100644 index 000000000..765a616e6 --- /dev/null +++ b/operator/ascendc/0_introduction/28_simple_add_torch_library/add_custom.asc @@ -0,0 +1,123 @@ +/** +* Copyright (c) 2025 Huawei Technologies Co., Ltd. +* This program is free software, you can redistribute it and/or modify it under the terms and conditions of +* CANN Open Software License Agreement Version 2.0 (the "License"). +* Please refer to the License for details. You may not use this file except in compliance with the License. +* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +* See LICENSE in the root of the software repository for the full text of the License. +*/ + +/* ! + * \file add_custom.asc + * \brief + */ + +#include +#include "torch_npu/csrc/core/npu/NPUStream.h" +#include "kernel_operator.h" + +constexpr uint32_t BUFFER_NUM = 2; //tensor num for each queue +class KernelAdd { +public: + __aicore__ inline KernelAdd() {} + __aicore__ inline void Init(GM_ADDR x, GM_ADDR y, GM_ADDR z, uint32_t totalLength) + { + this->blockLength = totalLength / AscendC::GetBlockNum(); + this->tileNum = 8; + this->tileLength = this->blockLength / this->tileNum / BUFFER_NUM; + xGm.SetGlobalBuffer((__gm__ float *)x + this->blockLength * AscendC::GetBlockIdx(), this->blockLength); + yGm.SetGlobalBuffer((__gm__ float *)y + this->blockLength * AscendC::GetBlockIdx(), this->blockLength); + zGm.SetGlobalBuffer((__gm__ float *)z + this->blockLength * AscendC::GetBlockIdx(), this->blockLength); + pipe.InitBuffer(inQueueX, BUFFER_NUM, this->tileLength * sizeof(float)); + pipe.InitBuffer(inQueueY, BUFFER_NUM, this->tileLength * sizeof(float)); + pipe.InitBuffer(outQueueZ, BUFFER_NUM, this->tileLength * sizeof(float)); + } + __aicore__ inline void Process() + { + int32_t loopCount = this->tileNum * BUFFER_NUM; + for (int32_t i = 0; i < loopCount; i++) { + CopyIn(i); + Compute(i); + CopyOut(i); + } + } + +private: + __aicore__ inline void CopyIn(int32_t progress) + { + AscendC::LocalTensor xLocal = inQueueX.AllocTensor(); + AscendC::LocalTensor yLocal = inQueueY.AllocTensor(); + AscendC::DataCopy(xLocal, xGm[progress * this->tileLength], this->tileLength); + AscendC::DataCopy(yLocal, yGm[progress * this->tileLength], this->tileLength); + inQueueX.EnQue(xLocal); + inQueueY.EnQue(yLocal); + } + __aicore__ inline void Compute(int32_t progress) + { + AscendC::LocalTensor xLocal = inQueueX.DeQue(); + AscendC::LocalTensor yLocal = inQueueY.DeQue(); + AscendC::LocalTensor zLocal = outQueueZ.AllocTensor(); + AscendC::Add(zLocal,xLocal, yLocal, this->tileLength); + outQueueZ.EnQue(zLocal); + inQueueX.FreeTensor(xLocal); + inQueueY.FreeTensor(yLocal); + } + __aicore__ inline void CopyOut(int32_t progress) + { + AscendC::LocalTensor zLocal = outQueueZ.DeQue(); + AscendC::DataCopy(zGm[progress * this->tileLength], zLocal, this->tileLength); + outQueueZ.FreeTensor(zLocal); + } + +private: + AscendC::TPipe pipe; + AscendC::TQue inQueueX, inQueueY; + AscendC::TQue outQueueZ; + AscendC::GlobalTensor xGm; + AscendC::GlobalTensor yGm; + AscendC::GlobalTensor zGm; + uint32_t blockLength; + uint32_t tileNum; + uint32_t tileLength; +}; + +__global__ __vector__ void add_custom(GM_ADDR x, GM_ADDR y, GM_ADDR z, uint32_t totalLength) +{ + KernelAdd op; + op.Init(x, y, z, totalLength); + op.Process(); +} + +namespace ascendc_ops { +at::Tensor ascendc_add(const at::Tensor& x, const at::Tensor& y) +{ + auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); + at::Tensor z = at::empty_like(x); + uint32_t blockDim = 8; + uint32_t totalLength = 1; + for (uint32_t size : x.sizes()) { + totalLength *= size; + } + add_custom<<>>((uint8_t*)(x.mutable_data_ptr()), (uint8_t*)(y.mutable_data_ptr()), (uint8_t*)(z.mutable_data_ptr()), totalLength); + return z; +} +} // namespace ascendc_ops + +TORCH_LIBRARY(ascendc_ops, m) +{ + m.def("ascendc_add(Tensor x, Tensor y) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(ascendc_ops, PrivateUse1, m) +{ + m.impl("ascendc_add", TORCH_FN(ascendc_ops::ascendc_add)); +} + +TORCH_LIBRARY_IMPL(ascendc_ops, Meta, m) +{ + m.impl("ascendc_add", [](const torch::Tensor& a, const torch::Tensor& b) { + TORCH_CHECK(a.sizes() == b.sizes(), "Input size must match"); + return torch::empty_like(a, torch::device(torch::kMeta)); + }); +} diff --git a/operator/ascendc/0_introduction/28_simple_add_torch_library/add_custom_test.py b/operator/ascendc/0_introduction/28_simple_add_torch_library/add_custom_test.py new file mode 100644 index 000000000..b3eb6116a --- /dev/null +++ b/operator/ascendc/0_introduction/28_simple_add_torch_library/add_custom_test.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +# -*- coding: UTF-8 -*- +# ---------------------------------------------------------------------------- +# This program is free software, you can redistribute it and/or modify it. +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ---------------------------------------------------------------------------- + +import sys +import os +import torch +import torch_npu +import torchair +from torch_npu.testing.testcase import TestCase, run_tests +torch.ops.load_library("libcustom_ops.so") + + +class SingleOpModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.ops.ascendc_ops.ascendc_add(x, y) + + +class TestCustomAdd(TestCase): + def test_add_custom_ops(self): + config = torchair.CompilerConfig() + config.mode = "reduce-overhead" + npu_backend = torchair.get_npu_backend(compiler_config=config) + model = torch.compile(SingleOpModel().npu(), backend=npu_backend) + + experimental_config = torch_npu.profiler._ExperimentalConfig( + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization + ) + + profiler = torch_npu.profiler.profile( + activities=[ + torch_npu.profiler.ProfilerActivity.NPU, + torch_npu.profiler.ProfilerActivity.CPU, + ], + with_stack=False, + record_shapes=False, + profile_memory=False, + experimental_config=experimental_config, + schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=10, repeat=1, skip_first=0), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./res") + ) + + length = [8, 2048] + x = torch.rand(length, device='cpu', dtype=torch.float) + y = torch.rand(length, device='cpu', dtype=torch.float) + profiler.start() + output = model(x.npu(), y.npu()).cpu() + profiler.stop() + + cpuout = torch.add(x, y) + self.assertRtolEqual(output, cpuout) + + +if __name__ == "__main__": + run_tests() diff --git a/operator/ascendc/0_introduction/README.md b/operator/ascendc/0_introduction/README.md index 0b9383e71..5cb4be5cf 100644 --- a/operator/ascendc/0_introduction/README.md +++ b/operator/ascendc/0_introduction/README.md @@ -42,6 +42,7 @@ | [25_simple_add](./25_simple_add) | Ascend C异构混合编程样例, 实现Add自定义Vector算子及调用, 支持host/device代码混合编程 | Atlas A2训练系列产品/Atlas 800I A2推理产品 | [26_simple_matmulleakyrelu](./26_simple_matmulleakyrelu) | Ascend C异构混合编程样例, 实现MatmulLeakyRelu自定义Cube+Vector算子及调用, 支持host/device代码混合编程 | Atlas A2训练系列产品/Atlas 800I A2推理产品 | [27_simple_add_cpp_extensions](./27_simple_add_cpp_extensions) | Ascend C异构混合编程样例, 实现Add自定义Vector算子动态库及pybind调用, 支持host/device代码混合编程 | Atlas A2训练系列产品/Atlas 800I A2推理产品 +| [28_simple_add_torch_library](./28_simple_add_torch_library) | Ascend C异构混合编程样例, 使用PyTorch的torch.library机制注册自定义算子, 支持host/device代码混合编程 | Atlas A2训练系列产品/Atlas 800I A2推理产品 ## 获取样例代码 可以使用以下两种方式下载,请选择其中一种进行源码准备。 -- Gitee From d57d774db36300218dd7d99bd02e81c255711740 Mon Sep 17 00:00:00 2001 From: zhangyujia77 Date: Sat, 27 Dec 2025 18:50:25 +0800 Subject: [PATCH 2/3] fix readme --- .../28_simple_add_torch_library/README.md | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/operator/ascendc/0_introduction/28_simple_add_torch_library/README.md b/operator/ascendc/0_introduction/28_simple_add_torch_library/README.md index 203333bbd..b4e100572 100644 --- a/operator/ascendc/0_introduction/28_simple_add_torch_library/README.md +++ b/operator/ascendc/0_introduction/28_simple_add_torch_library/README.md @@ -1,4 +1,4 @@ -# torch_library注册自定义算子直调样例 +# torch.library注册自定义算子直调样例 本样例展示了如何使用PyTorch的torch.library机制注册自定义算子,并通过<<<>>>内核调用符调用核函数,以简单的Add算子为例,实现两个向量的逐元素相加。 ## 目录结构介绍 @@ -27,7 +27,7 @@ m.def(ascendc_add"(Tensor x, Tensor y) -> Tensor"); } ``` - 另外,若相同命名空间需要在多个文件中拆分注册,需要使用`TORCH_LIBRARY_FRAGMENT`扩展现有算子库,避免重复创建命名空间导致冲突。 + 另外,若相同算子需要在多个文件中拆分注册,需要使用`TORCH_LIBRARY_FRAGMENT`扩展现有算子库,避免重复定义导致冲突。 `TORCH_LIBRARY_IMPL`用于将算子逻辑绑定到特定的DispatchKey(PyTorch设备调度标识)。针对NPU设备,需要将算子实现注册到PrivateUse1这一专属的DispatchKey上,例如: ```c++ @@ -36,6 +36,16 @@ m.impl("ascendc_add", TORCH_FN(ascendc_ops::ascendc_add)); } ``` + 另外,需要提供Meta Kernel实现,用于在图模式快速推导输出Tensor的形状、类型等信息,如: + ```c++ + TORCH_LIBRARY_IMPL(ascendc_ops, Meta, m) + { + m.impl("ascendc_add", [](const torch::Tensor& a, const torch::Tensor& b) { + TORCH_CHECK(a.sizes() == b.sizes(), "Input size must match"); + return torch::empty_like(a, torch::device(torch::kMeta)); + }); + } + ``` 本样例在add_custom.asc中定义了一个名为ascendc_ops的命名空间,并在其中注册了ascendc_add函数,Python侧可以通过`torch.ops.ascendc_ops.ascendc_add`调用自定义的API。在ascendc_add函数中通过`c10_npu::getCurrentNPUStream()`函数获取当前NPU上的流,并通过内核调用符<<<>>>调用自定义的Kernel函数add_custom,在NPU上执行算子。 -- Gitee From 45f39fc50615183dc7031ecdfbd2fb949815c844 Mon Sep 17 00:00:00 2001 From: zhangyujia77 Date: Tue, 30 Dec 2025 11:17:20 +0800 Subject: [PATCH 3/3] modify the way of invocation method --- .../CMakeLists.txt | 6 --- .../28_simple_add_torch_library/README.md | 43 +++---------------- .../add_custom.asc | 8 ---- .../add_custom_test.py | 37 +--------------- 4 files changed, 8 insertions(+), 86 deletions(-) diff --git a/operator/ascendc/0_introduction/28_simple_add_torch_library/CMakeLists.txt b/operator/ascendc/0_introduction/28_simple_add_torch_library/CMakeLists.txt index 6fe405761..c7d03b609 100644 --- a/operator/ascendc/0_introduction/28_simple_add_torch_library/CMakeLists.txt +++ b/operator/ascendc/0_introduction/28_simple_add_torch_library/CMakeLists.txt @@ -21,21 +21,15 @@ execute_process( OUTPUT_VARIABLE TORCH_CMAKE_PATH OUTPUT_STRIP_TRAILING_WHITESPACE ) -message("TORCH_CMAKE_PATH is ${TORCH_CMAKE_PATH}") find_package(Torch REQUIRED HINTS ${TORCH_CMAKE_PATH}/Torch) -message("TORCH_INCLUDE_DIRS is ${TORCH_INCLUDE_DIRS}") -message("TORCH_LIBRARIES is ${TORCH_LIBRARIES}") execute_process( COMMAND python3 -c "import os, torch_npu; print(os.path.dirname(torch_npu.__file__))" OUTPUT_VARIABLE TORCH_NPU_PATH OUTPUT_STRIP_TRAILING_WHITESPACE ) -message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}") set(TORCH_NPU_INCLUDE_DIRS ${TORCH_NPU_PATH}/include) set(TORCH_NPU_LIBRARIES ${TORCH_NPU_PATH}/lib) -message("TORCH_NPU_INCLUDE_DIRS is ${TORCH_NPU_INCLUDE_DIRS}") -message("TORCH_NPU_LIBRARIES is ${TORCH_NPU_LIBRARIES}") add_library(custom_ops SHARED add_custom.asc diff --git a/operator/ascendc/0_introduction/28_simple_add_torch_library/README.md b/operator/ascendc/0_introduction/28_simple_add_torch_library/README.md index b4e100572..ae7e8ff4b 100644 --- a/operator/ascendc/0_introduction/28_simple_add_torch_library/README.md +++ b/operator/ascendc/0_introduction/28_simple_add_torch_library/README.md @@ -21,13 +21,14 @@ - 自定义算子注册: + 本样例在add_custom.asc中定义了一个名为ascendc_ops的命名空间,并在其中注册了ascendc_add函数。 + PyTorch提供`TORCH_LIBRARY`宏作为自定义算子注册的核心接口,用于创建并初始化自定义算子库,注册后在Python侧可以通过`torch.ops.namespace.op_name`方式进行调用,例如: ```c++ TORCH_LIBRARY(ascendc_ops, m) { m.def(ascendc_add"(Tensor x, Tensor y) -> Tensor"); } ``` - 另外,若相同算子需要在多个文件中拆分注册,需要使用`TORCH_LIBRARY_FRAGMENT`扩展现有算子库,避免重复定义导致冲突。 `TORCH_LIBRARY_IMPL`用于将算子逻辑绑定到特定的DispatchKey(PyTorch设备调度标识)。针对NPU设备,需要将算子实现注册到PrivateUse1这一专属的DispatchKey上,例如: ```c++ @@ -36,54 +37,24 @@ m.impl("ascendc_add", TORCH_FN(ascendc_ops::ascendc_add)); } ``` - 另外,需要提供Meta Kernel实现,用于在图模式快速推导输出Tensor的形状、类型等信息,如: - ```c++ - TORCH_LIBRARY_IMPL(ascendc_ops, Meta, m) - { - m.impl("ascendc_add", [](const torch::Tensor& a, const torch::Tensor& b) { - TORCH_CHECK(a.sizes() == b.sizes(), "Input size must match"); - return torch::empty_like(a, torch::device(torch::kMeta)); - }); - } - ``` - - 本样例在add_custom.asc中定义了一个名为ascendc_ops的命名空间,并在其中注册了ascendc_add函数,Python侧可以通过`torch.ops.ascendc_ops.ascendc_add`调用自定义的API。在ascendc_add函数中通过`c10_npu::getCurrentNPUStream()`函数获取当前NPU上的流,并通过内核调用符<<<>>>调用自定义的Kernel函数add_custom,在NPU上执行算子。 + 在ascendc_add函数中通过`c10_npu::getCurrentNPUStream()`函数获取当前NPU上的流,并通过内核调用符<<<>>>调用自定义的Kernel函数add_custom,在NPU上执行算子。 - Python测试脚本 - 在add_custom_test.py中,首先通过`torch.ops.load_library`加载生成的自定义算子库,并定义一个仅包含单算子的PyTorch模型SingleOpModel,其前向计算直接调用自定义算子。在测试执行时,脚本通过torchair配置编译策略并利用`torch.compile`使能模型在NPU进行全图编译优化,同时启动`torch_npu.profiler`性能分析器来获取算子执行过程中NPU和CPU的性能数据,最终通过对比NPU输出与CPU标准加法结果来验证自定义算子的数值正确性。 + 在add_custom_test.py中,首先通过`torch.ops.load_library`加载生成的自定义算子库,调用注册的ascendc_add函数,并通过对比NPU输出与CPU标准加法结果来验证自定义算子的数值正确性。 ## 支持的产品型号 本样例支持如下产品型号: - Atlas A2 训练系列产品/Atlas 800I A2 推理产品 ## 运行样例算子 -- 安装PyTorch (这里以使用2.1.0版本为例) - - **aarch64:** - - ```bash - pip3 install torch==2.1.0 - ``` - - **x86:** - - ```bash - pip3 install torch==2.1.0+cpu --index-url https://download.pytorch.org/whl/cpu - ``` - -- 安装torch-npu (以Pytorch2.1.0、python3.9、CANN版本8.0.RC1.alpha002为例) - ```bash - git clone https://gitee.com/ascend/pytorch.git -b v6.0.rc1.alpha002-pytorch2.1.0 - cd pytorch/ - bash ci/build.sh --python=3.9 - pip3 install dist/*.whl - ``` +- 请参考与您当前使用的版本配套的[《Ascend Extension for PyTorch +软件安装指南》](https://www.hiascend.com/document/detail/zh/Pytorch/720/configandinstg/instg/insg_0001.html),获取PyTorch和torch_npu详细的安装步骤。 - 打开样例目录 以命令行方式下载样例代码,master分支为例。 ```bash - cd ${git_clone_path}/samples/operator/ascendc/0_introduction/27_simple_add_cpp_extensions + cd ${git_clone_path}/samples/operator/ascendc/0_introduction/28_simple_add_torch_library ``` - 配置环境变量 diff --git a/operator/ascendc/0_introduction/28_simple_add_torch_library/add_custom.asc b/operator/ascendc/0_introduction/28_simple_add_torch_library/add_custom.asc index 765a616e6..22fb3151c 100644 --- a/operator/ascendc/0_introduction/28_simple_add_torch_library/add_custom.asc +++ b/operator/ascendc/0_introduction/28_simple_add_torch_library/add_custom.asc @@ -113,11 +113,3 @@ TORCH_LIBRARY_IMPL(ascendc_ops, PrivateUse1, m) { m.impl("ascendc_add", TORCH_FN(ascendc_ops::ascendc_add)); } - -TORCH_LIBRARY_IMPL(ascendc_ops, Meta, m) -{ - m.impl("ascendc_add", [](const torch::Tensor& a, const torch::Tensor& b) { - TORCH_CHECK(a.sizes() == b.sizes(), "Input size must match"); - return torch::empty_like(a, torch::device(torch::kMeta)); - }); -} diff --git a/operator/ascendc/0_introduction/28_simple_add_torch_library/add_custom_test.py b/operator/ascendc/0_introduction/28_simple_add_torch_library/add_custom_test.py index b3eb6116a..be779f130 100644 --- a/operator/ascendc/0_introduction/28_simple_add_torch_library/add_custom_test.py +++ b/operator/ascendc/0_introduction/28_simple_add_torch_library/add_custom_test.py @@ -15,51 +15,16 @@ import sys import os import torch import torch_npu -import torchair from torch_npu.testing.testcase import TestCase, run_tests torch.ops.load_library("libcustom_ops.so") -class SingleOpModel(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return torch.ops.ascendc_ops.ascendc_add(x, y) - - class TestCustomAdd(TestCase): def test_add_custom_ops(self): - config = torchair.CompilerConfig() - config.mode = "reduce-overhead" - npu_backend = torchair.get_npu_backend(compiler_config=config) - model = torch.compile(SingleOpModel().npu(), backend=npu_backend) - - experimental_config = torch_npu.profiler._ExperimentalConfig( - profiler_level=torch_npu.profiler.ProfilerLevel.Level1, - aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization - ) - - profiler = torch_npu.profiler.profile( - activities=[ - torch_npu.profiler.ProfilerActivity.NPU, - torch_npu.profiler.ProfilerActivity.CPU, - ], - with_stack=False, - record_shapes=False, - profile_memory=False, - experimental_config=experimental_config, - schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=10, repeat=1, skip_first=0), - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./res") - ) - length = [8, 2048] x = torch.rand(length, device='cpu', dtype=torch.float) y = torch.rand(length, device='cpu', dtype=torch.float) - profiler.start() - output = model(x.npu(), y.npu()).cpu() - profiler.stop() - + output = torch.ops.ascendc_ops.ascendc_add(x.npu(), y.npu()).cpu() cpuout = torch.add(x, y) self.assertRtolEqual(output, cpuout) -- Gitee