Ai
1 Star 5 Fork 1

cubone/learnopencv

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
test.py 2.34 KB
一键复制 编辑 原始数据 按行查看 历史
Koriukina, Valeriia 提交于 2020-07-16 06:30 +08:00 . added face mask overlay post
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Created by Tianheng Cheng(tianhengcheng@gmail.com)
# ------------------------------------------------------------------------------
import os
import pprint
import argparse
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
import lib.models as models
from lib.config import config, update_config
from lib.utils import utils
from lib.datasets import get_dataset
from lib.core import function
def parse_args():
parser = argparse.ArgumentParser(description='Train Face Alignment')
parser.add_argument('--cfg', help='experiment configuration filename',
required=True, type=str)
parser.add_argument('--model-file', help='model parameters', required=True, type=str)
args = parser.parse_args()
update_config(config, args)
return args
def main():
args = parse_args()
logger, final_output_dir, tb_log_dir = \
utils.create_logger(config, args.cfg, 'test')
logger.info(pprint.pformat(args))
logger.info(pprint.pformat(config))
cudnn.benchmark = config.CUDNN.BENCHMARK
cudnn.determinstic = config.CUDNN.DETERMINISTIC
cudnn.enabled = config.CUDNN.ENABLED
config.defrost()
config.MODEL.INIT_WEIGHTS = False
config.freeze()
model = models.get_face_alignment_net(config)
gpus = list(config.GPUS)
model = nn.DataParallel(model, device_ids=gpus).cuda()
# load model
state_dict = torch.load(args.model_file)
if 'state_dict' in state_dict.keys():
state_dict = state_dict['state_dict']
model.load_state_dict(state_dict)
else:
model.module.load_state_dict(state_dict)
dataset_type = get_dataset(config)
test_loader = DataLoader(
dataset=dataset_type(config,
is_train=False),
batch_size=config.TEST.BATCH_SIZE_PER_GPU*len(gpus),
shuffle=False,
num_workers=config.WORKERS,
pin_memory=config.PIN_MEMORY
)
nme, predictions = function.inference(config, test_loader, model)
torch.save(predictions, os.path.join(final_output_dir, 'predictions.pth'))
if __name__ == '__main__':
main()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/cubone/learnopencv.git
git@gitee.com:cubone/learnopencv.git
cubone
learnopencv
learnopencv
master

搜索帮助