1 Star 1 Fork 0

NLP小学生 / bert-classification-train-tf2

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
贡献代码
同步代码
取消
提示: 由于 Git 不支持空文件夾,创建文件夹后会生成空的 .keep 文件
Loading...
README

基于TensorFlow2的BERT文本分类

本项目跟bert-classification-inference-trt搭配使用。

特性

  • 环境简单。依赖少。
  • 使用简单。使用者只需要准备自己的训练数据,然后运行python train.py即可训练模型,训练结束后会产生SavedModel。
  • 参数清晰。只保留了最少的必要超参数。
  • 支持单机多卡训练。单卡和多卡使用的代码一样。

环境准备

以下环境只代表能跑通的版本,其他版本也可能跑通。换句话说,这是跑通的充分条件,不是必要条件。

显卡

2张 RTX 3090。

Python和CUDA

环境 版本
Python 3.7
CUDA 驱动 11.4
CUDA API 11.1
cuDNN 8.0.4

依赖

pip install -r requirements.txt

数据

数据是一个文本文件,一行一个样例,形式是标签\t文本,如(示例数据来自THUCNews):

科技	读图选机 佳能中端产品HF M40全面评测 。读图选机 佳能中端产品HF M40全面评测作者: 陈道道日期: 2011-07-08文章属性:资讯。 【IT168 资讯】佳能HF M40读图选机。
社会	2名盗贼踩点尾随官员入室盗窃。 本报讯(记者王鹏昊)男子冯某和赵某,开锁学校毕业后,在昌平某镇政府蹲点跟踪尾随领导,摸清住址后伺机入室盗窃。昨天,昌平检察院透露,冯赵二人因涉嫌盗窃已被批捕。。 检方指控,男子冯某在唐山一家开锁学校培训期间与同学赵某相识。今年五月,来京打工的冯某找到赵某想用二人的开锁技术发财。为了不走空,二人将目标锁定为公务员。。 按照计划,二人来到昌平某镇政府,在镇政府的公开栏中有多名主要领导的信息及照片,二人用相机拍照记录。
股票	三一重工1亿欧元建欧洲研发制造基地。。 迄今为止中国在欧洲最大的一笔投资项目。 人民网长沙2月1日电 1月29日,三一重工与德国北威州政府投资协议,三一重工将在德国北威州的科隆市投资1亿欧元建设研发中心及机械制造基地,这是迄今为止中国在欧洲最大的一笔投资项目。
科技	女生使用的音质好机 三星YP-Q2现569元。 作者:白瑞。 三星YP-Q2其续航能力可达50小时音频播放或者4小时视频播放。拥有黑、白两种颜色,外观设计一如既往的清新明快。4GB内存的售价仅569元。。 三星YP-Q2。 容量:4GB 报价:569元。 编辑点评:拥有黑、白两款色彩的三星YP-Q2搭载了2.4英寸320×240分辨率屏幕,机身三围尺寸为49×101×9.9mm。
股票	传美政府将等通用汽车公布财报后决定售股事宜。 新浪财经讯 北京时间4月19日上午消息,据外电报道,一位知情人士透露,美国财政部将等通用汽车公司公布第一季度财报之后再决定是否出售所持的该公司部分股份。。 该人士称,美国财政部从5月22日起就能出售剩余的5亿股通用汽车股票,相当于通用汽车的33%股份。
科技	世博网络志愿者“接力赛”开跑。 本报讯(记者戎明迈)今日,21世纪最大规模的网络志愿者活动“宝马-腾讯世博网络志愿者接力”正式进入传递阶段。首批20100位志愿者从全国近4亿网友中脱颖而出,以腾讯世博网络志愿者“第一棒”的身份,同时向各自的Q Q好友发出活动邀请,号召更多的网友参与到传播世博信息、传递世博精神的行列中来。

注意事项:

  • 数据包括训练集、验证集和测试集。训练集用于训练模型,验证集用于查看是否过拟合和选择模型,测试集用于评估。验证集和测试集要符合真实数据分布,训练集不必。
  • 训练集要事先shuffle。

