1 Star 1 Fork 1

xyygudu/Learning-to-See-in-the-Dark

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
test_Fuji.py 7.00 KB
一键复制 编辑 原始数据 按行查看 历史
pengxiaoping 提交于 2018-08-08 15:16 . more efficient and more pythonic
from __future__ import division
import os, scipy.io
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
import rawpy
import glob
input_dir = './dataset/Fuji/short/'
gt_dir = './dataset/Fuji/long/'
checkpoint_dir = './checkpoint/Fuji/'
result_dir = './result_Fuji/'
# get test IDs
test_fns = glob.glob(gt_dir + '1*.RAF')
test_ids = [int(os.path.basename(test_fn)[0:5]) for test_fn in test_fns]
def lrelu(x):
return tf.maximum(x * 0.2, x)
def upsample_and_concat(x1, x2, output_channels, in_channels):
pool_size = 2
deconv_filter = tf.Variable(tf.truncated_normal([pool_size, pool_size, output_channels, in_channels], stddev=0.02))
deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2), strides=[1, pool_size, pool_size, 1])
deconv_output = tf.concat([deconv, x2], 3)
deconv_output.set_shape([None, None, None, output_channels * 2])
return deconv_output
def network(input): # Unet
conv1 = slim.conv2d(input, 32, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv1_1')
conv1 = slim.conv2d(conv1, 32, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv1_2')
pool1 = slim.max_pool2d(conv1, [2, 2], padding='SAME')
conv2 = slim.conv2d(pool1, 64, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv2_1')
conv2 = slim.conv2d(conv2, 64, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv2_2')
pool2 = slim.max_pool2d(conv2, [2, 2], padding='SAME')
conv3 = slim.conv2d(pool2, 128, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv3_1')
conv3 = slim.conv2d(conv3, 128, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv3_2')
pool3 = slim.max_pool2d(conv3, [2, 2], padding='SAME')
conv4 = slim.conv2d(pool3, 256, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv4_1')
conv4 = slim.conv2d(conv4, 256, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv4_2')
pool4 = slim.max_pool2d(conv4, [2, 2], padding='SAME')
conv5 = slim.conv2d(pool4, 512, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv5_1')
conv5 = slim.conv2d(conv5, 512, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv5_2')
up6 = upsample_and_concat(conv5, conv4, 256, 512)
conv6 = slim.conv2d(up6, 256, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv6_1')
conv6 = slim.conv2d(conv6, 256, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv6_2')
up7 = upsample_and_concat(conv6, conv3, 128, 256)
conv7 = slim.conv2d(up7, 128, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv7_1')
conv7 = slim.conv2d(conv7, 128, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv7_2')
up8 = upsample_and_concat(conv7, conv2, 64, 128)
conv8 = slim.conv2d(up8, 64, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv8_1')
conv8 = slim.conv2d(conv8, 64, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv8_2')
up9 = upsample_and_concat(conv8, conv1, 32, 64)
conv9 = slim.conv2d(up9, 32, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv9_1')
conv9 = slim.conv2d(conv9, 32, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv9_2')
conv10 = slim.conv2d(conv9, 27, [1, 1], rate=1, activation_fn=None, scope='g_conv10')
out = tf.depth_to_space(conv10, 3)
return out
def pack_raw(raw):
# pack X-Trans image to 9 channels
im = raw.raw_image_visible.astype(np.float32)
im = np.maximum(im - 1024, 0) / (16383 - 1024) # subtract the black level
img_shape = im.shape
H = (img_shape[0] // 6) * 6
W = (img_shape[1] // 6) * 6
out = np.zeros((H // 3, W // 3, 9))
# 0 R
out[0::2, 0::2, 0] = im[0:H:6, 0:W:6]
out[0::2, 1::2, 0] = im[0:H:6, 4:W:6]
out[1::2, 0::2, 0] = im[3:H:6, 1:W:6]
out[1::2, 1::2, 0] = im[3:H:6, 3:W:6]
# 1 G
out[0::2, 0::2, 1] = im[0:H:6, 2:W:6]
out[0::2, 1::2, 1] = im[0:H:6, 5:W:6]
out[1::2, 0::2, 1] = im[3:H:6, 2:W:6]
out[1::2, 1::2, 1] = im[3:H:6, 5:W:6]
# 1 B
out[0::2, 0::2, 2] = im[0:H:6, 1:W:6]
out[0::2, 1::2, 2] = im[0:H:6, 3:W:6]
out[1::2, 0::2, 2] = im[3:H:6, 0:W:6]
out[1::2, 1::2, 2] = im[3:H:6, 4:W:6]
# 4 R
out[0::2, 0::2, 3] = im[1:H:6, 2:W:6]
out[0::2, 1::2, 3] = im[2:H:6, 5:W:6]
out[1::2, 0::2, 3] = im[5:H:6, 2:W:6]
out[1::2, 1::2, 3] = im[4:H:6, 5:W:6]
# 5 B
out[0::2, 0::2, 4] = im[2:H:6, 2:W:6]
out[0::2, 1::2, 4] = im[1:H:6, 5:W:6]
out[1::2, 0::2, 4] = im[4:H:6, 2:W:6]
out[1::2, 1::2, 4] = im[5:H:6, 5:W:6]
out[:, :, 5] = im[1:H:3, 0:W:3]
out[:, :, 6] = im[1:H:3, 1:W:3]
out[:, :, 7] = im[2:H:3, 0:W:3]
out[:, :, 8] = im[2:H:3, 1:W:3]
return out
sess = tf.Session()
in_image = tf.placeholder(tf.float32, [None, None, None, 9])
gt_image = tf.placeholder(tf.float32, [None, None, None, 3])
out_image = network(in_image)
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt:
print('loaded ' + ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
if not os.path.isdir(result_dir + 'final/'):
os.makedirs(result_dir + 'final/')
for test_id in test_ids:
# test the first image in each sequence
in_files = glob.glob(input_dir + '%05d_00*.RAF' % test_id)
for k in range(len(in_files)):
in_path = in_files[k]
in_fn = os.path.basename(in_path)
print(in_fn)
gt_files = glob.glob(gt_dir + '%05d_00*.RAF' % test_id)
gt_path = gt_files[0]
gt_fn = os.path.basename(gt_path)
in_exposure = float(in_fn[9:-5])
gt_exposure = float(gt_fn[9:-5])
ratio = min(gt_exposure / in_exposure, 300)
raw = rawpy.imread(in_path)
input_full = np.expand_dims(pack_raw(raw), axis=0) * ratio
im = raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
# scale_full = np.expand_dims(np.float32(im/65535.0),axis = 0)*ratio #scale the low-light image using the same ratio
scale_full = np.expand_dims(np.float32(im / 65535.0), axis=0)
gt_raw = rawpy.imread(gt_path)
im = gt_raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
gt_full = np.expand_dims(np.float32(im / 65535.0), axis=0)
input_full = np.minimum(input_full, 1.0)
output = sess.run(out_image, feed_dict={in_image: input_full})
output = np.minimum(np.maximum(output, 0), 1)
_, H, W, _ = output.shape
output = output[0, :, :, :]
gt_full = gt_full[0, 0:H, 0:W, :]
scale_full = scale_full[0, 0:H, 0:W, :]
scale_full = scale_full * np.mean(gt_full) / np.mean(
scale_full) # scale the low-light image to the same mean of the groundtruth
scipy.misc.toimage(output * 255, high=255, low=0, cmin=0, cmax=255).save(
result_dir + 'final/%5d_00_%d_out.png' % (test_id, ratio))
scipy.misc.toimage(scale_full * 255, high=255, low=0, cmin=0, cmax=255).save(
result_dir + 'final/%5d_00_%d_scale.png' % (test_id, ratio))
scipy.misc.toimage(gt_full * 255, high=255, low=0, cmin=0, cmax=255).save(
result_dir + 'final/%5d_00_%d_gt.png' % (test_id, ratio))
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/xyygudu/Learning-to-See-in-the-Dark.git
git@gitee.com:xyygudu/Learning-to-See-in-the-Dark.git
xyygudu
Learning-to-See-in-the-Dark
Learning-to-See-in-the-Dark
master

搜索帮助