From 96fc7b13f5c9df8ffc270b271b88c624c86c7002 Mon Sep 17 00:00:00 2001 From: Zichun Ye Date: Sat, 19 Mar 2022 18:18:22 +0800 Subject: [PATCH] fix topi conflict in custom op update compile attr add more checking --- python/akg/composite/build_module.py | 17 ++++++++++------- python/akg/composite/topi.py | 18 +++++++++++++++--- src/composite/lower_tree/json_leaf.cc | 9 +++++++-- src/composite/utils/util.cc | 10 ++-------- 4 files changed, 34 insertions(+), 20 deletions(-) diff --git a/python/akg/composite/build_module.py b/python/akg/composite/build_module.py index 06ec65c3..ae34f67a 100644 --- a/python/akg/composite/build_module.py +++ b/python/akg/composite/build_module.py @@ -17,6 +17,7 @@ """build module""" import os import json +from collections import Iterable import akg from akg import tvm from akg.utils.kernel_exec import ReturnType @@ -181,14 +182,16 @@ def _update_compile_attr(desc_d, attr): if desc_d['op_desc'] is None: return attr for op in desc_d['op_desc']: - if "compile_attr" not in op or op["compile_attr"] is None: + op_attrs = op.get("attr", {}) + if not isinstance(op_attrs, Iterable): continue - for i in op["compile_attr"]: - if isinstance(i, str): - attr.update({i: op["compile_attr"][i]}) - else: - raise ValueError("Currently all compile attrs' name for AKG should be type of str. But got \ - an attr name: {}, which type is: {}.".format(i, type(i))) + compile_attrs = list(item.get("value", "") for item in op_attrs if isinstance( + item, dict) and item.get("name", "") == "func_compile_attrs") + if compile_attrs: + attrs_dict = json.loads(compile_attrs[0]) + for key in attrs_dict: + attr.update({key: attrs_dict[key]}) + return attr diff --git a/python/akg/composite/topi.py b/python/akg/composite/topi.py index 11d12f1b..f802b197 100644 --- a/python/akg/composite/topi.py +++ b/python/akg/composite/topi.py @@ -19,6 +19,7 @@ import functools import itertools import operator import os +import sys import importlib.util from pathlib import Path @@ -28,7 +29,7 @@ from akg.utils.format_transform import get_const, get_shape from akg.utils.dsl_create import get_broadcast_shape from akg.utils import validation_check as vc_util -import akg.topi as topi +import akg.topi as akg_topi import akg.utils as utils @@ -84,7 +85,7 @@ def pad(inputs, attrs): raise ValueError( "Input dimensions and pad dimensions dismatch: %d vs %d vs %d" % (n, len(pad_before), len(pad_after))) output_name = "T_pad_" + in_tensor.op.name - return topi.nn.pad(in_tensor, pad_before, pad_after, pad_value, name=output_name) + return akg_topi.nn.pad(in_tensor, pad_before, pad_after, pad_value, name=output_name) @tvm.register_func("UnPadAkg") @@ -210,7 +211,7 @@ def StridedSlice(inputs, attrs): begin[i] = 1 end[i] = 1 strides[i] = 1 - in_tensor = topi.expand_dims(in_tensor, i, 1) + in_tensor = akg_topi.expand_dims(in_tensor, i, 1) i += 1 continue if i < len(shrink_axis_pos) and shrink_axis_pos[i] == '1': @@ -500,6 +501,17 @@ def _launch_kernel_from_source(inputs, op_attrs, source_str, real_inputs_num, is def _launch_kernel_from_path(inputs, op_attrs, func_type, op_imply_path, func_name): if not os.path.isfile(op_imply_path): raise ValueError("Can't find file under path: {}".format(str(op_imply_path))) + # here we need to drop some sys module with name tvm, akg and topi + # as they will lead to conflict with TBE on ascend when exec module + akg_key_list = [] + for key in sys.modules.keys(): + if "tvm" in key or "akg" in key or "topi" in key: + akg_key_list.append(key) + for key in akg_key_list: + sys.modules.pop(key) + # del akg related path in the sys.path + # these two paths are added when akg kernel compiler is launched + sys.path = sys.path[2:] custom_mod_name = Path(op_imply_path).resolve().stem mod_spec = importlib.util.spec_from_file_location( diff --git a/src/composite/lower_tree/json_leaf.cc b/src/composite/lower_tree/json_leaf.cc index c812d575..1704cd10 100644 --- a/src/composite/lower_tree/json_leaf.cc +++ b/src/composite/lower_tree/json_leaf.cc @@ -30,8 +30,13 @@ constexpr auto kKernelName = "kernel_name"; void JsonLowerLeaf::Lower(StageType s) { auto info = GenBuildInfo(attrs_); attrs_.Set(kOriginKernelName, Expr(origin_kernel_name_)); - data_ = LowerDataNode::make(GetScheduleWithBuildInfo(info), info.args, info.in_binds, attrs_, - GetProcess(String2Json(json_str_)), info.kernel_name, GetConfig(), polyhedral_); + std::string process = GetProcess(String2Json(json_str_)); + if (attrs_.count("target_option")) { + CHECK(attrs_["target_option"]->IsInstance()); + process = process + " " + attrs_["target_option"].as()->value; + } + data_ = LowerDataNode::make(GetScheduleWithBuildInfo(info), info.args, info.in_binds, attrs_, process, + info.kernel_name, GetConfig(), polyhedral_); current_stage_ = StageType::Begin; } diff --git a/src/composite/utils/util.cc b/src/composite/utils/util.cc index bf0881fc..8fcd58e8 100644 --- a/src/composite/utils/util.cc +++ b/src/composite/utils/util.cc @@ -67,19 +67,13 @@ std::string GetRealTarget(const std::string &target) { std::string GetProcess(const picojson::value &input_json) { const picojson::value::object &input_obj = input_json.get(); std::string target; - std::string target_option; auto iter = input_obj.find("process"); if (iter != input_obj.end()) { CHECK(iter->second.is()); target = iter->second.get(); } - iter = input_obj.find("target_option"); - if (iter != input_obj.end()) { - CHECK(iter->second.is()); - target_option = " " + iter->second.get(); - } - return GetRealTarget(target) + target_option; + return GetRealTarget(target); } picojson::value String2Json(const std::string &json_str) { @@ -283,7 +277,7 @@ bool HasFoundReshape(const Array &shape_change, const Expr &ori_size, size } bool CheckInputBroadcast(const std::string &type, const std::vector &index_groups, size_t i, - const Array &shape_ori) { + const Array &shape_ori) { auto index_group = index_groups[i]; auto indexs = index_groups[i].indexs; auto index_group_start = indexs[0]; -- Gitee