1 Star 0 Fork 0

huananbao / MG_CAM

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
imagenet_data.py 1.83 KB
一键复制 编辑 原始数据 按行查看 历史
huananbao 提交于 2023-06-09 08:24 . 第一次上传
# -*- coding: utf-8 -*-
"""
Created on Fri Jul 9 12:01:33 2021
@author: BHN
"""
from torch.utils.data import Dataset
import torchvision.transforms as transforms
#from torchsummary import summary
import cv2
import os
import json
class Single_classification_Dataset(Dataset):
def __init__(self, classification, filename = 'labels_1000.json',path = '//harddisk//bhn//data'):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.data_transform = transforms.Compose([
transforms.ToTensor()
])
self.classification = classification
self.filename = filename
self.Get_all_pic(path,self.filename)
def Get_all_pic(self, path,filename):
data_tmp = []
f=open(filename)
labels_dic = json.load(f)
for imagenet_id in self.classification:
for root, dirs, files in os.walk(path+"//imagenet//"+imagenet_id, topdown=False):
for name in files:
data_tmp.append({'img_path':os.path.join(root,name),'label':labels_dic[imagenet_id]['num']})
self.data = data_tmp
# print(self.data)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
image_path = self.data[idx]['img_path']
# print(image_path)
image = cv2.imread(image_path)
image = cv2.resize(image, (224,224))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
label = int(self.data[idx]['label'])
if self.data_transform:
image = self.data_transform(image)
return [image,label]
Python
1
https://gitee.com/huananbao/mg_cam.git
git@gitee.com:huananbao/mg_cam.git
huananbao
mg_cam
MG_CAM
master

搜索帮助