代码拉取完成,页面将自动刷新
import torch as t
import numpy as np
from torch.utils.data import DataLoader
from torch import optim
from torch import nn
from model import *
from torchnet import meter
import tqdm
from config import *
from test import *
def train():
if Config.use_gpu:
Config.device = t.device("cuda")
else:
Config.device = t.device("cpu")
device = Config.device
# 获取数据
datas = np.load("tang.npz", allow_pickle=True)
data = datas['data']
ix2word = datas['ix2word'].item()
word2ix = datas['word2ix'].item()
data = t.from_numpy(data)
dataloader = DataLoader(data,
batch_size=Config.batch_size,
shuffle=True,
num_workers=2)
# 定义模型
model = PoetryModel(len(word2ix),
embedding_dim=Config.embedding_dim,
hidden_dim=Config.hidden_dim)
Configimizer = optim.Adam(model.parameters(), lr=Config.lr)
criterion = nn.CrossEntropyLoss()
if Config.model_path:
model.load_state_dict(t.load(Config.model_path, map_location='cpu'))
# 转移到相应计算设备上
model.to(device)
loss_meter = meter.AverageValueMeter()
# 进行训练
f = open('result.txt', 'w')
for epoch in range(Config.epoch):
loss_meter.reset()
for li, data_ in tqdm.tqdm(enumerate(dataloader)):
# print(data_.shape)
data_ = data_.long().transpose(1, 0).contiguous()
# 注意这里,也转移到了计算设备上
data_ = data_.to(device)
Configimizer.zero_grad()
# n个句子,前n-1句作为输入,后n-1句作为输出,二者一一对应
input_, target = data_[:-1, :], data_[1:, :]
output, _ = model(input_)
# print("Here",output.shape)
# 这里为什么view(-1)
print(target.shape, target.view(-1).shape)
loss = criterion(output, target.view(-1))
loss.backward()
Configimizer.step()
loss_meter.add(loss.item())
# 进行可视化
if (1+li) % Config.plot_every == 0:
print("训练损失为%s" % (str(loss_meter.mean)))
f.write("训练损失为%s" % (str(loss_meter.mean)))
for word in list(u"春江花朝秋月夜"):
gen_poetry = ''.join(
generate(model, word, ix2word, word2ix))
print(gen_poetry)
f.write(gen_poetry)
f.write("\n\n\n")
f.flush()
t.save(model.state_dict(), '%s_%s.pth' % (Config.model_prefix, epoch))
if __name__ == '__main__':
train()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。