diff --git a/ccsrc/base/ascendc/pyboost/ascendc_pyboost_runner.h b/ccsrc/base/ascendc/pyboost/ascendc_pyboost_runner.h index 9cbc565c9a23fd6daf755c041fd58d444b991dce..029553f3b04f1037644b4ad0ac1a0139fe2829a2 100644 --- a/ccsrc/base/ascendc/pyboost/ascendc_pyboost_runner.h +++ b/ccsrc/base/ascendc/pyboost/ascendc_pyboost_runner.h @@ -16,49 +16,67 @@ #ifndef MS_CUSTOM_OPS_OP_DEF_ASCENDC_PYBOOST_ASCENDC_PYBOOST_RUNNER_H_ #define MS_CUSTOM_OPS_OP_DEF_ASCENDC_PYBOOST_ASCENDC_PYBOOST_RUNNER_H_ +#include "module.h" #include "ms_extension/all.h" #include #include #include -#include "module.h" - - -namespace ms_custom_ops { -using namespace mindspore; -using namespace ms::pynative; +namespace ms::pynative { using AscendCLaunchFunc = std::function; -using AscendCWorkSpaceFunc = std::function; -class AscendCOpRunner : public PyboostRunner { +class AscendCOpRunner final : public PyboostRunner { public: using PyboostRunner::PyboostRunner; void SetLaunchFunc(AscendCLaunchFunc func) { launch_func_ = func; } - void SetWorkSpaceFunc(AscendCWorkSpaceFunc func) { workspace_func_ = func; } protected: - size_t CalcWorkspace() override { - if (workspace_func_ != nullptr) { - return workspace_func_(); - } - return 0; - } - void LaunchKernel() override { - if (launch_func_ != nullptr) { - launch_func_(_device_context_, _stream_id_); - } + MS_EXCEPTION_IF_NULL(launch_func_); + launch_func_(_device_context_, _stream_id_); } + void _DispatchLaunchTask() override { LaunchKernel(); } AscendCLaunchFunc launch_func_{nullptr}; - AscendCWorkSpaceFunc workspace_func_{nullptr}; }; -#define LAUNCH_ASCENDC_FUNC(aclnn_api, ...) \ - [__VA_ARGS__](auto __device_context, auto __stream_id) { \ - LAUNCH_ACLNN(aclnn_api, __device_context, __stream_id, __VA_ARGS__); \ +inline mindspore::tensor::TensorPtr Tensor2Ptr(const ms::Tensor &t) { + return t.is_defined() ? t.tensor() : nullptr; +} + +inline std::vector +Tensor2Ptr(const std::vector &tensors) { + std::vector result; + result.reserve(tensors.size()); + for (const auto &t : tensors) { + result.push_back(t.tensor()); + } + return result; +} + +inline std::optional +Tensor2Ptr(const std::optional &opt_tensor) { + if (opt_tensor.has_value()) { + return Tensor2Ptr(opt_tensor.value()); } -} // namespace ms_custom_ops + return std::nullopt; +} + +template inline constexpr T Tensor2Ptr(const T &t) { return t; } + +#define LAUNCH_ASCENDC_FUNC(aclnn_api, ...) \ + [](auto &&... args) { \ + auto args_t = std::make_tuple( \ + ms::pynative::Tensor2Ptr(std::forward(args))...); \ + return [args_t](auto __dev_ctx, auto __stream_id) { \ + std::apply( \ + [&](auto &&... args) { \ + LAUNCH_ACLNN(aclnn_api, __dev_ctx, __stream_id, args...); \ + }, \ + args_t); \ + }; \ + }(__VA_ARGS__) +} // namespace ms::pynative #endif // MS_CUSTOM_OPS_OP_DEF_ASCENDC_PYBOOST_ASCENDC_PYBOOST_RUNNER_H_ diff --git a/ccsrc/ops/ascendc/add.cc b/ccsrc/ops/ascendc/add.cc index da7d5724ce368f1d5fde960fe3d625387986b6c6..98594b6defc630d6725598666faa803049475209 100644 --- a/ccsrc/ops/ascendc/add.cc +++ b/ccsrc/ops/ascendc/add.cc @@ -78,28 +78,23 @@ MS_CUSTOM_OPS_REGISTER(add, AddCustomOpFuncImpl, AddCustomAscend); #include "ascendc_pyboost_runner.h" namespace ms_custom_ops { -class AddRunner : public AscendCOpRunner { -public: - using AscendCOpRunner::AscendCOpRunner; -}; - +using namespace mindspore; +using namespace mindspore::device::ascend; ms::Tensor custom_add(const ms::Tensor &x, const ms::Tensor &y) { + // assume the shape of x and y is same. auto out = ms::Tensor(x.data_type(), x.shape()); - auto runner = std::make_shared("AddCustom"); - auto ms_x = x.tensor(); - auto ms_y = y.tensor(); - auto ms_out = out.tensor(); - runner->SetLaunchFunc( - LAUNCH_ASCENDC_FUNC(aclnnAddCustom, ms_x, ms_y, ms_out)); + auto runner = std::make_shared("AddCustom"); + runner->SetLaunchFunc(LAUNCH_ASCENDC_FUNC(aclnnAddCustom, x, y, out)); runner->Run({x, y}, {out}); return out; } -} // namespace ms_custom_ops -py::object pyboost_add(const ms::Tensor &x, const ms::Tensor &y) { - return ms_custom_ops::AddRunner::Call<1>(ms_custom_ops::custom_add, x, y); +auto pyboost_add(const ms::Tensor &x, const ms::Tensor &y) { + return ms::pynative::PyboostRunner::Call<1>(custom_add, x, y); } +} // namespace ms_custom_ops MS_CUSTOM_OPS_EXTENSION_MODULE(m) { - m.def("add", &pyboost_add, "add", pybind11::arg("x"), pybind11::arg("y")); + m.def("add", &ms_custom_ops::pyboost_add, "add", pybind11::arg("x"), + pybind11::arg("y")); }