From c05b1f383c3caabf55f5253ba4b9b0658e172e4c Mon Sep 17 00:00:00 2001 From: Ryan Date: Fri, 8 Apr 2022 10:10:49 +0800 Subject: [PATCH] =?UTF-8?q?VoxelPose=E6=A8=A1=E5=9E=8B1P=E6=80=A7=E8=83=BD?= =?UTF-8?q?=E8=84=9A=E6=9C=AC=E5=8E=BB=E6=8E=89Eval=E9=83=A8=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit VoxelPose模型1P性能脚本去掉Eval部分 --- .../pose_estimation/VoxelPose/run/train_3d.py | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/PyTorch/contrib/cv/pose_estimation/VoxelPose/run/train_3d.py b/PyTorch/contrib/cv/pose_estimation/VoxelPose/run/train_3d.py index 44ef94dd71..fbfd3a7755 100644 --- a/PyTorch/contrib/cv/pose_estimation/VoxelPose/run/train_3d.py +++ b/PyTorch/contrib/cv/pose_estimation/VoxelPose/run/train_3d.py @@ -251,22 +251,23 @@ def main(): # lr_scheduler.step() train_3d(config, model, optimizer, train_loader, epoch, final_output_dir, writer_dict, len(gpus), device=device, is_master_node=args.is_master_node, use_apex=args.apex) - precision = validate_3d(config, model, test_loader, final_output_dir, device=device, is_master_node=args.is_master_node) + if args.distributed: + precision = validate_3d(config, model, test_loader, final_output_dir, device=device, is_master_node=args.is_master_node) - if precision > best_precision: - best_precision = precision - best_model = True - else: - best_model = False - if args.is_master_node: - logger.info('=> saving checkpoint to {} (Best: {})'.format(final_output_dir, best_model)) - model_copy=copy.deepcopy(model).cpu() - save_checkpoint({ - 'epoch': epoch + 1, - 'state_dict': model_copy.module.state_dict(), - 'precision': best_precision, - 'optimizer': optimizer.state_dict(), - }, best_model, final_output_dir) + if precision > best_precision: + best_precision = precision + best_model = True + else: + best_model = False + if args.is_master_node: + logger.info('=> saving checkpoint to {} (Best: {})'.format(final_output_dir, best_model)) + model_copy=copy.deepcopy(model).cpu() + save_checkpoint({ + 'epoch': epoch + 1, + 'state_dict': model_copy.module.state_dict(), + 'precision': best_precision, + 'optimizer': optimizer.state_dict(), + }, best_model, final_output_dir) final_model_state_file = os.path.join(final_output_dir, 'final_state.pth.tar') -- Gitee