1 Star 3 Fork 3

Hauk Zero/Transformer Demo

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
utils.py 2.42 KB
一键复制 编辑 原始数据 按行查看 历史
Hauk Zero 提交于 8个月前 . update
import json
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
def load_config(filename='model/config.json'):
with open(filename, 'r') as f:
data = json.load(f)
return data
def load_vocab(filename='model/vocab.json'):
with open(filename, 'r') as f:
vocab = json.load(f)
return vocab
def sen2vec(sen, vocab, mex_len):
vec = [ ]
for sentence in sen:
sen_list = sentence.split(' ')
v = [ vocab[ i ] for i in sen_list ]
while len(v) < mex_len:
v.append(vocab[ '<pad>' ])
vec.append(v)
return torch.LongTensor(vec)
def vec2sen(x, vocab):
sen = [ ]
for line in x:
s = ""
for i in range(len(line)):
token = list(vocab.keys())[ line[ i ] ]
if token != '<pad>':
s += list(vocab.keys())[ line[ i ] ] + ' '
if token == '<end>':
break
sen.append(s[ :-1 ])
return sen
def get_pad_mask(seq_q, seq_k, pad_token=0):
batch_size, len_q = seq_q.shape
len_k = seq_k.shape[ 1 ]
# 先放成 (batch_size, 1, len_k) 的形状
mask = (seq_q == pad_token).unsqueeze(1)
# 再填充成 (batch_size, len_q, len_k) 的形状
return mask.expand(batch_size, len_q, len_k).byte()
def get_attn_mask(seq):
mask_shape = [ seq.shape[ 0 ], seq.shape[ 1 ], seq.shape[ 1 ] ]
# 上三角为 1
mask = np.triu(np.ones(mask_shape), k=1)
# (batch_size, n_seq, n_seq) == (batch_size, len_q, len_k)
return torch.from_numpy(mask).byte()
def bool_mask(pad_mask, attn_mask=None):
attn_mask = torch.zeros_like(pad_mask) if attn_mask is None else attn_mask
mask = (pad_mask + attn_mask) > 0
return mask
def attention(Q, K, V, mask):
d_k = K.shape[ -1 ]
scores = (Q @ K.transpose(-1, -2)) / np.sqrt(d_k)
mask = mask[ :, :, :Q.shape[ 2 ], :K.shape[ 2 ] ]
# 用 -1e10 代替 -inf 防止出现 nan
scores.masked_fill_(mask, -1e10)
return nn.Softmax(dim=-1)(scores) @ V
def draw_loss(losses):
x = np.arange(len(losses))
plt.plot(x, losses)
plt.xlabel('Iter')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.show()
if __name__ == '__main__':
sentences = [
'<sta> I very very very love you <end>',
'I very love you'
]
vocab = load_vocab()
vec = sen2vec(sentences, vocab, 16)
print(vec)
sen = vec2sen(vec, vocab)
print(sen == sentences)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/haukzero/transformer-demo.git
git@gitee.com:haukzero/transformer-demo.git
haukzero
transformer-demo
Transformer Demo
moe

搜索帮助

371d5123 14472233 46e8bd33 14472233