diff --git a/python/akg/composite/build_module.py b/python/akg/composite/build_module.py index d2b8124291125571ff269cd8dff0c23c082bffe7..3a63cca6e99fe84f38c72355c09e52f488d9c11a 100644 --- a/python/akg/composite/build_module.py +++ b/python/akg/composite/build_module.py @@ -26,6 +26,17 @@ from akg.topi.cuda.injective_single_kernel import schedule_injective import topi from akg.global_configs import get_dump_ir_flag +def should_enable_tensor_core(kernel_info): + for op in kernel_info["op_desc"]: + if op['name'] in ["MatMul", "Conv2D"]: + return True + return False + +def should_enable_conv_tensor_core(kernel_info): + for op in kernel_info["op_desc"]: + if op["name"] == "Conv2D": + return True + return False def should_enable_atomic_add(kernel_info): for op in kernel_info["op_desc"]: @@ -824,6 +835,12 @@ def _build(desc_s, desc_d, attrs=None, poly=True, use_repo=True): if backend == 'cuda': if poly: attrs["enable_akg_reduce_lib"] = True + if "pragma_enable_matmul" not in attrs.keys(): + attrs['pragma_enable_matmul'] = should_enable_tensor_core(desc_d) + attrs['enable_auto_inline'] = (not should_enable_tensor_core(desc_d)) + if "pragma_enable_conv_tensor_core" not in attrs.keys(): + attrs["pragma_enable_conv_tensor_core"] = should_enable_conv_tensor_core(desc_d) + attrs["enable_auto_fuse"] = (not should_enable_conv_tensor_core(desc_d)) return _build_to_module_gpu(desc_s, desc_d, attrs, poly) else: return _build_to_module(desc_s, desc_d, attrs, use_repo) diff --git a/python/akg/composite/topi.py b/python/akg/composite/topi.py index bc7cb81862d7a29341f7fbd5814d4fe748ced44e..3811737106604b2d7abf67b8bf64c47463648f16 100644 --- a/python/akg/composite/topi.py +++ b/python/akg/composite/topi.py @@ -225,3 +225,43 @@ def trans_data(inputs, attrs): else: raise ValueError("TransData for src_format %s and dst_format %s is not supported" % (src_format, dst_format)) + +@tvm.register_func("Conv2D") +def conv2d_nhwc(inputs, attrs): + attrs = {k: v for k, v in attrs.items()} + # Check inputs and attrs + if len(inputs) != 2: + raise ValueError("length of inputs shoule be 2, but got %d." % len(inputs)) + if "stride" not in attrs: + raise ValueError("stride not be found in the attrs") + data = inputs[0] + weight = inputs[1] + output_name = "T_conv2d_nhwc_" + data.op.name + "_" + weight.op.name + stride = attrs["stride"] + data_dtype = data.dtype + weight_dtype = weight.dtype + # Check data type + vc_util.ops_dtype_check(data_dtype, vc_util.DtypeForDavinci.FLOAT16) + vc_util.ops_dtype_check(weight_dtype, vc_util.DtypeForDavinci.FLOAT16) + # Check shape + if len(data.shape) != 4 or len(weight.shape) != 4: + raise ValueError("shape of data and weight should be 4-dim, but got %d and %d." % (len(data.shape), + len(weight.shape))) + # Compute output + n, in_h, in_w, in_c = data.shape + out_c, k_h, k_w, in_c = weight.shape + _, _, s_h, s_w = stride + o_h = (in_h - k_h) // s_h + 1 + o_w =(in_w - k_w) // s_w + 1 + rc = tvm.reduce_axis((0, in_c), name="rc") + rh = tvm.reduce_axis((0, k_h), name="rh") + rw = tvm.reduce_axis((0, k_w), name="rw") + output = tvm.compute( + (n, o_h, o_w, out_c), + lambda n, h, w, o: tvm.sum( + data[n, (h * s_h + rh), (w * s_w + rw), rc] + * weight[o, rh, rw, rc], + axis=[rc, rh, rw]), + name=output_name + ) + return output diff --git a/python/akg/ops/math_gpu/conv.py b/python/akg/ops/math_gpu/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..e81b20e6dc0141334a32817982f786c392e79576 --- /dev/null +++ b/python/akg/ops/math_gpu/conv.py @@ -0,0 +1,46 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""operator dsl function: batch_matmul""" +import numpy as np +import akg.topi as topi +import akg.tvm as tvm +from akg.utils import validation_check as vc_util + +def conv(data, weight, stride=[1,1], pad=[0,0,0,0], dilation=[1,1], name="out"): + batch, in_c, in_h, in_w = data.shape + out_c, in_c, k_h, k_w = weight.shape + pad_left, pad_right, pad_top, pad_bottom = pad + s_h, s_w = stride + o_h = (in_h + pad_top + pad_bottom - k_h) // s_h + 1 + o_w = (in_w + pad_left + pad_right - k_w) // s_w + 1 + out_shape = (batch, out_c, o_h, o_w) + + data_pad = topi.nn.pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_bottom, pad_right], 0.0) + + rc = tvm.reduce_axis((0, in_c), name="rc") + rh = tvm.reduce_axis((0, k_h), name="rh") + rw = tvm.reduce_axis((0, k_w), name="rw") + + out = tvm.compute(out_shape, + lambda n,c,h,w: tvm.sum(data_pad[n, rc, h * s_h + rh, w * s_w + rw]*weight[c, rc, rh, rw], + axis=[rc, rh, rw]), name=name) + # use for relu condition + # out = tvm.compute(out.shape, lambda *i: tvm.max(out(*i), tvm.const(0, out.dtype)), name="relu") + return out + +def conv_fusion(data, weight1, weight2, stride1=[1,1], stride2=[1,1], pad1=[0,0,0,0], pad2=[0,0,0,0], dilation1=[1,1], dilation2=[1,1]): + data2 = conv(data, weight1, stride1, pad1, dilation1) + out = conv(data2, weight2, stride2, pad2, dilation2, "out2") + return out diff --git a/python/akg/ops/math_gpu/tensorcore_conv.py b/python/akg/ops/math_gpu/tensorcore_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..de5ac240623a44b25b7d968081a9fb23093266a6 --- /dev/null +++ b/python/akg/ops/math_gpu/tensorcore_conv.py @@ -0,0 +1,71 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""operator dsl function: conv using tensorcore""" +import numpy as np +import akg.topi as topi +import akg.tvm as tvm +from akg.utils import validation_check as vc_util + + +def conv_tc(data, weight, stride=[1, 1], pad=[0, 0, 0, 0], dilation=[1, 1], out_dtype="float32", name="out"): + batch_outer, in_h, in_w, in_c_outer = data.shape + out_c_outer, k_h, k_w, _ = weight.shape + pad_left, pad_right, pad_top, pad_bottom = pad + s_h, s_w = stride + o_h = (in_h + pad_top + pad_bottom - k_h) // s_h + 1 + o_w = (in_w + pad_left + pad_right - k_w) // s_w + 1 + + has_pad = not(pad_left == 0 and pad_right == + 0 and pad_top == 0 and pad_bottom == 0) + + if has_pad: + data_pad = tvm.compute( + (batch_outer, in_h+pad_top+pad_bottom, + in_w+pad_left+pad_right, in_c_outer), + lambda n, h, w, i: tvm.if_then_else( + tvm.all(h >= pad_top, h - pad_bottom < in_h, + w >= pad_left, w - pad_right < in_w), + data[n, h-pad_top, w - pad_left, i], + tvm.const(0.0, "float16"), + ), + name="Pad", + ) + else: + data_pad = data + + rc = tvm.reduce_axis((0, in_c_outer), name="rc") + rh = tvm.reduce_axis((0, k_h), name="rh") + rw = tvm.reduce_axis((0, k_w), name="rw") + + if out_dtype == "float32": + out = tvm.compute( + (batch_outer, o_h, o_w, out_c_outer), + lambda n, h, w, o: tvm.sum( + data_pad[n, (h * s_h + rh), (w * s_w + rw), rc].astype("float32") * + weight[o, rh, rw, rc].astype("float32"), + axis=[rc, rh, rw]), + name=name + ) + else: + out = tvm.compute( + (batch_outer, o_h, o_w, out_c_outer), + lambda n, h, w, o: tvm.sum( + data_pad[n, (h * s_h + rh), (w * s_w + rw), rc] * + weight[o, rh, rw, rc], + axis=[rc, rh, rw]), + name=name + ) + + return out diff --git a/src/akg_mma_lib/m16n16k4.hpp b/src/akg_mma_lib/m16n16k4.hpp deleted file mode 100644 index cb094a19f60806127360bb715a8bcf935c402247..0000000000000000000000000000000000000000 --- a/src/akg_mma_lib/m16n16k4.hpp +++ /dev/null @@ -1,375 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __WMMA_M16N16K4_HPP__ -#define __WMMA_M16N16K4_HPP__ -#include -#include - -namespace akg { -namespace wmma { - -inline __device__ unsigned get_lane_id() { - unsigned lane_id; - asm(R"({mov.s32 %0, %laneid;})" : "=r"(lane_id)); - return lane_id; -} - -template -struct __align__(4) __frag_base { - T x[size]; - enum { num_elements = size }; -}; - -template -class fragment; -template <> -class fragment : public __frag_base {}; -template <> -class fragment : public __frag_base {}; -template <> -class fragment : public __frag_base {}; -template <> -class fragment : public __frag_base {}; -template <> -class fragment : public __frag_base {}; -template <> -class fragment : public __frag_base {}; - -template -__device__ inline void fill_fragment(__frag_base &f, const T v) { -#pragma unroll - for (unsigned i = 0; i < f.num_elements; i++) { - f.x[i] = v; - } -} - -template -__device__ inline void load_matrix_sync( - akg::wmma::fragment &f, const T *const p, - const unsigned ldm) { - const unsigned lane_id = get_lane_id(); - const unsigned row = lane_id & 0x3; - const unsigned col = (lane_id & 0x4) + ((lane_id >> 4) << 3); - const unsigned offset = row * ldm + col; - - float2 *src = (float2 *)(p + offset); - float2 *dst = (float2 *)f.x; - dst[0] = src[0]; -} - -template -__device__ inline void load_matrix_sync( - akg::wmma::fragment &f, const T *const p, - const unsigned ldm) { - const unsigned lane_id = get_lane_id(); - const unsigned row = (lane_id & 0x7) + ((lane_id >> 4) << 3); // (l and (1111)b + l div 16 * 8) - const unsigned offset = row * ldm; - float2 *src = (float2 *)(p + offset); - float2 *dst = (float2 *)f.x; - dst[0] = src[0]; -} - -template -__device__ inline void load_matrix_sync( - akg::wmma::fragment &f, const T *const p, - const unsigned ldm) { - const unsigned lane_id = get_lane_id(); - const unsigned row = (lane_id & 0x3) + ((lane_id & 0x18) >> 1); - const unsigned offset = row * ldm; - - float2 *src = (float2 *)(p + offset); - float2 *dst = (float2 *)f.x; - dst[0] = src[0]; -} - -template -__device__ inline void load_matrix_sync( - akg::wmma::fragment &f, const T *const p, - const unsigned ldm) { - const unsigned lane_id = get_lane_id(); - const unsigned row = lane_id & 0x3; - const unsigned col = (lane_id >> 3) << 2; - const unsigned offset = row * ldm + col; - - float2 *src = (float2 *)(p + offset); - float2 *dst = (float2 *)f.x; - dst[0] = src[0]; -} - -template -__device__ inline void load_matrix_sync(akg::wmma::fragment &f, - const T *const p, const unsigned ldm, const nvcuda::wmma::layout_t layout) { - const unsigned lane_id = get_lane_id(); - const unsigned row = (lane_id & 0x7) + ((lane_id >> 4) << 3); - const unsigned col = ((lane_id & 0xf) >> 3) << 2; - if (layout == nvcuda::wmma::mem_col_major) { - const int offset = row * ldm + col; - f.x[0] = static_cast(p[offset]); - f.x[1] = static_cast(p[offset + ldm]); - f.x[2] = static_cast(p[offset + 2 * ldm]); - f.x[3] = static_cast(p[offset + 3 + ldm]); - f.x[4] = static_cast(p[offset + 8 * ldm]); - f.x[5] = static_cast(p[offset + 9 * ldm]); - f.x[6] = static_cast(p[offset + 10 * ldm]); - f.x[7] = static_cast(p[offset + 11 * ldm]); - } else { - const int offset = row * ldm + col; - float2 *src = (float2 *)(p + offset); - float2 *dst = (float2 *)f.x; - dst[0] = src[0]; - dst[1] = src[2]; - } -} - -template -__device__ inline void store_matrix_sync(T *const p, - const fragment &f, - const unsigned ldm, const nvcuda::wmma::layout_t layout) { - const unsigned lane_id = get_lane_id(); - const unsigned row = (lane_id & 0x7) + ((lane_id >> 4) << 3); - const unsigned col = ((lane_id & 0xf) >> 3) << 2; - if (layout == nvcuda::wmma::mem_col_major) { - const int offset = row * ldm + col; - p[offset + 0] = static_cast(f.x[0]); - p[offset + ldm] = static_cast(f.x[1]); - p[offset + 2 * ldm] = static_cast(f.x[2]); - p[offset + 3 * ldm] = static_cast(f.x[3]); - p[offset + 8 * ldm] = static_cast(f.x[4]); - p[offset + 9 * ldm] = static_cast(f.x[5]); - p[offset + 10 * ldm] = static_cast(f.x[6]); - p[offset + 11 * ldm] = static_cast(f.x[7]); - } else { - const int offset = row * ldm + col; - float2 *dst = (float2 *)(p + offset); - float2 *src = (float2 *)f.x; - dst[0] = src[0]; - dst[2] = src[1]; - } -} - -// ptx_isa_7.1.pdf page277 -template -__device__ inline void load_matrix_sync(fragment &f, T *const p, - const unsigned ldm, const nvcuda::wmma::layout_t layout) { - const unsigned lane_id = get_lane_id(); - const unsigned row = (lane_id & 0x5) + ((lane_id >> 4) << 3); - const unsigned col = ((lane_id & 0x2)) + ((lane_id & 0x8) >> 1); - if (layout == nvcuda::wmma::mem_col_major) { - const int offset = row * ldm + col; - f.x[0] = static_cast(p[offset]); - f.x[1] = static_cast(p[offset + ldm]); - f.x[2] = static_cast(p[offset + 2]); - f.x[3] = static_cast(p[offset + 2 + ldm]); - f.x[4] = static_cast(p[offset + 8 * ldm]); - f.x[5] = static_cast(p[offset + 9 * ldm]); - f.x[6] = static_cast(p[offset + 2 + 8 * ldm]); - f.x[7] = static_cast(p[offset + 2 + 9 * ldm]); - } else { - const int offset = row * ldm + col; - float2 *src = (float2 *)(p + offset); - float2 *dst = (float2 *)f.x; - dst[0] = src[0]; - dst[1] = src[ldm]; - dst[2] = src[4]; - dst[3] = src[ldm + 4]; - } -} - -template <> -__device__ inline void load_matrix_sync(fragment &f, half *const p, - const unsigned ldm, const nvcuda::wmma::layout_t layout) { - const unsigned lane_id = get_lane_id(); - const unsigned row = (lane_id & 0x5) + ((lane_id >> 4) << 3); - const unsigned col = ((lane_id & 0x2)) + ((lane_id & 0x8) >> 1); - if (layout == nvcuda::wmma::mem_col_major) { - const int offset = row * ldm + col; - f.x[0] = static_cast(p[offset]); - f.x[1] = static_cast(p[offset + ldm]); - f.x[2] = static_cast(p[offset + 2]); - f.x[3] = static_cast(p[offset + 2 + ldm]); - f.x[4] = static_cast(p[offset + 8 * ldm]); - f.x[5] = static_cast(p[offset + 9 * ldm]); - f.x[6] = static_cast(p[offset + 2 + 8 * ldm]); - f.x[7] = static_cast(p[offset + 2 + 9 * ldm]); - } else { - const int offset = row * ldm + col; - half2 *src = (half2 *)(p + offset); - float2 *dst = (float2 *)f.x; - dst[0] = __half22float2(src[0]); - dst[1] = __half22float2(src[ldm]); - dst[2] = __half22float2(src[4]); - dst[3] = __half22float2(src[ldm + 4]); - } -} - -template -__device__ inline void store_matrix_sync(T *const p, fragment &f, - const unsigned ldm, const nvcuda::wmma::layout_t layout) { - const unsigned lane_id = get_lane_id(); - const unsigned row = (lane_id & 0x5) + ((lane_id >> 4) << 3); - const unsigned col = ((lane_id & 0x2)) + ((lane_id & 0x8) >> 1); - - if (layout == nvcuda::wmma::mem_col_major) { - const int offset = row * ldm + col; - p[offset + 0] = static_cast(f.x[0]); - p[offset + ldm] = static_cast(f.x[1]); - p[offset + 2] = static_cast(f.x[2]); - p[offset + 2 + ldm] = static_cast(f.x[3]); - p[offset + 8 * ldm] = static_cast(f.x[4]); - p[offset + 9 * ldm] = static_cast(f.x[5]); - p[offset + 2 + 8 * ldm] = static_cast(f.x[6]); - p[offset + 2 + 9 * ldm] = static_cast(f.x[7]); - } else { - const int offset = row * ldm + col; - float2 *dst = (float2 *)(p + offset); - float2 *src = (float2 *)f.x; - dst[0] = src[0]; - dst[ldm] = src[1]; - dst[4] = src[2]; - dst[ldm + 4] = src[3]; - } -} - -template <> -__device__ inline void store_matrix_sync(half *const p, fragment &f, - const unsigned ldm, const nvcuda::wmma::layout_t layout) { - const unsigned lane_id = get_lane_id(); - const unsigned row = (lane_id & 0x5) + ((lane_id >> 4) << 3); - const unsigned col = ((lane_id & 0x2)) + ((lane_id & 0x8) >> 1); - - if (layout == nvcuda::wmma::mem_col_major) { - const int offset = row * ldm + col; - p[offset + 0] = __float2half(f.x[0]); - p[offset + ldm] = __float2half(f.x[1]); - p[offset + 2] = __float2half(f.x[2]); - p[offset + 2 + ldm] = __float2half(f.x[3]); - p[offset + 8 * ldm] = __float2half(f.x[4]); - p[offset + 9 * ldm] = __float2half(f.x[5]); - p[offset + 2 + 8 * ldm] = __float2half(f.x[6]); - p[offset + 2 + 9 * ldm] = __float2half(f.x[7]); - } else { - const int offset = row * ldm + col; - half2 *dst = (half2 *)(p + offset); - float2 *src = (float2 *)f.x; - dst[0] = __float22half2_rn(src[0]); - dst[ldm] = __float22half2_rn(src[1]); - dst[4] = __float22half2_rn(src[2]); - dst[ldm + 4] = __float22half2_rn(src[3]); - } -} - -#define MMA_M16N16K4_F32_F32(A_LAYOUT, B_LAYOUT) \ -__device__ inline void mma_sync( \ - fragment & d, \ - const fragment & a, \ - const fragment & b, \ - const fragment & c){ \ - asm("{mma.sync.aligned.m8n8k4." #A_LAYOUT "." #B_LAYOUT" \ - .f32.f16.f16.f32 {%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}," \ - "{%10, %11}, {%12, %13, %14, %15, %16, %17, %18, %19};}" \ - :"=f"(d.x[0]), \ - "=f"(d.x[1]), \ - "=f"(d.x[2]), \ - "=f"(d.x[3]), \ - "=f"(d.x[4]), \ - "=f"(d.x[5]), \ - "=f"(d.x[6]), \ - "=f"(d.x[7]) \ - :"r"(*reinterpret_cast(a.x)), \ - "r"(*reinterpret_cast(a.x + 2)), \ - "r"(*reinterpret_cast(b.x)), \ - "r"(*reinterpret_cast(b.x + 2)), \ - "f"(c.x[0]), \ - "f"(c.x[1]), \ - "f"(c.x[2]), \ - "f"(c.x[3]), \ - "f"(c.x[4]), \ - "f"(c.x[5]), \ - "f"(c.x[6]), \ - "f"(c.x[7])); \ - } \ - -MMA_M16N16K4_F32_F32(col, col); -MMA_M16N16K4_F32_F32(row, col); -MMA_M16N16K4_F32_F32(col, row); -MMA_M16N16K4_F32_F32(row, row); - -#define MMA_M16N16K4_F16_F16(A_LAYOUT, B_LAYOUT) \ -__device__ inline void mma_sync( \ - fragment & d, \ - const fragment & a, \ - const fragment & b, \ - const fragment & c){ \ - asm("{mma.sync.aligned.m8n8k4." #A_LAYOUT "." #B_LAYOUT" \ - .f16.f16.f16.f16 {%0, %1, %2, %3}, {%4, %5}," \ - "{%6, %7}, {%8, %9, %10, %11};}" \ - :"=r"(*reinterpret_cast(d.x)), \ - "=r"(*reinterpret_cast(d.x + 2)), \ - "=r"(*reinterpret_cast(d.x + 4)), \ - "=r"(*reinterpret_cast(d.x + 6)) \ - :"r"(*reinterpret_cast(a.x)), \ - "r"(*reinterpret_cast(a.x + 2)), \ - "r"(*reinterpret_cast(b.x)), \ - "r"(*reinterpret_cast(b.x + 2)), \ - "r"(*reinterpret_cast(c.x)), \ - "r"(*reinterpret_cast(c.x + 2)), \ - "r"(*reinterpret_cast(c.x + 4)), \ - "r"(*reinterpret_cast(c.x + 6))); \ - } \ - -MMA_M16N16K4_F16_F16(col, col); -MMA_M16N16K4_F16_F16(row, col); -MMA_M16N16K4_F16_F16(col, row); -MMA_M16N16K4_F16_F16(row, row); - -template -__device__ inline void print_fragment(const akg::wmma - ::fragment &frag, const char *name = "") { - if ((threadIdx.x & 0x1f) == 0) { - if (name[0] != '\0') { - printf("%s = \n", name); - } - } - - for (unsigned i = 0; i < warpSize; i++) { - if (i == (threadIdx.x & 0x1f)) { - printf("threadIdx.x = %d", threadIdx.x); - for (unsigned j = 0; j < frag.num_elements; j++) { - float v; - if (sizeof(T) == 2) { - v = __half2float(frag.x[j]); - } else { - v = frag.x[j]; - } - if (v == 0.0f) { - printf(" %f ", 0.0f); - } else if (v > 0) { - printf(" %f ", v); - } else { - printf("%f ", v); - } - } - printf("\n"); - } - __syncthreads(); - } -} - -} // namespace akg -} // namespace wmma - -#endif // __WMMA_M16N16_K4_HPP__ diff --git a/src/akg_mma_lib/wmma.hpp b/src/akg_mma_lib/wmma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..86ff3eb1c53f96d2ad6d9d296882e2200ed0a189 --- /dev/null +++ b/src/akg_mma_lib/wmma.hpp @@ -0,0 +1,678 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * WMMA API Extension + * CUDA provides an experimental PTX instruction mma.m8n8k4 which compute matrix FMA use Tensor Core + * See detail: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions + * This extension provides its C++ interface + * + * Sample: + * #include "wmma.hpp" + * __global__ void wmma_kernel(float *c_ptr, const half *a_ptr, const half *b_ptr) { + * akg::wmma::fragment frag_a; + * akg::wmma::fragment frag_b; + * akg::wmma::fragment frag_c; + * akg::wmma::fragment frag_d; + * + * akg::wmma::fill_fragment(frag_c, 0.0f); + * akg::wmma::load_matrix_sync(frag_a, a_ptr, 16); + * akg::wmma::load_matrix_sync(frag_b, b_ptr, 16); + * + * akg::wmma::mma_sync(frag_d, frag_a, frag_b, frag_c); + * akg::wmma::store_matrix_sync(c_ptr, frag_d, N, nvcuda::wmma::mem_row_major); + * } + */ + +#ifndef __WMMA_HPP__ +#define __WMMA_HPP__ +#include +#include + +namespace akg { +namespace wmma { + +template +class Vector2; + +template <> +class Vector2 { + public: + typedef half2 Vector2Type; +}; + +template <> +class Vector2 { + public: + typedef float2 Vector2Type; +}; + +template +class CastTypeFun; + +template +class CastTypeFun { + public: + typedef T1 CastType; +}; + +template +class CastTypeFun { + public: + typedef T2 CastType; +}; + +template +class CastValueType { + public: + typedef typename CastTypeFun<(k % 8 == 0), T1, T2>::CastType CastType; +}; + +template +inline __device__ T cast(const float src) { + return src; +} + +template <> +inline __device__ half cast(const float src) { + return __float2half_rn(src); +} + +template +inline __device__ T cast(const float2 src) { + return src; +} + +template <> +inline __device__ half2 cast(const float2 src) { + return __float22half2_rn(src); +} + +template +inline __device__ T cast(const half src) { + return src; +} + +template <> +inline __device__ float cast(const half src) { + return __half2float(src); +} + +template +inline __device__ T cast(const half2 src) { + return src; +} + +template <> +inline __device__ float2 cast(const half2 src) { + return __half22float2(src); +} + +inline __device__ unsigned get_lane_id() { + unsigned lane_id; + asm volatile (R"({mov.s32 %0, %laneid;})" : "=r"(lane_id)); + return lane_id; +} + +template +struct __align__(4) __frag_base { + T x[size]; + enum { num_elements = size }; +}; + +template +class fragment; +template +class fragment : public __frag_base {}; +template +class fragment : public __frag_base {}; +template <> +class fragment : public __frag_base {}; +template <> +class fragment : public __frag_base {}; +template <> +class fragment : public __frag_base {}; + +template +__device__ inline void fill_fragment(__frag_base &f, const S v) { + #pragma unroll + for (unsigned i = 0; i < f.num_elements; i++) { + f.x[i] = cast(v); + } +} + +template +__device__ inline void load_matrix_sync(fragment &f, + const T *const p, const unsigned ldm) { + const unsigned lane_id = get_lane_id(); + const unsigned row = lane_id & 0x3; + const unsigned col = (lane_id & 0x4) + ((lane_id >> 4) << 3); + const unsigned offset = row * ldm + col; + + using Type = typename CastValueType::CastType; + Type *src = (Type *)(p + offset); + Type *dst = (Type *)f.x; + #pragma unroll + for (int i = 0; i < k / 4; i++) { + dst[i] = src[i * ldm]; + } +} + +template +__device__ inline void load_matrix_sync(fragment &f, + const T *const p, const unsigned ldm) { + const unsigned lane_id = get_lane_id(); + const unsigned row = lane_id & 0x3; + const unsigned col = (lane_id & 0x8) + (lane_id & 0x10); + const unsigned offset = row * ldm + col; + + using Type = float4; + Type *src = (Type *)(p + offset); + Type *dst = (Type *)f.x; + dst[0] = src[0]; +} + +template +__device__ inline void load_matrix_sync(fragment &f, + const T *const p, const unsigned ldm) { + const unsigned lane_id = get_lane_id(); + const unsigned row = (lane_id & 0x7) + ((lane_id >> 4) << 3); + const unsigned offset = row * ldm; + + using Type = typename CastValueType::CastType; + Type *src = (Type *)(p + offset); + Type *dst = (Type *)f.x; + dst[0] = src[0]; +} + +template +__device__ inline void load_matrix_sync(fragment &f, + const T *const p, const unsigned ldm) { + const unsigned lane_id = get_lane_id(); + const unsigned row = (lane_id & 0x3) + ((lane_id & 0x18) >> 1); + const unsigned offset = row * ldm; + + using Type = typename CastValueType::CastType; + Type *src = (Type *)(p + offset); + Type *dst = (Type *)f.x; + dst[0] = src[0]; +} + +template +__device__ inline void load_matrix_sync(fragment &f, + const T *const p, const unsigned ldm) { + const unsigned lane_id = get_lane_id(); + const unsigned row = lane_id & 0x3; + const unsigned col = (lane_id >> 3) << 2; + const unsigned offset = row * ldm + col; + + using Type = typename CastValueType::CastType; + Type *src = (Type *)(p + offset); + Type *dst = (Type *)f.x; + #pragma unroll + for (int i = 0; i < k / 4; i++) { + dst[i] = src[i * ldm]; + } +} + +template +__device__ inline void load_matrix_sync(fragment &f, + const T *const p, const unsigned ldm) { + const unsigned lane_id = get_lane_id(); + const unsigned row = lane_id & 0x3; + const unsigned col = ((lane_id & 0x4) << 1) + (lane_id & 0x10); + const unsigned offset = row * ldm + col; + + using Type = float4; + Type *src = (Type *)(p + offset); + Type *dst = (Type *)f.x; + dst[0] = src[0]; +} + +template +__device__ inline void load_matrix_sync(fragment &f, + const T *const p, const unsigned ldm, const nvcuda::wmma::layout_t layout) { + const unsigned lane_id = get_lane_id(); + const unsigned row = (lane_id & 0x7) + ((lane_id >> 4) << 3); + const unsigned col = ((lane_id & 0xf) >> 3) << 2; + if (layout == nvcuda::wmma::mem_col_major) { + const int offset = col * ldm + row; + f.x[0] = cast(p[offset]); + f.x[1] = cast(p[offset + ldm]); + f.x[2] = cast(p[offset + 2 * ldm]); + f.x[3] = cast(p[offset + 3 + ldm]); + f.x[4] = cast(p[offset + 8 * ldm]); + f.x[5] = cast(p[offset + 9 * ldm]); + f.x[6] = cast(p[offset + 10 * ldm]); + f.x[7] = cast(p[offset + 11 * ldm]); + } else { + const int offset = row * ldm + col; + float2 *src = (float2 *)(p + offset); + float2 *dst = (float2 *)f.x; + dst[0] = src[0]; + dst[1] = src[2]; + } +} + +template +__device__ inline void load_matrix_sync(fragment &f, + const T *const p, const unsigned ldm, const nvcuda::wmma::layout_t layout) { + const unsigned lane_id = get_lane_id(); + const unsigned row = (lane_id & 0x5) + ((lane_id >> 4) << 3); + const unsigned col = ((lane_id & 0x2)) + ((lane_id & 0x8) >> 1); + if (layout == nvcuda::wmma::mem_col_major) { + const int offset = col * ldm + row; + f.x[0] = cast(p[offset]); + f.x[1] = cast(p[offset + ldm]); + f.x[2] = cast(p[offset + 2]); + f.x[3] = cast(p[offset + 2 + ldm]); + f.x[4] = cast(p[offset + 8 * ldm]); + f.x[5] = cast(p[offset + 9 * ldm]); + f.x[6] = cast(p[offset + 2 + 8 * ldm]); + f.x[7] = cast(p[offset + 2 + 9 * ldm]); + } else { + using SrcType = typename Vector2::Vector2Type; + const int offset = row * ldm + col; + SrcType *src = (SrcType *)(p + offset); + float2 *dst = (float2 *)f.x; + dst[0] = cast(src[0]); + dst[1] = cast(src[ldm]); + dst[2] = cast(src[4]); + dst[3] = cast(src[ldm + 4]); + } +} + +template +__device__ inline void load_matrix_sync(fragment &f, + const T *const p, const unsigned ldm, const nvcuda::wmma::layout_t layout) { + const unsigned lane_id = get_lane_id(); + const unsigned row = (lane_id & 0x1) + (lane_id & 0x18); + const unsigned col = ((lane_id & 0x2)) + ((lane_id & 0x4) << 1); + + if (layout == nvcuda::wmma::mem_col_major) { + const int offset = col * ldm + row; + f.x[0] = cast(p[offset + 0]); + f.x[1] = cast(p[offset + ldm]); + f.x[2] = cast(p[offset + 2]); + f.x[3] = cast(p[offset + 2 + ldm]); + f.x[4] = cast(p[offset + 16 * ldm]); + f.x[5] = cast(p[offset + 17 * ldm]); + f.x[6] = cast(p[offset + 16 * ldm + 2]); + f.x[7] = cast(p[offset + 17 * ldm + 2]); + f.x[8] = cast(p[offset + 4 * ldm]); + f.x[9] = cast(p[offset + 5 * ldm]); + f.x[10] = cast(p[offset + 4 * ldm + 2]); + f.x[11] = cast(p[offset + 5 * ldm + 2]); + f.x[12] = cast(p[offset + 20 * ldm]); + f.x[13] = cast(p[offset + 21 * ldm]); + f.x[14] = cast(p[offset + 20 * ldm + 2]); + f.x[15] = cast(p[offset + 21 * ldm + 2]); + f.x[16] = cast(p[offset + 4]); + f.x[17] = cast(p[offset + ldm + 4]); + f.x[18] = cast(p[offset + 6]); + f.x[19] = cast(p[offset + ldm + 6]); + f.x[20] = cast(p[offset + 16 * ldm + 4]); + f.x[21] = cast(p[offset + 17 * ldm + 4]); + f.x[22] = cast(p[offset + 16 * ldm + 6]); + f.x[23] = cast(p[offset + 17 * ldm + 6]); + f.x[24] = cast(p[offset + 4 * ldm + 4]); + f.x[25] = cast(p[offset + 5 * ldm + 4]); + f.x[26] = cast(p[offset + 4 * ldm + 6]); + f.x[27] = cast(p[offset + 5 * ldm + 6]); + f.x[28] = cast(p[offset + 20 * ldm + 4]); + f.x[29] = cast(p[offset + 21 * ldm + 4]); + f.x[30] = cast(p[offset + 20 * ldm + 6]); + f.x[31] = cast(p[offset + 21 * ldm + 6]); + } else { + using SrcType = typename Vector2::Vector2Type; + const int offset = row * ldm + col; + SrcType *src = (SrcType *)(p + offset); + float2 *dst = (float2 *)f.x; + dst[0] = cast(src[0]); + dst[ldm] = cast(src[1]); + dst[8] = cast(src[2]); + dst[ldm + 8] = cast(src[3]); + dst[2] = cast(src[4]); + dst[ldm + 2] = cast(src[5]); + dst[10] = cast(src[6]); + dst[ldm + 10] = cast(src[7]); + dst[2 * ldm] = cast(src[8]); + dst[3 * ldm] = cast(src[9]); + dst[2 * ldm + 8] = cast(src[10]); + dst[3 * ldm + 8] = cast(src[11]); + dst[2 * ldm + 2] = cast(src[12]); + dst[3 * ldm + 2] = cast(src[13]); + dst[2 * ldm + 10] = cast(src[14]); + dst[3 * ldm + 10] = cast(src[15]); + } +} + +template +__device__ inline void store_matrix_sync(T *const p, + const fragment &f, + const unsigned ldm, const nvcuda::wmma::layout_t layout) { + const unsigned lane_id = get_lane_id(); + const unsigned row = (lane_id & 0x7) + ((lane_id >> 4) << 3); + const unsigned col = ((lane_id & 0xf) >> 3) << 2; + if (layout == nvcuda::wmma::mem_col_major) { + const int offset = col * ldm + row; + p[offset + 0] = cast(f.x[0]); + p[offset + ldm] = cast(f.x[1]); + p[offset + 2 * ldm] = cast(f.x[2]); + p[offset + 3 * ldm] = cast(f.x[3]); + p[offset + 8 * ldm] = cast(f.x[4]); + p[offset + 9 * ldm] = cast(f.x[5]); + p[offset + 10 * ldm] = cast(f.x[6]); + p[offset + 11 * ldm] = cast(f.x[7]); + } else { + const int offset = row * ldm + col; + float2 *dst = (float2 *)(p + offset); + float2 *src = (float2 *)f.x; + dst[0] = src[0]; + dst[2] = src[1]; + } +} + +template +__device__ inline void store_matrix_sync(T *const p, fragment &f, + const unsigned ldm, const nvcuda::wmma::layout_t layout) { + const unsigned lane_id = get_lane_id(); + const unsigned row = (lane_id & 0x5) + ((lane_id >> 4) << 3); + const unsigned col = ((lane_id & 0x2)) + ((lane_id & 0x8) >> 1); + + if (layout == nvcuda::wmma::mem_col_major) { + const int offset = col * ldm + row; + p[offset + 0] = cast(f.x[0]); + p[offset + ldm] = cast(f.x[1]); + p[offset + 2] = cast(f.x[2]); + p[offset + 2 + ldm] = cast(f.x[3]); + p[offset + 8 * ldm] = cast(f.x[4]); + p[offset + 9 * ldm] = cast(f.x[5]); + p[offset + 2 + 8 * ldm] = cast(f.x[6]); + p[offset + 2 + 9 * ldm] = cast(f.x[7]); + } else { + using DstType = typename Vector2::Vector2Type; + const int offset = row * ldm + col; + DstType *dst = (DstType *)(p + offset); + float2 *src = (float2 *)f.x; + dst[0] = cast(src[0]); + dst[ldm] = cast(src[1]); + dst[4] = cast(src[2]); + dst[ldm + 4] = cast(src[3]); + } +} + +template +__device__ inline void store_matrix_sync(T *const p, fragment &f, + const unsigned ldm, const nvcuda::wmma::layout_t layout) { + const unsigned lane_id = get_lane_id(); + const unsigned row = (lane_id & 0x1) + (lane_id & 0x18); + const unsigned col = ((lane_id & 0x2)) + ((lane_id & 0x4) << 1); + + if (layout == nvcuda::wmma::mem_col_major) { + const int offset = col * ldm + row; + p[offset + 0] = cast(f.x[0]); + p[offset + ldm] = cast(f.x[1]); + p[offset + 2] = cast(f.x[2]); + p[offset + 2 + ldm] = cast(f.x[3]); + p[offset + 16 * ldm] = cast(f.x[4]); + p[offset + 17 * ldm] = cast(f.x[5]); + p[offset + 16 * ldm + 2] = cast(f.x[6]); + p[offset + 17 * ldm + 2] = cast(f.x[7]); + p[offset + 4 * ldm] = cast(f.x[8]); + p[offset + 5 * ldm] = cast(f.x[9]); + p[offset + 4 * ldm + 2] = cast(f.x[10]); + p[offset + 5 * ldm + 2] = cast(f.x[11]); + p[offset + 20 * ldm] = cast(f.x[12]); + p[offset + 21 * ldm] = cast(f.x[13]); + p[offset + 20 * ldm + 2] = cast(f.x[14]); + p[offset + 21 * ldm + 2] = cast(f.x[15]); + p[offset + 4] = cast(f.x[16]); + p[offset + ldm + 4] = cast(f.x[17]); + p[offset + 6] = cast(f.x[18]); + p[offset + ldm + 6] = cast(f.x[19]); + p[offset + 16 * ldm + 4] = cast(f.x[20]); + p[offset + 17 * ldm + 4] = cast(f.x[21]); + p[offset + 16 * ldm + 6] = cast(f.x[22]); + p[offset + 17 * ldm + 6] = cast(f.x[23]); + p[offset + 4 * ldm + 4] = cast(f.x[24]); + p[offset + 5 * ldm + 4] = cast(f.x[25]); + p[offset + 4 * ldm + 6] = cast(f.x[26]); + p[offset + 5 * ldm + 6] = cast(f.x[27]); + p[offset + 20 * ldm + 4] = cast(f.x[28]); + p[offset + 21 * ldm + 4] = cast(f.x[29]); + p[offset + 20 * ldm + 6] = cast(f.x[30]); + p[offset + 21 * ldm + 6] = cast(f.x[31]); + } else { + using DstType = typename Vector2::Vector2Type; + const int offset = row * ldm + col; + DstType *dst = (DstType *)(p + offset); + float2 *src = (float2 *)f.x; + dst[0] = cast(src[0]); + dst[ldm] = cast(src[1]); + dst[8] = cast(src[2]); + dst[ldm + 8] = cast(src[3]); + dst[2] = cast(src[4]); + dst[ldm + 2] = cast(src[5]); + dst[10] = cast(src[6]); + dst[ldm + 10] = cast(src[7]); + dst[2 * ldm] = cast(src[8]); + dst[3 * ldm] = cast(src[9]); + dst[2 * ldm + 8] = cast(src[10]); + dst[3 * ldm + 8] = cast(src[11]); + dst[2 * ldm + 2] = cast(src[12]); + dst[3 * ldm + 2] = cast(src[13]); + dst[2 * ldm + 10] = cast(src[14]); + dst[3 * ldm + 10] = cast(src[15]); + } +} + +/* + * FP32 MMA functions for shape 16x16xk + */ +#define MMA_M16N16_F32_F32(A_LAYOUT, B_LAYOUT, K) \ + __device__ inline void mma_sync( \ + fragment &d, \ + const fragment &a, \ + const fragment &b, \ + const fragment &c) { \ + asm volatile ("{mma.sync.aligned.m8n8k4." #A_LAYOUT "." #B_LAYOUT ".f32.f16.f16.f32" \ + "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}," \ + "{%10, %11}, {%12, %13, %14, %15, %16, %17, %18, %19};}" \ + : "=f"(d.x[0]), "=f"(d.x[1]), "=f"(d.x[2]), "=f"(d.x[3]), \ + "=f"(d.x[4]), "=f"(d.x[5]), "=f"(d.x[6]), "=f"(d.x[7]) \ + : "r"(*reinterpret_cast(a.x)), \ + "r"(*reinterpret_cast(a.x + 2)), \ + "r"(*reinterpret_cast(b.x)), \ + "r"(*reinterpret_cast(b.x + 2)), "f"(c.x[0]), \ + "f"(c.x[1]), "f"(c.x[2]), "f"(c.x[3]), \ + "f"(c.x[4]), "f"(c.x[5]), "f"(c.x[6]), "f"(c.x[7])); \ + for (int k = 4; k < K; k += 4) { \ + asm volatile ("{mma.sync.aligned.m8n8k4." #A_LAYOUT "." #B_LAYOUT ".f32.f16.f16.f32" \ + "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}," \ + "{%10, %11}, {%12, %13, %14, %15, %16, %17, %18, %19};}" \ + : "=f"(d.x[0]), "=f"(d.x[1]), "=f"(d.x[2]), "=f"(d.x[3]), \ + "=f"(d.x[4]), "=f"(d.x[5]), "=f"(d.x[6]), "=f"(d.x[7]) \ + : "r"(*reinterpret_cast(a.x + k)), \ + "r"(*reinterpret_cast(a.x + k + 2)), \ + "r"(*reinterpret_cast(b.x + k)), \ + "r"(*reinterpret_cast(b.x + k + 2)), \ + "f"(d.x[0]), "f"(d.x[1]), "f"(d.x[2]), "f"(d.x[3]), \ + "f"(d.x[4]), "f"(d.x[5]), "f"(d.x[6]) "f"(d.x[7])); \ + } \ + } + +MMA_M16N16_F32_F32(col, col, 4); +MMA_M16N16_F32_F32(row, col, 4); +MMA_M16N16_F32_F32(col, row, 4); +MMA_M16N16_F32_F32(row, row, 4); +MMA_M16N16_F32_F32(col, col, 8); +MMA_M16N16_F32_F32(row, col, 8); +MMA_M16N16_F32_F32(col, row, 8); +MMA_M16N16_F32_F32(row, row, 8); + +/* + * FP16 MMA functions for shape 16x16xk + */ +#define MMA_M16N16_F16_F16(A_LAYOUT, B_LAYOUT, K) \ + __device__ inline void mma_sync( \ + fragment &d, \ + const fragment &a, \ + const fragment &b, \ + const fragment &c) { \ + asm volatile ("{mma.sync.aligned.m8n8k4." #A_LAYOUT "." #B_LAYOUT ".f16.f16.f16.f16" \ + "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%8, %9, %10, %11};}" \ + : "=r"(*reinterpret_cast(d.x)), \ + "=r"(*reinterpret_cast(d.x + 2)), \ + "=r"(*reinterpret_cast(d.x + 4)), \ + "=r"(*reinterpret_cast(d.x + 6)) \ + : "r"(*reinterpret_cast(a.x)), \ + "r"(*reinterpret_cast(a.x + 2)), \ + "r"(*reinterpret_cast(b.x)), \ + "r"(*reinterpret_cast(b.x + 2)), \ + "r"(*reinterpret_cast(c.x)), \ + "r"(*reinterpret_cast(c.x + 2)), \ + "r"(*reinterpret_cast(c.x + 4)), \ + "r"(*reinterpret_cast(c.x + 6))); \ + for (int k = 4; k < K; k += 4) { \ + asm volatile ("{mma.sync.aligned.m8n8k4." #A_LAYOUT "." #B_LAYOUT ".f16.f16.f16.f16" \ + "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%8, %9, %10, %11};}" \ + : "=r"(*reinterpret_cast(d.x)), \ + "=r"(*reinterpret_cast(d.x + 2)), \ + "=r"(*reinterpret_cast(d.x + 4)), \ + "=r"(*reinterpret_cast(d.x + 6)) \ + : "r"(*reinterpret_cast(a.x + k)), \ + "r"(*reinterpret_cast(a.x + k + 2)), \ + "r"(*reinterpret_cast(b.x + k)), \ + "r"(*reinterpret_cast(b.x + k + 2)), \ + "r"(*reinterpret_cast(d.x)), \ + "r"(*reinterpret_cast(d.x + 2)), \ + "r"(*reinterpret_cast(d.x + 4)), \ + "r"(*reinterpret_cast(d.x + 6))); \ + } \ + } + +MMA_M16N16_F16_F16(col, col, 4); +MMA_M16N16_F16_F16(row, col, 4); +MMA_M16N16_F16_F16(col, row, 4); +MMA_M16N16_F16_F16(row, row, 4); +MMA_M16N16_F16_F16(col, col, 8); +MMA_M16N16_F16_F16(row, col, 8); +MMA_M16N16_F16_F16(col, row, 8); +MMA_M16N16_F16_F16(row, row, 8); + +#define MMA_M32N32K4_F32_F32_(A_LAYOUT, B_LAYOUT, STEP) \ + asm volatile ("{mma.sync.aligned.m8n8k4." #A_LAYOUT "." #B_LAYOUT ".f32.f16.f16.f32" \ + "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}," \ + "{%10, %11}, {%12, %13, %14, %15, %16, %17, %18, %19};}" \ + : "=f"(d.x[0 + (STEP << 3)]), "=f"(d.x[1 + (STEP << 3)]), \ + "=f"(d.x[2 + (STEP << 3)]), "=f"(d.x[3 + (STEP << 3)]), \ + "=f"(d.x[4 + (STEP << 3)]), "=f"(d.x[5 + (STEP << 3)]), \ + "=f"(d.x[6 + (STEP << 3)]), "=f"(d.x[7 + (STEP << 3)]) \ + : "r"(*reinterpret_cast(a.x + ((STEP & 0x2) << 1))), \ + "r"(*reinterpret_cast(a.x + ((STEP & 0x2) << 1) + 2)), \ + "r"(*reinterpret_cast(b.x + ((STEP & 0x1) << 2))), \ + "r"(*reinterpret_cast(b.x + ((STEP & 0x1) << 2) + 2)), \ + "f"(c.x[0 + (STEP << 3)]), "f"(c.x[1 + (STEP << 3)]), \ + "f"(c.x[2 + (STEP << 3)]), "f"(c.x[3 + (STEP << 3)]), \ + "f"(c.x[4 + (STEP << 3)]), "f"(c.x[5 + (STEP << 3)]), \ + "f"(c.x[6 + (STEP << 3)]), "f"(c.x[7 + (STEP << 3)])); + +/* + * FP32 MMA functions for shape 32x32x4 + */ +#define MMA_M32N32K4_F32_F32(A_LAYOUT, B_LAYOUT) \ + __device__ inline void mma_sync( \ + fragment &d, \ + const fragment &a, \ + const fragment &b, \ + const fragment &c) { \ + MMA_M32N32K4_F32_F32_(A_LAYOUT, B_LAYOUT, 0) \ + MMA_M32N32K4_F32_F32_(A_LAYOUT, B_LAYOUT, 1) \ + MMA_M32N32K4_F32_F32_(A_LAYOUT, B_LAYOUT, 2) \ + MMA_M32N32K4_F32_F32_(A_LAYOUT, B_LAYOUT, 3) \ + } + +MMA_M32N32K4_F32_F32(col, row); + +template +__device__ inline void fragment_add(__frag_base &c, const __frag_base &a, const __frag_base &b) { + #pragma unroll + for (unsigned i = 0; i < c.num_elements; i++) { + c.x[i] = a.x[i] + b.x[i]; + } +} + +template +__device__ inline void fragment_sub(__frag_base &c, const __frag_base &a, const __frag_base &b) { + #pragma unroll + for (unsigned i = 0; i < c.num_elements; i++) { + c.x[i] = a.x[i] - b.x[i]; + } +} + +template +__device__ inline void fragment_mul(__frag_base &c, const __frag_base &a, const __frag_base &b) { + #pragma unroll + for (unsigned i = 0; i < c.num_elements; i++) { + c.x[i] = a.x[i] * b.x[i]; + } +} + +template +__device__ inline void fragment_div(__frag_base &c, const __frag_base &a, const __frag_base &b) { + #pragma unroll + for (unsigned i = 0; i < c.num_elements; i++) { + c.x[i] = a.x[i] / b.x[i]; + } +} + +template +__device__ inline void fragment_add(__frag_base &c, const __frag_base &a, const T b) { + #pragma unroll + for (unsigned i = 0; i < c.num_elements; i++) { + c.x[i] = a.x[i] + b; + } +} + +template +__device__ inline void fragment_sub(__frag_base &c, const __frag_base &a, const T b) { + #pragma unroll + for (unsigned i = 0; i < c.num_elements; i++) { + c.x[i] = a.x[i] - b; + } +} + +template +__device__ inline void fragment_mul(__frag_base &c, const __frag_base &a, const T b) { + #pragma unroll + for (unsigned i = 0; i < c.num_elements; i++) { + c.x[i] = a.x[i] * b; + } +} + +template +__device__ inline void fragment_div(__frag_base &c, const __frag_base &a, const T b) { + #pragma unroll + for (unsigned i = 0; i < c.num_elements; i++) { + c.x[i] = a.x[i] / b; + } +} + +} // namespace wmma +} // namespace akg + +#endif // __WMMA_HPP__ diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 4186a3d3ce6cb4d6e3917f4778f6882289b9a285..b01084305938fb2bc296c3cd17a0d800f9419aae 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -131,5 +131,6 @@ REGISTER_PASS(SinkAllocate); REGISTER_PASS(StrideKernelOp); REGISTER_PASS(UnifyLoopVars); REGISTER_PASS(TileCoverCorrect); +REGISTER_PASS(ReconstructLayout); } // namespace ir } // namespace akg diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index f5cb3b01007b6c89a0c2466def352a8c990f6264..0da197a2532a80d3035633a0c4cd1ff7e79f72f9 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -542,6 +542,7 @@ NodeRef LowerStmt(Schedule sch, const Array &in_args, const Arrayinstrument_bound_checkers); @@ -557,11 +558,11 @@ NodeRef LowerStmt(Schedule sch, const Array &in_args, const Arraydouble_buffer_split_loop, - g_attrs.GetBool(kEnableTransferBuffer, false)); + g_attrs.GetBool(kEnableDoubleBuffer, false)); stmt = NEXT_PASS(StorageRewrite, stmt); if (target_platform->device_type == kDLGPU && polyhedral) { diff --git a/src/composite/composite.cc b/src/composite/composite.cc index ac0b6840973c03ea76f55d479e0d79edc09651f4..29f6fcbc0db84d5c4e0a1cfc0b381f699ce4af31 100644 --- a/src/composite/composite.cc +++ b/src/composite/composite.cc @@ -54,9 +54,14 @@ class Emitter : public IRVisitor { real_input.push_back(input); } } + if (op_name == "MatMul") { + op_name = "BatchMatMul"; + } const auto *topi_f = air::runtime::Registry::Get(op_name); - if (topi_f == nullptr) { - topi_f = air::runtime::Registry::Get(opt_.target + '_' + op_name); + if (topi_f == nullptr && opt_.target != "") { + std::string target = opt_.target; + target[0] = std::toupper(target[0]); + topi_f = air::runtime::Registry::Get(target + op_name); } CHECK(topi_f) << "Akg topi has no op: " << op_name; if (op_name == "Reshape") { // reshape's attr may have shape [-1], it will cause error. diff --git a/src/composite/composite_topi.cc b/src/composite/composite_topi.cc index a66621d9c01eb9a239d5a397576a56f7c4b64d7b..88a79a4118e89a207f3440e43b426139fda0aaa7 100644 --- a/src/composite/composite_topi.cc +++ b/src/composite/composite_topi.cc @@ -607,7 +607,7 @@ TVM_REGISTER_GLOBAL("BroadcastTo").set_body([](TVMArgs args, TVMRetValue *rv) { } }); -TVM_REGISTER_GLOBAL("cuda_BatchMatMul").set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("CudaBatchMatMul").set_body([](TVMArgs args, TVMRetValue *rv) { CHECK_GE(args.size(), 2); auto inputs = args[0].operator Array(); auto attrs = args[1].operator OpAttr(); @@ -618,6 +618,8 @@ TVM_REGISTER_GLOBAL("cuda_BatchMatMul").set_body([](TVMArgs args, TVMRetValue *r auto right_matrix = Downcast(inputs[1]); CHECK(attrs.count("transpose_a")); CHECK(attrs.count("transpose_b")); + CHECK(attrs.count("dst_type")); + auto dst_type = GetString(attrs["dst_type"]); bool transpose_a = static_cast(ir::GetInt32Const(Downcast(attrs["transpose_a"]))); bool transpose_b = static_cast(ir::GetInt32Const(Downcast(attrs["transpose_b"]))); auto left_shape = left_matrix->shape; @@ -625,8 +627,8 @@ TVM_REGISTER_GLOBAL("cuda_BatchMatMul").set_body([](TVMArgs args, TVMRetValue *r CHECK_EQ(left_shape.size(), right_shape.size()); auto type_checker = [](const Tensor &input_data, const std::string name) { - if (input_data->dtype != Float(16) && input_data->dtype != Float(32)) { - LOG(FATAL) << "dtype of " << name << " should be float16 or float32"; + if (input_data->dtype != Float(16)) { + LOG(FATAL) << "dtype of input tensor " << name << " should be float16"; } }; @@ -655,7 +657,7 @@ TVM_REGISTER_GLOBAL("cuda_BatchMatMul").set_body([](TVMArgs args, TVMRetValue *r size_t batch_dim = 0; IterVar reduce_k; auto fcompute = [&left_matrix, &right_matrix, &transpose_a, &transpose_b, &reduce_k, - &batch_dim](const Array &indices) { + &batch_dim, &dst_type](const Array &indices) { Array left_indice; Array right_indice; for (size_t i = 0; i < batch_dim; ++i) { @@ -684,6 +686,11 @@ TVM_REGISTER_GLOBAL("cuda_BatchMatMul").set_body([](TVMArgs args, TVMRetValue *r Expr right_buffer = Call::make(right_matrix->dtype, right_matrix->op->name, right_indice, Call::CallType::Halide, right_matrix->op, right_matrix->value_index); + if (dst_type == "float32") { + left_buffer = Cast::make(Float(32), left_buffer); + right_buffer = Cast::make(Float(32), right_buffer); + } + auto matrix_mul = Mul::make(left_buffer, right_buffer); Array reduces; reduces.push_back(reduce_k); @@ -701,7 +708,7 @@ TVM_REGISTER_GLOBAL("cuda_BatchMatMul").set_body([](TVMArgs args, TVMRetValue *r }); // only support fractal_zN: [ko mo mi ki] * [no ko ki ni] = [no mo mi ni] -TVM_REGISTER_GLOBAL("aicore_BatchMatMul").set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("AicoreBatchMatMul").set_body([](TVMArgs args, TVMRetValue *rv) { CHECK_GE(args.size(), 2); auto attrs = args[1].operator OpAttr(); CHECK(attrs.count("transpose_a")); diff --git a/src/composite/optimize/delete_cast.cc b/src/composite/optimize/delete_cast.cc new file mode 100644 index 0000000000000000000000000000000000000000..4a89613f0528fe94688d49680e15e367d8d65c52 --- /dev/null +++ b/src/composite/optimize/delete_cast.cc @@ -0,0 +1,157 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "composite/optimize/delete_cast.h" + +namespace akg { +class DeleteCastMatcher : public IRVisitor { + public: + void Visit_(const Provide *op) { + auto call = op->value.as(); + if (call == nullptr) { + return IRVisitor::Visit_(op); + } + if (call->name == "Cast") { + auto in = call->args[0].as(); + if (in && in->name == matmul_output_) { + cast_matmul_output_ = true; + cast_map_[op->func.as()->name] = in->name; + cast_func_map_[op->func.as()->name] = in->func; + cast_dtype_[in->name] = op->func.as()->dtype; + } + } else if (call->name == "BatchMatMul") { + if (auto gemm = op->func.as()) { + matmul_output_ = gemm->name; + } + } + return IRVisitor::Visit_(op); + } + + inline bool Matched() { return cast_matmul_output_; } + + friend class DeleteCastMutator; + + private: + bool cast_matmul_output_{false}; + std::unordered_map cast_map_; + std::unordered_map cast_func_map_; + std::unordered_map cast_dtype_; + std::string matmul_output_; +}; + +// delete cast for MatMul / BatchMatMul fusion op +class DeleteCastMutator : public IRMutator { + public: + explicit DeleteCastMutator(const DeleteCastMatcher &deletecast_matcher) + : cast_map_(deletecast_matcher.cast_map_), + cast_func_map_(deletecast_matcher.cast_func_map_), + cast_dtype_(deletecast_matcher.cast_dtype_), + matmul_output_(deletecast_matcher.matmul_output_) {} + ~DeleteCastMutator() override = default; + + Stmt Mutate_(const AttrStmt *op, const Stmt &s) { + auto attrs = Downcast>(op->node); + if (attrs.find("is_backed_cast") != attrs.end()) { + for (auto &val : attrs) { + std::string key = val.first; + auto pos = key.find("_format"); + if (pos != std::string::npos) { + std::string src_tensor = key.substr(0, pos); + if (src_tensor == matmul_output_) { + return IRMutator::Mutate(op->body); + } + } + } + } + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + if (op == nullptr) { + return stmt; + } + + for (auto &val : attrs) { + std::string key = val.first; + auto pos = key.find("_format"); + if (pos != std::string::npos) { + std::string src_tensor = key.substr(0, pos); + if (cast_map_.find(src_tensor) != cast_map_.end()) { + std::string dst_tensor = cast_map_[src_tensor] + "_format"; + attrs.Set(dst_tensor, val.second); + } + } + if (val.first == "Akg") { + attrs.Set("dst_type", Expr("float32")); + } + } + return AttrStmt::make(attrs, op->attr_key, op->value, op->body); + } + + Stmt Mutate_(const Block *op, const Stmt &s) { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + if (op == nullptr) { + return stmt; + } + auto pro = op->first.as(); + if (pro == nullptr) { + return stmt; + } + if (auto call = pro->value.as()) { + CHECK(call->args[0].as() != nullptr) << "Cast tensor is not Call op"; + if (call->name == "Cast" && cast_dtype_.find(call->args[0].as()->name) != cast_dtype_.end()) { + return op->rest; + } + } + return stmt; + } + + Expr Mutate_(const Call *op, const Expr &e) final { + Array args; + for (const auto &arg : op->args) { + if (auto tensor = arg.as()) { + if (cast_map_.find(tensor->name) != cast_map_.end()) { + args.push_back(Call::make( + cast_dtype_.at(cast_map_.at(tensor->name)), cast_map_.at(tensor->name), tensor->args, + tensor->call_type, cast_func_map_.at(tensor->name))); + } else { + args.push_back(arg); + } + } else { + args.push_back(arg); + } + } + if (cast_map_.find(op->name) != cast_map_.end()) { + return Call::make(cast_dtype_.at(cast_map_.at(op->name)), cast_map_.at(op->name), + args, op->call_type, cast_func_map_.at(op->name), op->value_index); + } + return Call::make(op->type, op->name, args, op->call_type, op->func); + } + + private: + std::unordered_map cast_map_; + std::unordered_map cast_func_map_; + std::unordered_map cast_dtype_; + std::string matmul_output_; +}; + +Stmt DeleteCast::Run(const Stmt &s) { + DeleteCastMatcher deletecast_matcher; + deletecast_matcher.Visit(s); + if (!deletecast_matcher.Matched()) { + return s; + } + return DeleteCastMutator(deletecast_matcher).Mutate(s); +} +} // namespace akg \ No newline at end of file diff --git a/src/composite/optimize/delete_cast.h b/src/composite/optimize/delete_cast.h new file mode 100644 index 0000000000000000000000000000000000000000..c77a1c51b0986d9bbd9d4a0c82d0b57b5712cd23 --- /dev/null +++ b/src/composite/optimize/delete_cast.h @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef COMPOSITE_OPTIMIZE_DELETE_CAST_H_ +#define COMPOSITE_OPTIMIZE_DELETE_CAST_H_ +#include "composite/optimize/optimize.h" + +namespace akg { +class DeleteCast : public CompositeOptPass { + public: + DeleteCast() { pass_name_ = __FUNCTION__; } + ~DeleteCast() = default; + Stmt Run(const Stmt &s) override; +}; +} // namespace akg +#endif // COMPOSITE_OPTIMIZE_DELETE_CAST_H_ diff --git a/src/composite/optimize/optimize.cc b/src/composite/optimize/optimize.cc index 92867bf94269e3c8d9ac1c68ecfa2dc887437d33..0cf0030b3cd2ba0d5f39a94516acf7b5c371565c 100644 --- a/src/composite/optimize/optimize.cc +++ b/src/composite/optimize/optimize.cc @@ -26,6 +26,7 @@ #include "composite/optimize/ops_combine.h" #include "composite/optimize/intrin_rewriter.h" #include "composite/optimize/complex_expander.h" +#include "composite/optimize/delete_cast.h" namespace akg { Stmt Optimize(Stmt &s, BuildInfo &info) { @@ -58,6 +59,10 @@ Stmt Optimize(Stmt &s, BuildInfo &info) { } // rename MatMul to BatchMatMul pm.RegisterPass(std::make_shared()); + // delete cast for MatMul fusion + if (info.opt.target == "cuda") { + pm.RegisterPass(std::make_shared()); + } // intrin rewrite in ascend if (info.opt.target == "aicore") { pm.RegisterPass(std::make_shared()); diff --git a/src/composite/optimize/reshape_tensor.cc b/src/composite/optimize/reshape_tensor.cc index aae4540c8e53d7aec9f99e6a125260e82a24f349..b3ddd8454ea8ded0a07166f1b3e81150825526d5 100644 --- a/src/composite/optimize/reshape_tensor.cc +++ b/src/composite/optimize/reshape_tensor.cc @@ -281,7 +281,7 @@ class ReshapeTensorMutator : public IRMutator { } } - static Tensor RecoverTensor(const Expr &e) { + Tensor RecoverTensor(const Expr &e) { Tensor ret; auto call = e.as(); if (call == nullptr || call->call_type != Call::CallType::Halide) { diff --git a/src/composite/util.cc b/src/composite/util.cc index 4bc07e2d040d4c0f31025da0a237bc640795cc14..7d24a8d244573b4561634dac383dd970d3b6134e 100644 --- a/src/composite/util.cc +++ b/src/composite/util.cc @@ -65,7 +65,7 @@ bool IsOtherOp(const std::string &op_name) { // if topi support more, add to this list std::unordered_set elems = {"MatMul", "BatchMatMul", "Conv", "Transpose", "Tile", "Assign", "InplaceAssign", "EquivFormat", "TransData", "AddMinValue", - "BroadcastTo", "PadAkg", "UnPadAkg"}; + "BroadcastTo", "PadAkg", "UnPadAkg", "Conv2D"}; return elems.find(op_name) != elems.end(); } bool IsElemwise(const std::string &op_name) { diff --git a/src/include/ir_pass.h b/src/include/ir_pass.h index 286ac952b03439768fc368d8e3f111aecbafc085..ebfea185608e0e50c5a1d00518594e457ca30fe1 100644 --- a/src/include/ir_pass.h +++ b/src/include/ir_pass.h @@ -84,6 +84,14 @@ Stmt InjectDoubleBufferScopeOnGpu(Stmt stmt); */ Stmt InjectTransferBufferScope(Stmt stmt); +/*! + * \brief Rearrange the buffer of shared memory to eliminate the bank conflict. + * + * \param expr The statement to be rearranged. + * \return The statement after rearranged. + */ +Stmt ReconstructLayout(const Stmt &stmt); + Stmt ElementwiseFlatten(Stmt stmt, const Map &extern_buffer, const Map &new_extern_buffer); diff --git a/src/pass/fuse_axis.cc b/src/pass/fuse_axis.cc index 9a70daa3a381704fb95c7ac07ec013c2c4a646d7..6552a8292a827b08f08ee4a2acf2a66a2134c18b 100644 --- a/src/pass/fuse_axis.cc +++ b/src/pass/fuse_axis.cc @@ -23,6 +23,7 @@ #include #include #include +#include namespace akg { namespace ir { @@ -60,13 +61,6 @@ namespace ir { using VarPair = std::pair; using IterVarPair = std::pair; -struct PairHash { - template - size_t operator()(const std::pair &a) const { - return dmlc::HashCombine(std::hash()(a.first), std::hash()(a.second)); - } -}; - struct ArrayIterVarHash { size_t operator()(const Array &arr) const { size_t ret = 0; diff --git a/src/pass/inject_transfer_buffer_scope.cc b/src/pass/inject_transfer_buffer_scope.cc index 425bf6b962819a21fbbbc4d5934121d8f550934f..b79c49922ce5a99470909886de546bab00e8f311 100644 --- a/src/pass/inject_transfer_buffer_scope.cc +++ b/src/pass/inject_transfer_buffer_scope.cc @@ -21,6 +21,8 @@ #include #include #include "build_module.h" +#include "utils.h" +#include "common/common_util.h" namespace akg { namespace ir { @@ -28,9 +30,7 @@ constexpr auto PREFETCH_SCOPE = "double_buffer_scope"; constexpr auto THREAD_GROUP_OFFSET = "thread_group_offset"; constexpr auto WMMA_FACTOR_AB = 16; constexpr auto WMMA_FACTOR_C = 32; -constexpr auto BITS_PER_BYTE = 8; constexpr int REGISTER_FILE_SIZE_PER_SM = 256 * 1024; -constexpr int MAX_SHARED_USAGE = 48 * 1024; constexpr int TOTAL_THREAD_NUM_PER_BLOCK = 1024; constexpr int MIN_OUTER_LOOP = 2; constexpr int MAX_OUTER_LOOP = 64; @@ -60,37 +60,51 @@ class PrefetchScopeInjector : public IRMutator { Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { if (op->attr_key == air::ir::attr::storage_scope && op->value.as()->value == "shared") { touched_.insert(op->node.as()); - } - if (op->attr_key == "shared_mem_promoted_complete") { + } else if (op->attr_key == "shared_mem_promoted_complete") { if_shared_promoted_ = true; - } - if (op->attr_key == "promote_vectorization") { - is_vectorize_ = true; - auto res = IRMutator::Mutate_(op, s); - if (need_prefetch_) { - res = AttrStmt::make(prefetch_var_, PREFETCH_SCOPE, 1, res); - if_prefetch_injected_ = true; - need_prefetch_ = false; + } else if (op->attr_key == "promote_register_to_global" || op->attr_key == "promote_register_to_shared") { + if_shared_finished_ = true; + } else if (op->attr_key == "promote_vectorization") { + if (IsPrefetchBlock(s) && HasOuterLoop()) { + if (loop_nest_.back()->extent.as() != nullptr) { + prefetch_outer_loop_ = (loop_nest_.back()->extent).as()->value; + if_prefetch_injected_ = true; + return AttrStmt::make(prefetch_var_, PREFETCH_SCOPE, 1, s); + } } - is_vectorize_ = false; - return res; + return s; } return IRMutator::Mutate_(op, s); } Stmt Mutate_(const Evaluate *op, const Stmt &s) final { if (is_const(op->value)) return IRMutator::Mutate_(op, s); - const Call *call = op->value.as(); - if (call) { + if (const auto call = op->value.as()) { if (call->is_intrinsic(air::ir::intrinsic::tvm_storage_sync)) { if (if_prefetch_injected_ && (!if_shared_promoted_)) { - return AttrStmt::make(Var(""), "delete_this_sync", Expr("delete_this_sync"), s); + return AttrStmt::make(Var(""), "delete_this_sync", 1, s); + } else if (if_shared_promoted_ && (!if_shared_finished_)) { + return AttrStmt::make(Var(""), "delete_this_sync_for_db", 1, s); } } } return IRMutator::Mutate_(op, s); } + Stmt Mutate_(const IfThenElse *op, const Stmt &s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + if (const auto ifthenelse = stmt.as()) { + if (auto attr = ifthenelse->then_case.as()) { + if (attr->attr_key == PREFETCH_SCOPE) { + Stmt rotated_stmt = IfThenElse::make(ifthenelse->condition, attr->body, ifthenelse->else_case); + rotated_stmt = AttrStmt::make(attr->node, attr->attr_key, attr->value, rotated_stmt); + return rotated_stmt; + } + } + } + return stmt; + } + bool IsPrefetchBlock(const Stmt &s) { if (auto store = s.as()) { auto it = touched_.find(store->buffer_var.get()); @@ -106,34 +120,37 @@ class PrefetchScopeInjector : public IRMutator { if (IsPrefetchBlock(attr->body)) { return true; } + } else if (auto cond = s.as()) { + if (IsPrefetchBlock(cond->then_case)) { + return true; + } } return false; } bool HasOuterLoop() { return !loop_nest_.empty(); } Stmt Mutate_(const For *op, const Stmt &s) final { if (IsPrefetchBlock(s) && HasOuterLoop()) { - prefetch_outer_loop_ = (loop_nest_.back()->extent).as()->value; - if (is_vectorize_) { - need_prefetch_ = true; - return s; + if (loop_nest_.back()->extent.as() != nullptr) { + prefetch_outer_loop_ = (loop_nest_.back()->extent).as()->value; + if_prefetch_injected_ = true; + return AttrStmt::make(prefetch_var_, PREFETCH_SCOPE, 1, s); } - if_prefetch_injected_ = true; - return AttrStmt::make(prefetch_var_, PREFETCH_SCOPE, 1, s); - } else { - loop_nest_.push_back(op); - auto stmt = IRMutator::Mutate_(op, s); - loop_nest_.pop_back(); - return stmt; } + loop_nest_.push_back(op); + auto stmt = IRMutator::Mutate_(op, s); + loop_nest_.pop_back(); + return stmt; } Stmt Mutate_(const Store *op, const Stmt &s) final { if (IsPrefetchBlock(s) && HasOuterLoop()) { - if_prefetch_injected_ = true; - return AttrStmt::make(prefetch_var_, PREFETCH_SCOPE, 1, s); - } else { - return IRMutator::Mutate_(op, s); + if (loop_nest_.back()->extent.as() != nullptr) { + prefetch_outer_loop_ = (loop_nest_.back()->extent).as()->value; + if_prefetch_injected_ = true; + return AttrStmt::make(prefetch_var_, PREFETCH_SCOPE, 1, s); + } } + return IRMutator::Mutate_(op, s); } const bool GetIfPrefetchInjected() { return if_prefetch_injected_; } @@ -148,6 +165,7 @@ class PrefetchScopeInjector : public IRMutator { bool is_vectorize_{false}; bool if_prefetch_injected_{false}; bool if_shared_promoted_{false}; + bool if_shared_finished_{false}; int prefetch_outer_loop_; }; @@ -169,26 +187,27 @@ class IfResouceIsEnough : public IRVisitor { } return IRVisitor::Visit_(op); } else if (op->attr_key == air::ir::attr::storage_scope) { - auto alloc = op->body.as(); - if (!promote_local_usage_.defined()) { - promote_local_usage_ = make_const(alloc->extents[0].type(), 0); - } - Expr dtype_factor = alloc->type.bits() / BITS_PER_BYTE; - if (op->value.as()->value == "shared") { - if (!shared_usage_.defined()) { - shared_usage_ = make_const(alloc->extents[0].type(), 0); + if (auto alloc = op->body.as()) { + if (!promote_local_usage_.defined()) { + promote_local_usage_ = make_const(alloc->extents[0].type(), 0); + } + if (op->value.as()->value == "shared") { + if (!shared_usage_.defined()) { + shared_usage_ = make_const(alloc->extents[0].type(), 0); + } + shared_usage_ += + air::arith::ComputeReduce(alloc->extents, Expr()) * alloc->type.lanes() * alloc->type.bytes(); + } else if (op->value.as()->value == "local") { + promote_local_usage_ += + air::arith::ComputeReduce(alloc->extents, Expr()) * alloc->type.lanes() * alloc->type.bytes(); + } else if (op->value.as()->value == "wmma.accumulator") { + promote_local_usage_ += air::arith::ComputeReduce(alloc->extents, Expr()) * alloc->type.lanes() / + Expr(WMMA_FACTOR_C) * alloc->type.bytes(); + } else if (op->value.as()->value == "wmma.matrix_b" || + op->value.as()->value == "wmma.matrix_a") { + promote_local_usage_ += air::arith::ComputeReduce(alloc->extents, Expr()) * alloc->type.lanes() / + Expr(WMMA_FACTOR_AB) * alloc->type.bytes(); } - shared_usage_ += air::arith::ComputeReduce(alloc->extents, Expr()) * alloc->type.lanes() * dtype_factor; - } else if (op->value.as()->value == "local") { - promote_local_usage_ += - air::arith::ComputeReduce(alloc->extents, Expr()) * alloc->type.lanes() * dtype_factor; - } else if (op->value.as()->value == "wmma.accumulator") { - promote_local_usage_ += air::arith::ComputeReduce(alloc->extents, Expr()) * alloc->type.lanes() / - Expr(WMMA_FACTOR_C) * dtype_factor; - } else if (op->value.as()->value == "wmma.matrix_b" || - op->value.as()->value == "wmma.matrix_a") { - promote_local_usage_ += air::arith::ComputeReduce(alloc->extents, Expr()) * alloc->type.lanes() / - Expr(WMMA_FACTOR_AB) * dtype_factor; } return IRVisitor::Visit_(op); } else if (op->attr_key == PREFETCH_SCOPE) { @@ -201,8 +220,7 @@ class IfResouceIsEnough : public IRVisitor { for (unsigned i = 0; i < transfer_loop_nest_.size(); i++) { current_local_usage *= transfer_loop_nest_[i]->extent - transfer_loop_nest_[i]->min; } - Expr dtype_factor = prefetch_data_type_.bits() / BITS_PER_BYTE; - prefetch_local_usage_ += current_local_usage * dtype_factor; + prefetch_local_usage_ += current_local_usage * prefetch_data_type_.bytes(); transfer_loop_nest_.clear(); in_prefetch_buffer_scope_ = false; } else { @@ -216,7 +234,11 @@ class IfResouceIsEnough : public IRVisitor { } void Visit_(const Store *op) { - if (in_prefetch_buffer_scope_) prefetch_data_type_ = op->value.as()->type; + if (in_prefetch_buffer_scope_) { + if (const auto load = op->value.as()) { + prefetch_data_type_ = load->type; + } + } return IRVisitor::Visit_(op); } @@ -272,7 +294,8 @@ class ThreadGroupScopeInjector : public IRMutator { Stmt body = Mutate(op->body); return AttrStmt::make(op->node, op->attr_key, op->value, AttrStmt::make(thread_var_, THREAD_GROUP_OFFSET, thread_offset_, body)); - } else if (op->attr_key == "promote_local_to_global" || op->attr_key == "shared_mem_promoted_complete") { + } else if (op->attr_key == "promote_register_to_shared" || op->attr_key == "promote_shared_to_global" || + op->attr_key == "promote_register_to_global" || op->attr_key == "shared_mem_promoted_complete") { Stmt body = Mutate(op->body); return AttrStmt::make( op->node, op->attr_key, op->value, @@ -305,12 +328,12 @@ Stmt InjectTransferBufferScope(Stmt stmt) { if (tuning) { // tuning: manually control by attrs enable_double_buffer = g_attrs.GetBool(kEnableDoubleBuffer, false); - enable_transfer_buffer = g_attrs.GetBool(kEnableTransferBuffer, false); + enable_transfer_buffer = g_attrs.GetBool(kEnableTransferBuffer, true); enable_thread_group = g_attrs.GetBool(kEnableThreadGroup, false); } else { // not tuning: auto-analyse const int total_shared_usage = resource_calc.GetTotalSharedUsage(); - float shared_mem_rate = float(total_shared_usage * 2) / float(MAX_SHARED_USAGE); + float shared_mem_rate = float(total_shared_usage * 2) / float(common::ADVANCED_SHARED_MEMORY_SIZE); const int bind_thread_num = resource_calc.GetBindThreadNum(); const int total_local_usage = resource_calc.GetTotalLocalUsage(); float local_mem_rate = float(total_local_usage * bind_thread_num) / float(REGISTER_FILE_SIZE_PER_SM); @@ -327,22 +350,34 @@ Stmt InjectTransferBufferScope(Stmt stmt) { } } } - if (enable_transfer_buffer || enable_double_buffer) { - if (enable_thread_group) { + // avoid enabling two modes + if (enable_double_buffer) { + enable_transfer_buffer = false; + } + Stmt stmt_after_prefetch = stmt; + if (enable_transfer_buffer || enable_double_buffer) { + if (enable_thread_group) { const Var thread_group_var = resource_calc.GetThreadGroupVar(); const Expr thread_group_offset = resource_calc.GetThreadGoupOffset(); - return ThreadGroupScopeInjector().Inject(new_stmt, thread_group, thread_group_var, thread_group_offset); - } else { - return new_stmt; - } - } else { - return stmt; - } + stmt_after_prefetch = + ThreadGroupScopeInjector().Inject(new_stmt, thread_group, thread_group_var, thread_group_offset); + } else { + stmt_after_prefetch = new_stmt; + } + } + // add an attr of prefetch_mode + int prefetch_mode = static_cast(PrefetchMode::DEFAULT); + if (enable_double_buffer && enable_thread_group) { + prefetch_mode = static_cast(PrefetchMode::DOUBLEBUFFER_THREADGROUP); + } else if (enable_transfer_buffer && enable_thread_group) { + prefetch_mode = static_cast(PrefetchMode::TRANSFERBUFFER_THREADGROUP); + } else if (enable_double_buffer) { + prefetch_mode = static_cast(PrefetchMode::DOUBLEBUFFER); + } else if (enable_transfer_buffer) { + prefetch_mode = static_cast(PrefetchMode::TRANSFERBUFFER); + } + return AttrStmt::make(Expr("INFO"), ATTR_PREFETCH_MODE, prefetch_mode, stmt_after_prefetch); } } // namespace ir -} // namespace akg - - - - \ No newline at end of file +} // namespace akg \ No newline at end of file diff --git a/src/pass/reconstruct_layout.cc b/src/pass/reconstruct_layout.cc new file mode 100644 index 0000000000000000000000000000000000000000..fc0ca2045571e5fba8b5d789b06bc6e3ac09087e --- /dev/null +++ b/src/pass/reconstruct_layout.cc @@ -0,0 +1,323 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "common/common_util.h" +#include "pass/utils.h" +#include "ir_pass.h" + +namespace akg { +namespace ir { + +class TensorCoreMatcher : public IRVisitor { + public: + void Visit_(const AttrStmt *op) final { + if (op->attr_key == air::ir::attr::pragma_tensor_core) { + tensor_core_on_ = true; + } else if (op->attr_key == air::ir::attr::realize_scope) { + auto pos = op->value.as()->value.find("wmma.matrix_"); + if (pos != std::string::npos) { + wmma_matrix_.insert(std::make_pair( + akg::common::GetGlobalName(op->node.as()->name), op->value.as()->value)); + } + } else if (op->attr_key == "batch_axis_num") { + batch_axis_num_ = op->value.as()->value; + } + IRVisitor::Visit_(op); + } + + void Visit_(const Realize *op) final { + if (tensor_core_on_ && op->func->func_name().find("shared") != std::string::npos) { + std::vector tmp; + tmp.reserve(op->bounds.size() - batch_axis_num_); + for (size_t i = batch_axis_num_; i < op->bounds.size(); i++) { + tmp.push_back(op->bounds[i]->extent); + } + shared_bound_[akg::common::GetGlobalName(op->func->func_name())] = tmp; + } + IRVisitor::Visit_(op); + } + + void Visit_(const Evaluate *op) final { + if (const auto call = op->value.as()) { + if (tensor_core_on_ && call->is_intrinsic(air::ir::intrinsic::tvm_load_matrix_sync)) { + Expr warp_tile_m = call->args[1]; + Expr warp_tile_n = call->args[2]; + Expr warp_tile_k = call->args[3]; + auto it_matrix = wmma_matrix_.find(akg::common::GetGlobalName(call->args[0].as()->name_hint)); + if (it_matrix != wmma_matrix_.end()) { + wmma_layout_.insert(std::make_pair(it_matrix->second, call->args[7].as()->value)); + if (warp_tile_m.as()->value == 16 && warp_tile_n.as()->value == 16 && + warp_tile_k.as()->value == 8) { + auto pair_name = std::pair(it_matrix->second, call->args[7].as()->value); + std::vector tmp; + tmp.reserve(2); + if (it_matrix->second == "wmma.matrix_a" && call->args[7].as()->value == "row_major") { + tmp.emplace_back(warp_tile_m); + tmp.emplace_back(warp_tile_k); + } else if (it_matrix->second == "wmma.matrix_a" && call->args[7].as()->value == "col_major") { + tmp.emplace_back(warp_tile_k); + tmp.emplace_back(warp_tile_m); + } else if (it_matrix->second == "wmma.matrix_b" && call->args[7].as()->value == "row_major") { + tmp.emplace_back(warp_tile_k); + tmp.emplace_back(warp_tile_n); + } else if (it_matrix->second == "wmma.matrix_b" && call->args[7].as()->value == "col_major") { + tmp.emplace_back(warp_tile_n); + tmp.emplace_back(warp_tile_k); + } else { + LOG(FATAL) << "Not supported layout " << call->args[7].as()->value << " for " << it_matrix->second; + } + tile_size_[pair_name] = tmp; + } + } + } + } + + IRVisitor::Visit_(op); + } + + inline bool Matched() { return tensor_core_on_;} + + friend class SharedReconstruction; + + private: + bool tensor_core_on_{false}; + unsigned int batch_axis_num_{0}; + std::unordered_map wmma_matrix_; + std::unordered_map wmma_layout_; + std::unordered_map> shared_bound_; + std::unordered_map, std::vector, PairHash> tile_size_; +}; + +class SharedReconstruction : public IRMutator { + public: + explicit SharedReconstruction(const TensorCoreMatcher &tensorcore_matcher) + : batch_axis_num_(tensorcore_matcher.batch_axis_num_), + wmma_matrix_(tensorcore_matcher.wmma_matrix_), + wmma_layout_(tensorcore_matcher.wmma_layout_), + shared_bound_(tensorcore_matcher.shared_bound_), + tile_size_(tensorcore_matcher.tile_size_) {} + + Stmt Mutate_(const Provide *op, const Stmt &s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + auto it_matrix = wmma_matrix_.find(akg::common::GetGlobalName(op->func->func_name())); + if (it_matrix != wmma_matrix_.end() && op->func->func_name().find("shared") != std::string::npos) { + auto it_layout = wmma_layout_.find(it_matrix->second); + auto pair_name = std::pair(it_layout->first, it_layout->second); + auto it_tile = tile_size_.find(pair_name); + auto it_bound = shared_bound_.find(akg::common::GetGlobalName(op->func->func_name())); + Array fuse_args; + if (it_tile != tile_size_.end() && it_bound != shared_bound_.end()) { + Array split_args; + split_args.push_back(Div::make(op->args[batch_axis_num_], it_tile->second[0])); + split_args.push_back(Div::make(op->args[op->args.size() - 1], it_tile->second[1])); + split_args.push_back(Mod::make(op->args[batch_axis_num_], it_tile->second[0])); + split_args.push_back(Mod::make(op->args[op->args.size() - 1], it_tile->second[1])); + for (size_t i = 0; i < op->args.size(); i++) { + Expr new_arg = op->args[i]; + if (i == batch_axis_num_) { + new_arg = split_args[0]; + } + if (i == op->args.size() - 2) { + new_arg = Add::make(Mul::make(new_arg, + Div::make(it_bound->second.back(), it_tile->second[1])), split_args[1]); + } + if (i == op->args.size() - 1) { + new_arg = Add::make(Mul::make(split_args[2], it_tile->second[1]), split_args[3]); + } + fuse_args.push_back(new_arg); + } + } else { + for (size_t i = 0; i < op->args.size(); i++) { + fuse_args.push_back(op->args[i]); + } + } + return Provide::make(op->func, op->value_index, op->value, fuse_args); + } + return stmt; + } + + Stmt Mutate_(const Evaluate *op, const Stmt &s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + auto wmma_op = op->value.as(); + if (wmma_op->is_intrinsic(air::ir::intrinsic::tvm_load_matrix_sync)) { + auto it_matrix = wmma_matrix_.find(akg::common::GetGlobalName(wmma_op->args[0].as()->name_hint)); + if (it_matrix != wmma_matrix_.end()) { + auto it_layout = wmma_layout_.find(it_matrix->second); + auto it_bound = shared_bound_.find(akg::common::GetGlobalName(wmma_op->args[0].as()->name_hint)); + if (it_bound == shared_bound_.end()) { + LOG(FATAL) << "Insufficient arguments for shared memory tensor " << + akg::common::GetGlobalName(wmma_op->args[0].as()->name_hint); + } + Expr inner_bound = it_bound->second.back(); + for (size_t i = 1; i < it_bound->second.size() - 1; i++) { + inner_bound = inner_bound * it_bound->second[i]; + } + if (wmma_op->args[1].as()->value == 16 && wmma_op->args[2].as()->value == 16 && + wmma_op->args[3].as()->value == 8) { + if ((it_layout->first == "wmma.matrix_a" && it_layout->second == "row_major") || + (it_layout->first == "wmma.matrix_b" && it_layout->second == "col_major")) { + if (inner_bound.as()->value <= 16) { + shared_offset_[akg::common::GetGlobalName(wmma_op->args[0].as()->name_hint)] = IntImm::make(Int(32), 32); + } else if (inner_bound.as()->value <= 40) { + shared_offset_[akg::common::GetGlobalName(wmma_op->args[0].as()->name_hint)] = IntImm::make(Int(32), 16); + } else { + shared_offset_[akg::common::GetGlobalName(wmma_op->args[0].as()->name_hint)] = IntImm::make(Int(32), 8); + } + offset_expr_ = IntImm::make(Int(32), 8); + } else if ((it_layout->first == "wmma.matrix_a" && it_layout->second == "col_major") || + (it_layout->first == "wmma.matrix_b" && it_layout->second == "row_major")) { + if (inner_bound.as()->value <= 32) { + shared_offset_[akg::common::GetGlobalName(wmma_op->args[0].as()->name_hint)] = IntImm::make(Int(32), 32); + } else { + shared_offset_[akg::common::GetGlobalName(wmma_op->args[0].as()->name_hint)] = IntImm::make(Int(32), 16); + } + offset_expr_ = IntImm::make(Int(32), 16); + } else { + LOG(FATAL) << "Not supported layout " << it_layout->second << " for " << it_layout->first; + } + } else { + shared_offset_[akg::common::GetGlobalName(wmma_op->args[0].as()->name_hint)] = IntImm::make(Int(32), 16); + offset_expr_ = IntImm::make(Int(32), wmma_op->args[6].as()->value + 16); + } + auto shared_op = wmma_op->args[5].as(); + auto call_op = shared_op->args[0].as(); + Array fuse_args; + auto pair_name = std::pair(it_layout->first, it_layout->second); + auto it_tile = tile_size_.find(pair_name); + if (it_tile != tile_size_.end() && it_bound != shared_bound_.end()) { + Array split_args; + split_args.push_back(Div::make(call_op->args[batch_axis_num_], it_tile->second[0])); + split_args.push_back(Div::make(call_op->args[call_op->args.size() - 1], it_tile->second[1])); + split_args.push_back(Mod::make(call_op->args[batch_axis_num_], it_tile->second[0])); + split_args.push_back(Mod::make(call_op->args[call_op->args.size() - 1], it_tile->second[1])); + for (size_t i = 0; i < call_op->args.size(); i++) { + Expr new_arg = call_op->args[i]; + if (i == batch_axis_num_) { + new_arg = split_args[0]; + } + if (i == call_op->args.size() - 2) { + new_arg = Add::make(Mul::make(new_arg, + Div::make(it_bound->second.back(), it_tile->second[1])), split_args[1]); + } + if (i == call_op->args.size() - 1) { + new_arg = Add::make(Mul::make(split_args[2], it_tile->second[1]), split_args[3]); + } + fuse_args.push_back(new_arg); + } + } else { + for (size_t i = 0; i < call_op->args.size(); i++) { + fuse_args.push_back(call_op->args[i]); + } + } + Array split_args_in; + split_args_in.push_back( + Call::make(call_op->type, call_op->name, fuse_args, Call::CallType::Halide, call_op->func, call_op->value_index)); + auto new_shared_op = Call::make( + shared_op->type, shared_op->name, split_args_in, shared_op->call_type, shared_op->func, shared_op->value_index); + return Evaluate::make( + Call::make(Handle(), air::ir::intrinsic::tvm_load_matrix_sync, { + wmma_op->args[0], wmma_op->args[1], wmma_op->args[2], + wmma_op->args[3], wmma_op->args[4], new_shared_op, + offset_expr_, wmma_op->args[7]}, Call::Intrinsic)); + } + } + return stmt; + } + + Stmt Mutate_(const Realize *op, const Stmt &s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + auto it_matrix = wmma_matrix_.find(akg::common::GetGlobalName(op->func->func_name())); + if (op != nullptr && it_matrix != wmma_matrix_.end() && op->func->func_name().find("shared") != std::string::npos) { + auto it_layout = wmma_layout_.find(it_matrix->second); + auto pair_name = std::pair(it_layout->first, it_layout->second); + auto offset = shared_offset_.find(akg::common::GetGlobalName(op->func->func_name())); + auto it_tile = tile_size_.find(pair_name); + if (it_tile != tile_size_.end()) { + Region new_bounds; + for (size_t i = 0; i < op->bounds.size(); i++) { + Expr new_extent = op->bounds[i]->extent; + if (i == batch_axis_num_) { + new_extent = new_extent / it_tile->second[0]; + } + if (i == op->bounds.size() - 2) { + new_extent = new_extent * (op->bounds[op->bounds.size() - 1]->extent / it_tile->second[1]); + } + if (i == op->bounds.size() - 1) { + new_extent = offset->second + (it_tile->second[0] * it_tile->second[1]); + } + new_bounds.push_back(Range::make_by_min_extent(op->bounds[i]->min, new_extent)); + } + return Realize::make(op->func, op->value_index, op->type, new_bounds, op->condition, op->body); + } else { + Region new_bounds; + for (size_t i = 0; i < op->bounds.size() - 1; ++i) { + new_bounds.push_back(Range::make_by_min_extent(op->bounds[i]->min, op->bounds[i]->extent)); + } + new_bounds.push_back( + Range::make_by_min_extent(op->bounds[op->bounds.size() - 1]->min, + op->bounds[op->bounds.size() - 1]->extent + offset->second)); + return Realize::make(op->func, op->value_index, op->type, new_bounds, op->condition, op->body); + } + } + return stmt; + } + + Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { + if (op->attr_key == "batch_axis_num") { + return this->Mutate(op->body); + } + return IRMutator::Mutate_(op, s); + } + + private: + unsigned int batch_axis_num_{0}; + std::unordered_map wmma_matrix_; + std::unordered_map wmma_layout_; + std::unordered_map> shared_bound_; + std::unordered_map, std::vector, PairHash> tile_size_; + std::unordered_map shared_offset_; + Expr offset_expr_; +}; + +Stmt ReconstructLayout(const Stmt &stmt) { + TensorCoreMatcher tensorcore_matcher; + tensorcore_matcher.Visit(stmt); + if (!tensorcore_matcher.Matched()) { + return stmt; + } + return SharedReconstruction(tensorcore_matcher).Mutate(stmt); +} + +} // namespace ir +} // namespace akg diff --git a/src/poly/dma_inject.cc b/src/poly/dma_inject.cc index aa332ceb51328c8d9e7bb341188bf45d0fcc376e..322e0ca72d6d4727b0f8cd44fb3de850f141ae40 100644 --- a/src/poly/dma_inject.cc +++ b/src/poly/dma_inject.cc @@ -1331,55 +1331,37 @@ isl::schedule_node InsertExtensionBeforeOrAfter(ScopInfo &scop_info, isl::schedu int index = tree.parent().get_ancestor_child_position(tree.ancestor(2)); if (tree.parent().isa()) { + int children_number = tree.ancestor(2).n_children(); + CHECK(children_number > 0) << "sequence node must have children"; if (isl_bool_true == before) { tree = tree.ancestor(2).child(0).child(0); } else { - int size = tree.ancestor(2).n_children(); - tree = tree.ancestor(2).child(size - 1).child(0); - } - } - - if (scop_info.user_config_.GetTarget() == TARGET_CUDA && USE_SIMPLE_EXTENSION) { - if (auto graft_band = graft.child(0).as()) { - auto graft_domain = graft_band.get_partial_schedule().domain(); - bool is_compute_shared = false; - graft_domain.foreach_set([&is_compute_shared](isl::set s) { - if (s.get_tuple_name() == SHARED_WRITE_ID_NAME) { - is_compute_shared = true; + auto domain = schedule.domain(); + bool is_promoted_shared = false; + domain.foreach_set([&is_promoted_shared](const isl::set &set) -> void { + if (set.get_tuple_name() == SHARED_WRITE_ID_NAME) { + is_promoted_shared = true; } }); - - if (!is_compute_shared) { - return InsertExtensionSimple(tree, graft, before, index); - } - - auto IsComputePromotion = [](const isl::schedule_node &node) -> bool { - if (!node.isa()) { - return false; - } - auto filter_node = node.as(); - isl::union_set uset = filter_node.get_filter(); - bool is_compute_gm = false; - uset.foreach_set([&is_compute_gm](isl::set s) { - if (s.get_tuple_name() == WRITE_ID_NAME) { - is_compute_gm = true; + int size = children_number - 1; + if (is_promoted_shared) { + for (int i = size; i >= 0; --i) { + auto filter_node = tree.ancestor(2).child(i).as(); + isl::union_set uset = filter_node.get_filter(); + std::vector vset; + uset.foreach_set([&vset](isl::set s) { vset.push_back(s); }); + if (vset.empty() || vset[0].get_tuple_name() != WRITE_ID_NAME) { + continue; } - }); - return is_compute_gm; - }; - - if (!tree.has_parent()) { - return InsertExtensionSimple(tree, graft, before, index); - } - auto gm_node = tree.parent(); - bool is_wrong_order = gm_node.has_previous_sibling() && gm_node.previous_sibling().has_previous_sibling() && - IsComputePromotion(gm_node.previous_sibling()); - if (!is_wrong_order) { - return InsertExtensionSimple(tree, graft, before, index); + size = (i == 0) ? 0 : i - 1; + break; + } } - gm_node = gm_node.previous_sibling().previous_sibling().child(0); - tree = gm_node; + tree = tree.ancestor(2).child(size).child(0); } + } + + if (scop_info.user_config_.GetTarget() == TARGET_CUDA && USE_SIMPLE_EXTENSION) { return InsertExtensionSimple(tree, graft, before, index); } diff --git a/src/poly/dma_inject.h b/src/poly/dma_inject.h index 07f24ff7f40039c0a34764575a161ea0346157b7..80a29c4c9229ea54a189570adb1045403c75f130 100644 --- a/src/poly/dma_inject.h +++ b/src/poly/dma_inject.h @@ -26,6 +26,7 @@ namespace akg { namespace ir { namespace poly { enum class ReferenceType : int16_t { Read, Write }; +constexpr auto SYNC_NUMBER_BEFORE_GMWRITE = 3; struct ScopedFootprint { size_t GetBoxDim() const { return box.get_size().size(); } diff --git a/src/poly/dsa_utils.h b/src/poly/dsa_utils.h index 99a45bc33f3ad0c29991549b59d9f84a8f361d9f..8dea447934f19f2ebead81ad79a7a425c6b7d3a7 100644 --- a/src/poly/dsa_utils.h +++ b/src/poly/dsa_utils.h @@ -75,7 +75,7 @@ extern const char *const PRAGMA_MMU_C1WRITE; extern const char *const K_C1; extern const char *const PRAGMA_GEMM_C0; -enum MemType { DDR = 1, C1_, BUF_, C0A_, C0B_, C0C_, BUF_C0_, BUF_C1_, SHARED_, LOCAL_ }; +enum MemType { DDR = 1, C1_, BUF_, C0A_, C0B_, C0C_, BUF_C0_, BUF_C1_, SHARED_, LOCAL_, DDR_LOCAL_ }; using DataFlowAttrs = std::vector>; extern const DataFlowAttrs Mmu_Conv_A; diff --git a/src/poly/gpu_emit/emit_pass.h b/src/poly/gpu_emit/emit_pass.h index 79280f7130b80714f6533564371f61c15c5ac971..e22fa25f889db77be99c684209991fde38eae9ba 100644 --- a/src/poly/gpu_emit/emit_pass.h +++ b/src/poly/gpu_emit/emit_pass.h @@ -17,12 +17,15 @@ #ifndef EMIT_PASS_H_ #define EMIT_PASS_H_ #include "../isl_emitter.h" -#include "../gpu_isl_emitter.h" +#include "gpu_isl_emitter.h" +#include "gpu_isl_emitter_reduce.h" +#include "gpu_isl_emitter_tensor_core.h" namespace akg { namespace ir { namespace poly { -Stmt EmitForTensorCore(Stmt stmt, TensorCoreInfo &info); +Stmt EmitForTensorCore(Stmt stmt, TensorCoreInfo &info, ScopInfo &scop_info); +Stmt EmitForReduce(Stmt stmt, ScopInfo &scop_info); Stmt EmitForTensorCoreDesignOne(Stmt stmt, TensorCoreInfo &info); } // namespace poly } // namespace ir diff --git a/src/poly/gpu_emit/gpu_emit_tensor_core.cc b/src/poly/gpu_emit/gpu_emit_tensor_core.cc deleted file mode 100644 index 22b1d52040deacb31549ab0157ffa1634be4ffc0..0000000000000000000000000000000000000000 --- a/src/poly/gpu_emit/gpu_emit_tensor_core.cc +++ /dev/null @@ -1,672 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! - * \file gpu_emit_tensor_core.cc - */ - -#include "emit_pass.h" -#include - -namespace akg { -namespace ir { -namespace poly { - -class CheckCast : public IRVisitor { - public: - explicit CheckCast() {} - using IRVisitor::Visit_; - - void Visit_(const AttrStmt *op) final { - if (op->attr_key == CAST_FLAG) { - std::string mode = op->value.as()->value; - if (mode == CAST_MODE_1) { - origin_type_ = Float(32); - cast_type_ = Float(16); - } - is_cast_ = true; - IRVisitor::Visit_(op); - return; - } - IRVisitor::Visit_(op); - } - - void Visit_(const Call *op) final { - if (op->is_intrinsic(air::ir::intrinsic::tvm_mma_sync)) { - CHECK_EQ(op->args.size(), 8U); - Expr arg2 = op->args[2]; - Expr arg4 = op->args[4]; - const Variable *a2 = arg2.as(); - CHECK(a2); - const Variable *a4 = arg4.as(); - CHECK(a4); - cast_tensors_.insert(SimplifyName(a2->name_hint)); - cast_tensors_.insert(SimplifyName(a4->name_hint)); - } - IRVisitor::Visit_(op); - } - - bool IsCastAdapt() { return is_cast_; } - friend class CollectInfoToAdaptCast; - - private: - Type origin_type_; - Type cast_type_; - bool is_cast_{false}; - std::set cast_tensors_; -}; - -class CollectInfoToAdaptCast : public IRVisitor { - public: - explicit CollectInfoToAdaptCast(CheckCast &check_cast) - : origin_type_(check_cast.origin_type_), - cast_type_(check_cast.cast_type_), - cast_tensors_(check_cast.cast_tensors_) {} - using IRVisitor::Visit_; - - void Visit_(const AttrStmt *op) final { - if (op->attr_key == GMREAD_FLAG) { - is_global_to_shared_ = true; - IRVisitor::Visit_(op); - is_global_to_shared_ = false; - return; - } - IRVisitor::Visit_(op); - } - - void Visit_(const Provide *op) final { - if (is_global_to_shared_) { - global_to_shared_.insert(op); - } - IRVisitor::Visit_(op); - } - - void Visit_(const Realize *op) final { - std::string name = op->func->func_name(); - if (IsEndsWith(name, SHARE_SUFFIX) && cast_tensors_.count(SimplifyName(name))) { - realize_need_cast_shared_.insert(name); - } else if (IsEndsWith(name, LOCAL_SUFFIX) && cast_tensors_.count(SimplifyName(name))) { - realize_need_cast_local_.insert(name); - } - IRVisitor::Visit_(op); - } - - friend class AdaptCast; - - private: - Type origin_type_; - Type cast_type_; - bool is_global_to_shared_{false}; - std::set cast_tensors_; - - std::set global_to_shared_; - std::set realize_need_cast_shared_; - std::set realize_need_cast_local_; -}; - -class AdaptCast : public IRMutator { - public: - explicit AdaptCast(CollectInfoToAdaptCast &info) - : realize_need_cast_shared_(info.realize_need_cast_shared_), - realize_need_cast_local_(info.realize_need_cast_local_), - global_to_shared_(info.global_to_shared_), - origin_type_(info.origin_type_), - cast_type_(info.cast_type_) {} - - Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { - if (op->attr_key == air::ir::attr::buffer_bind_scope) { - Array arr = Downcast>(op->node); - CHECK_EQ(arr.size(), 2U); - const BufferNode *buffer = arr[0].as(); - const TensorNode *tensor = arr[1].as(); - const Call *tuple = op->value.as(); - CHECK(buffer && tensor); - CHECK(tuple); - if (realize_need_cast_local_.count(buffer->name)) { - NodePtr buffer_node = make_node(); - buffer_node->data = buffer->data; - buffer_node->name = buffer->name; - buffer_node->scope = buffer->scope; - buffer_node->dtype = cast_type_; - buffer_node->shape = buffer->shape; - buffer_node->strides = buffer->strides; - buffer_node->data_alignment = buffer->data_alignment; - buffer_node->elem_offset = buffer->elem_offset; - buffer_node->offset_factor = buffer->offset_factor; - - Buffer buffer_new(buffer_node); - NodePtr tensor_node = make_node(); - tensor_node->value_index = tensor->value_index; - tensor_node->op = tensor->op; - tensor_node->shape = tensor->shape; - tensor_node->dtype = cast_type_; - Tensor tensor_new(tensor_node); - - Array node = {buffer_new, tensor_new}; - Stmt body = this->Mutate(op->body); - return AttrStmt::make(node, op->attr_key, op->value, body); - } - } - return IRMutator::Mutate_(op, s); - } - - Stmt Mutate_(const Realize *op, const Stmt &s) final { - std::string tensor_name = op->func->func_name(); - Stmt stmt = IRMutator::Mutate_(op, s); - op = stmt.as(); - if (op != nullptr) { - if (!realize_need_cast_shared_.count(tensor_name) && !realize_need_cast_local_.count(tensor_name)) { - return stmt; - } - - return Realize::make(op->func, op->value_index, cast_type_, op->bounds, op->condition, op->body); - } - return stmt; - } - - Stmt Mutate_(const Provide *op, const Stmt &s) final { - if (global_to_shared_.count(op)) { - auto value = op->value; - auto call = value.as(); - CHECK(call); - CHECK(call->type == origin_type_); - value = Cast::make(cast_type_, value); - return Provide::make(op->func, op->value_index, value, op->args); - } - return IRMutator::Mutate_(op, s); - } - - private: - std::set realize_need_cast_shared_; - std::set realize_need_cast_local_; - std::set global_to_shared_; - Type origin_type_; - Type cast_type_; -}; - -class AdaptCastDesignOne : public IRMutator { - public: - explicit AdaptCastDesignOne(TensorCoreInfo &info) : cast_tensors_(info.cast_tensors_) {} - - Stmt Mutate_(const Realize *op, const Stmt &s) final { - std::string tensor_name = op->func->func_name(); - Stmt stmt = IRMutator::Mutate_(op, s); - op = stmt.as(); - if (op != nullptr) { - if (!cast_tensors_.count(tensor_name)) { - return stmt; - } - - return Realize::make(op->func, op->value_index, Float(16), op->bounds, op->condition, op->body); - } - return stmt; - } - - private: - std::unordered_set cast_tensors_; -}; - -class ModifySizeOfLocal : public IRMutator { - public: - explicit ModifySizeOfLocal(TensorCoreInfo &info) : info_(info) { - m_size_ = Expr(info_.warp_tile_.m); - m_size_ = Mul::make(m_size_, info_.fragment_m_.defined() ? info_.fragment_m_ : make_const(Int(32), 1)); - n_size_ = Expr(info_.warp_tile_.n); - n_size_ = Mul::make(n_size_, info_.fragment_n_.defined() ? info_.fragment_n_ : make_const(Int(32), 1)); - } - Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { - if (op->attr_key == air::ir::attr::buffer_bind_scope) { - Array arr = Downcast>(op->node); - CHECK_EQ(arr.size(), 2U); - const BufferNode *buffer = arr[0].as(); - const TensorNode *tensor = arr[1].as(); - const Call *tuple = op->value.as(); - CHECK(buffer && tensor); - CHECK(tuple); - NodePtr buffer_node = make_node(); - buffer_node->data = buffer->data; - buffer_node->name = buffer->name; - buffer_node->scope = buffer->scope; - buffer_node->dtype = buffer->dtype; - - auto old_shape = buffer->shape; - size_t len = old_shape.size(); - CHECK_GE(len, 2); - - std::string base_name = SimplifyName(buffer->name); - auto matrix = info_.matrix_abc_[base_name]; - auto major = info_.matrix_major_[base_name]; - - int mod_index = -1; - Expr shape_mod; - bool is_c_matrix = false; - - if (matrix == MATRIX_A) { - if (major == ROW_MAJOR) { - mod_index = len - 2; - } else if (major == COL_MAJOR) { - mod_index = len - 1; - } - shape_mod = m_size_; - } else if (matrix == MATRIX_B) { - if (major == ROW_MAJOR) { - mod_index = len - 1; - } else if (major == COL_MAJOR) { - mod_index = len - 2; - } - shape_mod = n_size_; - } else if (matrix == MATRIX_C) { - is_c_matrix = true; - } - - Array new_shape; - if (!is_c_matrix) { - for (size_t i = 0; i < len; ++i) { - if (i == static_cast(mod_index)) { - CHECK(shape_mod.defined()); - new_shape.push_back(shape_mod); - continue; - } - new_shape.push_back(old_shape[i]); - } - buffer_node->shape = new_shape; - - } else { - int len = buffer->shape.size(); - new_shape = ModifyCShape(buffer, len); - buffer_node->shape = new_shape; - } - - Array strides; - for (size_t i = 1; i < new_shape.size(); ++i) { - Expr stride = IntImm::make(Int(32), 1); - for (size_t j = new_shape.size() - 1; j >= i; --j) { - stride = Mul::make(stride, new_shape[j]); - } - strides.push_back(stride); - } - strides.push_back(make_const(Int(32), 1)); - - buffer_node->strides = strides; - buffer_node->data_alignment = buffer->data_alignment; - - // elem_offset need modify now - auto old_args = tuple->args; - Array new_args; - if (!is_c_matrix) { - size_t mod_index_arg = 2 * mod_index + 1; - for (size_t i = 0; i < old_args.size(); ++i) { - if (i == mod_index_arg) { - new_args.push_back(shape_mod); - continue; - } - new_args.push_back(old_args[i]); - } - } else { - // C Matrix modify - size_t mod_index_n = 2 * (len - 1) + 1; - size_t mod_index_m = 2 * (len - 2) + 1; - for (size_t i = 0; i < old_args.size(); ++i) { - if (i == mod_index_n) { - new_args.push_back(n_size_); - } else if (i == mod_index_m) { - new_args.push_back(m_size_); - } else { - new_args.push_back(old_args[i]); - } - } - } - - Array call_args; - for (size_t i = 0; i < new_args.size();) { - call_args.push_back(new_args[i]); - i += 2; - } - - Expr elem_offset_new = IntImm::make(Int(32), 0); - auto min_bound = info_.min_bounds_[buffer->name]; - CHECK(min_bound.defined()) << "min_bound should be defined"; - CHECK_EQ(call_args.size(), min_bound.size()); - for (size_t i = 0; i < min_bound.size(); i++) { - elem_offset_new = Add::make(elem_offset_new, Mul::make(strides[i], Sub::make(call_args[i], min_bound[i]))); - } - - buffer_node->elem_offset = elem_offset_new; - buffer_node->offset_factor = buffer->offset_factor; - - Buffer buffer_new(buffer_node); - NodePtr tensor_node = make_node(); - tensor_node->value_index = tensor->value_index; - tensor_node->op = tensor->op; - tensor_node->shape = new_shape; - tensor_node->dtype = tensor->dtype; - Tensor tensor_new(tensor_node); - - Array node = {buffer_new, tensor_new}; - - auto tuple_new = - Call::make(tuple->type, tuple->name, new_args, tuple->call_type, tuple->func, tuple->value_index); - - Stmt body = this->Mutate(op->body); - return AttrStmt::make(node, op->attr_key, tuple_new, body); - } - return IRMutator::Mutate_(op, s); - } - - Stmt Mutate_(const Realize *op, const Stmt &s) final { - std::string tensor_name = op->func->func_name(); - Stmt stmt = IRMutator::Mutate_(op, s); - op = stmt.as(); - if (op != nullptr) { - if (!info_.frag_reg_.count(tensor_name)) { - return stmt; - } - - std::string base_name = SimplifyName(tensor_name); - auto matrix = info_.matrix_abc_[base_name]; - auto major = info_.matrix_major_[base_name]; - - Region new_bounds; - size_t len = op->bounds.size(); - CHECK_GE(len, 2) << "bounds size should be greater than 2"; - - int mod_index = -1; - - if (matrix == MATRIX_A) { - if (major == ROW_MAJOR) { - mod_index = len - 2; - } else if (major == COL_MAJOR) { - mod_index = len - 1; - } - new_bounds = ModifyAbRegion(op, mod_index, m_size_); - } else if (matrix == MATRIX_B) { - if (major == ROW_MAJOR) { - mod_index = len - 1; - } else if (major == COL_MAJOR) { - mod_index = len - 2; - } - new_bounds = ModifyAbRegion(op, mod_index, n_size_); - } else if (matrix == MATRIX_C) { - new_bounds = ModifyCRegion(op, len); - } - return Realize::make(op->func, op->value_index, op->type, new_bounds, op->condition, op->body); - } - return stmt; - } - - Region ModifyAbRegion(const Realize *op, int mod_index, Expr mod_size) { - Region new_bounds; - for (size_t i = 0; i < op->bounds.size(); ++i) { - if (i == static_cast(mod_index)) { - new_bounds.push_back(Range::make_by_min_extent(op->bounds[i]->min, mod_size)); - continue; - } else { - new_bounds.push_back(op->bounds[i]); - } - } - return new_bounds; - } - - Region ModifyCRegion(const Realize *op, int len) { - Region new_bounds; - for (size_t i = 0; i < op->bounds.size(); ++i) { - if (i == static_cast(len - 1)) { - new_bounds.push_back(Range::make_by_min_extent(op->bounds[i]->min, n_size_)); - continue; - } else if (i == static_cast(len - 2)) { - new_bounds.push_back(Range::make_by_min_extent(op->bounds[i]->min, m_size_)); - continue; - } else { - new_bounds.push_back(op->bounds[i]); - } - } - return new_bounds; - } - Array ModifyCShape(const BufferNode *op, int len) { - Array new_shape; - for (size_t i = 0; i < op->shape.size(); ++i) { - if (i == static_cast(len - 1)) { - CHECK(n_size_.defined()); - new_shape.push_back(n_size_); - continue; - } else if (i == static_cast(len - 2)) { - CHECK(m_size_.defined()); - new_shape.push_back(m_size_); - continue; - } else { - new_shape.push_back(op->shape[i]); - } - } - return new_shape; - } - - private: - TensorCoreInfo info_; - Expr m_size_; - Expr n_size_; -}; - -class ModifyTheLocalOffset : public IRMutator { - public: - explicit ModifyTheLocalOffset(TensorCoreInfo &info) : info_(info) {} - - Expr Mutate_(const Call *op, const Expr &e) final { - if (op->is_intrinsic(air::ir::intrinsic::tvm_fill_fragment)) { - CHECK_EQ(op->args.size(), 6U); - Array args = op->args; - auto a0 = args[0].as(); - CHECK(a0); - std::string cur_tensor_name = a0->name_hint; - std::string cur_base_name = SimplifyName(cur_tensor_name); - - auto a4 = args[4]; - a4 = ChangeTensorIndex(a4, cur_base_name); - - Array new_args; - for (unsigned int i = 0; i < args.size(); ++i) { - if (i == 4) { - new_args.push_back(a4); - continue; - } - new_args.push_back(args[i]); - } - return Call::make(op->type, op->name, new_args, op->call_type, op->func, op->value_index); - - } else if (op->is_intrinsic(air::ir::intrinsic::tvm_load_matrix_sync)) { - CHECK_EQ(op->args.size(), 8U); - Array args = op->args; - auto a0 = args[0].as(); - CHECK(a0); - std::string cur_tensor_name = a0->name_hint; - std::string cur_base_name = SimplifyName(cur_tensor_name); - auto a4 = args[4]; - a4 = ChangeTensorIndex(a4, cur_base_name); - - Array new_args; - for (unsigned int i = 0; i < args.size(); ++i) { - if (i == 4) { - new_args.push_back(a4); - continue; - } - new_args.push_back(args[i]); - } - return Call::make(op->type, op->name, new_args, op->call_type, op->func, op->value_index); - - } else if (op->is_intrinsic(air::ir::intrinsic::tvm_store_matrix_sync)) { - CHECK_EQ(op->args.size(), 8U); - Array args = op->args; - auto a0 = args[0].as(); - CHECK(a0); - std::string cur_tensor_name = a0->name_hint; - std::string cur_base_name = SimplifyName(cur_tensor_name); - auto a4 = args[4]; - a4 = ChangeTensorIndex(a4, cur_base_name); - - Array new_args; - for (unsigned int i = 0; i < args.size(); ++i) { - if (i == 4) { - new_args.push_back(a4); - continue; - } - new_args.push_back(args[i]); - } - return Call::make(op->type, op->name, new_args, op->call_type, op->func, op->value_index); - - } else if (op->is_intrinsic(air::ir::intrinsic::tvm_mma_sync)) { - CHECK_EQ(op->args.size(), 8U); - Array args = op->args; - auto a0 = args[0].as(); - CHECK(a0); - std::string a0_name = a0->name_hint; - std::string a0_base_name = SimplifyName(a0_name); - auto a1 = args[1]; - a1 = ChangeTensorIndex(a1, a0_base_name); - - auto a2 = args[2].as(); - CHECK(a2); - std::string a2_name = a2->name_hint; - std::string a2_base_name = SimplifyName(a2_name); - auto a3 = args[3]; - a3 = ChangeTensorIndex(a3, a2_base_name); - - auto a4 = args[4].as(); - CHECK(a4); - std::string a4_name = a4->name_hint; - std::string a4_base_name = SimplifyName(a4_name); - auto a5 = args[5]; - a5 = ChangeTensorIndex(a5, a4_base_name); - - auto a6 = args[6].as(); - CHECK(a6); - std::string a6_name = a6->name_hint; - std::string a6_base_name = SimplifyName(a6_name); - auto a7 = args[7]; - a7 = ChangeTensorIndex(a7, a6_base_name); - - Array new_args; - new_args.push_back(args[0]); - new_args.push_back(a1); - new_args.push_back(args[2]); - new_args.push_back(a3); - new_args.push_back(args[4]); - new_args.push_back(a5); - new_args.push_back(args[6]); - new_args.push_back(a7); - - return Call::make(op->type, op->name, new_args, op->call_type, op->func, op->value_index); - } else { - return IRMutator::Mutate_(op, e); - } - } - - private: - Expr ChangeTensorIndex(Expr e, std::string name) { - auto matrix_map_info = info_.matrix_abc_; - if ((matrix_map_info[name] == MATRIX_A) || (matrix_map_info[name] == MATRIX_B)) { - if (e.as()) { - return e; - } - - if (e.as()) { - CHECK_EQ(e.as()->value, 0) << "A B matrix index should be 0"; - return e; - } - - if (e.as()) { - auto mul = e.as(); - auto a = mul->a; - auto b = mul->b; - CHECK(b.as()) << "A B matrix index format error"; - return a; - } - CHECK(false) << "A B matrix index error"; - - } else if (matrix_map_info[name] == MATRIX_C) { - if (e.as()) { - return e; - } - - if (e.as()) { - CHECK_EQ(e.as()->value, 0) << "C matrix index should be 0"; - return e; - } - - if (e.as()) { - auto add = e.as(); - auto a = add->a; - auto b = add->b; - CHECK(a.as()) << "A B matrix index format error"; - auto mul = a.as(); - auto mul_a = mul->a; - auto mul_b = mul->b; - CHECK(mul_b.as()) << "C matrix index mul format error"; - CHECK(info_.fragment_n_.defined()); - a = Mul::make(mul_a, info_.fragment_n_); - if (b.as()) { - auto b_mul = b.as(); - auto b_a = b_mul->a; - auto b_b = b_mul->b; - CHECK(b_b.as()) << "C matrix index b_b mul format error"; - b = b_a; - } - e = Add::make(a, b); - return e; - } - - if (e.as()) { - auto mul = e.as(); - auto a = mul->a; - auto b = mul->b; - CHECK(b.as()) << "C matrix index format error"; - return a; - } - CHECK(false) << "C matrix index error"; - } - - return e; - } - TensorCoreInfo info_; -}; - -class DeleteUselessAttr : public IRMutator { - public: - explicit DeleteUselessAttr() {} - Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { - if ((op->attr_key == GMREAD_FLAG) || (op->attr_key == MATRIX_A) || (op->attr_key == MATRIX_B) || - (op->attr_key == MMA_C) || (op->attr_key == MMA_SYNC) || (op->attr_key == FRAGMENT_A) || - (op->attr_key == FRAGMENT_B)) { - return IRMutator::Mutate(op->body); - } - return IRMutator::Mutate_(op, s); - } -}; - -Stmt EmitForTensorCoreDesignOne(Stmt stmt, TensorCoreInfo &info) { - AdaptCastDesignOne adapt(info); - stmt = adapt.Mutate(stmt); - return stmt; -} - -Stmt EmitForTensorCore(Stmt stmt, TensorCoreInfo &info) { - stmt = ModifySizeOfLocal(info).Mutate(stmt); - stmt = ModifyTheLocalOffset(info).Mutate(stmt); - stmt = DeleteUselessAttr().Mutate(stmt); - - return stmt; -} -} // namespace poly -} // namespace ir -} // namespace akg diff --git a/src/poly/gpu_emit/gpu_isl_emitter.cc b/src/poly/gpu_emit/gpu_isl_emitter.cc new file mode 100644 index 0000000000000000000000000000000000000000..6063b7b31152de010d6828d8739862a2ee5a9dee --- /dev/null +++ b/src/poly/gpu_emit/gpu_isl_emitter.cc @@ -0,0 +1,581 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gpu_isl_emitter.h" +#include "emit_pass.h" +#include +#include + +namespace akg { +namespace ir { +namespace poly { + +Expr GpuIslEmitter::EmitLoad(const isl::ast_expr &expr, const Type type) { + if (PRINT_EMITTER) { + LOG(INFO) << ">>>>>>>>>>>>INPUT AST_NODE[LOAD]<<<<<<<<<<<<<<\n" << expr; + } + if (auto op = expr.as()) { + if (auto access = op.as()) { + CHECK(op.get_arg(0).as()); + auto var = op.get_arg(0).as().get_id(); + Array local_args; + for (unsigned int i = 1; i < op.get_n_arg(); ++i) { + local_args.push_back(Interpret(op.get_arg(i))); + } + + Tensor t = info_.FindTensor(var); + auto call = Call::make(type, t->op->name, local_args, Call::CallType::Halide, t->op, t->value_index); + if (PRINT_EMITTER) { + LOG(INFO) << ">>>>>>>>>>>>OUTPUT STMT<<<<<<<<<<<<\n" << call; + } + return call; + } + } + return Expr(); +} + +Stmt GpuIslEmitter::EmitRead(const isl::ast_node_user &node) { + isl::id node_id = node.get_annotation(); + isl::pw_multi_aff iterator_map = node_info_map_.at(node_id).iterator_map; + isl::pw_multi_aff hoisted = iterator_map.range_factor_range(); + isl::pw_multi_aff original = iterator_map.range_factor_domain().range_factor_range(); + + isl::id original_tensor = original.get_tuple_id(isl_dim_out); + + auto build = node_info_map_.at(node_id).build; + auto lhs = build.access_from(isl::multi_pw_aff(hoisted)); + auto rhs = build.access_from(isl::multi_pw_aff(original)); + + Type type = info_.GetDtypeOf(rhs); + if (auto op = lhs.as()) { + if (auto access = op.as()) { + Expr value = EmitLoad(rhs, type); + auto var = op.get_arg(0).as().get_id(); + + Array local_args; + for (unsigned int i = 1; i < op.get_n_arg(); ++i) { + local_args.push_back(Interpret(op.get_arg(i))); + } + + Tensor t = info_.FindTensor(var); + CHECK(t.defined()); + return Provide::make(t->op, 0, value, local_args); + } + } + return Stmt(); +} + +Stmt GpuIslEmitter::EmitWrite(const isl::ast_node_user &node) { + auto node_id = node.get_annotation(); + CHECK_GT(node_info_map_.count(node_id), 0); + auto iterator_map = node_info_map_.at(node_id).iterator_map; + auto hoisted = iterator_map.range_factor_range(); + auto original = iterator_map.range_factor_domain().range_factor_range(); + + auto build = node_info_map_.at(node_id).build; + auto rhs = build.access_from(isl::multi_pw_aff(hoisted)); + auto lhs = build.access_from(isl::multi_pw_aff(original)); + Type type = info_.GetDtypeOf(lhs); + + if (auto op = lhs.as()) { + if (auto access = op.as()) { + Expr value = EmitLoad(rhs, type); + auto var = op.get_arg(0).as().get_id(); + + Array local_args; + for (unsigned int i = 1; i < op.get_n_arg(); ++i) { + local_args.push_back(Interpret(op.get_arg(static_cast(i)))); + } + + Tensor t = info_.FindTensor(var); + CHECK(t.defined()); + + return Provide::make(t->op, 0, value, local_args); + } + } + return Stmt(); +} + +Stmt GpuIslEmitter::EmitSync() { + return Evaluate::make(Call::make(Int(32), STORAGE_SYNC, {StringImm::make(SYNC_SCOP_SHARED)}, Call::Intrinsic)); +} + +Stmt GpuIslEmitter::EmitStmt(const isl::ast_node_user &node) { + CHECK(node.get_expr().isa()); + isl::ast_expr_op usr_expr = node.get_expr().as(); + CHECK(usr_expr); + auto stmt_id = usr_expr.get_arg(0).as().get_id(); + + if (info_.IsRead(stmt_id)) { + Stmt s; + s = EmitRead(node); + s = AttrStmt::make(Expr(""), GMREAD_FLAG, StringImm::make(GMREAD_FLAG), s); + return s; + } else if (info_.IsWrite(stmt_id)) { + return EmitWrite(node); + } else if (info_.IsSync(stmt_id)) { + return EmitSync(); + } else { + return EmitUserStmt(node); + } +} + +bool GpuIslEmitter::NoNeedToEmitForTempTensor(const isl::id &id) { + bool no_need = true; + auto origin_binds = info_.user_config_.GetOriginBind(); + for (auto i : origin_binds) { + if (!i.first.defined()) continue; + std::string name = i.first->op->name; + if (name == id.name()) { + no_need = false; + break; + } + } + return no_need; +} + +Stmt GpuIslEmitter::EmitBlock(const isl::ast_node_block &block_node) { + std::vector stmts; + + int num = block_node.get_children().size(); + int last_num = 0; + for (int i = num - 1; i >= 0; --i) { + auto child = block_node.get_children().at(i); + + if (auto node = child.as()) { + CHECK(node.get_expr().isa()); + isl::ast_expr_op usr_expr = node.get_expr().as(); + CHECK(usr_expr); + auto stmt_id = usr_expr.get_arg(0).as().get_id(); + if (info_.IsRealize(stmt_id)) { + isl::id new_stmt_id = isl::id(stmt_id.ctx(), stmt_id.name().substr(REALIZE_PREFIX_LEN)); + int stmt_num = stmts.size(); + CHECK_NE(stmt_num, 0) << "when stmt_num is zero, no realize should be emitted!."; + if (stmt_num == 1) { + stmts[0] = InsertRealize(stmts[0], new_stmt_id); + } else { + if (stmt_num - last_num == 1) { + stmts[0] = InsertRealize(stmts[0], new_stmt_id); + } else { + for (int index = stmt_num - 2 - last_num; index >= 0; --index) { + auto p_index = static_cast(index); + stmts[p_index] = Block::make(stmts[p_index], stmts[p_index + 1]); + } + stmts[0] = InsertRealize(stmts[0], new_stmt_id); + } + } + last_num = stmt_num - 1; + continue; + } + } + + Stmt body = EmitAst(child); + if (!body.defined()) continue; + stmts.insert(stmts.begin(), body); + } + + int len = stmts.size(); + + if (len == 0) { + return Stmt(); + } + + if (last_num == len - 1) { + return stmts[0]; + } else { + for (int index = len - 2 - last_num; index >= 0; --index) { + auto p_index = static_cast(index); + stmts[p_index] = Block::make(stmts[p_index], stmts[p_index + 1]); + } + return stmts[0]; + } +} + +Stmt GpuIslEmitter::EmitFor(const isl::ast_node_for &node) { + isl::id isl_iter_id = node.get_iterator().as().get_id(); + VarExpr iter_expr(isl_iter_id.to_str()); + PushIter(iter_expr.get()); + + Expr init_expr = Interpret(node.get_init()); + + auto isl_cond = node.get_cond().as(); + CHECK(isl_cond.as() || isl_cond.as()); + auto cond_lhs = isl_cond.get_arg(0).as(); + CHECK(cond_lhs); + CHECK_EQ(cond_lhs.get_id(), isl_iter_id); + Expr cond_expr = Interpret(isl_cond.get_arg(1)); + + int64_t inc = static_cast(WrappedStrtol(node.get_inc().to_C_str())); + CHECK_NE(inc, 0) << "stride should not be zero!."; + + bool need_to_modify_inc_ = false; + if (inc != 1) { + need_to_modify_inc_ = true; + Expr original_init_expr = init_expr; + init_expr = ModifyTheInitExpr(init_expr); + cond_expr = ModifyTheCondExpr(cond_expr, static_cast(inc)); + Expr modify_iter = ModifyTheIterExpr(iter_expr, static_cast(inc), original_init_expr); + stride_modify_iter_map_[iter_expr.get()] = modify_iter; + } + + if (isl_cond.as()) { + cond_expr = Simplify(cond_expr + 1); + } + + cond_expr = Simplify(cond_expr - init_expr); + + Stmt body_stmt = EmitAst(node.get_body()); + + if (!body_stmt.defined()) { + PopIter(iter_expr.get()); + return Stmt(); + } + + if (need_to_modify_inc_) { + stride_modify_iter_map_.erase(iter_expr.get()); + } + PopIter(iter_expr.get()); + Stmt stmt = For::make(iter_expr, init_expr, cond_expr, ForType::Serial, DeviceAPI::None, body_stmt); + return stmt; +} + +Stmt GpuIslEmitter::EmitIf(const isl::ast_node_if &node) { + Expr cond_expr = Interpret(node.get_cond()); + cur_if_list_.push_back(cond_expr.get()); + Stmt then_case = EmitAst(node.get_then_node()); + if (!then_case.defined()) { + return Stmt(); + } + Stmt else_case; + if (node.has_else_node()) { + else_case = EmitAst(node.get_else_node()); + } + cur_if_list_.pop_back(); + + Stmt s; + if (!cond_expr.defined()) { + s = then_case; + } else { + s = IfThenElse::make(cond_expr, then_case, else_case); + } + + return s; +} + +Expr GpuIslEmitter::ModifyTheInitExpr(const Expr &e) { return 0; } + +Expr GpuIslEmitter::ModifyTheCondExpr(const Expr &e, int inc) { return e / Expr(inc); } + +Expr GpuIslEmitter::ModifyTheIterExpr(const VarExpr &iter, int inc, const Expr &init) { + return Simplify(iter * inc + init); +} + +int GpuIslEmitter::GetThreadExtent(const std::string &name) { + if (name == BLOCK_IDX_X || name == BLOCK_IDX_Y || name == BLOCK_IDX_Z) { + auto block_cfg = info_.user_config_.GetBlockConfig(); + CHECK(block_cfg) << "block config is null."; + return name == BLOCK_IDX_X ? block_cfg->GetX().second + : (name == BLOCK_IDX_Y ? block_cfg->GetY().second : block_cfg->GetZ().second); + } + + if (name == THREAD_IDX_X || name == THREAD_IDX_Y || name == THREAD_IDX_Z) { + auto thread_cfg = info_.user_config_.GetThreadConfig(); + CHECK(thread_cfg) << "thread config is null."; + if (info_.user_config_.GetEnableOneDimThread()) { + return name == THREAD_IDX_X ? (thread_cfg->GetX().second * thread_cfg->GetY().second * thread_cfg->GetZ().second) + : 1; + } + return name == THREAD_IDX_X ? thread_cfg->GetX().second + : (name == THREAD_IDX_Y ? thread_cfg->GetY().second : thread_cfg->GetZ().second); + } + LOG(WARNING) << "Unrecognized thread name " << name; + return 1; +} + +Stmt GpuIslEmitter::Emit(const isl::ast_node &node) { + Stmt stmt = EmitAst(node); + + // emit realize for temporary tensor + stmt = EmitRealizeForGlobalTensor(stmt); + + // iter var node attr emit + std::map::iterator it; + for (it = iter_name_map_.begin(); it != iter_name_map_.end(); it++) { + IterVar axis = IterVarNode::make(Range(), it->second, air::kThreadIndex, it->second->name_hint); + stmt = AttrStmt::make(axis, air::ir::attr::thread_extent, Expr(GetThreadExtent(it->second->name_hint)), stmt); + } + + // attr for one dimension mapping + if (info_.user_config_.GetEnableOneDimThread()) { + auto thread_cfg = info_.user_config_.GetThreadConfig(); + CHECK(thread_cfg) << "thread config is null."; + int tx = thread_cfg->GetX().second; + stmt = AttrStmt::make(Expr(""), ORIGIN_THREAD_DIM_X, Expr(tx), stmt); + } + + return stmt; +} + +Stmt GpuIslEmitter::EmitRealizeForGlobalTensor(Stmt stmt) { + auto binds = info_.user_config_.GetBind(); + auto origin_binds = info_.user_config_.GetOriginBind(); + std::unordered_set tensor_name; + + for (auto i : binds) { + if (!i.first.defined()) continue; + tensor_name.insert(i.first->op->name); + } + + for (auto i : binds) { + if (!i.first.defined()) continue; + // input and output tensor, no need to emit realize + if (origin_binds.find(i.first) != origin_binds.end()) { + continue; + } + + // promoted tensor, the realize info already emitted before + std::string name = i.first->op->name; + if (IsEndsWith(name, MEM_TYPE_SHARED) || IsEndsWith(name, MEM_TYPE_LOCAL)) { + continue; + } + + // if the tensor is temporary, but has already promoted, there is no need to emit realize + if (tensor_name.find(name + "_" + MEM_TYPE_SHARED) != tensor_name.end() || + tensor_name.find(name + "_" + MEM_TYPE_LOCAL) != tensor_name.end()) { + continue; + } + + // if the tensor is temporary and it is not promoted, it needs to emit realize + stmt = InsertRealize(stmt, isl::id(info_.GetCtx(), name)); + } + return stmt; +} + +Stmt GpuIslEmitter::EmitMark(const isl::ast_node_mark &node) { + std::string mark = node.get_id().get_name(); + + // add for prefetch pass + if (mark == PROMOTE_GLOBAL_TO_SHARED_AB) { + Stmt stmt = EmitAst(node.get_node()); + if (!stmt.defined()) { + return Stmt(); + } + return AttrStmt::make(Expr("INFO"), SHARED_MEM_PROMOTED_COMPLETE, StringImm::make(SHARED_MEM_PROMOTED_COMPLETE), + stmt); + } + + Stmt stmt; + + if ((mark == PROMOTE_VECTORIZATION) || (mark == PROMOTE_REGISTER_TO_GLOBAL) || (mark == PROMOTE_REGISTER_TO_SHARED) || + (mark == PROMOTE_SHARED_TO_GLOBAL)) { + stmt = EmitAst(node.get_node()); + if (!stmt.defined()) { + return Stmt(); + } + stmt = AttrStmt::make(Expr("INFO"), mark, StringImm::make(mark), stmt); + } else { + stmt = EmitAst(node.get_node()); + } + + return stmt; +} + +std::string GpuIslEmitter::FindRealizeScopeToString(const isl::id &var) { + if (info_.analysis_result_.CountBufferDefInfo(var)) { + auto tensor_info = info_.analysis_result_.GetBufferDefInfo(var); + MemType mem_type = tensor_info.DstMemType(); + + switch (mem_type) { + case MemType::SHARED_: + return MEM_TYPE_SHARED; + case MemType::LOCAL_: + return MEM_TYPE_LOCAL; + default: + LOG(FATAL) << "unexpected mem_type of var " << var; + return "ERROR"; + } + } + return ""; +} + +Expr GpuIslEmitter::FindRealizeScope(const isl::id &var) { return Expr(FindRealizeScopeToString(var)); } + +Stmt GpuIslEmitter::InsertRealize(Stmt stmt, const isl::id &var) { + stmt = FindInnerRealize(var.get_name()).Mutate(stmt); + + // A tensor may be defined multiple times in BufferDefInfo due to nested realize. + // Because we cannot determine which one we actually want, we have to be conservative here + // and allocate space for the largest shape to avoid overflow. + Tensor t = info_.FindTensorWithLargestShape(var); + Region bounds; + + // no isolate + if (bounds.empty()) { + for (auto j : t->shape) { + bounds.push_back(Range::make_by_min_extent(Expr(0), j)); + } + } + + // If isolate, make a new buffer + auto buf = info_.user_config_.GetBind().at(t); + + auto tt = placeholder(t->shape, t->dtype, t->op->name); + + stmt = TensorSubstitute(stmt, t->op, tt->op, tt->value_index); + t = tt; + if (info_.analysis_result_.CountBufferDefInfo(var)) { + auto decl = info_.analysis_result_.GetBufferDefInfo(var); + decl.tensor = t; + } + info_.user_config_.SetBind(t, buf); + stmt = TensorSubstitute2(stmt, t->op->func_name(), t->op, t->value_index); + stmt = Realize::make(t->op, t->value_index, t->dtype, bounds, const_true(1), stmt); + stmt = AttrStmt::make(t->op, air::ir::attr::realize_scope, FindRealizeScope(var), stmt); + + return stmt; +} + +Expr GpuIslEmitter::IterNameAdaptor(std::string name) { + if (iter_name_map_.find(name) != iter_name_map_.end()) { + return iter_name_map_[name]; + } else if (name.find(REPLACE) != std::string::npos) { + name = name.substr(strlen(REPLACE)); + return AdaptPolyNewVar(name); + } else { + return VarExpr(name); + } +} + +// if new var is added in poly process, modify the logic here. +// another modify pos is IterNameAdaptor interface +Expr GpuIslEmitter::AdaptPolyNewVar(std::string name) { + Expr e; + std::string t0_string = T0; + int suffix_len = t0_string.size() + 1; + auto tensor_name = name.substr(0, name.size() - suffix_len); + if (!info_.user_config_.GetReplaceConfig().count(tensor_name)) { + return e; + } + auto mapping_cfg = (info_.user_config_.GetReplaceConfig()[tensor_name]); + CHECK(mapping_cfg) << "mapping config is null."; + if (mapping_cfg->type == MappingType::REPLACE_THREADS) { + e = AdaptThreadNewVar(name, mapping_cfg); + } else { + e = AdaptBlockNewVar(name, mapping_cfg); + } + CHECK(e.defined()) << "new var is null"; + return e; +} + +Expr GpuIslEmitter::AdaptBlockNewVar(const std::string name, MappingCfg *mapping_cfg) { + Expr e; + if (name.find(CONV_H_W) != std::string::npos) { + int mx = mapping_cfg->GetX().second; + if (name.find(B0) != std::string::npos) { + e = Mod::make(iter_name_map_[B1], mx); + return e; + } else if (name.find(B1) != std::string::npos) { + e = Div::make(iter_name_map_[B1], mx); + return e; + } + } else if (name.find(CONV_N) != std::string::npos) { + return iter_name_map_[B2]; + } else if (name.find(CONV_O) != std::string::npos) { + return iter_name_map_[B0]; + } + return e; +} + +Expr GpuIslEmitter::AdaptThreadNewVar(const std::string name, MappingCfg *mapping_cfg) { + Expr e; + int mx = mapping_cfg->GetX().second; + if (name.find(WARP_COMPUTE) != std::string::npos) { + if (name.find(T0) != std::string::npos) { + e = Div::make(iter_name_map_[T0], WARP_SIZE); + e = Mod::make(e, mx); + return e; + } else if (name.find(T1) != std::string::npos) { + e = Div::make(iter_name_map_[T0], WARP_SIZE); + e = Div::make(e, mx); + return e; + } + } else { + Expr div_e = iter_name_map_[T0]; + for (size_t i = 0; i < mapping_cfg->bound; ++i) { + std::string thread_id_name = "t" + std::to_string(i); + if (name.find(thread_id_name) == std::string::npos) { + continue; + } + + e = iter_name_map_[T0]; + int thread_id_number = mapping_cfg->GetAt(i).second; + + if (i == 0) { + e = Mod::make(e, thread_id_number); + return e; + } + + for (size_t j = 0; j < i; ++j) { + thread_id_number = mapping_cfg->GetAt(j).second; + e = Div::make(e, thread_id_number); + } + + thread_id_number = mapping_cfg->GetAt(i).second; + e = Mod::make(e, thread_id_number); + return e; + } + } + return e; +} + +Expr GpuIslEmitter::Interpret(const isl::ast_expr &e) { + if (auto int_expr = e.as()) { + return Expr(IslExprToSInt(int_expr)); + } else if (auto id_expr = e.as()) { + // If this variable is defined by loop index, we need sharing it. + const Variable *var = GetIterByName(id_expr.get_id().get_name()); + if (var) { + if (stride_modify_iter_map_.find(var) != stride_modify_iter_map_.end()) { + return stride_modify_iter_map_[var]; + } + return VarExpr(GetObjPtr(var)); + } else { + return IterNameAdaptor(id_expr.get_id().to_str()); + } + } else if (auto op_expr = e.as()) { + return InterpretOp(op_expr); + } else { + LOG(FATAL) << "NYI " << e; + return 0; + } +} + +Stmt GpuIslEmitter::EmitAccessNodeFromPromoteAcsCall(isl::id var, const Node *node, Array &args) { + const Call *call = static_cast(node); + Tensor t = info_.FindTensor(var); + return Evaluate::make(Call::make(call->type, var.get_name(), args, call->call_type, t->op, t->value_index)); +} + +Stmt GpuIslEmitter::EmitAccessNodeFromPromoteAcsProvide(isl::id var, const Node *node, Array &args) { + const auto provide = static_cast(node); + Tensor t = info_.FindTensor(var); + Stmt s = Provide::make(t->op, 0, provide->value, args); + return s; +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/gpu_emit/gpu_isl_emitter.h b/src/poly/gpu_emit/gpu_isl_emitter.h new file mode 100644 index 0000000000000000000000000000000000000000..9a7416730ddefbd98b9376da104b9aa190b6fad6 --- /dev/null +++ b/src/poly/gpu_emit/gpu_isl_emitter.h @@ -0,0 +1,88 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_GPU_ISL_EMITTER_H_ +#define POLY_GPU_ISL_EMITTER_H_ + +#include "poly/isl_emitter.h" +#include "pass/utils.h" + +namespace akg { +namespace ir { +namespace poly { + +// add for mind tricks swizzle +constexpr auto MIND_TRICKS_SWIZZLE_MARKER = "mind_trick_swizzle_marker"; +constexpr auto MIND_TRICKS_SWIZZLE_PRAGMA = "pragma_swizzle"; + +// add for one dimension mapping +constexpr auto ORIGIN_THREAD_DIM_X = "bind_thread_x"; +constexpr auto SHARED_MEM_PROMOTED_COMPLETE = "shared_mem_promoted_complete"; + +class GpuIslEmitter : public IslEmitter { + public: + GpuIslEmitter(ScopInfo &info, const NodeInfoRepo &n, const isl::id_list &i) : IslEmitter(info, n, i) {} + ~GpuIslEmitter() override = default; + + bool NoNeedToEmitForTempTensor(const isl::id &id); + Stmt Emit(const isl::ast_node &node) override; + Expr Interpret(const isl::ast_expr &e); + Stmt EmitStmt(const isl::ast_node_user &node) override; + Stmt EmitMark(const isl::ast_node_mark &node_id) override; + virtual Expr AdaptPolyNewVar(std::string name); + Expr AdaptThreadNewVar(const std::string name, MappingCfg *mapping_cfg); + Expr AdaptBlockNewVar(const std::string name, MappingCfg *mapping_cfg); + int GetThreadExtent(const std::string &name); + virtual Expr IterNameAdaptor(std::string name); + std::map iter_name_map_{{B0, VarExpr(BLOCK_IDX_X)}, {B1, VarExpr(BLOCK_IDX_Y)}, + {B2, VarExpr(BLOCK_IDX_Z)}, {T0, VarExpr(THREAD_IDX_X)}, + {T1, VarExpr(THREAD_IDX_Y)}, {T2, VarExpr(THREAD_IDX_Z)}}; + + private: + // override emitters for GPU + Stmt EmitBlock(const isl::ast_node_block &node) final; + Stmt EmitFor(const isl::ast_node_for &node) final; + Stmt EmitIf(const isl::ast_node_if &node) final; + + // DMA emitters for GPU + Expr EmitLoad(const isl::ast_expr &lhs, Type type); + Stmt EmitRead(const isl::ast_node_user &node); + Stmt EmitWrite(const isl::ast_node_user &node); + + Stmt EmitAccessNodeFromPromoteAcsCall(isl::id var, const Node *node, Array &args); + Stmt EmitAccessNodeFromPromoteAcsProvide(isl::id var, const Node *node, Array &args); + + Stmt EmitSync(); + Stmt EmitAttr(); // thread_extent, virtual_thread + + Expr FindRealizeScope(const isl::id &var); + std::string FindRealizeScopeToString(const isl::id &var); + Stmt InsertRealize(Stmt stmt, const isl::id &var); + + Expr SingleConfigToMultiBand(std::string name); + + Expr ModifyTheInitExpr(const Expr &e); + Expr ModifyTheCondExpr(const Expr &e, int inc); + Expr ModifyTheIterExpr(const VarExpr &iter, int inc, const Expr &init); + + Stmt EmitRealizeForGlobalTensor(Stmt stmt); + + std::unordered_map stride_modify_iter_map_; +}; + +} // namespace poly +} // namespace ir +} // namespace akg +#endif // POLY_GPU_ISL_EMITTER_H_ diff --git a/src/poly/gpu_emit/gpu_isl_emitter_reduce.cc b/src/poly/gpu_emit/gpu_isl_emitter_reduce.cc new file mode 100644 index 0000000000000000000000000000000000000000..85a6cd2bd6d36cc27d73a56a37ef85e61019809c --- /dev/null +++ b/src/poly/gpu_emit/gpu_isl_emitter_reduce.cc @@ -0,0 +1,108 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "emit_pass.h" +#include "gpu_isl_emitter_reduce.h" + +namespace akg { +namespace ir { +namespace poly { + +Stmt GpuIslEmitterReduce::Emit(const isl::ast_node &node) { + Stmt stmt = GpuIslEmitter::Emit(node); + + stmt = EmitForReduce(stmt, info_); + + return stmt; +} + +Stmt GpuIslEmitterReduce::EmitMark(const isl::ast_node_mark &node) { + std::string mark = node.get_id().get_name(); + if (IsStartsWith(mark, REDUCE_ATOMIC_FLAG) || mark == REDUCE_AREA_FLAG) { + Stmt stmt = EmitAst(node.get_node()); + if (!stmt.defined()) { + return Stmt(); + } + return AttrStmt::make(Expr("INFO"), mark, StringImm::make(mark), stmt); + } + return GpuIslEmitter::EmitMark(node); +} + +Stmt GpuIslEmitterReduce::EmitStmt(const isl::ast_node_user &node) { + CHECK(node.get_expr().isa()); + isl::ast_expr_op usr_expr = node.get_expr().as(); + CHECK(usr_expr); + auto stmt_id = usr_expr.get_arg(0).as().get_id(); + auto node_id = node.get_annotation(); + + if (info_.IsWrite(stmt_id)) { + if (info_.IsGMWrite(stmt_id)) { + auto iterator_map = node_info_map_.at(node_id).iterator_map; + auto original = iterator_map.range_factor_domain().range_factor_range(); + auto srcid = original.get_tuple_id(isl_dim_out); + bool no_need_to_emit = GpuIslEmitter::NoNeedToEmitForTempTensor(srcid); + if (no_need_to_emit) return Stmt(); + } + } else if (info_.IsReduceInit(stmt_id) || info_.IsReduceUpdate(stmt_id)) { + return EmitFilter(stmt_id.get_name()); + } + return GpuIslEmitter::EmitStmt(node); +} + +Stmt GpuIslEmitterReduce::EmitFilter(std::string name) { + return Evaluate::make(Call::make(Int(32), name, {}, Call::Extern)); +} + +Stmt GpuIslEmitterReduce::EmitUserStmt(const isl::ast_node_user &node) { + CHECK(node.get_expr().isa()); + isl::ast_expr_op usr_expr = node.get_expr().as(); + stmt_id_ = usr_expr.get_arg(0).as().get_id(); + node_id_ = node.get_annotation(); + const Node *stmt_node = info_.analysis_result_.GetStatementMap().at(stmt_id_); + CHECK(stmt_node); + // compute VarMap to replace old iterators + auto build = node_info_map_.at(node_id_).build; + auto tuple = info_.analysis_result_.GetOperatorDomainMap().at(stmt_id_).tuple; + auto iterator_map = node_info_map_.at(node_id_).iterator_map; + + bool init_stmt_emit = false; + auto ids = info_.analysis_result_.GetReduceInitIds(); + for (auto &i : ids) { + if (i.get_name() == stmt_id_.get_name()) { + init_stmt_emit = true; + break; + } + } + + var_map_.clear(); + for (unsigned int i = 0; i < tuple.size(); ++i) { + isl::id isl_old_iter = tuple.get_id(i); + auto isl_expr = build.expr_from(iterator_map.get_pw_aff(i)); + Expr halide_new_iter = Interpret(isl_expr); + var_map_.emplace(isl_old_iter, halide_new_iter); + } + + Stmt stmt = EmitUserStmtContent(stmt_node); + + if (init_stmt_emit) { + stmt = AttrStmt::make(Expr("INFO"), REDUCE_INIT_FLAG, StringImm::make(""), stmt); + } + return stmt; +} + +} // namespace poly +} // namespace ir +} // namespace akg \ No newline at end of file diff --git a/src/poly/gpu_emit/gpu_isl_emitter_reduce.h b/src/poly/gpu_emit/gpu_isl_emitter_reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..702637fb870282428c9cacc2a5a3c27f3cd4456f --- /dev/null +++ b/src/poly/gpu_emit/gpu_isl_emitter_reduce.h @@ -0,0 +1,91 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_GPU_ISL_EMITTER_REDUCE_H_ +#define POLY_GPU_ISL_EMITTER_REDUCE_H_ + +#include "ir_pass.h" +#include "gpu_isl_emitter.h" + +namespace akg { +namespace ir { +namespace poly { + +/*! + * IslEmitter for GPU + */ +constexpr auto AKG_ALL_REDUCE = "akg_reduce::ALL_REDUCE"; +constexpr auto AKG_X_REDUCE = "akg_reduce::REDUCE2D_X"; +constexpr auto AKG_Y_REDUCE = "akg_reduce::REDUCE2D_Y"; + +// example: +// red_init_SumOp_S_1_0 +constexpr auto REDUCE_FLAG_SIZE = 6; +constexpr auto REDUCE_FLAG_TYPE_POS = 2; +constexpr auto REDUCE_FLAG_STMT_PREFIX_POS = 3; +constexpr auto REDUCE_FLAG_STMT_NUM_POS = 4; +constexpr auto REDUCE_FLAG_REDUCE_INDEX = 5; + +// example: +// atomic_SumOp +constexpr auto REDUCE_ATOMIC_FLAG_SIZE = 2; +constexpr auto REDUCE_ATOMIC_FLAG = "atomic"; +constexpr auto REDUCE_ATOMIC_FLAG_POS = 0; +constexpr auto REDUCE_ATOMIC_FLAG_TYPE_POS = 1; + +constexpr auto DEFAULT_TENSOR_INDEX = "[0]"; + +constexpr auto USELESS_INDEX = "0"; +constexpr auto USELESS_SHAPE_SIZE = "1"; +constexpr auto SCALAR_TENSOR_PREFIX = "acc_"; +constexpr auto SCALAR_KHT_PREFIX = "kahan_t"; +constexpr auto SCALAR_KHY_PREFIX = "kahan_y"; +constexpr auto SCALAR_KHC_PREFIX = "kahan_c"; +constexpr auto SHARED_MEMORY_PREFIX = "__shared__"; +constexpr auto SHARED_TENSOR_PREFIX = "red_buf"; + +constexpr auto REDUCE_LIB_TYPE_ORIGIN = "origin"; +constexpr auto REDUCE_LIB_TYPE_PARIS = "paris"; +constexpr auto AKG_REDUCE_LIB_SPACE = "akg_reduce"; +constexpr auto AKG_REDUCE_LIB_NAME = "AkgReduce"; +constexpr auto AKG_KAHAN_LIB_NAME = "AkgKahanAccumulation"; +constexpr auto PARIS_REDUCE_LIB_SPACE = "paris_reduce"; +constexpr auto PARIS_REDUCE_LIB_NAME = "ParisReduce"; +constexpr auto AKG_REDUCE_RETURN_NAME = "AkgAtomicReturn"; +constexpr auto PARIS_REDUCE_RETURN_NAME = "ParisReturn"; +constexpr auto REDUCE_LIB_TYPE_FLAG = "reduceLibType"; +constexpr auto REDUCE_INIT_FLAG = "InitStmt"; + +constexpr auto MEM_TYPE_SHARED = "shared"; +constexpr auto MEM_TYPE_LOCAL = "local"; + +class GpuIslEmitterReduce : public GpuIslEmitter { + public: + GpuIslEmitterReduce(ScopInfo &info, const NodeInfoRepo &n, const isl::id_list &i) : GpuIslEmitter(info, n, i) {} + ~GpuIslEmitterReduce() override = default; + + Stmt Emit(const isl::ast_node &node) final; + Stmt EmitUserStmt(const isl::ast_node_user &node); + + private: + Stmt EmitMark(const isl::ast_node_mark &node_id) final; + Stmt EmitStmt(const isl::ast_node_user &node) final; + Stmt EmitFilter(std::string name); +}; +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_GPU_ISL_EMITTER_REDUCE_H_ \ No newline at end of file diff --git a/src/poly/gpu_emit/gpu_isl_emitter_tensor_core.cc b/src/poly/gpu_emit/gpu_isl_emitter_tensor_core.cc new file mode 100644 index 0000000000000000000000000000000000000000..56402b52ec42b0b8b0a8060b1bad344bb97776fd --- /dev/null +++ b/src/poly/gpu_emit/gpu_isl_emitter_tensor_core.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "emit_pass.h" +#include "gpu_isl_emitter_reduce.h" + +namespace akg { +namespace ir { +namespace poly { +Stmt GpuIslEmitterTensorCore::Emit(const isl::ast_node &node) { + Stmt stmt = GpuIslEmitter::Emit(node); + + if (info_.user_config_.GetEnableTensorCoreUsePoly() && info_.user_config_.GetEnableEmitCore()) { + stmt = EmitForTensorCore(stmt, tensor_core_info_, info_); + } else { + tensor_core_info_.cast_tensors_ = info_.analysis_result_.GetCastTensors(); + stmt = EmitForTensorCoreDesignOne(stmt, tensor_core_info_); + } + + return stmt; +} + +Stmt GpuIslEmitterTensorCore::EmitStmt(const isl::ast_node_user &node) { + CHECK(node.get_expr().isa()); + isl::ast_expr_op usr_expr = node.get_expr().as(); + CHECK(usr_expr); + auto stmt_id = usr_expr.get_arg(0).as().get_id(); + auto node_id = node.get_annotation(); + + if (info_.IsGMWrite(stmt_id) || info_.IsGMLWrite(stmt_id)) { + auto iterator_map = node_info_map_.at(node_id).iterator_map; + auto original = iterator_map.range_factor_domain().range_factor_range(); + auto srcid = original.get_tuple_id(isl_dim_out); + bool no_need_to_emit = GpuIslEmitter::NoNeedToEmitForTempTensor(srcid); + if (no_need_to_emit) return Stmt(); + } + return GpuIslEmitter::EmitStmt(node); +} + +Stmt GpuIslEmitterTensorCore::EmitMark(const isl::ast_node_mark &node) { + std::string mark = node.get_id().get_name(); + // add for tensor core + if (mark == WARP_MARKER || mark == CONV_KHKW_OUTER) { + Stmt stmt = EmitAst(node.get_node()); + if (!stmt.defined()) { + return Stmt(); + } + return AttrStmt::make(Expr("INFO"), mark, StringImm::make(mark), stmt); + } + return GpuIslEmitter::EmitMark(node); +} + +void GetNameWithoutShared(isl::id &tensor_id, ScopInfo &info) { + if (info.user_config_.GetEnableMatmul()) { + size_t pos = tensor_id.get_name().find(SHARE_SUFFIX); + std::string substr = tensor_id.get_name().substr(0, pos); + if (pos != 0) tensor_id = isl::id(tensor_id.ctx(), substr); + } +} + +isl::multi_aff GpuIslEmitterTensorCore::TensorAccessMultAff(isl::id &tensor_id, const Array &tensor_index, + const isl::id &node_id) { + GetNameWithoutShared(tensor_id, info_); + return IslEmitter::TensorAccessMultAff(tensor_id, tensor_index, node_id); +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/gpu_emit/gpu_isl_emitter_tensor_core.h b/src/poly/gpu_emit/gpu_isl_emitter_tensor_core.h new file mode 100644 index 0000000000000000000000000000000000000000..c1ce8a567fdf0b45f27ab5e912350336e71b62d3 --- /dev/null +++ b/src/poly/gpu_emit/gpu_isl_emitter_tensor_core.h @@ -0,0 +1,90 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_GPU_ISL_EMITTER_TENSOR_CORE_H_ +#define POLY_GPU_ISL_EMITTER_TENSOR_CORE_H_ + +#include "ir_pass.h" +#include "gpu_isl_emitter.h" + +namespace akg { +namespace ir { +namespace poly { + +// add for tensor core +constexpr auto MMA_A = "matrix_a"; +constexpr auto MMA_B = "matrix_b"; +constexpr auto MMA_C = "accumulator"; +constexpr auto MMA_SYNC = "matrix_sync"; +constexpr auto MMA_PREFIX = "matrix_"; +constexpr auto MMA_FILL_STMT_SERIAL = 2; +constexpr auto MMA_SYNC_STMT_SERIAL = 1; +constexpr auto CAST_FLAG = "CAST"; +constexpr auto CAST_MODE_1 = "mode1"; +constexpr auto GMREAD_FLAG = "GMRead"; +constexpr auto FRAGMENT_A = "fragment_a"; +constexpr auto FRAGMENT_B = "fragment_b"; +constexpr auto FRAGMENT_C = "fragment_c"; + +constexpr auto FOR_INFO_COLLECT_DEPTH = 3; +constexpr auto LOCAL_INDEX_POS = 4; +constexpr auto TENSOR_CORE_MODE_ONE = "1"; +constexpr auto TENSOR_CORE_MODE_TWO = "2"; +constexpr auto WARP_MARKER = "warp_marker"; + +constexpr auto DATA_LOAD_STORE_FOR_DEPTH = 2; +constexpr auto DATA_COMPUTE_FOR_DEPTH = 3; +constexpr auto CONV_OUTPUT_DIMENSION = 4; +constexpr auto CONV_MATRIXA_DIMENSION = 4; + +struct Tile { + int m{-1}; + int n{-1}; + int k{-1}; +}; + +class TensorCoreInfo { + public: + Tile warp_tile_; + + std::unordered_map matrix_major_; + std::unordered_map matrix_abc_; + std::unordered_map bounds_; + std::unordered_map> strides_; + std::set frag_reg_; + std::unordered_set cast_tensors_; + std::unordered_map> min_bounds_; + std::string wmma_scope_; +}; + +class GpuIslEmitterTensorCore : public GpuIslEmitter { + public: + GpuIslEmitterTensorCore(ScopInfo &info, const NodeInfoRepo &n, const isl::id_list &i) : GpuIslEmitter(info, n, i) {} + ~GpuIslEmitterTensorCore() override = default; + + Stmt Emit(const isl::ast_node &node) final; + + private: + Stmt EmitStmt(const isl::ast_node_user &node) final; + Stmt EmitMark(const isl::ast_node_mark &node_id) final; + isl::multi_aff TensorAccessMultAff(isl::id &tensor_id, const Array &subscripts, const isl::id &stmt_id); + TensorCoreInfo tensor_core_info_; +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_GPU_ISL_EMITTER_TENSOR_CORE_H_ \ No newline at end of file diff --git a/src/poly/gpu_emit/gpu_reduce_emit_pass.cc b/src/poly/gpu_emit/gpu_reduce_emit_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..59d854fe10b7a07a638472eb94b1b3a4a2bcdceb --- /dev/null +++ b/src/poly/gpu_emit/gpu_reduce_emit_pass.cc @@ -0,0 +1,754 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file gpu_reduce_emit_pass.cc + */ + +#include "emit_pass.h" +#include "gpu_isl_emitter_reduce.h" + +namespace akg { +namespace ir { +namespace poly { + +struct ReduceData { + Expr init_value; + const Provide *origin_reduce_stmt_; + + std::string reduce_stmt_index_; + std::string scalar_tensor_name_; + std::string scalar_kht_name_; + std::string scalar_khy_name_; + std::string scalar_khc_name_; + Expr input_tensor_expr_; + std::string shared_compute_name_; + std::string reduce_op_; + std::string promoted_tensor_name_for_reduce_; + std::string akg_reduce_api_; + std::string akg_reduce_template_arg_; + Type reduce_data_type_info_; + std::map scalar_tensor_; + Tensor shared_tensor_; + std::vector stmts_; +}; + +class ReduceInfoCollect : public IRVisitor { + public: + explicit ReduceInfoCollect(ScopInfo &scop_info) : scop_info_(scop_info) {} + using IRVisitor::Visit_; + + void Visit_(const Call *op) final { + std::string name = op->name; + if (ScopInfo::IsReduceInit(name)) { + reduce_valid_ = true; + in_reduce_area_ = true; + ReduceData reduce_data; + std::vector strs = common::Split(name, "_"); + CHECK_EQ(strs.size(), REDUCE_FLAG_SIZE) << "red update format is not right!."; + + reduce_data.reduce_stmt_index_ = strs[REDUCE_FLAG_REDUCE_INDEX]; + reduce_data.scalar_tensor_name_ = SCALAR_TENSOR_PREFIX; + reduce_data.scalar_tensor_name_ += reduce_data.reduce_stmt_index_; + + reduce_data.shared_compute_name_ = SHARED_TENSOR_PREFIX; + reduce_data.shared_compute_name_ += reduce_data.reduce_stmt_index_; + + if (AkgSupportedReduceOp.count(strs[REDUCE_FLAG_TYPE_POS])) { + reduce_data.reduce_op_ = AKG_REDUCE_LIB_SPACE; + reduce_data.reduce_op_ += "::"; + reduce_data.reduce_op_ += strs[REDUCE_FLAG_TYPE_POS]; + } + CHECK(!reduce_data.reduce_op_.empty()) << "reduce op should not be empty!"; + if (reduce_data.reduce_op_.find("SumOp") != std::string::npos) { + reduce_data.scalar_kht_name_ = SCALAR_KHT_PREFIX; + reduce_data.scalar_kht_name_ += reduce_data.reduce_stmt_index_; + reduce_data.scalar_khy_name_ = SCALAR_KHY_PREFIX; + reduce_data.scalar_khy_name_ += reduce_data.reduce_stmt_index_; + reduce_data.scalar_khc_name_ = SCALAR_KHC_PREFIX; + reduce_data.scalar_khc_name_ += reduce_data.reduce_stmt_index_; + } + cur_reduce_stmt_ = strs[REDUCE_FLAG_STMT_PREFIX_POS] + "_" + strs[REDUCE_FLAG_STMT_NUM_POS]; + + std::string origin_tensor_name = ""; + for (auto it : scop_info_.analysis_result_.GetReduceTensorInfoMap()) { + if (it.first.name() == cur_reduce_stmt_) { + origin_tensor_name = it.second.write_tensor_name; + reduce_data.reduce_data_type_info_ = it.second.write_dtype; + break; + } + } + CHECK(!origin_tensor_name.empty()) << "origin_tensor_name should not be empty!"; + + for (const auto &buffer : scop_info_.analysis_result_.active_buffer_footprints_) { + auto cluster_id = buffer.second.cluster_id; + auto buf_def = scop_info_.analysis_result_.GetBufferDefInfo(cluster_id); + if (buf_def.tensor_id.name() == origin_tensor_name) { + reduce_data.promoted_tensor_name_for_reduce_ = cluster_id.name(); + break; + } + } + + for (auto it : scop_info_.analysis_result_.GetReduceTensorInfoMap()) { + if (it.first.name() == cur_reduce_stmt_) { + reduce_data.init_value = it.second.init_value; + break; + } + } + + MakeAkgReduceFuncName(reduce_data); + SetScalarTensorBind(reduce_data, reduce_data.scalar_tensor_name_); + if (reduce_data.reduce_op_.find("SumOp") != std::string::npos) { + SetScalarTensorBind(reduce_data, reduce_data.scalar_kht_name_); + SetScalarTensorBind(reduce_data, reduce_data.scalar_khy_name_); + SetScalarTensorBind(reduce_data, reduce_data.scalar_khc_name_); + } + SetSharedTensorBind(reduce_data); + + reduce_datas_[cur_reduce_stmt_] = reduce_data; + } else if (ScopInfo::IsReduceUpdate(name)) { + in_reduce_area_ = false; + } + IRVisitor::Visit_(op); + } + + void Visit_(const Provide *op) { + if (in_reduce_area_) { + reduce_datas_[cur_reduce_stmt_].origin_reduce_stmt_ = op; + } + IRVisitor::Visit_(op); + } + + void MakeAkgReduceFuncName(ReduceData &reduce_data) { + auto thread_cfg = scop_info_.user_config_.GetThreadConfig(); + CHECK(thread_cfg) << "thread config is null."; + auto block_cfg = scop_info_.user_config_.GetBlockConfig(); + CHECK(block_cfg) << "thread config is null."; + int tx = thread_cfg->GetX().second; + int ty = thread_cfg->GetY().second; + int by = block_cfg->GetY().second; + std::string direction = scop_info_.analysis_result_.GetReduceDirection(); + CHECK(!direction.empty()) << "direction should not be empty!"; + std::string direction_size = ""; + if (direction == X_DIRECTION) { + direction_size = std::to_string(tx); + } else { + direction_size = std::to_string(ty); + } + + std::string reduce_lib_namespace = ""; + std::string reduce_lib_name = ""; + if (scop_info_.user_config_.GetReduceLibType() == REDUCE_LIB_TYPE_ORIGIN) { + reduce_lib_namespace = AKG_REDUCE_LIB_SPACE; + reduce_lib_name = AKG_REDUCE_LIB_NAME; + } else if (scop_info_.user_config_.GetReduceLibType() == REDUCE_LIB_TYPE_PARIS) { + reduce_lib_namespace = PARIS_REDUCE_LIB_SPACE; + reduce_lib_name = PARIS_REDUCE_LIB_NAME; + } else { + CHECK(false) << "reduce lib type is invalid!" + << "\n"; + } + std::string ret = reduce_lib_namespace; + ret += "::"; + ret += reduce_lib_name; + + reduce_data.akg_reduce_api_ = ret; + ret = ""; + + std::string op = reduce_data.reduce_op_; + ret += op; + ret += ", "; + + ret += std::to_string(tx); + ret += ", "; + ret += std::to_string(ty); + std::string reduce_type = ""; + if (by == 1 && ty == 1) { + reduce_type = AKG_ALL_REDUCE; + } else if (direction == X_DIRECTION) { + reduce_type = AKG_X_REDUCE; + } else { + reduce_type = AKG_Y_REDUCE; + } + ret += ", "; + ret += reduce_type; + + reduce_data.akg_reduce_template_arg_ = ret; + } + void SetScalarTensorBind(ReduceData &reduce_data, std::string scalar_tensor_name) { + Array shapes; + shapes.push_back(Expr(1)); + Type type = reduce_data.reduce_data_type_info_; + + Tensor tensor = placeholder(shapes, type, scalar_tensor_name); + const Buffer buffer = decl_buffer(shapes, type, scalar_tensor_name); + reduce_data.scalar_tensor_[scalar_tensor_name] = tensor; + CHECK(reduce_data.scalar_tensor_[scalar_tensor_name].defined()); + + scop_info_.user_config_.SetBind(tensor, buffer); + } + + void SetSharedTensorBind(ReduceData &reduce_data) { + auto thread_cfg = scop_info_.user_config_.GetThreadConfig(); + CHECK(thread_cfg) << "thread config is null."; + int tx = thread_cfg->GetX().second; + int ty = thread_cfg->GetY().second; + + int size = tx * ty; + Array shapes; + shapes.push_back(Expr(size)); + Type type = reduce_data.reduce_data_type_info_; + std::string shared_tensor_name = reduce_data.shared_compute_name_; + + Tensor tensor = placeholder(shapes, type, shared_tensor_name); + const Buffer buffer = decl_buffer(shapes, type, shared_tensor_name); + reduce_data.shared_tensor_ = tensor; + + scop_info_.user_config_.SetBind(tensor, buffer); + } + + bool is_valid_reduce() { return reduce_valid_; } + + friend class ReduceStmtEmit; + + private: + bool in_reduce_area_{false}; + ScopInfo &scop_info_; + std::map reduce_datas_; + std::string cur_reduce_stmt_{""}; + bool reduce_valid_{false}; +}; + +class AkgReduceStmtChange : public air::ir::IRMutator { + public: + explicit AkgReduceStmtChange(Tensor t, Array args, std::string name) : t(t), args(args), name(name) {} + ~AkgReduceStmtChange() override = default; + + Expr Mutate_(const Call *op, const Expr &e) final { + if (op->name == name) { + return Call::make(op->type, t->op->func_name(), args, op->call_type, t->op, op->value_index); + } + return IRMutator::Mutate_(op, e); + } + + Stmt Mutate_(const Provide *op, const Stmt &s) final { + auto stmt = IRMutator::Mutate_(op, s); + auto new_op = stmt.as(); + CHECK(new_op); + if (new_op->func->func_name() == name) { + return Provide::make(t->op, new_op->value_index, new_op->value, args); + } + return stmt; + } + + private: + Tensor t; + Array args; + std::string name; +}; + +class ReduceStmtEmit : public IRMutator { + public: + explicit ReduceStmtEmit(ReduceInfoCollect &info, ScopInfo &scop_info) + : reduce_datas_(info.reduce_datas_), scop_info_(scop_info) {} + Stmt Mutate_(const AttrStmt *op, const Stmt &s) { + auto key = op->attr_key; + if (key == REDUCE_AREA_FLAG) { + Stmt stmt = IRMutator::Mutate_(op, s); + CHECK(!cur_reduce_stmt_.empty()); + auto reduce_data = reduce_datas_[cur_reduce_stmt_]; + stmt = InsertRealizeWithMemType(stmt, isl::id(scop_info_.ctx_, reduce_data.scalar_tensor_name_), MEM_TYPE_LOCAL); + if (reduce_data.reduce_op_.find("SumOp") != std::string::npos) { + stmt = InsertRealizeWithMemType(stmt, isl::id(scop_info_.ctx_, reduce_data.scalar_kht_name_), MEM_TYPE_LOCAL); + stmt = InsertRealizeWithMemType(stmt, isl::id(scop_info_.ctx_, reduce_data.scalar_khy_name_), MEM_TYPE_LOCAL); + stmt = InsertRealizeWithMemType(stmt, isl::id(scop_info_.ctx_, reduce_data.scalar_khc_name_), MEM_TYPE_LOCAL); + } + stmt = + InsertRealizeWithMemType(stmt, isl::id(scop_info_.ctx_, reduce_data.shared_compute_name_), MEM_TYPE_SHARED); + return stmt; + } + return IRMutator::Mutate_(op, s); + } + + Stmt Mutate_(const Evaluate *op, const Stmt &s) { + Expr value = op->value; + if (value.as()) { + auto call = value.as(); + auto name = call->name; + if (ScopInfo::IsReduceInit(name)) { + in_reduce_area_ = true; + std::vector strs = common::Split(name, "_"); + CHECK_EQ(strs.size(), REDUCE_FLAG_SIZE) << "red init format is not right!."; + + cur_reduce_stmt_ = strs[REDUCE_FLAG_STMT_PREFIX_POS] + "_" + strs[REDUCE_FLAG_STMT_NUM_POS]; + CHECK(reduce_datas_.find(cur_reduce_stmt_) != reduce_datas_.end()); + auto reduce_data = reduce_datas_[cur_reduce_stmt_]; + + Array args; + args.push_back(Expr(0)); + Stmt scalar_stmt = Provide::make(reduce_data.scalar_tensor_[reduce_data.scalar_tensor_name_]->op, 0, + reduce_data.init_value, args); + if (reduce_data.reduce_op_.find("SumOp") != std::string::npos) { + Stmt scalar_khc = Provide::make(reduce_data.scalar_tensor_[reduce_data.scalar_khc_name_]->op, 0, + reduce_data.init_value, args); + CHECK(scalar_khc.defined()); + scalar_stmt = Block::make(scalar_khc, scalar_stmt); + } + + scalar_stmt = AttrStmt::make(Expr("INFO"), name, Expr(""), scalar_stmt); + return scalar_stmt; + } else if (ScopInfo::IsReduceUpdate(name)) { + in_reduce_area_ = false; + auto reduce_data = reduce_datas_[cur_reduce_stmt_]; + return MakeReduceStmt(reduce_data); + } + } + return IRMutator::Mutate_(op, s); + } + + Stmt Mutate_(const Provide *op, const Stmt &s) { + if (in_reduce_area_) { + Array args_scalar; + auto reduce_data = reduce_datas_[cur_reduce_stmt_]; + args_scalar.push_back(Expr(0)); + + Stmt stmt = AkgReduceStmtChange(reduce_data.scalar_tensor_[reduce_data.scalar_tensor_name_], args_scalar, + reduce_data.promoted_tensor_name_for_reduce_) + .Mutate(s); + + if (reduce_data.reduce_op_.find("SumOp") != std::string::npos) { + auto pro = stmt.as(); + CHECK(pro); + auto value = pro->value; + auto add = value.as(); + CHECK(add); + auto add_a = add->a; + auto add_b = add->b; + reduce_data.input_tensor_expr_ = + (add->a.as() && add->a.as()->name == reduce_data.scalar_tensor_name_) ? add_b : add_a; + stmt = TransferToKaHanInterface(reduce_data); + } + + return stmt; + } + return IRMutator::Mutate_(op, s); + } + + Stmt InsertRealizeWithMemType(Stmt stmt, const isl::id &var, std::string mem) { + stmt = FindInnerRealize(var.get_name()).Mutate(stmt); + + Tensor t = scop_info_.FindTensorWithLargestShape(var); + Region bounds; + + // no isolate + if (bounds.empty()) { + for (auto j : t->shape) { + bounds.push_back(Range::make_by_min_extent(Expr(0), j)); + } + } + + // If isolate, make a new buffer + auto buf = scop_info_.user_config_.GetBind().at(t); + + auto tt = placeholder(t->shape, t->dtype, t->op->name); + + stmt = TensorSubstitute(stmt, t->op, tt->op, tt->value_index); + t = tt; + if (scop_info_.analysis_result_.CountBufferDefInfo(var)) { + auto decl = scop_info_.analysis_result_.GetBufferDefInfo(var); + decl.tensor = t; + } + scop_info_.user_config_.SetBind(t, buf); + stmt = TensorSubstitute2(stmt, t->op->func_name(), t->op, t->value_index); + stmt = Realize::make(t->op, t->value_index, t->dtype, bounds, const_true(1), stmt); + stmt = AttrStmt::make(t->op, air::ir::attr::realize_scope, Expr(mem), stmt); + + return stmt; + } + + Stmt MakeReduceStmt(ReduceData &reduce_data) { + std::string func_name = reduce_data.akg_reduce_api_; + std::string op_info = reduce_data.reduce_op_ + "()"; + + Expr template_arg0 = make_const(reduce_data.reduce_data_type_info_, 1); + CHECK(!reduce_data.akg_reduce_template_arg_.empty()); + Expr template_arg1 = StringImm::make(reduce_data.akg_reduce_template_arg_); + + Array args_a1; + Expr a1 = Call::make(Int(32), reduce_data.reduce_op_, args_a1, Call::Extern); + + auto p = reduce_data.origin_reduce_stmt_; + CHECK(p); + Expr a2 = Call::make(p->value.type(), p->func->func_name(), p->args, Call::Halide, p->func, 0); + a2 = Call::make(a2.type(), "&", {a2}, Call::Extern); + + Tensor tensor = scop_info_.FindTensor(reduce_data.shared_compute_name_); + auto bind = scop_info_.user_config_.GetBind(); + Buffer buffer; + for (auto &i : bind) { + if (!i.first.defined()) continue; + if (i.first == tensor) { + buffer = i.second; + } + } + + CHECK(buffer.defined()); + + Tensor tt = reduce_data.scalar_tensor_[reduce_data.scalar_tensor_name_]; + Array args; + args.push_back(Expr(0)); + Expr a4 = Call::make(tt->dtype, tt->op->func_name(), args, Call::Halide, tt->op, 0); + + auto thread_cfg = scop_info_.user_config_.GetThreadConfig(); + CHECK(thread_cfg); + int tx = thread_cfg->GetX().second; + int ty = thread_cfg->GetY().second; + Expr a5 = Expr(tx); + + Stmt stmt = Evaluate::make( + Call::make(Int(32), func_name, {template_arg0, template_arg1, a1, a2, buffer->data, a4, a5}, Call::Extern)); + + stmt = AttrStmt::make(Expr("INFO"), REDUCE_LIB_TYPE_FLAG, scop_info_.user_config_.GetReduceLibType(), stmt); + + int size = tx * ty; + stmt = AttrStmt::make(buffer->data, air::ir::attr::storage_scope, Expr(MEM_TYPE_SHARED), + Allocate::make(buffer->data, buffer->dtype, {Expr(size)}, const_true(), stmt)); + return stmt; + } + + Stmt TransferToKaHanInterface(ReduceData &reduce_data) { + std::string func_name = AKG_REDUCE_LIB_SPACE; + func_name += "::"; + func_name += AKG_KAHAN_LIB_NAME; + Expr template_arg0 = make_const(reduce_data.reduce_data_type_info_, 1); + + Array args; + args.push_back(Expr(0)); + + Tensor tt = reduce_data.scalar_tensor_[reduce_data.scalar_khy_name_]; + Expr a1 = Call::make(tt->dtype, tt->op->func_name(), args, Call::Halide, tt->op, 0); + a1 = Call::make(a1.type(), "&", {a1}, Call::Extern); + + tt = reduce_data.scalar_tensor_[reduce_data.scalar_kht_name_]; + Expr a2 = Call::make(tt->dtype, tt->op->func_name(), args, Call::Halide, tt->op, 0); + a2 = Call::make(a2.type(), "&", {a2}, Call::Extern); + + tt = reduce_data.scalar_tensor_[reduce_data.scalar_khc_name_]; + Expr a3 = Call::make(tt->dtype, tt->op->func_name(), args, Call::Halide, tt->op, 0); + a3 = Call::make(a3.type(), "&", {a3}, Call::Extern); + + tt = reduce_data.scalar_tensor_[reduce_data.scalar_tensor_name_]; + Expr a4 = Call::make(tt->dtype, tt->op->func_name(), args, Call::Halide, tt->op, 0); + a4 = Call::make(a4.type(), "&", {a4}, Call::Extern); + + CHECK(reduce_data.input_tensor_expr_.defined()); + Stmt stmt = Evaluate::make( + Call::make(Int(32), func_name, {template_arg0, a1, a2, a3, a4, reduce_data.input_tensor_expr_}, Call::Extern)); + + return stmt; + } + + private: + std::map reduce_datas_; + bool in_reduce_area_{false}; + bool collect_area_stmt_{false}; + std::string cur_reduce_stmt_{""}; + ScopInfo &scop_info_; + std::vector block_stmts_; + int block_depth_{1}; + bool reduce_start_{false}; + bool reduce_end_{false}; + Stmt rest_part_; +}; + +struct AtomicReturnData { + std::string reduce_op_; + std::string akg_atomic_api_; + std::string akg_atomic_template_arg_; + Type output_tensor_data_type_info_; + Expr atomic_rhs_; + Stmt gm_write_stmt_; +}; + +class AtomicReturnStmtEmit : public IRMutator { + public: + explicit AtomicReturnStmtEmit(ScopInfo &scop_info) : scop_info_(scop_info) {} + + Stmt Mutate_(const AttrStmt *op, const Stmt &s) { + auto key = op->attr_key; + if (IsStartsWith(key, REDUCE_ATOMIC_FLAG)) { + in_atomic_area_ = true; + std::vector strs = common::Split(key, "_"); + CHECK_EQ(strs.size(), REDUCE_ATOMIC_FLAG_SIZE) << "atomic mark format is not right!."; + atomic_data_.reduce_op_.clear(); + if (AkgSupportedReduceOp.count(strs[REDUCE_ATOMIC_FLAG_TYPE_POS])) { + atomic_data_.reduce_op_ = AKG_REDUCE_LIB_SPACE; + atomic_data_.reduce_op_ += "::"; + atomic_data_.reduce_op_ += strs[REDUCE_ATOMIC_FLAG_TYPE_POS]; + } else { + CHECK(false) << "reduce op type is not supported!"; + } + } + return IRMutator::Mutate_(op, s); + } + + Stmt Mutate_(const Provide *op, const Stmt &s) { + if (in_atomic_area_) { + in_atomic_area_ = false; + Stmt stmt = IRMutator::Mutate_(op, s); + atomic_data_.gm_write_stmt_ = stmt; + auto op = stmt.as(); + CHECK(op); + atomic_data_.atomic_rhs_ = op->value; + atomic_data_.output_tensor_data_type_info_ = scop_info_.user_config_.GetDataType(op->func->func_name()); + + ConstructAtomicReturnFuncName(); + return MakeAtomicStmt(); + } + return IRMutator::Mutate_(op, s); + } + + void ConstructAtomicReturnFuncName() { + std::string reduce_lib_namespace = ""; + std::string reduce_return_name = ""; + if (scop_info_.user_config_.GetReduceLibType() == REDUCE_LIB_TYPE_ORIGIN) { + reduce_lib_namespace = AKG_REDUCE_LIB_SPACE; + reduce_return_name = AKG_REDUCE_RETURN_NAME; + } else if (scop_info_.user_config_.GetReduceLibType() == REDUCE_LIB_TYPE_PARIS) { + reduce_lib_namespace = PARIS_REDUCE_LIB_SPACE; + reduce_return_name = PARIS_REDUCE_RETURN_NAME; + } else { + CHECK(false) << "reduce lib type is invalid!" + << "\n"; + } + std::string ret = ""; + ret += reduce_lib_namespace; + ret += "::"; + ret += reduce_return_name; + + atomic_data_.akg_atomic_api_ = ret; + ret = ""; + + std::string op = atomic_data_.reduce_op_; + ret += op; + + atomic_data_.akg_atomic_template_arg_ = ret; + } + + Stmt MakeAtomicStmt() { + std::string func_name = atomic_data_.akg_atomic_api_; + + Expr template_arg0 = make_const(atomic_data_.output_tensor_data_type_info_, 1); + CHECK(!atomic_data_.akg_atomic_template_arg_.empty()); + Expr template_arg1 = StringImm::make(atomic_data_.akg_atomic_template_arg_); + + Expr a1 = atomic_data_.atomic_rhs_; + + auto p = atomic_data_.gm_write_stmt_.as(); + CHECK(p); + + Expr a2 = Call::make(p->value.type(), p->func->func_name(), p->args, Call::Halide, p->func, 0); + a2 = Call::make(a2.type(), "&", {a2}, Call::Extern); + + std::string op_info = atomic_data_.reduce_op_ + "()"; + + Array args; + Expr a3 = Call::make(Int(32), atomic_data_.reduce_op_, args, Call::Extern); + + return Evaluate::make(Call::make(Int(32), func_name, {template_arg0, template_arg1, a1, a2, a3}, Call::Extern)); + } + + private: + ScopInfo &scop_info_; + AtomicReturnData atomic_data_; + bool in_atomic_area_{false}; +}; + +class ConditionExprMod : public air::ir::IRMutator { + public: + explicit ConditionExprMod(bool &is_found) : is_found_(is_found) {} + ~ConditionExprMod() override = default; + + Expr Mutate_(const And *op, const Expr &e) override { + auto o_a = op->a; + auto o_b = op->b; + auto a = air::ir::IRMutator::Mutate(op->a); + auto b = air::ir::IRMutator::Mutate(op->b); + if (!a.defined() && !b.defined()) return Expr(); + if (!a.defined()) return b; + if (!b.defined()) return a; + if (o_a.same_as(a) && o_b.same_as(b)) return e; + return And::make(a, b); + } + + Expr Mutate_(const Or *op, const Expr &e) override { + auto o_a = op->a; + auto o_b = op->b; + auto a = air::ir::IRMutator::Mutate(op->a); + auto b = air::ir::IRMutator::Mutate(op->b); + if (!a.defined() && !b.defined()) return Expr(); + if (!a.defined()) return b; + if (!b.defined()) return a; + if (o_a.same_as(a) && o_b.same_as(b)) return e; + return Or::make(a, b); + } + + Expr Mutate_(const EQ *op, const Expr &e) override { + Expr a = op->a; + Expr b = op->b; + + bool rh_zero = false; + bool lh_block = false; + if (b.as()) { + auto v = b.as(); + if (v->value == 0) rh_zero = true; + } + + if (a.as()) { + auto v = a.as(); + if (v->name_hint == BLOCK_IDX_X) { + lh_block = true; + } + } + + if (rh_zero && lh_block) { + is_found_ = true; + return Expr(); + } + return e; + } + + private: + bool &is_found_; +}; + +class InitStmtIndexModify : public IRMutator { + public: + explicit InitStmtIndexModify(ScopInfo &scop_info) : scop_info_(scop_info) {} + + Stmt Mutate_(const AttrStmt *op, const Stmt &s) { + auto key = op->attr_key; + if (key == REDUCE_INIT_FLAG) { + init_stmt_emit_ = true; + } + return IRMutator::Mutate_(op, s); + } + + Stmt Mutate_(const IfThenElse *op, const Stmt &s) { + Stmt stmt = IRMutator::Mutate_(op, s); + if (init_stmt_emit_) { + if (scop_info_.user_config_.GetEnableAtomicAdd() && !scop_info_.analysis_result_.GetAtomicMarkers().empty()) { + bool is_found = false; + auto op = s.as(); + CHECK(op); + auto condition = op->condition; + condition = ConditionExprMod(is_found).Mutate(condition); + if (is_found) { + init_stmt_emit_ = false; + } + return IfThenElse::make(condition, op->then_case, op->else_case); + } + } + return stmt; + } + + private: + ScopInfo &scop_info_; + bool init_stmt_emit_{false}; +}; + +class DeleteComplicatedSync : public IRMutator { + public: + DeleteComplicatedSync() {} + + Stmt Mutate_(const Block *op, const Stmt &s) { + Stmt first = this->Mutate(op->first); + if (first.as()) { + Expr value = first.as()->value; + if (value.as()) { + auto call = value.as(); + auto name = call->name; + if (name == STORAGE_SYNC) { + emit_sync_ = true; + } else { + emit_sync_ = false; + } + } + } else { + emit_sync_ = false; + } + + Stmt rest = this->Mutate(op->rest); + + if (!first.defined() && !rest.defined()) { + return Stmt(); + } + + if (!first.defined() && rest.defined()) { + return rest; + } + + if (first.defined() && !rest.defined()) { + return first; + } + + if (first.same_as(op->first) && rest.same_as(op->rest)) { + return s; + } else { + return Block::make(first, rest); + } + } + + Stmt Mutate_(const Evaluate *op, const Stmt &s) { + Expr value = op->value; + if (value.as()) { + auto call = value.as(); + auto name = call->name; + if (name == STORAGE_SYNC) { + if (emit_sync_) { + return Stmt(); + } + } + } + return IRMutator::Mutate_(op, s); + } + + private: + bool emit_sync_{false}; +}; + +Stmt EmitForReduce(Stmt stmt, ScopInfo &scop_info) { + ReduceInfoCollect col(scop_info); + col.Visit(stmt); + + if (!col.is_valid_reduce()) { + return stmt; + } + + stmt = ReduceStmtEmit(col, scop_info).Mutate(stmt); + stmt = AtomicReturnStmtEmit(scop_info).Mutate(stmt); + + if (scop_info.user_config_.GetEnableAtomicAdd()) { + stmt = InitStmtIndexModify(scop_info).Mutate(stmt); + } + + stmt = DeleteComplicatedSync().Mutate(stmt); + + return stmt; +} +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/gpu_emit/gpu_tensor_core_emit_pass.cc b/src/poly/gpu_emit/gpu_tensor_core_emit_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..a6b93443ac704855d94dbd9a8f2749da3ad1071b --- /dev/null +++ b/src/poly/gpu_emit/gpu_tensor_core_emit_pass.cc @@ -0,0 +1,1535 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file gpu_emit_tensor_core.cc + */ + +#include "emit_pass.h" +#include + +namespace akg { +namespace ir { +namespace poly { + +class CheckTensorCoreValid : public IRVisitor { + public: + explicit CheckTensorCoreValid() {} + using IRVisitor::Visit_; + + void Visit_(const AttrStmt *op) { + auto key = op->attr_key; + if (key == WARP_MARKER) { + warp_marker_ = true; + } + return IRVisitor::Visit_(op); + } + + bool IsValid() { return warp_marker_; } + + private: + bool warp_marker_{false}; +}; + +Array GetTileSize(TensorCoreInfo &tensor_core_info, const std::string &name) { + auto it = tensor_core_info.matrix_abc_.find(name); + auto it2 = tensor_core_info.matrix_major_.find(name); + CHECK(it != tensor_core_info.matrix_abc_.end() && it2 != tensor_core_info.matrix_major_.end()) + << "Cannot find matrix info for " << name; + Expr size0 = make_const(Int(32), 16); + Expr size1 = make_const(Int(32), 16); + if (it->second == MMA_A && it2->second == COL_MAJOR) { + size0 = make_const(Int(32), tensor_core_info.warp_tile_.k); + size1 = make_const(Int(32), tensor_core_info.warp_tile_.m); + } + if (it->second == MMA_A && it2->second == ROW_MAJOR) { + size0 = make_const(Int(32), tensor_core_info.warp_tile_.m); + size1 = make_const(Int(32), tensor_core_info.warp_tile_.k); + } + if (it->second == MMA_B && it2->second == ROW_MAJOR) { + size0 = make_const(Int(32), tensor_core_info.warp_tile_.k); + size1 = make_const(Int(32), tensor_core_info.warp_tile_.n); + } + if (it->second == MMA_B && it2->second == COL_MAJOR) { + size0 = make_const(Int(32), tensor_core_info.warp_tile_.n); + size1 = make_const(Int(32), tensor_core_info.warp_tile_.k); + } + + if (it->second == MATRIX_C || it->second == MATRIX_ELSE) { + size0 = make_const(Int(32), tensor_core_info.warp_tile_.m); + size1 = make_const(Int(32), tensor_core_info.warp_tile_.n); + } + Array tile_size = {size0, size1}; + return tile_size; +} + +class DeleteUselessFor : public air::ir::IRMutator { + public: + explicit DeleteUselessFor() {} + ~DeleteUselessFor() override = default; + + Stmt Mutate_(const For *op, const Stmt &s) { + for_iters_.push_back(op->loop_var.get()); + Stmt stmt = IRMutator::Mutate_(op, s); + for_iters_.pop_back(); + return stmt.as()->body; + } + + Stmt Mutate_(const AttrStmt *op, const Stmt &s) override { + if (op->attr_key == air::ir::attr::buffer_bind_scope) { + Array arr = Downcast>(op->node); + CHECK_EQ(arr.size(), 2U); + const BufferNode *buffer = arr[0].as(); + const TensorNode *tensor = arr[1].as(); + CHECK(buffer && tensor); + auto e = buffer->elem_offset; + Expr ret = this->Mutate(e); + NodePtr buffer_node = make_node(); + buffer_node->data = buffer->data; + buffer_node->name = buffer->name; + buffer_node->scope = buffer->scope; + buffer_node->dtype = buffer->dtype; + buffer_node->strides = buffer->strides; + buffer_node->shape = buffer->shape; + buffer_node->data_alignment = buffer->data_alignment; + buffer_node->elem_offset = ret; + buffer_node->offset_factor = buffer->offset_factor; + + Buffer buffer_new(buffer_node); + Array node = {buffer_new, arr[1]}; + + auto value = this->Mutate(op->value); + auto body = this->Mutate(op->body); + + return AttrStmt::make(node, op->attr_key, value, body); + } + return IRMutator::Mutate_(op, s); + } + + Expr Mutate_(const EQ *op, const Expr &e) override { + Expr a = op->a; + Expr b = op->b; + auto for_var = a.as(); + if (for_var != nullptr) { + for (auto &i : for_iters_) { + if (i == for_var) { + return EQ::make(b, b); + } + } + } + return e; + } + + Expr Mutate_(const Variable *op, const Expr &e) { + bool be_zero = false; + for (auto &i : for_iters_) { + if (i == op) { + be_zero = true; + break; + } + } + + if (be_zero) { + return make_const(Int(32), 0); + } + + return e; + } + + Expr Mutate_(const Call *op, const Expr &e) final { + if (op->is_intrinsic(air::ir::intrinsic::tvm_fill_fragment)) { + CHECK_EQ(op->args.size(), 6U); + return DeleteUselessForIndex(op, e); + } else if (op->is_intrinsic(air::ir::intrinsic::tvm_load_matrix_sync)) { + CHECK_EQ(op->args.size(), 8U); + return DeleteUselessForIndex(op, e); + + } else if (op->is_intrinsic(air::ir::intrinsic::tvm_store_matrix_sync)) { + CHECK_EQ(op->args.size(), 8U); + return DeleteUselessForIndex(op, e); + + } else if (op->is_intrinsic(air::ir::intrinsic::tvm_mma_sync)) { + CHECK_EQ(op->args.size(), 8U); + return DeleteUselessForIndex(op, e); + } else { + return IRMutator::Mutate_(op, e); + } + } + + Expr DeleteUselessForIndex(const Call *op, const Expr &e) { + Array args = op->args; + for (unsigned int i = 0; i < args.size(); ++i) { + args.Set(i, Simplify(this->Mutate(args[i]))); + } + if (args.same_as(op->args)) { + return e; + } + return Call::make(op->type, op->name, args, op->call_type, op->func, op->value_index); + } + + private: + std::vector for_iters_; +}; + +struct DataForLoad { + Expr src; + Expr stride; + Expr major; + const Call *call; + const Provide *op; + NodePtr node; +}; + +struct DataForStore { + Expr dst; + Expr stride; + const Call *call; + NodePtr node; +}; + +struct DataForFill { + const Call *call; + const Provide *op; + NodePtr node; +}; + +struct DataForSync { + Expr a; + Expr b; + Expr c; + NodePtr node_a; + NodePtr node_b; + NodePtr node_c; +}; + +struct DataForElem { + Expr a; + Expr b; + Expr c; + NodePtr node_a; + NodePtr node_b; + NodePtr node_c; +}; + + +class EmitTensorCoreHelper { + public: + struct CompareExpr { + bool operator()(const Expr &lhs, const Expr &rhs) const { return Compare(lhs, rhs) < 0; } + }; + EmitTensorCoreHelper(TensorCoreInfo &info, ScopInfo &scop_info) : tensor_core_info_(info), scop_info_(scop_info) {} + ~EmitTensorCoreHelper(){}; + + void SetDataForLoad(Expr src, Expr stride, Expr major, const Call *call, const Provide *op, + NodePtr &node); + void SetDataForStore(Expr dst, Expr stride, const Call *call, NodePtr &node); + void SetDataForFill(const Provide *op, const Call *call, NodePtr &node); + void SetDataForSync(Expr a, Expr b, Expr c, NodePtr &node_a, NodePtr &node_b, + NodePtr &node_c); + void SetDataForElem(Expr a, Expr b, Expr c, NodePtr &node_a, NodePtr &node_b, + NodePtr &node_c); + + void PrepareDataCore(); + + Stmt MakeLoadTransform(); + Stmt MakeStoreTransform(); + Stmt MakeFillTransform(); + Stmt MakeSyncTransform(); + Stmt MakeFragmentElemTransform(Expr op_name); + + private: + Array node_; + Expr tuple_; + TensorCoreInfo &tensor_core_info_; + + DataForLoad data_for_load_; + DataForStore data_for_store_; + DataForFill data_for_fill_; + DataForSync data_for_sync_; + DataForElem data_for_elemwise_; + + air::ir::TensorKey key_; + const Call *call_; + NodePtr buffer_node_; + Type data_type_; + ScopInfo &scop_info_; + std::map fragment_offset_; +}; + +void EmitTensorCoreHelper::SetDataForLoad(Expr src, Expr stride, Expr major, const Call *call, const Provide *op, + NodePtr &node) { + data_for_load_.src = src; + data_for_load_.stride = stride; + data_for_load_.major = major; + data_for_load_.call = call; + data_for_load_.op = op; + data_for_load_.node = node; +} +void EmitTensorCoreHelper::SetDataForStore(Expr dst, Expr stride, const Call *call, NodePtr &node) { + data_for_store_.dst = dst; + data_for_store_.stride = stride; + data_for_store_.call = call; + data_for_store_.node = node; +} +void EmitTensorCoreHelper::SetDataForFill(const Provide *op, const Call *call, NodePtr &node) { + data_for_fill_.call = call; + data_for_fill_.op = op; + data_for_fill_.node = node; +} +void EmitTensorCoreHelper::SetDataForSync(Expr a, Expr b, Expr c, NodePtr &node_a, + NodePtr &node_b, NodePtr &node_c) { + data_for_sync_.a = a; + data_for_sync_.b = b; + data_for_sync_.c = c; + data_for_sync_.node_a = node_a; + data_for_sync_.node_b = node_b; + data_for_sync_.node_c = node_c; +} +void EmitTensorCoreHelper::SetDataForElem(Expr a, Expr b, Expr c, NodePtr &node_a, + NodePtr &node_b, NodePtr &node_c) { + data_for_elemwise_.a = a; + data_for_elemwise_.b = b; + data_for_elemwise_.c = c; + data_for_elemwise_.node_a = node_a; + data_for_elemwise_.node_b = node_b; + data_for_elemwise_.node_c = node_c; +} + +void EmitTensorCoreHelper::PrepareDataCore() { + auto it = tensor_core_info_.bounds_.find(key_); + CHECK(it != tensor_core_info_.bounds_.end()); + Array min_bound; + for (auto i : it->second) { + min_bound.push_back(i->min); + } + + CHECK_GE(it->second.size(), 2); + Array shape; + for (size_t i = 0; i < it->second.size(); ++i) { + shape.push_back(it->second[i]->extent); + } + + auto tile_size = GetTileSize(tensor_core_info_, akg::common::GetGlobalName(call_->name)); + tensor_core_info_.min_bounds_[call_->name] = min_bound; + + Array strides; + for (size_t i = 1; i < shape.size(); ++i) { + Expr stride = IntImm::make(Int(32), 1); + for (size_t j = shape.size() - 1; j >= i; --j) { + stride = Mul::make(stride, shape[j]); + } + strides.push_back(stride); + } + strides.push_back(make_const(Int(32), 1)); + + // compute the local offset for fragment + // example: (cc1, cc2) + Expr fragment_elem_offset = IntImm::make(Int(32), 0); + CHECK_EQ(call_->args.size(), min_bound.size()); + for (size_t i = 0; i < min_bound.size(); i++) { + auto arg = call_->args[i]; + arg = Simplify(arg); + auto stride_val = strides[i]; + // tile_size[1] is the innermost axis of the tensor. + // And this axis is used for wmma interface. + // The fragment offset computing should make a division of the parameter. + if (i != min_bound.size() - 1) { + stride_val = Div::make(stride_val, tile_size[1]); + } + fragment_elem_offset = Add::make(fragment_elem_offset, Mul::make(stride_val, Sub::make(arg, min_bound[i]))); + } + + Expr elem_offset = IntImm::make(Int(32), 0); + CHECK_EQ(call_->args.size(), min_bound.size()); + for (size_t i = 0; i < min_bound.size(); i++) { + auto arg = call_->args[i]; + arg = Simplify(arg); + elem_offset = Add::make(elem_offset, Mul::make(strides[i], Sub::make(arg, min_bound[i]))); + } + + elem_offset = Simplify(elem_offset); + + // insert the fragment offset information + fragment_offset_[elem_offset] = fragment_elem_offset; + + auto it2 = tensor_core_info_.matrix_abc_.find(akg::common::GetGlobalName(call_->name)); + CHECK(it2 != tensor_core_info_.matrix_abc_.end()) << "Cannot find matrix info for " << call_->name; + buffer_node_->data = Variable::make(Handle(), call_->name); + buffer_node_->name = call_->name; + std::string name = it2->second; + if (name == MATRIX_C || name == MATRIX_ELSE) { + name = MMA_C; + } + buffer_node_->scope = "wmma." + name; + buffer_node_->dtype = data_type_; + buffer_node_->strides = strides; + buffer_node_->shape = shape; + buffer_node_->data_alignment = 1; + buffer_node_->elem_offset = Simplify(elem_offset); + buffer_node_->offset_factor = 1; + Buffer buffer(buffer_node_); + + NodePtr tensor_node = make_node(); + tensor_node->value_index = key_.value_index; + tensor_node->op = Downcast(key_.f); + tensor_node->shape = shape; + tensor_node->dtype = data_type_; + Tensor tensor(tensor_node); + + Array args; + for (size_t i = 0; i < call_->args.size(); ++i) { + auto arg = call_->args[i]; + arg = Simplify(arg); + args.push_back(arg); + args.push_back(shape[i]); + } + tuple_ = Call::make(Handle(), air::ir::intrinsic::tvm_tuple, args, Call::Intrinsic); + node_ = {buffer, tensor}; +} + +Stmt EmitTensorCoreHelper::MakeLoadTransform() { + key_ = air::ir::TensorKey{data_for_load_.op->func, data_for_load_.op->value_index}; + call_ = data_for_load_.call; + buffer_node_ = data_for_load_.node; + data_type_ = call_->type; + + PrepareDataCore(); + Buffer buffer = Downcast(node_[0]); + Stmt stmt = Evaluate::make(Call::make( + Handle(), air::ir::intrinsic::tvm_load_matrix_sync, + {buffer->data, tensor_core_info_.warp_tile_.m, tensor_core_info_.warp_tile_.n, tensor_core_info_.warp_tile_.k, + Simplify(fragment_offset_[buffer->elem_offset]), data_for_load_.src, data_for_load_.stride, data_for_load_.major}, + Call::Intrinsic)); + fragment_offset_.clear(); + return AttrStmt::make(node_, "buffer_bind_scope", tuple_, stmt); +} + +Stmt EmitTensorCoreHelper::MakeStoreTransform() { + key_ = air::ir::TensorKey{data_for_store_.call->func, data_for_store_.call->value_index}; + call_ = data_for_store_.call; + buffer_node_ = data_for_store_.node; + data_type_ = call_->type; + + PrepareDataCore(); + Buffer buffer = Downcast(node_[0]); + Stmt stmt = Evaluate::make(Call::make( + Handle(), air::ir::intrinsic::tvm_store_matrix_sync, + {buffer->data, tensor_core_info_.warp_tile_.m, tensor_core_info_.warp_tile_.n, tensor_core_info_.warp_tile_.k, + fragment_offset_[buffer->elem_offset], data_for_store_.dst, data_for_store_.stride, StringImm::make(ROW_MAJOR)}, + Call::Intrinsic)); + fragment_offset_.clear(); + return AttrStmt::make(node_, "buffer_bind_scope", tuple_, stmt); +} + +Stmt EmitTensorCoreHelper::MakeFillTransform() { + key_ = air::ir::TensorKey{data_for_fill_.call->func, data_for_fill_.call->value_index}; + call_ = data_for_fill_.call; + buffer_node_ = data_for_fill_.node; + data_type_ = call_->type; + + PrepareDataCore(); + Buffer buffer = Downcast(node_[0]); + Stmt stmt = Evaluate::make( + Call::make(Handle(), air::ir::intrinsic::tvm_fill_fragment, + {buffer->data, tensor_core_info_.warp_tile_.m, tensor_core_info_.warp_tile_.n, + tensor_core_info_.warp_tile_.k, fragment_offset_[buffer->elem_offset], data_for_fill_.op->value}, + Call::Intrinsic)); + fragment_offset_.clear(); + return AttrStmt::make(node_, "buffer_bind_scope", tuple_, stmt); +} + +Stmt EmitTensorCoreHelper::MakeSyncTransform() { + bool is_cast = false; + if (data_for_sync_.a.as()) { + auto call_a = data_for_sync_.a.as(); + key_ = air::ir::TensorKey{call_a->func, call_a->value_index}; + call_ = call_a; + buffer_node_ = data_for_sync_.node_a; + data_type_ = call_->type; + is_cast = true; + } else if (data_for_sync_.a.as()) { + auto cast_a = data_for_sync_.a.as(); + auto call_a = cast_a->value.as(); + CHECK(call_a); + key_ = air::ir::TensorKey{call_a->func, call_a->value_index}; + call_ = call_a; + buffer_node_ = data_for_sync_.node_a; + data_type_ = call_->type; + is_cast = true; + } + + PrepareDataCore(); + + auto tuple_a = tuple_; + auto node_a = node_; + + if (data_for_sync_.b.as()) { + auto call_b = data_for_sync_.b.as(); + key_ = air::ir::TensorKey{call_b->func, call_b->value_index}; + call_ = call_b; + buffer_node_ = data_for_sync_.node_b; + data_type_ = call_->type; + is_cast = false; + } else if (data_for_sync_.b.as()) { + auto cast_b = data_for_sync_.b.as(); + auto call_b = cast_b->value.as(); + CHECK(call_b); + key_ = air::ir::TensorKey{call_b->func, call_b->value_index}; + call_ = call_b; + buffer_node_ = data_for_sync_.node_b; + data_type_ = call_->type; + is_cast = true; + } + + PrepareDataCore(); + + auto tuple_b = tuple_; + auto node_b = node_; + + auto call_c = data_for_sync_.c.as(); + CHECK(call_c); + key_ = air::ir::TensorKey{call_c->func, call_c->value_index}; + call_ = call_c; + buffer_node_ = data_for_sync_.node_c; + data_type_ = call_->type; + + PrepareDataCore(); + + auto tuple_c = tuple_; + auto node_c = node_; + + Buffer buffer_a(data_for_sync_.node_a); + Buffer buffer_b(data_for_sync_.node_b); + Buffer buffer = Downcast(node_c[0]); + + Stmt stmt = Evaluate::make(Call::make( + Handle(), air::ir::intrinsic::tvm_mma_sync, + {buffer->data, fragment_offset_[buffer->elem_offset], buffer_a->data, fragment_offset_[buffer_a->elem_offset], + buffer_b->data, fragment_offset_[buffer_b->elem_offset], buffer->data, fragment_offset_[buffer->elem_offset]}, + Call::Intrinsic)); + fragment_offset_.clear(); + stmt = AttrStmt::make(node_c, "buffer_bind_scope", tuple_c, stmt); + stmt = AttrStmt::make(node_b, "buffer_bind_scope", tuple_b, stmt); + stmt = AttrStmt::make(node_a, "buffer_bind_scope", tuple_a, stmt); + + std::string cast_mode = CAST_MODE_1; + if (is_cast) { + stmt = AttrStmt::make(Expr("INFO"), CAST_FLAG, StringImm::make(cast_mode), stmt); + } + + return stmt; +} + +Stmt EmitTensorCoreHelper::MakeFragmentElemTransform(Expr op_name) { + auto call_a = data_for_elemwise_.a.as(); + key_ = air::ir::TensorKey{call_a->func, call_a->value_index}; + call_ = call_a; + buffer_node_ = data_for_elemwise_.node_a; + data_type_ = call_->type; + + PrepareDataCore(); + + auto tuple_a = tuple_; + auto node_a = node_; + + auto call_b = data_for_elemwise_.b.as(); + if (call_b) { + key_ = air::ir::TensorKey{call_b->func, call_b->value_index}; + call_ = call_b; + buffer_node_ = data_for_elemwise_.node_b; + data_type_ = call_->type; + + PrepareDataCore(); + + auto tuple_b = tuple_; + auto node_b = node_; + + auto call_c = data_for_elemwise_.c.as(); + CHECK(call_c); + key_ = air::ir::TensorKey{call_c->func, call_c->value_index}; + call_ = call_c; + buffer_node_ = data_for_elemwise_.node_c; + data_type_ = call_->type; + + PrepareDataCore(); + + auto tuple_c = tuple_; + auto node_c = node_; + + Buffer buffer_a(data_for_elemwise_.node_a); + Buffer buffer_b(data_for_elemwise_.node_b); + Buffer buffer = Downcast(node_c[0]); + + Stmt stmt = Evaluate::make( + Call::make(Handle(), air::ir::intrinsic::akg_fragment_elem, + {buffer->data, fragment_offset_[buffer->elem_offset], buffer_a->data, + fragment_offset_[buffer_a->elem_offset], buffer_b->data, fragment_offset_[buffer_b->elem_offset], + op_name}, Call::Intrinsic)); + fragment_offset_.clear(); + stmt = AttrStmt::make(node_c, "buffer_bind_scope", tuple_c, stmt); + stmt = AttrStmt::make(node_b, "buffer_bind_scope", tuple_b, stmt); + stmt = AttrStmt::make(node_a, "buffer_bind_scope", tuple_a, stmt); + return stmt; + } else { + auto call_c = data_for_elemwise_.c.as(); + CHECK(call_c); + key_ = air::ir::TensorKey{call_c->func, call_c->value_index}; + call_ = call_c; + buffer_node_ = data_for_elemwise_.node_c; + data_type_ = call_->type; + + PrepareDataCore(); + + auto tuple_c = tuple_; + auto node_c = node_; + + Buffer buffer_a(data_for_elemwise_.node_a); + Buffer buffer = Downcast(node_c[0]); + + Stmt stmt = Evaluate::make( + Call::make(Handle(), air::ir::intrinsic::akg_fragment_elem, + {buffer->data, fragment_offset_[buffer->elem_offset], buffer_a->data, + fragment_offset_[buffer_a->elem_offset], Expr(data_for_elemwise_.b), + op_name}, Call::Intrinsic)); + fragment_offset_.clear(); + stmt = AttrStmt::make(node_c, "buffer_bind_scope", tuple_c, stmt); + stmt = AttrStmt::make(node_a, "buffer_bind_scope", tuple_a, stmt); + return stmt; + } +} + +class AddMmaAttrFlag : public air::ir::IRMutator { + public: + explicit AddMmaAttrFlag(TensorCoreInfo t) : tt(t) {} + ~AddMmaAttrFlag() override = default; + + Stmt Mutate_(const AttrStmt *op, const Stmt &s) override { + Stmt stmt = IRMutator::Mutate_(op, s); + if (op->attr_key == air::ir::attr::realize_scope) { + auto node = op->node.as(); + if (node != nullptr) { + if (!tt.frag_reg_.count(node->name)) { + return stmt; + } + + auto it = tt.matrix_abc_.find(akg::common::GetGlobalName(node->name)); + CHECK(it != tt.matrix_abc_.end()) << "Cannot find matrix info for " << node->name; + std::string name = it->second; + if (name == MATRIX_C || name == MATRIX_ELSE) { + name = MMA_C; + } + + auto matrix_abc = "wmma." + name; + Stmt body = Mutate(op->body); + return AttrStmt::make(op->node, op->attr_key, matrix_abc, body); + } + } + return stmt; + } + + private: + TensorCoreInfo tt; +}; + +class LocalTensorAnalyser : public IRVisitor { + public: + explicit LocalTensorAnalyser(TensorCoreInfo &info, ScopInfo &scop_info) + : matrix_abc_(info.matrix_abc_), matrix_major_(info.matrix_major_), frag_reg_(info.frag_reg_) { + for (auto kv : scop_info.user_config_.GetOriginBind()) { + BufferInfo bi; + bi.name = kv.second->name; + bi.dtype = kv.second->dtype; + bi.external = true; + buf_map_[air::ir::TensorKey{kv.first->op, kv.first->value_index}] = bi; + } + } + using IRVisitor::Visit_; + + void Visit_(const Provide *op) final { + IRVisitor::Visit_(op); + air::ir::TensorKey key{op->func, op->value_index}; + auto it = buf_map_.find(key); + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key.f; + const BufferInfo &bi = it->second; + CHECK(!bi.released) << "Read a buffer that is already out of scope"; + + std::vector tile_size; + if (frag_reg_.count(bi.name)) { + Expr dst = Call::make(bi.dtype, bi.name, op->args, Call::Halide, op->func, 0); + frag_load_.insert(std::make_pair(op, dst)); + } + + const Call *value = op->value.as(); + + if (value != nullptr && frag_reg_.count(value->name)) { + Expr dst = Call::make(bi.dtype, bi.name, op->args, Call::Halide, op->func, 0); + frag_store_.insert(std::make_pair(op, dst)); + } + } + + void Visit_(const Realize *op) final { + air::ir::TensorKey key{op->func, op->value_index}; + if (buf_map_.count(key)) { + CHECK(buf_map_.at(key).external); + Visit(op->body); + } else { + BufferInfo bi; + bi.name = key.GetName(); + bi.dtype = op->type; + + buf_map_[key] = bi; + Visit(op->body); + buf_map_[key].released = true; + } + } + + private: + struct BufferInfo { + std::string name; + Type dtype; + bool external{false}; + bool released{false}; + }; + + std::unordered_map buf_map_; + std::unordered_map matrix_abc_; + std::unordered_map matrix_major_; + std::set frag_reg_; + + public: + std::unordered_map frag_load_; + std::unordered_map frag_store_; +}; + +class ExprUsedVarsVisitor : public IRVisitor { + public: + explicit ExprUsedVarsVisitor() {} + + void Visit_(const Variable *op) { + if (op->name_hint != THREAD_IDX_X) { + vars_.push_back(op); + } + } + + std::vector Run(Expr e) { + this->Visit(e); + return vars_; + } + + private: + std::vector vars_; +}; + +class ModifyTheLocalOffset : public IRMutator { + public: + explicit ModifyTheLocalOffset(TensorCoreInfo &info, ScopInfo &scop_info, LocalTensorAnalyser &local_analyser) + : tensor_core_info_(info), + scop_info_(scop_info), + frag_load_(local_analyser.frag_load_), + frag_store_(local_analyser.frag_store_) {} + + Stmt Mutate_(const Provide *op, const Stmt &s) { + auto it2 = frag_load_.find(op); + if (it2 != frag_load_.end()) { + if (op->value.as() != nullptr || op->value.as() != nullptr) { + Stmt stmt = ModifyTheOpIndexOfLoadFill(op, GetFragmentIndex(op)); + // The provide op is a new op. + frag_load_new_.insert(stmt.as()); + return stmt; + } + const Call *value = op->value.as(); + if (value != nullptr) { + Stmt stmt = ModifyTheOpIndexOfLoadFill(op, GetFragmentIndex(op)); + frag_load_new_.insert(stmt.as()); + return stmt; + } + + Stmt stmt = ModifyTheOpIndexOfSync(op, GetFragmentIndex(op)); + frag_load_new_.insert(stmt.as()); + return stmt; + } + + auto it3 = frag_store_.find(op); + if (it3 != frag_store_.end()) { + auto value = op->value; + auto call = value.as(); + CHECK(call); + Stmt stmt; + if (scop_info_.user_config_.GetEnableConvTensorCore()) { + stmt = ModifyTheOpIndexOfStore(op, GetFragmentIndexConv(call)); + } else { + stmt = ModifyTheOpIndexOfStore(op, GetFragmentIndex(call)); + } + frag_store_new_.insert(stmt.as()); + return stmt; + } + return IRMutator::Mutate_(op, s); + } + + Stmt Mutate_(const For *op, const Stmt &s) { + vec_for_vars_.push_back(op); + Stmt stmt = IRMutator::Mutate_(op, s); + vec_for_vars_.pop_back(); + return stmt; + } + + Expr Mutate_(const Call *op, const Expr &e) { + if (sync_value_mod) { + Array real_index; + if (scop_info_.user_config_.GetEnableConvTensorCore()) { + real_index = GetFragmentIndexConv(op); + } else { + real_index = GetFragmentIndex(op); + } + return Call::make(op->type, op->name, real_index, op->call_type, op->func, op->value_index); + } + + return IRMutator::Mutate_(op, e); + } + + Array GetFragmentIndex(const Provide *op) { + auto call = frag_load_[op].as(); + CHECK(call); + if (scop_info_.user_config_.GetEnableConvTensorCore()) { + return GetFragmentIndexConv(call); + } + return GetFragmentIndex(call); + } + + Array GetFragmentIndex(const Call *call) { + auto args = call->args; + Array new_index; + for (auto &i : args) { + auto used_vars = ExprUsedVarsVisitor().Run(i); + Expr e = make_const(Int(32), 0); + int size = used_vars.size(); + // The last var is used for wmma interface. + // So in this place, delete the related var infor. + // example: (cc1*16+cc2) cc2 is the var using for wmma interface + // cc1 is the var after warping + // This index will be converted to cc1 + // If the index is without cc1, this will be converted to 0 + for (int i = 0; i < size - 1; i++) { + auto u = used_vars[i]; + Expr temp = Expr(GetObjPtr(u)); + for (int j = i + 1; j < size - 1; j++) { + // The last var is used for wmma interface. + // So in this place, the extent of last var is not used. + temp = Mul::make(temp, FindExtentOfForVar(used_vars[j])); + } + e = Add::make(e, temp); + } + new_index.push_back(e); + } + return new_index; + } + + // for input_1 and output, the layout is n h w c + // for input_2, the layout is o kh kw c + // the h and w dimention is useful for the fragment compute. + // the kh and kw dimension is splitted outer. + Array GetFragmentIndexConv(const Call *call) { + auto args = call->args; + Array new_index; + int len = args.size(); + for (int i = 0; i < len; i++) { + auto used_vars = ExprUsedVarsVisitor().Run(args[i]); + Expr e = make_const(Int(32), 0); + int size = used_vars.size(); + // The last var is used for wmma interface. + // So in this place, delete the related var infor. + // example: (cc1*16+cc2) cc2 is the var using for wmma interface + // cc1 is the var after warping + // This index will be converted to cc1 + // If the index is without cc1, this will be converted to 0 + // The H&W dimensions are not used for warp mapping and wmma interface. + // So, for H and W dimension, this logic should be disabled. + constexpr auto H_DIMENSION_INDEX = 1; + constexpr auto W_DIMENSION_INDEX = 2; + int outer_size = size; + if (i != H_DIMENSION_INDEX && i != W_DIMENSION_INDEX) { + outer_size -= 1; + } + for (int i = 0; i < outer_size; i++) { + auto u = used_vars[i]; + Expr temp = Expr(GetObjPtr(u)); + for (int j = i + 1; j < size - 1; j++) { + // The last var is used for wmma interface. + // So in this place, the extent of last var is not used. + temp = Mul::make(temp, FindExtentOfForVar(used_vars[j])); + } + e = Add::make(e, temp); + } + new_index.push_back(e); + } + return new_index; + } + + Stmt ModifyTheOpIndexOfLoadFill(const Provide *op, Array real_index) { + return Provide::make(op->func, op->value_index, op->value, real_index); + } + + Stmt ModifyTheOpIndexOfStore(const Provide *op, Array real_index) { + auto value = op->value; + auto call = value.as(); + CHECK(call); + value = Call::make(call->type, call->name, real_index, call->call_type, call->func, call->value_index); + return Provide::make(op->func, op->value_index, value, op->args); + } + + Stmt ModifyTheOpIndexOfSync(const Provide *op, Array real_index) { + auto value = op->value; + sync_value_mod = true; + value = this->Mutate(value); + sync_value_mod = false; + return Provide::make(op->func, op->value_index, value, real_index); + } + + Expr FindExtentOfForVar(const Variable *var) { + for (auto &v : vec_for_vars_) { + if (v->loop_var.get() == var) { + return v->extent; + } + } + return Expr(); + } + + friend class TensorCoreInterfaceEmit; + + private: + TensorCoreInfo &tensor_core_info_; + ScopInfo &scop_info_; + + std::unordered_map frag_load_; + std::unordered_map frag_store_; + std::unordered_set frag_load_new_; + std::unordered_set frag_store_new_; + std::vector vec_for_vars_; + int for_count_{0}; + bool sync_value_mod{false}; +}; + +class TensorCoreInterfaceEmit : public IRMutator { + public: + explicit TensorCoreInterfaceEmit(TensorCoreInfo &info, ScopInfo &scop_info, ModifyTheLocalOffset &warp) + : tensor_core_info_(info), + scop_info_(scop_info), + frag_load_(warp.frag_load_new_), + frag_store_(warp.frag_store_new_) {} + + Stmt Mutate_(const Provide *op, const Stmt &s) { + Stmt stmt = IRMutator::Mutate_(op, s); + auto it2 = frag_load_.find(op); + if (it2 != frag_load_.end()) { + if (op->value.as() != nullptr || op->value.as() != nullptr) { + for_count_ = DATA_COMPUTE_FOR_DEPTH; + return EmitFillStmt(stmt); + } + + const Call *value = op->value.as(); + if (value != nullptr) { + for_count_ = DATA_LOAD_STORE_FOR_DEPTH; + return EmitLoadStmt(stmt); + } + + if (Mma(stmt)) { + return EmitSyncStmt(stmt); + } + + Array elemwise = GetBinaryOpExprChildren(op->value); + if (!elemwise.empty()) { + for_count_ = DATA_COMPUTE_FOR_DEPTH; + return EmitFragmentElem(stmt); + } + + return stmt; + } + + auto it3 = frag_store_.find(op); + if (it3 != frag_store_.end()) { + for_count_ = DATA_LOAD_STORE_FOR_DEPTH; + return EmitStoreStmt(stmt); + } + + return IRMutator::Mutate_(op, s); + } + + bool Mma(Stmt stmt) { + auto op = stmt.as(); + if (op == nullptr) { + return false; + } + + auto add_op = op->value.as(); + if (add_op == nullptr) { + return false; + } + + auto tensor_c = add_op->a.as(); + if (tensor_c == nullptr) { + return false; + } + + Type tensor_c_type = tensor_c->type; + if (tensor_c_type != Float(16) && tensor_c_type != Float(32)) { + return false; + } + + auto mul_op = akg::common::SplitCast(add_op->b, tensor_c_type).as(); + if (mul_op == nullptr) { + return false; + } + + return true; + } + + Stmt Mutate_(const For *op, const Stmt &s) { + Stmt stmt = IRMutator::Mutate_(op, s); + if (for_count_ != 0) { + for_count_--; + if (for_count_ == 0) { + stmt = DeleteUselessFor().Mutate(stmt); + } + } + return stmt; + } + + Stmt EmitLoadStmt(Stmt stmt) { + auto op_new = stmt.as(); + CHECK(op_new); + const Call *call_value = op_new->value.as(); + CHECK(call_value != nullptr) << "Can only load fragment from a buffer"; + + auto left_expr = MakeLeftCallFromProvide(op_new); + auto left_call = left_expr.as(); + CHECK(left_call != nullptr) << "make right part call failed!"; + + auto it = tensor_core_info_.strides_.find(call_value->name); + CHECK(it != tensor_core_info_.strides_.end()) << "Cannot find stride for " << call_value->name; + auto strides = it->second; + CHECK_GE(strides.size(), 2); + Expr stride = strides[strides.size() - 2]; + // set the stride information for conv operator + // conv operator matrix a layout is "n h w ic" + // The wmma interface uses the data of n. So the stride computing + // should used the axises of h w ic. + if (scop_info_.user_config_.GetEnableConvTensorCore() && + tensor_core_info_.matrix_abc_[akg::common::GetGlobalName(call_value->name)] == MATRIX_A) { + CHECK_GE(strides.size(), CONV_MATRIXA_DIMENSION); + stride = strides[strides.size() - CONV_MATRIXA_DIMENSION]; + } + + std::string call_name = op_new->func->func_name(); + Expr src = Call::make(call_value->type, "&", {op_new->value}, Call::Extern); + + Expr matrix_major; + auto iter2 = tensor_core_info_.matrix_major_.find(akg::common::GetGlobalName(call_name)); + CHECK(iter2 != tensor_core_info_.matrix_major_.end()) << "Can not determine matrix major for " << call_name; + if (iter2->second == COL_MAJOR) { + matrix_major = StringImm::make(COL_MAJOR); + } else if (iter2->second == ROW_MAJOR) { + matrix_major = StringImm::make(ROW_MAJOR); + } else { + LOG(FATAL) << "invalid matrix major for " << call_name; + } + + NodePtr buffer_node = make_node(); + EmitTensorCoreHelper helper(tensor_core_info_, scop_info_); + helper.SetDataForLoad(src, stride, matrix_major, left_call, op_new, buffer_node); + return helper.MakeLoadTransform(); + } + + Stmt EmitSyncStmt(Stmt stmt) { + auto op = stmt.as(); + CHECK(op); + + auto left_expr = MakeLeftCallFromProvide(op); + Type type = scop_info_.user_config_.GetDataType(op->func->func_name()); + auto *add = op->value.as(); + CHECK(add) << "format error of bmm"; + auto mul = akg::common::SplitCast(add->b, type).as(); + CHECK(mul) << "format error of bmm"; + + auto load_a_expr = akg::common::SplitCast(mul->a, type); + auto load_b_expr = akg::common::SplitCast(mul->b, type); + + Expr a = load_a_expr; + Expr b = load_b_expr; + Expr c = left_expr; + + NodePtr buffer_node_a = make_node(); + NodePtr buffer_node_b = make_node(); + NodePtr buffer_node_c = make_node(); + + EmitTensorCoreHelper helper(tensor_core_info_, scop_info_); + helper.SetDataForSync(a, b, c, buffer_node_a, buffer_node_b, buffer_node_c); + return helper.MakeSyncTransform(); + } + + Stmt EmitFillStmt(Stmt stmt) { + auto op = stmt.as(); + auto left_expr = MakeLeftCallFromProvide(op); + auto left_call = left_expr.as(); + CHECK(left_call != nullptr) << "make right part call failed"; + + if (op->value.as() != nullptr || op->value.as() != nullptr) { + NodePtr buffer_node = make_node(); + EmitTensorCoreHelper helper(tensor_core_info_, scop_info_); + helper.SetDataForFill(op, left_call, buffer_node); + return helper.MakeFillTransform(); + } else { + CHECK(false) << "mma init stmt format error"; + } + return Stmt(); + } + + Stmt EmitStoreStmt(Stmt stmt) { + auto op = stmt.as(); + CHECK(op); + + auto lh_expr = MakeLeftCallFromProvide(op); + auto lh_call = lh_expr.as(); + CHECK(lh_call != nullptr) << "make right part call failed!"; + + auto it = tensor_core_info_.strides_.find(lh_call->name); + CHECK(it != tensor_core_info_.strides_.end()) << "Cannot find stride for " << lh_call->name; + auto strides = it->second; + CHECK_GE(strides.size(), 2); + Expr stride = strides[strides.size() - 2]; + + // set the stride information for conv operator + // conv operator output layout is "n h w o" + // The wmma interface uses the data of n. So the stride computing + // should used the axises of h w o. + if (scop_info_.user_config_.GetEnableConvTensorCore()) { + CHECK_GE(strides.size(), CONV_OUTPUT_DIMENSION); + stride = strides[strides.size() - CONV_OUTPUT_DIMENSION]; + } + + Expr dst = lh_expr; + dst = Call::make(Handle(), "&", {dst}, Call::Extern); + + auto call = op->value.as(); + NodePtr buffer_node = make_node(); + EmitTensorCoreHelper helper(tensor_core_info_, scop_info_); + helper.SetDataForStore(dst, stride, call, buffer_node); + return helper.MakeStoreTransform(); + } + + Stmt EmitFragmentElem(Stmt stmt) { + auto op = stmt.as(); + CHECK(op); + + auto elem = GetBinaryOpExprChildren(op->value); + Expr op_name = GetBinaryOpName(op->value); + + Expr a = elem[0]; + Expr b = elem[1]; + auto left_expr = MakeLeftCallFromProvide(op); + Expr c = left_expr; + + NodePtr buffer_node_a = make_node(); + NodePtr buffer_node_b = make_node(); + NodePtr buffer_node_c = make_node(); + + EmitTensorCoreHelper helper(tensor_core_info_, scop_info_); + helper.SetDataForElem(a, b, c, buffer_node_a, buffer_node_b, buffer_node_c); + return helper.MakeFragmentElemTransform(op_name); + } + + Expr MakeLeftCallFromProvide(const Provide *op) { + std::string name = op->func->func_name(); + Type type = scop_info_.user_config_.GetDataType(name); + Expr dst = Call::make(type, name, op->args, Call::Halide, op->func, 0); + return dst; + } + + private: + TensorCoreInfo &tensor_core_info_; + ScopInfo &scop_info_; + bool load_stmt_{false}; + bool store_stmt_{false}; + bool sync_stmt_{false}; + std::unordered_set frag_load_; + std::unordered_set frag_store_; + std::stack st; + std::vector vec_for_vars_; + int for_count_{0}; +}; + +class CheckCast : public IRVisitor { + public: + explicit CheckCast() {} + using IRVisitor::Visit_; + + void Visit_(const AttrStmt *op) final { + if (op->attr_key == CAST_FLAG) { + std::string mode = op->value.as()->value; + if (mode == CAST_MODE_1) { + origin_type_ = Float(32); + cast_type_ = Float(16); + } + is_cast_ = true; + IRVisitor::Visit_(op); + return; + } + IRVisitor::Visit_(op); + } + + void Visit_(const Call *op) final { + if (op->is_intrinsic(air::ir::intrinsic::tvm_mma_sync)) { + CHECK_EQ(op->args.size(), 8U); + Expr arg2 = op->args[2]; + Expr arg4 = op->args[4]; + const Variable *a2 = arg2.as(); + CHECK(a2); + const Variable *a4 = arg4.as(); + CHECK(a4); + cast_tensors_.insert(akg::common::GetGlobalName(a2->name_hint)); + cast_tensors_.insert(akg::common::GetGlobalName(a4->name_hint)); + } + IRVisitor::Visit_(op); + } + + bool IsCastAdapt() { return is_cast_; } + friend class CollectInfoToAdaptCast; + + private: + Type origin_type_; + Type cast_type_; + bool is_cast_{false}; + std::set cast_tensors_; +}; + +class CollectInfoToAdaptCast : public IRVisitor { + public: + explicit CollectInfoToAdaptCast(CheckCast &check_cast) + : origin_type_(check_cast.origin_type_), + cast_type_(check_cast.cast_type_), + cast_tensors_(check_cast.cast_tensors_) {} + using IRVisitor::Visit_; + + void Visit_(const AttrStmt *op) final { + if (op->attr_key == GMREAD_FLAG) { + is_global_to_shared_ = true; + IRVisitor::Visit_(op); + is_global_to_shared_ = false; + return; + } + IRVisitor::Visit_(op); + } + + void Visit_(const Provide *op) final { + if (is_global_to_shared_) { + global_to_shared_.insert(op); + } + IRVisitor::Visit_(op); + } + + void Visit_(const Realize *op) final { + std::string name = op->func->func_name(); + if (IsEndsWith(name, SHARE_SUFFIX) && cast_tensors_.count(akg::common::GetGlobalName(name))) { + realize_need_cast_shared_.insert(name); + } else if (IsEndsWith(name, LOCAL_SUFFIX) && cast_tensors_.count(akg::common::GetGlobalName(name))) { + realize_need_cast_local_.insert(name); + } + IRVisitor::Visit_(op); + } + + friend class AdaptCast; + + private: + Type origin_type_; + Type cast_type_; + bool is_global_to_shared_{false}; + std::set cast_tensors_; + + std::set global_to_shared_; + std::set realize_need_cast_shared_; + std::set realize_need_cast_local_; +}; + +class AdaptCast : public IRMutator { + public: + explicit AdaptCast(CollectInfoToAdaptCast &info) + : realize_need_cast_shared_(info.realize_need_cast_shared_), + realize_need_cast_local_(info.realize_need_cast_local_), + global_to_shared_(info.global_to_shared_), + origin_type_(info.origin_type_), + cast_type_(info.cast_type_) {} + + Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { + if (op->attr_key == air::ir::attr::buffer_bind_scope) { + Array arr = Downcast>(op->node); + CHECK_EQ(arr.size(), 2U); + const BufferNode *buffer = arr[0].as(); + const TensorNode *tensor = arr[1].as(); + const Call *tuple = op->value.as(); + CHECK(buffer && tensor); + CHECK(tuple); + if (realize_need_cast_local_.count(buffer->name)) { + NodePtr buffer_node = make_node(); + buffer_node->data = buffer->data; + buffer_node->name = buffer->name; + buffer_node->scope = buffer->scope; + buffer_node->dtype = cast_type_; + buffer_node->shape = buffer->shape; + buffer_node->strides = buffer->strides; + buffer_node->data_alignment = buffer->data_alignment; + buffer_node->elem_offset = buffer->elem_offset; + buffer_node->offset_factor = buffer->offset_factor; + + Buffer buffer_new(buffer_node); + NodePtr tensor_node = make_node(); + tensor_node->value_index = tensor->value_index; + tensor_node->op = tensor->op; + tensor_node->shape = tensor->shape; + tensor_node->dtype = cast_type_; + Tensor tensor_new(tensor_node); + + Array node = {buffer_new, tensor_new}; + Stmt body = this->Mutate(op->body); + return AttrStmt::make(node, op->attr_key, op->value, body); + } + } + return IRMutator::Mutate_(op, s); + } + + Stmt Mutate_(const Realize *op, const Stmt &s) final { + std::string tensor_name = op->func->func_name(); + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + if (op != nullptr) { + if (!realize_need_cast_shared_.count(tensor_name) && !realize_need_cast_local_.count(tensor_name)) { + return stmt; + } + + return Realize::make(op->func, op->value_index, cast_type_, op->bounds, op->condition, op->body); + } + return stmt; + } + + Stmt Mutate_(const Provide *op, const Stmt &s) final { + if (global_to_shared_.count(op)) { + auto value = op->value; + auto call = value.as(); + CHECK(call); + CHECK(call->type == origin_type_); + value = Cast::make(cast_type_, value); + return Provide::make(op->func, op->value_index, value, op->args); + } + return IRMutator::Mutate_(op, s); + } + + private: + std::set realize_need_cast_shared_; + std::set realize_need_cast_local_; + std::set global_to_shared_; + Type origin_type_; + Type cast_type_; +}; + +class AdaptCastDesignOne : public IRMutator { + public: + explicit AdaptCastDesignOne(TensorCoreInfo &info) : cast_tensors_(info.cast_tensors_) {} + + Stmt Mutate_(const Realize *op, const Stmt &s) final { + std::string tensor_name = op->func->func_name(); + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + if (op != nullptr) { + if (!cast_tensors_.count(tensor_name)) { + return stmt; + } + + return Realize::make(op->func, op->value_index, Float(16), op->bounds, op->condition, op->body); + } + return stmt; + } + + private: + std::unordered_set cast_tensors_; +}; + +class DeleteUselessAttr : public IRMutator { + public: + explicit DeleteUselessAttr() {} + Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { + if (op->attr_key == GMREAD_FLAG) { + return IRMutator::Mutate(op->body); + } + return IRMutator::Mutate_(op, s); + } +}; + +class AddNoUnrollAttr : public IRMutator { + public: + explicit AddNoUnrollAttr() {} + Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { + if (op->attr_key == CONV_KHKW_OUTER) { + meet_khkw_outer_ = true; + return IRMutator::Mutate(op->body); + } + return IRMutator::Mutate_(op, s); + } + + Stmt Mutate_(const For *op, const Stmt &s) { + if (!meet_khkw_outer_) { + return AttrStmt::make(Expr("INFO"), "no_unroll", StringImm::make("no_unroll"), IRMutator::Mutate_(op, s)); + } + return IRMutator::Mutate_(op, s); + } + + private: + bool meet_khkw_outer_{false}; +}; + +Stmt EmitForTensorCoreDesignOne(Stmt stmt, TensorCoreInfo &info) { + AdaptCastDesignOne adapt(info); + stmt = adapt.Mutate(stmt); + return stmt; +} + +bool CheckTileValid(Tile tile, TensorCoreInfo &info) { + if (tile.m == 16 && tile.n == 16 && tile.k == 4) { + info.wmma_scope_ = "akg"; + return true; + } + if (tile.m == 16 && tile.n == 16 && tile.k == 8) { + info.wmma_scope_ = "akg"; + return true; + } + if (tile.m == 32 && tile.n == 32 && tile.k == 4) { + info.wmma_scope_ = "akg"; + return true; + } + if (tile.m == 16 && tile.n == 16 && tile.k == 16) { + info.wmma_scope_ = "nvcuda"; + return true; + } + if (tile.m == 8 && tile.n == 32 && tile.k == 16) { + info.wmma_scope_ = "nvcuda"; + return true; + } + if (tile.m == 32 && tile.n == 8 && tile.k == 16) { + info.wmma_scope_ = "nvcuda"; + return true; + } + return false; +} + +void PrepareDataForTensorCore(TensorCoreInfo &info, ScopInfo &scop_info) { + auto binds = scop_info.user_config_.GetBind(); + + auto thread_cfg = scop_info.user_config_.GetThreadConfig(); + CHECK(thread_cfg) << "thread config is null"; + int tx = thread_cfg->GetX().second; + int ty = thread_cfg->GetY().second; + int tz = thread_cfg->GetZ().second; + + if (scop_info.user_config_.GetEnableOneDimThread()) { + tx = tx * ty * tz; + ty = 1; + tz = 1; + } + + for (auto i : binds) { + if (!i.first.defined()) continue; + if (!i.second.defined()) continue; + auto t = i.first; + auto b = i.second; + + std::string name = t->op->name; + + air::ir::TensorKey key{t->op, t->value_index}; + Region bounds; + if (bounds.empty()) { + for (auto j : t->shape) { + bounds.push_back(Range::make_by_min_extent(Expr(0), j)); + } + } + + info.bounds_[key] = bounds; + + Array strides; + for (size_t i = 1; i < b->shape.size(); ++i) { + Expr stride = IntImm::make(Int(32), 1); + for (size_t j = b->shape.size() - 1; j >= i; --j) { + stride = Mul::make(stride, b->shape[j]); + } + strides.push_back(stride); + } + strides.push_back(make_const(Int(32), 1)); + info.strides_[name] = strides; + } + + auto mma = scop_info.analysis_result_.GetMmaMode(); + info.warp_tile_.m = mma.m; + info.warp_tile_.n = mma.n; + info.warp_tile_.k = mma.k; + + bool result = CheckTileValid(info.warp_tile_, info); + CHECK(result) << "tile set is not valid!"; + + info.matrix_abc_ = scop_info.analysis_result_.GetMatrixMatmulMap(); + info.matrix_major_ = scop_info.analysis_result_.GetMatrixMatmulMajor(); + + for (auto &i : info.matrix_abc_) { + info.frag_reg_.insert(i.first + LOCAL_SUFFIX); + } +} + +Stmt EmitForTensorCore(Stmt stmt, TensorCoreInfo &info, ScopInfo &scop_info) { + CheckTensorCoreValid check; + check.Visit(stmt); + if (!check.IsValid()) { + return stmt; + } + PrepareDataForTensorCore(info, scop_info); + stmt = AddMmaAttrFlag(info).Mutate(stmt); + LocalTensorAnalyser local_analyser(info, scop_info); + local_analyser.Visit(stmt); + + ModifyTheLocalOffset warp(info, scop_info, local_analyser); + stmt = warp.Mutate(stmt); + + stmt = TensorCoreInterfaceEmit(info, scop_info, warp).Mutate(stmt); + stmt = DeleteUselessAttr().Mutate(stmt); + + if (scop_info.user_config_.GetEnableConvTensorCore()) { + stmt = AddNoUnrollAttr().Mutate(stmt); + } + + if (scop_info.analysis_result_.GetBatchAxisNumForMatmul()) { + auto batch_axis_num = scop_info.analysis_result_.GetBatchAxisNumForMatmul(); + stmt = AttrStmt::make(Expr(""), "batch_axis_num", make_const(Int(32), batch_axis_num), stmt); + } + + // add tensor core plan two attr + if (scop_info.user_config_.GetEnableTensorCore()) { + if (scop_info.user_config_.GetEnableTensorCoreUsePoly()) { + stmt = AttrStmt::make(Expr(""), "pragma_tensor_core", StringImm::make(TENSOR_CORE_MODE_TWO), stmt); + stmt = AttrStmt::make(Expr("INFO"), "wmma_scope", StringImm::make(info.wmma_scope_), stmt); + } else { + stmt = AttrStmt::make(Expr(""), "pragma_tensor_core", StringImm::make(TENSOR_CORE_MODE_ONE), stmt); + } + } + + return stmt; +} +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/gpu_isl_emitter.cc b/src/poly/gpu_isl_emitter.cc deleted file mode 100644 index 10d4fbeb4aa0ce1c331b8e6439a8325fb6bc82db..0000000000000000000000000000000000000000 --- a/src/poly/gpu_isl_emitter.cc +++ /dev/null @@ -1,1899 +0,0 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "poly/gpu_isl_emitter.h" -#include "pass/utils.h" -#include "gpu_emit/emit_pass.h" -#include -#include - -namespace akg { -namespace ir { -namespace poly { - -Expr GpuIslEmitter::EmitLoad(const isl::ast_expr &expr, const Type type) { - if (PRINT_EMITTER) { - LOG(INFO) << ">>>>>>>>>>>>INPUT AST_NODE[LOAD]<<<<<<<<<<<<<<\n" << expr; - } - if (auto op = expr.as()) { - if (auto access = op.as()) { - CHECK(op.get_arg(0).as()); - auto var = op.get_arg(0).as().get_id(); - Array local_args; - for (unsigned int i = 1; i < op.get_n_arg(); ++i) { - local_args.push_back(Interpret(op.get_arg(i))); - } - - Tensor t = info_.FindTensor(var); - auto call = Call::make(type, t->op->name, local_args, Call::CallType::Halide, t->op, t->value_index); - if (PRINT_EMITTER) { - LOG(INFO) << ">>>>>>>>>>>>OUTPUT STMT<<<<<<<<<<<<\n" << call; - } - return call; - } - } - return Expr(); -} - -Stmt GpuIslEmitter::EmitRead(const isl::ast_node_user &node) { - isl::id node_id = node.get_annotation(); - isl::pw_multi_aff iterator_map = node_info_map_.at(node_id).iterator_map; - isl::pw_multi_aff hoisted = iterator_map.range_factor_range(); - isl::pw_multi_aff original = iterator_map.range_factor_domain().range_factor_range(); - - isl::id original_tensor = original.get_tuple_id(isl_dim_out); - - auto build = node_info_map_.at(node_id).build; - auto lhs = build.access_from(isl::multi_pw_aff(hoisted)); - auto rhs = build.access_from(isl::multi_pw_aff(original)); - - Type type = info_.GetDtypeOf(rhs); - if (auto op = lhs.as()) { - if (auto access = op.as()) { - Expr value = EmitLoad(rhs, type); - auto var = op.get_arg(0).as().get_id(); - - Array local_args; - for (unsigned int i = 1; i < op.get_n_arg(); ++i) { - local_args.push_back(Interpret(op.get_arg(i))); - } - - Tensor t = info_.FindTensor(var); - CHECK(t.defined()); - return Provide::make(t->op, 0, value, local_args); - } - } - return Stmt(); -} - -std::string SimplifyName(std::string input) { - auto pos_local = input.find(LOCAL_SUFFIX); - auto pos_shared = input.find(SHARE_SUFFIX); - std::string res = input; - if (pos_local != std::string::npos) { - res = input.substr(0, pos_local); - } - if (pos_shared != std::string::npos) { - res = res.substr(0, pos_shared); - } - return res; -} - -Stmt GpuIslEmitter::EmitReadCore(const isl::ast_node_user &node) { - isl::id node_id = node.get_annotation(); - isl::pw_multi_aff iterator_map = node_info_map_.at(node_id).iterator_map; - isl::pw_multi_aff hoisted = iterator_map.range_factor_range(); - isl::pw_multi_aff original = iterator_map.range_factor_domain().range_factor_range(); - - isl::id original_tensor = original.get_tuple_id(isl_dim_out); - - auto build = node_info_map_.at(node_id).build; - auto lhs = build.access_from(isl::multi_pw_aff(hoisted)); - auto rhs = build.access_from(isl::multi_pw_aff(original)); - - Type type = info_.GetDtypeOf(rhs); - if (auto op = lhs.as()) { - if (auto access = op.as()) { - Expr value = EmitLoad(rhs, type); - auto var = op.get_arg(0).as().get_id(); - - Array local_args; - for (unsigned int i = 1; i < op.get_n_arg(); ++i) { - local_args.push_back(Interpret(op.get_arg(i))); - } - - Tensor t = info_.FindTensor(var); - CHECK(t.defined()); - Stmt s = Provide::make(t->op, 0, value, local_args); - - auto op_new = s.as(); - CHECK(op_new); - const Call *call_value = op_new->value.as(); - CHECK(call_value != nullptr) << "Can only load fragment from a buffer"; - - auto left_expr = MakeLeftCallFromProvide(op_new); - auto left_call = left_expr.as(); - CHECK(left_call != nullptr) << "make right part call failed!"; - - auto it = tensor_core_info_.strides_.find(call_value->name); - CHECK(it != tensor_core_info_.strides_.end()) << "Cannot find stride for " << call_value->name; - auto strides = it->second; - CHECK_GE(strides.size(), 2); - Expr stride = strides[strides.size() - 2]; - - std::string call_name = op_new->func->func_name(); - Expr src = Call::make(call_value->type, "&", {op_new->value}, Call::Extern); - - Expr matrix_major; - auto iter2 = tensor_core_info_.matrix_major_.find(SimplifyName(call_name)); - CHECK(iter2 != tensor_core_info_.matrix_major_.end()) << "Can not determine matrix major for " << call_name; - if (iter2->second == COL_MAJOR) { - matrix_major = StringImm::make(COL_MAJOR); - } else if (iter2->second == ROW_MAJOR) { - matrix_major = StringImm::make(ROW_MAJOR); - } else { - LOG(FATAL) << "invalid matrix major for " << call_name; - } - - NodePtr buffer_node = make_node(); - EmitTensorCoreHelper helper(tensor_core_info_); - helper.SetDataForLoad(src, stride, matrix_major, left_call, op_new, buffer_node); - return helper.MakeLoadTransform(); - } - } - return Stmt(); -} - -Expr GpuIslEmitter::MakeLeftCallFromProvide(const Provide *op) { - std::string name = op->func->func_name(); - Type type = info_.GetDtypeOf(name); - Expr dst = Call::make(type, name, op->args, Call::Halide, op->func, 0); - return dst; -} - -Stmt GpuIslEmitter::EmitWrite(const isl::ast_node_user &node) { - auto node_id = node.get_annotation(); - CHECK_GT(node_info_map_.count(node_id), 0); - auto iterator_map = node_info_map_.at(node_id).iterator_map; - auto hoisted = iterator_map.range_factor_range(); - auto original = iterator_map.range_factor_domain().range_factor_range(); - - auto build = node_info_map_.at(node_id).build; - auto rhs = build.access_from(isl::multi_pw_aff(hoisted)); - auto lhs = build.access_from(isl::multi_pw_aff(original)); - Type type = info_.GetDtypeOf(lhs); - - if (auto op = lhs.as()) { - if (auto access = op.as()) { - Expr value = EmitLoad(rhs, type); - auto var = op.get_arg(0).as().get_id(); - - Array local_args; - for (unsigned int i = 1; i < op.get_n_arg(); ++i) { - local_args.push_back(Interpret(op.get_arg(static_cast(i)))); - } - - Tensor t = info_.FindTensor(var); - CHECK(t.defined()); - - return Provide::make(t->op, 0, value, local_args); - } - } - return Stmt(); -} - -Stmt GpuIslEmitter::EmitWriteCore(const isl::ast_node_user &node) { - auto node_id = node.get_annotation(); - CHECK_GT(node_info_map_.count(node_id), 0); - auto iterator_map = node_info_map_.at(node_id).iterator_map; - auto hoisted = iterator_map.range_factor_range(); - auto original = iterator_map.range_factor_domain().range_factor_range(); - - auto build = node_info_map_.at(node_id).build; - auto rhs = build.access_from(isl::multi_pw_aff(hoisted)); - auto lhs = build.access_from(isl::multi_pw_aff(original)); - Type type = info_.GetDtypeOf(lhs); - - if (auto op = lhs.as()) { - if (auto access = op.as()) { - Expr value = EmitLoad(rhs, type); - auto var = op.get_arg(0).as().get_id(); - - Array local_args; - for (unsigned int i = 1; i < op.get_n_arg(); ++i) { - local_args.push_back(Interpret(op.get_arg(static_cast(i)))); - } - - Tensor t = info_.FindTensor(var); - CHECK(t.defined()); - - Stmt s = Provide::make(t->op, 0, value, local_args); - - auto op = s.as(); - CHECK(op); - - auto lh_expr = MakeLeftCallFromProvide(op); - auto lh_call = lh_expr.as(); - CHECK(lh_call != nullptr) << "make right part call failed!"; - - auto it = tensor_core_info_.strides_.find(lh_call->name); - CHECK(it != tensor_core_info_.strides_.end()) << "Cannot find stride for " << lh_call->name; - auto strides = it->second; - CHECK_GE(strides.size(), 2); - Expr stride = strides[strides.size() - 2]; - - Expr dst = lh_expr; - dst = Call::make(Handle(), "&", {dst}, Call::Extern); - - auto call = op->value.as(); - NodePtr buffer_node = make_node(); - EmitTensorCoreHelper helper(tensor_core_info_); - helper.SetDataForStore(dst, stride, call, buffer_node); - return helper.MakeStoreTransform(); - } - } - return Stmt(); -} - -Stmt GpuIslEmitter::EmitWriteAtomic(const isl::ast_node_user &node) { - auto node_id = node.get_annotation(); - CHECK_GT(node_info_map_.count(node_id), 0); - auto iterator_map = node_info_map_.at(node_id).iterator_map; - auto hoisted = iterator_map.range_factor_range(); - auto original = iterator_map.range_factor_domain().range_factor_range(); - - auto build = node_info_map_.at(node_id).build; - auto rhs = build.access_from(isl::multi_pw_aff(hoisted)); - auto lhs = build.access_from(isl::multi_pw_aff(original)); - - auto opr = rhs.as(); - reduce_info_.output_promoted_tensor_name_for_atomic_ = opr.get_arg(0).as().get_id().name(); - reduce_info_.atomic_tensors_.insert(reduce_info_.output_promoted_tensor_name_for_atomic_); - - Type type = info_.GetDtypeOf(lhs); - reduce_info_.output_tensor_data_type_info_ = type; - - if (auto op = lhs.as()) { - if (auto access = op.as()) { - Expr value = EmitLoad(rhs, type); - reduce_info_.atomic_rhs_ = value; - auto var = op.get_arg(0).as().get_id(); - - Array local_args; - for (unsigned int i = 1; i < op.get_n_arg(); ++i) { - Expr arg = Interpret(op.get_arg(static_cast(i))); - local_args.push_back(arg); - } - - Tensor t = info_.FindTensor(var); - CHECK(t.defined()); - - return Provide::make(t->op, 0, value, local_args); - } - } - return Stmt(); -} - -Stmt GpuIslEmitter::EmitSync() { - return Evaluate::make(Call::make(Int(32), STORAGE_SYNC, {StringImm::make(SYNC_SCOP_SHARED)}, Call::Intrinsic)); -} - -void GpuIslEmitter::SetScalarTensorBind(std::string scalar_tensor_name) { - Array shapes; - shapes.push_back(Expr(1)); - Type type = reduce_info_.reduce_data_type_info_; - reduce_info_.added_tensors_.insert(scalar_tensor_name); - - Tensor tensor = placeholder(shapes, type, scalar_tensor_name); - const Buffer buffer = decl_buffer(shapes, type, scalar_tensor_name); - reduce_info_.scalar_tensor_[scalar_tensor_name] = tensor; - - info_.user_config_.SetBind(tensor, buffer); -} - -void GpuIslEmitter::SetSharedTensorBind() { - auto thread_cfg = info_.user_config_.GetThreadConfig(); - CHECK(thread_cfg) << "thread config is null."; - int tx = thread_cfg->GetX().second; - int ty = thread_cfg->GetY().second; - - int size = tx * ty; - Array shapes; - shapes.push_back(Expr(size)); - Type type = reduce_info_.reduce_data_type_info_; - std::string shared_tensor_name = reduce_info_.shared_compute_name_; - reduce_info_.added_tensors_.insert(shared_tensor_name); - - Tensor tensor = placeholder(shapes, type, shared_tensor_name); - const Buffer buffer = decl_buffer(shapes, type, shared_tensor_name); - reduce_info_.shared_tensor_ = tensor; - - info_.user_config_.SetBind(tensor, buffer); -} - -Stmt GpuIslEmitter::EmitReduceInit(const isl::ast_node_user &node) { - CHECK(node.get_expr().isa()); - isl::ast_expr_op usr_expr = node.get_expr().as(); - CHECK(usr_expr); - auto stmt_id = usr_expr.get_arg(0).as().get_id(); - - CHECK(!reduce_info_.scalar_tensor_name_.empty()) << "scalar tensor info should not be empty!"; - if (reduce_info_.reduce_op_.find("SumOp") != std::string::npos) { - CHECK(!reduce_info_.scalar_kht_name_.empty()) << "scalar tensor kht info should not be empty!"; - CHECK(!reduce_info_.scalar_khy_name_.empty()) << "scalar tensor khy info should not be empty!"; - CHECK(!reduce_info_.scalar_khc_name_.empty()) << "scalar tensor khc info should not be empty!"; - } - - std::vector strs = common::Split(stmt_id.name(), "_"); - CHECK_EQ(strs.size(), REDUCE_FLAG_SIZE) << "red init format is not right!."; - - std::string stmt_name = strs[REDUCE_FLAG_STMT_PREFIX_POS] + "_" + strs[REDUCE_FLAG_STMT_NUM_POS]; - Expr init_value; - for (auto it : info_.analysis_result_.GetReduceTensorInfoMap()) { - if (it.first.name() == stmt_name) { - init_value = it.second.init_value; - break; - } - } - - CHECK(reduce_info_.reduce_area_stmt_.defined()); - reduce_info_.stmts_.insert(reduce_info_.stmts_.begin(), reduce_info_.reduce_area_stmt_); - - Array args; - args.push_back(Expr(0)); - Stmt scalar_stmt = - Provide::make(reduce_info_.scalar_tensor_[reduce_info_.scalar_tensor_name_]->op, 0, init_value, args); - CHECK(scalar_stmt.defined()); - reduce_info_.stmts_.insert(reduce_info_.stmts_.begin(), scalar_stmt); - - if (reduce_info_.reduce_op_.find("SumOp") != std::string::npos) { - Stmt scalar_khc = - Provide::make(reduce_info_.scalar_tensor_[reduce_info_.scalar_khc_name_]->op, 0, init_value, args); - CHECK(scalar_khc.defined()); - reduce_info_.stmts_.insert(reduce_info_.stmts_.begin(), scalar_khc); - } - - MakeReduceStmt(); - - Stmt stmt = Block::make(reduce_info_.stmts_); - stmt = InsertRealizeWithMemType(stmt, isl::id(stmt_id.ctx(), reduce_info_.scalar_tensor_name_), MEM_TYPE_LOCAL); - if (reduce_info_.reduce_op_.find("SumOp") != std::string::npos) { - stmt = InsertRealizeWithMemType(stmt, isl::id(stmt_id.ctx(), reduce_info_.scalar_kht_name_), MEM_TYPE_LOCAL); - stmt = InsertRealizeWithMemType(stmt, isl::id(stmt_id.ctx(), reduce_info_.scalar_khy_name_), MEM_TYPE_LOCAL); - stmt = InsertRealizeWithMemType(stmt, isl::id(stmt_id.ctx(), reduce_info_.scalar_khc_name_), MEM_TYPE_LOCAL); - } - stmt = InsertRealizeWithMemType(stmt, isl::id(stmt_id.ctx(), reduce_info_.shared_compute_name_), MEM_TYPE_SHARED); - - ResetStatus(); - return stmt; -} - -Stmt GpuIslEmitter::EmitUserStmt(const isl::ast_node_user &node) { - CHECK(node.get_expr().isa()); - isl::ast_expr_op usr_expr = node.get_expr().as(); - stmt_id_ = usr_expr.get_arg(0).as().get_id(); - node_id_ = node.get_annotation(); - const Node *stmt_node = info_.analysis_result_.GetStatementMap().at(stmt_id_); - CHECK(stmt_node); - // compute VarMap to replace old iterators - auto build = node_info_map_.at(node_id_).build; - auto tuple = info_.analysis_result_.GetOperatorDomainMap().at(stmt_id_).tuple; - auto iterator_map = node_info_map_.at(node_id_).iterator_map; - - auto ids = info_.analysis_result_.GetReduceInitIds(); - for (auto &i : ids) { - if (i.get_name() == stmt_id_.get_name()) { - reduce_info_.init_stmt_emit_ = true; - break; - } - } - - var_map_.clear(); - for (unsigned int i = 0; i < tuple.size(); ++i) { - isl::id isl_old_iter = tuple.get_id(i); - auto isl_expr = build.expr_from(iterator_map.get_pw_aff(i)); - Expr halide_new_iter = Interpret(isl_expr); - var_map_.emplace(isl_old_iter, halide_new_iter); - } - - return EmitUserStmtContent(stmt_node); -} - -void GpuIslEmitter::ResetStatus() { - reduce_info_.stmts_.clear(); - reduce_info_.reduce_area_stmt_ = Stmt(); - reduce_info_.origin_reduce_stmt_ = Stmt(); - reduce_info_.gm_write_stmt_ = Stmt(); - reduce_info_.atomic_rhs_ = Expr(); - reduce_info_.input_tensor_expr_ = Expr(); - is_out_most_stmt_ = true; -} - -Stmt GpuIslEmitter::EmitReduceUpdate(const isl::ast_node_user &node) { - CHECK(node.get_expr().isa()); - isl::ast_expr_op usr_expr = node.get_expr().as(); - CHECK(usr_expr); - auto stmt_id = usr_expr.get_arg(0).as().get_id(); - - std::vector strs = common::Split(stmt_id.name(), "_"); - CHECK_EQ(strs.size(), REDUCE_FLAG_SIZE) << "red update format is not right!."; - - reduce_info_.reduce_stmt_index_ = strs[REDUCE_FLAG_REDUCE_INDEX]; - reduce_info_.scalar_tensor_name_ = SCALAR_TENSOR_PREFIX; - reduce_info_.scalar_tensor_name_ += reduce_info_.reduce_stmt_index_; - - reduce_info_.shared_compute_name_ = SHARED_TENSOR_PREFIX; - reduce_info_.shared_compute_name_ += reduce_info_.reduce_stmt_index_; - - if (AkgSupportedReduceOp.count(strs[REDUCE_FLAG_TYPE_POS])) { - reduce_info_.reduce_op_ = AKG_REDUCE_LIB_SPACE; - reduce_info_.reduce_op_ += "::"; - reduce_info_.reduce_op_ += strs[REDUCE_FLAG_TYPE_POS]; - } - CHECK(!reduce_info_.reduce_op_.empty()) << "reduce op should not be empty!"; - - if (reduce_info_.reduce_op_.find("SumOp") != std::string::npos) { - reduce_info_.scalar_kht_name_ = SCALAR_KHT_PREFIX; - reduce_info_.scalar_kht_name_ += reduce_info_.reduce_stmt_index_; - reduce_info_.scalar_khy_name_ = SCALAR_KHY_PREFIX; - reduce_info_.scalar_khy_name_ += reduce_info_.reduce_stmt_index_; - reduce_info_.scalar_khc_name_ = SCALAR_KHC_PREFIX; - reduce_info_.scalar_khc_name_ += reduce_info_.reduce_stmt_index_; - } - - std::string stmt_name = strs[REDUCE_FLAG_STMT_PREFIX_POS] + "_" + strs[REDUCE_FLAG_STMT_NUM_POS]; - std::string origin_tensor_name = ""; - for (auto it : info_.analysis_result_.GetReduceTensorInfoMap()) { - if (it.first.name() == stmt_name) { - origin_tensor_name = it.second.write_tensor_name; - reduce_info_.reduce_data_type_info_ = it.second.write_dtype; - break; - } - } - CHECK(!origin_tensor_name.empty()) << "origin_tensor_name should not be empty!"; - - for (const auto &buffer : info_.analysis_result_.active_buffer_footprints_) { - auto cluster_id = buffer.second.cluster_id; - auto buf_def = info_.analysis_result_.GetBufferDefInfo(cluster_id); - if (buf_def.tensor_id.name() == origin_tensor_name) { - reduce_info_.promoted_tensor_name_for_reduce_ = cluster_id.name(); - break; - } - } - - MakeAkgReduceFuncName(); - SetScalarTensorBind(reduce_info_.scalar_tensor_name_); - if (reduce_info_.reduce_op_.find("SumOp") != std::string::npos) { - SetScalarTensorBind(reduce_info_.scalar_kht_name_); - SetScalarTensorBind(reduce_info_.scalar_khy_name_); - SetScalarTensorBind(reduce_info_.scalar_khc_name_); - } - SetSharedTensorBind(); - - return Stmt(); -} - -Stmt GpuIslEmitter::TransferToKaHanInterface() { - std::string func_name = AKG_REDUCE_LIB_SPACE; - func_name += "::"; - func_name += AKG_KAHAN_LIB_NAME; - Expr template_arg0 = make_const(reduce_info_.reduce_data_type_info_, 1); - - Array args; - args.push_back(Expr(0)); - - Tensor tt = reduce_info_.scalar_tensor_[reduce_info_.scalar_khy_name_]; - Expr a1 = Call::make(tt->dtype, tt->op->func_name(), args, Call::Halide, tt->op, 0); - a1 = Call::make(a1.type(), "&", {a1}, Call::Extern); - - tt = reduce_info_.scalar_tensor_[reduce_info_.scalar_kht_name_]; - Expr a2 = Call::make(tt->dtype, tt->op->func_name(), args, Call::Halide, tt->op, 0); - a2 = Call::make(a2.type(), "&", {a2}, Call::Extern); - - tt = reduce_info_.scalar_tensor_[reduce_info_.scalar_khc_name_]; - Expr a3 = Call::make(tt->dtype, tt->op->func_name(), args, Call::Halide, tt->op, 0); - a3 = Call::make(a3.type(), "&", {a3}, Call::Extern); - - tt = reduce_info_.scalar_tensor_[reduce_info_.scalar_tensor_name_]; - Expr a4 = Call::make(tt->dtype, tt->op->func_name(), args, Call::Halide, tt->op, 0); - a4 = Call::make(a4.type(), "&", {a4}, Call::Extern); - - CHECK(reduce_info_.input_tensor_expr_.defined()); - Stmt stmt = Evaluate::make( - Call::make(Int(32), func_name, {template_arg0, a1, a2, a3, a4, reduce_info_.input_tensor_expr_}, Call::Extern)); - - return stmt; -} - -void GpuIslEmitter::MakeReduceStmt() { - std::string func_name = reduce_info_.akg_reduce_api_; - std::string op_info = reduce_info_.reduce_op_ + "()"; - - Expr template_arg0 = make_const(reduce_info_.reduce_data_type_info_, 1); - CHECK(!reduce_info_.akg_reduce_template_arg_.empty()); - Expr template_arg1 = StringImm::make(reduce_info_.akg_reduce_template_arg_); - - Array args_a1; - Expr a1 = Call::make(Int(32), reduce_info_.reduce_op_, args_a1, Call::Extern); - - auto p = reduce_info_.origin_reduce_stmt_.as(); - CHECK(p); - Expr a2 = Call::make(p->value.type(), p->func->func_name(), p->args, Call::Halide, p->func, 0); - a2 = Call::make(a2.type(), "&", {a2}, Call::Extern); - - Tensor tensor = info_.FindTensor(reduce_info_.shared_compute_name_); - auto bind = info_.user_config_.GetBind(); - Buffer buffer; - for (auto &i : bind) { - if (!i.first.defined()) continue; - if (i.first == tensor) { - buffer = i.second; - } - } - - CHECK(buffer.defined()); - - Tensor tt = reduce_info_.scalar_tensor_[reduce_info_.scalar_tensor_name_]; - Array args; - args.push_back(Expr(0)); - Expr a4 = Call::make(tt->dtype, tt->op->func_name(), args, Call::Halide, tt->op, 0); - - auto thread_cfg = info_.user_config_.GetThreadConfig(); - CHECK(thread_cfg); - int tx = thread_cfg->GetX().second; - int ty = thread_cfg->GetY().second; - Expr a5 = Expr(tx); - - Stmt stmt = Evaluate::make( - Call::make(Int(32), func_name, {template_arg0, template_arg1, a1, a2, buffer->data, a4, a5}, Call::Extern)); - - stmt = AttrStmt::make(Expr("INFO"), REDUCE_LIB_TYPE_FLAG, info_.user_config_.GetReduceLibType(), stmt); - - int size = tx * ty; - stmt = AttrStmt::make(buffer->data, air::ir::attr::storage_scope, Expr(MEM_TYPE_SHARED), - Allocate::make(buffer->data, buffer->dtype, {Expr(size)}, const_true(), stmt)); - reduce_info_.stmts_.insert(reduce_info_.stmts_.end(), stmt); - return; -} - -Stmt GpuIslEmitter::MakeAtomicStmt() { - std::string func_name = reduce_info_.akg_atomic_api_; - - Expr template_arg0 = make_const(reduce_info_.output_tensor_data_type_info_, 1); - CHECK(!reduce_info_.akg_atomic_template_arg_.empty()); - Expr template_arg1 = StringImm::make(reduce_info_.akg_atomic_template_arg_); - - Expr a1 = reduce_info_.atomic_rhs_; - - auto p = reduce_info_.gm_write_stmt_.as(); - CHECK(p); - - Expr a2 = Call::make(p->value.type(), p->func->func_name(), p->args, Call::Halide, p->func, 0); - a2 = Call::make(a2.type(), "&", {a2}, Call::Extern); - - std::string op_info = reduce_info_.reduce_op_ + "()"; - - Array args; - Expr a3 = Call::make(Int(32), reduce_info_.reduce_op_, args, Call::Extern); - - return Evaluate::make(Call::make(Int(32), func_name, {template_arg0, template_arg1, a1, a2, a3}, Call::Extern)); -} - -Stmt GpuIslEmitter::EmitReduceArea(const isl::ast_node_user &node) { - bool add_to_reduce_area = false; - if (in_reduce_area_ && is_out_most_stmt_) { - add_to_reduce_area = true; - is_out_most_stmt_ = false; - } - CHECK(node.get_expr().isa()); - isl::ast_expr_op usr_expr = node.get_expr().as(); - stmt_id_ = usr_expr.get_arg(0).as().get_id(); - node_id_ = node.get_annotation(); - const Node *stmt_node = info_.analysis_result_.GetStatementMap().at(stmt_id_); - CHECK(stmt_node); - // compute VarMap to replace old iterators - auto build = node_info_map_.at(node_id_).build; - auto tuple = info_.analysis_result_.GetOperatorDomainMap().at(stmt_id_).tuple; - auto iterator_map = node_info_map_.at(node_id_).iterator_map; - - var_map_.clear(); - for (unsigned int i = 0; i < tuple.size(); ++i) { - isl::id isl_old_iter = tuple.get_id(i); - auto isl_expr = build.expr_from(iterator_map.get_pw_aff(i)); - Expr halide_new_iter = Interpret(isl_expr); - var_map_.emplace(isl_old_iter, halide_new_iter); - } - - Stmt stmt = EmitUserStmtContent(stmt_node); - - CHECK(!reduce_info_.promoted_tensor_name_for_reduce_.empty()) - << "promoted_tensor_name_for_reduce_ should not be empty"; - reduce_info_.reduce_stmt_[reduce_info_.promoted_tensor_name_for_reduce_] = stmt; - reduce_info_.origin_reduce_stmt_ = stmt; - - Array args_scalar; - args_scalar.push_back(Expr(0)); - - stmt = AkgReduceStmtChange(reduce_info_.scalar_tensor_[reduce_info_.scalar_tensor_name_], args_scalar, - reduce_info_.promoted_tensor_name_for_reduce_) - .Mutate(stmt); - - if (reduce_info_.reduce_op_.find("SumOp") != std::string::npos) { - auto pro = stmt.as(); - CHECK(pro); - auto value = pro->value; - auto add = value.as(); - CHECK(add); - auto add_a = add->a; - auto add_b = add->b; - reduce_info_.input_tensor_expr_ = - (add->a.as() && add->a.as()->name == reduce_info_.scalar_tensor_name_) ? add_b : add_a; - stmt = TransferToKaHanInterface(); - } - - if (add_to_reduce_area) { - reduce_info_.reduce_area_stmt_ = stmt; - return Stmt(); - } - - return stmt; -} - -Stmt GpuIslEmitter::EmitUserStmtCore(const isl::ast_node_user &node) { - if (tensor_core_info_.matrix_info_[MMA_SYNC]) { - return EmitUserStmtCoreSync(node); - } - return Stmt(); -} - -Stmt GpuIslEmitter::EmitUserStmtCoreSync(const isl::ast_node_user &node) { - static int serial_number = MMA_SYNC_STMT_SERIAL; - CHECK(node.get_expr().isa()); - isl::ast_expr_op usr_expr = node.get_expr().as(); - stmt_id_ = usr_expr.get_arg(0).as().get_id(); - node_id_ = node.get_annotation(); - const Node *stmt_node = info_.analysis_result_.GetStatementMap().at(stmt_id_); - CHECK(stmt_node); - // compute VarMap to replace old iterators - auto build = node_info_map_.at(node_id_).build; - auto tuple = info_.analysis_result_.GetOperatorDomainMap().at(stmt_id_).tuple; - auto iterator_map = node_info_map_.at(node_id_).iterator_map; - - var_map_.clear(); - for (unsigned int i = 0; i < tuple.size(); ++i) { - isl::id isl_old_iter = tuple.get_id(i); - auto isl_expr = build.expr_from(iterator_map.get_pw_aff(i)); - Expr halide_new_iter = Interpret(isl_expr); - var_map_.emplace(isl_old_iter, halide_new_iter); - } - - Stmt s = EmitUserStmtContent(stmt_node); - - if (serial_number == MMA_SYNC_STMT_SERIAL) { - serial_number = MMA_FILL_STMT_SERIAL; - auto op = s.as(); - auto left_expr = MakeLeftCallFromProvide(op); - Type type = info_.GetDtypeOf(op->func->func_name()); - auto *add = op->value.as(); - CHECK(add) << "format error of bmm"; - auto mul = akg::common::SplitCast(add->b, type).as(); - CHECK(mul) << "format error of bmm"; - - auto load_a_expr = akg::common::SplitCast(mul->a, type); - auto load_b_expr = akg::common::SplitCast(mul->b, type); - - Expr a = load_a_expr; - Expr b = load_b_expr; - Expr c = left_expr; - - NodePtr buffer_node_a = make_node(); - NodePtr buffer_node_b = make_node(); - NodePtr buffer_node_c = make_node(); - - EmitTensorCoreHelper helper(tensor_core_info_); - helper.SetDataForSync(a, b, c, buffer_node_a, buffer_node_b, buffer_node_c); - return helper.MakeSyncTransform(); - } else if (serial_number == MMA_FILL_STMT_SERIAL) { - serial_number = MMA_SYNC_STMT_SERIAL; - auto op = s.as(); - auto left_expr = MakeLeftCallFromProvide(op); - auto left_call = left_expr.as(); - CHECK(left_call != nullptr) << "make right part call failed"; - - if (op->value.as() != nullptr || op->value.as() != nullptr) { - NodePtr buffer_node = make_node(); - EmitTensorCoreHelper helper(tensor_core_info_); - helper.SetDataForFill(op, left_call, buffer_node); - return helper.MakeFillTransform(); - } else { - CHECK(false) << "mma init stmt format error"; - } - } - - return Stmt(); -} - -Stmt GpuIslEmitter::EmitStmt(const isl::ast_node_user &node) { - CHECK(node.get_expr().isa()); - isl::ast_expr_op usr_expr = node.get_expr().as(); - CHECK(usr_expr); - auto stmt_id = usr_expr.get_arg(0).as().get_id(); - auto node_id = node.get_annotation(); - - if (info_.IsRead(stmt_id)) { - Stmt s; - is_sync_before_ = false; - if (tensor_core_info_.core_area_) { - s = EmitReadCore(node); - } else { - s = EmitRead(node); - s = AttrStmt::make(Expr(""), GMREAD_FLAG, StringImm::make(GMREAD_FLAG), s); - } - return s; - } else if (info_.IsWrite(stmt_id)) { - if (info_.IsGMWrite(stmt_id)) { - if (tensor_core_info_.core_area_) { - is_sync_before_ = false; - return EmitWriteCore(node); - } - auto iterator_map = node_info_map_.at(node_id).iterator_map; - auto original = iterator_map.range_factor_domain().range_factor_range(); - auto srcid = original.get_tuple_id(isl_dim_out); - bool no_need_to_emit = NoNeedToEmitForTempTensor(srcid); - if (no_need_to_emit) return Stmt(); - - if (reduce_info_.is_atomic) { - reduce_info_.gm_write_stmt_ = EmitWriteAtomic(node); - ConstructAtomicReturnFuncName(); - is_sync_before_ = false; - reduce_info_.is_atomic = false; - return MakeAtomicStmt(); - } - is_sync_before_ = false; - if (tensor_core_info_.core_area_) { - return EmitWriteCore(node); - } else { - return EmitWrite(node); - } - } - is_sync_before_ = false; - return EmitWrite(node); - } else if (info_.IsSync(stmt_id)) { - if (is_sync_before_) { - return Stmt(); - } - Stmt s = EmitSync(); - is_sync_before_ = true; - return s; - } else if (info_.IsReduceInit(stmt_id)) { - is_sync_before_ = false; - in_reduce_area_ = false; - return EmitReduceInit(node); - } else if (in_reduce_area_) { - is_sync_before_ = false; - return EmitReduceArea(node); - } else if (info_.IsReduceUpdate(stmt_id)) { - is_sync_before_ = false; - Stmt s = EmitReduceUpdate(node); - in_reduce_area_ = true; - return s; - } else { - is_sync_before_ = false; - Stmt s; - if (tensor_core_info_.core_area_) { - s = EmitUserStmtCore(node); - } else { - s = EmitUserStmt(node); - } - - return s; - } -} - -void GpuIslEmitter::MakeAkgReduceFuncName() { - auto thread_cfg = info_.user_config_.GetThreadConfig(); - CHECK(thread_cfg) << "thread config is null."; - auto block_cfg = info_.user_config_.GetBlockConfig(); - CHECK(block_cfg) << "thread config is null."; - int tx = thread_cfg->GetX().second; - int ty = thread_cfg->GetY().second; - int by = block_cfg->GetY().second; - std::string direction = info_.analysis_result_.GetReduceDirection(); - CHECK(!direction.empty()) << "direction should not be empty!"; - std::string direction_size = ""; - if (direction == X_DIRECTION) { - direction_size = std::to_string(tx); - } else { - direction_size = std::to_string(ty); - } - - std::string reduce_lib_namespace = ""; - std::string reduce_lib_name = ""; - if (info_.user_config_.GetReduceLibType() == REDUCE_LIB_TYPE_ORIGIN) { - reduce_lib_namespace = AKG_REDUCE_LIB_SPACE; - reduce_lib_name = AKG_REDUCE_LIB_NAME; - } else if (info_.user_config_.GetReduceLibType() == REDUCE_LIB_TYPE_PARIS) { - reduce_lib_namespace = PARIS_REDUCE_LIB_SPACE; - reduce_lib_name = PARIS_REDUCE_LIB_NAME; - } else { - CHECK(false) << "reduce lib type is invalid!" - << "\n"; - } - std::string ret = reduce_lib_namespace; - ret += "::"; - ret += reduce_lib_name; - - reduce_info_.akg_reduce_api_ = ret; - ret = ""; - - std::string op = reduce_info_.reduce_op_; - ret += op; - ret += ", "; - - ret += std::to_string(tx); - ret += ", "; - ret += std::to_string(ty); - std::string reduce_type = ""; - if (by == 1 && ty == 1) { - reduce_type = AKG_ALL_REDUCE; - } else if (direction == X_DIRECTION) { - reduce_type = AKG_X_REDUCE; - } else { - reduce_type = AKG_Y_REDUCE; - } - ret += ", "; - ret += reduce_type; - - reduce_info_.akg_reduce_template_arg_ = ret; -} - -void GpuIslEmitter::ConstructAtomicReturnFuncName() { - std::string reduce_lib_namespace = ""; - std::string reduce_return_name = ""; - if (info_.user_config_.GetReduceLibType() == REDUCE_LIB_TYPE_ORIGIN) { - reduce_lib_namespace = AKG_REDUCE_LIB_SPACE; - reduce_return_name = AKG_REDUCE_RETURN_NAME; - } else if (info_.user_config_.GetReduceLibType() == REDUCE_LIB_TYPE_PARIS) { - reduce_lib_namespace = PARIS_REDUCE_LIB_SPACE; - reduce_return_name = PARIS_REDUCE_RETURN_NAME; - } else { - CHECK(false) << "reduce lib type is invalid!" - << "\n"; - } - std::string ret = ""; - ret += reduce_lib_namespace; - ret += "::"; - ret += reduce_return_name; - - reduce_info_.akg_atomic_api_ = ret; - ret = ""; - - std::string op = reduce_info_.reduce_op_; - ret += op; - - reduce_info_.akg_atomic_template_arg_ = ret; -} - -bool GpuIslEmitter::NoNeedToEmitForTempTensor(const isl::id &id) { - bool no_need = true; - auto origin_binds = info_.user_config_.GetOriginBind(); - for (auto i : origin_binds) { - if (!i.first.defined()) continue; - std::string name = i.first->op->name; - if (name == id.name()) { - no_need = false; - break; - } - } - return no_need; -} - -Stmt GpuIslEmitter::EmitBlock(const isl::ast_node_block &block_node) { - bool add_to_reduce_area = false; - if (in_reduce_area_ && is_out_most_stmt_) { - add_to_reduce_area = true; - is_out_most_stmt_ = false; - } - - std::vector stmts; - - int num = block_node.get_children().size(); - int last_num = 0; - for (int i = num - 1; i >= 0; --i) { - auto child = block_node.get_children().at(i); - - if (auto node = child.as()) { - CHECK(node.get_expr().isa()); - isl::ast_expr_op usr_expr = node.get_expr().as(); - CHECK(usr_expr); - auto stmt_id = usr_expr.get_arg(0).as().get_id(); - if (info_.IsRealize(stmt_id)) { - isl::id new_stmt_id = isl::id(stmt_id.ctx(), stmt_id.name().substr(REALIZE_PREFIX_LEN)); - int stmt_num = stmts.size(); - CHECK_NE(stmt_num, 0) << "when stmt_num is zero, no realize should be emitted!."; - if (stmt_num == 1) { - stmts[0] = InsertRealize(stmts[0], new_stmt_id); - } else { - if (stmt_num - last_num == 1) { - stmts[0] = InsertRealize(stmts[0], new_stmt_id); - } else { - for (int index = stmt_num - 2 - last_num; index >= 0; --index) { - auto p_index = static_cast(index); - stmts[p_index] = Block::make(stmts[p_index], stmts[p_index + 1]); - } - stmts[0] = InsertRealize(stmts[0], new_stmt_id); - } - } - last_num = stmt_num - 1; - continue; - } - } - - Stmt body = EmitAst(child); - if (!body.defined()) continue; - stmts.insert(stmts.begin(), body); - } - - int len = stmts.size(); - - if (len == 0) { - return Stmt(); - } - - if (last_num == len - 1) { - if (add_to_reduce_area) { - reduce_info_.reduce_area_stmt_ = stmts[0]; - return Stmt(); - } - return stmts[0]; - } else { - for (int index = len - 2 - last_num; index >= 0; --index) { - auto p_index = static_cast(index); - stmts[p_index] = Block::make(stmts[p_index], stmts[p_index + 1]); - } - if (add_to_reduce_area) { - reduce_info_.reduce_area_stmt_ = stmts[0]; - return Stmt(); - } - return stmts[0]; - } -} - -Stmt GpuIslEmitter::EmitFor(const isl::ast_node_for &node) { - bool add_to_reduce_area = false; - if (in_reduce_area_ && is_out_most_stmt_) { - add_to_reduce_area = true; - is_out_most_stmt_ = false; - } - isl::id isl_iter_id = node.get_iterator().as().get_id(); - VarExpr iter_expr(isl_iter_id.to_str()); - PushIter(iter_expr.get()); - - Expr init_expr = Interpret(node.get_init()); - - auto isl_cond = node.get_cond().as(); - CHECK(isl_cond.as() || isl_cond.as()); - auto cond_lhs = isl_cond.get_arg(0).as(); - CHECK(cond_lhs); - CHECK_EQ(cond_lhs.get_id(), isl_iter_id); - Expr cond_expr = Interpret(isl_cond.get_arg(1)); - - int64_t inc = static_cast(WrappedStrtol(node.get_inc().to_C_str())); - CHECK_NE(inc, 0) << "stride should not be zero!."; - - bool need_to_modify_inc_ = false; - if (inc != 1) { - need_to_modify_inc_ = true; - Expr original_init_expr = init_expr; - init_expr = ModifyTheInitExpr(init_expr); - cond_expr = ModifyTheCondExpr(cond_expr, static_cast(inc)); - Expr modify_iter = ModifyTheIterExpr(iter_expr, static_cast(inc), original_init_expr); - stride_modify_iter_map_[iter_expr.get()] = modify_iter; - } - - if (isl_cond.as()) { - cond_expr = Simplify(cond_expr + 1); - } - - cond_expr = Simplify(cond_expr - init_expr); - - // add for tensor core - - if (tensor_core_info_.core_area_) { - tensor_core_info_.core_area_for_extent_[iter_expr] = cond_expr; - } - - if (tensor_core_info_.fragment_axis_begin_) { - if (tensor_core_info_.is_fragment_m_) { - tensor_core_info_.fragment_m_ = cond_expr; - } else if (tensor_core_info_.is_fragment_n_) { - tensor_core_info_.fragment_n_ = cond_expr; - } - } - - Stmt body_stmt = EmitAst(node.get_body()); - - if (!body_stmt.defined()) { - PopIter(iter_expr.get()); - if (tensor_core_info_.core_area_) { - tensor_core_info_.core_area_for_extent_.erase(iter_expr); - } - return Stmt(); - } - - if (need_to_modify_inc_) { - stride_modify_iter_map_.erase(iter_expr.get()); - } - PopIter(iter_expr.get()); - if (tensor_core_info_.core_area_) { - tensor_core_info_.core_area_for_extent_.erase(iter_expr); - } - Stmt stmt = For::make(iter_expr, init_expr, cond_expr, ForType::Serial, DeviceAPI::None, body_stmt); - if (add_to_reduce_area) { - reduce_info_.reduce_area_stmt_ = stmt; - return Stmt(); - } - return stmt; -} - -Stmt GpuIslEmitter::EmitIf(const isl::ast_node_if &node) { - bool add_to_reduce_area = false; - if (in_reduce_area_ && is_out_most_stmt_) { - add_to_reduce_area = true; - is_out_most_stmt_ = false; - } - - Expr cond_expr = Interpret(node.get_cond()); - cur_if_list_.push_back(cond_expr.get()); - Stmt then_case = EmitAst(node.get_then_node()); - if (!then_case.defined()) { - return Stmt(); - } - Stmt else_case; - if (node.has_else_node()) { - else_case = EmitAst(node.get_else_node()); - } - cur_if_list_.pop_back(); - if (reduce_info_.init_stmt_emit_) { - if (info_.user_config_.GetEnableAtomicAdd() && !info_.analysis_result_.GetAtomicMarkers().empty()) { - bool is_found = false; - cond_expr = ConditionExprMod(is_found).Mutate(cond_expr); - if (is_found) { - reduce_info_.init_stmt_emit_ = false; - } - } - } - - Stmt s; - if (!cond_expr.defined()) { - s = then_case; - } else { - s = IfThenElse::make(cond_expr, then_case, else_case); - } - - if (add_to_reduce_area) { - reduce_info_.reduce_area_stmt_ = s; - return Stmt(); - } - - return s; -} - -Expr GpuIslEmitter::ModifyTheInitExpr(const Expr &e) { return 0; } - -Expr GpuIslEmitter::ModifyTheCondExpr(const Expr &e, int inc) { return e / Expr(inc); } - -Expr GpuIslEmitter::ModifyTheIterExpr(const VarExpr &iter, int inc, const Expr &init) { - return Simplify(iter * inc + init); -} - -int GpuIslEmitter::GetThreadExtent(const std::string &name) { - if (name == BLOCK_IDX_X || name == BLOCK_IDX_Y || name == BLOCK_IDX_Z) { - auto block_cfg = info_.user_config_.GetBlockConfig(); - CHECK(block_cfg) << "block config is null."; - return name == BLOCK_IDX_X ? block_cfg->GetX().second - : (name == BLOCK_IDX_Y ? block_cfg->GetY().second : block_cfg->GetZ().second); - } - - if (name == THREAD_IDX_X || name == THREAD_IDX_Y || name == THREAD_IDX_Z) { - auto thread_cfg = info_.user_config_.GetThreadConfig(); - CHECK(thread_cfg) << "thread config is null."; - if (info_.user_config_.GetEnableOneDimThread()) { - return name == THREAD_IDX_X ? (thread_cfg->GetX().second * thread_cfg->GetY().second * thread_cfg->GetZ().second) - : 1; - } - return name == THREAD_IDX_X ? thread_cfg->GetX().second - : (name == THREAD_IDX_Y ? thread_cfg->GetY().second : thread_cfg->GetZ().second); - } - LOG(WARNING) << "Unrecognized thread name " << name; - return 1; -} - -void GpuIslEmitter::PrepareDataForTensorCore() { - auto binds = info_.user_config_.GetBind(); - - auto thread_cfg = info_.user_config_.GetThreadConfig(); - CHECK(thread_cfg) << "thread config is null"; - int tx = thread_cfg->GetX().second; - int ty = thread_cfg->GetY().second; - int tz = thread_cfg->GetZ().second; - - if (info_.user_config_.GetEnableOneDimThread()) { - tx = tx * ty * tz; - ty = 1; - tz = 1; - } - - for (auto i : binds) { - if (!i.first.defined()) continue; - if (!i.second.defined()) continue; - auto t = i.first; - auto b = i.second; - - std::string name = t->op->name; - - air::ir::TensorKey key{t->op, t->value_index}; - Region bounds; - if (bounds.empty()) { - for (auto j : t->shape) { - bounds.push_back(Range::make_by_min_extent(Expr(0), j)); - } - } - - tensor_core_info_.bounds_[key] = bounds; - - Array strides; - for (size_t i = 1; i < b->shape.size(); ++i) { - Expr stride = IntImm::make(Int(32), 1); - for (size_t j = b->shape.size() - 1; j >= i; --j) { - stride = Mul::make(stride, b->shape[j]); - } - strides.push_back(stride); - } - strides.push_back(make_const(Int(32), 1)); - tensor_core_info_.strides_[name] = strides; - } - - auto tile_size = info_.analysis_result_.GetTileSizes(); - CHECK_GE(tile_size.size(), 3) << "tile size should be greater to 3"; - int len = tile_size.size(); - tensor_core_info_.warp_tile_.m = tile_size[len - 3].c0_tiling_size; - tensor_core_info_.warp_tile_.n = tile_size[len - 2].c0_tiling_size; - tensor_core_info_.warp_tile_.k = tile_size[len - 1].c0_tiling_size; - - bool result = CheckTileValid(tensor_core_info_.warp_tile_); - CHECK(result) << "tile set is not valid!"; - - tensor_core_info_.thread_tile_.m = tensor_core_info_.warp_tile_.m / tx; - tensor_core_info_.thread_tile_.n = tx / 2; - tensor_core_info_.thread_tile_.k = tile_size[2].c0_tiling_size / tz; - - tensor_core_info_.matrix_abc_ = info_.analysis_result_.GetMatrixMatmulMap(); - tensor_core_info_.matrix_major_ = info_.analysis_result_.GetMatrixMatmulMajor(); - - for (auto &i : tensor_core_info_.matrix_abc_) { - tensor_core_info_.frag_reg_.insert(i.first + LOCAL_SUFFIX); - } - - tensor_core_info_.warp_threads_y_ = 32 / tx; - tensor_core_info_.warp_threads_x_ = tx; -} - -bool GpuIslEmitter::CheckTileValid(Tile tile) { - if (tile.m == 16 && tile.n == 16 && tile.k == 4) { - tensor_core_info_.wmma_scope_ = "akg"; - return true; - } - if (tile.m == 16 && tile.n == 16 && tile.k == 16) { - tensor_core_info_.wmma_scope_ = "nvcuda"; - return true; - } - if (tile.m == 8 && tile.n == 32 && tile.k == 16) { - tensor_core_info_.wmma_scope_ = "nvcuda"; - return true; - } - if (tile.m == 32 && tile.n == 8 && tile.k == 16) { - tensor_core_info_.wmma_scope_ = "nvcuda"; - return true; - } - return false; -} - -Stmt GpuIslEmitter::Emit(const isl::ast_node &node) { - Stmt stmt = EmitAst(node); - - // emit realize for temporary tensor - stmt = EmitRealizeForGlobalTensor(stmt); - - // iter var node attr emit - std::map::iterator it; - for (it = iter_name_map_.begin(); it != iter_name_map_.end(); it++) { - IterVar axis = IterVarNode::make(Range(), it->second, air::kThreadIndex, it->second->name_hint); - stmt = AttrStmt::make(axis, air::ir::attr::thread_extent, Expr(GetThreadExtent(it->second->name_hint)), stmt); - } - - // attr for one dimension mapping - if (info_.user_config_.GetEnableOneDimThread()) { - auto thread_cfg = info_.user_config_.GetThreadConfig(); - CHECK(thread_cfg) << "thread config is null."; - int tx = thread_cfg->GetX().second; - stmt = AttrStmt::make(Expr(""), ORIGIN_THREAD_DIM_X, Expr(tx), stmt); - } - - // add tensor core plan two attr - if (info_.user_config_.GetEnableTensorCore()) { - if (info_.user_config_.GetEnableTensorCoreUsePoly()) { - stmt = AttrStmt::make(Expr(""), "pragma_tensor_core", StringImm::make(TENSOR_CORE_MODE_TWO), stmt); - stmt = AttrStmt::make(Expr("INFO"), "wmma_scope", StringImm::make(tensor_core_info_.wmma_scope_), stmt); - } else { - stmt = AttrStmt::make(Expr(""), "pragma_tensor_core", StringImm::make(TENSOR_CORE_MODE_ONE), stmt); - } - } - - if (tensor_core_info_.is_tensor_core_ && info_.user_config_.GetEnableTensorCoreUsePoly()) { - stmt = AddMmaAttrFlag(tensor_core_info_).Mutate(stmt); - stmt = EmitForTensorCore(stmt, tensor_core_info_); - } else if (info_.user_config_.GetEnableTensorCore()) { - tensor_core_info_.cast_tensors_ = info_.analysis_result_.GetCastTensors(); - stmt = EmitForTensorCoreDesignOne(stmt, tensor_core_info_); - } - - return stmt; -} - -Stmt GpuIslEmitter::EmitRealizeForGlobalTensor(Stmt stmt) { - auto binds = info_.user_config_.GetBind(); - auto origin_binds = info_.user_config_.GetOriginBind(); - std::unordered_set tensor_name; - - for (auto i : binds) { - if (!i.first.defined()) continue; - tensor_name.insert(i.first->op->name); - } - - for (auto i : binds) { - if (!i.first.defined()) continue; - // input and output tensor, no need to emit realize - if (origin_binds.find(i.first) != origin_binds.end()) { - continue; - } - - // promoted tensor, the realize info already emitted before - std::string name = i.first->op->name; - if (IsEndsWith(name, MEM_TYPE_SHARED) || IsEndsWith(name, MEM_TYPE_LOCAL)) { - continue; - } - - // if the tensor is temporary, but has already promoted, there is no need to emit realize - if (tensor_name.find(name + "_" + MEM_TYPE_SHARED) != tensor_name.end() || - tensor_name.find(name + "_" + MEM_TYPE_LOCAL) != tensor_name.end()) { - continue; - } - - if (reduce_info_.added_tensors_.find(name) != reduce_info_.added_tensors_.end()) { - continue; - } - - // if the tensor is temporary and it is not promoted, it needs to emit realize - stmt = InsertRealize(stmt, isl::id(info_.GetCtx(), name)); - } - return stmt; -} - -Stmt GpuIslEmitter::EmitMark(const isl::ast_node_mark &node) { - bool add_to_reduce_area = false; - if (in_reduce_area_ && is_out_most_stmt_) { - add_to_reduce_area = true; - is_out_most_stmt_ = false; - } - - std::string mark = node.get_id().get_name(); - if (mark == MIND_TRICKS_SWIZZLE_MARKER) { - auto stmt = EmitAst(node.get_node()); - stmt = AttrStmt::make(make_zero(Int(32)), MIND_TRICKS_SWIZZLE_PRAGMA, Expr(1), stmt); - return stmt; - } - - if (IsStartsWith(mark, REDUCE_ATOMIC_FLAG)) { - std::vector strs = common::Split(mark, "_"); - CHECK_EQ(strs.size(), REDUCE_ATOMIC_FLAG_SIZE) << "atomic mark format is not right!."; - reduce_info_.reduce_op_.clear(); - if (AkgSupportedReduceOp.count(strs[REDUCE_ATOMIC_FLAG_TYPE_POS])) { - reduce_info_.reduce_op_ = AKG_REDUCE_LIB_SPACE; - reduce_info_.reduce_op_ += "::"; - reduce_info_.reduce_op_ += strs[REDUCE_ATOMIC_FLAG_TYPE_POS]; - } - CHECK(!reduce_info_.reduce_op_.empty()) << "reduce op should not be empty!"; - - if (strs[REDUCE_ATOMIC_FLAG_POS] == REDUCE_ATOMIC_FLAG) { - reduce_info_.is_atomic = true; - } - } - - // add for tensor core - if ((mark == MATRIX_A) || (mark == MATRIX_B) || (mark == MATRIX_C) || (mark == WARP_MARKER)) { - if (!tensor_core_info_.data_is_set_) { - PrepareDataForTensorCore(); - tensor_core_info_.data_is_set_ = true; - } - tensor_core_info_.fragment_axis_begin_ = false; - if (mark == WARP_MARKER) { - mark = MMA_SYNC; - } - if (mark == MATRIX_C) { - mark = MMA_C; - } - - if (!tensor_core_info_.data_is_set_) { - PrepareDataForTensorCore(); - tensor_core_info_.data_is_set_ = true; - } - - tensor_core_info_.is_tensor_core_ = true; - tensor_core_info_.matrix_info_[mark] = true; - tensor_core_info_.core_area_ = true; - - Stmt stmt = EmitAst(node.get_node()); - stmt = DeleteUselessFor().Mutate(stmt); - tensor_core_info_.matrix_info_[mark] = false; - tensor_core_info_.core_area_ = false; - return AttrStmt::make(Expr("INFO"), mark, StringImm::make(mark), stmt); - } - - if ((mark == FRAGMENT_A) || (mark == FRAGMENT_B)) { - tensor_core_info_.fragment_axis_begin_ = true; - if (mark == FRAGMENT_A) { - tensor_core_info_.is_fragment_m_ = true; - } else if (mark == FRAGMENT_B) { - tensor_core_info_.is_fragment_n_ = true; - } - Stmt stmt = EmitAst(node.get_node()); - tensor_core_info_.fragment_axis_begin_ = false; - tensor_core_info_.is_fragment_m_ = false; - tensor_core_info_.is_fragment_n_ = false; - if (!stmt.defined()) { - return Stmt(); - } - return AttrStmt::make(Expr("INFO"), mark, StringImm::make(mark), stmt); - } - - // add for prefetch pass - if (mark == PROMOTE_GLOBAL_TO_SHARED_AB) { - Stmt stmt = EmitAst(node.get_node()); - if (!stmt.defined()) { - return Stmt(); - } - return AttrStmt::make(Expr("INFO"), SHARED_MEM_PROMOTED_COMPLETE, StringImm::make(SHARED_MEM_PROMOTED_COMPLETE), - stmt); - } - - Stmt stmt; - - if ((mark == PROMOTE_VECTORIZATION) || (mark == PROMOTE_LOCAL_TO_GLOBAL)) { - stmt = EmitAst(node.get_node()); - if (!stmt.defined()) { - return Stmt(); - } - stmt = AttrStmt::make(Expr("INFO"), mark, StringImm::make(mark), stmt); - } else { - stmt = EmitAst(node.get_node()); - } - - if (add_to_reduce_area) { - reduce_info_.reduce_area_stmt_ = stmt; - return Stmt(); - } - return stmt; -} - -std::string GpuIslEmitter::FindRealizeScopeToString(const isl::id &var) { - if (info_.analysis_result_.CountBufferDefInfo(var)) { - auto tensor_info = info_.analysis_result_.GetBufferDefInfo(var); - MemType mem_type = tensor_info.DstMemType(); - - switch (mem_type) { - case MemType::SHARED_: - return MEM_TYPE_SHARED; - case MemType::LOCAL_: - return MEM_TYPE_LOCAL; - default: - LOG(FATAL) << "unexpected mem_type of var " << var; - return "ERROR"; - } - } - return ""; -} - -Expr GpuIslEmitter::FindRealizeScope(const isl::id &var) { return Expr(FindRealizeScopeToString(var)); } - -Stmt GpuIslEmitter::InsertRealize(Stmt stmt, const isl::id &var) { - stmt = FindInnerRealize(var.get_name()).Mutate(stmt); - - // A tensor may be defined multiple times in BufferDefInfo due to nested realize. - // Because we cannot determine which one we actually want, we have to be conservative here - // and allocate space for the largest shape to avoid overflow. - Tensor t = info_.FindTensorWithLargestShape(var); - Region bounds; - - // no isolate - if (bounds.empty()) { - for (auto j : t->shape) { - bounds.push_back(Range::make_by_min_extent(Expr(0), j)); - } - } - - // If isolate, make a new buffer - auto buf = info_.user_config_.GetBind().at(t); - - auto tt = placeholder(t->shape, t->dtype, t->op->name); - - stmt = TensorSubstitute(stmt, t->op, tt->op, tt->value_index); - if (tensor_core_info_.is_tensor_core_) { - stmt = TensorSubstituteTensorCore(t->op, tt->op, tt->value_index).Mutate(stmt); - } - t = tt; - if (info_.analysis_result_.CountBufferDefInfo(var)) { - auto decl = info_.analysis_result_.GetBufferDefInfo(var); - decl.tensor = t; - } - info_.user_config_.SetBind(t, buf); - stmt = TensorSubstitute2(stmt, t->op->func_name(), t->op, t->value_index); - stmt = Realize::make(t->op, t->value_index, t->dtype, bounds, const_true(1), stmt); - realized_.insert(t); - stmt = AttrStmt::make(t->op, air::ir::attr::realize_scope, FindRealizeScope(var), stmt); - - return stmt; -} - -Stmt GpuIslEmitter::InsertRealizeWithMemType(Stmt stmt, const isl::id &var, std::string mem) { - stmt = FindInnerRealize(var.get_name()).Mutate(stmt); - - Tensor t = info_.FindTensorWithLargestShape(var); - Region bounds; - - // no isolate - if (bounds.empty()) { - for (auto j : t->shape) { - bounds.push_back(Range::make_by_min_extent(Expr(0), j)); - } - } - - // If isolate, make a new buffer - auto buf = info_.user_config_.GetBind().at(t); - - auto tt = placeholder(t->shape, t->dtype, t->op->name); - - stmt = TensorSubstitute(stmt, t->op, tt->op, tt->value_index); - t = tt; - if (info_.analysis_result_.CountBufferDefInfo(var)) { - auto decl = info_.analysis_result_.GetBufferDefInfo(var); - decl.tensor = t; - } - info_.user_config_.SetBind(t, buf); - stmt = TensorSubstitute2(stmt, t->op->func_name(), t->op, t->value_index); - stmt = Realize::make(t->op, t->value_index, t->dtype, bounds, const_true(1), stmt); - realized_.insert(t); - stmt = AttrStmt::make(t->op, air::ir::attr::realize_scope, Expr(mem), stmt); - - return stmt; -} - -Expr GpuIslEmitter::IterNameAdaptor(std::string name) { - if (iter_name_map_.find(name) != iter_name_map_.end()) { - return iter_name_map_[name]; - } else if (name.find(REPLACE) != std::string::npos) { - name = name.substr(strlen(REPLACE)); - if (info_.user_config_.GetEnableTileC0()) { - return SingleConfigToMultiBand(name); - } - return AdaptPolyNewVar(name); - } else { - return VarExpr(name); - } -} - -Expr GpuIslEmitter::SingleConfigToMultiBand(std::string name) { - Expr e; - VarExpr original_id; - int rep_size = 1; - auto l0_block_size = info_.user_config_.GetC0BlockSize(); - if (name.find(B0) != std::string::npos) { - original_id = iter_name_map_[B0]; - rep_size = l0_block_size[0]; - } else if (name.find(B1) != std::string::npos) { - original_id = iter_name_map_[B1]; - rep_size = l0_block_size[1]; - } else { - original_id = iter_name_map_[B2]; - rep_size = l0_block_size[2]; - } - - if (rep_size < 0) { - return e; - } - - if (name.find(TILE_WITH_C0) != std::string::npos) { - e = Mod::make(original_id, rep_size); - } else if (name.find(TILE_WITH_C1) != std::string::npos) { - e = Div::make(original_id, rep_size); - } else { - LOG(FATAL) << "Unexpected binding id: " << name; - } - return e; -} - -// if new var is added in poly process, modify the logic here. -// another modify pos is IterNameAdaptor interface -Expr GpuIslEmitter::AdaptPolyNewVar(std::string name) { - Expr e; - std::string t0_string = T0; - int suffix_len = t0_string.size() + 1; - auto tensor_name = name.substr(0, name.size() - suffix_len); - if (!info_.user_config_.GetReplaceConfig().count(tensor_name)) { - return e; - } - auto mapping_cfg = (info_.user_config_.GetReplaceConfig()[tensor_name]); - CHECK(mapping_cfg) << "mapping config is null."; - int mx = mapping_cfg->GetX().second; - int my = mapping_cfg->GetY().second; - int mz = mapping_cfg->GetZ().second; - if (name.find(WARP_COMPUTE) != std::string::npos) { - if (name.find(T0) != std::string::npos) { - e = Div::make(iter_name_map_[T0], WARP_SIZE); - e = Mod::make(e, mx); - return e; - } else if (name.find(T1) != std::string::npos) { - e = Div::make(iter_name_map_[T0], WARP_SIZE); - e = Div::make(e, mx); - return e; - } - } else { - if (name.find(T0) != std::string::npos) { - e = Mod::make(iter_name_map_[T0], mx); - return e; - } else if (name.find(T1) != std::string::npos) { - e = Div::make(iter_name_map_[T0], mx); - if (mz == 1) { - return e; - } - e = Mod::make(e, my); - return e; - } else if (name.find(T2) != std::string::npos) { - e = Div::make(iter_name_map_[T0], mx); - e = Div::make(e, my); - return e; - } - } - return e; -} - -Expr GpuIslEmitter::Interpret(const isl::ast_expr &e) { - if (auto int_expr = e.as()) { - return Expr(IslExprToSInt(int_expr)); - } else if (auto id_expr = e.as()) { - // If this variable is defined by loop index, we need sharing it. - const Variable *var = GetIterByName(id_expr.get_id().get_name()); - if (var) { - if (stride_modify_iter_map_.find(var) != stride_modify_iter_map_.end()) { - return stride_modify_iter_map_[var]; - } - return VarExpr(GetObjPtr(var)); - } else { - return IterNameAdaptor(id_expr.get_id().to_str()); - } - } else if (auto op_expr = e.as()) { - return InterpretOp(op_expr); - } else { - LOG(FATAL) << "NYI " << e; - return 0; - } -} - -Stmt GpuIslEmitter::EmitAccessNodeFromPromoteAcsCall(isl::id var, const Node *node, Array &args) { - const Call *call = static_cast(node); - Tensor t = info_.FindTensor(var); - return Evaluate::make(Call::make(call->type, var.get_name(), args, call->call_type, t->op, t->value_index)); -} - -Stmt GpuIslEmitter::EmitAccessNodeFromPromoteAcsProvide(isl::id var, const Node *node, Array &args) { - const auto provide = static_cast(node); - Tensor t = info_.FindTensor(var); - Stmt s = Provide::make(t->op, 0, provide->value, args); - return s; -} - -void GetNameWithoutShared(isl::id &tensor_id, ScopInfo &info) { - if (info.user_config_.GetEnableMatmul()) { - size_t pos = tensor_id.get_name().find(SHARE_SUFFIX); - std::string substr = tensor_id.get_name().substr(0, pos); - if (pos != 0) tensor_id = isl::id(tensor_id.ctx(), substr); - } -} - -isl::multi_aff GpuIslEmitter::TensorAccessMultAff(isl::id &tensor_id, const Array &tensor_index, - const isl::id &node_id) { - GetNameWithoutShared(tensor_id, info_); - return IslEmitter::TensorAccessMultAff(tensor_id, tensor_index, node_id); -} - -Array EmitTensorCoreHelper::GetTileSize(const std::string &name) { - auto it = tensor_core_info_.matrix_abc_.find(name); - auto it2 = tensor_core_info_.matrix_major_.find(name); - CHECK(it != tensor_core_info_.matrix_abc_.end() && it2 != tensor_core_info_.matrix_major_.end()) - << "Cannot find matrix info for " << name; - Expr size0 = make_const(Int(32), 16); - Expr size1 = make_const(Int(32), 16); - if (it->second == MMA_A && it2->second == COL_MAJOR) { - size0 = make_const(Int(32), tensor_core_info_.warp_tile_.k); - size1 = make_const(Int(32), tensor_core_info_.warp_tile_.m); - } - if (it->second == MMA_A && it2->second == ROW_MAJOR) { - size0 = make_const(Int(32), tensor_core_info_.warp_tile_.m); - size1 = make_const(Int(32), tensor_core_info_.warp_tile_.k); - } - if (it->second == MMA_B && it2->second == ROW_MAJOR) { - size0 = make_const(Int(32), tensor_core_info_.warp_tile_.k); - size1 = make_const(Int(32), tensor_core_info_.warp_tile_.n); - } - if (it->second == MMA_B && it2->second == COL_MAJOR) { - size0 = make_const(Int(32), tensor_core_info_.warp_tile_.n); - size1 = make_const(Int(32), tensor_core_info_.warp_tile_.k); - } - - if (it->second == MATRIX_C) { - size0 = make_const(Int(32), tensor_core_info_.warp_tile_.m); - size1 = make_const(Int(32), tensor_core_info_.warp_tile_.n); - } - Array tile_size = {size0, size1}; - return tile_size; -} - -void EmitTensorCoreHelper::SetDataForLoad(Expr src, Expr stride, Expr major, const Call *call, const Provide *op, - NodePtr &node) { - data_for_load_.src = src; - data_for_load_.stride = stride; - data_for_load_.major = major; - data_for_load_.call = call; - data_for_load_.op = op; - data_for_load_.node = node; -} -void EmitTensorCoreHelper::SetDataForStore(Expr dst, Expr stride, const Call *call, NodePtr &node) { - data_for_store_.dst = dst; - data_for_store_.stride = stride; - data_for_store_.call = call; - data_for_store_.node = node; -} -void EmitTensorCoreHelper::SetDataForFill(const Provide *op, const Call *call, NodePtr &node) { - data_for_fill_.call = call; - data_for_fill_.op = op; - data_for_fill_.node = node; -} -void EmitTensorCoreHelper::SetDataForSync(Expr a, Expr b, Expr c, NodePtr &node_a, - NodePtr &node_b, NodePtr &node_c) { - data_for_sync_.a = a; - data_for_sync_.b = b; - data_for_sync_.c = c; - data_for_sync_.node_a = node_a; - data_for_sync_.node_b = node_b; - data_for_sync_.node_c = node_c; -} - -void EmitTensorCoreHelper::PrepareDataCore() { - auto it = tensor_core_info_.bounds_.find(key_); - CHECK(it != tensor_core_info_.bounds_.end()); - Array min_bound; - for (auto i : it->second) { - min_bound.push_back(i->min); - } - - CHECK_GE(it->second.size(), 2); - Array shape; - for (size_t i = 0; i < it->second.size() - 2; ++i) { - shape.push_back(it->second[i]->extent); - } - auto tile_size = GetTileSize(SimplifyName(call_->name)); - shape.push_back(tile_size[0]); - shape.push_back(tile_size[1]); - - tensor_core_info_.min_bounds_[call_->name] = min_bound; - - Array strides; - for (size_t i = 1; i < shape.size(); ++i) { - Expr stride = IntImm::make(Int(32), 1); - for (size_t j = shape.size() - 1; j >= i; --j) { - stride = Mul::make(stride, shape[j]); - } - strides.push_back(stride); - } - strides.push_back(make_const(Int(32), 1)); - - Expr elem_offset = IntImm::make(Int(32), 0); - CHECK_EQ(call_->args.size(), min_bound.size()); - for (size_t i = 0; i < min_bound.size(); i++) { - auto arg = call_->args[i]; - arg = DeleteThreadIdx().Mutate(arg); - arg = Simplify(arg); - elem_offset = Add::make(elem_offset, Mul::make(strides[i], Sub::make(arg, min_bound[i]))); - } - - auto it2 = tensor_core_info_.matrix_abc_.find(SimplifyName(call_->name)); - CHECK(it2 != tensor_core_info_.matrix_abc_.end()) << "Cannot find matrix info for " << call_->name; - buffer_node_->data = Variable::make(Handle(), call_->name); - buffer_node_->name = call_->name; - std::string name = it2->second; - if (name == MATRIX_C) { - name = MMA_C; - } - buffer_node_->scope = "wmma." + name; - buffer_node_->dtype = data_type_; - buffer_node_->strides = strides; - buffer_node_->shape = shape; - buffer_node_->data_alignment = 1; - buffer_node_->elem_offset = Simplify(elem_offset); - buffer_node_->offset_factor = 1; - Buffer buffer(buffer_node_); - - NodePtr tensor_node = make_node(); - tensor_node->value_index = key_.value_index; - tensor_node->op = Downcast(key_.f); - tensor_node->shape = shape; - tensor_node->dtype = data_type_; - Tensor tensor(tensor_node); - - Array args; - for (size_t i = 0; i < call_->args.size(); ++i) { - auto arg = call_->args[i]; - arg = DeleteThreadIdx().Mutate(arg); - arg = Simplify(arg); - - args.push_back(arg); - args.push_back(shape[i]); - } - tuple_ = Call::make(Handle(), air::ir::intrinsic::tvm_tuple, args, Call::Intrinsic); - node_ = {buffer, tensor}; -} - -Stmt EmitTensorCoreHelper::MakeLoadTransform() { - key_ = air::ir::TensorKey{data_for_load_.op->func, data_for_load_.op->value_index}; - call_ = data_for_load_.call; - buffer_node_ = data_for_load_.node; - data_type_ = call_->type; - - PrepareDataCore(); - Buffer buffer = Downcast(node_[0]); - Stmt stmt = Evaluate::make(Call::make( - Handle(), air::ir::intrinsic::tvm_load_matrix_sync, - {buffer->data, tensor_core_info_.warp_tile_.m, tensor_core_info_.warp_tile_.n, tensor_core_info_.warp_tile_.k, - Simplify(buffer->elem_offset), data_for_load_.src, data_for_load_.stride, data_for_load_.major}, - Call::Intrinsic)); - return AttrStmt::make(node_, "buffer_bind_scope", tuple_, stmt); -} - -Stmt EmitTensorCoreHelper::MakeStoreTransform() { - key_ = air::ir::TensorKey{data_for_store_.call->func, data_for_store_.call->value_index}; - call_ = data_for_store_.call; - buffer_node_ = data_for_store_.node; - data_type_ = call_->type; - - PrepareDataCore(); - Buffer buffer = Downcast(node_[0]); - Stmt stmt = Evaluate::make(Call::make( - Handle(), air::ir::intrinsic::tvm_store_matrix_sync, - {buffer->data, tensor_core_info_.warp_tile_.m, tensor_core_info_.warp_tile_.n, tensor_core_info_.warp_tile_.k, - buffer->elem_offset, data_for_store_.dst, data_for_store_.stride, StringImm::make(ROW_MAJOR)}, - Call::Intrinsic)); - return AttrStmt::make(node_, "buffer_bind_scope", tuple_, stmt); -} - -Stmt EmitTensorCoreHelper::MakeFillTransform() { - key_ = air::ir::TensorKey{data_for_fill_.call->func, data_for_fill_.call->value_index}; - call_ = data_for_fill_.call; - buffer_node_ = data_for_fill_.node; - data_type_ = call_->type; - - PrepareDataCore(); - Buffer buffer = Downcast(node_[0]); - Stmt stmt = Evaluate::make(Call::make(Handle(), air::ir::intrinsic::tvm_fill_fragment, - {buffer->data, tensor_core_info_.warp_tile_.m, tensor_core_info_.warp_tile_.n, - tensor_core_info_.warp_tile_.k, buffer->elem_offset, data_for_fill_.op->value}, - Call::Intrinsic)); - return AttrStmt::make(node_, "buffer_bind_scope", tuple_, stmt); -} - -Stmt EmitTensorCoreHelper::MakeSyncTransform() { - bool is_cast = false; - if (data_for_sync_.a.as()) { - auto call_a = data_for_sync_.a.as(); - key_ = air::ir::TensorKey{call_a->func, call_a->value_index}; - call_ = call_a; - buffer_node_ = data_for_sync_.node_a; - data_type_ = call_->type; - is_cast = true; - } else if (data_for_sync_.a.as()) { - auto cast_a = data_for_sync_.a.as(); - auto call_a = cast_a->value.as(); - CHECK(call_a); - key_ = air::ir::TensorKey{call_a->func, call_a->value_index}; - call_ = call_a; - buffer_node_ = data_for_sync_.node_a; - data_type_ = call_->type; - is_cast = true; - } - - PrepareDataCore(); - - auto tuple_a = tuple_; - auto node_a = node_; - - if (data_for_sync_.b.as()) { - auto call_b = data_for_sync_.b.as(); - key_ = air::ir::TensorKey{call_b->func, call_b->value_index}; - call_ = call_b; - buffer_node_ = data_for_sync_.node_b; - data_type_ = call_->type; - is_cast = true; - } else if (data_for_sync_.b.as()) { - auto cast_b = data_for_sync_.b.as(); - auto call_b = cast_b->value.as(); - CHECK(call_b); - key_ = air::ir::TensorKey{call_b->func, call_b->value_index}; - call_ = call_b; - buffer_node_ = data_for_sync_.node_b; - data_type_ = call_->type; - is_cast = true; - } - - PrepareDataCore(); - - auto tuple_b = tuple_; - auto node_b = node_; - - auto call_c = data_for_sync_.c.as(); - CHECK(call_c); - key_ = air::ir::TensorKey{call_c->func, call_c->value_index}; - call_ = call_c; - buffer_node_ = data_for_sync_.node_c; - data_type_ = call_->type; - - PrepareDataCore(); - - auto tuple_c = tuple_; - auto node_c = node_; - - Buffer buffer_a(data_for_sync_.node_a); - Buffer buffer_b(data_for_sync_.node_b); - Buffer buffer = Downcast(node_c[0]); - - Stmt stmt = Evaluate::make(Call::make(Handle(), air::ir::intrinsic::tvm_mma_sync, - {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, - buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}, - Call::Intrinsic)); - - stmt = AttrStmt::make(node_c, "buffer_bind_scope", tuple_c, stmt); - stmt = AttrStmt::make(node_b, "buffer_bind_scope", tuple_b, stmt); - stmt = AttrStmt::make(node_a, "buffer_bind_scope", tuple_a, stmt); - - std::string cast_mode = CAST_MODE_1; - if (is_cast) { - stmt = AttrStmt::make(Expr("INFO"), CAST_FLAG, StringImm::make(cast_mode), stmt); - } - - return stmt; -} - -} // namespace poly -} // namespace ir -} // namespace akg diff --git a/src/poly/gpu_isl_emitter.h b/src/poly/gpu_isl_emitter.h deleted file mode 100644 index f6b2b44b45f531edc0b4efc38603c805e6399607..0000000000000000000000000000000000000000 --- a/src/poly/gpu_isl_emitter.h +++ /dev/null @@ -1,592 +0,0 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef POLY_GPU_ISL_EMITTER_H_ -#define POLY_GPU_ISL_EMITTER_H_ - -#include "isl_emitter.h" - -namespace akg { -namespace ir { -namespace poly { -#define TENSOR_CORE_DEV true -/*! - * IslEmitter for GPU - */ -constexpr auto AKG_ALL_REDUCE = "akg_reduce::ALL_REDUCE"; -constexpr auto AKG_X_REDUCE = "akg_reduce::REDUCE2D_X"; -constexpr auto AKG_Y_REDUCE = "akg_reduce::REDUCE2D_Y"; - -constexpr auto MIND_TRICKS_SWIZZLE_MARKER = "mind_trick_swizzle_marker"; -constexpr auto MIND_TRICKS_SWIZZLE_PRAGMA = "pragma_swizzle"; - -// example: -// red_init_SumOp_S_1_0 -constexpr auto REDUCE_FLAG_SIZE = 6; -constexpr auto REDUCE_FLAG_TYPE_POS = 2; -constexpr auto REDUCE_FLAG_STMT_PREFIX_POS = 3; -constexpr auto REDUCE_FLAG_STMT_NUM_POS = 4; -constexpr auto REDUCE_FLAG_REDUCE_INDEX = 5; - -// example: -// atomic_SumOp -constexpr auto REDUCE_ATOMIC_FLAG_SIZE = 2; -constexpr auto REDUCE_ATOMIC_FLAG = "atomic"; -constexpr auto REDUCE_ATOMIC_FLAG_POS = 0; -constexpr auto REDUCE_ATOMIC_FLAG_TYPE_POS = 1; - -constexpr auto DEFAULT_TENSOR_INDEX = "[0]"; - -constexpr auto USELESS_INDEX = "0"; -constexpr auto USELESS_SHAPE_SIZE = "1"; -constexpr auto SCALAR_TENSOR_PREFIX = "acc_"; -constexpr auto SCALAR_KHT_PREFIX = "kahan_t"; -constexpr auto SCALAR_KHY_PREFIX = "kahan_y"; -constexpr auto SCALAR_KHC_PREFIX = "kahan_c"; -constexpr auto SHARED_MEMORY_PREFIX = "__shared__"; -constexpr auto SHARED_TENSOR_PREFIX = "red_buf"; - -constexpr auto REDUCE_LIB_TYPE_ORIGIN = "origin"; -constexpr auto REDUCE_LIB_TYPE_PARIS = "paris"; -constexpr auto AKG_REDUCE_LIB_SPACE = "akg_reduce"; -constexpr auto AKG_REDUCE_LIB_NAME = "AkgReduce"; -constexpr auto AKG_KAHAN_LIB_NAME = "AkgKahanAccumulation"; -constexpr auto PARIS_REDUCE_LIB_SPACE = "paris_reduce"; -constexpr auto PARIS_REDUCE_LIB_NAME = "ParisReduce"; -constexpr auto AKG_REDUCE_RETURN_NAME = "AkgAtomicReturn"; -constexpr auto PARIS_REDUCE_RETURN_NAME = "ParisReturn"; -constexpr auto REDUCE_LIB_TYPE_FLAG = "reduceLibType"; - -constexpr auto MEM_TYPE_SHARED = "shared"; -constexpr auto MEM_TYPE_LOCAL = "local"; - -// add for one dimension mapping -constexpr auto ORIGIN_THREAD_DIM_X = "bind_thread_x"; - -// add for tensor core -constexpr auto MMA_A = "matrix_a"; -constexpr auto MMA_B = "matrix_b"; -constexpr auto MMA_C = "accumulator"; -constexpr auto MMA_SYNC = "matrix_sync"; -constexpr auto MMA_PREFIX = "matrix_"; -constexpr auto MMA_FILL_STMT_SERIAL = 2; -constexpr auto MMA_SYNC_STMT_SERIAL = 1; -constexpr auto ENABLE_SCHEME_TWO = "EnableSchemeTwo"; -constexpr auto CAST_FLAG = "CAST"; -constexpr auto CAST_MODE_1 = "mode1"; -constexpr auto GMREAD_FLAG = "GMRead"; -constexpr auto SHARED_MEM_PROMOTED_COMPLETE = "shared_mem_promoted_complete"; -constexpr auto FRAGMENT_A = "fragment_a"; -constexpr auto FRAGMENT_B = "fragment_b"; -constexpr auto FRAGMENT_C = "fragment_c"; - -std::string SimplifyName(std::string input); -constexpr auto FOR_INFO_COLLECT_DEPTH = 3; -constexpr auto LOCAL_INDEX_POS = 4; -constexpr auto TENSOR_CORE_MODE_ONE = "1"; -constexpr auto TENSOR_CORE_MODE_TWO = "2"; -constexpr auto WARP_MARKER = "warp_marker"; - -class ReduceEmitInfo { - public: - std::string akg_reduce_api_; - std::string akg_reduce_template_arg_; - std::string output_promoted_tensor_name_for_atomic_; - std::string akg_atomic_api_; - std::string akg_atomic_template_arg_; - std::set atomic_tensors_; - - std::string promoted_tensor_name_for_reduce_; - std::map reduce_stmt_; - - std::string shared_compute_name_; - std::string scalar_tensor_name_; - std::string scalar_kht_name_; - std::string scalar_khy_name_; - std::string scalar_khc_name_; - Expr input_tensor_expr_; - - std::string reduce_op_; - std::string reduce_stmt_index_; - bool is_atomic{false}; - Type output_tensor_data_type_info_; - Type reduce_data_type_info_; - - std::set added_tensors_; - Stmt reduce_area_stmt_; - Stmt origin_reduce_stmt_; - std::map scalar_tensor_; - Tensor shared_tensor_; - std::vector stmts_; - Expr atomic_rhs_; - Stmt gm_write_stmt_; - - bool init_stmt_emit_{false}; -}; - -struct Tile { - int m{-1}; - int n{-1}; - int k{-1}; -}; - -class TensorCoreInfo { - public: - bool in_matrix_a_{false}; - bool in_matrix_b_{false}; - bool in_matrix_c_{false}; - bool in_matrix_sync_{false}; - - std::map matrix_info_{{MMA_A, false}, {MMA_B, false}, {MMA_C, false}, {MMA_SYNC, false}}; - bool core_area_{false}; - bool fragment_axis_begin_{false}; - bool is_fragment_m_{false}; - bool is_fragment_n_{false}; - Expr fragment_m_; - Expr fragment_n_; - int warp_threads_y_{-1}; - int warp_threads_x_{-1}; - Tile warp_tile_; - Tile thread_tile_; - - std::unordered_map matrix_major_; - std::unordered_map matrix_abc_; - std::unordered_map bounds_; - std::unordered_map> strides_; - bool data_is_set_{false}; - std::set frag_reg_; - bool is_tensor_core_{false}; - bool for_mod_pos_found_{false}; - std::unordered_set cast_tensors_; - std::unordered_map core_area_for_extent_; - std::unordered_map> min_bounds_; - - std::string wmma_scope_; -}; - -class GpuIslEmitter : public IslEmitter { - public: - GpuIslEmitter(ScopInfo &info, const NodeInfoRepo &n, const isl::id_list &i) : IslEmitter(info, n, i) {} - ~GpuIslEmitter() override = default; - - Stmt Emit(const isl::ast_node &node) final; - Expr Interpret(const isl::ast_expr &e); - - private: - // override emitters for GPU - Stmt EmitBlock(const isl::ast_node_block &node) final; - Stmt EmitStmt(const isl::ast_node_user &node) final; - Stmt EmitFor(const isl::ast_node_for &node) final; - Stmt EmitMark(const isl::ast_node_mark &node_id) override; - Stmt EmitIf(const isl::ast_node_if &node) final; - Stmt EmitUserStmt(const isl::ast_node_user &node) final; - - // DMA emitters for GPU - Expr EmitLoad(const isl::ast_expr &lhs, Type type); - Stmt EmitRead(const isl::ast_node_user &node); - Stmt EmitWrite(const isl::ast_node_user &node); - Stmt EmitWriteAtomic(const isl::ast_node_user &node); - - Stmt EmitAccessNodeFromPromoteAcsCall(isl::id var, const Node *node, Array &args); - Stmt EmitAccessNodeFromPromoteAcsProvide(isl::id var, const Node *node, Array &args); - isl::multi_aff TensorAccessMultAff(isl::id &tensor_id, const Array &subscripts, const isl::id &stmt_id); - - Stmt EmitSync(); - Stmt EmitReduceInit(const isl::ast_node_user &node); - Stmt EmitReduceUpdate(const isl::ast_node_user &node); - Stmt EmitReduceArea(const isl::ast_node_user &node); - Stmt EmitAttr(); // thread_extent, virtual_thread - - // add for tensor core - Stmt EmitUserStmtCore(const isl::ast_node_user &node); - Stmt EmitUserStmtCoreSync(const isl::ast_node_user &node); - Stmt EmitReadCore(const isl::ast_node_user &node); - Stmt EmitWriteCore(const isl::ast_node_user &node); - - Type GetTypeOfTensor(std::string name); - Expr MakeLeftCallFromProvide(const Provide *op); - void PrepareDataForTensorCore(); - bool CheckTileValid(Tile tile); - - Expr FindRealizeScope(const isl::id &var); - std::string FindRealizeScopeToString(const isl::id &var); - Stmt InsertRealize(Stmt stmt, const isl::id &var); - Stmt InsertRealizeWithMemType(Stmt stmt, const isl::id &var, std::string mem); - - Expr IterNameAdaptor(std::string name); - Expr SingleConfigToMultiBand(std::string name); - - Expr AdaptPolyNewVar(std::string name); - int GetThreadExtent(const std::string &name); - - Expr ModifyTheInitExpr(const Expr &e); - Expr ModifyTheCondExpr(const Expr &e, int inc); - Expr ModifyTheIterExpr(const VarExpr &iter, int inc, const Expr &init); - - Stmt EmitRealizeForGlobalTensor(Stmt stmt); - - bool NoNeedToEmitForTempTensor(const isl::id &id); - - void MakeAkgReduceFuncName(); - void ConstructAtomicReturnFuncName(); - void MakeReduceStmt(); - Stmt TransferToKaHanInterface(); - Stmt MakeAtomicStmt(); - - void SetScalarTensorBind(std::string scalar_tensor_name); - void SetSharedTensorBind(); - void ResetStatus(); - - std::set realized_; - - std::unordered_map stride_modify_iter_map_; - std::map iter_name_map_{{B0, VarExpr(BLOCK_IDX_X)}, {B1, VarExpr(BLOCK_IDX_Y)}, - {B2, VarExpr(BLOCK_IDX_Z)}, {T0, VarExpr(THREAD_IDX_X)}, - {T1, VarExpr(THREAD_IDX_Y)}, {T2, VarExpr(THREAD_IDX_Z)}}; - - bool in_reduce_area_{false}; - bool update_stmt_out_{false}; - bool init_stmt_out_{false}; - bool is_out_most_stmt_{true}; - ReduceEmitInfo reduce_info_; - TensorCoreInfo tensor_core_info_; - bool is_sync_before_{false}; -}; - -struct DataForLoad { - Expr src; - Expr stride; - Expr major; - const Call *call; - const Provide *op; - NodePtr node; -}; - -struct DataForStore { - Expr dst; - Expr stride; - const Call *call; - NodePtr node; -}; - -struct DataForFill { - const Call *call; - const Provide *op; - NodePtr node; -}; - -struct DataForSync { - Expr a; - Expr b; - Expr c; - NodePtr node_a; - NodePtr node_b; - NodePtr node_c; -}; - -class DeleteThreadIdx : public air::ir::IRMutator { - public: - explicit DeleteThreadIdx() {} - ~DeleteThreadIdx() override = default; - Expr Mutate_(const Variable *op, const Expr &e) { - if (op->name_hint == THREAD_IDX_X) { - return make_const(Int(32), 0); - } - - return e; - } -}; - -class EmitTensorCoreHelper { - public: - EmitTensorCoreHelper(TensorCoreInfo &info) : tensor_core_info_(info) {} - ~EmitTensorCoreHelper(){}; - - void SetDataForLoad(Expr src, Expr stride, Expr major, const Call *call, const Provide *op, - NodePtr &node); - void SetDataForStore(Expr dst, Expr stride, const Call *call, NodePtr &node); - void SetDataForFill(const Provide *op, const Call *call, NodePtr &node); - void SetDataForSync(Expr a, Expr b, Expr c, NodePtr &node_a, NodePtr &node_b, - NodePtr &node_c); - - void PrepareDataCore(); - - Stmt MakeLoadTransform(); - Stmt MakeStoreTransform(); - Stmt MakeFillTransform(); - Stmt MakeSyncTransform(); - - Array GetTileSize(const std::string &name); - - private: - Array node_; - Expr tuple_; - TensorCoreInfo &tensor_core_info_; - - DataForLoad data_for_load_; - DataForStore data_for_store_; - DataForFill data_for_fill_; - DataForSync data_for_sync_; - - air::ir::TensorKey key_; - const Call *call_; - NodePtr buffer_node_; - Type data_type_; -}; - -class AddMmaAttrFlag : public air::ir::IRMutator { - public: - explicit AddMmaAttrFlag(TensorCoreInfo t) : tt(t) {} - ~AddMmaAttrFlag() override = default; - - Stmt Mutate_(const AttrStmt *op, const Stmt &s) override { - Stmt stmt = IRMutator::Mutate_(op, s); - if (op->attr_key == air::ir::attr::realize_scope) { - auto node = op->node.as(); - if (node != nullptr) { - if (!tt.frag_reg_.count(node->name)) { - return stmt; - } - - auto it = tt.matrix_abc_.find(SimplifyName(node->name)); - CHECK(it != tt.matrix_abc_.end()) << "Cannot find matrix info for " << node->name; - std::string name = it->second; - if (name == MATRIX_C) { - name = MMA_C; - } - - auto matrix_abc = "wmma." + name; - Stmt body = Mutate(op->body); - return AttrStmt::make(op->node, op->attr_key, matrix_abc, body); - } - } - return stmt; - } - - private: - TensorCoreInfo tt; -}; - -class TensorSubstituteTensorCore : public air::ir::IRMutator { - public: - explicit TensorSubstituteTensorCore(const FunctionRef &a, const FunctionRef &b, int b_value_index) - : a_(a), b_(b), b_value_index_(b_value_index) {} - ~TensorSubstituteTensorCore() override = default; - - Stmt Mutate_(const AttrStmt *op, const Stmt &s) override { - if (op->attr_key == air::ir::attr::buffer_bind_scope) { - Array arr = Downcast>(op->node); - CHECK_EQ(arr.size(), 2U); - const BufferNode *buffer = arr[0].as(); - const TensorNode *tensor = arr[1].as(); - CHECK(buffer && tensor); - if (tensor->op == a_) { - Tensor new_tensor = TensorNode::make(tensor->shape, tensor->dtype, Downcast(b_), b_value_index_); - Array node = {arr[0], new_tensor}; - return AttrStmt::make(node, op->attr_key, op->value, op->body); - } - } - return IRMutator::Mutate_(op, s); - } - - private: - FunctionRef a_, b_; - int b_value_index_{0}; -}; - -class DeleteUselessFor : public air::ir::IRMutator { - public: - explicit DeleteUselessFor() {} - ~DeleteUselessFor() override = default; - - Stmt Mutate_(const For *op, const Stmt &s) { - for_iters_.push_back(op->loop_var.get()); - Stmt stmt = IRMutator::Mutate_(op, s); - for_iters_.pop_back(); - return stmt.as()->body; - } - - Stmt Mutate_(const AttrStmt *op, const Stmt &s) override { - if (op->attr_key == air::ir::attr::buffer_bind_scope) { - Array arr = Downcast>(op->node); - CHECK_EQ(arr.size(), 2U); - const BufferNode *buffer = arr[0].as(); - const TensorNode *tensor = arr[1].as(); - CHECK(buffer && tensor); - auto e = buffer->elem_offset; - Expr ret = this->Mutate(e); - NodePtr buffer_node = make_node(); - buffer_node->data = buffer->data; - buffer_node->name = buffer->name; - buffer_node->scope = buffer->scope; - buffer_node->dtype = buffer->dtype; - buffer_node->strides = buffer->strides; - buffer_node->shape = buffer->shape; - buffer_node->data_alignment = buffer->data_alignment; - buffer_node->elem_offset = ret; - buffer_node->offset_factor = buffer->offset_factor; - - Buffer buffer_new(buffer_node); - Array node = {buffer_new, arr[1]}; - - auto value = this->Mutate(op->value); - auto body = this->Mutate(op->body); - - return AttrStmt::make(node, op->attr_key, value, body); - } - return IRMutator::Mutate_(op, s); - } - - Expr Mutate_(const Variable *op, const Expr &e) { - bool be_zero = false; - for (auto &i : for_iters_) { - if (i == op) { - be_zero = true; - break; - } - } - - if (be_zero) { - return make_const(Int(32), 0); - } - - return e; - } - - Expr Mutate_(const Call *op, const Expr &e) final { - if (op->is_intrinsic(air::ir::intrinsic::tvm_fill_fragment)) { - CHECK_EQ(op->args.size(), 6U); - return DeleteUselessForIndex(op, e); - } else if (op->is_intrinsic(air::ir::intrinsic::tvm_load_matrix_sync)) { - CHECK_EQ(op->args.size(), 8U); - return DeleteUselessForIndex(op, e); - - } else if (op->is_intrinsic(air::ir::intrinsic::tvm_store_matrix_sync)) { - CHECK_EQ(op->args.size(), 8U); - return DeleteUselessForIndex(op, e); - - } else if (op->is_intrinsic(air::ir::intrinsic::tvm_mma_sync)) { - CHECK_EQ(op->args.size(), 8U); - return DeleteUselessForIndex(op, e); - } else { - return IRMutator::Mutate_(op, e); - } - } - - Expr DeleteUselessForIndex(const Call *op, const Expr &e) { - Array args = op->args; - for (unsigned int i = 0; i < args.size(); ++i) { - args.Set(i, Simplify(this->Mutate(args[i]))); - } - if (args.same_as(op->args)) { - return e; - } - return Call::make(op->type, op->name, args, op->call_type, op->func, op->value_index); - } - - private: - std::vector for_iters_; -}; - -class AkgReduceStmtChange : public air::ir::IRMutator { - public: - explicit AkgReduceStmtChange(Tensor t, Array args, std::string name) : t(t), args(args), name(name) {} - ~AkgReduceStmtChange() override = default; - - Expr Mutate_(const Call *op, const Expr &e) final { - if (op->name == name) { - return Call::make(op->type, t->op->func_name(), args, op->call_type, t->op, op->value_index); - } - return IRMutator::Mutate_(op, e); - } - - Stmt Mutate_(const Provide *op, const Stmt &s) final { - auto stmt = IRMutator::Mutate_(op, s); - auto new_op = stmt.as(); - CHECK(new_op); - if (new_op->func->func_name() == name) { - return Provide::make(t->op, new_op->value_index, new_op->value, args); - } - return stmt; - } - - private: - Tensor t; - Array args; - std::string name; -}; - -class ConditionExprMod : public air::ir::IRMutator { - public: - explicit ConditionExprMod(bool &is_found) : is_found_(is_found) {} - ~ConditionExprMod() override = default; - - Expr Mutate_(const And *op, const Expr &e) override { - auto o_a = op->a; - auto o_b = op->b; - auto a = air::ir::IRMutator::Mutate(op->a); - auto b = air::ir::IRMutator::Mutate(op->b); - if (!a.defined() && !b.defined()) return Expr(); - if (!a.defined()) return b; - if (!b.defined()) return a; - if (o_a.same_as(a) && o_b.same_as(b)) return e; - return And::make(a, b); - } - - Expr Mutate_(const Or *op, const Expr &e) override { - auto o_a = op->a; - auto o_b = op->b; - auto a = air::ir::IRMutator::Mutate(op->a); - auto b = air::ir::IRMutator::Mutate(op->b); - if (!a.defined() && !b.defined()) return Expr(); - if (!a.defined()) return b; - if (!b.defined()) return a; - if (o_a.same_as(a) && o_b.same_as(b)) return e; - return Or::make(a, b); - } - - Expr Mutate_(const EQ *op, const Expr &e) override { - Expr a = op->a; - Expr b = op->b; - - bool rh_zero = false; - bool lh_block = false; - if (b.as()) { - auto v = b.as(); - if (v->value == 0) rh_zero = true; - } - - if (a.as()) { - auto v = a.as(); - if (v->name_hint == BLOCK_IDX_X) { - lh_block = true; - } - } - - if (rh_zero && lh_block) { - is_found_ = true; - return Expr(); - } - return e; - } - - private: - bool &is_found_; -}; - -} // namespace poly -} // namespace ir -} // namespace akg -#endif // POLY_GPU_ISL_EMITTER_H_ diff --git a/src/poly/gpu_mgr_strategy.cc b/src/poly/gpu_mgr_strategy.cc index 94e761509c8c8eb047ba773baced4b6255dca44d..667fe5cc7891fb5d1d9d1fd3fd9a79fe49295b80 100644 --- a/src/poly/gpu_mgr_strategy.cc +++ b/src/poly/gpu_mgr_strategy.cc @@ -45,7 +45,7 @@ void GPUMgrStrategy::RegisterPasses() { } RegisterPass(std::make_shared(pass_info_, scop_info_)); RegisterMemPromPasses(); - RegisterPass(std::make_shared()); + RegisterPass(std::make_shared(pass_info_, scop_info_)); } } // namespace poly diff --git a/src/poly/poly_util.h b/src/poly/poly_util.h index 5833a49615425896088645ef4ce556510736eac3..35b84bd94463e78afaf4df964e78ea62ee4b4742 100644 --- a/src/poly/poly_util.h +++ b/src/poly/poly_util.h @@ -331,10 +331,15 @@ constexpr auto T2 = "t2"; constexpr auto TILE_WITH_C1 = "C1"; constexpr auto TILE_WITH_C0 = "C0"; constexpr auto TILE_WITH_C0_C1 = "C0_C1"; +constexpr auto TILE_WITH_WARP_C1 = "WARP_C1"; constexpr auto REPLACE = "replace_"; constexpr auto COMPUTE = "compute"; constexpr auto PROMOTE = "promote_"; constexpr auto WARP_COMPUTE = "warp_compute"; +constexpr auto CONV_O = "conv_o"; +constexpr auto CONV_N = "conv_n"; +constexpr auto CONV_H_W = "conv_h_w"; + constexpr auto BLOCK_IDX_X = "blockIdx.x"; constexpr auto BLOCK_IDX_Y = "blockIdx.y"; constexpr auto BLOCK_IDX_Z = "blockIdx.z"; @@ -351,12 +356,14 @@ constexpr auto SYNC_SCOP_GLOBAL = "global"; constexpr auto ROW_MAJOR = "row_major"; constexpr auto COL_MAJOR = "col_major"; +constexpr auto REDUCE_AREA_FLAG = "reduce_area"; /****************************************************** * Following const is the mark tags for schedule tree ******************************************************/ constexpr auto REALIZE = "realize"; constexpr auto CONV_GEMM = "conv_gemm"; +constexpr auto CONV_KHKW_OUTER = "conv_khkw_outer"; constexpr auto FUSE_VECTOR = "fuse_vector"; constexpr auto MULTICORE_COINCIDENT = "multicore_coincident_"; @@ -367,15 +374,24 @@ constexpr auto CALL_IM2COL_UB = "cce_img2col_ub"; constexpr auto ATTR_IM2COL_KEY = "im2colKey"; constexpr auto MAPPING_INVALID_WARP = INT_MAX; +// promote marker for poly constexpr auto PROMOTE_GLOBAL_TO_SHARED_AB = "promote_global_to_shared_ab"; constexpr auto PROMOTE_GLOBAL_TO_SHARED_C = "promote_global_to_shared_c"; -constexpr auto PROMOTE_SHARED_TO_REGISTER = "promote_shared_to_register"; +constexpr auto PROMOTE_SHARED_TO_REGISTER_AB = "promote_shared_to_register_ab"; +constexpr auto PROMOTE_SHARED_TO_REGISTER_C = "promote_shared_to_register_c"; constexpr auto PROMOTE_GLOBAL_TO_REGISTER_C = "promote_global_to_register_c"; -constexpr auto PROMOTE_LOCAL_TO_GLOBAL = "promote_local_to_global"; +// promote marker for thread group +constexpr auto PROMOTE_REGISTER_TO_GLOBAL = "promote_register_to_global"; +constexpr auto PROMOTE_REGISTER_TO_SHARED = "promote_register_to_shared"; +constexpr auto PROMOTE_SHARED_TO_GLOBAL = "promote_shared_to_global"; + constexpr auto PROMOTE_VECTORIZATION = "promote_vectorization"; +constexpr auto SKIP_MARKER = "skip"; +constexpr auto MAP_TO_WARP = "map_to_warp"; constexpr auto THREAD_MARKER = "thread_marker"; constexpr auto BLOCK_MARKER = "block_marker"; constexpr auto WARP_MARKER = "warp_marker"; +constexpr auto KH_KW_MARKER = "kh_kw_marker"; constexpr auto VECTORIZATION_MARKER = "vectorization_marker"; constexpr auto REDUCE_MARKER = "reduce_marker_"; constexpr auto ATOMIC_MARKER = "atomic"; @@ -389,6 +405,8 @@ constexpr auto READ_ID_NAME = "GMread"; constexpr auto WRITE_ID_NAME = "GMwrite"; constexpr auto SHARED_READ_ID_NAME = "SHAREDread"; constexpr auto SHARED_WRITE_ID_NAME = "SHAREDwrite"; +constexpr auto GML_READ_ID_NAME = "GMLread"; +constexpr auto GML_WRITE_ID_NAME = "GMLwrite"; constexpr auto AKG_REDUCE_SUM = "SumOp"; constexpr auto AKG_REDUCE_MIN = "MinOp"; @@ -400,6 +418,7 @@ constexpr auto AKG_REDUCE_UNSUPPORTED = "X"; constexpr auto MATRIX_A = "matrix_a"; constexpr auto MATRIX_B = "matrix_b"; constexpr auto MATRIX_C = "matrix_c"; +constexpr auto MATRIX_ELSE = "matrix_else"; constexpr auto FRAGMENT = "fragment_"; constexpr auto LOCAL_SUFFIX = "_local"; constexpr auto SHARE_SUFFIX = "_shared"; diff --git a/src/poly/schedule_pass.cc b/src/poly/schedule_pass.cc index b91bf37139921ed2f7f667735ce4caa13398bce9..47b07918337ddcd900a596fb718bd3a41d0978cf 100644 --- a/src/poly/schedule_pass.cc +++ b/src/poly/schedule_pass.cc @@ -76,7 +76,28 @@ isl::schedule_node ReorderFilters(const isl::schedule_node &node, return isl::manage(new_node); } -isl::schedule_node InsertContextNode(isl::schedule_node &node, ScopInfo &scop_info) { +size_t CountConsecutiveCoincident(const isl::schedule_node &node) { + size_t count = 0; + if (!node.isa()) { + return count; + } + + isl::schedule_node_band band_node = node.as(); + while (count < band_node.n_member()) { + if (!band_node.member_get_coincident(static_cast(count))) { + break; + } + ++count; + } + return count; +} + +isl::schedule InsertContextNode(const isl::schedule &sch, ScopInfo &scop_info) { + auto node = sch.root().child(0); + if (node.isa()) { + node = node.del(); + } + // step1. get config std::unordered_map mapping_ids_with_sizes; auto block_cfg = scop_info.user_config_.GetBlockConfig(); @@ -117,7 +138,7 @@ isl::schedule_node InsertContextNode(isl::schedule_node &node, ScopInfo &scop_in scop_info.analysis_result_.RecordContextParams(context_set); // step3. insert context node = node.insert_context(context_set.from_params()); - return node; + return node.get_schedule(); } isl::union_map DependenceAnalysis(const isl::union_map &sources, const isl::union_map &targets, @@ -317,7 +338,8 @@ bool ReplaceScheduleTree(isl::schedule &schedule, ScopInfo &info) { } std::vector GetTileSizeOfLevel(const int member_size, const int dim_size, const std::string &tile_level, - TileSizes tile_sizes, const int count_coincident) { + TileSizes tile_sizes, const int count_coincident, + const std::vector warp_list) { std::vector tile_size(member_size, 0); for (auto i = 0; i < member_size; ++i) { if (i >= dim_size) { @@ -329,6 +351,8 @@ std::vector GetTileSizeOfLevel(const int member_size, const int dim_size, c tile_size[i] = static_cast(tile_sizes[i].c0_tiling_size); } else if (tile_level == TILE_WITH_C1) { tile_size[i] = static_cast(tile_sizes[i].c1_tiling_size); + } else if (tile_level == TILE_WITH_WARP_C1) { + tile_size[i] = warp_list[i]; } else { // The tiling size of n and m is warp_number times of c0_tiling_size, which is equivalent to extracting the for // loop generated during mapping.This avoids the if condition and facilitates isl_emitter. @@ -339,7 +363,8 @@ std::vector GetTileSizeOfLevel(const int member_size, const int dim_size, c return tile_size; } -std::string GetPromotionTensorName(const isl::schedule_node &node, const std::vector &buffer_def_infos) { +std::string GetPromotionTensorName(const isl::schedule_node &node, + const std::vector &buffer_def_infos) { std::string id_name = ""; if (!node.isa()) { return id_name; @@ -347,10 +372,15 @@ std::string GetPromotionTensorName(const isl::schedule_node &node, const std::ve for (size_t i = 0; i < buffer_def_infos.size(); ++i) { auto tensor_id = buffer_def_infos[i].tensor_id; isl::union_set id_domain = node.as().get_partial_schedule().domain(); + id_domain = id_domain.unwrap().range(); id_domain.foreach_set([tensor_id, &id_name](const isl::set &s) -> void { - if (s.to_str().find(tensor_id.get_name()) != std::string::npos) { - id_name = tensor_id.get_name(); + std::string node_tensor_name = s.get_tuple_name(); + size_t pos = 0; + if ((pos = node_tensor_name.find(LOCAL_SUFFIX)) != std::string::npos || + (pos = node_tensor_name.find(SHARE_SUFFIX)) != std::string::npos) { + node_tensor_name = node_tensor_name.erase(pos, node_tensor_name.size() - pos); } + id_name = (node_tensor_name == tensor_id.get_name()) ? node_tensor_name : id_name; }); if (!id_name.empty()) { diff --git a/src/poly/schedule_pass.h b/src/poly/schedule_pass.h index 57211d51d21d6740485bbac50b6d6c0dd606c14e..1d8639242d260791f79c31510a1a8ad52404fccb 100644 --- a/src/poly/schedule_pass.h +++ b/src/poly/schedule_pass.h @@ -82,7 +82,12 @@ isl::union_map ComputeFakeCopyin(const isl::schedule &schedule, const isl::union /* * Insert a context node beyond to determine bound block and thread sizes for Gpu. */ -isl::schedule_node InsertContextNode(isl::schedule_node &node, ScopInfo &scop_info); +isl::schedule InsertContextNode(const isl::schedule &sch, ScopInfo &scop_info); + +/* + * Get the number of axis whose coincidence is 1 in the current band node. + */ +size_t CountConsecutiveCoincident(const isl::schedule_node &node); /* * Tile a node band based on given tile sizes. @@ -90,12 +95,14 @@ isl::schedule_node InsertContextNode(isl::schedule_node &node, ScopInfo &scop_in isl::schedule_node TileBand(isl::schedule_node node, const isl::multi_val &sizes); std::vector GetTileSizeOfLevel(const int member_size, const int dim_size, const std::string &tile_level, - TileSizes tile_sizes, const int count_coincident = -1); + TileSizes tile_sizes, const int count_coincident = -1, + const std::vector warp_list = {}); /* * Obtain the information needed during the data promotion phase. */ -std::string GetPromotionTensorName(const isl::schedule_node &node, const std::vector &buffer_def_infos); +std::string GetPromotionTensorName(const isl::schedule_node &node, + const std::vector &buffer_def_infos); bool IsReadOrWriteTensor(const isl::schedule_node &node, const std::string read_name, const std::string write_name); diff --git a/src/poly/schedule_pass/scheduling_mind_trick.cc b/src/poly/schedule_pass/scheduling_mind_trick.cc index 71ec409814afc4ded57984c232807ba2eedd80f9..b2fcd9e4d796501c32c6c2604fb5611a2b273056 100644 --- a/src/poly/schedule_pass/scheduling_mind_trick.cc +++ b/src/poly/schedule_pass/scheduling_mind_trick.cc @@ -27,7 +27,7 @@ #include "poly/isl_util.h" #include "poly/log_util.h" -#include "poly/gpu_isl_emitter.h" +#include "poly/gpu_emit/gpu_isl_emitter.h" namespace akg { namespace ir { diff --git a/src/poly/schedule_pass/tile_outer_band.cc b/src/poly/schedule_pass/tile_outer_band.cc index d9c17c179ef53a95df2024134327b8c3dd2ed1b8..267f8c1b98763f153d280300adfd6356cf165ec6 100644 --- a/src/poly/schedule_pass/tile_outer_band.cc +++ b/src/poly/schedule_pass/tile_outer_band.cc @@ -917,18 +917,18 @@ isl::schedule_node TileOuterBand::MarkOuterPermutableCuda(isl::schedule_node nod // get tile size node = SetTileSizeAndTile(node, TILE_WITH_C1); - if (scop_info_.user_config_.GetEnableTileC0()) { - node = SetTileSizeAndTile(node.child(0), TILE_WITH_C0); - node = node.parent(); - } - // tile matmul operator - if (!scop_info_.user_config_.GetEnableMatmul()) { - return node; + if (scop_info_.user_config_.GetEnableMatmul()) { + node = MatmulTile(node); } - size_t start_depth = node.get_tree_depth(); + return node; +} - isl::schedule_node_band band_node = node.as(); +isl::schedule_node TileOuterBand::MatmulTile(const isl::schedule_node &node) { + auto tile_node = node; + size_t start_depth = tile_node.get_tree_depth(); + + isl::schedule_node_band band_node = tile_node.as(); size_t count_coincident = 0; for (size_t i = 0; i < band_node.n_member(); ++i) { if (!band_node.member_get_coincident(i)) { @@ -938,50 +938,99 @@ isl::schedule_node TileOuterBand::MarkOuterPermutableCuda(isl::schedule_node nod } // split the k axis - node = band_node.split(count_coincident); - std::string shared_tensors = scop_info_.user_config_.GetSharedTensors(); - auto insert_marker = PROMOTE_GLOBAL_TO_REGISTER_C; - if (shared_tensors.find(COMPUTE) != std::string::npos) { - insert_marker = PROMOTE_GLOBAL_TO_SHARED_C; - } - node = node.child(0).insert_mark(isl::id(node.ctx(), insert_marker)); - auto sqlit_node = SplitBmmStatement(node.child(0)); - node = sqlit_node.is_equal(node.child(0)) ? sqlit_node : sqlit_node.child(0); - node = node.child(0).insert_mark(isl::id(node.ctx(), PROMOTE_GLOBAL_TO_SHARED_AB)); + tile_node = band_node.split(count_coincident); + tile_node = InsertPromoteMarker(tile_node); if (scop_info_.user_config_.GetEnableTensorCoreUsePoly()) { - // The second tile of tensor_core for mapping to warp. auto replace_cfg_map = scop_info_.user_config_.GetReplaceConfig(); if (replace_cfg_map.count(WARP_COMPUTE) == 0) { - auto thread_cfg = scop_info_.user_config_.GetThreadConfig(); - CHECK(thread_cfg != nullptr) << "thread config is null"; + ResetWarpMappingConfig(); + } + // The second tiling of tensor_core is to split the k-axis. + tile_node = SetTileSizeAndTile(tile_node.child(0), TILE_WITH_C0_C1, count_coincident); + } - int total_warp = 1; - for (size_t j = 0; j < thread_cfg->bound; ++j) { - total_warp *= thread_cfg->GetAt(j).second; - } - total_warp = std::ceil(total_warp / WARP_SIZE); - size_t warp_dim_x = std::sqrt(total_warp); - size_t warp_dim_y = total_warp / warp_dim_x; - std::string new_warp_cfg = std::to_string(warp_dim_x) + " " + std::to_string(warp_dim_y); - scop_info_.user_config_.RecordReplaceConfig(WARP_COMPUTE, new_warp_cfg, MappingType::REPLACE_THREADS); + // The third tiling of tensor_core is to map to warp. + tile_node = SetTileSizeAndTile(tile_node.child(0), TILE_WITH_WARP_C1, count_coincident); + if (!scop_info_.user_config_.GetEnableTensorCoreUsePoly()) { + tile_node = tile_node.child(0); + } + tile_node = tile_node.insert_mark(isl::id(tile_node.ctx(), PROMOTE_SHARED_TO_REGISTER_AB)); + // Locate the band to be mapped. + tile_node = tile_node.child(0).insert_mark(MAP_TO_WARP).child(0); + tile_node = tile_node.child(0).insert_mark(SKIP_MARKER).child(0); + + // The last tiling of tensor_core is to calculate the size of fragment. + tile_node = SetTileSizeAndTile(tile_node, TILE_WITH_C0); + tile_node = tile_node.child(0).insert_mark(SKIP_MARKER); + + if (scop_info_.user_config_.GetEnableConvTensorCore()) { + int child_depth = KH_KW_DEPTH; + while (tile_node.has_children() && child_depth != 0) { + --child_depth; + tile_node = tile_node.child(0); + } + if (tile_node.child(0).isa()) { + tile_node = tile_node.insert_mark(KH_KW_MARKER); } + } + + tile_node = tile_node.ancestor(tile_node.get_tree_depth() - start_depth); + return tile_node; +} - node = SetTileSizeAndTile(node.child(0), TILE_WITH_C0_C1, count_coincident); +void TileOuterBand::ResetWarpMappingConfig() { + auto thread_cfg = scop_info_.user_config_.GetThreadConfig(); + CHECK(thread_cfg != nullptr) << "thread config is null"; + + int total_warp = 1; + for (size_t j = 0; j < thread_cfg->bound; ++j) { + total_warp *= thread_cfg->GetAt(j).second; } + total_warp = std::ceil(total_warp / WARP_SIZE); + size_t warp_dim_x = std::sqrt(total_warp); + size_t warp_dim_y = total_warp / warp_dim_x; + std::string new_warp_cfg = std::to_string(warp_dim_x) + " " + std::to_string(warp_dim_y); + scop_info_.user_config_.RecordReplaceConfig(WARP_COMPUTE, new_warp_cfg, MappingType::REPLACE_THREADS); +} - // the last tile of tensor_core - node = SetTileSizeAndTile(node.child(0), TILE_WITH_C0); - if (!scop_info_.user_config_.GetEnableTensorCoreUsePoly()) { - node = node.child(0); +isl::schedule_node TileOuterBand::InsertPromoteMarker(const isl::schedule_node node) { + isl::schedule_node tile_node = node.child(0); + bool is_matrixc_promote_shared = IsMatrixCPromoteToShared(); + + // Add different promotion marks in different positions. + if (is_matrixc_promote_shared) { + tile_node = tile_node.insert_mark(isl::id(tile_node.ctx(), PROMOTE_GLOBAL_TO_SHARED_C)).child(0); + tile_node = tile_node.insert_mark(isl::id(tile_node.ctx(), PROMOTE_SHARED_TO_REGISTER_C)).child(0); + } else { + tile_node = tile_node.insert_mark(isl::id(tile_node.ctx(), PROMOTE_GLOBAL_TO_REGISTER_C)).child(0); } - node = node.insert_mark(isl::id(node.ctx(), PROMOTE_SHARED_TO_REGISTER)); - node = node.ancestor(node.get_tree_depth() - start_depth); - return node; + tile_node = tile_node.child(0).insert_mark(isl::id(tile_node.ctx(), PROMOTE_GLOBAL_TO_SHARED_AB)); + return tile_node; +} + +bool TileOuterBand::IsMatrixCPromoteToShared() { + std::string shared_tensors = scop_info_.user_config_.GetSharedTensors(); + if (shared_tensors.empty()) { + return false; + } + + shared_tensors += " "; + auto pos = shared_tensors.find(" "); + while (pos != std::string::npos) { + std::string tensor = shared_tensors.substr(0, pos); + auto matmul_map = scop_info_.analysis_result_.GetMatrixMatmulMap(); + if (matmul_map.count(tensor) && (matmul_map[tensor] == MATRIX_C || matmul_map[tensor] == MATRIX_ELSE)) { + return true; + } + shared_tensors = shared_tensors.substr(pos + 1, shared_tensors.size()); + pos = shared_tensors.find(" "); + } + return false; } -isl::schedule_node TileOuterBand::SplitBmmStatement(const isl::schedule_node &node) { +isl::schedule_node TileOuterBand::SplitMatmulStatement(const isl::schedule_node &node) { isl::schedule_node tile_node = node; auto all_reduce_map = scop_info_.analysis_result_.GetReduceTensorInfoMap(); ReduceManager reduce_manager; @@ -1011,7 +1060,26 @@ isl::schedule_node TileOuterBand::SetTileSizeAndTile(const isl::schedule_node &n const unsigned int n_member = node.as().n_member(); auto title_size = static_cast(tile_sizes_.size()); unsigned int dim_num = (n_member <= title_size) ? n_member : title_size; - std::vector tile_size = GetTileSizeOfLevel(n_member, dim_num, tile_level, tile_sizes_, count_coincident); + std::vector tile_size; + auto replace_cfg_map = scop_info_.user_config_.GetReplaceConfig(); + if (tile_level == TILE_WITH_WARP_C1) { + std::vector warp_list; + CHECK_NE(replace_cfg_map.count(WARP_COMPUTE), 0) << "Can't find warpconfig"; + auto warp_cfg = replace_cfg_map[WARP_COMPUTE]; + for (size_t i = 0, j = 0; i < n_member; ++i) { + auto c1 = static_cast(tile_sizes_[i].c1_tiling_size); + auto c0 = static_cast(tile_sizes_[i].c0_tiling_size); + c1 = (static_cast(i) < count_coincident) ? c1 : c0; + if (c0 == scop_info_.analysis_result_.GetMmaMode().m && j < warp_cfg->bound) { + c1 = std::max(c1 / warp_cfg->GetAt(j).second, c0); + ++j; + } + warp_list.push_back(c1); + } + tile_size = GetTileSizeOfLevel(n_member, dim_num, tile_level, tile_sizes_, count_coincident, warp_list); + } else { + tile_size = GetTileSizeOfLevel(n_member, dim_num, tile_level, tile_sizes_, count_coincident); + } isl::multi_val sizes = ComputeBandTilesSizes(node, &tile_size[0]); return TileBand(node, sizes); } diff --git a/src/poly/schedule_pass/tile_outer_band.h b/src/poly/schedule_pass/tile_outer_band.h index 4258f23e2489c2676f264b1f7564ae5aa2797af7..3837e996360ac3c8c1457be2a0e4147ac0dd54e6 100644 --- a/src/poly/schedule_pass/tile_outer_band.h +++ b/src/poly/schedule_pass/tile_outer_band.h @@ -23,6 +23,8 @@ namespace akg { namespace ir { namespace poly { +constexpr auto KH_KW_DEPTH = 2; + /* * Tile the outer band accoding to TilingInfo. In this pass, we get the out-most band, * decide tile_size depending on the types of operators, and then start tiling. @@ -86,9 +88,13 @@ class TileOuterBand : public SchedulePass { void ComputeWInfo(int &w_base, bool &head, bool &tail, int &w_head, int &w_tail, int &win_w, int &win_cut_w); bool NeedIsolate(); bool BoolNot(bool b) { return !b; } - isl::schedule_node SplitBmmStatement(const isl::schedule_node &node); + isl::schedule_node SplitMatmulStatement(const isl::schedule_node &node); isl::schedule_node SetTileSizeAndTile(const isl::schedule_node &node, const std::string &tile_level, const int count_coincident = -1); + bool IsMatrixCPromoteToShared(); + isl::schedule_node InsertPromoteMarker(const isl::schedule_node node); + void ResetWarpMappingConfig(); + isl::schedule_node MatmulTile(const isl::schedule_node &node); private: PassInfo &pass_info_; diff --git a/src/poly/schedule_pass_gpu/mapping_outer_band.cc b/src/poly/schedule_pass_gpu/mapping_outer_band.cc index f699d2410917629d51d5d4513fdf92c159f59159..4b3b3721ca4593b58370b2f212697e15cd5a7350 100644 --- a/src/poly/schedule_pass_gpu/mapping_outer_band.cc +++ b/src/poly/schedule_pass_gpu/mapping_outer_band.cc @@ -15,259 +15,21 @@ */ #include "mapping_outer_band.h" +#include "poly/schedule_pass_gpu/operator_mapping_strategy.h" #include #include "poly/schedule_tree_util.h" #include "poly/sync_manager.h" #include "poly/scop.h" -#include "poly/gpu_isl_emitter.h" +#include "poly/gpu_emit/gpu_isl_emitter.h" namespace akg { namespace ir { namespace poly { -isl::multi_union_pw_aff MappingOuterBand::MapDomainToWarp(const isl::schedule_node &node, MappingCfg *mapping_cfg, - isl::multi_union_pw_aff domain_threads) { - isl::space space = isl::space(node.ctx(), 0); - auto block_space = space.add_named_tuple_id_ui(isl::id(node.ctx(), SYNC_BLOCK), mapping_cfg->bound); - auto bspace = block_space; - auto warp_space = space.add_named_tuple_id_ui(isl::id(node.ctx(), SYNC_WARP), 1); - - auto block_aff = isl_aff_zero_on_domain(isl_local_space_from_space(bspace.release())); - isl::aff aff = isl::manage(block_aff); - - auto identity = isl::multi_aff::identity(block_space.map_from_set()); - for (int i = mapping_cfg->bound - 1; i >= 0; --i) { - auto bi = mapping_cfg->GetAt(i); - aff = aff.scale(isl::val(node.ctx(), bi.second)); - aff = aff.add(identity.get_aff(i)); - } - - aff = aff.scale_down(isl::val(node.ctx(), WARP_SIZE)).floor(); - auto map_space = block_space.product(warp_space).unwrap(); - isl::multi_aff thread_warp = isl::multi_aff(map_space, isl::aff_list(aff)); - return domain_threads.apply(thread_warp); -} - -bool MappingOuterBand::IsOuterBandWithNoCoincident(const isl::schedule_node &node) { - int depth = node.get_tree_depth(); - isl::schedule_node ancestor_node; - - for (int i = 0; i < depth; ++i) { - ancestor_node = node.ancestor(depth - i); - if (auto band = ancestor_node.as()) { - auto n_coincident = CountConsecutiveCoincident(band); - if (band.n_member() > n_coincident) { - return true; - } - } - if (ancestor_node.isa()) { - return false; - } - } - - return false; -} - -size_t MappingOuterBand::GetReduceId() const { - static size_t reduce_count = 0; - return reduce_count++; -} - -std::string MappingOuterBand::GetMarkerName(const isl::schedule_node &node, std::string find_name) { - std::string reduce_marker_name = ""; - if (node.isa()) { - reduce_marker_name = node.as().get_id().get_name(); - if (reduce_marker_name.find(find_name) != std::string::npos) { - return reduce_marker_name; - } - reduce_marker_name = ""; - } - return reduce_marker_name; -} - -size_t MappingOuterBand::CountConsecutiveCoincident(const isl::schedule_node_band &band_node) { - size_t count = 0; - while (count < band_node.n_member()) { - if (!band_node.member_get_coincident(static_cast(count))) { - break; - } - ++count; - } - return count; -} - -isl::schedule_node MappingOuterBand::FillRemainingThreads(isl::schedule_node &node, size_t begin) { - auto thread_cfg = scop_info_.user_config_.GetThreadConfig(); - CHECK(thread_cfg != nullptr) << "threadconfig is null"; - size_t end = thread_cfg->bound; - if (begin == end) { - return node; - } - - CHECK(node.isa()) << "The child of set or sequence must be a filter!"; - node = node.child(0); - - isl::union_set domain = CollectDomain(node); - isl::space space = domain.get_space(); - space = space.set_from_params(); - isl::multi_val mv = isl::multi_val::zero(space); - isl::multi_union_pw_aff mupa = isl::multi_union_pw_aff(domain, mv); - node.insert_partial_schedule(mupa); - - isl::schedule_node_band band_node = node.as(); - Mapping mapping; - auto after_map_pair = MapInnerDimToThreads(band_node, false, thread_cfg, mapping, - scop_info_.analysis_result_.GetReduceDirection() == Y_DIRECTION); - auto after_map_node = after_map_pair.first; - scop_info_.upa_node_mapping_.emplace_back(std::make_pair(after_map_node, mapping)); - after_map_node = after_map_node.parent(); - return after_map_node; -} - -size_t MappingOuterBand::NumMappedDescendant(const RoadMap &thread_roadmap, const isl::schedule_node parent) { - size_t max_thread_size = 0; - for (const auto &record : thread_roadmap) { - auto child_node = record.first; - auto thread_size = record.second; - bool is_child = IsEqualNode(parent, child_node); - while (!is_child && child_node && child_node.has_parent()) { - child_node = child_node.parent(); - is_child = IsEqualNode(parent, child_node); - } - if (is_child) { - max_thread_size = std::max(max_thread_size, thread_size); - } - } - return max_thread_size; -} - -bool MappingOuterBand::CanBeMappedToThread(const isl::schedule_node node, const RoadMap &thread_record) { - auto IsInnerMostBand = [this, &thread_record](const isl::schedule_node node) { - auto band = node.as(); - return band && band.permutable() && NumMappedDescendant(thread_record, node) == 0; - }; - - auto HasMapped = [&thread_record](const isl::schedule_node node) -> bool { - for (size_t i = 0; i < thread_record.size(); ++i) { - if (IsEqualNode(thread_record[i].first, node)) { - return true; - } - } - return false; - }; - - if (!IsInnerMostBand(node)) { - return false; - } - - auto band = node.as(); - - // make sure a band node in a sequence node only be mapped when all its siblings can be mapped together - if (band.ancestor(2) && band.ancestor(2).isa()) { - auto seq = band.ancestor(2).as(); - for (size_t i = 0; i < seq.n_children(); ++i) { - auto filter = seq.child(i); - if (filter.child(0).isa()) { - continue; - } - if (!IsInnerMostBand(filter.child(0)) && !HasMapped(filter)) { - return false; - } - } - } - return true; -} - -isl::schedule MappingOuterBand::DoThreadMapping(const isl::schedule &sch) { - auto final_schedule = sch; - auto thread_cfg = scop_info_.user_config_.GetThreadConfig(); - CHECK(thread_cfg != nullptr) << "thread config is null"; - if (thread_cfg->bound < 1) { - return final_schedule; - } - - // Step 1. Find inner-most permutable band to map threads. - RoadMap thread_record; - bool is_reduce_stmt = false; - auto MapFromInner = [&thread_record, &is_reduce_stmt, thread_cfg, - this](isl::schedule_node node) -> isl::schedule_node { - if (scop_info_.user_config_.GetEnableAkgReduceLib() && node.has_parent() && - !GetMarkerName(node.parent(), REDUCE_MARKER).empty()) { - is_reduce_stmt = true; - } - - if (scop_info_.user_config_.GetEnableTensorCoreUsePoly() && node.get_tree_depth() >= 2 && - !GetMarkerName(node.ancestor(2), PROMOTE_SHARED_TO_REGISTER).empty()) { - return node; - } - - if (node.has_parent() && node.parent().isa()) { - const std::string &marker = node.parent().as().get_id().get_name(); - if (marker == MIND_TRICKS_SWIZZLE_MARKER) { - return node; - } - } - - size_t num_mapped_desc = NumMappedDescendant(thread_record, node); - - if (CanBeMappedToThread(node, thread_record)) { - auto node_bak = node; - auto mapped_threads = MapThreadHelper(node); - if (!node_bak.is_equal(node)) { - // if successfully mapped current node, we insert a map filter beyond and need to return to band node - node = node.parent(); - } - thread_record.emplace_back(std::make_pair(node, mapped_threads)); - return node; - } - - // deal with band that has children mapped to threads - if (node.n_children() > 1 && num_mapped_desc > 0) { - auto num_children = node.n_children(); - int start_node_depth = node.get_tree_depth(); - for (size_t i = 0; i < num_children; ++i) { - isl::schedule_node node_child = node.child(i); - for (const auto &record : thread_record) { - auto child_node = record.first; - auto thread_size = record.second; - if (child_node.has_parent() && child_node.parent().isa()) { - child_node = child_node.parent(); - } - bool is_child = IsEqualNode(node_child, child_node); - if (is_child) { - node_child = FillRemainingThreads(node_child, thread_size); - node = node_child.ancestor(node_child.get_tree_depth() - start_node_depth); - break; - } - } - } - - auto need_sync = node.isa(); - if (need_sync) { - if (is_reduce_stmt && node.has_parent() && !GetMarkerName(node.parent(), INSERT_SYNC).empty()) { - node = node.parent().del(); - node = DoThreadSynchronization(node); - } else if (!is_reduce_stmt) { - node = DoThreadSynchronization(node); - } - } - - auto band = node.as(); - if (band && CountConsecutiveCoincident(band) < band.n_member()) { - CHECK_EQ(num_mapped_desc, thread_cfg->bound) << "Must be mapped to all threads."; - auto sync_manager = scop_info_.sync_manager_; - sync_manager.InsertExtensionNode(band.child(0), SyncLevel::BLOCK, true); - } - } - return node; - }; - final_schedule = sch.get_root().map_descendant_bottom_up(MapFromInner).get_schedule(); - return final_schedule; -} - -isl::schedule_node MappingOuterBand::DoThreadSynchronization(const isl::schedule_node &node) { +isl::schedule_node MappingOuterBand::DoThreadSynchronization(const isl::schedule_node &node, + const std::vector other_mapping_cfg) { auto sync_node = node; auto sync_manager = scop_info_.sync_manager_; auto thread_cfg = scop_info_.user_config_.GetThreadConfig(); @@ -276,12 +38,11 @@ isl::schedule_node MappingOuterBand::DoThreadSynchronization(const isl::schedule // Step 1. prepare info bool is_outer = IsOuterBandWithNoCoincident(node); auto domain_thread = MapDomainToThread(node, thread_cfg, scop_info_.upa_node_mapping_); - - if (scop_info_.user_config_.GetEnableTensorCoreUsePoly()) { - auto warp_cfg = scop_info_.user_config_.GetReplaceConfig()[WARP_COMPUTE]; - CHECK(warp_cfg != nullptr) << "warp config is null"; - auto domain_thread_warp = MapDomainToThread(node, warp_cfg, scop_info_.upa_node_mapping_); - domain_thread = domain_thread.union_add(domain_thread_warp); + for (size_t i = 0; i < other_mapping_cfg.size(); ++i) { + auto mapping_cfg = other_mapping_cfg[i]; + CHECK(mapping_cfg != nullptr) << "mapping config is null"; + auto domain_other_mapping = MapDomainToThread(node, mapping_cfg, scop_info_.upa_node_mapping_); + domain_thread = domain_thread.union_add(domain_other_mapping); } auto domain_node = CollectDomain(node); bool sub_set = domain_node.is_subset(domain_thread.domain()); @@ -416,310 +177,238 @@ int MappingOuterBand::GetBestSyncStartPoint(bool is_outer) { return 0; } -isl::schedule_node MappingOuterBand::InsertReduceExtension(const isl::schedule_node &node) { - auto thread_cfg = scop_info_.user_config_.GetThreadConfig(); - CHECK(thread_cfg != nullptr) << "thread config is null"; - - isl::schedule_node insert_node = node; - isl::schedule_node parent_node = node; - isl::schedule_node ancestor_node = node; - if (insert_node.has_parent()) { - parent_node = parent_node.parent(); - if (parent_node.has_parent()) { - ancestor_node = parent_node.parent(); - } - } +isl::multi_union_pw_aff MappingOuterBand::MapDomainToWarp(const isl::schedule_node &node, MappingCfg *mapping_cfg, + isl::multi_union_pw_aff domain_threads) { + isl::space space = isl::space(node.ctx(), 0); + auto block_space = space.add_named_tuple_id_ui(isl::id(node.ctx(), SYNC_BLOCK), mapping_cfg->bound); + auto bspace = block_space; + auto warp_space = space.add_named_tuple_id_ui(isl::id(node.ctx(), SYNC_WARP), 1); - std::string reduce_marker_name = ""; - if (!GetMarkerName(parent_node, REDUCE_MARKER).empty()) { - reduce_marker_name = GetMarkerName(parent_node, REDUCE_MARKER); - insert_node = parent_node.del(); - } + auto block_aff = isl_aff_zero_on_domain(isl_local_space_from_space(bspace.release())); + isl::aff aff = isl::manage(block_aff); - if (!GetMarkerName(ancestor_node, REDUCE_MARKER).empty()) { - reduce_marker_name = GetMarkerName(ancestor_node, REDUCE_MARKER); - insert_node = ancestor_node.del(); + auto identity = isl::multi_aff::identity(block_space.map_from_set()); + for (int i = mapping_cfg->bound - 1; i >= 0; --i) { + auto bi = mapping_cfg->GetAt(i); + aff = aff.scale(isl::val(node.ctx(), bi.second)); + aff = aff.add(identity.get_aff(i)); } - if (reduce_marker_name.empty()) { - return node; - } + aff = aff.scale_down(isl::val(node.ctx(), WARP_SIZE)).floor(); + auto map_space = block_space.product(warp_space).unwrap(); + isl::multi_aff thread_warp = isl::multi_aff(map_space, isl::aff_list(aff)); + return domain_threads.apply(thread_warp); +} - reduce_marker_name.erase(0, strlen(REDUCE_MARKER)); - isl::id sync_id = isl::id(insert_node.ctx(), REDUCE_UPDATE + reduce_marker_name); - isl::id reduction_id = isl::id(insert_node.ctx(), REDUCE_INIT + reduce_marker_name); +bool MappingOuterBand::IsOuterBandWithNoCoincident(const isl::schedule_node &node) { + int depth = node.get_tree_depth(); + isl::schedule_node ancestor_node; - insert_node = InsertExtensionNodeBeforeOrAfter(insert_node, reduction_id, true); - insert_node = InsertExtensionNodeBeforeOrAfter(insert_node, sync_id, false).parent(); + for (int i = 0; i < depth; ++i) { + ancestor_node = node.ancestor(depth - i); + if (auto band = ancestor_node.as()) { + auto n_coincident = CountConsecutiveCoincident(band); + if (band.n_member() > n_coincident) { + return true; + } + } + if (ancestor_node.isa()) { + return false; + } + } - return insert_node; + return false; } -size_t MappingOuterBand::MapThreadHelper(isl::schedule_node &thread_root) { - isl::schedule_node_band band_node = thread_root.as(); +isl::schedule_node MappingOuterBand::FillRemainingThreads(isl::schedule_node &node, size_t begin) { auto thread_cfg = scop_info_.user_config_.GetThreadConfig(); - CHECK(thread_cfg != nullptr) << "thread config is null"; - if (thread_cfg->bound < 1) { - return 0; + CHECK(thread_cfg != nullptr) << "threadconfig is null"; + size_t end = thread_cfg->bound; + if (begin == end) { + return node; } - if (!band_node) { - LOG(WARNING) << "No permutable band to map thread."; - return 0; - } + CHECK(node.isa()) << "The child of set or sequence must be a filter!"; + node = node.child(0); - int start_node_depth = thread_root.get_tree_depth(); - // Step 1. Determine max num dimension of threads that can be mapped. - auto n_thread_map = CountConsecutiveCoincident(band_node); - - bool is_bmm_statement = false; - if (scop_info_.user_config_.GetEnableTensorCoreUsePoly() && thread_root.has_parent() && - !GetMarkerName(thread_root.parent(), PROMOTE_SHARED_TO_REGISTER).empty()) { - auto warp_cfg = scop_info_.user_config_.GetReplaceConfig()[WARP_COMPUTE]; - CHECK(warp_cfg != nullptr) << "warp config is null"; - thread_cfg = warp_cfg; - is_bmm_statement = true; - } + isl::schedule_node_band band_node = node.as(); + Mapping mapping; + auto after_map_node = MapInnerDimToThreads(band_node, false, thread_cfg, mapping, false); + bool is_tiled = GetMarkerName(after_map_node, THREAD_MARKER).empty(); + after_map_node = is_tiled ? after_map_node.child(0) : after_map_node; + scop_info_.upa_node_mapping_.emplace_back(std::make_pair(after_map_node, mapping)); + return after_map_node; +} - bool is_reduce_stmt = false; - std::string reduce_marker_name = ""; - if (band_node.has_parent()) { - reduce_marker_name = GetMarkerName(band_node.parent(), REDUCE_MARKER); - if (!reduce_marker_name.empty()) { - ++n_thread_map; - is_reduce_stmt = true; +size_t MappingOuterBand::NumMappedDescendant(const RoadMap &thread_roadmap, const isl::schedule_node parent) { + size_t max_thread_size = 0; + for (const auto &record : thread_roadmap) { + auto child_node = record.first; + auto thread_size = record.second; + bool is_child = IsEqualNode(parent, child_node); + while (!is_child && child_node && child_node.has_parent()) { + child_node = child_node.parent(); + is_child = IsEqualNode(parent, child_node); + } + if (is_child) { + max_thread_size = std::max(max_thread_size, thread_size); } } + return max_thread_size; +} - if (n_thread_map < 1) { - return 0; - } +bool MappingOuterBand::CanBeMappedToThread(const isl::schedule_node node, const RoadMap &thread_record) { + auto IsInnerMostBand = [this, &thread_record](const isl::schedule_node node) { + auto band = node.as(); + return band && band.permutable() && NumMappedDescendant(thread_record, node) == 0; + }; - // Step 2. Split band node according to mapping config and coincidence of band node. - if (n_thread_map > thread_cfg->bound) { - thread_root = band_node.split(n_thread_map - thread_cfg->bound); - thread_root = thread_root.child(0); - n_thread_map = thread_cfg->bound; - if (is_reduce_stmt) { - isl::schedule_node ancestor_node = thread_root.ancestor(2); - ancestor_node = ancestor_node.del().child(0); - thread_root = ancestor_node.insert_mark(reduce_marker_name); - thread_root = thread_root.child(0); + auto HasMapped = [&thread_record](const isl::schedule_node node) -> bool { + for (size_t i = 0; i < thread_record.size(); ++i) { + if (IsEqualNode(thread_record[i].first, node)) { + return true; + } } - band_node = thread_root.as(); - } + return false; + }; - // split to keep nodes with coincident equals to 1 - if (n_thread_map < band_node.n_member() && !scop_info_.user_config_.EnableStitchFusion()) { - thread_root = band_node.split(n_thread_map); - band_node = thread_root.as(); - } else { - n_thread_map = static_cast(band_node.n_member()); + if (!IsInnerMostBand(node)) { + return false; } - // Step 3. Map band under thread_root from inner dim to outer dim. - Mapping mapping; - bool is_y_reduce = - scop_info_.analysis_result_.GetReduceDirection() == Y_DIRECTION || scop_info_.user_config_.GetEnableTensorCore(); - auto after_map_pair = MapInnerDimToThreads(band_node, false, thread_cfg, mapping, is_y_reduce); - thread_root = after_map_pair.first; - if (is_bmm_statement && !GetMarkerName(thread_root, THREAD_MARKER).empty()) { - thread_root = thread_root.del().insert_mark(isl::id(thread_root.ctx(), WARP_MARKER)); - } - scop_info_.upa_node_mapping_.emplace_back(std::make_pair(thread_root, mapping)); - int end_node_depth = thread_root.get_tree_depth() - start_node_depth; - - if (is_reduce_stmt) { - // Split the reduce axis and non-reduce axis of the outer band. - if (thread_root.ancestor(2) && !GetMarkerName(thread_root.ancestor(2), REDUCE_MARKER).empty() && n_thread_map > 1) { - thread_root = thread_root.ancestor(2).del(); - band_node = thread_root.as(); - thread_root = band_node.split(n_thread_map - 1).child(0); - thread_root = thread_root.insert_mark(reduce_marker_name); - thread_root = thread_root.child(0); + auto band = node.as(); + + // make sure a band node in a sequence node only be mapped when all its siblings can be mapped together + if (band.ancestor(2) && band.ancestor(2).isa()) { + auto seq = band.ancestor(2).as(); + for (size_t i = 0; i < seq.n_children(); ++i) { + auto filter = seq.child(i); + if (filter.child(0).isa()) { + continue; + } + if (!IsInnerMostBand(filter.child(0)) && !HasMapped(filter)) { + return false; + } } - // Add the filter that initializes and calls the akg_reduce library for the reduce statement. - thread_root = InsertReduceExtension(thread_root); - end_node_depth = thread_root.get_tree_depth() - start_node_depth; - ++end_node_depth; } - thread_root = thread_root.ancestor(end_node_depth); + return true; +} - // Step 4. Do unroll if needed. - if (scop_info_.user_config_.GetMaxUnrollLoop() != 1) { - isl::schedule_node after_fix_node = thread_root.child(0); - if (!IsEqualNode(after_map_pair.second, after_map_pair.first)) { - after_fix_node = after_fix_node.parent(); +isl::schedule_node MappingOuterBand::MapSequenceNode(const isl::schedule_node &orig_node, + const RoadMap &thread_record) { + // deal with band that has children mapped to threads + auto node = orig_node; + auto num_children = node.n_children(); + int start_node_depth = node.get_tree_depth(); + for (size_t i = 0; i < num_children; ++i) { + isl::schedule_node node_child = node.child(i); + for (const auto &record : thread_record) { + auto child_node = record.first; + auto thread_size = record.second; + if (child_node.has_parent() && child_node.parent().isa()) { + child_node = child_node.parent(); + } + bool is_child = IsEqualNode(node_child, child_node); + if (is_child) { + node_child = FillRemainingThreads(node_child, thread_size); + node = node_child.ancestor(node_child.get_tree_depth() - start_node_depth); + break; + } } - thread_root = UnrollByMarkOptions(after_fix_node, scop_info_.user_config_.GetMaxUnrollLoop()); } - return thread_cfg->bound; + return node; } -isl::schedule MappingOuterBand::DetectAndMarkReduce(const isl::schedule &sch) { +isl::schedule MappingOuterBand::DoThreadMapping(const isl::schedule &sch) { auto final_schedule = sch; auto thread_cfg = scop_info_.user_config_.GetThreadConfig(); - CHECK(thread_cfg != nullptr) << "threadconfig is null"; - if (thread_cfg->bound == 0) { + CHECK(thread_cfg != nullptr) << "thread config is null"; + if (thread_cfg->bound < 1) { return final_schedule; } - auto all_reduce_map = scop_info_.analysis_result_.GetReduceTensorInfoMap(); - ReduceManager reduce_manager; - bool done_separate = false; - auto GetInnerMostBand = [&done_separate, &all_reduce_map, &reduce_manager, thread_cfg, - this](isl::schedule_node node) -> isl::schedule_node { - if (done_separate) { - return node; + // Step 1. Find inner-most permutable band to map threads. + RoadMap thread_record; + bool is_reduce_stmt = false; + + auto MapFromInner = [&thread_record, &is_reduce_stmt, thread_cfg, + this](isl::schedule_node node) -> isl::schedule_node { + // batch matmul operator + bool is_bmm_stmt = false; + if (scop_info_.user_config_.GetEnableTensorCoreUsePoly()) { + if (node.has_parent() && !GetMarkerName(node.parent(), SKIP_MARKER).empty()) { + node = node.parent().del(); + return node; + } + if ((node.has_parent() && !GetMarkerName(node.parent(), MAP_TO_WARP).empty())) { + node = node.parent().del(); + is_bmm_stmt = true; + } } - auto band_node = node.as(); - if (!band_node || !band_node.permutable()) { - return node; + + // swizzle + if (node.has_parent() && node.parent().isa()) { + const std::string &marker = node.parent().as().get_id().get_name(); + if (marker == MIND_TRICKS_SWIZZLE_MARKER) { + return node; + } } - auto band_node_domain = band_node.get_partial_schedule().domain(); - StatementMap all_statements = scop_info_.analysis_result_.GetStatementMap(); - isl::union_map reduce_statement_map = isl::union_map::empty(node.ctx()); - isl::union_set reduce_statements = isl::union_set::empty(node.ctx()); - - for (auto it = all_reduce_map.begin(); it != all_reduce_map.end();) { - reduce_statement_map = reduce_statement_map.unite(it->second.stmt_map); - auto this_reduce = reduce_manager.GetReduceStatements(band_node_domain, reduce_statement_map, all_statements); - if (!this_reduce.is_empty()) { - reduce_statements = reduce_statements.unite(this_reduce); - all_reduce_map.erase(it++); + if (CanBeMappedToThread(node, thread_record)) { + auto node_bak = node; + size_t mapped_threads = 0; + if (scop_info_.user_config_.GetEnableAkgReduceLib() && node.has_parent() && + !GetMarkerName(node.parent(), REDUCE_MARKER).empty()) { + // reduce operator + is_reduce_stmt = true; + ReduceMappingStrategy reduce_op(pass_info_, scop_info_); + mapped_threads = reduce_op.MapThreadHelper(node); + } else if (is_bmm_stmt) { + // batch matmul operator + BatchMatmulMappingStrategy bmm_op(pass_info_, scop_info_); + if (scop_info_.user_config_.GetEnableConvTensorCore()) { + // conv operator + node = AdjustConvScheduleTreeStructure(node, false); + } + mapped_threads = bmm_op.MapThreadHelper(node); } else { - ++it; + // others operator + OperatorMappingStrategy others_op(pass_info_, scop_info_); + mapped_threads = others_op.MapThreadHelper(node); } - } - if (reduce_statements.n_set() < 1) { + if (!node_bak.is_equal(node)) { + // if successfully mapped current node, we insert a map filter beyond and need to return to band node + node = node.parent(); + } + thread_record.emplace_back(std::make_pair(node, mapped_threads)); return node; } - isl::union_map dependences = pass_info_.dependences_.subtract(pass_info_.force_dependences_); - auto node_bak = node; - if (!reduce_manager.SplitReduceStatements(node, reduce_statements, dependences, true)) { - return node_bak; - } - done_separate = all_reduce_map.empty(); - return node; - }; - final_schedule = sch.get_root().map_descendant_bottom_up(GetInnerMostBand).get_schedule(); - if (done_separate) { - final_schedule = InsertReduceMarker(final_schedule); - } - return final_schedule; -} - -isl::schedule MappingOuterBand::InsertReduceMarker(const isl::schedule &sch) { - isl::schedule final_schedule = sch; - auto all_reduce_map = scop_info_.analysis_result_.GetReduceTensorInfoMap(); - auto InsertMarker = [&all_reduce_map, this](isl::schedule_node node) -> isl::schedule_node { - ReduceManager reduce_manager; - auto band_node = node.as(); - if (!band_node) { + if (node.n_children() <= 1 || NumMappedDescendant(thread_record, node) <= 0) { return node; } - - for (auto it = all_reduce_map.begin(); it != all_reduce_map.end();) { - isl::union_map reduce_statement_map = it->second.stmt_map; - isl::id reduce_id = it->first; - auto band_node_domain = band_node.get_partial_schedule().domain(); - auto op_type = scop_info_.analysis_result_.GetReduceOpType(reduce_id) + "_"; - - StatementMap all_statements = scop_info_.analysis_result_.GetStatementMap(); - isl::union_set reduce_statements = - reduce_manager.GetReduceStatements(band_node_domain, reduce_statement_map, all_statements); - if (reduce_statements.n_set() != 1) { - ++it; - continue; + node = MapSequenceNode(node, thread_record); + + auto need_sync = node.isa(); + if (need_sync) { + if (is_reduce_stmt && node.has_parent() && !GetMarkerName(node.parent(), INSERT_SYNC).empty()) { + node = node.parent().del(); + node = DoThreadSynchronization(node); + } else if (!is_reduce_stmt && scop_info_.user_config_.GetEnableTensorCoreUsePoly()) { + std::vector other_mapping_cfg; + other_mapping_cfg.push_back(scop_info_.user_config_.GetReplaceConfig()[WARP_COMPUTE]); + node = DoThreadSynchronization(node, other_mapping_cfg); + } else if (!is_reduce_stmt) { + node = DoThreadSynchronization(node); } - - all_reduce_map.erase(it++); - std::string reduce_marker_name = - REDUCE_MARKER + op_type + reduce_id.get_name() + "_" + std::to_string(GetReduceId()); - auto reduce_node = band_node.insert_mark(reduce_marker_name); - return reduce_node; - } - return band_node; - }; - final_schedule = final_schedule.get_root().map_descendant_bottom_up(InsertMarker).get_schedule(); - return final_schedule; -} - -std::pair MappingOuterBand::GetC1C0BlockConfig(size_t n_block_map, int member_size) { - auto block_cfg = scop_info_.user_config_.GetBlockConfig(); - CHECK(block_cfg != nullptr) << "block config is null"; - auto title_size = static_cast(pass_info_.tile_sizes_.size()); - auto dim_num = (member_size <= title_size) ? member_size : title_size; - std::vector c1_tile_size = GetTileSizeOfLevel(member_size, dim_num, TILE_WITH_C1, pass_info_.tile_sizes_); - std::vector c0_tile_size = GetTileSizeOfLevel(member_size, dim_num, TILE_WITH_C0, pass_info_.tile_sizes_); - CHECK_EQ(c1_tile_size.size(), c0_tile_size.size()); - - for (size_t i = 0; i < c1_tile_size.size(); ++i) { - auto c0 = c0_tile_size[i] <= 0 ? 1 : c0_tile_size[i]; - c0_tile_size[i] = std::max(1, c1_tile_size[i] / c0); - } - std::string c1_cfg = ""; - std::string c0_cfg = ""; - std::vector c1_cfg_list; - std::vector c0_cfg_list; - if (std::accumulate(c0_tile_size.begin(), c0_tile_size.end(), 1, std::multiplies()) > 1) { - for (size_t i = 0; i < n_block_map; ++i) { - auto c0 = i < c0_tile_size.size() ? c0_tile_size[i] : 1; - c0_cfg_list.emplace_back(c0); - auto block_idx = scop_info_.analysis_result_.GetReduceDirection() == Y_DIRECTION ? i : n_block_map - 1 - i; - auto c1 = block_cfg->GetAt(block_idx).second / c0; - c1_cfg_list.emplace_back(c1); - } - if (scop_info_.analysis_result_.GetReduceDirection() != Y_DIRECTION) { - std::reverse(c1_cfg_list.begin(), c1_cfg_list.end()); - std::reverse(c0_cfg_list.begin(), c0_cfg_list.end()); - } - scop_info_.user_config_.SetC0BlockSize(c0_cfg_list); - for (size_t i = 0; i < n_block_map; ++i) { - c1_cfg += (std::to_string(c1_cfg_list[i]) + " "); - c0_cfg += (std::to_string(c0_cfg_list[i]) + " "); } - } - return std::make_pair(c1_cfg, c0_cfg); -} -isl::schedule_node MappingOuterBand::MapBlockHelper(const isl::schedule_node &orig_node, MappingCfg *block_cfg, - size_t n_block_map, bool check_extent, - std::unordered_map map_idx_shift) { - auto node = orig_node; - auto band_node = node.as(); - if (!band_node || !band_node.permutable()) { - LOG(WARNING) << "No permutable outer band node to map block."; return node; - } - - auto partial_schedule = band_node.get_partial_schedule(); - auto upa_list = partial_schedule.get_union_pw_aff_list(); - - if (check_extent) { - auto domain = band_node.get_schedule().get_domain(); - isl::union_pw_aff_list range_aff_list(band_node.ctx(), static_cast(upa_list.size())); - for (int i = upa_list.size() - 1; i >= 0; --i) { - auto range = upa_list.get_at(i).intersect_domain(domain); - range_aff_list = range_aff_list.add(range); - } - node = CheckMapSizeAndApplyTile(node, range_aff_list, block_cfg, false); - } - - upa_list = upa_list.drop(n_block_map, upa_list.size() - n_block_map).reverse(); - - node = node.insert_mark(isl::id(node.ctx(), BLOCK_MARKER)); - node = node.child(0); - - Mapping mapping; - node = CreateAndInsertMapFilter(node, false, upa_list, block_cfg, mapping, map_idx_shift); - scop_info_.upa_node_mapping_.emplace_back(std::make_pair(node.parent(), mapping)); - - return node; + }; + final_schedule = sch.get_root().map_descendant_bottom_up(MapFromInner).get_schedule(); + return final_schedule; } isl::schedule MappingOuterBand::DoBlockMapping(const isl::schedule &sch) { @@ -766,97 +455,44 @@ isl::schedule MappingOuterBand::DoBlockMapping(const isl::schedule &sch) { } } - if (scop_info_.user_config_.GetEnableAtomicAdd() && NeedAtomicAdd(band_node, n_block_map)) { - MarkAtomicAddTensor(band_node); - } - - // Step 2. Separate original block config according to tile levels. - auto c1_block_cfg = block_cfg; - MappingCfg *c0_block_cfg = nullptr; - if (scop_info_.user_config_.GetEnableTileC0()) { - auto c1_c0_block_cfg = GetC1C0BlockConfig(n_block_map, band_node.n_member()); - if (!c1_c0_block_cfg.first.empty() && !c1_c0_block_cfg.second.empty()) { - scop_info_.user_config_.RecordReplaceConfig(TILE_WITH_C1, c1_c0_block_cfg.first, MappingType::REPLACE_BLOCKS); - scop_info_.user_config_.RecordReplaceConfig(TILE_WITH_C0, c1_c0_block_cfg.second, MappingType::REPLACE_BLOCKS); - auto rep_cfg = scop_info_.user_config_.GetReplaceConfig(); - c1_block_cfg = rep_cfg[TILE_WITH_C1]; - c0_block_cfg = rep_cfg[TILE_WITH_C0]; - } + // Step 2. Map outer-most band for c1 tile as usual (and do not check extent when c0 tile is applied manually). + if (scop_info_.user_config_.GetEnableAkgReduceLib()) { + // reduce operator + ReduceMappingStrategy reduce_op(pass_info_, scop_info_); + if (scop_info_.user_config_.GetEnableAtomicAdd() && reduce_op.NeedAtomicAdd(band_node, n_block_map)) { + reduce_op.MarkAtomicAddTensor(band_node); + } + node = reduce_op.MapBlockHelper(node, block_cfg, n_block_map, map_idx_shift.empty(), map_idx_shift); + } else if (scop_info_.user_config_.GetEnableConvTensorCore()) { + // conv operator + ConvMappingStrategy conv_op(pass_info_, scop_info_); + node = conv_op.ResetConvBlockMappingConfig(node, block_cfg, map_idx_shift.empty()); + } else { + // others operator + OperatorMappingStrategy others_op(pass_info_, scop_info_); + node = others_op.MapBlockHelper(node, block_cfg, n_block_map, map_idx_shift.empty(), map_idx_shift); } - // Step 3. Map outer-most band for c1 tile as usual (and do not check extent when c0 tile is applied manually). - auto map_c0_block = c0_block_cfg != nullptr; - bool check_extent = !map_c0_block && map_idx_shift.empty(); - node = MapBlockHelper(node, c1_block_cfg, n_block_map, check_extent, map_idx_shift); auto final_schedule = node.get_schedule(); - - // Step 4. Map middle-level band (i.e. c0 tile band). - if (map_c0_block) { - isl::schedule_node middle_node = GetOuterBand(final_schedule.get_root()).child(0); - middle_node = MapBlockHelper(middle_node, c0_block_cfg, n_block_map, false); - final_schedule = middle_node.get_schedule(); - } - return final_schedule; } -bool MappingOuterBand::NeedAtomicAdd(const isl::schedule_node_band &band, size_t n_block_map) { - if (!scop_info_.user_config_.GetEnableAkgReduceLib()) { - return false; - } - - auto non_coin_start_idx = CountConsecutiveCoincident(band); - bool is_all_reduce = - band.n_member() == 1 && scop_info_.analysis_result_.GetReduceDirection() == X_DIRECTION && non_coin_start_idx == 1; - if (is_all_reduce) { - non_coin_start_idx = 0; // Compare block size of position 0 to enable atomic add for all reduce ops - } - if (n_block_map < non_coin_start_idx) { - return false; - } - - auto block_cfg = scop_info_.user_config_.GetBlockConfig(); - CHECK(block_cfg != nullptr) << "block config is null"; - while (non_coin_start_idx < block_cfg->bound) { - auto idx = block_cfg->bound - non_coin_start_idx - 1; - if (block_cfg->GetAt(idx).second > 1) { - return true; - } - ++non_coin_start_idx; - } - return false; -} - -void MappingOuterBand::MarkAtomicAddTensor(const isl::schedule_node_band &band) { - auto target_stmt = scop_info_.analysis_result_.GetReduceWriteStmt(band); - auto tensor = target_stmt.range(); - std::unordered_set stmt_ids; - target_stmt.foreach_map( - [this, &stmt_ids](const isl::map m) { stmt_ids.insert(m.get_tuple_id(isl_dim_type::isl_dim_in)); }); - tensor.foreach_set([this, &stmt_ids](const isl::set &s) -> void { - for (auto it : scop_info_.analysis_result_.GetReduceTensorInfoMap()) { - auto provide = static_cast(it.second.stmt_node); - if (stmt_ids.count(it.first) == 0 || provide->func->func_name() != s.get_tuple_name()) { - continue; - } - auto type = scop_info_.analysis_result_.GetReduceOpType(it.first); - scop_info_.analysis_result_.RecordAtomicTensors(AtomicInfo{s.get_tuple_name(), type}); - } - }); -} - isl::schedule MappingOuterBand::Run(isl::schedule sch) { - auto node = sch.root().child(0); - node = InsertContextNode(node, scop_info_); - sch = node.schedule(); + sch = InsertContextNode(sch, scop_info_); if (scop_info_.user_config_.GetEnableAkgReduceLib()) { - sch = DetectAndMarkReduce(sch); + ReduceMappingStrategy reduce_op(pass_info_, scop_info_); + sch = reduce_op.DetectAndMarkReduce(sch); } sch = DoThreadMapping(sch); sch = DoBlockMapping(sch); + + if (scop_info_.user_config_.GetEnableConvTensorCore()) { + ConvMappingStrategy conv_op(pass_info_, scop_info_); + sch = conv_op.MoveKernelHWBand(sch); + } return sch; } diff --git a/src/poly/schedule_pass_gpu/mapping_outer_band.h b/src/poly/schedule_pass_gpu/mapping_outer_band.h index 49a2d3e286cf10abfde7100a099b8ec36a4332b7..85354389ef290c2c574293670e73b56623fecc79 100644 --- a/src/poly/schedule_pass_gpu/mapping_outer_band.h +++ b/src/poly/schedule_pass_gpu/mapping_outer_band.h @@ -38,24 +38,23 @@ class MappingOuterBand : public SchedulePass { virtual isl::schedule Run(isl::schedule sch); + isl::schedule DoThreadMapping(const isl::schedule &sch); + isl::schedule DoBlockMapping(const isl::schedule &sch); - std::pair GetC1C0BlockConfig(size_t n_block_map, int member_size); - bool NeedAtomicAdd(const isl::schedule_node_band &band, size_t n_block_map); - void MarkAtomicAddTensor(const isl::schedule_node_band &band); - isl::schedule_node MapBlockHelper(const isl::schedule_node &node, MappingCfg *block_cfg, size_t n_block_map, - bool check_extent, std::unordered_map map_idx_shift = {}); - isl::schedule DoThreadMapping(const isl::schedule &sch); - size_t MapThreadHelper(isl::schedule_node &thread_root); size_t NumMappedDescendant(const RoadMap &thread_roadmap, const isl::schedule_node parent); bool CanBeMappedToThread(const isl::schedule_node node, const RoadMap &thread_record); isl::schedule_node FillRemainingThreads(isl::schedule_node &node, size_t begin); - size_t CountConsecutiveCoincident(const isl::schedule_node_band &band_node); + isl::schedule_node MapSequenceNode(const isl::schedule_node &orig_node, const RoadMap &thread_record); - isl::schedule_node DoThreadSynchronization(const isl::schedule_node &node); + /* + * Functions related to synchronization. + */ + isl::schedule_node DoThreadSynchronization(const isl::schedule_node &node, + const std::vector other_mapping_cfg = {}); // preparation for synchronization isl::multi_union_pw_aff MapDomainToWarp(const isl::schedule_node &nod, MappingCfg *mapping_cfg, @@ -69,12 +68,6 @@ class MappingOuterBand : public SchedulePass { SyncCandidate *CountSyncNumberAmongLoop(SyncCandidate *head); int GetBestSyncStartPoint(bool is_outer); - size_t GetReduceId() const; - std::string GetMarkerName(const isl::schedule_node &node, std::string find_name); - isl::schedule DetectAndMarkReduce(const isl::schedule &sch); - isl::schedule InsertReduceMarker(const isl::schedule &sch); - isl::schedule_node InsertReduceExtension(const isl::schedule_node &node); - private: PassInfo &pass_info_; ScopInfo &scop_info_; diff --git a/src/poly/schedule_pass_gpu/operator_mapping_strategy.cc b/src/poly/schedule_pass_gpu/operator_mapping_strategy.cc new file mode 100644 index 0000000000000000000000000000000000000000..7af7ff820a3d95378102d2960c33afd001342695 --- /dev/null +++ b/src/poly/schedule_pass_gpu/operator_mapping_strategy.cc @@ -0,0 +1,446 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "operator_mapping_strategy.h" + +#include + +#include "poly/schedule_tree_util.h" +#include "poly/sync_manager.h" +#include "poly/scop.h" + +namespace akg { +namespace ir { +namespace poly { + +size_t OperatorMappingStrategy::GetFinalMappingThreadNumber(isl::schedule_node &node, const size_t thread_cfg_bound, + const size_t n_thread_map) { + auto final_n_thread_map = n_thread_map; + isl::schedule_node_band band_node = node.as(); + // Split band node according to mapping config and coincidence of band node. + if (final_n_thread_map > thread_cfg_bound) { + node = band_node.split(final_n_thread_map - thread_cfg_bound); + node = node.child(0); + final_n_thread_map = thread_cfg_bound; + band_node = node.as(); + } + + // Split to keep nodes with coincident equals to 1. + if (final_n_thread_map < band_node.n_member() && !scop_info_.user_config_.EnableStitchFusion()) { + node = band_node.split(final_n_thread_map); + } else { + final_n_thread_map = static_cast(band_node.n_member()); + } + return final_n_thread_map; +} + +size_t OperatorMappingStrategy::MapThreadHelper(isl::schedule_node &thread_root) { + auto thread_cfg = scop_info_.user_config_.GetThreadConfig(); + CHECK(thread_cfg != nullptr) << "thread config is null"; + if (thread_cfg->bound < 1 || !thread_root.isa()) { + return 0; + } + + int start_node_depth = thread_root.get_tree_depth(); + // Determine max num dimension of threads that can be mapped. + auto n_thread_map = CountConsecutiveCoincident(thread_root); + if (n_thread_map < 1) { + return 0; + } + n_thread_map = GetFinalMappingThreadNumber(thread_root, thread_cfg->bound, n_thread_map); + + // Map band under thread_root from inner dim to outer dim. + Mapping mapping; + thread_root = MapInnerDimToThreads(thread_root, false, thread_cfg, mapping, false); + auto tile_node = GetMarkerName(thread_root, THREAD_MARKER).empty() ? thread_root.child(0) : thread_root; + scop_info_.upa_node_mapping_.emplace_back(std::make_pair(tile_node, mapping)); + + // Do unroll if needed. + if (scop_info_.user_config_.GetMaxUnrollLoop() != 1) { + isl::schedule_node unroll_node = thread_root.child(0); + thread_root = UnrollByMarkOptions(unroll_node, scop_info_.user_config_.GetMaxUnrollLoop()); + } + + int end_node_depth = thread_root.get_tree_depth() - start_node_depth; + thread_root = thread_root.ancestor(end_node_depth); + return thread_cfg->bound; +} + +isl::schedule_node OperatorMappingStrategy::MapBlockHelper(const isl::schedule_node &orig_node, MappingCfg *block_cfg, + size_t n_block_map, bool check_extent, + std::unordered_map map_idx_shift) { + auto node = orig_node; + auto band_node = node.as(); + if (!band_node || !band_node.permutable()) { + LOG(WARNING) << "No permutable outer band node to map block."; + return node; + } + + auto partial_schedule = band_node.get_partial_schedule(); + auto upa_list = partial_schedule.get_union_pw_aff_list(); + + if (check_extent) { + auto domain = band_node.get_schedule().get_domain(); + isl::union_pw_aff_list range_aff_list(band_node.ctx(), static_cast(upa_list.size())); + for (int i = upa_list.size() - 1; i >= 0; --i) { + auto range = upa_list.get_at(i).intersect_domain(domain); + range_aff_list = range_aff_list.add(range); + } + node = CheckMapSizeAndApplyTile(node, range_aff_list, block_cfg, false); + } + + upa_list = upa_list.drop(n_block_map, upa_list.size() - n_block_map).reverse(); + + node = node.insert_mark(isl::id(node.ctx(), BLOCK_MARKER)); + node = node.child(0); + + Mapping mapping; + node = CreateAndInsertMapFilter(node, false, upa_list, block_cfg, mapping, map_idx_shift); + scop_info_.upa_node_mapping_.emplace_back(std::make_pair(node.parent(), mapping)); + + return node; +} + +size_t ReduceMappingStrategy::MapThreadHelper(isl::schedule_node &thread_root) { + auto thread_cfg = scop_info_.user_config_.GetThreadConfig(); + CHECK(thread_cfg != nullptr) << "thread config is null"; + if (thread_cfg->bound < 1 || !thread_root.isa()) { + return 0; + } + + int start_node_depth = thread_root.get_tree_depth(); + // Determine max num dimension of threads that can be mapped. + auto n_thread_map = CountConsecutiveCoincident(thread_root); + + std::string reduce_marker_name = ""; + if (thread_root.has_parent()) { + reduce_marker_name = GetMarkerName(thread_root.parent(), REDUCE_MARKER); + if (!reduce_marker_name.empty()) { + thread_root = thread_root.parent().del(); + ++n_thread_map; + } + } + + // When akg reduce lib is enabled, we can try to map other injective statements whose coincidence equals 0 + if (n_thread_map < thread_cfg->bound && scop_info_.user_config_.GetEnableAkgReduceLib()) { + n_thread_map = thread_cfg->bound; + } + + if (n_thread_map < 1) { + return 0; + } + n_thread_map = GetFinalMappingThreadNumber(thread_root, thread_cfg->bound, n_thread_map); + + // Map band under thread_root from inner dim to outer dim. + Mapping mapping; + bool is_y_reduce = scop_info_.analysis_result_.GetReduceDirection() == Y_DIRECTION; + thread_root = MapInnerDimToThreads(thread_root, false, thread_cfg, mapping, is_y_reduce); + + // If the current band is split during the mapping process, split the reduce axis and non-reduce axis of + // the outer band. + bool is_tiled = GetMarkerName(thread_root, THREAD_MARKER).empty(); + if (is_tiled && n_thread_map > 1) { + isl::schedule_node_band band_node = thread_root.as(); + thread_root = band_node.split(n_thread_map - 1).child(0); + } + thread_root = thread_root.insert_mark(reduce_marker_name); + thread_root = thread_root.child(0); + // Add the filter that initializes and calls the akg_reduce library for the reduce statement. + thread_root = InsertReduceExtension(thread_root); + // The band corresponding to the reduce statement has a REDUCE_MARKER that needs to be deleted at the beginning. + int end_node_depth = thread_root.get_tree_depth() - start_node_depth + 1; + thread_root = thread_root.ancestor(end_node_depth); + scop_info_.upa_node_mapping_.emplace_back(std::make_pair(thread_root, mapping)); + return thread_cfg->bound; +} + +size_t ReduceMappingStrategy::GetReduceId() const { + static size_t reduce_count = 0; + return reduce_count++; +} + +isl::schedule_node ReduceMappingStrategy::InsertReduceExtension(const isl::schedule_node &node) { + auto thread_cfg = scop_info_.user_config_.GetThreadConfig(); + CHECK(thread_cfg != nullptr) << "thread config is null"; + + isl::schedule_node insert_node = node; + isl::schedule_node parent_node = node; + isl::schedule_node ancestor_node = node; + if (insert_node.has_parent()) { + parent_node = parent_node.parent(); + if (parent_node.has_parent()) { + ancestor_node = parent_node.parent(); + } + } + + std::string reduce_marker_name = ""; + if (!GetMarkerName(parent_node, REDUCE_MARKER).empty()) { + reduce_marker_name = GetMarkerName(parent_node, REDUCE_MARKER); + insert_node = parent_node.del(); + } + + if (!GetMarkerName(ancestor_node, REDUCE_MARKER).empty()) { + reduce_marker_name = GetMarkerName(ancestor_node, REDUCE_MARKER); + insert_node = ancestor_node.del(); + } + + if (reduce_marker_name.empty()) { + return node; + } + + reduce_marker_name.erase(0, strlen(REDUCE_MARKER)); + isl::id sync_id = isl::id(insert_node.ctx(), REDUCE_UPDATE + reduce_marker_name); + isl::id reduction_id = isl::id(insert_node.ctx(), REDUCE_INIT + reduce_marker_name); + + insert_node = InsertExtensionNodeBeforeOrAfter(insert_node, reduction_id, true); + insert_node = InsertExtensionNodeBeforeOrAfter(insert_node, sync_id, false).parent(); + insert_node = insert_node.parent().insert_mark(REDUCE_AREA_FLAG); + + return insert_node; +} + +isl::schedule ReduceMappingStrategy::DetectAndMarkReduce(const isl::schedule &sch) { + auto final_schedule = sch; + auto thread_cfg = scop_info_.user_config_.GetThreadConfig(); + CHECK(thread_cfg != nullptr) << "threadconfig is null"; + if (thread_cfg->bound == 0) { + return final_schedule; + } + + auto all_reduce_map = scop_info_.analysis_result_.GetReduceTensorInfoMap(); + ReduceManager reduce_manager; + bool done_separate = false; + auto GetInnerMostBand = [&done_separate, &all_reduce_map, &reduce_manager, thread_cfg, + this](isl::schedule_node node) -> isl::schedule_node { + if (done_separate) { + return node; + } + auto band_node = node.as(); + if (!band_node || !band_node.permutable()) { + return node; + } + + auto band_node_domain = band_node.get_partial_schedule().domain(); + StatementMap all_statements = scop_info_.analysis_result_.GetStatementMap(); + isl::union_map reduce_statement_map = isl::union_map::empty(node.ctx()); + isl::union_set reduce_statements = isl::union_set::empty(node.ctx()); + + for (auto it = all_reduce_map.begin(); it != all_reduce_map.end();) { + reduce_statement_map = reduce_statement_map.unite(it->second.stmt_map); + auto this_reduce = reduce_manager.GetReduceStatements(band_node_domain, reduce_statement_map, all_statements); + if (!this_reduce.is_empty()) { + reduce_statements = reduce_statements.unite(this_reduce); + all_reduce_map.erase(it++); + } else { + ++it; + } + } + + if (reduce_statements.n_set() < 1) { + return node; + } + + isl::union_map dependences = pass_info_.dependences_.subtract(pass_info_.force_dependences_); + auto node_bak = node; + if (!reduce_manager.SplitReduceStatements(node, reduce_statements, dependences, true)) { + return node_bak; + } + done_separate = all_reduce_map.empty(); + return node; + }; + final_schedule = sch.get_root().map_descendant_bottom_up(GetInnerMostBand).get_schedule(); + if (done_separate) { + final_schedule = InsertReduceMarker(final_schedule); + } + return final_schedule; +} + +isl::schedule ReduceMappingStrategy::InsertReduceMarker(const isl::schedule &sch) { + isl::schedule final_schedule = sch; + auto all_reduce_map = scop_info_.analysis_result_.GetReduceTensorInfoMap(); + auto InsertMarker = [&all_reduce_map, this](isl::schedule_node node) -> isl::schedule_node { + ReduceManager reduce_manager; + auto band_node = node.as(); + if (!band_node) { + return node; + } + + for (auto it = all_reduce_map.begin(); it != all_reduce_map.end();) { + isl::union_map reduce_statement_map = it->second.stmt_map; + isl::id reduce_id = it->first; + auto band_node_domain = band_node.get_partial_schedule().domain(); + auto op_type = scop_info_.analysis_result_.GetReduceOpType(reduce_id) + "_"; + + StatementMap all_statements = scop_info_.analysis_result_.GetStatementMap(); + isl::union_set reduce_statements = + reduce_manager.GetReduceStatements(band_node_domain, reduce_statement_map, all_statements); + if (reduce_statements.n_set() != 1) { + ++it; + continue; + } + + all_reduce_map.erase(it++); + std::string reduce_marker_name = + REDUCE_MARKER + op_type + reduce_id.get_name() + "_" + std::to_string(GetReduceId()); + auto reduce_node = band_node.insert_mark(reduce_marker_name); + return reduce_node; + } + return band_node; + }; + final_schedule = final_schedule.get_root().map_descendant_bottom_up(InsertMarker).get_schedule(); + return final_schedule; +} + +bool ReduceMappingStrategy::NeedAtomicAdd(const isl::schedule_node_band &band, size_t n_block_map) { + if (!scop_info_.user_config_.GetEnableAkgReduceLib()) { + return false; + } + + auto non_coin_start_idx = CountConsecutiveCoincident(band); + bool is_all_reduce = + band.n_member() == 1 && scop_info_.analysis_result_.GetReduceDirection() == X_DIRECTION && non_coin_start_idx == 1; + if (is_all_reduce) { + non_coin_start_idx = 0; // Compare block size of position 0 to enable atomic add for all reduce ops + } + if (n_block_map < non_coin_start_idx) { + return false; + } + + auto block_cfg = scop_info_.user_config_.GetBlockConfig(); + CHECK(block_cfg != nullptr) << "block config is null"; + while (non_coin_start_idx < block_cfg->bound) { + auto idx = block_cfg->bound - non_coin_start_idx - 1; + if (block_cfg->GetAt(idx).second > 1) { + return true; + } + ++non_coin_start_idx; + } + return false; +} + +void ReduceMappingStrategy::MarkAtomicAddTensor(const isl::schedule_node_band &band) { + auto target_stmt = scop_info_.analysis_result_.GetReduceWriteStmt(band); + auto tensor = target_stmt.range(); + std::unordered_set stmt_ids; + target_stmt.foreach_map( + [this, &stmt_ids](const isl::map m) { stmt_ids.insert(m.get_tuple_id(isl_dim_type::isl_dim_in)); }); + tensor.foreach_set([this, &stmt_ids](const isl::set &s) -> void { + for (auto it : scop_info_.analysis_result_.GetReduceTensorInfoMap()) { + auto provide = static_cast(it.second.stmt_node); + if (stmt_ids.count(it.first) == 0 || provide->func->func_name() != s.get_tuple_name()) { + continue; + } + auto type = scop_info_.analysis_result_.GetReduceOpType(it.first); + scop_info_.analysis_result_.RecordAtomicTensors(AtomicInfo{s.get_tuple_name(), type}); + } + }); +} + +size_t BatchMatmulMappingStrategy::MapThreadHelper(isl::schedule_node &thread_root) { + auto warp_cfg = scop_info_.user_config_.GetReplaceConfig()[WARP_COMPUTE]; + CHECK(warp_cfg != nullptr) << "warp config is null"; + if (warp_cfg->bound < 1 || !thread_root.isa()) { + return 0; + } + + int start_node_depth = thread_root.get_tree_depth(); + // Determine max num dimension of threads that can be mapped. + auto n_thread_map = CountConsecutiveCoincident(thread_root); + if (n_thread_map < 1) { + return 0; + } + n_thread_map = GetFinalMappingThreadNumber(thread_root, warp_cfg->bound, n_thread_map); + + // Map band under thread_root from inner dim to outer dim. + Mapping mapping; + thread_root = MapInnerDimToThreads(thread_root, false, warp_cfg, mapping, true); + bool is_tiled = GetMarkerName(thread_root, THREAD_MARKER).empty(); + thread_root = is_tiled ? thread_root.child(0) : thread_root; + thread_root = thread_root.del().insert_mark(isl::id(thread_root.ctx(), WARP_MARKER)); + + int end_node_depth = thread_root.get_tree_depth() - start_node_depth; + thread_root = thread_root.ancestor(end_node_depth); + scop_info_.upa_node_mapping_.emplace_back(std::make_pair(thread_root, mapping)); + return warp_cfg->bound; +} + +isl::schedule_node ConvMappingStrategy::ResetConvBlockMappingConfig(const isl::schedule_node &orig_node, + MappingCfg *block_cfg, const bool check_extent) { + if (!orig_node.isa()) { + return orig_node; + } + const unsigned outer_band_axis_size = 4; + auto node = orig_node; + CHECK_GE(node.as().n_member(), outer_band_axis_size); + + // For the convolution operator, n axis is mapped to blockIdx.z, h axis and w axis are mapped to blockIdx.y, o axis + // is mapped to blockIdx.x, + node = node.as().split(1); + auto new_cfg = std::to_string(block_cfg->GetZ().second); + scop_info_.user_config_.RecordReplaceConfig(CONV_N, new_cfg, MappingType::REPLACE_BLOCKS); + auto conv_o_block_cfg = scop_info_.user_config_.GetReplaceConfig()[CONV_N]; + node = MapBlockHelper(node, conv_o_block_cfg, 1, check_extent); + + node = node.child(0).child(0).as().split(2); + auto partial_schedule = node.as().get_partial_schedule(); + partial_schedule = partial_schedule.intersect_domain(node.get_domain()); + auto upa_list = partial_schedule.get_union_pw_aff_list(); + auto extent_h = upa_list.get_at(0).floor().max_val().get_num_si() + 1; + auto bind_block_h = std::min(static_cast(extent_h), block_cfg->GetY().second); + new_cfg = std::to_string(block_cfg->GetY().second / bind_block_h) + " " + std::to_string(bind_block_h); + scop_info_.user_config_.RecordReplaceConfig(CONV_H_W, new_cfg, MappingType::REPLACE_BLOCKS); + auto conv_h_w_block_cfg = scop_info_.user_config_.GetReplaceConfig()[CONV_H_W]; + node = MapBlockHelper(node, conv_h_w_block_cfg, 2, check_extent); + + node = node.child(0).child(0); + new_cfg = std::to_string(block_cfg->GetX().second); + scop_info_.user_config_.RecordReplaceConfig(CONV_O, new_cfg, MappingType::REPLACE_BLOCKS); + auto conv_n_block_cfg = scop_info_.user_config_.GetReplaceConfig()[CONV_O]; + node = MapBlockHelper(node, conv_n_block_cfg, 1, check_extent); + return node; +} + +isl::schedule ConvMappingStrategy::MoveKernelHWBand(isl::schedule sch) { + auto node = sch.root(); + isl::multi_union_pw_aff kh_mupa = isl::multi_union_pw_aff::zero(node.get_domain().get_space().set_from_params()); + isl::multi_union_pw_aff kw_mupa = kh_mupa; + auto MapFromInner = [this, &kh_mupa, &kw_mupa](isl::schedule_node node) -> isl::schedule_node { + if (!GetMarkerName(node, KH_KW_MARKER).empty()) { + node = node.child(0); + kh_mupa = node.as().get_partial_schedule(); + node = node.del(); + kw_mupa = node.as().get_partial_schedule(); + node = node.del(); + node = node.parent().del(); + return node; + } + if (!GetMarkerName(node, PROMOTE_GLOBAL_TO_SHARED_AB).empty()) { + node = node.insert_mark(CONV_KHKW_OUTER).child(0); + node = node.insert_partial_schedule(kw_mupa); + node = node.as().set_permutable(1); + node = node.insert_partial_schedule(kh_mupa); + node = node.as().set_permutable(1); + return node; + } + return node; + }; + sch = sch.get_root().map_descendant_bottom_up(MapFromInner).get_schedule(); + return sch; +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass_gpu/operator_mapping_strategy.h b/src/poly/schedule_pass_gpu/operator_mapping_strategy.h new file mode 100644 index 0000000000000000000000000000000000000000..d5808589e4dc4c08b7f415c4ed165e35dbad1bac --- /dev/null +++ b/src/poly/schedule_pass_gpu/operator_mapping_strategy.h @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef POLY_OPRATOR_MAPPING_STRATEGY_H_ +#define POLY_OPRATOR_MAPPING_STRATEGY_H_ + +#include "poly/schedule_pass.h" +#include "poly/reduce_manager.h" + +namespace akg { +namespace ir { +namespace poly { + +class OperatorMappingStrategy { + public: + explicit OperatorMappingStrategy(PassInfo &pass_info, ScopInfo &scop_info) + : pass_info_(pass_info), scop_info_(scop_info) {} + ~OperatorMappingStrategy() {} + + size_t GetFinalMappingThreadNumber(isl::schedule_node &node, const size_t thread_cfg_bound, + const size_t n_thread_map); + virtual size_t MapThreadHelper(isl::schedule_node &thread_root); + virtual isl::schedule_node MapBlockHelper(const isl::schedule_node &orig_node, MappingCfg *block_cfg, + size_t n_block_map, bool check_extent, + std::unordered_map map_idx_shift = {}); + + protected: + PassInfo &pass_info_; + ScopInfo &scop_info_; +}; + +class ReduceMappingStrategy : public OperatorMappingStrategy { + public: + explicit ReduceMappingStrategy(PassInfo &pass_info, ScopInfo &scop_info) + : OperatorMappingStrategy(pass_info, scop_info) {} + ~ReduceMappingStrategy() {} + + size_t MapThreadHelper(isl::schedule_node &thread_root); + + bool NeedAtomicAdd(const isl::schedule_node_band &band, size_t n_block_map); + void MarkAtomicAddTensor(const isl::schedule_node_band &band); + size_t GetReduceId() const; + isl::schedule DetectAndMarkReduce(const isl::schedule &sch); + isl::schedule InsertReduceMarker(const isl::schedule &sch); + isl::schedule_node InsertReduceExtension(const isl::schedule_node &node); +}; + +class BatchMatmulMappingStrategy : public OperatorMappingStrategy { + public: + explicit BatchMatmulMappingStrategy(PassInfo &pass_info, ScopInfo &scop_info) + : OperatorMappingStrategy(pass_info, scop_info) {} + ~BatchMatmulMappingStrategy() {} + + size_t MapThreadHelper(isl::schedule_node &thread_root); +}; + +class ConvMappingStrategy : public OperatorMappingStrategy { + public: + explicit ConvMappingStrategy(PassInfo &pass_info, ScopInfo &scop_info) + : OperatorMappingStrategy(pass_info, scop_info) {} + ~ConvMappingStrategy() {} + + isl::schedule_node ResetConvBlockMappingConfig(const isl::schedule_node &orig_node, MappingCfg *block_cfg, + const bool check_extent); + isl::schedule MoveKernelHWBand(isl::schedule sch); +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_OPRATOR_MAPPING_STRATEGY_H_ \ No newline at end of file diff --git a/src/poly/schedule_pass_gpu/realize_manager.cc b/src/poly/schedule_pass_gpu/realize_manager.cc index 18f1373a33b5ba5dd157085dbce1972a05b49b39..ebd58833f2f480e10280e463313f104f772b274d 100644 --- a/src/poly/schedule_pass_gpu/realize_manager.cc +++ b/src/poly/schedule_pass_gpu/realize_manager.cc @@ -29,7 +29,8 @@ isl::id RealizeManager::GetRealizeId(const isl::schedule_node &node, std::string return isl::id(node.ctx(), realize_id); } -isl::schedule_node RealizeManager::InsertExtensionNodeBefore(const isl::schedule_node &node, const std::string tensor_name) { +isl::schedule_node RealizeManager::InsertExtensionNodeBefore(const isl::schedule_node &node, + const std::string tensor_name) { auto space = GetExtensionSpace(node, tensor_name); isl::schedule_node graft = isl::schedule_node::from_extension(space); auto extension_node = node; @@ -49,6 +50,8 @@ isl::map RealizeManager::GetExtensionSpace(const isl::schedule_node &node, const isl::schedule_node RealizeManager::BreadthFirstTopDown(const isl::schedule_node &root, bool &end) { std::queue bfs_queue; bfs_queue.push(root); + std::unordered_set promotion_read_set = {READ_ID_NAME, SHARED_READ_ID_NAME, GML_READ_ID_NAME}; + std::unordered_set promotion_write_set = {WRITE_ID_NAME, SHARED_WRITE_ID_NAME, GML_WRITE_ID_NAME}; isl::schedule_node top; while (!bfs_queue.empty()) { @@ -65,8 +68,8 @@ isl::schedule_node RealizeManager::BreadthFirstTopDown(const isl::schedule_node } auto filter_node = top.as(); std::string filter_name = GetFilterName(filter_node); - if (filter_name != READ_ID_NAME && filter_name != WRITE_ID_NAME && - filter_name != SHARED_READ_ID_NAME && filter_name != SHARED_WRITE_ID_NAME) { + if (promotion_read_set.find(filter_name) == promotion_read_set.end() && + promotion_write_set.find(filter_name) == promotion_write_set.end()) { continue; } std::string tensor_name = GetTensorName(filter_node); @@ -74,19 +77,19 @@ isl::schedule_node RealizeManager::BreadthFirstTopDown(const isl::schedule_node continue; } // Insert realize node for read node - if (filter_name == READ_ID_NAME || filter_name == SHARED_READ_ID_NAME) { + if (promotion_read_set.find(filter_name) != promotion_read_set.end()) { top = InsertExtensionNodeBefore(top.child(0), tensor_name).parent(); names_set_.insert(tensor_name); break; } // Insert realize node for write node - if (filter_name == WRITE_ID_NAME || filter_name == SHARED_WRITE_ID_NAME) { + if (promotion_write_set.find(filter_name) != promotion_write_set.end()) { size_t i = 0; auto top_parent = top.parent(); for (; i < top_parent.n_children(); ++i) { auto tmp_name = GetFilterName(top_parent.child(i).as()); - if ((tmp_name != READ_ID_NAME) && (tmp_name != WRITE_ID_NAME) && - (tmp_name != SHARED_READ_ID_NAME) && (tmp_name != SHARED_WRITE_ID_NAME)) { + if (promotion_read_set.find(tmp_name) == promotion_read_set.end() && + promotion_write_set.find(tmp_name) == promotion_write_set.end()) { break; } } @@ -107,9 +110,7 @@ std::string RealizeManager::GetFilterName(const isl::schedule_node_filter &filte if (filter_node) { isl::union_set uset = filter_node.get_filter(); std::vector vset; - uset.foreach_set([&vset](isl::set s) { - vset.push_back(s); - }); + uset.foreach_set([&vset](isl::set s) { vset.push_back(s); }); if (!vset.empty()) { filter_name = vset[0].get_tuple_name(); } @@ -122,9 +123,7 @@ std::string RealizeManager::GetTensorName(const isl::schedule_node_filter &filte if (filter_node) { isl::union_set uset = filter_node.get_filter(); std::vector vset; - uset.foreach_set([&vset](isl::set s) { - vset.push_back(s); - }); + uset.foreach_set([&vset](isl::set s) { vset.push_back(s); }); if (!vset.empty()) { tensor_name = vset[0].unwrap().get_tuple_id(isl_dim_out).get_name(); } @@ -142,10 +141,12 @@ isl::schedule_node RealizeManager::InsertRealize(const isl::schedule_node &root) while (!end) { res_root = BreadthFirstTopDown(res_root, end); } + return res_root; } isl::schedule RealizeManager::Run(isl::schedule sch) { + sch = scop_info_.sync_manager_.InsertPromotionSync(sch); auto root = sch.get_root(); auto res_root = InsertRealize(root); names_set_.clear(); diff --git a/src/poly/schedule_pass_gpu/realize_manager.h b/src/poly/schedule_pass_gpu/realize_manager.h index 02bc1660708b4e1f9c9e05f8a020afb5edcfb76e..6af62b339c2963f4ae825cc6e3375ce4253b21f7 100644 --- a/src/poly/schedule_pass_gpu/realize_manager.h +++ b/src/poly/schedule_pass_gpu/realize_manager.h @@ -25,7 +25,9 @@ namespace poly { class RealizeManager : public SchedulePass { public: - explicit RealizeManager() { pass_name_ = __FUNCTION__; } + explicit RealizeManager(PassInfo &pass_info, ScopInfo &scop_info) : pass_info_(pass_info), scop_info_(scop_info) { + pass_name_ = __FUNCTION__; + }; ~RealizeManager() {} virtual isl::schedule Run(isl::schedule sch); @@ -33,6 +35,8 @@ class RealizeManager : public SchedulePass { isl::schedule_node InsertRealize(const isl::schedule_node &root); private: + PassInfo &pass_info_; + ScopInfo &scop_info_; std::set names_set_{}; isl::id GetRealizeId(const isl::schedule_node &node, std::string tensor_name) const; diff --git a/src/poly/schedule_pass_gpu/register_memory_manager.cc b/src/poly/schedule_pass_gpu/register_memory_manager.cc index 3592352e18335e498fe1c7f1745a3b64e1ebc14d..d8c22f9ea148cadc1f454cf20f5361db5ad76286 100644 --- a/src/poly/schedule_pass_gpu/register_memory_manager.cc +++ b/src/poly/schedule_pass_gpu/register_memory_manager.cc @@ -20,74 +20,38 @@ #include "poly/scop.h" #include "poly/dma_inject.h" -#include "poly/schedule_tree_util.h" #include "poly/poly_util.h" namespace akg { namespace ir { namespace poly { -isl::union_set RegisterMemoryManager::GatherMappingsTo(MappingCfg *cfg) { - isl::schedule_node root = schedule_.get_root(); - auto domain_node = root.as(); - auto domain = domain_node.domain(); - auto mapping_filters = CollectNode(schedule_); - - std::vector filters; - for (size_t idx = 0; idx < cfg->bound; ++idx) { - auto value = cfg->GetAt(idx); - auto id = isl::id(root.ctx(), value.first); - filters.push_back(id); - } - mapping_filters = FilterNode(mapping_filters, filters); - - auto mapping = isl::union_set::empty(domain.ctx()); - for (auto item : mapping_filters) { - if (item.isa()) { - auto filter = item.as(); - if (filter.has_parent() && !filter.parent().isa()) { - continue; - } - - isl::union_set uset = filter.get_filter(); - std::vector vset; - uset.foreach_set([&vset](isl::set s) { vset.push_back(s); }); - if (!vset.empty()) { - auto filter_name = vset[0].get_tuple_name(); - if (filter_name == READ_ID_NAME || filter_name == WRITE_ID_NAME) { - continue; - } - } - - mapping = mapping.unite(filter.filter()); - } - } - return mapping; -} - -void RegisterMemoryManager::SharedTensors() { +void RegisterMemoryManager::GetActualPromotedSharedTensors() { for (const auto &buffer : scop_info_.analysis_result_.active_buffer_footprints_) { auto cluster_id = buffer.second.cluster_id; - auto buf_def = scop_info_.analysis_result_.GetBufferDefInfo(cluster_id); - shared_tensors_ += buf_def.tensor_id.name() + " "; + shared_tensors_ += cluster_id.name() + " "; } } isl::schedule RegisterMemoryManager::HoistRegisterMemoryOnDepth(isl::schedule_node &node, size_t depth) { - auto block_cfg = scop_info_.user_config_.GetBlockConfig(); auto res_node = node; - auto block_mapping = GatherMappingsTo(block_cfg); - auto thread_cfg = scop_info_.user_config_.GetThreadConfig(); - auto mapping = GatherMappingsTo(thread_cfg).intersect(block_mapping); + isl::schedule_node root_node = node.get_schedule().get_root(); - auto partial_sched = LocalSchedule(node); - auto tmp_sched = partial_sched.intersect_domain(mapping); - if (scop_info_.user_config_.GetEnableMatmul() && scop_info_.user_config_.GetEnableTensorCoreUsePoly()) { - tmp_sched = partial_sched; + auto block_cfg = scop_info_.user_config_.GetBlockConfig(); + CHECK(block_cfg != nullptr) << "block config is null"; + auto replace_cfg = scop_info_.user_config_.GetReplaceConfig(); + auto block_mapping = GetBlockMappingFilterInfo(root_node, block_cfg, replace_cfg); + + auto thread_cfg = scop_info_.user_config_.GetThreadConfig(); + if (scop_info_.user_config_.GetEnableTensorCoreUsePoly()) { + thread_cfg = replace_cfg[WARP_COMPUTE]; } - CreateTensorCluster(node, tmp_sched); + CHECK(thread_cfg != nullptr) << "thread config is null"; + auto mapping = GatherMappingsTo(root_node, thread_cfg).intersect(block_mapping); - isl::schedule_node root_node = node.get_schedule().get_root(); + auto partial_sched = LocalSchedule(node); + partial_sched = partial_sched.intersect_domain(mapping); + CreateTensorCluster(node, partial_sched); isl::schedule sch = schedule_; if (memory_exceeding_) { @@ -115,11 +79,11 @@ isl::schedule RegisterMemoryManager::HoistRegisterMemoryOnDepth(isl::schedule_no if (scop_info_.user_config_.GetEnableMatmul() && !hoist_tensor_all_) { if (!hoist_compute_local_tensor_) { - if (buffer_info.dst_tensor_id.get_name() == local_tensor_c_ + LOCAL_SUFFIX) { + if (!IsTensorAB(buffer_info.dst_tensor_id.get_name(), scop_info_)) { continue; } } else { - if (buffer_info.dst_tensor_id.get_name() != local_tensor_c_ + LOCAL_SUFFIX) { + if (IsTensorAB(buffer_info.dst_tensor_id.get_name(), scop_info_)) { continue; } } @@ -185,18 +149,6 @@ isl::schedule RegisterMemoryManager::HoistRegisterMemoryOnDepth(isl::schedule_no } } -bool RegisterMemoryManager::UnrolledLoop(const TensorFootprintCluster &fp_cluster) { - auto box_sizes = fp_cluster.GetFixedBoxSizes(); - size_t tmp_size = 1; - for (auto size : box_sizes) { - tmp_size = tmp_size * size; - } - if (tmp_size != 1) { - return true; - } - return false; -} - /*Check if the given "group" can be promoted to registers for the given * mapping to thread identifiers and within the given outer schedule */ bool RegisterMemoryManager::IsPromote(const TensorFootprintCluster &fp_cluster, @@ -211,15 +163,6 @@ bool RegisterMemoryManager::IsPromote(const TensorFootprintCluster &fp_cluster, return thread_schedule_mapping.is_injective(); } -/* Check that whether the mapping relation between instance statement - * and outer schedule points and tensor elements pair is reusable. */ -bool RegisterMemoryManager::ReuseTensorCluster(const TensorFootprintCluster &cluster, - const isl::multi_union_pw_aff &outer_pw_aff) { - /* compute the mapping relation between statement instance and outer schedule space and tensor elements pair */ - isl::union_map state_schedule_mapping = ScheduleTensorMapping(outer_pw_aff, cluster.OrigianlAccessRelations()); - return !state_schedule_mapping.is_injective(); -} - void RegisterMemoryManager::CreateTensorCluster(const isl::schedule_node &node, const isl::union_map &outer_sch) { isl::union_map reads = scop_info_.analysis_result_.GetReads(); isl::union_map writes = scop_info_.analysis_result_.GetWrites(); @@ -271,23 +214,14 @@ void RegisterMemoryManager::CreateTensorCluster(const isl::schedule_node &node, std::vector promoted_infos; - if (scop_info_.user_config_.GetEnableMatmul()) { - std::unordered_map matmul_map = scop_info_.analysis_result_.GetMatrixMatmulMap(); - for (auto i : matmul_map) { - if (i.second == MATRIX_C) { - local_tensor_c_ = i.first; - } - } - } - for (const auto &item : tensor_list) { if (scop_info_.user_config_.GetEnableMatmul() && !hoist_tensor_all_) { if (!hoist_compute_local_tensor_) { - if (item.get_name() == local_tensor_c_) { + if (!IsTensorAB(item.get_name(), scop_info_)) { continue; } } else { - if (item.get_name() != local_tensor_c_) { + if (IsTensorAB(item.get_name(), scop_info_)) { continue; } } @@ -354,7 +288,11 @@ void RegisterMemoryManager::IsOutofMemory(std::vector promoted_in auto box_sizes = promoted_info.footprints_cluster->GetFixedBoxSizes(); if (!box_sizes.empty()) { auto tensor_size = std::accumulate(box_sizes.begin(), box_sizes.end(), 1, std::multiplies()); - auto data_bytes = scop_info_.user_config_.GetDataType(promoted_info.tensor_id.get_name()); + if (scop_info_.user_config_.GetEnableTensorCoreUsePoly()) { + tensor_size = (promoted_info.tensor_id.get_name() == local_tensor_c_) ? (tensor_size / alloc_threads) + : (tensor_size * 2 / alloc_threads); + } + auto data_bytes = scop_info_.user_config_.GetDataBytes(promoted_info.tensor_id.get_name()); total_alloc_size += tensor_size * std::max(1, data_bytes / BYTES_PER_REGISTER); if (total_alloc_size * alloc_threads >= MAX_REGISTER_PER_THREAD_BLOCK * REGISTER_ALLOC_RATIO) { memory_exceeding_ = true; @@ -459,7 +397,8 @@ bool RegisterMemoryManager::IsReadOrWriteBand(isl::schedule_node node) { isl::schedule_node RegisterMemoryManager::GetRegisterPromotedNode(isl::schedule_node &root) { isl::schedule_node hoist_register_node = root; root.foreach_descendant_top_down([&hoist_register_node](const isl::schedule_node &node) -> bool { - if (auto sequence_node = node.as()) { + if (node.isa()) { + auto sequence_node = node.as(); if (sequence_node.parent().isa() && sequence_node.parent().parent().isa()) { hoist_register_node = sequence_node.parent().parent(); @@ -469,7 +408,8 @@ isl::schedule_node RegisterMemoryManager::GetRegisterPromotedNode(isl::schedule_ return false; } } - if (auto mark_node = node.as()) { + if (node.isa()) { + auto mark_node = node.as(); if (mark_node.get_id().get_name() == THREAD_MARKER && mark_node.parent().isa()) { hoist_register_node = mark_node.parent(); return false; @@ -480,30 +420,14 @@ isl::schedule_node RegisterMemoryManager::GetRegisterPromotedNode(isl::schedule_ return hoist_register_node; } -isl::schedule_node RegisterMemoryManager::CollectMarkNode(isl::schedule_node root, - const std::string local_position_mark) { - isl::schedule_node hoist_node; - root.foreach_descendant_top_down([&hoist_node, &local_position_mark](const isl::schedule_node &node) -> bool { - if (auto mark_node = node.as()) { - // ignore nested mark nodes - if (mark_node.get_id().get_name() == local_position_mark) { - hoist_node = mark_node; - return false; - } - } - return true; - }); - return hoist_node; -} - isl::schedule RegisterMemoryManager::HoistRegisterMemoryOnMark(isl::schedule_node root) { std::string config_shared_tensors = scop_info_.user_config_.GetSharedTensors(); auto c_mark = PROMOTE_GLOBAL_TO_REGISTER_C; if (config_shared_tensors.find(local_tensor_c_) != std::string::npos) { - c_mark = PROMOTE_GLOBAL_TO_SHARED_C; + c_mark = PROMOTE_SHARED_TO_REGISTER_C; } - auto mark_node = CollectMarkNode(root, c_mark); + auto mark_node = CollectMarkNodeOnPromotion(root, c_mark); auto tmp_hoist_node = mark_node.parent(); while (!tmp_hoist_node.isa()) { @@ -527,8 +451,8 @@ isl::schedule RegisterMemoryManager::HoistRegisterMemoryOnMark(isl::schedule_nod auto sch = HoistRegisterMemoryOnDepth(hoist_compute_node, depth); auto hoist_ab_root = sch.get_root(); - auto ab_mark = PROMOTE_SHARED_TO_REGISTER; - auto mark_ab_node = CollectMarkNode(hoist_ab_root, ab_mark); + auto ab_mark = PROMOTE_SHARED_TO_REGISTER_AB; + auto mark_ab_node = CollectMarkNodeOnPromotion(hoist_ab_root, ab_mark); auto hoist_ab_node = mark_ab_node.del().parent(); auto hoist_ab_depth = hoist_ab_node.schedule_depth(); hoist_compute_local_tensor_ = false; @@ -537,12 +461,19 @@ isl::schedule RegisterMemoryManager::HoistRegisterMemoryOnMark(isl::schedule_nod return sch; } -isl::schedule_node RegisterMemoryManager::MapPromotionTensorToWarps(isl::schedule_node &root) { - std::string write_name = WRITE_ID_NAME; +std::string RegisterMemoryManager::GetPromotedWriteName() { + std::string write_name = GML_WRITE_ID_NAME; std::string shared_tensors = shared_tensors_; if (shared_tensors.find(local_tensor_c_) != std::string::npos) { write_name = SHARED_WRITE_ID_NAME; } + return write_name; +} + +// According to the value of the conv interface, the size of the tensor is split to confirm the size of the fragment. +isl::schedule_node RegisterMemoryManager::TileTensorAccordingInterfaceValue(isl::schedule_node &root) { + CHECK(scop_info_.user_config_.GetReplaceConfig().count(WARP_COMPUTE)) << "Cannot map to warp."; + std::string write_name = GetPromotedWriteName(); auto CollectReadWriteFilter = [this, write_name](isl::schedule_node node) -> isl::schedule_node { if (!node.isa()) { return node; @@ -551,75 +482,36 @@ isl::schedule_node RegisterMemoryManager::MapPromotionTensorToWarps(isl::schedul if (!is_all_sets_read_or_write) { return node; } - auto band_node = GetCanMappingNode(node); - CHECK(scop_info_.user_config_.GetReplaceConfig().count(WARP_COMPUTE)) << "Cannot map to warp."; - auto mapping_cfg = scop_info_.user_config_.GetReplaceConfig()[WARP_COMPUTE]; - auto original_x = mapping_cfg->GetX().second; - auto original_y = mapping_cfg->GetY().second; + auto start_depth = node.get_tree_depth(); + + auto band_node = GetCanMappingNode(node); std::string id_name = GetPromotionTensorName(band_node, scop_info_.analysis_result_.buffer_def_infos_); if (id_name.empty() || !scop_info_.analysis_result_.GetMatrixMatmulMap().count(id_name) || !scop_info_.analysis_result_.GetMatrixMatmulMajor().count(id_name)) { - return band_node; + return node; } + bool is_conv = scop_info_.user_config_.GetEnableConvTensorCore(); + if (is_conv) { + band_node = AdjustConvScheduleTreeStructure(band_node); + } + + auto mapping_cfg = scop_info_.user_config_.GetReplaceConfig()[WARP_COMPUTE]; + CHECK(mapping_cfg != nullptr) << "mapping config is null"; // split member that does not involved in thread mapping - bool has_split = false; auto mem_size = band_node.as().n_member(); if (mem_size > mapping_cfg->bound) { band_node = band_node.as().split(mem_size - mapping_cfg->bound); band_node = band_node.child(0); - has_split = true; } - auto matrix_name = scop_info_.analysis_result_.GetMatrixMatmulMap()[id_name]; - auto matrix_major = scop_info_.analysis_result_.GetMatrixMatmulMajor()[id_name]; + std::string matrix_name = scop_info_.analysis_result_.GetMatrixMatmulMap()[id_name]; + std::string matrix_major = scop_info_.analysis_result_.GetMatrixMatmulMajor()[id_name]; isl::multi_val tile_size_val = GetRealTileSizeVal(band_node, matrix_name, matrix_major); band_node = TileBand(band_node, tile_size_val); - // In order to ensure that the data when promotion and calculation are consistent, map the m axis of MATRIX_A to - // w0, and map the n axis of MATRIX_B to w1. - bool need_coalesce = false; - if (matrix_name == MATRIX_A) { - need_coalesce = (matrix_major == ROW_MAJOR) ? true : false; - mapping_cfg->ModifySize(N_POSITION, MAPPING_INVALID_WARP); - } else if (matrix_name == MATRIX_B) { - need_coalesce = (matrix_major == ROW_MAJOR) ? true : false; - mapping_cfg->ModifySize(M_POSITION, MAPPING_INVALID_WARP); - } else { - need_coalesce = true; - } - - Mapping mapping; - auto after_map_pair = MapInnerDimToThreads(band_node, true, mapping_cfg, mapping, need_coalesce); - band_node = after_map_pair.first; - - if (matrix_name == MATRIX_A) { - need_coalesce = true; - mapping_cfg->ModifySize(N_POSITION, original_y); - } else if (matrix_name == MATRIX_B) { - mapping_cfg->ModifySize(M_POSITION, original_x); - } - - bool locate_is_child = false; - if (band_node.child(0).as()) { - band_node = band_node.child(0); - locate_is_child = true; - } - if (band_node.as()) { - auto marker_name = band_node.as().get_id().get_name(); - if (marker_name.find(THREAD_MARKER) != std::string::npos) { - band_node = band_node.del().insert_mark(isl::id(band_node.ctx(), matrix_name)); - } - } - band_node = locate_is_child ? band_node.parent() : band_node; - std::string fragment_mark = FRAGMENT; - fragment_mark += matrix_name.at(matrix_name.size() - 1); - band_node = band_node.insert_mark(fragment_mark); - - band_node = has_split ? band_node.parent() : band_node; - - node = band_node.parent(); + node = band_node.ancestor(band_node.get_tree_depth() - start_depth); return node; }; @@ -628,35 +520,31 @@ isl::schedule_node RegisterMemoryManager::MapPromotionTensorToWarps(isl::schedul isl::multi_val RegisterMemoryManager::GetRealTileSizeVal(const isl::schedule_node &node, const std::string &matrix_name, const std::string &matrix_major) { - auto title_size_count = static_cast(pass_info_.tile_sizes_.size()); auto ctx = node.ctx(); auto space = node.as().get_space(); isl::multi_val tile_size_val = isl::multi_val::zero(space); - auto init_number = title_size_count > M_N_K_COUNT ? title_size_count - M_N_K_COUNT : 0; - auto del_position = init_number; - bool need_coalesce = false; + int m = scop_info_.analysis_result_.GetMmaMode().m; + int n = scop_info_.analysis_result_.GetMmaMode().n; + int k = scop_info_.analysis_result_.GetMmaMode().k; + std::vector tile_size_number; + bool need_reverse = false; if (matrix_name == MATRIX_B) { - need_coalesce = (matrix_major == ROW_MAJOR) ? true : false; - del_position += M_POSITION; + need_reverse = (matrix_major == ROW_MAJOR) ? true : false; + tile_size_number.emplace_back(m); + tile_size_number.emplace_back(k); } else if (matrix_name == MATRIX_A) { - need_coalesce = (matrix_major == COL_MAJOR) ? true : false; - del_position += N_POSITION; + need_reverse = (matrix_major == COL_MAJOR) ? true : false; + tile_size_number.emplace_back(n); + tile_size_number.emplace_back(k); } else { - del_position += K_POSITION; - } - - std::vector tile_size_number; - for (auto i = init_number; i < title_size_count; ++i) { - if (i == del_position) { - continue; - } - tile_size_number.emplace_back(static_cast(pass_info_.tile_sizes_[i].c0_tiling_size)); + tile_size_number.emplace_back(m); + tile_size_number.emplace_back(n); } auto len = static_cast(tile_size_number.size()); for (auto i = 0; i < len; ++i) { - int pos = need_coalesce ? len - 1 - i : i; + int pos = need_reverse ? len - 1 - i : i; tile_size_val = tile_size_val.set_val(pos, isl::val(ctx, tile_size_number[i])); } @@ -664,30 +552,7 @@ isl::multi_val RegisterMemoryManager::GetRealTileSizeVal(const isl::schedule_nod } isl::schedule RegisterMemoryManager::Run(isl::schedule sch) { - auto GetGMWriteFilter = [this](isl::schedule_node node) -> isl::schedule_node { - if (!node.isa()) { - return node; - } - isl::union_set uset = node.as().get_filter(); - bool is_gm_write = false; - uset.foreach_set([&is_gm_write](isl::set s) { - if (s.get_tuple_name() == WRITE_ID_NAME) { - is_gm_write = true; - } - }); - if (is_gm_write && node.has_parent() && node.parent().isa()) { - node = node.child(0).insert_mark(PROMOTE_LOCAL_TO_GLOBAL); - node = node.parent(); - } - return node; - }; - - auto node = sch.root().child(0); - if (node.isa()) { - node = node.del(); - } - node = InsertContextNode(node, scop_info_); - sch = node.schedule(); + sch = InsertContextNode(sch, scop_info_); if (!scop_info_.user_config_.UseRegisterMemory()) { return sch; @@ -700,22 +565,29 @@ isl::schedule RegisterMemoryManager::Run(isl::schedule sch) { schedule_ = sch; auto root = sch.get_root(); + if (scop_info_.user_config_.GetEnableMatmul()) { + GetActualPromotedSharedTensors(); + sch = HoistRegisterMemoryOnMark(root); + if (scop_info_.user_config_.GetEnableTensorCoreUsePoly()) { + root = sch.get_root(); + sch = TileTensorAccordingInterfaceValue(root).get_schedule(); + } + std::string write_name = GetPromotedWriteName(); + std::string marker_name = PROMOTE_REGISTER_TO_GLOBAL; + if (write_name == SHARED_WRITE_ID_NAME) { + marker_name = PROMOTE_REGISTER_TO_SHARED; + } + sch = InsertMarkerForThreadGroup(sch, write_name, marker_name); + return sch; + } + auto res_node = GetRegisterPromotedNode(root); if (res_node.isa()) { auto depth = UpdateDepth(res_node); if (scop_info_.user_config_.GetRegisterDepth() >= 0) { depth = scop_info_.user_config_.GetRegisterDepth(); } - if (scop_info_.user_config_.GetEnableMatmul()) { - sch = HoistRegisterMemoryOnMark(root); - if (scop_info_.user_config_.GetEnableTensorCoreUsePoly()) { - root = sch.get_root(); - sch = MapPromotionTensorToWarps(root).get_schedule(); - } - sch = sch.get_root().map_descendant_bottom_up(GetGMWriteFilter).get_schedule(); - } else { - sch = HoistRegisterMemory(root, depth); - } + sch = HoistRegisterMemory(root, depth); } return sch; diff --git a/src/poly/schedule_pass_gpu/register_memory_manager.h b/src/poly/schedule_pass_gpu/register_memory_manager.h index c6e5593b97f6ec3329ab68d91b82bf9e742474f9..52d6b4ff3729846c0bd5787fcc509b1e629dd26e 100644 --- a/src/poly/schedule_pass_gpu/register_memory_manager.h +++ b/src/poly/schedule_pass_gpu/register_memory_manager.h @@ -18,6 +18,7 @@ #define REGISTER_MEMORY_MANAGER_H_ #include "poly/schedule_pass.h" +#include "poly/schedule_tree_util.h" namespace akg { namespace ir { @@ -26,10 +27,6 @@ namespace poly { constexpr auto MAX_REGISTER_PER_THREAD_BLOCK = 65536; constexpr auto BYTES_PER_REGISTER = 4; constexpr auto REGISTER_ALLOC_RATIO = 1.0; // percentage of local memory that allocated to tensors -constexpr auto M_N_K_COUNT = 3; -constexpr auto M_POSITION = 0; -constexpr auto N_POSITION = 1; -constexpr auto K_POSITION = 2; /* * Manager shared memory in GPU. @@ -42,6 +39,9 @@ class RegisterMemoryManager : public SchedulePass { if (!scop_info.user_config_.GetLocalTensors().empty()) { configed_tensors_ = Split(scop_info.user_config_.GetLocalTensors(), " "); } + if (scop_info_.user_config_.GetEnableMatmul()) { + local_tensor_c_ = GetMatmulTensorsName(scop_info)[MATRIX_C]; + } }; ~RegisterMemoryManager() {} @@ -49,14 +49,10 @@ class RegisterMemoryManager : public SchedulePass { isl::schedule HoistRegisterMemoryOnDepth(isl::schedule_node &node, size_t depth); - isl::union_set GatherMappingsTo(MappingCfg *cfg); - void CreateTensorCluster(const isl::schedule_node &node, const isl::union_map &outer_sch); void GatherBufferFootprintDefInfo(const isl::schedule_node &node, BufferDefInfo &tensor_info); - bool ReuseTensorCluster(const TensorFootprintCluster &cluster, const isl::multi_union_pw_aff &outer_pw_aff); - bool IsPromote(const TensorFootprintCluster &fp_cluster, const isl::multi_union_pw_aff &partial_sched_mupa, const isl::multi_union_pw_aff &thread_schedule); @@ -71,13 +67,12 @@ class RegisterMemoryManager : public SchedulePass { isl::schedule_node GetRegisterPromotedNode(isl::schedule_node &root); isl::schedule HoistRegisterMemoryOnMark(isl::schedule_node root); - isl::schedule_node CollectMarkNode(isl::schedule_node root, const std::string local_position_mark); - - isl::schedule_node MapPromotionTensorToWarps(isl::schedule_node &root); + isl::schedule_node TileTensorAccordingInterfaceValue(isl::schedule_node &root); isl::multi_val GetRealTileSizeVal(const isl::schedule_node &node, const std::string &matrix_name, const std::string &matrix_major); + std::string GetPromotedWriteName(); - void SharedTensors(); + void GetActualPromotedSharedTensors(); bool IsReadOrWriteBand(isl::schedule_node node); @@ -89,7 +84,7 @@ class RegisterMemoryManager : public SchedulePass { bool memory_exceeding_{false}; bool hoist_compute_local_tensor_{true}; bool hoist_tensor_all_{false}; - std::string local_tensor_c_{COMPUTE}; + std::string local_tensor_c_; std::string shared_tensors_; }; diff --git a/src/poly/schedule_pass_gpu/shared_memory_manager.cc b/src/poly/schedule_pass_gpu/shared_memory_manager.cc index a4eb4546bc02148d89b8f2b33664da1517081354..722a9bd6113a54fd25a907c316282da9bfc239cf 100644 --- a/src/poly/schedule_pass_gpu/shared_memory_manager.cc +++ b/src/poly/schedule_pass_gpu/shared_memory_manager.cc @@ -53,53 +53,38 @@ isl::schedule SharedMemoryManager::Run(isl::schedule sch) { shared_vector_align_ = scop_info_.user_config_.GetSharedVectorAlign(); // collect all bands at the given depth in the schedule tree - size_t remain_memory = share_memory_size_; + size_t remain_memory = common::SHARED_MEMORY_SIZE; if (scop_info_.user_config_.GetEnableMatmul()) { - remain_memory = tensor_core_share_memory_size_; - root = HoistSharedMemoryOnMark(root, remain_memory, depth_); + remain_memory =akg::common::ADVANCED_SHARED_MEMORY_SIZE; + root = HoistSharedMemoryOnMark(root, remain_memory, depth_).root(); } else { - root = HoistSharedMemoryOnDepth(root, remain_memory, depth_); + root = HoistSharedMemoryOnDepth(root, remain_memory, depth_).root(); } bool unroll_shared = scop_info_.user_config_.GetUnrollShared(); root = MapCopiesToThreads(root, unroll_shared); schedule_ = root.get_schedule(); - - auto node = schedule_.root().child(0); - if (node.as()) { - node = node.del(); + if (scop_info_.user_config_.GetEnableMatmul()) { + schedule_ = InsertMarkerForThreadGroup(schedule_, WRITE_ID_NAME, PROMOTE_SHARED_TO_GLOBAL); } - node = InsertContextNode(node, scop_info_); - schedule_ = node.schedule(); - return schedule_; -} + schedule_ = InsertContextNode(schedule_, scop_info_); -isl::schedule_node SharedMemoryManager::CollectMarkNode(isl::schedule_node root, const std::string mark) { - isl::schedule_node hoist_node; - root.foreach_descendant_top_down([&hoist_node, &mark](const isl::schedule_node &node) -> bool { - if (auto mark_node = node.as()) { - // ignore nested mark nodes - if (mark_node.get_id().get_name() == mark) { - hoist_node = mark_node; - return false; - } - } - return true; - }); - return hoist_node; + return schedule_; } isl::schedule_node SharedMemoryManager::HoistSharedMemoryOnMark(const isl::schedule_node &root, size_t &remain_memory, size_t depth) { - auto ab_mark_node = CollectMarkNode(root, PROMOTE_GLOBAL_TO_SHARED_AB); + auto ab_mark_node = CollectMarkNodeOnPromotion(root, PROMOTE_GLOBAL_TO_SHARED_AB); auto ab_promote_node = ab_mark_node.parent(); hoist_tensor_c_ = false; auto ab_res_node = ManageToShareBelow(this->schedule_, ab_promote_node, remain_memory); - if (find(configed_tensors_.begin(), configed_tensors_.end(), tensor_c_) != configed_tensors_.end()) { - auto c_mark_node = CollectMarkNode(ab_res_node.get_schedule().get_root(), PROMOTE_GLOBAL_TO_SHARED_C); + auto tensor_c_name = GetMatmulTensorsName(scop_info_)[MATRIX_C]; + if (find(configed_tensors_.begin(), configed_tensors_.end(), tensor_c_name) != configed_tensors_.end()) { + auto c_mark_node = CollectMarkNodeOnPromotion(ab_res_node.get_schedule().get_root(), PROMOTE_GLOBAL_TO_SHARED_C); auto c_promote_node = c_mark_node.parent(); hoist_tensor_c_ = true; + remain_memory = akg::common::ADVANCED_SHARED_MEMORY_SIZE; auto c_res_node = ManageToShareBelow(c_promote_node.get_schedule(), c_promote_node, remain_memory); return c_res_node; } @@ -135,31 +120,6 @@ isl::schedule_node SharedMemoryManager::HoistSharedMemoryOnDepth(const isl::sche return MapDescendantTopDown(root, fn); } -isl::union_set SharedMemoryManager::GatherMappingsTo(MappingCfg *cfg) { - isl::schedule_node root = schedule_.get_root(); - auto domain_node = root.as(); - auto domain = domain_node.domain(); - auto mapping_filters = CollectNode(schedule_); - - std::vector filters; - for (size_t idx = 0; idx < cfg->bound; ++idx) { - auto value = cfg->GetAt(idx); - auto id = isl::id(root.ctx(), value.first); - filters.push_back(id); - } - mapping_filters = FilterNode(mapping_filters, filters); - - auto mapping = isl::union_set::empty(domain.ctx()); - for (auto item : mapping_filters) { - if (item.isa()) { - auto filter = item.as(); - auto filter_domain = filter.filter().intersect(CollectDomain(item)); - mapping = mapping.unite(filter.filter()); - } - } - return mapping; -} - isl::schedule_node SharedMemoryManager::MapCopiesToThreads(isl::schedule_node &root, bool unroll) { auto CollectReadWriteFilter = [&unroll, this](isl::schedule_node node) -> isl::schedule_node { if (!node.isa()) { @@ -173,15 +133,7 @@ isl::schedule_node SharedMemoryManager::MapCopiesToThreads(isl::schedule_node &r auto band_node = GetCanMappingNode(node); std::string atomic_type = InAtomicTensors(node); - // split member that does not involved in thread mapping - bool has_split = false; auto thread_cfg = scop_info_.user_config_.GetThreadConfig(); - auto mem_size = band_node.as().n_member(); - if (mem_size > thread_cfg->bound) { - band_node = band_node.as().split(mem_size - thread_cfg->bound); - band_node = band_node.child(0); - has_split = true; - } if (shared_inversed_thread_map_) { // Pretille - To make a vectorize loop more apparent with only the information of the mapping @@ -231,9 +183,43 @@ isl::schedule_node SharedMemoryManager::MapCopiesToThreads(isl::schedule_node &r mapping_cfg = thread_cfg; } } + + // split member that does not involved in mapping_cfg + bool has_split = false; + auto mem_size = band_node.as().n_member(); + if (mem_size > mapping_cfg->bound) { + band_node = band_node.as().split(mem_size - mapping_cfg->bound); + band_node = band_node.child(0); + has_split = true; + } + + if (shared_inversed_thread_map_) { + // Pretille - To make a vectorize loop more apparent with only the information of the mapping + const auto &domain = band_node.as().get_partial_schedule().domain(); + const isl::id ¤t_computing_id_shared = domain.unwrap().range().set_list().get_at(0).get_tuple_id(); + + std::vector tensor_size; + for (BufferDefInfo &buffer_info : scop_info_.analysis_result_.buffer_def_infos_) { + if (current_computing_id_shared == buffer_info.dst_tensor_id) { + tensor_size = buffer_info.sizes; + } + } + // Reverse because thread is innermost map + std::reverse(tensor_size.begin(), tensor_size.end()); + + auto ctx = band_node.ctx(); + const auto &space = band_node.as().get_space(); + const auto n_member = band_node.as().n_member(); + isl::multi_val tile_size = isl::multi_val::zero(space); + for (size_t i = 0; i < n_member; ++i) { + const size_t size = tensor_size[i] / thread_cfg->GetAt(i).second; + tile_size = tile_size.set_val(n_member - 1 - i, isl::val(ctx, size != 0 ? size : 1)); + } + band_node = TileBand(band_node, tile_size); + } + Mapping mapping; - auto after_map_pair = MapInnerDimToThreads(band_node, true, mapping_cfg, mapping, false); - band_node = after_map_pair.first; + band_node = MapInnerDimToThreads(band_node, true, mapping_cfg, mapping, false); auto InsertAtomicMarker = [atomic_type, this](isl::schedule_node atomic_node) -> isl::schedule_node { if (atomic_type != "" && atomic_node.has_children() && atomic_node.child(0).isa()) { atomic_node = @@ -285,6 +271,20 @@ MappingCfg *SharedMemoryManager::GetCurrentConfig(isl::schedule_node &node) { int vectorization_loop = 0; if (enable_vectorization) { vectorization_loop = vector_load_type / shares_tensor_bits_map[id_name]; + + isl::multi_val tile_size; + auto ctx = node.ctx(); + auto space = node.as().get_space(); + tile_size = isl::multi_val::zero(space); + + auto n_member = node.as().n_member(); + for (size_t i = 0; i < n_member - 1; ++i) { + tile_size = tile_size.set_val(i, isl::val(ctx, 1)); + } + tile_size = tile_size.set_val(n_member - 1, isl::val(ctx, vectorization_loop)); + + node = TileBand(node, tile_size).child(0); + node = node.insert_mark(PROMOTE_VECTORIZATION).parent(); } auto replace_cfg_map = scop_info_.user_config_.GetReplaceConfig(); @@ -302,17 +302,14 @@ MappingCfg *SharedMemoryManager::GetCurrentConfig(isl::schedule_node &node) { } std::string new_cfg = ""; - int mapping_dim = std::min(static_cast(upa_list.size()), static_cast(thread_cfg->MaxDim())); + int mapping_dim = static_cast(upa_list.size()); for (int i = 0; i < mapping_dim; ++i) { - auto extend = upa_list.get_at(i).max_val().get_num_si() + 1; - if (extend > total_thread || (i == mapping_dim - 1 && extend < total_thread)) { + auto extend = upa_list.get_at(i).floor().max_val().get_num_si() + 1; + if (extend >= total_thread || (i == mapping_dim - 1 && extend < total_thread)) { new_cfg += (std::to_string(total_thread) + " "); break; } - if (i == 0 && enable_vectorization) { - extend /= vectorization_loop; - } total_thread /= extend; new_cfg += (std::to_string(extend) + " "); } @@ -320,25 +317,10 @@ MappingCfg *SharedMemoryManager::GetCurrentConfig(isl::schedule_node &node) { if (new_cfg.empty()) { return nullptr; } - scop_info_.user_config_.RecordReplaceConfig(id_name, new_cfg, MappingType::REPLACE_THREADS); + scop_info_.user_config_.RecordReplaceConfig(id_name, new_cfg, MappingType::REPLACE_THREADS, false); } auto mapping_cfg = scop_info_.user_config_.GetReplaceConfig()[id_name]; - if (enable_vectorization) { - isl::multi_val tile_size; - auto ctx = node.ctx(); - auto space = node.as().get_space(); - tile_size = isl::multi_val::zero(space); - - auto n_member = node.as().n_member(); - for (size_t i = 0; i < n_member - 1; ++i) { - tile_size = tile_size.set_val(i, isl::val(ctx, 1)); - } - tile_size = tile_size.set_val(n_member - 1, isl::val(ctx, vectorization_loop)); - - node = TileBand(node, tile_size).child(0); - node = node.insert_mark(PROMOTE_VECTORIZATION).parent(); - } return mapping_cfg; } @@ -357,21 +339,19 @@ isl::schedule_node SharedMemoryManager::ManageToShareBelow(const isl::schedule & auto cfg = it.second; if (cfg->type == MappingType::REPLACE_BLOCKS) { if (mapping.is_null()) { - mapping = GatherMappingsTo(cfg); + mapping = GatherMappingsTo(root_node, cfg); } else { - mapping = mapping.intersect(GatherMappingsTo(cfg)); + mapping = mapping.intersect(GatherMappingsTo(root_node, cfg)); } } } if (mapping.is_null()) { - mapping = GatherMappingsTo(cfg); + mapping = GatherMappingsTo(root_node, cfg); } auto out_sched = partial_sched.intersect_domain(mapping); CreateClusterList(node, out_sched); - auto new_node = HoistClusters(root_node, node, remaining_memory); - auto sync_manager = scop_info_.sync_manager_; - return sync_manager.InsertPromotionSync(new_node); + return HoistClusters(root_node, node, remaining_memory); } std::set SharedMemoryManager::AnalysisReduceTensors() { @@ -457,21 +437,12 @@ void SharedMemoryManager::CreateClusterList(const isl::schedule_node &node, cons } if (scop_info_.user_config_.GetEnableMatmul()) { - std::unordered_map matmul_map = scop_info_.analysis_result_.GetMatrixMatmulMap(); - for (auto i : matmul_map) { - if (i.second == MATRIX_C) { - tensor_c_ = i.first; - } else if (i.second == MATRIX_A) { - tensor_a_ = i.first; - } else if (i.second == MATRIX_B) { - tensor_b_ = i.first; - } + auto tensors = GetMatmulTensorsName(scop_info_); + if (id_sets.count(tensors[MATRIX_A]) == 0) { + id_sets.emplace(tensors[MATRIX_A]); } - if (id_sets.count(tensor_a_) == 0) { - id_sets.emplace(tensor_a_); - } - if (id_sets.count(tensor_b_) == 0) { - id_sets.emplace(tensor_b_); + if (id_sets.count(tensors[MATRIX_B]) == 0) { + id_sets.emplace(tensors[MATRIX_B]); } } @@ -482,13 +453,11 @@ void SharedMemoryManager::CreateClusterList(const isl::schedule_node &node, cons for (const auto &item : tensor_list) { if (scop_info_.user_config_.GetEnableMatmul()) { if (!hoist_tensor_c_) { - if (item.get_name().find(tensor_a_) == std::string::npos && - item.get_name().find(tensor_b_) == std::string::npos) { + if (!IsTensorAB(item.get_name(), scop_info_)) { continue; } } else { - if (item.get_name().find(tensor_a_) != std::string::npos || - item.get_name().find(tensor_b_) != std::string::npos) { + if (IsTensorAB(item.get_name(), scop_info_)) { continue; } } @@ -529,15 +498,17 @@ void SharedMemoryManager::GatherBufferFootprintDefInfo(const isl::schedule_node return; } sizes = fp_cluster->GetFixedBoxSizes(); - if (scop_info_.user_config_.GetEnableMatmul() && sizes.back() % 2 == 0) { - sizes.back() += 16; + + isl::id tensor_id = tensor_info.tensor_id; + + if (scop_info_.user_config_.GetEnableMatmul() && tensor_id.get_name() == GetMatmulTensorsName(scop_info_)[MATRIX_C]) { + sizes.back() += 8; } if (bank_conflict_) { sizes = OptimizeSharedDimension(sizes); } - isl::id tensor_id = tensor_info.tensor_id; isl::id cluster_id = tensor_info.dst_tensor_id; // build a Halide Node for cluster_id @@ -551,7 +522,7 @@ void SharedMemoryManager::GatherBufferFootprintDefInfo(const isl::schedule_node const Buffer buffer = decl_buffer(shapes, scop_info_.GetDtypeOf(tensor_id), cluster_id.get_name()); scop_info_.user_config_.SetBind(tensor, buffer); if (scop_info_.user_config_.GetVectorLoadType()) { - scop_info_.analysis_result_.RecoreSharedTensorBitsMap(tensor_id.get_name(), + scop_info_.analysis_result_.RecordSharedTensorBitsMap(tensor_id.get_name(), scop_info_.GetDtypeOf(tensor_id).bits()); } @@ -575,11 +546,11 @@ isl::schedule_node SharedMemoryManager::HoistClusters(const isl::schedule_node & if (scop_info_.user_config_.GetEnableMatmul()) { if (!hoist_tensor_c_) { - if (id.get_name().find(tensor_a_) == std::string::npos && id.get_name().find(tensor_b_) == std::string::npos) { + if (!IsTensorAB(id.get_name(), scop_info_)) { continue; } } else { - if (id.get_name().find(tensor_a_) != std::string::npos || id.get_name().find(tensor_b_) != std::string::npos) { + if (IsTensorAB(id.get_name(), scop_info_)) { continue; } } @@ -644,20 +615,6 @@ isl::schedule_node SharedMemoryManager::HoistToBlockThreadMemory(isl::schedule_n return res_node; } -bool SharedMemoryManager::ReuseTensorCluster(const TensorFootprintCluster &cluster, - const isl::multi_union_pw_aff &outer_pw_aff) { - isl::union_map state_schedule_mapping = ScheduleTensorMapping(outer_pw_aff, cluster.OrigianlAccessRelations()); - /* Here we use the property of bijective to decide whether promote this tensor to shared. - * For element wise operator, S -> tensor_schedule is bijective. - * It should not be promoted to shared memory. - * For reduced operator, S -> tensor_schedule is not bijective. - * It should be promoted to shared memory. - * For stencil operator in sciencetific computing, S -> tensor_schedule is not bijective. - * It should be promoted to shared memory. - * *******************************************************************************************/ - return !(state_schedule_mapping.is_bijective()); -} - bool SharedMemoryManager::CoalescingAccessWay(const isl::schedule_node &root, const isl::schedule_node &node, const TensorFootprintCluster &cluster) { isl::union_map original = cluster.OrigianlAccessRelations(); @@ -784,8 +741,8 @@ std::vector SharedMemoryManager::OptimizeSharedDimension(std::vector SharedMemoryManager::OptimizeBankConflict(std::vector sizes) { std::vector res = sizes; if (res.back() % 2 == 0) { - if (scop_info_.user_config_.GetEnableMatmul()) { - res.back() += 16; + if (bank_conflict_ && res.back() < 32) { + res.back() = 33; } else { res.back() += 1; } diff --git a/src/poly/schedule_pass_gpu/shared_memory_manager.h b/src/poly/schedule_pass_gpu/shared_memory_manager.h index 724283adbdc6d27f7ff9af5123f5ed359fc43a53..5bc5ab3a599b273878ecd14a7f50981b6a250c39 100644 --- a/src/poly/schedule_pass_gpu/shared_memory_manager.h +++ b/src/poly/schedule_pass_gpu/shared_memory_manager.h @@ -18,6 +18,7 @@ #define SHARED_MEMORY_MANAGER_H_ #include "poly/schedule_pass.h" +#include "common/common_util.h" namespace akg { namespace ir { @@ -32,9 +33,6 @@ class SharedMemoryManager : public SchedulePass { public: explicit SharedMemoryManager(ScopInfo &scop_info) : scop_info_(scop_info) { pass_name_ = __FUNCTION__; - // use 48KB in current GPU - share_memory_size_ = 49152; - tensor_core_share_memory_size_ = 61440; if (!scop_info.user_config_.GetSharedTensors().empty()) { configed_tensors_ = Split(scop_info.user_config_.GetSharedTensors(), " "); } @@ -46,8 +44,6 @@ class SharedMemoryManager : public SchedulePass { isl::schedule_node HoistSharedMemoryOnDepth(const isl::schedule_node &root, size_t &remain_memory, size_t depth); - isl::union_set GatherMappingsTo(MappingCfg *cfg); - isl::schedule_node MapCopiesToThreads(isl::schedule_node &root, bool unroll); MappingCfg *GetCurrentConfig(isl::schedule_node &node); @@ -65,8 +61,6 @@ class SharedMemoryManager : public SchedulePass { isl::schedule_node HoistToBlockThreadMemory(isl::schedule_node &tree, GpuMemType type, const isl::id &tensor_id, TensorFootprintCluster &cluster, bool force_last_extension_odd); - bool ReuseTensorCluster(const TensorFootprintCluster &cluster, const isl::multi_union_pw_aff &outer_pw_aff); - bool CoalescingAccessWay(const isl::schedule_node &root, const isl::schedule_node &node, const TensorFootprintCluster &cluster); @@ -86,24 +80,18 @@ class SharedMemoryManager : public SchedulePass { std::set AnalysisReduceTensors(); size_t Bytes(const isl::id tensor_id); - isl::schedule_node CollectMarkNode(isl::schedule_node root, const std::string mark); isl::schedule_node HoistSharedMemoryOnMark(const isl::schedule_node &root, size_t &remain_memory, size_t depth); private: ScopInfo &scop_info_; isl::schedule schedule_; - size_t share_memory_size_; - size_t tensor_core_share_memory_size_; int depth_{1}; bool use_config_{false}; std::vector configed_tensors_; bool unroll_copies_; bool bank_conflict_{false}; bool hoist_tensor_c_ = true; - std::string tensor_c_; - std::string tensor_a_; - std::string tensor_b_; bool shared_inversed_thread_map_{false}; int shared_vector_align_{0}; }; diff --git a/src/poly/schedule_tree_util.cc b/src/poly/schedule_tree_util.cc index 4f739de73ef4ac3e03811d0773d805fe38d2c586..bef715dd37781b10ad36fbaa6a10ee582e0261f2 100644 --- a/src/poly/schedule_tree_util.cc +++ b/src/poly/schedule_tree_util.cc @@ -229,8 +229,96 @@ std::vector BandsSplitAfterDepth(const std::vector isl::schedule_node { + if (!node.isa()) { + return node; + } + isl::union_set uset = node.as().get_filter(); + bool is_gm_write = false; + uset.foreach_set([&is_gm_write, write_name](isl::set s) { + if (s.get_tuple_name() == write_name) { + is_gm_write = true; + } + }); + if (is_gm_write && node.has_parent() && node.parent().isa()) { + node = node.child(0).insert_mark(marker_name); + node = node.parent(); + } + return node; + }; + auto final_sch = sch.get_root().map_descendant_bottom_up(GetPromotedWriteFilter).schedule(); + return final_sch; +} + +isl::schedule_node AdjustConvScheduleTreeStructure(const isl::schedule_node &orig_node, const bool is_promotion) { + auto node = orig_node; + if (!node.isa()) { + return node; + } + + auto band_node = node.as(); + auto orig_number = band_node.n_member(); + if (orig_number <= 2) { + return node; + } + + // original node + auto orig_partial_schedule = band_node.get_partial_schedule(); + bool orig_permutable = band_node.get_permutable(); + std::vector orig_coincident; + for (int i = 0; i < static_cast(orig_number); ++i) { + orig_coincident.push_back(band_node.member_get_coincident(i)); + } + + isl::union_pw_aff_list new_partial_schedule(node.ctx(), orig_number); + auto InsertPartialSchedule = [&new_partial_schedule](isl::schedule_node node) -> void { + auto partial_schedule = node.as().get_partial_schedule(); + for (int i = 0; i < static_cast(partial_schedule.size()); ++i) { + new_partial_schedule = new_partial_schedule.add(partial_schedule.get_at(i)); + } + }; + + // split n axis + node = node.as().split(1); + auto n_node = node; + node = node.del(); + + // split h and w axis + const int h_w_axis_size = 2; + int real_h_w_axis_size = is_promotion ? static_cast(orig_number) - h_w_axis_size : h_w_axis_size; + node = node.as().split(real_h_w_axis_size); + InsertPartialSchedule(node); + node = node.del(); + InsertPartialSchedule(n_node); + + // split o and other axis + InsertPartialSchedule(node); + + node = node.insert_partial_schedule(isl::multi_union_pw_aff(orig_partial_schedule.get_space(), new_partial_schedule)); + band_node = node.as(); + band_node = band_node.set_permutable(orig_permutable); + for (int i = 0; i < static_cast(orig_number); ++i) { + band_node = band_node.member_set_coincident(i, orig_coincident[i]); + } + return band_node; +} + +std::string GetMarkerName(const isl::schedule_node &node, std::string find_name) { + std::string marker_name = ""; + if (node.isa()) { + marker_name = node.as().get_id().get_name(); + if (marker_name.find(find_name) != std::string::npos) { + return marker_name; + } + marker_name = ""; + } + return marker_name; +} + isl::union_pw_aff_list GetUPAList(const isl::schedule_node &node, isl::multi_union_pw_aff &partial_schedule, - const bool is_promotion, bool need_coalesce) { + const bool is_promotion, const bool need_reverse) { if (is_promotion) { // we need to to get range of promoted band from extension node so that we can correctly fix stride auto parent = node; @@ -245,21 +333,20 @@ isl::union_pw_aff_list GetUPAList(const isl::schedule_node &node, isl::multi_uni auto upa_list = partial_schedule.get_union_pw_aff_list().reverse(); - if (need_coalesce) { + if (need_reverse) { upa_list = upa_list.reverse(); } return upa_list; } -std::pair MapInnerDimToThreads(const isl::schedule_node &node, - const bool is_promotion, MappingCfg *mapping_cfg, - Mapping &mapping, bool need_coalesce) { +isl::schedule_node MapInnerDimToThreads(const isl::schedule_node &node, const bool is_promotion, + MappingCfg *mapping_cfg, Mapping &mapping, const bool need_reverse) { CHECK(mapping_cfg != nullptr) << "thread config is null"; isl::schedule_node_band band_node = node.as(); size_t n_thread_map = std::min(static_cast(band_node.n_member()), mapping_cfg->bound); CHECK_LE(n_thread_map, mapping_cfg->MaxDim()) << "mapping to too many threads."; auto partial_schedule = band_node.get_partial_schedule(); - auto upa_list = GetUPAList(node, partial_schedule, is_promotion, need_coalesce); + auto upa_list = GetUPAList(node, partial_schedule, is_promotion, need_reverse); // append prefix to partial schedule for tiling auto add_prefix_schedule = partial_schedule; @@ -278,12 +365,12 @@ std::pair MapInnerDimToThreads(const isl } auto prefix_upa_list = add_prefix_schedule.get_union_pw_aff_list().reverse(); - if (need_coalesce) { + if (need_reverse) { prefix_upa_list = prefix_upa_list.reverse(); } - isl::schedule_node fix_node = CheckMapSizeAndApplyTile(node, prefix_upa_list, mapping_cfg, need_coalesce); - bool tiled = !fix_node.is_equal(node); + isl::schedule_node fix_node = CheckMapSizeAndApplyTile(node, prefix_upa_list, mapping_cfg, need_reverse); + bool is_tiled = !fix_node.is_equal(node); // drop un-mapped aff after tiling upa_list = upa_list.drop(n_thread_map, upa_list.size() - n_thread_map); @@ -294,15 +381,11 @@ std::pair MapInnerDimToThreads(const isl auto after_map_node = CreateAndInsertMapFilter(fix_node, is_promotion, upa_list, mapping_cfg, mapping); after_map_node = after_map_node.parent(); - if (is_promotion && tiled) { + if (is_tiled) { after_map_node = after_map_node.parent(); } - isl::schedule_node after_fix_node = after_map_node; - if (tiled && after_fix_node.has_parent()) { - after_fix_node = after_fix_node.parent(); - } - return std::make_pair(after_map_node, after_fix_node); + return after_map_node; } isl::schedule_node CreateAndInsertMapFilter(const isl::schedule_node &node, const bool is_promotion, @@ -370,13 +453,13 @@ isl::schedule_node CreateAndInsertMapFilter(const isl::schedule_node &node, cons */ isl::schedule_node CheckMapSizeAndApplyTile(const isl::schedule_node &mapping_root, const isl::union_pw_aff_list &aff_list, MappingCfg *mapping_cfg, - bool need_coalesce) { + const bool need_reverse) { bool need_tile = false; std::vector mapping_sizes; CHECK(mapping_cfg != nullptr) << "mapping config is null"; size_t block_count = 0; for (size_t i = 0; i < aff_list.size(); ++i) { - auto aff = aff_list.get_at(i); + auto aff = aff_list.get_at(i).floor(); auto extent = aff.max_val().get_num_si() + 1; if (mapping_cfg->type == MappingType::BLOCKS) { if (aff_list.size() - 1 - i < mapping_cfg->bound) { @@ -409,7 +492,7 @@ isl::schedule_node CheckMapSizeAndApplyTile(const isl::schedule_node &mapping_ro auto len = static_cast(mapping_sizes.size()); for (auto i = len - 1; i >= 0; --i) { - int pos = need_coalesce ? i : len - 1 - i; + int pos = need_reverse ? i : len - 1 - i; tile_size = tile_size.set_val(pos, isl::val(ctx, mapping_sizes[i])); } @@ -736,6 +819,127 @@ isl::schedule_node InsertExtensionNodeBeforeOrAfter(const isl::schedule_node &no return extension_node; } +isl::union_set GetBlockMappingFilterInfo(const isl::schedule_node node, MappingCfg *block_cfg, + std::unordered_map replace_cfg) { + isl::union_set mapping; + for (auto it : replace_cfg) { + auto cfg = it.second; + if (cfg->type == MappingType::REPLACE_BLOCKS) { + if (mapping.is_null()) { + mapping = GatherMappingsTo(node, cfg); + } else { + mapping = mapping.intersect(GatherMappingsTo(node, cfg)); + } + } + } + if (mapping.is_null()) { + mapping = GatherMappingsTo(node, block_cfg); + } + return mapping; +} + +isl::union_set GatherMappingsTo(const isl::schedule_node &root, MappingCfg *cfg) { + auto domain_node = root.as(); + auto domain = domain_node.domain(); + auto sch = root.get_schedule(); + auto mapping_filters = CollectNode(sch); + + std::vector filters; + for (size_t idx = 0; idx < cfg->bound; ++idx) { + auto value = cfg->GetAt(idx); + auto id = isl::id(root.ctx(), value.first); + filters.push_back(id); + } + mapping_filters = FilterNode(mapping_filters, filters); + + auto mapping = isl::union_set::empty(domain.ctx()); + for (auto item : mapping_filters) { + if (item.isa()) { + auto filter = item.as(); + if (filter.has_parent() && !filter.parent().isa()) { + continue; + } + + isl::union_set uset = filter.get_filter(); + std::vector vset; + uset.foreach_set([&vset](isl::set s) { vset.push_back(s); }); + if (!vset.empty()) { + auto filter_name = vset[0].get_tuple_name(); + if (filter_name == READ_ID_NAME || filter_name == WRITE_ID_NAME) { + continue; + } + } + + mapping = mapping.unite(filter.filter()); + } + } + return mapping; +} + +/* Check that whether the mapping relation between instance statement + * and outer schedule points and tensor elements pair is reusable. */ +bool ReuseTensorCluster(const TensorFootprintCluster &cluster, const isl::multi_union_pw_aff &outer_pw_aff) { + /* compute the mapping relation between statement instance and outer schedule space and tensor elements pair */ + /* Here we use the property of bijective to decide whether promote this tensor to shared. + * For element wise operator, S -> tensor_schedule is bijective. + * It should not be promoted to shared/local memory. + * For reduced operator, S -> tensor_schedule is not bijective. + * It should be promoted to shared/local memory. + * For stencil operator in sciencetific computing, S -> tensor_schedule is not bijective. + * It should be promoted to shared/local memory. + * *******************************************************************************************/ + isl::union_map state_schedule_mapping = ScheduleTensorMapping(outer_pw_aff, cluster.OrigianlAccessRelations()); + return !state_schedule_mapping.is_injective(); +} + +isl::schedule_node CollectMarkNodeOnPromotion(isl::schedule_node root, const std::string mark) { + isl::schedule_node hoist_node; + root.foreach_descendant_top_down([&hoist_node, &mark](const isl::schedule_node &node) -> bool { + if (auto mark_node = node.as()) { + // ignore nested mark nodes + if (mark_node.get_id().get_name() == mark) { + hoist_node = mark_node; + return false; + } + } + return true; + }); + return hoist_node; +} + +std::unordered_map GetMatmulTensorsName(ScopInfo &scop_info) { + std::unordered_map tensors; + if (scop_info.user_config_.GetEnableMatmul()) { + std::unordered_map matmul_map = scop_info.analysis_result_.GetMatrixMatmulMap(); + for (auto i : matmul_map) { + if (i.second == MATRIX_C) { + tensors.emplace(MATRIX_C, i.first); + } else if (i.second == MATRIX_A) { + tensors.emplace(MATRIX_A, i.first); + } else if (i.second == MATRIX_B) { + tensors.emplace(MATRIX_B, i.first); + } else if (i.second == MATRIX_ELSE) { + tensors.emplace(MATRIX_ELSE, i.first); + } + } + } + return tensors; +} + +bool IsTensorAB(const std::string &item, ScopInfo &scop_info) { + auto tensors = GetMatmulTensorsName(scop_info); + size_t pos = 0; + std::string item_tensor_name = item; + if ((pos = item_tensor_name.find(LOCAL_SUFFIX)) != std::string::npos || + (pos = item_tensor_name.find(SHARE_SUFFIX)) != std::string::npos) { + item_tensor_name = item_tensor_name.erase(pos, item_tensor_name.size() - pos); + } + if (item_tensor_name != tensors[MATRIX_A] && item_tensor_name != tensors[MATRIX_B]) { + return false; + } + return true; +} + } // namespace poly } // namespace ir } // namespace akg diff --git a/src/poly/schedule_tree_util.h b/src/poly/schedule_tree_util.h index 8a6d1e88ba92d88eb6c9168fab30c8f4f345c16e..c60445df675a26e9840d588c1520dce0cfb37a33 100644 --- a/src/poly/schedule_tree_util.h +++ b/src/poly/schedule_tree_util.h @@ -94,16 +94,15 @@ isl::schedule_node BandSplitAtDepth(isl::schedule_node &band, size_t depth); std::vector BandsSplitAfterDepth(const std::vector &bands, isl::schedule_node &root, size_t depth); isl::union_pw_aff_list GetUPAList(const isl::schedule_node &node, isl::multi_union_pw_aff &partial_schedule, - const bool is_promotion, bool need_coalesce); -std::pair MapInnerDimToThreads(const isl::schedule_node &node, - const bool is_promotion, MappingCfg *mapping_cfg, - Mapping &mapping, bool need_coalesce); + const bool is_promotion, const bool need_reverse); +isl::schedule_node MapInnerDimToThreads(const isl::schedule_node &node, const bool is_promotion, + MappingCfg *mapping_cfg, Mapping &mapping, const bool need_reverse); isl::schedule_node CreateAndInsertMapFilter(const isl::schedule_node &node, const bool is_promotion, isl::union_pw_aff_list upa_list, MappingCfg *mapping_cfg, Mapping &mapping, std::unordered_map map_idx_shift = {}); isl::schedule_node CheckMapSizeAndApplyTile(const isl::schedule_node &thread_root, const isl::union_pw_aff_list &aff_list, MappingCfg *mapping_cfg, - bool need_coalesce); + const bool need_reverse); bool IsEqualNode(const isl::schedule_node node1, const isl::schedule_node node2); isl::multi_union_pw_aff MapDomainToThread(const isl::schedule_node &node, MappingCfg *mapping_cfg, @@ -122,6 +121,24 @@ isl::schedule_node UnrollByMarkOptions(isl::schedule_node &node, uint64_t unroll isl::map GetExtensionSpace(const isl::schedule_node &node, const isl::id &id); isl::schedule_node InsertExtensionNodeBeforeOrAfter(const isl::schedule_node &node, const isl::id &id, bool before); +isl::schedule InsertMarkerForThreadGroup(const isl::schedule sch, const std::string write_name, + const std::string marker_name); +std::string GetMarkerName(const isl::schedule_node &node, std::string find_name); + +isl::union_set GetBlockMappingFilterInfo(const isl::schedule_node node, MappingCfg *block_cfg, + std::unordered_map replace_cfg); +isl::union_set GatherMappingsTo(const isl::schedule_node &root, MappingCfg *cfg); + +bool ReuseTensorCluster(const TensorFootprintCluster &cluster, const isl::multi_union_pw_aff &outer_pw_aff); + +isl::schedule_node CollectMarkNodeOnPromotion(isl::schedule_node root, const std::string mark); + +std::unordered_map GetMatmulTensorsName(ScopInfo &scop_info); + +bool IsTensorAB(const std::string &item, ScopInfo &scop_info); + +isl::schedule_node AdjustConvScheduleTreeStructure(const isl::schedule_node &orig_node, const bool is_promotion = true); + } // namespace poly } // namespace ir } // namespace akg diff --git a/src/poly/scop.cc b/src/poly/scop.cc index 005bf86b24f2f7c3f652ab4d5c429df42f2f70ce..3557732283b27f9d0e9c177b6c5c516c1caf8ce7 100644 --- a/src/poly/scop.cc +++ b/src/poly/scop.cc @@ -20,7 +20,9 @@ #include "poly/scop_builder.h" #include "poly/poly_util.h" #include "poly/npu_isl_emitter.h" -#include "poly/gpu_isl_emitter.h" +#include "poly/gpu_emit/gpu_isl_emitter.h" +#include "poly/gpu_emit/gpu_isl_emitter_reduce.h" +#include "poly/gpu_emit/gpu_isl_emitter_tensor_core.h" #include "poly/dsa_mgr_strategy.h" #include "poly/gpu_mgr_strategy.h" #include "poly/schedule_pass_mgr.h" @@ -137,12 +139,20 @@ isl::schedule Scop::Transform(const isl::schedule &input_schedule) { } } if (info_.user_config_.GetTarget() == TARGET_CUDA) { - auto reduce_st_map = info_.analysis_result_.GetReduceTensorInfoMap(); - info_.user_config_.SetEnableAkgReduceLib((!reduce_st_map.empty()) && (!info_.user_config_.GetEnableMatmul())); + auto reduce_tensor_info = info_.analysis_result_.GetReduceTensorInfoMap(); + bool is_reduce = !reduce_tensor_info.empty() && !info_.user_config_.GetEnableMatmul() && + info_.user_config_.GetEnableAkgReduceLib(); + bool is_matmul = !reduce_tensor_info.empty() && !info_.user_config_.GetEnableAkgReduceLib() && + info_.user_config_.GetEnableMatmul(); + bool is_tensor_core = !reduce_tensor_info.empty() && !info_.user_config_.GetEnableAkgReduceLib() && + info_.user_config_.GetEnableTensorCore(); + info_.user_config_.SetEnableAkgReduceLib(is_reduce); + info_.user_config_.SetEnableMatmul(is_matmul); + info_.user_config_.SetEnableTensorCore(is_tensor_core); if (info_.user_config_.GetEnableAkgReduceLib()) { bool has_supported_op = false; LOG(INFO) << "====== Reduce op type ========"; - for (auto it : reduce_st_map) { + for (auto it : reduce_tensor_info) { LOG(INFO) << it.first << " -> " << info_.analysis_result_.GetReduceOpType(it.first); auto type = info_.analysis_result_.GetReduceOpType(it.first); if (type == AKG_REDUCE_UNSUPPORTED) { @@ -262,7 +272,13 @@ Stmt GenHalide(ScopInfo &info, const isl::schedule &sch, bool used_for_tile_out_ stmt = NPUIslEmitter(info, node_info_repo, iters).Emit(ast_node); } else if (info.user_config_.GetTarget() == TARGET_CUDA) { PrintHeader("GpuIslEmitter"); - stmt = GpuIslEmitter(info, node_info_repo, iters).Emit(ast_node); + if (info.user_config_.GetEnableAkgReduceLib()) { + stmt = GpuIslEmitterReduce(info, node_info_repo, iters).Emit(ast_node); + } else if (info.user_config_.GetEnableTensorCore()) { + stmt = GpuIslEmitterTensorCore(info, node_info_repo, iters).Emit(ast_node); + } else { + stmt = GpuIslEmitter(info, node_info_repo, iters).Emit(ast_node); + } } } else { PrintHeader("IslEmitter"); @@ -272,7 +288,13 @@ Stmt GenHalide(ScopInfo &info, const isl::schedule &sch, bool used_for_tile_out_ if (info.user_config_.GetTarget() == TARGET_CCE) { stmt = NPUIslEmitter(info, node_info_repo, iters).Emit(ast_node); } else if (info.user_config_.GetTarget() == TARGET_CUDA) { - stmt = GpuIslEmitter(info, node_info_repo, iters).Emit(ast_node); + if (info.user_config_.GetEnableAkgReduceLib()) { + stmt = GpuIslEmitterReduce(info, node_info_repo, iters).Emit(ast_node); + } else if (info.user_config_.GetEnableTensorCore()) { + stmt = GpuIslEmitterTensorCore(info, node_info_repo, iters).Emit(ast_node); + } else { + stmt = GpuIslEmitter(info, node_info_repo, iters).Emit(ast_node); + } } } diff --git a/src/poly/scop_info.cc b/src/poly/scop_info.cc index 209a8abe6acb1752c402331f8a618ef54a03d6f4..871b9ddcd4a3d2d15130e7d5348e8f150eda2573 100644 --- a/src/poly/scop_info.cc +++ b/src/poly/scop_info.cc @@ -83,7 +83,7 @@ std::unordered_set AnalysisResult::ExtractWithStmtId() const { return res; } -int UserConfig::GetDataType(const std::string &name) const { +int UserConfig::GetDataBytes(const std::string &name) const { for (auto i : GetBind()) { if (i.first->op->name == name) { int size = i.first->dtype.bytes(); @@ -93,6 +93,16 @@ int UserConfig::GetDataType(const std::string &name) const { return 1; } +Type UserConfig::GetDataType(const std::string &name) const { + for (auto i : GetBind()) { + if (i.first->op->name == name) { + Type type = i.first->dtype; + return type; + } + } + CHECK(false) << "Get Data Type fail!"; + return Type(); +} std::string CubeInfo::ExtractStringFromAttrs(const std::string &name) const { for (auto i : analysis_result_.GetStmtOpInfoMap()) { if (!i.second.isMMU) { @@ -1241,9 +1251,11 @@ CondVarsMap AnalysisResult::GetCondVarsMap() { return cond_vars; } -const BufferDefInfo &AnalysisResult::GetBufferDefInfo(const isl::id &tensor_id) const { +const BufferDefInfo &AnalysisResult::GetBufferDefInfo(const isl::id &tensor_id, const bool is_dst_tensor_id) const { for (const auto &idx : BufferDefInfos()) { - if (idx.dst_tensor_id.get_name() == tensor_id.get_name()) { + bool is_contains_target_tensor = (is_dst_tensor_id && idx.dst_tensor_id.get_name() == tensor_id.get_name()) || + (!is_dst_tensor_id && idx.tensor_id.get_name() == tensor_id.get_name()); + if (is_contains_target_tensor) { return idx; } } @@ -1442,6 +1454,8 @@ static std::string MemTypeToString(const MemType &memType) { return "SHARED"; case MemType::LOCAL_: return "LOCAL"; + case MemType::DDR_LOCAL_: + return "GML"; default: return ""; } @@ -1449,8 +1463,14 @@ static std::string MemTypeToString(const MemType &memType) { std::string ScopInfo::GetIslReadName(const isl::id &cluster_id) { auto tensor_info = analysis_result_.GetBufferDefInfo(cluster_id); - MemType memType = tensor_info.SrcMemType(); - return MemTypeToString(memType) + "read"; + MemType src_memType = tensor_info.SrcMemType(); + if (user_config_.GetTarget() == TARGET_CUDA) { + MemType dst_memType = tensor_info.DstMemType(); + if (src_memType == MemType::DDR && dst_memType == MemType::LOCAL_) { + return MemTypeToString(MemType::DDR_LOCAL_) + "read"; + } + } + return MemTypeToString(src_memType) + "read"; } std::string ScopInfo::GetIslWriteName(const isl::id &cluster_id) { @@ -1459,6 +1479,15 @@ std::string ScopInfo::GetIslWriteName(const isl::id &cluster_id) { MemType memType = tensor_info.DstMemType(); return MemTypeToString(memType) + "write"; } + + if (user_config_.GetTarget() == TARGET_CUDA) { + auto tensor_info = analysis_result_.GetBufferDefInfo(cluster_id, false); + MemType src_memType = tensor_info.SrcMemType(); + MemType dst_memType = tensor_info.DstMemType(); + if (src_memType == MemType::DDR && dst_memType == MemType::LOCAL_) { + return MemTypeToString(MemType::DDR_LOCAL_) + "write"; + } + } return MemTypeToString(MemType::DDR) + "write"; } diff --git a/src/poly/scop_info.h b/src/poly/scop_info.h index db83ae11ba2a47ddd4b2ee7d669d45b32c4723be..3885e4d3d4ed022670e355efd33239f22453f610 100644 --- a/src/poly/scop_info.h +++ b/src/poly/scop_info.h @@ -53,11 +53,12 @@ struct MappingCfg { std::pair x; std::pair y; std::pair z; + std::vector> dim; public: MappingType type{NONE}; size_t bound{0}; - size_t MaxDim() { return 3; } + size_t MaxDim() { return std::max(dim.size(), static_cast(3)); } std::string GetPrefix(MappingType type) { CHECK_NE(type, MappingType::NONE); if (type == MappingType::BLOCKS || type == MappingType::REPLACE_BLOCKS) { @@ -66,17 +67,21 @@ struct MappingCfg { return "t"; } } - void BindFromStr(const std::string &cfg, const std::string &id_name = "") { + void BindFromStr(const std::string &cfg, const std::string &id_name = "", const bool enable_max_dim = true) { std::vector res = common::Split(cfg, " "); - CHECK_LE(res.size(), MaxDim()); + if (enable_max_dim) { + CHECK_LE(res.size(), MaxDim()); + } for (size_t i = 0; i < res.size(); ++i) { CHECK(!res[i].empty()); auto size = static_cast(std::strtol(res[i].c_str(), nullptr, 10)); - BindAt(i, size, id_name); + BindAt(i, size, id_name, enable_max_dim); } } - void BindAt(size_t pos, int size, const std::string &id_name = "") { - CHECK_LT(pos, MaxDim()); + void BindAt(size_t pos, int size, const std::string &id_name = "", const bool enable_max_dim = true) { + if (enable_max_dim) { + CHECK_LT(pos, MaxDim()); + } bound = std::max(bound, pos + 1); std::string id = ""; if (!id_name.empty()) { @@ -89,18 +94,27 @@ struct MappingCfg { } else if (pos == 1) { y.first = id; y.second = size; - } else { + } else if (pos == 2) { z.first = id; z.second = size; } + std::pair dim_pos; + dim_pos.first = id; + dim_pos.second = size; + dim.push_back(dim_pos); } std::pair GetAt(size_t pos) { if (pos == 0) { return GetX(); } else if (pos == 1) { return GetY(); - } else { + } else if (pos == 2) { return GetZ(); + } else { + CHECK_LT(pos, dim.size()); + auto res = dim[pos]; + res.second = res.second == 0 ? 1 : res.second; + return res; } } std::pair GetX() { @@ -123,6 +137,7 @@ struct MappingCfg { x.second = 0; y.second = 0; z.second = 0; + dim.clear(); } void SwapConfig(size_t pos1, size_t pos2) { auto cfg1 = GetAt(pos1); @@ -237,8 +252,9 @@ class UserConfig { ParseBoolAttr(attrs, "enable_atomic_add", &enable_atomic_add_); if (GetTarget() == TARGET_CUDA) { - ParseBoolAttr(attrs, "enable_tile_c0", &enable_tile_c0_); ParseBoolAttr(attrs, "pragma_enable_tensor_core", &enable_tensor_core_); + ParseBoolAttr(attrs, "pragma_enable_emit_core", &pragma_enable_emit_core_); + ParseBoolAttr(attrs, "pragma_enable_conv_tensor_core", &enable_conv_tensor_core_); ParseBoolAttr(attrs, "pragma_enable_matmul", &enable_matmul_); ParseBoolAttr(attrs, "enable_tensor_core_use_poly", &enable_tensor_core_use_poly_); ParseBoolAttr(attrs, "enable_akg_reduce_lib", &enable_akg_reduce_lib_); @@ -286,11 +302,12 @@ class UserConfig { this->block_cfg_.BindFromStr(block_cfg); } void SetThreadConfig(const std::string &thread_cfg); - void RecordReplaceConfig(const std::string id, const std::string replace_cfg_str, const MappingType mapping_type) { + void RecordReplaceConfig(const std::string id, const std::string replace_cfg_str, const MappingType mapping_type, + const bool enable_max_dim = true) { MappingCfg *replace_cfg(new (std::nothrow) MappingCfg()); CHECK(replace_cfg) << "memory alloc fail."; replace_cfg->type = mapping_type; - replace_cfg->BindFromStr(replace_cfg_str, id); + replace_cfg->BindFromStr(replace_cfg_str, id, enable_max_dim); this->replace_cfg_[id] = replace_cfg; } void SetC0BlockSize(const std::vector c0_block_size) { c0_block_size_ = c0_block_size; } @@ -394,12 +411,12 @@ class UserConfig { std::string GetIterPrefix(bool is_spec_gemm = false) const { return is_spec_gemm ? kGemmIterNamePrefix : kIterNamePrefix; } - int GetDataType(const std::string &name) const; + int GetDataBytes(const std::string &name) const; + Type GetDataType(const std::string &name) const; // dump all info void DumpScopDataScheduleAttrs(std::ofstream &of); - bool GetEnableTileC0() { return enable_tile_c0_; } bool GetEnableAtomicAdd() { return enable_atomic_add_; } bool GetEnableAkgReduceLib() { return enable_akg_reduce_lib_; } @@ -409,14 +426,26 @@ class UserConfig { bool GetEnableMatmul() { return enable_matmul_; } void SetEnableMatmul(bool enable_matmul) { enable_matmul_ = enable_matmul; } - bool GetEnableTensorCore() { return enable_tensor_core_; } - void SetEnableTensorCore(bool use_tensor_core) { enable_tensor_core_ = use_tensor_core; } + bool GetEnableTensorCore() { + SetEnableTensorCore(enable_tensor_core_); + return enable_tensor_core_; + } + void SetEnableTensorCore(bool enable_tensor_core) { enable_tensor_core_ = enable_matmul_ && enable_tensor_core; } + + bool GetEnableEmitCore() { return pragma_enable_emit_core_; } + void SetEnableEmitCore(bool pragma_enable_emit_core) { pragma_enable_emit_core_ = pragma_enable_emit_core; } - bool GetEnableTensorCoreUsePoly() { return enable_tensor_core_use_poly_; } + bool GetEnableTensorCoreUsePoly() { + SetEnableTensorCoreUsePoly(enable_tensor_core_use_poly_); + return enable_tensor_core_use_poly_; + } void SetEnableTensorCoreUsePoly(bool enable_tensor_core_use_poly) { - enable_tensor_core_use_poly_ = enable_tensor_core_use_poly; + enable_tensor_core_use_poly_ = enable_tensor_core_ && enable_tensor_core_use_poly; } + bool GetEnableConvTensorCore() { return enable_conv_tensor_core_; } + void SetEnableConvTensorCore(bool enable_conv_tensor_core) { enable_conv_tensor_core_ = enable_conv_tensor_core; } + bool GetEnableOneDimThread() { return enable_one_dim_thread_; } void SetEnableOneDimThread(bool enable_one_dim_thread) { enable_one_dim_thread_ = enable_one_dim_thread; } @@ -426,12 +455,14 @@ class UserConfig { void SetUseRegisterMemory(bool use_register_memory) { use_register_memory_ = use_register_memory; } int GetRegisterDepth() { return register_depth_; } int GetSharedDepth() { return shared_depth_; } + void SetSharedTensors(std::string shared_tensors) { shared_tensors_ = shared_tensors; } std::string GetSharedTensors() { return shared_tensors_; } std::string GetReduceLibType() { return reduce_lib_type_; } std::string GetLocalTensors() { return local_tensors_; } void SetEnableBankConflict(bool enable_bank_conflict) { enable_bank_conflict_ = enable_bank_conflict; } bool GetEnableBankConflict() { return enable_bank_conflict_; } int GetVectorLoadType() { return vector_load_type_; } + void SetVectorLoadType(int vector_load_type) { vector_load_type_ = vector_load_type; } void SetSharedInversedThreadMap(bool shared_inversed_thread_map) { shared_inversed_thread_map_ = shared_inversed_thread_map; } @@ -554,12 +585,14 @@ class UserConfig { bool tile_size_is_var_{false}; bool outer_band_need_split_{false}; - bool enable_tile_c0_{false}; bool enable_atomic_add_{false}; // tensor_core config bool enable_matmul_{false}; bool enable_tensor_core_{false}; + bool pragma_enable_emit_core_{true}; bool enable_tensor_core_use_poly_{false}; + // conv config + bool enable_conv_tensor_core_{false}; // lib config bool enable_akg_reduce_lib_{true}; // memory config @@ -691,6 +724,20 @@ struct ReduceTensorInfo { using ReduceTensorInfoMap = std::unordered_map; +struct Mma { + int64_t m; + int64_t n; + int64_t k; +}; + +struct MmaConv { + int64_t m; + int64_t h; + int64_t w; + int64_t n; + int64_t k; +}; + class AnalysisResult { public: AnalysisResult() = default; @@ -722,27 +769,27 @@ class AnalysisResult { void RecordAtomicMarkers(const std::string &marker_name) { atomic_markers_.insert(marker_name); } void RecordReduceOutTensors(const std::string &tensor_name) { reduce_out_tensors_.insert(tensor_name); } void RecordContextParams(const isl::set &context_params) { context_params_ = context_params; } - void RecoreMatrixMatmulMap(const std::string matrix_name, const std::string matrix_position) { + void RecordMatrixMatmulMap(const std::string matrix_name, const std::string matrix_position) { matrix_matmul_map_.emplace(matrix_name, matrix_position); } void RecordCastTensors(const std::string tensor_name) { cast_tensors_.insert(tensor_name); } - void RecoreSharedTensorBitsMap(const std::string tensor_name, const int tensor_bits) { + void RecordSharedTensorBitsMap(const std::string tensor_name, const int tensor_bits) { shared_tensor_bits_map_.emplace(tensor_name, tensor_bits); } std::unordered_map GetSharedTensorBitsMap() const { return shared_tensor_bits_map_; } - void RecoreMatrixMatmulMajor(const std::string matrix_name, const std::string matrix_major) { - matrix_matmul_major_.emplace(matrix_name, matrix_major); + void RecordMatrixMatmulMajor(const std::string matrix_name, const std::string matrix_major) { + matrix_matmul_major_[matrix_name] = matrix_major; } + void SetMmaMode(Mma mma) { mma_ = mma; } std::unordered_map GetMatrixMatmulMap() const { return matrix_matmul_map_; } std::unordered_map GetMatrixMatmulMajor() const { return matrix_matmul_major_; } + Mma GetMmaMode() const { return mma_; } std::unordered_set GetCastTensors() const { return cast_tensors_; } isl::set GetContextParams() { return context_params_; } std::vector GetAtomicTensors() { return atomic_tensors_; } std::unordered_set GetAtomicMarkers() { return atomic_markers_; } std::unordered_set GetReduceOutTensors() { return reduce_out_tensors_; } isl::union_map GetReads() const { return reads_; } - std::unordered_set GetReduceAttrs() const { return reduce_attrs_; } - std::unordered_set GetNotReduceAttrs() const { return not_reduce_attrs_; } isl::union_map &GetWrites() { return writes_; } isl::union_map GetWrites() const { return writes_; } isl::union_map &GetCopyin() { return copyin_; } @@ -799,7 +846,7 @@ class AnalysisResult { int CountBufferDefInfo(const isl::id &tensor_id) const; const std::vector &BufferDefInfos() const { return buffer_def_infos_; } - const BufferDefInfo &GetBufferDefInfo(const isl::id &tensor_id) const; + const BufferDefInfo &GetBufferDefInfo(const isl::id &tensor_id, const bool is_dst_tensor_id = true) const; bool HasBufferDefInfo(const isl::id &tensor_id) const; const std::vector> &ActiveBufferFootprints() const { return active_buffer_footprints_; @@ -815,13 +862,28 @@ class AnalysisResult { void RecordReduceAttrs(const std::unordered_set &reduce_attrs) { reduce_attrs_ = std::move(reduce_attrs); } + std::unordered_set GetReduceAttrs() const { return reduce_attrs_; } void ClearReduceAttrs() { reduce_attrs_.clear(); } void RecordNotReduceAttrs(const std::unordered_set ¬_reduce_attrs) { not_reduce_attrs_ = std::move(not_reduce_attrs); } + std::unordered_set GetNotReduceAttrs() const { return not_reduce_attrs_; } void ClearNotReduceAttrs() { not_reduce_attrs_.clear(); } + void RecordReduceAxisForMatmul(const std::vector &reduce_axis) { + reduce_axis_ = std::move(reduce_axis); + } + std::vector GetReduceAxisForMatmul() const { return reduce_axis_; } + + void RecordNotReduceAxisForMatmul(const std::vector ¬_reduce_axis) { + not_reduce_axis_ = std::move(not_reduce_axis); + } + std::vector GetNotReduceAxisForMatmul() const { return not_reduce_axis_; } + + void RecordBatchAxisNumForMatmul(const unsigned int &batch_axis_num) { batch_axis_num_ = std::move(batch_axis_num); } + unsigned int GetBatchAxisNumForMatmul() const { return batch_axis_num_; } + void RecordReduceDirection(const std::string reduce_direction) { reduce_direction_ = reduce_direction; } std::string GetReduceDirection() const { return reduce_direction_; } @@ -857,9 +919,12 @@ class AnalysisResult { ReduceMap reduces_; ReduceTensorInfoMap reduce_tensor_info_; std::string reduce_direction_; + std::vector reduce_init_ids_; std::unordered_set reduce_attrs_; std::unordered_set not_reduce_attrs_; - std::vector reduce_init_ids_; + std::vector reduce_axis_; + std::vector not_reduce_axis_; + unsigned int batch_axis_num_; isl::union_map reads_; isl::union_map writes_; @@ -896,6 +961,7 @@ class AnalysisResult { std::unordered_map shared_tensor_bits_map_; TensorScheduleRepo tensor_schedule_repo_; std::unordered_map matrix_matmul_major_; + Mma mma_; }; class CubeInfo { @@ -1056,12 +1122,15 @@ class ScopInfo { std::string GetIslWriteName(const isl::id &cluster_id); static bool IsRead(const isl::id &id) { return IsEndsWith(id.get_name(), kReadSuffix); } static bool IsWrite(const isl::id &id) { return IsEndsWith(id.get_name(), kWriteSuffix); } + static bool IsGMLWrite(const isl::id &id) { return id.get_name() == std::string("GMLwrite"); } static bool IsGMWrite(const isl::id &id) { return id.get_name() == std::string("GMwrite"); } static bool IsGMRead(const isl::id &id) { return id.get_name() == std::string("GMread"); } static bool IsSync(const isl::id &id) { return IsStartsWith(id.name(), SYNC_FLAG); } static bool IsRealize(const isl::id &id) { return IsStartsWith(id.get_name(), "REALIZE"); } static bool IsReduceInit(const isl::id &id) { return IsStartsWith(id.get_name(), "red_init"); } static bool IsReduceUpdate(const isl::id &id) { return IsStartsWith(id.get_name(), "red_update"); } + static bool IsReduceInit(const std::string &name) { return IsStartsWith(name, "red_init"); } + static bool IsReduceUpdate(const std::string &name) { return IsStartsWith(name, "red_update"); } public: isl::ctx ctx_; diff --git a/src/poly/scop_make_schedule_tree.cc b/src/poly/scop_make_schedule_tree.cc index 3c311fd7de38fa33968b329d6a37b7a73e925557..b8b8fb8ffdd39668c2721804a4304fc779d5c8f1 100644 --- a/src/poly/scop_make_schedule_tree.cc +++ b/src/poly/scop_make_schedule_tree.cc @@ -17,6 +17,7 @@ #include "pass/utils.h" #include "construct_poly_accesses.h" #include "poly/scop_builder.h" +#include "poly/schedule_tree_util.h" namespace akg { namespace ir { @@ -171,14 +172,37 @@ class ScopMakeScheduleTree final : protected IRVisitor { auto domain = set.unbind_params(op_domain.tuple); sch = isl::schedule::from_domain(domain); - if ((scop_info_.user_config_.GetTarget() == TARGET_CUDA) && (scop_info_.user_config_.GetEnableMatmul())) { - scop_info_.user_config_.SetEnableAkgReduceLib(false); - CheckMatmul(op); + if (scop_info_.user_config_.GetTarget() == TARGET_CUDA && + (scop_info_.user_config_.GetEnableAkgReduceLib() || scop_info_.user_config_.GetEnableMatmul())) { RecordReduceInfo(op, op_domain, id); } - - if (scop_info_.user_config_.GetTarget() == TARGET_CUDA && (scop_info_.user_config_.GetEnableAkgReduceLib())) { - RecordReduceInfo(op, op_domain, id); + auto matmul_map = scop_info_.analysis_result_.GetMatrixMatmulMap(); + if (!matmul_map.empty()) { + std::string accumulator = ""; + auto mp = GetMatmulTensorsName(scop_info_); + if (mp.find(MATRIX_C) != mp.end()) { + accumulator = mp[MATRIX_C]; + } + CHECK(accumulator != "") << "MatMul info not enough!"; + Array elem_tensors = GetBinaryOpExprChildren(op->value); + if (!elem_tensors.empty()) { + auto left = elem_tensors[0].as(); + auto right = elem_tensors[1].as(); + if ((left || right) && (matmul_map.find(left->name) != matmul_map.end() || matmul_map.find(right->name) != matmul_map.end())) { + if (op->func->func_name() != accumulator) { + scop_info_.analysis_result_.RecordMatrixMatmulMap(op->func->func_name(), MATRIX_ELSE); + scop_info_.analysis_result_.RecordMatrixMatmulMajor(op->func->func_name(), ROW_MAJOR); + } + if (left && left->name != accumulator) { + scop_info_.analysis_result_.RecordMatrixMatmulMap(left->name, MATRIX_ELSE); + scop_info_.analysis_result_.RecordMatrixMatmulMajor(left->name, ROW_MAJOR); + } + if (right && right->name != accumulator) { + scop_info_.analysis_result_.RecordMatrixMatmulMap(right->name, MATRIX_ELSE); + scop_info_.analysis_result_.RecordMatrixMatmulMajor(right->name, ROW_MAJOR); + } + } + } } isl::union_map new_reads, new_writes, new_to_inner; @@ -271,8 +295,17 @@ class ScopMakeScheduleTree final : protected IRVisitor { reduce_tensor_info.stmt_map = upa; scop_info_.analysis_result_.RecordReduceTensorInfoMap(red_id, reduce_tensor_info); auto type = scop_info_.analysis_result_.GetReduceOpType(red_id); - if (scop_info_.user_config_.GetEnableAkgReduceLib() && AkgSupportedReduceOp.count(type) == 0) { - return; + + bool is_matmul = false; + if (AkgSupportedReduceOp.count(type) == 0) { + is_matmul = CheckMatmul(op); + if (!is_matmul) { + return; + } + } else { + scop_info_.user_config_.SetEnableMatmul(false); + scop_info_.user_config_.SetEnableTensorCore(false); + scop_info_.user_config_.SetEnableTensorCoreUsePoly(false); } reduce_tensor_info.write_tensor_name = op->func->func_name(); @@ -317,95 +350,76 @@ class ScopMakeScheduleTree final : protected IRVisitor { if (reduce_direction.empty()) { LOG(WARNING) << "Cannot identify reduce direction for stmt " << red_id; } - if (scop_info_.user_config_.GetEnableMatmul()) { + if (is_matmul) { reduce_direction = X_DIRECTION; } scop_info_.analysis_result_.RecordReduceDirection(reduce_direction); } - void GetRowColInfo() { - auto sch = scop_info_.user_config_.GetScheduleInfo(); - if (!sch.defined()) { - return; - } - for (air::Stage s : sch->stages) { - const ComputeOpNode *com = s->op.as(); - if (com == nullptr) continue; - - auto axis = com->axis; - auto reduce_axis = com->reduce_axis; - if (axis.size() < 2 || reduce_axis.size() != 1) continue; - - const Variable *axis_var[2]; - const Variable *reduce_axis_var; - axis_var[0] = axis[axis.size() - 2]->var.as(); - axis_var[1] = axis[axis.size() - 1]->var.as(); - reduce_axis_var = reduce_axis[0]->var.as(); - - class CollectInfoOfBody : public IRVisitor { - public: - CollectInfoOfBody() {} - using IRVisitor::Visit_; - - void Visit_(const Reduce *op) final { - auto *comm_add = op->combiner->result[0].as(); - if (comm_add == nullptr || op->combiner->result.size() > 1) { - return; - } - for (Expr source : op->source) { - auto mul_0 = akg::common::SplitCast(source, Float(32)).as(); - auto mul_1 = akg::common::SplitCast(source, Int(32)).as(); - if (mul_0 == nullptr && mul_1 == nullptr) { - continue; - } - - is_candidate_ = true; - IRVisitor::Visit(source); - } - } + bool GetRowColInfo(const Provide *op) { + auto axis = scop_info_.analysis_result_.GetNotReduceAxisForMatmul(); + auto reduce_axis = scop_info_.analysis_result_.GetReduceAxisForMatmul(); + auto batch_num_axis = scop_info_.analysis_result_.GetBatchAxisNumForMatmul(); + if (axis.size() < 2 || reduce_axis.size() < 1 || axis.size() <= batch_num_axis) return false; + + const Variable *axis_var[2]; + const Variable *reduce_axis_var; + axis_var[0] = axis[batch_num_axis].as(); + axis_var[1] = axis.back().as(); + reduce_axis_var = reduce_axis.back(); + + class CollectInfoOfBody : public IRVisitor { + public: + CollectInfoOfBody() {} + using IRVisitor::Visit_; + + void Visit_(const Call *op) final { + IRVisitor::Visit_(op); + args_.insert(std::make_pair(op->name, op->args)); + } - void Visit_(const Call *op) final { - IRVisitor::Visit_(op); - args_.insert(std::make_pair(op->name, op->args)); - } + std::unordered_map> GetArgs() { return args_; } - bool GetCandidate() { return is_candidate_; } - std::unordered_map> GetArgs() { return args_; } + private: + std::unordered_map> args_; + } collect_info_of_body; - private: - std::unordered_map> args_; - bool is_candidate_{false}; - } collect_info_of_body; + auto right = op->value; + auto add_op = right.as(); + CHECK(add_op); + auto tensor_c = add_op->a.as(); + if (tensor_c == nullptr) return false; - for (Expr expr : com->body) { - collect_info_of_body.Visit(expr); - } - if (!collect_info_of_body.GetCandidate()) { - continue; - } - for (auto iter : collect_info_of_body.GetArgs()) { - auto name = iter.first; - auto args = iter.second; - if (args.size() < 2) continue; - - const Variable *var0 = args[args.size() - 2].as(); - const Variable *var1 = args[args.size() - 1].as(); - if (var0 == nullptr || var1 == nullptr) continue; - - std::string major; - if ((var0 == reduce_axis_var) && (var1 == axis_var[0])) { - major = COL_MAJOR; - } else if ((var0 == reduce_axis_var) && (var1 == axis_var[1])) { - major = ROW_MAJOR; - } else if ((var0 == axis_var[0]) && (var1 == reduce_axis_var)) { - major = ROW_MAJOR; - } else if ((var0 == axis_var[1]) && (var1 == reduce_axis_var)) { - major = COL_MAJOR; - } - scop_info_.analysis_result_.RecoreMatrixMatmulMajor(name, major); + Type tensor_c_type; + if (!IsExistTensor(tensor_c->name, tensor_c_type)) return false; + + collect_info_of_body.Visit(add_op->b); + + for (auto iter : collect_info_of_body.GetArgs()) { + auto name = iter.first; + auto args = iter.second; + if (args.size() < 2) continue; + + const Variable *var0 = args[batch_num_axis].as(); + const Variable *var1 = args[args.size() - 1].as(); + if (var0 == nullptr || var1 == nullptr) continue; + + std::string major; + if ((var0 == reduce_axis_var) && (var1 == axis_var[0])) { + major = COL_MAJOR; + } else if ((var0 == reduce_axis_var) && (var1 == axis_var[1])) { + major = ROW_MAJOR; + } else if ((var0 == axis_var[0]) && (var1 == reduce_axis_var)) { + major = ROW_MAJOR; + } else if ((var0 == axis_var[1]) && (var1 == reduce_axis_var)) { + major = COL_MAJOR; + } else { + return false; } - scop_info_.analysis_result_.RecoreMatrixMatmulMajor(com->name, ROW_MAJOR); + scop_info_.analysis_result_.RecordMatrixMatmulMajor(name, major); } + scop_info_.analysis_result_.RecordMatrixMatmulMajor(op->func->func_name(), ROW_MAJOR); + return true; } bool IsExistTensor(const std::string tensor_name, Type &tensor_type) { @@ -426,7 +440,7 @@ class ScopMakeScheduleTree final : protected IRVisitor { return false; } - std::string GetTensorName(Expr tensor_data) { + std::string GetTensorName(Expr tensor_data, bool &enable_tensor_core) { std::string tensor_name = ""; if (tensor_data.as()) { auto tensor_data_p = tensor_data.as(); @@ -435,7 +449,7 @@ class ScopMakeScheduleTree final : protected IRVisitor { return tensor_name; } if ((tensor_type != Float(16)) && (tensor_type != Int(8))) { - return tensor_name; + enable_tensor_core = false; } tensor_name = tensor_data_p->name; } else if (tensor_data.as() && @@ -479,24 +493,71 @@ class ScopMakeScheduleTree final : protected IRVisitor { auto tensor_a = akg::common::SplitCast(mul_op->a, tensor_c_type); auto tensor_b = akg::common::SplitCast(mul_op->b, tensor_c_type); - std::string tensor_a_name = GetTensorName(tensor_a); - std::string tensor_b_name = GetTensorName(tensor_b); + std::string tensor_a_name = GetTensorName(tensor_a, enable_tensor_core); + std::string tensor_b_name = GetTensorName(tensor_b, enable_tensor_core); if (tensor_a_name.empty() || tensor_b_name.empty()) { return false; } - scop_info_.analysis_result_.RecoreMatrixMatmulMap(tensor_a_name, MATRIX_A); - scop_info_.analysis_result_.RecoreMatrixMatmulMap(tensor_b_name, MATRIX_B); - scop_info_.analysis_result_.RecoreMatrixMatmulMap(tensor_c->name, MATRIX_C); + scop_info_.analysis_result_.RecordMatrixMatmulMap(tensor_a_name, MATRIX_A); + scop_info_.analysis_result_.RecordMatrixMatmulMap(tensor_b_name, MATRIX_B); + scop_info_.analysis_result_.RecordMatrixMatmulMap(tensor_c->name, MATRIX_C); + + bool ret = GetRowColInfo(op); + if (!ret) { + return false; + } - GetRowColInfo(); - scop_info_.user_config_.SetEnableTensorCore(enable_tensor_core); + SetMmaModeForTensor(tensor_a_name, tensor_b_name); + scop_info_.user_config_.SetEnableMatmul(true); + scop_info_.user_config_.SetEnableTensorCore(true); + scop_info_.user_config_.SetEnableTensorCoreUsePoly(true); + scop_info_.user_config_.SetVectorLoadType(128); // Default vectorization access mode (128 bits). + scop_info_.user_config_.SetEnableAkgReduceLib(false); + + if (tensor_c_type == Float(16)) { + std::string shared_tensors = tensor_a_name + " " + tensor_b_name + " " + tensor_c->name; + scop_info_.user_config_.SetSharedTensors(shared_tensors); + } return true; } + void SetMmaModeForTensor(const std::string tensor_a_name, const std::string tensor_b_name) { + std::string custom_dim = scop_info_.user_config_.GetBDim(); + if (!custom_dim.empty() && !scop_info_.user_config_.GetEnableConvTensorCore()) { + const int each_axis_size = 4; + const int m_axis_pos = 1; + const int n_axis_pos = 2; + const int k_axis_pos = 3; + + Mma mma; + std::vector dim_str = Split(custom_dim, " "); + int batch_number = static_cast(scop_info_.analysis_result_.GetBatchAxisNumForMatmul()) > 0 ? 1 : 0; + int real_m_axis_pos = (m_axis_pos + batch_number) * each_axis_size - 1; + int real_n_axis_pos = (n_axis_pos + batch_number) * each_axis_size - 1; + int real_k_axis_pos = (k_axis_pos + batch_number) * each_axis_size - 1; + mma.m = static_cast(WrappedStrtol(dim_str[real_m_axis_pos])); + mma.n = static_cast(WrappedStrtol(dim_str[real_n_axis_pos])); + mma.k = static_cast(WrappedStrtol(dim_str[real_k_axis_pos])); + + scop_info_.analysis_result_.SetMmaMode(mma); + return; + } + + Mma mma; + auto matrix_a_major = scop_info_.analysis_result_.GetMatrixMatmulMajor()[tensor_a_name]; + auto matrix_b_major = scop_info_.analysis_result_.GetMatrixMatmulMajor()[tensor_b_name]; + if (matrix_a_major == COL_MAJOR && matrix_b_major == ROW_MAJOR) { + mma = {32, 32, 4}; + } else { + mma = {16, 16, 8}; + } + scop_info_.analysis_result_.SetMmaMode(mma); + } + void Visit_(const Block *op) final { auto sch_first = MakeScheduleTreeHelper(op->first, scop_info_, set, outer, macro_stmt); auto sch_rest = MakeScheduleTreeHelper(op->rest, scop_info_, set, outer, macro_stmt); @@ -738,8 +799,7 @@ class ScopMakeScheduleTree final : protected IRVisitor { } } if (scop_info_.user_config_.GetTarget() == TARGET_CUDA && - (scop_info_.user_config_.GetEnableAkgReduceLib() || scop_info_.user_config_.GetEnableTensorCore() || - scop_info_.user_config_.GetEnableMatmul())) { + (scop_info_.user_config_.GetEnableAkgReduceLib() || scop_info_.user_config_.GetEnableMatmul())) { class ExtractReductionAttrs final : public IRVisitor { public: ExtractReductionAttrs(const Stmt stmt, std::unordered_set left_args) @@ -751,33 +811,69 @@ class ScopMakeScheduleTree final : protected IRVisitor { void Visit_(const Variable *op) final { if (!extract_left_args.count(op->name_hint)) { extract_reduce_attrs.insert(op->name_hint); + for (auto &i : extract_reduce_axis) { + if (i == op) return; + } + extract_reduce_axis.push_back(op); } } + void Visit_(const Call *op) final { + if (visited_axis.size() == 0) { + batch_axis_num = op->args.size(); + for (size_t i = 0; i < op->args.size(); i++) { + visited_axis.push_back(op->args[i]); + } + } else { + unsigned int same_axis_num = 0; + for (size_t i = 0; (i < op->args.size()) && (i < visited_axis.size()); i++) { + if (Equal(op->args[i], visited_axis[i])) { + same_axis_num++; + } else { + break; + } + } + if (batch_axis_num > same_axis_num) batch_axis_num = same_axis_num; + } + IRVisitor::Visit_(op); + } + public: std::unordered_set extract_reduce_attrs; std::unordered_set extract_left_args; + std::vector extract_reduce_axis; + std::vector visited_axis; + unsigned int batch_axis_num; }; const auto pro = op->body.as(); + CHECK(pro); for (auto i = 0u; i < pro->args.size(); ++i) { auto args_i = pro->args[i]; auto mod = args_i.as(); if (mod != nullptr && mod->a.as()) { left_args.insert(mod->a.as()->name_hint); + left_axis_for_matmul.push_back(Downcast(mod->a)); } auto div = args_i.as(); if (div != nullptr && div->a.as()) { left_args.insert(div->a.as()->name_hint); + left_axis_for_matmul.push_back(Downcast(div->a)); } if (mod == nullptr && div == nullptr && args_i.as()) { left_args.insert(args_i.as()->name_hint); + left_axis_for_matmul.push_back(Downcast(args_i)); } } ExtractReductionAttrs extract_reduce_attr(op->body, left_args); scop_info_.analysis_result_.RecordReduceAttrs(extract_reduce_attr.extract_reduce_attrs); scop_info_.analysis_result_.RecordNotReduceAttrs(left_args); + if (scop_info_.user_config_.GetEnableMatmul()) { + scop_info_.analysis_result_.RecordReduceAxisForMatmul(extract_reduce_attr.extract_reduce_axis); + scop_info_.analysis_result_.RecordNotReduceAxisForMatmul(left_axis_for_matmul); + scop_info_.analysis_result_.RecordBatchAxisNumForMatmul(extract_reduce_attr.batch_axis_num); + } sch = MakeScheduleTreeHelper(op->body, scop_info_, set, outer, macro_stmt); scop_info_.analysis_result_.ClearReduceAttrs(); scop_info_.analysis_result_.ClearNotReduceAttrs(); @@ -808,6 +904,7 @@ class ScopMakeScheduleTree final : protected IRVisitor { isl::id_list outer; ssize_t macro_stmt{-1}; std::unordered_set left_args; + std::vector left_axis_for_matmul; }; isl::schedule MakeScheduleTreeHelper(const NodeRef &s, ScopInfo &scop_info, const isl::set &set, @@ -821,4 +918,4 @@ isl::schedule MakeScheduleTreeHelper(const NodeRef &s, ScopInfo &scop_info, cons } // namespace poly } // namespace ir -} // namespace akg \ No newline at end of file +} // namespace akg diff --git a/src/poly/sync_manager.cc b/src/poly/sync_manager.cc index 9d4bb8eff491db64073243da38d49d2a862edb34..e7cfe65c162b688d1d53e95134e803a820d2df87 100644 --- a/src/poly/sync_manager.cc +++ b/src/poly/sync_manager.cc @@ -65,59 +65,118 @@ isl::map SyncManager::GetExtensionSpace(const isl::schedule_node &node, SyncLeve return extension_space; } -isl::schedule_node SyncManager::InsertPromotionSync(const isl::schedule_node &tree) { - if (!tree.has_parent() || !tree.parent().has_parent()) { - return tree; +bool SyncManager::IsRepeatSync(const isl::schedule_node orig_node) { + // Determine whether there are repeated sync. + auto node = orig_node; + auto is_repeat_sync = false; + while (node.has_children()) { + node = node.child(node.n_children() - 1); } - auto seq_node = tree.parent().parent(); - if (!seq_node.isa()) { - return tree; + + if (node.has_parent() && node.parent().isa()) { + auto filter = node.parent().as().get_filter(); + filter.foreach_set([&is_repeat_sync](isl::set s) { + if (s.get_tuple_name().find_first_of(SYNC_PREFIX) == 0) { + is_repeat_sync = true; + return; + } + }); } + return is_repeat_sync; +} - bool find_read_or_write = false; - for (int i = seq_node.n_children() - 1; i >= 0; --i) { - auto filter_node = seq_node.child(i).as(); - CHECK(filter_node) << "Expected filters below sequence"; - // Transform isl::union_set to a vector of isl::set - isl::union_set uset = filter_node.get_filter(); - std::vector vset; - uset.foreach_set([&vset](isl::set s) { vset.push_back(s); }); - if (!vset.empty() && ((vset[0].get_tuple_name() == READ_ID_NAME) || (vset[0].get_tuple_name() == WRITE_ID_NAME))) { - find_read_or_write = true; - break; +isl::schedule SyncManager::InsertPromotionSync(const isl::schedule &sch) { + auto InsertSyncForSequence = [this](isl::schedule_node node) -> isl::schedule_node { + if (!node.isa()) { + return node; } - } - if (!find_read_or_write) { - return tree; - } - std::string cur_filter_name = ""; - std::string next_filter_name = ""; - for (int i = seq_node.n_children() - 1; i >= 0; --i) { - auto filter_node = seq_node.child(i).as(); - isl::union_set uset = filter_node.get_filter(); - std::vector vset; - uset.foreach_set([&vset](isl::set s) { vset.push_back(s); }); - // Get current filter name - if (!vset.empty()) { - cur_filter_name = vset[0].get_tuple_name(); + if (!node.has_parent() || !node.parent().isa()) { + return node; } - // Do not insert sync after the filter node - if (cur_filter_name == next_filter_name) { - continue; + + if (node.n_children() <= 1) { + return node; } - next_filter_name = cur_filter_name; - if ((cur_filter_name == READ_ID_NAME && next_filter_name == WRITE_ID_NAME) || - (cur_filter_name == WRITE_ID_NAME && next_filter_name == READ_ID_NAME)) { - continue; + auto GetCurrentFilterName = [this](isl::schedule_node node) -> std::string { + auto filter_node = node.as(); + CHECK(filter_node) << "Expected filters below sequence"; + // Transform isl::union_set to a vector of isl::set + isl::union_set uset = filter_node.get_filter(); + std::vector vset; + uset.foreach_set([&vset](isl::set s) { vset.push_back(s); }); + // Get current filter name + std::string cur_filter_name = ""; + if (!vset.empty()) { + cur_filter_name = vset[0].get_tuple_name(); + } + return cur_filter_name; + }; + + std::unordered_set shared_promotion_set = {READ_ID_NAME, WRITE_ID_NAME}; + bool find_read_or_write = false; + for (int i = node.n_children() - 1; i >= 0; --i) { + auto filter_node = node.child(i).as(); + CHECK(filter_node) << "Expected filters below sequence"; + std::string cur_filter_name = GetCurrentFilterName(filter_node); + if (!cur_filter_name.empty() && shared_promotion_set.find(cur_filter_name) != shared_promotion_set.end()) { + find_read_or_write = true; + break; + } } - // Insert sync after the filter node - seq_node = InsertExtensionNode(filter_node.child(0), SyncLevel::BLOCK, true).child(0); - } + if (!find_read_or_write) { + return node; + } - return seq_node; + std::string cur_filter_name = ""; + std::string next_filter_name = ""; + for (int i = node.n_children() - 1; i >= 0; --i) { + auto filter_node = node.child(i).as(); + cur_filter_name = GetCurrentFilterName(filter_node); + + // When the current filter and the next filter are the same, do not insert synchronization. + if (cur_filter_name == next_filter_name) { + continue; + } + + // When the current filter and the next filter are shared_read and shared_write at the same time, do not insert + // synchronization. + if (shared_promotion_set.find(cur_filter_name) != shared_promotion_set.end() && + shared_promotion_set.find(next_filter_name) != shared_promotion_set.end()) { + continue; + } + + bool is_continue = false; + // When the current filter and the next filter have nothing to do with shared_read and shared_write, do not insert + // synchronizatio + if (shared_promotion_set.find(cur_filter_name) == shared_promotion_set.end() && + shared_promotion_set.find(next_filter_name) == shared_promotion_set.end()) { + is_continue = true; + // When the first filter is related to shared_read and shared_write, insert synchronization + if (i == static_cast(node.n_children() - 1) && + shared_promotion_set.find(GetCurrentFilterName(node.child(0))) != shared_promotion_set.end()) { + is_continue = false; + } + } + if (is_continue) { + continue; + } + + next_filter_name = cur_filter_name; + + if (IsRepeatSync(filter_node)) { + continue; + } + + // Insert sync after the filter node + node = InsertExtensionNode(filter_node.child(0), SyncLevel::BLOCK, true).child(0); + } + return node; + }; + auto final_sch = sch.get_root().map_descendant_bottom_up(InsertSyncForSequence).get_schedule(); + return final_sch; } } // namespace poly diff --git a/src/poly/sync_manager.h b/src/poly/sync_manager.h index ff24cce86fb6871eb39056c62e40903f05de9f54..245856b7fbd7ed3ee49be150835449d076a22f33 100644 --- a/src/poly/sync_manager.h +++ b/src/poly/sync_manager.h @@ -22,6 +22,7 @@ #include #include #include +#include #include namespace akg { @@ -129,11 +130,13 @@ struct SyncCandidate { std::cout << std::endl; std::cout << "Block level sync count: " << std::endl; for (const auto &p : node->num_block_sync_to) { - std::cout << "[No." << node->idx << "]" << " -> [No." << p.first->idx << "] : #" << p.second << " sync." << std::endl; + std::cout << "[No." << node->idx << "]" + << " -> [No." << p.first->idx << "] : #" << p.second << " sync." << std::endl; } std::cout << "Warp level sync count: " << std::endl; for (const auto &p : node->num_warp_sync_to) { - std::cout << "[No." << node->idx << "]" << " -> [No." << p.first->idx << "] : #" << p.second << " sync." << std::endl; + std::cout << "[No." << node->idx << "]" + << " -> [No." << p.first->idx << "] : #" << p.second << " sync." << std::endl; } std::cout << "====================================================" << std::endl; }); @@ -146,7 +149,7 @@ class SyncManager { ~SyncManager() {} isl::schedule_node InsertExtensionNode(const isl::schedule_node &node, SyncLevel level, bool after); - isl::schedule_node InsertPromotionSync(const isl::schedule_node &tree); + isl::schedule InsertPromotionSync(const isl::schedule &sch); private: isl::ctx ctx_; @@ -157,6 +160,8 @@ class SyncManager { isl::id GetWarpSyncId() const; isl::map GetExtensionSpace(const isl::schedule_node &node, SyncLevel level); + + bool IsRepeatSync(const isl::schedule_node orig_node); }; } // namespace poly } // namespace ir diff --git a/src/poly/tiling/schtree_analyzer.cc b/src/poly/tiling/schtree_analyzer.cc index 4fb339a2b91d5f9b3b46f7eb5cece005c2e9f64c..58a689784364b647fbc0fc97921fe806f69b67cb 100644 --- a/src/poly/tiling/schtree_analyzer.cc +++ b/src/poly/tiling/schtree_analyzer.cc @@ -572,7 +572,7 @@ void ScheduleTreeAnalyzer::AddLoopDataSize() { if (it.first == nullptr) continue; std::vector pros = it.second; for (const Provide *p : pros) { - int data_size = analyzer_->scop_info_.user_config_.GetDataType(p->func->func_name()); + int data_size = analyzer_->scop_info_.user_config_.GetDataBytes(p->func->func_name()); VarNames related_name; auto ExtractName = [this, &related_name](const NodeRef &op) { if (const Call *call = op.as()) { diff --git a/src/poly/tiling/space_analyzer.cc b/src/poly/tiling/space_analyzer.cc index 30186c338dc7e5318c3c55fd1f677c50486b3780..715986ab8c2f4b819362cd4d440827cec2030e64 100644 --- a/src/poly/tiling/space_analyzer.cc +++ b/src/poly/tiling/space_analyzer.cc @@ -136,7 +136,7 @@ class SpaceVisitor : public IRVisitor { dst_tensor = MatchLoopByName(dst_tensor); dst_tensor.args = op->args; dst_tensor.band_index = band_count_; - dst_tensor.type_byte = analyzer_->scop_info_.user_config_.GetDataType(dst_tensor.name); + dst_tensor.type_byte = analyzer_->scop_info_.user_config_.GetDataBytes(dst_tensor.name); prov.basic_op_type = basic_op_type.empty() ? GetBasicOpType(dst_tensor, src_tensor) : basic_op_type; prov.flow = GetFlowFromBasicOpType(prov.basic_op_type); prov.band_index = band_count_; @@ -386,6 +386,9 @@ void SpaceAnalyzer::MarkGemmAxes(const ProvideEntry &pe) { auto EmplaceVarsInTensor = [](Tensor tensor, VarNames &var_list) -> void { for (const auto &vars_i : tensor.var_names) { for (const auto &name : vars_i) { + if (IsNum(name)) { + continue; + } var_list.emplace_back(name); } } @@ -427,8 +430,20 @@ void SpaceAnalyzer::MarkGemmAxes(const ProvideEntry &pe) { } // construct relationship between loop indices and loop type(b/m/n/k) and mark axis with corresponding attribute - std::string attr_key = analyzer_->op_type_ == CONV_OP ? AT_CONV : AT_GEMM; - auto loop_indices_map = ExtractLoopIndicesFromMatrices({mx_c, mx_a, mx_b}); + std::string attr_key = ""; + if (analyzer_->scop_info_.user_config_.GetEnableConvTensorCore()) { + attr_key = AT_CONV; + } else { + attr_key = AT_GEMM; + } + + std::unordered_map loop_indices_map; + if (analyzer_->scop_info_.user_config_.GetEnableConvTensorCore()) { + loop_indices_map = ExtractLoopIndicesFromMatricesConv({mx_c, mx_a, mx_b}); + } else { + loop_indices_map = ExtractLoopIndicesFromMatrices({mx_c, mx_a, mx_b}); + } + auto FindAxisAndMark = [this, &loop_indices_map, &attr_key](Band loops) { for (const auto &loop : loops) { auto index = loop->loop_var.get()->name_hint; @@ -941,8 +956,29 @@ void SpaceAnalyzer::IdentifyCustomTiling() { const auto mode = ctn->tile_mode.as(); CHECK(mode) << "Custom tiling mode must be set as string"; if (mode->value == "COMMON") { - if (ctn->mem_ratio != -1) { - analyzer_->RootAxis()->MarkWithAttr(AttrInfo{AT_MEM_RATIO, std::to_string(ctn->mem_ratio)}); + if (analyzer_->scop_info_.user_config_.GetTarget() == TARGET_CUDA) { + if (!ctn->thread_min.empty()) { + analyzer_->RootAxis()->MarkWithAttr(AttrInfo{AT_THREAD_MIN, ParseArrayExpr(ctn->thread_min)}); + } + if (!ctn->thread_max.empty()) { + analyzer_->RootAxis()->MarkWithAttr(AttrInfo{AT_THREAD_MAX, ParseArrayExpr(ctn->thread_max)}); + } + if (!ctn->thread_mod.empty()) { + analyzer_->RootAxis()->MarkWithAttr(AttrInfo{AT_THREAD_MOD, ParseArrayExpr(ctn->thread_mod)}); + } + if (!ctn->block_min.empty()) { + analyzer_->RootAxis()->MarkWithAttr(AttrInfo{AT_BLOCK_MIN, ParseArrayExpr(ctn->block_min)}); + } + if (!ctn->block_max.empty()) { + analyzer_->RootAxis()->MarkWithAttr(AttrInfo{AT_BLOCK_MAX, ParseArrayExpr(ctn->block_max)}); + } + if (!ctn->block_mod.empty()) { + analyzer_->RootAxis()->MarkWithAttr(AttrInfo{AT_BLOCK_MOD, ParseArrayExpr(ctn->block_mod)}); + } + } else { + if (ctn->mem_ratio != -1) { + analyzer_->RootAxis()->MarkWithAttr(AttrInfo{AT_MEM_RATIO, std::to_string(ctn->mem_ratio)}); + } } } else { std::string attr_value = ""; @@ -1028,6 +1064,15 @@ std::string SpaceAnalyzer::ParseAllTypeExpr(const Expr constraint) { } } +std::string SpaceAnalyzer::ParseArrayExpr(const Array constraint) { + std::stringstream ss; + for (auto val : constraint) { + ss << val; + ss << ","; + } + return ss.str(); +} + bool IsNameMatch(const std::string &match_from, const std::string &match_to) { std::vector pattern = akg::common::Split(match_to, "*"); bool fuzz = pattern.size() > 1U; diff --git a/src/poly/tiling/space_analyzer.h b/src/poly/tiling/space_analyzer.h index 2547877130909b5654b6c4bf39446c53c4577170..bf2afe369c4c92b0f03b7e9eca1f48cc79dc1ef8 100644 --- a/src/poly/tiling/space_analyzer.h +++ b/src/poly/tiling/space_analyzer.h @@ -90,6 +90,7 @@ class SpaceAnalyzer { void SetAttrForTensor(const std::string &tensor_name, int pos, const std::string &attr_key, const std::string &attr_value); std::string ParseAllTypeExpr(const Expr constraint); + std::string ParseArrayExpr(const Array constraint); }; } // namespace poly } // namespace ir diff --git a/src/poly/tiling/tiling_analyzer.cc b/src/poly/tiling/tiling_analyzer.cc index 787f6942cd4f00d1195fb36b9df75aa4b083c795..537777fda38fc634fcbb06e9685987395a587461 100644 --- a/src/poly/tiling/tiling_analyzer.cc +++ b/src/poly/tiling/tiling_analyzer.cc @@ -1353,6 +1353,7 @@ void TilingAnalyzer::AddPostTilingConstraints() { ReduceStrategy reduce_strategy(this); ModStrategy mod_strategy(this); GemmStrategy gemm_strategy(this); + ConvStrategy conv_strategy(this); GpuDmaAnalysisStrategy dma_analysis_strategy(this); CustomTilingStrategy custom_strategy(this); GpuStrategy gpu_strategy(this); @@ -1365,6 +1366,7 @@ void TilingAnalyzer::AddPostTilingConstraints() { actived_strategies.push_back(&reduce_strategy); actived_strategies.push_back(&mod_strategy); actived_strategies.push_back(&gemm_strategy); + actived_strategies.push_back(&conv_strategy); } actived_strategies.push_back(&gpu_strategy); } diff --git a/src/poly/tiling/tiling_solver.cc b/src/poly/tiling/tiling_solver.cc index 1f9c43e1e7c718e5b3ae06bc6e27b88a32808126..6aa4a8ded8295f3f8c2fd31168bbcc66b9aa25e4 100644 --- a/src/poly/tiling/tiling_solver.cc +++ b/src/poly/tiling/tiling_solver.cc @@ -847,15 +847,19 @@ TileCandidate *TraverseSolver::Solve() { } } - if (analyzer_.op_type_ == GEMM_OP) { + if (analyzer_.op_type_ == GEMM_OP || analyzer_.scop_info_.user_config_.GetTarget() == TARGET_CUDA) { for (TileAxis *axis : cand_.GetTileAxis()) { std::unique_ptr info(new (std::nothrow) TileInfo(axis, CACHE0, band)); CHECK(info) << "memory alloc fail"; if (IsTilable(info.get())) { - if (DoTiling(info.get())) break; + if (analyzer_.scop_info_.user_config_.GetEnableTensorCoreUsePoly() || DoTiling(info.get())) { + break; + } } } + } + if (analyzer_.op_type_ == GEMM_OP) { std::vector ko_axes = this->analyzer_.GetAxesOfAttr(AttrInfo{AT_GEMM, "ko"}); std::vector mo_axes = this->analyzer_.GetAxesOfAttr(AttrInfo{AT_GEMM, "mo"}); std::vector no_axes = this->analyzer_.GetAxesOfAttr(AttrInfo{AT_GEMM, "no"}); diff --git a/src/poly/tiling/tiling_strategy_manager.h b/src/poly/tiling/tiling_strategy_manager.h index 8c53f7dc358019fef6060c5df481e9934d1f14e6..35a51b6465c653ea8b16d65f966587a1cb4df10c 100755 --- a/src/poly/tiling/tiling_strategy_manager.h +++ b/src/poly/tiling/tiling_strategy_manager.h @@ -281,6 +281,29 @@ class ConvStrategy : public TilingStrategy { void RestrainH(TileAxis *axis); void RestrainW(TileAxis *axis); + + // gpu tensor core strategy steps + std::unique_ptr InitGemmShape(Mma mma); + std::pair CalculateNumOfWarps(Mma mma); + void CalculateMacroMma(MmaConv shape, Mma mma); + void SetFinalConfig(MmaConv macro_mma, Mma mma); + + // Return a combination of total factor that can be divisible by shape_m and shape_n. + std::pair GetDivisibleFactorForMN(int64_t shape_m, int64_t shape_n, int64_t total_factor, Mma mma); + + int w0_for_m_{1}; + int w1_for_n_{1}; + TileAxis *m_axis_{nullptr}; + TileAxis *h_axis_{nullptr}; + TileAxis *w_axis_{nullptr}; + TileAxis *n_axis_{nullptr}; + TileAxis *k_axis_{nullptr}; + int sm_bytes_{1}; + int reg_bytes_{1}; + int64_t num_sm_{80}; + int64_t min_blocks_{400}; + int64_t default_num_warps_{1}; + MmaConv macro_mma_{128, 1, 1, 128, 32}; }; class GemmStrategy : public TilingStrategy { @@ -289,6 +312,32 @@ class GemmStrategy : public TilingStrategy { ~GemmStrategy() {} void AddNpuConstraint(); void AddGpuConstraint(); + + // gpu tensor core strategy steps + std::unique_ptr InitGemmShape(Mma mma); + std::pair CalculateNumOfWarps(Mma mma); + void CalculateMacroMma(Mma shape, Mma mma); + void SetFinalConfig(Mma macro_mma, Mma mma); + + // common utils + int EstimateSharedSize(Mma alloc, int dtype); + int EstimateRegisterSize(Mma alloc, int dtype); + // Return a combination of total factor that can be divisible by shape_m and shape_n. + std::pair GetDivisibleFactorForMN(int64_t shape_m, int64_t shape_n, int64_t total_factor, Mma mma); + + std::string interested_attr_key = AT_GEMM; + int w0_for_m_{1}; + int w1_for_n_{1}; + TileAxis *m_axis_{nullptr}; + TileAxis *n_axis_{nullptr}; + TileAxis *k_axis_{nullptr}; + int sm_bytes_{1}; + int reg_bytes_{1}; + int64_t num_sm_{80}; + int64_t min_blocks_{2048}; + int64_t default_num_warps_{1}; + int64_t tile_stride_{32}; + Mma macro_mma_{128, 128, 32}; }; class GpuStrategy : public TilingStrategy { @@ -306,6 +355,7 @@ class GpuStrategy : public TilingStrategy { TRANSPOSE_OP, PAD_OP, CUSTOM_CONFIG, + CONV, TEMPLATE_BULK }; void AddNpuConstraint(); @@ -371,8 +421,9 @@ class GpuStrategy : public TilingStrategy { bool reverse_binding_{false}; int64_t fused_size_{1}; std::unordered_map template_map_ = { - {0, "DEFAULT"}, {1, "PURE_ELEM"}, {2, "BROADCAST_OP"}, {3, "REDUCTION"}, {4, "ALL_REDUCE"}, - {5, "BITWISE_REDUCTION"}, {6, "MATMUL"}, {7, "TRANSPOSE_OP"}, {8, "PAD_OP"}, {9, "CUSTOM_CONFIG"}}; + {0, "DEFAULT"}, {1, "PURE_ELEM"}, {2, "BROADCAST_OP"}, {3, "REDUCTION"}, + {4, "ALL_REDUCE"}, {5, "BITWISE_REDUCTION"}, {6, "MATMUL"}, {7, "TRANSPOSE_OP"}, + {8, "PAD_OP"}, {9, "CUSTOM_CONFIG"}, {10, "CONV"}}; }; class MulticoreStrategy { diff --git a/src/poly/tiling/tiling_strategy_manager_gpu.cc b/src/poly/tiling/tiling_strategy_manager_gpu.cc index 681dcf3f9ce2d975ffc6b0275c5810d470437f87..7f30ab28b8536dbb79990b66396bfc056db4f520 100755 --- a/src/poly/tiling/tiling_strategy_manager_gpu.cc +++ b/src/poly/tiling/tiling_strategy_manager_gpu.cc @@ -13,11 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "build_module.h" #include "tiling_strategy_manager.h" #include #include "tiling_analyzer.h" +#include "poly/schedule_pass_gpu/register_memory_manager.h" + namespace akg { namespace ir { namespace poly { @@ -30,24 +33,242 @@ void GpuDmaAnalysisStrategy::AddGpuConstraint() { void CastStrategy::AddGpuConstraint() { MarkDataSize(); } void GemmStrategy::AddGpuConstraint() { - if (!analyzer_->scop_info_.user_config_.GetEnableTensorCore()) { + if (!analyzer_->scop_info_.user_config_.GetEnableTensorCore() || + analyzer_->scop_info_.analysis_result_.GetIsGpuDmaAnalysed() || + analyzer_->scop_info_.user_config_.GetEnableConvTensorCore()) { return; } - auto interested_info = GetInterestedInfo(interested_attr_key); - for (auto it : interested_info) { - TileAxis *axis = it.first; - axis->TileRestrainToSingleValue(CastIntToExpr(64), TileLevel::CACHE1); - axis->TileRestrainToSingleValue(CastIntToExpr(16), TileLevel::CACHE0); - for (const auto &attr : it.second) { - if (attr.attr_value == "mi") { - axis->thread_constraints.map_min_ = warp_sizes_; - axis->thread_constraints.map_extent_ = warp_sizes_; - } else if (attr.attr_value == "ni") { - axis->thread_constraints.map_min_ = 4; - axis->thread_constraints.map_extent_ = 4; + + Mma mma = analyzer_->scop_info_.analysis_result_.GetMmaMode(); + + // Step 1. Collect Batch, M, N, K axis info. + std::unique_ptr shape = InitGemmShape(mma); + if (shape == nullptr) { + return; + } + + Mma middle_band = {shape->m / mma.m, shape->n / mma.n, shape->k / mma.k}; + std::stringstream ss; + ss << "[Gemm] M = " << shape->m << " N = " << shape->n << " K = " << shape->k << ", middle band = [" << middle_band.m + << ", " << middle_band.n << ", " << middle_band.k << "]"; + analyzer_->GetTileLogger().AppendLog(GPU_MAPPING, ss); + + GpuInfo &gpu_info = GpuInfo::GetInstance(); + sm_bytes_ = gpu_info.GetMemoryLimitInScope(MEM_SCOPE_SHARED); + sm_bytes_ = sm_bytes_ / 3 * 4; + reg_bytes_ = MAX_REGISTER_PER_THREAD_BLOCK * REGISTER_ALLOC_RATIO; + + auto b_axes = analyzer_->GetAxesOfAttr(AttrInfo{AT_GEMM, "bi"}); + for (auto bo : analyzer_->GetAxesOfAttr(AttrInfo{AT_GEMM, "bo"})) { + b_axes.push_back(bo); + } + for (auto b_axis : b_axes) { + CHECK(b_axis->range_extent.as()) << "Dynamic shape is not supported in tensor core for now."; + b_axis->TileRestrainToSingleValue(CastIntToExpr(MIN_TILE), CACHE1); + b_axis->TileRestrainToSingleValue(CastIntToExpr(MIN_TILE), CACHE0); + b_axis->thread_constraints.map_min_ = MIN_TILE; + b_axis->thread_constraints.map_extent_ = MIN_TILE; + min_blocks_ /= b_axis->range_extent.as()->value; + ss << "Map batch axis " << b_axis->range_extent.as()->value << " to block."; + analyzer_->GetTileLogger().AppendLog(GPU_MAPPING, ss); + } + min_blocks_ = std::max(1, min_blocks_); + + // Step 2. Calculate macro M, N, K tile size. + CalculateMacroMma(*shape, mma); + + // Step 3. Calculate possible number of warps. + auto warp_sizes = CalculateNumOfWarps(mma); + std::tie(w0_for_m_, w1_for_n_) = warp_sizes; + middle_band.m /= w0_for_m_; + middle_band.n /= w1_for_n_; + std::string warp_cfg = std::to_string(w0_for_m_) + " " + std::to_string(w1_for_n_); + analyzer_->scop_info_.user_config_.RecordReplaceConfig(WARP_COMPUTE, warp_cfg, MappingType::REPLACE_THREADS); + + // Step 4. Set mapping and tiling config. + SetFinalConfig(macro_mma_, mma); +} + +std::pair GemmStrategy::CalculateNumOfWarps(Mma mma) { + int w0 = 1; + int w1 = 1; + int use_local_group = (macro_mma_.m / mma.m) * (macro_mma_.n / mma.n); + CHECK_GE(use_local_group, 1); + if (use_local_group > 8) { + default_num_warps_ = 4; + } else if (use_local_group > 1) { + default_num_warps_ = 2; + } + std::tie(w0, w1) = GetDivisibleFactorForMN(macro_mma_.m, macro_mma_.n, default_num_warps_, mma); + std::stringstream ss; + ss << "[Gemm] Try warp " << default_num_warps_ << " -> " << w0 << " * " << w1; + analyzer_->GetTileLogger().AppendLog(GPU_MAPPING, ss); + return std::make_pair(w0, w1); +} + +std::unique_ptr GemmStrategy::InitGemmShape(Mma mma) { + auto m_axes = analyzer_->GetAxesOfAttr(AttrInfo{AT_GEMM, "mi"}); + auto n_axes = analyzer_->GetAxesOfAttr(AttrInfo{AT_GEMM, "ni"}); + auto k_axes = analyzer_->GetAxesOfAttr(AttrInfo{AT_GEMM, "ki"}); + if (m_axes.size() != 1U || n_axes.size() != 1U || k_axes.size() != 1U) { + return nullptr; + } + + m_axis_ = m_axes[0]; + n_axis_ = n_axes[0]; + k_axis_ = k_axes[0]; + if (m_axis_->range_extent.as() == nullptr || n_axis_->range_extent.as() == nullptr || + k_axis_->range_extent.as() == nullptr) { + return nullptr; + } + auto shape_m = m_axis_->range_extent.as()->value; + auto shape_n = n_axis_->range_extent.as()->value; + auto shape_k = k_axis_->range_extent.as()->value; + CHECK_EQ(shape_m % mma.m, 0) << "Shape m " << shape_m << " should be multiples of mma.m " << mma.m + << " to enable tensor core."; + CHECK_EQ(shape_n % mma.n, 0) << "Shape n " << shape_n << " should be multiples of mma.n " << mma.n + << " to enable tensor core."; + CHECK_EQ(shape_k % mma.k, 0) << "Shape k " << shape_k << " should be multiples of mma.k " << mma.k + << " to enable tensor core."; + + return std::unique_ptr(new (std::nothrow) Mma{shape_m, shape_n, shape_k}); +} + +int GemmStrategy::EstimateSharedSize(Mma alloc, int dtype) { + std::string a_major = ROW_MAJOR; + std::string b_major = ROW_MAJOR; + auto major_map = analyzer_->scop_info_.analysis_result_.GetMatrixMatmulMajor(); + auto matmul_map = analyzer_->scop_info_.analysis_result_.GetMatrixMatmulMap(); + for (auto i : matmul_map) { + if (i.second == MATRIX_A) { + CHECK(major_map.find(i.first) != major_map.end()); + a_major = major_map[i.first]; + } else if (i.second == MATRIX_B) { + CHECK(major_map.find(i.first) != major_map.end()); + b_major = major_map[i.first]; + } + } + + // bank conflit avoid strategy + auto matrix_a_size = a_major == ROW_MAJOR ? (alloc.m * (alloc.k + 16)) : ((alloc.m + 16) * alloc.k); + auto matrix_b_size = b_major == COL_MAJOR ? (alloc.n * (alloc.k + 16)) : ((alloc.n + 16) * alloc.k); + auto matrix_c_size = alloc.m * alloc.n; + auto alloc_shared = (matrix_a_size + matrix_b_size) * dtype; // single op does not alloc shared for matrix_c + std::stringstream ss; + ss << "[Shared] A(" << a_major << "), B(" << b_major << "); This config results matrix_a_size = " << matrix_a_size + << " matrix_b_size = " << matrix_b_size << " matrix_c_size = " << matrix_c_size + << " --> alloc shared = " << alloc_shared; + analyzer_->GetTileLogger().AppendLog(GPU_MAPPING, ss); + return alloc_shared; +} + +int GemmStrategy::EstimateRegisterSize(Mma alloc, int dtype) { + auto alloc_reg_unit = std::max(1, dtype / BYTES_PER_REGISTER); + auto matrix_a_size = alloc.m * alloc.k * 2; + auto matrix_b_size = alloc.n * alloc.k * 2; + auto matrix_c_size = alloc.m * alloc.n; + auto alloc_reg = (matrix_a_size + matrix_b_size + matrix_c_size) * alloc_reg_unit; + std::stringstream ss; + ss << "[Reg] This config results matrix_a_size = " << matrix_a_size << " matrix_b_size = " << matrix_b_size + << " matrix_c_size = " << matrix_c_size << " --> alloc reg = " << alloc_reg; + analyzer_->GetTileLogger().AppendLog(GPU_MAPPING, ss); + + return alloc_reg; +} + +void GemmStrategy::CalculateMacroMma(Mma shape, Mma mma) { + std::stringstream ss; + Mma default_macro_mma = macro_mma_; + Mma macro_mma = {std::min(macro_mma_.m, shape.m), std::min(macro_mma_.n, shape.n), + std::min(macro_mma_.k, shape.k)}; + ss << "[Init macro mma]: [" << macro_mma.m << ", " << macro_mma.n << ", " << macro_mma.k << "]"; + analyzer_->GetTileLogger().AppendLog(GPU_MAPPING, ss); + while (shape.m % macro_mma_.m != 0 && macro_mma_.m - tile_stride_ >= mma.m) { + macro_mma_.m -= tile_stride_; + } + while (shape.n % macro_mma_.n != 0 && macro_mma_.n - tile_stride_ >= mma.n) { + macro_mma_.n -= tile_stride_; + } + if (shape.m % macro_mma_.m != 0) { + macro_mma_.m /= 2; + } + if (shape.n % macro_mma_.n != 0) { + macro_mma_.n /= 2; + } + while (shape.k % macro_mma_.k != 0 && macro_mma_.k / 2 >= mma.k) { + macro_mma_.k /= 2; + } + while ((shape.m / macro_mma_.m) * (shape.n / macro_mma_.n) < min_blocks_ && macro_mma_.m == default_macro_mma.m && + macro_mma_.n == default_macro_mma.n) { + (shape.m < shape.n) ? macro_mma_.m /= 2 : macro_mma_.n /= 2; + } + if ((shape.m / macro_mma_.m) * (shape.n / macro_mma_.n) < min_blocks_ && shape.k % (macro_mma_.k * 2) == 0 && + shape.k / (macro_mma_.k * 2) > 1) { + macro_mma_.k *= 2; + } + if (shape.k == macro_mma_.k) { + g_attrs.Set(kEnableTransferBuffer, air::make_const(Int(32), false)); + } + ss << "[Final macro mma]: [" << macro_mma.m << ", " << macro_mma.n << ", " << macro_mma.k << "]"; + analyzer_->GetTileLogger().AppendLog(GPU_MAPPING, ss); +} + +void GemmStrategy::SetFinalConfig(Mma macro_mma, Mma mma) { + std::stringstream ss; + m_axis_->TileRestrainToSingleValue(CastIntToExpr(macro_mma.m), CACHE1); + m_axis_->thread_constraints.map_min_ = w0_for_m_ * w1_for_n_; + m_axis_->thread_constraints.map_extent_ = w0_for_m_ * w1_for_n_; + m_axis_->TileRestrainToSingleValue(CastIntToExpr(mma.m), CACHE0); + + n_axis_->TileRestrainToSingleValue(CastIntToExpr(macro_mma.n), CACHE1); + n_axis_->thread_constraints.map_min_ = warp_sizes_; + n_axis_->thread_constraints.map_extent_ = warp_sizes_; + n_axis_->TileRestrainToSingleValue(CastIntToExpr(mma.n), CACHE0); + + k_axis_->TileRestrainToSingleValue(CastIntToExpr(macro_mma.k), CACHE1); + k_axis_->thread_constraints.map_min_ = MIN_TILE; + k_axis_->thread_constraints.map_extent_ = MIN_TILE; + k_axis_->TileRestrainToSingleValue(CastIntToExpr(mma.k), CACHE0); + ss << "[Final config] : L1(M, N, K) = " << macro_mma.m << ", " << macro_mma.n << ", " << macro_mma.k; + ss << "; L0(M, N, K) = " << mma.m << ", " << mma.n << ", " << mma.k; + ss << "; Thread(W0, W1, TX) = " << w0_for_m_ << ", " << w1_for_n_ << ", " << warp_sizes_; + analyzer_->GetTileLogger().AppendLog(GPU_MAPPING, ss); +} + +std::pair GemmStrategy::GetDivisibleFactorForMN(int64_t shape_m, int64_t shape_n, + int64_t total_factor, Mma mma) { + auto TryCombination = [&shape_m, &shape_n, &mma](int64_t factor1, int64_t factor2) -> bool { + return (shape_m % factor1 == 0 && shape_n % factor2 == 0 && shape_m / factor1 >= mma.m && + shape_n / factor2 >= mma.n); + }; + auto SwapWarp = [&shape_m, &shape_n, &mma](int64_t w0, int64_t w1) -> std::pair { + int64_t max_w0 = shape_m / mma.m; + int64_t max_w1 = shape_n / mma.n; + if ((max_w0 - max_w1 > 0) ^ (w0 - w1 > 0)) { + return std::make_pair(w1, w0); + } + return std::make_pair(w0, w1); + }; + int64_t w0 = std::sqrt(total_factor); + int64_t w1 = total_factor / w0; + CHECK_EQ(w0 * w1, total_factor); + std::tie(w0, w1) = SwapWarp(w0, w1); + + if (TryCombination(w0, w1)) { + return std::make_pair(w0, w1); + } else { + while (total_factor > 1) { + total_factor /= 2; + w0 = std::sqrt(total_factor); + w1 = total_factor / w0; + CHECK_EQ(w0 * w1, total_factor); + std::tie(w0, w1) = SwapWarp(w0, w1); + if (TryCombination(w0, w1)) { + return std::make_pair(w0, w1); } } } + return std::make_pair(1, 1); } void ReduceStrategy::AddGpuConstraint() { @@ -530,7 +751,8 @@ void GpuStrategy::AddGpuConstraint() { InjectiveSpeedup(); } SetMappingConfig(); - if (template_ != Template::MATMUL || !analyzer_->scop_info_.user_config_.GetEnableTensorCore()) { + if (!((template_ == Template::MATMUL || template_ == Template::CONV) && + analyzer_->scop_info_.user_config_.GetEnableTensorCore())) { analyzer_->ForEachAxisTopDown([this](TileAxis *axis) { if (axis == analyzer_->RootAxis()) { return; @@ -604,6 +826,8 @@ void GpuStrategy::InitMappingLimit() { block_limit_ = {max_x_dim_block_, max_y_z_dim_block_, max_y_z_dim_block_}; } else if (template_ == Template::ALL_REDUCE && !analyzer_->scop_info_.user_config_.GetEnableAkgReduceLib()) { block_limit_ = {1}; + } else if (template_ == Template::CONV) { + block_limit_ = {max_x_dim_block_, max_y_z_dim_block_, max_y_z_dim_block_, max_y_z_dim_block_}; } else { block_limit_ = {max_x_dim_block_, max_y_z_dim_block_, max_y_z_dim_block_}; } @@ -622,7 +846,8 @@ void GpuStrategy::BuildAxesQueue() { return; } const auto r = axis->range_extent.as(); - if (r && r->value > 0) { + // For Conv, kh and kw are invalid for pending_axes + if (r && r->value > 0 && !axis->is_inner) { this->pending_axes_.push_front(std::make_pair(axis, r->value)); } @@ -645,6 +870,10 @@ void GpuStrategy::InnerThreadOuterBlock() { auto thread_dim = std::min(thread_limit_.size(), max_dim_); auto block_dim = std::min(block_limit_.size(), max_dim_); + if (analyzer_->scop_info_.user_config_.GetEnableConvTensorCore()) { + block_dim = block_limit_.size(); + } + // tile from inner to outer and map to thread analyzer_->GetTileLogger().AppendLine(GPU_MAPPING, "-----Map to thread-----"); ss << "[Thread Limit]: "; @@ -702,6 +931,13 @@ void GpuStrategy::InnerThreadOuterBlock() { SkipMapping(); continue; } + + // For Conv, hi and wi are invalid for thread mapping + if (axis->HasAttr(AttrInfo{AT_CONV, "hi"}) || axis->HasAttr(AttrInfo{AT_CONV, "wi"})) { + SkipMapping(); + continue; + } + if (rest_threads <= 1) { if (axis->mc_sup || (template_ == Template::REDUCTION && analyzer_->scop_info_.user_config_.GetEnableAkgReduceLib())) { @@ -855,14 +1091,33 @@ void GpuStrategy::SetMappingConfig() { } } - for (size_t i = block_cfg_.size(); i < 2; ++i) { - block_cfg_.emplace_back(1); - } - for (int i = block_cfg_.size() - 1; i >= 0; --i) { - if (i >= block_count_) { - continue; + if (analyzer_->scop_info_.user_config_.GetEnableConvTensorCore()) { + // For Conv, the H and W axis should mul to map + // The axis sequence is M H W OC + constexpr auto h_axis_index = 2; + constexpr auto w_axis_index = 1; + for (int i = block_cfg_.size() - 1; i >= 0; --i) { + if (i >= block_count_) { + continue; + } + if (i == h_axis_index) { + block_str += (std::to_string(block_cfg_[h_axis_index] * block_cfg_[w_axis_index]) + " "); + i = w_axis_index; + } else { + block_str += (std::to_string(block_cfg_[i]) + " "); + } + } + } else { + // pad binding to at least two dim to bind reduce axis at block y + for (size_t i = block_cfg_.size(); i < 2; ++i) { + block_cfg_.emplace_back(1); + } + for (int i = block_cfg_.size() - 1; i >= 0; --i) { + if (i >= block_count_) { + continue; + } + block_str += (std::to_string(block_cfg_[i]) + " "); } - block_str += (std::to_string(block_cfg_[i]) + " "); } ss << "Block config = " << block_str; @@ -874,6 +1129,10 @@ void GpuStrategy::SetMappingConfig() { if (axis == analyzer_->RootAxis()) { return; } + // For Conv, kh and kw are invalid for Tile + if (axis->is_inner) { + return; + } ss << axis->c1_constraints.tile_extent_ << ","; }); analyzer_->GetTileLogger().AppendLog(GPU_MAPPING, ss); @@ -1007,6 +1266,11 @@ void GpuStrategy::DetermineTemplate() { return; } + if (!analyzer_->GetAxesOfAttr(AT_CONV).empty()) { + template_ = Template::CONV; + return; + } + auto reduce_axes_ = analyzer_->GetAxesOfAttr(AT_REDUCE_AXIS); if (reduce_axes_.empty()) { @@ -1488,6 +1752,234 @@ void CustomTilingStrategy::AddGpuConstraint() { } } +void ConvStrategy::AddGpuConstraint() { + if (!analyzer_->scop_info_.user_config_.GetEnableTensorCore() || + analyzer_->scop_info_.analysis_result_.GetIsGpuDmaAnalysed() || + !analyzer_->scop_info_.user_config_.GetEnableConvTensorCore()) { + return; + } + + Mma mma = analyzer_->scop_info_.analysis_result_.GetMmaMode(); + + // Step 1. Collect M, H, W, N, K axis info. + std::unique_ptr shape = InitGemmShape(mma); + if (shape == nullptr) { + return; + } + + MmaConv middle_band = {shape->m / mma.m, shape->h, shape->w, shape->n / mma.n, shape->k / mma.k}; + std::stringstream ss; + ss << "[Conv] M = " << shape->m << " H = " << shape->h << " W = " << shape->w << " N = " << shape->n + << " K = " << shape->k << ", middle band = [" << middle_band.m << ", " << middle_band.h << ", " << middle_band.w + << middle_band.n << ", " << middle_band.k << "]"; + analyzer_->GetTileLogger().AppendLog(GPU_MAPPING, ss); + + // Step 2. Calculate macro M, H, W, N, K tile size. + CalculateMacroMma(*shape, mma); + + // Step 3. Calculate possible number of warps. + auto warp_sizes = CalculateNumOfWarps(mma); + std::tie(w0_for_m_, w1_for_n_) = warp_sizes; + middle_band.m /= w0_for_m_; + middle_band.n /= w1_for_n_; + std::string warp_cfg = std::to_string(w0_for_m_) + " " + std::to_string(w1_for_n_); + analyzer_->scop_info_.user_config_.RecordReplaceConfig(WARP_COMPUTE, warp_cfg, MappingType::REPLACE_THREADS); + analyzer_->GetTileLogger().AppendLog(GPU_MAPPING, ss); + + // Step 4. Set mapping and tiling config. + SetFinalConfig(macro_mma_, mma); +} + +std::pair ConvStrategy::CalculateNumOfWarps(Mma mma) { + int w0 = 1; + int w1 = 1; + // H and W do not participate in the calculation of the warp level + int use_local_group = (macro_mma_.m / mma.m) * (macro_mma_.n / mma.n); + CHECK_GE(use_local_group, 1); + if (use_local_group >= 8) { + default_num_warps_ = 4; + } else if (use_local_group > 1) { + default_num_warps_ = 2; + } + std::tie(w0, w1) = GetDivisibleFactorForMN(macro_mma_.m, macro_mma_.n, default_num_warps_, mma); + std::stringstream ss; + ss << "[Conv] Try warp " << default_num_warps_ << " -> " << w0 << " * " << w1; + analyzer_->GetTileLogger().AppendLog(GPU_MAPPING, ss); + return std::make_pair(w0, w1); +} + +std::unique_ptr ConvStrategy::InitGemmShape(Mma mma) { + auto m_axes = analyzer_->GetAxesOfAttr(AttrInfo{AT_CONV, "mi"}); + auto h_axes = analyzer_->GetAxesOfAttr(AttrInfo{AT_CONV, "hi"}); + auto w_axes = analyzer_->GetAxesOfAttr(AttrInfo{AT_CONV, "wi"}); + auto n_axes = analyzer_->GetAxesOfAttr(AttrInfo{AT_CONV, "oc"}); + auto k_axes = analyzer_->GetAxesOfAttr(AttrInfo{AT_CONV, "ic"}); + if (m_axes.size() != 1U || h_axes.size() != 1U || w_axes.size() != 1U || n_axes.size() != 1U || k_axes.size() != 1U) { + return nullptr; + } + + m_axis_ = m_axes[0]; + h_axis_ = h_axes[0]; + w_axis_ = w_axes[0]; + n_axis_ = n_axes[0]; + k_axis_ = k_axes[0]; + if (m_axis_->range_extent.as() == nullptr || h_axis_->range_extent.as() == nullptr || + w_axis_->range_extent.as() == nullptr || n_axis_->range_extent.as() == nullptr || + k_axis_->range_extent.as() == nullptr) { + return nullptr; + } + auto shape_m = m_axis_->range_extent.as()->value; + auto shape_h = h_axis_->range_extent.as()->value; + auto shape_w = w_axis_->range_extent.as()->value; + auto shape_n = n_axis_->range_extent.as()->value; + auto shape_k = k_axis_->range_extent.as()->value; + CHECK_EQ(shape_m % mma.m, 0) << "Shape m " << shape_m << " should be multiples of mma.m " << mma.m + << " to enable tensor core."; + CHECK_EQ(shape_n % mma.n, 0) << "Shape n " << shape_n << " should be multiples of mma.n " << mma.n + << " to enable tensor core."; + CHECK_EQ(shape_k % mma.k, 0) << "Shape k " << shape_k << " should be multiples of mma.k " << mma.k + << " to enable tensor core."; + + return std::unique_ptr(new (std::nothrow) MmaConv{shape_m, shape_h, shape_w, shape_n, shape_k}); +} + +void ConvStrategy::CalculateMacroMma(MmaConv shape, Mma mma) { + std::stringstream ss; + MmaConv macro_mma = {std::min(macro_mma_.m, shape.m), std::min(macro_mma_.h, shape.h), + std::min(macro_mma_.w, shape.w), std::min(macro_mma_.n, shape.n), + std::min(macro_mma_.k, shape.k)}; + ss << "[Init macro mma]: [" << macro_mma.m << ", " << macro_mma.h << ", " << macro_mma.w << ", " << macro_mma.n + << ", " << macro_mma.k << "]"; + analyzer_->GetTileLogger().AppendLog(GPU_MAPPING, ss); + while (shape.m % macro_mma_.m != 0 && macro_mma_.m / 2 >= mma.m) { + macro_mma_.m /= 2; + } + while (shape.n % macro_mma_.n != 0 && macro_mma_.n / 2 >= mma.n) { + macro_mma_.n /= 2; + } + while (shape.k % macro_mma_.k != 0 && macro_mma_.k / 2 >= mma.k) { + macro_mma_.k /= 2; + } + + // Data volume in the M direction and data volume in the N direction should be close + if (macro_mma_.m > macro_mma_.n) { + while (macro_mma_.m > macro_mma_.n) { + macro_mma_.m /= 2; + } + } else if (macro_mma_.m < macro_mma_.n) { + // split h and w direction, increase the data volume + int temp_h = shape.h; + int temp_w = shape.w; + while (macro_mma_.m * macro_mma_.w * macro_mma_.h < macro_mma_.n) { + if (temp_w % 2 == 0) { + macro_mma_.w *= 2; + temp_w /= 2; + } else if (temp_h % 2 == 0) { + macro_mma_.h *= 2; + temp_h /= 2; + } else { + break; + } + } + } + + while ((shape.m / macro_mma_.m) * (shape.h / macro_mma_.h) * (shape.w / macro_mma_.w) * (shape.n / macro_mma_.n) < + min_blocks_ && + macro_mma_.m / mma.m > 4 && macro_mma_.n / mma.n > 4) { + // decrease h and increase the use of block + if (macro_mma_.h % 2 == 0) { + macro_mma_.h /= 2; + continue; + } + + // decrease w and increase the use of block + if (macro_mma_.w % 2 == 0) { + macro_mma_.w /= 2; + continue; + } + + (shape.m < shape.n) ? macro_mma_.m /= 2 : macro_mma_.n /= 2; + } + + if ((shape.m / macro_mma_.m) * (shape.h / macro_mma_.h) * (shape.w / macro_mma_.w) * (shape.n / macro_mma_.n) < + min_blocks_ && + shape.k % (macro_mma_.k * 2) == 0 && shape.k / (macro_mma_.k * 2) > 1) { + macro_mma_.k *= 2; + } + ss << "[Final macro mma]: [" << macro_mma.m << ", " << macro_mma.h << ", " << macro_mma.w << ", " << macro_mma.n + << ", " << macro_mma.k << "]"; + analyzer_->GetTileLogger().AppendLog(GPU_MAPPING, ss); +} + +void ConvStrategy::SetFinalConfig(MmaConv macro_mma, Mma mma) { + std::stringstream ss; + m_axis_->TileRestrainToSingleValue(CastIntToExpr(macro_mma.m), CACHE1); + m_axis_->thread_constraints.map_min_ = w0_for_m_ * w1_for_n_; + m_axis_->thread_constraints.map_extent_ = w0_for_m_ * w1_for_n_; + m_axis_->TileRestrainToSingleValue(CastIntToExpr(mma.m), CACHE0); + + h_axis_->TileRestrainToSingleValue(CastIntToExpr(macro_mma.h), CACHE1); + h_axis_->thread_constraints.map_min_ = MIN_TILE; + h_axis_->thread_constraints.map_extent_ = MIN_TILE; + h_axis_->TileRestrainToSingleValue(CastIntToExpr(1), CACHE0); + + w_axis_->TileRestrainToSingleValue(CastIntToExpr(macro_mma.w), CACHE1); + w_axis_->thread_constraints.map_min_ = MIN_TILE; + w_axis_->thread_constraints.map_extent_ = MIN_TILE; + w_axis_->TileRestrainToSingleValue(CastIntToExpr(1), CACHE0); + + n_axis_->TileRestrainToSingleValue(CastIntToExpr(macro_mma.n), CACHE1); + n_axis_->thread_constraints.map_min_ = warp_sizes_; + n_axis_->thread_constraints.map_extent_ = warp_sizes_; + n_axis_->TileRestrainToSingleValue(CastIntToExpr(mma.n), CACHE0); + + k_axis_->TileRestrainToSingleValue(CastIntToExpr(macro_mma.k), CACHE1); + k_axis_->thread_constraints.map_min_ = MIN_TILE; + k_axis_->thread_constraints.map_extent_ = MIN_TILE; + k_axis_->TileRestrainToSingleValue(CastIntToExpr(mma.k), CACHE0); + ss << "[Final config] : L1(M, H, W, N, K) = " << macro_mma.m << ", " << macro_mma.h << ", " << macro_mma.w << ", " + << macro_mma.n << ", " << macro_mma.k; + ss << "; L0(M, H, W, N, K) = " << mma.m << ", " << 1 << ", " << 1 << ", " << mma.n << ", " << mma.k; + ss << "; Thread(W0, W1, TX) = " << w0_for_m_ << ", " << w1_for_n_ << ", " << warp_sizes_; + analyzer_->GetTileLogger().AppendLog(GPU_MAPPING, ss); +} + +std::pair ConvStrategy::GetDivisibleFactorForMN(int64_t shape_m, int64_t shape_n, + int64_t total_factor, Mma mma) { + auto TryCombination = [&shape_m, &shape_n, &mma](int64_t factor1, int64_t factor2) -> bool { + return (shape_m % factor1 == 0 && shape_n % factor2 == 0 && shape_m / factor1 >= mma.m && + shape_n / factor2 >= mma.n); + }; + auto SwapWarp = [&shape_m, &shape_n, &mma](int64_t w0, int64_t w1) -> std::pair { + int64_t max_w0 = shape_m / mma.m; + int64_t max_w1 = shape_n / mma.n; + if ((max_w0 - max_w1 > 0) ^ (w0 - w1 > 0)) { + return std::make_pair(w1, w0); + } + return std::make_pair(w0, w1); + }; + int64_t w0 = std::sqrt(total_factor); + int64_t w1 = total_factor / w0; + CHECK_EQ(w0 * w1, total_factor); + std::tie(w0, w1) = SwapWarp(w0, w1); + + if (TryCombination(w0, w1)) { + return std::make_pair(w0, w1); + } else { + while (total_factor > 1) { + total_factor /= 2; + w0 = std::sqrt(total_factor); + w1 = total_factor / w0; + CHECK_EQ(w0 * w1, total_factor); + std::tie(w0, w1) = SwapWarp(w0, w1); + if (TryCombination(w0, w1)) { + return std::make_pair(w0, w1); + } + } + } + return std::make_pair(1, 1); +} + // No constraint found in cuda void ModStrategy::AddGpuConstraint() {} @@ -1510,8 +2002,6 @@ void ShiftAxisStrategy::AddGpuConstraint() {} void ModShiftAxisStrategy::AddGpuConstraint() {} -void ConvStrategy::AddGpuConstraint() {} - // end of null constraint } // namespace poly diff --git a/src/poly/tiling/tiling_utils.cc b/src/poly/tiling/tiling_utils.cc index 570bfa371e6109cecc76b5d23fd917bc75fd2bcc..432adf81e6af4e1af9a7cfa934b87a86cd753195 100644 --- a/src/poly/tiling/tiling_utils.cc +++ b/src/poly/tiling/tiling_utils.cc @@ -189,6 +189,56 @@ std::unordered_map ExtractLoopIndicesFromMatrices(std: return cube_var_map; } +std::unordered_map ExtractLoopIndicesFromMatricesConv(std::vector var_names_list) { + CHECK_EQ(var_names_list.size(), 3) + << "Matmul should have exactly three matrices in C(output), A(lhs) and B(rhs) order."; + VarNames mx_c = var_names_list[0]; + VarNames mx_a = var_names_list[1]; + VarNames mx_b = var_names_list[2]; + + VarNames gemm_m, gemm_n, gemm_k; + std::unordered_set stack; + + for (const auto &var : mx_a) { + stack.insert(var); + } + + // 1. N = B_vars - A_vars; + // [B, K] = A_vars & B_vars + for (const auto &var : mx_b) { + auto it = stack.find(var); + if (it != stack.end()) { + gemm_k.emplace_back(var); + stack.erase(it); + } else { + gemm_n.emplace_back(var); + } + } + + // 2. M = A_vars - B - K + for (const auto &var : mx_a) { + if (stack.find(var) != stack.end()) { + gemm_m.emplace_back(var); + } + } + + CHECK_LE(gemm_m.size(), ConvFormatM.size()); + CHECK_LE(gemm_n.size(), ConvFormatN.size()); + CHECK_LE(gemm_k.size(), ConvFormatK.size()); + + std::unordered_map cube_var_map; + for (auto i = static_cast(gemm_m.size()) - 1; i >= 0; --i) { + cube_var_map[gemm_m[i]] = ConvFormatM[static_cast(gemm_m.size()) - 1 - i]; + } + for (auto i = static_cast(gemm_n.size()) - 1; i >= 0; --i) { + cube_var_map[gemm_n[i]] = ConvFormatN[static_cast(gemm_n.size()) - 1 - i]; + } + for (auto i = static_cast(gemm_k.size()) - 1; i >= 0; --i) { + cube_var_map[gemm_k[i]] = ConvFormatK[static_cast(gemm_k.size()) - 1 - i]; + } + return cube_var_map; +} + VarNames VisitVarNames(const air::Expr &arg, VarNames var_names, bool add_num) { if (const auto var = arg.as()) { var_names.emplace_back(var->name_hint); @@ -221,6 +271,15 @@ VarNames VisitVarNames(const air::Expr &arg, VarNames var_names, bool add_num) { return var_names; } +bool IsNum(const std::string &name) { + for (auto c : name) { + if (c > '9' || c < '0') { + return false; + } + } + return true; +}; + } // namespace poly } // namespace ir } // namespace akg diff --git a/src/poly/tiling/tiling_utils.h b/src/poly/tiling/tiling_utils.h index 35fefcec6540e4fc9ab0ec3ee20ae14d4ce386d6..e7f5eb57584797e018b89a535fbd7860179ba170 100644 --- a/src/poly/tiling/tiling_utils.h +++ b/src/poly/tiling/tiling_utils.h @@ -141,9 +141,12 @@ using Band = std::vector; using VarNames = std::vector; std::unordered_map ExtractLoopIndicesFromMatrices(std::vector var_names_list); +std::unordered_map ExtractLoopIndicesFromMatricesConv(std::vector var_names_list); VarNames VisitVarNames(const air::Expr &arg, VarNames var_names, bool add_num = true); +bool IsNum(const std::string &name); + /* Data format definition */ const VarNames DsaNCHW = {"N", "C", "H", "W", "C0"}; const VarNames DsaNHWCC0 = {"N", "H", "W", "C", "C0"}; @@ -161,6 +164,10 @@ const VarNames FormatN = {"ni", "no"}; const VarNames FormatK = {"ki", "ko"}; const VarNames FormatB = {"bi", "bo"}; +const VarNames ConvFormatM = {"wi", "hi", "mi"}; +const VarNames ConvFormatN = {"oc"}; +const VarNames ConvFormatK = {"ic", "kw", "kh"}; + } // namespace poly } // namespace ir } // namespace akg diff --git a/src/schedule/auto_fuse.cc b/src/schedule/auto_fuse.cc index 136c33ea75455fe71f080fded342e7ea40fcb637..52040ec78a1440f5c106aadb649627e5485fdab3 100644 --- a/src/schedule/auto_fuse.cc +++ b/src/schedule/auto_fuse.cc @@ -19,6 +19,7 @@ #include #include +#include "common/common_util.h" #include "pass/utils.h" struct FuncIndex { @@ -239,10 +240,17 @@ class FuseCheck { for (const auto &s : sch_->stages) { auto op = s->op; CHECK(op.defined()); + auto tensor = op.output(0); auto compute_op = op.as(); if (compute_op && !compute_op->reduce_axis.empty()) { // For the matmul, do not perform fuse if (IsMatmul(op)) { + IterVar fused_axis; + Array need_fused_axis; + for (size_t i = 0; i < compute_op->axis.size() - 2; ++i) { + need_fused_axis.push_back(compute_op->axis[i]); + } + sch_[tensor].fuse(need_fused_axis, &fused_axis); return false; } // Restrictions related to the Shared memory @@ -301,8 +309,8 @@ class FuseCheck { return false; } auto mul = source[0].as(); - auto left = mul->a.as(); - auto right = mul->b.as(); + auto left = akg::common::SplitCast(mul->a, compute_op->output_dtype(0)).as(); + auto right = akg::common::SplitCast(mul->b, compute_op->output_dtype(0)).as(); if (!left || !right || left->args.size() != right->args.size()) { return false; } diff --git a/tests/common/test_run/transdata_matmul_run.py b/tests/common/test_run/transdata_matmul_run.py index 0bdba0861e1b834080a7379b4662b7ab23e690e9..82df4e3920fc6d0ca7947cc2d969e9014c57483e 100644 --- a/tests/common/test_run/transdata_matmul_run.py +++ b/tests/common/test_run/transdata_matmul_run.py @@ -88,4 +88,4 @@ def transdata_matmul_execute(shape_x, shape_y, bias, left_format, right_format, # compare result rtol, atol = get_rtol_atol("matmul", dtype) compare_result = compare_tensor(output, bench_mark, rtol=rtol, atol=atol, equal_nan=True) - return (m_x, m_y), output, bench_mark, compare_result \ No newline at end of file + return (m_x, m_y), output, bench_mark, compare_result diff --git a/tests/operators/gpu/test_all.py b/tests/operators/gpu/test_all.py index 9e98bc33f396b22b0ac5ff695ed1475f740adb1b..ceb2c5b84e756ae98110d75e8030ed3a86590932 100644 --- a/tests/operators/gpu/test_all.py +++ b/tests/operators/gpu/test_all.py @@ -46,6 +46,8 @@ from tests.operators.gpu.test_ms_greater_equal import test_ms_greater_equal from tests.operators.gpu.test_ms_reciprocal import test_ms_reciprocal from tests.operators.gpu.test_ms_reduce_max import test_ms_reduce_max from tests.operators.gpu.test_ms_reduce_min import test_ms_reduce_min +from tests.operators.gpu.test_ms_conv import test_ms_conv +from tests.operators.gpu.test_ms_conv_tensorcore import test_ms_conv_tc from tests.operators.gpu.test_fused_pad import test_fused_pad from tests.operators.gpu.test_fused_bn_reduce import test_fused_bn_reduce from tests.operators.gpu.test_fused_bn_update import test_fused_bn_update @@ -80,83 +82,43 @@ def addn(poly_sch, fuzz_shape=None, mind_trick_str=''): def bmm(poly_sch, fuzz_shape=None, mind_trick_str=''): + # Test for FP32 MatMul (Non-TensorCore) + test_ms_bmm((768, 768), (768, 768), 'float32', 'float32', layout1='NHDT', layout2='NHDT', layout_out='NHDT', + shape_bias=(1, ), add_bias=False, tensor_core=False, poly_sch=poly_sch) + + # Test for FP16 MatMul (Enable TensorCore) + test_ms_bmm((768, 768), (768, 768), 'float16', 'float16', layout1='NHDT', layout2='NHTD', layout_out='NHDT', + shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, + attrs={"dim": "0 0 128 16 0 1 128 16 0 2 32 8", "bind_block": "6 6", "bind_thread": "128 1"}) + test_ms_bmm((768, 768), (768, 768), 'float16', 'float32', layout1='NHDT', layout2='NHDT', layout_out='NHDT', + shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, + attrs={"dim": "0 0 128 16 0 1 128 16 0 2 32 8", "bind_block": "6 6", "bind_thread": "128 1"}) + test_ms_bmm((32, 12, 128, 128), (32, 12, 128, 64), 'float16', 'float32', layout1='NHDT', layout2='NHTD', layout_out='NHDT', + shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, + attrs={"dim": "0 0 1 1 0 1 64 16 0 2 64 16 0 3 64 8", "bind_block": "1 384", "bind_thread": "128 1"}) test_ms_bmm((768, 768), (768, 768), 'float16', 'float16', layout1='NHDT', layout2='NHDT', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 128 128 0 1 128 128 0 2 64 4", bind_block="6 6", bind_thread="32 4") + shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, + attrs={"dim": "0 0 128 16 0 1 128 16 0 2 64 8", "bind_block": "6 6", "bind_thread": "128 1"}) test_ms_bmm((768, 768), (768, 768), 'float16', 'float16', layout1='NHTD', layout2='NHTD', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 128 128 0 1 128 128 0 2 64 4", bind_block="6 6", bind_thread="32 8") + shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, + attrs={"dim": "0 0 128 32 0 1 128 32 0 2 64 4", "bind_block": "6 6", "bind_thread": "256 1"}) test_ms_bmm((32, 12, 128, 128), (32, 12, 128, 64), 'float16', 'float16', layout1='NHDT', layout2='NHTD', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 1 1 0 1 1 1 0 2 64 64 0 3 64 64 0 4 64 4", bind_block="12 32", bind_thread="32 4") + shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, + attrs={"dim": "0 0 1 1 0 1 64 16 0 2 64 16 0 3 64 8", "bind_block": "1 384", "bind_thread": "128 1"}) test_ms_bmm((32, 12, 128, 64), (32, 12, 128, 64), 'float16', 'float16', layout1='NHDT', layout2='NHDT', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 1 1 0 1 1 1 0 2 64 64 0 3 64 64 0 4 64 4", bind_block="12 32", bind_thread="32 4") - """ Bert Batch64 - test_ms_bmm((8192, 768), (768, 768), 'float16', 'float16', layout1='NHDT', layout2='NHDT', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 128 128 0 1 128 128 0 2 64 4", bind_block="6 64", bind_thread="32 4") - test_ms_bmm((8192, 768), (8192, 768), 'float16', 'float16', layout1='NHTD', layout2='NHTD', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 64 64 0 1 64 64 0 2 64 4", bind_block="12 12", bind_thread="32 4") - test_ms_bmm((64, 12, 128, 128), (64, 12, 128, 64), 'float16', 'float16', layout1='NHDT', layout2='NHTD', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 1 1 0 1 1 1 0 2 64 64 0 3 64 64 0 4 64 4", bind_block="12 64", bind_thread="32 4") - test_ms_bmm((64, 12, 128, 64), (64, 12, 128, 64), 'float16', 'float16', layout1='NHDT', layout2='NHDT', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 1 1 0 1 1 1 0 2 64 64 0 1 128 128 0 2 64 4", bind_block="2 12 64", bind_thread="32 4") - test_ms_bmm((8192, 768), (8192, 3072), 'float16', 'float16', layout1='NHTD', layout2='NHTD', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 128 128 0 1 128 128 0 2 64 4", bind_block="24 6", bind_thread="32 8") - test_ms_bmm((8192, 3072), (8192, 768), 'float16', 'float16', layout1='NHTD', layout2='NHTD', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 128 128 0 1 128 128 0 2 64 4", bind_block="6 24", bind_thread="32 8") - test_ms_bmm((8192, 768), (3072, 768), 'float16', 'float16', layout1='NHDT', layout2='NHDT', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 128 128 0 1 128 128 0 2 64 4", bind_block="24 64", bind_thread="32 4") - test_ms_bmm((8192, 3072), (768, 3072), 'float16', 'float16', layout1='NHDT', layout2='NHDT', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 128 128 0 1 128 128 0 2 64 4", bind_block="6 64", bind_thread="32 8") - test_ms_bmm((8192, 3072), (3072, 768), 'float16', 'float16', layout1='NHDT', layout2='NHTD', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 128 128 0 1 128 128 0 2 64 4", bind_block="6 64", bind_thread="32 8") - test_ms_bmm((8192, 768), (768, 3072), 'float16', 'float16', layout1='NHDT', layout2='NHTD', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 128 128 0 1 128 128 0 2 64 4", bind_block="24 64", bind_thread="32 4") - """ - - """ Bert Batch32 - test_ms_bmm((4096, 768), (768, 768), 'float16', 'float16', layout1='NHDT', layout2='NHDT', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 128 128 0 1 128 128 0 2 64 4", bind_block="6 32", bind_thread="32 4") + shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, + attrs={"dim": "0 0 1 1 0 1 64 16 0 2 64 16 0 3 64 8", "bind_block": "1 384", "bind_thread": "128 1"}) + + # Auto tiling pass cases for scheme two + test_ms_bmm((32, 12, 128, 128), (32, 12, 128, 64), 'float16', 'float16', layout1='NHDT', layout2='NHTD', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 1 1 0 1 1 1 0 2 64 64 0 3 64 64 0 4 64 4", bind_block="12 32", bind_thread="32 4") - test_ms_bmm((4096, 768), (4096, 768), 'float16', 'float16', layout1='NHTD', layout2='NHTD', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 128 128 0 1 128 128 0 2 64 4", bind_block="6 6", bind_thread="32 4") - test_ms_bmm((32, 12, 128, 64), (32, 12, 128, 64), 'float16', 'float16', layout1='NHDT', layout2='NHDT', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 1 1 0 1 1 1 0 2 32 32 0 3 64 64 0 4 32 4", bind_block="12 32", bind_thread="32 4") - test_ms_bmm((4096, 768), (4096, 3072), 'float16', 'float16', layout1='NHTD', layout2='NHTD', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 128 128 0 1 128 128 0 2 64 4", bind_block="24 6", bind_thread="32 8") - test_ms_bmm((4096, 3072), (768, 3072), 'float16', 'float16', layout1='NHDT', layout2='NHDT', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 128 128 0 1 128 128 0 2 64 4", bind_block="6 32", bind_thread="32 8") - test_ms_bmm((4096, 3072), (4096, 768), 'float16', 'float16', layout1='NHTD', layout2='NHTD', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 128 128 0 1 128 128 0 2 64 4", bind_block="6 24", bind_thread="32 8") - test_ms_bmm((4096, 3072), (3072, 768), 'float16', 'float16', layout1='NHDT', layout2='NHTD', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 128 128 0 1 128 128 0 2 64 4", bind_block="6 32", bind_thread="32 4") - test_ms_bmm((4096, 768), (3072, 768), 'float16', 'float16', layout1='NHDT', layout2='NHDT', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 128 128 0 1 128 128 0 2 64 4", bind_block="24 32", bind_thread="32 4") - test_ms_bmm((4096, 768), (768, 3072), 'float16', 'float16', layout1='NHDT', layout2='NHTD', layout_out='NHDT', - shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, - dim="0 0 128 128 0 1 128 128 0 2 64 4", bind_block="24 32", bind_thread="32 4") - """ + shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch) + test_ms_bmm((256, 128), (64, 128), 'float16', 'float16', layout1='NHDT', layout2='NHDT', layout_out='NHDT', + shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch) + test_ms_bmm((128, 32), (128, 512), 'float16', 'float16', layout1='NHTD', layout2='NHTD', layout_out='NHDT', + shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch) + test_ms_bmm((128, 64), (64, 32), 'float16', 'float16', layout1='NHDT', layout2='NHTD', layout_out='NHDT', + shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch) def cast(poly_sch, fuzz_shape=None, mind_trick_str=''): test_ms_cast((32, 32, 14, 14, 16), "float16", "float32", poly_sch=poly_sch) @@ -304,6 +266,28 @@ def reduce_sum(poly_sch, fuzz_shape=None, mind_trick_str=''): keepdims=True, poly_sch=poly_sch) +def conv(poly_sch, fuzz_shape=None, mind_trick_str=''): + test_ms_conv((32, 64, 56, 56), (64, 64, 3, 3), (1, 1), + (1, 1, 1, 1), (1, 1), "float32", "float32") + + +def conv_tc(poly_sch, fuzz_shape=None, mind_trick_str=''): + test_ms_conv_tc((16, 4, 4, 16), (16, 3, 3, 16), (1, 1), (0, 0, 0, 0), (1, 1), "float16", "float32", + attrs={"dim": "0 0 16 16 0 1 1 1 0 2 1 1 0 3 16 16 0 4 16 8", + "bind_block": "1 4 1", "bind_thread": "32 1"}) + + test_ms_conv_tc((16, 16, 16, 16), (16, 3, 3, 16), (1, 1), (0, 0, 0, 0), (1, 1), "float16", "float32", + attrs={"dim": "0 0 16 16 0 1 2 1 0 2 2 1 0 3 16 16 0 4 16 8", + "bind_block": "1 1 1", "bind_thread": "32 1"}) + + test_ms_conv_tc((64, 6, 6, 64), (64, 3, 3, 64), (1, 1), (0, 0, 0, 0), (1, 1), "float16", "float32", + attrs={"dim": "0 0 32 16 0 1 2 1 0 2 2 1 0 3 32 16 0 4 32 8", + "bind_block": "2 4 2", "bind_thread": "32 4"}) + + test_ms_conv_tc((64, 6, 6, 64), (64, 3, 3, 64), (1, 1), (0, 0, 0, 0), (1, 1), "float16", "float32", + attrs={"dim": "0 0 32 16 0 1 1 1 0 2 1 1 0 3 32 16 0 4 32 8", + "bind_block": "2 16 2", "bind_thread": "32 4"}) + def select(poly_sch, fuzz_shape=None, mind_trick_str=''): test_ms_select((2, ), (2, 2, 2), "int8", "float16", poly_sch=poly_sch) @@ -466,6 +450,7 @@ if __name__ == '__main__': "sub": sub, "reduce_max": reduce_max, "reduce_min": reduce_min, "reduce_sum": reduce_sum, "expand_dims": expand_dims, "one_hot": one_hot, "reshape": reshape, "tile": tile, "trans_data": trans_data, + "conv": conv, "conv_tc": conv_tc, "fused_pad": fused_pad, "fused_bn_reduce": fused_bn_reduce, "fused_bn_update": fused_bn_update, diff --git a/tests/operators/gpu/test_ms_batch_matmul.py b/tests/operators/gpu/test_ms_batch_matmul.py index d25cd4f15a44439e6a91a86433823770eb96a434..723f5dc846bdd43b9a0a35bd6f39792f92a86337 100644 --- a/tests/operators/gpu/test_ms_batch_matmul.py +++ b/tests/operators/gpu/test_ms_batch_matmul.py @@ -74,17 +74,19 @@ def gen_data(shape1, shape2, dtype, out_dtype="float32", layout1="NHDT", layout2 return lhs, rhs, bias, output, expect -def test_ms_bmm(shape1, shape2, dtype, out_dtype="float32", layout1="NHDT", layout2="NHDT", layout_out="NHDT", - shape_bias=None, add_bias=False, tensor_core=True, poly_sch=False, dim="", bind_block="", bind_thread=""): +def test_ms_bmm(shape1, shape2, dtype, out_dtype="float32", layout1="NHDT", layout2="NHDT", layout_out="NHDT", + shape_bias=None, add_bias=False, tensor_core=True, poly_sch=False, attrs=None): op_attrs = [out_dtype, layout1, layout2, layout_out, tensor_core, add_bias] if poly_sch: - attrs = {"target": "cuda", "use_shared_memory": True, "enable_auto_fuse": False, - "dim": dim, "bind_block": bind_block, "bind_thread": bind_thread} + default_attrs = {"target": "cuda"} + if attrs: + default_attrs.update(attrs) if tensor_core: - attrs.update({"pragma_enable_tensor_core": True, "vector_load_type": "float4", "pragma_enable_matmul": True}) + default_attrs.update({"pragma_enable_matmul": True, "enable_auto_inline": False}) + print(default_attrs) mod = utils.op_build_test(batch_matmul, (shape1, shape2, shape_bias), (dtype, dtype, out_dtype), op_attrs=op_attrs, - attrs=attrs, kernel_name="batch_matmul") + attrs=default_attrs, kernel_name="batch_matmul") lhs, rhs, bias, output, expect = gen_data( shape1, shape2, dtype, out_dtype, layout1, layout2, layout_out, shape_bias, add_bias) diff --git a/tests/operators/gpu/test_ms_conv.py b/tests/operators/gpu/test_ms_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..a271ab9a153979c38b50b9977442daa4cf785911 --- /dev/null +++ b/tests/operators/gpu/test_ms_conv.py @@ -0,0 +1,82 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +import numpy as np +from akg.ops.math_gpu.conv import conv +from tests.common.gen_random import random_gaussian +from akg.utils import kernel_exec as utils +from akg.utils.result_analysis import gpu_profiling +from akg.utils.format_transform import to_tvm_nd_array + +def has_pad(padding): + p_l, p_r, p_t, p_b = padding + return not(p_l == 0 and p_r == 0 and p_t == 0 and p_b == 0) + +def gen_data(shape_data, shape_filter, stride, padding, dilation, dtype, out_dtype): + support_list = {"float16": np.float16, "float32": np.float32} + data = random_gaussian(shape_data, miu=1, sigma=0.1).astype(support_list[dtype]) + filter_ = random_gaussian(shape_filter, miu=1, sigma=0.1).astype(support_list[dtype]) + + n, c, h, w = shape_data + c_out, c, kh, kw = shape_filter + s_h, s_w = stride + d_h, d_w = dilation + p_l, p_r, p_t, p_b = padding + + out_h = (h + p_t + p_b - kh) // s_h + 1 + out_w = (w + p_l + p_r - kw) // s_w + 1 + out_shape = (n, c_out, out_h, out_w) + shape_data_pad = (n, c, h + p_t + p_b, w + p_l + p_r) + + """ + initialization data with padding + """ + data_pad = np.zeros(shape_data_pad).astype(support_list[dtype]) + if has_pad(padding): + data_pad[:,:,p_t:p_t+h,p_l:p_l+w] = data + else: + data_pad = data + + whd = (kh - 1) * d_h + 1 + wwd = (kw - 1) * d_w + 1 + expect = np.zeros(out_shape).astype(support_list[out_dtype]) + for f in range(c_out): + for i in range(out_h): + for j in range(out_w): + expect[:,f,i,j] = np.sum(data_pad[:,:,i*s_h:i*s_h+whd:d_h,j*s_w:j*s_w+wwd:d_w]*filter_[f,:,:,:], axis=(1,2,3)) + + output = np.full(expect.shape, np.nan, out_dtype) + print("expect shape is ", np.shape(expect)) + + return data, filter_, output, expect + +def test_ms_conv(shape_data, shape_filter, stride, padding, dilation, dtype, out_dtype="float32", poly_sch=True): + op_attrs = [stride, padding, dilation] + attrs = {"target":"cuda", "enable_auto_fuse":False, "shared_memory_tensors":"input_1 input_2", + "dim":" 0 0 1 1 0 1 1 1 0 2 8 8 0 3 56 56 0 4 2 2", "bind_block":"64 32", "bind_thread":"28 4"} + + if poly_sch: + mod = utils.op_build_test(conv, (shape_data, shape_filter), (dtype, dtype), op_attrs=op_attrs, attrs=attrs, kernel_name="conv_auto") + + data, weight, output, expect = gen_data(shape_data, shape_filter, stride, padding, dilation, dtype, out_dtype) + args = (data, weight, output) + output = utils.mod_launch(mod, args, expect=expect) + res = np.allclose(output, expect, rtol=5e-3, atol=1.e-8) + print("Test {}".format("Pass")) + if not res: + print("Error cuda:===================================") + print(mod.imported_modules[0].get_source()) + raise AssertionError("Test fail") + + data, weight, output, expect = to_tvm_nd_array([data, weight, output, expect]) + gpu_profiling(mod, data, weight, output, expect, repeat_time=2) diff --git a/tests/operators/gpu/test_ms_conv_fusion.py b/tests/operators/gpu/test_ms_conv_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..ab8921f0065dc205c22db74061adb6e04ee28edc --- /dev/null +++ b/tests/operators/gpu/test_ms_conv_fusion.py @@ -0,0 +1,120 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +import numpy as np +from akg.ops.math_gpu.conv_fusion import conv_fusion +from tests.common.gen_random import random_gaussian +from akg.utils import kernel_exec as utils +from akg.utils.result_analysis import gpu_profiling +from akg.utils.format_transform import to_tvm_nd_array + +def has_pad(padding): + p_l, p_r, p_t, p_b = padding + return not(p_l == 0 and p_r == 0 and p_t == 0 and p_b == 0) + +def gen_data(shape_data, shape_filter, stride, padding, dilation, dtype, out_dtype): + support_list = {"float16": np.float16, "float32": np.float32} + data = random_gaussian(shape_data, miu=1, sigma=0.1).astype(support_list[dtype]) + filter_ = random_gaussian(shape_filter, miu=1, sigma=0.1).astype(support_list[dtype]) + + n, c, h, w = shape_data + c_out, c, kh, kw = shape_filter + s_h, s_w = stride + d_h, d_w = dilation + p_l, p_r, p_t, p_b = padding + + out_h = (h + p_t + p_b - kh) // s_h + 1 + out_w = (w + p_l + p_r - kw) // s_w + 1 + out_shape = (n, c_out, out_h, out_w) + shape_data_pad = (n, c, h + p_t + p_b, w + p_l + p_r) + + """ + initialization data with padding + """ + data_pad = np.zeros(shape_data_pad).astype(support_list[dtype]) + if has_pad(padding): + data_pad[:,:,p_t:p_t+h,p_l:p_l+w] = data + else: + data_pad = data + + whd = (kh - 1) * d_h + 1 + wwd = (kw - 1) * d_w + 1 + expect = np.zeros(out_shape).astype(support_list[out_dtype]) + for f in range(c_out): + for i in range(out_h): + for j in range(out_w): + expect[:,f,i,j] = np.sum(data_pad[:,:,i*s_h:i*s_h+whd:d_h,j*s_w:j*s_w+wwd:d_w]*filter_[f,:,:,:], axis=(1,2,3)) + + output = np.full(expect.shape, np.nan, out_dtype) + print("expect shape is ", np.shape(expect)) + + return data, filter_, output, expect + +def fusion_gen_data(shape_data, shape_filter1, shape_filter2, stride1, stride2, padding1, padding2, dilation1, dilation2, dtype, out_dtype): + data, filter_, output1, expect_data = gen_data(shape_data, shape_filter1, stride1, padding1, dilation1, dtype, out_dtype) + support_list = {"float16": np.float16, "float32": np.float32} + filter2 = random_gaussian(shape_filter2, miu=1, sigma=0.1).astype(support_list[dtype]) + + n, c, h, w = expect_data.shape + c_out, c, kh, kw = shape_filter2 + s_h, s_w = stride2 + d_h, d_w = dilation2 + p_l, p_r, p_t, p_b = padding2 + + out_h = (h + p_t + p_b - kh) // s_h + 1 + out_w = (w + p_l + p_r - kw) // s_w + 1 + out_shape = (n, c_out, out_h, out_w) + shape_data_pad = (n, c, h + p_t + p_b, w + p_l + p_r) + + """ + initialization data with padding + """ + data_pad = np.zeros(shape_data_pad).astype(support_list[dtype]) + if has_pad(padding2): + data_pad[:,:,p_t:p_t+h,p_l:p_l+w] = expect_data + else: + data_pad = expect_data + + whd = (kh - 1) * d_h + 1 + wwd = (kw - 1) * d_w + 1 + expect = np.zeros(out_shape).astype(support_list[out_dtype]) + for f in range(c_out): + for i in range(out_h): + for j in range(out_w): + expect[:,f,i,j] = np.sum(data_pad[:,:,i*s_h:i*s_h+whd:d_h,j*s_w:j*s_w+wwd:d_w]*filter_[f,:,:,:], axis=(1,2,3)) + + output = np.full(expect.shape, np.nan, out_dtype) + print("expect shape is ", np.shape(expect)) + + return data, filter_, filter2, output, expect + +def test_ms_conv_fusion(shape_data, shape_filter1, shape_filter2, stride1, stride2, padding1, padding2, dilation1, dilation2, dtype, out_dtype="float32", poly_sch=True): + op_attrs = [stride1, stride2, padding1, padding2, dilation1, dilation2] + attrs = {"target":"cuda", "enable_auto_fuse":False, "shared_memory_tensors":"out input_1 input_2 input_3", "pragma_disable_loop_fusion": True, + "dim": "3 0 1 1 3 1 1 1 3 2 4 4 3 3 52 52 3 4 64 64"} + + if poly_sch: + mod = utils.op_build_test(conv_fusion, (shape_data, shape_filter1, shape_filter2), (dtype, dtype, dtype), op_attrs=op_attrs, attrs=attrs, kernel_name="conv_fusion_auto") + + data, weight1, weight2, output, expect = fusion_gen_data(shape_data, shape_filter1, shape_filter2, stride1, stride2, padding1, padding2, dilation1, dilation2, dtype, out_dtype) + args = (data, weight1, weight2, output) + output = utils.mod_launch(mod, args, expect=expect) + res = np.allclose(output, expect, rtol=5e-3, atol=1.e-8) + print("Test {}".format("Pass")) + if not res: + print("Error cuda:===================================") + print(mod.imported_modules[0].get_source()) + raise AssertionError("Test fail") + + data, weight1, weight2, output, expect = to_tvm_nd_array([data, weight1, weight2, output, expect]) + gpu_profiling(mod, data, weight1, weight2, output, expect, repeat_time=2) diff --git a/tests/operators/gpu/test_ms_conv_tensorcore.py b/tests/operators/gpu/test_ms_conv_tensorcore.py new file mode 100644 index 0000000000000000000000000000000000000000..147e423396d0be65b5d9550fb64ba1297a533225 --- /dev/null +++ b/tests/operators/gpu/test_ms_conv_tensorcore.py @@ -0,0 +1,96 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +import numpy as np +from akg.ops.math_gpu.tensorcore_conv import conv_tc +from tests.common.gen_random import random_gaussian +from akg.utils import kernel_exec as utils +from akg.utils.result_analysis import gpu_profiling +from akg.utils.format_transform import to_tvm_nd_array + + +def has_pad(padding): + p_l, p_r, p_t, p_b = padding + return not(p_l == 0 and p_r == 0 and p_t == 0 and p_b == 0) + + +def gen_data_im2col(shape_data, shape_filter, stride, padding, dilation, dtype, out_dtype): + support_list = {"float16": np.float16, "float32": np.float32} + n, h, w, c = shape_data + out_c, kh, kw, c = shape_filter + s_h, s_w = stride + d_h, d_w = dilation + p_l, p_r, p_t, p_b = padding + out_h = (h + p_t + p_b - kh) // s_h + 1 + out_w = (w + p_l + p_r - kw) // s_w + 1 + + out_shape = (n, out_h, out_w, out_c) + shape_data_pad = (n, h + p_t + p_b, w + p_l + p_r, c) + + data = random_gaussian(shape_data, miu=1, + sigma=0.1).astype(support_list[dtype]) + filter_ = random_gaussian(shape_filter, miu=1, + sigma=0.1).astype(support_list[dtype]) + + """ + initialization data with padding + """ + data_pad = np.zeros(shape_data_pad).astype(support_list[dtype]) + if has_pad(padding): + data_pad[:, p_t:p_t+h, p_l:p_l+w, :] = data + else: + data_pad = data + + whd = (kh - 1) * d_h + 1 + wwd = (kw - 1) * d_w + 1 + expect = np.zeros(out_shape).astype(support_list[out_dtype]) + for i in range(out_h): + for j in range(out_w): + for f in range(out_c): + expect[:, i, j, f] = np.sum( + data_pad[:, i*s_h:i*s_h+whd:d_h, j*s_w:j*s_w+wwd:d_w, :].astype("float32") * + filter_[f, :, :, :].astype("float32"), + axis=(1, 2, 3) + ) + + output = np.full(expect.shape, np.nan, out_dtype) + print("expect shape is ", np.shape(expect)) + return data, filter_, output, expect + + +def test_ms_conv_tc(shape_data, shape_filter, stride, padding, dilation, dtype, out_dtype="float32", poly_sch=True, attrs=None): + op_attrs = [stride, padding, dilation, out_dtype] + default_attrs = {"target": "cuda", "enable_auto_fuse": False} + default_attrs.update({"pragma_enable_matmul": True, "pragma_enable_conv_tensor_core": True}) + if attrs: + default_attrs.update(attrs) + + data, weight, output, expect = gen_data_im2col( + shape_data, shape_filter, stride, padding, dilation, dtype, out_dtype) + + if poly_sch: + mod = utils.op_build_test(conv_tc, (data.shape, weight.shape), ( + dtype, dtype), op_attrs=op_attrs, attrs=default_attrs, kernel_name="conv_tc_auto") + + args = (data, weight, output) + output = utils.mod_launch(mod, args, expect=expect) + res = np.allclose(output, expect, rtol=5e-3, atol=1.e-8) + print("Test {}".format("Pass" if res else "Fail")) + if not res: + print("Error cuda:===================================") + print(mod.imported_modules[0].get_source()) + raise AssertionError("Test fail") + + data, weight, output, expect = to_tvm_nd_array( + [data, weight, output, expect]) + gpu_profiling(mod, data, weight, output, expect, repeat_time=10000) diff --git a/tests/st/composite/stitch_adapt/Fused_Mul_ReduceSum_Sub_Mul_Mul__8814173056820319925.json b/tests/st/composite/stitch_adapt/Fused_Mul_ReduceSum_Sub_Mul_Mul__8814173056820319925.json index f5e9c447b6e49e06e019be6a513592c8e558c20a..23ba9b9d7222e4ddbae6fe7675539d89232b33dd 100644 --- a/tests/st/composite/stitch_adapt/Fused_Mul_ReduceSum_Sub_Mul_Mul__8814173056820319925.json +++ b/tests/st/composite/stitch_adapt/Fused_Mul_ReduceSum_Sub_Mul_Mul__8814173056820319925.json @@ -307,4 +307,4 @@ ], "platform": "AKG", "process": "cuda" -} \ No newline at end of file +} diff --git a/tests/st/composite/stitch_adapt/Fused_Reshape_Mul_ReduceSum_Neg_Mul.json b/tests/st/composite/stitch_adapt/Fused_Reshape_Mul_ReduceSum_Neg_Mul.json index b823a1083612852f88f53f2c500eaca469d373d2..0d1e71ac4ea4b7bd7bc691363377b63690a7b93e 100644 --- a/tests/st/composite/stitch_adapt/Fused_Reshape_Mul_ReduceSum_Neg_Mul.json +++ b/tests/st/composite/stitch_adapt/Fused_Reshape_Mul_ReduceSum_Neg_Mul.json @@ -271,3 +271,4 @@ "platform": "AKG", "process": "cuda" } + diff --git a/tests/st/ops/gpu/stitch_cases/Fused_Cast_LessEqual_Cast_Mul_TensorAdd_ReduceMax_Sub_Exp_ReduceSum_RealDiv_Mul.json b/tests/st/ops/gpu/stitch_cases/Fused_Cast_LessEqual_Cast_Mul_TensorAdd_ReduceMax_Sub_Exp_ReduceSum_RealDiv_Mul.json index 9c9547adfba0d1ff40648d3d462dc0cc6701d60c..589802144061789cd50bbd7e3669e992670bbdaf 100644 --- a/tests/st/ops/gpu/stitch_cases/Fused_Cast_LessEqual_Cast_Mul_TensorAdd_ReduceMax_Sub_Exp_ReduceSum_RealDiv_Mul.json +++ b/tests/st/ops/gpu/stitch_cases/Fused_Cast_LessEqual_Cast_Mul_TensorAdd_ReduceMax_Sub_Exp_ReduceSum_RealDiv_Mul.json @@ -675,3 +675,4 @@ "platform": "AKG", "process": "cuda" } + diff --git a/third_party/incubator-tvm/include/tvm/ir.h b/third_party/incubator-tvm/include/tvm/ir.h index 78d5a3dae83c88cad7f3cf1b494a1b225de52c36..9e0311ca46d2041838ddafd2e3a734487a37934e 100644 --- a/third_party/incubator-tvm/include/tvm/ir.h +++ b/third_party/incubator-tvm/include/tvm/ir.h @@ -30,6 +30,10 @@ * 2021.01.13 - Add mark of TensorCore, data access vectorization. */ +/* + * 2021.05.27 - Add intrinsic var for GEMM op fusion on TensorCore. + */ + #ifndef TVM_IR_H_ #define TVM_IR_H_ @@ -1637,6 +1641,21 @@ constexpr const char* tvm_fill_fragment = "tvm_fill_fragment"; * } */ constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync"; + +/*! + * \brief akg intrinsic for tensor core fragment operator fusion. + * + * void akg_fragment_elem(Var fragment_c, Expr index_c, + * Var fragment_a, Expr index_a, + * Var fragment_b, Expr index_b, + * Expr op_name) { + * akg::wmma::fragment_add / fragment_sub / fragment_mul / fragment_div( + * fragment_c[index_c], fragment_a[index_a], + * fragment_b[index_b]); + * } + */ +constexpr const char* akg_fragment_elem = "akg_fragment_elem"; + constexpr const char* tvm_cce_string_print = "tvm_cce_string_print"; } // namespace intrinsic diff --git a/third_party/incubator-tvm/python/tvm/build_module.py b/third_party/incubator-tvm/python/tvm/build_module.py index 9e619a47e3d3cbc469e00dc7852f309f0ba8f061..3a3422f5d13e6ea65fbc4d531523f58835b0a29f 100644 --- a/third_party/incubator-tvm/python/tvm/build_module.py +++ b/third_party/incubator-tvm/python/tvm/build_module.py @@ -406,7 +406,7 @@ def lower(sch, else: stmt = ir_pass.VectorizeLoop(stmt) stmt = ir_pass.InjectVirtualThread(stmt) - stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop) + stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop, False) stmt = ir_pass.StorageRewrite(stmt) stmt = ir_pass.UnrollLoop( stmt, diff --git a/third_party/incubator-tvm/src/codegen/codegen_c.cc b/third_party/incubator-tvm/src/codegen/codegen_c.cc index aa0284e4afd740a543303fda3c5d15259f0d1391..faed4e054dccf437dcf28d91e4b10a918b0f53c3 100644 --- a/third_party/incubator-tvm/src/codegen/codegen_c.cc +++ b/third_party/incubator-tvm/src/codegen/codegen_c.cc @@ -704,13 +704,17 @@ void CodeGenC::VisitStmt_(const Store* op) { if (pos_call != std::string::npos) { value = value.substr(0, pos_call) + ")" + value.substr(pos_call); } - ref = ref.replace(ref.find(vectorize_var_), vectorize_var_.length(), "0"); + while (ref.find(vectorize_var_) != std::string::npos) { + ref = ref.replace(ref.find(vectorize_var_), vectorize_var_.length(), "0"); + } auto scale = std::to_string(vectorize_scale_.as()->value); auto pos_end = ref.find("]"); if (pos_end != std::string::npos) { ref = ref.substr(0, pos_end) + " / " + scale + ref.substr(pos_end); } - value = value.replace(value.find(vectorize_var_), vectorize_var_.length(), "0"); + while (value.find(vectorize_var_) != std::string::npos) { + value = value.replace(value.find(vectorize_var_), vectorize_var_.length(), "0"); + } pos_end = value.find("]"); if (pos_end != std::string::npos) { value = value.substr(0, pos_end) + " / " + scale + value.substr(pos_end); diff --git a/third_party/incubator-tvm/src/codegen/codegen_cuda.cc b/third_party/incubator-tvm/src/codegen/codegen_cuda.cc index a7957d1031382e1b4254874455bd3fd0d9806a15..6cd48373e89653298b7e180dae8e9d30c118cc81 100644 --- a/third_party/incubator-tvm/src/codegen/codegen_cuda.cc +++ b/third_party/incubator-tvm/src/codegen/codegen_cuda.cc @@ -54,6 +54,11 @@ * Print offset shared memory when use total shared_memory of VisitStmt_(const Allocate* op) */ +/* + * 2021.3.22 + * Refactor the function Simplify_name. + */ + /* * 2021.5.17 * Modify the functions: @@ -61,6 +66,12 @@ * for the reduce sum operator */ +/* + * 2021.5.27 + * Add function for GEMM op fusion on TensorCore. + */ + +#include "codegen_cuda.h" #include #include @@ -68,6 +79,7 @@ #include #include #include +#include "common/common_util.h" #include #include "literal/cuda_half_t.h" #include "codegen_cuda.h" @@ -83,14 +95,6 @@ CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; } -std::string CodeGenCUDA::Simplify_name(std::string input) { - auto pos = input.find("_local"); - if (pos != std::string::npos) { - return input.substr(0, pos); - } - return input; -} - void CodeGenCUDA::Init(bool output_ssa) { CodeGenC::Init(output_ssa); vid_global_barrier_state_ = GetUniqueName(runtime::symbol::tvm_global_barrier_state); @@ -155,7 +159,7 @@ std::string CodeGenCUDA::Finish() { if (need_mma_h_) { if (wmma_scope == "akg") { - decl_stream << "#include \"akg_mma_lib/m16n16k4.hpp\"\n"; + decl_stream << "#include \"akg_mma_lib/wmma.hpp\"\n"; } else{ decl_stream << "#include \n"; } @@ -455,7 +459,7 @@ void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) { Expr new_args = op->args[4]; Expr warp_tile; auto var_node = op->args[0].as(); - auto it_matrix = matrix_abc.find(Simplify_name(var_node->name_hint)); + auto it_matrix = matrix_abc.find(akg::common::GetGlobalName(var_node->name_hint)); if (it_matrix != matrix_abc.end()) { if (it_matrix->second == "matrix_a") { if (op->args[7].as()->value == "row_major") { @@ -469,6 +473,12 @@ void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) { } else { warp_tile = warp_tile_n; } + } else if (it_matrix->second == "accumulator") { + if (op->args[7].as()->value == "row_major") { + warp_tile = warp_tile_n; + } else { + LOG(FATAL) << "Not support matrix to load fragment accumulator!"; + } } else { LOG(FATAL) << "Not support matrix to load !"; } @@ -486,6 +496,9 @@ void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) { this->PrintExpr(op->args[5], os); os << ", "; this->PrintExpr(op->args[6], os); + if (it_matrix != matrix_abc.end() && it_matrix->second == "accumulator") { + os << ", nvcuda::wmma::mem_row_major"; + } os << ")"; } else if (op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) { need_mma_h_ = true; @@ -538,6 +551,25 @@ void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) { } os << "]" << ((i < 3) ? ", ": ")"); } + } else if (op->is_intrinsic(intrinsic::akg_fragment_elem)) { + need_mma_h_ = true; + os << wmma_scope << "::wmma::fragment_" << op->args[op->args.size() - 1].as()->value << "("; + if (op->args.size() == 7) { + for (int i = 0; i < 3; ++i) { + this->PrintExpr(op->args[i * 2], os); + os << "["; + this->PrintExpr(op->args[i * 2 + 1], os); + os << "]" << ((i < 2) ? ", " : ")"); + } + } else if (op->args.size() == 6) { + for (int i = 0; i < 2; ++i) { + this->PrintExpr(op->args[i * 2], os); + os << "["; + this->PrintExpr(op->args[i * 2 + 1], os); + os << "]" << ((i < 1) ? ", " : ""); + } + os << ", " << op->args[4] << ")"; + } } else if ((op->call_type == Call::Extern) || (op->call_type == Call::PureExtern)) { if (op->name == "&") { CHECK_EQ(op->args.size(), 1); @@ -796,7 +828,7 @@ void CodeGenCUDA::VisitStmt_(const Allocate* op) { if (pos != std::string::npos) { matrix_scope = matrix_scope.substr(pos + 1); } - matrix_abc.insert(std::make_pair(Simplify_name(vid), matrix_scope)); + matrix_abc.insert(std::make_pair(akg::common::GetGlobalName(vid), matrix_scope)); if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { CHECK(op->type == Float(16) || op->type == Int(8) || op->type == UInt(8)) << "Matrix_a and matrix_b only support half or char or unsigned char type for now"; diff --git a/third_party/incubator-tvm/src/codegen/codegen_cuda.h b/third_party/incubator-tvm/src/codegen/codegen_cuda.h index 05d1a094fad94755f63ef4bffe7998556bc1e724..b9688cb7bbd466b3d34f6701348a3970582f5f95 100644 --- a/third_party/incubator-tvm/src/codegen/codegen_cuda.h +++ b/third_party/incubator-tvm/src/codegen/codegen_cuda.h @@ -43,6 +43,11 @@ */ +/* + * 2021.3.22 + * Refactor the function Simplify_name. + */ + /* * 2021.05.17 * Add const akg_reduce::AkgKahanAccumulation for reduce @@ -72,7 +77,6 @@ constexpr auto PARIS_REDUCE_LIB = "paris"; class CodeGenCUDA final : public CodeGenC { public: CodeGenCUDA(); - std::string Simplify_name(std::string input); void Init(bool output_ssa); void AddFunction(LoweredFunc f); std::string Finish(); @@ -162,7 +166,7 @@ class CodeGenCUDA final : public CodeGenC { Expr matrix_b_major = StringImm::make("col_major"); std::unordered_map matrix_abc; // indicate which TensorCore interface - std::string wmma_scope; + std::string wmma_scope = "nvcuda"; std::unordered_map sm_offsets; std::unordered_map fragment_shapes; diff --git a/third_party/incubator-tvm/src/pass/inject_double_buffer.cc b/third_party/incubator-tvm/src/pass/inject_double_buffer.cc index a280c14648c55e05627caaaa5ff0cc27e7063b1c..61954489bae291aebc27c6a0116539ed7aedfe32 100644 --- a/third_party/incubator-tvm/src/pass/inject_double_buffer.cc +++ b/third_party/incubator-tvm/src/pass/inject_double_buffer.cc @@ -67,6 +67,51 @@ class DoubleBufferDetector : public IRVisitor { std::unordered_set touched_; }; +class StripSyncAndAllocs : public IRMutator { + public: + explicit StripSyncAndAllocs(bool use_double_shared) + : use_double_buffer_(use_double_shared) {} + + Stmt Mutate_(const Block* op, const Stmt& s) final { + Stmt first = Mutate(op->first); + Stmt rest = Mutate(op->rest); + if (const auto attr = first.as()) { + if (attr->attr_key == "delete_this_sync" + || (use_double_buffer_ && attr->attr_key == "delete_this_sync_for_db")) { + return rest; + } + } + return Block::make(first, rest); + } + + Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + if (op->attr_key == attr::storage_scope) { + if (auto alloc = op->body.as()) { + int fragment_size_ = (arith::ComputeReduce(alloc->extents, Expr()) + * alloc->type.lanes() * alloc->type.bytes()).as()->value; + auto it = fragment_info_.find(alloc->buffer_var.get()); + if ( it == fragment_info_.end() + || ( it != fragment_info_.end() && (fragment_size_ > (it->second)) ) ) { + fragment_info_[alloc->buffer_var.as()] = fragment_size_; + fragment_allocs_.emplace_back( AttrStmt::make(op->node, op->attr_key, op->value, Evaluate::make(0)) ); + fragment_allocs_.emplace_back( + Allocate::make(alloc->buffer_var, alloc->type, alloc->extents, alloc->condition, Evaluate::make(0)) + ); + } + return Mutate(alloc->body); + } + } + return IRMutator::Mutate_(op, s); + } + + std::vector GetFragmentAllocs() { return fragment_allocs_; } + + private: + bool use_double_buffer_{false}; + std::vector fragment_allocs_; + std::unordered_map fragment_info_; +}; + class StripDoubleBufferWrite : public IRMutator { public: Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { @@ -78,11 +123,6 @@ class StripDoubleBufferWrite : public IRMutator { } }; -class StripTailAlloc : public IRMutator { - public: - Stmt Mutate_(const Allocate* op, const Stmt& s) final { return Mutate(op->body); } -}; - class StripTransferWriteIndex : public IRMutator { public: Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { @@ -192,8 +232,8 @@ class ThreadGroupInjector : public IRMutator { class DoubleBufferInjector : public IRMutator { public: - explicit DoubleBufferInjector(int split_loop, bool use_transfer_buffer) - : split_loop_(split_loop), use_transfer_buffer_(use_transfer_buffer) {} + explicit DoubleBufferInjector(int split_loop, bool use_double_shared) + : split_loop_(split_loop), use_double_buffer_(use_double_shared) {} Stmt Inject(const Stmt& stmt) { DoubleBufferDetector detector; @@ -212,14 +252,11 @@ class DoubleBufferInjector : public IRMutator { if (it != dbuffer_info_.end()) { it->second.scope = op->value.as()->value; return Mutate(op->body); - } else { - return IRMutator::Mutate_(op, s); } } else if (op->attr_key == attr::double_buffer_scope) { return MakeProducer(op, s); - } else { - return IRMutator::Mutate_(op, s); } + return IRMutator::Mutate_(op, s); } Stmt Mutate_(const Allocate* op, const Stmt& s) final { @@ -232,32 +269,32 @@ class DoubleBufferInjector : public IRMutator { it->second.transfer_buffer = Var(op->buffer_var->name_hint + "_transfer"); it->second.transfer_buffer_extents.push_back(make_const(op->extents[0].type(), 1)); Stmt stmt = IRMutator::Mutate_(op, s); - op = stmt.as(); - CHECK(it->second.loop != nullptr); - auto& alloc_nest = loop_allocs_[it->second.loop]; - alloc_nest.emplace_back(AttrStmt::make(op->buffer_var, attr::storage_scope, - StringImm::make(it->second.scope), Evaluate::make(0))); - if (!use_transfer_buffer_) { - Array new_extents{make_const(op->extents[0].type(), 2)}; - for (Expr e : op->extents) { - new_extents.push_back(e); + if (const auto alloc = stmt.as()) { + CHECK(it->second.loop != nullptr); + auto& alloc_nest = loop_allocs_[it->second.loop]; + alloc_nest.emplace_back(AttrStmt::make(alloc->buffer_var, attr::storage_scope, + StringImm::make(it->second.scope), Evaluate::make(0))); + if (use_double_buffer_) { + Array new_extents{make_const(alloc->extents[0].type(), 2)}; + for (Expr e : alloc->extents) { + new_extents.push_back(e); + } + alloc_nest.emplace_back(Allocate::make(alloc->buffer_var, alloc->type, new_extents, alloc->condition, + Evaluate::make(0))); + } else { + alloc_nest.emplace_back(Allocate::make(alloc->buffer_var, alloc->type, alloc->extents, alloc->condition, + Evaluate::make(0))); } - alloc_nest.emplace_back(Allocate::make(op->buffer_var, op->type, new_extents, op->condition, - Evaluate::make(0))); - } else { - alloc_nest.emplace_back(Allocate::make(op->buffer_var, op->type, op->extents, op->condition, - Evaluate::make(0))); alloc_nest.emplace_back( AttrStmt::make(it->second.transfer_buffer, air::ir::attr::storage_scope, - StringImm::make(it->second.transfer_buffer_scope), Evaluate::make(0))); + StringImm::make(it->second.transfer_buffer_scope), Evaluate::make(0))); alloc_nest.emplace_back(Allocate::make(it->second.transfer_buffer, it->second.type, - it->second.transfer_buffer_extents, - it->second.condition, Evaluate::make(0))); + it->second.transfer_buffer_extents, + it->second.condition, Evaluate::make(0))); + return alloc->body; } - return op->body; - } else { - return IRMutator::Mutate_(op, s); } + return IRMutator::Mutate_(op, s); } Stmt Mutate_(const For* op, const Stmt& s) final { @@ -266,14 +303,13 @@ class DoubleBufferInjector : public IRMutator { transfer_loop_nest_.push_back(op); } Stmt stmt = IRMutator::Mutate_(op, s); - if (use_transfer_buffer_) { - const For* orig_loop = stmt.as(); - auto iter = loop_transfer_.find(op); - if (iter != loop_transfer_.end()) { - stmt = - For::make(orig_loop->loop_var, orig_loop->min, orig_loop->extent, orig_loop->for_type, - orig_loop->device_api, Block::make(orig_loop->body, MergeSeq(iter->second))); - } + const For* orig_loop = stmt.as(); + auto iter = loop_transfer_.find(op); + std::vector fragment_allocs; + if (iter != loop_transfer_.end()) { + stmt = + For::make(orig_loop->loop_var, orig_loop->min, orig_loop->extent, orig_loop->for_type, + orig_loop->device_api, Block::make(orig_loop->body, MergeSeq(iter->second))); } auto it = loop_pre_.find(op); if (it != loop_pre_.end()) { @@ -291,25 +327,24 @@ class DoubleBufferInjector : public IRMutator { Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.type()); std::unordered_map vmap; std::vector loop_seq; + StripSyncAndAllocs body_remover(use_double_buffer_); + Stmt old_loop_body = body_remover.Mutate(old_loop->body); + fragment_allocs = body_remover.GetFragmentAllocs(); for (int32_t i = 0; i < split_loop_; ++i) { vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.type(), i); - loop_seq.emplace_back(Substitute(old_loop->body, vmap)); - } - Stmt loop; - if (use_transfer_buffer_) { - // Add syncthreads at the end of main loop - loop = For::make(outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api, - Block::make(MergeSeq(loop_seq), Evaluate::make( - Call::make(Int(32), "tvm_storage_sync", {StringImm::make("shared")}, Call::Intrinsic) - )) - ); - } else { - loop = For::make(outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api, MergeSeq(loop_seq)); + loop_seq.emplace_back(Substitute(old_loop_body, vmap)); } + // Add syncthreads at the end of main loop + Stmt loop = For::make(outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api, + Block::make(MergeSeq(loop_seq), Evaluate::make( + Call::make(Int(32), "tvm_storage_sync", {StringImm::make("shared")}, Call::Intrinsic) + )) + ); // tail std::vector tail_seq; - Stmt tail_body = StripDoubleBufferWrite().Mutate(old_loop->body); - tail_body = StripTailAlloc().Mutate(tail_body); + StripSyncAndAllocs tail_remover(false); + old_loop_body = tail_remover.Mutate(old_loop->body); + Stmt tail_body = StripDoubleBufferWrite().Mutate(old_loop_body); for (int32_t i = 0; i < split_loop_; ++i) { Expr idx = tail_base + make_const(tail_base.type(), i); vmap[old_loop->loop_var.get()] = idx; @@ -318,13 +353,14 @@ class DoubleBufferInjector : public IRMutator { } stmt = Block::make(loop, MergeSeq(tail_seq)); } + // Move fragment allocation statements to the top of the current loop + Stmt loop_pre_stmt = MergeNest(fragment_allocs, MergeSeq(it->second)); // Add syncthreads after the first prefetch - stmt = Block::make( - Block::make(MergeSeq(it->second), Evaluate::make( - Call::make(Int(32), "tvm_storage_sync", {StringImm::make("shared")}, Call::Intrinsic) - )), - stmt - ); + loop_pre_stmt = Block::make( + loop_pre_stmt, + Evaluate::make(Call::make(Int(32), "tvm_storage_sync", {StringImm::make("shared")}, Call::Intrinsic)) + ); + stmt = Block::make(loop_pre_stmt, stmt); } it = loop_allocs_.find(op); if (it != loop_allocs_.end()) { @@ -336,16 +372,11 @@ class DoubleBufferInjector : public IRMutator { Stmt Mutate_(const Store* op, const Stmt& s) final { Stmt stmt = IRMutator::Mutate_(op, s); - op = stmt.as(); - auto it = dbuffer_info_.find(op->buffer_var.get()); - if (it != dbuffer_info_.end()) { - StorageEntry& e = it->second; - CHECK(in_double_buffer_scope_); - if (!use_transfer_buffer_) { - CHECK(e.stride.defined()); - return Store::make(op->buffer_var, op->value, e.switch_write_var * e.stride + op->index, - op->predicate); - } else { + if (const auto store = stmt.as()) { + auto it = dbuffer_info_.find(store->buffer_var.get()); + if (it != dbuffer_info_.end()) { + StorageEntry& e = it->second; + CHECK(in_double_buffer_scope_); Expr transfer_index = make_const(e.loop->loop_var.type(), 0); Expr transfer_extent = make_const(e.transfer_buffer_extents[0].type(), 1); for (unsigned i = 0; i < transfer_loop_nest_.size(); i++) { @@ -359,14 +390,20 @@ class DoubleBufferInjector : public IRMutator { transfer_extent *= transfer_loop_nest_[i]->extent - transfer_loop_nest_[i]->min; } e.transfer_buffer_extents.push_back(transfer_extent); - air::DataType transfer_type = op->value.as()->type; + air::DataType transfer_type = store->value.as()->type; if (e.type != transfer_type) { e.type = transfer_type; } Stmt transfer_store = - Store::make(e.transfer_buffer, op->value, transfer_index, op->predicate); - transfer_store = - AttrStmt::make(op->buffer_var, TRANSFER_WRITE_INDEX, op->index, transfer_store); + Store::make(e.transfer_buffer, store->value, transfer_index, store->predicate); + if (use_double_buffer_) { + CHECK(e.stride.defined()); + transfer_store = + AttrStmt::make(store->buffer_var, TRANSFER_WRITE_INDEX, e.switch_write_var * e.stride + store->index, transfer_store); + } else { + transfer_store = + AttrStmt::make(store->buffer_var, TRANSFER_WRITE_INDEX, store->index, transfer_store); + } return transfer_store; } } @@ -375,31 +412,16 @@ class DoubleBufferInjector : public IRMutator { Expr Mutate_(const Load* op, const Expr& e) final { Expr expr = IRMutator::Mutate_(op, e); - op = expr.as(); - auto it = dbuffer_info_.find(op->buffer_var.get()); - if ((!use_transfer_buffer_ && it != dbuffer_info_.end())) { - const StorageEntry& e = it->second; - CHECK(e.stride.defined()); - CHECK(e.switch_read_var.defined()); - return Load::make(op->type, op->buffer_var, e.switch_read_var * e.stride + op->index, - op->predicate); - } else { - return expr; - } - } - - Stmt Mutate_(const Block* op, const Stmt& s) final { - Stmt first = Mutate(op->first); - Stmt rest = Mutate(op->rest); - if (const auto attr = op->first.as()) { - if (attr->attr_key == "delete_this_sync") { - return rest; + if (const auto load = expr.as()) { + auto it = dbuffer_info_.find(load->buffer_var.get()); + if ((use_double_buffer_ && it != dbuffer_info_.end())) { + const StorageEntry& e = it->second; + CHECK(e.stride.defined()); + CHECK(e.switch_read_var.defined()); + return Load::make(load->type, load->buffer_var, e.switch_read_var * e.stride + load->index, load->predicate); } } - if (first.same_as(op->first) && rest.same_as(op->rest)) { - return s; - } - return Block::make(first, rest); + return expr; } Expr Mutate_(const Variable* op, const Expr& e) final { @@ -430,21 +452,19 @@ class DoubleBufferInjector : public IRMutator { std::unordered_map vmap; vmap[e.loop->loop_var.get()] = zero; Stmt transfer_stmt; - if (!use_transfer_buffer_) { + if (use_double_buffer_) { vmap[e.switch_write_var.get()] = zero; - } else { - transfer_stmt = TransferBufferInjector().Mutate(body); - body = StripTransferWriteIndex().Mutate(body); - transfer_stmt = Substitute(transfer_stmt, vmap); } + transfer_stmt = TransferBufferInjector().Mutate(body); + body = StripTransferWriteIndex().Mutate(body); loop_pre_[e.loop].emplace_back(Substitute(body, vmap)); + loop_pre_[e.loop].emplace_back(Substitute(transfer_stmt, vmap)); vmap[e.loop->loop_var.get()] = loop_shift; - if (!use_transfer_buffer_) { + if (use_double_buffer_) { vmap[e.switch_write_var.get()] = indexmod(loop_shift, two); - } else { - loop_pre_[e.loop].emplace_back(transfer_stmt); - } + } body = Substitute(body, vmap); + transfer_stmt = Substitute(transfer_stmt, vmap); body = AttrStmt::make(buffer, air::ir::attr::double_buffer_write, 1, body); body = IfThenElse::make(loop_shift < e.loop->extent, body); transfer_stmt = AttrStmt::make(e.transfer_buffer, attr::double_buffer_write, 1, transfer_stmt); @@ -479,7 +499,7 @@ class DoubleBufferInjector : public IRMutator { // Whether split loop int32_t split_loop_; // Whether use transfer buffer to replace the second shared buffer - bool use_transfer_buffer_{false}; + bool use_double_buffer_{false}; // Whether we are inside double buffer scope. bool in_double_buffer_scope_{false}; // The current loop nest @@ -496,8 +516,8 @@ class DoubleBufferInjector : public IRMutator { std::vector transfer_loop_nest_; }; -Stmt InjectDoubleBuffer(Stmt stmt, int split_loop, bool use_transfer_buffer) { - Stmt new_stmt = DoubleBufferInjector(split_loop, use_transfer_buffer).Inject(stmt); +Stmt InjectDoubleBuffer(Stmt stmt, int split_loop, bool use_double_shared) { + Stmt new_stmt = DoubleBufferInjector(split_loop, use_double_shared).Inject(stmt); new_stmt = ThreadGroupInjector().Inject(new_stmt); return new_stmt; } diff --git a/third_party/incubator-tvm/src/pass/storage_rewrite.cc b/third_party/incubator-tvm/src/pass/storage_rewrite.cc index cae44d987734b410e920b3cf4be67a4761e24237..553cabcadc2101286d15ad11b81bd7924b736367 100644 --- a/third_party/incubator-tvm/src/pass/storage_rewrite.cc +++ b/third_party/incubator-tvm/src/pass/storage_rewrite.cc @@ -22,6 +22,12 @@ * \brief Memory access pattern analysis and optimization. * Re-write data access to enable memory sharing when possible. */ + +/*! + * 2021.06.07 + * Add function for Merging shared memory of the same life cycle. + */ + #include #include #include @@ -79,6 +85,11 @@ class LinearAccessPatternFinder final : public IRVisitor { }; void Visit_(const Allocate* op) final { + auto it_shared = tensor_scope_.find(op->buffer_var->name_hint); + if (it_shared != tensor_scope_.end() && it_shared->second == "shared") { + shared_tensor_.emplace(op->buffer_var->name_hint); + tensor_bounds_[op->buffer_var->name_hint] = op->extents; + } size_t level = scope_.size(); const Variable* buf = op->buffer_var.get(); auto it = alloc_info_.find(buf); @@ -179,6 +190,10 @@ class LinearAccessPatternFinder final : public IRVisitor { const Variable* buf = op->node.as(); alloc_info_[buf].storage_scope = StorageScope::make(op->value.as()->value); + tensor_scope_[buf->name_hint] = op->value.as()->value; + IRVisitor::Visit_(op); + } else if (op->attr_key == air::ir::attr::pragma_tensor_core) { + tensor_core_on_ = true; IRVisitor::Visit_(op); } else { IRVisitor::Visit_(op); @@ -196,6 +211,76 @@ class LinearAccessPatternFinder final : public IRVisitor { VisitNewScope(op); } + inline bool Matched() { + if (!tensor_core_on_) { + return false; + } + LivenessAnalysis(linear_seq_); + return !trans_tensor_.empty(); + } + + void LivenessAnalysis(const std::vector &seq) { + std::unordered_set touched; + for (size_t i = seq.size(); i > 0; --i) { + const StmtEntry &s = seq[i - 1]; + for (const Variable *buffer : s.touched) { + if (!touched.count(buffer)) { + touched.insert(buffer); + event_map_[s.stmt].push_back(buffer); + } + } + } + touched.clear(); + int level = 0; + for (auto it = event_map_.begin(); it != event_map_.end(); ++it) { + if (it->second.size() > 1) { + for (size_t i = 0; i < it->second.size(); ++i) { + if (shared_tensor_.find(it->second[i]->name_hint) != shared_tensor_.end()) { + same_var_[level].push_back(GetRef(it->second[i])); + same_live_[level].push_back(it->second[i]->name_hint); + auto itstr = same_live_.find(level); + auto itvar = same_var_.find(level); + CHECK(itstr != same_live_.end()); + CHECK(itvar != same_var_.end()); + trans_tensor_[it->second[i]->name_hint] = itstr->second[0]; + auto bounds = tensor_bounds_.find(itstr->second[0]); + CHECK(bounds != tensor_bounds_.end()); + Expr offset = IntImm::make(Int(32), 1); + for (size_t bd = 0; bd < bounds->second.size(); ++bd) { + offset = Mul::make(offset, bounds->second[bd]); + } + shared_offset_[it->second[i]->name_hint] = offset; + shared_offset_[itstr->second[0]] = Expr(0); + trans_var_[it->second[i]->name_hint] = itvar->second[0]; + + auto itsrc = tensor_bounds_.find(itstr->second[0]); + CHECK(itsrc != tensor_bounds_.end()); + Expr tmpsrc = itsrc->second[0]; + for (size_t bd = 1; bd < itsrc->second.size(); ++bd) { + tmpsrc = tmpsrc * itsrc->second[bd]; + } + Array new_bounds; + if (it->second[i]->name_hint == itstr->second[0]) { + new_bounds.push_back(tmpsrc); + } else { + auto itdst = tensor_bounds_.find(it->second[i]->name_hint); + CHECK(itdst != tensor_bounds_.end()); + Expr tmpdst = itdst->second[0]; + for (size_t bd = 1; bd < itdst->second.size(); ++bd) { + tmpdst = tmpdst * itdst->second[bd]; + } + new_bounds.push_back(tmpsrc + tmpdst); + } + tensor_bounds_[itstr->second[0]] = new_bounds; + } + } + level++; + } + } + } + + friend class SharedMemRewriter; + // linearized access sequence. std::vector linear_seq_; // The storage scope of each buffer @@ -206,6 +291,104 @@ class LinearAccessPatternFinder final : public IRVisitor { bool in_thread_env_{false}; // The scope stack. std::vector scope_; + bool tensor_core_on_{false}; + std::unordered_map> event_map_; + std::unordered_map> same_live_; + std::unordered_map> same_var_; + std::unordered_map trans_tensor_; + std::unordered_map trans_var_; + std::unordered_map> tensor_bounds_; + std::unordered_map shared_offset_; + std::unordered_map tensor_scope_; + std::unordered_set shared_tensor_; +}; + +class SharedMemRewriter : public IRMutator { + public: + explicit SharedMemRewriter(const LinearAccessPatternFinder &finder) + : trans_tensor_(finder.trans_tensor_), + trans_var_(finder.trans_var_), + tensor_bounds_(finder.tensor_bounds_), + shared_offset_(finder.shared_offset_) {} + + Stmt Mutate_(const Store *op, const Stmt &s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + if (op == nullptr) { + return stmt; + } + auto it = trans_tensor_.find(op->buffer_var->name_hint); + if (it != trans_tensor_.end() && it->first != it->second) { + auto itoffset = shared_offset_.find(it->first); + CHECK(itoffset != shared_offset_.end()); + Expr new_index = op->index + itoffset->second; + auto itvar = trans_var_.find(op->buffer_var->name_hint); + CHECK(itvar != trans_var_.end()); + return Store::make(itvar->second, op->value, new_index, op->predicate); + } + return stmt; + } + + Expr Mutate_(const Load *op, const Expr &e) final { + Expr expr = IRMutator::Mutate_(op, e); + op = expr.as(); + auto it = trans_tensor_.find(op->buffer_var->name_hint); + if (it != trans_tensor_.end() && it->first != it->second) { + auto itoffset = shared_offset_.find(it->first); + CHECK(itoffset != shared_offset_.end()); + Expr new_index = op->index + itoffset->second; + auto itvar = trans_var_.find(op->buffer_var->name_hint); + CHECK(itvar != trans_var_.end()); + return Load::make(op->type, itvar->second, new_index, op->predicate); + } + return expr; + } + + Expr Mutate_(const Variable *op, const Expr &e) final { + auto it = trans_tensor_.find(op->name_hint); + if (it != trans_tensor_.end() && it->first != it->second) { + return Variable::make(op->type, it->second); + } + return e; + } + + Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + if (op != nullptr && op->attr_key == air::ir::attr::storage_scope) { + auto it = trans_tensor_.find(op->node.as()->name_hint); + if (it != trans_tensor_.end() && it->first != it->second) { + return op->body; + } + } + return stmt; + } + + Stmt Mutate_(const Allocate *op, const Stmt &s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + if (op == nullptr) { + return stmt; + } + auto it = trans_tensor_.find(op->buffer_var->name_hint); + if (it != trans_tensor_.end()) { + if (it->first != it->second) { + return op->body; + } else { + auto bd = tensor_bounds_.find(it->second); + CHECK(bd != tensor_bounds_.end()); + return Allocate::make(op->buffer_var, op->type, bd->second, op->condition, + op->body, op->new_expr, op->free_function); + } + } + return stmt; + } + + private: + std::unordered_map trans_tensor_; + std::unordered_map trans_var_; + std::unordered_map> tensor_bounds_; + std::unordered_map shared_offset_; }; // Verify if the statement can be run safely via inplace fashion @@ -1010,6 +1193,11 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f) { } Stmt StorageRewrite(Stmt stmt) { + LinearAccessPatternFinder finder; + finder.Visit(stmt); + if (finder.Matched()) { + stmt = SharedMemRewriter(finder).Mutate(stmt); + } stmt = StoragePlanRewriter().Rewrite(stmt, true); return VectorAllocRewriter().Mutate(stmt); } diff --git a/third_party/incubator-tvm/src/pass/unroll_loop.cc b/third_party/incubator-tvm/src/pass/unroll_loop.cc index 5829ada448a7a58c4adfeb420d8eb9b75f42acc6..c475c2280378f6e5a9ea6a03951921b76b67693e 100644 --- a/third_party/incubator-tvm/src/pass/unroll_loop.cc +++ b/third_party/incubator-tvm/src/pass/unroll_loop.cc @@ -62,6 +62,9 @@ class LoopUnroller : public IRMutator { } else if (op->attr_key == attr::promote_vectorization) { enable_vectorize_ = true; return IRMutator::Mutate_(op, stmt); + } else if (op->attr_key == "no_unroll") { + no_unroll_ = true; + return IRMutator::Mutate_(op, stmt); } else if (op->attr_key == "pragma_auto_unroll_max_step") { int value = 0; CHECK(arith::GetConstInt(op->value, &value)); @@ -89,6 +92,12 @@ class LoopUnroller : public IRMutator { op = stmt.as(); return For::make(op->loop_var, op->min, op->extent, ForType::Vectorized, op->device_api, op->body); } + + if (no_unroll_) { + no_unroll_ = false; + return IRMutator::Mutate_(op, s); + } + Stmt stmt = IRMutator::Mutate_(op, s); op = stmt.as(); int value = GetExtent(op); @@ -215,6 +224,7 @@ class LoopUnroller : public IRMutator { int step_count_{0}; // Flag for enable vectorization bool enable_vectorize_{false}; + bool no_unroll_{false}; };