代码拉取完成,页面将自动刷新
from dataloader import tr_dataset, te_dataset
from modeling import Model, simi_loss, metric
import torch
from tqdm import tqdm
import random
import os
from torch.utils.tensorboard import SummaryWriter
import numpy as np
writer = SummaryWriter('./log/train')
os.environ["TOKENIZERS_PARALLELISM"] = "false"
Epoch = 100
batchsize = 128
lr = 1e-3
def collate_fn(batch):
return tuple(zip(*batch))
tr_dl = torch.utils.data.DataLoader(tr_dataset, batch_size = batchsize,num_workers = 3,collate_fn=collate_fn)
te_dl = torch.utils.data.DataLoader(te_dataset, batch_size = 200 ,num_workers = 3,collate_fn=collate_fn)
model = Model().cuda()
loss = simi_loss()
opt = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10, eta_min=0, last_epoch=-1, verbose=True)
tr_m = metric()
te_m = metric()
step = 0
best_per = -1000
not_improve = 0
for epoch in range(Epoch):
model.vision_model.train()
tr_m.reset
for img, caps in tqdm(tr_dl):
img = torch.vstack([im[None,:,:,:] for im in img])
img = img.cuda()
tar = [random.choice(cap) for cap in caps]
assert img.shape[0] == len(tar)
img_vec, cap_vec = model([img, tar])
l = loss(img_vec, cap_vec)
opt.zero_grad()
l.backward()
opt.step()
acc = tr_m.compute(img_vec.cpu().detach(), cap_vec.cpu().detach())
tr_m.update(acc)
step += len(img)
writer.add_scalar('loss/trloss', l.cpu().detach(), step)
writer.add_scalar('simi/tr_simi', acc, step)
scheduler.step()
current_lr = scheduler.get_last_lr()[0]
writer.add_scalar('lr/lr-loss',l.cpu().detach(), current_lr * 1e+6)
torch.cuda.empty_cache()
with torch.no_grad():
model.eval()
te_m.reset
for img, caps in tqdm(te_dl):
img = torch.vstack([im[None,:,:,:] for im in img])
img = img.cuda()
tar = [random.choice(cap) for cap in caps]
img_vec, cap_vec = model([img, tar])
l = loss(img_vec, cap_vec)
acc = te_m.compute(img_vec.cpu().detach(), cap_vec.cpu().detach())
te_m.update(acc)
writer.add_scalar('loss/teloss',l.cpu().detach() , epoch)
writer.add_scalar('simi/te_simi', te_m(), epoch)
print("epoch:{}; tr_simi:{:.4f}; te_simi:{:.4f};".format(epoch, tr_m(), te_m()))
if te_m() > best_per:
not_improve = 0
best_per = te_m()
torch.save(model.state_dict(), 'best_parameter.pkl')
torch.save(scheduler.state_dict(), "current_opt.pkl")
print("best model saved")
else:
not_improve += 1
print("Not Improve {}".format(not_improve))
if not_improve > 5:
break
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。