1 Star 0 Fork 0

王禹程/MindKD

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
eval_teacher.py 1.81 KB
一键复制 编辑 原始数据 按行查看 历史
王禹程 提交于 2022-03-08 06:48 . fix run_teacher
import argparse
import mindspore
from mindspore.communication import init
from mindspore import context
from dataset.cifar100 import create_dataset
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model
from model.vgg import vgg8, vgg13, vgg16
model_dict = {
'vgg8' : vgg8,
'vgg13' : vgg13,
'vgg16' : vgg16
}
def parse_option():
parser = argparse.ArgumentParser('argument for training')
parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
parser.add_argument('--pre_trained', type=str, help='pre trained model path')
parser.add_argument('--model', type=str, default='vgg8', choices=['vgg8', 'vgg13', 'vgg16'])
options = parser.parse_args()
options.num_classes = 100
options.batch_norm = True
options.initialize_mode = "XavierUniform"
options.padding = 0
options.pad_mode = 'same'
options.has_bias = True
options.has_dropout = True
return options
def eval_net(test_set, options):
net = model_dict[options.model](options.num_classes, options)
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
model = Model(net, loss_fn=loss, metrics={'acc'})
param_dict = load_checkpoint(options.pre_trained)
load_param_into_net(net, param_dict)
net.set_train(False)
res = model.eval(test_set)
print("result: ", res)
def main():
options = parse_option()
mindspore.common.set_seed(233)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
init("nccl")
test_set, test_num = create_dataset('./data/cifar-100-python', train=False, batch_size=options.batch_size)
print("test size: ", test_num)
eval_net(test_set, options)
if __name__ == '__main__':
main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/reku1997/mind-kd.git
git@gitee.com:reku1997/mind-kd.git
reku1997
mind-kd
MindKD
main

搜索帮助

0d507c66 1850385 C8b1a773 1850385