diff --git a/python/akg/composite/topi.py b/python/akg/composite/topi.py index 3bcff8760888f54831a5b785e2ec3ff43c42acc9..4ab8b251fde3a30458ea92b759c497813b62f592 100644 --- a/python/akg/composite/topi.py +++ b/python/akg/composite/topi.py @@ -17,9 +17,15 @@ """composite topi""" import akg.topi as topi from akg import tvm +from akg.tvm.hybrid import script from akg.utils.format_transform import get_const from akg.utils import validation_check as vc_util +import logging +import os +from pathlib import Path +import importlib.util + @tvm.register_func("ElemAny") def elem_any(inputs, attrs): @@ -405,4 +411,77 @@ def cumprod(inputs, attrs): ib.store(dst, idx_cur, ib.load(dst, idx_pre) * ib.load(data, idx_pre if exclusive else idx_cur)) return ib.get() return tvm.extern(shape, [in_tensor], lambda ins, outs : kernel_ir(ins[0], outs[0]), name=output_name, - dtype=in_tensor.dtype) \ No newline at end of file + dtype=in_tensor.dtype) + +@tvm.register_func("UserDefined") +def user_defined(inputs, attrs): + + op_desc_attr = [] + source_str = "" + op_imply_path = "" + func_name = "" + func_type = "" + + for ext_arg in attrs.items(): + attr_name = ext_arg[0] + if attr_name == "func_source_str": + source_str = ext_arg[1].value + elif attr_name == "op_imply_path": + op_imply_path = ext_arg[1].value + elif attr_name == "func_name": + func_name = ext_arg[1].value + elif attr_name == "func_type": + func_type = ext_arg[1].value + elif not (attr_name == "akg" or "_format" in attr_name): + # store the rest of op attr for op build + op_desc_attr.append(ext_arg) + + op_attrs = [] + ir_builder_attrs = {} + for ext_arg in op_desc_attr: + if func_type == "ir_builder": + # ir_builder functions take attrs as a dict/tvm.Map + ir_builder_attrs[ext_arg[0]] = ext_arg[1] + else: + op_attrs.append(ext_arg[1]) + + func_kernel = None + if len(source_str) > 0: + capture = locals() + capture["source_str"] = source_str + + func_mod = compile(source_str, "", "exec") + exec(func_mod) + func_kernel = locals()[func_name] + + elif len(op_imply_path) > 0: + if os.path.isfile(op_imply_path): + custom_mod_name = Path(op_imply_path).resolve().stem + mod_spec = importlib.util.spec_from_file_location( + custom_mod_name, op_imply_path) + custom_mod = importlib.util.module_from_spec(mod_spec) + mod_spec.loader.exec_module(custom_mod) + func_kernel = getattr(custom_mod, func_name, None) + else: + logging.error("Can't find file under path: %s", + str(op_imply_path)) + else: + logging.error( + "Neither source_str nor op_imply_path is provided in the json file") + + if func_kernel is None: + logging.error( + "Failed in compiling op function from userdefine op") + + output = None + if func_type == "hybrid": + hybrid_func = script(func_kernel, capture=capture) + inputs = list(inputs) + op_attrs + output = hybrid_func(*inputs) + elif func_type == "ir_builder": + output = func_kernel(inputs, ir_builder_attrs) + else: + inputs = list(inputs) + op_attrs + output = func_kernel(*inputs) + + return output \ No newline at end of file diff --git a/src/composite/util.cc b/src/composite/util.cc index 0022986640398d6859f6e2e5bd94241b2da8d856..7ff50d90f13a73e27677f1be01d2d131f5f02840 100644 --- a/src/composite/util.cc +++ b/src/composite/util.cc @@ -77,7 +77,7 @@ bool IsOtherOp(const std::string &op_name) { std::unordered_set elems = {"MatMul", "BatchMatMul", "Conv", "Transpose", "Tile", "Assign", "InplaceAssign", "EquivFormat", "TransData", "AddMinValue", "BroadcastTo", "PadAkg", "UnPadAkg", "Conv2D", "CumSum", - "CumProd", "StridedSlice"}; + "CumProd", "StridedSlice", "UserDefined"}; return elems.find(op_name) != elems.end(); } bool IsElemwise(const std::string &op_name) { diff --git a/third_party/incubator-tvm/python/tvm/hybrid/__init__.py b/third_party/incubator-tvm/python/tvm/hybrid/__init__.py index abe59a062f54a36e1c14d027b82d1450bb5b1326..594457d9d4706c80a34d988b3cddd0a85f97ecca 100644 --- a/third_party/incubator-tvm/python/tvm/hybrid/__init__.py +++ b/third_party/incubator-tvm/python/tvm/hybrid/__init__.py @@ -26,6 +26,7 @@ HalideIR. """ # 2019.12.30 - Enhance function script. +# 2021.08.17 - Update getting source logic to support UserDefined ops # TODO(@were): Make this module more complete. # 1. Support HalideIR dumping to Hybrid Script @@ -95,7 +96,10 @@ def script(pyfunc=None, intrinsics=None, capture=None): if _is_tvm_arg_types(args): _patch_intrins_to_calls(intrinsics=intrinsics) _patch_intrins_to_runtime(intrinsics=intrinsics) - src = _pruned_source(func) + if capture.get("source_str"): + src = capture["source_str"] + else: + src = _pruned_source(func) op = source_to_op(src, args, func.__globals__, closure_vars) _unpatch_intrins_from_runtime(intrinsics=intrinsics) _unpatch_intrins_from_calls(intrinsics=intrinsics)