1 Star 1 Fork 1

付昌陇/MMBSSL

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
main.py 21.21 KB
一键复制 编辑 原始数据 按行查看 历史
fu-changlong 提交于 2023-07-03 18:30 . 提交
import argparse
import copy
import logging
import os
import random
from functools import partial
from operator import itemgetter
# sys.path.append('./')
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import ParameterGrid
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.data import Batch
from torch_geometric.loader import DataListLoader, DataLoader
from torch_geometric.nn import GINConv, global_mean_pool, global_add_pool
from tqdm import tqdm
from chem.data_process import process_idx, set_random_seed
from chem.dataloader import DataLoaderMasking1, DataLoaderMasking2
from chem.loader import MoleculeDataset1
from loader_we import MoleculeDataset2
from model import GNN
from splitters import scaffold_split, random_split
class CosineDecayScheduler:
def __init__(self, max_val, warmup_steps, total_steps):
self.max_val = max_val
self.warmup_steps = warmup_steps
self.total_steps = total_steps
def get(self, step):
if step < self.warmup_steps:
return self.max_val * step / self.warmup_steps
elif self.warmup_steps <= step <= self.total_steps:
return self.max_val * (1 + np.cos((step - self.warmup_steps) * np.pi /
(self.total_steps - self.warmup_steps))) / 2
else:
raise ValueError('Step ({}) > total number of steps ({}).'.format(step, self.total_steps))
def sce_loss(x, y, alpha=1):
x = F.normalize(x, p=2, dim=-1)
y = F.normalize(y, p=2, dim=-1)
loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)
loss = loss.mean()
return loss
def mask(g, mask_rate=0.5):
num_nodes = g.x.shape[0]
perm = torch.randperm(num_nodes, device=g.x.device)
num_mask_nodes = int(mask_rate * num_nodes)
mask_nodes = perm[:num_mask_nodes]
return mask_nodes
class D_CG(nn.Module):
# rate学习率
def __init__(self, num_layer, pre_dim, emb_dim, JK, drop_ratio=0, mask_rate=0, random_mask=False):
super(D_CG, self).__init__()
self.drop_rate = drop_ratio
self.online_encoder = GNN(num_layer, pre_dim=pre_dim, emb_dim=emb_dim, JK=JK, drop_ratio=self.drop_rate,
pre=True)
self.target_encoder = GNN(num_layer, pre_dim=pre_dim, emb_dim=emb_dim, JK=JK, drop_ratio=0,
pre=True)
self.mask_rate = mask_rate
self.enc_mask_token = nn.Parameter(torch.zeros(1, pre_dim))
self.criterion = self.setup_loss_fn("sce", 1)
self.pool_mean = global_mean_pool
self.pool_add = global_add_pool
self.random_mask = random_mask
for param in self.target_encoder.parameters():
param.requires_grad = False
def setup_loss_fn(self, loss_fn, alpha_l):
if loss_fn == "mse":
criterion = nn.MSELoss()
elif loss_fn == "sce":
criterion = partial(sce_loss, alpha=alpha_l)
else:
raise NotImplementedError
return criterion
def trainable_parameters(self):
r"""Returns the parameters that will be updated via an optimizer."""
return list(self.online_encoder.parameters())
@torch.no_grad()
def update_target_network(self, mm):
r"""Performs a momentum update of the target network's weights.
Args:
mm (float): Momentum used in moving average update.
"""
for param_q, param_k in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
param_k.data.mul_(mm).add_(param_q.data, alpha=1. - mm)
def forward(self, graph_batch, device):
graph_batch = graph_batch.to(device)
h1 = self.online_encoder(graph_batch.x.to(device), graph_batch.edge_index, graph_batch.edge_attr)
with torch.no_grad():
h2 = self.target_encoder(graph_batch.x, graph_batch.edge_index, graph_batch.edge_attr)
loss_2 = self.criterion(h1, h2.detach())
return loss_2
def get_embed(self, graph_batch, device):
with torch.no_grad():
graph_batch = graph_batch.to(device)
h1 = self.online_encoder(graph_batch.x, graph_batch.edge_index, graph_batch.edge_attr)
h = self.pool_add(h1, graph_batch.batch)
return h.detach()
class CG(nn.Module):
# rate学习率
def __init__(self, num_layer, pre_dim, emb_dim, JK, drop_ratio, mask_rate, random_mask=False):
super(CG, self).__init__()
self.drop_rate = drop_ratio
self.online_encoder = GNN(num_layer, pre_dim=pre_dim, emb_dim=emb_dim, JK=JK, drop_ratio=self.drop_rate,
pre=True)
self.target_encoder = copy.deepcopy(self.online_encoder)
self.target_encoder.reset_parameters()
self.mask_rate = mask_rate
self.enc_mask_token = nn.Parameter(torch.zeros(1, pre_dim))
self.criterion = self.setup_loss_fn("sce", 1)
self.pool_mean = global_mean_pool
self.pool_add = global_add_pool
self.random_mask = random_mask
for param in self.target_encoder.parameters():
param.requires_grad = False
def setup_loss_fn(self, loss_fn, alpha_l):
if loss_fn == "mse":
criterion = nn.MSELoss()
elif loss_fn == "sce":
criterion = partial(sce_loss, alpha=alpha_l)
else:
raise NotImplementedError
return criterion
def trainable_parameters(self):
r"""Returns the parameters that will be updated via an optimizer."""
return list(self.online_encoder.parameters())
@torch.no_grad()
def update_target_network(self, mm):
r"""Performs a momentum update of the target network's weights.
Args:
mm (float): Momentum used in moving average update.
"""
for param_q, param_k in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
param_k.data.mul_(mm).add_(param_q.data, alpha=1. - mm)
def forward(self, graph_batch, device):
if self.random_mask:
graph_batch = graph_batch.to(device)
mask_nodes = mask(graph_batch, mask_rate=0.1)
x = graph_batch.x.clone()
x[mask_nodes] = 0.0
x[mask_nodes] += self.enc_mask_token.long().to(graph_batch.x.device)
h1 = self.online_encoder(x.to(device), graph_batch.edge_index, graph_batch.edge_attr)
with torch.no_grad():
h2 = self.target_encoder(graph_batch.x, graph_batch.edge_index, graph_batch.edge_attr)
loss_2 = self.criterion(h1[mask_nodes], h2[mask_nodes].detach())
else:
graph_batch = graph_batch.to(device)
x = graph_batch.x.clone().float().to(device)
mask_nodes = graph_batch.mask_nodes.long()
x[mask_nodes] = 0.0
x[mask_nodes] += self.enc_mask_token
h1 = self.online_encoder(x, graph_batch.edge_index, graph_batch.edge_attr)
with torch.no_grad():
h2 = self.target_encoder(graph_batch.x.float(), graph_batch.edge_index, graph_batch.edge_attr)
h1_mask_pool = self.pool_mean(h1[mask_nodes], graph_batch.batch[mask_nodes])
h2_mask_pool = self.pool_mean(h2[mask_nodes].detach(), graph_batch.batch[mask_nodes])
loss_2 = self.criterion(h1_mask_pool, h2_mask_pool)
return loss_2
def get_embed(self, graph_batch, device):
with torch.no_grad():
graph_batch = graph_batch.to(device)
h1 = self.online_encoder(graph_batch.x, graph_batch.edge_index, graph_batch.edge_attr)
h = self.pool_add(h1, graph_batch.batch)
return h.detach()
class LogReg(torch.nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.graph_pred_linear = torch.nn.Linear(in_dim, out_dim)
def forward(self, h):
z = self.graph_pred_linear(h)
return z
def computer_roc(log, model, loader, device):
y_true = []
y_scores = []
log.eval()
for step, batch in enumerate(loader):
with torch.no_grad():
z = model.get_embed(batch, device)
pred = log(z)
y_true.append(batch.y.view(pred.shape))
y_scores.append(pred)
y_true = torch.cat(y_true, dim=0).cpu().detach().numpy()
y_scores = torch.cat(y_scores, dim=0).cpu().detach().numpy()
roc_list = []
for i in range(y_true.shape[1]):
if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == -1) > 0:
is_valid = y_true[:, i] ** 2 > 0
roc_list.append(roc_auc_score((y_true[is_valid, i] + 1) / 2, y_scores[is_valid, i]))
if len(roc_list) < y_true.shape[1]:
print("some target is missing")
print("missing ratio: %f" % (1 - float(len(roc_list)) / y_true.shape[1]))
roc = sum(roc_list) / len(roc_list)
return roc
def evaluation(model, writer, loader, train_loader, val_loader, test_loader, num_tasks, device,
num_test_epoch=100):
xent = nn.BCEWithLogitsLoss(reduction="none")
log = LogReg(300, num_tasks).to(device)
opt = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=0.0)
loss_list, train_roc, val_roc, test_roc = [], [], [], []
pdar_2 = tqdm(range(num_test_epoch))
best_val_roc = 0
test_best_roc = []
model.eval()
for epoch in pdar_2:
total_loss = []
for step, batch in enumerate(loader):
with torch.no_grad():
z = model.get_embed(batch, device)
# train
log.train()
pred = log(z)
y = batch.y.view(pred.shape).to(torch.float64)
is_valid = y ** 2 > 0
loss_mat = xent(pred.double(), (y + 1) / 2)
loss_mat = torch.where(is_valid, loss_mat,
torch.zeros(loss_mat.shape).to(loss_mat.device).to(loss_mat.dtype))
opt.zero_grad()
loss = torch.sum(loss_mat) / torch.sum(is_valid)
total_loss.append(loss.item())
loss.backward()
opt.step()
avg_loss = sum(total_loss) / len(total_loss)
logging.info("Epoch: {}, Loss: {}".format(epoch, round(avg_loss, 3)))
# roc_of_train
roc_train = computer_roc(log, model, train_loader, device)
# roc_of_val
roc_val = computer_roc(log, model, val_loader, device)
# roc_of_test
roc_test = computer_roc(log, model, test_loader, device)
best_test = 0
if roc_val >= best_val_roc:
if roc_val == best_val_roc and best_test < roc_test:
best_test = roc_test
test_best_roc.append([best_val_roc, roc_test, epoch])
continue
best_test = roc_test
best_val_roc = roc_val
test_best_roc.append([best_val_roc, roc_test, epoch])
writer.add_scalar("train_loss", avg_loss, epoch)
writer.add_scalar("train_roc", roc_train, epoch)
writer.add_scalar("val_roc", roc_val, epoch)
writer.add_scalar("test_roc", roc_test, epoch)
loss_list.append(avg_loss)
train_roc.append(roc_train)
val_roc.append(roc_val)
test_roc.append(roc_test)
pdar_2.set_description(
f"Epoch: {epoch}, Loss: {round(avg_loss, 3)}, Train_roc: {round(roc_train, 3)},Val_roc: {round(roc_val, 3)},Test_roc: {round(roc_test, 3)}")
return loss_list, train_roc, val_roc, test_roc, test_best_roc
def train(args, model, loader, train_loader, val_loader, test_loader, device):
dataname = args.dataset
save_name = args.save_name
num_epochs = args.epochs
num_tasks = args.num_tasks
num_test_epoch = args.test_epochs
writer = SummaryWriter(log_dir=f"./logs/{dataname}/{args.out_file}/{save_name}")
lr = args.lr
mm_l = args.mm
lr_scheduler = CosineDecayScheduler(lr, 100, num_epochs)
mm_scheduler = CosineDecayScheduler(1 - mm_l, 0, num_epochs)
optimizer = Adam(model.trainable_parameters(), lr=lr, weight_decay=1e-4)
print("开始进行训练......")
best_loss = float('inf')
pdar = tqdm(range(1, num_epochs + 1))
for epoch in pdar:
model.train()
# update learning rate
lr = lr_scheduler.get(epoch - 1)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
mm = 1 - mm_scheduler.get(epoch - 1)
losses = 0
for gs in loader:
loss = model(gs, device)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.update_target_network(mm)
losses += loss.item()
avg_loss = losses / len(loader)
if avg_loss < best_loss:
best_loss = avg_loss
torch.save(model.state_dict(), f"./result/{dataname}/{args.out_file}/{save_name}/{dataname}_best_model.pt")
pdar.set_description(f"Epoch: {epoch}, Loss: {round(losses / len(train_loader), 5)}")
writer.add_scalar("ContrastiveTrain_Loss", round(losses / len(train_loader), 5), epoch)
loss_list, train_roc, val_roc, test_roc, test_best_roc = evaluation(model, writer, loader, train_loader, val_loader,
test_loader, num_tasks, device,
num_test_epoch)
# 打开文件
with open(f"./result/{dataname}/{args.out_file}/{save_name}/result_of_train.txt", 'w') as f:
# 写入数据
for i in range(len(loss_list)):
f.write(f'{loss_list[i]:.4f},{train_roc[i]:.4f},{val_roc[i]:.4f},{test_roc[i]:.4f}\n')
with open(f"./result/{dataname}/{args.out_file}/{save_name}/result_of_roc_test.txt", 'w') as f:
# 写入数据
for i in range(len(test_best_roc)):
f.write("val_roc,test_roc,epoch\n")
val_roc, test_roc, epoch = test_best_roc[i][0], test_best_roc[i][1], test_best_roc[i][2],
f.write(f'{round(val_roc, 4)},{round(test_roc, 4)},{epoch}\n')
return test_best_roc[-1]
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
## cessary args
parser.add_argument('--dataset', type=str, default='bbbp')
parser.add_argument('--input_model_file', type=str,
default='./result/pre_training/Method_3_0.25dropout/zinc_standard_agent_best_model.pth',
help='filename to read the model (if there is any)')
parser.add_argument("--seeds", type=int, nargs="+", default=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
parser.add_argument('--save_name', type=str, default='exp1')
parser.add_argument('--batch_size', type=int, default=256, help='input batch size for training (default: 32)')
parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train (default: 100)')
parser.add_argument('--pre_dim', type=int, default=2)
parser.add_argument('--test_epochs', type=int, default=100, help='number of epochs to train (default: 100)')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate (default: 0.001)')
parser.add_argument('--num_layer', type=int, default=3, help='number of GNN message passing layers (default: 5).')
parser.add_argument('--emb_dim', type=int, default=300, help='embedding dimensions (default: 300)')
parser.add_argument('--dropout_ratio', type=float, default=0.25, help='dropout ratio (default: 0.5)')
parser.add_argument('--mask_ratio', type=float, default=0.0, help='mask ratio (default: 0.5)')
parser.add_argument('--JK', type=str, default="last")
parser.add_argument('--split', type=str, default="scaffold", help="random or scaffold or random_scaffold")
parser.add_argument('--num_workers', type=int, default=4, help='number of workers for dataset loading')
# 默认1为不对官能团进行任何处理,2是保留长的官能团以及保留去重之后短的官能团,3是保留短的官能团,不遮蔽长的官能团
parser.add_argument('--hander_mt', type=int, default=3, help='1,2,3')
parser.add_argument('--mm', type=float, default=0.999, help='0.99,0.999,0.999')
parser.add_argument('--pre_mask', type=bool, default=True, help='number of workers for dataset loading')
args = parser.parse_args()
print(args)
args_Dict = args.__dict__
seeds = args.seeds
dataname = args.dataset
split = args.split
batch_size = args.batch_size
num_workers = args.num_workers
if dataname == "tox21":
num_tasks = 12
elif dataname == "hiv":
num_tasks = 1
elif dataname == "muv":
num_tasks = 17
elif dataname == "bace":
num_tasks = 1
elif dataname == "bbbp":
num_tasks = 1
elif dataname == "toxcast":
num_tasks = 617
elif dataname == "sider":
num_tasks = 27
elif dataname == "clintox":
num_tasks = 2
elif dataname == "ptc_mr":
num_tasks = 1
elif dataname == "mutag":
num_tasks = 1
else:
raise ValueError("Invalid dataset name.")
device = torch.device("cuda:0")
dataset = MoleculeDataset2("./dataset/" + dataname, dataset=dataname, mask_ratio=args.mask_ratio,
hander_mt=args.hander_mt)
if split == "scaffold":
smiles_list = pd.read_csv('./dataset/' + dataname + '/processed/smiles.csv', header=None)[
0].tolist()
train_dataset, valid_dataset, test_dataset = scaffold_split(dataset, smiles_list, null_value=0,
frac_train=0.8,
frac_valid=0.1, frac_test=0.1)
print("数据集划分方法为: scaffold")
elif split == "random":
smiles_list = pd.read_csv('./dataset/' + dataname + '/processed/smiles.csv', header=None)[
0].tolist()
train_dataset, valid_dataset, test_dataset, (train_smiles, valid_smiles,
test_smiles) = random_split(dataset, task_idx=None, null_value=0,
frac_train=0.8, frac_valid=0.1,
frac_test=0.1,
seed=0,
smiles_list=smiles_list)
print("数据集划分方法为: scaffold")
else:
raise ValueError("Invalid split option.")
loader = DataLoaderMasking2(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
train_loader = DataLoaderMasking2(train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
val_loader = DataLoaderMasking2(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoaderMasking2(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
args.num_tasks = num_tasks
roc_list = []
test_roc_list = []
j = 1
while True:
out_file = 'test'
out_file = out_file + str(j)
if not os.path.exists(f"./result/{dataname}/{out_file}"):
os.makedirs(f"./result/{dataname}/{out_file}")
args.out_file = out_file
break
j = j + 1
with open(f"./result/{dataname}/{out_file}/args.txt", 'w', encoding='utf-8') as f:
for eachArg, value in args_Dict.items():
f.writelines(eachArg + ' : ' + str(value) + '\n')
for i, seed in enumerate(seeds):
folder_name = "exp"
j = 1
while True:
new_folder_name = folder_name + str(j)
if not os.path.exists(f"./result/{dataname}/{out_file}/{new_folder_name}"):
os.makedirs(f"./result/{dataname}/{out_file}/{new_folder_name}")
args.save_name = new_folder_name
break
j += 1
print(f"####### Run {i} for seed {seed}")
set_random_seed(seed)
model = CG(args.num_layer, args.pre_dim, args.emb_dim, "last", args.dropout_ratio, mask_rate=args.mask_ratio,
random_mask=True).to(device)
if dataname not in ['mutag', 'ptc_mr']:
model.load_state_dict(torch.load(args.input_model_file))
test_roc = train(args, model, loader, train_loader, val_loader, test_loader, device)
roc_list.append(test_roc)
test_roc_list.append(test_roc[1])
final_roc, final_roc_std = np.mean(test_roc_list), np.std(test_roc_list)
print(final_roc, final_roc_std)
with open(f"./result/{dataname}/{args.out_file}/{args.save_name}/result_of_train.txt", "w") as f:
for i in roc_list:
val_roc, test_roc, epoch = i[0], i[1], i[2]
f.write(f'{round(val_roc, 4)},{round(test_roc, 4)},{epoch}\n')
f.write("# final_acc:" + str(final_roc) + "±" + str(final_roc_std) + "\n")
if __name__ == '__main__':
main()
# dataname = 'sider'
# dataset = MoleculeDataset2("./dataset/" + dataname, dataset=dataname, mask_ratio=0.24,hander_mt=1)
# print(dataset[0])
# loader = DataLoaderMasking2(dataset, batch_size=256, shuffle=False, num_workers=0)
# z = 0
# for i in loader:
# z += len(i.mask_nodes) / i.x.size(0)
# print(len(i.mask_nodes) / i.x.size(0))
# print(z / len(loader), 'pj')
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/fu-changlong/mmbssl.git
git@gitee.com:fu-changlong/mmbssl.git
fu-changlong
mmbssl
MMBSSL
master

搜索帮助