From a4498d976b811dfd61f292d0a5922c8e92993bc1 Mon Sep 17 00:00:00 2001 From: jiaorui Date: Fri, 25 Jul 2025 16:29:40 +0800 Subject: [PATCH] unify custom reg --- README.md | 2 +- ccsrc/base/module.h | 41 ++++++++++--------- .../graphmode/internal_kernel_mod.h | 2 + ccsrc/ops/ascendc/add.cc | 9 ++-- .../ms_kernels_internal/reshape_and_cache.cc | 16 ++++---- 5 files changed, 34 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 8ee84ca..9bddec2 100644 --- a/README.md +++ b/README.md @@ -322,7 +322,7 @@ protected: // 注册算子到 MindSpore 框架 // 注册算子名称映射 (对外接口my_op, 内部算子库名字internal::kInternalMyOpName 对接的kernelmod CustomMyOp) -MS_CUSTOM_OPS_REGISTER(my_op, internal::kInternalMyOpName, +REG_GRAPH_MODE_OP(my_op, internal::kInternalMyOpName, CustomMyOp); ``` diff --git a/ccsrc/base/module.h b/ccsrc/base/module.h index c9cb414..da9fa9f 100644 --- a/ccsrc/base/module.h +++ b/ccsrc/base/module.h @@ -17,6 +17,7 @@ #ifndef MS_CUSTOM_OPS_MODULE_H_ #define MS_CUSTOM_OPS_MODULE_H_ +#include "ms_extension/api.h" #include "plugin/device/ascend/kernel/custom/custom_kernel_factory.h" #include #include @@ -36,13 +37,18 @@ public: } // Register a module function - void Register(const ModuleRegisterFunction &func) { - functions_.push_back(func); + void Register(ModuleRegisterFunction func, bool pynative_node = true) { + auto &target = + pynative_node ? pynative_reg_functions_ : graph_reg_functions_; + target.emplace_back(std::move(func)); } // Call all registered module functions void RegisterAll(pybind11::module_ &m) { - for (const auto &func : functions_) { + for (const auto &func : pynative_reg_functions_) { + func(m); + } + for (const auto &func : graph_reg_functions_) { func(m); } } @@ -56,14 +62,22 @@ private: ModuleRegistry &operator=(const ModuleRegistry &) = delete; // Store all registered functions - std::vector functions_; + std::vector pynative_reg_functions_; + std::vector graph_reg_functions_; }; -#define REG_GRAPH_MODE_OP(op) \ +#define REG_GRAPH_MODE_OP(op, OpFuncImplClass, KernelClass) \ + MS_CUSTOM_OPS_REGISTER(op, OpFuncImplClass, KernelClass); \ static void op##_func() {} \ - static void op##_register(pybind11::module_ &m) { m.def(#op, &op##_func); } \ + static void op##_register(pybind11::module_ &m) { \ + if (!pybind11::hasattr(m, #op)) { \ + m.def(#op, &op##_func); \ + } \ + } \ struct op##_registrar { \ - op##_registrar() { ModuleRegistry::Instance().Register(op##_register); } \ + op##_registrar() { \ + ModuleRegistry::Instance().Register(op##_register, false); \ + } \ }; \ static op##_registrar registrar_instance @@ -84,17 +98,4 @@ private: } \ static void CONCATENATE(func_register_, __LINE__)(pybind11::module_ & m) -#define MS_CUSTOM_OPS_REGISTER(OpName, OpFuncImplClass, KernelClass) \ - namespace mindspore { \ - namespace ops { \ - static OpFuncImplClass g##OpName##FuncImplReal; \ - OpFuncImpl &gCustom_##OpName##FuncImpl = g##OpName##FuncImplReal; \ - } /* namespace ops */ \ - } /* namespace mindspore */ \ - \ - namespace ms_custom_ops { \ - using namespace mindspore::ops; \ - using namespace mindspore::kernel; \ - MS_CUSTOM_KERNEL_FACTORY_REG("Custom_" #OpName, KernelClass); \ - } /* namespace ms_custom_ops */ #endif // MS_CUSTOM_OPS_MODULE_H_ diff --git a/ccsrc/base/ms_kernels_internal/graphmode/internal_kernel_mod.h b/ccsrc/base/ms_kernels_internal/graphmode/internal_kernel_mod.h index 1bb677f..3c086d5 100644 --- a/ccsrc/base/ms_kernels_internal/graphmode/internal_kernel_mod.h +++ b/ccsrc/base/ms_kernels_internal/graphmode/internal_kernel_mod.h @@ -31,6 +31,8 @@ #include "module.h" namespace ms_custom_ops { +using namespace mindspore::ops; + class InternalKernelMod : public KernelMod { public: InternalKernelMod() { diff --git a/ccsrc/ops/ascendc/add.cc b/ccsrc/ops/ascendc/add.cc index 98594b6..01308a4 100644 --- a/ccsrc/ops/ascendc/add.cc +++ b/ccsrc/ops/ascendc/add.cc @@ -24,8 +24,7 @@ #include #include -namespace mindspore { -namespace ops { +namespace ms_custom_ops { class OPS_API AddCustomOpFuncImpl : public OpFuncImpl { public: ShapeArray InferShape(const PrimitivePtr &primitive, @@ -41,10 +40,7 @@ public: bool GeneralInferRegistered() const override { return true; } }; -} // namespace ops -} // namespace mindspore -namespace ms_custom_ops { class AddCustomAscend : public AscendCKernelMod { public: AddCustomAscend() : AscendCKernelMod(std::move("aclnnAddCustom")) {} @@ -69,7 +65,8 @@ private: }; } // namespace ms_custom_ops -MS_CUSTOM_OPS_REGISTER(add, AddCustomOpFuncImpl, AddCustomAscend); +REG_GRAPH_MODE_OP(add, ms_custom_ops::AddCustomOpFuncImpl, + ms_custom_ops::AddCustomAscend); // ============================================================================= // PYBOOST MODE IMPLEMENTATION diff --git a/ccsrc/ops/ms_kernels_internal/reshape_and_cache.cc b/ccsrc/ops/ms_kernels_internal/reshape_and_cache.cc index 764fddd..b6bfc7f 100644 --- a/ccsrc/ops/ms_kernels_internal/reshape_and_cache.cc +++ b/ccsrc/ops/ms_kernels_internal/reshape_and_cache.cc @@ -33,8 +33,7 @@ #include #include -namespace mindspore { -namespace ops { +namespace ms_custom_ops { class OPS_API CustomReshapeAndCacheOpFuncImpl : public OpFuncImpl { public: ShapeArray InferShape(const PrimitivePtr &primitive, @@ -49,10 +48,8 @@ public: bool GeneralInferRegistered() const override { return true; } }; -} // namespace ops -} // namespace mindspore -namespace ms_custom_ops { + constexpr size_t kInputKeyIndex = 0; constexpr size_t kInputValueIndex = 1; constexpr size_t kInputKeyCacheIndex = 2; @@ -66,8 +63,9 @@ public: ~CustomReshapeAndCache() = default; void InitKernelInputsOutputsIndex() override { - kernel_inputs_index_ = {kInputKeyIndex, kInputValueIndex, kInputKeyCacheIndex, - kInputValueCacheIndex, kInputSlotMappingIndex}; + kernel_inputs_index_ = {kInputKeyIndex, kInputValueIndex, + kInputKeyCacheIndex, kInputValueCacheIndex, + kInputSlotMappingIndex}; kernel_outputs_index_ = {kOutputIndex}; } @@ -83,8 +81,8 @@ protected: }; } // namespace ms_custom_ops -MS_CUSTOM_OPS_REGISTER(reshape_and_cache, CustomReshapeAndCacheOpFuncImpl, - CustomReshapeAndCache); +REG_GRAPH_MODE_OP(reshape_and_cache, ms_custom_ops::CustomReshapeAndCacheOpFuncImpl, + ms_custom_ops::CustomReshapeAndCache); // ============================================================================= // PYBOOST MODE IMPLEMENTATION -- Gitee