1 Star 9 Fork 7

Tanhaiyan/keyPointsDetectionMethodWithTorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
test_KPDM_model.py 1.55 KB
一键复制 编辑 原始数据 按行查看 历史
Tanhaiyan 提交于 10个月前 . update test_KPDM_model.py.
from keypoints_Net import CoordRegression
from data_process import *
import torch.nn as nn
import torch.optim as optim
import dsntnn
from matplotlib import pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def show_result(coords, picture):
coords = coords * 32.0 + 31.5
maxPixle = picture.max()
picture = picture / maxPixle
mat = np.uint8(picture)
plt.figure(num='dress')
for i in range(8):
print(i)
plt.subplot(2, 4, i + 1) # 将窗口分为两行两列四个子图,则可显示四幅图片
plt.title(str(i + 1)) # 第一幅图片标题
plt.imshow(mat[i].transpose(1, 2, 0)) # 绘制第一幅图片
plt.scatter(coords[i][:, 0], coords[i][:, 1], marker="*", s=30, color="r")
plt.show()
return
def main(path=None):
model = torch.load(path)
model = torch.nn.DataParallel(model).cuda()
model.eval()
from data_test_fashion_hw import skirtDataset
dataloader_val = DataLoader(dataset=skirtDataset, batch_size=8, shuffle=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 验证
with torch.no_grad():
model.eval()
for i_batch, data in enumerate(dataloader_val):
img = data
img = img.to(device)
coords, _ = model(img)
img = img.cpu()
coords = coords.cpu()
show_result(coords.detach().numpy(), img.numpy())
if __name__ == "__main__":
model_path = r"E:\model\KPDEM\skirt_kp58.pth"
main(model_path)
print("The end!")
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/rpr/key-points-detection-method-with-torch.git
git@gitee.com:rpr/key-points-detection-method-with-torch.git
rpr
key-points-detection-method-with-torch
keyPointsDetectionMethodWithTorch
master

搜索帮助