108 Star 871 Fork 1.5K

MindSpore/models

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train.py 4.82 KB
一键复制 编辑 原始数据 按行查看 历史
yide12 提交于 2024-08-05 14:57 +08:00 . change_usage_set_seed
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Train Retinaface_resnet50."""
from __future__ import print_function
import math
import argparse
import mindspore
from mindspore import ParallelMode
from mindspore.train import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.communication.management import init, get_rank, get_group_size
from src.config import cfg_res50
from src.network import RetinaFace, RetinaFaceWithLossCell, TrainingWrapper, resnet50
from src.loss import MultiBoxLoss
from src.dataset import create_dataset
from src.lr_schedule import adjust_learning_rate
def train(cfg, args):
mindspore.set_context(mode=0, device_target='GPU', save_graphs=False)
if mindspore.get_context("device_target") == "GPU":
# Enable graph kernel
mindspore.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_parallel_fusion")
if args.is_distributed:
init("nccl")
mindspore.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
cfg['ckpt_path'] = cfg['ckpt_path'] + "ckpt_" + str(get_rank()) + "/"
batch_size = cfg['batch_size']
max_epoch = cfg['epoch']
momentum = cfg['momentum']
weight_decay = cfg['weight_decay']
initial_lr = cfg['initial_lr']
gamma = cfg['gamma']
training_dataset = cfg['training_dataset']
num_classes = 2
negative_ratio = 7
stepvalues = (cfg['decay1'], cfg['decay2'])
ds_train = create_dataset(training_dataset, cfg, batch_size, multiprocessing=True, num_worker=cfg['num_workers'],
is_distribute=args.is_distributed)
print('dataset size is : \n', ds_train.get_dataset_size())
steps_per_epoch = math.ceil(ds_train.get_dataset_size())
multibox_loss = MultiBoxLoss(num_classes, cfg['num_anchor'], negative_ratio, cfg['batch_size'])
backbone = resnet50(1001)
backbone.set_train(True)
if cfg['pretrain'] and cfg['resume_net'] is None:
pretrained_res50 = cfg['pretrain_path']
param_dict_res50 = mindspore.load_checkpoint(pretrained_res50)
mindspore.load_param_into_net(backbone, param_dict_res50)
print('Load resnet50 from [{}] done.'.format(pretrained_res50))
net = RetinaFace(phase='train', backbone=backbone)
net.set_train(True)
if cfg['resume_net'] is not None:
pretrain_model_path = cfg['resume_net']
param_dict_retinaface = mindspore.load_checkpoint(pretrain_model_path)
mindspore.load_param_into_net(net, param_dict_retinaface)
print('Resume Model from [{}] Done.'.format(cfg['resume_net']))
net = RetinaFaceWithLossCell(net, multibox_loss, cfg)
lr = adjust_learning_rate(initial_lr, gamma, stepvalues, steps_per_epoch, max_epoch,
warmup_epoch=cfg['warmup_epoch'])
if cfg['optim'] == 'momentum':
opt = mindspore.nn.Momentum(net.trainable_params(), lr, momentum)
elif cfg['optim'] == 'sgd':
opt = mindspore.nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=momentum,
weight_decay=weight_decay, loss_scale=1)
else:
raise ValueError('optim is not define.')
net = TrainingWrapper(net, opt)
model = Model(net)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg['save_checkpoint_steps'],
keep_checkpoint_max=cfg['keep_checkpoint_max'])
ckpoint_cb = ModelCheckpoint(prefix="RetinaFace", directory=cfg['ckpt_path'], config=config_ck)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
callback_list = [LossMonitor(), time_cb, ckpoint_cb]
print("============== Starting Training ==============")
model.train(max_epoch, ds_train, callbacks=callback_list,
dataset_sink_mode=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser('MindSpore RetinaFace training')
parser.add_argument('--is_distributed', action='store_true', help='distributed training')
arg, _ = parser.parse_known_args()
config = cfg_res50
mindspore.set_seed(config.get('seed'))
print('train config:\n', config)
train(cfg=config, args=arg)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mindspore/models.git
git@gitee.com:mindspore/models.git
mindspore
models
models
r2.3

搜索帮助