1 Star 1 Fork 1

oscarlin/imagecaption

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
newtrain.py 2.73 KB
一键复制 编辑 原始数据 按行查看 历史
SAKURA-CAT 提交于 2021-05-10 17:59 . update
from dataloader import tr_dataset, te_dataset
from modeling import Model, simi_loss, metric
import torch
from tqdm import tqdm
import random
import os
from torch.utils.tensorboard import SummaryWriter
import numpy as np
writer = SummaryWriter('./log/train')
os.environ["TOKENIZERS_PARALLELISM"] = "false"
Epoch = 100
batchsize = 128
lr = 1e-3
def collate_fn(batch):
return tuple(zip(*batch))
tr_dl = torch.utils.data.DataLoader(tr_dataset, batch_size = batchsize,num_workers = 3,collate_fn=collate_fn)
te_dl = torch.utils.data.DataLoader(te_dataset, batch_size = 200 ,num_workers = 3,collate_fn=collate_fn)
model = Model().cuda()
loss = simi_loss()
opt = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10, eta_min=0, last_epoch=-1, verbose=True)
tr_m = metric()
te_m = metric()
step = 0
best_per = -1000
not_improve = 0
for epoch in range(Epoch):
model.vision_model.train()
tr_m.reset
for img, caps in tqdm(tr_dl):
img = torch.vstack([im[None,:,:,:] for im in img])
img = img.cuda()
tar = [random.choice(cap) for cap in caps]
assert img.shape[0] == len(tar)
img_vec, cap_vec = model([img, tar])
l = loss(img_vec, cap_vec)
opt.zero_grad()
l.backward()
opt.step()
acc = tr_m.compute(img_vec.cpu().detach(), cap_vec.cpu().detach())
tr_m.update(acc)
step += len(img)
writer.add_scalar('loss/trloss', l.cpu().detach(), step)
writer.add_scalar('simi/tr_simi', acc, step)
scheduler.step()
current_lr = scheduler.get_last_lr()[0]
writer.add_scalar('lr/lr-loss',l.cpu().detach(), current_lr * 1e+6)
torch.cuda.empty_cache()
with torch.no_grad():
model.eval()
te_m.reset
for img, caps in tqdm(te_dl):
img = torch.vstack([im[None,:,:,:] for im in img])
img = img.cuda()
tar = [random.choice(cap) for cap in caps]
img_vec, cap_vec = model([img, tar])
l = loss(img_vec, cap_vec)
acc = te_m.compute(img_vec.cpu().detach(), cap_vec.cpu().detach())
te_m.update(acc)
writer.add_scalar('loss/teloss',l.cpu().detach() , epoch)
writer.add_scalar('simi/te_simi', te_m(), epoch)
print("epoch:{}; tr_simi:{:.4f}; te_simi:{:.4f};".format(epoch, tr_m(), te_m()))
if te_m() > best_per:
not_improve = 0
best_per = te_m()
torch.save(model.state_dict(), 'best_parameter.pkl')
torch.save(scheduler.state_dict(), "current_opt.pkl")
print("best model saved")
else:
not_improve += 1
print("Not Improve {}".format(not_improve))
if not_improve > 5:
break
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/oscarlin/imagecaption.git
git@gitee.com:oscarlin/imagecaption.git
oscarlin
imagecaption
imagecaption
master

搜索帮助

23e8dbc6 1850385 7e0993f3 1850385