代码拉取完成,页面将自动刷新
本项目跟bert-classification-inference-trt搭配使用。
python train.py
即可训练模型,训练结束后会产生SavedModel。以下环境只代表能跑通的版本,其他版本也可能跑通。换句话说,这是跑通的充分条件,不是必要条件。
2张 RTX 3090。
环境 | 版本 |
---|---|
Python | 3.7 |
CUDA 驱动 | 11.4 |
CUDA API | 11.1 |
cuDNN | 8.0.4 |
pip install -r requirements.txt
数据是一个文本文件,一行一个样例,形式是标签\t文本
,如(示例数据来自THUCNews):
科技 读图选机 佳能中端产品HF M40全面评测 。读图选机 佳能中端产品HF M40全面评测作者: 陈道道日期: 2011-07-08文章属性:资讯。 【IT168 资讯】佳能HF M40读图选机。
社会 2名盗贼踩点尾随官员入室盗窃。 本报讯(记者王鹏昊)男子冯某和赵某,开锁学校毕业后,在昌平某镇政府蹲点跟踪尾随领导,摸清住址后伺机入室盗窃。昨天,昌平检察院透露,冯赵二人因涉嫌盗窃已被批捕。。 检方指控,男子冯某在唐山一家开锁学校培训期间与同学赵某相识。今年五月,来京打工的冯某找到赵某想用二人的开锁技术发财。为了不走空,二人将目标锁定为公务员。。 按照计划,二人来到昌平某镇政府,在镇政府的公开栏中有多名主要领导的信息及照片,二人用相机拍照记录。
股票 三一重工1亿欧元建欧洲研发制造基地。。 迄今为止中国在欧洲最大的一笔投资项目。 人民网长沙2月1日电 1月29日,三一重工与德国北威州政府投资协议,三一重工将在德国北威州的科隆市投资1亿欧元建设研发中心及机械制造基地,这是迄今为止中国在欧洲最大的一笔投资项目。
科技 女生使用的音质好机 三星YP-Q2现569元。 作者:白瑞。 三星YP-Q2其续航能力可达50小时音频播放或者4小时视频播放。拥有黑、白两种颜色,外观设计一如既往的清新明快。4GB内存的售价仅569元。。 三星YP-Q2。 容量:4GB 报价:569元。 编辑点评:拥有黑、白两款色彩的三星YP-Q2搭载了2.4英寸320×240分辨率屏幕,机身三围尺寸为49×101×9.9mm。
股票 传美政府将等通用汽车公布财报后决定售股事宜。 新浪财经讯 北京时间4月19日上午消息,据外电报道,一位知情人士透露,美国财政部将等通用汽车公司公布第一季度财报之后再决定是否出售所持的该公司部分股份。。 该人士称,美国财政部从5月22日起就能出售剩余的5亿股通用汽车股票,相当于通用汽车的33%股份。
科技 世博网络志愿者“接力赛”开跑。 本报讯(记者戎明迈)今日,21世纪最大规模的网络志愿者活动“宝马-腾讯世博网络志愿者接力”正式进入传递阶段。首批20100位志愿者从全国近4亿网友中脱颖而出,以腾讯世博网络志愿者“第一棒”的身份,同时向各自的Q Q好友发出活动邀请,号召更多的网友参与到传播世博信息、传递世博精神的行列中来。
注意事项:
nvidia-smi
,查看空闲的GPU。export CUDA_VISIBLE_DEVICES=1
,如果只有GPU 2空闲,则=
后面填2
。如果想用两块卡,=
后面填2,3
(或1,2
)。支持多GPU分布式训练。python train.py -t ../data/THUCNews.shuf5w.txt.train.txt -v ../data/THUCNews.shuf5w.txt.valid.txt -p ../bert_zh_L-12_H-768_A-12_2/ -e 3
train.py
。saved_model
目录就是最新训练好的SavedModel。训练结束后,如果出现以下错误,可忽略:
Exception ignored in: <function Pool.__del__ at 0x7fa4e283a550>
Traceback (most recent call last):
File "/root/miniconda3/lib/python3.8/multiprocessing/pool.py", line 268, in __del__
self._change_notifier.put(None)
File "/root/miniconda3/lib/python3.8/multiprocessing/queues.py", line 368, in put
self._writer.send_bytes(obj)
File "/root/miniconda3/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
self._send_bytes(m[offset:offset + size])
File "/root/miniconda3/lib/python3.8/multiprocessing/connection.py", line 411, in _send_bytes
self._send(header + buf)
File "/root/miniconda3/lib/python3.8/multiprocessing/connection.py", line 368, in _send
n = write(self._handle, buf)
OSError: [Errno 9] Bad file descriptor
steps_per_epoch
表示多少步为一轮(跟过去理解的整个训练集过一遍为一轮不一样),是超参数。epoch
表示轮数,是超参数。每轮结束后会打印一次训练集准确率和验证集准确率,如果想要更频繁地打印这两个指标,则应该减少steps_per_epoch
,增加epoch
。steps_per_epoch
x epoch
,达到训练步数时,训练停止。一般来说,越大越好,直到OOM。不过,如果训练集不大,batch_size也不宜太大, 40000条样例时,64和128的效果没有明显区别。
训练时,每轮会打印一次训练集的准确率accuracy和验证集的准确率val_accuracy。如果accuracy增加而val_accuracy下降,则发生了过拟合。
ModelCheckpoint
回调会在val_accuracy最大时,存储SavedModel。
有一个现象,在训练最初的几轮,accuracy明显小于val_accuracy。有两个原因:
所以,这是正常现象。
目前提供了两周pooling选项,cls和avg。如果是cls,则文本向量是cls位的向量。如果是avg,则文本向量是其他位置的向量的平均。 这两者区别不大,使用者可以自己做做实验。在THUCNews上,cls略胜一筹。
id2label.json
。bi_cls分支为二分类做了一些定制优化。
wget storage.googleapis.com/tfhub-modules/tensorflow/bert_zh_L-12_H-768_A-12/2.tar.gz
tar xzf 2.tar.gz
mkdir bert_zh_L-12_H-768_A-12_2
mv assets bert_zh_L-12_H-768_A-12_2
mv saved_model.pb bert_zh_L-12_H-768_A-12_2
mv variables bert_zh_L-12_H-768_A-12_2
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。