1 Star 0 Fork 0

evesystem/Inf-Net

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
MyTrain_MulClsLungInf_UNet.py 3.26 KB
一键复制 编辑 原始数据 按行查看 历史
Gepeng_Ji 提交于 2020-05-13 14:46 . Update info
# -*- coding: utf-8 -*-
"""Preview
Code for 'Inf-Net: Automatic COVID-19 Lung Infection Segmentation from CT Scans'
submit to Transactions on Medical Imaging, 2020.
First Version: Created on 2020-05-13 (@author: Ge-Peng Ji)
"""
import os
import numpy as np
import torch.optim as optim
from Code.utils.dataloader_MulClsLungInf_UNet import LungDataset
from torchvision import transforms
# from LungData import test_dataloader, train_dataloader # pls change batch_size
from torch.utils.data import DataLoader
from Code.model_lung_infection.InfNet_UNet import *
def train(epo_num, num_classes, input_channels, batch_size, lr, save_path):
train_dataset = LungDataset(
imgs_path='./Dataset/TrainingSet/MultiClassInfection-Train/Imgs/',
# NOTES: prior is borrowed from the object-level label of train split
pseudo_path='./Dataset/TrainingSet/MultiClassInfection-Train/Prior/',
label_path='./Dataset/TrainingSet/MultiClassInfection-Train/GT/',
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lung_model = Inf_Net_UNet(input_channels, num_classes) # input_channels=3, n_class=3
print(lung_model)
lung_model = lung_model.to(device)
criterion = nn.BCELoss().to(device)
optimizer = optim.SGD(lung_model.parameters(), lr=lr, momentum=0.7)
print("#" * 20, "\nStart Training (Inf-Net)\nThis code is written for 'Inf-Net: Automatic COVID-19 Lung "
"Infection Segmentation from CT Scans', 2020, arXiv.\n"
"----\nPlease cite the paper if you use this code and dataset. "
"And any questions feel free to contact me "
"via E-mail (gepengai.ji@163.com)\n----\n", "#" * 20)
for epo in range(epo_num):
train_loss = 0
lung_model.train()
for index, (img, pseudo, img_mask, _) in enumerate(train_dataloader):
img = img.to(device)
pseudo = pseudo.to(device)
img_mask = img_mask.to(device)
optimizer.zero_grad()
output = lung_model(torch.cat((img, pseudo), dim=1))
output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, 160, 160])
loss = criterion(output, img_mask)
loss.backward()
iter_loss = loss.item()
train_loss += iter_loss
optimizer.step()
if np.mod(index, 20) == 0:
print('Epoch: {}/{}, Step: {}/{}, Train loss is {}'.format(epo, epo_num, index, len(train_dataloader), iter_loss))
os.makedirs('./checkpoints//UNet_Multi-Class-Semi', exist_ok=True)
if np.mod(epo+1, 10) == 0:
torch.save(lung_model.state_dict(),
'./Snapshots/save_weights/{}/unet_model_{}.pkl'.format(save_path, epo+1))
print('Saving checkpoints: unet_model_{}.pkl'.format(epo+1))
if __name__ == "__main__":
train(epo_num=200,
num_classes=3,
input_channels=6,
batch_size=16,
lr=1e-2,
save_path='Semi-Inf-Net_UNet')
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/evesystem/Inf-Net.git
git@gitee.com:evesystem/Inf-Net.git
evesystem
Inf-Net
Inf-Net
master

搜索帮助