1 Star 0 Fork 0

winday00/pytorch-pwc

加入 Gitee
与超过 1400万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
run.py 14.01 KB
一键复制 编辑 原始数据 按行查看 历史
Simon Niklaus 提交于 2021-07-23 23:51 +08:00 . no message
#!/usr/bin/env python
import torch
import getopt
import math
import numpy
import os
import PIL
import PIL.Image
import sys
try:
from .correlation import correlation # the custom cost volume layer
except:
sys.path.insert(0, './correlation'); import correlation # you should consider upgrading python
# end
##########################################################
assert(int(str('').join(torch.__version__.split('.')[0:2])) >= 13) # requires at least pytorch version 1.3.0
torch.set_grad_enabled(False) # make sure to not compute gradients for computational performance
torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance
##########################################################
arguments_strModel = 'default' # 'default', or 'chairs-things'
arguments_strOne = './images/one.png'
arguments_strTwo = './images/two.png'
arguments_strOut = './out.flo'
for strOption, strArgument in getopt.getopt(sys.argv[1:], '', [ strParameter[2:] + '=' for strParameter in sys.argv[1::2] ])[0]:
if strOption == '--model' and strArgument != '': arguments_strModel = strArgument # which model to use
if strOption == '--one' and strArgument != '': arguments_strOne = strArgument # path to the first frame
if strOption == '--two' and strArgument != '': arguments_strTwo = strArgument # path to the second frame
if strOption == '--out' and strArgument != '': arguments_strOut = strArgument # path to where the output should be stored
# end
##########################################################
backwarp_tenGrid = {}
backwarp_tenPartial = {}
def backwarp(tenInput, tenFlow):
if str(tenFlow.shape) not in backwarp_tenGrid:
tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), tenFlow.shape[3]).view(1, 1, 1, -1).expand(-1, -1, tenFlow.shape[2], -1)
tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), tenFlow.shape[2]).view(1, 1, -1, 1).expand(-1, -1, -1, tenFlow.shape[3])
backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([ tenHor, tenVer ], 1).cuda()
# end
if str(tenFlow.shape) not in backwarp_tenPartial:
backwarp_tenPartial[str(tenFlow.shape)] = tenFlow.new_ones([ tenFlow.shape[0], 1, tenFlow.shape[2], tenFlow.shape[3] ])
# end
tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1)
tenInput = torch.cat([ tenInput, backwarp_tenPartial[str(tenFlow.shape)] ], 1)
tenOutput = torch.nn.functional.grid_sample(input=tenInput, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=False)
tenMask = tenOutput[:, -1:, :, :]; tenMask[tenMask > 0.999] = 1.0; tenMask[tenMask < 1.0] = 0.0
return tenOutput[:, :-1, :, :] * tenMask
# end
##########################################################
class Network(torch.nn.Module):
def __init__(self):
super().__init__()
class Extractor(torch.nn.Module):
def __init__(self):
super().__init__()
self.netOne = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
)
self.netTwo = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
)
self.netThr = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
)
self.netFou = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
)
self.netFiv = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
)
self.netSix = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=128, out_channels=196, kernel_size=3, stride=2, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
)
# end
def forward(self, tenInput):
tenOne = self.netOne(tenInput)
tenTwo = self.netTwo(tenOne)
tenThr = self.netThr(tenTwo)
tenFou = self.netFou(tenThr)
tenFiv = self.netFiv(tenFou)
tenSix = self.netSix(tenFiv)
return [ tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix ]
# end
# end
class Decoder(torch.nn.Module):
def __init__(self, intLevel):
super().__init__()
intPrevious = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 1]
intCurrent = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 0]
if intLevel < 6: self.netUpflow = torch.nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1)
if intLevel < 6: self.netUpfeat = torch.nn.ConvTranspose2d(in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=4, stride=2, padding=1)
if intLevel < 6: self.fltBackwarp = [ None, None, None, 5.0, 2.5, 1.25, 0.625, None ][intLevel + 1]
self.netOne = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=intCurrent, out_channels=128, kernel_size=3, stride=1, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
)
self.netTwo = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=intCurrent + 128, out_channels=128, kernel_size=3, stride=1, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
)
self.netThr = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=intCurrent + 128 + 128, out_channels=96, kernel_size=3, stride=1, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
)
self.netFou = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96, out_channels=64, kernel_size=3, stride=1, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
)
self.netFiv = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64, out_channels=32, kernel_size=3, stride=1, padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
)
self.netSix = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=3, stride=1, padding=1)
)
# end
def forward(self, tenOne, tenTwo, objPrevious):
tenFlow = None
tenFeat = None
if objPrevious is None:
tenFlow = None
tenFeat = None
tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenOne=tenOne, tenTwo=tenTwo), negative_slope=0.1, inplace=False)
tenFeat = torch.cat([ tenVolume ], 1)
elif objPrevious is not None:
tenFlow = self.netUpflow(objPrevious['tenFlow'])
tenFeat = self.netUpfeat(objPrevious['tenFeat'])
tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenOne=tenOne, tenTwo=backwarp(tenInput=tenTwo, tenFlow=tenFlow * self.fltBackwarp)), negative_slope=0.1, inplace=False)
tenFeat = torch.cat([ tenVolume, tenOne, tenFlow, tenFeat ], 1)
# end
tenFeat = torch.cat([ self.netOne(tenFeat), tenFeat ], 1)
tenFeat = torch.cat([ self.netTwo(tenFeat), tenFeat ], 1)
tenFeat = torch.cat([ self.netThr(tenFeat), tenFeat ], 1)
tenFeat = torch.cat([ self.netFou(tenFeat), tenFeat ], 1)
tenFeat = torch.cat([ self.netFiv(tenFeat), tenFeat ], 1)
tenFlow = self.netSix(tenFeat)
return {
'tenFlow': tenFlow,
'tenFeat': tenFeat
}
# end
# end
class Refiner(torch.nn.Module):
def __init__(self):
super().__init__()
self.netMain = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=4, dilation=4),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
torch.nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=8, dilation=8),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
torch.nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=16, dilation=16),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1, dilation=1)
)
# end
def forward(self, tenInput):
return self.netMain(tenInput)
# end
# end
self.netExtractor = Extractor()
self.netTwo = Decoder(2)
self.netThr = Decoder(3)
self.netFou = Decoder(4)
self.netFiv = Decoder(5)
self.netSix = Decoder(6)
self.netRefiner = Refiner()
self.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.hub.load_state_dict_from_url(url='http://content.sniklaus.com/github/pytorch-pwc/network-' + arguments_strModel + '.pytorch', file_name='pwc-' + arguments_strModel).items() })
# end
def forward(self, tenOne, tenTwo):
tenOne = self.netExtractor(tenOne)
tenTwo = self.netExtractor(tenTwo)
objEstimate = self.netSix(tenOne[-1], tenTwo[-1], None)
objEstimate = self.netFiv(tenOne[-2], tenTwo[-2], objEstimate)
objEstimate = self.netFou(tenOne[-3], tenTwo[-3], objEstimate)
objEstimate = self.netThr(tenOne[-4], tenTwo[-4], objEstimate)
objEstimate = self.netTwo(tenOne[-5], tenTwo[-5], objEstimate)
return objEstimate['tenFlow'] + self.netRefiner(objEstimate['tenFeat'])
# end
# end
netNetwork = None
##########################################################
def estimate(tenOne, tenTwo):
global netNetwork
if netNetwork is None:
netNetwork = Network().cuda().eval()
# end
assert(tenOne.shape[1] == tenTwo.shape[1])
assert(tenOne.shape[2] == tenTwo.shape[2])
intWidth = tenOne.shape[2]
intHeight = tenOne.shape[1]
assert(intWidth == 1024) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue
assert(intHeight == 436) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue
tenPreprocessedOne = tenOne.cuda().view(1, 3, intHeight, intWidth)
tenPreprocessedTwo = tenTwo.cuda().view(1, 3, intHeight, intWidth)
intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0))
intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0))
tenPreprocessedOne = torch.nn.functional.interpolate(input=tenPreprocessedOne, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
tenPreprocessedTwo = torch.nn.functional.interpolate(input=tenPreprocessedTwo, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
tenFlow = 20.0 * torch.nn.functional.interpolate(input=netNetwork(tenPreprocessedOne, tenPreprocessedTwo), size=(intHeight, intWidth), mode='bilinear', align_corners=False)
tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth)
tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight)
return tenFlow[0, :, :, :].cpu()
# end
##########################################################
if __name__ == '__main__':
tenOne = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strOne))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))
tenTwo = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strTwo))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))
tenOutput = estimate(tenOne, tenTwo)
objOutput = open(arguments_strOut, 'wb')
numpy.array([ 80, 73, 69, 72 ], numpy.uint8).tofile(objOutput)
numpy.array([ tenOutput.shape[2], tenOutput.shape[1] ], numpy.int32).tofile(objOutput)
numpy.array(tenOutput.numpy().transpose(1, 2, 0), numpy.float32).tofile(objOutput)
objOutput.close()
# end
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/winday00/pytorch-pwc.git
git@gitee.com:winday00/pytorch-pwc.git
winday00
pytorch-pwc
pytorch-pwc
master

搜索帮助