1 Star 0 Fork 5

hjy / keyPointsDetectionMethodWithTorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
data_process.py 6.92 KB
一键复制 编辑 原始数据 按行查看 历史
unknown 提交于 2021-06-18 15:23 . KPDM_v2.1
# 1.输入图像预处理,包括尺寸,旋转。
# 2.真实值ground truth变形,shape = (w,h,kp_num) = (224, 224, 24)
# 3.返回一个发生器,用于给模型做输入,以及输出时做损失计算。
import os
import numpy as np
import pandas as pd
import torch
from skimage import io, transform # 用于图像的IO和变换
from torch.utils.data import Dataset, DataLoader
from GassionHeatMap import generate_hmap_mask
from torchvision import transforms
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class KeyPointsDataSet(Dataset):
"""服装关键点标记数据集"""
def __init__(self, csv_file, root_dir, transform_img=None, transform_heat=None):
"""
初始化数据集
:param csv_file: 带标记的csv文件,为数据-category-标签coordination 成对组成的文件
:param root_dir: 图像数据目录
:param transform(callable,optional): 一个样本上的可用可选变换
"""
self.data_info = self.get_file_info(csv_file)
self.root_dir = root_dir
self.transform_img = transform_img
self.transform_heat = transform_heat
def __len__(self):
return len(self.data_info[0])
def __getitem__(self, idx):
img_id = self.data_info[0][idx]
img_id = os.path.join(self.root_dir, img_id)
image = io.imread(img_id)
o_image_size = image.shape[0:2]
landmarks = np.asfortranarray(self.data_info[1][idx])
heatmap = self.get_htmap(o_image_size, landmarks)
# print("Before, image size is:", image.shape)
# print("Before, heatmap size is:", heatmap.shape)
image = self.change_img_size(image)
heatmap = self.change_heat_size(heatmap)
try:
if self.transform_img and self.transform_heat:
image = self.transform_img(image) / 255
# print("After:image size is:", image.shape)
new_size = image.shape[1:]
bi = np.array(new_size) / np.array(o_image_size)
landmarks[:, 0:2] = landmarks[:, 0:2] * bi
heatmap = self.transform_heat(heatmap)
# print("After:heatmap size is", heatmap.shape)
except:
print("on,here!")
# raise EOFError
image = self.transform_img(image) / 255
heatmap = torch.tensor(heatmap)
finally:
pass
return image.float(), torch.tensor(landmarks), heatmap
@staticmethod
def get_file_info(file_path):
file_info = pd.read_csv(file_path)
img_list = file_info.iloc[:, 0]
landmarks = file_info.iloc[:, 2:26].values # panda中DataFrame数据的读取
coordinarys = []
for i in range(len(landmarks)):
label = []
for j in range(24):
plot = landmarks[i][j].split('_')
coor = []
for per in plot:
coor.append(int(per))
label.append(coor)
coordinarys.append(np.concatenate(label))
landmarks = np.array(coordinarys).reshape((-1, 24, 3))
return img_list, landmarks
@staticmethod
def get_htmap(image_size, landmarks):
hmap = generate_hmap_mask(image_size, landmarks)
return hmap
@staticmethod
def change_heat_size(image):
chn, h, w = image.shape
n_image = np.zeros((chn, 512, 512))
if h <= 512:
if w <= 512:
n_image[:, 0:h, 0:w] = image
return n_image
else:
n_image[:, 0:h, :] = image[:, :, 0:512]
return n_image
else:
if w <= 512:
n_image[:, :, 0:w] = image[:, 0:512, :]
return n_image
else:
n_image[:, :, :] = image[:, 0:512, 0:512]
return n_image
@staticmethod
def change_img_size(image):
h, w, chn = image.shape
# print(image.shape)
n_image = np.zeros((512, 512, chn))
if h <= 512:
if w <= 512:
n_image[0:h, 0:w, :] = image
return n_image
else:
n_image[0:h, :, :] = image[:, 0:512, :]
return n_image
else:
if w <= 512:
n_image[:, 0:w, :] = image[0:512, :, :]
return n_image
else:
n_image[:, :, :] = image[0:512, 0:512, :]
return n_image
class DataSet_Test(KeyPointsDataSet):
def __init__(self, csv_file, root_dir, transform_img):
super().__init__(csv_file, root_dir, transform_img)
def __getitem__(self, idx):
img_id = self.data_info[idx]
img_id = os.path.join(self.root_dir, img_id)
image = io.imread(img_id)
image = self.change_img_size(image)
# print(image)
try:
if self.transform_img:
image = self.transform_img(image)/255
except:
print("on,here!")
image = self.transform_img(image)/255
finally:
# print(image)
pass
return image.float()
@staticmethod
def get_file_info(file_path):
file_info = pd.read_csv(file_path)
img_list = file_info.iloc[:, 0]
return img_list
class ToTensor(object):
"""将样本中的ndarrays转换为Tensors."""
def __call__(self, sample):
return torch.from_numpy(sample)
transform_img = transforms.Compose([
transforms.ToTensor(), # 将图像(Image)转成Tensor,归一化[0,1]
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 将tensor标准化[-1,1]
])
transform_heat = transforms.Compose([
ToTensor(), # 将图像(Image)转成Tensor,归一化[0,1]
])
if __name__ == "__main__":
# fashionDataset = KeyPointsDataSet(csv_file=r"E:\Datasets\Fashion\Fashion AI-keypoints\test\test.csv",
# root_dir=r"E:\Datasets\Fashion\Fashion AI-keypoints\test",
# transform_img=transform_img,
# transform_heat=transform_heat
# )
# dataloader = DataLoader(dataset=fashionDataset, batch_size=4)
# for i_batch, data in enumerate(dataloader):
# img, landmarks, hmap = data
# print(type(img), type(landmarks), type(hmap))
# print(img.shape, landmarks.shape, hmap.shape)
# if i_batch == 1:
# break
test_fashionDataset = DataSet_Test(csv_file=r"E:\Datasets\Fashion\Fashion AI-keypoints\test\test.csv",
root_dir=r"E:\Datasets\Fashion\Fashion AI-keypoints\test",
transform_img=transform_img)
test_dataloader = DataLoader(dataset=test_fashionDataset, batch_size=4)
for i_batch, data in enumerate(test_dataloader):
img = data
# print(img)
# print(img.shape)
if i_batch == 1:
break
1
https://gitee.com/hjycodehjy/key-points-detection-method-with-torch.git
git@gitee.com:hjycodehjy/key-points-detection-method-with-torch.git
hjycodehjy
key-points-detection-method-with-torch
keyPointsDetectionMethodWithTorch
master

搜索帮助

53164aa7 5694891 3bd8fe86 5694891