1 Star 0 Fork 0

gitee-hc/LSTM

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
main.py 2.70 KB
一键复制 编辑 原始数据 按行查看 历史
import torch as t
import numpy as np
from torch.utils.data import DataLoader
from torch import optim
from torch import nn
from model import *
from torchnet import meter
import tqdm
from config import *
from test import *
def train():
if Config.use_gpu:
Config.device = t.device("cuda")
else:
Config.device = t.device("cpu")
device = Config.device
# 获取数据
datas = np.load("tang.npz", allow_pickle=True)
data = datas['data']
ix2word = datas['ix2word'].item()
word2ix = datas['word2ix'].item()
data = t.from_numpy(data)
dataloader = DataLoader(data,
batch_size=Config.batch_size,
shuffle=True,
num_workers=2)
# 定义模型
model = PoetryModel(len(word2ix),
embedding_dim=Config.embedding_dim,
hidden_dim=Config.hidden_dim)
Configimizer = optim.Adam(model.parameters(), lr=Config.lr)
criterion = nn.CrossEntropyLoss()
if Config.model_path:
model.load_state_dict(t.load(Config.model_path, map_location='cpu'))
# 转移到相应计算设备上
model.to(device)
loss_meter = meter.AverageValueMeter()
# 进行训练
f = open('result.txt', 'w')
for epoch in range(Config.epoch):
loss_meter.reset()
for li, data_ in tqdm.tqdm(enumerate(dataloader)):
# print(data_.shape)
data_ = data_.long().transpose(1, 0).contiguous()
# 注意这里,也转移到了计算设备上
data_ = data_.to(device)
Configimizer.zero_grad()
# n个句子,前n-1句作为输入,后n-1句作为输出,二者一一对应
input_, target = data_[:-1, :], data_[1:, :]
output, _ = model(input_)
# print("Here",output.shape)
# 这里为什么view(-1)
print(target.shape, target.view(-1).shape)
loss = criterion(output, target.view(-1))
loss.backward()
Configimizer.step()
loss_meter.add(loss.item())
# 进行可视化
if (1+li) % Config.plot_every == 0:
print("训练损失为%s" % (str(loss_meter.mean)))
f.write("训练损失为%s" % (str(loss_meter.mean)))
for word in list(u"春江花朝秋月夜"):
gen_poetry = ''.join(
generate(model, word, ix2word, word2ix))
print(gen_poetry)
f.write(gen_poetry)
f.write("\n\n\n")
f.flush()
t.save(model.state_dict(), '%s_%s.pth' % (Config.model_prefix, epoch))
if __name__ == '__main__':
train()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/gitee-hc/lstm.git
git@gitee.com:gitee-hc/lstm.git
gitee-hc
lstm
LSTM
master

搜索帮助