代码拉取完成,页面将自动刷新
import os
import copy
import json
import logging
import torch
from torch.utils.data import TensorDataset
from bert_finetune_re.utils import get_re_labels
# 日志对象初始化
logger = logging.getLogger(__name__)
#############################
# add entity markers
#############################
class InputExample(object):
"""
定义InputExample类数据
"""
def __init__(self, guid, words, re_label_id=None, head_entity_pos=None, tail_entity_pos=None):
# 每个样本的独特序号
self.guid = guid
# 原文本
self.words = words
# 关系分类的标签
self.re_label_id = re_label_id
# 头实体的位置信息
self.head_entity_pos = head_entity_pos # tuple, (s, e)
# 尾实体的位置信息
self.tail_entity_pos = tail_entity_pos # tuple, (s, e)
def __repr__(self):
"""
这里重写我们的输入信息
"""
return str(self.to_json_string())
def to_dict(self):
"""
将此实例序列化到Python字典中
__dict__:类的静态函数、类函数、普通函数、全局变量以及一些内置的属性都是放在类__dict__里的
对象实例的__dict__中存储了一些self.xxx的一些东西
"""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""
类的属性等信息(字典格式)dump进入json string
json.dumps()函数将python对象编码成JSON字符串
indent=2.文件格式中加入了换行与缩进
"""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class InputFeatures(object):
"""
定义输入到BERT模型的InputFeatures类数据
"""
def __init__(self, input_ids,
attention_mask,
token_type_ids,
re_label_id,
head_entity_pos=None,
tail_entity_pos=None,
):
# 输入样本序列在bert词表里的索引,可以直接喂给nn.embedding
self.input_ids = input_ids
# 注意力mask,padding的部分为0,其他为1
self.attention_mask = attention_mask
# 表示每个token属于句子1还是句子2
self.token_type_ids = token_type_ids
# 关系分类标签
self.re_label_id = re_label_id
# 头实体位置信息作为模型输入
self.head_entity_pos = head_entity_pos
# 尾实体位置信息作为模型输入
self.tail_entity_pos = tail_entity_pos
def __repr__(self):
"""
这里重写我们的输入信息
"""
return str(self.to_json_string())
def to_dict(self):
"""
将此实例序列化到Python字典中
__dict__:类的静态函数、类函数、普通函数、全局变量以及一些内置的属性都是放在类__dict__里的
对象实例的__dict__中存储了一些self.xxx的一些东西
"""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""
类的属性等信息(字典格式)dump进入json string
json.dumps()函数将python对象编码成JSON字符串
indent=2.文件格式中加入了换行与缩进
"""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class ReProcessor(object):
"""
关系分类任务的数据处理器
"""
def __init__(self, args):
# 参数
self.args = args
# 获得关系分类标签
self.re_label2id = get_re_labels(args)
# 输入文本文件
self.input_text_file = 'semeval.txt'
@classmethod
def _read_file(cls, input_file, quotechar=None):
"""
逐行读取输入文件
:param input_file: 输入文件路径
:param quotechar:
:return: 句子列表
"""
with open(input_file, "r", encoding="utf-8") as f:
lines = []
for line in f:
lines.append(line.strip())
return lines
def _create_examples(self, lines, set_type):
"""
为训练集与验证集构建example
:param lines: 句子列表
:param set_type: 区分训练集与验证集
:return: 处理后的InputExample类数据
"""
examples = []
for i, line in enumerate(lines):
# 每个样本的独特序号
guid = "%s-%s" % (set_type, i)
# 去掉首尾空格
line = line.strip()
if not line:
continue
# json.loads()函数将已编码的JSON字符串解码成python的dict对象
line = json.loads(line)
# 关系标签
# 获得"relation"关系标签,如果指定的键"relation"不存在,则返回默认值"Other"
re_label = line.get("relation", "Other")
# 获得关系标签对应的索引
re_label_id = self.re_label2id[re_label]
# 获得输入文本
words = line["token"]
# 头实体的位置
head_entity_pos = line["h"]["pos"]
# 头实体
head_entity_mention = line["h"]["name"]
# 断言:头实体与其位置是否对应
assert " ".join(words[head_entity_pos[0]: head_entity_pos[1]]) == head_entity_mention
# 尾实体的位置
tail_entity_pos = line["t"]["pos"]
# 尾实体
tail_entity_mention = line["t"]["name"]
# 如果尾实体与其位置不对应,则分别打印两个位置
if " ".join(words[tail_entity_pos[0]: tail_entity_pos[1]]) != tail_entity_mention:
print(" ".join(words[tail_entity_pos[0]: tail_entity_pos[1]]))
print(tail_entity_mention)
# 断言:尾实体与其位置是否对应
assert " ".join(words[tail_entity_pos[0]: tail_entity_pos[1]]) == tail_entity_mention
# 将每一行JSON文件转化成InputExample类数据
examples.append(
InputExample(
guid,
words,
re_label_id=re_label_id,
head_entity_pos=head_entity_pos,
tail_entity_pos=tail_entity_pos,
)
)
return examples
def get_examples(self, mode):
"""
获得样例数据
Args:
mode: train, dev, test
"""
# 拼接数据路径
data_path = os.path.join(self.args.data_dir, self.args.task, mode)
# 写入日志
logger.info("LOOKING AT {}".format(data_path))
# 构建InputExample类样例数据
return self._create_examples(lines=self._read_file(os.path.join(data_path, self.input_text_file)),
set_type=mode)
# 如果有多个数据集,则数据集的processor可以通过映射得到
processors = {
"semeval10": ReProcessor,
}
def convert_examples_to_features(examples,
max_seq_len,
tokenizer,
cls_token_segment_id=0,
pad_token_segment_id=0,
sequence_a_segment_id=0,
mask_padding_with_zero=True,
head_entity_start_token="[unused0]",
head_entity_end_token="[unused2]",
tail_entity_start_token="[unused1]",
tail_entity_end_token="[unused3]",
):
"""
将example数据转化为BERT模型需要的输入格式
head_entity_start_token、head_entity_end_token、tail_entity_start_token、tail_entity_end_token为entity markers
:param examples: 样例数据
:param max_seq_len: 最大序列长度
:param tokenizer: 分词
:param cls_token_segment_id: 0
:param pad_token_segment_id: 0
:param sequence_a_segment_id: 0
:param mask_padding_with_zero: True
:param head_entity_start_token: [unused0]
:param head_entity_end_token: [unused2]
:param tail_entity_start_token: [unused1]
:param tail_entity_end_token: [unused3]
:return: 返回BERT模型的输入数据
"""
# 基于当前分词模型的设置
cls_token = tokenizer.cls_token
sep_token = tokenizer.sep_token
unk_token = tokenizer.unk_token
pad_token_id = tokenizer.pad_token_id
features = []
# 循环遍历每一个样例数据
for (ex_index, example) in enumerate(examples):
# 每隔1000个数据,写入日志
if ex_index % 1000 == 0:
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
# 关系标签id
re_label_id = int(example.re_label_id)
# 头尾实体的位置
head_entity_start_pos = None
tail_entity_start_pos = None
head_entity_end_pos = None
tail_entity_end_pos = None
# Tokenize word by word (for RE task)
tokens = []
# 遍历每一个样例数据中的每一个单词
for i, word in enumerate(example.words):
# 使用tokenize()函数对文本进行tokenization之后,返回的分词的token词
word_tokens = tokenizer.tokenize(word)
# 如果碰到头实体的开始位置
if example.head_entity_pos[0] == i:
# 加入特殊符号
tokens.append(head_entity_start_token)
# 记录头实体的位置
head_entity_start_pos = len(tokens)
# 如果碰到尾实体的开始位置
if example.tail_entity_pos[0] == i:
tokens.append(tail_entity_start_token)
tail_entity_start_pos = len(tokens)
# 拓展tokens序列
tokens.extend(word_tokens)
# 如果碰到头实体的结束位置
if example.head_entity_pos[1] == i + 1:
head_entity_end_pos = len(tokens)
tokens.append(head_entity_end_token)
# 如果碰到尾实体的结束位置
if example.tail_entity_pos[1] == i + 1:
tail_entity_end_pos = len(tokens)
tokens.append(tail_entity_end_token)
# 如果句长太长,需要对句子进行截断,导致头尾实体不全,则删除
if not head_entity_start_pos or not tail_entity_start_pos \
or not head_entity_end_pos or not tail_entity_end_pos:
continue
# 添加[CLS]和[SEP]
special_tokens_count = 2
# 如果头尾实体不全,则删除
if head_entity_end_pos > max_seq_len - special_tokens_count - 1:
continue
if tail_entity_end_pos > max_seq_len - special_tokens_count - 1:
continue
# 如果句子过长,则进行截断
if len(tokens) > max_seq_len - special_tokens_count:
# token词表也进行相应的截断
tokens = tokens[:(max_seq_len - special_tokens_count)]
# Add [SEP] token
tokens += [sep_token]
token_type_ids = [sequence_a_segment_id] * len(tokens)
# Add [CLS] token
tokens = [cls_token] + tokens
token_type_ids = [cls_token_segment_id] + token_type_ids
# 将tokens转化为bert词表中对应的id
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# 注意力mask,句子中存在的部分为1
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
# 需要填充的序列长度
padding_length = max_seq_len - len(input_ids)
# 输入样本序列在bert词表里的索引
input_ids = input_ids + ([pad_token_id] * padding_length)
# 注意力mask,padding的部分为0,其他为1
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
# token_type_ids表示每个token属于句子1还是句子2
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
# 验证长度是否填充至最长序列
assert len(input_ids) == max_seq_len, "Error with input length {} vs {}".format(len(input_ids), max_seq_len)
assert len(attention_mask) == max_seq_len, "Error with attention mask length {} vs {}".format(len(attention_mask), max_seq_len)
assert len(token_type_ids) == max_seq_len, "Error with token type length {} vs {}".format(len(token_type_ids), max_seq_len)
# 头实体位置索引
head_entity_pos = [head_entity_start_pos, head_entity_end_pos]
# 尾实体位置索引
tail_entity_pos = [tail_entity_start_pos, tail_entity_end_pos]
# 如果是前5个数据,则记录日志
if ex_index < 5:
logger.info("*** Example ***")
logger.info("guid: %s" % example.guid)
logger.info("tokens: %s" % " ".join([str(x) for x in tokens]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids]))
logger.info("re_label_id: %d" % (re_label_id))
logger.info("head_entity_pos: (%d, %d)" % (head_entity_pos[0], head_entity_pos[1]))
logger.info("tail_entity_pos: (%d, %d)" % (tail_entity_pos[0], tail_entity_pos[1]))
# 构造InputFeatures类BERT模型输入数据
features.append(
InputFeatures(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
re_label_id=re_label_id,
head_entity_pos=head_entity_pos,
tail_entity_pos=tail_entity_pos,
))
return features
def load_and_cache_examples(args, tokenizer, mode):
"""
将数据转化为cache文件,方便下一次快速加载
:param args: 参数
:param tokenizer: 分词
:param mode: 区分训练、验证、测试
:return:
"""
# 加载醒目数据处理器
processor = processors[args.task](args)
# 拼接cach文件路径
cached_features_file = os.path.join(
args.data_dir,
'cached_{}_{}_{}_{}'.format(
mode,
args.task,
list(filter(None, args.model_name_or_path.split("/"))).pop(),
args.max_seq_len
)
)
# 如果路径存在,则加载cach文件
if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file)
# 如果路径不存在
else:
# Load data features from dataset file
logger.info("Creating features from dataset file at %s", args.data_dir)
# 区分数据集
if mode == "train":
examples = processor.get_examples("train")
elif mode == "dev":
examples = processor.get_examples("dev")
elif mode == "test":
examples = processor.get_examples("test")
else:
raise Exception("For mode, Only train, dev, test is available")
# 在计算交叉熵损失时忽略的索引:-100
pad_token_label_id = args.ignore_index
# 将example数据转化为features数据
features = convert_examples_to_features(examples, args.max_seq_len, tokenizer,)
logger.info("Saving features into cached file %s", cached_features_file)
# 将数据保存至cach路径中
torch.save(features, cached_features_file)
# Convert to Tensors and build dataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
all_re_label_ids = torch.tensor([f.re_label_id for f in features], dtype=torch.long)
all_head_entity_pos = torch.tensor([f.head_entity_pos for f in features], dtype=torch.long)
all_tail_entity_pos = torch.tensor([f.tail_entity_pos for f in features], dtype=torch.long)
# 构造dataset
dataset = TensorDataset(all_input_ids, all_attention_mask,
all_token_type_ids, all_re_label_ids,
all_head_entity_pos, all_tail_entity_pos,
)
# 返回dataset
return dataset
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。