1 Star 0 Fork 0

greitzmann/Keras-Image-Super-Resolution

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
data.py 2.01 KB
一键复制 编辑 原始数据 按行查看 历史
hieubkset 提交于 6年前 . Update data.py
import random
import pathlib
import tensorflow as tf
def preprocess_image(image, ext):
"""
Normalize image to [-1, 1]
"""
assert ext in ['.png', '.jpg', '.jpeg', '.JPEG']
if ext == '.png':
image = tf.image.decode_png(image, channels=3)
else:
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image -= 0.5
image /= 0.5
return image
def load_and_preprocess_image(image_path, ext):
image = tf.read_file(image_path)
return preprocess_image(image, ext)
def get_sorted_image_path(path, ext):
ext_regex = "*" + ext
data_root = pathlib.Path(path)
image_paths = list(data_root.glob(ext_regex))
image_paths = sorted([str(path) for path in image_paths])
return image_paths
def get_dataset(lr_path, hr_path, ext):
lr_sorted_paths = get_sorted_image_path(lr_path, ext)
hr_sorted_paths = get_sorted_image_path(hr_path, ext)
lr_hr_sorted_paths = list(zip(lr_sorted_paths[:], hr_sorted_paths[:]))
random.shuffle(lr_hr_sorted_paths)
lr_sorted_paths, hr_sorted_paths = zip(*lr_hr_sorted_paths)
ds = tf.data.Dataset.from_tensor_slices((list(lr_sorted_paths), list(hr_sorted_paths)))
def load_and_preprocess_lr_hr_images(lr_path, hr_path, ext=ext):
return load_and_preprocess_image(lr_path, ext), load_and_preprocess_image(hr_path, ext)
lr_hr_ds = ds.map(load_and_preprocess_lr_hr_images, num_parallel_calls=8)
return lr_hr_ds, len(lr_sorted_paths)
def load_train_dataset(lr_path, hr_path, ext, batch_size):
lr_hr_ds, n_data = get_dataset(lr_path, hr_path, ext)
lr_hr_ds = lr_hr_ds.batch(batch_size)
lr_hr_ds = lr_hr_ds.repeat()
lr_hr_ds = lr_hr_ds.make_one_shot_iterator()
return lr_hr_ds, n_data
def load_test_dataset(lr_path, hr_path, ext, batch_size):
val_lr_hr_ds, val_n_data = get_dataset(lr_path, hr_path, ext)
val_lr_hr_ds = val_lr_hr_ds.batch(batch_size)
val_lr_hr_ds = val_lr_hr_ds.repeat()
return val_lr_hr_ds, val_n_data
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/greitzmann/Keras-Image-Super-Resolution.git
git@gitee.com:greitzmann/Keras-Image-Super-Resolution.git
greitzmann
Keras-Image-Super-Resolution
Keras-Image-Super-Resolution
master

搜索帮助