2 Star 0 Fork 0

TinyMind/quiz-object-detection

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
inference.py 2.80 KB
一键复制 编辑 原始数据 按行查看 历史
dwSun 提交于 7年前 . fix path and training step.
import argparse
import os
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
from PIL import Image
from utils import visualization_utils as vis_util
from utils import label_map_util
NUM_CLASSES = 5
def parse_args(check=True):
parser = argparse.ArgumentParser()
parser.add_argument('--output_dir', type=str, required=True)
parser.add_argument('--dataset_dir', type=str, required=True)
FLAGS, unparsed = parser.parse_known_args()
return FLAGS, unparsed
if __name__ == '__main__':
FLAGS, unparsed = parse_args()
PATH_TO_CKPT = os.path.join(FLAGS.output_dir, 'frozen_inference_graph.pb')
PATH_TO_LABELS = os.path.join(FLAGS.dataset_dir, 'labels_items.txt')
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
def load_image_into_numpy_array(image):
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(im_height, im_width, 3)).astype(np.uint8)
test_img_path = os.path.join(FLAGS.dataset_dir, 'test.jpg')
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
image = Image.open(test_img_path)
image_np = load_image_into_numpy_array(image)
image_np_expanded = np.expand_dims(image_np, axis=0)
(boxes, scores, classes, num) = sess.run(
[detection_boxes, detection_scores, detection_classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
vis_util.visualize_boxes_and_labels_on_image_array(
image_np,
np.squeeze(boxes),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
category_index,
use_normalized_coordinates=True,
line_thickness=8)
plt.imsave(os.path.join(FLAGS.output_dir, 'output.png'), image_np)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ai100/quiz-object-detection.git
git@gitee.com:ai100/quiz-object-detection.git
ai100
quiz-object-detection
quiz-object-detection
master

搜索帮助