代码拉取完成,页面将自动刷新
# -*- coding: utf-8 -*-
# file: data_utils.py
# author: songyouwei <youwei0314@gmail.com>
# Copyright (C) 2018. All Rights Reserved.
import os
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer
def build_tokenizer(fnames, max_seq_len, dat_fname):
if os.path.exists(dat_fname):
print('loading tokenizer:', dat_fname)
tokenizer = pickle.load(open(dat_fname, 'rb'))
else:
text = ''
for fname in fnames:
fin = open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
lines = fin.readlines()
fin.close()
for i in range(0, len(lines), 3):
text_left, _, text_right = [s.lower().strip() for s in lines[i].partition("$T$")]
aspect = lines[i + 1].lower().strip()
text_raw = text_left + " " + aspect + " " + text_right
text += text_raw + " "
tokenizer = Tokenizer(max_seq_len)
tokenizer.fit_on_text(text)
pickle.dump(tokenizer, open(dat_fname, 'wb'))
return tokenizer
def _load_word_vec(path, word2idx=None, embed_dim=300):
fin = open(path, 'r', encoding='utf-8', newline='\n', errors='ignore')
word_vec = {}
for line in fin:
tokens = line.rstrip().split()
word, vec = ' '.join(tokens[:-embed_dim]), tokens[-embed_dim:]
if word in word2idx.keys():
word_vec[word] = np.asarray(vec, dtype='float32')
return word_vec
def build_embedding_matrix(word2idx, embed_dim, dat_fname):
if os.path.exists(dat_fname):
print('loading embedding_matrix:', dat_fname)
embedding_matrix = pickle.load(open(dat_fname, 'rb'))
else:
print('loading word vectors...')
embedding_matrix = np.zeros((len(word2idx) + 2, embed_dim)) # idx 0 and len(word2idx)+1 are all-zeros
fname = './glove.twitter.27B/glove.twitter.27B.' + str(embed_dim) + 'd.txt' \
if embed_dim != 300 else './glove.42B.300d.txt'
word_vec = _load_word_vec(fname, word2idx=word2idx, embed_dim=embed_dim)
print('building embedding_matrix:', dat_fname)
for word, i in word2idx.items():
vec = word_vec.get(word)
if vec is not None:
# words not found in embedding index will be all-zeros.
embedding_matrix[i] = vec
pickle.dump(embedding_matrix, open(dat_fname, 'wb'))
return embedding_matrix
def pad_and_truncate(sequence, maxlen, dtype='int64', padding='post', truncating='post', value=0):
x = (np.ones(maxlen) * value).astype(dtype)
if truncating == 'pre':
trunc = sequence[-maxlen:]
else:
trunc = sequence[:maxlen]
trunc = np.asarray(trunc, dtype=dtype)
if padding == 'post':
x[:len(trunc)] = trunc
else:
x[-len(trunc):] = trunc
return x
class Tokenizer(object):
def __init__(self, max_seq_len, lower=True):
self.lower = lower
self.max_seq_len = max_seq_len
self.word2idx = {}
self.idx2word = {}
self.idx = 1
def fit_on_text(self, text):
if self.lower:
text = text.lower()
words = text.split()
for word in words:
if word not in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
def text_to_sequence(self, text, reverse=False, padding='post', truncating='post'):
if self.lower:
text = text.lower()
words = text.split()
unknownidx = len(self.word2idx)+1
sequence = [self.word2idx[w] if w in self.word2idx else unknownidx for w in words]
if len(sequence) == 0:
sequence = [0]
if reverse:
sequence = sequence[::-1]
return pad_and_truncate(sequence, self.max_seq_len, padding=padding, truncating=truncating)
class Tokenizer4Bert:
def __init__(self, max_seq_len, pretrained_bert_name):
self.tokenizer = BertTokenizer.from_pretrained(pretrained_bert_name)
self.max_seq_len = max_seq_len
def text_to_sequence(self, text, reverse=False, padding='post', truncating='post'):
sequence = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
if len(sequence) == 0:
sequence = [0]
if reverse:
sequence = sequence[::-1]
return pad_and_truncate(sequence, self.max_seq_len, padding=padding, truncating=truncating)
class ABSADataset(Dataset):
def __init__(self, fname, tokenizer):
fin = open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
lines = fin.readlines()
fin.close()
fin = open(fname+'.graph', 'rb')
idx2graph = pickle.load(fin)
fin.close()
all_data = []
for i in range(0, len(lines), 3):
text_left, _, text_right = [s.lower().strip() for s in lines[i].partition("$T$")]
aspect = lines[i + 1].lower().strip()
polarity = lines[i + 2].strip()
text_indices = tokenizer.text_to_sequence(text_left + " " + aspect + " " + text_right)
context_indices = tokenizer.text_to_sequence(text_left + " " + text_right)
left_indices = tokenizer.text_to_sequence(text_left)
left_with_aspect_indices = tokenizer.text_to_sequence(text_left + " " + aspect)
right_indices = tokenizer.text_to_sequence(text_right, reverse=True)
right_with_aspect_indices = tokenizer.text_to_sequence(aspect + " " + text_right, reverse=True)
aspect_indices = tokenizer.text_to_sequence(aspect)
left_len = np.sum(left_indices != 0)
aspect_len = np.sum(aspect_indices != 0)
aspect_boundary = np.asarray([left_len, left_len + aspect_len - 1], dtype=np.int64)
polarity = int(polarity) + 1
text_len = np.sum(text_indices != 0)
concat_bert_indices = tokenizer.text_to_sequence('[CLS] ' + text_left + " " + aspect + " " + text_right + ' [SEP] ' + aspect + " [SEP]")
concat_segments_indices = [0] * (text_len + 2) + [1] * (aspect_len + 1)
concat_segments_indices = pad_and_truncate(concat_segments_indices, tokenizer.max_seq_len)
text_bert_indices = tokenizer.text_to_sequence("[CLS] " + text_left + " " + aspect + " " + text_right + " [SEP]")
aspect_bert_indices = tokenizer.text_to_sequence("[CLS] " + aspect + " [SEP]")
dependency_graph = np.pad(idx2graph[i], \
((0,tokenizer.max_seq_len-idx2graph[i].shape[0]),(0,tokenizer.max_seq_len-idx2graph[i].shape[0])), 'constant')
data = {
'concat_bert_indices': concat_bert_indices,
'concat_segments_indices': concat_segments_indices,
'text_bert_indices': text_bert_indices,
'aspect_bert_indices': aspect_bert_indices,
'text_indices': text_indices,
'context_indices': context_indices,
'left_indices': left_indices,
'left_with_aspect_indices': left_with_aspect_indices,
'right_indices': right_indices,
'right_with_aspect_indices': right_with_aspect_indices,
'aspect_indices': aspect_indices,
'aspect_boundary': aspect_boundary,
'dependency_graph': dependency_graph,
'polarity': polarity,
}
all_data.append(data)
self.data = all_data
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。