Fetch the repository succeeded.
# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os
import cv2
import glob
import torch
import tensorrt
import numpy as np
import pycuda.driver as cuda
def eval_batch(batch_score, batch_label):
batch_score = torch.tensor(torch.from_numpy(batch_score), dtype=torch.float32)
values, indices = batch_score.topk(5)
top1, top5 = 0, 0
for idx, label in enumerate(batch_label):
if label == indices[idx][0]:
top1 += 1
if label in indices[idx]:
top5 += 1
return top1, top5
def create_engine_context(engine_path, logger):
with open(engine_path, "rb") as f:
runtime = tensorrt.Runtime(logger)
assert runtime
engine = runtime.deserialize_cuda_engine(f.read())
assert engine
context = engine.create_execution_context()
assert context
return engine, context
def get_io_bindings(engine):
# Setup I/O bindings
inputs = []
outputs = []
allocations = []
for i in range(engine.num_bindings):
is_input = False
if engine.binding_is_input(i):
is_input = True
name = engine.get_binding_name(i)
dtype = engine.get_binding_dtype(i)
shape = engine.get_binding_shape(i)
if is_input:
batch_size = shape[0]
size = np.dtype(tensorrt.nptype(dtype)).itemsize
for s in shape:
size *= s
allocation = cuda.mem_alloc(size)
binding = {
"index": i,
"name": name,
"dtype": np.dtype(tensorrt.nptype(dtype)),
"shape": list(shape),
"allocation": allocation,
}
print(f"binding {i}, name : {name} dtype : {np.dtype(tensorrt.nptype(dtype))} shape : {list(shape)}")
allocations.append(allocation)
if engine.binding_is_input(i):
inputs.append(binding)
else:
outputs.append(binding)
return inputs, outputs, allocations
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。