From b7df4dedf2c3ef5710bbedce8232563c20fbabc2 Mon Sep 17 00:00:00 2001 From: dayschan Date: Mon, 29 Sep 2025 18:59:12 +0800 Subject: [PATCH 1/3] add triton op supports --- ops/framework/triton/graphmode/test.py | 7 + .../triton/graphmode/triton_kernel_mod.cc | 109 ++++++++++++ .../triton/graphmode/triton_kernel_mod.h | 61 +++++++ .../triton/graphmode/triton_op_def.cc | 163 ++++++++++++++++++ .../triton/graphmode/triton_op_infer_impl.h | 37 ++++ tests/st/triton_cases/triton_ops.py | 36 ++++ 6 files changed, 413 insertions(+) create mode 100644 ops/framework/triton/graphmode/test.py create mode 100644 ops/framework/triton/graphmode/triton_kernel_mod.cc create mode 100644 ops/framework/triton/graphmode/triton_kernel_mod.h create mode 100644 ops/framework/triton/graphmode/triton_op_def.cc create mode 100644 ops/framework/triton/graphmode/triton_op_infer_impl.h create mode 100644 tests/st/triton_cases/triton_ops.py diff --git a/ops/framework/triton/graphmode/test.py b/ops/framework/triton/graphmode/test.py new file mode 100644 index 0000000..3c63802 --- /dev/null +++ b/ops/framework/triton/graphmode/test.py @@ -0,0 +1,7 @@ +import mindspore +from mindspore import ops + +my_ops = ops.CustomOpBuilder("triton_op", + ["triton_op_def.cc", "triton_kernel_mod.cc"], + include_paths=["/home/chendeshi/ms_custom_ops"], + backend="Ascend").load() diff --git a/ops/framework/triton/graphmode/triton_kernel_mod.cc b/ops/framework/triton/graphmode/triton_kernel_mod.cc new file mode 100644 index 0000000..dfdc5a3 --- /dev/null +++ b/ops/framework/triton/graphmode/triton_kernel_mod.cc @@ -0,0 +1,109 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/framework/triton/graphmode/triton_kernel_mod.h" + +namespace ms_custom_ops { +namespace { +// This function is derived from: +// https://gitee.com/mindspore/mindspore/blob/master/mindspore/ccsrc/frontend/operator/py_execute_py.h +// original function: GetValueByAbstract +ValuePtr GetValueByKernelTensor(const kernel::KernelTensor *kernel_tensor) { + MS_EXCEPTION_IF_NULL(kernel_tensor); + if (kernel_tensor->GetValueTrack() != nullptr && !kernel_tensor->GetValueTrack()->isa()) { + return kernel_tensor->GetValueTrack(); + } else if (IsShapeEmpty(kernel_tensor->GetShapeVector())) { + auto type_id = + (kernel_tensor->dtype_id() == TypeId::kTypeUnknown ? TypeId::kNumberTypeInt64 : kernel_tensor->dtype_id()); + return tensor::from_spec(type_id, kernel_tensor->GetShapeVector(), device::DeviceType::kAscend); + } + + MS_LOG(DEBUG) << "Type:" << kernel_tensor->dtype_id() << " shape:" << kernel_tensor->GetShapeVector() + << " size:" << kernel_tensor->size(); + auto real_value = kernel_tensor->GetValue(); + MS_EXCEPTION_IF_NULL(real_value); + if (!real_value->isa()) { + MS_LOG(EXCEPTION) << "Invalid kernel tensor value:" << real_value->ToString(); + } + + if (kernel_tensor->GetType() != nullptr && kernel_tensor->GetType()->isa()) { + auto kernel_tensor_value = real_value->cast(); + MS_EXCEPTION_IF_NULL(kernel_tensor_value); + return common::AnfAlgo::ValueToScalar(kernel_tensor_value, kernel_tensor->GetType()->type_id()); + } + return std::make_shared(kernel_tensor->dtype_id(), kernel_tensor->GetShapeVector(), + kernel_tensor->device_address()); +} +} // namespace + +bool TritonKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + func_id_ = GetValue(primitive_->GetAttr("fn_id")); + return true; +} + +bool TritonKernelMod::Launch(const std::vector &inputs, const std::vector &, + const std::vector &, void *stream_ptr) { + if (!init_) { + py_func_ = GetPythonFunc(); + init_ = true; + } + return ExecuteKernel(inputs); +} + +bool TritonKernelMod::ExecuteKernel(const std::vector &inputs) { + if (Py_IsInitialized() != true) { + MS_LOG(ERROR) << "Py_IsInitialized failed."; + return false; + } + + py::gil_scoped_acquire gil_acquire; + py::object result; + if (inputs.size() > 0) { + py::tuple args(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + args[i] = ValueToPyData(GetValueByKernelTensor(inputs[i])); + } + result = py_func_(*args); + } else { + result = py_func_(); + } + + if (!result.is_none()) { + MS_LOG(ERROR) << "This triton function should return none"; + return false; + } + return true; +} + +py::function TritonKernelMod::GetPythonFunc() const { + py::gil_scoped_acquire gil_acquire; + static const std::string &module_name = "mindspore.ops.operations._pyfunc_registry"; + static const std::string &entrance = "get_pyfunc"; + py::module mod = py::module::import(module_name.c_str()); + py::object get_pyfunc_obj = mod.attr(entrance.c_str()); + if (get_pyfunc_obj.is_none()) { + MS_LOG(EXCEPTION) << "Cannot find a python function named " << entrance << "in module" << module_name; + } + + py::function get_pyfunc = get_pyfunc_obj.cast(); + py::object py_func_obj = get_pyfunc(py::int_(func_id_)); + if (py_func_obj.is_none()) { + MS_LOG(EXCEPTION) << "Cannot find python func with id: " << func_id_; + } + + return py_func_obj.cast(); +} +} // namespace ms_custom_ops diff --git a/ops/framework/triton/graphmode/triton_kernel_mod.h b/ops/framework/triton/graphmode/triton_kernel_mod.h new file mode 100644 index 0000000..34238c2 --- /dev/null +++ b/ops/framework/triton/graphmode/triton_kernel_mod.h @@ -0,0 +1,61 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_CUSTOM_OPS_FRAMEWORK_TRITON_GRAPHMODE_TRITON_KERNEL_MOD_H_ +#define MS_CUSTOM_OPS_FRAMEWORK_TRITON_GRAPHMODE_TRITON_KERNEL_MOD_H_ +#include +#include +#include + +#include "custom_op_api.h" +#include "ops/framework/module.h" + +namespace ms_custom_ops { +using namespace mindspore; +using namespace mindspore::ops; +using namespace mindspore::kernel; + +class TritonKernelMod : public KernelMod { + public: + TritonKernelMod() = default; + ~TritonKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + // Construct arguments with raw memory and invoke Python function. + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &, void *stream_ptr) override; + + std::vector GetOpSupport() override { + MS_LOG(EXCEPTION) << "This interface is not support in triton kernel."; + } + + protected: + // Get Python function from anchor. + py::function GetPythonFunc() const; + bool ExecuteKernel(const std::vector &args); + + // The Python object is not acceptable for `Primitive` attribute. So we pass an unique key instead of Python function. + // mindspore.ops.operations.PyFunc store the Python function to a dict, and pass the key to backend kernel. + // The kernel get the Python functhon by the key from the dict when the kernel is first invoked. + int64_t func_id_{0}; + py::function py_func_; + bool init_{false}; + // The kernel hold the input tensors during execution to avoid dynamic malloc/free host memory. + // std::vector> input_tensors_; +}; +} // namespace ms_custom_ops +#endif // MS_CUSTOM_OPS_FRAMEWORK_TRITON_GRAPHMODE_TRITON_KERNEL_MOD_H_ diff --git a/ops/framework/triton/graphmode/triton_op_def.cc b/ops/framework/triton/graphmode/triton_op_def.cc new file mode 100644 index 0000000..269447b --- /dev/null +++ b/ops/framework/triton/graphmode/triton_op_def.cc @@ -0,0 +1,163 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "custom_op_api.h" +#include "ops/framework/triton/graphmode/triton_op_infer_impl.h" +#include "ops/framework/triton/graphmode/triton_kernel_mod.h" + +namespace mindspore::ops { +OP_DTYPE TypeStrToEnum(const std::string &s) { + static const std::unordered_map kMap = { + // 基础类型 + {"bool", DT_BOOL}, + {"int", DT_INT}, + {"float", DT_FLOAT}, + {"number", DT_NUMBER}, + {"tensor", DT_TENSOR}, + {"str", DT_STR}, + {"any", DT_ANY}, + {"type", DT_TYPE}, + {"none", DT_NONE}, + + // tuple_* + {"tuple_bool", DT_TUPLE_BOOL}, + {"tuple_int", DT_TUPLE_INT}, + {"tuple_float", DT_TUPLE_FLOAT}, + {"tuple_number", DT_TUPLE_NUMBER}, + {"tuple_tensor", DT_TUPLE_TENSOR}, + {"tuple_str", DT_TUPLE_STR}, + {"tuple_any", DT_TUPLE_ANY}, + + // list_* + {"list_bool", DT_LIST_BOOL}, + {"list_int", DT_LIST_INT}, + {"list_float", DT_LIST_FLOAT}, + {"list_number", DT_LIST_NUMBER}, + {"list_tensor", DT_LIST_TENSOR}, + {"list_str", DT_LIST_STR}, + {"list_any", DT_LIST_ANY}, + }; + + if (auto it = kMap.find(s); it != kMap.end()) { + return it->second; + } + throw std::invalid_argument("TypeStrToEnum: unknown dtype string: '" + s + "'"); +} + +ms_custom_ops::TritonInferImpl gCustom_TritonFuncImpl; +static const OpDef gOpDefTemplate = { + /* name_ = */ "", + /* args_ = */ {}, + /* returns_ = */ {}, + /* signatures_ = */ {}, + /* indexes_ = */ {}, + /* func_impl_ = */ gCustom_TritonFuncImpl, + /* enable_dispatch_ = */ false, + /* is_view_ = */ false, + /* is_graph_view_ = */ false, +}; + +// 仅提供:构造器 + Arg(role='input'|'output') + Register() +class TritonOpDefBuilder { + public: + explicit TritonOpDefBuilder(const std::string &name) : name_("CustomTriton_" + name), op_def_(NewOp()) { + op_def_->name_ = name_; + } + + // role: + // - "input": 仅作为输入 + // - "output": 自身既是输入也是输出(就地),实现为:先加到 args_,再在 returns_ 中添加同名输出并绑定到自身索引 + TritonOpDefBuilder &Arg(const std::string &arg_name, const std::string &role, const std::string &obj_type) { + if (role == "input") { + (void)AddInputImpl(arg_name, obj_type); + } else if (role == "output") { + const size_t idx = AddInputImpl(arg_name, obj_type); + AddOutputImpl(arg_name, obj_type, static_cast(idx)); + } else { + MS_LOG(EXCEPTION) << "Unsupported role: " << role << ". Expect 'input' or 'output'."; + } + return *this; + } + + void Register() { + if (mindspore::ops::GetOpDef(name_) != nullptr) { + MS_LOG(EXCEPTION) << "OpDef named '" << name_ << "' already exists."; + } + mindspore::ops::AddOpDef(name_, op_def_); + + mindspore::kernel::CustomKernelFactory::Instance().Register( + name_, []() { return std::make_shared(); }); + } + + private: + size_t AddInputImpl(const std::string &arg_name, const std::string &obj_type) { + if (op_def_->indexes_.find(arg_name) != op_def_->indexes_.end()) { + MS_LOG(EXCEPTION) << "Input arg '" << arg_name << "' already exists."; + } + OpInputArg arg; + arg.arg_name_ = arg_name; + arg.arg_dtype_ = TypeStrToEnum(obj_type); + arg.as_init_arg_ = false; + arg.arg_handler_.clear(); + arg.cast_dtype_.clear(); + arg.is_optional_ = false; + const size_t idx = op_def_->args_.size(); + op_def_->args_.push_back(std::move(arg)); + op_def_->indexes_[arg_name] = idx; + op_def_->signatures_.emplace_back(Signature(arg_name, SignatureEnumRW::kRWDefault, + SignatureEnumKind::kKindPositionalKeyword, nullptr, + SignatureEnumDType::kDTypeEmptyDefaultValue)); + return idx; + } + + void AddOutputImpl(const std::string &arg_name, const std::string &obj_type, int64_t input_index) { + OpOutputArg arg; + arg.arg_name_ = arg_name; + arg.arg_dtype_ = TypeStrToEnum(obj_type); + arg.inplace_input_index_ = input_index; // 绑定到自身输入索引,实现自绑定就地 + op_def_->returns_.push_back(std::move(arg)); + op_def_->signatures_[input_index].rw = SignatureEnumRW::kRWWrite; + } + + static OpDef *NewOp() { + static std::vector> pool; + pool.emplace_back(std::make_unique(gOpDefTemplate)); + return pool.back().get(); + } + + std::string name_; + OpDef *op_def_; +}; +} // namespace mindspore::ops + +MS_CUSTOM_OPS_EXTENSION_MODULE(m) { + py::class_(m, "TritonOpDefBuilder") + .def(py::init(), py::arg("name")) + .def( + "arg", + [](mindspore::ops::TritonOpDefBuilder &self, const std::string &arg_name, const std::string &role, + const std::string &obj_type) -> mindspore::ops::TritonOpDefBuilder & { + return self.Arg(arg_name, role, obj_type); + }, + py::arg("arg_name"), py::arg("role"), py::arg("obj_type"), py::return_value_policy::reference_internal, + "Add an argument. role in {'input','output'}. obj_type is canonical string.") + .def("register_op", &mindspore::ops::TritonOpDefBuilder::Register, "Register the OpDef and its Triton kernel"); +} diff --git a/ops/framework/triton/graphmode/triton_op_infer_impl.h b/ops/framework/triton/graphmode/triton_op_infer_impl.h new file mode 100644 index 0000000..4f56e9b --- /dev/null +++ b/ops/framework/triton/graphmode/triton_op_infer_impl.h @@ -0,0 +1,37 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_CUSTOM_OPS_FRAMEWORK_TRITON_GRAPHMODE_TRITON_OP_INFER_IMPL_H_ +#define MS_CUSTOM_OPS_FRAMEWORK_TRITON_GRAPHMODE_TRITON_OP_INFER_IMPL_H_ + +#include "custom_op_api.h" +#include "ops/framework/module.h" + +using namespace mindspore; +using namespace mindspore::ops; +namespace ms_custom_ops { +class TritonInferImpl : public OpFuncImpl { + public: + ShapeArray InferShape(const PrimitivePtr &, const InferInfoPtrList &) const override { return {{1}}; } + + std::vector InferType(const PrimitivePtr &, const InferInfoPtrList &) const override { + return {TypeId::kNumberTypeInt64}; + } + + bool GeneralInferRegistered() const override { return true; } +}; +} // namespace ms_custom_ops +#endif // MS_CUSTOM_OPS_FRAMEWORK_TRITON_GRAPHMODE_TRITON_OP_INFER_IMPL_H_ diff --git a/tests/st/triton_cases/triton_ops.py b/tests/st/triton_cases/triton_ops.py new file mode 100644 index 0000000..45ff484 --- /dev/null +++ b/tests/st/triton_cases/triton_ops.py @@ -0,0 +1,36 @@ +import triton +import triton.language as tl + + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + ): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +# @ms_triton(("x", "input", "tensor"), ("y", "input", "tensor"), ("out", "output", "tensor")) +def add(x, y, out): + n_elements = out.numel() + def grid(meta): return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + return add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=1024) -- Gitee From 610b97db197c9bcfeda650bfd55c23ead51d1382 Mon Sep 17 00:00:00 2001 From: jiaorui Date: Sat, 11 Oct 2025 11:38:29 +0800 Subject: [PATCH 2/3] ms_triton --- ops/framework/triton/graphmode/test.py | 4 +- .../triton/graphmode/triton_kernel_mod.h | 2 + tests/st/triton_cases/test_decorator.py | 27 ++++++++++ tests/st/triton_cases/triton_decorator.py | 50 +++++++++++++++++++ tests/st/triton_cases/triton_ops.py | 2 +- 5 files changed, 82 insertions(+), 3 deletions(-) create mode 100644 tests/st/triton_cases/test_decorator.py create mode 100644 tests/st/triton_cases/triton_decorator.py diff --git a/ops/framework/triton/graphmode/test.py b/ops/framework/triton/graphmode/test.py index 3c63802..93d2f05 100644 --- a/ops/framework/triton/graphmode/test.py +++ b/ops/framework/triton/graphmode/test.py @@ -2,6 +2,6 @@ import mindspore from mindspore import ops my_ops = ops.CustomOpBuilder("triton_op", - ["triton_op_def.cc", "triton_kernel_mod.cc"], - include_paths=["/home/chendeshi/ms_custom_ops"], + ["triton_op_def.cc", "triton_kernel_mod.cc","/home/jiaorui/ms_custom_ops/ops/framework/module.cc"], + include_paths=["/home/jiaorui/ms_custom_ops"], backend="Ascend").load() diff --git a/ops/framework/triton/graphmode/triton_kernel_mod.h b/ops/framework/triton/graphmode/triton_kernel_mod.h index 34238c2..1b029d2 100644 --- a/ops/framework/triton/graphmode/triton_kernel_mod.h +++ b/ops/framework/triton/graphmode/triton_kernel_mod.h @@ -42,6 +42,7 @@ class TritonKernelMod : public KernelMod { std::vector GetOpSupport() override { MS_LOG(EXCEPTION) << "This interface is not support in triton kernel."; } + void set_fullname(const std::string &fullname) override { fullname_ = fullname; } protected: // Get Python function from anchor. @@ -54,6 +55,7 @@ class TritonKernelMod : public KernelMod { int64_t func_id_{0}; py::function py_func_; bool init_{false}; + std::string fullname_; // The kernel hold the input tensors during execution to avoid dynamic malloc/free host memory. // std::vector> input_tensors_; }; diff --git a/tests/st/triton_cases/test_decorator.py b/tests/st/triton_cases/test_decorator.py new file mode 100644 index 0000000..8d30229 --- /dev/null +++ b/tests/st/triton_cases/test_decorator.py @@ -0,0 +1,27 @@ +import mindspore as ms +from triton_decorator import ms_triton +import numpy as np + + +@ms_triton(("x", "input", "tensor"), ("y", "input", "tensor"), ("out", "output", "tensor")) +def custom_add(x, y, out): + print("=============================") + print("=============================") + + +class MyNet(ms.nn.Cell): + def __init__(self): + super(MyNet, self).__init__() + # self.add =custom_add + # print(self.add) + + def construct(self, x, y): + return custom_add(x, y, x) + + +ms.set_device("Ascend") +ms.set_context(mode=ms.GRAPH_MODE, save_graphs=True, save_graphs_path="./graphs") + +x = np.array([1, 2, 3], dtype=np.float16) +y = np.array([4, 5, 6], dtype=np.float16) +MyNet()(ms.Tensor(x), ms.Tensor(y)) \ No newline at end of file diff --git a/tests/st/triton_cases/triton_decorator.py b/tests/st/triton_cases/triton_decorator.py new file mode 100644 index 0000000..f53e234 --- /dev/null +++ b/tests/st/triton_cases/triton_decorator.py @@ -0,0 +1,50 @@ +# ms_triton.py +import functools +import threading +from mindspore.ops.primitive import Primitive, prim_arg_register +from mindspore.ops import signature as sig +from triton_op import TritonOpDefBuilder +from mindspore.ops.operations._pyfunc_registry import add_pyfunc + +_lock = threading.Lock() +_reg = set() + + +def _make_sigs(arg_specs): + return tuple(sig.make_sig(name, dtype=sig.sig_dtype.T) + for name, _, _ in arg_specs) + + +def ms_triton(*arg_specs): + def decorator(triton_fn): + op_name = triton_fn.__name__ + prim_cls_name = f"CustomTriton_{op_name}" + + with _lock: + if op_name not in _reg: + builder = TritonOpDefBuilder(op_name) + for name, role, obj_type in arg_specs: + builder.arg(name, role, obj_type) + builder.register_op() + _reg.add(op_name) + + func_id = id(triton_fn) + add_pyfunc(func_id, triton_fn) + + def init(self): + super(self.__class__, self).__init__(self.__class__.__name__) + self.add_prim_attr("fn_id", func_id) + + _PrimCls = type( + prim_cls_name, + (Primitive,), + { + "__mindspore_signature__": _make_sigs(arg_specs), + "__init__": prim_arg_register(init), + }, + ) + _prim_singleton = _PrimCls() + + return _prim_singleton + + return decorator diff --git a/tests/st/triton_cases/triton_ops.py b/tests/st/triton_cases/triton_ops.py index 45ff484..d811396 100644 --- a/tests/st/triton_cases/triton_ops.py +++ b/tests/st/triton_cases/triton_ops.py @@ -29,7 +29,7 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. tl.store(output_ptr + offsets, output, mask=mask) -# @ms_triton(("x", "input", "tensor"), ("y", "input", "tensor"), ("out", "output", "tensor")) +@ms_triton(("x", "input", "tensor"), ("y", "input", "tensor"), ("out", "output", "tensor")) def add(x, y, out): n_elements = out.numel() def grid(meta): return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) -- Gitee From 13e6d4c34d21ab376b9019f5518ef8e5947d18f8 Mon Sep 17 00:00:00 2001 From: dayschan Date: Sat, 11 Oct 2025 18:10:03 +0800 Subject: [PATCH 3/3] test triton_add --- ops/framework/triton/graphmode/test.py | 7 ------ .../triton/graphmode/triton_kernel_mod.cc | 10 ++------ .../ms_custom_ops/triton_utils.py | 9 ++++---- tests/st/triton_cases/test_decorator.py | 4 ++-- tests/st/triton_cases/test_triton_add.py | 23 +++++++++++++++++++ tests/st/triton_cases/triton_ops.py | 5 +++- 6 files changed, 35 insertions(+), 23 deletions(-) delete mode 100644 ops/framework/triton/graphmode/test.py rename tests/st/triton_cases/triton_decorator.py => python/ms_custom_ops/triton_utils.py (81%) create mode 100644 tests/st/triton_cases/test_triton_add.py diff --git a/ops/framework/triton/graphmode/test.py b/ops/framework/triton/graphmode/test.py deleted file mode 100644 index 93d2f05..0000000 --- a/ops/framework/triton/graphmode/test.py +++ /dev/null @@ -1,7 +0,0 @@ -import mindspore -from mindspore import ops - -my_ops = ops.CustomOpBuilder("triton_op", - ["triton_op_def.cc", "triton_kernel_mod.cc","/home/jiaorui/ms_custom_ops/ops/framework/module.cc"], - include_paths=["/home/jiaorui/ms_custom_ops"], - backend="Ascend").load() diff --git a/ops/framework/triton/graphmode/triton_kernel_mod.cc b/ops/framework/triton/graphmode/triton_kernel_mod.cc index dfdc5a3..529c2d8 100644 --- a/ops/framework/triton/graphmode/triton_kernel_mod.cc +++ b/ops/framework/triton/graphmode/triton_kernel_mod.cc @@ -70,20 +70,14 @@ bool TritonKernelMod::ExecuteKernel(const std::vector &inputs) { } py::gil_scoped_acquire gil_acquire; - py::object result; if (inputs.size() > 0) { py::tuple args(inputs.size()); for (size_t i = 0; i < inputs.size(); i++) { args[i] = ValueToPyData(GetValueByKernelTensor(inputs[i])); } - result = py_func_(*args); + py_func_(*args); } else { - result = py_func_(); - } - - if (!result.is_none()) { - MS_LOG(ERROR) << "This triton function should return none"; - return false; + py_func_(); } return true; } diff --git a/tests/st/triton_cases/triton_decorator.py b/python/ms_custom_ops/triton_utils.py similarity index 81% rename from tests/st/triton_cases/triton_decorator.py rename to python/ms_custom_ops/triton_utils.py index f53e234..5c8d684 100644 --- a/tests/st/triton_cases/triton_decorator.py +++ b/python/ms_custom_ops/triton_utils.py @@ -1,18 +1,16 @@ -# ms_triton.py -import functools import threading from mindspore.ops.primitive import Primitive, prim_arg_register from mindspore.ops import signature as sig -from triton_op import TritonOpDefBuilder from mindspore.ops.operations._pyfunc_registry import add_pyfunc +from ms_custom_ops import TritonOpDefBuilder _lock = threading.Lock() _reg = set() def _make_sigs(arg_specs): - return tuple(sig.make_sig(name, dtype=sig.sig_dtype.T) - for name, _, _ in arg_specs) + return tuple(sig.make_sig(name) if io == "input" else sig.make_sig(name, sig.sig_rw.RW_WRITE) + for name, io, _ in arg_specs) def ms_triton(*arg_specs): @@ -33,6 +31,7 @@ def ms_triton(*arg_specs): def init(self): super(self.__class__, self).__init__(self.__class__.__name__) + self.add_prim_attr("side_effect_mem", True) self.add_prim_attr("fn_id", func_id) _PrimCls = type( diff --git a/tests/st/triton_cases/test_decorator.py b/tests/st/triton_cases/test_decorator.py index 8d30229..b6db8ab 100644 --- a/tests/st/triton_cases/test_decorator.py +++ b/tests/st/triton_cases/test_decorator.py @@ -1,5 +1,5 @@ import mindspore as ms -from triton_decorator import ms_triton +from ms_custom_ops.triton_utils import ms_triton import numpy as np @@ -24,4 +24,4 @@ ms.set_context(mode=ms.GRAPH_MODE, save_graphs=True, save_graphs_path="./graphs" x = np.array([1, 2, 3], dtype=np.float16) y = np.array([4, 5, 6], dtype=np.float16) -MyNet()(ms.Tensor(x), ms.Tensor(y)) \ No newline at end of file +MyNet()(ms.Tensor(x), ms.Tensor(y)) diff --git a/tests/st/triton_cases/test_triton_add.py b/tests/st/triton_cases/test_triton_add.py new file mode 100644 index 0000000..2ca406d --- /dev/null +++ b/tests/st/triton_cases/test_triton_add.py @@ -0,0 +1,23 @@ +import mindspore as ms +import numpy as np +import triton_ops + + +class MyNet(ms.nn.Cell): + def __init__(self): + super(MyNet, self).__init__() + + def construct(self, x, y): + out = ms.mint.empty_like(x) + t = triton_ops.add(x, y, out) + out2 = out * out + return out2 + + +ms.set_device("Ascend") +ms.set_context(mode=ms.GRAPH_MODE, save_graphs=True, save_graphs_path="./graphs") + +x = np.array([1, 2, 3], dtype=np.float16) +y = np.array([4, 5, 6], dtype=np.float16) +ret = MyNet()(ms.Tensor(x), ms.Tensor(y)) +print(ret) diff --git a/tests/st/triton_cases/triton_ops.py b/tests/st/triton_cases/triton_ops.py index d811396..ccaa0a0 100644 --- a/tests/st/triton_cases/triton_ops.py +++ b/tests/st/triton_cases/triton_ops.py @@ -1,6 +1,8 @@ import triton import triton.language as tl +from ms_custom_ops.triton_utils import ms_triton + @triton.jit def add_kernel(x_ptr, # *Pointer* to first input vector. @@ -33,4 +35,5 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. def add(x, y, out): n_elements = out.numel() def grid(meta): return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) - return add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=1024) + add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=1024) + return None -- Gitee