Ai
1 Star 2 Fork 0

liuzhongkai/code-generator

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
gemm_avx512_nhwc_asm.c.in 8.62 KB
一键复制 编辑 原始数据 按行查看 历史
lzk 提交于 2021-11-21 17:03 +08:00 . code genarator
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_@{row_block}x@{col_block}_kernel_nhwc_fp32(float *dst, const float *src, const float *weight,
const float *bias, const size_t act_flag, const size_t row_block,
const size_t col_block, const size_t deep, const size_t src_stride,
const size_t dst_stride, const size_t inc_flag) {
@if row_block == 4:
const float *src_4 = src + 3 * src_stride;
const float *dst_4 = dst + 3 * dst_stride;
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
@col_split_num = col_block >> 4;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\\n"
"je 0f\\n"
@for row in range(0, min(row_block, 3)):
@for col in range(0, col_split_num):
@if row == 0:
"vmovups @{col * 64}(%[dst]), %%zmm@{row * col_split_num + col}\\n"
@else:
"vmovups @{col * 64}(%[dst], %[dst_stride], @{row}), %%zmm@{row * col_split_num + col}\\n"
@if row_block >= 4:
@for col in range(0, col_split_num):
"vmovups @{col * 64}(%[dst_4]), %%zmm@{3 * col_split_num + col}\\n"
@if row_block >= 5:
@for col in range(0, col_split_num):
"vmovups @{col * 64}(%[dst], %[dst_stride], 4), %%zmm@{4 * col_split_num + col}\\n"
@if row_block >= 6:
@for col in range(0, col_split_num):
"vmovups @{col * 64}(%[dst_4], %[dst_stride], 2), %%zmm@{5 * col_split_num + col}\\n"
"jmp 2f\\n"
"0:\\n"
"cmpq $0, %[bias]\\n"
"je 1f\\n"
@for row in range(0, row_block):
@for col in range(0, col_split_num):
"vmovaps @{col * 64}(%[bias]), %%zmm@{row * col_split_num + col}\\n"
"jmp 2f\\n"
"1:\\n"
@for row in range(0, row_block):
@for col in range(0, col_split_num):
"vxorps %%zmm@{row * col_split_num + col}, %%zmm@{row * col_split_num + col}, %%zmm@{row * col_split_num + col}\\n"
"2:\\n"
:
@list = ["[dst] \"r\"(dst)", "[bias] \"r\"(bias)", "[dst_stride] \"r\"(dst_stride_t)", "[inc_flag] \"r\"(inc_flag)"]
@if row_block == 4:
@list.append("[dst_4] \"r\"(dst_4)")
@print(" : " + ", ".join(list), file=OUT_STREAM)
@print(" : " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, row_block * col_block >>4)]), file=OUT_STREAM)
);
asm volatile(
"0:\\n"
@loop_count = 8
@for i in range(0, loop_count):
// block @{i}
@for col in range(0, col_split_num):
"vmovups @{col * 64 + i * col_block * 4}(%[weight]), %%zmm@{31 - col}\\n"
@if col_split_num == 6:
@if row_block == 4:
"vbroadcastss @{i * 4}(%[src]), %%zmm@{31 - col_split_num}\\n"
"vbroadcastss @{i * 4}(%[src], %[src_stride], 1), %%zmm@{31 - col_split_num - 1}\\n"
@for col in range(0, col_split_num):
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{0 * col_split_num + col}\\n"
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{1 * col_split_num + col}\\n"
"vbroadcastss @{i * 4}(%[src], %[src_stride], 2), %%zmm@{31 - col_split_num}\\n"
"vbroadcastss @{i * 4}(%[src_4]), %%zmm@{31 - col_split_num - 1}\\n"
@for col in range(0, col_split_num):
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{2 * col_split_num + col}\\n"
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{3 * col_split_num + col}\\n"
@else:
@for row in range(0, row_block):
@if row == 0:
"vbroadcastss @{i * 4}(%[src]), %%zmm@{31 - col_split_num - row}\\n"
@else:
"vbroadcastss @{i * 4}(%[src], %[src_stride], @{row}), %%zmm@{31 - col_split_num - row}\\n"
@for row in range(0, row_block):
@for col in range(0, col_split_num):
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - row}, %%zmm@{row * col_split_num + col}\\n"
@elif col_split_num == 5:
@if row_block == 5:
"vbroadcastss @{i * 4}(%[src]), %%zmm@{31 - col_split_num}\\n"
"vbroadcastss @{i * 4}(%[src], %[src_stride], 1), %%zmm@{31 - col_split_num - 1}\\n"
@for col in range(0, col_split_num):
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{0 * col_split_num + col}\\n"
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{1 * col_split_num + col}\\n"
"vbroadcastss @{i * 4}(%[src], %[src_stride], 2), %%zmm@{31 - col_split_num}\\n"
"vbroadcastss @{i * 4}(%[src_4]), %%zmm@{31 - col_split_num - 1}\\n"
@for col in range(0, col_split_num):
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{2 * col_split_num + col}\\n"
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{3 * col_split_num + col}\\n"
"vbroadcastss @{i * 4}(%[src], %[src_stride], 4), %%zmm@{31 - col_split_num}\\n"
@for col in range(0, col_split_num):
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{2 * col_split_num + col}\\n"
@else:
@for row in range(0, row_block):
@if row == 0:
"vbroadcastss @{i * 4}(%[src]), %%zmm@{31 - col_split_num - row}\\n"
@else:
"vbroadcastss @{i * 4}(%[src], %[src_stride], @{row}), %%zmm@{31 - col_split_num - row}\\n"
@for row in range(0, row_block):
@for col in range(0, col_split_num):
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - row}, %%zmm@{row * col_split_num + col}\\n"
"dec %[deep]\\n"
"add $@{col_block * 4 * 8}, %[weight]\\n"
"add $@{loop_count * 4}, %[src]\\n"
@if row_block == 4:
"add $32, %[src_4]\\n"
"jg 0b\\n"
"and $0x2, %[inc_flag]\\n"
"je 3f\\n"
"movq %[act_flag], %rax\\n"
"and $0x3, %eax\\n"
"je 3f\\n"
// relu
"vxorps %zmm31, %zmm31, %zmm31\\n"
@for col in range(0, col_split_num):
@for row in range(0, row_block):
"vmaxps %%zmm@{row + col * row_block}, %%zmm31, %%zmm@{row + col * row_block}\\n"
"and $0x1, %eax\\n"
"je 3f\\n"
// relu6
"mov $0x40C00000, %eax\\n"
"vmovd %eax, %xmm30\\n"
"vbroadcastss %xmm30, %zmm30\\n"
@for col in range(0, col_split_num):
@for row in range(0, row_block):
"vminps %%zmm@{row + col * row_block}, %%zmm30, %%zmm@{row + col * row_block}\\n"
"3:\\n"
@for row in range(0, min(row_block, 3)):
@for col in range(0, col_split_num):
@if row == 0:
"vmovups @{col * 64}(%[dst]), %%zmm@{row * col_split_num + col}\\n"
@else:
"vmovups @{col * 64}(%[dst], %[dst_stride], @{row}), %%zmm@{row * col_split_num + col}\\n"
@if row_block >= 4:
@for col in range(0, col_split_num):
"vmovups @{col * 64}(%[dst_4]), %%zmm@{(row + 1) * col_split_num + col}\\n"
@if row_block >= 5:
@for col in range(0, col_split_num):
"vmovups @{col * 64}(%[dst], %[dst_stride], 4), %%zmm@{(row + 1) * col_split_num + col}\\n"
:
@list = ["[src] \"r\"(src)", "[src_stride] \"r\"(src_stride_t)", "[weight] \"r\"(weight)", "[deep] \"r\"(deep_t)", "[inc_flag] \"r\"(inc_flag)", "[act_flag] \"r\"(act_flag)", "[dst] \"r\"(dst)", "[dst_stride] \"r\"(dst_stride_t)"]
@if row_block >= 4:
@list.append("[dst_4] \"r\"(dst_4)")
@list.append("[src_4] \"r\"(src_4)")
@print(" : " + ", ".join(list), file=OUT_STREAM)
@print(" : \"%rax\", " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, 32)]), file=OUT_STREAM)
);
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/lzkcode/code-generator.git
git@gitee.com:lzkcode/code-generator.git
lzkcode
code-generator
code-generator
code-ge

搜索帮助