Ai
1 Star 2 Fork 0

liuzhongkai/code-generator

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
gemm_fma_nc8hw8_asm.c.in 7.32 KB
一键复制 编辑 原始数据 按行查看 历史
lzk 提交于 2021-11-21 17:03 +08:00 . code genarator
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_@{row_block}x@{col_block}_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight,
const float *bias, const size_t act_flag, const size_t row_block,
const size_t col_block, const size_t deep, const size_t src_stride,
const size_t dst_stride, const size_t inc_flag) {
@if col_block == 32:
const float *dst_4 = dst + 3 * dst_stride;
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\\n"
"je 0f\\n"
@for col in range(0, min((col_block >> 3), 3)):
@for row in range(0, row_block):
@if col == 0:
"vmovups @{row * 32}(%[dst]), %%ymm@{row + col * row_block}\\n"
@else:
"vmovups @{row * 32}(%[dst], %[dst_stride], @{col}), %%ymm@{row + col * row_block}\\n"
@if col_block == 32:
@for row in range(0, row_block):
"vmovups @{row * 32}(%[dst_4]), %%ymm@{row + (col + 1) * row_block}\\n"
"jmp 2f\\n"
"0:\\n"
"cmpq $0, %[bias]\\n"
"je 1f\\n"
@for col in range(0, col_block >> 3):
@for row in range(0, row_block):
"vmovaps @{col * 32}(%[bias]), %%ymm@{row + col * row_block}\\n"
"jmp 2f\\n"
"1:\\n"
@for col in range(0, col_block >> 3):
@for row in range(0, row_block):
"vxorps %%ymm@{row + col * row_block}, %%ymm@{row + col * row_block}, %%ymm@{row + col * row_block}\\n"
"2:\\n"
:
@list = ["[dst] \"r\"(dst)", "[bias] \"r\"(bias)", "[dst_stride] \"r\"(dst_stride_t)", "[inc_flag] \"r\"(inc_flag)"]
@if col_block == 32:
@list.append("[dst_4] \"r\"(dst_4)")
@print(" : " + ", ".join(list), file=OUT_STREAM)
@print(" : " + ", ".join(["\"%ymm" + str(i) + "\"" for i in range(0, row_block * col_block >> 3)]), file=OUT_STREAM)
);
asm volatile(
"0:\\n"
@for i in range(0, 8):
// block @{i}
@if col_block == 32:
@for row in range(0, row_block):
"vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{15 - row}\\n"
@for col in range(0, col_block >> 3):
"vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - row_block}\\n"
@for row in range(0, row_block):
"vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - row_block}, %%ymm@{15 - row}\\n"
@elif col_block == 24:
@for col in range(0, col_block >> 3):
"vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - col}\\n"
@for row in range(0, row_block):
"vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{14 - col}\\n"
@for col in range(0, col_block >> 3):
"vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - (col_block >> 3)}, %%ymm@{15 - col}\\n"
@elif col_block == 16:
@for col in range(0, col_block >> 3):
"vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - col}\\n"
@for row in range(0, row_block >> 1):
"vbroadcastss @{row * 64 + i}(%[src]), %%ymm@{14 - col}\\n"
"vbroadcastss @{row * 64 + 32 + i}(%[src]), %%ymm@{13 - col}\\n"
@for col in range(0, col_block >> 3):
@for j in range(0, 2):
"vfmadd231ps %%ymm@{row * 2 + j + col * row_block}, %%ymm@{15 - (col_block >> 3) - j}, %%ymm@{15 - col}\\n"
@for row in range(row_block >> 1 << 1, row_block):
"vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{14 - col}\\n"
@for col in range(0, col_block >> 3):
"vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - (col_block >> 3)}, %%ymm@{15 - col}\\n"
@else:
@for col in range(0, col_block >> 3):
"vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - col}\\n"
@split_num = 3
@for row in range(0, int(row_block / split_num)):
@for j in range(0, split_num):
"vbroadcastss @{row * 96 + j * 32 + i}(%[src]), %%ymm@{15 - (col_block >> 3) - j}\\n"
@for col in range(0, col_block >> 3):
@for j in range(0, split_num):
"vfmadd231ps %%ymm@{row * split_num + j + col * row_block}, %%ymm@{15 - (col_block >> 3) - j}, %%ymm@{15 - col}\\n"
@for row in range(int(row_block / split_num) * split_num, row_block):
"vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{15 - (col_block >> 3) - (row - int(row_block / split_num) * split_num)}\\n"
@for col in range(0, col_block >> 3):
@for row in range(int(row_block / split_num) * split_num, row_block):
"vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - (col_block >> 3) - (row - int(row_block / split_num) * split_num)}, %%ymm@{15 - col}\\n"
"dec %[deep]\\n"
"add @{col_block * 4 * 8}, %[weight]\\n"
"add %[src_stride], %[src]\\n"
"jg 0b\\n"
"movq %[inc_flag], %rax\\n"
"and $0x2, %eax\\n"
"je 3f\\n"
"movq %[act_flag], %rax\\n"
"and $0x3, %eax\\n"
"je 3f\\n"
// relu
"vxorps %ymm15, %ymm15, %ymm15\\n"
@for col in range(0, col_block >> 3):
@for row in range(0, row_block):
"vmaxps %%ymm@{row + col * row_block}, %%ymm15, %%ymm@{row + col * row_block}\\n"
"and $0x1, %eax\\n"
"je 3f\\n"
// relu6
"mov $0x40C00000, %eax\\n"
"vmovd %eax, %xmm14\\n"
"vpermps %ymm14, %ymm15, %ymm14\\n"
@for col in range(0, col_block >> 3):
@for row in range(0, row_block):
"vminps %%ymm@{row + col * row_block}, %%ymm14, %%ymm@{row + col * row_block}\\n"
"3:\\n"
@for col in range(0, min((col_block >> 3), 3)):
@for row in range(0, row_block):
@if col == 0:
"vmovups %%ymm@{row + col * row_block}, @{row * 32}(%[dst])\\n"
@else:
"vmovups %%ymm@{row + col * row_block}, @{row * 32}(%[dst], %[dst_stride], @{col})\\n"
@if col_block == 32:
@for row in range(0, row_block):
"vmovups %%ymm@{row + (col + 1) * row_block}, @{row * 32}(%[dst_4])\\n"
:
@list = ["[src] \"r\"(src)", "[src_stride] \"r\"(src_stride_t)", "[weight] \"r\"(weight)", "[deep] \"r\"(deep_t)", "[inc_flag] \"r\"(inc_flag)", "[act_flag] \"r\"(act_flag)", "[dst] \"r\"(dst)", "[dst_stride] \"r\"(dst_stride_t)"]
@if col_block == 32:
@list.append("[dst_4] \"r\"(dst_4)")
@print(" : " + ", ".join(list), file=OUT_STREAM)
@print(" : \"%rax\", " + ", ".join(["\"%ymm" + str(i) + "\"" for i in range(0, 16)]), file=OUT_STREAM)
);
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/lzkcode/code-generator.git
git@gitee.com:lzkcode/code-generator.git
lzkcode
code-generator
code-generator
code-ge

搜索帮助