diff --git a/MindChem/applications/unimol/README.md b/MindChem/applications/unimol/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..eb510c599430724f7df0cf0a98f971fee105d991
--- /dev/null
+++ b/MindChem/applications/unimol/README.md
@@ -0,0 +1,117 @@
+# Uni-Mol tools for various prediction and downstreams.
+
+Documentation of Uni-Mol tools is available at https://unimol.readthedocs.io/en/latest/
+
+
+
+
+
+## 描述
+unimol_tools 是一个 Python 包,集成了 Uni‑Mol1 和 Uni-Mol2,用于做分子 / 材料数据的性质预测、表示 (representation)、下游任务 (downstream)。它将 Uni‑Mol 的功能封装起来,降低使用门槛,让用户可以比较方便地在已有数据集上进行训练、预测、表示抽取等操作,而不必深入 Uni‑Mol 的底层实现。因为是轻量、独立的工具包 (相比原始 Uni‑Mol repository 更轻、更简洁):即使不使用GPU或NPU,甚至在 CPU 环境下也可以运行。
+
+本仓库提供了基于 MindSpore 的 unimol_tools 实现,改写自原始的 [Uni-Mol](https://github.com/deepmodeling/Uni-Mol) 仓库
+
+---
+
+## 快速开始 / 安装
+
+基础环境要求:
+
+```text
+python >= 3.10
+mindspore == 2.7.1
+CANN == 8.0.rc1
+```
+
+克隆 MindScience 仓库:
+
+```bash
+git clone https://gitee.com/mindspore/mindscience.git
+```
+
+模型权重下载:
+
+Uni-Molv1和Uni-Molv2权重可分别在[Weiland/Uni-Molv1](https://modelers.cn/models/Weiland/Uni-Molv1/tree/master)和[Weiland/Uni-Molv2](https://modelers.cn/models/Weiland/Uni-Molv2/tree/main/modelzoo)下载
+
+### 配置 Python 环境
+
+安装依赖包:
+
+```bash
+pip install -r requirements.txt
+```
+
+---
+
+## 使用方法
+
+### unimol_tools 基本用法
+
+以下示例可在仓库根目录运行:
+
+## molecule property prediction
+运行过程中程序若没有检测到所用的权重,程序会自动下载到unimol_tools/weights目录下
+```python
+from unimol_tools import MolTrain, MolPredict
+# 训练
+clf = MolTrain(
+ task='regression',
+ data_type='molecule', #目前支持oled和molecule
+ epochs=20,
+ kfold=5,
+ batch_size=10,
+ metrics='mse',
+ model_name='unimolv1', #可选unimolv1或unimolv2
+ model_size='84m', #unimolv2权重:84m/164m/310m/570m/1.1B
+)
+pred = clf.fit(data = train_data)
+# currently support data with smiles based csv/txt file, and
+# custom dict of {'atoms':[['C','C],['C','H','O']], 'coordinates':[coordinates_1,coordinates_2]}
+# 预测
+clf = MolPredict(load_model='../exp')
+res = clf.predict(data = data)
+```
+项目中准备了一个训练的小例子。用户可运行下面的脚本来进行训练。
+```
+python train_Esipt.py
+```
+该脚本所用的数据集存于文件data_S0N_dE_updated.pkl中。
+
+
+## unimol molecule and atoms level representation
+```python
+import numpy as np
+from unimol_tools import UniMolRepr
+# single smiles unimol representation
+clf = UniMolRepr(data_type='molecule', remove_hs=False)
+smiles = 'c1ccc(cc1)C2=NCC(=O)Nc3c2cc(cc3)[N+](=O)[O]'
+smiles_list = [smiles]
+unimol_repr = clf.get_repr(smiles_list, return_atomic_reprs=True)
+# CLS token repr
+print(np.array(unimol_repr['cls_repr']).shape)
+# atomic level repr, align with rdkit mol.GetAtoms()
+print(np.array(unimol_repr['atomic_reprs']).shape)
+```
+
+Please kindly cite our papers if you use the data/code/model.
+```
+@inproceedings{
+ zhou2023unimol,
+ title={Uni-Mol: A Universal 3D Molecular Representation Learning Framework},
+ author={Gengmo Zhou and Zhifeng Gao and Qiankun Ding and Hang Zheng and Hongteng Xu and Zhewei Wei and Linfeng Zhang and Guolin Ke},
+ booktitle={The Eleventh International Conference on Learning Representations },
+ year={2023},
+ url={https://openreview.net/forum?id=6K2RM6wVqKu}
+}
+@article{gao2023uni,
+ title={Uni-qsar: an auto-ml tool for molecular property prediction},
+ author={Gao, Zhifeng and Ji, Xiaohong and Zhao, Guojiang and Wang, Hongshuai and Zheng, Hang and Ke, Guolin and Zhang, Linfeng},
+ journal={arXiv preprint arXiv:2304.12239},
+ year={2023}
+}
+```
+
+License
+-------
+> Modified from [Uni-Mol](https://github.com/deepmodeling/Uni-Mol)
+> Original license: MIT License
diff --git a/MindChem/applications/unimol/data_S0N_dE_updated.pkl b/MindChem/applications/unimol/data_S0N_dE_updated.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..904a08f6755058f67f659e73c89333dd991988a0
Binary files /dev/null and b/MindChem/applications/unimol/data_S0N_dE_updated.pkl differ
diff --git a/MindChem/applications/unimol/download_weights.sh b/MindChem/applications/unimol/download_weights.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f210f1272dfdec7d55ee526e16b82e9f1ce9b958
--- /dev/null
+++ b/MindChem/applications/unimol/download_weights.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+set -e
+
+DOWNLOAD_DIR="./unimol_tools/weights"
+
+if ! command -v wget &> /dev/null
+then
+ echo "Error: wget could not be found. Please install wget (sudo apt-get install wget)"
+ exit 1
+fi
+
+#目前只支持下载Uni-Mol1的模型权重,由于Uni-Mol2的训练流程还未调通,暂时无法提供下载
+# 需要把已经转换的mol_pre_all_h_220816.ckpt上传到服务器后再放开下面的注释
+# wget -P "${DOWNLOAD_DIR}/ca_model_weights/" \
+# https://tools.mindspore.cn/dataset/workspace/mindspore_ckpt/ckpt/unimol/model_weights/mol_pre_all_h_220816.ckpt
+
diff --git a/MindChem/applications/unimol/img/unimol2_arch.jpg b/MindChem/applications/unimol/img/unimol2_arch.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e98b48897db43d7c6427141c73c5b741cd888645
Binary files /dev/null and b/MindChem/applications/unimol/img/unimol2_arch.jpg differ
diff --git a/MindChem/applications/unimol/requirements.txt b/MindChem/applications/unimol/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ca40a7d0a6884014ce385f58c1b525a51c81cc40
--- /dev/null
+++ b/MindChem/applications/unimol/requirements.txt
@@ -0,0 +1,12 @@
+numpy==1.22.4
+pandas==1.4.0
+scikit-learn==1.5.0
+mindspore == 2.7.1
+joblib
+rdkit
+pyyaml
+addict
+tqdm
+numba
+openmind_hub
+sympy
\ No newline at end of file
diff --git a/MindChem/applications/unimol/setup.py b/MindChem/applications/unimol/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f22271e41bcc5d0bdec4d8d677a5c3e84537a26
--- /dev/null
+++ b/MindChem/applications/unimol/setup.py
@@ -0,0 +1,50 @@
+"""Install script for setuptools."""
+
+from setuptools import find_packages
+from setuptools import setup
+
+setup(
+ name="unimol_tools",
+ version="0.1.3.post1",
+ description=(
+ "unimol_tools is a Python package for property prediction with Uni-Mol in molecule, materials and protein."
+ ),
+ long_description=open('README.md').read(),
+ long_description_content_type='text/markdown',
+ author="DP Technology",
+ author_email="unimol@dp.tech",
+ license="The MIT License",
+ url="https://github.com/deepmodeling/Uni-Mol/unimol_tools",
+ packages=find_packages(
+ where='.',
+ exclude=[
+ "build",
+ "dist",
+ ],
+ ),
+ install_requires=[
+ "numpy<2.0.0,>=1.22.4",
+ "pandas<2.0.0",
+ "torch",
+ "joblib",
+ "rdkit",
+ "pyyaml",
+ "addict",
+ "scikit-learn",
+ "numba",
+ "tqdm",
+ ],
+ python_requires=">=3.6",
+ include_package_data=True,
+ classifiers=[
+ "Development Status :: 5 - Production/Stable",
+ "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: Apache Software License",
+ "Operating System :: POSIX :: Linux",
+ "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ ],
+)
diff --git a/MindChem/applications/unimol/train_Esipt.py b/MindChem/applications/unimol/train_Esipt.py
new file mode 100644
index 0000000000000000000000000000000000000000..85e43c81f17d2bafd72e2c3adb0ab898f3d1c9d3
--- /dev/null
+++ b/MindChem/applications/unimol/train_Esipt.py
@@ -0,0 +1,47 @@
+def main():
+ import pickle
+ with open('data_S0N_dE_updated.pkl', 'rb') as f:
+ datas_ = pickle.load(f)
+
+ datas = {}
+ datas['atoms'] = []
+ datas['coordinates'] = []
+ datas['deltaE*'] = []
+ datas['extra_features'] = [] # 新增
+
+ for x in datas_:
+ datas['atoms'].append(x['elements'])
+ datas['coordinates'].append(x['coordinates'])
+ datas['deltaE*'].append(x['deltaE*'])
+ datas['extra_features'].append(x['Add_feature']) # 新增
+
+ train_data = {}
+ train_data['atoms'] = datas['atoms'][:]
+ train_data['coordinates'] = datas['coordinates'][:]
+ train_data['deltaE*'] = datas['deltaE*'][:]
+ train_data['extra_features'] = datas['extra_features'][:]
+
+ from unimol_tools import MolTrain, MolPredict
+
+ clf = MolTrain(
+ task='regression',
+ target_cols=['deltaE*'],
+ data_type='molecule',
+ epochs=20,
+ kfold=5,
+ batch_size=10,
+ metrics='mse',
+ model_name='unimolv1',
+ model_size='84m',
+ )
+ pred = clf.fit(data = train_data)
+
+ # 加载模型进行预测
+ # clf = MolPredict(load_model='./exp')
+ # res = clf.predict(data = train_data)
+ # print(res)
+
+ # currently support data with smiles based csv/txt file
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/MindChem/applications/unimol/unimol_tools/__init__.py b/MindChem/applications/unimol/unimol_tools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..18c3e995cab6d5d42ae3371dad039010a65a138d
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/__init__.py
@@ -0,0 +1,3 @@
+from .predict import MolPredict
+from .predictor import UniMolRepr
+from .train import MolTrain
diff --git a/MindChem/applications/unimol/unimol_tools/config/__init__.py b/MindChem/applications/unimol/unimol_tools/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a71b837c402179b1e0426d016c745fe9926a893
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/config/__init__.py
@@ -0,0 +1 @@
+from .model_config import MODEL_CONFIG, MODEL_CONFIG_V2
diff --git a/MindChem/applications/unimol/unimol_tools/config/default.yaml b/MindChem/applications/unimol/unimol_tools/config/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..317d320d6f4b7b686bd5b662e0341b97ce810b92
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/config/default.yaml
@@ -0,0 +1,23 @@
+### data
+smiles_col: "SMILES"
+target_col_prefix: "TARGET"
+target_normalize: "auto"
+anomaly_clean: False
+smi_strict: False
+### model
+model_name: "unimolv1"
+### trainer
+split_method: "5fold_random"
+split_seed: 42
+seed: 42
+logger_level: 1
+patience: 10
+max_epochs: 100
+learning_rate: 1e-4
+warmup_ratio: 0.03
+batch_size: 16
+max_norm: 5.0
+use_cuda: True
+use_amp: True
+use_ddp: True
+use_gpu: 0, 1
\ No newline at end of file
diff --git a/MindChem/applications/unimol/unimol_tools/config/model_config.py b/MindChem/applications/unimol/unimol_tools/config/model_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..daaea3a2ccbe3900d843fe2d60257c18f2148b63
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/config/model_config.py
@@ -0,0 +1,26 @@
+MODEL_CONFIG = {
+ "weight": {
+ "protein": "poc_pre_220816.ckpt",
+ "molecule_no_h": "mol_pre_no_h_220816.ckpt",
+ "molecule_all_h": "mol_pre_all_h_220816.ckpt",
+ "crystal": "mp_all_h_230313.ckpt",
+ "oled": "oled_pre_no_h_230101.ckpt",
+ },
+ "dict": {
+ "protein": "poc.dict.txt",
+ "molecule_no_h": "mol.dict.txt",
+ "molecule_all_h": "mol.dict.txt",
+ "crystal": "mp.dict.txt",
+ "oled": "oled.dict.txt",
+ },
+}
+
+MODEL_CONFIG_V2 = {
+ 'weight': {
+ '84m': 'modelzoo/84M/checkpoint.ckpt',
+ '164m': 'modelzoo/164M/checkpoint.ckpt',
+ '310m': 'modelzoo/310M/checkpoint.ckpt',
+ '570m': 'modelzoo/570M/checkpoint.ckpt',
+ '1.1B': 'modelzoo/1.1B/checkpoint.ckpt',
+ },
+}
diff --git a/MindChem/applications/unimol/unimol_tools/data/__init__.py b/MindChem/applications/unimol/unimol_tools/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c872a71ede54504a98885d04979001e3c7e3e7b
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/data/__init__.py
@@ -0,0 +1,2 @@
+from .datahub import DataHub
+from .dictionary import Dictionary
diff --git a/MindChem/applications/unimol/unimol_tools/data/conformer.py b/MindChem/applications/unimol/unimol_tools/data/conformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e80f4e1babe943d5fe2192d64596ee1a4cfd9621
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/data/conformer.py
@@ -0,0 +1,732 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import absolute_import, division, print_function
+
+import os
+import warnings
+
+import numpy as np
+from rdkit import Chem, RDLogger
+from rdkit.Chem import AllChem
+from scipy.spatial import distance_matrix
+
+RDLogger.DisableLog('rdApp.*')
+warnings.filterwarnings(action='ignore')
+from multiprocessing import Pool
+
+from numba import njit
+from tqdm import tqdm
+
+from ..config import MODEL_CONFIG
+from ..utils import logger
+from ..weights import WEIGHT_DIR, weight_download
+from .dictionary import Dictionary
+
+# https://github.com/snap-stanford/ogb/blob/master/ogb/utils/features.py
+# allowable multiple choice node and edge features
+allowable_features = {
+ "possible_atomic_num_list": list(range(1, 119)) + ["misc"],
+ "possible_chirality_list": [
+ "CHI_UNSPECIFIED",
+ "CHI_TETRAHEDRAL_CW",
+ "CHI_TETRAHEDRAL_CCW",
+ "CHI_TRIGONALBIPYRAMIDAL",
+ "CHI_OCTAHEDRAL",
+ "CHI_SQUAREPLANAR",
+ "CHI_OTHER",
+ ],
+ "possible_degree_list": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, "misc"],
+ "possible_formal_charge_list": [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, "misc"],
+ "possible_numH_list": [0, 1, 2, 3, 4, 5, 6, 7, 8, "misc"],
+ "possible_number_radical_e_list": [0, 1, 2, 3, 4, "misc"],
+ "possible_hybridization_list": ["SP", "SP2", "SP3", "SP3D", "SP3D2", "misc"],
+ "possible_is_aromatic_list": [False, True],
+ "possible_is_in_ring_list": [False, True],
+ "possible_bond_type_list": ["SINGLE", "DOUBLE", "TRIPLE", "AROMATIC", "misc"],
+ "possible_bond_stereo_list": [
+ "STEREONONE",
+ "STEREOZ",
+ "STEREOE",
+ "STEREOCIS",
+ "STEREOTRANS",
+ "STEREOANY",
+ ],
+ "possible_is_conjugated_list": [False, True],
+}
+
+
+class ConformerGen(object):
+ '''
+ This class designed to generate conformers for molecules represented as SMILES strings using provided parameters and configurations. The `transform` method uses multiprocessing to speed up the conformer generation process.
+ '''
+
+ def __init__(self, **params):
+ """
+ Initializes the neural network model based on the provided model name and parameters.
+
+ :param model_name: (str) The name of the model to initialize.
+ :param params: Additional parameters for model configuration.
+
+ :return: An instance of the specified neural network model.
+ :raises ValueError: If the model name is not recognized.
+ """
+ self._init_features(**params)
+
+ def _init_features(self, **params):
+ """
+ Initializes the features of the ConformerGen object based on provided parameters.
+
+ :param params: Arbitrary keyword arguments for feature configuration.
+ These can include the random seed, maximum number of atoms, data type,
+ generation method, generation mode, and whether to remove hydrogens.
+ """
+ self.seed = params.get('seed', 42)
+ self.max_atoms = params.get('max_atoms', 256)
+ self.data_type = params.get('data_type', 'molecule')
+ self.method = params.get('method', 'rdkit_random')
+ self.mode = params.get('mode', 'fast')
+ self.remove_hs = params.get('remove_hs', False)
+ if self.data_type == 'molecule':
+ name = "no_h" if self.remove_hs else "all_h"
+ name = self.data_type + '_' + name
+ self.dict_name = MODEL_CONFIG['dict'][name]
+ else:
+ self.dict_name = MODEL_CONFIG['dict'][self.data_type]
+ if not os.path.exists(os.path.join(WEIGHT_DIR, self.dict_name)):
+ weight_download(self.dict_name, WEIGHT_DIR)
+ self.dictionary = Dictionary.load(os.path.join(WEIGHT_DIR, self.dict_name))
+ self.dictionary.add_symbol("[MASK]", is_special=True)
+ if os.name == 'posix':
+ self.multi_process = params.get('multi_process', True)
+ else:
+ self.multi_process = params.get('multi_process', False)
+ if self.multi_process:
+ logger.warning(
+ 'Please use "if __name__ == "__main__":" to wrap the main function when using multi_process on Windows.'
+ )
+
+ def single_process(self, smiles):
+ """
+ Processes a single SMILES string to generate conformers using the specified method.
+
+ :param smiles: (str) The SMILES string representing the molecule.
+ :return: A unimolecular data representation (dictionary) of the molecule.
+ :raises ValueError: If the conformer generation method is unrecognized.
+ """
+ if self.method == 'rdkit_random':
+ atoms, coordinates, mol = inner_smi2coords(
+ smiles, seed=self.seed, mode=self.mode, remove_hs=self.remove_hs
+ )
+ feat = coords2unimol(
+ atoms,
+ coordinates,
+ self.dictionary,
+ self.max_atoms,
+ remove_hs=self.remove_hs,
+ )
+ return feat, mol
+ else:
+ raise ValueError(
+ 'Unknown conformer generation method: {}'.format(self.method)
+ )
+
+ def transform_raw(self, atoms_list, coordinates_list):
+
+ inputs = []
+ for atoms, coordinates in zip(atoms_list, coordinates_list):
+ inputs.append(
+ coords2unimol(
+ atoms,
+ coordinates,
+ self.dictionary,
+ self.max_atoms,
+ remove_hs=self.remove_hs,
+ )
+ )
+ return inputs
+
+ def transform_mols(self, mols_list):
+ inputs = []
+ for mol in mols_list:
+ atoms = np.array([atom.GetSymbol() for atom in mol.GetAtoms()])
+ coordinates = mol.GetConformer().GetPositions().astype(np.float32)
+ inputs.append(
+ coords2unimol(
+ atoms,
+ coordinates,
+ self.dictionary,
+ self.max_atoms,
+ remove_hs=self.remove_hs,
+ )
+ )
+ return inputs
+
+ def transform(self, smiles_list):
+ logger.info('Start generating conformers...')
+ if self.multi_process:
+ pool = Pool(processes=min(8, os.cpu_count()))
+ results = [
+ item for item in tqdm(pool.imap(self.single_process, smiles_list))
+ ]
+ pool.close()
+ else:
+ results = [self.single_process(smiles) for smiles in tqdm(smiles_list)]
+
+ inputs, mols = zip(*results)
+ inputs = list(inputs)
+ mols = list(mols)
+
+ failed_conf = [(item['src_coord'] == 0.0).all() for item in inputs]
+ logger.info(
+ 'Succeeded in generating conformers for {:.2f}% of molecules.'.format(
+ (1 - np.mean(failed_conf)) * 100
+ )
+ )
+ failed_conf_indices = [
+ index for index, value in enumerate(failed_conf) if value
+ ]
+ if len(failed_conf_indices) > 0:
+ logger.info('Failed conformers indices: {}'.format(failed_conf_indices))
+ logger.debug(
+ 'Failed conformers SMILES: {}'.format(
+ [smiles_list[index] for index in failed_conf_indices]
+ )
+ )
+
+ failed_conf_3d = [(item['src_coord'][:, 2] == 0.0).all() for item in inputs]
+ logger.info(
+ 'Succeeded in generating 3d conformers for {:.2f}% of molecules.'.format(
+ (1 - np.mean(failed_conf_3d)) * 100
+ )
+ )
+ failed_conf_3d_indices = [
+ index for index, value in enumerate(failed_conf_3d) if value
+ ]
+ if len(failed_conf_3d_indices) > 0:
+ logger.info(
+ 'Failed 3d conformers indices: {}'.format(failed_conf_3d_indices)
+ )
+ logger.debug(
+ 'Failed 3d conformers SMILES: {}'.format(
+ [smiles_list[index] for index in failed_conf_3d_indices]
+ )
+ )
+ return inputs, mols
+
+
+def inner_smi2coords(smi, seed=42, mode='fast', remove_hs=True, return_mol=False):
+ '''
+ This function is responsible for converting a SMILES (Simplified Molecular Input Line Entry System) string into 3D coordinates for each atom in the molecule. It also allows for the generation of 2D coordinates if 3D conformation generation fails, and optionally removes hydrogen atoms and their coordinates from the resulting data.
+
+ :param smi: (str) The SMILES representation of the molecule.
+ :param seed: (int, optional) The random seed for conformation generation. Defaults to 42.
+ :param mode: (str, optional) The mode of conformation generation, 'fast' for quick generation, 'heavy' for more attempts. Defaults to 'fast'.
+ :param remove_hs: (bool, optional) Whether to remove hydrogen atoms from the final coordinates. Defaults to True.
+
+ :return: A tuple containing the list of atom symbols and their corresponding 3D coordinates.
+ :raises AssertionError: If no atoms are present in the molecule or if the coordinates do not align with the atom count.
+ '''
+ mol = Chem.MolFromSmiles(smi)
+ mol = AllChem.AddHs(mol)
+ atoms = [atom.GetSymbol() for atom in mol.GetAtoms()]
+ assert len(atoms) > 0, 'No atoms in molecule: {}'.format(smi)
+ try:
+ # will random generate conformer with seed equal to -1. else fixed random seed.
+ res = AllChem.EmbedMolecule(mol, randomSeed=seed)
+ if res == 0:
+ try:
+ # some conformer can not use MMFF optimize
+ AllChem.MMFFOptimizeMolecule(mol)
+ coordinates = mol.GetConformer().GetPositions().astype(np.float32)
+ except:
+ coordinates = mol.GetConformer().GetPositions().astype(np.float32)
+ ## for fast test... ignore this ###
+ elif res == -1 and mode == 'heavy':
+ AllChem.EmbedMolecule(mol, maxAttempts=5000, randomSeed=seed)
+ try:
+ # some conformer can not use MMFF optimize
+ AllChem.MMFFOptimizeMolecule(mol)
+ coordinates = mol.GetConformer().GetPositions().astype(np.float32)
+ except:
+ AllChem.Compute2DCoords(mol)
+ coordinates_2d = mol.GetConformer().GetPositions().astype(np.float32)
+ coordinates = coordinates_2d
+ else:
+ AllChem.Compute2DCoords(mol)
+ coordinates_2d = mol.GetConformer().GetPositions().astype(np.float32)
+ coordinates = coordinates_2d
+ except:
+ print("Failed to generate conformer, replace with zeros.")
+ coordinates = np.zeros((len(atoms), 3))
+
+ if return_mol:
+ return mol # for unimolv2
+
+ assert len(atoms) == len(
+ coordinates
+ ), "coordinates shape is not align with {}".format(smi)
+ if remove_hs:
+ idx = [i for i, atom in enumerate(atoms) if atom != 'H']
+ atoms_no_h = [atom for atom in atoms if atom != 'H']
+ coordinates_no_h = coordinates[idx]
+ assert len(atoms_no_h) == len(
+ coordinates_no_h
+ ), "coordinates shape is not align with {}".format(smi)
+ return atoms_no_h, coordinates_no_h, mol
+ else:
+ return atoms, coordinates, mol
+
+
+def inner_coords(atoms, coordinates, remove_hs=True):
+ """
+ Processes a list of atoms and their corresponding coordinates to remove hydrogen atoms if specified.
+ This function takes a list of atom symbols and their corresponding coordinates and optionally removes hydrogen atoms from the output. It includes assertions to ensure the integrity of the data and uses numpy for efficient processing of the coordinates.
+
+ :param atoms: (list) A list of atom symbols (e.g., ['C', 'H', 'O']).
+ :param coordinates: (list of tuples or list of lists) Coordinates corresponding to each atom in the `atoms` list.
+ :param remove_hs: (bool, optional) A flag to indicate whether hydrogen atoms should be removed from the output.
+ Defaults to True.
+
+ :return: A tuple containing two elements; the filtered list of atom symbols and their corresponding coordinates.
+ If `remove_hs` is False, the original lists are returned.
+
+ :raises AssertionError: If the length of `atoms` list does not match the length of `coordinates` list.
+ """
+ assert len(atoms) == len(coordinates), "coordinates shape is not align atoms"
+ coordinates = np.array(coordinates).astype(np.float32)
+ if remove_hs:
+ idx = [i for i, atom in enumerate(atoms) if atom != 'H']
+ atoms_no_h = [atom for atom in atoms if atom != 'H']
+ coordinates_no_h = coordinates[idx]
+ assert len(atoms_no_h) == len(
+ coordinates_no_h
+ ), "coordinates shape is not align with atoms"
+ return atoms_no_h, coordinates_no_h
+ else:
+ return atoms, coordinates
+
+
+def coords2unimol(
+ atoms, coordinates, dictionary, max_atoms=256, remove_hs=True, **params
+):
+ """
+ Converts atom symbols and coordinates into a unified molecular representation.
+
+ :param atoms: (list) List of atom symbols.
+ :param coordinates: (ndarray) Array of atomic coordinates.
+ :param dictionary: (Dictionary) An object that maps atom symbols to unique integers.
+ :param max_atoms: (int) The maximum number of atoms to consider for the molecule.
+ :param remove_hs: (bool) Whether to remove hydrogen atoms from the representation.
+ :param params: Additional parameters.
+
+ :return: A dictionary containing the molecular representation with tokens, distances, coordinates, and edge types.
+ """
+ atoms, coordinates = inner_coords(atoms, coordinates, remove_hs=remove_hs)
+ atoms = np.array(atoms)
+ coordinates = np.array(coordinates).astype(np.float32)
+ # cropping atoms and coordinates
+ if len(atoms) > max_atoms:
+ idx = np.random.choice(len(atoms), max_atoms, replace=False)
+ atoms = atoms[idx]
+ coordinates = coordinates[idx]
+ # tokens padding
+ src_tokens = np.array(
+ [dictionary.bos()]
+ + [dictionary.index(atom) for atom in atoms]
+ + [dictionary.eos()]
+ )
+ src_distance = np.zeros((len(src_tokens), len(src_tokens)))
+ # coordinates normalize & padding
+ src_coord = coordinates - coordinates.mean(axis=0)
+ src_coord = np.concatenate([np.zeros((1, 3)), src_coord, np.zeros((1, 3))], axis=0)
+ # distance matrix
+ src_distance = distance_matrix(src_coord, src_coord)
+ # edge type
+ src_edge_type = src_tokens.reshape(-1, 1) * len(dictionary) + src_tokens.reshape(
+ 1, -1
+ )
+
+ return {
+ 'src_tokens': src_tokens.astype(int),
+ 'src_distance': src_distance.astype(np.float32),
+ 'src_coord': src_coord.astype(np.float32),
+ 'src_edge_type': src_edge_type.astype(int),
+ }
+
+
+class UniMolV2Feature(object):
+ '''
+ This class is responsible for generating features for molecules represented as SMILES strings. It uses the ConformerGen class to generate conformers for the molecules and converts the resulting atom symbols and coordinates into a unified molecular representation.
+ '''
+
+ def __init__(self, **params):
+ """
+ Initializes the neural network model based on the provided model name and parameters.
+
+ :param model_name: (str) The name of the model to initialize.
+ :param params: Additional parameters for model configuration.
+
+ :return: An instance of the specified neural network model.
+ :raises ValueError: If the model name is not recognized.
+ """
+ self._init_features(**params)
+
+ def _init_features(self, **params):
+ """
+ Initializes the features of the UniMolV2Feature object based on provided parameters.
+
+ :param params: Arbitrary keyword arguments for feature configuration.
+ These can include the random seed, maximum number of atoms, data type,
+ generation method, generation mode, and whether to remove hydrogens.
+ """
+ self.seed = params.get('seed', 42)
+ self.max_atoms = params.get('max_atoms', 128)
+ self.data_type = params.get('data_type', 'molecule')
+ self.method = params.get('method', 'rdkit_random')
+ self.mode = params.get('mode', 'fast')
+ self.remove_hs = params.get('remove_hs', True)
+ if os.name == 'posix':
+ self.multi_process = params.get('multi_process', True)
+ else:
+ self.multi_process = params.get('multi_process', False)
+ if self.multi_process:
+ logger.warning(
+ 'Please use "if __name__ == "__main__":" to wrap the main function when using multi_process on Windows.'
+ )
+
+ def single_process(self, smiles):
+ """
+ Processes a single SMILES string to generate conformers using the specified method.
+
+ :param smiles: (str) The SMILES string representing the molecule.
+ :return: A unimolecular data representation (dictionary) of the molecule.
+ :raises ValueError: If the conformer generation method is unrecognized.
+ """
+ if self.method == 'rdkit_random':
+ mol = inner_smi2coords(
+ smiles,
+ seed=self.seed,
+ mode=self.mode,
+ remove_hs=self.remove_hs,
+ return_mol=True,
+ )
+ feat = mol2unimolv2(mol, self.max_atoms, remove_hs=self.remove_hs)
+ return feat, mol
+ else:
+ raise ValueError(
+ 'Unknown conformer generation method: {}'.format(self.method)
+ )
+
+ def transform_raw(self, atoms_list, coordinates_list):
+
+ inputs = []
+ for atoms, coordinates in zip(atoms_list, coordinates_list):
+ mol = create_mol_from_atoms_and_coords(atoms, coordinates)
+ inputs.append(mol2unimolv2(mol, self.max_atoms, remove_hs=self.remove_hs))
+ return inputs
+
+ def transform_mols(self, mols_list):
+ inputs = []
+ for mol in mols_list:
+ inputs.append(mol2unimolv2(mol, self.max_atoms, remove_hs=self.remove_hs))
+ return inputs
+
+ def transform(self, smiles_list):
+ logger.info('Start generating conformers...')
+ if self.multi_process:
+ pool = Pool(processes=min(8, os.cpu_count()))
+ results = [
+ item for item in tqdm(pool.imap(self.single_process, smiles_list))
+ ]
+ pool.close()
+ else:
+ results = [self.single_process(smiles) for smiles in tqdm(smiles_list)]
+
+ inputs, mols = zip(*results)
+ inputs = list(inputs)
+ mols = list(mols)
+
+ failed_conf = [(item['src_coord'] == 0.0).all() for item in inputs]
+ logger.info(
+ 'Succeeded in generating conformers for {:.2f}% of molecules.'.format(
+ (1 - np.mean(failed_conf)) * 100
+ )
+ )
+ failed_conf_indices = [
+ index for index, value in enumerate(failed_conf) if value
+ ]
+ if len(failed_conf_indices) > 0:
+ logger.info('Failed conformers indices: {}'.format(failed_conf_indices))
+ logger.debug(
+ 'Failed conformers SMILES: {}'.format(
+ [smiles_list[index] for index in failed_conf_indices]
+ )
+ )
+
+ failed_conf_3d = [(item['src_coord'][:, 2] == 0.0).all() for item in inputs]
+ logger.info(
+ 'Succeeded in generating 3d conformers for {:.2f}% of molecules.'.format(
+ (1 - np.mean(failed_conf_3d)) * 100
+ )
+ )
+ failed_conf_3d_indices = [
+ index for index, value in enumerate(failed_conf_3d) if value
+ ]
+ if len(failed_conf_3d_indices) > 0:
+ logger.info(
+ 'Failed 3d conformers indices: {}'.format(failed_conf_3d_indices)
+ )
+ logger.debug(
+ 'Failed 3d conformers SMILES: {}'.format(
+ [smiles_list[index] for index in failed_conf_3d_indices]
+ )
+ )
+
+ return inputs, mols
+
+
+def create_mol_from_atoms_and_coords(atoms, coordinates):
+ """
+ Creates an RDKit molecule object from a list of atom symbols and their corresponding coordinates.
+
+ :param atoms: (list) Atom symbols for the molecule.
+ :param coordinates: (list) Atomic coordinates for the molecule.
+ :return: RDKit molecule object.
+ """
+ mol = Chem.RWMol()
+ atom_indices = []
+
+ for atom in atoms:
+ atom_idx = mol.AddAtom(Chem.Atom(atom))
+ atom_indices.append(atom_idx)
+
+ conf = Chem.Conformer(len(atoms))
+ for i, coord in enumerate(coordinates):
+ conf.SetAtomPosition(i, coord)
+
+ mol.AddConformer(conf)
+ Chem.SanitizeMol(mol)
+ return mol
+
+
+def mol2unimolv2(mol, max_atoms=128, remove_hs=True, **params):
+ """
+ Converts atom symbols and coordinates into a unified molecular representation.
+
+ :param mol: (rdkit.Chem.Mol) The molecule object containing atom symbols and coordinates.
+ :param max_atoms: (int) The maximum number of atoms to consider for the molecule.
+ :param remove_hs: (bool) Whether to remove hydrogen atoms from the representation. This must be True for UniMolV2.
+ :param params: Additional parameters.
+
+ :return: A batched data containing the molecular representation.
+ """
+
+ mol = AllChem.RemoveAllHs(mol)
+ atoms = np.array([atom.GetSymbol() for atom in mol.GetAtoms()])
+ coordinates = mol.GetConformer().GetPositions().astype(np.float32)
+
+ # cropping atoms and coordinates
+ if len(atoms) > max_atoms:
+ mask = np.zeros(len(atoms), dtype=bool)
+ mask[:max_atoms] = True
+ np.random.shuffle(mask) # shuffle the mask
+ atoms = atoms[mask]
+ coordinates = coordinates[mask]
+ else:
+ mask = np.ones(len(atoms), dtype=bool)
+ # tokens padding
+ src_tokens = [AllChem.GetPeriodicTable().GetAtomicNumber(item) for item in atoms]
+ src_coord = coordinates
+ #
+ node_attr, edge_index, edge_attr = get_graph(mol)
+ feat = get_graph_features(edge_attr, edge_index, node_attr, drop_feat=0, mask=mask)
+ feat['src_tokens'] = src_tokens
+ feat['src_coord'] = src_coord
+ return feat
+
+
+def safe_index(l, e):
+ """
+ Return index of element e in list l. If e is not present, return the last index
+ """
+ try:
+ return l.index(e)
+ except:
+ return len(l) - 1
+
+
+def atom_to_feature_vector(atom):
+ """
+ Converts rdkit atom object to feature list of indices
+ :param mol: rdkit atom object
+ :return: list
+ """
+ atom_feature = [
+ safe_index(allowable_features["possible_atomic_num_list"], atom.GetAtomicNum()),
+ allowable_features["possible_chirality_list"].index(str(atom.GetChiralTag())),
+ safe_index(allowable_features["possible_degree_list"], atom.GetTotalDegree()),
+ safe_index(
+ allowable_features["possible_formal_charge_list"], atom.GetFormalCharge()
+ ),
+ safe_index(allowable_features["possible_numH_list"], atom.GetTotalNumHs()),
+ safe_index(
+ allowable_features["possible_number_radical_e_list"],
+ atom.GetNumRadicalElectrons(),
+ ),
+ safe_index(
+ allowable_features["possible_hybridization_list"],
+ str(atom.GetHybridization()),
+ ),
+ allowable_features["possible_is_aromatic_list"].index(atom.GetIsAromatic()),
+ allowable_features["possible_is_in_ring_list"].index(atom.IsInRing()),
+ ]
+ return atom_feature
+
+
+def bond_to_feature_vector(bond):
+ """
+ Converts rdkit bond object to feature list of indices
+ :param mol: rdkit bond object
+ :return: list
+ """
+ bond_feature = [
+ safe_index(
+ allowable_features["possible_bond_type_list"], str(bond.GetBondType())
+ ),
+ allowable_features["possible_bond_stereo_list"].index(str(bond.GetStereo())),
+ allowable_features["possible_is_conjugated_list"].index(bond.GetIsConjugated()),
+ ]
+ return bond_feature
+
+
+def get_graph(mol):
+ """
+ Converts SMILES string to graph Data object
+ :input: SMILES string (str)
+ :return: graph object
+ """
+ atom_features_list = []
+ for atom in mol.GetAtoms():
+ atom_features_list.append(atom_to_feature_vector(atom))
+ x = np.array(atom_features_list, dtype=np.int32)
+ # bonds
+ num_bond_features = 3 # bond type, bond stereo, is_conjugated
+ if len(mol.GetBonds()) > 0: # mol has bonds
+ edges_list = []
+ edge_features_list = []
+ for bond in mol.GetBonds():
+ i = bond.GetBeginAtomIdx()
+ j = bond.GetEndAtomIdx()
+ edge_feature = bond_to_feature_vector(bond)
+ # add edges in both directions
+ edges_list.append((i, j))
+ edge_features_list.append(edge_feature)
+ edges_list.append((j, i))
+ edge_features_list.append(edge_feature)
+ # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
+ edge_index = np.array(edges_list, dtype=np.int32).T
+ # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
+ edge_attr = np.array(edge_features_list, dtype=np.int32)
+
+ else: # mol has no bonds
+ edge_index = np.empty((2, 0), dtype=np.int32)
+ edge_attr = np.empty((0, num_bond_features), dtype=np.int32)
+ return x, edge_index, edge_attr
+
+
+def get_graph_features(edge_attr, edge_index, node_attr, drop_feat, mask):
+ # atom_feat_sizes = [128] + [16 for _ in range(8)]
+ atom_feat_sizes = [16 for _ in range(8)]
+ edge_feat_sizes = [16, 16, 16]
+ edge_attr, edge_index, x = edge_attr, edge_index, node_attr
+ N = x.shape[0]
+
+ # atom feature here
+ atom_feat = convert_to_single_emb(x[:, 1:], atom_feat_sizes)
+
+ # node adj matrix [N, N] bool
+ adj = np.zeros([N, N], dtype=np.int32)
+ adj[edge_index[0, :], edge_index[1, :]] = 1
+ degree = adj.sum(axis=-1)
+
+ # edge feature here
+ if len(edge_attr.shape) == 1:
+ edge_attr = edge_attr[:, None]
+ edge_feat = np.zeros([N, N, edge_attr.shape[-1]], dtype=np.int32)
+ edge_feat[edge_index[0, :], edge_index[1, :]] = (
+ convert_to_single_emb(edge_attr, edge_feat_sizes) + 1
+ )
+ shortest_path_result = floyd_warshall(adj)
+ # max distance is 509
+ if drop_feat:
+ atom_feat[...] = 1
+ edge_feat[...] = 1
+ degree[...] = 1
+ shortest_path_result[...] = 511
+ else:
+ atom_feat = atom_feat + 2
+ edge_feat = edge_feat + 2
+ degree = degree + 2
+ shortest_path_result = shortest_path_result + 1
+
+ # combine, plus 1 for padding
+ feat = {}
+ feat["atom_feat"] = atom_feat[mask]
+ feat["atom_mask"] = np.ones(N, dtype=np.int64)[mask]
+ feat["edge_feat"] = edge_feat[mask][:, mask]
+ feat["shortest_path"] = shortest_path_result[mask][:, mask]
+ feat["degree"] = degree.reshape(-1)[mask]
+ # pair-type
+ atoms = atom_feat[..., 0]
+ pair_type = np.concatenate(
+ [
+ np.expand_dims(atoms, axis=(1, 2)).repeat(N, axis=1),
+ np.expand_dims(atoms, axis=(0, 2)).repeat(N, axis=0),
+ ],
+ axis=-1,
+ )
+ pair_type = pair_type[mask][:, mask]
+ feat["pair_type"] = convert_to_single_emb(pair_type, [128, 128])
+ feat["attn_bias"] = np.zeros((mask.sum() + 1, mask.sum() + 1), dtype=np.float32)
+ return feat
+
+
+def convert_to_single_emb(x, sizes):
+ assert x.shape[-1] == len(sizes)
+ offset = 1
+ for i in range(len(sizes)):
+ assert (x[..., i] < sizes[i]).all()
+ x[..., i] = x[..., i] + offset
+ offset += sizes[i]
+ return x
+
+
+@njit
+def floyd_warshall(M):
+ (nrows, ncols) = M.shape
+ assert nrows == ncols
+ n = nrows
+ # set unreachable nodes distance to 510
+ for i in range(n):
+ for j in range(n):
+ if M[i, j] == 0:
+ M[i, j] = 510
+
+ for i in range(n):
+ M[i, i] = 0
+
+ # floyed algo
+ for k in range(n):
+ for i in range(n):
+ for j in range(n):
+ cost_ikkj = M[i, k] + M[k, j]
+ if M[i, j] > cost_ikkj:
+ M[i, j] = cost_ikkj
+
+ for i in range(n):
+ for j in range(n):
+ if M[i, j] >= 510:
+ M[i, j] = 510
+ return M
diff --git a/MindChem/applications/unimol/unimol_tools/data/datahub.py b/MindChem/applications/unimol/unimol_tools/data/datahub.py
new file mode 100644
index 0000000000000000000000000000000000000000..29e574b2596207db1a61b820433f1fb8d4b62cc3
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/data/datahub.py
@@ -0,0 +1,198 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import absolute_import, division, print_function
+
+import os
+import numpy as np
+from rdkit.Chem import PandasTools
+
+from ..utils import logger
+from .conformer import ConformerGen, UniMolV2Feature
+from .datareader import MolDataReader
+from .datascaler import TargetScaler
+from .split import Splitter
+
+
+class DataHub(object):
+ """
+ The DataHub class is responsible for storing and preprocessing data for machine learning tasks.
+ It initializes with configuration options to handle different types of tasks such as regression,
+ classification, and others. It also supports data scaling and handling molecular data.
+ """
+
+ def __init__(self, data=None, is_train=True, save_path=None, **params):
+ """
+ Initializes the DataHub instance with data and configuration for the ML task.
+
+ :param data: Initial dataset to be processed.
+ :param is_train: (bool) Indicates if the DataHub is being used for training.
+ :param save_path: (str) Path to save any necessary files, like scalers.
+ :param params: Additional parameters for data preprocessing and model configuration.
+ """
+ self.raw_data = data
+ self.is_train = is_train
+ self.save_path = save_path
+ self.task = params.get('task', None)
+ self.target_cols = params.get('target_cols', None)
+ self.multiclass_cnt = params.get('multiclass_cnt', None)
+ self.ss_method = params.get('target_normalize', 'none')
+ self.conf_cache_level = params.get('conf_cache_level', 1)
+ self._init_data(**params)
+ self._init_split(**params)
+
+ def _init_data(self, **params):
+ """
+ Initializes and preprocesses the data based on the task and parameters provided.
+
+ This method handles reading raw data, scaling targets, and transforming data for use with
+ molecular inputs. It tailors the preprocessing steps based on the task type, such as regression
+ or classification.
+
+ :param params: Additional parameters for data processing.
+ :raises ValueError: If the task type is unknown.
+ """
+ self.data = MolDataReader().read_data(self.raw_data, self.is_train, **params)
+ self.data['target_scaler'] = TargetScaler(
+ self.ss_method, self.task, self.save_path
+ )
+ if self.task == 'regression':
+ target = np.array(self.data['raw_target']).reshape(-1, 1).astype(np.float32)
+ if self.is_train:
+ self.data['target_scaler'].fit(target, self.save_path)
+ self.data['target'] = self.data['target_scaler'].transform(target)
+ else:
+ self.data['target'] = target
+ elif self.task == 'classification':
+ target = np.array(self.data['raw_target']).reshape(-1, 1).astype(np.int32)
+ self.data['target'] = target
+ elif self.task == 'multiclass':
+ target = np.array(self.data['raw_target']).reshape(-1, 1).astype(np.int32)
+ self.data['target'] = target
+ if not self.is_train:
+ self.data['multiclass_cnt'] = self.multiclass_cnt
+ elif self.task == 'multilabel_regression':
+ target = (
+ np.array(self.data['raw_target'])
+ .reshape(-1, self.data['num_classes'])
+ .astype(np.float32)
+ )
+ if self.is_train:
+ self.data['target_scaler'].fit(target, self.save_path)
+ self.data['target'] = self.data['target_scaler'].transform(target)
+ else:
+ self.data['target'] = target
+ elif self.task == 'multilabel_classification':
+ target = (
+ np.array(self.data['raw_target'])
+ .reshape(-1, self.data['num_classes'])
+ .astype(np.int32)
+ )
+ self.data['target'] = target
+ elif self.task == 'repr':
+ self.data['target'] = self.data['raw_target']
+ else:
+ raise ValueError('Unknown task: {}'.format(self.task))
+
+ if params.get('model_name', None) == 'unimolv1':
+ if 'mols' in self.data:
+ no_h_list = ConformerGen(**params).transform_mols(self.data['mols'])
+ mols = None
+ elif 'atoms' in self.data and 'coordinates' in self.data:
+ no_h_list = ConformerGen(**params).transform_raw(
+ self.data['atoms'], self.data['coordinates']
+ )
+ mols = None
+ else:
+ smiles_list = self.data["smiles"]
+ no_h_list, mols = ConformerGen(**params).transform(smiles_list)
+ elif params.get('model_name', None) == 'unimolv2':
+ if 'mols' in self.data:
+ no_h_list = UniMolV2Feature(**params).transform_mols(self.data['mols'])
+ mols = None
+ elif 'atoms' in self.data and 'coordinates' in self.data:
+ no_h_list = UniMolV2Feature(**params).transform_raw(
+ self.data['atoms'], self.data['coordinates']
+ )
+ mols = None
+ else:
+ smiles_list = self.data["smiles"]
+ no_h_list, mols = UniMolV2Feature(**params).transform(smiles_list)
+
+ self.data['unimol_input'] = no_h_list
+
+ if mols is not None:
+ self.save_mol2sdf(self.data['raw_data'], mols, params)
+
+ def _init_split(self, **params):
+
+ self.split_method = params.get('split_method', '5fold_random')
+ kfold, method = (
+ int(self.split_method.split('fold')[0]),
+ self.split_method.split('_')[-1],
+ ) # Nfold_xxxx
+ self.kfold = params.get('kfold', kfold)
+ self.method = params.get('split', method)
+ self.split_seed = params.get('split_seed', 42)
+ self.data['kfold'] = self.kfold
+ if not self.is_train:
+ return
+ self.splitter = Splitter(self.method, self.kfold, seed=self.split_seed)
+ split_nfolds = self.splitter.split(**self.data)
+ if self.kfold == 1:
+ logger.info(f"Kfold is 1, all data is used for training.")
+ else:
+ logger.info(f"Split method: {self.method}, fold: {self.kfold}")
+ nfolds = np.zeros(len(split_nfolds[0][0]) + len(split_nfolds[0][1]), dtype=int)
+ for enu, (tr_idx, te_idx) in enumerate(split_nfolds):
+ nfolds[te_idx] = enu
+ self.data['split_nfolds'] = split_nfolds
+ return split_nfolds
+
+ def save_mol2sdf(self, data, mols, params):
+ """
+ Save the conformers to a SDF file.
+
+ :param data: DataFrame containing the raw data.
+ :param mols: List of RDKit molecule objects.
+ """
+ if isinstance(self.raw_data, str):
+ base_name = os.path.splitext(os.path.basename(self.raw_data))[0]
+ elif isinstance(self.raw_data, list) or isinstance(self.raw_data, np.ndarray):
+ # If the raw_data is a list of smiles, we can use a default name.
+ base_name = 'unimol_conformers'
+ else:
+ logger.warning('Warning: raw_data is not a path or list, cannot save sdf.')
+ return
+ if params.get('sdf_save_path') is None:
+ if self.save_path is not None:
+ params['sdf_save_path'] = self.save_path
+ else:
+ return
+ save_path = os.path.join(params.get('sdf_save_path'), f"{base_name}.sdf")
+ if self.conf_cache_level == 0:
+ logger.warning(f"conf_cache_level is 0, do not save conformers.")
+ return
+ elif self.conf_cache_level == 1 and os.path.exists(save_path):
+ logger.warning(f"conf_cache_level is 1, but {save_path} exists, so do not save conformers.")
+ return
+ elif self.conf_cache_level == 2 or not os.path.exists(save_path):
+ logger.info(f"conf_cache_level is {self.conf_cache_level}, saving conformers to {save_path}.")
+ else:
+ logger.warning(f"Unknown conf_cache_level: {self.conf_cache_level}, do not saving conformers.")
+ return
+ sdf_result = data.copy()
+ sdf_result['ROMol'] = mols
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ try:
+ PandasTools.WriteSDF(
+ sdf_result,
+ save_path,
+ properties=list(sdf_result.columns),
+ idName='RowID',
+ )
+ logger.info(f"Successfully saved sdf file to {save_path}")
+ except Exception as e:
+ logger.warning(f"Failed to write sdf file: {e}")
+ pass
diff --git a/MindChem/applications/unimol/unimol_tools/data/datareader.py b/MindChem/applications/unimol/unimol_tools/data/datareader.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ded184caef551c777660fab4daf28efa8070144
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/data/datareader.py
@@ -0,0 +1,228 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import absolute_import, division, print_function
+
+import os
+import pathlib
+
+import numpy as np
+import pandas as pd
+from rdkit import Chem
+from rdkit.Chem import PandasTools
+from rdkit.Chem.Scaffolds import MurckoScaffold
+
+from ..utils import logger
+
+
+class MolDataReader(object):
+ '''A class to read Mol Data.'''
+
+ def read_data(self, data=None, is_train=True, **params):
+ # TO DO
+ # 1. add anomaly detection & outlier removal.
+ # 2. add support for other file format.
+ # 3. add support for multi tasks.
+
+ """
+ Reads and preprocesses molecular data from various input formats for model training or prediction.
+ Parsing target columns
+ 1. if target_cols is not None, use target_cols as target columns.
+ 2. if target_cols is None, use all columns with prefix 'target_col_prefix' as target columns.
+ 3. use given target_cols as target columns placeholder with value -1.0 for predict
+
+ :param data: The input molecular data. Can be a file path (str), a dictionary, or a list of SMILES strings.
+ :param is_train: (bool) A flag indicating if the operation is for training. Determines data processing steps.
+ :param params: A dictionary of additional parameters for data processing.
+
+ :return: A dictionary containing processed data and related information for model consumption.
+ :raises ValueError: If the input data type is not supported or if any SMILES string is invalid (when strict).
+ """
+ task = params.get('task', None)
+ target_cols = params.get('target_cols', None)
+ smiles_col = params.get('smiles_col', 'SMILES')
+ target_col_prefix = params.get('target_col_prefix', 'TARGET')
+ anomaly_clean = params.get('anomaly_clean', False)
+ smi_strict = params.get('smi_strict', False)
+ split_group_col = params.get('split_group_col', 'scaffold')
+
+ if isinstance(data, str):
+ # load from file
+ self.data_path = data
+ if data.endswith('.sdf'):
+ # load sdf file
+ data = PandasTools.LoadSDF(data)
+ elif data.endswith('.csv'):
+ data = pd.read_csv(self.data_path)
+ else:
+ raise ValueError('Unknown file type: {}'.format(data))
+ elif isinstance(data, dict):
+ # load from dict
+ if 'target' in data:
+ label = np.array(data['target'])
+ if len(label.shape) == 1 or label.shape[1] == 1:
+ data[target_col_prefix] = label.reshape(-1)
+ else:
+ for i in range(label.shape[1]):
+ data[target_col_prefix + str(i)] = label[:, i]
+
+ _ = data.pop('target', None)
+ data = pd.DataFrame(data).rename(columns={smiles_col: 'SMILES'})
+
+ elif isinstance(data, list) or isinstance(data, np.ndarray):
+ # load from smiles list
+ data = pd.DataFrame(data, columns=['SMILES'])
+ else:
+ raise ValueError('Unknown data type: {}'.format(type(data)))
+
+ #### parsing target columns
+ #### 1. if target_cols is not None, use target_cols as target columns.
+ #### 2. if target_cols is None, use all columns with prefix 'target_col_prefix' as target columns.
+ #### 3. use given target_cols as target columns placeholder with value -1.0 for predict
+ if task == 'repr':
+ # placeholder for repr task
+ targets = None
+ target_cols = None
+ num_classes = None
+ multiclass_cnt = None
+ else:
+ if target_cols is None:
+ target_cols = [
+ item for item in data.columns if item.startswith(target_col_prefix)
+ ]
+ elif isinstance(target_cols, str):
+ target_cols = [target_col.strip() for target_col in target_cols.split(',')]
+ elif isinstance(target_cols, list):
+ pass
+ else:
+ raise ValueError(
+ 'Unknown target_cols type: {}'.format(type(target_cols))
+ )
+
+ if is_train:
+ if anomaly_clean:
+ data = self.anomaly_clean(data, task, target_cols)
+ if task == 'multiclass':
+ multiclass_cnt = int(data[target_cols].max() + 1)
+ else:
+ for col in target_cols:
+ if col not in data.columns or data[col].isnull().any():
+ data[col] = -1.0
+
+ targets = data[target_cols].values.tolist()
+ num_classes = len(target_cols)
+
+ dd = {
+ 'raw_data': data,
+ 'raw_target': targets,
+ 'num_classes': num_classes,
+ 'target_cols': target_cols,
+ 'multiclass_cnt': (
+ multiclass_cnt if task == 'multiclass' and is_train else None
+ ),
+ }
+ if smiles_col in data.columns:
+ mask = data[smiles_col].apply(
+ lambda smi: self.check_smiles(smi, is_train, smi_strict)
+ )
+ data = data[mask]
+ dd['smiles'] = data[smiles_col].tolist()
+ dd['scaffolds'] = data[smiles_col].map(self.smi2scaffold).tolist()
+ else:
+ dd['smiles'] = None
+ dd['scaffolds'] = None
+
+ if split_group_col in data.columns:
+ dd['group'] = data[split_group_col].tolist()
+ elif split_group_col == 'scaffold':
+ dd['group'] = dd['scaffolds']
+ else:
+ dd['group'] = None
+
+ if 'atoms' in data.columns and 'coordinates' in data.columns:
+ dd['atoms'] = data['atoms'].tolist()
+ dd['coordinates'] = data['coordinates'].tolist()
+
+ if 'ROMol' in data.columns:
+ dd['mols'] = data['ROMol'].tolist()
+
+ return dd
+
+ def check_smiles(self, smi, is_train, smi_strict):
+ """
+ Validates a SMILES string and decides whether it should be included based on training mode and strictness.
+
+ :param smi: (str) The SMILES string to check.
+ :param is_train: (bool) Indicates if this check is happening during training.
+ :param smi_strict: (bool) If true, invalid SMILES strings raise an error, otherwise they're logged and skipped.
+
+ :return: (bool) True if the SMILES string is valid, False otherwise.
+ :raises ValueError: If the SMILES string is invalid and strict mode is on.
+ """
+ if Chem.MolFromSmiles(smi) is None:
+ if is_train and not smi_strict:
+ logger.info(f'Illegal SMILES clean: {smi}')
+ return False
+ else:
+ raise ValueError(f'SMILES rule is illegal: {smi}')
+ return True
+
+ def smi2scaffold(self, smi):
+ """
+ Converts a SMILES string to its corresponding scaffold.
+
+ :param smi: (str) The SMILES string to convert.
+
+ :return: (str) The scaffold of the SMILES string, or the original SMILES if conversion fails.
+ """
+ try:
+ return MurckoScaffold.MurckoScaffoldSmiles(
+ smiles=smi, includeChirality=True
+ )
+ except:
+ return smi
+
+ def anomaly_clean(self, data, task, target_cols):
+ """
+ Performs anomaly cleaning on the data based on the specified task.
+
+ :param data: (DataFrame) The dataset to be cleaned.
+ :param task: (str) The type of task which determines the cleaning strategy.
+ :param target_cols: (list) The list of target columns to consider for cleaning.
+
+ :return: (DataFrame) The cleaned dataset.
+ :raises ValueError: If the provided task is not recognized.
+ """
+ if task in [
+ 'classification',
+ 'multiclass',
+ 'multilabel_classification',
+ 'multilabel_regression',
+ ]:
+ return data
+ if task == 'regression':
+ return self.anomaly_clean_regression(data, target_cols)
+ else:
+ raise ValueError('Unknown task: {}'.format(task))
+
+ def anomaly_clean_regression(self, data, target_cols):
+ """
+ Performs anomaly cleaning specifically for regression tasks using a 3-sigma threshold.
+
+ :param data: (DataFrame) The dataset to be cleaned.
+ :param target_cols: (list) The list of target columns to consider for cleaning.
+
+ :return: (DataFrame) The cleaned dataset after applying the 3-sigma rule.
+ """
+ sz = data.shape[0]
+ target_col = target_cols[0]
+ _mean, _std = data[target_col].mean(), data[target_col].std()
+ data = data[
+ (data[target_col] > _mean - 3 * _std)
+ & (data[target_col] < _mean + 3 * _std)
+ ]
+ logger.info(
+ 'Anomaly clean with 3 sigma threshold: {} -> {}'.format(sz, data.shape[0])
+ )
+ return data
diff --git a/MindChem/applications/unimol/unimol_tools/data/datascaler.py b/MindChem/applications/unimol/unimol_tools/data/datascaler.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9894840ddfc090e1c53f9a2660fe41d0fda5c36
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/data/datascaler.py
@@ -0,0 +1,218 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import absolute_import, division, print_function
+
+import os
+
+import joblib
+import numpy as np
+from scipy.stats import kurtosis, skew
+from sklearn.preprocessing import (
+ FunctionTransformer,
+ MaxAbsScaler,
+ MinMaxScaler,
+ Normalizer,
+ PowerTransformer,
+ QuantileTransformer,
+ RobustScaler,
+ StandardScaler,
+)
+
+from ..utils import logger
+
+
+class TargetScaler(object):
+ '''
+ A class to scale the target.
+ '''
+
+ def __init__(self, ss_method, task, load_dir=None):
+ """
+ Initializes the TargetScaler object for scaling target values.
+
+ :param ss_method: (str) The scaling method to be used.
+ :param task: (str) The type of machine learning task (e.g., 'classification', 'regression').
+ :param load_dir: (str, optional) Directory from which to load an existing scaler.
+ """
+ self.ss_method = ss_method
+ self.task = task
+ if load_dir and os.path.exists(os.path.join(load_dir, 'target_scaler.ss')):
+ self.scaler = joblib.load(os.path.join(load_dir, 'target_scaler.ss'))
+ else:
+ self.scaler = None
+
+ def transform(self, target):
+ """
+ Transforms the target values using the appropriate scaling method.
+
+ :param target: (array-like) The target values to be transformed.
+
+ :return: (array-like) The transformed target values.
+ """
+ if self.task in ['classification', 'multiclass', 'multilabel_classification']:
+ return target
+ elif self.ss_method == 'none':
+ return target
+ elif self.task == 'regression':
+ return self.scaler.transform(target)
+ elif self.task == 'multilabel_regression':
+ assert isinstance(self.scaler, list) and len(self.scaler) == target.shape[1]
+ target = np.ma.masked_invalid(target) # mask NaN value
+ new_target = np.zeros_like(target)
+ for i in range(target.shape[1]):
+ new_target[:, i] = (
+ self.scaler[i]
+ .transform(target[:, i : i + 1])
+ .reshape(
+ -1,
+ )
+ )
+ return new_target
+ else:
+ return target
+
+ def fit(self, target, dump_dir):
+ """
+ Fits the scaler to the target values and optionally saves the scaler to disk.
+
+ :param target: (array-like) The target values to fit the scaler.
+ :param dump_dir: (str) Directory where the fitted scaler will be saved.
+ """
+ if self.task in ['classification', 'multiclass', 'multilabel_classification']:
+ return
+ elif self.ss_method == 'none':
+ return
+ elif self.ss_method == 'auto':
+ if self.task == 'regression':
+ if self.is_skewed(target):
+ self.scaler = FunctionTransformer(
+ func=np.log1p, inverse_func=np.expm1
+ )
+ logger.info('Auto select robust transformer.')
+ else:
+ self.scaler = StandardScaler()
+ self.scaler.fit(target)
+ elif self.task == 'multilabel_regression':
+ self.scaler = []
+ target = np.ma.masked_invalid(target) # mask NaN value
+ for i in range(target.shape[1]):
+ if self.is_skewed(target[:, i]):
+ self.scaler.append(
+ FunctionTransformer(func=np.log1p, inverse_func=np.expm1)
+ )
+ logger.info('Auto select robust transformer.')
+ else:
+ self.scaler.append(StandardScaler())
+ self.scaler[-1].fit(target[:, i : i + 1])
+ else:
+ if self.task == 'regression':
+ self.scaler = self.scaler_choose(self.ss_method, target)
+ self.scaler.fit(target)
+ elif self.task == 'multilabel_regression':
+ self.scaler = []
+ for i in range(target.shape[1]):
+ self.scaler.append(
+ self.scaler_choose(self.ss_method, target[:, i : i + 1])
+ )
+ self.scaler[-1].fit(target[:, i : i + 1])
+ try:
+ os.remove(os.path.join(dump_dir, 'target_scaler.ss'))
+ except:
+ pass
+ os.makedirs(dump_dir, exist_ok=True)
+ joblib.dump(self.scaler, os.path.join(dump_dir, 'target_scaler.ss'))
+
+ def scaler_choose(self, method, target):
+ """
+ Selects the appropriate scaler based on the scaling method and fit it to the target.
+
+ :param method: (str) The scaling method to be used.
+
+ currently support:
+
+ - 'minmax': MinMaxScaler,
+
+ - 'standard': StandardScaler,
+
+ - 'robust': RobustScaler,
+
+ - 'maxabs': MaxAbsScaler,
+
+ - 'quantile': QuantileTransformer,
+
+ - 'power_trans': PowerTransformer,
+
+ - 'normalizer': Normalizer,
+
+ - 'log1p': FunctionTransformer,
+
+ :param target: (array-like) The target values to fit the scaler.
+ :return: The fitted scaler object.
+ """
+ if method == 'minmax':
+ scaler = MinMaxScaler()
+ elif method == 'standard':
+ scaler = StandardScaler()
+ elif method == 'robust':
+ scaler = RobustScaler()
+ elif method == 'maxabs':
+ scaler = MaxAbsScaler()
+ elif method == 'quantile':
+ scaler = QuantileTransformer()
+ elif method == 'power_trans':
+ scaler = (
+ PowerTransformer(method='box-cox')
+ if min(target) > 0
+ else PowerTransformer(method='yeo-johnson')
+ )
+ elif method == 'normalizer':
+ scaler = Normalizer()
+ elif method == 'log1p':
+ scaler = FunctionTransformer(func=np.log1p, inverse_func=np.expm1)
+ else:
+ raise ValueError('Unknown scaler method: {}'.format(method))
+ return scaler
+
+ def inverse_transform(self, target):
+ """
+ Inverse transforms the scaled target values back to their original scale.
+
+ :param target: (array-like) The target values to be inverse transformed.
+
+ :return: (array-like) The target values in their original scale.
+ """
+ if self.task in ['classification', 'multiclass', 'multilabel_classification']:
+ return target
+ if self.ss_method == 'none' or self.scaler is None:
+ return target
+ elif self.task == 'regression':
+ return self.scaler.inverse_transform(target)
+ elif self.task == 'multilabel_regression':
+ assert isinstance(self.scaler, list) and len(self.scaler) == target.shape[1]
+ new_target = np.zeros_like(target)
+ for i in range(target.shape[1]):
+ new_target[:, i] = (
+ self.scaler[i]
+ .inverse_transform(target[:, i : i + 1])
+ .reshape(
+ -1,
+ )
+ )
+ return new_target
+ else:
+ raise ValueError('Unknown scaler method: {}'.format(self.ss_method))
+
+ def is_skewed(self, target):
+ """
+ Determines whether the target values are skewed based on skewness and kurtosis metrics.
+
+ :param target: (array-like) The target values to be checked for skewness.
+
+ :return: (bool) True if the target is skewed, False otherwise.
+ """
+ if self.task in ['classification', 'multiclass', 'multilabel_classification']:
+ return False
+ else:
+ return abs(skew(target)) > 5.0 or abs(kurtosis(target)) > 20.0
diff --git a/MindChem/applications/unimol/unimol_tools/data/dictionary.py b/MindChem/applications/unimol/unimol_tools/data/dictionary.py
new file mode 100644
index 0000000000000000000000000000000000000000..48d077a2fc68248c36f20105e3795d8430f336f7
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/data/dictionary.py
@@ -0,0 +1,151 @@
+# Copyright (c) DP Technology.
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import logging
+
+import numpy as np
+
+logger = logging.getLogger(__name__) # pylint: disable=invalid-name
+
+
+class Dictionary:
+ """A mapping from symbols to consecutive integers"""
+
+ def __init__(
+ self,
+ *, # begin keyword-only arguments
+ bos="[CLS]",
+ pad="[PAD]",
+ eos="[SEP]",
+ unk="[UNK]",
+ extra_special_symbols=None,
+ ):
+ self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
+ self.symbols = []
+ self.count = []
+ self.indices = {}
+ self.specials = set()
+ self.specials.add(bos)
+ self.specials.add(unk)
+ self.specials.add(pad)
+ self.specials.add(eos)
+
+ def __eq__(self, other):
+ return self.indices == other.indices
+
+ def __getitem__(self, idx):
+ if idx < len(self.symbols):
+ return self.symbols[idx]
+ return self.unk_word
+
+ def __len__(self):
+ """Returns the number of symbols in the dictionary"""
+ return len(self.symbols)
+
+ def __contains__(self, sym):
+ return sym in self.indices
+
+ def vec_index(self, a):
+ return np.vectorize(self.index)(a)
+
+ def index(self, sym):
+ """Returns the index of the specified symbol"""
+ assert isinstance(sym, str)
+ if sym in self.indices:
+ return self.indices[sym]
+ return self.indices[self.unk_word]
+
+ def special_index(self):
+ return [self.index(x) for x in self.specials]
+
+ def add_symbol(self, word, n=1, overwrite=False, is_special=False):
+ """Adds a word to the dictionary"""
+ if is_special:
+ self.specials.add(word)
+ if word in self.indices and not overwrite:
+ idx = self.indices[word]
+ self.count[idx] = self.count[idx] + n
+ return idx
+ else:
+ idx = len(self.symbols)
+ self.indices[word] = idx
+ self.symbols.append(word)
+ self.count.append(n)
+ return idx
+
+ def bos(self):
+ """Helper to get index of beginning-of-sentence symbol"""
+ return self.index(self.bos_word)
+
+ def pad(self):
+ """Helper to get index of pad symbol"""
+ return self.index(self.pad_word)
+
+ def eos(self):
+ """Helper to get index of end-of-sentence symbol"""
+ return self.index(self.eos_word)
+
+ def unk(self):
+ """Helper to get index of unk symbol"""
+ return self.index(self.unk_word)
+
+ @classmethod
+ def load(cls, f):
+ """Loads the dictionary from a text file with the format:
+
+ ```
+
+
+ ...
+ ```
+ """
+ d = cls()
+ d.add_from_file(f)
+ return d
+
+ def add_from_file(self, f):
+ """
+ Loads a pre-existing dictionary from a text file and adds its symbols
+ to this instance.
+ """
+ if isinstance(f, str):
+ try:
+ with open(f, "r", encoding="utf-8") as fd:
+ self.add_from_file(fd)
+ except FileNotFoundError as fnfe:
+ raise fnfe
+ except UnicodeError:
+ raise Exception(
+ "Incorrect encoding detected in {}, please "
+ "rebuild the dataset".format(f)
+ )
+ return
+
+ lines = f.readlines()
+
+ for line_idx, line in enumerate(lines):
+ try:
+ splits = line.rstrip().rsplit(" ", 1)
+ line = splits[0]
+ field = splits[1] if len(splits) > 1 else str(len(lines) - line_idx)
+ if field == "#overwrite":
+ overwrite = True
+ line, field = line.rsplit(" ", 1)
+ else:
+ overwrite = False
+ count = int(field)
+ word = line
+ if word in self and not overwrite:
+ logger.info(
+ "Duplicate word found when loading Dictionary: '{}', index is {}.".format(
+ word, self.indices[word]
+ )
+ )
+ else:
+ self.add_symbol(word, n=count, overwrite=overwrite)
+ except ValueError:
+ raise ValueError(
+ "Incorrect dictionary format, expected ' [flags]'"
+ )
diff --git a/MindChem/applications/unimol/unimol_tools/data/split.py b/MindChem/applications/unimol/unimol_tools/data/split.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2d2e2df54e8a42273014c936aad8753c9bbf184
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/data/split.py
@@ -0,0 +1,108 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import absolute_import, division, print_function
+
+import numpy as np
+from sklearn.model_selection import GroupKFold, KFold, StratifiedKFold
+
+from ..utils import logger
+
+
+class Splitter(object):
+ """
+ The Splitter class is responsible for splitting a dataset into train and test sets
+ based on the specified method.
+ """
+
+ def __init__(self, method='random', kfold=5, seed=42, **params):
+ """
+ Initializes the Splitter with a specified split method and random seed.
+
+ :param split_method: (str) The method for splitting the dataset, in the format 'Nfold_method'.
+ Defaults to '5fold_random'.
+ :param seed: (int) Random seed for reproducibility in random splitting. Defaults to 42.
+ """
+ self.method = method
+ self.n_splits = kfold
+ self.seed = seed
+ self.splitter = self._init_split()
+
+ def _init_split(self):
+ """
+ Initializes the actual splitter object based on the specified method.
+
+ :return: The initialized splitter object.
+ :raises ValueError: If an unknown splitting method is specified.
+ """
+ if self.n_splits == 1:
+ return None
+ if self.method == 'random':
+ splitter = KFold(
+ n_splits=self.n_splits, shuffle=True, random_state=self.seed
+ )
+ elif self.method == 'scaffold' or self.method == 'group':
+ splitter = GroupKFold(n_splits=self.n_splits)
+ elif self.method == 'stratified':
+ splitter = StratifiedKFold(
+ n_splits=self.n_splits, shuffle=True, random_state=self.seed
+ )
+ elif self.method == 'select':
+ splitter = GroupKFold(n_splits=self.n_splits)
+ else:
+ raise ValueError(
+ 'Unknown splitter method: {}fold - {}'.format(
+ self.n_splits, self.method
+ )
+ )
+
+ return splitter
+
+ def split(self, smiles, target=None, group=None, scaffolds=None, **params):
+ """
+ Splits the dataset into train and test sets based on the initialized method.
+
+ :param data: The dataset to be split.
+ :param target: (optional) Target labels for stratified splitting. Defaults to None.
+ :param group: (optional) Group labels for group-based splitting. Defaults to None.
+
+ :return: An iterator yielding train and test set indices for each fold.
+ :raises ValueError: If the splitter method does not support the provided parameters.
+ """
+ if self.n_splits == 1:
+ logger.warning(
+ 'Only one fold is used for training, no splitting is performed.'
+ )
+ return [(np.arange(len(smiles)), ())]
+ if smiles is None and 'atoms' in params:
+ smiles = params['atoms']
+ logger.warning('Atoms are used as SMILES for splitting.')
+ if self.method in ['random']:
+ self.skf = self.splitter.split(smiles)
+ elif self.method in ['scaffold']:
+ self.skf = self.splitter.split(smiles, target, scaffolds)
+ elif self.method in ['group']:
+ self.skf = self.splitter.split(smiles, target, group)
+ elif self.method in ['stratified']:
+ self.skf = self.splitter.split(smiles, group)
+ elif self.method in ['select']:
+ unique_groups = np.unique(group)
+ if len(unique_groups) == self.n_splits:
+ split_folds = []
+ for unique_group in unique_groups:
+ train_idx = np.where(group != unique_group)[0]
+ test_idx = np.where(group == unique_group)[0]
+ split_folds.append((train_idx, test_idx))
+ self.split_folds = split_folds
+ return self.split_folds
+ else:
+ logger.error(
+ 'The number of unique groups is not equal to the number of splits.'
+ )
+ exit(1)
+ else:
+ logger.error('Unknown splitter method: {}'.format(self.method))
+ exit(1)
+ self.split_folds = list(self.skf)
+ return self.split_folds
diff --git a/MindChem/applications/unimol/unimol_tools/models/__init__.py b/MindChem/applications/unimol/unimol_tools/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8810021219b7f7ef5ba447121a3a60c54d689b32
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/models/__init__.py
@@ -0,0 +1 @@
+from .nnmodel import NNModel, UniMolModel, UniMolV2Model
diff --git a/MindChem/applications/unimol/unimol_tools/models/loss.py b/MindChem/applications/unimol/unimol_tools/models/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..a473a9dc0b2d5ce062eb17ee66a27c2b469523c3
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/models/loss.py
@@ -0,0 +1,260 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import mindspore as ms
+import mindspore.numpy as mnp
+import mindspore.ops as ops
+from mindspore import nn
+from mindspore import Tensor
+
+
+class GHM_Loss(nn.Cell):
+ """A :class:`GHM_Loss` class."""
+
+ def __init__(self, bins=10, alpha=0.5):
+ """
+ Initializes the GHM_Loss module with the specified number of bins and alpha value.
+
+ :param bins: (int) The number of bins to divide the gradient. Defaults to 10.
+ :param alpha: (float) The smoothing parameter for updating the last bin count. Defaults to 0.5.
+ """
+ super(GHM_Loss, self).__init__()
+ self._bins = bins
+ self._alpha = alpha
+ self._last_bin_count = None
+
+ def _g2bin(self, g):
+ """
+ Maps gradient values to corresponding bin indices.
+
+ :param g: (Tensor) Gradient tensor.
+ :return: (Tensor) Bin indices for each gradient value.
+ """
+ return ops.floor(g * (self._bins - 0.0001)).astype(ms.int64)
+
+ def _custom_loss(self, x, target, weight):
+ """
+ Custom loss function to be implemented in subclasses.
+
+ :param x: (Tensor) Predicted values.
+ :param target: (Tensor) Ground truth labels.
+ :param weight: (Tensor) Weights for the loss.
+ :raise NotImplementedError: Indicates that the method should be implemented in subclasses.
+ """
+ raise NotImplementedError
+
+ def _custom_loss_grad(self, x, target):
+ """
+ Custom gradient computation function to be implemented in subclasses.
+
+ :param x: (Tensor) Predicted values.
+ :param target: (Tensor) Ground truth labels.
+ :raise NotImplementedError: Indicates that the method should be implemented in subclasses.
+ """
+ raise NotImplementedError
+
+ def construct(self, x, target):
+ """
+ Forward pass for computing the GHM loss.
+
+ :param x: (Tensor) Predicted values.
+ :param target: (Tensor) Ground truth labels.
+ :return: (Tensor) Computed GHM loss.
+ """
+ g = ops.abs(self._custom_loss_grad(x, target))
+
+ bin_idx = self._g2bin(g)
+
+ bin_count = ops.zeros((self._bins), ms.float32)
+ for i in range(self._bins):
+ bin_count[i] = ops.equal(bin_idx, i).sum()
+
+ N = x.shape[0] * x.shape[1]
+
+ if self._last_bin_count is None:
+ self._last_bin_count = bin_count
+ else:
+ bin_count = (
+ self._alpha * self._last_bin_count + (1 - self._alpha) * bin_count
+ )
+ self._last_bin_count = bin_count
+
+ nonempty_bins = ops.greater(bin_count, 0).sum()
+
+ gd = bin_count * nonempty_bins
+ gd = ops.clip_by_value(gd, ms.Tensor(0.0001, ms.float32), ms.Tensor(float('inf')))
+ beta = N / gd
+
+ beta = beta.astype(x.dtype)
+
+ return self._custom_loss(x, target, beta[bin_idx])
+
+
+class GHMC_Loss(GHM_Loss):
+ '''
+ Inherits from GHM_Loss. GHM_Loss for classification.
+ '''
+
+ def __init__(self, bins, alpha):
+ """
+ Initializes the GHMC_Loss with specified number of bins and alpha value.
+
+ :param bins: (int) Number of bins for gradient division.
+ :param alpha: (float) Smoothing parameter for bin count updating.
+ """
+ super(GHMC_Loss, self).__init__(bins, alpha)
+
+ def _custom_loss(self, x, target, weight):
+ """
+ Custom loss function for GHM classification loss.
+
+ :param x: (Tensor) Predicted values.
+ :param target: (Tensor) Ground truth labels.
+ :param weight: (Tensor) Weights for the loss.
+
+ :return: Binary cross-entropy loss with logits.
+ """
+ sigmoid = ops.Sigmoid()
+ log_sigmoid = ops.Log()
+ loss = ops.binary_cross_entropy_with_logits(x, target, weight=weight)
+ return loss
+
+ def _custom_loss_grad(self, x, target):
+ """
+ Custom gradient function for GHM classification loss.
+
+ :param x: (Tensor) Predicted values.
+ :param target: (Tensor) Ground truth labels.
+
+ :return: Gradient of the loss.
+ """
+ sigmoid = ops.Sigmoid()
+ return sigmoid(x) - target
+
+
+class GHMR_Loss(GHM_Loss):
+ '''
+ Inherits from GHM_Loss. GHM_Loss for regression
+ '''
+
+ def __init__(self, bins, alpha, mu):
+ """
+ Initializes the GHMR_Loss with specified number of bins, alpha value, and mu parameter.
+
+ :param bins: (int) Number of bins for gradient division.
+ :param alpha: (float) Smoothing parameter for bin count updating.
+ :param mu: (float) Parameter used in the GHMR loss formula.
+ """
+ super(GHMR_Loss, self).__init__(bins, alpha)
+ self._mu = mu
+
+ def _custom_loss(self, x, target, weight):
+ """
+ Custom loss function for GHM regression loss.
+
+ :param x: (Tensor) Predicted values.
+ :param target: (Tensor) Ground truth values.
+ :param weight: (Tensor) Weights for the loss.
+
+ :return: GHMR loss.
+ """
+ d = x - target
+ mu = self._mu
+ sqrt = ops.Sqrt()
+ loss = sqrt(d * d + mu * mu) - mu
+ N = x.shape[0] * x.shape[1]
+ return (loss * weight).sum() / N
+
+ def _custom_loss_grad(self, x, target):
+ """
+ Custom gradient function for GHM regression loss.
+
+ :param x: (Tensor) Predicted values.
+ :param target: (Tensor) Ground truth values.
+
+ :return: Gradient of the loss.
+ """
+ d = x - target
+ mu = self._mu
+ sqrt = ops.Sqrt()
+ return d / sqrt(d * d + mu * mu)
+
+
+def MAEwithNan(y_pred, y_true):
+ """
+ Calculates the Mean Absolute Error (MAE) loss, ignoring NaN values in the target.
+
+ :param y_pred: (Tensor) Predicted values.
+ :param y_true: (Tensor) Ground truth values, may contain NaNs.
+
+ :return: (Tensor) MAE loss computed only on non-NaN elements.
+ """
+ mask = ~ops.isinf(y_true)
+ y_pred = y_pred[mask]
+ y_true = y_true[mask]
+ mae_loss = nn.L1Loss()
+ loss = mae_loss(y_pred, y_true)
+ return loss
+
+
+def FocalLoss(y_pred, y_true, alpha=0.25, gamma=2):
+ """
+ Calculates the Focal Loss, used to address class imbalance by focusing on hard examples.
+
+ :param y_pred: (Tensor) Predicted probabilities.
+ :param y_true: (Tensor) Ground truth labels.
+ :param alpha: (float) Weighting factor for balancing positive and negative examples. Defaults to 0.25.
+ :param gamma: (float) Focusing parameter to scale the loss. Defaults to 2.
+
+ :return: (Tensor) Computed focal loss.
+ """
+ concat = ops.Concat(axis=1)
+ pow_op = ops.Pow()
+ log_op = ops.Log()
+ if y_pred.shape != y_true.shape:
+ y_true = y_true.flatten()
+ y_true = y_true.astype(ms.int64)
+ y_pred = y_pred.astype(ms.float32)
+ y_true = y_true.astype(ms.float32)
+ y_true = ops.expand_dims(y_true, 1)
+ y_pred = ops.expand_dims(y_pred, 1)
+ y_true = concat((1 - y_true, y_true))
+ y_pred = concat((1 - y_pred, y_pred))
+ y_pred = ops.clip_by_value(y_pred, ms.Tensor(1e-5, ms.float32), ms.Tensor(1.0, ms.float32))
+ loss = -alpha * y_true * pow_op((1 - y_pred), gamma) * log_op(y_pred)
+ return ops.mean(ops.sum(loss, axis=1))
+
+
+def FocalLossWithLogits(y_pred, y_true, alpha=0.25, gamma=2.0):
+ """
+ Calculates the Focal Loss using predicted logits (raw scores), automatically applying the sigmoid function.
+
+ :param y_pred: (Tensor) Predicted logits.
+ :param y_true: (Tensor) Ground truth labels, may contain NaNs.
+ :param alpha: (float) Weighting factor for balancing positive and negative examples. Defaults to 0.25.
+ :param gamma: (float) Focusing parameter to scale the loss. Defaults to 2.0.
+
+ :return: (Tensor) Computed focal loss.
+ """
+ sigmoid = ops.Sigmoid()
+ y_pred = sigmoid(y_pred)
+ mask = ~ops.isinf(y_true)
+ y_pred = y_pred[mask]
+ y_true = y_true[mask]
+ loss = FocalLoss(y_pred, y_true)
+ return loss
+
+
+def myCrossEntropyLoss(y_pred, y_true):
+ """
+ Calculates the cross-entropy loss between predictions and targets.
+
+ :param y_pred: (Tensor) Predicted logits or probabilities.
+ :param y_true: (Tensor) Ground truth labels.
+
+ :return: (Tensor) Computed cross-entropy loss.
+ """
+ if y_pred.shape != y_true.shape:
+ y_true = y_true.flatten()
+ return nn.CrossEntropyLoss()(y_pred, y_true)
diff --git a/MindChem/applications/unimol/unimol_tools/models/nnmodel.py b/MindChem/applications/unimol/unimol_tools/models/nnmodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..629742b352e934c82b1ca51f7378c55c9c2664ae
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/models/nnmodel.py
@@ -0,0 +1,327 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import absolute_import, division, print_function
+
+import os
+
+import joblib
+import numpy as np
+import mindspore as ms
+import mindspore.ops as ops
+from mindspore import nn
+
+from ..utils import logger
+from .loss import FocalLossWithLogits, GHMC_Loss, MAEwithNan, myCrossEntropyLoss
+from .unimol import UniMolModel
+from .unimolv2 import UniMolV2Model
+
+NNMODEL_REGISTER = {
+ 'unimolv1': UniMolModel,
+ 'unimolv2': UniMolV2Model,
+}
+
+LOSS_RREGISTER = {
+ 'classification': myCrossEntropyLoss,
+ 'multiclass': myCrossEntropyLoss,
+ 'regression': nn.MSELoss(),
+ 'multilabel_classification': {
+ 'bce': nn.BCEWithLogitsLoss(),
+ 'ghm': GHMC_Loss(bins=10, alpha=0.5),
+ 'focal': FocalLossWithLogits,
+ },
+ 'multilabel_regression': MAEwithNan,
+}
+
+
+def classification_activation(x):
+ return ops.softmax(x, axis=-1)[:, 1:]
+
+
+def multiclass_activation(x):
+ return ops.softmax(x, axis=-1)
+
+
+def regression_activation(x):
+ return x
+
+
+def multilabel_classification_activation(x):
+ return ops.sigmoid(x)
+
+
+def multilabel_regression_activation(x):
+ return x
+
+
+ACTIVATION_FN = {
+ 'classification': classification_activation,
+ 'multiclass': multiclass_activation,
+ 'regression': regression_activation,
+ 'multilabel_classification': multilabel_classification_activation,
+ 'multilabel_regression': multilabel_regression_activation,
+}
+OUTPUT_DIM = {
+ 'classification': 2,
+ 'regression': 1,
+}
+
+
+class NNModel(object):
+ """A :class:`NNModel` class is responsible for initializing the model"""
+
+ def __init__(self, data, trainer, **params):
+ """
+ Initializes the neural network model with the given data and parameters.
+
+ :param data: (dict) Contains the dataset information, including features and target scaling.
+ :param trainer: (object) An instance of a training class, responsible for managing training processes.
+ :param params: Various additional parameters used for model configuration.
+
+ The model is configured based on the task type and specific parameters provided.
+ """
+ self.data = data
+ self.num_classes = self.data['num_classes']
+ self.target_scaler = self.data['target_scaler']
+ self.features = data['unimol_input']
+ self.model_name = params.get('model_name', 'unimolv1')
+ self.data_type = params.get('data_type', 'molecule')
+ self.loss_key = params.get('loss_key', None)
+ self.trainer = trainer
+ # self.splitter = self.trainer.splitter
+ self.model_params = params.copy()
+ self.task = params['task']
+ if self.task in OUTPUT_DIM:
+ self.model_params['output_dim'] = OUTPUT_DIM[self.task]
+ elif self.task == 'multiclass':
+ self.model_params['output_dim'] = self.data['multiclass_cnt']
+ else:
+ self.model_params['output_dim'] = self.num_classes
+ # device is managed by MindSpore context
+ self.cv = dict()
+ self.metrics = self.trainer.metrics
+ if self.task == 'multilabel_classification':
+ if self.loss_key is None:
+ self.loss_key = 'focal'
+ self.loss_func = LOSS_RREGISTER[self.task][self.loss_key]
+ else:
+ self.loss_func = LOSS_RREGISTER[self.task]
+ self.activation_fn = ACTIVATION_FN[self.task]
+ self.save_path = self.trainer.save_path
+ self.trainer.set_seed(self.trainer.seed)
+ self.model = self._init_model(**self.model_params)
+
+ def _init_model(self, model_name, **params):
+ """
+ Initializes the neural network model based on the provided model name and parameters.
+
+ :param model_name: (str) The name of the model to initialize.
+ :param params: Additional parameters for model configuration.
+
+ :return: An instance of the specified neural network model.
+ :raises ValueError: If the model name is not recognized.
+ """
+ if self.task in ['regression', 'multilabel_regression']:
+ params['pooler_dropout'] = 0
+ logger.debug("set pooler_dropout to 0 for regression task")
+ else:
+ pass
+ freeze_layers = params.get('freeze_layers', None)
+ freeze_layers_reversed = params.get('freeze_layers_reversed', False)
+ if model_name in NNMODEL_REGISTER:
+ model = NNMODEL_REGISTER[model_name](**params)
+ if isinstance(freeze_layers, str):
+ freeze_layers = freeze_layers.replace(' ', '').split(',')
+ if isinstance(freeze_layers, list):
+ for layer_name, layer_param in model.named_parameters():
+ should_freeze = any(
+ layer_name.startswith(freeze_layer)
+ for freeze_layer in freeze_layers
+ )
+ layer_param.requires_grad = not (
+ freeze_layers_reversed ^ should_freeze
+ )
+ else:
+ raise ValueError('Unknown model: {}'.format(self.model_name))
+ return model
+
+ def collect_data(self, X, y, idx):
+ """
+ Collects and formats the training or validation data.
+
+ :param X: (np.ndarray or dict) The input features, either as a numpy array or a dictionary of tensors.
+ :param y: (np.ndarray) The target values as a numpy array.
+ :param idx: Indices to select the specific data samples.
+
+ :return: A tuple containing processed input data and target values.
+ :raises ValueError: If X is neither a numpy array nor a dictionary.
+ """
+ assert isinstance(y, np.ndarray), 'y must be numpy array'
+ if isinstance(X, np.ndarray):
+ return X[idx], y[idx]
+ elif isinstance(X, list) or isinstance(X, dict):
+ return {k: v[idx] for k, v in X.items()}, y[idx]
+ else:
+ raise ValueError('X must be numpy array or dict')
+
+ def run(self):
+ """
+ Executes the training process of the model. This involves data preparation,
+ model training, validation, and computing metrics for each fold in cross-validation.
+ """
+ logger.info("start training Uni-Mol:{}".format(self.model_name))
+ X = np.asarray(self.features)
+ y = np.asarray(self.data['target'])
+ group = (
+ np.asarray(self.data['group']) if self.data['group'] is not None else None
+ )
+ if self.task == 'classification':
+ y_pred = np.zeros_like(y.reshape(y.shape[0], self.num_classes)).astype(
+ float
+ )
+ else:
+ y_pred = np.zeros((y.shape[0], self.model_params['output_dim']))
+ for fold, (tr_idx, te_idx) in enumerate(self.data['split_nfolds']):
+ X_train, y_train = X[tr_idx], y[tr_idx]
+ X_valid, y_valid = X[te_idx], y[te_idx]
+ traindataset = NNDataset(X_train, y_train)
+ validdataset = NNDataset(X_valid, y_valid)
+ if fold > 0:
+ # need to initalize model for next fold training
+ self.model = self._init_model(**self.model_params)
+
+ # TODO: move the following code to model.load_pretrained_weights
+ if self.model_params.get('load_model_dir', None) is not None:
+ load_model_path = os.path.join(
+ self.model_params['load_model_dir'], f'model_{fold}.ckpt'
+ )
+ if os.path.exists(load_model_path):
+ try:
+ self.model.load_pretrained_weights(load_model_path, strict=True)
+ logger.info("load model success from {}".format(load_model_path))
+ except Exception as e:
+ logger.warning(f"Failed to load ckpt: {e}")
+ _y_pred = self.trainer.fit_predict(
+ self.model,
+ traindataset,
+ validdataset,
+ self.loss_func,
+ self.activation_fn,
+ self.save_path,
+ fold,
+ self.target_scaler,
+ )
+ y_pred[te_idx] = _y_pred
+
+ if 'multiclass_cnt' in self.data:
+ label_cnt = self.data['multiclass_cnt']
+ else:
+ label_cnt = None
+
+ logger.info(
+ "fold {0}, result {1}".format(
+ fold,
+ self.metrics.cal_metric(
+ self.data['target_scaler'].inverse_transform(y_valid),
+ self.data['target_scaler'].inverse_transform(_y_pred),
+ label_cnt=label_cnt,
+ ),
+ )
+ )
+
+ self.cv['pred'] = y_pred
+ self.cv['metric'] = self.metrics.cal_metric(
+ self.data['target_scaler'].inverse_transform(y),
+ self.data['target_scaler'].inverse_transform(self.cv['pred']),
+ )
+ self.dump(self.cv['pred'], self.save_path, 'cv.data')
+ self.dump(self.cv['metric'], self.save_path, 'metric.result')
+ logger.info("Uni-Mol metrics score: \n{}".format(self.cv['metric']))
+ logger.info("Uni-Mol & Metric result saved!")
+
+ def dump(self, data, dir, name):
+ """
+ Saves the specified data to a file.
+
+ :param data: The data to be saved.
+ :param dir: (str) The directory where the data will be saved.
+ :param name: (str) The name of the file to save the data.
+ """
+ path = os.path.join(dir, name)
+ if not os.path.exists(dir):
+ os.makedirs(dir)
+ joblib.dump(data, path)
+
+ def evaluate(self, trainer=None, checkpoints_path=None):
+ """
+ Evaluates the model by making predictions on the test set and averaging the results.
+
+ :param trainer: An optional trainer instance to use for prediction.
+ :param checkpoints_path: (str) The path to the saved model checkpoints.
+ """
+ logger.info("start predict NNModel:{}".format(self.model_name))
+ testdataset = NNDataset(self.features, np.asarray(self.data['target']))
+ for fold in range(self.data['kfold']):
+ _y_pred, _, __ = trainer.predict(
+ self.model,
+ testdataset,
+ self.loss_func,
+ self.activation_fn,
+ self.save_path,
+ fold,
+ self.target_scaler,
+ epoch=1,
+ load_model=True,
+ )
+ if fold == 0:
+ y_pred = np.zeros_like(_y_pred)
+ y_pred += _y_pred
+ y_pred /= self.data['kfold']
+ self.cv['test_pred'] = y_pred
+
+ def count_parameters(self, model):
+ """
+ Counts the number of trainable parameters in the model.
+
+ :param model: The model whose parameters are to be counted.
+
+ :return: (int) The number of trainable parameters.
+ """
+ return sum(np.prod(p.shape) for p in model.get_parameters() if p.requires_grad)
+
+
+def NNDataset(data, label=None):
+
+ return SimpleDataset(data, label)
+
+
+class SimpleDataset(object):
+
+ def __init__(self, data, label=None):
+ """
+ Initializes the dataset with data and labels.
+
+ :param data: The input data.
+ :param label: The target labels for the input data.
+ """
+ self.data = data
+ self.label = label if label is not None else np.zeros((len(data), 1))
+
+ def __getitem__(self, idx):
+ """
+ Retrieves the data item and its corresponding label at the specified index.
+
+ :param idx: (int) The index of the data item to retrieve.
+
+ :return: A tuple containing the data item and its label.
+ """
+ return self.data[idx], self.label[idx]
+
+ def __len__(self):
+ """
+ Returns the total number of items in the dataset.
+
+ :return: (int) The size of the dataset.
+ """
+ return len(self.data)
diff --git a/MindChem/applications/unimol/unimol_tools/models/transformers.py b/MindChem/applications/unimol/unimol_tools/models/transformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..df7caebf8434a504f368f75655af2f2be39657c1
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/models/transformers.py
@@ -0,0 +1,421 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Optional
+
+import mindspore as ms
+import mindspore.ops as ops
+from mindspore import nn
+from mindspore import Tensor, Parameter
+
+
+def softmax_dropout(
+ input, dropout_prob, is_training=True, mask=None, bias=None, inplace=True
+):
+ """softmax dropout, and mask, bias are optional.
+
+ Args:
+ input (Tensor): input tensor
+ dropout_prob (float): dropout probability
+ is_training (bool, optional): is in training or not. Defaults to True.
+ mask (Tensor, optional): the mask tensor, use as input + mask . Defaults to None.
+ bias (Tensor, optional): the bias tensor, use as input + bias . Defaults to None.
+
+ Returns:
+ Tensor: the result after softmax
+ """
+ if not inplace:
+ # copy a input for non-inplace case
+ input = input.copy()
+ if mask is not None:
+ input = input + mask
+ if bias is not None:
+ input = input + bias
+ softmax = ops.Softmax(axis=-1)
+ dropout = nn.Dropout(p=dropout_prob)
+ return dropout(softmax(input))
+
+
+def get_activation_fn(activation):
+ """Returns the activation function corresponding to `activation`"""
+
+ if activation == "relu":
+ return ops.ReLU()
+ elif activation == "gelu":
+ return ops.GeLU()
+ elif activation == "tanh":
+ return ops.Tanh()
+ elif activation == "linear":
+ return ops.Identity()
+ else:
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
+
+
+class SelfMultiheadAttention(nn.Cell):
+ def __init__(
+ self,
+ embed_dim,
+ num_heads,
+ dropout=0.1,
+ bias=True,
+ scaling_factor=1,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+
+ self.num_heads = num_heads
+ self.dropout = dropout
+
+ self.head_dim = embed_dim // num_heads
+ assert (
+ self.head_dim * num_heads == self.embed_dim
+ ), "embed_dim must be divisible by num_heads"
+ self.scaling = (self.head_dim * scaling_factor) ** -0.5
+
+ self.in_proj = nn.Dense(embed_dim, embed_dim * 3, has_bias=bias)
+ self.out_proj = nn.Dense(embed_dim, embed_dim, has_bias=bias)
+
+ def construct(
+ self,
+ query,
+ key_padding_mask: Optional[Tensor] = None,
+ attn_bias: Optional[Tensor] = None,
+ return_attn: bool = False,
+ ) -> Tensor:
+
+ bsz, tgt_len, embed_dim = query.shape
+ assert embed_dim == self.embed_dim
+
+ q, k, v = self.in_proj(query).split(embed_dim, axis=-1)
+
+ q = (
+ q.view(bsz, tgt_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ .view(bsz * self.num_heads, -1, self.head_dim)
+ * self.scaling
+ )
+ if k is not None:
+ k = (
+ k.view(bsz, -1, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ .view(bsz * self.num_heads, -1, self.head_dim)
+ )
+ if v is not None:
+ v = (
+ v.view(bsz, -1, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ .view(bsz * self.num_heads, -1, self.head_dim)
+ )
+
+ assert k is not None
+ src_len = k.shape[1]
+
+ # This is part of a workaround to get around fork/join parallelism
+ # not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.ndim == 0:
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.shape[0] == bsz
+ assert key_padding_mask.shape[1] == src_len
+
+ batch_matmul = ops.BatchMatMul()
+ attn_weights = batch_matmul(q, k.transpose(1, 2))
+
+ assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
+
+ if key_padding_mask is not None:
+ # don't attend to padding symbols
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ mask_bool = key_padding_mask.unsqueeze(1).unsqueeze(2).astype(ms.bool_)
+ fill_tensor = ops.fill(attn_weights.dtype, attn_weights.shape, ms.Tensor(float('-inf'), dtype=attn_weights.dtype))
+ attn_weights = ops.select(mask_bool, fill_tensor, attn_weights)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if not return_attn:
+ attn = softmax_dropout(
+ attn_weights,
+ self.dropout,
+ self.training,
+ bias=attn_bias,
+ )
+ else:
+ attn_weights = attn_weights + attn_bias
+ attn = softmax_dropout(
+ attn_weights,
+ self.dropout,
+ self.training,
+ inplace=False,
+ )
+
+ o = batch_matmul(attn, v)
+ assert list(o.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
+
+ o = (
+ o.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ .transpose(1, 2)
+ .view(bsz, tgt_len, embed_dim)
+ )
+ o = self.out_proj(o)
+ if not return_attn:
+ return o
+ else:
+ return o, attn_weights, attn
+
+
+class TransformerEncoderLayer(nn.Cell):
+ """
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
+ models.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int = 768,
+ ffn_embed_dim: int = 3072,
+ attention_heads: int = 8,
+ dropout: float = 0.1,
+ attention_dropout: float = 0.1,
+ activation_dropout: float = 0.0,
+ activation_fn: str = "gelu",
+ post_ln=False,
+ ) -> None:
+ super().__init__()
+
+ # Initialize parameters
+ self.embed_dim = embed_dim
+ self.attention_heads = attention_heads
+ self.attention_dropout = attention_dropout
+
+ self.dropout = dropout
+ self.activation_dropout = activation_dropout
+ self.activation_fn = get_activation_fn(activation_fn)
+
+ self.self_attn = SelfMultiheadAttention(
+ self.embed_dim,
+ attention_heads,
+ dropout=attention_dropout,
+ )
+ # layer norm associated with the self attention layer
+ self.self_attn_layer_norm = nn.LayerNorm([self.embed_dim])
+ self.fc1 = nn.Dense(self.embed_dim, ffn_embed_dim)
+ self.fc2 = nn.Dense(ffn_embed_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm([self.embed_dim])
+ self.post_ln = post_ln
+
+ def construct(
+ self,
+ x: Tensor,
+ attn_bias: Optional[Tensor] = None,
+ padding_mask: Optional[Tensor] = None,
+ return_attn: bool = False,
+ ) -> Tensor:
+ """
+ LayerNorm is applied either before or after the self-attention/ffn
+ modules similar to the original Transformer implementation.
+ """
+ residual = x
+ if not self.post_ln:
+ x = self.self_attn_layer_norm(x)
+ # new added
+ x = self.self_attn(
+ query=x,
+ key_padding_mask=padding_mask,
+ attn_bias=attn_bias,
+ return_attn=return_attn,
+ )
+ if return_attn:
+ x, attn_weights, attn_probs = x
+ x = nn.Dropout(p=self.dropout)(x)
+ x = residual + x
+ if self.post_ln:
+ x = self.self_attn_layer_norm(x)
+
+ residual = x
+ if not self.post_ln:
+ x = self.final_layer_norm(x)
+ x = self.fc1(x)
+ x = self.activation_fn(x)
+ x = nn.Dropout(p=self.activation_dropout)(x)
+ x = self.fc2(x)
+ x = nn.Dropout(p=self.dropout)(x)
+ x = residual + x
+ if self.post_ln:
+ x = self.final_layer_norm(x)
+ if not return_attn:
+ return x
+ else:
+ return x, attn_weights, attn_probs
+
+
+class TransformerEncoderWithPair(nn.Cell):
+ """
+ A custom Transformer Encoder module that extends MindSpore's nn.Cell. This encoder is designed for tasks that require understanding pair relationships in sequences. It includes standard transformer encoder layers along with additional normalization and dropout features.
+
+ Attributes:
+ - emb_dropout: Dropout rate applied to the embedding layer.
+ - max_seq_len: Maximum length of the input sequences.
+ - embed_dim: Dimensionality of the embeddings.
+ - attention_heads: Number of attention heads in the transformer layers.
+ - emb_layer_norm: Layer normalization applied to the embedding layer.
+ - final_layer_norm: Optional final layer normalization.
+ - final_head_layer_norm: Optional layer normalization for the attention heads.
+ - layers: A list of transformer encoder layers.
+
+ Methods:
+ construct: Performs the forward pass of the module.
+ """
+
+ def __init__(
+ self,
+ encoder_layers: int = 6,
+ embed_dim: int = 768,
+ ffn_embed_dim: int = 3072,
+ attention_heads: int = 8,
+ emb_dropout: float = 0.1,
+ dropout: float = 0.1,
+ attention_dropout: float = 0.1,
+ activation_dropout: float = 0.0,
+ max_seq_len: int = 256,
+ activation_fn: str = "gelu",
+ post_ln: bool = False,
+ no_final_head_layer_norm: bool = False,
+ ) -> None:
+ """
+ Initializes and configures the layers and other components of the transformer encoder.
+
+ :param encoder_layers: (int) Number of encoder layers in the transformer.
+ :param embed_dim: (int) Dimensionality of the input embeddings.
+ :param ffn_embed_dim: (int) Dimensionality of the feedforward network model.
+ :param attention_heads: (int) Number of attention heads in each encoder layer.
+ :param emb_dropout: (float) Dropout rate for the embedding layer.
+ :param dropout: (float) Dropout rate for the encoder layers.
+ :param attention_dropout: (float) Dropout rate for the attention mechanisms.
+ :param activation_dropout: (float) Dropout rate for activations.
+ :param max_seq_len: (int) Maximum sequence length the model can handle.
+ :param activation_fn: (str) The activation function to use (e.g., "gelu").
+ :param post_ln: (bool) If True, applies layer normalization after the feedforward network.
+ :param no_final_head_layer_norm: (bool) If True, does not apply layer normalization to the final attention head.
+
+ """
+ super().__init__()
+ self.emb_dropout = emb_dropout
+ self.max_seq_len = max_seq_len
+ self.embed_dim = embed_dim
+ self.attention_heads = attention_heads
+ self.emb_layer_norm = nn.LayerNorm([self.embed_dim])
+ if not post_ln:
+ self.final_layer_norm = nn.LayerNorm([self.embed_dim])
+ else:
+ self.final_layer_norm = None
+
+ if not no_final_head_layer_norm:
+ self.final_head_layer_norm = nn.LayerNorm([attention_heads])
+ else:
+ self.final_head_layer_norm = None
+
+ self.layers = nn.CellList(
+ [
+ TransformerEncoderLayer(
+ embed_dim=self.embed_dim,
+ ffn_embed_dim=ffn_embed_dim,
+ attention_heads=attention_heads,
+ dropout=dropout,
+ attention_dropout=attention_dropout,
+ activation_dropout=activation_dropout,
+ activation_fn=activation_fn,
+ post_ln=post_ln,
+ )
+ for _ in range(encoder_layers)
+ ]
+ )
+
+ def construct(
+ self,
+ emb: Tensor,
+ attn_mask: Optional[Tensor] = None,
+ padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """
+ Conducts the forward pass of the transformer encoder.
+
+ :param emb: (Tensor) The input tensor of embeddings.
+ :param attn_mask: (Optional[Tensor]) Attention mask to specify positions to attend to.
+ :param padding_mask: (Optional[Tensor]) Mask to indicate padded elements in the input.
+
+ :return: (Tensor) The output tensor after passing through the transformer encoder layers.
+ It also returns tensors related to pair representation and normalization losses.
+ """
+ bsz = emb.shape[0]
+ seq_len = emb.shape[1]
+ x = self.emb_layer_norm(emb)
+ x = nn.Dropout(p=self.emb_dropout)(x)
+ # account for padding while computing the representation
+ if padding_mask is not None:
+ x = x * (1 - padding_mask.unsqueeze(-1).astype(x.dtype))
+ input_attn_mask = attn_mask
+ input_padding_mask = padding_mask
+
+ def fill_attn_mask(attn_mask, padding_mask, fill_val=float("-inf")):
+ if attn_mask is not None and padding_mask is not None:
+ # merge key_padding_mask and attn_mask
+ attn_mask = attn_mask.view(bsz, -1, seq_len, seq_len)
+ mask_bool = padding_mask.unsqueeze(1).unsqueeze(2).astype(ms.bool_)
+ fill_tensor = ops.fill(attn_mask.dtype, attn_mask.shape, ms.Tensor(float(fill_val), dtype=attn_mask.dtype))
+ attn_mask = ops.select(mask_bool, fill_tensor, attn_mask)
+ attn_mask = attn_mask.view(-1, seq_len, seq_len)
+ padding_mask = None
+ return attn_mask, padding_mask
+
+ assert attn_mask is not None
+ attn_mask, padding_mask = fill_attn_mask(attn_mask, padding_mask)
+ for i in range(len(self.layers)):
+ x, attn_mask, _ = self.layers[i](
+ x, padding_mask=padding_mask, attn_bias=attn_mask, return_attn=True
+ )
+
+ def norm_loss(x, eps=1e-10, tolerance=1.0):
+ x = x.astype(ms.float32)
+ max_norm = x.shape[-1] ** 0.5
+ reduce_sum = ops.ReduceSum(keep_dims=False)
+ norm = ops.sqrt(reduce_sum(x**2, -1) + eps)
+ error = ops.relu((norm - max_norm).abs() - tolerance)
+ return error
+
+ def masked_mean(mask, value, dim=-1, eps=1e-10):
+ reduce_sum = ops.ReduceSum(keep_dims=False)
+ numerator = reduce_sum(mask * value, dim)
+ denominator = eps + reduce_sum(mask, dim)
+ return (numerator / denominator).mean()
+
+ x_norm = norm_loss(x)
+ if input_padding_mask is not None:
+ token_mask = 1.0 - input_padding_mask.astype(ms.float32)
+ else:
+ token_mask = ops.ones_like(x_norm)
+ x_norm = masked_mean(token_mask, x_norm)
+
+ if self.final_layer_norm is not None:
+ x = self.final_layer_norm(x)
+
+ delta_pair_repr = attn_mask - input_attn_mask
+ delta_pair_repr, _ = fill_attn_mask(delta_pair_repr, input_padding_mask, 0)
+ attn_mask = (
+ attn_mask.view(bsz, -1, seq_len, seq_len).transpose(0, 2, 3, 1)
+ )
+ delta_pair_repr = (
+ delta_pair_repr.view(bsz, -1, seq_len, seq_len)
+ .transpose(0, 2, 3, 1)
+ )
+
+ pair_mask = token_mask[..., None] * token_mask[..., None, :]
+ delta_pair_repr_norm = norm_loss(delta_pair_repr)
+ delta_pair_repr_norm = masked_mean(
+ pair_mask, delta_pair_repr_norm, dim=(-1, -2)
+ )
+
+ if self.final_head_layer_norm is not None:
+ delta_pair_repr = self.final_head_layer_norm(delta_pair_repr)
+
+ return x, attn_mask, delta_pair_repr, x_norm, delta_pair_repr_norm
diff --git a/MindChem/applications/unimol/unimol_tools/models/transformersv2.py b/MindChem/applications/unimol/unimol_tools/models/transformersv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..615e2bb53d8402ae9522cc802e701c7d01af7127
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/models/transformersv2.py
@@ -0,0 +1,604 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import Optional, Tuple
+
+import mindspore as ms
+import mindspore.ops as ops
+from mindspore import nn
+
+
+def softmax_dropout(
+ input, dropout_prob, is_training=True, mask=None, bias=None, inplace=True
+):
+ """softmax dropout, and mask, bias are optional.
+ Args:
+ input (Tensor): input tensor
+ dropout_prob (float): dropout probability
+ is_training (bool, optional): is in training or not. Defaults to True.
+ mask (Tensor, optional): the mask tensor, use as input + mask . Defaults to None.
+ bias (Tensor, optional): the bias tensor, use as input + bias . Defaults to None.
+
+ Returns:
+ Tensor: the result after softmax
+ """
+ if not inplace:
+ input = input.copy()
+ if mask is not None:
+ input = input + mask
+ if bias is not None:
+ input = input + bias
+ softmax = ops.Softmax(axis=-1)
+ dropout = nn.Dropout(p=dropout_prob)
+ return dropout(softmax(input)) if is_training else softmax(input)
+
+
+def permute_final_dims(tensor: ms.Tensor, inds):
+ zero_index = -1 * len(inds)
+ first_inds = list(range(len(tensor.shape[:zero_index])))
+ return ops.transpose(tensor, tuple(first_inds + [zero_index + i for i in inds]))
+
+
+class Dropout(nn.Cell):
+ def __init__(self, p):
+ super().__init__()
+ self.dropout = nn.Dropout(p=p)
+
+ def construct(self, x, inplace: bool = False):
+ return self.dropout(x)
+
+
+class Linear(nn.Cell):
+ def __init__(
+ self,
+ d_in: int,
+ d_out: int,
+ bias: bool = True,
+ init: str = "default",
+ ):
+ super().__init__()
+ self.proj = nn.Dense(d_in, d_out, has_bias=bias)
+
+ def construct(self, x):
+ return self.proj(x)
+
+
+class Embedding(nn.Cell):
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ padding_idx: int = None,
+ ):
+ super().__init__()
+ self.emb = nn.Embedding(vocab_size=num_embeddings, embedding_size=embedding_dim, padding_idx=padding_idx)
+
+ @property
+ def weight(self):
+ return self.emb.embedding_table
+
+ def construct(self, x):
+ return self.emb(x.astype(ms.int32))
+
+
+class Transition(nn.Cell):
+ def __init__(self, d_in, n, dropout=0.0):
+ super(Transition, self).__init__()
+ self.d_in = d_in
+ self.n = n
+ self.linear_1 = Linear(self.d_in, self.n * self.d_in)
+ self.act = nn.GELU()
+ self.linear_2 = Linear(self.n * self.d_in, d_in)
+ self.dropout = nn.Dropout(p=dropout)
+
+ def _transition(self, x):
+ x = self.linear_1(x)
+ x = self.act(x)
+ x = self.dropout(x)
+ x = self.linear_2(x)
+ return x
+
+ def construct(
+ self,
+ x: ms.Tensor,
+ ) -> ms.Tensor:
+ x = self._transition(x=x)
+ return x
+
+
+class Attention(nn.Cell):
+ def __init__(
+ self,
+ q_dim: int,
+ k_dim: int,
+ v_dim: int,
+ pair_dim: int,
+ head_dim: int,
+ num_heads: int,
+ gating: bool = False,
+ dropout: float = 0.0,
+ ):
+ super(Attention, self).__init__()
+ self.num_heads = num_heads
+ total_dim = head_dim * self.num_heads
+ self.gating = gating
+ self.linear_q = Linear(q_dim, total_dim, bias=False)
+ self.linear_k = Linear(k_dim, total_dim, bias=False)
+ self.linear_v = Linear(v_dim, total_dim, bias=False)
+ self.linear_o = Linear(total_dim, q_dim)
+ self.linear_g = Linear(q_dim, total_dim) if self.gating else None
+ self.norm = head_dim**-0.5
+ self.dropout = dropout
+ self.linear_bias = Linear(pair_dim, num_heads)
+
+ def construct(
+ self,
+ q: ms.Tensor,
+ k: ms.Tensor,
+ v: ms.Tensor,
+ pair: ms.Tensor,
+ mask: ms.Tensor = None,
+ ) -> ms.Tensor:
+ g = None
+ if self.linear_g is not None:
+ g = self.linear_g(q)
+ q = self.linear_q(q)
+ q = q * self.norm
+ k = self.linear_k(k)
+ v = self.linear_v(v)
+ def split_heads(t):
+ bsz = t.shape[0]
+ head_dim = t.shape[-1] // self.num_heads
+ return ops.transpose(ops.reshape(t, (bsz, -1, self.num_heads, head_dim)), (0, 2, 1, 3))
+ q = split_heads(q)
+ k = split_heads(k)
+ v = split_heads(v)
+ attn = ops.matmul(q, ops.transpose(k, (0, 1, 3, 2)))
+ bias = ops.transpose(self.linear_bias(pair), (0, 3, 1, 2))
+ attn = softmax_dropout(attn, self.dropout, True, mask=mask, bias=bias)
+ o = ops.matmul(attn, v)
+ o = ops.transpose(o, (0, 2, 1, 3))
+ bsz = o.shape[0]
+ o = ops.reshape(o, (bsz, -1, self.num_heads * (o.shape[-1])))
+ if g is not None:
+ o = ops.sigmoid(g) * o
+ o = self.linear_o(o)
+ return o
+
+
+class OuterProduct(nn.Cell):
+ def __init__(self, d_atom, d_pair, d_hid=32):
+ super(OuterProduct, self).__init__()
+ self.d_atom = d_atom
+ self.d_pair = d_pair
+ self.d_hid = d_hid
+ self.linear_in = nn.Dense(d_atom, d_hid * 2)
+ self.linear_out = nn.Dense(d_hid**2, d_pair)
+ self.act = nn.GELU()
+
+ def _opm(self, a, b):
+ bsz, n, d = a.shape
+ a = ops.reshape(a, (bsz, n, 1, d, 1))
+ b = ops.reshape(b, (bsz, 1, n, 1, d))
+ outer = a * b
+ outer = ops.reshape(outer, outer.shape[:-2] + (outer.shape[-2] * outer.shape[-1],))
+ outer = self.linear_out(outer)
+ return outer
+
+ def construct(
+ self,
+ m: ms.Tensor,
+ op_mask: Optional[ms.Tensor] = None,
+ op_norm: Optional[ms.Tensor] = None,
+ ) -> ms.Tensor:
+ ab = self.linear_in(m)
+ ab = ab * op_mask
+ a, b = ops.split(ab, split_size_or_sections=ab.shape[-1] // 2, axis=-1)
+ z = self._opm(a, b)
+ z = z * op_norm
+ return z
+
+
+class AtomFeature(nn.Cell):
+ def __init__(
+ self,
+ num_atom,
+ num_degree,
+ hidden_dim,
+ ):
+ super(AtomFeature, self).__init__()
+ self.atom_encoder = Embedding(num_atom, hidden_dim, padding_idx=0)
+ self.degree_encoder = Embedding(num_degree, hidden_dim, padding_idx=0)
+ self.vnode_encoder = Embedding(1, hidden_dim)
+
+ def construct(self, batched_data, token_feat):
+ x, degree = (
+ batched_data["atom_feat"],
+ batched_data["degree"],
+ )
+ n_graph, n_node = x.shape[:2]
+ node_feature = ops.reduce_sum(self.atom_encoder(x), -2)
+ dtype = node_feature.dtype
+ degree_feature = self.degree_encoder(degree)
+ node_feature = node_feature + degree_feature + token_feat
+ graph_token_feature = ops.tile(self.vnode_encoder.weight.view(1, 1, -1), (n_graph, 1, 1))
+ graph_node_feature = ops.concat((graph_token_feature, node_feature), axis=1)
+ return graph_node_feature.astype(dtype)
+
+
+class EdgeFeature(nn.Cell):
+ def __init__(
+ self,
+ pair_dim,
+ num_edge,
+ num_spatial,
+ ):
+ super(EdgeFeature, self).__init__()
+ self.pair_dim = pair_dim
+ self.edge_encoder = Embedding(num_edge, pair_dim, padding_idx=0)
+ self.shorest_path_encoder = Embedding(num_spatial, pair_dim, padding_idx=0)
+ self.vnode_virtual_distance = Embedding(1, pair_dim)
+
+ def construct(self, batched_data, graph_attn_bias):
+ shortest_path = batched_data["shortest_path"]
+ edge_input = batched_data["edge_feat"]
+ graph_attn_bias[:, 1:, 1:, :] = self.shorest_path_encoder(shortest_path)
+ t = ops.reshape(self.vnode_virtual_distance.weight, (1, 1, self.pair_dim))
+ graph_attn_bias[:, 1:, 0, :] = t
+ graph_attn_bias[:, 0, :, :] = t
+ edge_input = ops.reduce_mean(self.edge_encoder(edge_input), -2)
+ graph_attn_bias[:, 1:, 1:, :] = graph_attn_bias[:, 1:, 1:, :] + edge_input
+ return graph_attn_bias
+
+
+def gaussian(x, mean, std):
+ pi = 3.14159
+ a = (2 * pi) ** 0.5
+ return ops.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std)
+
+
+class GaussianKernel(nn.Cell):
+ def __init__(self, K=128, num_pair=512, std_width=1.0, start=0.0, stop=9.0):
+ super().__init__()
+ self.K = K
+ mean = ops.linspace(ms.Tensor(start, ms.float32), ms.Tensor(stop, ms.float32), K)
+ self.std = std_width * (mean[1] - mean[0])
+ # Store precomputed mean as a tensor attribute
+ # self.mean_buf = mean
+ self.mean = ms.Parameter(mean, requires_grad=False, name="mean")
+ self.mul = Embedding(num_pair, 1, padding_idx=0)
+ self.bias = Embedding(num_pair, 1, padding_idx=0)
+
+ def construct(self, x, atom_pair):
+ # 这里需要对mul和bias的做squeeze操作,将其维度从5维降到4维,防止在反向传播中维数超过8维
+ # mul = ops.abs(self.mul(atom_pair)).sum(axis=-2)
+ mul = ops.abs(ops.squeeze(self.mul(atom_pair))).sum(axis=-1)
+ # bias = self.bias(atom_pair).sum(axis=-2)
+ bias = ops.squeeze(self.bias(atom_pair)).sum(axis=-1)
+ # x = mul * ops.expand_dims(x, -1) + bias
+ x = mul * x + bias
+ x = ops.tile(ops.expand_dims(x, -1), (1, 1, 1, self.K))
+ mean = ops.reshape(self.mean.astype(ms.float32), (-1,))
+ return gaussian(x.astype(ms.float32), mean, self.std)
+
+
+class NonLinear(nn.Cell):
+ def __init__(self, input, output_size, hidden=None):
+ super(NonLinear, self).__init__()
+ if hidden is None:
+ hidden = input
+ self.layer1 = Linear(input, hidden)
+ self.layer2 = Linear(hidden, output_size)
+ self.gelu = nn.GELU()
+
+ def construct(self, x):
+ x = self.layer1(x)
+ x = self.gelu(x)
+ x = self.layer2(x)
+ return x
+
+ def zero_init(self):
+ pass
+
+
+class SE3InvariantKernel(nn.Cell):
+ def __init__(
+ self,
+ pair_dim,
+ num_pair,
+ num_kernel,
+ std_width=1.0,
+ start=0.0,
+ stop=9.0,
+ ):
+ super(SE3InvariantKernel, self).__init__()
+ self.num_kernel = num_kernel
+ self.gaussian = GaussianKernel(self.num_kernel, num_pair, std_width=std_width, start=start, stop=stop)
+ self.out_proj = NonLinear(self.num_kernel, pair_dim)
+
+ def construct(self, dist, node_type_edge):
+ edge_feature = self.gaussian(dist, node_type_edge.astype(ms.int64))
+ edge_feature = self.out_proj(edge_feature)
+ return edge_feature
+
+
+class MovementPredictionHead(nn.Cell):
+ def __init__(
+ self,
+ embed_dim: int,
+ pair_dim: int,
+ num_head: int,
+ ):
+ super().__init__()
+ self.layer_norm = nn.LayerNorm([embed_dim])
+ self.embed_dim = embed_dim
+ self.q_proj = Linear(embed_dim, embed_dim, bias=False)
+ self.k_proj = Linear(embed_dim, embed_dim, bias=False)
+ self.v_proj = Linear(embed_dim, embed_dim, bias=False)
+ self.num_head = num_head
+ self.scaling = (embed_dim // num_head) ** -0.5
+ self.force_proj1 = Linear(embed_dim, 1, bias=False)
+ self.linear_bias = Linear(pair_dim, num_head)
+ self.pair_layer_norm = nn.LayerNorm([pair_dim])
+ self.dropout = 0.1
+
+ def zero_init(self):
+ pass
+
+ def construct(
+ self,
+ query,
+ pair,
+ attn_mask,
+ delta_pos,
+ ) -> ms.Tensor:
+ bsz, n_node, _ = query.shape
+ query = self.layer_norm(query)
+ q = ops.transpose(ops.reshape(self.q_proj(query), (bsz, n_node, self.num_head, -1)), (0, 2, 1, 3)) * self.scaling
+ k = ops.transpose(ops.reshape(self.k_proj(query), (bsz, n_node, self.num_head, -1)), (0, 2, 1, 3))
+ v = ops.transpose(ops.reshape(self.v_proj(query), (bsz, n_node, self.num_head, -1)), (0, 2, 1, 3))
+ attn = ops.matmul(q, ops.transpose(k, (0, 1, 3, 2)))
+ pair = self.pair_layer_norm(pair)
+ bias = ops.transpose(self.linear_bias(pair), (0, 3, 1, 2))
+ attn_probs = softmax_dropout(
+ attn,
+ self.dropout,
+ True,
+ mask=attn_mask,
+ bias=bias,
+ )
+ attn_probs = ops.reshape(attn_probs, (bsz, self.num_head, n_node, n_node))
+ rot_attn_probs = attn_probs.unsqueeze(-1) * delta_pos.unsqueeze(1).astype(attn_probs.dtype)
+ rot_attn_probs = ops.transpose(rot_attn_probs, (0, 1, 4, 2, 3))
+ x = ops.matmul(rot_attn_probs, v.unsqueeze(2))
+ x = ops.transpose(x, (0, 3, 2, 1, 4))
+ x = ops.reshape(x, (bsz, n_node, 3, -1))
+ cur_force = ops.reshape(self.force_proj1(x), (bsz, n_node, 3))
+ return cur_force
+
+
+class DropPath(nn.Cell):
+ def __init__(self, prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = prob if prob is not None else 0.0
+
+ def construct(self, x):
+ if self.drop_prob == 0.0:
+ return x
+ keep_prob = 1 - self.drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
+ random_tensor = keep_prob + ops.uniform(shape, ms.Tensor(0.0, x.dtype), ms.Tensor(1.0, x.dtype))
+ random_tensor = ops.floor(random_tensor)
+ output = x / keep_prob * random_tensor
+ return output
+
+ def extra_repr(self) -> str:
+ return f"prob={self.drop_prob}"
+
+
+class TriangleMultiplication(nn.Cell):
+ def __init__(self, d_pair, d_hid):
+ super(TriangleMultiplication, self).__init__()
+ self.linear_ab_p = Linear(d_pair, d_hid * 2)
+ self.linear_ab_g = Linear(d_pair, d_hid * 2)
+ self.linear_g = Linear(d_pair, d_pair)
+ self.linear_z = Linear(d_hid, d_pair)
+ self.layer_norm_out = nn.LayerNorm([d_hid])
+
+ def construct(
+ self,
+ z: ms.Tensor,
+ mask: Optional[ms.Tensor] = None,
+ ) -> ms.Tensor:
+ mask = mask.unsqueeze(-1)
+ mask = mask * (mask.shape[-2] ** -0.5)
+ g = self.linear_g(z)
+ ab = self.linear_ab_p(z) * mask * ops.sigmoid(self.linear_ab_g(z))
+ a, b = ops.split(ab, split_size_or_sections=ab.shape[-1] // 2, axis=-1)
+ a1 = permute_final_dims(a, (2, 0, 1))
+ b1 = b.transpose(0, 3, 2, 1)
+ # b1 = ops.transpose(b, (0, 2, 1))
+ x = ops.matmul(a1, b1)
+ b2 = permute_final_dims(b, (2, 0, 1))
+ a2 = a.transpose(0, 3, 2, 1)
+ # a2 = ops.transpose(a, (0, 2, 1))
+ x = x + ops.matmul(a2, b2)
+ x = permute_final_dims(x, (1, 2, 0))
+ x = self.layer_norm_out(x)
+ x = self.linear_z(x)
+ return g * x
+
+
+def get_activation_fn(activation):
+ if activation == "relu":
+ return nn.ReLU()
+ elif activation == "gelu":
+ return nn.GELU()
+ elif activation == "tanh":
+ return nn.Tanh()
+ elif activation == "linear":
+ return nn.Identity()
+ else:
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
+
+
+class TransformerEncoderLayerV2(nn.Cell):
+ def __init__(
+ self,
+ embedding_dim: int = 768,
+ pair_dim: int = 64,
+ pair_hidden_dim: int = 32,
+ ffn_embedding_dim: int = 3072,
+ num_attention_heads: int = 8,
+ dropout: float = 0.1,
+ attention_dropout: float = 0.1,
+ activation_dropout: float = 0.1,
+ activation_fn: str = "relu",
+ droppath_prob: float = 0.0,
+ pair_dropout: float = 0.25,
+ ) -> None:
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.num_attention_heads = num_attention_heads
+ self.attention_dropout = attention_dropout
+ self.dropout_module = DropPath(droppath_prob) if droppath_prob > 0.0 else Dropout(dropout)
+ self.activation_fn = get_activation_fn(activation_fn)
+ head_dim = self.embedding_dim // self.num_attention_heads
+ self.self_attn = Attention(
+ self.embedding_dim,
+ self.embedding_dim,
+ self.embedding_dim,
+ pair_dim=pair_dim,
+ head_dim=head_dim,
+ num_heads=self.num_attention_heads,
+ gating=False,
+ dropout=attention_dropout,
+ )
+ self.self_attn_layer_norm = nn.LayerNorm([self.embedding_dim])
+ self.ffn = Transition(
+ self.embedding_dim,
+ ffn_embedding_dim // self.embedding_dim,
+ dropout=activation_dropout,
+ )
+ self.final_layer_norm = nn.LayerNorm([self.embedding_dim])
+ self.x_layer_norm_opm = nn.LayerNorm([self.embedding_dim])
+ self.opm = OuterProduct(self.embedding_dim, pair_dim, d_hid=pair_hidden_dim)
+ self.pair_layer_norm_ffn = nn.LayerNorm([pair_dim])
+ self.pair_ffn = Transition(
+ pair_dim,
+ 1,
+ dropout=activation_dropout,
+ )
+ self.pair_dropout = pair_dropout
+ self.pair_layer_norm_trimul = nn.LayerNorm([pair_dim])
+ self.pair_tri_mul = TriangleMultiplication(pair_dim, pair_hidden_dim)
+
+ def shared_dropout(self, x, shared_dim, dropout):
+ shape = list(x.shape)
+ shape[shared_dim] = 1
+ mask = ops.ones(tuple(shape), dtype=x.dtype)
+ return nn.Dropout(p=dropout)(mask) * x
+
+ def construct(
+ self,
+ x: ms.Tensor,
+ pair: ms.Tensor,
+ pair_mask: ms.Tensor,
+ self_attn_mask: Optional[ms.Tensor] = None,
+ op_mask: Optional[ms.Tensor] = None,
+ op_norm: Optional[ms.Tensor] = None,
+ ):
+ residual = x
+ x = self.self_attn_layer_norm(x)
+ x = self.self_attn(
+ x,
+ x,
+ x,
+ pair=pair,
+ mask=self_attn_mask,
+ )
+ x = self.dropout_module(x)
+ x = residual + x
+ x = x + self.dropout_module(self.ffn(self.final_layer_norm(x)))
+ pair = pair + self.dropout_module(
+ self.opm(self.x_layer_norm_opm(x), op_mask, op_norm)
+ )
+ pair = pair + self.shared_dropout(
+ self.pair_tri_mul(self.pair_layer_norm_trimul(pair), pair_mask),
+ -3,
+ self.pair_dropout,
+ )
+ pair = pair + self.dropout_module(self.pair_ffn(self.pair_layer_norm_ffn(pair)))
+ return x, pair
+
+
+class TransformerEncoderWithPairV2(nn.Cell):
+ def __init__(
+ self,
+ num_encoder_layers: int = 6,
+ embedding_dim: int = 768,
+ pair_dim: int = 64,
+ pair_hidden_dim: int = 32,
+ ffn_embedding_dim: int = 3072,
+ num_attention_heads: int = 8,
+ dropout: float = 0.1,
+ attention_dropout: float = 0.1,
+ activation_dropout: float = 0.0,
+ activation_fn: str = "gelu",
+ droppath_prob: float = 0.0,
+ pair_dropout: float = 0.25,
+ ) -> None:
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.num_head = num_attention_heads
+ self.layer_norm = nn.LayerNorm([embedding_dim])
+ self.pair_layer_norm = nn.LayerNorm([pair_dim])
+ self.layers = nn.CellList([])
+ droppath_probs = [float(x.asnumpy()) for x in ops.linspace(ms.Tensor(0.0, ms.float32), ms.Tensor(droppath_prob, ms.float32), num_encoder_layers)] if droppath_prob > 0 else None
+ self.layers.extend(
+ [
+ TransformerEncoderLayerV2(
+ embedding_dim=embedding_dim,
+ pair_dim=pair_dim,
+ pair_hidden_dim=pair_hidden_dim,
+ ffn_embedding_dim=ffn_embedding_dim,
+ num_attention_heads=num_attention_heads,
+ dropout=dropout,
+ attention_dropout=attention_dropout,
+ activation_dropout=activation_dropout,
+ activation_fn=activation_fn,
+ droppath_prob=(droppath_probs[i] if droppath_probs is not None else 0),
+ pair_dropout=pair_dropout,
+ )
+ for i in range(num_encoder_layers)
+ ]
+ )
+
+ def construct(
+ self,
+ x,
+ pair,
+ atom_mask,
+ pair_mask,
+ attn_mask=None,
+ ) -> Tuple[ms.Tensor, ms.Tensor]:
+ x = self.layer_norm(x)
+ pair = self.pair_layer_norm(pair)
+ op_mask = atom_mask.unsqueeze(-1)
+ op_mask = op_mask * (op_mask.shape[-2] ** -0.5)
+ eps = 1e-3
+ # Compute outer product along the last dim using batched matmul: (B, N, 1) x (B, 1, N) -> (B, N, N), then expand last dim
+ op_outer = ops.matmul(op_mask, ops.transpose(op_mask, (0, 2, 1)))
+ op_norm = 1.0 / (eps + op_outer).unsqueeze(-1)
+ for layer in self.layers:
+ x, pair = layer(
+ x,
+ pair,
+ pair_mask=pair_mask,
+ self_attn_mask=attn_mask,
+ op_mask=op_mask,
+ op_norm=op_norm,
+ )
+ return x, pair
diff --git a/MindChem/applications/unimol/unimol_tools/models/unimol.py b/MindChem/applications/unimol/unimol_tools/models/unimol.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2e29076d010b8a2a57a4fb16f301c0672190838
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/models/unimol.py
@@ -0,0 +1,719 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import absolute_import, division, print_function
+
+import os
+
+# import argparse
+import pathlib
+
+import mindspore as ms
+import mindspore.ops as ops
+from mindspore import nn
+from addict import Dict
+import numpy as np
+
+from ..config import MODEL_CONFIG
+from ..data import Dictionary
+from ..utils import logger
+from ..weights import WEIGHT_DIR, weight_download
+from .transformers import TransformerEncoderWithPair
+
+BACKBONE = {
+ 'transformer': TransformerEncoderWithPair,
+}
+
+
+class UniMolModel(nn.Cell):
+ """
+ UniMolModel is a specialized model for molecular, protein, crystal, or MOF (Metal-Organic Frameworks) data.
+ It dynamically configures its architecture based on the type of data it is intended to work with. The model
+ supports multiple data types and incorporates various architecture configurations and pretrained weights.
+
+ Attributes:
+ - output_dim: The dimension of the output layer.
+ - data_type: The type of data the model is designed to handle.
+ - remove_hs: Flag to indicate whether hydrogen atoms are removed in molecular data.
+ - pretrain_path: Path to the pretrained model weights.
+ - dictionary: The dictionary object used for tokenization and encoding.
+ - mask_idx: Index of the mask token in the dictionary.
+ - padding_idx: Index of the padding token in the dictionary.
+ - embed_tokens: Embedding layer for token embeddings.
+ - encoder: Transformer encoder backbone of the model.
+ - gbf_proj, gbf: Layers for Gaussian basis functions or numerical embeddings.
+ - classification_head: The final classification head of the model.
+ """
+
+ def __init__(self, output_dim=2, data_type='molecule', **params):
+ """
+ Initializes the UniMolModel with specified parameters and data type.
+
+ :param output_dim: (int) The number of output dimensions (classes).
+ :param data_type: (str) The type of data (e.g., 'molecule', 'protein').
+ :param params: Additional parameters for model configuration.
+ """
+ super().__init__()
+ if data_type == 'molecule':
+ self.args = molecule_architecture()
+ elif data_type == 'oled':
+ self.args = oled_architecture()
+ elif data_type == 'protein':
+ self.args = protein_architecture()
+ elif data_type == 'crystal':
+ self.args = crystal_architecture()
+ else:
+ raise ValueError('Current not support data type: {}'.format(data_type))
+ self.output_dim = output_dim
+ self.data_type = data_type
+ self.remove_hs = params.get('remove_hs', False)
+ if data_type == 'molecule':
+ name = "no_h" if self.remove_hs else "all_h"
+ name = data_type + '_' + name
+ else:
+ name = data_type
+ if not os.path.exists(os.path.join(WEIGHT_DIR, MODEL_CONFIG['weight'][name])):
+ weight_download(MODEL_CONFIG['weight'][name], WEIGHT_DIR)
+ if not os.path.exists(os.path.join(WEIGHT_DIR, MODEL_CONFIG['dict'][name])):
+ weight_download(MODEL_CONFIG['dict'][name], WEIGHT_DIR)
+ # Prefer MindSpore .ckpt if available alongside the original .pt
+ original_weight_path = os.path.join(WEIGHT_DIR, MODEL_CONFIG['weight'][name])
+ candidate_ckpt_path = original_weight_path[:-3] + 'ckpt' if original_weight_path.endswith('.pt') else None
+ env_override_ckpt = os.environ.get('UNIMOL_MS_CKPT', None)
+ if env_override_ckpt and os.path.exists(env_override_ckpt):
+ self.pretrain_path = env_override_ckpt
+ elif candidate_ckpt_path and os.path.exists(candidate_ckpt_path):
+ self.pretrain_path = candidate_ckpt_path
+ else:
+ self.pretrain_path = original_weight_path
+ self.dictionary = Dictionary.load(
+ os.path.join(WEIGHT_DIR, MODEL_CONFIG['dict'][name])
+ )
+ self.mask_idx = self.dictionary.add_symbol("[MASK]", is_special=True)
+ self.padding_idx = self.dictionary.pad()
+ self.embed_tokens = nn.Embedding(
+ vocab_size=len(self.dictionary),
+ embedding_size=self.args.encoder_embed_dim,
+ padding_idx=self.padding_idx,
+ )
+ self.encoder = BACKBONE[self.args.backbone](
+ encoder_layers=self.args.encoder_layers,
+ embed_dim=self.args.encoder_embed_dim,
+ ffn_embed_dim=self.args.encoder_ffn_embed_dim,
+ attention_heads=self.args.encoder_attention_heads,
+ emb_dropout=self.args.emb_dropout,
+ dropout=self.args.dropout,
+ attention_dropout=self.args.attention_dropout,
+ activation_dropout=self.args.activation_dropout,
+ max_seq_len=self.args.max_seq_len,
+ activation_fn=self.args.activation_fn,
+ no_final_head_layer_norm=self.args.delta_pair_repr_norm_loss < 0,
+ )
+ K = 128
+ n_edge_type = len(self.dictionary) * len(self.dictionary)
+ self.gbf_proj = NonLinearHead(
+ K, self.args.encoder_attention_heads, self.args.activation_fn
+ )
+ if self.args.kernel == 'gaussian':
+ self.gbf = GaussianLayer(K, n_edge_type)
+ else:
+ self.gbf = NumericalEmbed(K, n_edge_type)
+
+ self.lm_head = MaskLMHead(
+ embed_dim=self.args.encoder_embed_dim,
+ output_dim=len(self.dictionary),
+ activation_fn=self.args.activation_fn,
+ )
+ # pair2coord_proj exists when masked coord loss is used in PT
+ self.pair2coord_proj = NonLinearHead(
+ self.args.encoder_attention_heads, 1, self.args.activation_fn
+ )
+ # dist_head exists when masked dist loss is used in PT
+ self.dist_head = DistanceHead(
+ input_dim=self.args.encoder_attention_heads,
+ activation_fn=self.args.activation_fn,
+ )
+ """
+ # To be deprecated in the future.
+ self.classification_head = ClassificationHead(
+ input_dim=self.args.encoder_embed_dim,
+ inner_dim=self.args.encoder_embed_dim,
+ num_classes=self.output_dim,
+ activation_fn=self.args.pooler_activation_fn,
+ pooler_dropout=self.args.pooler_dropout,
+ )
+ """
+ if 'pooler_dropout' in params:
+ self.args.pooler_dropout = params['pooler_dropout']
+ self.classification_head = LinearHead(
+ input_dim=self.args.encoder_embed_dim,
+ num_classes=self.output_dim,
+ pooler_dropout=self.args.pooler_dropout,
+ )
+ self.load_pretrained_weights(path=self.pretrain_path)
+
+ def load_pretrained_weights(self, path, strict=False):
+ """
+ Loads pretrained weights into the model.
+
+ :param path: (str) Path to the pretrained weight file.
+ """
+ if path is not None:
+ if path.endswith('.ckpt') and os.path.exists(path):
+ logger.info("Loading MindSpore checkpoint from {}".format(path))
+ from mindspore import load_checkpoint, load_param_into_net
+ param_dict = load_checkpoint(path)
+ param_dict["classification_head.out_proj.weight"] = self.classification_head.out_proj.weight
+ param_dict["classification_head.out_proj.bias"] = self.classification_head.out_proj.bias
+ load_param_into_net(self, param_dict)
+ else:
+ logger.warning(
+ "Pretrained weights are not in MindSpore .ckpt format ({}). Skipping load."
+ .format(path)
+ )
+
+ @classmethod
+ def build_model(cls, args):
+ """
+ Class method to build a new instance of the UniMolModel.
+
+ :param args: Arguments for model configuration.
+ :return: An instance of UniMolModel.
+ """
+ return cls(args)
+
+ def construct(
+ self,
+ src_tokens,
+ src_distance,
+ src_coord,
+ src_edge_type,
+ return_repr=False,
+ return_atomic_reprs=False,
+ **kwargs
+ ):
+ """
+ Defines the forward pass of the model.
+
+ :param src_tokens: Tokenized input data.
+ :param src_distance: Additional molecular features.
+ :param src_coord: Additional molecular features.
+ :param src_edge_type: Additional molecular features.
+ :param gas_id: Optional environmental features for MOFs.
+ :param gas_attr: Optional environmental features for MOFs.
+ :param pressure: Optional environmental features for MOFs.
+ :param temperature: Optional environmental features for MOFs.
+ :param return_repr: Flags to return intermediate representations.
+ :param return_atomic_reprs: Flags to return intermediate representations.
+
+ :return: Output logits or requested intermediate representations.
+ """
+ padding_mask = ops.equal(src_tokens, self.padding_idx)
+ if not ops.any(padding_mask):
+ padding_mask = None
+ x = self.embed_tokens(src_tokens)
+
+ def get_dist_features(dist, et):
+ n_node = dist.shape[-1]
+ gbf_feature = self.gbf(dist, et)
+ gbf_result = self.gbf_proj(gbf_feature)
+ graph_attn_bias = gbf_result
+ graph_attn_bias = ops.transpose(graph_attn_bias, (0, 3, 1, 2))
+ graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
+ return graph_attn_bias
+
+ graph_attn_bias = get_dist_features(src_distance, src_edge_type)
+ (
+ encoder_rep,
+ _,
+ _,
+ _,
+ _,
+ ) = self.encoder(x, padding_mask=padding_mask, attn_mask=graph_attn_bias)
+ cls_repr = encoder_rep[:, 0, :]
+ all_repr = encoder_rep[:, :, :]
+
+ if return_repr:
+ filtered_tensors = []
+ filtered_coords = []
+ for tokens, coord in zip(src_tokens, src_coord):
+ mask = (tokens != 0) & (tokens != 1) & (tokens != 2)
+ filtered_tensor = tokens[mask]
+ filtered_coord = coord[mask]
+ filtered_tensors.append(filtered_tensor)
+ filtered_coords.append(filtered_coord)
+
+ lengths = [
+ len(filtered_tensor) for filtered_tensor in filtered_tensors
+ ] # Compute the lengths of the filtered tensors
+ if return_atomic_reprs:
+ cls_atomic_reprs = []
+ atomic_symbols = []
+ for i in range(len(all_repr)):
+ atomic_reprs = encoder_rep[i, 1 : lengths[i] + 1, :]
+ atomic_symbol = []
+ for atomic_num in filtered_tensors[i]:
+ atomic_symbol.append(self.dictionary.symbols[atomic_num])
+ atomic_symbols.append(atomic_symbol)
+ cls_atomic_reprs.append(atomic_reprs)
+ return {
+ 'cls_repr': cls_repr,
+ 'atomic_symbol': atomic_symbols,
+ 'atomic_coords': filtered_coords,
+ 'atomic_reprs': cls_atomic_reprs,
+ }
+ else:
+ return {'cls_repr': cls_repr}
+
+ logits = self.classification_head(cls_repr)
+ return logits
+
+ def batch_collate_fn(self, samples):
+ """
+ Custom collate function for batch processing non-MOF data.
+
+ :param samples: A list of sample data.
+
+ :return: A tuple containing a batch dictionary and labels.
+ """
+ import numpy as np
+
+ def pad_1d(arrs, pad_idx):
+ max_len = max(len(a) for a in arrs)
+ out = np.full((len(arrs), max_len), pad_idx, dtype=np.int64)
+ for i, a in enumerate(arrs):
+ out[i, : len(a)] = a
+ return ms.Tensor(out, ms.int64)
+
+ def pad_2d_square(arrs, pad_idx, dtype):
+ max_len = max(a.shape[0] for a in arrs)
+ out = np.full((len(arrs), max_len, max_len), pad_idx, dtype=dtype)
+ for i, a in enumerate(arrs):
+ L = a.shape[0]
+ out[i, :L, :L] = a
+ return ms.Tensor(out)
+
+ def pad_coords_np(arrs, pad_val):
+ max_len = max(a.shape[0] for a in arrs)
+ out = np.full((len(arrs), max_len, 3), pad_val, dtype=np.float32)
+ for i, a in enumerate(arrs):
+ L = a.shape[0]
+ out[i, :L, :] = a
+ return ms.Tensor(out, ms.float32)
+
+ batch = {}
+ for k in samples[0][0].keys():
+ if k == 'src_coord':
+ v = pad_coords_np([samples[i][0][k].astype('float32') for i in range(len(samples))], 0.0)
+ elif k == 'src_edge_type':
+ v = pad_2d_square([samples[i][0][k].astype('int64') for i in range(len(samples))], self.padding_idx, dtype='int64')
+ elif k == 'src_distance':
+ v = pad_2d_square([samples[i][0][k].astype('float32') for i in range(len(samples))], 0.0, dtype='float32')
+ elif k == 'src_tokens':
+ v = pad_1d([samples[i][0][k].astype('int64') for i in range(len(samples))], self.padding_idx)
+ batch[k] = v
+ try:
+ label = ms.Tensor([s[1] for s in samples])
+ except Exception:
+ label = None
+ return batch, label
+
+
+class LinearHead(nn.Cell):
+ """Linear head."""
+
+ def __init__(
+ self,
+ input_dim,
+ num_classes,
+ pooler_dropout,
+ ):
+ """
+ Initialize the Linear head.
+
+ :param input_dim: Dimension of input features.
+ :param num_classes: Number of classes for output.
+ """
+ super().__init__()
+ self.out_proj = nn.Dense(input_dim, num_classes)
+ self.dropout = nn.Dropout(p=pooler_dropout)
+
+ def construct(self, features, **kwargs):
+ """
+ Forward pass for the Linear head.
+
+ :param features: Input features.
+
+ :return: Output from the Linear head.
+ """
+ x = features
+ x = self.dropout(x)
+ x = self.out_proj(x)
+ return x
+
+
+class ClassificationHead(nn.Cell):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(
+ self,
+ input_dim,
+ inner_dim,
+ num_classes,
+ activation_fn,
+ pooler_dropout,
+ ):
+ """
+ Initialize the classification head.
+
+ :param input_dim: Dimension of input features.
+ :param inner_dim: Dimension of the inner layer.
+ :param num_classes: Number of classes for classification.
+ :param activation_fn: Activation function name.
+ :param pooler_dropout: Dropout rate for the pooling layer.
+ """
+ super().__init__()
+ self.dense = nn.Dense(input_dim, inner_dim)
+ self.activation_fn = get_activation_fn(activation_fn)
+ self.dropout = nn.Dropout(p=pooler_dropout)
+ self.out_proj = nn.Dense(inner_dim, num_classes)
+
+ def construct(self, features, **kwargs):
+ """
+ Forward pass for the classification head.
+
+ :param features: Input features for classification.
+
+ :return: Output from the classification head.
+ """
+ x = features
+ x = self.dropout(x)
+ x = self.dense(x)
+ x = self.activation_fn(x)
+ x = self.dropout(x)
+ x = self.out_proj(x)
+ return x
+
+
+class NonLinearHead(nn.Cell):
+ """
+ A neural network module used for simple classification tasks. It consists of a two-layered linear network
+ with a nonlinear activation function in between.
+
+ Attributes:
+ - linear1: The first linear layer.
+ - linear2: The second linear layer that outputs to the desired dimensions.
+ - activation_fn: The nonlinear activation function.
+ """
+
+ def __init__(
+ self,
+ input_dim,
+ out_dim,
+ activation_fn,
+ hidden=None,
+ ):
+ """
+ Initializes the NonLinearHead module.
+
+ :param input_dim: Dimension of the input features.
+ :param out_dim: Dimension of the output.
+ :param activation_fn: The activation function to use.
+ :param hidden: Dimension of the hidden layer; defaults to the same as input_dim if not provided.
+ """
+ super().__init__()
+ hidden = input_dim if not hidden else hidden
+ self.linear1 = nn.Dense(input_dim, hidden)
+ self.linear2 = nn.Dense(hidden, out_dim)
+ self.activation_fn = get_activation_fn(activation_fn)
+
+ def construct(self, x):
+ """
+ Forward pass of the NonLinearHead.
+
+ :param x: Input tensor to the module.
+
+ :return: Tensor after passing through the network.
+ """
+ x = self.linear1(x)
+ x = self.activation_fn(x)
+ x = self.linear2(x)
+ return x
+
+
+def gaussian(x, mean, std):
+ """
+ :param x: The input tensor.
+ :param mean: The mean for the Gaussian function.
+ :param std: The standard deviation for the Gaussian function.
+ :return: The output tensor after applying the Gaussian function.
+ """
+ pi = 3.14159
+ a = (2 * pi) ** 0.5
+ return ops.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std)
+
+
+def get_activation_fn(activation):
+ """Returns the activation function corresponding to `activation`"""
+
+ if activation == "relu":
+ return nn.ReLU()
+ elif activation == "gelu":
+ return nn.GELU()
+ elif activation == "tanh":
+ return nn.Tanh()
+ elif activation == "linear":
+ return nn.Identity()
+ else:
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
+
+
+class GaussianLayer(nn.Cell):
+ """
+ A neural network module implementing a Gaussian layer, useful in graph neural networks.
+
+ Attributes:
+ - K: Number of Gaussian kernels.
+ - means, stds: Embeddings for the means and standard deviations of the Gaussian kernels.
+ - mul, bias: Embeddings for scaling and bias parameters.
+ """
+
+ def __init__(self, K=128, edge_types=1024):
+ """
+ Initializes the GaussianLayer module.
+
+ :param K: Number of Gaussian kernels.
+ :param edge_types: Number of different edge types to consider.
+
+ :return: An instance of the configured Gaussian kernel and edge types.
+ """
+ super().__init__()
+ self.K = K
+ self.means = nn.Embedding(1, K)
+ self.stds = nn.Embedding(1, K)
+ self.mul = nn.Embedding(edge_types, 1)
+ self.bias = nn.Embedding(edge_types, 1)
+ # Initialize with NumPy to avoid MindSpore UniformExt kernel dependency on some backends
+ means_shape = self.means.embedding_table.asnumpy().shape
+ stds_shape = self.stds.embedding_table.asnumpy().shape
+ bias_shape = self.bias.embedding_table.asnumpy().shape
+ mul_shape = self.mul.embedding_table.asnumpy().shape
+ self.means.embedding_table.set_data(
+ ms.Tensor(
+ np.random.uniform(0.0, 3.0, size=means_shape),
+ dtype=self.means.embedding_table.dtype,
+ )
+ )
+ self.stds.embedding_table.set_data(
+ ms.Tensor(
+ np.random.uniform(0.0, 3.0, size=stds_shape),
+ dtype=self.stds.embedding_table.dtype,
+ )
+ )
+ self.bias.embedding_table.set_data(
+ ms.Tensor(
+ np.zeros(bias_shape, dtype=np.float32),
+ dtype=self.bias.embedding_table.dtype,
+ )
+ )
+ self.mul.embedding_table.set_data(
+ ms.Tensor(
+ np.ones(mul_shape, dtype=np.float32),
+ dtype=self.mul.embedding_table.dtype,
+ )
+ )
+
+ def construct(self, x, edge_type):
+ """
+ Forward pass of the GaussianLayer.
+
+ :param x: Input tensor representing distances or other features.
+ :param edge_type: Tensor indicating types of edges in the graph.
+
+ :return: Tensor transformed by the Gaussian layer.
+ """
+ mul = self.mul(edge_type).astype(x.dtype)
+ bias = self.bias(edge_type).astype(x.dtype)
+ x = mul * ops.expand_dims(x, -1) + bias
+ x = ops.tile(x, (1, 1, 1, self.K))
+ mean = ops.reshape(self.means.embedding_table.astype(ms.float32), (-1,))
+ std = ops.reshape(self.stds.embedding_table.astype(ms.float32), (-1,)).abs() + 1e-5
+ return gaussian(x.astype(ms.float32), mean, std).astype(self.means.embedding_table.dtype)
+
+
+class NumericalEmbed(nn.Cell):
+ """
+ Numerical embedding module, typically used for embedding edge features in graph neural networks.
+
+ Attributes:
+ - K: Output dimension for embeddings.
+ - mul, bias, w_edge: Embeddings for transformation parameters.
+ - proj: Projection layer to transform inputs.
+ - ln: Layer normalization.
+ """
+
+ def __init__(self, K=128, edge_types=1024, activation_fn='gelu'):
+ """
+ Initializes the NonLinearHead.
+
+ :param input_dim: The input dimension of the first layer.
+ :param out_dim: The output dimension of the second layer.
+ :param activation_fn: The activation function to use.
+ :param hidden: The dimension of the hidden layer; defaults to input_dim if not specified.
+ """
+ super().__init__()
+ self.K = K
+ self.mul = nn.Embedding(edge_types, 1)
+ self.bias = nn.Embedding(edge_types, 1)
+ self.w_edge = nn.Embedding(edge_types, K)
+
+ self.proj = NonLinearHead(1, K, activation_fn, hidden=2 * K)
+ self.ln = nn.LayerNorm([K])
+
+ # initialize
+ self.bias.embedding_table.set_data(ms.Tensor(self.bias.embedding_table.asnumpy()).fill(0.0))
+ self.mul.embedding_table.set_data(ms.Tensor(self.mul.embedding_table.asnumpy()).fill(1.0))
+
+ def construct(self, x, edge_type): # edge_type, atoms
+ """
+ Forward pass of the NonLinearHead.
+
+ :param x: Input tensor to the classification head.
+
+ :return: The output tensor after passing through the layers.
+ """
+ mul = self.mul(edge_type).astype(x.dtype)
+ bias = self.bias(edge_type).astype(x.dtype)
+ w_edge = self.w_edge(edge_type).astype(x.dtype)
+ edge_emb = w_edge * ops.sigmoid(mul * ops.expand_dims(x, -1) + bias)
+
+ edge_proj = ops.expand_dims(x, -1).astype(self.mul.embedding_table.dtype)
+ edge_proj = self.proj(edge_proj)
+ edge_proj = self.ln(edge_proj)
+
+ h = edge_proj + edge_emb
+ h = h.astype(self.mul.embedding_table.dtype)
+ return h
+
+
+class MaskLMHead(nn.Cell):
+ def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
+ super().__init__()
+ self.dense = nn.Dense(embed_dim, embed_dim)
+ self.activation_fn = get_activation_fn(activation_fn)
+ self.layer_norm = nn.LayerNorm([embed_dim])
+ # Define parameters to match PT names: lm_head.weight, lm_head.bias
+ self.weight = ms.Parameter(ms.Tensor(np.zeros((output_dim, embed_dim), dtype=np.float32)), name='weight')
+ self.bias = ms.Parameter(ms.Tensor(np.zeros((output_dim,), dtype=np.float32)), name='bias')
+
+ def construct(self, features, masked_tokens=None, **kwargs):
+ x = features
+ x = self.dense(x)
+ x = self.activation_fn(x)
+ x = self.layer_norm(x)
+ # project back to size of vocabulary with bias: x @ W^T + b
+ x = ops.matmul(x, ops.transpose(self.weight, (1, 0))) + self.bias
+ return x
+
+
+class DistanceHead(nn.Cell):
+ def __init__(self, input_dim, activation_fn):
+ super().__init__()
+ self.dense = nn.Dense(input_dim, input_dim)
+ self.activation_fn = get_activation_fn(activation_fn)
+ self.layer_norm = nn.LayerNorm([input_dim])
+ # PT out_proj appears to map from attention_heads -> 1
+ self.out_proj = nn.Dense(input_dim, 1)
+
+ def construct(self, x):
+ x = self.dense(x)
+ x = self.activation_fn(x)
+ x = self.layer_norm(x)
+ x = self.out_proj(x)
+ return x
+
+
+def molecule_architecture():
+ args = Dict()
+ args.encoder_layers = 15
+ args.encoder_embed_dim = 512
+ args.encoder_ffn_embed_dim = 2048
+ args.encoder_attention_heads = 64
+ args.dropout = 0.1
+ args.emb_dropout = 0.1
+ args.attention_dropout = 0.1
+ args.activation_dropout = 0.0
+ args.pooler_dropout = 0.2
+ args.max_seq_len = 512
+ args.activation_fn = "gelu"
+ args.pooler_activation_fn = "tanh"
+ args.post_ln = False
+ args.backbone = "transformer"
+ args.kernel = "gaussian"
+ # Enable final_head_layer_norm to exist and receive weights
+ args.delta_pair_repr_norm_loss = -1.0
+ return args
+
+
+def protein_architecture():
+ args = Dict()
+ args.encoder_layers = 15
+ args.encoder_embed_dim = 512
+ args.encoder_ffn_embed_dim = 2048
+ args.encoder_attention_heads = 64
+ args.dropout = 0.1
+ args.emb_dropout = 0.1
+ args.attention_dropout = 0.1
+ args.activation_dropout = 0.0
+ args.pooler_dropout = 0.2
+ args.max_seq_len = 512
+ args.activation_fn = "gelu"
+ args.pooler_activation_fn = "tanh"
+ args.post_ln = False
+ args.backbone = "transformer"
+ args.kernel = "gaussian"
+ args.delta_pair_repr_norm_loss = -1.0
+ return args
+
+
+def crystal_architecture():
+ args = Dict()
+ args.encoder_layers = 8
+ args.encoder_embed_dim = 512
+ args.encoder_ffn_embed_dim = 2048
+ args.encoder_attention_heads = 64
+ args.dropout = 0.1
+ args.emb_dropout = 0.1
+ args.attention_dropout = 0.1
+ args.activation_dropout = 0.0
+ args.pooler_dropout = 0.0
+ args.max_seq_len = 1024
+ args.activation_fn = "gelu"
+ args.pooler_activation_fn = "tanh"
+ args.post_ln = False
+ args.backbone = "transformer"
+ args.kernel = "linear"
+ args.delta_pair_repr_norm_loss = -1.0
+ return args
+
+
+def oled_architecture():
+ args = Dict()
+ args.encoder_layers = 8
+ args.encoder_embed_dim = 512
+ args.encoder_ffn_embed_dim = 2048
+ args.encoder_attention_heads = 64
+ args.dropout = 0.1
+ args.emb_dropout = 0.1
+ args.attention_dropout = 0.1
+ args.activation_dropout = 0.0
+ args.pooler_dropout = 0.0
+ args.max_seq_len = 1024
+ args.activation_fn = "gelu"
+ args.pooler_activation_fn = "tanh"
+ args.post_ln = False
+ args.backbone = "transformer"
+ args.kernel = "linear"
+ args.delta_pair_repr_norm_loss = 0.0
+ return args
diff --git a/MindChem/applications/unimol/unimol_tools/models/unimolv2.py b/MindChem/applications/unimol/unimol_tools/models/unimolv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ad14ecc373c6b4c9a58ab2b5b72b65a7e620583
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/models/unimolv2.py
@@ -0,0 +1,685 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import absolute_import, division, print_function
+
+import os
+import pathlib
+
+import mindspore as ms
+import mindspore.ops as ops
+from mindspore import nn
+from addict import Dict
+
+from ..config import MODEL_CONFIG_V2
+from ..utils import logger, pad_1d_tokens, pad_2d, pad_coords
+from ..weights import WEIGHT_DIR, weight_download_v2
+from .transformersv2 import (
+ AtomFeature,
+ EdgeFeature,
+ MovementPredictionHead,
+ SE3InvariantKernel,
+ TransformerEncoderWithPairV2,
+)
+
+BACKBONE = {
+ 'transformer': TransformerEncoderWithPairV2,
+}
+
+
+class UniMolV2Model(nn.Cell):
+ """
+ UniMolModel is a specialized model for molecular, protein, crystal, or MOF (Metal-Organic Frameworks) data.
+ It dynamically configures its architecture based on the type of data it is intended to work with. The model
+ supports multiple data types and incorporates various architecture configurations and pretrained weights.
+
+ Attributes:
+ - output_dim: The dimension of the output layer.
+ - data_type: The type of data the model is designed to handle.
+ - remove_hs: Flag to indicate whether hydrogen atoms are removed in molecular data.
+ - pretrain_path: Path to the pretrained model weights.
+ - dictionary: The dictionary object used for tokenization and encoding.
+ - mask_idx: Index of the mask token in the dictionary.
+ - padding_idx: Index of the padding token in the dictionary.
+ - embed_tokens: Embedding layer for token embeddings.
+ - encoder: Transformer encoder backbone of the model.
+ - gbf_proj, gbf: Layers for Gaussian basis functions or numerical embeddings.
+ - classification_head: The final classification head of the model.
+ """
+
+ def __init__(self, output_dim=2, model_size='84m', **params):
+ """
+ Initializes the UniMolModel with specified parameters and data type.
+
+ :param output_dim: (int) The number of output dimensions (classes).
+ :param data_type: (str) The type of data (e.g., 'molecule', 'protein').
+ :param params: Additional parameters for model configuration.
+ """
+ super().__init__()
+
+ self.args = molecule_architecture(model_size=model_size)
+ self.output_dim = output_dim
+ self.model_size = model_size
+ self.remove_hs = params.get('remove_hs', False)
+
+ name = model_size
+ if not os.path.exists(
+ os.path.join(WEIGHT_DIR, MODEL_CONFIG_V2['weight'][name])
+ ):
+ weight_download_v2(MODEL_CONFIG_V2['weight'][name], WEIGHT_DIR)
+
+ self.pretrain_path = os.path.join(WEIGHT_DIR, MODEL_CONFIG_V2['weight'][name])
+
+ self.token_num = 128
+ self.padding_idx = 0
+ self.mask_idx = 127
+ self.embed_tokens = nn.Embedding(
+ vocab_size=self.token_num, embedding_size=self.args.encoder_embed_dim, padding_idx=self.padding_idx
+ )
+
+ self.encoder = BACKBONE[self.args.backbone](
+ num_encoder_layers=self.args.num_encoder_layers,
+ embedding_dim=self.args.encoder_embed_dim,
+ pair_dim=self.args.pair_embed_dim,
+ pair_hidden_dim=self.args.pair_hidden_dim,
+ ffn_embedding_dim=self.args.ffn_embedding_dim,
+ num_attention_heads=self.args.num_attention_heads,
+ dropout=self.args.dropout,
+ attention_dropout=self.args.attention_dropout,
+ activation_dropout=self.args.activation_dropout,
+ activation_fn=self.args.activation_fn,
+ droppath_prob=self.args.droppath_prob,
+ pair_dropout=self.args.pair_dropout,
+ )
+
+ num_atom = 512
+ num_degree = 128
+ num_edge = 64
+ num_pair = 512
+ num_spatial = 512
+
+ K = 128
+ n_edge_type = 1
+
+ self.atom_feature = AtomFeature(
+ num_atom=num_atom,
+ num_degree=num_degree,
+ hidden_dim=self.args.encoder_embed_dim,
+ )
+
+ self.edge_feature = EdgeFeature(
+ pair_dim=self.args.pair_embed_dim,
+ num_edge=num_edge,
+ num_spatial=num_spatial,
+ )
+
+ self.se3_invariant_kernel = SE3InvariantKernel(
+ pair_dim=self.args.pair_embed_dim,
+ num_pair=num_pair,
+ num_kernel=K,
+ std_width=self.args.gaussian_std_width,
+ start=self.args.gaussian_mean_start,
+ stop=self.args.gaussian_mean_stop,
+ )
+
+ self.movement_pred_head = MovementPredictionHead(
+ self.args.encoder_embed_dim,
+ self.args.pair_embed_dim,
+ self.args.encoder_attention_heads,
+ )
+
+ self.classification_heads = dict()
+ self.dtype = ms.float32
+
+ """
+ # To be deprecated in the future.
+ self.classification_head = ClassificationHead(
+ input_dim=self.args.encoder_embed_dim,
+ inner_dim=self.args.encoder_embed_dim,
+ num_classes=self.output_dim,
+ activation_fn=self.args.pooler_activation_fn,
+ pooler_dropout=self.args.pooler_dropout,
+ )
+ """
+ if 'pooler_dropout' in params:
+ self.args.pooler_dropout = params['pooler_dropout']
+ self.classification_head = LinearHead(
+ input_dim=self.args.encoder_embed_dim,
+ num_classes=self.output_dim,
+ pooler_dropout=self.args.pooler_dropout,
+ )
+ self.load_pretrained_weights(path=self.pretrain_path)
+
+ def load_pretrained_weights(self, path, strict=False):
+ """
+ Loads pretrained weights into the model.
+
+ :param path: (str) Path to the pretrained weight file.
+ """
+ if path is not None:
+ if path.endswith('.ckpt') and os.path.exists(path):
+ logger.info("Loading MindSpore checkpoint from {}".format(path))
+ from mindspore import load_checkpoint, load_param_into_net
+ param_dict = load_checkpoint(path)
+ param_dict["classification_head.out_proj.weight"] = self.classification_head.out_proj.weight
+ param_dict["classification_head.out_proj.bias"] = self.classification_head.out_proj.bias
+ load_param_into_net(self, param_dict)
+ else:
+ logger.warning("Pretrained weights are not in MindSpore .ckpt format ({}). Skipping load.".format(path))
+
+ @classmethod
+ def build_model(cls, args):
+ """
+ Class method to build a new instance of the UniMolModel.
+
+ :param args: Arguments for model configuration.
+ :return: An instance of UniMolModel.
+ """
+ return cls(args)
+
+ #'atom_feat', 'atom_mask', 'edge_feat', 'shortest_path', 'degree', 'pair_type', 'attn_bias', 'src_tokens'
+ def construct(
+ self,
+ atom_feat,
+ atom_mask,
+ edge_feat,
+ shortest_path,
+ degree,
+ pair_type,
+ attn_bias,
+ src_tokens,
+ src_coord,
+ return_repr=False,
+ return_atomic_reprs=False,
+ **kwargs
+ ):
+
+ pos = src_coord
+
+ n_mol, n_atom = atom_feat.shape[:2]
+ token_feat = self.embed_tokens(src_tokens)
+ x = self.atom_feature({'atom_feat': atom_feat, 'degree': degree}, token_feat)
+
+ dtype = self.dtype
+
+ x = x.astype(dtype)
+
+ attn_mask = attn_bias
+ attn_bias = ops.zeros_like(attn_mask)
+ attn_mask = ops.tile(attn_mask.unsqueeze(1), (1, self.args.encoder_attention_heads, 1, 1))
+ attn_bias = ops.tile(attn_bias.unsqueeze(-1), (1, 1, 1, self.args.pair_embed_dim))
+ attn_bias = self.edge_feature(
+ {'shortest_path': shortest_path, 'edge_feat': edge_feat}, attn_bias
+ )
+ attn_mask = attn_mask.astype(self.dtype)
+
+ atom_mask_cls = ops.concat([
+ ops.ones((n_mol, 1), dtype=atom_mask.dtype),
+ atom_mask,
+ ], axis=1).astype(self.dtype)
+
+ pair_mask = atom_mask_cls.unsqueeze(-1) * atom_mask_cls.unsqueeze(-2)
+
+ def one_block(x, pos, return_x=False):
+ delta_pos = pos.unsqueeze(1) - pos.unsqueeze(2)
+ dist = ops.sqrt(ops.reduce_sum(delta_pos * delta_pos, -1))
+ attn_bias_3d = self.se3_invariant_kernel(dist, pair_type)
+ # Align shape to [B, N, N, pair_dim]
+ if len(attn_bias_3d.shape) == 5 and attn_bias_3d.shape[-2] == 1:
+ attn_bias_3d = ops.squeeze(attn_bias_3d, axis=-2)
+ new_attn_bias = attn_bias
+ new_attn_bias[:, 1:, 1:, :] = new_attn_bias[:, 1:, 1:, :] + attn_bias_3d
+ new_attn_bias = new_attn_bias.astype(dtype)
+ x, pair = self.encoder(
+ x,
+ new_attn_bias,
+ atom_mask=atom_mask_cls,
+ pair_mask=pair_mask,
+ attn_mask=attn_mask,
+ )
+ node_output = self.movement_pred_head(
+ x[:, 1:, :],
+ pair[:, 1:, 1:, :],
+ attn_mask[:, :, 1:, 1:],
+ delta_pos,
+ )
+ if return_x:
+ return x, pair, pos + node_output
+ else:
+ return pos + node_output
+
+ x, pair, pos = one_block(x, pos, return_x=True)
+ cls_repr = x[:, 0, :]
+ all_repr = x[:, :, :]
+
+ if return_repr:
+ filtered_tensors = []
+ filtered_coords = []
+
+ for tokens, coord in zip(src_tokens, src_coord):
+ mask = (tokens != 0) & (tokens != 1) & (tokens != 2)
+ filtered_tensor = tokens[mask]
+ filtered_coord = coord[mask]
+ filtered_tensors.append(filtered_tensor)
+ filtered_coords.append(filtered_coord)
+
+ lengths = [
+ len(filtered_tensor) for filtered_tensor in filtered_tensors
+ ] # Compute the lengths of the filtered tensors
+ if return_atomic_reprs:
+ cls_atomic_reprs = []
+ atomic_symbols = []
+ for i in range(len(all_repr)):
+ atomic_reprs = x[i, 1 : lengths[i] + 1, :]
+ atomic_symbol = filtered_tensors[i]
+ atomic_symbols.append(atomic_symbol)
+ cls_atomic_reprs.append(atomic_reprs)
+ return {
+ 'cls_repr': cls_repr,
+ 'atomic_symbol': atomic_symbols,
+ 'atomic_coords': filtered_coords,
+ 'atomic_reprs': cls_atomic_reprs,
+ }
+ else:
+ return {'cls_repr': cls_repr}
+
+ logits = self.classification_head(cls_repr)
+ return logits
+
+ def register_classification_head(
+ self, name, num_classes=None, inner_dim=None, **kwargs
+ ):
+ """Register a classification head."""
+ if name in self.classification_heads:
+ prev_num_classes = None
+ prev_inner_dim = None
+ if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
+ logger.warning(
+ 're-registering head "{}" with num_classes {} (prev: {}) '
+ "and inner_dim {} (prev: {})".format(
+ name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
+ )
+ )
+ self.classification_heads[name] = ClassificationHead(
+ input_dim=self.args.encoder_embed_dim,
+ inner_dim=inner_dim or self.args.encoder_embed_dim,
+ num_classes=num_classes,
+ activation_fn=self.args.pooler_activation_fn,
+ pooler_dropout=self.args.pooler_dropout,
+ )
+
+ def set_num_updates(self, num_updates):
+ """State from trainer to pass along to model at every update."""
+ self._num_updates = num_updates
+
+ def get_num_updates(self):
+ return self._num_updates
+
+ def batch_collate_fn(self, samples):
+ """
+ Custom collate function for batch processing non-MOF data.
+
+ :param samples: A list of sample data.
+
+ :return: A tuple containing a batch dictionary and labels.
+ """
+ import numpy as np
+
+ def pad_1d(arrs, pad_idx):
+ max_len = max(len(a) for a in arrs)
+ out = np.full((len(arrs), max_len), pad_idx, dtype=np.int64)
+ for i, a in enumerate(arrs):
+ out[i, : len(a)] = a
+ return ms.Tensor(out, ms.int64)
+
+ def pad_2d_dim(arrs, pad_idx, dim=1, dtype='int64'):
+ max_len = max(a.shape[0] for a in arrs)
+ if dim == 1:
+ out = np.full((len(arrs), max_len, max_len), pad_idx, dtype=dtype)
+ for i, a in enumerate(arrs):
+ L = a.shape[0]
+ out[i, :L, :L] = a
+ else:
+ out = np.full((len(arrs), max_len, max_len, dim), pad_idx, dtype=dtype)
+ for i, a in enumerate(arrs):
+ L = a.shape[0]
+ out[i, :L, :L, :] = a
+ return ms.Tensor(out)
+
+ def pad_coords_np(arrs, pad_val, dim=3):
+ max_len = max(a.shape[0] for a in arrs)
+ out = np.full((len(arrs), max_len, dim), pad_val, dtype=np.float32)
+ for i, a in enumerate(arrs):
+ L = a.shape[0]
+ out[i, :L, :] = a
+ return ms.Tensor(out, ms.float32)
+
+ batch = {}
+ for k in samples[0][0].keys():
+ if k == 'atom_feat':
+ v = pad_coords_np([samples[i][0][k] for i in range(len(samples))], self.padding_idx, dim=8)
+ elif k == 'atom_mask':
+ v = pad_1d([samples[i][0][k] for i in range(len(samples))], self.padding_idx)
+ elif k == 'edge_feat':
+ v = pad_2d_dim([samples[i][0][k] for i in range(len(samples))], self.padding_idx, dim=3, dtype='int64')
+ elif k == 'shortest_path':
+ v = pad_2d_dim([samples[i][0][k] for i in range(len(samples))], self.padding_idx, dtype='int64')
+ elif k == 'degree':
+ v = pad_1d([samples[i][0][k] for i in range(len(samples))], self.padding_idx)
+ elif k == 'pair_type':
+ v = pad_2d_dim([samples[i][0][k] for i in range(len(samples))], self.padding_idx, dim=2, dtype='int64')
+ elif k == 'attn_bias':
+ v = pad_2d_dim([samples[i][0][k] for i in range(len(samples))], self.padding_idx, dtype='float32')
+ elif k == 'src_tokens':
+ v = pad_1d([samples[i][0][k] for i in range(len(samples))], self.padding_idx)
+ elif k == 'src_coord':
+ v = pad_coords_np([samples[i][0][k] for i in range(len(samples))], self.padding_idx)
+ batch[k] = v
+ try:
+ label = ms.Tensor([s[1] for s in samples])
+ except Exception:
+ label = None
+ return batch, label
+
+
+class LinearHead(nn.Cell):
+ """Linear head."""
+
+ def __init__(
+ self,
+ input_dim,
+ num_classes,
+ pooler_dropout,
+ ):
+ """
+ Initialize the Linear head.
+
+ :param input_dim: Dimension of input features.
+ :param num_classes: Number of classes for output.
+ """
+ super().__init__()
+ self.out_proj = nn.Dense(input_dim, num_classes)
+ self.dropout = nn.Dropout(p=pooler_dropout)
+
+ def construct(self, features, **kwargs):
+ """
+ Forward pass for the Linear head.
+
+ :param features: Input features.
+
+ :return: Output from the Linear head.
+ """
+ x = features
+ x = self.dropout(x)
+ x = self.out_proj(x)
+ return x
+
+
+class ClassificationHead(nn.Cell):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(
+ self,
+ input_dim,
+ inner_dim,
+ num_classes,
+ activation_fn,
+ pooler_dropout,
+ ):
+ """
+ Initialize the classification head.
+
+ :param input_dim: Dimension of input features.
+ :param inner_dim: Dimension of the inner layer.
+ :param num_classes: Number of classes for classification.
+ :param activation_fn: Activation function name.
+ :param pooler_dropout: Dropout rate for the pooling layer.
+ """
+ super().__init__()
+ self.dense = nn.Dense(input_dim, inner_dim)
+ self.activation_fn = get_activation_fn(activation_fn)
+ self.dropout = nn.Dropout(p=pooler_dropout)
+ self.out_proj = nn.Dense(inner_dim, num_classes)
+
+ def construct(self, features, **kwargs):
+ """
+ Forward pass for the classification head.
+
+ :param features: Input features for classification.
+
+ :return: Output from the classification head.
+ """
+ x = features
+ x = self.dropout(x)
+ x = self.dense(x)
+ x = self.activation_fn(x)
+ x = self.dropout(x)
+ x = self.out_proj(x)
+ return x
+
+
+class NonLinearHead(nn.Cell):
+ """
+ A neural network module used for simple classification tasks. It consists of a two-layered linear network
+ with a nonlinear activation function in between.
+
+ Attributes:
+ - linear1: The first linear layer.
+ - linear2: The second linear layer that outputs to the desired dimensions.
+ - activation_fn: The nonlinear activation function.
+ """
+
+ def __init__(
+ self,
+ input_dim,
+ out_dim,
+ activation_fn,
+ hidden=None,
+ ):
+ """
+ Initializes the NonLinearHead module.
+
+ :param input_dim: Dimension of the input features.
+ :param out_dim: Dimension of the output.
+ :param activation_fn: The activation function to use.
+ :param hidden: Dimension of the hidden layer; defaults to the same as input_dim if not provided.
+ """
+ super().__init__()
+ hidden = input_dim if not hidden else hidden
+ self.linear1 = nn.Dense(input_dim, hidden)
+ self.linear2 = nn.Dense(hidden, out_dim)
+ self.activation_fn = get_activation_fn(activation_fn)
+
+ def construct(self, x):
+ """
+ Forward pass of the NonLinearHead.
+
+ :param x: Input tensor to the module.
+
+ :return: Tensor after passing through the network.
+ """
+ x = self.linear1(x)
+ x = self.activation_fn(x)
+ x = self.linear2(x)
+ return x
+
+
+def gaussian(x, mean, std):
+ """
+ :param x: The input tensor.
+ :param mean: The mean for the Gaussian function.
+ :param std: The standard deviation for the Gaussian function.
+
+ :return: The output tensor after applying the Gaussian function.
+ """
+ pi = 3.14159
+ a = (2 * pi) ** 0.5
+ return ops.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std)
+
+
+def get_activation_fn(activation):
+ """Returns the activation function corresponding to `activation`"""
+
+ if activation == "relu":
+ return nn.ReLU()
+ elif activation == "gelu":
+ return nn.GELU()
+ elif activation == "tanh":
+ return nn.Tanh()
+ elif activation == "linear":
+ return nn.Identity()
+ else:
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
+
+
+class GaussianLayer(nn.Cell):
+ """
+ A neural network module implementing a Gaussian layer, useful in graph neural networks.
+
+ Attributes:
+ - K: Number of Gaussian kernels.
+ - means, stds: Embeddings for the means and standard deviations of the Gaussian kernels.
+ - mul, bias: Embeddings for scaling and bias parameters.
+ """
+
+ def __init__(self, K=128, edge_types=1024):
+ """
+ Initializes the GaussianLayer module.
+
+ :param K: Number of Gaussian kernels.
+ :param edge_types: Number of different edge types to consider.
+
+ :return: An instance of the configured Gaussian kernel and edge types.
+ """
+ super().__init__()
+ self.K = K
+ self.means = nn.Embedding(1, K)
+ self.stds = nn.Embedding(1, K)
+ self.mul = nn.Embedding(edge_types, 1)
+ self.bias = nn.Embedding(edge_types, 1)
+ nn.init.uniform_(self.means.weight, 0, 3)
+ nn.init.uniform_(self.stds.weight, 0, 3)
+ nn.init.constant_(self.bias.weight, 0)
+ nn.init.constant_(self.mul.weight, 1)
+
+ def construct(self, x, edge_type):
+ """
+ Forward pass of the GaussianLayer.
+
+ :param x: Input tensor representing distances or other features.
+ :param edge_type: Tensor indicating types of edges in the graph.
+
+ :return: Tensor transformed by the Gaussian layer.
+ """
+ mul = self.mul(edge_type).astype(x.dtype)
+ bias = self.bias(edge_type).astype(x.dtype)
+ x = mul * ops.expand_dims(x, -1) + bias
+ x = ops.tile(x, (1, -1, -1, self.K))
+ mean = self.means.weight.astype(ms.float32).view(-1)
+ std = ops.abs(self.stds.weight.astype(ms.float32).view(-1)) + 1e-5
+ return gaussian(x.astype(ms.float32), mean, std).astype(self.means.weight.dtype)
+
+
+class NumericalEmbed(nn.Cell):
+ """
+ Numerical embedding module, typically used for embedding edge features in graph neural networks.
+
+ Attributes:
+ - K: Output dimension for embeddings.
+ - mul, bias, w_edge: Embeddings for transformation parameters.
+ - proj: Projection layer to transform inputs.
+ - ln: Layer normalization.
+ """
+
+ def __init__(self, K=128, edge_types=1024, activation_fn='gelu'):
+ """
+ Initializes the NonLinearHead.
+
+ :param input_dim: The input dimension of the first layer.
+ :param out_dim: The output dimension of the second layer.
+ :param activation_fn: The activation function to use.
+ :param hidden: The dimension of the hidden layer; defaults to input_dim if not specified.
+ """
+ super().__init__()
+ self.K = K
+ self.mul = nn.Embedding(edge_types, 1)
+ self.bias = nn.Embedding(edge_types, 1)
+ self.w_edge = nn.Embedding(edge_types, K)
+
+ self.proj = NonLinearHead(1, K, activation_fn, hidden=2 * K)
+ self.ln = nn.LayerNorm([K])
+
+ # Initializers are omitted in MindSpore port
+
+ def construct(self, x, edge_type): # edge_type, atoms
+ """
+ Forward pass of the NonLinearHead.
+
+ :param x: Input tensor to the classification head.
+
+ :return: The output tensor after passing through the layers.
+ """
+ mul = self.mul(edge_type).astype(x.dtype)
+ bias = self.bias(edge_type).astype(x.dtype)
+ w_edge = self.w_edge(edge_type).astype(x.dtype)
+ edge_emb = w_edge * ops.sigmoid(mul * ops.expand_dims(x, -1) + bias)
+
+ edge_proj = ops.expand_dims(x, -1).astype(self.mul.weight.dtype)
+ edge_proj = self.proj(edge_proj)
+ edge_proj = self.ln(edge_proj)
+
+ h = edge_proj + edge_emb
+ h = h.astype(self.mul.weight.dtype)
+ return h
+
+
+def molecule_architecture(model_size='84m'):
+ args = Dict()
+ if model_size == '84m':
+ args.num_encoder_layers = 12
+ args.encoder_embed_dim = 768
+ args.num_attention_heads = 48
+ args.ffn_embedding_dim = 768
+ args.encoder_attention_heads = 48
+ elif model_size == '164m':
+ args.num_encoder_layers = 24
+ args.encoder_embed_dim = 768
+ args.num_attention_heads = 48
+ args.ffn_embedding_dim = 768
+ args.encoder_attention_heads = 48
+ elif model_size == '310m':
+ args.num_encoder_layers = 32
+ args.encoder_embed_dim = 1024
+ args.num_attention_heads = 64
+ args.ffn_embedding_dim = 1024
+ args.encoder_attention_heads = 64
+ elif model_size == '570m':
+ args.num_encoder_layers = 32
+ args.encoder_embed_dim = 1536
+ args.num_attention_heads = 96
+ args.ffn_embedding_dim = 1536
+ args.encoder_attention_heads = 96
+ elif model_size == '1.1B':
+ args.num_encoder_layers = 64
+ args.encoder_embed_dim = 1536
+ args.num_attention_heads = 96
+ args.ffn_embedding_dim = 1536
+ args.encoder_attention_heads = 96
+ else:
+ raise ValueError('Current not support data type: {}'.format(model_size))
+ args.pair_embed_dim = 512
+ args.pair_hidden_dim = 64
+ args.dropout = 0.1
+ args.attention_dropout = 0.1
+ args.activation_dropout = 0.0
+ args.activation_fn = "gelu"
+ args.droppath_prob = 0.0
+ args.pair_dropout = 0.25
+ args.backbone = "transformer"
+ args.gaussian_std_width = 1.0
+ args.gaussian_mean_start = 0.0
+ args.gaussian_mean_stop = 9.0
+ args.pooler_dropout = 0.0
+ args.pooler_activation_fn = "tanh"
+ return args
diff --git a/MindChem/applications/unimol/unimol_tools/predict.py b/MindChem/applications/unimol/unimol_tools/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5223dd411275b96119b01ec1d879390a6a1f955
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/predict.py
@@ -0,0 +1,136 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import absolute_import, division, print_function
+
+import os
+import json
+
+import joblib
+import numpy as np
+
+from .data import DataHub
+from .models import NNModel
+from .tasks import Trainer
+from .utils import YamlHandler, logger
+
+
+class MolPredict(object):
+ """A :class:`MolPredict` class is responsible for interface of predicting process of molecular data."""
+
+ def __init__(self, load_model=None):
+ """
+ Initialize a :class:`MolPredict` class.
+
+ :param load_model: str, default=None, path of model to load.
+ """
+ if not load_model:
+ raise ValueError("load_model is empty")
+ self.load_model = load_model
+ config_path = os.path.join(load_model, 'config.yaml')
+ self.config = YamlHandler(config_path).read_yaml()
+ self.config.target_cols = self.config.target_cols.split(',')
+ self.task = self.config.task
+ self.target_cols = self.config.target_cols
+
+ def predict(self, data, save_path=None, metrics='none'):
+ """
+ Predict molecular data.
+
+ :param data: str or pandas.DataFrame or dict of atoms and coordinates, input data for prediction. \
+ - str: path of csv file.
+ - pandas.DataFrame: dataframe of data.
+ - dict: dict of atoms and coordinates, e.g. {'atoms': ['C', 'C', 'C'], 'coordinates': [[0, 0, 0], [0, 0, 1], [0, 0, 2]]}
+ :param save_path: str, default=None, path to save predict result.
+ :param metrics: str, default='none', metrics to evaluate model performance.
+
+ currently support:
+
+ - classification: auc, auprc, log_loss, acc, f1_score, mcc, precision, recall, cohen_kappa.
+
+ - regression: mae, pearsonr, spearmanr, mse, r2.
+
+ - multiclass: log_loss, acc.
+
+ - multilabel_classification: auc, auprc, log_loss, acc, mcc.
+
+ - multilabel_regression: mae, mse, r2.
+
+ :return y_pred: numpy.ndarray, predict result.
+ """
+ self.save_path = save_path
+ self.config['sdf_save_path'] = save_path
+ if not metrics or metrics != 'none':
+ self.config.metrics = metrics
+ ## load test data
+ self.datahub = DataHub(
+ data=data, is_train=False, save_path=self.load_model, **self.config
+ )
+ self.config.use_ddp = False
+ self.trainer = Trainer(save_path=self.load_model, **self.config)
+ self.model = NNModel(self.datahub.data, self.trainer, **self.config)
+ self.model.evaluate(self.trainer, self.load_model)
+
+ y_pred = self.model.cv['test_pred']
+ scalar = self.datahub.data['target_scaler']
+ if scalar is not None:
+ y_pred = scalar.inverse_transform(y_pred)
+
+ df = self.datahub.data['raw_data'].copy()
+ predict_cols = ['predict_' + col for col in self.target_cols]
+ if self.task == 'multiclass' and self.config.multiclass_cnt is not None:
+ prob_cols = ['prob_' + str(i) for i in range(self.config.multiclass_cnt)]
+ df[prob_cols] = y_pred
+ df[predict_cols] = np.argmax(y_pred, axis=1).reshape(-1, 1)
+ elif self.task in ['classification', 'multilabel_classification']:
+ threshold = joblib.load(
+ open(os.path.join(self.load_model, 'threshold.dat'), "rb")
+ )
+ prob_cols = ['prob_' + col for col in self.target_cols]
+ df[prob_cols] = y_pred
+ df[predict_cols] = (y_pred > threshold).astype(int)
+ else:
+ prob_cols = predict_cols
+ df[predict_cols] = y_pred
+ if self.save_path:
+ os.makedirs(self.save_path, exist_ok=True)
+ if not (df[self.target_cols] == -1.0).all().all():
+ metrics = self.trainer.metrics.cal_metric(
+ df[self.target_cols].values, df[prob_cols].values
+ )
+ logger.info("final predict metrics score: \n{}".format(metrics))
+ if self.save_path:
+ joblib.dump(metrics, os.path.join(self.save_path, 'test_metric.result'))
+ with open(os.path.join(self.save_path, 'test_metric.json'), 'w') as f:
+ json.dump(metrics, f)
+ else:
+ df.drop(self.target_cols, axis=1, inplace=True)
+ if self.save_path:
+ prefix = (
+ data.split('/')[-1].split('.')[0] if isinstance(data, str) else 'test'
+ )
+ self.save_predict(df, self.save_path, prefix)
+ logger.info("pipeline finish!")
+
+ return y_pred
+
+ def save_predict(self, data, dir, prefix):
+ """
+ Save predict result to csv file.
+
+ :param data: pandas.DataFrame, predict result.
+ :param dir: str, directory to save predict result.
+ :param prefix: str, prefix of predict result file name.
+ """
+ run_id = 0
+ if not os.path.exists(dir):
+ os.makedirs(dir)
+ else:
+ folders = [x for x in os.listdir(dir)]
+ while prefix + f'.predict.{run_id}' + '.csv' in folders:
+ run_id += 1
+ name = prefix + f'.predict.{run_id}' + '.csv'
+ path = os.path.join(dir, name)
+ data.to_csv(path)
+ logger.info("save predict result to {}".format(path))
diff --git a/MindChem/applications/unimol/unimol_tools/predictor.py b/MindChem/applications/unimol/unimol_tools/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..6189038df913cadbf7eec159906e1c2857ce36a1
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/predictor.py
@@ -0,0 +1,148 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import absolute_import, division, print_function
+
+import numpy as np
+import pandas as pd
+import mindspore.dataset as ds
+import mindspore as ms
+
+from .data import DataHub
+from .models import UniMolModel, UniMolV2Model
+from .tasks import Trainer
+
+
+class MolDataset:
+ """
+ A :class:`MolDataset` class is responsible for interface of molecular dataset.
+ """
+
+ def __init__(self, data, label=None):
+ self.data = data
+ self.label = label if label is not None else np.zeros((len(data), 1))
+
+ def __getitem__(self, idx):
+ return self.data[idx], self.label[idx]
+
+ def __len__(self):
+ return len(self.data)
+
+
+class UniMolRepr(object):
+ """
+ A :class:`UniMolRepr` class is responsible for interface of molecular representation by unimol
+ """
+
+ def __init__(
+ self,
+ data_type='molecule',
+ batch_size=32,
+ remove_hs=False,
+ model_name='unimolv1',
+ model_size='84m',
+ use_cuda=True,
+ use_ddp=False,
+ use_gpu='all',
+ save_path=None,
+ **kwargs,
+ ):
+ """
+ Initialize a :class:`UniMolRepr` class.
+
+ :param data_type: str, default='molecule', currently support molecule, oled.
+ :param batch_size: int, default=32, batch size for training.
+ :param remove_hs: bool, default=False, whether to remove hydrogens in molecular.
+ :param model_name: str, default='unimolv1', currently support unimolv1, unimolv2.
+ :param model_size: str, default='84m', model size of unimolv2. Avaliable: 84m, 164m, 310m, 570m, 1.1B.
+ :param use_cuda: bool, default=True, whether to use gpu.
+ :param use_ddp: bool, default=False, whether to use distributed data parallel.
+ :param use_gpu: str, default='all', which gpu to use.
+ """
+ # Set MindSpore device context with safe fallback to CPU
+ device_target = 'Ascend' if use_cuda else 'CPU'
+ gpu_enabled = False
+ try:
+ ms.context.set_context(device_target=device_target)
+ gpu_enabled = device_target == 'Ascend'
+ except Exception:
+ ms.context.set_context(device_target='CPU')
+ gpu_enabled = False
+ if model_name == 'unimolv1':
+ self.model = UniMolModel(
+ output_dim=1, data_type=data_type, remove_hs=remove_hs
+ )
+ elif model_name == 'unimolv2':
+ self.model = UniMolV2Model(output_dim=1, model_size=model_size)
+ else:
+ raise ValueError('Unknown model name: {}'.format(model_name))
+ self.model.set_train(False)
+ self.params = {
+ 'data_type': data_type,
+ 'batch_size': batch_size,
+ 'remove_hs': remove_hs,
+ 'model_name': model_name,
+ 'model_size': model_size,
+ 'use_cuda': gpu_enabled,
+ 'use_ddp': use_ddp,
+ 'use_gpu': use_gpu,
+ 'save_path': save_path,
+ }
+
+ def get_repr(self, data=None, return_atomic_reprs=False):
+ """
+ Get molecular representation by unimol.
+
+ :param data: str, dict or list, default=None, input data for unimol.
+
+ - str: smiles string or path to a smiles file.
+
+ - dict: custom conformers, should take atoms and coordinates as input.
+
+ - list: list of smiles strings.
+
+ :param return_atomic_reprs: bool, default=False, whether to return atomic representations.
+
+ :return: dict of molecular representation.
+ """
+
+ if isinstance(data, str):
+ if data.endswith('.sdf'):
+ # Datahub will process sdf file.
+ pass
+ elif data.endswith('.csv'):
+ # read csv file.
+ data = pd.read_csv(data)
+ assert 'SMILES' in data.columns
+ data = data['SMILES'].values
+ else:
+ # single smiles string.
+ data = [data]
+ data = np.array(data)
+ elif isinstance(data, dict):
+ # custom conformers, should take atoms and coordinates as input.
+ assert 'atoms' in data and 'coordinates' in data
+ elif isinstance(data, list):
+ # list of smiles strings.
+ assert isinstance(data[-1], str)
+ data = np.array(data)
+ else:
+ raise ValueError('Unknown data type: {}'.format(type(data)))
+
+ datahub = DataHub(
+ data=data,
+ task='repr',
+ is_train=False,
+ **self.params,
+ )
+ dataset = ds.GeneratorDataset(source=MolDataset(datahub.data['unimol_input']), column_names=['data', 'label'])
+ # dataset = MolDataset(datahub.data['unimol_input'])
+ self.trainer = Trainer(task='repr', **self.params)
+ repr_output = self.trainer.inference(
+ self.model,
+ return_repr=True,
+ return_atomic_reprs=return_atomic_reprs,
+ dataset=dataset,
+ )
+ return repr_output
diff --git a/MindChem/applications/unimol/unimol_tools/tasks/__init__.py b/MindChem/applications/unimol/unimol_tools/tasks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..260e4c8d68f6bcad54c8610cc8abea45583480c0
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/tasks/__init__.py
@@ -0,0 +1 @@
+from .trainer import Trainer
diff --git a/MindChem/applications/unimol/unimol_tools/tasks/trainer.py b/MindChem/applications/unimol/unimol_tools/tasks/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..50a6b39d56e7c6f88e3680d24f9de9d8a568010b
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/tasks/trainer.py
@@ -0,0 +1,699 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import absolute_import, division, print_function
+
+import os
+import time
+from functools import partial
+
+import numpy as np
+import mindspore as ms
+import mindspore.ops as ops
+from mindspore import nn, context
+from tqdm import tqdm
+
+from ..utils import Metrics, logger
+
+
+class Trainer(object):
+ """A :class:`Trainer` class is responsible for initializing the model, and managing its training, validation, and testing phases."""
+
+ def __init__(self, save_path=None, **params):
+ """
+ :param save_path: Path for saving the training outputs. Defaults to None.
+ :param params: Additional parameters for training.
+ """
+ self.save_path = save_path
+ self.task = params.get('task', None)
+
+ if self.task != 'repr':
+ self.metrics_str = params['metrics']
+ self.metrics = Metrics(self.task, self.metrics_str)
+ self._init_trainer(**params)
+
+ def _init_trainer(self, **params):
+ """
+ Initializing the trainer class to train model.
+
+ :param params: Containing training arguments.
+ """
+ ### init common params ###
+ self.split_method = params.get('split_method', '5fold_random')
+ self.split_seed = params.get('split_seed', 42)
+ self.seed = params.get('seed', 42)
+ self.set_seed(self.seed)
+ self.logger_level = int(params.get('logger_level', 1))
+ ### init NN trainer params ###
+ self.learning_rate = float(params.get('learning_rate', 1e-4))
+ self.batch_size = params.get('batch_size', 32)
+ self.max_epochs = params.get('epochs', 50)
+ self.warmup_ratio = params.get('warmup_ratio', 0.1)
+ self.patience = params.get('patience', 10)
+ self.max_norm = params.get('max_norm', 1.0)
+ self._init_dist(params)
+
+ def _init_dist(self, params):
+ device_target = 'Ascend' if params.get('use_cuda', True) else 'CPU'
+ try:
+ ms.set_device("Ascend", 0)
+ self.device = device_target
+ except Exception:
+ ms.set_device("CPU", 0)
+ self.device = 'CPU'
+ self.ddp = False
+ return
+
+ # 还没有实现分布式训练
+ def init_ddp(self, local_rank):
+ raise NotImplementedError("DDP is not supported in MindSpore trainer conversion.")
+
+ def decorate_batch(self, batch, feature_name=None):
+ """
+ Prepares a batch of data for processing by the model. This method is a wrapper that
+ delegates to a specific batch decoration method based on the data type.
+
+ :param batch: The batch of data to be processed.
+ :param feature_name: (str, optional) Name of the feature used in batch decoration. Defaults to None.
+
+ :return: The decorated batch ready for processing by the model.
+ """
+ return self.decorate_ms_batch(batch)
+
+ def decorate_graph_batch(self, batch):
+ """
+ Prepares a graph-based batch of data for processing by the model. Specifically handles
+ graph-based data structures.
+
+ :param batch: The batch of graph-based data to be processed.
+
+ :return: A tuple of (net_input, net_target) for model processing.
+ """
+ net_input, net_target = {'net_input': batch}, batch.y
+ if self.task in ['classification', 'multiclass', 'multilabel_classification']:
+ net_target = net_target.astype(ms.int32)
+ else:
+ net_target = net_target.astype(ms.float32)
+ return net_input, net_target
+
+ def decorate_ms_batch(self, batch):
+ """
+ Prepares a standard batch of data for processing by the model. Handles tensor-based data structures.
+
+ :param batch: The batch of tensor-based data to be processed.
+
+ :return: A tuple of (net_input, net_target) for model processing.
+ """
+ net_input, net_target = batch
+ if isinstance(net_input, dict):
+ net_input, net_target = net_input, net_target
+ else:
+ net_input, net_target = {'net_input': net_input}, net_target
+ if self.task == 'repr':
+ net_target = None
+ elif self.task in ['classification', 'multiclass', 'multilabel_classification']:
+ net_target = net_target.astype(ms.int32)
+ else:
+ net_target = net_target.astype(ms.float32)
+ return net_input, net_target
+
+ def fit_predict(
+ self,
+ model,
+ train_dataset,
+ valid_dataset,
+ loss_func,
+ activation_fn,
+ dump_dir,
+ fold,
+ target_scaler,
+ feature_name=None,
+ ):
+ """
+ Trains the model on the given dataset.
+
+ :param local_rank: (int) The local rank of the current process.
+ :param args: Additional arguments for training.
+ """
+ return self.fit_predict_wo_ddp(
+ model,
+ train_dataset,
+ valid_dataset,
+ loss_func,
+ activation_fn,
+ dump_dir,
+ fold,
+ target_scaler,
+ feature_name,
+ )
+
+ def fit_predict_wo_ddp(
+ self,
+ model,
+ train_dataset,
+ valid_dataset,
+ loss_func,
+ activation_fn,
+ dump_dir,
+ fold,
+ target_scaler,
+ feature_name=None,
+ ):
+ """
+ Trains the model on the given training dataset and evaluates it on the validation dataset.
+
+ :param model: The model to be trained and evaluated.
+ :param train_dataset: Dataset used for training the model.
+ :param valid_dataset: Dataset used for validating the model.
+ :param loss_func: The loss function used during training.
+ :param activation_fn: The activation function applied to the model's output.
+ :param dump_dir: Directory where the best model state is saved.
+ :param fold: The fold number in a cross-validation setting.
+ :param target_scaler: Scaler used for scaling the target variable.
+ :param feature_name: (optional) Name of the feature used in data loading. Defaults to None.
+
+ :return: Predictions made by the model on the validation dataset.
+ """
+ train_dataloader = NNDataLoader(
+ feature_name=feature_name,
+ dataset=train_dataset,
+ batch_size=self.batch_size,
+ shuffle=True,
+ collate_fn=model.batch_collate_fn,
+ distributed=False,
+ drop_last=True,
+ )
+ optimizer = nn.Adam(params=model.trainable_params(), learning_rate=self.learning_rate)
+ loss_net = _WithLossNet(model, loss_func)
+ train_step = nn.TrainOneStepCell(loss_net, optimizer)
+ train_step.set_train()
+ early_stopper = EarlyStopper(
+ self.patience, dump_dir, fold, self.metrics, self.metrics_str
+ )
+
+ for epoch in range(self.max_epochs):
+ total_trn_loss = self._train_one_epoch(
+ model,
+ train_dataloader,
+ train_step,
+ loss_func,
+ feature_name,
+ epoch,
+ )
+
+ y_preds, val_loss, metric_score = self.predict(
+ model,
+ valid_dataset,
+ loss_func,
+ activation_fn,
+ dump_dir,
+ fold,
+ target_scaler,
+ epoch,
+ load_model=False,
+ feature_name=feature_name,
+ )
+
+ self._log_epoch_results(
+ epoch, total_trn_loss, np.mean(val_loss), metric_score, optimizer
+ )
+
+ if early_stopper.early_stop_choice(
+ model, epoch, np.mean(val_loss), metric_score
+ ):
+ break
+
+ y_preds, _, _ = self.predict(
+ model,
+ valid_dataset,
+ loss_func,
+ activation_fn,
+ dump_dir,
+ fold,
+ target_scaler,
+ epoch,
+ load_model=True,
+ feature_name=feature_name,
+ )
+ return y_preds
+
+ def fit_predict_with_ddp(
+ self,
+ local_rank,
+ shared_queue,
+ model,
+ train_dataset,
+ valid_dataset,
+ loss_func,
+ activation_fn,
+ dump_dir,
+ fold,
+ target_scaler,
+ feature_name=None,
+ ):
+ raise NotImplementedError("DDP is not supported in MindSpore trainer conversion.")
+
+ def _initialize_optimizer_scheduler(self, model, train_dataloader):
+ raise NotImplementedError
+
+ def _train_one_epoch(
+ self,
+ model,
+ train_dataloader,
+ train_step,
+ loss_func,
+ feature_name,
+ epoch,
+ ):
+ model.set_train()
+ trn_loss = []
+ batch_bar = tqdm(
+ total=len(train_dataloader),
+ dynamic_ncols=True,
+ leave=False,
+ position=0,
+ desc='Train',
+ ncols=5,
+ )
+ for i, batch in enumerate(train_dataloader):
+ net_input, net_target = self.decorate_batch(batch, feature_name)
+ loss = train_step(net_input, net_target)
+ trn_loss.append(float(loss.asnumpy()))
+ avg_loss = float(sum(trn_loss) / (i + 1))
+ batch_bar.set_postfix_str(
+ f"epoch {epoch+1}/{self.max_epochs} | train_loss {avg_loss:.4f}"
+ )
+ batch_bar.update()
+ batch_bar.close()
+ return np.mean(trn_loss)
+
+ def _compute_loss(self, model, net_input, net_target, loss_func):
+ outputs = model(**net_input)
+ loss = loss_func(outputs, net_target)
+ return loss
+
+ def _backward_and_step(self, optimizer, loss, model):
+ pass
+
+ def _log_epoch_results(
+ self, epoch, total_trn_loss, total_val_loss, metric_score, optimizer
+ ):
+ _score = list(metric_score.values())[0] if metric_score is not None else float('nan')
+ _metric = list(metric_score.keys())[0] if metric_score is not None else 'metric'
+ message = (
+ f'Epoch [{epoch+1}/{self.max_epochs}] '
+ f'train_loss: {total_trn_loss:.4f}, '
+ f'val_loss: {total_val_loss:.4f}, '
+ f'val_{_metric}: {_score:.4f}'
+ )
+ logger.info(message)
+ return False
+
+ def reduce_array(self, array):
+ return array
+
+ def gather_predictions(self, y_preds, len_valid_dataset):
+ return y_preds
+
+ def predict(
+ self,
+ model,
+ dataset,
+ loss_func,
+ activation_fn,
+ dump_dir,
+ fold,
+ target_scaler=None,
+ epoch=1,
+ load_model=False,
+ feature_name=None,
+ ):
+ """
+ Executes the prediction on a given dataset using the specified model.
+
+ :param model: The model to be used for predictions.
+ :param dataset: The dataset to perform predictions on.
+ :param loss_func: The loss function used during training.
+ :param activation_fn: The activation function applied to the model's output.
+ :param dump_dir: Directory where the model state is saved.
+ :param fold: The fold number in cross-validation.
+ :param target_scaler: (optional) Scaler to inverse transform the model's output. Defaults to None.
+ :param epoch: (int) The current epoch number. Defaults to 1.
+ :param load_model: (bool) Whether to load the model from a saved state. Defaults to False.
+ :param feature_name: (str, optional) Name of the feature for data processing. Defaults to None.
+
+ :return: A tuple (y_preds, val_loss, metric_score), where y_preds are the predicted outputs,
+ val_loss is the validation loss, and metric_score is the calculated metric score.
+ """
+ model = self._prepare_model_for_prediction(model, dump_dir, fold, load_model)
+ batch_collate_fn = model.batch_collate_fn
+ dataloader = NNDataLoader(
+ feature_name=feature_name,
+ dataset=dataset,
+ batch_size=self.batch_size,
+ shuffle=False,
+ collate_fn=batch_collate_fn,
+ distributed=self.ddp,
+ valid_mode=True,
+ )
+ y_preds, val_loss, y_truths = self._perform_prediction(
+ model, dataloader, loss_func, activation_fn, load_model, epoch, feature_name
+ )
+
+ metric_score = self._calculate_metrics(
+ y_preds, y_truths, target_scaler, model, load_model
+ )
+ return y_preds, val_loss, metric_score
+
+ def _prepare_model_for_prediction(self, model, dump_dir, fold, load_model):
+ if load_model:
+ load_model_path = os.path.join(dump_dir, f'model_{fold}.ckpt')
+ model.load_pretrained_weights(load_model_path, strict=True)
+ logger.info("load model success!")
+ return model
+
+ def _perform_prediction(
+ self,
+ model,
+ dataloader,
+ loss_func,
+ activation_fn,
+ load_model,
+ epoch,
+ feature_name,
+ ):
+ model.set_train(False)
+ batch_bar = tqdm(
+ total=len(dataloader),
+ dynamic_ncols=True,
+ position=0,
+ leave=False,
+ desc='Val',
+ ncols=5,
+ )
+ val_loss = []
+ y_preds = []
+ y_truths = []
+ for i, batch in enumerate(dataloader):
+ net_input, net_target = self.decorate_batch(batch, feature_name)
+ outputs = model(**net_input)
+ if not load_model:
+ loss = loss_func(outputs, net_target)
+ val_loss.append(float(loss.asnumpy()))
+ y_preds.append(activation_fn(outputs).asnumpy())
+ y_truths.append(net_target.asnumpy())
+ if not load_model:
+ avg_vloss = float(np.sum(val_loss) / (i + 1))
+ batch_bar.set_postfix_str(
+ f"epoch {epoch+1}/{self.max_epochs} | val_loss {avg_vloss:.4f}"
+ )
+ batch_bar.update()
+ batch_bar.close()
+ y_preds = np.concatenate(y_preds)
+ y_truths = np.concatenate(y_truths)
+ return y_preds, val_loss, y_truths
+
+ def _calculate_metrics(self, y_preds, y_truths, target_scaler, model, load_model):
+ try:
+ label_cnt = model.output_dim
+ except:
+ label_cnt = None
+ if target_scaler is not None:
+ inverse_y_preds = target_scaler.inverse_transform(y_preds)
+ inverse_y_truths = target_scaler.inverse_transform(y_truths)
+ metric_score = (
+ self.metrics.cal_metric(
+ inverse_y_truths, inverse_y_preds, label_cnt=label_cnt
+ )
+ if not load_model
+ else None
+ )
+ else:
+ metric_score = (
+ self.metrics.cal_metric(y_truths, y_preds, label_cnt=label_cnt)
+ if not load_model
+ else None
+ )
+ return metric_score
+
+ def inference(
+ self,
+ model,
+ dataset,
+ return_repr=False,
+ return_atomic_reprs=False,
+ feature_name=None,
+ ):
+ """
+ Runs inference on the given dataset using the provided model. This method can return
+ various representations based on the model's output.
+
+ :param model: The neural network model to be used for inference.
+ :param dataset: The dataset on which inference is to be performed.
+ :param return_repr: (bool, optional) If True, returns class-level representations. Defaults to False.
+ :param return_atomic_reprs: (bool, optional) If True, returns atomic-level representations. Defaults to False.
+ :param feature_name: (str, optional) Name of the feature used for data loading. Defaults to None.
+
+ :return: A dictionary containing different types of representations based on the model's output and the
+ specified parameters. This can include class-level representations, atomic coordinates,
+ atomic representations, and atomic symbols.
+ """
+ return self.inference_without_ddp(
+ model, dataset, return_repr, return_atomic_reprs, feature_name
+ )
+
+ def inference_with_ddp(
+ self,
+ local_rank,
+ shared_queue,
+ model,
+ dataset,
+ return_repr=False,
+ return_atomic_reprs=False,
+ feature_name=None,
+ ):
+ raise NotImplementedError("DDP is not supported in MindSpore trainer conversion.")
+
+ def inference_without_ddp(
+ self,
+ model,
+ dataset,
+ return_repr=False,
+ return_atomic_reprs=False,
+ feature_name=None,
+ ):
+ """
+ Runs inference on the given dataset using the provided model without DistributedDataParallel (DDP).
+
+ :param model: The neural network model to be used for inference.
+ :param dataset: The dataset on which inference is to be performed.
+ :param return_repr: (bool, optional) If True, returns class-level representations. Defaults to False.
+ :param return_atomic_reprs: (bool, optional) If True, returns atomic-level representations. Defaults to False.
+ :param feature_name: (str, optional) Name of the feature used for data loading. Defaults to None.
+
+ :return: A dictionary containing different types of representations based on the model's output and the
+ specified parameters. This can include class-level representations, atomic coordinates,
+ atomic representations, and atomic symbols.
+ """
+ dataloader = NNDataLoader(
+ feature_name=feature_name,
+ dataset=dataset,
+ batch_size=self.batch_size,
+ shuffle=False,
+ collate_fn=model.batch_collate_fn,
+ distributed=False,
+ )
+ model.set_train(False)
+ repr_dict = {
+ "cls_repr": [],
+ "atomic_coords": [],
+ "atomic_reprs": [],
+ "atomic_symbol": [],
+ }
+ for batch in tqdm(dataloader):
+ net_input, _ = self.decorate_batch(batch, feature_name)
+ outputs = model(
+ **net_input,
+ return_repr=return_repr,
+ return_atomic_reprs=return_atomic_reprs,
+ )
+ assert isinstance(outputs, dict)
+ repr_dict["cls_repr"].extend(
+ item.asnumpy() for item in outputs["cls_repr"]
+ )
+ if return_atomic_reprs:
+ repr_dict["atomic_symbol"].extend(outputs["atomic_symbol"])
+ repr_dict['atomic_coords'].extend(
+ item.asnumpy() for item in outputs['atomic_coords']
+ )
+ repr_dict['atomic_reprs'].extend(
+ item.asnumpy() for item in outputs['atomic_reprs']
+ )
+
+ return repr_dict
+
+ def set_seed(self, seed):
+ """
+ Sets a random seed for mindspore and numpy to ensure reproducibility.
+ :param seed: (int) The seed number to be set.
+ """
+ ms.set_seed(seed)
+ np.random.seed(seed)
+
+
+class EarlyStopper:
+ def __init__(self, patience, dump_dir, fold, metrics, metrics_str):
+ """
+ Initializes the EarlyStopper class.
+
+ :param patience: The number of epochs to wait for an improvement before stopping.
+ :param dump_dir: Directory to save the model state.
+ :param fold: The current fold number in a cross-validation setting.
+ """
+ self.patience = patience
+ self.dump_dir = dump_dir
+ self.fold = fold
+ self.metrics = metrics
+ self.metrics_str = metrics_str
+ self.wait = 0
+ self.min_loss = float("inf")
+ self.max_loss = float("-inf")
+ self.is_early_stop = False
+
+ def early_stop_choice(self, model, epoch, loss, metric_score=None):
+ """
+ Determines if early stopping criteria are met, based on either loss improvement or custom metric score.
+
+ :param model: The model being trained.
+ :param epoch: The current epoch number.
+ :param loss: The current loss value.
+ :param metric_score: The current metric score.
+
+ :return: A boolean indicating whether early stopping should occur.
+ """
+ if not isinstance(self.metrics_str, str) or self.metrics_str in [
+ 'loss',
+ 'none',
+ '',
+ ]:
+ return self._judge_early_stop_loss(loss, model, epoch)
+ else:
+ is_early_stop, min_score, wait, max_score = self.metrics._early_stop_choice(
+ self.wait,
+ self.min_loss,
+ metric_score,
+ self.max_loss,
+ model,
+ self.dump_dir,
+ self.fold,
+ self.patience,
+ epoch,
+ )
+ self.min_loss = min_score
+ self.max_loss = max_score
+ self.wait = wait
+ self.is_early_stop = is_early_stop
+ return self.is_early_stop
+
+ def _judge_early_stop_loss(self, loss, model, epoch):
+ """
+ Determines whether early stopping should be triggered based on the loss comparison.
+
+ :param loss: The current loss value of the model.
+ :param model: The neural network model being trained.
+ :param epoch: The current epoch number.
+
+ :return: A boolean indicating whether early stopping should occur.
+ """
+ if loss <= self.min_loss:
+ self.min_loss = loss
+ self.wait = 0
+ os.makedirs(self.dump_dir, exist_ok=True)
+ from mindspore import save_checkpoint
+ save_checkpoint(model, os.path.join(self.dump_dir, f'model_{self.fold}.ckpt'))
+ else:
+ self.wait += 1
+ if self.wait >= self.patience:
+ logger.warning(f'Early stopping at epoch: {epoch+1}')
+ self.is_early_stop = True
+ return self.is_early_stop
+
+
+def NNDataLoader(
+ feature_name=None,
+ dataset=None,
+ batch_size=None,
+ shuffle=False,
+ collate_fn=None,
+ drop_last=False,
+ distributed=False,
+ valid_mode=False,
+):
+ """
+ Creates a DataLoader for neural network training or inference.
+
+ :param feature_name: (str, optional) Name of the feature used for data loading.
+ :param dataset: (Dataset, optional) The dataset from which to load the data.
+ :param batch_size: (int, optional) Number of samples per batch to load.
+ :param shuffle: (bool, optional) Whether to shuffle the data at every epoch.
+ :param collate_fn: (callable, optional) Merges a list of samples to form a mini-batch.
+ :param drop_last: (bool, optional) Set to True to drop the last incomplete batch.
+ :param distributed: (bool, optional) Set to True to enable distributed data loading.
+
+ :return: generator of batches.
+ """
+
+ indices = np.arange(len(dataset))
+ if shuffle:
+ np.random.shuffle(indices)
+ batches = []
+ batch = []
+
+ for idx in indices:
+ batch.append(dataset[idx])
+ if len(batch) == batch_size:
+ batches.append(collate_fn(batch))
+ batch = []
+
+ if len(batch) and not drop_last:
+ batches.append(collate_fn(batch))
+
+ return batches
+
+
+def get_ddp_generator(seed=3407):
+ return None
+
+
+# source from https://github.com/huggingface/transformers/blob/main/src/transformers/optimization.py#L108C1-L132C54
+def _get_linear_schedule_with_warmup_lr_lambda(
+ current_step: int, *, num_warmup_steps: int, num_training_steps: int
+):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ return max(
+ 0.0,
+ float(num_training_steps - current_step)
+ / float(max(1, num_training_steps - num_warmup_steps)),
+ )
+
+
+def get_linear_schedule_with_warmup(
+ optimizer, num_warmup_steps, num_training_steps, last_epoch=-1
+):
+ return None
+
+
+class _WithLossNet(nn.Cell):
+ def __init__(self, network, loss_fn):
+ super().__init__()
+ self.network = network
+ self.loss_fn = loss_fn
+
+ # allow passing dict inputs
+ self._grad_op = ops.GradOperation(get_by_list=True)
+
+ def construct(self, net_input, net_target):
+ outputs = self.network(**net_input)
+ return self.loss_fn(outputs, net_target)
diff --git a/MindChem/applications/unimol/unimol_tools/train.py b/MindChem/applications/unimol/unimol_tools/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..90130f9954d4a2c7adf0bf0c2dc8477f83857e8d
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/train.py
@@ -0,0 +1,231 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import absolute_import, division, print_function
+
+import argparse
+import copy
+import json
+import logging
+import os
+
+import joblib
+import numpy as np
+import pandas as pd
+
+from .data import DataHub
+from .models import NNModel
+from .tasks import Trainer
+from .utils import YamlHandler, logger
+
+
+class MolTrain(object):
+ """A :class:`MolTrain` class is responsible for interface of training process of molecular data."""
+
+ def __init__(
+ self,
+ task='classification',
+ data_type='molecule',
+ epochs=10,
+ learning_rate=1e-4,
+ batch_size=16,
+ early_stopping=5,
+ metrics="none",
+ split='random', # random, scaffold, group, stratified
+ split_group_col='scaffold', # only active with group split
+ kfold=5,
+ save_path='./exp',
+ remove_hs=False,
+ smiles_col='SMILES',
+ target_cols=None,
+ target_col_prefix='TARGET',
+ target_anomaly_check=False,
+ smiles_check="filter",
+ target_normalize="auto",
+ max_norm=5.0,
+ use_cuda=True,
+ use_amp=True,
+ use_ddp=False,
+ use_gpu="all",
+ freeze_layers=None,
+ freeze_layers_reversed=False,
+ load_model_dir=None, # load model for transfer learning
+ model_name='unimolv1',
+ model_size='84m',
+ conf_cache_level=1,
+ **params,
+ ):
+ """
+ Initialize a :class:`MolTrain` class.
+
+ :param task: str, default='classification', currently support [`]classification`, `regression`, `multiclass`, `multilabel_classification`, `multilabel_regression`.
+ :param data_type: str, default='molecule', currently support molecule, oled.
+ :param epochs: int, default=10, number of epochs to train.
+ :param learning_rate: float, default=1e-4, learning rate of optimizer.
+ :param batch_size: int, default=16, batch size of training.
+ :param early_stopping: int, default=5, early stopping patience.
+ :param metrics: str, default='none', metrics to evaluate model performance.
+
+ currently support:
+
+ - classification: auc, auprc, log_loss, acc, f1_score, mcc, precision, recall, cohen_kappa.
+
+ - regression: mse, pearsonr, spearmanr, mse, r2.
+
+ - multiclass: log_loss, acc.
+
+ - multilabel_classification: auc, auprc, log_loss, acc, mcc.
+
+ - multilabel_regression: mae, mse, r2.
+
+ :param split: str, default='random', split method of training dataset. currently support: random, scaffold, group, stratified, select.
+
+ - random: random split.
+
+ - scaffold: split by scaffold.
+
+ - group: split by group. `split_group_col` should be specified.
+
+ - stratified: stratified split. `split_group_col` should be specified.
+
+ - select: use `split_group_col` to manually select the split group. Column values of `split_group_col` should be range from 0 to kfold-1 to indicate the split group.
+
+ :param split_group_col: str, default='scaffold', column name of group split.
+ :param kfold: int, default=5, number of folds for k-fold cross validation.
+
+ - 1: no split. all data will be used for training.
+
+ :param save_path: str, default='./exp', path to save training results.
+ :param remove_hs: bool, default=False, whether to remove hydrogens from molecules.
+ :param smiles_col: str, default='SMILES', column name of SMILES.
+ :param target_cols: list or str, default=None, column names of target values.
+ :param target_col_prefix: str, default='TARGET', prefix of target column name.
+ :param target_anomaly_check: str, default=False, how to deal with anomaly target values. currently support: filter, none.
+ :param smiles_check: str, default='filter', how to deal with invalid SMILES. currently support: filter, none.
+ :param target_normalize: str, default='auto', how to normalize target values. 'auto' means we will choose the normalize strategy by automatic. \
+ currently support: auto, minmax, standard, robust, log1p, none.
+ :param max_norm: float, default=5.0, max norm of gradient clipping.
+ :param use_cuda: bool, default=True, whether to use GPU.
+ :param use_amp: bool, default=True, whether to use automatic mixed precision.
+ :param use_ddp: bool, default=True, whether to use distributed data parallel.
+ :param use_gpu: str, default='all', which GPU to use. 'all' means use all GPUs. '0,1,2' means use GPU 0, 1, 2.
+ :param freeze_layers: str or list, frozen layers by startwith name list. ['encoder', 'gbf'] will freeze all the layers whose name start with 'encoder' or 'gbf'.
+ :param freeze_layers_reversed: bool, default=False, inverse selection of frozen layers
+ :param params: dict, default=None, other parameters.
+ :param load_model_dir: str, default=None, path to load model for transfer learning.
+ :param model_name: str, default='unimolv1', currently support unimolv1, unimolv2.
+ :param model_size: str, default='84m', model size. work when model_name is unimolv2. Avaliable: 84m, 164m, 310m, 570m, 1.1B.
+ :param conf_cache_level: int, optional [0, 1, 2], default=1, configuration cache level to save the conformers to sdf file.
+ - 0: no caching.
+ - 1: cache if not exists.
+ - 2: always cache.
+
+ """
+ if load_model_dir is not None:
+ config_path = os.path.join(load_model_dir, 'config.yaml')
+ logger.info('Load config file from {}'.format(config_path))
+ else:
+ config_path = os.path.join(os.path.dirname(__file__), 'config/default.yaml')
+ self.yamlhandler = YamlHandler(config_path)
+ config = self.yamlhandler.read_yaml()
+ config.task = task
+ config.data_type = data_type
+ config.epochs = epochs
+ config.learning_rate = learning_rate
+ config.batch_size = batch_size
+ config.patience = early_stopping
+ config.metrics = metrics
+ config.split = split
+ config.split_group_col = split_group_col
+ config.kfold = kfold
+ config.remove_hs = remove_hs
+ config.smiles_col = smiles_col
+ config.target_cols = target_cols
+ config.target_col_prefix = target_col_prefix
+ config.anomaly_clean = target_anomaly_check or target_anomaly_check in [
+ 'filter'
+ ]
+ config.smi_strict = smiles_check in ['filter']
+ config.target_normalize = target_normalize
+ config.max_norm = max_norm
+ config.use_cuda = use_cuda
+ config.use_amp = use_amp
+ config.use_ddp = use_ddp
+ config.use_gpu = use_gpu
+ config.freeze_layers = freeze_layers
+ config.freeze_layers_reversed = freeze_layers_reversed
+ config.load_model_dir = load_model_dir
+ config.model_name = model_name
+ config.model_size = model_size
+ config.conf_cache_level = conf_cache_level
+ self.save_path = save_path
+ self.config = config
+
+ def fit(self, data):
+ """
+ Fit the model according to the given training data with multi datasource support, including SMILES csv file and custom coordinate data.
+
+ For example: custom coordinate data.
+
+ .. code-block:: python
+
+ from unimol_tools import MolTrain
+ import numpy as np
+ custom_data ={'target':np.random.randint(2, size=100),
+ 'atoms':[['C','C','H','H','H','H'] for _ in range(100)],
+ 'coordinates':[np.random.randn(6,3) for _ in range(100)],
+ }
+
+ clf = MolTrain()
+ clf.fit(custom_data)
+ """
+ self.datahub = DataHub(
+ data=data, is_train=True, save_path=self.save_path, **self.config
+ )
+ self.data = self.datahub.data
+ self.update_and_save_config()
+ self.trainer = Trainer(save_path=self.save_path, **self.config)
+ self.model = NNModel(self.data, self.trainer, **self.config)
+ self.model.run()
+ scalar = self.data['target_scaler']
+ y_pred = self.model.cv['pred']
+ y_true = np.array(self.data['target'])
+ metrics = self.trainer.metrics
+ if scalar is not None:
+ y_pred = scalar.inverse_transform(y_pred)
+ y_true = scalar.inverse_transform(y_true)
+
+ if self.config["task"] in ['classification', 'multilabel_classification']:
+ threshold = metrics.calculate_classification_threshold(y_true, y_pred)
+ joblib.dump(threshold, os.path.join(self.save_path, 'threshold.dat'))
+
+ self.cv_pred = y_pred
+ return
+
+ def update_and_save_config(self):
+ """
+ Update and save config file.
+ """
+ self.config['num_classes'] = self.data['num_classes']
+ self.config['target_cols'] = ','.join(self.data['target_cols'])
+ if self.config['task'] == 'multiclass':
+ self.config['multiclass_cnt'] = self.data['multiclass_cnt']
+
+ self.config['split_method'] = (
+ f"{self.config['kfold']}fold_{self.config['split']}"
+ )
+ if self.save_path is not None:
+ if not os.path.exists(self.save_path):
+ logger.info('Create output directory: {}'.format(self.save_path))
+ os.makedirs(self.save_path)
+ else:
+ logger.info(
+ 'Output directory already exists: {}'.format(self.save_path)
+ )
+ logger.info(
+ 'Warning: Overwrite output directory: {}'.format(self.save_path)
+ )
+ out_path = os.path.join(self.save_path, 'config.yaml')
+ self.yamlhandler.write_yaml(data=self.config, out_file_path=out_path)
+ return
diff --git a/MindChem/applications/unimol/unimol_tools/utils/__init__.py b/MindChem/applications/unimol/unimol_tools/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c86f5cc3903087a3ea7b0965e5981d6203844edf
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/utils/__init__.py
@@ -0,0 +1,4 @@
+from .base_logger import logger
+from .config_handler import YamlHandler
+from .metrics import Metrics
+from .util import *
diff --git a/MindChem/applications/unimol/unimol_tools/utils/base_logger.py b/MindChem/applications/unimol/unimol_tools/utils/base_logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..52cb5a8906d22ddf6c200a874f1888dcf83fa631
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/utils/base_logger.py
@@ -0,0 +1,114 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import absolute_import, division, print_function
+
+import datetime
+import logging
+import os
+import sys
+import threading
+from logging.handlers import TimedRotatingFileHandler
+
+BASE_DIR = os.path.dirname(os.path.abspath(__file__))
+
+
+class PackagePathFilter(logging.Filter):
+ """A custom logging filter for adding the relative path to the log record."""
+
+ def filter(self, record):
+ """add relative path to record"""
+ pathname = record.pathname
+ record.relativepath = None
+ abs_sys_paths = map(os.path.abspath, sys.path)
+ for path in sorted(abs_sys_paths, key=len, reverse=True): # longer paths first
+ if not path.endswith(os.sep):
+ path += os.sep
+ if pathname.startswith(path):
+ record.relativepath = os.path.relpath(pathname, path)
+ break
+ return True
+
+
+class Logger(object):
+ """A custom logger class that provides logging functionality to console and file."""
+
+ _instance = None
+ _lock = threading.Lock()
+
+ DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
+ LOG_FORMAT = "%(asctime)s | %(relativepath)s | %(lineno)s | %(levelname)s | %(name)s | %(message)s"
+
+ def __new__(cls, *args, **kwargs):
+ if not cls._instance:
+ with cls._lock:
+ if not cls._instance:
+ cls._instance = super(Logger, cls).__new__(cls)
+ return cls._instance
+
+ def __init__(self, logger_name='None'):
+ """
+ :param logger_name: (str) The name of the logger (default: 'None')
+ """
+ self.logger = logging.getLogger(logger_name)
+ logging.root.setLevel(logging.NOTSET)
+ self.log_file_name = 'unimol_tools_{0}.log'.format(
+ datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
+ )
+
+ cwd_path = os.path.abspath(os.getcwd())
+ self.log_path = os.path.join(cwd_path, "logs")
+
+ if not os.path.exists(self.log_path):
+ os.makedirs(self.log_path)
+ self.backup_count = 5
+
+ self.console_output_level = 'INFO'
+ self.file_output_level = 'INFO'
+
+ self.formatter = logging.Formatter(self.LOG_FORMAT, self.DATE_FORMAT)
+
+ def get_logger(self):
+ """
+ Get the logger object.
+
+ :return: logging.Logger - a logger object.
+
+ """
+ if not self.logger.handlers:
+ console_handler = logging.StreamHandler()
+ console_handler.setFormatter(self.formatter)
+ console_handler.setLevel(self.console_output_level)
+ console_handler.addFilter(PackagePathFilter())
+ self.logger.addHandler(console_handler)
+
+ file_handler = TimedRotatingFileHandler(
+ filename=os.path.join(self.log_path, self.log_file_name),
+ when='D',
+ interval=1,
+ backupCount=self.backup_count,
+ delay=True,
+ encoding='utf-8',
+ )
+ file_handler.setFormatter(self.formatter)
+ file_handler.setLevel(self.file_output_level)
+ self.logger.addHandler(file_handler)
+ return self.logger
+
+
+# add highlight formatter to logger
+class HighlightFormatter(logging.Formatter):
+ def format(self, record):
+ if record.levelno == logging.WARNING:
+ record.msg = "\033[93m{}\033[0m".format(record.msg) # 黄色高亮
+ return super().format(record)
+
+
+logger = Logger('Uni-Mol Tools').get_logger()
+logger.setLevel(logging.INFO)
+
+# highlight warning messages in console
+for handler in logger.handlers:
+ if isinstance(handler, logging.StreamHandler):
+ handler.setFormatter(HighlightFormatter(Logger.LOG_FORMAT, Logger.DATE_FORMAT))
diff --git a/MindChem/applications/unimol/unimol_tools/utils/config_handler.py b/MindChem/applications/unimol/unimol_tools/utils/config_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..3185f9e8d077b389bb82a4e3496b92912281b2ae
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/utils/config_handler.py
@@ -0,0 +1,66 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import absolute_import, division, print_function
+
+import os
+
+import yaml
+from addict import Dict
+
+from .base_logger import logger
+
+
+class YamlHandler:
+ '''A clss to read and write the yaml file'''
+
+ def __init__(self, file_path):
+ """
+ A custom logger class that provides logging functionality to console and file.
+
+ :param file_path: (str) The yaml file path of the config.
+ """
+ if not os.path.exists(file_path):
+ raise FileExistsError(OSError)
+ self.file_path = file_path
+
+ def read_yaml(self, encoding='utf-8'):
+ """read yaml file and convert to easydict
+
+ :param encoding: (str) encoding method uses utf-8 by default
+ :return: Dict (addict), the usage of Dict is the same as dict
+ """
+ with open(self.file_path, encoding=encoding) as f:
+ return Dict(yaml.load(f.read(), Loader=yaml.FullLoader))
+
+ def write_yaml(self, data, out_file_path, encoding='utf-8'):
+ """write dict or easydict to yaml file(auto write to self.file_path)
+
+ :param data: (dict or Dict(addict)) dict containing the contents of the yaml file
+ """
+ with open(out_file_path, encoding=encoding, mode='w') as f:
+ return yaml.dump(
+ addict2dict(data) if isinstance(data, Dict) else data,
+ stream=f,
+ allow_unicode=True,
+ )
+
+
+def addict2dict(addict_obj):
+ '''convert addict to dict
+
+ :param addict_obj: (Dict(addict)) the addict obj that you want to convert to dict
+
+ :return: (Dict) converted result
+ '''
+ dict_obj = {}
+ for key, vals in addict_obj.items():
+ dict_obj[key] = addict2dict(vals) if isinstance(vals, Dict) else vals
+ return dict_obj
+
+
+if __name__ == '__main__':
+ yaml_handler = YamlHandler('../config/default.yaml')
+ config = yaml_handler.read_yaml()
+ print(config.Modelhub)
diff --git a/MindChem/applications/unimol/unimol_tools/utils/metrics.py b/MindChem/applications/unimol/unimol_tools/utils/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e35b082616a8b951c54b5d6462d6d7061aad163
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/utils/metrics.py
@@ -0,0 +1,353 @@
+import copy
+import os
+
+import numpy as np
+import pandas as pd
+from scipy.stats import pearsonr, spearmanr
+from sklearn.metrics import (
+ accuracy_score,
+ average_precision_score,
+ cohen_kappa_score,
+ f1_score,
+ log_loss,
+ matthews_corrcoef,
+ mean_absolute_error,
+ mean_squared_error,
+ precision_score,
+ r2_score,
+ recall_score,
+ roc_auc_score,
+)
+
+from .base_logger import logger
+
+
+def cal_nan_metric(y_true, y_pred, nan_value=None, metric_func=None):
+ if y_true.shape != y_pred.shape:
+ raise ValueError('y_ture and y_pred must have same shape')
+
+ if isinstance(y_true, pd.DataFrame):
+ y_true = y_true.to_numpy()
+
+ if isinstance(y_pred, pd.DataFrame):
+ y_pred = y_pred.to_numpy()
+
+ if not np.issubdtype(y_true.dtype, np.floating):
+ y_true = y_true.astype(np.float64)
+
+ mask = ~np.isnan(y_true)
+ if nan_value is not None:
+ mask = mask & (y_true != nan_value)
+
+ sz = y_true.shape[1]
+ result = []
+ for i in range(sz):
+ _mask = mask[:, i]
+ if not (~_mask).all():
+ result.append(metric_func(y_true[:, i][_mask], y_pred[:, i][_mask]))
+ return np.mean(result)
+
+
+def multi_acc(y_true, y_pred):
+ y_true = y_true.flatten()
+ y_pred_idx = np.argmax(y_pred, axis=1)
+ return np.mean(y_true == y_pred_idx)
+
+
+def log_loss_with_label(y_true, y_pred, labels=None):
+ if labels is None:
+ return log_loss(y_true, y_pred)
+ else:
+ return log_loss(y_true, y_pred, labels=labels)
+
+
+def reg_preasonr(y_true, y_pred):
+ return pearsonr(y_true, y_pred)[0]
+
+
+def reg_spearmanr(y_true, y_pred):
+ return spearmanr(y_true, y_pred)[0]
+
+
+# metric_func, is_increase, value_type
+METRICS_REGISTER = {
+ 'regression': {
+ "mae": [mean_absolute_error, False, 'float'],
+ "pearsonr": [reg_preasonr, True, 'float'],
+ "spearmanr": [reg_spearmanr, True, 'float'],
+ "mse": [mean_squared_error, False, 'float'],
+ "r2": [r2_score, True, 'float'],
+ },
+ 'classification': {
+ "auroc": [roc_auc_score, True, 'float'],
+ "auc": [roc_auc_score, True, 'float'],
+ "auprc": [average_precision_score, True, 'float'],
+ "log_loss": [log_loss, False, 'float'],
+ "acc": [accuracy_score, True, 'int'],
+ "f1_score": [f1_score, True, 'int'],
+ "mcc": [matthews_corrcoef, True, 'int'],
+ "precision": [precision_score, True, 'int'],
+ "recall": [recall_score, True, 'int'],
+ "cohen_kappa": [cohen_kappa_score, True, 'int'],
+ },
+ 'multiclass': {
+ "log_loss": [log_loss_with_label, False, 'float'],
+ "acc": [multi_acc, True, 'int'],
+ },
+ 'multilabel_classification': {
+ "auroc": [roc_auc_score, True, 'float'],
+ "auc": [roc_auc_score, True, 'float'],
+ "auprc": [average_precision_score, True, 'float'],
+ "log_loss": [log_loss_with_label, False, 'float'],
+ "acc": [accuracy_score, True, 'int'],
+ "mcc": [matthews_corrcoef, True, 'int'],
+ },
+ 'multilabel_regression': {
+ "mae": [mean_absolute_error, False, 'float'],
+ "mse": [mean_squared_error, False, 'float'],
+ "r2": [r2_score, True, 'float'],
+ },
+}
+
+DEFAULT_METRICS = {
+ 'regression': ['mse', 'mae', 'r2', 'spearmanr', 'pearsonr'],
+ 'classification': [
+ 'log_loss',
+ 'auc',
+ 'f1_score',
+ 'mcc',
+ 'acc',
+ 'precision',
+ 'recall',
+ ],
+ 'multiclass': ['log_loss', 'acc'],
+ "multilabel_classification": ['log_loss', 'auc', 'auprc'],
+ "multilabel_regression": ['mse', 'mae', 'r2'],
+}
+
+
+class Metrics(object):
+ """
+ Class for calculating metrics for different tasks.
+
+ :param task: The task type. Supported tasks are 'regression', 'multilabel_regression',
+ 'classification', 'multilabel_classification', and 'multiclass'.
+ :param metrics_str: Comma-separated string of metric names. If provided, only the specified metrics will be calculated. If not provided or an empty string, default metrics for the task will be used.
+ """
+
+ def __init__(self, task=None, metrics_str=None, **params):
+ self.task = task
+ self.threshold = np.arange(0, 1.0, 0.1)
+ self.metric_dict = self._init_metrics(self.task, metrics_str, **params)
+ self.METRICS_REGISTER = METRICS_REGISTER[task]
+
+ def _init_metrics(self, task, metrics_str, **params):
+ if task not in METRICS_REGISTER:
+ raise ValueError('Unknown task: {}'.format(self.task))
+ if (
+ not isinstance(metrics_str, str)
+ or metrics_str == ''
+ or metrics_str == 'none'
+ ):
+ metric_dict = {
+ key: METRICS_REGISTER[task][key] for key in DEFAULT_METRICS[task]
+ }
+ else:
+ for key in metrics_str.split(','):
+ if key not in METRICS_REGISTER[task]:
+ raise ValueError('Unknown metric: {}'.format(key))
+
+ priority_metric_list = metrics_str.split(',')
+ metric_list = priority_metric_list + [
+ key for key in METRICS_REGISTER[task] if key not in priority_metric_list
+ ]
+ metric_dict = {key: METRICS_REGISTER[task][key] for key in metric_list}
+
+ return metric_dict
+
+ def cal_classification_metric(self, label, predict, nan_value=-1.0, threshold=None):
+ """
+ :param label: the labels of the dataset.
+ :param predict: the predict values of the model.
+ """
+ res_dict = {}
+ for metric_type, metric_value in self.metric_dict.items():
+ metric, _, value_type = metric_value
+
+ def nan_metric(label, predict):
+ return cal_nan_metric(label, predict, nan_value, metric)
+
+ if value_type == 'float':
+ res_dict[metric_type] = nan_metric(
+ label.astype(int), predict.astype(np.float32)
+ )
+ elif value_type == 'int':
+ thre = 0.5 if threshold is None else threshold
+ res_dict[metric_type] = nan_metric(
+ label.astype(int), (predict > thre).astype(int)
+ )
+
+ # TO DO : add more metrics by grid search threshold
+
+ return res_dict
+
+ def cal_reg_metric(self, label, predict, nan_value=-1.0):
+ """
+ :param label: the labels of the dataset.
+ :param predict: the predict values of the model.
+ """
+ res_dict = {}
+ for metric_type, metric_value in self.metric_dict.items():
+ metric, _, _ = metric_value
+
+ def nan_metric(label, predict):
+ return cal_nan_metric(label, predict, nan_value, metric)
+
+ res_dict[metric_type] = nan_metric(label, predict)
+
+ return res_dict
+
+ def cal_multiclass_metric(self, label, predict, nan_value=-1.0, label_cnt=-1):
+ """
+ :param label: the labels of the dataset.
+ :param predict: the predict values of the model.
+ """
+ res_dict = {}
+ for metric_type, metric_value in self.metric_dict.items():
+ metric, _, _ = metric_value
+ if metric_type == 'log_loss' and label_cnt is not None:
+ labels = list(range(label_cnt))
+ res_dict[metric_type] = metric(label, predict, labels)
+ else:
+ res_dict[metric_type] = metric(label, predict)
+
+ return res_dict
+
+ def cal_metric(self, label, predict, nan_value=-1.0, threshold=0.5, label_cnt=None):
+ if self.task in ['regression', 'multilabel_regression']:
+ return self.cal_reg_metric(label, predict, nan_value)
+ elif self.task in ['classification', 'multilabel_classification']:
+ return self.cal_classification_metric(label, predict, nan_value)
+ elif self.task in ['multiclass']:
+ return self.cal_multiclass_metric(label, predict, nan_value, label_cnt)
+ else:
+ raise ValueError("We will add more tasks soon")
+
+ def _early_stop_choice(
+ self,
+ wait,
+ min_score,
+ metric_score,
+ max_score,
+ model,
+ dump_dir,
+ fold,
+ patience,
+ epoch,
+ ):
+ score = list(metric_score.values())[0]
+ judge_metric = list(metric_score.keys())[0]
+ is_increase = METRICS_REGISTER[self.task][judge_metric][1]
+ if is_increase:
+ is_early_stop, max_score, wait = self._judge_early_stop_increase(
+ wait, score, max_score, model, dump_dir, fold, patience, epoch
+ )
+ else:
+ is_early_stop, min_score, wait = self._judge_early_stop_decrease(
+ wait, score, min_score, model, dump_dir, fold, patience, epoch
+ )
+ return is_early_stop, min_score, wait, max_score
+
+ def _judge_early_stop_decrease(
+ self, wait, score, min_score, model, dump_dir, fold, patience, epoch
+ ):
+ is_early_stop = False
+ if score <= min_score:
+ min_score = score
+ wait = 0
+ info = {'model_state_dict': model.state_dict()}
+ os.makedirs(dump_dir, exist_ok=True)
+ from mindspore import save_checkpoint
+ save_checkpoint(model, os.path.join(dump_dir, f'model_{fold}.ckpt'))
+ elif score >= min_score:
+ wait += 1
+ if wait == patience:
+ logger.warning(f'Early stopping at epoch: {epoch+1}')
+ is_early_stop = True
+ return is_early_stop, min_score, wait
+
+ def _judge_early_stop_increase(
+ self, wait, score, max_score, model, dump_dir, fold, patience, epoch
+ ):
+ is_early_stop = False
+ if score >= max_score:
+ max_score = score
+ wait = 0
+ info = {'model_state_dict': model.state_dict()}
+ os.makedirs(dump_dir, exist_ok=True)
+ from mindspore import save_checkpoint
+ save_checkpoint(model, os.path.join(dump_dir, f'model_{fold}.ckpt'))
+ elif score <= max_score:
+ wait += 1
+ if wait == patience:
+ logger.warning(f'Early stopping at epoch: {epoch+1}')
+ is_early_stop = True
+ return is_early_stop, max_score, wait
+
+ def calculate_single_classification_threshold(
+ self, target, pred, metrics_key=None, step=20
+ ):
+ data = copy.deepcopy(pred)
+ range_min = np.min(data).item()
+ range_max = np.max(data).item()
+
+ for metric_type, metric_value in self.metric_dict.items():
+ metric, is_increase, value_type = metric_value
+ if value_type == 'int':
+ metrics_key = metric_value
+ break
+ # default threshold metrics
+ if metrics_key is None:
+ metrics_key = METRICS_REGISTER['classification']['f1_score']
+ logger.info("metrics for threshold: {0}".format(metrics_key[0].__name__))
+ metrics = metrics_key[0]
+ if metrics_key[1]:
+ # increase metric
+ best_metric = float('-inf')
+ best_threshold = 0.5
+ for threshold in np.linspace(range_min, range_max, step):
+ pred_label = np.zeros_like(pred)
+ pred_label[pred > threshold] = 1
+ # print ("threshold: ", threshold, metric(target, pred_label))
+ if metric(target, pred_label) > best_metric:
+ best_metric = metric(target, pred_label)
+ best_threshold = threshold
+ logger.info(
+ "best threshold: {0}, metrics: {1}".format(best_threshold, best_metric)
+ )
+ else:
+ # increase metric
+ best_metric = float('inf')
+ best_threshold = 0.5
+ for threshold in np.linspace(range_min, range_max, step):
+ pred_label = np.zeros_like(pred)
+ pred_label[pred > threshold] = 1
+ if metric(target, pred_label) < best_metric:
+ best_metric = metric(target, pred_label)
+ best_threshold = threshold
+ logger.info(
+ "best threshold: {0}, metrics: {1}".format(best_threshold, best_metric)
+ )
+
+ return best_threshold
+
+ def calculate_classification_threshold(self, target, pred):
+ threshold = np.zeros(target.shape[1])
+ for idx in range(target.shape[1]):
+ threshold[idx] = self.calculate_single_classification_threshold(
+ target[:, idx].reshape(-1, 1),
+ pred[:, idx].reshape(-1, 1),
+ metrics_key=None,
+ step=20,
+ )
+ return threshold
diff --git a/MindChem/applications/unimol/unimol_tools/utils/util.py b/MindChem/applications/unimol/unimol_tools/utils/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..65d58f3af79f5833085dcfc2da4e0f5f6d35cbc5
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/utils/util.py
@@ -0,0 +1,116 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from hashlib import md5
+
+def pad_1d_tokens(
+ values,
+ pad_idx,
+ left_pad=False,
+ pad_to_length=None,
+ pad_to_multiple=1,
+):
+ """
+ padding one dimension tokens inputs.
+
+ :param values: A list of 1d tensors.
+ :param pad_idx: The padding index.
+ :param left_pad: Whether to left pad the tensors. Defaults to False.
+ :param pad_to_length: The desired length of the padded tensors. Defaults to None.
+ :param pad_to_multiple: The multiple to pad the tensors to. Defaults to 1.
+
+ :return: A padded 1d tensor as a mindspore.Tensor.
+ """
+ size = max(v.size(0) for v in values)
+ size = size if pad_to_length is None else max(size, pad_to_length)
+ if pad_to_multiple != 1 and size % pad_to_multiple != 0:
+ size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
+ res = values[0].new(len(values), size).fill_(pad_idx)
+
+ def copy_tensor(src, dst):
+ assert dst.numel() == src.numel()
+ dst.copy_(src)
+
+ for i, v in enumerate(values):
+ copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
+ return res
+
+
+def pad_2d(
+ values,
+ pad_idx,
+ dim=1,
+ left_pad=False,
+ pad_to_length=None,
+ pad_to_multiple=1,
+):
+ """
+ padding two dimension tensor inputs.
+
+ :param values: A list of 2d tensors.
+ :param pad_idx: The padding index.
+ :param left_pad: Whether to pad on the left side. Defaults to False.
+ :param pad_to_length: The length to pad the tensors to. If None, the maximum length in the list
+ is used. Defaults to None.
+ :param pad_to_multiple: The multiple to pad the tensors to. Defaults to 1.
+
+ :return: A padded 2d tensor as a mindspore.Tensor.
+ """
+ size = max(v.size(0) for v in values)
+ size = size if pad_to_length is None else max(size, pad_to_length)
+ if pad_to_multiple != 1 and size % pad_to_multiple != 0:
+ size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
+ if dim == 1:
+ res = values[0].new(len(values), size, size).fill_(pad_idx)
+ else:
+ res = values[0].new(len(values), size, size, dim).fill_(pad_idx)
+
+ def copy_tensor(src, dst):
+ assert dst.numel() == src.numel()
+ dst.copy_(src)
+
+ for i, v in enumerate(values):
+ copy_tensor(
+ v,
+ (
+ res[i][size - len(v) :, size - len(v) :]
+ if left_pad
+ else res[i][: len(v), : len(v)]
+ ),
+ )
+ return res
+
+
+def pad_coords(
+ values,
+ pad_idx,
+ dim=3,
+ left_pad=False,
+ pad_to_length=None,
+ pad_to_multiple=1,
+):
+ """
+ padding two dimension tensor coords which the third dimension is 3.
+
+ :param values: A list of 1d tensors.
+ :param pad_idx: The value used for padding.
+ :param left_pad: Whether to pad on the left side. Defaults to False.
+ :param pad_to_length: The desired length of the padded tensor. Defaults to None.
+ :param pad_to_multiple: The multiple to pad the tensor to. Defaults to 1.
+
+ :return: A padded 2d coordinate tensor as a mindspore.Tensor.
+ """
+ size = max(v.size(0) for v in values)
+ size = size if pad_to_length is None else max(size, pad_to_length)
+ if pad_to_multiple != 1 and size % pad_to_multiple != 0:
+ size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
+ res = values[0].new(len(values), size, dim).fill_(pad_idx)
+
+ def copy_tensor(src, dst):
+ assert dst.numel() == src.numel()
+ dst.copy_(src)
+
+ for i, v in enumerate(values):
+ copy_tensor(v, res[i][size - len(v) :, :] if left_pad else res[i][: len(v), :])
+ return res
diff --git a/MindChem/applications/unimol/unimol_tools/weights/__init__.py b/MindChem/applications/unimol/unimol_tools/weights/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9f6ffadc746ff69648378c00cf6882a3ca42958
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/weights/__init__.py
@@ -0,0 +1 @@
+from .weighthub import WEIGHT_DIR, weight_download, weight_download_v2
diff --git a/MindChem/applications/unimol/unimol_tools/weights/mol.dict.txt b/MindChem/applications/unimol/unimol_tools/weights/mol.dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4130c254b4da592338b43298b49120a561dfae60
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/weights/mol.dict.txt
@@ -0,0 +1,30 @@
+[PAD]
+[CLS]
+[SEP]
+[UNK]
+C
+N
+O
+S
+H
+Cl
+F
+Br
+I
+Si
+P
+B
+Na
+K
+Al
+Ca
+Sn
+As
+Hg
+Fe
+Zn
+Cr
+Se
+Gd
+Au
+Li
\ No newline at end of file
diff --git a/MindChem/applications/unimol/unimol_tools/weights/oled.dict.txt b/MindChem/applications/unimol/unimol_tools/weights/oled.dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2775ae8972f5d444d78cc419ca61617e8fa777a4
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/weights/oled.dict.txt
@@ -0,0 +1,93 @@
+[PAD]
+[CLS]
+[SEP]
+[UNK]
+O
+H
+F
+S
+Li
+P
+N
+Mg
+C
+Si
+Cl
+Fe
+Mn
+B
+Se
+Al
+Co
+Na
+V
+Ni
+Cu
+K
+Ca
+Ba
+Ti
+Zn
+Ge
+Sr
+I
+Br
+Te
+Cr
+Mo
+Sb
+Ga
+Sn
+Bi
+La
+As
+Nb
+Rb
+W
+Y
+In
+Cs
+Ag
+Zr
+Cd
+Pb
+Nd
+Ta
+Ce
+Pd
+Pr
+Sm
+Rh
+Hg
+Tl
+Pt
+Er
+Tb
+Ru
+Sc
+U
+Dy
+Ho
+Au
+Hf
+Yb
+Ir
+Be
+Eu
+Tm
+Re
+Lu
+Gd
+Os
+Th
+Tc
+Pu
+Np
+Pm
+Xe
+Ac
+Pa
+Kr
+He
+Ne
+Ar
\ No newline at end of file
diff --git a/MindChem/applications/unimol/unimol_tools/weights/weighthub.py b/MindChem/applications/unimol/unimol_tools/weights/weighthub.py
new file mode 100644
index 0000000000000000000000000000000000000000..473d24ba0725c6729e75456bfbed15922095e047
--- /dev/null
+++ b/MindChem/applications/unimol/unimol_tools/weights/weighthub.py
@@ -0,0 +1,107 @@
+import os
+from ..utils import logger
+
+try:
+ from openmind_hub import snapshot_download
+except:
+ openmind_hub_installed = False
+
+ def snapshot_download(*args, **kwargs):
+ raise ImportError(
+ 'openmind_hub is not installed. If weights are not avaliable, please install it by running: pip install openmind_hub. Otherwise, please download the weights manually from https://modelers.cn/models/Weiland/Uni-Molv1/tree/master'
+ )
+
+WEIGHT_DIR = os.environ.get(
+ 'UNIMOL_WEIGHT_DIR', os.path.dirname(os.path.abspath(__file__))
+)
+
+# os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" # use mirror to download weights
+
+
+def log_weights_dir():
+ """
+ Logs the directory where the weights are stored.
+ """
+ if 'UNIMOL_WEIGHT_DIR' in os.environ:
+ logger.warning(
+ f'Using custom weight directory from UNIMOL_WEIGHT_DIR: {WEIGHT_DIR}'
+ )
+ else:
+ logger.info(f'Weights will be downloaded to default directory: {WEIGHT_DIR}')
+
+
+def weight_download(pretrain, save_path, local_dir_use_symlinks=True):
+ """
+ Downloads the specified pretrained model weights.
+
+ :param pretrain: (str), The name of the pretrained model to download.
+ :param save_path: (str), The directory where the weights should be saved.
+ :param local_dir_use_symlinks: (bool, optional), Whether to use symlinks for the local directory. Defaults to True.
+ """
+ log_weights_dir()
+
+ if os.path.exists(os.path.join(save_path, pretrain)):
+ logger.info(f'{pretrain} exists in {save_path}')
+ return
+ logger.info(f'Downloading {pretrain}')
+
+ snapshot_download(
+ repo_id="Weiland/Uni-Molv1",
+ repo_type="model",
+ revision="master",
+ local_dir=save_path,
+ allow_patterns=pretrain,
+ local_dir_use_symlinks=local_dir_use_symlinks,
+ # max_workers=8
+ )
+
+
+def weight_download_v2(pretrain, save_path, local_dir_use_symlinks=True):
+ """
+ Downloads the specified pretrained model weights.
+
+ :param pretrain: (str), The name of the pretrained model to download.
+ :param save_path: (str), The directory where the weights should be saved.
+ :param local_dir_use_symlinks: (bool, optional), Whether to use symlinks for the local directory. Defaults to True.
+ """
+ log_weights_dir()
+
+ if os.path.exists(os.path.join(save_path, pretrain)):
+ logger.info(f'{pretrain} exists in {save_path}')
+ return
+
+ logger.info(f'Downloading {pretrain}')
+ snapshot_download(
+ repo_id="Weiland/Uni-Molv2",
+ repo_type="model",
+ revision="main",
+ local_dir=save_path,
+ allow_patterns=pretrain,
+ local_dir_use_symlinks=local_dir_use_symlinks,
+ # max_workers=8
+ )
+
+
+# Download all the weights when this script is run
+def download_all_weights(local_dir_use_symlinks=False):
+ """
+ Downloads all available pretrained model weights to the WEIGHT_DIR.
+
+ :param local_dir_use_symlinks: (bool, optional), Whether to use symlinks for the local directory. Defaults to False.
+ """
+ log_weights_dir()
+
+ logger.info(f'Downloading all weights to {WEIGHT_DIR}')
+ snapshot_download(
+ repo_id="Weiland/Uni-Molv1",
+ local_dir=WEIGHT_DIR,
+ revision="master",
+ repo_type="model",
+ allow_patterns='*',
+ local_dir_use_symlinks=local_dir_use_symlinks,
+ # max_workers=8
+ )
+
+
+if '__main__' == __name__:
+ download_all_weights()
diff --git a/MindChem/applications/unimol/xyz2mol.py b/MindChem/applications/unimol/xyz2mol.py
new file mode 100644
index 0000000000000000000000000000000000000000..52ef8aa8ad8b7d6213074e117b79754914ed1924
--- /dev/null
+++ b/MindChem/applications/unimol/xyz2mol.py
@@ -0,0 +1,829 @@
+"""
+Module for generating rdkit molobj/smiles/molecular graph from free atoms
+
+Implementation by Jan H. Jensen, based on the paper
+
+ Yeonjoon Kim and Woo Youn Kim
+ "Universal Structure Conversion Method for Organic Molecules: From Atomic Connectivity
+ to Three-Dimensional Geometry"
+ Bull. Korean Chem. Soc. 2015, Vol. 36, 1769-1777
+ DOI: 10.1002/bkcs.10334
+
+"""
+
+import copy
+import itertools
+
+from rdkit.Chem import rdmolops
+from rdkit.Chem import rdchem
+try:
+ from rdkit.Chem import rdEHTTools #requires RDKit 2019.9.1 or later
+except ImportError:
+ rdEHTTools = None
+
+from collections import defaultdict
+
+import numpy as np
+import networkx as nx
+
+from rdkit import Chem
+from rdkit.Chem import AllChem, rdmolops
+import sys
+
+global __ATOM_LIST__
+__ATOM_LIST__ = \
+ ['h', 'he',
+ 'li', 'be', 'b', 'c', 'n', 'o', 'f', 'ne',
+ 'na', 'mg', 'al', 'si', 'p', 's', 'cl', 'ar',
+ 'k', 'ca', 'sc', 'ti', 'v ', 'cr', 'mn', 'fe', 'co', 'ni', 'cu',
+ 'zn', 'ga', 'ge', 'as', 'se', 'br', 'kr',
+ 'rb', 'sr', 'y', 'zr', 'nb', 'mo', 'tc', 'ru', 'rh', 'pd', 'ag',
+ 'cd', 'in', 'sn', 'sb', 'te', 'i', 'xe',
+ 'cs', 'ba', 'la', 'ce', 'pr', 'nd', 'pm', 'sm', 'eu', 'gd', 'tb', 'dy',
+ 'ho', 'er', 'tm', 'yb', 'lu', 'hf', 'ta', 'w', 're', 'os', 'ir', 'pt',
+ 'au', 'hg', 'tl', 'pb', 'bi', 'po', 'at', 'rn',
+ 'fr', 'ra', 'ac', 'th', 'pa', 'u', 'np', 'pu']
+
+
+global atomic_valence
+global atomic_valence_electrons
+
+atomic_valence = defaultdict(list)
+atomic_valence[1] = [1]
+atomic_valence[5] = [3,4]
+atomic_valence[6] = [4]
+atomic_valence[7] = [3,4]
+atomic_valence[8] = [2,1,3]
+atomic_valence[9] = [1]
+atomic_valence[14] = [4]
+atomic_valence[15] = [5,3] #[5,4,3]
+atomic_valence[16] = [6,3,2] #[6,4,2]
+atomic_valence[17] = [1]
+atomic_valence[32] = [4]
+atomic_valence[35] = [1]
+atomic_valence[53] = [1]
+
+atomic_valence_electrons = {}
+atomic_valence_electrons[1] = 1
+atomic_valence_electrons[5] = 3
+atomic_valence_electrons[6] = 4
+atomic_valence_electrons[7] = 5
+atomic_valence_electrons[8] = 6
+atomic_valence_electrons[9] = 7
+atomic_valence_electrons[14] = 4
+atomic_valence_electrons[15] = 5
+atomic_valence_electrons[16] = 6
+atomic_valence_electrons[17] = 7
+atomic_valence_electrons[32] = 4
+atomic_valence_electrons[35] = 7
+atomic_valence_electrons[53] = 7
+
+
+def str_atom(atom):
+ """
+ convert integer atom to string atom
+ """
+ global __ATOM_LIST__
+ atom = __ATOM_LIST__[atom - 1]
+ return atom
+
+
+def int_atom(atom):
+ """
+ convert str atom to integer atom
+ """
+ global __ATOM_LIST__
+ #print(atom)
+ atom = atom.lower()
+ return __ATOM_LIST__.index(atom) + 1
+
+
+def get_UA(maxValence_list, valence_list):
+ """
+ """
+ UA = []
+ DU = []
+ for i, (maxValence, valence) in enumerate(zip(maxValence_list, valence_list)):
+ if not maxValence - valence > 0:
+ continue
+ UA.append(i)
+ DU.append(maxValence - valence)
+ return UA, DU
+
+
+def get_BO(AC, UA, DU, valences, UA_pairs, use_graph=True):
+ """
+ """
+ BO = AC.copy()
+ DU_save = []
+
+ while DU_save != DU:
+ for i, j in UA_pairs:
+ BO[i, j] += 1
+ BO[j, i] += 1
+
+ BO_valence = list(BO.sum(axis=1))
+ DU_save = copy.copy(DU)
+ UA, DU = get_UA(valences, BO_valence)
+ UA_pairs = get_UA_pairs(UA, AC, use_graph=use_graph)[0]
+
+ return BO
+
+
+def valences_not_too_large(BO, valences):
+ """
+ """
+ number_of_bonds_list = BO.sum(axis=1)
+ for valence, number_of_bonds in zip(valences, number_of_bonds_list):
+ if number_of_bonds > valence:
+ return False
+
+ return True
+
+def charge_is_OK(BO, AC, charge, DU, atomic_valence_electrons, atoms, valences,
+ allow_charged_fragments=True):
+ # total charge
+ Q = 0
+
+ # charge fragment list
+ q_list = []
+
+ if allow_charged_fragments:
+
+ BO_valences = list(BO.sum(axis=1))
+ for i, atom in enumerate(atoms):
+ q = get_atomic_charge(atom, atomic_valence_electrons[atom], BO_valences[i])
+ Q += q
+ if atom == 6:
+ number_of_single_bonds_to_C = list(BO[i, :]).count(1)
+ if number_of_single_bonds_to_C == 2 and BO_valences[i] == 2:
+ Q += 1
+ q = 2
+ if number_of_single_bonds_to_C == 3 and Q + 1 < charge:
+ Q += 2
+ q = 1
+
+ if q != 0:
+ q_list.append(q)
+
+ return (charge == Q)
+
+def BO_is_OK(BO, AC, charge, DU, atomic_valence_electrons, atoms, valences,
+ allow_charged_fragments=True):
+ """
+ Sanity of bond-orders
+
+ args:
+ BO -
+ AC -
+ charge -
+ DU -
+
+
+ optional
+ allow_charges_fragments -
+
+
+ returns:
+ boolean - true of molecule is OK, false if not
+ """
+
+ if not valences_not_too_large(BO, valences):
+ return False
+
+ check_sum = (BO - AC).sum() == sum(DU)
+ check_charge = charge_is_OK(BO, AC, charge, DU, atomic_valence_electrons, atoms, valences,
+ allow_charged_fragments)
+
+ if check_charge and check_sum:
+ return True
+
+ return False
+
+
+def get_atomic_charge(atom, atomic_valence_electrons, BO_valence):
+ """
+ """
+
+ if atom == 1:
+ charge = 1 - BO_valence
+ elif atom == 5:
+ charge = 3 - BO_valence
+ elif atom == 15 and BO_valence == 5:
+ charge = 0
+ elif atom == 16 and BO_valence == 6:
+ charge = 0
+ else:
+ charge = atomic_valence_electrons - 8 + BO_valence
+
+ return charge
+
+
+def clean_charges(mol):
+ """
+ This hack should not be needed anymore, but is kept just in case
+
+ """
+
+ Chem.SanitizeMol(mol)
+ #rxn_smarts = ['[N+:1]=[*:2]-[C-:3]>>[N+0:1]-[*:2]=[C-0:3]',
+ # '[N+:1]=[*:2]-[O-:3]>>[N+0:1]-[*:2]=[O-0:3]',
+ # '[N+:1]=[*:2]-[*:3]=[*:4]-[O-:5]>>[N+0:1]-[*:2]=[*:3]-[*:4]=[O-0:5]',
+ # '[#8:1]=[#6:2]([!-:6])[*:3]=[*:4][#6-:5]>>[*-:1][*:2]([*:6])=[*:3][*:4]=[*+0:5]',
+ # '[O:1]=[c:2][c-:3]>>[*-:1][*:2][*+0:3]',
+ # '[O:1]=[C:2][C-:3]>>[*-:1][*:2]=[*+0:3]']
+
+ rxn_smarts = ['[#6,#7:1]1=[#6,#7:2][#6,#7:3]=[#6,#7:4][CX3-,NX3-:5][#6,#7:6]1=[#6,#7:7]>>'
+ '[#6,#7:1]1=[#6,#7:2][#6,#7:3]=[#6,#7:4][-0,-0:5]=[#6,#7:6]1[#6-,#7-:7]',
+ '[#6,#7:1]1=[#6,#7:2][#6,#7:3](=[#6,#7:4])[#6,#7:5]=[#6,#7:6][CX3-,NX3-:7]1>>'
+ '[#6,#7:1]1=[#6,#7:2][#6,#7:3]([#6-,#7-:4])=[#6,#7:5][#6,#7:6]=[-0,-0:7]1']
+
+ fragments = Chem.GetMolFrags(mol,asMols=True,sanitizeFrags=False)
+
+ for i, fragment in enumerate(fragments):
+ for smarts in rxn_smarts:
+ patt = Chem.MolFromSmarts(smarts.split(">>")[0])
+ while fragment.HasSubstructMatch(patt):
+ rxn = AllChem.ReactionFromSmarts(smarts)
+ ps = rxn.RunReactants((fragment,))
+ fragment = ps[0][0]
+ Chem.SanitizeMol(fragment)
+ if i == 0:
+ mol = fragment
+ else:
+ mol = Chem.CombineMols(mol, fragment)
+
+ return mol
+
+
+def BO2mol(mol, BO_matrix, atoms, atomic_valence_electrons,
+ mol_charge, allow_charged_fragments=True, use_atom_maps=False):
+ """
+ based on code written by Paolo Toscani
+
+ From bond order, atoms, valence structure and total charge, generate an
+ rdkit molecule.
+
+ args:
+ mol - rdkit molecule
+ BO_matrix - bond order matrix of molecule
+ atoms - list of integer atomic symbols
+ atomic_valence_electrons -
+ mol_charge - total charge of molecule
+
+ optional:
+ allow_charged_fragments - bool - allow charged fragments
+
+ returns
+ mol - updated rdkit molecule with bond connectivity
+
+ """
+
+ l = len(BO_matrix)
+ l2 = len(atoms)
+ BO_valences = list(BO_matrix.sum(axis=1))
+
+ if (l != l2):
+ raise RuntimeError('sizes of adjMat ({0:d}) and Atoms {1:d} differ'.format(l, l2))
+
+ rwMol = Chem.RWMol(mol)
+
+ bondTypeDict = {
+ 1: Chem.BondType.SINGLE,
+ 2: Chem.BondType.DOUBLE,
+ 3: Chem.BondType.TRIPLE
+ }
+
+ for i in range(l):
+ for j in range(i + 1, l):
+ bo = int(round(BO_matrix[i, j]))
+ if (bo == 0):
+ continue
+ bt = bondTypeDict.get(bo, Chem.BondType.SINGLE)
+ rwMol.AddBond(i, j, bt)
+
+ mol = rwMol.GetMol()
+
+ if allow_charged_fragments:
+ mol = set_atomic_charges(
+ mol,
+ atoms,
+ atomic_valence_electrons,
+ BO_valences,
+ BO_matrix,
+ mol_charge,
+ use_atom_maps)
+ else:
+ mol = set_atomic_radicals(mol, atoms, atomic_valence_electrons, BO_valences,
+ use_atom_maps)
+
+ return mol
+
+
+def set_atomic_charges(mol, atoms, atomic_valence_electrons,
+ BO_valences, BO_matrix, mol_charge,
+ use_atom_maps):
+ """
+ """
+ q = 0
+ for i, atom in enumerate(atoms):
+ a = mol.GetAtomWithIdx(i)
+ if use_atom_maps:
+ a.SetAtomMapNum(i+1)
+ charge = get_atomic_charge(atom, atomic_valence_electrons[atom], BO_valences[i])
+ q += charge
+ if atom == 6:
+ number_of_single_bonds_to_C = list(BO_matrix[i, :]).count(1)
+ if number_of_single_bonds_to_C == 2 and BO_valences[i] == 2:
+ q += 1
+ charge = 0
+ if number_of_single_bonds_to_C == 3 and q + 1 < mol_charge:
+ q += 2
+ charge = 1
+
+ if (abs(charge) > 0):
+ a.SetFormalCharge(int(charge))
+
+ #mol = clean_charges(mol)
+
+ return mol
+
+
+def set_atomic_radicals(mol, atoms, atomic_valence_electrons, BO_valences,
+ use_atom_maps):
+ """
+
+ The number of radical electrons = absolute atomic charge
+
+ """
+ for i, atom in enumerate(atoms):
+ a = mol.GetAtomWithIdx(i)
+ if use_atom_maps:
+ a.SetAtomMapNum(i+1)
+ charge = get_atomic_charge(
+ atom,
+ atomic_valence_electrons[atom],
+ BO_valences[i])
+
+ if (abs(charge) > 0):
+ a.SetNumRadicalElectrons(abs(int(charge)))
+
+ return mol
+
+
+def get_bonds(UA, AC):
+ """
+
+ """
+ bonds = []
+
+ for k, i in enumerate(UA):
+ for j in UA[k + 1:]:
+ if AC[i, j] == 1:
+ bonds.append(tuple(sorted([i, j])))
+
+ return bonds
+
+
+def get_UA_pairs(UA, AC, use_graph=True):
+ """
+
+ """
+
+ bonds = get_bonds(UA, AC)
+
+ if len(bonds) == 0:
+ return [()]
+
+ if use_graph:
+ G = nx.Graph()
+ G.add_edges_from(bonds)
+ UA_pairs = [list(nx.max_weight_matching(G))]
+ return UA_pairs
+
+ max_atoms_in_combo = 0
+ UA_pairs = [()]
+ for combo in list(itertools.combinations(bonds, int(len(UA) / 2))):
+ flat_list = [item for sublist in combo for item in sublist]
+ atoms_in_combo = len(set(flat_list))
+ if atoms_in_combo > max_atoms_in_combo:
+ max_atoms_in_combo = atoms_in_combo
+ UA_pairs = [combo]
+
+ elif atoms_in_combo == max_atoms_in_combo:
+ UA_pairs.append(combo)
+
+ return UA_pairs
+
+
+def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True):
+ """
+
+ implemenation of algorithm shown in Figure 2
+
+ UA: unsaturated atoms
+
+ DU: degree of unsaturation (u matrix in Figure)
+
+ best_BO: Bcurr in Figure
+
+ """
+
+ global atomic_valence
+ global atomic_valence_electrons
+
+ # make a list of valences, e.g. for CO: [[4],[2,1]]
+ valences_list_of_lists = []
+ AC_valence = list(AC.sum(axis=1))
+
+ try:
+ for i,(atomicNum,valence) in enumerate(zip(atoms,AC_valence)):
+ # valence can't be smaller than number of neighbourgs
+ possible_valence = [x for x in atomic_valence[atomicNum] if x >= valence]
+ if not possible_valence:
+ print('Valence of atom',i,'is',valence,'which bigger than allowed max',max(atomic_valence[atomicNum]),'. Stopping')
+ # sys.exit()
+ valences_list_of_lists.append(possible_valence)
+
+ # convert [[4],[2,1]] to [[4,2],[4,1]]
+ valences_list = itertools.product(*valences_list_of_lists)
+
+ best_BO = AC.copy()
+
+ for valences in valences_list:
+
+ UA, DU_from_AC = get_UA(valences, AC_valence)
+
+ check_len = (len(UA) == 0)
+ if check_len:
+ check_bo = BO_is_OK(AC, AC, charge, DU_from_AC,
+ atomic_valence_electrons, atoms, valences,
+ allow_charged_fragments=allow_charged_fragments)
+ else:
+ check_bo = None
+
+ if check_len and check_bo:
+ return AC, atomic_valence_electrons
+
+ UA_pairs_list = get_UA_pairs(UA, AC, use_graph=use_graph)
+ for UA_pairs in UA_pairs_list:
+ BO = get_BO(AC, UA, DU_from_AC, valences, UA_pairs, use_graph=use_graph)
+ status = BO_is_OK(BO, AC, charge, DU_from_AC,
+ atomic_valence_electrons, atoms, valences,
+ allow_charged_fragments=allow_charged_fragments)
+ charge_OK = charge_is_OK(BO, AC, charge, DU_from_AC, atomic_valence_electrons, atoms, valences,
+ allow_charged_fragments=allow_charged_fragments)
+
+ if status:
+ return BO, atomic_valence_electrons
+ elif BO.sum() >= best_BO.sum() and valences_not_too_large(BO, valences) and charge_OK:
+ best_BO = BO.copy()
+
+ except Exception as e:
+ print(f"\n{str(e)}")
+
+ return best_BO, atomic_valence_electrons
+
+
+def AC2mol(mol, AC, atoms, charge, allow_charged_fragments=True,
+ use_graph=True, use_atom_maps=False):
+ """
+ """
+
+ # convert AC matrix to bond order (BO) matrix
+ BO, atomic_valence_electrons = AC2BO(
+ AC,
+ atoms,
+ charge,
+ allow_charged_fragments=allow_charged_fragments,
+ use_graph=use_graph)
+
+ # add BO connectivity and charge info to mol object
+ mol = BO2mol(
+ mol,
+ BO,
+ atoms,
+ atomic_valence_electrons,
+ charge,
+ allow_charged_fragments=allow_charged_fragments,
+ use_atom_maps=use_atom_maps)
+
+ # If charge is not correct don't return mol
+ if Chem.GetFormalCharge(mol) != charge:
+ return []
+
+ # BO2mol returns an arbitrary resonance form. Let's make the rest
+ mols = rdchem.ResonanceMolSupplier(mol, Chem.UNCONSTRAINED_CATIONS, Chem.UNCONSTRAINED_ANIONS)
+ mols = [mol for mol in mols]
+
+ return mols
+
+
+def get_proto_mol(atoms):
+ """
+ """
+ mol = Chem.MolFromSmarts("[#" + str(atoms[0]) + "]")
+ rwMol = Chem.RWMol(mol)
+ for i in range(1, len(atoms)):
+ a = Chem.Atom(atoms[i])
+ rwMol.AddAtom(a)
+
+ mol = rwMol.GetMol()
+
+ return mol
+
+
+def read_xyz_file(filename, look_for_charge=True):
+ """
+ """
+
+ atomic_symbols = []
+ xyz_coordinates = []
+ charge = 0
+ title = ""
+
+ with open(filename, "r") as file:
+ for line_number, line in enumerate(file):
+ if line_number == 0:
+ num_atoms = int(line)
+ elif line_number == 1:
+ title = line
+ if "charge=" in line:
+ charge = int(line.split("=")[1])
+ else:
+ atomic_symbol, x, y, z = line.split()
+ atomic_symbols.append(atomic_symbol)
+ xyz_coordinates.append([float(x), float(y), float(z)])
+
+ atoms = [int_atom(atom) for atom in atomic_symbols]
+
+ return atoms, charge, xyz_coordinates
+
+
+def xyz2AC(atoms, xyz, charge, use_huckel=False):
+ """
+
+ atoms and coordinates to atom connectivity (AC)
+
+ args:
+ atoms - int atom types
+ xyz - coordinates
+ charge - molecule charge
+
+ optional:
+ use_huckel - Use Huckel method for atom connecitivty
+
+ returns
+ ac - atom connectivity matrix
+ mol - rdkit molecule
+
+ """
+
+ if use_huckel:
+ return xyz2AC_huckel(atoms, xyz, charge)
+ else:
+ return xyz2AC_vdW(atoms, xyz)
+
+
+def xyz2AC_vdW(atoms, xyz):
+
+ # Get mol template
+ mol = get_proto_mol(atoms)
+
+ # Set coordinates
+ conf = Chem.Conformer(mol.GetNumAtoms())
+ for i in range(mol.GetNumAtoms()):
+ conf.SetAtomPosition(i, (xyz[i][0], xyz[i][1], xyz[i][2]))
+ mol.AddConformer(conf)
+
+ AC = get_AC(mol)
+
+ return AC, mol
+
+
+def get_AC(mol, covalent_factor=1.3):
+ """
+
+ Generate adjacent matrix from atoms and coordinates.
+
+ AC is a (num_atoms, num_atoms) matrix with 1 being covalent bond and 0 is not
+
+
+ covalent_factor - 1.3 is an arbitrary factor
+
+ args:
+ mol - rdkit molobj with 3D conformer
+
+ optional
+ covalent_factor - increase covalent bond length threshold with facto
+
+ returns:
+ AC - adjacent matrix
+
+ """
+
+ # Calculate distance matrix
+ dMat = Chem.Get3DDistanceMatrix(mol)
+
+ pt = Chem.GetPeriodicTable()
+ num_atoms = mol.GetNumAtoms()
+ AC = np.zeros((num_atoms, num_atoms), dtype=int)
+
+ for i in range(num_atoms):
+ a_i = mol.GetAtomWithIdx(i)
+ Rcov_i = pt.GetRcovalent(a_i.GetAtomicNum()) * covalent_factor
+ for j in range(i + 1, num_atoms):
+ a_j = mol.GetAtomWithIdx(j)
+ Rcov_j = pt.GetRcovalent(a_j.GetAtomicNum()) * covalent_factor
+ if dMat[i, j] <= Rcov_i + Rcov_j:
+ AC[i, j] = 1
+ AC[j, i] = 1
+
+ return AC
+
+
+def xyz2AC_huckel(atomicNumList, xyz, charge):
+ """
+
+ args
+ atomicNumList - atom type list
+ xyz - coordinates
+ charge - molecule charge
+
+ returns
+ ac - atom connectivity
+ mol - rdkit molecule
+
+ """
+ mol = get_proto_mol(atomicNumList)
+
+ conf = Chem.Conformer(mol.GetNumAtoms())
+ for i in range(mol.GetNumAtoms()):
+ conf.SetAtomPosition(i,(xyz[i][0],xyz[i][1],xyz[i][2]))
+ mol.AddConformer(conf)
+
+ num_atoms = len(atomicNumList)
+ AC = np.zeros((num_atoms,num_atoms)).astype(int)
+
+ mol_huckel = Chem.Mol(mol)
+ mol_huckel.GetAtomWithIdx(0).SetFormalCharge(charge) #mol charge arbitrarily added to 1st atom
+
+ passed,result = rdEHTTools.RunMol(mol_huckel)
+ opop = result.GetReducedOverlapPopulationMatrix()
+ tri = np.zeros((num_atoms, num_atoms))
+ tri[np.tril(np.ones((num_atoms, num_atoms), dtype=bool))] = opop #lower triangular to square matrix
+ for i in range(num_atoms):
+ for j in range(i+1,num_atoms):
+ pair_pop = abs(tri[j,i])
+ if pair_pop >= 0.15: #arbitry cutoff for bond. May need adjustment
+ AC[i,j] = 1
+ AC[j,i] = 1
+
+ return AC, mol
+
+
+def chiral_stereo_check(mol):
+ """
+ Find and embed chiral information into the model based on the coordinates
+
+ args:
+ mol - rdkit molecule, with embeded conformer
+
+ """
+ Chem.SanitizeMol(mol)
+ Chem.DetectBondStereochemistry(mol, -1)
+ Chem.AssignStereochemistry(mol, flagPossibleStereoCenters=True, force=True)
+ Chem.AssignAtomChiralTagsFromStructure(mol, -1)
+
+ return
+
+
+def xyz2mol(atoms, coordinates, charge=0, allow_charged_fragments=True,
+ use_graph=True, use_huckel=False, embed_chiral=True,
+ use_atom_maps=False):
+ """
+ Generate a rdkit molobj from atoms, coordinates and a total_charge.
+
+ args:
+ atoms - list of atom types (int)
+ coordinates - 3xN Cartesian coordinates
+ charge - total charge of the system (default: 0)
+
+ optional:
+ allow_charged_fragments - alternatively radicals are made
+ use_graph - use graph (networkx)
+ use_huckel - Use Huckel method for atom connectivity prediction
+ embed_chiral - embed chiral information to the molecule
+
+ returns:
+ mols - list of rdkit molobjects
+
+ """
+
+ # Get atom connectivity (AC) matrix, list of atomic numbers, molecular charge,
+ # and mol object with no connectivity information
+ AC, mol = xyz2AC(atoms, coordinates, charge, use_huckel=use_huckel)
+
+ # Convert AC to bond order matrix and add connectivity and charge info to
+ # mol object
+ new_mols = AC2mol(mol, AC, atoms, charge,
+ allow_charged_fragments=allow_charged_fragments,
+ use_graph=use_graph,
+ use_atom_maps=use_atom_maps)
+
+ # Check for stereocenters and chiral centers
+ if embed_chiral:
+ for new_mol in new_mols:
+ chiral_stereo_check(new_mol)
+
+ return new_mols
+
+
+def main():
+
+
+ return
+
+
+if __name__ == "__main__":
+
+ import argparse
+
+ parser = argparse.ArgumentParser(usage='%(prog)s [options] molecule.xyz')
+ parser.add_argument('structure', metavar='structure', type=str)
+ parser.add_argument('-s', '--sdf',
+ action="store_true",
+ help="Dump sdf file")
+ parser.add_argument('--ignore-chiral',
+ action="store_true",
+ help="Ignore chiral centers")
+ parser.add_argument('--no-charged-fragments',
+ action="store_true",
+ help="Allow radicals to be made")
+ parser.add_argument('--no-graph',
+ action="store_true",
+ help="Run xyz2mol without networkx dependencies")
+
+ # huckel uses extended Huckel bond orders to locate bonds (requires RDKit 2019.9.1 or later)
+ # otherwise van der Waals radii are used
+ parser.add_argument('--use-huckel',
+ action="store_true",
+ help="Use Huckel method for atom connectivity")
+ parser.add_argument('-o', '--output-format',
+ action="store",
+ type=str,
+ help="Output format [smiles,sdf] (default=sdf)")
+ parser.add_argument('-c', '--charge',
+ action="store",
+ metavar="int",
+ type=int,
+ help="Total charge of the system")
+
+ args = parser.parse_args()
+
+ # read xyz file
+ filename = args.structure
+
+ # allow for charged fragments, alternatively radicals are made
+ charged_fragments = not args.no_charged_fragments
+
+ # quick is faster for large systems but requires networkx
+ # if you don't want to install networkx set quick=False and
+ # uncomment 'import networkx as nx' at the top of the file
+ quick = not args.no_graph
+
+ # chiral comment
+ embed_chiral = not args.ignore_chiral
+
+ # read atoms and coordinates. Try to find the charge
+ atoms, charge, xyz_coordinates = read_xyz_file(filename)
+
+ # huckel uses extended Huckel bond orders to locate bonds (requires RDKit 2019.9.1 or later)
+ # otherwise van der Waals radii are used
+ use_huckel = args.use_huckel
+
+ # if explicit charge from args, set it
+ if args.charge is not None:
+ charge = int(args.charge)
+
+ # Get the molobjs
+ mols = xyz2mol(atoms, xyz_coordinates,
+ charge=charge,
+ use_graph=quick,
+ allow_charged_fragments=charged_fragments,
+ embed_chiral=embed_chiral,
+ use_huckel=use_huckel)
+
+ # Print output
+ for mol in mols:
+ if args.output_format == "sdf":
+ txt = Chem.MolToMolBlock(mol)
+ print(txt)
+
+ else:
+ # Canonical hack
+ isomeric_smiles = not args.ignore_chiral
+ smiles = Chem.MolToSmiles(mol, isomericSmiles=isomeric_smiles)
+ m = Chem.MolFromSmiles(smiles)
+ smiles = Chem.MolToSmiles(m, isomericSmiles=isomeric_smiles)
+ print(smiles)