# Copyright 2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Evaluation script""" import argparse import os import warnings from time import time from mindspore import context, Tensor from mindspore import dataset as de from mindspore import load_checkpoint from mindspore import load_param_into_net from src.core.eval_utils import get_official_eval_result from src.predict import predict from src.predict import predict_kitti_to_anno from src.utils import get_config from src.utils import get_model_dataset from src.utils import get_params_for_net warnings.filterwarnings('ignore') def run_evaluate(args): """run evaluate""" cfg_path = args.cfg_path ckpt_path = args.ckpt_path cfg = get_config(cfg_path) device_id = int(os.getenv('DEVICE_ID', '0')) device_target = args.device_target context.set_context(mode=context.GRAPH_MODE, device_target=device_target, device_id=device_id) model_cfg = cfg['model'] center_limit_range = model_cfg['post_center_limit_range'] pointpillarsnet, eval_dataset, box_coder = get_model_dataset(cfg, False) params = load_checkpoint(ckpt_path) new_params = get_params_for_net(params) load_param_into_net(pointpillarsnet, new_params) eval_input_cfg = cfg['eval_input_reader'] eval_column_names = eval_dataset.data_keys ds = de.GeneratorDataset( eval_dataset, column_names=eval_column_names, python_multiprocessing=True, num_parallel_workers=6, max_rowsize=100, shuffle=False ) batch_size = eval_input_cfg['batch_size'] ds = ds.batch(batch_size, drop_remainder=False) data_loader = ds.create_dict_iterator(num_epochs=1, output_numpy=True) class_names = list(eval_input_cfg['class_names']) dt_annos = [] gt_annos = [info["annos"] for info in eval_dataset.kitti_infos] log_freq = 100 len_dataset = len(eval_dataset) start = time() for i, data in enumerate(data_loader): voxels = data["voxels"] num_points = data["num_points"] coors = data["coordinates"] bev_map = data.get('bev_map', False) preds = pointpillarsnet(Tensor(voxels), Tensor(num_points), Tensor(coors), Tensor(bev_map)) if len(preds) == 2: preds = { 'box_preds': preds[0].asnumpy(), 'cls_preds': preds[1].asnumpy(), } else: preds = { 'box_preds': preds[0].asnumpy(), 'cls_preds': preds[1].asnumpy(), 'dir_cls_preds': preds[2].asnumpy() } preds = predict(data, preds, model_cfg, box_coder) dt_annos += predict_kitti_to_anno(preds, data, class_names, center_limit_range) if i % log_freq == 0 and i > 0: time_used = time() - start print(f'processed: {i * batch_size}/{len_dataset} imgs, time elapsed: {time_used} s', flush=True) result = get_official_eval_result( gt_annos, dt_annos, class_names, ) print(result) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--cfg_path', required=True, help='Path to config file.') parser.add_argument('--ckpt_path', required=True, help='Path to checkpoint.') parser.add_argument('--device_target', default='GPU', help='device target') parser.add_argument('--is_modelarts', default='0', help='') parser.add_argument('--data_url', default='', help='') parser.add_argument('--train_url', default='', help='') parse_args = parser.parse_args() if parse_args.is_modelarts == '1': import moxing as mox data_dir = '/home/work/user-job-dir/data' if not os.path.exists(data_dir): os.mkdir(data_dir) obs_data_url = parse_args.data_url mox.file.copy_parallel(obs_data_url, data_dir) print("Successfully Download {} to {}".format(obs_data_url, data_dir)) run_evaluate(parse_args)