1 Star 1 Fork 0

zhugeliang1/UNET-ZOO

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
UNet.py 5.41 KB
一键复制 编辑 原始数据 按行查看 历史
andyzhu 提交于 2020-02-11 13:13 . Add files via upload
import torch.nn as nn
import torch
from torch import autograd
from functools import partial
import torch.nn.functional as F
from torchvision import models
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, input):
return self.conv(input)
class Unet(nn.Module):
def __init__(self, in_ch, out_ch):
super(Unet, self).__init__()
self.conv1 = DoubleConv(in_ch, 32)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = DoubleConv(32, 64)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = DoubleConv(64, 128)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = DoubleConv(128, 256)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = DoubleConv(256, 512)
self.up6 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv6 = DoubleConv(512, 256)
self.up7 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv7 = DoubleConv(256, 128)
self.up8 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv8 = DoubleConv(128, 64)
self.up9 = nn.ConvTranspose2d(64, 32, 2, stride=2)
self.conv9 = DoubleConv(64, 32)
self.conv10 = nn.Conv2d(32, out_ch, 1)
def forward(self, x):
#print(x.shape)
c1 = self.conv1(x)
p1 = self.pool1(c1)
#print(p1.shape)
c2 = self.conv2(p1)
p2 = self.pool2(c2)
#print(p2.shape)
c3 = self.conv3(p2)
p3 = self.pool3(c3)
#print(p3.shape)
c4 = self.conv4(p3)
p4 = self.pool4(c4)
#print(p4.shape)
c5 = self.conv5(p4)
up_6 = self.up6(c5)
merge6 = torch.cat([up_6, c4], dim=1)
c6 = self.conv6(merge6)
up_7 = self.up7(c6)
merge7 = torch.cat([up_7, c3], dim=1)
c7 = self.conv7(merge7)
up_8 = self.up8(c7)
merge8 = torch.cat([up_8, c2], dim=1)
c8 = self.conv8(merge8)
up_9 = self.up9(c8)
merge9 = torch.cat([up_9, c1], dim=1)
c9 = self.conv9(merge9)
c10 = self.conv10(c9)
out = nn.Sigmoid()(c10)
return out
nonlinearity = partial(F.relu, inplace=True)
class DecoderBlock(nn.Module):
def __init__(self, in_channels, n_filters):
super(DecoderBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1)
self.norm1 = nn.BatchNorm2d(in_channels // 4)
self.relu1 = nonlinearity
self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3, stride=2, padding=1, output_padding=1)
self.norm2 = nn.BatchNorm2d(in_channels // 4)
self.relu2 = nonlinearity
self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1)
self.norm3 = nn.BatchNorm2d(n_filters)
self.relu3 = nonlinearity
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.deconv2(x)
x = self.norm2(x)
x = self.relu2(x)
x = self.conv3(x)
x = self.norm3(x)
x = self.relu3(x)
return x
class resnet34_unet(nn.Module):
def __init__(self, num_classes=1, num_channels=3,pretrained=True):
super(resnet34_unet, self).__init__()
filters = [64, 128, 256, 512]
resnet = models.resnet34(pretrained=pretrained)
self.firstconv = resnet.conv1
self.firstbn = resnet.bn1
self.firstrelu = resnet.relu
self.firstmaxpool = resnet.maxpool
self.encoder1 = resnet.layer1
self.encoder2 = resnet.layer2
self.encoder3 = resnet.layer3
self.encoder4 = resnet.layer4
self.decoder4 = DecoderBlock(512, filters[2])
self.decoder3 = DecoderBlock(filters[2], filters[1])
self.decoder2 = DecoderBlock(filters[1], filters[0])
self.decoder1 = DecoderBlock(filters[0], filters[0])
self.decoder4 = DecoderBlock(512, filters[2])
self.decoder3 = DecoderBlock(filters[2], filters[1])
self.decoder2 = DecoderBlock(filters[1], filters[0])
self.decoder1 = DecoderBlock(filters[0], filters[0])
self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
self.finalrelu1 = nonlinearity
self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1)
self.finalrelu2 = nonlinearity
self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1)
def forward(self, x):
# Encoder
x = self.firstconv(x)
x = self.firstbn(x)
x = self.firstrelu(x)
x = self.firstmaxpool(x)
e1 = self.encoder1(x)
e2 = self.encoder2(e1)
e3 = self.encoder3(e2)
e4 = self.encoder4(e3)
# Center
# Decoder
d4 = self.decoder4(e4) + e3
d3 = self.decoder3(d4) + e2
d2 = self.decoder2(d3) + e1
d1 = self.decoder1(d2)
out = self.finaldeconv1(d1)
out = self.finalrelu1(out)
out = self.finalconv2(out)
out = self.finalrelu2(out)
out = self.finalconv3(out)
return nn.Sigmoid()(out)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/wsq_chd/UNET-ZOO.git
git@gitee.com:wsq_chd/UNET-ZOO.git
wsq_chd
UNET-ZOO
UNET-ZOO
master

搜索帮助