Ai
204 Star 1.3K Fork 1.2K

Ascend/MindSpeed-LLM
暂停

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
convert_ckpt.py 5.53 KB
一键复制 编辑 原始数据 按行查看 历史
轩辕敏峥 提交于 2025-09-16 10:37 +08:00 . !3266[pytorch][sh]seed-oss pretrain/sft
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
import argparse
import importlib
import os
import sys
from functools import wraps
import logging as logger
import torch.multiprocessing as mp
from mindspeed_llm import megatron_adaptor
import pretrain_gpt
from mindspeed_llm.tasks.posttrain.orm.orm_trainer import ORMTrainer
MODULE_ROOT = "mindspeed_llm.tasks.checkpoint"
def load_plugin(plugin_type, name):
if name == '':
module_name = f"{MODULE_ROOT}.{plugin_type}"
else:
module_name = f"{MODULE_ROOT}.{plugin_type}_{name}"
try:
plugin = importlib.import_module(module_name)
except ModuleNotFoundError:
module_name = f"{MODULE_ROOT}.{name}"
try:
plugin = importlib.import_module(module_name)
except ModuleNotFoundError:
sys.exit(f"Unable to load {plugin_type} plugin {name}. Exiting.")
if not hasattr(plugin, 'add_arguments'):
sys.exit(f"{module_name} module is not a plugin. Exiting.")
logger.info(f"Loaded {module_name} as the {plugin_type}.")
return plugin
def main():
parser = argparse.ArgumentParser(description="Megatron Checkpoint Utility Arguments",
allow_abbrev=False, conflict_handler='resolve')
parser.add_argument('--model-type', type=str, required=True,
choices=['GPT', 'BERT'],
help='Type of the model')
parser.add_argument('--loader', type=str, default='megatron',
help='Module name to load checkpoint, should be on python path')
parser.add_argument('--load-model-type', type=str, nargs='?',
default=None, const=None, choices=['hf', 'mg'],
help='Module name to load checkpoint, should be on python path')
parser.add_argument('--saver', type=str, default='megatron',
help='Module name to save checkpoint, should be on python path')
parser.add_argument('--load-dir', type=str, required=True,
help='Directory to load model checkpoint from')
parser.add_argument('--save-dir', type=str, required=True,
help='Directory to save model checkpoint to')
parser.add_argument('--max-queue-size', type=int, default=50,
help='Maximum number of tensors in the queue')
parser.add_argument('--no-checking', action='store_false',
help='Do not perform checking on the name and ordering of weights',
dest='checking')
parser.add_argument('--spec', type=str, default=None, nargs='*',
help='Specify the <module_location function_name> pair '
'that returns a spec to customize transformer layer, depending on the use case.')
parser.add_argument('--model-type-hf', type=str, default="llama2",
choices=['baichuan', 'baichuan2', 'llama2', 'mixtral', 'chatglm3', 'gemma', 'gemma2', 'qwen3',
'bloom', 'bloom_3b', 'qwen', 'internlm2', 'deepseek2', 'minicpm', 'minicpm3', 'minicpm-moe',
'deepseek2-lite', 'qwen2-moe', 'qwen3-moe', 'phi3.5', 'phi3.5-moe', 'hunyuan', 'glm4', 'seed-oss'],
help='model type of huggingface')
parser.add_argument('--ckpt-cfg-path', type=str, default="configs/checkpoint/model_cfg.json",
help="Path to the config directory. If not specified, the default path in the repository will be used.")
parser.add_argument('--qlora-nf4', action='store_true',
help='use bitsandbytes nf4 to quantize model.')
parser.add_argument('--orm', action="store_true", default=False,
help='Specify the ORM ckpt conversion, convert additional rm_head layer in ORM.')
parser.add_argument('--save-lora-to-hf', action='store_true', default=False,
help='Enable only save lora-checkpoint to hf')
parser.add_argument('--load-checkpoint-loosely', action='store_true', default=False,
help='Enable loading checkpoint not strictly.')
parser.add_argument('--ckpt-format', default='torch',
choices=['torch', 'torch_dist', 'zarr'],
help='Checkpoint format to use.')
known_args, _ = parser.parse_known_args()
use_saver = known_args.load_model_type is None
if use_saver:
loader = load_plugin('loader', known_args.loader)
saver = load_plugin('saver', known_args.saver)
else:
loader = load_plugin('loader', known_args.load_model_type)
saver = load_plugin('saver', '')
loader.add_arguments(parser)
saver.add_arguments(parser)
args = parser.parse_args()
queue = mp.Queue(maxsize=args.max_queue_size)
model_provider = ORMTrainer.model_provider if args.orm else pretrain_gpt.model_provider
if args.orm and not args.use_mcore_models:
raise AssertionError("Currently Outcome Reward Model only support Mcore models")
logger.info("Starting saver...")
saver_proc = mp.Process(target=saver.save_model_checkpoint, args=(model_provider, queue, args))
saver_proc.start()
logger.info("Starting loader...")
loader.load_checkpoint(model_provider, queue, args)
logger.info("Waiting for saver to complete...")
saver_proc.join()
if saver_proc.exitcode is not None and saver_proc.exitcode != 0:
logger.error(f"saver process exited with error code {saver_proc.exitcode}")
sys.exit(saver_proc.exitcode)
if __name__ == '__main__':
main()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/ascend/MindSpeed-LLM.git
git@gitee.com:ascend/MindSpeed-LLM.git
ascend
MindSpeed-LLM
MindSpeed-LLM
master

搜索帮助