1 Star 7 Fork 1

quarky/ABSA-PyTorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
aoa.py 2.06 KB
一键复制 编辑 原始数据 按行查看 历史
# -*- coding: utf-8 -*-
# file: aoa.py
# author: gene_zc <gene_zhangchen@163.com>
# Copyright (C) 2018. All Rights Reserved.
from layers.dynamic_rnn import DynamicLSTM
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class AOA(nn.Module):
def __init__(self, embedding_matrix, opt):
super(AOA, self).__init__()
self.opt = opt
self.embed = nn.Embedding.from_pretrained(torch.tensor(embedding_matrix, dtype=torch.float))
self.ctx_lstm = DynamicLSTM(opt.embed_dim, opt.hidden_dim, num_layers=1, batch_first=True, bidirectional=True)
self.asp_lstm = DynamicLSTM(opt.embed_dim, opt.hidden_dim, num_layers=1, batch_first=True, bidirectional=True)
self.dense = nn.Linear(2 * opt.hidden_dim, opt.polarities_dim)
def forward(self, inputs):
text_indices = inputs[0] # batch_size x seq_len
aspect_indices = inputs[1] # batch_size x seq_len
ctx_len = torch.sum(text_indices != 0, dim=1)
asp_len = torch.sum(aspect_indices != 0, dim=1)
ctx = self.embed(text_raw_indices) # batch_size x seq_len x embed_dim
asp = self.embed(aspect_indices) # batch_size x seq_len x embed_dim
ctx_out, (_, _) = self.ctx_lstm(ctx, ctx_len) # batch_size x (ctx) seq_len x 2*hidden_dim
asp_out, (_, _) = self.asp_lstm(asp, asp_len) # batch_size x (asp) seq_len x 2*hidden_dim
interaction_mat = torch.matmul(ctx_out, torch.transpose(asp_out, 1, 2)) # batch_size x (ctx) seq_len x (asp) seq_len
alpha = F.softmax(interaction_mat, dim=1) # col-wise, batch_size x (ctx) seq_len x (asp) seq_len
beta = F.softmax(interaction_mat, dim=2) # row-wise, batch_size x (ctx) seq_len x (asp) seq_len
beta_avg = beta.mean(dim=1, keepdim=True) # batch_size x 1 x (asp) seq_len
gamma = torch.matmul(alpha, beta_avg.transpose(1, 2)) # batch_size x (ctx) seq_len x 1
weighted_sum = torch.matmul(torch.transpose(ctx_out, 1, 2), gamma).squeeze(-1) # batch_size x 2*hidden_dim
out = self.dense(weighted_sum) # batch_size x polarity_dim
return out
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/quarky/ABSA-PyTorch.git
git@gitee.com:quarky/ABSA-PyTorch.git
quarky
ABSA-PyTorch
ABSA-PyTorch
master

搜索帮助