108 Star 868 Fork 1.5K

MindSpore/models

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
eval.py 4.62 KB
一键复制 编辑 原始数据 按行查看 历史
JessonGuo 提交于 3年前 . pointpillars
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation script"""
import argparse
import os
import warnings
from time import time
from mindspore import context, Tensor
from mindspore import dataset as de
from mindspore import load_checkpoint
from mindspore import load_param_into_net
from src.core.eval_utils import get_official_eval_result
from src.predict import predict
from src.predict import predict_kitti_to_anno
from src.utils import get_config
from src.utils import get_model_dataset
from src.utils import get_params_for_net
warnings.filterwarnings('ignore')
def run_evaluate(args):
"""run evaluate"""
cfg_path = args.cfg_path
ckpt_path = args.ckpt_path
cfg = get_config(cfg_path)
device_id = int(os.getenv('DEVICE_ID', '0'))
device_target = args.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, device_id=device_id)
model_cfg = cfg['model']
center_limit_range = model_cfg['post_center_limit_range']
pointpillarsnet, eval_dataset, box_coder = get_model_dataset(cfg, False)
params = load_checkpoint(ckpt_path)
new_params = get_params_for_net(params)
load_param_into_net(pointpillarsnet, new_params)
eval_input_cfg = cfg['eval_input_reader']
eval_column_names = eval_dataset.data_keys
ds = de.GeneratorDataset(
eval_dataset,
column_names=eval_column_names,
python_multiprocessing=True,
num_parallel_workers=6,
max_rowsize=100,
shuffle=False
)
batch_size = eval_input_cfg['batch_size']
ds = ds.batch(batch_size, drop_remainder=False)
data_loader = ds.create_dict_iterator(num_epochs=1, output_numpy=True)
class_names = list(eval_input_cfg['class_names'])
dt_annos = []
gt_annos = [info["annos"] for info in eval_dataset.kitti_infos]
log_freq = 100
len_dataset = len(eval_dataset)
start = time()
for i, data in enumerate(data_loader):
voxels = data["voxels"]
num_points = data["num_points"]
coors = data["coordinates"]
bev_map = data.get('bev_map', False)
preds = pointpillarsnet(Tensor(voxels), Tensor(num_points), Tensor(coors), Tensor(bev_map))
if len(preds) == 2:
preds = {
'box_preds': preds[0].asnumpy(),
'cls_preds': preds[1].asnumpy(),
}
else:
preds = {
'box_preds': preds[0].asnumpy(),
'cls_preds': preds[1].asnumpy(),
'dir_cls_preds': preds[2].asnumpy()
}
preds = predict(data, preds, model_cfg, box_coder)
dt_annos += predict_kitti_to_anno(preds,
data,
class_names,
center_limit_range)
if i % log_freq == 0 and i > 0:
time_used = time() - start
print(f'processed: {i * batch_size}/{len_dataset} imgs, time elapsed: {time_used} s',
flush=True)
result = get_official_eval_result(
gt_annos,
dt_annos,
class_names,
)
print(result)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--cfg_path', required=True, help='Path to config file.')
parser.add_argument('--ckpt_path', required=True, help='Path to checkpoint.')
parser.add_argument('--device_target', default='GPU', help='device target')
parser.add_argument('--is_modelarts', default='0', help='')
parser.add_argument('--data_url', default='', help='')
parser.add_argument('--train_url', default='', help='')
parse_args = parser.parse_args()
if parse_args.is_modelarts == '1':
import moxing as mox
data_dir = '/home/work/user-job-dir/data'
if not os.path.exists(data_dir):
os.mkdir(data_dir)
obs_data_url = parse_args.data_url
mox.file.copy_parallel(obs_data_url, data_dir)
print("Successfully Download {} to {}".format(obs_data_url, data_dir))
run_evaluate(parse_args)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mindspore/models.git
git@gitee.com:mindspore/models.git
mindspore
models
models
master

搜索帮助