代码拉取完成,页面将自动刷新
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Created by Tianheng Cheng(tianhengcheng@gmail.com)
# ------------------------------------------------------------------------------
import os
import pprint
import argparse
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
import lib.models as models
from lib.config import config, update_config
from lib.utils import utils
from lib.datasets import get_dataset
from lib.core import function
def parse_args():
parser = argparse.ArgumentParser(description='Train Face Alignment')
parser.add_argument('--cfg', help='experiment configuration filename',
required=True, type=str)
parser.add_argument('--model-file', help='model parameters', required=True, type=str)
args = parser.parse_args()
update_config(config, args)
return args
def main():
args = parse_args()
logger, final_output_dir, tb_log_dir = \
utils.create_logger(config, args.cfg, 'test')
logger.info(pprint.pformat(args))
logger.info(pprint.pformat(config))
cudnn.benchmark = config.CUDNN.BENCHMARK
cudnn.determinstic = config.CUDNN.DETERMINISTIC
cudnn.enabled = config.CUDNN.ENABLED
config.defrost()
config.MODEL.INIT_WEIGHTS = False
config.freeze()
model = models.get_face_alignment_net(config)
gpus = list(config.GPUS)
model = nn.DataParallel(model, device_ids=gpus).cuda()
# load model
state_dict = torch.load(args.model_file)
if 'state_dict' in state_dict.keys():
state_dict = state_dict['state_dict']
model.load_state_dict(state_dict)
else:
model.module.load_state_dict(state_dict)
dataset_type = get_dataset(config)
test_loader = DataLoader(
dataset=dataset_type(config,
is_train=False),
batch_size=config.TEST.BATCH_SIZE_PER_GPU*len(gpus),
shuffle=False,
num_workers=config.WORKERS,
pin_memory=config.PIN_MEMORY
)
nme, predictions = function.inference(config, test_loader, model)
torch.save(predictions, os.path.join(final_output_dir, 'predictions.pth'))
if __name__ == '__main__':
main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。