1 Star 0 Fork 0

ptz / tflite_python

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
model_resize.py 2.49 KB
一键复制 编辑 原始数据 按行查看 历史
#!/usr/bin/python3
import os, sys
import cv2
import numpy as np
import glob
from tflite_runtime.interpreter import Interpreter
MODEL_NAME = "model"
GRAPH_NAME = "detect.tflite"
LABELMAP_NAME = "labelmap.txt"
min_conf_threshold = float(0.5)
IM_NAME = "images/out.jpg"
IM_DIR = None
CWD_PATH = os.getcwd()
if IM_DIR:
PATH_TO_IMAGES = os.path.join(CWD_PATH,IM_DIR)
images = glob.glob(PATH_TO_IMAGES + '/*')
elif IM_NAME:
PATH_TO_IMAGES = os.path.join(CWD_PATH,IM_NAME)
images = glob.glob(PATH_TO_IMAGES)
print("images: ", images)
PATH_TO_CKPT = os.path.join(CWD_PATH,MODEL_NAME,GRAPH_NAME)
PATH_TO_LABELS = os.path.join(CWD_PATH,MODEL_NAME,LABELMAP_NAME)
with open(PATH_TO_LABELS, 'r') as f:
labels = [line.strip() for line in f.readlines()]
if labels[0] == '???':
del(labels[0])
interpreter = Interpreter(model_path=PATH_TO_CKPT)
interpreter.allocate_tensors()
# Get model details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
height = input_details[0]['shape'][1]
width = input_details[0]['shape'][2]
print("input_details: ", input_details)
print("output_details: ", output_details)
floating_model = (input_details[0]['dtype'] == np.float32)
input_mean = 127.5
input_std = 127.5
image_path = images[0]
# Load image and resize to expected shape [1xHxWx3]
image = cv2.imread(image_path)
#image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#imH, imW, _ = image.shape
#image_resized = cv2.resize(image_rgb, (width, height))
image_resized = image
input_data = np.expand_dims(image_resized, axis=0)
# Normalize pixel values if using a floating model (i.e. if model is non-quantized)
if floating_model:
input_data = (np.float32(input_data) - input_mean) / input_std
# Perform the actual detection by running the model with the image as input
interpreter.set_tensor(input_details[0]['index'],input_data)
interpreter.invoke()
# Retrieve detection results
boxes = interpreter.get_tensor(output_details[0]['index'])[0] # Bounding box coordinates of detected objects
classes = interpreter.get_tensor(output_details[1]['index'])[0] # Class index of detected objects
scores = interpreter.get_tensor(output_details[2]['index'])[0] # Confidence of detected objects
#num = interpreter.get_tensor(output_details[3]['index'])[0] # Total number of detected objects (inaccurate and not needed)
print("boxes =====================")
print(boxes)
print("classes =====================")
print(classes)
print("scores =====================")
print(scores)
print(len(scores))
Python
1
https://gitee.com/ptz1986/tflite_python.git
git@gitee.com:ptz1986/tflite_python.git
ptz1986
tflite_python
tflite_python
master

搜索帮助