1 Star 0 Fork 0

ErzongXie/tensorrt_demos

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
trt_yolo_mjpeg.py 3.13 KB
一键复制 编辑 原始数据 按行查看 历史
JK Jung 提交于 2022-04-03 14:29 . Add support of 'yolov4-p5' models
"""trt_yolo_mjpeg.py
MJPEG version of trt_yolo.py.
"""
import os
import time
import argparse
import cv2
import pycuda.autoinit # This is needed for initializing CUDA driver
from utils.yolo_classes import get_cls_dict
from utils.camera import add_camera_args, Camera
from utils.display import show_fps
from utils.visualization import BBoxVisualization
from utils.mjpeg import MjpegServer
from utils.yolo_with_plugins import TrtYOLO
def parse_args():
"""Parse input arguments."""
desc = 'MJPEG version of trt_yolo'
parser = argparse.ArgumentParser(description=desc)
parser = add_camera_args(parser)
parser.add_argument(
'-c', '--category_num', type=int, default=80,
help='number of object categories [80]')
parser.add_argument(
'-m', '--model', type=str, required=True,
help=('[yolov3-tiny|yolov3|yolov3-spp|yolov4-tiny|yolov4|'
'yolov4-csp|yolov4x-mish|yolov4-p5]-[{dimension}], where '
'{dimension} could be either a single number (e.g. '
'288, 416, 608) or 2 numbers, WxH (e.g. 416x256)'))
parser.add_argument(
'-l', '--letter_box', action='store_true',
help='inference with letterboxed image [False]')
parser.add_argument(
'-p', '--mjpeg_port', type=int, default=8080,
help='MJPEG server port [8080]')
args = parser.parse_args()
return args
def loop_and_detect(cam, trt_yolo, conf_th, vis, mjpeg_server):
"""Continuously capture images from camera and do object detection.
# Arguments
cam: the camera instance (video source).
trt_yolo: the TRT YOLO object detector instance.
conf_th: confidence/score threshold for object detection.
vis: for visualization.
mjpeg_server
"""
fps = 0.0
tic = time.time()
while True:
img = cam.read()
if img is None:
break
boxes, confs, clss = trt_yolo.detect(img, conf_th)
img = vis.draw_bboxes(img, boxes, confs, clss)
img = show_fps(img, fps)
mjpeg_server.send_img(img)
toc = time.time()
curr_fps = 1.0 / (toc - tic)
# calculate an exponentially decaying average of fps number
fps = curr_fps if fps == 0.0 else (fps*0.95 + curr_fps*0.05)
tic = toc
def main():
args = parse_args()
if args.category_num <= 0:
raise SystemExit('ERROR: bad category_num (%d)!' % args.category_num)
if not os.path.isfile('yolo/%s.trt' % args.model):
raise SystemExit('ERROR: file (yolo/%s.trt) not found!' % args.model)
cam = Camera(args)
if not cam.isOpened():
raise SystemExit('ERROR: failed to open camera!')
cls_dict = get_cls_dict(args.category_num)
vis = BBoxVisualization(cls_dict)
trt_yolo = TrtYOLO(args.model, args.category_num, args.letter_box)
mjpeg_server = MjpegServer(port=args.mjpeg_port)
print('MJPEG server started...')
try:
loop_and_detect(cam, trt_yolo, conf_th=0.3, vis=vis,
mjpeg_server=mjpeg_server)
except Exception as e:
print(e)
finally:
mjpeg_server.shutdown()
cam.release()
if __name__ == '__main__':
main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/erzongxie/tensorrt_demos.git
git@gitee.com:erzongxie/tensorrt_demos.git
erzongxie
tensorrt_demos
tensorrt_demos
master

搜索帮助

Dd8185d8 1850385 E526c682 1850385