1 Star 0 Fork 0

zeng/pytracking

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
run_training.py 1.77 KB
一键复制 编辑 原始数据 按行查看 历史
import os
import sys
import argparse
import importlib
import multiprocessing
import cv2 as cv
import torch.backends.cudnn
env_path = os.path.join(os.path.dirname(__file__), '..')
if env_path not in sys.path:
sys.path.append(env_path)
import ltr.admin.settings as ws_settings
def run_training(train_module, train_name, cudnn_benchmark=True):
"""Run a train scripts in train_settings.
args:
train_module: Name of module in the "train_settings/" folder.
train_name: Name of the train settings file.
cudnn_benchmark: Use cudnn benchmark or not (default is True).
"""
# This is needed to avoid strange crashes related to opencv
cv.setNumThreads(0)
torch.backends.cudnn.benchmark = cudnn_benchmark
print('Training: {} {}'.format(train_module, train_name))
settings = ws_settings.Settings()
settings.module_name = train_module
settings.script_name = train_name
settings.project_path = 'ltr/{}/{}'.format(train_module, train_name)
expr_module = importlib.import_module('ltr.train_settings.{}.{}'.format(train_module, train_name))
expr_func = getattr(expr_module, 'run')
expr_func(settings)
def main():
parser = argparse.ArgumentParser(description='Run a train scripts in train_settings.')
parser.add_argument('train_module', type=str, help='Name of module in the "train_settings/" folder.')
parser.add_argument('train_name', type=str, help='Name of the train settings file.')
parser.add_argument('--cudnn_benchmark', type=bool, default=True, help='Set cudnn benchmark on (1) or off (0) (default is on).')
args = parser.parse_args()
run_training(args.train_module, args.train_name, args.cudnn_benchmark)
if __name__ == '__main__':
multiprocessing.set_start_method('spawn', force=True)
main()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zengzheming/pytracking.git
git@gitee.com:zengzheming/pytracking.git
zengzheming
pytracking
pytracking
master

搜索帮助