1 Star 2 Fork 0

solaris/HRNR

加入 Gitee
与超过 1400万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
cmt_gen.py 731 Bytes
一键复制 编辑 原始数据 按行查看 历史
solaris 提交于 2020-09-03 17:07 +08:00 . first commit
# coding=utf-8
import torch
from utils import *
def get_cmt(x, adj, s, mask=None):
x = x.unsqueeze(0) if x.dim() == 2 else x
adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
s = s.unsqueeze(0) if s.dim() == 2 else s
batch_size, num_nodes, _ = x.size()
if mask is not None:
mask = mask.view(batch_size, num_nodes, 1).to(x.dtype)
x, s = x * mask, s * mask
out = torch.matmul(s.transpose(1, 2), x)
out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s)
link_loss = adj - torch.matmul(s, s.transpose(1, 2)) + EPS
link_loss = torch.norm(link_loss, p=2)
link_loss = link_loss / adj.numel()
ent_loss = (-s * torch.log(s + EPS)).sum(dim=-1).mean()
return out, out_adj, link_loss, ent_loss
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/solaris_wn/HRNR.git
git@gitee.com:solaris_wn/HRNR.git
solaris_wn
HRNR
HRNR
master

搜索帮助