代码拉取完成,页面将自动刷新
# -*- coding: utf-8 -*-
"""
Created on Fri Jul 9 12:01:33 2021
@author: BHN
"""
from torch.utils.data import Dataset
import torchvision.transforms as transforms
#from torchsummary import summary
import cv2
import os
import json
class Single_classification_Dataset(Dataset):
def __init__(self, classification, filename = 'labels_1000.json',path = '//harddisk//bhn//data'):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.data_transform = transforms.Compose([
transforms.ToTensor()
])
self.classification = classification
self.filename = filename
self.Get_all_pic(path,self.filename)
def Get_all_pic(self, path,filename):
data_tmp = []
f=open(filename)
labels_dic = json.load(f)
for imagenet_id in self.classification:
for root, dirs, files in os.walk(path+"//imagenet//"+imagenet_id, topdown=False):
for name in files:
data_tmp.append({'img_path':os.path.join(root,name),'label':labels_dic[imagenet_id]['num']})
self.data = data_tmp
# print(self.data)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
image_path = self.data[idx]['img_path']
# print(image_path)
image = cv2.imread(image_path)
image = cv2.resize(image, (224,224))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
label = int(self.data[idx]['label'])
if self.data_transform:
image = self.data_transform(image)
return [image,label]
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。