1 Star 9 Fork 7

Tanhaiyan/keyPointsDetectionMethodWithTorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
data_process_fashion_hw.py 5.28 KB
一键复制 编辑 原始数据 按行查看 历史
# 1.输入图像预处理,包括尺寸,旋转。
# 2.真实值ground truth变形,shape = (w,h,kp_num) = (64, 64, n)
# 3.返回一个发生器,用于给模型做输入,以及输出时做损失计算。
import os
import numpy as np
import pandas as pd
import torch
from skimage import io, transform # 用于图像的IO和变换
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from matplotlib import pyplot as plt
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class KeyPointsDataSet(Dataset):
"""服装关键点标记数据集"""
def __init__(self, csv_file, root_dir, num, transforms_img=None):
"""
初始化数据集
:param csv_file: 带标记的csv文件,为数据-category-标签coordination 成对组成的文件
:param root_dir: 图像数据目录
:param transform(callable,optional): 一个样本上的可用可选变换
"""
self.data_info = self.get_file_info(csv_file, num)
self.root_dir = root_dir
self.transform_img = transforms_img
def __len__(self):
return len(self.data_info[0])
def __getitem__(self, idx):
H, W = 64.0, 64.0
img_id = self.data_info[0][idx]
img_id = os.path.join(self.root_dir, img_id)
image = io.imread(img_id)
h, w, c = image.shape
landmarks = np.asfortranarray(self.data_info[1][idx])
image = self.change_img_size(image, H, W)
landmarks[:, 0] = landmarks[:, 0] * W / w
landmarks[:, 1] = landmarks[:, 1] * H / h
image = self.transform_img(image)
# print(image, landmarks)
return image.float(), torch.tensor(landmarks)
@staticmethod
def get_file_info(file_path, num):
file_info = pd.read_csv(file_path)
img_list = file_info.iloc[:, 0]
landmarks = file_info.iloc[:, 2:num + 2].values # panda中DataFrame数据的读取
coordinarys = []
for i in range(len(landmarks)):
label = []
for j in range(num):
plot = landmarks[i][j].split('_')
coor = []
for per in plot:
coor.append(int(per))
label.append(coor)
coordinarys.append(np.concatenate(label))
landmarks = np.array(coordinarys).reshape((-1, num, 3))
return img_list, landmarks
@staticmethod
def change_img_size(image, h, w):
return transform.resize(image, (h, w))
class ToTensor(object):
"""将样本中的ndarrays转换为Tensors."""
def __call__(self, sample):
return torch.from_numpy(sample)
transform_img = transforms.Compose([
transforms.ToTensor(), # 将图像(Image)转成Tensor
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 将tensor标准化[-1,1]
])
dressDataset = KeyPointsDataSet(csv_file=r"D:\TanHaiyan\Datasets\Fashion\Fashion AI-keypoints\train\train_dress.csv",
root_dir=r"D:\TanHaiyan\Datasets\Fashion\Fashion AI-keypoints\train",
num=15,
transforms_img=transform_img,
)
skirtDataset = KeyPointsDataSet(csv_file=r"D:\TanHaiyan\Datasets\Fashion\Fashion AI-keypoints\train\train_skirt.csv",
root_dir=r"D:\TanHaiyan\Datasets\Fashion\Fashion AI-keypoints\train",
num=4,
transforms_img=transform_img,
)
trousersDataset = KeyPointsDataSet(
csv_file=r"D:\TanHaiyan\Datasets\Fashion\Fashion AI-keypoints\train\train_trousers.csv",
root_dir=r"D:\TanHaiyan\Datasets\Fashion\Fashion AI-keypoints\train",
num=7,
transforms_img=transform_img,
)
outwearDataset = KeyPointsDataSet(
csv_file=r"D:\TanHaiyan\Datasets\Fashion\Fashion AI-keypoints\train\train_outwear.csv",
root_dir=r"D:\TanHaiyan\Datasets\Fashion\Fashion AI-keypoints\train",
num=15,
transforms_img=transform_img,
)
blouseDataset = KeyPointsDataSet(csv_file=r"D:\TanHaiyan\Datasets\Fashion\Fashion AI-keypoints\train\train_blouse.csv",
root_dir=r"D:\TanHaiyan\Datasets\Fashion\Fashion AI-keypoints\train",
num=13,
transforms_img=transform_img,
)
dataloader = DataLoader(dataset=blouseDataset, batch_size=8, shuffle=True)
if __name__ == "__main__":
# test
for i_batch, data in enumerate(dataloader):
img, landmarks = data
#
if i_batch == 0:
# print(img, landmarks)
mark = landmarks.numpy()
picture = img.numpy()
maxPixle = picture.max()
picture = picture / maxPixle
mat = np.uint8(picture)
plt.figure(num='dress')
for i in range(8):
plt.subplot(2, 4, i + 1) # 将窗口分为两行两列四个子图,则可显示四幅图片
plt.title(str(i + 1)) # 第一幅图片标题
plt.imshow(mat[i].transpose(1, 2, 0)) # 绘制第一幅图片
plt.scatter(mark[i][:, 0], mark[i][:, 1], marker=".", s=50, color="r")
plt.show()
print(img.shape, landmarks.shape)
break
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/rpr/key-points-detection-method-with-torch.git
git@gitee.com:rpr/key-points-detection-method-with-torch.git
rpr
key-points-detection-method-with-torch
keyPointsDetectionMethodWithTorch
master

搜索帮助