代码拉取完成,页面将自动刷新
from keypoints_Net import CoordRegression
from data_process import *
import torch.nn as nn
import torch.optim as optim
import dsntnn
from matplotlib import pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def show_result(coords, picture):
coords = coords * 32.0 + 31.5
maxPixle = picture.max()
picture = picture / maxPixle
mat = np.uint8(picture)
plt.figure(num='dress')
for i in range(8):
print(i)
plt.subplot(2, 4, i + 1) # 将窗口分为两行两列四个子图,则可显示四幅图片
plt.title(str(i + 1)) # 第一幅图片标题
plt.imshow(mat[i].transpose(1, 2, 0)) # 绘制第一幅图片
plt.scatter(coords[i][:, 0], coords[i][:, 1], marker="*", s=30, color="r")
plt.show()
return
def main(path=None):
model = torch.load(path)
model = torch.nn.DataParallel(model).cuda()
model.eval()
from data_test_fashion_hw import skirtDataset
dataloader_val = DataLoader(dataset=skirtDataset, batch_size=8, shuffle=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 验证
with torch.no_grad():
model.eval()
for i_batch, data in enumerate(dataloader_val):
img = data
img = img.to(device)
coords, _ = model(img)
img = img.cpu()
coords = coords.cpu()
show_result(coords.detach().numpy(), img.numpy())
if __name__ == "__main__":
model_path = r"E:\model\KPDEM\skirt_kp58.pth"
main(model_path)
print("The end!")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。