代码拉取完成,页面将自动刷新
import torch
import cv2
import numpy as np
from tracker import Tracker
from utils.headpose import get_head_pose
import time
import onnxruntime as rt
class Detector:
def __init__(self, detection_size=(160, 160)):
self.sess = rt.InferenceSession("pretrained_weights/slim_160_latest.onnx")
self.input_name = self.sess.get_inputs()[0].name
self.detection_size = detection_size
self.tracker = Tracker()
def crop_image(self, orig, bbox):
bbox = bbox.copy()
image = orig.copy()
bbox_width = bbox[2] - bbox[0]
bbox_height = bbox[3] - bbox[1]
face_width = (1 + 2 * 0.2) * bbox_width
face_height = (1 + 2 * 0.2) * bbox_height
center = [(bbox[0] + bbox[2]) // 2, (bbox[1] + bbox[3]) // 2]
bbox[0] = max(0, center[0] - face_width // 2)
bbox[1] = max(0, center[1] - face_height // 2)
bbox[2] = min(image.shape[1], center[0] + face_width // 2)
bbox[3] = min(image.shape[0], center[1] + face_height // 2)
bbox = bbox.astype(np.int)
crop_image = image[bbox[1]:bbox[3], bbox[0]:bbox[2], :]
h, w, _ = crop_image.shape
crop_image = cv2.resize(crop_image, self.detection_size)
return crop_image, ([h, w, bbox[1], bbox[0]])
def detect(self, img, bbox):
crop_image, detail = self.crop_image(img, bbox)
crop_image = (crop_image - 127.0) / 127.0
crop_image = np.array([np.transpose(crop_image, (2, 0, 1))]).astype(np.float32)
start = time.time()
raw = self.sess.run(None, {self.input_name: crop_image})[0][0]
end = time.time()
print("ONNX Inference Time: {:.6f}".format(end - start))
landmark = raw[0:136].reshape((-1, 2))
landmark[:, 0] = landmark[:, 0] * detail[1] + detail[3]
landmark[:, 1] = landmark[:, 1] * detail[0] + detail[2]
landmark = self.tracker.track(img, landmark)
_, PRY_3d = get_head_pose(landmark, img)
return landmark, PRY_3d[:, 0]
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。