Ai
1 Star 0 Fork 0

horn-learn/BertClassifier

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
model.py 1.26 KB
一键复制 编辑 原始数据 按行查看 历史
hornlive 提交于 2024-06-24 03:04 +08:00 . init
### 此资源由 58学课资源站 收集整理 ###
# 想要获取完整课件资料 请访问:58xueke.com
# 百万资源 畅享学习
#
# coding: utf-8
import torch
import torch.nn as nn
from transformers import BertModel
# Bert
class BertClassifier(nn.Module):
def __init__(self, bert_config, num_labels):
super().__init__()
# 定义BERT模型
self.bert = BertModel(config=bert_config)
# 定义分类器
self.classifier = nn.Linear(bert_config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask, token_type_ids):
# BERT的输出
# 分为两个部分,第一个元素是输入序列所有 token 的 Embedding 向量层,第二个变量是[CLS]位的隐层信息
# [CLS]id[SEP] [4 768] [[1_CLS], [2_CLS], [], []]
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
# 取[CLS]位置的pooled output [4, 768]
pooled = [1]
# 分类 [CLS] [] [] []
# [4, 512]
# [4, 512, 768]
# [4, 512, 768]
# [CLS]
# [4, 768] * [768, 10] = [4, 10]
logits = self.classifier(pooled)
# 返回softmax后结果
return torch.softmax(logits, dim=1)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/horn-learn/bert-classifier.git
git@gitee.com:horn-learn/bert-classifier.git
horn-learn
bert-classifier
BertClassifier
master

搜索帮助