1 Star 0 Fork 0

cassuto/LLAMA-FPGA-Inference

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
ils.py 1.21 KB
一键复制 编辑 原始数据 按行查看 历史
'''
Integer Lightweight Softmax implementation using python.
'''
import math
import numpy as np
def i_exp(X_q, S, z, n):
'''
Integer-only Exponential function (i-exp). [Algorithm 2 in paper]
X_q = quantized input
S = scale
z,n = normalization parameters
'''
# coefficients derived from second order polynomial
a = 0.3581
b = 1.353
c = 0.344
q_b = (b/S)
q_c = c/a*(S**2)
S_l = a*(S**2)
X_q_l = (X_q + q_b)**2 + q_c
X_out = X_q_l << n - z
S_out = S_l / (2**n)
return X_out, S_out
def norm_helper(X_q, S, b):
'''
Normalization function [Algorithm 1, function-I in paper]
X_q = quantized input
S = scale
b = bit-width
'''
X_q = X_q - (2**b - 1)
n = 32
X_q_ln = np.round(-1*np.log(2/S))
X_q = np.maximum(X_q, n*X_q_ln)
z = np.round(X_q / X_q_ln)
X_n = X_q - z*X_q_ln
return X_n, z, n
def ils(X_q, S, b):
'''
Integer Lightweight Softmax function. [Algorithm 1]
X_q = quantized input
S = scale
b = bit-width
'''
X_q_temp,z,n = norm_helper(X_q, S, b)
X_exp,S_iexp = i_exp(X_q_temp,S,z,n)
X_out = np.log2(np.sum(i_exp(X_q_temp))) - np.log2(i_exp())
## ensure ints are 32 bit < 2^31 or clip them
## try plotting function anad compare to regular softmax
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/nullptr12/LLAMA-FPGA-Inference.git
git@gitee.com:nullptr12/LLAMA-FPGA-Inference.git
nullptr12
LLAMA-FPGA-Inference
LLAMA-FPGA-Inference
master

搜索帮助