代码拉取完成,页面将自动刷新
import argparse
import mindspore
from mindspore.communication import init
from mindspore import context
from dataset.cifar100 import create_dataset
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model
from model.vgg import vgg8, vgg13, vgg16
model_dict = {
'vgg8' : vgg8,
'vgg13' : vgg13,
'vgg16' : vgg16
}
def parse_option():
parser = argparse.ArgumentParser('argument for training')
parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
parser.add_argument('--pre_trained', type=str, help='pre trained model path')
parser.add_argument('--model', type=str, default='vgg8', choices=['vgg8', 'vgg13', 'vgg16'])
options = parser.parse_args()
options.num_classes = 100
options.batch_norm = True
options.initialize_mode = "XavierUniform"
options.padding = 0
options.pad_mode = 'same'
options.has_bias = True
options.has_dropout = True
return options
def eval_net(test_set, options):
net = model_dict[options.model](options.num_classes, options)
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
model = Model(net, loss_fn=loss, metrics={'acc'})
param_dict = load_checkpoint(options.pre_trained)
load_param_into_net(net, param_dict)
net.set_train(False)
res = model.eval(test_set)
print("result: ", res)
def main():
options = parse_option()
mindspore.common.set_seed(233)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
init("nccl")
test_set, test_num = create_dataset('./data/cifar-100-python', train=False, batch_size=options.batch_size)
print("test size: ", test_num)
eval_net(test_set, options)
if __name__ == '__main__':
main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。