Ai
107 Star 890 Fork 1.4K

MindSpore/models
暂停

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
eval.py 5.34 KB
一键复制 编辑 原始数据 按行查看 历史
Shawny 提交于 2024-06-13 17:51 +08:00 . update context API
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
GPT evaluation script.
"""
import math
import argparse
import numpy as np
import mindspore
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.nn.transformer.loss import CrossEntropyLoss
from mindspore.nn.transformer.transformer import TransformerOpParallelConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.inference import generate
from src.dataset import create_dataset
from src.gpt import GPT, EvalNet, GPTWithLoss
from src.utils import GPTConfig
mindspore.set_context(mode=0)
def ppl_score(probs, length, is_logsoftmax=True):
""" calculate perplexity with prob or log_prob inputs """
probs = probs[:length]
if is_logsoftmax:
prob = np.sum(probs) / length
ppl = 1.0 / np.power(np.e, prob)
else:
prob = 1.0
for p in probs:
prob *= (1.0 / p)
ppl = np.power(prob, 1.0/length)
return ppl
def get_ppl(model, dataset):
""" calculate perplexity for input dataset """
PPL = []
tokens = 0
for data in dataset:
data = data[0].asnumpy()
input_ids = data
input_mask = (data != 0).astype(np.float32)
logits = model(Tensor(input_ids, mstype.int32), Tensor(input_mask, mstype.float32)).asnumpy()
PPL.append(logits * len(data))
tokens += len(data)
val_loss = sum(PPL) / tokens
ppl = math.exp(min(20, val_loss))
return ppl
def get_acc(model, dataset):
""" calculate accuracy for input dataset """
total_num = 0
acc_num = 0
for data in dataset:
data = data[0].asnumpy()
input_mask = (data != 0).astype(np.int32)
length = np.sum(input_mask, 1)
label = np.zeros(length.shape)
for i, idx in enumerate(length):
label[i] = data[i][idx-1]
input_mask[i][idx-1] = 0
data[i][idx-1] = 0
logits = model(Tensor(data, mstype.int32), Tensor(input_mask, mstype.float32)).asnumpy()
logits = logits.reshape(len(length), -1)
predicted_label = np.zeros(len(length))
for i, idx in enumerate(length):
predicted_label[i] = logits[i][idx-2]
total_num += len(label)
acc_num += sum(label == predicted_label)
acc = acc_num / total_num
return acc
def run_eval():
""" evaluate scripts """
parser = argparse.ArgumentParser(description="GPT inferencing")
parser.add_argument('--task_type', type=str, default="", help="Evaluation task.")
parser.add_argument('--metrics', type=str, default="acc", choices=["ppl", "acc"], help="Evaluation metrics.")
parser.add_argument('--ckpt_path', type=str, default="", help="path of checkpoint file.")
parser.add_argument('--data_path', type=str, default="", help="path of MindRecord file.")
args = parser.parse_args()
task = args.task_type
metrics = args.metrics
ckpt_path = args.ckpt_path
if task not in ["generate", "lambada", "wikitext"]:
raise ValueError("{} is not supported now".format(task))
if metrics not in ["acc", "ppl"]:
raise ValueError("{} is not supported now".format(metrics))
config = GPTConfig(batch_size=1,
seq_length=1024,
vocab_size=50257,
embedding_size=1024,
num_layers=24,
num_heads=16,
expand_ratio=4,
post_layernorm_residual=False,
dropout_rate=0.0,
compute_dtype=mstype.float16,
use_past=False)
ckpt_dict = load_checkpoint(ckpt_path)
gpt = GPT(config)
if task == "generate":
gpt_eval = EvalNet(gpt, generate=True)
elif metrics == "acc":
gpt_eval = EvalNet(gpt, generate=False)
else:
parallel_config = TransformerOpParallelConfig()
loss = CrossEntropyLoss(parallel_config.dp_mp_config)
gpt_eval = GPTWithLoss(gpt, loss, eos_token=0)
gpt_eval.set_train(False)
load_param_into_net(gpt_eval, ckpt_dict)
if task == "generate":
start_sentence = [6170, 318, 257]
input_ids = np.array(start_sentence).reshape(1, -1)
outputs = generate(gpt_eval, input_ids, config.seq_length)
output_list = outputs.tolist()
print("output id is ", output_list)
else:
data_path = args.data_path
eval_dataset = create_dataset(config.batch_size, data_path=data_path, drop=False)
if metrics == "acc":
acc = get_acc(gpt_eval, eval_dataset)
print("Accuracy is ", acc)
elif metrics == "ppl":
ppl = get_ppl(gpt_eval, eval_dataset)
print("Perplexity is ", ppl)
if __name__ == "__main__":
run_eval()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mindspore/models.git
git@gitee.com:mindspore/models.git
mindspore
models
models
master

搜索帮助