Mainstream AI computing frameworks such as MindSpore provide operators to users that is usually defined in terms of understandable and easy use for user. Each operator carries a different amount of computation and varies in computational complexity. However, from the hardware execution point of view, this natural, user perspective-based division of operator computation volume is not efficient and does not fully utilize the computational power of hardware resources, which is mainly reflected in the following aspects:
In terms of AI framework design, the current industry mainstream adopts a separate layer implementation approach of graph and operator layers. The graph layer is responsible for fusing or regrouping the computational graph, and the operator layer is responsible for compiling the fused or regrouped operators into high-performance executable operators. The graph layer is usually processed and optimized by using Tensor-based High-Level IR, while the operator layer is analyzed and optimized by using computational instruction-based Low-Level IR. This artificial separate-layer process significantly increases the difficulty of performing collaborative optimization in both graph and computational layers.
MindSpore has adopted the technique of graph-kernel fusion to better solve this problem in the past few years. Typical networks in different categories such as NLP and recommendation show significant gains in training speed after enabling graph-kernel fusion. One of the main reasons is the presence of a large number of small operator combinations in these networks, which have more opportunities for fusion optimization.
The overall architecture of graph-kernel fusion is shown in the figure below. The main idea in the graph layer is to turn on the composite operator, then perform cross-boundary aggregation and optimization, and finally perform Kernel operator splitting. The main steps include:
The optimized computational graph is passed to MindSpore AKG as a subgraph for further back-end optimization and target code generation.
By following these steps, we can obtain two aspects of performance gains:
As mentioned earlier, in scenarios such as HPC and deep neural network training, graph-kernel fusion optimization can bring exponential performance improvements. However, with the increasing capability of graph-kernel fusion, the development of fusion operator becomes a bottleneck point to continue to improve the graph-kernel fusion capability. The automatic generation technology of fusion operators can solve the problem of high programming threshold for developing fusion operators based on DSA, allowing programmers to focus on the implementation logic of operators during operator development without focusing on back-end optimization, which greatly improves their development efficiency. Especially for scenarios with complex back-end hardware architectures and the presence of complex operators and fusion operators, automatic operator generation techniques are more critical.
Therefore, MindSpore AKG accelerates optimization and automatic generation of fusion operator based on Polyhedral Compilation Technology (Polyhedral Model), can help fused operators optimized by MindSpore graph-kernel fusion module to automatically generate high-performance kernel on heterogeneous hardware platforms (GPU/Ascend) and improve MindSpore training performance.
The overall framework of MindSpore AKG is shown in the figure above:
The polyhedral model is a common circular nested optimization method in the field of computer-compiled optimization, and its theoretical basis is Presburger arithmetic. The polyhedral model allows us to analyze the read-write dependencies of statements in a program and then provides theoretical support for subsequent cyclic transformations. The core of polyhedral model cyclic optimization is its scheduling algorithm, which can define optimization objectives based on hardware architecture characteristics (such as parallelism and data locality) and convert the cyclic optimization problem into an integer programming problem for solving. In MindSpore AKG, the integer linear programming-based ISL scheduler is mainly used to perform a new scheduling transformation on the input program. The ISL scheduler is dominated by the Pluto algorithm and supplemented by the Feautrier algorithm, which seeks optimality between program parallelism and locality.
What is tiling
Tiling is a widely used method of loop transformation that changes the order in which statement instances are accessed. As shown in the code below, each cycle of this 1024 x 1024 loop can be thought of as a visit to a single point on this two-dimensional space. A 4 x 4 tiling of this loop can change the order of access to the loop points from traversing a large matrix to traversing a small 4 x 4 matrix multiple times.
The value and challenge of tiling
Tiling and memory mapping of data allows access to data via smaller and faster caches. When the amount of computational data is larger than the cache space, it is necessary to adapt the hardware characteristics of the target architecture by storing the original data onto the cache after tiling. To find a good tile, developers need to have an understanding of the memory hierarchy and code logic of the hardware, to be able to analyze what data is logically reused and needs to be put into the cache, and even to have an understanding of the caching mechanism of the hardware (e.g. CPU) and the parallelism mechanism of the hardware (e.g. GPU), so that the tile can improve the utilization of hardware resources and the performance of the code.
MindSpore AKG automatic tiling solution
MindSpore AKG provides Auto-Tiling module, and the main process contains:
Auto-Mapping refers to automatically mapping data and instance in the execution order to multi-threaded processing units, such as the GPU's Thread Block and Thread, on the hardware backend of a multi-threaded architecture. With Auto-Mapping, we can:
Reduce code complexity
As shown in the figure below, with a shape of 8 * 12 operator, Auto-Tiling will try to take the tiling in (a) to reduce the circular boundary judgments shown in blue in (b).
Next, Auto-Mapping also tries to allocate a Thread size that can be divided by the tiled data to improve the utilization of Threads, as in the following example with a Thread of 4.
Optimize Block/Thread/Tile ratio
Auto-Mapping takes into account the hardware characteristics of the GPU when assigning Block size, Thread size and Tile size. Performance optimization is performed by adjusting the ratio of the three to improve the three aspects of utilization, memory throughput rate, and access speed.
For a wide variety of hardware backends, the architectural design usually contains multiple layers of buffers, and the memory space and computation speed supported by each buffer varies greatly, as well as the suitable type of computation. Therefore, when programmers put the programs into different hardware backends for execution, they also need to consider the division of operator into different on-chip memories and the flow of data in different on-chip Buffer, in addition to the computation instructions, to match the different storage structures and enhance the parallelism of the programs.
MindSpore AKG is based on Polyhedral technology to implement DMA data flow identification and generation based on multi-layer Buffer structure. Automatic data movement can further optimize the performance of the operator by analyzing the data flow and giving an indication of what buffers the data should be placed in and the order in which the data should be moved between buffers.
Taking the more complex Davinci architecture with on-chip memory hierarchy as an example, the steps of MindSpore AKG automatic data movement generation are as follows:
The following is an example of two types of calculations to describe how MindSpore AKG uses the above features for automatic generation and optimization of complex operators.
Reduction Computation
The reduction computation, i.e., the cumulative operation on selected dimensions or all dimensions of the Tensor. The common operators are Sum, ReduceMax/Min, ReduceAnd/Or, etc. The Reduction scenario for a large shape is usually divided into two steps:
MindSpore AKG optimizes the reduction operation by automatic axis fusion + polyhedral scheduling optimization + AKG-Reduce template library, and implements the two-step reduction into one kernel by atomic addition.
The process is shown in the figure above:
General matrix multiply and Convolution
MindSpore AKG uses the GPU's Tensor Core hardware computational unit and combines it with polyhedral compilation scheduling optimization and high-performance inline PTX libraries to accelerate general matrix multiply computations in mixed precision scenarios.
On this basis, MindSpore AKG uses the Implicit GEMM to handle mixed accuracy convolution calculations. The two four-dimensional input matrices of the convolution are converted into two-dimensional matrices during the movement from global memory to shared memory, which in turn translates into matrix multiplication calculations for optimization. This method can solve the data redundancy caused by Image to Column (Image to Column is a common convolutional optimization method, referred to as Im2col, which converts each feature map into a contiguous column, and the converted matrix will occupy more global memory).
Taking the optimization of convolutional computation as an example, the process is as follows:
Users and developers can understand the optimization process of fusion or complex operators by running the test cases of MindSpore AKG. Taking the code related to convolution operators as an example:
The four-dimensional convolution (without Pad operation) is calculated as follows, where $N = 32, H = W = 28, Co = 128, Ci = 64, Hk = Hw = 5$.
$$Output(n, h, w, o)=\sum_{c=1}^{Ci} \sum_{rh=1}^{Hk} \sum_{rw=1}^{Wk} (Image(n, h+rh, w+rw, c)*Filter(o, rh, rw, c))$$
Based on its formula, the operator DSL can be written by using tvm.compute:
n, in_h, in_w, in_c = data.shape
out_c, k_h, k_w, in_c = weight.shape
_, _, s_h, s_w = stride
o_h = (in_h - k_h) // s_h + 1
o_w = (in_w - k_w) // s_w + 1
rc = tvm.reduce_axis((0, in_c), name="rc")
rh = tvm.reduce_axis((0, k_h), name="rh")
rw = tvm.reduce_axis((0, k_w), name="rw")
output = tvm.compute(
(n, o_h, o_w, out_c),
lambda n, h, w, o: tvm.sum(
data[n, (h * s_h + rh), (w * s_w + rw), rc]
* weight[o, rh, rw, rc],
axis=[rc, rh, rw]),
name=output_name
)
return output
The following initial scheuling is generated, containing seven for loops, which is computationally inefficient:
// attr [compute(out, 0x55c9185ce710)] realize_scope = ""
realize out<float16>([0, 32], [0, 28], [0, 28], [0, 128]) {
produce out {
for (n, 0, 32) {
for (h, 0, 28) {
for (w, 0, 28) {
for (o, 0, 128) {
out(n, h, w, o) = 0h
for (rc, 0, 64) {
for (rh, 0, 5) {
for (rw, 0, 5) {
// attr [[iter_var(rc, range(min=0, ext=64)), iter_var(rh, range(min=0, ext=5)), iter_var(rw, range(min=0, ext=5))]] reduce_update = ""
out(n, h, w, o) = (out(n, h, w, o) + (input_1(n, (h + rh), (w + rw), rc)*input_2(o, rh, rw, rc)))
}
}
}
}
}
}
}
}
}
After Poly module scheduling optimization, multiple back-end optimization pass and code generation, the program parallelism and data localization of the arithmetic is greatly improved, and the final CUDA kernel executed on the GPU is obtained as follows:
// Introduce the akg_mma_lib high-performance library
#include "akg_mma_lib/wmma.hpp"
extern "C" __global__ void conv_tc_auto_float16_32_32_32_64_float16_128_5_5_64_1_1_0_0_0_0_1_1_float16_kernel0( half* __restrict__ input_1, half* __restrict__ input_2, half* __restrict__ out) {
// Buffer assignment
akg::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 8, float> out_local[4];
half input_2_shared_transfer[32];
__shared__ half input_2_shared[13056];
half input_1_shared_transfer[16];
akg::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 8, half, nvcuda::wmma::col_major> input_2_local[2];
akg::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 8, half, nvcuda::wmma::row_major> input_1_local[2];
#pragma unroll
for (int cc5 = 0; cc5 < 5; ++cc5) {
// Preload the data to be used for this calculation from global memory to shared memory
// Vectorized reading with float4 pointers
#pragma unroll
for (int cc7 = 0; cc7 < 4; ++cc7) {
((float4*)input_2_shared_transfer)[((cc7 * 8) + 0) / 8] = ((float4*)input_2)[(((((cc7 * 51200) + ((((int)threadIdx.x) / 8) * 1600)) + (cc5 * 320)) + ((((int)threadIdx.x) % 8) * 8)) + 0) / 8];
}
#pragma unroll
for (int cc71 = 0; cc71 < 4; ++cc71) {
((float4*)input_2_shared)[(((((cc71 * 2176) + ((((int)threadIdx.x) / 128) * 1088)) + ((((int)threadIdx.x) % 8) * 136)) + (((((int)threadIdx.x) % 128) / 8) * 8)) + 0) / 8] = ((float4*)input_2_shared_transfer)[((cc71 * 8) + 0) / 8];
}
#pragma unroll
for (int cc72 = 0; cc72 < 2; ++cc72) {
((float4*)input_1_shared_transfer)[((cc72 * 8) + 0) / 8] = ((float4*)input_1)[(((((((cc72 * 1048576) + ((((int)threadIdx.x) / 16) * 65536)) + ((((int)blockIdx.y) / 14) * 2048)) + (cc5 * 2048)) + ((((int)blockIdx.y) % 14) * 128)) + ((((int)threadIdx.x) % 16) * 8)) + 0) / 8];
}
#pragma unroll
for (int cc73 = 0; cc73 < 2; ++cc73) {
((float4*)input_2_shared)[(((((cc73 * 2176) + ((((int)threadIdx.x) % 16) * 136)) + ((((int)threadIdx.x) / 16) * 8)) + 0) + 8704) / 8] = ((float4*)input_1_shared_transfer)[((cc73 * 8) + 0) / 8];
}
__syncthreads();
#pragma unroll
for (int cc6_outer = 0; cc6_outer < 4; ++cc6_outer) {
// Preload the data to be used for the next calculation from global memory into registers
#pragma unroll
for (int cc74 = 0; cc74 < 4; ++cc74) {
((float4*)input_2_shared_transfer)[((cc74 * 8) + 0) / 8] = ((float4*)input_2)[(((((((cc74 * 51200) + ((((int)threadIdx.x) / 8) * 1600)) + (cc5 * 320)) + (cc6_outer * 64)) + ((((int)threadIdx.x) % 8) * 8)) + 0) + 64) / 8];
}
#pragma unroll
for (int cc75 = 0; cc75 < 2; ++cc75) {
((float4*)input_1_shared_transfer)[((cc75 * 8) + 0) / 8] = ((float4*)input_1)[(((((((((cc75 * 1048576) + ((((int)threadIdx.x) / 16) * 65536)) + ((((int)blockIdx.y) / 14) * 2048)) + (cc5 * 2048)) + ((((int)blockIdx.y) % 14) * 128)) + (cc6_outer * 64)) + ((((int)threadIdx.x) % 16) * 8)) + 0) + 64) / 8];
}
// Call high performance interfaces for data movement, initialization and mma calculation
#pragma unroll
for (int cc11 = 0; cc11 < 8; ++cc11) {
#pragma unroll
for (int cc123 = 0; cc123 < 2; ++cc123) {
(void)akg::wmma::load_matrix_sync(input_2_local[cc123], &(input_2_shared[((((((int)threadIdx.x) / 64) * 2176) + (cc123 * 1088)) + (cc11 * 136))]), 8);
}
#pragma unroll
for (int cc124 = 0; cc124 < 2; ++cc124) {
(void)akg::wmma::load_matrix_sync(input_1_local[cc124], &(input_2_shared[((((((((int)threadIdx.x) % 64) / 32) * 2176) + (cc124 * 1088)) + (cc11 * 136)) + 8704)]), 8);
}
#pragma unroll
for (int cc21 = 0; cc21 < 2; ++cc21) {
#pragma unroll
for (int cc22 = 0; cc22 < 2; ++cc22) {
if (((cc5 == 0) && (cc6_outer == 0)) && (cc11 == 0)) {
(void)akg::wmma::fill_fragment(out_local[((cc21 * 2) + cc22)], 0.000000e+00f);
}
(void)akg::wmma::mma_sync(out_local[((cc21 * 2) + cc22)], input_1_local[cc21], input_2_local[cc22], out_local[((cc21 * 2) + cc22)]);
}
}
}
// Move the data to be used for the next calculation from registers to shared memory
__syncthreads();
#pragma unroll
for (int cc76 = 0; cc76 < 4; ++cc76) {
((float4*)input_2_shared)[(((((cc76 * 2176) + ((((int)threadIdx.x) / 128) * 1088)) + ((((int)threadIdx.x) % 8) * 136)) + (((((int)threadIdx.x) % 128) / 8) * 8)) + 0) / 8] = ((float4*)input_2_shared_transfer)[((cc76 * 8) + 0) / 8];
}
#pragma unroll
for (int cc77 = 0; cc77 < 2; ++cc77) {
((float4*)input_2_shared)[(((((cc77 * 2176) + ((((int)threadIdx.x) % 16) * 136)) + ((((int)threadIdx.x) / 16) * 8)) + 0) + 8704) / 8] = ((float4*)input_1_shared_transfer)[((cc77 * 8) + 0) / 8];
}
__syncthreads();
}
#pragma unroll
for (int cc111 = 0; cc111 < 8; ++cc111) {
#pragma unroll
for (int cc126 = 0; cc126 < 2; ++cc126) {
(void)akg::wmma::load_matrix_sync(input_2_local[cc126], &(input_2_shared[((((((int)threadIdx.x) / 64) * 2176) + (cc126 * 1088)) + (cc111 * 136))]), 8);
}
#pragma unroll
for (int cc127 = 0; cc127 < 2; ++cc127) {
(void)akg::wmma::load_matrix_sync(input_1_local[cc127], &(input_2_shared[((((((((int)threadIdx.x) % 64) / 32) * 2176) + (cc127 * 1088)) + (cc111 * 136)) + 8704)]), 8);
}
#pragma unroll
for (int cc211 = 0; cc211 < 2; ++cc211) {
#pragma unroll
for (int cc221 = 0; cc221 < 2; ++cc221) {
(void)akg::wmma::mma_sync(out_local[((cc211 * 2) + cc221)], input_1_local[cc211], input_2_local[cc221], out_local[((cc211 * 2) + cc221)]);
}
}
}
__syncthreads();
}
#pragma unroll
for (int cc4 = 0; cc4 < 2; ++cc4) {
#pragma unroll
for (int cc6 = 0; cc6 < 2; ++cc6) {
(void)akg::wmma::store_matrix_sync(&(input_2_shared[((((((((int)threadIdx.x) % 64) / 32) * 4352) + (cc4 * 136)) + ((((int)threadIdx.x) / 64) * 32)) + (cc6 * 16))]), out_local[((cc4 * 2) + cc6)], 272, nvcuda::wmma::mem_row_major);
}
}
// Move the calculation results out to the output buffer in global memory
__syncthreads();
#pragma unroll
for (int cc41 = 0; cc41 < 4; ++cc41) {
((float4*)out)[(((((cc41 * 802816) + ((((int)threadIdx.x) / 32) * 100352)) + (((int)blockIdx.y) * 256)) + ((((int)threadIdx.x) % 32) * 8)) + 0) / 8] = ((float4*)input_2_shared)[((((cc41 * 2176) + ((((int)threadIdx.x) / 16) * 136)) + ((((int)threadIdx.x) % 16) * 8)) + 0) / 8];
}
__syncthreads();
}
MindSpore AKG supports the generation of forward and backward fusion scenarios for reduction, general matrix multiply and convolution operators, ensuring the performance of fusion operators while saving inter-operator I/O and memory consumption.
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。