代码拉取完成,页面将自动刷新
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
num_atom_type = 120 # including the extra mask tokens
num_chirality_tag = 3
import torch.nn.functional as F
num_bond_type = 7 # including aromatic and self-loop edge, and extra masked tokens
num_bond_direction = 4
# device=torch.device('cuda:0')
def _reset_weights(m):
if isinstance(m, torch.nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
class GINConv(MessagePassing):
"""
Extension of GIN aggregation to incorporate edge information by concatenation.
Args:
emb_dim (int): dimensionality of embeddings for nodes and edges.tajiu
embed_input (bool): whether to embed input or not.
See https://arxiv.org/abs/1810.00826
"""
def __init__(self, emb_dim, out_dim, aggr="add", **kwargs):
kwargs.setdefault('aggr', aggr)
super(GINConv, self).__init__(**kwargs)
# multi-layer perceptron
self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.ReLU(),
torch.nn.Linear(2 * emb_dim, out_dim))
self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)
torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
self.aggr = aggr
def forward(self, x, edge_index, edge_attr):
# add self loops in the edge space
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# add features corresponding to self-loop edges.
self_loop_attr = torch.zeros(x.size(0), 2)
self_loop_attr[:, 0] = 4 # bond type for self-loop edge
self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0).to(x.device)
edge_embeddings = self.edge_embedding1(edge_attr[:, 0]) + self.edge_embedding2(edge_attr[:, 1])
return self.propagate(edge_index, x=x, edge_attr=edge_embeddings)
def message(self, x_j, edge_attr):
return x_j + edge_attr
def update(self, aggr_out):
return self.mlp(aggr_out)
def reset_parameters(self):
self.mlp.apply(_reset_weights)
class GNN(torch.nn.Module):
"""
Args:
num_layer (int): the number of GNN layers
emb_dim (int): dimensionality of embeddings
JK (str): last, concat, max or sum.
max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation
drop_ratio (float): dropout rate
gnn_type: gin, gcn, graphsage, gat
Output:
node representations
"""
def __init__(self, num_layer, pre_dim, emb_dim, JK="last", drop_ratio=0, pre=False):
super(GNN, self).__init__()
self.num_layer = num_layer
self.drop_ratio = drop_ratio
self.JK = JK
self.pre = pre
self.pre_dim = pre_dim
self.gnns = torch.nn.ModuleList()
if self.pre:
self.gnns.append(GINConv(self.pre_dim, emb_dim, aggr="add"))
for layer in range(num_layer - 1):
self.gnns.append(GINConv(emb_dim, emb_dim, aggr="add"))
else:
for layer in range(num_layer):
self.gnns.append(GINConv(emb_dim, emb_dim, aggr="add"))
self.batch_norms = torch.nn.ModuleList()
for layer in range(num_layer):
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
def forward(self, *argv):
if len(argv) == 3:
x, edge_index, edge_attr = argv[0], argv[1], argv[2]
elif len(argv) == 1:
data = argv[0]
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
else:
raise ValueError("unmatched number of arguments.")
h_list = [x]
for layer in range(self.num_layer):
h = self.gnns[layer](h_list[layer], edge_index, edge_attr)
h = self.batch_norms[layer](h)
if layer == self.num_layer - 1:
h = F.dropout(h, self.drop_ratio, training=self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
h_list.append(h)
if self.JK == "concat":
node_representation = torch.cat(h_list, dim=1)
elif self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "max":
h_list = [h.unsqueeze_(0) for h in h_list]
node_representation = torch.max(torch.cat(h_list, dim=0), dim=0)[0]
elif self.JK == "sum":
h_list = [h.unsqueeze_(0) for h in h_list]
node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)[0]
return node_representation
def reset_parameters(self):
for gnn in self.gnns:
gnn.reset_parameters()
for batch_norm in self.batch_norms:
batch_norm.reset_parameters()
if __name__ == "__main__":
pass
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。