代码拉取完成,页面将自动刷新
import argparse
import copy
import logging
import os
import random
from functools import partial
from operator import itemgetter
# sys.path.append('./')
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import ParameterGrid
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.data import Batch
from torch_geometric.loader import DataListLoader, DataLoader
from torch_geometric.nn import GINConv, global_mean_pool, global_add_pool
from tqdm import tqdm
from chem.data_process import process_idx, set_random_seed
from chem.dataloader import DataLoaderMasking1, DataLoaderMasking2
from chem.loader import MoleculeDataset1
from loader_we import MoleculeDataset2
from model import GNN
from splitters import scaffold_split
# torch.manual_seed(0)
# np.random.seed(0)
#
# if torch.cuda.is_available():
# torch.cuda.manual_seed_all(0)
# device = torch.device('cuda:0')
class CosineDecayScheduler:
def __init__(self, max_val, warmup_steps, total_steps):
self.max_val = max_val
self.warmup_steps = warmup_steps
self.total_steps = total_steps
def get(self, step):
if step < self.warmup_steps:
return self.max_val * step / self.warmup_steps
elif self.warmup_steps <= step <= self.total_steps:
return self.max_val * (1 + np.cos((step - self.warmup_steps) * np.pi /
(self.total_steps - self.warmup_steps))) / 2
else:
raise ValueError('Step ({}) > total number of steps ({}).'.format(step, self.total_steps))
def sce_loss(x, y, alpha=1):
x = F.normalize(x, p=2, dim=-1)
y = F.normalize(y, p=2, dim=-1)
loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)
loss = loss.mean()
return loss
def mask(batch, mask_rate):
clique_idx = batch.clique_batch
num_cliques = len(clique_idx)
perm = torch.randperm(num_cliques, device=batch.x.device)
num_mask_cliques = int(mask_rate * num_cliques)
mask_clique_idx = perm[:num_mask_cliques]
get_items = itemgetter(*mask_clique_idx)
mask_cliques = get_items(clique_idx)
t_mask_nodes = torch.cat([t.view(-1) for t in mask_cliques])
return t_mask_nodes
class CG(nn.Module):
# rate学习率
def __init__(self, num_layer, pre_dim, emb_dim, JK, drop_ratio, mask_rate, pre_mask=False):
super(CG, self).__init__()
self.graphs = []
self.drop_rate = drop_ratio
self.online_encoder = GNN(num_layer, pre_dim=pre_dim, emb_dim=emb_dim, JK=JK, drop_ratio=self.drop_rate,
pre=True)
self.target_encoder = copy.deepcopy(self.online_encoder)
self.target_encoder.reset_parameters()
self.mask_rate = mask_rate
self.enc_mask_token = nn.Parameter(torch.zeros(1, pre_dim))
self.criterion = self.setup_loss_fn("sce", 1)
self.pool = global_add_pool
self.pool_mean = global_mean_pool
self.pre_mask = pre_mask
# stop gradient
for param in self.target_encoder.parameters():
param.requires_grad = False
def setup_loss_fn(self, loss_fn, alpha_l):
if loss_fn == "mse":
criterion = nn.MSELoss()
elif loss_fn == "sce":
criterion = partial(sce_loss, alpha=alpha_l)
else:
raise NotImplementedError
return criterion
def trainable_parameters(self):
r"""Returns the parameters that will be updated via an optimizer."""
return list(self.online_encoder.parameters())
@torch.no_grad()
def update_target_network(self, mm):
r"""Performs a momentum update of the target network's weights.
Args:
mm (float): Momentum used in moving average update.
"""
for param_q, param_k in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
param_k.data.mul_(mm).add_(param_q.data, alpha=1. - mm)
def forward(self, graph_batch, device):
graph_batch = graph_batch.to(device)
x = graph_batch.x.clone().float().to(device)
if self.pre_mask:
mask_nodes = graph_batch.mask_nodes.long()
else:
mask_nodes = mask(graph_batch, self.mask_rate)
x[mask_nodes] = 0.0
x[mask_nodes] += self.enc_mask_token
h1 = self.online_encoder(x, graph_batch.edge_index, graph_batch.edge_attr)
with torch.no_grad():
h2 = self.target_encoder(graph_batch.x.float(), graph_batch.edge_index, graph_batch.edge_attr)
# g1_pool = self.pool_mean(h1, graph_batch.batch)
# g2_pool = self.pool_mean(h2, graph_batch.batch)
# loss_1 = self.criterion(g1_pool, g2_pool)
h1_mask_pool = self.pool(h1[mask_nodes], graph_batch.batch[mask_nodes])
h2_mask_pool = self.pool(h2[mask_nodes].detach(), graph_batch.batch[mask_nodes])
loss_2 = self.criterion(h1_mask_pool, h2_mask_pool)
return loss_2
def get_embed(self, graph_batch, device):
with torch.no_grad():
graph_batch = graph_batch.to(device)
h1 = self.online_encoder(graph_batch.x, graph_batch.edge_index, graph_batch.edge_attr)
h = self.pool(h1, graph_batch.batch)
return h.detach()
class LogReg(torch.nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.graph_pred_linear = torch.nn.Linear(in_dim, out_dim)
def forward(self, h):
z = self.graph_pred_linear(h)
# z = torch.sigmoid(z)
return z
def computer_roc(log, model, loader, device):
y_true = []
y_scores = []
log.eval()
for step, batch in enumerate(loader):
z = model.get_embed(batch, device)
with torch.no_grad():
pred = log(z)
y_true.append(batch.y.view(pred.shape))
y_scores.append(pred)
y_true = torch.cat(y_true, dim=0).cpu().detach().numpy()
y_scores = torch.cat(y_scores, dim=0).cpu().detach().numpy()
roc_list = []
for i in range(y_true.shape[1]):
# AUC is only defined when there is at least one positive data.
if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == -1) > 0:
is_valid = y_true[:, i] ** 2 > 0
roc_list.append(roc_auc_score((y_true[is_valid, i] + 1) / 2, y_scores[is_valid, i]))
roc = sum(roc_list) / len(roc_list)
return roc
def evaluation(model, writer, loader, train_loader, val_loader, test_loader, num_tasks, device,
num_test_epoch=100):
xent = nn.BCEWithLogitsLoss(reduction="none")
log = LogReg(300, num_tasks).to(device)
opt = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=0.0)
loss_list, train_roc, val_roc, test_roc = [], [], [], []
model.eval()
pdar_2 = tqdm(range(num_test_epoch))
best_val_roc = 0
test_best_roc = []
for epoch in pdar_2:
total_loss = []
for step, batch in enumerate(loader):
z = model.get_embed(batch, device)
# train
log.train()
pred = log(z)
y = batch.y.view(pred.shape).to(torch.float64)
is_valid = y ** 2 > 0
loss_mat = xent(pred.double(), (y + 1) / 2)
loss_mat = torch.where(is_valid, loss_mat,
torch.zeros(loss_mat.shape).to(loss_mat.device).to(loss_mat.dtype))
opt.zero_grad()
loss = torch.sum(loss_mat) / torch.sum(is_valid)
total_loss.append(loss.item())
loss.backward()
opt.step()
avg_loss = sum(total_loss) / len(total_loss)
logging.info("Epoch: {}, Loss: {}".format(epoch, round(avg_loss, 3)))
# roc_of_train
roc_train = computer_roc(log, model, train_loader, device)
# roc_of_val
roc_val = computer_roc(log, model, val_loader, device)
# roc_of_test
roc_test = computer_roc(log, model, test_loader, device)
best_test = 0
if roc_val >= best_val_roc:
if roc_val == best_val_roc and best_test < roc_test:
best_test = roc_test
test_best_roc.append([best_val_roc, roc_test, epoch])
continue
best_test = roc_test
best_val_roc = roc_val
test_best_roc.append([best_val_roc, roc_test, epoch])
writer.add_scalar("train_loss", avg_loss, epoch)
writer.add_scalar("train_roc", roc_train, epoch)
writer.add_scalar("val_roc", roc_val, epoch)
writer.add_scalar("test_roc", roc_test, epoch)
loss_list.append(avg_loss)
train_roc.append(roc_train)
val_roc.append(roc_val)
test_roc.append(roc_test)
pdar_2.set_description(
f"Epoch: {epoch}, Loss: {round(avg_loss, 3)}, Train_roc: {round(roc_train, 3)},Val_roc: {round(roc_val, 3)},Test_roc: {round(roc_test, 3)}")
return loss_list, train_roc, val_roc, test_roc, test_best_roc
def train(args, model, loader, train_loader, val_loader, test_loader, device):
dataname = args.dataset
save_name = args.save_name
num_epochs = args.epochs
num_tasks = args.num_tasks
num_test_epoch = args.test_epochs
writer = SummaryWriter(log_dir=f"./logs/{dataname}/{args.out_file}/{save_name}")
lr = args.lr
mm_l = args.mm
lr_scheduler = CosineDecayScheduler(lr, 100, num_epochs)
mm_scheduler = CosineDecayScheduler(1 - mm_l, 0, num_epochs)
optimizer = Adam(model.trainable_parameters(), lr=lr, weight_decay=1e-4)
print("开始进行训练......")
best_loss = float('inf')
pdar = tqdm(range(1, num_epochs + 1))
for epoch in pdar:
model.train()
# update learning rate
lr = lr_scheduler.get(epoch - 1)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
mm = 1 - mm_scheduler.get(epoch - 1)
losses = 0
for gs in loader:
loss = model(gs, device)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.update_target_network(mm)
losses += loss.item()
avg_loss = losses / len(loader)
if avg_loss < best_loss:
best_loss = avg_loss
torch.save(model.state_dict(), f"./result/{dataname}/{args.out_file}/{save_name}/{dataname}_best_model.pt")
pdar.set_description(f"Epoch: {epoch}, Loss: {round(losses / len(train_loader), 5)}")
writer.add_scalar("ContrastiveTrain_Loss", round(losses / len(train_loader), 5), epoch)
loss_list, train_roc, val_roc, test_roc, test_best_roc = evaluation(model, writer, loader, train_loader, val_loader,
test_loader, num_tasks, device,
num_test_epoch)
# 打开文件
with open(f"./result/{dataname}/{args.out_file}/{save_name}/result_of_train.txt", 'w') as f:
# 写入数据
for i in range(len(loss_list)):
f.write(f'{loss_list[i]:.4f},{train_roc[i]:.4f},{val_roc[i]:.4f},{test_roc[i]:.4f}\n')
with open(f"./result/{dataname}/{args.out_file}/{save_name}/result_of_roc_test.txt", 'w') as f:
# 写入数据
for i in range(len(test_best_roc)):
f.write("val_roc,test_roc,epoch\n")
val_roc, test_roc, epoch = test_best_roc[i][0], test_best_roc[i][1], test_best_roc[i][2],
f.write(f'{round(val_roc, 4)},{round(test_roc, 4)},{epoch}\n')
return test_best_roc[-1]
def grid_search(args, device, loader, train_loader, val_loader, test_loader):
"""
网格搜索函数
:param args: 参数对象
:param device: 计算设备
:param train_loader: 训练集 DataLoader
:param val_loader: 验证集 DataLoader
:param test_loader: 测试集 DataLoader
:return: 最优参数和最优得分
"""
param_grid = {
'lr': [0.001, 0.01, 0.05, 0.005],
'dropout_ratio': [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5],
# 'emb_dim': [128, 256, 512],
# 'num_layer': [3, 5, 7],
# 'pre_dim': [2, 4, 6],
"batch_size": [32, 64, 128, 256],
# "hander_mt": [1,2,3],
'mask_ratio': [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45,0.50]
}
best_score = -np.inf
best_params = None
for params in ParameterGrid(param_grid):
set_random_seed(0)
for k, v in params.items():
setattr(args, k, v)
model = CG(args.num_layer, args.pre_dim, args.emb_dim, args.JK, args.dropout_ratio, args.mask_ratio,
pre_mask=args.pre_mask).to(device)
model.load_state_dict(torch.load(args.input_model_file))
test_roc = train(args, model, loader, train_loader, val_loader, test_loader, device)
if test_roc[1] > best_score:
best_score = test_roc[1]
best_params = params
print(f"params: {params}, test_roc: {test_roc[1]}")
print(f"best params: {best_params}, best score: {best_score}")
return best_params, best_score
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
## cessary args
parser.add_argument('--dataset', type=str, default='clintox')
parser.add_argument('--input_model_file', type=str,
default='./result/pre_training/Method2/zinc_standard_agent_best_model.pth',
help='filename to read the model (if there is any)')
parser.add_argument("--seeds", type=int, nargs="+", default=[0])
parser.add_argument('--save_name', type=str, default='exp1')
parser.add_argument('--batch_size', type=int, default=128, help='input batch size for training (default: 32)')
parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train (default: 100)')
parser.add_argument('--pre_dim', type=int, default=2)
parser.add_argument('--test_epochs', type=int, default=100, help='number of epochs to train (default: 100)')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate (default: 0.001)')
parser.add_argument('--num_layer', type=int, default=3, help='number of GNN message passing layers (default: 5).')
parser.add_argument('--emb_dim', type=int, default=300, help='embedding dimensions (default: 300)')
parser.add_argument('--dropout_ratio', type=float, default=0.20, help='dropout ratio (default: 0.5)')
parser.add_argument('--mask_ratio', type=float, default=0.37, help='mask ratio (default: 0.5)')
parser.add_argument('--JK', type=str, default="last")
parser.add_argument('--split', type=str, default="scaffold", help="random or scaffold or random_scaffold")
parser.add_argument('--num_workers', type=int, default=8, help='number of workers for dataset loading')
# 默认1为不对官能团进行任何处理,2是保留长的官能团以及保留去重之后短的官能团,3是保留短的官能团,不遮蔽长的官能团
parser.add_argument('--hander_mt', type=int, default=1, help='1,2,3')
parser.add_argument('--pre_mask', type=bool, default=True, help='number of workers for dataset loading')
args = parser.parse_args()
args_Dict = args.__dict__
seeds = args.seeds
print(args)
dataname = args.dataset
split = args.split
batch_size = args.batch_size
num_workers = args.num_workers
if dataname == "tox21":
num_tasks = 12
elif dataname == "hiv":
num_tasks = 1
elif dataname == "muv":
num_tasks = 17
elif dataname == "bace":
num_tasks = 1
elif dataname == "bbbp":
num_tasks = 1
elif dataname == "toxcast":
num_tasks = 617
elif dataname == "sider":
num_tasks = 27
elif dataname == "clintox":
num_tasks = 2
else:
raise ValueError("Invalid dataset name.")
device = torch.device("cuda:0")
if args.pre_mask:
dataset = MoleculeDataset2("./dataset/" + dataname, dataset=dataname, mask_ratio=args.mask_ratio,
hander_mt=args.hander_mt)
if split == "scaffold":
smiles_list = pd.read_csv('../dataset/' + dataname + '/processed/smiles.csv', header=None)[
0].tolist()
train_dataset, valid_dataset, test_dataset = scaffold_split(dataset, smiles_list, null_value=0,
frac_train=0.8,
frac_valid=0.1, frac_test=0.1)
print("数据集划分方法为: scaffold")
else:
raise ValueError("Invalid split option.")
loader = DataLoaderMasking2(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
train_loader = DataLoaderMasking2(train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
val_loader = DataLoaderMasking2(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoaderMasking2(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
else:
dataset = MoleculeDataset1("../dataset/" + dataname, dataset=dataname)
if split == "scaffold":
smiles_list = pd.read_csv('../dataset/' + dataname + '/processed/smiles.csv', header=None)[
0].tolist()
train_dataset, valid_dataset, test_dataset = scaffold_split(dataset, smiles_list, null_value=0,
frac_train=0.8,
frac_valid=0.1, frac_test=0.1)
print("数据集划分方法为: scaffold")
else:
raise ValueError("Invalid split option.")
loader = DataLoaderMasking1(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
train_loader = DataLoaderMasking1(train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
val_loader = DataLoaderMasking1(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoaderMasking1(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
args.num_tasks = num_tasks
args.mm = 0.999
roc_list = []
test_roc_list = []
j = 1
while True:
out_file = 'test'
out_file = out_file + str(j)
if not os.path.exists(f"./result/{dataname}/{out_file}"):
os.makedirs(f"./result/{dataname}/{out_file}")
args.out_file = out_file
break
j = j + 1
with open(f"./result/{dataname}/{out_file}/args.txt", 'w', encoding='utf-8') as f:
for eachArg, value in args_Dict.items():
f.writelines(eachArg + ' : ' + str(value) + '\n')
for i, seed in enumerate(seeds):
folder_name = "exp"
j = 1
while True:
new_folder_name = folder_name + str(j)
if not os.path.exists(f"./result/{dataname}/{out_file}/{new_folder_name}"):
os.makedirs(f"./result/{dataname}/{out_file}/{new_folder_name}")
args.save_name = new_folder_name
break
j += 1
print(f"####### Run {i} for seed {seed}")
set_random_seed(seed)
#
grid_search(args, device, loader, train_loader, val_loader, test_loader)
# for k,v in best_params.items():
# setattr(args,k,v)
#
# model = CG(args.num_layer, args.pre_dim, args.emb_dim, "last", args.dropout_ratio, args.mask_ratio,
# pre_mask=args.pre_mask).to(device)
# model.load_state_dict(torch.load(args.input_model_file))
# test_roc = train(args, model, loader, train_loader, val_loader, test_loader, device)
# roc_list.append(test_roc)
# test_roc_list.append(test_roc[1])
# final_roc, final_roc_std = np.mean(test_roc_list), np.std(test_roc_list)
# print(final_roc, final_roc_std)
#
# with open(f"./result/{dataname}/{args.out_file}/{args.save_name}/result_of_train.txt", "w") as f:
# for i in roc_list:
# val_roc, test_roc, epoch = i[0], i[1], i[2]
# f.write(f'{round(val_roc, 4)},{round(test_roc, 4)},{epoch}\n')
# f.write("# final_acc:" + str(final_roc) + "±" + str(final_roc_std) + "\n")
if __name__ == '__main__':
main()
# dataname = 'sider'
# dataset = MoleculeDataset2("./dataset/" + dataname, dataset=dataname, mask_ratio=0.24,hander_mt=1)
# print(dataset[0])
# loader = DataLoaderMasking2(dataset, batch_size=256, shuffle=False, num_workers=0)
# z = 0
# for i in loader:
# z += len(i.mask_nodes) / i.x.size(0)
# print(len(i.mask_nodes) / i.x.size(0))
# print(z / len(loader), 'pj')
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。