1 Star 0 Fork 0

Fish/DeepLearning

Create your Gitee Account
Explore and code with more than 13.5 million developers,Free private repositories !:)
Sign up
文件
This repository doesn't specify license. Please pay attention to the specific project description and its upstream code dependency when using it.
Clone or Download
Res34UNet 4.75 KB
Copy Edit Raw Blame History
Fish authored 2 years ago . add UNet/Res34UNet.
from typing import Dict
import torch
import torch.nn as nn
from torchvision.models import resnet34
from torchvision.models._utils import IntermediateLayerGetter
class DoubleConv(nn.Sequential):
def __init__(self, in_channels, out_channels, mid_channels=None):
if mid_channels is None:
mid_channels = out_channels
super(DoubleConv, self).__init__(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
class Down(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(Down, self).__init__(
nn.MaxPool2d(2, stride=2),
DoubleConv(in_channels, out_channels)
)
class Up(nn.Module):
def __init__(self, in_channels, out_channels, nearest=True):
super(Up, self).__init__()
if nearest:
self.up = nn.Upsample(scale_factor=2, mode='nearest')
self.conv = DoubleConv(in_channels, out_channels)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x1 = self.up(x1)
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
class OutConv(nn.Sequential):
def __init__(self, in_channels, num_classes):
super(OutConv, self).__init__(
nn.Upsample(scale_factor=2, mode='nearest'),
DoubleConv(in_channels, 16),
nn.Conv2d(16, num_classes, kernel_size=3, padding=1)
)
class Res34UNet(nn.Module):
def __init__(self, num_classes, pretrain_backbone: bool = False):
super(Res34UNet, self).__init__()
backbone = resnet34()
# 1通道
# backbone.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# if pretrain_backbone:
# # 载入resnet34预训练权重
# # https://download.pytorch.org/models/resnet34-333f7ec4.pth,
# backbone.load_state_dict(torch.load("resnet34.pth", map_location='cpu'))
stage_indices = ['relu', 'layer1', 'layer2', 'layer3', 'layer4']
self.stage_out_channels = [64, 64, 128, 256, 512]
return_layers = dict([(str(j), f"stage{i}") for i, j in enumerate(stage_indices)])
self.backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
c = self.stage_out_channels[4] + self.stage_out_channels[3]
self.up1 = Up(c, self.stage_out_channels[3])
c = self.stage_out_channels[3] + self.stage_out_channels[2]
self.up2 = Up(c, self.stage_out_channels[2])
c = self.stage_out_channels[2] + self.stage_out_channels[1]
self.up3 = Up(c, self.stage_out_channels[1])
c = self.stage_out_channels[1] + self.stage_out_channels[0]
self.up4 = Up(c, 32)
self.outconv = OutConv(32, num_classes=num_classes)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
backbone_out = self.backbone(x)
x = self.up1(backbone_out['stage4'], backbone_out['stage3'])
x = self.up2(x, backbone_out['stage2'])
x = self.up3(x, backbone_out['stage1'])
x = self.up4(x, backbone_out['stage0'])
x = self.outconv(x)
return {"out": x}
def convert_onnx(m, save_path):
m.eval()
x = torch.rand(1, 3, 1504, 1504, requires_grad=True)
# export the model
torch.onnx.export(m, # model being run
x, # model input (or a tuple for multiple inputs)
save_path, # where to save the model (can be a file or file-like object)
input_names=["input"], output_names=['output'],
export_params=True,
opset_version=9)
if __name__ == '__main__':
model = Res34UNet(num_classes=1)
# 导出onnx模型
convert_onnx(model, "../save_weights/Res34UNet.onnx")
# 打印模型参数量
total_params = sum(p.numel() for p in model.parameters())
total_params += sum(p.numel() for p in model.buffers())
print(f'{total_params:,} total parameters.')
print(f'{total_params/(1024*1024):.2f}M total parameters.')
total_trainable_params = sum(
p.numel() for p in model.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')
print(f'{total_trainable_params/(1024*1024):.2f}M training parameters.')
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/yurj0403/deep-learning.git
git@gitee.com:yurj0403/deep-learning.git
yurj0403
deep-learning
DeepLearning
master

Search