代码拉取完成,页面将自动刷新
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author :hhx
@Date :2022/5/19 21:35
@Description : dataloder
"""
# -*- coding:utf-8 -*-
import os
from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as T
import torch
from osgeo import gdal
import matplotlib.pyplot as plt
import os
import scipy.misc
from PIL import Image
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
def nor(data):
"""归一化"""
min = np.min(data)
max = np.max(data)
return (data - min) / (max - min)
# 读取tiff文件
def readGeoTIFF(fileName):
dataset = gdal.Open(fileName)
if dataset == None:
print(fileName + "文件无法打开")
im_width = dataset.RasterXSize # 栅格矩阵的列数
im_height = dataset.RasterYSize # 栅格矩阵的行数
im_data = dataset.ReadAsArray(0, 0, im_width, im_height) # 获取数据
return im_data
class CarIndexDateSet(data.Dataset):
def __init__(self, root, transforms=None, type='train'):
self.type = type
imgs = []
labels = []
if type == 'train':
# dataList = ['长江存储']
dataList = ['江边', '洛阳1', '未来城校区2016-2021月度云最小', '长江存储']
for dir in dataList:
for i in os.listdir(os.path.join(root, dir)):
if i.endswith('tif'):
imgs.append(os.path.join(root, dir, i))
labels.append(os.path.join(root, dir, i))
else:
dataList = ['长江存储']
for dir in dataList:
for i in os.listdir(os.path.join(root, dir)):
if i.endswith('tif'):
imgs.append(os.path.join(root, dir, i))
labels.append(os.path.join(root, dir, i))
self.imgs = imgs
self.labels = labels
if transforms is None:
self.transforms = T.Compose([
T.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
])
else:
self.transforms = transforms
def __getitem__(self, index):
"""
一次返回一张图片的数据
"""
img_path = self.imgs[index]
label = self.labels[index]
Img = readGeoTIFF(img_path)
Label = readGeoTIFF(label)
NDVI = (Img[7] - Img[3]) / (Img[7] + Img[3])
NDVI = np.array(Image.fromarray(NDVI).resize((32, 32)))
NDVI_id = np.where(NDVI.astype('str') == 'nan')
NDVI[NDVI_id] = 0
# print(NDVI)
Img = self.transforms(nor(NDVI))
Label = self.transforms(nor(NDVI))
if self.type == 'train':
return Img, Label
else:
return Img
def __len__(self):
return len(self.imgs)
class CarTiffDateSet(data.Dataset):
def __init__(self, root, transforms=None, type='train'):
self.type = type
imgs = []
labels = []
if type == 'train':
dataList = ['JB', 'LY', 'WLC', 'CJCC']
for dir in dataList:
for i in os.listdir(os.path.join(root, dir)):
if i.endswith('tif'):
imgs.append(os.path.join(root, dir, i))
labels.append(os.path.join(root, dir, i))
else:
dataList = ['CJCC']
for dir in dataList:
for i in os.listdir(os.path.join(root, dir)):
if i.endswith('tif'):
imgs.append(os.path.join(root, dir, i))
labels.append(os.path.join(root, dir, i))
self.imgs = imgs
self.labels = labels
if transforms is None:
self.transforms = T.Compose([
T.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
])
else:
self.transforms = transforms
def __getitem__(self, index):
"""
一次返回一张图片的数据
"""
img_path = self.imgs[index]
label = self.labels[index]
Img = readGeoTIFF(img_path)
Label = readGeoTIFF(label)
Img = np.delete(Img, np.s_[0, 8, 9, 13, 14, 15], axis=0)
Label = np.delete(Label, np.s_[0, 8, 9, 13, 14, 15], axis=0)
Img_temp = np.zeros([10, 32, 32])
for i in range(10):
Img_temp[i] = np.array(Image.fromarray(Img[i]).resize((32, 32)))
# print(Img_temp)
# plt.imshow(Img_temp[0])
# plt.show()
Img = self.transforms(nor(Img_temp.T))
Label = self.transforms(nor(Img_temp.T))
if self.type == 'train':
return Img, Label
else:
return Img
def __len__(self):
return len(self.imgs)
class CarDateSet(data.Dataset):
def __init__(self, root, transforms=None, type='train'):
self.type = type
imgs = []
labels = []
if type == 'train':
for dir in os.listdir(root):
for i in os.listdir(os.path.join(root, dir)):
imgs.append(os.path.join(root, dir, i))
labels.append(os.path.join(root, dir, i))
# for i in os.listdir(os.path.join(root, 'CJCC')):
# imgs.append(os.path.join(root, 'CJCC', i))
# labels.append(os.path.join(root, 'CJCC', i))
else:
for i in os.listdir(os.path.join(root, 'CJCC')):
imgs.append(os.path.join(root, 'CJCC', i))
labels.append(os.path.join(root, 'CJCC', i))
self.imgs = imgs
self.labels = labels
if transforms is None:
self.transforms = T.Compose([
# torchvision.transforms.Resize(256),
# T.ToTensor()
# T.Resize(224), # 缩放图片(Image),保持长宽比不变,最短边为224像素
# T.ToPILImage(),
T.Resize((32, 32)), # 缩放图片(Image)到(h,w)
# T.RandomHorizontalFlip(p=0.3),
# T.RandomVerticalFlip(p=0.3),
# T.RandomCrop(size=224),
# T.RandomRotation(180),
# T.RandomHorizontalFlip(), #水平翻转,注意不是所有图片都适合,比如车牌
# T.CenterCrop(224), # 从图片中间切出224*224的图片
# T.RandomCrop(224), #随机裁剪
T.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
# T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 标准化至[-1, 1],规定均值和标准差
])
else:
self.transforms = transforms
def __getitem__(self, index):
"""
一次返回一张图片的数据
"""
img_path = self.imgs[index]
label = self.labels[index]
Img = Image.open(img_path).convert('RGB')
Img = self.transforms(Img)
Label = Image.open(img_path).convert('RGB')
Label = self.transforms(Label)
if self.type == 'train':
return Img, Label
else:
return Img
def __len__(self):
return len(self.imgs)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。