108 Star 885 Fork 1.5K

MindSpore/models

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
eval.py 2.37 KB
一键复制 编辑 原始数据 按行查看 历史
ghost 提交于 2022-06-01 21:23 +08:00 . add ViG
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval"""
from mindspore import Model
from mindspore import context
from mindspore import nn
from mindspore.common import set_seed
from src.args import args
from src.tools.cell import cast_amp
from src.tools.criterion import get_criterion, NetWithLoss
from src.tools.get_misc import get_dataset, set_device, get_model, pretrained, get_train_one_step
from src.tools.optimizer import get_optimizer
set_seed(args.seed)
def main():
mode = {
0: context.GRAPH_MODE,
1: context.PYNATIVE_MODE
}
context.set_context(mode=mode[args.graph_mode], device_target=args.device_target)
context.set_context(enable_graph_kernel=False)
set_device(args)
# get model
net = get_model(args)
cast_amp(net)
criterion = get_criterion(args)
net_with_loss = NetWithLoss(net, criterion)
if args.pretrained:
pretrained(args, net)
data = get_dataset(args, training=False)
batch_num = data.val_dataset.get_dataset_size()
optimizer = get_optimizer(args, net, batch_num)
# save a yaml file to read to record parameters
net_with_loss = get_train_one_step(args, net_with_loss, optimizer)
eval_network = nn.WithEvalCell(net, criterion, args.amp_level in ["O2", "O3", "auto"])
eval_indexes = [0, 1, 2]
eval_metrics = {'Loss': nn.Loss(),
'Top1-Acc': nn.Top1CategoricalAccuracy(),
'Top5-Acc': nn.Top5CategoricalAccuracy()}
model = Model(net_with_loss, metrics=eval_metrics,
eval_network=eval_network,
eval_indexes=eval_indexes)
print(f"=> begin eval")
results = model.eval(data.val_dataset)
print(f"=> eval results:{results}")
print(f"=> eval success")
if __name__ == '__main__':
main()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mindspore/models.git
git@gitee.com:mindspore/models.git
mindspore
models
models
master

搜索帮助