训练

  1. 执行nvidia-smi,查看空闲的GPU。
  2. 执行export CUDA_VISIBLE_DEVICES=1,如果只有GPU 2空闲,则=后面填2。如果想用两块卡,=后面填2,3(或1,2)。支持多GPU分布式训练。
  3. 开始训练,执行
    python train.py -t ../data/THUCNews.shuf5w.txt.train.txt -v ../data/THUCNews.shuf5w.txt.valid.txt -p ../bert_zh_L-12_H-768_A-12_2/ -e 3
    可忽略WARNING和UserWarning。更多超参数查看train.py
  4. 等待训练完成,saved_model目录就是最新训练好的SavedModel。

训练结束后,如果出现以下错误,可忽略:

Exception ignored in: <function Pool.__del__ at 0x7fa4e283a550>
Traceback (most recent call last):
  File "/root/miniconda3/lib/python3.8/multiprocessing/pool.py", line 268, in __del__
    self._change_notifier.put(None)
  File "/root/miniconda3/lib/python3.8/multiprocessing/queues.py", line 368, in put
    self._writer.send_bytes(obj)
  File "/root/miniconda3/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/root/miniconda3/lib/python3.8/multiprocessing/connection.py", line 411, in _send_bytes
    self._send(header + buf)
  File "/root/miniconda3/lib/python3.8/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
OSError: [Errno 9] Bad file descriptor

原理和细节

超参数详细解释

  • 一个step对应一个batch的参数更新。
  • steps_per_epoch表示多少步为一轮(跟过去理解的整个训练集过一遍为一轮不一样),是超参数。
  • epoch表示轮数,是超参数。每轮结束后会打印一次训练集准确率和验证集准确率,如果想要更频繁地打印这两个指标,则应该减少steps_per_epoch,增加epoch
  • num_train_steps是训练步数,等于steps_per_epoch x epoch,达到训练步数时,训练停止。
  • num_train_steps x batch_size应该尽量是训练集大小的整数倍(不是也没关系)。

如何选取batch_size

一般来说,越大越好,直到OOM。不过,如果训练集不大,batch_size也不宜太大, 40000条样例时,64和128的效果没有明显区别。

如何识别过拟合

训练时,每轮会打印一次训练集的准确率accuracy和验证集的准确率val_accuracy。如果accuracy增加而val_accuracy下降,则发生了过拟合。 ModelCheckpoint回调会在val_accuracy最大时,存储SavedModel。

有一个现象,在训练最初的几轮,accuracy明显小于val_accuracy。有两个原因:

  • accuracy的计算是训练时每个step累加的,而val_accuracy是这轮结束后才计算的。
  • 训练时开启了Dropout。

所以,这是正常现象。

pooling方法选哪个?

目前提供了两周pooling选项,cls和avg。如果是cls,则文本向量是cls位的向量。如果是avg,则文本向量是其他位置的向量的平均。 这两者区别不大,使用者可以自己做做实验。在THUCNews上,cls略胜一筹。

注意事项

  • 该项目跟部署项目搭配使用,所以代码中会有一些多余的跟训练无关的命令,比如存储id2label.json

bi_cls分支

bi_cls分支为二分类做了一些定制优化。

附录

附录1 下载BERT预训练模型

wget storage.googleapis.com/tfhub-modules/tensorflow/bert_zh_L-12_H-768_A-12/2.tar.gz
tar xzf 2.tar.gz
mkdir bert_zh_L-12_H-768_A-12_2
mv assets bert_zh_L-12_H-768_A-12_2
mv saved_model.pb bert_zh_L-12_H-768_A-12_2
mv variables bert_zh_L-12_H-768_A-12_2

空文件

简介

取消

发行版

暂无发行版

贡献者

全部

近期动态

加载更多
不能加载更多了
1
https://gitee.com/nlp_pupil/bert-classification-train-tf2.git
git@gitee.com:nlp_pupil/bert-classification-train-tf2.git
nlp_pupil
bert-classification-train-tf2
bert-classification-train-tf2
master

搜索帮助