1 Star 0 Fork 1

lhx/simplified_struct2depth

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
DispNetS.py 6.57 KB
一键复制 编辑 原始数据 按行查看 历史
necroen 提交于 2019-05-17 19:42 . init
# disp 网络无所谓输入的是 ref 还是 tgt, 给一张图片就生成一张视差图
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, zeros_
def downsample_conv(in_planes, out_planes, kernel_size=3):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2),
nn.ReLU(inplace=True),
nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2),
nn.ReLU(inplace=True)
)
def predict_disp(in_planes):
return nn.Sequential(
nn.Conv2d(in_planes, 1, kernel_size=3, padding=1),
nn.Sigmoid()
)
def conv(in_planes, out_planes):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
def upconv(in_planes, out_planes):
return nn.Sequential(
nn.ConvTranspose2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ReLU(inplace=True)
)
def crop_like(input, ref):
assert( input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3) )
return input[:, :, :ref.size(2), :ref.size(3)]
class DispNetS(nn.Module):
def __init__(self, alpha=10, beta=0.01):
super(DispNetS, self).__init__()
self.alpha = alpha
self.beta = beta
conv_planes = [32, 64, 128, 256, 512, 512, 512]
self.conv1 = downsample_conv(3, conv_planes[0], kernel_size=7)
self.conv2 = downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5)
self.conv3 = downsample_conv(conv_planes[1], conv_planes[2])
self.conv4 = downsample_conv(conv_planes[2], conv_planes[3])
self.conv5 = downsample_conv(conv_planes[3], conv_planes[4])
self.conv6 = downsample_conv(conv_planes[4], conv_planes[5])
self.conv7 = downsample_conv(conv_planes[5], conv_planes[6])
upconv_planes = [512, 512, 256, 128, 64, 32, 16]
self.upconv7 = upconv(conv_planes[6], upconv_planes[0])
self.upconv6 = upconv(upconv_planes[0], upconv_planes[1])
self.upconv5 = upconv(upconv_planes[1], upconv_planes[2])
self.upconv4 = upconv(upconv_planes[2], upconv_planes[3])
self.upconv3 = upconv(upconv_planes[3], upconv_planes[4])
self.upconv2 = upconv(upconv_planes[4], upconv_planes[5])
self.upconv1 = upconv(upconv_planes[5], upconv_planes[6])
self.iconv7 = conv(upconv_planes[0] + conv_planes[5], upconv_planes[0])
self.iconv6 = conv(upconv_planes[1] + conv_planes[4], upconv_planes[1])
self.iconv5 = conv(upconv_planes[2] + conv_planes[3], upconv_planes[2])
self.iconv4 = conv(upconv_planes[3] + conv_planes[2], upconv_planes[3])
self.iconv3 = conv(1 + upconv_planes[4] + conv_planes[1], upconv_planes[4])
self.iconv2 = conv(1 + upconv_planes[5] + conv_planes[0], upconv_planes[5])
self.iconv1 = conv(1 + upconv_planes[6], upconv_planes[6])
self.predict_disp4 = predict_disp(upconv_planes[3])
self.predict_disp3 = predict_disp(upconv_planes[4])
self.predict_disp2 = predict_disp(upconv_planes[5])
self.predict_disp1 = predict_disp(upconv_planes[6])
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
xavier_uniform_(m.weight)
if m.bias is not None:
zeros_(m.bias)
def forward(self, x):
out_conv1 = self.conv1(x)
out_conv2 = self.conv2(out_conv1)
out_conv3 = self.conv3(out_conv2)
out_conv4 = self.conv4(out_conv3)
out_conv5 = self.conv5(out_conv4)
out_conv6 = self.conv6(out_conv5)
out_conv7 = self.conv7(out_conv6)
out_upconv7 = crop_like(self.upconv7(out_conv7), out_conv6)
concat7 = torch.cat((out_upconv7, out_conv6), 1)
out_iconv7 = self.iconv7(concat7)
out_upconv6 = crop_like(self.upconv6(out_iconv7), out_conv5)
concat6 = torch.cat((out_upconv6, out_conv5), 1)
out_iconv6 = self.iconv6(concat6)
out_upconv5 = crop_like(self.upconv5(out_iconv6), out_conv4)
concat5 = torch.cat((out_upconv5, out_conv4), 1)
out_iconv5 = self.iconv5(concat5)
out_upconv4 = crop_like(self.upconv4(out_iconv5), out_conv3)
concat4 = torch.cat((out_upconv4, out_conv3), 1)
out_iconv4 = self.iconv4(concat4)
disp4 = self.alpha * self.predict_disp4(out_iconv4) + self.beta
out_upconv3 = crop_like(self.upconv3(out_iconv4), out_conv2)
disp4_up = crop_like(F.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2)
concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1)
out_iconv3 = self.iconv3(concat3)
disp3 = self.alpha * self.predict_disp3(out_iconv3) + self.beta
out_upconv2 = crop_like(self.upconv2(out_iconv3), out_conv1)
disp3_up = crop_like(F.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
out_iconv2 = self.iconv2(concat2)
disp2 = self.alpha * self.predict_disp2(out_iconv2) + self.beta
out_upconv1 = crop_like(self.upconv1(out_iconv2), x)
disp2_up = crop_like(
F.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False),
x)
concat1 = torch.cat((out_upconv1, disp2_up), 1)
out_iconv1 = self.iconv1(concat1)
disp1 = self.alpha * self.predict_disp1(out_iconv1) + self.beta
# print("disp1 size:", disp1.size() ) # torch.Size([1, 1, 128, 416])
# print("disp2 size:", disp2.size() ) # torch.Size([1, 1, 64, 208])
# print("disp3 size:", disp3.size() ) # torch.Size([1, 1, 32, 104])
# print("disp4 size:", disp4.size() ) # torch.Size([1, 1, 16, 52])
return [disp1, disp2, disp3, disp4], out_conv7
# multiscale_disps_i, bottleneck
# [1,1,128,416], [1,1,64,208],[1,1,32,104],[1,1,16,52] when batchsize = 1
if __name__ == '__main__':
disp_net = DispNetS()
disp_net.init_weights()
disp_net.train()
inputs = torch.rand([1, 3, 128, 416])
disp_lists, out_conv7 = disp_net(inputs)
print("output size:", disp_lists[0].size() )
# from tensorboardX import SummaryWriter
# with SummaryWriter(comment = 'depth_net') as w:
# w.add_graph(disp_net, (inputs, ))
# tensorboard --logdir=./runs --port=6006 windows
# http://localhost:6006/
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/linhongxiang/simplified_struct2depth.git
git@gitee.com:linhongxiang/simplified_struct2depth.git
linhongxiang
simplified_struct2depth
simplified_struct2depth
master

搜索帮助