代码拉取完成,页面将自动刷新
### 此资源由 58学课资源站 收集整理 ###
# 想要获取完整课件资料 请访问:58xueke.com
# 百万资源 畅享学习
#
# coding: utf-8
import torch
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset
from transformers import BertTokenizer
from tqdm import tqdm
from common import constants
class CNewsDataset(Dataset): # 标签 文本
def __init__(self, filename):
# 数据集初始化
self.labels = ['体育', '娱乐', '家居', '房产', '教育', '时尚', '时政', '游戏', '科技', '财经']
self.labels_id = list(range(len(self.labels)))
self.tokenizer = BertTokenizer.from_pretrained(constants.BERT_PATH) #加载与训练模型
self.input_ids = [] # 一个句子转成id后数组
self.token_type_ids = [] #当前文本属于A还是B,bert支持2种文本输入的
self.attention_mask = [] #掩盖pad
self.label_id = []
self.load_data(filename)
def load_data(self, filename):
# 加载数据
print('loading data from:', filename)
with open(filename, 'r', encoding='utf-8') as rf:
lines = rf.readlines()
for line in tqdm(lines, ncols=100):
label, text = line.strip().split('\t')
label_id = self.labels.index(label)
token = self.tokenizer(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=512)
self.input_ids.append(np.array(token['input_ids']))
self.token_type_ids.append(np.array(token['token_type_ids']))
self.attention_mask.append(np.array(token['attention_mask']))
self.label_id.append(label_id)
def __getitem__(self, index):
return self.input_ids[index], self.token_type_ids[index], self.attention_mask[index], self.label_id[index]
def __len__(self):
return len(self.input_ids)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。