name | about | labels |
---|---|---|
Bug Report | run topi.nn.softmax() fail | kind/bug |
When I run the softmax in akg, fail occurs.
It seems that it is caused by wrongly using sharedmemory.
Pls check this bug.
Ascend
/GPU
/CPU
):Uncomment only one
/device <>
line, hit enter to put that in a new line, and remove leading whitespaces from that line:
/device gpu
Code to reproduce:
import numpy as np
from akg.utils import kernel_exec as utils
from akg.utils.result_analysis import gpu_profiling
from akg.utils.format_transform import to_tvm_nd_array
from akg import topi
import scipy
def softmax(data1):
return topi.nn.softmax(data1)
def gen_data(shape, dtype):
np.random.seed(0)
in_data = np.random.uniform(-10, 10, size=shape).astype(dtype)
expect = scipy.special.softmax(in_data)
# import tensorflow as tf
# tf.compat.v1.enable_eager_execution()
# expect = tf.math.softmax(in_data).numpy()
output = np.full(expect.shape, np.nan, dtype)
return in_data, output, expect
def test_ms_softmax(shape, dtype, poly_sch=False):
if poly_sch:
mod = utils.op_build_test(softmax, (shape,), (dtype,), kernel_name="softmax", attrs={"target": "cuda","enable_akg_reduce_lib":True,"dim":"0 0 1 1 0 1 1001 1","bind_block":"1 1","bind_thread":"1001 1"})
in_data, output, expect = gen_data(shape, dtype)
output = utils.mod_launch(mod, (in_data, output), expect=expect)
res = np.allclose(output, expect, rtol=1e-03, atol=1.e-3)
print("output")
print(output)
print("expect")
print(expect)
print("Test {}".format("Pass" if res else "Fail"))
if not res:
print("Error cuda:========================")
print(mod.imported_modules[0].get_source())
raise AssertionError("Test fail")
if __name__ == "__main__":
test_ms_softmax((1, 1001), "float32", True)
Hey yiyanzhi_akane, Welcome to MindSpore Community.
All of the projects in MindSpore Community are maintained by @mindspore-ci-bot.
That means the developers can comment below every pull request or issue to trigger Bot Commands.
Please follow instructions at https://gitee.com/mindspore/community/blob/master/command.md to find the details.
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。
What is the current behavior? Is there any screenshot of the error?
[INFO] AKG:2021-05-28-17:22:53.300.687 [poly.cc:46] [poly] [ Polyhedral exec time ], GenIsl spent 4.14875 ms
[INFO] AKG:2021-05-28-17:22:53.300.741 [scop.cc:144] [poly] ====== Reduce op type ========
[INFO] AKG:2021-05-28-17:22:53.300.830 [scop.cc:146] [poly] S_3 -> SumOp
[INFO] AKG:2021-05-28-17:22:53.300.896 [scop.cc:146] [poly] S_1 -> MaxOp
[INFO] AKG:2021-05-28-17:22:53.301.074 [schedule_pass_mgr.cc:55] [poly] Running poly pass InitSchedule
[INFO] AKG:2021-05-28-17:22:53.309.864 [schedule_pass_mgr.cc:71] [poly] [ Polyhedral exec time ], InitSchedule spent 8.72858 ms
[INFO] AKG:2021-05-28-17:22:53.311.121 [schedule_pass_mgr.cc:55] [poly] Running poly pass ConstrainSchedule
[INFO] AKG:2021-05-28-17:22:53.314.503 [schedule_pass_mgr.cc:71] [poly] [ Polyhedral exec time ], ConstrainSchedule spent 3.32793 ms
[INFO] AKG:2021-05-28-17:22:53.315.497 [schedule_pass_mgr.cc:55] [poly] Running poly pass ComputeSchedule
[INFO] AKG:2021-05-28-17:22:53.316.871 [schedule_pass_mgr.cc:71] [poly] [ Polyhedral exec time ], ComputeSchedule spent 1.32603 ms
[INFO] AKG:2021-05-28-17:22:53.317.809 [schedule_pass_mgr.cc:55] [poly] Running poly pass GpuDmaAnalysis
[INFO] AKG:2021-05-28-17:22:53.321.258 [scop.cc:244] [poly] [ Polyhedral exec time ], NodeFrom spent 2.684 ms
[INFO] AKG:2021-05-28-17:22:53.323.500 [scop.cc:279] [poly] [ Polyhedral exec time ], IslEmitter spent 2.19311 ms
[INFO] AKG:2021-05-28-17:22:53.329.594 [tiling.cc:447] [tiling] This dim is generated by auto tiling
[INFO] AKG:2021-05-28-17:22:53.333.361 [schedule_pass_mgr.cc:71] [poly] [ Polyhedral exec time ], GpuDmaAnalysis spent 15.4926 ms
[INFO] AKG:2021-05-28-17:22:53.334.421 [schedule_pass_mgr.cc:55] [poly] Running poly pass TileOuterBand
[INFO] AKG:2021-05-28-17:22:53.335.460 [tile_outer_band.cc:263] [schedule_pass] No: 0, tiling_flag: 0
[INFO] AKG:2021-05-28-17:22:53.335.480 [tile_outer_band.cc:272] [schedule_pass] index: 0, axis: 0, c1_size: 1, c0_size: 1, seq: 0, is inner: 0
[INFO] AKG:2021-05-28-17:22:53.335.492 [tile_outer_band.cc:272] [schedule_pass] index: 0, axis: 1, c1_size: 1001, c0_size: 1, seq: 1, is inner: 0
[INFO] AKG:2021-05-28-17:22:53.336.307 [schedule_pass_mgr.cc:71] [poly] [ Polyhedral exec time ], TileOuterBand spent 1.8231 ms
[INFO] AKG:2021-05-28-17:22:53.337.257 [schedule_pass_mgr.cc:55] [poly] Running poly pass MappingOuterBand
[INFO] AKG:2021-05-28-17:22:53.342.102 [schedule_pass_mgr.cc:71] [poly] [ Polyhedral exec time ], MappingOuterBand spent 4.78755 ms
[INFO] AKG:2021-05-28-17:22:53.343.192 [schedule_pass_mgr.cc:55] [poly] Running poly pass SharedMemoryManager
[INFO] AKG:2021-05-28-17:22:53.364.321 [schedule_pass_mgr.cc:71] [poly] [ Polyhedral exec time ], SharedMemoryManager spent 21.0531 ms
[INFO] AKG:2021-05-28-17:22:53.366.172 [schedule_pass_mgr.cc:55] [poly] Running poly pass RegisterMemoryManager
[INFO] AKG:2021-05-28-17:22:53.383.362 [schedule_pass_mgr.cc:71] [poly] [ Polyhedral exec time ], RegisterMemoryManager spent 17.0989 ms
[INFO] AKG:2021-05-28-17:22:53.385.426 [schedule_pass_mgr.cc:55] [poly] Running poly pass RealizeManager
[INFO] AKG:2021-05-28-17:22:53.386.254 [schedule_pass_mgr.cc:71] [poly] [ Polyhedral exec time ], RealizeManager spent 0.772184 ms
[INFO] AKG:2021-05-28-17:22:53.388.084 [poly.cc:51] [poly] [ Polyhedral exec time ], Transform spent 87.3507 ms
[INFO] AKG:2021-05-28-17:22:53.413.464 [scop.cc:244] [poly] [ Polyhedral exec time ], NodeFrom spent 25.2141 ms
[INFO] AKG:2021-05-28-17:22:53.429.559 [scop.cc:279] [poly] [ Polyhedral exec time ], IslEmitter spent 16.0373 ms
[INFO] AKG:2021-05-28-17:22:53.429.950 [poly.cc:56] [poly] [ Polyhedral exec time ], GenHalide spent 41.8455 ms
[INFO] AKG:2021-05-28-17:22:53.441.755 [swizzle_gpu.cc:729] [pass] BEGIN_PASS SwizzleGPU on softmax_float32_1_1001_0
[DEBUG] AKG:2021-05-28-17:22:53.441.826 [swizzle_gpu.cc:435] [pass] Visit statement
[DEBUG] AKG:2021-05-28-17:22:53.441.839 [swizzle_gpu.cc:42] [pass] Thread extent (threadIdx.z) : 1
[DEBUG] AKG:2021-05-28-17:22:53.441.859 [swizzle_gpu.cc:232] [pass] Allocate : T_softmax_maxelem_shared size 1
[DEBUG] AKG:2021-05-28-17:22:53.441.891 [swizzle_gpu.cc:232] [pass] Allocate : T_softmax_expsum_shared size 1
[DEBUG] AKG:2021-05-28-17:22:53.441.908 [swizzle_gpu.cc:232] [pass] Allocate : T_softmax_expsum_shared size 1
[DEBUG] AKG:2021-05-28-17:22:53.441.922 [swizzle_gpu.cc:232] [pass] Allocate : T_softmax_maxelem_shared size 1
[DEBUG] AKG:2021-05-28-17:22:53.441.935 [swizzle_gpu.cc:42] [pass] Thread extent (threadIdx.y) : 1
[DEBUG] AKG:2021-05-28-17:22:53.441.949 [swizzle_gpu.cc:42] [pass] Thread extent (threadIdx.x) : 1001
[DEBUG] AKG:2021-05-28-17:22:53.441.960 [swizzle_gpu.cc:42] [pass] Thread extent (blockIdx.z) : 1
[DEBUG] AKG:2021-05-28-17:22:53.441.972 [swizzle_gpu.cc:42] [pass] Thread extent (blockIdx.y) : 1
[DEBUG] AKG:2021-05-28-17:22:53.441.984 [swizzle_gpu.cc:42] [pass] Thread extent (blockIdx.x) : 1
[DEBUG] AKG:2021-05-28-17:22:53.442.017 [swizzle_gpu.cc:227] [pass] End Store T_softmax_maxelem_shared
[DEBUG] AKG:2021-05-28-17:22:53.442.030 [swizzle_gpu.cc:227] [pass] End Store T_softmax_maxelem_shared
[DEBUG] AKG:2021-05-28-17:22:53.442.043 [swizzle_gpu.cc:227] [pass] End Store T_softmax_expsum_shared
[DEBUG] AKG:2021-05-28-17:22:53.442.053 [swizzle_gpu.cc:227] [pass] End Store T_softmax_expsum_shared
[DEBUG] AKG:2021-05-28-17:22:53.442.073 [swizzle_gpu.cc:227] [pass] End Store T_softmax_expsum_shared
[DEBUG] AKG:2021-05-28-17:22:53.442.086 [swizzle_gpu.cc:227] [pass] End Store T_softmax_norm
[INFO] AKG:2021-05-28-17:22:53.442.100 [swizzle_gpu.cc:444] [pass] Total swizzled loops for softmax_float32_1_1001_0 : 0
[INFO] AKG:2021-05-28-17:22:53.442.110 [swizzle_gpu.cc:733] [pass] END_PASS
func_time_required func:op_build_test, running:0.721295 seconds
func_time_required func:mod_launch, running:0.412954 seconds
output
[[4.1422942e-01 1.1544185e+01 1.2185507e+00 ... 6.8597146e-04
5.3936224e+00 1.0000000e+00]]
expect
[[2.2479408e-06 6.2648018e-05 6.6128327e-06 ... 3.7226335e-09
2.9270133e-05 5.4268007e-06]]
Test Fail
Error cuda:========================
#ifndef __CUDA_ARCH__
#define __CUDA_ARCH__ 700
#endif
// built-in for half swizzle
#include <cuda_fp16.h>
struct __device_builtin__ __align__(8) half4 { half x, y, z, w; };
#if defined(__CUDACC_RTC__)
#define __CUDA_FP16_DECL__ __host__ __device__
#else
#define __CUDA_FP16_DECL__ static __device__ __inline__
#endif
// half4 ldg function support
#if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)
#define __LDG_PTR "l"
#else
// not sure about this one, it was copied from the half2 ldg() function
#define __LDG_PTR "r"
#endif /*(defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)*/
#define __HALF4_TO_UI(var) *(reinterpret_cast<unsigned long *>(&(var)))
__CUDA_FP16_DECL__ half4 __ldg(const half4 *ptr)
{
half4 ret;
asm ("ld.global.nc.b64 %0, [%1];" : "=l"(__HALF4_TO_UI(ret)) : __LDG_PTR(ptr));
return ret;
}
extern "C" __global__ void softmax_float32_1_1001_kernel0( float* __restrict__ input_1, float* __restrict__ T_softmax_norm) {
__shared__ float T_softmax_maxelem_shared[1];
__shared__ float T_softmax_expsum_shared[1];
__shared__ float T_softmax_expsum_shared1[1];
__shared__ float T_softmax_maxelem_shared1[1];
for (int cc0 = 0; cc0 < 2000; ++cc0) {
if (((int)threadIdx.x) == 0) {
if (cc0 == 0) {
T_softmax_maxelem_shared[0] = -3.402823e+38f;
}
if (cc0 <= 1000) {
T_softmax_maxelem_shared[0] = max(T_softmax_maxelem_shared[0], input_1[cc0]);
}
if (1000 <= cc0) {
T_softmax_expsum_shared[0] = (T_softmax_expsum_shared[0] + __expf((input_1[(cc0 - 1000)] - T_softmax_maxelem_shared[0])));
} else {
if (cc0 == 0) {
T_softmax_expsum_shared[0] = 0.000000e+00f;
}
}
}
__syncthreads();
}
for (int cc01 = 2000; cc01 < 3001; ++cc01) {
if (((int)threadIdx.x) == 0) {
if (cc01 == 2000) {
T_softmax_expsum_shared1[0] = (T_softmax_expsum_shared1[0] + __expf((input_1[1000] - T_softmax_maxelem_shared1[0])));
}
T_softmax_norm[(cc01 - 2000)] = (__expf((input_1[(cc01 - 2000)] - T_softmax_maxelem_shared1[0])) / T_softmax_expsum_shared1[0]);
}
__syncthreads();
}
}
Traceback (most recent call last):
File "issue.py", line 45, in <module>
test_ms_softmax((1, 1001), "float32", True)
File "issue.py", line 42, in test_ms_softmax
raise AssertionError("Test fail")
AssertionError: Test fail
登录 后才可以发表评论