1 Star 0 Fork 0

Unluckyless / U-NET-for-LocalBrainAge-prediction-pytorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
test.py 2.82 KB
一键复制 编辑 原始数据 按行查看 历史
Unluckyless 提交于 2022-11-02 21:23 . First commit
from numpy.testing._private.utils import print_assert_equal
import torch
import torch.nn as nn
from torch.nn.modules.activation import PReLU
from torch.nn.modules.distance import PairwiseDistance
import glob
from nets.net import conv_block,upsample,down_conv
from nets.unet import U_Net
import SimpleITK as sitk
import numpy as np
import cv2
from torch.autograd import Variable
from utils.config import unt_config as cfg
softmax = nn.Softmax(dim=1)
#test_img = torch.zeros(1,4)
img = torch.zeros(1,3,52,52).cuda()
#test_img = softmax(test_img)
#print(test_img.shape)
net = U_Net(cfg,'test').cuda()
output = net(img)
print(output.shape)
'''
def read_img(img_path):
return sitk.GetArrayFromImage(sitk.ReadImage(img_path))
t1 = glob.glob(r'dataset/MICCAI_BraTS2020_TrainingData/*/*t1.nii.gz')
imgs = read_img(t1[0])
img = imgs[100].astype(np.uint8)
img = np.array([[img]])
print(img.shape)
img_tensor = Variable(torch.from_numpy(img).type(torch.FloatTensor)).cuda()
output = net(img_tensor)[0]
pre = output.argmax(dim=0)
print(output.shape)
print(pre.shape)
pre = pre.cpu().numpy()*100
pre_img = np.uint8(pre)
cv2.imshow('test',pre_img)
cv2.waitKey(0)
'''
'''
test = torch.zeros(1,5,8,6)
print(test.shape)
temp = test.transpose(1, 2).transpose(2, 3).contiguous().view(-1,5)
print(temp.shape)
print()
'''
num = 100
def read_img(img_path):
return sitk.GetArrayFromImage(sitk.ReadImage(img_path))
t1 = glob.glob(r'dataset/MICCAI_BraTS2020_TrainingData/*/*t1.nii.gz')
t2 = glob.glob(r'dataset/MICCAI_BraTS2020_TrainingData/*/*t2.nii.gz')
flair = glob.glob(r'dataset/MICCAI_BraTS2020_TrainingData/*/*flair.nii.gz')
t1ce = glob.glob(r'dataset/MICCAI_BraTS2020_TrainingData/*/*t1ce.nii.gz')
seg = glob.glob(r'./dataset/MICCAI_BraTS2020_TrainingData/*/*seg.nii.gz')
t1_imgs = read_img(t1[2])
t2_imgs = read_img(t2[2])
flair_imgs = read_img(flair[2])
t1ce_imgs = read_img(t1ce[2])
seg_imgs = read_img(seg[2])
#t1_imgs = sitk.ReadImage(t1)
#t1_array = sitk.GetArrayFromImage(t1_imgs)
print(len(t1_imgs))
t1_img = t1_imgs[num].astype(np.uint8)
print(t1_img.shape)
t2_img = t2_imgs[num].astype(np.uint8)
flair_img = flair_imgs[num].astype(np.uint8)
t1ce_img = t1ce_imgs[num].astype(np.uint8)
seg_img = seg_imgs[num].astype(np.uint8)*100
'''
def crop_ceter(img,croph,cropw):
#for n_slice in range(img.shape[0]):
height,width = img.shape
starth = height//2-(croph//2)+5
startw = width//2-(cropw//2)
return img[starth:starth+croph,startw:startw+cropw]
img = crop_ceter(img,160,160)
'''
input_img = [t1_img,t2_img,flair_img,t1ce_img]
def stacked(imgs):
output = np.stack((imgs[0],imgs[1],imgs[2]),axis=-1)
return output
#print(stacked(input_img))
cv2.imshow('t1',t1_img)
cv2.imshow('t2',t2_img)
cv2.imshow('flair',flair_img)
cv2.imshow('t1ce',t1ce_img)
cv2.imshow('seg_img',seg_img)
cv2.imshow('test',stacked(input_img))
cv2.waitKey(0)
1
https://gitee.com/unluckyless/u-net-for-local-brain-age-prediction-pytorch.git
git@gitee.com:unluckyless/u-net-for-local-brain-age-prediction-pytorch.git
unluckyless
u-net-for-local-brain-age-prediction-pytorch
U-NET-for-LocalBrainAge-prediction-pytorch
master

搜索帮助