代码拉取完成,页面将自动刷新
# 1.输入图像预处理,包括尺寸,旋转。
# 2.真实值ground truth变形,shape = (w,h,kp_num) = (64, 64, n)
# 3.返回一个发生器,用于给模型做输入,以及输出时做损失计算。
import os
import numpy as np
import pandas as pd
import torch
from skimage import io, transform # 用于图像的IO和变换
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from matplotlib import pyplot as plt
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class TestDataSet(Dataset):
"""服装关键点标记数据集"""
def __init__(self, csv_file, root_dir, num, transforms_img=None):
"""
初始化数据集
:param csv_file: 带标记的csv文件,为数据-category-标签coordination 成对组成的文件
:param root_dir: 图像数据目录
:param transform(callable,optional): 一个样本上的可用可选变换
"""
self.data_info = self.get_file_info(csv_file)
self.root_dir = root_dir
self.transform_img = transforms_img
def __len__(self):
return len(self.data_info[0])
def __getitem__(self, idx):
H, W = 64.0, 64.0
img_id = self.data_info[idx]
img_id = os.path.join(self.root_dir, img_id)
image = io.imread(img_id)
h, w, c = image.shape
image = self.change_img_size(image, H, W)
image = self.transform_img(image)
return image.float()
@staticmethod
def get_file_info(file_path):
file_info = pd.read_csv(file_path)
img_list = file_info.iloc[:, 0]
return img_list
@staticmethod
def change_img_size(image, h, w):
return transform.resize(image, (h, w))
class ToTensor(object):
"""将样本中的ndarrays转换为Tensors."""
def __call__(self, sample):
return torch.from_numpy(sample)
transform_img = transforms.Compose([
transforms.ToTensor(), # 将图像(Image)转成Tensor
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 将tensor标准化[-1,1]
])
dressDataset = TestDataSet(csv_file=r"E:\Datasets\Fashion\Fashion AI-keypoints\test\test_dress.csv",
root_dir=r"E:\Datasets\Fashion\Fashion AI-keypoints\test",
num=15,
transforms_img=transform_img,
)
skirtDataset = TestDataSet(csv_file=r"E:\Datasets\Fashion\Fashion AI-keypoints\test\test_skirt.csv",
root_dir=r"E:\Datasets\Fashion\Fashion AI-keypoints\test",
num=4,
transforms_img=transform_img,
)
trousersDataset = TestDataSet(csv_file=r"E:\Datasets\Fashion\Fashion AI-keypoints\test\test_trousers.csv",
root_dir=r"E:\Datasets\Fashion\Fashion AI-keypoints\test",
num=7,
transforms_img=transform_img,
)
outwearDataset = TestDataSet(csv_file=r"E:\Datasets\Fashion\Fashion AI-keypoints\test\test_outwear.csv",
root_dir=r"E:\Datasets\Fashion\Fashion AI-keypoints\test",
num=15,
transforms_img=transform_img,
)
blouseDataset = TestDataSet(csv_file=r"E:\Datasets\Fashion\Fashion AI-keypoints\test\test_blouse.csv",
root_dir=r"E:\Datasets\Fashion\Fashion AI-keypoints\test",
num=13,
transforms_img=transform_img,
)
dayiDataset = TestDataSet(csv_file=r"E:\Datasets\Fashion\Fashion AI-keypoints\test\test_dayi.csv",
root_dir=r"E:\Datasets\Fashion\Fashion AI-keypoints\test",
num=15,
transforms_img=transform_img,
)
dataloader = DataLoader(dataset=trousersDataset, batch_size=8, shuffle=True)
if __name__ == "__main__":
# test
for i_batch, data in enumerate(dataloader):
img = data
# print(type(img), landmarks)
if i_batch == 0:
picture = img.numpy()
maxPixle = picture.max()
picture = picture / maxPixle
mat = np.uint8(picture)
plt.figure(num='dress')
for i in range(8):
plt.subplot(2, 4, i + 1) # 将窗口分为两行两列四个子图,则可显示四幅图片
plt.title(str(i + 1)) # 第一幅图片标题
plt.imshow(mat[i].transpose(1, 2, 0)) # 绘制第一幅图片
plt.show()
print(img.shape)
break
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。