1 Star 0 Fork 174

笨笨翔/PaddleNLP

forked from PaddlePaddle/PaddleNLP 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train.py 4.83 KB
一键复制 编辑 原始数据 按行查看 历史
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import defaultdict
from dataclasses import dataclass, field
import paddle
from paddle.metric import Accuracy
from utils import load_local_dataset
from paddlenlp.prompt import (
AutoTemplate,
PromptModelForSequenceClassification,
PromptTrainer,
PromptTuningArguments,
SoftVerbalizer,
)
from paddlenlp.trainer import EarlyStoppingCallback, PdArgumentParser
from paddlenlp.transformers import AutoModelForMaskedLM, AutoTokenizer
from paddlenlp.utils.log import logger
# yapf: disable
@dataclass
class DataArguments:
data_dir: str = field(default="./data/", metadata={"help": "Path to a dataset which includes train.txt, dev.txt, test.txt, label.txt and data.txt (optional)."})
prompt: str = field(default=None, metadata={"help": "The input prompt for tuning."})
@dataclass
class ModelArguments:
model_name_or_path: str = field(default="ernie-3.0-base-zh", metadata={"help": "Build-in pretrained model name or the path to local model."})
export_type: str = field(default='paddle', metadata={"help": "The type to export. Support `paddle` and `onnx`."})
# yapf: enable
def main():
# Parse the arguments.
parser = PdArgumentParser((ModelArguments, DataArguments, PromptTuningArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
training_args.print_config(model_args, "Model")
training_args.print_config(data_args, "Data")
paddle.set_device(training_args.device)
# Load the pretrained language model.
model = AutoModelForMaskedLM.from_pretrained(model_args.model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
# Define the template for preprocess and the verbalizer for postprocess.
template = AutoTemplate.create_from(data_args.prompt, tokenizer, training_args.max_seq_length, model=model)
logger.info("Using template: {}".format(template.prompt))
label_file = os.path.join(data_args.data_dir, "label.txt")
with open(label_file, "r", encoding="utf-8") as fp:
label_words = defaultdict(list)
for line in fp:
data = line.strip().split("==")
word = data[1] if len(data) > 1 else data[0].split("##")[-1]
label_words[data[0]].append(word)
verbalizer = SoftVerbalizer(label_words, tokenizer, model)
# Load the few-shot datasets.
train_ds, dev_ds, test_ds = load_local_dataset(
data_path=data_args.data_dir, splits=["train", "dev", "test"], label_list=verbalizer.labels_to_ids
)
# Define the criterion.
criterion = paddle.nn.CrossEntropyLoss()
# Initialize the prompt model with the above variables.
prompt_model = PromptModelForSequenceClassification(
model, template, verbalizer, freeze_plm=training_args.freeze_plm, freeze_dropout=training_args.freeze_dropout
)
# Define the metric function.
def compute_metrics(eval_preds):
metric = Accuracy()
correct = metric.compute(paddle.to_tensor(eval_preds.predictions), paddle.to_tensor(eval_preds.label_ids))
metric.update(correct)
acc = metric.accumulate()
return {"accuracy": acc}
# Deine the early-stopping callback.
callbacks = [EarlyStoppingCallback(early_stopping_patience=4, early_stopping_threshold=0.0)]
# Initialize the trainer.
trainer = PromptTrainer(
model=prompt_model,
tokenizer=tokenizer,
args=training_args,
criterion=criterion,
train_dataset=train_ds,
eval_dataset=dev_ds,
callbacks=callbacks,
compute_metrics=compute_metrics,
)
# Traininig.
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=None)
metrics = train_result.metrics
trainer.save_model()
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# Prediction.
if training_args.do_predict:
test_ret = trainer.predict(test_ds)
trainer.log_metrics("test", test_ret.metrics)
# Export static model.
if training_args.do_export:
export_path = os.path.join(training_args.output_dir, "export")
trainer.export_model(export_path, export_type=model_args.export_type)
if __name__ == "__main__":
main()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/lipengxiang/PaddleNLP.git
git@gitee.com:lipengxiang/PaddleNLP.git
lipengxiang
PaddleNLP
PaddleNLP
develop

搜索帮助