1 Star 0 Fork 0

deeplearningrepos/tensorflow-deeplab-resnet

加入 Gitee
与超过 1400万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
evaluate_msc.py 5.47 KB
一键复制 编辑 原始数据 按行查看 历史
DrSleep 提交于 2017-04-23 15:42 +08:00 . added NUM_CLASSES and IGNORE_LABEL flags
"""Evaluation script for the DeepLab-ResNet network on the validation subset
of PASCAL VOC dataset.
This script evaluates the model on 1449 validation images.
"""
from __future__ import print_function
import argparse
from datetime import datetime
import os
import sys
import time
import tensorflow as tf
import numpy as np
from deeplab_resnet import DeepLabResNetModel, ImageReader, prepare_label
IMG_MEAN = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32)
DATA_DIRECTORY = '/home/VOCdevkit'
DATA_LIST_PATH = './dataset/val.txt'
IGNORE_LABEL = 255
NUM_CLASSES = 21
NUM_STEPS = 1449 # Number of images in the validation set.
RESTORE_FROM = './deeplab_resnet.ckpt'
def get_arguments():
"""Parse all the arguments provided from the CLI.
Returns:
A list of parsed arguments.
"""
parser = argparse.ArgumentParser(description="DeepLabLFOV Network")
parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY,
help="Path to the directory containing the PASCAL VOC dataset.")
parser.add_argument("--data-list", type=str, default=DATA_LIST_PATH,
help="Path to the file listing the images in the dataset.")
parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL,
help="The index of the label to ignore during the training.")
parser.add_argument("--num-classes", type=int, default=NUM_CLASSES,
help="Number of classes to predict (including background).")
parser.add_argument("--num-steps", type=int, default=NUM_STEPS,
help="Number of images in the validation set.")
parser.add_argument("--restore-from", type=str, default=RESTORE_FROM,
help="Where restore model parameters from.")
return parser.parse_args()
def load(saver, sess, ckpt_path):
'''Load trained weights.
Args:
saver: TensorFlow saver object.
sess: TensorFlow session.
ckpt_path: path to checkpoint file with parameters.
'''
saver.restore(sess, ckpt_path)
print("Restored model parameters from {}".format(ckpt_path))
def main():
"""Create the model and start the evaluation process."""
args = get_arguments()
# Create queue coordinator.
coord = tf.train.Coordinator()
# Load reader.
with tf.name_scope("create_inputs"):
reader = ImageReader(
args.data_dir,
args.data_list,
None, # No defined input size.
False, # No random scale.
False, # No random mirror.
args.ignore_label,
IMG_MEAN,
coord)
image, label = reader.image, reader.label
image_batch, label_batch = tf.expand_dims(image, dim=0), tf.expand_dims(label, dim=0) # Add one batch dimension.
h_orig, w_orig = tf.to_float(tf.shape(image_batch)[1]), tf.to_float(tf.shape(image_batch)[2])
image_batch075 = tf.image.resize_images(image_batch, tf.stack([tf.to_int32(tf.multiply(h_orig, 0.75)), tf.to_int32(tf.multiply(w_orig, 0.75))]))
image_batch05 = tf.image.resize_images(image_batch, tf.stack([tf.to_int32(tf.multiply(h_orig, 0.5)), tf.to_int32(tf.multiply(w_orig, 0.5))]))
# Create network.
with tf.variable_scope('', reuse=False):
net = DeepLabResNetModel({'data': image_batch}, is_training=False, num_classes=args.num_classes)
with tf.variable_scope('', reuse=True):
net075 = DeepLabResNetModel({'data': image_batch075}, is_training=False, num_classes=args.num_classes)
with tf.variable_scope('', reuse=True):
net05 = DeepLabResNetModel({'data': image_batch05}, is_training=False, num_classes=args.num_classes)
# Which variables to load.
restore_var = tf.global_variables()
# Predictions.
raw_output100 = net.layers['fc1_voc12']
raw_output075 = tf.image.resize_images(net075.layers['fc1_voc12'], tf.shape(raw_output100)[1:3,])
raw_output05 = tf.image.resize_images(net05.layers['fc1_voc12'], tf.shape(raw_output100)[1:3,])
raw_output = tf.reduce_max(tf.stack([raw_output100, raw_output075, raw_output05]), axis=0)
raw_output = tf.image.resize_bilinear(raw_output, tf.shape(image_batch)[1:3,])
raw_output = tf.argmax(raw_output, dimension=3)
pred = tf.expand_dims(raw_output, dim=3) # Create 4-d tensor.
# mIoU
pred = tf.reshape(pred, [-1,])
gt = tf.reshape(label_batch, [-1,])
weights = tf.cast(tf.less_equal(gt, args.num_classes - 1), tf.int32) # Ignoring all labels greater than or equal to n_classes.
mIoU, update_op = tf.contrib.metrics.streaming_mean_iou(pred, gt, num_classes=args.num_classes, weights=weights)
# Set up tf session and initialize variables.
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
init = tf.global_variables_initializer()
sess.run(init)
sess.run(tf.local_variables_initializer())
# Load weights.
loader = tf.train.Saver(var_list=restore_var)
if args.restore_from is not None:
load(loader, sess, args.restore_from)
# Start queue threads.
threads = tf.train.start_queue_runners(coord=coord, sess=sess)
# Iterate over training steps.
for step in range(args.num_steps):
preds, _ = sess.run([pred, update_op])
if step % 100 == 0:
print('step {:d}'.format(step))
print('Mean IoU: {:.3f}'.format(mIoU.eval(session=sess)))
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
main()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/deeplearningrepos/tensorflow-deeplab-resnet.git
git@gitee.com:deeplearningrepos/tensorflow-deeplab-resnet.git
deeplearningrepos
tensorflow-deeplab-resnet
tensorflow-deeplab-resnet
master

搜索帮助