1 Star 0 Fork 0

xiong/LLMTuner

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
trainer.py 2.94 KB
一键复制 编辑 原始数据 按行查看 历史
zejunwang1 提交于 2年前 . initial commit
# coding=utf-8
import os
import torch
import transformers
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
class Trainer(transformers.Trainer):
# Avoid saving irrelevant files with large capacity
def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
# want to save except FullyShardedDDP.
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
if self.hp_search_backend is None and trial is None:
self.store_flos()
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
self.save_model(output_dir, _internal_call=True)
# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
metric_value = metrics[metric_to_check]
operator = np.greater if self.args.greater_is_better else np.less
if (
self.state.best_metric is None
or self.state.best_model_checkpoint is None
or operator(metric_value, self.state.best_metric)
):
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir
# Save the Trainer state
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
if self.args.push_to_hub:
self._push_from_checkpoint(output_dir)
# Maybe delete some older checkpoints.
if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
class LoraTrainer(transformers.Trainer):
# Only save LoRA weights
def _save(self, output_dir=None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
print(f"Saving model checkpoint to {output_dir}")
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/hacker_xwj/LLMTuner.git
git@gitee.com:hacker_xwj/LLMTuner.git
hacker_xwj
LLMTuner
LLMTuner
main

搜索帮助