代码拉取完成,页面将自动刷新
import os
import pickle
from itertools import repeat, chain
import networkx as nx
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Descriptors
from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect
from torch.utils import data
from torch_geometric.data import InMemoryDataset
def read_book(cli_path, edge_path):
file = open(cli_path, 'r', encoding='utf-8')
data = file.readlines()
all_edge = []
edges = open(edge_path, 'r', encoding='utf-8')
for e in edges:
s = e.strip('\n')
lst = eval(s)
# 将元组转换为二维列表
lst_2d = [[x, y] for x, y in lst]
all_edge.append(lst_2d)
clique = []
cliques = []
mol_index = 0
for index, cli in enumerate(data):
cli = cli.strip('\n')
index_cli = (cli.split(' /'))
indexs = index_cli[0]
mol, cli_index = indexs.split(' ')
mol, cli_index = int(mol), int(cli_index)
if mol != mol_index:
cliques.append(clique)
clique = []
c = eval(index_cli[1])
clique.append(c)
mol_index = mol
else:
c = eval(index_cli[1])
clique.append(c)
cliques.append(clique)
return cliques, all_edge
# allowable node and edge features
allowable_features = {
'possible_atomic_num_list': list(range(1, 119)),
'possible_formal_charge_list': [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5],
'possible_chirality_list': [
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
Chem.rdchem.ChiralType.CHI_OTHER
],
'possible_hybridization_list': [
Chem.rdchem.HybridizationType.S,
Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED
],
'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8],
'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6],
'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'possible_bonds': [
Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC
],
'possible_bond_dirs': [ # only for double bond stereo information
Chem.rdchem.BondDir.NONE,
Chem.rdchem.BondDir.ENDUPRIGHT,
Chem.rdchem.BondDir.ENDDOWNRIGHT
]
}
def mol_to_graph_data_obj_simple(mol):
"""
Converts rdkit mol object to graph Data object required by the pytorch
geometric package. NB: Uses simplified atom and bond features, and represent
as indices
:param mol: rdkit mol object
:return: graph data object with the attributes: x, edge_index, edge_attr
"""
# atoms
# mol = Chem.MolFromSmiles(smiles_string)
# atoms
# atoms
atom_features_list = []
for atom in mol.GetAtoms():
atom_features_list.append(atom_to_feature_vector(atom))
x = torch.tensor(np.array(atom_features_list), dtype=torch.long)
# bonds
num_bond_features = 2 # bond type, bond direction
if len(mol.GetBonds()) > 0: # mol has bonds
edges_list = []
edge_features_list = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edge_feature = [allowable_features['possible_bonds'].index(
bond.GetBondType())] + [allowable_features[
'possible_bond_dirs'].index(
bond.GetBondDir())]
edges_list.append((i, j))
edge_features_list.append(edge_feature)
edges_list.append((j, i))
edge_features_list.append(edge_feature)
# data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)
# data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
edge_attr = torch.tensor(np.array(edge_features_list),
dtype=torch.long)
else: # mol has no bonds
edge_index = torch.empty((2, 0), dtype=torch.long)
edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
return data
def graph_data_obj_to_mol_simple(data_x, data_edge_index, data_edge_attr):
"""
Convert pytorch geometric data obj to rdkit mol object. NB: Uses simplified
atom and bond features, and represent as indices.
:param: data_x:
:param: data_edge_index:
:param: data_edge_attr
:return:
"""
mol = Chem.RWMol()
# atoms
atom_features = data_x.cpu().numpy()
num_atoms = atom_features.shape[0]
for i in range(num_atoms):
atomic_num_idx, chirality_tag_idx = atom_features[i]
atomic_num = allowable_features['possible_atomic_num_list'][atomic_num_idx]
chirality_tag = allowable_features['possible_chirality_list'][chirality_tag_idx]
atom = Chem.Atom(atomic_num)
atom.SetChiralTag(chirality_tag)
mol.AddAtom(atom)
# bonds
edge_index = data_edge_index.cpu().numpy()
edge_attr = data_edge_attr.cpu().numpy()
num_bonds = edge_index.shape[1]
for j in range(0, num_bonds, 2):
begin_idx = int(edge_index[0, j])
end_idx = int(edge_index[1, j])
bond_type_idx, bond_dir_idx = edge_attr[j]
bond_type = allowable_features['possible_bonds'][bond_type_idx]
bond_dir = allowable_features['possible_bond_dirs'][bond_dir_idx]
mol.AddBond(begin_idx, end_idx, bond_type)
# set bond direction
new_bond = mol.GetBondBetweenAtoms(begin_idx, end_idx)
new_bond.SetBondDir(bond_dir)
# Chem.SanitizeMol(mol) # fails for COC1=CC2=C(NC(=N2)[S@@](=O)CC2=NC=C(
# C)C(OC)=C2C)C=C1, when aromatic bond is possible
# when we do not have aromatic bonds
# Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
return mol
def graph_data_obj_to_nx_simple(data):
"""
Converts graph Data object required by the pytorch geometric package to
network x data object. NB: Uses simplified atom and bond features,
and represent as indices. NB: possible issues with recapitulating relative
stereochemistry since the edges in the nx object are unordered.
:param data: pytorch geometric Data object
:return: network x object
"""
G = nx.Graph()
# atoms
atom_features = data.x.cpu().numpy()
num_atoms = atom_features.shape[0]
for i in range(num_atoms):
atomic_num_idx, chirality_tag_idx = atom_features[i]
G.add_node(i, atom_num_idx=atomic_num_idx, chirality_tag_idx=chirality_tag_idx)
pass
# bonds
edge_index = data.edge_index.cpu().numpy()
edge_attr = data.edge_attr.cpu().numpy()
num_bonds = edge_index.shape[1]
for j in range(0, num_bonds, 2):
begin_idx = int(edge_index[0, j])
end_idx = int(edge_index[1, j])
bond_type_idx, bond_dir_idx = edge_attr[j]
if not G.has_edge(begin_idx, end_idx):
G.add_edge(begin_idx, end_idx, bond_type_idx=bond_type_idx,
bond_dir_idx=bond_dir_idx)
return G
def nx_to_graph_data_obj_simple(G):
"""
Converts nx graph to pytorch geometric Data object. Assume node indices
are numbered from 0 to num_nodes - 1. NB: Uses simplified atom and bond
features, and represent as indices. NB: possible issues with
recapitulating relative stereochemistry since the edges in the nx
object are unordered.
:param G: nx graph obj
:return: pytorch geometric Data object
"""
# atoms
num_atom_features = 2 # atom type, chirality tag
atom_features_list = []
for _, node in G.nodes(data=True):
atom_feature = [node['atom_num_idx'], node['chirality_tag_idx']]
atom_features_list.append(atom_feature)
x = torch.tensor(np.array(atom_features_list), dtype=torch.long)
# bonds
num_bond_features = 2 # bond type, bond direction
if len(G.edges()) > 0: # mol has bonds
edges_list = []
edge_features_list = []
for i, j, edge in G.edges(data=True):
edge_feature = [edge['bond_type_idx'], edge['bond_dir_idx']]
edges_list.append((i, j))
edge_features_list.append(edge_feature)
edges_list.append((j, i))
edge_features_list.append(edge_feature)
# data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)
# data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
edge_attr = torch.tensor(np.array(edge_features_list),
dtype=torch.long)
else: # mol has no bonds
edge_index = torch.empty((2, 0), dtype=torch.long)
edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
return data
def get_gasteiger_partial_charges(mol, n_iter=12):
"""
Calculates list of gasteiger partial charges for each atom in mol object.
:param mol: rdkit mol object
:param n_iter: number of iterations. Default 12
:return: list of computed partial charges for each atom.
"""
Chem.rdPartialCharges.ComputeGasteigerCharges(mol, nIter=n_iter,
throwOnParamFailure=True)
partial_charges = [float(a.GetProp('_GasteigerCharge')) for a in
mol.GetAtoms()]
return partial_charges
def create_standardized_mol_id(smiles):
"""
:param smiles:
:return: inchi
"""
if check_smiles_validity(smiles):
# remove stereochemistry
smiles = AllChem.MolToSmiles(AllChem.MolFromSmiles(smiles),
isomericSmiles=False)
mol = AllChem.MolFromSmiles(smiles)
if mol != None: # to catch weird issue with O=C1O[al]2oc(=O)c3ccc(cn3)c3ccccc3c3cccc(c3)c3ccccc3c3cc(C(F)(F)F)c(cc3o2)-c2ccccc2-c2cccc(c2)-c2ccccc2-c2cccnc21
if '.' in smiles: # if multiple species, pick largest molecule
mol_species_list = split_rdkit_mol_obj(mol)
largest_mol = get_largest_mol(mol_species_list)
inchi = AllChem.MolToInchi(largest_mol)
else:
inchi = AllChem.MolToInchi(mol)
return inchi
else:
return
else:
return
def read_clipue(dataset):
cli_path = "./dataset/" + dataset + "/processed/clique_dict.txt"
edge_path = "./dataset/" + dataset + "/processed/edge_dict.txt"
clique_dict, all_edges = read_book(cli_path, edge_path)
return clique_dict, all_edges
def pad_clique(clique, max_len):
new_lst = [[sublst[i] if i < len(sublst) else -1 for i in range(max_len)] for sublst in clique]
return new_lst
def max_len(cliques):
max_len = 0
for x in cliques:
leng = max([len(sublst) for sublst in x])
if leng > max_len:
max_len = leng
return max_len
def get_cli_length(clique):
return [len(c) for c in clique]
class MoleculeDataset1(InMemoryDataset):
def __init__(self,
root,
# data = None,
# slices = None,
transform=None,
pre_transform=None,
pre_filter=None,
dataset='zinc250k',
empty=False):
"""
Adapted from qm9.py. Disabled the download functionality
:param root: directory of the dataset, containing a raw and processed
dir. The raw dir should contain the file containing the smiles, and the
processed dir can either empty or a previously processed file
:param dataset: name of the dataset. Currently only implemented for
zinc250k, chembl_with_labels, tox21, hiv, bace, bbbp, clintox, esol,
freesolv, lipophilicity, muv, pcba, sider, toxcast
:param empty: if True, then will not load any data obj. For
initializing empty dataset
"""
self.dataset = dataset
self.root = root
super(MoleculeDataset1, self).__init__(root, transform, pre_transform,
pre_filter)
self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter
if not empty:
self.data, self.slices = torch.load(self.processed_paths[0])
def get(self, idx):
data = Data()
for key in self.data.keys:
item, slices = self.data[key], self.slices[key]
s = list(repeat(slice(None), item.dim()))
s[data.__cat_dim__(key, item)] = slice(slices[idx],
slices[idx + 1])
data[key] = item[s]
return data
@property
def raw_file_names(self):
file_name_list = os.listdir(self.raw_dir)
# assert len(file_name_list) == 1 # currently assume we have a
# # single raw file
return file_name_list
@property
def processed_file_names(self):
return 'geometric_data_processed.pt'
def download(self):
raise NotImplementedError('Must indicate valid location of raw data. '
'No download allowed')
def process(self):
data_smiles_list = []
data_list = []
if self.dataset == 'zinc_standard_agent':
input_path = './dataset/zinc_standard_agent/raw/testfile.csv.gz'
clique_dict, all_edges = read_clipue(self.dataset)
del clique_dict[0]
max_length = max_len(clique_dict)
input_df = pd.read_csv(input_path, sep=',', compression='gzip',
dtype='str')
smiles_list = list(input_df['smiles'])
for i in range(len(smiles_list)):
s = smiles_list[i]
# each example contains a single species
try:
rdkit_mol = AllChem.MolFromSmiles(s)
if rdkit_mol != None: # ignore invalid mol objects
# # convert aromatic bonds to double bonds
data = mol_to_graph_data_obj_simple(rdkit_mol)
# manually add mol id
data.clique_slice=torch.tensor(get_cli_length(clique_dict[i]))
data.clique_idx = torch.tensor(pad_clique(clique_dict[i], max_length))
data.clique_edge = torch.tensor(all_edges[i])
data.id = torch.tensor(
[i]) # id here is zinc id value, stripped of
# leading zeros
data_list.append(data)
data_smiles_list.append(smiles_list[i])
except:
continue
elif self.dataset == "zinc_sample":
input_path = self.raw_paths[0]
with open(input_path, "r") as f:
data = f.readlines()
clique_dict, all_edges = read_clipue(self.dataset)
all_data = [x.strip() for x in data]
data_smiles_list = []
data_list = []
for i, item in enumerate(all_data):
s = item
try:
rdkit_mol = AllChem.MolFromSmiles(s)
if rdkit_mol != None:
data = mol_to_graph_data_obj_simple(rdkit_mol)
# manually add mol id
id = i
data.id = torch.tensor([id]) # id here is zinc id value, stripped of
# leading zeros
data_list.append(data)
data_smiles_list.append(s)
except:
continue
elif self.dataset == 'chembl_filtered':
### get downstream test molecules.
from splitters import scaffold_split
###
downstream_dir = [
'chem_dataset/dataset/bace',
'chem_dataset/dataset/bbbp',
'chem_dataset/dataset/clintox',
'chem_dataset/dataset/hiv',
'chem_dataset/dataset/muv',
'chem_dataset/dataset/sider',
'chem_dataset/dataset/tox21',
'chem_dataset/dataset/toxcast'
]
downstream_inchi_set = set()
for d_path in downstream_dir:
print(d_path)
dataset_name = d_path.split('/')[1]
downstream_dataset = MoleculeDataset1(d_path, dataset=dataset_name)
downstream_smiles = pd.read_csv(os.path.join(d_path,
'processed', 'smiles.csv'),
header=None)[0].tolist()
assert len(downstream_dataset) == len(downstream_smiles)
_, _, _, (train_smiles, valid_smiles, test_smiles) = scaffold_split(downstream_dataset,
downstream_smiles, task_idx=None,
null_value=0,
frac_train=0.8, frac_valid=0.1,
frac_test=0.1,
return_smiles=True)
### remove both test and validation molecules
remove_smiles = test_smiles + valid_smiles
downstream_inchis = []
for smiles in remove_smiles:
species_list = smiles.split('.')
for s in species_list: # record inchi for all species, not just
# largest (by default in create_standardized_mol_id if input has
# multiple species)
inchi = create_standardized_mol_id(s)
downstream_inchis.append(inchi)
downstream_inchi_set.update(downstream_inchis)
smiles_list, rdkit_mol_objs, folds, labels = \
_load_chembl_with_labels_dataset(os.path.join(self.root, 'raw'))
print('processing')
for i in range(len(rdkit_mol_objs)):
rdkit_mol = rdkit_mol_objs[i]
if rdkit_mol != None:
# # convert aromatic bonds to double bonds
# Chem.SanitizeMol(rdkit_mol,
# sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
mw = Descriptors.MolWt(rdkit_mol)
if 50 <= mw <= 900:
inchi = create_standardized_mol_id(smiles_list[i])
if inchi != None and inchi not in downstream_inchi_set:
data = mol_to_graph_data_obj_simple(rdkit_mol)
# manually add mol id
data.id = torch.tensor(
[i]) # id here is the index of the mol in
# the dataset
data.y = torch.tensor(labels[i, :])
# fold information
if i in folds[0]:
data.fold = torch.tensor([0])
elif i in folds[1]:
data.fold = torch.tensor([1])
else:
data.fold = torch.tensor([2])
data_list.append(data)
data_smiles_list.append(smiles_list[i])
elif self.dataset == 'tox21':
clique_dict, all_edges = read_clipue(self.dataset)
max_length = max_len(clique_dict)
smiles_list, rdkit_mol_objs, labels = \
_load_tox21_dataset(self.raw_paths[0])
for i in range(len(smiles_list)):
rdkit_mol = rdkit_mol_objs[i]
## convert aromatic bonds to double bonds
# Chem.SanitizeMol(rdkit_mol,
# sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
data = mol_to_graph_data_obj_simple(rdkit_mol)
# manually add mol id
data.clique_slice = torch.tensor(get_cli_length(clique_dict[i]))
data.clique_idx = torch.tensor(pad_clique(clique_dict[i], max_length))
data.clique_edge = torch.tensor(all_edges[i])
data.id = torch.tensor(
[i]) # id here is the index of the mol in
# the dataset
data.y = torch.tensor(labels[i, :])
data_list.append(data)
data_smiles_list.append(smiles_list[i])
elif self.dataset == 'hiv':
clique_dict, all_edges = read_clipue(self.dataset)
max_length = max_len(clique_dict)
smiles_list, rdkit_mol_objs, labels = \
_load_hiv_dataset(self.raw_paths[0])
for i in range(len(smiles_list)):
rdkit_mol = rdkit_mol_objs[i]
# # convert aromatic bonds to double bonds
# Chem.SanitizeMol(rdkit_mol,
# sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
data = mol_to_graph_data_obj_simple(rdkit_mol)
# manually add mol id
# manually add mol id
data.clique_slice = torch.tensor(get_cli_length(clique_dict[i]))
data.clique_idx = torch.tensor(pad_clique(clique_dict[i], max_length))
data.clique_edge = torch.tensor(all_edges[i])
data.id = torch.tensor(
[i]) # id here is the index of the mol in
# the dataset
data.y = torch.tensor([labels[i]])
data_list.append(data)
data_smiles_list.append(smiles_list[i])
elif self.dataset == 'bace':
clique_dict, all_edges = read_clipue(self.dataset)
max_length = max_len(clique_dict)
smiles_list, rdkit_mol_objs, folds, labels = \
_load_bace_dataset(self.raw_paths[0])
for i in range(len(smiles_list)):
rdkit_mol = rdkit_mol_objs[i]
# # convert aromatic bonds to double bonds
# Chem.SanitizeMol(rdkit_mol,
# sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
data = mol_to_graph_data_obj_simple(rdkit_mol)
# manually add mol id
# manually add mol id
data.clique_slice = torch.tensor(get_cli_length(clique_dict[i]))
data.clique_idx = torch.tensor(pad_clique(clique_dict[i], max_length))
data.clique_edge = torch.tensor(all_edges[i])
data.id = torch.tensor(
[i]) # id here is the index of the mol in
# the dataset
data.y = torch.tensor([labels[i]])
data.fold = torch.tensor([folds[i]])
data_list.append(data)
data_smiles_list.append(smiles_list[i])
elif self.dataset == 'bbbp':
clique_dict, all_edges = read_clipue(self.dataset)
max_length = max_len(clique_dict)
smiles_list, rdkit_mol_objs, labels = \
_load_bbbp_dataset(self.raw_paths[0])
for i in range(len(smiles_list)):
rdkit_mol = rdkit_mol_objs[i]
if rdkit_mol != None:
# # convert aromatic bonds to double bonds
# Chem.SanitizeMol(rdkit_mol,
# sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
data = mol_to_graph_data_obj_simple(rdkit_mol)
# manually add mol id
data.clique_slice = torch.tensor(get_cli_length(clique_dict[i]))
data.clique_idx = torch.tensor(pad_clique(clique_dict[i], max_length))
data.clique_edge = torch.tensor(all_edges[i])
data.id = torch.tensor(
[i]) # id here is the index of the mol in
# the dataset
data.y = torch.tensor([labels[i]])
data_list.append(data)
data_smiles_list.append(smiles_list[i])
elif self.dataset == 'clintox':
clique_dict, all_edges = read_clipue(self.dataset)
max_length = max_len(clique_dict)
smiles_list, rdkit_mol_objs, labels = \
_load_clintox_dataset(self.raw_paths[0])
for i in range(len(smiles_list)):
rdkit_mol = rdkit_mol_objs[i]
if rdkit_mol != None:
# # convert aromatic bonds to double bonds
# Chem.SanitizeMol(rdkit_mol,
# sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
data = mol_to_graph_data_obj_simple(rdkit_mol)
# manually add mol id
data.clique_slice = torch.tensor(get_cli_length(clique_dict[i]))
data.clique_idx = torch.tensor(pad_clique(clique_dict[i], max_length))
data.clique_edge = torch.tensor(all_edges[i])
data.id = torch.tensor(
[i]) # id here is the index of the mol in
# the dataset
data.y = torch.tensor(labels[i, :])
data_list.append(data)
data_smiles_list.append(smiles_list[i])
elif self.dataset == 'esol':
clique_dict, all_edges = read_clipue(self.dataset)
smiles_list, rdkit_mol_objs, labels = \
_load_esol_dataset(self.raw_paths[0])
for i in range(len(smiles_list)):
rdkit_mol = rdkit_mol_objs[i]
# # convert aromatic bonds to double bonds
# Chem.SanitizeMol(rdkit_mol,
# sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
data = mol_to_graph_data_obj_simple(rdkit_mol)
# manually add mol id
data.id = torch.tensor(
[i]) # id here is the index of the mol in
# the dataset
data.y = torch.tensor([labels[i]])
data_list.append(data)
data_smiles_list.append(smiles_list[i])
elif self.dataset == 'freesolv':
clique_dict, all_edges = read_clipue(self.dataset)
smiles_list, rdkit_mol_objs, labels = \
_load_freesolv_dataset(self.raw_paths[0])
for i in range(len(smiles_list)):
rdkit_mol = rdkit_mol_objs[i]
# # convert aromatic bonds to double bonds
# Chem.SanitizeMol(rdkit_mol,
# sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
data = mol_to_graph_data_obj_simple(rdkit_mol)
# manually add mol id
data.id = torch.tensor(
[i]) # id here is the index of the mol in
# the dataset
data.y = torch.tensor([labels[i]])
data_list.append(data)
data_smiles_list.append(smiles_list[i])
elif self.dataset == 'lipophilicity':
clique_dict, all_edges = read_clipue(self.dataset)
smiles_list, rdkit_mol_objs, labels = \
_load_lipophilicity_dataset(self.raw_paths[0])
for i in range(len(smiles_list)):
rdkit_mol = rdkit_mol_objs[i]
# # convert aromatic bonds to double bonds
# Chem.SanitizeMol(rdkit_mol,
# sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
data = mol_to_graph_data_obj_simple(rdkit_mol)
# manually add mol id
data.id = torch.tensor(
[i]) # id here is the index of the mol in
# the dataset
data.y = torch.tensor([labels[i]])
data_list.append(data)
data_smiles_list.append(smiles_list[i])
elif self.dataset == 'muv':
clique_dict, all_edges = read_clipue(self.dataset)
max_length = max_len(clique_dict)
smiles_list, rdkit_mol_objs, labels = \
_load_muv_dataset(self.raw_paths[0])
for i in range(len(smiles_list)):
rdkit_mol = rdkit_mol_objs[i]
# # convert aromatic bonds to double bonds
# Chem.SanitizeMol(rdkit_mol,
# sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
data = mol_to_graph_data_obj_simple(rdkit_mol)
# manually add mol id
data.clique_slice = torch.tensor(get_cli_length(clique_dict[i]))
data.clique_idx = torch.tensor(pad_clique(clique_dict[i], max_length))
data.clique_edge = torch.tensor(all_edges[i])
data.id = torch.tensor(
[i]) # id here is the index of the mol in
# the dataset
data.y = torch.tensor(labels[i, :])
data_list.append(data)
data_smiles_list.append(smiles_list[i])
elif self.dataset == 'pcba':
clique_dict, all_edges = read_clipue(self.dataset)
smiles_list, rdkit_mol_objs, labels = \
_load_pcba_dataset(self.raw_paths[0])
for i in range(len(smiles_list)):
rdkit_mol = rdkit_mol_objs[i]
# # convert aromatic bonds to double bonds
# Chem.SanitizeMol(rdkit_mol,
# sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
data = mol_to_graph_data_obj_simple(rdkit_mol)
# manually add mol id
data.id = torch.tensor(
[i]) # id here is the index of the mol in
# the dataset
data.y = torch.tensor(labels[i, :])
data_list.append(data)
data_smiles_list.append(smiles_list[i])
elif self.dataset == 'pcba_pretrain':
smiles_list, rdkit_mol_objs, labels = \
_load_pcba_dataset(self.raw_paths[0])
downstream_inchi = set(pd.read_csv(os.path.join(self.root,
'downstream_mol_inchi_may_24_2019'),
sep=',', header=None)[0])
for i in range(len(smiles_list)):
if '.' not in smiles_list[i]: # remove examples with
# multiples species
rdkit_mol = rdkit_mol_objs[i]
mw = Descriptors.MolWt(rdkit_mol)
if 50 <= mw <= 900:
inchi = create_standardized_mol_id(smiles_list[i])
if inchi != None and inchi not in downstream_inchi:
# # convert aromatic bonds to double bonds
# Chem.SanitizeMol(rdkit_mol,
# sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
data = mol_to_graph_data_obj_simple(rdkit_mol)
# manually add mol id
data.id = torch.tensor(
[i]) # id here is the index of the mol in
# the dataset
data.y = torch.tensor(labels[i, :])
data_list.append(data)
data_smiles_list.append(smiles_list[i])
# elif self.dataset == ''
elif self.dataset == 'sider':
clique_dict, all_edges = read_clipue(self.dataset)
max_length = max_len(clique_dict)
smiles_list, rdkit_mol_objs, labels = \
_load_sider_dataset(self.raw_paths[0])
for i in range(len(smiles_list)):
rdkit_mol = rdkit_mol_objs[i]
# # convert aromatic bonds to double bonds
# Chem.SanitizeMol(rdkit_mol,
# sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
data = mol_to_graph_data_obj_simple(rdkit_mol)
# manually add mol id
data.clique_slice = torch.tensor(get_cli_length(clique_dict[i]))
data.clique_idx = torch.tensor(pad_clique(clique_dict[i], max_length))
data.clique_edge = torch.tensor(all_edges[i])
data.id = torch.tensor(
[i]) # id here is the index of the mol in
# the dataset
data.y = torch.tensor(labels[i, :])
data_list.append(data)
data_smiles_list.append(smiles_list[i])
elif self.dataset == 'toxcast':
clique_dict, all_edges = read_clipue(self.dataset)
max_length = max_len(clique_dict)
smiles_list, rdkit_mol_objs, labels = \
_load_toxcast_dataset("./dataset/toxcast/raw/toxcast_data.csv")
for i in range(len(smiles_list)):
rdkit_mol = rdkit_mol_objs[i]
if rdkit_mol != None:
# # convert aromatic bonds to double bonds
# Chem.SanitizeMol(rdkit_mol,
# sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
data = mol_to_graph_data_obj_simple(rdkit_mol)
# manually add mol id
data.clique_slice = torch.tensor(get_cli_length(clique_dict[i]))
data.clique_idx = torch.tensor(pad_clique(clique_dict[i], max_length))
data.clique_edge = torch.tensor(all_edges[i])
data.id = torch.tensor(
[i]) # id here is the index of the mol in
# the dataset
data.y = torch.tensor(labels[i, :])
data_list.append(data)
data_smiles_list.append(smiles_list[i])
elif self.dataset == 'ptc_mr':
input_path = self.raw_paths[0]
input_df = pd.read_csv(input_path, sep=',', header=None, names=['id', 'label', 'smiles'])
smiles_list = input_df['smiles']
labels = input_df['label'].values
for i in range(len(smiles_list)):
s = smiles_list[i]
rdkit_mol = AllChem.MolFromSmiles(s)
if rdkit_mol != None: # ignore invalid mol objects
# # convert aromatic bonds to double bonds
# Chem.SanitizeMol(rdkit_mol,
# sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
data = mol_to_graph_data_obj_simple(rdkit_mol)
# manually add mol id
data.id = torch.tensor(
[i])
data.y = torch.tensor([labels[i]])
data_list.append(data)
data_smiles_list.append(smiles_list[i])
elif self.dataset == 'mutag':
smiles_path = os.path.join(self.root, 'raw', 'mutag_188_data.can')
# smiles_path = 'dataset/mutag/raw/mutag_188_data.can'
labels_path = os.path.join(self.root, 'raw', 'mutag_188_target.txt')
# labels_path = 'dataset/mutag/raw/mutag_188_target.txt'
smiles_list = pd.read_csv(smiles_path, sep=' ', header=None)[0]
labels = pd.read_csv(labels_path, header=None)[0].values
for i in range(len(smiles_list)):
s = smiles_list[i]
rdkit_mol = AllChem.MolFromSmiles(s)
if rdkit_mol != None: # ignore invalid mol objects
# # convert aromatic bonds to double bonds
# Chem.SanitizeMol(rdkit_mol,
# sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
data = mol_to_graph_data_obj_simple(rdkit_mol)
# manually add mol id
data.id = torch.tensor(
[i])
data.y = torch.tensor([labels[i]])
data_list.append(data)
data_smiles_list.append(smiles_list[i])
else:
raise ValueError('Invalid dataset name')
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
# write data_smiles_list in processed paths
data_smiles_series = pd.Series(data_smiles_list)
data_smiles_series.to_csv(os.path.join(self.processed_dir,
'smiles.csv'), index=False,
header=False)
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
# NB: only properly tested when dataset_1 is chembl_with_labels and dataset_2
# is pcba_pretrain
def merge_dataset_objs(dataset_1, dataset_2):
"""
Naively merge 2 molecule dataset objects, and ignore identities of
molecules. Assumes both datasets have multiple y labels, and will pad
accordingly. ie if dataset_1 has obj_1 with y dim 1310 and dataset_2 has
obj_2 with y dim 128, then the resulting obj_1 and obj_2 will have dim
1438, where obj_1 have the last 128 cols with 0, and obj_2 have
the first 1310 cols with 0.
:return: pytorch geometric dataset obj, with the x, edge_attr, edge_index,
new y attributes only
"""
d_1_y_dim = dataset_1[0].y.size()[0]
d_2_y_dim = dataset_2[0].y.size()[0]
data_list = []
# keep only x, edge_attr, edge_index, padded_y then append
for d in dataset_1:
old_y = d.y
new_y = torch.cat([old_y, torch.zeros(d_2_y_dim, dtype=torch.long)])
data_list.append(Data(x=d.x, edge_index=d.edge_index,
edge_attr=d.edge_attr, y=new_y))
for d in dataset_2:
old_y = d.y
new_y = torch.cat([torch.zeros(d_1_y_dim, dtype=torch.long), old_y.long()])
data_list.append(Data(x=d.x, edge_index=d.edge_index,
edge_attr=d.edge_attr, y=new_y))
# create 'empty' dataset obj. Just randomly pick a dataset and root path
# that has already been processed
new_dataset = MoleculeDataset1(root='dataset/chembl_with_labels',
dataset='chembl_with_labels', empty=True)
# collate manually
new_dataset.data, new_dataset.slices = new_dataset.collate(data_list)
return new_dataset
def create_circular_fingerprint(mol, radius, size, chirality):
"""
:param mol:
:param radius:
:param size:
:param chirality:
:return: np array of morgan fingerprint
"""
fp = GetMorganFingerprintAsBitVect(mol, radius,
nBits=size, useChirality=chirality)
return np.array(fp)
class MoleculeFingerprintDataset(data.Dataset):
def __init__(self, root, dataset, radius, size, chirality=True):
"""
Create dataset object containing list of dicts, where each dict
contains the circular fingerprint of the molecule, label, id,
and possibly precomputed fold information
:param root: directory of the dataset, containing a raw and
processed_fp dir. The raw dir should contain the file containing the
smiles, and the processed_fp dir can either be empty or a
previously processed file
:param dataset: name of dataset. Currently only implemented for
tox21, hiv, chembl_with_labels
:param radius: radius of the circular fingerprints
:param size: size of the folded fingerprint vector
:param chirality: if True, fingerprint includes chirality information
"""
self.dataset = dataset
self.root = root
self.radius = radius
self.size = size
self.chirality = chirality
self._load()
def _process(self):
data_smiles_list = []
data_list = []
if self.dataset == 'chembl_with_labels':
smiles_list, rdkit_mol_objs, folds, labels = \
_load_chembl_with_labels_dataset(os.path.join(self.root, 'raw'))
print('processing')
for i in range(len(rdkit_mol_objs)):
print(i)
rdkit_mol = rdkit_mol_objs[i]
if rdkit_mol != None:
# # convert aromatic bonds to double bonds
fp_arr = create_circular_fingerprint(rdkit_mol,
self.radius,
self.size, self.chirality)
fp_arr = torch.tensor(fp_arr)
# manually add mol id
id = torch.tensor([i]) # id here is the index of the mol in
# the dataset
y = torch.tensor(labels[i, :])
# fold information
if i in folds[0]:
fold = torch.tensor([0])
elif i in folds[1]:
fold = torch.tensor([1])
else:
fold = torch.tensor([2])
data_list.append({'fp_arr': fp_arr, 'id': id, 'y': y,
'fold': fold})
data_smiles_list.append(smiles_list[i])
elif self.dataset == 'tox21':
smiles_list, rdkit_mol_objs, labels = \
_load_tox21_dataset(os.path.join(self.root, 'raw/tox21.csv'))
print('processing')
for i in range(len(smiles_list)):
print(i)
rdkit_mol = rdkit_mol_objs[i]
## convert aromatic bonds to double bonds
fp_arr = create_circular_fingerprint(rdkit_mol,
self.radius,
self.size,
self.chirality)
fp_arr = torch.tensor(fp_arr)
# manually add mol id
id = torch.tensor([i]) # id here is the index of the mol in
# the dataset
y = torch.tensor(labels[i, :])
data_list.append({'fp_arr': fp_arr, 'id': id, 'y': y})
data_smiles_list.append(smiles_list[i])
elif self.dataset == 'hiv':
smiles_list, rdkit_mol_objs, labels = \
_load_hiv_dataset(os.path.join(self.root, 'raw/HIV.csv'))
print('processing')
for i in range(len(smiles_list)):
print(i)
rdkit_mol = rdkit_mol_objs[i]
# # convert aromatic bonds to double bonds
fp_arr = create_circular_fingerprint(rdkit_mol,
self.radius,
self.size,
self.chirality)
fp_arr = torch.tensor(fp_arr)
# manually add mol id
id = torch.tensor([i]) # id here is the index of the mol in
# the dataset
y = torch.tensor([labels[i]])
data_list.append({'fp_arr': fp_arr, 'id': id, 'y': y})
data_smiles_list.append(smiles_list[i])
else:
raise ValueError('Invalid dataset name')
# save processed data objects and smiles
processed_dir = os.path.join(self.root, 'processed_fp')
data_smiles_series = pd.Series(data_smiles_list)
data_smiles_series.to_csv(os.path.join(processed_dir, 'smiles.csv'),
index=False,
header=False)
with open(os.path.join(processed_dir,
'fingerprint_data_processed.pkl'),
'wb') as f:
pickle.dump(data_list, f)
def _load(self):
processed_dir = os.path.join(self.root, 'processed_fp')
# check if saved file exist. If so, then load from save
file_name_list = os.listdir(processed_dir)
if 'fingerprint_data_processed.pkl' in file_name_list:
with open(os.path.join(processed_dir,
'fingerprint_data_processed.pkl'),
'rb') as f:
self.data_list = pickle.load(f)
# if no saved file exist, then perform processing steps, save then
# reload
else:
self._process()
self._load()
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
## if iterable class is passed, return dataset objection
if hasattr(index, "__iter__"):
dataset = MoleculeFingerprintDataset(self.root, self.dataset, self.radius, self.size,
chirality=self.chirality)
dataset.data_list = [self.data_list[i] for i in index]
return dataset
else:
return self.data_list[index]
def _load_tox21_dataset(input_path):
"""
:param input_path:
:return: list ofshur smiles, list of rdkit mol obj, np.array containing the
labels
"""
input_df = pd.read_csv(input_path, sep=',')
smiles_list = input_df['smiles']
rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list]
tasks = ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD',
'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']
labels = input_df[tasks]
# convert 0 to -1
labels = labels.replace(0, -1)
# convert nan to 0
labels = labels.fillna(0)
assert len(smiles_list) == len(rdkit_mol_objs_list)
assert len(smiles_list) == len(labels)
return smiles_list, rdkit_mol_objs_list, labels.values
def _load_hiv_dataset(input_path):
"""
:param input_path:
:return: list of smiles, list of rdkit mol obj, np.array containing the
labels
"""
input_df = pd.read_csv(input_path, sep=',')
smiles_list = input_df['smiles']
rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list]
labels = input_df['HIV_active']
# convert 0 to -1
labels = labels.replace(0, -1)
# there are no nans
assert len(smiles_list) == len(rdkit_mol_objs_list)
assert len(smiles_list) == len(labels)
return smiles_list, rdkit_mol_objs_list, labels.values
def _load_bace_dataset(input_path):
"""
:param input_path:
:return: list of smiles, list of rdkit mol obj, np.array
containing indices for each of the 3 folds, np.array containing the
labels
"""
input_df = pd.read_csv(input_path, sep=',')
smiles_list = input_df['mol']
rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list]
labels = input_df['Class']
# convert 0 to -1
labels = labels.replace(0, -1)
# there are no nans
folds = input_df['Model']
folds = folds.replace('Train', 0) # 0 -> train
folds = folds.replace('Valid', 1) # 1 -> valid
folds = folds.replace('Test', 2) # 2 -> test
assert len(smiles_list) == len(rdkit_mol_objs_list)
assert len(smiles_list) == len(labels)
assert len(smiles_list) == len(folds)
return smiles_list, rdkit_mol_objs_list, folds.values, labels.values
def _load_bbbp_dataset(input_path):
"""
:param input_path:
:return: list of smiles, list of rdkit mol obj, np.array containing the
labels
"""
input_df = pd.read_csv(input_path, sep=',')
smiles_list = input_df['smiles']
rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list]
preprocessed_rdkit_mol_objs_list = [m if m != None else None for m in
rdkit_mol_objs_list]
preprocessed_smiles_list = [AllChem.MolToSmiles(m) if m != None else
None for m in preprocessed_rdkit_mol_objs_list]
labels = input_df['p_np']
# convert 0 to -1
labels = labels.replace(0, -1)
# there are no nans
assert len(smiles_list) == len(preprocessed_rdkit_mol_objs_list)
assert len(smiles_list) == len(preprocessed_smiles_list)
assert len(smiles_list) == len(labels)
return preprocessed_smiles_list, preprocessed_rdkit_mol_objs_list, \
labels.values
def _load_clintox_dataset(input_path):
"""
:param input_path:
:return: list of smiles, list of rdkit mol obj, np.array containing the
labels
"""
input_df = pd.read_csv(input_path, sep=',')
smiles_list = input_df['smiles']
rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list]
preprocessed_rdkit_mol_objs_list = [m if m != None else None for m in
rdkit_mol_objs_list]
preprocessed_smiles_list = [AllChem.MolToSmiles(m) if m != None else
None for m in preprocessed_rdkit_mol_objs_list]
tasks = ['FDA_APPROVED', 'CT_TOX']
labels = input_df[tasks]
# convert 0 to -1
labels = labels.replace(0, -1)
# there are no nans
assert len(smiles_list) == len(preprocessed_rdkit_mol_objs_list)
assert len(smiles_list) == len(preprocessed_smiles_list)
assert len(smiles_list) == len(labels)
return preprocessed_smiles_list, preprocessed_rdkit_mol_objs_list, \
labels.values
# input_path = 'dataset/clintox/raw/clintox.csv'
# smiles_list, rdkit_mol_objs_list, labels = _load_clintox_dataset(input_path)
def _load_esol_dataset(input_path):
"""
:param input_path:
:return: list of smiles, list of rdkit mol obj, np.array containing the
labels (regression task)
"""
# NB: some examples have multiple species
input_df = pd.read_csv(input_path, sep=',')
smiles_list = input_df['smiles']
rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list]
labels = input_df['measured log solubility in mols per litre']
assert len(smiles_list) == len(rdkit_mol_objs_list)
assert len(smiles_list) == len(labels)
return smiles_list, rdkit_mol_objs_list, labels.values
# input_path = 'dataset/esol/raw/delaney-processed.csv'
# smiles_list, rdkit_mol_objs_list, labels = _load_esol_dataset(input_path)
def _load_freesolv_dataset(input_path):
"""
:param input_path:
:return: list of smiles, list of rdkit mol obj, np.array containing the
labels (regression task)
"""
input_df = pd.read_csv(input_path, sep=',')
smiles_list = input_df['smiles']
rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list]
labels = input_df['expt']
assert len(smiles_list) == len(rdkit_mol_objs_list)
assert len(smiles_list) == len(labels)
return smiles_list, rdkit_mol_objs_list, labels.values
def _load_lipophilicity_dataset(input_path):
"""
:param input_path:
:return: list of smiles, list of rdkit mol obj, np.array containing the
labels (regression task)
"""
input_df = pd.read_csv(input_path, sep=',')
smiles_list = input_df['smiles']
rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list]
labels = input_df['exp']
assert len(smiles_list) == len(rdkit_mol_objs_list)
assert len(smiles_list) == len(labels)
return smiles_list, rdkit_mol_objs_list, labels.values
def _load_muv_dataset(input_path):
"""
:param input_path:
:return: list of smiles, list of rdkit mol obj, np.array containing the
labels
"""
input_df = pd.read_csv(input_path, sep=',')
smiles_list = input_df['smiles']
rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list]
tasks = ['MUV-466', 'MUV-548', 'MUV-600', 'MUV-644', 'MUV-652', 'MUV-689',
'MUV-692', 'MUV-712', 'MUV-713', 'MUV-733', 'MUV-737', 'MUV-810',
'MUV-832', 'MUV-846', 'MUV-852', 'MUV-858', 'MUV-859']
labels = input_df[tasks]
# convert 0 to -1
labels = labels.replace(0, -1)
# convert nan to 0
labels = labels.fillna(0)
assert len(smiles_list) == len(rdkit_mol_objs_list)
assert len(smiles_list) == len(labels)
return smiles_list, rdkit_mol_objs_list, labels.values
def _load_sider_dataset(input_path):
"""
:param input_path:
:return: list of smiles, list of rdkit mol obj, np.array containing the
labels
"""
input_df = pd.read_csv(input_path, sep=',')
smiles_list = input_df['smiles']
rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list]
tasks = ['Hepatobiliary disorders',
'Metabolism and nutrition disorders', 'Product issues', 'Eye disorders',
'Investigations', 'Musculoskeletal and connective tissue disorders',
'Gastrointestinal disorders', 'Social circumstances',
'Immune system disorders', 'Reproductive system and breast disorders',
'Neoplasms benign, malignant and unspecified (incl cysts and polyps)',
'General disorders and administration site conditions',
'Endocrine disorders', 'Surgical and medical procedures',
'Vascular disorders', 'Blood and lymphatic system disorders',
'Skin and subcutaneous tissue disorders',
'Congenital, familial and genetic disorders',
'Infections and infestations',
'Respiratory, thoracic and mediastinal disorders',
'Psychiatric disorders', 'Renal and urinary disorders',
'Pregnancy, puerperium and perinatal conditions',
'Ear and labyrinth disorders', 'Cardiac disorders',
'Nervous system disorders',
'Injury, poisoning and procedural complications']
labels = input_df[tasks]
# convert 0 to -1
labels = labels.replace(0, -1)
assert len(smiles_list) == len(rdkit_mol_objs_list)
assert len(smiles_list) == len(labels)
return smiles_list, rdkit_mol_objs_list, labels.values
def _load_toxcast_dataset(input_path):
"""
:param input_path:
:return: list of smiles, list of rdkit mol obj, np.array containing the
labels
"""
# NB: some examples have multiple species, some example smiles are invalid
input_df = pd.read_csv(input_path, sep=',')
smiles_list = input_df['smiles']
rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list]
# Some smiles could not be successfully converted
# to rdkit mol object so them to None
preprocessed_rdkit_mol_objs_list = [m if m != None else None for m in
rdkit_mol_objs_list]
preprocessed_smiles_list = [AllChem.MolToSmiles(m) if m != None else
None for m in preprocessed_rdkit_mol_objs_list]
tasks = list(input_df.columns)[1:]
labels = input_df[tasks]
# convert 0 to -1
labels = labels.replace(0, -1)
# convert nan to 0
labels = labels.fillna(0)
assert len(smiles_list) == len(preprocessed_rdkit_mol_objs_list)
assert len(smiles_list) == len(preprocessed_smiles_list)
assert len(smiles_list) == len(labels)
return preprocessed_smiles_list, preprocessed_rdkit_mol_objs_list, \
labels.values
def _load_chembl_with_labels_dataset(root_path):
"""
Data from 'Large-scale comparison of machine learning methods for drug target prediction on ChEMBL'
:param root_path: path to the folder containing the reduced chembl dataset
:return: list of smiles, preprocessed rdkit mol obj list, list of np.array
containing indices for each of the 3 folds, np.array containing the labels
"""
# adapted from https://github.com/ml-jku/lsc/blob/master/pythonCode/lstm/loadData.py
# first need to download the files and unzip:
# wget http://bioinf.jku.at/research/lsc/chembl20/dataPythonReduced.zip
# unzip and rename to chembl_with_labels
# wget http://bioinf.jku.at/research/lsc/chembl20/dataPythonReduced/chembl20Smiles.pckl
# into the dataPythonReduced directory
# wget http://bioinf.jku.at/research/lsc/chembl20/dataPythonReduced/chembl20LSTM.pckl
# 1. load folds and labels
f = open(os.path.join(root_path, 'folds0.pckl'), 'rb')
folds = pickle.load(f)
f.close()
f = open(os.path.join(root_path, 'labelsHard.pckl'), 'rb')
targetMat = pickle.load(f)
sampleAnnInd = pickle.load(f)
targetAnnInd = pickle.load(f)
f.close()
targetMat = targetMat
targetMat = targetMat.copy().tocsr()
targetMat.sort_indices()
targetAnnInd = targetAnnInd
targetAnnInd = targetAnnInd - targetAnnInd.min()
folds = [np.intersect1d(fold, sampleAnnInd.index.values).tolist() for fold in folds]
targetMatTransposed = targetMat[sampleAnnInd[list(chain(*folds))]].T.tocsr()
targetMatTransposed.sort_indices()
# # num positive examples in each of the 1310 targets
trainPosOverall = np.array([np.sum(targetMatTransposed[x].data > 0.5) for x in range(targetMatTransposed.shape[0])])
# # num negative examples in each of the 1310 targets
trainNegOverall = np.array(
[np.sum(targetMatTransposed[x].data < -0.5) for x in range(targetMatTransposed.shape[0])])
# dense array containing the labels for the 456331 molecules and 1310 targets
denseOutputData = targetMat.A # possible values are {-1, 0, 1}
# 2. load structures
f = open(os.path.join(root_path, 'chembl20LSTM.pckl'), 'rb')
rdkitArr = pickle.load(f)
f.close()
assert len(rdkitArr) == denseOutputData.shape[0]
assert len(rdkitArr) == len(folds[0]) + len(folds[1]) + len(folds[2])
preprocessed_rdkitArr = []
print('preprocessing')
for i in range(len(rdkitArr)):
print(i)
m = rdkitArr[i]
if m == None:
preprocessed_rdkitArr.append(None)
else:
mol_species_list = split_rdkit_mol_obj(m)
if len(mol_species_list) == 0:
preprocessed_rdkitArr.append(None)
else:
largest_mol = get_largest_mol(mol_species_list)
if len(largest_mol.GetAtoms()) <= 2:
preprocessed_rdkitArr.append(None)
else:
preprocessed_rdkitArr.append(largest_mol)
assert len(preprocessed_rdkitArr) == denseOutputData.shape[0]
smiles_list = [AllChem.MolToSmiles(m) if m != None else None for m in
preprocessed_rdkitArr] # bc some empty mol in the
# rdkitArr zzz...
assert len(preprocessed_rdkitArr) == len(smiles_list)
return smiles_list, preprocessed_rdkitArr, folds, denseOutputData
# root_path = 'dataset/chembl_with_labels'
def check_smiles_validity(smiles):
try:
m = Chem.MolFromSmiles(smiles)
if m:
return True
else:
return False
except:
return False
def split_rdkit_mol_obj(mol):
"""
Split rdkit mol object containing multiple species or one species into a
list of mol objects or a list containing a single object respectively
:param mol:
:return:
"""
smiles = AllChem.MolToSmiles(mol, isomericSmiles=True)
smiles_list = smiles.split('.')
mol_species_list = []
for s in smiles_list:
if check_smiles_validity(s):
mol_species_list.append(AllChem.MolFromSmiles(s))
return mol_species_list
def get_largest_mol(mol_list):
"""
Given a list of rdkit mol objects, returns mol object containing the
largest num of atoms. If multiple containing largest num of atoms,
picks the first one
:param mol_list:
:return:
"""
num_atoms_list = [len(m.GetAtoms()) for m in mol_list]
largest_mol_idx = num_atoms_list.index(max(num_atoms_list))
return mol_list[largest_mol_idx]
def create_all_datasets():
#### create dataset
downstream_dir = [
'bace',
'bbbp',
'clintox',
'esol',
'freesolv',
'hiv',
'lipophilicity',
'muv',
'sider',
'tox21',
'toxcast'
]
for dataset_name in downstream_dir:
print(dataset_name)
root = "dataset/" + dataset_name
os.makedirs(root + "/processed", exist_ok=True)
dataset = MoleculeDataset1(root, dataset=dataset_name)
print(dataset)
dataset = MoleculeDataset1(root="chem_dataset/dataset/chembl_filtered", dataset="chembl_filtered")
print(dataset)
dataset = MoleculeDataset1(root="chem_dataset/dataset/zinc_standard_agent", dataset="zinc_standard_agent")
print(dataset)
import numpy as np
import torch
from ogb.utils.features import (atom_to_feature_vector,
bond_to_feature_vector)
from rdkit import Chem
def smiles2graph(smiles_string):
"""
Converts SMILES string to graph Data object
:input: SMILES string (str)
:return: graph object
"""
mol = Chem.MolFromSmiles(smiles_string)
# atoms
atom_features_list = []
for atom in mol.GetAtoms():
atom_features_list.append(atom_to_feature_vector(atom))
x = torch.tensor(np.array(atom_features_list), dtype=torch.long)
# bonds
num_bond_features = 3 # bond type, bond stereo, is_conjugated
if len(mol.GetBonds()) > 0: # mol has bonds
edges_list = []
edge_features_list = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edge_feature = bond_to_feature_vector(bond)
# add edges in both directions
edges_list.append((i, j))
edge_features_list.append(edge_feature)
edges_list.append((j, i))
edge_features_list.append(edge_feature)
# data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)
# data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
edge_attr = torch.tensor(np.array(edge_features_list),
dtype=torch.long)
else: # mol has no bonds
edge_index = torch.empty((2, 0), dtype=torch.long)
edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)
graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
# graph.num_node_features = len(x)
return graph
from torch_geometric.data import Data
# test MoleculeDataset object
if __name__ == "__main__":
z = pad_clique([[[1, 2], [1, 2, 3, 4, 5, 6], [3, 4, 5]], [[1, 2], [1, 2, 3, 4, 5, 6, 3, 3, 4, 5], [3, 4, 5]]])
print(max_len)
# print(x)
# dataset = []
# # create_all_datasets()
# smiles_list, rdkit_mol_objs_list, labels = _load_bbbp_dataset("./chem_dataset/dataset/bbbp/raw/BBBP.csv")
# # print(smiles_list,rdkit_mol_objs_list,labels)
# s1 = []
# s2 = []
# for i in range(len(smiles_list)):
# if type(smiles_list[i]) != str:
# continue
# graph = smiles2graph(smiles_list[i])
# graph.y = torch.tensor(labels[i]).unsqueeze(0)
# s1.append(graph.edge_attr[:, 0].max())
# s2.append(graph.edge_attr[:, 1].max())
# graph.num_nodes = graph.x.size(0)
# dataset.append(graph)
# print(np.array(s1).max(), np.array(s2).max())
# torch.save(dataset, "chem_dataset/dataset/bbbp/processed/bbbp_data.pt")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。