1 Star 1 Fork 1

郭丰/pytorch-superpoint

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
Val_model_subpixel.py 3.41 KB
一键复制 编辑 原始数据 按行查看 历史
You-Yi Jau 提交于 5年前 . clean up base paths
"""script for subpixel experiment (not tested)
"""
import numpy as np
import torch
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import torch.optim
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from tqdm import tqdm
from utils.loader import dataLoader, modelLoader, pretrainedLoader
import logging
from utils.tools import dict_update
from utils.utils import labels2Dto3D, flattenDetection, labels2Dto3D_flattened
from utils.utils import pltImshow, saveImg
from utils.utils import precisionRecall_torch
from utils.utils import save_checkpoint
from pathlib import Path
@torch.no_grad()
class Val_model_subpixel(object):
def __init__(self, config, device='cpu', verbose=False):
self.config = config
self.model = self.config['name']
self.params = self.config['params']
self.weights_path = self.config['pretrained']
self.device=device
pass
def loadModel(self):
# model = 'SuperPointNet'
# params = self.config['model']['subpixel']['params']
from utils.loader import modelLoader
self.net = modelLoader(model=self.model, **self.params)
checkpoint = torch.load(self.weights_path,
map_location=lambda storage, loc: storage)
self.net.load_state_dict(checkpoint['model_state_dict'])
self.net = self.net.to(self.device)
logging.info('successfully load pretrained model from: %s', self.weights_path)
pass
def extract_patches(self, label_idx, img):
"""
input:
label_idx: tensor [N, 4]: (batch, 0, y, x)
img: tensor [batch, channel(1), H, W]
"""
from utils.losses import extract_patches
patch_size = self.config['params']['patch_size']
patches = extract_patches(label_idx.to(self.device), img.to(self.device),
patch_size=patch_size)
return patches
pass
def run(self, patches):
"""
"""
with torch.no_grad():
pred_res = self.net(patches)
return pred_res
pass
if __name__ == '__main__':
# filename = 'configs/magicpoint_shapes_subpix.yaml'
filename = 'configs/magicpoint_repeatability.yaml'
import yaml
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_tensor_type(torch.FloatTensor)
with open(filename, 'r') as f:
config = yaml.load(f)
task = config['data']['dataset']
# data loading
from utils.loader import dataLoader_test as dataLoader
data = dataLoader(config, dataset='hpatches')
test_set, test_loader = data['test_set'], data['test_loader']
# take one sample
for i, sample in tqdm(enumerate(test_loader)):
if i>1: break
val_agent = Val_model_subpixel(config['subpixel'], device=device)
val_agent.loadModel()
# points from heatmap
img = sample['image']
print("image: ", img.shape)
points = torch.tensor([[1,2], [3,4]])
def points_to_4d(points):
num_of_points = points.shape[0]
cols = torch.zeros(num_of_points, 1).float()
points = torch.cat((cols, cols, points.float()), dim=1)
return points
label_idx = points_to_4d(points)
# concat points to be (batch, 0, y, x)
patches = val_agent.extract_patches(label_idx, img)
points_res = val_agent.run(patches)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/wind_x/pytorch-superpoint.git
git@gitee.com:wind_x/pytorch-superpoint.git
wind_x
pytorch-superpoint
pytorch-superpoint
master

搜索帮助