From 2c503cda12ff0544506106568feb7734b12756d0 Mon Sep 17 00:00:00 2001 From: panzhihui Date: Wed, 5 Nov 2025 10:32:18 +0800 Subject: [PATCH] op plugin KernelInputInfo store tensor storage info --- .../ccsrc/plugin/cpu/cpu_device_context.cc | 4 +- .../kernel_select/kernel_select_cpu.cc | 2 +- .../ccsrc/pynative/utils/pynative_utils.cc | 4 +- .../custom_kernel_input_info.h | 25 ++--- .../custom_kernel_input_info_impl.h | 20 +++- .../custom_op_plugin_kernel.cc | 2 + .../kernel_mod_impl/custom_op_plugin_kernel.h | 4 +- .../custom/kernel_mod_impl/op_plugin_utils.cc | 2 + .../custom/kernel_mod_impl/op_plugin_utils.h | 2 + .../kernel_mod_impl/reg_op_plugin_kernels.cc | 2 + .../cpu/pyboost/pyboost_op_plugin_utils.h | 5 +- .../template/pyboost_cpu_call_template.tpl | 2 +- .../pyboost_cpu_customize_call_template.tpl | 2 +- .../mock_op_plugin/bool_tensor_iterator.h | 96 +++++++++++++++++++ .../mock_op_plugin/cumsum_ext_mock.cc | 5 +- .../mock_op_plugin/logical_and_mock.cc | 50 +++++++++- tests/st/ops/op_plugin/test_op_plugin.py | 51 ++++++++-- 17 files changed, 237 insertions(+), 41 deletions(-) create mode 100644 tests/st/ops/op_plugin/mock_op_plugin/bool_tensor_iterator.h diff --git a/mindspore/ccsrc/plugin/cpu/cpu_device_context.cc b/mindspore/ccsrc/plugin/cpu/cpu_device_context.cc index 9325b2ed23c..5382e421868 100644 --- a/mindspore/ccsrc/plugin/cpu/cpu_device_context.cc +++ b/mindspore/ccsrc/plugin/cpu/cpu_device_context.cc @@ -462,8 +462,8 @@ void CPUKernelExecutor::SetOperatorInfo(const KernelGraphPtr &graph) const { } kernel::KernelModPtr CPUKernelExecutor::CreateKernelMod(const std::string &op_name) const { - if (kernel::IsOpPluginKernel(op_name)) { - return kernel::Factory::Instance().Create(op_name); + if (kernel::op_plugin::IsOpPluginKernel(op_name)) { + return kernel::Factory::Instance().Create(op_name); } return kernel::Factory::Instance().Create(op_name); } diff --git a/mindspore/ccsrc/plugin/cpu/kernel_executor/kernel_select/kernel_select_cpu.cc b/mindspore/ccsrc/plugin/cpu/kernel_executor/kernel_select/kernel_select_cpu.cc index 23bb80b1001..a7552485dca 100644 --- a/mindspore/ccsrc/plugin/cpu/kernel_executor/kernel_select/kernel_select_cpu.cc +++ b/mindspore/ccsrc/plugin/cpu/kernel_executor/kernel_select/kernel_select_cpu.cc @@ -737,7 +737,7 @@ std::pair SetKernelInfoWithMsg(const CNodePtr &kerne static std::once_flag once; std::call_once(once, callback::CommonCallback::GetInstance().GetCallback( "RegisterOpPluginKernels")); // register op plugin kernels - static const auto &op_plugin_kernels = kernel::GetAllOpPluginKernelNames(); + static const auto &op_plugin_kernels = kernel::op_plugin::GetAllOpPluginKernelNames(); if (op_plugin_kernels.find(op_name) != op_plugin_kernels.end()) { UpdateCustomKernelBuildInfo(kernel_node, false); return {}; diff --git a/mindspore/ccsrc/pynative/utils/pynative_utils.cc b/mindspore/ccsrc/pynative/utils/pynative_utils.cc index 75fe4f1a093..2d6c35ce048 100644 --- a/mindspore/ccsrc/pynative/utils/pynative_utils.cc +++ b/mindspore/ccsrc/pynative/utils/pynative_utils.cc @@ -362,7 +362,9 @@ tensor::TensorPtr Common::ConvertStubNodeToTensor(const ValuePtr &v, bool need_c auto device_address = tensor->device_address(); MS_EXCEPTION_IF_NULL(device_address); const auto &device_target = device_address->GetDeviceType(); - if (device_target == device::DeviceType::kAscend) { + static const auto ms_op_plugin_path = common::EnvHelper::GetInstance()->GetEnv("MS_OP_PLUGIN_PATH"); + if (device_target == device::DeviceType::kAscend || + (ms_op_plugin_path != nullptr && device_target == device::DeviceType::kCPU)) { return tensor; } diff --git a/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/custom_kernel_input_info.h b/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/custom_kernel_input_info.h index 69e1cbed7b6..82ca87d18c1 100644 --- a/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/custom_kernel_input_info.h +++ b/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/custom_kernel_input_info.h @@ -19,12 +19,13 @@ #include #include +#include -namespace mindspore { -class CustomKernelData { - public: - CustomKernelData() = default; - virtual ~CustomKernelData() = default; +namespace mindspore::kernel { +namespace op_plugin { +struct OpPluginTensorStorageInfo { + std::vector strides; + size_t storage_offset; }; // KernelInputInfo is an interface class. @@ -44,13 +45,6 @@ class KernelInputInfo { void SetWorkSpace(const std::vector &workspace) { workspace_ = workspace; } const std::vector &WorkSpace() const { return workspace_; } - void SetKernelData(CustomKernelData *kernel_data) { kernel_data_ = kernel_data; } - const CustomKernelData *KernelData() const { return kernel_data_; } - - void DestructKernelData() { - delete kernel_data_; - kernel_data_ = nullptr; - } virtual size_t GetInputSize() = 0; virtual bool GetBoolInput(size_t idx) = 0; @@ -63,10 +57,9 @@ class KernelInputInfo { virtual std::vector> GetInt2DVecInput(size_t idx) = 0; virtual std::vector> GetFloat2DVecInput(size_t idx) = 0; virtual int GetInputTypeId(size_t idx) = 0; + virtual std::optional GetInputTensorLayout(size_t idx) = 0; std::vector workspace_; - - private: - CustomKernelData *kernel_data_{nullptr}; }; -} // namespace mindspore +} // namespace op_plugin +} // namespace mindspore::kernel #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CUSTOM_CUSTOM_KERNEL_INPUT_INFO_H_ diff --git a/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/custom_kernel_input_info_impl.h b/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/custom_kernel_input_info_impl.h index a31e2d91c41..6f6ab1d4652 100644 --- a/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/custom_kernel_input_info_impl.h +++ b/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/custom_kernel_input_info_impl.h @@ -19,10 +19,12 @@ #include #include +#include #include "include/runtime/hardware_abstract/kernel_base/kernel_tensor.h" #include "kernel/cpu/custom/kernel_mod_impl/custom_kernel_input_info.h" -namespace mindspore { +namespace mindspore::kernel { +namespace op_plugin { class KernelInputInfoImpl : public KernelInputInfo { public: KernelInputInfoImpl() = default; @@ -53,8 +55,22 @@ class KernelInputInfoImpl : public KernelInputInfo { int GetInputTypeId(size_t idx) { return static_cast(inputs_[idx]->dtype_id()); } + std::optional GetInputTensorLayout(size_t idx) { + if (inputs_[idx]->type_id() != TypeId::kObjectTypeTensorType) { + return std::nullopt; + } + const auto &input = inputs_[idx]; + if (input->tensor_storage_info() == nullptr) { + return std::nullopt; + } + const auto &strides = input->tensor_storage_info()->strides; + const auto &storage_offset = input->tensor_storage_info()->storage_offset; + return std::make_optional(OpPluginTensorStorageInfo{strides, storage_offset}); + } + private: std::vector inputs_; }; -} // namespace mindspore +} // namespace op_plugin +} // namespace mindspore::kernel #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CUSTOM_CUSTOM_KERNEL_INPUT_INFO_IMPL_H_ diff --git a/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/custom_op_plugin_kernel.cc b/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/custom_op_plugin_kernel.cc index 0aec50440da..8e38695d595 100644 --- a/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/custom_op_plugin_kernel.cc +++ b/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/custom_op_plugin_kernel.cc @@ -33,6 +33,7 @@ namespace mindspore { namespace kernel { +namespace op_plugin { void CustomOpPluginCpuKernelMod::SetKernelPath() { const char *op_plugin_path = common::EnvHelper::GetInstance()->GetEnv("MS_OP_PLUGIN_PATH"); @@ -152,5 +153,6 @@ int CustomOpPluginCpuKernelMod::Resize(const std::vector &inputs return static_cast(KRET_OK); } MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CustomOpPlugin, CustomOpPluginCpuKernelMod); +} // namespace op_plugin } // namespace kernel } // namespace mindspore diff --git a/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/custom_op_plugin_kernel.h b/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/custom_op_plugin_kernel.h index d5ecef0e3c2..d134eb395c8 100644 --- a/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/custom_op_plugin_kernel.h +++ b/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/custom_op_plugin_kernel.h @@ -26,7 +26,7 @@ namespace mindspore { namespace kernel { - +namespace op_plugin { class OPS_HOST_API CustomOpPluginCpuKernelMod : public NativeCpuKernelMod { public: CustomOpPluginCpuKernelMod() = default; @@ -53,7 +53,7 @@ class OPS_HOST_API CustomOpPluginCpuKernelMod : public NativeCpuKernelMod { private: void SetKernelPath(); }; +} // namespace op_plugin } // namespace kernel } // namespace mindspore - #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CUSTOM_CUSTOM_OP_PLUGIN_CPU_KERNEL_H_ diff --git a/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/op_plugin_utils.cc b/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/op_plugin_utils.cc index 673fa29eeb1..3fa52781452 100644 --- a/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/op_plugin_utils.cc +++ b/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/op_plugin_utils.cc @@ -54,6 +54,7 @@ #include "utils/log_adapter.h" namespace mindspore::kernel { +namespace op_plugin { void *GetOpPluginHandle() { static bool is_initialized = false; static void *handle = nullptr; @@ -195,4 +196,5 @@ int LaunchOpPluginKernel(const std::string &op_name, OpPluginKernelParam *param) param->shapes.data(), param->dtypes.data(), reinterpret_cast(¶m->kernel_info), param->stream); } +} // namespace op_plugin } // namespace mindspore::kernel diff --git a/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/op_plugin_utils.h b/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/op_plugin_utils.h index f395e691485..60446e35191 100644 --- a/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/op_plugin_utils.h +++ b/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/op_plugin_utils.h @@ -24,6 +24,7 @@ #include "kernel/cpu/utils/visible.h" namespace mindspore::kernel { +namespace op_plugin { struct OpPluginKernelParam { std::vector params; std::vector ndims; @@ -43,6 +44,7 @@ int LaunchOpPluginKernel(const std::string &op_name, OpPluginKernelParam *param) OpPluginKernelParam CreateOpPluginParam(const std::vector &inputs, const std::vector &outputs, const std::vector &workspace); +} // namespace op_plugin } // namespace mindspore::kernel #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CUSTOM_OP_PLUGIN_UTILS_H_ diff --git a/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/reg_op_plugin_kernels.cc b/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/reg_op_plugin_kernels.cc index e77dcdbd0c1..bf7cebae40f 100644 --- a/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/reg_op_plugin_kernels.cc +++ b/mindspore/ops/kernel/cpu/custom/kernel_mod_impl/reg_op_plugin_kernels.cc @@ -20,6 +20,7 @@ #include "include/utils/callback.h" namespace mindspore::kernel { +namespace op_plugin { void RegisterOpPluginKernels() { const auto &op_names = GetAllOpPluginKernelNames(); for (const auto &op_name : op_names) { @@ -29,4 +30,5 @@ void RegisterOpPluginKernels() { } } REGISTER_COMMON_CALLBACK(RegisterOpPluginKernels); +} // namespace op_plugin } // namespace mindspore::kernel diff --git a/mindspore/ops/kernel/cpu/pyboost/pyboost_op_plugin_utils.h b/mindspore/ops/kernel/cpu/pyboost/pyboost_op_plugin_utils.h index 2e64dbb19ef..306931c0c19 100644 --- a/mindspore/ops/kernel/cpu/pyboost/pyboost_op_plugin_utils.h +++ b/mindspore/ops/kernel/cpu/pyboost/pyboost_op_plugin_utils.h @@ -136,8 +136,9 @@ std::enable_if_t, std::vectorstream_id(), {op->output_abs()}, outputs); std::vector workspace_tensors; - auto op_plugin_param = CreateOpPluginParam(input_address_info.first, output_address_info.first, workspace_tensors); - auto ret = LaunchOpPluginKernel(op_name, &op_plugin_param); + auto op_plugin_param = + op_plugin::CreateOpPluginParam(input_address_info.first, output_address_info.first, workspace_tensors); + auto ret = op_plugin::LaunchOpPluginKernel(op_name, &op_plugin_param); if (ret != 0) { MS_LOG(EXCEPTION) << "Launch op plugin kernel failed, op name: " << op_name << ", return code: " << ret; } diff --git a/mindspore/ops/kernel/cpu/pyboost/template/pyboost_cpu_call_template.tpl b/mindspore/ops/kernel/cpu/pyboost/template/pyboost_cpu_call_template.tpl index 59018660e3f..072314a5655 100644 --- a/mindspore/ops/kernel/cpu/pyboost/template/pyboost_cpu_call_template.tpl +++ b/mindspore/ops/kernel/cpu/pyboost/template/pyboost_cpu_call_template.tpl @@ -1,6 +1,6 @@ MS_LOG(DEBUG) << op_name() << " call start"; -if (IsOpPluginKernel(op_name())) { +if (op_plugin::IsOpPluginKernel(op_name())) { outputs_ = PyboostLaunchOpPluginKernel<${inplace_indices}>(get_op(), ${call_args}); return ${return_values}; } diff --git a/mindspore/ops/kernel/cpu/pyboost/template/pyboost_cpu_customize_call_template.tpl b/mindspore/ops/kernel/cpu/pyboost/template/pyboost_cpu_customize_call_template.tpl index d2fa8fedb76..c032424a0b6 100644 --- a/mindspore/ops/kernel/cpu/pyboost/template/pyboost_cpu_customize_call_template.tpl +++ b/mindspore/ops/kernel/cpu/pyboost/template/pyboost_cpu_customize_call_template.tpl @@ -1,5 +1,5 @@ ProfileTrackerTask(); - if (IsOpPluginKernel(op_name())) { + if (op_plugin::IsOpPluginKernel(op_name())) { outputs_ = PyboostLaunchOpPluginKernel<${inplace_indices}>(get_op(), ${call_args}); return ${return_values}; } diff --git a/tests/st/ops/op_plugin/mock_op_plugin/bool_tensor_iterator.h b/tests/st/ops/op_plugin/mock_op_plugin/bool_tensor_iterator.h new file mode 100644 index 00000000000..38b5f7f2de7 --- /dev/null +++ b/tests/st/ops/op_plugin/mock_op_plugin/bool_tensor_iterator.h @@ -0,0 +1,96 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_TESTS_ST_OPS_OP_PLUGIN_MOCK_OP_PLUGIN_BOOL_TENSOR_ITERATOR_H_ +#define MINDSPORE_TESTS_ST_OPS_OP_PLUGIN_MOCK_OP_PLUGIN_BOOL_TENSOR_ITERATOR_H_ + +#include +#include + +// A simple row-major iterator over a strided boolean tensor. +class BoolTensorIterator { + public: + BoolTensorIterator(const bool *base_address, const std::vector &shape, + const std::vector &strides_in_elements, int64_t offset_in_elements) + : base_(base_address), + shape_(shape), + strides_elems_(strides_in_elements), + index_(shape.size(), 0), + dims_(static_cast(shape.size())), + total_elements_(ComputeNumElements(shape)), + consumed_elements_(0), + base_offset_elems_(offset_in_elements) {} + + bool has_next() const { return consumed_elements_ < total_elements_; } + + bool next() { + // Caller must ensure has_next() == true + const int64_t elem_offset = ComputeCurrentElementOffset(); + const bool value = *(base_ + elem_offset); + AdvanceIndex(); + ++consumed_elements_; + return value; + } + + void reset() { + std::fill(index_.begin(), index_.end(), 0); + consumed_elements_ = 0; + } + + private: + static int64_t ComputeNumElements(const std::vector &shape) { + if (shape.empty()) { + return 1; // scalar + } + int64_t num = 1; + for (int64_t dim : shape) { + num *= dim; + } + return num; + } + + int64_t ComputeCurrentElementOffset() const { + int64_t offset = base_offset_elems_; + for (int64_t i = 0; i < dims_; ++i) { + offset += index_[i] * strides_elems_[static_cast(i)]; + } + return offset; + } + + void AdvanceIndex() { + if (dims_ == 0) { + return; // scalar + } + for (int64_t d = dims_ - 1; d >= 0; --d) { + ++index_[static_cast(d)]; + if (index_[static_cast(d)] < shape_[static_cast(d)]) { + break; + } + index_[static_cast(d)] = 0; + } + } + + const bool *base_; + std::vector shape_; + std::vector strides_elems_; + std::vector index_; + int64_t dims_; + int64_t total_elements_; + int64_t consumed_elements_; + int64_t base_offset_elems_; +}; + +#endif // MINDSPORE_TESTS_ST_OPS_OP_PLUGIN_MOCK_OP_PLUGIN_BOOL_TENSOR_ITERATOR_H_ diff --git a/tests/st/ops/op_plugin/mock_op_plugin/cumsum_ext_mock.cc b/tests/st/ops/op_plugin/mock_op_plugin/cumsum_ext_mock.cc index fc3f37e1cd5..201b7f1989c 100644 --- a/tests/st/ops/op_plugin/mock_op_plugin/cumsum_ext_mock.cc +++ b/tests/st/ops/op_plugin/mock_op_plugin/cumsum_ext_mock.cc @@ -20,6 +20,8 @@ #include "custom_kernel_input_info.h" +using mindspore::kernel::op_plugin::KernelInputInfo; + extern "C" { // Mock implementation of the cumsum_ext operator. @@ -35,7 +37,7 @@ int CumsumExt(int nparam, void **params, int *ndims, int64_t **shapes, const cha const int64_t *x = static_cast(params[0]); int64_t *out = static_cast(params[nparam - 1]); - auto kernel_input_info = static_cast(extra); + auto kernel_input_info = static_cast(extra); if (kernel_input_info == nullptr) { std::cout << "Invalid kernel input info for cumsum_ext operator" << std::endl; return -1; @@ -62,5 +64,4 @@ int CumsumExt(int nparam, void **params, int *ndims, int64_t **shapes, const cha return 0; } - } // extern "C" diff --git a/tests/st/ops/op_plugin/mock_op_plugin/logical_and_mock.cc b/tests/st/ops/op_plugin/mock_op_plugin/logical_and_mock.cc index 360dd0f37df..6ac3cc2fbfc 100644 --- a/tests/st/ops/op_plugin/mock_op_plugin/logical_and_mock.cc +++ b/tests/st/ops/op_plugin/mock_op_plugin/logical_and_mock.cc @@ -18,6 +18,11 @@ #include #include +#include "custom_kernel_input_info.h" +#include "bool_tensor_iterator.h" + +using mindspore::kernel::op_plugin::KernelInputInfo; + extern "C" { // Mock implementation of the logical_and operator. // Not fully implemented, only for certain test cases. @@ -60,8 +65,51 @@ int LogicalAnd(int nparam, void **params, int *ndims, int64_t **shapes, const ch numel *= static_cast(d0); } + // Extract tensor layout (strides, storage_offset) from extra if available. + auto kernel_input_info = static_cast(extra); + std::vector shape_vec; + shape_vec.reserve(static_cast(dims)); + for (int i = 0; i < dims; ++i) { + shape_vec.push_back(shapes[0][i]); + } + + auto make_elem_strides = [&](size_t input_index) -> std::pair, int64_t> { + std::vector strides_elems(shape_vec.size(), 0); + int64_t offset_elems = 0; + if (kernel_input_info != nullptr) { + auto layout_opt = kernel_input_info->GetInputTensorLayout(input_index); + if (layout_opt.has_value()) { + const auto &layout = layout_opt.value(); + if (layout.strides.size() == shape_vec.size()) { + for (size_t i = 0; i < layout.strides.size(); ++i) { + strides_elems[i] = layout.strides[i]; + } + offset_elems = static_cast(layout.storage_offset); + return {strides_elems, offset_elems}; + } + } + } + // Fallback to contiguous layout in elements (row-major). + if (!shape_vec.empty()) { + strides_elems.back() = 1; + for (int64_t d = static_cast(shape_vec.size()) - 2; d >= 0; --d) { + strides_elems[static_cast(d)] = + strides_elems[static_cast(d + 1)] * shape_vec[static_cast(d + 1)]; + } + } + return {strides_elems, 0}; + }; + + auto [x_strides_elems, x_offset_elems] = make_elem_strides(0); + auto [y_strides_elems, y_offset_elems] = make_elem_strides(1); + + BoolTensorIterator it_x(x, shape_vec, x_strides_elems, x_offset_elems); + BoolTensorIterator it_y(y, shape_vec, y_strides_elems, y_offset_elems); + for (size_t i = 0; i < numel; ++i) { - out[i] = (x[i] || y[i]); // wrong implementation by purpose to ensure op plugin is used + const bool vx = it_x.next(); + const bool vy = it_y.next(); + out[i] = (vx || vy); // keep mocked behavior (intentional OR) } return 0; } diff --git a/tests/st/ops/op_plugin/test_op_plugin.py b/tests/st/ops/op_plugin/test_op_plugin.py index 0d2828f8fb2..d4797f13623 100644 --- a/tests/st/ops/op_plugin/test_op_plugin.py +++ b/tests/st/ops/op_plugin/test_op_plugin.py @@ -32,6 +32,9 @@ from mindspore.ops.auto_generate.gen_ops_prim import expand_dims_view_op def _configure_and_build_mock_plugin() -> str: """Configure and build the mock op plugin and return the built library path.""" + system = platform.system().lower() + if system == "windows": # windows is not supported for now + return "" this_dir = Path(__file__).resolve().parent plugin_src_dir = this_dir / "mock_op_plugin" build_dir = plugin_src_dir / "build" @@ -41,8 +44,6 @@ def _configure_and_build_mock_plugin() -> str: # include path for custom_kernel_input_info.h include_dir = os.path.join(repo_root, "include", "mindspore", "ops", "kernel", "cpu", "custom", "kernel_mod_impl") - system = platform.system().lower() - cmake_args = [ "cmake", "-S", @@ -115,7 +116,10 @@ def view_func(x): return out -def test_cumsum(mode): +@arg_mark(plat_marks=['cpu_linux', 'cpu_macos'], level_mark='level0', card_mark='onecard', + essential_mark='essential') +@pytest.mark.parametrize('mode', ['kbk', 'pynative']) +def test_normal_op(mode): """ Feature: op_plugin kernel Description: Test op_plugin kernel @@ -129,7 +133,10 @@ def test_cumsum(mode): assert np.allclose(output.asnumpy(), expect) -def test_logical_and(mode): +@arg_mark(plat_marks=['cpu_linux', 'cpu_macos'], level_mark='level0', card_mark='onecard', + essential_mark='essential') +@pytest.mark.parametrize('mode', ['kbk', 'pynative']) +def test_op_with_existing_cpu_kernelmod(mode): """ Feature: op_plugin kernel Description: Test op_plugin kernel when normal cpu kernelmod exists @@ -146,7 +153,10 @@ def test_logical_and(mode): assert np.allclose(output.asnumpy(), expect) -def test_inplace_relu(mode): +@arg_mark(plat_marks=['cpu_linux', 'cpu_macos'], level_mark='level0', card_mark='onecard', + essential_mark='essential') +@pytest.mark.parametrize('mode', ['kbk', 'pynative']) +def test_inplace_op(mode): """ Feature: op_plugin kernel Description: Test op_plugin kernel for inplace op @@ -158,8 +168,10 @@ def test_inplace_relu(mode): inplace_relu_forward_func(x) assert np.allclose(x.asnumpy(), expect) - -def test_view_feature(mode): +@arg_mark(plat_marks=['cpu_linux', 'cpu_macos'], level_mark='level0', card_mark='onecard', + essential_mark='essential') +@pytest.mark.parametrize('mode', ['pynative']) +def test_view_op(mode): """ Feature: op_plugin kernel Description: Test op_plugin kernel for view feature. Disabled for now @@ -170,7 +182,26 @@ def test_view_feature(mode): expected_x_after_inplace_relu = np.maximum(x.asnumpy(), 0.0) expect_view = expected_x_after_inplace_relu.reshape(6, 1) view = view_func(x) - # TODO: fix the issue of view feature in op plugin - # assert np.allclose(x.asnumpy(), expected_x_after_inplace_relu) - # assert np.allclose(view.asnumpy(), expect_view) + assert np.allclose(x.asnumpy(), expected_x_after_inplace_relu) + assert np.allclose(view.asnumpy(), expect_view) assert expect_view.shape == view.shape + +@arg_mark(plat_marks=['cpu_linux', 'cpu_macos'], level_mark='level0', card_mark='onecard', + essential_mark='essential') +@pytest.mark.parametrize('mode', ['pynative', 'kbk']) +def test_noncontiguous_input_op(mode): + """ + Feature: op_plugin kernel + Description: Test op_plugin kernel for noncontiguous input op + Expectation: Correct result. + """ + set_mode(mode) + orig_x = np.random.randint(0, 2, size=(4, 4)) == 1 + orig_y = np.random.randint(0, 2, size=(4, 4)) == 1 + x_np = orig_x[1:, ::2] + y_np = orig_y[1:, ::2] + x_noncontiguous = Tensor(orig_x, ms.bool_)[1:, ::2] + y_noncontiguous = Tensor(orig_y, ms.bool_)[1:, ::2] + expect = np.logical_or(x_np, y_np) + output = logical_and_forward_func(x_noncontiguous, y_noncontiguous) + assert np.allclose(output.asnumpy(), expect) -- Gitee