1 Star 1 Fork 0

Haixu He/自编码器提取特征

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
utils.py 7.04 KB
一键复制 编辑 原始数据 按行查看 历史
Haixu He 提交于 2022-05-29 18:18 . add Segment
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author :hhx
@Date :2022/5/19 21:35
@Description : dataloder
"""
# -*- coding:utf-8 -*-
import os
from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as T
import torch
from osgeo import gdal
import matplotlib.pyplot as plt
import os
import scipy.misc
from PIL import Image
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
def nor(data):
"""归一化"""
min = np.min(data)
max = np.max(data)
return (data - min) / (max - min)
# 读取tiff文件
def readGeoTIFF(fileName):
dataset = gdal.Open(fileName)
if dataset == None:
print(fileName + "文件无法打开")
im_width = dataset.RasterXSize # 栅格矩阵的列数
im_height = dataset.RasterYSize # 栅格矩阵的行数
im_data = dataset.ReadAsArray(0, 0, im_width, im_height) # 获取数据
return im_data
class CarIndexDateSet(data.Dataset):
def __init__(self, root, transforms=None, type='train'):
self.type = type
imgs = []
labels = []
if type == 'train':
# dataList = ['长江存储']
dataList = ['江边', '洛阳1', '未来城校区2016-2021月度云最小', '长江存储']
for dir in dataList:
for i in os.listdir(os.path.join(root, dir)):
if i.endswith('tif'):
imgs.append(os.path.join(root, dir, i))
labels.append(os.path.join(root, dir, i))
else:
dataList = ['长江存储']
for dir in dataList:
for i in os.listdir(os.path.join(root, dir)):
if i.endswith('tif'):
imgs.append(os.path.join(root, dir, i))
labels.append(os.path.join(root, dir, i))
self.imgs = imgs
self.labels = labels
if transforms is None:
self.transforms = T.Compose([
T.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
])
else:
self.transforms = transforms
def __getitem__(self, index):
"""
一次返回一张图片的数据
"""
img_path = self.imgs[index]
label = self.labels[index]
Img = readGeoTIFF(img_path)
Label = readGeoTIFF(label)
NDVI = (Img[7] - Img[3]) / (Img[7] + Img[3])
NDVI = np.array(Image.fromarray(NDVI).resize((32, 32)))
NDVI_id = np.where(NDVI.astype('str') == 'nan')
NDVI[NDVI_id] = 0
# print(NDVI)
Img = self.transforms(nor(NDVI))
Label = self.transforms(nor(NDVI))
if self.type == 'train':
return Img, Label
else:
return Img
def __len__(self):
return len(self.imgs)
class CarTiffDateSet(data.Dataset):
def __init__(self, root, transforms=None, type='train'):
self.type = type
imgs = []
labels = []
if type == 'train':
dataList = ['JB', 'LY', 'WLC', 'CJCC']
for dir in dataList:
for i in os.listdir(os.path.join(root, dir)):
if i.endswith('tif'):
imgs.append(os.path.join(root, dir, i))
labels.append(os.path.join(root, dir, i))
else:
dataList = ['CJCC']
for dir in dataList:
for i in os.listdir(os.path.join(root, dir)):
if i.endswith('tif'):
imgs.append(os.path.join(root, dir, i))
labels.append(os.path.join(root, dir, i))
self.imgs = imgs
self.labels = labels
if transforms is None:
self.transforms = T.Compose([
T.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
])
else:
self.transforms = transforms
def __getitem__(self, index):
"""
一次返回一张图片的数据
"""
img_path = self.imgs[index]
label = self.labels[index]
Img = readGeoTIFF(img_path)
Label = readGeoTIFF(label)
Img = np.delete(Img, np.s_[0, 8, 9, 13, 14, 15], axis=0)
Label = np.delete(Label, np.s_[0, 8, 9, 13, 14, 15], axis=0)
Img_temp = np.zeros([10, 32, 32])
for i in range(10):
Img_temp[i] = np.array(Image.fromarray(Img[i]).resize((32, 32)))
# print(Img_temp)
# plt.imshow(Img_temp[0])
# plt.show()
Img = self.transforms(nor(Img_temp.T))
Label = self.transforms(nor(Img_temp.T))
if self.type == 'train':
return Img, Label
else:
return Img
def __len__(self):
return len(self.imgs)
class CarDateSet(data.Dataset):
def __init__(self, root, transforms=None, type='train'):
self.type = type
imgs = []
labels = []
if type == 'train':
for dir in os.listdir(root):
for i in os.listdir(os.path.join(root, dir)):
imgs.append(os.path.join(root, dir, i))
labels.append(os.path.join(root, dir, i))
# for i in os.listdir(os.path.join(root, 'CJCC')):
# imgs.append(os.path.join(root, 'CJCC', i))
# labels.append(os.path.join(root, 'CJCC', i))
else:
for i in os.listdir(os.path.join(root, 'CJCC')):
imgs.append(os.path.join(root, 'CJCC', i))
labels.append(os.path.join(root, 'CJCC', i))
self.imgs = imgs
self.labels = labels
if transforms is None:
self.transforms = T.Compose([
# torchvision.transforms.Resize(256),
# T.ToTensor()
# T.Resize(224), # 缩放图片(Image),保持长宽比不变,最短边为224像素
# T.ToPILImage(),
T.Resize((32, 32)), # 缩放图片(Image)到(h,w)
# T.RandomHorizontalFlip(p=0.3),
# T.RandomVerticalFlip(p=0.3),
# T.RandomCrop(size=224),
# T.RandomRotation(180),
# T.RandomHorizontalFlip(), #水平翻转,注意不是所有图片都适合,比如车牌
# T.CenterCrop(224), # 从图片中间切出224*224的图片
# T.RandomCrop(224), #随机裁剪
T.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
# T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 标准化至[-1, 1],规定均值和标准差
])
else:
self.transforms = transforms
def __getitem__(self, index):
"""
一次返回一张图片的数据
"""
img_path = self.imgs[index]
label = self.labels[index]
Img = Image.open(img_path).convert('RGB')
Img = self.transforms(Img)
Label = Image.open(img_path).convert('RGB')
Label = self.transforms(Label)
if self.type == 'train':
return Img, Label
else:
return Img
def __len__(self):
return len(self.imgs)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/HaixuHe/Feature-from-encoder.git
git@gitee.com:HaixuHe/Feature-from-encoder.git
HaixuHe
Feature-from-encoder
自编码器提取特征
master

搜索帮助