diff --git a/ops/framework/triton/graphmode/triton_kernel_mod.cc b/ops/framework/triton/graphmode/triton_kernel_mod.cc new file mode 100644 index 0000000000000000000000000000000000000000..529c2d85d3e69c7609d10a2273d77dde40fdfb31 --- /dev/null +++ b/ops/framework/triton/graphmode/triton_kernel_mod.cc @@ -0,0 +1,103 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/framework/triton/graphmode/triton_kernel_mod.h" + +namespace ms_custom_ops { +namespace { +// This function is derived from: +// https://gitee.com/mindspore/mindspore/blob/master/mindspore/ccsrc/frontend/operator/py_execute_py.h +// original function: GetValueByAbstract +ValuePtr GetValueByKernelTensor(const kernel::KernelTensor *kernel_tensor) { + MS_EXCEPTION_IF_NULL(kernel_tensor); + if (kernel_tensor->GetValueTrack() != nullptr && !kernel_tensor->GetValueTrack()->isa()) { + return kernel_tensor->GetValueTrack(); + } else if (IsShapeEmpty(kernel_tensor->GetShapeVector())) { + auto type_id = + (kernel_tensor->dtype_id() == TypeId::kTypeUnknown ? TypeId::kNumberTypeInt64 : kernel_tensor->dtype_id()); + return tensor::from_spec(type_id, kernel_tensor->GetShapeVector(), device::DeviceType::kAscend); + } + + MS_LOG(DEBUG) << "Type:" << kernel_tensor->dtype_id() << " shape:" << kernel_tensor->GetShapeVector() + << " size:" << kernel_tensor->size(); + auto real_value = kernel_tensor->GetValue(); + MS_EXCEPTION_IF_NULL(real_value); + if (!real_value->isa()) { + MS_LOG(EXCEPTION) << "Invalid kernel tensor value:" << real_value->ToString(); + } + + if (kernel_tensor->GetType() != nullptr && kernel_tensor->GetType()->isa()) { + auto kernel_tensor_value = real_value->cast(); + MS_EXCEPTION_IF_NULL(kernel_tensor_value); + return common::AnfAlgo::ValueToScalar(kernel_tensor_value, kernel_tensor->GetType()->type_id()); + } + return std::make_shared(kernel_tensor->dtype_id(), kernel_tensor->GetShapeVector(), + kernel_tensor->device_address()); +} +} // namespace + +bool TritonKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + func_id_ = GetValue(primitive_->GetAttr("fn_id")); + return true; +} + +bool TritonKernelMod::Launch(const std::vector &inputs, const std::vector &, + const std::vector &, void *stream_ptr) { + if (!init_) { + py_func_ = GetPythonFunc(); + init_ = true; + } + return ExecuteKernel(inputs); +} + +bool TritonKernelMod::ExecuteKernel(const std::vector &inputs) { + if (Py_IsInitialized() != true) { + MS_LOG(ERROR) << "Py_IsInitialized failed."; + return false; + } + + py::gil_scoped_acquire gil_acquire; + if (inputs.size() > 0) { + py::tuple args(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + args[i] = ValueToPyData(GetValueByKernelTensor(inputs[i])); + } + py_func_(*args); + } else { + py_func_(); + } + return true; +} + +py::function TritonKernelMod::GetPythonFunc() const { + py::gil_scoped_acquire gil_acquire; + static const std::string &module_name = "mindspore.ops.operations._pyfunc_registry"; + static const std::string &entrance = "get_pyfunc"; + py::module mod = py::module::import(module_name.c_str()); + py::object get_pyfunc_obj = mod.attr(entrance.c_str()); + if (get_pyfunc_obj.is_none()) { + MS_LOG(EXCEPTION) << "Cannot find a python function named " << entrance << "in module" << module_name; + } + + py::function get_pyfunc = get_pyfunc_obj.cast(); + py::object py_func_obj = get_pyfunc(py::int_(func_id_)); + if (py_func_obj.is_none()) { + MS_LOG(EXCEPTION) << "Cannot find python func with id: " << func_id_; + } + + return py_func_obj.cast(); +} +} // namespace ms_custom_ops diff --git a/ops/framework/triton/graphmode/triton_kernel_mod.h b/ops/framework/triton/graphmode/triton_kernel_mod.h new file mode 100644 index 0000000000000000000000000000000000000000..1b029d2501da87a4c6078804ec398e035f69fc10 --- /dev/null +++ b/ops/framework/triton/graphmode/triton_kernel_mod.h @@ -0,0 +1,63 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_CUSTOM_OPS_FRAMEWORK_TRITON_GRAPHMODE_TRITON_KERNEL_MOD_H_ +#define MS_CUSTOM_OPS_FRAMEWORK_TRITON_GRAPHMODE_TRITON_KERNEL_MOD_H_ +#include +#include +#include + +#include "custom_op_api.h" +#include "ops/framework/module.h" + +namespace ms_custom_ops { +using namespace mindspore; +using namespace mindspore::ops; +using namespace mindspore::kernel; + +class TritonKernelMod : public KernelMod { + public: + TritonKernelMod() = default; + ~TritonKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + // Construct arguments with raw memory and invoke Python function. + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &, void *stream_ptr) override; + + std::vector GetOpSupport() override { + MS_LOG(EXCEPTION) << "This interface is not support in triton kernel."; + } + void set_fullname(const std::string &fullname) override { fullname_ = fullname; } + + protected: + // Get Python function from anchor. + py::function GetPythonFunc() const; + bool ExecuteKernel(const std::vector &args); + + // The Python object is not acceptable for `Primitive` attribute. So we pass an unique key instead of Python function. + // mindspore.ops.operations.PyFunc store the Python function to a dict, and pass the key to backend kernel. + // The kernel get the Python functhon by the key from the dict when the kernel is first invoked. + int64_t func_id_{0}; + py::function py_func_; + bool init_{false}; + std::string fullname_; + // The kernel hold the input tensors during execution to avoid dynamic malloc/free host memory. + // std::vector> input_tensors_; +}; +} // namespace ms_custom_ops +#endif // MS_CUSTOM_OPS_FRAMEWORK_TRITON_GRAPHMODE_TRITON_KERNEL_MOD_H_ diff --git a/ops/framework/triton/graphmode/triton_op_def.cc b/ops/framework/triton/graphmode/triton_op_def.cc new file mode 100644 index 0000000000000000000000000000000000000000..269447b18ff569740937f0086eef0be0769f2394 --- /dev/null +++ b/ops/framework/triton/graphmode/triton_op_def.cc @@ -0,0 +1,163 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "custom_op_api.h" +#include "ops/framework/triton/graphmode/triton_op_infer_impl.h" +#include "ops/framework/triton/graphmode/triton_kernel_mod.h" + +namespace mindspore::ops { +OP_DTYPE TypeStrToEnum(const std::string &s) { + static const std::unordered_map kMap = { + // 基础类型 + {"bool", DT_BOOL}, + {"int", DT_INT}, + {"float", DT_FLOAT}, + {"number", DT_NUMBER}, + {"tensor", DT_TENSOR}, + {"str", DT_STR}, + {"any", DT_ANY}, + {"type", DT_TYPE}, + {"none", DT_NONE}, + + // tuple_* + {"tuple_bool", DT_TUPLE_BOOL}, + {"tuple_int", DT_TUPLE_INT}, + {"tuple_float", DT_TUPLE_FLOAT}, + {"tuple_number", DT_TUPLE_NUMBER}, + {"tuple_tensor", DT_TUPLE_TENSOR}, + {"tuple_str", DT_TUPLE_STR}, + {"tuple_any", DT_TUPLE_ANY}, + + // list_* + {"list_bool", DT_LIST_BOOL}, + {"list_int", DT_LIST_INT}, + {"list_float", DT_LIST_FLOAT}, + {"list_number", DT_LIST_NUMBER}, + {"list_tensor", DT_LIST_TENSOR}, + {"list_str", DT_LIST_STR}, + {"list_any", DT_LIST_ANY}, + }; + + if (auto it = kMap.find(s); it != kMap.end()) { + return it->second; + } + throw std::invalid_argument("TypeStrToEnum: unknown dtype string: '" + s + "'"); +} + +ms_custom_ops::TritonInferImpl gCustom_TritonFuncImpl; +static const OpDef gOpDefTemplate = { + /* name_ = */ "", + /* args_ = */ {}, + /* returns_ = */ {}, + /* signatures_ = */ {}, + /* indexes_ = */ {}, + /* func_impl_ = */ gCustom_TritonFuncImpl, + /* enable_dispatch_ = */ false, + /* is_view_ = */ false, + /* is_graph_view_ = */ false, +}; + +// 仅提供:构造器 + Arg(role='input'|'output') + Register() +class TritonOpDefBuilder { + public: + explicit TritonOpDefBuilder(const std::string &name) : name_("CustomTriton_" + name), op_def_(NewOp()) { + op_def_->name_ = name_; + } + + // role: + // - "input": 仅作为输入 + // - "output": 自身既是输入也是输出(就地),实现为:先加到 args_,再在 returns_ 中添加同名输出并绑定到自身索引 + TritonOpDefBuilder &Arg(const std::string &arg_name, const std::string &role, const std::string &obj_type) { + if (role == "input") { + (void)AddInputImpl(arg_name, obj_type); + } else if (role == "output") { + const size_t idx = AddInputImpl(arg_name, obj_type); + AddOutputImpl(arg_name, obj_type, static_cast(idx)); + } else { + MS_LOG(EXCEPTION) << "Unsupported role: " << role << ". Expect 'input' or 'output'."; + } + return *this; + } + + void Register() { + if (mindspore::ops::GetOpDef(name_) != nullptr) { + MS_LOG(EXCEPTION) << "OpDef named '" << name_ << "' already exists."; + } + mindspore::ops::AddOpDef(name_, op_def_); + + mindspore::kernel::CustomKernelFactory::Instance().Register( + name_, []() { return std::make_shared(); }); + } + + private: + size_t AddInputImpl(const std::string &arg_name, const std::string &obj_type) { + if (op_def_->indexes_.find(arg_name) != op_def_->indexes_.end()) { + MS_LOG(EXCEPTION) << "Input arg '" << arg_name << "' already exists."; + } + OpInputArg arg; + arg.arg_name_ = arg_name; + arg.arg_dtype_ = TypeStrToEnum(obj_type); + arg.as_init_arg_ = false; + arg.arg_handler_.clear(); + arg.cast_dtype_.clear(); + arg.is_optional_ = false; + const size_t idx = op_def_->args_.size(); + op_def_->args_.push_back(std::move(arg)); + op_def_->indexes_[arg_name] = idx; + op_def_->signatures_.emplace_back(Signature(arg_name, SignatureEnumRW::kRWDefault, + SignatureEnumKind::kKindPositionalKeyword, nullptr, + SignatureEnumDType::kDTypeEmptyDefaultValue)); + return idx; + } + + void AddOutputImpl(const std::string &arg_name, const std::string &obj_type, int64_t input_index) { + OpOutputArg arg; + arg.arg_name_ = arg_name; + arg.arg_dtype_ = TypeStrToEnum(obj_type); + arg.inplace_input_index_ = input_index; // 绑定到自身输入索引,实现自绑定就地 + op_def_->returns_.push_back(std::move(arg)); + op_def_->signatures_[input_index].rw = SignatureEnumRW::kRWWrite; + } + + static OpDef *NewOp() { + static std::vector> pool; + pool.emplace_back(std::make_unique(gOpDefTemplate)); + return pool.back().get(); + } + + std::string name_; + OpDef *op_def_; +}; +} // namespace mindspore::ops + +MS_CUSTOM_OPS_EXTENSION_MODULE(m) { + py::class_(m, "TritonOpDefBuilder") + .def(py::init(), py::arg("name")) + .def( + "arg", + [](mindspore::ops::TritonOpDefBuilder &self, const std::string &arg_name, const std::string &role, + const std::string &obj_type) -> mindspore::ops::TritonOpDefBuilder & { + return self.Arg(arg_name, role, obj_type); + }, + py::arg("arg_name"), py::arg("role"), py::arg("obj_type"), py::return_value_policy::reference_internal, + "Add an argument. role in {'input','output'}. obj_type is canonical string.") + .def("register_op", &mindspore::ops::TritonOpDefBuilder::Register, "Register the OpDef and its Triton kernel"); +} diff --git a/ops/framework/triton/graphmode/triton_op_infer_impl.h b/ops/framework/triton/graphmode/triton_op_infer_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..4f56e9bf1e575317f5532b00b2685f1551fbd5ba --- /dev/null +++ b/ops/framework/triton/graphmode/triton_op_infer_impl.h @@ -0,0 +1,37 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_CUSTOM_OPS_FRAMEWORK_TRITON_GRAPHMODE_TRITON_OP_INFER_IMPL_H_ +#define MS_CUSTOM_OPS_FRAMEWORK_TRITON_GRAPHMODE_TRITON_OP_INFER_IMPL_H_ + +#include "custom_op_api.h" +#include "ops/framework/module.h" + +using namespace mindspore; +using namespace mindspore::ops; +namespace ms_custom_ops { +class TritonInferImpl : public OpFuncImpl { + public: + ShapeArray InferShape(const PrimitivePtr &, const InferInfoPtrList &) const override { return {{1}}; } + + std::vector InferType(const PrimitivePtr &, const InferInfoPtrList &) const override { + return {TypeId::kNumberTypeInt64}; + } + + bool GeneralInferRegistered() const override { return true; } +}; +} // namespace ms_custom_ops +#endif // MS_CUSTOM_OPS_FRAMEWORK_TRITON_GRAPHMODE_TRITON_OP_INFER_IMPL_H_ diff --git a/python/ms_custom_ops/triton_utils.py b/python/ms_custom_ops/triton_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5c8d6841388fecdce255088906ef2fe331bb7e3d --- /dev/null +++ b/python/ms_custom_ops/triton_utils.py @@ -0,0 +1,49 @@ +import threading +from mindspore.ops.primitive import Primitive, prim_arg_register +from mindspore.ops import signature as sig +from mindspore.ops.operations._pyfunc_registry import add_pyfunc +from ms_custom_ops import TritonOpDefBuilder + +_lock = threading.Lock() +_reg = set() + + +def _make_sigs(arg_specs): + return tuple(sig.make_sig(name) if io == "input" else sig.make_sig(name, sig.sig_rw.RW_WRITE) + for name, io, _ in arg_specs) + + +def ms_triton(*arg_specs): + def decorator(triton_fn): + op_name = triton_fn.__name__ + prim_cls_name = f"CustomTriton_{op_name}" + + with _lock: + if op_name not in _reg: + builder = TritonOpDefBuilder(op_name) + for name, role, obj_type in arg_specs: + builder.arg(name, role, obj_type) + builder.register_op() + _reg.add(op_name) + + func_id = id(triton_fn) + add_pyfunc(func_id, triton_fn) + + def init(self): + super(self.__class__, self).__init__(self.__class__.__name__) + self.add_prim_attr("side_effect_mem", True) + self.add_prim_attr("fn_id", func_id) + + _PrimCls = type( + prim_cls_name, + (Primitive,), + { + "__mindspore_signature__": _make_sigs(arg_specs), + "__init__": prim_arg_register(init), + }, + ) + _prim_singleton = _PrimCls() + + return _prim_singleton + + return decorator diff --git a/tests/st/triton_cases/test_decorator.py b/tests/st/triton_cases/test_decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..b6db8abcf3ee8c0d88f0729ee1712fe4f5f41b39 --- /dev/null +++ b/tests/st/triton_cases/test_decorator.py @@ -0,0 +1,27 @@ +import mindspore as ms +from ms_custom_ops.triton_utils import ms_triton +import numpy as np + + +@ms_triton(("x", "input", "tensor"), ("y", "input", "tensor"), ("out", "output", "tensor")) +def custom_add(x, y, out): + print("=============================") + print("=============================") + + +class MyNet(ms.nn.Cell): + def __init__(self): + super(MyNet, self).__init__() + # self.add =custom_add + # print(self.add) + + def construct(self, x, y): + return custom_add(x, y, x) + + +ms.set_device("Ascend") +ms.set_context(mode=ms.GRAPH_MODE, save_graphs=True, save_graphs_path="./graphs") + +x = np.array([1, 2, 3], dtype=np.float16) +y = np.array([4, 5, 6], dtype=np.float16) +MyNet()(ms.Tensor(x), ms.Tensor(y)) diff --git a/tests/st/triton_cases/test_triton_add.py b/tests/st/triton_cases/test_triton_add.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca406d249795bc2ff17e0235abfa465c840603b --- /dev/null +++ b/tests/st/triton_cases/test_triton_add.py @@ -0,0 +1,23 @@ +import mindspore as ms +import numpy as np +import triton_ops + + +class MyNet(ms.nn.Cell): + def __init__(self): + super(MyNet, self).__init__() + + def construct(self, x, y): + out = ms.mint.empty_like(x) + t = triton_ops.add(x, y, out) + out2 = out * out + return out2 + + +ms.set_device("Ascend") +ms.set_context(mode=ms.GRAPH_MODE, save_graphs=True, save_graphs_path="./graphs") + +x = np.array([1, 2, 3], dtype=np.float16) +y = np.array([4, 5, 6], dtype=np.float16) +ret = MyNet()(ms.Tensor(x), ms.Tensor(y)) +print(ret) diff --git a/tests/st/triton_cases/triton_ops.py b/tests/st/triton_cases/triton_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ccaa0a076c90592e6f77e5400874d11f914eaec4 --- /dev/null +++ b/tests/st/triton_cases/triton_ops.py @@ -0,0 +1,39 @@ +import triton +import triton.language as tl + +from ms_custom_ops.triton_utils import ms_triton + + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + ): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +@ms_triton(("x", "input", "tensor"), ("y", "input", "tensor"), ("out", "output", "tensor")) +def add(x, y, out): + n_elements = out.numel() + def grid(meta): return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=1024) + return None