代码拉取完成,页面将自动刷新
# coding=utf-8
import torch
from utils import *
def get_cmt(x, adj, s, mask=None):
x = x.unsqueeze(0) if x.dim() == 2 else x
adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
s = s.unsqueeze(0) if s.dim() == 2 else s
batch_size, num_nodes, _ = x.size()
if mask is not None:
mask = mask.view(batch_size, num_nodes, 1).to(x.dtype)
x, s = x * mask, s * mask
out = torch.matmul(s.transpose(1, 2), x)
out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s)
link_loss = adj - torch.matmul(s, s.transpose(1, 2)) + EPS
link_loss = torch.norm(link_loss, p=2)
link_loss = link_loss / adj.numel()
ent_loss = (-s * torch.log(s + EPS)).sum(dim=-1).mean()
return out, out_adj, link_loss, ent_loss
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。