1 Star 1 Fork 1

付昌陇/MMBSSL

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
data_process.py 3.22 KB
一键复制 编辑 原始数据 按行查看 历史
fu-changlong 提交于 2023-07-03 18:30 . 提交
import random
from operator import itemgetter
import numpy as np
import torch
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.determinstic = True
def process_idx(loader):
new_loader = []
for batch in loader:
clique_batch = []
pos = 0
for t in range(len(batch)):
if t == 0:
pos = 0
else:
pos += batch[t - 1].x.size(0)
for j, (g, i) in enumerate(zip(batch[t].clique_idx, batch[t].clique_slice)):
clique_batch.append(g[:i].add_(pos))
batch.clique_batch = clique_batch
new_loader.append(batch)
return new_loader
def add_Inbatch(batch):
clique_batch = []
pos = 0
for t in range(len(batch)):
if t == 0:
pos = 0
else:
pos += batch[t - 1].x.size(0)
for j, (g, i) in enumerate(zip(batch[t].clique_idx, batch[t].clique_slice)):
clique_batch.append(g[:i].add_(pos))
batch.clique_batch = clique_batch
return batch
def get_dict(graphs, clique_dict, all_edges):
idxs = []
for i in graphs:
idxs.append(i.id)
get_items = itemgetter(*idxs)
cli_dict = get_items(clique_dict)
edge_dict = get_items(all_edges)
return cli_dict, edge_dict
def readout(clique_dict, all_edges, graphs):
index = 0
graphx = []
for g in graphs:
readout_result = []
data = g.x.data
x = clique_dict[index]
for j in x:
sum_result = sum([data[a] for a in j])
sum_result = sum_result.cpu().numpy()
readout_result.append(sum_result)
readout_result = np.array(readout_result)
readout_result = torch.Tensor(readout_result)
g.x = readout_result
g.edge_index = torch.Tensor(all_edges[index]).T.to(torch.long)
if len(g.edge_index) != 0:
g.edge_attr = torch.zeros(g.edge_index.size(1), 2)
g.edge_attr[:, 0] = 6
g.edge_attr[:, 1] = 3
g.edge_attr = g.edge_attr.to(torch.long)
num_atoms = g.x.shape[0]
g.num_nodes = num_atoms
graphx.append(g)
index += 1
return graphx
def read_book(cli_path, edge_path):
file = open(cli_path, 'r', encoding='utf-8')
data = file.readlines()
all_edge = []
edges = open(edge_path, 'r', encoding='utf-8')
for e in edges:
s = e.strip('\n')
lst = eval(s)
# 将元组转换为二维列表
lst_2d = [[x, y] for x, y in lst]
all_edge.append(lst_2d)
clique = []
cliques = []
mol_index = 0
for index, cli in enumerate(data):
cli = cli.strip('\n')
index_cli = (cli.split(' /'))
indexs = index_cli[0]
mol, cli_index = indexs.split(' ')
mol, cli_index = int(mol), int(cli_index)
if mol != mol_index:
cliques.append(clique)
clique = []
c = eval(index_cli[1])
clique.append(c)
mol_index = mol
else:
c = eval(index_cli[1])
clique.append(c)
cliques.append(clique)
return cliques, all_edge
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/fu-changlong/mmbssl.git
git@gitee.com:fu-changlong/mmbssl.git
fu-changlong
mmbssl
MMBSSL
master

搜索帮助