1 Star 3 Fork 0

孙浩/3D-UNet-PyTorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train.py 5.57 KB
一键复制 编辑 原始数据 按行查看 历史
孙浩 提交于 3年前 . 3D-UNet-Lung-Segmentation
import pandas as pd
import torchvision
from dataset.dataset_lits_val import Val_Dataset
from dataset.dataset_lits_train import Train_Dataset
from torch.utils.data import DataLoader
import torch
import torch.optim as optim
from tqdm import tqdm
import config
from models import UNet, ResUNet, KiUNet_min, SegNet
from utils import logger, weights_init, metrics, common, loss
import os
import numpy as np
from collections import OrderedDict
def val(model, val_loader, loss_func, n_labels):
model.eval()
val_loss = metrics.LossAverage()
val_dice = metrics.DiceAverage(n_labels)
with torch.no_grad():
for idx, (data, target) in tqdm(enumerate(val_loader), total=len(val_loader)):
data, target = data.float(), target.long()
target = common.to_one_hot_3d(target, n_labels)
data, target = data.to(device), target.to(device)
output = model(data)
loss = loss_func(output, target)
val_loss.update(loss.item(), data.size(0))
val_dice.update(output, target)
val_log = OrderedDict({'Valid_Loss': float('%.4f' % val_loss.avg), 'V_dice_liver_avg': 0})
for i in range(1, n_labels):
val_log[f'V_dice_liver-{i}'] = float('%.4f' % val_dice.avg[i])
val_log['V_dice_liver_avg'] = float('%.4f' % np.mean([v for v in val_log.values()][2:]))
return val_log
def train(model, train_loader, optimizer, loss_func, n_labels, alpha):
print("=======Epoch:{}=======lr:{}".format(epoch, optimizer.state_dict()['param_groups'][0]['lr']))
model.train()
train_loss = metrics.LossAverage()
train_dice = metrics.DiceAverage(n_labels)
for idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
data, target = data.float(), target.long()
target = common.to_one_hot_3d(target, n_labels)
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss0 = loss_func(output[0], target)
loss1 = loss_func(output[1], target)
loss2 = loss_func(output[2], target)
loss3 = loss_func(output[3], target)
loss = loss3 + alpha * (loss0 + loss1 + loss2)
loss.backward()
optimizer.step()
train_loss.update(loss3.item(), data.size(0))
train_dice.update(output[3], target)
val_log = OrderedDict({'Train_Loss': float('%.4f' % train_loss.avg), 'T_dice_liver_avg': 0})
for i in range(1, n_labels):
val_log[f'T_dice_liver-{i}'] = float('%.4f' % train_dice.avg[i])
val_log['T_dice_liver_avg'] = float('%.4f' % np.mean([v for v in val_log.values()][2:]))
return val_log
if __name__ == '__main__':
args = config.args
save_path = os.path.join('./experiments', args.save)
if not os.path.exists(save_path): os.mkdir(save_path)
device = torch.device('cpu' if args.cpu else 'cuda')
# data info
train_loader = DataLoader(dataset=Train_Dataset(args), batch_size=args.batch_size, num_workers=args.n_threads, shuffle=True)
val_loader = DataLoader(dataset=Val_Dataset(args), batch_size=1, num_workers=args.n_threads, shuffle=False)
# model info
if os.path.exists(f'{save_path}/best_model.pth'):
model = ResUNet(in_channel=1, out_channel=args.n_labels, training=False).to(device)
model = torch.nn.DataParallel(model, device_ids=args.gpu_id) # multi-GPU
ckpt = torch.load('{}/best_model.pth'.format(save_path))
# ckpt = torch.load('{}/latest_model.pth'.format(save_path))
model.load_state_dict(ckpt['net'])
log = logger.Train_Logger(save_path, "train_log", pd.read_csv(f'{save_path}/train_log.csv'))
info = [ckpt['last_epoch'] + 1, ckpt['best_epoch'], ckpt['val_dice']] # [last_epoch, best_epoch]
else:
model = ResUNet(in_channel=1, out_channel=args.n_labels, training=True).to(device)
model.apply(weights_init.init_model)
model = torch.nn.DataParallel(model, device_ids=args.gpu_id) # multi-GPU
log = logger.Train_Logger(save_path, "train_log", None)
info = [1, 1, 0] # [last_epoch, best_epoch]
# optimizer = optim.Adam(model.parameters(), lr=args.lr) # 0.0001
optimizer = optim.SGD(model.parameters(), lr=args.lr)
common.print_network(model)
loss = loss.TverskyLoss()
trigger = 0 # early stop 计数器
alpha = 0.4 # 深监督衰减系数初始值
for epoch in range(info[0], args.epochs + info[0]):
common.adjust_learning_rate(optimizer, epoch, args)
train_log = train(model, train_loader, optimizer, loss, args.n_labels, alpha)
val_log = val(model, val_loader, loss, args.n_labels)
log.update(epoch, train_log, val_log)
# Save checkpoint.
state = {'net': model.state_dict(),
'optimizer': optimizer.state_dict(),
'last_epoch': epoch,
'best_epoch': info[1],
'val_dice': val_log['V_dice_liver_avg']}
torch.save(state, os.path.join(save_path, 'latest_model.pth'))
trigger += 1
if val_log['V_dice_liver_avg'] > info[2]:
print('Saving best model')
torch.save(state, os.path.join(save_path, 'best_model.pth'))
info[1] = epoch
info[2] = val_log['V_dice_liver_avg']
trigger = 0
print('Best performance at Epoch: {} | {}'.format(info[1], info[2]))
# 深监督系数衰减
if epoch % 5 == 10:
alpha *= 0.8
# early stopping
if args.early_stop is not None:
if trigger >= args.early_stop:
print("=> early stopping")
break
torch.cuda.empty_cache()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/SunHao-AI/3d-unet-pytorch.git
git@gitee.com:SunHao-AI/3d-unet-pytorch.git
SunHao-AI
3d-unet-pytorch
3D-UNet-PyTorch
master

搜索帮助