diff --git a/python/akg/composite/build_module.py b/python/akg/composite/build_module.py index 19c1634b403f7a986b2ba6287899f3633a2153ed..5bd79addf0b9eae70863838fb13bc7e88f44feb3 100644 --- a/python/akg/composite/build_module.py +++ b/python/akg/composite/build_module.py @@ -348,6 +348,8 @@ def _update_attrs_cpu(all_ops, attrs, poly): if any([i in all_ops for i in ["Conv2D"]]): attrs["enable_auto_fuse"] = False attrs["pragma_enable_conv2d_direct"] = True + if any([i in all_ops for i in ["Pool2D"]]): + attrs["enable_auto_fuse"] = False if "feature" not in attrs.keys() and any([i in all_ops for i in ["BatchMatMul", "MatMul"]]): attrs["feature"] = "avx" return attrs diff --git a/python/akg/composite/topi.py b/python/akg/composite/topi.py index b0736d5a21b399fca29a3a1b7ecbf1debf4c5de2..0776ba4cfaf14c5aa6df384153bfff73020d69a2 100644 --- a/python/akg/composite/topi.py +++ b/python/akg/composite/topi.py @@ -362,6 +362,66 @@ def trans_data(inputs, attrs): raise ValueError("TransData for src_format %s and dst_format %s is not supported" % (src_format, dst_format)) +@tvm.register_func("Pool2D") +def pool2d(inputs, attrs): + is_global = attrs["global"].value + pool_type = attrs["pool_type"].value + data_layout = attrs["data_layout"].value + + if is_global: + # global_pool2d dsl + if data_layout == "NHWC": + red_axis = (1, 2) + else: + # data_layout is NCHW or NCHWc + red_axis = (2, 3) + + if pool_type == "max": + out = akg_topi.max(inputs[0], axis=red_axis, keepdims=True) + elif pool_type == "avg": + out = akg_topi.sum(inputs[0], axis=red_axis, keepdims=True) + + count = 1 + for i in red_axis: + count *= inputs[0].shape[i] + out = akg_topi.divide(out, count) + else: + raise ValueError( + "pool_type should be max/avg, current pool_type is {}".format(pool_type)) + + return out + + kernels = [get_const(x) for x in attrs["kernel_size"]] + strides = [get_const(x) for x in attrs["strides"]] + padding = [get_const(x) for x in attrs["pad"]] if attrs.__contains__("pad") else [0, 0, 0,0] + ceil_mode = attrs["round_mode"].value + return akg_topi.nn.pool(inputs[0], kernels, strides, padding, pool_type, ceil_mode, data_layout) + +def global_pool2d(inputs, attrs): + + data_layout = attrs["data_layout"] + pool_type = attrs["pool_type"] + if data_layout == "NHWC": + red_axis = (1, 2) + else: + # data_layout is NCHW or NCHWc + red_axis = (2, 3) + + if pool_type == "max": + out = akg_topi.max(inputs[0], axis=red_axis, keepdims=True) + elif pool_type == "avg": + out = akg_topi.sum(inputs[0], axis=red_axis, keepdims=True) + + count = 1 + for i in red_axis: + count *= inputs[0].shape[i] + out = akg_topi.divide(out, count) + else: + raise ValueError( + "pool_type should be max/avg, current pool_type is {}".format(pool_type)) + + return out + @tvm.register_func("LayoutTransform") def layout_transform(inputs, attrs): diff --git a/python/akg/ops/nn/cpu/__init__.py b/python/akg/ops/nn/cpu/__init__.py index 2de9f0124320caf4ed280316ae62647f32cfa745..2a3c6ab05f16830b543650a9950a9fcce7f40ae8 100644 --- a/python/akg/ops/nn/cpu/__init__.py +++ b/python/akg/ops/nn/cpu/__init__.py @@ -18,3 +18,5 @@ from .layout_transform_utils import get_layout_list, get_alpha_only, \ from .conv2d import conv2d_nchwc from .depthwise_conv2d import depthwise_conv2d_nchwc from .layout_transform import layout_transform +from .pooling import pooling +from .global_pooling import global_pooling diff --git a/python/akg/ops/nn/cpu/global_pooling.py b/python/akg/ops/nn/cpu/global_pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..45c999ec85bf754838b5580806e573fbaf1e19aa --- /dev/null +++ b/python/akg/ops/nn/cpu/global_pooling.py @@ -0,0 +1,43 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""operator dsl function: global pooling""" +import akg.topi as topi +from akg.topi.util import get_const_tuple +import akg.tvm as tvm + + +def global_pooling(data, pool_type, + data_layout="NCHW"): + """Global pooling op impl""" + if data_layout == "NHWC": + red_axis = (1, 2) + else: + # data_layout is NCHW or NCHWc + red_axis = (2, 3) + + if pool_type == "max": + out = topi.max(data, axis=red_axis, keepdims=True) + elif pool_type == "avg": + out = topi.sum(data, axis=red_axis, keepdims=True) + + count = 1 + for i in red_axis: + count *= data.shape[i] + out = topi.divide(out, count) + else: + raise ValueError( + "pool_type should be max/avg, current pool_type is {}".format(pool_type)) + + return out diff --git a/python/akg/ops/nn/cpu/pooling.py b/python/akg/ops/nn/cpu/pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..a7c94323510adc75784325bc69ee5b9c28060f30 --- /dev/null +++ b/python/akg/ops/nn/cpu/pooling.py @@ -0,0 +1,27 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""operator dsl function: pooling""" +import akg.topi as topi +from akg.topi.util import get_const_tuple +import akg.tvm as tvm + +def pooling(data, kernel, stride, padding, pool_type, + ceil_mode, count_include_pad=True, + data_layout="NCHW"): + """Pooling op impl""" + out = topi.nn.pool(data, kernel=kernel, stride=stride, padding=padding, + pool_type=pool_type, ceil_mode=ceil_mode, + layout=data_layout, count_include_pad=count_include_pad) + return out \ No newline at end of file diff --git a/python/akg/utils/op_dsl.py b/python/akg/utils/op_dsl.py index cadf1402f795e506dd85fdc3e774e93149182182..192898386b4f72ec0f7e1c02d49d76ae87f83406 100644 --- a/python/akg/utils/op_dsl.py +++ b/python/akg/utils/op_dsl.py @@ -718,6 +718,110 @@ def cummulative_str(inputs, outputs, attr, op_type): res += "{} = out\n".format(outputs[0]['tensor_name']) return res +def pool2d_str(inputs, output, attr): + + import tvm + import akg.topi as akg_topi + from akg.topi.util import get_const_tuple + + if get_attr(attr, "global"): + # global_pool2d python impl + return global_pool2d_str(inputs, output, attr) + + kh, kw = get_attr(attr, "kernel_size") + sh, sw = get_attr(attr, "strides") + paddings = get_attr(attr, "pad") + pt, pl, pb, pr = paddings if len(paddings) == 4 else [0, 0, 0, 0] + pool_type = get_attr(attr, "pool_type") + ceil_mode = get_attr(attr, "round_mode") + count_include_pad = True + data_layout = get_attr(attr, "data_layout") + + res = "" + + # We temporarily reshape all format as NCHWc + if data_layout == "NCHW": + n, ic_out, ih, iw = inputs[0][0]['shape'] + ic_in = 1 + + elif data_layout == "NHWC": + n, ih, iw, ic_in = inputs[0][0]['shape'] + ic_out = 1 + else: + # NCHWc + import re + pattern = re.compile(r'NCHW\d*c') + if pattern.match(data_layout) == None: + raise ValueError("Invalid data_layout = {}".format(data_layout)) + n, ic_out, ih, iw, ic_in = inputs[0][0]['shape'] + + tmp_shape = "({},{},{},{},{})".format(n, ic_out, ih, iw, ic_in) + res += "tmp_data = np.reshape({}, {})\n".format(inputs[0][0]["tensor_name"], tmp_shape) + res += "pt, pl, pb, pr = {}, {}, {}, {}\n".format(pt, pl, pb, pr) + res += "kh ,kw = {}, {}\n".format(kh, kw) + res += "sh, sw = {}, {}\n".format(sh, sw) + input0 = tvm.placeholder((n, ic_out, ih, iw, ic_in), name='a') + output0 = akg_topi.nn.pool(input0, kernel=[kh, kw], stride=[sh, sw], padding=(pt, pl, pb, pr), + pool_type=pool_type, ceil_mode=ceil_mode, + layout="NCHWc", count_include_pad=count_include_pad) + _, oc_out, oh, ow, oc_in = get_const_tuple(output0.shape) + res += "n, ic_out, ih, iw, ic_in = {}, {}, {}, {}, {}\n".format( + n, ic_out, ih, iw, ic_in) + res += "oc_out, oh, ow, oc_in = {}, {}, {}, {}\n".format( + oc_out, oh, ow, oc_in) + res += "dtype = np.{}\n".format(inputs[0][0]["data_type"]) + res += "pad_np = np.zeros(shape=(n, ic_out, ih+pt+pb, iw+pl+pr, ic_in)).astype(dtype)\n" + res += "no_zero = (range(n), range(ic_out), (range(pt, ih+pt)), (range(pl, iw+pl)), range(ic_in))\n" + res += "pad_np[np.ix_(*no_zero)] = tmp_data\n" + res += "b_np = np.zeros(shape=(n, oc_out, oh, ow, oc_in)).astype(dtype)\n" + + if pool_type == "avg": + res += "for i in range(oh):\n" + res += " for j in range(ow):\n" + res += " if count_include_pad:\n" + res += " b_np[:, :, i, j, :] = np.mean(\n" + res += " pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw, :], axis=(2, 3))\n" + res += " else:\n" + res += " pad_count = np.sum(\n" + res += " pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw, :] > 0, axis=(2, 3))\n" + res += " b_np[:, :, i, j, :] = np.sum(\n" + res += " pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw, :], axis=(2, 3)) / np.maximum(pad_count, 1)\n" + elif pool_type == "max": + res += "for i in range(oh):\n" + res += " for j in range(ow):\n" + res += " b_np[:, :, i, j, :] = np.max(\n" + res += " pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw, :], axis=(2, 3))\n" + else: + raise ValueError("Invalid pool_type=={}".format(pool_type)) + + if data_layout == "NCHW" or data_layout == "NHWC": + res += "{} = np.squeeze(b_np)\n".format(output[0]["tensor_name"]) + else: + res += "{} = b_np\n".format(output[0]["tensor_name"]) + + return res + + +def global_pool2d_str(inputs, output, attr): + pool_type = get_attr(attr, "pool_type") + data_layout = get_attr(attr, "data_layout") + res = "" + if data_layout == "NHWC": + res += "pool_idxs = (1, 2)\n" + elif data_layout[:4] == "NCHW": + # NCHW or NCHWc/ NCHW[x]c + res += "pool_idxs = (2, 3)\n" + + res += "global_pool_input = {}\n".format(inputs[0][0]["tensor_name"]) + if pool_type == 'avg': + res += "{} = np.mean(global_pool_input, axis=pool_idxs, keepdims=True)\n".format( + output[0]["tensor_name"]) + elif pool_type == 'max': + res += "{} = np.max(global_pool_input, axis=pool_idxs, keepdims=True)\n".format( + output[0]["tensor_name"]) + return res + + def layout_transform_str(inputs, output, attr, op_type): """gen layout_transform string""" @@ -1012,6 +1116,7 @@ op_dsl = { "Complex": lambda inputs, output, attr: "%s = np.vectorize(complex)(%s, %s)" % (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0])), + "Pool2D": lambda inputs, output, attr: pool2d_str(inputs, output, attr), "LayoutTransform": lambda inputs, output, attr: layout_transform_str(inputs, output, attr, "layout_transform"), "Concat": lambda inputs, output, attr: concat_str(inputs, output, attr) } diff --git a/src/composite/utils/util.cc b/src/composite/utils/util.cc index 2f950e78cfd7932320e1d547def7889d7a9f6cf4..a91856ac58dece5a025e9f21cd090671ab165adb 100644 --- a/src/composite/utils/util.cc +++ b/src/composite/utils/util.cc @@ -83,7 +83,6 @@ std::string GetProcess(const picojson::value &input_json) { } else if (feature == "avx512") { options = " -mcpu=skylake-avx512 -mattr=-avx512f"; } else if (feature == "neon") { - // NOTE(yanzhi): there are many kinds of arm cpu, we should select options here. options = " -target=aarch64-linux-gnu -mattr=+neon"; } } @@ -145,6 +144,7 @@ bool IsOtherOp(const std::string &op_name) { "COO2CSR", "ElemAny", "CSRMM", + "Pool2D", "LayoutTransform"}; return elems.find(op_name) != elems.end(); } diff --git a/src/poly/tiling/tiling_analyzer.h b/src/poly/tiling/tiling_analyzer.h index ec1f475af2a47749b4052cf77c0199e541c41351..21461dd2e4c1273d0bf3618ad721f30fabc71c38 100755 --- a/src/poly/tiling/tiling_analyzer.h +++ b/src/poly/tiling/tiling_analyzer.h @@ -51,7 +51,7 @@ constexpr auto MIN_CORE_GRANULARITY = 256; constexpr auto DESIRE_CORE_GRANULARITY = 8192; constexpr auto MIN_EXEC_NUM_PER_THREAD = 4096; constexpr auto BEST_PARALLEL_NUM = 192; -constexpr auto PARALLEL_DECREASE_VALUE = 8; +constexpr auto PARALLEL_DECREASE_VALUE = 1; constexpr auto BEST_UNROLL_NUM = 256; constexpr auto MIN_UNROLL_NUM = 8; constexpr auto MATMUL_BEST_FACTOR = 128; diff --git a/src/poly/tiling/tiling_strategy_manager_cpu.cc b/src/poly/tiling/tiling_strategy_manager_cpu.cc index f1bad9a935484f23659139a72ca7bb6a31468363..208969116e7d86315571461cc31ea5d9df23de75 100644 --- a/src/poly/tiling/tiling_strategy_manager_cpu.cc +++ b/src/poly/tiling/tiling_strategy_manager_cpu.cc @@ -98,9 +98,6 @@ void CpuStrategy::SetParallelTileValue(TileAxis *axis, const int64_t axis_size, parallel_num = std::min(axis_size, static_cast(best_parallel_num_)); } else if (evaluate_num > 1) { while (parallel_num > 0 && tile_size % parallel_num != 0) { - if (parallel_num < evaluate_num) { - break; - } parallel_num -= parallel_decrease_value_; } } else { @@ -137,6 +134,7 @@ void CpuStrategy::SetConv2dTileValue(int index) { TileAxis *batch_axis = nullptr; int64_t _; std::tie(batch_axis, _) = pending_axes_[index][p]; + CHECK(batch_axis != nullptr); batch_axis->TileRestrainToSingleValue(Expr((int64_t)1), TileLevel::CACHE1); batch_axis->TileRestrainToSingleValue(Expr((int64_t)1), TileLevel::CACHE0); p += 1; @@ -147,6 +145,7 @@ void CpuStrategy::SetConv2dTileValue(int index) { TileAxis *oc_out_axis = nullptr; int64_t _; std::tie(oc_out_axis, _) = pending_axes_[index][p]; + CHECK(oc_out_axis != nullptr); oc_out_axis->TileRestrainToSingleValue(Expr((int64_t)1), TileLevel::CACHE1); oc_out_axis->TileRestrainToSingleValue(Expr((int64_t)1), TileLevel::CACHE0); p += 1; @@ -157,6 +156,7 @@ void CpuStrategy::SetConv2dTileValue(int index) { TileAxis *oh_axis = nullptr; int64_t _; std::tie(oh_axis, _) = pending_axes_[index][p]; + CHECK(oh_axis != nullptr); oh_axis->TileRestrainToSingleValue(Expr((int64_t)1), TileLevel::CACHE1); oh_axis->TileRestrainToSingleValue(Expr((int64_t)1), TileLevel::CACHE0); p += 1; @@ -167,6 +167,7 @@ void CpuStrategy::SetConv2dTileValue(int index) { TileAxis *ow_axis = nullptr; int64_t ow_shape; std::tie(ow_axis, ow_shape) = pending_axes_[index][p]; + CHECK(ow_axis != nullptr); /* ow_inner should follow some strategy: 1. ow_shape % ow_tile == 0 @@ -189,6 +190,7 @@ void CpuStrategy::SetConv2dTileValue(int index) { TileAxis *oc_in_axis = nullptr; int64_t oc_in_shape; std::tie(oc_in_axis, oc_in_shape) = pending_axes_[index][p]; + CHECK(oc_in_axis != nullptr); oc_in_axis->TileRestrainToSingleValue(Expr(oc_in_shape), TileLevel::CACHE1); oc_in_axis->TileRestrainToSingleValue(Expr(oc_in_shape), TileLevel::CACHE0); p += 1; @@ -199,6 +201,7 @@ void CpuStrategy::SetConv2dTileValue(int index) { TileAxis *ic_out_axis = nullptr; int64_t ic_out_shape; std::tie(ic_out_axis, ic_out_shape) = pending_axes_[index][p]; + CHECK(ic_out_axis != nullptr); ic_out_axis->TileRestrainToSingleValue(Expr(ic_out_shape), TileLevel::CACHE1); ic_out_axis->TileRestrainToSingleValue(Expr((int64_t)1), TileLevel::CACHE0); p += 1; diff --git a/tests/common/test_run/cpu/__init__.py b/tests/common/test_run/cpu/__init__.py index a049c55e79b0b20ab775c3264e8a9da3f374fe5f..4b9a47c573f40c01e2cc1cef00e3591b38d5d69f 100644 --- a/tests/common/test_run/cpu/__init__.py +++ b/tests/common/test_run/cpu/__init__.py @@ -14,4 +14,6 @@ from .conv2d_run import conv2d_run from .depthwise_conv2d_run import depthwise_conv2d_run -from .layout_transform_run import layout_transform_run \ No newline at end of file +from .layout_transform_run import layout_transform_run +from .pooling_run import pooling_run +from .global_pooling_run import global_pooling_run \ No newline at end of file diff --git a/tests/common/test_run/cpu/conv2d_run.py b/tests/common/test_run/cpu/conv2d_run.py index e2bde312564ac89e951c0f58f93b1e6f574b5ec3..ae872e74e7289930d51656e532feabae1924f731 100644 --- a/tests/common/test_run/cpu/conv2d_run.py +++ b/tests/common/test_run/cpu/conv2d_run.py @@ -40,7 +40,6 @@ def gen_data(shape_data, shape_weight, stride, padding, dilation, dtype, data_la def compute_np_conv2d(data, weight, stride, padding, dilation, dtype, data_layout="NCHWc", output_layout="NCHWc"): if data_layout == "NCHWc": - # NOTE(yanzhi): we change NCHWc to NCHW, and do the compute data_nchw = unpack_nchwc_to_nchw_python(data, dtype) weight_nchw = unpack_kcrsxy_to_kcrs_python(weight, dtype) elif data_layout == "NCHW": diff --git a/tests/common/test_run/cpu/depthwise_conv2d_run.py b/tests/common/test_run/cpu/depthwise_conv2d_run.py index 43dbf357bc98653446e5cda761b1094964d1058b..d399b054ed87ca67b8191946ec872f011cf60197 100644 --- a/tests/common/test_run/cpu/depthwise_conv2d_run.py +++ b/tests/common/test_run/cpu/depthwise_conv2d_run.py @@ -40,7 +40,6 @@ def gen_data(shape_data, shape_weight, stride, padding, dilation, dtype, data_la def compute_np_depthwise_conv2d(data, weight, stride, padding, dilation, dtype, data_layout="NCHWc", output_layout="NCHWc"): if data_layout == "NCHWc": - # NOTE(yanzhi): we change NCHWc to NCHW, and do the compute data_nchw = unpack_nchwc_to_nchw_python(data, dtype) weight_nchw = unpack_kcrsxy_to_kcrs_python(weight, dtype) elif data_layout == "NCHW": diff --git a/tests/common/test_run/cpu/global_pooling_run.py b/tests/common/test_run/cpu/global_pooling_run.py new file mode 100644 index 0000000000000000000000000000000000000000..44716b1abadedf6725a231ea7d7f61ccbbc52645 --- /dev/null +++ b/tests/common/test_run/cpu/global_pooling_run.py @@ -0,0 +1,86 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +from akg.ops.nn.cpu import global_pooling +import akg +import tvm +import topi +import numpy as np +from akg.topi.util import get_const_tuple +from akg.utils import kernel_exec as utils +from akg.utils.result_analysis import target_profiling +from akg.utils.format_transform import to_tvm_nd_array +from tests.common.gen_random import random_gaussian + +support_list = {"float32": np.float32} +support_layout_format = {"NCHW", "NCHWc"} + + +def gen_data(shape_data, pool_type, dtype, data_layout): + if data_layout == "NHWC": + pool_idxs = (1, 2) + else: + # NCHW or NCHWc + pool_idxs = (2, 3) + + input0 = tvm.placeholder(shape_data, name='input0') + output0 = topi.nn.global_pool(input0, pool_type=pool_type) + a_np = random_gaussian(shape_data, miu=1, sigma=0.1).astype( + support_list[dtype]) + if pool_type == 'avg': + b_np = np.mean(a_np, axis=pool_idxs, keepdims=True) + elif pool_type == 'max': + b_np = np.max(a_np, axis=pool_idxs, keepdims=True) + + output_np = np.zeros(shape=b_np.shape).astype(dtype) + + return a_np, output_np, b_np + + +def global_pooling_run(shape_data, pool_type, dtype, + data_layout="NCHWc", poly_sch=True, attrs=None): + + default_attrs = {"enable_auto_fuse": True} + attrs = {} if attrs == None else attrs + attrs.update(default_attrs) + attrs["target"] = attrs.get("target", "llvm") + op_attrs = [pool_type, data_layout] + + mod = utils.op_build_test(global_pooling, (shape_data,), (dtype,), + op_attrs=op_attrs, attrs=attrs, + kernel_name="global_pooling_" + pool_type + "_auto", polyhedral=poly_sch) + + data, output, expect = gen_data( + shape_data, pool_type, dtype, data_layout) + args = (data, output) + output = utils.mod_launch(mod, args, expect=expect) + rtol = 1e-3 if dtype == "float16" else 1e-4 + atol = 1e-3 if dtype == "float16" else 1e-4 + res = np.allclose(output, expect, rtol=rtol, atol=atol) + print("Test {}".format("Pass" if res else "Fail")) + target_name = attrs["target"].split()[0] + if not res: + mod_source = mod + if target_name != "llvm": + mod_source = mod.imported_modules[0] + print("Error {}:========================".format(target_name)) + print(mod_source.get_source()) + raise AssertionError("Test fail") + + attrs["profiling"] = True + if attrs.get("profiling", False): + data, output = to_tvm_nd_array( + [data, output], akg.tvm.context(target_name, 0)) + target_profiling(mod, data, output, + target=target_name, repeat_time=attrs.get("repeat_times", 1000)) + return (data, ), output, expect, res diff --git a/tests/common/test_run/cpu/pooling_run.py b/tests/common/test_run/cpu/pooling_run.py new file mode 100644 index 0000000000000000000000000000000000000000..814a210df7ca28e8c0a342dd6784d8e638402738 --- /dev/null +++ b/tests/common/test_run/cpu/pooling_run.py @@ -0,0 +1,146 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +from akg.ops.nn.cpu import pooling +import akg +import tvm +import topi +import numpy as np +from akg.topi.util import get_const_tuple +from akg.utils import kernel_exec as utils +from akg.utils.result_analysis import target_profiling +from akg.utils.format_transform import to_tvm_nd_array +from tests.common.gen_random import random_gaussian + +support_list = {"float32": np.float32} +support_layout_format = {"NCHW", "NCHWc"} + + +def gen_data(shape_data, kernel, stride, padding, pool_type, dtype, ceil_mode, count_include_pad, data_layout): + kw, kh = kernel + sw, sh = stride + pt, pl, pb, pr = padding + if data_layout == "NCHW": + n, ic, ih, iw = shape_data + + input0 = tvm.placeholder((n, ic, ih, iw), name='input0') + output0 = topi.nn.pool(input0, kernel=[kh, kw], stride=[sh, sw], padding=padding, + pool_type=pool_type, ceil_mode=ceil_mode, + layout="NCHW", count_include_pad=count_include_pad) + + a_np = random_gaussian((n, ic, ih, iw), miu=1, sigma=0.1).astype(support_list[dtype]) + pad_np = np.zeros(shape=(n, ic, ih+pt+pb, iw+pl+pr)).astype(dtype) + no_zero = (range(n), range(ic), (range(pt, ih+pt)), (range(pl, iw+pl))) + pad_np[np.ix_(*no_zero)] = a_np + _, oc, oh, ow = get_const_tuple(output0.shape) + b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype) + + if pool_type == 'avg': + for i in range(oh): + for j in range(ow): + if count_include_pad: + b_np[:, :, i, j] = np.mean( + pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2, 3)) + else: + pad_count = np.sum( + pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2, 3)) + b_np[:, :, i, j] = np.sum( + pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2, 3)) / np.maximum(pad_count, 1) + + elif pool_type == 'max': + for i in range(oh): + for j in range(ow): + b_np[:, :, i, j] = np.max( + pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2, 3)) + output_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype) + return a_np, output_np, b_np + + elif data_layout == "NHWC": + raise ValueError("Only layout NCHWc/NCHW supported on python dsl") + else: + # NCHWc + n, ic_out, ih, iw, ic_in = shape_data + + input0 = tvm.placeholder((n, ic_out, ih, iw, ic_in), name='input0') + output0 = topi.nn.pool(input0, kernel=[kh, kw], stride=[sh, sw], padding=padding, + pool_type=pool_type, ceil_mode=ceil_mode, + layout="NCHWc", count_include_pad=count_include_pad) + + a_np = random_gaussian((n, ic_out, ih, iw, ic_in), miu=1, sigma=0.1).astype(support_list[dtype]) + pad_np = np.zeros(shape=(n, ic_out, ih+pt+pb, + iw+pl+pr, ic_in)).astype(dtype) + no_zero = (range(n), range(ic_out), (range(pt, ih+pt)), + (range(pl, iw+pl)), range(ic_in)) + pad_np[np.ix_(*no_zero)] = a_np + _, oc_out, oh, ow, oc_in = get_const_tuple(output0.shape) + b_np = np.zeros(shape=(n, oc_out, oh, ow, oc_in)).astype(dtype) + + if pool_type == 'avg': + for i in range(oh): + for j in range(ow): + if count_include_pad: + b_np[:, :, i, j, :] = np.mean( + pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw, :], axis=(2, 3)) + else: + pad_count = np.sum( + pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw, :] > 0, axis=(2, 3)) + b_np[:, :, i, j, :] = np.sum( + pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw, :], axis=(2, 3)) / np.maximum(pad_count, 1) + + elif pool_type == 'max': + for i in range(oh): + for j in range(ow): + b_np[:, :, i, j, :] = np.max( + pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw, :], axis=(2, 3)) + output_np = np.zeros(shape=(n, oc_out, oh, ow, oc_in)).astype(dtype) + return a_np, output_np, b_np + + +def pooling_run(shape_data, kernel, stride, padding, pool_type, dtype, + ceil_mode, count_include_pad=True, + data_layout="NCHWc", poly_sch=True, attrs=None): + + default_attrs = {"enable_auto_fuse": False} + attrs = {} if attrs == None else attrs + attrs.update(default_attrs) + attrs["target"] = attrs.get("target", "llvm") + op_attrs = [kernel, stride, padding, pool_type, + ceil_mode, count_include_pad, data_layout] + + mod = utils.op_build_test(pooling, (shape_data,), (dtype,), + op_attrs=op_attrs, attrs=attrs, + kernel_name="pooling_" + pool_type + "_auto", polyhedral=poly_sch) + + data, output, expect = gen_data( + shape_data, kernel, stride, padding, pool_type, dtype, ceil_mode, count_include_pad, data_layout) + args = (data, output) + output = utils.mod_launch(mod, args, expect=expect) + rtol = 1e-3 if dtype == "float16" else 1e-4 + atol = 1e-3 if dtype == "float16" else 1e-4 + res = np.allclose(output, expect, rtol=rtol, atol=atol) + print("Test {}".format("Pass" if res else "Fail")) + target_name = attrs["target"].split()[0] + if not res: + mod_source = mod + if target_name != "llvm": + mod_source = mod.imported_modules[0] + print("Error {}:========================".format(target_name)) + print(mod_source.get_source()) + raise AssertionError("Test fail") + + if attrs.get("profiling", False): + data, output = to_tvm_nd_array( + [data, output], akg.tvm.context(target_name, 0)) + target_profiling(mod, data, output, + target=target_name, repeat_time=attrs.get("repeat_times", 1000)) + return (data, ), output, expect, res diff --git a/tests/st/ops/cpu/test_pool2d.py b/tests/st/ops/cpu/test_pool2d.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd07598bd97fbeb7dadbfc7c1363fa518aede57 --- /dev/null +++ b/tests/st/ops/cpu/test_pool2d.py @@ -0,0 +1,72 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import pytest +import akg.utils as utils +from tests.common.base import TestBase +from tests.common.test_run.cpu import pooling_run +from tests.common.test_run.cpu import global_pooling_run + +############################################################ +# TestCase= class: put to tests/*/ +############################################################ + + +class TestCase(TestBase): + def setup(self): + case_name = "cpu_pooling" + case_path = os.getcwd() + + self.params_init(case_name, case_path) + + self.args_default = [ + # max/avg pooling2d + ("000_case", pooling_run, ((1, 256, 32, 32), (2, 2), (2, 2), (0, 0, + 0, 0), "avg", "float32", False, True, "NCHW", True), ["level0"]), + ("001_case", pooling_run, ((1, 256, 32, 32), (2, 2), (2, 2), (0, 0, + 0, 0), "max", "float32", False, True, "NCHW", True), ["level0"]), + ("002_case", pooling_run, ((1, 256, 31, 31), (3, 3), (3, 3), (0, 0, + 0, 0), "avg", "float32", False, True, "NCHW", True), ["level0"]), + ("003_case", pooling_run, ((1, 256, 31, 31), (3, 3), (3, 3), (0, 0, + 0, 0), "max", "float32", False, True, "NCHW", True), ["level0"]), + ("004_case", pooling_run, ((1, 16, 31, 31, 8), (3, 3), (3, 3), (0, 0, + 0, 0), "max", "float32", False, True, "NCHW8c", True), ["level0"]), + ("005_case", pooling_run, ((1, 16, 31, 31, 8), (3, 3), (3, 3), (0, 0, + 0, 0), "avg", "float32", False, True, "NCHW8c", True), ["level0"]), + ("006_case", pooling_run, ((1, 16, 31, 31, 8), (3, 3), (3, 3), (0, 0, + 0, 0), "max", "float32", True, True, "NCHW8c", True), ["level0"]), + + # max/avg global pooling2d + ("007_case", global_pooling_run, ((4, 1024, 7, 7), + "max", "float32", "NCHW", True), ["level0"]), + ("008_case", global_pooling_run, ((4, 1024, 7, 7), + "avg", "float32", "NCHW", True), ["level0"]), + ("009_case", global_pooling_run, ((4, 128, 7, 7, 8), + "max", "float32", "NCHW8c", True), ["level0"]), + ("0019_case", global_pooling_run, ((4, 128, 7, 7, 8), + "avg", "float32", "NCHW8c", True), ["level0"]), + ] + + return True + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu + @pytest.mark.env_onecard + def test_cpu_level0(self): + return self.run_cases(self.args_default, utils.LLVM, "level0") + + def teardown(self): + self._log.info("{0} Teardown".format(self.casename)) + super(TestCase, self).teardown() + return diff --git a/third_party/incubator-tvm/topi/include/topi/nn/pooling.h b/third_party/incubator-tvm/topi/include/topi/nn/pooling.h index dbb70d365308a1dcdcaca6646ed1fe0e927aa5ce..511acf6a64243cbd468fc5ce0fe4707d369f8d3e 100644 --- a/third_party/incubator-tvm/topi/include/topi/nn/pooling.h +++ b/third_party/incubator-tvm/topi/include/topi/nn/pooling.h @@ -21,6 +21,12 @@ * \brief Pooling op constructions * \file nn/pooling.h */ + +/* + * 2022.7.4 + * Fix bug for the name of reduction axes in pool_impl. + */ + #ifndef TOPI_NN_POOLING_H_ #define TOPI_NN_POOLING_H_ @@ -105,8 +111,8 @@ inline Tensor pool_impl(const Tensor& x, auto out_width = air::ir::Simplify( indexdiv(width - kernel_width + pad_left + pad_right, stride_width) + 1); - auto dheight = air::reduce_axis(Range(0, kernel_height)); - auto dwidth = air::reduce_axis(Range(0, kernel_width)); + auto dheight = air::reduce_axis(Range(0, kernel_height), "red_h"); + auto dwidth = air::reduce_axis(Range(0, kernel_width), "red_w"); Array out_shape = x->shape; out_shape.Set(height_axis, out_height); @@ -288,6 +294,8 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, // output indices whose pooling windows cover current input element (can be out-of-bound) Array out_idx{inds.begin(), inds.end()}; + CHECK(stride_height.as()->value != 0) << "stride_height == 0"; + CHECK(stride_width.as()->value != 0) << "stride_width == 0"; out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh)); out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww));