msModelSlim工具支持API方式的蒸馏调优。蒸馏调优时,用户只需要提供teacher模型、student模型和数据集,调用API接口完成模型的蒸馏调优过程。
目前支持MindSpore和PyTorch框架下Transformer类模型的蒸馏调优,执行前需参考环境准备完成开发环境部署、Python环境变量、所需框架及训练服务器环境变量配置。
模型蒸馏期间,用户可将原始Transformer模型、配置较小参数的Transformer模型分别作为teacher和student进行知识蒸馏。通过手动配置参数,返回一个待蒸馏的DistillDualModels模型实例,用户对其进行训练。训练完毕后,从DistillDualModels模型实例获取训练后的student模型,即通过蒸馏训练后的模型。
以下步骤以PyTorch框架的模型为例,MindSpore框架的模型仅在调用部分接口时,入参配置有所差异,使用时请参照具体的API接口说明。
用户自行准备原始Transformer模型、配置较小参数的Transformer模型,分别作为模型蒸馏调优的teacher模型和student模型。本样例以Bert为例,在ModelZoo搜索下载Bert代码和原模型权重文件。
新建待蒸馏模型的Python脚本,例如distill_model.py。编辑distill_model.py文件,导入如下接口。蒸馏API接口说明请参考蒸馏接口。
from msmodelslim.common.knowledge_distill.knowledge_distill import KnowledgeDistillConfig, get_distill_model
from msmodelslim import set_logger_level
set_logger_level("info") #根据实际情况配置
distill_config = KnowledgeDistillConfig()
distill_config.add_output_soft_label({
"t_output_idx": 1,
"s_output_idx": 1,
"loss_func": [{"func_name": "KDCrossEntropy",
"func_weight": 1,
"temperature": 1}]})
distill_model = get_distill_model(teacher_model, student_model, distill_config) #请传入teacher模型、student模型的实例
student_model = distill_model.get_student_model()
python3 distill_model.py
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。