应用领域(Application Domain):Natural Language Processing
版本(Version):1.1
修改时间(Modified) :2021.7.17
大小(Size):552K
框架(Framework):TensorFlow 1.15.0
模型格式(Model Format):ckpt
精度(Precision):Mixed
应用级别(Categories):Official
描述(Description):基于TensorFlow框架的BERT网络在CLUE数据集上的finetune代码
BERT是一种与训练语言表示的方法,这意味着我们在大型文本语料库(如维基百科)上训练一个通用的”语言理解“模型,然后将该模型用于我们关心的下游NLP任务(如问答)。该工程提供了在CLUE数据集上finetune的方法。
参考论文:
参考实现:
适配昇腾 AI 处理器的实现:
通过Git获取对应commit_id的代码方法如下:
git clone {repository_url} # 克隆仓库的代码
cd {repository_name} # 切换到模型的代码仓目录
git checkout {branch} # 切换到对应分支
git reset --hard {commit_id} # 代码设置到对应的commit_id
cd {code_path} # 切换到模型代码所在路径,若仓库下只有该模型,则无需切换
训练超参
特性列表 | 是否支持 |
---|---|
分布式训练 | 是 |
混合精度 | 是 |
并行数据 | 否 |
昇腾910 AI处理器提供自动混合精度功能,可以针对全网中float32数据类型的算子,按照内置的优化策略,自动将部分float32的算子降低精度到float16,从而在精度损失很小的情况下提升系统性能并减少内存使用。
脚本已默认开启混合精度,设置precision_mode参数的脚本参考如下。
run_config = NPURunConfig(
model_dir=FLAGS.output_dir,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
iterations_per_loop=FLAGS.iterations_per_loop,
session_config=config,
precision_mode="allow_mix_precision",
keep_checkpoint_max=5)
硬件环境准备请参见各硬件产品文档"驱动和固件安装升级指南"。需要在硬件设备上安装与CANN版本配套的固件与驱动。
宿主机上需要安装Docker并登录Ascend Hub中心获取镜像。
当前模型支持的镜像列表如表1所示。
表 1 镜像列表
|
单击“立即下载”,并选择合适的下载方式下载源码包。
启动训练之前,首先要配置程序运行相关环境变量。
环境变量配置信息参见:
将环境变量配置到scrpts/run_*.sh中
TNEWS训练
启动单卡训练
修改test/train_full_tnews_1p.sh中的data_path,里面包括提前下载的数据集和预训练模型。
cd test
bash train_full_tnews_1p.sh
MSRANER训练
启动单卡训练
修改test/train_full_msraner_8p.sh中的data_path,里面包括提前下载的数据集和预训练模型。
cd test
bash run_8p.sh
└─BertMRPC_for_TensorFlow
├─test
| ├─train_full_tnews_1p.sh
| ├─train_full_msraner_8p.sh
├─CONTRIBUTING.md
├─create_pretraining_data.py
├─evaluate-v1.1.py
├─extract_features.py
├─gpu_environment.py
├─LICENSE
├─modeling.py
├─modeling_test.py
├─multilingual.md
├─optimization.py
├─optimization_test.py
├─README.md
├─run.sh
├─run_classifier.py
├─run_classifier_with_tfhub.py
├─run_pretraining.py
├─run_squad.py
├─tokenization.py
└─tokenization_test.py
TNEWS:
python3 run_classifier.py \
--task_name=tnews \
--do_train=true \
--do_eval=true \
--data_dir=$data_path/tnews \
--vocab_file=$data_path/chinese_L-12_H-768_A-12/vocab.txt \
--bert_config_file=$data_path/chinese_L-12_H-768_A-12/bert_config.json \
--init_checkpoint=$data_path/chinese_L-12_H-768_A-12/bert_model.ckpt \
--max_seq_length=128 \
--train_batch_size=${batch_size} \
--learning_rate=2e-5 \
--num_train_epochs=3.0 \
--output_dir=tnews_output \
MSRANER:
python3 run_ner.py \
--task_name=msraner \
--do_train=true \
--do_predict=true \
--data_path=$data_path/msraner \
--vocab_file=$data_path/chinese_L-12_H-768_A-12/vocab.txt \
--bert_config_file=$data_path/chinese_L-12_H-768_A-12/bert_config.json \
--init_checkpoint=$data_path/chinese_L-12_H-768_A-12/bert_model.ckpt \
--max_seq_length=128 \
--train_batch_size=${batch_size} \
--learning_rate=2e-5 \
--num_train_epochs=3.0 \
--output_dir=msraner_output \
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。