1 Star 0 Fork 0

Franck2333 / DeblurGAN

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
dataset.py 5.06 KB
一键复制 编辑 原始数据 按行查看 历史
Franck2333 提交于 2021-03-16 11:20 . init this repository
import os
from copy import deepcopy
from functools import partial
from glob import glob
from hashlib import sha1
from typing import Callable, Iterable, Optional, Tuple
import cv2
import numpy as np
from glog import logger
from joblib import Parallel, cpu_count, delayed
from skimage.io import imread
from torch.utils.data import Dataset
from tqdm import tqdm
from v2 import aug
def subsample(data: Iterable, bounds: Tuple[float, float], hash_fn: Callable, n_buckets=100, salt='', verbose=True):
data = list(data)
buckets = split_into_buckets(data, n_buckets=n_buckets, salt=salt, hash_fn=hash_fn)
lower_bound, upper_bound = [x * n_buckets for x in bounds]
msg = f'Subsampling buckets from {lower_bound} to {upper_bound}, total buckets number is {n_buckets}'
if salt:
msg += f'; salt is {salt}'
if verbose:
logger.info(msg)
return np.array([sample for bucket, sample in zip(buckets, data) if lower_bound <= bucket < upper_bound])
def hash_from_paths(x: Tuple[str, str], salt: str = '') -> str:
path_a, path_b = x
names = ''.join(map(os.path.basename, (path_a, path_b)))
return sha1(f'{names}_{salt}'.encode()).hexdigest()
def split_into_buckets(data: Iterable, n_buckets: int, hash_fn: Callable, salt=''):
hashes = map(partial(hash_fn, salt=salt), data)
return np.array([int(x, 16) % n_buckets for x in hashes])
def _read_img(x: str):
img = cv2.imread(x)
if img is None:
logger.warning(f'Can not read image {x} with OpenCV, switching to scikit-image')
img = imread(x)
return img
class PairedDataset(Dataset):
def __init__(self,
files_a: Tuple[str],
files_b: Tuple[str],
transform_fn: Callable,
normalize_fn: Callable,
corrupt_fn: Optional[Callable] = None,
preload: bool = True,
preload_size: Optional[int] = 0,
verbose=True):
assert len(files_a) == len(files_b)
self.preload = preload
self.data_a = files_a
self.data_b = files_b
self.verbose = verbose
self.corrupt_fn = corrupt_fn
self.transform_fn = transform_fn
self.normalize_fn = normalize_fn
logger.info(f'Dataset has been created with {len(self.data_a)} samples')
if preload:
preload_fn = partial(self._bulk_preload, preload_size=preload_size)
if files_a == files_b:
self.data_a = self.data_b = preload_fn(self.data_a)
else:
self.data_a, self.data_b = map(preload_fn, (self.data_a, self.data_b))
self.preload = True
def _bulk_preload(self, data: Iterable[str], preload_size: int):
jobs = [delayed(self._preload)(x, preload_size=preload_size) for x in data]
jobs = tqdm(jobs, desc='preloading images', disable=not self.verbose)
return Parallel(n_jobs=cpu_count(), backend='threading')(jobs)
@staticmethod
def _preload(x: str, preload_size: int):
img = _read_img(x)
if preload_size:
h, w, *_ = img.shape
h_scale = preload_size / h
w_scale = preload_size / w
scale = max(h_scale, w_scale)
img = cv2.resize(img, fx=scale, fy=scale, dsize=None)
assert min(img.shape[:2]) >= preload_size, f'weird img shape: {img.shape}'
return img
def _preprocess(self, img, res):
def transpose(x):
return np.transpose(x, (2, 0, 1))
return map(transpose, self.normalize_fn(img, res))
def __len__(self):
return len(self.data_a)
def __getitem__(self, idx):
a, b = self.data_a[idx], self.data_b[idx]
if not self.preload:
a, b = map(_read_img, (a, b))
a, b = self.transform_fn(a, b)
if self.corrupt_fn is not None:
a = self.corrupt_fn(a)
a, b = self._preprocess(a, b)
return {'a': a, 'b': b}
@staticmethod
def from_config(config):
config = deepcopy(config)
files_a, files_b = map(lambda x: sorted(glob(config[x], recursive=True)), ('files_a', 'files_b'))
transform_fn = aug.get_transforms(size=config['size'], scope=config['scope'], crop=config['crop'])
normalize_fn = aug.get_normalize()
corrupt_fn = aug.get_corrupt_function(config['corrupt'])
hash_fn = hash_from_paths
# ToDo: add more hash functions
verbose = config.get('verbose', True)
data = subsample(data=zip(files_a, files_b),
bounds=config.get('bounds', (0, 1)),
hash_fn=hash_fn,
verbose=verbose)
files_a, files_b = map(list, zip(*data))
return PairedDataset(files_a=files_a,
files_b=files_b,
preload=config['preload'],
preload_size=config['preload_size'],
corrupt_fn=corrupt_fn,
normalize_fn=normalize_fn,
transform_fn=transform_fn,
verbose=verbose)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/franck2333/deblur-gan.git
git@gitee.com:franck2333/deblur-gan.git
franck2333
deblur-gan
DeblurGAN
master

搜索帮助

344bd9b3 5694891 D2dac590 5694891