From 60b4aef80ba93fa3599998f87ad47d52418573cd Mon Sep 17 00:00:00 2001 From: polyhedral Date: Thu, 19 Aug 2021 15:49:29 +0800 Subject: [PATCH] support gnn --- python/akg/composite/topi.py | 150 +++++++++++++++++- ...ther_gather_add_mul_max_exp_scatter_add.py | 111 +++++++++++++ .../array_gpu/fused_gather_mul_scatter_add.py | 86 ++++++++++ ..._nd_reduce_sum_mul_unsorted_segment_sum.py | 34 ++++ python/akg/ops/array_gpu/gather.py | 26 +++ python/akg/ops/array_gpu/gather_nd.py | 23 +++ .../akg/ops/array_gpu/tensor_scatter_add.py | 22 +++ .../akg/ops/array_gpu/unsorted_segment_sum.py | 22 +++ python/akg/utils/dsl_create.py | 12 ++ src/codegen/build_module.cc | 3 +- src/composite/optimize/elim_reshape.cc | 43 +++++ src/composite/util.cc | 9 +- src/composite/util.h | 3 +- src/pass/rewrite_tensor_index.cc | 50 ++++++ src/pass/tensor_access_rewrite.cc | 8 +- src/poly/gpu_emit/gpu_isl_emitter.cc | 120 +++++++++++++- src/poly/gpu_emit/gpu_isl_emitter.h | 46 ++++++ src/poly/gpu_emit/gpu_isl_emitter_reduce.h | 19 --- src/poly/gpu_emit/gpu_reduce_emit_pass.cc | 103 ------------ src/poly/isl_emitter.cc | 10 ++ src/poly/isl_emitter.h | 1 + src/poly/schedule_pass.h | 5 +- src/poly/schedule_pass/init_schedule.cc | 5 + src/poly/schedule_pass/rm_self_dep.cc | 93 ++++++----- src/poly/schedule_pass/tile_outer_band.cc | 81 +++++----- src/poly/schedule_pass/tile_outer_band.h | 6 +- .../schedule_pass_gpu/gpu_dma_analysis.cc | 1 + .../schedule_pass_gpu/mapping_outer_band.cc | 70 +++----- .../schedule_pass_gpu/mapping_outer_band.h | 6 +- .../operator_mapping_strategy.cc | 85 ++++++++++ .../operator_mapping_strategy.h | 12 ++ .../operator_shared_strategy.cc | 20 +++ .../operator_shared_strategy.h | 1 + .../register_memory_manager.cc | 11 +- .../shared_memory_manager.cc | 50 +++++- .../schedule_pass_gpu/shared_memory_manager.h | 3 + src/poly/schedule_tree_util.cc | 108 ++++++++++++- src/poly/schedule_tree_util.h | 9 +- src/poly/scop_info.h | 36 ++++- src/poly/scop_make_schedule_tree.cc | 20 ++- src/poly/tiling/tiling_analyzer.cc | 12 +- src/poly/tiling/tiling_analyzer.h | 24 +-- src/poly/tiling/tiling_strategy_manager.h | 23 ++- .../tiling/tiling_strategy_manager_gpu.cc | 81 +++++++++- .../tiling/tiling_strategy_manager_npu.cc | 17 +- tests/common/gen_json_data.py | 66 +++++++- tests/common/gen_random.py | 69 ++++++++ tests/common/test_utils.py | 70 +++++++- tests/operators/gpu/__init__.py | 20 +++ tests/operators/gpu/test_all.py | 51 ++++-- ...ther_gather_add_mul_max_exp_scatter_add.py | 77 +++++++++ .../gpu/test_fused_gather_mul_scatter_add.py | 82 ++++++++++ ..._nd_reduce_sum_mul_unsorted_segment_sum.py | 76 +++++++++ tests/operators/gpu/test_ms_gather.py | 56 +++++++ tests/operators/gpu/test_ms_gather_nd.py | 64 ++++++++ .../gpu/test_ms_tensor_scatter_add.py | 69 ++++++++ .../gpu/test_ms_unsorted_segment_sum.py | 64 ++++++++ third_party/incubator-tvm/include/tvm/ir.h | 2 + 58 files changed, 2095 insertions(+), 351 deletions(-) create mode 100644 python/akg/ops/array_gpu/fused_gather_gather_add_mul_max_exp_scatter_add.py create mode 100644 python/akg/ops/array_gpu/fused_gather_mul_scatter_add.py create mode 100644 python/akg/ops/array_gpu/fused_gather_nd_reduce_sum_mul_unsorted_segment_sum.py create mode 100644 python/akg/ops/array_gpu/gather.py create mode 100644 python/akg/ops/array_gpu/gather_nd.py create mode 100644 python/akg/ops/array_gpu/tensor_scatter_add.py create mode 100644 python/akg/ops/array_gpu/unsorted_segment_sum.py create mode 100644 tests/operators/gpu/test_fused_gather_gather_add_mul_max_exp_scatter_add.py create mode 100644 tests/operators/gpu/test_fused_gather_mul_scatter_add.py create mode 100644 tests/operators/gpu/test_fused_gather_nd_reduce_sum_mul_unsorted_segment_sum.py create mode 100644 tests/operators/gpu/test_ms_gather.py create mode 100644 tests/operators/gpu/test_ms_gather_nd.py create mode 100644 tests/operators/gpu/test_ms_tensor_scatter_add.py create mode 100644 tests/operators/gpu/test_ms_unsorted_segment_sum.py diff --git a/python/akg/composite/topi.py b/python/akg/composite/topi.py index 4ab8b251..1f3c91b4 100644 --- a/python/akg/composite/topi.py +++ b/python/akg/composite/topi.py @@ -484,4 +484,152 @@ def user_defined(inputs, attrs): inputs = list(inputs) + op_attrs output = func_kernel(*inputs) - return output \ No newline at end of file + return output + +@tvm.register_func("GatherNd") +def gather_nd(inputs, attrs): + if len(inputs) != 2: + raise ValueError(f"2 inputs expected, but got {len(input)}") + data, indices = inputs + + data_shape = list(data.shape) + indices_shape = list(indices.shape) + indices_last_dim = len(indices_shape) - 1 + left_shape = indices_shape[:indices_last_dim] + right_shape = data_shape[int(indices_shape[indices_last_dim]):] + def gen_ir(data, indices, out): + ib = tvm.ir_builder.create() + with ib.for_range_n(left_shape, 'i') as i: + with ib.for_range_n(right_shape, 'j') as j: + read_idx = [] + inbound = True + for k in range(0, int(indices_shape[-1])): + temp_idx = ib.load(indices, i + [k]) + if k == 0: + inbound = tvm.all((temp_idx >= 0), (temp_idx < data_shape[k])) + else: + inbound = tvm.all(inbound, (temp_idx >= 0), (temp_idx < data_shape[k])) + read_idx.append(temp_idx) + with ib.if_scope(inbound): + ib.store(out, i + j, ib.load(data, read_idx + j)) + with ib.else_scope(): + ib.store(out, i + j, tvm.const(0, data.dtype)) + return ib.get() + + output_name = "T_gathernd_" + data.op.name + "_" + indices.op.name + output_shape = left_shape + right_shape + out_buf = tvm.decl_buffer(output_shape, data.dtype, output_name) + return tvm.extern([output_shape], [data, indices], lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), + dtype=data.dtype, out_buffers=[out_buf], name=output_name) + +@tvm.register_func("TensorScatterAdd") +def tensor_scatter_add(inputs, attrs): + if len(inputs) != 3: + raise ValueError(f"3 inputs expected, but got {len(input)}") + data, indices, updates = inputs + data_shape = list(data.shape) + indices_shape = list(indices.shape) + is_1d_indices = False + if len(indices_shape) == 1: + indices_shape.append(1) + is_1d_indices = True + left_shape = indices_shape[:-1] + right_shape = data_shape[int(indices_shape[-1]):] + def gen_ir(data, indices, updates, out): + ib = tvm.ir_builder.create() + with ib.for_range_n(left_shape, "i") as i: + with ib.for_range_n(right_shape, "j") as j: + index_read = i + j + index_write = [] + inbound = True + if is_1d_indices: + temp_idx = ib.load(indices, i) + inbound = tvm.all((temp_idx >= 0), (temp_idx < data_shape[0])) + index_write.append(temp_idx) + else: + for k in range(0, int(indices_shape[-1])): + temp_idx = ib.load(indices, i+[k]) + if k == 0: + inbound = tvm.all((temp_idx >= 0), (temp_idx < data_shape[k])) + else: + inbound = tvm.all(inbound, (temp_idx >= 0), (temp_idx < data_shape[k])) + index_write.append(temp_idx) + index_write = index_write + j + with ib.if_scope(inbound): + temp = ib.load(updates, index_read) + ib.load(out, index_write) + ib.store(out, index_write, temp) + return ib.get() + + output_name = "T_tsa_" + data.op.name + "_" + indices.op.name + "_" + updates.op.name + out_buf = tvm.decl_buffer(data.shape, data.dtype, output_name) + return tvm.extern([data.shape], [data, indices, updates], lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), + dtype=data.dtype, out_buffers=[out_buf], name=output_name) + +@tvm.register_func("UnsortedSegmentSum") +def tensor_unsorted_segment_sum(inputs, attrs): + attrs = {k: v for k, v in attrs.items()} + num = attrs['num_segments'] + op_id = attrs['op_id'] if 'op_id' in attrs else 0 + if len(inputs) != 2: + raise ValueError(f"2 inputs expected, but got {len(input)}") + data, indices = inputs + data_shape = list(data.shape) + indices_shape = list(indices.shape) + segment_len = len(data_shape) - len(indices_shape) + if segment_len < 0: + raise ValueError(f'input rank should not be less than segment_id rank') + for i, v in enumerate(indices_shape): + if int(v) != int(data_shape[i]): + raise ValueError(f'input shape at dim {i} is not equal to segment_id shape at dim {i}') + output_shape = [num] + if segment_len > 0: + output_shape += data_shape[len(indices_shape):] + if len(indices_shape) > 1: + raise ValueError('only 1-D segment currently supported') + + def gen_ir(data, indices, out): + ib = tvm.ir_builder.create() + with ib.for_range_n(indices_shape, "i") as i: + read_idx = ib.load(indices, i) + # 1-D segment + with ib.for_range_n(data_shape[1:], 'j') as j: + inbound = tvm.all((read_idx >= 0), (read_idx < num)) + with ib.if_scope(inbound): + val = ib.load(data, i + j) + ib.load(out, [read_idx] + j) + ib.store(out, [read_idx] + j, val) + return ib.get() + + output_name = "T_uss_" + data.op.name + "_" + indices.op.name + out_buf = tvm.decl_buffer(output_shape, data.dtype, output_name) + return tvm.extern([data.shape], [data, indices], lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), + dtype=data.dtype, out_buffers=[out_buf], name=output_name) + +@tvm.register_func("Gather") +def gather(inputs, attrs): + attrs = {k: v for k, v in attrs.items()} + axis = int(attrs["axis"][0]) if "axis" in attrs else 0 + if len(inputs) != 2: + raise ValueError(f"2 inputs expected, but got {len(input)}") + data, indices = inputs + data_shape = list(data.shape) + indices_shape = list(indices.shape) + output_shape = data_shape[: axis] + indices_shape + data_shape[axis + 1:] + + def gen_ir(data, indices, out): + ib = tvm.ir_builder.create() + with ib.for_range_n(data_shape[: axis], "i") as i: + with ib.for_range_n(indices_shape, "j") as j: + load_idx = ib.load(indices, j) + inbound = tvm.all(load_idx >= 0, load_idx < data_shape[axis]) + read_idx = i + [load_idx] + with ib.for_range_n(data_shape[axis + 1:], "k") as k: + with ib.if_scope(inbound): + ib.store(out, i + j + k, ib.load(data, read_idx + k)) + with ib.else_scope(): + ib.store(out, i + j + k, tvm.const(0, data.dtype)) + return ib.get() + + output_name = "T_gather_" + data.op.name + "_" + indices.op.name + "_" + str(axis) + out_buf = tvm.decl_buffer(output_shape, data.dtype, output_name) + return tvm.extern([data.shape], [data, indices], lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), + dtype=data.dtype, out_buffers=[out_buf], name=output_name) diff --git a/python/akg/ops/array_gpu/fused_gather_gather_add_mul_max_exp_scatter_add.py b/python/akg/ops/array_gpu/fused_gather_gather_add_mul_max_exp_scatter_add.py new file mode 100644 index 00000000..ab9a623e --- /dev/null +++ b/python/akg/ops/array_gpu/fused_gather_gather_add_mul_max_exp_scatter_add.py @@ -0,0 +1,111 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""operator dsl function: fused_gather_gather_add_mul_max_exp_scatter_add""" +import akg.tvm as tvm +from akg.tvm.hybrid import script +from akg.utils import validation_check as vc_util +from akg.utils.format_transform import get_shape +from akg.utils.dsl_create import get_broadcast_shape +from .tensor_scatter_add import tensor_scatter_add +from akg.ops.math_gpu.add import add +from akg.ops.math_gpu.mul import mul +from akg.ops.math_gpu.maximum import maximum +from akg.ops.math_gpu.exp import exp + +@vc_util.check_input_type(tvm.tensor.Tensor, tvm.tensor.Tensor, int, str) +def gather(data, indices, axis, flag): + """Only support axis=0.""" + ndim = len(data.shape) + axis = axis + ndim if axis < 0 else axis + assert axis >= 0 + assert axis < ndim + + data_shape = list(data.shape) + indices_shape = list(indices.shape) + output_shape = data_shape[:axis] + indices_shape + data_shape[axis+1:] + left_shape = output_shape[:1] + right_shape = output_shape[1:] + + def gen_ir(data, indices, out): + ib = tvm.ir_builder.create() + with ib.for_range_n(left_shape, 'i') as i: + with ib.for_range_n(right_shape, 'j') as j: + read_idx = [ib.load(indices, i)] + val = ib.load(data, read_idx + j) + ib.store(out, i + j, val) + return ib.get() + + out_buf = tvm.decl_buffer(output_shape, data.dtype, "out_buf") + + return tvm.extern( + [output_shape], + [data, indices], + lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), + dtype=data.dtype, + out_buffers=[out_buf], + name="fused_gather" + flag, + ) + +@vc_util.check_input_type(tvm.tensor.Tensor, tvm.tensor.Tensor, tvm.tensor.Tensor) +def scatter_add(data, indices, updates): + """ + Args: + data: [x, y, z] + indices: [n] + updates: [n, y, z] + Output: + [x, y, z] + """ + left_shape = list(updates.shape[:1]) + right_shape = list(updates.shape[1:]) + + def gen_ir(data, indices, updates, out): + ib = tvm.ir_builder.create() + with ib.for_range_n(left_shape, "i") as i: + with ib.for_range_n(right_shape, "j") as j: + idx_updates = i + j + idx_data = [ib.load(indices, i)] + j + temp = ib.load(updates, idx_updates) + ib.load(out, idx_data) + ib.store(out, idx_data, temp) + return ib.get() + + out_buf = tvm.decl_buffer(data.shape, data.dtype, "out_buf") + return tvm.extern( + [data.shape], + [data, indices, updates], + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), + dtype=data.dtype, + out_buffers=[out_buf], + name="fused_scatter_add", + ) + +@vc_util.check_input_type(tvm.tensor.Tensor, tvm.tensor.Tensor, tvm.tensor.Tensor, + tvm.tensor.Tensor, int) +def fused_gather_gather_add_mul_max_exp_scatter_add(inp1, inp2, inp3, inp4, axis): + ndim = len(inp1.shape) + axis = axis + ndim if axis < 0 else axis + assert axis >= 0 + assert axis < ndim + + gather_out1 = gather(inp1, inp2, axis, "1") + gather_out2 = gather(inp1, inp2, axis, "2") + + add_out = add(gather_out1, gather_out2) + mul_out = mul(add_out, inp3) + max_out = maximum(add_out, mul_out) + exp_out = exp(max_out) + scatter_out = scatter_add(inp1, inp4, exp_out) + + return exp_out, scatter_out diff --git a/python/akg/ops/array_gpu/fused_gather_mul_scatter_add.py b/python/akg/ops/array_gpu/fused_gather_mul_scatter_add.py new file mode 100644 index 00000000..d43e920b --- /dev/null +++ b/python/akg/ops/array_gpu/fused_gather_mul_scatter_add.py @@ -0,0 +1,86 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""operator dsl function: gather_mul_scatter_add""" +import akg.tvm +from akg.tvm.hybrid import script +from akg.utils import validation_check as vc_util +from akg.utils.format_transform import get_shape +from akg.utils.dsl_create import get_broadcast_shape + + +@vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, int) +def fused_gather_mul_scatter_add(input1, input2, input3, input4, axis=0): + if axis < 0: + axis += len(input1.shape) + assert axis >= 0 + assert axis < len(input1.shape) + + axis_var = akg.tvm.const(0, dtype="int32") + if len(input1.shape) == 3: + gather_out_shape = get_shape(input1)[:axis] + [get_shape(input2)[0],] + get_shape(input1)[axis + 1:] + broadcast_shape = get_broadcast_shape(gather_out_shape, get_shape(input3)) + dim2_size = broadcast_shape[2] + dtype = input3.dtype + + @script(capture=locals()) + def _gather_out(input1_, input2_): + # gather + gather_out_ = output_tensor(broadcast_shape, input1_.dtype) + for i in range(broadcast_shape[0]): + for j in range(broadcast_shape[1]): + for k in range(broadcast_shape[2]): + if axis == 0: + gather_out_[i, j, k] = input1_[input2_[i], j, k] + elif axis == 1: + gather_out_[i, j, k] = input1_[i, input2_[j], k] + else: + gather_out_[i, j, k] = input1_[i, j, input2_[k]] + return gather_out_ + + gather_out = _gather_out(input1, input2) + + @script(capture=locals()) + def _mul_out(gather_out_, input3_): + # mul + mul_out_ = output_tensor(broadcast_shape, input3_.dtype) + + for i in range(input3_.shape[0]): + i1 = i if gather_out_.shape[0] == broadcast_shape[0] else 0 + i2 = i if input3_.shape[0] == broadcast_shape[0] else 0 + for j in range(input3_.shape[1]): + j1 = j if gather_out_.shape[1] == broadcast_shape[1] else 0 + j2 = j if input3_.shape[1] == broadcast_shape[1] else 0 + for k in range(dim2_size): + k1 = k if gather_out_.shape[2] == broadcast_shape[2] else 0 + k2 = k if input3_.shape[2] == broadcast_shape[2] else 0 + mul_out_[i, j, k] = gather_out_[i1, j1, k1] * input3_[i2, j2, k2] + return mul_out_ + + mul_out = _mul_out(gather_out, input3) + + @script(capture=locals()) + def _scatter_add(input1_, mul_out_, input4_): + # scatter_add + scatter_add_ = output_tensor(input1_.shape, input1_.dtype) + for i in range(broadcast_shape[0]): + for j in range(broadcast_shape[1]): + for k in range(broadcast_shape[2]): + scatter_add_[input4_[i, 0], j, k] += mul_out_[i, j, k] + + return scatter_add_ + + return _scatter_add(input1, mul_out, input4) + + raise ValueError("scatter_add only support for 3 dimensions") diff --git a/python/akg/ops/array_gpu/fused_gather_nd_reduce_sum_mul_unsorted_segment_sum.py b/python/akg/ops/array_gpu/fused_gather_nd_reduce_sum_mul_unsorted_segment_sum.py new file mode 100644 index 00000000..98cf3941 --- /dev/null +++ b/python/akg/ops/array_gpu/fused_gather_nd_reduce_sum_mul_unsorted_segment_sum.py @@ -0,0 +1,34 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""operator dsl function: scatter_add""" +import akg.tvm +from akg.utils import validation_check as vc_util +from .gather_nd import gather_nd +from .unsorted_segment_sum import unsorted_segment_sum +from ..math_gpu.mul import mul +from ..math_gpu.reduce_sum import reduce_sum + +@vc_util.check_input_type( + akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, int, bool, int) +def fused_gather_nd_reduce_sum_mul_unsorted_segment_sum(input1, input2, input3, input4, input5, axis=0, keepdims=False, num=0): + item_get = gather_nd(input1, input2) + sum_axis = reduce_sum(item_get, axis, keepdims) + prod = mul(sum_axis, input3) + res1 = unsorted_segment_sum(prod, input4, num, op_id=0) + res2 = unsorted_segment_sum(prod, input5, num, op_id=1) + return res1, res2 + + + diff --git a/python/akg/ops/array_gpu/gather.py b/python/akg/ops/array_gpu/gather.py new file mode 100644 index 00000000..87be0479 --- /dev/null +++ b/python/akg/ops/array_gpu/gather.py @@ -0,0 +1,26 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""operator dsl function: scatter_add""" +import akg.tvm +from akg.utils import validation_check as vc_util +from ...composite import gather as cuda_gather + +@vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, int) +def gather(data, indices, axis): + dim = data.ndim + if axis < -dim or axis >= dim: + raise ValueError(f'axis {axis} is out of bounds for array with dim {dim}') + axis = axis % dim + return cuda_gather((data, indices), {'axis': [axis]}) diff --git a/python/akg/ops/array_gpu/gather_nd.py b/python/akg/ops/array_gpu/gather_nd.py new file mode 100644 index 00000000..3b987bc2 --- /dev/null +++ b/python/akg/ops/array_gpu/gather_nd.py @@ -0,0 +1,23 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""operator dsl function: gather_nd""" +import akg.tvm +from akg.utils import validation_check as vc_util +from ...composite import gather_nd as cuda_gather_nd + + +@vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor) +def gather_nd(data, indices): + return cuda_gather_nd((data, indices), {}) diff --git a/python/akg/ops/array_gpu/tensor_scatter_add.py b/python/akg/ops/array_gpu/tensor_scatter_add.py new file mode 100644 index 00000000..e9210870 --- /dev/null +++ b/python/akg/ops/array_gpu/tensor_scatter_add.py @@ -0,0 +1,22 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""operator dsl function: scatter_add""" +import akg.tvm +from akg.utils import validation_check as vc_util +from ...composite import tensor_scatter_add as cuda_tensor_scatter_add + +@vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor) +def tensor_scatter_add(data, indices, updates): + return cuda_tensor_scatter_add((data, indices, updates), {}) diff --git a/python/akg/ops/array_gpu/unsorted_segment_sum.py b/python/akg/ops/array_gpu/unsorted_segment_sum.py new file mode 100644 index 00000000..5d500773 --- /dev/null +++ b/python/akg/ops/array_gpu/unsorted_segment_sum.py @@ -0,0 +1,22 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""operator dsl function: scatter_add""" +import akg.tvm +from akg.utils import validation_check as vc_util +from ...composite import tensor_unsorted_segment_sum as tensor_unsorted_segment_sum + +@vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, int, int) +def unsorted_segment_sum(data, indices, num, op_id=0): + return tensor_unsorted_segment_sum((data, indices), {'num_segments': num, 'op_id': op_id}) diff --git a/python/akg/utils/dsl_create.py b/python/akg/utils/dsl_create.py index 8144c0d4..a6fd1298 100644 --- a/python/akg/utils/dsl_create.py +++ b/python/akg/utils/dsl_create.py @@ -15,6 +15,8 @@ # limitations under the License. """dsl create helping function""" +import collections +import itertools import logging import math import akg @@ -327,6 +329,16 @@ def broadcast_gradient_args(x, y): return rx, ry +def get_broadcast_shape(*shapes): + shape_out = collections.deque() + reversed_shapes = map(reversed, shapes) + for items in itertools.zip_longest(*reversed_shapes, fillvalue=1): + max_size = 0 if 0 in items else max(items) + if any(item not in (1, max_size) for item in items): + raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}') + shape_out.appendleft(max_size) + return list(shape_out) + def zero_const(dtype): return akg.tvm.const(0, dtype) diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 2b18997c..633884eb 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -523,6 +523,7 @@ NodeRef LowerStmt(Schedule sch, const Array &in_args, const Array help_tiling_level["None"]) { if (tuning) { @@ -533,11 +534,11 @@ NodeRef LowerStmt(Schedule sch, const Array &in_args, const Array poly_res = NEXT_PASS(AutoPoly, stmt, *binds_0, target, g_attrs, false, false, new_sch); CHECK_EQ(poly_res.size(), 2); stmt = air::Downcast(poly_res[0]); g_attrs.Set(kEnablePolySch, air::make_const(Int(32), true)); + stmt = NEXT_PASS(LowerWith, stmt); } else { g_attrs.Set(kEnablePolySch, air::make_const(Int(32), false)); } diff --git a/src/composite/optimize/elim_reshape.cc b/src/composite/optimize/elim_reshape.cc index c40c3cc7..4d841dcc 100644 --- a/src/composite/optimize/elim_reshape.cc +++ b/src/composite/optimize/elim_reshape.cc @@ -336,8 +336,50 @@ std::string GetId(const std::string &name, int count) { return id.str(); } +class TSA : public IRMutator { + public: + explicit TSA(AnalysisResult &result) : result_(result){}; + ~TSA() = default; + + private: + Stmt Mutate_(const Provide *op, const Stmt &s) final { + auto op_name = GetOpName(op); + auto call = op->value.as(); + CHECK(call); + if (IsTransform(op_name)) { + CHECK(call->args[0].as()); + reshape_[op->func] = call->args[0]; + p_.emplace_back(op); + } + if (op_name == "TensorScatterAdd") { + CHECK(call->args.size() == 3); + CHECK(call->args[1].as()); + auto arg1 = call->args[1].as()->func; + if (reshape_.count(arg1) && call->args[1].as()->args.size() == 2) { + for (auto &it : p_) { + if (it->func == arg1) { + result_.to_be_removed.insert(it); + } + } + auto new_call = Call::make(call->type, op_name, {call->args[0], reshape_[arg1], call->args[2]}, call->call_type, + call->func, call->value_index); + return Provide::make(op->func, op->value_index, new_call, op->args); + } + } + return IRMutator::Mutate_(op, s); + } + + AnalysisResult &result_; + FuncExprMap reshape_; + std::vector p_; +}; + Stmt ElimReshapeBackward::Run(const Stmt &stmt) { auto s = stmt; + AnalysisResult as; + s = TSA(as).Mutate(s); + s = AnalysisResultMutator(as).Mutate(s); + auto checker = ElimReshapeOpChecker(); checker.Visit(s); if (!checker.can_elim) return s; @@ -357,6 +399,7 @@ Stmt ElimReshapeBackward::Run(const Stmt &stmt) { LOG(WARNING) << "ElimReshapeBackward reach to max_try_count!"; return s; } + Stmt ElimReshapeForward::Run(const Stmt &stmt) { auto s = stmt; auto checker = ElimReshapeOpChecker(); diff --git a/src/composite/util.cc b/src/composite/util.cc index 7ff50d90..f81af59a 100644 --- a/src/composite/util.cc +++ b/src/composite/util.cc @@ -74,10 +74,11 @@ bool IsInplaceAssign(const std::string &op_name) { return op_name == "InplaceAss bool IsAssign(const std::string &op_name) { return op_name == "Assign"; } bool IsOtherOp(const std::string &op_name) { // if topi support more, add to this list - std::unordered_set elems = {"MatMul", "BatchMatMul", "Conv", "Transpose", "Tile", - "Assign", "InplaceAssign", "EquivFormat", "TransData", "AddMinValue", - "BroadcastTo", "PadAkg", "UnPadAkg", "Conv2D", "CumSum", - "CumProd", "StridedSlice", "UserDefined"}; + std::unordered_set elems = {"MatMul", "BatchMatMul", "Conv", "Transpose", "Tile", + "Assign", "InplaceAssign", "EquivFormat", "TransData", "AddMinValue", + "BroadcastTo", "PadAkg", "UnPadAkg", "Conv2D", "CumSum", + "CumProd", "StridedSlice", "UserDefined", "GatherNd", "TensorScatterAdd", + "UnsortedSegmentSum", "Gather"}; return elems.find(op_name) != elems.end(); } bool IsElemwise(const std::string &op_name) { diff --git a/src/composite/util.h b/src/composite/util.h index eaa0a474..9bba63ab 100644 --- a/src/composite/util.h +++ b/src/composite/util.h @@ -16,6 +16,7 @@ #ifndef COMPOSITE_UTIL_H_ #define COMPOSITE_UTIL_H_ #include +#include #include "tvm.h" #include "picojson.h" @@ -366,7 +367,7 @@ struct AnalysisResult { class AnalysisResultMutator : public IRMutator { public: - explicit AnalysisResultMutator(AnalysisResult result, const std::string &id) + explicit AnalysisResultMutator(AnalysisResult result, std::string id="0") : result_(std::move(result)), id_(std::move(id)){}; private: diff --git a/src/pass/rewrite_tensor_index.cc b/src/pass/rewrite_tensor_index.cc index 3f726cde..186e42e0 100644 --- a/src/pass/rewrite_tensor_index.cc +++ b/src/pass/rewrite_tensor_index.cc @@ -88,6 +88,8 @@ class RewriteTensorIdx : public IRMutator { // remake provide now if (!lhs_tensor_idx_.empty()) { + is_tensor_of_tensor_ = true; + tensors_not_promote_.insert(op->func->func_name()); // build a new value Array idx_args; Expr extent = Expr(0); @@ -120,6 +122,7 @@ class RewriteTensorIdx : public IRMutator { idx_args.push_back(Call::make(type_, "orig", {new_op->value}, Call::PureIntrinsic)); Expr val = Call::make(type_, "with", idx_args, Call::PureIntrinsic); stmt = Provide::make(new_op->func, new_op->value_index, val, new_args); + stmt = AddAttrForAtomicToT(new_op, op, stmt); } lhs_tensor_idx_.clear(); @@ -131,6 +134,7 @@ class RewriteTensorIdx : public IRMutator { if (in_args_ && op->call_type == Call::Halide) { halide_call_ = true; if (cache_idx_.count(op->func.get()) == 0) { + inner_tensors_.insert(op->func->func_name()); cache_idx_[op->func.get()] = i_; i_ = i_ + 2; } @@ -150,6 +154,8 @@ class RewriteTensorIdx : public IRMutator { // for call not in provide, rhs always if (!rhs_tensor_idx_.empty()) { + is_tensor_of_tensor_ = true; + tensors_not_promote_.insert(op->func->func_name()); Array idx_args; Expr ne = e; @@ -182,6 +188,32 @@ class RewriteTensorIdx : public IRMutator { return cache_idx_[op->func.get()]; } + Stmt AddAttrForAtomicToT(const Provide *new_op, const Provide *op, Stmt stmt) { + auto Get = [new_op, op, stmt](const Expr a, const Expr b, std::string op_type) -> Stmt { + auto func_name = op->func->func_name(); + auto call_a = a.as(); + auto call_b = b.as(); + if (call_a && call_b && (call_a->name == func_name || call_b->name == func_name)) { + return AttrStmt::make(new_op->func, "atomic_tot", Expr(op_type), stmt); + } + return stmt; + }; + + if (auto atomic_op = op->value.as()) { + stmt = Get(atomic_op->a, atomic_op->b, "MaxOp"); + } else if (auto atomic_op = op->value.as()) { + stmt = Get(atomic_op->a, atomic_op->b, "MinOp"); + } else if (auto atomic_op = op->value.as()) { + stmt = Get(atomic_op->a, atomic_op->b, "AndOp"); + } else if (auto atomic_op = op->value.as()) { + stmt = Get(atomic_op->a, atomic_op->b, "OrOp"); + } else if (auto atomic_op = op->value.as()) { + stmt = Get(atomic_op->a, atomic_op->b, "SumOp"); + } + + return stmt; + } + std::unordered_map lhs_tensor_idx_; std::unordered_map rhs_tensor_idx_; std::unordered_map realize_type_; @@ -196,11 +228,29 @@ class RewriteTensorIdx : public IRMutator { public: bool has_invalid_tensor_expr_{false}; + std::unordered_set tensors_not_promote_; + std::unordered_set inner_tensors_; + bool is_tensor_of_tensor_{false}; }; Stmt RewriteTensorIndex(const Stmt stmt) { auto mutator = RewriteTensorIdx(); auto new_stmt = mutator.Mutate(stmt); + if (!mutator.tensors_not_promote_.empty()) { + for (auto &t : mutator.tensors_not_promote_) { + new_stmt = AttrStmt::make(Expr("INFO"), "TENSOR_NOT_PROMOTE", Expr(t), new_stmt); + } + } + + if (!mutator.inner_tensors_.empty()) { + for (auto &t : mutator.inner_tensors_) { + new_stmt = AttrStmt::make(Expr("INFO"), "INNER_TENSOR", Expr(t), new_stmt); + } + } + + if (mutator.is_tensor_of_tensor_) { + new_stmt = AttrStmt::make(Expr("INFO"), "TENSOR_OF_TENSOR", Expr("TENSOR_OF_TENSOR"), new_stmt); + } return mutator.has_invalid_tensor_expr_ ? stmt : new_stmt; } diff --git a/src/pass/tensor_access_rewrite.cc b/src/pass/tensor_access_rewrite.cc index f68c4132..8943c266 100644 --- a/src/pass/tensor_access_rewrite.cc +++ b/src/pass/tensor_access_rewrite.cc @@ -33,7 +33,9 @@ class TensorAccessRewriter : public IRMutator { } Expr Mutate_(const Call *op, const Expr &e) override { - if (op->name == "tensor_load") { + Expr expr = IRMutator::Mutate_(op, e); + op = expr.as(); + if (op != nullptr && op->name == "tensor_load") { auto it = tensors_.find(op->args[0].as()); CHECK(it != tensors_.end()); Tensor t = it->second; @@ -42,8 +44,8 @@ class TensorAccessRewriter : public IRMutator { args.push_back(op->args[i]); } return Call::make(t->dtype, t->op->name, args, Call::CallType::Halide, t->op, t->value_index); - } - return IRMutator::Mutate_(op, e); + } + return expr; } Stmt Mutate_(const Evaluate *op, const Stmt &s) override { diff --git a/src/poly/gpu_emit/gpu_isl_emitter.cc b/src/poly/gpu_emit/gpu_isl_emitter.cc index be4667a8..2a6f3f0b 100644 --- a/src/poly/gpu_emit/gpu_isl_emitter.cc +++ b/src/poly/gpu_emit/gpu_isl_emitter.cc @@ -16,6 +16,7 @@ #include "gpu_isl_emitter.h" #include "emit_pass.h" +#include "ir_pass.h" #include #include @@ -137,7 +138,16 @@ Stmt GpuIslEmitter::EmitStmt(const isl::ast_node_user &node) { } else if (info_.IsSync(stmt_id)) { return EmitSync(); } else { - return EmitUserStmt(node); + Stmt stmt = EmitUserStmt(node); + auto tot = info_.analysis_result_.GetTensorOfTensorStmt(); + auto id_name = stmt_id.get_name(); + if (tot.count(id_name)) { + std::string marker_name = ATOMIC_MARKER; + marker_name += "_"; + marker_name += tot[id_name]; + stmt = AttrStmt::make(Expr("INFO"), marker_name, StringImm::make(marker_name), stmt); + } + return stmt; } } @@ -319,6 +329,12 @@ Stmt GpuIslEmitter::Emit(const isl::ast_node &node) { // emit realize for temporary tensor stmt = EmitRealizeForGlobalTensor(stmt); + if (!info_.analysis_result_.GetTensorOfTensorStmt().empty()) { + stmt = LowerWith(stmt); + stmt = AtomicReturnStmtEmit(info_).Mutate(stmt); + stmt = AttrStmt::make(Expr("INFO"), REDUCE_LIB_TYPE_FLAG, info_.user_config_.GetReduceLibType(), stmt); + } + // iter var node attr emit std::map::iterator it; for (it = iter_name_map_.begin(); it != iter_name_map_.end(); ++it) { @@ -392,7 +408,7 @@ Stmt GpuIslEmitter::EmitMark(const isl::ast_node_mark &node) { Stmt stmt; if ((mark == PROMOTE_VECTORIZATION) || (mark == PROMOTE_REGISTER_TO_GLOBAL) || (mark == PROMOTE_REGISTER_TO_SHARED) || - (mark == PROMOTE_SHARED_TO_GLOBAL)) { + (mark == PROMOTE_SHARED_TO_GLOBAL) || IsStartsWith(mark, REDUCE_ATOMIC_FLAG)) { stmt = EmitAst(node.get_node()); if (!stmt.defined()) { return Stmt(); @@ -588,6 +604,106 @@ Stmt GpuIslEmitter::EmitAccessNodeFromPromoteAcsProvide(isl::id var, const Node return s; } +Stmt AtomicReturnStmtEmit::Mutate_(const AttrStmt *op, const Stmt &s) { + auto key = op->attr_key; + if (IsStartsWith(key, REDUCE_ATOMIC_FLAG)) { + in_atomic_area_ = true; + std::vector strs = common::Split(key, "_"); + CHECK_EQ(strs.size(), REDUCE_ATOMIC_FLAG_SIZE) << "atomic mark format is not right!."; + atomic_data_.reduce_op_.clear(); + if (AkgSupportedReduceOp.count(strs[REDUCE_ATOMIC_FLAG_TYPE_POS])) { + atomic_data_.reduce_op_ = AKG_REDUCE_LIB_SPACE; + atomic_data_.reduce_op_ += "::"; + atomic_data_.reduce_op_ += strs[REDUCE_ATOMIC_FLAG_TYPE_POS]; + } else { + CHECK(false) << "reduce op type is not supported!"; + } + } + return IRMutator::Mutate_(op, s); +} + +Stmt AtomicReturnStmtEmit::Mutate_(const Provide *op, const Stmt &s) { + if (in_atomic_area_) { + in_atomic_area_ = false; + Stmt stmt = IRMutator::Mutate_(op, s); + atomic_data_.gm_write_stmt_ = stmt; + auto op = stmt.as(); + CHECK(op); + auto value = op->value; + auto value_call = value.as(); + auto value_add = value.as(); + if (value_call) { + atomic_data_.atomic_rhs_ = op->value; + } + if (value_add) { + auto a = value_add->a.as(); + auto b = value_add->b.as(); + if (a && a->name != op->func->func_name()) { + atomic_data_.atomic_rhs_ = value_add->a; + } else if (b && b->name != op->func->func_name()) { + atomic_data_.atomic_rhs_ = value_add->b; + } else { + CHECK(false) << "no support atomic return type"; + } + } + CHECK(atomic_data_.atomic_rhs_.defined()) << "atomic_data_.atomic_rhs_ is not defined"; + atomic_data_.output_tensor_data_type_info_ = scop_info_.GetDtypeOf(op->func->func_name()); + + ConstructAtomicReturnFuncName(); + return MakeAtomicStmt(); + } + return IRMutator::Mutate_(op, s); +} + +void AtomicReturnStmtEmit::ConstructAtomicReturnFuncName() { + std::string reduce_lib_namespace = ""; + std::string reduce_return_name = ""; + if (scop_info_.user_config_.GetReduceLibType() == REDUCE_LIB_TYPE_ORIGIN) { + reduce_lib_namespace = AKG_REDUCE_LIB_SPACE; + reduce_return_name = AKG_REDUCE_RETURN_NAME; + } else if (scop_info_.user_config_.GetReduceLibType() == REDUCE_LIB_TYPE_PARIS) { + reduce_lib_namespace = PARIS_REDUCE_LIB_SPACE; + reduce_return_name = PARIS_REDUCE_RETURN_NAME; + } else { + CHECK(false) << "reduce lib type is invalid!"; + } + std::string ret = ""; + ret += reduce_lib_namespace; + ret += "::"; + ret += reduce_return_name; + + atomic_data_.akg_atomic_api_ = ret; + ret = ""; + + std::string op = atomic_data_.reduce_op_; + ret += op; + + atomic_data_.akg_atomic_template_arg_ = ret; +} + +Stmt AtomicReturnStmtEmit::MakeAtomicStmt() { + std::string func_name = atomic_data_.akg_atomic_api_; + + Expr template_arg0 = make_const(atomic_data_.output_tensor_data_type_info_, 1); + CHECK(!atomic_data_.akg_atomic_template_arg_.empty()); + Expr template_arg1 = StringImm::make(atomic_data_.akg_atomic_template_arg_); + + Expr a1 = atomic_data_.atomic_rhs_; + + auto p = atomic_data_.gm_write_stmt_.as(); + CHECK(p); + + Expr a2 = Call::make(p->value.type(), p->func->func_name(), p->args, Call::Halide, p->func, 0); + a2 = Call::make(a2.type(), "&", {a2}, Call::Extern); + + std::string op_info = atomic_data_.reduce_op_ + "()"; + + Array args; + Expr a3 = Call::make(Int(32), atomic_data_.reduce_op_, args, Call::Extern); + + return Evaluate::make(Call::make(Int(32), func_name, {template_arg0, template_arg1, a1, a2, a3}, Call::Extern)); +} + } // namespace poly } // namespace ir } // namespace akg diff --git a/src/poly/gpu_emit/gpu_isl_emitter.h b/src/poly/gpu_emit/gpu_isl_emitter.h index 230970d2..a824f023 100644 --- a/src/poly/gpu_emit/gpu_isl_emitter.h +++ b/src/poly/gpu_emit/gpu_isl_emitter.h @@ -31,6 +31,25 @@ constexpr auto MIND_TRICKS_SWIZZLE_PRAGMA = "pragma_swizzle_kernel"; constexpr auto ORIGIN_THREAD_DIM_X = "bind_thread_x"; constexpr auto SHARED_MEM_PROMOTED_COMPLETE = "shared_mem_promoted_complete"; +// example: +// atomic_SumOp +constexpr auto REDUCE_ATOMIC_FLAG_SIZE = 2; +constexpr auto REDUCE_ATOMIC_FLAG = "atomic"; +constexpr auto REDUCE_ATOMIC_FLAG_POS = 0; +constexpr auto REDUCE_ATOMIC_FLAG_TYPE_POS = 1; + +constexpr auto REDUCE_LIB_TYPE_ORIGIN = "origin"; +constexpr auto REDUCE_LIB_TYPE_PARIS = "paris"; +constexpr auto AKG_REDUCE_LIB_SPACE = "akg_reduce"; +constexpr auto AKG_REDUCE_LIB_NAME = "AkgReduce"; +constexpr auto AKG_KAHAN_LIB_NAME = "AkgKahanAccumulation"; +constexpr auto PARIS_REDUCE_LIB_SPACE = "paris_reduce"; +constexpr auto PARIS_REDUCE_LIB_NAME = "ParisReduce"; +constexpr auto AKG_REDUCE_RETURN_NAME = "AkgAtomicReturn"; +constexpr auto PARIS_REDUCE_RETURN_NAME = "ParisReturn"; +constexpr auto REDUCE_LIB_TYPE_FLAG = "reduceLibType"; +constexpr auto REDUCE_INIT_FLAG = "InitStmt"; + class GpuIslEmitter : public IslEmitter { public: GpuIslEmitter(ScopInfo &info, const NodeInfoRepo &n, const isl::id_list &i) : IslEmitter(info, n, i) {} @@ -82,6 +101,33 @@ class GpuIslEmitter : public IslEmitter { std::unordered_map stride_modify_iter_map_; }; +struct AtomicReturnData { + std::string reduce_op_; + std::string akg_atomic_api_; + std::string akg_atomic_template_arg_; + Type output_tensor_data_type_info_; + Expr atomic_rhs_; + Stmt gm_write_stmt_; +}; + +class AtomicReturnStmtEmit : public IRMutator { + public: + explicit AtomicReturnStmtEmit(ScopInfo &scop_info) : scop_info_(scop_info) {} + + Stmt Mutate_(const AttrStmt *op, const Stmt &s); + + Stmt Mutate_(const Provide *op, const Stmt &s); + + void ConstructAtomicReturnFuncName(); + + Stmt MakeAtomicStmt(); + + private: + ScopInfo &scop_info_; + AtomicReturnData atomic_data_; + bool in_atomic_area_{false}; +}; + } // namespace poly } // namespace ir } // namespace akg diff --git a/src/poly/gpu_emit/gpu_isl_emitter_reduce.h b/src/poly/gpu_emit/gpu_isl_emitter_reduce.h index 702637fb..be720685 100644 --- a/src/poly/gpu_emit/gpu_isl_emitter_reduce.h +++ b/src/poly/gpu_emit/gpu_isl_emitter_reduce.h @@ -38,13 +38,6 @@ constexpr auto REDUCE_FLAG_STMT_PREFIX_POS = 3; constexpr auto REDUCE_FLAG_STMT_NUM_POS = 4; constexpr auto REDUCE_FLAG_REDUCE_INDEX = 5; -// example: -// atomic_SumOp -constexpr auto REDUCE_ATOMIC_FLAG_SIZE = 2; -constexpr auto REDUCE_ATOMIC_FLAG = "atomic"; -constexpr auto REDUCE_ATOMIC_FLAG_POS = 0; -constexpr auto REDUCE_ATOMIC_FLAG_TYPE_POS = 1; - constexpr auto DEFAULT_TENSOR_INDEX = "[0]"; constexpr auto USELESS_INDEX = "0"; @@ -56,18 +49,6 @@ constexpr auto SCALAR_KHC_PREFIX = "kahan_c"; constexpr auto SHARED_MEMORY_PREFIX = "__shared__"; constexpr auto SHARED_TENSOR_PREFIX = "red_buf"; -constexpr auto REDUCE_LIB_TYPE_ORIGIN = "origin"; -constexpr auto REDUCE_LIB_TYPE_PARIS = "paris"; -constexpr auto AKG_REDUCE_LIB_SPACE = "akg_reduce"; -constexpr auto AKG_REDUCE_LIB_NAME = "AkgReduce"; -constexpr auto AKG_KAHAN_LIB_NAME = "AkgKahanAccumulation"; -constexpr auto PARIS_REDUCE_LIB_SPACE = "paris_reduce"; -constexpr auto PARIS_REDUCE_LIB_NAME = "ParisReduce"; -constexpr auto AKG_REDUCE_RETURN_NAME = "AkgAtomicReturn"; -constexpr auto PARIS_REDUCE_RETURN_NAME = "ParisReturn"; -constexpr auto REDUCE_LIB_TYPE_FLAG = "reduceLibType"; -constexpr auto REDUCE_INIT_FLAG = "InitStmt"; - constexpr auto MEM_TYPE_SHARED = "shared"; constexpr auto MEM_TYPE_LOCAL = "local"; diff --git a/src/poly/gpu_emit/gpu_reduce_emit_pass.cc b/src/poly/gpu_emit/gpu_reduce_emit_pass.cc index 690fc66a..8a26da33 100644 --- a/src/poly/gpu_emit/gpu_reduce_emit_pass.cc +++ b/src/poly/gpu_emit/gpu_reduce_emit_pass.cc @@ -474,109 +474,6 @@ class ReduceStmtEmit : public IRMutator { Stmt rest_part_; }; -struct AtomicReturnData { - std::string reduce_op_; - std::string akg_atomic_api_; - std::string akg_atomic_template_arg_; - Type output_tensor_data_type_info_; - Expr atomic_rhs_; - Stmt gm_write_stmt_; -}; - -class AtomicReturnStmtEmit : public IRMutator { - public: - explicit AtomicReturnStmtEmit(ScopInfo &scop_info) : scop_info_(scop_info) {} - - Stmt Mutate_(const AttrStmt *op, const Stmt &s) { - auto key = op->attr_key; - if (IsStartsWith(key, REDUCE_ATOMIC_FLAG)) { - in_atomic_area_ = true; - std::vector strs = common::Split(key, "_"); - CHECK_EQ(strs.size(), REDUCE_ATOMIC_FLAG_SIZE) << "atomic mark format is not right!."; - atomic_data_.reduce_op_.clear(); - if (AkgSupportedReduceOp.count(strs[REDUCE_ATOMIC_FLAG_TYPE_POS])) { - atomic_data_.reduce_op_ = AKG_REDUCE_LIB_SPACE; - atomic_data_.reduce_op_ += "::"; - atomic_data_.reduce_op_ += strs[REDUCE_ATOMIC_FLAG_TYPE_POS]; - } else { - CHECK(false) << "reduce op type is not supported!"; - } - } - return IRMutator::Mutate_(op, s); - } - - Stmt Mutate_(const Provide *op, const Stmt &s) { - if (in_atomic_area_) { - in_atomic_area_ = false; - Stmt stmt = IRMutator::Mutate_(op, s); - atomic_data_.gm_write_stmt_ = stmt; - auto op = stmt.as(); - CHECK(op); - atomic_data_.atomic_rhs_ = op->value; - atomic_data_.output_tensor_data_type_info_ = scop_info_.GetDtypeOf(op->func->func_name()); - - ConstructAtomicReturnFuncName(); - return MakeAtomicStmt(); - } - return IRMutator::Mutate_(op, s); - } - - void ConstructAtomicReturnFuncName() { - std::string reduce_lib_namespace = ""; - std::string reduce_return_name = ""; - if (scop_info_.user_config_.GetReduceLibType() == REDUCE_LIB_TYPE_ORIGIN) { - reduce_lib_namespace = AKG_REDUCE_LIB_SPACE; - reduce_return_name = AKG_REDUCE_RETURN_NAME; - } else if (scop_info_.user_config_.GetReduceLibType() == REDUCE_LIB_TYPE_PARIS) { - reduce_lib_namespace = PARIS_REDUCE_LIB_SPACE; - reduce_return_name = PARIS_REDUCE_RETURN_NAME; - } else { - CHECK(false) << "reduce lib type is invalid!" - << "\n"; - } - std::string ret = ""; - ret += reduce_lib_namespace; - ret += "::"; - ret += reduce_return_name; - - atomic_data_.akg_atomic_api_ = ret; - ret = ""; - - std::string op = atomic_data_.reduce_op_; - ret += op; - - atomic_data_.akg_atomic_template_arg_ = ret; - } - - Stmt MakeAtomicStmt() { - std::string func_name = atomic_data_.akg_atomic_api_; - - Expr template_arg0 = make_const(atomic_data_.output_tensor_data_type_info_, 1); - CHECK(!atomic_data_.akg_atomic_template_arg_.empty()); - Expr template_arg1 = StringImm::make(atomic_data_.akg_atomic_template_arg_); - - Expr a1 = atomic_data_.atomic_rhs_; - - auto p = atomic_data_.gm_write_stmt_.as(); - CHECK(p); - - Expr a2 = Call::make(p->value.type(), p->func->func_name(), p->args, Call::Halide, p->func, 0); - a2 = Call::make(a2.type(), "&", {a2}, Call::Extern); - - std::string op_info = atomic_data_.reduce_op_ + "()"; - - Array args; - Expr a3 = Call::make(Int(32), atomic_data_.reduce_op_, args, Call::Extern); - - return Evaluate::make(Call::make(Int(32), func_name, {template_arg0, template_arg1, a1, a2, a3}, Call::Extern)); - } - - private: - ScopInfo &scop_info_; - AtomicReturnData atomic_data_; - bool in_atomic_area_{false}; -}; - class ConditionExprMod : public air::ir::IRMutator { public: explicit ConditionExprMod(bool &is_found) : is_found_(is_found) {} diff --git a/src/poly/isl_emitter.cc b/src/poly/isl_emitter.cc index fbeee759..eb8b8f31 100644 --- a/src/poly/isl_emitter.cc +++ b/src/poly/isl_emitter.cc @@ -613,6 +613,12 @@ Stmt IslEmitter::EmitUserStmtContent(const Block *block_node) { return stmt; } +Stmt IslEmitter::EmitUserStmtContent(const AttrStmt *attr_node) { + Stmt stmt = EmitUserStmtContent(attr_node->body.get()); + stmt = AttrStmt::make(attr_node->node, attr_node->_type_key, attr_node->value, stmt); + return stmt; +} + Stmt IslEmitter::EmitUserStmtContent(const Node *node) { if (node->IsInstance()) { const auto op = static_cast(node); @@ -632,6 +638,10 @@ Stmt IslEmitter::EmitUserStmtContent(const Node *node) { LOG(WARNING) << "found Evaluate in isl::ast_node_user"; const auto op = static_cast(node); return EmitUserStmtContent(op); + } else if (node->IsInstance()) { + LOG(WARNING) << "found AttrStmt in isl::ast_node_user"; + const auto op = static_cast(node); + return EmitUserStmtContent(op); } else { CHECK(false) << "unknown node type in isl::ast_node_user: " << node << " " << node->_type_key; return Stmt(); diff --git a/src/poly/isl_emitter.h b/src/poly/isl_emitter.h index b4625e17..a1c50d5f 100644 --- a/src/poly/isl_emitter.h +++ b/src/poly/isl_emitter.h @@ -97,6 +97,7 @@ class IslEmitter { virtual Stmt EmitUserStmtContent(const IfThenElse *if_node); virtual Stmt EmitUserStmtContent(const For *for_node); virtual Stmt EmitUserStmtContent(const Block *block_node); + virtual Stmt EmitUserStmtContent(const AttrStmt *stmt_node); // Loop isl iters info virtual void PushIter(const Variable *iter); diff --git a/src/poly/schedule_pass.h b/src/poly/schedule_pass.h index 1d863924..d4dc9298 100644 --- a/src/poly/schedule_pass.h +++ b/src/poly/schedule_pass.h @@ -64,7 +64,7 @@ isl::schedule_constraints MakeScheduleConstraints(const isl::schedule &schedule, isl::union_map RemoveReduceOpSelfDependence(ScopInfo &scop_info, PassInfo &pass_info); -isl::union_map RemoveSelfDependence(PassInfo &pass_info); +isl::union_map RemoveSelfDependence(PassInfo &pass_info, std::map tensor_name_map = {}); isl::union_map RemoveInvariantDependence(const isl::schedule &schedule, PassInfo &pass_info, ScopInfo &scop_info); @@ -101,8 +101,7 @@ std::vector GetTileSizeOfLevel(const int member_size, const int dim_size, c /* * Obtain the information needed during the data promotion phase. */ -std::string GetPromotionTensorName(const isl::schedule_node &node, - const std::vector &buffer_def_infos); +std::string GetPromotionTensorName(const isl::schedule_node &node, const std::vector &buffer_def_infos); bool IsReadOrWriteTensor(const isl::schedule_node &node, const std::string read_name, const std::string write_name); diff --git a/src/poly/schedule_pass/init_schedule.cc b/src/poly/schedule_pass/init_schedule.cc index 04cfb9a7..2389ef3d 100644 --- a/src/poly/schedule_pass/init_schedule.cc +++ b/src/poly/schedule_pass/init_schedule.cc @@ -138,6 +138,11 @@ isl::schedule InitSchedule::Run(isl::schedule sch) { ForceDepBetweenLiveouts(sinks); pass_info_.dependences_ = pass_info_.dependences_.unite(pass_info_.force_dependences_); } + + auto tot_stmt = scop_info_.analysis_result_.GetTensorOfTensorStmt(); + if (!tot_stmt.empty()) { + pass_info_.dependences_ = RemoveSelfDependence(pass_info_, tot_stmt); + } } pass_info_.orig_dependences_ = pass_info_.dependences_; diff --git a/src/poly/schedule_pass/rm_self_dep.cc b/src/poly/schedule_pass/rm_self_dep.cc index 22221bc2..a40fb4ef 100644 --- a/src/poly/schedule_pass/rm_self_dep.cc +++ b/src/poly/schedule_pass/rm_self_dep.cc @@ -660,73 +660,80 @@ isl::union_map RemoveReduceOpSelfDependence(ScopInfo &scop_info, PassInfo &pass_ * value: reduce axis no. of this reduce stmt, if the no. >= 1, it is the reduce statement * *********************************************/ std::unordered_map is_tuple_reduce_op; - pass_info.dependences_.foreach_map([&scop_info, &pass_info, &preserved_dependences, - &is_tuple_reduce_op](const isl::map &m) -> void { - if (m.domain().get_tuple_id() != m.range().get_tuple_id()) { - preserved_dependences = preserved_dependences.add_map(m); - } else { // self dependence - isl::id tuple_id = m.domain().get_tuple_id(); - std::string tuple_id_key = tuple_id.get_name(); - if (is_tuple_reduce_op.count(tuple_id_key) == 0) { - std::vector reduce_axis_list; - ReduceOp res = std::make_pair(false, ""); - ReduceAxisInfo reduce_axis_info = IsMultiAxisSelfDependence(pass_info.dependences_, tuple_id); - is_tuple_reduce_op[tuple_id_key] = reduce_axis_info.second; - if (reduce_axis_info.first) { - res = CheckIsStmtReduceOp(scop_info.analysis_result_.GetReads(), scop_info.analysis_result_.GetWrites(), - tuple_id, reduce_axis_list); - if (!(res.first || CheckIsStmtReduceOp(pass_info.dependences_, tuple_id, reduce_axis_list))) { - is_tuple_reduce_op[tuple_id_key] = 0; + pass_info.dependences_.foreach_map( + [&scop_info, &pass_info, &preserved_dependences, &is_tuple_reduce_op](const isl::map &m) -> void { + if (m.domain().get_tuple_id() != m.range().get_tuple_id()) { + preserved_dependences = preserved_dependences.add_map(m); + } else { // self dependence + isl::id tuple_id = m.domain().get_tuple_id(); + std::string tuple_id_key = tuple_id.get_name(); + if (is_tuple_reduce_op.count(tuple_id_key) == 0) { + std::vector reduce_axis_list; + ReduceOp res = std::make_pair(false, ""); + ReduceAxisInfo reduce_axis_info = IsMultiAxisSelfDependence(pass_info.dependences_, tuple_id); + is_tuple_reduce_op[tuple_id_key] = reduce_axis_info.second; + if (reduce_axis_info.first) { + res = CheckIsStmtReduceOp(scop_info.analysis_result_.GetReads(), scop_info.analysis_result_.GetWrites(), + tuple_id, reduce_axis_list); + if (!(res.first || CheckIsStmtReduceOp(pass_info.dependences_, tuple_id, reduce_axis_list))) { + is_tuple_reduce_op[tuple_id_key] = 0; + } } - } - if (is_tuple_reduce_op[tuple_id_key] >= 2) { - ReduceTensorInfo reduce_tensor_info; - reduce_tensor_info.axis_vec = reduce_axis_list; - reduce_tensor_info.stmt_map = isl::union_map::empty(isl::space(scop_info.ctx_, 0)); - scop_info.analysis_result_.RecordReduceTensorInfoMap(tuple_id, reduce_tensor_info); - } + if (is_tuple_reduce_op[tuple_id_key] >= 2) { + ReduceTensorInfo reduce_tensor_info; + reduce_tensor_info.axis_vec = reduce_axis_list; + reduce_tensor_info.stmt_map = isl::union_map::empty(isl::space(scop_info.ctx_, 0)); + scop_info.analysis_result_.RecordReduceTensorInfoMap(tuple_id, reduce_tensor_info); + } - /*************************************************** - * New flow of atomic add optimization on poly npu - * will store the reduce tensor info for npu isl emitter. - ****************************************************/ - if (is_tuple_reduce_op[tuple_id_key] >= 1 && scop_info.user_config_.GetEnableAtomicAdd() && - !res.second.empty()) { - scop_info.analysis_result_.RecordReduceOutTensors(res.second); + /*************************************************** + * New flow of atomic add optimization on poly npu + * will store the reduce tensor info for npu isl emitter. + ****************************************************/ + if (is_tuple_reduce_op[tuple_id_key] >= 1 && scop_info.user_config_.GetEnableAtomicAdd() && + !res.second.empty()) { + scop_info.analysis_result_.RecordReduceOutTensors(res.second); + } } - } - // for reduce axis number is smaller than one, keep the dependences relation - if (is_tuple_reduce_op[tuple_id_key] <= 1) { - preserved_dependences = preserved_dependences.add_map(m); + // for reduce axis number is smaller than one, keep the dependences relation + if (is_tuple_reduce_op[tuple_id_key] <= 1) { + preserved_dependences = preserved_dependences.add_map(m); + } } - } - }); + }); return preserved_dependences; } /* * Removes all self dependences in the program. Use with special care. + * If tensor_name_map is not empty, only the self-dependency of tensor in tensor_name_map is deleted. */ -isl::union_map RemoveSelfDependence(PassInfo &pass_info) { +isl::union_map RemoveSelfDependence(PassInfo &pass_info, std::map tensor_name_map) { isl::union_map preserved = isl::union_map::empty(pass_info.dependences_.get_space()); isl::union_map removed = isl::union_map::empty(pass_info.dependences_.get_space()); - pass_info.dependences_.foreach_map([&](const isl::map &m) -> void { - if (m.domain().get_tuple_id() != m.range().get_tuple_id()) { + pass_info.dependences_.foreach_map([&preserved, &removed, tensor_name_map](const isl::map &m) -> void { + auto domian_id = m.domain().get_tuple_id(); + if (domian_id != m.range().get_tuple_id()) { preserved = preserved.add_map(m); } else { - removed = removed.add_map(m); + if (!tensor_name_map.empty() && tensor_name_map.count(domian_id.get_name()) == 0) { + preserved = preserved.add_map(m); + } else { + removed = removed.add_map(m); + } } }); if (!removed.is_empty()) LOG(INFO) << "force remove self dependence: " << removed; return preserved; } -static bool HasAllReduce(std::unordered_map &reduce_repo, OperatorDomainMap &domain_map) { +static bool HasAllReduce(std::unordered_map &reduce_repo, + OperatorDomainMap &domain_map) { for (auto item : domain_map) { auto dim = item.second.param_space.dim(isl_dim_param); - if ( reduce_repo.count(item.first) > 0 && dim - reduce_repo[item.first] == 0) { + if (reduce_repo.count(item.first) > 0 && dim - reduce_repo[item.first] == 0) { return true; } } diff --git a/src/poly/schedule_pass/tile_outer_band.cc b/src/poly/schedule_pass/tile_outer_band.cc index 6ed7a8d1..59fc228e 100644 --- a/src/poly/schedule_pass/tile_outer_band.cc +++ b/src/poly/schedule_pass/tile_outer_band.cc @@ -19,6 +19,7 @@ #include "poly/scop.h" #include "poly/schedule_pass/transfer_stmt.h" #include "poly/schedule_pass/try_mark_scalar_stmt.h" +#include "poly/schedule_tree_util.h" #include "poly/reduce_manager.h" #include @@ -203,12 +204,12 @@ void TileOuterBand::InitDimensionInfo(const isl::schedule &sch_init) { return; } - int dim_info_entry_size = 4; + int dim_info_entry_size = DIM_SIZE; const std::vector thread_block_list = {T0, T1, T2, B0, B1, B2}; for (auto i : thread_block_list) { if (dim.find(i) != std::string::npos) { scop_info_.analysis_result_.SetIsCustomMapping(true); - dim_info_entry_size = 6; + dim_info_entry_size = CUSTOM_DIM_SIZE; break; } } @@ -246,9 +247,27 @@ void TileOuterBand::InitDimensionInfo(const isl::schedule &sch_init) { scop_info_.analysis_result_.InsertDimensionInfo(dim_info); if (scop_info_.analysis_result_.GetIsCustomMapping()) { - CustomMappingConfig(str, i); + CHECK(str.size() >= CUSTOM_DIM_SIZE) + << "The configuration length of custom mapping must not be less than " << CUSTOM_DIM_SIZE << "!"; + int axis_number = static_cast(WrappedStrtol(str[i + 1])); + std::string outer_mapping = str[i + 4]; + if (outer_mapping != "-") { + scop_info_.user_config_.RecordCustomOuterMapping(axis_number, outer_mapping); + } + + std::string inner_mapping = str[i + 5]; + if (inner_mapping != "-") { + scop_info_.user_config_.RecordCustomInnerMapping(axis_number, inner_mapping); + } } } + + if (scop_info_.analysis_result_.GetIsCustomMapping()) { + CheckCustomMapping(scop_info_.user_config_.GetCustomInnerMapping()); + CheckCustomMapping(scop_info_.user_config_.GetCustomOuterMapping()); + scop_info_.user_config_.RecordCustomInnerMapping(-1, ""); + scop_info_.user_config_.RecordCustomOuterMapping(-1, ""); + } } void TileOuterBand::MergeTilingInfo() { @@ -396,48 +415,30 @@ isl::schedule_node TileOuterBand::MarkOuterPermutableNpu(isl::schedule_node node return node; } -void TileOuterBand::CustomMappingConfig(const std::vector &str, const int index) { - CHECK(str.size() >= 6) << "The configuration length of custom mapping must not be less than 6."; - int axis_number = static_cast(WrappedStrtol(str[index + 1])); - auto CheckCustomMapping = [this](std::unordered_map custom_mapping_map) -> void { - const std::unordered_set thread_set = {T0, T1, T2}; - const std::unordered_set block_set = {B0, B1, B2}; - - size_t thread_prefix = 0; - size_t block_prefix = 0; - for (auto custom_mapping : custom_mapping_map) { - if (thread_set.find(custom_mapping.second) != thread_set.end()) { - ++thread_prefix; - } else if (block_set.find(custom_mapping.second) != block_set.end()) { - ++block_prefix; - } else { - LOG(FATAL) << "The custom configuration must be t0, t1, t2, b0, b1 and b2."; - } - } - if (thread_prefix != custom_mapping_map.size() && block_prefix != custom_mapping_map.size()) { - LOG(FATAL) << "All of the inner configuration or the outer configuration must be threads or blocks."; - } - - if (thread_prefix == custom_mapping_map.size()) { - scop_info_.analysis_result_.SetIsOuterBlockMapping(false); +void TileOuterBand::CheckCustomMapping(const std::unordered_map &custom_mapping_map) { + const std::unordered_set thread_set = {T0, T1, T2}; + const std::unordered_set block_set = {B0, B1, B2}; + + size_t thread_prefix = 0; + size_t block_prefix = 0; + for (auto custom_mapping : custom_mapping_map) { + if (thread_set.find(custom_mapping.second) != thread_set.end()) { + ++thread_prefix; + } else if (block_set.find(custom_mapping.second) != block_set.end()) { + ++block_prefix; } else { - scop_info_.analysis_result_.SetIsOuterBlockMapping(true); + LOG(FATAL) << "The custom configuration must be t0, t1, t2, b0, b1 and b2."; } - }; - std::string outer_mapping = str[index + 4]; - if (outer_mapping != "-") { - scop_info_.user_config_.RecordCustomOuterMapping(axis_number, outer_mapping); } - - std::string inner_mapping = str[index + 5]; - if (inner_mapping != "-") { - scop_info_.user_config_.RecordCustomInnerMapping(axis_number, inner_mapping); + if (thread_prefix != custom_mapping_map.size() && block_prefix != custom_mapping_map.size()) { + LOG(FATAL) << "All of the inner configuration or the outer configuration must be threads or blocks."; } - CheckCustomMapping(scop_info_.user_config_.GetCustomInnerMapping()); - CheckCustomMapping(scop_info_.user_config_.GetCustomOuterMapping()); - scop_info_.user_config_.RecordCustomInnerMapping(-1, ""); - scop_info_.user_config_.RecordCustomOuterMapping(-1, ""); + if (thread_prefix == custom_mapping_map.size()) { + scop_info_.analysis_result_.SetIsOuterBlockMapping(false); + } else { + scop_info_.analysis_result_.SetIsOuterBlockMapping(true); + } } std::vector> TileOuterBand::AddTileInfo(const std::vector> &partition_info) { diff --git a/src/poly/schedule_pass/tile_outer_band.h b/src/poly/schedule_pass/tile_outer_band.h index b83d47dc..3359c802 100644 --- a/src/poly/schedule_pass/tile_outer_band.h +++ b/src/poly/schedule_pass/tile_outer_band.h @@ -24,6 +24,8 @@ namespace ir { namespace poly { constexpr auto KH_KW_DEPTH = 2; +constexpr auto DIM_SIZE = 4; +constexpr auto CUSTOM_DIM_SIZE = 6; /* * Tile the outer band accoding to TilingInfo. In this pass, we get the out-most band, @@ -99,11 +101,9 @@ class TileOuterBand : public SchedulePass { isl::schedule_node InsertPromoteMarker(const isl::schedule_node node); void ResetWarpMappingConfig(); isl::schedule_node MatmulTile(const isl::schedule_node &node); - void CustomMappingConfig(const std::vector &str, const int index); + void CheckCustomMapping(const std::unordered_map &custom_mapping_map); bool IsMatrixCPromoteToShared(); - void CheckVectorizedForElemwiseOp(isl::schedule_node node); - private: PassInfo &pass_info_; ScopInfo &scop_info_; diff --git a/src/poly/schedule_pass_gpu/gpu_dma_analysis.cc b/src/poly/schedule_pass_gpu/gpu_dma_analysis.cc index 985b6f86..8daf1508 100644 --- a/src/poly/schedule_pass_gpu/gpu_dma_analysis.cc +++ b/src/poly/schedule_pass_gpu/gpu_dma_analysis.cc @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "poly/schedule_tree_util.h" #include "gpu_dma_analysis.h" #include "poly/scop.h" diff --git a/src/poly/schedule_pass_gpu/mapping_outer_band.cc b/src/poly/schedule_pass_gpu/mapping_outer_band.cc index e14632a5..9332c41e 100644 --- a/src/poly/schedule_pass_gpu/mapping_outer_band.cc +++ b/src/poly/schedule_pass_gpu/mapping_outer_band.cc @@ -380,6 +380,16 @@ isl::schedule MappingOuterBand::DoThreadMapping(const isl::schedule &sch) { node = AdjustConvScheduleTreeStructure(node, false); } mapped_threads = bmm_op.MapThreadHelper(node); + } else if (scop_info_.analysis_result_.GetTensorOfTensor()) { + // tensor of tensor + int last_axis = scop_info_.analysis_result_.GetLastAxisInScheduleTree(); + if (last_axis < 0 || last_axis >= static_cast(node.as().n_member())) { + OperatorMappingStrategy others_op(pass_info_, scop_info_); + mapped_threads = others_op.MapThreadHelper(node, false); + } else { + TOTMappingStrategy tot_op(pass_info_, scop_info_); + mapped_threads = tot_op.MapThreadHelper(node); + } } else { // others operator OperatorMappingStrategy others_op(pass_info_, scop_info_); @@ -476,6 +486,16 @@ isl::schedule MappingOuterBand::DoBlockMapping(const isl::schedule &sch) { // conv operator ConvMappingStrategy conv_op(pass_info_, scop_info_); node = conv_op.ResetConvBlockMappingConfig(node, block_cfg, map_idx_shift.empty()); + } else if (scop_info_.analysis_result_.GetTensorOfTensor()) { + // tensor of tensor + int last_axis = scop_info_.analysis_result_.GetLastAxisInScheduleTree(); + if (last_axis < 0 || last_axis >= static_cast(n_block_map)) { + OperatorMappingStrategy others_op(pass_info_, scop_info_); + node = others_op.MapBlockHelper(node, block_cfg, n_block_map, map_idx_shift.empty(), map_idx_shift); + } else { + TOTMappingStrategy tot_op(pass_info_, scop_info_); + node = tot_op.MapBlockHelper(node, block_cfg, n_block_map, map_idx_shift.empty(), map_idx_shift); + } } else { // others operator OperatorMappingStrategy others_op(pass_info_, scop_info_); @@ -486,52 +506,6 @@ isl::schedule MappingOuterBand::DoBlockMapping(const isl::schedule &sch) { return final_schedule; } -isl::schedule_node MappingOuterBand::InsertCustomMappingFilter(const isl::schedule_node &node, - isl::union_pw_aff_list upa_list, MappingCfg *mapping_cfg, - Mapping &mapping, - std::unordered_map custom_mapping, - std::unordered_set outer_mapping_cfg) { - isl::union_set domain = node.get_schedule().get_domain(); - - std::unordered_set current_mapping_cfg; - for (size_t i = 0; i < upa_list.size(); ++i) { - if (custom_mapping.count(static_cast(i)) == 0) { - continue; - } - auto mapping_i = custom_mapping[static_cast(i)]; - current_mapping_cfg.emplace(mapping_i); - std::pair cfg = mapping_cfg->GetAt(mapping_i); - - auto upa = upa_list.get_at(i); - CHECK_GT(cfg.second, 0); - upa = upa.mod(isl::val(node.ctx(), cfg.second)); - auto id = isl::id(node.ctx(), cfg.first); - mapping[id] = upa; - domain = upa.domain(); - } - - // Set other configurations to 0. - if (!outer_mapping_cfg.empty()) { - for (size_t i = 0; i < mapping_cfg->bound; ++i) { - CHECK(!domain.is_null()); - auto universe = domain.universe(); - // Remove the configuration that has been mapped. - if (current_mapping_cfg.find(mapping_cfg->GetAt(i).first) != current_mapping_cfg.end()) { - continue; - } - // Remove the configuration in the outer mapping. - if (outer_mapping_cfg.find(mapping_cfg->GetAt(i).first) != outer_mapping_cfg.end()) { - continue; - } - std::pair cfg = mapping_cfg->GetAt(i); - auto id = isl::id(node.ctx(), cfg.first); - mapping[id] = isl::union_pw_aff(universe, isl::val::zero(domain.ctx())); - } - } - - return InsertMapFilter(node, false, mapping); -} - // Map the inner and outer bands to the inner and outer mapping configuration. isl::schedule_node MappingOuterBand::MapCustomHelper(const isl::schedule_node orig_node, const bool is_inner, MappingCfg *mapping_cfg) { @@ -560,7 +534,7 @@ isl::schedule_node MappingOuterBand::MapCustomHelper(const isl::schedule_node or for (auto outer_mapping : scop_info_.user_config_.GetCustomOuterMapping()) { outer_mapping_cfg.emplace(outer_mapping.second); } - node = InsertCustomMappingFilter(node, upa_list, mapping_cfg, mapping, custom_mapping_cfg, outer_mapping_cfg); + node = InsertRequiredMappingFilter(node, upa_list, mapping_cfg, mapping, custom_mapping_cfg, outer_mapping_cfg); } else { custom_mapping_cfg = scop_info_.user_config_.GetCustomOuterMapping(); @@ -574,7 +548,7 @@ isl::schedule_node MappingOuterBand::MapCustomHelper(const isl::schedule_node or } node = CheckMapSizeAndApplyTile(node, range_aff_list, mapping_cfg, true, custom_mapping_cfg); node = node.insert_mark(isl::id(node.ctx(), BLOCK_MARKER)).child(0); - node = InsertCustomMappingFilter(node, upa_list, mapping_cfg, mapping, custom_mapping_cfg); + node = InsertRequiredMappingFilter(node, upa_list, mapping_cfg, mapping, custom_mapping_cfg); } scop_info_.upa_node_mapping_.emplace_back(std::make_pair(node.parent(), mapping)); diff --git a/src/poly/schedule_pass_gpu/mapping_outer_band.h b/src/poly/schedule_pass_gpu/mapping_outer_band.h index 9762754c..83f58820 100644 --- a/src/poly/schedule_pass_gpu/mapping_outer_band.h +++ b/src/poly/schedule_pass_gpu/mapping_outer_band.h @@ -41,12 +41,10 @@ class MappingOuterBand : public SchedulePass { isl::schedule DoBlockMapping(const isl::schedule &sch); + // custom mapping isl::schedule DoCustomMapping(const isl::schedule &sch); isl::schedule_node MapCustomHelper(const isl::schedule_node orig_node, const bool is_inner, MappingCfg *mapping_cfg); - isl::schedule_node InsertCustomMappingFilter(const isl::schedule_node &node, isl::union_pw_aff_list upa_list, - MappingCfg *mapping_cfg, Mapping &mapping, - std::unordered_map custom_mapping, - std::unordered_set outer_mapping_cfg = {}); + std::unordered_map GetRequiredMappingCfg(MappingCfg *mapping_cfg); size_t NumMappedDescendant(const RoadMap &thread_roadmap, const isl::schedule_node parent); diff --git a/src/poly/schedule_pass_gpu/operator_mapping_strategy.cc b/src/poly/schedule_pass_gpu/operator_mapping_strategy.cc index 9ee2b1ee..3c8a2fb4 100644 --- a/src/poly/schedule_pass_gpu/operator_mapping_strategy.cc +++ b/src/poly/schedule_pass_gpu/operator_mapping_strategy.cc @@ -344,6 +344,91 @@ isl::schedule ConvMappingStrategy::MoveKernelHWBand(isl::schedule sch) { return sch; } +size_t TOTMappingStrategy::MapThreadHelper(isl::schedule_node &thread_root) { + auto thread_cfg = scop_info_.user_config_.GetThreadConfig(); + CHECK(thread_cfg != nullptr) << "thread config is null"; + if (thread_cfg->bound < 1 || !thread_root.isa()) { + return 0; + } + + int start_node_depth = thread_root.get_tree_depth(); + // Determine max num dimension of threads that can be mapped. + auto n_thread_map = CountConsecutiveCoincident(thread_root); + if (n_thread_map < 1) { + return 0; + } + + if (n_thread_map < thread_root.as().n_member()) { + thread_root = thread_root.as().split(n_thread_map); + } + + // Map band under thread_root from inner dim to outer dim. + auto band_node = thread_root.as(); + auto partial_schedule = band_node.get_partial_schedule(); + auto upa_list = partial_schedule.get_union_pw_aff_list(); + + std::unordered_map tot_mapping = GetRequiredMappingCfg(thread_cfg); + auto prefix_upa_list = GetPrefixPartialSchedule(partial_schedule, thread_root, true); + thread_root = CheckMapSizeAndApplyTile(thread_root, prefix_upa_list, thread_cfg, true, tot_mapping); + thread_root = thread_root.insert_mark(isl::id(thread_root.ctx(), THREAD_MARKER)).child(0); + + std::unordered_set outer_mapping_cfg = {SKIP_MARKER}; + Mapping mapping; + thread_root = InsertRequiredMappingFilter(thread_root, upa_list, thread_cfg, mapping, tot_mapping, outer_mapping_cfg); + + scop_info_.upa_node_mapping_.emplace_back(std::make_pair(thread_root.parent(), mapping)); + + int end_node_depth = thread_root.get_tree_depth() - start_node_depth; + thread_root = thread_root.ancestor(end_node_depth); + return thread_cfg->bound; +} + +isl::schedule_node TOTMappingStrategy::MapBlockHelper(const isl::schedule_node &orig_node, MappingCfg *block_cfg, + size_t n_block_map, bool check_extent, + std::unordered_map &map_idx_shift) { + auto node = orig_node; + auto band_node = node.as(); + if (!band_node || !band_node.permutable()) { + LOG(WARNING) << "No permutable outer band node to map block."; + return node; + } + + auto partial_schedule = band_node.get_partial_schedule(); + auto upa_list = partial_schedule.get_union_pw_aff_list(); + + auto domain = band_node.get_schedule().get_domain(); + isl::union_pw_aff_list range_aff_list(band_node.ctx(), static_cast(upa_list.size())); + for (int i = upa_list.size() - 1; i >= 0; --i) { + auto range = upa_list.get_at(i).intersect_domain(domain); + range_aff_list = range_aff_list.add(range); + } + + std::unordered_map tot_mapping = GetRequiredMappingCfg(block_cfg); + node = CheckMapSizeAndApplyTile(node, range_aff_list, block_cfg, true, tot_mapping); + node = node.insert_mark(isl::id(node.ctx(), BLOCK_MARKER)).child(0); + Mapping mapping; + node = InsertRequiredMappingFilter(node, upa_list, block_cfg, mapping, tot_mapping); + + scop_info_.upa_node_mapping_.emplace_back(std::make_pair(node.parent(), mapping)); + return node; +} + +std::unordered_map TOTMappingStrategy::GetRequiredMappingCfg(MappingCfg *mapping_cfg) { + CHECK(mapping_cfg != nullptr) << "mapping config is null"; + std::unordered_map tot_mapping = {}; + int last_axis = scop_info_.analysis_result_.GetLastAxisInScheduleTree(); + CHECK(last_axis != -1) << "last axis is -1"; + tot_mapping[last_axis] = mapping_cfg->GetAt(0).first; + for (int i = mapping_cfg->bound - 1, j = 1; i >= 0; --i) { + if (i == last_axis) { + continue; + } + tot_mapping[i] = mapping_cfg->GetAt(j).first; + ++j; + } + return tot_mapping; +} + } // namespace poly } // namespace ir } // namespace akg diff --git a/src/poly/schedule_pass_gpu/operator_mapping_strategy.h b/src/poly/schedule_pass_gpu/operator_mapping_strategy.h index d8b4d38c..9dc69296 100644 --- a/src/poly/schedule_pass_gpu/operator_mapping_strategy.h +++ b/src/poly/schedule_pass_gpu/operator_mapping_strategy.h @@ -76,6 +76,18 @@ class ConvMappingStrategy : public OperatorMappingStrategy { isl::schedule MoveKernelHWBand(isl::schedule sch); }; +class TOTMappingStrategy : public OperatorMappingStrategy { + public: + explicit TOTMappingStrategy(PassInfo &pass_info, ScopInfo &scop_info) + : OperatorMappingStrategy(pass_info, scop_info) {} + ~TOTMappingStrategy() {} + + size_t MapThreadHelper(isl::schedule_node &thread_root); + isl::schedule_node MapBlockHelper(const isl::schedule_node &orig_node, MappingCfg *block_cfg, size_t n_block_map, + bool check_extent, std::unordered_map &map_idx_shift); + std::unordered_map GetRequiredMappingCfg(MappingCfg *mapping_cfg); +}; + } // namespace poly } // namespace ir } // namespace akg diff --git a/src/poly/schedule_pass_gpu/operator_shared_strategy.cc b/src/poly/schedule_pass_gpu/operator_shared_strategy.cc index ef0e6f88..4a6eb2ee 100644 --- a/src/poly/schedule_pass_gpu/operator_shared_strategy.cc +++ b/src/poly/schedule_pass_gpu/operator_shared_strategy.cc @@ -52,6 +52,13 @@ std::set OperatorSharedStrategy::GetInitPromotedTensor() { ********************************************************/ std::set_difference(read_sets.begin(), read_sets.end(), write_sets.begin(), write_sets.end(), std::inserter(id_sets, id_sets.begin())); + + if (scop_info_.analysis_result_.GetTensorOfTensor()) { + id_sets.clear(); + std::set_union(read_sets.begin(), read_sets.end(), write_sets.begin(), write_sets.end(), + std::inserter(id_sets, id_sets.begin())); + } + return id_sets; } @@ -106,9 +113,22 @@ void OperatorSharedStrategy::RecordCustomPromotedTensors(std::set & } } +void OperatorSharedStrategy::DeleteNotPromotedTensors(std::set &id_sets) { + if (scop_info_.analysis_result_.GetTensorsNotPromote().empty()) { + return; + } + std::unordered_set tensors = scop_info_.analysis_result_.GetTensorsNotPromote(); + for (const auto &item : tensors) { + if (id_sets.count(item)) { + id_sets.erase(item); + } + } +} + void OperatorSharedStrategy::CreateClusterList(const isl::schedule_node &node, const isl::union_map &outer_sch) { std::set id_sets = GetInitPromotedTensor(); RecordCustomPromotedTensors(id_sets); + DeleteNotPromotedTensors(id_sets); RecordPromotedTensorInfo(node, outer_sch, id_sets); } diff --git a/src/poly/schedule_pass_gpu/operator_shared_strategy.h b/src/poly/schedule_pass_gpu/operator_shared_strategy.h index 43a93c85..0879f79a 100644 --- a/src/poly/schedule_pass_gpu/operator_shared_strategy.h +++ b/src/poly/schedule_pass_gpu/operator_shared_strategy.h @@ -33,6 +33,7 @@ class OperatorSharedStrategy { const std::set &id_sets); void CreateClusterList(const isl::schedule_node &node, const isl::union_map &outer_sch); void RecordCustomPromotedTensors(std::set &id_sets); + void DeleteNotPromotedTensors(std::set &id_sets); protected: ScopInfo &scop_info_; diff --git a/src/poly/schedule_pass_gpu/register_memory_manager.cc b/src/poly/schedule_pass_gpu/register_memory_manager.cc index fb9aeca0..380317f5 100644 --- a/src/poly/schedule_pass_gpu/register_memory_manager.cc +++ b/src/poly/schedule_pass_gpu/register_memory_manager.cc @@ -205,9 +205,16 @@ void RegisterMemoryManager::CreateTensorCluster(const isl::schedule_node &node, if (scop_info_.user_config_.GetEnableMatmul()) { tensor_list.push_back(item); } else { - if (!shared_dst_tensor_ids.count(item.get_name() + SHARE_SUFFIX)) { - tensor_list.push_back(item); + if (shared_dst_tensor_ids.count(item.get_name() + SHARE_SUFFIX)) { + continue; + } + + std::unordered_set tensors = scop_info_.analysis_result_.GetTensorsNotPromote(); + if (tensors.count(item.get_name())) { + continue; } + + tensor_list.push_back(item); } } diff --git a/src/poly/schedule_pass_gpu/shared_memory_manager.cc b/src/poly/schedule_pass_gpu/shared_memory_manager.cc index b5bcc8de..048aedc2 100644 --- a/src/poly/schedule_pass_gpu/shared_memory_manager.cc +++ b/src/poly/schedule_pass_gpu/shared_memory_manager.cc @@ -426,9 +426,49 @@ void SharedMemoryManager::GatherBufferFootprintDefInfo(const isl::schedule_node isl::schedule_node SharedMemoryManager::HoistClusters(const isl::schedule_node &root_node, const isl::schedule_node &node) { auto partial_sched_mupa = ShortScheduleMupa(root_node, node); - auto res_node = node; + + std::vector buffer_def_infos_origin; + std::vector buffer_def_infos_temp; + auto origin_binds = scop_info_.user_config_.GetOriginBind(); + std::unordered_set tensor_name; + + for (auto i : origin_binds) { + if (!i.first.defined()) continue; + tensor_name.insert(i.first->op->name); + } + for (size_t index = 0; index < scop_info_.analysis_result_.buffer_def_infos_.size(); index++) { BufferDefInfo &buffer_info = scop_info_.analysis_result_.buffer_def_infos_[index]; + if (tensor_name.count(buffer_info.tensor_id.get_name())) { + buffer_def_infos_origin.push_back(buffer_info); + } else { + buffer_def_infos_temp.push_back(buffer_info); + } + } + + auto res_node = node; + if (scop_info_.analysis_result_.GetTensorOfTensor()) { + SharedPromotion(buffer_def_infos_temp, res_node, root_node, node, partial_sched_mupa); + SharedPromotion(buffer_def_infos_origin, res_node, root_node, node, partial_sched_mupa); + + scop_info_.analysis_result_.buffer_def_infos_.clear(); + for (auto &b : buffer_def_infos_temp) { + scop_info_.analysis_result_.buffer_def_infos_.push_back(b); + } + for (auto &b : buffer_def_infos_origin) { + scop_info_.analysis_result_.buffer_def_infos_.push_back(b); + } + } else { + SharedPromotion(scop_info_.analysis_result_.buffer_def_infos_, res_node, root_node, node, partial_sched_mupa); + } + return res_node; +} + +void SharedMemoryManager::SharedPromotion(std::vector &bd, isl::schedule_node &res_node, + const isl::schedule_node &root_node, const isl::schedule_node &node, + const isl::multi_union_pw_aff &partial_sched_mupa) { + for (size_t index = 0; index < bd.size(); index++) { + BufferDefInfo &buffer_info = bd[index]; auto fp_cluster = buffer_info.GetFootPrintClusterGPU(node); if ((fp_cluster == nullptr || !fp_cluster->foot_print_.box.is_valid())) { continue; @@ -472,7 +512,6 @@ isl::schedule_node SharedMemoryManager::HoistClusters(const isl::schedule_node & buffer_info.find_buffer = true; } } - return res_node; } isl::schedule_node SharedMemoryManager::HoistToBlockThreadMemory(isl::schedule_node &tree, GpuMemType type, @@ -520,12 +559,7 @@ bool SharedMemoryManager::CoalescingAccessWay(const isl::schedule_node &root, co auto schedule = ShortSchedule(inner_band); auto schedule_access = local_access.apply_domain(schedule); for (auto access : schedule_access.get_map_list()) { - auto schedule_space = access.get_space().domain(); - auto tensor_space = access.get_space().range(); - auto element_next = CreateMapIncreaseDim(tensor_space, tensor_dim - 1); - auto schedule_next = CreateMapIncreaseDim(schedule_space, inner_depth - 1); - auto access_by_adjacent_inner = schedule_next.apply_domain(access).apply_range(access); - if (!access_by_adjacent_inner.is_subset(element_next)) { + if (!IsSubsetForIncreaseDim(access, tensor_dim - 1, inner_depth - 1)) { return true; } } diff --git a/src/poly/schedule_pass_gpu/shared_memory_manager.h b/src/poly/schedule_pass_gpu/shared_memory_manager.h index 0257974e..d20acbb7 100644 --- a/src/poly/schedule_pass_gpu/shared_memory_manager.h +++ b/src/poly/schedule_pass_gpu/shared_memory_manager.h @@ -77,6 +77,9 @@ class SharedMemoryManager : public SchedulePass { isl::schedule_node HoistSharedMemoryOnMark(const isl::schedule_node &root); void PrepareInfoForPromotion(const isl::schedule_node &root); + void SharedPromotion(std::vector &bd, isl::schedule_node &res_node, + const isl::schedule_node &root_node, const isl::schedule_node &node, + const isl::multi_union_pw_aff &partial_sched_mupa); private: ScopInfo &scop_info_; diff --git a/src/poly/schedule_tree_util.cc b/src/poly/schedule_tree_util.cc index c488fc26..3ce5b940 100644 --- a/src/poly/schedule_tree_util.cc +++ b/src/poly/schedule_tree_util.cc @@ -447,6 +447,51 @@ isl::schedule_node InsertMapFilter(const isl::schedule_node &node, const bool is return map_filter_node; } +isl::schedule_node InsertRequiredMappingFilter(const isl::schedule_node &node, isl::union_pw_aff_list upa_list, + MappingCfg *mapping_cfg, Mapping &mapping, + std::unordered_map required_mapping, + std::unordered_set outer_mapping_cfg) { + isl::union_set domain = node.get_schedule().get_domain(); + + std::unordered_set current_mapping_cfg; + for (size_t i = 0; i < upa_list.size(); ++i) { + if (required_mapping.count(static_cast(i)) == 0) { + continue; + } + auto mapping_i = required_mapping[static_cast(i)]; + std::pair cfg = mapping_cfg->GetAt(mapping_i); + current_mapping_cfg.emplace(cfg.first); + + auto upa = upa_list.get_at(i); + CHECK_GT(cfg.second, 0); + upa = upa.mod(isl::val(node.ctx(), cfg.second)); + auto id = isl::id(node.ctx(), cfg.first); + mapping[id] = upa; + domain = upa.domain(); + } + + // Set other configurations to 0. + if (!outer_mapping_cfg.empty()) { + for (size_t i = 0; i < mapping_cfg->bound; ++i) { + CHECK(!domain.is_null()); + auto universe = domain.universe(); + // Remove the configuration that has been mapped. + if (current_mapping_cfg.find(mapping_cfg->GetAt(i).first) != current_mapping_cfg.end()) { + continue; + } + // Remove the configuration in the outer mapping. + if (outer_mapping_cfg.find(mapping_cfg->GetAt(i).first) != outer_mapping_cfg.end()) { + continue; + } + std::pair cfg = mapping_cfg->GetAt(i); + auto id = isl::id(node.ctx(), cfg.first); + mapping[id] = isl::union_pw_aff(universe, isl::val::zero(domain.ctx())); + } + } + + return InsertMapFilter(node, false, mapping); +} + /* * When mapping size is smaller than the extent of corresponding axis, we may encounter several problems if the axis * is not tiled. Firstly, for case that extent is multiplies of mapping sizes, directly mapping the axis will @@ -458,7 +503,7 @@ isl::schedule_node InsertMapFilter(const isl::schedule_node &node, const bool is isl::schedule_node CheckMapSizeAndApplyTile(const isl::schedule_node &mapping_root, const isl::union_pw_aff_list &aff_list, MappingCfg *mapping_cfg, const bool need_reverse, - std::unordered_map custom_mapping) { + std::unordered_map required_mapping) { bool need_tile = false; std::vector mapping_sizes; CHECK(mapping_cfg != nullptr) << "mapping config is null"; @@ -483,10 +528,10 @@ isl::schedule_node CheckMapSizeAndApplyTile(const isl::schedule_node &mapping_ro extent = aff.max_val().get_num_si() + 1; map_size = extent; // custom mapping - if (!custom_mapping.empty()) { - bool is_config = (custom_mapping.count(static_cast(i)) != 0); + if (!required_mapping.empty()) { + bool is_config = (required_mapping.count(static_cast(i)) != 0); if (is_config) { - auto mapping_i = custom_mapping[static_cast(i)]; + auto mapping_i = required_mapping[static_cast(i)]; map_size = mapping_cfg->GetAt(mapping_i).second; } RecordMappingSizes(is_config, false); @@ -703,6 +748,61 @@ isl::map CreateMapIncreaseDim(isl::space space, unsigned dim) { return isl::map(identity); } +bool IsSubsetForIncreaseDim(const isl::map access, size_t tensor_dim, size_t node_dim) { + auto schedule_space = access.get_space().domain(); + auto tensor_space = access.get_space().range(); + auto element_next = CreateMapIncreaseDim(tensor_space, tensor_dim); + + auto schedule_next = CreateMapIncreaseDim(schedule_space, node_dim); + auto access_by_adjacent_inner = schedule_next.apply_domain(access).apply_range(access); + if (!access_by_adjacent_inner.is_subset(element_next)) { + return false; + } + return true; +} + +int GetLastAxis(const isl::schedule_node node, isl::union_map original_access, + std::unordered_set skip_tensors) { + if (!node.isa()) { + return -1; + } + // Get current node information. + auto active_domains = CollectDomain(node); + auto local_access = original_access.intersect_domain(active_domains); + auto schedule = LocalSchedule(node); + auto schedule_access = local_access.apply_domain(schedule); + + int node_depth = static_cast(node.as().n_member()); + for (auto access : schedule_access.get_map_list()) { + // Skip the related tensor in tensor of tensor. + auto tensor_name = access.range().get_tuple_name(); + if (skip_tensors.count(tensor_name) != 0) { + continue; + } + + int tensor_dim = -1; + for (int i = static_cast(access.range().n_dim()) - 1; i >= 0; --i) { + auto axis_i = access.range().dim_max(i); + if (!axis_i.is_equal(isl::pw_aff(axis_i.domain(), isl::val(axis_i.ctx(), 0)))) { + tensor_dim = i; + break; + } + } + + if (tensor_dim < 0) { + continue; + } + + for (int i = node_depth - 1; i >= 0; --i) { + if (!IsSubsetForIncreaseDim(access, tensor_dim, i)) { + continue; + } + return i; + } + } + return -1; +} + std::vector CollectFnNode(const std::function &fn, const isl::schedule_node &root) { std::vector res_nodes; diff --git a/src/poly/schedule_tree_util.h b/src/poly/schedule_tree_util.h index 402c72b5..70788aea 100644 --- a/src/poly/schedule_tree_util.h +++ b/src/poly/schedule_tree_util.h @@ -101,11 +101,15 @@ isl::schedule_node AnalysisNodeAndInsertMapFilter(const isl::schedule_node &node isl::union_pw_aff_list upa_list, MappingCfg *mapping_cfg, Mapping &mapping, std::unordered_map map_idx_shift = {}); +isl::schedule_node InsertRequiredMappingFilter(const isl::schedule_node &node, isl::union_pw_aff_list upa_list, + MappingCfg *mapping_cfg, Mapping &mapping, + std::unordered_map required_mapping, + std::unordered_set outer_mapping_cfg = {}); isl::schedule_node InsertMapFilter(const isl::schedule_node &node, const bool is_promotion, Mapping &mapping); isl::schedule_node CheckMapSizeAndApplyTile(const isl::schedule_node &thread_root, const isl::union_pw_aff_list &aff_list, MappingCfg *mapping_cfg, const bool need_reverse, - std::unordered_map custom_mapping = {}); + std::unordered_map required_mapping = {}); bool IsEqualNode(const isl::schedule_node node1, const isl::schedule_node node2); isl::multi_union_pw_aff MapDomainToThread(const isl::schedule_node &node, MappingCfg *mapping_cfg, @@ -114,6 +118,9 @@ isl::multi_union_pw_aff MapDomainAllWithType(const isl::schedule_node &node, Map const UpaNodeMapping &upa_node_mapping, const std::string &map_type); isl::map CreateMapIncreaseDim(isl::space space, unsigned dim); +bool IsSubsetForIncreaseDim(const isl::map access, size_t tensor_dim, size_t node_dim); +int GetLastAxis(const isl::schedule_node node, isl::union_map original_access, + std::unordered_set skip_tensors); std::vector CollectFnNode(const std::function &fn, const isl::schedule_node &root); diff --git a/src/poly/scop_info.h b/src/poly/scop_info.h index b722a05b..48269bc5 100644 --- a/src/poly/scop_info.h +++ b/src/poly/scop_info.h @@ -105,14 +105,14 @@ struct MappingCfg { } std::pair GetAt(std::string cfg_name) { std::pair fixed_pos_cfg = {}; - if (cfg_name == T0 || cfg_name == B0) { + if (cfg_name.find(T0) != std::string::npos || cfg_name.find(B0) != std::string::npos) { fixed_pos_cfg = GetX(); - } else if (cfg_name == T1 || cfg_name == B1) { + } else if (cfg_name.find(T1) != std::string::npos || cfg_name.find(B1) != std::string::npos) { fixed_pos_cfg = GetY(); - } else if (cfg_name == T2 || cfg_name == B2) { + } else if (cfg_name.find(T2) != std::string::npos || cfg_name.find(B2) != std::string::npos) { fixed_pos_cfg = GetZ(); } else { - LOG(FATAL) << "Mapping config can only contain t0, t1, t2, b0, b1 and b2."; + LOG(FATAL) << "Mapping config can contain t0, t1, t2, b0, b1 and b2."; }; return fixed_pos_cfg; } @@ -908,6 +908,25 @@ class AnalysisResult { tiling_constraints_ = std::move(tiling_constraints); } + std::map GetTensorOfTensorStmt() const { return tensor_of_tensor_stmt_; } + void RecordTensorOfTensorStmt(const std::string &id_name, const std::string &op_type) { + tensor_of_tensor_stmt_[id_name] = op_type; + } + + bool GetTensorOfTensor() const { return is_tensor_of_tensor_; } + void SetTensorOfTensor(const bool &is_tensor_of_tensor) { is_tensor_of_tensor_ = is_tensor_of_tensor; } + + int GetLastAxisInScheduleTree() const { return last_axis_in_schedule_tree_; } + void SetLastAxisInScheduleTree(const int last_axis_in_schedule_tree) { + last_axis_in_schedule_tree_ = last_axis_in_schedule_tree; + } + + std::unordered_set GetTensorsNotPromote() const { return tensors_not_promote_; } + void RecordTensorsNotPromote(const std::string &tensor_name) { tensors_not_promote_.insert(tensor_name); } + + std::unordered_set GetInnerTensor() const { return inner_tensor_; } + void RecordInnerTensor(const std::string &tensor_name) { inner_tensor_.insert(tensor_name); } + // dump all data void DumpScopDataBasics(std::ofstream &of); @@ -1024,6 +1043,15 @@ class AnalysisResult { // custom mapping bool is_custom_mapping_{false}; bool is_outer_block_mapping_{false}; + + // All axis of each tensor + std::unordered_map> tensor_all_axis_; + // tensor_of_tensor + std::map tensor_of_tensor_stmt_; + std::unordered_set tensors_not_promote_; + std::unordered_set inner_tensor_; + int last_axis_in_schedule_tree_{-1}; + bool is_tensor_of_tensor_{false}; }; class CubeInfo { diff --git a/src/poly/scop_make_schedule_tree.cc b/src/poly/scop_make_schedule_tree.cc index 80dafad0..ea049c82 100644 --- a/src/poly/scop_make_schedule_tree.cc +++ b/src/poly/scop_make_schedule_tree.cc @@ -883,7 +883,25 @@ class ScopMakeScheduleTree final : protected IRVisitor { } void Visit_(const AttrStmt *op) final { - if (op->attr_key == air::ir::attr::reduce_update) { + if (op->attr_key == air::ir::attr::atomic_tot) { + size_t stmt_index = scop_info_.analysis_result_.GetStatementMap().size(); + isl::id id(set.ctx(), macro_stmt >= 0 ? kStatementLabel + std::to_string(macro_stmt) + : kStatementLabel + std::to_string(stmt_index)); + CHECK(op->value.as()); + scop_info_.analysis_result_.RecordTensorOfTensorStmt(id.get_name(), op->value.as()->value); + sch = MakeScheduleTreeHelper(op->body, scop_info_, set, outer, macro_stmt); + } else if (op->attr_key == "TENSOR_OF_TENSOR") { + scop_info_.analysis_result_.SetTensorOfTensor(true); + sch = MakeScheduleTreeHelper(op->body, scop_info_, set, outer, macro_stmt); + } else if (op->attr_key == "TENSOR_NOT_PROMOTE") { + CHECK(op->value.as()); + scop_info_.analysis_result_.RecordTensorsNotPromote(op->value.as()->value); + sch = MakeScheduleTreeHelper(op->body, scop_info_, set, outer, macro_stmt); + } else if (op->attr_key == "INNER_TENSOR") { + CHECK(op->value.as()); + scop_info_.analysis_result_.RecordInnerTensor(op->value.as()->value); + sch = MakeScheduleTreeHelper(op->body, scop_info_, set, outer, macro_stmt); + } else if (op->attr_key == air::ir::attr::reduce_update) { Array red = Downcast>(op->node); const auto pro = op->body.as(); if (pro) { diff --git a/src/poly/tiling/tiling_analyzer.cc b/src/poly/tiling/tiling_analyzer.cc index 8f6593af..c40c2fce 100644 --- a/src/poly/tiling/tiling_analyzer.cc +++ b/src/poly/tiling/tiling_analyzer.cc @@ -54,7 +54,7 @@ TileAxis::TileAxis(TileAxis *p, int i, int da, bool mc, const std::pair TilingAnalyzer::GetAxesContainsAttr(const std::string attr_key) const { +std::vector TilingAnalyzer::GetAxesContainsAttr(const std::string &attr_key) const { std::vector axes; auto AddAxisWithAttr = [&attr_key, &axes](TileAxis *a) { for (const auto &attr : a->attrs) { @@ -1255,7 +1255,7 @@ std::vector TilingAnalyzer::GetAxesContainsAttr(const std::string at return axes; } -std::vector TilingAnalyzer::GetAxesOfAttr(const std::string attr_key) const { +std::vector TilingAnalyzer::GetAxesOfAttr(const std::string &attr_key) const { std::vector axes; auto AddAxisWithAttr = [&attr_key, &axes](TileAxis *a) { for (const auto &attr : a->attrs) { @@ -1269,7 +1269,7 @@ std::vector TilingAnalyzer::GetAxesOfAttr(const std::string attr_key return axes; } -std::vector TilingAnalyzer::GetAxesOfAttr(const AttrInfo attr_info) const { +std::vector TilingAnalyzer::GetAxesOfAttr(const AttrInfo &attr_info) const { std::vector axes; auto AddAxisWithAttr = [&attr_info, &axes](TileAxis *a) { for (const auto &attr : a->attrs) { @@ -1355,6 +1355,7 @@ void TilingAnalyzer::AddPostTilingConstraints() { if (scop_info_.user_config_.GetTarget() == TARGET_CUDA) { ReduceStrategy reduce_strategy(this); ModStrategy mod_strategy(this); + ShiftAxisStrategy shift_strategy(this); GemmStrategy gemm_strategy(this); ConvStrategy conv_strategy(this); GpuDmaAnalysisStrategy dma_analysis_strategy(this); @@ -1368,6 +1369,7 @@ void TilingAnalyzer::AddPostTilingConstraints() { } actived_strategies.push_back(&reduce_strategy); actived_strategies.push_back(&mod_strategy); + actived_strategies.push_back(&shift_strategy); actived_strategies.push_back(&gemm_strategy); actived_strategies.push_back(&conv_strategy); actived_strategies.push_back(&gpu_strategy); @@ -1452,7 +1454,7 @@ void TilingAnalyzer::AddTilingConstraints() { bool TilingAnalyzer::Prepare() { logger_ = std::unique_ptr(new (std::nothrow) TileLogger( - scop_info_.AddDumpDir("tiling.log"), !scop_info_.user_config_.GetDumpPolyDir().empty())); + scop_info_.AddDumpDir("tiling.log"), !scop_info_.user_config_.GetDumpPolyDir().empty())); CHECK(logger_) << "memory alloc fail."; // Stage 1: Analyze schedule tree. ScheduleTreeAnalyzer sch_ana(this, this->sch_); diff --git a/src/poly/tiling/tiling_analyzer.h b/src/poly/tiling/tiling_analyzer.h index ada31bc0..e3376e7e 100755 --- a/src/poly/tiling/tiling_analyzer.h +++ b/src/poly/tiling/tiling_analyzer.h @@ -54,7 +54,6 @@ constexpr auto MAX_REPEAT = 255; constexpr auto MIN_CORE_GRANULARITY = 256; constexpr auto DESIRE_CORE_GRANULARITY = 8192; - // Controlled by custom tiling. constexpr auto ALLOCATION_PERCENTAGE = 0.5; // reserved for double buffer in default @@ -171,7 +170,7 @@ class TilingAnalyzer; class TileAxis { public: TileAxis(TileAxis *p, int i, int da, bool mc, const std::pair &ds, bool inner, TilingAnalyzer *ta); - TileAxis(const Expr &l1_size, Expr l0_size, std::string at, TilingAnalyzer *ta, bool inner = false); + TileAxis(const Expr &l1_size, const Expr &l0_size, const std::string &at, TilingAnalyzer *ta, bool inner = false); ~TileAxis() {} struct Constraint { Expr tile_mod_{MIN_TILE}; @@ -265,12 +264,12 @@ class TilingAnalyzer { sch_(sch), scop_info_(scop_info), is_retry_(!g_attrs.GetStr(kErrorInfo, "").empty()) { - if (scop_info.mmu_info_.IsGemm()) { - op_type_ = GEMM_OP; - } else if (scop_info.mmu_info_.IsConv()) { - op_type_ = CONV_OP; - } else { - op_type_ = VECTOR_OP; + if (scop_info.mmu_info_.IsGemm()) { + op_type_ = GEMM_OP; + } else if (scop_info.mmu_info_.IsConv()) { + op_type_ = CONV_OP; + } else { + op_type_ = VECTOR_OP; } } @@ -306,9 +305,9 @@ class TilingAnalyzer { inline Stmt Halide() const { return body_; } - std::vector GetAxesContainsAttr(std::string attr_key) const; - std::vector GetAxesOfAttr(std::string attr_key) const; - std::vector GetAxesOfAttr(AttrInfo attr_info) const; + std::vector GetAxesContainsAttr(const std::string &attr_key) const; + std::vector GetAxesOfAttr(const std::string &attr_key) const; + std::vector GetAxesOfAttr(const AttrInfo &attr_info) const; TileAxis *Axis(const For *loop) const { auto it = tile_axis_.find(loop); @@ -339,7 +338,8 @@ class TilingAnalyzer { std::unordered_map> buffer_usage_timetable_; std::unordered_map> buf_info_; bool is_retry_{false}; - std::vector binding_spaces_; // [thread.x[min, max, mod], thread.y, thread.z, block.x, block.y, block.z] + std::vector + binding_spaces_; // [thread.x[min, max, mod], thread.y, thread.z, block.x, block.y, block.z] private: void AddTilingConstraints(); void AddPostTilingConstraints(); diff --git a/src/poly/tiling/tiling_strategy_manager.h b/src/poly/tiling/tiling_strategy_manager.h index 99aa28bc..884801b6 100755 --- a/src/poly/tiling/tiling_strategy_manager.h +++ b/src/poly/tiling/tiling_strategy_manager.h @@ -244,8 +244,27 @@ class ShiftAxisStrategy : public TilingStrategy { void AddNpuConstraint(); void AddGpuConstraint(); + void TileEntirely() { + auto interested_info = GetInterestedInfo(interested_attr_key); + for (auto it : interested_info) { + TileAxis *axis = it.first; + int64_t const_extent = axis->GetConstExtent(); + if (const_extent == -1) { + continue; + } + shifted_axes_.insert(axis); + for (const auto &attr : it.second) { + CHECK_NE(attr.attr_value, ""); + auto share_time = static_cast(std::strtol(attr.attr_value.c_str(), nullptr, 10)); + axis->TileRestrainToSingleValue(const_extent * (share_time + 1), CACHE1); + break; + } + } + } + std::string interested_attr_key = AT_SHIFT; -}; + std::unordered_set shifted_axes_; +}; // namespace poly class ModShiftAxisStrategy : public TilingStrategy { public: @@ -401,6 +420,8 @@ class GpuStrategy : public TilingStrategy { void MarkMappingInRootAxis(); int GetLocalAllocBufCount(); + bool NeedModifyOrderOfAxis(); + void SetCoalescedAccess(); Template template_{Template::DEFAULT}; bool is_reduce_op_[TEMPLATE_BULK] = {false, false, true, true, true, false}; diff --git a/src/poly/tiling/tiling_strategy_manager_gpu.cc b/src/poly/tiling/tiling_strategy_manager_gpu.cc index 1eeaca58..22f42961 100755 --- a/src/poly/tiling/tiling_strategy_manager_gpu.cc +++ b/src/poly/tiling/tiling_strategy_manager_gpu.cc @@ -292,7 +292,7 @@ void ReduceStrategy::AddGpuConstraint() { has_transpose_ = std::any_of(axis->attrs.begin(), axis->attrs.end(), HasTranspose); } - if (axis == analyzer_->RootAxis()) { + if (axis == analyzer_->RootAxis() || axis->is_inner) { return; } ++depth; @@ -371,7 +371,7 @@ void ReduceStrategy::AkgReduceLibStrategyOnGpu() { } } - bool square_thread = analyzer_->scop_info_.analysis_result_.GetReduceDirection() == Y_DIRECTION && + bool square_thread = analyzer_->scop_info_.analysis_result_.GetReduceDirection() == Y_DIRECTION && analyzer_->scop_info_.user_config_.GetEnableAkgReduceLib(); int64_t total_reduce_size = 1; int64_t total_injective_size = 1; @@ -783,6 +783,63 @@ void GpuStrategy::ShowOptions() { analyzer_->GetTileLogger().AppendLog(GPU_MAPPING, ss); } +bool GpuStrategy::NeedModifyOrderOfAxis() { + int last_axis = analyzer_->scop_info_.analysis_result_.GetLastAxisInScheduleTree(); + if (last_axis < 0 || last_axis >= static_cast(pending_axes_.size())) { + return false; + } + + int real_pos = static_cast(pending_axes_.size()) - 1 - last_axis; + if (real_pos == 0) { + return false; + } + + TileAxis *axis; + int64_t shape; + std::tie(axis, shape) = pending_axes_[real_pos]; + pending_axes_.erase(pending_axes_.begin() + real_pos, pending_axes_.begin() + real_pos + 1); + pending_axes_.push_front(std::make_pair(axis, shape)); + return true; +} + +// For the tensor of tensor operator, confirm whether coalesced access is required in the calculation phase. +void GpuStrategy::SetCoalescedAccess() { + isl::schedule_node root = analyzer_->sch_.get_root(); + isl::schedule_node node = GetOuterBand(root); + if (!node.isa()) { + return; + } + + if (analyzer_->scop_info_.user_config_.GetEnableAkgReduceLib() || analyzer_->scop_info_.user_config_.GetEnableMatmul()) { + return; + } + + auto band_node = node.as(); + auto n_parallel_axis = CountConsecutiveCoincident(band_node); + node = band_node.split(n_parallel_axis); + + std::unordered_set skip_tensors = analyzer_->scop_info_.analysis_result_.GetTensorsNotPromote(); + for (auto inner_tensor : analyzer_->scop_info_.analysis_result_.GetInnerTensor()) { + skip_tensors.emplace(inner_tensor); + } + + // Get read and write tensor information. + auto reads_access = analyzer_->scop_info_.analysis_result_.GetReads().domain_factor_domain(); + int last_axis = GetLastAxis(node, reads_access, skip_tensors); + if (last_axis != -1) { + analyzer_->scop_info_.analysis_result_.SetLastAxisInScheduleTree(last_axis); + return; + } + + auto write_access = analyzer_->scop_info_.analysis_result_.GetWrites().domain_factor_domain(); + last_axis = GetLastAxis(node, write_access, skip_tensors); + if (last_axis != -1) { + analyzer_->scop_info_.analysis_result_.SetLastAxisInScheduleTree(last_axis); + return; + } +} + + void GpuStrategy::AddGpuConstraint() { ShowOptions(); InitMappingLimit(); @@ -809,8 +866,16 @@ void GpuStrategy::AddGpuConstraint() { } return; } + // tensor of tensor + bool need_injective_speed_up = true; + if ((template_ == Template::PURE_ELEM || template_ == Template::BROADCAST_OP) && + analyzer_->scop_info_.analysis_result_.GetTensorOfTensor()) { + SetCoalescedAccess(); + need_injective_speed_up = !NeedModifyOrderOfAxis(); + } + InnerThreadOuterBlock(); - if (template_ == Template::PURE_ELEM) { + if (template_ == Template::PURE_ELEM && need_injective_speed_up) { InjectiveSpeedup(); } @@ -2245,6 +2310,14 @@ std::pair ConvStrategy::GetDivisibleFactorForMN(int64_t shape_ return std::make_pair(1, 1); } +void ShiftAxisStrategy::AddGpuConstraint() { + TileEntirely(); + for (auto axis : shifted_axes_) { + axis->block_constraints.map_extent_ = 1; + axis->thread_constraints.map_extent_ = 1; + } +} + // No constraint found in cuda void ModStrategy::AddGpuConstraint() {} @@ -2263,8 +2336,6 @@ void DynamicShapeLimitStrategy::AddGpuConstraint() {} void DynamicBoundStrategy::AddGpuConstraint() {} -void ShiftAxisStrategy::AddGpuConstraint() {} - void ModShiftAxisStrategy::AddGpuConstraint() {} // end of null constraint diff --git a/src/poly/tiling/tiling_strategy_manager_npu.cc b/src/poly/tiling/tiling_strategy_manager_npu.cc index 7ac58822..348e9554 100644 --- a/src/poly/tiling/tiling_strategy_manager_npu.cc +++ b/src/poly/tiling/tiling_strategy_manager_npu.cc @@ -281,22 +281,7 @@ void DynamicBoundStrategy::AddNpuConstraint() { } } -void ShiftAxisStrategy::AddNpuConstraint() { - auto interested_info = GetInterestedInfo(interested_attr_key); - for (auto it : interested_info) { - TileAxis *axis = it.first; - int64_t const_extent = axis->GetConstExtent(); - if (const_extent == -1) { - continue; - } - for (const auto &attr : it.second) { - CHECK_NE(attr.attr_value, ""); - auto share_time = static_cast(std::strtol(attr.attr_value.c_str(), nullptr, 10)); - axis->TileRestrainToSingleValue(const_extent * (share_time + 1), CACHE1); - break; - } - } -} +void ShiftAxisStrategy::AddNpuConstraint() { TileEntirely(); } void ModShiftAxisStrategy::AddNpuConstraint() { auto interested_info = GetInterestedInfo(interested_attr_key); diff --git a/tests/common/gen_json_data.py b/tests/common/gen_json_data.py index 1a4d8237..b1459f6d 100644 --- a/tests/common/gen_json_data.py +++ b/tests/common/gen_json_data.py @@ -18,10 +18,13 @@ import tempfile import json import logging import inspect +from collections import namedtuple + import numpy as np +import scipy as sp from akg.global_configs import get_ascend_meta_path, get_cuda_meta_path -from tests.common.gen_random import random_gaussian -from tests.common.test_utils import precheck +from tests.common.gen_random import random_gaussian, gen_indices +from tests.common.test_utils import precheck, tensor_scatter_add_np, gather_np def get_attr(attr_desc, attr_type): @@ -381,6 +384,18 @@ op_dsl = { "RealDiv": lambda inputs, output, attr: "%s = np.divide(%s, %s)" % (output[0]['tensor_name'], get_input( inputs[0][0]), get_input(inputs[1][0])), + "Div": lambda inputs, output, attr: "%s = np.divide(%s, %s)" % + (output[0]['tensor_name'], get_input( + inputs[0][0]), get_input(inputs[1][0])), + "FloorDiv": lambda inputs, output, attr: "%s = np.floor_divide(%s, %s)" % + (output[0]['tensor_name'], get_input( + inputs[0][0]), get_input(inputs[1][0])), + "Mod": lambda inputs, output, attr: "%s = np.fmod(%s, %s)" % + (output[0]['tensor_name'], get_input( + inputs[0][0]), get_input(inputs[1][0])), + "FloorMod": lambda inputs, output, attr: "%s = np.mod(%s, %s)" % + (output[0]['tensor_name'], get_input( + inputs[0][0]), get_input(inputs[1][0])), "Minimum": lambda inputs, output, attr: "%s = np.minimum(%s, %s)" % (output[0]['tensor_name'], get_input( inputs[0][0]), get_input(inputs[1][0])), @@ -410,6 +425,9 @@ op_dsl = { "Equal": lambda inputs, output, attr: "%s = np.equal(%s, %s)" % (output[0]['tensor_name'], get_input( inputs[0][0]), get_input(inputs[1][0])), + "NotEqual": lambda inputs, output, attr: "%s = np.not_equal(%s, %s)" % + (output[0]['tensor_name'], get_input( + inputs[0][0]), get_input(inputs[1][0])), "GreaterEqual": lambda inputs, output, attr: "%s = np.greater_equal(%s, %s)" % (output[0]['tensor_name'], get_input( inputs[0][0]), get_input(inputs[1][0])), @@ -461,6 +479,20 @@ op_dsl = { (output[0]['tensor_name'], get_input(inputs[0][0])), "LogicalAnd": lambda inputs, output, attr: "%s = np.logical_and(%s, %s)" % (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0])), + "LogicalOr": lambda inputs, output, attr: "%s = np.logical_or(%s, %s)" % + (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0])), + "Erf": lambda inputs, output, attr: "%s = sp.special.erf(%s)" % + (output[0]['tensor_name'], get_input(inputs[0][0])), + "TensorScatterAdd": lambda inputs, output, attr: "%s = tensor_scatter_add_np(%s, %s, %s)" % + (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0]), + get_input(inputs[2][0])), + "GatherNd": lambda inputs, output, attr: "%s = %s[tuple(%s.transpose().tolist())]" % + (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0])), + "UnsortedSegmentSum": lambda inputs, output, attr: "%s = np.zeros([%s,] + %s[%s:]); np.add.at(%s, %s, %s)" % + (output[0]['tensor_name'], get_attr(attr, 'num_segments'), inputs[0][0]['shape'], len(inputs[1][0]['shape']), + output[0]['tensor_name'], get_input(inputs[1][0]), get_input(inputs[0][0])), + "Gather": lambda inputs, output, attr: "%s = gather_np(%s, %s, %s)" % + (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0]), get_attr(attr, "axis")) } def conv_2d_str(inputs, output, attr): @@ -573,11 +605,21 @@ def gen_json_data(op_desc, with_compute=True): p = CodePrinter(uni_file_name) idx = 0 - # Collect input which should be processed by atomic clean. + # Collect input which should be processed by atomic clean / or should be indices. clean_input = [] + indices_input = {} + MakeIndices = namedtuple("MakeIndices", "name data_shape indices_shape indices_dtype attrs") sum_out = None for op in desc["op_desc"]: - if op["name"] == "ReduceSum": + if op["name"] in ("ReduceSum", "UnsortedSegmentSum"): + if op["name"] == "UnsortedSegmentSum": + assert op["input_desc"][1][0]["data_type"] == "int32", "Default indices type should be int32" + assert op["attr"][1]["name"] == "num_segments", "UnsortedSegmentSum only accepts num_segments attribute." + indices_input[op["input_desc"][1][0]["tensor_name"]] = \ + MakeIndices(name=op["name"], data_shape=op["input_desc"][0][0]["shape"], + indices_shape=op["input_desc"][1][0]["shape"], + indices_dtype=op["input_desc"][1][0]["data_type"], + attrs=op["attr"][1]["value"]) for a in op["attr"]: if a["name"] == "enable_atomic_add": sum_out = op["output_desc"][0]["tensor_name"] @@ -587,6 +629,20 @@ def gen_json_data(op_desc, with_compute=True): continue if op["input_desc"][1][0]["tensor_name"] == sum_out: clean_input.append(op["input_desc"][0][0]["tensor_name"]) + elif op["name"] in ("TensorScatterAdd", "Gather", "GatherNd"): + assert op["input_desc"][1][0]["data_type"] == "int32", "Default indices type should be int32" + indices_input[op["input_desc"][1][0]["tensor_name"]] = \ + MakeIndices(name=op["name"], data_shape=op["input_desc"][0][0]["shape"], + indices_shape=op["input_desc"][1][0]["shape"], + indices_dtype=op["input_desc"][1][0]["data_type"], + attrs=None) + if op["name"] == "Gather": + assert op["attr"][0]["name"] == "axis", "Gather only accepts axis attribute." + indices_input[op["input_desc"][1][0]["tensor_name"]] = \ + MakeIndices(name=op["name"], data_shape=op["input_desc"][0][0]["shape"], + indices_shape=op["input_desc"][1][0]["shape"], + indices_dtype=op["input_desc"][1][0]["data_type"], + attrs=op["attr"][0]["value"]) input_mean_value = precheck(desc) for input_desc in desc["input_desc"] if desc["input_desc"] is not None else []: @@ -595,6 +651,8 @@ def gen_json_data(op_desc, with_compute=True): tensor_name = input_desc[0]["tensor_name"] if tensor_name in clean_input: item = np.zeros(shape).astype(dtype) + elif tensor_name in indices_input.keys(): + item = gen_indices(indices_input[tensor_name]) else: item = random_gaussian(shape, miu=input_mean_value, sigma=0.1).astype(dtype) input_for_mod.append(item) diff --git a/tests/common/gen_random.py b/tests/common/gen_random.py index 978ce0c1..d011c0a8 100644 --- a/tests/common/gen_random.py +++ b/tests/common/gen_random.py @@ -114,3 +114,72 @@ def random_gaussian(size, miu=0, sigma=8, epsilon=0, seed=None): def gen_epsilon(dtype): """Generate suggested epsilon according to data type.""" return 1e-7 if dtype == np.float32 else 1e-3 + +def gen_indices_tensor_scatter_add(shape1, shape2, dtype2): + assert dtype2 == "int32", "Currently only support int32 indices" + indices = np.zeros(shape2, dtype2) + indices = indices.reshape(-1, indices.shape[-1]) + for i in range(indices.shape[0]): + update_idx = [] + for j in range(indices.shape[1]): + # add outbounds situation + if np.random.random() < 0.1: + if np.random.random() < 0.5: + # less than 0 + indices[i][j] = -1 + else: + # larger than original shape + indices[i][j] = shape1[j] + 10 + else: + indices[i][j] = np.random.randint(shape1[j], size=()) + indices = indices.reshape(shape2) + return indices + +def gen_indices_gather(shape1, shape2, dtype2, axis): + assert dtype2 == "int32", "Currently only support int32 indices" + indices = np.random.randint(low=0, high=shape1[axis], size=shape2).astype(dtype2) + offset = np.random.choice((0, 0, 10, -10), shape2).astype(dtype2) + indices += offset + return indices + +def gen_indices_unsorted_segment_sum(shape1, shape2, dtype2, num): + # currently only support 1D + assert dtype2 == "int32", "Currently only support int32 indices" + return np.random.randint(low=0, high=num, size=shape2).astype(dtype2) + +def gen_indices_gather_nd(shape1, shape2, dtype2): + out_dim1 = 1 + for i in range(len(shape2) - 1): + out_dim1 = out_dim1 * shape2[i] + assert dtype2 == "int32", "Currently only support int32 indices" + indices = np.zeros([shape2[-1], out_dim1]).astype(dtype2) + for i in range(shape2[-1]): + # add outbounds situation + if np.random.random() < 0.1: + if np.random.random() < 0.5: + # less than 0 + indices[i] = np.random.randint(low=0, high=shape1[i], size=out_dim1) - 10 + else: + # larger than original shape + indices[i] = np.random.randint(low=0, high=shape1[i], size=out_dim1) + 10 + else: + indices[i] = np.random.randint(low=0, high=shape1[i], size=out_dim1) + + indices = indices.transpose() + indices = indices.reshape(shape2) + return indices + +def gen_indices(indices_argument): + op_name = indices_argument.name + data_shape = indices_argument.data_shape + indices_shape = indices_argument.indices_shape + indices_dtype = indices_argument.indices_dtype + attrs = indices_argument.attrs + if op_name == "Gather": + return gen_indices_gather(data_shape, indices_shape, indices_dtype, attrs) + elif op_name == "GatherNd": + return gen_indices_gather_nd(data_shape, indices_shape, indices_dtype) + elif op_name == "UnsortedSegmentSum": + return gen_indices_unsorted_segment_sum(data_shape, indices_shape, indices_dtype, attrs) + assert op_name == "TensorScatterAdd", "Input OP Name Not Known!" + return gen_indices_tensor_scatter_add(data_shape, indices_shape, indices_dtype) diff --git a/tests/common/test_utils.py b/tests/common/test_utils.py index 71c1795d..54962fcd 100644 --- a/tests/common/test_utils.py +++ b/tests/common/test_utils.py @@ -17,6 +17,7 @@ import math import random import logging +from copy import deepcopy import numpy as np import akg.tvm from akg.utils.validation_check import MAX_DATA_SIZE @@ -207,4 +208,71 @@ def precheck(desc): return 1 logging.info( "Input data with mean value {} is generated".format(initial_input)) - return initial_input \ No newline at end of file + return initial_input + + +def gather_nd_np(data, indices): + data_shape = data.shape + indices_shape = indices.shape + new_indices = indices.reshape(-1, indices.shape[-1]) + left_shape = indices_shape[:-1] + right_shape = data_shape[int(indices_shape[-1]):] + out_shape = left_shape + right_shape + out = np.zeros(out_shape, np.float32).reshape(new_indices.shape[0], -1) + new_data = deepcopy(data).reshape(-1, int(np.prod(data_shape[int(indices_shape[-1]):]))) + for i in range(new_indices.shape[0]): + for j in range(out.shape[1]): + index_read = [i,0] + index_write = [0,0] + inbound = True + for k in range(0, int(indices_shape[-1])): + temp_idx = new_indices[i, k] + inbound = np.all((inbound, (temp_idx >= 0), (temp_idx < data_shape[k]))) + index_write[0] += int(temp_idx * data.strides[k] / data.itemsize) + index_write[0] = int(index_write[0] / out.shape[1]) + index_read[1] = j + if inbound: + index_write[1] = j + out[tuple(index_read)] = new_data[tuple(index_write)] + return out.reshape(out_shape) + + +def tensor_scatter_add_np(data, indices, updates): + data_shape = data.shape + indices_shape = indices.shape + updates_shape = updates.shape + if indices.ndim > 1: + new_indices = indices.reshape(-1, indices.shape[-1]) + out = deepcopy(data).reshape(-1, int(np.prod(data_shape[int(indices_shape[-1]):]))) + else: + new_indices = indices.reshape(-1, 1) + out = deepcopy(data).reshape(-1, int(np.prod(data_shape[1:]))) + new_updates = updates.reshape(new_indices.shape[0], -1) + for i in range(new_indices.shape[0]): + for j in range(out.shape[1]): + index_read = [i,0] + index_write = [0,0] + inbound = True + for k in range(0, int(indices_shape[-1])): + temp_idx = new_indices[i, k] + inbound = np.all((inbound, (temp_idx >= 0), (temp_idx < data_shape[k]))) + index_write[0] += int(temp_idx * data.strides[k] / data.itemsize) + index_write[0] = int(index_write[0] / out.shape[1]) + index_read[1] = j + if inbound: + index_write[1] = j + temp = new_updates[tuple(index_read)] + out[tuple(index_write)] + out[tuple(index_write)] = temp + return out.reshape(data_shape) + + +def gather_np(data, indices, axis): + Ni, Nk = data.shape[:axis], data.shape[axis + 1:] + Nj = indices.shape + expect = np.zeros(Ni + Nj + Nk, data.dtype) + for i in np.ndindex(Ni): + for j in np.ndindex(Nj): + for k in np.ndindex(Nk): + if 0 <= indices[j] < data.shape[axis]: + expect[i + j + k] = data[i + (indices[j],) + k] + return expect diff --git a/tests/operators/gpu/__init__.py b/tests/operators/gpu/__init__.py index e69de29b..39a047ba 100644 --- a/tests/operators/gpu/__init__.py +++ b/tests/operators/gpu/__init__.py @@ -0,0 +1,20 @@ +from .test_ms_tensor_scatter_add import tensor_scatter_add_np,\ + gen_indices_tensor_scatter_add +from .test_ms_gather import gather_np, gen_indices_gather +from .test_ms_unsorted_segment_sum import gen_indices_unsorted_segment_sum +from .test_ms_gather_nd import gen_indices_gather_nd + +def gen_indices(indices_argument): + op_name = indices_argument.name + data_shape = indices_argument.data_shape + indices_shape = indices_argument.indices_shape + indices_dtype = indices_argument.indices_dtype + attrs = indices_argument.attrs + if op_name == "Gather": + return gen_indices_gather(data_shape, indices_shape, indices_dtype, attrs) + elif op_name == "GatherNd": + return gen_indices_gather_nd(data_shape, indices_shape, indices_dtype) + elif op_name == "UnsortedSegmentSum": + return gen_indices_unsorted_segment_sum(data_shape, indices_shape, indices_dtype, attrs) + assert op_name == "TensorScatterAdd", "Input OP Name Not Known!" + return gen_indices_tensor_scatter_add(data_shape, indices_shape, indices_dtype) diff --git a/tests/operators/gpu/test_all.py b/tests/operators/gpu/test_all.py index bf2169ac..0ab75cdb 100644 --- a/tests/operators/gpu/test_all.py +++ b/tests/operators/gpu/test_all.py @@ -52,6 +52,13 @@ from tests.operators.gpu.test_ms_reduce_or import test_ms_reduce_or from tests.operators.gpu.test_ms_cumsum import test_ms_cumsum from tests.operators.gpu.test_ms_cumprod import test_ms_cumprod from tests.operators.gpu.test_ms_conv import test_ms_conv +from tests.operators.gpu.test_ms_gather import test_ms_gather +from tests.operators.gpu.test_ms_gather_nd import test_ms_gather_nd +from tests.operators.gpu.test_ms_tensor_scatter_add import test_ms_tensor_scatter_add +from tests.operators.gpu.test_fused_gather_mul_scatter_add import test_fused_gather_mul_scatter_add +from tests.operators.gpu.test_ms_unsorted_segment_sum import test_ms_unsorted_segment_sum +from tests.operators.gpu.test_fused_gather_nd_reduce_sum_mul_unsorted_segment_sum import test_fused_gather_nd_reduce_sum_mul_unsorted_segment_sum + from tests.operators.gpu.test_fused_pad import test_fused_pad from tests.operators.gpu.test_fused_bn_reduce import test_fused_bn_reduce from tests.operators.gpu.test_fused_bn_update import test_fused_bn_update @@ -68,7 +75,7 @@ from tests.operators.gpu.test_fused_relu_grad_bn_double_update_grad import test_ from tests.operators.gpu.test_fused_relu_grad import test_fused_relu_grad from tests.operators.gpu.test_fused_bn_update_grad import test_fused_bn_update_grad from tests.operators.gpu.test_fused_mul_div_rsqrt_mul_isfinite_red import test_fused_mul_div_rsqrt_mul_isfinite_red - +from tests.operators.gpu.test_fused_gather_gather_add_mul_max_exp_scatter_add import test_fused_gather_gather_add_mul_max_exp_scatter_add def add(poly_sch, fuzz_shape=None, mind_trick_str=''): if fuzz_shape: @@ -344,16 +351,30 @@ def reduce_or(poly_sch, fuzz_shape=None, mind_trick_str=''): test_ms_reduce_or((1024, 1024), 'bool', axis=1, keepdims=True, poly_sch=poly_sch) +def gather(poly_sch, fuzz_shape=None, mind_trick_str=''): + test_ms_gather((19717, 8, 1), 'float32', (108365, ), 'int32', 0, poly_sch=True) + +def gather_nd(poly_sch, fuzz_shape=None, mind_trick_str=''): + test_ms_gather_nd((19717, 1, 3), 'float32', (108365, 1), 'int32', 0, poly_sch=True) + +def tensor_scatter_add(poly_sch, fuzz_shape=None, mind_trick_str=''): + test_ms_tensor_scatter_add((19717, 8, 1), 'float32', (108365, 1), 'int32', 0, poly_sch=True, + attrs={"dim": "0 0 8 8 0 1 128 128", "bind_block": "847 1", "bind_thread": "128 8"}) -def cumsum(poly_sch, fuzz_shape=None, mind_trick_str=''): - test_ms_cumsum((65, 49, 21), "float32", axis=2, poly_sch=poly_sch) - test_ms_cumsum((65, 49, 21), "float16", axis=0, poly_sch=poly_sch) +def unsorted_segment_sum(poly_sch, fuzz_shape=None, mind_trick_str=''): + test_ms_unsorted_segment_sum((108365, 8, 1), 'float32', (108365,), 'int32', 19717, poly_sch=True) +def fused_gather_mul_scatter_add(poly_sch, fuzz_shape=None, mind_trick_str=''): + test_fused_gather_mul_scatter_add((19717, 8, 8), (108365, ), (108365, 8, 8), (108365, 1), 'float32', 'int32', 0, poly_sch=True, + attrs={"dim": "0 0 16 16 0 1 8 8 0 2 8 8", "bind_block": "1 1 6773", "bind_thread": "8 8 16"}) -def cumprod(poly_sch, fuzz_shape=None, mind_trick_str=''): - test_ms_cumprod((65, 49, 21), "float32", axis=2, poly_sch=poly_sch) - test_ms_cumprod((65, 49, 21), "float16", axis=0, poly_sch=poly_sch) +def fused_gather_nd_reduce_sum_mul_unsorted_segment_sum(poly_sch, fuzz_shape=None, mind_trick_str=''): + test_fused_gather_nd_reduce_sum_mul_unsorted_segment_sum( + (19717, 8, 8), (108365, 1), (108365, 8, 1), (108365,), 'float32', 'int32', -1, True, 19717, poly_sch=True) +def fused_gather_gather_add_mul_max_exp_scatter_add(poly_sch, fuzz_shape=None, mind_trick_str=''): + test_fused_gather_gather_add_mul_max_exp_scatter_add((19717, 8, 1), (108365, ), (1,), (108365, ), + 'float32', 'int32', 0, poly_sch=True) def fused_pad(poly_sch, fuzz_shape=None, mind_trick_str=''): test_fused_pad((7, 7, 3, 64), (0, 0, 0, 0), (0, 0, 1, 0), @@ -468,10 +489,14 @@ if __name__ == '__main__': "equal": equal, "exp": exp, "greater_equal": greater_equal, "less_equal": less_equal, "log": log, "max": maximum, "min": minimum, "mul": mul, "neg": neg, "pow": pow, "reciprocal": reciprocal, "round": round, "rsqrt": rsqrt, "select": select, "sqrt": sqrt, - "sub": sub, "reduce_max": reduce_max, "reduce_min": reduce_min, - "reduce_sum": reduce_sum, "cumsum": cumsum, "cumprod": cumprod, - "expand_dims": expand_dims, "one_hot": one_hot, - "reshape": reshape, "tile": tile, "trans_data": trans_data, "conv": conv, + "sub": sub, "reduce_max": reduce_max, "reduce_min": reduce_min, "reduce_and":reduce_and, + "reduce_or":reduce_or, "reduce_sum": reduce_sum, "expand_dims": expand_dims, "one_hot": one_hot, + "reshape": reshape, "tile": tile, "trans_data": trans_data, + "conv": conv, "gather":gather, "gather_nd":gather_nd, + "tensor_scatter_add":tensor_scatter_add, + "unsorted_segment_sum": unsorted_segment_sum, + "fused_gather_mul_scatter_add":fused_gather_mul_scatter_add, + "fused_gather_nd_reduce_sum_mul_unsorted_segment_sum": fused_gather_nd_reduce_sum_mul_unsorted_segment_sum, "fused_pad": fused_pad, "fused_bn_reduce": fused_bn_reduce, "fused_bn_update": fused_bn_update, @@ -487,7 +512,9 @@ if __name__ == '__main__': "fused_relu_grad_bn_double_update_grad": fused_relu_grad_bn_double_update_grad, "fused_relu_grad": fused_relu_grad, "fused_bn_update_grad": fused_bn_update_grad, - "fused_mul_div_rsqrt_mul_isfinite_red": fused_mul_div_rsqrt_mul_isfinite_red + "fused_mul_div_rsqrt_mul_isfinite_red": fused_mul_div_rsqrt_mul_isfinite_red, + "fused_gather_mul_scatter_add": fused_gather_mul_scatter_add, + "fused_gather_gather_add_mul_max_exp_scatter_add": fused_gather_gather_add_mul_max_exp_scatter_add, } all_f = list(op_map.values()) op_map["all"] = all_f diff --git a/tests/operators/gpu/test_fused_gather_gather_add_mul_max_exp_scatter_add.py b/tests/operators/gpu/test_fused_gather_gather_add_mul_max_exp_scatter_add.py new file mode 100644 index 00000000..13ae9ad9 --- /dev/null +++ b/tests/operators/gpu/test_fused_gather_gather_add_mul_max_exp_scatter_add.py @@ -0,0 +1,77 @@ +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from tests.common.base import get_rtol_atol +from tests.common.tensorio import compare_tensor +from akg.utils import kernel_exec as utils +from akg.ops.array_gpu.fused_gather_gather_add_mul_max_exp_scatter_add import fused_gather_gather_add_mul_max_exp_scatter_add +from akg.utils.result_analysis import gpu_profiling +from akg.utils.format_transform import to_tvm_nd_array +from tests.common.gen_random import random_gaussian +import numpy as np +from copy import deepcopy + +def gen_data(shape1, shape2, shape3, shape4, dtype1, dtype2, axis): + input1 = random_gaussian(shape1).astype(dtype1) + input2 = np.random.randint(low=0, high=shape1[axis], size=shape2).astype(dtype2) + input3 = np.full(shape3, 0.2).astype(dtype1) + input4 = np.random.randint(shape1[axis], size=shape4).astype(dtype2) + + gather_out1 = np.take(input1, input2, axis=axis) + gather_out2 = np.take(input1, input2, axis=axis) + add_out = np.add(gather_out1, gather_out2) + mul_out = np.multiply(input3, add_out) + max_out = np.maximum(add_out, mul_out) + exp_out = np.exp(max_out) + + scatter_out = deepcopy(input1) + np.add.at(scatter_out, input4, exp_out) + + return input1, input2, input3, input4, exp_out, scatter_out + + +def test_fused_gather_gather_add_mul_max_exp_scatter_add(input1_shape, input2_shape, input3_shape, input4_shape, + data_dtype, indices_type, axis, poly_sch=False, attrs=None): + op_attrs = [axis] + default_attrs = {"target": "cuda"} + if attrs: + default_attrs.update(attrs) + if poly_sch: + mod = utils.op_build_test(fused_gather_gather_add_mul_max_exp_scatter_add, + [input1_shape, input2_shape, input3_shape, input4_shape], + [data_dtype, indices_type, data_dtype, indices_type], + op_attrs=op_attrs, attrs=default_attrs, + kernel_name="fused_gather_gather_add_mul_max_exp_scatter_add", ) + + # gen data + input1, input2, input3, input4, expect1, expect2 = gen_data(input1_shape, input2_shape, input3_shape, input4_shape, + data_dtype, indices_type, axis) + + output1 = np.zeros(expect1.shape, expect1.dtype) + output2 = deepcopy(input1) + output1, output2 = utils.mod_launch(mod, (input1, input2, input3, input4, output1, output2), + outputs=(-2, -1)) + + atol, rtol = get_rtol_atol("fused_gather_gather_add_mul_max_exp_scatter_add", data_dtype) + res = compare_tensor(output1, expect1, rtol=rtol, atol=atol) + res &= compare_tensor(output2, expect2, rtol=rtol, atol=atol) + print("Test {}".format("Pass" if res else "Failed")) + if not res: + print("Error cuda:========================") + print(mod.imported_modules[0].get_source()) + raise AssertionError("Test fail") + + inputs = to_tvm_nd_array([input1, input2, input3, input4]) + expects = to_tvm_nd_array([expect1, expect2]) + gpu_profiling(mod, *inputs, *expects, repeat_time=400) diff --git a/tests/operators/gpu/test_fused_gather_mul_scatter_add.py b/tests/operators/gpu/test_fused_gather_mul_scatter_add.py new file mode 100644 index 00000000..00dd6293 --- /dev/null +++ b/tests/operators/gpu/test_fused_gather_mul_scatter_add.py @@ -0,0 +1,82 @@ +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from tests.common.base import get_rtol_atol +from tests.common.tensorio import compare_tensor +from akg.utils import kernel_exec as utils +from akg.ops.array_gpu import fused_gather_mul_scatter_add +from akg.utils.result_analysis import gpu_profiling +from akg.utils.format_transform import to_tvm_nd_array +from tests.common.gen_random import random_gaussian +import numpy as np +from copy import deepcopy + +def gen_data(shape1, shape2, shape3, shape4, dtype1, dtype2, axis): + # gather + input1 = random_gaussian(shape1).astype(dtype1) + input2 = np.random.randint(low=0, high=shape1[axis], size=shape2).astype(dtype2) + gather_out = np.take(input1, input2, axis=axis) + + # mul + input3 = random_gaussian(shape3).astype(dtype1) + mul_out = np.multiply(gather_out, input3) + + # scatter_add + params = np.zeros(shape1, dtype1) + #params = random_gaussian(shape1).astype(dtype1) + indices = np.zeros(shape4, dtype2) + original_shape = indices.shape + indices = indices.reshape(-1, indices.shape[-1]) + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + indices[i][j] = np.random.randint(shape1[j], size=()) + + indices = indices.reshape(original_shape) + expect = deepcopy(params) + np.add.at(expect, tuple(indices.T.tolist()), mul_out) + indices = indices.reshape(shape4) + return input1, input2, input3, indices, expect + +def test_fused_gather_mul_scatter_add(input1_shape, input2_shape, input3_shape, input4_shape, data_dtype, indices_type, axis, poly_sch=False, attrs=None): + op_attrs = [axis] + default_attrs = {"target": "cuda"} + if attrs: + default_attrs.update(attrs) + + if poly_sch: + mod = utils.op_build_test(fused_gather_mul_scatter_add.fused_gather_mul_scatter_add, + [input1_shape, input2_shape, input3_shape, input4_shape], [data_dtype, indices_type, data_dtype, indices_type], op_attrs=op_attrs, + attrs=default_attrs, kernel_name="fused_gather_mul_scatter_add", ) + + # gen data + input1, input2, input3, input4, expect = gen_data(input1_shape, input2_shape, input3_shape, input4_shape, data_dtype, indices_type, axis) + output_shape = expect.shape + + if len(expect.shape) == 0: + output_shape = (1, ) + #output = np.full(output_shape, np.nan, expect.dtype) + output = np.zeros(output_shape, expect.dtype) + output = utils.mod_launch(mod, (input1, input2, input3, input4, output), expect = expect) + + atol, rtol = get_rtol_atol("fused_gather_mul_scatter_add", data_dtype) + res = compare_tensor(output, expect, rtol=rtol, atol=atol) + print("Test {}".format("Pass" if res else "Failed")) + if not res: + print("Error cuda:========================") + print(mod.imported_modules[0].get_source()) + raise AssertionError("Test fail") + + input1, input2, input3, input4, output, expect = to_tvm_nd_array( + [input1, input2, input3, input4, output, expect]) + gpu_profiling(mod, input1, input2, input3, input4, output, expect, repeat_time=400) diff --git a/tests/operators/gpu/test_fused_gather_nd_reduce_sum_mul_unsorted_segment_sum.py b/tests/operators/gpu/test_fused_gather_nd_reduce_sum_mul_unsorted_segment_sum.py new file mode 100644 index 00000000..1a2046ef --- /dev/null +++ b/tests/operators/gpu/test_fused_gather_nd_reduce_sum_mul_unsorted_segment_sum.py @@ -0,0 +1,76 @@ +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from tests.common.base import get_rtol_atol +from tests.common.tensorio import compare_tensor +from akg.utils import kernel_exec as utils +from akg.ops.array_gpu import fused_gather_nd_reduce_sum_mul_unsorted_segment_sum +from akg.utils.result_analysis import gpu_profiling +from akg.utils.format_transform import to_tvm_nd_array +from tests.common.gen_random import random_gaussian +import numpy as np +from copy import deepcopy + +def gen_data(shape1, shape2, shape3, shape4, data_type, indices_type, axis, keepdims, num): + input1 = random_gaussian(shape1).astype(data_type) + out_dim1 = 1 + for i in range(len(shape2) - 1): + out_dim1 = out_dim1 * shape2[i] + input2 = np.zeros([shape2[-1], out_dim1]).astype(indices_type) + for i in range(shape2[-1]): + input2[i] = np.random.randint(low=0, high=shape1[i], size=out_dim1) + input3 = random_gaussian(shape3).astype(data_type) + prod = np.sum(input1[tuple(input2.tolist())], axis=axis, keepdims=keepdims) * input3 + + input4 = np.random.randint(low=0, high=10, size=shape4).astype(indices_type) + input5 = np.random.randint(low=0, high=10, size=shape4).astype(indices_type) + expect1 = np.zeros((num,) + shape3[len(shape4):]).astype(data_type) + expect2 = np.zeros((num,) + shape3[len(shape4):]).astype(data_type) + np.add.at(expect1, input4, prod) + np.add.at(expect2, input5, prod) + + input2 = input2.transpose() + input2 = input2.reshape(shape2) + return input1, input2, input3, input4, input5, expect1, expect2 + +def test_fused_gather_nd_reduce_sum_mul_unsorted_segment_sum( + input1_shape, input2_shape, input3_shape, input4_shape, data_dtype, indices_type, axis, keepdims, num, poly_sch=False, attrs=None): + op_attrs = [axis, keepdims, num] + default_attrs = {"target": "cuda"} + if attrs: + default_attrs.update(attrs) + + if poly_sch: + mod = utils.op_build_test(fused_gather_nd_reduce_sum_mul_unsorted_segment_sum.fused_gather_nd_reduce_sum_mul_unsorted_segment_sum, + [input1_shape, input2_shape, input3_shape, input4_shape, input4_shape], [data_dtype, indices_type, data_dtype, indices_type, indices_type], op_attrs=op_attrs, + attrs=default_attrs, kernel_name="fused_gather_nd_reduce_sum_mul_unsorted_segment_sum", ) + + input1, input2, input3, input4, input5, expect1, expect2 = gen_data( + input1_shape, input2_shape, input3_shape, input4_shape, data_dtype, indices_type, axis, keepdims, num) + + output1 = np.zeros(expect1.shape, expect1.dtype) + output2 = np.zeros(expect2.shape, expect2.dtype) + output = utils.mod_launch(mod, (input1, input2, input3, input4, input5, output1, output2), outputs=(-2, -1), expect=(expect1, expect2)) + + atol, rtol = get_rtol_atol("fused_gather_nd_reduce_sum_mul_unsorted_segment_sum", data_dtype) + res = compare_tensor(output[0], expect1, rtol=rtol, atol=atol) and compare_tensor(output[1], expect2, rtol=rtol, atol=atol) + print("Test {}".format("Pass" if res else "Failed")) + if not res: + print("Error cuda:========================") + print(mod.imported_modules[0].get_source()) + raise AssertionError("Test fail") + + input1, input2, input3, input4, input5, output1, output2, expect1, expect2 = to_tvm_nd_array( + [input1, input2, input3, input4, input5, output1, output2, expect1, expect2]) + gpu_profiling(mod, input1, input2, input3, input4, input5, output1, output2, expect1, expect2, repeat_time=400) diff --git a/tests/operators/gpu/test_ms_gather.py b/tests/operators/gpu/test_ms_gather.py new file mode 100644 index 00000000..370e0f95 --- /dev/null +++ b/tests/operators/gpu/test_ms_gather.py @@ -0,0 +1,56 @@ +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from tests.common.base import get_rtol_atol +from tests.common.tensorio import compare_tensor +from akg.utils import kernel_exec as utils +from akg.ops.array_gpu import gather +from akg.utils.result_analysis import gpu_profiling +from akg.utils.format_transform import to_tvm_nd_array +from tests.common.gen_random import random_gaussian, gen_indices_gather +from tests.common.test_utils import gather_np +import numpy as np + + +def gen_data(shape1, dtype1, shape2, dtype2, axis): + params = random_gaussian(shape1).astype(dtype1) + indices = gen_indices_gather(shape1, shape2, dtype2, axis) + expect = gather_np(params, indices, axis) + return params, indices, expect + +def test_ms_gather(shape1, dtype1, shape2, dtype2, axis, poly_sch=False): + op_attrs = [axis] + if poly_sch: + mod = utils.op_build_test(gather.gather, [shape1, shape2], [dtype1, dtype2], op_attrs=op_attrs, + attrs={"target": "cuda"}, kernel_name="gather") + + # gen data + params, indices, expect = gen_data(shape1, dtype1, shape2, dtype2, axis) + output_shape = expect.shape + + if len(expect.shape) == 0: + output_shape = (1, ) + output = np.zeros(output_shape, expect.dtype) + output = utils.mod_launch(mod, (params, indices, output), expect = expect) + atol, rtol = get_rtol_atol("gather", dtype1) + res = compare_tensor(output, expect, rtol=rtol, atol=atol) + print("Test {}".format("Pass" if res else "Failed")) + if not res: + print("Error cuda:========================") + print(mod.imported_modules[0].get_source()) + raise AssertionError("Test fail") + + params, indices, output, expect = to_tvm_nd_array( + [params, indices, output, expect]) + gpu_profiling(mod, params, indices, output, expect, repeat_time=400) diff --git a/tests/operators/gpu/test_ms_gather_nd.py b/tests/operators/gpu/test_ms_gather_nd.py new file mode 100644 index 00000000..243fda1a --- /dev/null +++ b/tests/operators/gpu/test_ms_gather_nd.py @@ -0,0 +1,64 @@ +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +from copy import deepcopy + +from tests.common.base import get_rtol_atol +from tests.common.tensorio import compare_tensor +from akg.utils import kernel_exec as utils +from akg.ops.array_gpu import gather_nd +from akg.utils.result_analysis import gpu_profiling +from akg.utils.format_transform import to_tvm_nd_array +from tests.common.gen_random import random_gaussian, gen_indices_gather_nd +from tests.common.test_utils import gather_nd_np +import numpy as np + + +def gen_data(shape1, dtype1, shape2, dtype2): + params = random_gaussian(shape1).astype(dtype1) + out_dim1 = 1 + for i in range(len(shape2) - 1): + out_dim1 = out_dim1 * shape2[i] + + indices = gen_indices_gather_nd(shape1, shape2, dtype2) + expect = gather_nd_np(params, indices) + + return params, indices, expect + +def test_ms_gather_nd(shape1, dtype1, shape2, dtype2, axis, poly_sch=False): + # op_attrs = [axis] + if poly_sch: + mod = utils.op_build_test(gather_nd.gather_nd, + [shape1, shape2], [dtype1, dtype2], #op_attrs=op_attrs, + attrs={"target": "cuda"}, kernel_name="gather_nd") + + # gen data + params, indices, expect = gen_data(shape1, dtype1, shape2, dtype2) + output_shape = expect.shape + + if len(expect.shape) == 0: + output_shape = (1, ) + output = np.zeros(output_shape, expect.dtype) + output = utils.mod_launch(mod, (params, indices, output), expect = expect) + + atol, rtol = get_rtol_atol("gather_nd", dtype1) + res = compare_tensor(output, expect, rtol=rtol, atol=atol) + print("Test {}".format("Pass" if res else "Failed")) + if not res: + print("Error cuda:========================") + print(mod.imported_modules[0].get_source()) + raise AssertionError("Test fail") + + params, indices, output, expect = to_tvm_nd_array( + [params, indices, output, expect]) + gpu_profiling(mod, params, indices, output, expect, repeat_time=400) diff --git a/tests/operators/gpu/test_ms_tensor_scatter_add.py b/tests/operators/gpu/test_ms_tensor_scatter_add.py new file mode 100644 index 00000000..9a7207c5 --- /dev/null +++ b/tests/operators/gpu/test_ms_tensor_scatter_add.py @@ -0,0 +1,69 @@ +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from tests.common.base import get_rtol_atol +from tests.common.tensorio import compare_tensor +from akg.utils import kernel_exec as utils +from akg.ops.array_gpu import tensor_scatter_add +from akg.utils.result_analysis import gpu_profiling +from akg.utils.format_transform import to_tvm_nd_array +from tests.common.gen_random import random_gaussian, gen_indices_tensor_scatter_add +from tests.common.test_utils import tensor_scatter_add_np +import numpy as np + + +def gen_data(shape1, dtype1, shape2, dtype2): + params = np.zeros(shape1, dtype1) + update_shape = shape2[:-1] + shape1[shape2[-1]:] + updates = np.random.random(update_shape).astype(dtype1) + indices = gen_indices_tensor_scatter_add(shape1, shape2, dtype2) + expect = tensor_scatter_add_np(params, indices, updates) + return params, indices, updates, expect + +def test_ms_tensor_scatter_add(data_shape, data_type, indices_shape, indices_type, axis, poly_sch=False, attrs=None): + op_attrs = [axis] + default_attrs = {"target": "cuda"} + if attrs: + default_attrs.update(attrs) + if len(indices_shape) > 1: + updates_shape = indices_shape[:-1] + data_shape[indices_shape[-1]:] + else: + updates_shape = indices_shape + data_shape[1:] + + if poly_sch: + mod = utils.op_build_test(tensor_scatter_add.tensor_scatter_add, + [data_shape, indices_shape, updates_shape], [data_type, indices_type, data_type], + attrs=default_attrs, kernel_name="tensor_scatter_add", ) + + # gen data + indices_shape = indices_shape + (1,) if len(indices_shape) == 1 else indices_shape + params, indices, updates, expect = gen_data(data_shape, data_type, indices_shape, indices_type) + output_shape = expect.shape + + if len(expect.shape) == 0: + output_shape = (1, ) + output = np.zeros(output_shape, expect.dtype) + output = utils.mod_launch(mod, (params, indices, updates, output), expect = expect) + + atol, rtol = get_rtol_atol("tensor_scatter_add", data_type) + res = compare_tensor(output, expect, rtol=rtol, atol=atol) + print("Test {}".format("Pass" if res else "Failed")) + if not res: + print("Error cuda:========================") + print(mod.imported_modules[0].get_source()) + raise AssertionError("Test fail") + + params, indices, updates, output, expect = to_tvm_nd_array( + [params, indices, updates, output, expect]) + gpu_profiling(mod, params, indices, updates, output, expect, repeat_time=400) diff --git a/tests/operators/gpu/test_ms_unsorted_segment_sum.py b/tests/operators/gpu/test_ms_unsorted_segment_sum.py new file mode 100644 index 00000000..334ebae7 --- /dev/null +++ b/tests/operators/gpu/test_ms_unsorted_segment_sum.py @@ -0,0 +1,64 @@ +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import numpy as np +from copy import deepcopy +from tests.common.base import get_rtol_atol +from tests.common.tensorio import compare_tensor +from akg.utils import kernel_exec as utils +from akg.ops.array_gpu import unsorted_segment_sum +from akg.utils.result_analysis import gpu_profiling +from akg.utils.format_transform import to_tvm_nd_array +from tests.common.gen_random import random_gaussian, gen_indices_unsorted_segment_sum + + +def gen_data(shape1, dtype1, shape2, dtype2, num): + input1 = random_gaussian(shape1).astype(dtype1) + input2 = gen_indices_unsorted_segment_sum(shape1, shape2, dtype2, num) + expect = np.zeros((num,) + shape1[len(shape2):]).astype(dtype1) + np.add.at(expect, input2, input1) + return input1, input2, expect + +def test_ms_unsorted_segment_sum(data_shape, data_type, indices_shape, indices_type, num, poly_sch=False, attrs=None): + op_attrs = [num] + default_attrs = {"target": "cuda"} + if attrs: + default_attrs.update(attrs) + + if poly_sch: + mod = utils.op_build_test(unsorted_segment_sum.unsorted_segment_sum, + [data_shape, indices_shape], [data_type, indices_type, data_type], op_attrs=op_attrs, + attrs=default_attrs, kernel_name="unsorted_segment_sum", ) + + # gen data + input1, input2, expect = gen_data(data_shape, data_type, indices_shape, indices_type, num) + output_shape = expect.shape + + if len(expect.shape) == 0: + output_shape = (1, ) + #output = np.full(output_shape, np.nan, expect.dtype) + output = np.zeros(output_shape, expect.dtype) + output = utils.mod_launch(mod, (input1, input2, output), expect = expect) + + atol, rtol = get_rtol_atol("unsorted_segment_sum", data_type) + res = compare_tensor(output, expect, rtol=rtol, atol=atol) + print("Test {}".format("Pass" if res else "Failed")) + if not res: + print("Error cuda:========================") + print(mod.imported_modules[0].get_source()) + raise AssertionError("Test fail") + + input1, input2, output, expect = to_tvm_nd_array( + [input1, input2, output, expect]) + gpu_profiling(mod, input1, input2, output, expect, repeat_time=400) diff --git a/third_party/incubator-tvm/include/tvm/ir.h b/third_party/incubator-tvm/include/tvm/ir.h index 9e0311ca..9eb098b9 100644 --- a/third_party/incubator-tvm/include/tvm/ir.h +++ b/third_party/incubator-tvm/include/tvm/ir.h @@ -1278,6 +1278,8 @@ constexpr const char* promote_vectorization = "promote_vectorization"; constexpr const char* bind_thread_x = "bind_thread_x"; /*! \brief Mark for tensorcore interface*/ constexpr const char* wmma_scope = "wmma_scope"; +/*! \brief Mark of tensor_of_tensor */ +constexpr const char* atomic_tot = "atomic_tot"; /*! * \brief Mark of prefetch scope, value=offset, -- Gitee